-
-
Notifications
You must be signed in to change notification settings - Fork 391
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
848 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
package pmtud | ||
|
||
import ( | ||
"net" | ||
"time" | ||
|
||
"golang.org/x/net/ipv4" | ||
) | ||
|
||
var _ net.PacketConn = &ipv4Wrapper{} | ||
|
||
// ipv4Wrapper is a wrapper around ipv4.PacketConn to implement | ||
// the net.PacketConn interface. It's only used for Darwin or iOS. | ||
type ipv4Wrapper struct { | ||
ipv4Conn *ipv4.PacketConn | ||
} | ||
|
||
func ipv4ToNetPacketConn(ipv4 *ipv4.PacketConn) *ipv4Wrapper { | ||
return &ipv4Wrapper{ipv4Conn: ipv4} | ||
} | ||
|
||
func (i *ipv4Wrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) { | ||
n, _, addr, err = i.ipv4Conn.ReadFrom(p) | ||
return n, addr, err | ||
} | ||
|
||
func (i *ipv4Wrapper) WriteTo(p []byte, addr net.Addr) (n int, err error) { | ||
return i.ipv4Conn.WriteTo(p, nil, addr) | ||
} | ||
|
||
func (i *ipv4Wrapper) Close() error { | ||
return i.ipv4Conn.Close() | ||
} | ||
|
||
func (i *ipv4Wrapper) LocalAddr() net.Addr { | ||
return i.ipv4Conn.LocalAddr() | ||
} | ||
|
||
func (i *ipv4Wrapper) SetDeadline(t time.Time) error { | ||
return i.ipv4Conn.SetDeadline(t) | ||
} | ||
|
||
func (i *ipv4Wrapper) SetReadDeadline(t time.Time) error { | ||
return i.ipv4Conn.SetReadDeadline(t) | ||
} | ||
|
||
func (i *ipv4Wrapper) SetWriteDeadline(t time.Time) error { | ||
return i.ipv4Conn.SetWriteDeadline(t) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
package pmtud | ||
|
||
import ( | ||
"bytes" | ||
"errors" | ||
"fmt" | ||
|
||
"golang.org/x/net/icmp" | ||
) | ||
|
||
var ( | ||
ErrICMPNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low") | ||
ErrICMPNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high") | ||
) | ||
|
||
func checkMTU(mtu, minMTU, physicalLinkMTU int) (err error) { | ||
switch { | ||
case mtu < minMTU: | ||
return fmt.Errorf("%w: %d", ErrICMPNextHopMTUTooLow, mtu) | ||
case mtu > physicalLinkMTU: | ||
return fmt.Errorf("%w: %d is larger than physical link MTU %d", | ||
ErrICMPNextHopMTUTooHigh, mtu, physicalLinkMTU) | ||
default: | ||
return nil | ||
} | ||
} | ||
|
||
func checkInvokingReplyIDMatch(icmpProtocol int, received []byte, | ||
outboundMessage *icmp.Message, | ||
) (match bool, err error) { | ||
inboundMessage, err := icmp.ParseMessage(icmpProtocol, received) | ||
if err != nil { | ||
return false, fmt.Errorf("parsing invoking packet: %w", err) | ||
} | ||
inboundBody, ok := inboundMessage.Body.(*icmp.Echo) | ||
if !ok { | ||
return false, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body) | ||
} | ||
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert | ||
return inboundBody.ID == outboundBody.ID, nil | ||
} | ||
|
||
var ErrICMPIDMismatch = errors.New("ICMP id mismatch") | ||
|
||
func checkEchoReply(icmpProtocol int, received []byte, | ||
outboundMessage *icmp.Message, truncatedBody bool, | ||
) (err error) { | ||
inboundMessage, err := icmp.ParseMessage(icmpProtocol, received) | ||
if err != nil { | ||
return fmt.Errorf("parsing invoking packet: %w", err) | ||
} | ||
inboundBody, ok := inboundMessage.Body.(*icmp.Echo) | ||
if !ok { | ||
return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body) | ||
} | ||
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert | ||
if inboundBody.ID != outboundBody.ID { | ||
return fmt.Errorf("%w: sent id %d and received id %d", | ||
ErrICMPIDMismatch, outboundBody.ID, inboundBody.ID) | ||
} | ||
err = checkEchoBodies(outboundBody.Data, inboundBody.Data, truncatedBody) | ||
if err != nil { | ||
return fmt.Errorf("checking sent and received bodies: %w", err) | ||
} | ||
return nil | ||
} | ||
|
||
var ErrICMPEchoDataMismatch = errors.New("ICMP data mismatch") | ||
|
||
func checkEchoBodies(sent, received []byte, receivedTruncated bool) (err error) { | ||
if len(received) > len(sent) { | ||
return fmt.Errorf("%w: sent %d bytes and received %d bytes", | ||
ErrICMPEchoDataMismatch, len(sent), len(received)) | ||
} | ||
if receivedTruncated { | ||
sent = sent[:len(received)] | ||
} | ||
if !bytes.Equal(received, sent) { | ||
return fmt.Errorf("%w: sent %x and received %x", | ||
ErrICMPEchoDataMismatch, sent, received) | ||
} | ||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
//go:build !linux && !windows | ||
|
||
package pmtud | ||
|
||
// setDontFragment for platforms other than Linux and Windows | ||
// is not implemented, so we just return assuming the don't | ||
// fragment flag is set on IP packets. | ||
func setDontFragment(fd uintptr) (err error) { | ||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
package pmtud | ||
|
||
import ( | ||
"syscall" | ||
) | ||
|
||
func setDontFragment(fd uintptr) (err error) { | ||
return syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, | ||
syscall.IP_MTU_DISCOVER, syscall.IP_PMTUDISC_PROBE) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
package pmtud | ||
|
||
import ( | ||
"syscall" | ||
) | ||
|
||
func setDontFragment(fd uintptr) (err error) { | ||
// https://docs.microsoft.com/en-us/troubleshoot/windows/win32/header-library-requirement-socket-ipproto-ip | ||
// #define IP_DONTFRAGMENT 14 /* don't fragment IP datagrams */ | ||
return syscall.SetsockoptInt(syscall.Handle(fd), syscall.IPPROTO_IP, 14, 1) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
package pmtud | ||
|
||
import ( | ||
"errors" | ||
) | ||
|
||
var ( | ||
ErrICMPDestinationUnreachable = errors.New("ICMP destination unreachable") | ||
ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported") | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
package pmtud | ||
|
||
type Logger interface { | ||
Debug(msg string, args ...any) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
package pmtud | ||
|
||
import ( | ||
"context" | ||
"encoding/binary" | ||
"fmt" | ||
"net" | ||
"net/netip" | ||
"runtime" | ||
"syscall" | ||
"time" | ||
|
||
"golang.org/x/net/icmp" | ||
"golang.org/x/net/ipv4" | ||
) | ||
|
||
const ( | ||
// see https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media | ||
minIPv4MTU = 68 | ||
icmpv4Protocol = 1 | ||
) | ||
|
||
func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) { | ||
var listenConfig net.ListenConfig | ||
listenConfig.Control = func(_, _ string, rawConn syscall.RawConn) error { | ||
var setDFErr error | ||
err := rawConn.Control(func(fd uintptr) { | ||
setDFErr = setDontFragment(fd) // runs when calling ListenPacket | ||
}) | ||
if err == nil { | ||
err = setDFErr | ||
} | ||
return err | ||
} | ||
|
||
const listenAddress = "" | ||
packetConn, err := listenConfig.ListenPacket(ctx, "ip4:icmp", listenAddress) | ||
if err != nil { | ||
return nil, fmt.Errorf("listening for ICMP packets: %w", err) | ||
} | ||
|
||
if runtime.GOOS == "darwin" || runtime.GOOS == "ios" { | ||
packetConn = ipv4ToNetPacketConn(ipv4.NewPacketConn(packetConn)) | ||
} | ||
|
||
return packetConn, nil | ||
} | ||
|
||
func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr, | ||
physicalLinkMTU int, pingTimeout time.Duration, logger Logger, | ||
) (mtu int, err error) { | ||
if ip.Is6() { | ||
panic("IP address is not v4") | ||
} | ||
conn, err := listenICMPv4(ctx) | ||
if err != nil { | ||
return 0, fmt.Errorf("listening for ICMP packets: %w", err) | ||
} | ||
ctx, cancel := context.WithTimeout(ctx, pingTimeout) | ||
defer cancel() | ||
go func() { | ||
<-ctx.Done() | ||
conn.Close() | ||
}() | ||
|
||
// First try to send a packet which is too big to get the maximum MTU | ||
// directly. | ||
outboundID, outboundMessage := buildMessageToSend("v4", physicalLinkMTU) | ||
encodedMessage, err := outboundMessage.Marshal(nil) | ||
if err != nil { | ||
return 0, fmt.Errorf("encoding ICMP message: %w", err) | ||
} | ||
|
||
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()}) | ||
if err != nil { | ||
return 0, fmt.Errorf("writing ICMP message: %w", err) | ||
} | ||
|
||
buffer := make([]byte, physicalLinkMTU) | ||
|
||
for { // for loop in case we read an echo reply for another ICMP request | ||
// Note we need to read the whole packet in one call to ReadFrom, so the buffer | ||
// must be large enough to read the entire reply packet. See: | ||
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J | ||
bytesRead, _, err := conn.ReadFrom(buffer) | ||
if err != nil { | ||
return 0, fmt.Errorf("reading from ICMP connection: %w", err) | ||
} | ||
packetBytes := buffer[:bytesRead] | ||
// Side note: echo reply should be at most the number of bytes | ||
// sent, and can be lower, more precisely 576-ipHeader bytes, | ||
// in case the next hop we are reaching replies with a destination | ||
// unreachable and wants to ensure the response makes it way back | ||
// by keeping a low packet size, see: | ||
// https://datatracker.ietf.org/doc/html/rfc1122#page-59 | ||
|
||
inboundMessage, err := icmp.ParseMessage(icmpv4Protocol, packetBytes) | ||
if err != nil { | ||
return 0, fmt.Errorf("parsing message: %w", err) | ||
} | ||
|
||
switch typedBody := inboundMessage.Body.(type) { | ||
case *icmp.DstUnreach: | ||
const fragmentationRequiredAndDFFlagSetCode = 4 | ||
if inboundMessage.Code != fragmentationRequiredAndDFFlagSetCode { | ||
return 0, fmt.Errorf("%w: code %d", | ||
ErrICMPDestinationUnreachable, inboundMessage.Code) | ||
} | ||
|
||
// See https://datatracker.ietf.org/doc/html/rfc1191#section-4 | ||
// Note: the go library does not handle this NextHopMTU section. | ||
nextHopMTU := packetBytes[6:8] | ||
mtu = int(binary.BigEndian.Uint16(nextHopMTU)) | ||
err = checkMTU(mtu, minIPv4MTU, physicalLinkMTU) | ||
if err != nil { | ||
return 0, fmt.Errorf("checking next-hop-mtu found: %w", err) | ||
} | ||
|
||
// The code below is really for sanity checks | ||
packetBytes = packetBytes[8:] | ||
header, err := ipv4.ParseHeader(packetBytes) | ||
if err != nil { | ||
return 0, fmt.Errorf("parsing IPv4 header: %w", err) | ||
} | ||
packetBytes = packetBytes[header.Len:] // truncated original datagram | ||
|
||
const truncated = true | ||
err = checkEchoReply(icmpv4Protocol, packetBytes, outboundMessage, truncated) | ||
if err != nil { | ||
return 0, fmt.Errorf("checking echo reply: %w", err) | ||
} | ||
return mtu, nil | ||
case *icmp.Echo: | ||
inboundID := uint16(typedBody.ID) //nolint:gosec | ||
if inboundID == outboundID { | ||
return physicalLinkMTU, nil | ||
} | ||
logger.Debug("discarding received ICMP echo reply with id %d mismatching sent id %d", | ||
inboundID, outboundID) | ||
continue | ||
default: | ||
return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody) | ||
} | ||
} | ||
} |
Oops, something went wrong.