diff --git a/backend/webdav/webdav.go b/backend/webdav/webdav.go index 1dd7b6060..ea25fa276 100644 --- a/backend/webdav/webdav.go +++ b/backend/webdav/webdav.go @@ -20,6 +20,7 @@ import ( "path" "strconv" "strings" + "sync" "time" "github.com/pkg/errors" @@ -143,6 +144,7 @@ type Fs struct { retryWithZeroDepth bool // some vendors (sharepoint) won't list files when Depth is 1 (our default) hasMD5 bool // set if can use owncloud style checksums for MD5 hasSHA1 bool // set if can use owncloud style checksums for SHA1 + ntlmAuthMu sync.Mutex // mutex to serialize NTLM auth roundtrips } // Object describes a webdav object @@ -206,6 +208,22 @@ func (f *Fs) shouldRetry(resp *http.Response, err error) (bool, error) { return fserrors.ShouldRetry(err) || fserrors.ShouldRetryHTTP(resp, retryErrorCodes), err } +// safeRoundTripper is a wrapper for http.RoundTripper that serializes +// http roundtrips. NTLM authentication sequence can involve up to four +// rounds of negotiations and might fail due to concurrency. +// This wrapper allows to use ntlmssp.Negotiator safely with goroutines. +type safeRoundTripper struct { + fs *Fs + rt http.RoundTripper +} + +// RoundTrip guards wrapped RoundTripper by a mutex. +func (srt *safeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + srt.fs.ntlmAuthMu.Lock() + defer srt.fs.ntlmAuthMu.Unlock() + return srt.rt.RoundTrip(req) +} + // itemIsDir returns true if the item is a directory // // When a client sees a resourcetype it doesn't recognize it should @@ -365,6 +383,16 @@ func NewFs(ctx context.Context, name, root string, m configmap.Mapper) (fs.Fs, e return nil, err } + f := &Fs{ + name: name, + root: root, + opt: *opt, + endpoint: u, + endpointURL: u.String(), + pacer: fs.NewPacer(ctx, pacer.NewDefault(pacer.MinSleep(minSleep), pacer.MaxSleep(maxSleep), pacer.DecayConstant(decayConstant))), + precision: fs.ModTimeNotSupported, + } + client := fshttp.NewClient(ctx) if opt.Vendor == "sharepoint-ntlm" { // Disable transparent HTTP/2 support as per https://golang.org/pkg/net/http/ , @@ -374,19 +402,15 @@ func NewFs(ctx context.Context, name, root string, m configmap.Mapper) (fs.Fs, e t := fshttp.NewTransportCustom(ctx, func(t *http.Transport) { t.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{} }) + // Add NTLM layer - client.Transport = ntlmssp.Negotiator{RoundTripper: t} - } - f := &Fs{ - name: name, - root: root, - opt: *opt, - endpoint: u, - endpointURL: u.String(), - srv: rest.NewClient(client).SetRoot(u.String()), - pacer: fs.NewPacer(ctx, pacer.NewDefault(pacer.MinSleep(minSleep), pacer.MaxSleep(maxSleep), pacer.DecayConstant(decayConstant))), - precision: fs.ModTimeNotSupported, + client.Transport = &safeRoundTripper{ + fs: f, + rt: ntlmssp.Negotiator{RoundTripper: t}, + } } + f.srv = rest.NewClient(client).SetRoot(u.String()) + f.features = (&fs.Features{ CanHaveEmptyDirectories: true, }).Fill(ctx, f)