From b5ae9403bbed1cb84d20bc84b39cbf04c6a25ef3 Mon Sep 17 00:00:00 2001 From: yedf2 <120050102@qq.com> Date: Fri, 23 Jul 2021 13:49:56 +0800 Subject: [PATCH] GenGid 2 MustGenGid --- dtmcli/barrier.go | 6 ++--- dtmcli/message.go | 17 +++++--------- dtmcli/saga.go | 9 +++----- dtmcli/tcc.go | 17 +++++++------- dtmcli/types.go | 23 +++++++++++++++---- dtmcli/types_test.go | 21 +++++++++++++++++ dtmcli/xa.go | 43 ++++++++++++++--------------------- dtmsvr/dtmsvr_test.go | 7 ++++++ dtmsvr/trans_xa_test.go | 9 +++++++- examples/main_msg.go | 2 +- examples/main_saga.go | 2 +- examples/main_saga_barrier.go | 2 +- examples/main_tcc.go | 19 ++++------------ examples/main_tcc_barrier.go | 16 +++---------- examples/main_xa.go | 12 +++++++--- examples/quick_start.go | 2 +- 16 files changed, 113 insertions(+), 94 deletions(-) create mode 100644 dtmcli/types_test.go diff --git a/dtmcli/barrier.go b/dtmcli/barrier.go index 8051dfc..db0f477 100644 --- a/dtmcli/barrier.go +++ b/dtmcli/barrier.go @@ -121,10 +121,10 @@ func ThroughBarrierCall(db *sql.DB, transInfo *TransInfo, busiCall BusiFunc) (re if result.Valid { // 数据库里有上一次结果,返回上一次的结果 res = json.Unmarshal([]byte(result.String), &res) return - } else { // 数据库里没有上次的结果,属于重复空补偿,直接返回成功 - res = common.MS{"dtm_result": "SUCCESS"} - return } + // 数据库里没有上次的结果,属于重复空补偿,直接返回成功 + res = common.MS{"dtm_result": "SUCCESS"} + return } res, rerr = busiCall(db) if rerr == nil { // 正确返回了,需要将结果保存到数据库 diff --git a/dtmcli/message.go b/dtmcli/message.go index d6baf84..8d0ea58 100644 --- a/dtmcli/message.go +++ b/dtmcli/message.go @@ -2,7 +2,6 @@ package dtmcli import ( "fmt" - "strings" jsonitor "github.com/json-iterator/go" "github.com/sirupsen/logrus" @@ -55,11 +54,9 @@ func (s *Msg) Add(action string, postData interface{}) *Msg { func (s *Msg) Submit() error { logrus.Printf("committing %s body: %v", s.Gid, &s.MsgData) resp, err := common.RestyClient.R().SetBody(&s.MsgData).Post(fmt.Sprintf("%s/submit", s.Server)) - if err != nil { - return err - } - if !strings.Contains(resp.String(), "SUCCESS") { - return fmt.Errorf("submit failed: %v", resp.Body()) + rerr := CheckDtmResponse(resp, err) + if rerr != nil { + return rerr } s.Gid = jsonitor.Get(resp.Body(), "gid").ToString() return nil @@ -70,11 +67,9 @@ func (s *Msg) Prepare(queryPrepared string) error { s.QueryPrepared = common.OrString(queryPrepared, s.QueryPrepared) logrus.Printf("preparing %s body: %v", s.Gid, &s.MsgData) resp, err := common.RestyClient.R().SetBody(&s.MsgData).Post(fmt.Sprintf("%s/prepare", s.Server)) - if err != nil { - return err - } - if !strings.Contains(resp.String(), "SUCCESS") { - return fmt.Errorf("prepare failed: %v", resp.Body()) + rerr := CheckDtmResponse(resp, err) + if rerr != nil { + return rerr } return nil } diff --git a/dtmcli/saga.go b/dtmcli/saga.go index 20f4843..07d1596 100644 --- a/dtmcli/saga.go +++ b/dtmcli/saga.go @@ -2,7 +2,6 @@ package dtmcli import ( "fmt" - "strings" jsonitor "github.com/json-iterator/go" "github.com/sirupsen/logrus" @@ -56,11 +55,9 @@ func (s *Saga) Add(action string, compensate string, postData interface{}) *Saga func (s *Saga) Submit() error { logrus.Printf("committing %s body: %v", s.Gid, &s.SagaData) resp, err := common.RestyClient.R().SetBody(&s.SagaData).Post(fmt.Sprintf("%s/submit", s.Server)) - if err != nil { - return err - } - if !strings.Contains(resp.String(), "SUCCESS") { - return fmt.Errorf("submit failed: %v", resp.Body()) + rerr := CheckDtmResponse(resp, err) + if rerr != nil { + return rerr } s.Gid = jsonitor.Get(resp.Body(), "gid").ToString() return nil diff --git a/dtmcli/tcc.go b/dtmcli/tcc.go index 0fb00e0..7301417 100644 --- a/dtmcli/tcc.go +++ b/dtmcli/tcc.go @@ -27,29 +27,28 @@ func TccGlobalTransaction(dtm string, gid string, tccFunc TccGlobalFunc) (rerr e "trans_type": "tcc", } defer func() { + var resp *resty.Response var err error var x interface{} if x = recover(); x != nil || rerr != nil { - _, err = common.RestyClient.R().SetBody(data).Post(dtm + "/abort") + resp, err = common.RestyClient.R().SetBody(data).Post(dtm + "/abort") } else { - _, err = common.RestyClient.R().SetBody(data).Post(dtm + "/submit") + resp, err = common.RestyClient.R().SetBody(data).Post(dtm + "/submit") } - if err != nil { - logrus.Errorf("submitting or abort global transaction error: %v", err) + err2 := CheckDtmResponse(resp, err) + if err2 != nil { + logrus.Errorf("submitting or abort global transaction error: %v", err2) } if x != nil { panic(x) } }() tcc := &Tcc{Dtm: dtm, Gid: gid} - resp, rerr := common.RestyClient.R().SetBody(data).Post(tcc.Dtm + "/prepare") + resp, err := common.RestyClient.R().SetBody(data).Post(tcc.Dtm + "/prepare") + rerr = CheckDtmResponse(resp, err) if rerr != nil { return } - if !strings.Contains(resp.String(), "SUCCESS") { - rerr = fmt.Errorf("bad response: %s", resp.String()) - return - } rerr = tccFunc(tcc) return } diff --git a/dtmcli/types.go b/dtmcli/types.go index d6766f9..e443fca 100644 --- a/dtmcli/types.go +++ b/dtmcli/types.go @@ -2,18 +2,33 @@ package dtmcli import ( "fmt" + "strings" + "github.com/go-resty/resty/v2" "github.com/yedf/dtm/common" ) -// GenGid generate a new gid -func GenGid(server string) string { +// MustGenGid generate a new gid +func MustGenGid(server string) string { res := common.MS{} - _, err := common.RestyClient.R().SetResult(&res).Get(server + "/newGid") - e2p(err) + resp, err := common.RestyClient.R().SetResult(&res).Get(server + "/newGid") + if err != nil || res["gid"] == "" { + panic(fmt.Errorf("newGid error: %v, resp: %s", err, resp)) + } return res["gid"] } +// CheckDtmResponse check the response of dtm, if not ok ,generate error +func CheckDtmResponse(resp *resty.Response, err error) error { + if err != nil { + return err + } + if !strings.Contains(resp.String(), "SUCCESS") { + return fmt.Errorf("dtm response failed: %s", resp.String()) + } + return nil +} + // IDGenerator used to generate a branch id type IDGenerator struct { parentID string diff --git a/dtmcli/types_test.go b/dtmcli/types_test.go new file mode 100644 index 0000000..f72e6b1 --- /dev/null +++ b/dtmcli/types_test.go @@ -0,0 +1,21 @@ +package dtmcli + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/yedf/dtm/common" +) + +func TestTypes(t *testing.T) { + err := common.CatchP(func() { + idGen := IDGenerator{parentID: "12345678901234567890123"} + idGen.NewBranchID() + }) + assert.Error(t, err) + err = common.CatchP(func() { + idGen := IDGenerator{branchID: 99} + idGen.NewBranchID() + }) + assert.Error(t, err) +} diff --git a/dtmcli/xa.go b/dtmcli/xa.go index 52c650f..13f10b6 100644 --- a/dtmcli/xa.go +++ b/dtmcli/xa.go @@ -35,16 +35,6 @@ type Xa struct { Gid string } -// GetParams get xa params map -func (x *Xa) GetParams(branchID string) common.MS { - return common.MS{ - "gid": x.Gid, - "trans_type": "xa", - "branch_id": branchID, - "branch_type": "action", - } -} - // XaFromReq construct xa info from request func XaFromReq(c *gin.Context) *Xa { return &Xa{ @@ -53,20 +43,17 @@ func XaFromReq(c *gin.Context) *Xa { } } -// NewXaBranchID generate a xa branch id -func (x *Xa) NewXaBranchID() string { - return x.Gid + "-" + x.NewBranchID() -} - // NewXaClient construct a xa client -func NewXaClient(server string, mysqlConf map[string]string, app *gin.Engine, callbackURL string) *XaClient { +func NewXaClient(server string, mysqlConf map[string]string, app *gin.Engine, callbackURL string) (*XaClient, error) { xa := &XaClient{ Server: server, Conf: mysqlConf, CallbackURL: callbackURL, } u, err := url.Parse(callbackURL) - e2p(err) + if err != nil { + return nil, err + } app.POST(u.Path, common.WrapHandler(func(c *gin.Context) (interface{}, error) { type CallbackReq struct { Gid string `json:"gid"` @@ -75,7 +62,9 @@ func NewXaClient(server string, mysqlConf map[string]string, app *gin.Engine, ca } req := CallbackReq{} b, err := c.GetRawData() - e2p(err) + if err != nil { + return nil, err + } common.MustUnmarshal(b, &req) tx, my := common.DbAlone(xa.Conf) defer my.Close() @@ -89,7 +78,7 @@ func NewXaClient(server string, mysqlConf map[string]string, app *gin.Engine, ca } return M{"dtm_result": "SUCCESS"}, nil })) - return xa + return xa, nil } // XaLocalTransaction start a xa local transaction @@ -131,17 +120,19 @@ func (xc *XaClient) XaGlobalTransaction(gid string, transFunc XaGlobalFunc) erro } } }() - resp, rerr := common.RestyClient.R().SetBody(data).Post(xc.Server + "/prepare") - if !strings.Contains(resp.String(), "SUCCESS") { - return fmt.Errorf("unexpected result: %s", resp.String()) + resp, err := common.RestyClient.R().SetBody(data).Post(xc.Server + "/prepare") + rerr := CheckDtmResponse(resp, err) + if rerr != nil { + return rerr } rerr = transFunc(&xa) if rerr != nil { return rerr } - resp, rerr = common.RestyClient.R().SetBody(data).Post(xc.Server + "/submit") - if !strings.Contains(resp.String(), "SUCCESS") { - return fmt.Errorf("unexpected result: %s err: %v", resp.String(), rerr) + resp, err = common.RestyClient.R().SetBody(data).Post(xc.Server + "/submit") + rerr = CheckDtmResponse(resp, err) + if rerr != nil { + return rerr } return nil } @@ -159,7 +150,7 @@ func (x *Xa) CallBranch(body interface{}, url string) (*resty.Response, error) { }). Post(url) if strings.Contains(resp.String(), "FAILURE") { - return resp, fmt.Errorf("unexpected result: %s err: %v", resp.String(), err) + return resp, fmt.Errorf("FAILURE result: %s err: %v", resp.String(), err) } return resp, err } diff --git a/dtmsvr/dtmsvr_test.go b/dtmsvr/dtmsvr_test.go index 10957d5..b0c8910 100644 --- a/dtmsvr/dtmsvr_test.go +++ b/dtmsvr/dtmsvr_test.go @@ -51,6 +51,13 @@ func TestCover(t *testing.T) { go CronExpiredTrans(1) } +func TestType(t *testing.T) { + err := common.CatchP(func() { + dtmcli.MustGenGid("http://localhost:8080/api/no") + }) + assert.Error(t, err) +} + func getTransStatus(gid string) string { sm := TransGlobal{} dbr := dbGet().Model(&sm).Where("gid=?", gid).First(&sm) diff --git a/dtmsvr/trans_xa_test.go b/dtmsvr/trans_xa_test.go index 7d8dd94..851f9c6 100644 --- a/dtmsvr/trans_xa_test.go +++ b/dtmsvr/trans_xa_test.go @@ -1,6 +1,7 @@ package dtmsvr import ( + "fmt" "testing" "github.com/sirupsen/logrus" @@ -11,11 +12,17 @@ import ( ) func TestXa(t *testing.T) { - + xaLocalError(t) xaNormal(t) xaRollback(t) } +func xaLocalError(t *testing.T) { + err := examples.XaClient.XaGlobalTransaction("xaLocalError", func(xa *dtmcli.Xa) error { + return fmt.Errorf("an error") + }) + assert.Error(t, err, fmt.Errorf("an error")) +} func xaNormal(t *testing.T) { xc := examples.XaClient gid := "xaNormal" diff --git a/examples/main_msg.go b/examples/main_msg.go index 9aaab48..81c2ee5 100644 --- a/examples/main_msg.go +++ b/examples/main_msg.go @@ -18,7 +18,7 @@ func MsgFireRequest() string { TransInResult: "SUCCESS", TransOutResult: "SUCCESS", } - msg := dtmcli.NewMsg(DtmServer, dtmcli.GenGid(DtmServer)). + msg := dtmcli.NewMsg(DtmServer, dtmcli.MustGenGid(DtmServer)). Add(Busi+"/TransOut", req). Add(Busi+"/TransIn", req) err := msg.Prepare(Busi + "/TransQuery") diff --git a/examples/main_saga.go b/examples/main_saga.go index b28723b..36c349e 100644 --- a/examples/main_saga.go +++ b/examples/main_saga.go @@ -18,7 +18,7 @@ func SagaFireRequest() string { TransInResult: "SUCCESS", TransOutResult: "SUCCESS", } - saga := dtmcli.NewSaga(DtmServer, dtmcli.GenGid(DtmServer)). + saga := dtmcli.NewSaga(DtmServer, dtmcli.MustGenGid(DtmServer)). Add(Busi+"/TransOut", Busi+"/TransOutRevert", req). Add(Busi+"/TransIn", Busi+"/TransInRevert", req) logrus.Printf("saga busi trans submit") diff --git a/examples/main_saga_barrier.go b/examples/main_saga_barrier.go index 9d4bd01..3cd7ceb 100644 --- a/examples/main_saga_barrier.go +++ b/examples/main_saga_barrier.go @@ -14,7 +14,7 @@ import ( func SagaBarrierFireRequest() string { logrus.Printf("a busi transaction begin") req := &TransReq{Amount: 30} - saga := dtmcli.NewSaga(DtmServer, dtmcli.GenGid(DtmServer)). + saga := dtmcli.NewSaga(DtmServer, dtmcli.MustGenGid(DtmServer)). Add(Busi+"/SagaBTransOut", Busi+"/SagaBTransOutCompensate", req). Add(Busi+"/SagaBTransIn", Busi+"/SagaBTransInCompensate", req) logrus.Printf("busi trans submit") diff --git a/examples/main_tcc.go b/examples/main_tcc.go index 6ac4b47..fa2e310 100644 --- a/examples/main_tcc.go +++ b/examples/main_tcc.go @@ -11,16 +11,11 @@ import ( func TccSetup(app *gin.Engine) { app.POST(BusiAPI+"/TransInTcc", common.WrapHandler(func(c *gin.Context) (interface{}, error) { tcc, err := dtmcli.TccFromReq(c) - if err != nil { - return nil, err - } + e2p(err) req := reqFrom(c) logrus.Printf("Trans in %d here, and Trans in another %d in call2 ", req.Amount/2, req.Amount/2) _, rerr := tcc.CallBranch(&TransReq{Amount: req.Amount / 2}, Busi+"/TransIn", Busi+"/TransInConfirm", Busi+"/TransInRevert") - if rerr != nil { - return nil, rerr - } - + e2p(rerr) return M{"dtm_result": "SUCCESS"}, nil })) @@ -29,16 +24,12 @@ func TccSetup(app *gin.Engine) { // TccFireRequest 1 func TccFireRequest() string { logrus.Printf("tcc transaction begin") - gid := dtmcli.GenGid(DtmServer) + gid := dtmcli.MustGenGid(DtmServer) err := dtmcli.TccGlobalTransaction(DtmServer, gid, func(tcc *dtmcli.Tcc) (rerr error) { res1, rerr := tcc.CallBranch(&TransReq{Amount: 30}, Busi+"/TransOut", Busi+"/TransOutConfirm", Busi+"/TransOutRevert") - if rerr != nil { - return - } + e2p(rerr) res2, rerr := tcc.CallBranch(&TransReq{Amount: 30}, Busi+"/TransInTcc", Busi+"/TransInConfirm", Busi+"/TransInRevert") - if rerr != nil { - return - } + e2p(rerr) logrus.Printf("tcc returns: %s, %s", res1.String(), res2.String()) return }) diff --git a/examples/main_tcc_barrier.go b/examples/main_tcc_barrier.go index d7c0390..761ff2b 100644 --- a/examples/main_tcc_barrier.go +++ b/examples/main_tcc_barrier.go @@ -13,22 +13,12 @@ import ( // TccBarrierFireRequest 1 func TccBarrierFireRequest() string { logrus.Printf("tcc transaction begin") - gid := dtmcli.GenGid(DtmServer) + gid := dtmcli.MustGenGid(DtmServer) err := dtmcli.TccGlobalTransaction(DtmServer, gid, func(tcc *dtmcli.Tcc) (rerr error) { res1, rerr := tcc.CallBranch(&TransReq{Amount: 30}, Busi+"/TccBTransOutTry", Busi+"/TccBTransOutConfirm", Busi+"/TccBTransOutCancel") - if rerr != nil { - return - } - if res1.StatusCode() != 200 { - return fmt.Errorf("bad status code: %d", res1.StatusCode()) - } + common.CheckRestySuccess(res1, rerr) res2, rerr := tcc.CallBranch(&TransReq{Amount: 30}, Busi+"/TccBTransInTry", Busi+"/TccBTransInConfirm", Busi+"/TccBTransInCancel") - if rerr != nil { - return - } - if res2.StatusCode() != 200 { - return fmt.Errorf("bad status code: %d", res2.StatusCode()) - } + common.CheckRestySuccess(res1, rerr) logrus.Printf("tcc returns: %s, %s", res1.String(), res2.String()) return }) diff --git a/examples/main_xa.go b/examples/main_xa.go index e9fc73d..b55450b 100644 --- a/examples/main_xa.go +++ b/examples/main_xa.go @@ -2,6 +2,7 @@ package examples import ( "fmt" + "strings" "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" @@ -38,7 +39,7 @@ func dbGet() *common.DB { // XaFireRequest 1 func XaFireRequest() string { - gid := dtmcli.GenGid(DtmServer) + gid := dtmcli.MustGenGid(DtmServer) err := XaClient.XaGlobalTransaction(gid, func(xa *dtmcli.Xa) (rerr error) { defer common.P2E(&rerr) req := GenTransReq(30, false, false) @@ -57,18 +58,23 @@ func XaSetup(app *gin.Engine) { app.POST(BusiAPI+"/TransInXa", common.WrapHandler(xaTransIn)) app.POST(BusiAPI+"/TransOutXa", common.WrapHandler(xaTransOut)) config.Mysql["database"] = "dtm_busi" - XaClient = dtmcli.NewXaClient(DtmServer, config.Mysql, app, Busi+"/xa") + var err error + XaClient, err = dtmcli.NewXaClient(DtmServer, config.Mysql, app, Busi+"/xa") + e2p(err) } func xaTransIn(c *gin.Context) (interface{}, error) { err := XaClient.XaLocalTransaction(c, func(db *common.DB, xa *dtmcli.Xa) (rerr error) { req := reqFrom(c) if req.TransInResult != "SUCCESS" { - return fmt.Errorf("tranIn failed") + return fmt.Errorf("tranIn FAILURE") } dbr := db.Exec("update user_account set balance=balance+? where user_id=?", req.Amount, 2) return dbr.Error }) + if err != nil && strings.Contains(err.Error(), "FAILURE") { + return M{"dtm_result": "FAILURE"}, nil + } e2p(err) return M{"dtm_result": "SUCCESS"}, nil } diff --git a/examples/quick_start.go b/examples/quick_start.go index 57ac94d..a2f268a 100644 --- a/examples/quick_start.go +++ b/examples/quick_start.go @@ -31,7 +31,7 @@ func QsStartSvr() { func QsFireRequest() string { req := &gin.H{"amount": 30} // 微服务的载荷 // DtmServer为DTM服务的地址 - saga := dtmcli.NewSaga(DtmServer, dtmcli.GenGid(DtmServer)). + saga := dtmcli.NewSaga(DtmServer, dtmcli.MustGenGid(DtmServer)). // 添加一个TransOut的子事务,正向操作为url: qsBusi+"/TransOut", 逆向操作为url: qsBusi+"/TransOutCompensate" Add(qsBusi+"/TransOut", qsBusi+"/TransOutCompensate", req). // 添加一个TransIn的子事务,正向操作为url: qsBusi+"/TransOut", 逆向操作为url: qsBusi+"/TransInCompensate"