From 50a0c3482df5227533eba342b8be8d82302bd294 Mon Sep 17 00:00:00 2001 From: Nick Craig-Wood Date: Tue, 3 May 2022 16:37:55 +0100 Subject: [PATCH] lib/readers: add FakeSeeker to adapt io.Reader to io.ReadSeeker #5422 --- lib/readers/fakeseeker.go | 72 ++++++++++++++++++++++++++++ lib/readers/fakeseeker_test.go | 87 ++++++++++++++++++++++++++++++++++ 2 files changed, 159 insertions(+) create mode 100644 lib/readers/fakeseeker.go create mode 100644 lib/readers/fakeseeker_test.go diff --git a/lib/readers/fakeseeker.go b/lib/readers/fakeseeker.go new file mode 100644 index 000000000..844fc0422 --- /dev/null +++ b/lib/readers/fakeseeker.go @@ -0,0 +1,72 @@ +package readers + +import ( + "errors" + "fmt" + "io" +) + +// FakeSeeker adapts an io.Seeker into an io.ReadSeeker +type FakeSeeker struct { + in io.Reader + readErr error + length int64 + offset int64 + read bool +} + +// NewFakeSeeker creates a fake io.ReadSeeker from an io.Reader +// +// This can be seeked before reading to discover the length passed in. +func NewFakeSeeker(in io.Reader, length int64) io.ReadSeeker { + if rs, ok := in.(io.ReadSeeker); ok { + return rs + } + return &FakeSeeker{ + in: in, + length: length, + } +} + +// Seek the stream - possible only before reading +func (r *FakeSeeker) Seek(offset int64, whence int) (abs int64, err error) { + if r.readErr != nil { + return 0, r.readErr + } + if r.read { + return 0, fmt.Errorf("FakeSeeker: can't Seek(%d, %d) after reading", offset, whence) + } + switch whence { + case io.SeekStart: + abs = offset + case io.SeekCurrent: + abs = r.offset + offset + case io.SeekEnd: + abs = r.length + offset + default: + return 0, errors.New("FakeSeeker: invalid whence") + } + if abs < 0 { + return 0, errors.New("FakeSeeker: negative position") + } + r.offset = abs + return abs, nil +} + +// Read data from the stream. Will give an error if seeked. +func (r *FakeSeeker) Read(p []byte) (n int, err error) { + if r.readErr != nil { + return 0, r.readErr + } + if !r.read && r.offset != 0 { + return 0, errors.New("FakeSeeker: not at start: can't read") + } + n, err = r.in.Read(p) + if n != 0 { + r.read = true + } + if err != nil { + r.readErr = err + } + return n, err +} diff --git a/lib/readers/fakeseeker_test.go b/lib/readers/fakeseeker_test.go new file mode 100644 index 000000000..0702b7cbb --- /dev/null +++ b/lib/readers/fakeseeker_test.go @@ -0,0 +1,87 @@ +package readers + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Check interface +var _ io.ReadSeeker = &FakeSeeker{} + +func TestFakeSeeker(t *testing.T) { + // Test that passing in an io.ReadSeeker just passes it through + bufReader := bytes.NewReader([]byte{1}) + r := NewFakeSeeker(bufReader, 5) + assert.Equal(t, r, bufReader) + + in := bytes.NewBufferString("hello") + buf := make([]byte, 16) + r = NewFakeSeeker(in, 5) + assert.NotEqual(t, r, in) + + // check the seek offset is as passed in + checkPos := func(pos int64) { + abs, err := r.Seek(0, io.SeekCurrent) + require.NoError(t, err) + assert.Equal(t, pos, abs) + } + + // Test some seeking + checkPos(0) + + abs, err := r.Seek(2, io.SeekStart) + require.NoError(t, err) + assert.Equal(t, int64(2), abs) + checkPos(2) + + abs, err = r.Seek(-1, io.SeekEnd) + require.NoError(t, err) + assert.Equal(t, int64(4), abs) + checkPos(4) + + // Check can't read if not at start + _, err = r.Read(buf) + require.ErrorContains(t, err, "not at start") + + // Seek back to start + abs, err = r.Seek(-4, io.SeekCurrent) + require.NoError(t, err) + assert.Equal(t, int64(0), abs) + checkPos(0) + + _, err = r.Seek(42, 17) + require.ErrorContains(t, err, "invalid whence") + + _, err = r.Seek(-1, io.SeekStart) + require.ErrorContains(t, err, "negative position") + + // Test reading now seeked back to the start + n, err := r.Read(buf) + require.NoError(t, err) + assert.Equal(t, 5, n) + assert.Equal(t, []byte("hello"), buf[:5]) + + // Seeking should give an error now + _, err = r.Seek(-1, io.SeekEnd) + require.ErrorContains(t, err, "after reading") +} + +func TestFakeSeekerError(t *testing.T) { + in := bytes.NewBufferString("hello") + r := NewFakeSeeker(in, 5) + assert.NotEqual(t, r, in) + + buf, err := io.ReadAll(r) + require.NoError(t, err) + assert.Equal(t, []byte("hello"), buf) + + _, err = r.Read(buf) + assert.Equal(t, io.EOF, err) + + _, err = r.Seek(0, io.SeekStart) + assert.Equal(t, io.EOF, err) +}