diff --git a/listener.go b/listener.go index 16522e7..f1ee6cf 100644 --- a/listener.go +++ b/listener.go @@ -6,6 +6,7 @@ import ( "net" "os" "os/signal" + "runtime" "sync" "syscall" "time" @@ -302,7 +303,6 @@ func (l *Listener) Close() error { if conf.Debug.Verbose { log.Println("Listener is waiting for all tasks to finish...") } - // Stop consuming go func() { l.interrupts <- syscall.SIGINT @@ -313,7 +313,35 @@ func (l *Listener) Close() error { if conf.Debug.Verbose { log.Println("Listener is shutting down.") } + runtime.Gosched() + return nil +} +// CloseSync closes the Listener in a syncronous manner. +func (l *Listener) CloseSync() error { + if conf.Debug.Verbose { + log.Println("Listener is waiting for all tasks to finish...") + } + var err error + // Stop consuming + select { + case l.interrupts <- syscall.SIGINT: + break + default: + if conf.Debug.Verbose { + log.Println("Already closing listener.") + } + runtime.Gosched() + return err + } + l.wg.Wait() + for l.IsConsuming() { + runtime.Gosched() + } + if conf.Debug.Verbose { + log.Println("Listener is shutting down.") + } + runtime.Gosched() return nil } diff --git a/listener_test.go b/listener_test.go index 1c66b13..f3b9b4e 100644 --- a/listener_test.go +++ b/listener_test.go @@ -35,6 +35,25 @@ func TestListenerStop(t *testing.T) { listener.Close() } +func TestListenerSyncStop(t *testing.T) { + listener, _ := new(Listener).Init() + listener.NewEndpoint(testEndpoint, "stream-name") + + Convey("Given a running listener", t, func() { + go listener.Listen(func(msg []byte, wg *sync.WaitGroup) { + wg.Done() + }) + + Convey("It should stop listening if sent an interrupt signal", func() { + err := listener.CloseSync() + So(err, ShouldBeNil) + So(listener.IsListening(), ShouldEqual, false) + }) + }) + + listener.Close() +} + func TestListenerError(t *testing.T) { listener, _ := new(Listener).Init() listener.NewEndpoint(testEndpoint, "stream-name") diff --git a/producer.go b/producer.go index d8d19ba..de55677 100644 --- a/producer.go +++ b/producer.go @@ -5,6 +5,7 @@ import ( "log" "os" "os/signal" + "runtime" "strconv" "sync" "syscall" @@ -370,7 +371,35 @@ func (p *Producer) Close() error { if conf.Debug.Verbose { log.Println("Producer is shutting down.") } + runtime.Gosched() + return nil +} +// CloseSync closes the Producer in a syncronous manner. +func (p *Producer) CloseSync() error { + if conf.Debug.Verbose { + log.Println("Listener is waiting for all tasks to finish...") + } + var err error + // Stop consuming + select { + case p.interrupts <- syscall.SIGINT: + break + default: + if conf.Debug.Verbose { + log.Println("Already closing listener.") + } + runtime.Gosched() + return err + } + p.wg.Wait() + for p.IsProducing() { + runtime.Gosched() + } + if conf.Debug.Verbose { + log.Println("Listener is shutting down.") + } + runtime.Gosched() return nil } diff --git a/producer_test.go b/producer_test.go index 86ae8b4..4c545ec 100644 --- a/producer_test.go +++ b/producer_test.go @@ -36,6 +36,24 @@ func TestProducerStop(t *testing.T) { producer.Close() } +func TestSyncStop(t *testing.T) { + producer, _ := new(Producer).Init() + producer.NewEndpoint(testEndpoint, "stream-name") + + Convey("Given a running producer", t, func() { + go producer.produce() + runtime.Gosched() + Convey("It should stop producing if sent an interrupt signal", func() { + err := producer.CloseSync() + So(err, ShouldBeNil) + // Wait for it to stop + So(producer.IsProducing(), ShouldEqual, false) + }) + }) + + producer.Close() +} + func TestProducerError(t *testing.T) { producer, _ := new(Producer).Init() producer.NewEndpoint(testEndpoint, "stream-name")