diff --git a/app/main.go b/app/main.go index 2e5a0b8..4d3dcf3 100644 --- a/app/main.go +++ b/app/main.go @@ -21,7 +21,7 @@ func wait() { func main() { onlyServer := len(os.Args) > 1 && os.Args[1] == "dtmsvr" if !onlyServer { // 实际线上运行,只启动dtmsvr,不准备table相关的数据 - dtmsvr.PopulateMysql(true) + dtmsvr.PopulateDB(true) } dtmsvr.StartSvr() // 启动dtmsvr的api服务 go dtmsvr.CronExpiredTrans(-1) // 启动dtmsvr的定时过期查询 @@ -38,7 +38,7 @@ func main() { } // 下面是各类的例子 - examples.PopulateMysql(true) + examples.PopulateDB(true) app := examples.BaseAppStartup() if os.Args[1] == "xa" { // 启动xa示例 examples.XaSetup(app) diff --git a/common/types_test.go b/common/types_test.go index b7e242e..21bc7ff 100644 --- a/common/types_test.go +++ b/common/types_test.go @@ -7,18 +7,18 @@ import ( ) type testConfig struct { - Mysql map[string]string `yaml:"Mysql"` + DB map[string]string `yaml:"DB"` } var config = testConfig{} func init() { InitConfig(GetProjectDir(), &config) - config.Mysql["database"] = "" + config.DB["database"] = "" } func TestDb(t *testing.T) { - db := DbGet(config.Mysql) + db := DbGet(config.DB) err := func() (rerr error) { defer P2E(&rerr) dbr := db.NoMust().Exec("select a") @@ -32,7 +32,7 @@ func TestDb(t *testing.T) { } func TestDbAlone(t *testing.T) { - db, con := DbAlone(config.Mysql) + db, con := DbAlone(config.DB) dbr := db.Exec("select 1") assert.Equal(t, nil, dbr.Error) con.Close() diff --git a/conf.sample.yml b/conf.sample.yml index fa5af2f..d17de37 100644 --- a/conf.sample.yml +++ b/conf.sample.yml @@ -1,4 +1,5 @@ -Mysql: +DB: + driver: 'mysql' host: 'localhost' user: 'root' password: '' diff --git a/dtmsvr/config.go b/dtmsvr/config.go index b6fa47f..a078691 100644 --- a/dtmsvr/config.go +++ b/dtmsvr/config.go @@ -4,7 +4,7 @@ import "github.com/yedf/dtm/common" type dtmsvrConfig struct { TransCronInterval int64 `yaml:"TransCronInterval"` // 单位秒 当事务等待这个时间之后,还没有变化,则进行一轮处理,包括prepared中的任务和committed的任务 - Mysql map[string]string `yaml:"Mysql"` + DB map[string]string `yaml:"DB"` } var config = &dtmsvrConfig{ @@ -15,5 +15,5 @@ var dbName = "dtm" func init() { common.InitConfig(common.GetProjectDir(), &config) - config.Mysql["database"] = "" + config.DB["database"] = "" } diff --git a/dtmsvr/dtmsvr.sql b/dtmsvr/dtmsvr.mysql.sql similarity index 100% rename from dtmsvr/dtmsvr.sql rename to dtmsvr/dtmsvr.mysql.sql diff --git a/dtmsvr/dtmsvr_test.go b/dtmsvr/dtmsvr_test.go index a5d569b..072d7e9 100644 --- a/dtmsvr/dtmsvr_test.go +++ b/dtmsvr/dtmsvr_test.go @@ -20,9 +20,9 @@ var app *gin.Engine func TestMain(m *testing.M) { TransProcessedTestChan = make(chan string, 1) common.InitConfig(common.GetProjectDir(), &config) - config.Mysql["database"] = dbName - PopulateMysql(false) - examples.PopulateMysql(false) + config.DB["database"] = dbName + PopulateDB(false) + examples.PopulateDB(false) // 启动组件 go StartSvr() app = examples.BaseAppStartup() @@ -130,7 +130,7 @@ func transQuery(t *testing.T, gid string) { func TestSqlDB(t *testing.T) { asserts := assert.New(t) - db := common.DbGet(config.Mysql) + db := common.DbGet(config.DB) transInfo := &dtmcli.TransInfo{ TransType: "saga", Gid: "gid2", diff --git a/dtmsvr/main.go b/dtmsvr/main.go index 1aefe42..5061beb 100644 --- a/dtmsvr/main.go +++ b/dtmsvr/main.go @@ -21,7 +21,8 @@ func StartSvr() { time.Sleep(100 * time.Millisecond) } -// PopulateMysql setup mysql data -func PopulateMysql(skipDrop bool) { - examples.RunSQLScript(config.Mysql, common.GetCurrentCodeDir()+"/dtmsvr.sql", skipDrop) +// PopulateDB setup mysql data +func PopulateDB(skipDrop bool) { + file := fmt.Sprintf("%s/dtmsvr.%s.sql", common.GetCurrentCodeDir(), config.DB["driver"]) + examples.RunSQLScript(config.DB, file, skipDrop) } diff --git a/dtmsvr/utils.go b/dtmsvr/utils.go index 63b20ea..0e6ea8b 100644 --- a/dtmsvr/utils.go +++ b/dtmsvr/utils.go @@ -18,7 +18,7 @@ var p2e = common.P2E var e2p = common.E2P func dbGet() *common.DB { - return common.DbGet(config.Mysql) + return common.DbGet(config.DB) } func writeTransLog(gid string, action string, status string, branch string, detail string) { // if detail == "" { diff --git a/examples/config.go b/examples/config.go index 49ff2e3..86e619d 100644 --- a/examples/config.go +++ b/examples/config.go @@ -3,7 +3,7 @@ package examples import "github.com/yedf/dtm/common" type exampleConfig struct { - Mysql map[string]string `yaml:"Mysql"` + DB map[string]string `yaml:"DB"` } var config = exampleConfig{} @@ -12,5 +12,5 @@ var dbName = "dtm_busi" func init() { common.InitConfig(common.GetProjectDir(), &config) - config.Mysql["database"] = dbName + config.DB["database"] = dbName } diff --git a/examples/data.go b/examples/data.go index e7dcda6..d656c30 100644 --- a/examples/data.go +++ b/examples/data.go @@ -1,6 +1,7 @@ package examples import ( + "fmt" "io/ioutil" "strings" @@ -28,7 +29,8 @@ func RunSQLScript(mysql map[string]string, script string, skipDrop bool) { } } -// PopulateMysql populate example mysql data -func PopulateMysql(skipDrop bool) { - RunSQLScript(config.Mysql, common.GetCurrentCodeDir()+"/examples.sql", skipDrop) +// PopulateDB populate example mysql data +func PopulateDB(skipDrop bool) { + file := fmt.Sprintf("%s/examples.%s.sql", common.GetCurrentCodeDir(), config.DB["driver"]) + RunSQLScript(config.DB, file, skipDrop) } diff --git a/examples/examples.sql b/examples/examples.mysql.sql similarity index 100% rename from examples/examples.sql rename to examples/examples.mysql.sql diff --git a/examples/main_xa.go b/examples/main_xa.go index daa2090..d7fe99f 100644 --- a/examples/main_xa.go +++ b/examples/main_xa.go @@ -33,16 +33,16 @@ type UserAccountTrading struct { func (u *UserAccountTrading) TableName() string { return "user_account_trading" } func dbGet() *common.DB { - return common.DbGet(config.Mysql) + return common.DbGet(config.DB) } // XaSetup 挂载http的api,创建XaClient func XaSetup(app *gin.Engine) { app.POST(BusiAPI+"/TransInXa", common.WrapHandler(xaTransIn)) app.POST(BusiAPI+"/TransOutXa", common.WrapHandler(xaTransOut)) - config.Mysql["database"] = "dtm_busi" + config.DB["database"] = "dtm_busi" var err error - XaClient, err = dtmcli.NewXaClient(DtmServer, config.Mysql, app, Busi+"/xa") + XaClient, err = dtmcli.NewXaClient(DtmServer, config.DB, app, Busi+"/xa") e2p(err) }