diff --git a/common/types_test.go b/common/types_test.go index bce7cfc..906d7a0 100644 --- a/common/types_test.go +++ b/common/types_test.go @@ -22,9 +22,9 @@ func TestDb(t *testing.T) { func TestDbAlone(t *testing.T) { db, err := dtmcli.SdbAlone(DtmConfig.DB) assert.Nil(t, err) - _, err = dtmcli.SdbExec(db, "select 1") + _, err = dtmcli.DBExec(db, "select 1") assert.Equal(t, nil, err) db.Close() - _, err = dtmcli.SdbExec(db, "select 1") + _, err = dtmcli.DBExec(db, "select 1") assert.NotEqual(t, nil, err) } diff --git a/doc/barrier-en.md b/doc/barrier-en.md index 0f320c8..1dcfe6a 100644 --- a/doc/barrier-en.md +++ b/doc/barrier-en.md @@ -17,8 +17,8 @@ Let’s take a look at a sequence diagram of network abnormalities to better und ![abnormal network image](https://pic2.zhimg.com/80/v2-04c577b69ab7145ab493a8158a048a08_1440w.png) -- When business processing step 4, Cancel is executed before Try, and empty rollback needs to be processed. -- When business processing step 6, Cancel is executed repeatedly and needs to be idempotent. +- When business processing step 4, Cancel is executed before Try, and empty rollback needs to be processed. +- When business processing step 6, Cancel is executed repeatedly and needs to be idempotent. - When business processing step 8, Try is executed after Cancel and needs to be processed. For the above-mentioned complex network abnormalities, all distributed transaction systems currently recommend that business developers use unique keys to query whether the associated operations have been completed. If it is completed, it returns directly to success. The related logic is complex, error-prone, and the business burden is heavy. @@ -72,19 +72,21 @@ Under this mechanism, problems related to network abnormalities are solved - Idempotent control: No single key can be inserted repeatedly in any branch, which ensures that it will not be executed repeatedly - Dangling action control: Try is to be executed after Cancel, then the inserted gid-branchid-try is not successful, it will not be executed -Let's take a look at the example in main_tcc_barrier.go in dtm: +Let's take a look at the example in http_tcc_barrier.go in dtm: ``` GO func tccBarrierTransInTry(c *gin.Context) (interface{}, error) { - return dtmcli.ThroughBarrierCall(db, dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { - return adjustTrading(sdb, transInUid, reqFrom(c).Amount) - }) + req := reqFrom(c) // 去重构一下,改成可以重复使用的输入 + barrier := MustBarrierFromGin(c) + return dtmcli.ResultSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error { + return adjustTrading(db, transInUID, req.Amount) + }) } ``` -The Try in the TransIn business only needs one ThroughBarrierCall call to handle the above abnormal situation, which greatly simplifies the work of business developers. For SAGA transactions, reliable messages, etc., a similar mechanism can also be used. +The Try in the TransIn business only needs one barrier.Call call to handle the above abnormal situation, which greatly simplifies the work of business developers. For SAGA transactions, reliable messages, etc., a similar mechanism can also be used. ## summary -The sub-transaction barrier technology proposed in this project systematically solves the problem of network disorder in distributed transactions and greatly reduces the difficulty of sub-transaction disorder processing. +The sub-transaction barrier technology proposed in this project systematically solves the problem of network disorder in distributed transactions and greatly reduces the difficulty of sub-transaction disorder processing. Other development languages ​​can also quickly access the technology \ No newline at end of file diff --git a/dtmcli/adapter.go b/dtmcli/adapter.go new file mode 100644 index 0000000..d549caa --- /dev/null +++ b/dtmcli/adapter.go @@ -0,0 +1,18 @@ +package dtmcli + +import ( + "database/sql" +) + +// DB inteface of dtmcli db +type DB interface { + Exec(query string, args ...interface{}) (sql.Result, error) + QueryRow(query string, args ...interface{}) *sql.Row +} + +// Tx interface of dtmcli tx +type Tx interface { + Rollback() error + Commit() error + DB +} diff --git a/dtmcli/barrier.go b/dtmcli/barrier.go index 1096983..bbd982f 100644 --- a/dtmcli/barrier.go +++ b/dtmcli/barrier.go @@ -1,14 +1,13 @@ package dtmcli import ( - "context" "database/sql" "fmt" "net/url" ) // BusiFunc type for busi func -type BusiFunc func(db *sql.Tx) error +type BusiFunc func(db DB) error // BranchBarrier every branch info type BranchBarrier struct { @@ -42,11 +41,11 @@ func BarrierFrom(transType, gid, branchID, branchType string) (*BranchBarrier, e return ti, nil } -func insertBarrier(tx *sql.Tx, transType string, gid string, branchID string, branchType string, barrierID string, reason string) (int64, error) { +func insertBarrier(tx Tx, transType string, gid string, branchID string, branchType string, barrierID string, reason string) (int64, error) { if branchType == "" { return 0, nil } - return StxExec(tx, "insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, barrier_id, reason) values(?,?,?,?,?,?)", transType, gid, branchID, branchType, barrierID, reason) + return DBExec(tx, "insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, barrier_id, reason) values(?,?,?,?,?,?)", transType, gid, branchID, branchType, barrierID, reason) } // Call 子事务屏障,详细介绍见 https://zhuanlan.zhihu.com/p/388444465 @@ -56,10 +55,9 @@ func insertBarrier(tx *sql.Tx, transType string, gid string, branchID string, br // 返回值: // 如果发生悬挂,则busiCall不会被调用,直接返回错误 ErrFailure,全局事务尽早进行回滚 // 如果正常调用,重复调用,空补偿,返回的错误值为nil,正常往下进行 -func (bb *BranchBarrier) Call(db *sql.DB, busiCall BusiFunc) (rerr error) { +func (bb *BranchBarrier) Call(tx Tx, busiCall BusiFunc) (rerr error) { bb.BarrierID = bb.BarrierID + 1 bid := fmt.Sprintf("%02d", bb.BarrierID) - tx, rerr := db.BeginTx(context.Background(), &sql.TxOptions{}) if rerr != nil { return } @@ -86,7 +84,7 @@ func (bb *BranchBarrier) Call(db *sql.DB, busiCall BusiFunc) (rerr error) { return } else if currentAffected == 0 { // 插入不成功 var result sql.NullString - err := StxQueryRow(tx, "select 1 from dtm_barrier.barrier where trans_type=? and gid=? and branch_id=? and branch_type=? and barrier_id=? and reason=?", + err := DBQueryRow(tx, "select 1 from dtm_barrier.barrier where trans_type=? and gid=? and branch_id=? and branch_type=? and barrier_id=? and reason=?", ti.TransType, ti.Gid, ti.BranchID, ti.BranchType, bid, ti.BranchType).Scan(&result) if err == sql.ErrNoRows { // 不是当前分支插入的,那么是cancel插入的,因此是悬挂操作,返回失败,AP收到这个返回,会尽快回滚 rerr = ErrFailure diff --git a/dtmcli/utils.go b/dtmcli/utils.go index d8ac577..33bc8c0 100644 --- a/dtmcli/utils.go +++ b/dtmcli/utils.go @@ -202,8 +202,8 @@ func SdbAlone(conf map[string]string) (*sql.DB, error) { return sql.Open(conf["driver"], dsn) } -// SdbExec use raw db to exec -func SdbExec(db *sql.DB, sql string, values ...interface{}) (affected int64, rerr error) { +// DBExec use raw db to exec +func DBExec(db DB, sql string, values ...interface{}) (affected int64, rerr error) { r, rerr := db.Exec(sql, values...) if rerr == nil { affected, rerr = r.RowsAffected() @@ -214,22 +214,10 @@ func SdbExec(db *sql.DB, sql string, values ...interface{}) (affected int64, rer return } -// StxExec use raw tx to exec -func StxExec(tx *sql.Tx, sql string, values ...interface{}) (affected int64, rerr error) { - r, rerr := tx.Exec(sql, values...) - if rerr == nil { - affected, rerr = r.RowsAffected() - Logf("affected: %d for %s %v", affected, sql, values) - } else { - LogRedf("exec error: %v for %s %v", rerr, sql, values) - } - return -} - -// StxQueryRow use raw tx to query row -func StxQueryRow(tx *sql.Tx, query string, args ...interface{}) *sql.Row { +// DBQueryRow use raw tx to query row +func DBQueryRow(db DB, query string, args ...interface{}) *sql.Row { Logf("querying: "+query, args...) - return tx.QueryRow(query, args...) + return db.QueryRow(query, args...) } // GetDsn get dsn from map config diff --git a/dtmcli/xa.go b/dtmcli/xa.go index c1c9542..bb774af 100644 --- a/dtmcli/xa.go +++ b/dtmcli/xa.go @@ -61,7 +61,7 @@ func (xc *XaClient) HandleCallback(gid string, branchID string, action string) ( } defer db.Close() xaID := gid + "-" + branchID - _, err = SdbExec(db, fmt.Sprintf("xa %s '%s'", action, xaID)) + _, err = DBExec(db, fmt.Sprintf("xa %s '%s'", action, xaID)) return ResultSuccess, err } @@ -82,9 +82,9 @@ func (xc *XaClient) XaLocalTransaction(qs url.Values, xaFunc XaLocalFunc) (ret i defer func() { db.Close() }() defer func() { x := recover() - _, err := SdbExec(db, fmt.Sprintf("XA end '%s'", xaBranch)) + _, err := DBExec(db, fmt.Sprintf("XA end '%s'", xaBranch)) if x == nil && rerr == nil && err == nil { - _, err = SdbExec(db, fmt.Sprintf("XA prepare '%s'", xaBranch)) + _, err = DBExec(db, fmt.Sprintf("XA prepare '%s'", xaBranch)) } if rerr == nil { rerr = err @@ -93,7 +93,7 @@ func (xc *XaClient) XaLocalTransaction(qs url.Values, xaFunc XaLocalFunc) (ret i panic(x) } }() - _, rerr = SdbExec(db, fmt.Sprintf("XA start '%s'", xaBranch)) + _, rerr = DBExec(db, fmt.Sprintf("XA start '%s'", xaBranch)) if rerr != nil { return } diff --git a/dtmgrpc/barrier.go b/dtmgrpc/barrier.go index b68c5ee..1389da6 100644 --- a/dtmgrpc/barrier.go +++ b/dtmgrpc/barrier.go @@ -1,8 +1,6 @@ package dtmgrpc import ( - "database/sql" - "github.com/yedf/dtm/dtmcli" "google.golang.org/grpc/codes" status "google.golang.org/grpc/status" @@ -20,8 +18,8 @@ type BranchBarrier struct { // 返回值: // 如果发生悬挂,则busiCall不会被调用,直接返回错误 ErrFailure,全局事务尽早进行回滚 // 如果正常调用,重复调用,空补偿,返回的错误值为nil,正常往下进行 -func (bb *BranchBarrier) Call(db *sql.DB, busiCall dtmcli.BusiFunc) (rerr error) { - err := bb.BranchBarrier.Call(db, busiCall) +func (bb *BranchBarrier) Call(tx dtmcli.Tx, busiCall dtmcli.BusiFunc) (rerr error) { + err := bb.BranchBarrier.Call(tx, busiCall) if err == dtmcli.ErrFailure { return status.New(codes.Aborted, "user rollback").Err() } diff --git a/dtmgrpc/xa.go b/dtmgrpc/xa.go index 204d3dc..0a79697 100644 --- a/dtmgrpc/xa.go +++ b/dtmgrpc/xa.go @@ -55,7 +55,7 @@ func (xc *XaGrpcClient) HandleCallback(gid string, branchID string, action strin } defer db.Close() xaID := gid + "-" + branchID - _, err = dtmcli.SdbExec(db, fmt.Sprintf("xa %s '%s'", action, xaID)) + _, err = dtmcli.DBExec(db, fmt.Sprintf("xa %s '%s'", action, xaID)) return err } @@ -76,9 +76,9 @@ func (xc *XaGrpcClient) XaLocalTransaction(br *BusiRequest, xaFunc XaGrpcLocalFu defer func() { db.Close() }() defer func() { x := recover() - _, err := dtmcli.SdbExec(db, fmt.Sprintf("XA end '%s'", xaBranch)) + _, err := dtmcli.DBExec(db, fmt.Sprintf("XA end '%s'", xaBranch)) if x == nil && rerr == nil && err == nil { - _, err = dtmcli.SdbExec(db, fmt.Sprintf("XA prepare '%s'", xaBranch)) + _, err = dtmcli.DBExec(db, fmt.Sprintf("XA prepare '%s'", xaBranch)) } if rerr == nil { rerr = err @@ -87,7 +87,7 @@ func (xc *XaGrpcClient) XaLocalTransaction(br *BusiRequest, xaFunc XaGrpcLocalFu panic(x) } }() - _, rerr = dtmcli.SdbExec(db, fmt.Sprintf("XA start '%s'", xaBranch)) + _, rerr = dtmcli.DBExec(db, fmt.Sprintf("XA start '%s'", xaBranch)) if rerr != nil { return } diff --git a/examples/base_grpc.go b/examples/base_grpc.go index 856ead3..a64d484 100644 --- a/examples/base_grpc.go +++ b/examples/base_grpc.go @@ -116,7 +116,7 @@ func (s *busiServer) TransInXa(ctx context.Context, in *dtmgrpc.BusiRequest) (*d if req.TransInResult == "FAILURE" { return status.New(codes.Aborted, "user return failure").Err() } - _, err := dtmcli.SdbExec(db, "update dtm_busi.user_account set balance=balance+? where user_id=?", req.Amount, 2) + _, err := dtmcli.DBExec(db, "update dtm_busi.user_account set balance=balance+? where user_id=?", req.Amount, 2) return err }) } @@ -128,7 +128,7 @@ func (s *busiServer) TransOutXa(ctx context.Context, in *dtmgrpc.BusiRequest) (* if req.TransOutResult == "FAILURE" { return status.New(codes.Aborted, "user return failure").Err() } - _, err := dtmcli.SdbExec(db, "update dtm_busi.user_account set balance=balance-? where user_id=?", req.Amount, 1) + _, err := dtmcli.DBExec(db, "update dtm_busi.user_account set balance=balance-? where user_id=?", req.Amount, 1) return err }) } diff --git a/examples/base_http.go b/examples/base_http.go index 4de4112..28154b0 100644 --- a/examples/base_http.go +++ b/examples/base_http.go @@ -8,6 +8,8 @@ import ( "github.com/gin-gonic/gin" "github.com/yedf/dtm/common" "github.com/yedf/dtm/dtmcli" + "gorm.io/driver/mysql" + "gorm.io/gorm" ) const ( @@ -108,7 +110,7 @@ func BaseAddRoute(app *gin.Engine) { if reqFrom(c).TransInResult == "FAILURE" { return dtmcli.ResultFailure, nil } - _, err := dtmcli.SdbExec(db, "update dtm_busi.user_account set balance=balance+? where user_id=?", reqFrom(c).Amount, 2) + _, err := dtmcli.DBExec(db, "update dtm_busi.user_account set balance=balance+? where user_id=?", reqFrom(c).Amount, 2) return dtmcli.ResultSuccess, err }) })) @@ -117,9 +119,25 @@ func BaseAddRoute(app *gin.Engine) { if reqFrom(c).TransOutResult == "FAILURE" { return dtmcli.ResultFailure, nil } - _, err := dtmcli.SdbExec(db, "update dtm_busi.user_account set balance=balance-? where user_id=?", reqFrom(c).Amount, 1) + _, err := dtmcli.DBExec(db, "update dtm_busi.user_account set balance=balance-? where user_id=?", reqFrom(c).Amount, 1) return dtmcli.ResultSuccess, err }) })) + app.POST(BusiAPI+"/TransOutXaGorm", common.WrapHandler(func(c *gin.Context) (interface{}, error) { + return XaClient.XaLocalTransaction(c.Request.URL.Query(), func(db *sql.DB, xa *dtmcli.Xa) (interface{}, error) { + if reqFrom(c).TransOutResult == "FAILURE" { + return dtmcli.ResultFailure, nil + } + gdb, err := gorm.Open(mysql.New(mysql.Config{ + Conn: db, + }), &gorm.Config{}) + if err != nil { + return nil, err + } + dbr := gdb.Exec("update dtm_busi.user_account set balance=balance-? where user_id=?", reqFrom(c).Amount, 1) + return dtmcli.ResultSuccess, dbr.Error + }) + })) + } diff --git a/examples/base_types.go b/examples/base_types.go index 87a57b3..cefc688 100644 --- a/examples/base_types.go +++ b/examples/base_types.go @@ -71,6 +71,13 @@ func sdbGet() *sql.DB { return db } +func txGet() *sql.Tx { + db := sdbGet() + tx, err := db.Begin() + dtmcli.FatalIfError(err) + return tx +} + // MustBarrierFromGin 1 func MustBarrierFromGin(c *gin.Context) *dtmcli.BranchBarrier { ti, err := dtmcli.BarrierFromQuery(c.Request.URL.Query()) diff --git a/examples/data.go b/examples/data.go index f4d9938..9bbc5bc 100644 --- a/examples/data.go +++ b/examples/data.go @@ -24,7 +24,7 @@ func RunSQLScript(conf map[string]string, script string, skipDrop bool) { if s == "" || skipDrop && strings.Contains(s, "drop") { continue } - _, err = dtmcli.SdbExec(con, s) + _, err = dtmcli.DBExec(con, s) dtmcli.FatalIfError(err) } } diff --git a/examples/grpc_saga_barrier.go b/examples/grpc_saga_barrier.go index 5337a50..cabb1c7 100644 --- a/examples/grpc_saga_barrier.go +++ b/examples/grpc_saga_barrier.go @@ -2,7 +2,6 @@ package examples import ( "context" - "database/sql" "github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/dtmgrpc" @@ -24,11 +23,11 @@ func init() { }) } -func sagaGrpcBarrierAdjustBalance(sdb *sql.Tx, uid int, amount int, result string) error { +func sagaGrpcBarrierAdjustBalance(db dtmcli.DB, uid int, amount int, result string) error { if result == "FAILURE" { return status.New(codes.Aborted, "user rollback").Err() } - _, err := dtmcli.StxExec(sdb, "update dtm_busi.user_account set balance = balance + ? where user_id = ?", amount, uid) + _, err := dtmcli.DBExec(db, "update dtm_busi.user_account set balance = balance + ? where user_id = ?", amount, uid) return err } @@ -37,8 +36,8 @@ func (s *busiServer) TransInBSaga(ctx context.Context, in *dtmgrpc.BusiRequest) req := TransReq{} dtmcli.MustUnmarshal(in.BusiData, &req) barrier := MustBarrierFromGrpc(in) - return &emptypb.Empty{}, barrier.Call(sdbGet(), func(sdb *sql.Tx) error { - return sagaGrpcBarrierAdjustBalance(sdb, 2, req.Amount, req.TransInResult) + return &emptypb.Empty{}, barrier.Call(txGet(), func(tx dtmcli.DB) error { + return sagaGrpcBarrierAdjustBalance(tx, 2, req.Amount, req.TransInResult) }) } @@ -46,8 +45,8 @@ func (s *busiServer) TransOutBSaga(ctx context.Context, in *dtmgrpc.BusiRequest) req := TransReq{} dtmcli.MustUnmarshal(in.BusiData, &req) barrier := MustBarrierFromGrpc(in) - return &emptypb.Empty{}, barrier.Call(sdbGet(), func(sdb *sql.Tx) error { - return sagaGrpcBarrierAdjustBalance(sdb, 1, -req.Amount, req.TransOutResult) + return &emptypb.Empty{}, barrier.Call(txGet(), func(db dtmcli.DB) error { + return sagaGrpcBarrierAdjustBalance(db, 1, -req.Amount, req.TransOutResult) }) } @@ -55,8 +54,8 @@ func (s *busiServer) TransInRevertBSaga(ctx context.Context, in *dtmgrpc.BusiReq req := TransReq{} dtmcli.MustUnmarshal(in.BusiData, &req) barrier := MustBarrierFromGrpc(in) - return &emptypb.Empty{}, barrier.Call(sdbGet(), func(sdb *sql.Tx) error { - return sagaGrpcBarrierAdjustBalance(sdb, 2, -req.Amount, "") + return &emptypb.Empty{}, barrier.Call(txGet(), func(db dtmcli.DB) error { + return sagaGrpcBarrierAdjustBalance(db, 2, -req.Amount, "") }) } @@ -64,7 +63,7 @@ func (s *busiServer) TransOutRevertBSaga(ctx context.Context, in *dtmgrpc.BusiRe req := TransReq{} dtmcli.MustUnmarshal(in.BusiData, &req) barrier := MustBarrierFromGrpc(in) - return &emptypb.Empty{}, barrier.Call(sdbGet(), func(sdb *sql.Tx) error { - return sagaGrpcBarrierAdjustBalance(sdb, 1, req.Amount, "") + return &emptypb.Empty{}, barrier.Call(txGet(), func(db dtmcli.DB) error { + return sagaGrpcBarrierAdjustBalance(db, 1, req.Amount, "") }) } diff --git a/examples/http_gorm_xa.go b/examples/http_gorm_xa.go new file mode 100644 index 0000000..a6d817e --- /dev/null +++ b/examples/http_gorm_xa.go @@ -0,0 +1,22 @@ +package examples + +import ( + "github.com/go-resty/resty/v2" + "github.com/yedf/dtm/dtmcli" +) + +func init() { + addSample("xa_gorm", func() string { + gid := dtmcli.MustGenGid(DtmServer) + err := XaClient.XaGlobalTransaction(gid, func(xa *dtmcli.Xa) (*resty.Response, error) { + resp, err := xa.CallBranch(&TransReq{Amount: 30}, Busi+"/TransOutXaGorm") + if err != nil { + return resp, err + } + return xa.CallBranch(&TransReq{Amount: 30}, Busi+"/TransInXa") + }) + dtmcli.FatalIfError(err) + return gid + }) + +} diff --git a/examples/http_saga_barrier.go b/examples/http_saga_barrier.go index c5e092e..a6c865a 100644 --- a/examples/http_saga_barrier.go +++ b/examples/http_saga_barrier.go @@ -1,8 +1,6 @@ package examples import ( - "database/sql" - "github.com/gin-gonic/gin" "github.com/yedf/dtm/common" "github.com/yedf/dtm/dtmcli" @@ -28,8 +26,8 @@ func init() { }) } -func sagaBarrierAdjustBalance(sdb *sql.Tx, uid int, amount int) error { - _, err := dtmcli.StxExec(sdb, "update dtm_busi.user_account set balance = balance + ? where user_id = ?", amount, uid) +func sagaBarrierAdjustBalance(db dtmcli.DB, uid int, amount int) error { + _, err := dtmcli.DBExec(db, "update dtm_busi.user_account set balance = balance + ? where user_id = ?", amount, uid) return err } @@ -40,15 +38,15 @@ func sagaBarrierTransIn(c *gin.Context) (interface{}, error) { return req.TransInResult, nil } barrier := MustBarrierFromGin(c) - return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error { - return sagaBarrierAdjustBalance(sdb, 1, req.Amount) + return dtmcli.ResultSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error { + return sagaBarrierAdjustBalance(db, 1, req.Amount) }) } func sagaBarrierTransInCompensate(c *gin.Context) (interface{}, error) { barrier := MustBarrierFromGin(c) - return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error { - return sagaBarrierAdjustBalance(sdb, 1, -reqFrom(c).Amount) + return dtmcli.ResultSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error { + return sagaBarrierAdjustBalance(db, 1, -reqFrom(c).Amount) }) } @@ -58,14 +56,14 @@ func sagaBarrierTransOut(c *gin.Context) (interface{}, error) { return req.TransInResult, nil } barrier := MustBarrierFromGin(c) - return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error { - return sagaBarrierAdjustBalance(sdb, 2, -req.Amount) + return dtmcli.ResultSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error { + return sagaBarrierAdjustBalance(db, 2, -req.Amount) }) } func sagaBarrierTransOutCompensate(c *gin.Context) (interface{}, error) { barrier := MustBarrierFromGin(c) - return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error { - return sagaBarrierAdjustBalance(sdb, 2, reqFrom(c).Amount) + return dtmcli.ResultSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error { + return sagaBarrierAdjustBalance(db, 2, reqFrom(c).Amount) }) } diff --git a/examples/http_saga_gorm_barrier.go b/examples/http_saga_gorm_barrier.go new file mode 100644 index 0000000..5b23a25 --- /dev/null +++ b/examples/http_saga_gorm_barrier.go @@ -0,0 +1,34 @@ +package examples + +import ( + "github.com/gin-gonic/gin" + "github.com/yedf/dtm/common" + "github.com/yedf/dtm/dtmcli" +) + +func init() { + setupFuncs["SagaGormBarrierSetup"] = func(app *gin.Engine) { + app.POST(BusiAPI+"/SagaBTransOutGorm", common.WrapHandler(sagaGormBarrierTransOut)) + } + addSample("saga_gorm_barrier", func() string { + dtmcli.Logf("a busi transaction begin") + req := &TransReq{Amount: 30} + saga := dtmcli.NewSaga(DtmServer, dtmcli.MustGenGid(DtmServer)). + Add(Busi+"/SagaBTransOutGorm", Busi+"/SagaBTransOutCompensate", req). + Add(Busi+"/SagaBTransIn", Busi+"/SagaBTransInCompensate", req) + dtmcli.Logf("busi trans submit") + err := saga.Submit() + dtmcli.FatalIfError(err) + return saga.Gid + }) + +} + +func sagaGormBarrierTransOut(c *gin.Context) (interface{}, error) { + req := reqFrom(c) + barrier := MustBarrierFromGin(c) + tx := dbGet().DB.Begin() + return dtmcli.ResultSuccess, barrier.Call(tx.Statement.ConnPool.(dtmcli.Tx), func(db dtmcli.DB) error { + return tx.Exec("update dtm_busi.user_account set balance = balance + ? where user_id = ?", -req.Amount, 2).Error + }) +} diff --git a/examples/http_tcc_barrier.go b/examples/http_tcc_barrier.go index 99edf3a..8026f5b 100644 --- a/examples/http_tcc_barrier.go +++ b/examples/http_tcc_barrier.go @@ -1,7 +1,6 @@ package examples import ( - "database/sql" "fmt" "github.com/gin-gonic/gin" @@ -37,18 +36,18 @@ func init() { const transInUID = 1 const transOutUID = 2 -func adjustTrading(sdb *sql.Tx, uid int, amount int) error { - affected, err := dtmcli.StxExec(sdb, "update dtm_busi.user_account_trading set trading_balance=trading_balance + ? where user_id=? and trading_balance + ? + (select balance from dtm_busi.user_account where id=?) >= 0", amount, uid, amount, uid) +func adjustTrading(db dtmcli.DB, uid int, amount int) error { + affected, err := dtmcli.DBExec(db, "update dtm_busi.user_account_trading set trading_balance=trading_balance + ? where user_id=? and trading_balance + ? + (select balance from dtm_busi.user_account where id=?) >= 0", amount, uid, amount, uid) if err == nil && affected == 0 { return fmt.Errorf("update error, maybe balance not enough") } return err } -func adjustBalance(sdb *sql.Tx, uid int, amount int) error { - affected, err := dtmcli.StxExec(sdb, "update dtm_busi.user_account_trading set trading_balance = trading_balance + ? where user_id=?;", -amount, uid) +func adjustBalance(db dtmcli.DB, uid int, amount int) error { + affected, err := dtmcli.DBExec(db, "update dtm_busi.user_account_trading set trading_balance = trading_balance + ? where user_id=?;", -amount, uid) if err == nil && affected == 1 { - affected, err = dtmcli.StxExec(sdb, "update dtm_busi.user_account set balance=balance+? where user_id=?", amount, uid) + affected, err = dtmcli.DBExec(db, "update dtm_busi.user_account set balance=balance+? where user_id=?", amount, uid) } if err == nil && affected == 0 { return fmt.Errorf("update 0 rows") @@ -63,22 +62,22 @@ func tccBarrierTransInTry(c *gin.Context) (interface{}, error) { return req.TransInResult, nil } barrier := MustBarrierFromGin(c) - return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error { - return adjustTrading(sdb, transInUID, req.Amount) + return dtmcli.ResultSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error { + return adjustTrading(db, transInUID, req.Amount) }) } func tccBarrierTransInConfirm(c *gin.Context) (interface{}, error) { barrier := MustBarrierFromGin(c) - return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error { - return adjustBalance(sdb, transInUID, reqFrom(c).Amount) + return dtmcli.ResultSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error { + return adjustBalance(db, transInUID, reqFrom(c).Amount) }) } func tccBarrierTransInCancel(c *gin.Context) (interface{}, error) { barrier := MustBarrierFromGin(c) - return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error { - return adjustTrading(sdb, transInUID, -reqFrom(c).Amount) + return dtmcli.ResultSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error { + return adjustTrading(db, transInUID, -reqFrom(c).Amount) }) } @@ -88,22 +87,22 @@ func tccBarrierTransOutTry(c *gin.Context) (interface{}, error) { return req.TransInResult, nil } barrier := MustBarrierFromGin(c) - return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error { - return adjustTrading(sdb, transOutUID, -req.Amount) + return dtmcli.ResultSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error { + return adjustTrading(db, transOutUID, -req.Amount) }) } func tccBarrierTransOutConfirm(c *gin.Context) (interface{}, error) { barrier := MustBarrierFromGin(c) - return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error { - return adjustBalance(sdb, transOutUID, -reqFrom(c).Amount) + return dtmcli.ResultSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error { + return adjustBalance(db, transOutUID, -reqFrom(c).Amount) }) } // TccBarrierTransOutCancel will be use in test func TccBarrierTransOutCancel(c *gin.Context) (interface{}, error) { barrier := MustBarrierFromGin(c) - return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error { - return adjustTrading(sdb, transOutUID, reqFrom(c).Amount) + return dtmcli.ResultSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error { + return adjustTrading(db, transOutUID, reqFrom(c).Amount) }) } diff --git a/examples/quick_start.go b/examples/quick_start.go index dafeaf1..5a9a76a 100644 --- a/examples/quick_start.go +++ b/examples/quick_start.go @@ -42,7 +42,7 @@ func QsFireRequest() string { } func qsAdjustBalance(uid int, amount int) (interface{}, error) { - _, err := dtmcli.SdbExec(sdbGet(), "update dtm_busi.user_account set balance = balance + ? where user_id = ?", amount, uid) + _, err := dtmcli.DBExec(sdbGet(), "update dtm_busi.user_account set balance = balance + ? where user_id = ?", amount, uid) return dtmcli.ResultSuccess, err } diff --git a/test/dtmsvr_test.go b/test/dtmsvr_test.go index db87cfb..703d1e0 100644 --- a/test/dtmsvr_test.go +++ b/test/dtmsvr_test.go @@ -1,7 +1,6 @@ package test import ( - "database/sql" "fmt" "testing" @@ -118,7 +117,9 @@ func TestSqlDB(t *testing.T) { BranchType: "action", } db.Must().Exec("insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values('saga', 'gid1', 'branch_id1', 'action', 'saga')") - err := barrier.Call(db.ToSQLDB(), func(db *sql.Tx) error { + tx, err := db.ToSQLDB().Begin() + asserts.Nil(err) + err = barrier.Call(tx, func(db dtmcli.DB) error { dtmcli.Logf("rollback gid2") return fmt.Errorf("gid2 error") }) @@ -128,7 +129,9 @@ func TestSqlDB(t *testing.T) { dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{}) asserts.Equal(dbr.RowsAffected, int64(0)) barrier.BarrierID = 0 - err = barrier.Call(db.ToSQLDB(), func(db *sql.Tx) error { + tx2, err := db.ToSQLDB().Begin() + asserts.Nil(err) + err = barrier.Call(tx2, func(db dtmcli.DB) error { dtmcli.Logf("submit gid2") return nil })