diff --git a/multiread_test.go b/multiread_test.go index ac3f53f..05ca206 100644 --- a/multiread_test.go +++ b/multiread_test.go @@ -6,6 +6,7 @@ import ( "math/rand" "sync" "testing" + "time" ) func TestMultiRead(t *testing.T) { @@ -226,3 +227,65 @@ func TestReaderClose(t *testing.T) { } } + +func TestNotify(t *testing.T) { + src := make([]byte, 16) + _, _ = rand.Read(src) + + wg := &sync.WaitGroup{} + cr := make(chan int, 3) + m := NewMultiReader(0) + r := make([]*Reader, 0) + r = append(r, m.NewReader(), m.NewReader(), m.NewReader()) + for i := range r { + r[i].NotifyFunc(func(ix int) { + wg.Add(1) + defer wg.Done() + cr <- ix + }) + } + _, _ = m.Write(src) + time.Sleep(time.Millisecond) + wg.Wait() + + nctr := 0 + stop := false + for !stop { + select { + case rd := <-cr: + nctr++ + if rd != 16 { + t.Error("unexpected notify value") + } + default: + stop = true + } + } + if nctr != 3 { + t.Error("unexpected number of notify event") + } + + _, _ = r[0].Seek(16, io.SeekCurrent) + _, _ = r[1].Seek(8, io.SeekCurrent) + _, _ = r[2].Seek(16, io.SeekCurrent) + + _, _ = m.Write(src[:8]) + time.Sleep(time.Millisecond) + wg.Wait() + stop = false + nctr = 0 + for !stop { + select { + case rd := <-cr: + nctr++ + if rd != 8 { + t.Error("unexpected notify value") + } + default: + stop = true + } + } + if nctr != 2 { + t.Error("unexpected number of notify event") + } +} diff --git a/multireader.go b/multireader.go index 49ce804..cab23b1 100644 --- a/multireader.go +++ b/multireader.go @@ -1,4 +1,3 @@ -//Package multireader provides capability for single writer to multiple reader package multireader import ( @@ -21,6 +20,7 @@ type MultiReader struct { buffer []byte readers []*Reader pos []*Reader + notif []*Reader lastWritePos int closed bool } @@ -29,6 +29,7 @@ type Reader struct { multiReader *MultiReader mtx *sync.Mutex cond *sync.Cond + notify func(int) id int readPos int closed bool @@ -48,24 +49,39 @@ func NewMultiReader(len int) *MultiReader { buffer: buff, readers: make([]*Reader, 0), pos: make([]*Reader, 0), + notif: make([]*Reader, 0), lastWritePos: 0, closed: false, } } +func (m *MultiReader) notify(dataLen int) { + m.cond.Broadcast() + for _, reader := range m.notif { + if reader.notify != nil && reader.readPos >= m.lastWritePos-dataLen && !reader.closed { + go func() { + reader.notify(dataLen) + }() + } + } +} + // Write implements Writer interface. The function will not return error. Error return only satisfy the interface. func (m *MultiReader) Write(data []byte) (int, error) { m.mtx.Lock() defer m.mtx.Unlock() - defer m.cond.Broadcast() var ( tmpBuff []byte + dataLen int ) if m.closed { return 0, ErrMultiReaderClosed } + defer func(ip *int) { + m.notify(*ip) + }(&dataLen) if m.lastWritePos+len(data) > len(m.buffer) { tmpBuff = make([]byte, m.lastWritePos+len(data)-len(m.buffer)) @@ -73,6 +89,8 @@ func (m *MultiReader) Write(data []byte) (int, error) { } copy(m.buffer[m.lastWritePos:], data) m.lastWritePos += len(data) + dataLen = len(data) + //_ = dataLen //prevents lint complaining about ineffectual assignment return len(data), nil } @@ -146,12 +164,44 @@ func (m *MultiReader) NewReader() *Reader { return reader } +// NotifyFunc sets up a callback function notify. The notify function will be called when there is a Write operation. +// Notify function passed to the NotifyFunc will executed in a separate go routine for each notify function. +// It's the caller responsibility to setup any synchronization to prevent data race. +func (r *Reader) NotifyFunc(notify func(int)) { + r.mtx.Lock() + defer r.mtx.Unlock() + + r.notify = notify + + var x int + switch len(r.multiReader.notif) { + case 0: + r.multiReader.notif = append(r.multiReader.notif, r) + case 1: + if r.multiReader.notif[0].id != r.id { + goto notifyFuncAppend2 + } + default: + goto notifyFuncAppend1 + } + return + +notifyFuncAppend1: + x = sort.Search(len(r.multiReader.notif), func(i int) bool { return r.multiReader.notif[i].id >= r.id }) + if x < len(r.multiReader.notif) && r.id == x { + return + } + +notifyFuncAppend2: + r.multiReader.notif = append(r.multiReader.notif, r) + sort.Slice(r.multiReader.notif, func(i, j int) bool { return r.multiReader.notif[i].id < r.multiReader.notif[j].id }) +} + // Read implements Reader interface. If the reader reached the same position with the writer. It will wait. // Calling Read() is the same with calling WaitAvailable(), ReadAhead(), and Seek() func (r *Reader) Read(data []byte) (int, error) { r.mtx.Lock() defer r.mtx.Unlock() - defer r.cond.Broadcast() var ( n int @@ -225,9 +275,6 @@ func (r *Reader) waitAvailable() (int, error) { } for r.readPos >= r.multiReader.lastWritePos { - if r.closed { - return 0, ErrReaderClosed - } if r.multiReader.closed { return 0, ErrMultiReaderClosed }