mirror of
https://github.com/fosrl/gerbil.git
synced 2026-03-22 21:29:33 -05:00
Compare commits
4 Commits
dev
...
proxy-cont
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5eacbb7239 | ||
|
|
d21c09c84f | ||
|
|
28c65b950c | ||
|
|
1643d71905 |
@@ -18,6 +18,7 @@ import (
|
|||||||
|
|
||||||
"github.com/fosrl/gerbil/logger"
|
"github.com/fosrl/gerbil/logger"
|
||||||
"github.com/patrickmn/go-cache"
|
"github.com/patrickmn/go-cache"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RouteRecord represents a routing configuration
|
// RouteRecord represents a routing configuration
|
||||||
@@ -72,7 +73,9 @@ type SNIProxy struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type activeTunnel struct {
|
type activeTunnel struct {
|
||||||
conns []net.Conn
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
count int // protected by activeTunnelsLock
|
||||||
}
|
}
|
||||||
|
|
||||||
// readOnlyConn is a wrapper for io.Reader that implements net.Conn
|
// readOnlyConn is a wrapper for io.Reader that implements net.Conn
|
||||||
@@ -588,37 +591,32 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Track this tunnel by SNI
|
// Track this tunnel by SNI using context for cancellation
|
||||||
p.activeTunnelsLock.Lock()
|
p.activeTunnelsLock.Lock()
|
||||||
tunnel, ok := p.activeTunnels[hostname]
|
tunnel, ok := p.activeTunnels[hostname]
|
||||||
if !ok {
|
if !ok {
|
||||||
tunnel = &activeTunnel{}
|
ctx, cancel := context.WithCancel(p.ctx)
|
||||||
|
tunnel = &activeTunnel{ctx: ctx, cancel: cancel}
|
||||||
p.activeTunnels[hostname] = tunnel
|
p.activeTunnels[hostname] = tunnel
|
||||||
}
|
}
|
||||||
tunnel.conns = append(tunnel.conns, actualClientConn)
|
tunnel.count++
|
||||||
|
tunnelCtx := tunnel.ctx
|
||||||
p.activeTunnelsLock.Unlock()
|
p.activeTunnelsLock.Unlock()
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
// Remove this conn from active tunnels
|
|
||||||
p.activeTunnelsLock.Lock()
|
p.activeTunnelsLock.Lock()
|
||||||
if tunnel, ok := p.activeTunnels[hostname]; ok {
|
tunnel.count--
|
||||||
newConns := make([]net.Conn, 0, len(tunnel.conns))
|
if tunnel.count == 0 {
|
||||||
for _, c := range tunnel.conns {
|
tunnel.cancel()
|
||||||
if c != actualClientConn {
|
if p.activeTunnels[hostname] == tunnel {
|
||||||
newConns = append(newConns, c)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(newConns) == 0 {
|
|
||||||
delete(p.activeTunnels, hostname)
|
delete(p.activeTunnels, hostname)
|
||||||
} else {
|
|
||||||
tunnel.conns = newConns
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
p.activeTunnelsLock.Unlock()
|
p.activeTunnelsLock.Unlock()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Start bidirectional data transfer
|
// Start bidirectional data transfer with tunnel context
|
||||||
p.pipe(actualClientConn, targetConn, clientReader)
|
p.pipe(tunnelCtx, actualClientConn, targetConn, clientReader)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRoute retrieves routing information for a hostname
|
// getRoute retrieves routing information for a hostname
|
||||||
@@ -754,47 +752,36 @@ func (p *SNIProxy) selectStickyEndpoint(clientAddr string, endpoints []string) s
|
|||||||
}
|
}
|
||||||
|
|
||||||
// pipe handles bidirectional data transfer between connections
|
// pipe handles bidirectional data transfer between connections
|
||||||
func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader) {
|
func (p *SNIProxy) pipe(ctx context.Context, clientConn, targetConn net.Conn, clientReader io.Reader) {
|
||||||
var wg sync.WaitGroup
|
g, gCtx := errgroup.WithContext(ctx)
|
||||||
wg.Add(2)
|
|
||||||
|
|
||||||
// closeOnce ensures we only close connections once
|
// Close connections when context cancels to unblock io.Copy operations
|
||||||
var closeOnce sync.Once
|
context.AfterFunc(gCtx, func() {
|
||||||
closeConns := func() {
|
clientConn.Close()
|
||||||
closeOnce.Do(func() {
|
targetConn.Close()
|
||||||
// Close both connections to unblock any pending reads
|
})
|
||||||
clientConn.Close()
|
|
||||||
targetConn.Close()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy data from client to target (using the buffered reader)
|
// Copy data from client to target
|
||||||
go func() {
|
g.Go(func() error {
|
||||||
defer wg.Done()
|
|
||||||
defer closeConns()
|
|
||||||
|
|
||||||
// Use a large buffer for better performance
|
|
||||||
buf := make([]byte, 32*1024)
|
buf := make([]byte, 32*1024)
|
||||||
_, err := io.CopyBuffer(targetConn, clientReader, buf)
|
_, err := io.CopyBuffer(targetConn, clientReader, buf)
|
||||||
if err != nil && err != io.EOF {
|
if err != nil && err != io.EOF {
|
||||||
logger.Debug("Copy client->target error: %v", err)
|
logger.Debug("Copy client->target error: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
// Copy data from target to client
|
// Copy data from target to client
|
||||||
go func() {
|
g.Go(func() error {
|
||||||
defer wg.Done()
|
|
||||||
defer closeConns()
|
|
||||||
|
|
||||||
// Use a large buffer for better performance
|
|
||||||
buf := make([]byte, 32*1024)
|
buf := make([]byte, 32*1024)
|
||||||
_, err := io.CopyBuffer(clientConn, targetConn, buf)
|
_, err := io.CopyBuffer(clientConn, targetConn, buf)
|
||||||
if err != nil && err != io.EOF {
|
if err != nil && err != io.EOF {
|
||||||
logger.Debug("Copy target->client error: %v", err)
|
logger.Debug("Copy target->client error: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
wg.Wait()
|
g.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCacheStats returns cache statistics
|
// GetCacheStats returns cache statistics
|
||||||
@@ -830,16 +817,14 @@ func (p *SNIProxy) UpdateLocalSNIs(fullDomains []string) {
|
|||||||
|
|
||||||
logger.Debug("Updated local SNIs, added %d, removed %d", len(newSNIs), len(removed))
|
logger.Debug("Updated local SNIs, added %d, removed %d", len(newSNIs), len(removed))
|
||||||
|
|
||||||
// Terminate tunnels for removed SNIs
|
// Terminate tunnels for removed SNIs via context cancellation
|
||||||
if len(removed) > 0 {
|
if len(removed) > 0 {
|
||||||
p.activeTunnelsLock.Lock()
|
p.activeTunnelsLock.Lock()
|
||||||
for _, sni := range removed {
|
for _, sni := range removed {
|
||||||
if tunnels, ok := p.activeTunnels[sni]; ok {
|
if tunnel, ok := p.activeTunnels[sni]; ok {
|
||||||
for _, conn := range tunnels.conns {
|
tunnel.cancel()
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
delete(p.activeTunnels, sni)
|
delete(p.activeTunnels, sni)
|
||||||
logger.Debug("Closed tunnels for SNI target change: %s", sni)
|
logger.Debug("Cancelled tunnel context for SNI target change: %s", sni)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
p.activeTunnelsLock.Unlock()
|
p.activeTunnelsLock.Unlock()
|
||||||
|
|||||||
Reference in New Issue
Block a user