From fffeb902a65cb74ae2d0691ae3020315d92db252 Mon Sep 17 00:00:00 2001 From: yedf2 <120050102@qq.com> Date: Wed, 11 Aug 2021 16:41:57 +0800 Subject: [PATCH] barrier refactored --- dtmcli/barrier.go | 35 +++++++++++++++++++---------------- dtmcli/barrier.mysql.sql | 3 ++- dtmcli/trans_test.go | 2 +- dtmcli/types_test.go | 2 +- examples/base_types.go | 8 ++++---- examples/saga_barrier.go | 12 ++++++++---- examples/tcc_barrier.go | 18 ++++++++++++------ test/dtmsvr_test.go | 12 +++++++----- 8 files changed, 54 insertions(+), 38 deletions(-) diff --git a/dtmcli/barrier.go b/dtmcli/barrier.go index 9872078..223dc51 100644 --- a/dtmcli/barrier.go +++ b/dtmcli/barrier.go @@ -11,21 +11,22 @@ import ( // BusiFunc type for busi func type BusiFunc func(db *sql.Tx) (interface{}, error) -// TransInfo every branch info -type TransInfo struct { +// BranchBarrier every branch info +type BranchBarrier struct { TransType string Gid string BranchID string BranchType string + BarrierID int } -func (t *TransInfo) String() string { - return fmt.Sprintf("transInfo: %s %s %s %s", t.TransType, t.Gid, t.BranchID, t.BranchType) +func (bb *BranchBarrier) String() string { + return fmt.Sprintf("transInfo: %s %s %s %s", bb.TransType, bb.Gid, bb.BranchID, bb.BranchType) } -// TransInfoFromQuery construct transaction info from request -func TransInfoFromQuery(qs url.Values) (*TransInfo, error) { - ti := &TransInfo{ +// BarrierFromQuery construct transaction info from request +func BarrierFromQuery(qs url.Values) (*BranchBarrier, error) { + ti := &BranchBarrier{ TransType: qs.Get("trans_type"), Gid: qs.Get("gid"), BranchID: qs.Get("branch_id"), @@ -37,14 +38,14 @@ func TransInfoFromQuery(qs url.Values) (*TransInfo, error) { return ti, nil } -func insertBarrier(tx *sql.Tx, transType string, gid string, branchID string, branchType string, reason string) (int64, error) { +func insertBarrier(tx *sql.Tx, transType string, gid string, branchID string, branchType string, barrierID string, reason string) (int64, error) { if branchType == "" { return 0, nil } - return StxExec(tx, "insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values(?,?,?,?,?)", transType, gid, branchID, branchType, reason) + return StxExec(tx, "insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, barrier_id, reason) values(?,?,?,?,?,?)", transType, gid, branchID, branchType, barrierID, reason) } -// ThroughBarrierCall 子事务屏障,详细介绍见 https://zhuanlan.zhihu.com/p/388444465 +// Call 子事务屏障,详细介绍见 https://zhuanlan.zhihu.com/p/388444465 // db: 本地数据库 // transInfo: 事务信息 // bisiCall: 业务函数,仅在必要时被调用 @@ -53,7 +54,9 @@ func insertBarrier(tx *sql.Tx, transType string, gid string, branchID string, br // 如果发生重复调用,则busiCall不会被重复调用,直接对保存在数据库中上一次的结果,进行unmarshal,通常是一个map[string]interface{},直接作为http的resp // 如果发生悬挂,则busiCall不会被调用,直接返回错误 {"dtm_result": "FAILURE"} // 如果发生空补偿,则busiCall不会被调用,直接返回 {"dtm_result": "SUCCESS"} -func ThroughBarrierCall(db *sql.DB, transInfo *TransInfo, busiCall BusiFunc) (res interface{}, rerr error) { +func (bb *BranchBarrier) Call(db *sql.DB, busiCall BusiFunc) (res interface{}, rerr error) { + bb.BarrierID = bb.BarrierID + 1 + bid := fmt.Sprintf("%02d", bb.BarrierID) tx, rerr := db.BeginTx(context.Background(), &sql.TxOptions{}) if rerr != nil { return @@ -69,21 +72,21 @@ func ThroughBarrierCall(db *sql.DB, transInfo *TransInfo, busiCall BusiFunc) (re tx.Commit() } }() - ti := transInfo + ti := bb originType := map[string]string{ "cancel": "try", "compensate": "action", }[ti.BranchType] - originAffected, _ := insertBarrier(tx, ti.TransType, ti.Gid, ti.BranchID, originType, ti.BranchType) - currentAffected, rerr := insertBarrier(tx, ti.TransType, ti.Gid, ti.BranchID, ti.BranchType, ti.BranchType) + originAffected, _ := insertBarrier(tx, ti.TransType, ti.Gid, ti.BranchID, originType, bid, ti.BranchType) + currentAffected, rerr := insertBarrier(tx, ti.TransType, ti.Gid, ti.BranchID, ti.BranchType, bid, ti.BranchType) Logf("originAffected: %d currentAffected: %d", originAffected, currentAffected) if (ti.BranchType == "cancel" || ti.BranchType == "compensate") && originAffected > 0 { // 这个是空补偿,返回成功 res = ResultSuccess return } else if currentAffected == 0 { // 插入不成功 var result sql.NullString - err := StxQueryRow(tx, "select result from dtm_barrier.barrier where trans_type=? and gid=? and branch_id=? and branch_type=? and reason=?", - ti.TransType, ti.Gid, ti.BranchID, ti.BranchType, ti.BranchType).Scan(&result) + err := StxQueryRow(tx, "select result from dtm_barrier.barrier where trans_type=? and gid=? and branch_id=? and branch_type=? and barrier_id=? and reason=?", + ti.TransType, ti.Gid, ti.BranchID, ti.BranchType, bid, ti.BranchType).Scan(&result) if err == sql.ErrNoRows { // 这个是悬挂操作,返回失败,AP收到这个返回,会尽快回滚 res = ResultFailure return diff --git a/dtmcli/barrier.mysql.sql b/dtmcli/barrier.mysql.sql index ccb837c..67eaa5d 100644 --- a/dtmcli/barrier.mysql.sql +++ b/dtmcli/barrier.mysql.sql @@ -7,11 +7,12 @@ create table if not exists dtm_barrier.barrier( gid varchar(128) default'', branch_id varchar(128) default '', branch_type varchar(45) default '', + barrier_id varchar(45) default '', reason varchar(45) default '' comment 'the branch type who insert this record', result varchar(2047) default null comment 'the business result of this branch', create_time datetime DEFAULT now(), update_time datetime DEFAULT now(), key(create_time), key(update_time), - UNIQUE key(gid, branch_id, branch_type) + UNIQUE key(gid, branch_id, branch_type, barrier_id) ); diff --git a/dtmcli/trans_test.go b/dtmcli/trans_test.go index 9b75a82..05ab78e 100644 --- a/dtmcli/trans_test.go +++ b/dtmcli/trans_test.go @@ -14,7 +14,7 @@ func TestQuery(t *testing.T) { assert.Error(t, err) _, err = TccFromQuery(qs) assert.Error(t, err) - _, err = TransInfoFromQuery(qs) + _, err = BarrierFromQuery(qs) assert.Error(t, err) } diff --git a/dtmcli/types_test.go b/dtmcli/types_test.go index 5cdb5a1..f8c647e 100644 --- a/dtmcli/types_test.go +++ b/dtmcli/types_test.go @@ -22,7 +22,7 @@ func TestTypes(t *testing.T) { }) assert.Error(t, err) assert.Error(t, err) - _, err = TransInfoFromQuery(url.Values{}) + _, err = BarrierFromQuery(url.Values{}) assert.Error(t, err) } diff --git a/examples/base_types.go b/examples/base_types.go index 25134b8..eb11406 100644 --- a/examples/base_types.go +++ b/examples/base_types.go @@ -52,8 +52,8 @@ func reqFrom(c *gin.Context) *TransReq { return v.(*TransReq) } -func infoFromContext(c *gin.Context) *dtmcli.TransInfo { - info := dtmcli.TransInfo{ +func infoFromContext(c *gin.Context) *dtmcli.BranchBarrier { + info := dtmcli.BranchBarrier{ TransType: c.Query("trans_type"), Gid: c.Query("gid"), BranchID: c.Query("branch_id"), @@ -73,8 +73,8 @@ func sdbGet() *sql.DB { } // MustGetTrans construct transaction info from request -func MustGetTrans(c *gin.Context) *dtmcli.TransInfo { - ti, err := dtmcli.TransInfoFromQuery(c.Request.URL.Query()) +func MustGetTrans(c *gin.Context) *dtmcli.BranchBarrier { + ti, err := dtmcli.BarrierFromQuery(c.Request.URL.Query()) e2p(err) return ti } diff --git a/examples/saga_barrier.go b/examples/saga_barrier.go index e044b16..9da16a0 100644 --- a/examples/saga_barrier.go +++ b/examples/saga_barrier.go @@ -42,13 +42,15 @@ func sagaBarrierTransIn(c *gin.Context) (interface{}, error) { if req.TransInResult != "" { return req.TransInResult, nil } - return dtmcli.ThroughBarrierCall(sdbGet(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) { + barrier := MustGetTrans(c) + return barrier.Call(sdbGet(), func(sdb *sql.Tx) (interface{}, error) { return sagaBarrierAdjustBalance(sdb, 1, req.Amount) }) } func sagaBarrierTransInCompensate(c *gin.Context) (interface{}, error) { - return dtmcli.ThroughBarrierCall(sdbGet(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) { + barrier := MustGetTrans(c) + return barrier.Call(sdbGet(), func(sdb *sql.Tx) (interface{}, error) { return sagaBarrierAdjustBalance(sdb, 1, -reqFrom(c).Amount) }) } @@ -58,13 +60,15 @@ func sagaBarrierTransOut(c *gin.Context) (interface{}, error) { if req.TransInResult != "" { return req.TransInResult, nil } - return dtmcli.ThroughBarrierCall(sdbGet(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) { + barrier := MustGetTrans(c) + return barrier.Call(sdbGet(), func(sdb *sql.Tx) (interface{}, error) { return sagaBarrierAdjustBalance(sdb, 2, -req.Amount) }) } func sagaBarrierTransOutCompensate(c *gin.Context) (interface{}, error) { - return dtmcli.ThroughBarrierCall(sdbGet(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) { + barrier := MustGetTrans(c) + return barrier.Call(sdbGet(), func(sdb *sql.Tx) (interface{}, error) { return sagaBarrierAdjustBalance(sdb, 2, reqFrom(c).Amount) }) } diff --git a/examples/tcc_barrier.go b/examples/tcc_barrier.go index d2eb33f..95b7780 100644 --- a/examples/tcc_barrier.go +++ b/examples/tcc_barrier.go @@ -65,19 +65,22 @@ func tccBarrierTransInTry(c *gin.Context) (interface{}, error) { if req.TransInResult != "" { return req.TransInResult, nil } - return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) { + barrier := MustGetTrans(c) + return barrier.Call(dbGet().ToSQLDB(), func(sdb *sql.Tx) (interface{}, error) { return adjustTrading(sdb, transInUID, req.Amount) }) } func tccBarrierTransInConfirm(c *gin.Context) (interface{}, error) { - return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) { + barrier := MustGetTrans(c) + return barrier.Call(dbGet().ToSQLDB(), func(sdb *sql.Tx) (interface{}, error) { return adjustBalance(sdb, transInUID, reqFrom(c).Amount) }) } func tccBarrierTransInCancel(c *gin.Context) (interface{}, error) { - return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) { + barrier := MustGetTrans(c) + return barrier.Call(dbGet().ToSQLDB(), func(sdb *sql.Tx) (interface{}, error) { return adjustTrading(sdb, transInUID, -reqFrom(c).Amount) }) } @@ -87,20 +90,23 @@ func tccBarrierTransOutTry(c *gin.Context) (interface{}, error) { if req.TransInResult != "" { return req.TransInResult, nil } - return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) { + barrier := MustGetTrans(c) + return barrier.Call(dbGet().ToSQLDB(), func(sdb *sql.Tx) (interface{}, error) { return adjustTrading(sdb, transOutUID, -req.Amount) }) } func tccBarrierTransOutConfirm(c *gin.Context) (interface{}, error) { - return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) { + barrier := MustGetTrans(c) + return barrier.Call(dbGet().ToSQLDB(), func(sdb *sql.Tx) (interface{}, error) { return adjustBalance(sdb, transOutUID, -reqFrom(c).Amount) }) } // TccBarrierTransOutCancel will be use in test func TccBarrierTransOutCancel(c *gin.Context) (interface{}, error) { - return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) { + barrier := MustGetTrans(c) + return barrier.Call(dbGet().ToSQLDB(), func(sdb *sql.Tx) (interface{}, error) { return adjustTrading(sdb, transOutUID, reqFrom(c).Amount) }) } diff --git a/test/dtmsvr_test.go b/test/dtmsvr_test.go index 8888334..847e56e 100644 --- a/test/dtmsvr_test.go +++ b/test/dtmsvr_test.go @@ -20,7 +20,7 @@ var app *gin.Engine // BarrierModel barrier model for gorm type BarrierModel struct { common.ModelBase - dtmcli.TransInfo + dtmcli.BranchBarrier } // TableName gorm table name @@ -111,14 +111,14 @@ func transQuery(t *testing.T, gid string) { func TestSqlDB(t *testing.T) { asserts := assert.New(t) db := common.DbGet(config.DB) - transInfo := &dtmcli.TransInfo{ + barrier := &dtmcli.BranchBarrier{ TransType: "saga", Gid: "gid2", BranchID: "branch_id2", BranchType: "action", } db.Must().Exec("insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values('saga', 'gid1', 'branch_id1', 'action', 'saga')") - _, err := dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.Tx) (interface{}, error) { + _, err := barrier.Call(db.ToSQLDB(), func(db *sql.Tx) (interface{}, error) { dtmcli.Logf("rollback gid2") return nil, fmt.Errorf("gid2 error") }) @@ -128,14 +128,16 @@ func TestSqlDB(t *testing.T) { dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{}) asserts.Equal(dbr.RowsAffected, int64(0)) gid2Res := dtmcli.M{"result": "first"} - _, err = dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.Tx) (interface{}, error) { + barrier.BarrierID = 0 + _, err = barrier.Call(db.ToSQLDB(), func(db *sql.Tx) (interface{}, error) { dtmcli.Logf("submit gid2") return gid2Res, nil }) asserts.Nil(err) dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{}) asserts.Equal(dbr.RowsAffected, int64(1)) - newResult, err := dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.Tx) (interface{}, error) { + barrier.BarrierID = 0 + newResult, err := barrier.Call(db.ToSQLDB(), func(db *sql.Tx) (interface{}, error) { dtmcli.Logf("submit gid2") return dtmcli.MS{"result": "ignored"}, nil })