package main /* Copyright 2025 Suyono Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ import ( "context" "fmt" "gitea.suyono.dev/suyono/netbounce/cmd/slicewriter" "github.com/rs/zerolog" "github.com/rs/zerolog/log" "github.com/spf13/pflag" "github.com/spf13/viper" "golang.org/x/sys/unix" "net" "os/signal" "sync" "time" ) func main() { //zerolog.TimeFieldFormat = zerolog.TimeFormatUnix parseFlags() log.Debug().Msg("Starting server") ctx, cancel := signal.NotifyContext(context.Background(), unix.SIGINT, unix.SIGTERM) defer cancel() wg := &sync.WaitGroup{} openPorts(wg, ctx) wg.Wait() } func parseFlags() { pflag.String("name", "", "server name") pflag.Bool("debug", false, "run in debug mode") pflag.StringSlice("tcp", []string{}, "tcp listen address") pflag.StringSlice("udp", []string{}, "udp listen address") pflag.Parse() _ = viper.BindPFlags(pflag.CommandLine) if !viper.IsSet("name") { log.Fatal().Msg("server name is required") } zerolog.SetGlobalLevel(zerolog.InfoLevel) if viper.GetBool("debug") { zerolog.SetGlobalLevel(zerolog.DebugLevel) } } func openPorts(wg *sync.WaitGroup, ctx context.Context) { tcpPorts := viper.GetStringSlice("tcp") for _, tcpPort := range tcpPorts { wg.Add(1) go listen(ctx, wg, tcpPort) } udpPorts := viper.GetStringSlice("udp") for _, udpPort := range udpPorts { wg.Add(1) go bindUDP(ctx, wg, udpPort) } } func ClosePacket(ctx context.Context, conn net.PacketConn) { <-ctx.Done() _ = conn.Close() } func bindUDP(ctx context.Context, wg *sync.WaitGroup, address string) { defer wg.Done() var ( conn net.PacketConn err error buf, b []byte n int addr net.Addr ) log.Debug().Str("address", address).Msg("binding socket for UDP") if conn, err = net.ListenPacket("udp", address); err != nil { panic(fmt.Errorf("failed to bind udp address: %v", err)) } go ClosePacket(ctx, conn) buf = make([]byte, 4096) udpLoop: for { if n, addr, err = conn.ReadFrom(buf); err != nil && n == 0 { select { case <-ctx.Done(): break udpLoop default: } log.Error().Err(err).Msg("failed to read packet") continue udpLoop } log.Info().Str("client", addr.String()).Msgf("received message: %s", buf[:n]) select { case <-ctx.Done(): break udpLoop default: } sb := slicewriter.NewSliceWriter(buf) if _, err = fmt.Fprintf(sb, "server: %s | UDP | %v", viper.GetString("name"), time.Now()); err != nil { log.Error().Err(err).Msg("build server message") } b = sb.Bytes() if n, err = conn.WriteTo(b, addr); err != nil { select { case <-ctx.Done(): break udpLoop default: } log.Error().Err(err).Str("client", addr.String()).Msg("failed to write packet") continue udpLoop } if n != len(b) { log.Debug().Str("client", addr.String()).Msg("incomplete packet sent") } log.Info().Str("client", addr.String()).Msg("packet received and replied") } } func CloseListener(ctx context.Context, listener net.Listener) { <-ctx.Done() _ = listener.Close() } func listen(ctx context.Context, wg *sync.WaitGroup, address string) { defer wg.Done() var ( listener net.Listener err error conn net.Conn ) if listener, err = net.Listen("tcp", address); err != nil { log.Error().Err(err).Str("address", address).Msg("failed to listen") return } go CloseListener(ctx, listener) tcpIncoming: for { if conn, err = listener.Accept(); err != nil { select { case <-ctx.Done(): break tcpIncoming default: } log.Error().Err(err).Str("address", address).Msg("failed to accept connection") continue tcpIncoming } wg.Add(1) go handleTCP(ctx, wg, conn) } } func CloseConnection(ctx context.Context, conn net.Conn) { <-ctx.Done() _ = conn.Close() } func handleTCP(ctx context.Context, wg *sync.WaitGroup, conn net.Conn) { defer wg.Done() defer func() { _ = conn.Close() }() var ( buf, b []byte err error n int ) buf = make([]byte, 4096) addr := conn.RemoteAddr() cctx, cancel := context.WithCancel(ctx) go CloseConnection(cctx, conn) defer cancel() for { if n, err = conn.Read(buf); err != nil { log.Error().Err(err).Str("client", addr.String()).Msg("failed to read data from TCP connection") return } log.Info().Str("client", addr.String()).Msgf("received message: %s", buf[:n]) sb := slicewriter.NewSliceWriter(buf) if _, err = fmt.Fprintf(sb, "server: %s | TCP | %v", viper.GetString("name"), time.Now()); err != nil { log.Error().Err(err).Msg("build server message") } b = sb.Bytes() if _, err = conn.Write(b); err != nil { log.Error().Err(err).Str("client", addr.String()).Msg("failed to write data to TCP connection") return } } }