diff --git a/lib/readers/pattern_reader.go b/lib/readers/pattern_reader.go index a4e534baa..a480dc185 100644 --- a/lib/readers/pattern_reader.go +++ b/lib/readers/pattern_reader.go @@ -1,29 +1,59 @@ package readers -import "io" +import ( + "io" + + "github.com/pkg/errors" +) + +// This is the smallest prime less than 256 +// +// Using a prime here means we are less likely to hit repeating patterns +const patternReaderModulo = 251 // NewPatternReader creates a reader, that returns a deterministic byte pattern. // After length bytes are read -func NewPatternReader(length int64) io.Reader { +func NewPatternReader(length int64) io.ReadSeeker { return &patternReader{ length: length, } } type patternReader struct { + offset int64 length int64 c byte } func (r *patternReader) Read(p []byte) (n int, err error) { for i := range p { - if r.length <= 0 { + if r.offset >= r.length { return n, io.EOF } p[i] = r.c - r.c = (r.c + 1) % 253 - r.length-- + r.c = (r.c + 1) % patternReaderModulo + r.offset++ n++ } return } + +// Seek implements the io.Seeker interface. +func (r *patternReader) Seek(offset int64, whence int) (abs int64, err error) { + switch whence { + case io.SeekStart: + abs = offset + case io.SeekCurrent: + abs = r.offset + offset + case io.SeekEnd: + abs = r.length + offset + default: + return 0, errors.New("patternReader: invalid whence") + } + if abs < 0 { + return 0, errors.New("patternReader: negative position") + } + r.offset = abs + r.c = byte(abs % patternReaderModulo) + return abs, nil +} diff --git a/lib/readers/pattern_reader_test.go b/lib/readers/pattern_reader_test.go index ed10887cc..10c21d064 100644 --- a/lib/readers/pattern_reader_test.go +++ b/lib/readers/pattern_reader_test.go @@ -28,3 +28,61 @@ func TestPatternReader(t *testing.T) { require.Equal(t, io.EOF, err) require.Equal(t, 0, n) } + +func TestPatternReaderSeek(t *testing.T) { + r := NewPatternReader(1024) + b, err := ioutil.ReadAll(r) + require.NoError(t, err) + + for i := range b { + assert.Equal(t, byte(i%251), b[i]) + } + + n, err := r.Seek(1, io.SeekStart) + require.NoError(t, err) + assert.Equal(t, int64(1), n) + + // pos 1 + + b2 := make([]byte, 10) + nn, err := r.Read(b2) + require.NoError(t, err) + assert.Equal(t, 10, nn) + assert.Equal(t, b[1:11], b2) + + // pos 11 + + n, err = r.Seek(9, io.SeekCurrent) + require.NoError(t, err) + assert.Equal(t, int64(20), n) + + // pos 20 + + nn, err = r.Read(b2) + require.NoError(t, err) + assert.Equal(t, 10, nn) + assert.Equal(t, b[20:30], b2) + + n, err = r.Seek(-24, io.SeekEnd) + require.NoError(t, err) + assert.Equal(t, int64(1000), n) + + // pos 1000 + + nn, err = r.Read(b2) + require.NoError(t, err) + assert.Equal(t, 10, nn) + assert.Equal(t, b[1000:1010], b2) + + // Now test errors + + n, err = r.Seek(1, 400) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid whence") + assert.Equal(t, int64(0), n) + + n, err = r.Seek(-1, io.SeekStart) + require.Error(t, err) + assert.Contains(t, err.Error(), "negative position") + assert.Equal(t, int64(0), n) +}