Compare commits

...

30 Commits

Author SHA1 Message Date
Owen
0f57985b6f Saving and sending access logs pass 1 2026-03-23 16:39:01 -07:00
Owen Schwartz
a2683eb385 Merge pull request #274 from LaurenceJJones/refactor/proxy-cleanup-basics
refactor(proxy): cleanup basics - constants, remove dead code, fix de…
2026-03-18 15:39:43 -07:00
Owen Schwartz
d3722c2519 Merge pull request #280 from LaurenceJJones/fix/healthcheck-ipv6
fix(healthcheck): Support ipv6 healthchecks
2026-03-18 15:38:15 -07:00
Laurence
8fda35db4f fix(healthcheck): Support ipv6 healthchecks
Currently we are doing fmt.sprintf on hostname and port which will not properly handle ipv6 addresses, instead of changing pangolin to send bracketed address a simply net.join can do this for us since we dont need to parse a formatted string
2026-03-18 13:37:31 +00:00
Owen Schwartz
de4353f2e6 Merge pull request #269 from LaurenceJJones/feature/pprof-endpoint
feat(admin): Add pprof endpoints
2026-03-17 11:42:08 -07:00
Owen
8161fa6626 Bump ping interval up 2026-03-16 14:33:40 -07:00
Owen
24dfb3a8a2 Remove redundant info 2026-03-16 13:50:45 -07:00
Laurence
13448f76aa refactor(proxy): cleanup basics - constants, remove dead code, fix deprecated calls
- Add maxUDPPacketSize constant to replace magic number 65507
- Remove commented-out code in Stop()
- Replace deprecated ne.Temporary() with errors.Is(err, net.ErrClosed)
- Use errors.As instead of type assertion for net.Error
- Use errors.Is for closed connection checks instead of string matching
- Handle closed connection gracefully when reading from UDP target
2026-03-16 14:11:14 +00:00
Owen
d4ebb3e2af Send disconnecting message 2026-03-15 17:42:03 -07:00
Owen
bf029b7bb2 Clean up to match olm 2026-03-14 11:57:37 -07:00
Owen
745d2dbc7e Merge branch 'dev' into msg-opt 2026-03-13 17:10:49 -07:00
Owen
c7b01288e0 Clean up previous logging 2026-03-13 11:45:36 -07:00
Owen
539e595c48 Add optional compression 2026-03-12 17:49:05 -07:00
Laurence
836144aebf feat(admin): Add pprof endpoints
To aid us in debugging user issues with memory or leaks we need to be able for the user to configure pprof, wait and then provide us the output files to see where memory/leaks occur in actual runtimes
2026-03-12 09:22:50 +00:00
Owen
a1df3d7ff0 Merge branch 'dev' of github.com:fosrl/newt into dev 2026-03-11 17:28:16 -07:00
Laurence
d68a13ea1f feat(installer): prefer /usr/local/bin and improve POSIX compatibility
- Always install to /usr/local/bin instead of ~/.local/bin
  - Use sudo automatically when write access is needed
  - Replace bash-specific syntax with POSIX equivalents:
    - Change shebang from #!/bin/bash to #!/bin/sh
    - Replace [[ == *pattern* ]] with case statements
    - Replace echo -e with printf for colored output
  - Script now works with dash, ash, busybox sh, and bash
2026-03-10 10:01:28 -07:00
Owen
accac75a53 Set newt version in dockerfile 2026-03-08 11:26:35 -07:00
Laurence
768415f90b Parse target strings with IPv6 support and strict validation
Add parseTargetString() for listenPort:host:targetPort using net.SplitHostPort/JoinHostPort. Replace manual split in updateTargets; fix err shadowing on remove. Validate listen port 1–65535 and reject empty host/port; use %w for errors. Add tests for IPv4, IPv6, hostnames, and invalid cases.
2026-03-07 21:32:36 -08:00
Owen
da9825d030 Merge branch 'main' into dev 2026-03-07 12:34:45 -08:00
Owen
afdb1fc977 Make sure to set version and fix prepare issue 2026-03-07 12:32:49 -08:00
Owen
1bd1133ac2 Make sure to skip prepare 2026-03-07 10:36:18 -08:00
Owen
fac0f5b197 Build full arn 2026-03-07 10:17:14 -08:00
Owen
e68b65683f Temp lets ignore the sync messages 2026-03-06 15:14:48 -08:00
Owen
7d6825132b Merge branch 'dev' into msg-opt 2026-03-03 16:56:41 -08:00
Owen
6371e980d2 Update the get all rules 2026-03-03 16:11:32 -08:00
Owen
15ea631b96 Mutex on handlers, slight change to ping message and handler 2026-03-02 20:56:36 -08:00
Owen
4e854b5f96 Working on message versioning 2026-03-02 20:56:18 -08:00
Owen
287eef0f44 Add version and send it down 2026-03-02 18:27:26 -08:00
Owen
f982e6b629 Merge branch 'dev' into msg-opt 2026-03-02 18:13:55 -08:00
Owen
039ae07b7b Support prefixes sent from server 2026-03-02 18:11:20 -08:00
17 changed files with 1574 additions and 210 deletions

View File

@@ -136,7 +136,7 @@ jobs:
build-amd:
name: Build image (linux/amd64)
needs: [pre-run, prepare]
if: ${{ needs.pre-run.result == 'success' && ((github.event_name == 'push' && github.actor != 'github-actions[bot]') || (github.event_name == 'workflow_dispatch' && (needs.prepare.result == 'success' || needs.prepare.result == 'skipped'))) }}
if: ${{ needs.pre-run.result == 'success' && ((github.event_name == 'push' && github.actor != 'github-actions[bot]' && needs.prepare.result == 'skipped') || (github.event_name == 'workflow_dispatch' && (needs.prepare.result == 'success' || needs.prepare.result == 'skipped'))) }}
runs-on: [self-hosted, linux, x64]
timeout-minutes: 120
env:
@@ -269,6 +269,7 @@ jobs:
context: .
push: true
platforms: linux/amd64
build-args: VERSION=${{ env.TAG }}
tags: |
${{ env.GHCR_IMAGE }}:amd64-${{ env.TAG }}
${{ env.DOCKERHUB_IMAGE }}:amd64-${{ env.TAG }}
@@ -293,7 +294,7 @@ jobs:
build-arm:
name: Build image (linux/arm64)
needs: [pre-run, prepare]
if: ${{ needs.pre-run.result == 'success' && ((github.event_name == 'push' && github.actor != 'github-actions[bot]') || (github.event_name == 'workflow_dispatch' && (needs.prepare.result == 'success' || needs.prepare.result == 'skipped'))) }}
if: ${{ needs.pre-run.result == 'success' && ((github.event_name == 'push' && github.actor != 'github-actions[bot]' && needs.prepare.result == 'skipped') || (github.event_name == 'workflow_dispatch' && (needs.prepare.result == 'success' || needs.prepare.result == 'skipped'))) }}
runs-on: [self-hosted, linux, arm64] # NOTE: ensure label exists on runner
timeout-minutes: 120
env:
@@ -393,6 +394,7 @@ jobs:
context: .
push: true
platforms: linux/arm64
build-args: VERSION=${{ env.TAG }}
tags: |
${{ env.GHCR_IMAGE }}:arm64-${{ env.TAG }}
${{ env.DOCKERHUB_IMAGE }}:arm64-${{ env.TAG }}
@@ -417,7 +419,7 @@ jobs:
build-armv7:
name: Build image (linux/arm/v7)
needs: [pre-run, prepare]
if: ${{ needs.pre-run.result == 'success' && ((github.event_name == 'push' && github.actor != 'github-actions[bot]') || (github.event_name == 'workflow_dispatch' && (needs.prepare.result == 'success' || needs.prepare.result == 'skipped'))) }}
if: ${{ needs.pre-run.result == 'success' && ((github.event_name == 'push' && github.actor != 'github-actions[bot]' && needs.prepare.result == 'skipped') || (github.event_name == 'workflow_dispatch' && (needs.prepare.result == 'success' || needs.prepare.result == 'skipped'))) }}
runs-on: [self-hosted, linux, arm64]
timeout-minutes: 120
env:
@@ -509,6 +511,7 @@ jobs:
context: .
push: true
platforms: linux/arm/v7
build-args: VERSION=${{ env.TAG }}
tags: |
${{ env.GHCR_IMAGE }}:armv7-${{ env.TAG }}
${{ env.DOCKERHUB_IMAGE }}:armv7-${{ env.TAG }}
@@ -887,7 +890,7 @@ jobs:
shell: bash
run: |
set -euo pipefail
make -j 10 go-build-release tag="${TAG}"
make -j 10 go-build-release VERSION="${TAG}"
- name: Create GitHub Release (draft)
uses: softprops/action-gh-release@5be0e66d93ac7ed76da52eca8bb058f665c3a5fe # v2.4.2

View File

@@ -17,7 +17,8 @@ RUN go mod download
COPY . .
# Build the application
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o /newt
ARG VERSION=dev
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w -X main.newtVersion=${VERSION}" -o /newt
FROM public.ecr.aws/docker/library/alpine:3.23 AS runner

View File

@@ -2,6 +2,9 @@
all: local
VERSION ?= dev
LDFLAGS = -X main.newtVersion=$(VERSION)
local:
CGO_ENABLED=0 go build -o ./bin/newt
@@ -40,31 +43,31 @@ go-build-release: \
go-build-release-freebsd-arm64
go-build-release-linux-arm64:
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -o bin/newt_linux_arm64
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -ldflags "$(LDFLAGS)" -o bin/newt_linux_arm64
go-build-release-linux-arm32-v7:
CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=7 go build -o bin/newt_linux_arm32
CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=7 go build -ldflags "$(LDFLAGS)" -o bin/newt_linux_arm32
go-build-release-linux-arm32-v6:
CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=6 go build -o bin/newt_linux_arm32v6
CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=6 go build -ldflags "$(LDFLAGS)" -o bin/newt_linux_arm32v6
go-build-release-linux-amd64:
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/newt_linux_amd64
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags "$(LDFLAGS)" -o bin/newt_linux_amd64
go-build-release-linux-riscv64:
CGO_ENABLED=0 GOOS=linux GOARCH=riscv64 go build -o bin/newt_linux_riscv64
CGO_ENABLED=0 GOOS=linux GOARCH=riscv64 go build -ldflags "$(LDFLAGS)" -o bin/newt_linux_riscv64
go-build-release-darwin-arm64:
CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -o bin/newt_darwin_arm64
CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -ldflags "$(LDFLAGS)" -o bin/newt_darwin_arm64
go-build-release-darwin-amd64:
CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -o bin/newt_darwin_amd64
CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -ldflags "$(LDFLAGS)" -o bin/newt_darwin_amd64
go-build-release-windows-amd64:
CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -o bin/newt_windows_amd64.exe
CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -ldflags "$(LDFLAGS)" -o bin/newt_windows_amd64.exe
go-build-release-freebsd-amd64:
CGO_ENABLED=0 GOOS=freebsd GOARCH=amd64 go build -o bin/newt_freebsd_amd64
CGO_ENABLED=0 GOOS=freebsd GOARCH=amd64 go build -ldflags "$(LDFLAGS)" -o bin/newt_freebsd_amd64
go-build-release-freebsd-arm64:
CGO_ENABLED=0 GOOS=freebsd GOARCH=arm64 go build -o bin/newt_freebsd_arm64
CGO_ENABLED=0 GOOS=freebsd GOARCH=arm64 go build -ldflags "$(LDFLAGS)" -o bin/newt_freebsd_arm64

View File

@@ -37,11 +37,13 @@ type WgConfig struct {
}
type Target struct {
SourcePrefix string `json:"sourcePrefix"`
DestPrefix string `json:"destPrefix"`
RewriteTo string `json:"rewriteTo,omitempty"`
DisableIcmp bool `json:"disableIcmp,omitempty"`
PortRange []PortRange `json:"portRange,omitempty"`
SourcePrefix string `json:"sourcePrefix"`
SourcePrefixes []string `json:"sourcePrefixes"`
DestPrefix string `json:"destPrefix"`
RewriteTo string `json:"rewriteTo,omitempty"`
DisableIcmp bool `json:"disableIcmp,omitempty"`
PortRange []PortRange `json:"portRange,omitempty"`
ResourceId int `json:"resourceId,omitempty"`
}
type PortRange struct {
@@ -112,8 +114,6 @@ func NewWireGuardService(interfaceName string, port uint16, mtu int, host string
return nil, fmt.Errorf("failed to generate private key: %v", err)
}
logger.Debug("+++++++++++++++++++++++++++++++= the port is %d", port)
if port == 0 {
// Find an available port
portRandom, err := util.FindAvailableUDPPort(49152, 65535)
@@ -174,6 +174,7 @@ func NewWireGuardService(interfaceName string, port uint16, mtu int, host string
wsClient.RegisterHandler("newt/wg/targets/add", service.handleAddTarget)
wsClient.RegisterHandler("newt/wg/targets/remove", service.handleRemoveTarget)
wsClient.RegisterHandler("newt/wg/targets/update", service.handleUpdateTarget)
wsClient.RegisterHandler("newt/wg/sync", service.handleSyncConfig)
return service, nil
}
@@ -196,6 +197,15 @@ func (s *WireGuardService) Close() {
s.stopGetConfig = nil
}
// Flush access logs before tearing down the tunnel
if s.tnet != nil {
if ph := s.tnet.GetProxyHandler(); ph != nil {
if al := ph.GetAccessLogger(); al != nil {
al.Close()
}
}
}
// Stop the direct UDP relay first
s.StopDirectUDPRelay()
@@ -279,7 +289,7 @@ func (s *WireGuardService) StartHolepunch(publicKey string, endpoint string, rel
}
if relayPort == 0 {
relayPort = 21820
relayPort = 21820
}
// Convert websocket.ExitNode to holepunch.ExitNode
@@ -494,6 +504,183 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
logger.Info("Client connectivity setup. Ready to accept connections from clients!")
}
// SyncConfig represents the configuration sent from server for syncing
type SyncConfig struct {
Targets []Target `json:"targets"`
Peers []Peer `json:"peers"`
}
func (s *WireGuardService) handleSyncConfig(msg websocket.WSMessage) {
var syncConfig SyncConfig
logger.Debug("Received sync message: %v", msg)
logger.Info("Received sync configuration from remote server")
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling sync data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &syncConfig); err != nil {
logger.Error("Error unmarshaling sync data: %v", err)
return
}
// Sync peers
if err := s.syncPeers(syncConfig.Peers); err != nil {
logger.Error("Failed to sync peers: %v", err)
}
// Sync targets
if err := s.syncTargets(syncConfig.Targets); err != nil {
logger.Error("Failed to sync targets: %v", err)
}
}
// syncPeers synchronizes the current peers with the desired state
// It removes peers not in the desired list and adds missing ones
func (s *WireGuardService) syncPeers(desiredPeers []Peer) error {
if s.device == nil {
return fmt.Errorf("WireGuard device is not initialized")
}
// Get current peers from the device
currentConfig, err := s.device.IpcGet()
if err != nil {
return fmt.Errorf("failed to get current device config: %v", err)
}
// Parse current peer public keys
lines := strings.Split(currentConfig, "\n")
currentPeerKeys := make(map[string]bool)
for _, line := range lines {
if strings.HasPrefix(line, "public_key=") {
pubKey := strings.TrimPrefix(line, "public_key=")
currentPeerKeys[pubKey] = true
}
}
// Build a map of desired peers by their public key (normalized)
desiredPeerMap := make(map[string]Peer)
for _, peer := range desiredPeers {
// Normalize the public key for comparison
pubKey, err := wgtypes.ParseKey(peer.PublicKey)
if err != nil {
logger.Warn("Invalid public key in desired peers: %s", peer.PublicKey)
continue
}
normalizedKey := util.FixKey(pubKey.String())
desiredPeerMap[normalizedKey] = peer
}
// Remove peers that are not in the desired list
for currentKey := range currentPeerKeys {
if _, exists := desiredPeerMap[currentKey]; !exists {
// Parse the key back to get the original format for removal
removeConfig := fmt.Sprintf("public_key=%s\nremove=true", currentKey)
if err := s.device.IpcSet(removeConfig); err != nil {
logger.Warn("Failed to remove peer %s during sync: %v", currentKey, err)
} else {
logger.Info("Removed peer %s during sync", currentKey)
}
}
}
// Add peers that are missing
for normalizedKey, peer := range desiredPeerMap {
if _, exists := currentPeerKeys[normalizedKey]; !exists {
if err := s.addPeerToDevice(peer); err != nil {
logger.Warn("Failed to add peer %s during sync: %v", peer.PublicKey, err)
} else {
logger.Info("Added peer %s during sync", peer.PublicKey)
}
}
}
return nil
}
// syncTargets synchronizes the current targets with the desired state
// It removes targets not in the desired list and adds missing ones
func (s *WireGuardService) syncTargets(desiredTargets []Target) error {
if s.tnet == nil {
// Native interface mode - proxy features not available, skip silently
logger.Debug("Skipping target sync - using native interface (no proxy support)")
return nil
}
// Get current rules from the proxy handler
currentRules := s.tnet.GetProxySubnetRules()
// Build a map of current rules by source+dest prefix
type ruleKey struct {
sourcePrefix string
destPrefix string
}
currentRuleMap := make(map[ruleKey]bool)
for _, rule := range currentRules {
key := ruleKey{
sourcePrefix: rule.SourcePrefix.String(),
destPrefix: rule.DestPrefix.String(),
}
currentRuleMap[key] = true
}
// Build a map of desired targets
desiredTargetMap := make(map[ruleKey]Target)
for _, target := range desiredTargets {
key := ruleKey{
sourcePrefix: target.SourcePrefix,
destPrefix: target.DestPrefix,
}
desiredTargetMap[key] = target
}
// Remove targets that are not in the desired list
for _, rule := range currentRules {
key := ruleKey{
sourcePrefix: rule.SourcePrefix.String(),
destPrefix: rule.DestPrefix.String(),
}
if _, exists := desiredTargetMap[key]; !exists {
s.tnet.RemoveProxySubnetRule(rule.SourcePrefix, rule.DestPrefix)
logger.Info("Removed target %s -> %s during sync", rule.SourcePrefix.String(), rule.DestPrefix.String())
}
}
// Add targets that are missing
for key, target := range desiredTargetMap {
if _, exists := currentRuleMap[key]; !exists {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Warn("Invalid source prefix %s during sync: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil {
logger.Warn("Invalid dest prefix %s during sync: %v", target.DestPrefix, err)
continue
}
var portRanges []netstack2.PortRange
for _, pr := range target.PortRange {
portRanges = append(portRanges, netstack2.PortRange{
Min: pr.Min,
Max: pr.Max,
Protocol: pr.Protocol,
})
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId)
logger.Info("Added target %s -> %s during sync", target.SourcePrefix, target.DestPrefix)
}
}
return nil
}
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
s.mu.Lock()
@@ -617,6 +804,13 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
s.TunnelIP = tunnelIP.String()
// Configure the access log sender to ship compressed session logs via websocket
s.tnet.SetAccessLogSender(func(data string) error {
return s.client.SendMessageNoLog("newt/access-log", map[string]interface{}{
"compressed": data,
})
})
// Create WireGuard device using the shared bind
s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger(
device.LogLevelSilent, // Use silent logging by default - could be made configurable
@@ -697,6 +891,19 @@ func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
return nil
}
// resolveSourcePrefixes returns the effective list of source prefixes for a target,
// supporting both the legacy single SourcePrefix field and the new SourcePrefixes array.
// If SourcePrefixes is non-empty it takes precedence; otherwise SourcePrefix is used.
func resolveSourcePrefixes(target Target) []string {
if len(target.SourcePrefixes) > 0 {
return target.SourcePrefixes
}
if target.SourcePrefix != "" {
return []string{target.SourcePrefix}
}
return nil
}
func (s *WireGuardService) ensureTargets(targets []Target) error {
if s.tnet == nil {
// Native interface mode - proxy features not available, skip silently
@@ -705,11 +912,6 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
}
for _, target := range targets {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
return fmt.Errorf("invalid CIDR %s: %v", target.SourcePrefix, err)
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil {
return fmt.Errorf("invalid CIDR %s: %v", target.DestPrefix, err)
@@ -724,9 +926,14 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
})
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v disableIcmp: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange, target.DisableIcmp)
for _, sp := range resolveSourcePrefixes(target) {
sourcePrefix, err := netip.ParsePrefix(sp)
if err != nil {
return fmt.Errorf("invalid CIDR %s: %v", sp, err)
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
}
}
return nil
@@ -1045,7 +1252,7 @@ func (s *WireGuardService) processPeerBandwidth(publicKey string, rxBytes, txByt
BytesOut: bytesOutMB,
}
}
return nil
}
}
@@ -1096,12 +1303,6 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
// Process all targets
for _, target := range targets {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
@@ -1111,15 +1312,21 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
var portRanges []netstack2.PortRange
for _, pr := range target.PortRange {
portRanges = append(portRanges, netstack2.PortRange{
Min: pr.Min,
Max: pr.Max,
Protocol: pr.Protocol,
Min: pr.Min,
Max: pr.Max,
Protocol: pr.Protocol,
})
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v disableIcmp: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange, target.DisableIcmp)
for _, sp := range resolveSourcePrefixes(target) {
sourcePrefix, err := netip.ParsePrefix(sp)
if err != nil {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
}
}
}
@@ -1148,21 +1355,21 @@ func (s *WireGuardService) handleRemoveTarget(msg websocket.WSMessage) {
// Process all targets
for _, target := range targets {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
continue
}
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix)
logger.Info("Removed target subnet %s with destination %s", target.SourcePrefix, target.DestPrefix)
for _, sp := range resolveSourcePrefixes(target) {
sourcePrefix, err := netip.ParsePrefix(sp)
if err != nil {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix)
logger.Info("Removed target subnet %s with destination %s", sp, target.DestPrefix)
}
}
}
@@ -1196,30 +1403,24 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
// Process all update requests
for _, target := range requests.OldTargets {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
continue
}
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix)
logger.Info("Removed target subnet %s with destination %s", target.SourcePrefix, target.DestPrefix)
for _, sp := range resolveSourcePrefixes(target) {
sourcePrefix, err := netip.ParsePrefix(sp)
if err != nil {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix)
logger.Info("Removed target subnet %s with destination %s", sp, target.DestPrefix)
}
}
for _, target := range requests.NewTargets {
// Now add the new target
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
@@ -1229,14 +1430,21 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
var portRanges []netstack2.PortRange
for _, pr := range target.PortRange {
portRanges = append(portRanges, netstack2.PortRange{
Min: pr.Min,
Max: pr.Max,
Protocol: pr.Protocol,
Min: pr.Min,
Max: pr.Max,
Protocol: pr.Protocol,
})
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v disableIcmp: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange, target.DisableIcmp)
for _, sp := range resolveSourcePrefixes(target) {
sourcePrefix, err := netip.ParsePrefix(sp)
if err != nil {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
}
}
}

View File

@@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"fmt"
"net"
"os"
"os/exec"
"strings"
@@ -363,27 +364,62 @@ func parseTargetData(data interface{}) (TargetData, error) {
return targetData, nil
}
// parseTargetString parses a target string in the format "listenPort:host:targetPort"
// It properly handles IPv6 addresses which must be in brackets: "listenPort:[ipv6]:targetPort"
// Examples:
// - IPv4: "3001:192.168.1.1:80"
// - IPv6: "3001:[::1]:8080" or "3001:[fd70:1452:b736:4dd5:caca:7db9:c588:f5b3]:80"
//
// Returns listenPort, targetAddress (in host:port format suitable for net.Dial), and error
func parseTargetString(target string) (int, string, error) {
// Find the first colon to extract the listen port
firstColon := strings.Index(target, ":")
if firstColon == -1 {
return 0, "", fmt.Errorf("invalid target format, no colon found: %s", target)
}
listenPortStr := target[:firstColon]
var listenPort int
_, err := fmt.Sscanf(listenPortStr, "%d", &listenPort)
if err != nil {
return 0, "", fmt.Errorf("invalid listen port: %s", listenPortStr)
}
if listenPort <= 0 || listenPort > 65535 {
return 0, "", fmt.Errorf("listen port out of range: %d", listenPort)
}
// The remainder is host:targetPort - use net.SplitHostPort which handles IPv6 brackets
remainder := target[firstColon+1:]
host, targetPort, err := net.SplitHostPort(remainder)
if err != nil {
return 0, "", fmt.Errorf("invalid host:port format '%s': %w", remainder, err)
}
// Reject empty host or target port
if host == "" {
return 0, "", fmt.Errorf("empty host in target: %s", target)
}
if targetPort == "" {
return 0, "", fmt.Errorf("empty target port in target: %s", target)
}
// Reconstruct the target address using JoinHostPort (handles IPv6 properly)
targetAddr := net.JoinHostPort(host, targetPort)
return listenPort, targetAddr, nil
}
func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
for _, t := range targetData.Targets {
// Split the first number off of the target with : separator and use as the port
parts := strings.Split(t, ":")
if len(parts) != 3 {
logger.Info("Invalid target format: %s", t)
continue
}
// Get the port as an int
port := 0
_, err := fmt.Sscanf(parts[0], "%d", &port)
// Parse the target string, handling both IPv4 and IPv6 addresses
port, target, err := parseTargetString(t)
if err != nil {
logger.Info("Invalid port: %s", parts[0])
logger.Info("Invalid target format: %s (%v)", t, err)
continue
}
switch action {
case "add":
target := parts[1] + ":" + parts[2]
// Call updown script if provided
processedTarget := target
if updownScript != "" {
@@ -410,8 +446,6 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
case "remove":
logger.Info("Removing target with port %d", port)
target := parts[1] + ":" + parts[2]
// Call updown script if provided
if updownScript != "" {
_, err := executeUpdownScript(action, proto, target)
@@ -420,7 +454,7 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
}
}
err := pm.RemoveTarget(proto, tunnelIP, port)
err = pm.RemoveTarget(proto, tunnelIP, port)
if err != nil {
logger.Error("Failed to remove target: %v", err)
return err

212
common_test.go Normal file
View File

@@ -0,0 +1,212 @@
package main
import (
"net"
"testing"
)
func TestParseTargetString(t *testing.T) {
tests := []struct {
name string
input string
wantListenPort int
wantTargetAddr string
wantErr bool
}{
// IPv4 test cases
{
name: "valid IPv4 basic",
input: "3001:192.168.1.1:80",
wantListenPort: 3001,
wantTargetAddr: "192.168.1.1:80",
wantErr: false,
},
{
name: "valid IPv4 localhost",
input: "8080:127.0.0.1:3000",
wantListenPort: 8080,
wantTargetAddr: "127.0.0.1:3000",
wantErr: false,
},
{
name: "valid IPv4 same ports",
input: "443:10.0.0.1:443",
wantListenPort: 443,
wantTargetAddr: "10.0.0.1:443",
wantErr: false,
},
// IPv6 test cases
{
name: "valid IPv6 loopback",
input: "3001:[::1]:8080",
wantListenPort: 3001,
wantTargetAddr: "[::1]:8080",
wantErr: false,
},
{
name: "valid IPv6 full address",
input: "80:[fd70:1452:b736:4dd5:caca:7db9:c588:f5b3]:8080",
wantListenPort: 80,
wantTargetAddr: "[fd70:1452:b736:4dd5:caca:7db9:c588:f5b3]:8080",
wantErr: false,
},
{
name: "valid IPv6 link-local",
input: "443:[fe80::1]:443",
wantListenPort: 443,
wantTargetAddr: "[fe80::1]:443",
wantErr: false,
},
{
name: "valid IPv6 all zeros compressed",
input: "8000:[::]:9000",
wantListenPort: 8000,
wantTargetAddr: "[::]:9000",
wantErr: false,
},
{
name: "valid IPv6 mixed notation",
input: "5000:[::ffff:192.168.1.1]:6000",
wantListenPort: 5000,
wantTargetAddr: "[::ffff:192.168.1.1]:6000",
wantErr: false,
},
// Hostname test cases
{
name: "valid hostname",
input: "8080:example.com:80",
wantListenPort: 8080,
wantTargetAddr: "example.com:80",
wantErr: false,
},
{
name: "valid hostname with subdomain",
input: "443:api.example.com:8443",
wantListenPort: 443,
wantTargetAddr: "api.example.com:8443",
wantErr: false,
},
{
name: "valid localhost hostname",
input: "3000:localhost:3000",
wantListenPort: 3000,
wantTargetAddr: "localhost:3000",
wantErr: false,
},
// Error cases
{
name: "invalid - no colons",
input: "invalid",
wantErr: true,
},
{
name: "invalid - empty string",
input: "",
wantErr: true,
},
{
name: "invalid - non-numeric listen port",
input: "abc:192.168.1.1:80",
wantErr: true,
},
{
name: "invalid - missing target port",
input: "3001:192.168.1.1",
wantErr: true,
},
{
name: "invalid - IPv6 without brackets",
input: "3001:fd70:1452:b736:4dd5:caca:7db9:c588:f5b3:80",
wantErr: true,
},
{
name: "invalid - only listen port",
input: "3001:",
wantErr: true,
},
{
name: "invalid - missing host",
input: "3001::80",
wantErr: true,
},
{
name: "invalid - IPv6 unclosed bracket",
input: "3001:[::1:80",
wantErr: true,
},
{
name: "invalid - listen port zero",
input: "0:192.168.1.1:80",
wantErr: true,
},
{
name: "invalid - listen port negative",
input: "-1:192.168.1.1:80",
wantErr: true,
},
{
name: "invalid - listen port out of range",
input: "70000:192.168.1.1:80",
wantErr: true,
},
{
name: "invalid - empty target port",
input: "3001:192.168.1.1:",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
listenPort, targetAddr, err := parseTargetString(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("parseTargetString(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
return
}
if tt.wantErr {
return // Don't check other values if we expected an error
}
if listenPort != tt.wantListenPort {
t.Errorf("parseTargetString(%q) listenPort = %d, want %d", tt.input, listenPort, tt.wantListenPort)
}
if targetAddr != tt.wantTargetAddr {
t.Errorf("parseTargetString(%q) targetAddr = %q, want %q", tt.input, targetAddr, tt.wantTargetAddr)
}
})
}
}
// TestParseTargetStringNetDialCompatibility verifies that the output is compatible with net.Dial
func TestParseTargetStringNetDialCompatibility(t *testing.T) {
tests := []struct {
name string
input string
}{
{"IPv4", "8080:127.0.0.1:80"},
{"IPv6 loopback", "8080:[::1]:80"},
{"IPv6 full", "8080:[2001:db8::1]:80"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, targetAddr, err := parseTargetString(tt.input)
if err != nil {
t.Fatalf("parseTargetString(%q) unexpected error: %v", tt.input, err)
}
// Verify the format is valid for net.Dial by checking it can be split back
// This doesn't actually dial, just validates the format
_, _, err = net.SplitHostPort(targetAddr)
if err != nil {
t.Errorf("parseTargetString(%q) produced invalid net.Dial format %q: %v", tt.input, targetAddr, err)
}
})
}
}

View File

@@ -1,7 +1,7 @@
#!/bin/bash
#!/bin/sh
# Get Newt - Cross-platform installation script
# Usage: curl -fsSL https://raw.githubusercontent.com/fosrl/newt/refs/heads/main/get-newt.sh | bash
# Usage: curl -fsSL https://raw.githubusercontent.com/fosrl/newt/refs/heads/main/get-newt.sh | sh
set -e
@@ -17,15 +17,15 @@ GITHUB_API_URL="https://api.github.com/repos/${REPO}/releases/latest"
# Function to print colored output
print_status() {
echo -e "${GREEN}[INFO]${NC} $1"
printf '%b[INFO]%b %s\n' "${GREEN}" "${NC}" "$1"
}
print_warning() {
echo -e "${YELLOW}[WARN]${NC} $1"
printf '%b[WARN]%b %s\n' "${YELLOW}" "${NC}" "$1"
}
print_error() {
echo -e "${RED}[ERROR]${NC} $1"
printf '%b[ERROR]%b %s\n' "${RED}" "${NC}" "$1"
}
# Function to get latest version from GitHub API
@@ -113,16 +113,34 @@ get_install_dir() {
if [ "$OS" = "windows" ]; then
echo "$HOME/bin"
else
# Try to use a directory in PATH, fallback to ~/.local/bin
if echo "$PATH" | grep -q "/usr/local/bin"; then
if [ -w "/usr/local/bin" ] 2>/dev/null; then
echo "/usr/local/bin"
else
echo "$HOME/.local/bin"
fi
# Prefer /usr/local/bin for system-wide installation
echo "/usr/local/bin"
fi
}
# Check if we need sudo for installation
needs_sudo() {
local install_dir="$1"
if [ -w "$install_dir" ] 2>/dev/null; then
return 1 # No sudo needed
else
return 0 # Sudo needed
fi
}
# Get the appropriate command prefix (sudo or empty)
get_sudo_cmd() {
local install_dir="$1"
if needs_sudo "$install_dir"; then
if command -v sudo >/dev/null 2>&1; then
echo "sudo"
else
echo "$HOME/.local/bin"
print_error "Cannot write to ${install_dir} and sudo is not available."
print_error "Please run this script as root or install sudo."
exit 1
fi
else
echo ""
fi
}
@@ -130,21 +148,24 @@ get_install_dir() {
install_newt() {
local platform="$1"
local install_dir="$2"
local sudo_cmd="$3"
local binary_name="newt_${platform}"
local exe_suffix=""
# Add .exe suffix for Windows
if [[ "$platform" == *"windows"* ]]; then
binary_name="${binary_name}.exe"
exe_suffix=".exe"
fi
case "$platform" in
*windows*)
binary_name="${binary_name}.exe"
exe_suffix=".exe"
;;
esac
local download_url="${BASE_URL}/${binary_name}"
local temp_file="/tmp/newt${exe_suffix}"
local final_path="${install_dir}/newt${exe_suffix}"
print_status "Downloading newt from ${download_url}"
# Download the binary
if command -v curl >/dev/null 2>&1; then
curl -fsSL "$download_url" -o "$temp_file"
@@ -154,18 +175,22 @@ install_newt() {
print_error "Neither curl nor wget is available. Please install one of them."
exit 1
fi
# Make executable before moving
chmod +x "$temp_file"
# Create install directory if it doesn't exist
mkdir -p "$install_dir"
# Move binary to install directory
mv "$temp_file" "$final_path"
# Make executable (not needed on Windows, but doesn't hurt)
chmod +x "$final_path"
if [ -n "$sudo_cmd" ]; then
$sudo_cmd mkdir -p "$install_dir"
print_status "Using sudo to install to ${install_dir}"
$sudo_cmd mv "$temp_file" "$final_path"
else
mkdir -p "$install_dir"
mv "$temp_file" "$final_path"
fi
print_status "newt installed to ${final_path}"
# Check if install directory is in PATH
if ! echo "$PATH" | grep -q "$install_dir"; then
print_warning "Install directory ${install_dir} is not in your PATH."
@@ -179,9 +204,9 @@ verify_installation() {
local install_dir="$1"
local exe_suffix=""
if [[ "$PLATFORM" == *"windows"* ]]; then
exe_suffix=".exe"
fi
case "$PLATFORM" in
*windows*) exe_suffix=".exe" ;;
esac
local newt_path="${install_dir}/newt${exe_suffix}"
@@ -198,34 +223,36 @@ verify_installation() {
# Main installation process
main() {
print_status "Installing latest version of newt..."
# Get latest version
print_status "Fetching latest version from GitHub..."
VERSION=$(get_latest_version)
print_status "Latest version: v${VERSION}"
# Set base URL with the fetched version
BASE_URL="https://github.com/${REPO}/releases/download/${VERSION}"
# Detect platform
PLATFORM=$(detect_platform)
print_status "Detected platform: ${PLATFORM}"
# Get install directory
INSTALL_DIR=$(get_install_dir)
print_status "Install directory: ${INSTALL_DIR}"
# Check if we need sudo
SUDO_CMD=$(get_sudo_cmd "$INSTALL_DIR")
if [ -n "$SUDO_CMD" ]; then
print_status "Root privileges required for installation to ${INSTALL_DIR}"
fi
# Install newt
install_newt "$PLATFORM" "$INSTALL_DIR"
install_newt "$PLATFORM" "$INSTALL_DIR" "$SUDO_CMD"
# Verify installation
if verify_installation "$INSTALL_DIR"; then
print_status "newt is ready to use!"
if [[ "$PLATFORM" == *"windows"* ]]; then
print_status "Run 'newt --help' to get started"
else
print_status "Run 'newt --help' to get started"
fi
print_status "Run 'newt --help' to get started"
else
exit 1
fi

View File

@@ -5,7 +5,9 @@ import (
"crypto/tls"
"encoding/json"
"fmt"
"net"
"net/http"
"strconv"
"strings"
"sync"
"time"
@@ -365,11 +367,12 @@ func (m *Monitor) performHealthCheck(target *Target) {
target.LastCheck = time.Now()
target.LastError = ""
// Build URL
url := fmt.Sprintf("%s://%s", target.Config.Scheme, target.Config.Hostname)
// Build URL (use net.JoinHostPort to properly handle IPv6 addresses with ports)
host := target.Config.Hostname
if target.Config.Port > 0 {
url = fmt.Sprintf("%s:%d", url, target.Config.Port)
host = net.JoinHostPort(target.Config.Hostname, strconv.Itoa(target.Config.Port))
}
url := fmt.Sprintf("%s://%s", target.Config.Scheme, host)
if target.Config.Path != "" {
if !strings.HasPrefix(target.Config.Path, "/") {
url += "/"
@@ -521,3 +524,82 @@ func (m *Monitor) DisableTarget(id int) error {
return nil
}
// GetTargetIDs returns a slice of all current target IDs
func (m *Monitor) GetTargetIDs() []int {
m.mutex.RLock()
defer m.mutex.RUnlock()
ids := make([]int, 0, len(m.targets))
for id := range m.targets {
ids = append(ids, id)
}
return ids
}
// SyncTargets synchronizes the current targets to match the desired set.
// It removes targets not in the desired set and adds targets that are missing.
func (m *Monitor) SyncTargets(desiredConfigs []Config) error {
m.mutex.Lock()
defer m.mutex.Unlock()
logger.Info("Syncing health check targets: %d desired targets", len(desiredConfigs))
// Build a set of desired target IDs
desiredIDs := make(map[int]Config)
for _, config := range desiredConfigs {
desiredIDs[config.ID] = config
}
// Find targets to remove (exist but not in desired set)
var toRemove []int
for id := range m.targets {
if _, exists := desiredIDs[id]; !exists {
toRemove = append(toRemove, id)
}
}
// Remove targets that are not in the desired set
for _, id := range toRemove {
logger.Info("Sync: removing health check target %d", id)
if target, exists := m.targets[id]; exists {
target.cancel()
delete(m.targets, id)
}
}
// Add or update targets from the desired set
var addedCount, updatedCount int
for id, config := range desiredIDs {
if existing, exists := m.targets[id]; exists {
// Target exists - check if config changed and update if needed
// For now, we'll replace it to ensure config is up to date
logger.Debug("Sync: updating health check target %d", id)
existing.cancel()
delete(m.targets, id)
if err := m.addTargetUnsafe(config); err != nil {
logger.Error("Sync: failed to update target %d: %v", id, err)
return fmt.Errorf("failed to update target %d: %v", id, err)
}
updatedCount++
} else {
// Target doesn't exist - add it
logger.Debug("Sync: adding health check target %d", id)
if err := m.addTargetUnsafe(config); err != nil {
logger.Error("Sync: failed to add target %d: %v", id, err)
return fmt.Errorf("failed to add target %d: %v", id, err)
}
addedCount++
}
}
logger.Info("Sync complete: removed %d, added %d, updated %d targets",
len(toRemove), addedCount, updatedCount)
// Notify callback if any changes were made
if (len(toRemove) > 0 || addedCount > 0 || updatedCount > 0) && m.callback != nil {
go m.callback(m.getAllTargetsUnsafe())
}
return nil
}

193
main.go
View File

@@ -10,6 +10,7 @@ import (
"fmt"
"net"
"net/http"
"net/http/pprof"
"net/netip"
"os"
"os/signal"
@@ -147,6 +148,7 @@ var (
adminAddr string
region string
metricsAsyncBytes bool
pprofEnabled bool
blueprintFile string
noCloud bool
@@ -225,6 +227,7 @@ func runNewtMain(ctx context.Context) {
adminAddrEnv := os.Getenv("NEWT_ADMIN_ADDR")
regionEnv := os.Getenv("NEWT_REGION")
asyncBytesEnv := os.Getenv("NEWT_METRICS_ASYNC_BYTES")
pprofEnabledEnv := os.Getenv("NEWT_PPROF_ENABLED")
disableClientsEnv := os.Getenv("DISABLE_CLIENTS")
disableClients = disableClientsEnv == "true"
@@ -302,10 +305,10 @@ func runNewtMain(ctx context.Context) {
flag.StringVar(&dockerSocket, "docker-socket", "", "Path or address to Docker socket (typically unix:///var/run/docker.sock)")
}
if pingIntervalStr == "" {
flag.StringVar(&pingIntervalStr, "ping-interval", "3s", "Interval for pinging the server (default 3s)")
flag.StringVar(&pingIntervalStr, "ping-interval", "15s", "Interval for pinging the server (default 15s)")
}
if pingTimeoutStr == "" {
flag.StringVar(&pingTimeoutStr, "ping-timeout", "5s", " Timeout for each ping (default 5s)")
flag.StringVar(&pingTimeoutStr, "ping-timeout", "7s", " Timeout for each ping (default 7s)")
}
// load the prefer endpoint just as a flag
flag.StringVar(&preferEndpoint, "prefer-endpoint", "", "Prefer this endpoint for the connection (if set, will override the endpoint from the server)")
@@ -330,21 +333,21 @@ func runNewtMain(ctx context.Context) {
if pingIntervalStr != "" {
pingInterval, err = time.ParseDuration(pingIntervalStr)
if err != nil {
fmt.Printf("Invalid PING_INTERVAL value: %s, using default 3 seconds\n", pingIntervalStr)
pingInterval = 3 * time.Second
fmt.Printf("Invalid PING_INTERVAL value: %s, using default 15 seconds\n", pingIntervalStr)
pingInterval = 15 * time.Second
}
} else {
pingInterval = 3 * time.Second
pingInterval = 15 * time.Second
}
if pingTimeoutStr != "" {
pingTimeout, err = time.ParseDuration(pingTimeoutStr)
if err != nil {
fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 5 seconds\n", pingTimeoutStr)
pingTimeout = 5 * time.Second
fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 7 seconds\n", pingTimeoutStr)
pingTimeout = 7 * time.Second
}
} else {
pingTimeout = 5 * time.Second
pingTimeout = 7 * time.Second
}
if dockerEnforceNetworkValidation == "" {
@@ -390,6 +393,14 @@ func runNewtMain(ctx context.Context) {
metricsAsyncBytes = v
}
}
// pprof debug endpoint toggle
if pprofEnabledEnv == "" {
flag.BoolVar(&pprofEnabled, "pprof", false, "Enable pprof debug endpoints on admin server")
} else {
if v, err := strconv.ParseBool(pprofEnabledEnv); err == nil {
pprofEnabled = v
}
}
// Optional region flag (resource attribute)
if regionEnv == "" {
flag.StringVar(&region, "region", "", "Optional region resource attribute (also NEWT_REGION)")
@@ -485,6 +496,14 @@ func runNewtMain(ctx context.Context) {
if tel.PrometheusHandler != nil {
mux.Handle("/metrics", tel.PrometheusHandler)
}
if pprofEnabled {
mux.HandleFunc("/debug/pprof/", pprof.Index)
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
logger.Info("pprof debugging enabled on %s/debug/pprof/", tcfg.AdminAddr)
}
admin := &http.Server{
Addr: tcfg.AdminAddr,
Handler: otelhttp.NewHandler(mux, "newt-admin"),
@@ -565,8 +584,7 @@ func runNewtMain(ctx context.Context) {
id, // CLI arg takes precedence
secret, // CLI arg takes precedence
endpoint,
pingInterval,
pingTimeout,
30*time.Second,
opt,
)
if err != nil {
@@ -618,8 +636,6 @@ func runNewtMain(ctx context.Context) {
var connected bool
var wgData WgData
var dockerEventMonitor *docker.EventMonitor
logger.Debug("++++++++++++++++++++++ the port is %d", port)
if !disableClients {
setupClients(client)
@@ -959,7 +975,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
"publicKey": publicKey.String(),
"pingResults": pingResults,
"newtVersion": newtVersion,
}, 1*time.Second)
}, 2*time.Second)
return
}
@@ -1062,7 +1078,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
"publicKey": publicKey.String(),
"pingResults": pingResults,
"newtVersion": newtVersion,
}, 1*time.Second)
}, 2*time.Second)
logger.Debug("Sent exit node ping results to cloud for selection: pingResults=%+v", pingResults)
})
@@ -1167,6 +1183,153 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
}
})
// Register handler for syncing targets (TCP, UDP, and health checks)
client.RegisterHandler("newt/sync", func(msg websocket.WSMessage) {
logger.Info("Received sync message")
// if there is no wgData or pm, we can't sync targets
if wgData.TunnelIP == "" || pm == nil {
logger.Info(msgNoTunnelOrProxy)
return
}
// Define the sync data structure
type SyncData struct {
Targets TargetsByType `json:"targets"`
HealthCheckTargets []healthcheck.Config `json:"healthCheckTargets"`
}
var syncData SyncData
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling sync data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &syncData); err != nil {
logger.Error("Error unmarshaling sync data: %v", err)
return
}
logger.Debug("Sync data received: TCP targets=%d, UDP targets=%d, health check targets=%d",
len(syncData.Targets.TCP), len(syncData.Targets.UDP), len(syncData.HealthCheckTargets))
//TODO: TEST AND IMPLEMENT THIS
// // Build sets of desired targets (port -> target string)
// desiredTCP := make(map[int]string)
// for _, t := range syncData.Targets.TCP {
// parts := strings.Split(t, ":")
// if len(parts) != 3 {
// logger.Warn("Invalid TCP target format: %s", t)
// continue
// }
// port := 0
// if _, err := fmt.Sscanf(parts[0], "%d", &port); err != nil {
// logger.Warn("Invalid port in TCP target: %s", parts[0])
// continue
// }
// desiredTCP[port] = parts[1] + ":" + parts[2]
// }
// desiredUDP := make(map[int]string)
// for _, t := range syncData.Targets.UDP {
// parts := strings.Split(t, ":")
// if len(parts) != 3 {
// logger.Warn("Invalid UDP target format: %s", t)
// continue
// }
// port := 0
// if _, err := fmt.Sscanf(parts[0], "%d", &port); err != nil {
// logger.Warn("Invalid port in UDP target: %s", parts[0])
// continue
// }
// desiredUDP[port] = parts[1] + ":" + parts[2]
// }
// // Get current targets from proxy manager
// currentTCP, currentUDP := pm.GetTargets()
// // Sync TCP targets
// // Remove TCP targets not in desired set
// if tcpForIP, ok := currentTCP[wgData.TunnelIP]; ok {
// for port := range tcpForIP {
// if _, exists := desiredTCP[port]; !exists {
// logger.Info("Sync: removing TCP target on port %d", port)
// targetStr := fmt.Sprintf("%d:%s", port, tcpForIP[port])
// updateTargets(pm, "remove", wgData.TunnelIP, "tcp", TargetData{Targets: []string{targetStr}})
// }
// }
// }
// // Add TCP targets that are missing
// for port, target := range desiredTCP {
// needsAdd := true
// if tcpForIP, ok := currentTCP[wgData.TunnelIP]; ok {
// if currentTarget, exists := tcpForIP[port]; exists {
// // Check if target address changed
// if currentTarget == target {
// needsAdd = false
// } else {
// // Target changed, remove old one first
// logger.Info("Sync: updating TCP target on port %d", port)
// targetStr := fmt.Sprintf("%d:%s", port, currentTarget)
// updateTargets(pm, "remove", wgData.TunnelIP, "tcp", TargetData{Targets: []string{targetStr}})
// }
// }
// }
// if needsAdd {
// logger.Info("Sync: adding TCP target on port %d -> %s", port, target)
// targetStr := fmt.Sprintf("%d:%s", port, target)
// updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: []string{targetStr}})
// }
// }
// // Sync UDP targets
// // Remove UDP targets not in desired set
// if udpForIP, ok := currentUDP[wgData.TunnelIP]; ok {
// for port := range udpForIP {
// if _, exists := desiredUDP[port]; !exists {
// logger.Info("Sync: removing UDP target on port %d", port)
// targetStr := fmt.Sprintf("%d:%s", port, udpForIP[port])
// updateTargets(pm, "remove", wgData.TunnelIP, "udp", TargetData{Targets: []string{targetStr}})
// }
// }
// }
// // Add UDP targets that are missing
// for port, target := range desiredUDP {
// needsAdd := true
// if udpForIP, ok := currentUDP[wgData.TunnelIP]; ok {
// if currentTarget, exists := udpForIP[port]; exists {
// // Check if target address changed
// if currentTarget == target {
// needsAdd = false
// } else {
// // Target changed, remove old one first
// logger.Info("Sync: updating UDP target on port %d", port)
// targetStr := fmt.Sprintf("%d:%s", port, currentTarget)
// updateTargets(pm, "remove", wgData.TunnelIP, "udp", TargetData{Targets: []string{targetStr}})
// }
// }
// }
// if needsAdd {
// logger.Info("Sync: adding UDP target on port %d -> %s", port, target)
// targetStr := fmt.Sprintf("%d:%s", port, target)
// updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: []string{targetStr}})
// }
// }
// // Sync health check targets
// if err := healthMonitor.SyncTargets(syncData.HealthCheckTargets); err != nil {
// logger.Error("Failed to sync health check targets: %v", err)
// } else {
// logger.Info("Successfully synced health check targets")
// }
logger.Info("Sync complete")
})
// Register handler for Docker socket check
client.RegisterHandler("newt/socket/check", func(msg websocket.WSMessage) {
logger.Debug("Received Docker socket check request")
@@ -1649,6 +1812,8 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
pm.Stop()
}
client.SendMessage("newt/disconnecting", map[string]any{})
if client != nil {
client.Close()
}

355
netstack2/access_log.go Normal file
View File

@@ -0,0 +1,355 @@
package netstack2
import (
"bytes"
"compress/zlib"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
"sync"
"time"
"github.com/fosrl/newt/logger"
)
const (
// flushInterval is how often the access logger flushes completed sessions to the server
flushInterval = 60 * time.Second
// maxBufferedSessions is the max number of completed sessions to buffer before forcing a flush
maxBufferedSessions = 100
)
// SendFunc is a callback that sends compressed access log data to the server.
// The data is a base64-encoded zlib-compressed JSON array of AccessSession objects.
type SendFunc func(data string) error
// AccessSession represents a tracked access session through the proxy
type AccessSession struct {
SessionID string `json:"sessionId"`
ResourceID int `json:"resourceId"`
SourceAddr string `json:"sourceAddr"`
DestAddr string `json:"destAddr"`
Protocol string `json:"protocol"`
StartedAt time.Time `json:"startedAt"`
EndedAt time.Time `json:"endedAt,omitempty"`
BytesTx int64 `json:"bytesTx"`
BytesRx int64 `json:"bytesRx"`
}
// udpSessionKey identifies a unique UDP "session" by src -> dst
type udpSessionKey struct {
srcAddr string
dstAddr string
protocol string
}
// AccessLogger tracks access sessions for resources and periodically
// flushes completed sessions to the server via a configurable SendFunc.
type AccessLogger struct {
mu sync.Mutex
sessions map[string]*AccessSession // active sessions: sessionID -> session
udpSessions map[udpSessionKey]*AccessSession // active UDP sessions for dedup
completedSessions []*AccessSession // completed sessions waiting to be flushed
udpTimeout time.Duration
sendFn SendFunc
stopCh chan struct{}
flushDone chan struct{} // closed after the flush goroutine exits
}
// NewAccessLogger creates a new access logger.
// udpTimeout controls how long a UDP session is kept alive without traffic before being ended.
func NewAccessLogger(udpTimeout time.Duration) *AccessLogger {
al := &AccessLogger{
sessions: make(map[string]*AccessSession),
udpSessions: make(map[udpSessionKey]*AccessSession),
completedSessions: make([]*AccessSession, 0),
udpTimeout: udpTimeout,
stopCh: make(chan struct{}),
flushDone: make(chan struct{}),
}
go al.backgroundLoop()
return al
}
// SetSendFunc sets the callback used to send compressed access log batches
// to the server. This can be called after construction once the websocket
// client is available.
func (al *AccessLogger) SetSendFunc(fn SendFunc) {
al.mu.Lock()
defer al.mu.Unlock()
al.sendFn = fn
}
// generateSessionID creates a random session identifier
func generateSessionID() string {
b := make([]byte, 8)
rand.Read(b)
return hex.EncodeToString(b)
}
// StartTCPSession logs the start of a TCP session and returns a session ID.
func (al *AccessLogger) StartTCPSession(resourceID int, srcAddr, dstAddr string) string {
sessionID := generateSessionID()
now := time.Now()
session := &AccessSession{
SessionID: sessionID,
ResourceID: resourceID,
SourceAddr: srcAddr,
DestAddr: dstAddr,
Protocol: "tcp",
StartedAt: now,
}
al.mu.Lock()
al.sessions[sessionID] = session
al.mu.Unlock()
logger.Info("ACCESS START session=%s resource=%d proto=tcp src=%s dst=%s time=%s",
sessionID, resourceID, srcAddr, dstAddr, now.Format(time.RFC3339))
return sessionID
}
// EndTCPSession logs the end of a TCP session and queues it for sending.
func (al *AccessLogger) EndTCPSession(sessionID string) {
now := time.Now()
al.mu.Lock()
session, ok := al.sessions[sessionID]
if ok {
session.EndedAt = now
delete(al.sessions, sessionID)
al.completedSessions = append(al.completedSessions, session)
}
shouldFlush := len(al.completedSessions) >= maxBufferedSessions
al.mu.Unlock()
if ok {
duration := now.Sub(session.StartedAt)
logger.Info("ACCESS END session=%s resource=%d proto=tcp src=%s dst=%s started=%s ended=%s duration=%s",
sessionID, session.ResourceID, session.SourceAddr, session.DestAddr,
session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration)
}
if shouldFlush {
al.flush()
}
}
// TrackUDPSession starts or returns an existing UDP session. Returns the session ID.
func (al *AccessLogger) TrackUDPSession(resourceID int, srcAddr, dstAddr string) string {
key := udpSessionKey{
srcAddr: srcAddr,
dstAddr: dstAddr,
protocol: "udp",
}
al.mu.Lock()
defer al.mu.Unlock()
if existing, ok := al.udpSessions[key]; ok {
return existing.SessionID
}
sessionID := generateSessionID()
now := time.Now()
session := &AccessSession{
SessionID: sessionID,
ResourceID: resourceID,
SourceAddr: srcAddr,
DestAddr: dstAddr,
Protocol: "udp",
StartedAt: now,
}
al.sessions[sessionID] = session
al.udpSessions[key] = session
logger.Info("ACCESS START session=%s resource=%d proto=udp src=%s dst=%s time=%s",
sessionID, resourceID, srcAddr, dstAddr, now.Format(time.RFC3339))
return sessionID
}
// EndUDPSession ends a UDP session and queues it for sending.
func (al *AccessLogger) EndUDPSession(sessionID string) {
now := time.Now()
al.mu.Lock()
session, ok := al.sessions[sessionID]
if ok {
session.EndedAt = now
delete(al.sessions, sessionID)
key := udpSessionKey{
srcAddr: session.SourceAddr,
dstAddr: session.DestAddr,
protocol: "udp",
}
delete(al.udpSessions, key)
al.completedSessions = append(al.completedSessions, session)
}
shouldFlush := len(al.completedSessions) >= maxBufferedSessions
al.mu.Unlock()
if ok {
duration := now.Sub(session.StartedAt)
logger.Info("ACCESS END session=%s resource=%d proto=udp src=%s dst=%s started=%s ended=%s duration=%s",
sessionID, session.ResourceID, session.SourceAddr, session.DestAddr,
session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration)
}
if shouldFlush {
al.flush()
}
}
// backgroundLoop handles periodic flushing and stale session reaping.
func (al *AccessLogger) backgroundLoop() {
defer close(al.flushDone)
flushTicker := time.NewTicker(flushInterval)
defer flushTicker.Stop()
reapTicker := time.NewTicker(30 * time.Second)
defer reapTicker.Stop()
for {
select {
case <-al.stopCh:
return
case <-flushTicker.C:
al.flush()
case <-reapTicker.C:
al.reapStaleSessions()
}
}
}
// reapStaleSessions cleans up UDP sessions that were not properly ended.
func (al *AccessLogger) reapStaleSessions() {
al.mu.Lock()
defer al.mu.Unlock()
staleThreshold := time.Now().Add(-5 * time.Minute)
for key, session := range al.udpSessions {
if session.StartedAt.Before(staleThreshold) && session.EndedAt.IsZero() {
now := time.Now()
session.EndedAt = now
duration := now.Sub(session.StartedAt)
logger.Info("ACCESS END (reaped) session=%s resource=%d proto=udp src=%s dst=%s started=%s ended=%s duration=%s",
session.SessionID, session.ResourceID, session.SourceAddr, session.DestAddr,
session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration)
al.completedSessions = append(al.completedSessions, session)
delete(al.sessions, session.SessionID)
delete(al.udpSessions, key)
}
}
}
// flush drains the completed sessions buffer, compresses with zlib, and sends via the SendFunc.
func (al *AccessLogger) flush() {
al.mu.Lock()
if len(al.completedSessions) == 0 {
al.mu.Unlock()
return
}
batch := al.completedSessions
al.completedSessions = make([]*AccessSession, 0)
sendFn := al.sendFn
al.mu.Unlock()
if sendFn == nil {
logger.Debug("Access logger: no send function configured, discarding %d sessions", len(batch))
return
}
compressed, err := compressSessions(batch)
if err != nil {
logger.Error("Access logger: failed to compress %d sessions: %v", len(batch), err)
return
}
if err := sendFn(compressed); err != nil {
logger.Error("Access logger: failed to send %d sessions: %v", len(batch), err)
// Re-queue the batch so we don't lose data
al.mu.Lock()
al.completedSessions = append(batch, al.completedSessions...)
// Cap re-queued data to prevent unbounded growth if server is unreachable
if len(al.completedSessions) > maxBufferedSessions*5 {
dropped := len(al.completedSessions) - maxBufferedSessions*5
al.completedSessions = al.completedSessions[:maxBufferedSessions*5]
logger.Warn("Access logger: buffer overflow, dropped %d oldest sessions", dropped)
}
al.mu.Unlock()
return
}
logger.Info("Access logger: sent %d sessions to server", len(batch))
}
// compressSessions JSON-encodes the sessions, compresses with zlib, and returns
// a base64-encoded string suitable for embedding in a JSON message.
func compressSessions(sessions []*AccessSession) (string, error) {
jsonData, err := json.Marshal(sessions)
if err != nil {
return "", err
}
var buf bytes.Buffer
w, err := zlib.NewWriterLevel(&buf, zlib.BestCompression)
if err != nil {
return "", err
}
if _, err := w.Write(jsonData); err != nil {
w.Close()
return "", err
}
if err := w.Close(); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(buf.Bytes()), nil
}
// Close shuts down the background loop, ends all active sessions,
// and performs one final flush to send everything to the server.
func (al *AccessLogger) Close() {
// Signal the background loop to stop
select {
case <-al.stopCh:
// Already closed
return
default:
close(al.stopCh)
}
// Wait for the background loop to exit so we don't race on flush
<-al.flushDone
al.mu.Lock()
now := time.Now()
// End all active sessions and move them to the completed buffer
for _, session := range al.sessions {
if session.EndedAt.IsZero() {
session.EndedAt = now
duration := now.Sub(session.StartedAt)
logger.Info("ACCESS END (shutdown) session=%s resource=%d proto=%s src=%s dst=%s started=%s ended=%s duration=%s",
session.SessionID, session.ResourceID, session.Protocol, session.SourceAddr, session.DestAddr,
session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration)
al.completedSessions = append(al.completedSessions, session)
}
}
al.sessions = make(map[string]*AccessSession)
al.udpSessions = make(map[udpSessionKey]*AccessSession)
al.mu.Unlock()
// Final flush to send all remaining sessions to the server
al.flush()
}

View File

@@ -158,6 +158,18 @@ func (h *TCPHandler) handleTCPConn(netstackConn *gonet.TCPConn, id stack.Transpo
targetAddr := fmt.Sprintf("%s:%d", actualDstIP, dstPort)
// Look up resource ID and start access session if applicable
var accessSessionID string
if h.proxyHandler != nil {
resourceId := h.proxyHandler.LookupResourceId(srcIP, dstIP, dstPort, uint8(tcp.ProtocolNumber))
if resourceId != 0 {
if al := h.proxyHandler.GetAccessLogger(); al != nil {
srcAddr := fmt.Sprintf("%s:%d", srcIP, srcPort)
accessSessionID = al.StartTCPSession(resourceId, srcAddr, targetAddr)
}
}
}
// Create context with timeout for connection establishment
ctx, cancel := context.WithTimeout(context.Background(), tcpConnectTimeout)
defer cancel()
@@ -167,11 +179,26 @@ func (h *TCPHandler) handleTCPConn(netstackConn *gonet.TCPConn, id stack.Transpo
targetConn, err := d.DialContext(ctx, "tcp", targetAddr)
if err != nil {
logger.Info("TCP Forwarder: Failed to connect to %s: %v", targetAddr, err)
// End access session on connection failure
if accessSessionID != "" {
if al := h.proxyHandler.GetAccessLogger(); al != nil {
al.EndTCPSession(accessSessionID)
}
}
// Connection failed, netstack will handle RST
return
}
defer targetConn.Close()
// End access session when connection closes
if accessSessionID != "" {
defer func() {
if al := h.proxyHandler.GetAccessLogger(); al != nil {
al.EndTCPSession(accessSessionID)
}
}()
}
logger.Info("TCP Forwarder: Successfully connected to %s, starting bidirectional copy", targetAddr)
// Bidirectional copy between netstack and target
@@ -280,6 +307,27 @@ func (h *UDPHandler) handleUDPConn(netstackConn *gonet.UDPConn, id stack.Transpo
targetAddr := fmt.Sprintf("%s:%d", actualDstIP, dstPort)
// Look up resource ID and start access session if applicable
var accessSessionID string
if h.proxyHandler != nil {
resourceId := h.proxyHandler.LookupResourceId(srcIP, dstIP, dstPort, uint8(udp.ProtocolNumber))
if resourceId != 0 {
if al := h.proxyHandler.GetAccessLogger(); al != nil {
srcAddr := fmt.Sprintf("%s:%d", srcIP, srcPort)
accessSessionID = al.TrackUDPSession(resourceId, srcAddr, targetAddr)
}
}
}
// End access session when UDP handler returns (timeout or error)
if accessSessionID != "" {
defer func() {
if al := h.proxyHandler.GetAccessLogger(); al != nil {
al.EndUDPSession(accessSessionID)
}
}()
}
// Resolve target address
remoteUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr)
if err != nil {

View File

@@ -22,6 +22,12 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
const (
// udpAccessSessionTimeout is how long a UDP access session stays alive without traffic
// before being considered ended by the access logger
udpAccessSessionTimeout = 120 * time.Second
)
// PortRange represents an allowed range of ports (inclusive) with optional protocol filtering
// Protocol can be "tcp", "udp", or "" (empty string means both protocols)
type PortRange struct {
@@ -46,6 +52,24 @@ type SubnetRule struct {
DisableIcmp bool // If true, ICMP traffic is blocked for this subnet
RewriteTo string // Optional rewrite address for DNAT - can be IP/CIDR or domain name
PortRanges []PortRange // empty slice means all ports allowed
ResourceId int // Optional resource ID from the server for access logging
}
// GetAllRules returns a copy of all subnet rules
func (sl *SubnetLookup) GetAllRules() []SubnetRule {
sl.mu.RLock()
defer sl.mu.RUnlock()
var rules []SubnetRule
for _, destTriePtr := range sl.sourceTrie.All() {
if destTriePtr == nil {
continue
}
for _, rule := range destTriePtr.rules {
rules = append(rules, *rule)
}
}
return rules
}
// connKey uniquely identifies a connection for NAT tracking
@@ -94,10 +118,12 @@ type ProxyHandler struct {
natTable map[connKey]*natState
reverseNatTable map[reverseConnKey]*natState // Reverse lookup map for O(1) reply packet NAT
destRewriteTable map[destKey]netip.Addr // Maps original dest to rewritten dest for handler lookups
resourceTable map[destKey]int // Maps connection key to resource ID for access logging
natMu sync.RWMutex
enabled bool
icmpReplies chan []byte // Channel for ICMP reply packets to be sent back through the tunnel
notifiable channel.Notification // Notification handler for triggering reads
accessLogger *AccessLogger // Access logger for tracking sessions
}
// ProxyHandlerOptions configures the proxy handler
@@ -120,7 +146,9 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
natTable: make(map[connKey]*natState),
reverseNatTable: make(map[reverseConnKey]*natState),
destRewriteTable: make(map[destKey]netip.Addr),
resourceTable: make(map[destKey]int),
icmpReplies: make(chan []byte, 256), // Buffer for ICMP reply packets
accessLogger: NewAccessLogger(udpAccessSessionTimeout),
proxyEp: channel.New(1024, uint32(options.MTU), ""),
proxyStack: stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
@@ -185,11 +213,11 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
// destPrefix: The IP prefix of the destination
// rewriteTo: Optional address to rewrite destination to - can be IP/CIDR or domain name
// If portRanges is nil or empty, all ports are allowed for this subnet
func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) {
func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) {
if p == nil || !p.enabled {
return
}
p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp)
p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp, resourceId)
}
// RemoveSubnetRule removes a subnet from the proxy handler
@@ -200,6 +228,51 @@ func (p *ProxyHandler) RemoveSubnetRule(sourcePrefix, destPrefix netip.Prefix) {
p.subnetLookup.RemoveSubnet(sourcePrefix, destPrefix)
}
// GetAllRules returns all subnet rules from the proxy handler
func (p *ProxyHandler) GetAllRules() []SubnetRule {
if p == nil || !p.enabled {
return nil
}
return p.subnetLookup.GetAllRules()
}
// LookupResourceId looks up the resource ID for a connection
// Returns 0 if no resource ID is associated with this connection
func (p *ProxyHandler) LookupResourceId(srcIP, dstIP string, dstPort uint16, proto uint8) int {
if p == nil || !p.enabled {
return 0
}
key := destKey{
srcIP: srcIP,
dstIP: dstIP,
dstPort: dstPort,
proto: proto,
}
p.natMu.RLock()
defer p.natMu.RUnlock()
return p.resourceTable[key]
}
// GetAccessLogger returns the access logger for session tracking
func (p *ProxyHandler) GetAccessLogger() *AccessLogger {
if p == nil {
return nil
}
return p.accessLogger
}
// SetAccessLogSender configures the function used to send compressed access log
// batches to the server. This should be called once the websocket client is available.
func (p *ProxyHandler) SetAccessLogSender(fn SendFunc) {
if p == nil || !p.enabled || p.accessLogger == nil {
return
}
p.accessLogger.SetSendFunc(fn)
}
// LookupDestinationRewrite looks up the rewritten destination for a connection
// This is used by TCP/UDP handlers to find the actual target address
func (p *ProxyHandler) LookupDestinationRewrite(srcIP, dstIP string, dstPort uint16, proto uint8) (netip.Addr, bool) {
@@ -362,8 +435,22 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
// Check if the source IP, destination IP, port, and protocol match any subnet rule
matchedRule := p.subnetLookup.Match(srcAddr, dstAddr, dstPort, protocol)
if matchedRule != nil {
logger.Debug("HandleIncomingPacket: Matched rule for %s -> %s (proto=%d, port=%d)",
srcAddr, dstAddr, protocol, dstPort)
logger.Debug("HandleIncomingPacket: Matched rule for %s -> %s (proto=%d, port=%d, resourceId=%d)",
srcAddr, dstAddr, protocol, dstPort, matchedRule.ResourceId)
// Store resource ID for connections without DNAT as well
if matchedRule.ResourceId != 0 && matchedRule.RewriteTo == "" {
dKey := destKey{
srcIP: srcAddr.String(),
dstIP: dstAddr.String(),
dstPort: dstPort,
proto: uint8(protocol),
}
p.natMu.Lock()
p.resourceTable[dKey] = matchedRule.ResourceId
p.natMu.Unlock()
}
// Check if we need to perform DNAT
if matchedRule.RewriteTo != "" {
// Create connection tracking key using original destination
@@ -395,6 +482,13 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
proto: uint8(protocol),
}
// Store resource ID for access logging if present
if matchedRule.ResourceId != 0 {
p.natMu.Lock()
p.resourceTable[dKey] = matchedRule.ResourceId
p.natMu.Unlock()
}
// Check if we already have a NAT entry for this connection
p.natMu.RLock()
existingEntry, exists := p.natTable[key]
@@ -695,6 +789,11 @@ func (p *ProxyHandler) Close() error {
return nil
}
// Shut down access logger
if p.accessLogger != nil {
p.accessLogger.Close()
}
// Close ICMP replies channel
if p.icmpReplies != nil {
close(p.icmpReplies)

View File

@@ -47,7 +47,7 @@ func prefixEqual(a, b netip.Prefix) bool {
// AddSubnet adds a subnet rule with source and destination prefixes and optional port restrictions
// If portRanges is nil or empty, all ports are allowed for this subnet
// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com")
func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) {
func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) {
sl.mu.Lock()
defer sl.mu.Unlock()
@@ -57,6 +57,7 @@ func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewrite
DisableIcmp: disableIcmp,
RewriteTo: rewriteTo,
PortRanges: portRanges,
ResourceId: resourceId,
}
// Canonicalize source prefix to handle host bits correctly

View File

@@ -354,10 +354,10 @@ func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
// AddProxySubnetRule adds a subnet rule to the proxy handler
// If portRanges is nil or empty, all ports are allowed for this subnet
// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com")
func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) {
func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) {
tun := (*netTun)(net)
if tun.proxyHandler != nil {
tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp)
tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp, resourceId)
}
}
@@ -369,6 +369,15 @@ func (net *Net) RemoveProxySubnetRule(sourcePrefix, destPrefix netip.Prefix) {
}
}
// GetProxySubnetRules returns all subnet rules from the proxy handler
func (net *Net) GetProxySubnetRules() []SubnetRule {
tun := (*netTun)(net)
if tun.proxyHandler != nil {
return tun.proxyHandler.GetAllRules()
}
return nil
}
// GetProxyHandler returns the proxy handler (for advanced use cases)
// Returns nil if proxy is not enabled
func (net *Net) GetProxyHandler() *ProxyHandler {
@@ -376,6 +385,15 @@ func (net *Net) GetProxyHandler() *ProxyHandler {
return tun.proxyHandler
}
// SetAccessLogSender configures the function used to send compressed access log
// batches to the server. This should be called once the websocket client is available.
func (net *Net) SetAccessLogSender(fn SendFunc) {
tun := (*netTun)(net)
if tun.proxyHandler != nil {
tun.proxyHandler.SetAccessLogSender(fn)
}
}
type PingConn struct {
laddr PingAddr
raddr PingAddr

View File

@@ -21,7 +21,10 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
)
const errUnsupportedProtoFmt = "unsupported protocol: %s"
const (
errUnsupportedProtoFmt = "unsupported protocol: %s"
maxUDPPacketSize = 65507
)
// Target represents a proxy target with its address and port
type Target struct {
@@ -105,13 +108,9 @@ func classifyProxyError(err error) string {
if errors.Is(err, net.ErrClosed) {
return "closed"
}
if ne, ok := err.(net.Error); ok {
if ne.Timeout() {
return "timeout"
}
if ne.Temporary() {
return "temporary"
}
var ne net.Error
if errors.As(err, &ne) && ne.Timeout() {
return "timeout"
}
msg := strings.ToLower(err.Error())
switch {
@@ -437,14 +436,6 @@ func (pm *ProxyManager) Stop() error {
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
}
// // Clear the target maps
// for k := range pm.tcpTargets {
// delete(pm.tcpTargets, k)
// }
// for k := range pm.udpTargets {
// delete(pm.udpTargets, k)
// }
// Give active connections a chance to close gracefully
time.Sleep(100 * time.Millisecond)
@@ -498,7 +489,7 @@ func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string)
if !pm.running {
return
}
if ne, ok := err.(net.Error); ok && !ne.Temporary() {
if errors.Is(err, net.ErrClosed) {
logger.Info("TCP listener closed, stopping proxy handler for %v", listener.Addr())
return
}
@@ -564,7 +555,7 @@ func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string)
}
func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
buffer := make([]byte, 65507) // Max UDP packet size
buffer := make([]byte, maxUDPPacketSize) // Max UDP packet size
clientConns := make(map[string]*net.UDPConn)
var clientsMutex sync.RWMutex
@@ -583,7 +574,7 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
}
// Check for connection closed conditions
if err == io.EOF || strings.Contains(err.Error(), "use of closed network connection") {
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
logger.Info("UDP connection closed, stopping proxy handler")
// Clean up existing client connections
@@ -662,10 +653,14 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
telemetry.IncProxyConnectionEvent(context.Background(), tunnelID, "udp", telemetry.ProxyConnectionClosed)
}()
buffer := make([]byte, 65507)
buffer := make([]byte, maxUDPPacketSize)
for {
n, _, err := targetConn.ReadFromUDP(buffer)
if err != nil {
// Connection closed is normal during cleanup
if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) {
return // defer will handle cleanup, result stays "success"
}
logger.Error("Error reading from target: %v", err)
result = "failure"
return // defer will handle cleanup
@@ -736,3 +731,28 @@ func (pm *ProxyManager) PrintTargets() {
}
}
}
// GetTargets returns a copy of the current TCP and UDP targets
// Returns map[listenIP]map[port]targetAddress for both TCP and UDP
func (pm *ProxyManager) GetTargets() (tcpTargets map[string]map[int]string, udpTargets map[string]map[int]string) {
pm.mutex.RLock()
defer pm.mutex.RUnlock()
tcpTargets = make(map[string]map[int]string)
for listenIP, targets := range pm.tcpTargets {
tcpTargets[listenIP] = make(map[int]string)
for port, targetAddr := range targets {
tcpTargets[listenIP][port] = targetAddr
}
}
udpTargets = make(map[string]map[int]string)
for listenIP, targets := range pm.udpTargets {
udpTargets[listenIP] = make(map[int]string)
for port, targetAddr := range targets {
udpTargets[listenIP][port] = targetAddr
}
}
return tcpTargets, udpTargets
}

View File

@@ -2,6 +2,7 @@ package websocket
import (
"bytes"
"compress/gzip"
"crypto/tls"
"crypto/x509"
"encoding/json"
@@ -37,7 +38,6 @@ type Client struct {
isConnected bool
reconnectMux sync.RWMutex
pingInterval time.Duration
pingTimeout time.Duration
onConnect func() error
onTokenUpdate func(token string)
writeMux sync.Mutex
@@ -47,6 +47,11 @@ type Client struct {
metricsCtx context.Context
configNeedsSave bool // Flag to track if config needs to be saved
serverVersion string
configVersion int64 // Latest config version received from server
configVersionMux sync.RWMutex
processingMessage bool // Flag to track if a message is currently being processed
processingMux sync.RWMutex // Protects processingMessage
processingWg sync.WaitGroup // WaitGroup to wait for message processing to complete
}
type ClientOption func(*Client)
@@ -111,7 +116,7 @@ func (c *Client) MetricsContext() context.Context {
}
// NewClient creates a new websocket client
func NewClient(clientType string, ID, secret string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) {
func NewClient(clientType string, ID, secret string, endpoint string, pingInterval time.Duration, opts ...ClientOption) (*Client, error) {
config := &Config{
ID: ID,
Secret: secret,
@@ -126,7 +131,6 @@ func NewClient(clientType string, ID, secret string, endpoint string, pingInterv
reconnectInterval: 3 * time.Second,
isConnected: false,
pingInterval: pingInterval,
pingTimeout: pingTimeout,
clientType: clientType,
}
@@ -154,6 +158,20 @@ func (c *Client) GetServerVersion() string {
return c.serverVersion
}
// GetConfigVersion returns the latest config version received from server
func (c *Client) GetConfigVersion() int64 {
c.configVersionMux.RLock()
defer c.configVersionMux.RUnlock()
return c.configVersion
}
// setConfigVersion updates the config version
func (c *Client) setConfigVersion(version int64) {
c.configVersionMux.Lock()
defer c.configVersionMux.Unlock()
c.configVersion = version
}
// Connect establishes the WebSocket connection
func (c *Client) Connect() error {
go c.connectWithRetry()
@@ -641,7 +659,57 @@ func (c *Client) setupPKCS12TLS() (*tls.Config, error) {
}
// pingMonitor sends pings at a short interval and triggers reconnect on failure
func (c *Client) sendPing() {
if c.conn == nil {
return
}
// Skip ping if a message is currently being processed
c.processingMux.RLock()
isProcessing := c.processingMessage
c.processingMux.RUnlock()
if isProcessing {
logger.Debug("Skipping ping, message is being processed")
return
}
c.configVersionMux.RLock()
configVersion := c.configVersion
c.configVersionMux.RUnlock()
pingMsg := WSMessage{
Type: "newt/ping",
Data: map[string]interface{}{},
ConfigVersion: configVersion,
}
c.writeMux.Lock()
err := c.conn.WriteJSON(pingMsg)
if err == nil {
telemetry.IncWSMessage(c.metricsContext(), "out", "ping")
}
c.writeMux.Unlock()
if err != nil {
// Check if we're shutting down before logging error and reconnecting
select {
case <-c.done:
// Expected during shutdown
return
default:
logger.Error("Ping failed: %v", err)
telemetry.IncWSKeepaliveFailure(c.metricsContext(), "ping_write")
telemetry.IncWSReconnect(c.metricsContext(), "ping_write")
c.reconnect()
return
}
}
}
func (c *Client) pingMonitor() {
// Send an immediate ping as soon as we connect
c.sendPing()
ticker := time.NewTicker(c.pingInterval)
defer ticker.Stop()
@@ -650,29 +718,7 @@ func (c *Client) pingMonitor() {
case <-c.done:
return
case <-ticker.C:
if c.conn == nil {
return
}
c.writeMux.Lock()
err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout))
if err == nil {
telemetry.IncWSMessage(c.metricsContext(), "out", "ping")
}
c.writeMux.Unlock()
if err != nil {
// Check if we're shutting down before logging error and reconnecting
select {
case <-c.done:
// Expected during shutdown
return
default:
logger.Error("Ping failed: %v", err)
telemetry.IncWSKeepaliveFailure(c.metricsContext(), "ping_write")
telemetry.IncWSReconnect(c.metricsContext(), "ping_write")
c.reconnect()
return
}
}
c.sendPing()
}
}
}
@@ -709,10 +755,13 @@ func (c *Client) readPumpWithDisconnectDetection(started time.Time) {
disconnectResult = "success"
return
default:
var msg WSMessage
err := c.conn.ReadJSON(&msg)
msgType, p, err := c.conn.ReadMessage()
if err == nil {
telemetry.IncWSMessage(c.metricsContext(), "in", "text")
if msgType == websocket.BinaryMessage {
telemetry.IncWSMessage(c.metricsContext(), "in", "binary")
} else {
telemetry.IncWSMessage(c.metricsContext(), "in", "text")
}
}
if err != nil {
// Check if we're shutting down before logging error
@@ -737,9 +786,47 @@ func (c *Client) readPumpWithDisconnectDetection(started time.Time) {
}
}
// Update config version from incoming message
var data []byte
if msgType == websocket.BinaryMessage {
gr, err := gzip.NewReader(bytes.NewReader(p))
if err != nil {
logger.Error("WebSocket failed to create gzip reader: %v", err)
continue
}
data, err = io.ReadAll(gr)
gr.Close()
if err != nil {
logger.Error("WebSocket failed to decompress message: %v", err)
continue
}
} else {
data = p
}
var msg WSMessage
if err = json.Unmarshal(data, &msg); err != nil {
logger.Error("WebSocket failed to parse message: %v", err)
continue
}
c.setConfigVersion(msg.ConfigVersion)
c.handlersMux.RLock()
if handler, ok := c.handlers[msg.Type]; ok {
// Mark that we're processing a message
c.processingMux.Lock()
c.processingMessage = true
c.processingMux.Unlock()
c.processingWg.Add(1)
handler(msg)
// Mark that we're done processing
c.processingWg.Done()
c.processingMux.Lock()
c.processingMessage = false
c.processingMux.Unlock()
}
c.handlersMux.RUnlock()
}

View File

@@ -17,6 +17,7 @@ type TokenResponse struct {
}
type WSMessage struct {
Type string `json:"type"`
Data interface{} `json:"data"`
Type string `json:"type"`
Data interface{} `json:"data"`
ConfigVersion int64 `json:"configVersion,omitempty"`
}