diff --git a/common/types.go b/common/types.go index b822e10..af77fa4 100644 --- a/common/types.go +++ b/common/types.go @@ -3,11 +3,15 @@ package common import ( "database/sql" "fmt" + "io/ioutil" + "os" + "path/filepath" "strings" "time" _ "github.com/go-sql-driver/mysql" "github.com/yedf/dtm/dtmcli" + "gopkg.in/yaml.v2" // _ "github.com/lib/pq" @@ -117,3 +121,40 @@ func DbGet(conf map[string]string) *DB { } return dbs[dsn] } + +type dtmConfigType struct { + TransCronInterval int64 `yaml:"TransCronInterval"` // 单位秒 当事务等待这个时间之后,还没有变化,则进行一轮处理,包括prepared中的任务和committed的任务 + DB map[string]string `yaml:"DB"` +} + +// DtmConfig 配置 +var DtmConfig = dtmConfigType{} + +func init() { + DtmConfig.TransCronInterval = int64(dtmcli.MustAtoi(dtmcli.OrString(os.Getenv("TRANS_CRON_INTERVAL"), "10"))) + DtmConfig.DB = map[string]string{ + "driver": os.Getenv("DB_DRIVER"), + "host": os.Getenv("DB_HOST"), + "port": os.Getenv("DB_PORT"), + "user": os.Getenv("DB_USER"), + "password": os.Getenv("DB_PASSWORD"), + } + cont := []byte{} + for d := MustGetwd(); d != ""; d = filepath.Dir(d) { + cont1, err := ioutil.ReadFile(d + "/conf.yml") + if err != nil { + cont1, err = ioutil.ReadFile(d + "/conf.sample.yml") + } + if cont1 != nil { + cont = cont1 + break + } + } + if cont != nil { + dtmcli.Logf("cont is: \n%s", string(cont)) + err := yaml.Unmarshal(cont, &DtmConfig) + dtmcli.FatalIfError(err) + } + dtmcli.LogIfFatalf(DtmConfig.DB["driver"] == "" || DtmConfig.DB["user"] == "", + "dtm config error: %v. check you env, and conf.yml/conf.sample.yml found in current and parent path: %s", DtmConfig, MustGetwd()) +} diff --git a/common/types_test.go b/common/types_test.go index 5f1e11d..bce7cfc 100644 --- a/common/types_test.go +++ b/common/types_test.go @@ -7,18 +7,8 @@ import ( "github.com/yedf/dtm/dtmcli" ) -type testConfig struct { - DB map[string]string `yaml:"DB"` -} - -var config = testConfig{} - -func init() { - InitConfig(&config) -} - func TestDb(t *testing.T) { - db := DbGet(config.DB) + db := DbGet(DtmConfig.DB) err := func() (rerr error) { defer dtmcli.P2E(&rerr) dbr := db.NoMust().Exec("select a") @@ -30,7 +20,7 @@ func TestDb(t *testing.T) { } func TestDbAlone(t *testing.T) { - db, err := dtmcli.SdbAlone(config.DB) + db, err := dtmcli.SdbAlone(DtmConfig.DB) assert.Nil(t, err) _, err = dtmcli.SdbExec(db, "select 1") assert.Equal(t, nil, err) diff --git a/common/utils.go b/common/utils.go index 90b62d4..1f86e05 100644 --- a/common/utils.go +++ b/common/utils.go @@ -12,7 +12,6 @@ import ( "github.com/gin-gonic/gin" "github.com/go-resty/resty/v2" "github.com/yedf/dtm/dtmcli" - yaml "gopkg.in/yaml.v2" ) // GetGinApp init and return gin @@ -74,22 +73,3 @@ func GetCurrentCodeDir() string { _, file, _, _ := runtime.Caller(1) return filepath.Dir(file) } - -// InitConfig init config -func InitConfig(config interface{}) { - cont := []byte{} - for d := MustGetwd(); d != ""; d = filepath.Dir(d) { - cont1, err := ioutil.ReadFile(d + "/conf.yml") - if err != nil { - cont1, err = ioutil.ReadFile(d + "/conf.sample.yml") - } - if cont1 != nil { - cont = cont1 - break - } - } - dtmcli.LogIfFatalf(cont == nil, "no config file conf.yml/conf.sample.yml found in current and parent path: %s", MustGetwd()) - dtmcli.Logf("cont is: \n%s", string(cont)) - err := yaml.Unmarshal(cont, config) - dtmcli.FatalIfError(err) -} diff --git a/dtmsvr/config.go b/dtmsvr/config.go deleted file mode 100644 index 79089aa..0000000 --- a/dtmsvr/config.go +++ /dev/null @@ -1,18 +0,0 @@ -package dtmsvr - -import ( - "github.com/yedf/dtm/common" -) - -type dtmsvrConfig struct { - TransCronInterval int64 `yaml:"TransCronInterval"` // 单位秒 当事务等待这个时间之后,还没有变化,则进行一轮处理,包括prepared中的任务和committed的任务 - DB map[string]string `yaml:"DB"` -} - -var config = &dtmsvrConfig{ - TransCronInterval: 10, -} - -func init() { - common.InitConfig(&config) -} diff --git a/dtmsvr/dtmsvr_test.go b/dtmsvr/dtmsvr_test.go index 2299993..44fd50e 100644 --- a/dtmsvr/dtmsvr_test.go +++ b/dtmsvr/dtmsvr_test.go @@ -42,7 +42,6 @@ func resetXaData() { func TestMain(m *testing.M) { TransProcessedTestChan = make(chan string, 1) - common.InitConfig(&config) PopulateDB(false) examples.PopulateDB(false) // 启动组件 diff --git a/dtmsvr/utils.go b/dtmsvr/utils.go index 782fa18..973581f 100644 --- a/dtmsvr/utils.go +++ b/dtmsvr/utils.go @@ -17,6 +17,8 @@ type M = map[string]interface{} var p2e = dtmcli.P2E var e2p = dtmcli.E2P +var config = common.DtmConfig + func dbGet() *common.DB { return common.DbGet(config.DB) } diff --git a/examples/config.go b/examples/config.go deleted file mode 100644 index fd831c5..0000000 --- a/examples/config.go +++ /dev/null @@ -1,15 +0,0 @@ -package examples - -import ( - "github.com/yedf/dtm/common" -) - -type exampleConfig struct { - DB map[string]string `yaml:"DB"` -} - -var config = exampleConfig{} - -func init() { - common.InitConfig(&config) -} diff --git a/examples/data.go b/examples/data.go index d2e8d61..a341777 100644 --- a/examples/data.go +++ b/examples/data.go @@ -9,6 +9,8 @@ import ( "github.com/yedf/dtm/dtmcli" ) +var config = common.DtmConfig + // RunSQLScript 1 func RunSQLScript(conf map[string]string, script string, skipDrop bool) { con, err := dtmcli.SdbAlone(conf) @@ -31,4 +33,6 @@ func RunSQLScript(conf map[string]string, script string, skipDrop bool) { func PopulateDB(skipDrop bool) { file := fmt.Sprintf("%s/examples.%s.sql", common.GetCurrentCodeDir(), config.DB["driver"]) RunSQLScript(config.DB, file, skipDrop) + file = fmt.Sprintf("%s/../dtmcli/barrier.%s.sql", common.GetCurrentCodeDir(), config.DB["driver"]) + RunSQLScript(config.DB, file, skipDrop) } diff --git a/examples/examples.mysql.sql b/examples/examples.mysql.sql index c56b71a..4f77eab 100644 --- a/examples/examples.mysql.sql +++ b/examples/examples.mysql.sql @@ -1,5 +1,4 @@ CREATE DATABASE if not exists dtm_busi /*!40100 DEFAULT CHARACTER SET utf8mb4 */; -create database if not exists dtm_barrier /*!40100 DEFAULT CHARACTER SET utf8mb4 */; drop table if exists dtm_busi.user_account; create table if not exists dtm_busi.user_account( @@ -28,18 +27,3 @@ create table if not exists dtm_busi.user_account_trading( -- 表示交易中被 insert into dtm_busi.user_account_trading (user_id, trading_balance) values (1, 0), (2, 0) on DUPLICATE KEY UPDATE trading_balance=values (trading_balance); -drop table if exists dtm_barrier.barrier; -create table if not exists dtm_barrier.barrier( - id int(11) PRIMARY KEY AUTO_INCREMENT, - trans_type varchar(45) default '' , - gid varchar(128) default'', - branch_id varchar(128) default '', - branch_type varchar(45) default '', - reason varchar(45) default '' comment 'the branch type who insert this record', - result varchar(2047) default null comment 'the business result of this branch', - create_time datetime DEFAULT now(), - update_time datetime DEFAULT now(), - key(create_time), - key(update_time), - UNIQUE key(gid, branch_id, branch_type) -);