From 6a5b7664f7d006f010c1d5c6e3bebda4a4cc2963 Mon Sep 17 00:00:00 2001 From: Arnie97 Date: Sun, 20 Mar 2022 18:12:48 +0800 Subject: [PATCH] backend/http: support content-range response header --- backend/http/http.go | 26 ++-------- backend/http/http_internal_test.go | 77 ++++++++++++++++++++++-------- backend/s3/s3.go | 21 ++------ lib/rest/headers.go | 39 +++++++++++++++ lib/rest/headers_test.go | 41 ++++++++++++++++ 5 files changed, 146 insertions(+), 58 deletions(-) create mode 100644 lib/rest/headers.go create mode 100644 lib/rest/headers_test.go diff --git a/backend/http/http.go b/backend/http/http.go index 498d92e21..867c59371 100644 --- a/backend/http/http.go +++ b/backend/http/http.go @@ -13,7 +13,6 @@ import ( "net/http" "net/url" "path" - "strconv" "strings" "sync" "time" @@ -317,15 +316,6 @@ func (f *Fs) url(remote string) string { return f.endpointURL + rest.URLPathEscape(remote) } -// parse s into an int64, on failure return def -func parseInt64(s string, def int64) int64 { - n, e := strconv.ParseInt(s, 10, 64) - if e != nil { - return def - } - return n -} - // Errors returned by parseName var ( errURLJoinFailed = errors.New("URLJoin failed") @@ -601,23 +591,18 @@ func (o *Object) head(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to stat: %w", err) } - return o.stat(ctx, res, true) + return o.stat(ctx, res) } // stat updates info fields in the Object according to HTTP response headers -func (o *Object) stat(ctx context.Context, res *http.Response, isRangeRequest bool) error { +func (o *Object) stat(ctx context.Context, res *http.Response) error { t, err := http.ParseTime(res.Header.Get("Last-Modified")) if err != nil { t = timeUnset } o.modTime = t - - // TODO: parse Content-Range for total size - // https://developer.mozilla.org/en-US/docs/Web/HTTP/Range_requests - if !isRangeRequest { - o.size = parseInt64(res.Header.Get("Content-Length"), -1) - o.contentType = res.Header.Get("Content-Type") - } + o.contentType = res.Header.Get("Content-Type") + o.size = rest.ParseSizeFromHeaders(res.Header) // If NoSlash is set then check ContentType to see if it is a directory if o.fs.opt.NoSlash { @@ -666,8 +651,7 @@ func (o *Object) Open(ctx context.Context, options ...fs.OpenOption) (in io.Read } if o.fs.opt.NoHead { - isRangeRequest := len(req.Header.Get("Range")) > 0 - if err = o.stat(ctx, res, isRangeRequest); err != nil { + if err = o.stat(ctx, res); err != nil { return nil, fmt.Errorf("Stat failed: %w", err) } } diff --git a/backend/http/http_internal_test.go b/backend/http/http_internal_test.go index 8bf6be6b6..d0aba666e 100644 --- a/backend/http/http_internal_test.go +++ b/backend/http/http_internal_test.go @@ -194,31 +194,66 @@ func TestNewObject(t *testing.T) { } func TestOpen(t *testing.T) { - f, tidy := prepare(t) + m, tidy := prepareServer(t) defer tidy() - o, err := f.NewObject(context.Background(), "four/under four.txt") - require.NoError(t, err) + for _, head := range []bool{false, true} { + if !head { + m.Set("no_head", "true") + } + f, err := NewFs(context.Background(), remoteName, "", m) + require.NoError(t, err) - // Test normal read - fd, err := o.Open(context.Background()) - require.NoError(t, err) - data, err := io.ReadAll(fd) - require.NoError(t, err) - require.NoError(t, fd.Close()) - if lineEndSize == 2 { - assert.Equal(t, "beetroot\r\n", string(data)) - } else { - assert.Equal(t, "beetroot\n", string(data)) + for _, rangeRead := range []bool{false, true} { + o, err := f.NewObject(context.Background(), "four/under four.txt") + require.NoError(t, err) + + if !head { + // Test mod time is still indeterminate + tObj := o.ModTime(context.Background()) + assert.Equal(t, time.Duration(0), time.Unix(0, 0).Sub(tObj)) + + // Test file size is still indeterminate + assert.Equal(t, int64(-1), o.Size()) + } + + var data []byte + if !rangeRead { + // Test normal read + fd, err := o.Open(context.Background()) + require.NoError(t, err) + data, err = io.ReadAll(fd) + require.NoError(t, err) + require.NoError(t, fd.Close()) + if lineEndSize == 2 { + assert.Equal(t, "beetroot\r\n", string(data)) + } else { + assert.Equal(t, "beetroot\n", string(data)) + } + } else { + // Test with range request + fd, err := o.Open(context.Background(), &fs.RangeOption{Start: 1, End: 5}) + require.NoError(t, err) + data, err = io.ReadAll(fd) + require.NoError(t, err) + require.NoError(t, fd.Close()) + assert.Equal(t, "eetro", string(data)) + } + + fi, err := os.Stat(filepath.Join(filesPath, "four", "under four.txt")) + require.NoError(t, err) + tFile := fi.ModTime() + + // Test the time is always correct on the object after file open + tObj := o.ModTime(context.Background()) + fstest.AssertTimeEqualWithPrecision(t, o.Remote(), tFile, tObj, time.Second) + + if !rangeRead { + // Test the file size + assert.Equal(t, int64(len(data)), o.Size()) + } + } } - - // Test with range request - fd, err = o.Open(context.Background(), &fs.RangeOption{Start: 1, End: 5}) - require.NoError(t, err) - data, err = io.ReadAll(fd) - require.NoError(t, err) - require.NoError(t, fd.Close()) - assert.Equal(t, "eetro", string(data)) } func TestMimeType(t *testing.T) { diff --git a/backend/s3/s3.go b/backend/s3/s3.go index efea72410..a1a2fdb3b 100644 --- a/backend/s3/s3.go +++ b/backend/s3/s3.go @@ -4761,23 +4761,12 @@ func (o *Object) downloadFromURL(ctx context.Context, bucketPath string, options return nil, err } - contentLength := &resp.ContentLength - if resp.Header.Get("Content-Range") != "" { - var contentRange = resp.Header.Get("Content-Range") - slash := strings.IndexRune(contentRange, '/') - if slash >= 0 { - i, err := strconv.ParseInt(contentRange[slash+1:], 10, 64) - if err == nil { - contentLength = &i - } else { - fs.Debugf(o, "Failed to find parse integer from in %q: %v", contentRange, err) - } - } else { - fs.Debugf(o, "Failed to find length in %q", contentRange) - } + contentLength := rest.ParseSizeFromHeaders(resp.Header) + if contentLength < 0 { + fs.Debugf(o, "Failed to parse file size from headers") } - lastModified, err := time.Parse(time.RFC1123, resp.Header.Get("Last-Modified")) + lastModified, err := http.ParseTime(resp.Header.Get("Last-Modified")) if err != nil { fs.Debugf(o, "Failed to parse last modified from string %s, %v", resp.Header.Get("Last-Modified"), err) } @@ -4801,7 +4790,7 @@ func (o *Object) downloadFromURL(ctx context.Context, bucketPath string, options var head = s3.HeadObjectOutput{ ETag: header("Etag"), - ContentLength: contentLength, + ContentLength: &contentLength, LastModified: &lastModified, Metadata: metaData, CacheControl: header("Cache-Control"), diff --git a/lib/rest/headers.go b/lib/rest/headers.go new file mode 100644 index 000000000..3e52fbdaa --- /dev/null +++ b/lib/rest/headers.go @@ -0,0 +1,39 @@ +package rest + +import ( + "net/http" + "strconv" + "strings" +) + +// ParseSizeFromHeaders parses HTTP response headers to get the full file size. +// Returns -1 if the headers did not exist or were invalid. +func ParseSizeFromHeaders(headers http.Header) (size int64) { + size = -1 + + var contentLength = headers.Get("Content-Length") + if len(contentLength) != 0 { + var err error + if size, err = strconv.ParseInt(contentLength, 10, 64); err != nil { + return -1 + } + } + + var contentRange = headers.Get("Content-Range") + if len(contentRange) == 0 { + return size + } + + if !strings.HasPrefix(contentRange, "bytes ") { + return -1 + } + slash := strings.IndexRune(contentRange, '/') + if slash < 0 { + return -1 + } + ret, err := strconv.ParseInt(contentRange[slash+1:], 10, 64) + if err != nil { + return -1 + } + return ret +} diff --git a/lib/rest/headers_test.go b/lib/rest/headers_test.go new file mode 100644 index 000000000..23400062d --- /dev/null +++ b/lib/rest/headers_test.go @@ -0,0 +1,41 @@ +package rest + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseSizeFromHeaders(t *testing.T) { + testCases := []struct { + ContentLength, ContentRange string + Size int64 + }{{ + "", "", -1, + }, { + "42", "", 42, + }, { + "42", "invalid", -1, + }, { + "", "bytes 22-33/42", 42, + }, { + "12", "bytes 22-33/42", 42, + }, { + "12", "otherUnit 22-33/42", -1, + }, { + "12", "bytes 22-33/*", -1, + }, { + "0", "bytes */42", 42, + }} + for _, testCase := range testCases { + headers := make(http.Header, 2) + if len(testCase.ContentLength) > 0 { + headers.Set("Content-Length", testCase.ContentLength) + } + if len(testCase.ContentRange) > 0 { + headers.Set("Content-Range", testCase.ContentRange) + } + assert.Equalf(t, testCase.Size, ParseSizeFromHeaders(headers), "%+v", testCase) + } +}