diff --git a/README.md b/README.md index f24d459..c877c71 100644 --- a/README.md +++ b/README.md @@ -35,9 +35,11 @@ DTM是一款go语言的分布式事务管理器,在微服务架构中,提供 ### 使用 ``` go + // 具体业务微服务地址 + const qsBusi = "http://localhost:8081/api/busi_saga" req := &gin.H{"amount": 30} // 微服务的载荷 // DtmServer为DTM服务的地址,是一个url - saga := dtmcli.NewSaga(DtmServer). + saga := dtmcli.NewSaga("http://localhost:8080/api/dtmsvr"). // 添加一个TransOut的子事务,正向操作为url: qsBusi+"/TransOut", 逆向操作为url: qsBusi+"/TransOutCompensate" Add(qsBusi+"/TransOut", qsBusi+"/TransOutCompensate", req). // 添加一个TransIn的子事务,正向操作为url: qsBusi+"/TransOut", 逆向操作为url: qsBusi+"/TransInCompensate" diff --git a/dtmcli/barrier.go b/dtmcli/barrier.go index 42d204a..bfc019b 100644 --- a/dtmcli/barrier.go +++ b/dtmcli/barrier.go @@ -5,6 +5,7 @@ import ( "database/sql" "fmt" + "github.com/gin-gonic/gin" "github.com/yedf/dtm/common" ) @@ -21,6 +22,15 @@ func (t *TransInfo) String() string { return fmt.Sprintf("transInfo: %s %s %s %s", t.TransType, t.Gid, t.BranchID, t.BranchType) } +func TransInfoFromReq(c *gin.Context) *TransInfo { + return &TransInfo{ + TransType: c.Query("trans_type"), + Gid: c.Query("gid"), + BranchID: c.Query("branch_id"), + BranchType: c.Query("branch_type"), + } +} + type BarrierModel struct { common.ModelBase TransInfo @@ -39,7 +49,7 @@ func insertBarrier(tx *sql.Tx, transType string, gid string, branchID string, br return res.RowsAffected() } -func ThroughBarrierCall(db *sql.DB, transType string, gid string, branchId string, branchType string, busiCall BusiFunc) (res interface{}, rerr error) { +func ThroughBarrierCall(db *sql.DB, transInfo *TransInfo, busiCall BusiFunc) (res interface{}, rerr error) { tx, rerr := db.BeginTx(context.Background(), &sql.TxOptions{}) if rerr != nil { return @@ -58,9 +68,9 @@ func ThroughBarrierCall(db *sql.DB, transType string, gid string, branchId strin originType := map[string]string{ "cancel": "action", "compensate": "action", - }[branchType] - originAffected, _ := insertBarrier(tx, transType, gid, branchId, originType) - currentAffected, rerr := insertBarrier(tx, transType, gid, branchId, branchType) + }[transInfo.BranchType] + originAffected, _ := insertBarrier(tx, transInfo.TransType, transInfo.Gid, transInfo.BranchID, originType) + currentAffected, rerr := insertBarrier(tx, transInfo.TransType, transInfo.Gid, transInfo.BranchID, transInfo.TransType) if currentAffected == 0 || (originType == "cancel" || originType == "compensate") && originAffected > 0 { return } diff --git a/dtmsvr/dtmsvr_test.go b/dtmsvr/dtmsvr_test.go index 75d1afb..bec6dd7 100644 --- a/dtmsvr/dtmsvr_test.go +++ b/dtmsvr/dtmsvr_test.go @@ -250,8 +250,14 @@ func transQuery(t *testing.T, gid string) { func TestSqlDB(t *testing.T) { asserts := assert.New(t) db := common.DbGet(config.Mysql) + transInfo := &dtmcli.TransInfo{ + TransType: "saga", + Gid: "gid2", + BranchID: "branch_id2", + BranchType: "compensate", + } db.Must().Exec("insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type) values('saga', 'gid1', 'branch_id1', 'action')") - _, err := dtmcli.ThroughBarrierCall(db.ToSqlDB(), "saga", "gid2", "branch_id2", "compensate", func(db *sql.DB) (interface{}, error) { + _, err := dtmcli.ThroughBarrierCall(db.ToSqlDB(), transInfo, func(db *sql.DB) (interface{}, error) { logrus.Printf("rollback gid2") return nil, fmt.Errorf("gid2 error") }) @@ -260,7 +266,7 @@ func TestSqlDB(t *testing.T) { asserts.Equal(dbr.RowsAffected, int64(1)) dbr = db.Model(&dtmcli.BarrierModel{}).Where("gid=?", "gid2").Find(&[]dtmcli.BarrierModel{}) asserts.Equal(dbr.RowsAffected, int64(0)) - _, err = dtmcli.ThroughBarrierCall(db.ToSqlDB(), "saga", "gid2", "branch_id2", "compensate", func(db *sql.DB) (interface{}, error) { + _, err = dtmcli.ThroughBarrierCall(db.ToSqlDB(), transInfo, func(db *sql.DB) (interface{}, error) { logrus.Printf("submit gid2") return nil, nil }) diff --git a/examples/main_saga_barrier.go b/examples/main_saga_barrier.go index 6584c04..db0d425 100644 --- a/examples/main_saga_barrier.go +++ b/examples/main_saga_barrier.go @@ -55,47 +55,42 @@ func SagaBarrierAddRoute(app *gin.Engine) { logrus.Printf("examples listening at %d", SagaBarrierBusiPort) } -var SagaBarrierTransInResult = "" -var SagaBarrierTransOutResult = "" -var SagaBarrierTransInCompensateResult = "" -var SagaBarrierTransOutCompensateResult = "" - func sagaBarrierTransIn(c *gin.Context) (interface{}, error) { - gid := c.Query("gid") req := reqFrom(c) - res := common.OrString(SagaBarrierTransInResult, req.TransInResult, "SUCCESS") - logrus.Printf("%s TransIn: %v result: %s", gid, req, res) - return M{"result": res}, nil + return dtmcli.ThroughBarrierCall(dbGet().ToSqlDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { + db := common.SqlDB2DB(sdb) + dbr := db.Model(&UserAccount{}).Where("user_id = ?", c.Query("user_id")). + Update("balance", gorm.Expr("balance + ?", req.Amount)) + return "SUCCESS", dbr.Error + }) } func sagaBarrierTransInCompensate(c *gin.Context) (interface{}, error) { - gid := c.Query("gid") req := reqFrom(c) - res := common.OrString(SagaBarrierTransInCompensateResult, "SUCCESS") - logrus.Printf("%s TransInCompensate: %v result: %s", gid, req, res) - return M{"result": res}, nil -} - -func sagaBarrierTransOut(c *gin.Context) (interface{}, error) { - gid := c.Query("gid") - lid := c.Query("lid") - req := reqFrom(c) - return dtmcli.ThroughBarrierCall(dbGet().ToSqlDB(), "saga", gid, lid, "action", func(sdb *sql.DB) (interface{}, error) { + return dtmcli.ThroughBarrierCall(dbGet().ToSqlDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { db := common.SqlDB2DB(sdb) dbr := db.Model(&UserAccount{}).Where("user_id = ?", c.Query("user_id")). Update("balance", gorm.Expr("balance - ?", req.Amount)) - return nil, dbr.Error + return "SUCCESS", dbr.Error }) +} - // res := common.OrString(SagaBarrierTransOutResult, req.TransOutResult, "SUCCESS") - // logrus.Printf("%s TransOut: %v result: %s", gid, req, res) - // return M{"result": res}, nil +func sagaBarrierTransOut(c *gin.Context) (interface{}, error) { + req := reqFrom(c) + return dtmcli.ThroughBarrierCall(dbGet().ToSqlDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { + db := common.SqlDB2DB(sdb) + dbr := db.Model(&UserAccount{}).Where("user_id = ?", c.Query("user_id")). + Update("balance", gorm.Expr("balance - ?", req.Amount)) + return "SUCCESS", dbr.Error + }) } func sagaBarrierTransOutCompensate(c *gin.Context) (interface{}, error) { - gid := c.Query("gid") req := reqFrom(c) - res := common.OrString(SagaBarrierTransOutCompensateResult, "SUCCESS") - logrus.Printf("%s TransOutCompensate: %v result: %s", gid, req, res) - return M{"result": res}, nil + return dtmcli.ThroughBarrierCall(dbGet().ToSqlDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { + db := common.SqlDB2DB(sdb) + dbr := db.Model(&UserAccount{}).Where("user_id = ?", c.Query("user_id")). + Update("balance", gorm.Expr("balance + ?", req.Amount)) + return "SUCCESS", dbr.Error + }) } diff --git a/examples/quick_start.go b/examples/quick_start.go index 3c02539..d0292b4 100644 --- a/examples/quick_start.go +++ b/examples/quick_start.go @@ -31,7 +31,7 @@ func qsStartSvr() { func qsFireRequest() { req := &gin.H{"amount": 30} // 微服务的载荷 - // DtmServer为DTM服务的地址,是一个url + // DtmServer为DTM服务的地址 saga := dtmcli.NewSaga(DtmServer). // 添加一个TransOut的子事务,正向操作为url: qsBusi+"/TransOut", 逆向操作为url: qsBusi+"/TransOutCompensate" Add(qsBusi+"/TransOut", qsBusi+"/TransOutCompensate", req).