diff --git a/.gitignore b/.gitignore index 07fd594..d2e3296 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ -dtmsvr/dtmsvr.yml +*/**/*.yml *.out -main +*/**/main +main \ No newline at end of file diff --git a/app/main.go b/app/main.go index f4f92bf..95aef70 100644 --- a/app/main.go +++ b/app/main.go @@ -4,7 +4,6 @@ import ( "os" "time" - "github.com/yedf/dtm/common" "github.com/yedf/dtm/dtmsvr" "github.com/yedf/dtm/examples" ) @@ -12,12 +11,10 @@ import ( type M = map[string]interface{} func main() { - cmd := common.If(len(os.Args) > 1, os.Args[1], "").(string) - dtmsvr.LoadConfig() - if cmd == "" { // 所有服务都启动 + if len(os.Args) == 1 { // 所有服务都启动 go dtmsvr.StartSvr() go examples.SagaStartSvr() - } else if cmd == "dtmsvr" { + } else if len(os.Args) > 1 && os.Args[1] == "dtmsvr" { go dtmsvr.StartSvr() } for { diff --git a/common/types.go b/common/types.go new file mode 100644 index 0000000..cdd9539 --- /dev/null +++ b/common/types.go @@ -0,0 +1,108 @@ +package common + +import ( + "database/sql" + "fmt" + "strings" + "time" + + "github.com/sirupsen/logrus" + "gorm.io/driver/mysql" + "gorm.io/gorm" +) + +type ModelBase struct { + ID uint + CreateTime *time.Time `gorm:"autoCreateTime"` + UpdateTime *time.Time `gorm:"autoUpdateTime"` +} + +var dbs = map[string]*MyDb{} + +type MyDb struct { + *gorm.DB +} + +func (m *MyDb) Must() *MyDb { + db := m.InstanceSet("ivy.must", true) + return &MyDb{DB: db} +} + +func (m *MyDb) NoMust() *MyDb { + db := m.InstanceSet("ivy.must", false) + return &MyDb{DB: db} +} + +type tracePlugin struct{} + +func (op *tracePlugin) Name() string { + return "tracePlugin" +} + +func (op *tracePlugin) Initialize(db *gorm.DB) (err error) { + before := func(db *gorm.DB) { + db.InstanceSet("ivy.startTime", time.Now()) + } + + after := func(db *gorm.DB) { + _ts, _ := db.InstanceGet("ivy.startTime") + sql := db.Dialector.Explain(db.Statement.SQL.String(), db.Statement.Vars...) + logrus.Printf("used: %d ms affected: %d sql is: %s", time.Since(_ts.(time.Time)).Milliseconds(), db.RowsAffected, sql) + if v, ok := db.InstanceGet("ivy.must"); ok && v.(bool) { + if db.Error != nil && db.Error != gorm.ErrRecordNotFound { + panic(db.Error) + } + } + } + + beforeName := "cb_before" + afterName := "cb_after" + + logrus.Printf("installing db plugin: %s", op.Name()) + // 开始前 + _ = db.Callback().Create().Before("gorm:before_create").Register(beforeName, before) + _ = db.Callback().Query().Before("gorm:query").Register(beforeName, before) + _ = db.Callback().Delete().Before("gorm:before_delete").Register(beforeName, before) + _ = db.Callback().Update().Before("gorm:setup_reflect_value").Register(beforeName, before) + _ = db.Callback().Row().Before("gorm:row").Register(beforeName, before) + _ = db.Callback().Raw().Before("gorm:raw").Register(beforeName, before) + + // 结束后 + _ = db.Callback().Create().After("gorm:after_create").Register(afterName, after) + _ = db.Callback().Query().After("gorm:after_query").Register(afterName, after) + _ = db.Callback().Delete().After("gorm:after_delete").Register(afterName, after) + _ = db.Callback().Update().After("gorm:after_update").Register(afterName, after) + _ = db.Callback().Row().After("gorm:row").Register(afterName, after) + _ = db.Callback().Raw().After("gorm:raw").Register(afterName, after) + return +} + +func GetDsn(conf map[string]string) string { + return fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", conf["user"], conf["password"], conf["host"], conf["port"], conf["database"]) +} + +func DbGet(conf map[string]string) *MyDb { + dsn := GetDsn(conf) + if dbs[dsn] == nil { + logrus.Printf("connecting %s", strings.Replace(dsn, conf["password"], "****", 1)) + db1, err := gorm.Open(mysql.Open(dsn), &gorm.Config{ + SkipDefaultTransaction: true, + }) + PanicIfError(err) + db1.Use(&tracePlugin{}) + dbs[dsn] = &MyDb{DB: db1} + } + return dbs[dsn] +} + +func DbAlone(conf map[string]string) (*MyDb, *sql.DB) { + logrus.Printf("opening alone mysql: %s", GetDsn(conf)) + mdb, err := sql.Open("mysql", GetDsn(conf)) + PanicIfError(err) + gormDB, err := gorm.Open(mysql.New(mysql.Config{ + Conn: mdb, + }), &gorm.Config{}) + PanicIfError(err) + gormDB.Use(&tracePlugin{}) + return &MyDb{DB: gormDB}, mdb +} diff --git a/common/utils.go b/common/utils.go index cd74637..03f035d 100644 --- a/common/utils.go +++ b/common/utils.go @@ -3,12 +3,19 @@ package common import ( "bytes" "encoding/json" + "fmt" "io/ioutil" + "path" + "path/filepath" + "runtime" + "strings" "time" "github.com/bwmarrin/snowflake" "github.com/gin-gonic/gin" + "github.com/go-resty/resty/v2" "github.com/sirupsen/logrus" + "github.com/spf13/viper" ) type M = map[string]interface{} @@ -22,6 +29,16 @@ func OrString(ss ...string) string { return "" } +func Panic2Error(perr *error) { + if x := recover(); x != nil { + if e, ok := x.(error); ok { + *perr = e + } else { + panic(x) + } + } +} + func GenGid() string { return gNode.Generate().Base58() } @@ -117,3 +134,69 @@ func WrapHandler(fn func(*gin.Context) (interface{}, error)) gin.HandlerFunc { } } } + +// 辅助工具与代码 +var RestyClient = resty.New() + +func init() { + // RestyClient.SetTimeout(3 * time.Second) + // RestyClient.SetRetryCount(2) + // RestyClient.SetRetryWaitTime(1 * time.Second) + RestyClient.OnBeforeRequest(func(c *resty.Client, r *resty.Request) error { + logrus.Printf("requesting: %s %s %v", r.Method, r.URL, r.Body) + return nil + }) + RestyClient.OnAfterResponse(func(c *resty.Client, resp *resty.Response) error { + r := resp.Request + logrus.Printf("requested: %s %s %s", r.Method, r.URL, resp.String()) + return nil + }) +} + +func CheckRestySuccess(resp *resty.Response, err error) { + PanicIfError(err) + if !strings.Contains(resp.String(), "SUCCESS") { + panic(fmt.Errorf("resty response not success: %s", resp.String())) + } +} + +// formatter 自定义formatter +type formatter struct{} + +// Format 进行格式化 +func (f *formatter) Format(entry *logrus.Entry) ([]byte, error) { + var b *bytes.Buffer = &bytes.Buffer{} + if entry.Buffer != nil { + b = entry.Buffer + } + n := time.Now() + ts := fmt.Sprintf("%d-%02d-%02d %02d:%02d:%02d.%03d", n.Year(), n.Month(), n.Day(), n.Hour(), n.Minute(), n.Second(), n.Nanosecond()/1000000) + var file string + var line int + for i := 1; ; i++ { + _, file, line, _ = runtime.Caller(i) + if strings.Contains(file, "dtm") { + break + } + } + b.WriteString(fmt.Sprintf("%s %s:%d %s\n", ts, path.Base(file), line, entry.Message)) + return b.Bytes(), nil +} + +var configLoaded = map[string]bool{} + +// 加载调用者文件相同目录下的配置文件 +func InitApp(config interface{}) { + logrus.SetFormatter(&formatter{}) + _, file, _, _ := runtime.Caller(1) + fileName := filepath.Dir(file) + "/conf.yml" + if configLoaded[fileName] { + return + } + configLoaded[fileName] = true + viper.SetConfigFile(fileName) + err := viper.ReadInConfig() + PanicIfError(err) + err = viper.Unmarshal(config) + PanicIfError(err) +} diff --git a/dtmsvr/api.go b/dtmsvr/api.go index 87e8e2c..291282e 100644 --- a/dtmsvr/api.go +++ b/dtmsvr/api.go @@ -10,13 +10,14 @@ import ( func AddRoute(engine *gin.Engine) { engine.POST("/api/dtmsvr/prepare", common.WrapHandler(Prepare)) engine.POST("/api/dtmsvr/commit", common.WrapHandler(Commit)) + engine.POST("/api/dtmsvr/branch", common.WrapHandler(Branch)) } func Prepare(c *gin.Context) (interface{}, error) { - db := DbGet() - m := getSagaModelFromContext(c) + db := dbGet() + m := getTransFromContext(c) m.Status = "prepared" - writeTransLog(m.Gid, "save prepared", m.Status, -1, m.Steps) + writeTransLog(m.Gid, "save prepared", m.Status, "", m.Data) db.Must().Clauses(clause.OnConflict{ DoNothing: true, }).Create(&m) @@ -24,20 +25,33 @@ func Prepare(c *gin.Context) (interface{}, error) { } func Commit(c *gin.Context) (interface{}, error) { - m := getSagaModelFromContext(c) - saveCommitedSagaModel(m) - go ProcessCommitedSaga(m.Gid) + m := getTransFromContext(c) + saveCommitted(m) + go ProcessCommitted(m) return M{"message": "SUCCESS"}, nil } -func getSagaModelFromContext(c *gin.Context) *SagaModel { +func Branch(c *gin.Context) (interface{}, error) { + branch := TransBranchModel{} + err := c.BindJSON(&branch) + common.PanicIfError(err) + db := dbGet() + db.Must().Clauses(clause.OnConflict{ + DoNothing: true, + }).Create(&branch) + return M{"message": "SUCCESS"}, nil +} + +func getTransFromContext(c *gin.Context) *TransGlobalModel { data := M{} b, err := c.GetRawData() common.PanicIfError(err) common.MustUnmarshal(b, &data) - logrus.Printf("creating saga model in prepare") - data["steps"] = common.MustMarshalString(data["steps"]) - m := SagaModel{} + logrus.Printf("creating trans model in prepare") + if data["trans_type"].(string) == "saga" { + data["data"] = common.MustMarshalString(data["steps"]) + } + m := TransGlobalModel{} common.MustRemarshal(data, &m) return &m } diff --git a/dtmsvr/dtmsvr.yml.sample b/dtmsvr/conf.sample similarity index 100% rename from dtmsvr/dtmsvr.yml.sample rename to dtmsvr/conf.sample diff --git a/dtmsvr/config.go b/dtmsvr/config.go index fad69fc..0420f7d 100644 --- a/dtmsvr/config.go +++ b/dtmsvr/config.go @@ -1,62 +1,10 @@ package dtmsvr -import ( - "bytes" - "fmt" - "path" - "path/filepath" - "runtime" - "strings" - "time" - - "github.com/sirupsen/logrus" - "github.com/spf13/viper" - "github.com/yedf/dtm/common" -) - -// formatter 自定义formatter -type formatter struct{} - -// Format 进行格式化 -func (f *formatter) Format(entry *logrus.Entry) ([]byte, error) { - var b *bytes.Buffer = &bytes.Buffer{} - if entry.Buffer != nil { - b = entry.Buffer - } - n := time.Now() - ts := fmt.Sprintf("%d-%02d-%02d %02d:%02d:%02d.%03d", n.Year(), n.Month(), n.Day(), n.Hour(), n.Minute(), n.Second(), n.Nanosecond()/1000000) - var file string - var line int - for i := 1; ; i++ { - _, file, line, _ = runtime.Caller(i) - if strings.Contains(file, "dtm") { - break - } - } - b.WriteString(fmt.Sprintf("%s %s:%d %s\n", ts, path.Base(file), line, entry.Message)) - return b.Bytes(), nil -} - type dtmsvrConfig struct { PreparedExpire int64 `json:"prepare_expire"` // 单位秒,当prepared的状态超过该时间,才能够转变成canceled,避免cancel了之后,才进入prepared + Mysql map[string]string } -var Config = &dtmsvrConfig{ +var config = &dtmsvrConfig{ PreparedExpire: 60, } - -var configLoaded = false - -func LoadConfig() { - if configLoaded { - return - } - configLoaded = true - logrus.SetFormatter(&formatter{}) - _, file, _, _ := runtime.Caller(0) - viper.SetConfigFile(filepath.Dir(file) + "/dtmsvr.yml") - err := viper.ReadInConfig() - common.PanicIfError(err) - err = viper.Unmarshal(&Config) - common.PanicIfError(err) -} diff --git a/dtmsvr/cron.go b/dtmsvr/cron.go index d6bf24c..80bb85c 100644 --- a/dtmsvr/cron.go +++ b/dtmsvr/cron.go @@ -6,33 +6,32 @@ import ( "time" "github.com/sirupsen/logrus" - "github.com/yedf/dtm" "github.com/yedf/dtm/common" ) func CronPreparedOnce(expire time.Duration) { - db := DbGet() - ss := []SagaModel{} - db.Must().Model(&SagaModel{}).Where("update_time < date_sub(now(), interval ? second)", int(expire/time.Second)).Where("status = ?", "prepared").Find(&ss) - writeTransLog("", "saga fetch prepared", fmt.Sprint(len(ss)), -1, "") + db := dbGet() + ss := []TransGlobalModel{} + db.Must().Model(&TransGlobalModel{}).Where("update_time < date_sub(now(), interval ? second)", int(expire/time.Second)).Where("status = ?", "prepared").Find(&ss) + writeTransLog("", "saga fetch prepared", fmt.Sprint(len(ss)), "", "") if len(ss) == 0 { return } for _, sm := range ss { - writeTransLog(sm.Gid, "saga touch prepared", "", -1, "") + writeTransLog(sm.Gid, "saga touch prepared", "", "", "") db.Must().Model(&sm).Update("id", sm.ID) - resp, err := dtm.RestyClient.R().SetQueryParam("gid", sm.Gid).Get(sm.TransQuery) + resp, err := common.RestyClient.R().SetQueryParam("gid", sm.Gid).Get(sm.QueryPrepared) common.PanicIfError(err) body := resp.String() if strings.Contains(body, "FAIL") { - preparedExpire := time.Now().Add(time.Duration(-Config.PreparedExpire) * time.Second) + preparedExpire := time.Now().Add(time.Duration(-config.PreparedExpire) * time.Second) logrus.Printf("create time: %s prepared expire: %s ", sm.CreateTime.Local(), preparedExpire.Local()) status := common.If(sm.CreateTime.Before(preparedExpire), "canceled", "prepared").(string) - writeTransLog(sm.Gid, "saga canceled", status, -1, "") + writeTransLog(sm.Gid, "saga canceled", status, "", "") db.Must().Model(&sm).Where("status = ?", "prepared").Update("status", status) } else if strings.Contains(body, "SUCCESS") { - saveCommitedSagaModel(&sm) - ProcessCommitedSaga(sm.Gid) + saveCommitted(&sm) + ProcessCommitted(&sm) } } } @@ -44,25 +43,25 @@ func CronPrepared() { } } -func CronCommitedOnce(expire time.Duration) { - db := DbGet() - ss := []SagaModel{} - db.Must().Model(&SagaModel{}).Where("update_time < date_sub(now(), interval ? second)", int(expire/time.Second)).Where("status = ?", "commited").Find(&ss) - writeTransLog("", "saga fetch commited", fmt.Sprint(len(ss)), -1, "") +func CronCommittedOnce(expire time.Duration) { + db := dbGet() + ss := []TransGlobalModel{} + db.Must().Model(&TransGlobalModel{}).Where("update_time < date_sub(now(), interval ? second)", int(expire/time.Second)).Where("status = ?", "committed").Find(&ss) + writeTransLog("", "saga fetch committed", fmt.Sprint(len(ss)), "", "") if len(ss) == 0 { return } for _, sm := range ss { - writeTransLog(sm.Gid, "saga touch commited", "", -1, "") + writeTransLog(sm.Gid, "saga touch committed", "", "", "") db.Must().Model(&sm).Update("id", sm.ID) - ProcessCommitedSaga(sm.Gid) + ProcessCommitted(&sm) } } -func CronCommited() { +func CronCommitted() { for { defer handlePanic() - CronCommitedOnce(10 * time.Second) + CronCommittedOnce(10 * time.Second) } } diff --git a/dtmsvr/db.go b/dtmsvr/db.go index 8fe9276..2820c9a 100644 --- a/dtmsvr/db.go +++ b/dtmsvr/db.go @@ -1,103 +1,20 @@ package dtmsvr -import ( - "fmt" - "strings" - "time" +import "github.com/yedf/dtm/common" - "github.com/sirupsen/logrus" - "github.com/spf13/viper" - "github.com/yedf/dtm/common" - "gorm.io/driver/mysql" - "gorm.io/gorm" -) - -var db *gorm.DB = nil - -type MyDb struct { - *gorm.DB +func dbGet() *common.MyDb { + return common.DbGet(config.Mysql) } - -func (m *MyDb) Must() *MyDb { - db := m.InstanceSet("ivy.must", true) - return &MyDb{DB: db} -} - -func (m *MyDb) NoMust() *MyDb { - db := m.InstanceSet("ivy.must", false) - return &MyDb{DB: db} -} - -func DbGet() *MyDb { - LoadConfig() - if db == nil { - conf := viper.GetStringMapString("mysql") - dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", conf["user"], conf["password"], conf["host"], conf["port"], conf["database"]) - logrus.Printf("connecting %s", strings.Replace(dsn, conf["password"], "****", 1)) - db1, err := gorm.Open(mysql.Open(dsn), &gorm.Config{ - SkipDefaultTransaction: true, - }) - common.PanicIfError(err) - db1.Use(&tracePlugin{}) - db = db1 - } - return &MyDb{DB: db} -} - -func writeTransLog(gid string, action string, status string, step int, detail string) { - db := DbGet() +func writeTransLog(gid string, action string, status string, branch string, detail string) { + db := dbGet() if detail == "" { detail = "{}" } db.Must().Table("trans_log").Create(M{ - "gid": gid, - "action": action, - "status": status, - "step": step, - "detail": detail, + "gid": gid, + "action": action, + "new_status": status, + "branch": branch, + "detail": detail, }) } - -type tracePlugin struct{} - -func (op *tracePlugin) Name() string { - return "tracePlugin" -} - -func (op *tracePlugin) Initialize(db *gorm.DB) (err error) { - before := func(db *gorm.DB) { - db.InstanceSet("ivy.startTime", time.Now()) - } - - after := func(db *gorm.DB) { - _ts, _ := db.InstanceGet("ivy.startTime") - sql := db.Dialector.Explain(db.Statement.SQL.String(), db.Statement.Vars...) - logrus.Printf("used: %d ms affected: %d sql is: %s", time.Since(_ts.(time.Time)).Milliseconds(), db.RowsAffected, sql) - if v, ok := db.InstanceGet("ivy.must"); ok && v.(bool) { - if db.Error != nil && db.Error != gorm.ErrRecordNotFound { - panic(db.Error) - } - } - } - - beforeName := "cb_before" - afterName := "cb_after" - - logrus.Printf("installing db plugin: %s", op.Name()) - // 开始前 - _ = db.Callback().Create().Before("gorm:before_create").Register(beforeName, before) - _ = db.Callback().Query().Before("gorm:query").Register(beforeName, before) - _ = db.Callback().Delete().Before("gorm:before_delete").Register(beforeName, before) - _ = db.Callback().Update().Before("gorm:setup_reflect_value").Register(beforeName, before) - _ = db.Callback().Row().Before("gorm:row").Register(beforeName, before) - _ = db.Callback().Raw().Before("gorm:raw").Register(beforeName, before) - - // 结束后 - _ = db.Callback().Create().After("gorm:after_create").Register(afterName, after) - _ = db.Callback().Query().After("gorm:after_query").Register(afterName, after) - _ = db.Callback().Delete().After("gorm:after_delete").Register(afterName, after) - _ = db.Callback().Update().After("gorm:after_update").Register(afterName, after) - _ = db.Callback().Row().After("gorm:row").Register(afterName, after) - _ = db.Callback().Raw().After("gorm:raw").Register(afterName, after) - return -} diff --git a/dtmsvr/dtmsvr.sql b/dtmsvr/dtmsvr.sql index ff0997b..e448100 100644 --- a/dtmsvr/dtmsvr.sql +++ b/dtmsvr/dtmsvr.sql @@ -27,7 +27,7 @@ CREATE TABLE `saga_step` ( `step` int(11) NOT NULL COMMENT '处于saga中的第几步', `url` varchar(128) NOT NULL COMMENT '动作关联的url', `type` varchar(45) NOT NULL COMMENT 'saga的所有步骤', - `status` varchar(45) NOT NULL COMMENT '步骤的状态 pending | finished | rollbacked', + `status` varchar(45) NOT NULL COMMENT '步骤的状态 prepared | finished | rollbacked', `finish_time` datetime DEFAULT NULL, `rollback_time` datetime DEFAULT NULL, `create_time` datetime DEFAULT NULL, @@ -52,15 +52,3 @@ CREATE TABLE `trans_log` ( KEY `create_time` (`create_time`) ) ENGINE=InnoDB AUTO_INCREMENT=48 DEFAULT CHARSET=utf8mb4; -drop table if EXISTS user_account; -CREATE TABLE `user_account` ( - `id` int(11) NOT NULL AUTO_INCREMENT, - `user_id` int(11) DEFAULT NULL, - `balance` decimal(10,2) NOT NULL DEFAULT '0.00', - `create_time` datetime DEFAULT CURRENT_TIMESTAMP, - `update_time` datetime DEFAULT CURRENT_TIMESTAMP, - PRIMARY KEY (`id`), - UNIQUE KEY `user_id` (`user_id`), - KEY `create_time` (`create_time`), - KEY `update_time` (`update_time`) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; diff --git a/dtmsvr/dtmsvr_test.go b/dtmsvr/dtmsvr_test.go index 0423e0c..856f7e0 100644 --- a/dtmsvr/dtmsvr_test.go +++ b/dtmsvr/dtmsvr_test.go @@ -13,56 +13,60 @@ import ( ) var myinit int = func() int { - LoadConfig() + common.InitApp(&config) return 0 }() func TestViper(t *testing.T) { assert.Equal(t, true, viper.Get("mysql") != nil) - assert.Equal(t, int64(90), Config.PreparedExpire) + assert.Equal(t, int64(90), config.PreparedExpire) } func TestDtmSvr(t *testing.T) { - SagaProcessedTestChan = make(chan string, 1) - // 清理数据 - common.PanicIfError(db.Exec("truncate saga").Error) - common.PanicIfError(db.Exec("truncate saga_step").Error) - common.PanicIfError(db.Exec("truncate trans_log").Error) + TransProcessedTestChan = make(chan string, 1) // 启动组件 go StartSvr() go examples.SagaStartSvr() + go examples.XaStartSvr() time.Sleep(time.Duration(100 * 1000 * 1000)) - preparePending(t) - prepareCancel(t) - commitedPending(t) - noramlSaga(t) - rollbackSaga2(t) + // 清理数据 + common.PanicIfError(dbGet().Exec("truncate trans_global").Error) + common.PanicIfError(dbGet().Exec("truncate trans_branch").Error) + common.PanicIfError(dbGet().Exec("truncate trans_log").Error) + examples.ResetXaData() + + xaNormal(t) + sagaPreparePending(t) + sagaPrepareCancel(t) + sagaCommittedPending(t) + sagaNormal(t) + sagaRollback(t) } func TestCover(t *testing.T) { - db := DbGet() + db := dbGet() db.NoMust() CronPreparedOnce(0) - CronCommitedOnce(0) + CronCommittedOnce(0) defer handlePanic() checkAffected(db.DB) } // 测试使用的全局对象 -var initdb = DbGet() +var initdb = dbGet() -func getSagaModel(gid string) *SagaModel { - sm := SagaModel{} - dbr := db.Model(&sm).Where("gid=?", gid).First(&sm) +func getSagaModel(gid string) *TransGlobalModel { + sm := TransGlobalModel{} + dbr := dbGet().Model(&sm).Where("gid=?", gid).First(&sm) common.PanicIfError(dbr.Error) return &sm } -func getSagaStepStatus(gid string) []string { - steps := []SagaStepModel{} - dbr := db.Model(&SagaStepModel{}).Where("gid=?", gid).Find(&steps) +func getBranchesStatus(gid string) []string { + steps := []TransBranchModel{} + dbr := dbGet().Model(&TransBranchModel{}).Where("gid=?", gid).Find(&steps) common.PanicIfError(dbr.Error) status := []string{} for _, step := range steps { @@ -71,37 +75,81 @@ func getSagaStepStatus(gid string) []string { return status } -func noramlSaga(t *testing.T) { +func xaNormal(t *testing.T) { + xa := examples.XaClient + gid := "xa-normal" + err := xa.XaGlobalTransaction(gid, func() error { + req := examples.GenTransReq(30, false, false) + resp, err := common.RestyClient.R().SetBody(req).SetQueryParams(map[string]string{ + "gid": gid, + "user_id": "1", + }).Post(examples.XaBusi + "/TransOut") + common.CheckRestySuccess(resp, err) + resp, err = common.RestyClient.R().SetBody(req).SetQueryParams(map[string]string{ + "gid": gid, + "user_id": "2", + }).Post(examples.XaBusi + "/TransIn") + common.CheckRestySuccess(resp, err) + return nil + }) + common.PanicIfError(err) + WaitTransCommitted(gid) + assert.Equal(t, []string{"finished", "finished"}, getBranchesStatus(gid)) +} + +func xaRollback(t *testing.T) { + xa := examples.XaClient + gid := "xa-rollback" + err := xa.XaGlobalTransaction(gid, func() error { + req := examples.GenTransReq(30, false, true) + resp, err := common.RestyClient.R().SetBody(req).SetQueryParams(map[string]string{ + "gid": gid, + "user_id": "1", + }).Post(examples.XaBusi + "/TransOut") + common.CheckRestySuccess(resp, err) + resp, err = common.RestyClient.R().SetBody(req).SetQueryParams(map[string]string{ + "gid": gid, + "user_id": "2", + }).Post(examples.XaBusi + "/TransIn") + common.CheckRestySuccess(resp, err) + return nil + }) + common.PanicIfError(err) + WaitTransCommitted(gid) + assert.Equal(t, []string{"rollbacked", "rollbacked"}, getBranchesStatus(gid)) +} + +func sagaNormal(t *testing.T) { saga := genSaga("gid-noramlSaga", false, false) saga.Prepare() assert.Equal(t, "prepared", getSagaModel(saga.Gid).Status) saga.Commit() - assert.Equal(t, "commited", getSagaModel(saga.Gid).Status) - WaitCommitedSaga(saga.Gid) - assert.Equal(t, []string{"pending", "finished", "pending", "finished"}, getSagaStepStatus(saga.Gid)) + assert.Equal(t, "committed", getSagaModel(saga.Gid).Status) + WaitTransCommitted(saga.Gid) + assert.Equal(t, []string{"prepared", "finished", "prepared", "finished"}, getBranchesStatus(saga.Gid)) } -func rollbackSaga2(t *testing.T) { +func sagaRollback(t *testing.T) { saga := genSaga("gid-rollbackSaga2", false, true) saga.Commit() - WaitCommitedSaga(saga.Gid) + WaitTransCommitted(saga.Gid) saga.Prepare() assert.Equal(t, "rollbacked", getSagaModel(saga.Gid).Status) - assert.Equal(t, []string{"rollbacked", "finished", "rollbacked", "rollbacked"}, getSagaStepStatus(saga.Gid)) + assert.Equal(t, []string{"rollbacked", "finished", "rollbacked", "rollbacked"}, getBranchesStatus(saga.Gid)) } -func prepareCancel(t *testing.T) { +func sagaPrepareCancel(t *testing.T) { saga := genSaga("gid1-prepareCancel", false, true) saga.Prepare() examples.TransQueryResult = "FAIL" - Config.PreparedExpire = -10 + config.PreparedExpire = -10 CronPreparedOnce(-10 * time.Second) examples.TransQueryResult = "" - Config.PreparedExpire = 60 + config.PreparedExpire = 60 assert.Equal(t, "canceled", getSagaModel(saga.Gid).Status) } -func preparePending(t *testing.T) { +func sagaPreparePending(t *testing.T) { saga := genSaga("gid1-preparePending", false, false) saga.Prepare() examples.TransQueryResult = "PENDING" @@ -109,33 +157,29 @@ func preparePending(t *testing.T) { examples.TransQueryResult = "" assert.Equal(t, "prepared", getSagaModel(saga.Gid).Status) CronPreparedOnce(-10 * time.Second) - WaitCommitedSaga(saga.Gid) + WaitTransCommitted(saga.Gid) assert.Equal(t, "finished", getSagaModel(saga.Gid).Status) } -func commitedPending(t *testing.T) { - saga := genSaga("gid-commitedPending", false, false) +func sagaCommittedPending(t *testing.T) { + saga := genSaga("gid-committedPending", false, false) saga.Prepare() - examples.TransOutResult = "PENDING" + examples.TransInResult = "PENDING" saga.Commit() - WaitCommitedSaga(saga.Gid) - examples.TransOutResult = "" - assert.Equal(t, []string{"pending", "finished", "pending", "pending"}, getSagaStepStatus(saga.Gid)) - CronCommitedOnce(-10 * time.Second) - WaitCommitedSaga(saga.Gid) - assert.Equal(t, []string{"pending", "finished", "pending", "finished"}, getSagaStepStatus(saga.Gid)) + WaitTransCommitted(saga.Gid) + examples.TransInResult = "" + assert.Equal(t, []string{"prepared", "finished", "prepared", "prepared"}, getBranchesStatus(saga.Gid)) + CronCommittedOnce(-10 * time.Second) + WaitTransCommitted(saga.Gid) + assert.Equal(t, []string{"prepared", "finished", "prepared", "finished"}, getBranchesStatus(saga.Gid)) assert.Equal(t, "finished", getSagaModel(saga.Gid).Status) } -func genSaga(gid string, inFailed bool, outFailed bool) *dtm.Saga { +func genSaga(gid string, outFailed bool, inFailed bool) *dtm.Saga { logrus.Printf("beginning a saga test ---------------- %s", gid) saga := dtm.SagaNew(examples.DtmServer, gid, examples.SagaBusi+"/TransQuery") - req := examples.TransReq{ - Amount: 30, - TransInResult: common.If(inFailed, "FAIL", "SUCCESS").(string), - TransOutResult: common.If(outFailed, "FAIL", "SUCCESS").(string), - } - saga.Add(examples.SagaBusi+"/TransIn", examples.SagaBusi+"/TransInCompensate", &req) + req := examples.GenTransReq(30, outFailed, inFailed) saga.Add(examples.SagaBusi+"/TransOut", examples.SagaBusi+"/TransOutCompensate", &req) + saga.Add(examples.SagaBusi+"/TransIn", examples.SagaBusi+"/TransInCompensate", &req) return saga } diff --git a/dtmsvr/main.go b/dtmsvr/main.go index b99e3dc..9118d3a 100644 --- a/dtmsvr/main.go +++ b/dtmsvr/main.go @@ -11,6 +11,7 @@ func Main() { func StartSvr() { logrus.Printf("start dtmsvr") + common.InitApp(&config) app := common.GetGinApp() AddRoute(app) logrus.Printf("dtmsvr listen at: 8080") diff --git a/dtmsvr/service.go b/dtmsvr/service.go index 186a6c6..8fda0d0 100644 --- a/dtmsvr/service.go +++ b/dtmsvr/service.go @@ -6,107 +6,145 @@ import ( "time" "github.com/sirupsen/logrus" - "github.com/yedf/dtm" "github.com/yedf/dtm/common" "gorm.io/gorm" "gorm.io/gorm/clause" ) -func saveCommitedSagaModel(m *SagaModel) { - db := DbGet() - m.Status = "commited" +func saveCommitted(m *TransGlobalModel) { + db := dbGet() + m.Status = "committed" err := db.Transaction(func(db1 *gorm.DB) error { - db := &MyDb{DB: db1} - writeTransLog(m.Gid, "save commited", m.Status, -1, m.Steps) + db := &common.MyDb{DB: db1} + writeTransLog(m.Gid, "save committed", m.Status, "", m.Data) dbr := db.Must().Clauses(clause.OnConflict{ DoNothing: true, }).Create(&m) if dbr.RowsAffected == 0 { - writeTransLog(m.Gid, "change status", m.Status, -1, "") - db.Must().Model(&m).Where("status=?", "prepared").Update("status", "commited") + writeTransLog(m.Gid, "change status", m.Status, "", "") + db.Must().Model(&m).Where("status=?", "prepared").Update("status", "committed") } - nsteps := []SagaStepModel{} - steps := []M{} - common.MustUnmarshalString(m.Steps, &steps) - for _, step := range steps { - nsteps = append(nsteps, SagaStepModel{ - Gid: m.Gid, - Step: len(nsteps) + 1, - Data: step["post_data"].(string), - Url: step["compensate"].(string), - Type: "compensate", - Status: "pending", - }) - nsteps = append(nsteps, SagaStepModel{ - Gid: m.Gid, - Step: len(nsteps) + 1, - Data: step["post_data"].(string), - Url: step["action"].(string), - Type: "action", - Status: "pending", - }) + if m.TransType == "saga" { + nsteps := []TransBranchModel{} + steps := []M{} + common.MustUnmarshalString(m.Data, &steps) + for _, step := range steps { + nsteps = append(nsteps, TransBranchModel{ + Gid: m.Gid, + Branch: fmt.Sprintf("%d", len(nsteps)+1), + Data: step["data"].(string), + Url: step["compensate"].(string), + BranchType: "compensate", + Status: "prepared", + }) + nsteps = append(nsteps, TransBranchModel{ + Gid: m.Gid, + Branch: fmt.Sprintf("%d", len(nsteps)+1), + Data: step["data"].(string), + Url: step["action"].(string), + BranchType: "action", + Status: "prepared", + }) + } + writeTransLog(m.Gid, "save steps", m.Status, "", common.MustMarshalString(nsteps)) + db.Must().Clauses(clause.OnConflict{ + DoNothing: true, + }).Create(&nsteps) } - writeTransLog(m.Gid, "save steps", m.Status, -1, common.MustMarshalString(nsteps)) - db.Must().Clauses(clause.OnConflict{ - DoNothing: true, - }).Create(&nsteps) return nil }) common.PanicIfError(err) } -var SagaProcessedTestChan chan string = nil // 用于测试时,通知处理结束 +var TransProcessedTestChan chan string = nil // 用于测试时,通知处理结束 -func WaitCommitedSaga(gid string) { - id := <-SagaProcessedTestChan +func WaitTransCommitted(gid string) { + id := <-TransProcessedTestChan for id != gid { logrus.Errorf("-------id %s not match gid %s", id, gid) - id = <-SagaProcessedTestChan + id = <-TransProcessedTestChan } } -func ProcessCommitedSaga(gid string) { - err := innerProcessCommitedSaga(gid) +func ProcessCommitted(trans *TransGlobalModel) { + err := innerProcessCommitted(trans) if err != nil { - logrus.Errorf("process commited saga error: %s", err.Error()) + logrus.Errorf("process committed error: %s", err.Error()) } - if SagaProcessedTestChan != nil { - SagaProcessedTestChan <- gid + if TransProcessedTestChan != nil { + TransProcessedTestChan <- trans.Gid } } -func checkAffected(db1 *gorm.DB) { - if db1.RowsAffected == 0 { - panic(fmt.Errorf("duplicate updating")) +func innerProcessCommitted(trans *TransGlobalModel) (rerr error) { + branches := []TransBranchModel{} + db := dbGet() + db.Must().Order("id asc").Find(&branches) + if trans.TransType == "saga" { + return innerProcessCommittedSaga(trans, db, branches) + } else if trans.TransType == "xa" { + return innerProcessCommittedXa(trans, db, branches) } + panic(fmt.Errorf("unkown trans type: %s", trans.TransType)) } -func innerProcessCommitedSaga(gid string) (rerr error) { - steps := []SagaStepModel{} - db := DbGet() - db.Must().Order("id asc").Find(&steps) - current := 0 // 当前正在处理的步骤 - for ; current < len(steps); current++ { - step := steps[current] - if step.Type == "compensate" && step.Status == "pending" || step.Type == "action" && step.Status == "finished" { +func innerProcessCommittedXa(trans *TransGlobalModel, db *common.MyDb, branches []TransBranchModel) error { + gid := trans.Gid + for _, branch := range branches { + if branch.Status == "finished" { continue } - if step.Type == "action" && step.Status == "pending" { - resp, err := dtm.RestyClient.R().SetBody(step.Data).SetQueryParam("gid", step.Gid).Post(step.Url) + db.Must().Model(&TransGlobalModel{}).Where("gid=?", gid).Update("gid", gid) // 更新update_time,避免被定时任务再次 + resp, err := common.RestyClient.R().SetBody(M{ + "branch": branch.Branch, + "action": "commit", + "gid": branch.Gid, + }).Post(branch.Url) + if err != nil { + return err + } + body := resp.String() + if !strings.Contains(body, "SUCCESS") { + return fmt.Errorf("bad response: %s", body) + } + writeTransLog(gid, "step finished", "finished", branch.Branch, "") + db.Must().Model(&branch).Where("status=?", "prepared").Updates(M{ + "status": "finished", + "finish_time": time.Now(), + }) + } + writeTransLog(gid, "xa finished", "finished", "", "") + db.Must().Model(&TransGlobalModel{}).Where("gid=? and status=?", gid, "committed").Updates(M{ + "status": "finished", + "finish_time": time.Now(), + }) + return nil +} + +func innerProcessCommittedSaga(trans *TransGlobalModel, db *common.MyDb, branches []TransBranchModel) error { + gid := trans.Gid + current := 0 // 当前正在处理的步骤 + for ; current < len(branches); current++ { + step := branches[current] + if step.BranchType == "compensate" && step.Status == "prepared" || step.BranchType == "action" && step.Status == "finished" { + continue + } + if step.BranchType == "action" && step.Status == "prepared" { + resp, err := common.RestyClient.R().SetBody(step.Data).SetQueryParam("gid", step.Gid).Post(step.Url) if err != nil { return err } body := resp.String() - db.Must().Model(&SagaModel{}).Where("gid=?", gid).Update("gid", gid) // 更新update_time,避免被定时任务再次 + db.Must().Model(&TransGlobalModel{}).Where("gid=?", gid).Update("gid", gid) // 更新update_time,避免被定时任务再次 if strings.Contains(body, "SUCCESS") { - writeTransLog(gid, "step finished", "finished", step.Step, "") - dbr := db.Must().Model(&step).Where("status=?", "pending").Updates(M{ + writeTransLog(gid, "step finished", "finished", step.Branch, "") + dbr := db.Must().Model(&step).Where("status=?", "prepared").Updates(M{ "status": "finished", "finish_time": time.Now(), }) checkAffected(dbr) } else if strings.Contains(body, "FAIL") { - writeTransLog(gid, "step rollbacked", "rollbacked", step.Step, "") - dbr := db.Must().Model(&step).Where("status=?", "pending").Updates(M{ + writeTransLog(gid, "step rollbacked", "rollbacked", step.Branch, "") + dbr := db.Must().Model(&step).Where("status=?", "prepared").Updates(M{ "status": "rollbacked", "rollback_time": time.Now(), }) @@ -117,9 +155,9 @@ func innerProcessCommitedSaga(gid string) (rerr error) { } } } - if current == len(steps) { // saga 事务完成 - writeTransLog(gid, "saga finished", "finished", -1, "") - dbr := db.Must().Model(&SagaModel{}).Where("gid=? and status=?", gid, "commited").Updates(M{ + if current == len(branches) { // saga 事务完成 + writeTransLog(gid, "saga finished", "finished", "", "") + dbr := db.Must().Model(&TransGlobalModel{}).Where("gid=? and status=?", gid, "committed").Updates(M{ "status": "finished", "finish_time": time.Now(), }) @@ -127,17 +165,17 @@ func innerProcessCommitedSaga(gid string) (rerr error) { return nil } for current = current - 1; current >= 0; current-- { - step := steps[current] - if step.Type != "compensate" || step.Status != "pending" { + step := branches[current] + if step.BranchType != "compensate" || step.Status != "prepared" { continue } - resp, err := dtm.RestyClient.R().SetBody(step.Data).SetQueryParam("gid", step.Gid).Post(step.Url) + resp, err := common.RestyClient.R().SetBody(step.Data).SetQueryParam("gid", step.Gid).Post(step.Url) if err != nil { return err } body := resp.String() if strings.Contains(body, "SUCCESS") { - writeTransLog(gid, "step rollbacked", "rollbacked", step.Step, "") + writeTransLog(gid, "step rollbacked", "rollbacked", step.Branch, "") dbr := db.Must().Model(&step).Where("status=?", step.Status).Updates(M{ "status": "rollbacked", "rollback_time": time.Now(), @@ -150,11 +188,17 @@ func innerProcessCommitedSaga(gid string) (rerr error) { if current != -1 { return fmt.Errorf("saga current not -1") } - writeTransLog(gid, "saga rollbacked", "rollbacked", -1, "") - dbr := db.Must().Model(&SagaModel{}).Where("status=? and gid=?", "commited", gid).Updates(M{ + writeTransLog(gid, "saga rollbacked", "rollbacked", "", "") + dbr := db.Must().Model(&TransGlobalModel{}).Where("status=? and gid=?", "committed", gid).Updates(M{ "status": "rollbacked", "rollback_time": time.Now(), }) checkAffected(dbr) return nil } + +func checkAffected(db1 *gorm.DB) { + if db1.RowsAffected == 0 { + panic(fmt.Errorf("duplicate updating")) + } +} diff --git a/dtmsvr/types.go b/dtmsvr/types.go index 2a13f28..7dbf9ad 100644 --- a/dtmsvr/types.go +++ b/dtmsvr/types.go @@ -2,41 +2,40 @@ package dtmsvr import ( "time" + + "github.com/yedf/dtm/common" ) type M = map[string]interface{} -type ModelBase struct { - ID uint - CreateTime *time.Time `gorm:"autoCreateTime"` - UpdateTime *time.Time `gorm:"autoUpdateTime"` -} -type SagaModel struct { - ModelBase - Gid string `json:"gid"` - Steps string `json:"steps"` - TransQuery string `json:"trans_query"` - Status string `json:"status"` - FinishTime *time.Time - RollbackTime *time.Time +type TransGlobalModel struct { + common.ModelBase + Gid string `json:"gid"` + TransType string `json:"trans_type"` + Data string `json:"data"` + Status string `json:"status"` + QueryPrepared string `json:"query_prepared"` + CommitTime *time.Time + FinishTime *time.Time + RollbackTime *time.Time } -func (*SagaModel) TableName() string { - return "saga" +func (*TransGlobalModel) TableName() string { + return "trans_global" } -type SagaStepModel struct { - ModelBase +type TransBranchModel struct { + common.ModelBase Gid string - Data string - Step int Url string - Type string + Data string + Branch string + BranchType string Status string FinishTime *time.Time RollbackTime *time.Time } -func (*SagaStepModel) TableName() string { - return "saga_step" +func (*TransBranchModel) TableName() string { + return "trans_branch" } diff --git a/examples/conf.yml.sample b/examples/conf.yml.sample new file mode 100644 index 0000000..e69de29 diff --git a/examples/config.go b/examples/config.go index 502fb27..1f787be 100644 --- a/examples/config.go +++ b/examples/config.go @@ -1,4 +1,7 @@ package examples -// 指定dtm服务地址 -const DtmServer = "http://localhost:8080/api/dtmsvr" +type exampleConfig struct { + Mysql map[string]string +} + +var Config = exampleConfig{} diff --git a/examples/examples.sql b/examples/examples.sql new file mode 100644 index 0000000..b5e2d9d --- /dev/null +++ b/examples/examples.sql @@ -0,0 +1,13 @@ +use dtm_busi; +drop table if exists user_account; +create table user_account( + id int(11) PRIMARY KEY AUTO_INCREMENT, + user_id int(11) UNIQUE , + balance DECIMAL(10, 2) not null default '0', + create_time datetime DEFAULT now(), + update_time datetime DEFAULT now(), + key(create_time), + key(update_time) +); + +insert into user_account (user_id, balance) values (1, 10000), (2, 10000); \ No newline at end of file diff --git a/examples/saga_main.go b/examples/saga_main.go index fa18441..b4991cc 100644 --- a/examples/saga_main.go +++ b/examples/saga_main.go @@ -39,8 +39,8 @@ func sagaFireRequest() { } saga := dtm.SagaNew(DtmServer, gid, SagaBusi+"/TransQuery") - saga.Add(SagaBusi+"/TransIn", SagaBusi+"/TransInCompensate", req) saga.Add(SagaBusi+"/TransOut", SagaBusi+"/TransOutCompensate", req) + saga.Add(SagaBusi+"/TransIn", SagaBusi+"/TransInCompensate", req) err := saga.Prepare() common.PanicIfError(err) logrus.Printf("busi trans commit") @@ -67,12 +67,6 @@ var TransInCompensateResult = "" var TransOutCompensateResult = "" var TransQueryResult = "" -type TransReq struct { - Amount int `json:"amount"` - TransInResult string `json:"transInResult"` - TransOutResult string `json:"transOutResult"` -} - func transReqFromContext(c *gin.Context) *TransReq { req := TransReq{} err := c.BindJSON(&req) diff --git a/examples/types.go b/examples/types.go new file mode 100644 index 0000000..49d6c0b --- /dev/null +++ b/examples/types.go @@ -0,0 +1,11 @@ +package examples + +import "github.com/yedf/dtm/common" + +type UserAccount struct { + common.ModelBase + UserId int + Balance string +} + +func (u *UserAccount) TableName() string { return "user_account" } diff --git a/examples/utils.go b/examples/utils.go new file mode 100644 index 0000000..4efc547 --- /dev/null +++ b/examples/utils.go @@ -0,0 +1,20 @@ +package examples + +import "github.com/yedf/dtm/common" + +// 指定dtm服务地址 +const DtmServer = "http://localhost:8080/api/dtmsvr" + +type TransReq struct { + Amount int `json:"amount"` + TransInResult string `json:"transInResult"` + TransOutResult string `json:"transOutResult"` +} + +func GenTransReq(amount int, outFailed bool, inFailed bool) *TransReq { + return &TransReq{ + Amount: amount, + TransOutResult: common.If(outFailed, "FAIL", "SUCCESS").(string), + TransInResult: common.If(inFailed, "FAIL", "SUCCESS").(string), + } +} diff --git a/examples/xa_main.go b/examples/xa_main.go index b223e51..4e5b2e5 100644 --- a/examples/xa_main.go +++ b/examples/xa_main.go @@ -4,8 +4,11 @@ import ( "fmt" "time" + "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" + "github.com/yedf/dtm" "github.com/yedf/dtm/common" + "gorm.io/gorm" ) // 事务参与者的服务地址 @@ -14,21 +17,84 @@ const XaBusiApi = "/api/busi_xa" var XaBusi = fmt.Sprintf("http://localhost:%d%s", XaBusiPort, XaBusiApi) +var XaClient *dtm.XaClient = nil + func XaMain() { go XaStartSvr() - xaFireRequest() + time.Sleep(100 * time.Millisecond) + XaFireRequest() time.Sleep(1000 * time.Second) } func XaStartSvr() { + common.InitApp(&Config) logrus.Printf("xa examples starting") app := common.GetGinApp() - AddRoute(app) - app.Run(":8081") + XaClient = dtm.XaClientNew(DtmServer, Config.Mysql, app, XaBusi+"/xa") + XaAddRoute(app) + app.Run(fmt.Sprintf(":%d", XaBusiPort)) } -func xaFireRequest() { - +func XaFireRequest() { + gid := common.GenGid() + err := XaClient.XaGlobalTransaction(gid, func() (rerr error) { + defer common.Panic2Error(&rerr) + req := GenTransReq(30, false, false) + resp, err := common.RestyClient.R().SetBody(req).SetQueryParams(map[string]string{ + "gid": gid, + "user_id": "1", + }).Post(XaBusi + "/TransOut") + common.CheckRestySuccess(resp, err) + resp, err = common.RestyClient.R().SetBody(req).SetQueryParams(map[string]string{ + "gid": gid, + "user_id": "2", + }).Post(XaBusi + "/TransOut") + common.CheckRestySuccess(resp, err) + return nil + }) + common.PanicIfError(err) } // api +func XaAddRoute(app *gin.Engine) { + app.POST(XaBusiApi+"/TransIn", common.WrapHandler(XaTransIn)) + app.POST(XaBusiApi+"/TransOut", common.WrapHandler(XaTransOut)) +} + +func XaTransIn(c *gin.Context) (interface{}, error) { + err := XaClient.XaLocalTransaction(c.Query("gid"), func(db *common.MyDb) (rerr error) { + dbr := db.Model(&UserAccount{}).Where("user_id = ?", c.Query("user_id")). + Update("balance", gorm.Expr("balance - ?", transReqFromContext(c).Amount)) + return dbr.Error + }) + common.PanicIfError(err) + return M{"result": "SUCCESS"}, nil +} + +func XaTransOut(c *gin.Context) (interface{}, error) { + err := XaClient.XaLocalTransaction(c.Query("gid"), func(db *common.MyDb) (rerr error) { + dbr := db.Model(&UserAccount{}).Where("user_id = ?", c.Query("user_id")). + Update("balance", gorm.Expr("balance + ?", transReqFromContext(c).Amount)) + return dbr.Error + }) + common.PanicIfError(err) + return M{"result": "SUCCESS"}, nil +} + +func ResetXaData() { + db := dbGet() + db.Must().Exec("truncate user_account") + db.Must().Exec("insert into user_account (user_id, balance) values (1, 10000), (2, 10000)") + type XaRow struct { + Data string + } + xas := []XaRow{} + db.Must().Raw("xa recover").Scan(&xas) + for _, xa := range xas { + db.Must().Exec(fmt.Sprintf("xa rollback '%s'", xa.Data)) + } +} + +func dbGet() *common.MyDb { + return common.DbGet(Config.Mysql) +} diff --git a/saga.go b/saga.go index ca989df..4b41fb0 100644 --- a/saga.go +++ b/saga.go @@ -3,10 +3,9 @@ package dtm import ( "encoding/json" "fmt" - "time" - "github.com/go-resty/resty/v2" "github.com/sirupsen/logrus" + "github.com/yedf/dtm/common" ) type Saga struct { @@ -15,21 +14,23 @@ type Saga struct { } type SagaData struct { - Gid string `json:"gid"` - Steps []SagaStep `json:"steps"` - TransQuery string `json:"trans_query"` + Gid string `json:"gid"` + TransType string `json:"trans_type"` + Steps []SagaStep `json:"steps"` + QueryPrepared string `json:"query_prepared"` } type SagaStep struct { Action string `json:"action"` Compensate string `json:"compensate"` - PostData string `json:"post_data"` + Data string `json:"data"` } -func SagaNew(server string, gid string, transQuery string) *Saga { +func SagaNew(server string, gid string, queryPrepared string) *Saga { return &Saga{ SagaData: SagaData{ - Gid: gid, - TransQuery: transQuery, + Gid: gid, + TransType: "saga", + QueryPrepared: queryPrepared, }, Server: server, } @@ -43,7 +44,7 @@ func (s *Saga) Add(action string, compensate string, postData interface{}) error step := SagaStep{ Action: action, Compensate: compensate, - PostData: string(d), + Data: string(d), } s.Steps = append(s.Steps, step) return nil @@ -51,7 +52,7 @@ func (s *Saga) Add(action string, compensate string, postData interface{}) error func (s *Saga) Prepare() error { logrus.Printf("preparing %s body: %v", s.Gid, &s.SagaData) - resp, err := RestyClient.R().SetBody(&s.SagaData).Post(fmt.Sprintf("%s/prepare", s.Server)) + resp, err := common.RestyClient.R().SetBody(&s.SagaData).Post(fmt.Sprintf("%s/prepare", s.Server)) if err != nil { return err } @@ -63,7 +64,7 @@ func (s *Saga) Prepare() error { func (s *Saga) Commit() error { logrus.Printf("committing %s body: %v", s.Gid, &s.SagaData) - resp, err := RestyClient.R().SetBody(&s.SagaData).Post(fmt.Sprintf("%s/commit", s.Server)) + resp, err := common.RestyClient.R().SetBody(&s.SagaData).Post(fmt.Sprintf("%s/commit", s.Server)) if err != nil { return err } @@ -72,21 +73,3 @@ func (s *Saga) Commit() error { } return nil } - -// 辅助工具与代码 -var RestyClient = resty.New() - -func init() { - RestyClient.SetTimeout(3 * time.Second) - RestyClient.SetRetryCount(2) - RestyClient.SetRetryWaitTime(1 * time.Second) - RestyClient.OnBeforeRequest(func(c *resty.Client, r *resty.Request) error { - logrus.Printf("requesting: %s %s %v", r.Method, r.URL, r.Body) - return nil - }) - RestyClient.OnAfterResponse(func(c *resty.Client, resp *resty.Response) error { - r := resp.Request - logrus.Printf("requested: %s %s %s", r.Method, r.URL, resp.String()) - return nil - }) -} diff --git a/xa.go b/xa.go new file mode 100644 index 0000000..1e5a2d2 --- /dev/null +++ b/xa.go @@ -0,0 +1,113 @@ +package dtm + +import ( + "fmt" + "net/url" + "strings" + + "github.com/gin-gonic/gin" + "github.com/sirupsen/logrus" + "github.com/yedf/dtm/common" +) + +type M = map[string]interface{} +type XaGlobalFunc func() error + +type XaLocalFunc func(db *common.MyDb) error + +type XaClient struct { + Server string + Conf map[string]string + CallbackUrl string +} + +func XaClientNew(server string, mysqlConf map[string]string, app *gin.Engine, callbackUrl string) *XaClient { + xa := &XaClient{ + Server: server, + Conf: mysqlConf, + CallbackUrl: callbackUrl, + } + u, err := url.Parse(callbackUrl) + common.PanicIfError(err) + app.POST(u.Path, common.WrapHandler(func(c *gin.Context) (interface{}, error) { + type CallbackReq struct { + Gid string `json:"gid"` + Branch string `json:"branch"` + Action string `json:"action"` + } + req := CallbackReq{} + b, err := c.GetRawData() + common.PanicIfError(err) + common.MustUnmarshal(b, &req) + tx, my := common.DbAlone(xa.Conf) + defer func() { + logrus.Printf("closing conn %v", xa.Conf) + my.Close() + }() + if req.Action == "commit" { + tx.Must().Exec(fmt.Sprintf("xa commit '%s'", req.Branch)) + } else if req.Action == "rollback" { + tx.Must().Exec(fmt.Sprintf("xa rollback '%s'", req.Branch)) + } else { + panic(fmt.Errorf("unknown action: %s", req.Action)) + } + return M{"result": "SUCCESS"}, nil + })) + return xa +} +func (xa *XaClient) XaLocalTransaction(gid string, transFunc XaLocalFunc) (rerr error) { + defer common.Panic2Error(&rerr) + branch := common.GenGid() + tx, my := common.DbAlone(xa.Conf) + defer func() { + logrus.Printf("closing conn %v", xa.Conf) + my.Close() + }() + // tx1 := db.Session(&gorm.Session{SkipDefaultTransaction: true}) + // common.PanicIfError(tx1.Error) + // tx := common.MyDb{DB: tx1} + tx.Must().Exec(fmt.Sprintf("XA start '%s'", branch)) + err := transFunc(tx) + common.PanicIfError(err) + resp, err := common.RestyClient.R(). + SetBody(&M{"gid": gid, "branch": branch, "trans_type": "xa", "status": "prepared", "url": xa.CallbackUrl}). + Post(xa.Server + "/branch") + common.PanicIfError(err) + if !strings.Contains(resp.String(), "SUCCESS") { + common.PanicIfError(fmt.Errorf("unknown server response: %s", resp.String())) + } + tx.Must().Exec(fmt.Sprintf("XA end '%s'", branch)) + tx.Must().Exec(fmt.Sprintf("XA prepare '%s'", branch)) + return nil +} + +func (xa *XaClient) XaGlobalTransaction(gid string, transFunc XaGlobalFunc) (rerr error) { + data := &M{ + "gid": gid, + "trans_type": "xa", + } + defer func() { + x := recover() + if x != nil { + _, _ = common.RestyClient.R().SetBody(data).Post(xa.Server + "/rollback") + rerr = x.(error) + } + }() + resp, err := common.RestyClient.R().SetBody(data).Post(xa.Server + "/prepare") + common.PanicIfError(err) + if !strings.Contains(resp.String(), "SUCCESS") { + panic(fmt.Errorf("unexpected result: %s", resp.String())) + } + err = transFunc() + common.PanicIfError(err) + resp, err = common.RestyClient.R().SetBody(data).Post(xa.Server + "/commit") + common.PanicIfError(err) + if !strings.Contains(resp.String(), "SUCCESS") { + panic(fmt.Errorf("unexpected result: %s", resp.String())) + } + return nil +} + +func getDb(conf map[string]string) *common.MyDb { + return common.DbGet(conf) +}