db adapter ok, add gorm example

This commit is contained in:
yedf2 2021-08-21 04:06:59 +08:00
parent 94df49b7cb
commit eff856d392
19 changed files with 180 additions and 96 deletions

View File

@ -22,9 +22,9 @@ func TestDb(t *testing.T) {
func TestDbAlone(t *testing.T) {
db, err := dtmcli.SdbAlone(DtmConfig.DB)
assert.Nil(t, err)
_, err = dtmcli.SdbExec(db, "select 1")
_, err = dtmcli.DBExec(db, "select 1")
assert.Equal(t, nil, err)
db.Close()
_, err = dtmcli.SdbExec(db, "select 1")
_, err = dtmcli.DBExec(db, "select 1")
assert.NotEqual(t, nil, err)
}

View File

@ -17,8 +17,8 @@ Lets take a look at a sequence diagram of network abnormalities to better und
![abnormal network image](https://pic2.zhimg.com/80/v2-04c577b69ab7145ab493a8158a048a08_1440w.png)
- When business processing step 4, Cancel is executed before Try, and empty rollback needs to be processed.
- When business processing step 6, Cancel is executed repeatedly and needs to be idempotent.
- When business processing step 4, Cancel is executed before Try, and empty rollback needs to be processed.
- When business processing step 6, Cancel is executed repeatedly and needs to be idempotent.
- When business processing step 8, Try is executed after Cancel and needs to be processed.
For the above-mentioned complex network abnormalities, all distributed transaction systems currently recommend that business developers use unique keys to query whether the associated operations have been completed. If it is completed, it returns directly to success. The related logic is complex, error-prone, and the business burden is heavy.
@ -72,19 +72,21 @@ Under this mechanism, problems related to network abnormalities are solved
- Idempotent control: No single key can be inserted repeatedly in any branch, which ensures that it will not be executed repeatedly
- Dangling action control: Try is to be executed after Cancel, then the inserted gid-branchid-try is not successful, it will not be executed
Let's take a look at the example in main_tcc_barrier.go in dtm:
Let's take a look at the example in http_tcc_barrier.go in dtm:
``` GO
func tccBarrierTransInTry(c *gin.Context) (interface{}, error) {
return dtmcli.ThroughBarrierCall(db, dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) {
return adjustTrading(sdb, transInUid, reqFrom(c).Amount)
})
req := reqFrom(c) // 去重构一下,改成可以重复使用的输入
barrier := MustBarrierFromGin(c)
return dtmcli.ResultSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error {
return adjustTrading(db, transInUID, req.Amount)
})
}
```
The Try in the TransIn business only needs one ThroughBarrierCall call to handle the above abnormal situation, which greatly simplifies the work of business developers. For SAGA transactions, reliable messages, etc., a similar mechanism can also be used.
The Try in the TransIn business only needs one barrier.Call call to handle the above abnormal situation, which greatly simplifies the work of business developers. For SAGA transactions, reliable messages, etc., a similar mechanism can also be used.
## summary
The sub-transaction barrier technology proposed in this project systematically solves the problem of network disorder in distributed transactions and greatly reduces the difficulty of sub-transaction disorder processing.
The sub-transaction barrier technology proposed in this project systematically solves the problem of network disorder in distributed transactions and greatly reduces the difficulty of sub-transaction disorder processing.
Other development languages can also quickly access the technology

18
dtmcli/adapter.go Normal file
View File

@ -0,0 +1,18 @@
package dtmcli
import (
"database/sql"
)
// DB inteface of dtmcli db
type DB interface {
Exec(query string, args ...interface{}) (sql.Result, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
// Tx interface of dtmcli tx
type Tx interface {
Rollback() error
Commit() error
DB
}

View File

@ -1,14 +1,13 @@
package dtmcli
import (
"context"
"database/sql"
"fmt"
"net/url"
)
// BusiFunc type for busi func
type BusiFunc func(db *sql.Tx) error
type BusiFunc func(db DB) error
// BranchBarrier every branch info
type BranchBarrier struct {
@ -42,11 +41,11 @@ func BarrierFrom(transType, gid, branchID, branchType string) (*BranchBarrier, e
return ti, nil
}
func insertBarrier(tx *sql.Tx, transType string, gid string, branchID string, branchType string, barrierID string, reason string) (int64, error) {
func insertBarrier(tx Tx, transType string, gid string, branchID string, branchType string, barrierID string, reason string) (int64, error) {
if branchType == "" {
return 0, nil
}
return StxExec(tx, "insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, barrier_id, reason) values(?,?,?,?,?,?)", transType, gid, branchID, branchType, barrierID, reason)
return DBExec(tx, "insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, barrier_id, reason) values(?,?,?,?,?,?)", transType, gid, branchID, branchType, barrierID, reason)
}
// Call 子事务屏障,详细介绍见 https://zhuanlan.zhihu.com/p/388444465
@ -56,10 +55,9 @@ func insertBarrier(tx *sql.Tx, transType string, gid string, branchID string, br
// 返回值:
// 如果发生悬挂则busiCall不会被调用直接返回错误 ErrFailure全局事务尽早进行回滚
// 如果正常调用重复调用空补偿返回的错误值为nil正常往下进行
func (bb *BranchBarrier) Call(db *sql.DB, busiCall BusiFunc) (rerr error) {
func (bb *BranchBarrier) Call(tx Tx, busiCall BusiFunc) (rerr error) {
bb.BarrierID = bb.BarrierID + 1
bid := fmt.Sprintf("%02d", bb.BarrierID)
tx, rerr := db.BeginTx(context.Background(), &sql.TxOptions{})
if rerr != nil {
return
}
@ -86,7 +84,7 @@ func (bb *BranchBarrier) Call(db *sql.DB, busiCall BusiFunc) (rerr error) {
return
} else if currentAffected == 0 { // 插入不成功
var result sql.NullString
err := StxQueryRow(tx, "select 1 from dtm_barrier.barrier where trans_type=? and gid=? and branch_id=? and branch_type=? and barrier_id=? and reason=?",
err := DBQueryRow(tx, "select 1 from dtm_barrier.barrier where trans_type=? and gid=? and branch_id=? and branch_type=? and barrier_id=? and reason=?",
ti.TransType, ti.Gid, ti.BranchID, ti.BranchType, bid, ti.BranchType).Scan(&result)
if err == sql.ErrNoRows { // 不是当前分支插入的那么是cancel插入的因此是悬挂操作返回失败AP收到这个返回会尽快回滚
rerr = ErrFailure

View File

@ -202,8 +202,8 @@ func SdbAlone(conf map[string]string) (*sql.DB, error) {
return sql.Open(conf["driver"], dsn)
}
// SdbExec use raw db to exec
func SdbExec(db *sql.DB, sql string, values ...interface{}) (affected int64, rerr error) {
// DBExec use raw db to exec
func DBExec(db DB, sql string, values ...interface{}) (affected int64, rerr error) {
r, rerr := db.Exec(sql, values...)
if rerr == nil {
affected, rerr = r.RowsAffected()
@ -214,22 +214,10 @@ func SdbExec(db *sql.DB, sql string, values ...interface{}) (affected int64, rer
return
}
// StxExec use raw tx to exec
func StxExec(tx *sql.Tx, sql string, values ...interface{}) (affected int64, rerr error) {
r, rerr := tx.Exec(sql, values...)
if rerr == nil {
affected, rerr = r.RowsAffected()
Logf("affected: %d for %s %v", affected, sql, values)
} else {
LogRedf("exec error: %v for %s %v", rerr, sql, values)
}
return
}
// StxQueryRow use raw tx to query row
func StxQueryRow(tx *sql.Tx, query string, args ...interface{}) *sql.Row {
// DBQueryRow use raw tx to query row
func DBQueryRow(db DB, query string, args ...interface{}) *sql.Row {
Logf("querying: "+query, args...)
return tx.QueryRow(query, args...)
return db.QueryRow(query, args...)
}
// GetDsn get dsn from map config

View File

@ -61,7 +61,7 @@ func (xc *XaClient) HandleCallback(gid string, branchID string, action string) (
}
defer db.Close()
xaID := gid + "-" + branchID
_, err = SdbExec(db, fmt.Sprintf("xa %s '%s'", action, xaID))
_, err = DBExec(db, fmt.Sprintf("xa %s '%s'", action, xaID))
return ResultSuccess, err
}
@ -82,9 +82,9 @@ func (xc *XaClient) XaLocalTransaction(qs url.Values, xaFunc XaLocalFunc) (ret i
defer func() { db.Close() }()
defer func() {
x := recover()
_, err := SdbExec(db, fmt.Sprintf("XA end '%s'", xaBranch))
_, err := DBExec(db, fmt.Sprintf("XA end '%s'", xaBranch))
if x == nil && rerr == nil && err == nil {
_, err = SdbExec(db, fmt.Sprintf("XA prepare '%s'", xaBranch))
_, err = DBExec(db, fmt.Sprintf("XA prepare '%s'", xaBranch))
}
if rerr == nil {
rerr = err
@ -93,7 +93,7 @@ func (xc *XaClient) XaLocalTransaction(qs url.Values, xaFunc XaLocalFunc) (ret i
panic(x)
}
}()
_, rerr = SdbExec(db, fmt.Sprintf("XA start '%s'", xaBranch))
_, rerr = DBExec(db, fmt.Sprintf("XA start '%s'", xaBranch))
if rerr != nil {
return
}

View File

@ -1,8 +1,6 @@
package dtmgrpc
import (
"database/sql"
"github.com/yedf/dtm/dtmcli"
"google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
@ -20,8 +18,8 @@ type BranchBarrier struct {
// 返回值:
// 如果发生悬挂则busiCall不会被调用直接返回错误 ErrFailure全局事务尽早进行回滚
// 如果正常调用重复调用空补偿返回的错误值为nil正常往下进行
func (bb *BranchBarrier) Call(db *sql.DB, busiCall dtmcli.BusiFunc) (rerr error) {
err := bb.BranchBarrier.Call(db, busiCall)
func (bb *BranchBarrier) Call(tx dtmcli.Tx, busiCall dtmcli.BusiFunc) (rerr error) {
err := bb.BranchBarrier.Call(tx, busiCall)
if err == dtmcli.ErrFailure {
return status.New(codes.Aborted, "user rollback").Err()
}

View File

@ -55,7 +55,7 @@ func (xc *XaGrpcClient) HandleCallback(gid string, branchID string, action strin
}
defer db.Close()
xaID := gid + "-" + branchID
_, err = dtmcli.SdbExec(db, fmt.Sprintf("xa %s '%s'", action, xaID))
_, err = dtmcli.DBExec(db, fmt.Sprintf("xa %s '%s'", action, xaID))
return err
}
@ -76,9 +76,9 @@ func (xc *XaGrpcClient) XaLocalTransaction(br *BusiRequest, xaFunc XaGrpcLocalFu
defer func() { db.Close() }()
defer func() {
x := recover()
_, err := dtmcli.SdbExec(db, fmt.Sprintf("XA end '%s'", xaBranch))
_, err := dtmcli.DBExec(db, fmt.Sprintf("XA end '%s'", xaBranch))
if x == nil && rerr == nil && err == nil {
_, err = dtmcli.SdbExec(db, fmt.Sprintf("XA prepare '%s'", xaBranch))
_, err = dtmcli.DBExec(db, fmt.Sprintf("XA prepare '%s'", xaBranch))
}
if rerr == nil {
rerr = err
@ -87,7 +87,7 @@ func (xc *XaGrpcClient) XaLocalTransaction(br *BusiRequest, xaFunc XaGrpcLocalFu
panic(x)
}
}()
_, rerr = dtmcli.SdbExec(db, fmt.Sprintf("XA start '%s'", xaBranch))
_, rerr = dtmcli.DBExec(db, fmt.Sprintf("XA start '%s'", xaBranch))
if rerr != nil {
return
}

View File

@ -116,7 +116,7 @@ func (s *busiServer) TransInXa(ctx context.Context, in *dtmgrpc.BusiRequest) (*d
if req.TransInResult == "FAILURE" {
return status.New(codes.Aborted, "user return failure").Err()
}
_, err := dtmcli.SdbExec(db, "update dtm_busi.user_account set balance=balance+? where user_id=?", req.Amount, 2)
_, err := dtmcli.DBExec(db, "update dtm_busi.user_account set balance=balance+? where user_id=?", req.Amount, 2)
return err
})
}
@ -128,7 +128,7 @@ func (s *busiServer) TransOutXa(ctx context.Context, in *dtmgrpc.BusiRequest) (*
if req.TransOutResult == "FAILURE" {
return status.New(codes.Aborted, "user return failure").Err()
}
_, err := dtmcli.SdbExec(db, "update dtm_busi.user_account set balance=balance-? where user_id=?", req.Amount, 1)
_, err := dtmcli.DBExec(db, "update dtm_busi.user_account set balance=balance-? where user_id=?", req.Amount, 1)
return err
})
}

View File

@ -8,6 +8,8 @@ import (
"github.com/gin-gonic/gin"
"github.com/yedf/dtm/common"
"github.com/yedf/dtm/dtmcli"
"gorm.io/driver/mysql"
"gorm.io/gorm"
)
const (
@ -108,7 +110,7 @@ func BaseAddRoute(app *gin.Engine) {
if reqFrom(c).TransInResult == "FAILURE" {
return dtmcli.ResultFailure, nil
}
_, err := dtmcli.SdbExec(db, "update dtm_busi.user_account set balance=balance+? where user_id=?", reqFrom(c).Amount, 2)
_, err := dtmcli.DBExec(db, "update dtm_busi.user_account set balance=balance+? where user_id=?", reqFrom(c).Amount, 2)
return dtmcli.ResultSuccess, err
})
}))
@ -117,9 +119,25 @@ func BaseAddRoute(app *gin.Engine) {
if reqFrom(c).TransOutResult == "FAILURE" {
return dtmcli.ResultFailure, nil
}
_, err := dtmcli.SdbExec(db, "update dtm_busi.user_account set balance=balance-? where user_id=?", reqFrom(c).Amount, 1)
_, err := dtmcli.DBExec(db, "update dtm_busi.user_account set balance=balance-? where user_id=?", reqFrom(c).Amount, 1)
return dtmcli.ResultSuccess, err
})
}))
app.POST(BusiAPI+"/TransOutXaGorm", common.WrapHandler(func(c *gin.Context) (interface{}, error) {
return XaClient.XaLocalTransaction(c.Request.URL.Query(), func(db *sql.DB, xa *dtmcli.Xa) (interface{}, error) {
if reqFrom(c).TransOutResult == "FAILURE" {
return dtmcli.ResultFailure, nil
}
gdb, err := gorm.Open(mysql.New(mysql.Config{
Conn: db,
}), &gorm.Config{})
if err != nil {
return nil, err
}
dbr := gdb.Exec("update dtm_busi.user_account set balance=balance-? where user_id=?", reqFrom(c).Amount, 1)
return dtmcli.ResultSuccess, dbr.Error
})
}))
}

View File

@ -71,6 +71,13 @@ func sdbGet() *sql.DB {
return db
}
func txGet() *sql.Tx {
db := sdbGet()
tx, err := db.Begin()
dtmcli.FatalIfError(err)
return tx
}
// MustBarrierFromGin 1
func MustBarrierFromGin(c *gin.Context) *dtmcli.BranchBarrier {
ti, err := dtmcli.BarrierFromQuery(c.Request.URL.Query())

View File

@ -24,7 +24,7 @@ func RunSQLScript(conf map[string]string, script string, skipDrop bool) {
if s == "" || skipDrop && strings.Contains(s, "drop") {
continue
}
_, err = dtmcli.SdbExec(con, s)
_, err = dtmcli.DBExec(con, s)
dtmcli.FatalIfError(err)
}
}

View File

@ -2,7 +2,6 @@ package examples
import (
"context"
"database/sql"
"github.com/yedf/dtm/dtmcli"
"github.com/yedf/dtm/dtmgrpc"
@ -24,11 +23,11 @@ func init() {
})
}
func sagaGrpcBarrierAdjustBalance(sdb *sql.Tx, uid int, amount int, result string) error {
func sagaGrpcBarrierAdjustBalance(db dtmcli.DB, uid int, amount int, result string) error {
if result == "FAILURE" {
return status.New(codes.Aborted, "user rollback").Err()
}
_, err := dtmcli.StxExec(sdb, "update dtm_busi.user_account set balance = balance + ? where user_id = ?", amount, uid)
_, err := dtmcli.DBExec(db, "update dtm_busi.user_account set balance = balance + ? where user_id = ?", amount, uid)
return err
}
@ -37,8 +36,8 @@ func (s *busiServer) TransInBSaga(ctx context.Context, in *dtmgrpc.BusiRequest)
req := TransReq{}
dtmcli.MustUnmarshal(in.BusiData, &req)
barrier := MustBarrierFromGrpc(in)
return &emptypb.Empty{}, barrier.Call(sdbGet(), func(sdb *sql.Tx) error {
return sagaGrpcBarrierAdjustBalance(sdb, 2, req.Amount, req.TransInResult)
return &emptypb.Empty{}, barrier.Call(txGet(), func(tx dtmcli.DB) error {
return sagaGrpcBarrierAdjustBalance(tx, 2, req.Amount, req.TransInResult)
})
}
@ -46,8 +45,8 @@ func (s *busiServer) TransOutBSaga(ctx context.Context, in *dtmgrpc.BusiRequest)
req := TransReq{}
dtmcli.MustUnmarshal(in.BusiData, &req)
barrier := MustBarrierFromGrpc(in)
return &emptypb.Empty{}, barrier.Call(sdbGet(), func(sdb *sql.Tx) error {
return sagaGrpcBarrierAdjustBalance(sdb, 1, -req.Amount, req.TransOutResult)
return &emptypb.Empty{}, barrier.Call(txGet(), func(db dtmcli.DB) error {
return sagaGrpcBarrierAdjustBalance(db, 1, -req.Amount, req.TransOutResult)
})
}
@ -55,8 +54,8 @@ func (s *busiServer) TransInRevertBSaga(ctx context.Context, in *dtmgrpc.BusiReq
req := TransReq{}
dtmcli.MustUnmarshal(in.BusiData, &req)
barrier := MustBarrierFromGrpc(in)
return &emptypb.Empty{}, barrier.Call(sdbGet(), func(sdb *sql.Tx) error {
return sagaGrpcBarrierAdjustBalance(sdb, 2, -req.Amount, "")
return &emptypb.Empty{}, barrier.Call(txGet(), func(db dtmcli.DB) error {
return sagaGrpcBarrierAdjustBalance(db, 2, -req.Amount, "")
})
}
@ -64,7 +63,7 @@ func (s *busiServer) TransOutRevertBSaga(ctx context.Context, in *dtmgrpc.BusiRe
req := TransReq{}
dtmcli.MustUnmarshal(in.BusiData, &req)
barrier := MustBarrierFromGrpc(in)
return &emptypb.Empty{}, barrier.Call(sdbGet(), func(sdb *sql.Tx) error {
return sagaGrpcBarrierAdjustBalance(sdb, 1, req.Amount, "")
return &emptypb.Empty{}, barrier.Call(txGet(), func(db dtmcli.DB) error {
return sagaGrpcBarrierAdjustBalance(db, 1, req.Amount, "")
})
}

22
examples/http_gorm_xa.go Normal file
View File

@ -0,0 +1,22 @@
package examples
import (
"github.com/go-resty/resty/v2"
"github.com/yedf/dtm/dtmcli"
)
func init() {
addSample("xa_gorm", func() string {
gid := dtmcli.MustGenGid(DtmServer)
err := XaClient.XaGlobalTransaction(gid, func(xa *dtmcli.Xa) (*resty.Response, error) {
resp, err := xa.CallBranch(&TransReq{Amount: 30}, Busi+"/TransOutXaGorm")
if err != nil {
return resp, err
}
return xa.CallBranch(&TransReq{Amount: 30}, Busi+"/TransInXa")
})
dtmcli.FatalIfError(err)
return gid
})
}

View File

@ -1,8 +1,6 @@
package examples
import (
"database/sql"
"github.com/gin-gonic/gin"
"github.com/yedf/dtm/common"
"github.com/yedf/dtm/dtmcli"
@ -28,8 +26,8 @@ func init() {
})
}
func sagaBarrierAdjustBalance(sdb *sql.Tx, uid int, amount int) error {
_, err := dtmcli.StxExec(sdb, "update dtm_busi.user_account set balance = balance + ? where user_id = ?", amount, uid)
func sagaBarrierAdjustBalance(db dtmcli.DB, uid int, amount int) error {
_, err := dtmcli.DBExec(db, "update dtm_busi.user_account set balance = balance + ? where user_id = ?", amount, uid)
return err
}
@ -40,15 +38,15 @@ func sagaBarrierTransIn(c *gin.Context) (interface{}, error) {
return req.TransInResult, nil
}
barrier := MustBarrierFromGin(c)
return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error {
return sagaBarrierAdjustBalance(sdb, 1, req.Amount)
return dtmcli.ResultSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error {
return sagaBarrierAdjustBalance(db, 1, req.Amount)
})
}
func sagaBarrierTransInCompensate(c *gin.Context) (interface{}, error) {
barrier := MustBarrierFromGin(c)
return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error {
return sagaBarrierAdjustBalance(sdb, 1, -reqFrom(c).Amount)
return dtmcli.ResultSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error {
return sagaBarrierAdjustBalance(db, 1, -reqFrom(c).Amount)
})
}
@ -58,14 +56,14 @@ func sagaBarrierTransOut(c *gin.Context) (interface{}, error) {
return req.TransInResult, nil
}
barrier := MustBarrierFromGin(c)
return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error {
return sagaBarrierAdjustBalance(sdb, 2, -req.Amount)
return dtmcli.ResultSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error {
return sagaBarrierAdjustBalance(db, 2, -req.Amount)
})
}
func sagaBarrierTransOutCompensate(c *gin.Context) (interface{}, error) {
barrier := MustBarrierFromGin(c)
return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error {
return sagaBarrierAdjustBalance(sdb, 2, reqFrom(c).Amount)
return dtmcli.ResultSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error {
return sagaBarrierAdjustBalance(db, 2, reqFrom(c).Amount)
})
}

View File

@ -0,0 +1,34 @@
package examples
import (
"github.com/gin-gonic/gin"
"github.com/yedf/dtm/common"
"github.com/yedf/dtm/dtmcli"
)
func init() {
setupFuncs["SagaGormBarrierSetup"] = func(app *gin.Engine) {
app.POST(BusiAPI+"/SagaBTransOutGorm", common.WrapHandler(sagaGormBarrierTransOut))
}
addSample("saga_gorm_barrier", func() string {
dtmcli.Logf("a busi transaction begin")
req := &TransReq{Amount: 30}
saga := dtmcli.NewSaga(DtmServer, dtmcli.MustGenGid(DtmServer)).
Add(Busi+"/SagaBTransOutGorm", Busi+"/SagaBTransOutCompensate", req).
Add(Busi+"/SagaBTransIn", Busi+"/SagaBTransInCompensate", req)
dtmcli.Logf("busi trans submit")
err := saga.Submit()
dtmcli.FatalIfError(err)
return saga.Gid
})
}
func sagaGormBarrierTransOut(c *gin.Context) (interface{}, error) {
req := reqFrom(c)
barrier := MustBarrierFromGin(c)
tx := dbGet().DB.Begin()
return dtmcli.ResultSuccess, barrier.Call(tx.Statement.ConnPool.(dtmcli.Tx), func(db dtmcli.DB) error {
return tx.Exec("update dtm_busi.user_account set balance = balance + ? where user_id = ?", -req.Amount, 2).Error
})
}

View File

@ -1,7 +1,6 @@
package examples
import (
"database/sql"
"fmt"
"github.com/gin-gonic/gin"
@ -37,18 +36,18 @@ func init() {
const transInUID = 1
const transOutUID = 2
func adjustTrading(sdb *sql.Tx, uid int, amount int) error {
affected, err := dtmcli.StxExec(sdb, "update dtm_busi.user_account_trading set trading_balance=trading_balance + ? where user_id=? and trading_balance + ? + (select balance from dtm_busi.user_account where id=?) >= 0", amount, uid, amount, uid)
func adjustTrading(db dtmcli.DB, uid int, amount int) error {
affected, err := dtmcli.DBExec(db, "update dtm_busi.user_account_trading set trading_balance=trading_balance + ? where user_id=? and trading_balance + ? + (select balance from dtm_busi.user_account where id=?) >= 0", amount, uid, amount, uid)
if err == nil && affected == 0 {
return fmt.Errorf("update error, maybe balance not enough")
}
return err
}
func adjustBalance(sdb *sql.Tx, uid int, amount int) error {
affected, err := dtmcli.StxExec(sdb, "update dtm_busi.user_account_trading set trading_balance = trading_balance + ? where user_id=?;", -amount, uid)
func adjustBalance(db dtmcli.DB, uid int, amount int) error {
affected, err := dtmcli.DBExec(db, "update dtm_busi.user_account_trading set trading_balance = trading_balance + ? where user_id=?;", -amount, uid)
if err == nil && affected == 1 {
affected, err = dtmcli.StxExec(sdb, "update dtm_busi.user_account set balance=balance+? where user_id=?", amount, uid)
affected, err = dtmcli.DBExec(db, "update dtm_busi.user_account set balance=balance+? where user_id=?", amount, uid)
}
if err == nil && affected == 0 {
return fmt.Errorf("update 0 rows")
@ -63,22 +62,22 @@ func tccBarrierTransInTry(c *gin.Context) (interface{}, error) {
return req.TransInResult, nil
}
barrier := MustBarrierFromGin(c)
return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error {
return adjustTrading(sdb, transInUID, req.Amount)
return dtmcli.ResultSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error {
return adjustTrading(db, transInUID, req.Amount)
})
}
func tccBarrierTransInConfirm(c *gin.Context) (interface{}, error) {
barrier := MustBarrierFromGin(c)
return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error {
return adjustBalance(sdb, transInUID, reqFrom(c).Amount)
return dtmcli.ResultSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error {
return adjustBalance(db, transInUID, reqFrom(c).Amount)
})
}
func tccBarrierTransInCancel(c *gin.Context) (interface{}, error) {
barrier := MustBarrierFromGin(c)
return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error {
return adjustTrading(sdb, transInUID, -reqFrom(c).Amount)
return dtmcli.ResultSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error {
return adjustTrading(db, transInUID, -reqFrom(c).Amount)
})
}
@ -88,22 +87,22 @@ func tccBarrierTransOutTry(c *gin.Context) (interface{}, error) {
return req.TransInResult, nil
}
barrier := MustBarrierFromGin(c)
return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error {
return adjustTrading(sdb, transOutUID, -req.Amount)
return dtmcli.ResultSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error {
return adjustTrading(db, transOutUID, -req.Amount)
})
}
func tccBarrierTransOutConfirm(c *gin.Context) (interface{}, error) {
barrier := MustBarrierFromGin(c)
return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error {
return adjustBalance(sdb, transOutUID, -reqFrom(c).Amount)
return dtmcli.ResultSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error {
return adjustBalance(db, transOutUID, -reqFrom(c).Amount)
})
}
// TccBarrierTransOutCancel will be use in test
func TccBarrierTransOutCancel(c *gin.Context) (interface{}, error) {
barrier := MustBarrierFromGin(c)
return dtmcli.ResultSuccess, barrier.Call(sdbGet(), func(sdb *sql.Tx) error {
return adjustTrading(sdb, transOutUID, reqFrom(c).Amount)
return dtmcli.ResultSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error {
return adjustTrading(db, transOutUID, reqFrom(c).Amount)
})
}

View File

@ -42,7 +42,7 @@ func QsFireRequest() string {
}
func qsAdjustBalance(uid int, amount int) (interface{}, error) {
_, err := dtmcli.SdbExec(sdbGet(), "update dtm_busi.user_account set balance = balance + ? where user_id = ?", amount, uid)
_, err := dtmcli.DBExec(sdbGet(), "update dtm_busi.user_account set balance = balance + ? where user_id = ?", amount, uid)
return dtmcli.ResultSuccess, err
}

View File

@ -1,7 +1,6 @@
package test
import (
"database/sql"
"fmt"
"testing"
@ -118,7 +117,9 @@ func TestSqlDB(t *testing.T) {
BranchType: "action",
}
db.Must().Exec("insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values('saga', 'gid1', 'branch_id1', 'action', 'saga')")
err := barrier.Call(db.ToSQLDB(), func(db *sql.Tx) error {
tx, err := db.ToSQLDB().Begin()
asserts.Nil(err)
err = barrier.Call(tx, func(db dtmcli.DB) error {
dtmcli.Logf("rollback gid2")
return fmt.Errorf("gid2 error")
})
@ -128,7 +129,9 @@ func TestSqlDB(t *testing.T) {
dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{})
asserts.Equal(dbr.RowsAffected, int64(0))
barrier.BarrierID = 0
err = barrier.Call(db.ToSQLDB(), func(db *sql.Tx) error {
tx2, err := db.ToSQLDB().Begin()
asserts.Nil(err)
err = barrier.Call(tx2, func(db dtmcli.DB) error {
dtmcli.Logf("submit gid2")
return nil
})