diff --git a/dtmsvr/dtmsvr_test.go b/dtmsvr/dtmsvr_test.go index c714999..a898375 100644 --- a/dtmsvr/dtmsvr_test.go +++ b/dtmsvr/dtmsvr_test.go @@ -17,6 +17,21 @@ var DtmServer = examples.DtmServer var Busi = examples.Busi var app *gin.Engine +func resetXaData() { + if config.DB["driver"] != "mysql" { + return + } + db := dbGet() + type XaRow struct { + Data string + } + xas := []XaRow{} + db.Must().Raw("xa recover").Scan(&xas) + for _, xa := range xas { + db.Must().Exec(fmt.Sprintf("xa rollback '%s'", xa.Data)) + } +} + func TestMain(m *testing.M) { TransProcessedTestChan = make(chan string, 1) common.InitConfig(common.GetProjectDir(), &config) @@ -32,7 +47,7 @@ func TestMain(m *testing.M) { examples.TccBarrierAddRoute(app) examples.SagaBarrierAddRoute(app) - examples.ResetXaData() + resetXaData() m.Run() } diff --git a/examples/main_xa.go b/examples/main_xa.go index 8bfbf49..2f1b5b7 100644 --- a/examples/main_xa.go +++ b/examples/main_xa.go @@ -13,14 +13,6 @@ import ( // XaClient XA client connection var XaClient *dtmcli.XaClient = nil -func dbGet() *common.DB { - return common.DbGet(config.DB) -} - -func sdbGet() *sql.DB { - return common.SdbGet(config.DB) -} - // XaSetup 挂载http的api,创建XaClient func XaSetup(app *gin.Engine) { app.POST(BusiAPI+"/TransInXa", common.WrapHandler(xaTransIn)) @@ -74,19 +66,3 @@ func xaTransOut(c *gin.Context) (interface{}, error) { e2p(err) return M{"dtm_result": "SUCCESS"}, nil } - -// ResetXaData 1 -func ResetXaData() { - if config.DB["driver"] != "mysql" { - return - } - db := dbGet() - type XaRow struct { - Data string - } - xas := []XaRow{} - db.Must().Raw("xa recover").Scan(&xas) - for _, xa := range xas { - db.Must().Exec(fmt.Sprintf("xa rollback '%s'", xa.Data)) - } -} diff --git a/examples/types.go b/examples/types.go index fe7562d..0e0621b 100644 --- a/examples/types.go +++ b/examples/types.go @@ -1,6 +1,7 @@ package examples import ( + "database/sql" "fmt" "github.com/gin-gonic/gin" @@ -52,3 +53,11 @@ func infoFromContext(c *gin.Context) *dtmcli.TransInfo { } return &info } + +func dbGet() *common.DB { + return common.DbGet(config.DB) +} + +func sdbGet() *sql.DB { + return common.SdbGet(config.DB) +}