partial dtmcli refactor
This commit is contained in:
parent
9e8359c893
commit
fe65ff96e4
@ -4,7 +4,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/yedf/dtm/dtmcli"
|
||||||
"github.com/yedf/dtm/dtmsvr"
|
"github.com/yedf/dtm/dtmsvr"
|
||||||
"github.com/yedf/dtm/examples"
|
"github.com/yedf/dtm/examples"
|
||||||
)
|
)
|
||||||
@ -68,7 +68,7 @@ func main() {
|
|||||||
examples.TccBarrierAddRoute(app)
|
examples.TccBarrierAddRoute(app)
|
||||||
examples.TccBarrierFireRequest()
|
examples.TccBarrierFireRequest()
|
||||||
} else {
|
} else {
|
||||||
logrus.Fatalf("unknown arg: %s", os.Args[1])
|
dtmcli.LogRedf("unknown arg: %s", os.Args[1])
|
||||||
}
|
}
|
||||||
wait()
|
wait()
|
||||||
}
|
}
|
||||||
|
|||||||
@ -7,6 +7,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
_ "github.com/go-sql-driver/mysql"
|
_ "github.com/go-sql-driver/mysql"
|
||||||
|
"github.com/yedf/dtm/dtmcli"
|
||||||
|
|
||||||
// _ "github.com/lib/pq"
|
// _ "github.com/lib/pq"
|
||||||
|
|
||||||
"gorm.io/driver/mysql"
|
"gorm.io/driver/mysql"
|
||||||
@ -15,12 +17,6 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// M a short name
|
|
||||||
type M = map[string]interface{}
|
|
||||||
|
|
||||||
// MS a short name
|
|
||||||
type MS = map[string]string
|
|
||||||
|
|
||||||
// ModelBase model base for gorm to provide base fields
|
// ModelBase model base for gorm to provide base fields
|
||||||
type ModelBase struct {
|
type ModelBase struct {
|
||||||
ID uint
|
ID uint
|
||||||
@ -38,7 +34,6 @@ func getGormDialator(driver string, dsn string) gorm.Dialector {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var dbs = map[string]*DB{}
|
var dbs = map[string]*DB{}
|
||||||
var sqlDbs = map[string]*sql.DB{}
|
|
||||||
|
|
||||||
// DB provide more func over gorm.DB
|
// DB provide more func over gorm.DB
|
||||||
type DB struct {
|
type DB struct {
|
||||||
@ -60,7 +55,7 @@ func (m *DB) NoMust() *DB {
|
|||||||
// ToSQLDB get the sql.DB
|
// ToSQLDB get the sql.DB
|
||||||
func (m *DB) ToSQLDB() *sql.DB {
|
func (m *DB) ToSQLDB() *sql.DB {
|
||||||
d, err := m.DB.DB()
|
d, err := m.DB.DB()
|
||||||
E2P(err)
|
dtmcli.E2P(err)
|
||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,7 +73,7 @@ func (op *tracePlugin) Initialize(db *gorm.DB) (err error) {
|
|||||||
after := func(db *gorm.DB) {
|
after := func(db *gorm.DB) {
|
||||||
_ts, _ := db.InstanceGet("ivy.startTime")
|
_ts, _ := db.InstanceGet("ivy.startTime")
|
||||||
sql := db.Dialector.Explain(db.Statement.SQL.String(), db.Statement.Vars...)
|
sql := db.Dialector.Explain(db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
Logf("used: %d ms affected: %d sql is: %s", time.Since(_ts.(time.Time)).Milliseconds(), db.RowsAffected, sql)
|
dtmcli.Logf("used: %d ms affected: %d sql is: %s", time.Since(_ts.(time.Time)).Milliseconds(), db.RowsAffected, sql)
|
||||||
if v, ok := db.InstanceGet("ivy.must"); ok && v.(bool) {
|
if v, ok := db.InstanceGet("ivy.must"); ok && v.(bool) {
|
||||||
if db.Error != nil && db.Error != gorm.ErrRecordNotFound {
|
if db.Error != nil && db.Error != gorm.ErrRecordNotFound {
|
||||||
panic(db.Error)
|
panic(db.Error)
|
||||||
@ -89,7 +84,7 @@ func (op *tracePlugin) Initialize(db *gorm.DB) (err error) {
|
|||||||
beforeName := "cb_before"
|
beforeName := "cb_before"
|
||||||
afterName := "cb_after"
|
afterName := "cb_after"
|
||||||
|
|
||||||
Logf("installing db plugin: %s", op.Name())
|
dtmcli.Logf("installing db plugin: %s", op.Name())
|
||||||
// 开始前
|
// 开始前
|
||||||
_ = db.Callback().Create().Before("gorm:before_create").Register(beforeName, before)
|
_ = db.Callback().Create().Before("gorm:before_create").Register(beforeName, before)
|
||||||
_ = db.Callback().Query().Before("gorm:query").Register(beforeName, before)
|
_ = db.Callback().Query().Before("gorm:query").Register(beforeName, before)
|
||||||
@ -108,79 +103,17 @@ func (op *tracePlugin) Initialize(db *gorm.DB) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDsn get dsn from map config
|
|
||||||
func GetDsn(conf map[string]string) string {
|
|
||||||
conf["host"] = MayReplaceLocalhost(conf["host"])
|
|
||||||
driver := conf["driver"]
|
|
||||||
dsn := MS{
|
|
||||||
"mysql": fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local",
|
|
||||||
conf["user"], conf["password"], conf["host"], conf["port"], conf["database"]),
|
|
||||||
"postgres": fmt.Sprintf("host=%s user=%s password=%s dbname='%s' port=%s sslmode=disable TimeZone=Asia/Shanghai",
|
|
||||||
conf["host"], conf["user"], conf["password"], conf["database"], conf["port"]),
|
|
||||||
}[driver]
|
|
||||||
PanicIf(dsn == "", fmt.Errorf("unknow driver: %s", driver))
|
|
||||||
return dsn
|
|
||||||
}
|
|
||||||
|
|
||||||
// DbGet get db connection for specified conf
|
// DbGet get db connection for specified conf
|
||||||
func DbGet(conf map[string]string) *DB {
|
func DbGet(conf map[string]string) *DB {
|
||||||
dsn := GetDsn(conf)
|
dsn := dtmcli.GetDsn(conf)
|
||||||
if dbs[dsn] == nil {
|
if dbs[dsn] == nil {
|
||||||
Logf("connecting %s", strings.Replace(dsn, conf["password"], "****", 1))
|
dtmcli.Logf("connecting %s", strings.Replace(dsn, conf["password"], "****", 1))
|
||||||
db1, err := gorm.Open(getGormDialator(conf["driver"], dsn), &gorm.Config{
|
db1, err := gorm.Open(getGormDialator(conf["driver"], dsn), &gorm.Config{
|
||||||
SkipDefaultTransaction: true,
|
SkipDefaultTransaction: true,
|
||||||
})
|
})
|
||||||
E2P(err)
|
dtmcli.E2P(err)
|
||||||
db1.Use(&tracePlugin{})
|
db1.Use(&tracePlugin{})
|
||||||
dbs[dsn] = &DB{DB: db1}
|
dbs[dsn] = &DB{DB: db1}
|
||||||
}
|
}
|
||||||
return dbs[dsn]
|
return dbs[dsn]
|
||||||
}
|
}
|
||||||
|
|
||||||
// SdbGet get pooled sql.DB
|
|
||||||
func SdbGet(conf map[string]string) *sql.DB {
|
|
||||||
dsn := GetDsn(conf)
|
|
||||||
if sqlDbs[dsn] == nil {
|
|
||||||
sqlDbs[dsn] = SdbAlone(conf)
|
|
||||||
}
|
|
||||||
return sqlDbs[dsn]
|
|
||||||
}
|
|
||||||
|
|
||||||
// SdbAlone get a standalone db connection
|
|
||||||
func SdbAlone(conf map[string]string) *sql.DB {
|
|
||||||
dsn := GetDsn(conf)
|
|
||||||
Logf("opening alone %s: %s", conf["driver"], strings.Replace(dsn, conf["password"], "****", 1))
|
|
||||||
mdb, err := sql.Open(conf["driver"], dsn)
|
|
||||||
E2P(err)
|
|
||||||
return mdb
|
|
||||||
}
|
|
||||||
|
|
||||||
// SdbExec use raw db to exec
|
|
||||||
func SdbExec(db *sql.DB, sql string, values ...interface{}) (affected int64, rerr error) {
|
|
||||||
r, rerr := db.Exec(sql, values...)
|
|
||||||
if rerr == nil {
|
|
||||||
affected, rerr = r.RowsAffected()
|
|
||||||
Logf("affected: %d for %s %v", affected, sql, values)
|
|
||||||
} else {
|
|
||||||
LogRedf("exec error: %v for %s %v", rerr, sql, values)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// StxExec use raw tx to exec
|
|
||||||
func StxExec(tx *sql.Tx, sql string, values ...interface{}) (affected int64, rerr error) {
|
|
||||||
r, rerr := tx.Exec(sql, values...)
|
|
||||||
if rerr == nil {
|
|
||||||
affected, rerr = r.RowsAffected()
|
|
||||||
Logf("affected: %d for %s %v", affected, sql, values)
|
|
||||||
} else {
|
|
||||||
LogRedf("exec error: %v for %s %v", rerr, sql, values)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// StxQueryRow use raw tx to query row
|
|
||||||
func StxQueryRow(tx *sql.Tx, query string, args ...interface{}) *sql.Row {
|
|
||||||
Logf("querying: "+query, args...)
|
|
||||||
return tx.QueryRow(query, args...)
|
|
||||||
}
|
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/go-playground/assert/v2"
|
"github.com/go-playground/assert/v2"
|
||||||
|
"github.com/yedf/dtm/dtmcli"
|
||||||
)
|
)
|
||||||
|
|
||||||
type testConfig struct {
|
type testConfig struct {
|
||||||
@ -13,14 +14,14 @@ type testConfig struct {
|
|||||||
var config = testConfig{}
|
var config = testConfig{}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
InitConfig(GetProjectDir(), &config)
|
InitConfig(dtmcli.GetProjectDir(), &config)
|
||||||
config.DB["database"] = ""
|
config.DB["database"] = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDb(t *testing.T) {
|
func TestDb(t *testing.T) {
|
||||||
db := DbGet(config.DB)
|
db := DbGet(config.DB)
|
||||||
err := func() (rerr error) {
|
err := func() (rerr error) {
|
||||||
defer P2E(&rerr)
|
defer dtmcli.P2E(&rerr)
|
||||||
dbr := db.NoMust().Exec("select a")
|
dbr := db.NoMust().Exec("select a")
|
||||||
assert.NotEqual(t, nil, dbr.Error)
|
assert.NotEqual(t, nil, dbr.Error)
|
||||||
db.Must().Exec("select a")
|
db.Must().Exec("select a")
|
||||||
@ -30,10 +31,10 @@ func TestDb(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDbAlone(t *testing.T) {
|
func TestDbAlone(t *testing.T) {
|
||||||
db := SdbAlone(config.DB)
|
db := dtmcli.SdbAlone(config.DB)
|
||||||
_, err := SdbExec(db, "select 1")
|
_, err := dtmcli.SdbExec(db, "select 1")
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
db.Close()
|
db.Close()
|
||||||
_, err = SdbExec(db, "select 1")
|
_, err = dtmcli.SdbExec(db, "select 1")
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
}
|
}
|
||||||
|
|||||||
182
common/utils.go
182
common/utils.go
@ -3,113 +3,14 @@ package common
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"os"
|
|
||||||
"path"
|
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-resty/resty/v2"
|
"github.com/go-resty/resty/v2"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
yaml "gopkg.in/yaml.v2"
|
yaml "gopkg.in/yaml.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// P2E panic to error
|
|
||||||
func P2E(perr *error) {
|
|
||||||
if x := recover(); x != nil {
|
|
||||||
if e, ok := x.(error); ok {
|
|
||||||
*perr = e
|
|
||||||
} else {
|
|
||||||
panic(x)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// E2P error to panic
|
|
||||||
func E2P(err error) {
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CatchP catch panic to error
|
|
||||||
func CatchP(f func()) (rerr error) {
|
|
||||||
defer P2E(&rerr)
|
|
||||||
f()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// PanicIf name is clear
|
|
||||||
func PanicIf(cond bool, err error) {
|
|
||||||
if cond {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustAtoi 走must逻辑
|
|
||||||
func MustAtoi(s string) int {
|
|
||||||
r, err := strconv.Atoi(s)
|
|
||||||
if err != nil {
|
|
||||||
E2P(errors.New("convert to int error: " + s))
|
|
||||||
}
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// OrString return the first not empty string
|
|
||||||
func OrString(ss ...string) string {
|
|
||||||
for _, s := range ss {
|
|
||||||
if s != "" {
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// If ternary operator
|
|
||||||
func If(condition bool, trueObj interface{}, falseObj interface{}) interface{} {
|
|
||||||
if condition {
|
|
||||||
return trueObj
|
|
||||||
}
|
|
||||||
return falseObj
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustMarshal checked version for marshal
|
|
||||||
func MustMarshal(v interface{}) []byte {
|
|
||||||
b, err := json.Marshal(v)
|
|
||||||
E2P(err)
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustMarshalString string version of MustMarshal
|
|
||||||
func MustMarshalString(v interface{}) string {
|
|
||||||
return string(MustMarshal(v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustUnmarshal checked version for unmarshal
|
|
||||||
func MustUnmarshal(b []byte, obj interface{}) {
|
|
||||||
err := json.Unmarshal(b, obj)
|
|
||||||
E2P(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustUnmarshalString string version of MustUnmarshal
|
|
||||||
func MustUnmarshalString(s string, obj interface{}) {
|
|
||||||
MustUnmarshal([]byte(s), obj)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustRemarshal marshal and unmarshal, and check error
|
|
||||||
func MustRemarshal(from interface{}, to interface{}) {
|
|
||||||
b, err := json.Marshal(from)
|
|
||||||
E2P(err)
|
|
||||||
err = json.Unmarshal(b, to)
|
|
||||||
E2P(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetGinApp init and return gin
|
// GetGinApp init and return gin
|
||||||
func GetGinApp() *gin.Engine {
|
func GetGinApp() *gin.Engine {
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
@ -157,54 +58,6 @@ func WrapHandler(fn func(*gin.Context) (interface{}, error)) gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RestyClient the resty object
|
|
||||||
var RestyClient = resty.New()
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
// RestyClient.SetTimeout(3 * time.Second)
|
|
||||||
// RestyClient.SetRetryCount(2)
|
|
||||||
// RestyClient.SetRetryWaitTime(1 * time.Second)
|
|
||||||
RestyClient.OnBeforeRequest(func(c *resty.Client, r *resty.Request) error {
|
|
||||||
r.URL = MayReplaceLocalhost(r.URL)
|
|
||||||
Logf("requesting: %s %s %v %v", r.Method, r.URL, r.Body, r.QueryParam)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
RestyClient.OnAfterResponse(func(c *resty.Client, resp *resty.Response) error {
|
|
||||||
r := resp.Request
|
|
||||||
Logf("requested: %s %s %s", r.Method, r.URL, resp.String())
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// CheckRestySuccess panic if error or resp not success
|
|
||||||
func CheckRestySuccess(resp *resty.Response, err error) {
|
|
||||||
E2P(err)
|
|
||||||
if !strings.Contains(resp.String(), "SUCCESS") {
|
|
||||||
panic(fmt.Errorf("resty response not success: %s", resp.String()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Logf 输出日志
|
|
||||||
func Logf(format string, args ...interface{}) {
|
|
||||||
msg := fmt.Sprintf(format, args...)
|
|
||||||
n := time.Now()
|
|
||||||
ts := fmt.Sprintf("%d-%02d-%02d %02d:%02d:%02d.%03d", n.Year(), n.Month(), n.Day(), n.Hour(), n.Minute(), n.Second(), n.Nanosecond()/1000000)
|
|
||||||
var file string
|
|
||||||
var line int
|
|
||||||
for i := 1; ; i++ {
|
|
||||||
_, file, line, _ = runtime.Caller(i)
|
|
||||||
if strings.Contains(file, "dtm") {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fmt.Printf("%s %s:%d %s\n", ts, path.Base(file), line, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
// LogRedf 采用红色打印错误类信息
|
|
||||||
func LogRedf(fmt string, args ...interface{}) {
|
|
||||||
logrus.Errorf("\x1b[31m\n"+fmt+"\x1b[0m\n", args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// InitConfig init config
|
// InitConfig init config
|
||||||
func InitConfig(dir string, config interface{}) {
|
func InitConfig(dir string, config interface{}) {
|
||||||
cont, err := ioutil.ReadFile(dir + "/conf.yml")
|
cont, err := ioutil.ReadFile(dir + "/conf.yml")
|
||||||
@ -216,38 +69,3 @@ func InitConfig(dir string, config interface{}) {
|
|||||||
err = yaml.Unmarshal(cont, config)
|
err = yaml.Unmarshal(cont, config)
|
||||||
E2P(err)
|
E2P(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MustGetwd must version of os.Getwd
|
|
||||||
func MustGetwd() string {
|
|
||||||
wd, err := os.Getwd()
|
|
||||||
E2P(err)
|
|
||||||
return wd
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetCurrentCodeDir name is clear
|
|
||||||
func GetCurrentCodeDir() string {
|
|
||||||
_, file, _, _ := runtime.Caller(1)
|
|
||||||
return filepath.Dir(file)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetProjectDir name is clear
|
|
||||||
func GetProjectDir() string {
|
|
||||||
_, file, _, _ := runtime.Caller(1)
|
|
||||||
for ; !strings.HasSuffix(file, "/dtm"); file = filepath.Dir(file) {
|
|
||||||
}
|
|
||||||
return file
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetFuncName get current call func name
|
|
||||||
func GetFuncName() string {
|
|
||||||
pc, _, _, _ := runtime.Caller(1)
|
|
||||||
return runtime.FuncForPC(pc).Name()
|
|
||||||
}
|
|
||||||
|
|
||||||
// MayReplaceLocalhost when run in docker compose, change localhost to host.docker.internal for accessing host network
|
|
||||||
func MayReplaceLocalhost(host string) string {
|
|
||||||
if os.Getenv("IS_DOCKER_COMPOSE") != "" {
|
|
||||||
return strings.Replace(host, "localhost", "host.docker.internal", 1)
|
|
||||||
}
|
|
||||||
return host
|
|
||||||
}
|
|
||||||
|
|||||||
@ -8,7 +8,6 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/yedf/dtm/common"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// BusiFunc type for busi func
|
// BusiFunc type for busi func
|
||||||
@ -47,20 +46,11 @@ func TransInfoFromQuery(qs url.Values) (*TransInfo, error) {
|
|||||||
return ti, nil
|
return ti, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// BarrierModel barrier model for gorm
|
|
||||||
type BarrierModel struct {
|
|
||||||
common.ModelBase
|
|
||||||
TransInfo
|
|
||||||
}
|
|
||||||
|
|
||||||
// TableName gorm table name
|
|
||||||
func (BarrierModel) TableName() string { return "dtm_barrier.barrier" }
|
|
||||||
|
|
||||||
func insertBarrier(tx *sql.Tx, transType string, gid string, branchID string, branchType string, reason string) (int64, error) {
|
func insertBarrier(tx *sql.Tx, transType string, gid string, branchID string, branchType string, reason string) (int64, error) {
|
||||||
if branchType == "" {
|
if branchType == "" {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
return common.StxExec(tx, "insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values(?,?,?,?,?)", transType, gid, branchID, branchType, reason)
|
return StxExec(tx, "insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values(?,?,?,?,?)", transType, gid, branchID, branchType, reason)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ThroughBarrierCall 子事务屏障,详细介绍见 https://zhuanlan.zhihu.com/p/388444465
|
// ThroughBarrierCall 子事务屏障,详细介绍见 https://zhuanlan.zhihu.com/p/388444465
|
||||||
@ -78,7 +68,7 @@ func ThroughBarrierCall(db *sql.DB, transInfo *TransInfo, busiCall BusiFunc) (re
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
common.Logf("result is %v error is %v", res, rerr)
|
Logf("result is %v error is %v", res, rerr)
|
||||||
if x := recover(); x != nil {
|
if x := recover(); x != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
panic(x)
|
panic(x)
|
||||||
@ -95,13 +85,13 @@ func ThroughBarrierCall(db *sql.DB, transInfo *TransInfo, busiCall BusiFunc) (re
|
|||||||
}[ti.BranchType]
|
}[ti.BranchType]
|
||||||
originAffected, _ := insertBarrier(tx, ti.TransType, ti.Gid, ti.BranchID, originType, ti.BranchType)
|
originAffected, _ := insertBarrier(tx, ti.TransType, ti.Gid, ti.BranchID, originType, ti.BranchType)
|
||||||
currentAffected, rerr := insertBarrier(tx, ti.TransType, ti.Gid, ti.BranchID, ti.BranchType, ti.BranchType)
|
currentAffected, rerr := insertBarrier(tx, ti.TransType, ti.Gid, ti.BranchID, ti.BranchType, ti.BranchType)
|
||||||
common.Logf("originAffected: %d currentAffected: %d", originAffected, currentAffected)
|
Logf("originAffected: %d currentAffected: %d", originAffected, currentAffected)
|
||||||
if (ti.BranchType == "cancel" || ti.BranchType == "compensate") && originAffected > 0 { // 这个是空补偿,返回成功
|
if (ti.BranchType == "cancel" || ti.BranchType == "compensate") && originAffected > 0 { // 这个是空补偿,返回成功
|
||||||
res = ResultSuccess
|
res = ResultSuccess
|
||||||
return
|
return
|
||||||
} else if currentAffected == 0 { // 插入不成功
|
} else if currentAffected == 0 { // 插入不成功
|
||||||
var result sql.NullString
|
var result sql.NullString
|
||||||
err := common.StxQueryRow(tx, "select result from dtm_barrier.barrier where trans_type=? and gid=? and branch_id=? and branch_type=? and reason=?",
|
err := StxQueryRow(tx, "select result from dtm_barrier.barrier where trans_type=? and gid=? and branch_id=? and branch_type=? and reason=?",
|
||||||
ti.TransType, ti.Gid, ti.BranchID, ti.BranchType, ti.BranchType).Scan(&result)
|
ti.TransType, ti.Gid, ti.BranchID, ti.BranchType, ti.BranchType).Scan(&result)
|
||||||
if err == sql.ErrNoRows { // 这个是悬挂操作,返回失败,AP收到这个返回,会尽快回滚
|
if err == sql.ErrNoRows { // 这个是悬挂操作,返回失败,AP收到这个返回,会尽快回滚
|
||||||
res = ResultFailure
|
res = ResultFailure
|
||||||
@ -121,8 +111,8 @@ func ThroughBarrierCall(db *sql.DB, transInfo *TransInfo, busiCall BusiFunc) (re
|
|||||||
}
|
}
|
||||||
res, rerr = busiCall(tx)
|
res, rerr = busiCall(tx)
|
||||||
if rerr == nil { // 正确返回了,需要将结果保存到数据库
|
if rerr == nil { // 正确返回了,需要将结果保存到数据库
|
||||||
sval := common.MustMarshalString(res)
|
sval := MustMarshalString(res)
|
||||||
_, rerr = common.StxExec(tx, "update dtm_barrier.barrier set result=? where trans_type=? and gid=? and branch_id=? and branch_type=?", sval,
|
_, rerr = StxExec(tx, "update dtm_barrier.barrier set result=? where trans_type=? and gid=? and branch_id=? and branch_type=?", sval,
|
||||||
ti.TransType, ti.Gid, ti.BranchID, ti.BranchType)
|
ti.TransType, ti.Gid, ti.BranchID, ti.BranchType)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|||||||
189
dtmcli/common.go
Normal file
189
dtmcli/common.go
Normal file
@ -0,0 +1,189 @@
|
|||||||
|
package dtmcli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-resty/resty/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// P2E panic to error
|
||||||
|
func P2E(perr *error) {
|
||||||
|
if x := recover(); x != nil {
|
||||||
|
if e, ok := x.(error); ok {
|
||||||
|
*perr = e
|
||||||
|
} else {
|
||||||
|
panic(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// E2P error to panic
|
||||||
|
func E2P(err error) {
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CatchP catch panic to error
|
||||||
|
func CatchP(f func()) (rerr error) {
|
||||||
|
defer P2E(&rerr)
|
||||||
|
f()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PanicIf name is clear
|
||||||
|
func PanicIf(cond bool, err error) {
|
||||||
|
if cond {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustAtoi 走must逻辑
|
||||||
|
func MustAtoi(s string) int {
|
||||||
|
r, err := strconv.Atoi(s)
|
||||||
|
if err != nil {
|
||||||
|
E2P(errors.New("convert to int error: " + s))
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// OrString return the first not empty string
|
||||||
|
func OrString(ss ...string) string {
|
||||||
|
for _, s := range ss {
|
||||||
|
if s != "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// If ternary operator
|
||||||
|
func If(condition bool, trueObj interface{}, falseObj interface{}) interface{} {
|
||||||
|
if condition {
|
||||||
|
return trueObj
|
||||||
|
}
|
||||||
|
return falseObj
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustMarshal checked version for marshal
|
||||||
|
func MustMarshal(v interface{}) []byte {
|
||||||
|
b, err := json.Marshal(v)
|
||||||
|
E2P(err)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustMarshalString string version of MustMarshal
|
||||||
|
func MustMarshalString(v interface{}) string {
|
||||||
|
return string(MustMarshal(v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustUnmarshal checked version for unmarshal
|
||||||
|
func MustUnmarshal(b []byte, obj interface{}) {
|
||||||
|
err := json.Unmarshal(b, obj)
|
||||||
|
E2P(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustUnmarshalString string version of MustUnmarshal
|
||||||
|
func MustUnmarshalString(s string, obj interface{}) {
|
||||||
|
MustUnmarshal([]byte(s), obj)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustRemarshal marshal and unmarshal, and check error
|
||||||
|
func MustRemarshal(from interface{}, to interface{}) {
|
||||||
|
b, err := json.Marshal(from)
|
||||||
|
E2P(err)
|
||||||
|
err = json.Unmarshal(b, to)
|
||||||
|
E2P(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Logf 输出日志
|
||||||
|
func Logf(format string, args ...interface{}) {
|
||||||
|
msg := fmt.Sprintf(format, args...)
|
||||||
|
n := time.Now()
|
||||||
|
ts := fmt.Sprintf("%d-%02d-%02d %02d:%02d:%02d.%03d", n.Year(), n.Month(), n.Day(), n.Hour(), n.Minute(), n.Second(), n.Nanosecond()/1000000)
|
||||||
|
var file string
|
||||||
|
var line int
|
||||||
|
for i := 1; ; i++ {
|
||||||
|
_, file, line, _ = runtime.Caller(i)
|
||||||
|
if strings.Contains(file, "dtm") {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Printf("%s %s:%d %s\n", ts, path.Base(file), line, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogRedf 采用红色打印错误类信息
|
||||||
|
func LogRedf(fmt string, args ...interface{}) {
|
||||||
|
Logf("\x1b[31m\n"+fmt+"\x1b[0m\n", args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RestyClient the resty object
|
||||||
|
var RestyClient = resty.New()
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// RestyClient.SetTimeout(3 * time.Second)
|
||||||
|
// RestyClient.SetRetryCount(2)
|
||||||
|
// RestyClient.SetRetryWaitTime(1 * time.Second)
|
||||||
|
RestyClient.OnBeforeRequest(func(c *resty.Client, r *resty.Request) error {
|
||||||
|
r.URL = MayReplaceLocalhost(r.URL)
|
||||||
|
Logf("requesting: %s %s %v %v", r.Method, r.URL, r.Body, r.QueryParam)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
RestyClient.OnAfterResponse(func(c *resty.Client, resp *resty.Response) error {
|
||||||
|
r := resp.Request
|
||||||
|
Logf("requested: %s %s %s", r.Method, r.URL, resp.String())
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckRestySuccess panic if error or resp not success
|
||||||
|
func CheckRestySuccess(resp *resty.Response, err error) {
|
||||||
|
E2P(err)
|
||||||
|
if !strings.Contains(resp.String(), "SUCCESS") {
|
||||||
|
panic(fmt.Errorf("resty response not success: %s", resp.String()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustGetwd must version of os.Getwd
|
||||||
|
func MustGetwd() string {
|
||||||
|
wd, err := os.Getwd()
|
||||||
|
E2P(err)
|
||||||
|
return wd
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCurrentCodeDir name is clear
|
||||||
|
func GetCurrentCodeDir() string {
|
||||||
|
_, file, _, _ := runtime.Caller(1)
|
||||||
|
return filepath.Dir(file)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProjectDir name is clear
|
||||||
|
func GetProjectDir() string {
|
||||||
|
_, file, _, _ := runtime.Caller(1)
|
||||||
|
for ; !strings.HasSuffix(file, "/dtm"); file = filepath.Dir(file) {
|
||||||
|
}
|
||||||
|
return file
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFuncName get current call func name
|
||||||
|
func GetFuncName() string {
|
||||||
|
pc, _, _, _ := runtime.Caller(1)
|
||||||
|
return runtime.FuncForPC(pc).Name()
|
||||||
|
}
|
||||||
|
|
||||||
|
// MayReplaceLocalhost when run in docker compose, change localhost to host.docker.internal for accessing host network
|
||||||
|
func MayReplaceLocalhost(host string) string {
|
||||||
|
if os.Getenv("IS_DOCKER_COMPOSE") != "" {
|
||||||
|
return strings.Replace(host, "localhost", "host.docker.internal", 1)
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
||||||
@ -1,9 +1,5 @@
|
|||||||
package dtmcli
|
package dtmcli
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/yedf/dtm/common"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Msg reliable msg type
|
// Msg reliable msg type
|
||||||
type Msg struct {
|
type Msg struct {
|
||||||
MsgData
|
MsgData
|
||||||
@ -39,10 +35,10 @@ func NewMsg(server string, gid string) *Msg {
|
|||||||
|
|
||||||
// Add add a new step
|
// Add add a new step
|
||||||
func (s *Msg) Add(action string, postData interface{}) *Msg {
|
func (s *Msg) Add(action string, postData interface{}) *Msg {
|
||||||
common.Logf("msg %s Add %s %v", s.MsgData.Gid, action, postData)
|
Logf("msg %s Add %s %v", s.MsgData.Gid, action, postData)
|
||||||
step := MsgStep{
|
step := MsgStep{
|
||||||
Action: action,
|
Action: action,
|
||||||
Data: common.MustMarshalString(postData),
|
Data: MustMarshalString(postData),
|
||||||
}
|
}
|
||||||
s.Steps = append(s.Steps, step)
|
s.Steps = append(s.Steps, step)
|
||||||
return s
|
return s
|
||||||
@ -50,7 +46,7 @@ func (s *Msg) Add(action string, postData interface{}) *Msg {
|
|||||||
|
|
||||||
// Prepare prepare the msg
|
// Prepare prepare the msg
|
||||||
func (s *Msg) Prepare(queryPrepared string) error {
|
func (s *Msg) Prepare(queryPrepared string) error {
|
||||||
s.QueryPrepared = common.OrString(queryPrepared, s.QueryPrepared)
|
s.QueryPrepared = OrString(queryPrepared, s.QueryPrepared)
|
||||||
return s.CallDtm(&s.MsgData, "prepare")
|
return s.CallDtm(&s.MsgData, "prepare")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,5 @@
|
|||||||
package dtmcli
|
package dtmcli
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/yedf/dtm/common"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Saga struct of saga
|
// Saga struct of saga
|
||||||
type Saga struct {
|
type Saga struct {
|
||||||
SagaData
|
SagaData
|
||||||
@ -39,11 +35,11 @@ func NewSaga(server string, gid string) *Saga {
|
|||||||
|
|
||||||
// Add add a saga step
|
// Add add a saga step
|
||||||
func (s *Saga) Add(action string, compensate string, postData interface{}) *Saga {
|
func (s *Saga) Add(action string, compensate string, postData interface{}) *Saga {
|
||||||
common.Logf("saga %s Add %s %s %v", s.SagaData.Gid, action, compensate, postData)
|
Logf("saga %s Add %s %s %v", s.SagaData.Gid, action, compensate, postData)
|
||||||
step := SagaStep{
|
step := SagaStep{
|
||||||
Action: action,
|
Action: action,
|
||||||
Compensate: compensate,
|
Compensate: compensate,
|
||||||
Data: common.MustMarshalString(postData),
|
Data: MustMarshalString(postData),
|
||||||
}
|
}
|
||||||
s.Steps = append(s.Steps, step)
|
s.Steps = append(s.Steps, step)
|
||||||
return s
|
return s
|
||||||
|
|||||||
@ -5,7 +5,6 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-resty/resty/v2"
|
"github.com/go-resty/resty/v2"
|
||||||
"github.com/yedf/dtm/common"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Tcc struct of tcc
|
// Tcc struct of tcc
|
||||||
@ -34,7 +33,7 @@ func TccGlobalTransaction(dtm string, gid string, tccFunc TccGlobalFunc) (rerr e
|
|||||||
// 小概率情况下,prepare成功了,但是由于网络状况导致上面Failure,那么不执行下面defer的内容,等待超时后再回滚标记事务失败,也没有问题
|
// 小概率情况下,prepare成功了,但是由于网络状况导致上面Failure,那么不执行下面defer的内容,等待超时后再回滚标记事务失败,也没有问题
|
||||||
defer func() {
|
defer func() {
|
||||||
x := recover()
|
x := recover()
|
||||||
operation := common.If(x == nil && rerr == nil, "submit", "abort").(string)
|
operation := If(x == nil && rerr == nil, "submit", "abort").(string)
|
||||||
err := tcc.CallDtm(data, operation)
|
err := tcc.CallDtm(data, operation)
|
||||||
if rerr == nil {
|
if rerr == nil {
|
||||||
rerr = err
|
rerr = err
|
||||||
@ -69,7 +68,7 @@ func (t *Tcc) CallBranch(body interface{}, tryURL string, confirmURL string, can
|
|||||||
"branch_id": branchID,
|
"branch_id": branchID,
|
||||||
"trans_type": "tcc",
|
"trans_type": "tcc",
|
||||||
"status": "prepared",
|
"status": "prepared",
|
||||||
"data": string(common.MustMarshal(body)),
|
"data": string(MustMarshal(body)),
|
||||||
"try": tryURL,
|
"try": tryURL,
|
||||||
"confirm": confirmURL,
|
"confirm": confirmURL,
|
||||||
"cancel": cancelURL,
|
"cancel": cancelURL,
|
||||||
@ -77,9 +76,9 @@ func (t *Tcc) CallBranch(body interface{}, tryURL string, confirmURL string, can
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
resp, err := common.RestyClient.R().
|
resp, err := RestyClient.R().
|
||||||
SetBody(body).
|
SetBody(body).
|
||||||
SetQueryParams(common.MS{
|
SetQueryParams(MS{
|
||||||
"dtm": t.Dtm,
|
"dtm": t.Dtm,
|
||||||
"gid": t.Gid,
|
"gid": t.Gid,
|
||||||
"branch_id": branchID,
|
"branch_id": branchID,
|
||||||
|
|||||||
@ -1,25 +1,95 @@
|
|||||||
package dtmcli
|
package dtmcli
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-resty/resty/v2"
|
"github.com/go-resty/resty/v2"
|
||||||
"github.com/yedf/dtm/common"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// M a short name
|
||||||
|
type M = map[string]interface{}
|
||||||
|
|
||||||
|
// MS a short name
|
||||||
|
type MS = map[string]string
|
||||||
|
|
||||||
// MustGenGid generate a new gid
|
// MustGenGid generate a new gid
|
||||||
func MustGenGid(server string) string {
|
func MustGenGid(server string) string {
|
||||||
res := common.MS{}
|
res := MS{}
|
||||||
resp, err := common.RestyClient.R().SetResult(&res).Get(server + "/newGid")
|
resp, err := RestyClient.R().SetResult(&res).Get(server + "/newGid")
|
||||||
if err != nil || res["gid"] == "" {
|
if err != nil || res["gid"] == "" {
|
||||||
panic(fmt.Errorf("newGid error: %v, resp: %s", err, resp))
|
panic(fmt.Errorf("newGid error: %v, resp: %s", err, resp))
|
||||||
}
|
}
|
||||||
return res["gid"]
|
return res["gid"]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var sqlDbs = map[string]*sql.DB{}
|
||||||
|
|
||||||
|
// SdbGet get pooled sql.DB
|
||||||
|
func SdbGet(conf map[string]string) *sql.DB {
|
||||||
|
dsn := GetDsn(conf)
|
||||||
|
if sqlDbs[dsn] == nil {
|
||||||
|
sqlDbs[dsn] = SdbAlone(conf)
|
||||||
|
}
|
||||||
|
return sqlDbs[dsn]
|
||||||
|
}
|
||||||
|
|
||||||
|
// SdbAlone get a standalone db connection
|
||||||
|
func SdbAlone(conf map[string]string) *sql.DB {
|
||||||
|
dsn := GetDsn(conf)
|
||||||
|
Logf("opening alone %s: %s", conf["driver"], strings.Replace(dsn, conf["password"], "****", 1))
|
||||||
|
mdb, err := sql.Open(conf["driver"], dsn)
|
||||||
|
E2P(err)
|
||||||
|
return mdb
|
||||||
|
}
|
||||||
|
|
||||||
|
// SdbExec use raw db to exec
|
||||||
|
func SdbExec(db *sql.DB, sql string, values ...interface{}) (affected int64, rerr error) {
|
||||||
|
r, rerr := db.Exec(sql, values...)
|
||||||
|
if rerr == nil {
|
||||||
|
affected, rerr = r.RowsAffected()
|
||||||
|
Logf("affected: %d for %s %v", affected, sql, values)
|
||||||
|
} else {
|
||||||
|
LogRedf("exec error: %v for %s %v", rerr, sql, values)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// StxExec use raw tx to exec
|
||||||
|
func StxExec(tx *sql.Tx, sql string, values ...interface{}) (affected int64, rerr error) {
|
||||||
|
r, rerr := tx.Exec(sql, values...)
|
||||||
|
if rerr == nil {
|
||||||
|
affected, rerr = r.RowsAffected()
|
||||||
|
Logf("affected: %d for %s %v", affected, sql, values)
|
||||||
|
} else {
|
||||||
|
LogRedf("exec error: %v for %s %v", rerr, sql, values)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// StxQueryRow use raw tx to query row
|
||||||
|
func StxQueryRow(tx *sql.Tx, query string, args ...interface{}) *sql.Row {
|
||||||
|
Logf("querying: "+query, args...)
|
||||||
|
return tx.QueryRow(query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDsn get dsn from map config
|
||||||
|
func GetDsn(conf map[string]string) string {
|
||||||
|
conf["host"] = MayReplaceLocalhost(conf["host"])
|
||||||
|
driver := conf["driver"]
|
||||||
|
dsn := MS{
|
||||||
|
"mysql": fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local",
|
||||||
|
conf["user"], conf["password"], conf["host"], conf["port"], conf["database"]),
|
||||||
|
"postgres": fmt.Sprintf("host=%s user=%s password=%s dbname='%s' port=%s sslmode=disable TimeZone=Asia/Shanghai",
|
||||||
|
conf["host"], conf["user"], conf["password"], conf["database"], conf["port"]),
|
||||||
|
}[driver]
|
||||||
|
PanicIf(dsn == "", fmt.Errorf("unknow driver: %s", driver))
|
||||||
|
return dsn
|
||||||
|
}
|
||||||
|
|
||||||
// CheckResponse 检查Response,返回错误
|
// CheckResponse 检查Response,返回错误
|
||||||
func CheckResponse(resp *resty.Response, err error) error {
|
func CheckResponse(resp *resty.Response, err error) error {
|
||||||
if err == nil && resp != nil {
|
if err == nil && resp != nil {
|
||||||
@ -38,7 +108,7 @@ func CheckResult(res interface{}, err error) error {
|
|||||||
if ok {
|
if ok {
|
||||||
return CheckResponse(resp, err)
|
return CheckResponse(resp, err)
|
||||||
}
|
}
|
||||||
if res != nil && strings.Contains(common.MustMarshalString(res), "FAILURE") {
|
if res != nil && strings.Contains(MustMarshalString(res), "FAILURE") {
|
||||||
return ErrFailure
|
return ErrFailure
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
@ -86,11 +156,11 @@ func TransBaseFromReq(c *gin.Context) *TransBase {
|
|||||||
|
|
||||||
// CallDtm 调用dtm服务器,返回事务的状态
|
// CallDtm 调用dtm服务器,返回事务的状态
|
||||||
func (tb *TransBase) CallDtm(body interface{}, operation string) error {
|
func (tb *TransBase) CallDtm(body interface{}, operation string) error {
|
||||||
params := common.MS{}
|
params := MS{}
|
||||||
if tb.WaitResult {
|
if tb.WaitResult {
|
||||||
params["wait_result"] = "1"
|
params["wait_result"] = "1"
|
||||||
}
|
}
|
||||||
resp, err := common.RestyClient.R().SetQueryParams(params).
|
resp, err := RestyClient.R().SetQueryParams(params).
|
||||||
SetResult(&TransResult{}).SetBody(body).Post(fmt.Sprintf("%s/%s", tb.Dtm, operation))
|
SetResult(&TransResult{}).SetBody(body).Post(fmt.Sprintf("%s/%s", tb.Dtm, operation))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -106,7 +176,7 @@ func (tb *TransBase) CallDtm(body interface{}, operation string) error {
|
|||||||
var ErrFailure = errors.New("transaction FAILURE")
|
var ErrFailure = errors.New("transaction FAILURE")
|
||||||
|
|
||||||
// ResultSuccess 表示返回成功,可以进行下一步
|
// ResultSuccess 表示返回成功,可以进行下一步
|
||||||
var ResultSuccess = common.M{"dtm_result": "SUCCESS"}
|
var ResultSuccess = M{"dtm_result": "SUCCESS"}
|
||||||
|
|
||||||
// ResultFailure 表示返回失败,要求回滚
|
// ResultFailure 表示返回失败,要求回滚
|
||||||
var ResultFailure = common.M{"dtm_result": "FAILURE"}
|
var ResultFailure = M{"dtm_result": "FAILURE"}
|
||||||
|
|||||||
@ -5,16 +5,15 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/yedf/dtm/common"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestTypes(t *testing.T) {
|
func TestTypes(t *testing.T) {
|
||||||
err := common.CatchP(func() {
|
err := CatchP(func() {
|
||||||
idGen := IDGenerator{parentID: "12345678901234567890123"}
|
idGen := IDGenerator{parentID: "12345678901234567890123"}
|
||||||
idGen.NewBranchID()
|
idGen.NewBranchID()
|
||||||
})
|
})
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
err = common.CatchP(func() {
|
err = CatchP(func() {
|
||||||
idGen := IDGenerator{branchID: 99}
|
idGen := IDGenerator{branchID: 99}
|
||||||
idGen.NewBranchID()
|
idGen.NewBranchID()
|
||||||
})
|
})
|
||||||
|
|||||||
24
dtmcli/xa.go
24
dtmcli/xa.go
@ -7,13 +7,9 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-resty/resty/v2"
|
"github.com/go-resty/resty/v2"
|
||||||
"github.com/yedf/dtm/common"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// M alias
|
var e2p = E2P
|
||||||
type M = map[string]interface{}
|
|
||||||
|
|
||||||
var e2p = common.E2P
|
|
||||||
|
|
||||||
// XaGlobalFunc type of xa global function
|
// XaGlobalFunc type of xa global function
|
||||||
type XaGlobalFunc func(xa *Xa) (*resty.Response, error)
|
type XaGlobalFunc func(xa *Xa) (*resty.Response, error)
|
||||||
@ -59,10 +55,10 @@ func NewXaClient(server string, mysqlConf map[string]string, callbackURL string,
|
|||||||
|
|
||||||
// HandleCallback 处理commit/rollback的回调
|
// HandleCallback 处理commit/rollback的回调
|
||||||
func (xc *XaClient) HandleCallback(gid string, branchID string, action string) (interface{}, error) {
|
func (xc *XaClient) HandleCallback(gid string, branchID string, action string) (interface{}, error) {
|
||||||
db := common.SdbAlone(xc.Conf)
|
db := SdbAlone(xc.Conf)
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
xaID := gid + "-" + branchID
|
xaID := gid + "-" + branchID
|
||||||
_, err := common.SdbExec(db, fmt.Sprintf("xa %s '%s'", action, xaID))
|
_, err := SdbExec(db, fmt.Sprintf("xa %s '%s'", action, xaID))
|
||||||
return ResultSuccess, err
|
return ResultSuccess, err
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -73,13 +69,13 @@ func (xc *XaClient) XaLocalTransaction(c *gin.Context, xaFunc XaLocalFunc) (ret
|
|||||||
xa.Dtm = xc.Server
|
xa.Dtm = xc.Server
|
||||||
branchID := xa.NewBranchID()
|
branchID := xa.NewBranchID()
|
||||||
xaBranch := xa.Gid + "-" + branchID
|
xaBranch := xa.Gid + "-" + branchID
|
||||||
db := common.SdbAlone(xc.Conf)
|
db := SdbAlone(xc.Conf)
|
||||||
defer func() { db.Close() }()
|
defer func() { db.Close() }()
|
||||||
defer func() {
|
defer func() {
|
||||||
x := recover()
|
x := recover()
|
||||||
_, err := common.SdbExec(db, fmt.Sprintf("XA end '%s'", xaBranch))
|
_, err := SdbExec(db, fmt.Sprintf("XA end '%s'", xaBranch))
|
||||||
if x == nil && rerr == nil && err == nil {
|
if x == nil && rerr == nil && err == nil {
|
||||||
_, err = common.SdbExec(db, fmt.Sprintf("XA prepare '%s'", xaBranch))
|
_, err = SdbExec(db, fmt.Sprintf("XA prepare '%s'", xaBranch))
|
||||||
}
|
}
|
||||||
if rerr == nil {
|
if rerr == nil {
|
||||||
rerr = err
|
rerr = err
|
||||||
@ -88,7 +84,7 @@ func (xc *XaClient) XaLocalTransaction(c *gin.Context, xaFunc XaLocalFunc) (ret
|
|||||||
panic(x)
|
panic(x)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
_, rerr = common.SdbExec(db, fmt.Sprintf("XA start '%s'", xaBranch))
|
_, rerr = SdbExec(db, fmt.Sprintf("XA start '%s'", xaBranch))
|
||||||
if rerr != nil {
|
if rerr != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -116,7 +112,7 @@ func (xc *XaClient) XaGlobalTransaction(gid string, xaFunc XaGlobalFunc) (rerr e
|
|||||||
// 小概率情况下,prepare成功了,但是由于网络状况导致上面Failure,那么不执行下面defer的内容,等待超时后再回滚标记事务失败,也没有问题
|
// 小概率情况下,prepare成功了,但是由于网络状况导致上面Failure,那么不执行下面defer的内容,等待超时后再回滚标记事务失败,也没有问题
|
||||||
defer func() {
|
defer func() {
|
||||||
x := recover()
|
x := recover()
|
||||||
operation := common.If(x != nil || rerr != nil, "abort", "submit").(string)
|
operation := If(x != nil || rerr != nil, "abort", "submit").(string)
|
||||||
err := xa.CallDtm(data, operation)
|
err := xa.CallDtm(data, operation)
|
||||||
if rerr == nil { // 如果用户函数没有返回错误,那么返回dtm的
|
if rerr == nil { // 如果用户函数没有返回错误,那么返回dtm的
|
||||||
rerr = err
|
rerr = err
|
||||||
@ -133,9 +129,9 @@ func (xc *XaClient) XaGlobalTransaction(gid string, xaFunc XaGlobalFunc) (rerr e
|
|||||||
// CallBranch call a xa branch
|
// CallBranch call a xa branch
|
||||||
func (x *Xa) CallBranch(body interface{}, url string) (*resty.Response, error) {
|
func (x *Xa) CallBranch(body interface{}, url string) (*resty.Response, error) {
|
||||||
branchID := x.NewBranchID()
|
branchID := x.NewBranchID()
|
||||||
resp, err := common.RestyClient.R().
|
resp, err := RestyClient.R().
|
||||||
SetBody(body).
|
SetBody(body).
|
||||||
SetQueryParams(common.MS{
|
SetQueryParams(MS{
|
||||||
"gid": x.Gid,
|
"gid": x.Gid,
|
||||||
"branch_id": branchID,
|
"branch_id": branchID,
|
||||||
"trans_type": "xa",
|
"trans_type": "xa",
|
||||||
|
|||||||
@ -16,6 +16,15 @@ var DtmServer = examples.DtmServer
|
|||||||
var Busi = examples.Busi
|
var Busi = examples.Busi
|
||||||
var app *gin.Engine
|
var app *gin.Engine
|
||||||
|
|
||||||
|
// BarrierModel barrier model for gorm
|
||||||
|
type BarrierModel struct {
|
||||||
|
common.ModelBase
|
||||||
|
dtmcli.TransInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
// TableName gorm table name
|
||||||
|
func (BarrierModel) TableName() string { return "dtm_barrier.barrier" }
|
||||||
|
|
||||||
func resetXaData() {
|
func resetXaData() {
|
||||||
if config.DB["driver"] != "mysql" {
|
if config.DB["driver"] != "mysql" {
|
||||||
return
|
return
|
||||||
@ -156,9 +165,9 @@ func TestSqlDB(t *testing.T) {
|
|||||||
return nil, fmt.Errorf("gid2 error")
|
return nil, fmt.Errorf("gid2 error")
|
||||||
})
|
})
|
||||||
asserts.Error(err, fmt.Errorf("gid2 error"))
|
asserts.Error(err, fmt.Errorf("gid2 error"))
|
||||||
dbr := db.Model(&dtmcli.BarrierModel{}).Where("gid=?", "gid1").Find(&[]dtmcli.BarrierModel{})
|
dbr := db.Model(&BarrierModel{}).Where("gid=?", "gid1").Find(&[]BarrierModel{})
|
||||||
asserts.Equal(dbr.RowsAffected, int64(1))
|
asserts.Equal(dbr.RowsAffected, int64(1))
|
||||||
dbr = db.Model(&dtmcli.BarrierModel{}).Where("gid=?", "gid2").Find(&[]dtmcli.BarrierModel{})
|
dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{})
|
||||||
asserts.Equal(dbr.RowsAffected, int64(0))
|
asserts.Equal(dbr.RowsAffected, int64(0))
|
||||||
gid2Res := common.M{"result": "first"}
|
gid2Res := common.M{"result": "first"}
|
||||||
_, err = dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.Tx) (interface{}, error) {
|
_, err = dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.Tx) (interface{}, error) {
|
||||||
@ -166,7 +175,7 @@ func TestSqlDB(t *testing.T) {
|
|||||||
return gid2Res, nil
|
return gid2Res, nil
|
||||||
})
|
})
|
||||||
asserts.Nil(err)
|
asserts.Nil(err)
|
||||||
dbr = db.Model(&dtmcli.BarrierModel{}).Where("gid=?", "gid2").Find(&[]dtmcli.BarrierModel{})
|
dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{})
|
||||||
asserts.Equal(dbr.RowsAffected, int64(1))
|
asserts.Equal(dbr.RowsAffected, int64(1))
|
||||||
newResult, err := dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.Tx) (interface{}, error) {
|
newResult, err := dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.Tx) (interface{}, error) {
|
||||||
common.Logf("submit gid2")
|
common.Logf("submit gid2")
|
||||||
|
|||||||
@ -7,8 +7,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/bwmarrin/snowflake"
|
"github.com/bwmarrin/snowflake"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/yedf/dtm/common"
|
"github.com/yedf/dtm/common"
|
||||||
|
"github.com/yedf/dtm/dtmcli"
|
||||||
)
|
)
|
||||||
|
|
||||||
// M a short name
|
// M a short name
|
||||||
@ -38,13 +38,13 @@ var TransProcessedTestChan chan string = nil
|
|||||||
|
|
||||||
// WaitTransProcessed only for test usage. wait for transaction processed once
|
// WaitTransProcessed only for test usage. wait for transaction processed once
|
||||||
func WaitTransProcessed(gid string) {
|
func WaitTransProcessed(gid string) {
|
||||||
common.Logf("waiting for gid %s", gid)
|
dtmcli.Logf("waiting for gid %s", gid)
|
||||||
id := <-TransProcessedTestChan
|
id := <-TransProcessedTestChan
|
||||||
for id != gid {
|
for id != gid {
|
||||||
logrus.Errorf("-------id %s not match gid %s", id, gid)
|
dtmcli.LogRedf("-------id %s not match gid %s", id, gid)
|
||||||
id = <-TransProcessedTestChan
|
id = <-TransProcessedTestChan
|
||||||
}
|
}
|
||||||
common.Logf("finish for gid %s", gid)
|
dtmcli.Logf("finish for gid %s", gid)
|
||||||
}
|
}
|
||||||
|
|
||||||
var gNode *snowflake.Node = nil
|
var gNode *snowflake.Node = nil
|
||||||
|
|||||||
1
go.mod
1
go.mod
@ -12,7 +12,6 @@ require (
|
|||||||
github.com/kr/pretty v0.1.0 // indirect
|
github.com/kr/pretty v0.1.0 // indirect
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||||
github.com/modern-go/reflect2 v1.0.1 // indirect
|
github.com/modern-go/reflect2 v1.0.1 // indirect
|
||||||
github.com/sirupsen/logrus v1.7.0
|
|
||||||
github.com/stretchr/testify v1.7.0
|
github.com/stretchr/testify v1.7.0
|
||||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
|
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
|
||||||
gopkg.in/yaml.v2 v2.3.0
|
gopkg.in/yaml.v2 v2.3.0
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user