diff --git a/internal/pmtud/apple_ipv4.go b/internal/pmtud/apple_ipv4.go new file mode 100644 index 000000000..6b298d792 --- /dev/null +++ b/internal/pmtud/apple_ipv4.go @@ -0,0 +1,49 @@ +package pmtud + +import ( + "net" + "time" + + "golang.org/x/net/ipv4" +) + +var _ net.PacketConn = &ipv4Wrapper{} + +// ipv4Wrapper is a wrapper around ipv4.PacketConn to implement +// the net.PacketConn interface. It's only used for Darwin or iOS. +type ipv4Wrapper struct { + ipv4Conn *ipv4.PacketConn +} + +func ipv4ToNetPacketConn(ipv4 *ipv4.PacketConn) *ipv4Wrapper { + return &ipv4Wrapper{ipv4Conn: ipv4} +} + +func (i *ipv4Wrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, _, addr, err = i.ipv4Conn.ReadFrom(p) + return n, addr, err +} + +func (i *ipv4Wrapper) WriteTo(p []byte, addr net.Addr) (n int, err error) { + return i.ipv4Conn.WriteTo(p, nil, addr) +} + +func (i *ipv4Wrapper) Close() error { + return i.ipv4Conn.Close() +} + +func (i *ipv4Wrapper) LocalAddr() net.Addr { + return i.ipv4Conn.LocalAddr() +} + +func (i *ipv4Wrapper) SetDeadline(t time.Time) error { + return i.ipv4Conn.SetDeadline(t) +} + +func (i *ipv4Wrapper) SetReadDeadline(t time.Time) error { + return i.ipv4Conn.SetReadDeadline(t) +} + +func (i *ipv4Wrapper) SetWriteDeadline(t time.Time) error { + return i.ipv4Conn.SetWriteDeadline(t) +} diff --git a/internal/pmtud/check.go b/internal/pmtud/check.go new file mode 100644 index 000000000..71f8ff1f5 --- /dev/null +++ b/internal/pmtud/check.go @@ -0,0 +1,83 @@ +package pmtud + +import ( + "bytes" + "errors" + "fmt" + + "golang.org/x/net/icmp" +) + +var ( + ErrICMPNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low") + ErrICMPNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high") +) + +func checkMTU(mtu, minMTU, physicalLinkMTU int) (err error) { + switch { + case mtu < minMTU: + return fmt.Errorf("%w: %d", ErrICMPNextHopMTUTooLow, mtu) + case mtu > physicalLinkMTU: + return fmt.Errorf("%w: %d is larger than physical link MTU %d", + ErrICMPNextHopMTUTooHigh, mtu, physicalLinkMTU) + default: + return nil + } +} + +func checkInvokingReplyIDMatch(icmpProtocol int, received []byte, + outboundMessage *icmp.Message, +) (match bool, err error) { + inboundMessage, err := icmp.ParseMessage(icmpProtocol, received) + if err != nil { + return false, fmt.Errorf("parsing invoking packet: %w", err) + } + inboundBody, ok := inboundMessage.Body.(*icmp.Echo) + if !ok { + return false, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body) + } + outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert + return inboundBody.ID == outboundBody.ID, nil +} + +var ErrICMPIDMismatch = errors.New("ICMP id mismatch") + +func checkEchoReply(icmpProtocol int, received []byte, + outboundMessage *icmp.Message, truncatedBody bool, +) (err error) { + inboundMessage, err := icmp.ParseMessage(icmpProtocol, received) + if err != nil { + return fmt.Errorf("parsing invoking packet: %w", err) + } + inboundBody, ok := inboundMessage.Body.(*icmp.Echo) + if !ok { + return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body) + } + outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert + if inboundBody.ID != outboundBody.ID { + return fmt.Errorf("%w: sent id %d and received id %d", + ErrICMPIDMismatch, outboundBody.ID, inboundBody.ID) + } + err = checkEchoBodies(outboundBody.Data, inboundBody.Data, truncatedBody) + if err != nil { + return fmt.Errorf("checking sent and received bodies: %w", err) + } + return nil +} + +var ErrICMPEchoDataMismatch = errors.New("ICMP data mismatch") + +func checkEchoBodies(sent, received []byte, receivedTruncated bool) (err error) { + if len(received) > len(sent) { + return fmt.Errorf("%w: sent %d bytes and received %d bytes", + ErrICMPEchoDataMismatch, len(sent), len(received)) + } + if receivedTruncated { + sent = sent[:len(received)] + } + if !bytes.Equal(received, sent) { + return fmt.Errorf("%w: sent %x and received %x", + ErrICMPEchoDataMismatch, sent, received) + } + return nil +} diff --git a/internal/pmtud/df.go b/internal/pmtud/df.go new file mode 100644 index 000000000..9e6ee59da --- /dev/null +++ b/internal/pmtud/df.go @@ -0,0 +1,10 @@ +//go:build !linux && !windows + +package pmtud + +// setDontFragment for platforms other than Linux and Windows +// is not implemented, so we just return assuming the don't +// fragment flag is set on IP packets. +func setDontFragment(fd uintptr) (err error) { + return nil +} diff --git a/internal/pmtud/df_linux.go b/internal/pmtud/df_linux.go new file mode 100644 index 000000000..08c7979cb --- /dev/null +++ b/internal/pmtud/df_linux.go @@ -0,0 +1,10 @@ +package pmtud + +import ( + "syscall" +) + +func setDontFragment(fd uintptr) (err error) { + return syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, + syscall.IP_MTU_DISCOVER, syscall.IP_PMTUDISC_PROBE) +} diff --git a/internal/pmtud/df_windows.go b/internal/pmtud/df_windows.go new file mode 100644 index 000000000..0e2616cd3 --- /dev/null +++ b/internal/pmtud/df_windows.go @@ -0,0 +1,11 @@ +package pmtud + +import ( + "syscall" +) + +func setDontFragment(fd uintptr) (err error) { + // https://docs.microsoft.com/en-us/troubleshoot/windows/win32/header-library-requirement-socket-ipproto-ip + // #define IP_DONTFRAGMENT 14 /* don't fragment IP datagrams */ + return syscall.SetsockoptInt(syscall.Handle(fd), syscall.IPPROTO_IP, 14, 1) +} diff --git a/internal/pmtud/errors.go b/internal/pmtud/errors.go new file mode 100644 index 000000000..095deb862 --- /dev/null +++ b/internal/pmtud/errors.go @@ -0,0 +1,10 @@ +package pmtud + +import ( + "errors" +) + +var ( + ErrICMPDestinationUnreachable = errors.New("ICMP destination unreachable") + ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported") +) diff --git a/internal/pmtud/interfaces.go b/internal/pmtud/interfaces.go new file mode 100644 index 000000000..ed7e60973 --- /dev/null +++ b/internal/pmtud/interfaces.go @@ -0,0 +1,5 @@ +package pmtud + +type Logger interface { + Debug(msg string, args ...any) +} diff --git a/internal/pmtud/ipv4.go b/internal/pmtud/ipv4.go new file mode 100644 index 000000000..9b080e0be --- /dev/null +++ b/internal/pmtud/ipv4.go @@ -0,0 +1,145 @@ +package pmtud + +import ( + "context" + "encoding/binary" + "fmt" + "net" + "net/netip" + "runtime" + "syscall" + "time" + + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" +) + +const ( + // see https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media + minIPv4MTU = 68 + icmpv4Protocol = 1 +) + +func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) { + var listenConfig net.ListenConfig + listenConfig.Control = func(_, _ string, rawConn syscall.RawConn) error { + var setDFErr error + err := rawConn.Control(func(fd uintptr) { + setDFErr = setDontFragment(fd) // runs when calling ListenPacket + }) + if err == nil { + err = setDFErr + } + return err + } + + const listenAddress = "" + packetConn, err := listenConfig.ListenPacket(ctx, "ip4:icmp", listenAddress) + if err != nil { + return nil, fmt.Errorf("listening for ICMP packets: %w", err) + } + + if runtime.GOOS == "darwin" || runtime.GOOS == "ios" { + packetConn = ipv4ToNetPacketConn(ipv4.NewPacketConn(packetConn)) + } + + return packetConn, nil +} + +func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr, + physicalLinkMTU int, pingTimeout time.Duration, logger Logger, +) (mtu int, err error) { + if ip.Is6() { + panic("IP address is not v4") + } + conn, err := listenICMPv4(ctx) + if err != nil { + return 0, fmt.Errorf("listening for ICMP packets: %w", err) + } + ctx, cancel := context.WithTimeout(ctx, pingTimeout) + defer cancel() + go func() { + <-ctx.Done() + conn.Close() + }() + + // First try to send a packet which is too big to get the maximum MTU + // directly. + outboundID, outboundMessage := buildMessageToSend("v4", physicalLinkMTU) + encodedMessage, err := outboundMessage.Marshal(nil) + if err != nil { + return 0, fmt.Errorf("encoding ICMP message: %w", err) + } + + _, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()}) + if err != nil { + return 0, fmt.Errorf("writing ICMP message: %w", err) + } + + buffer := make([]byte, physicalLinkMTU) + + for { // for loop in case we read an echo reply for another ICMP request + // Note we need to read the whole packet in one call to ReadFrom, so the buffer + // must be large enough to read the entire reply packet. See: + // https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J + bytesRead, _, err := conn.ReadFrom(buffer) + if err != nil { + return 0, fmt.Errorf("reading from ICMP connection: %w", err) + } + packetBytes := buffer[:bytesRead] + // Side note: echo reply should be at most the number of bytes + // sent, and can be lower, more precisely 576-ipHeader bytes, + // in case the next hop we are reaching replies with a destination + // unreachable and wants to ensure the response makes it way back + // by keeping a low packet size, see: + // https://datatracker.ietf.org/doc/html/rfc1122#page-59 + + inboundMessage, err := icmp.ParseMessage(icmpv4Protocol, packetBytes) + if err != nil { + return 0, fmt.Errorf("parsing message: %w", err) + } + + switch typedBody := inboundMessage.Body.(type) { + case *icmp.DstUnreach: + const fragmentationRequiredAndDFFlagSetCode = 4 + if inboundMessage.Code != fragmentationRequiredAndDFFlagSetCode { + return 0, fmt.Errorf("%w: code %d", + ErrICMPDestinationUnreachable, inboundMessage.Code) + } + + // See https://datatracker.ietf.org/doc/html/rfc1191#section-4 + // Note: the go library does not handle this NextHopMTU section. + nextHopMTU := packetBytes[6:8] + mtu = int(binary.BigEndian.Uint16(nextHopMTU)) + err = checkMTU(mtu, minIPv4MTU, physicalLinkMTU) + if err != nil { + return 0, fmt.Errorf("checking next-hop-mtu found: %w", err) + } + + // The code below is really for sanity checks + packetBytes = packetBytes[8:] + header, err := ipv4.ParseHeader(packetBytes) + if err != nil { + return 0, fmt.Errorf("parsing IPv4 header: %w", err) + } + packetBytes = packetBytes[header.Len:] // truncated original datagram + + const truncated = true + err = checkEchoReply(icmpv4Protocol, packetBytes, outboundMessage, truncated) + if err != nil { + return 0, fmt.Errorf("checking echo reply: %w", err) + } + return mtu, nil + case *icmp.Echo: + inboundID := uint16(typedBody.ID) //nolint:gosec + if inboundID == outboundID { + return physicalLinkMTU, nil + } + logger.Debug("discarding received ICMP echo reply with id %d mismatching sent id %d", + inboundID, outboundID) + continue + default: + return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody) + } + } +} diff --git a/internal/pmtud/ipv6.go b/internal/pmtud/ipv6.go new file mode 100644 index 000000000..a9cc196e2 --- /dev/null +++ b/internal/pmtud/ipv6.go @@ -0,0 +1,116 @@ +package pmtud + +import ( + "context" + "fmt" + "net" + "net/netip" + "time" + + "golang.org/x/net/icmp" + "golang.org/x/net/ipv6" +) + +const ( + minIPv6MTU = 1280 + icmpv6Protocol = 58 +) + +func listenICMPv6(ctx context.Context) (conn net.PacketConn, err error) { + var listenConfig net.ListenConfig + const listenAddress = "" + packetConn, err := listenConfig.ListenPacket(ctx, "ip6:ipv6-icmp", listenAddress) + if err != nil { + return nil, fmt.Errorf("listening for ICMPv6 packets: %w", err) + } + return packetConn, nil +} + +func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr, + physicalLinkMTU int, pingTimeout time.Duration, logger Logger, +) (mtu int, err error) { + if ip.Is4() { + panic("IP address is not v6") + } + conn, err := listenICMPv6(ctx) + if err != nil { + return 0, fmt.Errorf("listening for ICMP packets: %w", err) + } + ctx, cancel := context.WithTimeout(ctx, pingTimeout) + defer cancel() + go func() { + <-ctx.Done() + conn.Close() + }() + + // First try to send a packet which is too big to get the maximum MTU + // directly. + outboundID, outboundMessage := buildMessageToSend("v6", physicalLinkMTU) + encodedMessage, err := outboundMessage.Marshal(nil) + if err != nil { + return 0, fmt.Errorf("encoding ICMP message: %w", err) + } + + _, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice(), Zone: ip.Zone()}) + if err != nil { + return 0, fmt.Errorf("writing ICMP message: %w", err) + } + + buffer := make([]byte, physicalLinkMTU) + + for { // for loop if we encounter another ICMP packet with an unknown id. + // Note we need to read the whole packet in one call to ReadFrom, so the buffer + // must be large enough to read the entire reply packet. See: + // https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J + bytesRead, _, err := conn.ReadFrom(buffer) + if err != nil { + return 0, fmt.Errorf("reading from ICMP connection: %w", err) + } + packetBytes := buffer[:bytesRead] + + packetBytes = packetBytes[ipv6.HeaderLen:] + + inboundMessage, err := icmp.ParseMessage(icmpv6Protocol, packetBytes) + if err != nil { + return 0, fmt.Errorf("parsing message: %w", err) + } + + switch typedBody := inboundMessage.Body.(type) { + case *icmp.PacketTooBig: + // https://datatracker.ietf.org/doc/html/rfc1885#section-3.2 + mtu = typedBody.MTU + err = checkMTU(mtu, minIPv6MTU, physicalLinkMTU) + if err != nil { + return 0, fmt.Errorf("checking MTU: %w", err) + } + + // Sanity checks + const truncatedBody = true + err = checkEchoReply(icmpv6Protocol, typedBody.Data, outboundMessage, truncatedBody) + if err != nil { + return 0, fmt.Errorf("checking invoking message: %w", err) + } + return typedBody.MTU, nil + case *icmp.DstUnreach: + // https://datatracker.ietf.org/doc/html/rfc1885#section-3.1 + idMatch, err := checkInvokingReplyIDMatch(icmpv6Protocol, packetBytes, outboundMessage) + if err != nil { + return 0, fmt.Errorf("checking invoking message id: %w", err) + } else if idMatch { + return 0, fmt.Errorf("%w", ErrICMPDestinationUnreachable) + } + logger.Debug("discarding received ICMP destination unreachable reply with an unknown id") + continue + case *icmp.Echo: + inboundID := uint16(typedBody.ID) //nolint:gosec + if inboundID == outboundID { + return physicalLinkMTU, nil + } + logger.Debug("discarding received ICMP echo reply with id %d mismatching sent id %d", + inboundID, outboundID) + continue + default: + return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody) + } + } +} diff --git a/internal/pmtud/message.go b/internal/pmtud/message.go new file mode 100644 index 000000000..f04c7a897 --- /dev/null +++ b/internal/pmtud/message.go @@ -0,0 +1,58 @@ +package pmtud + +import ( + cryptorand "crypto/rand" + "encoding/binary" + "fmt" + "math/rand/v2" + + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +func buildMessageToSend(ipVersion string, mtu int) (id uint16, message *icmp.Message) { + var seed [32]byte + _, _ = cryptorand.Read(seed[:]) + randomSource := rand.NewChaCha8(seed) + + const uint16Bytes = 2 + idBytes := make([]byte, uint16Bytes) + _, _ = randomSource.Read(idBytes) + id = binary.BigEndian.Uint16(idBytes) + + var ipHeaderLength int + var icmpType icmp.Type + switch ipVersion { + case "v4": + ipHeaderLength = ipv4.HeaderLen + icmpType = ipv4.ICMPTypeEcho + case "v6": + ipHeaderLength = ipv6.HeaderLen + icmpType = ipv6.ICMPTypeEchoRequest + default: + panic(fmt.Sprintf("IP version %q not supported", ipVersion)) + } + const pingHeaderLength = 0 + + 1 + // type + 1 + // code + 2 + // checksum + 2 + // identifier + 2 // sequence number + pingBodyDataSize := mtu - ipHeaderLength - pingHeaderLength + messageBodyData := make([]byte, pingBodyDataSize) + _, _ = randomSource.Read(messageBodyData) + + // See https://www.iana.org/assignments/icmp-parameters/icmp-parameters.xhtml#icmp-parameters-types + message = &icmp.Message{ + Type: icmpType, // echo request + Code: 0, // no code + Checksum: 0, // calculated at encoding (ipv4) or sending (ipv6) + Body: &icmp.Echo{ + ID: int(id), + Seq: 0, // only one packet + Data: messageBodyData, + }, + } + return id, message +} diff --git a/internal/pmtud/nooplogger.go b/internal/pmtud/nooplogger.go new file mode 100644 index 000000000..d899296c0 --- /dev/null +++ b/internal/pmtud/nooplogger.go @@ -0,0 +1,5 @@ +package pmtud + +type noopLogger struct{} + +func (noopLogger) Debug(_ string, _ ...any) {} diff --git a/internal/pmtud/pmtud.go b/internal/pmtud/pmtud.go new file mode 100644 index 000000000..3a12064f4 --- /dev/null +++ b/internal/pmtud/pmtud.go @@ -0,0 +1,269 @@ +package pmtud + +import ( + "context" + "errors" + "fmt" + "math" + "net" + "net/netip" + "time" + + "golang.org/x/net/icmp" +) + +// If the physicalLinkMTU is zero, it defaults to 1500 which is the ethernet standard MTU. +// If the pingTimeout is zero, it defaults to 1 second. +// If the logger is nil, a no-op logger is used. +func PathMTUDiscover(ctx context.Context, ip netip.Addr, + physicalLinkMTU int, pingTimeout time.Duration, logger Logger) ( + mtu int, err error, +) { + if physicalLinkMTU == 0 { + const ethernetStandardMTU = 1500 + physicalLinkMTU = ethernetStandardMTU + } + if pingTimeout == 0 { + pingTimeout = time.Second + } + if logger == nil { + logger = &noopLogger{} + } + + if ip.Is4() { + mtu, err = findIPv4NextHopMTU(ctx, ip, physicalLinkMTU, pingTimeout, logger) + switch { + case err == nil: + return mtu, nil + case errors.Is(err, net.ErrClosed): // blackhole + default: + return 0, fmt.Errorf("finding IPv4 next hop MTU: %w", err) + } + } else { + mtu, err = getIPv6PacketTooBig(ctx, ip, physicalLinkMTU, pingTimeout, logger) + switch { + case err == nil: + return mtu, nil + case errors.Is(err, net.ErrClosed): // blackhole + default: + return 0, fmt.Errorf("getting IPv6 packet-too-big message: %w", err) + } + } + + // Fall back method: send echo requests with different packet + // sizes and check which ones succeed to find the maximum MTU. + minMTU := minIPv4MTU + if ip.Is6() { + minMTU = minIPv6MTU + } + return pmtudMultiSizes(ctx, ip, minMTU, physicalLinkMTU, pingTimeout, logger) +} + +type pmtudTestUnit struct { + mtu int + echoID uint16 + sentBytes int + ok bool +} + +func pmtudMultiSizes(ctx context.Context, ip netip.Addr, + minMTU, maxPossibleMTU int, pingTimeout time.Duration, + logger Logger, +) (maxMTU int, err error) { + var ipVersion string + var conn net.PacketConn + if ip.Is4() { + ipVersion = "v4" + conn, err = listenICMPv4(ctx) + } else { + ipVersion = "v6" + conn, err = listenICMPv6(ctx) + } + if err != nil { + return 0, fmt.Errorf("listening for ICMP packets: %w", err) + } + + mtusToTest := makeMTUsToTest(minMTU, maxPossibleMTU) + if len(mtusToTest) == 0 { + return minMTU, nil + } + logger.Debug("testing the following MTUs: %v", mtusToTest) + + tests := make([]pmtudTestUnit, len(mtusToTest)) + for i := range mtusToTest { + tests[i] = pmtudTestUnit{mtu: mtusToTest[i]} + } + + timedCtx, cancel := context.WithTimeout(ctx, pingTimeout) + defer cancel() + go func() { + <-timedCtx.Done() + conn.Close() + }() + + for i := range tests { + id, message := buildMessageToSend(ipVersion, tests[i].mtu) + tests[i].echoID = id + + encodedMessage, err := message.Marshal(nil) + if err != nil { + return 0, fmt.Errorf("encoding ICMP message: %w", err) + } + tests[i].sentBytes = len(encodedMessage) + + _, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()}) + if err != nil { + return 0, fmt.Errorf("writing ICMP message: %w", err) + } + } + + err = collectReplies(conn, ipVersion, tests) + switch { + case err == nil: // max possible MTU is working + return tests[len(tests)-1].mtu, nil + case err != nil && errors.Is(err, net.ErrClosed): + // we have timeouts (IPv4 testing or IPv6 PMTUD blackholes) + // so find the highest MTU which worked. + // Note we start from index len(tests) - 2 since the max MTU + // cannot be working if we had a timeout. + for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd + if tests[i].ok { + return pmtudMultiSizes(ctx, ip, tests[i].mtu, tests[i+1].mtu, + pingTimeout, logger) + } + } + + // All MTUs failed. + if tests[0].mtu == minMTU+1 { + return minMTU, nil + } + // Re-test with MTUs between the minimum MTU + // and the smallest next MTU we tested. + return pmtudMultiSizes(ctx, ip, minMTU, tests[0].mtu, + pingTimeout, logger) + case err != nil: + return 0, fmt.Errorf("collecting ICMP echo replies: %w", err) + default: + panic("unreachable") + } +} + +// Create the MTU slice of length 8 such that: +// - the first element is the minMTU plus the step +// - the last element is the maxMTU +// - elements in-between are separated as close to each other +// - Don't make the minMTU part of the MTUs to test since +// it's assumed it's already working. +func makeMTUsToTest(minMTU, maxMTU int) (mtus []int) { + const mtusLength = 8 + diff := maxMTU - minMTU + switch { + case minMTU > maxMTU: + panic("minMTU > maxMTU") + case diff <= mtusLength: + mtus = make([]int, 0, diff) + for mtu := minMTU + 1; mtu <= maxMTU; mtu++ { + mtus = append(mtus, mtu) + } + case diff < mtusLength*2: + step := float64(diff) / float64(mtusLength) + mtus = make([]int, 0, mtusLength) + for mtu := float64(minMTU) + step; len(mtus) < mtusLength-1; mtu += step { + mtus = append(mtus, int(math.Round(mtu))) + } + mtus = append(mtus, maxMTU) // last element is the maxMTU + default: + step := diff / mtusLength + mtus = make([]int, 0, mtusLength) + for mtu := minMTU + step; len(mtus) < mtusLength-1; mtu += step { + mtus = append(mtus, mtu) + } + mtus = append(mtus, maxMTU) // last element is the maxMTU + } + + return mtus +} + +func collectReplies(conn net.PacketConn, ipVersion string, + tests []pmtudTestUnit, +) (err error) { + echoIDToTestIndex := make(map[uint16]int, len(tests)) + for i, test := range tests { + echoIDToTestIndex[test.echoID] = i + } + + // The theoretical limit is 4GiB for IPv6 MTU path discovery jumbograms, but that would + // create huge buffers which we don't really want to support anyway. + // The standard frame maximum MTU is 1500 bytes, and there are Jumbo frames with + // a conventional maximum of 9000 bytes. However, some manufacturers support up + // 9216-20 = 9196 bytes for the maximum MTU. We thus use buffers of size 9196 to + // match eventual Jumbo frames. More information at: + // https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media + const maxPossibleMTU = 9196 + buffer := make([]byte, maxPossibleMTU) + + idsFound := 0 + for idsFound < len(tests) { + // Note we need to read the whole packet in one call to ReadFrom, so the buffer + // must be large enough to read the entire reply packet. See: + // https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J + bytesRead, _, err := conn.ReadFrom(buffer) + if err != nil { + return fmt.Errorf("reading from ICMP connection: %w", err) + } + packetBytes := buffer[:bytesRead] + + ipPacketLength := len(packetBytes) + + var icmpProtocol int + switch ipVersion { + case "v4": + icmpProtocol = icmpv4Protocol + case "v6": + icmpProtocol = icmpv6Protocol + default: + panic(fmt.Sprintf("unknown IP version: %s", ipVersion)) + } + + // Parse the ICMP message + // Note: this parsing works for a truncated 556 bytes ICMP reply packet. + message, err := icmp.ParseMessage(icmpProtocol, packetBytes) + if err != nil { + return fmt.Errorf("parsing message: %w", err) + } + + echoBody, ok := message.Body.(*icmp.Echo) + if !ok { + return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, message.Body) + } + + id := uint16(echoBody.ID) //nolint:gosec + testIndex, testing := echoIDToTestIndex[id] + if !testing { // not an id we expected so ignore it + // TODO log warning + fmt.Println("ID not sent:", echoBody.ID, message.Code, message.Type, ipPacketLength) + continue + } + fmt.Println("ID ok:", echoBody.ID, message.Code, message.Type, ipPacketLength) + idsFound++ + sentBytes := tests[testIndex].sentBytes + + // echo reply should be at most the number of bytes sent, + // and can be lower, more precisely 556 bytes, in case + // the host we are reaching wants to stay out of trouble + // and ensure its echo reply goes through without + // fragmentation, see the following page: + // https://datatracker.ietf.org/doc/html/rfc1122#page-59 + const conservativeReplyLength = 556 + truncated := ipPacketLength < sentBytes && + ipPacketLength == conservativeReplyLength + // Check the packet size is the same if the reply is not truncated + if !truncated && sentBytes != ipPacketLength { + return fmt.Errorf("%w: sent %dB and received %dB", + ErrICMPEchoDataMismatch, sentBytes, ipPacketLength) + } + // Truncated reply or matching reply size + tests[testIndex].ok = true + } + return nil +} diff --git a/internal/pmtud/pmtud_integration_test.go b/internal/pmtud/pmtud_integration_test.go new file mode 100644 index 000000000..468da1e69 --- /dev/null +++ b/internal/pmtud/pmtud_integration_test.go @@ -0,0 +1,22 @@ +//go:build integration + +package pmtud + +import ( + "context" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func Test_PathMTUDiscover(t *testing.T) { + t.Parallel() + const physicalLinkMTU = 1500 + const timeout = time.Second + mtu, err := PathMTUDiscover(context.Background(), netip.MustParseAddr("1.1.1.1"), + physicalLinkMTU, timeout, nil) + require.NoError(t, err) + t.Log("MTU found:", mtu) +} diff --git a/internal/pmtud/pmtud_test.go b/internal/pmtud/pmtud_test.go new file mode 100644 index 000000000..45879a928 --- /dev/null +++ b/internal/pmtud/pmtud_test.go @@ -0,0 +1,55 @@ +package pmtud + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_makeMTUsToTest(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + minMTU int + maxMTU int + mtus []int + }{ + "0_0": { + mtus: []int{}, + }, + "0_1": { + maxMTU: 1, + mtus: []int{1}, + }, + "0_8": { + maxMTU: 8, + mtus: []int{1, 2, 3, 4, 5, 6, 7, 8}, + }, + "0_12": { + maxMTU: 12, + mtus: []int{2, 3, 5, 6, 8, 9, 11, 12}, + }, + "0_80": { + maxMTU: 80, + mtus: []int{10, 20, 30, 40, 50, 60, 70, 80}, + }, + "0_100": { + maxMTU: 100, + mtus: []int{12, 24, 36, 48, 60, 72, 84, 100}, + }, + "1280_1500": { + minMTU: 1280, + maxMTU: 1500, + mtus: []int{1307, 1334, 1361, 1388, 1415, 1442, 1469, 1500}, + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + t.Parallel() + + mtus := makeMTUsToTest(testCase.minMTU, testCase.maxMTU) + assert.Equal(t, testCase.mtus, mtus) + }) + } +}