dtmcli ok
This commit is contained in:
parent
0270648c6c
commit
e4392bac43
@ -3,7 +3,7 @@ package common
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/go-playground/assert/v2"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/yedf/dtm/dtmcli"
|
"github.com/yedf/dtm/dtmcli"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -14,8 +14,7 @@ type testConfig struct {
|
|||||||
var config = testConfig{}
|
var config = testConfig{}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
InitConfig(dtmcli.GetProjectDir(), &config)
|
InitConfig(&config)
|
||||||
config.DB["database"] = ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDb(t *testing.T) {
|
func TestDb(t *testing.T) {
|
||||||
@ -31,8 +30,9 @@ func TestDb(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDbAlone(t *testing.T) {
|
func TestDbAlone(t *testing.T) {
|
||||||
db := dtmcli.SdbAlone(config.DB)
|
db, err := dtmcli.SdbAlone(config.DB)
|
||||||
_, err := dtmcli.SdbExec(db, "select 1")
|
assert.Nil(t, err)
|
||||||
|
_, err = dtmcli.SdbExec(db, "select 1")
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
db.Close()
|
db.Close()
|
||||||
_, err = dtmcli.SdbExec(db, "select 1")
|
_, err = dtmcli.SdbExec(db, "select 1")
|
||||||
|
|||||||
@ -4,6 +4,9 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@ -59,14 +62,36 @@ func WrapHandler(fn func(*gin.Context) (interface{}, error)) gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MustGetwd must version of os.Getwd
|
||||||
|
func MustGetwd() string {
|
||||||
|
wd, err := os.Getwd()
|
||||||
|
dtmcli.E2P(err)
|
||||||
|
return wd
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCurrentCodeDir 获取当前源代码的目录,主要用于测试时,查找相关文件
|
||||||
|
func GetCurrentCodeDir() string {
|
||||||
|
_, file, _, _ := runtime.Caller(1)
|
||||||
|
return filepath.Dir(file)
|
||||||
|
}
|
||||||
|
|
||||||
// InitConfig init config
|
// InitConfig init config
|
||||||
func InitConfig(dir string, config interface{}) {
|
func InitConfig(config interface{}) {
|
||||||
cont, err := ioutil.ReadFile(dir + "/conf.yml")
|
cont := []byte{}
|
||||||
if err != nil {
|
for d := MustGetwd(); d != ""; d = filepath.Dir(d) {
|
||||||
cont, err = ioutil.ReadFile(dir + "/conf.sample.yml")
|
cont1, err := ioutil.ReadFile(d + "/conf.yml")
|
||||||
|
if err != nil {
|
||||||
|
cont1, err = ioutil.ReadFile(d + "/conf.sample.yml")
|
||||||
|
}
|
||||||
|
if cont1 != nil {
|
||||||
|
cont = cont1
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cont == nil {
|
||||||
|
dtmcli.LogFatalf("no config file conf.yml/conf.sample.yml found in current and parent path: %s", MustGetwd())
|
||||||
}
|
}
|
||||||
dtmcli.Logf("cont is: \n%s", string(cont))
|
dtmcli.Logf("cont is: \n%s", string(cont))
|
||||||
dtmcli.E2P(err)
|
err := yaml.Unmarshal(cont, config)
|
||||||
err = yaml.Unmarshal(cont, config)
|
dtmcli.FatalIfError(err)
|
||||||
dtmcli.E2P(err)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -30,3 +30,12 @@ func TestGin(t *testing.T) {
|
|||||||
assert.Equal(t, "1", getResultString("/api/sample", nil))
|
assert.Equal(t, "1", getResultString("/api/sample", nil))
|
||||||
assert.Equal(t, "{\"code\":500,\"message\":\"err1\"}", getResultString("/api/error", strings.NewReader("{}")))
|
assert.Equal(t, "{\"code\":500,\"message\":\"err1\"}", getResultString("/api/error", strings.NewReader("{}")))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFuncs(t *testing.T) {
|
||||||
|
wd := MustGetwd()
|
||||||
|
assert.NotEqual(t, "", wd)
|
||||||
|
|
||||||
|
dir1 := GetCurrentCodeDir()
|
||||||
|
assert.Equal(t, true, strings.HasSuffix(dir1, "common"))
|
||||||
|
|
||||||
|
}
|
||||||
|
|||||||
@ -1,13 +1,9 @@
|
|||||||
package dtmcli
|
package dtmcli
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/go-resty/resty/v2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// M a short name
|
// M a short name
|
||||||
@ -26,94 +22,6 @@ func MustGenGid(server string) string {
|
|||||||
return res["gid"]
|
return res["gid"]
|
||||||
}
|
}
|
||||||
|
|
||||||
var sqlDbs = map[string]*sql.DB{}
|
|
||||||
|
|
||||||
// SdbGet get pooled sql.DB
|
|
||||||
func SdbGet(conf map[string]string) *sql.DB {
|
|
||||||
dsn := GetDsn(conf)
|
|
||||||
if sqlDbs[dsn] == nil {
|
|
||||||
sqlDbs[dsn] = SdbAlone(conf)
|
|
||||||
}
|
|
||||||
return sqlDbs[dsn]
|
|
||||||
}
|
|
||||||
|
|
||||||
// SdbAlone get a standalone db connection
|
|
||||||
func SdbAlone(conf map[string]string) *sql.DB {
|
|
||||||
dsn := GetDsn(conf)
|
|
||||||
Logf("opening alone %s: %s", conf["driver"], strings.Replace(dsn, conf["password"], "****", 1))
|
|
||||||
mdb, err := sql.Open(conf["driver"], dsn)
|
|
||||||
E2P(err)
|
|
||||||
return mdb
|
|
||||||
}
|
|
||||||
|
|
||||||
// SdbExec use raw db to exec
|
|
||||||
func SdbExec(db *sql.DB, sql string, values ...interface{}) (affected int64, rerr error) {
|
|
||||||
r, rerr := db.Exec(sql, values...)
|
|
||||||
if rerr == nil {
|
|
||||||
affected, rerr = r.RowsAffected()
|
|
||||||
Logf("affected: %d for %s %v", affected, sql, values)
|
|
||||||
} else {
|
|
||||||
LogRedf("exec error: %v for %s %v", rerr, sql, values)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// StxExec use raw tx to exec
|
|
||||||
func StxExec(tx *sql.Tx, sql string, values ...interface{}) (affected int64, rerr error) {
|
|
||||||
r, rerr := tx.Exec(sql, values...)
|
|
||||||
if rerr == nil {
|
|
||||||
affected, rerr = r.RowsAffected()
|
|
||||||
Logf("affected: %d for %s %v", affected, sql, values)
|
|
||||||
} else {
|
|
||||||
LogRedf("exec error: %v for %s %v", rerr, sql, values)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// StxQueryRow use raw tx to query row
|
|
||||||
func StxQueryRow(tx *sql.Tx, query string, args ...interface{}) *sql.Row {
|
|
||||||
Logf("querying: "+query, args...)
|
|
||||||
return tx.QueryRow(query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDsn get dsn from map config
|
|
||||||
func GetDsn(conf map[string]string) string {
|
|
||||||
conf["host"] = MayReplaceLocalhost(conf["host"])
|
|
||||||
driver := conf["driver"]
|
|
||||||
dsn := MS{
|
|
||||||
"mysql": fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local",
|
|
||||||
conf["user"], conf["password"], conf["host"], conf["port"], conf["database"]),
|
|
||||||
"postgres": fmt.Sprintf("host=%s user=%s password=%s dbname='%s' port=%s sslmode=disable TimeZone=Asia/Shanghai",
|
|
||||||
conf["host"], conf["user"], conf["password"], conf["database"], conf["port"]),
|
|
||||||
}[driver]
|
|
||||||
PanicIf(dsn == "", fmt.Errorf("unknow driver: %s", driver))
|
|
||||||
return dsn
|
|
||||||
}
|
|
||||||
|
|
||||||
// CheckResponse 检查Response,返回错误
|
|
||||||
func CheckResponse(resp *resty.Response, err error) error {
|
|
||||||
if err == nil && resp != nil {
|
|
||||||
if resp.IsError() {
|
|
||||||
return errors.New(resp.String())
|
|
||||||
} else if strings.Contains(resp.String(), "FAILURE") {
|
|
||||||
return ErrFailure
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// CheckResult 检查Result,返回错误
|
|
||||||
func CheckResult(res interface{}, err error) error {
|
|
||||||
resp, ok := res.(*resty.Response)
|
|
||||||
if ok {
|
|
||||||
return CheckResponse(resp, err)
|
|
||||||
}
|
|
||||||
if res != nil && strings.Contains(MustMarshalString(res), "FAILURE") {
|
|
||||||
return ErrFailure
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// IDGenerator used to generate a branch id
|
// IDGenerator used to generate a branch id
|
||||||
type IDGenerator struct {
|
type IDGenerator struct {
|
||||||
parentID string
|
parentID string
|
||||||
|
|||||||
@ -17,7 +17,12 @@ func TestTypes(t *testing.T) {
|
|||||||
idGen := IDGenerator{branchID: 99}
|
idGen := IDGenerator{branchID: 99}
|
||||||
idGen.NewBranchID()
|
idGen.NewBranchID()
|
||||||
})
|
})
|
||||||
|
err = CatchP(func() {
|
||||||
|
MustGenGid("http://localhost:8080/api/no")
|
||||||
|
})
|
||||||
|
assert.Error(t, err)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
_, err = TransInfoFromQuery(url.Values{})
|
_, err = TransInfoFromQuery(url.Values{})
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,12 +1,12 @@
|
|||||||
package dtmcli
|
package dtmcli
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@ -126,6 +126,21 @@ func LogRedf(fmt string, args ...interface{}) {
|
|||||||
Logf("\x1b[31m\n"+fmt+"\x1b[0m\n", args...)
|
Logf("\x1b[31m\n"+fmt+"\x1b[0m\n", args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LogFatalf 采用红色打印错误类信息, 并退出
|
||||||
|
func LogFatalf(fmt string, args ...interface{}) {
|
||||||
|
Logf("\x1b[31m\n"+fmt+"\x1b[0m\n", args...)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FatalIfError 采用红色打印错误类信息, 并退出
|
||||||
|
func FatalIfError(err error) {
|
||||||
|
if err == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
Logf("\x1b[31m\nFatal error: %v\x1b[0m\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
// RestyClient the resty object
|
// RestyClient the resty object
|
||||||
var RestyClient = resty.New()
|
var RestyClient = resty.New()
|
||||||
|
|
||||||
@ -145,35 +160,6 @@ func init() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckRestySuccess panic if error or resp not success
|
|
||||||
func CheckRestySuccess(resp *resty.Response, err error) {
|
|
||||||
E2P(err)
|
|
||||||
if !strings.Contains(resp.String(), "SUCCESS") {
|
|
||||||
panic(fmt.Errorf("resty response not success: %s", resp.String()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustGetwd must version of os.Getwd
|
|
||||||
func MustGetwd() string {
|
|
||||||
wd, err := os.Getwd()
|
|
||||||
E2P(err)
|
|
||||||
return wd
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetCurrentCodeDir name is clear
|
|
||||||
func GetCurrentCodeDir() string {
|
|
||||||
_, file, _, _ := runtime.Caller(1)
|
|
||||||
return filepath.Dir(file)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetProjectDir name is clear
|
|
||||||
func GetProjectDir() string {
|
|
||||||
_, file, _, _ := runtime.Caller(1)
|
|
||||||
for ; !strings.HasSuffix(file, "/dtm"); file = filepath.Dir(file) {
|
|
||||||
}
|
|
||||||
return file
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetFuncName get current call func name
|
// GetFuncName get current call func name
|
||||||
func GetFuncName() string {
|
func GetFuncName() string {
|
||||||
pc, _, _, _ := runtime.Caller(1)
|
pc, _, _, _ := runtime.Caller(1)
|
||||||
@ -187,3 +173,93 @@ func MayReplaceLocalhost(host string) string {
|
|||||||
}
|
}
|
||||||
return host
|
return host
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var sqlDbs = map[string]*sql.DB{}
|
||||||
|
|
||||||
|
// SdbGet get pooled sql.DB
|
||||||
|
func SdbGet(conf map[string]string) (*sql.DB, error) {
|
||||||
|
dsn := GetDsn(conf)
|
||||||
|
if sqlDbs[dsn] == nil {
|
||||||
|
db, err := SdbAlone(conf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sqlDbs[dsn] = db
|
||||||
|
}
|
||||||
|
return sqlDbs[dsn], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SdbAlone get a standalone db connection
|
||||||
|
func SdbAlone(conf map[string]string) (*sql.DB, error) {
|
||||||
|
dsn := GetDsn(conf)
|
||||||
|
Logf("opening alone %s: %s", conf["driver"], strings.Replace(dsn, conf["password"], "****", 1))
|
||||||
|
return sql.Open(conf["driver"], dsn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SdbExec use raw db to exec
|
||||||
|
func SdbExec(db *sql.DB, sql string, values ...interface{}) (affected int64, rerr error) {
|
||||||
|
r, rerr := db.Exec(sql, values...)
|
||||||
|
if rerr == nil {
|
||||||
|
affected, rerr = r.RowsAffected()
|
||||||
|
Logf("affected: %d for %s %v", affected, sql, values)
|
||||||
|
} else {
|
||||||
|
LogRedf("exec error: %v for %s %v", rerr, sql, values)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// StxExec use raw tx to exec
|
||||||
|
func StxExec(tx *sql.Tx, sql string, values ...interface{}) (affected int64, rerr error) {
|
||||||
|
r, rerr := tx.Exec(sql, values...)
|
||||||
|
if rerr == nil {
|
||||||
|
affected, rerr = r.RowsAffected()
|
||||||
|
Logf("affected: %d for %s %v", affected, sql, values)
|
||||||
|
} else {
|
||||||
|
LogRedf("exec error: %v for %s %v", rerr, sql, values)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// StxQueryRow use raw tx to query row
|
||||||
|
func StxQueryRow(tx *sql.Tx, query string, args ...interface{}) *sql.Row {
|
||||||
|
Logf("querying: "+query, args...)
|
||||||
|
return tx.QueryRow(query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDsn get dsn from map config
|
||||||
|
func GetDsn(conf map[string]string) string {
|
||||||
|
conf["host"] = MayReplaceLocalhost(conf["host"])
|
||||||
|
driver := conf["driver"]
|
||||||
|
dsn := MS{
|
||||||
|
"mysql": fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local",
|
||||||
|
conf["user"], conf["password"], conf["host"], conf["port"], conf["database"]),
|
||||||
|
"postgres": fmt.Sprintf("host=%s user=%s password=%s dbname='%s' port=%s sslmode=disable TimeZone=Asia/Shanghai",
|
||||||
|
conf["host"], conf["user"], conf["password"], conf["database"], conf["port"]),
|
||||||
|
}[driver]
|
||||||
|
PanicIf(dsn == "", fmt.Errorf("unknow driver: %s", driver))
|
||||||
|
return dsn
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckResponse 检查Response,返回错误
|
||||||
|
func CheckResponse(resp *resty.Response, err error) error {
|
||||||
|
if err == nil && resp != nil {
|
||||||
|
if resp.IsError() {
|
||||||
|
return errors.New(resp.String())
|
||||||
|
} else if strings.Contains(resp.String(), "FAILURE") {
|
||||||
|
return ErrFailure
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckResult 检查Result,返回错误
|
||||||
|
func CheckResult(res interface{}, err error) error {
|
||||||
|
resp, ok := res.(*resty.Response)
|
||||||
|
if ok {
|
||||||
|
return CheckResponse(resp, err)
|
||||||
|
}
|
||||||
|
if res != nil && strings.Contains(MustMarshalString(res), "FAILURE") {
|
||||||
|
return ErrFailure
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
@ -65,11 +65,6 @@ func TestSome(t *testing.T) {
|
|||||||
MustAtoi("abc")
|
MustAtoi("abc")
|
||||||
})
|
})
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
wd := MustGetwd()
|
|
||||||
assert.NotEqual(t, "", wd)
|
|
||||||
|
|
||||||
dir1 := GetCurrentCodeDir()
|
|
||||||
assert.Equal(t, true, strings.HasSuffix(dir1, "dtmcli"))
|
|
||||||
|
|
||||||
func1 := GetFuncName()
|
func1 := GetFuncName()
|
||||||
assert.Equal(t, true, strings.HasSuffix(func1, "TestSome"))
|
assert.Equal(t, true, strings.HasSuffix(func1, "TestSome"))
|
||||||
|
|||||||
14
dtmcli/xa.go
14
dtmcli/xa.go
@ -8,8 +8,6 @@ import (
|
|||||||
"github.com/go-resty/resty/v2"
|
"github.com/go-resty/resty/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
var e2p = E2P
|
|
||||||
|
|
||||||
// XaGlobalFunc type of xa global function
|
// XaGlobalFunc type of xa global function
|
||||||
type XaGlobalFunc func(xa *Xa) (*resty.Response, error)
|
type XaGlobalFunc func(xa *Xa) (*resty.Response, error)
|
||||||
|
|
||||||
@ -58,10 +56,13 @@ func NewXaClient(server string, mysqlConf map[string]string, callbackURL string,
|
|||||||
|
|
||||||
// HandleCallback 处理commit/rollback的回调
|
// HandleCallback 处理commit/rollback的回调
|
||||||
func (xc *XaClient) HandleCallback(gid string, branchID string, action string) (interface{}, error) {
|
func (xc *XaClient) HandleCallback(gid string, branchID string, action string) (interface{}, error) {
|
||||||
db := SdbAlone(xc.Conf)
|
db, err := SdbAlone(xc.Conf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
xaID := gid + "-" + branchID
|
xaID := gid + "-" + branchID
|
||||||
_, err := SdbExec(db, fmt.Sprintf("xa %s '%s'", action, xaID))
|
_, err = SdbExec(db, fmt.Sprintf("xa %s '%s'", action, xaID))
|
||||||
return ResultSuccess, err
|
return ResultSuccess, err
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -75,7 +76,10 @@ func (xc *XaClient) XaLocalTransaction(qs url.Values, xaFunc XaLocalFunc) (ret i
|
|||||||
xa.Dtm = xc.Server
|
xa.Dtm = xc.Server
|
||||||
branchID := xa.NewBranchID()
|
branchID := xa.NewBranchID()
|
||||||
xaBranch := xa.Gid + "-" + branchID
|
xaBranch := xa.Gid + "-" + branchID
|
||||||
db := SdbAlone(xc.Conf)
|
db, rerr := SdbAlone(xc.Conf)
|
||||||
|
if rerr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
defer func() { db.Close() }()
|
defer func() { db.Close() }()
|
||||||
defer func() {
|
defer func() {
|
||||||
x := recover()
|
x := recover()
|
||||||
|
|||||||
@ -2,7 +2,6 @@ package dtmsvr
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/yedf/dtm/common"
|
"github.com/yedf/dtm/common"
|
||||||
"github.com/yedf/dtm/dtmcli"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type dtmsvrConfig struct {
|
type dtmsvrConfig struct {
|
||||||
@ -14,9 +13,6 @@ var config = &dtmsvrConfig{
|
|||||||
TransCronInterval: 10,
|
TransCronInterval: 10,
|
||||||
}
|
}
|
||||||
|
|
||||||
var dbName = "dtm"
|
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
common.InitConfig(dtmcli.GetProjectDir(), &config)
|
common.InitConfig(&config)
|
||||||
config.DB["database"] = ""
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -42,7 +42,7 @@ func resetXaData() {
|
|||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
TransProcessedTestChan = make(chan string, 1)
|
TransProcessedTestChan = make(chan string, 1)
|
||||||
common.InitConfig(dtmcli.GetProjectDir(), &config)
|
common.InitConfig(&config)
|
||||||
PopulateDB(false)
|
PopulateDB(false)
|
||||||
examples.PopulateDB(false)
|
examples.PopulateDB(false)
|
||||||
// 启动组件
|
// 启动组件
|
||||||
@ -72,21 +72,6 @@ func TestCover(t *testing.T) {
|
|||||||
go sleepCronTime()
|
go sleepCronTime()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestType(t *testing.T) {
|
|
||||||
err := dtmcli.CatchP(func() {
|
|
||||||
dtmcli.MustGenGid("http://localhost:8080/api/no")
|
|
||||||
})
|
|
||||||
assert.Error(t, err)
|
|
||||||
err = dtmcli.CatchP(func() {
|
|
||||||
resp, err := dtmcli.RestyClient.R().SetBody(dtmcli.M{
|
|
||||||
"gid": "1",
|
|
||||||
"trans_type": "msg",
|
|
||||||
}).Get("http://localhost:8080/api/dtmsvr/abort")
|
|
||||||
dtmcli.CheckRestySuccess(resp, err)
|
|
||||||
})
|
|
||||||
assert.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTransStatus(gid string) string {
|
func getTransStatus(gid string) string {
|
||||||
sm := TransGlobal{}
|
sm := TransGlobal{}
|
||||||
dbr := dbGet().Model(&sm).Where("gid=?", gid).First(&sm)
|
dbr := dbGet().Model(&sm).Where("gid=?", gid).First(&sm)
|
||||||
|
|||||||
@ -23,6 +23,6 @@ func StartSvr() {
|
|||||||
|
|
||||||
// PopulateDB setup mysql data
|
// PopulateDB setup mysql data
|
||||||
func PopulateDB(skipDrop bool) {
|
func PopulateDB(skipDrop bool) {
|
||||||
file := fmt.Sprintf("%s/dtmsvr.%s.sql", dtmcli.GetCurrentCodeDir(), config.DB["driver"])
|
file := fmt.Sprintf("%s/dtmsvr.%s.sql", common.GetCurrentCodeDir(), config.DB["driver"])
|
||||||
examples.RunSQLScript(config.DB, file, skipDrop)
|
examples.RunSQLScript(config.DB, file, skipDrop)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,7 +2,6 @@ package examples
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/yedf/dtm/common"
|
"github.com/yedf/dtm/common"
|
||||||
"github.com/yedf/dtm/dtmcli"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type exampleConfig struct {
|
type exampleConfig struct {
|
||||||
@ -12,5 +11,5 @@ type exampleConfig struct {
|
|||||||
var config = exampleConfig{}
|
var config = exampleConfig{}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
common.InitConfig(dtmcli.GetProjectDir(), &config)
|
common.InitConfig(&config)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,12 +5,14 @@ import (
|
|||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/yedf/dtm/common"
|
||||||
"github.com/yedf/dtm/dtmcli"
|
"github.com/yedf/dtm/dtmcli"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RunSQLScript 1
|
// RunSQLScript 1
|
||||||
func RunSQLScript(conf map[string]string, script string, skipDrop bool) {
|
func RunSQLScript(conf map[string]string, script string, skipDrop bool) {
|
||||||
con := dtmcli.SdbAlone(conf)
|
con, err := dtmcli.SdbAlone(conf)
|
||||||
|
e2p(err)
|
||||||
defer func() { con.Close() }()
|
defer func() { con.Close() }()
|
||||||
content, err := ioutil.ReadFile(script)
|
content, err := ioutil.ReadFile(script)
|
||||||
e2p(err)
|
e2p(err)
|
||||||
@ -27,6 +29,6 @@ func RunSQLScript(conf map[string]string, script string, skipDrop bool) {
|
|||||||
|
|
||||||
// PopulateDB populate example mysql data
|
// PopulateDB populate example mysql data
|
||||||
func PopulateDB(skipDrop bool) {
|
func PopulateDB(skipDrop bool) {
|
||||||
file := fmt.Sprintf("%s/examples.%s.sql", dtmcli.GetCurrentCodeDir(), config.DB["driver"])
|
file := fmt.Sprintf("%s/examples.%s.sql", common.GetCurrentCodeDir(), config.DB["driver"])
|
||||||
RunSQLScript(config.DB, file, skipDrop)
|
RunSQLScript(config.DB, file, skipDrop)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -64,7 +64,9 @@ func dbGet() *common.DB {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func sdbGet() *sql.DB {
|
func sdbGet() *sql.DB {
|
||||||
return dtmcli.SdbGet(config.DB)
|
db, err := dtmcli.SdbGet(config.DB)
|
||||||
|
e2p(err)
|
||||||
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
// MustGetTrans construct transaction info from request
|
// MustGetTrans construct transaction info from request
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user