diff --git a/dtmsvr/api.go b/dtmsvr/api.go index 291282e..e125592 100644 --- a/dtmsvr/api.go +++ b/dtmsvr/api.go @@ -11,6 +11,7 @@ func AddRoute(engine *gin.Engine) { engine.POST("/api/dtmsvr/prepare", common.WrapHandler(Prepare)) engine.POST("/api/dtmsvr/commit", common.WrapHandler(Commit)) engine.POST("/api/dtmsvr/branch", common.WrapHandler(Branch)) + engine.POST("/api/dtmsvr/rollback", common.WrapHandler(Rollback)) } func Prepare(c *gin.Context) (interface{}, error) { @@ -27,7 +28,16 @@ func Prepare(c *gin.Context) (interface{}, error) { func Commit(c *gin.Context) (interface{}, error) { m := getTransFromContext(c) saveCommitted(m) - go ProcessCommitted(m) + go ProcessTrans(m) + return M{"message": "SUCCESS"}, nil +} + +func Rollback(c *gin.Context) (interface{}, error) { + m := getTransFromContext(c) + trans := TransGlobalModel{} + dbGet().Must().Model(&m).First(&trans) + // 当前xa trans的状态为prepared,直接处理,则是回滚 + go ProcessTrans(&trans) return M{"message": "SUCCESS"}, nil } diff --git a/dtmsvr/cron.go b/dtmsvr/cron.go index 80bb85c..495c995 100644 --- a/dtmsvr/cron.go +++ b/dtmsvr/cron.go @@ -31,7 +31,7 @@ func CronPreparedOnce(expire time.Duration) { db.Must().Model(&sm).Where("status = ?", "prepared").Update("status", status) } else if strings.Contains(body, "SUCCESS") { saveCommitted(&sm) - ProcessCommitted(&sm) + ProcessTrans(&sm) } } } @@ -54,7 +54,7 @@ func CronCommittedOnce(expire time.Duration) { for _, sm := range ss { writeTransLog(sm.Gid, "saga touch committed", "", "", "") db.Must().Model(&sm).Update("id", sm.ID) - ProcessCommitted(&sm) + ProcessTrans(&sm) } } diff --git a/dtmsvr/dtmsvr_test.go b/dtmsvr/dtmsvr_test.go index 856f7e0..d708195 100644 --- a/dtmsvr/dtmsvr_test.go +++ b/dtmsvr/dtmsvr_test.go @@ -37,6 +37,7 @@ func TestDtmSvr(t *testing.T) { common.PanicIfError(dbGet().Exec("truncate trans_log").Error) examples.ResetXaData() + xaRollback(t) xaNormal(t) sagaPreparePending(t) sagaPrepareCancel(t) @@ -93,7 +94,7 @@ func xaNormal(t *testing.T) { return nil }) common.PanicIfError(err) - WaitTransCommitted(gid) + WaitTransProcessed(gid) assert.Equal(t, []string{"finished", "finished"}, getBranchesStatus(gid)) } @@ -114,9 +115,11 @@ func xaRollback(t *testing.T) { common.CheckRestySuccess(resp, err) return nil }) - common.PanicIfError(err) - WaitTransCommitted(gid) - assert.Equal(t, []string{"rollbacked", "rollbacked"}, getBranchesStatus(gid)) + if err != nil { + logrus.Errorf("global transaction failed, so rollback") + } + WaitTransProcessed(gid) + assert.Equal(t, []string{"rollbacked"}, getBranchesStatus(gid)) } func sagaNormal(t *testing.T) { @@ -125,14 +128,14 @@ func sagaNormal(t *testing.T) { assert.Equal(t, "prepared", getSagaModel(saga.Gid).Status) saga.Commit() assert.Equal(t, "committed", getSagaModel(saga.Gid).Status) - WaitTransCommitted(saga.Gid) + WaitTransProcessed(saga.Gid) assert.Equal(t, []string{"prepared", "finished", "prepared", "finished"}, getBranchesStatus(saga.Gid)) } func sagaRollback(t *testing.T) { saga := genSaga("gid-rollbackSaga2", false, true) saga.Commit() - WaitTransCommitted(saga.Gid) + WaitTransProcessed(saga.Gid) saga.Prepare() assert.Equal(t, "rollbacked", getSagaModel(saga.Gid).Status) assert.Equal(t, []string{"rollbacked", "finished", "rollbacked", "rollbacked"}, getBranchesStatus(saga.Gid)) @@ -157,7 +160,7 @@ func sagaPreparePending(t *testing.T) { examples.TransQueryResult = "" assert.Equal(t, "prepared", getSagaModel(saga.Gid).Status) CronPreparedOnce(-10 * time.Second) - WaitTransCommitted(saga.Gid) + WaitTransProcessed(saga.Gid) assert.Equal(t, "finished", getSagaModel(saga.Gid).Status) } @@ -166,11 +169,11 @@ func sagaCommittedPending(t *testing.T) { saga.Prepare() examples.TransInResult = "PENDING" saga.Commit() - WaitTransCommitted(saga.Gid) + WaitTransProcessed(saga.Gid) examples.TransInResult = "" assert.Equal(t, []string{"prepared", "finished", "prepared", "prepared"}, getBranchesStatus(saga.Gid)) CronCommittedOnce(-10 * time.Second) - WaitTransCommitted(saga.Gid) + WaitTransProcessed(saga.Gid) assert.Equal(t, []string{"prepared", "finished", "prepared", "finished"}, getBranchesStatus(saga.Gid)) assert.Equal(t, "finished", getSagaModel(saga.Gid).Status) } diff --git a/dtmsvr/service.go b/dtmsvr/service.go index 8fda0d0..d6c7576 100644 --- a/dtmsvr/service.go +++ b/dtmsvr/service.go @@ -58,7 +58,7 @@ func saveCommitted(m *TransGlobalModel) { var TransProcessedTestChan chan string = nil // 用于测试时,通知处理结束 -func WaitTransCommitted(gid string) { +func WaitTransProcessed(gid string) { id := <-TransProcessedTestChan for id != gid { logrus.Errorf("-------id %s not match gid %s", id, gid) @@ -66,19 +66,19 @@ func WaitTransCommitted(gid string) { } } -func ProcessCommitted(trans *TransGlobalModel) { - err := innerProcessCommitted(trans) +func ProcessTrans(trans *TransGlobalModel) { + err := innerProcessTrans(trans) if err != nil { - logrus.Errorf("process committed error: %s", err.Error()) + logrus.Errorf("process trans ignore error: %s", err.Error()) } if TransProcessedTestChan != nil { TransProcessedTestChan <- trans.Gid } } -func innerProcessCommitted(trans *TransGlobalModel) (rerr error) { +func innerProcessTrans(trans *TransGlobalModel) (rerr error) { branches := []TransBranchModel{} db := dbGet() - db.Must().Order("id asc").Find(&branches) + db.Must().Where("gid=?", trans.Gid).Order("id asc").Find(&branches) if trans.TransType == "saga" { return innerProcessCommittedSaga(trans, db, branches) } else if trans.TransType == "xa" { @@ -89,34 +89,70 @@ func innerProcessCommitted(trans *TransGlobalModel) (rerr error) { func innerProcessCommittedXa(trans *TransGlobalModel, db *common.MyDb, branches []TransBranchModel) error { gid := trans.Gid - for _, branch := range branches { - if branch.Status == "finished" { - continue + if trans.Status == "finished" { + return nil + } + if trans.Status == "committed" { + for _, branch := range branches { + if branch.Status == "finished" { + continue + } + db.Must().Model(&TransGlobalModel{}).Where("gid=?", gid).Update("gid", gid) // 更新update_time,避免被定时任务再次 + resp, err := common.RestyClient.R().SetBody(M{ + "branch": branch.Branch, + "action": "commit", + "gid": branch.Gid, + }).Post(branch.Url) + if err != nil { + return err + } + body := resp.String() + if !strings.Contains(body, "SUCCESS") { + return fmt.Errorf("bad response: %s", body) + } + writeTransLog(gid, "step finished", "finished", branch.Branch, "") + db.Must().Model(&branch).Where("status=?", "prepared").Updates(M{ + "status": "finished", + "finish_time": time.Now(), + }) } - db.Must().Model(&TransGlobalModel{}).Where("gid=?", gid).Update("gid", gid) // 更新update_time,避免被定时任务再次 - resp, err := common.RestyClient.R().SetBody(M{ - "branch": branch.Branch, - "action": "commit", - "gid": branch.Gid, - }).Post(branch.Url) - if err != nil { - return err - } - body := resp.String() - if !strings.Contains(body, "SUCCESS") { - return fmt.Errorf("bad response: %s", body) - } - writeTransLog(gid, "step finished", "finished", branch.Branch, "") - db.Must().Model(&branch).Where("status=?", "prepared").Updates(M{ + writeTransLog(gid, "xa finished", "finished", "", "") + db.Must().Model(&TransGlobalModel{}).Where("gid=? and status=?", gid, "committed").Updates(M{ "status": "finished", "finish_time": time.Now(), }) + } else if trans.Status == "prepared" { // 未commit直接处理的情况为回滚场景 + for _, branch := range branches { + if branch.Status == "rollbacked" { + continue + } + db.Must().Model(&TransGlobalModel{}).Where("gid=?", gid).Update("gid", gid) // 更新update_time,避免被定时任务再次 + resp, err := common.RestyClient.R().SetBody(M{ + "branch": branch.Branch, + "action": "rollback", + "gid": branch.Gid, + }).Post(branch.Url) + if err != nil { + return err + } + body := resp.String() + if !strings.Contains(body, "SUCCESS") { + return fmt.Errorf("bad response: %s", body) + } + writeTransLog(gid, "step rollbacked", "rollbacked", branch.Branch, "") + db.Must().Model(&branch).Where("status=?", "prepared").Updates(M{ + "status": "rollbacked", + "finish_time": time.Now(), + }) + } + writeTransLog(gid, "xa rollbacked", "rollbacked", "", "") + db.Must().Model(&TransGlobalModel{}).Where("gid=? and status=?", gid, "prepared").Updates(M{ + "status": "rollbacked", + "finish_time": time.Now(), + }) + } else { + return fmt.Errorf("bad trans status: %s", trans.Status) } - writeTransLog(gid, "xa finished", "finished", "", "") - db.Must().Model(&TransGlobalModel{}).Where("gid=? and status=?", gid, "committed").Updates(M{ - "status": "finished", - "finish_time": time.Now(), - }) return nil } diff --git a/examples/xa_main.go b/examples/xa_main.go index 4e5b2e5..03df76c 100644 --- a/examples/xa_main.go +++ b/examples/xa_main.go @@ -63,8 +63,12 @@ func XaAddRoute(app *gin.Engine) { func XaTransIn(c *gin.Context) (interface{}, error) { err := XaClient.XaLocalTransaction(c.Query("gid"), func(db *common.MyDb) (rerr error) { + req := transReqFromContext(c) + if req.TransInResult != "SUCCESS" { + return fmt.Errorf("tranIn failed") + } dbr := db.Model(&UserAccount{}).Where("user_id = ?", c.Query("user_id")). - Update("balance", gorm.Expr("balance - ?", transReqFromContext(c).Amount)) + Update("balance", gorm.Expr("balance - ?", req.Amount)) return dbr.Error }) common.PanicIfError(err) @@ -73,8 +77,12 @@ func XaTransIn(c *gin.Context) (interface{}, error) { func XaTransOut(c *gin.Context) (interface{}, error) { err := XaClient.XaLocalTransaction(c.Query("gid"), func(db *common.MyDb) (rerr error) { + req := transReqFromContext(c) + if req.TransOutResult != "SUCCESS" { + return fmt.Errorf("tranOut failed") + } dbr := db.Model(&UserAccount{}).Where("user_id = ?", c.Query("user_id")). - Update("balance", gorm.Expr("balance + ?", transReqFromContext(c).Amount)) + Update("balance", gorm.Expr("balance + ?", req.Amount)) return dbr.Error }) common.PanicIfError(err)