dtmcli ok

This commit is contained in:
yedf2 2021-08-04 14:31:23 +08:00
parent 0270648c6c
commit e4392bac43
14 changed files with 177 additions and 171 deletions

View File

@ -3,7 +3,7 @@ package common
import ( import (
"testing" "testing"
"github.com/go-playground/assert/v2" "github.com/stretchr/testify/assert"
"github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/dtmcli"
) )
@ -14,8 +14,7 @@ type testConfig struct {
var config = testConfig{} var config = testConfig{}
func init() { func init() {
InitConfig(dtmcli.GetProjectDir(), &config) InitConfig(&config)
config.DB["database"] = ""
} }
func TestDb(t *testing.T) { func TestDb(t *testing.T) {
@ -31,8 +30,9 @@ func TestDb(t *testing.T) {
} }
func TestDbAlone(t *testing.T) { func TestDbAlone(t *testing.T) {
db := dtmcli.SdbAlone(config.DB) db, err := dtmcli.SdbAlone(config.DB)
_, err := dtmcli.SdbExec(db, "select 1") assert.Nil(t, err)
_, err = dtmcli.SdbExec(db, "select 1")
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
db.Close() db.Close()
_, err = dtmcli.SdbExec(db, "select 1") _, err = dtmcli.SdbExec(db, "select 1")

View File

@ -4,6 +4,9 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"io/ioutil" "io/ioutil"
"os"
"path/filepath"
"runtime"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -59,14 +62,36 @@ func WrapHandler(fn func(*gin.Context) (interface{}, error)) gin.HandlerFunc {
} }
} }
// MustGetwd must version of os.Getwd
func MustGetwd() string {
wd, err := os.Getwd()
dtmcli.E2P(err)
return wd
}
// GetCurrentCodeDir 获取当前源代码的目录,主要用于测试时,查找相关文件
func GetCurrentCodeDir() string {
_, file, _, _ := runtime.Caller(1)
return filepath.Dir(file)
}
// InitConfig init config // InitConfig init config
func InitConfig(dir string, config interface{}) { func InitConfig(config interface{}) {
cont, err := ioutil.ReadFile(dir + "/conf.yml") cont := []byte{}
if err != nil { for d := MustGetwd(); d != ""; d = filepath.Dir(d) {
cont, err = ioutil.ReadFile(dir + "/conf.sample.yml") cont1, err := ioutil.ReadFile(d + "/conf.yml")
if err != nil {
cont1, err = ioutil.ReadFile(d + "/conf.sample.yml")
}
if cont1 != nil {
cont = cont1
break
}
}
if cont == nil {
dtmcli.LogFatalf("no config file conf.yml/conf.sample.yml found in current and parent path: %s", MustGetwd())
} }
dtmcli.Logf("cont is: \n%s", string(cont)) dtmcli.Logf("cont is: \n%s", string(cont))
dtmcli.E2P(err) err := yaml.Unmarshal(cont, config)
err = yaml.Unmarshal(cont, config) dtmcli.FatalIfError(err)
dtmcli.E2P(err)
} }

View File

@ -30,3 +30,12 @@ func TestGin(t *testing.T) {
assert.Equal(t, "1", getResultString("/api/sample", nil)) assert.Equal(t, "1", getResultString("/api/sample", nil))
assert.Equal(t, "{\"code\":500,\"message\":\"err1\"}", getResultString("/api/error", strings.NewReader("{}"))) assert.Equal(t, "{\"code\":500,\"message\":\"err1\"}", getResultString("/api/error", strings.NewReader("{}")))
} }
func TestFuncs(t *testing.T) {
wd := MustGetwd()
assert.NotEqual(t, "", wd)
dir1 := GetCurrentCodeDir()
assert.Equal(t, true, strings.HasSuffix(dir1, "common"))
}

View File

@ -1,13 +1,9 @@
package dtmcli package dtmcli
import ( import (
"database/sql"
"errors" "errors"
"fmt" "fmt"
"net/url" "net/url"
"strings"
"github.com/go-resty/resty/v2"
) )
// M a short name // M a short name
@ -26,94 +22,6 @@ func MustGenGid(server string) string {
return res["gid"] return res["gid"]
} }
var sqlDbs = map[string]*sql.DB{}
// SdbGet get pooled sql.DB
func SdbGet(conf map[string]string) *sql.DB {
dsn := GetDsn(conf)
if sqlDbs[dsn] == nil {
sqlDbs[dsn] = SdbAlone(conf)
}
return sqlDbs[dsn]
}
// SdbAlone get a standalone db connection
func SdbAlone(conf map[string]string) *sql.DB {
dsn := GetDsn(conf)
Logf("opening alone %s: %s", conf["driver"], strings.Replace(dsn, conf["password"], "****", 1))
mdb, err := sql.Open(conf["driver"], dsn)
E2P(err)
return mdb
}
// SdbExec use raw db to exec
func SdbExec(db *sql.DB, sql string, values ...interface{}) (affected int64, rerr error) {
r, rerr := db.Exec(sql, values...)
if rerr == nil {
affected, rerr = r.RowsAffected()
Logf("affected: %d for %s %v", affected, sql, values)
} else {
LogRedf("exec error: %v for %s %v", rerr, sql, values)
}
return
}
// StxExec use raw tx to exec
func StxExec(tx *sql.Tx, sql string, values ...interface{}) (affected int64, rerr error) {
r, rerr := tx.Exec(sql, values...)
if rerr == nil {
affected, rerr = r.RowsAffected()
Logf("affected: %d for %s %v", affected, sql, values)
} else {
LogRedf("exec error: %v for %s %v", rerr, sql, values)
}
return
}
// StxQueryRow use raw tx to query row
func StxQueryRow(tx *sql.Tx, query string, args ...interface{}) *sql.Row {
Logf("querying: "+query, args...)
return tx.QueryRow(query, args...)
}
// GetDsn get dsn from map config
func GetDsn(conf map[string]string) string {
conf["host"] = MayReplaceLocalhost(conf["host"])
driver := conf["driver"]
dsn := MS{
"mysql": fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local",
conf["user"], conf["password"], conf["host"], conf["port"], conf["database"]),
"postgres": fmt.Sprintf("host=%s user=%s password=%s dbname='%s' port=%s sslmode=disable TimeZone=Asia/Shanghai",
conf["host"], conf["user"], conf["password"], conf["database"], conf["port"]),
}[driver]
PanicIf(dsn == "", fmt.Errorf("unknow driver: %s", driver))
return dsn
}
// CheckResponse 检查Response返回错误
func CheckResponse(resp *resty.Response, err error) error {
if err == nil && resp != nil {
if resp.IsError() {
return errors.New(resp.String())
} else if strings.Contains(resp.String(), "FAILURE") {
return ErrFailure
}
}
return err
}
// CheckResult 检查Result返回错误
func CheckResult(res interface{}, err error) error {
resp, ok := res.(*resty.Response)
if ok {
return CheckResponse(resp, err)
}
if res != nil && strings.Contains(MustMarshalString(res), "FAILURE") {
return ErrFailure
}
return err
}
// IDGenerator used to generate a branch id // IDGenerator used to generate a branch id
type IDGenerator struct { type IDGenerator struct {
parentID string parentID string

View File

@ -17,7 +17,12 @@ func TestTypes(t *testing.T) {
idGen := IDGenerator{branchID: 99} idGen := IDGenerator{branchID: 99}
idGen.NewBranchID() idGen.NewBranchID()
}) })
err = CatchP(func() {
MustGenGid("http://localhost:8080/api/no")
})
assert.Error(t, err)
assert.Error(t, err) assert.Error(t, err)
_, err = TransInfoFromQuery(url.Values{}) _, err = TransInfoFromQuery(url.Values{})
assert.Error(t, err) assert.Error(t, err)
} }

View File

@ -1,12 +1,12 @@
package dtmcli package dtmcli
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"os" "os"
"path" "path"
"path/filepath"
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
@ -126,6 +126,21 @@ func LogRedf(fmt string, args ...interface{}) {
Logf("\x1b[31m\n"+fmt+"\x1b[0m\n", args...) Logf("\x1b[31m\n"+fmt+"\x1b[0m\n", args...)
} }
// LogFatalf 采用红色打印错误类信息, 并退出
func LogFatalf(fmt string, args ...interface{}) {
Logf("\x1b[31m\n"+fmt+"\x1b[0m\n", args...)
os.Exit(1)
}
// FatalIfError 采用红色打印错误类信息, 并退出
func FatalIfError(err error) {
if err == nil {
return
}
Logf("\x1b[31m\nFatal error: %v\x1b[0m\n", err)
os.Exit(1)
}
// RestyClient the resty object // RestyClient the resty object
var RestyClient = resty.New() var RestyClient = resty.New()
@ -145,35 +160,6 @@ func init() {
}) })
} }
// CheckRestySuccess panic if error or resp not success
func CheckRestySuccess(resp *resty.Response, err error) {
E2P(err)
if !strings.Contains(resp.String(), "SUCCESS") {
panic(fmt.Errorf("resty response not success: %s", resp.String()))
}
}
// MustGetwd must version of os.Getwd
func MustGetwd() string {
wd, err := os.Getwd()
E2P(err)
return wd
}
// GetCurrentCodeDir name is clear
func GetCurrentCodeDir() string {
_, file, _, _ := runtime.Caller(1)
return filepath.Dir(file)
}
// GetProjectDir name is clear
func GetProjectDir() string {
_, file, _, _ := runtime.Caller(1)
for ; !strings.HasSuffix(file, "/dtm"); file = filepath.Dir(file) {
}
return file
}
// GetFuncName get current call func name // GetFuncName get current call func name
func GetFuncName() string { func GetFuncName() string {
pc, _, _, _ := runtime.Caller(1) pc, _, _, _ := runtime.Caller(1)
@ -187,3 +173,93 @@ func MayReplaceLocalhost(host string) string {
} }
return host return host
} }
var sqlDbs = map[string]*sql.DB{}
// SdbGet get pooled sql.DB
func SdbGet(conf map[string]string) (*sql.DB, error) {
dsn := GetDsn(conf)
if sqlDbs[dsn] == nil {
db, err := SdbAlone(conf)
if err != nil {
return nil, err
}
sqlDbs[dsn] = db
}
return sqlDbs[dsn], nil
}
// SdbAlone get a standalone db connection
func SdbAlone(conf map[string]string) (*sql.DB, error) {
dsn := GetDsn(conf)
Logf("opening alone %s: %s", conf["driver"], strings.Replace(dsn, conf["password"], "****", 1))
return sql.Open(conf["driver"], dsn)
}
// SdbExec use raw db to exec
func SdbExec(db *sql.DB, sql string, values ...interface{}) (affected int64, rerr error) {
r, rerr := db.Exec(sql, values...)
if rerr == nil {
affected, rerr = r.RowsAffected()
Logf("affected: %d for %s %v", affected, sql, values)
} else {
LogRedf("exec error: %v for %s %v", rerr, sql, values)
}
return
}
// StxExec use raw tx to exec
func StxExec(tx *sql.Tx, sql string, values ...interface{}) (affected int64, rerr error) {
r, rerr := tx.Exec(sql, values...)
if rerr == nil {
affected, rerr = r.RowsAffected()
Logf("affected: %d for %s %v", affected, sql, values)
} else {
LogRedf("exec error: %v for %s %v", rerr, sql, values)
}
return
}
// StxQueryRow use raw tx to query row
func StxQueryRow(tx *sql.Tx, query string, args ...interface{}) *sql.Row {
Logf("querying: "+query, args...)
return tx.QueryRow(query, args...)
}
// GetDsn get dsn from map config
func GetDsn(conf map[string]string) string {
conf["host"] = MayReplaceLocalhost(conf["host"])
driver := conf["driver"]
dsn := MS{
"mysql": fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local",
conf["user"], conf["password"], conf["host"], conf["port"], conf["database"]),
"postgres": fmt.Sprintf("host=%s user=%s password=%s dbname='%s' port=%s sslmode=disable TimeZone=Asia/Shanghai",
conf["host"], conf["user"], conf["password"], conf["database"], conf["port"]),
}[driver]
PanicIf(dsn == "", fmt.Errorf("unknow driver: %s", driver))
return dsn
}
// CheckResponse 检查Response返回错误
func CheckResponse(resp *resty.Response, err error) error {
if err == nil && resp != nil {
if resp.IsError() {
return errors.New(resp.String())
} else if strings.Contains(resp.String(), "FAILURE") {
return ErrFailure
}
}
return err
}
// CheckResult 检查Result返回错误
func CheckResult(res interface{}, err error) error {
resp, ok := res.(*resty.Response)
if ok {
return CheckResponse(resp, err)
}
if res != nil && strings.Contains(MustMarshalString(res), "FAILURE") {
return ErrFailure
}
return err
}

View File

@ -65,11 +65,6 @@ func TestSome(t *testing.T) {
MustAtoi("abc") MustAtoi("abc")
}) })
assert.Error(t, err) assert.Error(t, err)
wd := MustGetwd()
assert.NotEqual(t, "", wd)
dir1 := GetCurrentCodeDir()
assert.Equal(t, true, strings.HasSuffix(dir1, "dtmcli"))
func1 := GetFuncName() func1 := GetFuncName()
assert.Equal(t, true, strings.HasSuffix(func1, "TestSome")) assert.Equal(t, true, strings.HasSuffix(func1, "TestSome"))

View File

@ -8,8 +8,6 @@ import (
"github.com/go-resty/resty/v2" "github.com/go-resty/resty/v2"
) )
var e2p = E2P
// XaGlobalFunc type of xa global function // XaGlobalFunc type of xa global function
type XaGlobalFunc func(xa *Xa) (*resty.Response, error) type XaGlobalFunc func(xa *Xa) (*resty.Response, error)
@ -58,10 +56,13 @@ func NewXaClient(server string, mysqlConf map[string]string, callbackURL string,
// HandleCallback 处理commit/rollback的回调 // HandleCallback 处理commit/rollback的回调
func (xc *XaClient) HandleCallback(gid string, branchID string, action string) (interface{}, error) { func (xc *XaClient) HandleCallback(gid string, branchID string, action string) (interface{}, error) {
db := SdbAlone(xc.Conf) db, err := SdbAlone(xc.Conf)
if err != nil {
return nil, err
}
defer db.Close() defer db.Close()
xaID := gid + "-" + branchID xaID := gid + "-" + branchID
_, err := SdbExec(db, fmt.Sprintf("xa %s '%s'", action, xaID)) _, err = SdbExec(db, fmt.Sprintf("xa %s '%s'", action, xaID))
return ResultSuccess, err return ResultSuccess, err
} }
@ -75,7 +76,10 @@ func (xc *XaClient) XaLocalTransaction(qs url.Values, xaFunc XaLocalFunc) (ret i
xa.Dtm = xc.Server xa.Dtm = xc.Server
branchID := xa.NewBranchID() branchID := xa.NewBranchID()
xaBranch := xa.Gid + "-" + branchID xaBranch := xa.Gid + "-" + branchID
db := SdbAlone(xc.Conf) db, rerr := SdbAlone(xc.Conf)
if rerr != nil {
return
}
defer func() { db.Close() }() defer func() { db.Close() }()
defer func() { defer func() {
x := recover() x := recover()

View File

@ -2,7 +2,6 @@ package dtmsvr
import ( import (
"github.com/yedf/dtm/common" "github.com/yedf/dtm/common"
"github.com/yedf/dtm/dtmcli"
) )
type dtmsvrConfig struct { type dtmsvrConfig struct {
@ -14,9 +13,6 @@ var config = &dtmsvrConfig{
TransCronInterval: 10, TransCronInterval: 10,
} }
var dbName = "dtm"
func init() { func init() {
common.InitConfig(dtmcli.GetProjectDir(), &config) common.InitConfig(&config)
config.DB["database"] = ""
} }

View File

@ -42,7 +42,7 @@ func resetXaData() {
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
TransProcessedTestChan = make(chan string, 1) TransProcessedTestChan = make(chan string, 1)
common.InitConfig(dtmcli.GetProjectDir(), &config) common.InitConfig(&config)
PopulateDB(false) PopulateDB(false)
examples.PopulateDB(false) examples.PopulateDB(false)
// 启动组件 // 启动组件
@ -72,21 +72,6 @@ func TestCover(t *testing.T) {
go sleepCronTime() go sleepCronTime()
} }
func TestType(t *testing.T) {
err := dtmcli.CatchP(func() {
dtmcli.MustGenGid("http://localhost:8080/api/no")
})
assert.Error(t, err)
err = dtmcli.CatchP(func() {
resp, err := dtmcli.RestyClient.R().SetBody(dtmcli.M{
"gid": "1",
"trans_type": "msg",
}).Get("http://localhost:8080/api/dtmsvr/abort")
dtmcli.CheckRestySuccess(resp, err)
})
assert.Error(t, err)
}
func getTransStatus(gid string) string { func getTransStatus(gid string) string {
sm := TransGlobal{} sm := TransGlobal{}
dbr := dbGet().Model(&sm).Where("gid=?", gid).First(&sm) dbr := dbGet().Model(&sm).Where("gid=?", gid).First(&sm)

View File

@ -23,6 +23,6 @@ func StartSvr() {
// PopulateDB setup mysql data // PopulateDB setup mysql data
func PopulateDB(skipDrop bool) { func PopulateDB(skipDrop bool) {
file := fmt.Sprintf("%s/dtmsvr.%s.sql", dtmcli.GetCurrentCodeDir(), config.DB["driver"]) file := fmt.Sprintf("%s/dtmsvr.%s.sql", common.GetCurrentCodeDir(), config.DB["driver"])
examples.RunSQLScript(config.DB, file, skipDrop) examples.RunSQLScript(config.DB, file, skipDrop)
} }

View File

@ -2,7 +2,6 @@ package examples
import ( import (
"github.com/yedf/dtm/common" "github.com/yedf/dtm/common"
"github.com/yedf/dtm/dtmcli"
) )
type exampleConfig struct { type exampleConfig struct {
@ -12,5 +11,5 @@ type exampleConfig struct {
var config = exampleConfig{} var config = exampleConfig{}
func init() { func init() {
common.InitConfig(dtmcli.GetProjectDir(), &config) common.InitConfig(&config)
} }

View File

@ -5,12 +5,14 @@ import (
"io/ioutil" "io/ioutil"
"strings" "strings"
"github.com/yedf/dtm/common"
"github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/dtmcli"
) )
// RunSQLScript 1 // RunSQLScript 1
func RunSQLScript(conf map[string]string, script string, skipDrop bool) { func RunSQLScript(conf map[string]string, script string, skipDrop bool) {
con := dtmcli.SdbAlone(conf) con, err := dtmcli.SdbAlone(conf)
e2p(err)
defer func() { con.Close() }() defer func() { con.Close() }()
content, err := ioutil.ReadFile(script) content, err := ioutil.ReadFile(script)
e2p(err) e2p(err)
@ -27,6 +29,6 @@ func RunSQLScript(conf map[string]string, script string, skipDrop bool) {
// PopulateDB populate example mysql data // PopulateDB populate example mysql data
func PopulateDB(skipDrop bool) { func PopulateDB(skipDrop bool) {
file := fmt.Sprintf("%s/examples.%s.sql", dtmcli.GetCurrentCodeDir(), config.DB["driver"]) file := fmt.Sprintf("%s/examples.%s.sql", common.GetCurrentCodeDir(), config.DB["driver"])
RunSQLScript(config.DB, file, skipDrop) RunSQLScript(config.DB, file, skipDrop)
} }

View File

@ -64,7 +64,9 @@ func dbGet() *common.DB {
} }
func sdbGet() *sql.DB { func sdbGet() *sql.DB {
return dtmcli.SdbGet(config.DB) db, err := dtmcli.SdbGet(config.DB)
e2p(err)
return db
} }
// MustGetTrans construct transaction info from request // MustGetTrans construct transaction info from request