Compare commits

..

1 Commits

Author SHA1 Message Date
Laurence
14a3e7c531 Optimize lock usage in proxy connection handling
- Change activeTunnel.conns from slice to map for O(1) add/remove
- Improve lock scoping in UpdateLocalSNIs: use read lock for diff
  computation, minimize write lock hold time
- Move cache invalidation outside lock (go-cache is thread-safe)
2026-03-13 15:23:30 +00:00
2 changed files with 32 additions and 132 deletions

View File

@@ -72,7 +72,7 @@ type SNIProxy struct {
}
type activeTunnel struct {
conns []net.Conn
conns map[net.Conn]struct{}
}
// readOnlyConn is a wrapper for io.Reader that implements net.Conn
@@ -592,26 +592,19 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
p.activeTunnelsLock.Lock()
tunnel, ok := p.activeTunnels[hostname]
if !ok {
tunnel = &activeTunnel{}
tunnel = &activeTunnel{conns: make(map[net.Conn]struct{})}
p.activeTunnels[hostname] = tunnel
}
tunnel.conns = append(tunnel.conns, actualClientConn)
tunnel.conns[actualClientConn] = struct{}{}
p.activeTunnelsLock.Unlock()
defer func() {
// Remove this conn from active tunnels
// Remove this conn from active tunnels - O(1) with map
p.activeTunnelsLock.Lock()
if tunnel, ok := p.activeTunnels[hostname]; ok {
newConns := make([]net.Conn, 0, len(tunnel.conns))
for _, c := range tunnel.conns {
if c != actualClientConn {
newConns = append(newConns, c)
}
}
if len(newConns) == 0 {
delete(tunnel.conns, actualClientConn)
if len(tunnel.conns) == 0 {
delete(p.activeTunnels, hostname)
} else {
tunnel.conns = newConns
}
}
p.activeTunnelsLock.Unlock()
@@ -810,32 +803,42 @@ func (p *SNIProxy) ClearCache() {
// UpdateLocalSNIs updates the local SNIs and invalidates cache for changed domains
func (p *SNIProxy) UpdateLocalSNIs(fullDomains []string) {
newSNIs := make(map[string]struct{})
newSNIs := make(map[string]struct{}, len(fullDomains))
for _, domain := range fullDomains {
newSNIs[domain] = struct{}{}
// Invalidate any cached route for this domain
p.cache.Delete(domain)
}
// Update localSNIs
p.localSNIsLock.Lock()
// Get old SNIs with read lock to compute diff outside write lock
p.localSNIsLock.RLock()
oldSNIs := p.localSNIs
p.localSNIsLock.RUnlock()
// Compute removed SNIs outside the lock
removed := make([]string, 0)
for sni := range p.localSNIs {
for sni := range oldSNIs {
if _, stillLocal := newSNIs[sni]; !stillLocal {
removed = append(removed, sni)
}
}
// Swap with minimal write lock hold time
p.localSNIsLock.Lock()
p.localSNIs = newSNIs
p.localSNIsLock.Unlock()
// Invalidate cache for new domains (cache is thread-safe)
for domain := range newSNIs {
p.cache.Delete(domain)
}
logger.Debug("Updated local SNIs, added %d, removed %d", len(newSNIs), len(removed))
// Terminate tunnels for removed SNIs
if len(removed) > 0 {
p.activeTunnelsLock.Lock()
for _, sni := range removed {
if tunnels, ok := p.activeTunnels[sni]; ok {
for _, conn := range tunnels.conns {
if tunnel, ok := p.activeTunnels[sni]; ok {
for conn := range tunnel.conns {
conn.Close()
}
delete(p.activeTunnels, sni)

View File

@@ -9,7 +9,6 @@ import (
"io"
"net"
"net/http"
"runtime"
"sync"
"time"
@@ -119,13 +118,6 @@ type Packet struct {
n int
}
// holePunchRateLimitEntry tracks hole punch message counts within a sliding 1-second window.
type holePunchRateLimitEntry struct {
mu sync.Mutex
count int
windowStart time.Time
}
// WireGuard message types
const (
WireGuardMessageTypeHandshakeInitiation = 1
@@ -161,11 +153,6 @@ type UDPProxyServer struct {
// Communication pattern tracking for rebuilding sessions
// Key format: "clientIP:clientPort-destIP:destPort"
commPatterns sync.Map
// Rate limiter for encrypted hole punch messages, keyed by "ip:port"
holePunchRateLimiter sync.Map
// Cache for resolved UDP addresses to avoid per-packet DNS lookups
// Key: "ip:port" string, Value: *net.UDPAddr
addrCache sync.Map
// ReachableAt is the URL where this server can be reached
ReachableAt string
}
@@ -177,7 +164,7 @@ func NewUDPProxyServer(parentCtx context.Context, addr, serverURL string, privat
addr: addr,
serverURL: serverURL,
privateKey: privateKey,
packetChan: make(chan Packet, 50000), // Increased from 1000 to handle high throughput
packetChan: make(chan Packet, 1000),
ReachableAt: reachableAt,
ctx: ctx,
cancel: cancel,
@@ -202,13 +189,8 @@ func (s *UDPProxyServer) Start() error {
s.conn = conn
logger.Info("UDP server listening on %s", s.addr)
// Start worker goroutines based on CPU cores for better parallelism
// At high throughput (160+ Mbps), we need many workers to avoid bottlenecks
workerCount := runtime.NumCPU() * 10
if workerCount < 20 {
workerCount = 20 // Minimum 20 workers
}
logger.Info("Starting %d packet workers (CPUs: %d)", workerCount, runtime.NumCPU())
// Start a fixed number of worker goroutines.
workerCount := 10 // TODO: Make this configurable or pick it better!
for i := 0; i < workerCount; i++ {
go s.packetWorker()
}
@@ -228,9 +210,6 @@ func (s *UDPProxyServer) Start() error {
// Start the communication pattern cleanup routine
go s.cleanupIdleCommunicationPatterns()
// Start the hole punch rate limiter cleanup routine
go s.cleanupHolePunchRateLimiter()
return nil
}
@@ -293,27 +272,6 @@ func (s *UDPProxyServer) packetWorker() {
// Process as a WireGuard packet.
s.handleWireGuardPacket(packet.data, packet.remoteAddr)
} else {
// Rate limit: allow at most 2 hole punch messages per IP:Port per second
rateLimitKey := packet.remoteAddr.String()
entryVal, _ := s.holePunchRateLimiter.LoadOrStore(rateLimitKey, &holePunchRateLimitEntry{
windowStart: time.Now(),
})
rlEntry := entryVal.(*holePunchRateLimitEntry)
rlEntry.mu.Lock()
now := time.Now()
if now.Sub(rlEntry.windowStart) >= time.Second {
rlEntry.count = 0
rlEntry.windowStart = now
}
rlEntry.count++
allowed := rlEntry.count <= 2
rlEntry.mu.Unlock()
if !allowed {
// logger.Debug("Rate limiting hole punch message from %s", rateLimitKey)
bufferPool.Put(packet.data[:1500])
continue
}
// Process as an encrypted hole punch message
var encMsg EncryptedHolePunchMessage
if err := json.Unmarshal(packet.data, &encMsg); err != nil {
@@ -333,7 +291,7 @@ func (s *UDPProxyServer) packetWorker() {
// This appears to be an encrypted message
decryptedData, err := s.decryptMessage(encMsg)
if err != nil {
// logger.Error("Failed to decrypt message: %v", err)
logger.Error("Failed to decrypt message: %v", err)
// Return the buffer to the pool for reuse and continue with next packet
bufferPool.Put(packet.data[:1500])
continue
@@ -458,43 +416,6 @@ func extractWireGuardIndices(packet []byte) (uint32, uint32, bool) {
return 0, 0, false
}
// cachedAddr holds a resolved UDP address with TTL
type cachedAddr struct {
addr *net.UDPAddr
expiresAt time.Time
}
// addrCacheTTL is how long resolved addresses are cached before re-resolving
const addrCacheTTL = 5 * time.Minute
// getCachedAddr returns a cached UDP address or resolves and caches it.
// This avoids per-packet DNS lookups which are a major throughput bottleneck.
func (s *UDPProxyServer) getCachedAddr(ip string, port int) (*net.UDPAddr, error) {
key := fmt.Sprintf("%s:%d", ip, port)
// Check cache first
if cached, ok := s.addrCache.Load(key); ok {
entry := cached.(*cachedAddr)
if time.Now().Before(entry.expiresAt) {
return entry.addr, nil
}
// Cache expired, delete and re-resolve
s.addrCache.Delete(key)
}
// Resolve and cache
addr, err := net.ResolveUDPAddr("udp", key)
if err != nil {
return nil, err
}
s.addrCache.Store(key, &cachedAddr{
addr: addr,
expiresAt: time.Now().Add(addrCacheTTL),
})
return addr, nil
}
// Updated to handle multi-peer WireGuard communication
func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UDPAddr) {
if len(packet) == 0 {
@@ -529,7 +450,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
logger.Debug("Forwarding handshake initiation from %s (sender index: %d) to peers %v", remoteAddr, senderIndex, proxyMapping.Destinations)
for _, dest := range proxyMapping.Destinations {
destAddr, err := s.getCachedAddr(dest.DestinationIP, dest.DestinationPort)
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
if err != nil {
logger.Error("Failed to resolve destination address: %v", err)
continue
@@ -565,7 +486,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
// Forward the response to the original sender
for _, dest := range proxyMapping.Destinations {
destAddr, err := s.getCachedAddr(dest.DestinationIP, dest.DestinationPort)
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
if err != nil {
logger.Error("Failed to resolve destination address: %v", err)
continue
@@ -622,7 +543,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
// No known session, fall back to forwarding to all peers
logger.Debug("No session found for receiver index %d, forwarding to all destinations", receiverIndex)
for _, dest := range proxyMapping.Destinations {
destAddr, err := s.getCachedAddr(dest.DestinationIP, dest.DestinationPort)
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
if err != nil {
logger.Error("Failed to resolve destination address: %v", err)
continue
@@ -650,7 +571,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
// Forward to all peers
for _, dest := range proxyMapping.Destinations {
destAddr, err := s.getCachedAddr(dest.DestinationIP, dest.DestinationPort)
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
if err != nil {
logger.Error("Failed to resolve destination address: %v", err)
continue
@@ -1109,30 +1030,6 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) {
}
// cleanupIdleCommunicationPatterns periodically removes idle communication patterns
// cleanupHolePunchRateLimiter periodically evicts stale rate limit entries to prevent unbounded growth.
func (s *UDPProxyServer) cleanupHolePunchRateLimiter() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
now := time.Now()
s.holePunchRateLimiter.Range(func(key, value interface{}) bool {
rlEntry := value.(*holePunchRateLimitEntry)
rlEntry.mu.Lock()
stale := now.Sub(rlEntry.windowStart) > 10*time.Second
rlEntry.mu.Unlock()
if stale {
s.holePunchRateLimiter.Delete(key)
}
return true
})
case <-s.ctx.Done():
return
}
}
}
func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() {
ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop()