diff --git a/common/types.go b/common/types.go index cf564cc..7c6d657 100644 --- a/common/types.go +++ b/common/types.go @@ -4,7 +4,6 @@ import ( "database/sql" "fmt" "regexp" - "strings" "time" "github.com/sirupsen/logrus" @@ -97,9 +96,7 @@ func (op *tracePlugin) Initialize(db *gorm.DB) (err error) { // GetDsn get dsn from map config func GetDsn(conf map[string]string) string { - if IsDockerCompose() { - conf["host"] = strings.Replace(conf["host"], "localhost", "host.docker.internal", 1) - } + conf["host"] = MayReplaceLocalhost(conf["host"]) // logrus.Printf("is docker: %t IS_DOCKER_COMPOSE: %s and conf host: %s", IsDockerCompose(), os.Getenv("IS_DOCKER_COMPOSE"), conf["host"]) return fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", conf["user"], conf["password"], conf["host"], conf["port"], conf["database"]) } diff --git a/common/utils.go b/common/utils.go index b059786..1389eea 100644 --- a/common/utils.go +++ b/common/utils.go @@ -163,9 +163,7 @@ func init() { // RestyClient.SetRetryCount(2) // RestyClient.SetRetryWaitTime(1 * time.Second) RestyClient.OnBeforeRequest(func(c *resty.Client, r *resty.Request) error { - if IsDockerCompose() { - r.URL = strings.Replace(r.URL, "localhost", "host.docker.internal", 1) - } + r.URL = MayReplaceLocalhost(r.URL) logrus.Printf("requesting: %s %s %v %v", r.Method, r.URL, r.Body, r.QueryParam) return nil }) @@ -247,7 +245,10 @@ func GetFuncName() string { return runtime.FuncForPC(pc).Name() } -// IsDockerCompose name is clear -func IsDockerCompose() bool { - return os.Getenv("IS_DOCKER_COMPOSE") != "" +// MayReplaceLocalhost when run in docker compose, change localhost to host.docker.internal for accessing host network +func MayReplaceLocalhost(host string) string { + if os.Getenv("IS_DOCKER_COMPOSE") != "" { + return strings.Replace(host, "localhost", "host.docker.internal", 1) + } + return host } diff --git a/common/utils_test.go b/common/utils_test.go index 96407b6..f50e614 100644 --- a/common/utils_test.go +++ b/common/utils_test.go @@ -5,6 +5,7 @@ import ( "io" "net/http" "net/http/httptest" + "os" "strings" "testing" @@ -95,4 +96,11 @@ func TestSome(t *testing.T) { func1 := GetFuncName() assert.Equal(t, true, strings.HasSuffix(func1, "TestSome")) + + os.Setenv("IS_DOCKER_COMPOSE", "1") + s := MayReplaceLocalhost("http://localhost") + assert.Equal(t, "http://host.docker.internal", s) + os.Setenv("IS_DOCKER_COMPOSE", "") + s2 := MayReplaceLocalhost("http://localhost") + assert.Equal(t, "http://localhost", s2) } diff --git a/dtmcli/types_test.go b/dtmcli/types_test.go index 97a3d88..d47e232 100644 --- a/dtmcli/types_test.go +++ b/dtmcli/types_test.go @@ -1,6 +1,7 @@ package dtmcli import ( + "fmt" "net/url" "testing" @@ -21,4 +22,8 @@ func TestTypes(t *testing.T) { assert.Error(t, err) _, err = TransInfoFromQuery(url.Values{}) assert.Error(t, err) + + err2 := fmt.Errorf("an error") + err3 := CheckDtmResponse(nil, err2) + assert.Error(t, err2, err3) } diff --git a/dtmsvr/dtmsvr_test.go b/dtmsvr/dtmsvr_test.go index 3d5b538..26da56c 100644 --- a/dtmsvr/dtmsvr_test.go +++ b/dtmsvr/dtmsvr_test.go @@ -44,7 +44,8 @@ func TestCover(t *testing.T) { defer handlePanic() checkAffected(db.DB) - go CronExpiredTrans(1) + CronExpiredTrans(1) + go sleepCronTime() } func TestType(t *testing.T) { diff --git a/dtmsvr/trans.go b/dtmsvr/trans.go index 568a179..5d2beee 100644 --- a/dtmsvr/trans.go +++ b/dtmsvr/trans.go @@ -116,6 +116,7 @@ func (t *TransGlobal) Process(db *common.DB) { if TransProcessedTestChan != nil { logrus.Printf("processed: %s", t.Gid) TransProcessedTestChan <- t.Gid + logrus.Printf("notified: %s", t.Gid) } }() logrus.Printf("processing: %s status: %s", t.Gid, t.Status) diff --git a/dtmsvr/trans_tcc.go b/dtmsvr/trans_tcc.go index 0b8c597..8a53fb6 100644 --- a/dtmsvr/trans_tcc.go +++ b/dtmsvr/trans_tcc.go @@ -26,9 +26,6 @@ func (t *transTccProcessor) ExecBranch(db *common.DB, branch *TransBranch) { if strings.Contains(body, "SUCCESS") { t.touch(db, config.TransCronInterval) branch.changeStatus(db, "succeed") - } else if branch.BranchType == "try" && strings.Contains(body, "FAILURE") { - t.touch(db, config.TransCronInterval) - branch.changeStatus(db, "failed") } else { panic(fmt.Errorf("unknown response: %s, will be retried", body)) } diff --git a/dtmsvr/trans_tcc_test.go b/dtmsvr/trans_tcc_test.go index 741026c..8383b05 100644 --- a/dtmsvr/trans_tcc_test.go +++ b/dtmsvr/trans_tcc_test.go @@ -2,6 +2,7 @@ package dtmsvr import ( "testing" + "time" "github.com/stretchr/testify/assert" "github.com/yedf/dtm/dtmcli" @@ -33,9 +34,14 @@ func tccRollback(t *testing.T) { err := dtmcli.TccGlobalTransaction(examples.DtmServer, gid, func(tcc *dtmcli.Tcc) (rerr error) { resp, rerr := tcc.CallBranch(data, Busi+"/TransOut", Busi+"/TransOutConfirm", Busi+"/TransOutRevert") assert.Contains(t, resp.String(), "SUCCESS") + examples.MainSwitch.TransOutRevertResult.SetOnce("PENDING") _, rerr = tcc.CallBranch(data, Busi+"/TransIn", Busi+"/TransInConfirm", Busi+"/TransInRevert") assert.Error(t, rerr) return }) assert.Error(t, err) + WaitTransProcessed(gid) + assert.Equal(t, "aborting", getTransStatus(gid)) + CronTransOnce(60 * time.Second) + assert.Equal(t, "failed", getTransStatus(gid)) }