fix golint

This commit is contained in:
yedongfu 2021-07-16 11:37:41 +08:00
parent c1a2f5c965
commit 193f54e7f3
16 changed files with 151 additions and 98 deletions

View File

@ -1,6 +1,7 @@
![license](https://img.shields.io/github/license/yedf/dtm) ![license](https://img.shields.io/github/license/yedf/dtm)
[![Go Reference](https://pkg.go.dev/badge/github.com/yedf/dtm.svg)](https://pkg.go.dev/github.com/yedf/dtm)
[![Build Status](https://travis-ci.com/yedf/dtm.svg?branch=main)](https://travis-ci.com/yedf/dtm) [![Build Status](https://travis-ci.com/yedf/dtm.svg?branch=main)](https://travis-ci.com/yedf/dtm)
[![Go Report Card](https://goreportcard.com/badge/github.com/yedf/dtm)](https://goreportcard.com/report/github.com/yedf/dtm)
[![Go Reference](https://pkg.go.dev/badge/github.com/yedf/dtm.svg)](https://pkg.go.dev/github.com/yedf/dtm)
[English](https://github.com/yedf/dtm/blob/master/README-en.md) [English](https://github.com/yedf/dtm/blob/master/README-en.md)

View File

@ -12,9 +12,13 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
// M a short name
type M = map[string]interface{} type M = map[string]interface{}
// MS a short name
type MS = map[string]string type MS = map[string]string
// ModelBase model base for gorm to provide base fields
type ModelBase struct { type ModelBase struct {
ID uint ID uint
CreateTime *time.Time `gorm:"autoCreateTime"` CreateTime *time.Time `gorm:"autoCreateTime"`
@ -23,21 +27,25 @@ type ModelBase struct {
var dbs = map[string]*DB{} var dbs = map[string]*DB{}
// DB provide more func over gorm.DB
type DB struct { type DB struct {
*gorm.DB *gorm.DB
} }
// Must set must flag, panic when error occur
func (m *DB) Must() *DB { func (m *DB) Must() *DB {
db := m.InstanceSet("ivy.must", true) db := m.InstanceSet("ivy.must", true)
return &DB{DB: db} return &DB{DB: db}
} }
// NoMust unset must flag, don't panic when error occur
func (m *DB) NoMust() *DB { func (m *DB) NoMust() *DB {
db := m.InstanceSet("ivy.must", false) db := m.InstanceSet("ivy.must", false)
return &DB{DB: db} return &DB{DB: db}
} }
func (m *DB) ToSqlDB() *sql.DB { // ToSQLDB get the sql.DB
func (m *DB) ToSQLDB() *sql.DB {
d, err := m.DB.DB() d, err := m.DB.DB()
E2P(err) E2P(err)
return d return d
@ -87,6 +95,7 @@ func (op *tracePlugin) Initialize(db *gorm.DB) (err error) {
return return
} }
// GetDsn get dsn from map config
func GetDsn(conf map[string]string) string { func GetDsn(conf map[string]string) string {
if IsDockerCompose() { if IsDockerCompose() {
conf["host"] = strings.Replace(conf["host"], "localhost", "host.docker.internal", 1) conf["host"] = strings.Replace(conf["host"], "localhost", "host.docker.internal", 1)
@ -95,11 +104,13 @@ func GetDsn(conf map[string]string) string {
return fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", conf["user"], conf["password"], conf["host"], conf["port"], conf["database"]) return fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", conf["user"], conf["password"], conf["host"], conf["port"], conf["database"])
} }
// ReplaceDsnPassword replace password for log output
func ReplaceDsnPassword(dsn string) string { func ReplaceDsnPassword(dsn string) string {
reg := regexp.MustCompile(`:.*@`) reg := regexp.MustCompile(`:.*@`)
return reg.ReplaceAllString(dsn, ":****@") return reg.ReplaceAllString(dsn, ":****@")
} }
// DbGet get db connection for specified conf
func DbGet(conf map[string]string) *DB { func DbGet(conf map[string]string) *DB {
dsn := GetDsn(conf) dsn := GetDsn(conf)
if dbs[dsn] == nil { if dbs[dsn] == nil {
@ -114,7 +125,8 @@ func DbGet(conf map[string]string) *DB {
return dbs[dsn] return dbs[dsn]
} }
func SqlDB2DB(sdb *sql.DB) *DB { // SQLDB2DB name is clear
func SQLDB2DB(sdb *sql.DB) *DB {
db, err := gorm.Open(mysql.New(mysql.Config{ db, err := gorm.Open(mysql.New(mysql.Config{
Conn: sdb, Conn: sdb,
}), &gorm.Config{}) }), &gorm.Config{})
@ -123,16 +135,19 @@ func SqlDB2DB(sdb *sql.DB) *DB {
return &DB{DB: db} return &DB{DB: db}
} }
// MyConn for xa alone connection
type MyConn struct { type MyConn struct {
Conn *sql.DB Conn *sql.DB
Dsn string Dsn string
} }
// Close name is clear
func (conn *MyConn) Close() { func (conn *MyConn) Close() {
logrus.Printf("closing alone mysql: %s", ReplaceDsnPassword(conn.Dsn)) logrus.Printf("closing alone mysql: %s", ReplaceDsnPassword(conn.Dsn))
conn.Conn.Close() conn.Conn.Close()
} }
// DbAlone get a standalone db connection
func DbAlone(conf map[string]string) (*DB, *MyConn) { func DbAlone(conf map[string]string) (*DB, *MyConn) {
dsn := GetDsn(conf) dsn := GetDsn(conf)
logrus.Printf("opening alone mysql: %s", ReplaceDsnPassword(dsn)) logrus.Printf("opening alone mysql: %s", ReplaceDsnPassword(dsn))

View File

@ -20,6 +20,7 @@ import (
yaml "gopkg.in/yaml.v2" yaml "gopkg.in/yaml.v2"
) )
// P2E panic to error
func P2E(perr *error) { func P2E(perr *error) {
if x := recover(); x != nil { if x := recover(); x != nil {
if e, ok := x.(error); ok { if e, ok := x.(error); ok {
@ -30,18 +31,21 @@ func P2E(perr *error) {
} }
} }
// E2P error to panic
func E2P(err error) { func E2P(err error) {
if err != nil { if err != nil {
panic(err) panic(err)
} }
} }
// CatchP catch panic to error
func CatchP(f func()) (rerr error) { func CatchP(f func()) (rerr error) {
defer P2E(&rerr) defer P2E(&rerr)
f() f()
return nil return nil
} }
// PanicIf name is clear
func PanicIf(cond bool, err error) { func PanicIf(cond bool, err error) {
if cond { if cond {
panic(err) panic(err)
@ -57,6 +61,7 @@ func MustAtoi(s string) int {
return r return r
} }
// OrString return the first not empty string
func OrString(ss ...string) string { func OrString(ss ...string) string {
for _, s := range ss { for _, s := range ss {
if s != "" { if s != "" {
@ -66,6 +71,7 @@ func OrString(ss ...string) string {
return "" return ""
} }
// If ternary operator
func If(condition bool, trueObj interface{}, falseObj interface{}) interface{} { func If(condition bool, trueObj interface{}, falseObj interface{}) interface{} {
if condition { if condition {
return trueObj return trueObj
@ -73,24 +79,30 @@ func If(condition bool, trueObj interface{}, falseObj interface{}) interface{} {
return falseObj return falseObj
} }
// MustMarshal checked version for marshal
func MustMarshal(v interface{}) []byte { func MustMarshal(v interface{}) []byte {
b, err := json.Marshal(v) b, err := json.Marshal(v)
E2P(err) E2P(err)
return b return b
} }
// MustMarshalString string version of MustMarshal
func MustMarshalString(v interface{}) string { func MustMarshalString(v interface{}) string {
return string(MustMarshal(v)) return string(MustMarshal(v))
} }
// MustUnmarshal checked version for unmarshal
func MustUnmarshal(b []byte, obj interface{}) { func MustUnmarshal(b []byte, obj interface{}) {
err := json.Unmarshal(b, obj) err := json.Unmarshal(b, obj)
E2P(err) E2P(err)
} }
// MustUnmarshalString string version of MustUnmarshal
func MustUnmarshalString(s string, obj interface{}) { func MustUnmarshalString(s string, obj interface{}) {
MustUnmarshal([]byte(s), obj) MustUnmarshal([]byte(s), obj)
} }
// MustRemarshal marshal and unmarshal, and check error
func MustRemarshal(from interface{}, to interface{}) { func MustRemarshal(from interface{}, to interface{}) {
b, err := json.Marshal(from) b, err := json.Marshal(from)
E2P(err) E2P(err)
@ -98,6 +110,7 @@ func MustRemarshal(from interface{}, to interface{}) {
E2P(err) E2P(err)
} }
// GetGinApp init and return gin
func GetGinApp() *gin.Engine { func GetGinApp() *gin.Engine {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
app := gin.Default() app := gin.Default()
@ -121,6 +134,7 @@ func GetGinApp() *gin.Engine {
return app return app
} }
// WrapHandler name is clear
func WrapHandler(fn func(*gin.Context) (interface{}, error)) gin.HandlerFunc { func WrapHandler(fn func(*gin.Context) (interface{}, error)) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
r, err := fn(c) r, err := fn(c)
@ -141,7 +155,7 @@ func WrapHandler(fn func(*gin.Context) (interface{}, error)) gin.HandlerFunc {
} }
} }
// 辅助工具与代码 // RestyClient the resty object
var RestyClient = resty.New() var RestyClient = resty.New()
func init() { func init() {
@ -162,6 +176,7 @@ func init() {
}) })
} }
// CheckRestySuccess panic if error or resp not success
func CheckRestySuccess(resp *resty.Response, err error) { func CheckRestySuccess(resp *resty.Response, err error) {
E2P(err) E2P(err)
if !strings.Contains(resp.String(), "SUCCESS") { if !strings.Contains(resp.String(), "SUCCESS") {
@ -192,7 +207,7 @@ func (f *formatter) Format(entry *logrus.Entry) ([]byte, error) {
return b.Bytes(), nil return b.Bytes(), nil
} }
// 加载调用者文件相同目录下的配置文件 // InitApp init config
func InitApp(dir string, config interface{}) { func InitApp(dir string, config interface{}) {
logrus.SetFormatter(&formatter{}) logrus.SetFormatter(&formatter{})
cont, err := ioutil.ReadFile(dir + "/conf.yml") cont, err := ioutil.ReadFile(dir + "/conf.yml")
@ -205,17 +220,20 @@ func InitApp(dir string, config interface{}) {
E2P(err) E2P(err)
} }
func Getwd() string { // MustGetwd must version of os.Getwd
func MustGetwd() string {
wd, err := os.Getwd() wd, err := os.Getwd()
E2P(err) E2P(err)
return wd return wd
} }
func GetCurrentDir() string { // GetCurrentCodeDir name is clear
func GetCurrentCodeDir() string {
_, file, _, _ := runtime.Caller(1) _, file, _, _ := runtime.Caller(1)
return filepath.Dir(file) return filepath.Dir(file)
} }
// GetProjectDir name is clear
func GetProjectDir() string { func GetProjectDir() string {
_, file, _, _ := runtime.Caller(1) _, file, _, _ := runtime.Caller(1)
for ; !strings.HasSuffix(file, "/dtm"); file = filepath.Dir(file) { for ; !strings.HasSuffix(file, "/dtm"); file = filepath.Dir(file) {
@ -223,11 +241,13 @@ func GetProjectDir() string {
return file return file
} }
// GetFuncName get current call func name
func GetFuncName() string { func GetFuncName() string {
pc, _, _, _ := runtime.Caller(1) pc, _, _, _ := runtime.Caller(1)
return runtime.FuncForPC(pc).Name() return runtime.FuncForPC(pc).Name()
} }
// IsDockerCompose name is clear
func IsDockerCompose() bool { func IsDockerCompose() bool {
return os.Getenv("IS_DOCKER_COMPOSE") != "" return os.Getenv("IS_DOCKER_COMPOSE") != ""
} }

View File

@ -10,37 +10,37 @@ import (
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
) )
func AddRoute(engine *gin.Engine) { func addRoute(engine *gin.Engine) {
engine.POST("/api/dtmsvr/prepare", common.WrapHandler(Prepare)) engine.POST("/api/dtmsvr/prepare", common.WrapHandler(prepare))
engine.POST("/api/dtmsvr/submit", common.WrapHandler(Submit)) engine.POST("/api/dtmsvr/submit", common.WrapHandler(submit))
engine.POST("/api/dtmsvr/registerXaBranch", common.WrapHandler(RegisterXaBranch)) engine.POST("/api/dtmsvr/registerXaBranch", common.WrapHandler(registerXaBranch))
engine.POST("/api/dtmsvr/registerTccBranch", common.WrapHandler(RegisterTccBranch)) engine.POST("/api/dtmsvr/registerTccBranch", common.WrapHandler(registerTccBranch))
engine.POST("/api/dtmsvr/abort", common.WrapHandler(Abort)) engine.POST("/api/dtmsvr/abort", common.WrapHandler(abort))
engine.GET("/api/dtmsvr/query", common.WrapHandler(Query)) engine.GET("/api/dtmsvr/query", common.WrapHandler(query))
engine.GET("/api/dtmsvr/newGid", common.WrapHandler(NewGid)) engine.GET("/api/dtmsvr/newGid", common.WrapHandler(newGid))
} }
func NewGid(c *gin.Context) (interface{}, error) { func newGid(c *gin.Context) (interface{}, error) {
return M{"gid": GenGid()}, nil return M{"gid": GenGid()}, nil
} }
func Prepare(c *gin.Context) (interface{}, error) { func prepare(c *gin.Context) (interface{}, error) {
m := TransFromContext(c) m := TransFromContext(c)
m.Status = "prepared" m.Status = "prepared"
m.SaveNew(dbGet()) m.saveNew(dbGet())
return M{"message": "SUCCESS", "gid": m.Gid}, nil return M{"message": "SUCCESS", "gid": m.Gid}, nil
} }
func Submit(c *gin.Context) (interface{}, error) { func submit(c *gin.Context) (interface{}, error) {
db := dbGet() db := dbGet()
m := TransFromContext(c) m := TransFromContext(c)
m.Status = "submitted" m.Status = "submitted"
m.SaveNew(db) m.saveNew(db)
go m.Process(db) go m.Process(db)
return M{"message": "SUCCESS", "gid": m.Gid}, nil return M{"message": "SUCCESS", "gid": m.Gid}, nil
} }
func Abort(c *gin.Context) (interface{}, error) { func abort(c *gin.Context) (interface{}, error) {
db := dbGet() db := dbGet()
m := TransFromContext(c) m := TransFromContext(c)
m = TransFromDb(db, m.Gid) m = TransFromDb(db, m.Gid)
@ -51,7 +51,7 @@ func Abort(c *gin.Context) (interface{}, error) {
return M{"message": "SUCCESS"}, nil return M{"message": "SUCCESS"}, nil
} }
func RegisterXaBranch(c *gin.Context) (interface{}, error) { func registerXaBranch(c *gin.Context) (interface{}, error) {
branch := TransBranch{} branch := TransBranch{}
err := c.BindJSON(&branch) err := c.BindJSON(&branch)
e2p(err) e2p(err)
@ -68,7 +68,7 @@ func RegisterXaBranch(c *gin.Context) (interface{}, error) {
return M{"message": "SUCCESS"}, nil return M{"message": "SUCCESS"}, nil
} }
func RegisterTccBranch(c *gin.Context) (interface{}, error) { func registerTccBranch(c *gin.Context) (interface{}, error) {
data := common.MS{} data := common.MS{}
err := c.BindJSON(&data) err := c.BindJSON(&data)
e2p(err) e2p(err)
@ -82,7 +82,7 @@ func RegisterTccBranch(c *gin.Context) (interface{}, error) {
branches := []TransBranch{branch, branch, branch} branches := []TransBranch{branch, branch, branch}
for i, b := range []string{"cancel", "confirm", "try"} { for i, b := range []string{"cancel", "confirm", "try"} {
branches[i].BranchType = b branches[i].BranchType = b
branches[i].Url = data[b] branches[i].URL = data[b]
} }
dbGet().Must().Clauses(clause.OnConflict{ dbGet().Must().Clauses(clause.OnConflict{
@ -94,7 +94,7 @@ func RegisterTccBranch(c *gin.Context) (interface{}, error) {
return M{"message": "SUCCESS"}, nil return M{"message": "SUCCESS"}, nil
} }
func Query(c *gin.Context) (interface{}, error) { func query(c *gin.Context) (interface{}, error) {
gid := c.Query("gid") gid := c.Query("gid")
if gid == "" { if gid == "" {
return nil, errors.New("no gid specified") return nil, errors.New("no gid specified")

View File

@ -9,13 +9,17 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// CronPrepared cron expired prepared trans forever
func CronPrepared() { func CronPrepared() {
for { for {
CronTransOnce(time.Duration(0), "prepared") notEmpty := CronTransOnce(time.Duration(0), "prepared")
sleepCronTime() if !notEmpty {
sleepCronTime()
}
} }
} }
// CronTransOnce cron expired trans who's status match param status for once. use expireIn as expire time
func CronTransOnce(expireIn time.Duration, status string) bool { func CronTransOnce(expireIn time.Duration, status string) bool {
defer handlePanic() defer handlePanic()
trans := lockOneTrans(expireIn, status) trans := lockOneTrans(expireIn, status)
@ -27,6 +31,7 @@ func CronTransOnce(expireIn time.Duration, status string) bool {
return true return true
} }
// CronSubmitted cron expired submitted trans forever
func CronSubmitted() { func CronSubmitted() {
for { for {
notEmpty := CronTransOnce(time.Duration(0), "submitted") notEmpty := CronTransOnce(time.Duration(0), "submitted")

View File

@ -238,7 +238,7 @@ func TestSqlDB(t *testing.T) {
BranchType: "compensate", BranchType: "compensate",
} }
db.Must().Exec("insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type) values('saga', 'gid1', 'branch_id1', 'action')") db.Must().Exec("insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type) values('saga', 'gid1', 'branch_id1', 'action')")
_, err := dtmcli.ThroughBarrierCall(db.ToSqlDB(), transInfo, func(db *sql.DB) (interface{}, error) { _, err := dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.DB) (interface{}, error) {
logrus.Printf("rollback gid2") logrus.Printf("rollback gid2")
return nil, fmt.Errorf("gid2 error") return nil, fmt.Errorf("gid2 error")
}) })
@ -247,7 +247,7 @@ func TestSqlDB(t *testing.T) {
asserts.Equal(dbr.RowsAffected, int64(1)) asserts.Equal(dbr.RowsAffected, int64(1))
dbr = db.Model(&dtmcli.BarrierModel{}).Where("gid=?", "gid2").Find(&[]dtmcli.BarrierModel{}) dbr = db.Model(&dtmcli.BarrierModel{}).Where("gid=?", "gid2").Find(&[]dtmcli.BarrierModel{})
asserts.Equal(dbr.RowsAffected, int64(0)) asserts.Equal(dbr.RowsAffected, int64(0))
_, err = dtmcli.ThroughBarrierCall(db.ToSqlDB(), transInfo, func(db *sql.DB) (interface{}, error) { _, err = dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.DB) (interface{}, error) {
logrus.Printf("submit gid2") logrus.Printf("submit gid2")
return nil, nil return nil, nil
}) })

View File

@ -11,25 +11,28 @@ import (
var dtmsvrPort = 8080 var dtmsvrPort = 8080
// MainStart main
func MainStart() { func MainStart() {
StartSvr() StartSvr()
go CronSubmitted() go CronSubmitted()
go CronPrepared() go CronPrepared()
} }
// StartSvr StartSvr
func StartSvr() { func StartSvr() {
logrus.Printf("start dtmsvr") logrus.Printf("start dtmsvr")
common.InitApp(common.GetProjectDir(), &config) common.InitApp(common.GetProjectDir(), &config)
config.Mysql["database"] = dbName config.Mysql["database"] = dbName
app := common.GetGinApp() app := common.GetGinApp()
AddRoute(app) addRoute(app)
logrus.Printf("dtmsvr listen at: %d", dtmsvrPort) logrus.Printf("dtmsvr listen at: %d", dtmsvrPort)
go app.Run(fmt.Sprintf(":%d", dtmsvrPort)) go app.Run(fmt.Sprintf(":%d", dtmsvrPort))
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
} }
// PopulateMysql setup mysql data
func PopulateMysql() { func PopulateMysql() {
common.InitApp(common.GetProjectDir(), &config) common.InitApp(common.GetProjectDir(), &config)
config.Mysql["database"] = "" config.Mysql["database"] = ""
examples.RunSqlScript(config.Mysql, common.GetCurrentDir()+"/dtmsvr.sql") examples.RunSqlScript(config.Mysql, common.GetCurrentCodeDir()+"/dtmsvr.sql")
} }

View File

@ -11,6 +11,7 @@ import (
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
) )
// TransGlobal global transaction
type TransGlobal struct { type TransGlobal struct {
common.ModelBase common.ModelBase
Gid string `json:"gid"` Gid string `json:"gid"`
@ -25,11 +26,12 @@ type TransGlobal struct {
NextCronTime *time.Time NextCronTime *time.Time
} }
// TableName TableName
func (*TransGlobal) TableName() string { func (*TransGlobal) TableName() string {
return "dtm.trans_global" return "dtm.trans_global"
} }
type TransProcessor interface { type transProcessor interface {
GenBranches() []TransBranch GenBranches() []TransBranch
ProcessOnce(db *common.DB, branches []TransBranch) ProcessOnce(db *common.DB, branches []TransBranch)
ExecBranch(db *common.DB, branch *TransBranch) ExecBranch(db *common.DB, branch *TransBranch)
@ -60,10 +62,11 @@ func (t *TransGlobal) changeStatus(db *common.DB, status string) *gorm.DB {
return dbr return dbr
} }
// TransBranch branch transaction
type TransBranch struct { type TransBranch struct {
common.ModelBase common.ModelBase
Gid string Gid string
Url string URL string `json:"url"`
Data string Data string
BranchID string `json:"branch_id"` BranchID string `json:"branch_id"`
BranchType string BranchType string
@ -72,6 +75,7 @@ type TransBranch struct {
RollbackTime *time.Time RollbackTime *time.Time
} }
// TableName TableName
func (*TransBranch) TableName() string { func (*TransBranch) TableName() string {
return "dtm.trans_branch" return "dtm.trans_branch"
} }
@ -93,7 +97,7 @@ func checkAffected(db1 *gorm.DB) {
} }
} }
type processorCreator func(*TransGlobal) TransProcessor type processorCreator func(*TransGlobal) transProcessor
var processorFac = map[string]processorCreator{} var processorFac = map[string]processorCreator{}
@ -101,28 +105,29 @@ func registorProcessorCreator(transType string, creator processorCreator) {
processorFac[transType] = creator processorFac[transType] = creator
} }
func (trans *TransGlobal) getProcessor() TransProcessor { func (t *TransGlobal) getProcessor() transProcessor {
return processorFac[trans.TransType](trans) return processorFac[t.TransType](t)
} }
func (trans *TransGlobal) Process(db *common.DB) { // Process process global transaction once
func (t *TransGlobal) Process(db *common.DB) {
defer handlePanic() defer handlePanic()
defer func() { defer func() {
if TransProcessedTestChan != nil { if TransProcessedTestChan != nil {
logrus.Printf("processed: %s", trans.Gid) logrus.Printf("processed: %s", t.Gid)
TransProcessedTestChan <- trans.Gid TransProcessedTestChan <- t.Gid
} }
}() }()
logrus.Printf("processing: %s", trans.Gid) logrus.Printf("processing: %s", t.Gid)
branches := []TransBranch{} branches := []TransBranch{}
db.Must().Where("gid=?", trans.Gid).Order("id asc").Find(&branches) db.Must().Where("gid=?", t.Gid).Order("id asc").Find(&branches)
trans.getProcessor().ProcessOnce(db, branches) t.getProcessor().ProcessOnce(db, branches)
} }
func (trans *TransGlobal) getBranchParams(branch *TransBranch) common.MS { func (t *TransGlobal) getBranchParams(branch *TransBranch) common.MS {
return common.MS{ return common.MS{
"gid": trans.Gid, "gid": t.Gid,
"trans_type": trans.TransType, "trans_type": t.TransType,
"branch_id": branch.BranchID, "branch_id": branch.BranchID,
"branch_type": branch.BranchType, "branch_type": branch.BranchType,
} }
@ -135,7 +140,7 @@ func (t *TransGlobal) setNextCron(expireIn int64) []string {
return []string{"next_cron_interval", "next_cron_time"} return []string{"next_cron_interval", "next_cron_time"}
} }
func (t *TransGlobal) SaveNew(db *common.DB) { func (t *TransGlobal) saveNew(db *common.DB) {
if t.Gid == "" { if t.Gid == "" {
t.Gid = GenGid() t.Gid = GenGid()
} }
@ -162,6 +167,7 @@ func (t *TransGlobal) SaveNew(db *common.DB) {
e2p(err) e2p(err)
} }
// TransFromContext TransFromContext
func TransFromContext(c *gin.Context) *TransGlobal { func TransFromContext(c *gin.Context) *TransGlobal {
data := M{} data := M{}
b, err := c.GetRawData() b, err := c.GetRawData()
@ -176,6 +182,7 @@ func TransFromContext(c *gin.Context) *TransGlobal {
return &m return &m
} }
// TransFromDb construct trans from db
func TransFromDb(db *common.DB, gid string) *TransGlobal { func TransFromDb(db *common.DB, gid string) *TransGlobal {
m := TransGlobal{} m := TransGlobal{}
dbr := db.Must().Model(&m).Where("gid=?", gid).First(&m) dbr := db.Must().Model(&m).Where("gid=?", gid).First(&m)

View File

@ -7,15 +7,15 @@ import (
"github.com/yedf/dtm/common" "github.com/yedf/dtm/common"
) )
type TransMsgProcessor struct { type transMsgProcessor struct {
*TransGlobal *TransGlobal
} }
func init() { func init() {
registorProcessorCreator("msg", func(trans *TransGlobal) TransProcessor { return &TransMsgProcessor{TransGlobal: trans} }) registorProcessorCreator("msg", func(trans *TransGlobal) transProcessor { return &transMsgProcessor{TransGlobal: trans} })
} }
func (t *TransMsgProcessor) GenBranches() []TransBranch { func (t *transMsgProcessor) GenBranches() []TransBranch {
branches := []TransBranch{} branches := []TransBranch{}
steps := []M{} steps := []M{}
common.MustUnmarshalString(t.Data, &steps) common.MustUnmarshalString(t.Data, &steps)
@ -24,7 +24,7 @@ func (t *TransMsgProcessor) GenBranches() []TransBranch {
Gid: t.Gid, Gid: t.Gid,
BranchID: GenGid(), BranchID: GenGid(),
Data: step["data"].(string), Data: step["data"].(string),
Url: step["action"].(string), URL: step["action"].(string),
BranchType: "action", BranchType: "action",
Status: "prepared", Status: "prepared",
}) })
@ -32,8 +32,8 @@ func (t *TransMsgProcessor) GenBranches() []TransBranch {
return branches return branches
} }
func (t *TransMsgProcessor) ExecBranch(db *common.DB, branch *TransBranch) { func (t *transMsgProcessor) ExecBranch(db *common.DB, branch *TransBranch) {
resp, err := common.RestyClient.R().SetBody(branch.Data).SetQueryParams(t.getBranchParams(branch)).Post(branch.Url) resp, err := common.RestyClient.R().SetBody(branch.Data).SetQueryParams(t.getBranchParams(branch)).Post(branch.URL)
e2p(err) e2p(err)
body := resp.String() body := resp.String()
if strings.Contains(body, "SUCCESS") { if strings.Contains(body, "SUCCESS") {
@ -58,7 +58,7 @@ func (t *TransGlobal) mayQueryPrepared(db *common.DB) {
} }
} }
func (t *TransMsgProcessor) ProcessOnce(db *common.DB, branches []TransBranch) { func (t *transMsgProcessor) ProcessOnce(db *common.DB, branches []TransBranch) {
t.mayQueryPrepared(db) t.mayQueryPrepared(db)
if t.Status != "submitted" { if t.Status != "submitted" {
return return

View File

@ -7,15 +7,15 @@ import (
"github.com/yedf/dtm/common" "github.com/yedf/dtm/common"
) )
type TransSagaProcessor struct { type transSagaProcessor struct {
*TransGlobal *TransGlobal
} }
func init() { func init() {
registorProcessorCreator("saga", func(trans *TransGlobal) TransProcessor { return &TransSagaProcessor{TransGlobal: trans} }) registorProcessorCreator("saga", func(trans *TransGlobal) transProcessor { return &transSagaProcessor{TransGlobal: trans} })
} }
func (t *TransSagaProcessor) GenBranches() []TransBranch { func (t *transSagaProcessor) GenBranches() []TransBranch {
branches := []TransBranch{} branches := []TransBranch{}
steps := []M{} steps := []M{}
common.MustUnmarshalString(t.Data, &steps) common.MustUnmarshalString(t.Data, &steps)
@ -26,7 +26,7 @@ func (t *TransSagaProcessor) GenBranches() []TransBranch {
Gid: t.Gid, Gid: t.Gid,
BranchID: branch, BranchID: branch,
Data: step["data"].(string), Data: step["data"].(string),
Url: step[branchType].(string), URL: step[branchType].(string),
BranchType: branchType, BranchType: branchType,
Status: "prepared", Status: "prepared",
}) })
@ -35,8 +35,8 @@ func (t *TransSagaProcessor) GenBranches() []TransBranch {
return branches return branches
} }
func (t *TransSagaProcessor) ExecBranch(db *common.DB, branch *TransBranch) { func (t *transSagaProcessor) ExecBranch(db *common.DB, branch *TransBranch) {
resp, err := common.RestyClient.R().SetBody(branch.Data).SetQueryParams(t.getBranchParams(branch)).Post(branch.Url) resp, err := common.RestyClient.R().SetBody(branch.Data).SetQueryParams(t.getBranchParams(branch)).Post(branch.URL)
e2p(err) e2p(err)
body := resp.String() body := resp.String()
if strings.Contains(body, "SUCCESS") { if strings.Contains(body, "SUCCESS") {
@ -50,7 +50,7 @@ func (t *TransSagaProcessor) ExecBranch(db *common.DB, branch *TransBranch) {
} }
} }
func (t *TransSagaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) { func (t *transSagaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) {
if t.Status != "submitted" { if t.Status != "submitted" {
return return
} }

View File

@ -7,20 +7,20 @@ import (
"github.com/yedf/dtm/common" "github.com/yedf/dtm/common"
) )
type TransTccProcessor struct { type transTccProcessor struct {
*TransGlobal *TransGlobal
} }
func init() { func init() {
registorProcessorCreator("tcc", func(trans *TransGlobal) TransProcessor { return &TransTccProcessor{TransGlobal: trans} }) registorProcessorCreator("tcc", func(trans *TransGlobal) transProcessor { return &transTccProcessor{TransGlobal: trans} })
} }
func (t *TransTccProcessor) GenBranches() []TransBranch { func (t *transTccProcessor) GenBranches() []TransBranch {
return []TransBranch{} return []TransBranch{}
} }
func (t *TransTccProcessor) ExecBranch(db *common.DB, branch *TransBranch) { func (t *transTccProcessor) ExecBranch(db *common.DB, branch *TransBranch) {
resp, err := common.RestyClient.R().SetBody(branch.Data).SetHeader("Content-type", "application/json").SetQueryParams(t.getBranchParams(branch)).Post(branch.Url) resp, err := common.RestyClient.R().SetBody(branch.Data).SetHeader("Content-type", "application/json").SetQueryParams(t.getBranchParams(branch)).Post(branch.URL)
e2p(err) e2p(err)
body := resp.String() body := resp.String()
if strings.Contains(body, "SUCCESS") { if strings.Contains(body, "SUCCESS") {
@ -34,7 +34,7 @@ func (t *TransTccProcessor) ExecBranch(db *common.DB, branch *TransBranch) {
} }
} }
func (t *TransTccProcessor) ProcessOnce(db *common.DB, branches []TransBranch) { func (t *transTccProcessor) ProcessOnce(db *common.DB, branches []TransBranch) {
if t.Status == "succeed" || t.Status == "failed" { if t.Status == "succeed" || t.Status == "failed" {
return return
} }

View File

@ -7,23 +7,23 @@ import (
"github.com/yedf/dtm/common" "github.com/yedf/dtm/common"
) )
type TransXaProcessor struct { type transXaProcessor struct {
*TransGlobal *TransGlobal
} }
func init() { func init() {
registorProcessorCreator("xa", func(trans *TransGlobal) TransProcessor { return &TransXaProcessor{TransGlobal: trans} }) registorProcessorCreator("xa", func(trans *TransGlobal) transProcessor { return &transXaProcessor{TransGlobal: trans} })
} }
func (t *TransXaProcessor) GenBranches() []TransBranch { func (t *transXaProcessor) GenBranches() []TransBranch {
return []TransBranch{} return []TransBranch{}
} }
func (t *TransXaProcessor) ExecBranch(db *common.DB, branch *TransBranch) { func (t *transXaProcessor) ExecBranch(db *common.DB, branch *TransBranch) {
resp, err := common.RestyClient.R().SetBody(M{ resp, err := common.RestyClient.R().SetBody(M{
"branch_id": branch.BranchID, "branch_id": branch.BranchID,
"action": common.If(t.Status == "prepared", "rollback", "commit"), "action": common.If(t.Status == "prepared", "rollback", "commit"),
"gid": branch.Gid, "gid": branch.Gid,
}).Post(branch.Url) }).Post(branch.URL)
e2p(err) e2p(err)
body := resp.String() body := resp.String()
if strings.Contains(body, "SUCCESS") { if strings.Contains(body, "SUCCESS") {
@ -34,7 +34,7 @@ func (t *TransXaProcessor) ExecBranch(db *common.DB, branch *TransBranch) {
} }
} }
func (t *TransXaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) { func (t *transXaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) {
if t.Status == "succeed" { if t.Status == "succeed" {
return return
} }

View File

@ -11,6 +11,7 @@ import (
"github.com/yedf/dtm/common" "github.com/yedf/dtm/common"
) )
// M a short name
type M = map[string]interface{} type M = map[string]interface{}
var p2e = common.P2E var p2e = common.P2E
@ -20,22 +21,22 @@ func dbGet() *common.DB {
return common.DbGet(config.Mysql) return common.DbGet(config.Mysql)
} }
func writeTransLog(gid string, action string, status string, branch string, detail string) { func writeTransLog(gid string, action string, status string, branch string, detail string) {
return
db := dbGet()
if detail == "" { if detail == "" {
detail = "{}" detail = "{}"
} }
db.Must().Table("trans_log").Create(M{ // dbGet().Must().Table("trans_log").Create(M{
"gid": gid, // "gid": gid,
"action": action, // "action": action,
"new_status": status, // "new_status": status,
"branch": branch, // "branch": branch,
"detail": detail, // "detail": detail,
}) // })
} }
var TransProcessedTestChan chan string = nil // 用于测试时,通知处理结束 // TransProcessedTestChan only for test usage. when transaction processed once, write gid to this chan
var TransProcessedTestChan chan string = nil
// WaitTransProcessed only for test usage. wait for transaction processed once
func WaitTransProcessed(gid string) { func WaitTransProcessed(gid string) {
logrus.Printf("waiting for gid %s", gid) logrus.Printf("waiting for gid %s", gid)
id := <-TransProcessedTestChan id := <-TransProcessedTestChan
@ -54,11 +55,12 @@ func init() {
gNode = node gNode = node
} }
// GenGid generate gid, use ip + snowflake
func GenGid() string { func GenGid() string {
return getOneHexIp() + "_" + gNode.Generate().Base58() return getOneHexIP() + "_" + gNode.Generate().Base58()
} }
func getOneHexIp() string { func getOneHexIP() string {
addrs, err := net.InterfaceAddrs() addrs, err := net.InterfaceAddrs()
if err != nil { if err != nil {
fmt.Printf("cannot get ip, default to another call") fmt.Printf("cannot get ip, default to another call")

View File

@ -32,5 +32,5 @@ func RunSqlScript(mysql map[string]string, script string) {
func PopulateMysql() { func PopulateMysql() {
common.InitApp(common.GetProjectDir(), &Config) common.InitApp(common.GetProjectDir(), &Config)
Config.Mysql["database"] = dbName Config.Mysql["database"] = dbName
RunSqlScript(Config.Mysql, common.GetCurrentDir()+"/examples.sql") RunSqlScript(Config.Mysql, common.GetCurrentCodeDir()+"/examples.sql")
} }

View File

@ -32,32 +32,32 @@ func SagaBarrierAddRoute(app *gin.Engine) {
} }
func sagaBarrierAdjustBalance(sdb *sql.DB, uid int, amount int) (interface{}, error) { func sagaBarrierAdjustBalance(sdb *sql.DB, uid int, amount int) (interface{}, error) {
db := common.SqlDB2DB(sdb) db := common.SQLDB2DB(sdb)
dbr := db.Model(&UserAccount{}).Where("user_id = ?", 1).Update("balance", gorm.Expr("balance + ?", amount)) dbr := db.Model(&UserAccount{}).Where("user_id = ?", 1).Update("balance", gorm.Expr("balance + ?", amount))
return "SUCCESS", dbr.Error return "SUCCESS", dbr.Error
} }
func sagaBarrierTransIn(c *gin.Context) (interface{}, error) { func sagaBarrierTransIn(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSqlDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return sagaBarrierAdjustBalance(sdb, 1, reqFrom(c).Amount) return sagaBarrierAdjustBalance(sdb, 1, reqFrom(c).Amount)
}) })
} }
func sagaBarrierTransInCompensate(c *gin.Context) (interface{}, error) { func sagaBarrierTransInCompensate(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSqlDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return sagaBarrierAdjustBalance(sdb, 1, -reqFrom(c).Amount) return sagaBarrierAdjustBalance(sdb, 1, -reqFrom(c).Amount)
}) })
} }
func sagaBarrierTransOut(c *gin.Context) (interface{}, error) { func sagaBarrierTransOut(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSqlDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return sagaBarrierAdjustBalance(sdb, 2, -reqFrom(c).Amount) return sagaBarrierAdjustBalance(sdb, 2, -reqFrom(c).Amount)
}) })
} }
func sagaBarrierTransOutCompensate(c *gin.Context) (interface{}, error) { func sagaBarrierTransOutCompensate(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSqlDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return sagaBarrierAdjustBalance(sdb, 2, reqFrom(c).Amount) return sagaBarrierAdjustBalance(sdb, 2, reqFrom(c).Amount)
}) })
} }

View File

@ -49,7 +49,7 @@ const transInUid = 1
const transOutUid = 2 const transOutUid = 2
func adjustTrading(sdb *sql.DB, uid int, amount int) (interface{}, error) { func adjustTrading(sdb *sql.DB, uid int, amount int) (interface{}, error) {
db := common.SqlDB2DB(sdb) db := common.SQLDB2DB(sdb)
dbr := db.Exec("update dtm_busi.user_account_trading t join dtm_busi.user_account a on t.user_id=a.user_id and t.user_id=? set t.trading_balance=t.trading_balance + ? where a.balance + t.trading_balance + ? >= 0", uid, amount, amount) dbr := db.Exec("update dtm_busi.user_account_trading t join dtm_busi.user_account a on t.user_id=a.user_id and t.user_id=? set t.trading_balance=t.trading_balance + ? where a.balance + t.trading_balance + ? >= 0", uid, amount, amount)
if dbr.Error == nil && dbr.RowsAffected == 0 { if dbr.Error == nil && dbr.RowsAffected == 0 {
return nil, fmt.Errorf("update error, maybe balance not enough") return nil, fmt.Errorf("update error, maybe balance not enough")
@ -58,7 +58,7 @@ func adjustTrading(sdb *sql.DB, uid int, amount int) (interface{}, error) {
} }
func adjustBalance(sdb *sql.DB, uid int, amount int) (interface{}, error) { func adjustBalance(sdb *sql.DB, uid int, amount int) (interface{}, error) {
db := common.SqlDB2DB(sdb) db := common.SQLDB2DB(sdb)
dbr := db.Exec("update dtm_busi.user_account_trading t join dtm_busi.user_account a on t.user_id=a.user_id and t.user_id=? set t.trading_balance=t.trading_balance + ?", uid, -amount, -amount) dbr := db.Exec("update dtm_busi.user_account_trading t join dtm_busi.user_account a on t.user_id=a.user_id and t.user_id=? set t.trading_balance=t.trading_balance + ?", uid, -amount, -amount)
if dbr.Error == nil && dbr.RowsAffected == 1 { if dbr.Error == nil && dbr.RowsAffected == 1 {
dbr = db.Exec("update dtm_busi.user_account set balance=balance+? where user_id=?", amount, uid) dbr = db.Exec("update dtm_busi.user_account set balance=balance+? where user_id=?", amount, uid)
@ -74,37 +74,37 @@ func adjustBalance(sdb *sql.DB, uid int, amount int) (interface{}, error) {
// TCC下转入 // TCC下转入
func tccBarrierTransInTry(c *gin.Context) (interface{}, error) { func tccBarrierTransInTry(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSqlDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return adjustTrading(sdb, transInUid, reqFrom(c).Amount) return adjustTrading(sdb, transInUid, reqFrom(c).Amount)
}) })
} }
func tccBarrierTransInConfirm(c *gin.Context) (interface{}, error) { func tccBarrierTransInConfirm(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSqlDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return adjustBalance(sdb, transInUid, reqFrom(c).Amount) return adjustBalance(sdb, transInUid, reqFrom(c).Amount)
}) })
} }
func tccBarrierTransInCancel(c *gin.Context) (interface{}, error) { func tccBarrierTransInCancel(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSqlDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return adjustTrading(sdb, transInUid, -reqFrom(c).Amount) return adjustTrading(sdb, transInUid, -reqFrom(c).Amount)
}) })
} }
func tccBarrierTransOutTry(c *gin.Context) (interface{}, error) { func tccBarrierTransOutTry(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSqlDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return adjustTrading(sdb, transOutUid, -reqFrom(c).Amount) return adjustTrading(sdb, transOutUid, -reqFrom(c).Amount)
}) })
} }
func tccBarrierTransOutConfirm(c *gin.Context) (interface{}, error) { func tccBarrierTransOutConfirm(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSqlDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return adjustBalance(sdb, transOutUid, -reqFrom(c).Amount) return adjustBalance(sdb, transOutUid, -reqFrom(c).Amount)
}) })
} }
func tccBarrierTransOutCancel(c *gin.Context) (interface{}, error) { func tccBarrierTransOutCancel(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSqlDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return adjustTrading(sdb, transOutUid, reqFrom(c).Amount) return adjustTrading(sdb, transOutUid, reqFrom(c).Amount)
}) })
} }