diff --git a/fs/operations/multithread.go b/fs/operations/multithread.go index 5ab41b9a2..deb7c88ba 100644 --- a/fs/operations/multithread.go +++ b/fs/operations/multithread.go @@ -9,13 +9,13 @@ import ( "github.com/rclone/rclone/fs" "github.com/rclone/rclone/fs/accounting" + "github.com/rclone/rclone/lib/readers" "golang.org/x/sync/errgroup" + "golang.org/x/sync/semaphore" ) const ( - multithreadChunkSize = 64 << 10 - multithreadChunkSizeMask = multithreadChunkSize - 1 - multithreadReadBufferSize = 32 * 1024 + multithreadChunkSize = 64 << 10 ) // An offsetWriter maps writes at offset base to offset base+off in the underlying writer. @@ -60,7 +60,7 @@ func doMultiThreadCopy(ctx context.Context, f fs.Fs, src fs.Object) bool { } // ...destination doesn't support it dstFeatures := f.Features() - if dstFeatures.OpenWriterAt == nil { + if dstFeatures.OpenChunkWriter == nil && dstFeatures.OpenWriterAt == nil { return false } // ...if --multi-thread-streams not in use and source and @@ -73,21 +73,20 @@ func doMultiThreadCopy(ctx context.Context, f fs.Fs, src fs.Object) bool { // state for a multi-thread copy type multiThreadCopyState struct { - ctx context.Context - partSize int64 - size int64 - wc fs.WriterAtCloser - src fs.Object - acc *accounting.Account - streams int + ctx context.Context + partSize int64 + size int64 + src fs.Object + acc *accounting.Account + streams int + numChunks int } // Copy a single stream into place -func (mc *multiThreadCopyState) copyStream(ctx context.Context, stream int) (err error) { - ci := fs.GetConfig(ctx) +func (mc *multiThreadCopyState) copyStream(ctx context.Context, stream int, writer fs.ChunkWriter) (err error) { defer func() { if err != nil { - fs.Debugf(mc.src, "multi-thread copy: stream %d/%d failed: %v", stream+1, mc.streams, err) + fs.Debugf(mc.src, "multi-thread copy: stream %d/%d failed: %v", stream+1, mc.numChunks, err) } }() start := int64(stream) * mc.partSize @@ -99,7 +98,7 @@ func (mc *multiThreadCopyState) copyStream(ctx context.Context, stream int) (err end = mc.size } - fs.Debugf(mc.src, "multi-thread copy: stream %d/%d (%d-%d) size %v starting", stream+1, mc.streams, start, end, fs.SizeSuffix(end-start)) + fs.Debugf(mc.src, "multi-thread copy: stream %d/%d (%d-%d) size %v starting", stream+1, mc.numChunks, start, end, fs.SizeSuffix(end-start)) rc, err := Open(ctx, mc.src, &fs.RangeOption{Start: start, End: end - 1}) if err != nil { @@ -107,119 +106,99 @@ func (mc *multiThreadCopyState) copyStream(ctx context.Context, stream int) (err } defer fs.CheckClose(rc, &err) - var writer io.Writer = newOffsetWriter(mc.wc, start) - if ci.MultiThreadWriteBufferSize > 0 { - writer = bufio.NewWriterSize(writer, int(ci.MultiThreadWriteBufferSize)) - fs.Debugf(mc.src, "multi-thread copy: write buffer set to %v", ci.MultiThreadWriteBufferSize) + bytesWritten, err := writer.WriteChunk(stream, readers.NewRepeatableReader(rc)) + if err != nil { + return err } - // Copy the data - buf := make([]byte, multithreadReadBufferSize) - offset := start - for { - // Check if context cancelled and exit if so - if mc.ctx.Err() != nil { - return mc.ctx.Err() - } - nr, er := rc.Read(buf) - if nr > 0 { - err = mc.acc.AccountRead(nr) - if err != nil { - return fmt.Errorf("multipart copy: accounting failed: %w", err) - } - nw, ew := writer.Write(buf[0:nr]) - if nw > 0 { - offset += int64(nw) - } - if ew != nil { - return fmt.Errorf("multipart copy: write failed: %w", ew) - } - if nr != nw { - return fmt.Errorf("multipart copy: %w", io.ErrShortWrite) - } - } - if er != nil { - if er != io.EOF { - return fmt.Errorf("multipart copy: read failed: %w", er) - } - - // if we were buffering, flush do disk - switch w := writer.(type) { - case *bufio.Writer: - er2 := w.Flush() - if er2 != nil { - return fmt.Errorf("multipart copy: flush failed: %w", er2) - } - } - - break - } + // FIXME: Wrap ReadSeeker for Accounting + // However, to ensure reporting is correctly seeks have to be handled properly + errAccRead := mc.acc.AccountRead(int(bytesWritten)) + if errAccRead != nil { + return errAccRead } - if offset != end { - return fmt.Errorf("multipart copy: wrote %d bytes but expected to write %d", offset-start, end-start) - } - - fs.Debugf(mc.src, "multi-thread copy: stream %d/%d (%d-%d) size %v finished", stream+1, mc.streams, start, end, fs.SizeSuffix(end-start)) + fs.Debugf(mc.src, "multi-thread copy: stream %d/%d (%d-%d) size %v finished", stream+1, mc.numChunks, start, end, fs.SizeSuffix(bytesWritten)) return nil } -// Calculate the chunk sizes and updated number of streams -func (mc *multiThreadCopyState) calculateChunks() { - partSize := mc.size / int64(mc.streams) - // Round partition size up so partSize * streams >= size - if (mc.size % int64(mc.streams)) != 0 { - partSize++ - } - // round partSize up to nearest multithreadChunkSize boundary - mc.partSize = (partSize + multithreadChunkSizeMask) &^ multithreadChunkSizeMask - // recalculate number of streams - mc.streams = int(mc.size / mc.partSize) - // round streams up so partSize * streams >= size - if (mc.size % mc.partSize) != 0 { - mc.streams++ +// Given a file size and a chunkSize +// it returns the number of chunks, so that chunkSize * numChunks >= size +func calculateNumChunks(size int64, chunkSize int64) int { + numChunks := size / chunkSize + if size%chunkSize != 0 { + numChunks++ } + + return int(numChunks) } -// Copy src to (f, remote) using streams download threads and the OpenWriterAt feature +// Copy src to (f, remote) using streams download threads. It tries to use the OpenChunkWriter feature +// and if that's not available it creates an adapter using OpenWriterAt func multiThreadCopy(ctx context.Context, f fs.Fs, remote string, src fs.Object, streams int, tr *accounting.Transfer) (newDst fs.Object, err error) { - openWriterAt := f.Features().OpenWriterAt - if openWriterAt == nil { - return nil, errors.New("multi-thread copy: OpenWriterAt not supported") + openChunkWriter := f.Features().OpenChunkWriter + ci := fs.GetConfig(ctx) + if openChunkWriter == nil { + openWriterAt := f.Features().OpenWriterAt + if openWriterAt == nil { + return nil, errors.New("multi-part copy: neither OpenChunkWriter nor OpenWriterAt supported") + } + openChunkWriter = openChunkWriterFromOpenWriterAt(openWriterAt, int64(ci.MultiThreadChunkSize), int64(ci.MultiThreadWriteBufferSize), f) } + if src.Size() < 0 { - return nil, errors.New("multi-thread copy: can't copy unknown sized file") + return nil, fmt.Errorf("multi-thread copy: can't copy unknown sized file") } if src.Size() == 0 { - return nil, errors.New("multi-thread copy: can't copy zero sized file") + return nil, fmt.Errorf("multi-thread copy: can't copy zero sized file") } g, gCtx := errgroup.WithContext(ctx) - mc := &multiThreadCopyState{ - ctx: gCtx, - size: src.Size(), - src: src, - streams: streams, + chunkSize, chunkWriter, err := openChunkWriter(ctx, remote, src) + + if chunkSize > src.Size() { + fs.Debugf(src, "multi-thread copy: chunk size %v was bigger than source file size %v", fs.SizeSuffix(chunkSize), fs.SizeSuffix(src.Size())) + chunkSize = src.Size() + } + + numChunks := calculateNumChunks(src.Size(), chunkSize) + if streams > numChunks { + fs.Debugf(src, "multi-thread copy: number of streams '%d' was bigger than number of chunks '%d'", streams, numChunks) + streams = numChunks + } + + mc := &multiThreadCopyState{ + ctx: gCtx, + size: src.Size(), + src: src, + partSize: chunkSize, + streams: streams, + numChunks: numChunks, + } + + if err != nil { + return nil, fmt.Errorf("multipart copy: failed to open chunk writer: %w", err) } - mc.calculateChunks() // Make accounting mc.acc = tr.Account(ctx, nil) - // create write file handle - mc.wc, err = openWriterAt(gCtx, remote, mc.size) - if err != nil { - return nil, fmt.Errorf("multipart copy: failed to open destination: %w", err) - } - - fs.Debugf(src, "Starting multi-thread copy with %d parts of size %v", mc.streams, fs.SizeSuffix(mc.partSize)) - for stream := 0; stream < mc.streams; stream++ { - stream := stream + fs.Debugf(src, "Starting multi-thread copy with %d parts of size %v with %v parallel streams", mc.numChunks, fs.SizeSuffix(mc.partSize), mc.streams) + sem := semaphore.NewWeighted(int64(mc.streams)) + for chunk := 0; chunk < mc.numChunks; chunk++ { + fs.Debugf(src, "Acquiring semaphore...") + if err := sem.Acquire(ctx, 1); err != nil { + fs.Errorf(src, "Failed to acquire semaphore: %v", err) + break + } + currChunk := chunk g.Go(func() (err error) { - return mc.copyStream(gCtx, stream) + defer sem.Release(1) + return mc.copyStream(gCtx, currChunk, chunkWriter) }) } + err = g.Wait() - closeErr := mc.wc.Close() + closeErr := chunkWriter.Close() if err != nil { return nil, err } @@ -232,13 +211,94 @@ func multiThreadCopy(ctx context.Context, f fs.Fs, remote string, src fs.Object, return nil, fmt.Errorf("multi-thread copy: failed to find object after copy: %w", err) } - err = obj.SetModTime(ctx, src.ModTime(ctx)) - switch err { - case nil, fs.ErrorCantSetModTime, fs.ErrorCantSetModTimeWithoutDelete: - default: - return nil, fmt.Errorf("multi-thread copy: failed to set modification time: %w", err) + if f.Features().PartialUploads { + err = obj.SetModTime(ctx, src.ModTime(ctx)) + switch err { + case nil, fs.ErrorCantSetModTime, fs.ErrorCantSetModTimeWithoutDelete: + default: + return nil, fmt.Errorf("multi-thread copy: failed to set modification time: %w", err) + } } - fs.Debugf(src, "Finished multi-thread copy with %d parts of size %v", mc.streams, fs.SizeSuffix(mc.partSize)) + fs.Debugf(src, "Finished multi-thread copy with %d parts of size %v", mc.numChunks, fs.SizeSuffix(mc.partSize)) return obj, nil } + +type writerAtChunkWriter struct { + ctx context.Context + remote string + size int64 + writerAt fs.WriterAtCloser + chunkSize int64 + chunks int + writeBufferSize int64 + f fs.Fs +} + +func (w writerAtChunkWriter) WriteChunk(chunkNumber int, reader io.ReadSeeker) (int64, error) { + fs.Debugf(w.remote, "writing chunk %v", chunkNumber) + + bytesToWrite := w.chunkSize + if chunkNumber == (w.chunks-1) && w.size%w.chunkSize != 0 { + bytesToWrite = w.size % w.chunkSize + } + + var writer io.Writer = newOffsetWriter(w.writerAt, int64(chunkNumber)*w.chunkSize) + if w.writeBufferSize > 0 { + writer = bufio.NewWriterSize(writer, int(w.writeBufferSize)) + } + n, err := io.Copy(writer, reader) + if err != nil { + return -1, err + } + if n != bytesToWrite { + return -1, fmt.Errorf("expected to write %v bytes for chunk %v, but wrote %v bytes", bytesToWrite, chunkNumber, n) + } + // if we were buffering, flush do disk + switch w := writer.(type) { + case *bufio.Writer: + er2 := w.Flush() + if er2 != nil { + return -1, fmt.Errorf("multipart copy: flush failed: %w", err) + } + } + return n, nil +} + +func (w writerAtChunkWriter) Close() error { + return w.writerAt.Close() +} + +func (w writerAtChunkWriter) Abort() error { + obj, err := w.f.NewObject(w.ctx, w.remote) + if err != nil { + return fmt.Errorf("multi-thread copy: failed to find temp file when aborting chunk writer: %w", err) + } + return obj.Remove(w.ctx) +} + +func openChunkWriterFromOpenWriterAt(openWriterAt func(ctx context.Context, remote string, size int64) (fs.WriterAtCloser, error), chunkSize int64, writeBufferSize int64, f fs.Fs) func(ctx context.Context, remote string, src fs.ObjectInfo, options ...fs.OpenOption) (chunkSizeResult int64, writer fs.ChunkWriter, err error) { + return func(ctx context.Context, remote string, src fs.ObjectInfo, options ...fs.OpenOption) (chunkSizeResult int64, writer fs.ChunkWriter, err error) { + writerAt, err := openWriterAt(ctx, remote, src.Size()) + if err != nil { + return -1, nil, err + } + + if writeBufferSize > 0 { + fs.Debugf(src.Remote(), "multi-thread copy: write buffer set to %v", writeBufferSize) + } + + chunkWriter := &writerAtChunkWriter{ + ctx: ctx, + remote: remote, + size: src.Size(), + chunkSize: chunkSize, + chunks: calculateNumChunks(src.Size(), chunkSize), + writerAt: writerAt, + writeBufferSize: writeBufferSize, + f: f, + } + + return chunkSize, chunkWriter, nil + } +} diff --git a/fs/operations/multithread_test.go b/fs/operations/multithread_test.go index aef6f75af..cd5f430f8 100644 --- a/fs/operations/multithread_test.go +++ b/fs/operations/multithread_test.go @@ -86,27 +86,24 @@ func TestDoMultiThreadCopy(t *testing.T) { assert.True(t, doMultiThreadCopy(ctx, f, src)) } -func TestMultithreadCalculateChunks(t *testing.T) { +func TestMultithreadCalculateNumChunks(t *testing.T) { for _, test := range []struct { - size int64 - streams int - wantPartSize int64 - wantStreams int + size int64 + chunkSize int64 + wantNumChunks int }{ - {size: 1, streams: 10, wantPartSize: multithreadChunkSize, wantStreams: 1}, - {size: 1 << 20, streams: 1, wantPartSize: 1 << 20, wantStreams: 1}, - {size: 1 << 20, streams: 2, wantPartSize: 1 << 19, wantStreams: 2}, - {size: (1 << 20) + 1, streams: 2, wantPartSize: (1 << 19) + multithreadChunkSize, wantStreams: 2}, - {size: (1 << 20) - 1, streams: 2, wantPartSize: (1 << 19), wantStreams: 2}, + {size: 1, chunkSize: multithreadChunkSize, wantNumChunks: 1}, + {size: 1 << 20, chunkSize: 1, wantNumChunks: 1 << 20}, + {size: 1 << 20, chunkSize: 2, wantNumChunks: 1 << 19}, + {size: (1 << 20) + 1, chunkSize: 2, wantNumChunks: (1 << 19) + 1}, + {size: (1 << 20) - 1, chunkSize: 2, wantNumChunks: 1 << 19}, } { t.Run(fmt.Sprintf("%+v", test), func(t *testing.T) { mc := &multiThreadCopyState{ - size: test.size, - streams: test.streams, + size: test.size, } - mc.calculateChunks() - assert.Equal(t, test.wantPartSize, mc.partSize) - assert.Equal(t, test.wantStreams, mc.streams) + mc.numChunks = calculateNumChunks(test.size, test.chunkSize) + assert.Equal(t, test.wantNumChunks, mc.numChunks) }) } }