diff --git a/app/main.go b/app/main.go index cc349d4..92fd312 100644 --- a/app/main.go +++ b/app/main.go @@ -1,7 +1,9 @@ package main import ( + "fmt" "os" + "strings" "time" "github.com/yedf/dtm/dtmcli" @@ -19,6 +21,25 @@ func wait() { } func main() { + if len(os.Args) == 1 { + for _, ln := range []string{ + "dtm is a lightweight distributed transaction manager.", + "usage:", + "dtm [command]", + "", + "Available Commands:", + "dtmsvr run dtm as a server. ", + "", + "quick_start run quick start example. dtm will create all needed table", + "qs same as quick_start", + } { + fmt.Print(ln + "\n") + } + for name := range examples.Samples { + fmt.Printf("%-18srun a sample includes %s\n", name, strings.Replace(name, "_", " ", 100)) + } + return + } onlyServer := len(os.Args) > 1 && os.Args[1] == "dtmsvr" if !onlyServer { // 实际线上运行,只启动dtmsvr,不准备table相关的数据 dtmsvr.PopulateDB(true) @@ -43,8 +64,8 @@ func main() { examples.GrpcStartup() examples.BaseAppStartup() - fn := examples.Samples[os.Args[1]] - dtmcli.LogIfFatalf(fn == nil, "no sample name for %s", os.Args[1]) - fn() + sample := examples.Samples[os.Args[1]] + dtmcli.LogIfFatalf(sample == nil, "no sample name for %s", os.Args[1]) + sample.Action() wait() } diff --git a/common/types.go b/common/types.go index 65bf92b..6e26204 100644 --- a/common/types.go +++ b/common/types.go @@ -131,6 +131,9 @@ type dtmConfigType struct { var DtmConfig = dtmConfigType{} func init() { + if len(os.Args) == 1 { + return + } DtmConfig.TransCronInterval = int64(dtmcli.MustAtoi(dtmcli.OrString(os.Getenv("TRANS_CRON_INTERVAL"), "10"))) DtmConfig.DB = map[string]string{ "driver": dtmcli.OrString(os.Getenv("DB_DRIVER"), "mysql"), diff --git a/examples/data.go b/examples/data.go index 9313e81..ac62da1 100644 --- a/examples/data.go +++ b/examples/data.go @@ -37,12 +37,16 @@ func PopulateDB(skipDrop bool) { RunSQLScript(config.DB, file, skipDrop) } +type sampleInfo struct { + Arg string + Action func() string + Desc string +} + // Samples 所有的示例都会注册到这里 -var Samples = map[string]func() string{} +var Samples = map[string]*sampleInfo{} func addSample(name string, fn func() string) { - if Samples[name] != nil { - dtmcli.LogFatalf("sample %s duplicated", name) - } - Samples[name] = fn + dtmcli.LogIfFatalf(Samples[name] != nil, "%s already exists") + Samples[name] = &sampleInfo{Arg: name, Action: fn} } diff --git a/test/examples_test.go b/test/examples_test.go index 7b3e044..059f803 100644 --- a/test/examples_test.go +++ b/test/examples_test.go @@ -9,7 +9,7 @@ import ( func TestExamples(t *testing.T) { // for coverage examples.QsStartSvr() - for _, fn := range examples.Samples { - assertSucceed(t, fn()) + for _, s := range examples.Samples { + assertSucceed(t, s.Action()) } }