diff --git a/.simplecov b/.simplecov deleted file mode 100644 index a8c24e0..0000000 --- a/.simplecov +++ /dev/null @@ -1,8 +0,0 @@ -require 'simplecov' -require 'coveralls' - -SimpleCov.formatter = Coveralls::SimpleCov::Formatter -SimpleCov.start do - add_filter 'app' - add_filter 'examples' -end \ No newline at end of file diff --git a/common/types.go b/common/types.go index 7c6d657..3c3de7f 100644 --- a/common/types.go +++ b/common/types.go @@ -3,11 +3,15 @@ package common import ( "database/sql" "fmt" - "regexp" + "strings" "time" + _ "github.com/go-sql-driver/mysql" + // _ "github.com/lib/pq" "github.com/sirupsen/logrus" "gorm.io/driver/mysql" + + // "gorm.io/driver/postgres" "gorm.io/gorm" ) @@ -97,22 +101,32 @@ func (op *tracePlugin) Initialize(db *gorm.DB) (err error) { // GetDsn get dsn from map config func GetDsn(conf map[string]string) string { conf["host"] = MayReplaceLocalhost(conf["host"]) - // logrus.Printf("is docker: %t IS_DOCKER_COMPOSE: %s and conf host: %s", IsDockerCompose(), os.Getenv("IS_DOCKER_COMPOSE"), conf["host"]) - return fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", conf["user"], conf["password"], conf["host"], conf["port"], conf["database"]) + driver := conf["driver"] + dsn := MS{ + "mysql": fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", + conf["user"], conf["password"], conf["host"], conf["port"], conf["database"]), + "postgres": fmt.Sprintf("host=%s user=%s password=%s dbname='%s' port=%s sslmode=disable TimeZone=Asia/Shanghai", + conf["host"], conf["user"], conf["password"], conf["database"], conf["port"]), + }[driver] + PanicIf(dsn == "", fmt.Errorf("unknow driver: %s", driver)) + return dsn } -// ReplaceDsnPassword replace password for log output -func ReplaceDsnPassword(dsn string) string { - reg := regexp.MustCompile(`:.*@`) - return reg.ReplaceAllString(dsn, ":****@") +func getGormDialator(driver string, dsn string) gorm.Dialector { + if driver == "mysql" { + return mysql.Open(dsn) + // } else if driver == "postgres" { + // return postgres.Open(dsn) + } + panic(fmt.Errorf("unkown driver: %s", driver)) } // DbGet get db connection for specified conf func DbGet(conf map[string]string) *DB { dsn := GetDsn(conf) if dbs[dsn] == nil { - logrus.Printf("connecting %s", ReplaceDsnPassword(dsn)) - db1, err := gorm.Open(mysql.Open(dsn), &gorm.Config{ + logrus.Printf("connecting %s", strings.Replace(dsn, conf["password"], "****", 1)) + db1, err := gorm.Open(getGormDialator(conf["driver"], dsn), &gorm.Config{ SkipDefaultTransaction: true, }) E2P(err) @@ -132,28 +146,21 @@ func SQLDB2DB(sdb *sql.DB) *DB { return &DB{DB: db} } -// MyConn for xa alone connection -type MyConn struct { - Conn *sql.DB - Dsn string -} - -// Close name is clear -func (conn *MyConn) Close() { - logrus.Printf("closing alone mysql: %s", ReplaceDsnPassword(conn.Dsn)) - conn.Conn.Close() -} - // DbAlone get a standalone db connection -func DbAlone(conf map[string]string) (*DB, *MyConn) { +func DbAlone(conf map[string]string) *sql.DB { dsn := GetDsn(conf) - logrus.Printf("opening alone mysql: %s", ReplaceDsnPassword(dsn)) - mdb, err := sql.Open("mysql", dsn) + logrus.Printf("opening alone %s: %s", conf["driver"], strings.Replace(dsn, conf["password"], "****", 1)) + mdb, err := sql.Open(conf["driver"], dsn) E2P(err) - gormDB, err := gorm.Open(mysql.New(mysql.Config{ - Conn: mdb, - }), &gorm.Config{}) - E2P(err) - gormDB.Use(&tracePlugin{}) - return &DB{DB: gormDB}, &MyConn{Conn: mdb, Dsn: dsn} + return mdb +} + +// DbExec use raw db to exec +func DbExec(db *sql.DB, sql string, values ...interface{}) (affected int64, rerr error) { + r, rerr := db.Exec(sql, values...) + if rerr == nil { + affected, rerr = r.RowsAffected() + } + logrus.Printf("affected: %d error: %v for %s %v", affected, rerr, sql, values) + return } diff --git a/common/types_test.go b/common/types_test.go index 21bc7ff..e89f7cb 100644 --- a/common/types_test.go +++ b/common/types_test.go @@ -32,10 +32,10 @@ func TestDb(t *testing.T) { } func TestDbAlone(t *testing.T) { - db, con := DbAlone(config.DB) - dbr := db.Exec("select 1") - assert.Equal(t, nil, dbr.Error) - con.Close() - dbr = db.Exec("select 1") - assert.NotEqual(t, nil, dbr.Error) + db := DbAlone(config.DB) + _, err := DbExec(db, "select 1") + assert.Equal(t, nil, err) + db.Close() + _, err = DbExec(db, "select 1") + assert.NotEqual(t, nil, err) } diff --git a/conf.sample.yml b/conf.sample.yml index d17de37..55d8e63 100644 --- a/conf.sample.yml +++ b/conf.sample.yml @@ -4,5 +4,9 @@ DB: user: 'root' password: '' port: '3306' - + # driver: 'postgres' + # host: 'localhost' + # user: 'postgres' + # password: 'mysecretpassword' + # port: '5432' TransCronInterval: 10 # 单位秒 当事务等待这个时间之后,还没有变化,则进行一轮重试处理,包括prepared中的任务和commited的任务 diff --git a/dtmcli/xa.go b/dtmcli/xa.go index 13f10b6..416f3ad 100644 --- a/dtmcli/xa.go +++ b/dtmcli/xa.go @@ -1,6 +1,7 @@ package dtmcli import ( + "database/sql" "fmt" "net/url" "strings" @@ -20,7 +21,7 @@ var e2p = common.E2P type XaGlobalFunc func(xa *Xa) error // XaLocalFunc type of xa local function -type XaLocalFunc func(db *common.DB, xa *Xa) error +type XaLocalFunc func(db *sql.DB, xa *Xa) error // XaClient xa client type XaClient struct { @@ -66,13 +67,15 @@ func NewXaClient(server string, mysqlConf map[string]string, app *gin.Engine, ca return nil, err } common.MustUnmarshal(b, &req) - tx, my := common.DbAlone(xa.Conf) - defer my.Close() + db := common.DbAlone(xa.Conf) + defer db.Close() branchID := req.Gid + "-" + req.BranchID if req.Action == "commit" { - tx.Must().Exec(fmt.Sprintf("xa commit '%s'", branchID)) + _, err := common.DbExec(db, fmt.Sprintf("xa commit '%s'", branchID)) + e2p(err) } else if req.Action == "rollback" { - tx.Must().Exec(fmt.Sprintf("xa rollback '%s'", branchID)) + _, err := common.DbExec(db, fmt.Sprintf("xa rollback '%s'", branchID)) + e2p(err) } else { panic(fmt.Errorf("unknown action: %s", req.Action)) } @@ -87,10 +90,11 @@ func (xc *XaClient) XaLocalTransaction(c *gin.Context, transFunc XaLocalFunc) (r xa := XaFromReq(c) branchID := xa.NewBranchID() xaBranch := xa.Gid + "-" + branchID - tx, my := common.DbAlone(xc.Conf) - defer func() { my.Close() }() - tx.Must().Exec(fmt.Sprintf("XA start '%s'", xaBranch)) - err := transFunc(tx, xa) + db := common.DbAlone(xc.Conf) + defer func() { db.Close() }() + _, err := common.DbExec(db, fmt.Sprintf("XA start '%s'", xaBranch)) + e2p(err) + err = transFunc(db, xa) e2p(err) resp, err := common.RestyClient.R(). SetBody(&M{"gid": xa.Gid, "branch_id": branchID, "trans_type": "xa", "status": "prepared", "url": xc.CallbackURL}). @@ -99,8 +103,10 @@ func (xc *XaClient) XaLocalTransaction(c *gin.Context, transFunc XaLocalFunc) (r if !strings.Contains(resp.String(), "SUCCESS") { e2p(fmt.Errorf("unknown server response: %s", resp.String())) } - tx.Must().Exec(fmt.Sprintf("XA end '%s'", xaBranch)) - tx.Must().Exec(fmt.Sprintf("XA prepare '%s'", xaBranch)) + _, err = common.DbExec(db, fmt.Sprintf("XA end '%s'", xaBranch)) + e2p(err) + _, err = common.DbExec(db, fmt.Sprintf("XA prepare '%s'", xaBranch)) + e2p(err) return nil } diff --git a/dtmsvr/dtmsvr.postgres.sql b/dtmsvr/dtmsvr.postgres.sql index e69de29..29f92bf 100644 --- a/dtmsvr/dtmsvr.postgres.sql +++ b/dtmsvr/dtmsvr.postgres.sql @@ -0,0 +1,72 @@ +CREATE SCHEMA if not EXISTS dtm /* SQLINES DEMO *** RACTER SET utf8mb4 */; + +drop table IF EXISTS dtm.trans_global; +-- SQLINES LICENSE FOR EVALUATION USE ONLY +CREATE SEQUENCE if not EXISTS dtm.trans_global_seq; + +CREATE TABLE if not EXISTS dtm.trans_global ( + id int NOT NULL DEFAULT NEXTVAL ('dtm.trans_global_seq'), + gid varchar(128) NOT NULL , + trans_type varchar(45) not null , + data TEXT , + status varchar(45) NOT NULL , + query_prepared varchar(128) NOT NULL , + create_time timestamp(0) DEFAULT NULL, + update_time timestamp(0) DEFAULT NULL, + commit_time timestamp(0) DEFAULT NULL, + finish_time timestamp(0) DEFAULT NULL, + rollback_time timestamp(0) DEFAULT NULL, + next_cron_interval int default null , + next_cron_time timestamp(0) default null , + owner varchar(128) not null default '' , + PRIMARY KEY (id), + CONSTRAINT gid UNIQUE (gid) +) ; + +create index if not EXISTS owner on dtm.trans_global(owner); +CREATE INDEX if not EXISTS create_time ON dtm.trans_global (create_time); +CREATE INDEX if not EXISTS update_time ON dtm.trans_global (update_time); +create index if not EXISTS next_cron_time on dtm.trans_global (next_cron_time); + +drop table IF EXISTS dtm.trans_branch; +-- SQLINES LICENSE FOR EVALUATION USE ONLY +CREATE SEQUENCE if not EXISTS dtm.trans_branch_seq; + +CREATE TABLE IF NOT EXISTS dtm.trans_branch ( + id int NOT NULL DEFAULT NEXTVAL ('dtm.trans_branch_seq'), + gid varchar(128) NOT NULL , + url varchar(128) NOT NULL , + data TEXT , + branch_id VARCHAR(128) NOT NULL , + branch_type varchar(45) NOT NULL , + status varchar(45) NOT NULL , + finish_time timestamp(0) DEFAULT NULL, + rollback_time timestamp(0) DEFAULT NULL, + create_time timestamp(0) DEFAULT NULL, + update_time timestamp(0) DEFAULT NULL, + PRIMARY KEY (id), + CONSTRAINT gid_uniq UNIQUE (gid,branch_id, branch_type) +) ; + +CREATE INDEX if not EXISTS create_time ON dtm.trans_branch (create_time); +CREATE INDEX if not EXISTS update_time ON dtm.trans_branch (update_time); + +drop table IF EXISTS dtm.trans_log; +-- SQLINES LICENSE FOR EVALUATION USE ONLY +CREATE SEQUENCE if not EXISTS dtm.trans_log_seq; + +CREATE TABLE IF NOT EXISTS dtm.trans_log ( + id int NOT NULL DEFAULT NEXTVAL ('dtm.trans_log_seq'), + gid varchar(128) NOT NULL , + branch_id varchar(128) DEFAULT NULL , + action varchar(45) DEFAULT NULL , + old_status varchar(45) NOT NULL DEFAULT '' , + new_status varchar(45) NOT NULL , + detail TEXT NOT NULL , + create_time timestamp(0) DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (id) +) ; + +CREATE INDEX if not EXISTS gid ON dtm.trans_log (gid); +CREATE INDEX if not EXISTS create_time ON dtm.trans_log (create_time); + diff --git a/dtmsvr/dtmsvr_test.go b/dtmsvr/dtmsvr_test.go index 072d7e9..ddabb7e 100644 --- a/dtmsvr/dtmsvr_test.go +++ b/dtmsvr/dtmsvr_test.go @@ -20,7 +20,6 @@ var app *gin.Engine func TestMain(m *testing.M) { TransProcessedTestChan = make(chan string, 1) common.InitConfig(common.GetProjectDir(), &config) - config.DB["database"] = dbName PopulateDB(false) examples.PopulateDB(false) // 启动组件 diff --git a/dtmsvr/trans_xa_test.go b/dtmsvr/trans_xa_test.go index 9f2d0ae..79361a6 100644 --- a/dtmsvr/trans_xa_test.go +++ b/dtmsvr/trans_xa_test.go @@ -12,6 +12,9 @@ import ( ) func TestXa(t *testing.T) { + if config.DB["driver"] != "mysql" { + return + } xaLocalError(t) xaNormal(t) xaRollback(t) diff --git a/examples/config.go b/examples/config.go index 86e619d..cba0051 100644 --- a/examples/config.go +++ b/examples/config.go @@ -8,9 +8,6 @@ type exampleConfig struct { var config = exampleConfig{} -var dbName = "dtm_busi" - func init() { common.InitConfig(common.GetProjectDir(), &config) - config.DB["database"] = dbName } diff --git a/examples/data.go b/examples/data.go index d656c30..ed45c96 100644 --- a/examples/data.go +++ b/examples/data.go @@ -5,16 +5,12 @@ import ( "io/ioutil" "strings" - "github.com/sirupsen/logrus" "github.com/yedf/dtm/common" ) // RunSQLScript 1 -func RunSQLScript(mysql map[string]string, script string, skipDrop bool) { - conf := map[string]string{} - common.MustRemarshal(mysql, &conf) - conf["database"] = "" - db, con := common.DbAlone(conf) +func RunSQLScript(conf map[string]string, script string, skipDrop bool) { + con := common.DbAlone(conf) defer func() { con.Close() }() content, err := ioutil.ReadFile(script) e2p(err) @@ -24,8 +20,8 @@ func RunSQLScript(mysql map[string]string, script string, skipDrop bool) { if s == "" || skipDrop && strings.Contains(s, "drop") { continue } - logrus.Printf("executing: '%s'", s) - db.Must().Exec(s) + _, err = common.DbExec(con, s) + e2p(err) } } diff --git a/examples/examples. postgres.sql b/examples/examples. postgres.sql deleted file mode 100644 index e69de29..0000000 diff --git a/examples/examples.postgres.sql b/examples/examples.postgres.sql new file mode 100644 index 0000000..0730c69 --- /dev/null +++ b/examples/examples.postgres.sql @@ -0,0 +1,62 @@ +CREATE SCHEMA if not exists dtm_busi /* SQLINES DEMO *** RACTER SET utf8mb4 */; +create SCHEMA if not exists dtm_barrier /* SQLINES DEMO *** RACTER SET utf8mb4 */; + +drop table if exists dtm_busi.user_account; +-- SQLINES LICENSE FOR EVALUATION USE ONLY +create sequence if not exists dtm_busi.user_account_seq; + +create table if not exists dtm_busi.user_account( + id int PRIMARY KEY DEFAULT NEXTVAL ('dtm_busi.user_account_seq'), + user_id int UNIQUE , + balance DECIMAL(10, 2) not null default '0', + create_time timestamp(0) DEFAULT now(), + update_time timestamp(0) DEFAULT now() +); +-- SQLINES LICENSE FOR EVALUATION USE ONLY +create index if not exists create_idx on dtm_busi.user_account(create_time); +-- SQLINES LICENSE FOR EVALUATION USE ONLY +create index if not exists update_idx on dtm_busi.user_account(update_time); + +TRUNCATE dtm_busi.user_account +insert into dtm_busi.user_account (user_id, balance) values (1, 10000), (2, 10000); + +drop table if exists dtm_busi.user_account_trading; +-- SQLINES LICENSE FOR EVALUATION USE ONLY +create sequence if not exists dtm_busi.user_account_trading_seq; + +create table if not exists dtm_busi.user_account_trading( -- SQLINES DEMO *** �冻结的金额 + id int PRIMARY KEY DEFAULT NEXTVAL ('dtm_busi.user_account_trading_seq'), + user_id int UNIQUE , + trading_balance DECIMAL(10, 2) not null default '0', + create_time timestamp(0) DEFAULT now(), + update_time timestamp(0) DEFAULT now() +); +-- SQLINES LICENSE FOR EVALUATION USE ONLY +create index if not exists create_idx on dtm_busi.user_account_trading(create_time); +-- SQLINES LICENSE FOR EVALUATION USE ONLY +create index if not exists update_idx on dtm_busi.user_account_trading(update_time); + +TRUNCATE dtm_busi.user_account_trading; +insert into dtm_busi.user_account_trading (user_id, trading_balance) values (1, 0), (2, 0); + + +drop table if exists dtm_busi.barrier; +-- SQLINES LICENSE FOR EVALUATION USE ONLY +create sequence if not exists dtm_busi.barrier_seq; + +create table if not exists dtm_busi.barrier( + id int PRIMARY KEY DEFAULT NEXTVAL ('dtm_busi.barrier_seq'), + trans_type varchar(45) default '' , + gid varchar(128) default'', + branch_id varchar(128) default '', + branch_type varchar(45) default '', + reason varchar(45) default '' , + result varchar(2047) default null , + create_time timestamp(0) DEFAULT now(), + update_time timestamp(0) DEFAULT now(), + UNIQUE (gid, branch_id, branch_type) +); +-- SQLINES LICENSE FOR EVALUATION USE ONLY +create index if not exists create_idx on dtm_busi.barrier(create_time); +-- SQLINES LICENSE FOR EVALUATION USE ONLY +create index if not exists update_idx on dtm_busi.barrier(update_time); diff --git a/examples/main_xa.go b/examples/main_xa.go index d7fe99f..331bf6b 100644 --- a/examples/main_xa.go +++ b/examples/main_xa.go @@ -1,6 +1,7 @@ package examples import ( + "database/sql" "fmt" "strings" @@ -20,7 +21,7 @@ type UserAccount struct { } // TableName gorm table name -func (u *UserAccount) TableName() string { return "user_account" } +func (u *UserAccount) TableName() string { return "dtm_busi.user_account" } // UserAccountTrading freeze user account table type UserAccountTrading struct { @@ -30,7 +31,7 @@ type UserAccountTrading struct { } // TableName gorm table name -func (u *UserAccountTrading) TableName() string { return "user_account_trading" } +func (u *UserAccountTrading) TableName() string { return "dtm_busi.user_account_trading" } func dbGet() *common.DB { return common.DbGet(config.DB) @@ -40,7 +41,6 @@ func dbGet() *common.DB { func XaSetup(app *gin.Engine) { app.POST(BusiAPI+"/TransInXa", common.WrapHandler(xaTransIn)) app.POST(BusiAPI+"/TransOutXa", common.WrapHandler(xaTransOut)) - config.DB["database"] = "dtm_busi" var err error XaClient, err = dtmcli.NewXaClient(DtmServer, config.DB, app, Busi+"/xa") e2p(err) @@ -63,13 +63,13 @@ func XaFireRequest() string { } func xaTransIn(c *gin.Context) (interface{}, error) { - err := XaClient.XaLocalTransaction(c, func(db *common.DB, xa *dtmcli.Xa) (rerr error) { + err := XaClient.XaLocalTransaction(c, func(db *sql.DB, xa *dtmcli.Xa) (rerr error) { req := reqFrom(c) if req.TransInResult == "FAILURE" { return fmt.Errorf("tranIn FAILURE") } - dbr := db.Exec("update user_account set balance=balance+? where user_id=?", req.Amount, 2) - return dbr.Error + _, rerr = common.DbExec(db, "update dtm_busi.user_account set balance=balance+? where user_id=?", req.Amount, 2) + return }) if err != nil && strings.Contains(err.Error(), "FAILURE") { return M{"dtm_result": "FAILURE"}, nil @@ -79,13 +79,13 @@ func xaTransIn(c *gin.Context) (interface{}, error) { } func xaTransOut(c *gin.Context) (interface{}, error) { - err := XaClient.XaLocalTransaction(c, func(db *common.DB, xa *dtmcli.Xa) (rerr error) { + err := XaClient.XaLocalTransaction(c, func(db *sql.DB, xa *dtmcli.Xa) (rerr error) { req := reqFrom(c) if req.TransOutResult == "FAILURE" { return fmt.Errorf("tranOut failed") } - dbr := db.Exec("update user_account set balance=balance-? where user_id=?", req.Amount, 1) - return dbr.Error + _, rerr = common.DbExec(db, "update dtm_busi.user_account set balance=balance-? where user_id=?", req.Amount, 1) + return }) e2p(err) return M{"dtm_result": "SUCCESS"}, nil @@ -93,9 +93,10 @@ func xaTransOut(c *gin.Context) (interface{}, error) { // ResetXaData 1 func ResetXaData() { + if config.DB["driver"] != "mysql" { + return + } db := dbGet() - db.Must().Exec("truncate user_account") - db.Must().Exec("insert into user_account (user_id, balance) values (1, 10000), (2, 10000)") type XaRow struct { Data string } diff --git a/go.mod b/go.mod index 90e21d0..9401013 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/gin-gonic/gin v1.6.3 github.com/go-playground/assert/v2 v2.0.1 github.com/go-resty/resty/v2 v2.6.0 + github.com/go-sql-driver/mysql v1.5.0 github.com/json-iterator/go v1.1.10 // indirect github.com/kr/pretty v0.1.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect diff --git a/go.sum b/go.sum index cad09b1..6ee955e 100644 --- a/go.sum +++ b/go.sum @@ -93,5 +93,3 @@ gorm.io/driver/mysql v1.0.3/go.mod h1:twGxftLBlFgNVNakL7F+P/x9oYqoymG3YYT8cAfI9o gorm.io/gorm v1.20.4/go.mod h1:0HFTzE/SqkGTzK6TlDPPQbAYCluiVvhzoA1+aVyzenw= gorm.io/gorm v1.21.12 h1:3fQM0Eiz7jcJEhPggHEpoYnsGZqynMzverL77DV40RM= gorm.io/gorm v1.21.12/go.mod h1:F+OptMscr0P2F2qU97WT1WimdH9GaQPoDW7AYd5i2Y0= - -