2025-04-21 18:37:11 +10:00

229 lines
5.1 KiB
Go

package main
/*
Copyright 2025 Suyono <suyono3484@gmail.com>
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import (
"context"
"fmt"
"gitea.suyono.dev/suyono/netbounce/cmd/slicewriter"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/spf13/pflag"
"github.com/spf13/viper"
"golang.org/x/sys/unix"
"net"
"os/signal"
"sync"
"time"
)
func main() {
//zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
parseFlags()
log.Debug().Msg("Starting server")
ctx, cancel := signal.NotifyContext(context.Background(), unix.SIGINT, unix.SIGTERM)
defer cancel()
wg := &sync.WaitGroup{}
openPorts(wg, ctx)
wg.Wait()
}
func parseFlags() {
pflag.String("name", "", "server name")
pflag.Bool("debug", false, "run in debug mode")
pflag.StringSlice("tcp", []string{}, "tcp listen address")
pflag.StringSlice("udp", []string{}, "udp listen address")
pflag.Parse()
_ = viper.BindPFlags(pflag.CommandLine)
if !viper.IsSet("name") {
log.Fatal().Msg("server name is required")
}
zerolog.SetGlobalLevel(zerolog.InfoLevel)
if viper.GetBool("debug") {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
}
}
func openPorts(wg *sync.WaitGroup, ctx context.Context) {
tcpPorts := viper.GetStringSlice("tcp")
for _, tcpPort := range tcpPorts {
wg.Add(1)
go listen(ctx, wg, tcpPort)
}
udpPorts := viper.GetStringSlice("udp")
for _, udpPort := range udpPorts {
wg.Add(1)
go bindUDP(ctx, wg, udpPort)
}
}
func ClosePacket(ctx context.Context, conn net.PacketConn) {
<-ctx.Done()
_ = conn.Close()
}
func bindUDP(ctx context.Context, wg *sync.WaitGroup, address string) {
defer wg.Done()
var (
conn net.PacketConn
err error
buf, b []byte
n int
addr net.Addr
)
log.Debug().Str("address", address).Msg("binding socket for UDP")
if conn, err = net.ListenPacket("udp", address); err != nil {
panic(fmt.Errorf("failed to bind udp address: %v", err))
}
go ClosePacket(ctx, conn)
buf = make([]byte, 4096)
udpLoop:
for {
if n, addr, err = conn.ReadFrom(buf); err != nil && n == 0 {
select {
case <-ctx.Done():
break udpLoop
default:
}
log.Error().Err(err).Msg("failed to read packet")
continue udpLoop
}
log.Info().Str("client", addr.String()).Msgf("received message: %s", buf[:n])
select {
case <-ctx.Done():
break udpLoop
default:
}
sb := slicewriter.NewSliceWriter(buf)
if _, err = fmt.Fprintf(sb, "server: %s | UDP | %v", viper.GetString("name"), time.Now()); err != nil {
log.Error().Err(err).Msg("build server message")
}
b = sb.Bytes()
if n, err = conn.WriteTo(b, addr); err != nil {
select {
case <-ctx.Done():
break udpLoop
default:
}
log.Error().Err(err).Str("client", addr.String()).Msg("failed to write packet")
continue udpLoop
}
if n != len(b) {
log.Debug().Str("client", addr.String()).Msg("incomplete packet sent")
}
log.Info().Str("client", addr.String()).Msg("packet received and replied")
}
}
func CloseListener(ctx context.Context, listener net.Listener) {
<-ctx.Done()
_ = listener.Close()
}
func listen(ctx context.Context, wg *sync.WaitGroup, address string) {
defer wg.Done()
var (
listener net.Listener
err error
conn net.Conn
)
if listener, err = net.Listen("tcp", address); err != nil {
log.Error().Err(err).Str("address", address).Msg("failed to listen")
return
}
go CloseListener(ctx, listener)
tcpIncoming:
for {
if conn, err = listener.Accept(); err != nil {
select {
case <-ctx.Done():
break tcpIncoming
default:
}
log.Error().Err(err).Str("address", address).Msg("failed to accept connection")
continue tcpIncoming
}
wg.Add(1)
go handleTCP(ctx, wg, conn)
}
}
func CloseConnection(ctx context.Context, conn net.Conn) {
<-ctx.Done()
_ = conn.Close()
}
func handleTCP(ctx context.Context, wg *sync.WaitGroup, conn net.Conn) {
defer wg.Done()
defer func() {
_ = conn.Close()
}()
var (
buf, b []byte
err error
n int
)
buf = make([]byte, 4096)
addr := conn.RemoteAddr()
cctx, cancel := context.WithCancel(ctx)
go CloseConnection(cctx, conn)
defer cancel()
for {
if n, err = conn.Read(buf); err != nil {
log.Error().Err(err).Str("client", addr.String()).Msg("failed to read data from TCP connection")
return
}
log.Info().Str("client", addr.String()).Msgf("received message: %s", buf[:n])
sb := slicewriter.NewSliceWriter(buf)
if _, err = fmt.Fprintf(sb, "server: %s | TCP | %v", viper.GetString("name"), time.Now()); err != nil {
log.Error().Err(err).Msg("build server message")
}
b = sb.Bytes()
if _, err = conn.Write(b); err != nil {
log.Error().Err(err).Str("client", addr.String()).Msg("failed to write data to TCP connection")
return
}
}
}