package multireader import ( "bytes" "io" "math/rand" "sync" "testing" "time" ) func TestMultiRead(t *testing.T) { t.Run("positive", func(t *testing.T) { _ = NewMultiReader(64) }) t.Run("zero", func(t *testing.T) { _ = NewMultiReader(0) }) t.Run("negative", func(t *testing.T) { _ = NewMultiReader(-64) }) } func TestReadWrite(t *testing.T) { var ( src []byte dst1 []byte dst2 []byte dst3 []byte wg *sync.WaitGroup ) src = make([]byte, 20480) dst1 = make([]byte, 20480) dst2 = make([]byte, 20480) dst3 = make([]byte, 20480) n, err := rand.Read(src) if n < len(src) || err != nil { t.Fatalf("cannot initiate test data: %d; %v", n, err) } m := NewMultiReader(512) r1 := m.NewReader() r2 := m.NewReader() r3 := m.NewReader() wg = &sync.WaitGroup{} writeFunc := func(t *testing.T, m *MultiReader, data []byte, wg *sync.WaitGroup) { var ( n int dl int r int loopFlag bool ) defer wg.Done() for n < len(data) { dl = rand.Intn(512) if n+512+dl > len(data) { r, _ = m.Write(data[n:]) } else { r, _ = m.Write(data[n : n+512+dl]) } n += r if rand.Intn(5) == 0 { loopFlag = false for !loopFlag { if rand.Intn(100) == 0 { loopFlag = true } } } } _ = m.Close() } wg.Add(1) go writeFunc(t, m, src, wg) readFunc := func(t *testing.T, r *Reader, data []byte, wg *sync.WaitGroup, split bool) { var ( offset int n int err error ) defer wg.Done() offset = 0 for { if offset+1024 < len(data) { if split { _, err = r.WaitAvailable() if err != nil { t.Logf("reader got error on WaitAvailable: %v", err) break } n, err = r.ReadAhead(data[offset : offset+1024]) if err != nil { t.Logf("reader got error on ReadAhead: %v", err) break } n -= 5 _, err = r.Seek(int64(n), io.SeekCurrent) if err != nil { t.Logf("reader got error on Seek: %v", err) break } } else { n, err = r.Read(data[offset : offset+1024]) if err != nil { t.Logf("reader got error: %v", err) break } } } else { n, err = r.Read(data[offset:]) if err != nil { t.Logf("reader got error: %v", err) break } } offset += n } } wg.Add(1) go readFunc(t, r1, dst1, wg, false) wg.Add(1) go readFunc(t, r2, dst2, wg, false) wg.Add(1) go readFunc(t, r3, dst3, wg, true) wg.Wait() if !bytes.Equal(src, dst1) { t.Error("src and dst1 mismatch") } if !bytes.Equal(src, dst2) { t.Error("src and dst2 mismatch") } if !bytes.Equal(src, dst3) { t.Error("src and dst3 mismatch") } } func TestReaderClose(t *testing.T) { var ( m *MultiReader r1 *Reader r2 *Reader r3 *Reader src, dst, dst1, dst3 []byte err error found bool ctr int ) src = make([]byte, 16) dst = make([]byte, 32) dst1 = dst[:16] dst3 = dst[16:] m = NewMultiReader(0) r1 = m.NewReader() r2 = m.NewReader() r3 = m.NewReader() _, _ = rand.Read(src) _, err = m.Write(src) if err != nil { t.Fatalf("error setting up data: %v", err) } found = false for ctr = range m.pos { if m.pos[ctr].id == r2.id { found = true break } } if !found { t.Error("unexpected behavior, r2 should exist in reader position list") } _ = r2.Close() _, err = r2.Read(dst) if err == nil && err != ErrReaderClosed { t.Error("unexpected behavior, r2 should be closed") } found = false for ctr = range m.pos { if m.pos[ctr].id == r2.id { found = true break } } if found { t.Error("unexpected behavior, r2 should be removed from reader position list") } _, err = r1.Read(dst1) if err != nil { t.Errorf("failing to read from r1") } if !bytes.Equal(dst1, src) { t.Error("invalid result reading from r1") } _ = m.Close() _, err = r3.Read(dst3) if err != nil { t.Errorf("unexpected behavior, cannot read form r3: %v", err) } if !bytes.Equal(src, dst3) { t.Error("invalid result reading from r3") } _, err = r3.WaitAvailable() if err == nil { t.Error("unexpected behavior; WaitAvailable when readpos == lastWritePos should fail") } _, err = r3.ReadAhead(dst3) if err == nil { t.Error("unexpected behavior; ReadAhead when readpos == lastWritePos should fail") } } func TestNotify(t *testing.T) { src := make([]byte, 16) _, _ = rand.Read(src) wg := &sync.WaitGroup{} cr := make(chan int, 3) m := NewMultiReader(0) r := make([]*Reader, 0) r = append(r, m.NewReader(), m.NewReader(), m.NewReader()) for i := range r { r[i].NotifyFunc(func(ix int) { wg.Add(1) defer wg.Done() cr <- ix }) } _, _ = m.Write(src) time.Sleep(time.Millisecond) wg.Wait() nctr := 0 stop := false for !stop { select { case rd := <-cr: nctr++ if rd != 16 { t.Error("unexpected notify value") } default: stop = true } } if nctr != 3 { t.Error("unexpected number of notify event") } _, _ = r[0].Seek(16, io.SeekCurrent) _, _ = r[1].Seek(8, io.SeekCurrent) _, _ = r[2].Seek(16, io.SeekCurrent) _, _ = m.Write(src[:8]) time.Sleep(time.Millisecond) wg.Wait() stop = false nctr = 0 for !stop { select { case rd := <-cr: nctr++ if rd != 8 { t.Error("unexpected notify value") } default: stop = true } } if nctr != 2 { t.Error("unexpected number of notify event") } }