Skip to content

Commit

Permalink
Fixed RPC ForwardPacket
Browse files Browse the repository at this point in the history
packet is always padded to 4 bytes
  • Loading branch information
alpinskiy committed Jan 13, 2025
1 parent e56ed22 commit e76f739
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 218 deletions.
12 changes: 6 additions & 6 deletions internal/aggregator/ingress_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -416,15 +416,15 @@ func (p *proxyConn) run() {
if firstReq.tip == rpcInvokeReqHeaderTLTag {
break
}
log.Printf("Client skip #%d looking for invoke request, addr %v\n", firstReq.tip, p.clientConn.RemoteAddr())
p.rareLog("Client skip #%d looking for invoke request, addr %v\n", firstReq.tip, p.clientConn.RemoteAddr())
}
shardReplica := firstReq.shardReplica(p)
upstreamAddr := p.agent.GetConfigResult.Addresses[shardReplica]
log.Printf("Connect shard replica %d, addr %s < %s\n", shardReplica, p.clientConn.LocalAddr(), p.clientConn.RemoteAddr())
p.rareLog("Connect shard replica %d, addr %s < %s\n", shardReplica, p.clientConn.LocalAddr(), p.clientConn.RemoteAddr())
// connect upstream
upstreamConn, err := net.DialTimeout("tcp", upstreamAddr, rpc.DefaultPacketTimeout)
if err != nil {
log.Printf("error connect upstream addr %s < %s: %v\n", upstreamAddr, p.clientConn.RemoteAddr(), err)
p.rareLog("error connect upstream addr %s < %s: %v\n", upstreamAddr, p.clientConn.RemoteAddr(), err)
_ = firstReq.WriteReponseAndFlush(p.clientConn, err)
return
}
Expand Down Expand Up @@ -469,7 +469,7 @@ func (p *proxyConn) run() {
// no timeout for connection graceful shutdown (has server level shutdown timeout)
ctx = context.Background()
}
log.Printf("Disconnect shard replica %d, addr %v < %v, graceful %t, request %s, response %s\n", shardReplica, upstreamAddr, p.clientConn.RemoteAddr(), gracefulShutdown, reqLoopRes.String(), respLoopRes.String())
p.rareLog("Disconnect shard replica %d, addr %v < %v, graceful %t, request %s, response %s\n", shardReplica, upstreamAddr, p.clientConn.RemoteAddr(), gracefulShutdown, reqLoopRes.String(), respLoopRes.String())
}

func (p *proxyConn) requestLoop(ctx context.Context) (res rpc.ForwardPacketsResult) {
Expand Down Expand Up @@ -573,7 +573,7 @@ func (p *proxyConn) logClientError(tag string, err error, lastPackets rpc.Packet
if p.clientConn != nil {
addr = p.clientConn.RemoteAddr()
}
log.Printf("error %s, client addr %s, version %d, key 0x%X: %v, %s\n", tag, addr, p.clientProtocolVersion, p.clientCryptoKeyID, err, lastPackets.String())
p.rareLog("error %s, client addr %s, version %d, key 0x%X: %v, %s\n", tag, addr, p.clientProtocolVersion, p.clientCryptoKeyID, err, lastPackets.String())
}

func (p *proxyConn) logUpstreamError(tag string, err error, lastPackets rpc.PacketHeaderCircularBuffer) {
Expand All @@ -584,7 +584,7 @@ func (p *proxyConn) logUpstreamError(tag string, err error, lastPackets rpc.Pack
if p.upstreamConn != nil {
addr = p.upstreamConn.RemoteAddr()
}
log.Printf("error %s, upstream addr %s: %v, %s\n", tag, addr, err, lastPackets.String())
p.rareLog("error %s, upstream addr %s: %v, %s\n", tag, addr, err, lastPackets.String())
}

func (req *proxyRequest) process(p *proxyConn) (res rpc.ForwardPacketsResult) {
Expand Down
81 changes: 15 additions & 66 deletions internal/vkgo/rpc/statshouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
)

var errZeroRead = fmt.Errorf("read returned zero bytes without an error")
var errNonZeroPadding = fmt.Errorf("non-zero padding")

type ReadWriteError struct {
ReadErr error
Expand Down Expand Up @@ -163,7 +164,7 @@ func forwardPacket(dst, src *PacketConn, header *packetHeader) (res ReadWriteErr
return res
}
// write body
if res = copyBodyCheckedSkipCryptoPadding(dst, src, header); res.Error() != nil {
if res = forwardPacketBody(dst, src, header); res.Error() != nil {
return res
}
if 0 < legacyWriteAlignTo4 && legacyWriteAlignTo4 < 4 {
Expand All @@ -178,7 +179,7 @@ func forwardPacket(dst, src *PacketConn, header *packetHeader) (res ReadWriteErr
return res
}

func copyBodyCheckedSkipCryptoPadding(dst, src *PacketConn, header *packetHeader) (res ReadWriteError) {
func forwardPacketBody(dst, src *PacketConn, header *packetHeader) (res ReadWriteError) {
src.readMu.Lock()
defer src.readMu.Unlock()
// copy body
Expand All @@ -202,9 +203,18 @@ func copyBodyCheckedSkipCryptoPadding(dst, src *PacketConn, header *packetHeader
}
return res
}
// skip crypto padding
if src.w.isEncrypted() {
res.ReadErr = src.r.discard(int(-header.length & 3))
// skip padding
if n := int(-header.length & 3); n != 0 {
if _, err := io.ReadFull(src.r, src.headerReadBuf[:n]); err != nil {
res.ReadErr = err
return res
}
for i := 0; i < n; i++ {
if src.headerReadBuf[i] != 0 {
res.ReadErr = errNonZeroPadding
return res
}
}
}
return res
}
Expand Down Expand Up @@ -330,64 +340,3 @@ func cryptoCopy(dst *cryptoWriter, src *cryptoReader, n int, readCRC uint32, tab
}
}
}

func (src *cryptoReader) discard(n int) error {
if n == 0 {
return nil
}
// discard decrypted
if m := src.end - src.begin; m > 0 {
if m > n {
m = n
}
src.begin += m
n -= m
if n == 0 {
return nil
}
}
// discard encrypted
buf := src.buf[:cap(src.buf)]
m := len(src.buf) - src.end
src.begin, src.end = 0, 0
if m > 0 {
if m > n {
m = n
}
src.buf = buf[:copy(buf, src.buf[src.end+m:])]
n -= m
if n == 0 {
return nil
}
}
// read and discard
var err error
for {
var read int
if read, err = src.r.Read(buf); read < n {
if err != nil {
buf = buf[:read]
break // read error
}
if read <= 0 {
buf = buf[:0]
err = errZeroRead
break // infinite loop
}
n -= read
} else {
buf = buf[:copy(buf, buf[n:read])]
break // success
}
}
// restore invariant by decrypting the read buffer
src.buf = buf
if src.enc != nil {
decrypt := roundDownPow2(len(buf), src.blockSize)
src.enc.CryptBlocks(buf[:decrypt], buf[:decrypt])
src.end = decrypt
} else {
src.end = len(buf)
}
return err
}
146 changes: 0 additions & 146 deletions internal/vkgo/rpc/statshouse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,10 @@ package rpc

import (
"bytes"
"encoding/binary"
"fmt"
"io"
"log"
"net"
"os"
"testing"
"time"

"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/gopacket/pcap"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
"pgregory.net/rapid"
Expand Down Expand Up @@ -41,14 +33,6 @@ func (c *cryptoPipelineMachine) Write(t *rapid.T) {
}
}

func (c *cryptoPipelineMachine) Discard(t *rapid.T) {
n := rapid.IntRange(0, c.rb.Len()).Draw(t, "n")
if err := c.r.discard(n); err != nil {
c.fatalf("discard failed: %v", err)
}
c.expected = append(c.expected[:c.offset], c.expected[c.offset+n:]...)
}

func (c *cryptoPipelineMachine) ReadDiscard(t *rapid.T) {
n := rapid.IntRange(0, c.rb.Len()).Draw(t, "n")
m, err := c.r.Read(make([]byte, n))
Expand Down Expand Up @@ -220,133 +204,3 @@ func TestForwardPacket(t *testing.T) {
shutdown()
})
}

type pcapEndpoint struct {
host string
port layers.TCPPort
}

func (e pcapEndpoint) Network() string {
return "ip"
}

func (e pcapEndpoint) String() string {
return fmt.Sprintf("%s:%d", e.host, e.port)
}

type testConn struct {
localAddr net.Addr
remoteAddr net.Addr
buffer []byte
offset int
}

func (c *testConn) Read(b []byte) (n int, err error) {
if c.offset == len(c.buffer) {
return 0, io.EOF
}
n = copy(b, c.buffer[c.offset:])
c.offset += n
return n, nil
}

func (c *testConn) Write(b []byte) (int, error) {
return len(b), nil // nop
}

func (c *testConn) Close() error {
c.offset = len(c.buffer)
return nil
}

func (c *testConn) LocalAddr() net.Addr {
return c.localAddr
}

func (c *testConn) RemoteAddr() net.Addr {
return c.remoteAddr
}

func (c *testConn) SetDeadline(_ time.Time) error {
return nil
}

func (c *testConn) SetReadDeadline(_ time.Time) error {
return nil
}

func (c *testConn) SetWriteDeadline(_ time.Time) error {
return nil
}

// NB! remove "received unexpected pong" assertion fot test to pass
func TestPlayPcap(t *testing.T) {
path := os.Getenv("STATSHOUSE_TEST_PLAY_PCAP_FILE_PATH")
if path == "" {
return
}
log.Println("PCAP play", path)
for k, v := range readPCAP(t, path, "") {
playPcap(t, k, v)
}
}

func playPcap(t *testing.T, k [2]pcapEndpoint, v []byte) {
srcConn := &testConn{
buffer: v,
localAddr: k[0],
remoteAddr: k[1],
}
src := &PacketConn{
conn: srcConn,
timeoutAccuracy: DefaultConnTimeoutAccuracy,
r: newCryptoReader(srcConn, DefaultServerRequestBufSize),
w: newCryptoWriter(srcConn, DefaultServerResponseBufSize),
readSeqNum: int64(binary.LittleEndian.Uint32(v[4:])),
}
dstConn := &testConn{
localAddr: k[0],
remoteAddr: k[1],
}
dst := &PacketConn{
conn: dstConn,
timeoutAccuracy: DefaultConnTimeoutAccuracy,
r: newCryptoReader(dstConn, DefaultServerRequestBufSize),
w: newCryptoWriter(dstConn, DefaultServerResponseBufSize),
table: castagnoliTable,
}
var buf PacketHeaderCircularBuffer
for {
res := ForwardPacket(dst, src, forwardPacketOptions{testEnv: true})
buf.add(res.packetHeader)
if res.Error() != nil {
require.ErrorIsf(t, res.ReadErr, io.EOF, "%v %s", k, buf.String())
require.NoError(t, res.WriteErr)
break
}
}
}

func readPCAP(t *testing.T, path string, dstHost string) map[[2]pcapEndpoint][]byte {
handle, err := pcap.OpenOffline(path)
require.NoError(t, err)
packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
m := map[[2]pcapEndpoint][]byte{}
for p := range packetSource.Packets() {
var src, dst pcapEndpoint
ip := p.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
src.host = ip.SrcIP.String()
dst.host = ip.DstIP.String()
if dstHost != "" && dst.host != dstHost {
continue
}
tcp := p.Layer(layers.LayerTypeTCP).(*layers.TCP)
src.port = tcp.SrcPort
dst.port = tcp.DstPort
if appLayer := p.ApplicationLayer(); appLayer != nil {
k := [2]pcapEndpoint{src, dst}
m[k] = append(m[k], appLayer.Payload()...)
}
}
return m
}
Loading

0 comments on commit e76f739

Please sign in to comment.