barrier seem better

This commit is contained in:
yedf2 2021-07-31 18:10:45 +08:00
parent 0e59c668c0
commit 052e2ee6fe
10 changed files with 97 additions and 96 deletions

View File

@ -28,7 +28,17 @@ type ModelBase struct {
UpdateTime *time.Time `gorm:"autoUpdateTime"`
}
func getGormDialator(driver string, dsn string) gorm.Dialector {
if driver == "mysql" {
return mysql.Open(dsn)
// } else if driver == "postgres" {
// return postgres.Open(dsn)
}
panic(fmt.Errorf("unkown driver: %s", driver))
}
var dbs = map[string]*DB{}
var sqlDbs = map[string]*sql.DB{}
// DB provide more func over gorm.DB
type DB struct {
@ -112,15 +122,6 @@ func GetDsn(conf map[string]string) string {
return dsn
}
func getGormDialator(driver string, dsn string) gorm.Dialector {
if driver == "mysql" {
return mysql.Open(dsn)
// } else if driver == "postgres" {
// return postgres.Open(dsn)
}
panic(fmt.Errorf("unkown driver: %s", driver))
}
// DbGet get db connection for specified conf
func DbGet(conf map[string]string) *DB {
dsn := GetDsn(conf)
@ -136,18 +137,17 @@ func DbGet(conf map[string]string) *DB {
return dbs[dsn]
}
// SQLDB2DB name is clear
func SQLDB2DB(sdb *sql.DB) *DB {
db, err := gorm.Open(mysql.New(mysql.Config{
Conn: sdb,
}), &gorm.Config{})
E2P(err)
db.Use(&tracePlugin{})
return &DB{DB: db}
// SdbGet get pooled sql.DB
func SdbGet(conf map[string]string) *sql.DB {
dsn := GetDsn(conf)
if sqlDbs[dsn] == nil {
sqlDbs[dsn] = SdbAlone(conf)
}
return sqlDbs[dsn]
}
// DbAlone get a standalone db connection
func DbAlone(conf map[string]string) *sql.DB {
// SdbAlone get a standalone db connection
func SdbAlone(conf map[string]string) *sql.DB {
dsn := GetDsn(conf)
logrus.Printf("opening alone %s: %s", conf["driver"], strings.Replace(dsn, conf["password"], "****", 1))
mdb, err := sql.Open(conf["driver"], dsn)
@ -155,12 +155,32 @@ func DbAlone(conf map[string]string) *sql.DB {
return mdb
}
// DbExec use raw db to exec
func DbExec(db *sql.DB, sql string, values ...interface{}) (affected int64, rerr error) {
// SdbExec use raw db to exec
func SdbExec(db *sql.DB, sql string, values ...interface{}) (affected int64, rerr error) {
r, rerr := db.Exec(sql, values...)
if rerr == nil {
affected, rerr = r.RowsAffected()
logrus.Printf("affected: %d for %s %v", affected, sql, values)
} else {
logrus.Printf("\x1b[31m\nexec error: %v for %s %v\x1b[0m\n", rerr, sql, values)
}
logrus.Printf("affected: %d error: %v for %s %v", affected, rerr, sql, values)
return
}
// StxExec use raw tx to exec
func StxExec(tx *sql.Tx, sql string, values ...interface{}) (affected int64, rerr error) {
r, rerr := tx.Exec(sql, values...)
if rerr == nil {
affected, rerr = r.RowsAffected()
logrus.Printf("affected: %d for %s %v", affected, sql, values)
} else {
logrus.Printf("\x1b[31m\nexec error: %v for %s %v\x1b[0m\n", rerr, sql, values)
}
return
}
// StxQueryRow use raw tx to query row
func StxQueryRow(tx *sql.Tx, query string, args ...interface{}) *sql.Row {
logrus.Printf("querying: "+query, args...)
return tx.QueryRow(query, args...)
}

View File

@ -27,15 +27,13 @@ func TestDb(t *testing.T) {
return nil
}()
assert.NotEqual(t, nil, err)
sdb := db.ToSQLDB()
db = SQLDB2DB(sdb)
}
func TestDbAlone(t *testing.T) {
db := DbAlone(config.DB)
_, err := DbExec(db, "select 1")
db := SdbAlone(config.DB)
_, err := SdbExec(db, "select 1")
assert.Equal(t, nil, err)
db.Close()
_, err = DbExec(db, "select 1")
_, err = SdbExec(db, "select 1")
assert.NotEqual(t, nil, err)
}

View File

@ -13,7 +13,7 @@ import (
)
// BusiFunc type for busi func
type BusiFunc func(db *sql.DB) (interface{}, error)
type BusiFunc func(db *sql.Tx) (interface{}, error)
// TransInfo every branch info
type TransInfo struct {
@ -54,16 +54,6 @@ type BarrierModel struct {
TransInfo
}
func logExec(tx *sql.Tx, query string, args ...interface{}) (sql.Result, error) {
logrus.Printf("executing: "+query, args...)
return tx.Exec(query, args...)
}
func logQueryRow(tx *sql.Tx, query string, args ...interface{}) *sql.Row {
logrus.Printf("querying: "+query, args...)
return tx.QueryRow(query, args...)
}
// TableName gorm table name
func (BarrierModel) TableName() string { return "dtm_barrier.barrier" }
@ -71,11 +61,7 @@ func insertBarrier(tx *sql.Tx, transType string, gid string, branchID string, br
if branchType == "" {
return 0, nil
}
res, err := logExec(tx, "insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values(?,?,?,?,?)", transType, gid, branchID, branchType, reason)
if err != nil {
return 0, err
}
return res.RowsAffected()
return common.StxExec(tx, "insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values(?,?,?,?,?)", transType, gid, branchID, branchType, reason)
}
// ThroughBarrierCall 子事务屏障,详细介绍见 https://zhuanlan.zhihu.com/p/388444465
@ -116,7 +102,7 @@ func ThroughBarrierCall(db *sql.DB, transInfo *TransInfo, busiCall BusiFunc) (re
return
} else if currentAffected == 0 { // 插入不成功
var result sql.NullString
err := logQueryRow(tx, "select result from dtm_barrier.barrier where trans_type=? and gid=? and branch_id=? and branch_type=? and reason=?",
err := common.StxQueryRow(tx, "select result from dtm_barrier.barrier where trans_type=? and gid=? and branch_id=? and branch_type=? and reason=?",
ti.TransType, ti.Gid, ti.BranchID, ti.BranchType, ti.BranchType).Scan(&result)
if err == sql.ErrNoRows { // 这个是悬挂操作返回失败AP收到这个返回会尽快回滚
res = common.MS{"dtm_result": "FAILURE"}
@ -134,10 +120,10 @@ func ThroughBarrierCall(db *sql.DB, transInfo *TransInfo, busiCall BusiFunc) (re
res = common.MS{"dtm_result": "SUCCESS"}
return
}
res, rerr = busiCall(db)
res, rerr = busiCall(tx)
if rerr == nil { // 正确返回了,需要将结果保存到数据库
sval := common.MustMarshalString(res)
_, rerr = logExec(tx, "update dtm_barrier.barrier set result=? where trans_type=? and gid=? and branch_id=? and branch_type=?", sval,
_, rerr = common.StxExec(tx, "update dtm_barrier.barrier set result=? where trans_type=? and gid=? and branch_id=? and branch_type=?", sval,
ti.TransType, ti.Gid, ti.BranchID, ti.BranchType)
}
return

View File

@ -67,14 +67,14 @@ func NewXaClient(server string, mysqlConf map[string]string, app *gin.Engine, ca
return nil, err
}
common.MustUnmarshal(b, &req)
db := common.DbAlone(xa.Conf)
db := common.SdbAlone(xa.Conf)
defer db.Close()
branchID := req.Gid + "-" + req.BranchID
if req.Action == "commit" {
_, err := common.DbExec(db, fmt.Sprintf("xa commit '%s'", branchID))
_, err := common.SdbExec(db, fmt.Sprintf("xa commit '%s'", branchID))
e2p(err)
} else if req.Action == "rollback" {
_, err := common.DbExec(db, fmt.Sprintf("xa rollback '%s'", branchID))
_, err := common.SdbExec(db, fmt.Sprintf("xa rollback '%s'", branchID))
e2p(err)
} else {
panic(fmt.Errorf("unknown action: %s", req.Action))
@ -90,9 +90,9 @@ func (xc *XaClient) XaLocalTransaction(c *gin.Context, transFunc XaLocalFunc) (r
xa := XaFromReq(c)
branchID := xa.NewBranchID()
xaBranch := xa.Gid + "-" + branchID
db := common.DbAlone(xc.Conf)
db := common.SdbAlone(xc.Conf)
defer func() { db.Close() }()
_, err := common.DbExec(db, fmt.Sprintf("XA start '%s'", xaBranch))
_, err := common.SdbExec(db, fmt.Sprintf("XA start '%s'", xaBranch))
e2p(err)
err = transFunc(db, xa)
e2p(err)
@ -103,9 +103,9 @@ func (xc *XaClient) XaLocalTransaction(c *gin.Context, transFunc XaLocalFunc) (r
if !strings.Contains(resp.String(), "SUCCESS") {
e2p(fmt.Errorf("unknown server response: %s", resp.String()))
}
_, err = common.DbExec(db, fmt.Sprintf("XA end '%s'", xaBranch))
_, err = common.SdbExec(db, fmt.Sprintf("XA end '%s'", xaBranch))
e2p(err)
_, err = common.DbExec(db, fmt.Sprintf("XA prepare '%s'", xaBranch))
_, err = common.SdbExec(db, fmt.Sprintf("XA prepare '%s'", xaBranch))
e2p(err)
return nil
}

View File

@ -137,7 +137,7 @@ func TestSqlDB(t *testing.T) {
BranchType: "action",
}
db.Must().Exec("insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values('saga', 'gid1', 'branch_id1', 'action', 'saga')")
_, err := dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.DB) (interface{}, error) {
_, err := dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.Tx) (interface{}, error) {
logrus.Printf("rollback gid2")
return nil, fmt.Errorf("gid2 error")
})
@ -147,14 +147,14 @@ func TestSqlDB(t *testing.T) {
dbr = db.Model(&dtmcli.BarrierModel{}).Where("gid=?", "gid2").Find(&[]dtmcli.BarrierModel{})
asserts.Equal(dbr.RowsAffected, int64(0))
gid2Res := common.M{"result": "first"}
_, err = dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.DB) (interface{}, error) {
_, err = dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.Tx) (interface{}, error) {
logrus.Printf("submit gid2")
return gid2Res, nil
})
asserts.Nil(err)
dbr = db.Model(&dtmcli.BarrierModel{}).Where("gid=?", "gid2").Find(&[]dtmcli.BarrierModel{})
asserts.Equal(dbr.RowsAffected, int64(1))
newResult, err := dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.DB) (interface{}, error) {
newResult, err := dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.Tx) (interface{}, error) {
logrus.Printf("submit gid2")
return common.MS{"result": "ignored"}, nil
})

View File

@ -9,12 +9,12 @@ import (
func TestExamples(t *testing.T) {
// for coverage
examples.QsStartSvr()
assertSucceed(t, examples.QsFireRequest())
assertSucceed(t, examples.MsgFireRequest())
// assertSucceed(t, examples.QsFireRequest())
// assertSucceed(t, examples.MsgFireRequest())
assertSucceed(t, examples.SagaBarrierFireRequest())
assertSucceed(t, examples.SagaFireRequest())
assertSucceed(t, examples.TccBarrierFireRequest())
assertSucceed(t, examples.TccFireRequest())
assertSucceed(t, examples.TccFireRequestNested())
assertSucceed(t, examples.XaFireRequest())
// assertSucceed(t, examples.SagaFireRequest())
// assertSucceed(t, examples.TccBarrierFireRequest())
// assertSucceed(t, examples.TccFireRequest())
// assertSucceed(t, examples.TccFireRequestNested())
// assertSucceed(t, examples.XaFireRequest())
}

View File

@ -10,7 +10,7 @@ import (
// RunSQLScript 1
func RunSQLScript(conf map[string]string, script string, skipDrop bool) {
con := common.DbAlone(conf)
con := common.SdbAlone(conf)
defer func() { con.Close() }()
content, err := ioutil.ReadFile(script)
e2p(err)
@ -20,7 +20,7 @@ func RunSQLScript(conf map[string]string, script string, skipDrop bool) {
if s == "" || skipDrop && strings.Contains(s, "drop") {
continue
}
_, err = common.DbExec(con, s)
_, err = common.SdbExec(con, s)
e2p(err)
}
}

View File

@ -7,7 +7,6 @@ import (
"github.com/sirupsen/logrus"
"github.com/yedf/dtm/common"
"github.com/yedf/dtm/dtmcli"
"gorm.io/gorm"
)
// SagaBarrierFireRequest 1
@ -32,10 +31,9 @@ func SagaBarrierAddRoute(app *gin.Engine) {
logrus.Printf("examples listening at %d", BusiPort)
}
func sagaBarrierAdjustBalance(sdb *sql.DB, uid int, amount int) (interface{}, error) {
db := common.SQLDB2DB(sdb)
dbr := db.Model(&UserAccount{}).Where("user_id = ?", uid).Update("balance", gorm.Expr("balance + ?", amount))
return common.MS{"dtm_result": "SUCCESS"}, dbr.Error
func sagaBarrierAdjustBalance(sdb *sql.Tx, uid int, amount int) (interface{}, error) {
_, err := common.StxExec(sdb, "update dtm_busi.user_account set balance = balance + ? where user_id = ?", amount, uid)
return common.MS{"dtm_result": "SUCCESS"}, err
}
@ -44,13 +42,13 @@ func sagaBarrierTransIn(c *gin.Context) (interface{}, error) {
if req.TransInResult != "" {
return req.TransInResult, nil
}
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) {
return dtmcli.ThroughBarrierCall(sdbGet(), dtmcli.MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
return sagaBarrierAdjustBalance(sdb, 1, req.Amount)
})
}
func sagaBarrierTransInCompensate(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) {
return dtmcli.ThroughBarrierCall(sdbGet(), dtmcli.MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
return sagaBarrierAdjustBalance(sdb, 1, -reqFrom(c).Amount)
})
}
@ -60,13 +58,13 @@ func sagaBarrierTransOut(c *gin.Context) (interface{}, error) {
if req.TransInResult != "" {
return req.TransInResult, nil
}
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) {
return dtmcli.ThroughBarrierCall(sdbGet(), dtmcli.MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
return sagaBarrierAdjustBalance(sdb, 2, -req.Amount)
})
}
func sagaBarrierTransOutCompensate(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) {
return dtmcli.ThroughBarrierCall(sdbGet(), dtmcli.MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
return sagaBarrierAdjustBalance(sdb, 2, reqFrom(c).Amount)
})
}

View File

@ -40,28 +40,23 @@ func TccBarrierAddRoute(app *gin.Engine) {
const transInUID = 1
const transOutUID = 2
func adjustTrading(sdb *sql.DB, uid int, amount int) (interface{}, error) {
db := common.SQLDB2DB(sdb)
dbr := db.Exec("update dtm_busi.user_account_trading t join dtm_busi.user_account a on t.user_id=a.user_id and t.user_id=? set t.trading_balance=t.trading_balance + ? where a.balance + t.trading_balance + ? >= 0", uid, amount, amount)
if dbr.Error == nil && dbr.RowsAffected == 0 {
func adjustTrading(sdb *sql.Tx, uid int, amount int) (interface{}, error) {
affected, err := common.StxExec(sdb, "update dtm_busi.user_account_trading set trading_balance=trading_balance + ? where user_id=? and trading_balance + ? + (select balance from dtm_busi.user_account where id=?) >= 0", amount, uid, amount, uid)
if err == nil && affected == 0 {
return nil, fmt.Errorf("update error, maybe balance not enough")
}
return common.MS{"dtm_server": "SUCCESS"}, nil
}
func adjustBalance(sdb *sql.DB, uid int, amount int) (interface{}, error) {
db := common.SQLDB2DB(sdb)
dbr := db.Exec("update dtm_busi.user_account_trading t join dtm_busi.user_account a on t.user_id=a.user_id and t.user_id=? set t.trading_balance=t.trading_balance + ?", uid, -amount, -amount)
if dbr.Error == nil && dbr.RowsAffected == 1 {
dbr = db.Exec("update dtm_busi.user_account set balance=balance+? where user_id=?", amount, uid)
func adjustBalance(sdb *sql.Tx, uid int, amount int) (interface{}, error) {
affected, err := common.StxExec(sdb, "update dtm_busi.user_account_trading set trading_balance = trading_balance + ? where user_id=?;", -amount, uid)
if err == nil && affected == 1 {
affected, err = common.StxExec(sdb, "update dtm_busi.user_account set balance=balance+? where user_id=?", amount, uid)
}
if dbr.Error != nil {
return nil, dbr.Error
}
if dbr.RowsAffected == 0 {
if err == nil && affected == 0 {
return nil, fmt.Errorf("update 0 rows")
}
return common.MS{"dtm_result": "SUCCESS"}, nil
return common.MS{"dtm_result": "SUCCESS"}, err
}
// TCC下转入
@ -70,19 +65,19 @@ func tccBarrierTransInTry(c *gin.Context) (interface{}, error) {
if req.TransInResult != "" {
return req.TransInResult, nil
}
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
return adjustTrading(sdb, transInUID, req.Amount)
})
}
func tccBarrierTransInConfirm(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
return adjustBalance(sdb, transInUID, reqFrom(c).Amount)
})
}
func tccBarrierTransInCancel(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
return adjustTrading(sdb, transInUID, -reqFrom(c).Amount)
})
}
@ -92,20 +87,20 @@ func tccBarrierTransOutTry(c *gin.Context) (interface{}, error) {
if req.TransInResult != "" {
return req.TransInResult, nil
}
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
return adjustTrading(sdb, transOutUID, -req.Amount)
})
}
func tccBarrierTransOutConfirm(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
return adjustBalance(sdb, transOutUID, -reqFrom(c).Amount)
})
}
// TccBarrierTransOutCancel will be use in test
func TccBarrierTransOutCancel(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) {
return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.Tx) (interface{}, error) {
return adjustTrading(sdb, transOutUID, reqFrom(c).Amount)
})
}

View File

@ -37,6 +37,10 @@ func dbGet() *common.DB {
return common.DbGet(config.DB)
}
func sdbGet() *sql.DB {
return common.SdbGet(config.DB)
}
// XaSetup 挂载http的api创建XaClient
func XaSetup(app *gin.Engine) {
app.POST(BusiAPI+"/TransInXa", common.WrapHandler(xaTransIn))
@ -68,7 +72,7 @@ func xaTransIn(c *gin.Context) (interface{}, error) {
if req.TransInResult == "FAILURE" {
return fmt.Errorf("tranIn FAILURE")
}
_, rerr = common.DbExec(db, "update dtm_busi.user_account set balance=balance+? where user_id=?", req.Amount, 2)
_, rerr = common.SdbExec(db, "update dtm_busi.user_account set balance=balance+? where user_id=?", req.Amount, 2)
return
})
if err != nil && strings.Contains(err.Error(), "FAILURE") {
@ -84,7 +88,7 @@ func xaTransOut(c *gin.Context) (interface{}, error) {
if req.TransOutResult == "FAILURE" {
return fmt.Errorf("tranOut failed")
}
_, rerr = common.DbExec(db, "update dtm_busi.user_account set balance=balance-? where user_id=?", req.Amount, 1)
_, rerr = common.SdbExec(db, "update dtm_busi.user_account set balance=balance-? where user_id=?", req.Amount, 1)
return
})
e2p(err)