xa rollback ok

This commit is contained in:
yedongfu 2021-05-25 22:46:29 +08:00
parent 828de82f14
commit ef323fd70d
5 changed files with 100 additions and 43 deletions

View File

@ -11,6 +11,7 @@ func AddRoute(engine *gin.Engine) {
engine.POST("/api/dtmsvr/prepare", common.WrapHandler(Prepare)) engine.POST("/api/dtmsvr/prepare", common.WrapHandler(Prepare))
engine.POST("/api/dtmsvr/commit", common.WrapHandler(Commit)) engine.POST("/api/dtmsvr/commit", common.WrapHandler(Commit))
engine.POST("/api/dtmsvr/branch", common.WrapHandler(Branch)) engine.POST("/api/dtmsvr/branch", common.WrapHandler(Branch))
engine.POST("/api/dtmsvr/rollback", common.WrapHandler(Rollback))
} }
func Prepare(c *gin.Context) (interface{}, error) { func Prepare(c *gin.Context) (interface{}, error) {
@ -27,7 +28,16 @@ func Prepare(c *gin.Context) (interface{}, error) {
func Commit(c *gin.Context) (interface{}, error) { func Commit(c *gin.Context) (interface{}, error) {
m := getTransFromContext(c) m := getTransFromContext(c)
saveCommitted(m) saveCommitted(m)
go ProcessCommitted(m) go ProcessTrans(m)
return M{"message": "SUCCESS"}, nil
}
func Rollback(c *gin.Context) (interface{}, error) {
m := getTransFromContext(c)
trans := TransGlobalModel{}
dbGet().Must().Model(&m).First(&trans)
// 当前xa trans的状态为prepared直接处理则是回滚
go ProcessTrans(&trans)
return M{"message": "SUCCESS"}, nil return M{"message": "SUCCESS"}, nil
} }

View File

@ -31,7 +31,7 @@ func CronPreparedOnce(expire time.Duration) {
db.Must().Model(&sm).Where("status = ?", "prepared").Update("status", status) db.Must().Model(&sm).Where("status = ?", "prepared").Update("status", status)
} else if strings.Contains(body, "SUCCESS") { } else if strings.Contains(body, "SUCCESS") {
saveCommitted(&sm) saveCommitted(&sm)
ProcessCommitted(&sm) ProcessTrans(&sm)
} }
} }
} }
@ -54,7 +54,7 @@ func CronCommittedOnce(expire time.Duration) {
for _, sm := range ss { for _, sm := range ss {
writeTransLog(sm.Gid, "saga touch committed", "", "", "") writeTransLog(sm.Gid, "saga touch committed", "", "", "")
db.Must().Model(&sm).Update("id", sm.ID) db.Must().Model(&sm).Update("id", sm.ID)
ProcessCommitted(&sm) ProcessTrans(&sm)
} }
} }

View File

@ -37,6 +37,7 @@ func TestDtmSvr(t *testing.T) {
common.PanicIfError(dbGet().Exec("truncate trans_log").Error) common.PanicIfError(dbGet().Exec("truncate trans_log").Error)
examples.ResetXaData() examples.ResetXaData()
xaRollback(t)
xaNormal(t) xaNormal(t)
sagaPreparePending(t) sagaPreparePending(t)
sagaPrepareCancel(t) sagaPrepareCancel(t)
@ -93,7 +94,7 @@ func xaNormal(t *testing.T) {
return nil return nil
}) })
common.PanicIfError(err) common.PanicIfError(err)
WaitTransCommitted(gid) WaitTransProcessed(gid)
assert.Equal(t, []string{"finished", "finished"}, getBranchesStatus(gid)) assert.Equal(t, []string{"finished", "finished"}, getBranchesStatus(gid))
} }
@ -114,9 +115,11 @@ func xaRollback(t *testing.T) {
common.CheckRestySuccess(resp, err) common.CheckRestySuccess(resp, err)
return nil return nil
}) })
common.PanicIfError(err) if err != nil {
WaitTransCommitted(gid) logrus.Errorf("global transaction failed, so rollback")
assert.Equal(t, []string{"rollbacked", "rollbacked"}, getBranchesStatus(gid)) }
WaitTransProcessed(gid)
assert.Equal(t, []string{"rollbacked"}, getBranchesStatus(gid))
} }
func sagaNormal(t *testing.T) { func sagaNormal(t *testing.T) {
@ -125,14 +128,14 @@ func sagaNormal(t *testing.T) {
assert.Equal(t, "prepared", getSagaModel(saga.Gid).Status) assert.Equal(t, "prepared", getSagaModel(saga.Gid).Status)
saga.Commit() saga.Commit()
assert.Equal(t, "committed", getSagaModel(saga.Gid).Status) assert.Equal(t, "committed", getSagaModel(saga.Gid).Status)
WaitTransCommitted(saga.Gid) WaitTransProcessed(saga.Gid)
assert.Equal(t, []string{"prepared", "finished", "prepared", "finished"}, getBranchesStatus(saga.Gid)) assert.Equal(t, []string{"prepared", "finished", "prepared", "finished"}, getBranchesStatus(saga.Gid))
} }
func sagaRollback(t *testing.T) { func sagaRollback(t *testing.T) {
saga := genSaga("gid-rollbackSaga2", false, true) saga := genSaga("gid-rollbackSaga2", false, true)
saga.Commit() saga.Commit()
WaitTransCommitted(saga.Gid) WaitTransProcessed(saga.Gid)
saga.Prepare() saga.Prepare()
assert.Equal(t, "rollbacked", getSagaModel(saga.Gid).Status) assert.Equal(t, "rollbacked", getSagaModel(saga.Gid).Status)
assert.Equal(t, []string{"rollbacked", "finished", "rollbacked", "rollbacked"}, getBranchesStatus(saga.Gid)) assert.Equal(t, []string{"rollbacked", "finished", "rollbacked", "rollbacked"}, getBranchesStatus(saga.Gid))
@ -157,7 +160,7 @@ func sagaPreparePending(t *testing.T) {
examples.TransQueryResult = "" examples.TransQueryResult = ""
assert.Equal(t, "prepared", getSagaModel(saga.Gid).Status) assert.Equal(t, "prepared", getSagaModel(saga.Gid).Status)
CronPreparedOnce(-10 * time.Second) CronPreparedOnce(-10 * time.Second)
WaitTransCommitted(saga.Gid) WaitTransProcessed(saga.Gid)
assert.Equal(t, "finished", getSagaModel(saga.Gid).Status) assert.Equal(t, "finished", getSagaModel(saga.Gid).Status)
} }
@ -166,11 +169,11 @@ func sagaCommittedPending(t *testing.T) {
saga.Prepare() saga.Prepare()
examples.TransInResult = "PENDING" examples.TransInResult = "PENDING"
saga.Commit() saga.Commit()
WaitTransCommitted(saga.Gid) WaitTransProcessed(saga.Gid)
examples.TransInResult = "" examples.TransInResult = ""
assert.Equal(t, []string{"prepared", "finished", "prepared", "prepared"}, getBranchesStatus(saga.Gid)) assert.Equal(t, []string{"prepared", "finished", "prepared", "prepared"}, getBranchesStatus(saga.Gid))
CronCommittedOnce(-10 * time.Second) CronCommittedOnce(-10 * time.Second)
WaitTransCommitted(saga.Gid) WaitTransProcessed(saga.Gid)
assert.Equal(t, []string{"prepared", "finished", "prepared", "finished"}, getBranchesStatus(saga.Gid)) assert.Equal(t, []string{"prepared", "finished", "prepared", "finished"}, getBranchesStatus(saga.Gid))
assert.Equal(t, "finished", getSagaModel(saga.Gid).Status) assert.Equal(t, "finished", getSagaModel(saga.Gid).Status)
} }

View File

@ -58,7 +58,7 @@ func saveCommitted(m *TransGlobalModel) {
var TransProcessedTestChan chan string = nil // 用于测试时,通知处理结束 var TransProcessedTestChan chan string = nil // 用于测试时,通知处理结束
func WaitTransCommitted(gid string) { func WaitTransProcessed(gid string) {
id := <-TransProcessedTestChan id := <-TransProcessedTestChan
for id != gid { for id != gid {
logrus.Errorf("-------id %s not match gid %s", id, gid) logrus.Errorf("-------id %s not match gid %s", id, gid)
@ -66,19 +66,19 @@ func WaitTransCommitted(gid string) {
} }
} }
func ProcessCommitted(trans *TransGlobalModel) { func ProcessTrans(trans *TransGlobalModel) {
err := innerProcessCommitted(trans) err := innerProcessTrans(trans)
if err != nil { if err != nil {
logrus.Errorf("process committed error: %s", err.Error()) logrus.Errorf("process trans ignore error: %s", err.Error())
} }
if TransProcessedTestChan != nil { if TransProcessedTestChan != nil {
TransProcessedTestChan <- trans.Gid TransProcessedTestChan <- trans.Gid
} }
} }
func innerProcessCommitted(trans *TransGlobalModel) (rerr error) { func innerProcessTrans(trans *TransGlobalModel) (rerr error) {
branches := []TransBranchModel{} branches := []TransBranchModel{}
db := dbGet() db := dbGet()
db.Must().Order("id asc").Find(&branches) db.Must().Where("gid=?", trans.Gid).Order("id asc").Find(&branches)
if trans.TransType == "saga" { if trans.TransType == "saga" {
return innerProcessCommittedSaga(trans, db, branches) return innerProcessCommittedSaga(trans, db, branches)
} else if trans.TransType == "xa" { } else if trans.TransType == "xa" {
@ -89,34 +89,70 @@ func innerProcessCommitted(trans *TransGlobalModel) (rerr error) {
func innerProcessCommittedXa(trans *TransGlobalModel, db *common.MyDb, branches []TransBranchModel) error { func innerProcessCommittedXa(trans *TransGlobalModel, db *common.MyDb, branches []TransBranchModel) error {
gid := trans.Gid gid := trans.Gid
for _, branch := range branches { if trans.Status == "finished" {
if branch.Status == "finished" { return nil
continue }
if trans.Status == "committed" {
for _, branch := range branches {
if branch.Status == "finished" {
continue
}
db.Must().Model(&TransGlobalModel{}).Where("gid=?", gid).Update("gid", gid) // 更新update_time避免被定时任务再次
resp, err := common.RestyClient.R().SetBody(M{
"branch": branch.Branch,
"action": "commit",
"gid": branch.Gid,
}).Post(branch.Url)
if err != nil {
return err
}
body := resp.String()
if !strings.Contains(body, "SUCCESS") {
return fmt.Errorf("bad response: %s", body)
}
writeTransLog(gid, "step finished", "finished", branch.Branch, "")
db.Must().Model(&branch).Where("status=?", "prepared").Updates(M{
"status": "finished",
"finish_time": time.Now(),
})
} }
db.Must().Model(&TransGlobalModel{}).Where("gid=?", gid).Update("gid", gid) // 更新update_time避免被定时任务再次 writeTransLog(gid, "xa finished", "finished", "", "")
resp, err := common.RestyClient.R().SetBody(M{ db.Must().Model(&TransGlobalModel{}).Where("gid=? and status=?", gid, "committed").Updates(M{
"branch": branch.Branch,
"action": "commit",
"gid": branch.Gid,
}).Post(branch.Url)
if err != nil {
return err
}
body := resp.String()
if !strings.Contains(body, "SUCCESS") {
return fmt.Errorf("bad response: %s", body)
}
writeTransLog(gid, "step finished", "finished", branch.Branch, "")
db.Must().Model(&branch).Where("status=?", "prepared").Updates(M{
"status": "finished", "status": "finished",
"finish_time": time.Now(), "finish_time": time.Now(),
}) })
} else if trans.Status == "prepared" { // 未commit直接处理的情况为回滚场景
for _, branch := range branches {
if branch.Status == "rollbacked" {
continue
}
db.Must().Model(&TransGlobalModel{}).Where("gid=?", gid).Update("gid", gid) // 更新update_time避免被定时任务再次
resp, err := common.RestyClient.R().SetBody(M{
"branch": branch.Branch,
"action": "rollback",
"gid": branch.Gid,
}).Post(branch.Url)
if err != nil {
return err
}
body := resp.String()
if !strings.Contains(body, "SUCCESS") {
return fmt.Errorf("bad response: %s", body)
}
writeTransLog(gid, "step rollbacked", "rollbacked", branch.Branch, "")
db.Must().Model(&branch).Where("status=?", "prepared").Updates(M{
"status": "rollbacked",
"finish_time": time.Now(),
})
}
writeTransLog(gid, "xa rollbacked", "rollbacked", "", "")
db.Must().Model(&TransGlobalModel{}).Where("gid=? and status=?", gid, "prepared").Updates(M{
"status": "rollbacked",
"finish_time": time.Now(),
})
} else {
return fmt.Errorf("bad trans status: %s", trans.Status)
} }
writeTransLog(gid, "xa finished", "finished", "", "")
db.Must().Model(&TransGlobalModel{}).Where("gid=? and status=?", gid, "committed").Updates(M{
"status": "finished",
"finish_time": time.Now(),
})
return nil return nil
} }

View File

@ -63,8 +63,12 @@ func XaAddRoute(app *gin.Engine) {
func XaTransIn(c *gin.Context) (interface{}, error) { func XaTransIn(c *gin.Context) (interface{}, error) {
err := XaClient.XaLocalTransaction(c.Query("gid"), func(db *common.MyDb) (rerr error) { err := XaClient.XaLocalTransaction(c.Query("gid"), func(db *common.MyDb) (rerr error) {
req := transReqFromContext(c)
if req.TransInResult != "SUCCESS" {
return fmt.Errorf("tranIn failed")
}
dbr := db.Model(&UserAccount{}).Where("user_id = ?", c.Query("user_id")). dbr := db.Model(&UserAccount{}).Where("user_id = ?", c.Query("user_id")).
Update("balance", gorm.Expr("balance - ?", transReqFromContext(c).Amount)) Update("balance", gorm.Expr("balance - ?", req.Amount))
return dbr.Error return dbr.Error
}) })
common.PanicIfError(err) common.PanicIfError(err)
@ -73,8 +77,12 @@ func XaTransIn(c *gin.Context) (interface{}, error) {
func XaTransOut(c *gin.Context) (interface{}, error) { func XaTransOut(c *gin.Context) (interface{}, error) {
err := XaClient.XaLocalTransaction(c.Query("gid"), func(db *common.MyDb) (rerr error) { err := XaClient.XaLocalTransaction(c.Query("gid"), func(db *common.MyDb) (rerr error) {
req := transReqFromContext(c)
if req.TransOutResult != "SUCCESS" {
return fmt.Errorf("tranOut failed")
}
dbr := db.Model(&UserAccount{}).Where("user_id = ?", c.Query("user_id")). dbr := db.Model(&UserAccount{}).Where("user_id = ?", c.Query("user_id")).
Update("balance", gorm.Expr("balance + ?", transReqFromContext(c).Amount)) Update("balance", gorm.Expr("balance + ?", req.Amount))
return dbr.Error return dbr.Error
}) })
common.PanicIfError(err) common.PanicIfError(err)