diff --git a/stack_gvisor.go b/stack_gvisor.go index 60af865..89983f1 100644 --- a/stack_gvisor.go +++ b/stack_gvisor.go @@ -5,7 +5,7 @@ package tun import ( "context" "net/netip" - "os" + "runtime" "time" "github.com/sagernet/gvisor/pkg/tcpip" @@ -17,6 +17,7 @@ import ( "github.com/sagernet/gvisor/pkg/tcpip/transport/icmp" "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" + "github.com/sagernet/gvisor/pkg/waiter" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" @@ -77,17 +78,35 @@ func (t *GVisor) Start() error { destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort) pErr := t.handler.PrepareConnection(N.NetworkTCP, source, destination) if pErr != nil { - r.Complete(gWriteUnreachable(t.stack, r.Packet(), err) == os.ErrInvalid) + r.Complete(pErr != ErrDrop) return } - conn := &gLazyConn{ - parentCtx: t.ctx, - stack: t.stack, - request: r, - localAddr: source.TCPAddr(), - remoteAddr: destination.TCPAddr(), + var ( + wq waiter.Queue + endpoint tcpip.Endpoint + tErr tcpip.Error + ) + handshakeCtx, cancel := context.WithCancel(context.Background()) + go func() { + select { + case <-t.ctx.Done(): + wq.Notify(wq.Events()) + case <-handshakeCtx.Done(): + } + }() + endpoint, tErr = r.CreateEndpoint(&wq) + cancel() + if tErr != nil { + r.Complete(true) + return } - go t.handler.NewConnectionEx(t.ctx, conn, source, destination, nil) + r.Complete(false) + endpoint.SocketOptions().SetKeepAlive(true) + keepAliveIdle := tcpip.KeepaliveIdleOption(15 * time.Second) + endpoint.SetSockOpt(&keepAliveIdle) + keepAliveInterval := tcpip.KeepaliveIntervalOption(15 * time.Second) + endpoint.SetSockOpt(&keepAliveInterval) + go t.handler.NewConnectionEx(t.ctx, gonet.NewTCPConn(&wq, endpoint), source, destination, nil) }) ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket) @@ -134,30 +153,47 @@ func newGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) { icmp.NewProtocol6, }, }) - tErr := ipStack.CreateNIC(defaultNIC, ep) - if tErr != nil { - return nil, E.New("create nic: ", gonet.TranslateNetstackError(tErr)) + err := ipStack.CreateNIC(defaultNIC, ep) + if err != nil { + return nil, gonet.TranslateNetstackError(err) } ipStack.SetRouteTable([]tcpip.Route{ {Destination: header.IPv4EmptySubnet, NIC: defaultNIC}, {Destination: header.IPv6EmptySubnet, NIC: defaultNIC}, }) - ipStack.SetSpoofing(defaultNIC, true) - ipStack.SetPromiscuousMode(defaultNIC, true) - bufSize := 20 * 1024 - ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPReceiveBufferSizeRangeOption{ - Min: 1, - Default: bufSize, - Max: bufSize, - }) - ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPSendBufferSizeRangeOption{ - Min: 1, - Default: bufSize, - Max: bufSize, - }) + err = ipStack.SetSpoofing(defaultNIC, true) + if err != nil { + return nil, gonet.TranslateNetstackError(err) + } + err = ipStack.SetPromiscuousMode(defaultNIC, true) + if err != nil { + return nil, gonet.TranslateNetstackError(err) + } sOpt := tcpip.TCPSACKEnabled(true) ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt) mOpt := tcpip.TCPModerateReceiveBufferOption(true) ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &mOpt) + if runtime.GOOS == "windows" { + tcpRecoveryOpt := tcpip.TCPRecovery(0) + err = ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpRecoveryOpt) + } + tcpRXBufOpt := tcpip.TCPReceiveBufferSizeRangeOption{ + Min: tcpRXBufMinSize, + Default: tcpRXBufDefSize, + Max: tcpRXBufMaxSize, + } + err = ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpRXBufOpt) + if err != nil { + return nil, gonet.TranslateNetstackError(err) + } + tcpTXBufOpt := tcpip.TCPSendBufferSizeRangeOption{ + Min: tcpTXBufMinSize, + Default: tcpTXBufDefSize, + Max: tcpTXBufMaxSize, + } + err = ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpTXBufOpt) + if err != nil { + return nil, gonet.TranslateNetstackError(err) + } return ipStack, nil } diff --git a/stack_gvisor_lazy.go b/stack_gvisor_lazy.go deleted file mode 100644 index 16abdac..0000000 --- a/stack_gvisor_lazy.go +++ /dev/null @@ -1,228 +0,0 @@ -//go:build with_gvisor - -package tun - -import ( - "context" - "errors" - "net" - "os" - "syscall" - "time" - - "github.com/sagernet/gvisor/pkg/tcpip" - "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet" - "github.com/sagernet/gvisor/pkg/tcpip/header" - "github.com/sagernet/gvisor/pkg/tcpip/stack" - "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" - "github.com/sagernet/gvisor/pkg/waiter" -) - -type gLazyConn struct { - tcpConn *gonet.TCPConn - parentCtx context.Context - stack *stack.Stack - request *tcp.ForwarderRequest - localAddr net.Addr - remoteAddr net.Addr - handshakeDone bool - handshakeErr error -} - -func (c *gLazyConn) HandshakeContext(ctx context.Context) error { - if c.handshakeDone { - return nil - } - defer func() { - c.handshakeDone = true - }() - var ( - wq waiter.Queue - endpoint tcpip.Endpoint - ) - handshakeCtx, cancel := context.WithCancel(ctx) - go func() { - select { - case <-c.parentCtx.Done(): - wq.Notify(wq.Events()) - case <-handshakeCtx.Done(): - } - }() - endpoint, err := c.request.CreateEndpoint(&wq) - cancel() - if err != nil { - gErr := gonet.TranslateNetstackError(err) - c.handshakeErr = gErr - c.request.Complete(true) - return gErr - } - c.request.Complete(false) - endpoint.SocketOptions().SetKeepAlive(true) - keepAliveIdle := tcpip.KeepaliveIdleOption(15 * time.Second) - endpoint.SetSockOpt(&keepAliveIdle) - keepAliveInterval := tcpip.KeepaliveIntervalOption(15 * time.Second) - endpoint.SetSockOpt(&keepAliveInterval) - tcpConn := gonet.NewTCPConn(&wq, endpoint) - c.tcpConn = tcpConn - return nil -} - -func (c *gLazyConn) HandshakeFailure(err error) error { - if c.handshakeDone { - return nil - } - c.request.Complete(gWriteUnreachable(c.stack, c.request.Packet(), err) == os.ErrInvalid) - c.handshakeDone = true - c.handshakeErr = err - return nil -} - -func (c *gLazyConn) HandshakeSuccess() error { - return c.HandshakeContext(context.Background()) -} - -func (c *gLazyConn) Read(b []byte) (n int, err error) { - if !c.handshakeDone { - err = c.HandshakeContext(context.Background()) - if err != nil { - return - } - } else if c.handshakeErr != nil { - return 0, c.handshakeErr - } - return c.tcpConn.Read(b) -} - -func (c *gLazyConn) Write(b []byte) (n int, err error) { - if !c.handshakeDone { - err = c.HandshakeContext(context.Background()) - if err != nil { - return - } - } else if c.handshakeErr != nil { - return 0, c.handshakeErr - } - return c.tcpConn.Write(b) -} - -func (c *gLazyConn) LocalAddr() net.Addr { - return c.localAddr -} - -func (c *gLazyConn) RemoteAddr() net.Addr { - return c.remoteAddr -} - -func (c *gLazyConn) SetDeadline(t time.Time) error { - if !c.handshakeDone { - err := c.HandshakeContext(context.Background()) - if err != nil { - return err - } - } else if c.handshakeErr != nil { - return c.handshakeErr - } - return c.tcpConn.SetDeadline(t) -} - -func (c *gLazyConn) SetReadDeadline(t time.Time) error { - if !c.handshakeDone { - err := c.HandshakeContext(context.Background()) - if err != nil { - return err - } - } else if c.handshakeErr != nil { - return c.handshakeErr - } - return c.tcpConn.SetReadDeadline(t) -} - -func (c *gLazyConn) SetWriteDeadline(t time.Time) error { - if !c.handshakeDone { - err := c.HandshakeContext(context.Background()) - if err != nil { - return err - } - } else if c.handshakeErr != nil { - return c.handshakeErr - } - return c.tcpConn.SetWriteDeadline(t) -} - -func (c *gLazyConn) Close() error { - if !c.handshakeDone { - c.request.Complete(true) - c.handshakeErr = net.ErrClosed - return nil - } else if c.handshakeErr != nil { - return nil - } - return c.tcpConn.Close() -} - -func (c *gLazyConn) CloseRead() error { - if !c.handshakeDone { - c.request.Complete(true) - c.handshakeErr = net.ErrClosed - return nil - } else if c.handshakeErr != nil { - return nil - } - return c.tcpConn.CloseRead() -} - -func (c *gLazyConn) CloseWrite() error { - if !c.handshakeDone { - c.request.Complete(true) - c.handshakeErr = net.ErrClosed - return nil - } else if c.handshakeErr != nil { - return nil - } - return c.tcpConn.CloseRead() -} - -func (c *gLazyConn) ReaderReplaceable() bool { - return c.handshakeDone && c.handshakeErr == nil -} - -func (c *gLazyConn) WriterReplaceable() bool { - return c.handshakeDone && c.handshakeErr == nil -} - -func (c *gLazyConn) Upstream() any { - return c.tcpConn -} - -func gWriteUnreachable(gStack *stack.Stack, packet *stack.PacketBuffer, err error) error { - if errors.Is(err, ErrDrop) { - return nil - } else if errors.Is(err, syscall.ENETUNREACH) { - if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber { - return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPNetUnreachable) - } else { - return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute) - } - } else if errors.Is(err, syscall.EHOSTUNREACH) { - if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber { - return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPHostUnreachable) - } else { - return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPAddrUnreachable) - } - } else if errors.Is(err, syscall.ECONNREFUSED) { - if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber { - return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPPortUnreachable) - } else { - return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPPortUnreachable) - } - } - return os.ErrInvalid -} - -func gWriteUnreachable4(gStack *stack.Stack, packet *stack.PacketBuffer, icmpCode stack.RejectIPv4WithICMPType) error { - return gonet.TranslateNetstackError(gStack.NetworkProtocolInstance(header.IPv4ProtocolNumber).(stack.RejectIPv4WithHandler).SendRejectionError(packet, icmpCode, true)) -} - -func gWriteUnreachable6(gStack *stack.Stack, packet *stack.PacketBuffer, icmpCode stack.RejectIPv6WithICMPType) error { - return gonet.TranslateNetstackError(gStack.NetworkProtocolInstance(header.IPv6ProtocolNumber).(stack.RejectIPv6WithHandler).SendRejectionError(packet, icmpCode, true)) -} diff --git a/stack_gvisor_tcpbuf_default.go b/stack_gvisor_tcpbuf_default.go new file mode 100644 index 0000000..f636d1a --- /dev/null +++ b/stack_gvisor_tcpbuf_default.go @@ -0,0 +1,18 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build with_gvisor && !ios + +package tun + +import "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" + +const ( + tcpRXBufMinSize = tcp.MinBufferSize + tcpRXBufDefSize = tcp.DefaultSendBufferSize + tcpRXBufMaxSize = 8 << 20 // 8MiB + + tcpTXBufMinSize = tcp.MinBufferSize + tcpTXBufDefSize = tcp.DefaultReceiveBufferSize + tcpTXBufMaxSize = 6 << 20 // 6MiB +) diff --git a/stack_gvisor_tcpbuf_ios.go b/stack_gvisor_tcpbuf_ios.go new file mode 100644 index 0000000..495e59b --- /dev/null +++ b/stack_gvisor_tcpbuf_ios.go @@ -0,0 +1,21 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build with_gvisor + +package tun + +import "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" + +const ( + // tcp{RX,TX}Buf{Min,Def,Max}Size mirror gVisor defaults. We leave these + // unchanged on iOS for now as to not increase pressure towards the + // NetworkExtension memory limit. + tcpRXBufMinSize = tcp.MinBufferSize + tcpRXBufDefSize = tcp.DefaultSendBufferSize + tcpRXBufMaxSize = tcp.MaxBufferSize + + tcpTXBufMinSize = tcp.MinBufferSize + tcpTXBufDefSize = tcp.DefaultReceiveBufferSize + tcpTXBufMaxSize = tcp.MaxBufferSize +) diff --git a/stack_gvisor_udp.go b/stack_gvisor_udp.go index dd0c8a0..22e7e09 100644 --- a/stack_gvisor_udp.go +++ b/stack_gvisor_udp.go @@ -59,7 +59,9 @@ func rangeIterate(r stack.Range, fn func(*buffer.View)) func (f *UDPForwarder) PreparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) { pErr := f.handler.PrepareConnection(N.NetworkUDP, source, destination) if pErr != nil { - gWriteUnreachable(f.stack, userData.(*stack.PacketBuffer), pErr) + if pErr != ErrDrop { + gWriteUnreachable(f.stack, userData.(*stack.PacketBuffer), pErr) + } return false, nil, nil, nil } var sourceNetwork tcpip.NetworkProtocolNumber @@ -147,3 +149,11 @@ func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Sock route.Stats().UDP.PacketsSent.Increment() return nil } + +func gWriteUnreachable(gStack *stack.Stack, packet *stack.PacketBuffer, err error) error { + if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber { + return gonet.TranslateNetstackError(gStack.NetworkProtocolInstance(header.IPv4ProtocolNumber).(stack.RejectIPv4WithHandler).SendRejectionError(packet, stack.RejectIPv4WithICMPPortUnreachable, true)) + } else { + return gonet.TranslateNetstackError(gStack.NetworkProtocolInstance(header.IPv6ProtocolNumber).(stack.RejectIPv6WithHandler).SendRejectionError(packet, stack.RejectIPv6WithICMPPortUnreachable, true)) + } +} diff --git a/stack_system.go b/stack_system.go index 2baa0c6..a06329e 100644 --- a/stack_system.go +++ b/stack_system.go @@ -2,7 +2,6 @@ package tun import ( "context" - "errors" "net" "net/netip" "syscall" @@ -357,14 +356,8 @@ func (s *System) processIPv4TCP(ipHdr header.IPv4, tcpHdr header.TCP) (bool, err } else { natPort, err := s.tcpNat.Lookup(source, destination, s.handler) if err != nil { - if errors.Is(err, ErrDrop) { + if err == ErrDrop { return false, nil - } else if errors.Is(err, syscall.ENETUNREACH) { - return false, s.rejectIPv4WithICMP(ipHdr, header.ICMPv4NetUnreachable) - } else if errors.Is(err, syscall.EHOSTUNREACH) { - return false, s.rejectIPv4WithICMP(ipHdr, header.ICMPv4HostUnreachable) - } else if errors.Is(err, syscall.ECONNREFUSED) { - return false, s.rejectIPv4WithICMP(ipHdr, header.ICMPv4PortUnreachable) } else { return false, s.resetIPv4TCP(ipHdr, tcpHdr) } @@ -450,14 +443,8 @@ func (s *System) processIPv6TCP(ipHdr header.IPv6, tcpHdr header.TCP) (bool, err } else { natPort, err := s.tcpNat.Lookup(source, destination, s.handler) if err != nil { - if errors.Is(err, ErrDrop) { + if err == ErrDrop { return false, nil - } else if errors.Is(err, syscall.ENETUNREACH) { - return false, s.rejectIPv6WithICMP(ipHdr, header.ICMPv6NetworkUnreachable) - } else if errors.Is(err, syscall.EHOSTUNREACH) { - return false, s.rejectIPv6WithICMP(ipHdr, header.ICMPv6AddressUnreachable) - } else if errors.Is(err, syscall.ECONNREFUSED) { - return false, s.rejectIPv6WithICMP(ipHdr, header.ICMPv6PortUnreachable) } else { return false, s.resetIPv6TCP(ipHdr, tcpHdr) } @@ -551,23 +538,12 @@ func (s *System) processIPv6UDP(ipHdr header.IPv6, udpHdr header.UDP) error { func (s *System) preparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) { pErr := s.handler.PrepareConnection(N.NetworkUDP, source, destination) if pErr != nil { - if errors.Is(pErr, ErrDrop) { - } else if source.IsIPv4() { - ipHdr := userData.(header.IPv4) - if errors.Is(pErr, syscall.ENETUNREACH) { - s.rejectIPv4WithICMP(ipHdr, header.ICMPv4NetUnreachable) - } else if errors.Is(pErr, syscall.EHOSTUNREACH) { - s.rejectIPv4WithICMP(ipHdr, header.ICMPv4HostUnreachable) - } else { + if pErr != ErrDrop { + if source.IsIPv4() { + ipHdr := userData.(header.IPv4) s.rejectIPv4WithICMP(ipHdr, header.ICMPv4PortUnreachable) - } - } else { - ipHdr := userData.(header.IPv6) - if errors.Is(pErr, syscall.ENETUNREACH) { - s.rejectIPv6WithICMP(ipHdr, header.ICMPv6NetworkUnreachable) - } else if errors.Is(pErr, syscall.EHOSTUNREACH) { - s.rejectIPv6WithICMP(ipHdr, header.ICMPv6AddressUnreachable) } else { + ipHdr := userData.(header.IPv6) s.rejectIPv6WithICMP(ipHdr, header.ICMPv6PortUnreachable) } } diff --git a/tun_linux.go b/tun_linux.go index 390d5b4..72e1620 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -224,7 +224,6 @@ func open(name string, vnetHdr bool) (int, error) { func (t *NativeTun) configure(tunLink netlink.Link) error { err := netlink.LinkSetMTU(tunLink, int(t.options.MTU)) if errors.Is(err, unix.EPERM) { - // unprivileged return nil } else if err != nil { return err @@ -294,13 +293,11 @@ func (t *NativeTun) configure(tunLink netlink.Link) error { func (t *NativeTun) Start() error { tunLink, err := netlink.LinkByName(t.options.Name) - if err != nil { - return err + if err == nil { + err = netlink.LinkSetUp(tunLink) } - err = netlink.LinkSetUp(tunLink) if errors.Is(err, unix.EPERM) { - // unprivileged return nil } else if err != nil { return err