fix golint

This commit is contained in:
yedongfu 2021-07-16 11:37:41 +08:00
parent c1a2f5c965
commit 193f54e7f3
16 changed files with 151 additions and 98 deletions

View File

@ -1,6 +1,7 @@
![license](https://img.shields.io/github/license/yedf/dtm)
[![Go Reference](https://pkg.go.dev/badge/github.com/yedf/dtm.svg)](https://pkg.go.dev/github.com/yedf/dtm)
[![Build Status](https://travis-ci.com/yedf/dtm.svg?branch=main)](https://travis-ci.com/yedf/dtm)
[![Go Report Card](https://goreportcard.com/badge/github.com/yedf/dtm)](https://goreportcard.com/report/github.com/yedf/dtm)
[![Go Reference](https://pkg.go.dev/badge/github.com/yedf/dtm.svg)](https://pkg.go.dev/github.com/yedf/dtm)
[English](https://github.com/yedf/dtm/blob/master/README-en.md)

View File

@ -12,9 +12,13 @@ import (
"gorm.io/gorm"
)
// M a short name
type M = map[string]interface{}
// MS a short name
type MS = map[string]string
// ModelBase model base for gorm to provide base fields
type ModelBase struct {
ID uint
CreateTime *time.Time `gorm:"autoCreateTime"`
@ -23,21 +27,25 @@ type ModelBase struct {
var dbs = map[string]*DB{}
// DB provide more func over gorm.DB
type DB struct {
*gorm.DB
}
// Must set must flag, panic when error occur
func (m *DB) Must() *DB {
db := m.InstanceSet("ivy.must", true)
return &DB{DB: db}
}
// NoMust unset must flag, don't panic when error occur
func (m *DB) NoMust() *DB {
db := m.InstanceSet("ivy.must", false)
return &DB{DB: db}
}
func (m *DB) ToSqlDB() *sql.DB {
// ToSQLDB get the sql.DB
func (m *DB) ToSQLDB() *sql.DB {
d, err := m.DB.DB()
E2P(err)
return d
@ -87,6 +95,7 @@ func (op *tracePlugin) Initialize(db *gorm.DB) (err error) {
return
}
// GetDsn get dsn from map config
func GetDsn(conf map[string]string) string {
if IsDockerCompose() {
conf["host"] = strings.Replace(conf["host"], "localhost", "host.docker.internal", 1)
@ -95,11 +104,13 @@ func GetDsn(conf map[string]string) string {
return fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", conf["user"], conf["password"], conf["host"], conf["port"], conf["database"])
}
// ReplaceDsnPassword replace password for log output
func ReplaceDsnPassword(dsn string) string {
reg := regexp.MustCompile(`:.*@`)
return reg.ReplaceAllString(dsn, ":****@")
}
// DbGet get db connection for specified conf
func DbGet(conf map[string]string) *DB {
dsn := GetDsn(conf)
if dbs[dsn] == nil {
@ -114,7 +125,8 @@ func DbGet(conf map[string]string) *DB {
return dbs[dsn]
}
func SqlDB2DB(sdb *sql.DB) *DB {
// SQLDB2DB name is clear
func SQLDB2DB(sdb *sql.DB) *DB {
db, err := gorm.Open(mysql.New(mysql.Config{
Conn: sdb,
}), &gorm.Config{})
@ -123,16 +135,19 @@ func SqlDB2DB(sdb *sql.DB) *DB {
return &DB{DB: db}
}
// MyConn for xa alone connection
type MyConn struct {
Conn *sql.DB
Dsn string
}
// Close name is clear
func (conn *MyConn) Close() {
logrus.Printf("closing alone mysql: %s", ReplaceDsnPassword(conn.Dsn))
conn.Conn.Close()
}
// DbAlone get a standalone db connection
func DbAlone(conf map[string]string) (*DB, *MyConn) {
dsn := GetDsn(conf)
logrus.Printf("opening alone mysql: %s", ReplaceDsnPassword(dsn))

View File

@ -20,6 +20,7 @@ import (
yaml "gopkg.in/yaml.v2"
)
// P2E panic to error
func P2E(perr *error) {
if x := recover(); x != nil {
if e, ok := x.(error); ok {
@ -30,18 +31,21 @@ func P2E(perr *error) {
}
}
// E2P error to panic
func E2P(err error) {
if err != nil {
panic(err)
}
}
// CatchP catch panic to error
func CatchP(f func()) (rerr error) {
defer P2E(&rerr)
f()
return nil
}
// PanicIf name is clear
func PanicIf(cond bool, err error) {
if cond {
panic(err)
@ -57,6 +61,7 @@ func MustAtoi(s string) int {
return r
}
// OrString return the first not empty string
func OrString(ss ...string) string {
for _, s := range ss {
if s != "" {
@ -66,6 +71,7 @@ func OrString(ss ...string) string {
return ""
}
// If ternary operator
func If(condition bool, trueObj interface{}, falseObj interface{}) interface{} {
if condition {
return trueObj
@ -73,24 +79,30 @@ func If(condition bool, trueObj interface{}, falseObj interface{}) interface{} {
return falseObj
}
// MustMarshal checked version for marshal
func MustMarshal(v interface{}) []byte {
b, err := json.Marshal(v)
E2P(err)
return b
}
// MustMarshalString string version of MustMarshal
func MustMarshalString(v interface{}) string {
return string(MustMarshal(v))
}
// MustUnmarshal checked version for unmarshal
func MustUnmarshal(b []byte, obj interface{}) {
err := json.Unmarshal(b, obj)
E2P(err)
}
// MustUnmarshalString string version of MustUnmarshal
func MustUnmarshalString(s string, obj interface{}) {
MustUnmarshal([]byte(s), obj)
}
// MustRemarshal marshal and unmarshal, and check error
func MustRemarshal(from interface{}, to interface{}) {
b, err := json.Marshal(from)
E2P(err)
@ -98,6 +110,7 @@ func MustRemarshal(from interface{}, to interface{}) {
E2P(err)
}
// GetGinApp init and return gin
func GetGinApp() *gin.Engine {
gin.SetMode(gin.ReleaseMode)
app := gin.Default()
@ -121,6 +134,7 @@ func GetGinApp() *gin.Engine {
return app
}
// WrapHandler name is clear
func WrapHandler(fn func(*gin.Context) (interface{}, error)) gin.HandlerFunc {
return func(c *gin.Context) {
r, err := fn(c)
@ -141,7 +155,7 @@ func WrapHandler(fn func(*gin.Context) (interface{}, error)) gin.HandlerFunc {
}
}
// 辅助工具与代码
// RestyClient the resty object
var RestyClient = resty.New()
func init() {
@ -162,6 +176,7 @@ func init() {
})
}
// CheckRestySuccess panic if error or resp not success
func CheckRestySuccess(resp *resty.Response, err error) {
E2P(err)
if !strings.Contains(resp.String(), "SUCCESS") {
@ -192,7 +207,7 @@ func (f *formatter) Format(entry *logrus.Entry) ([]byte, error) {
return b.Bytes(), nil
}
// 加载调用者文件相同目录下的配置文件
// InitApp init config
func InitApp(dir string, config interface{}) {
logrus.SetFormatter(&formatter{})
cont, err := ioutil.ReadFile(dir + "/conf.yml")
@ -205,17 +220,20 @@ func InitApp(dir string, config interface{}) {
E2P(err)
}
func Getwd() string {
// MustGetwd must version of os.Getwd
func MustGetwd() string {
wd, err := os.Getwd()
E2P(err)
return wd
}
func GetCurrentDir() string {
// GetCurrentCodeDir name is clear
func GetCurrentCodeDir() string {
_, file, _, _ := runtime.Caller(1)
return filepath.Dir(file)
}
// GetProjectDir name is clear
func GetProjectDir() string {
_, file, _, _ := runtime.Caller(1)
for ; !strings.HasSuffix(file, "/dtm"); file = filepath.Dir(file) {
@ -223,11 +241,13 @@ func GetProjectDir() string {
return file
}
// GetFuncName get current call func name
func GetFuncName() string {
pc, _, _, _ := runtime.Caller(1)
return runtime.FuncForPC(pc).Name()
}
// IsDockerCompose name is clear
func IsDockerCompose() bool {
return os.Getenv("IS_DOCKER_COMPOSE") != ""
}

View File

@ -10,37 +10,37 @@ import (
"gorm.io/gorm/clause"
)
func AddRoute(engine *gin.Engine) {
engine.POST("/api/dtmsvr/prepare", common.WrapHandler(Prepare))
engine.POST("/api/dtmsvr/submit", common.WrapHandler(Submit))
engine.POST("/api/dtmsvr/registerXaBranch", common.WrapHandler(RegisterXaBranch))
engine.POST("/api/dtmsvr/registerTccBranch", common.WrapHandler(RegisterTccBranch))
engine.POST("/api/dtmsvr/abort", common.WrapHandler(Abort))
engine.GET("/api/dtmsvr/query", common.WrapHandler(Query))
engine.GET("/api/dtmsvr/newGid", common.WrapHandler(NewGid))
func addRoute(engine *gin.Engine) {
engine.POST("/api/dtmsvr/prepare", common.WrapHandler(prepare))
engine.POST("/api/dtmsvr/submit", common.WrapHandler(submit))
engine.POST("/api/dtmsvr/registerXaBranch", common.WrapHandler(registerXaBranch))
engine.POST("/api/dtmsvr/registerTccBranch", common.WrapHandler(registerTccBranch))
engine.POST("/api/dtmsvr/abort", common.WrapHandler(abort))
engine.GET("/api/dtmsvr/query", common.WrapHandler(query))
engine.GET("/api/dtmsvr/newGid", common.WrapHandler(newGid))
}
func NewGid(c *gin.Context) (interface{}, error) {
func newGid(c *gin.Context) (interface{}, error) {
return M{"gid": GenGid()}, nil
}
func Prepare(c *gin.Context) (interface{}, error) {
func prepare(c *gin.Context) (interface{}, error) {
m := TransFromContext(c)
m.Status = "prepared"
m.SaveNew(dbGet())
m.saveNew(dbGet())
return M{"message": "SUCCESS", "gid": m.Gid}, nil
}
func Submit(c *gin.Context) (interface{}, error) {
func submit(c *gin.Context) (interface{}, error) {
db := dbGet()
m := TransFromContext(c)
m.Status = "submitted"
m.SaveNew(db)
m.saveNew(db)
go m.Process(db)
return M{"message": "SUCCESS", "gid": m.Gid}, nil
}
func Abort(c *gin.Context) (interface{}, error) {
func abort(c *gin.Context) (interface{}, error) {
db := dbGet()
m := TransFromContext(c)
m = TransFromDb(db, m.Gid)
@ -51,7 +51,7 @@ func Abort(c *gin.Context) (interface{}, error) {
return M{"message": "SUCCESS"}, nil
}
func RegisterXaBranch(c *gin.Context) (interface{}, error) {
func registerXaBranch(c *gin.Context) (interface{}, error) {
branch := TransBranch{}
err := c.BindJSON(&branch)
e2p(err)
@ -68,7 +68,7 @@ func RegisterXaBranch(c *gin.Context) (interface{}, error) {
return M{"message": "SUCCESS"}, nil
}
func RegisterTccBranch(c *gin.Context) (interface{}, error) {
func registerTccBranch(c *gin.Context) (interface{}, error) {
data := common.MS{}
err := c.BindJSON(&data)
e2p(err)
@ -82,7 +82,7 @@ func RegisterTccBranch(c *gin.Context) (interface{}, error) {
branches := []TransBranch{branch, branch, branch}
for i, b := range []string{"cancel", "confirm", "try"} {
branches[i].BranchType = b
branches[i].Url = data[b]
branches[i].URL = data[b]
}
dbGet().Must().Clauses(clause.OnConflict{
@ -94,7 +94,7 @@ func RegisterTccBranch(c *gin.Context) (interface{}, error) {
return M{"message": "SUCCESS"}, nil
}
func Query(c *gin.Context) (interface{}, error) {
func query(c *gin.Context) (interface{}, error) {
gid := c.Query("gid")
if gid == "" {
return nil, errors.New("no gid specified")

View File

@ -9,13 +9,17 @@ import (
"github.com/sirupsen/logrus"
)
// CronPrepared cron expired prepared trans forever
func CronPrepared() {
for {
CronTransOnce(time.Duration(0), "prepared")
sleepCronTime()
notEmpty := CronTransOnce(time.Duration(0), "prepared")
if !notEmpty {
sleepCronTime()
}
}
}
// CronTransOnce cron expired trans who's status match param status for once. use expireIn as expire time
func CronTransOnce(expireIn time.Duration, status string) bool {
defer handlePanic()
trans := lockOneTrans(expireIn, status)
@ -27,6 +31,7 @@ func CronTransOnce(expireIn time.Duration, status string) bool {
return true
}
// CronSubmitted cron expired submitted trans forever
func CronSubmitted() {
for {
notEmpty := CronTransOnce(time.Duration(0), "submitted")

View File

@ -238,7 +238,7 @@ func TestSqlDB(t *testing.T) {
BranchType: "compensate",
}
db.Must().Exec("insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type) values('saga', 'gid1', 'branch_id1', 'action')")
_, err := dtmcli.ThroughBarrierCall(db.ToSqlDB(), transInfo, func(db *sql.DB) (interface{}, error) {
_, err := dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.DB) (interface{}, error) {
logrus.Printf("rollback gid2")
return nil, fmt.Errorf("gid2 error")
})
@ -247,7 +247,7 @@ func TestSqlDB(t *testing.T) {
asserts.Equal(dbr.RowsAffected, int64(1))
dbr = db.Model(&dtmcli.BarrierModel{}).Where("gid=?", "gid2").Find(&[]dtmcli.BarrierModel{})
asserts.Equal(dbr.RowsAffected, int64(0))
_, err = dtmcli.ThroughBarrierCall(db.ToSqlDB(), transInfo, func(db *sql.DB) (interface{}, error) {
_, err = dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.DB) (interface{}, error) {
logrus.Printf("submit gid2")
return nil, nil
})

View File

@ -11,25 +11,28 @@ import (
var dtmsvrPort = 8080
// MainStart main
func MainStart() {
StartSvr()
go CronSubmitted()
go CronPrepared()
}
// StartSvr StartSvr
func StartSvr() {
logrus.Printf("start dtmsvr")
common.InitApp(common.GetProjectDir(), &config)
config.Mysql["database"] = dbName
app := common.GetGinApp()
AddRoute(app)
addRoute(app)
logrus.Printf("dtmsvr listen at: %d", dtmsvrPort)
go app.Run(fmt.Sprintf(":%d", dtmsvrPort))
time.Sleep(100 * time.Millisecond)
}
// PopulateMysql setup mysql data
func PopulateMysql() {
common.InitApp(common.GetProjectDir(), &config)
config.Mysql["database"] = ""
examples.RunSqlScript(config.Mysql, common.GetCurrentDir()+"/dtmsvr.sql")
examples.RunSqlScript(config.Mysql, common.GetCurrentCodeDir()+"/dtmsvr.sql")
}

View File

@ -11,6 +11,7 @@ import (
"gorm.io/gorm/clause"
)
// TransGlobal global transaction
type TransGlobal struct {
common.ModelBase
Gid string `json:"gid"`
@ -25,11 +26,12 @@ type TransGlobal struct {
NextCronTime *time.Time
}
// TableName TableName
func (*TransGlobal) TableName() string {
return "dtm.trans_global"
}
type TransProcessor interface {
type transProcessor interface {
GenBranches() []TransBranch
ProcessOnce(db *common.DB, branches []TransBranch)
ExecBranch(db *common.DB, branch *TransBranch)
@ -60,10 +62,11 @@ func (t *TransGlobal) changeStatus(db *common.DB, status string) *gorm.DB {
return dbr
}
// TransBranch branch transaction
type TransBranch struct {
common.ModelBase
Gid string
Url string
URL string `json:"url"`
Data string
BranchID string `json:"branch_id"`
BranchType string
@ -72,6 +75,7 @@ type TransBranch struct {
RollbackTime *time.Time
}
// TableName TableName
func (*TransBranch) TableName() string {
return "dtm.trans_branch"
}
@ -93,7 +97,7 @@ func checkAffected(db1 *gorm.DB) {
}
}
type processorCreator func(*TransGlobal) TransProcessor
type processorCreator func(*TransGlobal) transProcessor
var processorFac = map[string]processorCreator{}
@ -101,28 +105,29 @@ func registorProcessorCreator(transType string, creator processorCreator) {
processorFac[transType] = creator
}
func (trans *TransGlobal) getProcessor() TransProcessor {
return processorFac[trans.TransType](trans)
func (t *TransGlobal) getProcessor() transProcessor {
return processorFac[t.TransType](t)
}
func (trans *TransGlobal) Process(db *common.DB) {
// Process process global transaction once
func (t *TransGlobal) Process(db *common.DB) {
defer handlePanic()
defer func() {
if TransProcessedTestChan != nil {
logrus.Printf("processed: %s", trans.Gid)
TransProcessedTestChan <- trans.Gid
logrus.Printf("processed: %s", t.Gid)
TransProcessedTestChan <- t.Gid
}
}()
logrus.Printf("processing: %s", trans.Gid)
logrus.Printf("processing: %s", t.Gid)
branches := []TransBranch{}
db.Must().Where("gid=?", trans.Gid).Order("id asc").Find(&branches)
trans.getProcessor().ProcessOnce(db, branches)
db.Must().Where("gid=?", t.Gid).Order("id asc").Find(&branches)
t.getProcessor().ProcessOnce(db, branches)
}
func (trans *TransGlobal) getBranchParams(branch *TransBranch) common.MS {
func (t *TransGlobal) getBranchParams(branch *TransBranch) common.MS {
return common.MS{
"gid": trans.Gid,
"trans_type": trans.TransType,
"gid": t.Gid,
"trans_type": t.TransType,
"branch_id": branch.BranchID,
"branch_type": branch.BranchType,
}
@ -135,7 +140,7 @@ func (t *TransGlobal) setNextCron(expireIn int64) []string {
return []string{"next_cron_interval", "next_cron_time"}
}
func (t *TransGlobal) SaveNew(db *common.DB) {
func (t *TransGlobal) saveNew(db *common.DB) {
if t.Gid == "" {
t.Gid = GenGid()
}
@ -162,6 +167,7 @@ func (t *TransGlobal) SaveNew(db *common.DB) {
e2p(err)
}
// TransFromContext TransFromContext
func TransFromContext(c *gin.Context) *TransGlobal {
data := M{}
b, err := c.GetRawData()
@ -176,6 +182,7 @@ func TransFromContext(c *gin.Context) *TransGlobal {
return &m
}
// TransFromDb construct trans from db
func TransFromDb(db *common.DB, gid string) *TransGlobal {
m := TransGlobal{}
dbr := db.Must().Model(&m).Where("gid=?", gid).First(&m)

View File

@ -7,15 +7,15 @@ import (
"github.com/yedf/dtm/common"
)
type TransMsgProcessor struct {
type transMsgProcessor struct {
*TransGlobal
}
func init() {
registorProcessorCreator("msg", func(trans *TransGlobal) TransProcessor { return &TransMsgProcessor{TransGlobal: trans} })
registorProcessorCreator("msg", func(trans *TransGlobal) transProcessor { return &transMsgProcessor{TransGlobal: trans} })
}
func (t *TransMsgProcessor) GenBranches() []TransBranch {
func (t *transMsgProcessor) GenBranches() []TransBranch {
branches := []TransBranch{}
steps := []M{}
common.MustUnmarshalString(t.Data, &steps)
@ -24,7 +24,7 @@ func (t *TransMsgProcessor) GenBranches() []TransBranch {
Gid: t.Gid,
BranchID: GenGid(),
Data: step["data"].(string),
Url: step["action"].(string),
URL: step["action"].(string),
BranchType: "action",
Status: "prepared",
})
@ -32,8 +32,8 @@ func (t *TransMsgProcessor) GenBranches() []TransBranch {
return branches
}
func (t *TransMsgProcessor) ExecBranch(db *common.DB, branch *TransBranch) {
resp, err := common.RestyClient.R().SetBody(branch.Data).SetQueryParams(t.getBranchParams(branch)).Post(branch.Url)
func (t *transMsgProcessor) ExecBranch(db *common.DB, branch *TransBranch) {
resp, err := common.RestyClient.R().SetBody(branch.Data).SetQueryParams(t.getBranchParams(branch)).Post(branch.URL)
e2p(err)
body := resp.String()
if strings.Contains(body, "SUCCESS") {
@ -58,7 +58,7 @@ func (t *TransGlobal) mayQueryPrepared(db *common.DB) {
}
}
func (t *TransMsgProcessor) ProcessOnce(db *common.DB, branches []TransBranch) {
func (t *transMsgProcessor) ProcessOnce(db *common.DB, branches []TransBranch) {
t.mayQueryPrepared(db)
if t.Status != "submitted" {
return

View File

@ -7,15 +7,15 @@ import (
"github.com/yedf/dtm/common"
)
type TransSagaProcessor struct {
type transSagaProcessor struct {
*TransGlobal
}
func init() {
registorProcessorCreator("saga", func(trans *TransGlobal) TransProcessor { return &TransSagaProcessor{TransGlobal: trans} })
registorProcessorCreator("saga", func(trans *TransGlobal) transProcessor { return &transSagaProcessor{TransGlobal: trans} })
}
func (t *TransSagaProcessor) GenBranches() []TransBranch {
func (t *transSagaProcessor) GenBranches() []TransBranch {
branches := []TransBranch{}
steps := []M{}
common.MustUnmarshalString(t.Data, &steps)
@ -26,7 +26,7 @@ func (t *TransSagaProcessor) GenBranches() []TransBranch {
Gid: t.Gid,
BranchID: branch,
Data: step["data"].(string),
Url: step[branchType].(string),
URL: step[branchType].(string),
BranchType: branchType,
Status: "prepared",
})
@ -35,8 +35,8 @@ func (t *TransSagaProcessor) GenBranches() []TransBranch {
return branches
}
func (t *TransSagaProcessor) ExecBranch(db *common.DB, branch *TransBranch) {
resp, err := common.RestyClient.R().SetBody(branch.Data).SetQueryParams(t.getBranchParams(branch)).Post(branch.Url)
func (t *transSagaProcessor) ExecBranch(db *common.DB, branch *TransBranch) {
resp, err := common.RestyClient.R().SetBody(branch.Data).SetQueryParams(t.getBranchParams(branch)).Post(branch.URL)
e2p(err)
body := resp.String()
if strings.Contains(body, "SUCCESS") {
@ -50,7 +50,7 @@ func (t *TransSagaProcessor) ExecBranch(db *common.DB, branch *TransBranch) {
}
}
func (t *TransSagaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) {
func (t *transSagaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) {
if t.Status != "submitted" {
return
}

View File

@ -7,20 +7,20 @@ import (
"github.com/yedf/dtm/common"
)
type TransTccProcessor struct {
type transTccProcessor struct {
*TransGlobal
}
func init() {
registorProcessorCreator("tcc", func(trans *TransGlobal) TransProcessor { return &TransTccProcessor{TransGlobal: trans} })
registorProcessorCreator("tcc", func(trans *TransGlobal) transProcessor { return &transTccProcessor{TransGlobal: trans} })
}
func (t *TransTccProcessor) GenBranches() []TransBranch {
func (t *transTccProcessor) GenBranches() []TransBranch {
return []TransBranch{}
}
func (t *TransTccProcessor) ExecBranch(db *common.DB, branch *TransBranch) {
resp, err := common.RestyClient.R().SetBody(branch.Data).SetHeader("Content-type", "application/json").SetQueryParams(t.getBranchParams(branch)).Post(branch.Url)
func (t *transTccProcessor) ExecBranch(db *common.DB, branch *TransBranch) {
resp, err := common.RestyClient.R().SetBody(branch.Data).SetHeader("Content-type", "application/json").SetQueryParams(t.getBranchParams(branch)).Post(branch.URL)
e2p(err)
body := resp.String()
if strings.Contains(body, "SUCCESS") {
@ -34,7 +34,7 @@ func (t *TransTccProcessor) ExecBranch(db *common.DB, branch *TransBranch) {
}
}
func (t *TransTccProcessor) ProcessOnce(db *common.DB, branches []TransBranch) {
func (t *transTccProcessor) ProcessOnce(db *common.DB, branches []TransBranch) {
if t.Status == "succeed" || t.Status == "failed" {
return
}

View File

@ -7,23 +7,23 @@ import (
"github.com/yedf/dtm/common"
)
type TransXaProcessor struct {
type transXaProcessor struct {
*TransGlobal
}
func init() {
registorProcessorCreator("xa", func(trans *TransGlobal) TransProcessor { return &TransXaProcessor{TransGlobal: trans} })
registorProcessorCreator("xa", func(trans *TransGlobal) transProcessor { return &transXaProcessor{TransGlobal: trans} })
}
func (t *TransXaProcessor) GenBranches() []TransBranch {
func (t *transXaProcessor) GenBranches() []TransBranch {
return []TransBranch{}
}
func (t *TransXaProcessor) ExecBranch(db *common.DB, branch *TransBranch) {
func (t *transXaProcessor) ExecBranch(db *common.DB, branch *TransBranch) {
resp, err := common.RestyClient.R().SetBody(M{
"branch_id": branch.BranchID,
"action": common.If(t.Status == "prepared", "rollback", "commit"),
"gid": branch.Gid,
}).Post(branch.Url)
}).Post(branch.URL)
e2p(err)
body := resp.String()
if strings.Contains(body, "SUCCESS") {
@ -34,7 +34,7 @@ func (t *TransXaProcessor) ExecBranch(db *common.DB, branch *TransBranch) {
}
}
func (t *TransXaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) {
func (t *transXaProcessor) ProcessOnce(db *common.DB, branches []TransBranch) {
if t.Status == "succeed" {
return
}

View File

@ -11,6 +11,7 @@ import (
"github.com/yedf/dtm/common"
)
// M a short name
type M = map[string]interface{}
var p2e = common.P2E
@ -20,22 +21,22 @@ func dbGet() *common.DB {
return common.DbGet(config.Mysql)
}
func writeTransLog(gid string, action string, status string, branch string, detail string) {
return
db := dbGet()
if detail == "" {
detail = "{}"
}
db.Must().Table("trans_log").Create(M{
"gid": gid,
"action": action,
"new_status": status,
"branch": branch,
"detail": detail,
})
// dbGet().Must().Table("trans_log").Create(M{
// "gid": gid,
// "action": action,
// "new_status": status,
// "branch": branch,
// "detail": detail,
// })
}
var TransProcessedTestChan chan string = nil // 用于测试时,通知处理结束
// TransProcessedTestChan only for test usage. when transaction processed once, write gid to this chan
var TransProcessedTestChan chan string = nil
// WaitTransProcessed only for test usage. wait for transaction processed once
func WaitTransProcessed(gid string) {
logrus.Printf("waiting for gid %s", gid)
id := <-TransProcessedTestChan
@ -54,11 +55,12 @@ func init() {
gNode = node
}
// GenGid generate gid, use ip + snowflake
func GenGid() string {
return getOneHexIp() + "_" + gNode.Generate().Base58()
return getOneHexIP() + "_" + gNode.Generate().Base58()
}
func getOneHexIp() string {
func getOneHexIP() string {
addrs, err := net.InterfaceAddrs()
if err != nil {
fmt.Printf("cannot get ip, default to another call")

View File

@ -32,5 +32,5 @@ func RunSqlScript(mysql map[string]string, script string) {
func PopulateMysql() {
common.InitApp(common.GetProjectDir(), &Config)
Config.Mysql["database"] = dbName
RunSqlScript(Config.Mysql, common.GetCurrentDir()+"/examples.sql")
RunSqlScript(Config.Mysql, common.GetCurrentCodeDir()+"/examples.sql")
}

View File

@ -32,32 +32,32 @@ func SagaBarrierAddRoute(app *gin.Engine) {
}
func sagaBarrierAdjustBalance(sdb *sql.DB, uid int, amount int) (interface{}, error) {
db := common.SqlDB2DB(sdb)
db := common.SQLDB2DB(sdb)
dbr := db.Model(&UserAccount{}).Where("user_id = ?", 1).Update("balance", gorm.Expr("balance + ?", amount))
return "SUCCESS", dbr.Error
}
func sagaBarrierTransIn(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSqlDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return sagaBarrierAdjustBalance(sdb, 1, reqFrom(c).Amount)
})
}
func sagaBarrierTransInCompensate(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSqlDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return sagaBarrierAdjustBalance(sdb, 1, -reqFrom(c).Amount)
})
}
func sagaBarrierTransOut(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSqlDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return sagaBarrierAdjustBalance(sdb, 2, -reqFrom(c).Amount)
})
}
func sagaBarrierTransOutCompensate(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSqlDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return sagaBarrierAdjustBalance(sdb, 2, reqFrom(c).Amount)
})
}

View File

@ -49,7 +49,7 @@ const transInUid = 1
const transOutUid = 2
func adjustTrading(sdb *sql.DB, uid int, amount int) (interface{}, error) {
db := common.SqlDB2DB(sdb)
db := common.SQLDB2DB(sdb)
dbr := db.Exec("update dtm_busi.user_account_trading t join dtm_busi.user_account a on t.user_id=a.user_id and t.user_id=? set t.trading_balance=t.trading_balance + ? where a.balance + t.trading_balance + ? >= 0", uid, amount, amount)
if dbr.Error == nil && dbr.RowsAffected == 0 {
return nil, fmt.Errorf("update error, maybe balance not enough")
@ -58,7 +58,7 @@ func adjustTrading(sdb *sql.DB, uid int, amount int) (interface{}, error) {
}
func adjustBalance(sdb *sql.DB, uid int, amount int) (interface{}, error) {
db := common.SqlDB2DB(sdb)
db := common.SQLDB2DB(sdb)
dbr := db.Exec("update dtm_busi.user_account_trading t join dtm_busi.user_account a on t.user_id=a.user_id and t.user_id=? set t.trading_balance=t.trading_balance + ?", uid, -amount, -amount)
if dbr.Error == nil && dbr.RowsAffected == 1 {
dbr = db.Exec("update dtm_busi.user_account set balance=balance+? where user_id=?", amount, uid)
@ -74,37 +74,37 @@ func adjustBalance(sdb *sql.DB, uid int, amount int) (interface{}, error) {
// TCC下转入
func tccBarrierTransInTry(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSqlDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return adjustTrading(sdb, transInUid, reqFrom(c).Amount)
})
}
func tccBarrierTransInConfirm(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSqlDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return adjustBalance(sdb, transInUid, reqFrom(c).Amount)
})
}
func tccBarrierTransInCancel(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSqlDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return adjustTrading(sdb, transInUid, -reqFrom(c).Amount)
})
}
func tccBarrierTransOutTry(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSqlDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return adjustTrading(sdb, transOutUid, -reqFrom(c).Amount)
})
}
func tccBarrierTransOutConfirm(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSqlDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return adjustBalance(sdb, transOutUid, -reqFrom(c).Amount)
})
}
func tccBarrierTransOutCancel(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSqlDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return adjustTrading(sdb, transOutUid, reqFrom(c).Amount)
})
}