dtmcli ok
This commit is contained in:
parent
0270648c6c
commit
e4392bac43
@ -3,7 +3,7 @@ package common
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/go-playground/assert/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/yedf/dtm/dtmcli"
|
||||
)
|
||||
|
||||
@ -14,8 +14,7 @@ type testConfig struct {
|
||||
var config = testConfig{}
|
||||
|
||||
func init() {
|
||||
InitConfig(dtmcli.GetProjectDir(), &config)
|
||||
config.DB["database"] = ""
|
||||
InitConfig(&config)
|
||||
}
|
||||
|
||||
func TestDb(t *testing.T) {
|
||||
@ -31,8 +30,9 @@ func TestDb(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDbAlone(t *testing.T) {
|
||||
db := dtmcli.SdbAlone(config.DB)
|
||||
_, err := dtmcli.SdbExec(db, "select 1")
|
||||
db, err := dtmcli.SdbAlone(config.DB)
|
||||
assert.Nil(t, err)
|
||||
_, err = dtmcli.SdbExec(db, "select 1")
|
||||
assert.Equal(t, nil, err)
|
||||
db.Close()
|
||||
_, err = dtmcli.SdbExec(db, "select 1")
|
||||
|
||||
@ -4,6 +4,9 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@ -59,14 +62,36 @@ func WrapHandler(fn func(*gin.Context) (interface{}, error)) gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// MustGetwd must version of os.Getwd
|
||||
func MustGetwd() string {
|
||||
wd, err := os.Getwd()
|
||||
dtmcli.E2P(err)
|
||||
return wd
|
||||
}
|
||||
|
||||
// GetCurrentCodeDir 获取当前源代码的目录,主要用于测试时,查找相关文件
|
||||
func GetCurrentCodeDir() string {
|
||||
_, file, _, _ := runtime.Caller(1)
|
||||
return filepath.Dir(file)
|
||||
}
|
||||
|
||||
// InitConfig init config
|
||||
func InitConfig(dir string, config interface{}) {
|
||||
cont, err := ioutil.ReadFile(dir + "/conf.yml")
|
||||
if err != nil {
|
||||
cont, err = ioutil.ReadFile(dir + "/conf.sample.yml")
|
||||
func InitConfig(config interface{}) {
|
||||
cont := []byte{}
|
||||
for d := MustGetwd(); d != ""; d = filepath.Dir(d) {
|
||||
cont1, err := ioutil.ReadFile(d + "/conf.yml")
|
||||
if err != nil {
|
||||
cont1, err = ioutil.ReadFile(d + "/conf.sample.yml")
|
||||
}
|
||||
if cont1 != nil {
|
||||
cont = cont1
|
||||
break
|
||||
}
|
||||
}
|
||||
if cont == nil {
|
||||
dtmcli.LogFatalf("no config file conf.yml/conf.sample.yml found in current and parent path: %s", MustGetwd())
|
||||
}
|
||||
dtmcli.Logf("cont is: \n%s", string(cont))
|
||||
dtmcli.E2P(err)
|
||||
err = yaml.Unmarshal(cont, config)
|
||||
dtmcli.E2P(err)
|
||||
err := yaml.Unmarshal(cont, config)
|
||||
dtmcli.FatalIfError(err)
|
||||
}
|
||||
|
||||
@ -30,3 +30,12 @@ func TestGin(t *testing.T) {
|
||||
assert.Equal(t, "1", getResultString("/api/sample", nil))
|
||||
assert.Equal(t, "{\"code\":500,\"message\":\"err1\"}", getResultString("/api/error", strings.NewReader("{}")))
|
||||
}
|
||||
|
||||
func TestFuncs(t *testing.T) {
|
||||
wd := MustGetwd()
|
||||
assert.NotEqual(t, "", wd)
|
||||
|
||||
dir1 := GetCurrentCodeDir()
|
||||
assert.Equal(t, true, strings.HasSuffix(dir1, "common"))
|
||||
|
||||
}
|
||||
|
||||
@ -1,13 +1,9 @@
|
||||
package dtmcli
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
// M a short name
|
||||
@ -26,94 +22,6 @@ func MustGenGid(server string) string {
|
||||
return res["gid"]
|
||||
}
|
||||
|
||||
var sqlDbs = map[string]*sql.DB{}
|
||||
|
||||
// SdbGet get pooled sql.DB
|
||||
func SdbGet(conf map[string]string) *sql.DB {
|
||||
dsn := GetDsn(conf)
|
||||
if sqlDbs[dsn] == nil {
|
||||
sqlDbs[dsn] = SdbAlone(conf)
|
||||
}
|
||||
return sqlDbs[dsn]
|
||||
}
|
||||
|
||||
// SdbAlone get a standalone db connection
|
||||
func SdbAlone(conf map[string]string) *sql.DB {
|
||||
dsn := GetDsn(conf)
|
||||
Logf("opening alone %s: %s", conf["driver"], strings.Replace(dsn, conf["password"], "****", 1))
|
||||
mdb, err := sql.Open(conf["driver"], dsn)
|
||||
E2P(err)
|
||||
return mdb
|
||||
}
|
||||
|
||||
// SdbExec use raw db to exec
|
||||
func SdbExec(db *sql.DB, sql string, values ...interface{}) (affected int64, rerr error) {
|
||||
r, rerr := db.Exec(sql, values...)
|
||||
if rerr == nil {
|
||||
affected, rerr = r.RowsAffected()
|
||||
Logf("affected: %d for %s %v", affected, sql, values)
|
||||
} else {
|
||||
LogRedf("exec error: %v for %s %v", rerr, sql, values)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// StxExec use raw tx to exec
|
||||
func StxExec(tx *sql.Tx, sql string, values ...interface{}) (affected int64, rerr error) {
|
||||
r, rerr := tx.Exec(sql, values...)
|
||||
if rerr == nil {
|
||||
affected, rerr = r.RowsAffected()
|
||||
Logf("affected: %d for %s %v", affected, sql, values)
|
||||
} else {
|
||||
LogRedf("exec error: %v for %s %v", rerr, sql, values)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// StxQueryRow use raw tx to query row
|
||||
func StxQueryRow(tx *sql.Tx, query string, args ...interface{}) *sql.Row {
|
||||
Logf("querying: "+query, args...)
|
||||
return tx.QueryRow(query, args...)
|
||||
}
|
||||
|
||||
// GetDsn get dsn from map config
|
||||
func GetDsn(conf map[string]string) string {
|
||||
conf["host"] = MayReplaceLocalhost(conf["host"])
|
||||
driver := conf["driver"]
|
||||
dsn := MS{
|
||||
"mysql": fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local",
|
||||
conf["user"], conf["password"], conf["host"], conf["port"], conf["database"]),
|
||||
"postgres": fmt.Sprintf("host=%s user=%s password=%s dbname='%s' port=%s sslmode=disable TimeZone=Asia/Shanghai",
|
||||
conf["host"], conf["user"], conf["password"], conf["database"], conf["port"]),
|
||||
}[driver]
|
||||
PanicIf(dsn == "", fmt.Errorf("unknow driver: %s", driver))
|
||||
return dsn
|
||||
}
|
||||
|
||||
// CheckResponse 检查Response,返回错误
|
||||
func CheckResponse(resp *resty.Response, err error) error {
|
||||
if err == nil && resp != nil {
|
||||
if resp.IsError() {
|
||||
return errors.New(resp.String())
|
||||
} else if strings.Contains(resp.String(), "FAILURE") {
|
||||
return ErrFailure
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// CheckResult 检查Result,返回错误
|
||||
func CheckResult(res interface{}, err error) error {
|
||||
resp, ok := res.(*resty.Response)
|
||||
if ok {
|
||||
return CheckResponse(resp, err)
|
||||
}
|
||||
if res != nil && strings.Contains(MustMarshalString(res), "FAILURE") {
|
||||
return ErrFailure
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// IDGenerator used to generate a branch id
|
||||
type IDGenerator struct {
|
||||
parentID string
|
||||
|
||||
@ -17,7 +17,12 @@ func TestTypes(t *testing.T) {
|
||||
idGen := IDGenerator{branchID: 99}
|
||||
idGen.NewBranchID()
|
||||
})
|
||||
err = CatchP(func() {
|
||||
MustGenGid("http://localhost:8080/api/no")
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Error(t, err)
|
||||
_, err = TransInfoFromQuery(url.Values{})
|
||||
assert.Error(t, err)
|
||||
|
||||
}
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
package dtmcli
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -126,6 +126,21 @@ func LogRedf(fmt string, args ...interface{}) {
|
||||
Logf("\x1b[31m\n"+fmt+"\x1b[0m\n", args...)
|
||||
}
|
||||
|
||||
// LogFatalf 采用红色打印错误类信息, 并退出
|
||||
func LogFatalf(fmt string, args ...interface{}) {
|
||||
Logf("\x1b[31m\n"+fmt+"\x1b[0m\n", args...)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// FatalIfError 采用红色打印错误类信息, 并退出
|
||||
func FatalIfError(err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
Logf("\x1b[31m\nFatal error: %v\x1b[0m\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// RestyClient the resty object
|
||||
var RestyClient = resty.New()
|
||||
|
||||
@ -145,35 +160,6 @@ func init() {
|
||||
})
|
||||
}
|
||||
|
||||
// CheckRestySuccess panic if error or resp not success
|
||||
func CheckRestySuccess(resp *resty.Response, err error) {
|
||||
E2P(err)
|
||||
if !strings.Contains(resp.String(), "SUCCESS") {
|
||||
panic(fmt.Errorf("resty response not success: %s", resp.String()))
|
||||
}
|
||||
}
|
||||
|
||||
// MustGetwd must version of os.Getwd
|
||||
func MustGetwd() string {
|
||||
wd, err := os.Getwd()
|
||||
E2P(err)
|
||||
return wd
|
||||
}
|
||||
|
||||
// GetCurrentCodeDir name is clear
|
||||
func GetCurrentCodeDir() string {
|
||||
_, file, _, _ := runtime.Caller(1)
|
||||
return filepath.Dir(file)
|
||||
}
|
||||
|
||||
// GetProjectDir name is clear
|
||||
func GetProjectDir() string {
|
||||
_, file, _, _ := runtime.Caller(1)
|
||||
for ; !strings.HasSuffix(file, "/dtm"); file = filepath.Dir(file) {
|
||||
}
|
||||
return file
|
||||
}
|
||||
|
||||
// GetFuncName get current call func name
|
||||
func GetFuncName() string {
|
||||
pc, _, _, _ := runtime.Caller(1)
|
||||
@ -187,3 +173,93 @@ func MayReplaceLocalhost(host string) string {
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
var sqlDbs = map[string]*sql.DB{}
|
||||
|
||||
// SdbGet get pooled sql.DB
|
||||
func SdbGet(conf map[string]string) (*sql.DB, error) {
|
||||
dsn := GetDsn(conf)
|
||||
if sqlDbs[dsn] == nil {
|
||||
db, err := SdbAlone(conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sqlDbs[dsn] = db
|
||||
}
|
||||
return sqlDbs[dsn], nil
|
||||
}
|
||||
|
||||
// SdbAlone get a standalone db connection
|
||||
func SdbAlone(conf map[string]string) (*sql.DB, error) {
|
||||
dsn := GetDsn(conf)
|
||||
Logf("opening alone %s: %s", conf["driver"], strings.Replace(dsn, conf["password"], "****", 1))
|
||||
return sql.Open(conf["driver"], dsn)
|
||||
}
|
||||
|
||||
// SdbExec use raw db to exec
|
||||
func SdbExec(db *sql.DB, sql string, values ...interface{}) (affected int64, rerr error) {
|
||||
r, rerr := db.Exec(sql, values...)
|
||||
if rerr == nil {
|
||||
affected, rerr = r.RowsAffected()
|
||||
Logf("affected: %d for %s %v", affected, sql, values)
|
||||
} else {
|
||||
LogRedf("exec error: %v for %s %v", rerr, sql, values)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// StxExec use raw tx to exec
|
||||
func StxExec(tx *sql.Tx, sql string, values ...interface{}) (affected int64, rerr error) {
|
||||
r, rerr := tx.Exec(sql, values...)
|
||||
if rerr == nil {
|
||||
affected, rerr = r.RowsAffected()
|
||||
Logf("affected: %d for %s %v", affected, sql, values)
|
||||
} else {
|
||||
LogRedf("exec error: %v for %s %v", rerr, sql, values)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// StxQueryRow use raw tx to query row
|
||||
func StxQueryRow(tx *sql.Tx, query string, args ...interface{}) *sql.Row {
|
||||
Logf("querying: "+query, args...)
|
||||
return tx.QueryRow(query, args...)
|
||||
}
|
||||
|
||||
// GetDsn get dsn from map config
|
||||
func GetDsn(conf map[string]string) string {
|
||||
conf["host"] = MayReplaceLocalhost(conf["host"])
|
||||
driver := conf["driver"]
|
||||
dsn := MS{
|
||||
"mysql": fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local",
|
||||
conf["user"], conf["password"], conf["host"], conf["port"], conf["database"]),
|
||||
"postgres": fmt.Sprintf("host=%s user=%s password=%s dbname='%s' port=%s sslmode=disable TimeZone=Asia/Shanghai",
|
||||
conf["host"], conf["user"], conf["password"], conf["database"], conf["port"]),
|
||||
}[driver]
|
||||
PanicIf(dsn == "", fmt.Errorf("unknow driver: %s", driver))
|
||||
return dsn
|
||||
}
|
||||
|
||||
// CheckResponse 检查Response,返回错误
|
||||
func CheckResponse(resp *resty.Response, err error) error {
|
||||
if err == nil && resp != nil {
|
||||
if resp.IsError() {
|
||||
return errors.New(resp.String())
|
||||
} else if strings.Contains(resp.String(), "FAILURE") {
|
||||
return ErrFailure
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// CheckResult 检查Result,返回错误
|
||||
func CheckResult(res interface{}, err error) error {
|
||||
resp, ok := res.(*resty.Response)
|
||||
if ok {
|
||||
return CheckResponse(resp, err)
|
||||
}
|
||||
if res != nil && strings.Contains(MustMarshalString(res), "FAILURE") {
|
||||
return ErrFailure
|
||||
}
|
||||
return err
|
||||
}
|
||||
@ -65,11 +65,6 @@ func TestSome(t *testing.T) {
|
||||
MustAtoi("abc")
|
||||
})
|
||||
assert.Error(t, err)
|
||||
wd := MustGetwd()
|
||||
assert.NotEqual(t, "", wd)
|
||||
|
||||
dir1 := GetCurrentCodeDir()
|
||||
assert.Equal(t, true, strings.HasSuffix(dir1, "dtmcli"))
|
||||
|
||||
func1 := GetFuncName()
|
||||
assert.Equal(t, true, strings.HasSuffix(func1, "TestSome"))
|
||||
|
||||
14
dtmcli/xa.go
14
dtmcli/xa.go
@ -8,8 +8,6 @@ import (
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
var e2p = E2P
|
||||
|
||||
// XaGlobalFunc type of xa global function
|
||||
type XaGlobalFunc func(xa *Xa) (*resty.Response, error)
|
||||
|
||||
@ -58,10 +56,13 @@ func NewXaClient(server string, mysqlConf map[string]string, callbackURL string,
|
||||
|
||||
// HandleCallback 处理commit/rollback的回调
|
||||
func (xc *XaClient) HandleCallback(gid string, branchID string, action string) (interface{}, error) {
|
||||
db := SdbAlone(xc.Conf)
|
||||
db, err := SdbAlone(xc.Conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer db.Close()
|
||||
xaID := gid + "-" + branchID
|
||||
_, err := SdbExec(db, fmt.Sprintf("xa %s '%s'", action, xaID))
|
||||
_, err = SdbExec(db, fmt.Sprintf("xa %s '%s'", action, xaID))
|
||||
return ResultSuccess, err
|
||||
|
||||
}
|
||||
@ -75,7 +76,10 @@ func (xc *XaClient) XaLocalTransaction(qs url.Values, xaFunc XaLocalFunc) (ret i
|
||||
xa.Dtm = xc.Server
|
||||
branchID := xa.NewBranchID()
|
||||
xaBranch := xa.Gid + "-" + branchID
|
||||
db := SdbAlone(xc.Conf)
|
||||
db, rerr := SdbAlone(xc.Conf)
|
||||
if rerr != nil {
|
||||
return
|
||||
}
|
||||
defer func() { db.Close() }()
|
||||
defer func() {
|
||||
x := recover()
|
||||
|
||||
@ -2,7 +2,6 @@ package dtmsvr
|
||||
|
||||
import (
|
||||
"github.com/yedf/dtm/common"
|
||||
"github.com/yedf/dtm/dtmcli"
|
||||
)
|
||||
|
||||
type dtmsvrConfig struct {
|
||||
@ -14,9 +13,6 @@ var config = &dtmsvrConfig{
|
||||
TransCronInterval: 10,
|
||||
}
|
||||
|
||||
var dbName = "dtm"
|
||||
|
||||
func init() {
|
||||
common.InitConfig(dtmcli.GetProjectDir(), &config)
|
||||
config.DB["database"] = ""
|
||||
common.InitConfig(&config)
|
||||
}
|
||||
|
||||
@ -42,7 +42,7 @@ func resetXaData() {
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
TransProcessedTestChan = make(chan string, 1)
|
||||
common.InitConfig(dtmcli.GetProjectDir(), &config)
|
||||
common.InitConfig(&config)
|
||||
PopulateDB(false)
|
||||
examples.PopulateDB(false)
|
||||
// 启动组件
|
||||
@ -72,21 +72,6 @@ func TestCover(t *testing.T) {
|
||||
go sleepCronTime()
|
||||
}
|
||||
|
||||
func TestType(t *testing.T) {
|
||||
err := dtmcli.CatchP(func() {
|
||||
dtmcli.MustGenGid("http://localhost:8080/api/no")
|
||||
})
|
||||
assert.Error(t, err)
|
||||
err = dtmcli.CatchP(func() {
|
||||
resp, err := dtmcli.RestyClient.R().SetBody(dtmcli.M{
|
||||
"gid": "1",
|
||||
"trans_type": "msg",
|
||||
}).Get("http://localhost:8080/api/dtmsvr/abort")
|
||||
dtmcli.CheckRestySuccess(resp, err)
|
||||
})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func getTransStatus(gid string) string {
|
||||
sm := TransGlobal{}
|
||||
dbr := dbGet().Model(&sm).Where("gid=?", gid).First(&sm)
|
||||
|
||||
@ -23,6 +23,6 @@ func StartSvr() {
|
||||
|
||||
// PopulateDB setup mysql data
|
||||
func PopulateDB(skipDrop bool) {
|
||||
file := fmt.Sprintf("%s/dtmsvr.%s.sql", dtmcli.GetCurrentCodeDir(), config.DB["driver"])
|
||||
file := fmt.Sprintf("%s/dtmsvr.%s.sql", common.GetCurrentCodeDir(), config.DB["driver"])
|
||||
examples.RunSQLScript(config.DB, file, skipDrop)
|
||||
}
|
||||
|
||||
@ -2,7 +2,6 @@ package examples
|
||||
|
||||
import (
|
||||
"github.com/yedf/dtm/common"
|
||||
"github.com/yedf/dtm/dtmcli"
|
||||
)
|
||||
|
||||
type exampleConfig struct {
|
||||
@ -12,5 +11,5 @@ type exampleConfig struct {
|
||||
var config = exampleConfig{}
|
||||
|
||||
func init() {
|
||||
common.InitConfig(dtmcli.GetProjectDir(), &config)
|
||||
common.InitConfig(&config)
|
||||
}
|
||||
|
||||
@ -5,12 +5,14 @@ import (
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
|
||||
"github.com/yedf/dtm/common"
|
||||
"github.com/yedf/dtm/dtmcli"
|
||||
)
|
||||
|
||||
// RunSQLScript 1
|
||||
func RunSQLScript(conf map[string]string, script string, skipDrop bool) {
|
||||
con := dtmcli.SdbAlone(conf)
|
||||
con, err := dtmcli.SdbAlone(conf)
|
||||
e2p(err)
|
||||
defer func() { con.Close() }()
|
||||
content, err := ioutil.ReadFile(script)
|
||||
e2p(err)
|
||||
@ -27,6 +29,6 @@ func RunSQLScript(conf map[string]string, script string, skipDrop bool) {
|
||||
|
||||
// PopulateDB populate example mysql data
|
||||
func PopulateDB(skipDrop bool) {
|
||||
file := fmt.Sprintf("%s/examples.%s.sql", dtmcli.GetCurrentCodeDir(), config.DB["driver"])
|
||||
file := fmt.Sprintf("%s/examples.%s.sql", common.GetCurrentCodeDir(), config.DB["driver"])
|
||||
RunSQLScript(config.DB, file, skipDrop)
|
||||
}
|
||||
|
||||
@ -64,7 +64,9 @@ func dbGet() *common.DB {
|
||||
}
|
||||
|
||||
func sdbGet() *sql.DB {
|
||||
return dtmcli.SdbGet(config.DB)
|
||||
db, err := dtmcli.SdbGet(config.DB)
|
||||
e2p(err)
|
||||
return db
|
||||
}
|
||||
|
||||
// MustGetTrans construct transaction info from request
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user