diff --git a/common/utils_test.go b/common/utils_test.go index ecebe8e..96407b6 100644 --- a/common/utils_test.go +++ b/common/utils_test.go @@ -9,7 +9,7 @@ import ( "testing" "github.com/gin-gonic/gin" - "github.com/go-playground/assert/v2" + "github.com/stretchr/testify/assert" ) func TestEP(t *testing.T) { @@ -83,6 +83,10 @@ func TestSome(t *testing.T) { n := MustAtoi("123") assert.Equal(t, 123, n) + err := CatchP(func() { + MustAtoi("abc") + }) + assert.Error(t, err) wd := MustGetwd() assert.NotEqual(t, "", wd) diff --git a/dtmcli/barrier.go b/dtmcli/barrier.go index db0f477..565e1db 100644 --- a/dtmcli/barrier.go +++ b/dtmcli/barrier.go @@ -5,6 +5,7 @@ import ( "database/sql" "encoding/json" "fmt" + "net/url" "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" @@ -26,18 +27,25 @@ func (t *TransInfo) String() string { return fmt.Sprintf("transInfo: %s %s %s %s", t.TransType, t.Gid, t.BranchID, t.BranchType) } -// TransInfoFromReq construct transaction info from request -func TransInfoFromReq(c *gin.Context) *TransInfo { +// MustGetTrans construct transaction info from request +func MustGetTrans(c *gin.Context) *TransInfo { + ti, err := TransInfoFromQuery(c.Request.URL.Query()) + e2p(err) + return ti +} + +// TransInfoFromQuery construct transaction info from request +func TransInfoFromQuery(qs url.Values) (*TransInfo, error) { ti := &TransInfo{ - TransType: c.Query("trans_type"), - Gid: c.Query("gid"), - BranchID: c.Query("branch_id"), - BranchType: c.Query("branch_type"), + TransType: qs.Get("trans_type"), + Gid: qs.Get("gid"), + BranchID: qs.Get("branch_id"), + BranchType: qs.Get("branch_type"), } if ti.TransType == "" || ti.Gid == "" || ti.BranchID == "" || ti.BranchType == "" { - panic(fmt.Errorf("invlid trans info: %v", ti)) + return nil, fmt.Errorf("invlid trans info: %v", ti) } - return ti + return ti, nil } // BarrierModel barrier model for gorm @@ -119,7 +127,7 @@ func ThroughBarrierCall(db *sql.DB, transInfo *TransInfo, busiCall BusiFunc) (re return } if result.Valid { // 数据库里有上一次结果,返回上一次的结果 - res = json.Unmarshal([]byte(result.String), &res) + rerr = json.Unmarshal([]byte(result.String), &res) return } // 数据库里没有上次的结果,属于重复空补偿,直接返回成功 diff --git a/dtmcli/message.go b/dtmcli/message.go index 8d0ea58..f1acb03 100644 --- a/dtmcli/message.go +++ b/dtmcli/message.go @@ -3,7 +3,6 @@ package dtmcli import ( "fmt" - jsonitor "github.com/json-iterator/go" "github.com/sirupsen/logrus" "github.com/yedf/dtm/common" ) @@ -54,12 +53,7 @@ func (s *Msg) Add(action string, postData interface{}) *Msg { func (s *Msg) Submit() error { logrus.Printf("committing %s body: %v", s.Gid, &s.MsgData) resp, err := common.RestyClient.R().SetBody(&s.MsgData).Post(fmt.Sprintf("%s/submit", s.Server)) - rerr := CheckDtmResponse(resp, err) - if rerr != nil { - return rerr - } - s.Gid = jsonitor.Get(resp.Body(), "gid").ToString() - return nil + return CheckDtmResponse(resp, err) } // Prepare prepare the msg diff --git a/dtmcli/saga.go b/dtmcli/saga.go index 07d1596..c92d13d 100644 --- a/dtmcli/saga.go +++ b/dtmcli/saga.go @@ -3,7 +3,6 @@ package dtmcli import ( "fmt" - jsonitor "github.com/json-iterator/go" "github.com/sirupsen/logrus" "github.com/yedf/dtm/common" ) @@ -55,10 +54,5 @@ func (s *Saga) Add(action string, compensate string, postData interface{}) *Saga func (s *Saga) Submit() error { logrus.Printf("committing %s body: %v", s.Gid, &s.SagaData) resp, err := common.RestyClient.R().SetBody(&s.SagaData).Post(fmt.Sprintf("%s/submit", s.Server)) - rerr := CheckDtmResponse(resp, err) - if rerr != nil { - return rerr - } - s.Gid = jsonitor.Get(resp.Body(), "gid").ToString() - return nil + return CheckDtmResponse(resp, err) } diff --git a/dtmcli/tcc.go b/dtmcli/tcc.go index 7301417..7f89d93 100644 --- a/dtmcli/tcc.go +++ b/dtmcli/tcc.go @@ -81,13 +81,11 @@ func (t *Tcc) CallBranch(body interface{}, tryURL string, confirmURL string, can "cancel": cancelURL, }). Post(t.Dtm + "/registerTccBranch") + err = CheckDtmResponse(resp, err) if err != nil { return resp, err } - if !strings.Contains(resp.String(), "SUCCESS") { - return nil, fmt.Errorf("registerTccBranch failed: %s", resp.String()) - } - r, err := common.RestyClient.R(). + resp, err = common.RestyClient.R(). SetBody(body). SetQueryParams(common.MS{ "dtm": t.Dtm, @@ -97,8 +95,8 @@ func (t *Tcc) CallBranch(body interface{}, tryURL string, confirmURL string, can "branch_type": "try", }). Post(tryURL) - if err == nil && strings.Contains(r.String(), "FAILURE") { - return r, fmt.Errorf("branch return failure: %s", r.String()) + if err == nil && strings.Contains(resp.String(), "FAILURE") { + err = fmt.Errorf("branch return failure: %s", resp.String()) } - return r, err + return resp, err } diff --git a/dtmcli/types_test.go b/dtmcli/types_test.go index f72e6b1..97a3d88 100644 --- a/dtmcli/types_test.go +++ b/dtmcli/types_test.go @@ -1,6 +1,7 @@ package dtmcli import ( + "net/url" "testing" "github.com/stretchr/testify/assert" @@ -18,4 +19,6 @@ func TestTypes(t *testing.T) { idGen.NewBranchID() }) assert.Error(t, err) + _, err = TransInfoFromQuery(url.Values{}) + assert.Error(t, err) } diff --git a/dtmsvr/api.go b/dtmsvr/api.go index e9a4d5d..0ca9511 100644 --- a/dtmsvr/api.go +++ b/dtmsvr/api.go @@ -21,14 +21,14 @@ func addRoute(engine *gin.Engine) { } func newGid(c *gin.Context) (interface{}, error) { - return M{"gid": GenGid()}, nil + return M{"gid": GenGid(), "dtm_result": "SUCCESS"}, nil } func prepare(c *gin.Context) (interface{}, error) { t := TransFromContext(c) t.Status = "prepared" t.saveNew(dbGet()) - return M{"dtm_result": "SUCCESS", "gid": t.Gid}, nil + return M{"dtm_result": "SUCCESS"}, nil } func submit(c *gin.Context) (interface{}, error) { @@ -41,7 +41,7 @@ func submit(c *gin.Context) (interface{}, error) { t.Status = "submitted" t.saveNew(db) go t.Process(db) - return M{"dtm_result": "SUCCESS", "gid": t.Gid}, nil + return M{"dtm_result": "SUCCESS"}, nil } func abort(c *gin.Context) (interface{}, error) { diff --git a/dtmsvr/dtmsvr_test.go b/dtmsvr/dtmsvr_test.go index b0c8910..284c44b 100644 --- a/dtmsvr/dtmsvr_test.go +++ b/dtmsvr/dtmsvr_test.go @@ -56,6 +56,14 @@ func TestType(t *testing.T) { dtmcli.MustGenGid("http://localhost:8080/api/no") }) assert.Error(t, err) + err = common.CatchP(func() { + resp, err := common.RestyClient.R().SetBody(common.M{ + "gid": "1", + "trans_type": "msg", + }).Get("http://localhost:8080/api/dtmsvr/abort") + common.CheckRestySuccess(resp, err) + }) + assert.Error(t, err) } func getTransStatus(gid string) string { @@ -140,11 +148,17 @@ func TestSqlDB(t *testing.T) { asserts.Equal(dbr.RowsAffected, int64(1)) dbr = db.Model(&dtmcli.BarrierModel{}).Where("gid=?", "gid2").Find(&[]dtmcli.BarrierModel{}) asserts.Equal(dbr.RowsAffected, int64(0)) + gid2Res := common.M{"result": "first"} _, err = dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.DB) (interface{}, error) { logrus.Printf("submit gid2") - return nil, nil + return gid2Res, nil }) asserts.Nil(err) dbr = db.Model(&dtmcli.BarrierModel{}).Where("gid=?", "gid2").Find(&[]dtmcli.BarrierModel{}) asserts.Equal(dbr.RowsAffected, int64(1)) + newResult, err := dtmcli.ThroughBarrierCall(db.ToSQLDB(), transInfo, func(db *sql.DB) (interface{}, error) { + logrus.Printf("submit gid2") + return common.MS{"result": "ignored"}, nil + }) + asserts.Equal(newResult, gid2Res) } diff --git a/examples/main_saga_barrier.go b/examples/main_saga_barrier.go index 3cd7ceb..8fd8e0e 100644 --- a/examples/main_saga_barrier.go +++ b/examples/main_saga_barrier.go @@ -44,13 +44,13 @@ func sagaBarrierTransIn(c *gin.Context) (interface{}, error) { if req.TransInResult != "" { return req.TransInResult, nil } - return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { + return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) { return sagaBarrierAdjustBalance(sdb, 1, req.Amount) }) } func sagaBarrierTransInCompensate(c *gin.Context) (interface{}, error) { - return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { + return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) { return sagaBarrierAdjustBalance(sdb, 1, -reqFrom(c).Amount) }) } @@ -60,13 +60,13 @@ func sagaBarrierTransOut(c *gin.Context) (interface{}, error) { if req.TransInResult != "" { return req.TransInResult, nil } - return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { + return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) { return sagaBarrierAdjustBalance(sdb, 2, -req.Amount) }) } func sagaBarrierTransOutCompensate(c *gin.Context) (interface{}, error) { - return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { + return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) { return sagaBarrierAdjustBalance(sdb, 2, reqFrom(c).Amount) }) } diff --git a/examples/main_tcc_barrier.go b/examples/main_tcc_barrier.go index 761ff2b..189d9b8 100644 --- a/examples/main_tcc_barrier.go +++ b/examples/main_tcc_barrier.go @@ -70,19 +70,19 @@ func tccBarrierTransInTry(c *gin.Context) (interface{}, error) { if req.TransInResult != "" { return req.TransInResult, nil } - return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { + return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) { return adjustTrading(sdb, transInUID, req.Amount) }) } func tccBarrierTransInConfirm(c *gin.Context) (interface{}, error) { - return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { + return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) { return adjustBalance(sdb, transInUID, reqFrom(c).Amount) }) } func tccBarrierTransInCancel(c *gin.Context) (interface{}, error) { - return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { + return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) { return adjustTrading(sdb, transInUID, -reqFrom(c).Amount) }) } @@ -92,20 +92,20 @@ func tccBarrierTransOutTry(c *gin.Context) (interface{}, error) { if req.TransInResult != "" { return req.TransInResult, nil } - return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { + return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) { return adjustTrading(sdb, transOutUID, -req.Amount) }) } func tccBarrierTransOutConfirm(c *gin.Context) (interface{}, error) { - return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { + return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) { return adjustBalance(sdb, transOutUID, -reqFrom(c).Amount) }) } // TccBarrierTransOutCancel will be use in test func TccBarrierTransOutCancel(c *gin.Context) (interface{}, error) { - return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.TransInfoFromReq(c), func(sdb *sql.DB) (interface{}, error) { + return dtmcli.ThroughBarrierCall(dbGet().ToSQLDB(), dtmcli.MustGetTrans(c), func(sdb *sql.DB) (interface{}, error) { return adjustTrading(sdb, transOutUID, reqFrom(c).Amount) }) }