mirror of
https://github.com/fosrl/gerbil.git
synced 2026-03-22 12:54:30 -05:00
Compare commits
1 Commits
dev
...
proxy-perf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c7d9c72f29 |
@@ -69,6 +69,12 @@ type SNIProxy struct {
|
||||
|
||||
// Trusted upstream proxies that can send PROXY protocol
|
||||
trustedUpstreams map[string]struct{}
|
||||
|
||||
// Reusable HTTP client for API requests
|
||||
httpClient *http.Client
|
||||
|
||||
// Buffer pool for connection piping
|
||||
bufferPool *sync.Pool
|
||||
}
|
||||
|
||||
type activeTunnel struct {
|
||||
@@ -374,6 +380,20 @@ func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, lo
|
||||
localOverrides: overridesMap,
|
||||
activeTunnels: make(map[string]*activeTunnel),
|
||||
trustedUpstreams: trustedMap,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
},
|
||||
},
|
||||
bufferPool: &sync.Pool{
|
||||
New: func() interface{} {
|
||||
buf := make([]byte, 32*1024)
|
||||
return &buf
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return proxy, nil
|
||||
@@ -681,9 +701,8 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) {
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Make HTTP request
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
// Make HTTP request using reusable client
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("API request failed: %w", err)
|
||||
}
|
||||
@@ -773,9 +792,15 @@ func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader)
|
||||
defer wg.Done()
|
||||
defer closeConns()
|
||||
|
||||
// Use a large buffer for better performance
|
||||
buf := make([]byte, 32*1024)
|
||||
_, err := io.CopyBuffer(targetConn, clientReader, buf)
|
||||
// Get buffer from pool and return when done
|
||||
bufPtr := p.bufferPool.Get().(*[]byte)
|
||||
defer func() {
|
||||
// Clear buffer before returning to pool to prevent data leakage
|
||||
clear(*bufPtr)
|
||||
p.bufferPool.Put(bufPtr)
|
||||
}()
|
||||
|
||||
_, err := io.CopyBuffer(targetConn, clientReader, *bufPtr)
|
||||
if err != nil && err != io.EOF {
|
||||
logger.Debug("Copy client->target error: %v", err)
|
||||
}
|
||||
@@ -786,9 +811,15 @@ func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader)
|
||||
defer wg.Done()
|
||||
defer closeConns()
|
||||
|
||||
// Use a large buffer for better performance
|
||||
buf := make([]byte, 32*1024)
|
||||
_, err := io.CopyBuffer(clientConn, targetConn, buf)
|
||||
// Get buffer from pool and return when done
|
||||
bufPtr := p.bufferPool.Get().(*[]byte)
|
||||
defer func() {
|
||||
// Clear buffer before returning to pool to prevent data leakage
|
||||
clear(*bufPtr)
|
||||
p.bufferPool.Put(bufPtr)
|
||||
}()
|
||||
|
||||
_, err := io.CopyBuffer(clientConn, targetConn, *bufPtr)
|
||||
if err != nil && err != io.EOF {
|
||||
logger.Debug("Copy target->client error: %v", err)
|
||||
}
|
||||
|
||||
119
relay/relay.go
119
relay/relay.go
@@ -9,7 +9,6 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -119,13 +118,6 @@ type Packet struct {
|
||||
n int
|
||||
}
|
||||
|
||||
// holePunchRateLimitEntry tracks hole punch message counts within a sliding 1-second window.
|
||||
type holePunchRateLimitEntry struct {
|
||||
mu sync.Mutex
|
||||
count int
|
||||
windowStart time.Time
|
||||
}
|
||||
|
||||
// WireGuard message types
|
||||
const (
|
||||
WireGuardMessageTypeHandshakeInitiation = 1
|
||||
@@ -161,11 +153,6 @@ type UDPProxyServer struct {
|
||||
// Communication pattern tracking for rebuilding sessions
|
||||
// Key format: "clientIP:clientPort-destIP:destPort"
|
||||
commPatterns sync.Map
|
||||
// Rate limiter for encrypted hole punch messages, keyed by "ip:port"
|
||||
holePunchRateLimiter sync.Map
|
||||
// Cache for resolved UDP addresses to avoid per-packet DNS lookups
|
||||
// Key: "ip:port" string, Value: *net.UDPAddr
|
||||
addrCache sync.Map
|
||||
// ReachableAt is the URL where this server can be reached
|
||||
ReachableAt string
|
||||
}
|
||||
@@ -177,7 +164,7 @@ func NewUDPProxyServer(parentCtx context.Context, addr, serverURL string, privat
|
||||
addr: addr,
|
||||
serverURL: serverURL,
|
||||
privateKey: privateKey,
|
||||
packetChan: make(chan Packet, 50000), // Increased from 1000 to handle high throughput
|
||||
packetChan: make(chan Packet, 1000),
|
||||
ReachableAt: reachableAt,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
@@ -202,13 +189,8 @@ func (s *UDPProxyServer) Start() error {
|
||||
s.conn = conn
|
||||
logger.Info("UDP server listening on %s", s.addr)
|
||||
|
||||
// Start worker goroutines based on CPU cores for better parallelism
|
||||
// At high throughput (160+ Mbps), we need many workers to avoid bottlenecks
|
||||
workerCount := runtime.NumCPU() * 10
|
||||
if workerCount < 20 {
|
||||
workerCount = 20 // Minimum 20 workers
|
||||
}
|
||||
logger.Info("Starting %d packet workers (CPUs: %d)", workerCount, runtime.NumCPU())
|
||||
// Start a fixed number of worker goroutines.
|
||||
workerCount := 10 // TODO: Make this configurable or pick it better!
|
||||
for i := 0; i < workerCount; i++ {
|
||||
go s.packetWorker()
|
||||
}
|
||||
@@ -228,9 +210,6 @@ func (s *UDPProxyServer) Start() error {
|
||||
// Start the communication pattern cleanup routine
|
||||
go s.cleanupIdleCommunicationPatterns()
|
||||
|
||||
// Start the hole punch rate limiter cleanup routine
|
||||
go s.cleanupHolePunchRateLimiter()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -293,27 +272,6 @@ func (s *UDPProxyServer) packetWorker() {
|
||||
// Process as a WireGuard packet.
|
||||
s.handleWireGuardPacket(packet.data, packet.remoteAddr)
|
||||
} else {
|
||||
// Rate limit: allow at most 2 hole punch messages per IP:Port per second
|
||||
rateLimitKey := packet.remoteAddr.String()
|
||||
entryVal, _ := s.holePunchRateLimiter.LoadOrStore(rateLimitKey, &holePunchRateLimitEntry{
|
||||
windowStart: time.Now(),
|
||||
})
|
||||
rlEntry := entryVal.(*holePunchRateLimitEntry)
|
||||
rlEntry.mu.Lock()
|
||||
now := time.Now()
|
||||
if now.Sub(rlEntry.windowStart) >= time.Second {
|
||||
rlEntry.count = 0
|
||||
rlEntry.windowStart = now
|
||||
}
|
||||
rlEntry.count++
|
||||
allowed := rlEntry.count <= 2
|
||||
rlEntry.mu.Unlock()
|
||||
if !allowed {
|
||||
// logger.Debug("Rate limiting hole punch message from %s", rateLimitKey)
|
||||
bufferPool.Put(packet.data[:1500])
|
||||
continue
|
||||
}
|
||||
|
||||
// Process as an encrypted hole punch message
|
||||
var encMsg EncryptedHolePunchMessage
|
||||
if err := json.Unmarshal(packet.data, &encMsg); err != nil {
|
||||
@@ -333,7 +291,7 @@ func (s *UDPProxyServer) packetWorker() {
|
||||
// This appears to be an encrypted message
|
||||
decryptedData, err := s.decryptMessage(encMsg)
|
||||
if err != nil {
|
||||
// logger.Error("Failed to decrypt message: %v", err)
|
||||
logger.Error("Failed to decrypt message: %v", err)
|
||||
// Return the buffer to the pool for reuse and continue with next packet
|
||||
bufferPool.Put(packet.data[:1500])
|
||||
continue
|
||||
@@ -458,43 +416,6 @@ func extractWireGuardIndices(packet []byte) (uint32, uint32, bool) {
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
// cachedAddr holds a resolved UDP address with TTL
|
||||
type cachedAddr struct {
|
||||
addr *net.UDPAddr
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// addrCacheTTL is how long resolved addresses are cached before re-resolving
|
||||
const addrCacheTTL = 5 * time.Minute
|
||||
|
||||
// getCachedAddr returns a cached UDP address or resolves and caches it.
|
||||
// This avoids per-packet DNS lookups which are a major throughput bottleneck.
|
||||
func (s *UDPProxyServer) getCachedAddr(ip string, port int) (*net.UDPAddr, error) {
|
||||
key := fmt.Sprintf("%s:%d", ip, port)
|
||||
|
||||
// Check cache first
|
||||
if cached, ok := s.addrCache.Load(key); ok {
|
||||
entry := cached.(*cachedAddr)
|
||||
if time.Now().Before(entry.expiresAt) {
|
||||
return entry.addr, nil
|
||||
}
|
||||
// Cache expired, delete and re-resolve
|
||||
s.addrCache.Delete(key)
|
||||
}
|
||||
|
||||
// Resolve and cache
|
||||
addr, err := net.ResolveUDPAddr("udp", key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.addrCache.Store(key, &cachedAddr{
|
||||
addr: addr,
|
||||
expiresAt: time.Now().Add(addrCacheTTL),
|
||||
})
|
||||
return addr, nil
|
||||
}
|
||||
|
||||
// Updated to handle multi-peer WireGuard communication
|
||||
func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UDPAddr) {
|
||||
if len(packet) == 0 {
|
||||
@@ -529,7 +450,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
||||
logger.Debug("Forwarding handshake initiation from %s (sender index: %d) to peers %v", remoteAddr, senderIndex, proxyMapping.Destinations)
|
||||
|
||||
for _, dest := range proxyMapping.Destinations {
|
||||
destAddr, err := s.getCachedAddr(dest.DestinationIP, dest.DestinationPort)
|
||||
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve destination address: %v", err)
|
||||
continue
|
||||
@@ -565,7 +486,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
||||
|
||||
// Forward the response to the original sender
|
||||
for _, dest := range proxyMapping.Destinations {
|
||||
destAddr, err := s.getCachedAddr(dest.DestinationIP, dest.DestinationPort)
|
||||
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve destination address: %v", err)
|
||||
continue
|
||||
@@ -622,7 +543,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
||||
// No known session, fall back to forwarding to all peers
|
||||
logger.Debug("No session found for receiver index %d, forwarding to all destinations", receiverIndex)
|
||||
for _, dest := range proxyMapping.Destinations {
|
||||
destAddr, err := s.getCachedAddr(dest.DestinationIP, dest.DestinationPort)
|
||||
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve destination address: %v", err)
|
||||
continue
|
||||
@@ -650,7 +571,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
||||
|
||||
// Forward to all peers
|
||||
for _, dest := range proxyMapping.Destinations {
|
||||
destAddr, err := s.getCachedAddr(dest.DestinationIP, dest.DestinationPort)
|
||||
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve destination address: %v", err)
|
||||
continue
|
||||
@@ -1109,30 +1030,6 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) {
|
||||
}
|
||||
|
||||
// cleanupIdleCommunicationPatterns periodically removes idle communication patterns
|
||||
// cleanupHolePunchRateLimiter periodically evicts stale rate limit entries to prevent unbounded growth.
|
||||
func (s *UDPProxyServer) cleanupHolePunchRateLimiter() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
s.holePunchRateLimiter.Range(func(key, value interface{}) bool {
|
||||
rlEntry := value.(*holePunchRateLimitEntry)
|
||||
rlEntry.mu.Lock()
|
||||
stale := now.Sub(rlEntry.windowStart) > 10*time.Second
|
||||
rlEntry.mu.Unlock()
|
||||
if stale {
|
||||
s.holePunchRateLimiter.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() {
|
||||
ticker := time.NewTicker(10 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
Reference in New Issue
Block a user