barrier refactored

This commit is contained in:
yedf2 2021-08-11 16:41:57 +08:00
parent a7c72162da
commit fffeb902a6
8 changed files with 54 additions and 38 deletions

View File

@ -11,21 +11,22 @@ import (
// BusiFunc type for busi func
type BusiFunc func(db *sql.Tx) (interface{}, error)
// TransInfo every branch info
type TransInfo struct {
// BranchBarrier every branch info
type BranchBarrier struct {
TransType string
Gid string
BranchID string
BranchType string
BarrierID int
}
func (t *TransInfo) String() string {
return fmt.Sprintf("transInfo: %s %s %s %s", t.TransType, t.Gid, t.BranchID, t.BranchType)
func (bb *BranchBarrier) String() string {
return fmt.Sprintf("transInfo: %s %s %s %s", bb.TransType, bb.Gid, bb.BranchID, bb.BranchType)
}
// TransInfoFromQuery construct transaction info from request
func TransInfoFromQuery(qs url.Values) (*TransInfo, error) {
ti := &TransInfo{
// BarrierFromQuery construct transaction info from request
func BarrierFromQuery(qs url.Values) (*BranchBarrier, error) {
ti := &BranchBarrier{
TransType: qs.Get("trans_type"),
Gid: qs.Get("gid"),
BranchID: qs.Get("branch_id"),
@ -37,14 +38,14 @@ func TransInfoFromQuery(qs url.Values) (*TransInfo, error) {
return ti, nil
}
func insertBarrier(tx *sql.Tx, transType string, gid string, branchID string, branchType string, reason string) (int64, error) {
func insertBarrier(tx *sql.Tx, transType string, gid string, branchID string, branchType string, barrierID string, reason string) (int64, error) {
if branchType == "" {
return 0, nil
}
return StxExec(tx, "insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values(?,?,?,?,?)", transType, gid, branchID, branchType, reason)
return StxExec(tx, "insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, barrier_id, reason) values(?,?,?,?,?,?)", transType, gid, branchID, branchType, barrierID, reason)
}
// ThroughBarrierCall 子事务屏障,详细介绍见 https://zhuanlan.zhihu.com/p/388444465
// Call 子事务屏障,详细介绍见 https://zhuanlan.zhihu.com/p/388444465
// db: 本地数据库
// transInfo: 事务信息
// bisiCall: 业务函数,仅在必要时被调用
@ -53,7 +54,9 @@ func insertBarrier(tx *sql.Tx, transType string, gid string, branchID string, br
// 如果发生重复调用则busiCall不会被重复调用直接对保存在数据库中上一次的结果进行unmarshal通常是一个map[string]interface{}直接作为http的resp
// 如果发生悬挂则busiCall不会被调用直接返回错误 {"dtm_result": "FAILURE"}
// 如果发生空补偿则busiCall不会被调用直接返回 {"dtm_result": "SUCCESS"}
func ThroughBarrierCall(db *sql.DB, transInfo *TransInfo, busiCall BusiFunc) (res interface{}, rerr error) {
func (bb *BranchBarrier) Call(db *sql.DB, busiCall BusiFunc) (res interface{}, rerr error) {
bb.BarrierID = bb.BarrierID + 1
bid := fmt.Sprintf("%02d", bb.BarrierID)
tx, rerr := db.BeginTx(context.Background(), &sql.TxOptions{})
if rerr != nil {
return
@ -69,21 +72,21 @@ func ThroughBarrierCall(db *sql.DB, transInfo *TransInfo, busiCall BusiFunc) (re
tx.Commit()
}
}()
ti := transInfo
ti := bb
originType := map[string]string{
"cancel": "try",
"compensate": "action",
}[ti.BranchType]
originAffected, _ := insertBarrier(tx, ti.TransType, ti.Gid, ti.BranchID, originType, ti.BranchType)
currentAffected, rerr := insertBarrier(tx, ti.TransType, ti.Gid, ti.BranchID, ti.BranchType, ti.BranchType)
originAffected, _ := insertBarrier(tx, ti.TransType, ti.Gid, ti.BranchID, originType, bid, ti.BranchType)
currentAffected, rerr := insertBarrier(tx, ti.TransType, ti.Gid, ti.BranchID, ti.BranchType, bid, ti.BranchType)
Logf("originAffected: %d currentAffected: %d", originAffected, currentAffected)
if (ti.BranchType == "cancel" || ti.BranchType == "compensate") && originAffected > 0 { // 这个是空补偿,返回成功
res = ResultSuccess
return
} else if currentAffected == 0 { // 插入不成功
var result sql.NullString
err := StxQueryRow(tx, "select result from dtm_barrier.barrier where trans_type=? and gid=? and branch_id=? and branch_type=? and reason=?",
ti.TransType, ti.Gid, ti.BranchID, ti.BranchType, ti.BranchType).Scan(&result)
err := StxQueryRow(tx, "select result from dtm_barrier.barrier where trans_type=? and gid=? and branch_id=? and branch_type=? and barrier_id=? and reason=?",
ti.TransType, ti.Gid, ti.BranchID, ti.BranchType, bid, ti.BranchType).Scan(&result)
if err == sql.ErrNoRows { // 这个是悬挂操作返回失败AP收到这个返回会尽快回滚
res = ResultFailure
return

View File

@ -7,11 +7,12 @@ create table if not exists dtm_barrier.barrier(
gid varchar(128) default'',
branch_id varchar(128) default '',
branch_type varchar(45) default '',
barrier_id varchar(45) default '',
reason varchar(45) default '' comment 'the branch type who insert this record',
result varchar(2047) default null comment 'the business result of this branch',
create_time datetime DEFAULT now(),
update_time datetime DEFAULT now(),
key(create_time),
key(update_time),
UNIQUE key(gid, branch_id, branch_type)
UNIQUE key(gid, branch_id, branch_type, barrier_id)
);

View File

@ -14,7 +14,7 @@ func TestQuery(t *testing.T) {
assert.Error(t, err)
_, err = TccFromQuery(qs)
assert.Error(t, err)
_, err = TransInfoFromQuery(qs)
_, err = BarrierFromQuery(qs)
assert.Error(t, err)
}

View File

@ -22,7 +22,7 @@ func TestTypes(t *testing.T) {
})
assert.Error(t, err)
assert.Error(t, err)
_, err = TransInfoFromQuery(url.Values{})
_, err = BarrierFromQuery(url.Values{})
assert.Error(t, err)
}

View File

@ -52,8 +52,8 @@ func reqFrom(c *gin.Context) *TransReq {
return v.(*TransReq)
}
func infoFromContext(c *gin.Context) *dtmcli.TransInfo {
info := dtmcli.TransInfo{
func infoFromContext(c *gin.Context) *dtmcli.BranchBarrier {
info := dtmcli.BranchBarrier{
TransType: c.Query("trans_type"),
Gid: c.Query("gid"),
BranchID: c.Query("branch_id"),
@ -73,8 +73,8 @@ func sdbGet() *sql.DB {
}
// MustGetTrans construct transaction info from request
func MustGetTrans(c *gin.Context) *dtmcli.TransInfo {
ti, err := dtmcli.TransInfoFromQuery(c.Request.URL.Query())
func MustGetTrans(c *gin.Context) *dtmcli.BranchBarrier {
ti, err := dtmcli.BarrierFromQuery(c.Request.URL.Query())
e2p(err)
return ti
}

View File

@ -42,13 +42,15 @@ func sagaBarrierTransIn(c *gin.Context) (interface{}, error) {
if req.TransInResult != "" {
return req.TransInResult, nil
}
return dtmcli.ThroughBarrierCall(sdbGet(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
barrier := MustGetTrans(c)
return barrier.Call(sdbGet(), func(sdb *sql.Tx) (interface{}, error) {
return sagaBarrierAdjustBalance(sdb, 1, req.Amount)
})
}
func sagaBarrierTransInCompensate(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(sdbGet(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
barrier := MustGetTrans(c)
return barrier.Call(sdbGet(), func(sdb *sql.Tx) (interface{}, error) {
return sagaBarrierAdjustBalance(sdb, 1, -reqFrom(c).Amount)
})
}
@ -58,13 +60,15 @@ func sagaBarrierTransOut(c *gin.Context) (interface{}, error) {
if req.TransInResult != "" {
return req.TransInResult, nil
}
return dtmcli.ThroughBarrierCall(sdbGet(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
barrier := MustGetTrans(c)
return barrier.Call(sdbGet(), func(sdb *sql.Tx) (interface{}, error) {
return sagaBarrierAdjustBalance(sdb, 2, -req.Amount)
})
}
func sagaBarrierTransOutCompensate(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(sdbGet(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
barrier := MustGetTrans(c)
return barrier.Call(sdbGet(), func(sdb *sql.Tx) (interface{}, error) {
return sagaBarrierAdjustBalance(sdb, 2, reqFrom(c).Amount)
})
}

View File

@ -65,19 +65,22 @@ func tccBarrierTransInTry(c *gin.Context) (interface{}, error) {
if req.TransInResult != "" {
return req.TransInResult, nil
}
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
barrier := MustGetTrans(c)
return barrier.Call(dbGet().ToSQLDB(), func(sdb *sql.Tx) (interface{}, error) {
return adjustTrading(sdb, transInUID, req.Amount)
})
}
func tccBarrierTransInConfirm(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
barrier := MustGetTrans(c)
return barrier.Call(dbGet().ToSQLDB(), func(sdb *sql.Tx) (interface{}, error) {
return adjustBalance(sdb, transInUID, reqFrom(c).Amount)
})
}
func tccBarrierTransInCancel(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
barrier := MustGetTrans(c)
return barrier.Call(dbGet().ToSQLDB(), func(sdb *sql.Tx) (interface{}, error) {
return adjustTrading(sdb, transInUID, -reqFrom(c).Amount)
})
}
@ -87,20 +90,23 @@ func tccBarrierTransOutTry(c *gin.Context) (interface{}, error) {
if req.TransInResult != "" {
return req.TransInResult, nil
}
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
barrier := MustGetTrans(c)
return barrier.Call(dbGet().ToSQLDB(), func(sdb *sql.Tx) (interface{}, error) {
return adjustTrading(sdb, transOutUID, -req.Amount)
})
}
func tccBarrierTransOutConfirm(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
barrier := MustGetTrans(c)
return barrier.Call(dbGet().ToSQLDB(), func(sdb *sql.Tx) (interface{}, error) {
return adjustBalance(sdb, transOutUID, -reqFrom(c).Amount)
})
}
// TccBarrierTransOutCancel will be use in test
func TccBarrierTransOutCancel(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
barrier := MustGetTrans(c)
return barrier.Call(dbGet().ToSQLDB(), func(sdb *sql.Tx) (interface{}, error) {
return adjustTrading(sdb, transOutUID, reqFrom(c).Amount)
})
}

View File

@ -20,7 +20,7 @@ var app *gin.Engine
// BarrierModel barrier model for gorm
type BarrierModel struct {
common.ModelBase
dtmcli.TransInfo
dtmcli.BranchBarrier
}
// TableName gorm table name
@ -111,14 +111,14 @@ func transQuery(t *testing.T, gid string) {
func TestSqlDB(t *testing.T) {
asserts := assert.New(t)
db := common.DbGet(config.DB)
transInfo := &dtmcli.TransInfo{
barrier := &dtmcli.BranchBarrier{
TransType: "saga",
Gid: "gid2",
BranchID: "branch_id2",
BranchType: "action",
}
db.Must().Exec("insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values('saga', 'gid1', 'branch_id1', 'action', 'saga')")
_, err := dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.Tx) (interface{}, error) {
_, err := barrier.Call(db.ToSQLDB(), func(db *sql.Tx) (interface{}, error) {
dtmcli.Logf("rollback gid2")
return nil, fmt.Errorf("gid2 error")
})
@ -128,14 +128,16 @@ func TestSqlDB(t *testing.T) {
dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{})
asserts.Equal(dbr.RowsAffected, int64(0))
gid2Res := dtmcli.M{"result": "first"}
_, err = dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.Tx) (interface{}, error) {
barrier.BarrierID = 0
_, err = barrier.Call(db.ToSQLDB(), func(db *sql.Tx) (interface{}, error) {
dtmcli.Logf("submit gid2")
return gid2Res, nil
})
asserts.Nil(err)
dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{})
asserts.Equal(dbr.RowsAffected, int64(1))
newResult, err := dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.Tx) (interface{}, error) {
barrier.BarrierID = 0
newResult, err := barrier.Call(db.ToSQLDB(), func(db *sql.Tx) (interface{}, error) {
dtmcli.Logf("submit gid2")
return dtmcli.MS{"result": "ignored"}, nil
})