dtmcli ok

This commit is contained in:
yedf2 2021-08-04 14:31:23 +08:00
parent 0270648c6c
commit e4392bac43
14 changed files with 177 additions and 171 deletions

View File

@ -3,7 +3,7 @@ package common
import (
"testing"
"github.com/go-playground/assert/v2"
"github.com/stretchr/testify/assert"
"github.com/yedf/dtm/dtmcli"
)
@ -14,8 +14,7 @@ type testConfig struct {
var config = testConfig{}
func init() {
InitConfig(dtmcli.GetProjectDir(), &config)
config.DB["database"] = ""
InitConfig(&config)
}
func TestDb(t *testing.T) {
@ -31,8 +30,9 @@ func TestDb(t *testing.T) {
}
func TestDbAlone(t *testing.T) {
db := dtmcli.SdbAlone(config.DB)
_, err := dtmcli.SdbExec(db, "select 1")
db, err := dtmcli.SdbAlone(config.DB)
assert.Nil(t, err)
_, err = dtmcli.SdbExec(db, "select 1")
assert.Equal(t, nil, err)
db.Close()
_, err = dtmcli.SdbExec(db, "select 1")

View File

@ -4,6 +4,9 @@ import (
"bytes"
"encoding/json"
"io/ioutil"
"os"
"path/filepath"
"runtime"
"time"
"github.com/gin-gonic/gin"
@ -59,14 +62,36 @@ func WrapHandler(fn func(*gin.Context) (interface{}, error)) gin.HandlerFunc {
}
}
// MustGetwd must version of os.Getwd
func MustGetwd() string {
wd, err := os.Getwd()
dtmcli.E2P(err)
return wd
}
// GetCurrentCodeDir 获取当前源代码的目录,主要用于测试时,查找相关文件
func GetCurrentCodeDir() string {
_, file, _, _ := runtime.Caller(1)
return filepath.Dir(file)
}
// InitConfig init config
func InitConfig(dir string, config interface{}) {
cont, err := ioutil.ReadFile(dir + "/conf.yml")
if err != nil {
cont, err = ioutil.ReadFile(dir + "/conf.sample.yml")
func InitConfig(config interface{}) {
cont := []byte{}
for d := MustGetwd(); d != ""; d = filepath.Dir(d) {
cont1, err := ioutil.ReadFile(d + "/conf.yml")
if err != nil {
cont1, err = ioutil.ReadFile(d + "/conf.sample.yml")
}
if cont1 != nil {
cont = cont1
break
}
}
if cont == nil {
dtmcli.LogFatalf("no config file conf.yml/conf.sample.yml found in current and parent path: %s", MustGetwd())
}
dtmcli.Logf("cont is: \n%s", string(cont))
dtmcli.E2P(err)
err = yaml.Unmarshal(cont, config)
dtmcli.E2P(err)
err := yaml.Unmarshal(cont, config)
dtmcli.FatalIfError(err)
}

View File

@ -30,3 +30,12 @@ func TestGin(t *testing.T) {
assert.Equal(t, "1", getResultString("/api/sample", nil))
assert.Equal(t, "{\"code\":500,\"message\":\"err1\"}", getResultString("/api/error", strings.NewReader("{}")))
}
func TestFuncs(t *testing.T) {
wd := MustGetwd()
assert.NotEqual(t, "", wd)
dir1 := GetCurrentCodeDir()
assert.Equal(t, true, strings.HasSuffix(dir1, "common"))
}

View File

@ -1,13 +1,9 @@
package dtmcli
import (
"database/sql"
"errors"
"fmt"
"net/url"
"strings"
"github.com/go-resty/resty/v2"
)
// M a short name
@ -26,94 +22,6 @@ func MustGenGid(server string) string {
return res["gid"]
}
var sqlDbs = map[string]*sql.DB{}
// SdbGet get pooled sql.DB
func SdbGet(conf map[string]string) *sql.DB {
dsn := GetDsn(conf)
if sqlDbs[dsn] == nil {
sqlDbs[dsn] = SdbAlone(conf)
}
return sqlDbs[dsn]
}
// SdbAlone get a standalone db connection
func SdbAlone(conf map[string]string) *sql.DB {
dsn := GetDsn(conf)
Logf("opening alone %s: %s", conf["driver"], strings.Replace(dsn, conf["password"], "****", 1))
mdb, err := sql.Open(conf["driver"], dsn)
E2P(err)
return mdb
}
// SdbExec use raw db to exec
func SdbExec(db *sql.DB, sql string, values ...interface{}) (affected int64, rerr error) {
r, rerr := db.Exec(sql, values...)
if rerr == nil {
affected, rerr = r.RowsAffected()
Logf("affected: %d for %s %v", affected, sql, values)
} else {
LogRedf("exec error: %v for %s %v", rerr, sql, values)
}
return
}
// StxExec use raw tx to exec
func StxExec(tx *sql.Tx, sql string, values ...interface{}) (affected int64, rerr error) {
r, rerr := tx.Exec(sql, values...)
if rerr == nil {
affected, rerr = r.RowsAffected()
Logf("affected: %d for %s %v", affected, sql, values)
} else {
LogRedf("exec error: %v for %s %v", rerr, sql, values)
}
return
}
// StxQueryRow use raw tx to query row
func StxQueryRow(tx *sql.Tx, query string, args ...interface{}) *sql.Row {
Logf("querying: "+query, args...)
return tx.QueryRow(query, args...)
}
// GetDsn get dsn from map config
func GetDsn(conf map[string]string) string {
conf["host"] = MayReplaceLocalhost(conf["host"])
driver := conf["driver"]
dsn := MS{
"mysql": fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local",
conf["user"], conf["password"], conf["host"], conf["port"], conf["database"]),
"postgres": fmt.Sprintf("host=%s user=%s password=%s dbname='%s' port=%s sslmode=disable TimeZone=Asia/Shanghai",
conf["host"], conf["user"], conf["password"], conf["database"], conf["port"]),
}[driver]
PanicIf(dsn == "", fmt.Errorf("unknow driver: %s", driver))
return dsn
}
// CheckResponse 检查Response返回错误
func CheckResponse(resp *resty.Response, err error) error {
if err == nil && resp != nil {
if resp.IsError() {
return errors.New(resp.String())
} else if strings.Contains(resp.String(), "FAILURE") {
return ErrFailure
}
}
return err
}
// CheckResult 检查Result返回错误
func CheckResult(res interface{}, err error) error {
resp, ok := res.(*resty.Response)
if ok {
return CheckResponse(resp, err)
}
if res != nil && strings.Contains(MustMarshalString(res), "FAILURE") {
return ErrFailure
}
return err
}
// IDGenerator used to generate a branch id
type IDGenerator struct {
parentID string

View File

@ -17,7 +17,12 @@ func TestTypes(t *testing.T) {
idGen := IDGenerator{branchID: 99}
idGen.NewBranchID()
})
err = CatchP(func() {
MustGenGid("http://localhost:8080/api/no")
})
assert.Error(t, err)
assert.Error(t, err)
_, err = TransInfoFromQuery(url.Values{})
assert.Error(t, err)
}

View File

@ -1,12 +1,12 @@
package dtmcli
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"os"
"path"
"path/filepath"
"runtime"
"strconv"
"strings"
@ -126,6 +126,21 @@ func LogRedf(fmt string, args ...interface{}) {
Logf("\x1b[31m\n"+fmt+"\x1b[0m\n", args...)
}
// LogFatalf 采用红色打印错误类信息, 并退出
func LogFatalf(fmt string, args ...interface{}) {
Logf("\x1b[31m\n"+fmt+"\x1b[0m\n", args...)
os.Exit(1)
}
// FatalIfError 采用红色打印错误类信息, 并退出
func FatalIfError(err error) {
if err == nil {
return
}
Logf("\x1b[31m\nFatal error: %v\x1b[0m\n", err)
os.Exit(1)
}
// RestyClient the resty object
var RestyClient = resty.New()
@ -145,35 +160,6 @@ func init() {
})
}
// CheckRestySuccess panic if error or resp not success
func CheckRestySuccess(resp *resty.Response, err error) {
E2P(err)
if !strings.Contains(resp.String(), "SUCCESS") {
panic(fmt.Errorf("resty response not success: %s", resp.String()))
}
}
// MustGetwd must version of os.Getwd
func MustGetwd() string {
wd, err := os.Getwd()
E2P(err)
return wd
}
// GetCurrentCodeDir name is clear
func GetCurrentCodeDir() string {
_, file, _, _ := runtime.Caller(1)
return filepath.Dir(file)
}
// GetProjectDir name is clear
func GetProjectDir() string {
_, file, _, _ := runtime.Caller(1)
for ; !strings.HasSuffix(file, "/dtm"); file = filepath.Dir(file) {
}
return file
}
// GetFuncName get current call func name
func GetFuncName() string {
pc, _, _, _ := runtime.Caller(1)
@ -187,3 +173,93 @@ func MayReplaceLocalhost(host string) string {
}
return host
}
var sqlDbs = map[string]*sql.DB{}
// SdbGet get pooled sql.DB
func SdbGet(conf map[string]string) (*sql.DB, error) {
dsn := GetDsn(conf)
if sqlDbs[dsn] == nil {
db, err := SdbAlone(conf)
if err != nil {
return nil, err
}
sqlDbs[dsn] = db
}
return sqlDbs[dsn], nil
}
// SdbAlone get a standalone db connection
func SdbAlone(conf map[string]string) (*sql.DB, error) {
dsn := GetDsn(conf)
Logf("opening alone %s: %s", conf["driver"], strings.Replace(dsn, conf["password"], "****", 1))
return sql.Open(conf["driver"], dsn)
}
// SdbExec use raw db to exec
func SdbExec(db *sql.DB, sql string, values ...interface{}) (affected int64, rerr error) {
r, rerr := db.Exec(sql, values...)
if rerr == nil {
affected, rerr = r.RowsAffected()
Logf("affected: %d for %s %v", affected, sql, values)
} else {
LogRedf("exec error: %v for %s %v", rerr, sql, values)
}
return
}
// StxExec use raw tx to exec
func StxExec(tx *sql.Tx, sql string, values ...interface{}) (affected int64, rerr error) {
r, rerr := tx.Exec(sql, values...)
if rerr == nil {
affected, rerr = r.RowsAffected()
Logf("affected: %d for %s %v", affected, sql, values)
} else {
LogRedf("exec error: %v for %s %v", rerr, sql, values)
}
return
}
// StxQueryRow use raw tx to query row
func StxQueryRow(tx *sql.Tx, query string, args ...interface{}) *sql.Row {
Logf("querying: "+query, args...)
return tx.QueryRow(query, args...)
}
// GetDsn get dsn from map config
func GetDsn(conf map[string]string) string {
conf["host"] = MayReplaceLocalhost(conf["host"])
driver := conf["driver"]
dsn := MS{
"mysql": fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local",
conf["user"], conf["password"], conf["host"], conf["port"], conf["database"]),
"postgres": fmt.Sprintf("host=%s user=%s password=%s dbname='%s' port=%s sslmode=disable TimeZone=Asia/Shanghai",
conf["host"], conf["user"], conf["password"], conf["database"], conf["port"]),
}[driver]
PanicIf(dsn == "", fmt.Errorf("unknow driver: %s", driver))
return dsn
}
// CheckResponse 检查Response返回错误
func CheckResponse(resp *resty.Response, err error) error {
if err == nil && resp != nil {
if resp.IsError() {
return errors.New(resp.String())
} else if strings.Contains(resp.String(), "FAILURE") {
return ErrFailure
}
}
return err
}
// CheckResult 检查Result返回错误
func CheckResult(res interface{}, err error) error {
resp, ok := res.(*resty.Response)
if ok {
return CheckResponse(resp, err)
}
if res != nil && strings.Contains(MustMarshalString(res), "FAILURE") {
return ErrFailure
}
return err
}

View File

@ -65,11 +65,6 @@ func TestSome(t *testing.T) {
MustAtoi("abc")
})
assert.Error(t, err)
wd := MustGetwd()
assert.NotEqual(t, "", wd)
dir1 := GetCurrentCodeDir()
assert.Equal(t, true, strings.HasSuffix(dir1, "dtmcli"))
func1 := GetFuncName()
assert.Equal(t, true, strings.HasSuffix(func1, "TestSome"))

View File

@ -8,8 +8,6 @@ import (
"github.com/go-resty/resty/v2"
)
var e2p = E2P
// XaGlobalFunc type of xa global function
type XaGlobalFunc func(xa *Xa) (*resty.Response, error)
@ -58,10 +56,13 @@ func NewXaClient(server string, mysqlConf map[string]string, callbackURL string,
// HandleCallback 处理commit/rollback的回调
func (xc *XaClient) HandleCallback(gid string, branchID string, action string) (interface{}, error) {
db := SdbAlone(xc.Conf)
db, err := SdbAlone(xc.Conf)
if err != nil {
return nil, err
}
defer db.Close()
xaID := gid + "-" + branchID
_, err := SdbExec(db, fmt.Sprintf("xa %s '%s'", action, xaID))
_, err = SdbExec(db, fmt.Sprintf("xa %s '%s'", action, xaID))
return ResultSuccess, err
}
@ -75,7 +76,10 @@ func (xc *XaClient) XaLocalTransaction(qs url.Values, xaFunc XaLocalFunc) (ret i
xa.Dtm = xc.Server
branchID := xa.NewBranchID()
xaBranch := xa.Gid + "-" + branchID
db := SdbAlone(xc.Conf)
db, rerr := SdbAlone(xc.Conf)
if rerr != nil {
return
}
defer func() { db.Close() }()
defer func() {
x := recover()

View File

@ -2,7 +2,6 @@ package dtmsvr
import (
"github.com/yedf/dtm/common"
"github.com/yedf/dtm/dtmcli"
)
type dtmsvrConfig struct {
@ -14,9 +13,6 @@ var config = &dtmsvrConfig{
TransCronInterval: 10,
}
var dbName = "dtm"
func init() {
common.InitConfig(dtmcli.GetProjectDir(), &config)
config.DB["database"] = ""
common.InitConfig(&config)
}

View File

@ -42,7 +42,7 @@ func resetXaData() {
func TestMain(m *testing.M) {
TransProcessedTestChan = make(chan string, 1)
common.InitConfig(dtmcli.GetProjectDir(), &config)
common.InitConfig(&config)
PopulateDB(false)
examples.PopulateDB(false)
// 启动组件
@ -72,21 +72,6 @@ func TestCover(t *testing.T) {
go sleepCronTime()
}
func TestType(t *testing.T) {
err := dtmcli.CatchP(func() {
dtmcli.MustGenGid("http://localhost:8080/api/no")
})
assert.Error(t, err)
err = dtmcli.CatchP(func() {
resp, err := dtmcli.RestyClient.R().SetBody(dtmcli.M{
"gid": "1",
"trans_type": "msg",
}).Get("http://localhost:8080/api/dtmsvr/abort")
dtmcli.CheckRestySuccess(resp, err)
})
assert.Error(t, err)
}
func getTransStatus(gid string) string {
sm := TransGlobal{}
dbr := dbGet().Model(&sm).Where("gid=?", gid).First(&sm)

View File

@ -23,6 +23,6 @@ func StartSvr() {
// PopulateDB setup mysql data
func PopulateDB(skipDrop bool) {
file := fmt.Sprintf("%s/dtmsvr.%s.sql", dtmcli.GetCurrentCodeDir(), config.DB["driver"])
file := fmt.Sprintf("%s/dtmsvr.%s.sql", common.GetCurrentCodeDir(), config.DB["driver"])
examples.RunSQLScript(config.DB, file, skipDrop)
}

View File

@ -2,7 +2,6 @@ package examples
import (
"github.com/yedf/dtm/common"
"github.com/yedf/dtm/dtmcli"
)
type exampleConfig struct {
@ -12,5 +11,5 @@ type exampleConfig struct {
var config = exampleConfig{}
func init() {
common.InitConfig(dtmcli.GetProjectDir(), &config)
common.InitConfig(&config)
}

View File

@ -5,12 +5,14 @@ import (
"io/ioutil"
"strings"
"github.com/yedf/dtm/common"
"github.com/yedf/dtm/dtmcli"
)
// RunSQLScript 1
func RunSQLScript(conf map[string]string, script string, skipDrop bool) {
con := dtmcli.SdbAlone(conf)
con, err := dtmcli.SdbAlone(conf)
e2p(err)
defer func() { con.Close() }()
content, err := ioutil.ReadFile(script)
e2p(err)
@ -27,6 +29,6 @@ func RunSQLScript(conf map[string]string, script string, skipDrop bool) {
// PopulateDB populate example mysql data
func PopulateDB(skipDrop bool) {
file := fmt.Sprintf("%s/examples.%s.sql", dtmcli.GetCurrentCodeDir(), config.DB["driver"])
file := fmt.Sprintf("%s/examples.%s.sql", common.GetCurrentCodeDir(), config.DB["driver"])
RunSQLScript(config.DB, file, skipDrop)
}

View File

@ -64,7 +64,9 @@ func dbGet() *common.DB {
}
func sdbGet() *sql.DB {
return dtmcli.SdbGet(config.DB)
db, err := dtmcli.SdbGet(config.DB)
e2p(err)
return db
}
// MustGetTrans construct transaction info from request