mirror of
https://github.com/fosrl/gerbil.git
synced 2026-03-26 08:02:45 -05:00
Compare commits
7 Commits
proxy-cont
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
40da38708c | ||
|
|
3af64d8bd3 | ||
|
|
fcead8cc15 | ||
|
|
20dad7bb8e | ||
|
|
a955aa6169 | ||
|
|
b118fef265 | ||
|
|
7985f97eb6 |
@@ -18,7 +18,6 @@ import (
|
|||||||
|
|
||||||
"github.com/fosrl/gerbil/logger"
|
"github.com/fosrl/gerbil/logger"
|
||||||
"github.com/patrickmn/go-cache"
|
"github.com/patrickmn/go-cache"
|
||||||
"golang.org/x/sync/errgroup"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// RouteRecord represents a routing configuration
|
// RouteRecord represents a routing configuration
|
||||||
@@ -73,9 +72,7 @@ type SNIProxy struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type activeTunnel struct {
|
type activeTunnel struct {
|
||||||
ctx context.Context
|
conns []net.Conn
|
||||||
cancel context.CancelFunc
|
|
||||||
count int // protected by activeTunnelsLock
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// readOnlyConn is a wrapper for io.Reader that implements net.Conn
|
// readOnlyConn is a wrapper for io.Reader that implements net.Conn
|
||||||
@@ -591,32 +588,37 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Track this tunnel by SNI using context for cancellation
|
// Track this tunnel by SNI
|
||||||
p.activeTunnelsLock.Lock()
|
p.activeTunnelsLock.Lock()
|
||||||
tunnel, ok := p.activeTunnels[hostname]
|
tunnel, ok := p.activeTunnels[hostname]
|
||||||
if !ok {
|
if !ok {
|
||||||
ctx, cancel := context.WithCancel(p.ctx)
|
tunnel = &activeTunnel{}
|
||||||
tunnel = &activeTunnel{ctx: ctx, cancel: cancel}
|
|
||||||
p.activeTunnels[hostname] = tunnel
|
p.activeTunnels[hostname] = tunnel
|
||||||
}
|
}
|
||||||
tunnel.count++
|
tunnel.conns = append(tunnel.conns, actualClientConn)
|
||||||
tunnelCtx := tunnel.ctx
|
|
||||||
p.activeTunnelsLock.Unlock()
|
p.activeTunnelsLock.Unlock()
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
|
// Remove this conn from active tunnels
|
||||||
p.activeTunnelsLock.Lock()
|
p.activeTunnelsLock.Lock()
|
||||||
tunnel.count--
|
if tunnel, ok := p.activeTunnels[hostname]; ok {
|
||||||
if tunnel.count == 0 {
|
newConns := make([]net.Conn, 0, len(tunnel.conns))
|
||||||
tunnel.cancel()
|
for _, c := range tunnel.conns {
|
||||||
if p.activeTunnels[hostname] == tunnel {
|
if c != actualClientConn {
|
||||||
|
newConns = append(newConns, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(newConns) == 0 {
|
||||||
delete(p.activeTunnels, hostname)
|
delete(p.activeTunnels, hostname)
|
||||||
|
} else {
|
||||||
|
tunnel.conns = newConns
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
p.activeTunnelsLock.Unlock()
|
p.activeTunnelsLock.Unlock()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Start bidirectional data transfer with tunnel context
|
// Start bidirectional data transfer
|
||||||
p.pipe(tunnelCtx, actualClientConn, targetConn, clientReader)
|
p.pipe(actualClientConn, targetConn, clientReader)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRoute retrieves routing information for a hostname
|
// getRoute retrieves routing information for a hostname
|
||||||
@@ -752,36 +754,47 @@ func (p *SNIProxy) selectStickyEndpoint(clientAddr string, endpoints []string) s
|
|||||||
}
|
}
|
||||||
|
|
||||||
// pipe handles bidirectional data transfer between connections
|
// pipe handles bidirectional data transfer between connections
|
||||||
func (p *SNIProxy) pipe(ctx context.Context, clientConn, targetConn net.Conn, clientReader io.Reader) {
|
func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader) {
|
||||||
g, gCtx := errgroup.WithContext(ctx)
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
|
||||||
// Close connections when context cancels to unblock io.Copy operations
|
// closeOnce ensures we only close connections once
|
||||||
context.AfterFunc(gCtx, func() {
|
var closeOnce sync.Once
|
||||||
clientConn.Close()
|
closeConns := func() {
|
||||||
targetConn.Close()
|
closeOnce.Do(func() {
|
||||||
})
|
// Close both connections to unblock any pending reads
|
||||||
|
clientConn.Close()
|
||||||
|
targetConn.Close()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Copy data from client to target
|
// Copy data from client to target (using the buffered reader)
|
||||||
g.Go(func() error {
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
defer closeConns()
|
||||||
|
|
||||||
|
// Use a large buffer for better performance
|
||||||
buf := make([]byte, 32*1024)
|
buf := make([]byte, 32*1024)
|
||||||
_, err := io.CopyBuffer(targetConn, clientReader, buf)
|
_, err := io.CopyBuffer(targetConn, clientReader, buf)
|
||||||
if err != nil && err != io.EOF {
|
if err != nil && err != io.EOF {
|
||||||
logger.Debug("Copy client->target error: %v", err)
|
logger.Debug("Copy client->target error: %v", err)
|
||||||
}
|
}
|
||||||
return err
|
}()
|
||||||
})
|
|
||||||
|
|
||||||
// Copy data from target to client
|
// Copy data from target to client
|
||||||
g.Go(func() error {
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
defer closeConns()
|
||||||
|
|
||||||
|
// Use a large buffer for better performance
|
||||||
buf := make([]byte, 32*1024)
|
buf := make([]byte, 32*1024)
|
||||||
_, err := io.CopyBuffer(clientConn, targetConn, buf)
|
_, err := io.CopyBuffer(clientConn, targetConn, buf)
|
||||||
if err != nil && err != io.EOF {
|
if err != nil && err != io.EOF {
|
||||||
logger.Debug("Copy target->client error: %v", err)
|
logger.Debug("Copy target->client error: %v", err)
|
||||||
}
|
}
|
||||||
return err
|
}()
|
||||||
})
|
|
||||||
|
|
||||||
g.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCacheStats returns cache statistics
|
// GetCacheStats returns cache statistics
|
||||||
@@ -817,14 +830,16 @@ func (p *SNIProxy) UpdateLocalSNIs(fullDomains []string) {
|
|||||||
|
|
||||||
logger.Debug("Updated local SNIs, added %d, removed %d", len(newSNIs), len(removed))
|
logger.Debug("Updated local SNIs, added %d, removed %d", len(newSNIs), len(removed))
|
||||||
|
|
||||||
// Terminate tunnels for removed SNIs via context cancellation
|
// Terminate tunnels for removed SNIs
|
||||||
if len(removed) > 0 {
|
if len(removed) > 0 {
|
||||||
p.activeTunnelsLock.Lock()
|
p.activeTunnelsLock.Lock()
|
||||||
for _, sni := range removed {
|
for _, sni := range removed {
|
||||||
if tunnel, ok := p.activeTunnels[sni]; ok {
|
if tunnels, ok := p.activeTunnels[sni]; ok {
|
||||||
tunnel.cancel()
|
for _, conn := range tunnels.conns {
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
delete(p.activeTunnels, sni)
|
delete(p.activeTunnels, sni)
|
||||||
logger.Debug("Cancelled tunnel context for SNI target change: %s", sni)
|
logger.Debug("Closed tunnels for SNI target change: %s", sni)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
p.activeTunnelsLock.Unlock()
|
p.activeTunnelsLock.Unlock()
|
||||||
|
|||||||
119
relay/relay.go
119
relay/relay.go
@@ -9,6 +9,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -118,6 +119,13 @@ type Packet struct {
|
|||||||
n int
|
n int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// holePunchRateLimitEntry tracks hole punch message counts within a sliding 1-second window.
|
||||||
|
type holePunchRateLimitEntry struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
count int
|
||||||
|
windowStart time.Time
|
||||||
|
}
|
||||||
|
|
||||||
// WireGuard message types
|
// WireGuard message types
|
||||||
const (
|
const (
|
||||||
WireGuardMessageTypeHandshakeInitiation = 1
|
WireGuardMessageTypeHandshakeInitiation = 1
|
||||||
@@ -153,6 +161,11 @@ type UDPProxyServer struct {
|
|||||||
// Communication pattern tracking for rebuilding sessions
|
// Communication pattern tracking for rebuilding sessions
|
||||||
// Key format: "clientIP:clientPort-destIP:destPort"
|
// Key format: "clientIP:clientPort-destIP:destPort"
|
||||||
commPatterns sync.Map
|
commPatterns sync.Map
|
||||||
|
// Rate limiter for encrypted hole punch messages, keyed by "ip:port"
|
||||||
|
holePunchRateLimiter sync.Map
|
||||||
|
// Cache for resolved UDP addresses to avoid per-packet DNS lookups
|
||||||
|
// Key: "ip:port" string, Value: *net.UDPAddr
|
||||||
|
addrCache sync.Map
|
||||||
// ReachableAt is the URL where this server can be reached
|
// ReachableAt is the URL where this server can be reached
|
||||||
ReachableAt string
|
ReachableAt string
|
||||||
}
|
}
|
||||||
@@ -164,7 +177,7 @@ func NewUDPProxyServer(parentCtx context.Context, addr, serverURL string, privat
|
|||||||
addr: addr,
|
addr: addr,
|
||||||
serverURL: serverURL,
|
serverURL: serverURL,
|
||||||
privateKey: privateKey,
|
privateKey: privateKey,
|
||||||
packetChan: make(chan Packet, 1000),
|
packetChan: make(chan Packet, 50000), // Increased from 1000 to handle high throughput
|
||||||
ReachableAt: reachableAt,
|
ReachableAt: reachableAt,
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
@@ -189,8 +202,13 @@ func (s *UDPProxyServer) Start() error {
|
|||||||
s.conn = conn
|
s.conn = conn
|
||||||
logger.Info("UDP server listening on %s", s.addr)
|
logger.Info("UDP server listening on %s", s.addr)
|
||||||
|
|
||||||
// Start a fixed number of worker goroutines.
|
// Start worker goroutines based on CPU cores for better parallelism
|
||||||
workerCount := 10 // TODO: Make this configurable or pick it better!
|
// At high throughput (160+ Mbps), we need many workers to avoid bottlenecks
|
||||||
|
workerCount := runtime.NumCPU() * 10
|
||||||
|
if workerCount < 20 {
|
||||||
|
workerCount = 20 // Minimum 20 workers
|
||||||
|
}
|
||||||
|
logger.Info("Starting %d packet workers (CPUs: %d)", workerCount, runtime.NumCPU())
|
||||||
for i := 0; i < workerCount; i++ {
|
for i := 0; i < workerCount; i++ {
|
||||||
go s.packetWorker()
|
go s.packetWorker()
|
||||||
}
|
}
|
||||||
@@ -210,6 +228,9 @@ func (s *UDPProxyServer) Start() error {
|
|||||||
// Start the communication pattern cleanup routine
|
// Start the communication pattern cleanup routine
|
||||||
go s.cleanupIdleCommunicationPatterns()
|
go s.cleanupIdleCommunicationPatterns()
|
||||||
|
|
||||||
|
// Start the hole punch rate limiter cleanup routine
|
||||||
|
go s.cleanupHolePunchRateLimiter()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -272,6 +293,27 @@ func (s *UDPProxyServer) packetWorker() {
|
|||||||
// Process as a WireGuard packet.
|
// Process as a WireGuard packet.
|
||||||
s.handleWireGuardPacket(packet.data, packet.remoteAddr)
|
s.handleWireGuardPacket(packet.data, packet.remoteAddr)
|
||||||
} else {
|
} else {
|
||||||
|
// Rate limit: allow at most 2 hole punch messages per IP:Port per second
|
||||||
|
rateLimitKey := packet.remoteAddr.String()
|
||||||
|
entryVal, _ := s.holePunchRateLimiter.LoadOrStore(rateLimitKey, &holePunchRateLimitEntry{
|
||||||
|
windowStart: time.Now(),
|
||||||
|
})
|
||||||
|
rlEntry := entryVal.(*holePunchRateLimitEntry)
|
||||||
|
rlEntry.mu.Lock()
|
||||||
|
now := time.Now()
|
||||||
|
if now.Sub(rlEntry.windowStart) >= time.Second {
|
||||||
|
rlEntry.count = 0
|
||||||
|
rlEntry.windowStart = now
|
||||||
|
}
|
||||||
|
rlEntry.count++
|
||||||
|
allowed := rlEntry.count <= 2
|
||||||
|
rlEntry.mu.Unlock()
|
||||||
|
if !allowed {
|
||||||
|
// logger.Debug("Rate limiting hole punch message from %s", rateLimitKey)
|
||||||
|
bufferPool.Put(packet.data[:1500])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// Process as an encrypted hole punch message
|
// Process as an encrypted hole punch message
|
||||||
var encMsg EncryptedHolePunchMessage
|
var encMsg EncryptedHolePunchMessage
|
||||||
if err := json.Unmarshal(packet.data, &encMsg); err != nil {
|
if err := json.Unmarshal(packet.data, &encMsg); err != nil {
|
||||||
@@ -291,7 +333,7 @@ func (s *UDPProxyServer) packetWorker() {
|
|||||||
// This appears to be an encrypted message
|
// This appears to be an encrypted message
|
||||||
decryptedData, err := s.decryptMessage(encMsg)
|
decryptedData, err := s.decryptMessage(encMsg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to decrypt message: %v", err)
|
// logger.Error("Failed to decrypt message: %v", err)
|
||||||
// Return the buffer to the pool for reuse and continue with next packet
|
// Return the buffer to the pool for reuse and continue with next packet
|
||||||
bufferPool.Put(packet.data[:1500])
|
bufferPool.Put(packet.data[:1500])
|
||||||
continue
|
continue
|
||||||
@@ -416,6 +458,43 @@ func extractWireGuardIndices(packet []byte) (uint32, uint32, bool) {
|
|||||||
return 0, 0, false
|
return 0, 0, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// cachedAddr holds a resolved UDP address with TTL
|
||||||
|
type cachedAddr struct {
|
||||||
|
addr *net.UDPAddr
|
||||||
|
expiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// addrCacheTTL is how long resolved addresses are cached before re-resolving
|
||||||
|
const addrCacheTTL = 5 * time.Minute
|
||||||
|
|
||||||
|
// getCachedAddr returns a cached UDP address or resolves and caches it.
|
||||||
|
// This avoids per-packet DNS lookups which are a major throughput bottleneck.
|
||||||
|
func (s *UDPProxyServer) getCachedAddr(ip string, port int) (*net.UDPAddr, error) {
|
||||||
|
key := fmt.Sprintf("%s:%d", ip, port)
|
||||||
|
|
||||||
|
// Check cache first
|
||||||
|
if cached, ok := s.addrCache.Load(key); ok {
|
||||||
|
entry := cached.(*cachedAddr)
|
||||||
|
if time.Now().Before(entry.expiresAt) {
|
||||||
|
return entry.addr, nil
|
||||||
|
}
|
||||||
|
// Cache expired, delete and re-resolve
|
||||||
|
s.addrCache.Delete(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve and cache
|
||||||
|
addr, err := net.ResolveUDPAddr("udp", key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.addrCache.Store(key, &cachedAddr{
|
||||||
|
addr: addr,
|
||||||
|
expiresAt: time.Now().Add(addrCacheTTL),
|
||||||
|
})
|
||||||
|
return addr, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Updated to handle multi-peer WireGuard communication
|
// Updated to handle multi-peer WireGuard communication
|
||||||
func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UDPAddr) {
|
func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UDPAddr) {
|
||||||
if len(packet) == 0 {
|
if len(packet) == 0 {
|
||||||
@@ -450,7 +529,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
|||||||
logger.Debug("Forwarding handshake initiation from %s (sender index: %d) to peers %v", remoteAddr, senderIndex, proxyMapping.Destinations)
|
logger.Debug("Forwarding handshake initiation from %s (sender index: %d) to peers %v", remoteAddr, senderIndex, proxyMapping.Destinations)
|
||||||
|
|
||||||
for _, dest := range proxyMapping.Destinations {
|
for _, dest := range proxyMapping.Destinations {
|
||||||
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
|
destAddr, err := s.getCachedAddr(dest.DestinationIP, dest.DestinationPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to resolve destination address: %v", err)
|
logger.Error("Failed to resolve destination address: %v", err)
|
||||||
continue
|
continue
|
||||||
@@ -486,7 +565,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
|||||||
|
|
||||||
// Forward the response to the original sender
|
// Forward the response to the original sender
|
||||||
for _, dest := range proxyMapping.Destinations {
|
for _, dest := range proxyMapping.Destinations {
|
||||||
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
|
destAddr, err := s.getCachedAddr(dest.DestinationIP, dest.DestinationPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to resolve destination address: %v", err)
|
logger.Error("Failed to resolve destination address: %v", err)
|
||||||
continue
|
continue
|
||||||
@@ -543,7 +622,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
|||||||
// No known session, fall back to forwarding to all peers
|
// No known session, fall back to forwarding to all peers
|
||||||
logger.Debug("No session found for receiver index %d, forwarding to all destinations", receiverIndex)
|
logger.Debug("No session found for receiver index %d, forwarding to all destinations", receiverIndex)
|
||||||
for _, dest := range proxyMapping.Destinations {
|
for _, dest := range proxyMapping.Destinations {
|
||||||
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
|
destAddr, err := s.getCachedAddr(dest.DestinationIP, dest.DestinationPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to resolve destination address: %v", err)
|
logger.Error("Failed to resolve destination address: %v", err)
|
||||||
continue
|
continue
|
||||||
@@ -571,7 +650,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
|||||||
|
|
||||||
// Forward to all peers
|
// Forward to all peers
|
||||||
for _, dest := range proxyMapping.Destinations {
|
for _, dest := range proxyMapping.Destinations {
|
||||||
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
|
destAddr, err := s.getCachedAddr(dest.DestinationIP, dest.DestinationPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to resolve destination address: %v", err)
|
logger.Error("Failed to resolve destination address: %v", err)
|
||||||
continue
|
continue
|
||||||
@@ -1030,6 +1109,30 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// cleanupIdleCommunicationPatterns periodically removes idle communication patterns
|
// cleanupIdleCommunicationPatterns periodically removes idle communication patterns
|
||||||
|
// cleanupHolePunchRateLimiter periodically evicts stale rate limit entries to prevent unbounded growth.
|
||||||
|
func (s *UDPProxyServer) cleanupHolePunchRateLimiter() {
|
||||||
|
ticker := time.NewTicker(30 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
now := time.Now()
|
||||||
|
s.holePunchRateLimiter.Range(func(key, value interface{}) bool {
|
||||||
|
rlEntry := value.(*holePunchRateLimitEntry)
|
||||||
|
rlEntry.mu.Lock()
|
||||||
|
stale := now.Sub(rlEntry.windowStart) > 10*time.Second
|
||||||
|
rlEntry.mu.Unlock()
|
||||||
|
if stale {
|
||||||
|
s.holePunchRateLimiter.Delete(key)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() {
|
func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() {
|
||||||
ticker := time.NewTicker(10 * time.Minute)
|
ticker := time.NewTicker(10 * time.Minute)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|||||||
Reference in New Issue
Block a user