partial dtmcli refactor

This commit is contained in:
yedf2 2021-08-04 11:19:06 +08:00
parent 9e8359c893
commit fe65ff96e4
15 changed files with 326 additions and 331 deletions

View File

@ -4,7 +4,7 @@ import (
"os"
"time"
"github.com/sirupsen/logrus"
"github.com/yedf/dtm/dtmcli"
"github.com/yedf/dtm/dtmsvr"
"github.com/yedf/dtm/examples"
)
@ -68,7 +68,7 @@ func main() {
examples.TccBarrierAddRoute(app)
examples.TccBarrierFireRequest()
} else {
logrus.Fatalf("unknown arg: %s", os.Args[1])
dtmcli.LogRedf("unknown arg: %s", os.Args[1])
}
wait()
}

View File

@ -7,6 +7,8 @@ import (
"time"
_ "github.com/go-sql-driver/mysql"
"github.com/yedf/dtm/dtmcli"
// _ "github.com/lib/pq"
"gorm.io/driver/mysql"
@ -15,12 +17,6 @@ import (
"gorm.io/gorm"
)
// M a short name
type M = map[string]interface{}
// MS a short name
type MS = map[string]string
// ModelBase model base for gorm to provide base fields
type ModelBase struct {
ID uint
@ -38,7 +34,6 @@ func getGormDialator(driver string, dsn string) gorm.Dialector {
}
var dbs = map[string]*DB{}
var sqlDbs = map[string]*sql.DB{}
// DB provide more func over gorm.DB
type DB struct {
@ -60,7 +55,7 @@ func (m *DB) NoMust() *DB {
// ToSQLDB get the sql.DB
func (m *DB) ToSQLDB() *sql.DB {
d, err := m.DB.DB()
E2P(err)
dtmcli.E2P(err)
return d
}
@ -78,7 +73,7 @@ func (op *tracePlugin) Initialize(db *gorm.DB) (err error) {
after := func(db *gorm.DB) {
_ts, _ := db.InstanceGet("ivy.startTime")
sql := db.Dialector.Explain(db.Statement.SQL.String(), db.Statement.Vars...)
Logf("used: %d ms affected: %d sql is: %s", time.Since(_ts.(time.Time)).Milliseconds(), db.RowsAffected, sql)
dtmcli.Logf("used: %d ms affected: %d sql is: %s", time.Since(_ts.(time.Time)).Milliseconds(), db.RowsAffected, sql)
if v, ok := db.InstanceGet("ivy.must"); ok && v.(bool) {
if db.Error != nil && db.Error != gorm.ErrRecordNotFound {
panic(db.Error)
@ -89,7 +84,7 @@ func (op *tracePlugin) Initialize(db *gorm.DB) (err error) {
beforeName := "cb_before"
afterName := "cb_after"
Logf("installing db plugin: %s", op.Name())
dtmcli.Logf("installing db plugin: %s", op.Name())
// 开始前
_ = db.Callback().Create().Before("gorm:before_create").Register(beforeName, before)
_ = db.Callback().Query().Before("gorm:query").Register(beforeName, before)
@ -108,79 +103,17 @@ func (op *tracePlugin) Initialize(db *gorm.DB) (err error) {
return
}
// GetDsn get dsn from map config
func GetDsn(conf map[string]string) string {
conf["host"] = MayReplaceLocalhost(conf["host"])
driver := conf["driver"]
dsn := MS{
"mysql": fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local",
conf["user"], conf["password"], conf["host"], conf["port"], conf["database"]),
"postgres": fmt.Sprintf("host=%s user=%s password=%s dbname='%s' port=%s sslmode=disable TimeZone=Asia/Shanghai",
conf["host"], conf["user"], conf["password"], conf["database"], conf["port"]),
}[driver]
PanicIf(dsn == "", fmt.Errorf("unknow driver: %s", driver))
return dsn
}
// DbGet get db connection for specified conf
func DbGet(conf map[string]string) *DB {
dsn := GetDsn(conf)
dsn := dtmcli.GetDsn(conf)
if dbs[dsn] == nil {
Logf("connecting %s", strings.Replace(dsn, conf["password"], "****", 1))
dtmcli.Logf("connecting %s", strings.Replace(dsn, conf["password"], "****", 1))
db1, err := gorm.Open(getGormDialator(conf["driver"], dsn), &gorm.Config{
SkipDefaultTransaction: true,
})
E2P(err)
dtmcli.E2P(err)
db1.Use(&tracePlugin{})
dbs[dsn] = &DB{DB: db1}
}
return dbs[dsn]
}
// SdbGet get pooled sql.DB
func SdbGet(conf map[string]string) *sql.DB {
dsn := GetDsn(conf)
if sqlDbs[dsn] == nil {
sqlDbs[dsn] = SdbAlone(conf)
}
return sqlDbs[dsn]
}
// SdbAlone get a standalone db connection
func SdbAlone(conf map[string]string) *sql.DB {
dsn := GetDsn(conf)
Logf("opening alone %s: %s", conf["driver"], strings.Replace(dsn, conf["password"], "****", 1))
mdb, err := sql.Open(conf["driver"], dsn)
E2P(err)
return mdb
}
// SdbExec use raw db to exec
func SdbExec(db *sql.DB, sql string, values ...interface{}) (affected int64, rerr error) {
r, rerr := db.Exec(sql, values...)
if rerr == nil {
affected, rerr = r.RowsAffected()
Logf("affected: %d for %s %v", affected, sql, values)
} else {
LogRedf("exec error: %v for %s %v", rerr, sql, values)
}
return
}
// StxExec use raw tx to exec
func StxExec(tx *sql.Tx, sql string, values ...interface{}) (affected int64, rerr error) {
r, rerr := tx.Exec(sql, values...)
if rerr == nil {
affected, rerr = r.RowsAffected()
Logf("affected: %d for %s %v", affected, sql, values)
} else {
LogRedf("exec error: %v for %s %v", rerr, sql, values)
}
return
}
// StxQueryRow use raw tx to query row
func StxQueryRow(tx *sql.Tx, query string, args ...interface{}) *sql.Row {
Logf("querying: "+query, args...)
return tx.QueryRow(query, args...)
}

View File

@ -4,6 +4,7 @@ import (
"testing"
"github.com/go-playground/assert/v2"
"github.com/yedf/dtm/dtmcli"
)
type testConfig struct {
@ -13,14 +14,14 @@ type testConfig struct {
var config = testConfig{}
func init() {
InitConfig(GetProjectDir(), &config)
InitConfig(dtmcli.GetProjectDir(), &config)
config.DB["database"] = ""
}
func TestDb(t *testing.T) {
db := DbGet(config.DB)
err := func() (rerr error) {
defer P2E(&rerr)
defer dtmcli.P2E(&rerr)
dbr := db.NoMust().Exec("select a")
assert.NotEqual(t, nil, dbr.Error)
db.Must().Exec("select a")
@ -30,10 +31,10 @@ func TestDb(t *testing.T) {
}
func TestDbAlone(t *testing.T) {
db := SdbAlone(config.DB)
_, err := SdbExec(db, "select 1")
db := dtmcli.SdbAlone(config.DB)
_, err := dtmcli.SdbExec(db, "select 1")
assert.Equal(t, nil, err)
db.Close()
_, err = SdbExec(db, "select 1")
_, err = dtmcli.SdbExec(db, "select 1")
assert.NotEqual(t, nil, err)
}

View File

@ -3,113 +3,14 @@ package common
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"os"
"path"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/go-resty/resty/v2"
"github.com/sirupsen/logrus"
yaml "gopkg.in/yaml.v2"
)
// P2E panic to error
func P2E(perr *error) {
if x := recover(); x != nil {
if e, ok := x.(error); ok {
*perr = e
} else {
panic(x)
}
}
}
// E2P error to panic
func E2P(err error) {
if err != nil {
panic(err)
}
}
// CatchP catch panic to error
func CatchP(f func()) (rerr error) {
defer P2E(&rerr)
f()
return nil
}
// PanicIf name is clear
func PanicIf(cond bool, err error) {
if cond {
panic(err)
}
}
// MustAtoi 走must逻辑
func MustAtoi(s string) int {
r, err := strconv.Atoi(s)
if err != nil {
E2P(errors.New("convert to int error: " + s))
}
return r
}
// OrString return the first not empty string
func OrString(ss ...string) string {
for _, s := range ss {
if s != "" {
return s
}
}
return ""
}
// If ternary operator
func If(condition bool, trueObj interface{}, falseObj interface{}) interface{} {
if condition {
return trueObj
}
return falseObj
}
// MustMarshal checked version for marshal
func MustMarshal(v interface{}) []byte {
b, err := json.Marshal(v)
E2P(err)
return b
}
// MustMarshalString string version of MustMarshal
func MustMarshalString(v interface{}) string {
return string(MustMarshal(v))
}
// MustUnmarshal checked version for unmarshal
func MustUnmarshal(b []byte, obj interface{}) {
err := json.Unmarshal(b, obj)
E2P(err)
}
// MustUnmarshalString string version of MustUnmarshal
func MustUnmarshalString(s string, obj interface{}) {
MustUnmarshal([]byte(s), obj)
}
// MustRemarshal marshal and unmarshal, and check error
func MustRemarshal(from interface{}, to interface{}) {
b, err := json.Marshal(from)
E2P(err)
err = json.Unmarshal(b, to)
E2P(err)
}
// GetGinApp init and return gin
func GetGinApp() *gin.Engine {
gin.SetMode(gin.ReleaseMode)
@ -157,54 +58,6 @@ func WrapHandler(fn func(*gin.Context) (interface{}, error)) gin.HandlerFunc {
}
}
// RestyClient the resty object
var RestyClient = resty.New()
func init() {
// RestyClient.SetTimeout(3 * time.Second)
// RestyClient.SetRetryCount(2)
// RestyClient.SetRetryWaitTime(1 * time.Second)
RestyClient.OnBeforeRequest(func(c *resty.Client, r *resty.Request) error {
r.URL = MayReplaceLocalhost(r.URL)
Logf("requesting: %s %s %v %v", r.Method, r.URL, r.Body, r.QueryParam)
return nil
})
RestyClient.OnAfterResponse(func(c *resty.Client, resp *resty.Response) error {
r := resp.Request
Logf("requested: %s %s %s", r.Method, r.URL, resp.String())
return nil
})
}
// CheckRestySuccess panic if error or resp not success
func CheckRestySuccess(resp *resty.Response, err error) {
E2P(err)
if !strings.Contains(resp.String(), "SUCCESS") {
panic(fmt.Errorf("resty response not success: %s", resp.String()))
}
}
// Logf 输出日志
func Logf(format string, args ...interface{}) {
msg := fmt.Sprintf(format, args...)
n := time.Now()
ts := fmt.Sprintf("%d-%02d-%02d %02d:%02d:%02d.%03d", n.Year(), n.Month(), n.Day(), n.Hour(), n.Minute(), n.Second(), n.Nanosecond()/1000000)
var file string
var line int
for i := 1; ; i++ {
_, file, line, _ = runtime.Caller(i)
if strings.Contains(file, "dtm") {
break
}
}
fmt.Printf("%s %s:%d %s\n", ts, path.Base(file), line, msg)
}
// LogRedf 采用红色打印错误类信息
func LogRedf(fmt string, args ...interface{}) {
logrus.Errorf("\x1b[31m\n"+fmt+"\x1b[0m\n", args...)
}
// InitConfig init config
func InitConfig(dir string, config interface{}) {
cont, err := ioutil.ReadFile(dir + "/conf.yml")
@ -216,38 +69,3 @@ func InitConfig(dir string, config interface{}) {
err = yaml.Unmarshal(cont, config)
E2P(err)
}
// MustGetwd must version of os.Getwd
func MustGetwd() string {
wd, err := os.Getwd()
E2P(err)
return wd
}
// GetCurrentCodeDir name is clear
func GetCurrentCodeDir() string {
_, file, _, _ := runtime.Caller(1)
return filepath.Dir(file)
}
// GetProjectDir name is clear
func GetProjectDir() string {
_, file, _, _ := runtime.Caller(1)
for ; !strings.HasSuffix(file, "/dtm"); file = filepath.Dir(file) {
}
return file
}
// GetFuncName get current call func name
func GetFuncName() string {
pc, _, _, _ := runtime.Caller(1)
return runtime.FuncForPC(pc).Name()
}
// MayReplaceLocalhost when run in docker compose, change localhost to host.docker.internal for accessing host network
func MayReplaceLocalhost(host string) string {
if os.Getenv("IS_DOCKER_COMPOSE") != "" {
return strings.Replace(host, "localhost", "host.docker.internal", 1)
}
return host
}

View File

@ -8,7 +8,6 @@ import (
"net/url"
"github.com/gin-gonic/gin"
"github.com/yedf/dtm/common"
)
// BusiFunc type for busi func
@ -47,20 +46,11 @@ func TransInfoFromQuery(qs url.Values) (*TransInfo, error) {
return ti, nil
}
// BarrierModel barrier model for gorm
type BarrierModel struct {
common.ModelBase
TransInfo
}
// TableName gorm table name
func (BarrierModel) TableName() string { return "dtm_barrier.barrier" }
func insertBarrier(tx *sql.Tx, transType string, gid string, branchID string, branchType string, reason string) (int64, error) {
if branchType == "" {
return 0, nil
}
return common.StxExec(tx, "insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values(?,?,?,?,?)", transType, gid, branchID, branchType, reason)
return StxExec(tx, "insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values(?,?,?,?,?)", transType, gid, branchID, branchType, reason)
}
// ThroughBarrierCall 子事务屏障,详细介绍见 https://zhuanlan.zhihu.com/p/388444465
@ -78,7 +68,7 @@ func ThroughBarrierCall(db *sql.DB, transInfo *TransInfo, busiCall BusiFunc) (re
return
}
defer func() {
common.Logf("result is %v error is %v", res, rerr)
Logf("result is %v error is %v", res, rerr)
if x := recover(); x != nil {
tx.Rollback()
panic(x)
@ -95,13 +85,13 @@ func ThroughBarrierCall(db *sql.DB, transInfo *TransInfo, busiCall BusiFunc) (re
}[ti.BranchType]
originAffected, _ := insertBarrier(tx, ti.TransType, ti.Gid, ti.BranchID, originType, ti.BranchType)
currentAffected, rerr := insertBarrier(tx, ti.TransType, ti.Gid, ti.BranchID, ti.BranchType, ti.BranchType)
common.Logf("originAffected: %d currentAffected: %d", originAffected, currentAffected)
Logf("originAffected: %d currentAffected: %d", originAffected, currentAffected)
if (ti.BranchType == "cancel" || ti.BranchType == "compensate") && originAffected > 0 { // 这个是空补偿,返回成功
res = ResultSuccess
return
} else if currentAffected == 0 { // 插入不成功
var result sql.NullString
err := common.StxQueryRow(tx, "select result from dtm_barrier.barrier where trans_type=? and gid=? and branch_id=? and branch_type=? and reason=?",
err := StxQueryRow(tx, "select result from dtm_barrier.barrier where trans_type=? and gid=? and branch_id=? and branch_type=? and reason=?",
ti.TransType, ti.Gid, ti.BranchID, ti.BranchType, ti.BranchType).Scan(&result)
if err == sql.ErrNoRows { // 这个是悬挂操作返回失败AP收到这个返回会尽快回滚
res = ResultFailure
@ -121,8 +111,8 @@ func ThroughBarrierCall(db *sql.DB, transInfo *TransInfo, busiCall BusiFunc) (re
}
res, rerr = busiCall(tx)
if rerr == nil { // 正确返回了,需要将结果保存到数据库
sval := common.MustMarshalString(res)
_, rerr = common.StxExec(tx, "update dtm_barrier.barrier set result=? where trans_type=? and gid=? and branch_id=? and branch_type=?", sval,
sval := MustMarshalString(res)
_, rerr = StxExec(tx, "update dtm_barrier.barrier set result=? where trans_type=? and gid=? and branch_id=? and branch_type=?", sval,
ti.TransType, ti.Gid, ti.BranchID, ti.BranchType)
}
return

189
dtmcli/common.go Normal file
View File

@ -0,0 +1,189 @@
package dtmcli
import (
"encoding/json"
"errors"
"fmt"
"os"
"path"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
"github.com/go-resty/resty/v2"
)
// P2E panic to error
func P2E(perr *error) {
if x := recover(); x != nil {
if e, ok := x.(error); ok {
*perr = e
} else {
panic(x)
}
}
}
// E2P error to panic
func E2P(err error) {
if err != nil {
panic(err)
}
}
// CatchP catch panic to error
func CatchP(f func()) (rerr error) {
defer P2E(&rerr)
f()
return nil
}
// PanicIf name is clear
func PanicIf(cond bool, err error) {
if cond {
panic(err)
}
}
// MustAtoi 走must逻辑
func MustAtoi(s string) int {
r, err := strconv.Atoi(s)
if err != nil {
E2P(errors.New("convert to int error: " + s))
}
return r
}
// OrString return the first not empty string
func OrString(ss ...string) string {
for _, s := range ss {
if s != "" {
return s
}
}
return ""
}
// If ternary operator
func If(condition bool, trueObj interface{}, falseObj interface{}) interface{} {
if condition {
return trueObj
}
return falseObj
}
// MustMarshal checked version for marshal
func MustMarshal(v interface{}) []byte {
b, err := json.Marshal(v)
E2P(err)
return b
}
// MustMarshalString string version of MustMarshal
func MustMarshalString(v interface{}) string {
return string(MustMarshal(v))
}
// MustUnmarshal checked version for unmarshal
func MustUnmarshal(b []byte, obj interface{}) {
err := json.Unmarshal(b, obj)
E2P(err)
}
// MustUnmarshalString string version of MustUnmarshal
func MustUnmarshalString(s string, obj interface{}) {
MustUnmarshal([]byte(s), obj)
}
// MustRemarshal marshal and unmarshal, and check error
func MustRemarshal(from interface{}, to interface{}) {
b, err := json.Marshal(from)
E2P(err)
err = json.Unmarshal(b, to)
E2P(err)
}
// Logf 输出日志
func Logf(format string, args ...interface{}) {
msg := fmt.Sprintf(format, args...)
n := time.Now()
ts := fmt.Sprintf("%d-%02d-%02d %02d:%02d:%02d.%03d", n.Year(), n.Month(), n.Day(), n.Hour(), n.Minute(), n.Second(), n.Nanosecond()/1000000)
var file string
var line int
for i := 1; ; i++ {
_, file, line, _ = runtime.Caller(i)
if strings.Contains(file, "dtm") {
break
}
}
fmt.Printf("%s %s:%d %s\n", ts, path.Base(file), line, msg)
}
// LogRedf 采用红色打印错误类信息
func LogRedf(fmt string, args ...interface{}) {
Logf("\x1b[31m\n"+fmt+"\x1b[0m\n", args...)
}
// RestyClient the resty object
var RestyClient = resty.New()
func init() {
// RestyClient.SetTimeout(3 * time.Second)
// RestyClient.SetRetryCount(2)
// RestyClient.SetRetryWaitTime(1 * time.Second)
RestyClient.OnBeforeRequest(func(c *resty.Client, r *resty.Request) error {
r.URL = MayReplaceLocalhost(r.URL)
Logf("requesting: %s %s %v %v", r.Method, r.URL, r.Body, r.QueryParam)
return nil
})
RestyClient.OnAfterResponse(func(c *resty.Client, resp *resty.Response) error {
r := resp.Request
Logf("requested: %s %s %s", r.Method, r.URL, resp.String())
return nil
})
}
// CheckRestySuccess panic if error or resp not success
func CheckRestySuccess(resp *resty.Response, err error) {
E2P(err)
if !strings.Contains(resp.String(), "SUCCESS") {
panic(fmt.Errorf("resty response not success: %s", resp.String()))
}
}
// MustGetwd must version of os.Getwd
func MustGetwd() string {
wd, err := os.Getwd()
E2P(err)
return wd
}
// GetCurrentCodeDir name is clear
func GetCurrentCodeDir() string {
_, file, _, _ := runtime.Caller(1)
return filepath.Dir(file)
}
// GetProjectDir name is clear
func GetProjectDir() string {
_, file, _, _ := runtime.Caller(1)
for ; !strings.HasSuffix(file, "/dtm"); file = filepath.Dir(file) {
}
return file
}
// GetFuncName get current call func name
func GetFuncName() string {
pc, _, _, _ := runtime.Caller(1)
return runtime.FuncForPC(pc).Name()
}
// MayReplaceLocalhost when run in docker compose, change localhost to host.docker.internal for accessing host network
func MayReplaceLocalhost(host string) string {
if os.Getenv("IS_DOCKER_COMPOSE") != "" {
return strings.Replace(host, "localhost", "host.docker.internal", 1)
}
return host
}

View File

@ -1,9 +1,5 @@
package dtmcli
import (
"github.com/yedf/dtm/common"
)
// Msg reliable msg type
type Msg struct {
MsgData
@ -39,10 +35,10 @@ func NewMsg(server string, gid string) *Msg {
// Add add a new step
func (s *Msg) Add(action string, postData interface{}) *Msg {
common.Logf("msg %s Add %s %v", s.MsgData.Gid, action, postData)
Logf("msg %s Add %s %v", s.MsgData.Gid, action, postData)
step := MsgStep{
Action: action,
Data: common.MustMarshalString(postData),
Data: MustMarshalString(postData),
}
s.Steps = append(s.Steps, step)
return s
@ -50,7 +46,7 @@ func (s *Msg) Add(action string, postData interface{}) *Msg {
// Prepare prepare the msg
func (s *Msg) Prepare(queryPrepared string) error {
s.QueryPrepared = common.OrString(queryPrepared, s.QueryPrepared)
s.QueryPrepared = OrString(queryPrepared, s.QueryPrepared)
return s.CallDtm(&s.MsgData, "prepare")
}

View File

@ -1,9 +1,5 @@
package dtmcli
import (
"github.com/yedf/dtm/common"
)
// Saga struct of saga
type Saga struct {
SagaData
@ -39,11 +35,11 @@ func NewSaga(server string, gid string) *Saga {
// Add add a saga step
func (s *Saga) Add(action string, compensate string, postData interface{}) *Saga {
common.Logf("saga %s Add %s %s %v", s.SagaData.Gid, action, compensate, postData)
Logf("saga %s Add %s %s %v", s.SagaData.Gid, action, compensate, postData)
step := SagaStep{
Action: action,
Compensate: compensate,
Data: common.MustMarshalString(postData),
Data: MustMarshalString(postData),
}
s.Steps = append(s.Steps, step)
return s

View File

@ -5,7 +5,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/go-resty/resty/v2"
"github.com/yedf/dtm/common"
)
// Tcc struct of tcc
@ -34,7 +33,7 @@ func TccGlobalTransaction(dtm string, gid string, tccFunc TccGlobalFunc) (rerr e
// 小概率情况下prepare成功了但是由于网络状况导致上面Failure那么不执行下面defer的内容等待超时后再回滚标记事务失败也没有问题
defer func() {
x := recover()
operation := common.If(x == nil && rerr == nil, "submit", "abort").(string)
operation := If(x == nil && rerr == nil, "submit", "abort").(string)
err := tcc.CallDtm(data, operation)
if rerr == nil {
rerr = err
@ -69,7 +68,7 @@ func (t *Tcc) CallBranch(body interface{}, tryURL string, confirmURL string, can
"branch_id": branchID,
"trans_type": "tcc",
"status": "prepared",
"data": string(common.MustMarshal(body)),
"data": string(MustMarshal(body)),
"try": tryURL,
"confirm": confirmURL,
"cancel": cancelURL,
@ -77,9 +76,9 @@ func (t *Tcc) CallBranch(body interface{}, tryURL string, confirmURL string, can
if err != nil {
return nil, err
}
resp, err := common.RestyClient.R().
resp, err := RestyClient.R().
SetBody(body).
SetQueryParams(common.MS{
SetQueryParams(MS{
"dtm": t.Dtm,
"gid": t.Gid,
"branch_id": branchID,

View File

@ -1,25 +1,95 @@
package dtmcli
import (
"database/sql"
"errors"
"fmt"
"strings"
"github.com/gin-gonic/gin"
"github.com/go-resty/resty/v2"
"github.com/yedf/dtm/common"
)
// M a short name
type M = map[string]interface{}
// MS a short name
type MS = map[string]string
// MustGenGid generate a new gid
func MustGenGid(server string) string {
res := common.MS{}
resp, err := common.RestyClient.R().SetResult(&res).Get(server + "/newGid")
res := MS{}
resp, err := RestyClient.R().SetResult(&res).Get(server + "/newGid")
if err != nil || res["gid"] == "" {
panic(fmt.Errorf("newGid error: %v, resp: %s", err, resp))
}
return res["gid"]
}
var sqlDbs = map[string]*sql.DB{}
// SdbGet get pooled sql.DB
func SdbGet(conf map[string]string) *sql.DB {
dsn := GetDsn(conf)
if sqlDbs[dsn] == nil {
sqlDbs[dsn] = SdbAlone(conf)
}
return sqlDbs[dsn]
}
// SdbAlone get a standalone db connection
func SdbAlone(conf map[string]string) *sql.DB {
dsn := GetDsn(conf)
Logf("opening alone %s: %s", conf["driver"], strings.Replace(dsn, conf["password"], "****", 1))
mdb, err := sql.Open(conf["driver"], dsn)
E2P(err)
return mdb
}
// SdbExec use raw db to exec
func SdbExec(db *sql.DB, sql string, values ...interface{}) (affected int64, rerr error) {
r, rerr := db.Exec(sql, values...)
if rerr == nil {
affected, rerr = r.RowsAffected()
Logf("affected: %d for %s %v", affected, sql, values)
} else {
LogRedf("exec error: %v for %s %v", rerr, sql, values)
}
return
}
// StxExec use raw tx to exec
func StxExec(tx *sql.Tx, sql string, values ...interface{}) (affected int64, rerr error) {
r, rerr := tx.Exec(sql, values...)
if rerr == nil {
affected, rerr = r.RowsAffected()
Logf("affected: %d for %s %v", affected, sql, values)
} else {
LogRedf("exec error: %v for %s %v", rerr, sql, values)
}
return
}
// StxQueryRow use raw tx to query row
func StxQueryRow(tx *sql.Tx, query string, args ...interface{}) *sql.Row {
Logf("querying: "+query, args...)
return tx.QueryRow(query, args...)
}
// GetDsn get dsn from map config
func GetDsn(conf map[string]string) string {
conf["host"] = MayReplaceLocalhost(conf["host"])
driver := conf["driver"]
dsn := MS{
"mysql": fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local",
conf["user"], conf["password"], conf["host"], conf["port"], conf["database"]),
"postgres": fmt.Sprintf("host=%s user=%s password=%s dbname='%s' port=%s sslmode=disable TimeZone=Asia/Shanghai",
conf["host"], conf["user"], conf["password"], conf["database"], conf["port"]),
}[driver]
PanicIf(dsn == "", fmt.Errorf("unknow driver: %s", driver))
return dsn
}
// CheckResponse 检查Response返回错误
func CheckResponse(resp *resty.Response, err error) error {
if err == nil && resp != nil {
@ -38,7 +108,7 @@ func CheckResult(res interface{}, err error) error {
if ok {
return CheckResponse(resp, err)
}
if res != nil && strings.Contains(common.MustMarshalString(res), "FAILURE") {
if res != nil && strings.Contains(MustMarshalString(res), "FAILURE") {
return ErrFailure
}
return err
@ -86,11 +156,11 @@ func TransBaseFromReq(c *gin.Context) *TransBase {
// CallDtm 调用dtm服务器返回事务的状态
func (tb *TransBase) CallDtm(body interface{}, operation string) error {
params := common.MS{}
params := MS{}
if tb.WaitResult {
params["wait_result"] = "1"
}
resp, err := common.RestyClient.R().SetQueryParams(params).
resp, err := RestyClient.R().SetQueryParams(params).
SetResult(&TransResult{}).SetBody(body).Post(fmt.Sprintf("%s/%s", tb.Dtm, operation))
if err != nil {
return err
@ -106,7 +176,7 @@ func (tb *TransBase) CallDtm(body interface{}, operation string) error {
var ErrFailure = errors.New("transaction FAILURE")
// ResultSuccess 表示返回成功,可以进行下一步
var ResultSuccess = common.M{"dtm_result": "SUCCESS"}
var ResultSuccess = M{"dtm_result": "SUCCESS"}
// ResultFailure 表示返回失败,要求回滚
var ResultFailure = common.M{"dtm_result": "FAILURE"}
var ResultFailure = M{"dtm_result": "FAILURE"}

View File

@ -5,16 +5,15 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/yedf/dtm/common"
)
func TestTypes(t *testing.T) {
err := common.CatchP(func() {
err := CatchP(func() {
idGen := IDGenerator{parentID: "12345678901234567890123"}
idGen.NewBranchID()
})
assert.Error(t, err)
err = common.CatchP(func() {
err = CatchP(func() {
idGen := IDGenerator{branchID: 99}
idGen.NewBranchID()
})

View File

@ -7,13 +7,9 @@ import (
"github.com/gin-gonic/gin"
"github.com/go-resty/resty/v2"
"github.com/yedf/dtm/common"
)
// M alias
type M = map[string]interface{}
var e2p = common.E2P
var e2p = E2P
// XaGlobalFunc type of xa global function
type XaGlobalFunc func(xa *Xa) (*resty.Response, error)
@ -59,10 +55,10 @@ func NewXaClient(server string, mysqlConf map[string]string, callbackURL string,
// HandleCallback 处理commit/rollback的回调
func (xc *XaClient) HandleCallback(gid string, branchID string, action string) (interface{}, error) {
db := common.SdbAlone(xc.Conf)
db := SdbAlone(xc.Conf)
defer db.Close()
xaID := gid + "-" + branchID
_, err := common.SdbExec(db, fmt.Sprintf("xa %s '%s'", action, xaID))
_, err := SdbExec(db, fmt.Sprintf("xa %s '%s'", action, xaID))
return ResultSuccess, err
}
@ -73,13 +69,13 @@ func (xc *XaClient) XaLocalTransaction(c *gin.Context, xaFunc XaLocalFunc) (ret
xa.Dtm = xc.Server
branchID := xa.NewBranchID()
xaBranch := xa.Gid + "-" + branchID
db := common.SdbAlone(xc.Conf)
db := SdbAlone(xc.Conf)
defer func() { db.Close() }()
defer func() {
x := recover()
_, err := common.SdbExec(db, fmt.Sprintf("XA end '%s'", xaBranch))
_, err := SdbExec(db, fmt.Sprintf("XA end '%s'", xaBranch))
if x == nil && rerr == nil && err == nil {
_, err = common.SdbExec(db, fmt.Sprintf("XA prepare '%s'", xaBranch))
_, err = SdbExec(db, fmt.Sprintf("XA prepare '%s'", xaBranch))
}
if rerr == nil {
rerr = err
@ -88,7 +84,7 @@ func (xc *XaClient) XaLocalTransaction(c *gin.Context, xaFunc XaLocalFunc) (ret
panic(x)
}
}()
_, rerr = common.SdbExec(db, fmt.Sprintf("XA start '%s'", xaBranch))
_, rerr = SdbExec(db, fmt.Sprintf("XA start '%s'", xaBranch))
if rerr != nil {
return
}
@ -116,7 +112,7 @@ func (xc *XaClient) XaGlobalTransaction(gid string, xaFunc XaGlobalFunc) (rerr e
// 小概率情况下prepare成功了但是由于网络状况导致上面Failure那么不执行下面defer的内容等待超时后再回滚标记事务失败也没有问题
defer func() {
x := recover()
operation := common.If(x != nil || rerr != nil, "abort", "submit").(string)
operation := If(x != nil || rerr != nil, "abort", "submit").(string)
err := xa.CallDtm(data, operation)
if rerr == nil { // 如果用户函数没有返回错误那么返回dtm的
rerr = err
@ -133,9 +129,9 @@ func (xc *XaClient) XaGlobalTransaction(gid string, xaFunc XaGlobalFunc) (rerr e
// CallBranch call a xa branch
func (x *Xa) CallBranch(body interface{}, url string) (*resty.Response, error) {
branchID := x.NewBranchID()
resp, err := common.RestyClient.R().
resp, err := RestyClient.R().
SetBody(body).
SetQueryParams(common.MS{
SetQueryParams(MS{
"gid": x.Gid,
"branch_id": branchID,
"trans_type": "xa",

View File

@ -16,6 +16,15 @@ var DtmServer = examples.DtmServer
var Busi = examples.Busi
var app *gin.Engine
// BarrierModel barrier model for gorm
type BarrierModel struct {
common.ModelBase
dtmcli.TransInfo
}
// TableName gorm table name
func (BarrierModel) TableName() string { return "dtm_barrier.barrier" }
func resetXaData() {
if config.DB["driver"] != "mysql" {
return
@ -156,9 +165,9 @@ func TestSqlDB(t *testing.T) {
return nil, fmt.Errorf("gid2 error")
})
asserts.Error(err, fmt.Errorf("gid2 error"))
dbr := db.Model(&dtmcli.BarrierModel{}).Where("gid=?", "gid1").Find(&[]dtmcli.BarrierModel{})
dbr := db.Model(&BarrierModel{}).Where("gid=?", "gid1").Find(&[]BarrierModel{})
asserts.Equal(dbr.RowsAffected, int64(1))
dbr = db.Model(&dtmcli.BarrierModel{}).Where("gid=?", "gid2").Find(&[]dtmcli.BarrierModel{})
dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{})
asserts.Equal(dbr.RowsAffected, int64(0))
gid2Res := common.M{"result": "first"}
_, err = dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.Tx) (interface{}, error) {
@ -166,7 +175,7 @@ func TestSqlDB(t *testing.T) {
return gid2Res, nil
})
asserts.Nil(err)
dbr = db.Model(&dtmcli.BarrierModel{}).Where("gid=?", "gid2").Find(&[]dtmcli.BarrierModel{})
dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{})
asserts.Equal(dbr.RowsAffected, int64(1))
newResult, err := dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.Tx) (interface{}, error) {
common.Logf("submit gid2")

View File

@ -7,8 +7,8 @@ import (
"strings"
"github.com/bwmarrin/snowflake"
"github.com/sirupsen/logrus"
"github.com/yedf/dtm/common"
"github.com/yedf/dtm/dtmcli"
)
// M a short name
@ -38,13 +38,13 @@ var TransProcessedTestChan chan string = nil
// WaitTransProcessed only for test usage. wait for transaction processed once
func WaitTransProcessed(gid string) {
common.Logf("waiting for gid %s", gid)
dtmcli.Logf("waiting for gid %s", gid)
id := <-TransProcessedTestChan
for id != gid {
logrus.Errorf("-------id %s not match gid %s", id, gid)
dtmcli.LogRedf("-------id %s not match gid %s", id, gid)
id = <-TransProcessedTestChan
}
common.Logf("finish for gid %s", gid)
dtmcli.Logf("finish for gid %s", gid)
}
var gNode *snowflake.Node = nil

1
go.mod
View File

@ -12,7 +12,6 @@ require (
github.com/kr/pretty v0.1.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.1 // indirect
github.com/sirupsen/logrus v1.7.0
github.com/stretchr/testify v1.7.0
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
gopkg.in/yaml.v2 v2.3.0