Skip to content

Commit

Permalink
pmtud package
Browse files Browse the repository at this point in the history
  • Loading branch information
qdm12 committed Nov 14, 2024
1 parent f9bdb21 commit 0bd15cf
Show file tree
Hide file tree
Showing 14 changed files with 848 additions and 0 deletions.
49 changes: 49 additions & 0 deletions internal/pmtud/apple_ipv4.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package pmtud

import (
"net"
"time"

"golang.org/x/net/ipv4"
)

var _ net.PacketConn = &ipv4Wrapper{}

// ipv4Wrapper is a wrapper around ipv4.PacketConn to implement
// the net.PacketConn interface. It's only used for Darwin or iOS.
type ipv4Wrapper struct {
ipv4Conn *ipv4.PacketConn
}

func ipv4ToNetPacketConn(ipv4 *ipv4.PacketConn) *ipv4Wrapper {
return &ipv4Wrapper{ipv4Conn: ipv4}
}

func (i *ipv4Wrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, _, addr, err = i.ipv4Conn.ReadFrom(p)
return n, addr, err
}

func (i *ipv4Wrapper) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return i.ipv4Conn.WriteTo(p, nil, addr)
}

func (i *ipv4Wrapper) Close() error {
return i.ipv4Conn.Close()
}

func (i *ipv4Wrapper) LocalAddr() net.Addr {
return i.ipv4Conn.LocalAddr()
}

func (i *ipv4Wrapper) SetDeadline(t time.Time) error {
return i.ipv4Conn.SetDeadline(t)
}

func (i *ipv4Wrapper) SetReadDeadline(t time.Time) error {
return i.ipv4Conn.SetReadDeadline(t)
}

func (i *ipv4Wrapper) SetWriteDeadline(t time.Time) error {
return i.ipv4Conn.SetWriteDeadline(t)
}
83 changes: 83 additions & 0 deletions internal/pmtud/check.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package pmtud

import (
"bytes"
"errors"
"fmt"

"golang.org/x/net/icmp"
)

var (
ErrICMPNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low")
ErrICMPNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high")
)

func checkMTU(mtu, minMTU, physicalLinkMTU int) (err error) {
switch {
case mtu < minMTU:
return fmt.Errorf("%w: %d", ErrICMPNextHopMTUTooLow, mtu)
case mtu > physicalLinkMTU:
return fmt.Errorf("%w: %d is larger than physical link MTU %d",
ErrICMPNextHopMTUTooHigh, mtu, physicalLinkMTU)
default:
return nil
}
}

func checkInvokingReplyIDMatch(icmpProtocol int, received []byte,
outboundMessage *icmp.Message,
) (match bool, err error) {
inboundMessage, err := icmp.ParseMessage(icmpProtocol, received)
if err != nil {
return false, fmt.Errorf("parsing invoking packet: %w", err)
}
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
if !ok {
return false, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body)
}
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
return inboundBody.ID == outboundBody.ID, nil
}

var ErrICMPIDMismatch = errors.New("ICMP id mismatch")

func checkEchoReply(icmpProtocol int, received []byte,
outboundMessage *icmp.Message, truncatedBody bool,
) (err error) {
inboundMessage, err := icmp.ParseMessage(icmpProtocol, received)
if err != nil {
return fmt.Errorf("parsing invoking packet: %w", err)
}
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
if !ok {
return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body)
}
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
if inboundBody.ID != outboundBody.ID {
return fmt.Errorf("%w: sent id %d and received id %d",
ErrICMPIDMismatch, outboundBody.ID, inboundBody.ID)
}
err = checkEchoBodies(outboundBody.Data, inboundBody.Data, truncatedBody)
if err != nil {
return fmt.Errorf("checking sent and received bodies: %w", err)
}
return nil
}

var ErrICMPEchoDataMismatch = errors.New("ICMP data mismatch")

func checkEchoBodies(sent, received []byte, receivedTruncated bool) (err error) {
if len(received) > len(sent) {
return fmt.Errorf("%w: sent %d bytes and received %d bytes",
ErrICMPEchoDataMismatch, len(sent), len(received))
}
if receivedTruncated {
sent = sent[:len(received)]
}
if !bytes.Equal(received, sent) {
return fmt.Errorf("%w: sent %x and received %x",
ErrICMPEchoDataMismatch, sent, received)
}
return nil
}
10 changes: 10 additions & 0 deletions internal/pmtud/df.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
//go:build !linux && !windows

package pmtud

// setDontFragment for platforms other than Linux and Windows
// is not implemented, so we just return assuming the don't
// fragment flag is set on IP packets.
func setDontFragment(fd uintptr) (err error) {
return nil
}
10 changes: 10 additions & 0 deletions internal/pmtud/df_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package pmtud

import (
"syscall"
)

func setDontFragment(fd uintptr) (err error) {
return syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP,
syscall.IP_MTU_DISCOVER, syscall.IP_PMTUDISC_PROBE)
}
11 changes: 11 additions & 0 deletions internal/pmtud/df_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package pmtud

import (
"syscall"
)

func setDontFragment(fd uintptr) (err error) {
// https://docs.microsoft.com/en-us/troubleshoot/windows/win32/header-library-requirement-socket-ipproto-ip
// #define IP_DONTFRAGMENT 14 /* don't fragment IP datagrams */
return syscall.SetsockoptInt(syscall.Handle(fd), syscall.IPPROTO_IP, 14, 1)
}
10 changes: 10 additions & 0 deletions internal/pmtud/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package pmtud

import (
"errors"
)

var (
ErrICMPDestinationUnreachable = errors.New("ICMP destination unreachable")
ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported")
)
5 changes: 5 additions & 0 deletions internal/pmtud/interfaces.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package pmtud

type Logger interface {
Debug(msg string, args ...any)
}
145 changes: 145 additions & 0 deletions internal/pmtud/ipv4.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package pmtud

import (
"context"
"encoding/binary"
"fmt"
"net"
"net/netip"
"runtime"
"syscall"
"time"

"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
)

const (
// see https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
minIPv4MTU = 68
icmpv4Protocol = 1
)

func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) {
var listenConfig net.ListenConfig
listenConfig.Control = func(_, _ string, rawConn syscall.RawConn) error {
var setDFErr error
err := rawConn.Control(func(fd uintptr) {
setDFErr = setDontFragment(fd) // runs when calling ListenPacket
})
if err == nil {
err = setDFErr
}
return err
}

const listenAddress = ""
packetConn, err := listenConfig.ListenPacket(ctx, "ip4:icmp", listenAddress)
if err != nil {
return nil, fmt.Errorf("listening for ICMP packets: %w", err)
}

if runtime.GOOS == "darwin" || runtime.GOOS == "ios" {
packetConn = ipv4ToNetPacketConn(ipv4.NewPacketConn(packetConn))
}

return packetConn, nil
}

func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
physicalLinkMTU int, pingTimeout time.Duration, logger Logger,
) (mtu int, err error) {
if ip.Is6() {
panic("IP address is not v4")
}
conn, err := listenICMPv4(ctx)
if err != nil {
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, pingTimeout)
defer cancel()
go func() {
<-ctx.Done()
conn.Close()
}()

// First try to send a packet which is too big to get the maximum MTU
// directly.
outboundID, outboundMessage := buildMessageToSend("v4", physicalLinkMTU)
encodedMessage, err := outboundMessage.Marshal(nil)
if err != nil {
return 0, fmt.Errorf("encoding ICMP message: %w", err)
}

_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
if err != nil {
return 0, fmt.Errorf("writing ICMP message: %w", err)
}

buffer := make([]byte, physicalLinkMTU)

for { // for loop in case we read an echo reply for another ICMP request
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
// must be large enough to read the entire reply packet. See:
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
bytesRead, _, err := conn.ReadFrom(buffer)
if err != nil {
return 0, fmt.Errorf("reading from ICMP connection: %w", err)
}
packetBytes := buffer[:bytesRead]
// Side note: echo reply should be at most the number of bytes
// sent, and can be lower, more precisely 576-ipHeader bytes,
// in case the next hop we are reaching replies with a destination
// unreachable and wants to ensure the response makes it way back
// by keeping a low packet size, see:
// https://datatracker.ietf.org/doc/html/rfc1122#page-59

inboundMessage, err := icmp.ParseMessage(icmpv4Protocol, packetBytes)
if err != nil {
return 0, fmt.Errorf("parsing message: %w", err)
}

switch typedBody := inboundMessage.Body.(type) {
case *icmp.DstUnreach:
const fragmentationRequiredAndDFFlagSetCode = 4
if inboundMessage.Code != fragmentationRequiredAndDFFlagSetCode {
return 0, fmt.Errorf("%w: code %d",
ErrICMPDestinationUnreachable, inboundMessage.Code)
}

// See https://datatracker.ietf.org/doc/html/rfc1191#section-4
// Note: the go library does not handle this NextHopMTU section.
nextHopMTU := packetBytes[6:8]
mtu = int(binary.BigEndian.Uint16(nextHopMTU))
err = checkMTU(mtu, minIPv4MTU, physicalLinkMTU)
if err != nil {
return 0, fmt.Errorf("checking next-hop-mtu found: %w", err)
}

// The code below is really for sanity checks
packetBytes = packetBytes[8:]
header, err := ipv4.ParseHeader(packetBytes)
if err != nil {
return 0, fmt.Errorf("parsing IPv4 header: %w", err)
}
packetBytes = packetBytes[header.Len:] // truncated original datagram

const truncated = true
err = checkEchoReply(icmpv4Protocol, packetBytes, outboundMessage, truncated)
if err != nil {
return 0, fmt.Errorf("checking echo reply: %w", err)
}
return mtu, nil
case *icmp.Echo:
inboundID := uint16(typedBody.ID) //nolint:gosec
if inboundID == outboundID {
return physicalLinkMTU, nil
}
logger.Debug("discarding received ICMP echo reply with id %d mismatching sent id %d",
inboundID, outboundID)
continue
default:
return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody)
}
}
}
Loading

0 comments on commit 0bd15cf

Please sign in to comment.