diff --git a/dtmcli/saga.go b/dtmcli/saga.go index 7399570..0d4deb1 100644 --- a/dtmcli/saga.go +++ b/dtmcli/saga.go @@ -32,6 +32,7 @@ type SagaStep struct { func NewSaga(server string) *Saga { return &Saga{ SagaData: SagaData{ + Gid: GenGid(server), TransType: "saga", }, Server: server, diff --git a/dtmcli/tcc.go b/dtmcli/tcc.go index e29f1dc..302b66b 100644 --- a/dtmcli/tcc.go +++ b/dtmcli/tcc.go @@ -5,6 +5,7 @@ import ( "github.com/gin-gonic/gin" "github.com/go-resty/resty/v2" + "github.com/sirupsen/logrus" "github.com/yedf/dtm/common" ) @@ -26,10 +27,14 @@ func TccGlobalTransaction(dtm string, tccFunc TccGlobalFunc) (gid string, rerr e "trans_type": "tcc", } defer func() { + var err error if x := recover(); x != nil || rerr != nil { - _, rerr = common.RestyClient.R().SetBody(data).Post(dtm + "/abort") + _, err = common.RestyClient.R().SetBody(data).Post(dtm + "/abort") } else { - _, rerr = common.RestyClient.R().SetBody(data).Post(dtm + "/submit") + _, err = common.RestyClient.R().SetBody(data).Post(dtm + "/submit") + } + if err != nil { + logrus.Errorf("submitting or abort global transaction error: %v", err) } }() tcc := &Tcc{Dtm: dtm, Gid: gid} diff --git a/dtmsvr/api.go b/dtmsvr/api.go index 060ec36..e9a4d5d 100644 --- a/dtmsvr/api.go +++ b/dtmsvr/api.go @@ -25,38 +25,46 @@ func newGid(c *gin.Context) (interface{}, error) { } func prepare(c *gin.Context) (interface{}, error) { - m := TransFromContext(c) - m.Status = "prepared" - m.saveNew(dbGet()) - return M{"message": "SUCCESS", "gid": m.Gid}, nil + t := TransFromContext(c) + t.Status = "prepared" + t.saveNew(dbGet()) + return M{"dtm_result": "SUCCESS", "gid": t.Gid}, nil } func submit(c *gin.Context) (interface{}, error) { db := dbGet() - m := TransFromContext(c) - m.Status = "submitted" - m.saveNew(db) - go m.Process(db) - return M{"message": "SUCCESS", "gid": m.Gid}, nil + t := TransFromContext(c) + dbt := TransFromDb(db, t.Gid) + if dbt != nil && dbt.Status != "prepared" && dbt.Status != "submitted" { + return M{"dtm_result": "FAILURE", "message": fmt.Sprintf("current status %s, cannot sumbmit", dbt.Status)}, nil + } + t.Status = "submitted" + t.saveNew(db) + go t.Process(db) + return M{"dtm_result": "SUCCESS", "gid": t.Gid}, nil } func abort(c *gin.Context) (interface{}, error) { db := dbGet() - m := TransFromContext(c) - m = TransFromDb(db, m.Gid) - if m.TransType != "xa" && m.TransType != "tcc" || m.Status != "prepared" { - return nil, fmt.Errorf("unexpected trans data. type: %s status: %s for gid: %s", m.TransType, m.Status, m.Gid) + t := TransFromContext(c) + dbt := TransFromDb(db, t.Gid) + if t.TransType != "xa" && t.TransType != "tcc" || dbt.Status != "prepared" && dbt.Status != "aborting" { + return M{"dtm_result": "FAILURE", "message": fmt.Sprintf("trans type: %s current status %s, cannot abort", dbt.TransType, dbt.Status)}, nil } - go m.Process(db) - return M{"message": "SUCCESS"}, nil + go dbt.Process(db) + return M{"dtm_result": "SUCCESS"}, nil } func registerXaBranch(c *gin.Context) (interface{}, error) { branch := TransBranch{} err := c.BindJSON(&branch) e2p(err) - branches := []TransBranch{branch, branch} db := dbGet() + dbt := TransFromDb(db, branch.Gid) + if dbt.Status != "prepared" { + return M{"dtm_result": "FAILURE", "message": fmt.Sprintf("current status: %s cannot register branch", dbt.Status)}, nil + } + branches := []TransBranch{branch, branch} branches[0].BranchType = "rollback" branches[1].BranchType = "commit" db.Must().Clauses(clause.OnConflict{ @@ -65,7 +73,7 @@ func registerXaBranch(c *gin.Context) (interface{}, error) { e2p(err) global := TransGlobal{Gid: branch.Gid} global.touch(db, config.TransCronInterval) - return M{"message": "SUCCESS"}, nil + return M{"dtm_result": "SUCCESS"}, nil } func registerTccBranch(c *gin.Context) (interface{}, error) { @@ -78,6 +86,11 @@ func registerTccBranch(c *gin.Context) (interface{}, error) { Status: data["status"], Data: data["data"], } + db := dbGet() + dbt := TransFromDb(db, branch.Gid) + if dbt.Status != "prepared" { + return M{"dtm_result": "FAILURE", "message": fmt.Sprintf("current status: %s cannot register branch", dbt.Status)}, nil + } branches := []TransBranch{branch, branch, branch} for i, b := range []string{"cancel", "confirm", "try"} { @@ -85,13 +98,13 @@ func registerTccBranch(c *gin.Context) (interface{}, error) { branches[i].URL = data[b] } - dbGet().Must().Clauses(clause.OnConflict{ + db.Must().Clauses(clause.OnConflict{ DoNothing: true, }).Create(branches) e2p(err) global := TransGlobal{Gid: branch.Gid} global.touch(dbGet(), config.TransCronInterval) - return M{"message": "SUCCESS"}, nil + return M{"dtm_result": "SUCCESS"}, nil } func query(c *gin.Context) (interface{}, error) { diff --git a/dtmsvr/cron.go b/dtmsvr/cron.go index 69913d7..3f79f63 100644 --- a/dtmsvr/cron.go +++ b/dtmsvr/cron.go @@ -9,10 +9,10 @@ import ( "github.com/sirupsen/logrus" ) -// CronTransOnce cron expired trans who's status match param status for once. use expireIn as expire time -func CronTransOnce(expireIn time.Duration, status string) bool { +// CronTransOnce cron expired trans. use expireIn as expire time +func CronTransOnce(expireIn time.Duration) bool { defer handlePanic() - trans := lockOneTrans(expireIn, status) + trans := lockOneTrans(expireIn) if trans == nil { return false } @@ -22,21 +22,21 @@ func CronTransOnce(expireIn time.Duration, status string) bool { } // CronExpiredTrans cron expired trans, num == -1 indicate for ever -func CronExpiredTrans(status string, num int) { +func CronExpiredTrans(num int) { for i := 0; i < num || num == -1; i++ { - notEmpty := CronTransOnce(time.Duration(0), status) + notEmpty := CronTransOnce(time.Duration(0)) if !notEmpty { sleepCronTime() } } } -func lockOneTrans(expireIn time.Duration, status string) *TransGlobal { +func lockOneTrans(expireIn time.Duration) *TransGlobal { trans := TransGlobal{} owner := GenGid() db := dbGet() dbr := db.Must().Model(&trans). - Where("next_cron_time < date_add(now(), interval ? second) and status=?", int(expireIn/time.Second), status). + Where("next_cron_time < date_add(now(), interval ? second) and status in ('prepared', 'aborting', 'submitted')", int(expireIn/time.Second)). Limit(1).Update("owner", owner) if dbr.RowsAffected == 0 { return nil diff --git a/dtmsvr/dtmsvr_test.go b/dtmsvr/dtmsvr_test.go index 4a7bcb9..5196d3c 100644 --- a/dtmsvr/dtmsvr_test.go +++ b/dtmsvr/dtmsvr_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/yedf/dtm/common" @@ -16,19 +17,17 @@ import ( var DtmServer = examples.DtmServer var Busi = examples.Busi +var app *gin.Engine func init() { + TransProcessedTestChan = make(chan string, 1) common.InitApp(common.GetProjectDir(), &config) config.Mysql["database"] = dbName PopulateMysql() examples.PopulateMysql() -} - -func TestDtmSvr(t *testing.T) { - TransProcessedTestChan = make(chan string, 1) // 启动组件 go StartSvr() - app := examples.BaseAppStartup() + app = examples.BaseAppStartup() examples.SagaSetup(app) examples.TccSetup(app) examples.XaSetup(app) @@ -41,7 +40,11 @@ func TestDtmSvr(t *testing.T) { e2p(dbGet().Exec("truncate trans_branch").Error) e2p(dbGet().Exec("truncate trans_log").Error) examples.ResetXaData() +} +func TestDtmSvr(t *testing.T) { + + tccBarrierDisorder(t) tccBarrierNormal(t) tccBarrierRollback(t) sagaBarrierNormal(t) @@ -70,12 +73,11 @@ func TestDtmSvr(t *testing.T) { func TestCover(t *testing.T) { db := dbGet() db.NoMust() - CronTransOnce(0, "prepared") - CronTransOnce(0, "submitted") + CronTransOnce(0) defer handlePanic() checkAffected(db.DB) - go CronExpiredTrans("submitted", 1) + go CronExpiredTrans(1) } func getTransStatus(gid string) string { @@ -176,7 +178,7 @@ func tccBarrierRollback(t *testing.T) { logrus.Printf("tcc returns: %s, %s", res1.String(), res2.String()) return }) - e2p(err) + assert.Equal(t, err, fmt.Errorf("branch trans in fail")) WaitTransProcessed(gid) assert.Equal(t, "failed", getTransStatus(gid)) } @@ -206,12 +208,12 @@ func msgPending(t *testing.T) { msg.Prepare("") assert.Equal(t, "prepared", getTransStatus(msg.Gid)) examples.MainSwitch.CanSubmitResult.SetOnce("PENDING") - CronTransOnce(60*time.Second, "prepared") + CronTransOnce(60 * time.Second) assert.Equal(t, "prepared", getTransStatus(msg.Gid)) examples.MainSwitch.TransInResult.SetOnce("PENDING") - CronTransOnce(60*time.Second, "prepared") + CronTransOnce(60 * time.Second) assert.Equal(t, "submitted", getTransStatus(msg.Gid)) - CronTransOnce(60*time.Second, "submitted") + CronTransOnce(60 * time.Second) assert.Equal(t, []string{"succeed", "succeed"}, getBranchesStatus(msg.Gid)) assert.Equal(t, "succeed", getTransStatus(msg.Gid)) } @@ -262,7 +264,7 @@ func sagaCommittedPending(t *testing.T) { saga.Submit() WaitTransProcessed(saga.Gid) assert.Equal(t, []string{"prepared", "prepared", "prepared", "prepared"}, getBranchesStatus(saga.Gid)) - CronTransOnce(60*time.Second, "submitted") + CronTransOnce(60 * time.Second) assert.Equal(t, []string{"prepared", "succeed", "prepared", "succeed"}, getBranchesStatus(saga.Gid)) assert.Equal(t, "succeed", getTransStatus(saga.Gid)) } @@ -341,3 +343,70 @@ func TestSqlDB(t *testing.T) { dbr = db.Model(&dtmcli.BarrierModel{}).Where("gid=?", "gid2").Find(&[]dtmcli.BarrierModel{}) asserts.Equal(dbr.RowsAffected, int64(1)) } + +func tccBarrierDisorder(t *testing.T) { + timeoutChan := make(chan string, 2) + finishedChan := make(chan string, 2) + gid, err := dtmcli.TccGlobalTransaction(DtmServer, func(tcc *dtmcli.Tcc) (rerr error) { + body := &examples.TransReq{Amount: 30} + tryURL := Busi + "/TccBTransOutTry" + confirmURL := Busi + "/TccBTransOutConfirm" + cancelURL := Busi + "/TccBSleepCancel" + // 请参见子事务屏障里的时序图,这里为了模拟该时序图,手动拆解了callbranch + branchID := tcc.NewBranchID() + sleeped := false + app.POST(examples.BusiAPI+"/TccBSleepCancel", common.WrapHandler(func(c *gin.Context) (interface{}, error) { + res, err := examples.TccBarrierTransOutCancel(c) + if !sleeped { + sleeped = true + logrus.Printf("sleep before cancel return") + <-timeoutChan + finishedChan <- "1" + } + return res, err + })) + // 注册子事务 + _, err := common.RestyClient.R(). + SetBody(&M{ + "gid": tcc.Gid, + "branch_id": branchID, + "trans_type": "tcc", + "status": "prepared", + "data": string(common.MustMarshal(body)), + "try": tryURL, + "confirm": confirmURL, + "cancel": cancelURL, + }). + Post(tcc.Dtm + "/registerTccBranch") + e2p(err) + go func() { + logrus.Printf("sleeping to wait for tcc try timeout") + <-timeoutChan + _, _ = common.RestyClient.R(). + SetBody(body). + SetQueryParams(common.MS{ + "dtm": tcc.Dtm, + "gid": tcc.Gid, + "branch_id": branchID, + "trans_type": "tcc", + "branch_type": "try", + }). + Post(tryURL) + finishedChan <- "1" + }() + logrus.Printf("cron to timeout and then call cancel") + go CronTransOnce(60 * time.Second) + time.Sleep(100 * time.Millisecond) + logrus.Printf("cron to timeout and then call cancelled twice") + CronTransOnce(60 * time.Second) + timeoutChan <- "wake" + timeoutChan <- "wake" + <-finishedChan + <-finishedChan + time.Sleep(100 * time.Millisecond) + return fmt.Errorf("a cancelled tcc") + }) + assert.Error(t, err, fmt.Errorf("a cancelled tcc")) + assert.Equal(t, []string{"succeed", "prepared", "prepared"}, getBranchesStatus(gid)) + assert.Equal(t, "failed", getTransStatus(gid)) +} diff --git a/dtmsvr/main.go b/dtmsvr/main.go index 1083eb3..a19a61b 100644 --- a/dtmsvr/main.go +++ b/dtmsvr/main.go @@ -14,8 +14,7 @@ var dtmsvrPort = 8080 // MainStart main func MainStart() { StartSvr() - go CronExpiredTrans("submitted", -1) - go CronExpiredTrans("prepared", -1) + go CronExpiredTrans(-1) } // StartSvr StartSvr diff --git a/dtmsvr/trans.go b/dtmsvr/trans.go index 79b6d38..3e8a156 100644 --- a/dtmsvr/trans.go +++ b/dtmsvr/trans.go @@ -118,7 +118,10 @@ func (t *TransGlobal) Process(db *common.DB) { TransProcessedTestChan <- t.Gid } }() - logrus.Printf("processing: %s", t.Gid) + logrus.Printf("processing: %s status: %s", t.Gid, t.Status) + if t.Status == "prepared" && t.TransType != "msg" { + t.changeStatus(db, "aborting") + } branches := []TransBranch{} db.Must().Where("gid=?", t.Gid).Order("id asc").Find(&branches) t.getProcessor().ProcessOnce(db, branches) @@ -186,6 +189,9 @@ func TransFromContext(c *gin.Context) *TransGlobal { func TransFromDb(db *common.DB, gid string) *TransGlobal { m := TransGlobal{} dbr := db.Must().Model(&m).Where("gid=?", gid).First(&m) + if dbr.Error == gorm.ErrRecordNotFound { + return nil + } e2p(dbr.Error) return &m } diff --git a/dtmsvr/trans_saga.go b/dtmsvr/trans_saga.go index bc9e641..8be49dd 100644 --- a/dtmsvr/trans_saga.go +++ b/dtmsvr/trans_saga.go @@ -69,6 +69,9 @@ func (t *transSagaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) t.changeStatus(db, "succeed") return } + if t.Status != "aborting" && t.Status != "failed" { + t.changeStatus(db, "aborting") + } for current = current - 1; current >= 0; current-- { branch := &branches[current] if branch.BranchType != "compensate" || branch.Status != "prepared" { diff --git a/examples/main_tcc_barrier.go b/examples/main_tcc_barrier.go index f02191b..22df1d5 100644 --- a/examples/main_tcc_barrier.go +++ b/examples/main_tcc_barrier.go @@ -42,7 +42,7 @@ func TccBarrierAddRoute(app *gin.Engine) { app.POST(BusiAPI+"/TccBTransInCancel", common.WrapHandler(tccBarrierTransInCancel)) app.POST(BusiAPI+"/TccBTransOutTry", common.WrapHandler(tccBarrierTransOutTry)) app.POST(BusiAPI+"/TccBTransOutConfirm", common.WrapHandler(tccBarrierTransOutConfirm)) - app.POST(BusiAPI+"/TccBTransOutCancel", common.WrapHandler(tccBarrierTransOutCancel)) + app.POST(BusiAPI+"/TccBTransOutCancel", common.WrapHandler(TccBarrierTransOutCancel)) logrus.Printf("examples listening at %d", BusiPort) } @@ -112,7 +112,7 @@ func tccBarrierTransOutConfirm(c *gin.Context) (interface{}, error) { }) } -func tccBarrierTransOutCancel(c *gin.Context) (interface{}, error) { +func TccBarrierTransOutCancel(c *gin.Context) (interface{}, error) { return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { return adjustTrading(sdb, transOutUID, reqFrom(c).Amount) })