diff --git a/dtmcli/xa.go b/dtmcli/xa.go index b141e2d..50bdd81 100644 --- a/dtmcli/xa.go +++ b/dtmcli/xa.go @@ -4,11 +4,9 @@ import ( "database/sql" "fmt" "net/url" - "strings" "github.com/gin-gonic/gin" "github.com/go-resty/resty/v2" - "github.com/sirupsen/logrus" "github.com/yedf/dtm/common" ) @@ -18,10 +16,10 @@ type M = map[string]interface{} var e2p = common.E2P // XaGlobalFunc type of xa global function -type XaGlobalFunc func(xa *Xa) error +type XaGlobalFunc func(xa *Xa) (interface{}, error) // XaLocalFunc type of xa local function -type XaLocalFunc func(db *sql.DB, xa *Xa) error +type XaLocalFunc func(db *sql.DB, xa *Xa) (interface{}, error) // XaClient xa client type XaClient struct { @@ -71,82 +69,89 @@ func NewXaClient(server string, mysqlConf map[string]string, app *gin.Engine, ca defer db.Close() branchID := req.Gid + "-" + req.BranchID if req.Action == "commit" { - _, err := common.SdbExec(db, fmt.Sprintf("xa commit '%s'", branchID)) - e2p(err) + _, err = common.SdbExec(db, fmt.Sprintf("xa commit '%s'", branchID)) } else if req.Action == "rollback" { - _, err := common.SdbExec(db, fmt.Sprintf("xa rollback '%s'", branchID)) - e2p(err) + _, err = common.SdbExec(db, fmt.Sprintf("xa rollback '%s'", branchID)) } else { panic(fmt.Errorf("unknown action: %s", req.Action)) } - return M{"dtm_result": "SUCCESS"}, nil + return M{"dtm_result": "SUCCESS"}, err })) return xa, nil } // XaLocalTransaction start a xa local transaction -func (xc *XaClient) XaLocalTransaction(c *gin.Context, transFunc XaLocalFunc) (rerr error) { - defer common.P2E(&rerr) +func (xc *XaClient) XaLocalTransaction(c *gin.Context, xaFunc XaLocalFunc) (ret interface{}, rerr error) { xa := XaFromReq(c) branchID := xa.NewBranchID() xaBranch := xa.Gid + "-" + branchID db := common.SdbAlone(xc.Conf) defer func() { db.Close() }() - _, err := common.SdbExec(db, fmt.Sprintf("XA start '%s'", xaBranch)) - e2p(err) - err = transFunc(db, xa) - e2p(err) - resp, err := common.RestyClient.R(). + defer func() { + var x interface{} + _, err := common.SdbExec(db, fmt.Sprintf("XA end '%s'", xaBranch)) + if err != nil { + common.RedLogf("sql db exec error: %v", err) + } + if x = recover(); x != nil || IsFailure(ret, rerr) { + } else { + _, err = common.SdbExec(db, fmt.Sprintf("XA prepare '%s'", xaBranch)) + } + if err != nil { + common.RedLogf("sql db exec error: %v", err) + } + if x != nil { + panic(x) + } + }() + _, rerr = common.SdbExec(db, fmt.Sprintf("XA start '%s'", xaBranch)) + if rerr != nil { + return + } + ret, rerr = xaFunc(db, xa) + if IsFailure(ret, rerr) { + return + } + ret, rerr = common.RestyClient.R(). SetBody(&M{"gid": xa.Gid, "branch_id": branchID, "trans_type": "xa", "status": "prepared", "url": xc.CallbackURL}). Post(xc.Server + "/registerXaBranch") - e2p(err) - if !strings.Contains(resp.String(), "SUCCESS") { - e2p(fmt.Errorf("unknown server response: %s", resp.String())) - } - _, err = common.SdbExec(db, fmt.Sprintf("XA end '%s'", xaBranch)) - e2p(err) - _, err = common.SdbExec(db, fmt.Sprintf("XA prepare '%s'", xaBranch)) - e2p(err) - return nil + return } // XaGlobalTransaction start a xa global transaction -func (xc *XaClient) XaGlobalTransaction(gid string, transFunc XaGlobalFunc) error { +func (xc *XaClient) XaGlobalTransaction(gid string, xaFunc XaGlobalFunc) (ret interface{}, rerr error) { xa := Xa{IDGenerator: IDGenerator{}, Gid: gid} data := &M{ "gid": gid, "trans_type": "xa", } + resp, err := common.RestyClient.R().SetBody(data).Post(xc.Server + "/prepare") + if IsFailure(resp, err) { + return resp, err + } + // 小概率情况下,prepare成功了,但是由于网络状况导致上面Failure,那么不执行下面defer的内容,等待超时后再回滚标记事务失败,也没有问题 defer func() { - x := recover() + var x interface{} + if x = recover(); x != nil || IsFailure(ret, rerr) { + resp, err = common.RestyClient.R().SetBody(data).Post(xc.Server + "/abort") + } else { + resp, err = common.RestyClient.R().SetBody(data).Post(xc.Server + "/submit") + } + if IsFailure(resp, err) { + common.RedLogf("submitting or abort global transaction error: %v resp: %s", err, resp.String()) + } if x != nil { - r, err := common.RestyClient.R().SetBody(data).Post(xc.Server + "/abort") - if !strings.Contains(r.String(), "SUCCESS") { - logrus.Errorf("abort xa error: resp: %s err: %v", r.String(), err) - } + panic(x) } }() - resp, err := common.RestyClient.R().SetBody(data).Post(xc.Server + "/prepare") - rerr := CheckDtmResponse(resp, err) - if rerr != nil { - return rerr - } - rerr = transFunc(&xa) - if rerr != nil { - return rerr - } - resp, err = common.RestyClient.R().SetBody(data).Post(xc.Server + "/submit") - rerr = CheckDtmResponse(resp, err) - if rerr != nil { - return rerr - } - return nil + ret, rerr = xaFunc(&xa) + return } // CallBranch call a xa branch func (x *Xa) CallBranch(body interface{}, url string) (*resty.Response, error) { branchID := x.NewBranchID() - resp, err := common.RestyClient.R(). + return common.RestyClient.R(). SetBody(body). SetQueryParams(common.MS{ "gid": x.Gid, @@ -155,8 +160,4 @@ func (x *Xa) CallBranch(body interface{}, url string) (*resty.Response, error) { "branch_type": "action", }). Post(url) - if strings.Contains(resp.String(), "FAILURE") { - return resp, fmt.Errorf("FAILURE result: %s err: %v", resp.String(), err) - } - return resp, err } diff --git a/dtmsvr/trans_xa_test.go b/dtmsvr/trans_xa_test.go index 79361a6..132fdf8 100644 --- a/dtmsvr/trans_xa_test.go +++ b/dtmsvr/trans_xa_test.go @@ -4,9 +4,7 @@ import ( "fmt" "testing" - "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" - "github.com/yedf/dtm/common" "github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/examples" ) @@ -15,29 +13,30 @@ func TestXa(t *testing.T) { if config.DB["driver"] != "mysql" { return } - xaLocalError(t) + // xaLocalError(t) xaNormal(t) - xaRollback(t) + // xaRollback(t) } func xaLocalError(t *testing.T) { - err := examples.XaClient.XaGlobalTransaction("xaLocalError", func(xa *dtmcli.Xa) error { - return fmt.Errorf("an error") + _, err := examples.XaClient.XaGlobalTransaction("xaLocalError", func(xa *dtmcli.Xa) (interface{}, error) { + return nil, fmt.Errorf("an error") }) assert.Error(t, err, fmt.Errorf("an error")) } + func xaNormal(t *testing.T) { xc := examples.XaClient gid := "xaNormal" - err := xc.XaGlobalTransaction(gid, func(xa *dtmcli.Xa) error { + res, err := xc.XaGlobalTransaction(gid, func(xa *dtmcli.Xa) (interface{}, error) { req := examples.GenTransReq(30, false, false) resp, err := xa.CallBranch(req, examples.Busi+"/TransOutXa") - common.CheckRestySuccess(resp, err) - resp, err = xa.CallBranch(req, examples.Busi+"/TransInXa") - common.CheckRestySuccess(resp, err) - return nil + if dtmcli.IsFailure(resp, err) { + return resp, err + } + return xa.CallBranch(req, examples.Busi+"/TransInXa") }) - e2p(err) + dtmcli.PanicIfFailure(res, err) WaitTransProcessed(gid) assert.Equal(t, []string{"prepared", "succeed", "prepared", "succeed"}, getBranchesStatus(gid)) } @@ -45,17 +44,15 @@ func xaNormal(t *testing.T) { func xaRollback(t *testing.T) { xc := examples.XaClient gid := "xaRollback" - err := xc.XaGlobalTransaction(gid, func(xa *dtmcli.Xa) error { + res, err := xc.XaGlobalTransaction(gid, func(xa *dtmcli.Xa) (interface{}, error) { req := &examples.TransReq{Amount: 30, TransInResult: "FAILURE"} resp, err := xa.CallBranch(req, examples.Busi+"/TransOutXa") - common.CheckRestySuccess(resp, err) - resp, err = xa.CallBranch(req, examples.Busi+"/TransInXa") - common.CheckRestySuccess(resp, err) - return nil + if dtmcli.IsFailure(resp, err) { + return resp, err + } + return xa.CallBranch(req, examples.Busi+"/TransInXa") }) - if err != nil { - logrus.Errorf("global transaction failed, so rollback") - } + assert.True(t, dtmcli.IsFailure(res, err)) WaitTransProcessed(gid) assert.Equal(t, []string{"succeed", "prepared"}, getBranchesStatus(gid)) assert.Equal(t, "failed", getTransStatus(gid)) diff --git a/examples/main_xa.go b/examples/main_xa.go index a7919d7..fd7c378 100644 --- a/examples/main_xa.go +++ b/examples/main_xa.go @@ -2,8 +2,6 @@ package examples import ( "database/sql" - "fmt" - "strings" "github.com/gin-gonic/gin" "github.com/yedf/dtm/common" @@ -25,42 +23,34 @@ func XaSetup(app *gin.Engine) { // XaFireRequest 注册全局XA事务,调用XA的分支 func XaFireRequest() string { gid := dtmcli.MustGenGid(DtmServer) - err := XaClient.XaGlobalTransaction(gid, func(xa *dtmcli.Xa) (rerr error) { - defer common.P2E(&rerr) + res, err := XaClient.XaGlobalTransaction(gid, func(xa *dtmcli.Xa) (interface{}, error) { req := &TransReq{Amount: 30} resp, err := xa.CallBranch(req, Busi+"/TransOutXa") - common.CheckRestySuccess(resp, err) - resp, err = xa.CallBranch(req, Busi+"/TransInXa") - common.CheckRestySuccess(resp, err) - return nil + if dtmcli.IsFailure(resp, err) { + return resp, err + } + return xa.CallBranch(req, Busi+"/TransInXa") }) - e2p(err) + dtmcli.PanicIfFailure(res, err) return gid } func xaTransIn(c *gin.Context) (interface{}, error) { - err := XaClient.XaLocalTransaction(c, func(db *sql.DB, xa *dtmcli.Xa) (rerr error) { + return XaClient.XaLocalTransaction(c, func(db *sql.DB, xa *dtmcli.Xa) (interface{}, error) { if reqFrom(c).TransInResult == "FAILURE" { - return fmt.Errorf("tranIn FAILURE") + return M{"dtm_result": "FAILURE"}, nil } - _, rerr = common.SdbExec(db, "update dtm_busi.user_account set balance=balance+? where user_id=?", reqFrom(c).Amount, 2) - return + _, err := common.SdbExec(db, "update dtm_busi.user_account set balance=balance+? where user_id=?", reqFrom(c).Amount, 2) + return M{"dtm_result": "SUCCESS"}, err }) - if err != nil && strings.Contains(err.Error(), "FAILURE") { - return M{"dtm_result": "FAILURE"}, nil - } - e2p(err) - return M{"dtm_result": "SUCCESS"}, nil } func xaTransOut(c *gin.Context) (interface{}, error) { - err := XaClient.XaLocalTransaction(c, func(db *sql.DB, xa *dtmcli.Xa) (rerr error) { + return XaClient.XaLocalTransaction(c, func(db *sql.DB, xa *dtmcli.Xa) (interface{}, error) { if reqFrom(c).TransOutResult == "FAILURE" { - return fmt.Errorf("tranOut failed") + return M{"dtm_result": "FAILURE"}, nil } - _, rerr = common.SdbExec(db, "update dtm_busi.user_account set balance=balance-? where user_id=?", reqFrom(c).Amount, 1) - return + _, err := common.SdbExec(db, "update dtm_busi.user_account set balance=balance-? where user_id=?", reqFrom(c).Amount, 1) + return M{"dtm_result": "SUCCESS"}, err }) - e2p(err) - return M{"dtm_result": "SUCCESS"}, nil }