Compare commits

...

4 Commits

Author SHA1 Message Date
Laurence
5eacbb7239 fix(proxy): prevent deleting wrong tunnel in defer cleanup
Add pointer check before delete to handle race where UpdateLocalSNIs
removes our tunnel and a new one is created for the same hostname.
2026-03-13 16:43:16 +00:00
Laurence
d21c09c84f refactor(proxy): simplify tunnel tracking with mutex-only approach
Remove atomic counter in favor of simple int protected by mutex.
Eliminates race condition complexity and recheck logic.
2026-03-13 16:36:56 +00:00
Laurence
28c65b950c fix(proxy): avoid shadowing ctx variable in pipe() 2026-03-13 15:51:23 +00:00
Laurence
1643d71905 refactor(proxy): use context cancellation for tunnel tracking
- Replace []net.Conn slice with context + atomic counter in activeTunnel
- Use errgroup.WithContext for pipe() to handle goroutine lifecycle
- Use context.AfterFunc to close connections on cancellation
- Fix race condition by comparing tunnel pointers instead of map lookup
- UpdateLocalSNIs now cancels tunnel context instead of iterating conns

This eliminates O(n) connection removal, prevents goroutine leaks,
and provides cleaner cancellation semantics.
2026-03-13 15:47:52 +00:00

View File

@@ -18,6 +18,7 @@ import (
"github.com/fosrl/gerbil/logger" "github.com/fosrl/gerbil/logger"
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
"golang.org/x/sync/errgroup"
) )
// RouteRecord represents a routing configuration // RouteRecord represents a routing configuration
@@ -72,7 +73,9 @@ type SNIProxy struct {
} }
type activeTunnel struct { type activeTunnel struct {
conns []net.Conn ctx context.Context
cancel context.CancelFunc
count int // protected by activeTunnelsLock
} }
// readOnlyConn is a wrapper for io.Reader that implements net.Conn // readOnlyConn is a wrapper for io.Reader that implements net.Conn
@@ -588,37 +591,32 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
} }
} }
// Track this tunnel by SNI // Track this tunnel by SNI using context for cancellation
p.activeTunnelsLock.Lock() p.activeTunnelsLock.Lock()
tunnel, ok := p.activeTunnels[hostname] tunnel, ok := p.activeTunnels[hostname]
if !ok { if !ok {
tunnel = &activeTunnel{} ctx, cancel := context.WithCancel(p.ctx)
tunnel = &activeTunnel{ctx: ctx, cancel: cancel}
p.activeTunnels[hostname] = tunnel p.activeTunnels[hostname] = tunnel
} }
tunnel.conns = append(tunnel.conns, actualClientConn) tunnel.count++
tunnelCtx := tunnel.ctx
p.activeTunnelsLock.Unlock() p.activeTunnelsLock.Unlock()
defer func() { defer func() {
// Remove this conn from active tunnels
p.activeTunnelsLock.Lock() p.activeTunnelsLock.Lock()
if tunnel, ok := p.activeTunnels[hostname]; ok { tunnel.count--
newConns := make([]net.Conn, 0, len(tunnel.conns)) if tunnel.count == 0 {
for _, c := range tunnel.conns { tunnel.cancel()
if c != actualClientConn { if p.activeTunnels[hostname] == tunnel {
newConns = append(newConns, c)
}
}
if len(newConns) == 0 {
delete(p.activeTunnels, hostname) delete(p.activeTunnels, hostname)
} else {
tunnel.conns = newConns
} }
} }
p.activeTunnelsLock.Unlock() p.activeTunnelsLock.Unlock()
}() }()
// Start bidirectional data transfer // Start bidirectional data transfer with tunnel context
p.pipe(actualClientConn, targetConn, clientReader) p.pipe(tunnelCtx, actualClientConn, targetConn, clientReader)
} }
// getRoute retrieves routing information for a hostname // getRoute retrieves routing information for a hostname
@@ -754,47 +752,36 @@ func (p *SNIProxy) selectStickyEndpoint(clientAddr string, endpoints []string) s
} }
// pipe handles bidirectional data transfer between connections // pipe handles bidirectional data transfer between connections
func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader) { func (p *SNIProxy) pipe(ctx context.Context, clientConn, targetConn net.Conn, clientReader io.Reader) {
var wg sync.WaitGroup g, gCtx := errgroup.WithContext(ctx)
wg.Add(2)
// closeOnce ensures we only close connections once // Close connections when context cancels to unblock io.Copy operations
var closeOnce sync.Once context.AfterFunc(gCtx, func() {
closeConns := func() { clientConn.Close()
closeOnce.Do(func() { targetConn.Close()
// Close both connections to unblock any pending reads })
clientConn.Close()
targetConn.Close()
})
}
// Copy data from client to target (using the buffered reader) // Copy data from client to target
go func() { g.Go(func() error {
defer wg.Done()
defer closeConns()
// Use a large buffer for better performance
buf := make([]byte, 32*1024) buf := make([]byte, 32*1024)
_, err := io.CopyBuffer(targetConn, clientReader, buf) _, err := io.CopyBuffer(targetConn, clientReader, buf)
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
logger.Debug("Copy client->target error: %v", err) logger.Debug("Copy client->target error: %v", err)
} }
}() return err
})
// Copy data from target to client // Copy data from target to client
go func() { g.Go(func() error {
defer wg.Done()
defer closeConns()
// Use a large buffer for better performance
buf := make([]byte, 32*1024) buf := make([]byte, 32*1024)
_, err := io.CopyBuffer(clientConn, targetConn, buf) _, err := io.CopyBuffer(clientConn, targetConn, buf)
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
logger.Debug("Copy target->client error: %v", err) logger.Debug("Copy target->client error: %v", err)
} }
}() return err
})
wg.Wait() g.Wait()
} }
// GetCacheStats returns cache statistics // GetCacheStats returns cache statistics
@@ -830,16 +817,14 @@ func (p *SNIProxy) UpdateLocalSNIs(fullDomains []string) {
logger.Debug("Updated local SNIs, added %d, removed %d", len(newSNIs), len(removed)) logger.Debug("Updated local SNIs, added %d, removed %d", len(newSNIs), len(removed))
// Terminate tunnels for removed SNIs // Terminate tunnels for removed SNIs via context cancellation
if len(removed) > 0 { if len(removed) > 0 {
p.activeTunnelsLock.Lock() p.activeTunnelsLock.Lock()
for _, sni := range removed { for _, sni := range removed {
if tunnels, ok := p.activeTunnels[sni]; ok { if tunnel, ok := p.activeTunnels[sni]; ok {
for _, conn := range tunnels.conns { tunnel.cancel()
conn.Close()
}
delete(p.activeTunnels, sni) delete(p.activeTunnels, sni)
logger.Debug("Closed tunnels for SNI target change: %s", sni) logger.Debug("Cancelled tunnel context for SNI target change: %s", sni)
} }
} }
p.activeTunnelsLock.Unlock() p.activeTunnelsLock.Unlock()