diff --git a/common/types_test.go b/common/types_test.go index 44e177b..5f1e11d 100644 --- a/common/types_test.go +++ b/common/types_test.go @@ -3,7 +3,7 @@ package common import ( "testing" - "github.com/go-playground/assert/v2" + "github.com/stretchr/testify/assert" "github.com/yedf/dtm/dtmcli" ) @@ -14,8 +14,7 @@ type testConfig struct { var config = testConfig{} func init() { - InitConfig(dtmcli.GetProjectDir(), &config) - config.DB["database"] = "" + InitConfig(&config) } func TestDb(t *testing.T) { @@ -31,8 +30,9 @@ func TestDb(t *testing.T) { } func TestDbAlone(t *testing.T) { - db := dtmcli.SdbAlone(config.DB) - _, err := dtmcli.SdbExec(db, "select 1") + db, err := dtmcli.SdbAlone(config.DB) + assert.Nil(t, err) + _, err = dtmcli.SdbExec(db, "select 1") assert.Equal(t, nil, err) db.Close() _, err = dtmcli.SdbExec(db, "select 1") diff --git a/common/utils.go b/common/utils.go index 4b00014..0ea402d 100644 --- a/common/utils.go +++ b/common/utils.go @@ -4,6 +4,9 @@ import ( "bytes" "encoding/json" "io/ioutil" + "os" + "path/filepath" + "runtime" "time" "github.com/gin-gonic/gin" @@ -59,14 +62,36 @@ func WrapHandler(fn func(*gin.Context) (interface{}, error)) gin.HandlerFunc { } } +// MustGetwd must version of os.Getwd +func MustGetwd() string { + wd, err := os.Getwd() + dtmcli.E2P(err) + return wd +} + +// GetCurrentCodeDir 获取当前源代码的目录,主要用于测试时,查找相关文件 +func GetCurrentCodeDir() string { + _, file, _, _ := runtime.Caller(1) + return filepath.Dir(file) +} + // InitConfig init config -func InitConfig(dir string, config interface{}) { - cont, err := ioutil.ReadFile(dir + "/conf.yml") - if err != nil { - cont, err = ioutil.ReadFile(dir + "/conf.sample.yml") +func InitConfig(config interface{}) { + cont := []byte{} + for d := MustGetwd(); d != ""; d = filepath.Dir(d) { + cont1, err := ioutil.ReadFile(d + "/conf.yml") + if err != nil { + cont1, err = ioutil.ReadFile(d + "/conf.sample.yml") + } + if cont1 != nil { + cont = cont1 + break + } + } + if cont == nil { + dtmcli.LogFatalf("no config file conf.yml/conf.sample.yml found in current and parent path: %s", MustGetwd()) } dtmcli.Logf("cont is: \n%s", string(cont)) - dtmcli.E2P(err) - err = yaml.Unmarshal(cont, config) - dtmcli.E2P(err) + err := yaml.Unmarshal(cont, config) + dtmcli.FatalIfError(err) } diff --git a/common/utils_test.go b/common/utils_test.go index ff9b6ed..35b1461 100644 --- a/common/utils_test.go +++ b/common/utils_test.go @@ -30,3 +30,12 @@ func TestGin(t *testing.T) { assert.Equal(t, "1", getResultString("/api/sample", nil)) assert.Equal(t, "{\"code\":500,\"message\":\"err1\"}", getResultString("/api/error", strings.NewReader("{}"))) } + +func TestFuncs(t *testing.T) { + wd := MustGetwd() + assert.NotEqual(t, "", wd) + + dir1 := GetCurrentCodeDir() + assert.Equal(t, true, strings.HasSuffix(dir1, "common")) + +} diff --git a/dtmcli/types.go b/dtmcli/types.go index a8dc62e..1d95054 100644 --- a/dtmcli/types.go +++ b/dtmcli/types.go @@ -1,13 +1,9 @@ package dtmcli import ( - "database/sql" "errors" "fmt" "net/url" - "strings" - - "github.com/go-resty/resty/v2" ) // M a short name @@ -26,94 +22,6 @@ func MustGenGid(server string) string { return res["gid"] } -var sqlDbs = map[string]*sql.DB{} - -// SdbGet get pooled sql.DB -func SdbGet(conf map[string]string) *sql.DB { - dsn := GetDsn(conf) - if sqlDbs[dsn] == nil { - sqlDbs[dsn] = SdbAlone(conf) - } - return sqlDbs[dsn] -} - -// SdbAlone get a standalone db connection -func SdbAlone(conf map[string]string) *sql.DB { - dsn := GetDsn(conf) - Logf("opening alone %s: %s", conf["driver"], strings.Replace(dsn, conf["password"], "****", 1)) - mdb, err := sql.Open(conf["driver"], dsn) - E2P(err) - return mdb -} - -// SdbExec use raw db to exec -func SdbExec(db *sql.DB, sql string, values ...interface{}) (affected int64, rerr error) { - r, rerr := db.Exec(sql, values...) - if rerr == nil { - affected, rerr = r.RowsAffected() - Logf("affected: %d for %s %v", affected, sql, values) - } else { - LogRedf("exec error: %v for %s %v", rerr, sql, values) - } - return -} - -// StxExec use raw tx to exec -func StxExec(tx *sql.Tx, sql string, values ...interface{}) (affected int64, rerr error) { - r, rerr := tx.Exec(sql, values...) - if rerr == nil { - affected, rerr = r.RowsAffected() - Logf("affected: %d for %s %v", affected, sql, values) - } else { - LogRedf("exec error: %v for %s %v", rerr, sql, values) - } - return -} - -// StxQueryRow use raw tx to query row -func StxQueryRow(tx *sql.Tx, query string, args ...interface{}) *sql.Row { - Logf("querying: "+query, args...) - return tx.QueryRow(query, args...) -} - -// GetDsn get dsn from map config -func GetDsn(conf map[string]string) string { - conf["host"] = MayReplaceLocalhost(conf["host"]) - driver := conf["driver"] - dsn := MS{ - "mysql": fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", - conf["user"], conf["password"], conf["host"], conf["port"], conf["database"]), - "postgres": fmt.Sprintf("host=%s user=%s password=%s dbname='%s' port=%s sslmode=disable TimeZone=Asia/Shanghai", - conf["host"], conf["user"], conf["password"], conf["database"], conf["port"]), - }[driver] - PanicIf(dsn == "", fmt.Errorf("unknow driver: %s", driver)) - return dsn -} - -// CheckResponse 检查Response,返回错误 -func CheckResponse(resp *resty.Response, err error) error { - if err == nil && resp != nil { - if resp.IsError() { - return errors.New(resp.String()) - } else if strings.Contains(resp.String(), "FAILURE") { - return ErrFailure - } - } - return err -} - -// CheckResult 检查Result,返回错误 -func CheckResult(res interface{}, err error) error { - resp, ok := res.(*resty.Response) - if ok { - return CheckResponse(resp, err) - } - if res != nil && strings.Contains(MustMarshalString(res), "FAILURE") { - return ErrFailure - } - return err -} - // IDGenerator used to generate a branch id type IDGenerator struct { parentID string diff --git a/dtmcli/types_test.go b/dtmcli/types_test.go index bc87bbc..5cdb5a1 100644 --- a/dtmcli/types_test.go +++ b/dtmcli/types_test.go @@ -17,7 +17,12 @@ func TestTypes(t *testing.T) { idGen := IDGenerator{branchID: 99} idGen.NewBranchID() }) + err = CatchP(func() { + MustGenGid("http://localhost:8080/api/no") + }) + assert.Error(t, err) assert.Error(t, err) _, err = TransInfoFromQuery(url.Values{}) assert.Error(t, err) + } diff --git a/dtmcli/common.go b/dtmcli/utils.go similarity index 53% rename from dtmcli/common.go rename to dtmcli/utils.go index aed8c0e..7815756 100644 --- a/dtmcli/common.go +++ b/dtmcli/utils.go @@ -1,12 +1,12 @@ package dtmcli import ( + "database/sql" "encoding/json" "errors" "fmt" "os" "path" - "path/filepath" "runtime" "strconv" "strings" @@ -126,6 +126,21 @@ func LogRedf(fmt string, args ...interface{}) { Logf("\x1b[31m\n"+fmt+"\x1b[0m\n", args...) } +// LogFatalf 采用红色打印错误类信息, 并退出 +func LogFatalf(fmt string, args ...interface{}) { + Logf("\x1b[31m\n"+fmt+"\x1b[0m\n", args...) + os.Exit(1) +} + +// FatalIfError 采用红色打印错误类信息, 并退出 +func FatalIfError(err error) { + if err == nil { + return + } + Logf("\x1b[31m\nFatal error: %v\x1b[0m\n", err) + os.Exit(1) +} + // RestyClient the resty object var RestyClient = resty.New() @@ -145,35 +160,6 @@ func init() { }) } -// CheckRestySuccess panic if error or resp not success -func CheckRestySuccess(resp *resty.Response, err error) { - E2P(err) - if !strings.Contains(resp.String(), "SUCCESS") { - panic(fmt.Errorf("resty response not success: %s", resp.String())) - } -} - -// MustGetwd must version of os.Getwd -func MustGetwd() string { - wd, err := os.Getwd() - E2P(err) - return wd -} - -// GetCurrentCodeDir name is clear -func GetCurrentCodeDir() string { - _, file, _, _ := runtime.Caller(1) - return filepath.Dir(file) -} - -// GetProjectDir name is clear -func GetProjectDir() string { - _, file, _, _ := runtime.Caller(1) - for ; !strings.HasSuffix(file, "/dtm"); file = filepath.Dir(file) { - } - return file -} - // GetFuncName get current call func name func GetFuncName() string { pc, _, _, _ := runtime.Caller(1) @@ -187,3 +173,93 @@ func MayReplaceLocalhost(host string) string { } return host } + +var sqlDbs = map[string]*sql.DB{} + +// SdbGet get pooled sql.DB +func SdbGet(conf map[string]string) (*sql.DB, error) { + dsn := GetDsn(conf) + if sqlDbs[dsn] == nil { + db, err := SdbAlone(conf) + if err != nil { + return nil, err + } + sqlDbs[dsn] = db + } + return sqlDbs[dsn], nil +} + +// SdbAlone get a standalone db connection +func SdbAlone(conf map[string]string) (*sql.DB, error) { + dsn := GetDsn(conf) + Logf("opening alone %s: %s", conf["driver"], strings.Replace(dsn, conf["password"], "****", 1)) + return sql.Open(conf["driver"], dsn) +} + +// SdbExec use raw db to exec +func SdbExec(db *sql.DB, sql string, values ...interface{}) (affected int64, rerr error) { + r, rerr := db.Exec(sql, values...) + if rerr == nil { + affected, rerr = r.RowsAffected() + Logf("affected: %d for %s %v", affected, sql, values) + } else { + LogRedf("exec error: %v for %s %v", rerr, sql, values) + } + return +} + +// StxExec use raw tx to exec +func StxExec(tx *sql.Tx, sql string, values ...interface{}) (affected int64, rerr error) { + r, rerr := tx.Exec(sql, values...) + if rerr == nil { + affected, rerr = r.RowsAffected() + Logf("affected: %d for %s %v", affected, sql, values) + } else { + LogRedf("exec error: %v for %s %v", rerr, sql, values) + } + return +} + +// StxQueryRow use raw tx to query row +func StxQueryRow(tx *sql.Tx, query string, args ...interface{}) *sql.Row { + Logf("querying: "+query, args...) + return tx.QueryRow(query, args...) +} + +// GetDsn get dsn from map config +func GetDsn(conf map[string]string) string { + conf["host"] = MayReplaceLocalhost(conf["host"]) + driver := conf["driver"] + dsn := MS{ + "mysql": fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", + conf["user"], conf["password"], conf["host"], conf["port"], conf["database"]), + "postgres": fmt.Sprintf("host=%s user=%s password=%s dbname='%s' port=%s sslmode=disable TimeZone=Asia/Shanghai", + conf["host"], conf["user"], conf["password"], conf["database"], conf["port"]), + }[driver] + PanicIf(dsn == "", fmt.Errorf("unknow driver: %s", driver)) + return dsn +} + +// CheckResponse 检查Response,返回错误 +func CheckResponse(resp *resty.Response, err error) error { + if err == nil && resp != nil { + if resp.IsError() { + return errors.New(resp.String()) + } else if strings.Contains(resp.String(), "FAILURE") { + return ErrFailure + } + } + return err +} + +// CheckResult 检查Result,返回错误 +func CheckResult(res interface{}, err error) error { + resp, ok := res.(*resty.Response) + if ok { + return CheckResponse(resp, err) + } + if res != nil && strings.Contains(MustMarshalString(res), "FAILURE") { + return ErrFailure + } + return err +} diff --git a/dtmcli/utils_test.go b/dtmcli/utils_test.go index 5a13445..2ed3320 100644 --- a/dtmcli/utils_test.go +++ b/dtmcli/utils_test.go @@ -65,11 +65,6 @@ func TestSome(t *testing.T) { MustAtoi("abc") }) assert.Error(t, err) - wd := MustGetwd() - assert.NotEqual(t, "", wd) - - dir1 := GetCurrentCodeDir() - assert.Equal(t, true, strings.HasSuffix(dir1, "dtmcli")) func1 := GetFuncName() assert.Equal(t, true, strings.HasSuffix(func1, "TestSome")) diff --git a/dtmcli/xa.go b/dtmcli/xa.go index 35928da..e590368 100644 --- a/dtmcli/xa.go +++ b/dtmcli/xa.go @@ -8,8 +8,6 @@ import ( "github.com/go-resty/resty/v2" ) -var e2p = E2P - // XaGlobalFunc type of xa global function type XaGlobalFunc func(xa *Xa) (*resty.Response, error) @@ -58,10 +56,13 @@ func NewXaClient(server string, mysqlConf map[string]string, callbackURL string, // HandleCallback 处理commit/rollback的回调 func (xc *XaClient) HandleCallback(gid string, branchID string, action string) (interface{}, error) { - db := SdbAlone(xc.Conf) + db, err := SdbAlone(xc.Conf) + if err != nil { + return nil, err + } defer db.Close() xaID := gid + "-" + branchID - _, err := SdbExec(db, fmt.Sprintf("xa %s '%s'", action, xaID)) + _, err = SdbExec(db, fmt.Sprintf("xa %s '%s'", action, xaID)) return ResultSuccess, err } @@ -75,7 +76,10 @@ func (xc *XaClient) XaLocalTransaction(qs url.Values, xaFunc XaLocalFunc) (ret i xa.Dtm = xc.Server branchID := xa.NewBranchID() xaBranch := xa.Gid + "-" + branchID - db := SdbAlone(xc.Conf) + db, rerr := SdbAlone(xc.Conf) + if rerr != nil { + return + } defer func() { db.Close() }() defer func() { x := recover() diff --git a/dtmsvr/config.go b/dtmsvr/config.go index 258ff5a..79089aa 100644 --- a/dtmsvr/config.go +++ b/dtmsvr/config.go @@ -2,7 +2,6 @@ package dtmsvr import ( "github.com/yedf/dtm/common" - "github.com/yedf/dtm/dtmcli" ) type dtmsvrConfig struct { @@ -14,9 +13,6 @@ var config = &dtmsvrConfig{ TransCronInterval: 10, } -var dbName = "dtm" - func init() { - common.InitConfig(dtmcli.GetProjectDir(), &config) - config.DB["database"] = "" + common.InitConfig(&config) } diff --git a/dtmsvr/dtmsvr_test.go b/dtmsvr/dtmsvr_test.go index 1cfa263..2299993 100644 --- a/dtmsvr/dtmsvr_test.go +++ b/dtmsvr/dtmsvr_test.go @@ -42,7 +42,7 @@ func resetXaData() { func TestMain(m *testing.M) { TransProcessedTestChan = make(chan string, 1) - common.InitConfig(dtmcli.GetProjectDir(), &config) + common.InitConfig(&config) PopulateDB(false) examples.PopulateDB(false) // 启动组件 @@ -72,21 +72,6 @@ func TestCover(t *testing.T) { go sleepCronTime() } -func TestType(t *testing.T) { - err := dtmcli.CatchP(func() { - dtmcli.MustGenGid("http://localhost:8080/api/no") - }) - assert.Error(t, err) - err = dtmcli.CatchP(func() { - resp, err := dtmcli.RestyClient.R().SetBody(dtmcli.M{ - "gid": "1", - "trans_type": "msg", - }).Get("http://localhost:8080/api/dtmsvr/abort") - dtmcli.CheckRestySuccess(resp, err) - }) - assert.Error(t, err) -} - func getTransStatus(gid string) string { sm := TransGlobal{} dbr := dbGet().Model(&sm).Where("gid=?", gid).First(&sm) diff --git a/dtmsvr/main.go b/dtmsvr/main.go index a3b24c7..6324cc4 100644 --- a/dtmsvr/main.go +++ b/dtmsvr/main.go @@ -23,6 +23,6 @@ func StartSvr() { // PopulateDB setup mysql data func PopulateDB(skipDrop bool) { - file := fmt.Sprintf("%s/dtmsvr.%s.sql", dtmcli.GetCurrentCodeDir(), config.DB["driver"]) + file := fmt.Sprintf("%s/dtmsvr.%s.sql", common.GetCurrentCodeDir(), config.DB["driver"]) examples.RunSQLScript(config.DB, file, skipDrop) } diff --git a/examples/config.go b/examples/config.go index 6368939..fd831c5 100644 --- a/examples/config.go +++ b/examples/config.go @@ -2,7 +2,6 @@ package examples import ( "github.com/yedf/dtm/common" - "github.com/yedf/dtm/dtmcli" ) type exampleConfig struct { @@ -12,5 +11,5 @@ type exampleConfig struct { var config = exampleConfig{} func init() { - common.InitConfig(dtmcli.GetProjectDir(), &config) + common.InitConfig(&config) } diff --git a/examples/data.go b/examples/data.go index 468a479..d2e8d61 100644 --- a/examples/data.go +++ b/examples/data.go @@ -5,12 +5,14 @@ import ( "io/ioutil" "strings" + "github.com/yedf/dtm/common" "github.com/yedf/dtm/dtmcli" ) // RunSQLScript 1 func RunSQLScript(conf map[string]string, script string, skipDrop bool) { - con := dtmcli.SdbAlone(conf) + con, err := dtmcli.SdbAlone(conf) + e2p(err) defer func() { con.Close() }() content, err := ioutil.ReadFile(script) e2p(err) @@ -27,6 +29,6 @@ func RunSQLScript(conf map[string]string, script string, skipDrop bool) { // PopulateDB populate example mysql data func PopulateDB(skipDrop bool) { - file := fmt.Sprintf("%s/examples.%s.sql", dtmcli.GetCurrentCodeDir(), config.DB["driver"]) + file := fmt.Sprintf("%s/examples.%s.sql", common.GetCurrentCodeDir(), config.DB["driver"]) RunSQLScript(config.DB, file, skipDrop) } diff --git a/examples/types.go b/examples/types.go index f3cd218..70d08db 100644 --- a/examples/types.go +++ b/examples/types.go @@ -64,7 +64,9 @@ func dbGet() *common.DB { } func sdbGet() *sql.DB { - return dtmcli.SdbGet(config.DB) + db, err := dtmcli.SdbGet(config.DB) + e2p(err) + return db } // MustGetTrans construct transaction info from request