test-client: good for test

This commit is contained in:
Suyono 2025-04-24 10:27:59 +10:00
parent 024e4cba4f
commit 4abc00f07c

View File

@ -1,17 +1,5 @@
package main package main
import (
"fmt"
"net"
"time"
"gitea.suyono.dev/suyono/netbounce/cmd/slicewriter"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/spf13/pflag"
"github.com/spf13/viper"
)
/* /*
Copyright 2025 Suyono <suyono3484@gmail.com> Copyright 2025 Suyono <suyono3484@gmail.com>
@ -28,25 +16,54 @@ import (
limitations under the License. limitations under the License.
*/ */
import (
"context"
"fmt"
"golang.org/x/sys/unix"
"net"
"os/signal"
"time"
"gitea.suyono.dev/suyono/netbounce/cmd/slicewriter"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/spf13/pflag"
"github.com/spf13/viper"
)
var gLimit *counter
func main() { func main() {
ctx, cancel := signal.NotifyContext(context.Background(), unix.SIGINT, unix.SIGTERM)
defer cancel()
parseFlags() parseFlags()
log.Debug().Msgf("Sending messages to %s server: %s", viper.GetString("protocol"), viper.GetString("server")) log.Debug().Msgf("Sending messages to %s server: %s", viper.GetString(PROTOCOL), viper.GetString(SERVER))
sendMessages() sendMessages(ctx)
} }
func sendMessages() { const (
switch viper.GetString("protocol") { SERVER = "server"
PROTOCOL = "protocol"
UDP = "udp"
TCP = "tcp"
NAME = "name"
MESSAGE = "message"
SLEEP = "sleep"
)
func sendMessages(ctx context.Context) {
switch viper.GetString(PROTOCOL) {
case "udp": case "udp":
sendUDP() sendUDP(ctx)
case "tcp": case "tcp":
sendTCP() sendTCP(ctx)
default: default:
log.Fatal().Str("protocol", viper.GetString("protocol")).Msg("Unknown protocol") log.Fatal().Str(PROTOCOL, viper.GetString(PROTOCOL)).Msg("Unknown protocol")
} }
} }
func sendTCP() { func sendTCP(ctx context.Context) {
var ( var (
conn net.Conn conn net.Conn
err error err error
@ -54,93 +71,116 @@ func sendTCP() {
n int n int
) )
if conn, err = net.Dial("tcp", viper.GetString("server")); err != nil { if conn, err = net.Dial(TCP, viper.GetString(SERVER)); err != nil {
log.Fatal().Err(err).Msg("Failed to connect to server") log.Fatal().Err(err).Msg("Failed to connect to server")
} }
defer func() { defer func() {
_ = conn.Close() _ = conn.Close()
}() }()
go func() {
<-ctx.Done()
_ = conn.Close()
}()
buf = make([]byte, 4096) buf = make([]byte, 4096)
for { for gLimit.isContinue(ctx) {
sb := slicewriter.NewSliceWriter(buf) sb := slicewriter.NewSliceWriter(buf)
if _, err = fmt.Fprintf(sb, "client %s | %v | %s", viper.GetString("name"), time.Now(), viper.GetString("message")); err != nil { if _, err = fmt.Fprintf(sb, "client %s | %v | %s", viper.GetString(NAME), time.Now(), viper.GetString(MESSAGE)); err != nil {
log.Fatal().Err(err).Msg("Failed to build client message") log.Fatal().Err(err).Msg("Failed to build client message")
} }
b = sb.Bytes() b = sb.Bytes()
if _, err = conn.Write(b); err != nil { if _, err = conn.Write(b); err != nil {
log.Fatal().Err(err).Msg("Failed to send client message") log.Fatal().Err(err).Str(PROTOCOL, TCP).Str(SERVER, viper.GetString(SERVER)).Msg("Failed to send client message")
} }
if n, err = conn.Read(buf); err != nil { if n, err = conn.Read(buf); err != nil {
log.Fatal().Err(err).Str("server", conn.RemoteAddr().String()).Msg("read from the server") log.Fatal().Err(err).Str(PROTOCOL, TCP).Str(SERVER, viper.GetString(SERVER)).Msg("read from the server")
} }
log.Info().Str("server", conn.RemoteAddr().String()).Msgf("%s", buf[:n]) log.Info().Str(PROTOCOL, TCP).Str(SERVER, viper.GetString(SERVER)).Msgf("%s", buf[:n])
time.Sleep(viper.GetDuration("sleep")) time.Sleep(viper.GetDuration(SLEEP))
} }
} }
func sendUDP() { func sendUDP(ctx context.Context) {
var ( var (
addr *net.UDPAddr addr *net.UDPAddr
conn *net.UDPConn conn *net.UDPConn
rAddr net.Addr
err error err error
buf, b []byte buf, b []byte
n int
) )
if addr, err = net.ResolveUDPAddr("udp", viper.GetString("server")); err != nil { if addr, err = net.ResolveUDPAddr(UDP, viper.GetString(SERVER)); err != nil {
log.Fatal().Err(err).Str("server", viper.GetString("server")).Msg("udp resolve address") log.Fatal().Err(err).Str(SERVER, viper.GetString(SERVER)).Msg("udp resolve address")
} }
if conn, err = net.DialUDP("udp", nil, addr); err != nil { if conn, err = net.DialUDP(UDP, nil, addr); err != nil {
log.Fatal().Err(err).Str("server", viper.GetString("server")).Msg("dial server udp") log.Fatal().Err(err).Str("server", viper.GetString(SERVER)).Msg("dial server udp")
} }
defer func() {
_ = conn.Close()
}()
go func() {
<-ctx.Done()
_ = conn.Close()
}()
buf = make([]byte, 4096) buf = make([]byte, 4096)
for { for gLimit.isContinue(ctx) {
sb := slicewriter.NewSliceWriter(buf) sb := slicewriter.NewSliceWriter(buf)
if _, err = fmt.Fprintf(sb, "client %s | %v | %s", viper.GetString("name"), time.Now(), viper.GetString("message")); err != nil { if _, err = fmt.Fprintf(sb, "client %s | %v | %s", viper.GetString(NAME), time.Now(), viper.GetString(MESSAGE)); err != nil {
log.Fatal().Err(err).Msg("Failed to build client message") log.Fatal().Err(err).Msg("Failed to build client message")
} }
b = sb.Bytes() b = sb.Bytes()
if _, err = conn.WriteTo(b, addr); err != nil {
log.Fatal().Err(err).Str(PROTOCOL, UDP).Str(SERVER, viper.GetString(SERVER)).Msg("Failed to send client message")
}
if n, rAddr, err = conn.ReadFrom(b); err != nil {
log.Fatal().Err(err).Str(PROTOCOL, UDP).Str(SERVER, viper.GetString(SERVER)).Msg("read from server")
}
log.Info().Str(PROTOCOL, UDP).Str(SERVER, rAddr.String()).Msgf("%s", buf[:n])
} }
//TODO: complete the implementation of sendUDP
} }
func parseFlags() { func parseFlags() {
pflag.String("name", "", "client name") pflag.String(NAME, "", "client name")
pflag.IntP("number", "n", 5, "number of messages to send; default 5; set to 0 for infinite") pflag.IntP("number", "n", 5, "number of messages to send; default 5; set to 0 for infinite")
pflag.DurationP("sleep", "s", 10*time.Millisecond, "sleep time between requests; default 10ms") pflag.DurationP(SLEEP, "s", 10*time.Millisecond, "sleep time between requests; default 10ms")
pflag.StringP("message", "m", "message from client", "message to send") pflag.StringP(MESSAGE, "m", "message from client", "message to send")
pflag.Bool("debug", false, "run in debug mode") pflag.Bool("debug", false, "run in debug mode")
pflag.String("tcp", "", "tcp server address") pflag.String(TCP, "", "tcp server address")
pflag.String("udp", "", "udp server address") pflag.String(UDP, "", "udp server address")
pflag.Parse() pflag.Parse()
_ = viper.BindPFlags(pflag.CommandLine) _ = viper.BindPFlags(pflag.CommandLine)
if !viper.IsSet("name") { gLimit = makeCounter(viper.GetInt("number"))
if !viper.IsSet(NAME) {
log.Fatal().Msg("server name is required") log.Fatal().Msg("server name is required")
} }
if viper.IsSet("tcp") && viper.IsSet("udp") { if viper.IsSet(TCP) && viper.IsSet(UDP) {
log.Fatal().Msg("cannot use tcp and udp at once") log.Fatal().Msg("cannot use tcp and udp at once")
} }
if !viper.IsSet("tcp") && !viper.IsSet("udp") { if !viper.IsSet(TCP) && !viper.IsSet(UDP) {
log.Fatal().Msg("server address is required") log.Fatal().Msg("server address is required")
} }
if viper.IsSet("tcp") { if viper.IsSet(TCP) {
viper.Set("protocol", "tcp") viper.Set(PROTOCOL, TCP)
viper.Set("server", viper.GetString("tcp")) viper.Set(SERVER, viper.GetString(TCP))
} }
if viper.IsSet("udp") { if viper.IsSet(UDP) {
viper.Set("protocol", "udp") viper.Set(PROTOCOL, UDP)
viper.Set("server", viper.GetString("udp")) viper.Set(SERVER, viper.GetString(UDP))
} }
zerolog.SetGlobalLevel(zerolog.InfoLevel) zerolog.SetGlobalLevel(zerolog.InfoLevel)
@ -148,3 +188,34 @@ func parseFlags() {
zerolog.SetGlobalLevel(zerolog.DebugLevel) zerolog.SetGlobalLevel(zerolog.DebugLevel)
} }
} }
type counter struct {
limit, tick int
}
func makeCounter(limit int) *counter {
if limit <= 0 {
log.Fatal().Msg("number must be > 0")
}
return &counter{limit: limit, tick: 0}
}
func (c *counter) isContinue(ctx context.Context) bool {
select {
case <-ctx.Done():
return false
default:
}
if c.limit == 0 {
return true
}
c.tick++
if c.tick < c.limit {
return true
}
return false
}