xa test past

This commit is contained in:
yedf2 2021-08-01 15:20:52 +08:00
parent f4796ec474
commit db4b6d59f0
3 changed files with 83 additions and 95 deletions

View File

@ -4,11 +4,9 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"net/url" "net/url"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-resty/resty/v2" "github.com/go-resty/resty/v2"
"github.com/sirupsen/logrus"
"github.com/yedf/dtm/common" "github.com/yedf/dtm/common"
) )
@ -18,10 +16,10 @@ type M = map[string]interface{}
var e2p = common.E2P var e2p = common.E2P
// XaGlobalFunc type of xa global function // XaGlobalFunc type of xa global function
type XaGlobalFunc func(xa *Xa) error type XaGlobalFunc func(xa *Xa) (interface{}, error)
// XaLocalFunc type of xa local function // XaLocalFunc type of xa local function
type XaLocalFunc func(db *sql.DB, xa *Xa) error type XaLocalFunc func(db *sql.DB, xa *Xa) (interface{}, error)
// XaClient xa client // XaClient xa client
type XaClient struct { type XaClient struct {
@ -71,82 +69,89 @@ func NewXaClient(server string, mysqlConf map[string]string, app *gin.Engine, ca
defer db.Close() defer db.Close()
branchID := req.Gid + "-" + req.BranchID branchID := req.Gid + "-" + req.BranchID
if req.Action == "commit" { if req.Action == "commit" {
_, err := common.SdbExec(db, fmt.Sprintf("xa commit '%s'", branchID)) _, err = common.SdbExec(db, fmt.Sprintf("xa commit '%s'", branchID))
e2p(err)
} else if req.Action == "rollback" { } else if req.Action == "rollback" {
_, err := common.SdbExec(db, fmt.Sprintf("xa rollback '%s'", branchID)) _, err = common.SdbExec(db, fmt.Sprintf("xa rollback '%s'", branchID))
e2p(err)
} else { } else {
panic(fmt.Errorf("unknown action: %s", req.Action)) panic(fmt.Errorf("unknown action: %s", req.Action))
} }
return M{"dtm_result": "SUCCESS"}, nil return M{"dtm_result": "SUCCESS"}, err
})) }))
return xa, nil return xa, nil
} }
// XaLocalTransaction start a xa local transaction // XaLocalTransaction start a xa local transaction
func (xc *XaClient) XaLocalTransaction(c *gin.Context, transFunc XaLocalFunc) (rerr error) { func (xc *XaClient) XaLocalTransaction(c *gin.Context, xaFunc XaLocalFunc) (ret interface{}, rerr error) {
defer common.P2E(&rerr)
xa := XaFromReq(c) xa := XaFromReq(c)
branchID := xa.NewBranchID() branchID := xa.NewBranchID()
xaBranch := xa.Gid + "-" + branchID xaBranch := xa.Gid + "-" + branchID
db := common.SdbAlone(xc.Conf) db := common.SdbAlone(xc.Conf)
defer func() { db.Close() }() defer func() { db.Close() }()
_, err := common.SdbExec(db, fmt.Sprintf("XA start '%s'", xaBranch)) defer func() {
e2p(err) var x interface{}
err = transFunc(db, xa) _, err := common.SdbExec(db, fmt.Sprintf("XA end '%s'", xaBranch))
e2p(err) if err != nil {
resp, err := common.RestyClient.R(). common.RedLogf("sql db exec error: %v", err)
}
if x = recover(); x != nil || IsFailure(ret, rerr) {
} else {
_, err = common.SdbExec(db, fmt.Sprintf("XA prepare '%s'", xaBranch))
}
if err != nil {
common.RedLogf("sql db exec error: %v", err)
}
if x != nil {
panic(x)
}
}()
_, rerr = common.SdbExec(db, fmt.Sprintf("XA start '%s'", xaBranch))
if rerr != nil {
return
}
ret, rerr = xaFunc(db, xa)
if IsFailure(ret, rerr) {
return
}
ret, rerr = common.RestyClient.R().
SetBody(&M{"gid": xa.Gid, "branch_id": branchID, "trans_type": "xa", "status": "prepared", "url": xc.CallbackURL}). SetBody(&M{"gid": xa.Gid, "branch_id": branchID, "trans_type": "xa", "status": "prepared", "url": xc.CallbackURL}).
Post(xc.Server + "/registerXaBranch") Post(xc.Server + "/registerXaBranch")
e2p(err) return
if !strings.Contains(resp.String(), "SUCCESS") {
e2p(fmt.Errorf("unknown server response: %s", resp.String()))
}
_, err = common.SdbExec(db, fmt.Sprintf("XA end '%s'", xaBranch))
e2p(err)
_, err = common.SdbExec(db, fmt.Sprintf("XA prepare '%s'", xaBranch))
e2p(err)
return nil
} }
// XaGlobalTransaction start a xa global transaction // XaGlobalTransaction start a xa global transaction
func (xc *XaClient) XaGlobalTransaction(gid string, transFunc XaGlobalFunc) error { func (xc *XaClient) XaGlobalTransaction(gid string, xaFunc XaGlobalFunc) (ret interface{}, rerr error) {
xa := Xa{IDGenerator: IDGenerator{}, Gid: gid} xa := Xa{IDGenerator: IDGenerator{}, Gid: gid}
data := &M{ data := &M{
"gid": gid, "gid": gid,
"trans_type": "xa", "trans_type": "xa",
} }
defer func() { resp, err := common.RestyClient.R().SetBody(data).Post(xc.Server + "/prepare")
x := recover() if IsFailure(resp, err) {
if x != nil { return resp, err
r, err := common.RestyClient.R().SetBody(data).Post(xc.Server + "/abort")
if !strings.Contains(r.String(), "SUCCESS") {
logrus.Errorf("abort xa error: resp: %s err: %v", r.String(), err)
} }
// 小概率情况下prepare成功了但是由于网络状况导致上面Failure那么不执行下面defer的内容等待超时后再回滚标记事务失败也没有问题
defer func() {
var x interface{}
if x = recover(); x != nil || IsFailure(ret, rerr) {
resp, err = common.RestyClient.R().SetBody(data).Post(xc.Server + "/abort")
} else {
resp, err = common.RestyClient.R().SetBody(data).Post(xc.Server + "/submit")
}
if IsFailure(resp, err) {
common.RedLogf("submitting or abort global transaction error: %v resp: %s", err, resp.String())
}
if x != nil {
panic(x)
} }
}() }()
resp, err := common.RestyClient.R().SetBody(data).Post(xc.Server + "/prepare") ret, rerr = xaFunc(&xa)
rerr := CheckDtmResponse(resp, err) return
if rerr != nil {
return rerr
}
rerr = transFunc(&xa)
if rerr != nil {
return rerr
}
resp, err = common.RestyClient.R().SetBody(data).Post(xc.Server + "/submit")
rerr = CheckDtmResponse(resp, err)
if rerr != nil {
return rerr
}
return nil
} }
// CallBranch call a xa branch // CallBranch call a xa branch
func (x *Xa) CallBranch(body interface{}, url string) (*resty.Response, error) { func (x *Xa) CallBranch(body interface{}, url string) (*resty.Response, error) {
branchID := x.NewBranchID() branchID := x.NewBranchID()
resp, err := common.RestyClient.R(). return common.RestyClient.R().
SetBody(body). SetBody(body).
SetQueryParams(common.MS{ SetQueryParams(common.MS{
"gid": x.Gid, "gid": x.Gid,
@ -155,8 +160,4 @@ func (x *Xa) CallBranch(body interface{}, url string) (*resty.Response, error) {
"branch_type": "action", "branch_type": "action",
}). }).
Post(url) Post(url)
if strings.Contains(resp.String(), "FAILURE") {
return resp, fmt.Errorf("FAILURE result: %s err: %v", resp.String(), err)
}
return resp, err
} }

View File

@ -4,9 +4,7 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/yedf/dtm/common"
"github.com/yedf/dtm/dtmcli" "github.com/yedf/dtm/dtmcli"
"github.com/yedf/dtm/examples" "github.com/yedf/dtm/examples"
) )
@ -15,29 +13,30 @@ func TestXa(t *testing.T) {
if config.DB["driver"] != "mysql" { if config.DB["driver"] != "mysql" {
return return
} }
xaLocalError(t) // xaLocalError(t)
xaNormal(t) xaNormal(t)
xaRollback(t) // xaRollback(t)
} }
func xaLocalError(t *testing.T) { func xaLocalError(t *testing.T) {
err := examples.XaClient.XaGlobalTransaction("xaLocalError", func(xa *dtmcli.Xa) error { _, err := examples.XaClient.XaGlobalTransaction("xaLocalError", func(xa *dtmcli.Xa) (interface{}, error) {
return fmt.Errorf("an error") return nil, fmt.Errorf("an error")
}) })
assert.Error(t, err, fmt.Errorf("an error")) assert.Error(t, err, fmt.Errorf("an error"))
} }
func xaNormal(t *testing.T) { func xaNormal(t *testing.T) {
xc := examples.XaClient xc := examples.XaClient
gid := "xaNormal" gid := "xaNormal"
err := xc.XaGlobalTransaction(gid, func(xa *dtmcli.Xa) error { res, err := xc.XaGlobalTransaction(gid, func(xa *dtmcli.Xa) (interface{}, error) {
req := examples.GenTransReq(30, false, false) req := examples.GenTransReq(30, false, false)
resp, err := xa.CallBranch(req, examples.Busi+"/TransOutXa") resp, err := xa.CallBranch(req, examples.Busi+"/TransOutXa")
common.CheckRestySuccess(resp, err) if dtmcli.IsFailure(resp, err) {
resp, err = xa.CallBranch(req, examples.Busi+"/TransInXa") return resp, err
common.CheckRestySuccess(resp, err) }
return nil return xa.CallBranch(req, examples.Busi+"/TransInXa")
}) })
e2p(err) dtmcli.PanicIfFailure(res, err)
WaitTransProcessed(gid) WaitTransProcessed(gid)
assert.Equal(t, []string{"prepared", "succeed", "prepared", "succeed"}, getBranchesStatus(gid)) assert.Equal(t, []string{"prepared", "succeed", "prepared", "succeed"}, getBranchesStatus(gid))
} }
@ -45,17 +44,15 @@ func xaNormal(t *testing.T) {
func xaRollback(t *testing.T) { func xaRollback(t *testing.T) {
xc := examples.XaClient xc := examples.XaClient
gid := "xaRollback" gid := "xaRollback"
err := xc.XaGlobalTransaction(gid, func(xa *dtmcli.Xa) error { res, err := xc.XaGlobalTransaction(gid, func(xa *dtmcli.Xa) (interface{}, error) {
req := &examples.TransReq{Amount: 30, TransInResult: "FAILURE"} req := &examples.TransReq{Amount: 30, TransInResult: "FAILURE"}
resp, err := xa.CallBranch(req, examples.Busi+"/TransOutXa") resp, err := xa.CallBranch(req, examples.Busi+"/TransOutXa")
common.CheckRestySuccess(resp, err) if dtmcli.IsFailure(resp, err) {
resp, err = xa.CallBranch(req, examples.Busi+"/TransInXa") return resp, err
common.CheckRestySuccess(resp, err)
return nil
})
if err != nil {
logrus.Errorf("global transaction failed, so rollback")
} }
return xa.CallBranch(req, examples.Busi+"/TransInXa")
})
assert.True(t, dtmcli.IsFailure(res, err))
WaitTransProcessed(gid) WaitTransProcessed(gid)
assert.Equal(t, []string{"succeed", "prepared"}, getBranchesStatus(gid)) assert.Equal(t, []string{"succeed", "prepared"}, getBranchesStatus(gid))
assert.Equal(t, "failed", getTransStatus(gid)) assert.Equal(t, "failed", getTransStatus(gid))

View File

@ -2,8 +2,6 @@ package examples
import ( import (
"database/sql" "database/sql"
"fmt"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/yedf/dtm/common" "github.com/yedf/dtm/common"
@ -25,42 +23,34 @@ func XaSetup(app *gin.Engine) {
// XaFireRequest 注册全局XA事务调用XA的分支 // XaFireRequest 注册全局XA事务调用XA的分支
func XaFireRequest() string { func XaFireRequest() string {
gid := dtmcli.MustGenGid(DtmServer) gid := dtmcli.MustGenGid(DtmServer)
err := XaClient.XaGlobalTransaction(gid, func(xa *dtmcli.Xa) (rerr error) { res, err := XaClient.XaGlobalTransaction(gid, func(xa *dtmcli.Xa) (interface{}, error) {
defer common.P2E(&rerr)
req := &TransReq{Amount: 30} req := &TransReq{Amount: 30}
resp, err := xa.CallBranch(req, Busi+"/TransOutXa") resp, err := xa.CallBranch(req, Busi+"/TransOutXa")
common.CheckRestySuccess(resp, err) if dtmcli.IsFailure(resp, err) {
resp, err = xa.CallBranch(req, Busi+"/TransInXa") return resp, err
common.CheckRestySuccess(resp, err) }
return nil return xa.CallBranch(req, Busi+"/TransInXa")
}) })
e2p(err) dtmcli.PanicIfFailure(res, err)
return gid return gid
} }
func xaTransIn(c *gin.Context) (interface{}, error) { func xaTransIn(c *gin.Context) (interface{}, error) {
err := XaClient.XaLocalTransaction(c, func(db *sql.DB, xa *dtmcli.Xa) (rerr error) { return XaClient.XaLocalTransaction(c, func(db *sql.DB, xa *dtmcli.Xa) (interface{}, error) {
if reqFrom(c).TransInResult == "FAILURE" { if reqFrom(c).TransInResult == "FAILURE" {
return fmt.Errorf("tranIn FAILURE")
}
_, rerr = common.SdbExec(db, "update dtm_busi.user_account set balance=balance+? where user_id=?", reqFrom(c).Amount, 2)
return
})
if err != nil && strings.Contains(err.Error(), "FAILURE") {
return M{"dtm_result": "FAILURE"}, nil return M{"dtm_result": "FAILURE"}, nil
} }
e2p(err) _, err := common.SdbExec(db, "update dtm_busi.user_account set balance=balance+? where user_id=?", reqFrom(c).Amount, 2)
return M{"dtm_result": "SUCCESS"}, nil return M{"dtm_result": "SUCCESS"}, err
})
} }
func xaTransOut(c *gin.Context) (interface{}, error) { func xaTransOut(c *gin.Context) (interface{}, error) {
err := XaClient.XaLocalTransaction(c, func(db *sql.DB, xa *dtmcli.Xa) (rerr error) { return XaClient.XaLocalTransaction(c, func(db *sql.DB, xa *dtmcli.Xa) (interface{}, error) {
if reqFrom(c).TransOutResult == "FAILURE" { if reqFrom(c).TransOutResult == "FAILURE" {
return fmt.Errorf("tranOut failed") return M{"dtm_result": "FAILURE"}, nil
} }
_, rerr = common.SdbExec(db, "update dtm_busi.user_account set balance=balance-? where user_id=?", reqFrom(c).Amount, 1) _, err := common.SdbExec(db, "update dtm_busi.user_account set balance=balance-? where user_id=?", reqFrom(c).Amount, 1)
return return M{"dtm_result": "SUCCESS"}, err
}) })
e2p(err)
return M{"dtm_result": "SUCCESS"}, nil
} }