diff --git a/app/main.go b/app/main.go index 95aef70..2f35461 100644 --- a/app/main.go +++ b/app/main.go @@ -14,7 +14,9 @@ func main() { if len(os.Args) == 1 { // 所有服务都启动 go dtmsvr.StartSvr() go examples.SagaStartSvr() - } else if len(os.Args) > 1 && os.Args[1] == "dtmsvr" { + go examples.TccStartSvr() + go examples.XaStartSvr() + } else if os.Args[1] == "dtmsvr" { go dtmsvr.StartSvr() } for { diff --git a/common/conf.yml.sample b/common/conf.yml.sample new file mode 100644 index 0000000..e69de29 diff --git a/common/types_test.go b/common/types_test.go new file mode 100644 index 0000000..9ac4fff --- /dev/null +++ b/common/types_test.go @@ -0,0 +1,39 @@ +package common + +import ( + "testing" + + "github.com/go-playground/assert/v2" +) + +type testConfig struct { + Mysql map[string]string +} + +var config = testConfig{} + +var myinit int = func() int { + InitApp(&config) + return 0 +}() + +func TestDb(t *testing.T) { + db := DbGet(config.Mysql) + err := func() (rerr error) { + defer P2E(&rerr) + dbr := db.NoMust().Exec("select a") + assert.NotEqual(t, nil, dbr.Error) + db.Must().Exec("select a") + return nil + }() + assert.NotEqual(t, nil, err) +} + +func TestDbAlone(t *testing.T) { + db, con := DbAlone(config.Mysql) + dbr := db.Exec("select 1") + assert.Equal(t, nil, dbr.Error) + con.Close() + dbr = db.Exec("select 1") + assert.NotEqual(t, nil, dbr.Error) +} diff --git a/common/utils.go b/common/utils.go index 932935f..90c6189 100644 --- a/common/utils.go +++ b/common/utils.go @@ -20,15 +20,6 @@ import ( type M = map[string]interface{} -func OrString(ss ...string) string { - for _, s := range ss { - if s != "" { - return s - } - } - return "" -} - func P2E(perr *error) { if x := recover(); x != nil { if e, ok := x.(error); ok { @@ -39,6 +30,18 @@ func P2E(perr *error) { } } +func E2P(err error) { + if err != nil { + panic(err) + } +} + +func CatchP(f func()) (rerr error) { + defer P2E(&rerr) + f() + return nil +} + func PanicIf(cond bool, err error) { if cond { panic(err) @@ -53,16 +56,17 @@ var gNode *snowflake.Node = nil func init() { node, err := snowflake.NewNode(1) - if err != nil { - panic(err) - } + E2P(err) gNode = node } -func E2P(err error) { - if err != nil { - panic(err) +func OrString(ss ...string) string { + for _, s := range ss { + if s != "" { + return s + } } + return "" } func If(condition bool, trueObj interface{}, falseObj interface{}) interface{} { @@ -102,11 +106,10 @@ func GetGinApp() *gin.Engine { app := gin.Default() app.Use(func(c *gin.Context) { body := "" - if c.Request.Method == "POST" { + if c.Request.Body != nil { rb, err := c.GetRawData() - if err != nil { - logrus.Printf("GetRawData error: %s", err.Error()) - } else { + E2P(err) + if len(rb) > 0 { body = string(rb) c.Request.Body = ioutil.NopCloser(bytes.NewBuffer(rb)) } @@ -196,13 +199,12 @@ func InitApp(config interface{}) { logrus.SetFormatter(&formatter{}) _, file, _, _ := runtime.Caller(1) fileName := filepath.Dir(file) + "/conf.yml" - if configLoaded[fileName] { - return + if !configLoaded[fileName] { + configLoaded[fileName] = true + viper.SetConfigFile(fileName) + err := viper.ReadInConfig() + E2P(err) } - configLoaded[fileName] = true - viper.SetConfigFile(fileName) - err := viper.ReadInConfig() - E2P(err) - err = viper.Unmarshal(config) + err := viper.Unmarshal(config) E2P(err) } diff --git a/common/utils_test.go b/common/utils_test.go new file mode 100644 index 0000000..1c92d23 --- /dev/null +++ b/common/utils_test.go @@ -0,0 +1,95 @@ +package common + +import ( + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/go-playground/assert/v2" +) + +func TestEP(t *testing.T) { + skipped := true + err := func() (rerr error) { + defer P2E(&rerr) + E2P(errors.New("err1")) + skipped = false + return nil + }() + assert.Equal(t, true, skipped) + assert.Equal(t, "err1", err.Error()) + err = CatchP(func() { + PanicIf(true, errors.New("err2")) + }) + assert.Equal(t, "err2", err.Error()) + err = func() (rerr error) { + defer func() { + x := recover() + assert.Equal(t, 1, x) + }() + defer P2E(&rerr) + panic(1) + }() +} + +func TestGid(t *testing.T) { + id1 := GenGid() + id2 := GenGid() + assert.NotEqual(t, id1, id2) +} + +func TestTernary(t *testing.T) { + assert.Equal(t, "1", OrString("", "", "1")) + assert.Equal(t, "", OrString("", "", "")) + assert.Equal(t, "1", If(true, "1", "2")) + assert.Equal(t, "2", If(false, "1", "2")) +} + +func TestMarshal(t *testing.T) { + a := 0 + type e struct { + A int + } + e1 := e{A: 10} + m := map[string]int{} + assert.Equal(t, "1", MustMarshalString(1)) + assert.Equal(t, []byte("1"), MustMarshal(1)) + MustUnmarshal([]byte("2"), &a) + assert.Equal(t, 2, a) + MustUnmarshalString("3", &a) + assert.Equal(t, 3, a) + MustRemarshal(&e1, &m) + assert.Equal(t, 10, m["A"]) +} + +func TestGin(t *testing.T) { + app := GetGinApp() + app.GET("/api/sample", WrapHandler(func(c *gin.Context) (interface{}, error) { + return 1, nil + })) + app.GET("/api/error", WrapHandler(func(c *gin.Context) (interface{}, error) { + return nil, errors.New("err1") + })) + getResultString := func(api string, body io.Reader) string { + req, _ := http.NewRequest("GET", api, body) + w := httptest.NewRecorder() + app.ServeHTTP(w, req) + return string(w.Body.Bytes()) + } + assert.Equal(t, "{\"msg\":\"pong\"}", getResultString("/api/ping", nil)) + assert.Equal(t, "1", getResultString("/api/sample", nil)) + assert.Equal(t, "{\"code\":500,\"message\":\"err1\"}", getResultString("/api/error", strings.NewReader("{}"))) +} + +func TestResty(t *testing.T) { + resp, err := RestyClient.R().Get("http://baidu.com") + assert.Equal(t, nil, err) + err2 := CatchP(func() { + CheckRestySuccess(resp, err) + }) + assert.NotEqual(t, nil, err2) +} diff --git a/dtmsvr/api.go b/dtmsvr/api.go index 0dbe103..14dea34 100644 --- a/dtmsvr/api.go +++ b/dtmsvr/api.go @@ -5,7 +5,6 @@ import ( "github.com/gin-gonic/gin" "github.com/yedf/dtm/common" - "gorm.io/gorm" "gorm.io/gorm/clause" ) @@ -49,15 +48,12 @@ func Branch(c *gin.Context) (interface{}, error) { err := c.BindJSON(&branch) e2p(err) branches := []TransBranch{branch, branch} - err = dbGet().Transaction(func(tx *gorm.DB) error { - db := &common.MyDb{DB: tx} - branches[0].BranchType = "rollback" - branches[1].BranchType = "commit" - db.Must().Clauses(clause.OnConflict{ - DoNothing: true, - }).Create(branches) - return nil - }) + db := dbGet() + branches[0].BranchType = "rollback" + branches[1].BranchType = "commit" + db.Must().Clauses(clause.OnConflict{ + DoNothing: true, + }).Create(branches) e2p(err) return M{"message": "SUCCESS"}, nil } diff --git a/dtmsvr/db.go b/dtmsvr/db.go deleted file mode 100644 index 2820c9a..0000000 --- a/dtmsvr/db.go +++ /dev/null @@ -1,20 +0,0 @@ -package dtmsvr - -import "github.com/yedf/dtm/common" - -func dbGet() *common.MyDb { - return common.DbGet(config.Mysql) -} -func writeTransLog(gid string, action string, status string, branch string, detail string) { - db := dbGet() - if detail == "" { - detail = "{}" - } - db.Must().Table("trans_log").Create(M{ - "gid": gid, - "action": action, - "new_status": status, - "branch": branch, - "detail": detail, - }) -} diff --git a/dtmsvr/dtmsvr_test.go b/dtmsvr/dtmsvr_test.go index 6acefdc..d71ecb1 100644 --- a/dtmsvr/dtmsvr_test.go +++ b/dtmsvr/dtmsvr_test.go @@ -37,12 +37,14 @@ func TestDtmSvr(t *testing.T) { e2p(dbGet().Exec("truncate trans_branch").Error) e2p(dbGet().Exec("truncate trans_log").Error) examples.ResetXaData() + tccNormal(t) tccRollback(t) + tccRollbackPending(t) + xaNormal(t) + xaRollback(t) sagaCommittedPending(t) sagaPreparePending(t) - xaRollback(t) - xaNormal(t) sagaPrepareCancel(t) sagaNormal(t) sagaRollback(t) @@ -60,20 +62,20 @@ func TestCover(t *testing.T) { // 测试使用的全局对象 var initdb = dbGet() -func getTransStatus(gid string) *TransGlobal { +func getTransStatus(gid string) string { sm := TransGlobal{} dbr := dbGet().Model(&sm).Where("gid=?", gid).First(&sm) e2p(dbr.Error) - return &sm + return sm.Status } func getBranchesStatus(gid string) []string { - steps := []TransBranch{} - dbr := dbGet().Model(&TransBranch{}).Where("gid=?", gid).Order("id").Find(&steps) + branches := []TransBranch{} + dbr := dbGet().Model(&TransBranch{}).Where("gid=?", gid).Order("id").Find(&branches) e2p(dbr.Error) status := []string{} - for _, step := range steps { - status = append(status, step.Status) + for _, branch := range branches { + status = append(status, branch.Status) } return status } @@ -122,15 +124,15 @@ func xaRollback(t *testing.T) { } WaitTransProcessed(gid) assert.Equal(t, []string{"succeed", "prepared"}, getBranchesStatus(gid)) - assert.Equal(t, "failed", getTransStatus(gid).Status) + assert.Equal(t, "failed", getTransStatus(gid)) } func tccNormal(t *testing.T) { tcc := genTcc("gid-tcc-normal", false, false) tcc.Prepare(tcc.QueryPrepared) - assert.Equal(t, "prepared", getTransStatus(tcc.Gid).Status) + assert.Equal(t, "prepared", getTransStatus(tcc.Gid)) tcc.Commit() - assert.Equal(t, "committed", getTransStatus(tcc.Gid).Status) + assert.Equal(t, "committed", getTransStatus(tcc.Gid)) WaitTransProcessed(tcc.Gid) assert.Equal(t, []string{"prepared", "succeed", "succeed", "prepared", "succeed", "succeed"}, getBranchesStatus(tcc.Gid)) } @@ -140,12 +142,22 @@ func tccRollback(t *testing.T) { WaitTransProcessed(tcc.Gid) assert.Equal(t, []string{"succeed", "prepared", "succeed", "succeed", "prepared", "failed"}, getBranchesStatus(tcc.Gid)) } +func tccRollbackPending(t *testing.T) { + tcc := genTcc("gid-tcc-rollback-pending", false, true) + examples.TccTransInCancelResult = "PENDING" + tcc.Commit() + WaitTransProcessed(tcc.Gid) + assert.Equal(t, "committed", getTransStatus(tcc.Gid)) + examples.TccTransInCancelResult = "" + CronTransOnce(-10*time.Second, "committed") + assert.Equal(t, []string{"succeed", "prepared", "succeed", "succeed", "prepared", "failed"}, getBranchesStatus(tcc.Gid)) +} func sagaNormal(t *testing.T) { saga := genSaga("gid-noramlSaga", false, false) saga.Prepare(saga.QueryPrepared) - assert.Equal(t, "prepared", getTransStatus(saga.Gid).Status) + assert.Equal(t, "prepared", getTransStatus(saga.Gid)) saga.Commit() - assert.Equal(t, "committed", getTransStatus(saga.Gid).Status) + assert.Equal(t, "committed", getTransStatus(saga.Gid)) WaitTransProcessed(saga.Gid) assert.Equal(t, []string{"prepared", "succeed", "prepared", "succeed"}, getBranchesStatus(saga.Gid)) } @@ -155,8 +167,8 @@ func sagaRollback(t *testing.T) { saga.Commit() WaitTransProcessed(saga.Gid) saga.Prepare(saga.QueryPrepared) - assert.Equal(t, "failed", getTransStatus(saga.Gid).Status) - assert.Equal(t, []string{"failed", "succeed", "failed", "failed"}, getBranchesStatus(saga.Gid)) + assert.Equal(t, "failed", getTransStatus(saga.Gid)) + assert.Equal(t, []string{"succeed", "succeed", "succeed", "failed"}, getBranchesStatus(saga.Gid)) } func sagaPrepareCancel(t *testing.T) { @@ -167,7 +179,7 @@ func sagaPrepareCancel(t *testing.T) { CronTransOnce(-10*time.Second, "prepared") examples.SagaTransQueryResult = "" config.PreparedExpire = 60 - assert.Equal(t, "canceled", getTransStatus(saga.Gid).Status) + assert.Equal(t, "canceled", getTransStatus(saga.Gid)) } func sagaPreparePending(t *testing.T) { @@ -176,9 +188,9 @@ func sagaPreparePending(t *testing.T) { examples.SagaTransQueryResult = "PENDING" CronTransOnce(-10*time.Second, "prepared") examples.SagaTransQueryResult = "" - assert.Equal(t, "prepared", getTransStatus(saga.Gid).Status) + assert.Equal(t, "prepared", getTransStatus(saga.Gid)) CronTransOnce(-10*time.Second, "prepared") - assert.Equal(t, "succeed", getTransStatus(saga.Gid).Status) + assert.Equal(t, "succeed", getTransStatus(saga.Gid)) } func sagaCommittedPending(t *testing.T) { @@ -191,7 +203,7 @@ func sagaCommittedPending(t *testing.T) { assert.Equal(t, []string{"prepared", "succeed", "prepared", "prepared"}, getBranchesStatus(saga.Gid)) CronTransOnce(-10*time.Second, "committed") assert.Equal(t, []string{"prepared", "succeed", "prepared", "succeed"}, getBranchesStatus(saga.Gid)) - assert.Equal(t, "succeed", getTransStatus(saga.Gid).Status) + assert.Equal(t, "succeed", getTransStatus(saga.Gid)) } func genSaga(gid string, outFailed bool, inFailed bool) *dtm.Saga { diff --git a/dtmsvr/service.go b/dtmsvr/service.go deleted file mode 100644 index cb2b85e..0000000 --- a/dtmsvr/service.go +++ /dev/null @@ -1,17 +0,0 @@ -package dtmsvr - -import ( - "github.com/sirupsen/logrus" -) - -var TransProcessedTestChan chan string = nil // 用于测试时,通知处理结束 - -func WaitTransProcessed(gid string) { - logrus.Printf("waiting for gid %s", gid) - id := <-TransProcessedTestChan - for id != gid { - logrus.Errorf("-------id %s not match gid %s", id, gid) - id = <-TransProcessedTestChan - } - logrus.Printf("finish for gid %s", gid) -} diff --git a/dtmsvr/trans.go b/dtmsvr/trans.go index 7672f09..cf227cc 100644 --- a/dtmsvr/trans.go +++ b/dtmsvr/trans.go @@ -5,207 +5,177 @@ import ( "strings" "time" + "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" "github.com/yedf/dtm/common" + "gorm.io/gorm" + "gorm.io/gorm/clause" ) +type TransGlobal struct { + common.ModelBase + Gid string `json:"gid"` + TransType string `json:"trans_type"` + Data string `json:"data"` + Status string `json:"status"` + QueryPrepared string `json:"query_prepared"` + CommitTime *time.Time + FinishTime *time.Time + RollbackTime *time.Time +} + +func (*TransGlobal) TableName() string { + return "trans_global" +} + type TransProcessor interface { GenBranches() []TransBranch ProcessOnce(db *common.MyDb, branches []TransBranch) - ExecBranch(db *common.MyDb, branch *TransBranch) string + ExecBranch(db *common.MyDb, branch *TransBranch) } -type TransSagaProcessor struct { - *TransGlobal +func (t *TransGlobal) touch(db *common.MyDb) *gorm.DB { + writeTransLog(t.Gid, "touch trans", "", "", "") + return db.Model(&TransGlobal{}).Where("gid=?", t.Gid).Update("gid", t.Gid) // 更新update_time,避免被定时任务再次 } -func (t *TransSagaProcessor) GenBranches() []TransBranch { - nsteps := []TransBranch{} - steps := []M{} - common.MustUnmarshalString(t.Data, &steps) - for _, step := range steps { - for _, branchType := range []string{"compensate", "action"} { - nsteps = append(nsteps, TransBranch{ - Gid: t.Gid, - Branch: fmt.Sprintf("%d", len(nsteps)+1), - Data: step["data"].(string), - Url: step[branchType].(string), - BranchType: branchType, - Status: "prepared", - }) - } +func (t *TransGlobal) changeStatus(db *common.MyDb, status string) *gorm.DB { + writeTransLog(t.Gid, "change status", status, "", "") + updates := M{ + "status": status, } - return nsteps + if status == "succeed" { + updates["finish_time"] = time.Now() + } else if status == "failed" { + updates["rollback_time"] = time.Now() + } + dbr := db.Must().Model(t).Where("status=?", t.Status).Updates(updates) + checkAffected(dbr) + t.Status = status + return dbr } -func (t *TransSagaProcessor) ExecBranch(db *common.MyDb, branche *TransBranch) string { - return "" +type TransBranch struct { + common.ModelBase + Gid string + Url string + Data string + Branch string + BranchType string + Status string + FinishTime *time.Time + RollbackTime *time.Time } -func (t *TransSagaProcessor) ProcessOnce(db *common.MyDb, branches []TransBranch) { - if t.Status == "prepared" { - resp, err := common.RestyClient.R().SetQueryParam("gid", t.Gid).Get(t.QueryPrepared) - e2p(err) - body := resp.String() - if strings.Contains(body, "FAIL") { - preparedExpire := time.Now().Add(time.Duration(-config.PreparedExpire) * time.Second) - logrus.Printf("create time: %s prepared expire: %s ", t.CreateTime.Local(), preparedExpire.Local()) - status := common.If(t.CreateTime.Before(preparedExpire), "canceled", "prepared").(string) - if status != t.Status { - t.changeStatus(db, status) - } - return - } else if strings.Contains(body, "SUCCESS") { - t.Status = "committed" - t.SaveNew(db) - } else { - panic(fmt.Errorf("unknown result, will be retried: %s", body)) - } - } - current := 0 // 当前正在处理的步骤 - for ; current < len(branches); current++ { - step := branches[current] - if step.BranchType == "compensate" && step.Status == "prepared" || step.BranchType == "action" && step.Status == "succeed" { - continue - } - if step.BranchType == "action" && step.Status == "prepared" { - resp, err := common.RestyClient.R().SetBody(step.Data).SetQueryParam("gid", step.Gid).Post(step.Url) - e2p(err) - body := resp.String() +func (*TransBranch) TableName() string { + return "trans_branch" +} - t.touch(db.Must()) - if strings.Contains(body, "SUCCESS") { - step.changeStatus(db.Must(), "succeed") - } else if strings.Contains(body, "FAIL") { - step.changeStatus(db.Must(), "failed") - break - } else { - panic(fmt.Errorf("unknown response: %s, will be retried", body)) - } - } +func (t *TransBranch) changeStatus(db *common.MyDb, status string) *gorm.DB { + writeTransLog(t.Gid, "branch change", status, t.Branch, "") + dbr := db.Must().Model(t).Where("status=?", t.Status).Updates(M{ + "status": status, + "finish_time": time.Now(), + }) + checkAffected(dbr) + t.Status = status + return dbr +} + +func checkAffected(db1 *gorm.DB) { + if db1.RowsAffected == 0 { + panic(fmt.Errorf("duplicate updating")) } - if current == len(branches) { // saga 事务完成 - t.changeStatus(db.Must(), "succeed") +} + +func (trans *TransGlobal) getProcessor() TransProcessor { + if trans.TransType == "saga" { + return &TransSagaProcessor{TransGlobal: trans} + } else if trans.TransType == "tcc" { + return &TransTccProcessor{TransGlobal: trans} + } else if trans.TransType == "xa" { + return &TransXaProcessor{TransGlobal: trans} + } + return nil +} + +func (t *TransGlobal) MayQueryPrepared(db *common.MyDb) { + if t.Status != "prepared" { return } - for current = current - 1; current >= 0; current-- { - step := branches[current] - if step.BranchType != "compensate" || step.Status != "prepared" { - continue - } - resp, err := common.RestyClient.R().SetBody(step.Data).SetQueryParam("gid", step.Gid).Post(step.Url) - e2p(err) - body := resp.String() - if strings.Contains(body, "SUCCESS") { - step.changeStatus(db.Must(), "failed") + resp, err := common.RestyClient.R().SetQueryParam("gid", t.Gid).Get(t.QueryPrepared) + e2p(err) + body := resp.String() + if strings.Contains(body, "FAIL") { + preparedExpire := time.Now().Add(time.Duration(-config.PreparedExpire) * time.Second) + logrus.Printf("create time: %s prepared expire: %s ", t.CreateTime.Local(), preparedExpire.Local()) + status := common.If(t.CreateTime.Before(preparedExpire), "canceled", "prepared").(string) + if status != t.Status { + t.changeStatus(db, status) } else { - panic(fmt.Errorf("expect compensate return SUCCESS")) - } - } - if current != -1 { - panic(fmt.Errorf("saga current not -1")) - } - t.changeStatus(db.Must(), "failed") -} - -type TransTccProcessor struct { - *TransGlobal -} - -func (t *TransTccProcessor) GenBranches() []TransBranch { - nsteps := []TransBranch{} - steps := []M{} - common.MustUnmarshalString(t.Data, &steps) - for _, step := range steps { - for _, branchType := range []string{"cancel", "confirm", "try"} { - nsteps = append(nsteps, TransBranch{ - Gid: t.Gid, - Branch: fmt.Sprintf("%d", len(nsteps)+1), - Data: step["data"].(string), - Url: step[branchType].(string), - BranchType: branchType, - Status: "prepared", - }) - } - } - return nsteps -} - -func (t *TransTccProcessor) ExecBranch(db *common.MyDb, branch *TransBranch) string { - resp, err := common.RestyClient.R().SetBody(branch.Data).SetQueryParam("gid", branch.Gid).Post(branch.Url) - e2p(err) - body := resp.String() - t.touch(db) - if strings.Contains(body, "SUCCESS") { - branch.changeStatus(db, "succeed") - return "SUCCESS" - } - if branch.BranchType == "try" && strings.Contains(body, "FAIL") { - branch.changeStatus(db, "failed") - return "FAIL" - } - panic(fmt.Errorf("unknown response: %s, will be retried", body)) -} - -func (t *TransTccProcessor) ProcessOnce(db *common.MyDb, branches []TransBranch) { - current := 0 // 当前正在处理的步骤 - // 先处理一轮正常try状态 - for ; current < len(branches); current++ { - step := &branches[current] - if step.BranchType != "try" || step.Status == "succeed" { - continue - } - if step.BranchType == "try" && step.Status == "prepared" { - result := t.ExecBranch(db, step) - if result == "FAIL" { - break - } - } - } - // 如果try全部成功,则处理confirm分支,否则处理cancel分支 - currentType := common.If(current == len(branches), "confirm", "cancel") - for current--; current >= 0; current-- { - branch := &branches[current] - if branch.BranchType != currentType || branch.Status != "prepared" { - continue - } - t.ExecBranch(db, branch) - } - t.changeStatus(db, common.If(currentType == "confirm", "succeed", "failed").(string)) -} - -type TransXaProcessor struct { - *TransGlobal -} - -func (t *TransXaProcessor) GenBranches() []TransBranch { - return []TransBranch{} -} -func (t *TransXaProcessor) ExecBranch(db *common.MyDb, branch *TransBranch) string { - resp, err := common.RestyClient.R().SetBody(M{ - "branch": branch.Branch, - "action": common.If(t.Status == "prepared", "rollback", "commit"), - "gid": branch.Gid, - }).Post(branch.Url) - e2p(err) - body := resp.String() - if !strings.Contains(body, "SUCCESS") { - panic(fmt.Errorf("bad response: %s", body)) - } - branch.changeStatus(db, "succeed") - return "SUCCESS" -} - -func (t *TransXaProcessor) ProcessOnce(db *common.MyDb, branches []TransBranch) { - if t.Status == "succeed" { - return - } - currentType := common.If(t.Status == "committed", "commit", "rollback").(string) - for _, branch := range branches { - if branch.BranchType == currentType && branch.Status != "succeed" { - _ = t.ExecBranch(db, &branch) t.touch(db) } + } else if strings.Contains(body, "SUCCESS") { + t.changeStatus(db, "committed") } - t.changeStatus(db, common.If(t.Status == "committed", "succeed", "failed").(string)) +} + +func (trans *TransGlobal) Process(db *common.MyDb) { + defer handlePanic() + defer func() { + if TransProcessedTestChan != nil { + TransProcessedTestChan <- trans.Gid + } + }() + branches := []TransBranch{} + db.Must().Where("gid=?", trans.Gid).Order("id asc").Find(&branches) + trans.getProcessor().ProcessOnce(db, branches) +} + +func (t *TransGlobal) SaveNew(db *common.MyDb) { + err := db.Transaction(func(db1 *gorm.DB) error { + db := &common.MyDb{DB: db1} + + writeTransLog(t.Gid, "create trans", t.Status, "", t.Data) + dbr := db.Must().Clauses(clause.OnConflict{ + DoNothing: true, + }).Create(t) + if dbr.RowsAffected == 0 && t.Status == "committed" { // 如果数据库已经存放了prepared的事务,则修改状态 + dbr = db.Must().Model(&TransGlobal{}).Where("gid=? and status=?", t.Gid, "prepared").Update("status", t.Status) + } + if dbr.RowsAffected == 0 { // 未保存任何数据,直接返回 + return nil + } + // 保存所有的分支 + branches := t.getProcessor().GenBranches() + if len(branches) > 0 { + writeTransLog(t.Gid, "save branches", t.Status, "", common.MustMarshalString(branches)) + db.Must().Clauses(clause.OnConflict{ + DoNothing: true, + }).Create(&branches) + } + return nil + }) + e2p(err) +} + +func TransFromContext(c *gin.Context) *TransGlobal { + data := M{} + b, err := c.GetRawData() + e2p(err) + common.MustUnmarshal(b, &data) + logrus.Printf("creating trans in prepare") + if data["steps"] != nil { + data["data"] = common.MustMarshalString(data["steps"]) + } + m := TransGlobal{} + common.MustRemarshal(data, &m) + return &m +} + +func TransFromDb(db *common.MyDb, gid string) *TransGlobal { + m := TransGlobal{} + dbr := db.Must().Model(&m).Where("gid=?", gid).First(&m) + e2p(dbr.Error) + return &m } diff --git a/dtmsvr/trans_saga.go b/dtmsvr/trans_saga.go new file mode 100644 index 0000000..0219eec --- /dev/null +++ b/dtmsvr/trans_saga.go @@ -0,0 +1,78 @@ +package dtmsvr + +import ( + "fmt" + "strings" + + "github.com/yedf/dtm/common" +) + +type TransSagaProcessor struct { + *TransGlobal +} + +func (t *TransSagaProcessor) GenBranches() []TransBranch { + branches := []TransBranch{} + steps := []M{} + common.MustUnmarshalString(t.Data, &steps) + for _, step := range steps { + for _, branchType := range []string{"compensate", "action"} { + branches = append(branches, TransBranch{ + Gid: t.Gid, + Branch: fmt.Sprintf("%d", len(branches)+1), + Data: step["data"].(string), + Url: step[branchType].(string), + BranchType: branchType, + Status: "prepared", + }) + } + } + return branches +} + +func (t *TransSagaProcessor) ExecBranch(db *common.MyDb, branch *TransBranch) { + resp, err := common.RestyClient.R().SetBody(branch.Data).SetQueryParam("gid", branch.Gid).Post(branch.Url) + e2p(err) + body := resp.String() + t.touch(db) + if strings.Contains(body, "SUCCESS") { + branch.changeStatus(db, "succeed") + } else if branch.BranchType == "action" && strings.Contains(body, "FAIL") { + branch.changeStatus(db, "failed") + } else { + panic(fmt.Errorf("unknown response: %s, will be retried", body)) + } +} + +func (t *TransSagaProcessor) ProcessOnce(db *common.MyDb, branches []TransBranch) { + t.MayQueryPrepared(db) + if t.Status != "committed" { + return + } + current := 0 // 当前正在处理的步骤 + for ; current < len(branches); current++ { + branch := &branches[current] + if branch.BranchType != "action" || branch.Status != "prepared" { + continue + } + t.ExecBranch(db, branch) + if branch.Status != "succeed" { + break + } + } + if current == len(branches) { // saga 事务完成 + t.changeStatus(db, "succeed") + return + } + for current = current - 1; current >= 0; current-- { + branch := &branches[current] + if branch.BranchType != "compensate" || branch.Status != "prepared" { + continue + } + t.ExecBranch(db, branch) + } + if current != -1 { + panic(fmt.Errorf("saga current not -1")) + } + t.changeStatus(db.Must(), "failed") +} diff --git a/dtmsvr/trans_tcc.go b/dtmsvr/trans_tcc.go new file mode 100644 index 0000000..d6d6382 --- /dev/null +++ b/dtmsvr/trans_tcc.go @@ -0,0 +1,78 @@ +package dtmsvr + +import ( + "fmt" + "strings" + + "github.com/yedf/dtm/common" +) + +type TransTccProcessor struct { + *TransGlobal +} + +func (t *TransTccProcessor) GenBranches() []TransBranch { + branches := []TransBranch{} + steps := []M{} + common.MustUnmarshalString(t.Data, &steps) + for _, step := range steps { + for _, branchType := range []string{"cancel", "confirm", "try"} { + branches = append(branches, TransBranch{ + Gid: t.Gid, + Branch: fmt.Sprintf("%d", len(branches)+1), + Data: step["data"].(string), + Url: step[branchType].(string), + BranchType: branchType, + Status: "prepared", + }) + } + } + return branches +} + +func (t *TransTccProcessor) ExecBranch(db *common.MyDb, branch *TransBranch) { + resp, err := common.RestyClient.R().SetBody(branch.Data).SetQueryParam("gid", branch.Gid).Post(branch.Url) + e2p(err) + body := resp.String() + t.touch(db) + if strings.Contains(body, "SUCCESS") { + branch.changeStatus(db, "succeed") + } else if branch.BranchType == "try" && strings.Contains(body, "FAIL") { + branch.changeStatus(db, "failed") + } else { + panic(fmt.Errorf("unknown response: %s, will be retried", body)) + } +} + +func (t *TransTccProcessor) ProcessOnce(db *common.MyDb, branches []TransBranch) { + t.MayQueryPrepared(db) + if t.Status != "committed" { + return + } + current := 0 // 当前正在处理的步骤 + // 先处理一轮正常try状态 + for ; current < len(branches); current++ { + branch := &branches[current] + if branch.BranchType != "try" || branch.Status == "succeed" { + continue + } + if branch.BranchType == "try" && branch.Status == "prepared" { + t.ExecBranch(db, branch) + if branch.Status != "succeed" { + break + } + } else { + break + } + } + // 如果try全部成功,则处理confirm分支,否则处理cancel分支 + currentType := common.If(current == len(branches), "confirm", "cancel") + for current--; current >= 0; current-- { + branch := &branches[current] + if branch.BranchType != currentType || branch.Status != "prepared" { + continue + } + t.ExecBranch(db, branch) + } + t.changeStatus(db, common.If(currentType == "confirm", "succeed", "failed").(string)) +} diff --git a/dtmsvr/trans_xa.go b/dtmsvr/trans_xa.go new file mode 100644 index 0000000..9b6d635 --- /dev/null +++ b/dtmsvr/trans_xa.go @@ -0,0 +1,44 @@ +package dtmsvr + +import ( + "fmt" + "strings" + + "github.com/yedf/dtm/common" +) + +type TransXaProcessor struct { + *TransGlobal +} + +func (t *TransXaProcessor) GenBranches() []TransBranch { + return []TransBranch{} +} +func (t *TransXaProcessor) ExecBranch(db *common.MyDb, branch *TransBranch) { + resp, err := common.RestyClient.R().SetBody(M{ + "branch": branch.Branch, + "action": common.If(t.Status == "prepared", "rollback", "commit"), + "gid": branch.Gid, + }).Post(branch.Url) + e2p(err) + body := resp.String() + t.touch(db) + if strings.Contains(body, "SUCCESS") { + branch.changeStatus(db, "succeed") + } else { + panic(fmt.Errorf("bad response: %s", body)) + } +} + +func (t *TransXaProcessor) ProcessOnce(db *common.MyDb, branches []TransBranch) { + if t.Status == "succeed" { + return + } + currentType := common.If(t.Status == "committed", "commit", "rollback").(string) + for _, branch := range branches { + if branch.BranchType == currentType && branch.Status != "succeed" { + t.ExecBranch(db, &branch) + } + } + t.changeStatus(db, common.If(t.Status == "committed", "succeed", "failed").(string)) +} diff --git a/dtmsvr/types.go b/dtmsvr/types.go deleted file mode 100644 index 16f797e..0000000 --- a/dtmsvr/types.go +++ /dev/null @@ -1,158 +0,0 @@ -package dtmsvr - -import ( - "fmt" - "time" - - "github.com/gin-gonic/gin" - "github.com/sirupsen/logrus" - "github.com/yedf/dtm/common" - "gorm.io/gorm" - "gorm.io/gorm/clause" -) - -type M = map[string]interface{} - -var p2e = common.P2E -var e2p = common.E2P - -type TransGlobal struct { - common.ModelBase - Gid string `json:"gid"` - TransType string `json:"trans_type"` - Data string `json:"data"` - Status string `json:"status"` - QueryPrepared string `json:"query_prepared"` - CommitTime *time.Time - FinishTime *time.Time - RollbackTime *time.Time -} - -func (*TransGlobal) TableName() string { - return "trans_global" -} - -func (t *TransGlobal) touch(db *common.MyDb) *gorm.DB { - writeTransLog(t.Gid, "touch trans", "", "", "") - return db.Model(&TransGlobal{}).Where("gid=?", t.Gid).Update("gid", t.Gid) // 更新update_time,避免被定时任务再次 -} - -func (t *TransGlobal) changeStatus(db *common.MyDb, status string) *gorm.DB { - writeTransLog(t.Gid, "change status", status, "", "") - updates := M{ - "status": status, - } - if status == "succeed" { - updates["finish_time"] = time.Now() - } else if status == "failed" { - updates["rollback_time"] = time.Now() - } - dbr := db.Must().Model(t).Where("status=?", t.Status).Updates(updates) - checkAffected(dbr) - t.Status = status - return dbr -} - -type TransBranch struct { - common.ModelBase - Gid string - Url string - Data string - Branch string - BranchType string - Status string - FinishTime *time.Time - RollbackTime *time.Time -} - -func (*TransBranch) TableName() string { - return "trans_branch" -} - -func (t *TransBranch) changeStatus(db *common.MyDb, status string) *gorm.DB { - writeTransLog(t.Gid, "step change", status, t.Branch, "") - dbr := db.Must().Model(t).Where("status=?", t.Status).Updates(M{ - "status": status, - "finish_time": time.Now(), - }) - checkAffected(dbr) - t.Status = status - return dbr -} - -func checkAffected(db1 *gorm.DB) { - if db1.RowsAffected == 0 { - panic(fmt.Errorf("duplicate updating")) - } -} - -func (trans *TransGlobal) getProcessor() TransProcessor { - if trans.TransType == "saga" { - return &TransSagaProcessor{TransGlobal: trans} - } else if trans.TransType == "tcc" { - return &TransTccProcessor{TransGlobal: trans} - } else if trans.TransType == "xa" { - return &TransXaProcessor{TransGlobal: trans} - } - return nil -} - -func (trans *TransGlobal) Process(db *common.MyDb) { - defer handlePanic() - defer func() { - if TransProcessedTestChan != nil { - TransProcessedTestChan <- trans.Gid - } - }() - branches := []TransBranch{} - db.Must().Where("gid=?", trans.Gid).Order("id asc").Find(&branches) - trans.getProcessor().ProcessOnce(db, branches) -} - -func (t *TransGlobal) SaveNew(db *common.MyDb) { - err := db.Transaction(func(db1 *gorm.DB) error { - db := &common.MyDb{DB: db1} - - writeTransLog(t.Gid, "create trans", t.Status, "", t.Data) - dbr := db.Must().Clauses(clause.OnConflict{ - DoNothing: true, - }).Create(t) - if dbr.RowsAffected == 0 && t.Status == "committed" { // 如果数据库已经存放了prepared的事务,则修改状态 - dbr = db.Must().Model(&TransGlobal{}).Where("gid=? and status=?", t.Gid, "prepared").Update("status", t.Status) - } - if dbr.RowsAffected == 0 { // 未保存任何数据,直接返回 - return nil - } - // 保存所有的分支 - nsteps := t.getProcessor().GenBranches() - if len(nsteps) > 0 { - writeTransLog(t.Gid, "save steps", t.Status, "", common.MustMarshalString(nsteps)) - db.Must().Clauses(clause.OnConflict{ - DoNothing: true, - }).Create(&nsteps) - } - return nil - }) - e2p(err) -} - -func TransFromContext(c *gin.Context) *TransGlobal { - data := M{} - b, err := c.GetRawData() - e2p(err) - common.MustUnmarshal(b, &data) - logrus.Printf("creating trans in prepare") - if data["steps"] != nil { - data["data"] = common.MustMarshalString(data["steps"]) - } - m := TransGlobal{} - common.MustRemarshal(data, &m) - return &m -} - -func TransFromDb(db *common.MyDb, gid string) *TransGlobal { - m := TransGlobal{} - dbr := db.Must().Model(&m).Where("gid=?", gid).First(&m) - e2p(dbr.Error) - return &m -} diff --git a/dtmsvr/utils.go b/dtmsvr/utils.go new file mode 100644 index 0000000..f971558 --- /dev/null +++ b/dtmsvr/utils.go @@ -0,0 +1,40 @@ +package dtmsvr + +import ( + "github.com/sirupsen/logrus" + "github.com/yedf/dtm/common" +) + +type M = map[string]interface{} + +var p2e = common.P2E +var e2p = common.E2P + +func dbGet() *common.MyDb { + return common.DbGet(config.Mysql) +} +func writeTransLog(gid string, action string, status string, branch string, detail string) { + db := dbGet() + if detail == "" { + detail = "{}" + } + db.Must().Table("trans_log").Create(M{ + "gid": gid, + "action": action, + "new_status": status, + "branch": branch, + "detail": detail, + }) +} + +var TransProcessedTestChan chan string = nil // 用于测试时,通知处理结束 + +func WaitTransProcessed(gid string) { + logrus.Printf("waiting for gid %s", gid) + id := <-TransProcessedTestChan + for id != gid { + logrus.Errorf("-------id %s not match gid %s", id, gid) + id = <-TransProcessedTestChan + } + logrus.Printf("finish for gid %s", gid) +}