diff --git a/internal/aggregator/ingress_proxy.go b/internal/aggregator/ingress_proxy.go index 4bdca4ca7..2dbfb04fd 100644 --- a/internal/aggregator/ingress_proxy.go +++ b/internal/aggregator/ingress_proxy.go @@ -416,15 +416,15 @@ func (p *proxyConn) run() { if firstReq.tip == rpcInvokeReqHeaderTLTag { break } - log.Printf("Client skip #%d looking for invoke request, addr %v\n", firstReq.tip, p.clientConn.RemoteAddr()) + p.rareLog("Client skip #%d looking for invoke request, addr %v\n", firstReq.tip, p.clientConn.RemoteAddr()) } shardReplica := firstReq.shardReplica(p) upstreamAddr := p.agent.GetConfigResult.Addresses[shardReplica] - log.Printf("Connect shard replica %d, addr %s < %s\n", shardReplica, p.clientConn.LocalAddr(), p.clientConn.RemoteAddr()) + p.rareLog("Connect shard replica %d, addr %s < %s\n", shardReplica, p.clientConn.LocalAddr(), p.clientConn.RemoteAddr()) // connect upstream upstreamConn, err := net.DialTimeout("tcp", upstreamAddr, rpc.DefaultPacketTimeout) if err != nil { - log.Printf("error connect upstream addr %s < %s: %v\n", upstreamAddr, p.clientConn.RemoteAddr(), err) + p.rareLog("error connect upstream addr %s < %s: %v\n", upstreamAddr, p.clientConn.RemoteAddr(), err) _ = firstReq.WriteReponseAndFlush(p.clientConn, err) return } @@ -469,7 +469,7 @@ func (p *proxyConn) run() { // no timeout for connection graceful shutdown (has server level shutdown timeout) ctx = context.Background() } - log.Printf("Disconnect shard replica %d, addr %v < %v, graceful %t, request %s, response %s\n", shardReplica, upstreamAddr, p.clientConn.RemoteAddr(), gracefulShutdown, reqLoopRes.String(), respLoopRes.String()) + p.rareLog("Disconnect shard replica %d, addr %v < %v, graceful %t, request %s, response %s\n", shardReplica, upstreamAddr, p.clientConn.RemoteAddr(), gracefulShutdown, reqLoopRes.String(), respLoopRes.String()) } func (p *proxyConn) requestLoop(ctx context.Context) (res rpc.ForwardPacketsResult) { @@ -573,7 +573,7 @@ func (p *proxyConn) logClientError(tag string, err error, lastPackets rpc.Packet if p.clientConn != nil { addr = p.clientConn.RemoteAddr() } - log.Printf("error %s, client addr %s, version %d, key 0x%X: %v, %s\n", tag, addr, p.clientProtocolVersion, p.clientCryptoKeyID, err, lastPackets.String()) + p.rareLog("error %s, client addr %s, version %d, key 0x%X: %v, %s\n", tag, addr, p.clientProtocolVersion, p.clientCryptoKeyID, err, lastPackets.String()) } func (p *proxyConn) logUpstreamError(tag string, err error, lastPackets rpc.PacketHeaderCircularBuffer) { @@ -584,7 +584,7 @@ func (p *proxyConn) logUpstreamError(tag string, err error, lastPackets rpc.Pack if p.upstreamConn != nil { addr = p.upstreamConn.RemoteAddr() } - log.Printf("error %s, upstream addr %s: %v, %s\n", tag, addr, err, lastPackets.String()) + p.rareLog("error %s, upstream addr %s: %v, %s\n", tag, addr, err, lastPackets.String()) } func (req *proxyRequest) process(p *proxyConn) (res rpc.ForwardPacketsResult) { diff --git a/internal/vkgo/rpc/statshouse.go b/internal/vkgo/rpc/statshouse.go index 41ae486db..2537646fc 100644 --- a/internal/vkgo/rpc/statshouse.go +++ b/internal/vkgo/rpc/statshouse.go @@ -13,6 +13,7 @@ import ( ) var errZeroRead = fmt.Errorf("read returned zero bytes without an error") +var errNonZeroPadding = fmt.Errorf("non-zero padding") type ReadWriteError struct { ReadErr error @@ -163,7 +164,7 @@ func forwardPacket(dst, src *PacketConn, header *packetHeader) (res ReadWriteErr return res } // write body - if res = copyBodyCheckedSkipCryptoPadding(dst, src, header); res.Error() != nil { + if res = forwardPacketBody(dst, src, header); res.Error() != nil { return res } if 0 < legacyWriteAlignTo4 && legacyWriteAlignTo4 < 4 { @@ -178,7 +179,7 @@ func forwardPacket(dst, src *PacketConn, header *packetHeader) (res ReadWriteErr return res } -func copyBodyCheckedSkipCryptoPadding(dst, src *PacketConn, header *packetHeader) (res ReadWriteError) { +func forwardPacketBody(dst, src *PacketConn, header *packetHeader) (res ReadWriteError) { src.readMu.Lock() defer src.readMu.Unlock() // copy body @@ -202,9 +203,18 @@ func copyBodyCheckedSkipCryptoPadding(dst, src *PacketConn, header *packetHeader } return res } - // skip crypto padding - if src.w.isEncrypted() { - res.ReadErr = src.r.discard(int(-header.length & 3)) + // skip padding + if n := int(-header.length & 3); n != 0 { + if _, err := io.ReadFull(src.r, src.headerReadBuf[:n]); err != nil { + res.ReadErr = err + return res + } + for i := 0; i < n; i++ { + if src.headerReadBuf[i] != 0 { + res.ReadErr = errNonZeroPadding + return res + } + } } return res } @@ -330,64 +340,3 @@ func cryptoCopy(dst *cryptoWriter, src *cryptoReader, n int, readCRC uint32, tab } } } - -func (src *cryptoReader) discard(n int) error { - if n == 0 { - return nil - } - // discard decrypted - if m := src.end - src.begin; m > 0 { - if m > n { - m = n - } - src.begin += m - n -= m - if n == 0 { - return nil - } - } - // discard encrypted - buf := src.buf[:cap(src.buf)] - m := len(src.buf) - src.end - src.begin, src.end = 0, 0 - if m > 0 { - if m > n { - m = n - } - src.buf = buf[:copy(buf, src.buf[src.end+m:])] - n -= m - if n == 0 { - return nil - } - } - // read and discard - var err error - for { - var read int - if read, err = src.r.Read(buf); read < n { - if err != nil { - buf = buf[:read] - break // read error - } - if read <= 0 { - buf = buf[:0] - err = errZeroRead - break // infinite loop - } - n -= read - } else { - buf = buf[:copy(buf, buf[n:read])] - break // success - } - } - // restore invariant by decrypting the read buffer - src.buf = buf - if src.enc != nil { - decrypt := roundDownPow2(len(buf), src.blockSize) - src.enc.CryptBlocks(buf[:decrypt], buf[:decrypt]) - src.end = decrypt - } else { - src.end = len(buf) - } - return err -} diff --git a/internal/vkgo/rpc/statshouse_test.go b/internal/vkgo/rpc/statshouse_test.go index 283b9a4ec..1aab40180 100644 --- a/internal/vkgo/rpc/statshouse_test.go +++ b/internal/vkgo/rpc/statshouse_test.go @@ -2,18 +2,10 @@ package rpc import ( "bytes" - "encoding/binary" - "fmt" - "io" - "log" "net" - "os" "testing" "time" - "github.com/google/gopacket" - "github.com/google/gopacket/layers" - "github.com/google/gopacket/pcap" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" "pgregory.net/rapid" @@ -41,14 +33,6 @@ func (c *cryptoPipelineMachine) Write(t *rapid.T) { } } -func (c *cryptoPipelineMachine) Discard(t *rapid.T) { - n := rapid.IntRange(0, c.rb.Len()).Draw(t, "n") - if err := c.r.discard(n); err != nil { - c.fatalf("discard failed: %v", err) - } - c.expected = append(c.expected[:c.offset], c.expected[c.offset+n:]...) -} - func (c *cryptoPipelineMachine) ReadDiscard(t *rapid.T) { n := rapid.IntRange(0, c.rb.Len()).Draw(t, "n") m, err := c.r.Read(make([]byte, n)) @@ -220,133 +204,3 @@ func TestForwardPacket(t *testing.T) { shutdown() }) } - -type pcapEndpoint struct { - host string - port layers.TCPPort -} - -func (e pcapEndpoint) Network() string { - return "ip" -} - -func (e pcapEndpoint) String() string { - return fmt.Sprintf("%s:%d", e.host, e.port) -} - -type testConn struct { - localAddr net.Addr - remoteAddr net.Addr - buffer []byte - offset int -} - -func (c *testConn) Read(b []byte) (n int, err error) { - if c.offset == len(c.buffer) { - return 0, io.EOF - } - n = copy(b, c.buffer[c.offset:]) - c.offset += n - return n, nil -} - -func (c *testConn) Write(b []byte) (int, error) { - return len(b), nil // nop -} - -func (c *testConn) Close() error { - c.offset = len(c.buffer) - return nil -} - -func (c *testConn) LocalAddr() net.Addr { - return c.localAddr -} - -func (c *testConn) RemoteAddr() net.Addr { - return c.remoteAddr -} - -func (c *testConn) SetDeadline(_ time.Time) error { - return nil -} - -func (c *testConn) SetReadDeadline(_ time.Time) error { - return nil -} - -func (c *testConn) SetWriteDeadline(_ time.Time) error { - return nil -} - -// NB! remove "received unexpected pong" assertion fot test to pass -func TestPlayPcap(t *testing.T) { - path := os.Getenv("STATSHOUSE_TEST_PLAY_PCAP_FILE_PATH") - if path == "" { - return - } - log.Println("PCAP play", path) - for k, v := range readPCAP(t, path, "") { - playPcap(t, k, v) - } -} - -func playPcap(t *testing.T, k [2]pcapEndpoint, v []byte) { - srcConn := &testConn{ - buffer: v, - localAddr: k[0], - remoteAddr: k[1], - } - src := &PacketConn{ - conn: srcConn, - timeoutAccuracy: DefaultConnTimeoutAccuracy, - r: newCryptoReader(srcConn, DefaultServerRequestBufSize), - w: newCryptoWriter(srcConn, DefaultServerResponseBufSize), - readSeqNum: int64(binary.LittleEndian.Uint32(v[4:])), - } - dstConn := &testConn{ - localAddr: k[0], - remoteAddr: k[1], - } - dst := &PacketConn{ - conn: dstConn, - timeoutAccuracy: DefaultConnTimeoutAccuracy, - r: newCryptoReader(dstConn, DefaultServerRequestBufSize), - w: newCryptoWriter(dstConn, DefaultServerResponseBufSize), - table: castagnoliTable, - } - var buf PacketHeaderCircularBuffer - for { - res := ForwardPacket(dst, src, forwardPacketOptions{testEnv: true}) - buf.add(res.packetHeader) - if res.Error() != nil { - require.ErrorIsf(t, res.ReadErr, io.EOF, "%v %s", k, buf.String()) - require.NoError(t, res.WriteErr) - break - } - } -} - -func readPCAP(t *testing.T, path string, dstHost string) map[[2]pcapEndpoint][]byte { - handle, err := pcap.OpenOffline(path) - require.NoError(t, err) - packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) - m := map[[2]pcapEndpoint][]byte{} - for p := range packetSource.Packets() { - var src, dst pcapEndpoint - ip := p.Layer(layers.LayerTypeIPv4).(*layers.IPv4) - src.host = ip.SrcIP.String() - dst.host = ip.DstIP.String() - if dstHost != "" && dst.host != dstHost { - continue - } - tcp := p.Layer(layers.LayerTypeTCP).(*layers.TCP) - src.port = tcp.SrcPort - dst.port = tcp.DstPort - if appLayer := p.ApplicationLayer(); appLayer != nil { - k := [2]pcapEndpoint{src, dst} - m[k] = append(m[k], appLayer.Payload()...) - } - } - return m -} diff --git a/internal/vkgo/rpc/statshouse_test_pcap.go b/internal/vkgo/rpc/statshouse_test_pcap.go new file mode 100644 index 000000000..7bb8980ea --- /dev/null +++ b/internal/vkgo/rpc/statshouse_test_pcap.go @@ -0,0 +1,149 @@ +//go:build ignore + +package rpc + +import ( + "encoding/binary" + "fmt" + "io" + "log" + "net" + "os" + "testing" + "time" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/google/gopacket/pcap" + "github.com/stretchr/testify/require" +) + +type pcapEndpoint struct { + host string + port layers.TCPPort +} + +func (e pcapEndpoint) Network() string { + return "ip" +} + +func (e pcapEndpoint) String() string { + return fmt.Sprintf("%s:%d", e.host, e.port) +} + +type testConn struct { + localAddr net.Addr + remoteAddr net.Addr + buffer []byte + offset int +} + +func (c *testConn) Read(b []byte) (n int, err error) { + if c.offset == len(c.buffer) { + return 0, io.EOF + } + n = copy(b, c.buffer[c.offset:]) + c.offset += n + return n, nil +} + +func (c *testConn) Write(b []byte) (int, error) { + return len(b), nil // nop +} + +func (c *testConn) Close() error { + c.offset = len(c.buffer) + return nil +} + +func (c *testConn) LocalAddr() net.Addr { + return c.localAddr +} + +func (c *testConn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +func (c *testConn) SetDeadline(_ time.Time) error { + return nil +} + +func (c *testConn) SetReadDeadline(_ time.Time) error { + return nil +} + +func (c *testConn) SetWriteDeadline(_ time.Time) error { + return nil +} + +// NB! remove "received unexpected pong" assertion fot test to pass +func TestPlayPcap(t *testing.T) { + path := os.Getenv("STATSHOUSE_TEST_PLAY_PCAP_FILE_PATH") + if path == "" { + return + } + log.Println("PCAP play", path) + for k, v := range readPCAP(t, path, "") { + playPcap(t, k, v) + } +} + +func playPcap(t *testing.T, k [2]pcapEndpoint, v []byte) { + srcConn := &testConn{ + buffer: v, + localAddr: k[0], + remoteAddr: k[1], + } + src := &PacketConn{ + conn: srcConn, + timeoutAccuracy: DefaultConnTimeoutAccuracy, + r: newCryptoReader(srcConn, DefaultServerRequestBufSize), + w: newCryptoWriter(srcConn, DefaultServerResponseBufSize), + readSeqNum: int64(binary.LittleEndian.Uint32(v[4:])), + } + dstConn := &testConn{ + localAddr: k[0], + remoteAddr: k[1], + } + dst := &PacketConn{ + conn: dstConn, + timeoutAccuracy: DefaultConnTimeoutAccuracy, + r: newCryptoReader(dstConn, DefaultServerRequestBufSize), + w: newCryptoWriter(dstConn, DefaultServerResponseBufSize), + table: castagnoliTable, + } + var buf PacketHeaderCircularBuffer + for { + res := ForwardPacket(dst, src, forwardPacketOptions{testEnv: true}) + buf.add(res.packetHeader) + if res.Error() != nil { + require.ErrorIsf(t, res.ReadErr, io.EOF, "%v %s", k, buf.String()) + require.NoError(t, res.WriteErr) + break + } + } +} + +func readPCAP(t *testing.T, path string, dstHost string) map[[2]pcapEndpoint][]byte { + handle, err := pcap.OpenOffline(path) + require.NoError(t, err) + packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) + m := map[[2]pcapEndpoint][]byte{} + for p := range packetSource.Packets() { + var src, dst pcapEndpoint + ip := p.Layer(layers.LayerTypeIPv4).(*layers.IPv4) + src.host = ip.SrcIP.String() + dst.host = ip.DstIP.String() + if dstHost != "" && dst.host != dstHost { + continue + } + tcp := p.Layer(layers.LayerTypeTCP).(*layers.TCP) + src.port = tcp.SrcPort + dst.port = tcp.DstPort + if appLayer := p.ApplicationLayer(); appLayer != nil { + k := [2]pcapEndpoint{src, dst} + m[k] = append(m[k], appLayer.Payload()...) + } + } + return m +}