diff --git a/src/trafficcontroller/internal/proxy/websocket_handler_test.go b/src/trafficcontroller/internal/proxy/websocket_handler_test.go index 71979ba85..739cffe51 100644 --- a/src/trafficcontroller/internal/proxy/websocket_handler_test.go +++ b/src/trafficcontroller/internal/proxy/websocket_handler_test.go @@ -4,6 +4,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "sync" "time" "github.com/gorilla/websocket" @@ -23,12 +24,14 @@ var _ = Describe("WebsocketHandler", func() { handlerDone chan struct{} ts *httptest.Server conn *websocket.Conn + wc *websocketClient ) BeforeEach(func() { input = make(chan []byte, 10) keepAlive = 200 * time.Millisecond count = metricemitter.NewCounter("egress", "") + wc = newWebsocketClient() }) JustBeforeEach(func() { @@ -46,21 +49,15 @@ var _ = Describe("WebsocketHandler", func() { u.Scheme = "ws" c, _, err := websocket.DefaultDialer.Dial(u.String(), nil) Expect(err).NotTo(HaveOccurred()) + go wc.Start(c) + conn = c - DeferCleanup(func() { - conn.Close() - }) + // DeferCleanup(func() { + // conn.Close() + // }) }) AfterEach(func() { - select { - case _, ok := <-handlerDone: - if ok { - close(handlerDone) - } - default: - close(handlerDone) - } select { case _, ok := <-input: if ok { @@ -69,6 +66,7 @@ var _ = Describe("WebsocketHandler", func() { default: close(input) } + <-wc.Done }) It("forwards byte arrays from the input channel to the websocket client", func() { @@ -78,14 +76,21 @@ var _ = Describe("WebsocketHandler", func() { } }() + type websocketResp struct { + messageType int + message string + } + expectedResp := websocketResp{messageType: websocket.BinaryMessage, message: "testing"} for i := 0; i < 10; i++ { - msgType, msg, err := conn.ReadMessage() - Expect(err).NotTo(HaveOccurred()) - Expect(msgType).To(Equal(websocket.BinaryMessage)) - Expect(string(msg)).To(Equal("testing")) + Eventually(func() (websocketResp, error) { + msgType, msg, ok := wc.Read() + if !ok { + err, _ := wc.ReadError() + return websocketResp{}, err + } + return websocketResp{messageType: msgType, message: msg}, nil + }).Should(Equal(expectedResp)) } - - close(input) }) Context("when the input channel is closed", func() { @@ -94,23 +99,19 @@ var _ = Describe("WebsocketHandler", func() { }) It("stops", func() { - Eventually(handlerDone, 100*time.Millisecond).Should(BeClosed()) + Eventually(handlerDone, keepAlive/2).Should(BeClosed()) }) - // On 12/14/2023 we upgraded to a new gorilla/websocket that caused this - // test to fail inconsistently, receiving a different error than - // expected: `write tcp 127.0.0.1:42598->127.0.0.1:36689: write: broken - // pipe`. - // - // It("closes the websocket normally", func() { - // Eventually(func() error { - // _, _, err := conn.ReadMessage() - // return err - // }).Should(MatchError(&websocket.CloseError{ - // Code: websocket.CloseNormalClosure, - // Text: "", - // })) - // }) + It("closes the connection", func() { + Eventually(wc.Done, keepAlive/2).Should(BeClosed()) + Eventually(func() error { + err, _ := wc.ReadError() + return err + }).Should(MatchError(&websocket.CloseError{ + Code: websocket.CloseNormalClosure, + Text: "", + })) + }) }) It("does not accept http requests", func() { @@ -125,7 +126,7 @@ var _ = Describe("WebsocketHandler", func() { }) It("stops", func() { - Eventually(handlerDone, 100*time.Millisecond).Should(BeClosed()) + Eventually(handlerDone, keepAlive/2).Should(BeClosed()) }) }) @@ -135,13 +136,13 @@ var _ = Describe("WebsocketHandler", func() { }) It("stops", func() { - timeout := keepAlive + 100*time.Millisecond - Eventually(handlerDone, timeout).Should(BeClosed()) + Eventually(handlerDone).Should(BeClosed()) }) It("closes the connection with a ClosePolicyViolation", func() { + Eventually(wc.Done).Should(BeClosed()) Eventually(func() error { - _, _, err := conn.ReadMessage() + err, _ := wc.ReadError() return err }).Should(MatchError(&websocket.CloseError{ Code: websocket.ClosePolicyViolation, @@ -150,63 +151,15 @@ var _ = Describe("WebsocketHandler", func() { }) }) - // On 12/14/2023 noticed that this test was written such that the - // description didn't match the test, which only checked that the connection - // continued through the keep-alive. Should come back and fix later. - // - // Context("when the client responds to - // pings", func() { - // JustBeforeEach(func() { - // conn.SetPingHandler(func(message string) error { - // err := conn.WriteControl(websocket.PongMessage, []byte(message), time.Now().Add(time.Second)) - // if err == websocket.ErrCloseSent { - // return nil - // } else if _, ok := err.(net.Error); ok { - // return nil - // } - // return err - // }) - // }) - - // It("continues", func() { - // timeout := keepAlive + time.Second - // Consistently(done, timeout).ShouldNot(BeClosed()) - // }) - // }) - - // On 12/14/2023 noticed that this test was written such that the - // description didn't match the test, which only checked that the connection - // continued through the keep-alive. Should come back and fix later. - // - // Context("when the client sends old style keepalives", func() { - // var finish chan struct{} - - // JustBeforeEach(func() { - // finish = make(chan struct{}) - // go func() { - // for { - // _ = conn.WriteMessage(websocket.TextMessage, []byte("I'm alive!")) - // time.Sleep(100 * time.Millisecond) - // select { - // case <-input: - // close(finish) - // return - // default: - // } - // } - // }() - // }) - - // JustAfterEach(func() { - // close(input) - // <-finish - // }) - - // It("continues", func() { - // timeout := keepAlive + time.Second - // Consistently(done, timeout).ShouldNot(BeClosed()) - // }) - // }) + Context("when the client responds to pings", func() { + It("does not stop", func() { + Consistently(handlerDone, keepAlive*2).ShouldNot(BeClosed()) + }) + + It("does not close the connection", func() { + Consistently(wc.Done, keepAlive*2).ShouldNot(BeClosed()) + }) + }) It("keeps a count of every time it writes an envelope", func() { Expect(count.GetDelta()).To(Equal(uint64(0))) @@ -218,3 +171,78 @@ var _ = Describe("WebsocketHandler", func() { func httpToWs(u string) string { return "ws" + u[len("http"):] } + +type websocketClient struct { + mu sync.Mutex + + Done chan struct{} + + readError []error + readMessageType []int + readMessage []string +} + +func newWebsocketClient() *websocketClient { + return &websocketClient{ + Done: make(chan struct{}), + } +} + +func (wc *websocketClient) Start(conn *websocket.Conn) { + defer conn.Close() + defer close(wc.Done) + pongWait := 60 * time.Second + conn.SetReadDeadline(time.Now().Add(pongWait)) + conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) + for { + messageType, message, err := conn.ReadMessage() + wc.mu.Lock() + if err != nil { + wc.readError = append(wc.readError, err) + wc.mu.Unlock() + return + } + wc.readMessageType = append(wc.readMessageType, messageType) + wc.readMessage = append(wc.readMessage, string(message)) + wc.mu.Unlock() + } +} + +func (wc *websocketClient) ReadError() (error, bool) { + wc.mu.Lock() + defer wc.mu.Unlock() + if len(wc.readError) == 0 { + return nil, false + } + err := wc.readError[0] + wc.readError = wc.readError[1:] + return err, true +} + +func (wc *websocketClient) Read() (messageType int, message string, ok bool) { + wc.mu.Lock() + defer wc.mu.Unlock() + if len(wc.readMessageType) == 0 { + return 0, "", false + } + ok = true + messageType = wc.readMessageType[0] + wc.readMessageType = wc.readMessageType[1:] + message = wc.readMessage[0] + wc.readMessage = wc.readMessage[1:] + return +} + +func (wc *websocketClient) Write() (messageType int, message string, ok bool) { + wc.mu.Lock() + defer wc.mu.Unlock() + if len(wc.readMessageType) == 0 { + return 0, "", false + } + ok = true + messageType = wc.readMessageType[0] + wc.readMessageType = wc.readMessageType[1:] + message = wc.readMessage[0] + wc.readMessage = wc.readMessage[1:] + return +}