-
Notifications
You must be signed in to change notification settings - Fork 286
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
199 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
package wireproxy | ||
|
||
import ( | ||
"fmt" | ||
"log" | ||
"net" | ||
"sync" | ||
"time" | ||
) | ||
|
||
// udpSession represents a UDP forwarding session, keyed by the local source address. | ||
// remoteConn is the UDP connection to the remote endpoint (on the WireGuard side). | ||
type udpSession struct { | ||
remoteConn net.Conn | ||
lastActive time.Time | ||
closeChan chan struct{} | ||
inactivityDur time.Duration | ||
} | ||
|
||
// SpawnRoutine implements the RoutineSpawner interface. | ||
// It starts listening on config.BindAddress, handling each unique source (client) address | ||
// with its own udpSession. If InactivityTimeout > 0, sessions automatically close after inactivity | ||
func (conf *UDPProxyTunnelConfig) SpawnRoutine(vt *VirtualTun) { | ||
addr, err := net.ResolveUDPAddr("udp", conf.BindAddress) | ||
if err != nil { | ||
log.Fatalf("UDPProxyTunnelConfig: could not resolve bind address %s: %v", conf.BindAddress, err) | ||
} | ||
|
||
listener, err := net.ListenUDP("udp", addr) | ||
if err != nil { | ||
log.Fatalf("UDPProxyTunnelConfig: could not listen on %s: %v", conf.BindAddress, err) | ||
} | ||
log.Printf("UDPProxyTunnel listening on %s, forwarding to %s", conf.BindAddress, conf.Target) | ||
|
||
inactivityDur := time.Duration(conf.InactivityTimeout) * time.Second | ||
sessions := make(map[string]*udpSession) | ||
var sessionMu sync.Mutex | ||
|
||
// Periodically clean up expired sessions if inactivity timeout is enabled | ||
if conf.InactivityTimeout > 0 { | ||
go func() { | ||
ticker := time.NewTicker(10 * time.Second) | ||
defer ticker.Stop() | ||
for range ticker.C { | ||
now := time.Now() | ||
sessionMu.Lock() | ||
for key, sess := range sessions { | ||
if now.Sub(sess.lastActive) >= inactivityDur { | ||
log.Printf("UDPProxyTunnel: closing inactive session for %s", key) | ||
close(sess.closeChan) | ||
delete(sessions, key) | ||
} | ||
} | ||
sessionMu.Unlock() | ||
} | ||
}() | ||
} | ||
|
||
// Create or get a UDP session based on the local source address | ||
getOrCreateSession := func(srcAddr string) (*udpSession, error) { | ||
sessionMu.Lock() | ||
defer sessionMu.Unlock() | ||
|
||
// return if session already exists | ||
if s, ok := sessions[srcAddr]; ok { | ||
s.lastActive = time.Now() | ||
return s, nil | ||
} | ||
|
||
// Create a new session | ||
remoteConn, err := vt.Tnet.Dial("udp", conf.Target) | ||
if err != nil { | ||
return nil, fmt.Errorf("UDPProxyTunnel: could not Dial(%s): %w", conf.Target, err) | ||
} | ||
|
||
s := &udpSession{ | ||
remoteConn: remoteConn, | ||
lastActive: time.Now(), | ||
closeChan: make(chan struct{}), | ||
inactivityDur: inactivityDur, | ||
} | ||
sessions[srcAddr] = s | ||
|
||
// Spin up a goroutine to handle traffic from remote -> local | ||
go conf.handleRemoteToLocal(listener, srcAddr, s) | ||
return s, nil | ||
} | ||
|
||
// Main loop to read from local client and forward to remote | ||
go func() { | ||
buf := make([]byte, 64*1024) // typical max UDP size | ||
for { | ||
n, src, err := listener.ReadFromUDP(buf) | ||
if err != nil { | ||
log.Printf("UDPProxyTunnel: error reading from UDP: %v", err) | ||
continue | ||
} | ||
|
||
srcKey := src.String() // identify session by the local client's IP:port | ||
s, err := getOrCreateSession(srcKey) | ||
if err != nil { | ||
errorLogger.Printf("UDPProxyTunnel: getOrCreateSession failed for %s: %v", srcKey, err) | ||
continue | ||
} | ||
|
||
s.lastActive = time.Now() | ||
_, err = s.remoteConn.Write(buf[:n]) | ||
if err != nil { | ||
errorLogger.Printf("UDPProxyTunnel: could not write to remote (%s): %v", conf.Target, err) | ||
} | ||
} | ||
}() | ||
} | ||
|
||
// handles data from the remote WireGuard side back to the local client | ||
// this function blocks until the session is closed | ||
func (conf *UDPProxyTunnelConfig) handleRemoteToLocal(listener *net.UDPConn, srcAddr string, s *udpSession) { | ||
defer func() { | ||
_ = s.remoteConn.Close() | ||
}() | ||
buf := make([]byte, 64*1024) | ||
|
||
for { | ||
select { | ||
case <-s.closeChan: | ||
return | ||
default: | ||
} | ||
|
||
_ = s.remoteConn.SetReadDeadline(time.Now().Add(5 * time.Second)) | ||
n, err := s.remoteConn.Read(buf) | ||
if err != nil { | ||
// If a timeout or temporary error, continue to see if the session is closed | ||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() { | ||
select { | ||
case <-s.closeChan: | ||
return | ||
default: | ||
continue | ||
} | ||
} | ||
errorLogger.Printf("UDPProxyTunnel: read error from remote: %v", err) | ||
return | ||
} | ||
|
||
s.lastActive = time.Now() | ||
|
||
dstUDPAddr, err := net.ResolveUDPAddr("udp", srcAddr) | ||
if err != nil { | ||
errorLogger.Printf("UDPProxyTunnel: cannot resolve local address %s: %v", srcAddr, err) | ||
return | ||
} | ||
|
||
_, err = listener.WriteToUDP(buf[:n], dstUDPAddr) | ||
if err != nil { | ||
errorLogger.Printf("UDPProxyTunnel: cannot write to local %s: %v", srcAddr, err) | ||
return | ||
} | ||
} | ||
} |