diff --git a/lib/readers/context.go b/lib/readers/context.go new file mode 100644 index 000000000..6af6efbaf --- /dev/null +++ b/lib/readers/context.go @@ -0,0 +1,28 @@ +package readers + +import ( + "context" + "io" +) + +// NewContextReader creates a reader, that returns any errors that ctx gives +func NewContextReader(ctx context.Context, r io.Reader) io.Reader { + return &contextReader{ + ctx: ctx, + r: r, + } +} + +type contextReader struct { + ctx context.Context + r io.Reader +} + +// Read bytes as per io.Reader interface +func (cr *contextReader) Read(p []byte) (n int, err error) { + err = cr.ctx.Err() + if err != nil { + return 0, err + } + return cr.r.Read(p) +} diff --git a/lib/readers/context_test.go b/lib/readers/context_test.go new file mode 100644 index 000000000..3bd70e4d5 --- /dev/null +++ b/lib/readers/context_test.go @@ -0,0 +1,28 @@ +package readers + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestContextReader(t *testing.T) { + r := NewPatternReader(100) + ctx, cancel := context.WithCancel(context.Background()) + cr := NewContextReader(ctx, r) + + var buf = make([]byte, 3) + + n, err := cr.Read(buf) + require.NoError(t, err) + assert.Equal(t, 3, n) + assert.Equal(t, []byte{0, 1, 2}, buf) + + cancel() + + n, err = cr.Read(buf) + assert.Equal(t, context.Canceled, err) + assert.Equal(t, 0, n) +}