mirror of
https://github.com/fosrl/newt.git
synced 2026-03-14 10:55:07 -05:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ef03b4566d | ||
|
|
44ca592a5c | ||
|
|
e1edbcea07 | ||
|
|
392e4c83bf | ||
|
|
a85454e770 | ||
|
|
068145c539 | ||
|
|
91a035f4ab |
1038
.github/workflows/cicd.yml
vendored
1038
.github/workflows/cicd.yml
vendored
File diff suppressed because it is too large
Load Diff
@@ -17,7 +17,8 @@ RUN go mod download
|
||||
COPY . .
|
||||
|
||||
# Build the application
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o /newt
|
||||
ARG VERSION=dev
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w -X main.newtVersion=${VERSION}" -o /newt
|
||||
|
||||
FROM public.ecr.aws/docker/library/alpine:3.23 AS runner
|
||||
|
||||
|
||||
23
Makefile
23
Makefile
@@ -2,6 +2,9 @@
|
||||
|
||||
all: local
|
||||
|
||||
VERSION ?= dev
|
||||
LDFLAGS = -X main.newtVersion=$(VERSION)
|
||||
|
||||
local:
|
||||
CGO_ENABLED=0 go build -o ./bin/newt
|
||||
|
||||
@@ -40,31 +43,31 @@ go-build-release: \
|
||||
go-build-release-freebsd-arm64
|
||||
|
||||
go-build-release-linux-arm64:
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -o bin/newt_linux_arm64
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -ldflags "$(LDFLAGS)" -o bin/newt_linux_arm64
|
||||
|
||||
go-build-release-linux-arm32-v7:
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=7 go build -o bin/newt_linux_arm32
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=7 go build -ldflags "$(LDFLAGS)" -o bin/newt_linux_arm32
|
||||
|
||||
go-build-release-linux-arm32-v6:
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=6 go build -o bin/newt_linux_arm32v6
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=6 go build -ldflags "$(LDFLAGS)" -o bin/newt_linux_arm32v6
|
||||
|
||||
go-build-release-linux-amd64:
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/newt_linux_amd64
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags "$(LDFLAGS)" -o bin/newt_linux_amd64
|
||||
|
||||
go-build-release-linux-riscv64:
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=riscv64 go build -o bin/newt_linux_riscv64
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=riscv64 go build -ldflags "$(LDFLAGS)" -o bin/newt_linux_riscv64
|
||||
|
||||
go-build-release-darwin-arm64:
|
||||
CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -o bin/newt_darwin_arm64
|
||||
CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -ldflags "$(LDFLAGS)" -o bin/newt_darwin_arm64
|
||||
|
||||
go-build-release-darwin-amd64:
|
||||
CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -o bin/newt_darwin_amd64
|
||||
CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -ldflags "$(LDFLAGS)" -o bin/newt_darwin_amd64
|
||||
|
||||
go-build-release-windows-amd64:
|
||||
CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -o bin/newt_windows_amd64.exe
|
||||
CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -ldflags "$(LDFLAGS)" -o bin/newt_windows_amd64.exe
|
||||
|
||||
go-build-release-freebsd-amd64:
|
||||
CGO_ENABLED=0 GOOS=freebsd GOARCH=amd64 go build -o bin/newt_freebsd_amd64
|
||||
CGO_ENABLED=0 GOOS=freebsd GOARCH=amd64 go build -ldflags "$(LDFLAGS)" -o bin/newt_freebsd_amd64
|
||||
|
||||
go-build-release-freebsd-arm64:
|
||||
CGO_ENABLED=0 GOOS=freebsd GOARCH=arm64 go build -o bin/newt_freebsd_arm64
|
||||
CGO_ENABLED=0 GOOS=freebsd GOARCH=arm64 go build -ldflags "$(LDFLAGS)" -o bin/newt_freebsd_arm64
|
||||
@@ -37,12 +37,11 @@ type WgConfig struct {
|
||||
}
|
||||
|
||||
type Target struct {
|
||||
SourcePrefix string `json:"sourcePrefix"`
|
||||
SourcePrefixes []string `json:"sourcePrefixes"`
|
||||
DestPrefix string `json:"destPrefix"`
|
||||
RewriteTo string `json:"rewriteTo,omitempty"`
|
||||
DisableIcmp bool `json:"disableIcmp,omitempty"`
|
||||
PortRange []PortRange `json:"portRange,omitempty"`
|
||||
SourcePrefix string `json:"sourcePrefix"`
|
||||
DestPrefix string `json:"destPrefix"`
|
||||
RewriteTo string `json:"rewriteTo,omitempty"`
|
||||
DisableIcmp bool `json:"disableIcmp,omitempty"`
|
||||
PortRange []PortRange `json:"portRange,omitempty"`
|
||||
}
|
||||
|
||||
type PortRange struct {
|
||||
@@ -113,6 +112,8 @@ func NewWireGuardService(interfaceName string, port uint16, mtu int, host string
|
||||
return nil, fmt.Errorf("failed to generate private key: %v", err)
|
||||
}
|
||||
|
||||
logger.Debug("+++++++++++++++++++++++++++++++= the port is %d", port)
|
||||
|
||||
if port == 0 {
|
||||
// Find an available port
|
||||
portRandom, err := util.FindAvailableUDPPort(49152, 65535)
|
||||
@@ -161,9 +162,8 @@ func NewWireGuardService(interfaceName string, port uint16, mtu int, host string
|
||||
useNativeInterface: useNativeInterface,
|
||||
}
|
||||
|
||||
// Create the holepunch manager with ResolveDomain function
|
||||
// We'll need to pass a domain resolver function
|
||||
service.holePunchManager = holepunch.NewManager(sharedBind, newtId, "newt", key.PublicKey().String())
|
||||
// Create the holepunch manager
|
||||
service.holePunchManager = holepunch.NewManager(sharedBind, newtId, "newt", key.PublicKey().String(), nil)
|
||||
|
||||
// Register websocket handlers
|
||||
wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig)
|
||||
@@ -173,7 +173,6 @@ func NewWireGuardService(interfaceName string, port uint16, mtu int, host string
|
||||
wsClient.RegisterHandler("newt/wg/targets/add", service.handleAddTarget)
|
||||
wsClient.RegisterHandler("newt/wg/targets/remove", service.handleRemoveTarget)
|
||||
wsClient.RegisterHandler("newt/wg/targets/update", service.handleUpdateTarget)
|
||||
wsClient.RegisterHandler("newt/wg/sync", service.handleSyncConfig)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
@@ -279,7 +278,7 @@ func (s *WireGuardService) StartHolepunch(publicKey string, endpoint string, rel
|
||||
}
|
||||
|
||||
if relayPort == 0 {
|
||||
relayPort = 21820
|
||||
relayPort = 21820
|
||||
}
|
||||
|
||||
// Convert websocket.ExitNode to holepunch.ExitNode
|
||||
@@ -494,183 +493,6 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
|
||||
logger.Info("Client connectivity setup. Ready to accept connections from clients!")
|
||||
}
|
||||
|
||||
// SyncConfig represents the configuration sent from server for syncing
|
||||
type SyncConfig struct {
|
||||
Targets []Target `json:"targets"`
|
||||
Peers []Peer `json:"peers"`
|
||||
}
|
||||
|
||||
func (s *WireGuardService) handleSyncConfig(msg websocket.WSMessage) {
|
||||
var syncConfig SyncConfig
|
||||
|
||||
logger.Debug("Received sync message: %v", msg)
|
||||
logger.Info("Received sync configuration from remote server")
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling sync data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &syncConfig); err != nil {
|
||||
logger.Error("Error unmarshaling sync data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Sync peers
|
||||
if err := s.syncPeers(syncConfig.Peers); err != nil {
|
||||
logger.Error("Failed to sync peers: %v", err)
|
||||
}
|
||||
|
||||
// Sync targets
|
||||
if err := s.syncTargets(syncConfig.Targets); err != nil {
|
||||
logger.Error("Failed to sync targets: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// syncPeers synchronizes the current peers with the desired state
|
||||
// It removes peers not in the desired list and adds missing ones
|
||||
func (s *WireGuardService) syncPeers(desiredPeers []Peer) error {
|
||||
if s.device == nil {
|
||||
return fmt.Errorf("WireGuard device is not initialized")
|
||||
}
|
||||
|
||||
// Get current peers from the device
|
||||
currentConfig, err := s.device.IpcGet()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get current device config: %v", err)
|
||||
}
|
||||
|
||||
// Parse current peer public keys
|
||||
lines := strings.Split(currentConfig, "\n")
|
||||
currentPeerKeys := make(map[string]bool)
|
||||
for _, line := range lines {
|
||||
if strings.HasPrefix(line, "public_key=") {
|
||||
pubKey := strings.TrimPrefix(line, "public_key=")
|
||||
currentPeerKeys[pubKey] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Build a map of desired peers by their public key (normalized)
|
||||
desiredPeerMap := make(map[string]Peer)
|
||||
for _, peer := range desiredPeers {
|
||||
// Normalize the public key for comparison
|
||||
pubKey, err := wgtypes.ParseKey(peer.PublicKey)
|
||||
if err != nil {
|
||||
logger.Warn("Invalid public key in desired peers: %s", peer.PublicKey)
|
||||
continue
|
||||
}
|
||||
normalizedKey := util.FixKey(pubKey.String())
|
||||
desiredPeerMap[normalizedKey] = peer
|
||||
}
|
||||
|
||||
// Remove peers that are not in the desired list
|
||||
for currentKey := range currentPeerKeys {
|
||||
if _, exists := desiredPeerMap[currentKey]; !exists {
|
||||
// Parse the key back to get the original format for removal
|
||||
removeConfig := fmt.Sprintf("public_key=%s\nremove=true", currentKey)
|
||||
if err := s.device.IpcSet(removeConfig); err != nil {
|
||||
logger.Warn("Failed to remove peer %s during sync: %v", currentKey, err)
|
||||
} else {
|
||||
logger.Info("Removed peer %s during sync", currentKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add peers that are missing
|
||||
for normalizedKey, peer := range desiredPeerMap {
|
||||
if _, exists := currentPeerKeys[normalizedKey]; !exists {
|
||||
if err := s.addPeerToDevice(peer); err != nil {
|
||||
logger.Warn("Failed to add peer %s during sync: %v", peer.PublicKey, err)
|
||||
} else {
|
||||
logger.Info("Added peer %s during sync", peer.PublicKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// syncTargets synchronizes the current targets with the desired state
|
||||
// It removes targets not in the desired list and adds missing ones
|
||||
func (s *WireGuardService) syncTargets(desiredTargets []Target) error {
|
||||
if s.tnet == nil {
|
||||
// Native interface mode - proxy features not available, skip silently
|
||||
logger.Debug("Skipping target sync - using native interface (no proxy support)")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get current rules from the proxy handler
|
||||
currentRules := s.tnet.GetProxySubnetRules()
|
||||
|
||||
// Build a map of current rules by source+dest prefix
|
||||
type ruleKey struct {
|
||||
sourcePrefix string
|
||||
destPrefix string
|
||||
}
|
||||
currentRuleMap := make(map[ruleKey]bool)
|
||||
for _, rule := range currentRules {
|
||||
key := ruleKey{
|
||||
sourcePrefix: rule.SourcePrefix.String(),
|
||||
destPrefix: rule.DestPrefix.String(),
|
||||
}
|
||||
currentRuleMap[key] = true
|
||||
}
|
||||
|
||||
// Build a map of desired targets
|
||||
desiredTargetMap := make(map[ruleKey]Target)
|
||||
for _, target := range desiredTargets {
|
||||
key := ruleKey{
|
||||
sourcePrefix: target.SourcePrefix,
|
||||
destPrefix: target.DestPrefix,
|
||||
}
|
||||
desiredTargetMap[key] = target
|
||||
}
|
||||
|
||||
// Remove targets that are not in the desired list
|
||||
for _, rule := range currentRules {
|
||||
key := ruleKey{
|
||||
sourcePrefix: rule.SourcePrefix.String(),
|
||||
destPrefix: rule.DestPrefix.String(),
|
||||
}
|
||||
if _, exists := desiredTargetMap[key]; !exists {
|
||||
s.tnet.RemoveProxySubnetRule(rule.SourcePrefix, rule.DestPrefix)
|
||||
logger.Info("Removed target %s -> %s during sync", rule.SourcePrefix.String(), rule.DestPrefix.String())
|
||||
}
|
||||
}
|
||||
|
||||
// Add targets that are missing
|
||||
for key, target := range desiredTargetMap {
|
||||
if _, exists := currentRuleMap[key]; !exists {
|
||||
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
|
||||
if err != nil {
|
||||
logger.Warn("Invalid source prefix %s during sync: %v", target.SourcePrefix, err)
|
||||
continue
|
||||
}
|
||||
|
||||
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
|
||||
if err != nil {
|
||||
logger.Warn("Invalid dest prefix %s during sync: %v", target.DestPrefix, err)
|
||||
continue
|
||||
}
|
||||
|
||||
var portRanges []netstack2.PortRange
|
||||
for _, pr := range target.PortRange {
|
||||
portRanges = append(portRanges, netstack2.PortRange{
|
||||
Min: pr.Min,
|
||||
Max: pr.Max,
|
||||
Protocol: pr.Protocol,
|
||||
})
|
||||
}
|
||||
|
||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
|
||||
logger.Info("Added target %s -> %s during sync", target.SourcePrefix, target.DestPrefix)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
||||
s.mu.Lock()
|
||||
|
||||
@@ -874,19 +696,6 @@ func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveSourcePrefixes returns the effective list of source prefixes for a target,
|
||||
// supporting both the legacy single SourcePrefix field and the new SourcePrefixes array.
|
||||
// If SourcePrefixes is non-empty it takes precedence; otherwise SourcePrefix is used.
|
||||
func resolveSourcePrefixes(target Target) []string {
|
||||
if len(target.SourcePrefixes) > 0 {
|
||||
return target.SourcePrefixes
|
||||
}
|
||||
if target.SourcePrefix != "" {
|
||||
return []string{target.SourcePrefix}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) ensureTargets(targets []Target) error {
|
||||
if s.tnet == nil {
|
||||
// Native interface mode - proxy features not available, skip silently
|
||||
@@ -895,6 +704,11 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
|
||||
}
|
||||
|
||||
for _, target := range targets {
|
||||
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid CIDR %s: %v", target.SourcePrefix, err)
|
||||
}
|
||||
|
||||
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid CIDR %s: %v", target.DestPrefix, err)
|
||||
@@ -909,14 +723,9 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
|
||||
})
|
||||
}
|
||||
|
||||
for _, sp := range resolveSourcePrefixes(target) {
|
||||
sourcePrefix, err := netip.ParsePrefix(sp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid CIDR %s: %v", sp, err)
|
||||
}
|
||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
|
||||
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
|
||||
}
|
||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
|
||||
|
||||
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v disableIcmp: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange, target.DisableIcmp)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1235,7 +1044,7 @@ func (s *WireGuardService) processPeerBandwidth(publicKey string, rxBytes, txByt
|
||||
BytesOut: bytesOutMB,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -1286,6 +1095,12 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
|
||||
|
||||
// Process all targets
|
||||
for _, target := range targets {
|
||||
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
|
||||
if err != nil {
|
||||
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
|
||||
continue
|
||||
}
|
||||
|
||||
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
|
||||
if err != nil {
|
||||
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
|
||||
@@ -1295,21 +1110,15 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
|
||||
var portRanges []netstack2.PortRange
|
||||
for _, pr := range target.PortRange {
|
||||
portRanges = append(portRanges, netstack2.PortRange{
|
||||
Min: pr.Min,
|
||||
Max: pr.Max,
|
||||
Protocol: pr.Protocol,
|
||||
Min: pr.Min,
|
||||
Max: pr.Max,
|
||||
Protocol: pr.Protocol,
|
||||
})
|
||||
}
|
||||
|
||||
for _, sp := range resolveSourcePrefixes(target) {
|
||||
sourcePrefix, err := netip.ParsePrefix(sp)
|
||||
if err != nil {
|
||||
logger.Info("Invalid CIDR %s: %v", sp, err)
|
||||
continue
|
||||
}
|
||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
|
||||
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
|
||||
}
|
||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
|
||||
|
||||
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v disableIcmp: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange, target.DisableIcmp)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1338,21 +1147,21 @@ func (s *WireGuardService) handleRemoveTarget(msg websocket.WSMessage) {
|
||||
|
||||
// Process all targets
|
||||
for _, target := range targets {
|
||||
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
|
||||
if err != nil {
|
||||
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
|
||||
continue
|
||||
}
|
||||
|
||||
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
|
||||
if err != nil {
|
||||
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, sp := range resolveSourcePrefixes(target) {
|
||||
sourcePrefix, err := netip.ParsePrefix(sp)
|
||||
if err != nil {
|
||||
logger.Info("Invalid CIDR %s: %v", sp, err)
|
||||
continue
|
||||
}
|
||||
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix)
|
||||
logger.Info("Removed target subnet %s with destination %s", sp, target.DestPrefix)
|
||||
}
|
||||
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix)
|
||||
|
||||
logger.Info("Removed target subnet %s with destination %s", target.SourcePrefix, target.DestPrefix)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1386,24 +1195,30 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
|
||||
|
||||
// Process all update requests
|
||||
for _, target := range requests.OldTargets {
|
||||
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
|
||||
if err != nil {
|
||||
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
|
||||
continue
|
||||
}
|
||||
|
||||
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
|
||||
if err != nil {
|
||||
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, sp := range resolveSourcePrefixes(target) {
|
||||
sourcePrefix, err := netip.ParsePrefix(sp)
|
||||
if err != nil {
|
||||
logger.Info("Invalid CIDR %s: %v", sp, err)
|
||||
continue
|
||||
}
|
||||
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix)
|
||||
logger.Info("Removed target subnet %s with destination %s", sp, target.DestPrefix)
|
||||
}
|
||||
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix)
|
||||
logger.Info("Removed target subnet %s with destination %s", target.SourcePrefix, target.DestPrefix)
|
||||
}
|
||||
|
||||
for _, target := range requests.NewTargets {
|
||||
// Now add the new target
|
||||
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
|
||||
if err != nil {
|
||||
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
|
||||
continue
|
||||
}
|
||||
|
||||
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
|
||||
if err != nil {
|
||||
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
|
||||
@@ -1413,21 +1228,14 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
|
||||
var portRanges []netstack2.PortRange
|
||||
for _, pr := range target.PortRange {
|
||||
portRanges = append(portRanges, netstack2.PortRange{
|
||||
Min: pr.Min,
|
||||
Max: pr.Max,
|
||||
Protocol: pr.Protocol,
|
||||
Min: pr.Min,
|
||||
Max: pr.Max,
|
||||
Protocol: pr.Protocol,
|
||||
})
|
||||
}
|
||||
|
||||
for _, sp := range resolveSourcePrefixes(target) {
|
||||
sourcePrefix, err := netip.ParsePrefix(sp)
|
||||
if err != nil {
|
||||
logger.Info("Invalid CIDR %s: %v", sp, err)
|
||||
continue
|
||||
}
|
||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
|
||||
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
|
||||
}
|
||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
|
||||
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v disableIcmp: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange, target.DisableIcmp)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -521,82 +521,3 @@ func (m *Monitor) DisableTarget(id int) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTargetIDs returns a slice of all current target IDs
|
||||
func (m *Monitor) GetTargetIDs() []int {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
ids := make([]int, 0, len(m.targets))
|
||||
for id := range m.targets {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// SyncTargets synchronizes the current targets to match the desired set.
|
||||
// It removes targets not in the desired set and adds targets that are missing.
|
||||
func (m *Monitor) SyncTargets(desiredConfigs []Config) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
logger.Info("Syncing health check targets: %d desired targets", len(desiredConfigs))
|
||||
|
||||
// Build a set of desired target IDs
|
||||
desiredIDs := make(map[int]Config)
|
||||
for _, config := range desiredConfigs {
|
||||
desiredIDs[config.ID] = config
|
||||
}
|
||||
|
||||
// Find targets to remove (exist but not in desired set)
|
||||
var toRemove []int
|
||||
for id := range m.targets {
|
||||
if _, exists := desiredIDs[id]; !exists {
|
||||
toRemove = append(toRemove, id)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove targets that are not in the desired set
|
||||
for _, id := range toRemove {
|
||||
logger.Info("Sync: removing health check target %d", id)
|
||||
if target, exists := m.targets[id]; exists {
|
||||
target.cancel()
|
||||
delete(m.targets, id)
|
||||
}
|
||||
}
|
||||
|
||||
// Add or update targets from the desired set
|
||||
var addedCount, updatedCount int
|
||||
for id, config := range desiredIDs {
|
||||
if existing, exists := m.targets[id]; exists {
|
||||
// Target exists - check if config changed and update if needed
|
||||
// For now, we'll replace it to ensure config is up to date
|
||||
logger.Debug("Sync: updating health check target %d", id)
|
||||
existing.cancel()
|
||||
delete(m.targets, id)
|
||||
if err := m.addTargetUnsafe(config); err != nil {
|
||||
logger.Error("Sync: failed to update target %d: %v", id, err)
|
||||
return fmt.Errorf("failed to update target %d: %v", id, err)
|
||||
}
|
||||
updatedCount++
|
||||
} else {
|
||||
// Target doesn't exist - add it
|
||||
logger.Debug("Sync: adding health check target %d", id)
|
||||
if err := m.addTargetUnsafe(config); err != nil {
|
||||
logger.Error("Sync: failed to add target %d: %v", id, err)
|
||||
return fmt.Errorf("failed to add target %d: %v", id, err)
|
||||
}
|
||||
addedCount++
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Sync complete: removed %d, added %d, updated %d targets",
|
||||
len(toRemove), addedCount, updatedCount)
|
||||
|
||||
// Notify callback if any changes were made
|
||||
if (len(toRemove) > 0 || addedCount > 0 || updatedCount > 0) && m.callback != nil {
|
||||
go m.callback(m.getAllTargetsUnsafe())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -27,16 +27,17 @@ type ExitNode struct {
|
||||
|
||||
// Manager handles UDP hole punching operations
|
||||
type Manager struct {
|
||||
mu sync.Mutex
|
||||
running bool
|
||||
stopChan chan struct{}
|
||||
sharedBind *bind.SharedBind
|
||||
ID string
|
||||
token string
|
||||
publicKey string
|
||||
clientType string
|
||||
exitNodes map[string]ExitNode // key is endpoint
|
||||
updateChan chan struct{} // signals the goroutine to refresh exit nodes
|
||||
mu sync.Mutex
|
||||
running bool
|
||||
stopChan chan struct{}
|
||||
sharedBind *bind.SharedBind
|
||||
ID string
|
||||
token string
|
||||
publicKey string
|
||||
clientType string
|
||||
exitNodes map[string]ExitNode // key is endpoint
|
||||
updateChan chan struct{} // signals the goroutine to refresh exit nodes
|
||||
publicDNS []string
|
||||
|
||||
sendHolepunchInterval time.Duration
|
||||
sendHolepunchIntervalMin time.Duration
|
||||
@@ -49,12 +50,13 @@ const defaultSendHolepunchIntervalMax = 60 * time.Second
|
||||
const defaultSendHolepunchIntervalMin = 1 * time.Second
|
||||
|
||||
// NewManager creates a new hole punch manager
|
||||
func NewManager(sharedBind *bind.SharedBind, ID string, clientType string, publicKey string) *Manager {
|
||||
func NewManager(sharedBind *bind.SharedBind, ID string, clientType string, publicKey string, publicDNS []string) *Manager {
|
||||
return &Manager{
|
||||
sharedBind: sharedBind,
|
||||
ID: ID,
|
||||
clientType: clientType,
|
||||
publicKey: publicKey,
|
||||
publicDNS: publicDNS,
|
||||
exitNodes: make(map[string]ExitNode),
|
||||
sendHolepunchInterval: defaultSendHolepunchIntervalMin,
|
||||
sendHolepunchIntervalMin: defaultSendHolepunchIntervalMin,
|
||||
@@ -281,7 +283,13 @@ func (m *Manager) TriggerHolePunch() error {
|
||||
// Send hole punch to all exit nodes
|
||||
successCount := 0
|
||||
for _, exitNode := range currentExitNodes {
|
||||
host, err := util.ResolveDomain(exitNode.Endpoint)
|
||||
var host string
|
||||
var err error
|
||||
if len(m.publicDNS) > 0 {
|
||||
host, err = util.ResolveDomainUpstream(exitNode.Endpoint, m.publicDNS)
|
||||
} else {
|
||||
host, err = util.ResolveDomain(exitNode.Endpoint)
|
||||
}
|
||||
if err != nil {
|
||||
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
|
||||
continue
|
||||
@@ -392,7 +400,13 @@ func (m *Manager) runMultipleExitNodes() {
|
||||
|
||||
var resolvedNodes []resolvedExitNode
|
||||
for _, exitNode := range currentExitNodes {
|
||||
host, err := util.ResolveDomain(exitNode.Endpoint)
|
||||
var host string
|
||||
var err error
|
||||
if len(m.publicDNS) > 0 {
|
||||
host, err = util.ResolveDomainUpstream(exitNode.Endpoint, m.publicDNS)
|
||||
} else {
|
||||
host, err = util.ResolveDomain(exitNode.Endpoint)
|
||||
}
|
||||
if err != nil {
|
||||
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
|
||||
continue
|
||||
|
||||
@@ -49,10 +49,11 @@ type cachedAddr struct {
|
||||
|
||||
// HolepunchTester monitors holepunch connectivity using magic packets
|
||||
type HolepunchTester struct {
|
||||
sharedBind *bind.SharedBind
|
||||
mu sync.RWMutex
|
||||
running bool
|
||||
stopChan chan struct{}
|
||||
sharedBind *bind.SharedBind
|
||||
publicDNS []string
|
||||
mu sync.RWMutex
|
||||
running bool
|
||||
stopChan chan struct{}
|
||||
|
||||
// Pending requests waiting for responses (key: echo data as string)
|
||||
pendingRequests sync.Map // map[string]*pendingRequest
|
||||
@@ -84,9 +85,10 @@ type pendingRequest struct {
|
||||
}
|
||||
|
||||
// NewHolepunchTester creates a new holepunch tester using the given SharedBind
|
||||
func NewHolepunchTester(sharedBind *bind.SharedBind) *HolepunchTester {
|
||||
func NewHolepunchTester(sharedBind *bind.SharedBind, publicDNS []string) *HolepunchTester {
|
||||
return &HolepunchTester{
|
||||
sharedBind: sharedBind,
|
||||
publicDNS: publicDNS,
|
||||
addrCache: make(map[string]*cachedAddr),
|
||||
addrCacheTTL: 5 * time.Minute, // Cache addresses for 5 minutes
|
||||
}
|
||||
@@ -169,7 +171,13 @@ func (t *HolepunchTester) resolveEndpoint(endpoint string) (*net.UDPAddr, error)
|
||||
}
|
||||
|
||||
// Resolve the endpoint
|
||||
host, err := util.ResolveDomain(endpoint)
|
||||
var host string
|
||||
var err error
|
||||
if len(t.publicDNS) > 0 {
|
||||
host, err = util.ResolveDomainUpstream(endpoint, t.publicDNS)
|
||||
} else {
|
||||
host, err = util.ResolveDomain(endpoint)
|
||||
}
|
||||
if err != nil {
|
||||
host = endpoint
|
||||
}
|
||||
|
||||
153
main.go
153
main.go
@@ -565,7 +565,7 @@ func runNewtMain(ctx context.Context) {
|
||||
id, // CLI arg takes precedence
|
||||
secret, // CLI arg takes precedence
|
||||
endpoint,
|
||||
30*time.Second,
|
||||
pingInterval,
|
||||
pingTimeout,
|
||||
opt,
|
||||
)
|
||||
@@ -959,7 +959,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
||||
"publicKey": publicKey.String(),
|
||||
"pingResults": pingResults,
|
||||
"newtVersion": newtVersion,
|
||||
}, 2*time.Second)
|
||||
}, 1*time.Second)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -1062,7 +1062,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
||||
"publicKey": publicKey.String(),
|
||||
"pingResults": pingResults,
|
||||
"newtVersion": newtVersion,
|
||||
}, 2*time.Second)
|
||||
}, 1*time.Second)
|
||||
|
||||
logger.Debug("Sent exit node ping results to cloud for selection: pingResults=%+v", pingResults)
|
||||
})
|
||||
@@ -1167,153 +1167,6 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
||||
}
|
||||
})
|
||||
|
||||
// Register handler for syncing targets (TCP, UDP, and health checks)
|
||||
client.RegisterHandler("newt/sync", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received sync message")
|
||||
|
||||
// if there is no wgData or pm, we can't sync targets
|
||||
if wgData.TunnelIP == "" || pm == nil {
|
||||
logger.Info(msgNoTunnelOrProxy)
|
||||
return
|
||||
}
|
||||
|
||||
// Define the sync data structure
|
||||
type SyncData struct {
|
||||
Targets TargetsByType `json:"targets"`
|
||||
HealthCheckTargets []healthcheck.Config `json:"healthCheckTargets"`
|
||||
}
|
||||
|
||||
var syncData SyncData
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling sync data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &syncData); err != nil {
|
||||
logger.Error("Error unmarshaling sync data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debug("Sync data received: TCP targets=%d, UDP targets=%d, health check targets=%d",
|
||||
len(syncData.Targets.TCP), len(syncData.Targets.UDP), len(syncData.HealthCheckTargets))
|
||||
|
||||
//TODO: TEST AND IMPLEMENT THIS
|
||||
|
||||
// // Build sets of desired targets (port -> target string)
|
||||
// desiredTCP := make(map[int]string)
|
||||
// for _, t := range syncData.Targets.TCP {
|
||||
// parts := strings.Split(t, ":")
|
||||
// if len(parts) != 3 {
|
||||
// logger.Warn("Invalid TCP target format: %s", t)
|
||||
// continue
|
||||
// }
|
||||
// port := 0
|
||||
// if _, err := fmt.Sscanf(parts[0], "%d", &port); err != nil {
|
||||
// logger.Warn("Invalid port in TCP target: %s", parts[0])
|
||||
// continue
|
||||
// }
|
||||
// desiredTCP[port] = parts[1] + ":" + parts[2]
|
||||
// }
|
||||
|
||||
// desiredUDP := make(map[int]string)
|
||||
// for _, t := range syncData.Targets.UDP {
|
||||
// parts := strings.Split(t, ":")
|
||||
// if len(parts) != 3 {
|
||||
// logger.Warn("Invalid UDP target format: %s", t)
|
||||
// continue
|
||||
// }
|
||||
// port := 0
|
||||
// if _, err := fmt.Sscanf(parts[0], "%d", &port); err != nil {
|
||||
// logger.Warn("Invalid port in UDP target: %s", parts[0])
|
||||
// continue
|
||||
// }
|
||||
// desiredUDP[port] = parts[1] + ":" + parts[2]
|
||||
// }
|
||||
|
||||
// // Get current targets from proxy manager
|
||||
// currentTCP, currentUDP := pm.GetTargets()
|
||||
|
||||
// // Sync TCP targets
|
||||
// // Remove TCP targets not in desired set
|
||||
// if tcpForIP, ok := currentTCP[wgData.TunnelIP]; ok {
|
||||
// for port := range tcpForIP {
|
||||
// if _, exists := desiredTCP[port]; !exists {
|
||||
// logger.Info("Sync: removing TCP target on port %d", port)
|
||||
// targetStr := fmt.Sprintf("%d:%s", port, tcpForIP[port])
|
||||
// updateTargets(pm, "remove", wgData.TunnelIP, "tcp", TargetData{Targets: []string{targetStr}})
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// // Add TCP targets that are missing
|
||||
// for port, target := range desiredTCP {
|
||||
// needsAdd := true
|
||||
// if tcpForIP, ok := currentTCP[wgData.TunnelIP]; ok {
|
||||
// if currentTarget, exists := tcpForIP[port]; exists {
|
||||
// // Check if target address changed
|
||||
// if currentTarget == target {
|
||||
// needsAdd = false
|
||||
// } else {
|
||||
// // Target changed, remove old one first
|
||||
// logger.Info("Sync: updating TCP target on port %d", port)
|
||||
// targetStr := fmt.Sprintf("%d:%s", port, currentTarget)
|
||||
// updateTargets(pm, "remove", wgData.TunnelIP, "tcp", TargetData{Targets: []string{targetStr}})
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// if needsAdd {
|
||||
// logger.Info("Sync: adding TCP target on port %d -> %s", port, target)
|
||||
// targetStr := fmt.Sprintf("%d:%s", port, target)
|
||||
// updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: []string{targetStr}})
|
||||
// }
|
||||
// }
|
||||
|
||||
// // Sync UDP targets
|
||||
// // Remove UDP targets not in desired set
|
||||
// if udpForIP, ok := currentUDP[wgData.TunnelIP]; ok {
|
||||
// for port := range udpForIP {
|
||||
// if _, exists := desiredUDP[port]; !exists {
|
||||
// logger.Info("Sync: removing UDP target on port %d", port)
|
||||
// targetStr := fmt.Sprintf("%d:%s", port, udpForIP[port])
|
||||
// updateTargets(pm, "remove", wgData.TunnelIP, "udp", TargetData{Targets: []string{targetStr}})
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// // Add UDP targets that are missing
|
||||
// for port, target := range desiredUDP {
|
||||
// needsAdd := true
|
||||
// if udpForIP, ok := currentUDP[wgData.TunnelIP]; ok {
|
||||
// if currentTarget, exists := udpForIP[port]; exists {
|
||||
// // Check if target address changed
|
||||
// if currentTarget == target {
|
||||
// needsAdd = false
|
||||
// } else {
|
||||
// // Target changed, remove old one first
|
||||
// logger.Info("Sync: updating UDP target on port %d", port)
|
||||
// targetStr := fmt.Sprintf("%d:%s", port, currentTarget)
|
||||
// updateTargets(pm, "remove", wgData.TunnelIP, "udp", TargetData{Targets: []string{targetStr}})
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// if needsAdd {
|
||||
// logger.Info("Sync: adding UDP target on port %d -> %s", port, target)
|
||||
// targetStr := fmt.Sprintf("%d:%s", port, target)
|
||||
// updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: []string{targetStr}})
|
||||
// }
|
||||
// }
|
||||
|
||||
// // Sync health check targets
|
||||
// if err := healthMonitor.SyncTargets(syncData.HealthCheckTargets); err != nil {
|
||||
// logger.Error("Failed to sync health check targets: %v", err)
|
||||
// } else {
|
||||
// logger.Info("Successfully synced health check targets")
|
||||
// }
|
||||
|
||||
logger.Info("Sync complete")
|
||||
})
|
||||
|
||||
// Register handler for Docker socket check
|
||||
client.RegisterHandler("newt/socket/check", func(msg websocket.WSMessage) {
|
||||
logger.Debug("Received Docker socket check request")
|
||||
|
||||
@@ -48,23 +48,6 @@ type SubnetRule struct {
|
||||
PortRanges []PortRange // empty slice means all ports allowed
|
||||
}
|
||||
|
||||
// GetAllRules returns a copy of all subnet rules
|
||||
func (sl *SubnetLookup) GetAllRules() []SubnetRule {
|
||||
sl.mu.RLock()
|
||||
defer sl.mu.RUnlock()
|
||||
|
||||
var rules []SubnetRule
|
||||
for _, destTriePtr := range sl.sourceTrie.All() {
|
||||
if destTriePtr == nil {
|
||||
continue
|
||||
}
|
||||
for _, rule := range destTriePtr.rules {
|
||||
rules = append(rules, *rule)
|
||||
}
|
||||
}
|
||||
return rules
|
||||
}
|
||||
|
||||
// connKey uniquely identifies a connection for NAT tracking
|
||||
type connKey struct {
|
||||
srcIP string
|
||||
@@ -217,14 +200,6 @@ func (p *ProxyHandler) RemoveSubnetRule(sourcePrefix, destPrefix netip.Prefix) {
|
||||
p.subnetLookup.RemoveSubnet(sourcePrefix, destPrefix)
|
||||
}
|
||||
|
||||
// GetAllRules returns all subnet rules from the proxy handler
|
||||
func (p *ProxyHandler) GetAllRules() []SubnetRule {
|
||||
if p == nil || !p.enabled {
|
||||
return nil
|
||||
}
|
||||
return p.subnetLookup.GetAllRules()
|
||||
}
|
||||
|
||||
// LookupDestinationRewrite looks up the rewritten destination for a connection
|
||||
// This is used by TCP/UDP handlers to find the actual target address
|
||||
func (p *ProxyHandler) LookupDestinationRewrite(srcIP, dstIP string, dstPort uint16, proto uint8) (netip.Addr, bool) {
|
||||
|
||||
@@ -369,15 +369,6 @@ func (net *Net) RemoveProxySubnetRule(sourcePrefix, destPrefix netip.Prefix) {
|
||||
}
|
||||
}
|
||||
|
||||
// GetProxySubnetRules returns all subnet rules from the proxy handler
|
||||
func (net *Net) GetProxySubnetRules() []SubnetRule {
|
||||
tun := (*netTun)(net)
|
||||
if tun.proxyHandler != nil {
|
||||
return tun.proxyHandler.GetAllRules()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetProxyHandler returns the proxy handler (for advanced use cases)
|
||||
// Returns nil if proxy is not enabled
|
||||
func (net *Net) GetProxyHandler() *ProxyHandler {
|
||||
|
||||
@@ -736,28 +736,3 @@ func (pm *ProxyManager) PrintTargets() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetTargets returns a copy of the current TCP and UDP targets
|
||||
// Returns map[listenIP]map[port]targetAddress for both TCP and UDP
|
||||
func (pm *ProxyManager) GetTargets() (tcpTargets map[string]map[int]string, udpTargets map[string]map[int]string) {
|
||||
pm.mutex.RLock()
|
||||
defer pm.mutex.RUnlock()
|
||||
|
||||
tcpTargets = make(map[string]map[int]string)
|
||||
for listenIP, targets := range pm.tcpTargets {
|
||||
tcpTargets[listenIP] = make(map[int]string)
|
||||
for port, targetAddr := range targets {
|
||||
tcpTargets[listenIP][port] = targetAddr
|
||||
}
|
||||
}
|
||||
|
||||
udpTargets = make(map[string]map[int]string)
|
||||
for listenIP, targets := range pm.udpTargets {
|
||||
udpTargets[listenIP] = make(map[int]string)
|
||||
for port, targetAddr := range targets {
|
||||
udpTargets[listenIP][port] = targetAddr
|
||||
}
|
||||
}
|
||||
|
||||
return tcpTargets, udpTargets
|
||||
}
|
||||
|
||||
94
util/util.go
94
util/util.go
@@ -1,6 +1,7 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
@@ -14,6 +15,99 @@ import (
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
)
|
||||
|
||||
func ResolveDomainUpstream(domain string, publicDNS []string) (string, error) {
|
||||
// trim whitespace
|
||||
domain = strings.TrimSpace(domain)
|
||||
|
||||
// Remove any protocol prefix if present (do this first, before splitting host/port)
|
||||
domain = strings.TrimPrefix(domain, "http://")
|
||||
domain = strings.TrimPrefix(domain, "https://")
|
||||
|
||||
// if there are any trailing slashes, remove them
|
||||
domain = strings.TrimSuffix(domain, "/")
|
||||
|
||||
// Check if there's a port in the domain
|
||||
host, port, err := net.SplitHostPort(domain)
|
||||
if err != nil {
|
||||
// No port found, use the domain as is
|
||||
host = domain
|
||||
port = ""
|
||||
}
|
||||
|
||||
// Check if host is already an IP address (IPv4 or IPv6)
|
||||
// For IPv6, the host from SplitHostPort will already have brackets stripped
|
||||
// but if there was no port, we need to handle bracketed IPv6 addresses
|
||||
cleanHost := strings.TrimPrefix(strings.TrimSuffix(host, "]"), "[")
|
||||
if ip := net.ParseIP(cleanHost); ip != nil {
|
||||
// It's already an IP address, no need to resolve
|
||||
ipAddr := ip.String()
|
||||
if port != "" {
|
||||
return net.JoinHostPort(ipAddr, port), nil
|
||||
}
|
||||
return ipAddr, nil
|
||||
}
|
||||
|
||||
// Lookup IP addresses using the upstream DNS servers if provided
|
||||
var ips []net.IP
|
||||
if len(publicDNS) > 0 {
|
||||
var lastErr error
|
||||
for _, server := range publicDNS {
|
||||
// Ensure the upstream DNS address has a port
|
||||
dnsAddr := server
|
||||
if _, _, err := net.SplitHostPort(dnsAddr); err != nil {
|
||||
// No port specified, default to 53
|
||||
dnsAddr = net.JoinHostPort(server, "53")
|
||||
}
|
||||
|
||||
resolver := &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
d := net.Dialer{}
|
||||
return d.DialContext(ctx, "udp", dnsAddr)
|
||||
},
|
||||
}
|
||||
ips, lastErr = resolver.LookupIP(context.Background(), "ip", host)
|
||||
if lastErr == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if lastErr != nil {
|
||||
return "", fmt.Errorf("DNS lookup failed using all upstream servers: %v", lastErr)
|
||||
}
|
||||
} else {
|
||||
ips, err = net.LookupIP(host)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("DNS lookup failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
return "", fmt.Errorf("no IP addresses found for domain %s", host)
|
||||
}
|
||||
|
||||
// Get the first IPv4 address if available
|
||||
var ipAddr string
|
||||
for _, ip := range ips {
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
ipAddr = ipv4.String()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If no IPv4 found, use the first IP (might be IPv6)
|
||||
if ipAddr == "" {
|
||||
ipAddr = ips[0].String()
|
||||
}
|
||||
|
||||
// Add port back if it existed
|
||||
if port != "" {
|
||||
ipAddr = net.JoinHostPort(ipAddr, port)
|
||||
}
|
||||
|
||||
return ipAddr, nil
|
||||
}
|
||||
|
||||
|
||||
func ResolveDomain(domain string) (string, error) {
|
||||
// trim whitespace
|
||||
domain = strings.TrimSpace(domain)
|
||||
|
||||
@@ -47,11 +47,6 @@ type Client struct {
|
||||
metricsCtx context.Context
|
||||
configNeedsSave bool // Flag to track if config needs to be saved
|
||||
serverVersion string
|
||||
configVersion int64 // Latest config version received from server
|
||||
configVersionMux sync.RWMutex
|
||||
processingMessage bool // Flag to track if a message is currently being processed
|
||||
processingMux sync.RWMutex // Protects processingMessage
|
||||
processingWg sync.WaitGroup // WaitGroup to wait for message processing to complete
|
||||
}
|
||||
|
||||
type ClientOption func(*Client)
|
||||
@@ -159,20 +154,6 @@ func (c *Client) GetServerVersion() string {
|
||||
return c.serverVersion
|
||||
}
|
||||
|
||||
// GetConfigVersion returns the latest config version received from server
|
||||
func (c *Client) GetConfigVersion() int64 {
|
||||
c.configVersionMux.RLock()
|
||||
defer c.configVersionMux.RUnlock()
|
||||
return c.configVersion
|
||||
}
|
||||
|
||||
// setConfigVersion updates the config version
|
||||
func (c *Client) setConfigVersion(version int64) {
|
||||
c.configVersionMux.Lock()
|
||||
defer c.configVersionMux.Unlock()
|
||||
c.configVersion = version
|
||||
}
|
||||
|
||||
// Connect establishes the WebSocket connection
|
||||
func (c *Client) Connect() error {
|
||||
go c.connectWithRetry()
|
||||
@@ -672,33 +653,12 @@ func (c *Client) pingMonitor() {
|
||||
if c.conn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Skip ping if a message is currently being processed
|
||||
c.processingMux.RLock()
|
||||
isProcessing := c.processingMessage
|
||||
c.processingMux.RUnlock()
|
||||
if isProcessing {
|
||||
logger.Debug("Skipping ping, message is being processed")
|
||||
continue
|
||||
}
|
||||
|
||||
c.configVersionMux.RLock()
|
||||
configVersion := c.configVersion
|
||||
c.configVersionMux.RUnlock()
|
||||
|
||||
pingMsg := WSMessage{
|
||||
Type: "newt/ping",
|
||||
Data: map[string]interface{}{},
|
||||
ConfigVersion: configVersion,
|
||||
}
|
||||
|
||||
c.writeMux.Lock()
|
||||
err := c.conn.WriteJSON(pingMsg)
|
||||
err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout))
|
||||
if err == nil {
|
||||
telemetry.IncWSMessage(c.metricsContext(), "out", "ping")
|
||||
}
|
||||
c.writeMux.Unlock()
|
||||
|
||||
if err != nil {
|
||||
// Check if we're shutting down before logging error and reconnecting
|
||||
select {
|
||||
@@ -777,24 +737,9 @@ func (c *Client) readPumpWithDisconnectDetection(started time.Time) {
|
||||
}
|
||||
}
|
||||
|
||||
// Update config version from incoming message
|
||||
c.setConfigVersion(msg.ConfigVersion)
|
||||
|
||||
c.handlersMux.RLock()
|
||||
if handler, ok := c.handlers[msg.Type]; ok {
|
||||
// Mark that we're processing a message
|
||||
c.processingMux.Lock()
|
||||
c.processingMessage = true
|
||||
c.processingMux.Unlock()
|
||||
c.processingWg.Add(1)
|
||||
|
||||
handler(msg)
|
||||
|
||||
// Mark that we're done processing
|
||||
c.processingWg.Done()
|
||||
c.processingMux.Lock()
|
||||
c.processingMessage = false
|
||||
c.processingMux.Unlock()
|
||||
}
|
||||
c.handlersMux.RUnlock()
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@ type TokenResponse struct {
|
||||
}
|
||||
|
||||
type WSMessage struct {
|
||||
Type string `json:"type"`
|
||||
Data interface{} `json:"data"`
|
||||
ConfigVersion int64 `json:"configVersion,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user