diff --git a/config.go b/config.go index b1aba15..a58eee3 100644 --- a/config.go +++ b/config.go @@ -33,6 +33,12 @@ type DeviceConfig struct { CheckAliveInterval int } +type UDPProxyTunnelConfig struct { + BindAddress string + Target string + InactivityTimeout int +} + type TCPClientTunnelConfig struct { BindAddress *net.TCPAddr Target string @@ -434,6 +440,34 @@ func parseHTTPConfig(section *ini.Section) (RoutineSpawner, error) { return config, nil } +func parseUDPProxyTunnelConfig(section *ini.Section) (RoutineSpawner, error) { + config := &UDPProxyTunnelConfig{} + + bindAddress, err := parseString(section, "BindAddress") + if err != nil { + return nil, err + } + config.BindAddress = bindAddress + + target, err := parseString(section, "Target") + if err != nil { + return nil, err + } + config.Target = target + + inactivityTimeout := 0 + if sectionKey, err := section.GetKey("InactivityTimeout"); err == nil { + timeoutVal, err := sectionKey.Int() + if err != nil { + return nil, err + } + inactivityTimeout = timeoutVal + } + config.InactivityTimeout = inactivityTimeout + + return config, nil +} + // Takes a function that parses an individual section into a config, and apply it on all // specified sections func parseRoutinesConfig(routines *[]RoutineSpawner, cfg *ini.File, sectionName string, f func(*ini.Section) (RoutineSpawner, error)) error { @@ -518,6 +552,11 @@ func ParseConfig(path string) (*Configuration, error) { return nil, err } + err = parseRoutinesConfig(&routinesSpawners, cfg, "UDPProxyTunnel", parseUDPProxyTunnelConfig) + if err != nil { + return nil, err + } + return &Configuration{ Device: device, Routines: routinesSpawners, diff --git a/udp_proxy.go b/udp_proxy.go new file mode 100644 index 0000000..542bdbf --- /dev/null +++ b/udp_proxy.go @@ -0,0 +1,160 @@ +package wireproxy + +import ( + "fmt" + "log" + "net" + "sync" + "time" +) + +// udpSession represents a UDP forwarding session, keyed by the local source address. +// remoteConn is the UDP connection to the remote endpoint (on the WireGuard side). +type udpSession struct { + remoteConn net.Conn + lastActive time.Time + closeChan chan struct{} + inactivityDur time.Duration +} + +// SpawnRoutine implements the RoutineSpawner interface. +// It starts listening on config.BindAddress, handling each unique source (client) address +// with its own udpSession. If InactivityTimeout > 0, sessions automatically close after inactivity +func (conf *UDPProxyTunnelConfig) SpawnRoutine(vt *VirtualTun) { + addr, err := net.ResolveUDPAddr("udp", conf.BindAddress) + if err != nil { + log.Fatalf("UDPProxyTunnelConfig: could not resolve bind address %s: %v", conf.BindAddress, err) + } + + listener, err := net.ListenUDP("udp", addr) + if err != nil { + log.Fatalf("UDPProxyTunnelConfig: could not listen on %s: %v", conf.BindAddress, err) + } + log.Printf("UDPProxyTunnel listening on %s, forwarding to %s", conf.BindAddress, conf.Target) + + inactivityDur := time.Duration(conf.InactivityTimeout) * time.Second + sessions := make(map[string]*udpSession) + var sessionMu sync.Mutex + + // Periodically clean up expired sessions if inactivity timeout is enabled + if conf.InactivityTimeout > 0 { + go func() { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + for range ticker.C { + now := time.Now() + sessionMu.Lock() + for key, sess := range sessions { + if now.Sub(sess.lastActive) >= inactivityDur { + log.Printf("UDPProxyTunnel: closing inactive session for %s", key) + close(sess.closeChan) + delete(sessions, key) + } + } + sessionMu.Unlock() + } + }() + } + + // Create or get a UDP session based on the local source address + getOrCreateSession := func(srcAddr string) (*udpSession, error) { + sessionMu.Lock() + defer sessionMu.Unlock() + + // return if session already exists + if s, ok := sessions[srcAddr]; ok { + s.lastActive = time.Now() + return s, nil + } + + // Create a new session + remoteConn, err := vt.Tnet.Dial("udp", conf.Target) + if err != nil { + return nil, fmt.Errorf("UDPProxyTunnel: could not Dial(%s): %w", conf.Target, err) + } + + s := &udpSession{ + remoteConn: remoteConn, + lastActive: time.Now(), + closeChan: make(chan struct{}), + inactivityDur: inactivityDur, + } + sessions[srcAddr] = s + + // Spin up a goroutine to handle traffic from remote -> local + go conf.handleRemoteToLocal(listener, srcAddr, s) + return s, nil + } + + // Main loop to read from local client and forward to remote + go func() { + buf := make([]byte, 64*1024) // typical max UDP size + for { + n, src, err := listener.ReadFromUDP(buf) + if err != nil { + log.Printf("UDPProxyTunnel: error reading from UDP: %v", err) + continue + } + + srcKey := src.String() // identify session by the local client's IP:port + s, err := getOrCreateSession(srcKey) + if err != nil { + errorLogger.Printf("UDPProxyTunnel: getOrCreateSession failed for %s: %v", srcKey, err) + continue + } + + s.lastActive = time.Now() + _, err = s.remoteConn.Write(buf[:n]) + if err != nil { + errorLogger.Printf("UDPProxyTunnel: could not write to remote (%s): %v", conf.Target, err) + } + } + }() +} + +// handles data from the remote WireGuard side back to the local client +// this function blocks until the session is closed +func (conf *UDPProxyTunnelConfig) handleRemoteToLocal(listener *net.UDPConn, srcAddr string, s *udpSession) { + defer func() { + _ = s.remoteConn.Close() + }() + buf := make([]byte, 64*1024) + + for { + select { + case <-s.closeChan: + return + default: + } + + _ = s.remoteConn.SetReadDeadline(time.Now().Add(5 * time.Second)) + n, err := s.remoteConn.Read(buf) + if err != nil { + // If a timeout or temporary error, continue to see if the session is closed + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + select { + case <-s.closeChan: + return + default: + continue + } + } + errorLogger.Printf("UDPProxyTunnel: read error from remote: %v", err) + return + } + + s.lastActive = time.Now() + + dstUDPAddr, err := net.ResolveUDPAddr("udp", srcAddr) + if err != nil { + errorLogger.Printf("UDPProxyTunnel: cannot resolve local address %s: %v", srcAddr, err) + return + } + + _, err = listener.WriteToUDP(buf[:n], dstUDPAddr) + if err != nil { + errorLogger.Printf("UDPProxyTunnel: cannot write to local %s: %v", srcAddr, err) + return + } + } +}