diff --git a/dtmsvr/api.go b/dtmsvr/api.go index a826622..2f7abf4 100644 --- a/dtmsvr/api.go +++ b/dtmsvr/api.go @@ -1,8 +1,9 @@ package dtmsvr import ( + "fmt" + "github.com/gin-gonic/gin" - "github.com/sirupsen/logrus" "github.com/yedf/dtm/common" "gorm.io/gorm/clause" ) @@ -15,29 +16,30 @@ func AddRoute(engine *gin.Engine) { } func Prepare(c *gin.Context) (interface{}, error) { - db := dbGet() - m := getTransFromContext(c) + m := TransFromContext(c) m.Status = "prepared" - writeTransLog(m.Gid, "save prepared", m.Status, "", m.Data) - db.Must().Clauses(clause.OnConflict{ - DoNothing: true, - }).Create(&m) + m.SaveNew(dbGet()) return M{"message": "SUCCESS"}, nil } func Commit(c *gin.Context) (interface{}, error) { - m := getTransFromContext(c) - saveCommitted(m) - go ProcessTrans(m) + db := dbGet() + m := TransFromContext(c) + m.Status = "committed" + m.SaveNew(db) + go m.Process(db) return M{"message": "SUCCESS"}, nil } func Rollback(c *gin.Context) (interface{}, error) { - m := getTransFromContext(c) - trans := TransGlobal{} - dbGet().Must().Model(&m).First(&trans) + db := dbGet() + m := TransFromContext(c) + m = TransFromDb(db, m.Gid) + if m.TransType != "xa" || m.Status != "prepared" { + return nil, fmt.Errorf("unkown trans data. type: %s status: %s for gid: %s", m.TransType, m.Status, m.Gid) + } // 当前xa trans的状态为prepared,直接处理,则是回滚 - go ProcessTrans(&trans) + go m.Process(db) return M{"message": "SUCCESS"}, nil } @@ -51,17 +53,3 @@ func Branch(c *gin.Context) (interface{}, error) { }).Create(&branch) return M{"message": "SUCCESS"}, nil } - -func getTransFromContext(c *gin.Context) *TransGlobal { - data := M{} - b, err := c.GetRawData() - e2p(err) - common.MustUnmarshal(b, &data) - logrus.Printf("creating trans model in prepare") - if data["trans_type"].(string) == "saga" { - data["data"] = common.MustMarshalString(data["steps"]) - } - m := TransGlobal{} - common.MustRemarshal(data, &m) - return &m -} diff --git a/dtmsvr/cron.go b/dtmsvr/cron.go index 207860a..7acbbf5 100644 --- a/dtmsvr/cron.go +++ b/dtmsvr/cron.go @@ -32,8 +32,9 @@ func CronPreparedOnce(expire time.Duration) { writeTransLog(sm.Gid, "saga canceled", status, "", "") db.Must().Model(&sm).Where("status = ?", "prepared").Update("status", status) } else if strings.Contains(body, "SUCCESS") { - saveCommitted(&sm) - ProcessTrans(&sm) + sm.Status = "committed" + sm.SaveNew(db) + sm.Process(db) } } } @@ -77,7 +78,7 @@ func lockOneTrans(expire time.Duration, status string) *TransGlobal { owner := common.GenGid() db := dbGet() dbr := db.Must().Model(&trans). - Where("update_time < date_sub(now(), interval ? second) and satus=?", int(expire/time.Second), status). + Where("update_time < date_sub(now(), interval ? second) and status=?", int(expire/time.Second), status). Limit(1).Update("owner", owner) if dbr.RowsAffected == 0 { return nil diff --git a/dtmsvr/service.go b/dtmsvr/service.go index df659eb..bbcd652 100644 --- a/dtmsvr/service.go +++ b/dtmsvr/service.go @@ -2,36 +2,8 @@ package dtmsvr import ( "github.com/sirupsen/logrus" - "github.com/yedf/dtm/common" - "gorm.io/gorm" - "gorm.io/gorm/clause" ) -func saveCommitted(m *TransGlobal) { - db := dbGet() - m.Status = "committed" - err := db.Transaction(func(db1 *gorm.DB) error { - db := &common.MyDb{DB: db1} - writeTransLog(m.Gid, "save committed", m.Status, "", m.Data) - dbr := db.Must().Clauses(clause.OnConflict{ - DoNothing: true, - }).Create(m) - if dbr.RowsAffected == 0 { - writeTransLog(m.Gid, "change status", m.Status, "", "") - db.Must().Model(m).Where("status=?", "prepared").Update("status", "committed") - } - nsteps := m.getProcessor().GenBranches() - if len(nsteps) > 0 { - writeTransLog(m.Gid, "save steps", m.Status, "", common.MustMarshalString(nsteps)) - db.Must().Clauses(clause.OnConflict{ - DoNothing: true, - }).Create(&nsteps) - } - return nil - }) - e2p(err) -} - var TransProcessedTestChan chan string = nil // 用于测试时,通知处理结束 func WaitTransProcessed(gid string) { @@ -41,13 +13,3 @@ func WaitTransProcessed(gid string) { id = <-TransProcessedTestChan } } - -func ProcessTrans(trans *TransGlobal) { - branches := []TransBranch{} - db := dbGet() - db.Must().Where("gid=?", trans.Gid).Order("id asc").Find(&branches) - trans.getProcessor().ProcessOnce(db, branches) - if TransProcessedTestChan != nil { - TransProcessedTestChan <- trans.Gid - } -} diff --git a/dtmsvr/trans.go b/dtmsvr/trans.go index 9cb9b64..089f51d 100644 --- a/dtmsvr/trans.go +++ b/dtmsvr/trans.go @@ -52,9 +52,9 @@ func (t *TransSagaProcessor) ProcessOnce(db *common.MyDb, branches []TransBranch t.touch(db.Must()) if strings.Contains(body, "SUCCESS") { - step.saveStatus(db.Must(), "finished") + step.changeStatus(db.Must(), "finished") } else if strings.Contains(body, "FAIL") { - step.saveStatus(db.Must(), "rollbacked") + step.changeStatus(db.Must(), "rollbacked") break } else { return fmt.Errorf("unknown response: %s, will be retried", body) @@ -76,7 +76,7 @@ func (t *TransSagaProcessor) ProcessOnce(db *common.MyDb, branches []TransBranch } body := resp.String() if strings.Contains(body, "SUCCESS") { - step.saveStatus(db.Must(), "rollbacked") + step.changeStatus(db.Must(), "rollbacked") } else { return fmt.Errorf("expect compensate return SUCCESS") } diff --git a/dtmsvr/types.go b/dtmsvr/types.go index dad515b..e80af74 100644 --- a/dtmsvr/types.go +++ b/dtmsvr/types.go @@ -4,8 +4,11 @@ import ( "fmt" "time" + "github.com/gin-gonic/gin" + "github.com/sirupsen/logrus" "github.com/yedf/dtm/common" "gorm.io/gorm" + "gorm.io/gorm/clause" ) type M = map[string]interface{} @@ -61,7 +64,7 @@ func (*TransBranch) TableName() string { return "trans_branch" } -func (t *TransBranch) saveStatus(db *common.MyDb, status string) *gorm.DB { +func (t *TransBranch) changeStatus(db *common.MyDb, status string) *gorm.DB { writeTransLog(t.Gid, "step change", status, t.Branch, "") dbr := db.Must().Model(t).Where("status=?", t.Status).Updates(M{ "status": status, @@ -88,3 +91,60 @@ func (trans *TransGlobal) getProcessor() TransProcessor { } return nil } + +func (trans *TransGlobal) Process(db *common.MyDb) { + branches := []TransBranch{} + db.Must().Where("gid=?", trans.Gid).Order("id asc").Find(&branches) + trans.getProcessor().ProcessOnce(db, branches) + if TransProcessedTestChan != nil { + TransProcessedTestChan <- trans.Gid + } +} + +func (t *TransGlobal) SaveNew(db *common.MyDb) { + err := db.Transaction(func(db1 *gorm.DB) error { + db := &common.MyDb{DB: db1} + + writeTransLog(t.Gid, "create trans", t.Status, "", t.Data) + dbr := db.Must().Clauses(clause.OnConflict{ + DoNothing: true, + }).Create(t) + if dbr.RowsAffected == 0 && t.Status == "committed" { // 如果数据库已经存放了prepared的事务,则修改状态 + dbr = db.Must().Model(&TransGlobal{}).Where("gid=? and status=?", t.Gid, "prepared").Update("status", t.Status) + } + if dbr.RowsAffected == 0 { // 未保存任何数据,直接返回 + return nil + } + // 保存所有的分支 + nsteps := t.getProcessor().GenBranches() + if len(nsteps) > 0 { + writeTransLog(t.Gid, "save steps", t.Status, "", common.MustMarshalString(nsteps)) + db.Must().Clauses(clause.OnConflict{ + DoNothing: true, + }).Create(&nsteps) + } + return nil + }) + e2p(err) +} + +func TransFromContext(c *gin.Context) *TransGlobal { + data := M{} + b, err := c.GetRawData() + e2p(err) + common.MustUnmarshal(b, &data) + logrus.Printf("creating trans in prepare") + if data["steps"] != nil { + data["data"] = common.MustMarshalString(data["steps"]) + } + m := TransGlobal{} + common.MustRemarshal(data, &m) + return &m +} + +func TransFromDb(db *common.MyDb, gid string) *TransGlobal { + m := TransGlobal{} + dbr := db.Must().Model(&m).Where("gid=?", gid).First(&m) + e2p(dbr.Error) + return &m +}