diff --git a/app/main.go b/app/main.go index 4d3dcf3..ac57d90 100644 --- a/app/main.go +++ b/app/main.go @@ -4,7 +4,7 @@ import ( "os" "time" - "github.com/sirupsen/logrus" + "github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/dtmsvr" "github.com/yedf/dtm/examples" ) @@ -68,7 +68,7 @@ func main() { examples.TccBarrierAddRoute(app) examples.TccBarrierFireRequest() } else { - logrus.Fatalf("unknown arg: %s", os.Args[1]) + dtmcli.LogRedf("unknown arg: %s", os.Args[1]) } wait() } diff --git a/common/types.go b/common/types.go index 1918975..b822e10 100644 --- a/common/types.go +++ b/common/types.go @@ -7,6 +7,8 @@ import ( "time" _ "github.com/go-sql-driver/mysql" + "github.com/yedf/dtm/dtmcli" + // _ "github.com/lib/pq" "gorm.io/driver/mysql" @@ -15,12 +17,6 @@ import ( "gorm.io/gorm" ) -// M a short name -type M = map[string]interface{} - -// MS a short name -type MS = map[string]string - // ModelBase model base for gorm to provide base fields type ModelBase struct { ID uint @@ -38,7 +34,6 @@ func getGormDialator(driver string, dsn string) gorm.Dialector { } var dbs = map[string]*DB{} -var sqlDbs = map[string]*sql.DB{} // DB provide more func over gorm.DB type DB struct { @@ -60,7 +55,7 @@ func (m *DB) NoMust() *DB { // ToSQLDB get the sql.DB func (m *DB) ToSQLDB() *sql.DB { d, err := m.DB.DB() - E2P(err) + dtmcli.E2P(err) return d } @@ -78,7 +73,7 @@ func (op *tracePlugin) Initialize(db *gorm.DB) (err error) { after := func(db *gorm.DB) { _ts, _ := db.InstanceGet("ivy.startTime") sql := db.Dialector.Explain(db.Statement.SQL.String(), db.Statement.Vars...) - Logf("used: %d ms affected: %d sql is: %s", time.Since(_ts.(time.Time)).Milliseconds(), db.RowsAffected, sql) + dtmcli.Logf("used: %d ms affected: %d sql is: %s", time.Since(_ts.(time.Time)).Milliseconds(), db.RowsAffected, sql) if v, ok := db.InstanceGet("ivy.must"); ok && v.(bool) { if db.Error != nil && db.Error != gorm.ErrRecordNotFound { panic(db.Error) @@ -89,7 +84,7 @@ func (op *tracePlugin) Initialize(db *gorm.DB) (err error) { beforeName := "cb_before" afterName := "cb_after" - Logf("installing db plugin: %s", op.Name()) + dtmcli.Logf("installing db plugin: %s", op.Name()) // 开始前 _ = db.Callback().Create().Before("gorm:before_create").Register(beforeName, before) _ = db.Callback().Query().Before("gorm:query").Register(beforeName, before) @@ -108,79 +103,17 @@ func (op *tracePlugin) Initialize(db *gorm.DB) (err error) { return } -// GetDsn get dsn from map config -func GetDsn(conf map[string]string) string { - conf["host"] = MayReplaceLocalhost(conf["host"]) - driver := conf["driver"] - dsn := MS{ - "mysql": fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", - conf["user"], conf["password"], conf["host"], conf["port"], conf["database"]), - "postgres": fmt.Sprintf("host=%s user=%s password=%s dbname='%s' port=%s sslmode=disable TimeZone=Asia/Shanghai", - conf["host"], conf["user"], conf["password"], conf["database"], conf["port"]), - }[driver] - PanicIf(dsn == "", fmt.Errorf("unknow driver: %s", driver)) - return dsn -} - // DbGet get db connection for specified conf func DbGet(conf map[string]string) *DB { - dsn := GetDsn(conf) + dsn := dtmcli.GetDsn(conf) if dbs[dsn] == nil { - Logf("connecting %s", strings.Replace(dsn, conf["password"], "****", 1)) + dtmcli.Logf("connecting %s", strings.Replace(dsn, conf["password"], "****", 1)) db1, err := gorm.Open(getGormDialator(conf["driver"], dsn), &gorm.Config{ SkipDefaultTransaction: true, }) - E2P(err) + dtmcli.E2P(err) db1.Use(&tracePlugin{}) dbs[dsn] = &DB{DB: db1} } return dbs[dsn] } - -// SdbGet get pooled sql.DB -func SdbGet(conf map[string]string) *sql.DB { - dsn := GetDsn(conf) - if sqlDbs[dsn] == nil { - sqlDbs[dsn] = SdbAlone(conf) - } - return sqlDbs[dsn] -} - -// SdbAlone get a standalone db connection -func SdbAlone(conf map[string]string) *sql.DB { - dsn := GetDsn(conf) - Logf("opening alone %s: %s", conf["driver"], strings.Replace(dsn, conf["password"], "****", 1)) - mdb, err := sql.Open(conf["driver"], dsn) - E2P(err) - return mdb -} - -// SdbExec use raw db to exec -func SdbExec(db *sql.DB, sql string, values ...interface{}) (affected int64, rerr error) { - r, rerr := db.Exec(sql, values...) - if rerr == nil { - affected, rerr = r.RowsAffected() - Logf("affected: %d for %s %v", affected, sql, values) - } else { - LogRedf("exec error: %v for %s %v", rerr, sql, values) - } - return -} - -// StxExec use raw tx to exec -func StxExec(tx *sql.Tx, sql string, values ...interface{}) (affected int64, rerr error) { - r, rerr := tx.Exec(sql, values...) - if rerr == nil { - affected, rerr = r.RowsAffected() - Logf("affected: %d for %s %v", affected, sql, values) - } else { - LogRedf("exec error: %v for %s %v", rerr, sql, values) - } - return -} - -// StxQueryRow use raw tx to query row -func StxQueryRow(tx *sql.Tx, query string, args ...interface{}) *sql.Row { - Logf("querying: "+query, args...) - return tx.QueryRow(query, args...) -} diff --git a/common/types_test.go b/common/types_test.go index ab1f0ef..44e177b 100644 --- a/common/types_test.go +++ b/common/types_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/go-playground/assert/v2" + "github.com/yedf/dtm/dtmcli" ) type testConfig struct { @@ -13,14 +14,14 @@ type testConfig struct { var config = testConfig{} func init() { - InitConfig(GetProjectDir(), &config) + InitConfig(dtmcli.GetProjectDir(), &config) config.DB["database"] = "" } func TestDb(t *testing.T) { db := DbGet(config.DB) err := func() (rerr error) { - defer P2E(&rerr) + defer dtmcli.P2E(&rerr) dbr := db.NoMust().Exec("select a") assert.NotEqual(t, nil, dbr.Error) db.Must().Exec("select a") @@ -30,10 +31,10 @@ func TestDb(t *testing.T) { } func TestDbAlone(t *testing.T) { - db := SdbAlone(config.DB) - _, err := SdbExec(db, "select 1") + db := dtmcli.SdbAlone(config.DB) + _, err := dtmcli.SdbExec(db, "select 1") assert.Equal(t, nil, err) db.Close() - _, err = SdbExec(db, "select 1") + _, err = dtmcli.SdbExec(db, "select 1") assert.NotEqual(t, nil, err) } diff --git a/common/utils.go b/common/utils.go index 9e6dc8d..8d70903 100644 --- a/common/utils.go +++ b/common/utils.go @@ -3,113 +3,14 @@ package common import ( "bytes" "encoding/json" - "errors" - "fmt" "io/ioutil" - "os" - "path" - "path/filepath" - "runtime" - "strconv" - "strings" "time" "github.com/gin-gonic/gin" "github.com/go-resty/resty/v2" - "github.com/sirupsen/logrus" yaml "gopkg.in/yaml.v2" ) -// P2E panic to error -func P2E(perr *error) { - if x := recover(); x != nil { - if e, ok := x.(error); ok { - *perr = e - } else { - panic(x) - } - } -} - -// E2P error to panic -func E2P(err error) { - if err != nil { - panic(err) - } -} - -// CatchP catch panic to error -func CatchP(f func()) (rerr error) { - defer P2E(&rerr) - f() - return nil -} - -// PanicIf name is clear -func PanicIf(cond bool, err error) { - if cond { - panic(err) - } -} - -// MustAtoi 走must逻辑 -func MustAtoi(s string) int { - r, err := strconv.Atoi(s) - if err != nil { - E2P(errors.New("convert to int error: " + s)) - } - return r -} - -// OrString return the first not empty string -func OrString(ss ...string) string { - for _, s := range ss { - if s != "" { - return s - } - } - return "" -} - -// If ternary operator -func If(condition bool, trueObj interface{}, falseObj interface{}) interface{} { - if condition { - return trueObj - } - return falseObj -} - -// MustMarshal checked version for marshal -func MustMarshal(v interface{}) []byte { - b, err := json.Marshal(v) - E2P(err) - return b -} - -// MustMarshalString string version of MustMarshal -func MustMarshalString(v interface{}) string { - return string(MustMarshal(v)) -} - -// MustUnmarshal checked version for unmarshal -func MustUnmarshal(b []byte, obj interface{}) { - err := json.Unmarshal(b, obj) - E2P(err) -} - -// MustUnmarshalString string version of MustUnmarshal -func MustUnmarshalString(s string, obj interface{}) { - MustUnmarshal([]byte(s), obj) -} - -// MustRemarshal marshal and unmarshal, and check error -func MustRemarshal(from interface{}, to interface{}) { - b, err := json.Marshal(from) - E2P(err) - err = json.Unmarshal(b, to) - E2P(err) -} - // GetGinApp init and return gin func GetGinApp() *gin.Engine { gin.SetMode(gin.ReleaseMode) @@ -157,54 +58,6 @@ func WrapHandler(fn func(*gin.Context) (interface{}, error)) gin.HandlerFunc { } } -// RestyClient the resty object -var RestyClient = resty.New() - -func init() { - // RestyClient.SetTimeout(3 * time.Second) - // RestyClient.SetRetryCount(2) - // RestyClient.SetRetryWaitTime(1 * time.Second) - RestyClient.OnBeforeRequest(func(c *resty.Client, r *resty.Request) error { - r.URL = MayReplaceLocalhost(r.URL) - Logf("requesting: %s %s %v %v", r.Method, r.URL, r.Body, r.QueryParam) - return nil - }) - RestyClient.OnAfterResponse(func(c *resty.Client, resp *resty.Response) error { - r := resp.Request - Logf("requested: %s %s %s", r.Method, r.URL, resp.String()) - return nil - }) -} - -// CheckRestySuccess panic if error or resp not success -func CheckRestySuccess(resp *resty.Response, err error) { - E2P(err) - if !strings.Contains(resp.String(), "SUCCESS") { - panic(fmt.Errorf("resty response not success: %s", resp.String())) - } -} - -// Logf 输出日志 -func Logf(format string, args ...interface{}) { - msg := fmt.Sprintf(format, args...) - n := time.Now() - ts := fmt.Sprintf("%d-%02d-%02d %02d:%02d:%02d.%03d", n.Year(), n.Month(), n.Day(), n.Hour(), n.Minute(), n.Second(), n.Nanosecond()/1000000) - var file string - var line int - for i := 1; ; i++ { - _, file, line, _ = runtime.Caller(i) - if strings.Contains(file, "dtm") { - break - } - } - fmt.Printf("%s %s:%d %s\n", ts, path.Base(file), line, msg) -} - -// LogRedf 采用红色打印错误类信息 -func LogRedf(fmt string, args ...interface{}) { - logrus.Errorf("\x1b[31m\n"+fmt+"\x1b[0m\n", args...) -} - // InitConfig init config func InitConfig(dir string, config interface{}) { cont, err := ioutil.ReadFile(dir + "/conf.yml") @@ -216,38 +69,3 @@ func InitConfig(dir string, config interface{}) { err = yaml.Unmarshal(cont, config) E2P(err) } - -// MustGetwd must version of os.Getwd -func MustGetwd() string { - wd, err := os.Getwd() - E2P(err) - return wd -} - -// GetCurrentCodeDir name is clear -func GetCurrentCodeDir() string { - _, file, _, _ := runtime.Caller(1) - return filepath.Dir(file) -} - -// GetProjectDir name is clear -func GetProjectDir() string { - _, file, _, _ := runtime.Caller(1) - for ; !strings.HasSuffix(file, "/dtm"); file = filepath.Dir(file) { - } - return file -} - -// GetFuncName get current call func name -func GetFuncName() string { - pc, _, _, _ := runtime.Caller(1) - return runtime.FuncForPC(pc).Name() -} - -// MayReplaceLocalhost when run in docker compose, change localhost to host.docker.internal for accessing host network -func MayReplaceLocalhost(host string) string { - if os.Getenv("IS_DOCKER_COMPOSE") != "" { - return strings.Replace(host, "localhost", "host.docker.internal", 1) - } - return host -} diff --git a/dtmcli/barrier.go b/dtmcli/barrier.go index 9147413..d278a5f 100644 --- a/dtmcli/barrier.go +++ b/dtmcli/barrier.go @@ -8,7 +8,6 @@ import ( "net/url" "github.com/gin-gonic/gin" - "github.com/yedf/dtm/common" ) // BusiFunc type for busi func @@ -47,20 +46,11 @@ func TransInfoFromQuery(qs url.Values) (*TransInfo, error) { return ti, nil } -// BarrierModel barrier model for gorm -type BarrierModel struct { - common.ModelBase - TransInfo -} - -// TableName gorm table name -func (BarrierModel) TableName() string { return "dtm_barrier.barrier" } - func insertBarrier(tx *sql.Tx, transType string, gid string, branchID string, branchType string, reason string) (int64, error) { if branchType == "" { return 0, nil } - return common.StxExec(tx, "insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values(?,?,?,?,?)", transType, gid, branchID, branchType, reason) + return StxExec(tx, "insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, reason) values(?,?,?,?,?)", transType, gid, branchID, branchType, reason) } // ThroughBarrierCall 子事务屏障,详细介绍见 https://zhuanlan.zhihu.com/p/388444465 @@ -78,7 +68,7 @@ func ThroughBarrierCall(db *sql.DB, transInfo *TransInfo, busiCall BusiFunc) (re return } defer func() { - common.Logf("result is %v error is %v", res, rerr) + Logf("result is %v error is %v", res, rerr) if x := recover(); x != nil { tx.Rollback() panic(x) @@ -95,13 +85,13 @@ func ThroughBarrierCall(db *sql.DB, transInfo *TransInfo, busiCall BusiFunc) (re }[ti.BranchType] originAffected, _ := insertBarrier(tx, ti.TransType, ti.Gid, ti.BranchID, originType, ti.BranchType) currentAffected, rerr := insertBarrier(tx, ti.TransType, ti.Gid, ti.BranchID, ti.BranchType, ti.BranchType) - common.Logf("originAffected: %d currentAffected: %d", originAffected, currentAffected) + Logf("originAffected: %d currentAffected: %d", originAffected, currentAffected) if (ti.BranchType == "cancel" || ti.BranchType == "compensate") && originAffected > 0 { // 这个是空补偿,返回成功 res = ResultSuccess return } else if currentAffected == 0 { // 插入不成功 var result sql.NullString - err := common.StxQueryRow(tx, "select result from dtm_barrier.barrier where trans_type=? and gid=? and branch_id=? and branch_type=? and reason=?", + err := StxQueryRow(tx, "select result from dtm_barrier.barrier where trans_type=? and gid=? and branch_id=? and branch_type=? and reason=?", ti.TransType, ti.Gid, ti.BranchID, ti.BranchType, ti.BranchType).Scan(&result) if err == sql.ErrNoRows { // 这个是悬挂操作,返回失败,AP收到这个返回,会尽快回滚 res = ResultFailure @@ -121,8 +111,8 @@ func ThroughBarrierCall(db *sql.DB, transInfo *TransInfo, busiCall BusiFunc) (re } res, rerr = busiCall(tx) if rerr == nil { // 正确返回了,需要将结果保存到数据库 - sval := common.MustMarshalString(res) - _, rerr = common.StxExec(tx, "update dtm_barrier.barrier set result=? where trans_type=? and gid=? and branch_id=? and branch_type=?", sval, + sval := MustMarshalString(res) + _, rerr = StxExec(tx, "update dtm_barrier.barrier set result=? where trans_type=? and gid=? and branch_id=? and branch_type=?", sval, ti.TransType, ti.Gid, ti.BranchID, ti.BranchType) } return diff --git a/dtmcli/common.go b/dtmcli/common.go new file mode 100644 index 0000000..aed8c0e --- /dev/null +++ b/dtmcli/common.go @@ -0,0 +1,189 @@ +package dtmcli + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "path" + "path/filepath" + "runtime" + "strconv" + "strings" + "time" + + "github.com/go-resty/resty/v2" +) + +// P2E panic to error +func P2E(perr *error) { + if x := recover(); x != nil { + if e, ok := x.(error); ok { + *perr = e + } else { + panic(x) + } + } +} + +// E2P error to panic +func E2P(err error) { + if err != nil { + panic(err) + } +} + +// CatchP catch panic to error +func CatchP(f func()) (rerr error) { + defer P2E(&rerr) + f() + return nil +} + +// PanicIf name is clear +func PanicIf(cond bool, err error) { + if cond { + panic(err) + } +} + +// MustAtoi 走must逻辑 +func MustAtoi(s string) int { + r, err := strconv.Atoi(s) + if err != nil { + E2P(errors.New("convert to int error: " + s)) + } + return r +} + +// OrString return the first not empty string +func OrString(ss ...string) string { + for _, s := range ss { + if s != "" { + return s + } + } + return "" +} + +// If ternary operator +func If(condition bool, trueObj interface{}, falseObj interface{}) interface{} { + if condition { + return trueObj + } + return falseObj +} + +// MustMarshal checked version for marshal +func MustMarshal(v interface{}) []byte { + b, err := json.Marshal(v) + E2P(err) + return b +} + +// MustMarshalString string version of MustMarshal +func MustMarshalString(v interface{}) string { + return string(MustMarshal(v)) +} + +// MustUnmarshal checked version for unmarshal +func MustUnmarshal(b []byte, obj interface{}) { + err := json.Unmarshal(b, obj) + E2P(err) +} + +// MustUnmarshalString string version of MustUnmarshal +func MustUnmarshalString(s string, obj interface{}) { + MustUnmarshal([]byte(s), obj) +} + +// MustRemarshal marshal and unmarshal, and check error +func MustRemarshal(from interface{}, to interface{}) { + b, err := json.Marshal(from) + E2P(err) + err = json.Unmarshal(b, to) + E2P(err) +} + +// Logf 输出日志 +func Logf(format string, args ...interface{}) { + msg := fmt.Sprintf(format, args...) + n := time.Now() + ts := fmt.Sprintf("%d-%02d-%02d %02d:%02d:%02d.%03d", n.Year(), n.Month(), n.Day(), n.Hour(), n.Minute(), n.Second(), n.Nanosecond()/1000000) + var file string + var line int + for i := 1; ; i++ { + _, file, line, _ = runtime.Caller(i) + if strings.Contains(file, "dtm") { + break + } + } + fmt.Printf("%s %s:%d %s\n", ts, path.Base(file), line, msg) +} + +// LogRedf 采用红色打印错误类信息 +func LogRedf(fmt string, args ...interface{}) { + Logf("\x1b[31m\n"+fmt+"\x1b[0m\n", args...) +} + +// RestyClient the resty object +var RestyClient = resty.New() + +func init() { + // RestyClient.SetTimeout(3 * time.Second) + // RestyClient.SetRetryCount(2) + // RestyClient.SetRetryWaitTime(1 * time.Second) + RestyClient.OnBeforeRequest(func(c *resty.Client, r *resty.Request) error { + r.URL = MayReplaceLocalhost(r.URL) + Logf("requesting: %s %s %v %v", r.Method, r.URL, r.Body, r.QueryParam) + return nil + }) + RestyClient.OnAfterResponse(func(c *resty.Client, resp *resty.Response) error { + r := resp.Request + Logf("requested: %s %s %s", r.Method, r.URL, resp.String()) + return nil + }) +} + +// CheckRestySuccess panic if error or resp not success +func CheckRestySuccess(resp *resty.Response, err error) { + E2P(err) + if !strings.Contains(resp.String(), "SUCCESS") { + panic(fmt.Errorf("resty response not success: %s", resp.String())) + } +} + +// MustGetwd must version of os.Getwd +func MustGetwd() string { + wd, err := os.Getwd() + E2P(err) + return wd +} + +// GetCurrentCodeDir name is clear +func GetCurrentCodeDir() string { + _, file, _, _ := runtime.Caller(1) + return filepath.Dir(file) +} + +// GetProjectDir name is clear +func GetProjectDir() string { + _, file, _, _ := runtime.Caller(1) + for ; !strings.HasSuffix(file, "/dtm"); file = filepath.Dir(file) { + } + return file +} + +// GetFuncName get current call func name +func GetFuncName() string { + pc, _, _, _ := runtime.Caller(1) + return runtime.FuncForPC(pc).Name() +} + +// MayReplaceLocalhost when run in docker compose, change localhost to host.docker.internal for accessing host network +func MayReplaceLocalhost(host string) string { + if os.Getenv("IS_DOCKER_COMPOSE") != "" { + return strings.Replace(host, "localhost", "host.docker.internal", 1) + } + return host +} diff --git a/dtmcli/message.go b/dtmcli/message.go index 2605fda..61c962a 100644 --- a/dtmcli/message.go +++ b/dtmcli/message.go @@ -1,9 +1,5 @@ package dtmcli -import ( - "github.com/yedf/dtm/common" -) - // Msg reliable msg type type Msg struct { MsgData @@ -39,10 +35,10 @@ func NewMsg(server string, gid string) *Msg { // Add add a new step func (s *Msg) Add(action string, postData interface{}) *Msg { - common.Logf("msg %s Add %s %v", s.MsgData.Gid, action, postData) + Logf("msg %s Add %s %v", s.MsgData.Gid, action, postData) step := MsgStep{ Action: action, - Data: common.MustMarshalString(postData), + Data: MustMarshalString(postData), } s.Steps = append(s.Steps, step) return s @@ -50,7 +46,7 @@ func (s *Msg) Add(action string, postData interface{}) *Msg { // Prepare prepare the msg func (s *Msg) Prepare(queryPrepared string) error { - s.QueryPrepared = common.OrString(queryPrepared, s.QueryPrepared) + s.QueryPrepared = OrString(queryPrepared, s.QueryPrepared) return s.CallDtm(&s.MsgData, "prepare") } diff --git a/dtmcli/saga.go b/dtmcli/saga.go index b2312d5..f92c4c6 100644 --- a/dtmcli/saga.go +++ b/dtmcli/saga.go @@ -1,9 +1,5 @@ package dtmcli -import ( - "github.com/yedf/dtm/common" -) - // Saga struct of saga type Saga struct { SagaData @@ -39,11 +35,11 @@ func NewSaga(server string, gid string) *Saga { // Add add a saga step func (s *Saga) Add(action string, compensate string, postData interface{}) *Saga { - common.Logf("saga %s Add %s %s %v", s.SagaData.Gid, action, compensate, postData) + Logf("saga %s Add %s %s %v", s.SagaData.Gid, action, compensate, postData) step := SagaStep{ Action: action, Compensate: compensate, - Data: common.MustMarshalString(postData), + Data: MustMarshalString(postData), } s.Steps = append(s.Steps, step) return s diff --git a/dtmcli/tcc.go b/dtmcli/tcc.go index ced6910..35afc00 100644 --- a/dtmcli/tcc.go +++ b/dtmcli/tcc.go @@ -5,7 +5,6 @@ import ( "github.com/gin-gonic/gin" "github.com/go-resty/resty/v2" - "github.com/yedf/dtm/common" ) // Tcc struct of tcc @@ -34,7 +33,7 @@ func TccGlobalTransaction(dtm string, gid string, tccFunc TccGlobalFunc) (rerr e // 小概率情况下,prepare成功了,但是由于网络状况导致上面Failure,那么不执行下面defer的内容,等待超时后再回滚标记事务失败,也没有问题 defer func() { x := recover() - operation := common.If(x == nil && rerr == nil, "submit", "abort").(string) + operation := If(x == nil && rerr == nil, "submit", "abort").(string) err := tcc.CallDtm(data, operation) if rerr == nil { rerr = err @@ -69,7 +68,7 @@ func (t *Tcc) CallBranch(body interface{}, tryURL string, confirmURL string, can "branch_id": branchID, "trans_type": "tcc", "status": "prepared", - "data": string(common.MustMarshal(body)), + "data": string(MustMarshal(body)), "try": tryURL, "confirm": confirmURL, "cancel": cancelURL, @@ -77,9 +76,9 @@ func (t *Tcc) CallBranch(body interface{}, tryURL string, confirmURL string, can if err != nil { return nil, err } - resp, err := common.RestyClient.R(). + resp, err := RestyClient.R(). SetBody(body). - SetQueryParams(common.MS{ + SetQueryParams(MS{ "dtm": t.Dtm, "gid": t.Gid, "branch_id": branchID, diff --git a/dtmcli/types.go b/dtmcli/types.go index 0aff9fa..8b97770 100644 --- a/dtmcli/types.go +++ b/dtmcli/types.go @@ -1,25 +1,95 @@ package dtmcli import ( + "database/sql" "errors" "fmt" "strings" "github.com/gin-gonic/gin" "github.com/go-resty/resty/v2" - "github.com/yedf/dtm/common" ) +// M a short name +type M = map[string]interface{} + +// MS a short name +type MS = map[string]string + // MustGenGid generate a new gid func MustGenGid(server string) string { - res := common.MS{} - resp, err := common.RestyClient.R().SetResult(&res).Get(server + "/newGid") + res := MS{} + resp, err := RestyClient.R().SetResult(&res).Get(server + "/newGid") if err != nil || res["gid"] == "" { panic(fmt.Errorf("newGid error: %v, resp: %s", err, resp)) } return res["gid"] } +var sqlDbs = map[string]*sql.DB{} + +// SdbGet get pooled sql.DB +func SdbGet(conf map[string]string) *sql.DB { + dsn := GetDsn(conf) + if sqlDbs[dsn] == nil { + sqlDbs[dsn] = SdbAlone(conf) + } + return sqlDbs[dsn] +} + +// SdbAlone get a standalone db connection +func SdbAlone(conf map[string]string) *sql.DB { + dsn := GetDsn(conf) + Logf("opening alone %s: %s", conf["driver"], strings.Replace(dsn, conf["password"], "****", 1)) + mdb, err := sql.Open(conf["driver"], dsn) + E2P(err) + return mdb +} + +// SdbExec use raw db to exec +func SdbExec(db *sql.DB, sql string, values ...interface{}) (affected int64, rerr error) { + r, rerr := db.Exec(sql, values...) + if rerr == nil { + affected, rerr = r.RowsAffected() + Logf("affected: %d for %s %v", affected, sql, values) + } else { + LogRedf("exec error: %v for %s %v", rerr, sql, values) + } + return +} + +// StxExec use raw tx to exec +func StxExec(tx *sql.Tx, sql string, values ...interface{}) (affected int64, rerr error) { + r, rerr := tx.Exec(sql, values...) + if rerr == nil { + affected, rerr = r.RowsAffected() + Logf("affected: %d for %s %v", affected, sql, values) + } else { + LogRedf("exec error: %v for %s %v", rerr, sql, values) + } + return +} + +// StxQueryRow use raw tx to query row +func StxQueryRow(tx *sql.Tx, query string, args ...interface{}) *sql.Row { + Logf("querying: "+query, args...) + return tx.QueryRow(query, args...) +} + +// GetDsn get dsn from map config +func GetDsn(conf map[string]string) string { + conf["host"] = MayReplaceLocalhost(conf["host"]) + driver := conf["driver"] + dsn := MS{ + "mysql": fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", + conf["user"], conf["password"], conf["host"], conf["port"], conf["database"]), + "postgres": fmt.Sprintf("host=%s user=%s password=%s dbname='%s' port=%s sslmode=disable TimeZone=Asia/Shanghai", + conf["host"], conf["user"], conf["password"], conf["database"], conf["port"]), + }[driver] + PanicIf(dsn == "", fmt.Errorf("unknow driver: %s", driver)) + return dsn +} + // CheckResponse 检查Response,返回错误 func CheckResponse(resp *resty.Response, err error) error { if err == nil && resp != nil { @@ -38,7 +108,7 @@ func CheckResult(res interface{}, err error) error { if ok { return CheckResponse(resp, err) } - if res != nil && strings.Contains(common.MustMarshalString(res), "FAILURE") { + if res != nil && strings.Contains(MustMarshalString(res), "FAILURE") { return ErrFailure } return err @@ -86,11 +156,11 @@ func TransBaseFromReq(c *gin.Context) *TransBase { // CallDtm 调用dtm服务器,返回事务的状态 func (tb *TransBase) CallDtm(body interface{}, operation string) error { - params := common.MS{} + params := MS{} if tb.WaitResult { params["wait_result"] = "1" } - resp, err := common.RestyClient.R().SetQueryParams(params). + resp, err := RestyClient.R().SetQueryParams(params). SetResult(&TransResult{}).SetBody(body).Post(fmt.Sprintf("%s/%s", tb.Dtm, operation)) if err != nil { return err @@ -106,7 +176,7 @@ func (tb *TransBase) CallDtm(body interface{}, operation string) error { var ErrFailure = errors.New("transaction FAILURE") // ResultSuccess 表示返回成功,可以进行下一步 -var ResultSuccess = common.M{"dtm_result": "SUCCESS"} +var ResultSuccess = M{"dtm_result": "SUCCESS"} // ResultFailure 表示返回失败,要求回滚 -var ResultFailure = common.M{"dtm_result": "FAILURE"} +var ResultFailure = M{"dtm_result": "FAILURE"} diff --git a/dtmcli/types_test.go b/dtmcli/types_test.go index 97a3d88..bc87bbc 100644 --- a/dtmcli/types_test.go +++ b/dtmcli/types_test.go @@ -5,16 +5,15 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/yedf/dtm/common" ) func TestTypes(t *testing.T) { - err := common.CatchP(func() { + err := CatchP(func() { idGen := IDGenerator{parentID: "12345678901234567890123"} idGen.NewBranchID() }) assert.Error(t, err) - err = common.CatchP(func() { + err = CatchP(func() { idGen := IDGenerator{branchID: 99} idGen.NewBranchID() }) diff --git a/dtmcli/xa.go b/dtmcli/xa.go index 860f093..aed9cb8 100644 --- a/dtmcli/xa.go +++ b/dtmcli/xa.go @@ -7,13 +7,9 @@ import ( "github.com/gin-gonic/gin" "github.com/go-resty/resty/v2" - "github.com/yedf/dtm/common" ) -// M alias -type M = map[string]interface{} - -var e2p = common.E2P +var e2p = E2P // XaGlobalFunc type of xa global function type XaGlobalFunc func(xa *Xa) (*resty.Response, error) @@ -59,10 +55,10 @@ func NewXaClient(server string, mysqlConf map[string]string, callbackURL string, // HandleCallback 处理commit/rollback的回调 func (xc *XaClient) HandleCallback(gid string, branchID string, action string) (interface{}, error) { - db := common.SdbAlone(xc.Conf) + db := SdbAlone(xc.Conf) defer db.Close() xaID := gid + "-" + branchID - _, err := common.SdbExec(db, fmt.Sprintf("xa %s '%s'", action, xaID)) + _, err := SdbExec(db, fmt.Sprintf("xa %s '%s'", action, xaID)) return ResultSuccess, err } @@ -73,13 +69,13 @@ func (xc *XaClient) XaLocalTransaction(c *gin.Context, xaFunc XaLocalFunc) (ret xa.Dtm = xc.Server branchID := xa.NewBranchID() xaBranch := xa.Gid + "-" + branchID - db := common.SdbAlone(xc.Conf) + db := SdbAlone(xc.Conf) defer func() { db.Close() }() defer func() { x := recover() - _, err := common.SdbExec(db, fmt.Sprintf("XA end '%s'", xaBranch)) + _, err := SdbExec(db, fmt.Sprintf("XA end '%s'", xaBranch)) if x == nil && rerr == nil && err == nil { - _, err = common.SdbExec(db, fmt.Sprintf("XA prepare '%s'", xaBranch)) + _, err = SdbExec(db, fmt.Sprintf("XA prepare '%s'", xaBranch)) } if rerr == nil { rerr = err @@ -88,7 +84,7 @@ func (xc *XaClient) XaLocalTransaction(c *gin.Context, xaFunc XaLocalFunc) (ret panic(x) } }() - _, rerr = common.SdbExec(db, fmt.Sprintf("XA start '%s'", xaBranch)) + _, rerr = SdbExec(db, fmt.Sprintf("XA start '%s'", xaBranch)) if rerr != nil { return } @@ -116,7 +112,7 @@ func (xc *XaClient) XaGlobalTransaction(gid string, xaFunc XaGlobalFunc) (rerr e // 小概率情况下,prepare成功了,但是由于网络状况导致上面Failure,那么不执行下面defer的内容,等待超时后再回滚标记事务失败,也没有问题 defer func() { x := recover() - operation := common.If(x != nil || rerr != nil, "abort", "submit").(string) + operation := If(x != nil || rerr != nil, "abort", "submit").(string) err := xa.CallDtm(data, operation) if rerr == nil { // 如果用户函数没有返回错误,那么返回dtm的 rerr = err @@ -133,9 +129,9 @@ func (xc *XaClient) XaGlobalTransaction(gid string, xaFunc XaGlobalFunc) (rerr e // CallBranch call a xa branch func (x *Xa) CallBranch(body interface{}, url string) (*resty.Response, error) { branchID := x.NewBranchID() - resp, err := common.RestyClient.R(). + resp, err := RestyClient.R(). SetBody(body). - SetQueryParams(common.MS{ + SetQueryParams(MS{ "gid": x.Gid, "branch_id": branchID, "trans_type": "xa", diff --git a/dtmsvr/dtmsvr_test.go b/dtmsvr/dtmsvr_test.go index fb359c8..a4bd8d0 100644 --- a/dtmsvr/dtmsvr_test.go +++ b/dtmsvr/dtmsvr_test.go @@ -16,6 +16,15 @@ var DtmServer = examples.DtmServer var Busi = examples.Busi var app *gin.Engine +// BarrierModel barrier model for gorm +type BarrierModel struct { + common.ModelBase + dtmcli.TransInfo +} + +// TableName gorm table name +func (BarrierModel) TableName() string { return "dtm_barrier.barrier" } + func resetXaData() { if config.DB["driver"] != "mysql" { return @@ -156,9 +165,9 @@ func TestSqlDB(t *testing.T) { return nil, fmt.Errorf("gid2 error") }) asserts.Error(err, fmt.Errorf("gid2 error")) - dbr := db.Model(&dtmcli.BarrierModel{}).Where("gid=?", "gid1").Find(&[]dtmcli.BarrierModel{}) + dbr := db.Model(&BarrierModel{}).Where("gid=?", "gid1").Find(&[]BarrierModel{}) asserts.Equal(dbr.RowsAffected, int64(1)) - dbr = db.Model(&dtmcli.BarrierModel{}).Where("gid=?", "gid2").Find(&[]dtmcli.BarrierModel{}) + dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{}) asserts.Equal(dbr.RowsAffected, int64(0)) gid2Res := common.M{"result": "first"} _, err = dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.Tx) (interface{}, error) { @@ -166,7 +175,7 @@ func TestSqlDB(t *testing.T) { return gid2Res, nil }) asserts.Nil(err) - dbr = db.Model(&dtmcli.BarrierModel{}).Where("gid=?", "gid2").Find(&[]dtmcli.BarrierModel{}) + dbr = db.Model(&BarrierModel{}).Where("gid=?", "gid2").Find(&[]BarrierModel{}) asserts.Equal(dbr.RowsAffected, int64(1)) newResult, err := dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.Tx) (interface{}, error) { common.Logf("submit gid2") diff --git a/dtmsvr/utils.go b/dtmsvr/utils.go index 45176c0..94838d0 100644 --- a/dtmsvr/utils.go +++ b/dtmsvr/utils.go @@ -7,8 +7,8 @@ import ( "strings" "github.com/bwmarrin/snowflake" - "github.com/sirupsen/logrus" "github.com/yedf/dtm/common" + "github.com/yedf/dtm/dtmcli" ) // M a short name @@ -38,13 +38,13 @@ var TransProcessedTestChan chan string = nil // WaitTransProcessed only for test usage. wait for transaction processed once func WaitTransProcessed(gid string) { - common.Logf("waiting for gid %s", gid) + dtmcli.Logf("waiting for gid %s", gid) id := <-TransProcessedTestChan for id != gid { - logrus.Errorf("-------id %s not match gid %s", id, gid) + dtmcli.LogRedf("-------id %s not match gid %s", id, gid) id = <-TransProcessedTestChan } - common.Logf("finish for gid %s", gid) + dtmcli.Logf("finish for gid %s", gid) } var gNode *snowflake.Node = nil diff --git a/go.mod b/go.mod index 9401013..681c2ed 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,6 @@ require ( github.com/kr/pretty v0.1.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.1 // indirect - github.com/sirupsen/logrus v1.7.0 github.com/stretchr/testify v1.7.0 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect gopkg.in/yaml.v2 v2.3.0