add barrier test

This commit is contained in:
yedongfu 2021-07-19 20:25:52 +08:00
parent 933df4b706
commit 4a7f397567
9 changed files with 142 additions and 46 deletions

View File

@ -32,6 +32,7 @@ type SagaStep struct {
func NewSaga(server string) *Saga { func NewSaga(server string) *Saga {
return &Saga{ return &Saga{
SagaData: SagaData{ SagaData: SagaData{
Gid: GenGid(server),
TransType: "saga", TransType: "saga",
}, },
Server: server, Server: server,

View File

@ -5,6 +5,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-resty/resty/v2" "github.com/go-resty/resty/v2"
"github.com/sirupsen/logrus"
"github.com/yedf/dtm/common" "github.com/yedf/dtm/common"
) )
@ -26,10 +27,14 @@ func TccGlobalTransaction(dtm string, tccFunc TccGlobalFunc) (gid string, rerr e
"trans_type": "tcc", "trans_type": "tcc",
} }
defer func() { defer func() {
var err error
if x := recover(); x != nil || rerr != nil { if x := recover(); x != nil || rerr != nil {
_, rerr = common.RestyClient.R().SetBody(data).Post(dtm + "/abort") _, err = common.RestyClient.R().SetBody(data).Post(dtm + "/abort")
} else { } else {
_, rerr = common.RestyClient.R().SetBody(data).Post(dtm + "/submit") _, err = common.RestyClient.R().SetBody(data).Post(dtm + "/submit")
}
if err != nil {
logrus.Errorf("submitting or abort global transaction error: %v", err)
} }
}() }()
tcc := &Tcc{Dtm: dtm, Gid: gid} tcc := &Tcc{Dtm: dtm, Gid: gid}

View File

@ -25,38 +25,46 @@ func newGid(c *gin.Context) (interface{}, error) {
} }
func prepare(c *gin.Context) (interface{}, error) { func prepare(c *gin.Context) (interface{}, error) {
m := TransFromContext(c) t := TransFromContext(c)
m.Status = "prepared" t.Status = "prepared"
m.saveNew(dbGet()) t.saveNew(dbGet())
return M{"message": "SUCCESS", "gid": m.Gid}, nil return M{"dtm_result": "SUCCESS", "gid": t.Gid}, nil
} }
func submit(c *gin.Context) (interface{}, error) { func submit(c *gin.Context) (interface{}, error) {
db := dbGet() db := dbGet()
m := TransFromContext(c) t := TransFromContext(c)
m.Status = "submitted" dbt := TransFromDb(db, t.Gid)
m.saveNew(db) if dbt != nil && dbt.Status != "prepared" && dbt.Status != "submitted" {
go m.Process(db) return M{"dtm_result": "FAILURE", "message": fmt.Sprintf("current status %s, cannot sumbmit", dbt.Status)}, nil
return M{"message": "SUCCESS", "gid": m.Gid}, nil }
t.Status = "submitted"
t.saveNew(db)
go t.Process(db)
return M{"dtm_result": "SUCCESS", "gid": t.Gid}, nil
} }
func abort(c *gin.Context) (interface{}, error) { func abort(c *gin.Context) (interface{}, error) {
db := dbGet() db := dbGet()
m := TransFromContext(c) t := TransFromContext(c)
m = TransFromDb(db, m.Gid) dbt := TransFromDb(db, t.Gid)
if m.TransType != "xa" && m.TransType != "tcc" || m.Status != "prepared" { if t.TransType != "xa" && t.TransType != "tcc" || dbt.Status != "prepared" && dbt.Status != "aborting" {
return nil, fmt.Errorf("unexpected trans data. type: %s status: %s for gid: %s", m.TransType, m.Status, m.Gid) return M{"dtm_result": "FAILURE", "message": fmt.Sprintf("trans type: %s current status %s, cannot abort", dbt.TransType, dbt.Status)}, nil
} }
go m.Process(db) go dbt.Process(db)
return M{"message": "SUCCESS"}, nil return M{"dtm_result": "SUCCESS"}, nil
} }
func registerXaBranch(c *gin.Context) (interface{}, error) { func registerXaBranch(c *gin.Context) (interface{}, error) {
branch := TransBranch{} branch := TransBranch{}
err := c.BindJSON(&branch) err := c.BindJSON(&branch)
e2p(err) e2p(err)
branches := []TransBranch{branch, branch}
db := dbGet() db := dbGet()
dbt := TransFromDb(db, branch.Gid)
if dbt.Status != "prepared" {
return M{"dtm_result": "FAILURE", "message": fmt.Sprintf("current status: %s cannot register branch", dbt.Status)}, nil
}
branches := []TransBranch{branch, branch}
branches[0].BranchType = "rollback" branches[0].BranchType = "rollback"
branches[1].BranchType = "commit" branches[1].BranchType = "commit"
db.Must().Clauses(clause.OnConflict{ db.Must().Clauses(clause.OnConflict{
@ -65,7 +73,7 @@ func registerXaBranch(c *gin.Context) (interface{}, error) {
e2p(err) e2p(err)
global := TransGlobal{Gid: branch.Gid} global := TransGlobal{Gid: branch.Gid}
global.touch(db, config.TransCronInterval) global.touch(db, config.TransCronInterval)
return M{"message": "SUCCESS"}, nil return M{"dtm_result": "SUCCESS"}, nil
} }
func registerTccBranch(c *gin.Context) (interface{}, error) { func registerTccBranch(c *gin.Context) (interface{}, error) {
@ -78,6 +86,11 @@ func registerTccBranch(c *gin.Context) (interface{}, error) {
Status: data["status"], Status: data["status"],
Data: data["data"], Data: data["data"],
} }
db := dbGet()
dbt := TransFromDb(db, branch.Gid)
if dbt.Status != "prepared" {
return M{"dtm_result": "FAILURE", "message": fmt.Sprintf("current status: %s cannot register branch", dbt.Status)}, nil
}
branches := []TransBranch{branch, branch, branch} branches := []TransBranch{branch, branch, branch}
for i, b := range []string{"cancel", "confirm", "try"} { for i, b := range []string{"cancel", "confirm", "try"} {
@ -85,13 +98,13 @@ func registerTccBranch(c *gin.Context) (interface{}, error) {
branches[i].URL = data[b] branches[i].URL = data[b]
} }
dbGet().Must().Clauses(clause.OnConflict{ db.Must().Clauses(clause.OnConflict{
DoNothing: true, DoNothing: true,
}).Create(branches) }).Create(branches)
e2p(err) e2p(err)
global := TransGlobal{Gid: branch.Gid} global := TransGlobal{Gid: branch.Gid}
global.touch(dbGet(), config.TransCronInterval) global.touch(dbGet(), config.TransCronInterval)
return M{"message": "SUCCESS"}, nil return M{"dtm_result": "SUCCESS"}, nil
} }
func query(c *gin.Context) (interface{}, error) { func query(c *gin.Context) (interface{}, error) {

View File

@ -9,10 +9,10 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// CronTransOnce cron expired trans who's status match param status for once. use expireIn as expire time // CronTransOnce cron expired trans. use expireIn as expire time
func CronTransOnce(expireIn time.Duration, status string) bool { func CronTransOnce(expireIn time.Duration) bool {
defer handlePanic() defer handlePanic()
trans := lockOneTrans(expireIn, status) trans := lockOneTrans(expireIn)
if trans == nil { if trans == nil {
return false return false
} }
@ -22,21 +22,21 @@ func CronTransOnce(expireIn time.Duration, status string) bool {
} }
// CronExpiredTrans cron expired trans, num == -1 indicate for ever // CronExpiredTrans cron expired trans, num == -1 indicate for ever
func CronExpiredTrans(status string, num int) { func CronExpiredTrans(num int) {
for i := 0; i < num || num == -1; i++ { for i := 0; i < num || num == -1; i++ {
notEmpty := CronTransOnce(time.Duration(0), status) notEmpty := CronTransOnce(time.Duration(0))
if !notEmpty { if !notEmpty {
sleepCronTime() sleepCronTime()
} }
} }
} }
func lockOneTrans(expireIn time.Duration, status string) *TransGlobal { func lockOneTrans(expireIn time.Duration) *TransGlobal {
trans := TransGlobal{} trans := TransGlobal{}
owner := GenGid() owner := GenGid()
db := dbGet() db := dbGet()
dbr := db.Must().Model(&trans). dbr := db.Must().Model(&trans).
Where("next_cron_time < date_add(now(), interval ? second) and status=?", int(expireIn/time.Second), status). Where("next_cron_time < date_add(now(), interval ? second) and status in ('prepared', 'aborting', 'submitted')", int(expireIn/time.Second)).
Limit(1).Update("owner", owner) Limit(1).Update("owner", owner)
if dbr.RowsAffected == 0 { if dbr.RowsAffected == 0 {
return nil return nil

View File

@ -7,6 +7,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/yedf/dtm/common" "github.com/yedf/dtm/common"
@ -16,19 +17,17 @@ import (
var DtmServer = examples.DtmServer var DtmServer = examples.DtmServer
var Busi = examples.Busi var Busi = examples.Busi
var app *gin.Engine
func init() { func init() {
TransProcessedTestChan = make(chan string, 1)
common.InitApp(common.GetProjectDir(), &config) common.InitApp(common.GetProjectDir(), &config)
config.Mysql["database"] = dbName config.Mysql["database"] = dbName
PopulateMysql() PopulateMysql()
examples.PopulateMysql() examples.PopulateMysql()
}
func TestDtmSvr(t *testing.T) {
TransProcessedTestChan = make(chan string, 1)
// 启动组件 // 启动组件
go StartSvr() go StartSvr()
app := examples.BaseAppStartup() app = examples.BaseAppStartup()
examples.SagaSetup(app) examples.SagaSetup(app)
examples.TccSetup(app) examples.TccSetup(app)
examples.XaSetup(app) examples.XaSetup(app)
@ -41,7 +40,11 @@ func TestDtmSvr(t *testing.T) {
e2p(dbGet().Exec("truncate trans_branch").Error) e2p(dbGet().Exec("truncate trans_branch").Error)
e2p(dbGet().Exec("truncate trans_log").Error) e2p(dbGet().Exec("truncate trans_log").Error)
examples.ResetXaData() examples.ResetXaData()
}
func TestDtmSvr(t *testing.T) {
tccBarrierDisorder(t)
tccBarrierNormal(t) tccBarrierNormal(t)
tccBarrierRollback(t) tccBarrierRollback(t)
sagaBarrierNormal(t) sagaBarrierNormal(t)
@ -70,12 +73,11 @@ func TestDtmSvr(t *testing.T) {
func TestCover(t *testing.T) { func TestCover(t *testing.T) {
db := dbGet() db := dbGet()
db.NoMust() db.NoMust()
CronTransOnce(0, "prepared") CronTransOnce(0)
CronTransOnce(0, "submitted")
defer handlePanic() defer handlePanic()
checkAffected(db.DB) checkAffected(db.DB)
go CronExpiredTrans("submitted", 1) go CronExpiredTrans(1)
} }
func getTransStatus(gid string) string { func getTransStatus(gid string) string {
@ -176,7 +178,7 @@ func tccBarrierRollback(t *testing.T) {
logrus.Printf("tcc returns: %s, %s", res1.String(), res2.String()) logrus.Printf("tcc returns: %s, %s", res1.String(), res2.String())
return return
}) })
e2p(err) assert.Equal(t, err, fmt.Errorf("branch trans in fail"))
WaitTransProcessed(gid) WaitTransProcessed(gid)
assert.Equal(t, "failed", getTransStatus(gid)) assert.Equal(t, "failed", getTransStatus(gid))
} }
@ -206,12 +208,12 @@ func msgPending(t *testing.T) {
msg.Prepare("") msg.Prepare("")
assert.Equal(t, "prepared", getTransStatus(msg.Gid)) assert.Equal(t, "prepared", getTransStatus(msg.Gid))
examples.MainSwitch.CanSubmitResult.SetOnce("PENDING") examples.MainSwitch.CanSubmitResult.SetOnce("PENDING")
CronTransOnce(60*time.Second, "prepared") CronTransOnce(60 * time.Second)
assert.Equal(t, "prepared", getTransStatus(msg.Gid)) assert.Equal(t, "prepared", getTransStatus(msg.Gid))
examples.MainSwitch.TransInResult.SetOnce("PENDING") examples.MainSwitch.TransInResult.SetOnce("PENDING")
CronTransOnce(60*time.Second, "prepared") CronTransOnce(60 * time.Second)
assert.Equal(t, "submitted", getTransStatus(msg.Gid)) assert.Equal(t, "submitted", getTransStatus(msg.Gid))
CronTransOnce(60*time.Second, "submitted") CronTransOnce(60 * time.Second)
assert.Equal(t, []string{"succeed", "succeed"}, getBranchesStatus(msg.Gid)) assert.Equal(t, []string{"succeed", "succeed"}, getBranchesStatus(msg.Gid))
assert.Equal(t, "succeed", getTransStatus(msg.Gid)) assert.Equal(t, "succeed", getTransStatus(msg.Gid))
} }
@ -262,7 +264,7 @@ func sagaCommittedPending(t *testing.T) {
saga.Submit() saga.Submit()
WaitTransProcessed(saga.Gid) WaitTransProcessed(saga.Gid)
assert.Equal(t, []string{"prepared", "prepared", "prepared", "prepared"}, getBranchesStatus(saga.Gid)) assert.Equal(t, []string{"prepared", "prepared", "prepared", "prepared"}, getBranchesStatus(saga.Gid))
CronTransOnce(60*time.Second, "submitted") CronTransOnce(60 * time.Second)
assert.Equal(t, []string{"prepared", "succeed", "prepared", "succeed"}, getBranchesStatus(saga.Gid)) assert.Equal(t, []string{"prepared", "succeed", "prepared", "succeed"}, getBranchesStatus(saga.Gid))
assert.Equal(t, "succeed", getTransStatus(saga.Gid)) assert.Equal(t, "succeed", getTransStatus(saga.Gid))
} }
@ -341,3 +343,70 @@ func TestSqlDB(t *testing.T) {
dbr = db.Model(&dtmcli.BarrierModel{}).Where("gid=?", "gid2").Find(&[]dtmcli.BarrierModel{}) dbr = db.Model(&dtmcli.BarrierModel{}).Where("gid=?", "gid2").Find(&[]dtmcli.BarrierModel{})
asserts.Equal(dbr.RowsAffected, int64(1)) asserts.Equal(dbr.RowsAffected, int64(1))
} }
func tccBarrierDisorder(t *testing.T) {
timeoutChan := make(chan string, 2)
finishedChan := make(chan string, 2)
gid, err := dtmcli.TccGlobalTransaction(DtmServer, func(tcc *dtmcli.Tcc) (rerr error) {
body := &examples.TransReq{Amount: 30}
tryURL := Busi + "/TccBTransOutTry"
confirmURL := Busi + "/TccBTransOutConfirm"
cancelURL := Busi + "/TccBSleepCancel"
// 请参见子事务屏障里的时序图这里为了模拟该时序图手动拆解了callbranch
branchID := tcc.NewBranchID()
sleeped := false
app.POST(examples.BusiAPI+"/TccBSleepCancel", common.WrapHandler(func(c *gin.Context) (interface{}, error) {
res, err := examples.TccBarrierTransOutCancel(c)
if !sleeped {
sleeped = true
logrus.Printf("sleep before cancel return")
<-timeoutChan
finishedChan <- "1"
}
return res, err
}))
// 注册子事务
_, err := common.RestyClient.R().
SetBody(&M{
"gid": tcc.Gid,
"branch_id": branchID,
"trans_type": "tcc",
"status": "prepared",
"data": string(common.MustMarshal(body)),
"try": tryURL,
"confirm": confirmURL,
"cancel": cancelURL,
}).
Post(tcc.Dtm + "/registerTccBranch")
e2p(err)
go func() {
logrus.Printf("sleeping to wait for tcc try timeout")
<-timeoutChan
_, _ = common.RestyClient.R().
SetBody(body).
SetQueryParams(common.MS{
"dtm": tcc.Dtm,
"gid": tcc.Gid,
"branch_id": branchID,
"trans_type": "tcc",
"branch_type": "try",
}).
Post(tryURL)
finishedChan <- "1"
}()
logrus.Printf("cron to timeout and then call cancel")
go CronTransOnce(60 * time.Second)
time.Sleep(100 * time.Millisecond)
logrus.Printf("cron to timeout and then call cancelled twice")
CronTransOnce(60 * time.Second)
timeoutChan <- "wake"
timeoutChan <- "wake"
<-finishedChan
<-finishedChan
time.Sleep(100 * time.Millisecond)
return fmt.Errorf("a cancelled tcc")
})
assert.Error(t, err, fmt.Errorf("a cancelled tcc"))
assert.Equal(t, []string{"succeed", "prepared", "prepared"}, getBranchesStatus(gid))
assert.Equal(t, "failed", getTransStatus(gid))
}

View File

@ -14,8 +14,7 @@ var dtmsvrPort = 8080
// MainStart main // MainStart main
func MainStart() { func MainStart() {
StartSvr() StartSvr()
go CronExpiredTrans("submitted", -1) go CronExpiredTrans(-1)
go CronExpiredTrans("prepared", -1)
} }
// StartSvr StartSvr // StartSvr StartSvr

View File

@ -118,7 +118,10 @@ func (t *TransGlobal) Process(db *common.DB) {
TransProcessedTestChan <- t.Gid TransProcessedTestChan <- t.Gid
} }
}() }()
logrus.Printf("processing: %s", t.Gid) logrus.Printf("processing: %s status: %s", t.Gid, t.Status)
if t.Status == "prepared" && t.TransType != "msg" {
t.changeStatus(db, "aborting")
}
branches := []TransBranch{} branches := []TransBranch{}
db.Must().Where("gid=?", t.Gid).Order("id asc").Find(&branches) db.Must().Where("gid=?", t.Gid).Order("id asc").Find(&branches)
t.getProcessor().ProcessOnce(db, branches) t.getProcessor().ProcessOnce(db, branches)
@ -186,6 +189,9 @@ func TransFromContext(c *gin.Context) *TransGlobal {
func TransFromDb(db *common.DB, gid string) *TransGlobal { func TransFromDb(db *common.DB, gid string) *TransGlobal {
m := TransGlobal{} m := TransGlobal{}
dbr := db.Must().Model(&m).Where("gid=?", gid).First(&m) dbr := db.Must().Model(&m).Where("gid=?", gid).First(&m)
if dbr.Error == gorm.ErrRecordNotFound {
return nil
}
e2p(dbr.Error) e2p(dbr.Error)
return &m return &m
} }

View File

@ -69,6 +69,9 @@ func (t *transSagaProcessor) ProcessOnce(db *common.DB, branches []TransBranch)
t.changeStatus(db, "succeed") t.changeStatus(db, "succeed")
return return
} }
if t.Status != "aborting" && t.Status != "failed" {
t.changeStatus(db, "aborting")
}
for current = current - 1; current >= 0; current-- { for current = current - 1; current >= 0; current-- {
branch := &branches[current] branch := &branches[current]
if branch.BranchType != "compensate" || branch.Status != "prepared" { if branch.BranchType != "compensate" || branch.Status != "prepared" {

View File

@ -42,7 +42,7 @@ func TccBarrierAddRoute(app *gin.Engine) {
app.POST(BusiAPI+"/TccBTransInCancel", common.WrapHandler(tccBarrierTransInCancel)) app.POST(BusiAPI+"/TccBTransInCancel", common.WrapHandler(tccBarrierTransInCancel))
app.POST(BusiAPI+"/TccBTransOutTry", common.WrapHandler(tccBarrierTransOutTry)) app.POST(BusiAPI+"/TccBTransOutTry", common.WrapHandler(tccBarrierTransOutTry))
app.POST(BusiAPI+"/TccBTransOutConfirm", common.WrapHandler(tccBarrierTransOutConfirm)) app.POST(BusiAPI+"/TccBTransOutConfirm", common.WrapHandler(tccBarrierTransOutConfirm))
app.POST(BusiAPI+"/TccBTransOutCancel", common.WrapHandler(tccBarrierTransOutCancel)) app.POST(BusiAPI+"/TccBTransOutCancel", common.WrapHandler(TccBarrierTransOutCancel))
logrus.Printf("examples listening at %d", BusiPort) logrus.Printf("examples listening at %d", BusiPort)
} }
@ -112,7 +112,7 @@ func tccBarrierTransOutConfirm(c *gin.Context) (interface{}, error) {
}) })
} }
func tccBarrierTransOutCancel(c *gin.Context) (interface{}, error) { func TccBarrierTransOutCancel(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return adjustTrading(sdb, transOutUID, reqFrom(c).Amount) return adjustTrading(sdb, transOutUID, reqFrom(c).Amount)
}) })