From 6c8148ef39ea51b1d4092128c598ac12c6fc1523 Mon Sep 17 00:00:00 2001 From: yuudi <26199752+yuudi@users.noreply.github.com> Date: Wed, 26 Jul 2023 05:15:54 -0400 Subject: [PATCH] http servers: allow CORS to be set with --allow-origin flag - fixes #5078 Some changes about test cases: Because MiddlewareCORS will return early on OPTIONS request, this middleware should only be used once at NewServer function. Test cases should pass AllowOrigin config instead of adding this middleware again. A new test case was added to test CORS preflight request with an authenticator. Preflight request should always return 200 OK regardless of autentications. Co-authored-by: yuudi --- fs/rc/rc.go | 31 +++++++------ fs/rc/rcflags/rcflags.go | 1 - fs/rc/rcserver/rcserver.go | 19 -------- fs/rc/rcserver/rcserver_test.go | 26 ----------- lib/http/middleware.go | 7 +++ lib/http/middleware_test.go | 80 ++++++++++++++++++++++++++++----- lib/http/server.go | 4 ++ lib/http/server_test.go | 2 - 8 files changed, 95 insertions(+), 75 deletions(-) diff --git a/fs/rc/rc.go b/fs/rc/rc.go index 3b150b5f4..b7db4e587 100644 --- a/fs/rc/rc.go +++ b/fs/rc/rc.go @@ -18,22 +18,21 @@ import ( // Options contains options for the remote control server type Options struct { - HTTP libhttp.Config - Auth libhttp.AuthConfig - Template libhttp.TemplateConfig - Enabled bool // set to enable the server - Serve bool // set to serve files from remotes - Files string // set to enable serving files locally - NoAuth bool // set to disable auth checks on AuthRequired methods - WebUI bool // set to launch the web ui - WebGUIUpdate bool // set to check new update - WebGUIForceUpdate bool // set to force download new update - WebGUINoOpenBrowser bool // set to disable auto opening browser - WebGUIFetchURL string // set the default url for fetching webgui - AccessControlAllowOrigin string // set the access control for CORS configuration - EnableMetrics bool // set to disable prometheus metrics on /metrics - JobExpireDuration time.Duration - JobExpireInterval time.Duration + HTTP libhttp.Config + Auth libhttp.AuthConfig + Template libhttp.TemplateConfig + Enabled bool // set to enable the server + Serve bool // set to serve files from remotes + Files string // set to enable serving files locally + NoAuth bool // set to disable auth checks on AuthRequired methods + WebUI bool // set to launch the web ui + WebGUIUpdate bool // set to check new update + WebGUIForceUpdate bool // set to force download new update + WebGUINoOpenBrowser bool // set to disable auto opening browser + WebGUIFetchURL string // set the default url for fetching webgui + EnableMetrics bool // set to disable prometheus metrics on /metrics + JobExpireDuration time.Duration + JobExpireInterval time.Duration } // DefaultOpt is the default values used for Options diff --git a/fs/rc/rcflags/rcflags.go b/fs/rc/rcflags/rcflags.go index f323806c8..7857c5189 100644 --- a/fs/rc/rcflags/rcflags.go +++ b/fs/rc/rcflags/rcflags.go @@ -27,7 +27,6 @@ func AddFlags(flagSet *pflag.FlagSet) { flags.BoolVarP(flagSet, &Opt.WebGUIForceUpdate, "rc-web-gui-force-update", "", false, "Force update to latest version of web gui") flags.BoolVarP(flagSet, &Opt.WebGUINoOpenBrowser, "rc-web-gui-no-open-browser", "", false, "Don't open the browser automatically") flags.StringVarP(flagSet, &Opt.WebGUIFetchURL, "rc-web-fetch-url", "", "https://api.github.com/repos/rclone/rclone-webui-react/releases/latest", "URL to fetch the releases for webgui") - flags.StringVarP(flagSet, &Opt.AccessControlAllowOrigin, "rc-allow-origin", "", "", "Set the allowed origin for CORS") flags.BoolVarP(flagSet, &Opt.EnableMetrics, "rc-enable-metrics", "", false, "Enable prometheus metrics on /metrics") flags.DurationVarP(flagSet, &Opt.JobExpireDuration, "rc-job-expire-duration", "", Opt.JobExpireDuration, "Expire finished async jobs older than this value") flags.DurationVarP(flagSet, &Opt.JobExpireInterval, "rc-job-expire-interval", "", Opt.JobExpireInterval, "Interval to check for expired async jobs") diff --git a/fs/rc/rcserver/rcserver.go b/fs/rc/rcserver/rcserver.go index f353c3ea1..8afd871ce 100644 --- a/fs/rc/rcserver/rcserver.go +++ b/fs/rc/rcserver/rcserver.go @@ -15,7 +15,6 @@ import ( "regexp" "sort" "strings" - "sync" "time" "github.com/go-chi/chi/v5/middleware" @@ -38,7 +37,6 @@ import ( ) var promHandler http.Handler -var onlyOnceWarningAllowOrigin sync.Once func init() { rcloneCollector := accounting.NewRcloneCollector(context.Background()) @@ -214,23 +212,6 @@ func writeError(path string, in rc.Params, w http.ResponseWriter, err error, sta func (s *Server) handler(w http.ResponseWriter, r *http.Request) { path := strings.TrimLeft(r.URL.Path, "/") - allowOrigin := rcflags.Opt.AccessControlAllowOrigin - if allowOrigin != "" { - onlyOnceWarningAllowOrigin.Do(func() { - if allowOrigin == "*" { - fs.Logf(nil, "Warning: Allow origin set to *. This can cause serious security problems.") - } - }) - w.Header().Add("Access-Control-Allow-Origin", allowOrigin) - } else { - urls := s.server.URLs() - if len(urls) == 1 { - w.Header().Add("Access-Control-Allow-Origin", urls[0]) - } else { - fs.Errorf(nil, "Warning, need exactly 1 URL for Access-Control-Allow-Origin, got %d %q", len(urls), urls) - } - } - // echo back access control headers client needs //reqAccessHeaders := r.Header.Get("Access-Control-Request-Headers") w.Header().Add("Access-Control-Request-Method", "POST, OPTIONS, GET, HEAD") diff --git a/fs/rc/rcserver/rcserver_test.go b/fs/rc/rcserver/rcserver_test.go index 278722297..05889ecc9 100644 --- a/fs/rc/rcserver/rcserver_test.go +++ b/fs/rc/rcserver/rcserver_test.go @@ -552,32 +552,6 @@ Unknown command testServer(t, tests, &opt) } -func TestMethods(t *testing.T) { - tests := []testRun{{ - Name: "options", - URL: "", - Method: "OPTIONS", - Status: http.StatusOK, - Expected: "", - Headers: map[string]string{ - "Access-Control-Allow-Origin": "testURL", - "Access-Control-Request-Method": "POST, OPTIONS, GET, HEAD", - "Access-Control-Allow-Headers": "authorization, Content-Type", - }, - }, { - Name: "bad", - URL: "", - Method: "POTATO", - Status: http.StatusMethodNotAllowed, - Expected: `Method Not Allowed -`, - }} - opt := newTestOpt() - opt.Serve = true - opt.Files = testFs - testServer(t, tests, &opt) -} - func TestMetrics(t *testing.T) { stats := accounting.GlobalStats() tests := makeMetricsTestCases(stats) diff --git a/lib/http/middleware.go b/lib/http/middleware.go index 2aa319c81..6deba60f6 100644 --- a/lib/http/middleware.go +++ b/lib/http/middleware.go @@ -181,6 +181,13 @@ func MiddlewareCORS(allowOrigin string) Middleware { w.Header().Add("Access-Control-Request-Method", "POST, OPTIONS, GET, HEAD") w.Header().Add("Access-Control-Allow-Headers", "authorization, Content-Type") + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) + return + // Because CORS preflight OPTIONS requests are not authenticated, + // and require a 200 OK response, we will return early here. + } + next.ServeHTTP(w, r) }) } diff --git a/lib/http/middleware_test.go b/lib/http/middleware_test.go index 24cc97afb..f8ac196b9 100644 --- a/lib/http/middleware_test.go +++ b/lib/http/middleware_test.go @@ -329,23 +329,22 @@ var _testCORSHeaderKeys = []string{ func TestMiddlewareCORS(t *testing.T) { servers := []struct { - name string - http Config - origin string + name string + http Config }{ { name: "EmptyOrigin", http: Config{ - ListenAddr: []string{"127.0.0.1:0"}, + ListenAddr: []string{"127.0.0.1:0"}, + AllowOrigin: "", }, - origin: "", }, { name: "CustomOrigin", http: Config{ - ListenAddr: []string{"127.0.0.1:0"}, + ListenAddr: []string{"127.0.0.1:0"}, + AllowOrigin: "http://test.rclone.org", }, - origin: "http://test.rclone.org", }, } @@ -357,8 +356,6 @@ func TestMiddlewareCORS(t *testing.T) { require.NoError(t, s.Shutdown()) }() - s.Router().Use(MiddlewareCORS(ss.origin)) - expected := []byte("data") s.Router().Mount("/", testEchoHandler(expected)) s.Serve() @@ -384,8 +381,69 @@ func TestMiddlewareCORS(t *testing.T) { } expectedOrigin := url - if ss.origin != "" { - expectedOrigin = ss.origin + if ss.http.AllowOrigin != "" { + expectedOrigin = ss.http.AllowOrigin + } + require.Equal(t, expectedOrigin, resp.Header.Get("Access-Control-Allow-Origin"), "allow origin should match") + }) + } +} + +func TestMiddlewareCORSWithAuth(t *testing.T) { + authServers := []struct { + name string + http Config + auth AuthConfig + }{ + { + name: "ServerWithAuth", + http: Config{ + ListenAddr: []string{"127.0.0.1:0"}, + AllowOrigin: "http://test.rclone.org", + }, + auth: AuthConfig{ + Realm: "test", + BasicUser: "test_user", + BasicPass: "test_pass", + }, + }, + } + + for _, ss := range authServers { + t.Run(ss.name, func(t *testing.T) { + s, err := NewServer(context.Background(), WithConfig(ss.http)) + require.NoError(t, err) + defer func() { + require.NoError(t, s.Shutdown()) + }() + + expected := []byte("data") + s.Router().Mount("/", testEchoHandler(expected)) + s.Serve() + + url := testGetServerURL(t, s) + + client := &http.Client{} + req, err := http.NewRequest("OPTIONS", url, nil) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + defer func() { + _ = resp.Body.Close() + }() + + require.Equal(t, http.StatusOK, resp.StatusCode, "OPTIONS should return ok even if not authenticated") + + testExpectRespBody(t, resp, []byte{}) + + for _, key := range _testCORSHeaderKeys { + require.Contains(t, resp.Header, key, "CORS headers should be sent even if not authenticated") + } + + expectedOrigin := url + if ss.http.AllowOrigin != "" { + expectedOrigin = ss.http.AllowOrigin } require.Equal(t, expectedOrigin, resp.Header.Get("Access-Control-Allow-Origin"), "allow origin should match") }) diff --git a/lib/http/server.go b/lib/http/server.go index 3c26d5cd5..3a43d37d9 100644 --- a/lib/http/server.go +++ b/lib/http/server.go @@ -109,6 +109,7 @@ type Config struct { TLSKeyBody []byte // TLS PEM Private key body, ignores TLSKey ClientCA string // Client certificate authority to verify clients with MinTLSVersion string // MinTLSVersion contains the minimum TLS version that is acceptable. + AllowOrigin string // AllowOrigin sets the Access-Control-Allow-Origin header } // AddFlagsPrefix adds flags for the httplib @@ -122,6 +123,7 @@ func (cfg *Config) AddFlagsPrefix(flagSet *pflag.FlagSet, prefix string) { flags.StringVarP(flagSet, &cfg.ClientCA, prefix+"client-ca", "", cfg.ClientCA, "Client certificate authority to verify clients with") flags.StringVarP(flagSet, &cfg.BaseURL, prefix+"baseurl", "", cfg.BaseURL, "Prefix for URLs - leave blank for root") flags.StringVarP(flagSet, &cfg.MinTLSVersion, prefix+"min-tls-version", "", cfg.MinTLSVersion, "Minimum TLS version that is acceptable") + flags.StringVarP(flagSet, &cfg.AllowOrigin, prefix+"allow-origin", "", cfg.AllowOrigin, "Origin which cross-domain request (CORS) can be executed from") } // AddHTTPFlagsPrefix adds flags for the httplib @@ -236,6 +238,8 @@ func NewServer(ctx context.Context, options ...Option) (*Server, error) { return nil, err } + s.mux.Use(MiddlewareCORS(s.cfg.AllowOrigin)) + s.initAuth() for _, addr := range s.cfg.ListenAddr { diff --git a/lib/http/server_test.go b/lib/http/server_test.go index d3bf997d4..656c7702f 100644 --- a/lib/http/server_test.go +++ b/lib/http/server_test.go @@ -82,8 +82,6 @@ func TestNewServerUnix(t *testing.T) { require.Empty(t, s.URLs(), "unix socket should not appear in URLs") - s.Router().Use(MiddlewareCORS("")) - expected := []byte("hello world") s.Router().Mount("/", testEchoHandler(expected)) s.Serve()