From ec7cc2b3c3b1049b61dbb64c6152c2e1e0dc44f4 Mon Sep 17 00:00:00 2001 From: Tom Mombourquette Date: Wed, 23 Nov 2022 05:44:53 -0400 Subject: [PATCH] lib/http: Simplify server.go to export an http server rather than an interface This also makes the implementation public. --- cmd/serve/http/http.go | 2 +- lib/http/serve/dir_test.go | 6 ++--- lib/http/serve/serve_test.go | 12 ++++----- lib/http/server.go | 48 +++++++++++++++++------------------- lib/http/server_test.go | 2 +- 5 files changed, 33 insertions(+), 37 deletions(-) diff --git a/cmd/serve/http/http.go b/cmd/serve/http/http.go index d88714ddd..78f6870f0 100644 --- a/cmd/serve/http/http.go +++ b/cmd/serve/http/http.go @@ -103,7 +103,7 @@ control the stats printing. type serveCmd struct { f fs.Fs vfs *vfs.VFS - server libhttp.Server + server *libhttp.Server } func run(ctx context.Context, f fs.Fs, opt Options) (*serveCmd, error) { diff --git a/lib/http/serve/dir_test.go b/lib/http/serve/dir_test.go index bdcf3f948..be128ae33 100644 --- a/lib/http/serve/dir_test.go +++ b/lib/http/serve/dir_test.go @@ -3,7 +3,7 @@ package serve import ( "errors" "html/template" - "io/ioutil" + "io" "net/http" "net/http/httptest" "net/url" @@ -94,7 +94,7 @@ func TestError(t *testing.T) { Error("potato", w, "sausage", err) resp := w.Result() assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) - body, _ := ioutil.ReadAll(resp.Body) + body, _ := io.ReadAll(resp.Body) assert.Equal(t, "sausage.\n", string(body)) } @@ -108,7 +108,7 @@ func TestServe(t *testing.T) { d.Serve(w, r) resp := w.Result() assert.Equal(t, http.StatusOK, resp.StatusCode) - body, _ := ioutil.ReadAll(resp.Body) + body, _ := io.ReadAll(resp.Body) assert.Equal(t, ` diff --git a/lib/http/serve/serve_test.go b/lib/http/serve/serve_test.go index 934fabe65..b6cc70976 100644 --- a/lib/http/serve/serve_test.go +++ b/lib/http/serve/serve_test.go @@ -1,7 +1,7 @@ package serve import ( - "io/ioutil" + "io" "net/http" "net/http/httptest" "testing" @@ -17,7 +17,7 @@ func TestObjectBadMethod(t *testing.T) { Object(w, r, o) resp := w.Result() assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) - body, _ := ioutil.ReadAll(resp.Body) + body, _ := io.ReadAll(resp.Body) assert.Equal(t, "Method Not Allowed\n", string(body)) } @@ -30,7 +30,7 @@ func TestObjectHEAD(t *testing.T) { assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, "5", resp.Header.Get("Content-Length")) assert.Equal(t, "bytes", resp.Header.Get("Accept-Ranges")) - body, _ := ioutil.ReadAll(resp.Body) + body, _ := io.ReadAll(resp.Body) assert.Equal(t, "", string(body)) } @@ -43,7 +43,7 @@ func TestObjectGET(t *testing.T) { assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, "5", resp.Header.Get("Content-Length")) assert.Equal(t, "bytes", resp.Header.Get("Accept-Ranges")) - body, _ := ioutil.ReadAll(resp.Body) + body, _ := io.ReadAll(resp.Body) assert.Equal(t, "hello", string(body)) } @@ -58,7 +58,7 @@ func TestObjectRange(t *testing.T) { assert.Equal(t, "3", resp.Header.Get("Content-Length")) assert.Equal(t, "bytes", resp.Header.Get("Accept-Ranges")) assert.Equal(t, "bytes 3-5/10", resp.Header.Get("Content-Range")) - body, _ := ioutil.ReadAll(resp.Body) + body, _ := io.ReadAll(resp.Body) assert.Equal(t, "345", string(body)) } @@ -71,6 +71,6 @@ func TestObjectBadRange(t *testing.T) { resp := w.Result() assert.Equal(t, http.StatusBadRequest, resp.StatusCode) assert.Equal(t, "10", resp.Header.Get("Content-Length")) - body, _ := ioutil.ReadAll(resp.Body) + body, _ := io.ReadAll(resp.Body) assert.Equal(t, "Bad Request\n", string(body)) } diff --git a/lib/http/server.go b/lib/http/server.go index 4dbc31c8e..9d69c6585 100644 --- a/lib/http/server.go +++ b/lib/http/server.go @@ -121,16 +121,6 @@ func DefaultCfg() Config { } } -// Server interface of http server -type Server interface { - Router() chi.Router - Serve() - Shutdown() error - HTMLTemplate() *template.Template - URLs() []string - Wait() -} - type instance struct { url string listener net.Listener @@ -145,7 +135,8 @@ func (s instance) serve(wg *sync.WaitGroup) { } } -type server struct { +// Server contains info about the running http server +type Server struct { wg sync.WaitGroup mux chi.Router tlsConfig *tls.Config @@ -157,25 +148,25 @@ type server struct { } // Option allows customizing the server -type Option func(*server) +type Option func(*Server) // WithAuth option initializes the appropriate auth middleware func WithAuth(cfg AuthConfig) Option { - return func(s *server) { + return func(s *Server) { s.auth = cfg } } // WithConfig option applies the Config to the server, overriding defaults func WithConfig(cfg Config) Option { - return func(s *server) { + return func(s *Server) { s.cfg = cfg } } // WithTemplate option allows the parsing of a template func WithTemplate(cfg TemplateConfig) Option { - return func(s *server) { + return func(s *Server) { s.template = &cfg } } @@ -184,12 +175,17 @@ func WithTemplate(cfg TemplateConfig) Option { // This function is provided if the default http server does not meet a services requirements and should not generally be used // A http server can listen using multiple listeners. For example, a listener for port 80, and a listener for port 443. // tlsListeners are ignored if opt.TLSKey is not provided -func NewServer(ctx context.Context, options ...Option) (*server, error) { - s := &server{ +func NewServer(ctx context.Context, options ...Option) (*Server, error) { + s := &Server{ mux: chi.NewRouter(), cfg: DefaultCfg(), } + // Make sure default logger is logging where everything else is + // middleware.DefaultLogger = middleware.RequestLogger(&middleware.DefaultLogFormatter{Logger: log.Default(), NoColor: true}) + // Log requests + // s.mux.Use(middleware.Logger) + for _, opt := range options { opt(s) } @@ -275,7 +271,7 @@ func NewServer(ctx context.Context, options ...Option) (*server, error) { return s, nil } -func (s *server) initAuth() { +func (s *Server) initAuth() { if s.auth.CustomAuthFn != nil { s.mux.Use(MiddlewareAuthCustom(s.auth.CustomAuthFn, s.auth.Realm)) return @@ -292,7 +288,7 @@ func (s *server) initAuth() { } } -func (s *server) initTemplate() error { +func (s *Server) initTemplate() error { if s.template == nil { return nil } @@ -317,7 +313,7 @@ var ( ErrTLSParseCA = errors.New("unable to parse client certificate authority") ) -func (s *server) initTLS() error { +func (s *Server) initTLS() error { if s.cfg.TLSCert == "" && s.cfg.TLSKey == "" && len(s.cfg.TLSCertBody) == 0 && len(s.cfg.TLSKeyBody) == 0 { return nil } @@ -383,7 +379,7 @@ func (s *server) initTLS() error { } // Serve starts the HTTP server on each listener -func (s *server) Serve() { +func (s *Server) Serve() { s.wg.Add(len(s.instances)) for _, ii := range s.instances { // TODO: decide how/when to log listening url @@ -393,17 +389,17 @@ func (s *server) Serve() { } // Wait blocks while the server is serving requests -func (s *server) Wait() { +func (s *Server) Wait() { s.wg.Wait() } // Router returns the server base router -func (s *server) Router() chi.Router { +func (s *Server) Router() chi.Router { return s.mux } // Shutdown gracefully shuts down the server -func (s *server) Shutdown() error { +func (s *Server) Shutdown() error { ctx := context.Background() for _, ii := range s.instances { if err := ii.httpServer.Shutdown(ctx); err != nil { @@ -416,12 +412,12 @@ func (s *server) Shutdown() error { } // HTMLTemplate returns the parsed template, if WithTemplate option was passed. -func (s *server) HTMLTemplate() *template.Template { +func (s *Server) HTMLTemplate() *template.Template { return s.htmlTemplate } // URLs returns all configured URLS -func (s *server) URLs() []string { +func (s *Server) URLs() []string { var out []string for _, ii := range s.instances { if ii.listener.Addr().Network() == "unix" { diff --git a/lib/http/server_test.go b/lib/http/server_test.go index 42c630ea4..506a78d51 100644 --- a/lib/http/server_test.go +++ b/lib/http/server_test.go @@ -26,7 +26,7 @@ func testExpectRespBody(t *testing.T, resp *http.Response, expected []byte) { require.Equal(t, expected, body) } -func testGetServerURL(t *testing.T, s Server) string { +func testGetServerURL(t *testing.T, s *Server) string { urls := s.URLs() require.GreaterOrEqual(t, len(urls), 1, "server should return at least one url") return urls[0]