Compare commits

...

6 Commits

Author SHA1 Message Date
Owen
02949be245 Support connection testing in native 2025-12-04 21:48:32 -05:00
Owen
6d51cbf0c0 Check permissions 2025-12-04 21:39:32 -05:00
Owen
4dbf200cca Change DNS lookup to conntrack 2025-12-04 20:13:48 -05:00
Owen
d8b4fb4acb Change to disable clients 2025-12-04 20:13:35 -05:00
Owen
5dd5a56379 Add caching to the dns requests - is this good enough? 2025-12-03 22:00:23 -05:00
Owen
8c4d6e2e0a Working on more hp 2025-12-03 20:49:46 -05:00
15 changed files with 391 additions and 206 deletions

View File

@@ -11,6 +11,7 @@ import (
"sync"
"sync/atomic"
"github.com/fosrl/newt/logger"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
@@ -522,6 +523,7 @@ func (b *SharedBind) receiveIPv4Simple(conn *net.UDPConn, bufs [][]byte, sizes [
func (b *SharedBind) handleMagicPacket(data []byte, addr *net.UDPAddr) bool {
// Check if this is a test request packet
if len(data) >= MagicTestRequestLen && bytes.HasPrefix(data, MagicTestRequest) {
logger.Debug("Received magic test REQUEST from %s, sending response", addr.String())
// Extract the random data portion to echo back
echoData := data[len(MagicTestRequest) : len(MagicTestRequest)+MagicPacketDataLen]
@@ -544,6 +546,7 @@ func (b *SharedBind) handleMagicPacket(data []byte, addr *net.UDPAddr) bool {
// Check if this is a test response packet
if len(data) >= MagicTestResponseLen && bytes.HasPrefix(data, MagicTestResponse) {
logger.Debug("Received magic test RESPONSE from %s", addr.String())
// Extract the echoed data
echoData := data[len(MagicTestResponse) : len(MagicTestResponse)+MagicPacketDataLen]
@@ -557,6 +560,8 @@ func (b *SharedBind) handleMagicPacket(data []byte, addr *net.UDPAddr) bool {
addrPort = netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())
}
callback(addrPort, echoData)
} else {
logger.Debug("Magic response received but no callback registered")
}
return true

View File

@@ -5,16 +5,13 @@ import (
"github.com/fosrl/newt/clients"
wgnetstack "github.com/fosrl/newt/clients"
"github.com/fosrl/newt/clients/permissions"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/netstack2"
"github.com/fosrl/newt/websocket"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/fosrl/newt/wgtester"
)
var wgService *clients.WireGuardService
var wgTesterServer *wgtester.Server
var ready bool
func setupClients(client *websocket.Client) {
@@ -28,29 +25,23 @@ func setupClients(client *websocket.Client) {
host = strings.TrimSuffix(host, "/")
logger.Info("Setting up clients with netstack2...")
// if useNativeInterface is true make sure we have permission to use native interface
if useNativeInterface {
logger.Debug("Checking permissions for native interface")
err := permissions.CheckNativeInterfacePermissions()
if err != nil {
logger.Fatal("Insufficient permissions to create native TUN interface: %v", err)
return
}
}
// Create WireGuard service
wgService, err = wgnetstack.NewWireGuardService(interfaceName, mtuInt, host, id, client, dns, useNativeInterface)
if err != nil {
logger.Fatal("Failed to create WireGuard service: %v", err)
}
// // Set up callback to restart wgtester with netstack when WireGuard is ready
wgService.SetOnNetstackReady(func(tnet *netstack2.Net) {
wgTesterServer = wgtester.NewServerWithNetstack("0.0.0.0", wgService.Port, id, tnet) // TODO: maybe make this the same ip of the wg server?
err := wgTesterServer.Start()
if err != nil {
logger.Error("Failed to start WireGuard tester server: %v", err)
}
})
wgService.SetOnNetstackClose(func() {
if wgTesterServer != nil {
wgTesterServer.Stop()
wgTesterServer = nil
}
})
client.OnTokenUpdate(func(token string) {
wgService.SetToken(token)
})
@@ -70,11 +61,6 @@ func closeClients() {
wgService.Close()
wgService = nil
}
if wgTesterServer != nil {
wgTesterServer.Stop()
wgTesterServer = nil
}
}
func clientsHandleNewtConnection(publicKey string, endpoint string) {

View File

@@ -2,8 +2,6 @@ package clients
import (
"context"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"net"
@@ -22,6 +20,7 @@ import (
"github.com/fosrl/newt/network"
"github.com/fosrl/newt/util"
"github.com/fosrl/newt/websocket"
"github.com/fosrl/newt/wgtester"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
@@ -73,7 +72,6 @@ type WireGuardService struct {
client *websocket.Client
config WgConfig
key wgtypes.Key
keyFilePath string
newtId string
lastReadings map[string]PeerReading
mu sync.Mutex
@@ -103,6 +101,7 @@ type WireGuardService struct {
directRelayWg sync.WaitGroup
netstackListener net.PacketConn
netstackListenerMu sync.Mutex
wgTesterServer *wgtester.Server
}
func NewWireGuardService(interfaceName string, mtu int, host string, newtId string, wsClient *websocket.Client, dns string, useNativeInterface bool) (*WireGuardService, error) {
@@ -224,6 +223,11 @@ func (s *WireGuardService) Close() {
s.sharedBind = nil
logger.Info("Released shared UDP bind")
}
if s.wgTesterServer != nil {
s.wgTesterServer.Stop()
s.wgTesterServer = nil
}
}
func (s *WireGuardService) SetToken(token string) {
@@ -268,10 +272,20 @@ func (s *WireGuardService) StartHolepunch(publicKey string, endpoint string) {
return
}
logger.Info("Starting hole punch to %s with public key: %s", endpoint, publicKey)
if err := s.holePunchManager.StartSingleEndpoint(endpoint, publicKey); err != nil {
// Convert websocket.ExitNode to holepunch.ExitNode
hpExitNodes := []holepunch.ExitNode{
{
Endpoint: endpoint,
PublicKey: publicKey,
},
}
// Start hole punching using the manager
if err := s.holePunchManager.StartMultipleExitNodes(hpExitNodes); err != nil {
logger.Warn("Failed to start hole punch: %v", err)
}
logger.Info("Starting hole punch to %s with public key: %s", endpoint, publicKey)
}
// StartDirectUDPRelay starts a direct UDP relay from the main tunnel netstack to the clients' WireGuard.
@@ -386,7 +400,7 @@ func (s *WireGuardService) runDirectUDPRelay(listener net.PacketConn) {
continue
}
logger.Debug("Relayed %d bytes from %s into WireGuard", n, srcAddrPort.String())
// logger.Debug("Relayed %d bytes from %s into WireGuard", n, srcAddrPort.String())
}
}
@@ -477,11 +491,6 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
// Parse the IP address and CIDR mask
tunnelIP := netip.MustParseAddr(parts[0])
// Stop any ongoing hole punch operations
if s.holePunchManager != nil {
s.holePunchManager.Stop()
}
var err error
if s.useNativeInterface {
@@ -563,6 +572,12 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
return fmt.Errorf("failed to configure interface: %v", err)
}
s.wgTesterServer = wgtester.NewServer("0.0.0.0", s.Port, s.newtId) // TODO: maybe make this the same ip of the wg server?
err = s.wgTesterServer.Start()
if err != nil {
logger.Error("Failed to start WireGuard tester server: %v", err)
}
logger.Info("WireGuard native device created and configured on %s", interfaceName)
s.mu.Unlock()
@@ -610,16 +625,13 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
logger.Info("WireGuard netstack device created and configured")
// Store callback and tnet reference before releasing mutex
callback := s.onNetstackReady
tnet := s.tnet
// Release the mutex before calling the callback
s.mu.Unlock()
// Call the callback if it's set to notify that netstack is ready
if callback != nil {
callback(tnet)
s.wgTesterServer = wgtester.NewServerWithNetstack("0.0.0.0", s.Port, s.newtId, s.tnet) // TODO: maybe make this the same ip of the wg server?
err = s.wgTesterServer.Start()
if err != nil {
logger.Error("Failed to start WireGuard tester server: %v", err)
}
// Note: we already unlocked above, so don't use defer unlock
@@ -682,15 +694,6 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
return fmt.Errorf("invalid CIDR %s: %v", target.DestPrefix, err)
}
var rewriteTo netip.Prefix
if target.RewriteTo != "" {
rewriteTo, err = netip.ParsePrefix(target.RewriteTo)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.RewriteTo, err)
continue
}
}
var portRanges []netstack2.PortRange
for _, pr := range target.PortRange {
portRanges = append(portRanges, netstack2.PortRange{
@@ -699,9 +702,9 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
})
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges)
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges)
logger.Info("Added target subnet from %s to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.PortRange)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange)
}
return nil
@@ -759,6 +762,8 @@ func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) {
return
}
s.holePunchManager.TriggerHolePunch()
err = s.addPeerToDevice(peer)
if err != nil {
logger.Info("Error adding peer: %v", err)
@@ -836,6 +841,8 @@ func (s *WireGuardService) handleUpdatePeer(msg websocket.WSMessage) {
return
}
s.holePunchManager.TriggerHolePunch()
// Parse the public key
pubKey, err := wgtypes.ParseKey(request.PublicKey)
if err != nil {
@@ -970,13 +977,7 @@ func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
// parse the public keys and have them as base64 in the opposite order to fixKey
for i := range peerBandwidths {
pubKeyBytes, err := base64.StdEncoding.DecodeString(peerBandwidths[i].PublicKey)
if err != nil {
logger.Info("Failed to decode public key %s: %v", peerBandwidths[i].PublicKey, err)
continue
}
// Convert to hex
peerBandwidths[i].PublicKey = hex.EncodeToString(pubKeyBytes)
peerBandwidths[i].PublicKey = util.UnfixKey(peerBandwidths[i].PublicKey) // its in the long form but we need base64
}
return peerBandwidths, nil
@@ -1037,7 +1038,7 @@ func (s *WireGuardService) reportPeerBandwidth() error {
return fmt.Errorf("failed to calculate peer bandwidth: %v", err)
}
err = s.client.SendMessage("newt/receive-bandwidth", map[string]interface{}{
err = s.client.SendMessageNoLog("newt/receive-bandwidth", map[string]interface{}{
"bandwidthData": bandwidths,
})
if err != nil {
@@ -1084,15 +1085,6 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
continue
}
var rewriteTo netip.Prefix
if target.RewriteTo != "" {
rewriteTo, err = netip.ParsePrefix(target.RewriteTo)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.RewriteTo, err)
continue
}
}
var portRanges []netstack2.PortRange
for _, pr := range target.PortRange {
portRanges = append(portRanges, netstack2.PortRange{
@@ -1101,9 +1093,9 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
})
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges)
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges)
logger.Info("Added target subnet from %s to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.PortRange)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange)
}
}
@@ -1210,15 +1202,6 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
continue
}
var rewriteTo netip.Prefix
if target.RewriteTo != "" {
rewriteTo, err = netip.ParsePrefix(target.RewriteTo)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.RewriteTo, err)
continue
}
}
var portRanges []netstack2.PortRange
for _, pr := range target.PortRange {
portRanges = append(portRanges, netstack2.PortRange{
@@ -1227,8 +1210,8 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
})
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges)
logger.Info("Added target subnet from %s to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.PortRange)
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange)
}
}

View File

@@ -0,0 +1,18 @@
//go:build darwin
package permissions
import (
"fmt"
"os"
)
// CheckNativeInterfacePermissions checks if the process has sufficient
// permissions to create a native TUN interface on macOS.
// This typically requires root privileges.
func CheckNativeInterfacePermissions() error {
if os.Geteuid() == 0 {
return nil
}
return fmt.Errorf("insufficient permissions: need root to create TUN interface on macOS")
}

View File

@@ -0,0 +1,96 @@
//go:build linux
package permissions
import (
"fmt"
"os"
"unsafe"
"github.com/fosrl/newt/logger"
"golang.org/x/sys/unix"
)
const (
// TUN device constants
tunDevice = "/dev/net/tun"
ifnamsiz = 16
iffTun = 0x0001
iffNoPi = 0x1000
tunSetIff = 0x400454ca
)
// ifReq is the structure for TUNSETIFF ioctl
type ifReq struct {
Name [ifnamsiz]byte
Flags uint16
_ [22]byte // padding to match kernel structure
}
// CheckNativeInterfacePermissions checks if the process has sufficient
// permissions to create a native TUN interface on Linux.
// This requires either root privileges (UID 0) or CAP_NET_ADMIN capability.
func CheckNativeInterfacePermissions() error {
logger.Debug("Checking native interface permissions on Linux")
// Check if running as root
if os.Geteuid() == 0 {
logger.Debug("Running as root, sufficient permissions for native TUN interface")
return nil
}
// Check for CAP_NET_ADMIN capability
caps := unix.CapUserHeader{
Version: unix.LINUX_CAPABILITY_VERSION_3,
Pid: 0, // 0 means current process
}
var data [2]unix.CapUserData
if err := unix.Capget(&caps, &data[0]); err != nil {
logger.Debug("Failed to get capabilities: %v, will try creating test TUN", err)
} else {
// CAP_NET_ADMIN is capability bit 12
const CAP_NET_ADMIN = 12
if data[0].Effective&(1<<CAP_NET_ADMIN) != 0 {
logger.Debug("Process has CAP_NET_ADMIN capability, sufficient permissions for native TUN interface")
return nil
}
logger.Debug("Process does not have CAP_NET_ADMIN capability in effective set")
}
// Actually try to create a TUN interface to verify permissions
// This is the most reliable check as it tests the actual operation
return tryCreateTestTun()
}
// tryCreateTestTun attempts to create a temporary TUN interface to verify
// we have the necessary permissions. This tests the actual ioctl call that
// will be used when creating the real interface.
func tryCreateTestTun() error {
f, err := os.OpenFile(tunDevice, os.O_RDWR, 0)
if err != nil {
return fmt.Errorf("cannot open %s: %v (need root or CAP_NET_ADMIN capability)", tunDevice, err)
}
defer f.Close()
// Try to create a TUN interface with a test name
// Using a random-ish name to avoid conflicts
var req ifReq
copy(req.Name[:], "tuntest0")
req.Flags = iffTun | iffNoPi
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
f.Fd(),
uintptr(tunSetIff),
uintptr(unsafe.Pointer(&req)),
)
if errno != 0 {
return fmt.Errorf("cannot create TUN interface (ioctl TUNSETIFF failed): %v (need root or CAP_NET_ADMIN capability)", errno)
}
// Success - the interface will be automatically destroyed when we close the fd
logger.Debug("Successfully created test TUN interface, sufficient permissions for native TUN interface")
return nil
}

View File

@@ -0,0 +1,38 @@
//go:build windows
package permissions
import (
"fmt"
"golang.org/x/sys/windows"
)
// CheckNativeInterfacePermissions checks if the process has sufficient
// permissions to create a native TUN interface on Windows.
// This requires Administrator privileges.
func CheckNativeInterfacePermissions() error {
var sid *windows.SID
err := windows.AllocateAndInitializeSid(
&windows.SECURITY_NT_AUTHORITY,
2,
windows.SECURITY_BUILTIN_DOMAIN_RID,
windows.DOMAIN_ALIAS_RID_ADMINS,
0, 0, 0, 0, 0, 0,
&sid)
if err != nil {
return fmt.Errorf("failed to initialize SID: %v", err)
}
defer windows.FreeSid(sid)
token := windows.Token(0)
member, err := token.IsMember(sid)
if err != nil {
return fmt.Errorf("failed to check admin group membership: %v", err)
}
if !member {
return fmt.Errorf("insufficient permissions: need Administrator to create TUN interface on Windows")
}
return nil
}

View File

@@ -25,7 +25,7 @@ import (
const msgHealthFileWriteFailed = "Failed to write health file: %v"
func ping(tnet *netstack.Net, dst string, timeout time.Duration) (time.Duration, error) {
logger.Debug("Pinging %s", dst)
// logger.Debug("Pinging %s", dst)
socket, err := tnet.Dial("ping4", dst)
if err != nil {
return 0, fmt.Errorf("failed to create ICMP socket: %w", err)
@@ -84,7 +84,7 @@ func ping(tnet *netstack.Net, dst string, timeout time.Duration) (time.Duration,
latency := time.Since(start)
logger.Debug("Ping to %s successful, latency: %v", dst, latency)
// logger.Debug("Ping to %s successful, latency: %v", dst, latency)
return latency, nil
}
@@ -122,7 +122,7 @@ func reliablePing(tnet *netstack.Net, dst string, baseTimeout time.Duration, max
// If we get at least one success, we can return early for health checks
if successCount > 0 {
avgLatency := totalLatency / time.Duration(successCount)
logger.Debug("Reliable ping succeeded after %d attempts, avg latency: %v", attempt, avgLatency)
// logger.Debug("Reliable ping succeeded after %d attempts, avg latency: %v", attempt, avgLatency)
return avgLatency, nil
}
}

View File

@@ -152,6 +152,28 @@ func (m *Manager) GetExitNodes() []ExitNode {
return nodes
}
// ResetInterval resets the hole punch interval back to the minimum value,
// allowing it to climb back up through exponential backoff.
// This is useful when network conditions change or connectivity is restored.
func (m *Manager) ResetInterval() {
m.mu.Lock()
defer m.mu.Unlock()
if m.sendHolepunchInterval != sendHolepunchIntervalMin {
m.sendHolepunchInterval = sendHolepunchIntervalMin
logger.Info("Reset hole punch interval to minimum (%v)", sendHolepunchIntervalMin)
}
// Signal the goroutine to apply the new interval if running
if m.running && m.updateChan != nil {
select {
case m.updateChan <- struct{}{}:
default:
// Channel full or closed, skip
}
}
}
// TriggerHolePunch sends an immediate hole punch packet to all configured exit nodes
// This is useful for triggering hole punching on demand without waiting for the interval
func (m *Manager) TriggerHolePunch() error {
@@ -266,27 +288,6 @@ func (m *Manager) Start() error {
return nil
}
// StartSingleEndpoint starts hole punching to a single endpoint (legacy mode)
func (m *Manager) StartSingleEndpoint(endpoint, serverPubKey string) error {
m.mu.Lock()
if m.running {
m.mu.Unlock()
logger.Debug("UDP hole punch already running, skipping new request")
return fmt.Errorf("hole punch already running")
}
m.running = true
m.stopChan = make(chan struct{})
m.mu.Unlock()
logger.Info("Starting UDP hole punch to %s with shared bind", endpoint)
go m.runSingleEndpoint(endpoint, serverPubKey)
return nil
}
// runMultipleExitNodes performs hole punching to multiple exit nodes
func (m *Manager) runMultipleExitNodes() {
defer func() {
@@ -404,67 +405,6 @@ func (m *Manager) runMultipleExitNodes() {
}
}
// runSingleEndpoint performs hole punching to a single endpoint
func (m *Manager) runSingleEndpoint(endpoint, serverPubKey string) {
defer func() {
m.mu.Lock()
m.running = false
m.mu.Unlock()
logger.Info("UDP hole punch goroutine ended for %s", endpoint)
}()
host, err := util.ResolveDomain(endpoint)
if err != nil {
logger.Error("Failed to resolve domain %s: %v", endpoint, err)
return
}
serverAddr := net.JoinHostPort(host, "21820")
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
return
}
// Execute once immediately before starting the loop
if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil {
logger.Warn("Failed to send initial hole punch: %v", err)
}
// Start with minimum interval
m.mu.Lock()
m.sendHolepunchInterval = sendHolepunchIntervalMin
m.mu.Unlock()
ticker := time.NewTicker(m.sendHolepunchInterval)
defer ticker.Stop()
for {
select {
case <-m.stopChan:
logger.Debug("Hole punch stopped by signal")
return
case <-ticker.C:
if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil {
logger.Debug("Failed to send hole punch: %v", err)
}
// Exponential backoff: double the interval up to max
m.mu.Lock()
newInterval := m.sendHolepunchInterval * 2
if newInterval > sendHolepunchIntervalMax {
newInterval = sendHolepunchIntervalMax
}
if newInterval != m.sendHolepunchInterval {
m.sendHolepunchInterval = newInterval
ticker.Reset(m.sendHolepunchInterval)
logger.Debug("Increased hole punch interval to %v", m.sendHolepunchInterval)
}
m.mu.Unlock()
}
}
}
// sendHolePunch sends an encrypted hole punch packet using the shared bind
func (m *Manager) sendHolePunch(remoteAddr *net.UDPAddr, serverPubKey string) error {
m.mu.Lock()

View File

@@ -140,16 +140,19 @@ func (t *HolepunchTester) Stop() {
// handleResponse is called by SharedBind when a magic response is received
func (t *HolepunchTester) handleResponse(addr netip.AddrPort, echoData []byte) {
logger.Debug("Received magic response from %s", addr.String())
key := string(echoData)
value, ok := t.pendingRequests.LoadAndDelete(key)
if !ok {
// No matching request found
logger.Debug("No pending request found for magic response from %s", addr.String())
return
}
req := value.(*pendingRequest)
rtt := time.Since(req.sentAt)
logger.Debug("Magic response matched pending request for %s (RTT: %v)", req.endpoint, rtt)
// Send RTT to the waiting goroutine (non-blocking)
select {

View File

@@ -3,6 +3,7 @@ package logger
import (
"fmt"
"os"
"strings"
"sync"
"time"
)
@@ -139,6 +140,10 @@ type WireGuardLogger struct {
func (l *Logger) GetWireGuardLogger(prepend string) *WireGuardLogger {
return &WireGuardLogger{
Verbosef: func(format string, args ...any) {
// if the format string contains "Sending keepalive packet", skip debug logging to reduce noise
if strings.Contains(format, "Sending keepalive packet") {
return
}
l.Debug(prepend+format, args...)
},
Errorf: func(format string, args ...any) {

14
main.go
View File

@@ -116,7 +116,7 @@ var (
err error
logLevel string
interfaceName string
acceptClients bool
disableClients bool
updownScript string
dockerSocket string
dockerEnforceNetworkValidation string
@@ -175,8 +175,8 @@ func main() {
regionEnv := os.Getenv("NEWT_REGION")
asyncBytesEnv := os.Getenv("NEWT_METRICS_ASYNC_BYTES")
acceptClientsEnv := os.Getenv("ACCEPT_CLIENTS")
acceptClients = acceptClientsEnv == "true"
disableClientsEnv := os.Getenv("DISABLE_CLIENTS")
disableClients = disableClientsEnv == "true"
useNativeInterfaceEnv := os.Getenv("USE_NATIVE_INTERFACE")
useNativeInterface = useNativeInterfaceEnv == "true"
enforceHealthcheckCertEnv := os.Getenv("ENFORCE_HC_CERT")
@@ -236,10 +236,10 @@ func main() {
flag.StringVar(&interfaceName, "interface", "newt", "Name of the WireGuard interface")
}
if useNativeInterfaceEnv == "" {
flag.BoolVar(&useNativeInterface, "native", false, "Use native WireGuard interface (requires WireGuard kernel module) and linux")
flag.BoolVar(&useNativeInterface, "native", false, "Use native WireGuard interface")
}
if acceptClientsEnv == "" {
flag.BoolVar(&acceptClients, "accept-clients", false, "Accept clients on the WireGuard interface")
if disableClientsEnv == "" {
flag.BoolVar(&disableClients, "disable-clients", false, "Disable clients on the WireGuard interface")
}
if enforceHealthcheckCertEnv == "" {
flag.BoolVar(&enforceHealthcheckCert, "enforce-hc-cert", false, "Enforce certificate validation for health checks (default: false, accepts any cert)")
@@ -528,7 +528,7 @@ func main() {
var wgData WgData
var dockerEventMonitor *docker.EventMonitor
if acceptClients {
if !disableClients {
setupClients(client)
}

View File

@@ -1,10 +1,14 @@
package netstack2
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"time"
"github.com/fosrl/newt/logger"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/checksum"
@@ -26,14 +30,18 @@ type PortRange struct {
// SubnetRule represents a subnet with optional port restrictions and source address
// When RewriteTo is set, DNAT (Destination Network Address Translation) is performed:
// - Incoming packets: destination IP is rewritten to RewriteTo.Addr()
// - Incoming packets: destination IP is rewritten to the resolved RewriteTo address
// - Outgoing packets: source IP is rewritten back to the original destination
//
// RewriteTo can be either:
// - An IP address with CIDR notation (e.g., "192.168.1.1/32")
// - A domain name (e.g., "example.com") which will be resolved at request time
//
// This allows transparent proxying where traffic appears to come from the rewritten address
type SubnetRule struct {
SourcePrefix netip.Prefix // Source IP prefix (who is sending)
DestPrefix netip.Prefix // Destination IP prefix (where it's going)
RewriteTo netip.Prefix // Optional rewrite address for DNAT (destination NAT)
RewriteTo string // Optional rewrite address for DNAT - can be IP/CIDR or domain name
PortRanges []PortRange // empty slice means all ports allowed
}
@@ -58,7 +66,8 @@ func NewSubnetLookup() *SubnetLookup {
// AddSubnet adds a subnet rule with source and destination prefixes and optional port restrictions
// If portRanges is nil or empty, all ports are allowed for this subnet
func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix, rewriteTo netip.Prefix, portRanges []PortRange) {
// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com")
func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange) {
sl.mu.Lock()
defer sl.mu.Unlock()
@@ -225,8 +234,9 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
// AddSubnetRule adds a subnet with optional port restrictions to the proxy handler
// sourcePrefix: The IP prefix of the peer sending the data
// destPrefix: The IP prefix of the destination
// rewriteTo: Optional address to rewrite destination to - can be IP/CIDR or domain name
// If portRanges is nil or empty, all ports are allowed for this subnet
func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix, rewriteTo netip.Prefix, portRanges []PortRange) {
func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange) {
if p == nil || !p.enabled {
return
}
@@ -241,6 +251,47 @@ func (p *ProxyHandler) RemoveSubnetRule(sourcePrefix, destPrefix netip.Prefix) {
p.subnetLookup.RemoveSubnet(sourcePrefix, destPrefix)
}
// resolveRewriteAddress resolves a rewrite address which can be either:
// - An IP address with CIDR notation (e.g., "192.168.1.1/32") - returns the IP directly
// - A plain IP address (e.g., "192.168.1.1") - returns the IP directly
// - A domain name (e.g., "example.com") - performs DNS lookup
func (p *ProxyHandler) resolveRewriteAddress(rewriteTo string) (netip.Addr, error) {
logger.Debug("Resolving rewrite address: %s", rewriteTo)
// First, try to parse as a CIDR prefix (e.g., "192.168.1.1/32")
if prefix, err := netip.ParsePrefix(rewriteTo); err == nil {
return prefix.Addr(), nil
}
// Try to parse as a plain IP address (e.g., "192.168.1.1")
if addr, err := netip.ParseAddr(rewriteTo); err == nil {
return addr, nil
}
// Not an IP address, treat as domain name - perform DNS lookup
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
ips, err := net.DefaultResolver.LookupIP(ctx, "ip4", rewriteTo)
if err != nil {
return netip.Addr{}, fmt.Errorf("failed to resolve domain %s: %w", rewriteTo, err)
}
if len(ips) == 0 {
return netip.Addr{}, fmt.Errorf("no IP addresses found for domain %s", rewriteTo)
}
// Use the first resolved IP address
ip := ips[0]
if ip4 := ip.To4(); ip4 != nil {
addr := netip.AddrFrom4([4]byte{ip4[0], ip4[1], ip4[2], ip4[3]})
logger.Debug("Resolved %s to %s", rewriteTo, addr)
return addr, nil
}
return netip.Addr{}, fmt.Errorf("no IPv4 address found for domain %s", rewriteTo)
}
// Initialize sets up the promiscuous NIC with the netTun's notification system
func (p *ProxyHandler) Initialize(notifiable channel.Notification) error {
if p == nil || !p.enabled {
@@ -334,12 +385,9 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
matchedRule := p.subnetLookup.Match(srcAddr, dstAddr, dstPort)
if matchedRule != nil {
// Check if we need to perform DNAT
if matchedRule.RewriteTo.IsValid() && matchedRule.RewriteTo.Addr().IsValid() {
// Perform DNAT - rewrite destination IP
originalDst := dstAddr
newDst := matchedRule.RewriteTo.Addr()
// Create connection tracking key
if matchedRule.RewriteTo != "" {
// Create connection tracking key using original destination
// This allows us to check if we've already resolved for this connection
var srcPort uint16
switch protocol {
case header.TCPProtocolNumber:
@@ -350,21 +398,48 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
srcPort = udpHeader.SourcePort()
}
// Key using original destination to track the connection
key := connKey{
srcIP: srcAddr.String(),
srcPort: srcPort,
dstIP: newDst.String(),
dstIP: dstAddr.String(),
dstPort: dstPort,
proto: uint8(protocol),
}
// Store NAT state for reverse translation
p.natMu.Lock()
p.natTable[key] = &natState{
originalDst: originalDst,
rewrittenTo: newDst,
// Check if we already have a NAT entry for this connection
p.natMu.RLock()
existingEntry, exists := p.natTable[key]
p.natMu.RUnlock()
var newDst netip.Addr
if exists {
// Use the previously resolved address for this connection
newDst = existingEntry.rewrittenTo
logger.Debug("Using existing NAT entry for connection: %s -> %s", dstAddr, newDst)
} else {
// New connection - resolve the rewrite address
var err error
newDst, err = p.resolveRewriteAddress(matchedRule.RewriteTo)
if err != nil {
// Failed to resolve, skip DNAT but still proxy the packet
logger.Debug("Failed to resolve rewrite address: %v", err)
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet),
})
p.proxyEp.InjectInbound(header.IPv4ProtocolNumber, pkb)
return true
}
// Store NAT state for this connection
p.natMu.Lock()
p.natTable[key] = &natState{
originalDst: dstAddr,
rewrittenTo: newDst,
}
p.natMu.Unlock()
logger.Debug("New NAT entry for connection: %s -> %s", dstAddr, newDst)
}
p.natMu.Unlock()
// Rewrite the packet
packet = p.rewritePacketDestination(packet, newDst)
@@ -534,20 +609,23 @@ func (p *ProxyHandler) ReadOutgoingPacket() *buffer.View {
}
}
// Look up NAT state (key is based on the request, so dst/src are swapped for replies)
key := connKey{
srcIP: dstIP.String(),
srcPort: dstPort,
dstIP: srcIP.String(),
dstPort: srcPort,
proto: uint8(protocol),
}
// Look up NAT state for reverse translation
// The key uses the original dst (before rewrite), so for replies we need to
// find the entry where the rewritten address matches the current source
p.natMu.RLock()
natEntry, exists := p.natTable[key]
var natEntry *natState
for k, entry := range p.natTable {
// Match: reply's dst should be original src, reply's src should be rewritten dst
if k.srcIP == dstIP.String() && k.srcPort == dstPort &&
entry.rewrittenTo.String() == srcIP.String() && k.dstPort == srcPort &&
k.proto == uint8(protocol) {
natEntry = entry
break
}
}
p.natMu.RUnlock()
if exists {
if natEntry != nil {
// Perform reverse NAT - rewrite source to original destination
packet = p.rewritePacketSource(packet, natEntry.originalDst)
if packet != nil {

View File

@@ -350,7 +350,8 @@ func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
// AddProxySubnetRule adds a subnet rule to the proxy handler
// If portRanges is nil or empty, all ports are allowed for this subnet
func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix, rewriteTo netip.Prefix, portRanges []PortRange) {
// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com")
func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange) {
tun := (*netTun)(net)
if tun.proxyHandler != nil {
tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges)

View File

@@ -139,6 +139,18 @@ func FixKey(key string) string {
return hex.EncodeToString(decoded)
}
// this is the opposite of FixKey
func UnfixKey(hexKey string) string {
// Decode from hex
decoded, err := hex.DecodeString(hexKey)
if err != nil {
logger.Fatal("Error decoding hex: %v", err)
}
// Convert to base64
return base64.StdEncoding.EncodeToString(decoded)
}
func MapToWireGuardLogLevel(level logger.LogLevel) int {
switch level {
case logger.DEBUG:

View File

@@ -206,6 +206,26 @@ func (c *Client) SendMessage(messageType string, data interface{}) error {
return nil
}
// SendMessage sends a message through the WebSocket connection
func (c *Client) SendMessageNoLog(messageType string, data interface{}) error {
if c.conn == nil {
return fmt.Errorf("not connected")
}
msg := WSMessage{
Type: messageType,
Data: data,
}
c.writeMux.Lock()
defer c.writeMux.Unlock()
if err := c.conn.WriteJSON(msg); err != nil {
return err
}
telemetry.IncWSMessage(c.metricsContext(), "out", "text")
return nil
}
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func()) {
stopChan := make(chan struct{})
go func() {