183 lines
4.7 KiB
Go
183 lines
4.7 KiB
Go
package dtmcli
|
||
|
||
import (
|
||
"database/sql"
|
||
"errors"
|
||
"fmt"
|
||
"strings"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/go-resty/resty/v2"
|
||
)
|
||
|
||
// M a short name
|
||
type M = map[string]interface{}
|
||
|
||
// MS a short name
|
||
type MS = map[string]string
|
||
|
||
// MustGenGid generate a new gid
|
||
func MustGenGid(server string) string {
|
||
res := MS{}
|
||
resp, err := RestyClient.R().SetResult(&res).Get(server + "/newGid")
|
||
if err != nil || res["gid"] == "" {
|
||
panic(fmt.Errorf("newGid error: %v, resp: %s", err, resp))
|
||
}
|
||
return res["gid"]
|
||
}
|
||
|
||
var sqlDbs = map[string]*sql.DB{}
|
||
|
||
// SdbGet get pooled sql.DB
|
||
func SdbGet(conf map[string]string) *sql.DB {
|
||
dsn := GetDsn(conf)
|
||
if sqlDbs[dsn] == nil {
|
||
sqlDbs[dsn] = SdbAlone(conf)
|
||
}
|
||
return sqlDbs[dsn]
|
||
}
|
||
|
||
// SdbAlone get a standalone db connection
|
||
func SdbAlone(conf map[string]string) *sql.DB {
|
||
dsn := GetDsn(conf)
|
||
Logf("opening alone %s: %s", conf["driver"], strings.Replace(dsn, conf["password"], "****", 1))
|
||
mdb, err := sql.Open(conf["driver"], dsn)
|
||
E2P(err)
|
||
return mdb
|
||
}
|
||
|
||
// SdbExec use raw db to exec
|
||
func SdbExec(db *sql.DB, sql string, values ...interface{}) (affected int64, rerr error) {
|
||
r, rerr := db.Exec(sql, values...)
|
||
if rerr == nil {
|
||
affected, rerr = r.RowsAffected()
|
||
Logf("affected: %d for %s %v", affected, sql, values)
|
||
} else {
|
||
LogRedf("exec error: %v for %s %v", rerr, sql, values)
|
||
}
|
||
return
|
||
}
|
||
|
||
// StxExec use raw tx to exec
|
||
func StxExec(tx *sql.Tx, sql string, values ...interface{}) (affected int64, rerr error) {
|
||
r, rerr := tx.Exec(sql, values...)
|
||
if rerr == nil {
|
||
affected, rerr = r.RowsAffected()
|
||
Logf("affected: %d for %s %v", affected, sql, values)
|
||
} else {
|
||
LogRedf("exec error: %v for %s %v", rerr, sql, values)
|
||
}
|
||
return
|
||
}
|
||
|
||
// StxQueryRow use raw tx to query row
|
||
func StxQueryRow(tx *sql.Tx, query string, args ...interface{}) *sql.Row {
|
||
Logf("querying: "+query, args...)
|
||
return tx.QueryRow(query, args...)
|
||
}
|
||
|
||
// GetDsn get dsn from map config
|
||
func GetDsn(conf map[string]string) string {
|
||
conf["host"] = MayReplaceLocalhost(conf["host"])
|
||
driver := conf["driver"]
|
||
dsn := MS{
|
||
"mysql": fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local",
|
||
conf["user"], conf["password"], conf["host"], conf["port"], conf["database"]),
|
||
"postgres": fmt.Sprintf("host=%s user=%s password=%s dbname='%s' port=%s sslmode=disable TimeZone=Asia/Shanghai",
|
||
conf["host"], conf["user"], conf["password"], conf["database"], conf["port"]),
|
||
}[driver]
|
||
PanicIf(dsn == "", fmt.Errorf("unknow driver: %s", driver))
|
||
return dsn
|
||
}
|
||
|
||
// CheckResponse 检查Response,返回错误
|
||
func CheckResponse(resp *resty.Response, err error) error {
|
||
if err == nil && resp != nil {
|
||
if resp.IsError() {
|
||
return errors.New(resp.String())
|
||
} else if strings.Contains(resp.String(), "FAILURE") {
|
||
return ErrFailure
|
||
}
|
||
}
|
||
return err
|
||
}
|
||
|
||
// CheckResult 检查Result,返回错误
|
||
func CheckResult(res interface{}, err error) error {
|
||
resp, ok := res.(*resty.Response)
|
||
if ok {
|
||
return CheckResponse(resp, err)
|
||
}
|
||
if res != nil && strings.Contains(MustMarshalString(res), "FAILURE") {
|
||
return ErrFailure
|
||
}
|
||
return err
|
||
}
|
||
|
||
// IDGenerator used to generate a branch id
|
||
type IDGenerator struct {
|
||
parentID string
|
||
branchID int
|
||
}
|
||
|
||
// NewBranchID generate a branch id
|
||
func (g *IDGenerator) NewBranchID() string {
|
||
if g.branchID >= 99 {
|
||
panic(fmt.Errorf("branch id is larger than 99"))
|
||
}
|
||
if len(g.parentID) >= 20 {
|
||
panic(fmt.Errorf("total branch id is longer than 20"))
|
||
}
|
||
g.branchID = g.branchID + 1
|
||
return g.parentID + fmt.Sprintf("%02d", g.branchID)
|
||
}
|
||
|
||
// TransResult dtm 返回的结果
|
||
type TransResult struct {
|
||
DtmResult string `json:"dtm_result"`
|
||
Message string
|
||
}
|
||
|
||
// TransBase 事务的基础类
|
||
type TransBase struct {
|
||
IDGenerator
|
||
Dtm string
|
||
// WaitResult 是否等待全局事务的最终结果
|
||
WaitResult bool
|
||
}
|
||
|
||
// TransBaseFromReq construct xa info from request
|
||
func TransBaseFromReq(c *gin.Context) *TransBase {
|
||
return &TransBase{
|
||
IDGenerator: IDGenerator{parentID: c.Query("branch_id")},
|
||
Dtm: c.Query("dtm"),
|
||
}
|
||
}
|
||
|
||
// CallDtm 调用dtm服务器,返回事务的状态
|
||
func (tb *TransBase) CallDtm(body interface{}, operation string) error {
|
||
params := MS{}
|
||
if tb.WaitResult {
|
||
params["wait_result"] = "1"
|
||
}
|
||
resp, err := RestyClient.R().SetQueryParams(params).
|
||
SetResult(&TransResult{}).SetBody(body).Post(fmt.Sprintf("%s/%s", tb.Dtm, operation))
|
||
if err != nil {
|
||
return err
|
||
}
|
||
tr := resp.Result().(*TransResult)
|
||
if tr.DtmResult == "FAILURE" {
|
||
return errors.New("FAILURE: " + tr.Message)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// ErrFailure 表示返回失败,要求回滚
|
||
var ErrFailure = errors.New("transaction FAILURE")
|
||
|
||
// ResultSuccess 表示返回成功,可以进行下一步
|
||
var ResultSuccess = M{"dtm_result": "SUCCESS"}
|
||
|
||
// ResultFailure 表示返回失败,要求回滚
|
||
var ResultFailure = M{"dtm_result": "FAILURE"}
|