From fa8abb239f19858454592783b5aeab06611e664a Mon Sep 17 00:00:00 2001 From: yedongfu Date: Fri, 28 May 2021 20:29:33 +0800 Subject: [PATCH] saga should refactor next --- dtmsvr/api.go | 16 ++++++++++++---- dtmsvr/dtmsvr_test.go | 40 ++++++++++++++++++++++++---------------- dtmsvr/trans.go | 26 ++++++-------------------- 3 files changed, 42 insertions(+), 40 deletions(-) diff --git a/dtmsvr/api.go b/dtmsvr/api.go index 2f7abf4..0dbe103 100644 --- a/dtmsvr/api.go +++ b/dtmsvr/api.go @@ -5,6 +5,7 @@ import ( "github.com/gin-gonic/gin" "github.com/yedf/dtm/common" + "gorm.io/gorm" "gorm.io/gorm/clause" ) @@ -47,9 +48,16 @@ func Branch(c *gin.Context) (interface{}, error) { branch := TransBranch{} err := c.BindJSON(&branch) e2p(err) - db := dbGet() - db.Must().Clauses(clause.OnConflict{ - DoNothing: true, - }).Create(&branch) + branches := []TransBranch{branch, branch} + err = dbGet().Transaction(func(tx *gorm.DB) error { + db := &common.MyDb{DB: tx} + branches[0].BranchType = "rollback" + branches[1].BranchType = "commit" + db.Must().Clauses(clause.OnConflict{ + DoNothing: true, + }).Create(branches) + return nil + }) + e2p(err) return M{"message": "SUCCESS"}, nil } diff --git a/dtmsvr/dtmsvr_test.go b/dtmsvr/dtmsvr_test.go index 3a1d02b..6acefdc 100644 --- a/dtmsvr/dtmsvr_test.go +++ b/dtmsvr/dtmsvr_test.go @@ -29,6 +29,7 @@ func TestDtmSvr(t *testing.T) { go StartSvr() go examples.SagaStartSvr() go examples.XaStartSvr() + go examples.TccStartSvr() time.Sleep(time.Duration(100 * 1000 * 1000)) // 清理数据 @@ -36,7 +37,8 @@ func TestDtmSvr(t *testing.T) { e2p(dbGet().Exec("truncate trans_branch").Error) e2p(dbGet().Exec("truncate trans_log").Error) examples.ResetXaData() - // tccNormal(t) + tccNormal(t) + tccRollback(t) sagaCommittedPending(t) sagaPreparePending(t) xaRollback(t) @@ -58,7 +60,7 @@ func TestCover(t *testing.T) { // 测试使用的全局对象 var initdb = dbGet() -func getSagaModel(gid string) *TransGlobal { +func getTransStatus(gid string) *TransGlobal { sm := TransGlobal{} dbr := dbGet().Model(&sm).Where("gid=?", gid).First(&sm) e2p(dbr.Error) @@ -67,7 +69,7 @@ func getSagaModel(gid string) *TransGlobal { func getBranchesStatus(gid string) []string { steps := []TransBranch{} - dbr := dbGet().Model(&TransBranch{}).Where("gid=?", gid).Find(&steps) + dbr := dbGet().Model(&TransBranch{}).Where("gid=?", gid).Order("id").Find(&steps) e2p(dbr.Error) status := []string{} for _, step := range steps { @@ -95,7 +97,7 @@ func xaNormal(t *testing.T) { }) e2p(err) WaitTransProcessed(gid) - assert.Equal(t, []string{"succeed", "succeed"}, getBranchesStatus(gid)) + assert.Equal(t, []string{"prepared", "succeed", "prepared", "succeed"}, getBranchesStatus(gid)) } func xaRollback(t *testing.T) { @@ -119,25 +121,31 @@ func xaRollback(t *testing.T) { logrus.Errorf("global transaction failed, so rollback") } WaitTransProcessed(gid) - assert.Equal(t, []string{"failed"}, getBranchesStatus(gid)) + assert.Equal(t, []string{"succeed", "prepared"}, getBranchesStatus(gid)) + assert.Equal(t, "failed", getTransStatus(gid).Status) } func tccNormal(t *testing.T) { - tcc := genTcc("gid-normal-tcc", false, false) + tcc := genTcc("gid-tcc-normal", false, false) tcc.Prepare(tcc.QueryPrepared) - assert.Equal(t, "prepared", getSagaModel(tcc.Gid).Status) + assert.Equal(t, "prepared", getTransStatus(tcc.Gid).Status) tcc.Commit() - assert.Equal(t, "committed", getSagaModel(tcc.Gid).Status) + assert.Equal(t, "committed", getTransStatus(tcc.Gid).Status) WaitTransProcessed(tcc.Gid) assert.Equal(t, []string{"prepared", "succeed", "succeed", "prepared", "succeed", "succeed"}, getBranchesStatus(tcc.Gid)) - +} +func tccRollback(t *testing.T) { + tcc := genTcc("gid-tcc-rollback", false, true) + tcc.Commit() + WaitTransProcessed(tcc.Gid) + assert.Equal(t, []string{"succeed", "prepared", "succeed", "succeed", "prepared", "failed"}, getBranchesStatus(tcc.Gid)) } func sagaNormal(t *testing.T) { saga := genSaga("gid-noramlSaga", false, false) saga.Prepare(saga.QueryPrepared) - assert.Equal(t, "prepared", getSagaModel(saga.Gid).Status) + assert.Equal(t, "prepared", getTransStatus(saga.Gid).Status) saga.Commit() - assert.Equal(t, "committed", getSagaModel(saga.Gid).Status) + assert.Equal(t, "committed", getTransStatus(saga.Gid).Status) WaitTransProcessed(saga.Gid) assert.Equal(t, []string{"prepared", "succeed", "prepared", "succeed"}, getBranchesStatus(saga.Gid)) } @@ -147,7 +155,7 @@ func sagaRollback(t *testing.T) { saga.Commit() WaitTransProcessed(saga.Gid) saga.Prepare(saga.QueryPrepared) - assert.Equal(t, "failed", getSagaModel(saga.Gid).Status) + assert.Equal(t, "failed", getTransStatus(saga.Gid).Status) assert.Equal(t, []string{"failed", "succeed", "failed", "failed"}, getBranchesStatus(saga.Gid)) } @@ -159,7 +167,7 @@ func sagaPrepareCancel(t *testing.T) { CronTransOnce(-10*time.Second, "prepared") examples.SagaTransQueryResult = "" config.PreparedExpire = 60 - assert.Equal(t, "canceled", getSagaModel(saga.Gid).Status) + assert.Equal(t, "canceled", getTransStatus(saga.Gid).Status) } func sagaPreparePending(t *testing.T) { @@ -168,9 +176,9 @@ func sagaPreparePending(t *testing.T) { examples.SagaTransQueryResult = "PENDING" CronTransOnce(-10*time.Second, "prepared") examples.SagaTransQueryResult = "" - assert.Equal(t, "prepared", getSagaModel(saga.Gid).Status) + assert.Equal(t, "prepared", getTransStatus(saga.Gid).Status) CronTransOnce(-10*time.Second, "prepared") - assert.Equal(t, "succeed", getSagaModel(saga.Gid).Status) + assert.Equal(t, "succeed", getTransStatus(saga.Gid).Status) } func sagaCommittedPending(t *testing.T) { @@ -183,7 +191,7 @@ func sagaCommittedPending(t *testing.T) { assert.Equal(t, []string{"prepared", "succeed", "prepared", "prepared"}, getBranchesStatus(saga.Gid)) CronTransOnce(-10*time.Second, "committed") assert.Equal(t, []string{"prepared", "succeed", "prepared", "succeed"}, getBranchesStatus(saga.Gid)) - assert.Equal(t, "succeed", getSagaModel(saga.Gid).Status) + assert.Equal(t, "succeed", getTransStatus(saga.Gid).Status) } func genSaga(gid string, outFailed bool, inFailed bool) *dtm.Saga { diff --git a/dtmsvr/trans.go b/dtmsvr/trans.go index 7bc9ca9..7672f09 100644 --- a/dtmsvr/trans.go +++ b/dtmsvr/trans.go @@ -137,8 +137,7 @@ func (t *TransTccProcessor) ExecBranch(db *common.MyDb, branch *TransBranch) str body := resp.String() t.touch(db) if strings.Contains(body, "SUCCESS") { - status := common.If(branch.BranchType == "cancel", "failed", "succeed").(string) - branch.changeStatus(db, status) + branch.changeStatus(db, "succeed") return "SUCCESS" } if branch.BranchType == "try" && strings.Contains(body, "FAIL") { @@ -193,7 +192,7 @@ func (t *TransXaProcessor) ExecBranch(db *common.MyDb, branch *TransBranch) stri if !strings.Contains(body, "SUCCESS") { panic(fmt.Errorf("bad response: %s", body)) } - branch.changeStatus(db, common.If(t.Status == "prepared", "failed", "succeed").(string)) + branch.changeStatus(db, "succeed") return "SUCCESS" } @@ -201,25 +200,12 @@ func (t *TransXaProcessor) ProcessOnce(db *common.MyDb, branches []TransBranch) if t.Status == "succeed" { return } - if t.Status == "committed" { - for _, branch := range branches { - if branch.Status == "succeed" { - continue - } - _ = t.ExecBranch(db, &branch) - t.touch(db) // 更新update_time,避免被定时任务再次 - } - t.changeStatus(db, "succeed") - } else if t.Status == "prepared" { // 未commit直接处理的情况为回滚场景 - for _, branch := range branches { - if branch.Status == "failed" { - continue - } + currentType := common.If(t.Status == "committed", "commit", "rollback").(string) + for _, branch := range branches { + if branch.BranchType == currentType && branch.Status != "succeed" { _ = t.ExecBranch(db, &branch) t.touch(db) } - t.changeStatus(db, "failed") - } else { - e2p(fmt.Errorf("bad trans status: %s", t.Status)) } + t.changeStatus(db, common.If(t.Status == "committed", "succeed", "failed").(string)) }