common tested

This commit is contained in:
yedongfu 2021-05-29 11:50:10 +08:00
parent faca20d756
commit 4864f20058
15 changed files with 591 additions and 430 deletions

View File

@ -14,7 +14,9 @@ func main() {
if len(os.Args) == 1 { // 所有服务都启动 if len(os.Args) == 1 { // 所有服务都启动
go dtmsvr.StartSvr() go dtmsvr.StartSvr()
go examples.SagaStartSvr() go examples.SagaStartSvr()
} else if len(os.Args) > 1 && os.Args[1] == "dtmsvr" { go examples.TccStartSvr()
go examples.XaStartSvr()
} else if os.Args[1] == "dtmsvr" {
go dtmsvr.StartSvr() go dtmsvr.StartSvr()
} }
for { for {

0
common/conf.yml.sample Normal file
View File

39
common/types_test.go Normal file
View File

@ -0,0 +1,39 @@
package common
import (
"testing"
"github.com/go-playground/assert/v2"
)
type testConfig struct {
Mysql map[string]string
}
var config = testConfig{}
var myinit int = func() int {
InitApp(&config)
return 0
}()
func TestDb(t *testing.T) {
db := DbGet(config.Mysql)
err := func() (rerr error) {
defer P2E(&rerr)
dbr := db.NoMust().Exec("select a")
assert.NotEqual(t, nil, dbr.Error)
db.Must().Exec("select a")
return nil
}()
assert.NotEqual(t, nil, err)
}
func TestDbAlone(t *testing.T) {
db, con := DbAlone(config.Mysql)
dbr := db.Exec("select 1")
assert.Equal(t, nil, dbr.Error)
con.Close()
dbr = db.Exec("select 1")
assert.NotEqual(t, nil, dbr.Error)
}

View File

@ -20,15 +20,6 @@ import (
type M = map[string]interface{} type M = map[string]interface{}
func OrString(ss ...string) string {
for _, s := range ss {
if s != "" {
return s
}
}
return ""
}
func P2E(perr *error) { func P2E(perr *error) {
if x := recover(); x != nil { if x := recover(); x != nil {
if e, ok := x.(error); ok { if e, ok := x.(error); ok {
@ -39,6 +30,18 @@ func P2E(perr *error) {
} }
} }
func E2P(err error) {
if err != nil {
panic(err)
}
}
func CatchP(f func()) (rerr error) {
defer P2E(&rerr)
f()
return nil
}
func PanicIf(cond bool, err error) { func PanicIf(cond bool, err error) {
if cond { if cond {
panic(err) panic(err)
@ -53,17 +56,18 @@ var gNode *snowflake.Node = nil
func init() { func init() {
node, err := snowflake.NewNode(1) node, err := snowflake.NewNode(1)
if err != nil { E2P(err)
panic(err)
}
gNode = node gNode = node
} }
func E2P(err error) { func OrString(ss ...string) string {
if err != nil { for _, s := range ss {
panic(err) if s != "" {
return s
} }
} }
return ""
}
func If(condition bool, trueObj interface{}, falseObj interface{}) interface{} { func If(condition bool, trueObj interface{}, falseObj interface{}) interface{} {
if condition { if condition {
@ -102,11 +106,10 @@ func GetGinApp() *gin.Engine {
app := gin.Default() app := gin.Default()
app.Use(func(c *gin.Context) { app.Use(func(c *gin.Context) {
body := "" body := ""
if c.Request.Method == "POST" { if c.Request.Body != nil {
rb, err := c.GetRawData() rb, err := c.GetRawData()
if err != nil { E2P(err)
logrus.Printf("GetRawData error: %s", err.Error()) if len(rb) > 0 {
} else {
body = string(rb) body = string(rb)
c.Request.Body = ioutil.NopCloser(bytes.NewBuffer(rb)) c.Request.Body = ioutil.NopCloser(bytes.NewBuffer(rb))
} }
@ -196,13 +199,12 @@ func InitApp(config interface{}) {
logrus.SetFormatter(&formatter{}) logrus.SetFormatter(&formatter{})
_, file, _, _ := runtime.Caller(1) _, file, _, _ := runtime.Caller(1)
fileName := filepath.Dir(file) + "/conf.yml" fileName := filepath.Dir(file) + "/conf.yml"
if configLoaded[fileName] { if !configLoaded[fileName] {
return
}
configLoaded[fileName] = true configLoaded[fileName] = true
viper.SetConfigFile(fileName) viper.SetConfigFile(fileName)
err := viper.ReadInConfig() err := viper.ReadInConfig()
E2P(err) E2P(err)
err = viper.Unmarshal(config) }
err := viper.Unmarshal(config)
E2P(err) E2P(err)
} }

95
common/utils_test.go Normal file
View File

@ -0,0 +1,95 @@
package common
import (
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/go-playground/assert/v2"
)
func TestEP(t *testing.T) {
skipped := true
err := func() (rerr error) {
defer P2E(&rerr)
E2P(errors.New("err1"))
skipped = false
return nil
}()
assert.Equal(t, true, skipped)
assert.Equal(t, "err1", err.Error())
err = CatchP(func() {
PanicIf(true, errors.New("err2"))
})
assert.Equal(t, "err2", err.Error())
err = func() (rerr error) {
defer func() {
x := recover()
assert.Equal(t, 1, x)
}()
defer P2E(&rerr)
panic(1)
}()
}
func TestGid(t *testing.T) {
id1 := GenGid()
id2 := GenGid()
assert.NotEqual(t, id1, id2)
}
func TestTernary(t *testing.T) {
assert.Equal(t, "1", OrString("", "", "1"))
assert.Equal(t, "", OrString("", "", ""))
assert.Equal(t, "1", If(true, "1", "2"))
assert.Equal(t, "2", If(false, "1", "2"))
}
func TestMarshal(t *testing.T) {
a := 0
type e struct {
A int
}
e1 := e{A: 10}
m := map[string]int{}
assert.Equal(t, "1", MustMarshalString(1))
assert.Equal(t, []byte("1"), MustMarshal(1))
MustUnmarshal([]byte("2"), &a)
assert.Equal(t, 2, a)
MustUnmarshalString("3", &a)
assert.Equal(t, 3, a)
MustRemarshal(&e1, &m)
assert.Equal(t, 10, m["A"])
}
func TestGin(t *testing.T) {
app := GetGinApp()
app.GET("/api/sample", WrapHandler(func(c *gin.Context) (interface{}, error) {
return 1, nil
}))
app.GET("/api/error", WrapHandler(func(c *gin.Context) (interface{}, error) {
return nil, errors.New("err1")
}))
getResultString := func(api string, body io.Reader) string {
req, _ := http.NewRequest("GET", api, body)
w := httptest.NewRecorder()
app.ServeHTTP(w, req)
return string(w.Body.Bytes())
}
assert.Equal(t, "{\"msg\":\"pong\"}", getResultString("/api/ping", nil))
assert.Equal(t, "1", getResultString("/api/sample", nil))
assert.Equal(t, "{\"code\":500,\"message\":\"err1\"}", getResultString("/api/error", strings.NewReader("{}")))
}
func TestResty(t *testing.T) {
resp, err := RestyClient.R().Get("http://baidu.com")
assert.Equal(t, nil, err)
err2 := CatchP(func() {
CheckRestySuccess(resp, err)
})
assert.NotEqual(t, nil, err2)
}

View File

@ -5,7 +5,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/yedf/dtm/common" "github.com/yedf/dtm/common"
"gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
) )
@ -49,15 +48,12 @@ func Branch(c *gin.Context) (interface{}, error) {
err := c.BindJSON(&branch) err := c.BindJSON(&branch)
e2p(err) e2p(err)
branches := []TransBranch{branch, branch} branches := []TransBranch{branch, branch}
err = dbGet().Transaction(func(tx *gorm.DB) error { db := dbGet()
db := &common.MyDb{DB: tx}
branches[0].BranchType = "rollback" branches[0].BranchType = "rollback"
branches[1].BranchType = "commit" branches[1].BranchType = "commit"
db.Must().Clauses(clause.OnConflict{ db.Must().Clauses(clause.OnConflict{
DoNothing: true, DoNothing: true,
}).Create(branches) }).Create(branches)
return nil
})
e2p(err) e2p(err)
return M{"message": "SUCCESS"}, nil return M{"message": "SUCCESS"}, nil
} }

View File

@ -1,20 +0,0 @@
package dtmsvr
import "github.com/yedf/dtm/common"
func dbGet() *common.MyDb {
return common.DbGet(config.Mysql)
}
func writeTransLog(gid string, action string, status string, branch string, detail string) {
db := dbGet()
if detail == "" {
detail = "{}"
}
db.Must().Table("trans_log").Create(M{
"gid": gid,
"action": action,
"new_status": status,
"branch": branch,
"detail": detail,
})
}

View File

@ -37,12 +37,14 @@ func TestDtmSvr(t *testing.T) {
e2p(dbGet().Exec("truncate trans_branch").Error) e2p(dbGet().Exec("truncate trans_branch").Error)
e2p(dbGet().Exec("truncate trans_log").Error) e2p(dbGet().Exec("truncate trans_log").Error)
examples.ResetXaData() examples.ResetXaData()
tccNormal(t) tccNormal(t)
tccRollback(t) tccRollback(t)
tccRollbackPending(t)
xaNormal(t)
xaRollback(t)
sagaCommittedPending(t) sagaCommittedPending(t)
sagaPreparePending(t) sagaPreparePending(t)
xaRollback(t)
xaNormal(t)
sagaPrepareCancel(t) sagaPrepareCancel(t)
sagaNormal(t) sagaNormal(t)
sagaRollback(t) sagaRollback(t)
@ -60,20 +62,20 @@ func TestCover(t *testing.T) {
// 测试使用的全局对象 // 测试使用的全局对象
var initdb = dbGet() var initdb = dbGet()
func getTransStatus(gid string) *TransGlobal { func getTransStatus(gid string) string {
sm := TransGlobal{} sm := TransGlobal{}
dbr := dbGet().Model(&sm).Where("gid=?", gid).First(&sm) dbr := dbGet().Model(&sm).Where("gid=?", gid).First(&sm)
e2p(dbr.Error) e2p(dbr.Error)
return &sm return sm.Status
} }
func getBranchesStatus(gid string) []string { func getBranchesStatus(gid string) []string {
steps := []TransBranch{} branches := []TransBranch{}
dbr := dbGet().Model(&TransBranch{}).Where("gid=?", gid).Order("id").Find(&steps) dbr := dbGet().Model(&TransBranch{}).Where("gid=?", gid).Order("id").Find(&branches)
e2p(dbr.Error) e2p(dbr.Error)
status := []string{} status := []string{}
for _, step := range steps { for _, branch := range branches {
status = append(status, step.Status) status = append(status, branch.Status)
} }
return status return status
} }
@ -122,15 +124,15 @@ func xaRollback(t *testing.T) {
} }
WaitTransProcessed(gid) WaitTransProcessed(gid)
assert.Equal(t, []string{"succeed", "prepared"}, getBranchesStatus(gid)) assert.Equal(t, []string{"succeed", "prepared"}, getBranchesStatus(gid))
assert.Equal(t, "failed", getTransStatus(gid).Status) assert.Equal(t, "failed", getTransStatus(gid))
} }
func tccNormal(t *testing.T) { func tccNormal(t *testing.T) {
tcc := genTcc("gid-tcc-normal", false, false) tcc := genTcc("gid-tcc-normal", false, false)
tcc.Prepare(tcc.QueryPrepared) tcc.Prepare(tcc.QueryPrepared)
assert.Equal(t, "prepared", getTransStatus(tcc.Gid).Status) assert.Equal(t, "prepared", getTransStatus(tcc.Gid))
tcc.Commit() tcc.Commit()
assert.Equal(t, "committed", getTransStatus(tcc.Gid).Status) assert.Equal(t, "committed", getTransStatus(tcc.Gid))
WaitTransProcessed(tcc.Gid) WaitTransProcessed(tcc.Gid)
assert.Equal(t, []string{"prepared", "succeed", "succeed", "prepared", "succeed", "succeed"}, getBranchesStatus(tcc.Gid)) assert.Equal(t, []string{"prepared", "succeed", "succeed", "prepared", "succeed", "succeed"}, getBranchesStatus(tcc.Gid))
} }
@ -140,12 +142,22 @@ func tccRollback(t *testing.T) {
WaitTransProcessed(tcc.Gid) WaitTransProcessed(tcc.Gid)
assert.Equal(t, []string{"succeed", "prepared", "succeed", "succeed", "prepared", "failed"}, getBranchesStatus(tcc.Gid)) assert.Equal(t, []string{"succeed", "prepared", "succeed", "succeed", "prepared", "failed"}, getBranchesStatus(tcc.Gid))
} }
func tccRollbackPending(t *testing.T) {
tcc := genTcc("gid-tcc-rollback-pending", false, true)
examples.TccTransInCancelResult = "PENDING"
tcc.Commit()
WaitTransProcessed(tcc.Gid)
assert.Equal(t, "committed", getTransStatus(tcc.Gid))
examples.TccTransInCancelResult = ""
CronTransOnce(-10*time.Second, "committed")
assert.Equal(t, []string{"succeed", "prepared", "succeed", "succeed", "prepared", "failed"}, getBranchesStatus(tcc.Gid))
}
func sagaNormal(t *testing.T) { func sagaNormal(t *testing.T) {
saga := genSaga("gid-noramlSaga", false, false) saga := genSaga("gid-noramlSaga", false, false)
saga.Prepare(saga.QueryPrepared) saga.Prepare(saga.QueryPrepared)
assert.Equal(t, "prepared", getTransStatus(saga.Gid).Status) assert.Equal(t, "prepared", getTransStatus(saga.Gid))
saga.Commit() saga.Commit()
assert.Equal(t, "committed", getTransStatus(saga.Gid).Status) assert.Equal(t, "committed", getTransStatus(saga.Gid))
WaitTransProcessed(saga.Gid) WaitTransProcessed(saga.Gid)
assert.Equal(t, []string{"prepared", "succeed", "prepared", "succeed"}, getBranchesStatus(saga.Gid)) assert.Equal(t, []string{"prepared", "succeed", "prepared", "succeed"}, getBranchesStatus(saga.Gid))
} }
@ -155,8 +167,8 @@ func sagaRollback(t *testing.T) {
saga.Commit() saga.Commit()
WaitTransProcessed(saga.Gid) WaitTransProcessed(saga.Gid)
saga.Prepare(saga.QueryPrepared) saga.Prepare(saga.QueryPrepared)
assert.Equal(t, "failed", getTransStatus(saga.Gid).Status) assert.Equal(t, "failed", getTransStatus(saga.Gid))
assert.Equal(t, []string{"failed", "succeed", "failed", "failed"}, getBranchesStatus(saga.Gid)) assert.Equal(t, []string{"succeed", "succeed", "succeed", "failed"}, getBranchesStatus(saga.Gid))
} }
func sagaPrepareCancel(t *testing.T) { func sagaPrepareCancel(t *testing.T) {
@ -167,7 +179,7 @@ func sagaPrepareCancel(t *testing.T) {
CronTransOnce(-10*time.Second, "prepared") CronTransOnce(-10*time.Second, "prepared")
examples.SagaTransQueryResult = "" examples.SagaTransQueryResult = ""
config.PreparedExpire = 60 config.PreparedExpire = 60
assert.Equal(t, "canceled", getTransStatus(saga.Gid).Status) assert.Equal(t, "canceled", getTransStatus(saga.Gid))
} }
func sagaPreparePending(t *testing.T) { func sagaPreparePending(t *testing.T) {
@ -176,9 +188,9 @@ func sagaPreparePending(t *testing.T) {
examples.SagaTransQueryResult = "PENDING" examples.SagaTransQueryResult = "PENDING"
CronTransOnce(-10*time.Second, "prepared") CronTransOnce(-10*time.Second, "prepared")
examples.SagaTransQueryResult = "" examples.SagaTransQueryResult = ""
assert.Equal(t, "prepared", getTransStatus(saga.Gid).Status) assert.Equal(t, "prepared", getTransStatus(saga.Gid))
CronTransOnce(-10*time.Second, "prepared") CronTransOnce(-10*time.Second, "prepared")
assert.Equal(t, "succeed", getTransStatus(saga.Gid).Status) assert.Equal(t, "succeed", getTransStatus(saga.Gid))
} }
func sagaCommittedPending(t *testing.T) { func sagaCommittedPending(t *testing.T) {
@ -191,7 +203,7 @@ func sagaCommittedPending(t *testing.T) {
assert.Equal(t, []string{"prepared", "succeed", "prepared", "prepared"}, getBranchesStatus(saga.Gid)) assert.Equal(t, []string{"prepared", "succeed", "prepared", "prepared"}, getBranchesStatus(saga.Gid))
CronTransOnce(-10*time.Second, "committed") CronTransOnce(-10*time.Second, "committed")
assert.Equal(t, []string{"prepared", "succeed", "prepared", "succeed"}, getBranchesStatus(saga.Gid)) assert.Equal(t, []string{"prepared", "succeed", "prepared", "succeed"}, getBranchesStatus(saga.Gid))
assert.Equal(t, "succeed", getTransStatus(saga.Gid).Status) assert.Equal(t, "succeed", getTransStatus(saga.Gid))
} }
func genSaga(gid string, outFailed bool, inFailed bool) *dtm.Saga { func genSaga(gid string, outFailed bool, inFailed bool) *dtm.Saga {

View File

@ -1,17 +0,0 @@
package dtmsvr
import (
"github.com/sirupsen/logrus"
)
var TransProcessedTestChan chan string = nil // 用于测试时,通知处理结束
func WaitTransProcessed(gid string) {
logrus.Printf("waiting for gid %s", gid)
id := <-TransProcessedTestChan
for id != gid {
logrus.Errorf("-------id %s not match gid %s", id, gid)
id = <-TransProcessedTestChan
}
logrus.Printf("finish for gid %s", gid)
}

View File

@ -5,45 +5,104 @@ import (
"strings" "strings"
"time" "time"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/yedf/dtm/common" "github.com/yedf/dtm/common"
"gorm.io/gorm"
"gorm.io/gorm/clause"
) )
type TransGlobal struct {
common.ModelBase
Gid string `json:"gid"`
TransType string `json:"trans_type"`
Data string `json:"data"`
Status string `json:"status"`
QueryPrepared string `json:"query_prepared"`
CommitTime *time.Time
FinishTime *time.Time
RollbackTime *time.Time
}
func (*TransGlobal) TableName() string {
return "trans_global"
}
type TransProcessor interface { type TransProcessor interface {
GenBranches() []TransBranch GenBranches() []TransBranch
ProcessOnce(db *common.MyDb, branches []TransBranch) ProcessOnce(db *common.MyDb, branches []TransBranch)
ExecBranch(db *common.MyDb, branch *TransBranch) string ExecBranch(db *common.MyDb, branch *TransBranch)
} }
type TransSagaProcessor struct { func (t *TransGlobal) touch(db *common.MyDb) *gorm.DB {
*TransGlobal writeTransLog(t.Gid, "touch trans", "", "", "")
return db.Model(&TransGlobal{}).Where("gid=?", t.Gid).Update("gid", t.Gid) // 更新update_time避免被定时任务再次
} }
func (t *TransSagaProcessor) GenBranches() []TransBranch { func (t *TransGlobal) changeStatus(db *common.MyDb, status string) *gorm.DB {
nsteps := []TransBranch{} writeTransLog(t.Gid, "change status", status, "", "")
steps := []M{} updates := M{
common.MustUnmarshalString(t.Data, &steps) "status": status,
for _, step := range steps { }
for _, branchType := range []string{"compensate", "action"} { if status == "succeed" {
nsteps = append(nsteps, TransBranch{ updates["finish_time"] = time.Now()
Gid: t.Gid, } else if status == "failed" {
Branch: fmt.Sprintf("%d", len(nsteps)+1), updates["rollback_time"] = time.Now()
Data: step["data"].(string), }
Url: step[branchType].(string), dbr := db.Must().Model(t).Where("status=?", t.Status).Updates(updates)
BranchType: branchType, checkAffected(dbr)
Status: "prepared", t.Status = status
return dbr
}
type TransBranch struct {
common.ModelBase
Gid string
Url string
Data string
Branch string
BranchType string
Status string
FinishTime *time.Time
RollbackTime *time.Time
}
func (*TransBranch) TableName() string {
return "trans_branch"
}
func (t *TransBranch) changeStatus(db *common.MyDb, status string) *gorm.DB {
writeTransLog(t.Gid, "branch change", status, t.Branch, "")
dbr := db.Must().Model(t).Where("status=?", t.Status).Updates(M{
"status": status,
"finish_time": time.Now(),
}) })
} checkAffected(dbr)
} t.Status = status
return nsteps return dbr
} }
func (t *TransSagaProcessor) ExecBranch(db *common.MyDb, branche *TransBranch) string { func checkAffected(db1 *gorm.DB) {
return "" if db1.RowsAffected == 0 {
panic(fmt.Errorf("duplicate updating"))
}
} }
func (t *TransSagaProcessor) ProcessOnce(db *common.MyDb, branches []TransBranch) { func (trans *TransGlobal) getProcessor() TransProcessor {
if t.Status == "prepared" { if trans.TransType == "saga" {
return &TransSagaProcessor{TransGlobal: trans}
} else if trans.TransType == "tcc" {
return &TransTccProcessor{TransGlobal: trans}
} else if trans.TransType == "xa" {
return &TransXaProcessor{TransGlobal: trans}
}
return nil
}
func (t *TransGlobal) MayQueryPrepared(db *common.MyDb) {
if t.Status != "prepared" {
return
}
resp, err := common.RestyClient.R().SetQueryParam("gid", t.Gid).Get(t.QueryPrepared) resp, err := common.RestyClient.R().SetQueryParam("gid", t.Gid).Get(t.QueryPrepared)
e2p(err) e2p(err)
body := resp.String() body := resp.String()
@ -53,159 +112,70 @@ func (t *TransSagaProcessor) ProcessOnce(db *common.MyDb, branches []TransBranch
status := common.If(t.CreateTime.Before(preparedExpire), "canceled", "prepared").(string) status := common.If(t.CreateTime.Before(preparedExpire), "canceled", "prepared").(string)
if status != t.Status { if status != t.Status {
t.changeStatus(db, status) t.changeStatus(db, status)
} else {
t.touch(db)
} }
return
} else if strings.Contains(body, "SUCCESS") { } else if strings.Contains(body, "SUCCESS") {
t.Status = "committed" t.changeStatus(db, "committed")
t.SaveNew(db)
} else {
panic(fmt.Errorf("unknown result, will be retried: %s", body))
} }
} }
current := 0 // 当前正在处理的步骤
for ; current < len(branches); current++ {
step := branches[current]
if step.BranchType == "compensate" && step.Status == "prepared" || step.BranchType == "action" && step.Status == "succeed" {
continue
}
if step.BranchType == "action" && step.Status == "prepared" {
resp, err := common.RestyClient.R().SetBody(step.Data).SetQueryParam("gid", step.Gid).Post(step.Url)
e2p(err)
body := resp.String()
t.touch(db.Must())
if strings.Contains(body, "SUCCESS") {
step.changeStatus(db.Must(), "succeed")
} else if strings.Contains(body, "FAIL") {
step.changeStatus(db.Must(), "failed")
break
} else {
panic(fmt.Errorf("unknown response: %s, will be retried", body))
}
}
}
if current == len(branches) { // saga 事务完成
t.changeStatus(db.Must(), "succeed")
return
}
for current = current - 1; current >= 0; current-- {
step := branches[current]
if step.BranchType != "compensate" || step.Status != "prepared" {
continue
}
resp, err := common.RestyClient.R().SetBody(step.Data).SetQueryParam("gid", step.Gid).Post(step.Url)
e2p(err)
body := resp.String()
if strings.Contains(body, "SUCCESS") {
step.changeStatus(db.Must(), "failed")
} else {
panic(fmt.Errorf("expect compensate return SUCCESS"))
}
}
if current != -1 {
panic(fmt.Errorf("saga current not -1"))
}
t.changeStatus(db.Must(), "failed")
}
type TransTccProcessor struct { func (trans *TransGlobal) Process(db *common.MyDb) {
*TransGlobal defer handlePanic()
defer func() {
if TransProcessedTestChan != nil {
TransProcessedTestChan <- trans.Gid
}
}()
branches := []TransBranch{}
db.Must().Where("gid=?", trans.Gid).Order("id asc").Find(&branches)
trans.getProcessor().ProcessOnce(db, branches)
} }
func (t *TransTccProcessor) GenBranches() []TransBranch { func (t *TransGlobal) SaveNew(db *common.MyDb) {
nsteps := []TransBranch{} err := db.Transaction(func(db1 *gorm.DB) error {
steps := []M{} db := &common.MyDb{DB: db1}
common.MustUnmarshalString(t.Data, &steps)
for _, step := range steps { writeTransLog(t.Gid, "create trans", t.Status, "", t.Data)
for _, branchType := range []string{"cancel", "confirm", "try"} { dbr := db.Must().Clauses(clause.OnConflict{
nsteps = append(nsteps, TransBranch{ DoNothing: true,
Gid: t.Gid, }).Create(t)
Branch: fmt.Sprintf("%d", len(nsteps)+1), if dbr.RowsAffected == 0 && t.Status == "committed" { // 如果数据库已经存放了prepared的事务则修改状态
Data: step["data"].(string), dbr = db.Must().Model(&TransGlobal{}).Where("gid=? and status=?", t.Gid, "prepared").Update("status", t.Status)
Url: step[branchType].(string), }
BranchType: branchType, if dbr.RowsAffected == 0 { // 未保存任何数据,直接返回
Status: "prepared", return nil
}
// 保存所有的分支
branches := t.getProcessor().GenBranches()
if len(branches) > 0 {
writeTransLog(t.Gid, "save branches", t.Status, "", common.MustMarshalString(branches))
db.Must().Clauses(clause.OnConflict{
DoNothing: true,
}).Create(&branches)
}
return nil
}) })
}
}
return nsteps
}
func (t *TransTccProcessor) ExecBranch(db *common.MyDb, branch *TransBranch) string {
resp, err := common.RestyClient.R().SetBody(branch.Data).SetQueryParam("gid", branch.Gid).Post(branch.Url)
e2p(err) e2p(err)
body := resp.String()
t.touch(db)
if strings.Contains(body, "SUCCESS") {
branch.changeStatus(db, "succeed")
return "SUCCESS"
}
if branch.BranchType == "try" && strings.Contains(body, "FAIL") {
branch.changeStatus(db, "failed")
return "FAIL"
}
panic(fmt.Errorf("unknown response: %s, will be retried", body))
} }
func (t *TransTccProcessor) ProcessOnce(db *common.MyDb, branches []TransBranch) { func TransFromContext(c *gin.Context) *TransGlobal {
current := 0 // 当前正在处理的步骤 data := M{}
// 先处理一轮正常try状态 b, err := c.GetRawData()
for ; current < len(branches); current++ {
step := &branches[current]
if step.BranchType != "try" || step.Status == "succeed" {
continue
}
if step.BranchType == "try" && step.Status == "prepared" {
result := t.ExecBranch(db, step)
if result == "FAIL" {
break
}
}
}
// 如果try全部成功则处理confirm分支否则处理cancel分支
currentType := common.If(current == len(branches), "confirm", "cancel")
for current--; current >= 0; current-- {
branch := &branches[current]
if branch.BranchType != currentType || branch.Status != "prepared" {
continue
}
t.ExecBranch(db, branch)
}
t.changeStatus(db, common.If(currentType == "confirm", "succeed", "failed").(string))
}
type TransXaProcessor struct {
*TransGlobal
}
func (t *TransXaProcessor) GenBranches() []TransBranch {
return []TransBranch{}
}
func (t *TransXaProcessor) ExecBranch(db *common.MyDb, branch *TransBranch) string {
resp, err := common.RestyClient.R().SetBody(M{
"branch": branch.Branch,
"action": common.If(t.Status == "prepared", "rollback", "commit"),
"gid": branch.Gid,
}).Post(branch.Url)
e2p(err) e2p(err)
body := resp.String() common.MustUnmarshal(b, &data)
if !strings.Contains(body, "SUCCESS") { logrus.Printf("creating trans in prepare")
panic(fmt.Errorf("bad response: %s", body)) if data["steps"] != nil {
data["data"] = common.MustMarshalString(data["steps"])
} }
branch.changeStatus(db, "succeed") m := TransGlobal{}
return "SUCCESS" common.MustRemarshal(data, &m)
return &m
} }
func (t *TransXaProcessor) ProcessOnce(db *common.MyDb, branches []TransBranch) { func TransFromDb(db *common.MyDb, gid string) *TransGlobal {
if t.Status == "succeed" { m := TransGlobal{}
return dbr := db.Must().Model(&m).Where("gid=?", gid).First(&m)
} e2p(dbr.Error)
currentType := common.If(t.Status == "committed", "commit", "rollback").(string) return &m
for _, branch := range branches {
if branch.BranchType == currentType && branch.Status != "succeed" {
_ = t.ExecBranch(db, &branch)
t.touch(db)
}
}
t.changeStatus(db, common.If(t.Status == "committed", "succeed", "failed").(string))
} }

78
dtmsvr/trans_saga.go Normal file
View File

@ -0,0 +1,78 @@
package dtmsvr
import (
"fmt"
"strings"
"github.com/yedf/dtm/common"
)
type TransSagaProcessor struct {
*TransGlobal
}
func (t *TransSagaProcessor) GenBranches() []TransBranch {
branches := []TransBranch{}
steps := []M{}
common.MustUnmarshalString(t.Data, &steps)
for _, step := range steps {
for _, branchType := range []string{"compensate", "action"} {
branches = append(branches, TransBranch{
Gid: t.Gid,
Branch: fmt.Sprintf("%d", len(branches)+1),
Data: step["data"].(string),
Url: step[branchType].(string),
BranchType: branchType,
Status: "prepared",
})
}
}
return branches
}
func (t *TransSagaProcessor) ExecBranch(db *common.MyDb, branch *TransBranch) {
resp, err := common.RestyClient.R().SetBody(branch.Data).SetQueryParam("gid", branch.Gid).Post(branch.Url)
e2p(err)
body := resp.String()
t.touch(db)
if strings.Contains(body, "SUCCESS") {
branch.changeStatus(db, "succeed")
} else if branch.BranchType == "action" && strings.Contains(body, "FAIL") {
branch.changeStatus(db, "failed")
} else {
panic(fmt.Errorf("unknown response: %s, will be retried", body))
}
}
func (t *TransSagaProcessor) ProcessOnce(db *common.MyDb, branches []TransBranch) {
t.MayQueryPrepared(db)
if t.Status != "committed" {
return
}
current := 0 // 当前正在处理的步骤
for ; current < len(branches); current++ {
branch := &branches[current]
if branch.BranchType != "action" || branch.Status != "prepared" {
continue
}
t.ExecBranch(db, branch)
if branch.Status != "succeed" {
break
}
}
if current == len(branches) { // saga 事务完成
t.changeStatus(db, "succeed")
return
}
for current = current - 1; current >= 0; current-- {
branch := &branches[current]
if branch.BranchType != "compensate" || branch.Status != "prepared" {
continue
}
t.ExecBranch(db, branch)
}
if current != -1 {
panic(fmt.Errorf("saga current not -1"))
}
t.changeStatus(db.Must(), "failed")
}

78
dtmsvr/trans_tcc.go Normal file
View File

@ -0,0 +1,78 @@
package dtmsvr
import (
"fmt"
"strings"
"github.com/yedf/dtm/common"
)
type TransTccProcessor struct {
*TransGlobal
}
func (t *TransTccProcessor) GenBranches() []TransBranch {
branches := []TransBranch{}
steps := []M{}
common.MustUnmarshalString(t.Data, &steps)
for _, step := range steps {
for _, branchType := range []string{"cancel", "confirm", "try"} {
branches = append(branches, TransBranch{
Gid: t.Gid,
Branch: fmt.Sprintf("%d", len(branches)+1),
Data: step["data"].(string),
Url: step[branchType].(string),
BranchType: branchType,
Status: "prepared",
})
}
}
return branches
}
func (t *TransTccProcessor) ExecBranch(db *common.MyDb, branch *TransBranch) {
resp, err := common.RestyClient.R().SetBody(branch.Data).SetQueryParam("gid", branch.Gid).Post(branch.Url)
e2p(err)
body := resp.String()
t.touch(db)
if strings.Contains(body, "SUCCESS") {
branch.changeStatus(db, "succeed")
} else if branch.BranchType == "try" && strings.Contains(body, "FAIL") {
branch.changeStatus(db, "failed")
} else {
panic(fmt.Errorf("unknown response: %s, will be retried", body))
}
}
func (t *TransTccProcessor) ProcessOnce(db *common.MyDb, branches []TransBranch) {
t.MayQueryPrepared(db)
if t.Status != "committed" {
return
}
current := 0 // 当前正在处理的步骤
// 先处理一轮正常try状态
for ; current < len(branches); current++ {
branch := &branches[current]
if branch.BranchType != "try" || branch.Status == "succeed" {
continue
}
if branch.BranchType == "try" && branch.Status == "prepared" {
t.ExecBranch(db, branch)
if branch.Status != "succeed" {
break
}
} else {
break
}
}
// 如果try全部成功则处理confirm分支否则处理cancel分支
currentType := common.If(current == len(branches), "confirm", "cancel")
for current--; current >= 0; current-- {
branch := &branches[current]
if branch.BranchType != currentType || branch.Status != "prepared" {
continue
}
t.ExecBranch(db, branch)
}
t.changeStatus(db, common.If(currentType == "confirm", "succeed", "failed").(string))
}

44
dtmsvr/trans_xa.go Normal file
View File

@ -0,0 +1,44 @@
package dtmsvr
import (
"fmt"
"strings"
"github.com/yedf/dtm/common"
)
type TransXaProcessor struct {
*TransGlobal
}
func (t *TransXaProcessor) GenBranches() []TransBranch {
return []TransBranch{}
}
func (t *TransXaProcessor) ExecBranch(db *common.MyDb, branch *TransBranch) {
resp, err := common.RestyClient.R().SetBody(M{
"branch": branch.Branch,
"action": common.If(t.Status == "prepared", "rollback", "commit"),
"gid": branch.Gid,
}).Post(branch.Url)
e2p(err)
body := resp.String()
t.touch(db)
if strings.Contains(body, "SUCCESS") {
branch.changeStatus(db, "succeed")
} else {
panic(fmt.Errorf("bad response: %s", body))
}
}
func (t *TransXaProcessor) ProcessOnce(db *common.MyDb, branches []TransBranch) {
if t.Status == "succeed" {
return
}
currentType := common.If(t.Status == "committed", "commit", "rollback").(string)
for _, branch := range branches {
if branch.BranchType == currentType && branch.Status != "succeed" {
t.ExecBranch(db, &branch)
}
}
t.changeStatus(db, common.If(t.Status == "committed", "succeed", "failed").(string))
}

View File

@ -1,158 +0,0 @@
package dtmsvr
import (
"fmt"
"time"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
"github.com/yedf/dtm/common"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type M = map[string]interface{}
var p2e = common.P2E
var e2p = common.E2P
type TransGlobal struct {
common.ModelBase
Gid string `json:"gid"`
TransType string `json:"trans_type"`
Data string `json:"data"`
Status string `json:"status"`
QueryPrepared string `json:"query_prepared"`
CommitTime *time.Time
FinishTime *time.Time
RollbackTime *time.Time
}
func (*TransGlobal) TableName() string {
return "trans_global"
}
func (t *TransGlobal) touch(db *common.MyDb) *gorm.DB {
writeTransLog(t.Gid, "touch trans", "", "", "")
return db.Model(&TransGlobal{}).Where("gid=?", t.Gid).Update("gid", t.Gid) // 更新update_time避免被定时任务再次
}
func (t *TransGlobal) changeStatus(db *common.MyDb, status string) *gorm.DB {
writeTransLog(t.Gid, "change status", status, "", "")
updates := M{
"status": status,
}
if status == "succeed" {
updates["finish_time"] = time.Now()
} else if status == "failed" {
updates["rollback_time"] = time.Now()
}
dbr := db.Must().Model(t).Where("status=?", t.Status).Updates(updates)
checkAffected(dbr)
t.Status = status
return dbr
}
type TransBranch struct {
common.ModelBase
Gid string
Url string
Data string
Branch string
BranchType string
Status string
FinishTime *time.Time
RollbackTime *time.Time
}
func (*TransBranch) TableName() string {
return "trans_branch"
}
func (t *TransBranch) changeStatus(db *common.MyDb, status string) *gorm.DB {
writeTransLog(t.Gid, "step change", status, t.Branch, "")
dbr := db.Must().Model(t).Where("status=?", t.Status).Updates(M{
"status": status,
"finish_time": time.Now(),
})
checkAffected(dbr)
t.Status = status
return dbr
}
func checkAffected(db1 *gorm.DB) {
if db1.RowsAffected == 0 {
panic(fmt.Errorf("duplicate updating"))
}
}
func (trans *TransGlobal) getProcessor() TransProcessor {
if trans.TransType == "saga" {
return &TransSagaProcessor{TransGlobal: trans}
} else if trans.TransType == "tcc" {
return &TransTccProcessor{TransGlobal: trans}
} else if trans.TransType == "xa" {
return &TransXaProcessor{TransGlobal: trans}
}
return nil
}
func (trans *TransGlobal) Process(db *common.MyDb) {
defer handlePanic()
defer func() {
if TransProcessedTestChan != nil {
TransProcessedTestChan <- trans.Gid
}
}()
branches := []TransBranch{}
db.Must().Where("gid=?", trans.Gid).Order("id asc").Find(&branches)
trans.getProcessor().ProcessOnce(db, branches)
}
func (t *TransGlobal) SaveNew(db *common.MyDb) {
err := db.Transaction(func(db1 *gorm.DB) error {
db := &common.MyDb{DB: db1}
writeTransLog(t.Gid, "create trans", t.Status, "", t.Data)
dbr := db.Must().Clauses(clause.OnConflict{
DoNothing: true,
}).Create(t)
if dbr.RowsAffected == 0 && t.Status == "committed" { // 如果数据库已经存放了prepared的事务则修改状态
dbr = db.Must().Model(&TransGlobal{}).Where("gid=? and status=?", t.Gid, "prepared").Update("status", t.Status)
}
if dbr.RowsAffected == 0 { // 未保存任何数据,直接返回
return nil
}
// 保存所有的分支
nsteps := t.getProcessor().GenBranches()
if len(nsteps) > 0 {
writeTransLog(t.Gid, "save steps", t.Status, "", common.MustMarshalString(nsteps))
db.Must().Clauses(clause.OnConflict{
DoNothing: true,
}).Create(&nsteps)
}
return nil
})
e2p(err)
}
func TransFromContext(c *gin.Context) *TransGlobal {
data := M{}
b, err := c.GetRawData()
e2p(err)
common.MustUnmarshal(b, &data)
logrus.Printf("creating trans in prepare")
if data["steps"] != nil {
data["data"] = common.MustMarshalString(data["steps"])
}
m := TransGlobal{}
common.MustRemarshal(data, &m)
return &m
}
func TransFromDb(db *common.MyDb, gid string) *TransGlobal {
m := TransGlobal{}
dbr := db.Must().Model(&m).Where("gid=?", gid).First(&m)
e2p(dbr.Error)
return &m
}

40
dtmsvr/utils.go Normal file
View File

@ -0,0 +1,40 @@
package dtmsvr
import (
"github.com/sirupsen/logrus"
"github.com/yedf/dtm/common"
)
type M = map[string]interface{}
var p2e = common.P2E
var e2p = common.E2P
func dbGet() *common.MyDb {
return common.DbGet(config.Mysql)
}
func writeTransLog(gid string, action string, status string, branch string, detail string) {
db := dbGet()
if detail == "" {
detail = "{}"
}
db.Must().Table("trans_log").Create(M{
"gid": gid,
"action": action,
"new_status": status,
"branch": branch,
"detail": detail,
})
}
var TransProcessedTestChan chan string = nil // 用于测试时,通知处理结束
func WaitTransProcessed(gid string) {
logrus.Printf("waiting for gid %s", gid)
id := <-TransProcessedTestChan
for id != gid {
logrus.Errorf("-------id %s not match gid %s", id, gid)
id = <-TransProcessedTestChan
}
logrus.Printf("finish for gid %s", gid)
}