From 08252778baac4d2acd5cc54d439c63a84e1437e0 Mon Sep 17 00:00:00 2001 From: Suyono Date: Fri, 5 Feb 2021 18:30:21 +0700 Subject: [PATCH] fix(WaitAvailable): check reader close state --- multiread_test.go | 136 ++++++++++++++++++++++++++++++++++++++++++---- multireader.go | 4 ++ 2 files changed, 129 insertions(+), 11 deletions(-) diff --git a/multiread_test.go b/multiread_test.go index 9f0113f..ac3f53f 100644 --- a/multiread_test.go +++ b/multiread_test.go @@ -2,6 +2,7 @@ package multireader import ( "bytes" + "io" "math/rand" "sync" "testing" @@ -46,9 +47,10 @@ func TestReadWrite(t *testing.T) { writeFunc := func(t *testing.T, m *MultiReader, data []byte, wg *sync.WaitGroup) { var ( - n int - dl int - r int + n int + dl int + r int + loopFlag bool ) defer wg.Done() @@ -60,13 +62,21 @@ func TestReadWrite(t *testing.T) { r, _ = m.Write(data[n : n+512+dl]) } n += r + if rand.Intn(5) == 0 { + loopFlag = false + for !loopFlag { + if rand.Intn(100) == 0 { + loopFlag = true + } + } + } } _ = m.Close() } wg.Add(1) go writeFunc(t, m, src, wg) - readFunc := func(t *testing.T, r *Reader, data []byte, wg *sync.WaitGroup) { + readFunc := func(t *testing.T, r *Reader, data []byte, wg *sync.WaitGroup, split bool) { var ( offset int n int @@ -77,10 +87,30 @@ func TestReadWrite(t *testing.T) { offset = 0 for { if offset+1024 < len(data) { - n, err = r.Read(data[offset : offset+1024]) - if err != nil { - t.Logf("reader got error: %v", err) - break + if split { + _, err = r.WaitAvailable() + if err != nil { + t.Logf("reader got error on WaitAvailable: %v", err) + break + } + n, err = r.ReadAhead(data[offset : offset+1024]) + if err != nil { + t.Logf("reader got error on ReadAhead: %v", err) + break + } + + n -= 5 + _, err = r.Seek(int64(n), io.SeekCurrent) + if err != nil { + t.Logf("reader got error on Seek: %v", err) + break + } + } else { + n, err = r.Read(data[offset : offset+1024]) + if err != nil { + t.Logf("reader got error: %v", err) + break + } } } else { n, err = r.Read(data[offset:]) @@ -95,11 +125,11 @@ func TestReadWrite(t *testing.T) { } wg.Add(1) - go readFunc(t, r1, dst1, wg) + go readFunc(t, r1, dst1, wg, false) wg.Add(1) - go readFunc(t, r2, dst2, wg) + go readFunc(t, r2, dst2, wg, false) wg.Add(1) - go readFunc(t, r3, dst3, wg) + go readFunc(t, r3, dst3, wg, true) wg.Wait() if !bytes.Equal(src, dst1) { @@ -112,3 +142,87 @@ func TestReadWrite(t *testing.T) { t.Error("src and dst3 mismatch") } } + +func TestReaderClose(t *testing.T) { + var ( + m *MultiReader + r1 *Reader + r2 *Reader + r3 *Reader + src, dst, dst1, dst3 []byte + err error + found bool + ctr int + ) + + src = make([]byte, 16) + dst = make([]byte, 32) + dst1 = dst[:16] + dst3 = dst[16:] + m = NewMultiReader(0) + r1 = m.NewReader() + r2 = m.NewReader() + r3 = m.NewReader() + + _, _ = rand.Read(src) + _, err = m.Write(src) + if err != nil { + t.Fatalf("error setting up data: %v", err) + } + + found = false + for ctr = range m.pos { + if m.pos[ctr].id == r2.id { + found = true + break + } + } + if !found { + t.Error("unexpected behavior, r2 should exist in reader position list") + } + + _ = r2.Close() + _, err = r2.Read(dst) + if err == nil && err != ErrReaderClosed { + t.Error("unexpected behavior, r2 should be closed") + } + + found = false + for ctr = range m.pos { + if m.pos[ctr].id == r2.id { + found = true + break + } + } + if found { + t.Error("unexpected behavior, r2 should be removed from reader position list") + } + + _, err = r1.Read(dst1) + if err != nil { + t.Errorf("failing to read from r1") + } + if !bytes.Equal(dst1, src) { + t.Error("invalid result reading from r1") + } + + _ = m.Close() + _, err = r3.Read(dst3) + if err != nil { + t.Errorf("unexpected behavior, cannot read form r3: %v", err) + } + if !bytes.Equal(src, dst3) { + t.Error("invalid result reading from r3") + } + + _, err = r3.WaitAvailable() + if err == nil { + t.Error("unexpected behavior; WaitAvailable when readpos == lastWritePos should fail") + } + + _, err = r3.ReadAhead(dst3) + if err == nil { + t.Error("unexpected behavior; ReadAhead when readpos == lastWritePos should fail") + } + +} diff --git a/multireader.go b/multireader.go index c343ca5..c0c907a 100644 --- a/multireader.go +++ b/multireader.go @@ -210,6 +210,10 @@ func (r *Reader) ReadAhead(data []byte) (int, error) { } func (r *Reader) waitAvailable() (int, error) { + if r.closed { + return 0, ErrReaderClosed + } + for r.readPos >= r.multiReader.lastWritePos { if r.closed { return 0, ErrReaderClosed