From 47d0ba5a7b6e9417ee23277dad359e7e72847b58 Mon Sep 17 00:00:00 2001 From: yedf2 <120050102@qq.com> Date: Thu, 12 Aug 2021 18:00:08 +0800 Subject: [PATCH] grpc barrier ok --- dtmcli/barrier.go | 38 ++++++-------------- dtmcli/barrier.mysql.sql | 1 - examples/grpc_saga_barrier.go | 66 ++++++++++++----------------------- examples/http_saga_barrier.go | 12 +++---- examples/http_tcc_barrier.go | 24 ++++++------- test/dtmsvr_test.go | 15 +++----- 6 files changed, 55 insertions(+), 101 deletions(-) diff --git a/dtmcli/barrier.go b/dtmcli/barrier.go index a4eae9c..1096983 100644 --- a/dtmcli/barrier.go +++ b/dtmcli/barrier.go @@ -3,13 +3,12 @@ package dtmcli import ( "context" "database/sql" - "encoding/json" "fmt" "net/url" ) // BusiFunc type for busi func -type BusiFunc func(db *sql.Tx) (interface{}, error) +type BusiFunc func(db *sql.Tx) error // BranchBarrier every branch info type BranchBarrier struct { @@ -55,11 +54,9 @@ func insertBarrier(tx *sql.Tx, transType string, gid string, branchID string, br // transInfo: 事务信息 // bisiCall: 业务函数,仅在必要时被调用 // 返回值: -// 如果正常调用,返回bisiCall的结果 -// 如果发生重复调用,则busiCall不会被重复调用,直接对保存在数据库中上一次的结果,进行unmarshal,通常是一个map[string]interface{},直接作为http的resp -// 如果发生悬挂,则busiCall不会被调用,直接返回错误 {"dtm_result": "FAILURE"} -// 如果发生空补偿,则busiCall不会被调用,直接返回 {"dtm_result": "SUCCESS"} -func (bb *BranchBarrier) Call(db *sql.DB, busiCall BusiFunc) (res interface{}, rerr error) { +// 如果发生悬挂,则busiCall不会被调用,直接返回错误 ErrFailure,全局事务尽早进行回滚 +// 如果正常调用,重复调用,空补偿,返回的错误值为nil,正常往下进行 +func (bb *BranchBarrier) Call(db *sql.DB, busiCall BusiFunc) (rerr error) { bb.BarrierID = bb.BarrierID + 1 bid := fmt.Sprintf("%02d", bb.BarrierID) tx, rerr := db.BeginTx(context.Background(), &sql.TxOptions{}) @@ -67,7 +64,7 @@ func (bb *BranchBarrier) Call(db *sql.DB, busiCall BusiFunc) (res interface{}, r return } defer func() { - Logf("result is %v error is %v", res, rerr) + Logf("barrier call error is %v", rerr) if x := recover(); x != nil { tx.Rollback() panic(x) @@ -86,33 +83,18 @@ func (bb *BranchBarrier) Call(db *sql.DB, busiCall BusiFunc) (res interface{}, r currentAffected, rerr := insertBarrier(tx, ti.TransType, ti.Gid, ti.BranchID, ti.BranchType, bid, ti.BranchType) Logf("originAffected: %d currentAffected: %d", originAffected, currentAffected) if (ti.BranchType == "cancel" || ti.BranchType == "compensate") && originAffected > 0 { // 这个是空补偿,返回成功 - res = ResultSuccess return } else if currentAffected == 0 { // 插入不成功 var result sql.NullString - err := StxQueryRow(tx, "select result from dtm_barrier.barrier where trans_type=? and gid=? and branch_id=? and branch_type=? and barrier_id=? and reason=?", + err := StxQueryRow(tx, "select 1 from dtm_barrier.barrier where trans_type=? and gid=? and branch_id=? and branch_type=? and barrier_id=? and reason=?", ti.TransType, ti.Gid, ti.BranchID, ti.BranchType, bid, ti.BranchType).Scan(&result) - if err == sql.ErrNoRows { // 这个是悬挂操作,返回失败,AP收到这个返回,会尽快回滚 - res = ResultFailure + if err == sql.ErrNoRows { // 不是当前分支插入的,那么是cancel插入的,因此是悬挂操作,返回失败,AP收到这个返回,会尽快回滚 + rerr = ErrFailure return } - if err != nil { - rerr = err - return - } - if result.Valid { // 数据库里有上一次结果,返回上一次的结果 - rerr = json.Unmarshal([]byte(result.String), &res) - return - } - // 数据库里没有上次的结果,属于重复空补偿,直接返回成功 - res = ResultSuccess + rerr = err //幂等和空补偿,直接返回 return } - res, rerr = busiCall(tx) - if rerr == nil { // 正确返回了,需要将结果保存到数据库 - sval := MustMarshalString(res) - _, rerr = StxExec(tx, "update dtm_barrier.barrier set result=? where trans_type=? and gid=? and branch_id=? and branch_type=?", sval, - ti.TransType, ti.Gid, ti.BranchID, ti.BranchType) - } + rerr = busiCall(tx) return } diff --git a/dtmcli/barrier.mysql.sql b/dtmcli/barrier.mysql.sql index 67eaa5d..a676f34 100644 --- a/dtmcli/barrier.mysql.sql +++ b/dtmcli/barrier.mysql.sql @@ -9,7 +9,6 @@ create table if not exists dtm_barrier.barrier( branch_type varchar(45) default '', barrier_id varchar(45) default '', reason varchar(45) default '' comment 'the branch type who insert this record', - result varchar(2047) default null comment 'the business result of this branch', create_time datetime DEFAULT now(), update_time datetime DEFAULT now(), key(create_time), diff --git a/examples/grpc_saga_barrier.go b/examples/grpc_saga_barrier.go index 94183cc..5337a50 100644 --- a/examples/grpc_saga_barrier.go +++ b/examples/grpc_saga_barrier.go @@ -4,9 +4,10 @@ import ( "context" "database/sql" - "github.com/gin-gonic/gin" "github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/dtmgrpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/emptypb" ) @@ -16,75 +17,54 @@ func init() { gid := dtmgrpc.MustGenGid(DtmGrpcServer) saga := dtmgrpc.NewSaga(DtmGrpcServer, gid). Add(BusiGrpc+"/examples.Busi/TransOutBSaga", BusiGrpc+"/examples.Busi/TransOutRevertBSaga", req). - Add(BusiGrpc+"/examples.Busi/TransInBSaga", BusiGrpc+"/examples.Busi/TransOutRevertBSaga", req) + Add(BusiGrpc+"/examples.Busi/TransInBSaga", BusiGrpc+"/examples.Busi/TransInRevertBSaga", req) err := saga.Submit() dtmcli.FatalIfError(err) return saga.Gid }) } -func sagaGrpcBarrierAdjustBalance(sdb *sql.Tx, uid int, amount int) (interface{}, error) { +func sagaGrpcBarrierAdjustBalance(sdb *sql.Tx, uid int, amount int, result string) error { + if result == "FAILURE" { + return status.New(codes.Aborted, "user rollback").Err() + } _, err := dtmcli.StxExec(sdb, "update dtm_busi.user_account set balance = balance + ? where user_id = ?", amount, uid) - return dtmcli.ResultSuccess, err + return err } func (s *busiServer) TransInBSaga(ctx context.Context, in *dtmgrpc.BusiRequest) (*emptypb.Empty, error) { req := TransReq{} dtmcli.MustUnmarshal(in.BusiData, &req) - return &emptypb.Empty{}, handleGrpcBusiness(in, MainSwitch.TransInResult.Fetch(), req.TransInResult, dtmcli.GetFuncName()) + barrier := MustBarrierFromGrpc(in) + return &emptypb.Empty{}, barrier.Call(sdbGet(), func(sdb *sql.Tx) error { + return sagaGrpcBarrierAdjustBalance(sdb, 2, req.Amount, req.TransInResult) + }) } func (s *busiServer) TransOutBSaga(ctx context.Context, in *dtmgrpc.BusiRequest) (*emptypb.Empty, error) { req := TransReq{} dtmcli.MustUnmarshal(in.BusiData, &req) - return &emptypb.Empty{}, handleGrpcBusiness(in, MainSwitch.TransOutResult.Fetch(), req.TransOutResult, dtmcli.GetFuncName()) + barrier := MustBarrierFromGrpc(in) + return &emptypb.Empty{}, barrier.Call(sdbGet(), func(sdb *sql.Tx) error { + return sagaGrpcBarrierAdjustBalance(sdb, 1, -req.Amount, req.TransOutResult) + }) } func (s *busiServer) TransInRevertBSaga(ctx context.Context, in *dtmgrpc.BusiRequest) (*emptypb.Empty, error) { req := TransReq{} dtmcli.MustUnmarshal(in.BusiData, &req) - return &emptypb.Empty{}, handleGrpcBusiness(in, MainSwitch.TransInRevertResult.Fetch(), "", dtmcli.GetFuncName()) + barrier := MustBarrierFromGrpc(in) + return &emptypb.Empty{}, barrier.Call(sdbGet(), func(sdb *sql.Tx) error { + return sagaGrpcBarrierAdjustBalance(sdb, 2, -req.Amount, "") + }) } func (s *busiServer) TransOutRevertBSaga(ctx context.Context, in *dtmgrpc.BusiRequest) (*emptypb.Empty, error) { req := TransReq{} dtmcli.MustUnmarshal(in.BusiData, &req) - return &emptypb.Empty{}, handleGrpcBusiness(in, MainSwitch.TransOutRevertResult.Fetch(), "", dtmcli.GetFuncName()) -} - -func sagaBarrierTransIn(c *gin.Context) (interface{}, error) { - req := reqFrom(c) - if req.TransInResult != "" { - return req.TransInResult, nil - } - barrier := MustBarrierFromGin(c) - return barrier.Call(sdbGet(), func(sdb *sql.Tx) (interface{}, error) { - return sagaBarrierAdjustBalance(sdb, 1, req.Amount) - }) -} - -func sagaBarrierTransInCompensate(c *gin.Context) (interface{}, error) { - barrier := MustBarrierFromGin(c) - return barrier.Call(sdbGet(), func(sdb *sql.Tx) (interface{}, error) { - return sagaBarrierAdjustBalance(sdb, 1, -reqFrom(c).Amount) - }) -} - -func sagaBarrierTransOut(c *gin.Context) (interface{}, error) { - req := reqFrom(c) - if req.TransInResult != "" { - return req.TransInResult, nil - } - barrier := MustBarrierFromGin(c) - return barrier.Call(sdbGet(), func(sdb *sql.Tx) (interface{}, error) { - return sagaBarrierAdjustBalance(sdb, 2, -req.Amount) - }) -} - -func sagaBarrierTransOutCompensate(c *gin.Context) (interface{}, error) { - barrier := MustBarrierFromGin(c) - return barrier.Call(sdbGet(), func(sdb *sql.Tx) (interface{}, error) { - return sagaBarrierAdjustBalance(sdb, 2, reqFrom(c).Amount) + barrier := MustBarrierFromGrpc(in) + return &emptypb.Empty{}, barrier.Call(sdbGet(), func(sdb *sql.Tx) error { + return sagaGrpcBarrierAdjustBalance(sdb, 1, req.Amount, "") }) } diff --git a/examples/http_saga_barrier.go b/examples/http_saga_barrier.go index 9e60888..c5e092e 100644 --- a/examples/http_saga_barrier.go +++ b/examples/http_saga_barrier.go @@ -28,9 +28,9 @@ func init() { }) } -func sagaBarrierAdjustBalance(sdb *sql.Tx, uid int, amount int) (interface{}, error) { +func sagaBarrierAdjustBalance(sdb *sql.Tx, uid int, amount int) error { _, err := dtmcli.StxExec(sdb, "update dtm_busi.user_account set balance = balance + ? where user_id = ?", amount, uid) - return dtmcli.ResultSuccess, err + return err } @@ -40,14 +40,14 @@ func sagaBarrierTransIn(c *gin.Context) (interface{}, error) { return req.TransInResult, nil } barrier := MustBarrierFromGin(c) - return barrier.Call(sdbGet(), func(sdb *sql.Tx) (interface{}, error) { + return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error { return sagaBarrierAdjustBalance(sdb, 1, req.Amount) }) } func sagaBarrierTransInCompensate(c *gin.Context) (interface{}, error) { barrier := MustBarrierFromGin(c) - return barrier.Call(sdbGet(), func(sdb *sql.Tx) (interface{}, error) { + return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error { return sagaBarrierAdjustBalance(sdb, 1, -reqFrom(c).Amount) }) } @@ -58,14 +58,14 @@ func sagaBarrierTransOut(c *gin.Context) (interface{}, error) { return req.TransInResult, nil } barrier := MustBarrierFromGin(c) - return barrier.Call(sdbGet(), func(sdb *sql.Tx) (interface{}, error) { + return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error { return sagaBarrierAdjustBalance(sdb, 2, -req.Amount) }) } func sagaBarrierTransOutCompensate(c *gin.Context) (interface{}, error) { barrier := MustBarrierFromGin(c) - return barrier.Call(sdbGet(), func(sdb *sql.Tx) (interface{}, error) { + return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error { return sagaBarrierAdjustBalance(sdb, 2, reqFrom(c).Amount) }) } diff --git a/examples/http_tcc_barrier.go b/examples/http_tcc_barrier.go index 755313a..99edf3a 100644 --- a/examples/http_tcc_barrier.go +++ b/examples/http_tcc_barrier.go @@ -37,23 +37,23 @@ func init() { const transInUID = 1 const transOutUID = 2 -func adjustTrading(sdb *sql.Tx, uid int, amount int) (interface{}, error) { +func adjustTrading(sdb *sql.Tx, uid int, amount int) error { affected, err := dtmcli.StxExec(sdb, "update dtm_busi.user_account_trading set trading_balance=trading_balance + ? where user_id=? and trading_balance + ? + (select balance from dtm_busi.user_account where id=?) >= 0", amount, uid, amount, uid) if err == nil && affected == 0 { - return nil, fmt.Errorf("update error, maybe balance not enough") + return fmt.Errorf("update error, maybe balance not enough") } - return dtmcli.MS{"dtm_server": "SUCCESS"}, nil + return err } -func adjustBalance(sdb *sql.Tx, uid int, amount int) (interface{}, error) { +func adjustBalance(sdb *sql.Tx, uid int, amount int) error { affected, err := dtmcli.StxExec(sdb, "update dtm_busi.user_account_trading set trading_balance = trading_balance + ? where user_id=?;", -amount, uid) if err == nil && affected == 1 { affected, err = dtmcli.StxExec(sdb, "update dtm_busi.user_account set balance=balance+? where user_id=?", amount, uid) } if err == nil && affected == 0 { - return nil, fmt.Errorf("update 0 rows") + return fmt.Errorf("update 0 rows") } - return dtmcli.ResultSuccess, err + return err } // TCC下,转入 @@ -63,21 +63,21 @@ func tccBarrierTransInTry(c *gin.Context) (interface{}, error) { return req.TransInResult, nil } barrier := MustBarrierFromGin(c) - return barrier.Call(sdbGet(), func(sdb *sql.Tx) (interface{}, error) { + return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error { return adjustTrading(sdb, transInUID, req.Amount) }) } func tccBarrierTransInConfirm(c *gin.Context) (interface{}, error) { barrier := MustBarrierFromGin(c) - return barrier.Call(sdbGet(), func(sdb *sql.Tx) (interface{}, error) { + return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error { return adjustBalance(sdb, transInUID, reqFrom(c).Amount) }) } func tccBarrierTransInCancel(c *gin.Context) (interface{}, error) { barrier := MustBarrierFromGin(c) - return barrier.Call(sdbGet(), func(sdb *sql.Tx) (interface{}, error) { + return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error { return adjustTrading(sdb, transInUID, -reqFrom(c).Amount) }) } @@ -88,14 +88,14 @@ func tccBarrierTransOutTry(c *gin.Context) (interface{}, error) { return req.TransInResult, nil } barrier := MustBarrierFromGin(c) - return barrier.Call(sdbGet(), func(sdb *sql.Tx) (interface{}, error) { + return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error { return adjustTrading(sdb, transOutUID, -req.Amount) }) } func tccBarrierTransOutConfirm(c *gin.Context) (interface{}, error) { barrier := MustBarrierFromGin(c) - return barrier.Call(sdbGet(), func(sdb *sql.Tx) (interface{}, error) { + return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error { return adjustBalance(sdb, transOutUID, -reqFrom(c).Amount) }) } @@ -103,7 +103,7 @@ func tccBarrierTransOutConfirm(c *gin.Context) (interface{}, error) { // TccBarrierTransOutCancel will be use in test func TccBarrierTransOutCancel(c *gin.Context) (interface{}, error) { barrier := MustBarrierFromGin(c) - return barrier.Call(sdbGet(), func(sdb *sql.Tx) (interface{}, error) { + return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error { return adjustTrading(sdb, transOutUID, reqFrom(c).Amount) }) } diff --git a/test/dtmsvr_test.go b/test/dtmsvr_test.go index 847e56e..db87cfb 100644 --- a/test/dtmsvr_test.go +++ b/test/dtmsvr_test.go @@ -118,28 +118,21 @@ func TestSqlDB(t *testing.T) { BranchType: "action", } db.Must().Exec("insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values('saga', 'gid1', 'branch_id1', 'action', 'saga')") - _, err := barrier.Call(db.ToSQLDB(), func(db *sql.Tx) (interface{}, error) { + err := barrier.Call(db.ToSQLDB(), func(db *sql.Tx) error { dtmcli.Logf("rollback gid2") - return nil, fmt.Errorf("gid2 error") + return fmt.Errorf("gid2 error") }) asserts.Error(err, fmt.Errorf("gid2 error")) dbr := db.Model(&BarrierModel{}).Where("gid=?", "gid1").Find(&[]BarrierModel{}) asserts.Equal(dbr.RowsAffected, int64(1)) dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{}) asserts.Equal(dbr.RowsAffected, int64(0)) - gid2Res := dtmcli.M{"result": "first"} barrier.BarrierID = 0 - _, err = barrier.Call(db.ToSQLDB(), func(db *sql.Tx) (interface{}, error) { + err = barrier.Call(db.ToSQLDB(), func(db *sql.Tx) error { dtmcli.Logf("submit gid2") - return gid2Res, nil + return nil }) asserts.Nil(err) dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{}) asserts.Equal(dbr.RowsAffected, int64(1)) - barrier.BarrierID = 0 - newResult, err := barrier.Call(db.ToSQLDB(), func(db *sql.Tx) (interface{}, error) { - dtmcli.Logf("submit gid2") - return dtmcli.MS{"result": "ignored"}, nil - }) - asserts.Equal(newResult, gid2Res) }