From b0102048a95dd06d3c487bdf01aac7b9fbef135a Mon Sep 17 00:00:00 2001 From: yedongfu Date: Thu, 8 Jul 2021 22:21:34 +0800 Subject: [PATCH] saga barrier seems ok --- app/main.go | 27 ++++++++++++++++----------- common/types.go | 1 + dtmcli/barrier.go | 6 +++++- dtmsvr/main.go | 8 +++++--- dtmsvr/trans.go | 9 +++++++++ dtmsvr/trans_msg.go | 2 +- dtmsvr/trans_saga.go | 2 +- dtmsvr/trans_tcc.go | 2 +- examples/main_saga_barrier.go | 16 ++++++++-------- examples/quick_start.go | 6 +++--- 10 files changed, 50 insertions(+), 29 deletions(-) diff --git a/app/main.go b/app/main.go index fad2d42..dbe9a62 100644 --- a/app/main.go +++ b/app/main.go @@ -11,36 +11,38 @@ import ( type M = map[string]interface{} +func wait() { + time.Sleep(10000 * time.Second) +} + func main() { if len(os.Args) > 1 && (os.Args[1] == "quick_start" || os.Args[1] == "qs") { dtmsvr.PopulateMysql() - dtmsvr.Main() + dtmsvr.MainStart() examples.StartMain() - for { - time.Sleep(1000 * time.Second) - } + wait() } app := examples.BaseAppNew() examples.BaseAppSetup(app) if len(os.Args) == 1 || os.Args[1] == "saga" { // 默认情况下,展示saga例子 dtmsvr.PopulateMysql() - dtmsvr.Main() + dtmsvr.MainStart() examples.SagaSetup(app) examples.BaseAppStart(app) examples.SagaFireRequest() } else if os.Args[1] == "xa" { // 启动xa示例 dtmsvr.PopulateMysql() - dtmsvr.Main() + dtmsvr.MainStart() examples.PopulateMysql() examples.XaSetup(app) examples.BaseAppStart(app) examples.XaFireRequest() } else if os.Args[1] == "dtmsvr" { // 只启动dtmsvr - go dtmsvr.StartSvr() + go dtmsvr.MainStart() } else if os.Args[1] == "all" { // 运行所有示例 dtmsvr.PopulateMysql() examples.PopulateMysql() - dtmsvr.Main() + dtmsvr.MainStart() examples.SagaSetup(app) examples.TccSetup(app) examples.XaSetup(app) @@ -48,10 +50,13 @@ func main() { examples.SagaFireRequest() examples.TccFireRequest() examples.XaFireRequest() + } else if os.Args[1] == "saga_barrier" { + dtmsvr.PopulateMysql() + dtmsvr.MainStart() + examples.PopulateMysql() + examples.SagaBarrierMainStart() } else { logrus.Fatalf("unknown arg: %s", os.Args[1]) } - for { - time.Sleep(1000 * time.Second) - } + wait() } diff --git a/common/types.go b/common/types.go index 770f400..d82dc54 100644 --- a/common/types.go +++ b/common/types.go @@ -114,6 +114,7 @@ func SqlDB2DB(sdb *sql.DB) *DB { Conn: sdb, }), &gorm.Config{}) E2P(err) + db.Use(&tracePlugin{}) return &DB{DB: db} } diff --git a/dtmcli/barrier.go b/dtmcli/barrier.go index 7f4849e..258df7f 100644 --- a/dtmcli/barrier.go +++ b/dtmcli/barrier.go @@ -23,12 +23,16 @@ func (t *TransInfo) String() string { } func TransInfoFromReq(c *gin.Context) *TransInfo { - return &TransInfo{ + ti := &TransInfo{ TransType: c.Query("trans_type"), Gid: c.Query("gid"), BranchID: c.Query("branch_id"), BranchType: c.Query("branch_type"), } + if ti.TransType == "" || ti.Gid == "" || ti.BranchID == "" || ti.BranchType == "" { + panic(fmt.Errorf("invlid trans info: %v", ti)) + } + return ti } type BarrierModel struct { diff --git a/dtmsvr/main.go b/dtmsvr/main.go index a07a172..b1c6a2f 100644 --- a/dtmsvr/main.go +++ b/dtmsvr/main.go @@ -2,6 +2,7 @@ package dtmsvr import ( "fmt" + "time" "github.com/sirupsen/logrus" "github.com/yedf/dtm/common" @@ -10,8 +11,8 @@ import ( var dtmsvrPort = 8080 -func Main() { - go StartSvr() +func MainStart() { + StartSvr() go CronCommitted() go CronPrepared() } @@ -23,7 +24,8 @@ func StartSvr() { app := common.GetGinApp() AddRoute(app) logrus.Printf("dtmsvr listen at: %d", dtmsvrPort) - app.Run(fmt.Sprintf(":%d", dtmsvrPort)) + go app.Run(fmt.Sprintf(":%d", dtmsvrPort)) + time.Sleep(100 * time.Millisecond) } func PopulateMysql() { diff --git a/dtmsvr/trans.go b/dtmsvr/trans.go index 7677896..c84bb28 100644 --- a/dtmsvr/trans.go +++ b/dtmsvr/trans.go @@ -117,6 +117,15 @@ func (trans *TransGlobal) Process(db *common.DB) { trans.getProcessor().ProcessOnce(db, branches) } +func (trans *TransGlobal) getBranchParams(branch *TransBranch) common.MS { + return common.MS{ + "gid": trans.Gid, + "trans_type": trans.TransType, + "branch_id": branch.BranchID, + "branch_type": branch.BranchType, + } +} + func (t *TransGlobal) setNextCron(expireIn int64) []string { t.NextCronInterval = expireIn next := time.Now().Add(time.Duration(config.TransCronInterval) * time.Second) diff --git a/dtmsvr/trans_msg.go b/dtmsvr/trans_msg.go index 91388e5..b814bb2 100644 --- a/dtmsvr/trans_msg.go +++ b/dtmsvr/trans_msg.go @@ -33,7 +33,7 @@ func (t *TransMsgProcessor) GenBranches() []TransBranch { } func (t *TransMsgProcessor) ExecBranch(db *common.DB, branch *TransBranch) { - resp, err := common.RestyClient.R().SetBody(branch.Data).SetQueryParam("gid", branch.Gid).Post(branch.Url) + resp, err := common.RestyClient.R().SetBody(branch.Data).SetQueryParams(t.getBranchParams(branch)).Post(branch.Url) e2p(err) body := resp.String() if strings.Contains(body, "SUCCESS") { diff --git a/dtmsvr/trans_saga.go b/dtmsvr/trans_saga.go index ac69e86..7abb70c 100644 --- a/dtmsvr/trans_saga.go +++ b/dtmsvr/trans_saga.go @@ -36,7 +36,7 @@ func (t *TransSagaProcessor) GenBranches() []TransBranch { } func (t *TransSagaProcessor) ExecBranch(db *common.DB, branch *TransBranch) { - resp, err := common.RestyClient.R().SetBody(branch.Data).SetQueryParam("gid", branch.Gid).Post(branch.Url) + resp, err := common.RestyClient.R().SetBody(branch.Data).SetQueryParams(t.getBranchParams(branch)).Post(branch.Url) e2p(err) body := resp.String() if strings.Contains(body, "SUCCESS") { diff --git a/dtmsvr/trans_tcc.go b/dtmsvr/trans_tcc.go index 0f0e27c..49b3444 100644 --- a/dtmsvr/trans_tcc.go +++ b/dtmsvr/trans_tcc.go @@ -20,7 +20,7 @@ func (t *TransTccProcessor) GenBranches() []TransBranch { } func (t *TransTccProcessor) ExecBranch(db *common.DB, branch *TransBranch) { - resp, err := common.RestyClient.R().SetBody(branch.Data).SetQueryParam("gid", branch.Gid).Post(branch.Url) + resp, err := common.RestyClient.R().SetBody(branch.Data).SetQueryParams(t.getBranchParams(branch)).Post(branch.Url) e2p(err) body := resp.String() if strings.Contains(body, "SUCCESS") { diff --git a/examples/main_saga_barrier.go b/examples/main_saga_barrier.go index 2eee16e..a1f6d2d 100644 --- a/examples/main_saga_barrier.go +++ b/examples/main_saga_barrier.go @@ -17,17 +17,17 @@ const SagaBarrierBusiApi = "/api/busi_saga_barrier" var SagaBarrierBusi = fmt.Sprintf("http://localhost:%d%s", SagaBarrierBusiPort, SagaBarrierBusiApi) -func SagaBarrierMain() { - go SagaBarrierStartSvr() +func SagaBarrierMainStart() { + SagaBarrierStartSvr() SagaBarrierFireRequest() - time.Sleep(1000 * time.Second) } func SagaBarrierStartSvr() { logrus.Printf("saga barrier examples starting") app := common.GetGinApp() SagaBarrierAddRoute(app) - app.Run(fmt.Sprintf(":%d", SagaBarrierBusiPort)) + go app.Run(fmt.Sprintf(":%d", SagaBarrierBusiPort)) + time.Sleep(100 * time.Millisecond) } func SagaBarrierFireRequest() { @@ -55,7 +55,7 @@ func sagaBarrierTransIn(c *gin.Context) (interface{}, error) { req := reqFrom(c) return dtmcli.ThroughBarrierCall(dbGet().ToSqlDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { db := common.SqlDB2DB(sdb) - dbr := db.Model(&UserAccount{}).Where("user_id = ?", c.Query("user_id")). + dbr := db.Model(&UserAccount{}).Where("user_id = ?", 1). Update("balance", gorm.Expr("balance + ?", req.Amount)) return "SUCCESS", dbr.Error }) @@ -65,7 +65,7 @@ func sagaBarrierTransInCompensate(c *gin.Context) (interface{}, error) { req := reqFrom(c) return dtmcli.ThroughBarrierCall(dbGet().ToSqlDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { db := common.SqlDB2DB(sdb) - dbr := db.Model(&UserAccount{}).Where("user_id = ?", c.Query("user_id")). + dbr := db.Model(&UserAccount{}).Where("user_id = ?", 1). Update("balance", gorm.Expr("balance - ?", req.Amount)) return "SUCCESS", dbr.Error }) @@ -75,7 +75,7 @@ func sagaBarrierTransOut(c *gin.Context) (interface{}, error) { req := reqFrom(c) return dtmcli.ThroughBarrierCall(dbGet().ToSqlDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { db := common.SqlDB2DB(sdb) - dbr := db.Model(&UserAccount{}).Where("user_id = ?", c.Query("user_id")). + dbr := db.Model(&UserAccount{}).Where("user_id = ?", 2). Update("balance", gorm.Expr("balance - ?", req.Amount)) return "SUCCESS", dbr.Error }) @@ -85,7 +85,7 @@ func sagaBarrierTransOutCompensate(c *gin.Context) (interface{}, error) { req := reqFrom(c) return dtmcli.ThroughBarrierCall(dbGet().ToSqlDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { db := common.SqlDB2DB(sdb) - dbr := db.Model(&UserAccount{}).Where("user_id = ?", c.Query("user_id")). + dbr := db.Model(&UserAccount{}).Where("user_id = ?", 2). Update("balance", gorm.Expr("balance + ?", req.Amount)) return "SUCCESS", dbr.Error }) diff --git a/examples/quick_start.go b/examples/quick_start.go index 2fc40cc..8155622 100644 --- a/examples/quick_start.go +++ b/examples/quick_start.go @@ -20,16 +20,16 @@ var qsBusi = fmt.Sprintf("http://localhost:%d%s", qsBusiPort, qsBusiApi) // 被app/main.go调用,启动服务并运行示例 func StartMain() { - go qsStartSvr() + qsStartSvr() qsFireRequest() - time.Sleep(1000 * time.Second) } func qsStartSvr() { app := common.GetGinApp() qsAddRoute(app) logrus.Printf("quick qs examples listening at %d", qsBusiPort) - app.Run(fmt.Sprintf(":%d", qsBusiPort)) + go app.Run(fmt.Sprintf(":%d", qsBusiPort)) + time.Sleep(100 * time.Millisecond) } func qsFireRequest() {