Skip to content

Commit

Permalink
Add UDPProxyTunnel
Browse files Browse the repository at this point in the history
  • Loading branch information
VastBlast committed Jan 1, 2025
1 parent d710683 commit 8ae43d6
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 0 deletions.
39 changes: 39 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ type DeviceConfig struct {
CheckAliveInterval int
}

type UDPProxyTunnelConfig struct {
BindAddress string
Target string
InactivityTimeout int
}

type TCPClientTunnelConfig struct {
BindAddress *net.TCPAddr
Target string
Expand Down Expand Up @@ -434,6 +440,34 @@ func parseHTTPConfig(section *ini.Section) (RoutineSpawner, error) {
return config, nil
}

func parseUDPProxyTunnelConfig(section *ini.Section) (RoutineSpawner, error) {
config := &UDPProxyTunnelConfig{}

bindAddress, err := parseString(section, "BindAddress")
if err != nil {
return nil, err
}
config.BindAddress = bindAddress

target, err := parseString(section, "Target")
if err != nil {
return nil, err
}
config.Target = target

inactivityTimeout := 0
if sectionKey, err := section.GetKey("InactivityTimeout"); err == nil {
timeoutVal, err := sectionKey.Int()
if err != nil {
return nil, err
}
inactivityTimeout = timeoutVal
}
config.InactivityTimeout = inactivityTimeout

return config, nil
}

// Takes a function that parses an individual section into a config, and apply it on all
// specified sections
func parseRoutinesConfig(routines *[]RoutineSpawner, cfg *ini.File, sectionName string, f func(*ini.Section) (RoutineSpawner, error)) error {
Expand Down Expand Up @@ -518,6 +552,11 @@ func ParseConfig(path string) (*Configuration, error) {
return nil, err
}

err = parseRoutinesConfig(&routinesSpawners, cfg, "UDPProxyTunnel", parseUDPProxyTunnelConfig)
if err != nil {
return nil, err
}

return &Configuration{
Device: device,
Routines: routinesSpawners,
Expand Down
160 changes: 160 additions & 0 deletions udp_proxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package wireproxy

import (
"fmt"
"log"
"net"
"sync"
"time"
)

// udpSession represents a UDP forwarding session, keyed by the local source address.
// remoteConn is the UDP connection to the remote endpoint (on the WireGuard side).
type udpSession struct {
remoteConn net.Conn
lastActive time.Time
closeChan chan struct{}
inactivityDur time.Duration
}

// SpawnRoutine implements the RoutineSpawner interface.
// It starts listening on config.BindAddress, handling each unique source (client) address
// with its own udpSession. If InactivityTimeout > 0, sessions automatically close after inactivity
func (conf *UDPProxyTunnelConfig) SpawnRoutine(vt *VirtualTun) {
addr, err := net.ResolveUDPAddr("udp", conf.BindAddress)
if err != nil {
log.Fatalf("UDPProxyTunnelConfig: could not resolve bind address %s: %v", conf.BindAddress, err)
}

listener, err := net.ListenUDP("udp", addr)
if err != nil {
log.Fatalf("UDPProxyTunnelConfig: could not listen on %s: %v", conf.BindAddress, err)
}
log.Printf("UDPProxyTunnel listening on %s, forwarding to %s", conf.BindAddress, conf.Target)

inactivityDur := time.Duration(conf.InactivityTimeout) * time.Second
sessions := make(map[string]*udpSession)
var sessionMu sync.Mutex

// Periodically clean up expired sessions if inactivity timeout is enabled
if conf.InactivityTimeout > 0 {
go func() {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for range ticker.C {
now := time.Now()
sessionMu.Lock()
for key, sess := range sessions {
if now.Sub(sess.lastActive) >= inactivityDur {
log.Printf("UDPProxyTunnel: closing inactive session for %s", key)
close(sess.closeChan)
delete(sessions, key)
}
}
sessionMu.Unlock()
}
}()
}

// Create or get a UDP session based on the local source address
getOrCreateSession := func(srcAddr string) (*udpSession, error) {
sessionMu.Lock()
defer sessionMu.Unlock()

// return if session already exists
if s, ok := sessions[srcAddr]; ok {
s.lastActive = time.Now()
return s, nil
}

// Create a new session
remoteConn, err := vt.Tnet.Dial("udp", conf.Target)
if err != nil {
return nil, fmt.Errorf("UDPProxyTunnel: could not Dial(%s): %w", conf.Target, err)
}

s := &udpSession{
remoteConn: remoteConn,
lastActive: time.Now(),
closeChan: make(chan struct{}),
inactivityDur: inactivityDur,
}
sessions[srcAddr] = s

// Spin up a goroutine to handle traffic from remote -> local
go conf.handleRemoteToLocal(listener, srcAddr, s)
return s, nil
}

// Main loop to read from local client and forward to remote
go func() {
buf := make([]byte, 64*1024) // typical max UDP size
for {
n, src, err := listener.ReadFromUDP(buf)
if err != nil {
log.Printf("UDPProxyTunnel: error reading from UDP: %v", err)
continue
}

srcKey := src.String() // identify session by the local client's IP:port
s, err := getOrCreateSession(srcKey)
if err != nil {
errorLogger.Printf("UDPProxyTunnel: getOrCreateSession failed for %s: %v", srcKey, err)
continue
}

s.lastActive = time.Now()
_, err = s.remoteConn.Write(buf[:n])
if err != nil {
errorLogger.Printf("UDPProxyTunnel: could not write to remote (%s): %v", conf.Target, err)
}
}
}()
}

// handles data from the remote WireGuard side back to the local client
// this function blocks until the session is closed
func (conf *UDPProxyTunnelConfig) handleRemoteToLocal(listener *net.UDPConn, srcAddr string, s *udpSession) {
defer func() {
_ = s.remoteConn.Close()
}()
buf := make([]byte, 64*1024)

for {
select {
case <-s.closeChan:
return
default:
}

_ = s.remoteConn.SetReadDeadline(time.Now().Add(5 * time.Second))
n, err := s.remoteConn.Read(buf)
if err != nil {
// If a timeout or temporary error, continue to see if the session is closed
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
select {
case <-s.closeChan:
return
default:
continue
}
}
errorLogger.Printf("UDPProxyTunnel: read error from remote: %v", err)
return
}

s.lastActive = time.Now()

dstUDPAddr, err := net.ResolveUDPAddr("udp", srcAddr)
if err != nil {
errorLogger.Printf("UDPProxyTunnel: cannot resolve local address %s: %v", srcAddr, err)
return
}

_, err = listener.WriteToUDP(buf[:n], dstUDPAddr)
if err != nil {
errorLogger.Printf("UDPProxyTunnel: cannot write to local %s: %v", srcAddr, err)
return
}
}
}

0 comments on commit 8ae43d6

Please sign in to comment.