barrier refactored

This commit is contained in:
yedf2 2021-08-11 16:41:57 +08:00
parent a7c72162da
commit fffeb902a6
8 changed files with 54 additions and 38 deletions

View File

@ -11,21 +11,22 @@ import (
// BusiFunc type for busi func // BusiFunc type for busi func
type BusiFunc func(db *sql.Tx) (interface{}, error) type BusiFunc func(db *sql.Tx) (interface{}, error)
// TransInfo every branch info // BranchBarrier every branch info
type TransInfo struct { type BranchBarrier struct {
TransType string TransType string
Gid string Gid string
BranchID string BranchID string
BranchType string BranchType string
BarrierID int
} }
func (t *TransInfo) String() string { func (bb *BranchBarrier) String() string {
return fmt.Sprintf("transInfo: %s %s %s %s", t.TransType, t.Gid, t.BranchID, t.BranchType) return fmt.Sprintf("transInfo: %s %s %s %s", bb.TransType, bb.Gid, bb.BranchID, bb.BranchType)
} }
// TransInfoFromQuery construct transaction info from request // BarrierFromQuery construct transaction info from request
func TransInfoFromQuery(qs url.Values) (*TransInfo, error) { func BarrierFromQuery(qs url.Values) (*BranchBarrier, error) {
ti := &TransInfo{ ti := &BranchBarrier{
TransType: qs.Get("trans_type"), TransType: qs.Get("trans_type"),
Gid: qs.Get("gid"), Gid: qs.Get("gid"),
BranchID: qs.Get("branch_id"), BranchID: qs.Get("branch_id"),
@ -37,14 +38,14 @@ func TransInfoFromQuery(qs url.Values) (*TransInfo, error) {
return ti, nil return ti, nil
} }
func insertBarrier(tx *sql.Tx, transType string, gid string, branchID string, branchType string, reason string) (int64, error) { func insertBarrier(tx *sql.Tx, transType string, gid string, branchID string, branchType string, barrierID string, reason string) (int64, error) {
if branchType == "" { if branchType == "" {
return 0, nil return 0, nil
} }
return StxExec(tx, "insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values(?,?,?,?,?)", transType, gid, branchID, branchType, reason) return StxExec(tx, "insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, barrier_id, reason) values(?,?,?,?,?,?)", transType, gid, branchID, branchType, barrierID, reason)
} }
// ThroughBarrierCall 子事务屏障,详细介绍见 https://zhuanlan.zhihu.com/p/388444465 // Call 子事务屏障,详细介绍见 https://zhuanlan.zhihu.com/p/388444465
// db: 本地数据库 // db: 本地数据库
// transInfo: 事务信息 // transInfo: 事务信息
// bisiCall: 业务函数,仅在必要时被调用 // bisiCall: 业务函数,仅在必要时被调用
@ -53,7 +54,9 @@ func insertBarrier(tx *sql.Tx, transType string, gid string, branchID string, br
// 如果发生重复调用则busiCall不会被重复调用直接对保存在数据库中上一次的结果进行unmarshal通常是一个map[string]interface{}直接作为http的resp // 如果发生重复调用则busiCall不会被重复调用直接对保存在数据库中上一次的结果进行unmarshal通常是一个map[string]interface{}直接作为http的resp
// 如果发生悬挂则busiCall不会被调用直接返回错误 {"dtm_result": "FAILURE"} // 如果发生悬挂则busiCall不会被调用直接返回错误 {"dtm_result": "FAILURE"}
// 如果发生空补偿则busiCall不会被调用直接返回 {"dtm_result": "SUCCESS"} // 如果发生空补偿则busiCall不会被调用直接返回 {"dtm_result": "SUCCESS"}
func ThroughBarrierCall(db *sql.DB, transInfo *TransInfo, busiCall BusiFunc) (res interface{}, rerr error) { func (bb *BranchBarrier) Call(db *sql.DB, busiCall BusiFunc) (res interface{}, rerr error) {
bb.BarrierID = bb.BarrierID + 1
bid := fmt.Sprintf("%02d", bb.BarrierID)
tx, rerr := db.BeginTx(context.Background(), &sql.TxOptions{}) tx, rerr := db.BeginTx(context.Background(), &sql.TxOptions{})
if rerr != nil { if rerr != nil {
return return
@ -69,21 +72,21 @@ func ThroughBarrierCall(db *sql.DB, transInfo *TransInfo, busiCall BusiFunc) (re
tx.Commit() tx.Commit()
} }
}() }()
ti := transInfo ti := bb
originType := map[string]string{ originType := map[string]string{
"cancel": "try", "cancel": "try",
"compensate": "action", "compensate": "action",
}[ti.BranchType] }[ti.BranchType]
originAffected, _ := insertBarrier(tx, ti.TransType, ti.Gid, ti.BranchID, originType, ti.BranchType) originAffected, _ := insertBarrier(tx, ti.TransType, ti.Gid, ti.BranchID, originType, bid, ti.BranchType)
currentAffected, rerr := insertBarrier(tx, ti.TransType, ti.Gid, ti.BranchID, ti.BranchType, ti.BranchType) currentAffected, rerr := insertBarrier(tx, ti.TransType, ti.Gid, ti.BranchID, ti.BranchType, bid, ti.BranchType)
Logf("originAffected: %d currentAffected: %d", originAffected, currentAffected) Logf("originAffected: %d currentAffected: %d", originAffected, currentAffected)
if (ti.BranchType == "cancel" || ti.BranchType == "compensate") && originAffected > 0 { // 这个是空补偿,返回成功 if (ti.BranchType == "cancel" || ti.BranchType == "compensate") && originAffected > 0 { // 这个是空补偿,返回成功
res = ResultSuccess res = ResultSuccess
return return
} else if currentAffected == 0 { // 插入不成功 } else if currentAffected == 0 { // 插入不成功
var result sql.NullString var result sql.NullString
err := StxQueryRow(tx, "select result from dtm_barrier.barrier where trans_type=? and gid=? and branch_id=? and branch_type=? and reason=?", err := StxQueryRow(tx, "select result from dtm_barrier.barrier where trans_type=? and gid=? and branch_id=? and branch_type=? and barrier_id=? and reason=?",
ti.TransType, ti.Gid, ti.BranchID, ti.BranchType, ti.BranchType).Scan(&result) ti.TransType, ti.Gid, ti.BranchID, ti.BranchType, bid, ti.BranchType).Scan(&result)
if err == sql.ErrNoRows { // 这个是悬挂操作返回失败AP收到这个返回会尽快回滚 if err == sql.ErrNoRows { // 这个是悬挂操作返回失败AP收到这个返回会尽快回滚
res = ResultFailure res = ResultFailure
return return

View File

@ -7,11 +7,12 @@ create table if not exists dtm_barrier.barrier(
gid varchar(128) default'', gid varchar(128) default'',
branch_id varchar(128) default '', branch_id varchar(128) default '',
branch_type varchar(45) default '', branch_type varchar(45) default '',
barrier_id varchar(45) default '',
reason varchar(45) default '' comment 'the branch type who insert this record', reason varchar(45) default '' comment 'the branch type who insert this record',
result varchar(2047) default null comment 'the business result of this branch', result varchar(2047) default null comment 'the business result of this branch',
create_time datetime DEFAULT now(), create_time datetime DEFAULT now(),
update_time datetime DEFAULT now(), update_time datetime DEFAULT now(),
key(create_time), key(create_time),
key(update_time), key(update_time),
UNIQUE key(gid, branch_id, branch_type) UNIQUE key(gid, branch_id, branch_type, barrier_id)
); );

View File

@ -14,7 +14,7 @@ func TestQuery(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
_, err = TccFromQuery(qs) _, err = TccFromQuery(qs)
assert.Error(t, err) assert.Error(t, err)
_, err = TransInfoFromQuery(qs) _, err = BarrierFromQuery(qs)
assert.Error(t, err) assert.Error(t, err)
} }

View File

@ -22,7 +22,7 @@ func TestTypes(t *testing.T) {
}) })
assert.Error(t, err) assert.Error(t, err)
assert.Error(t, err) assert.Error(t, err)
_, err = TransInfoFromQuery(url.Values{}) _, err = BarrierFromQuery(url.Values{})
assert.Error(t, err) assert.Error(t, err)
} }

View File

@ -52,8 +52,8 @@ func reqFrom(c *gin.Context) *TransReq {
return v.(*TransReq) return v.(*TransReq)
} }
func infoFromContext(c *gin.Context) *dtmcli.TransInfo { func infoFromContext(c *gin.Context) *dtmcli.BranchBarrier {
info := dtmcli.TransInfo{ info := dtmcli.BranchBarrier{
TransType: c.Query("trans_type"), TransType: c.Query("trans_type"),
Gid: c.Query("gid"), Gid: c.Query("gid"),
BranchID: c.Query("branch_id"), BranchID: c.Query("branch_id"),
@ -73,8 +73,8 @@ func sdbGet() *sql.DB {
} }
// MustGetTrans construct transaction info from request // MustGetTrans construct transaction info from request
func MustGetTrans(c *gin.Context) *dtmcli.TransInfo { func MustGetTrans(c *gin.Context) *dtmcli.BranchBarrier {
ti, err := dtmcli.TransInfoFromQuery(c.Request.URL.Query()) ti, err := dtmcli.BarrierFromQuery(c.Request.URL.Query())
e2p(err) e2p(err)
return ti return ti
} }

View File

@ -42,13 +42,15 @@ func sagaBarrierTransIn(c *gin.Context) (interface{}, error) {
if req.TransInResult != "" { if req.TransInResult != "" {
return req.TransInResult, nil return req.TransInResult, nil
} }
return dtmcli.ThroughBarrierCall(sdbGet(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) { barrier := MustGetTrans(c)
return barrier.Call(sdbGet(), func(sdb *sql.Tx) (interface{}, error) {
return sagaBarrierAdjustBalance(sdb, 1, req.Amount) return sagaBarrierAdjustBalance(sdb, 1, req.Amount)
}) })
} }
func sagaBarrierTransInCompensate(c *gin.Context) (interface{}, error) { func sagaBarrierTransInCompensate(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(sdbGet(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) { barrier := MustGetTrans(c)
return barrier.Call(sdbGet(), func(sdb *sql.Tx) (interface{}, error) {
return sagaBarrierAdjustBalance(sdb, 1, -reqFrom(c).Amount) return sagaBarrierAdjustBalance(sdb, 1, -reqFrom(c).Amount)
}) })
} }
@ -58,13 +60,15 @@ func sagaBarrierTransOut(c *gin.Context) (interface{}, error) {
if req.TransInResult != "" { if req.TransInResult != "" {
return req.TransInResult, nil return req.TransInResult, nil
} }
return dtmcli.ThroughBarrierCall(sdbGet(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) { barrier := MustGetTrans(c)
return barrier.Call(sdbGet(), func(sdb *sql.Tx) (interface{}, error) {
return sagaBarrierAdjustBalance(sdb, 2, -req.Amount) return sagaBarrierAdjustBalance(sdb, 2, -req.Amount)
}) })
} }
func sagaBarrierTransOutCompensate(c *gin.Context) (interface{}, error) { func sagaBarrierTransOutCompensate(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(sdbGet(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) { barrier := MustGetTrans(c)
return barrier.Call(sdbGet(), func(sdb *sql.Tx) (interface{}, error) {
return sagaBarrierAdjustBalance(sdb, 2, reqFrom(c).Amount) return sagaBarrierAdjustBalance(sdb, 2, reqFrom(c).Amount)
}) })
} }

View File

@ -65,19 +65,22 @@ func tccBarrierTransInTry(c *gin.Context) (interface{}, error) {
if req.TransInResult != "" { if req.TransInResult != "" {
return req.TransInResult, nil return req.TransInResult, nil
} }
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) { barrier := MustGetTrans(c)
return barrier.Call(dbGet().ToSQLDB(), func(sdb *sql.Tx) (interface{}, error) {
return adjustTrading(sdb, transInUID, req.Amount) return adjustTrading(sdb, transInUID, req.Amount)
}) })
} }
func tccBarrierTransInConfirm(c *gin.Context) (interface{}, error) { func tccBarrierTransInConfirm(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) { barrier := MustGetTrans(c)
return barrier.Call(dbGet().ToSQLDB(), func(sdb *sql.Tx) (interface{}, error) {
return adjustBalance(sdb, transInUID, reqFrom(c).Amount) return adjustBalance(sdb, transInUID, reqFrom(c).Amount)
}) })
} }
func tccBarrierTransInCancel(c *gin.Context) (interface{}, error) { func tccBarrierTransInCancel(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) { barrier := MustGetTrans(c)
return barrier.Call(dbGet().ToSQLDB(), func(sdb *sql.Tx) (interface{}, error) {
return adjustTrading(sdb, transInUID, -reqFrom(c).Amount) return adjustTrading(sdb, transInUID, -reqFrom(c).Amount)
}) })
} }
@ -87,20 +90,23 @@ func tccBarrierTransOutTry(c *gin.Context) (interface{}, error) {
if req.TransInResult != "" { if req.TransInResult != "" {
return req.TransInResult, nil return req.TransInResult, nil
} }
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) { barrier := MustGetTrans(c)
return barrier.Call(dbGet().ToSQLDB(), func(sdb *sql.Tx) (interface{}, error) {
return adjustTrading(sdb, transOutUID, -req.Amount) return adjustTrading(sdb, transOutUID, -req.Amount)
}) })
} }
func tccBarrierTransOutConfirm(c *gin.Context) (interface{}, error) { func tccBarrierTransOutConfirm(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) { barrier := MustGetTrans(c)
return barrier.Call(dbGet().ToSQLDB(), func(sdb *sql.Tx) (interface{}, error) {
return adjustBalance(sdb, transOutUID, -reqFrom(c).Amount) return adjustBalance(sdb, transOutUID, -reqFrom(c).Amount)
}) })
} }
// TccBarrierTransOutCancel will be use in test // TccBarrierTransOutCancel will be use in test
func TccBarrierTransOutCancel(c *gin.Context) (interface{}, error) { func TccBarrierTransOutCancel(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) { barrier := MustGetTrans(c)
return barrier.Call(dbGet().ToSQLDB(), func(sdb *sql.Tx) (interface{}, error) {
return adjustTrading(sdb, transOutUID, reqFrom(c).Amount) return adjustTrading(sdb, transOutUID, reqFrom(c).Amount)
}) })
} }

View File

@ -20,7 +20,7 @@ var app *gin.Engine
// BarrierModel barrier model for gorm // BarrierModel barrier model for gorm
type BarrierModel struct { type BarrierModel struct {
common.ModelBase common.ModelBase
dtmcli.TransInfo dtmcli.BranchBarrier
} }
// TableName gorm table name // TableName gorm table name
@ -111,14 +111,14 @@ func transQuery(t *testing.T, gid string) {
func TestSqlDB(t *testing.T) { func TestSqlDB(t *testing.T) {
asserts := assert.New(t) asserts := assert.New(t)
db := common.DbGet(config.DB) db := common.DbGet(config.DB)
transInfo := &dtmcli.TransInfo{ barrier := &dtmcli.BranchBarrier{
TransType: "saga", TransType: "saga",
Gid: "gid2", Gid: "gid2",
BranchID: "branch_id2", BranchID: "branch_id2",
BranchType: "action", BranchType: "action",
} }
db.Must().Exec("insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values('saga', 'gid1', 'branch_id1', 'action', 'saga')") db.Must().Exec("insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values('saga', 'gid1', 'branch_id1', 'action', 'saga')")
_, err := dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.Tx) (interface{}, error) { _, err := barrier.Call(db.ToSQLDB(), func(db *sql.Tx) (interface{}, error) {
dtmcli.Logf("rollback gid2") dtmcli.Logf("rollback gid2")
return nil, fmt.Errorf("gid2 error") return nil, fmt.Errorf("gid2 error")
}) })
@ -128,14 +128,16 @@ func TestSqlDB(t *testing.T) {
dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{}) dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{})
asserts.Equal(dbr.RowsAffected, int64(0)) asserts.Equal(dbr.RowsAffected, int64(0))
gid2Res := dtmcli.M{"result": "first"} gid2Res := dtmcli.M{"result": "first"}
_, err = dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.Tx) (interface{}, error) { barrier.BarrierID = 0
_, err = barrier.Call(db.ToSQLDB(), func(db *sql.Tx) (interface{}, error) {
dtmcli.Logf("submit gid2") dtmcli.Logf("submit gid2")
return gid2Res, nil return gid2Res, nil
}) })
asserts.Nil(err) asserts.Nil(err)
dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{}) dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{})
asserts.Equal(dbr.RowsAffected, int64(1)) asserts.Equal(dbr.RowsAffected, int64(1))
newResult, err := dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.Tx) (interface{}, error) { barrier.BarrierID = 0
newResult, err := barrier.Call(db.ToSQLDB(), func(db *sql.Tx) (interface{}, error) {
dtmcli.Logf("submit gid2") dtmcli.Logf("submit gid2")
return dtmcli.MS{"result": "ignored"}, nil return dtmcli.MS{"result": "ignored"}, nil
}) })