mirror of
https://github.com/fosrl/newt.git
synced 2025-12-05 19:17:38 -06:00
Compare commits
3 Commits
61b9615aea
...
1b1323b553
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1b1323b553 | ||
|
|
bb95d10e86 | ||
|
|
da04746781 |
24
clients.go
24
clients.go
@@ -29,19 +29,9 @@ func setupClients(client *websocket.Client) {
|
||||
|
||||
host = strings.TrimSuffix(host, "/")
|
||||
|
||||
if useNativeInterface {
|
||||
// setupClientsNative(client, host)
|
||||
} else {
|
||||
setupClientsNetstack(client, host)
|
||||
}
|
||||
|
||||
ready = true
|
||||
}
|
||||
|
||||
func setupClientsNetstack(client *websocket.Client, host string) {
|
||||
logger.Info("Setting up clients with netstack2...")
|
||||
// Create WireGuard service
|
||||
wgService, err = wgnetstack.NewWireGuardService(interfaceName, mtuInt, generateAndSaveKeyTo, host, id, client, "9.9.9.9")
|
||||
wgService, err = wgnetstack.NewWireGuardService(interfaceName, mtuInt, generateAndSaveKeyTo, host, id, client, "9.9.9.9", useNativeInterface)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to create WireGuard service: %v", err)
|
||||
}
|
||||
@@ -66,6 +56,8 @@ func setupClientsNetstack(client *websocket.Client, host string) {
|
||||
client.OnTokenUpdate(func(token string) {
|
||||
wgService.SetToken(token)
|
||||
})
|
||||
|
||||
ready = true
|
||||
}
|
||||
|
||||
func setDownstreamTNetstack(tnet *netstack.Net) {
|
||||
@@ -77,12 +69,10 @@ func setDownstreamTNetstack(tnet *netstack.Net) {
|
||||
func closeClients() {
|
||||
logger.Info("Closing clients...")
|
||||
if wgService != nil {
|
||||
wgService.Close(!keepInterface)
|
||||
wgService.Close()
|
||||
wgService = nil
|
||||
}
|
||||
|
||||
// closeWgServiceNative()
|
||||
|
||||
if wgTesterServer != nil {
|
||||
wgTesterServer.Stop()
|
||||
wgTesterServer = nil
|
||||
@@ -105,8 +95,6 @@ func clientsHandleNewtConnection(publicKey string, endpoint string) {
|
||||
if wgService != nil {
|
||||
wgService.StartHolepunch(publicKey, endpoint)
|
||||
}
|
||||
|
||||
// clientsHandleNewtConnectionNative(publicKey, endpoint)
|
||||
}
|
||||
|
||||
func clientsOnConnect() {
|
||||
@@ -116,8 +104,6 @@ func clientsOnConnect() {
|
||||
if wgService != nil {
|
||||
wgService.LoadRemoteConfig()
|
||||
}
|
||||
|
||||
// clientsOnConnectNative()
|
||||
}
|
||||
|
||||
func clientsAddProxyTarget(pm *proxy.ProxyManager, tunnelIp string) {
|
||||
@@ -129,6 +115,4 @@ func clientsAddProxyTarget(pm *proxy.ProxyManager, tunnelIp string) {
|
||||
if wgService != nil {
|
||||
pm.AddTarget("udp", tunnelIp, int(wgService.Port), fmt.Sprintf("127.0.0.1:%d", wgService.Port))
|
||||
}
|
||||
|
||||
// clientsAddProxyTargetNative(pm, tunnelIp)
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -18,9 +19,11 @@ import (
|
||||
"github.com/fosrl/newt/holepunch"
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/netstack2"
|
||||
"github.com/fosrl/newt/network"
|
||||
"github.com/fosrl/newt/util"
|
||||
"github.com/fosrl/newt/websocket"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
@@ -37,6 +40,7 @@ type WgConfig struct {
|
||||
type Target struct {
|
||||
SourcePrefix string `json:"sourcePrefix"`
|
||||
DestPrefix string `json:"destPrefix"`
|
||||
RewriteTo string `json:"rewriteTo,omitempty"`
|
||||
PortRange []PortRange `json:"portRange,omitempty"`
|
||||
}
|
||||
|
||||
@@ -91,11 +95,12 @@ type WireGuardService struct {
|
||||
// Proxy manager for tunnel
|
||||
TunnelIP string
|
||||
// Shared bind and holepunch manager
|
||||
sharedBind *bind.SharedBind
|
||||
holePunchManager *holepunch.Manager
|
||||
sharedBind *bind.SharedBind
|
||||
holePunchManager *holepunch.Manager
|
||||
useNativeInterface bool
|
||||
}
|
||||
|
||||
func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client, dns string) (*WireGuardService, error) {
|
||||
func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client, dns string, useNativeInterface bool) (*WireGuardService, error) {
|
||||
var key wgtypes.Key
|
||||
var err error
|
||||
|
||||
@@ -158,17 +163,18 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str
|
||||
dnsAddrs := []netip.Addr{netip.MustParseAddr(dns)}
|
||||
|
||||
service := &WireGuardService{
|
||||
interfaceName: interfaceName,
|
||||
mtu: mtu,
|
||||
client: wsClient,
|
||||
key: key,
|
||||
keyFilePath: generateAndSaveKeyTo,
|
||||
newtId: newtId,
|
||||
host: host,
|
||||
lastReadings: make(map[string]PeerReading),
|
||||
Port: port,
|
||||
dns: dnsAddrs,
|
||||
sharedBind: sharedBind,
|
||||
interfaceName: interfaceName,
|
||||
mtu: mtu,
|
||||
client: wsClient,
|
||||
key: key,
|
||||
keyFilePath: generateAndSaveKeyTo,
|
||||
newtId: newtId,
|
||||
host: host,
|
||||
lastReadings: make(map[string]PeerReading),
|
||||
Port: port,
|
||||
dns: dnsAddrs,
|
||||
sharedBind: sharedBind,
|
||||
useNativeInterface: useNativeInterface,
|
||||
}
|
||||
|
||||
// Create the holepunch manager with ResolveDomain function
|
||||
@@ -199,7 +205,7 @@ func (s *WireGuardService) SetOthertnet(tnet *netstack.Net) {
|
||||
s.othertnet = tnet
|
||||
}
|
||||
|
||||
func (s *WireGuardService) Close(rm bool) {
|
||||
func (s *WireGuardService) Close() {
|
||||
if s.stopGetConfig != nil {
|
||||
s.stopGetConfig()
|
||||
s.stopGetConfig = nil
|
||||
@@ -355,11 +361,94 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
||||
s.holePunchManager.Stop()
|
||||
}
|
||||
|
||||
// Parse the IP address from the config
|
||||
// tunnelIP := netip.MustParseAddr(wgconfig.IpAddress)
|
||||
var err error
|
||||
|
||||
if s.useNativeInterface {
|
||||
// Create native TUN device
|
||||
var interfaceName = s.interfaceName
|
||||
if runtime.GOOS == "darwin" {
|
||||
interfaceName, err = network.FindUnusedUTUN()
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("failed to find unused utun: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.tun, err = tun.CreateTUN(interfaceName, s.mtu)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("failed to create native TUN device: %v", err)
|
||||
}
|
||||
|
||||
// Get the real interface name (may differ on some platforms)
|
||||
if realName, err := s.tun.Name(); err == nil {
|
||||
interfaceName = realName
|
||||
}
|
||||
|
||||
s.TunnelIP = tunnelIP.String()
|
||||
// s.tnet is nil for native interface - proxy features not available
|
||||
s.tnet = nil
|
||||
|
||||
// Create WireGuard device using the shared bind
|
||||
s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger(
|
||||
device.LogLevelSilent,
|
||||
"wireguard: ",
|
||||
))
|
||||
|
||||
fileUAPI, err := func() (*os.File, error) {
|
||||
return ipc.UAPIOpen(interfaceName)
|
||||
}()
|
||||
if err != nil {
|
||||
logger.Error("UAPI listen error: %v", err)
|
||||
}
|
||||
|
||||
uapiListener, err := ipc.UAPIListen(interfaceName, fileUAPI)
|
||||
if err != nil {
|
||||
logger.Error("Failed to listen on uapi socket: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := uapiListener.Accept()
|
||||
if err != nil {
|
||||
|
||||
return
|
||||
}
|
||||
go s.device.IpcHandle(conn)
|
||||
}
|
||||
}()
|
||||
logger.Info("UAPI listener started")
|
||||
|
||||
// Configure WireGuard with private key
|
||||
config := fmt.Sprintf("private_key=%s", util.FixKey(s.key.String()))
|
||||
|
||||
err = s.device.IpcSet(config)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("failed to configure WireGuard device: %v", err)
|
||||
}
|
||||
|
||||
// Bring up the device
|
||||
err = s.device.Up()
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("failed to bring up WireGuard device: %v", err)
|
||||
}
|
||||
|
||||
// Configure the network interface with IP address
|
||||
if err := network.ConfigureInterface(interfaceName, wgconfig.IpAddress, s.mtu); err != nil {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("failed to configure interface: %v", err)
|
||||
}
|
||||
|
||||
logger.Info("WireGuard native device created and configured on %s", interfaceName)
|
||||
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create TUN device and network stack using netstack
|
||||
var err error
|
||||
s.tun, s.tnet, err = netstack2.CreateNetTUNWithOptions(
|
||||
[]netip.Addr{tunnelIP},
|
||||
s.dns,
|
||||
@@ -382,8 +471,6 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
||||
"wireguard: ",
|
||||
))
|
||||
|
||||
// logger.Info("Private key is %s", fixKey(s.key.String()))
|
||||
|
||||
// Configure WireGuard with private key
|
||||
config := fmt.Sprintf("private_key=%s", util.FixKey(s.key.String()))
|
||||
|
||||
@@ -458,7 +545,9 @@ func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
|
||||
|
||||
func (s *WireGuardService) ensureTargets(targets []Target) error {
|
||||
if s.tnet == nil {
|
||||
return fmt.Errorf("netstack not initialized")
|
||||
// Native interface mode - proxy features not available, skip silently
|
||||
logger.Debug("Skipping target configuration - using native interface (no proxy support)")
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, target := range targets {
|
||||
@@ -472,6 +561,15 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
|
||||
return fmt.Errorf("invalid CIDR %s: %v", target.DestPrefix, err)
|
||||
}
|
||||
|
||||
var rewriteTo netip.Prefix
|
||||
if target.RewriteTo != "" {
|
||||
rewriteTo, err = netip.ParsePrefix(target.RewriteTo)
|
||||
if err != nil {
|
||||
logger.Info("Invalid CIDR %s: %v", target.RewriteTo, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
var portRanges []netstack2.PortRange
|
||||
for _, pr := range target.PortRange {
|
||||
portRanges = append(portRanges, netstack2.PortRange{
|
||||
@@ -480,7 +578,7 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
|
||||
})
|
||||
}
|
||||
|
||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, portRanges)
|
||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges)
|
||||
|
||||
logger.Info("Added target subnet from %s to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.PortRange)
|
||||
}
|
||||
@@ -839,7 +937,8 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
|
||||
}
|
||||
|
||||
if s.tnet == nil {
|
||||
logger.Info("Netstack not initialized")
|
||||
// Native interface mode - proxy features not available, skip silently
|
||||
logger.Debug("Skipping add target - using native interface (no proxy support)")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -864,6 +963,15 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
|
||||
continue
|
||||
}
|
||||
|
||||
var rewriteTo netip.Prefix
|
||||
if target.RewriteTo != "" {
|
||||
rewriteTo, err = netip.ParsePrefix(target.RewriteTo)
|
||||
if err != nil {
|
||||
logger.Info("Invalid CIDR %s: %v", target.RewriteTo, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
var portRanges []netstack2.PortRange
|
||||
for _, pr := range target.PortRange {
|
||||
portRanges = append(portRanges, netstack2.PortRange{
|
||||
@@ -872,7 +980,7 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
|
||||
})
|
||||
}
|
||||
|
||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, portRanges)
|
||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges)
|
||||
|
||||
logger.Info("Added target subnet from %s to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.PortRange)
|
||||
}
|
||||
@@ -889,7 +997,8 @@ func (s *WireGuardService) handleRemoveTarget(msg websocket.WSMessage) {
|
||||
}
|
||||
|
||||
if s.tnet == nil {
|
||||
logger.Info("Netstack not initialized")
|
||||
// Native interface mode - proxy features not available, skip silently
|
||||
logger.Debug("Skipping remove target - using native interface (no proxy support)")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -936,7 +1045,8 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
|
||||
}
|
||||
|
||||
if s.tnet == nil {
|
||||
logger.Info("Netstack not initialized")
|
||||
// Native interface mode - proxy features not available, skip silently
|
||||
logger.Debug("Skipping update target - using native interface (no proxy support)")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -979,6 +1089,15 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
|
||||
continue
|
||||
}
|
||||
|
||||
var rewriteTo netip.Prefix
|
||||
if target.RewriteTo != "" {
|
||||
rewriteTo, err = netip.ParsePrefix(target.RewriteTo)
|
||||
if err != nil {
|
||||
logger.Info("Invalid CIDR %s: %v", target.RewriteTo, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
var portRanges []netstack2.PortRange
|
||||
for _, pr := range target.PortRange {
|
||||
portRanges = append(portRanges, netstack2.PortRange{
|
||||
@@ -987,7 +1106,7 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
|
||||
})
|
||||
}
|
||||
|
||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, portRanges)
|
||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges)
|
||||
logger.Info("Added target subnet from %s to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.PortRange)
|
||||
}
|
||||
}
|
||||
|
||||
6
main.go
6
main.go
@@ -117,7 +117,6 @@ var (
|
||||
logLevel string
|
||||
interfaceName string
|
||||
generateAndSaveKeyTo string
|
||||
keepInterface bool
|
||||
acceptClients bool
|
||||
updownScript string
|
||||
dockerSocket string
|
||||
@@ -178,8 +177,6 @@ func main() {
|
||||
regionEnv := os.Getenv("NEWT_REGION")
|
||||
asyncBytesEnv := os.Getenv("NEWT_METRICS_ASYNC_BYTES")
|
||||
|
||||
keepInterfaceEnv := os.Getenv("KEEP_INTERFACE")
|
||||
keepInterface = keepInterfaceEnv == "true"
|
||||
acceptClientsEnv := os.Getenv("ACCEPT_CLIENTS")
|
||||
acceptClients = acceptClientsEnv == "true"
|
||||
useNativeInterfaceEnv := os.Getenv("USE_NATIVE_INTERFACE")
|
||||
@@ -243,9 +240,6 @@ func main() {
|
||||
if generateAndSaveKeyTo == "" {
|
||||
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
|
||||
}
|
||||
if keepInterfaceEnv == "" {
|
||||
flag.BoolVar(&keepInterface, "keep-interface", false, "Keep the WireGuard interface")
|
||||
}
|
||||
if useNativeInterfaceEnv == "" {
|
||||
flag.BoolVar(&useNativeInterface, "native", false, "Use native WireGuard interface (requires WireGuard kernel module) and linux")
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/checksum"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
@@ -24,9 +25,15 @@ type PortRange struct {
|
||||
}
|
||||
|
||||
// SubnetRule represents a subnet with optional port restrictions and source address
|
||||
// When RewriteTo is set, DNAT (Destination Network Address Translation) is performed:
|
||||
// - Incoming packets: destination IP is rewritten to RewriteTo.Addr()
|
||||
// - Outgoing packets: source IP is rewritten back to the original destination
|
||||
//
|
||||
// This allows transparent proxying where traffic appears to come from the rewritten address
|
||||
type SubnetRule struct {
|
||||
SourcePrefix netip.Prefix // Source IP prefix (who is sending)
|
||||
DestPrefix netip.Prefix // Destination IP prefix (where it's going)
|
||||
RewriteTo netip.Prefix // Optional rewrite address for DNAT (destination NAT)
|
||||
PortRanges []PortRange // empty slice means all ports allowed
|
||||
}
|
||||
|
||||
@@ -51,7 +58,7 @@ func NewSubnetLookup() *SubnetLookup {
|
||||
|
||||
// AddSubnet adds a subnet rule with source and destination prefixes and optional port restrictions
|
||||
// If portRanges is nil or empty, all ports are allowed for this subnet
|
||||
func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, portRanges []PortRange) {
|
||||
func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix, rewriteTo netip.Prefix, portRanges []PortRange) {
|
||||
sl.mu.Lock()
|
||||
defer sl.mu.Unlock()
|
||||
|
||||
@@ -63,6 +70,7 @@ func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, portRan
|
||||
sl.rules[key] = &SubnetRule{
|
||||
SourcePrefix: sourcePrefix,
|
||||
DestPrefix: destPrefix,
|
||||
RewriteTo: rewriteTo,
|
||||
PortRanges: portRanges,
|
||||
}
|
||||
}
|
||||
@@ -81,13 +89,13 @@ func (sl *SubnetLookup) RemoveSubnet(sourcePrefix, destPrefix netip.Prefix) {
|
||||
}
|
||||
|
||||
// Match checks if a source IP, destination IP, and port match any subnet rule
|
||||
// Returns true if BOTH:
|
||||
// Returns the matched rule if BOTH:
|
||||
// - The source IP is in the rule's source prefix
|
||||
// - The destination IP is in the rule's destination prefix
|
||||
// - The port is in an allowed range (or no port restrictions exist)
|
||||
//
|
||||
// This implementation uses O(n) iteration but checks exact prefix matches first for common cases
|
||||
func (sl *SubnetLookup) Match(srcIP, dstIP netip.Addr, port uint16) bool {
|
||||
// Returns nil if no rule matches
|
||||
func (sl *SubnetLookup) Match(srcIP, dstIP netip.Addr, port uint16) *SubnetRule {
|
||||
sl.mu.RLock()
|
||||
defer sl.mu.RUnlock()
|
||||
|
||||
@@ -105,18 +113,33 @@ func (sl *SubnetLookup) Match(srcIP, dstIP netip.Addr, port uint16) bool {
|
||||
// Both IPs match - now check port restrictions
|
||||
// If no port ranges specified, all ports are allowed
|
||||
if len(rule.PortRanges) == 0 {
|
||||
return true
|
||||
return rule
|
||||
}
|
||||
|
||||
// Check if port is in any of the allowed ranges
|
||||
for _, pr := range rule.PortRanges {
|
||||
if port >= pr.Min && port <= pr.Max {
|
||||
return true
|
||||
return rule
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
return nil
|
||||
}
|
||||
|
||||
// connKey uniquely identifies a connection for NAT tracking
|
||||
type connKey struct {
|
||||
srcIP string
|
||||
srcPort uint16
|
||||
dstIP string
|
||||
dstPort uint16
|
||||
proto uint8
|
||||
}
|
||||
|
||||
// natState tracks NAT translation state for reverse translation
|
||||
type natState struct {
|
||||
originalDst netip.Addr // Original destination before DNAT
|
||||
rewrittenTo netip.Addr // The address we rewrote to
|
||||
}
|
||||
|
||||
// ProxyHandler handles packet injection and extraction for promiscuous mode
|
||||
@@ -127,6 +150,8 @@ type ProxyHandler struct {
|
||||
tcpHandler *TCPHandler
|
||||
udpHandler *UDPHandler
|
||||
subnetLookup *SubnetLookup
|
||||
natTable map[connKey]*natState
|
||||
natMu sync.RWMutex
|
||||
enabled bool
|
||||
}
|
||||
|
||||
@@ -146,6 +171,7 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
|
||||
handler := &ProxyHandler{
|
||||
enabled: true,
|
||||
subnetLookup: NewSubnetLookup(),
|
||||
natTable: make(map[connKey]*natState),
|
||||
proxyEp: channel.New(1024, uint32(options.MTU), ""),
|
||||
proxyStack: stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{
|
||||
@@ -200,11 +226,11 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
|
||||
// sourcePrefix: The IP prefix of the peer sending the data
|
||||
// destPrefix: The IP prefix of the destination
|
||||
// If portRanges is nil or empty, all ports are allowed for this subnet
|
||||
func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, portRanges []PortRange) {
|
||||
func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix, rewriteTo netip.Prefix, portRanges []PortRange) {
|
||||
if p == nil || !p.enabled {
|
||||
return
|
||||
}
|
||||
p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, portRanges)
|
||||
p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, rewriteTo, portRanges)
|
||||
}
|
||||
|
||||
// RemoveSubnetRule removes a subnet from the proxy handler
|
||||
@@ -305,7 +331,48 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
|
||||
}
|
||||
|
||||
// Check if the source IP, destination IP, and port match any subnet rule
|
||||
if p.subnetLookup.Match(srcAddr, dstAddr, dstPort) {
|
||||
matchedRule := p.subnetLookup.Match(srcAddr, dstAddr, dstPort)
|
||||
if matchedRule != nil {
|
||||
// Check if we need to perform DNAT
|
||||
if matchedRule.RewriteTo.IsValid() && matchedRule.RewriteTo.Addr().IsValid() {
|
||||
// Perform DNAT - rewrite destination IP
|
||||
originalDst := dstAddr
|
||||
newDst := matchedRule.RewriteTo.Addr()
|
||||
|
||||
// Create connection tracking key
|
||||
var srcPort uint16
|
||||
switch protocol {
|
||||
case header.TCPProtocolNumber:
|
||||
tcpHeader := header.TCP(packet[headerLen:])
|
||||
srcPort = tcpHeader.SourcePort()
|
||||
case header.UDPProtocolNumber:
|
||||
udpHeader := header.UDP(packet[headerLen:])
|
||||
srcPort = udpHeader.SourcePort()
|
||||
}
|
||||
|
||||
key := connKey{
|
||||
srcIP: srcAddr.String(),
|
||||
srcPort: srcPort,
|
||||
dstIP: newDst.String(),
|
||||
dstPort: dstPort,
|
||||
proto: uint8(protocol),
|
||||
}
|
||||
|
||||
// Store NAT state for reverse translation
|
||||
p.natMu.Lock()
|
||||
p.natTable[key] = &natState{
|
||||
originalDst: originalDst,
|
||||
rewrittenTo: newDst,
|
||||
}
|
||||
p.natMu.Unlock()
|
||||
|
||||
// Rewrite the packet
|
||||
packet = p.rewritePacketDestination(packet, newDst)
|
||||
if packet == nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Inject into proxy stack
|
||||
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(packet),
|
||||
@@ -317,6 +384,118 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// rewritePacketDestination rewrites the destination IP in a packet and recalculates checksums
|
||||
func (p *ProxyHandler) rewritePacketDestination(packet []byte, newDst netip.Addr) []byte {
|
||||
if len(packet) < header.IPv4MinimumSize {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Make a copy to avoid modifying the original
|
||||
pkt := make([]byte, len(packet))
|
||||
copy(pkt, packet)
|
||||
|
||||
ipv4Header := header.IPv4(pkt)
|
||||
headerLen := int(ipv4Header.HeaderLength())
|
||||
|
||||
// Rewrite destination IP
|
||||
newDstBytes := newDst.As4()
|
||||
newDstAddr := tcpip.AddrFrom4(newDstBytes)
|
||||
ipv4Header.SetDestinationAddress(newDstAddr)
|
||||
|
||||
// Recalculate IP checksum
|
||||
ipv4Header.SetChecksum(0)
|
||||
ipv4Header.SetChecksum(^ipv4Header.CalculateChecksum())
|
||||
|
||||
// Update transport layer checksum if needed
|
||||
protocol := ipv4Header.TransportProtocol()
|
||||
switch protocol {
|
||||
case header.TCPProtocolNumber:
|
||||
if len(pkt) >= headerLen+header.TCPMinimumSize {
|
||||
tcpHeader := header.TCP(pkt[headerLen:])
|
||||
tcpHeader.SetChecksum(0)
|
||||
xsum := header.PseudoHeaderChecksum(
|
||||
header.TCPProtocolNumber,
|
||||
ipv4Header.SourceAddress(),
|
||||
ipv4Header.DestinationAddress(),
|
||||
uint16(len(pkt)-headerLen),
|
||||
)
|
||||
xsum = checksum.Checksum(pkt[headerLen:], xsum)
|
||||
tcpHeader.SetChecksum(^xsum)
|
||||
}
|
||||
case header.UDPProtocolNumber:
|
||||
if len(pkt) >= headerLen+header.UDPMinimumSize {
|
||||
udpHeader := header.UDP(pkt[headerLen:])
|
||||
udpHeader.SetChecksum(0)
|
||||
xsum := header.PseudoHeaderChecksum(
|
||||
header.UDPProtocolNumber,
|
||||
ipv4Header.SourceAddress(),
|
||||
ipv4Header.DestinationAddress(),
|
||||
uint16(len(pkt)-headerLen),
|
||||
)
|
||||
xsum = checksum.Checksum(pkt[headerLen:], xsum)
|
||||
udpHeader.SetChecksum(^xsum)
|
||||
}
|
||||
}
|
||||
|
||||
return pkt
|
||||
}
|
||||
|
||||
// rewritePacketSource rewrites the source IP in a packet and recalculates checksums (for reverse NAT)
|
||||
func (p *ProxyHandler) rewritePacketSource(packet []byte, newSrc netip.Addr) []byte {
|
||||
if len(packet) < header.IPv4MinimumSize {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Make a copy to avoid modifying the original
|
||||
pkt := make([]byte, len(packet))
|
||||
copy(pkt, packet)
|
||||
|
||||
ipv4Header := header.IPv4(pkt)
|
||||
headerLen := int(ipv4Header.HeaderLength())
|
||||
|
||||
// Rewrite source IP
|
||||
newSrcBytes := newSrc.As4()
|
||||
newSrcAddr := tcpip.AddrFrom4(newSrcBytes)
|
||||
ipv4Header.SetSourceAddress(newSrcAddr)
|
||||
|
||||
// Recalculate IP checksum
|
||||
ipv4Header.SetChecksum(0)
|
||||
ipv4Header.SetChecksum(^ipv4Header.CalculateChecksum())
|
||||
|
||||
// Update transport layer checksum if needed
|
||||
protocol := ipv4Header.TransportProtocol()
|
||||
switch protocol {
|
||||
case header.TCPProtocolNumber:
|
||||
if len(pkt) >= headerLen+header.TCPMinimumSize {
|
||||
tcpHeader := header.TCP(pkt[headerLen:])
|
||||
tcpHeader.SetChecksum(0)
|
||||
xsum := header.PseudoHeaderChecksum(
|
||||
header.TCPProtocolNumber,
|
||||
ipv4Header.SourceAddress(),
|
||||
ipv4Header.DestinationAddress(),
|
||||
uint16(len(pkt)-headerLen),
|
||||
)
|
||||
xsum = checksum.Checksum(pkt[headerLen:], xsum)
|
||||
tcpHeader.SetChecksum(^xsum)
|
||||
}
|
||||
case header.UDPProtocolNumber:
|
||||
if len(pkt) >= headerLen+header.UDPMinimumSize {
|
||||
udpHeader := header.UDP(pkt[headerLen:])
|
||||
udpHeader.SetChecksum(0)
|
||||
xsum := header.PseudoHeaderChecksum(
|
||||
header.UDPProtocolNumber,
|
||||
ipv4Header.SourceAddress(),
|
||||
ipv4Header.DestinationAddress(),
|
||||
uint16(len(pkt)-headerLen),
|
||||
)
|
||||
xsum = checksum.Checksum(pkt[headerLen:], xsum)
|
||||
udpHeader.SetChecksum(^xsum)
|
||||
}
|
||||
}
|
||||
|
||||
return pkt
|
||||
}
|
||||
|
||||
// ReadOutgoingPacket reads packets from the proxy stack that need to be
|
||||
// sent back through the tunnel
|
||||
func (p *ProxyHandler) ReadOutgoingPacket() *buffer.View {
|
||||
@@ -328,6 +507,55 @@ func (p *ProxyHandler) ReadOutgoingPacket() *buffer.View {
|
||||
if pkt != nil {
|
||||
view := pkt.ToView()
|
||||
pkt.DecRef()
|
||||
|
||||
// Check if we need to perform reverse NAT
|
||||
packet := view.AsSlice()
|
||||
if len(packet) >= header.IPv4MinimumSize && packet[0]>>4 == 4 {
|
||||
ipv4Header := header.IPv4(packet)
|
||||
srcIP := ipv4Header.SourceAddress()
|
||||
dstIP := ipv4Header.DestinationAddress()
|
||||
protocol := ipv4Header.TransportProtocol()
|
||||
headerLen := int(ipv4Header.HeaderLength())
|
||||
|
||||
// Extract ports
|
||||
var srcPort, dstPort uint16
|
||||
switch protocol {
|
||||
case header.TCPProtocolNumber:
|
||||
if len(packet) >= headerLen+header.TCPMinimumSize {
|
||||
tcpHeader := header.TCP(packet[headerLen:])
|
||||
srcPort = tcpHeader.SourcePort()
|
||||
dstPort = tcpHeader.DestinationPort()
|
||||
}
|
||||
case header.UDPProtocolNumber:
|
||||
if len(packet) >= headerLen+header.UDPMinimumSize {
|
||||
udpHeader := header.UDP(packet[headerLen:])
|
||||
srcPort = udpHeader.SourcePort()
|
||||
dstPort = udpHeader.DestinationPort()
|
||||
}
|
||||
}
|
||||
|
||||
// Look up NAT state (key is based on the request, so dst/src are swapped for replies)
|
||||
key := connKey{
|
||||
srcIP: dstIP.String(),
|
||||
srcPort: dstPort,
|
||||
dstIP: srcIP.String(),
|
||||
dstPort: srcPort,
|
||||
proto: uint8(protocol),
|
||||
}
|
||||
|
||||
p.natMu.RLock()
|
||||
natEntry, exists := p.natTable[key]
|
||||
p.natMu.RUnlock()
|
||||
|
||||
if exists {
|
||||
// Perform reverse NAT - rewrite source to original destination
|
||||
packet = p.rewritePacketSource(packet, natEntry.originalDst)
|
||||
if packet != nil {
|
||||
return buffer.NewViewWithData(packet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return view
|
||||
}
|
||||
|
||||
|
||||
@@ -350,10 +350,10 @@ func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
|
||||
|
||||
// AddProxySubnetRule adds a subnet rule to the proxy handler
|
||||
// If portRanges is nil or empty, all ports are allowed for this subnet
|
||||
func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix netip.Prefix, portRanges []PortRange) {
|
||||
func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix, rewriteTo netip.Prefix, portRanges []PortRange) {
|
||||
tun := (*netTun)(net)
|
||||
if tun.proxyHandler != nil {
|
||||
tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, portRanges)
|
||||
tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
165
network/interface.go
Normal file
165
network/interface.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
// ConfigureInterface configures a network interface with an IP address and brings it up
|
||||
func ConfigureInterface(interfaceName string, tunnelIp string, mtu int) error {
|
||||
logger.Info("The tunnel IP is: %s", tunnelIp)
|
||||
|
||||
// Parse the IP address and network
|
||||
ip, ipNet, err := net.ParseCIDR(tunnelIp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid IP address: %v", err)
|
||||
}
|
||||
|
||||
// Convert CIDR mask to dotted decimal format (e.g., 255.255.255.0)
|
||||
mask := net.IP(ipNet.Mask).String()
|
||||
destinationAddress := ip.String()
|
||||
|
||||
logger.Debug("The destination address is: %s", destinationAddress)
|
||||
|
||||
// network.SetTunnelRemoteAddress() // what does this do?
|
||||
SetIPv4Settings([]string{destinationAddress}, []string{mask})
|
||||
SetMTU(mtu)
|
||||
|
||||
if interfaceName == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "linux":
|
||||
return configureLinux(interfaceName, ip, ipNet)
|
||||
case "darwin":
|
||||
return configureDarwin(interfaceName, ip, ipNet)
|
||||
case "windows":
|
||||
return configureWindows(interfaceName, ip, ipNet)
|
||||
default:
|
||||
return fmt.Errorf("unsupported operating system: %s", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
|
||||
// waitForInterfaceUp polls the network interface until it's up or times out
|
||||
func waitForInterfaceUp(interfaceName string, expectedIP net.IP, timeout time.Duration) error {
|
||||
logger.Info("Waiting for interface %s to be up with IP %s", interfaceName, expectedIP)
|
||||
deadline := time.Now().Add(timeout)
|
||||
pollInterval := 500 * time.Millisecond
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
// Check if interface exists and is up
|
||||
iface, err := net.InterfaceByName(interfaceName)
|
||||
if err == nil {
|
||||
// Check if interface is up
|
||||
if iface.Flags&net.FlagUp != 0 {
|
||||
// Check if it has the expected IP
|
||||
addrs, err := iface.Addrs()
|
||||
if err == nil {
|
||||
for _, addr := range addrs {
|
||||
ipNet, ok := addr.(*net.IPNet)
|
||||
if ok && ipNet.IP.Equal(expectedIP) {
|
||||
logger.Info("Interface %s is up with correct IP", interfaceName)
|
||||
return nil // Interface is up with correct IP
|
||||
}
|
||||
}
|
||||
logger.Info("Interface %s is up but doesn't have expected IP yet", interfaceName)
|
||||
}
|
||||
} else {
|
||||
logger.Info("Interface %s exists but is not up yet", interfaceName)
|
||||
}
|
||||
} else {
|
||||
logger.Info("Interface %s not found yet: %v", interfaceName, err)
|
||||
}
|
||||
|
||||
// Wait before next check
|
||||
time.Sleep(pollInterval)
|
||||
}
|
||||
|
||||
return fmt.Errorf("timed out waiting for interface %s to be up with IP %s", interfaceName, expectedIP)
|
||||
}
|
||||
|
||||
func FindUnusedUTUN() (string, error) {
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to list interfaces: %v", err)
|
||||
}
|
||||
used := make(map[int]bool)
|
||||
re := regexp.MustCompile(`^utun(\d+)$`)
|
||||
for _, iface := range ifaces {
|
||||
if matches := re.FindStringSubmatch(iface.Name); len(matches) == 2 {
|
||||
if num, err := strconv.Atoi(matches[1]); err == nil {
|
||||
used[num] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
// Try utun0 up to utun255.
|
||||
for i := 0; i < 256; i++ {
|
||||
if !used[i] {
|
||||
return fmt.Sprintf("utun%d", i), nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("no unused utun interface found")
|
||||
}
|
||||
|
||||
func configureDarwin(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
||||
logger.Info("Configuring darwin interface: %s", interfaceName)
|
||||
|
||||
prefix, _ := ipNet.Mask.Size()
|
||||
ipStr := fmt.Sprintf("%s/%d", ip.String(), prefix)
|
||||
|
||||
cmd := exec.Command("ifconfig", interfaceName, "inet", ipStr, ip.String(), "alias")
|
||||
logger.Info("Running command: %v", cmd)
|
||||
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("ifconfig command failed: %v, output: %s", err, out)
|
||||
}
|
||||
|
||||
// Bring up the interface
|
||||
cmd = exec.Command("ifconfig", interfaceName, "up")
|
||||
logger.Info("Running command: %v", cmd)
|
||||
|
||||
out, err = cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("ifconfig up command failed: %v, output: %s", err, out)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func configureLinux(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
||||
// Get the interface
|
||||
link, err := netlink.LinkByName(interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get interface %s: %v", interfaceName, err)
|
||||
}
|
||||
|
||||
// Create the IP address attributes
|
||||
addr := &netlink.Addr{
|
||||
IPNet: &net.IPNet{
|
||||
IP: ip,
|
||||
Mask: ipNet.Mask,
|
||||
},
|
||||
}
|
||||
|
||||
// Add the IP address to the interface
|
||||
if err := netlink.AddrAdd(link, addr); err != nil {
|
||||
return fmt.Errorf("failed to add IP address: %v", err)
|
||||
}
|
||||
|
||||
// Bring up the interface
|
||||
if err := netlink.LinkSetUp(link); err != nil {
|
||||
return fmt.Errorf("failed to bring up interface: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
12
network/interface_notwindows.go
Normal file
12
network/interface_notwindows.go
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build !windows
|
||||
|
||||
package network
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
||||
return fmt.Errorf("configureWindows called on non-Windows platform")
|
||||
}
|
||||
63
network/interface_windows.go
Normal file
63
network/interface_windows.go
Normal file
@@ -0,0 +1,63 @@
|
||||
//go:build windows
|
||||
|
||||
package network
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
)
|
||||
|
||||
func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
||||
logger.Info("Configuring Windows interface: %s", interfaceName)
|
||||
|
||||
// Get the LUID for the interface
|
||||
iface, err := net.InterfaceByName(interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get interface %s: %v", interfaceName, err)
|
||||
}
|
||||
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get LUID for interface %s: %v", interfaceName, err)
|
||||
}
|
||||
|
||||
// Create the IP address prefix
|
||||
maskBits, _ := ipNet.Mask.Size()
|
||||
|
||||
// Ensure we convert to the correct IP version (IPv4 vs IPv6)
|
||||
var addr netip.Addr
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
// IPv4 address
|
||||
addr, _ = netip.AddrFromSlice(ip4)
|
||||
} else {
|
||||
// IPv6 address
|
||||
addr, _ = netip.AddrFromSlice(ip)
|
||||
}
|
||||
if !addr.IsValid() {
|
||||
return fmt.Errorf("failed to convert IP address")
|
||||
}
|
||||
prefix := netip.PrefixFrom(addr, maskBits)
|
||||
|
||||
// Add the IP address to the interface
|
||||
logger.Info("Adding IP address %s to interface %s", prefix.String(), interfaceName)
|
||||
err = luid.AddIPAddress(prefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add IP address: %v", err)
|
||||
}
|
||||
|
||||
// This was required when we were using the subprocess "netsh" command to bring up the interface.
|
||||
// With the winipcfg library, the interface should already be up after adding the IP so we dont
|
||||
// need this step anymore as far as I can tell.
|
||||
|
||||
// // Wait for the interface to be up and have the correct IP
|
||||
// err = waitForInterfaceUp(interfaceName, ip, 30*time.Second)
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("interface did not come up within timeout: %v", err)
|
||||
// }
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,195 +0,0 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/net/bpf"
|
||||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
||||
const (
|
||||
udpProtocol = 17
|
||||
// EmptyUDPSize is the size of an empty UDP packet
|
||||
EmptyUDPSize = 28
|
||||
timeout = time.Second * 10
|
||||
)
|
||||
|
||||
// Server stores data relating to the server
|
||||
type Server struct {
|
||||
Hostname string
|
||||
Addr *net.IPAddr
|
||||
Port uint16
|
||||
}
|
||||
|
||||
// PeerNet stores data about a peer's endpoint
|
||||
type PeerNet struct {
|
||||
Resolved bool
|
||||
IP net.IP
|
||||
Port uint16
|
||||
NewtID string
|
||||
}
|
||||
|
||||
// GetClientIP gets source ip address that will be used when sending data to dstIP
|
||||
func GetClientIP(dstIP net.IP) net.IP {
|
||||
routes, err := netlink.RouteGet(dstIP)
|
||||
if err != nil {
|
||||
log.Fatalln("Error getting route:", err)
|
||||
}
|
||||
return routes[0].Src
|
||||
}
|
||||
|
||||
// HostToAddr resolves a hostname, whether DNS or IP to a valid net.IPAddr
|
||||
func HostToAddr(hostStr string) *net.IPAddr {
|
||||
remoteAddrs, err := net.LookupHost(hostStr)
|
||||
if err != nil {
|
||||
log.Fatalln("Error parsing remote address:", err)
|
||||
}
|
||||
|
||||
for _, addrStr := range remoteAddrs {
|
||||
if remoteAddr, err := net.ResolveIPAddr("ip4", addrStr); err == nil {
|
||||
return remoteAddr
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetupRawConn creates an ipv4 and udp only RawConn and applies packet filtering
|
||||
func SetupRawConn(server *Server, client *PeerNet) *ipv4.RawConn {
|
||||
packetConn, err := net.ListenPacket("ip4:udp", client.IP.String())
|
||||
if err != nil {
|
||||
log.Fatalln("Error creating packetConn:", err)
|
||||
}
|
||||
|
||||
rawConn, err := ipv4.NewRawConn(packetConn)
|
||||
if err != nil {
|
||||
log.Fatalln("Error creating rawConn:", err)
|
||||
}
|
||||
|
||||
ApplyBPF(rawConn, server, client)
|
||||
|
||||
return rawConn
|
||||
}
|
||||
|
||||
// ApplyBPF constructs a BPF program and applies it to the RawConn
|
||||
func ApplyBPF(rawConn *ipv4.RawConn, server *Server, client *PeerNet) {
|
||||
const ipv4HeaderLen = 20
|
||||
const srcIPOffset = 12
|
||||
const srcPortOffset = ipv4HeaderLen + 0
|
||||
const dstPortOffset = ipv4HeaderLen + 2
|
||||
|
||||
ipArr := []byte(server.Addr.IP.To4())
|
||||
ipInt := uint32(ipArr[0])<<(3*8) + uint32(ipArr[1])<<(2*8) + uint32(ipArr[2])<<8 + uint32(ipArr[3])
|
||||
|
||||
bpfRaw, err := bpf.Assemble([]bpf.Instruction{
|
||||
bpf.LoadAbsolute{Off: srcIPOffset, Size: 4},
|
||||
bpf.JumpIf{Cond: bpf.JumpEqual, Val: ipInt, SkipFalse: 5, SkipTrue: 0},
|
||||
|
||||
bpf.LoadAbsolute{Off: srcPortOffset, Size: 2},
|
||||
bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(server.Port), SkipFalse: 3, SkipTrue: 0},
|
||||
|
||||
bpf.LoadAbsolute{Off: dstPortOffset, Size: 2},
|
||||
bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(client.Port), SkipFalse: 1, SkipTrue: 0},
|
||||
|
||||
bpf.RetConstant{Val: 1<<(8*4) - 1},
|
||||
bpf.RetConstant{Val: 0},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Fatalln("Error assembling BPF:", err)
|
||||
}
|
||||
|
||||
err = rawConn.SetBPF(bpfRaw)
|
||||
if err != nil {
|
||||
log.Fatalln("Error setting BPF:", err)
|
||||
}
|
||||
}
|
||||
|
||||
// MakePacket constructs a request packet to send to the server
|
||||
func MakePacket(payload []byte, server *Server, client *PeerNet) []byte {
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
|
||||
opts := gopacket.SerializeOptions{
|
||||
FixLengths: true,
|
||||
ComputeChecksums: true,
|
||||
}
|
||||
|
||||
ipHeader := layers.IPv4{
|
||||
SrcIP: client.IP,
|
||||
DstIP: server.Addr.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
|
||||
udpHeader := layers.UDP{
|
||||
SrcPort: layers.UDPPort(client.Port),
|
||||
DstPort: layers.UDPPort(server.Port),
|
||||
}
|
||||
|
||||
payloadLayer := gopacket.Payload(payload)
|
||||
|
||||
udpHeader.SetNetworkLayerForChecksum(&ipHeader)
|
||||
|
||||
gopacket.SerializeLayers(buf, opts, &ipHeader, &udpHeader, &payloadLayer)
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// SendPacket sends packet to the Server
|
||||
func SendPacket(packet []byte, conn *ipv4.RawConn, server *Server, client *PeerNet) error {
|
||||
fullPacket := MakePacket(packet, server, client)
|
||||
_, err := conn.WriteToIP(fullPacket, server.Addr)
|
||||
return err
|
||||
}
|
||||
|
||||
// SendDataPacket sends a JSON payload to the Server
|
||||
func SendDataPacket(data interface{}, conn *ipv4.RawConn, server *Server, client *PeerNet) error {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal payload: %v", err)
|
||||
}
|
||||
|
||||
return SendPacket(jsonData, conn, server, client)
|
||||
}
|
||||
|
||||
// RecvPacket receives a UDP packet from server
|
||||
func RecvPacket(conn *ipv4.RawConn, server *Server, client *PeerNet) ([]byte, int, error) {
|
||||
err := conn.SetReadDeadline(time.Now().Add(timeout))
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
response := make([]byte, 4096)
|
||||
n, err := conn.Read(response)
|
||||
if err != nil {
|
||||
return nil, n, err
|
||||
}
|
||||
return response, n, nil
|
||||
}
|
||||
|
||||
// RecvDataPacket receives and unmarshals a JSON packet from server
|
||||
func RecvDataPacket(conn *ipv4.RawConn, server *Server, client *PeerNet) ([]byte, error) {
|
||||
response, n, err := RecvPacket(conn, server, client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Extract payload from UDP packet
|
||||
payload := response[EmptyUDPSize:n]
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// ParseResponse takes a response packet and parses it into an IP and port
|
||||
func ParseResponse(response []byte) (net.IP, uint16) {
|
||||
ip := net.IP(response[:4])
|
||||
port := binary.BigEndian.Uint16(response[4:6])
|
||||
return ip, port
|
||||
}
|
||||
282
network/route.go
Normal file
282
network/route.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
func DarwinAddRoute(destination string, gateway string, interfaceName string) error {
|
||||
if runtime.GOOS != "darwin" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var cmd *exec.Cmd
|
||||
|
||||
if gateway != "" {
|
||||
// Route with specific gateway
|
||||
cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-gateway", gateway)
|
||||
} else if interfaceName != "" {
|
||||
// Route via interface
|
||||
cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-interface", interfaceName)
|
||||
} else {
|
||||
return fmt.Errorf("either gateway or interface must be specified")
|
||||
}
|
||||
|
||||
logger.Info("Running command: %v", cmd)
|
||||
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("route command failed: %v, output: %s", err, out)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func DarwinRemoveRoute(destination string) error {
|
||||
if runtime.GOOS != "darwin" {
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd := exec.Command("route", "-q", "-n", "delete", "-inet", destination)
|
||||
logger.Info("Running command: %v", cmd)
|
||||
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("route delete command failed: %v, output: %s", err, out)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func LinuxAddRoute(destination string, gateway string, interfaceName string) error {
|
||||
if runtime.GOOS != "linux" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse destination CIDR
|
||||
_, ipNet, err := net.ParseCIDR(destination)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid destination address: %v", err)
|
||||
}
|
||||
|
||||
// Create route
|
||||
route := &netlink.Route{
|
||||
Dst: ipNet,
|
||||
}
|
||||
|
||||
if gateway != "" {
|
||||
// Route with specific gateway
|
||||
gw := net.ParseIP(gateway)
|
||||
if gw == nil {
|
||||
return fmt.Errorf("invalid gateway address: %s", gateway)
|
||||
}
|
||||
route.Gw = gw
|
||||
logger.Info("Adding route to %s via gateway %s", destination, gateway)
|
||||
} else if interfaceName != "" {
|
||||
// Route via interface
|
||||
link, err := netlink.LinkByName(interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get interface %s: %v", interfaceName, err)
|
||||
}
|
||||
route.LinkIndex = link.Attrs().Index
|
||||
logger.Info("Adding route to %s via interface %s", destination, interfaceName)
|
||||
} else {
|
||||
return fmt.Errorf("either gateway or interface must be specified")
|
||||
}
|
||||
|
||||
// Add the route
|
||||
if err := netlink.RouteAdd(route); err != nil {
|
||||
return fmt.Errorf("failed to add route: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func LinuxRemoveRoute(destination string) error {
|
||||
if runtime.GOOS != "linux" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse destination CIDR
|
||||
_, ipNet, err := net.ParseCIDR(destination)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid destination address: %v", err)
|
||||
}
|
||||
|
||||
// Create route to delete
|
||||
route := &netlink.Route{
|
||||
Dst: ipNet,
|
||||
}
|
||||
|
||||
logger.Info("Removing route to %s", destination)
|
||||
|
||||
// Delete the route
|
||||
if err := netlink.RouteDel(route); err != nil {
|
||||
return fmt.Errorf("failed to delete route: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addRouteForServerIP adds an OS-specific route for the server IP
|
||||
func AddRouteForServerIP(serverIP, interfaceName string) error {
|
||||
if err := AddRouteForNetworkConfig(serverIP); err != nil {
|
||||
return err
|
||||
}
|
||||
if interfaceName == "" {
|
||||
return nil
|
||||
}
|
||||
if runtime.GOOS == "darwin" {
|
||||
return DarwinAddRoute(serverIP, "", interfaceName)
|
||||
}
|
||||
// else if runtime.GOOS == "windows" {
|
||||
// return WindowsAddRoute(serverIP, "", interfaceName)
|
||||
// } else if runtime.GOOS == "linux" {
|
||||
// return LinuxAddRoute(serverIP, "", interfaceName)
|
||||
// }
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeRouteForServerIP removes an OS-specific route for the server IP
|
||||
func RemoveRouteForServerIP(serverIP string, interfaceName string) error {
|
||||
if err := RemoveRouteForNetworkConfig(serverIP); err != nil {
|
||||
return err
|
||||
}
|
||||
if interfaceName == "" {
|
||||
return nil
|
||||
}
|
||||
if runtime.GOOS == "darwin" {
|
||||
return DarwinRemoveRoute(serverIP)
|
||||
}
|
||||
// else if runtime.GOOS == "windows" {
|
||||
// return WindowsRemoveRoute(serverIP)
|
||||
// } else if runtime.GOOS == "linux" {
|
||||
// return LinuxRemoveRoute(serverIP)
|
||||
// }
|
||||
return nil
|
||||
}
|
||||
|
||||
func AddRouteForNetworkConfig(destination string) error {
|
||||
// Parse the subnet to extract IP and mask
|
||||
_, ipNet, err := net.ParseCIDR(destination)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse subnet %s: %v", destination, err)
|
||||
}
|
||||
|
||||
// Convert CIDR mask to dotted decimal format (e.g., 255.255.255.0)
|
||||
mask := net.IP(ipNet.Mask).String()
|
||||
destinationAddress := ipNet.IP.String()
|
||||
|
||||
AddIPv4IncludedRoute(IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func RemoveRouteForNetworkConfig(destination string) error {
|
||||
// Parse the subnet to extract IP and mask
|
||||
_, ipNet, err := net.ParseCIDR(destination)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse subnet %s: %v", destination, err)
|
||||
}
|
||||
|
||||
// Convert CIDR mask to dotted decimal format (e.g., 255.255.255.0)
|
||||
mask := net.IP(ipNet.Mask).String()
|
||||
destinationAddress := ipNet.IP.String()
|
||||
|
||||
RemoveIPv4IncludedRoute(IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addRoutes adds routes for each subnet in RemoteSubnets
|
||||
func AddRoutes(remoteSubnets []string, interfaceName string) error {
|
||||
if len(remoteSubnets) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add routes for each subnet
|
||||
for _, subnet := range remoteSubnets {
|
||||
subnet = strings.TrimSpace(subnet)
|
||||
if subnet == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := AddRouteForNetworkConfig(subnet); err != nil {
|
||||
logger.Error("Failed to add network config for subnet %s: %v", subnet, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Add route based on operating system
|
||||
if interfaceName == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
if err := DarwinAddRoute(subnet, "", interfaceName); err != nil {
|
||||
logger.Error("Failed to add Darwin route for subnet %s: %v", subnet, err)
|
||||
return err
|
||||
}
|
||||
} else if runtime.GOOS == "windows" {
|
||||
if err := WindowsAddRoute(subnet, "", interfaceName); err != nil {
|
||||
logger.Error("Failed to add Windows route for subnet %s: %v", subnet, err)
|
||||
return err
|
||||
}
|
||||
} else if runtime.GOOS == "linux" {
|
||||
if err := LinuxAddRoute(subnet, "", interfaceName); err != nil {
|
||||
logger.Error("Failed to add Linux route for subnet %s: %v", subnet, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Added route for remote subnet: %s", subnet)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeRoutesForRemoteSubnets removes routes for each subnet in RemoteSubnets
|
||||
func RemoveRoutes(remoteSubnets []string) error {
|
||||
if len(remoteSubnets) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove routes for each subnet
|
||||
for _, subnet := range remoteSubnets {
|
||||
subnet = strings.TrimSpace(subnet)
|
||||
if subnet == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := RemoveRouteForNetworkConfig(subnet); err != nil {
|
||||
logger.Error("Failed to remove network config for subnet %s: %v", subnet, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Remove route based on operating system
|
||||
if runtime.GOOS == "darwin" {
|
||||
if err := DarwinRemoveRoute(subnet); err != nil {
|
||||
logger.Error("Failed to remove Darwin route for subnet %s: %v", subnet, err)
|
||||
return err
|
||||
}
|
||||
} else if runtime.GOOS == "windows" {
|
||||
if err := WindowsRemoveRoute(subnet); err != nil {
|
||||
logger.Error("Failed to remove Windows route for subnet %s: %v", subnet, err)
|
||||
return err
|
||||
}
|
||||
} else if runtime.GOOS == "linux" {
|
||||
if err := LinuxRemoveRoute(subnet); err != nil {
|
||||
logger.Error("Failed to remove Linux route for subnet %s: %v", subnet, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Removed route for remote subnet: %s", subnet)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
11
network/route_notwindows.go
Normal file
11
network/route_notwindows.go
Normal file
@@ -0,0 +1,11 @@
|
||||
//go:build !windows
|
||||
|
||||
package network
|
||||
|
||||
func WindowsAddRoute(destination string, gateway string, interfaceName string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func WindowsRemoveRoute(destination string) error {
|
||||
return nil
|
||||
}
|
||||
148
network/route_windows.go
Normal file
148
network/route_windows.go
Normal file
@@ -0,0 +1,148 @@
|
||||
//go:build windows
|
||||
|
||||
package network
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
)
|
||||
|
||||
func WindowsAddRoute(destination string, gateway string, interfaceName string) error {
|
||||
if runtime.GOOS != "windows" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse destination CIDR
|
||||
_, ipNet, err := net.ParseCIDR(destination)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid destination address: %v", err)
|
||||
}
|
||||
|
||||
// Convert to netip.Prefix
|
||||
maskBits, _ := ipNet.Mask.Size()
|
||||
|
||||
// Ensure we convert to the correct IP version (IPv4 vs IPv6)
|
||||
var addr netip.Addr
|
||||
if ip4 := ipNet.IP.To4(); ip4 != nil {
|
||||
// IPv4 address
|
||||
addr, _ = netip.AddrFromSlice(ip4)
|
||||
} else {
|
||||
// IPv6 address
|
||||
addr, _ = netip.AddrFromSlice(ipNet.IP)
|
||||
}
|
||||
if !addr.IsValid() {
|
||||
return fmt.Errorf("failed to convert destination IP")
|
||||
}
|
||||
prefix := netip.PrefixFrom(addr, maskBits)
|
||||
|
||||
var luid winipcfg.LUID
|
||||
var nextHop netip.Addr
|
||||
|
||||
if interfaceName != "" {
|
||||
// Get the interface LUID - needed for both gateway and interface-only routes
|
||||
iface, err := net.InterfaceByName(interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get interface %s: %v", interfaceName, err)
|
||||
}
|
||||
|
||||
luid, err = winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get LUID for interface %s: %v", interfaceName, err)
|
||||
}
|
||||
}
|
||||
|
||||
if gateway != "" {
|
||||
// Route with specific gateway
|
||||
gwIP := net.ParseIP(gateway)
|
||||
if gwIP == nil {
|
||||
return fmt.Errorf("invalid gateway address: %s", gateway)
|
||||
}
|
||||
// Convert to correct IP version
|
||||
if ip4 := gwIP.To4(); ip4 != nil {
|
||||
nextHop, _ = netip.AddrFromSlice(ip4)
|
||||
} else {
|
||||
nextHop, _ = netip.AddrFromSlice(gwIP)
|
||||
}
|
||||
if !nextHop.IsValid() {
|
||||
return fmt.Errorf("failed to convert gateway IP")
|
||||
}
|
||||
logger.Info("Adding route to %s via gateway %s on interface %s", destination, gateway, interfaceName)
|
||||
} else if interfaceName != "" {
|
||||
// Route via interface only
|
||||
if addr.Is4() {
|
||||
nextHop = netip.IPv4Unspecified()
|
||||
} else {
|
||||
nextHop = netip.IPv6Unspecified()
|
||||
}
|
||||
logger.Info("Adding route to %s via interface %s", destination, interfaceName)
|
||||
} else {
|
||||
return fmt.Errorf("either gateway or interface must be specified")
|
||||
}
|
||||
|
||||
// Add the route using winipcfg
|
||||
err = luid.AddRoute(prefix, nextHop, 1)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add route: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func WindowsRemoveRoute(destination string) error {
|
||||
// Parse destination CIDR
|
||||
_, ipNet, err := net.ParseCIDR(destination)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid destination address: %v", err)
|
||||
}
|
||||
|
||||
// Convert to netip.Prefix
|
||||
maskBits, _ := ipNet.Mask.Size()
|
||||
|
||||
// Ensure we convert to the correct IP version (IPv4 vs IPv6)
|
||||
var addr netip.Addr
|
||||
if ip4 := ipNet.IP.To4(); ip4 != nil {
|
||||
// IPv4 address
|
||||
addr, _ = netip.AddrFromSlice(ip4)
|
||||
} else {
|
||||
// IPv6 address
|
||||
addr, _ = netip.AddrFromSlice(ipNet.IP)
|
||||
}
|
||||
if !addr.IsValid() {
|
||||
return fmt.Errorf("failed to convert destination IP")
|
||||
}
|
||||
prefix := netip.PrefixFrom(addr, maskBits)
|
||||
|
||||
// Get all routes and find the one to delete
|
||||
// We need to get the LUID from the existing route
|
||||
var family winipcfg.AddressFamily
|
||||
if addr.Is4() {
|
||||
family = 2 // AF_INET
|
||||
} else {
|
||||
family = 23 // AF_INET6
|
||||
}
|
||||
|
||||
routes, err := winipcfg.GetIPForwardTable2(family)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get route table: %v", err)
|
||||
}
|
||||
|
||||
// Find and delete matching route
|
||||
for _, route := range routes {
|
||||
routePrefix := route.DestinationPrefix.Prefix()
|
||||
if routePrefix == prefix {
|
||||
logger.Info("Removing route to %s", destination)
|
||||
err = route.Delete()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete route: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("route to %s not found", destination)
|
||||
}
|
||||
190
network/settings.go
Normal file
190
network/settings.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"sync"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
)
|
||||
|
||||
// NetworkSettings represents the network configuration for the tunnel
|
||||
type NetworkSettings struct {
|
||||
TunnelRemoteAddress string `json:"tunnel_remote_address,omitempty"`
|
||||
MTU *int `json:"mtu,omitempty"`
|
||||
DNSServers []string `json:"dns_servers,omitempty"`
|
||||
IPv4Addresses []string `json:"ipv4_addresses,omitempty"`
|
||||
IPv4SubnetMasks []string `json:"ipv4_subnet_masks,omitempty"`
|
||||
IPv4IncludedRoutes []IPv4Route `json:"ipv4_included_routes,omitempty"`
|
||||
IPv4ExcludedRoutes []IPv4Route `json:"ipv4_excluded_routes,omitempty"`
|
||||
IPv6Addresses []string `json:"ipv6_addresses,omitempty"`
|
||||
IPv6NetworkPrefixes []string `json:"ipv6_network_prefixes,omitempty"`
|
||||
IPv6IncludedRoutes []IPv6Route `json:"ipv6_included_routes,omitempty"`
|
||||
IPv6ExcludedRoutes []IPv6Route `json:"ipv6_excluded_routes,omitempty"`
|
||||
}
|
||||
|
||||
// IPv4Route represents an IPv4 route
|
||||
type IPv4Route struct {
|
||||
DestinationAddress string `json:"destination_address"`
|
||||
SubnetMask string `json:"subnet_mask,omitempty"`
|
||||
GatewayAddress string `json:"gateway_address,omitempty"`
|
||||
IsDefault bool `json:"is_default,omitempty"`
|
||||
}
|
||||
|
||||
// IPv6Route represents an IPv6 route
|
||||
type IPv6Route struct {
|
||||
DestinationAddress string `json:"destination_address"`
|
||||
NetworkPrefixLength int `json:"network_prefix_length,omitempty"`
|
||||
GatewayAddress string `json:"gateway_address,omitempty"`
|
||||
IsDefault bool `json:"is_default,omitempty"`
|
||||
}
|
||||
|
||||
var (
|
||||
networkSettings NetworkSettings
|
||||
networkSettingsMutex sync.RWMutex
|
||||
incrementor int
|
||||
)
|
||||
|
||||
// SetTunnelRemoteAddress sets the tunnel remote address
|
||||
func SetTunnelRemoteAddress(address string) {
|
||||
networkSettingsMutex.Lock()
|
||||
defer networkSettingsMutex.Unlock()
|
||||
networkSettings.TunnelRemoteAddress = address
|
||||
incrementor++
|
||||
logger.Info("Set tunnel remote address: %s", address)
|
||||
}
|
||||
|
||||
// SetMTU sets the MTU value
|
||||
func SetMTU(mtu int) {
|
||||
networkSettingsMutex.Lock()
|
||||
defer networkSettingsMutex.Unlock()
|
||||
networkSettings.MTU = &mtu
|
||||
incrementor++
|
||||
logger.Info("Set MTU: %d", mtu)
|
||||
}
|
||||
|
||||
// SetDNSServers sets the DNS servers
|
||||
func SetDNSServers(servers []string) {
|
||||
networkSettingsMutex.Lock()
|
||||
defer networkSettingsMutex.Unlock()
|
||||
networkSettings.DNSServers = servers
|
||||
incrementor++
|
||||
logger.Info("Set DNS servers: %v", servers)
|
||||
}
|
||||
|
||||
// SetIPv4Settings sets IPv4 addresses and subnet masks
|
||||
func SetIPv4Settings(addresses []string, subnetMasks []string) {
|
||||
networkSettingsMutex.Lock()
|
||||
defer networkSettingsMutex.Unlock()
|
||||
networkSettings.IPv4Addresses = addresses
|
||||
networkSettings.IPv4SubnetMasks = subnetMasks
|
||||
incrementor++
|
||||
logger.Info("Set IPv4 addresses: %v, subnet masks: %v", addresses, subnetMasks)
|
||||
}
|
||||
|
||||
// SetIPv4IncludedRoutes sets the included IPv4 routes
|
||||
func SetIPv4IncludedRoutes(routes []IPv4Route) {
|
||||
networkSettingsMutex.Lock()
|
||||
defer networkSettingsMutex.Unlock()
|
||||
networkSettings.IPv4IncludedRoutes = routes
|
||||
incrementor++
|
||||
logger.Info("Set IPv4 included routes: %d routes", len(routes))
|
||||
}
|
||||
|
||||
func AddIPv4IncludedRoute(route IPv4Route) {
|
||||
networkSettingsMutex.Lock()
|
||||
defer networkSettingsMutex.Unlock()
|
||||
|
||||
// make sure it does not already exist
|
||||
for _, r := range networkSettings.IPv4IncludedRoutes {
|
||||
if r == route {
|
||||
logger.Info("IPv4 included route already exists: %+v", route)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
networkSettings.IPv4IncludedRoutes = append(networkSettings.IPv4IncludedRoutes, route)
|
||||
incrementor++
|
||||
logger.Info("Added IPv4 included route: %+v", route)
|
||||
}
|
||||
|
||||
func RemoveIPv4IncludedRoute(route IPv4Route) {
|
||||
networkSettingsMutex.Lock()
|
||||
defer networkSettingsMutex.Unlock()
|
||||
routes := networkSettings.IPv4IncludedRoutes
|
||||
for i, r := range routes {
|
||||
if r == route {
|
||||
networkSettings.IPv4IncludedRoutes = append(routes[:i], routes[i+1:]...)
|
||||
logger.Info("Removed IPv4 included route: %+v", route)
|
||||
return
|
||||
}
|
||||
}
|
||||
incrementor++
|
||||
logger.Info("IPv4 included route not found for removal: %+v", route)
|
||||
}
|
||||
|
||||
func SetIPv4ExcludedRoutes(routes []IPv4Route) {
|
||||
networkSettingsMutex.Lock()
|
||||
defer networkSettingsMutex.Unlock()
|
||||
networkSettings.IPv4ExcludedRoutes = routes
|
||||
incrementor++
|
||||
logger.Info("Set IPv4 excluded routes: %d routes", len(routes))
|
||||
}
|
||||
|
||||
// SetIPv6Settings sets IPv6 addresses and network prefixes
|
||||
func SetIPv6Settings(addresses []string, networkPrefixes []string) {
|
||||
networkSettingsMutex.Lock()
|
||||
defer networkSettingsMutex.Unlock()
|
||||
networkSettings.IPv6Addresses = addresses
|
||||
networkSettings.IPv6NetworkPrefixes = networkPrefixes
|
||||
incrementor++
|
||||
logger.Info("Set IPv6 addresses: %v, network prefixes: %v", addresses, networkPrefixes)
|
||||
}
|
||||
|
||||
// SetIPv6IncludedRoutes sets the included IPv6 routes
|
||||
func SetIPv6IncludedRoutes(routes []IPv6Route) {
|
||||
networkSettingsMutex.Lock()
|
||||
defer networkSettingsMutex.Unlock()
|
||||
networkSettings.IPv6IncludedRoutes = routes
|
||||
incrementor++
|
||||
logger.Info("Set IPv6 included routes: %d routes", len(routes))
|
||||
}
|
||||
|
||||
// SetIPv6ExcludedRoutes sets the excluded IPv6 routes
|
||||
func SetIPv6ExcludedRoutes(routes []IPv6Route) {
|
||||
networkSettingsMutex.Lock()
|
||||
defer networkSettingsMutex.Unlock()
|
||||
networkSettings.IPv6ExcludedRoutes = routes
|
||||
incrementor++
|
||||
logger.Info("Set IPv6 excluded routes: %d routes", len(routes))
|
||||
}
|
||||
|
||||
// ClearNetworkSettings clears all network settings
|
||||
func ClearNetworkSettings() {
|
||||
networkSettingsMutex.Lock()
|
||||
defer networkSettingsMutex.Unlock()
|
||||
networkSettings = NetworkSettings{}
|
||||
incrementor++
|
||||
logger.Info("Cleared all network settings")
|
||||
}
|
||||
|
||||
func GetJSON() (string, error) {
|
||||
networkSettingsMutex.RLock()
|
||||
defer networkSettingsMutex.RUnlock()
|
||||
data, err := json.MarshalIndent(networkSettings, "", " ")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
func GetSettings() NetworkSettings {
|
||||
networkSettingsMutex.RLock()
|
||||
defer networkSettingsMutex.RUnlock()
|
||||
return networkSettings
|
||||
}
|
||||
|
||||
func GetIncrementor() int {
|
||||
networkSettingsMutex.Lock()
|
||||
defer networkSettingsMutex.Unlock()
|
||||
return incrementor
|
||||
}
|
||||
Reference in New Issue
Block a user