Compare commits

..

19 Commits

Author SHA1 Message Date
Owen
745d2dbc7e Merge branch 'dev' into msg-opt 2026-03-13 17:10:49 -07:00
Owen
c7b01288e0 Clean up previous logging 2026-03-13 11:45:36 -07:00
Owen
539e595c48 Add optional compression 2026-03-12 17:49:05 -07:00
Owen
a1df3d7ff0 Merge branch 'dev' of github.com:fosrl/newt into dev 2026-03-11 17:28:16 -07:00
Laurence
d68a13ea1f feat(installer): prefer /usr/local/bin and improve POSIX compatibility
- Always install to /usr/local/bin instead of ~/.local/bin
  - Use sudo automatically when write access is needed
  - Replace bash-specific syntax with POSIX equivalents:
    - Change shebang from #!/bin/bash to #!/bin/sh
    - Replace [[ == *pattern* ]] with case statements
    - Replace echo -e with printf for colored output
  - Script now works with dash, ash, busybox sh, and bash
2026-03-10 10:01:28 -07:00
Owen
accac75a53 Set newt version in dockerfile 2026-03-08 11:26:35 -07:00
Laurence
768415f90b Parse target strings with IPv6 support and strict validation
Add parseTargetString() for listenPort:host:targetPort using net.SplitHostPort/JoinHostPort. Replace manual split in updateTargets; fix err shadowing on remove. Validate listen port 1–65535 and reject empty host/port; use %w for errors. Add tests for IPv4, IPv6, hostnames, and invalid cases.
2026-03-07 21:32:36 -08:00
Owen
da9825d030 Merge branch 'main' into dev 2026-03-07 12:34:45 -08:00
Owen
afdb1fc977 Make sure to set version and fix prepare issue 2026-03-07 12:32:49 -08:00
Owen
1bd1133ac2 Make sure to skip prepare 2026-03-07 10:36:18 -08:00
Owen
fac0f5b197 Build full arn 2026-03-07 10:17:14 -08:00
Owen
e68b65683f Temp lets ignore the sync messages 2026-03-06 15:14:48 -08:00
Owen
7d6825132b Merge branch 'dev' into msg-opt 2026-03-03 16:56:41 -08:00
Owen
6371e980d2 Update the get all rules 2026-03-03 16:11:32 -08:00
Owen
15ea631b96 Mutex on handlers, slight change to ping message and handler 2026-03-02 20:56:36 -08:00
Owen
4e854b5f96 Working on message versioning 2026-03-02 20:56:18 -08:00
Owen
287eef0f44 Add version and send it down 2026-03-02 18:27:26 -08:00
Owen
f982e6b629 Merge branch 'dev' into msg-opt 2026-03-02 18:13:55 -08:00
Owen
039ae07b7b Support prefixes sent from server 2026-03-02 18:11:20 -08:00
17 changed files with 990 additions and 275 deletions

View File

@@ -37,11 +37,12 @@ type WgConfig struct {
}
type Target struct {
SourcePrefix string `json:"sourcePrefix"`
DestPrefix string `json:"destPrefix"`
RewriteTo string `json:"rewriteTo,omitempty"`
DisableIcmp bool `json:"disableIcmp,omitempty"`
PortRange []PortRange `json:"portRange,omitempty"`
SourcePrefix string `json:"sourcePrefix"`
SourcePrefixes []string `json:"sourcePrefixes"`
DestPrefix string `json:"destPrefix"`
RewriteTo string `json:"rewriteTo,omitempty"`
DisableIcmp bool `json:"disableIcmp,omitempty"`
PortRange []PortRange `json:"portRange,omitempty"`
}
type PortRange struct {
@@ -112,8 +113,6 @@ func NewWireGuardService(interfaceName string, port uint16, mtu int, host string
return nil, fmt.Errorf("failed to generate private key: %v", err)
}
logger.Debug("+++++++++++++++++++++++++++++++= the port is %d", port)
if port == 0 {
// Find an available port
portRandom, err := util.FindAvailableUDPPort(49152, 65535)
@@ -162,8 +161,9 @@ func NewWireGuardService(interfaceName string, port uint16, mtu int, host string
useNativeInterface: useNativeInterface,
}
// Create the holepunch manager
service.holePunchManager = holepunch.NewManager(sharedBind, newtId, "newt", key.PublicKey().String(), nil)
// Create the holepunch manager with ResolveDomain function
// We'll need to pass a domain resolver function
service.holePunchManager = holepunch.NewManager(sharedBind, newtId, "newt", key.PublicKey().String())
// Register websocket handlers
wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig)
@@ -173,6 +173,7 @@ func NewWireGuardService(interfaceName string, port uint16, mtu int, host string
wsClient.RegisterHandler("newt/wg/targets/add", service.handleAddTarget)
wsClient.RegisterHandler("newt/wg/targets/remove", service.handleRemoveTarget)
wsClient.RegisterHandler("newt/wg/targets/update", service.handleUpdateTarget)
wsClient.RegisterHandler("newt/wg/sync", service.handleSyncConfig)
return service, nil
}
@@ -278,7 +279,7 @@ func (s *WireGuardService) StartHolepunch(publicKey string, endpoint string, rel
}
if relayPort == 0 {
relayPort = 21820
relayPort = 21820
}
// Convert websocket.ExitNode to holepunch.ExitNode
@@ -493,6 +494,183 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
logger.Info("Client connectivity setup. Ready to accept connections from clients!")
}
// SyncConfig represents the configuration sent from server for syncing
type SyncConfig struct {
Targets []Target `json:"targets"`
Peers []Peer `json:"peers"`
}
func (s *WireGuardService) handleSyncConfig(msg websocket.WSMessage) {
var syncConfig SyncConfig
logger.Debug("Received sync message: %v", msg)
logger.Info("Received sync configuration from remote server")
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling sync data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &syncConfig); err != nil {
logger.Error("Error unmarshaling sync data: %v", err)
return
}
// Sync peers
if err := s.syncPeers(syncConfig.Peers); err != nil {
logger.Error("Failed to sync peers: %v", err)
}
// Sync targets
if err := s.syncTargets(syncConfig.Targets); err != nil {
logger.Error("Failed to sync targets: %v", err)
}
}
// syncPeers synchronizes the current peers with the desired state
// It removes peers not in the desired list and adds missing ones
func (s *WireGuardService) syncPeers(desiredPeers []Peer) error {
if s.device == nil {
return fmt.Errorf("WireGuard device is not initialized")
}
// Get current peers from the device
currentConfig, err := s.device.IpcGet()
if err != nil {
return fmt.Errorf("failed to get current device config: %v", err)
}
// Parse current peer public keys
lines := strings.Split(currentConfig, "\n")
currentPeerKeys := make(map[string]bool)
for _, line := range lines {
if strings.HasPrefix(line, "public_key=") {
pubKey := strings.TrimPrefix(line, "public_key=")
currentPeerKeys[pubKey] = true
}
}
// Build a map of desired peers by their public key (normalized)
desiredPeerMap := make(map[string]Peer)
for _, peer := range desiredPeers {
// Normalize the public key for comparison
pubKey, err := wgtypes.ParseKey(peer.PublicKey)
if err != nil {
logger.Warn("Invalid public key in desired peers: %s", peer.PublicKey)
continue
}
normalizedKey := util.FixKey(pubKey.String())
desiredPeerMap[normalizedKey] = peer
}
// Remove peers that are not in the desired list
for currentKey := range currentPeerKeys {
if _, exists := desiredPeerMap[currentKey]; !exists {
// Parse the key back to get the original format for removal
removeConfig := fmt.Sprintf("public_key=%s\nremove=true", currentKey)
if err := s.device.IpcSet(removeConfig); err != nil {
logger.Warn("Failed to remove peer %s during sync: %v", currentKey, err)
} else {
logger.Info("Removed peer %s during sync", currentKey)
}
}
}
// Add peers that are missing
for normalizedKey, peer := range desiredPeerMap {
if _, exists := currentPeerKeys[normalizedKey]; !exists {
if err := s.addPeerToDevice(peer); err != nil {
logger.Warn("Failed to add peer %s during sync: %v", peer.PublicKey, err)
} else {
logger.Info("Added peer %s during sync", peer.PublicKey)
}
}
}
return nil
}
// syncTargets synchronizes the current targets with the desired state
// It removes targets not in the desired list and adds missing ones
func (s *WireGuardService) syncTargets(desiredTargets []Target) error {
if s.tnet == nil {
// Native interface mode - proxy features not available, skip silently
logger.Debug("Skipping target sync - using native interface (no proxy support)")
return nil
}
// Get current rules from the proxy handler
currentRules := s.tnet.GetProxySubnetRules()
// Build a map of current rules by source+dest prefix
type ruleKey struct {
sourcePrefix string
destPrefix string
}
currentRuleMap := make(map[ruleKey]bool)
for _, rule := range currentRules {
key := ruleKey{
sourcePrefix: rule.SourcePrefix.String(),
destPrefix: rule.DestPrefix.String(),
}
currentRuleMap[key] = true
}
// Build a map of desired targets
desiredTargetMap := make(map[ruleKey]Target)
for _, target := range desiredTargets {
key := ruleKey{
sourcePrefix: target.SourcePrefix,
destPrefix: target.DestPrefix,
}
desiredTargetMap[key] = target
}
// Remove targets that are not in the desired list
for _, rule := range currentRules {
key := ruleKey{
sourcePrefix: rule.SourcePrefix.String(),
destPrefix: rule.DestPrefix.String(),
}
if _, exists := desiredTargetMap[key]; !exists {
s.tnet.RemoveProxySubnetRule(rule.SourcePrefix, rule.DestPrefix)
logger.Info("Removed target %s -> %s during sync", rule.SourcePrefix.String(), rule.DestPrefix.String())
}
}
// Add targets that are missing
for key, target := range desiredTargetMap {
if _, exists := currentRuleMap[key]; !exists {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Warn("Invalid source prefix %s during sync: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil {
logger.Warn("Invalid dest prefix %s during sync: %v", target.DestPrefix, err)
continue
}
var portRanges []netstack2.PortRange
for _, pr := range target.PortRange {
portRanges = append(portRanges, netstack2.PortRange{
Min: pr.Min,
Max: pr.Max,
Protocol: pr.Protocol,
})
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target %s -> %s during sync", target.SourcePrefix, target.DestPrefix)
}
}
return nil
}
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
s.mu.Lock()
@@ -696,6 +874,19 @@ func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
return nil
}
// resolveSourcePrefixes returns the effective list of source prefixes for a target,
// supporting both the legacy single SourcePrefix field and the new SourcePrefixes array.
// If SourcePrefixes is non-empty it takes precedence; otherwise SourcePrefix is used.
func resolveSourcePrefixes(target Target) []string {
if len(target.SourcePrefixes) > 0 {
return target.SourcePrefixes
}
if target.SourcePrefix != "" {
return []string{target.SourcePrefix}
}
return nil
}
func (s *WireGuardService) ensureTargets(targets []Target) error {
if s.tnet == nil {
// Native interface mode - proxy features not available, skip silently
@@ -704,11 +895,6 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
}
for _, target := range targets {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
return fmt.Errorf("invalid CIDR %s: %v", target.SourcePrefix, err)
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil {
return fmt.Errorf("invalid CIDR %s: %v", target.DestPrefix, err)
@@ -723,9 +909,14 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
})
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v disableIcmp: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange, target.DisableIcmp)
for _, sp := range resolveSourcePrefixes(target) {
sourcePrefix, err := netip.ParsePrefix(sp)
if err != nil {
return fmt.Errorf("invalid CIDR %s: %v", sp, err)
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
}
}
return nil
@@ -1044,7 +1235,7 @@ func (s *WireGuardService) processPeerBandwidth(publicKey string, rxBytes, txByt
BytesOut: bytesOutMB,
}
}
return nil
}
}
@@ -1095,12 +1286,6 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
// Process all targets
for _, target := range targets {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
@@ -1110,15 +1295,21 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
var portRanges []netstack2.PortRange
for _, pr := range target.PortRange {
portRanges = append(portRanges, netstack2.PortRange{
Min: pr.Min,
Max: pr.Max,
Protocol: pr.Protocol,
Min: pr.Min,
Max: pr.Max,
Protocol: pr.Protocol,
})
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v disableIcmp: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange, target.DisableIcmp)
for _, sp := range resolveSourcePrefixes(target) {
sourcePrefix, err := netip.ParsePrefix(sp)
if err != nil {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
}
}
}
@@ -1147,21 +1338,21 @@ func (s *WireGuardService) handleRemoveTarget(msg websocket.WSMessage) {
// Process all targets
for _, target := range targets {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
continue
}
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix)
logger.Info("Removed target subnet %s with destination %s", target.SourcePrefix, target.DestPrefix)
for _, sp := range resolveSourcePrefixes(target) {
sourcePrefix, err := netip.ParsePrefix(sp)
if err != nil {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix)
logger.Info("Removed target subnet %s with destination %s", sp, target.DestPrefix)
}
}
}
@@ -1195,30 +1386,24 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
// Process all update requests
for _, target := range requests.OldTargets {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
continue
}
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix)
logger.Info("Removed target subnet %s with destination %s", target.SourcePrefix, target.DestPrefix)
for _, sp := range resolveSourcePrefixes(target) {
sourcePrefix, err := netip.ParsePrefix(sp)
if err != nil {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix)
logger.Info("Removed target subnet %s with destination %s", sp, target.DestPrefix)
}
}
for _, target := range requests.NewTargets {
// Now add the new target
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
@@ -1228,14 +1413,21 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
var portRanges []netstack2.PortRange
for _, pr := range target.PortRange {
portRanges = append(portRanges, netstack2.PortRange{
Min: pr.Min,
Max: pr.Max,
Protocol: pr.Protocol,
Min: pr.Min,
Max: pr.Max,
Protocol: pr.Protocol,
})
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v disableIcmp: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange, target.DisableIcmp)
for _, sp := range resolveSourcePrefixes(target) {
sourcePrefix, err := netip.ParsePrefix(sp)
if err != nil {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
}
}
}

View File

@@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"fmt"
"net"
"os"
"os/exec"
"strings"
@@ -363,27 +364,62 @@ func parseTargetData(data interface{}) (TargetData, error) {
return targetData, nil
}
// parseTargetString parses a target string in the format "listenPort:host:targetPort"
// It properly handles IPv6 addresses which must be in brackets: "listenPort:[ipv6]:targetPort"
// Examples:
// - IPv4: "3001:192.168.1.1:80"
// - IPv6: "3001:[::1]:8080" or "3001:[fd70:1452:b736:4dd5:caca:7db9:c588:f5b3]:80"
//
// Returns listenPort, targetAddress (in host:port format suitable for net.Dial), and error
func parseTargetString(target string) (int, string, error) {
// Find the first colon to extract the listen port
firstColon := strings.Index(target, ":")
if firstColon == -1 {
return 0, "", fmt.Errorf("invalid target format, no colon found: %s", target)
}
listenPortStr := target[:firstColon]
var listenPort int
_, err := fmt.Sscanf(listenPortStr, "%d", &listenPort)
if err != nil {
return 0, "", fmt.Errorf("invalid listen port: %s", listenPortStr)
}
if listenPort <= 0 || listenPort > 65535 {
return 0, "", fmt.Errorf("listen port out of range: %d", listenPort)
}
// The remainder is host:targetPort - use net.SplitHostPort which handles IPv6 brackets
remainder := target[firstColon+1:]
host, targetPort, err := net.SplitHostPort(remainder)
if err != nil {
return 0, "", fmt.Errorf("invalid host:port format '%s': %w", remainder, err)
}
// Reject empty host or target port
if host == "" {
return 0, "", fmt.Errorf("empty host in target: %s", target)
}
if targetPort == "" {
return 0, "", fmt.Errorf("empty target port in target: %s", target)
}
// Reconstruct the target address using JoinHostPort (handles IPv6 properly)
targetAddr := net.JoinHostPort(host, targetPort)
return listenPort, targetAddr, nil
}
func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
for _, t := range targetData.Targets {
// Split the first number off of the target with : separator and use as the port
parts := strings.Split(t, ":")
if len(parts) != 3 {
logger.Info("Invalid target format: %s", t)
continue
}
// Get the port as an int
port := 0
_, err := fmt.Sscanf(parts[0], "%d", &port)
// Parse the target string, handling both IPv4 and IPv6 addresses
port, target, err := parseTargetString(t)
if err != nil {
logger.Info("Invalid port: %s", parts[0])
logger.Info("Invalid target format: %s (%v)", t, err)
continue
}
switch action {
case "add":
target := parts[1] + ":" + parts[2]
// Call updown script if provided
processedTarget := target
if updownScript != "" {
@@ -410,8 +446,6 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
case "remove":
logger.Info("Removing target with port %d", port)
target := parts[1] + ":" + parts[2]
// Call updown script if provided
if updownScript != "" {
_, err := executeUpdownScript(action, proto, target)
@@ -420,7 +454,7 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
}
}
err := pm.RemoveTarget(proto, tunnelIP, port)
err = pm.RemoveTarget(proto, tunnelIP, port)
if err != nil {
logger.Error("Failed to remove target: %v", err)
return err

212
common_test.go Normal file
View File

@@ -0,0 +1,212 @@
package main
import (
"net"
"testing"
)
func TestParseTargetString(t *testing.T) {
tests := []struct {
name string
input string
wantListenPort int
wantTargetAddr string
wantErr bool
}{
// IPv4 test cases
{
name: "valid IPv4 basic",
input: "3001:192.168.1.1:80",
wantListenPort: 3001,
wantTargetAddr: "192.168.1.1:80",
wantErr: false,
},
{
name: "valid IPv4 localhost",
input: "8080:127.0.0.1:3000",
wantListenPort: 8080,
wantTargetAddr: "127.0.0.1:3000",
wantErr: false,
},
{
name: "valid IPv4 same ports",
input: "443:10.0.0.1:443",
wantListenPort: 443,
wantTargetAddr: "10.0.0.1:443",
wantErr: false,
},
// IPv6 test cases
{
name: "valid IPv6 loopback",
input: "3001:[::1]:8080",
wantListenPort: 3001,
wantTargetAddr: "[::1]:8080",
wantErr: false,
},
{
name: "valid IPv6 full address",
input: "80:[fd70:1452:b736:4dd5:caca:7db9:c588:f5b3]:8080",
wantListenPort: 80,
wantTargetAddr: "[fd70:1452:b736:4dd5:caca:7db9:c588:f5b3]:8080",
wantErr: false,
},
{
name: "valid IPv6 link-local",
input: "443:[fe80::1]:443",
wantListenPort: 443,
wantTargetAddr: "[fe80::1]:443",
wantErr: false,
},
{
name: "valid IPv6 all zeros compressed",
input: "8000:[::]:9000",
wantListenPort: 8000,
wantTargetAddr: "[::]:9000",
wantErr: false,
},
{
name: "valid IPv6 mixed notation",
input: "5000:[::ffff:192.168.1.1]:6000",
wantListenPort: 5000,
wantTargetAddr: "[::ffff:192.168.1.1]:6000",
wantErr: false,
},
// Hostname test cases
{
name: "valid hostname",
input: "8080:example.com:80",
wantListenPort: 8080,
wantTargetAddr: "example.com:80",
wantErr: false,
},
{
name: "valid hostname with subdomain",
input: "443:api.example.com:8443",
wantListenPort: 443,
wantTargetAddr: "api.example.com:8443",
wantErr: false,
},
{
name: "valid localhost hostname",
input: "3000:localhost:3000",
wantListenPort: 3000,
wantTargetAddr: "localhost:3000",
wantErr: false,
},
// Error cases
{
name: "invalid - no colons",
input: "invalid",
wantErr: true,
},
{
name: "invalid - empty string",
input: "",
wantErr: true,
},
{
name: "invalid - non-numeric listen port",
input: "abc:192.168.1.1:80",
wantErr: true,
},
{
name: "invalid - missing target port",
input: "3001:192.168.1.1",
wantErr: true,
},
{
name: "invalid - IPv6 without brackets",
input: "3001:fd70:1452:b736:4dd5:caca:7db9:c588:f5b3:80",
wantErr: true,
},
{
name: "invalid - only listen port",
input: "3001:",
wantErr: true,
},
{
name: "invalid - missing host",
input: "3001::80",
wantErr: true,
},
{
name: "invalid - IPv6 unclosed bracket",
input: "3001:[::1:80",
wantErr: true,
},
{
name: "invalid - listen port zero",
input: "0:192.168.1.1:80",
wantErr: true,
},
{
name: "invalid - listen port negative",
input: "-1:192.168.1.1:80",
wantErr: true,
},
{
name: "invalid - listen port out of range",
input: "70000:192.168.1.1:80",
wantErr: true,
},
{
name: "invalid - empty target port",
input: "3001:192.168.1.1:",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
listenPort, targetAddr, err := parseTargetString(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("parseTargetString(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
return
}
if tt.wantErr {
return // Don't check other values if we expected an error
}
if listenPort != tt.wantListenPort {
t.Errorf("parseTargetString(%q) listenPort = %d, want %d", tt.input, listenPort, tt.wantListenPort)
}
if targetAddr != tt.wantTargetAddr {
t.Errorf("parseTargetString(%q) targetAddr = %q, want %q", tt.input, targetAddr, tt.wantTargetAddr)
}
})
}
}
// TestParseTargetStringNetDialCompatibility verifies that the output is compatible with net.Dial
func TestParseTargetStringNetDialCompatibility(t *testing.T) {
tests := []struct {
name string
input string
}{
{"IPv4", "8080:127.0.0.1:80"},
{"IPv6 loopback", "8080:[::1]:80"},
{"IPv6 full", "8080:[2001:db8::1]:80"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, targetAddr, err := parseTargetString(tt.input)
if err != nil {
t.Fatalf("parseTargetString(%q) unexpected error: %v", tt.input, err)
}
// Verify the format is valid for net.Dial by checking it can be split back
// This doesn't actually dial, just validates the format
_, _, err = net.SplitHostPort(targetAddr)
if err != nil {
t.Errorf("parseTargetString(%q) produced invalid net.Dial format %q: %v", tt.input, targetAddr, err)
}
})
}
}

View File

@@ -35,7 +35,7 @@
inherit version;
src = pkgs.nix-gitignore.gitignoreSource [ ] ./.;
vendorHash = "sha256-u7iQCKF8Jh1o0OQoPK4jSmO5pKMl9yT5Sj4GD2UuTU8=";
vendorHash = "sha256-kmQM8Yy5TuOiNpMpUme/2gfE+vrhUK+0AphN+p71wGs=";
nativeInstallCheckInputs = [ pkgs.versionCheckHook ];

View File

@@ -1,7 +1,7 @@
#!/bin/bash
#!/bin/sh
# Get Newt - Cross-platform installation script
# Usage: curl -fsSL https://raw.githubusercontent.com/fosrl/newt/refs/heads/main/get-newt.sh | bash
# Usage: curl -fsSL https://raw.githubusercontent.com/fosrl/newt/refs/heads/main/get-newt.sh | sh
set -e
@@ -17,15 +17,15 @@ GITHUB_API_URL="https://api.github.com/repos/${REPO}/releases/latest"
# Function to print colored output
print_status() {
echo -e "${GREEN}[INFO]${NC} $1"
printf '%b[INFO]%b %s\n' "${GREEN}" "${NC}" "$1"
}
print_warning() {
echo -e "${YELLOW}[WARN]${NC} $1"
printf '%b[WARN]%b %s\n' "${YELLOW}" "${NC}" "$1"
}
print_error() {
echo -e "${RED}[ERROR]${NC} $1"
printf '%b[ERROR]%b %s\n' "${RED}" "${NC}" "$1"
}
# Function to get latest version from GitHub API
@@ -113,16 +113,34 @@ get_install_dir() {
if [ "$OS" = "windows" ]; then
echo "$HOME/bin"
else
# Try to use a directory in PATH, fallback to ~/.local/bin
if echo "$PATH" | grep -q "/usr/local/bin"; then
if [ -w "/usr/local/bin" ] 2>/dev/null; then
echo "/usr/local/bin"
else
echo "$HOME/.local/bin"
fi
# Prefer /usr/local/bin for system-wide installation
echo "/usr/local/bin"
fi
}
# Check if we need sudo for installation
needs_sudo() {
local install_dir="$1"
if [ -w "$install_dir" ] 2>/dev/null; then
return 1 # No sudo needed
else
return 0 # Sudo needed
fi
}
# Get the appropriate command prefix (sudo or empty)
get_sudo_cmd() {
local install_dir="$1"
if needs_sudo "$install_dir"; then
if command -v sudo >/dev/null 2>&1; then
echo "sudo"
else
echo "$HOME/.local/bin"
print_error "Cannot write to ${install_dir} and sudo is not available."
print_error "Please run this script as root or install sudo."
exit 1
fi
else
echo ""
fi
}
@@ -130,21 +148,24 @@ get_install_dir() {
install_newt() {
local platform="$1"
local install_dir="$2"
local sudo_cmd="$3"
local binary_name="newt_${platform}"
local exe_suffix=""
# Add .exe suffix for Windows
if [[ "$platform" == *"windows"* ]]; then
binary_name="${binary_name}.exe"
exe_suffix=".exe"
fi
case "$platform" in
*windows*)
binary_name="${binary_name}.exe"
exe_suffix=".exe"
;;
esac
local download_url="${BASE_URL}/${binary_name}"
local temp_file="/tmp/newt${exe_suffix}"
local final_path="${install_dir}/newt${exe_suffix}"
print_status "Downloading newt from ${download_url}"
# Download the binary
if command -v curl >/dev/null 2>&1; then
curl -fsSL "$download_url" -o "$temp_file"
@@ -154,18 +175,22 @@ install_newt() {
print_error "Neither curl nor wget is available. Please install one of them."
exit 1
fi
# Make executable before moving
chmod +x "$temp_file"
# Create install directory if it doesn't exist
mkdir -p "$install_dir"
# Move binary to install directory
mv "$temp_file" "$final_path"
# Make executable (not needed on Windows, but doesn't hurt)
chmod +x "$final_path"
if [ -n "$sudo_cmd" ]; then
$sudo_cmd mkdir -p "$install_dir"
print_status "Using sudo to install to ${install_dir}"
$sudo_cmd mv "$temp_file" "$final_path"
else
mkdir -p "$install_dir"
mv "$temp_file" "$final_path"
fi
print_status "newt installed to ${final_path}"
# Check if install directory is in PATH
if ! echo "$PATH" | grep -q "$install_dir"; then
print_warning "Install directory ${install_dir} is not in your PATH."
@@ -179,9 +204,9 @@ verify_installation() {
local install_dir="$1"
local exe_suffix=""
if [[ "$PLATFORM" == *"windows"* ]]; then
exe_suffix=".exe"
fi
case "$PLATFORM" in
*windows*) exe_suffix=".exe" ;;
esac
local newt_path="${install_dir}/newt${exe_suffix}"
@@ -198,34 +223,36 @@ verify_installation() {
# Main installation process
main() {
print_status "Installing latest version of newt..."
# Get latest version
print_status "Fetching latest version from GitHub..."
VERSION=$(get_latest_version)
print_status "Latest version: v${VERSION}"
# Set base URL with the fetched version
BASE_URL="https://github.com/${REPO}/releases/download/${VERSION}"
# Detect platform
PLATFORM=$(detect_platform)
print_status "Detected platform: ${PLATFORM}"
# Get install directory
INSTALL_DIR=$(get_install_dir)
print_status "Install directory: ${INSTALL_DIR}"
# Check if we need sudo
SUDO_CMD=$(get_sudo_cmd "$INSTALL_DIR")
if [ -n "$SUDO_CMD" ]; then
print_status "Root privileges required for installation to ${INSTALL_DIR}"
fi
# Install newt
install_newt "$PLATFORM" "$INSTALL_DIR"
install_newt "$PLATFORM" "$INSTALL_DIR" "$SUDO_CMD"
# Verify installation
if verify_installation "$INSTALL_DIR"; then
print_status "newt is ready to use!"
if [[ "$PLATFORM" == *"windows"* ]]; then
print_status "Run 'newt --help' to get started"
else
print_status "Run 'newt --help' to get started"
fi
print_status "Run 'newt --help' to get started"
else
exit 1
fi

4
go.mod
View File

@@ -4,7 +4,7 @@ go 1.25.0
require (
github.com/docker/docker v28.5.2+incompatible
github.com/gaissmai/bart v0.26.1
github.com/gaissmai/bart v0.26.0
github.com/gorilla/websocket v1.5.3
github.com/prometheus/client_golang v1.23.2
github.com/vishvananda/netlink v1.3.1
@@ -24,7 +24,7 @@ require (
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/grpc v1.79.2
google.golang.org/grpc v1.79.1
gopkg.in/yaml.v3 v3.0.1
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c
software.sslmate.com/src/go-pkcs12 v0.7.0

8
go.sum
View File

@@ -26,8 +26,8 @@ github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw
github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/gaissmai/bart v0.26.1 h1:+w4rnLGNlA2GDVn382Tfe3jOsK5vOr5n4KmigJ9lbTo=
github.com/gaissmai/bart v0.26.1/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
github.com/gaissmai/bart v0.26.0 h1:xOZ57E9hJLBiQaSyeZa9wgWhGuzfGACgqp4BE77OkO0=
github.com/gaissmai/bart v0.26.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
@@ -159,8 +159,8 @@ google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57 h1:
google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57/go.mod h1:kSJwQxqmFXeo79zOmbrALdflXQeAYcUbgS7PbpMknCY=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57 h1:mWPCjDEyshlQYzBpMNHaEof6UX1PmHcaUODUywQ0uac=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ=
google.golang.org/grpc v1.79.2 h1:fRMD94s2tITpyJGtBBn7MkMseNpOZU8ZxgC3MMBaXRU=
google.golang.org/grpc v1.79.2/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ=
google.golang.org/grpc v1.79.1 h1:zGhSi45ODB9/p3VAawt9a+O/MULLl9dpizzNNpq7flY=
google.golang.org/grpc v1.79.1/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -521,3 +521,82 @@ func (m *Monitor) DisableTarget(id int) error {
return nil
}
// GetTargetIDs returns a slice of all current target IDs
func (m *Monitor) GetTargetIDs() []int {
m.mutex.RLock()
defer m.mutex.RUnlock()
ids := make([]int, 0, len(m.targets))
for id := range m.targets {
ids = append(ids, id)
}
return ids
}
// SyncTargets synchronizes the current targets to match the desired set.
// It removes targets not in the desired set and adds targets that are missing.
func (m *Monitor) SyncTargets(desiredConfigs []Config) error {
m.mutex.Lock()
defer m.mutex.Unlock()
logger.Info("Syncing health check targets: %d desired targets", len(desiredConfigs))
// Build a set of desired target IDs
desiredIDs := make(map[int]Config)
for _, config := range desiredConfigs {
desiredIDs[config.ID] = config
}
// Find targets to remove (exist but not in desired set)
var toRemove []int
for id := range m.targets {
if _, exists := desiredIDs[id]; !exists {
toRemove = append(toRemove, id)
}
}
// Remove targets that are not in the desired set
for _, id := range toRemove {
logger.Info("Sync: removing health check target %d", id)
if target, exists := m.targets[id]; exists {
target.cancel()
delete(m.targets, id)
}
}
// Add or update targets from the desired set
var addedCount, updatedCount int
for id, config := range desiredIDs {
if existing, exists := m.targets[id]; exists {
// Target exists - check if config changed and update if needed
// For now, we'll replace it to ensure config is up to date
logger.Debug("Sync: updating health check target %d", id)
existing.cancel()
delete(m.targets, id)
if err := m.addTargetUnsafe(config); err != nil {
logger.Error("Sync: failed to update target %d: %v", id, err)
return fmt.Errorf("failed to update target %d: %v", id, err)
}
updatedCount++
} else {
// Target doesn't exist - add it
logger.Debug("Sync: adding health check target %d", id)
if err := m.addTargetUnsafe(config); err != nil {
logger.Error("Sync: failed to add target %d: %v", id, err)
return fmt.Errorf("failed to add target %d: %v", id, err)
}
addedCount++
}
}
logger.Info("Sync complete: removed %d, added %d, updated %d targets",
len(toRemove), addedCount, updatedCount)
// Notify callback if any changes were made
if (len(toRemove) > 0 || addedCount > 0 || updatedCount > 0) && m.callback != nil {
go m.callback(m.getAllTargetsUnsafe())
}
return nil
}

View File

@@ -27,17 +27,16 @@ type ExitNode struct {
// Manager handles UDP hole punching operations
type Manager struct {
mu sync.Mutex
running bool
stopChan chan struct{}
sharedBind *bind.SharedBind
ID string
token string
publicKey string
clientType string
exitNodes map[string]ExitNode // key is endpoint
updateChan chan struct{} // signals the goroutine to refresh exit nodes
publicDNS []string
mu sync.Mutex
running bool
stopChan chan struct{}
sharedBind *bind.SharedBind
ID string
token string
publicKey string
clientType string
exitNodes map[string]ExitNode // key is endpoint
updateChan chan struct{} // signals the goroutine to refresh exit nodes
sendHolepunchInterval time.Duration
sendHolepunchIntervalMin time.Duration
@@ -50,13 +49,12 @@ const defaultSendHolepunchIntervalMax = 60 * time.Second
const defaultSendHolepunchIntervalMin = 1 * time.Second
// NewManager creates a new hole punch manager
func NewManager(sharedBind *bind.SharedBind, ID string, clientType string, publicKey string, publicDNS []string) *Manager {
func NewManager(sharedBind *bind.SharedBind, ID string, clientType string, publicKey string) *Manager {
return &Manager{
sharedBind: sharedBind,
ID: ID,
clientType: clientType,
publicKey: publicKey,
publicDNS: publicDNS,
exitNodes: make(map[string]ExitNode),
sendHolepunchInterval: defaultSendHolepunchIntervalMin,
sendHolepunchIntervalMin: defaultSendHolepunchIntervalMin,
@@ -283,13 +281,7 @@ func (m *Manager) TriggerHolePunch() error {
// Send hole punch to all exit nodes
successCount := 0
for _, exitNode := range currentExitNodes {
var host string
var err error
if len(m.publicDNS) > 0 {
host, err = util.ResolveDomainUpstream(exitNode.Endpoint, m.publicDNS)
} else {
host, err = util.ResolveDomain(exitNode.Endpoint)
}
host, err := util.ResolveDomain(exitNode.Endpoint)
if err != nil {
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
continue
@@ -400,13 +392,7 @@ func (m *Manager) runMultipleExitNodes() {
var resolvedNodes []resolvedExitNode
for _, exitNode := range currentExitNodes {
var host string
var err error
if len(m.publicDNS) > 0 {
host, err = util.ResolveDomainUpstream(exitNode.Endpoint, m.publicDNS)
} else {
host, err = util.ResolveDomain(exitNode.Endpoint)
}
host, err := util.ResolveDomain(exitNode.Endpoint)
if err != nil {
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
continue

View File

@@ -49,11 +49,10 @@ type cachedAddr struct {
// HolepunchTester monitors holepunch connectivity using magic packets
type HolepunchTester struct {
sharedBind *bind.SharedBind
publicDNS []string
mu sync.RWMutex
running bool
stopChan chan struct{}
sharedBind *bind.SharedBind
mu sync.RWMutex
running bool
stopChan chan struct{}
// Pending requests waiting for responses (key: echo data as string)
pendingRequests sync.Map // map[string]*pendingRequest
@@ -85,10 +84,9 @@ type pendingRequest struct {
}
// NewHolepunchTester creates a new holepunch tester using the given SharedBind
func NewHolepunchTester(sharedBind *bind.SharedBind, publicDNS []string) *HolepunchTester {
func NewHolepunchTester(sharedBind *bind.SharedBind) *HolepunchTester {
return &HolepunchTester{
sharedBind: sharedBind,
publicDNS: publicDNS,
addrCache: make(map[string]*cachedAddr),
addrCacheTTL: 5 * time.Minute, // Cache addresses for 5 minutes
}
@@ -171,13 +169,7 @@ func (t *HolepunchTester) resolveEndpoint(endpoint string) (*net.UDPAddr, error)
}
// Resolve the endpoint
var host string
var err error
if len(t.publicDNS) > 0 {
host, err = util.ResolveDomainUpstream(endpoint, t.publicDNS)
} else {
host, err = util.ResolveDomain(endpoint)
}
host, err := util.ResolveDomain(endpoint)
if err != nil {
host = endpoint
}

155
main.go
View File

@@ -565,7 +565,7 @@ func runNewtMain(ctx context.Context) {
id, // CLI arg takes precedence
secret, // CLI arg takes precedence
endpoint,
pingInterval,
30*time.Second,
pingTimeout,
opt,
)
@@ -619,8 +619,6 @@ func runNewtMain(ctx context.Context) {
var wgData WgData
var dockerEventMonitor *docker.EventMonitor
logger.Debug("++++++++++++++++++++++ the port is %d", port)
if !disableClients {
setupClients(client)
}
@@ -959,7 +957,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
"publicKey": publicKey.String(),
"pingResults": pingResults,
"newtVersion": newtVersion,
}, 1*time.Second)
}, 2*time.Second)
return
}
@@ -1062,7 +1060,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
"publicKey": publicKey.String(),
"pingResults": pingResults,
"newtVersion": newtVersion,
}, 1*time.Second)
}, 2*time.Second)
logger.Debug("Sent exit node ping results to cloud for selection: pingResults=%+v", pingResults)
})
@@ -1167,6 +1165,153 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
}
})
// Register handler for syncing targets (TCP, UDP, and health checks)
client.RegisterHandler("newt/sync", func(msg websocket.WSMessage) {
logger.Info("Received sync message")
// if there is no wgData or pm, we can't sync targets
if wgData.TunnelIP == "" || pm == nil {
logger.Info(msgNoTunnelOrProxy)
return
}
// Define the sync data structure
type SyncData struct {
Targets TargetsByType `json:"targets"`
HealthCheckTargets []healthcheck.Config `json:"healthCheckTargets"`
}
var syncData SyncData
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling sync data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &syncData); err != nil {
logger.Error("Error unmarshaling sync data: %v", err)
return
}
logger.Debug("Sync data received: TCP targets=%d, UDP targets=%d, health check targets=%d",
len(syncData.Targets.TCP), len(syncData.Targets.UDP), len(syncData.HealthCheckTargets))
//TODO: TEST AND IMPLEMENT THIS
// // Build sets of desired targets (port -> target string)
// desiredTCP := make(map[int]string)
// for _, t := range syncData.Targets.TCP {
// parts := strings.Split(t, ":")
// if len(parts) != 3 {
// logger.Warn("Invalid TCP target format: %s", t)
// continue
// }
// port := 0
// if _, err := fmt.Sscanf(parts[0], "%d", &port); err != nil {
// logger.Warn("Invalid port in TCP target: %s", parts[0])
// continue
// }
// desiredTCP[port] = parts[1] + ":" + parts[2]
// }
// desiredUDP := make(map[int]string)
// for _, t := range syncData.Targets.UDP {
// parts := strings.Split(t, ":")
// if len(parts) != 3 {
// logger.Warn("Invalid UDP target format: %s", t)
// continue
// }
// port := 0
// if _, err := fmt.Sscanf(parts[0], "%d", &port); err != nil {
// logger.Warn("Invalid port in UDP target: %s", parts[0])
// continue
// }
// desiredUDP[port] = parts[1] + ":" + parts[2]
// }
// // Get current targets from proxy manager
// currentTCP, currentUDP := pm.GetTargets()
// // Sync TCP targets
// // Remove TCP targets not in desired set
// if tcpForIP, ok := currentTCP[wgData.TunnelIP]; ok {
// for port := range tcpForIP {
// if _, exists := desiredTCP[port]; !exists {
// logger.Info("Sync: removing TCP target on port %d", port)
// targetStr := fmt.Sprintf("%d:%s", port, tcpForIP[port])
// updateTargets(pm, "remove", wgData.TunnelIP, "tcp", TargetData{Targets: []string{targetStr}})
// }
// }
// }
// // Add TCP targets that are missing
// for port, target := range desiredTCP {
// needsAdd := true
// if tcpForIP, ok := currentTCP[wgData.TunnelIP]; ok {
// if currentTarget, exists := tcpForIP[port]; exists {
// // Check if target address changed
// if currentTarget == target {
// needsAdd = false
// } else {
// // Target changed, remove old one first
// logger.Info("Sync: updating TCP target on port %d", port)
// targetStr := fmt.Sprintf("%d:%s", port, currentTarget)
// updateTargets(pm, "remove", wgData.TunnelIP, "tcp", TargetData{Targets: []string{targetStr}})
// }
// }
// }
// if needsAdd {
// logger.Info("Sync: adding TCP target on port %d -> %s", port, target)
// targetStr := fmt.Sprintf("%d:%s", port, target)
// updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: []string{targetStr}})
// }
// }
// // Sync UDP targets
// // Remove UDP targets not in desired set
// if udpForIP, ok := currentUDP[wgData.TunnelIP]; ok {
// for port := range udpForIP {
// if _, exists := desiredUDP[port]; !exists {
// logger.Info("Sync: removing UDP target on port %d", port)
// targetStr := fmt.Sprintf("%d:%s", port, udpForIP[port])
// updateTargets(pm, "remove", wgData.TunnelIP, "udp", TargetData{Targets: []string{targetStr}})
// }
// }
// }
// // Add UDP targets that are missing
// for port, target := range desiredUDP {
// needsAdd := true
// if udpForIP, ok := currentUDP[wgData.TunnelIP]; ok {
// if currentTarget, exists := udpForIP[port]; exists {
// // Check if target address changed
// if currentTarget == target {
// needsAdd = false
// } else {
// // Target changed, remove old one first
// logger.Info("Sync: updating UDP target on port %d", port)
// targetStr := fmt.Sprintf("%d:%s", port, currentTarget)
// updateTargets(pm, "remove", wgData.TunnelIP, "udp", TargetData{Targets: []string{targetStr}})
// }
// }
// }
// if needsAdd {
// logger.Info("Sync: adding UDP target on port %d -> %s", port, target)
// targetStr := fmt.Sprintf("%d:%s", port, target)
// updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: []string{targetStr}})
// }
// }
// // Sync health check targets
// if err := healthMonitor.SyncTargets(syncData.HealthCheckTargets); err != nil {
// logger.Error("Failed to sync health check targets: %v", err)
// } else {
// logger.Info("Successfully synced health check targets")
// }
logger.Info("Sync complete")
})
// Register handler for Docker socket check
client.RegisterHandler("newt/socket/check", func(msg websocket.WSMessage) {
logger.Debug("Received Docker socket check request")

View File

@@ -48,6 +48,23 @@ type SubnetRule struct {
PortRanges []PortRange // empty slice means all ports allowed
}
// GetAllRules returns a copy of all subnet rules
func (sl *SubnetLookup) GetAllRules() []SubnetRule {
sl.mu.RLock()
defer sl.mu.RUnlock()
var rules []SubnetRule
for _, destTriePtr := range sl.sourceTrie.All() {
if destTriePtr == nil {
continue
}
for _, rule := range destTriePtr.rules {
rules = append(rules, *rule)
}
}
return rules
}
// connKey uniquely identifies a connection for NAT tracking
type connKey struct {
srcIP string
@@ -200,6 +217,14 @@ func (p *ProxyHandler) RemoveSubnetRule(sourcePrefix, destPrefix netip.Prefix) {
p.subnetLookup.RemoveSubnet(sourcePrefix, destPrefix)
}
// GetAllRules returns all subnet rules from the proxy handler
func (p *ProxyHandler) GetAllRules() []SubnetRule {
if p == nil || !p.enabled {
return nil
}
return p.subnetLookup.GetAllRules()
}
// LookupDestinationRewrite looks up the rewritten destination for a connection
// This is used by TCP/UDP handlers to find the actual target address
func (p *ProxyHandler) LookupDestinationRewrite(srcIP, dstIP string, dstPort uint16, proto uint8) (netip.Addr, bool) {

View File

@@ -369,6 +369,15 @@ func (net *Net) RemoveProxySubnetRule(sourcePrefix, destPrefix netip.Prefix) {
}
}
// GetProxySubnetRules returns all subnet rules from the proxy handler
func (net *Net) GetProxySubnetRules() []SubnetRule {
tun := (*netTun)(net)
if tun.proxyHandler != nil {
return tun.proxyHandler.GetAllRules()
}
return nil
}
// GetProxyHandler returns the proxy handler (for advanced use cases)
// Returns nil if proxy is not enabled
func (net *Net) GetProxyHandler() *ProxyHandler {

View File

@@ -736,3 +736,28 @@ func (pm *ProxyManager) PrintTargets() {
}
}
}
// GetTargets returns a copy of the current TCP and UDP targets
// Returns map[listenIP]map[port]targetAddress for both TCP and UDP
func (pm *ProxyManager) GetTargets() (tcpTargets map[string]map[int]string, udpTargets map[string]map[int]string) {
pm.mutex.RLock()
defer pm.mutex.RUnlock()
tcpTargets = make(map[string]map[int]string)
for listenIP, targets := range pm.tcpTargets {
tcpTargets[listenIP] = make(map[int]string)
for port, targetAddr := range targets {
tcpTargets[listenIP][port] = targetAddr
}
}
udpTargets = make(map[string]map[int]string)
for listenIP, targets := range pm.udpTargets {
udpTargets[listenIP] = make(map[int]string)
for port, targetAddr := range targets {
udpTargets[listenIP][port] = targetAddr
}
}
return tcpTargets, udpTargets
}

View File

@@ -1,7 +1,6 @@
package util
import (
"context"
"encoding/base64"
"encoding/binary"
"encoding/hex"
@@ -15,99 +14,6 @@ import (
"golang.zx2c4.com/wireguard/device"
)
func ResolveDomainUpstream(domain string, publicDNS []string) (string, error) {
// trim whitespace
domain = strings.TrimSpace(domain)
// Remove any protocol prefix if present (do this first, before splitting host/port)
domain = strings.TrimPrefix(domain, "http://")
domain = strings.TrimPrefix(domain, "https://")
// if there are any trailing slashes, remove them
domain = strings.TrimSuffix(domain, "/")
// Check if there's a port in the domain
host, port, err := net.SplitHostPort(domain)
if err != nil {
// No port found, use the domain as is
host = domain
port = ""
}
// Check if host is already an IP address (IPv4 or IPv6)
// For IPv6, the host from SplitHostPort will already have brackets stripped
// but if there was no port, we need to handle bracketed IPv6 addresses
cleanHost := strings.TrimPrefix(strings.TrimSuffix(host, "]"), "[")
if ip := net.ParseIP(cleanHost); ip != nil {
// It's already an IP address, no need to resolve
ipAddr := ip.String()
if port != "" {
return net.JoinHostPort(ipAddr, port), nil
}
return ipAddr, nil
}
// Lookup IP addresses using the upstream DNS servers if provided
var ips []net.IP
if len(publicDNS) > 0 {
var lastErr error
for _, server := range publicDNS {
// Ensure the upstream DNS address has a port
dnsAddr := server
if _, _, err := net.SplitHostPort(dnsAddr); err != nil {
// No port specified, default to 53
dnsAddr = net.JoinHostPort(server, "53")
}
resolver := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{}
return d.DialContext(ctx, "udp", dnsAddr)
},
}
ips, lastErr = resolver.LookupIP(context.Background(), "ip", host)
if lastErr == nil {
break
}
}
if lastErr != nil {
return "", fmt.Errorf("DNS lookup failed using all upstream servers: %v", lastErr)
}
} else {
ips, err = net.LookupIP(host)
if err != nil {
return "", fmt.Errorf("DNS lookup failed: %v", err)
}
}
if len(ips) == 0 {
return "", fmt.Errorf("no IP addresses found for domain %s", host)
}
// Get the first IPv4 address if available
var ipAddr string
for _, ip := range ips {
if ipv4 := ip.To4(); ipv4 != nil {
ipAddr = ipv4.String()
break
}
}
// If no IPv4 found, use the first IP (might be IPv6)
if ipAddr == "" {
ipAddr = ips[0].String()
}
// Add port back if it existed
if port != "" {
ipAddr = net.JoinHostPort(ipAddr, port)
}
return ipAddr, nil
}
func ResolveDomain(domain string) (string, error) {
// trim whitespace
domain = strings.TrimSpace(domain)

View File

@@ -2,6 +2,7 @@ package websocket
import (
"bytes"
"compress/gzip"
"crypto/tls"
"crypto/x509"
"encoding/json"
@@ -47,6 +48,11 @@ type Client struct {
metricsCtx context.Context
configNeedsSave bool // Flag to track if config needs to be saved
serverVersion string
configVersion int64 // Latest config version received from server
configVersionMux sync.RWMutex
processingMessage bool // Flag to track if a message is currently being processed
processingMux sync.RWMutex // Protects processingMessage
processingWg sync.WaitGroup // WaitGroup to wait for message processing to complete
}
type ClientOption func(*Client)
@@ -154,6 +160,20 @@ func (c *Client) GetServerVersion() string {
return c.serverVersion
}
// GetConfigVersion returns the latest config version received from server
func (c *Client) GetConfigVersion() int64 {
c.configVersionMux.RLock()
defer c.configVersionMux.RUnlock()
return c.configVersion
}
// setConfigVersion updates the config version
func (c *Client) setConfigVersion(version int64) {
c.configVersionMux.Lock()
defer c.configVersionMux.Unlock()
c.configVersion = version
}
// Connect establishes the WebSocket connection
func (c *Client) Connect() error {
go c.connectWithRetry()
@@ -653,12 +673,33 @@ func (c *Client) pingMonitor() {
if c.conn == nil {
return
}
// Skip ping if a message is currently being processed
c.processingMux.RLock()
isProcessing := c.processingMessage
c.processingMux.RUnlock()
if isProcessing {
logger.Debug("Skipping ping, message is being processed")
continue
}
c.configVersionMux.RLock()
configVersion := c.configVersion
c.configVersionMux.RUnlock()
pingMsg := WSMessage{
Type: "newt/ping",
Data: map[string]interface{}{},
ConfigVersion: configVersion,
}
c.writeMux.Lock()
err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout))
err := c.conn.WriteJSON(pingMsg)
if err == nil {
telemetry.IncWSMessage(c.metricsContext(), "out", "ping")
}
c.writeMux.Unlock()
if err != nil {
// Check if we're shutting down before logging error and reconnecting
select {
@@ -709,10 +750,13 @@ func (c *Client) readPumpWithDisconnectDetection(started time.Time) {
disconnectResult = "success"
return
default:
var msg WSMessage
err := c.conn.ReadJSON(&msg)
msgType, p, err := c.conn.ReadMessage()
if err == nil {
telemetry.IncWSMessage(c.metricsContext(), "in", "text")
if msgType == websocket.BinaryMessage {
telemetry.IncWSMessage(c.metricsContext(), "in", "binary")
} else {
telemetry.IncWSMessage(c.metricsContext(), "in", "text")
}
}
if err != nil {
// Check if we're shutting down before logging error
@@ -737,9 +781,47 @@ func (c *Client) readPumpWithDisconnectDetection(started time.Time) {
}
}
// Update config version from incoming message
var data []byte
if msgType == websocket.BinaryMessage {
gr, err := gzip.NewReader(bytes.NewReader(p))
if err != nil {
logger.Error("WebSocket failed to create gzip reader: %v", err)
continue
}
data, err = io.ReadAll(gr)
gr.Close()
if err != nil {
logger.Error("WebSocket failed to decompress message: %v", err)
continue
}
} else {
data = p
}
var msg WSMessage
if err = json.Unmarshal(data, &msg); err != nil {
logger.Error("WebSocket failed to parse message: %v", err)
continue
}
c.setConfigVersion(msg.ConfigVersion)
c.handlersMux.RLock()
if handler, ok := c.handlers[msg.Type]; ok {
// Mark that we're processing a message
c.processingMux.Lock()
c.processingMessage = true
c.processingMux.Unlock()
c.processingWg.Add(1)
handler(msg)
// Mark that we're done processing
c.processingWg.Done()
c.processingMux.Lock()
c.processingMessage = false
c.processingMux.Unlock()
}
c.handlersMux.RUnlock()
}

View File

@@ -17,6 +17,7 @@ type TokenResponse struct {
}
type WSMessage struct {
Type string `json:"type"`
Data interface{} `json:"data"`
Type string `json:"type"`
Data interface{} `json:"data"`
ConfigVersion int64 `json:"configVersion,omitempty"`
}