xa normal ok

This commit is contained in:
yedongfu 2021-05-25 18:07:31 +08:00
parent 0c42bdeeb9
commit 828de82f14
23 changed files with 726 additions and 380 deletions

3
.gitignore vendored
View File

@ -1,3 +1,4 @@
dtmsvr/dtmsvr.yml
*/**/*.yml
*.out
*/**/main
main

View File

@ -4,7 +4,6 @@ import (
"os"
"time"
"github.com/yedf/dtm/common"
"github.com/yedf/dtm/dtmsvr"
"github.com/yedf/dtm/examples"
)
@ -12,12 +11,10 @@ import (
type M = map[string]interface{}
func main() {
cmd := common.If(len(os.Args) > 1, os.Args[1], "").(string)
dtmsvr.LoadConfig()
if cmd == "" { // 所有服务都启动
if len(os.Args) == 1 { // 所有服务都启动
go dtmsvr.StartSvr()
go examples.SagaStartSvr()
} else if cmd == "dtmsvr" {
} else if len(os.Args) > 1 && os.Args[1] == "dtmsvr" {
go dtmsvr.StartSvr()
}
for {

108
common/types.go Normal file
View File

@ -0,0 +1,108 @@
package common
import (
"database/sql"
"fmt"
"strings"
"time"
"github.com/sirupsen/logrus"
"gorm.io/driver/mysql"
"gorm.io/gorm"
)
type ModelBase struct {
ID uint
CreateTime *time.Time `gorm:"autoCreateTime"`
UpdateTime *time.Time `gorm:"autoUpdateTime"`
}
var dbs = map[string]*MyDb{}
type MyDb struct {
*gorm.DB
}
func (m *MyDb) Must() *MyDb {
db := m.InstanceSet("ivy.must", true)
return &MyDb{DB: db}
}
func (m *MyDb) NoMust() *MyDb {
db := m.InstanceSet("ivy.must", false)
return &MyDb{DB: db}
}
type tracePlugin struct{}
func (op *tracePlugin) Name() string {
return "tracePlugin"
}
func (op *tracePlugin) Initialize(db *gorm.DB) (err error) {
before := func(db *gorm.DB) {
db.InstanceSet("ivy.startTime", time.Now())
}
after := func(db *gorm.DB) {
_ts, _ := db.InstanceGet("ivy.startTime")
sql := db.Dialector.Explain(db.Statement.SQL.String(), db.Statement.Vars...)
logrus.Printf("used: %d ms affected: %d sql is: %s", time.Since(_ts.(time.Time)).Milliseconds(), db.RowsAffected, sql)
if v, ok := db.InstanceGet("ivy.must"); ok && v.(bool) {
if db.Error != nil && db.Error != gorm.ErrRecordNotFound {
panic(db.Error)
}
}
}
beforeName := "cb_before"
afterName := "cb_after"
logrus.Printf("installing db plugin: %s", op.Name())
// 开始前
_ = db.Callback().Create().Before("gorm:before_create").Register(beforeName, before)
_ = db.Callback().Query().Before("gorm:query").Register(beforeName, before)
_ = db.Callback().Delete().Before("gorm:before_delete").Register(beforeName, before)
_ = db.Callback().Update().Before("gorm:setup_reflect_value").Register(beforeName, before)
_ = db.Callback().Row().Before("gorm:row").Register(beforeName, before)
_ = db.Callback().Raw().Before("gorm:raw").Register(beforeName, before)
// 结束后
_ = db.Callback().Create().After("gorm:after_create").Register(afterName, after)
_ = db.Callback().Query().After("gorm:after_query").Register(afterName, after)
_ = db.Callback().Delete().After("gorm:after_delete").Register(afterName, after)
_ = db.Callback().Update().After("gorm:after_update").Register(afterName, after)
_ = db.Callback().Row().After("gorm:row").Register(afterName, after)
_ = db.Callback().Raw().After("gorm:raw").Register(afterName, after)
return
}
func GetDsn(conf map[string]string) string {
return fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", conf["user"], conf["password"], conf["host"], conf["port"], conf["database"])
}
func DbGet(conf map[string]string) *MyDb {
dsn := GetDsn(conf)
if dbs[dsn] == nil {
logrus.Printf("connecting %s", strings.Replace(dsn, conf["password"], "****", 1))
db1, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
SkipDefaultTransaction: true,
})
PanicIfError(err)
db1.Use(&tracePlugin{})
dbs[dsn] = &MyDb{DB: db1}
}
return dbs[dsn]
}
func DbAlone(conf map[string]string) (*MyDb, *sql.DB) {
logrus.Printf("opening alone mysql: %s", GetDsn(conf))
mdb, err := sql.Open("mysql", GetDsn(conf))
PanicIfError(err)
gormDB, err := gorm.Open(mysql.New(mysql.Config{
Conn: mdb,
}), &gorm.Config{})
PanicIfError(err)
gormDB.Use(&tracePlugin{})
return &MyDb{DB: gormDB}, mdb
}

View File

@ -3,12 +3,19 @@ package common
import (
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"path"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/bwmarrin/snowflake"
"github.com/gin-gonic/gin"
"github.com/go-resty/resty/v2"
"github.com/sirupsen/logrus"
"github.com/spf13/viper"
)
type M = map[string]interface{}
@ -22,6 +29,16 @@ func OrString(ss ...string) string {
return ""
}
func Panic2Error(perr *error) {
if x := recover(); x != nil {
if e, ok := x.(error); ok {
*perr = e
} else {
panic(x)
}
}
}
func GenGid() string {
return gNode.Generate().Base58()
}
@ -117,3 +134,69 @@ func WrapHandler(fn func(*gin.Context) (interface{}, error)) gin.HandlerFunc {
}
}
}
// 辅助工具与代码
var RestyClient = resty.New()
func init() {
// RestyClient.SetTimeout(3 * time.Second)
// RestyClient.SetRetryCount(2)
// RestyClient.SetRetryWaitTime(1 * time.Second)
RestyClient.OnBeforeRequest(func(c *resty.Client, r *resty.Request) error {
logrus.Printf("requesting: %s %s %v", r.Method, r.URL, r.Body)
return nil
})
RestyClient.OnAfterResponse(func(c *resty.Client, resp *resty.Response) error {
r := resp.Request
logrus.Printf("requested: %s %s %s", r.Method, r.URL, resp.String())
return nil
})
}
func CheckRestySuccess(resp *resty.Response, err error) {
PanicIfError(err)
if !strings.Contains(resp.String(), "SUCCESS") {
panic(fmt.Errorf("resty response not success: %s", resp.String()))
}
}
// formatter 自定义formatter
type formatter struct{}
// Format 进行格式化
func (f *formatter) Format(entry *logrus.Entry) ([]byte, error) {
var b *bytes.Buffer = &bytes.Buffer{}
if entry.Buffer != nil {
b = entry.Buffer
}
n := time.Now()
ts := fmt.Sprintf("%d-%02d-%02d %02d:%02d:%02d.%03d", n.Year(), n.Month(), n.Day(), n.Hour(), n.Minute(), n.Second(), n.Nanosecond()/1000000)
var file string
var line int
for i := 1; ; i++ {
_, file, line, _ = runtime.Caller(i)
if strings.Contains(file, "dtm") {
break
}
}
b.WriteString(fmt.Sprintf("%s %s:%d %s\n", ts, path.Base(file), line, entry.Message))
return b.Bytes(), nil
}
var configLoaded = map[string]bool{}
// 加载调用者文件相同目录下的配置文件
func InitApp(config interface{}) {
logrus.SetFormatter(&formatter{})
_, file, _, _ := runtime.Caller(1)
fileName := filepath.Dir(file) + "/conf.yml"
if configLoaded[fileName] {
return
}
configLoaded[fileName] = true
viper.SetConfigFile(fileName)
err := viper.ReadInConfig()
PanicIfError(err)
err = viper.Unmarshal(config)
PanicIfError(err)
}

View File

@ -10,13 +10,14 @@ import (
func AddRoute(engine *gin.Engine) {
engine.POST("/api/dtmsvr/prepare", common.WrapHandler(Prepare))
engine.POST("/api/dtmsvr/commit", common.WrapHandler(Commit))
engine.POST("/api/dtmsvr/branch", common.WrapHandler(Branch))
}
func Prepare(c *gin.Context) (interface{}, error) {
db := DbGet()
m := getSagaModelFromContext(c)
db := dbGet()
m := getTransFromContext(c)
m.Status = "prepared"
writeTransLog(m.Gid, "save prepared", m.Status, -1, m.Steps)
writeTransLog(m.Gid, "save prepared", m.Status, "", m.Data)
db.Must().Clauses(clause.OnConflict{
DoNothing: true,
}).Create(&m)
@ -24,20 +25,33 @@ func Prepare(c *gin.Context) (interface{}, error) {
}
func Commit(c *gin.Context) (interface{}, error) {
m := getSagaModelFromContext(c)
saveCommitedSagaModel(m)
go ProcessCommitedSaga(m.Gid)
m := getTransFromContext(c)
saveCommitted(m)
go ProcessCommitted(m)
return M{"message": "SUCCESS"}, nil
}
func getSagaModelFromContext(c *gin.Context) *SagaModel {
func Branch(c *gin.Context) (interface{}, error) {
branch := TransBranchModel{}
err := c.BindJSON(&branch)
common.PanicIfError(err)
db := dbGet()
db.Must().Clauses(clause.OnConflict{
DoNothing: true,
}).Create(&branch)
return M{"message": "SUCCESS"}, nil
}
func getTransFromContext(c *gin.Context) *TransGlobalModel {
data := M{}
b, err := c.GetRawData()
common.PanicIfError(err)
common.MustUnmarshal(b, &data)
logrus.Printf("creating saga model in prepare")
data["steps"] = common.MustMarshalString(data["steps"])
m := SagaModel{}
logrus.Printf("creating trans model in prepare")
if data["trans_type"].(string) == "saga" {
data["data"] = common.MustMarshalString(data["steps"])
}
m := TransGlobalModel{}
common.MustRemarshal(data, &m)
return &m
}

View File

@ -1,62 +1,10 @@
package dtmsvr
import (
"bytes"
"fmt"
"path"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/spf13/viper"
"github.com/yedf/dtm/common"
)
// formatter 自定义formatter
type formatter struct{}
// Format 进行格式化
func (f *formatter) Format(entry *logrus.Entry) ([]byte, error) {
var b *bytes.Buffer = &bytes.Buffer{}
if entry.Buffer != nil {
b = entry.Buffer
}
n := time.Now()
ts := fmt.Sprintf("%d-%02d-%02d %02d:%02d:%02d.%03d", n.Year(), n.Month(), n.Day(), n.Hour(), n.Minute(), n.Second(), n.Nanosecond()/1000000)
var file string
var line int
for i := 1; ; i++ {
_, file, line, _ = runtime.Caller(i)
if strings.Contains(file, "dtm") {
break
}
}
b.WriteString(fmt.Sprintf("%s %s:%d %s\n", ts, path.Base(file), line, entry.Message))
return b.Bytes(), nil
}
type dtmsvrConfig struct {
PreparedExpire int64 `json:"prepare_expire"` // 单位秒当prepared的状态超过该时间才能够转变成canceled避免cancel了之后才进入prepared
Mysql map[string]string
}
var Config = &dtmsvrConfig{
var config = &dtmsvrConfig{
PreparedExpire: 60,
}
var configLoaded = false
func LoadConfig() {
if configLoaded {
return
}
configLoaded = true
logrus.SetFormatter(&formatter{})
_, file, _, _ := runtime.Caller(0)
viper.SetConfigFile(filepath.Dir(file) + "/dtmsvr.yml")
err := viper.ReadInConfig()
common.PanicIfError(err)
err = viper.Unmarshal(&Config)
common.PanicIfError(err)
}

View File

@ -6,33 +6,32 @@ import (
"time"
"github.com/sirupsen/logrus"
"github.com/yedf/dtm"
"github.com/yedf/dtm/common"
)
func CronPreparedOnce(expire time.Duration) {
db := DbGet()
ss := []SagaModel{}
db.Must().Model(&SagaModel{}).Where("update_time < date_sub(now(), interval ? second)", int(expire/time.Second)).Where("status = ?", "prepared").Find(&ss)
writeTransLog("", "saga fetch prepared", fmt.Sprint(len(ss)), -1, "")
db := dbGet()
ss := []TransGlobalModel{}
db.Must().Model(&TransGlobalModel{}).Where("update_time < date_sub(now(), interval ? second)", int(expire/time.Second)).Where("status = ?", "prepared").Find(&ss)
writeTransLog("", "saga fetch prepared", fmt.Sprint(len(ss)), "", "")
if len(ss) == 0 {
return
}
for _, sm := range ss {
writeTransLog(sm.Gid, "saga touch prepared", "", -1, "")
writeTransLog(sm.Gid, "saga touch prepared", "", "", "")
db.Must().Model(&sm).Update("id", sm.ID)
resp, err := dtm.RestyClient.R().SetQueryParam("gid", sm.Gid).Get(sm.TransQuery)
resp, err := common.RestyClient.R().SetQueryParam("gid", sm.Gid).Get(sm.QueryPrepared)
common.PanicIfError(err)
body := resp.String()
if strings.Contains(body, "FAIL") {
preparedExpire := time.Now().Add(time.Duration(-Config.PreparedExpire) * time.Second)
preparedExpire := time.Now().Add(time.Duration(-config.PreparedExpire) * time.Second)
logrus.Printf("create time: %s prepared expire: %s ", sm.CreateTime.Local(), preparedExpire.Local())
status := common.If(sm.CreateTime.Before(preparedExpire), "canceled", "prepared").(string)
writeTransLog(sm.Gid, "saga canceled", status, -1, "")
writeTransLog(sm.Gid, "saga canceled", status, "", "")
db.Must().Model(&sm).Where("status = ?", "prepared").Update("status", status)
} else if strings.Contains(body, "SUCCESS") {
saveCommitedSagaModel(&sm)
ProcessCommitedSaga(sm.Gid)
saveCommitted(&sm)
ProcessCommitted(&sm)
}
}
}
@ -44,25 +43,25 @@ func CronPrepared() {
}
}
func CronCommitedOnce(expire time.Duration) {
db := DbGet()
ss := []SagaModel{}
db.Must().Model(&SagaModel{}).Where("update_time < date_sub(now(), interval ? second)", int(expire/time.Second)).Where("status = ?", "commited").Find(&ss)
writeTransLog("", "saga fetch commited", fmt.Sprint(len(ss)), -1, "")
func CronCommittedOnce(expire time.Duration) {
db := dbGet()
ss := []TransGlobalModel{}
db.Must().Model(&TransGlobalModel{}).Where("update_time < date_sub(now(), interval ? second)", int(expire/time.Second)).Where("status = ?", "committed").Find(&ss)
writeTransLog("", "saga fetch committed", fmt.Sprint(len(ss)), "", "")
if len(ss) == 0 {
return
}
for _, sm := range ss {
writeTransLog(sm.Gid, "saga touch commited", "", -1, "")
writeTransLog(sm.Gid, "saga touch committed", "", "", "")
db.Must().Model(&sm).Update("id", sm.ID)
ProcessCommitedSaga(sm.Gid)
ProcessCommitted(&sm)
}
}
func CronCommited() {
func CronCommitted() {
for {
defer handlePanic()
CronCommitedOnce(10 * time.Second)
CronCommittedOnce(10 * time.Second)
}
}

View File

@ -1,103 +1,20 @@
package dtmsvr
import (
"fmt"
"strings"
"time"
import "github.com/yedf/dtm/common"
"github.com/sirupsen/logrus"
"github.com/spf13/viper"
"github.com/yedf/dtm/common"
"gorm.io/driver/mysql"
"gorm.io/gorm"
)
var db *gorm.DB = nil
type MyDb struct {
*gorm.DB
func dbGet() *common.MyDb {
return common.DbGet(config.Mysql)
}
func (m *MyDb) Must() *MyDb {
db := m.InstanceSet("ivy.must", true)
return &MyDb{DB: db}
}
func (m *MyDb) NoMust() *MyDb {
db := m.InstanceSet("ivy.must", false)
return &MyDb{DB: db}
}
func DbGet() *MyDb {
LoadConfig()
if db == nil {
conf := viper.GetStringMapString("mysql")
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", conf["user"], conf["password"], conf["host"], conf["port"], conf["database"])
logrus.Printf("connecting %s", strings.Replace(dsn, conf["password"], "****", 1))
db1, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
SkipDefaultTransaction: true,
})
common.PanicIfError(err)
db1.Use(&tracePlugin{})
db = db1
}
return &MyDb{DB: db}
}
func writeTransLog(gid string, action string, status string, step int, detail string) {
db := DbGet()
func writeTransLog(gid string, action string, status string, branch string, detail string) {
db := dbGet()
if detail == "" {
detail = "{}"
}
db.Must().Table("trans_log").Create(M{
"gid": gid,
"action": action,
"status": status,
"step": step,
"new_status": status,
"branch": branch,
"detail": detail,
})
}
type tracePlugin struct{}
func (op *tracePlugin) Name() string {
return "tracePlugin"
}
func (op *tracePlugin) Initialize(db *gorm.DB) (err error) {
before := func(db *gorm.DB) {
db.InstanceSet("ivy.startTime", time.Now())
}
after := func(db *gorm.DB) {
_ts, _ := db.InstanceGet("ivy.startTime")
sql := db.Dialector.Explain(db.Statement.SQL.String(), db.Statement.Vars...)
logrus.Printf("used: %d ms affected: %d sql is: %s", time.Since(_ts.(time.Time)).Milliseconds(), db.RowsAffected, sql)
if v, ok := db.InstanceGet("ivy.must"); ok && v.(bool) {
if db.Error != nil && db.Error != gorm.ErrRecordNotFound {
panic(db.Error)
}
}
}
beforeName := "cb_before"
afterName := "cb_after"
logrus.Printf("installing db plugin: %s", op.Name())
// 开始前
_ = db.Callback().Create().Before("gorm:before_create").Register(beforeName, before)
_ = db.Callback().Query().Before("gorm:query").Register(beforeName, before)
_ = db.Callback().Delete().Before("gorm:before_delete").Register(beforeName, before)
_ = db.Callback().Update().Before("gorm:setup_reflect_value").Register(beforeName, before)
_ = db.Callback().Row().Before("gorm:row").Register(beforeName, before)
_ = db.Callback().Raw().Before("gorm:raw").Register(beforeName, before)
// 结束后
_ = db.Callback().Create().After("gorm:after_create").Register(afterName, after)
_ = db.Callback().Query().After("gorm:after_query").Register(afterName, after)
_ = db.Callback().Delete().After("gorm:after_delete").Register(afterName, after)
_ = db.Callback().Update().After("gorm:after_update").Register(afterName, after)
_ = db.Callback().Row().After("gorm:row").Register(afterName, after)
_ = db.Callback().Raw().After("gorm:raw").Register(afterName, after)
return
}

View File

@ -27,7 +27,7 @@ CREATE TABLE `saga_step` (
`step` int(11) NOT NULL COMMENT '处于saga中的第几步',
`url` varchar(128) NOT NULL COMMENT '动作关联的url',
`type` varchar(45) NOT NULL COMMENT 'saga的所有步骤',
`status` varchar(45) NOT NULL COMMENT '步骤的状态 pending | finished | rollbacked',
`status` varchar(45) NOT NULL COMMENT '步骤的状态 prepared | finished | rollbacked',
`finish_time` datetime DEFAULT NULL,
`rollback_time` datetime DEFAULT NULL,
`create_time` datetime DEFAULT NULL,
@ -52,15 +52,3 @@ CREATE TABLE `trans_log` (
KEY `create_time` (`create_time`)
) ENGINE=InnoDB AUTO_INCREMENT=48 DEFAULT CHARSET=utf8mb4;
drop table if EXISTS user_account;
CREATE TABLE `user_account` (
`id` int(11) NOT NULL AUTO_INCREMENT,
`user_id` int(11) DEFAULT NULL,
`balance` decimal(10,2) NOT NULL DEFAULT '0.00',
`create_time` datetime DEFAULT CURRENT_TIMESTAMP,
`update_time` datetime DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (`id`),
UNIQUE KEY `user_id` (`user_id`),
KEY `create_time` (`create_time`),
KEY `update_time` (`update_time`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;

View File

@ -13,56 +13,60 @@ import (
)
var myinit int = func() int {
LoadConfig()
common.InitApp(&config)
return 0
}()
func TestViper(t *testing.T) {
assert.Equal(t, true, viper.Get("mysql") != nil)
assert.Equal(t, int64(90), Config.PreparedExpire)
assert.Equal(t, int64(90), config.PreparedExpire)
}
func TestDtmSvr(t *testing.T) {
SagaProcessedTestChan = make(chan string, 1)
// 清理数据
common.PanicIfError(db.Exec("truncate saga").Error)
common.PanicIfError(db.Exec("truncate saga_step").Error)
common.PanicIfError(db.Exec("truncate trans_log").Error)
TransProcessedTestChan = make(chan string, 1)
// 启动组件
go StartSvr()
go examples.SagaStartSvr()
go examples.XaStartSvr()
time.Sleep(time.Duration(100 * 1000 * 1000))
preparePending(t)
prepareCancel(t)
commitedPending(t)
noramlSaga(t)
rollbackSaga2(t)
// 清理数据
common.PanicIfError(dbGet().Exec("truncate trans_global").Error)
common.PanicIfError(dbGet().Exec("truncate trans_branch").Error)
common.PanicIfError(dbGet().Exec("truncate trans_log").Error)
examples.ResetXaData()
xaNormal(t)
sagaPreparePending(t)
sagaPrepareCancel(t)
sagaCommittedPending(t)
sagaNormal(t)
sagaRollback(t)
}
func TestCover(t *testing.T) {
db := DbGet()
db := dbGet()
db.NoMust()
CronPreparedOnce(0)
CronCommitedOnce(0)
CronCommittedOnce(0)
defer handlePanic()
checkAffected(db.DB)
}
// 测试使用的全局对象
var initdb = DbGet()
var initdb = dbGet()
func getSagaModel(gid string) *SagaModel {
sm := SagaModel{}
dbr := db.Model(&sm).Where("gid=?", gid).First(&sm)
func getSagaModel(gid string) *TransGlobalModel {
sm := TransGlobalModel{}
dbr := dbGet().Model(&sm).Where("gid=?", gid).First(&sm)
common.PanicIfError(dbr.Error)
return &sm
}
func getSagaStepStatus(gid string) []string {
steps := []SagaStepModel{}
dbr := db.Model(&SagaStepModel{}).Where("gid=?", gid).Find(&steps)
func getBranchesStatus(gid string) []string {
steps := []TransBranchModel{}
dbr := dbGet().Model(&TransBranchModel{}).Where("gid=?", gid).Find(&steps)
common.PanicIfError(dbr.Error)
status := []string{}
for _, step := range steps {
@ -71,37 +75,81 @@ func getSagaStepStatus(gid string) []string {
return status
}
func noramlSaga(t *testing.T) {
func xaNormal(t *testing.T) {
xa := examples.XaClient
gid := "xa-normal"
err := xa.XaGlobalTransaction(gid, func() error {
req := examples.GenTransReq(30, false, false)
resp, err := common.RestyClient.R().SetBody(req).SetQueryParams(map[string]string{
"gid": gid,
"user_id": "1",
}).Post(examples.XaBusi + "/TransOut")
common.CheckRestySuccess(resp, err)
resp, err = common.RestyClient.R().SetBody(req).SetQueryParams(map[string]string{
"gid": gid,
"user_id": "2",
}).Post(examples.XaBusi + "/TransIn")
common.CheckRestySuccess(resp, err)
return nil
})
common.PanicIfError(err)
WaitTransCommitted(gid)
assert.Equal(t, []string{"finished", "finished"}, getBranchesStatus(gid))
}
func xaRollback(t *testing.T) {
xa := examples.XaClient
gid := "xa-rollback"
err := xa.XaGlobalTransaction(gid, func() error {
req := examples.GenTransReq(30, false, true)
resp, err := common.RestyClient.R().SetBody(req).SetQueryParams(map[string]string{
"gid": gid,
"user_id": "1",
}).Post(examples.XaBusi + "/TransOut")
common.CheckRestySuccess(resp, err)
resp, err = common.RestyClient.R().SetBody(req).SetQueryParams(map[string]string{
"gid": gid,
"user_id": "2",
}).Post(examples.XaBusi + "/TransIn")
common.CheckRestySuccess(resp, err)
return nil
})
common.PanicIfError(err)
WaitTransCommitted(gid)
assert.Equal(t, []string{"rollbacked", "rollbacked"}, getBranchesStatus(gid))
}
func sagaNormal(t *testing.T) {
saga := genSaga("gid-noramlSaga", false, false)
saga.Prepare()
assert.Equal(t, "prepared", getSagaModel(saga.Gid).Status)
saga.Commit()
assert.Equal(t, "commited", getSagaModel(saga.Gid).Status)
WaitCommitedSaga(saga.Gid)
assert.Equal(t, []string{"pending", "finished", "pending", "finished"}, getSagaStepStatus(saga.Gid))
assert.Equal(t, "committed", getSagaModel(saga.Gid).Status)
WaitTransCommitted(saga.Gid)
assert.Equal(t, []string{"prepared", "finished", "prepared", "finished"}, getBranchesStatus(saga.Gid))
}
func rollbackSaga2(t *testing.T) {
func sagaRollback(t *testing.T) {
saga := genSaga("gid-rollbackSaga2", false, true)
saga.Commit()
WaitCommitedSaga(saga.Gid)
WaitTransCommitted(saga.Gid)
saga.Prepare()
assert.Equal(t, "rollbacked", getSagaModel(saga.Gid).Status)
assert.Equal(t, []string{"rollbacked", "finished", "rollbacked", "rollbacked"}, getSagaStepStatus(saga.Gid))
assert.Equal(t, []string{"rollbacked", "finished", "rollbacked", "rollbacked"}, getBranchesStatus(saga.Gid))
}
func prepareCancel(t *testing.T) {
func sagaPrepareCancel(t *testing.T) {
saga := genSaga("gid1-prepareCancel", false, true)
saga.Prepare()
examples.TransQueryResult = "FAIL"
Config.PreparedExpire = -10
config.PreparedExpire = -10
CronPreparedOnce(-10 * time.Second)
examples.TransQueryResult = ""
Config.PreparedExpire = 60
config.PreparedExpire = 60
assert.Equal(t, "canceled", getSagaModel(saga.Gid).Status)
}
func preparePending(t *testing.T) {
func sagaPreparePending(t *testing.T) {
saga := genSaga("gid1-preparePending", false, false)
saga.Prepare()
examples.TransQueryResult = "PENDING"
@ -109,33 +157,29 @@ func preparePending(t *testing.T) {
examples.TransQueryResult = ""
assert.Equal(t, "prepared", getSagaModel(saga.Gid).Status)
CronPreparedOnce(-10 * time.Second)
WaitCommitedSaga(saga.Gid)
WaitTransCommitted(saga.Gid)
assert.Equal(t, "finished", getSagaModel(saga.Gid).Status)
}
func commitedPending(t *testing.T) {
saga := genSaga("gid-commitedPending", false, false)
func sagaCommittedPending(t *testing.T) {
saga := genSaga("gid-committedPending", false, false)
saga.Prepare()
examples.TransOutResult = "PENDING"
examples.TransInResult = "PENDING"
saga.Commit()
WaitCommitedSaga(saga.Gid)
examples.TransOutResult = ""
assert.Equal(t, []string{"pending", "finished", "pending", "pending"}, getSagaStepStatus(saga.Gid))
CronCommitedOnce(-10 * time.Second)
WaitCommitedSaga(saga.Gid)
assert.Equal(t, []string{"pending", "finished", "pending", "finished"}, getSagaStepStatus(saga.Gid))
WaitTransCommitted(saga.Gid)
examples.TransInResult = ""
assert.Equal(t, []string{"prepared", "finished", "prepared", "prepared"}, getBranchesStatus(saga.Gid))
CronCommittedOnce(-10 * time.Second)
WaitTransCommitted(saga.Gid)
assert.Equal(t, []string{"prepared", "finished", "prepared", "finished"}, getBranchesStatus(saga.Gid))
assert.Equal(t, "finished", getSagaModel(saga.Gid).Status)
}
func genSaga(gid string, inFailed bool, outFailed bool) *dtm.Saga {
func genSaga(gid string, outFailed bool, inFailed bool) *dtm.Saga {
logrus.Printf("beginning a saga test ---------------- %s", gid)
saga := dtm.SagaNew(examples.DtmServer, gid, examples.SagaBusi+"/TransQuery")
req := examples.TransReq{
Amount: 30,
TransInResult: common.If(inFailed, "FAIL", "SUCCESS").(string),
TransOutResult: common.If(outFailed, "FAIL", "SUCCESS").(string),
}
saga.Add(examples.SagaBusi+"/TransIn", examples.SagaBusi+"/TransInCompensate", &req)
req := examples.GenTransReq(30, outFailed, inFailed)
saga.Add(examples.SagaBusi+"/TransOut", examples.SagaBusi+"/TransOutCompensate", &req)
saga.Add(examples.SagaBusi+"/TransIn", examples.SagaBusi+"/TransInCompensate", &req)
return saga
}

View File

@ -11,6 +11,7 @@ func Main() {
func StartSvr() {
logrus.Printf("start dtmsvr")
common.InitApp(&config)
app := common.GetGinApp()
AddRoute(app)
logrus.Printf("dtmsvr listen at: 8080")

View File

@ -6,107 +6,145 @@ import (
"time"
"github.com/sirupsen/logrus"
"github.com/yedf/dtm"
"github.com/yedf/dtm/common"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
func saveCommitedSagaModel(m *SagaModel) {
db := DbGet()
m.Status = "commited"
func saveCommitted(m *TransGlobalModel) {
db := dbGet()
m.Status = "committed"
err := db.Transaction(func(db1 *gorm.DB) error {
db := &MyDb{DB: db1}
writeTransLog(m.Gid, "save commited", m.Status, -1, m.Steps)
db := &common.MyDb{DB: db1}
writeTransLog(m.Gid, "save committed", m.Status, "", m.Data)
dbr := db.Must().Clauses(clause.OnConflict{
DoNothing: true,
}).Create(&m)
if dbr.RowsAffected == 0 {
writeTransLog(m.Gid, "change status", m.Status, -1, "")
db.Must().Model(&m).Where("status=?", "prepared").Update("status", "commited")
writeTransLog(m.Gid, "change status", m.Status, "", "")
db.Must().Model(&m).Where("status=?", "prepared").Update("status", "committed")
}
nsteps := []SagaStepModel{}
if m.TransType == "saga" {
nsteps := []TransBranchModel{}
steps := []M{}
common.MustUnmarshalString(m.Steps, &steps)
common.MustUnmarshalString(m.Data, &steps)
for _, step := range steps {
nsteps = append(nsteps, SagaStepModel{
nsteps = append(nsteps, TransBranchModel{
Gid: m.Gid,
Step: len(nsteps) + 1,
Data: step["post_data"].(string),
Branch: fmt.Sprintf("%d", len(nsteps)+1),
Data: step["data"].(string),
Url: step["compensate"].(string),
Type: "compensate",
Status: "pending",
BranchType: "compensate",
Status: "prepared",
})
nsteps = append(nsteps, SagaStepModel{
nsteps = append(nsteps, TransBranchModel{
Gid: m.Gid,
Step: len(nsteps) + 1,
Data: step["post_data"].(string),
Branch: fmt.Sprintf("%d", len(nsteps)+1),
Data: step["data"].(string),
Url: step["action"].(string),
Type: "action",
Status: "pending",
BranchType: "action",
Status: "prepared",
})
}
writeTransLog(m.Gid, "save steps", m.Status, -1, common.MustMarshalString(nsteps))
writeTransLog(m.Gid, "save steps", m.Status, "", common.MustMarshalString(nsteps))
db.Must().Clauses(clause.OnConflict{
DoNothing: true,
}).Create(&nsteps)
}
return nil
})
common.PanicIfError(err)
}
var SagaProcessedTestChan chan string = nil // 用于测试时,通知处理结束
var TransProcessedTestChan chan string = nil // 用于测试时,通知处理结束
func WaitCommitedSaga(gid string) {
id := <-SagaProcessedTestChan
func WaitTransCommitted(gid string) {
id := <-TransProcessedTestChan
for id != gid {
logrus.Errorf("-------id %s not match gid %s", id, gid)
id = <-SagaProcessedTestChan
id = <-TransProcessedTestChan
}
}
func ProcessCommitedSaga(gid string) {
err := innerProcessCommitedSaga(gid)
func ProcessCommitted(trans *TransGlobalModel) {
err := innerProcessCommitted(trans)
if err != nil {
logrus.Errorf("process commited saga error: %s", err.Error())
logrus.Errorf("process committed error: %s", err.Error())
}
if SagaProcessedTestChan != nil {
SagaProcessedTestChan <- gid
if TransProcessedTestChan != nil {
TransProcessedTestChan <- trans.Gid
}
}
func checkAffected(db1 *gorm.DB) {
if db1.RowsAffected == 0 {
panic(fmt.Errorf("duplicate updating"))
func innerProcessCommitted(trans *TransGlobalModel) (rerr error) {
branches := []TransBranchModel{}
db := dbGet()
db.Must().Order("id asc").Find(&branches)
if trans.TransType == "saga" {
return innerProcessCommittedSaga(trans, db, branches)
} else if trans.TransType == "xa" {
return innerProcessCommittedXa(trans, db, branches)
}
panic(fmt.Errorf("unkown trans type: %s", trans.TransType))
}
func innerProcessCommitedSaga(gid string) (rerr error) {
steps := []SagaStepModel{}
db := DbGet()
db.Must().Order("id asc").Find(&steps)
current := 0 // 当前正在处理的步骤
for ; current < len(steps); current++ {
step := steps[current]
if step.Type == "compensate" && step.Status == "pending" || step.Type == "action" && step.Status == "finished" {
func innerProcessCommittedXa(trans *TransGlobalModel, db *common.MyDb, branches []TransBranchModel) error {
gid := trans.Gid
for _, branch := range branches {
if branch.Status == "finished" {
continue
}
if step.Type == "action" && step.Status == "pending" {
resp, err := dtm.RestyClient.R().SetBody(step.Data).SetQueryParam("gid", step.Gid).Post(step.Url)
db.Must().Model(&TransGlobalModel{}).Where("gid=?", gid).Update("gid", gid) // 更新update_time避免被定时任务再次
resp, err := common.RestyClient.R().SetBody(M{
"branch": branch.Branch,
"action": "commit",
"gid": branch.Gid,
}).Post(branch.Url)
if err != nil {
return err
}
body := resp.String()
db.Must().Model(&SagaModel{}).Where("gid=?", gid).Update("gid", gid) // 更新update_time避免被定时任务再次
if !strings.Contains(body, "SUCCESS") {
return fmt.Errorf("bad response: %s", body)
}
writeTransLog(gid, "step finished", "finished", branch.Branch, "")
db.Must().Model(&branch).Where("status=?", "prepared").Updates(M{
"status": "finished",
"finish_time": time.Now(),
})
}
writeTransLog(gid, "xa finished", "finished", "", "")
db.Must().Model(&TransGlobalModel{}).Where("gid=? and status=?", gid, "committed").Updates(M{
"status": "finished",
"finish_time": time.Now(),
})
return nil
}
func innerProcessCommittedSaga(trans *TransGlobalModel, db *common.MyDb, branches []TransBranchModel) error {
gid := trans.Gid
current := 0 // 当前正在处理的步骤
for ; current < len(branches); current++ {
step := branches[current]
if step.BranchType == "compensate" && step.Status == "prepared" || step.BranchType == "action" && step.Status == "finished" {
continue
}
if step.BranchType == "action" && step.Status == "prepared" {
resp, err := common.RestyClient.R().SetBody(step.Data).SetQueryParam("gid", step.Gid).Post(step.Url)
if err != nil {
return err
}
body := resp.String()
db.Must().Model(&TransGlobalModel{}).Where("gid=?", gid).Update("gid", gid) // 更新update_time避免被定时任务再次
if strings.Contains(body, "SUCCESS") {
writeTransLog(gid, "step finished", "finished", step.Step, "")
dbr := db.Must().Model(&step).Where("status=?", "pending").Updates(M{
writeTransLog(gid, "step finished", "finished", step.Branch, "")
dbr := db.Must().Model(&step).Where("status=?", "prepared").Updates(M{
"status": "finished",
"finish_time": time.Now(),
})
checkAffected(dbr)
} else if strings.Contains(body, "FAIL") {
writeTransLog(gid, "step rollbacked", "rollbacked", step.Step, "")
dbr := db.Must().Model(&step).Where("status=?", "pending").Updates(M{
writeTransLog(gid, "step rollbacked", "rollbacked", step.Branch, "")
dbr := db.Must().Model(&step).Where("status=?", "prepared").Updates(M{
"status": "rollbacked",
"rollback_time": time.Now(),
})
@ -117,9 +155,9 @@ func innerProcessCommitedSaga(gid string) (rerr error) {
}
}
}
if current == len(steps) { // saga 事务完成
writeTransLog(gid, "saga finished", "finished", -1, "")
dbr := db.Must().Model(&SagaModel{}).Where("gid=? and status=?", gid, "commited").Updates(M{
if current == len(branches) { // saga 事务完成
writeTransLog(gid, "saga finished", "finished", "", "")
dbr := db.Must().Model(&TransGlobalModel{}).Where("gid=? and status=?", gid, "committed").Updates(M{
"status": "finished",
"finish_time": time.Now(),
})
@ -127,17 +165,17 @@ func innerProcessCommitedSaga(gid string) (rerr error) {
return nil
}
for current = current - 1; current >= 0; current-- {
step := steps[current]
if step.Type != "compensate" || step.Status != "pending" {
step := branches[current]
if step.BranchType != "compensate" || step.Status != "prepared" {
continue
}
resp, err := dtm.RestyClient.R().SetBody(step.Data).SetQueryParam("gid", step.Gid).Post(step.Url)
resp, err := common.RestyClient.R().SetBody(step.Data).SetQueryParam("gid", step.Gid).Post(step.Url)
if err != nil {
return err
}
body := resp.String()
if strings.Contains(body, "SUCCESS") {
writeTransLog(gid, "step rollbacked", "rollbacked", step.Step, "")
writeTransLog(gid, "step rollbacked", "rollbacked", step.Branch, "")
dbr := db.Must().Model(&step).Where("status=?", step.Status).Updates(M{
"status": "rollbacked",
"rollback_time": time.Now(),
@ -150,11 +188,17 @@ func innerProcessCommitedSaga(gid string) (rerr error) {
if current != -1 {
return fmt.Errorf("saga current not -1")
}
writeTransLog(gid, "saga rollbacked", "rollbacked", -1, "")
dbr := db.Must().Model(&SagaModel{}).Where("status=? and gid=?", "commited", gid).Updates(M{
writeTransLog(gid, "saga rollbacked", "rollbacked", "", "")
dbr := db.Must().Model(&TransGlobalModel{}).Where("status=? and gid=?", "committed", gid).Updates(M{
"status": "rollbacked",
"rollback_time": time.Now(),
})
checkAffected(dbr)
return nil
}
func checkAffected(db1 *gorm.DB) {
if db1.RowsAffected == 0 {
panic(fmt.Errorf("duplicate updating"))
}
}

View File

@ -2,41 +2,40 @@ package dtmsvr
import (
"time"
"github.com/yedf/dtm/common"
)
type M = map[string]interface{}
type ModelBase struct {
ID uint
CreateTime *time.Time `gorm:"autoCreateTime"`
UpdateTime *time.Time `gorm:"autoUpdateTime"`
}
type SagaModel struct {
ModelBase
type TransGlobalModel struct {
common.ModelBase
Gid string `json:"gid"`
Steps string `json:"steps"`
TransQuery string `json:"trans_query"`
TransType string `json:"trans_type"`
Data string `json:"data"`
Status string `json:"status"`
QueryPrepared string `json:"query_prepared"`
CommitTime *time.Time
FinishTime *time.Time
RollbackTime *time.Time
}
func (*SagaModel) TableName() string {
return "saga"
func (*TransGlobalModel) TableName() string {
return "trans_global"
}
type SagaStepModel struct {
ModelBase
type TransBranchModel struct {
common.ModelBase
Gid string
Data string
Step int
Url string
Type string
Data string
Branch string
BranchType string
Status string
FinishTime *time.Time
RollbackTime *time.Time
}
func (*SagaStepModel) TableName() string {
return "saga_step"
func (*TransBranchModel) TableName() string {
return "trans_branch"
}

0
examples/conf.yml.sample Normal file
View File

View File

@ -1,4 +1,7 @@
package examples
// 指定dtm服务地址
const DtmServer = "http://localhost:8080/api/dtmsvr"
type exampleConfig struct {
Mysql map[string]string
}
var Config = exampleConfig{}

13
examples/examples.sql Normal file
View File

@ -0,0 +1,13 @@
use dtm_busi;
drop table if exists user_account;
create table user_account(
id int(11) PRIMARY KEY AUTO_INCREMENT,
user_id int(11) UNIQUE ,
balance DECIMAL(10, 2) not null default '0',
create_time datetime DEFAULT now(),
update_time datetime DEFAULT now(),
key(create_time),
key(update_time)
);
insert into user_account (user_id, balance) values (1, 10000), (2, 10000);

View File

@ -39,8 +39,8 @@ func sagaFireRequest() {
}
saga := dtm.SagaNew(DtmServer, gid, SagaBusi+"/TransQuery")
saga.Add(SagaBusi+"/TransIn", SagaBusi+"/TransInCompensate", req)
saga.Add(SagaBusi+"/TransOut", SagaBusi+"/TransOutCompensate", req)
saga.Add(SagaBusi+"/TransIn", SagaBusi+"/TransInCompensate", req)
err := saga.Prepare()
common.PanicIfError(err)
logrus.Printf("busi trans commit")
@ -67,12 +67,6 @@ var TransInCompensateResult = ""
var TransOutCompensateResult = ""
var TransQueryResult = ""
type TransReq struct {
Amount int `json:"amount"`
TransInResult string `json:"transInResult"`
TransOutResult string `json:"transOutResult"`
}
func transReqFromContext(c *gin.Context) *TransReq {
req := TransReq{}
err := c.BindJSON(&req)

11
examples/types.go Normal file
View File

@ -0,0 +1,11 @@
package examples
import "github.com/yedf/dtm/common"
type UserAccount struct {
common.ModelBase
UserId int
Balance string
}
func (u *UserAccount) TableName() string { return "user_account" }

20
examples/utils.go Normal file
View File

@ -0,0 +1,20 @@
package examples
import "github.com/yedf/dtm/common"
// 指定dtm服务地址
const DtmServer = "http://localhost:8080/api/dtmsvr"
type TransReq struct {
Amount int `json:"amount"`
TransInResult string `json:"transInResult"`
TransOutResult string `json:"transOutResult"`
}
func GenTransReq(amount int, outFailed bool, inFailed bool) *TransReq {
return &TransReq{
Amount: amount,
TransOutResult: common.If(outFailed, "FAIL", "SUCCESS").(string),
TransInResult: common.If(inFailed, "FAIL", "SUCCESS").(string),
}
}

View File

@ -4,8 +4,11 @@ import (
"fmt"
"time"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
"github.com/yedf/dtm"
"github.com/yedf/dtm/common"
"gorm.io/gorm"
)
// 事务参与者的服务地址
@ -14,21 +17,84 @@ const XaBusiApi = "/api/busi_xa"
var XaBusi = fmt.Sprintf("http://localhost:%d%s", XaBusiPort, XaBusiApi)
var XaClient *dtm.XaClient = nil
func XaMain() {
go XaStartSvr()
xaFireRequest()
time.Sleep(100 * time.Millisecond)
XaFireRequest()
time.Sleep(1000 * time.Second)
}
func XaStartSvr() {
common.InitApp(&Config)
logrus.Printf("xa examples starting")
app := common.GetGinApp()
AddRoute(app)
app.Run(":8081")
XaClient = dtm.XaClientNew(DtmServer, Config.Mysql, app, XaBusi+"/xa")
XaAddRoute(app)
app.Run(fmt.Sprintf(":%d", XaBusiPort))
}
func xaFireRequest() {
func XaFireRequest() {
gid := common.GenGid()
err := XaClient.XaGlobalTransaction(gid, func() (rerr error) {
defer common.Panic2Error(&rerr)
req := GenTransReq(30, false, false)
resp, err := common.RestyClient.R().SetBody(req).SetQueryParams(map[string]string{
"gid": gid,
"user_id": "1",
}).Post(XaBusi + "/TransOut")
common.CheckRestySuccess(resp, err)
resp, err = common.RestyClient.R().SetBody(req).SetQueryParams(map[string]string{
"gid": gid,
"user_id": "2",
}).Post(XaBusi + "/TransOut")
common.CheckRestySuccess(resp, err)
return nil
})
common.PanicIfError(err)
}
// api
func XaAddRoute(app *gin.Engine) {
app.POST(XaBusiApi+"/TransIn", common.WrapHandler(XaTransIn))
app.POST(XaBusiApi+"/TransOut", common.WrapHandler(XaTransOut))
}
func XaTransIn(c *gin.Context) (interface{}, error) {
err := XaClient.XaLocalTransaction(c.Query("gid"), func(db *common.MyDb) (rerr error) {
dbr := db.Model(&UserAccount{}).Where("user_id = ?", c.Query("user_id")).
Update("balance", gorm.Expr("balance - ?", transReqFromContext(c).Amount))
return dbr.Error
})
common.PanicIfError(err)
return M{"result": "SUCCESS"}, nil
}
func XaTransOut(c *gin.Context) (interface{}, error) {
err := XaClient.XaLocalTransaction(c.Query("gid"), func(db *common.MyDb) (rerr error) {
dbr := db.Model(&UserAccount{}).Where("user_id = ?", c.Query("user_id")).
Update("balance", gorm.Expr("balance + ?", transReqFromContext(c).Amount))
return dbr.Error
})
common.PanicIfError(err)
return M{"result": "SUCCESS"}, nil
}
func ResetXaData() {
db := dbGet()
db.Must().Exec("truncate user_account")
db.Must().Exec("insert into user_account (user_id, balance) values (1, 10000), (2, 10000)")
type XaRow struct {
Data string
}
xas := []XaRow{}
db.Must().Raw("xa recover").Scan(&xas)
for _, xa := range xas {
db.Must().Exec(fmt.Sprintf("xa rollback '%s'", xa.Data))
}
}
func dbGet() *common.MyDb {
return common.DbGet(Config.Mysql)
}

37
saga.go
View File

@ -3,10 +3,9 @@ package dtm
import (
"encoding/json"
"fmt"
"time"
"github.com/go-resty/resty/v2"
"github.com/sirupsen/logrus"
"github.com/yedf/dtm/common"
)
type Saga struct {
@ -16,20 +15,22 @@ type Saga struct {
type SagaData struct {
Gid string `json:"gid"`
TransType string `json:"trans_type"`
Steps []SagaStep `json:"steps"`
TransQuery string `json:"trans_query"`
QueryPrepared string `json:"query_prepared"`
}
type SagaStep struct {
Action string `json:"action"`
Compensate string `json:"compensate"`
PostData string `json:"post_data"`
Data string `json:"data"`
}
func SagaNew(server string, gid string, transQuery string) *Saga {
func SagaNew(server string, gid string, queryPrepared string) *Saga {
return &Saga{
SagaData: SagaData{
Gid: gid,
TransQuery: transQuery,
TransType: "saga",
QueryPrepared: queryPrepared,
},
Server: server,
}
@ -43,7 +44,7 @@ func (s *Saga) Add(action string, compensate string, postData interface{}) error
step := SagaStep{
Action: action,
Compensate: compensate,
PostData: string(d),
Data: string(d),
}
s.Steps = append(s.Steps, step)
return nil
@ -51,7 +52,7 @@ func (s *Saga) Add(action string, compensate string, postData interface{}) error
func (s *Saga) Prepare() error {
logrus.Printf("preparing %s body: %v", s.Gid, &s.SagaData)
resp, err := RestyClient.R().SetBody(&s.SagaData).Post(fmt.Sprintf("%s/prepare", s.Server))
resp, err := common.RestyClient.R().SetBody(&s.SagaData).Post(fmt.Sprintf("%s/prepare", s.Server))
if err != nil {
return err
}
@ -63,7 +64,7 @@ func (s *Saga) Prepare() error {
func (s *Saga) Commit() error {
logrus.Printf("committing %s body: %v", s.Gid, &s.SagaData)
resp, err := RestyClient.R().SetBody(&s.SagaData).Post(fmt.Sprintf("%s/commit", s.Server))
resp, err := common.RestyClient.R().SetBody(&s.SagaData).Post(fmt.Sprintf("%s/commit", s.Server))
if err != nil {
return err
}
@ -72,21 +73,3 @@ func (s *Saga) Commit() error {
}
return nil
}
// 辅助工具与代码
var RestyClient = resty.New()
func init() {
RestyClient.SetTimeout(3 * time.Second)
RestyClient.SetRetryCount(2)
RestyClient.SetRetryWaitTime(1 * time.Second)
RestyClient.OnBeforeRequest(func(c *resty.Client, r *resty.Request) error {
logrus.Printf("requesting: %s %s %v", r.Method, r.URL, r.Body)
return nil
})
RestyClient.OnAfterResponse(func(c *resty.Client, resp *resty.Response) error {
r := resp.Request
logrus.Printf("requested: %s %s %s", r.Method, r.URL, resp.String())
return nil
})
}

113
xa.go Normal file
View File

@ -0,0 +1,113 @@
package dtm
import (
"fmt"
"net/url"
"strings"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
"github.com/yedf/dtm/common"
)
type M = map[string]interface{}
type XaGlobalFunc func() error
type XaLocalFunc func(db *common.MyDb) error
type XaClient struct {
Server string
Conf map[string]string
CallbackUrl string
}
func XaClientNew(server string, mysqlConf map[string]string, app *gin.Engine, callbackUrl string) *XaClient {
xa := &XaClient{
Server: server,
Conf: mysqlConf,
CallbackUrl: callbackUrl,
}
u, err := url.Parse(callbackUrl)
common.PanicIfError(err)
app.POST(u.Path, common.WrapHandler(func(c *gin.Context) (interface{}, error) {
type CallbackReq struct {
Gid string `json:"gid"`
Branch string `json:"branch"`
Action string `json:"action"`
}
req := CallbackReq{}
b, err := c.GetRawData()
common.PanicIfError(err)
common.MustUnmarshal(b, &req)
tx, my := common.DbAlone(xa.Conf)
defer func() {
logrus.Printf("closing conn %v", xa.Conf)
my.Close()
}()
if req.Action == "commit" {
tx.Must().Exec(fmt.Sprintf("xa commit '%s'", req.Branch))
} else if req.Action == "rollback" {
tx.Must().Exec(fmt.Sprintf("xa rollback '%s'", req.Branch))
} else {
panic(fmt.Errorf("unknown action: %s", req.Action))
}
return M{"result": "SUCCESS"}, nil
}))
return xa
}
func (xa *XaClient) XaLocalTransaction(gid string, transFunc XaLocalFunc) (rerr error) {
defer common.Panic2Error(&rerr)
branch := common.GenGid()
tx, my := common.DbAlone(xa.Conf)
defer func() {
logrus.Printf("closing conn %v", xa.Conf)
my.Close()
}()
// tx1 := db.Session(&gorm.Session{SkipDefaultTransaction: true})
// common.PanicIfError(tx1.Error)
// tx := common.MyDb{DB: tx1}
tx.Must().Exec(fmt.Sprintf("XA start '%s'", branch))
err := transFunc(tx)
common.PanicIfError(err)
resp, err := common.RestyClient.R().
SetBody(&M{"gid": gid, "branch": branch, "trans_type": "xa", "status": "prepared", "url": xa.CallbackUrl}).
Post(xa.Server + "/branch")
common.PanicIfError(err)
if !strings.Contains(resp.String(), "SUCCESS") {
common.PanicIfError(fmt.Errorf("unknown server response: %s", resp.String()))
}
tx.Must().Exec(fmt.Sprintf("XA end '%s'", branch))
tx.Must().Exec(fmt.Sprintf("XA prepare '%s'", branch))
return nil
}
func (xa *XaClient) XaGlobalTransaction(gid string, transFunc XaGlobalFunc) (rerr error) {
data := &M{
"gid": gid,
"trans_type": "xa",
}
defer func() {
x := recover()
if x != nil {
_, _ = common.RestyClient.R().SetBody(data).Post(xa.Server + "/rollback")
rerr = x.(error)
}
}()
resp, err := common.RestyClient.R().SetBody(data).Post(xa.Server + "/prepare")
common.PanicIfError(err)
if !strings.Contains(resp.String(), "SUCCESS") {
panic(fmt.Errorf("unexpected result: %s", resp.String()))
}
err = transFunc()
common.PanicIfError(err)
resp, err = common.RestyClient.R().SetBody(data).Post(xa.Server + "/commit")
common.PanicIfError(err)
if !strings.Contains(resp.String(), "SUCCESS") {
panic(fmt.Errorf("unexpected result: %s", resp.String()))
}
return nil
}
func getDb(conf map[string]string) *common.MyDb {
return common.DbGet(conf)
}