Compare commits

..

4 Commits

Author SHA1 Message Date
Owen Schwartz
b398f531f0 Merge pull request #279 from fosrl/dev
1.10.3
2026-03-16 16:47:39 -07:00
Owen
ef03b4566d Allow passing public dns into resolve 2026-03-12 16:41:41 -07:00
Owen
44ca592a5c Set newt version in dockerfile 2026-03-08 11:28:56 -07:00
Owen
e1edbcea07 Make sure to set version and fix prepare issue 2026-03-07 12:34:55 -08:00
13 changed files with 171 additions and 1547 deletions

View File

@@ -43,7 +43,6 @@ type Target struct {
RewriteTo string `json:"rewriteTo,omitempty"`
DisableIcmp bool `json:"disableIcmp,omitempty"`
PortRange []PortRange `json:"portRange,omitempty"`
ResourceId int `json:"resourceId,omitempty"`
}
type PortRange struct {
@@ -162,9 +161,8 @@ func NewWireGuardService(interfaceName string, port uint16, mtu int, host string
useNativeInterface: useNativeInterface,
}
// Create the holepunch manager with ResolveDomain function
// We'll need to pass a domain resolver function
service.holePunchManager = holepunch.NewManager(sharedBind, newtId, "newt", key.PublicKey().String())
// Create the holepunch manager
service.holePunchManager = holepunch.NewManager(sharedBind, newtId, "newt", key.PublicKey().String(), nil)
// Register websocket handlers
wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig)
@@ -197,15 +195,6 @@ func (s *WireGuardService) Close() {
s.stopGetConfig = nil
}
// Flush access logs before tearing down the tunnel
if s.tnet != nil {
if ph := s.tnet.GetProxyHandler(); ph != nil {
if al := ph.GetAccessLogger(); al != nil {
al.Close()
}
}
}
// Stop the direct UDP relay first
s.StopDirectUDPRelay()
@@ -673,7 +662,7 @@ func (s *WireGuardService) syncTargets(desiredTargets []Target) error {
})
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId)
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target %s -> %s during sync", target.SourcePrefix, target.DestPrefix)
}
}
@@ -804,13 +793,6 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
s.TunnelIP = tunnelIP.String()
// Configure the access log sender to ship compressed session logs via websocket
s.tnet.SetAccessLogSender(func(data string) error {
return s.client.SendMessageNoLog("newt/access-log", map[string]interface{}{
"compressed": data,
})
})
// Create WireGuard device using the shared bind
s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger(
device.LogLevelSilent, // Use silent logging by default - could be made configurable
@@ -931,7 +913,7 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
if err != nil {
return fmt.Errorf("invalid CIDR %s: %v", sp, err)
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId)
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
}
}
@@ -1324,7 +1306,7 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId)
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
}
}
@@ -1442,7 +1424,7 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId)
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
}
}

View File

@@ -5,9 +5,7 @@ import (
"crypto/tls"
"encoding/json"
"fmt"
"net"
"net/http"
"strconv"
"strings"
"sync"
"time"
@@ -367,12 +365,11 @@ func (m *Monitor) performHealthCheck(target *Target) {
target.LastCheck = time.Now()
target.LastError = ""
// Build URL (use net.JoinHostPort to properly handle IPv6 addresses with ports)
host := target.Config.Hostname
// Build URL
url := fmt.Sprintf("%s://%s", target.Config.Scheme, target.Config.Hostname)
if target.Config.Port > 0 {
host = net.JoinHostPort(target.Config.Hostname, strconv.Itoa(target.Config.Port))
url = fmt.Sprintf("%s:%d", url, target.Config.Port)
}
url := fmt.Sprintf("%s://%s", target.Config.Scheme, host)
if target.Config.Path != "" {
if !strings.HasPrefix(target.Config.Path, "/") {
url += "/"

View File

@@ -27,16 +27,17 @@ type ExitNode struct {
// Manager handles UDP hole punching operations
type Manager struct {
mu sync.Mutex
running bool
stopChan chan struct{}
sharedBind *bind.SharedBind
ID string
token string
publicKey string
clientType string
exitNodes map[string]ExitNode // key is endpoint
updateChan chan struct{} // signals the goroutine to refresh exit nodes
mu sync.Mutex
running bool
stopChan chan struct{}
sharedBind *bind.SharedBind
ID string
token string
publicKey string
clientType string
exitNodes map[string]ExitNode // key is endpoint
updateChan chan struct{} // signals the goroutine to refresh exit nodes
publicDNS []string
sendHolepunchInterval time.Duration
sendHolepunchIntervalMin time.Duration
@@ -49,12 +50,13 @@ const defaultSendHolepunchIntervalMax = 60 * time.Second
const defaultSendHolepunchIntervalMin = 1 * time.Second
// NewManager creates a new hole punch manager
func NewManager(sharedBind *bind.SharedBind, ID string, clientType string, publicKey string) *Manager {
func NewManager(sharedBind *bind.SharedBind, ID string, clientType string, publicKey string, publicDNS []string) *Manager {
return &Manager{
sharedBind: sharedBind,
ID: ID,
clientType: clientType,
publicKey: publicKey,
publicDNS: publicDNS,
exitNodes: make(map[string]ExitNode),
sendHolepunchInterval: defaultSendHolepunchIntervalMin,
sendHolepunchIntervalMin: defaultSendHolepunchIntervalMin,
@@ -281,7 +283,13 @@ func (m *Manager) TriggerHolePunch() error {
// Send hole punch to all exit nodes
successCount := 0
for _, exitNode := range currentExitNodes {
host, err := util.ResolveDomain(exitNode.Endpoint)
var host string
var err error
if len(m.publicDNS) > 0 {
host, err = util.ResolveDomainUpstream(exitNode.Endpoint, m.publicDNS)
} else {
host, err = util.ResolveDomain(exitNode.Endpoint)
}
if err != nil {
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
continue
@@ -392,7 +400,13 @@ func (m *Manager) runMultipleExitNodes() {
var resolvedNodes []resolvedExitNode
for _, exitNode := range currentExitNodes {
host, err := util.ResolveDomain(exitNode.Endpoint)
var host string
var err error
if len(m.publicDNS) > 0 {
host, err = util.ResolveDomainUpstream(exitNode.Endpoint, m.publicDNS)
} else {
host, err = util.ResolveDomain(exitNode.Endpoint)
}
if err != nil {
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
continue

View File

@@ -49,10 +49,11 @@ type cachedAddr struct {
// HolepunchTester monitors holepunch connectivity using magic packets
type HolepunchTester struct {
sharedBind *bind.SharedBind
mu sync.RWMutex
running bool
stopChan chan struct{}
sharedBind *bind.SharedBind
publicDNS []string
mu sync.RWMutex
running bool
stopChan chan struct{}
// Pending requests waiting for responses (key: echo data as string)
pendingRequests sync.Map // map[string]*pendingRequest
@@ -84,9 +85,10 @@ type pendingRequest struct {
}
// NewHolepunchTester creates a new holepunch tester using the given SharedBind
func NewHolepunchTester(sharedBind *bind.SharedBind) *HolepunchTester {
func NewHolepunchTester(sharedBind *bind.SharedBind, publicDNS []string) *HolepunchTester {
return &HolepunchTester{
sharedBind: sharedBind,
publicDNS: publicDNS,
addrCache: make(map[string]*cachedAddr),
addrCacheTTL: 5 * time.Minute, // Cache addresses for 5 minutes
}
@@ -169,7 +171,13 @@ func (t *HolepunchTester) resolveEndpoint(endpoint string) (*net.UDPAddr, error)
}
// Resolve the endpoint
host, err := util.ResolveDomain(endpoint)
var host string
var err error
if len(t.publicDNS) > 0 {
host, err = util.ResolveDomainUpstream(endpoint, t.publicDNS)
} else {
host, err = util.ResolveDomain(endpoint)
}
if err != nil {
host = endpoint
}

19
main.go
View File

@@ -10,7 +10,6 @@ import (
"fmt"
"net"
"net/http"
"net/http/pprof"
"net/netip"
"os"
"os/signal"
@@ -148,7 +147,6 @@ var (
adminAddr string
region string
metricsAsyncBytes bool
pprofEnabled bool
blueprintFile string
noCloud bool
@@ -227,7 +225,6 @@ func runNewtMain(ctx context.Context) {
adminAddrEnv := os.Getenv("NEWT_ADMIN_ADDR")
regionEnv := os.Getenv("NEWT_REGION")
asyncBytesEnv := os.Getenv("NEWT_METRICS_ASYNC_BYTES")
pprofEnabledEnv := os.Getenv("NEWT_PPROF_ENABLED")
disableClientsEnv := os.Getenv("DISABLE_CLIENTS")
disableClients = disableClientsEnv == "true"
@@ -393,14 +390,6 @@ func runNewtMain(ctx context.Context) {
metricsAsyncBytes = v
}
}
// pprof debug endpoint toggle
if pprofEnabledEnv == "" {
flag.BoolVar(&pprofEnabled, "pprof", false, "Enable pprof debug endpoints on admin server")
} else {
if v, err := strconv.ParseBool(pprofEnabledEnv); err == nil {
pprofEnabled = v
}
}
// Optional region flag (resource attribute)
if regionEnv == "" {
flag.StringVar(&region, "region", "", "Optional region resource attribute (also NEWT_REGION)")
@@ -496,14 +485,6 @@ func runNewtMain(ctx context.Context) {
if tel.PrometheusHandler != nil {
mux.Handle("/metrics", tel.PrometheusHandler)
}
if pprofEnabled {
mux.HandleFunc("/debug/pprof/", pprof.Index)
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
logger.Info("pprof debugging enabled on %s/debug/pprof/", tcfg.AdminAddr)
}
admin := &http.Server{
Addr: tcfg.AdminAddr,
Handler: otelhttp.NewHandler(mux, "newt-admin"),

View File

@@ -1,514 +0,0 @@
package netstack2
import (
"bytes"
"compress/zlib"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
"net"
"sort"
"sync"
"time"
"github.com/fosrl/newt/logger"
)
const (
// flushInterval is how often the access logger flushes completed sessions to the server
flushInterval = 60 * time.Second
// maxBufferedSessions is the max number of completed sessions to buffer before forcing a flush
maxBufferedSessions = 100
// sessionGapThreshold is the maximum gap between the end of one connection
// and the start of the next for them to be considered part of the same session.
// If the gap exceeds this, a new consolidated session is created.
sessionGapThreshold = 5 * time.Second
// minConnectionsToConsolidate is the minimum number of connections in a group
// before we bother consolidating. Groups smaller than this are sent as-is.
minConnectionsToConsolidate = 2
)
// SendFunc is a callback that sends compressed access log data to the server.
// The data is a base64-encoded zlib-compressed JSON array of AccessSession objects.
type SendFunc func(data string) error
// AccessSession represents a tracked access session through the proxy
type AccessSession struct {
SessionID string `json:"sessionId"`
ResourceID int `json:"resourceId"`
SourceAddr string `json:"sourceAddr"`
DestAddr string `json:"destAddr"`
Protocol string `json:"protocol"`
StartedAt time.Time `json:"startedAt"`
EndedAt time.Time `json:"endedAt,omitempty"`
BytesTx int64 `json:"bytesTx"`
BytesRx int64 `json:"bytesRx"`
ConnectionCount int `json:"connectionCount,omitempty"` // number of raw connections merged into this session (0 or 1 = single)
}
// udpSessionKey identifies a unique UDP "session" by src -> dst
type udpSessionKey struct {
srcAddr string
dstAddr string
protocol string
}
// consolidationKey groups connections that may be part of the same logical session.
// Source port is intentionally excluded so that many ephemeral-port connections
// from the same source IP to the same destination are grouped together.
type consolidationKey struct {
sourceIP string // IP only, no port
destAddr string // full host:port of the destination
protocol string
resourceID int
}
// AccessLogger tracks access sessions for resources and periodically
// flushes completed sessions to the server via a configurable SendFunc.
type AccessLogger struct {
mu sync.Mutex
sessions map[string]*AccessSession // active sessions: sessionID -> session
udpSessions map[udpSessionKey]*AccessSession // active UDP sessions for dedup
completedSessions []*AccessSession // completed sessions waiting to be flushed
udpTimeout time.Duration
sendFn SendFunc
stopCh chan struct{}
flushDone chan struct{} // closed after the flush goroutine exits
}
// NewAccessLogger creates a new access logger.
// udpTimeout controls how long a UDP session is kept alive without traffic before being ended.
func NewAccessLogger(udpTimeout time.Duration) *AccessLogger {
al := &AccessLogger{
sessions: make(map[string]*AccessSession),
udpSessions: make(map[udpSessionKey]*AccessSession),
completedSessions: make([]*AccessSession, 0),
udpTimeout: udpTimeout,
stopCh: make(chan struct{}),
flushDone: make(chan struct{}),
}
go al.backgroundLoop()
return al
}
// SetSendFunc sets the callback used to send compressed access log batches
// to the server. This can be called after construction once the websocket
// client is available.
func (al *AccessLogger) SetSendFunc(fn SendFunc) {
al.mu.Lock()
defer al.mu.Unlock()
al.sendFn = fn
}
// generateSessionID creates a random session identifier
func generateSessionID() string {
b := make([]byte, 8)
rand.Read(b)
return hex.EncodeToString(b)
}
// StartTCPSession logs the start of a TCP session and returns a session ID.
func (al *AccessLogger) StartTCPSession(resourceID int, srcAddr, dstAddr string) string {
sessionID := generateSessionID()
now := time.Now()
session := &AccessSession{
SessionID: sessionID,
ResourceID: resourceID,
SourceAddr: srcAddr,
DestAddr: dstAddr,
Protocol: "tcp",
StartedAt: now,
}
al.mu.Lock()
al.sessions[sessionID] = session
al.mu.Unlock()
logger.Info("ACCESS START session=%s resource=%d proto=tcp src=%s dst=%s time=%s",
sessionID, resourceID, srcAddr, dstAddr, now.Format(time.RFC3339))
return sessionID
}
// EndTCPSession logs the end of a TCP session and queues it for sending.
func (al *AccessLogger) EndTCPSession(sessionID string) {
now := time.Now()
al.mu.Lock()
session, ok := al.sessions[sessionID]
if ok {
session.EndedAt = now
delete(al.sessions, sessionID)
al.completedSessions = append(al.completedSessions, session)
}
shouldFlush := len(al.completedSessions) >= maxBufferedSessions
al.mu.Unlock()
if ok {
duration := now.Sub(session.StartedAt)
logger.Info("ACCESS END session=%s resource=%d proto=tcp src=%s dst=%s started=%s ended=%s duration=%s",
sessionID, session.ResourceID, session.SourceAddr, session.DestAddr,
session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration)
}
if shouldFlush {
al.flush()
}
}
// TrackUDPSession starts or returns an existing UDP session. Returns the session ID.
func (al *AccessLogger) TrackUDPSession(resourceID int, srcAddr, dstAddr string) string {
key := udpSessionKey{
srcAddr: srcAddr,
dstAddr: dstAddr,
protocol: "udp",
}
al.mu.Lock()
defer al.mu.Unlock()
if existing, ok := al.udpSessions[key]; ok {
return existing.SessionID
}
sessionID := generateSessionID()
now := time.Now()
session := &AccessSession{
SessionID: sessionID,
ResourceID: resourceID,
SourceAddr: srcAddr,
DestAddr: dstAddr,
Protocol: "udp",
StartedAt: now,
}
al.sessions[sessionID] = session
al.udpSessions[key] = session
logger.Info("ACCESS START session=%s resource=%d proto=udp src=%s dst=%s time=%s",
sessionID, resourceID, srcAddr, dstAddr, now.Format(time.RFC3339))
return sessionID
}
// EndUDPSession ends a UDP session and queues it for sending.
func (al *AccessLogger) EndUDPSession(sessionID string) {
now := time.Now()
al.mu.Lock()
session, ok := al.sessions[sessionID]
if ok {
session.EndedAt = now
delete(al.sessions, sessionID)
key := udpSessionKey{
srcAddr: session.SourceAddr,
dstAddr: session.DestAddr,
protocol: "udp",
}
delete(al.udpSessions, key)
al.completedSessions = append(al.completedSessions, session)
}
shouldFlush := len(al.completedSessions) >= maxBufferedSessions
al.mu.Unlock()
if ok {
duration := now.Sub(session.StartedAt)
logger.Info("ACCESS END session=%s resource=%d proto=udp src=%s dst=%s started=%s ended=%s duration=%s",
sessionID, session.ResourceID, session.SourceAddr, session.DestAddr,
session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration)
}
if shouldFlush {
al.flush()
}
}
// backgroundLoop handles periodic flushing and stale session reaping.
func (al *AccessLogger) backgroundLoop() {
defer close(al.flushDone)
flushTicker := time.NewTicker(flushInterval)
defer flushTicker.Stop()
reapTicker := time.NewTicker(30 * time.Second)
defer reapTicker.Stop()
for {
select {
case <-al.stopCh:
return
case <-flushTicker.C:
al.flush()
case <-reapTicker.C:
al.reapStaleSessions()
}
}
}
// reapStaleSessions cleans up UDP sessions that were not properly ended.
func (al *AccessLogger) reapStaleSessions() {
al.mu.Lock()
defer al.mu.Unlock()
staleThreshold := time.Now().Add(-5 * time.Minute)
for key, session := range al.udpSessions {
if session.StartedAt.Before(staleThreshold) && session.EndedAt.IsZero() {
now := time.Now()
session.EndedAt = now
duration := now.Sub(session.StartedAt)
logger.Info("ACCESS END (reaped) session=%s resource=%d proto=udp src=%s dst=%s started=%s ended=%s duration=%s",
session.SessionID, session.ResourceID, session.SourceAddr, session.DestAddr,
session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration)
al.completedSessions = append(al.completedSessions, session)
delete(al.sessions, session.SessionID)
delete(al.udpSessions, key)
}
}
}
// extractIP strips the port from an address string and returns just the IP.
// If the address has no port component it is returned as-is.
func extractIP(addr string) string {
host, _, err := net.SplitHostPort(addr)
if err != nil {
// Might already be a bare IP
return addr
}
return host
}
// consolidateSessions takes a slice of completed sessions and merges bursts of
// short-lived connections from the same source IP to the same destination into
// single higher-level session entries.
//
// The algorithm:
// 1. Group sessions by (sourceIP, destAddr, protocol, resourceID).
// 2. Within each group, sort by StartedAt.
// 3. Walk through the sorted list and merge consecutive sessions whose gap
// (previous EndedAt → next StartedAt) is ≤ sessionGapThreshold.
// 4. For merged sessions the earliest StartedAt and latest EndedAt are kept,
// bytes are summed, and ConnectionCount records how many raw connections
// were folded in. If the merged connections used more than one source port,
// SourceAddr is set to just the IP (port omitted).
// 5. Groups with fewer than minConnectionsToConsolidate members are passed
// through unmodified.
func consolidateSessions(sessions []*AccessSession) []*AccessSession {
if len(sessions) <= 1 {
return sessions
}
// Group sessions by consolidation key
groups := make(map[consolidationKey][]*AccessSession)
for _, s := range sessions {
key := consolidationKey{
sourceIP: extractIP(s.SourceAddr),
destAddr: s.DestAddr,
protocol: s.Protocol,
resourceID: s.ResourceID,
}
groups[key] = append(groups[key], s)
}
result := make([]*AccessSession, 0, len(sessions))
for key, group := range groups {
// Small groups don't need consolidation
if len(group) < minConnectionsToConsolidate {
result = append(result, group...)
continue
}
// Sort the group by start time so we can detect gaps
sort.Slice(group, func(i, j int) bool {
return group[i].StartedAt.Before(group[j].StartedAt)
})
// Walk through and merge runs that are within the gap threshold
var merged []*AccessSession
cur := cloneSession(group[0])
cur.ConnectionCount = 1
sourcePorts := make(map[string]struct{})
sourcePorts[cur.SourceAddr] = struct{}{}
for i := 1; i < len(group); i++ {
s := group[i]
// Determine the gap: from the latest end time we've seen so far to the
// start of the next connection.
gapRef := cur.EndedAt
if gapRef.IsZero() {
gapRef = cur.StartedAt
}
gap := s.StartedAt.Sub(gapRef)
if gap <= sessionGapThreshold {
// Merge into the current consolidated session
cur.ConnectionCount++
cur.BytesTx += s.BytesTx
cur.BytesRx += s.BytesRx
sourcePorts[s.SourceAddr] = struct{}{}
// Extend EndedAt to the latest time
endTime := s.EndedAt
if endTime.IsZero() {
endTime = s.StartedAt
}
if endTime.After(cur.EndedAt) {
cur.EndedAt = endTime
}
} else {
// Gap exceeded — finalize the current session and start a new one
finalizeMergedSourceAddr(cur, key.sourceIP, sourcePorts)
merged = append(merged, cur)
cur = cloneSession(s)
cur.ConnectionCount = 1
sourcePorts = make(map[string]struct{})
sourcePorts[s.SourceAddr] = struct{}{}
}
}
// Finalize the last accumulated session
finalizeMergedSourceAddr(cur, key.sourceIP, sourcePorts)
merged = append(merged, cur)
result = append(result, merged...)
}
return result
}
// cloneSession creates a shallow copy of an AccessSession.
func cloneSession(s *AccessSession) *AccessSession {
cp := *s
return &cp
}
// finalizeMergedSourceAddr sets the SourceAddr on a consolidated session.
// If multiple distinct source addresses (ports) were seen, the port is
// stripped and only the IP is kept so the log isn't misleading.
func finalizeMergedSourceAddr(s *AccessSession, sourceIP string, ports map[string]struct{}) {
if len(ports) > 1 {
// Multiple source ports — just report the IP
s.SourceAddr = sourceIP
}
// Otherwise keep the original SourceAddr which already has ip:port
}
// flush drains the completed sessions buffer, consolidates bursts of
// short-lived connections, compresses with zlib, and sends via the SendFunc.
func (al *AccessLogger) flush() {
al.mu.Lock()
if len(al.completedSessions) == 0 {
al.mu.Unlock()
return
}
batch := al.completedSessions
al.completedSessions = make([]*AccessSession, 0)
sendFn := al.sendFn
al.mu.Unlock()
if sendFn == nil {
logger.Debug("Access logger: no send function configured, discarding %d sessions", len(batch))
return
}
// Consolidate bursts of short-lived connections into higher-level sessions
originalCount := len(batch)
batch = consolidateSessions(batch)
if len(batch) != originalCount {
logger.Info("Access logger: consolidated %d raw connections into %d sessions", originalCount, len(batch))
}
compressed, err := compressSessions(batch)
if err != nil {
logger.Error("Access logger: failed to compress %d sessions: %v", len(batch), err)
return
}
if err := sendFn(compressed); err != nil {
logger.Error("Access logger: failed to send %d sessions: %v", len(batch), err)
// Re-queue the batch so we don't lose data
al.mu.Lock()
al.completedSessions = append(batch, al.completedSessions...)
// Cap re-queued data to prevent unbounded growth if server is unreachable
if len(al.completedSessions) > maxBufferedSessions*5 {
dropped := len(al.completedSessions) - maxBufferedSessions*5
al.completedSessions = al.completedSessions[:maxBufferedSessions*5]
logger.Warn("Access logger: buffer overflow, dropped %d oldest sessions", dropped)
}
al.mu.Unlock()
return
}
logger.Info("Access logger: sent %d sessions to server", len(batch))
}
// compressSessions JSON-encodes the sessions, compresses with zlib, and returns
// a base64-encoded string suitable for embedding in a JSON message.
func compressSessions(sessions []*AccessSession) (string, error) {
jsonData, err := json.Marshal(sessions)
if err != nil {
return "", err
}
var buf bytes.Buffer
w, err := zlib.NewWriterLevel(&buf, zlib.BestCompression)
if err != nil {
return "", err
}
if _, err := w.Write(jsonData); err != nil {
w.Close()
return "", err
}
if err := w.Close(); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(buf.Bytes()), nil
}
// Close shuts down the background loop, ends all active sessions,
// and performs one final flush to send everything to the server.
func (al *AccessLogger) Close() {
// Signal the background loop to stop
select {
case <-al.stopCh:
// Already closed
return
default:
close(al.stopCh)
}
// Wait for the background loop to exit so we don't race on flush
<-al.flushDone
al.mu.Lock()
now := time.Now()
// End all active sessions and move them to the completed buffer
for _, session := range al.sessions {
if session.EndedAt.IsZero() {
session.EndedAt = now
duration := now.Sub(session.StartedAt)
logger.Info("ACCESS END (shutdown) session=%s resource=%d proto=%s src=%s dst=%s started=%s ended=%s duration=%s",
session.SessionID, session.ResourceID, session.Protocol, session.SourceAddr, session.DestAddr,
session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration)
al.completedSessions = append(al.completedSessions, session)
}
}
al.sessions = make(map[string]*AccessSession)
al.udpSessions = make(map[udpSessionKey]*AccessSession)
al.mu.Unlock()
// Final flush to send all remaining sessions to the server
al.flush()
}

View File

@@ -1,811 +0,0 @@
package netstack2
import (
"testing"
"time"
)
func TestExtractIP(t *testing.T) {
tests := []struct {
name string
addr string
expected string
}{
{"ipv4 with port", "192.168.1.1:12345", "192.168.1.1"},
{"ipv4 without port", "192.168.1.1", "192.168.1.1"},
{"ipv6 with port", "[::1]:12345", "::1"},
{"ipv6 without port", "::1", "::1"},
{"empty string", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractIP(tt.addr)
if result != tt.expected {
t.Errorf("extractIP(%q) = %q, want %q", tt.addr, result, tt.expected)
}
})
}
}
func TestConsolidateSessions_Empty(t *testing.T) {
result := consolidateSessions(nil)
if result != nil {
t.Errorf("expected nil, got %v", result)
}
result = consolidateSessions([]*AccessSession{})
if len(result) != 0 {
t.Errorf("expected empty slice, got %d items", len(result))
}
}
func TestConsolidateSessions_SingleSession(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "abc123",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(1 * time.Second),
},
}
result := consolidateSessions(sessions)
if len(result) != 1 {
t.Fatalf("expected 1 session, got %d", len(result))
}
if result[0].SourceAddr != "10.0.0.1:5000" {
t.Errorf("expected source addr preserved, got %q", result[0].SourceAddr)
}
}
func TestConsolidateSessions_MergesBurstFromSameSourceIP(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
BytesTx: 100,
BytesRx: 200,
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
BytesTx: 150,
BytesRx: 250,
},
{
SessionID: "s3",
ResourceID: 1,
SourceAddr: "10.0.0.1:5002",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(400 * time.Millisecond),
EndedAt: now.Add(500 * time.Millisecond),
BytesTx: 50,
BytesRx: 75,
},
}
result := consolidateSessions(sessions)
if len(result) != 1 {
t.Fatalf("expected 1 consolidated session, got %d", len(result))
}
s := result[0]
if s.ConnectionCount != 3 {
t.Errorf("expected ConnectionCount=3, got %d", s.ConnectionCount)
}
if s.SourceAddr != "10.0.0.1" {
t.Errorf("expected source addr to be IP only (multiple ports), got %q", s.SourceAddr)
}
if s.DestAddr != "192.168.1.100:443" {
t.Errorf("expected dest addr preserved, got %q", s.DestAddr)
}
if s.StartedAt != now {
t.Errorf("expected StartedAt to be earliest time")
}
if s.EndedAt != now.Add(500*time.Millisecond) {
t.Errorf("expected EndedAt to be latest time")
}
expectedTx := int64(300)
expectedRx := int64(525)
if s.BytesTx != expectedTx {
t.Errorf("expected BytesTx=%d, got %d", expectedTx, s.BytesTx)
}
if s.BytesRx != expectedRx {
t.Errorf("expected BytesRx=%d, got %d", expectedRx, s.BytesRx)
}
}
func TestConsolidateSessions_SameSourcePortPreserved(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
},
}
result := consolidateSessions(sessions)
if len(result) != 1 {
t.Fatalf("expected 1 session, got %d", len(result))
}
if result[0].SourceAddr != "10.0.0.1:5000" {
t.Errorf("expected source addr with port preserved when all ports are the same, got %q", result[0].SourceAddr)
}
if result[0].ConnectionCount != 2 {
t.Errorf("expected ConnectionCount=2, got %d", result[0].ConnectionCount)
}
}
func TestConsolidateSessions_GapSplitsSessions(t *testing.T) {
now := time.Now()
// First burst
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
},
// Big gap here (10 seconds)
{
SessionID: "s3",
ResourceID: 1,
SourceAddr: "10.0.0.1:5002",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(10 * time.Second),
EndedAt: now.Add(10*time.Second + 100*time.Millisecond),
},
{
SessionID: "s4",
ResourceID: 1,
SourceAddr: "10.0.0.1:5003",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(10*time.Second + 200*time.Millisecond),
EndedAt: now.Add(10*time.Second + 300*time.Millisecond),
},
}
result := consolidateSessions(sessions)
if len(result) != 2 {
t.Fatalf("expected 2 consolidated sessions (gap split), got %d", len(result))
}
// Find the sessions by their start time
var first, second *AccessSession
for _, s := range result {
if s.StartedAt.Equal(now) {
first = s
} else {
second = s
}
}
if first == nil || second == nil {
t.Fatal("could not find both consolidated sessions")
}
if first.ConnectionCount != 2 {
t.Errorf("first burst: expected ConnectionCount=2, got %d", first.ConnectionCount)
}
if second.ConnectionCount != 2 {
t.Errorf("second burst: expected ConnectionCount=2, got %d", second.ConnectionCount)
}
}
func TestConsolidateSessions_DifferentDestinationsNotMerged(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:8080",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
},
}
result := consolidateSessions(sessions)
// Each goes to a different dest port so they should not be merged
if len(result) != 2 {
t.Fatalf("expected 2 sessions (different destinations), got %d", len(result))
}
}
func TestConsolidateSessions_DifferentProtocolsNotMerged(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "udp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
},
}
result := consolidateSessions(sessions)
if len(result) != 2 {
t.Fatalf("expected 2 sessions (different protocols), got %d", len(result))
}
}
func TestConsolidateSessions_DifferentResourceIDsNotMerged(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
SessionID: "s2",
ResourceID: 2,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
},
}
result := consolidateSessions(sessions)
if len(result) != 2 {
t.Fatalf("expected 2 sessions (different resource IDs), got %d", len(result))
}
}
func TestConsolidateSessions_DifferentSourceIPsNotMerged(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.2:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
},
}
result := consolidateSessions(sessions)
if len(result) != 2 {
t.Fatalf("expected 2 sessions (different source IPs), got %d", len(result))
}
}
func TestConsolidateSessions_OutOfOrderInput(t *testing.T) {
now := time.Now()
// Provide sessions out of chronological order to verify sorting
sessions := []*AccessSession{
{
SessionID: "s3",
ResourceID: 1,
SourceAddr: "10.0.0.1:5002",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(400 * time.Millisecond),
EndedAt: now.Add(500 * time.Millisecond),
BytesTx: 30,
},
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
BytesTx: 10,
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
BytesTx: 20,
},
}
result := consolidateSessions(sessions)
if len(result) != 1 {
t.Fatalf("expected 1 consolidated session, got %d", len(result))
}
s := result[0]
if s.ConnectionCount != 3 {
t.Errorf("expected ConnectionCount=3, got %d", s.ConnectionCount)
}
if s.StartedAt != now {
t.Errorf("expected StartedAt to be earliest time")
}
if s.EndedAt != now.Add(500*time.Millisecond) {
t.Errorf("expected EndedAt to be latest time")
}
if s.BytesTx != 60 {
t.Errorf("expected BytesTx=60, got %d", s.BytesTx)
}
}
func TestConsolidateSessions_ExactlyAtGapThreshold(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
// Starts exactly sessionGapThreshold after s1 ends — should still merge
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(100*time.Millisecond + sessionGapThreshold),
EndedAt: now.Add(100*time.Millisecond + sessionGapThreshold + 50*time.Millisecond),
},
}
result := consolidateSessions(sessions)
if len(result) != 1 {
t.Fatalf("expected 1 session (gap exactly at threshold merges), got %d", len(result))
}
if result[0].ConnectionCount != 2 {
t.Errorf("expected ConnectionCount=2, got %d", result[0].ConnectionCount)
}
}
func TestConsolidateSessions_JustOverGapThreshold(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
// Starts 1ms over the gap threshold after s1 ends — should split
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(100*time.Millisecond + sessionGapThreshold + 1*time.Millisecond),
EndedAt: now.Add(100*time.Millisecond + sessionGapThreshold + 50*time.Millisecond),
},
}
result := consolidateSessions(sessions)
if len(result) != 2 {
t.Fatalf("expected 2 sessions (gap just over threshold splits), got %d", len(result))
}
}
func TestConsolidateSessions_UDPSessions(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "u1",
ResourceID: 5,
SourceAddr: "10.0.0.1:6000",
DestAddr: "192.168.1.100:53",
Protocol: "udp",
StartedAt: now,
EndedAt: now.Add(50 * time.Millisecond),
BytesTx: 64,
BytesRx: 512,
},
{
SessionID: "u2",
ResourceID: 5,
SourceAddr: "10.0.0.1:6001",
DestAddr: "192.168.1.100:53",
Protocol: "udp",
StartedAt: now.Add(100 * time.Millisecond),
EndedAt: now.Add(150 * time.Millisecond),
BytesTx: 64,
BytesRx: 256,
},
{
SessionID: "u3",
ResourceID: 5,
SourceAddr: "10.0.0.1:6002",
DestAddr: "192.168.1.100:53",
Protocol: "udp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(250 * time.Millisecond),
BytesTx: 64,
BytesRx: 128,
},
}
result := consolidateSessions(sessions)
if len(result) != 1 {
t.Fatalf("expected 1 consolidated UDP session, got %d", len(result))
}
s := result[0]
if s.Protocol != "udp" {
t.Errorf("expected protocol=udp, got %q", s.Protocol)
}
if s.ConnectionCount != 3 {
t.Errorf("expected ConnectionCount=3, got %d", s.ConnectionCount)
}
if s.SourceAddr != "10.0.0.1" {
t.Errorf("expected source addr to be IP only, got %q", s.SourceAddr)
}
if s.BytesTx != 192 {
t.Errorf("expected BytesTx=192, got %d", s.BytesTx)
}
if s.BytesRx != 896 {
t.Errorf("expected BytesRx=896, got %d", s.BytesRx)
}
}
func TestConsolidateSessions_MixedGroupsSomeConsolidatedSomeNot(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
// Group 1: 3 connections to :443 from same IP — should consolidate
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
},
{
SessionID: "s3",
ResourceID: 1,
SourceAddr: "10.0.0.1:5002",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(400 * time.Millisecond),
EndedAt: now.Add(500 * time.Millisecond),
},
// Group 2: 1 connection to :8080 from different IP — should pass through
{
SessionID: "s4",
ResourceID: 2,
SourceAddr: "10.0.0.2:6000",
DestAddr: "192.168.1.200:8080",
Protocol: "tcp",
StartedAt: now.Add(1 * time.Second),
EndedAt: now.Add(2 * time.Second),
},
}
result := consolidateSessions(sessions)
if len(result) != 2 {
t.Fatalf("expected 2 sessions total, got %d", len(result))
}
var consolidated, passthrough *AccessSession
for _, s := range result {
if s.ConnectionCount > 1 {
consolidated = s
} else {
passthrough = s
}
}
if consolidated == nil {
t.Fatal("expected a consolidated session")
}
if consolidated.ConnectionCount != 3 {
t.Errorf("consolidated: expected ConnectionCount=3, got %d", consolidated.ConnectionCount)
}
if passthrough == nil {
t.Fatal("expected a passthrough session")
}
if passthrough.SessionID != "s4" {
t.Errorf("passthrough: expected session s4, got %s", passthrough.SessionID)
}
}
func TestConsolidateSessions_OverlappingConnections(t *testing.T) {
now := time.Now()
// Connections that overlap in time (not sequential)
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(5 * time.Second),
BytesTx: 100,
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(1 * time.Second),
EndedAt: now.Add(3 * time.Second),
BytesTx: 200,
},
{
SessionID: "s3",
ResourceID: 1,
SourceAddr: "10.0.0.1:5002",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(2 * time.Second),
EndedAt: now.Add(6 * time.Second),
BytesTx: 300,
},
}
result := consolidateSessions(sessions)
if len(result) != 1 {
t.Fatalf("expected 1 consolidated session, got %d", len(result))
}
s := result[0]
if s.ConnectionCount != 3 {
t.Errorf("expected ConnectionCount=3, got %d", s.ConnectionCount)
}
if s.StartedAt != now {
t.Error("expected StartedAt to be earliest")
}
if s.EndedAt != now.Add(6*time.Second) {
t.Error("expected EndedAt to be the latest end time")
}
if s.BytesTx != 600 {
t.Errorf("expected BytesTx=600, got %d", s.BytesTx)
}
}
func TestConsolidateSessions_DoesNotMutateOriginals(t *testing.T) {
now := time.Now()
s1 := &AccessSession{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
BytesTx: 100,
}
s2 := &AccessSession{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
BytesTx: 200,
}
// Save original values
origS1Addr := s1.SourceAddr
origS1Bytes := s1.BytesTx
origS2Addr := s2.SourceAddr
origS2Bytes := s2.BytesTx
_ = consolidateSessions([]*AccessSession{s1, s2})
if s1.SourceAddr != origS1Addr {
t.Errorf("s1.SourceAddr was mutated: %q -> %q", origS1Addr, s1.SourceAddr)
}
if s1.BytesTx != origS1Bytes {
t.Errorf("s1.BytesTx was mutated: %d -> %d", origS1Bytes, s1.BytesTx)
}
if s2.SourceAddr != origS2Addr {
t.Errorf("s2.SourceAddr was mutated: %q -> %q", origS2Addr, s2.SourceAddr)
}
if s2.BytesTx != origS2Bytes {
t.Errorf("s2.BytesTx was mutated: %d -> %d", origS2Bytes, s2.BytesTx)
}
}
func TestConsolidateSessions_ThreeBurstsWithGaps(t *testing.T) {
now := time.Now()
sessions := make([]*AccessSession, 0, 9)
// Burst 1: 3 connections at t=0
for i := 0; i < 3; i++ {
sessions = append(sessions, &AccessSession{
SessionID: generateSessionID(),
ResourceID: 1,
SourceAddr: "10.0.0.1:" + string(rune('A'+i)),
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(time.Duration(i*100) * time.Millisecond),
EndedAt: now.Add(time.Duration(i*100+50) * time.Millisecond),
})
}
// Burst 2: 3 connections at t=20s (well past the 5s gap)
for i := 0; i < 3; i++ {
sessions = append(sessions, &AccessSession{
SessionID: generateSessionID(),
ResourceID: 1,
SourceAddr: "10.0.0.1:" + string(rune('D'+i)),
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(20*time.Second + time.Duration(i*100)*time.Millisecond),
EndedAt: now.Add(20*time.Second + time.Duration(i*100+50)*time.Millisecond),
})
}
// Burst 3: 3 connections at t=40s
for i := 0; i < 3; i++ {
sessions = append(sessions, &AccessSession{
SessionID: generateSessionID(),
ResourceID: 1,
SourceAddr: "10.0.0.1:" + string(rune('G'+i)),
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(40*time.Second + time.Duration(i*100)*time.Millisecond),
EndedAt: now.Add(40*time.Second + time.Duration(i*100+50)*time.Millisecond),
})
}
result := consolidateSessions(sessions)
if len(result) != 3 {
t.Fatalf("expected 3 consolidated sessions (3 bursts), got %d", len(result))
}
for _, s := range result {
if s.ConnectionCount != 3 {
t.Errorf("expected each burst to have ConnectionCount=3, got %d (started=%v)", s.ConnectionCount, s.StartedAt)
}
}
}
func TestFinalizeMergedSourceAddr(t *testing.T) {
s := &AccessSession{SourceAddr: "10.0.0.1:5000"}
ports := map[string]struct{}{"10.0.0.1:5000": {}}
finalizeMergedSourceAddr(s, "10.0.0.1", ports)
if s.SourceAddr != "10.0.0.1:5000" {
t.Errorf("single port: expected addr preserved, got %q", s.SourceAddr)
}
s2 := &AccessSession{SourceAddr: "10.0.0.1:5000"}
ports2 := map[string]struct{}{"10.0.0.1:5000": {}, "10.0.0.1:5001": {}}
finalizeMergedSourceAddr(s2, "10.0.0.1", ports2)
if s2.SourceAddr != "10.0.0.1" {
t.Errorf("multiple ports: expected IP only, got %q", s2.SourceAddr)
}
}
func TestCloneSession(t *testing.T) {
original := &AccessSession{
SessionID: "test",
ResourceID: 42,
SourceAddr: "1.2.3.4:100",
DestAddr: "5.6.7.8:443",
Protocol: "tcp",
BytesTx: 999,
}
clone := cloneSession(original)
if clone == original {
t.Error("clone should be a different pointer")
}
if clone.SessionID != original.SessionID {
t.Error("clone should have same SessionID")
}
// Mutating clone should not affect original
clone.BytesTx = 0
clone.SourceAddr = "changed"
if original.BytesTx != 999 {
t.Error("mutating clone affected original BytesTx")
}
if original.SourceAddr != "1.2.3.4:100" {
t.Error("mutating clone affected original SourceAddr")
}
}

View File

@@ -158,18 +158,6 @@ func (h *TCPHandler) handleTCPConn(netstackConn *gonet.TCPConn, id stack.Transpo
targetAddr := fmt.Sprintf("%s:%d", actualDstIP, dstPort)
// Look up resource ID and start access session if applicable
var accessSessionID string
if h.proxyHandler != nil {
resourceId := h.proxyHandler.LookupResourceId(srcIP, dstIP, dstPort, uint8(tcp.ProtocolNumber))
if resourceId != 0 {
if al := h.proxyHandler.GetAccessLogger(); al != nil {
srcAddr := fmt.Sprintf("%s:%d", srcIP, srcPort)
accessSessionID = al.StartTCPSession(resourceId, srcAddr, targetAddr)
}
}
}
// Create context with timeout for connection establishment
ctx, cancel := context.WithTimeout(context.Background(), tcpConnectTimeout)
defer cancel()
@@ -179,26 +167,11 @@ func (h *TCPHandler) handleTCPConn(netstackConn *gonet.TCPConn, id stack.Transpo
targetConn, err := d.DialContext(ctx, "tcp", targetAddr)
if err != nil {
logger.Info("TCP Forwarder: Failed to connect to %s: %v", targetAddr, err)
// End access session on connection failure
if accessSessionID != "" {
if al := h.proxyHandler.GetAccessLogger(); al != nil {
al.EndTCPSession(accessSessionID)
}
}
// Connection failed, netstack will handle RST
return
}
defer targetConn.Close()
// End access session when connection closes
if accessSessionID != "" {
defer func() {
if al := h.proxyHandler.GetAccessLogger(); al != nil {
al.EndTCPSession(accessSessionID)
}
}()
}
logger.Info("TCP Forwarder: Successfully connected to %s, starting bidirectional copy", targetAddr)
// Bidirectional copy between netstack and target
@@ -307,27 +280,6 @@ func (h *UDPHandler) handleUDPConn(netstackConn *gonet.UDPConn, id stack.Transpo
targetAddr := fmt.Sprintf("%s:%d", actualDstIP, dstPort)
// Look up resource ID and start access session if applicable
var accessSessionID string
if h.proxyHandler != nil {
resourceId := h.proxyHandler.LookupResourceId(srcIP, dstIP, dstPort, uint8(udp.ProtocolNumber))
if resourceId != 0 {
if al := h.proxyHandler.GetAccessLogger(); al != nil {
srcAddr := fmt.Sprintf("%s:%d", srcIP, srcPort)
accessSessionID = al.TrackUDPSession(resourceId, srcAddr, targetAddr)
}
}
}
// End access session when UDP handler returns (timeout or error)
if accessSessionID != "" {
defer func() {
if al := h.proxyHandler.GetAccessLogger(); al != nil {
al.EndUDPSession(accessSessionID)
}
}()
}
// Resolve target address
remoteUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr)
if err != nil {

View File

@@ -22,12 +22,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
const (
// udpAccessSessionTimeout is how long a UDP access session stays alive without traffic
// before being considered ended by the access logger
udpAccessSessionTimeout = 120 * time.Second
)
// PortRange represents an allowed range of ports (inclusive) with optional protocol filtering
// Protocol can be "tcp", "udp", or "" (empty string means both protocols)
type PortRange struct {
@@ -52,7 +46,6 @@ type SubnetRule struct {
DisableIcmp bool // If true, ICMP traffic is blocked for this subnet
RewriteTo string // Optional rewrite address for DNAT - can be IP/CIDR or domain name
PortRanges []PortRange // empty slice means all ports allowed
ResourceId int // Optional resource ID from the server for access logging
}
// GetAllRules returns a copy of all subnet rules
@@ -118,12 +111,10 @@ type ProxyHandler struct {
natTable map[connKey]*natState
reverseNatTable map[reverseConnKey]*natState // Reverse lookup map for O(1) reply packet NAT
destRewriteTable map[destKey]netip.Addr // Maps original dest to rewritten dest for handler lookups
resourceTable map[destKey]int // Maps connection key to resource ID for access logging
natMu sync.RWMutex
enabled bool
icmpReplies chan []byte // Channel for ICMP reply packets to be sent back through the tunnel
notifiable channel.Notification // Notification handler for triggering reads
accessLogger *AccessLogger // Access logger for tracking sessions
}
// ProxyHandlerOptions configures the proxy handler
@@ -146,9 +137,7 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
natTable: make(map[connKey]*natState),
reverseNatTable: make(map[reverseConnKey]*natState),
destRewriteTable: make(map[destKey]netip.Addr),
resourceTable: make(map[destKey]int),
icmpReplies: make(chan []byte, 256), // Buffer for ICMP reply packets
accessLogger: NewAccessLogger(udpAccessSessionTimeout),
proxyEp: channel.New(1024, uint32(options.MTU), ""),
proxyStack: stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
@@ -213,11 +202,11 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
// destPrefix: The IP prefix of the destination
// rewriteTo: Optional address to rewrite destination to - can be IP/CIDR or domain name
// If portRanges is nil or empty, all ports are allowed for this subnet
func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) {
func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) {
if p == nil || !p.enabled {
return
}
p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp, resourceId)
p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp)
}
// RemoveSubnetRule removes a subnet from the proxy handler
@@ -236,43 +225,6 @@ func (p *ProxyHandler) GetAllRules() []SubnetRule {
return p.subnetLookup.GetAllRules()
}
// LookupResourceId looks up the resource ID for a connection
// Returns 0 if no resource ID is associated with this connection
func (p *ProxyHandler) LookupResourceId(srcIP, dstIP string, dstPort uint16, proto uint8) int {
if p == nil || !p.enabled {
return 0
}
key := destKey{
srcIP: srcIP,
dstIP: dstIP,
dstPort: dstPort,
proto: proto,
}
p.natMu.RLock()
defer p.natMu.RUnlock()
return p.resourceTable[key]
}
// GetAccessLogger returns the access logger for session tracking
func (p *ProxyHandler) GetAccessLogger() *AccessLogger {
if p == nil {
return nil
}
return p.accessLogger
}
// SetAccessLogSender configures the function used to send compressed access log
// batches to the server. This should be called once the websocket client is available.
func (p *ProxyHandler) SetAccessLogSender(fn SendFunc) {
if p == nil || !p.enabled || p.accessLogger == nil {
return
}
p.accessLogger.SetSendFunc(fn)
}
// LookupDestinationRewrite looks up the rewritten destination for a connection
// This is used by TCP/UDP handlers to find the actual target address
func (p *ProxyHandler) LookupDestinationRewrite(srcIP, dstIP string, dstPort uint16, proto uint8) (netip.Addr, bool) {
@@ -435,22 +387,8 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
// Check if the source IP, destination IP, port, and protocol match any subnet rule
matchedRule := p.subnetLookup.Match(srcAddr, dstAddr, dstPort, protocol)
if matchedRule != nil {
logger.Debug("HandleIncomingPacket: Matched rule for %s -> %s (proto=%d, port=%d, resourceId=%d)",
srcAddr, dstAddr, protocol, dstPort, matchedRule.ResourceId)
// Store resource ID for connections without DNAT as well
if matchedRule.ResourceId != 0 && matchedRule.RewriteTo == "" {
dKey := destKey{
srcIP: srcAddr.String(),
dstIP: dstAddr.String(),
dstPort: dstPort,
proto: uint8(protocol),
}
p.natMu.Lock()
p.resourceTable[dKey] = matchedRule.ResourceId
p.natMu.Unlock()
}
logger.Debug("HandleIncomingPacket: Matched rule for %s -> %s (proto=%d, port=%d)",
srcAddr, dstAddr, protocol, dstPort)
// Check if we need to perform DNAT
if matchedRule.RewriteTo != "" {
// Create connection tracking key using original destination
@@ -482,13 +420,6 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
proto: uint8(protocol),
}
// Store resource ID for access logging if present
if matchedRule.ResourceId != 0 {
p.natMu.Lock()
p.resourceTable[dKey] = matchedRule.ResourceId
p.natMu.Unlock()
}
// Check if we already have a NAT entry for this connection
p.natMu.RLock()
existingEntry, exists := p.natTable[key]
@@ -789,11 +720,6 @@ func (p *ProxyHandler) Close() error {
return nil
}
// Shut down access logger
if p.accessLogger != nil {
p.accessLogger.Close()
}
// Close ICMP replies channel
if p.icmpReplies != nil {
close(p.icmpReplies)

View File

@@ -47,7 +47,7 @@ func prefixEqual(a, b netip.Prefix) bool {
// AddSubnet adds a subnet rule with source and destination prefixes and optional port restrictions
// If portRanges is nil or empty, all ports are allowed for this subnet
// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com")
func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) {
func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) {
sl.mu.Lock()
defer sl.mu.Unlock()
@@ -57,7 +57,6 @@ func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewrite
DisableIcmp: disableIcmp,
RewriteTo: rewriteTo,
PortRanges: portRanges,
ResourceId: resourceId,
}
// Canonicalize source prefix to handle host bits correctly

View File

@@ -354,10 +354,10 @@ func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
// AddProxySubnetRule adds a subnet rule to the proxy handler
// If portRanges is nil or empty, all ports are allowed for this subnet
// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com")
func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) {
func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) {
tun := (*netTun)(net)
if tun.proxyHandler != nil {
tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp, resourceId)
tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp)
}
}
@@ -385,15 +385,6 @@ func (net *Net) GetProxyHandler() *ProxyHandler {
return tun.proxyHandler
}
// SetAccessLogSender configures the function used to send compressed access log
// batches to the server. This should be called once the websocket client is available.
func (net *Net) SetAccessLogSender(fn SendFunc) {
tun := (*netTun)(net)
if tun.proxyHandler != nil {
tun.proxyHandler.SetAccessLogSender(fn)
}
}
type PingConn struct {
laddr PingAddr
raddr PingAddr

View File

@@ -21,10 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
)
const (
errUnsupportedProtoFmt = "unsupported protocol: %s"
maxUDPPacketSize = 65507
)
const errUnsupportedProtoFmt = "unsupported protocol: %s"
// Target represents a proxy target with its address and port
type Target struct {
@@ -108,9 +105,13 @@ func classifyProxyError(err error) string {
if errors.Is(err, net.ErrClosed) {
return "closed"
}
var ne net.Error
if errors.As(err, &ne) && ne.Timeout() {
return "timeout"
if ne, ok := err.(net.Error); ok {
if ne.Timeout() {
return "timeout"
}
if ne.Temporary() {
return "temporary"
}
}
msg := strings.ToLower(err.Error())
switch {
@@ -436,6 +437,14 @@ func (pm *ProxyManager) Stop() error {
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
}
// // Clear the target maps
// for k := range pm.tcpTargets {
// delete(pm.tcpTargets, k)
// }
// for k := range pm.udpTargets {
// delete(pm.udpTargets, k)
// }
// Give active connections a chance to close gracefully
time.Sleep(100 * time.Millisecond)
@@ -489,7 +498,7 @@ func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string)
if !pm.running {
return
}
if errors.Is(err, net.ErrClosed) {
if ne, ok := err.(net.Error); ok && !ne.Temporary() {
logger.Info("TCP listener closed, stopping proxy handler for %v", listener.Addr())
return
}
@@ -555,7 +564,7 @@ func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string)
}
func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
buffer := make([]byte, maxUDPPacketSize) // Max UDP packet size
buffer := make([]byte, 65507) // Max UDP packet size
clientConns := make(map[string]*net.UDPConn)
var clientsMutex sync.RWMutex
@@ -574,7 +583,7 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
}
// Check for connection closed conditions
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
if err == io.EOF || strings.Contains(err.Error(), "use of closed network connection") {
logger.Info("UDP connection closed, stopping proxy handler")
// Clean up existing client connections
@@ -653,14 +662,10 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
telemetry.IncProxyConnectionEvent(context.Background(), tunnelID, "udp", telemetry.ProxyConnectionClosed)
}()
buffer := make([]byte, maxUDPPacketSize)
buffer := make([]byte, 65507)
for {
n, _, err := targetConn.ReadFromUDP(buffer)
if err != nil {
// Connection closed is normal during cleanup
if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) {
return // defer will handle cleanup, result stays "success"
}
logger.Error("Error reading from target: %v", err)
result = "failure"
return // defer will handle cleanup

View File

@@ -1,6 +1,7 @@
package util
import (
"context"
"encoding/base64"
"encoding/binary"
"encoding/hex"
@@ -14,6 +15,99 @@ import (
"golang.zx2c4.com/wireguard/device"
)
func ResolveDomainUpstream(domain string, publicDNS []string) (string, error) {
// trim whitespace
domain = strings.TrimSpace(domain)
// Remove any protocol prefix if present (do this first, before splitting host/port)
domain = strings.TrimPrefix(domain, "http://")
domain = strings.TrimPrefix(domain, "https://")
// if there are any trailing slashes, remove them
domain = strings.TrimSuffix(domain, "/")
// Check if there's a port in the domain
host, port, err := net.SplitHostPort(domain)
if err != nil {
// No port found, use the domain as is
host = domain
port = ""
}
// Check if host is already an IP address (IPv4 or IPv6)
// For IPv6, the host from SplitHostPort will already have brackets stripped
// but if there was no port, we need to handle bracketed IPv6 addresses
cleanHost := strings.TrimPrefix(strings.TrimSuffix(host, "]"), "[")
if ip := net.ParseIP(cleanHost); ip != nil {
// It's already an IP address, no need to resolve
ipAddr := ip.String()
if port != "" {
return net.JoinHostPort(ipAddr, port), nil
}
return ipAddr, nil
}
// Lookup IP addresses using the upstream DNS servers if provided
var ips []net.IP
if len(publicDNS) > 0 {
var lastErr error
for _, server := range publicDNS {
// Ensure the upstream DNS address has a port
dnsAddr := server
if _, _, err := net.SplitHostPort(dnsAddr); err != nil {
// No port specified, default to 53
dnsAddr = net.JoinHostPort(server, "53")
}
resolver := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{}
return d.DialContext(ctx, "udp", dnsAddr)
},
}
ips, lastErr = resolver.LookupIP(context.Background(), "ip", host)
if lastErr == nil {
break
}
}
if lastErr != nil {
return "", fmt.Errorf("DNS lookup failed using all upstream servers: %v", lastErr)
}
} else {
ips, err = net.LookupIP(host)
if err != nil {
return "", fmt.Errorf("DNS lookup failed: %v", err)
}
}
if len(ips) == 0 {
return "", fmt.Errorf("no IP addresses found for domain %s", host)
}
// Get the first IPv4 address if available
var ipAddr string
for _, ip := range ips {
if ipv4 := ip.To4(); ipv4 != nil {
ipAddr = ipv4.String()
break
}
}
// If no IPv4 found, use the first IP (might be IPv6)
if ipAddr == "" {
ipAddr = ips[0].String()
}
// Add port back if it existed
if port != "" {
ipAddr = net.JoinHostPort(ipAddr, port)
}
return ipAddr, nil
}
func ResolveDomain(domain string) (string, error) {
// trim whitespace
domain = strings.TrimSpace(domain)