diff --git a/dtmcli/types.go b/dtmcli/types.go index d7a9044..b29acd2 100644 --- a/dtmcli/types.go +++ b/dtmcli/types.go @@ -34,8 +34,8 @@ func PanicIfFailure(res interface{}, err error) { } } -// CheckUserResponse 检查Response,返回错误 -func CheckUserResponse(resp *resty.Response, err error) error { +// CheckResponse 检查Response,返回错误 +func CheckResponse(resp *resty.Response, err error) error { if err == nil && resp != nil { if resp.IsError() { return errors.New(resp.String()) @@ -46,6 +46,18 @@ func CheckUserResponse(resp *resty.Response, err error) error { return err } +// CheckResult 检查Result,返回错误 +func CheckResult(res interface{}, err error) error { + resp, ok := res.(*resty.Response) + if ok { + return CheckResponse(resp, err) + } + if res != nil && strings.Contains(common.MustMarshalString(res), "FAILURE") { + return ErrFailure + } + return err +} + // CheckDtmResponse check the response of dtm, if not ok ,generate error func CheckDtmResponse(resp *resty.Response, err error) error { if err != nil { @@ -135,3 +147,6 @@ var ErrFailure = errors.New("transaction FAILURE") // ResultSuccess 表示返回成功,可以进行下一步 var ResultSuccess = common.M{"dtm_result": "SUCCESS"} + +// ResultFailure 表示返回失败,要求回滚 +var ResultFailure = common.M{"dtm_result": "FAILURE"} diff --git a/dtmcli/xa.go b/dtmcli/xa.go index d751799..f8520b2 100644 --- a/dtmcli/xa.go +++ b/dtmcli/xa.go @@ -95,6 +95,7 @@ func (xc *XaClient) XaLocalTransaction(c *gin.Context, xaFunc XaLocalFunc) (ret return } ret, rerr = xaFunc(db, xa) + rerr = CheckResult(ret, rerr) if rerr != nil { return } @@ -143,5 +144,5 @@ func (x *Xa) CallBranch(body interface{}, url string) (*resty.Response, error) { "branch_type": "action", }). Post(url) - return resp, CheckUserResponse(resp, err) + return resp, CheckResponse(resp, err) } diff --git a/dtmsvr/trans_xa_test.go b/dtmsvr/trans_xa_test.go index 3eb7ba4..c848d1c 100644 --- a/dtmsvr/trans_xa_test.go +++ b/dtmsvr/trans_xa_test.go @@ -56,6 +56,6 @@ func xaRollback(t *testing.T) { }) assert.Error(t, err) WaitTransProcessed(gid) - assert.Equal(t, []string{"succeed", "prepared", "succeed", "prepared"}, getBranchesStatus(gid)) + assert.Equal(t, []string{"succeed", "prepared"}, getBranchesStatus(gid)) assert.Equal(t, "failed", getTransStatus(gid)) } diff --git a/examples/quick_start.go b/examples/quick_start.go index 3b3810d..409e0b3 100644 --- a/examples/quick_start.go +++ b/examples/quick_start.go @@ -44,7 +44,7 @@ func QsFireRequest() string { func qsAdjustBalance(uid int, amount int) (interface{}, error) { _, err := common.SdbExec(sdbGet(), "update dtm_busi.user_account set balance = balance + ? where user_id = ?", amount, uid) - return M{"dtm_result": "SUCCESS"}, err + return dtmcli.ResultSuccess, err } func qsAddRoute(app *gin.Engine) {