diff --git a/server.go b/server.go index 6c20ec38..609f125d 100644 --- a/server.go +++ b/server.go @@ -367,10 +367,17 @@ func (srv *Server) spawnWorker(w *response) { } } +func unlockOnce(l sync.Locker) func() { + var once sync.Once + return func() { once.Do(l.Unlock) } +} + // ListenAndServe starts a nameserver on the configured address in *Server. func (srv *Server) ListenAndServe() error { + unlock := unlockOnce(&srv.lock) srv.lock.Lock() - defer srv.lock.Unlock() + defer unlock() + if srv.started { return &Error{err: "server already started"} } @@ -396,10 +403,8 @@ func (srv *Server) ListenAndServe() error { } srv.Listener = l srv.started = true - srv.lock.Unlock() - err = srv.serveTCP(l) - srv.lock.Lock() // to satisfy the defer at the top - return err + unlock() + return srv.serveTCP(l) case "tcp-tls", "tcp4-tls", "tcp6-tls": network := "tcp" if srv.Net == "tcp4-tls" { @@ -414,10 +419,8 @@ func (srv *Server) ListenAndServe() error { } srv.Listener = l srv.started = true - srv.lock.Unlock() - err = srv.serveTCP(l) - srv.lock.Lock() // to satisfy the defer at the top - return err + unlock() + return srv.serveTCP(l) case "udp", "udp4", "udp6": a, err := net.ResolveUDPAddr(srv.Net, addr) if err != nil { @@ -432,10 +435,8 @@ func (srv *Server) ListenAndServe() error { } srv.PacketConn = l srv.started = true - srv.lock.Unlock() - err = srv.serveUDP(l) - srv.lock.Lock() // to satisfy the defer at the top - return err + unlock() + return srv.serveUDP(l) } return &Error{err: "bad network"} } @@ -443,8 +444,10 @@ func (srv *Server) ListenAndServe() error { // ActivateAndServe starts a nameserver with the PacketConn or Listener // configured in *Server. Its main use is to start a server from systemd. func (srv *Server) ActivateAndServe() error { + unlock := unlockOnce(&srv.lock) srv.lock.Lock() - defer srv.lock.Unlock() + defer unlock() + if srv.started { return &Error{err: "server already started"} } @@ -464,18 +467,14 @@ func (srv *Server) ActivateAndServe() error { return e } srv.started = true - srv.lock.Unlock() - e := srv.serveUDP(t) - srv.lock.Lock() // to satisfy the defer at the top - return e + unlock() + return srv.serveUDP(t) } } if l != nil { srv.started = true - srv.lock.Unlock() - e := srv.serveTCP(l) - srv.lock.Lock() // to satisfy the defer at the top - return e + unlock() + return srv.serveTCP(l) } return &Error{err: "bad listeners"} }