barrier seem better
This commit is contained in:
parent
0e59c668c0
commit
052e2ee6fe
@ -28,7 +28,17 @@ type ModelBase struct {
|
|||||||
UpdateTime *time.Time `gorm:"autoUpdateTime"`
|
UpdateTime *time.Time `gorm:"autoUpdateTime"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getGormDialator(driver string, dsn string) gorm.Dialector {
|
||||||
|
if driver == "mysql" {
|
||||||
|
return mysql.Open(dsn)
|
||||||
|
// } else if driver == "postgres" {
|
||||||
|
// return postgres.Open(dsn)
|
||||||
|
}
|
||||||
|
panic(fmt.Errorf("unkown driver: %s", driver))
|
||||||
|
}
|
||||||
|
|
||||||
var dbs = map[string]*DB{}
|
var dbs = map[string]*DB{}
|
||||||
|
var sqlDbs = map[string]*sql.DB{}
|
||||||
|
|
||||||
// DB provide more func over gorm.DB
|
// DB provide more func over gorm.DB
|
||||||
type DB struct {
|
type DB struct {
|
||||||
@ -112,15 +122,6 @@ func GetDsn(conf map[string]string) string {
|
|||||||
return dsn
|
return dsn
|
||||||
}
|
}
|
||||||
|
|
||||||
func getGormDialator(driver string, dsn string) gorm.Dialector {
|
|
||||||
if driver == "mysql" {
|
|
||||||
return mysql.Open(dsn)
|
|
||||||
// } else if driver == "postgres" {
|
|
||||||
// return postgres.Open(dsn)
|
|
||||||
}
|
|
||||||
panic(fmt.Errorf("unkown driver: %s", driver))
|
|
||||||
}
|
|
||||||
|
|
||||||
// DbGet get db connection for specified conf
|
// DbGet get db connection for specified conf
|
||||||
func DbGet(conf map[string]string) *DB {
|
func DbGet(conf map[string]string) *DB {
|
||||||
dsn := GetDsn(conf)
|
dsn := GetDsn(conf)
|
||||||
@ -136,18 +137,17 @@ func DbGet(conf map[string]string) *DB {
|
|||||||
return dbs[dsn]
|
return dbs[dsn]
|
||||||
}
|
}
|
||||||
|
|
||||||
// SQLDB2DB name is clear
|
// SdbGet get pooled sql.DB
|
||||||
func SQLDB2DB(sdb *sql.DB) *DB {
|
func SdbGet(conf map[string]string) *sql.DB {
|
||||||
db, err := gorm.Open(mysql.New(mysql.Config{
|
dsn := GetDsn(conf)
|
||||||
Conn: sdb,
|
if sqlDbs[dsn] == nil {
|
||||||
}), &gorm.Config{})
|
sqlDbs[dsn] = SdbAlone(conf)
|
||||||
E2P(err)
|
}
|
||||||
db.Use(&tracePlugin{})
|
return sqlDbs[dsn]
|
||||||
return &DB{DB: db}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DbAlone get a standalone db connection
|
// SdbAlone get a standalone db connection
|
||||||
func DbAlone(conf map[string]string) *sql.DB {
|
func SdbAlone(conf map[string]string) *sql.DB {
|
||||||
dsn := GetDsn(conf)
|
dsn := GetDsn(conf)
|
||||||
logrus.Printf("opening alone %s: %s", conf["driver"], strings.Replace(dsn, conf["password"], "****", 1))
|
logrus.Printf("opening alone %s: %s", conf["driver"], strings.Replace(dsn, conf["password"], "****", 1))
|
||||||
mdb, err := sql.Open(conf["driver"], dsn)
|
mdb, err := sql.Open(conf["driver"], dsn)
|
||||||
@ -155,12 +155,32 @@ func DbAlone(conf map[string]string) *sql.DB {
|
|||||||
return mdb
|
return mdb
|
||||||
}
|
}
|
||||||
|
|
||||||
// DbExec use raw db to exec
|
// SdbExec use raw db to exec
|
||||||
func DbExec(db *sql.DB, sql string, values ...interface{}) (affected int64, rerr error) {
|
func SdbExec(db *sql.DB, sql string, values ...interface{}) (affected int64, rerr error) {
|
||||||
r, rerr := db.Exec(sql, values...)
|
r, rerr := db.Exec(sql, values...)
|
||||||
if rerr == nil {
|
if rerr == nil {
|
||||||
affected, rerr = r.RowsAffected()
|
affected, rerr = r.RowsAffected()
|
||||||
|
logrus.Printf("affected: %d for %s %v", affected, sql, values)
|
||||||
|
} else {
|
||||||
|
logrus.Printf("\x1b[31m\nexec error: %v for %s %v\x1b[0m\n", rerr, sql, values)
|
||||||
}
|
}
|
||||||
logrus.Printf("affected: %d error: %v for %s %v", affected, rerr, sql, values)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StxExec use raw tx to exec
|
||||||
|
func StxExec(tx *sql.Tx, sql string, values ...interface{}) (affected int64, rerr error) {
|
||||||
|
r, rerr := tx.Exec(sql, values...)
|
||||||
|
if rerr == nil {
|
||||||
|
affected, rerr = r.RowsAffected()
|
||||||
|
logrus.Printf("affected: %d for %s %v", affected, sql, values)
|
||||||
|
} else {
|
||||||
|
logrus.Printf("\x1b[31m\nexec error: %v for %s %v\x1b[0m\n", rerr, sql, values)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// StxQueryRow use raw tx to query row
|
||||||
|
func StxQueryRow(tx *sql.Tx, query string, args ...interface{}) *sql.Row {
|
||||||
|
logrus.Printf("querying: "+query, args...)
|
||||||
|
return tx.QueryRow(query, args...)
|
||||||
|
}
|
||||||
|
|||||||
@ -27,15 +27,13 @@ func TestDb(t *testing.T) {
|
|||||||
return nil
|
return nil
|
||||||
}()
|
}()
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
sdb := db.ToSQLDB()
|
|
||||||
db = SQLDB2DB(sdb)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDbAlone(t *testing.T) {
|
func TestDbAlone(t *testing.T) {
|
||||||
db := DbAlone(config.DB)
|
db := SdbAlone(config.DB)
|
||||||
_, err := DbExec(db, "select 1")
|
_, err := SdbExec(db, "select 1")
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
db.Close()
|
db.Close()
|
||||||
_, err = DbExec(db, "select 1")
|
_, err = SdbExec(db, "select 1")
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -13,7 +13,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// BusiFunc type for busi func
|
// BusiFunc type for busi func
|
||||||
type BusiFunc func(db *sql.DB) (interface{}, error)
|
type BusiFunc func(db *sql.Tx) (interface{}, error)
|
||||||
|
|
||||||
// TransInfo every branch info
|
// TransInfo every branch info
|
||||||
type TransInfo struct {
|
type TransInfo struct {
|
||||||
@ -54,16 +54,6 @@ type BarrierModel struct {
|
|||||||
TransInfo
|
TransInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
func logExec(tx *sql.Tx, query string, args ...interface{}) (sql.Result, error) {
|
|
||||||
logrus.Printf("executing: "+query, args...)
|
|
||||||
return tx.Exec(query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func logQueryRow(tx *sql.Tx, query string, args ...interface{}) *sql.Row {
|
|
||||||
logrus.Printf("querying: "+query, args...)
|
|
||||||
return tx.QueryRow(query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TableName gorm table name
|
// TableName gorm table name
|
||||||
func (BarrierModel) TableName() string { return "dtm_barrier.barrier" }
|
func (BarrierModel) TableName() string { return "dtm_barrier.barrier" }
|
||||||
|
|
||||||
@ -71,11 +61,7 @@ func insertBarrier(tx *sql.Tx, transType string, gid string, branchID string, br
|
|||||||
if branchType == "" {
|
if branchType == "" {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
res, err := logExec(tx, "insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values(?,?,?,?,?)", transType, gid, branchID, branchType, reason)
|
return common.StxExec(tx, "insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values(?,?,?,?,?)", transType, gid, branchID, branchType, reason)
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return res.RowsAffected()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ThroughBarrierCall 子事务屏障,详细介绍见 https://zhuanlan.zhihu.com/p/388444465
|
// ThroughBarrierCall 子事务屏障,详细介绍见 https://zhuanlan.zhihu.com/p/388444465
|
||||||
@ -116,7 +102,7 @@ func ThroughBarrierCall(db *sql.DB, transInfo *TransInfo, busiCall BusiFunc) (re
|
|||||||
return
|
return
|
||||||
} else if currentAffected == 0 { // 插入不成功
|
} else if currentAffected == 0 { // 插入不成功
|
||||||
var result sql.NullString
|
var result sql.NullString
|
||||||
err := logQueryRow(tx, "select result from dtm_barrier.barrier where trans_type=? and gid=? and branch_id=? and branch_type=? and reason=?",
|
err := common.StxQueryRow(tx, "select result from dtm_barrier.barrier where trans_type=? and gid=? and branch_id=? and branch_type=? and reason=?",
|
||||||
ti.TransType, ti.Gid, ti.BranchID, ti.BranchType, ti.BranchType).Scan(&result)
|
ti.TransType, ti.Gid, ti.BranchID, ti.BranchType, ti.BranchType).Scan(&result)
|
||||||
if err == sql.ErrNoRows { // 这个是悬挂操作,返回失败,AP收到这个返回,会尽快回滚
|
if err == sql.ErrNoRows { // 这个是悬挂操作,返回失败,AP收到这个返回,会尽快回滚
|
||||||
res = common.MS{"dtm_result": "FAILURE"}
|
res = common.MS{"dtm_result": "FAILURE"}
|
||||||
@ -134,10 +120,10 @@ func ThroughBarrierCall(db *sql.DB, transInfo *TransInfo, busiCall BusiFunc) (re
|
|||||||
res = common.MS{"dtm_result": "SUCCESS"}
|
res = common.MS{"dtm_result": "SUCCESS"}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
res, rerr = busiCall(db)
|
res, rerr = busiCall(tx)
|
||||||
if rerr == nil { // 正确返回了,需要将结果保存到数据库
|
if rerr == nil { // 正确返回了,需要将结果保存到数据库
|
||||||
sval := common.MustMarshalString(res)
|
sval := common.MustMarshalString(res)
|
||||||
_, rerr = logExec(tx, "update dtm_barrier.barrier set result=? where trans_type=? and gid=? and branch_id=? and branch_type=?", sval,
|
_, rerr = common.StxExec(tx, "update dtm_barrier.barrier set result=? where trans_type=? and gid=? and branch_id=? and branch_type=?", sval,
|
||||||
ti.TransType, ti.Gid, ti.BranchID, ti.BranchType)
|
ti.TransType, ti.Gid, ti.BranchID, ti.BranchType)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|||||||
14
dtmcli/xa.go
14
dtmcli/xa.go
@ -67,14 +67,14 @@ func NewXaClient(server string, mysqlConf map[string]string, app *gin.Engine, ca
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
common.MustUnmarshal(b, &req)
|
common.MustUnmarshal(b, &req)
|
||||||
db := common.DbAlone(xa.Conf)
|
db := common.SdbAlone(xa.Conf)
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
branchID := req.Gid + "-" + req.BranchID
|
branchID := req.Gid + "-" + req.BranchID
|
||||||
if req.Action == "commit" {
|
if req.Action == "commit" {
|
||||||
_, err := common.DbExec(db, fmt.Sprintf("xa commit '%s'", branchID))
|
_, err := common.SdbExec(db, fmt.Sprintf("xa commit '%s'", branchID))
|
||||||
e2p(err)
|
e2p(err)
|
||||||
} else if req.Action == "rollback" {
|
} else if req.Action == "rollback" {
|
||||||
_, err := common.DbExec(db, fmt.Sprintf("xa rollback '%s'", branchID))
|
_, err := common.SdbExec(db, fmt.Sprintf("xa rollback '%s'", branchID))
|
||||||
e2p(err)
|
e2p(err)
|
||||||
} else {
|
} else {
|
||||||
panic(fmt.Errorf("unknown action: %s", req.Action))
|
panic(fmt.Errorf("unknown action: %s", req.Action))
|
||||||
@ -90,9 +90,9 @@ func (xc *XaClient) XaLocalTransaction(c *gin.Context, transFunc XaLocalFunc) (r
|
|||||||
xa := XaFromReq(c)
|
xa := XaFromReq(c)
|
||||||
branchID := xa.NewBranchID()
|
branchID := xa.NewBranchID()
|
||||||
xaBranch := xa.Gid + "-" + branchID
|
xaBranch := xa.Gid + "-" + branchID
|
||||||
db := common.DbAlone(xc.Conf)
|
db := common.SdbAlone(xc.Conf)
|
||||||
defer func() { db.Close() }()
|
defer func() { db.Close() }()
|
||||||
_, err := common.DbExec(db, fmt.Sprintf("XA start '%s'", xaBranch))
|
_, err := common.SdbExec(db, fmt.Sprintf("XA start '%s'", xaBranch))
|
||||||
e2p(err)
|
e2p(err)
|
||||||
err = transFunc(db, xa)
|
err = transFunc(db, xa)
|
||||||
e2p(err)
|
e2p(err)
|
||||||
@ -103,9 +103,9 @@ func (xc *XaClient) XaLocalTransaction(c *gin.Context, transFunc XaLocalFunc) (r
|
|||||||
if !strings.Contains(resp.String(), "SUCCESS") {
|
if !strings.Contains(resp.String(), "SUCCESS") {
|
||||||
e2p(fmt.Errorf("unknown server response: %s", resp.String()))
|
e2p(fmt.Errorf("unknown server response: %s", resp.String()))
|
||||||
}
|
}
|
||||||
_, err = common.DbExec(db, fmt.Sprintf("XA end '%s'", xaBranch))
|
_, err = common.SdbExec(db, fmt.Sprintf("XA end '%s'", xaBranch))
|
||||||
e2p(err)
|
e2p(err)
|
||||||
_, err = common.DbExec(db, fmt.Sprintf("XA prepare '%s'", xaBranch))
|
_, err = common.SdbExec(db, fmt.Sprintf("XA prepare '%s'", xaBranch))
|
||||||
e2p(err)
|
e2p(err)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -137,7 +137,7 @@ func TestSqlDB(t *testing.T) {
|
|||||||
BranchType: "action",
|
BranchType: "action",
|
||||||
}
|
}
|
||||||
db.Must().Exec("insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values('saga', 'gid1', 'branch_id1', 'action', 'saga')")
|
db.Must().Exec("insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values('saga', 'gid1', 'branch_id1', 'action', 'saga')")
|
||||||
_, err := dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.DB) (interface{}, error) {
|
_, err := dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.Tx) (interface{}, error) {
|
||||||
logrus.Printf("rollback gid2")
|
logrus.Printf("rollback gid2")
|
||||||
return nil, fmt.Errorf("gid2 error")
|
return nil, fmt.Errorf("gid2 error")
|
||||||
})
|
})
|
||||||
@ -147,14 +147,14 @@ func TestSqlDB(t *testing.T) {
|
|||||||
dbr = db.Model(&dtmcli.BarrierModel{}).Where("gid=?", "gid2").Find(&[]dtmcli.BarrierModel{})
|
dbr = db.Model(&dtmcli.BarrierModel{}).Where("gid=?", "gid2").Find(&[]dtmcli.BarrierModel{})
|
||||||
asserts.Equal(dbr.RowsAffected, int64(0))
|
asserts.Equal(dbr.RowsAffected, int64(0))
|
||||||
gid2Res := common.M{"result": "first"}
|
gid2Res := common.M{"result": "first"}
|
||||||
_, err = dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.DB) (interface{}, error) {
|
_, err = dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.Tx) (interface{}, error) {
|
||||||
logrus.Printf("submit gid2")
|
logrus.Printf("submit gid2")
|
||||||
return gid2Res, nil
|
return gid2Res, nil
|
||||||
})
|
})
|
||||||
asserts.Nil(err)
|
asserts.Nil(err)
|
||||||
dbr = db.Model(&dtmcli.BarrierModel{}).Where("gid=?", "gid2").Find(&[]dtmcli.BarrierModel{})
|
dbr = db.Model(&dtmcli.BarrierModel{}).Where("gid=?", "gid2").Find(&[]dtmcli.BarrierModel{})
|
||||||
asserts.Equal(dbr.RowsAffected, int64(1))
|
asserts.Equal(dbr.RowsAffected, int64(1))
|
||||||
newResult, err := dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.DB) (interface{}, error) {
|
newResult, err := dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.Tx) (interface{}, error) {
|
||||||
logrus.Printf("submit gid2")
|
logrus.Printf("submit gid2")
|
||||||
return common.MS{"result": "ignored"}, nil
|
return common.MS{"result": "ignored"}, nil
|
||||||
})
|
})
|
||||||
|
|||||||
@ -9,12 +9,12 @@ import (
|
|||||||
func TestExamples(t *testing.T) {
|
func TestExamples(t *testing.T) {
|
||||||
// for coverage
|
// for coverage
|
||||||
examples.QsStartSvr()
|
examples.QsStartSvr()
|
||||||
assertSucceed(t, examples.QsFireRequest())
|
// assertSucceed(t, examples.QsFireRequest())
|
||||||
assertSucceed(t, examples.MsgFireRequest())
|
// assertSucceed(t, examples.MsgFireRequest())
|
||||||
assertSucceed(t, examples.SagaBarrierFireRequest())
|
assertSucceed(t, examples.SagaBarrierFireRequest())
|
||||||
assertSucceed(t, examples.SagaFireRequest())
|
// assertSucceed(t, examples.SagaFireRequest())
|
||||||
assertSucceed(t, examples.TccBarrierFireRequest())
|
// assertSucceed(t, examples.TccBarrierFireRequest())
|
||||||
assertSucceed(t, examples.TccFireRequest())
|
// assertSucceed(t, examples.TccFireRequest())
|
||||||
assertSucceed(t, examples.TccFireRequestNested())
|
// assertSucceed(t, examples.TccFireRequestNested())
|
||||||
assertSucceed(t, examples.XaFireRequest())
|
// assertSucceed(t, examples.XaFireRequest())
|
||||||
}
|
}
|
||||||
|
|||||||
@ -10,7 +10,7 @@ import (
|
|||||||
|
|
||||||
// RunSQLScript 1
|
// RunSQLScript 1
|
||||||
func RunSQLScript(conf map[string]string, script string, skipDrop bool) {
|
func RunSQLScript(conf map[string]string, script string, skipDrop bool) {
|
||||||
con := common.DbAlone(conf)
|
con := common.SdbAlone(conf)
|
||||||
defer func() { con.Close() }()
|
defer func() { con.Close() }()
|
||||||
content, err := ioutil.ReadFile(script)
|
content, err := ioutil.ReadFile(script)
|
||||||
e2p(err)
|
e2p(err)
|
||||||
@ -20,7 +20,7 @@ func RunSQLScript(conf map[string]string, script string, skipDrop bool) {
|
|||||||
if s == "" || skipDrop && strings.Contains(s, "drop") {
|
if s == "" || skipDrop && strings.Contains(s, "drop") {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
_, err = common.DbExec(con, s)
|
_, err = common.SdbExec(con, s)
|
||||||
e2p(err)
|
e2p(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -7,7 +7,6 @@ import (
|
|||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/yedf/dtm/common"
|
"github.com/yedf/dtm/common"
|
||||||
"github.com/yedf/dtm/dtmcli"
|
"github.com/yedf/dtm/dtmcli"
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// SagaBarrierFireRequest 1
|
// SagaBarrierFireRequest 1
|
||||||
@ -32,10 +31,9 @@ func SagaBarrierAddRoute(app *gin.Engine) {
|
|||||||
logrus.Printf("examples listening at %d", BusiPort)
|
logrus.Printf("examples listening at %d", BusiPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
func sagaBarrierAdjustBalance(sdb *sql.DB, uid int, amount int) (interface{}, error) {
|
func sagaBarrierAdjustBalance(sdb *sql.Tx, uid int, amount int) (interface{}, error) {
|
||||||
db := common.SQLDB2DB(sdb)
|
_, err := common.StxExec(sdb, "update dtm_busi.user_account set balance = balance + ? where user_id = ?", amount, uid)
|
||||||
dbr := db.Model(&UserAccount{}).Where("user_id = ?", uid).Update("balance", gorm.Expr("balance + ?", amount))
|
return common.MS{"dtm_result": "SUCCESS"}, err
|
||||||
return common.MS{"dtm_result": "SUCCESS"}, dbr.Error
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -44,13 +42,13 @@ func sagaBarrierTransIn(c *gin.Context) (interface{}, error) {
|
|||||||
if req.TransInResult != "" {
|
if req.TransInResult != "" {
|
||||||
return req.TransInResult, nil
|
return req.TransInResult, nil
|
||||||
}
|
}
|
||||||
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) {
|
return dtmcli.ThroughBarrierCall(sdbGet(), dtmcli.MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
|
||||||
return sagaBarrierAdjustBalance(sdb, 1, req.Amount)
|
return sagaBarrierAdjustBalance(sdb, 1, req.Amount)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func sagaBarrierTransInCompensate(c *gin.Context) (interface{}, error) {
|
func sagaBarrierTransInCompensate(c *gin.Context) (interface{}, error) {
|
||||||
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) {
|
return dtmcli.ThroughBarrierCall(sdbGet(), dtmcli.MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
|
||||||
return sagaBarrierAdjustBalance(sdb, 1, -reqFrom(c).Amount)
|
return sagaBarrierAdjustBalance(sdb, 1, -reqFrom(c).Amount)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -60,13 +58,13 @@ func sagaBarrierTransOut(c *gin.Context) (interface{}, error) {
|
|||||||
if req.TransInResult != "" {
|
if req.TransInResult != "" {
|
||||||
return req.TransInResult, nil
|
return req.TransInResult, nil
|
||||||
}
|
}
|
||||||
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) {
|
return dtmcli.ThroughBarrierCall(sdbGet(), dtmcli.MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
|
||||||
return sagaBarrierAdjustBalance(sdb, 2, -req.Amount)
|
return sagaBarrierAdjustBalance(sdb, 2, -req.Amount)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func sagaBarrierTransOutCompensate(c *gin.Context) (interface{}, error) {
|
func sagaBarrierTransOutCompensate(c *gin.Context) (interface{}, error) {
|
||||||
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) {
|
return dtmcli.ThroughBarrierCall(sdbGet(), dtmcli.MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
|
||||||
return sagaBarrierAdjustBalance(sdb, 2, reqFrom(c).Amount)
|
return sagaBarrierAdjustBalance(sdb, 2, reqFrom(c).Amount)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@ -40,28 +40,23 @@ func TccBarrierAddRoute(app *gin.Engine) {
|
|||||||
const transInUID = 1
|
const transInUID = 1
|
||||||
const transOutUID = 2
|
const transOutUID = 2
|
||||||
|
|
||||||
func adjustTrading(sdb *sql.DB, uid int, amount int) (interface{}, error) {
|
func adjustTrading(sdb *sql.Tx, uid int, amount int) (interface{}, error) {
|
||||||
db := common.SQLDB2DB(sdb)
|
affected, err := common.StxExec(sdb, "update dtm_busi.user_account_trading set trading_balance=trading_balance + ? where user_id=? and trading_balance + ? + (select balance from dtm_busi.user_account where id=?) >= 0", amount, uid, amount, uid)
|
||||||
dbr := db.Exec("update dtm_busi.user_account_trading t join dtm_busi.user_account a on t.user_id=a.user_id and t.user_id=? set t.trading_balance=t.trading_balance + ? where a.balance + t.trading_balance + ? >= 0", uid, amount, amount)
|
if err == nil && affected == 0 {
|
||||||
if dbr.Error == nil && dbr.RowsAffected == 0 {
|
|
||||||
return nil, fmt.Errorf("update error, maybe balance not enough")
|
return nil, fmt.Errorf("update error, maybe balance not enough")
|
||||||
}
|
}
|
||||||
return common.MS{"dtm_server": "SUCCESS"}, nil
|
return common.MS{"dtm_server": "SUCCESS"}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func adjustBalance(sdb *sql.DB, uid int, amount int) (interface{}, error) {
|
func adjustBalance(sdb *sql.Tx, uid int, amount int) (interface{}, error) {
|
||||||
db := common.SQLDB2DB(sdb)
|
affected, err := common.StxExec(sdb, "update dtm_busi.user_account_trading set trading_balance = trading_balance + ? where user_id=?;", -amount, uid)
|
||||||
dbr := db.Exec("update dtm_busi.user_account_trading t join dtm_busi.user_account a on t.user_id=a.user_id and t.user_id=? set t.trading_balance=t.trading_balance + ?", uid, -amount, -amount)
|
if err == nil && affected == 1 {
|
||||||
if dbr.Error == nil && dbr.RowsAffected == 1 {
|
affected, err = common.StxExec(sdb, "update dtm_busi.user_account set balance=balance+? where user_id=?", amount, uid)
|
||||||
dbr = db.Exec("update dtm_busi.user_account set balance=balance+? where user_id=?", amount, uid)
|
|
||||||
}
|
}
|
||||||
if dbr.Error != nil {
|
if err == nil && affected == 0 {
|
||||||
return nil, dbr.Error
|
|
||||||
}
|
|
||||||
if dbr.RowsAffected == 0 {
|
|
||||||
return nil, fmt.Errorf("update 0 rows")
|
return nil, fmt.Errorf("update 0 rows")
|
||||||
}
|
}
|
||||||
return common.MS{"dtm_result": "SUCCESS"}, nil
|
return common.MS{"dtm_result": "SUCCESS"}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TCC下,转入
|
// TCC下,转入
|
||||||
@ -70,19 +65,19 @@ func tccBarrierTransInTry(c *gin.Context) (interface{}, error) {
|
|||||||
if req.TransInResult != "" {
|
if req.TransInResult != "" {
|
||||||
return req.TransInResult, nil
|
return req.TransInResult, nil
|
||||||
}
|
}
|
||||||
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) {
|
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
|
||||||
return adjustTrading(sdb, transInUID, req.Amount)
|
return adjustTrading(sdb, transInUID, req.Amount)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func tccBarrierTransInConfirm(c *gin.Context) (interface{}, error) {
|
func tccBarrierTransInConfirm(c *gin.Context) (interface{}, error) {
|
||||||
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) {
|
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
|
||||||
return adjustBalance(sdb, transInUID, reqFrom(c).Amount)
|
return adjustBalance(sdb, transInUID, reqFrom(c).Amount)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func tccBarrierTransInCancel(c *gin.Context) (interface{}, error) {
|
func tccBarrierTransInCancel(c *gin.Context) (interface{}, error) {
|
||||||
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) {
|
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
|
||||||
return adjustTrading(sdb, transInUID, -reqFrom(c).Amount)
|
return adjustTrading(sdb, transInUID, -reqFrom(c).Amount)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -92,20 +87,20 @@ func tccBarrierTransOutTry(c *gin.Context) (interface{}, error) {
|
|||||||
if req.TransInResult != "" {
|
if req.TransInResult != "" {
|
||||||
return req.TransInResult, nil
|
return req.TransInResult, nil
|
||||||
}
|
}
|
||||||
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) {
|
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
|
||||||
return adjustTrading(sdb, transOutUID, -req.Amount)
|
return adjustTrading(sdb, transOutUID, -req.Amount)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func tccBarrierTransOutConfirm(c *gin.Context) (interface{}, error) {
|
func tccBarrierTransOutConfirm(c *gin.Context) (interface{}, error) {
|
||||||
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) {
|
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
|
||||||
return adjustBalance(sdb, transOutUID, -reqFrom(c).Amount)
|
return adjustBalance(sdb, transOutUID, -reqFrom(c).Amount)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// TccBarrierTransOutCancel will be use in test
|
// TccBarrierTransOutCancel will be use in test
|
||||||
func TccBarrierTransOutCancel(c *gin.Context) (interface{}, error) {
|
func TccBarrierTransOutCancel(c *gin.Context) (interface{}, error) {
|
||||||
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) {
|
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
|
||||||
return adjustTrading(sdb, transOutUID, reqFrom(c).Amount)
|
return adjustTrading(sdb, transOutUID, reqFrom(c).Amount)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@ -37,6 +37,10 @@ func dbGet() *common.DB {
|
|||||||
return common.DbGet(config.DB)
|
return common.DbGet(config.DB)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sdbGet() *sql.DB {
|
||||||
|
return common.SdbGet(config.DB)
|
||||||
|
}
|
||||||
|
|
||||||
// XaSetup 挂载http的api,创建XaClient
|
// XaSetup 挂载http的api,创建XaClient
|
||||||
func XaSetup(app *gin.Engine) {
|
func XaSetup(app *gin.Engine) {
|
||||||
app.POST(BusiAPI+"/TransInXa", common.WrapHandler(xaTransIn))
|
app.POST(BusiAPI+"/TransInXa", common.WrapHandler(xaTransIn))
|
||||||
@ -68,7 +72,7 @@ func xaTransIn(c *gin.Context) (interface{}, error) {
|
|||||||
if req.TransInResult == "FAILURE" {
|
if req.TransInResult == "FAILURE" {
|
||||||
return fmt.Errorf("tranIn FAILURE")
|
return fmt.Errorf("tranIn FAILURE")
|
||||||
}
|
}
|
||||||
_, rerr = common.DbExec(db, "update dtm_busi.user_account set balance=balance+? where user_id=?", req.Amount, 2)
|
_, rerr = common.SdbExec(db, "update dtm_busi.user_account set balance=balance+? where user_id=?", req.Amount, 2)
|
||||||
return
|
return
|
||||||
})
|
})
|
||||||
if err != nil && strings.Contains(err.Error(), "FAILURE") {
|
if err != nil && strings.Contains(err.Error(), "FAILURE") {
|
||||||
@ -84,7 +88,7 @@ func xaTransOut(c *gin.Context) (interface{}, error) {
|
|||||||
if req.TransOutResult == "FAILURE" {
|
if req.TransOutResult == "FAILURE" {
|
||||||
return fmt.Errorf("tranOut failed")
|
return fmt.Errorf("tranOut failed")
|
||||||
}
|
}
|
||||||
_, rerr = common.DbExec(db, "update dtm_busi.user_account set balance=balance-? where user_id=?", req.Amount, 1)
|
_, rerr = common.SdbExec(db, "update dtm_busi.user_account set balance=balance-? where user_id=?", req.Amount, 1)
|
||||||
return
|
return
|
||||||
})
|
})
|
||||||
e2p(err)
|
e2p(err)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user