wip: version cmd/flag use common functions

This commit is contained in:
Suyono 2024-01-11 11:47:50 +11:00
parent fe465ad031
commit 6a68209629
4 changed files with 133 additions and 100 deletions

38
cmd/cli/splitargs.go Normal file
View File

@ -0,0 +1,38 @@
package cli
import (
"errors"
"os"
)
func SplitArgs() ([]string, []string, error) {
var (
i int
arg string
selfArgs []string
childArgs []string
)
found := false
for i, arg = range os.Args {
if arg == "--" {
found = true
if i+1 == len(os.Args) {
return nil, nil, errors.New("invalid argument")
}
if len(os.Args[i+1:]) == 0 {
return nil, nil, errors.New("invalid argument")
}
selfArgs = os.Args[1:i]
childArgs = os.Args[i+1:]
break
}
if !found {
return nil, nil, errors.New("invalid argument")
}
}
return selfArgs, childArgs, nil
}

38
cmd/cli/version.go Normal file
View File

@ -0,0 +1,38 @@
package cli
import (
"fmt"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"os"
)
type Version string
const versionFlag = "version"
func (v Version) Print() {
fmt.Print(v)
os.Exit(0)
}
func (v Version) Cmd(cmd *cobra.Command) {
cmd.AddCommand(&cobra.Command{
Use: "version",
RunE: func(cmd *cobra.Command, args []string) error {
v.Print()
return nil
},
})
}
func (v Version) Flag(cmd *cobra.Command) {
cmd.PersistentFlags().Bool(versionFlag, false, "print version")
_ = viper.BindPFlag(versionFlag, cmd.PersistentFlags().Lookup(versionFlag))
}
func (v Version) FlagHook() {
if viper.GetBool(versionFlag) {
v.Print()
}
}

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"gitea.suyono.dev/suyono/wingmate" "gitea.suyono.dev/suyono/wingmate"
"gitea.suyono.dev/suyono/wingmate/cmd/cli"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/viper" "github.com/spf13/viper"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
@ -18,6 +19,7 @@ import (
type execApp struct { type execApp struct {
childArgs []string childArgs []string
err error err error
version cli.Version
} }
const ( const (
@ -36,15 +38,16 @@ var (
func main() { func main() {
var ( var (
selfArgs []string selfArgs []string
childArgs []string childArgs []string
app *execApp app *execApp
rootCmd *cobra.Command rootCmd *cobra.Command
versionCmd *cobra.Command err error
err error
) )
app = &execApp{} app = &execApp{
version: cli.Version(version),
}
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "wmexec", Use: "wmexec",
@ -52,19 +55,13 @@ func main() {
RunE: app.execCmd, RunE: app.execCmd,
} }
versionCmd = &cobra.Command{
Use: "version",
RunE: app.versionCmd,
}
rootCmd.PersistentFlags().BoolP(setsidFlag, "s", false, "set to true to run setsid() before exec") rootCmd.PersistentFlags().BoolP(setsidFlag, "s", false, "set to true to run setsid() before exec")
viper.BindPFlag(EnvSetsid, rootCmd.PersistentFlags().Lookup(setsidFlag)) viper.BindPFlag(EnvSetsid, rootCmd.PersistentFlags().Lookup(setsidFlag))
rootCmd.PersistentFlags().StringP(userFlag, "u", "", "\"user:[group]\"") rootCmd.PersistentFlags().StringP(userFlag, "u", "", "\"user:[group]\"")
viper.BindPFlag(EnvUser, rootCmd.PersistentFlags().Lookup(userFlag)) viper.BindPFlag(EnvUser, rootCmd.PersistentFlags().Lookup(userFlag))
rootCmd.PersistentFlags().Bool(versionFlag, false, "print version") app.version.Flag(rootCmd)
viper.BindPFlag(versionFlag, rootCmd.PersistentFlags().Lookup(versionFlag))
viper.SetEnvPrefix(wingmate.EnvPrefix) viper.SetEnvPrefix(wingmate.EnvPrefix)
viper.BindEnv(EnvUser) viper.BindEnv(EnvUser)
@ -72,11 +69,12 @@ func main() {
viper.SetDefault(EnvSetsid, false) viper.SetDefault(EnvSetsid, false)
viper.SetDefault(EnvUser, "") viper.SetDefault(EnvUser, "")
rootCmd.AddCommand(versionCmd) app.version.Cmd(rootCmd)
selfArgs, childArgs, err = argSplit() selfArgs, childArgs, err = cli.SplitArgs()
app.childArgs = childArgs app.childArgs = childArgs
app.err = err app.err = err
rootCmd.SetArgs(selfArgs) rootCmd.SetArgs(selfArgs)
if err := rootCmd.Execute(); err != nil { if err := rootCmd.Execute(); err != nil {
log.Println(err) log.Println(err)
@ -84,51 +82,8 @@ func main() {
} }
} }
func argSplit() ([]string, []string, error) {
var (
i int
arg string
selfArgs []string
childArgs []string
)
found := false
for i, arg = range os.Args {
if arg == "--" {
found = true
if i+1 == len(os.Args) {
return nil, nil, errors.New("invalid argument")
}
if len(os.Args[i+1:]) == 0 {
return nil, nil, errors.New("invalid argument")
}
selfArgs = os.Args[1:i]
childArgs = os.Args[i+1:]
break
}
if !found {
return nil, nil, errors.New("invalid argument")
}
}
return selfArgs, childArgs, nil
}
func (e *execApp) versionCmd(cmd *cobra.Command, args []string) error {
e.printVersion()
return nil
}
func (e *execApp) printVersion() {
fmt.Print(version)
os.Exit(0)
}
func (e *execApp) execCmd(cmd *cobra.Command, args []string) error { func (e *execApp) execCmd(cmd *cobra.Command, args []string) error {
if viper.GetBool(versionFlag) { e.version.FlagHook()
e.printVersion()
}
if e.err != nil { if e.err != nil {
return e.err return e.err

View File

@ -3,6 +3,7 @@ package main
import ( import (
"bufio" "bufio"
"errors" "errors"
"gitea.suyono.dev/suyono/wingmate/cmd/cli"
"log" "log"
"os" "os"
"os/exec" "os/exec"
@ -15,8 +16,16 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/viper" "github.com/spf13/viper"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
_ "embed"
) )
type pidProxyApp struct {
childArgs []string
err error
version cli.Version
}
const ( const (
pidFileFlag = "pid-file" pidFileFlag = "pid-file"
EnvStartSecs = "STARTSECS" EnvStartSecs = "STARTSECS"
@ -24,22 +33,30 @@ const (
) )
var ( var (
rootCmd = &cobra.Command{
Use: "wmpidproxy",
RunE: pidProxy,
}
childArgs []string //go:embed version.txt
version string
) )
func main() { func main() {
var ( var (
i int selfArgs []string
arg string childArgs []string
selfArgs []string err error
found bool app *pidProxyApp
rootCmd *cobra.Command
) )
app = &pidProxyApp{
version: cli.Version(version),
}
rootCmd = &cobra.Command{
Use: "wmpidproxy",
SilenceUsage: true,
RunE: app.pidProxy,
}
viper.SetEnvPrefix(wingmate.EnvPrefix) viper.SetEnvPrefix(wingmate.EnvPrefix)
viper.BindEnv(EnvStartSecs) viper.BindEnv(EnvStartSecs)
viper.SetDefault(EnvStartSecs, EnvDefaultStartSecs) viper.SetDefault(EnvStartSecs, EnvDefaultStartSecs)
@ -48,44 +65,29 @@ func main() {
rootCmd.MarkFlagRequired(pidFileFlag) rootCmd.MarkFlagRequired(pidFileFlag)
viper.BindPFlag(pidFileFlag, rootCmd.PersistentFlags().Lookup(pidFileFlag)) viper.BindPFlag(pidFileFlag, rootCmd.PersistentFlags().Lookup(pidFileFlag))
found = false app.version.Flag(rootCmd)
for i, arg = range os.Args { app.version.Cmd(rootCmd)
if arg == "--" {
found = true
if len(os.Args) <= i+1 {
log.Println("invalid argument")
os.Exit(1)
}
selfArgs = os.Args[1:i]
childArgs = os.Args[i+1:]
break
}
}
if !found {
log.Println("invalid argument")
os.Exit(1)
}
if len(childArgs) == 0 { selfArgs, childArgs, err = cli.SplitArgs()
log.Println("invalid argument") app.childArgs = childArgs
os.Exit(1) app.err = err
}
rootCmd.SetArgs(selfArgs) rootCmd.SetArgs(selfArgs)
if err := rootCmd.Execute(); err != nil { if err := rootCmd.Execute(); err != nil {
log.Println(err) log.Println(err)
os.Exit(1) os.Exit(1)
} }
} }
func pidProxy(cmd *cobra.Command, args []string) error { func (p *pidProxyApp) pidProxy(cmd *cobra.Command, args []string) error {
p.version.FlagHook()
pidfile := viper.GetString(pidFileFlag) pidfile := viper.GetString(pidFileFlag)
log.Printf("%s %v", pidfile, childArgs) log.Printf("%s %v", pidfile, p.childArgs)
if len(childArgs) > 1 { if len(p.childArgs) > 1 {
go startProcess(childArgs[0], childArgs[1:]...) go p.startProcess(p.childArgs[0], p.childArgs[1:]...)
} else { } else {
go startProcess(childArgs[0]) go p.startProcess(p.childArgs[0])
} }
initialWait := viper.GetInt(EnvStartSecs) initialWait := viper.GetInt(EnvStartSecs)
time.Sleep(time.Second * time.Duration(initialWait)) time.Sleep(time.Second * time.Duration(initialWait))
@ -104,7 +106,7 @@ func pidProxy(cmd *cobra.Command, args []string) error {
check: check:
for { for {
if pid, err = readPid(pidfile); err != nil { if pid, err = p.readPid(pidfile); err != nil {
return err return err
} }
@ -115,7 +117,7 @@ check:
select { select {
case <-t.C: case <-t.C:
case <-sc: case <-sc:
if pid, err = readPid(pidfile); err != nil { if pid, err = p.readPid(pidfile); err != nil {
return err return err
} }
@ -128,7 +130,7 @@ check:
return nil return nil
} }
func readPid(pidFile string) (int, error) { func (p *pidProxyApp) readPid(pidFile string) (int, error) {
var ( var (
file *os.File file *os.File
err error err error
@ -153,7 +155,7 @@ func readPid(pidFile string) (int, error) {
} }
} }
func startProcess(arg0 string, args ...string) { func (p *pidProxyApp) startProcess(arg0 string, args ...string) {
if err := exec.Command(arg0, args...).Run(); err != nil { if err := exec.Command(arg0, args...).Run(); err != nil {
log.Println("exec:", err) log.Println("exec:", err)
return return