mirror of
https://github.com/fosrl/newt.git
synced 2025-12-05 19:17:38 -06:00
1529 lines
47 KiB
Go
1529 lines
47 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"net/netip"
|
|
"os"
|
|
"os/signal"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/fosrl/newt/docker"
|
|
"github.com/fosrl/newt/healthcheck"
|
|
"github.com/fosrl/newt/logger"
|
|
"github.com/fosrl/newt/proxy"
|
|
"github.com/fosrl/newt/updates"
|
|
"github.com/fosrl/newt/websocket"
|
|
|
|
"github.com/fosrl/newt/internal/state"
|
|
"github.com/fosrl/newt/internal/telemetry"
|
|
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
|
"golang.zx2c4.com/wireguard/conn"
|
|
"golang.zx2c4.com/wireguard/device"
|
|
"golang.zx2c4.com/wireguard/tun"
|
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
)
|
|
|
|
type WgData struct {
|
|
Endpoint string `json:"endpoint"`
|
|
PublicKey string `json:"publicKey"`
|
|
ServerIP string `json:"serverIP"`
|
|
TunnelIP string `json:"tunnelIP"`
|
|
Targets TargetsByType `json:"targets"`
|
|
HealthCheckTargets []healthcheck.Config `json:"healthCheckTargets"`
|
|
}
|
|
|
|
type TargetsByType struct {
|
|
UDP []string `json:"udp"`
|
|
TCP []string `json:"tcp"`
|
|
}
|
|
|
|
type TargetData struct {
|
|
Targets []string `json:"targets"`
|
|
}
|
|
|
|
type ExitNodeData struct {
|
|
ExitNodes []ExitNode `json:"exitNodes"`
|
|
}
|
|
|
|
type SSHPublicKeyData struct {
|
|
PublicKey string `json:"publicKey"`
|
|
}
|
|
|
|
// ExitNode represents an exit node with an ID, endpoint, and weight.
|
|
type ExitNode struct {
|
|
ID int `json:"exitNodeId"`
|
|
Name string `json:"exitNodeName"`
|
|
Endpoint string `json:"endpoint"`
|
|
Weight float64 `json:"weight"`
|
|
WasPreviouslyConnected bool `json:"wasPreviouslyConnected"`
|
|
}
|
|
|
|
type ExitNodePingResult struct {
|
|
ExitNodeID int `json:"exitNodeId"`
|
|
LatencyMs int64 `json:"latencyMs"`
|
|
Weight float64 `json:"weight"`
|
|
Error string `json:"error,omitempty"`
|
|
Name string `json:"exitNodeName"`
|
|
Endpoint string `json:"endpoint"`
|
|
WasPreviouslyConnected bool `json:"wasPreviouslyConnected"`
|
|
}
|
|
|
|
type BlueprintResult struct {
|
|
Success bool `json:"success"`
|
|
Message string `json:"message,omitempty"`
|
|
}
|
|
|
|
// Custom flag type for multiple CA files
|
|
type stringSlice []string
|
|
|
|
func (s *stringSlice) String() string {
|
|
return strings.Join(*s, ",")
|
|
}
|
|
|
|
func (s *stringSlice) Set(value string) error {
|
|
*s = append(*s, value)
|
|
return nil
|
|
}
|
|
|
|
const (
|
|
fmtErrMarshaling = "Error marshaling data: %v"
|
|
fmtReceivedMsg = "Received: %+v"
|
|
topicWGRegister = "newt/wg/register"
|
|
msgNoTunnelOrProxy = "No tunnel IP or proxy manager available"
|
|
fmtErrParsingTargetData = "Error parsing target data: %v"
|
|
)
|
|
|
|
var (
|
|
endpoint string
|
|
id string
|
|
secret string
|
|
mtu string
|
|
mtuInt int
|
|
dns string
|
|
privateKey wgtypes.Key
|
|
err error
|
|
logLevel string
|
|
interfaceName string
|
|
generateAndSaveKeyTo string
|
|
keepInterface bool
|
|
acceptClients bool
|
|
updownScript string
|
|
dockerSocket string
|
|
dockerEnforceNetworkValidation string
|
|
dockerEnforceNetworkValidationBool bool
|
|
pingInterval time.Duration
|
|
pingTimeout time.Duration
|
|
publicKey wgtypes.Key
|
|
pingStopChan chan struct{}
|
|
stopFunc func()
|
|
healthFile string
|
|
useNativeInterface bool
|
|
authorizedKeysFile string
|
|
preferEndpoint string
|
|
healthMonitor *healthcheck.Monitor
|
|
enforceHealthcheckCert bool
|
|
// Build/version (can be overridden via -ldflags "-X main.newtVersion=...")
|
|
newtVersion = "version_replaceme"
|
|
|
|
// Observability/metrics flags
|
|
metricsEnabled bool
|
|
otlpEnabled bool
|
|
adminAddr string
|
|
region string
|
|
metricsAsyncBytes bool
|
|
blueprintFile string
|
|
noCloud bool
|
|
|
|
// New mTLS configuration variables
|
|
tlsClientCert string
|
|
tlsClientKey string
|
|
tlsClientCAs []string
|
|
|
|
// Legacy PKCS12 support (deprecated)
|
|
tlsPrivateKey string
|
|
)
|
|
|
|
func main() {
|
|
// Prepare context for graceful shutdown and signal handling
|
|
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
|
defer stop()
|
|
|
|
// if PANGOLIN_ENDPOINT, NEWT_ID, and NEWT_SECRET are set as environment variables, they will be used as default values
|
|
endpoint = os.Getenv("PANGOLIN_ENDPOINT")
|
|
id = os.Getenv("NEWT_ID")
|
|
secret = os.Getenv("NEWT_SECRET")
|
|
mtu = os.Getenv("MTU")
|
|
dns = os.Getenv("DNS")
|
|
logLevel = os.Getenv("LOG_LEVEL")
|
|
updownScript = os.Getenv("UPDOWN_SCRIPT")
|
|
interfaceName = os.Getenv("INTERFACE")
|
|
generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO")
|
|
|
|
// Metrics/observability env mirrors
|
|
metricsEnabledEnv := os.Getenv("NEWT_METRICS_PROMETHEUS_ENABLED")
|
|
otlpEnabledEnv := os.Getenv("NEWT_METRICS_OTLP_ENABLED")
|
|
adminAddrEnv := os.Getenv("NEWT_ADMIN_ADDR")
|
|
regionEnv := os.Getenv("NEWT_REGION")
|
|
asyncBytesEnv := os.Getenv("NEWT_METRICS_ASYNC_BYTES")
|
|
|
|
keepInterfaceEnv := os.Getenv("KEEP_INTERFACE")
|
|
keepInterface = keepInterfaceEnv == "true"
|
|
acceptClientsEnv := os.Getenv("ACCEPT_CLIENTS")
|
|
acceptClients = acceptClientsEnv == "true"
|
|
useNativeInterfaceEnv := os.Getenv("USE_NATIVE_INTERFACE")
|
|
useNativeInterface = useNativeInterfaceEnv == "true"
|
|
enforceHealthcheckCertEnv := os.Getenv("ENFORCE_HC_CERT")
|
|
enforceHealthcheckCert = enforceHealthcheckCertEnv == "true"
|
|
dockerSocket = os.Getenv("DOCKER_SOCKET")
|
|
pingIntervalStr := os.Getenv("PING_INTERVAL")
|
|
pingTimeoutStr := os.Getenv("PING_TIMEOUT")
|
|
dockerEnforceNetworkValidation = os.Getenv("DOCKER_ENFORCE_NETWORK_VALIDATION")
|
|
healthFile = os.Getenv("HEALTH_FILE")
|
|
// authorizedKeysFile = os.Getenv("AUTHORIZED_KEYS_FILE")
|
|
authorizedKeysFile = ""
|
|
|
|
// Read new mTLS environment variables
|
|
tlsClientCert = os.Getenv("TLS_CLIENT_CERT")
|
|
tlsClientKey = os.Getenv("TLS_CLIENT_KEY")
|
|
tlsClientCAsEnv := os.Getenv("TLS_CLIENT_CAS")
|
|
if tlsClientCAsEnv != "" {
|
|
tlsClientCAs = strings.Split(tlsClientCAsEnv, ",")
|
|
// Trim spaces from each CA file path
|
|
for i, ca := range tlsClientCAs {
|
|
tlsClientCAs[i] = strings.TrimSpace(ca)
|
|
}
|
|
}
|
|
|
|
// Legacy PKCS12 support (deprecated)
|
|
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT_PKCS12")
|
|
// Keep backward compatibility with old environment variable name
|
|
if tlsPrivateKey == "" && tlsClientKey == "" && len(tlsClientCAs) == 0 {
|
|
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT")
|
|
}
|
|
blueprintFile = os.Getenv("BLUEPRINT_FILE")
|
|
noCloudEnv := os.Getenv("NO_CLOUD")
|
|
noCloud = noCloudEnv == "true"
|
|
|
|
if endpoint == "" {
|
|
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
|
|
}
|
|
if id == "" {
|
|
flag.StringVar(&id, "id", "", "Newt ID")
|
|
}
|
|
if secret == "" {
|
|
flag.StringVar(&secret, "secret", "", "Newt secret")
|
|
}
|
|
if mtu == "" {
|
|
flag.StringVar(&mtu, "mtu", "1280", "MTU to use")
|
|
}
|
|
if dns == "" {
|
|
flag.StringVar(&dns, "dns", "9.9.9.9", "DNS server to use")
|
|
}
|
|
if logLevel == "" {
|
|
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
|
}
|
|
if updownScript == "" {
|
|
flag.StringVar(&updownScript, "updown", "", "Path to updown script to be called when targets are added or removed")
|
|
}
|
|
if interfaceName == "" {
|
|
flag.StringVar(&interfaceName, "interface", "newt", "Name of the WireGuard interface")
|
|
}
|
|
if generateAndSaveKeyTo == "" {
|
|
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
|
|
}
|
|
if keepInterfaceEnv == "" {
|
|
flag.BoolVar(&keepInterface, "keep-interface", false, "Keep the WireGuard interface")
|
|
}
|
|
if useNativeInterfaceEnv == "" {
|
|
flag.BoolVar(&useNativeInterface, "native", false, "Use native WireGuard interface (requires WireGuard kernel module) and linux")
|
|
}
|
|
if acceptClientsEnv == "" {
|
|
flag.BoolVar(&acceptClients, "accept-clients", false, "Accept clients on the WireGuard interface")
|
|
}
|
|
if enforceHealthcheckCertEnv == "" {
|
|
flag.BoolVar(&enforceHealthcheckCert, "enforce-hc-cert", false, "Enforce certificate validation for health checks (default: false, accepts any cert)")
|
|
}
|
|
if dockerSocket == "" {
|
|
flag.StringVar(&dockerSocket, "docker-socket", "", "Path or address to Docker socket (typically unix:///var/run/docker.sock)")
|
|
}
|
|
if pingIntervalStr == "" {
|
|
flag.StringVar(&pingIntervalStr, "ping-interval", "3s", "Interval for pinging the server (default 3s)")
|
|
}
|
|
if pingTimeoutStr == "" {
|
|
flag.StringVar(&pingTimeoutStr, "ping-timeout", "5s", " Timeout for each ping (default 5s)")
|
|
}
|
|
// load the prefer endpoint just as a flag
|
|
flag.StringVar(&preferEndpoint, "prefer-endpoint", "", "Prefer this endpoint for the connection (if set, will override the endpoint from the server)")
|
|
|
|
// if authorizedKeysFile == "" {
|
|
// flag.StringVar(&authorizedKeysFile, "authorized-keys-file", "~/.ssh/authorized_keys", "Path to authorized keys file (if unset, no keys will be authorized)")
|
|
// }
|
|
|
|
// Add new mTLS flags
|
|
if tlsClientCert == "" {
|
|
flag.StringVar(&tlsClientCert, "tls-client-cert-file", "", "Path to client certificate file (PEM/DER format)")
|
|
}
|
|
if tlsClientKey == "" {
|
|
flag.StringVar(&tlsClientKey, "tls-client-key", "", "Path to client private key file (PEM/DER format)")
|
|
}
|
|
|
|
// Handle multiple CA files
|
|
var tlsClientCAsFlag stringSlice
|
|
flag.Var(&tlsClientCAsFlag, "tls-client-ca", "Path to CA certificate file for validating remote certificates (can be specified multiple times)")
|
|
|
|
// Legacy PKCS12 flag (deprecated)
|
|
if tlsPrivateKey == "" {
|
|
flag.StringVar(&tlsPrivateKey, "tls-client-cert", "", "Path to client certificate (PKCS12 format) - DEPRECATED: use --tls-client-cert-file and --tls-client-key instead")
|
|
}
|
|
|
|
if pingIntervalStr != "" {
|
|
pingInterval, err = time.ParseDuration(pingIntervalStr)
|
|
if err != nil {
|
|
fmt.Printf("Invalid PING_INTERVAL value: %s, using default 3 seconds\n", pingIntervalStr)
|
|
pingInterval = 3 * time.Second
|
|
}
|
|
} else {
|
|
pingInterval = 3 * time.Second
|
|
}
|
|
|
|
if pingTimeoutStr != "" {
|
|
pingTimeout, err = time.ParseDuration(pingTimeoutStr)
|
|
if err != nil {
|
|
fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 5 seconds\n", pingTimeoutStr)
|
|
pingTimeout = 5 * time.Second
|
|
}
|
|
} else {
|
|
pingTimeout = 5 * time.Second
|
|
}
|
|
|
|
if dockerEnforceNetworkValidation == "" {
|
|
flag.StringVar(&dockerEnforceNetworkValidation, "docker-enforce-network-validation", "false", "Enforce validation of container on newt network (true or false)")
|
|
}
|
|
if healthFile == "" {
|
|
flag.StringVar(&healthFile, "health-file", "", "Path to health file (if unset, health file won't be written)")
|
|
}
|
|
if blueprintFile == "" {
|
|
flag.StringVar(&blueprintFile, "blueprint-file", "", "Path to blueprint file (if unset, no blueprint will be applied)")
|
|
}
|
|
if noCloudEnv == "" {
|
|
flag.BoolVar(&noCloud, "no-cloud", false, "Disable cloud failover")
|
|
}
|
|
|
|
// Metrics/observability flags (mirror ENV if unset)
|
|
if metricsEnabledEnv == "" {
|
|
flag.BoolVar(&metricsEnabled, "metrics", true, "Enable Prometheus /metrics exporter")
|
|
} else {
|
|
if v, err := strconv.ParseBool(metricsEnabledEnv); err == nil {
|
|
metricsEnabled = v
|
|
} else {
|
|
metricsEnabled = true
|
|
}
|
|
}
|
|
if otlpEnabledEnv == "" {
|
|
flag.BoolVar(&otlpEnabled, "otlp", false, "Enable OTLP exporters (metrics/traces) to OTEL_EXPORTER_OTLP_ENDPOINT")
|
|
} else {
|
|
if v, err := strconv.ParseBool(otlpEnabledEnv); err == nil {
|
|
otlpEnabled = v
|
|
}
|
|
}
|
|
if adminAddrEnv == "" {
|
|
flag.StringVar(&adminAddr, "metrics-admin-addr", "127.0.0.1:2112", "Admin/metrics bind address")
|
|
} else {
|
|
adminAddr = adminAddrEnv
|
|
}
|
|
// Async bytes toggle
|
|
if asyncBytesEnv == "" {
|
|
flag.BoolVar(&metricsAsyncBytes, "metrics-async-bytes", false, "Enable async bytes counting (background flush; lower hot path overhead)")
|
|
} else {
|
|
if v, err := strconv.ParseBool(asyncBytesEnv); err == nil {
|
|
metricsAsyncBytes = v
|
|
}
|
|
}
|
|
// Optional region flag (resource attribute)
|
|
if regionEnv == "" {
|
|
flag.StringVar(®ion, "region", "", "Optional region resource attribute (also NEWT_REGION)")
|
|
} else {
|
|
region = regionEnv
|
|
}
|
|
|
|
// do a --version check
|
|
version := flag.Bool("version", false, "Print the version")
|
|
|
|
flag.Parse()
|
|
|
|
// Merge command line CA flags with environment variable CAs
|
|
if len(tlsClientCAsFlag) > 0 {
|
|
tlsClientCAs = append(tlsClientCAs, tlsClientCAsFlag...)
|
|
}
|
|
|
|
logger.Init()
|
|
loggerLevel := parseLogLevel(logLevel)
|
|
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
|
|
|
|
// Initialize telemetry after flags are parsed (so flags override env)
|
|
tcfg := telemetry.FromEnv()
|
|
tcfg.PromEnabled = metricsEnabled
|
|
tcfg.OTLPEnabled = otlpEnabled
|
|
if adminAddr != "" {
|
|
tcfg.AdminAddr = adminAddr
|
|
}
|
|
// Resource attributes (if available)
|
|
tcfg.SiteID = id
|
|
tcfg.Region = region
|
|
// Build info
|
|
tcfg.BuildVersion = newtVersion
|
|
tcfg.BuildCommit = os.Getenv("NEWT_COMMIT")
|
|
|
|
tel, telErr := telemetry.Init(ctx, tcfg)
|
|
if telErr != nil {
|
|
logger.Warn("Telemetry init failed: %v", telErr)
|
|
}
|
|
if tel != nil {
|
|
// Admin HTTP server (exposes /metrics when Prometheus exporter is enabled)
|
|
logger.Info("Starting metrics server on %s", tcfg.AdminAddr)
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) })
|
|
if tel.PrometheusHandler != nil {
|
|
mux.Handle("/metrics", tel.PrometheusHandler)
|
|
}
|
|
admin := &http.Server{
|
|
Addr: tcfg.AdminAddr,
|
|
Handler: otelhttp.NewHandler(mux, "newt-admin"),
|
|
ReadTimeout: 5 * time.Second,
|
|
WriteTimeout: 10 * time.Second,
|
|
ReadHeaderTimeout: 5 * time.Second,
|
|
IdleTimeout: 30 * time.Second,
|
|
}
|
|
go func() {
|
|
if err := admin.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
logger.Warn("admin http error: %v", err)
|
|
}
|
|
}()
|
|
defer func() {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
_ = admin.Shutdown(ctx)
|
|
}()
|
|
defer func() { _ = tel.Shutdown(context.Background()) }()
|
|
}
|
|
|
|
if *version {
|
|
fmt.Println("Newt version " + newtVersion)
|
|
os.Exit(0)
|
|
} else {
|
|
logger.Info("Newt version %s", newtVersion)
|
|
}
|
|
|
|
if err := updates.CheckForUpdate("fosrl", "newt", newtVersion); err != nil {
|
|
logger.Error("Error checking for updates: %v\n", err)
|
|
}
|
|
|
|
// parse the mtu string into an int
|
|
mtuInt, err = strconv.Atoi(mtu)
|
|
if err != nil {
|
|
logger.Fatal("Failed to parse MTU: %v", err)
|
|
}
|
|
|
|
// parse if we want to enforce container network validation
|
|
dockerEnforceNetworkValidationBool, err = strconv.ParseBool(dockerEnforceNetworkValidation)
|
|
if err != nil {
|
|
logger.Info("Docker enforce network validation cannot be parsed. Defaulting to 'false'")
|
|
dockerEnforceNetworkValidationBool = false
|
|
}
|
|
|
|
// Add TLS configuration validation
|
|
if err := validateTLSConfig(); err != nil {
|
|
logger.Fatal("TLS configuration error: %v", err)
|
|
}
|
|
|
|
// Show deprecation warning if using PKCS12
|
|
if tlsPrivateKey != "" {
|
|
logger.Warn("Using deprecated PKCS12 format for mTLS. Consider migrating to separate certificate files using --tls-client-cert-file, --tls-client-key, and --tls-client-ca")
|
|
}
|
|
|
|
privateKey, err = wgtypes.GeneratePrivateKey()
|
|
if err != nil {
|
|
logger.Fatal("Failed to generate private key: %v", err)
|
|
}
|
|
|
|
// Create client option based on TLS configuration
|
|
var opt websocket.ClientOption
|
|
if tlsClientCert != "" && tlsClientKey != "" {
|
|
// Use new separate certificate configuration
|
|
opt = websocket.WithTLSConfig(websocket.TLSConfig{
|
|
ClientCertFile: tlsClientCert,
|
|
ClientKeyFile: tlsClientKey,
|
|
CAFiles: tlsClientCAs,
|
|
})
|
|
logger.Debug("Using separate certificate files for mTLS")
|
|
logger.Debug("Client cert: %s", tlsClientCert)
|
|
logger.Debug("Client key: %s", tlsClientKey)
|
|
logger.Debug("CA files: %v", tlsClientCAs)
|
|
} else if tlsPrivateKey != "" {
|
|
// Use existing PKCS12 configuration for backward compatibility
|
|
opt = websocket.WithTLSConfig(websocket.TLSConfig{
|
|
PKCS12File: tlsPrivateKey,
|
|
})
|
|
logger.Debug("Using PKCS12 file for mTLS: %s", tlsPrivateKey)
|
|
}
|
|
|
|
// Create a new client
|
|
client, err := websocket.NewClient(
|
|
"newt",
|
|
id, // CLI arg takes precedence
|
|
secret, // CLI arg takes precedence
|
|
endpoint,
|
|
pingInterval,
|
|
pingTimeout,
|
|
opt,
|
|
)
|
|
if err != nil {
|
|
logger.Fatal("Failed to create client: %v", err)
|
|
}
|
|
endpoint = client.GetConfig().Endpoint // Update endpoint from config
|
|
id = client.GetConfig().ID // Update ID from config
|
|
// Update site labels for metrics with the resolved ID
|
|
telemetry.UpdateSiteInfo(id, region)
|
|
|
|
// output env var values if set
|
|
logger.Debug("Endpoint: %v", endpoint)
|
|
logger.Debug("Log Level: %v", logLevel)
|
|
logger.Debug("Docker Network Validation Enabled: %v", dockerEnforceNetworkValidationBool)
|
|
logger.Debug("Health Check Certificate Enforcement: %v", enforceHealthcheckCert)
|
|
|
|
// Add new TLS debug logging
|
|
if tlsClientCert != "" {
|
|
logger.Debug("TLS Client Cert File: %v", tlsClientCert)
|
|
}
|
|
if tlsClientKey != "" {
|
|
logger.Debug("TLS Client Key File: %v", tlsClientKey)
|
|
}
|
|
if len(tlsClientCAs) > 0 {
|
|
logger.Debug("TLS CA Files: %v", tlsClientCAs)
|
|
}
|
|
if tlsPrivateKey != "" {
|
|
logger.Debug("TLS PKCS12 File: %v", tlsPrivateKey)
|
|
}
|
|
|
|
if dns != "" {
|
|
logger.Debug("Dns: %v", dns)
|
|
}
|
|
if dockerSocket != "" {
|
|
logger.Debug("Docker Socket: %v", dockerSocket)
|
|
}
|
|
if mtu != "" {
|
|
logger.Debug("MTU: %v", mtu)
|
|
}
|
|
if updownScript != "" {
|
|
logger.Debug("Up Down Script: %v", updownScript)
|
|
}
|
|
|
|
// Create TUN device and network stack
|
|
var tun tun.Device
|
|
var tnet *netstack.Net
|
|
var dev *device.Device
|
|
var pm *proxy.ProxyManager
|
|
var connected bool
|
|
var wgData WgData
|
|
var dockerEventMonitor *docker.EventMonitor
|
|
|
|
if acceptClients {
|
|
setupClients(client)
|
|
}
|
|
|
|
// Initialize health check monitor with status change callback
|
|
healthMonitor = healthcheck.NewMonitor(func(targets map[int]*healthcheck.Target) {
|
|
logger.Debug("Health check status update for %d targets", len(targets))
|
|
|
|
// Send health status update to the server
|
|
healthStatuses := make(map[int]interface{})
|
|
for id, target := range targets {
|
|
healthStatuses[id] = map[string]interface{}{
|
|
"status": target.Status.String(),
|
|
"lastCheck": target.LastCheck.Format(time.RFC3339),
|
|
"checkCount": target.CheckCount,
|
|
"lastError": target.LastError,
|
|
"config": target.Config,
|
|
}
|
|
}
|
|
|
|
// print the status of the targets
|
|
logger.Debug("Health check status: %+v", healthStatuses)
|
|
|
|
err := client.SendMessage("newt/healthcheck/status", map[string]interface{}{
|
|
"targets": healthStatuses,
|
|
})
|
|
if err != nil {
|
|
logger.Error("Failed to send health check status update: %v", err)
|
|
}
|
|
}, enforceHealthcheckCert)
|
|
|
|
var pingWithRetryStopChan chan struct{}
|
|
|
|
closeWgTunnel := func() {
|
|
if pingStopChan != nil {
|
|
// Stop the ping check
|
|
close(pingStopChan)
|
|
pingStopChan = nil
|
|
}
|
|
|
|
// Stop proxy manager if running
|
|
if pm != nil {
|
|
pm.Stop()
|
|
pm = nil
|
|
}
|
|
|
|
// Close WireGuard device first - this will automatically close the TUN device
|
|
if dev != nil {
|
|
dev.Close()
|
|
dev = nil
|
|
}
|
|
|
|
// Clear references but don't manually close since dev.Close() already did it
|
|
if tnet != nil {
|
|
tnet = nil
|
|
}
|
|
if tun != nil {
|
|
tun = nil // Don't call tun.Close() here since dev.Close() already closed it
|
|
}
|
|
|
|
}
|
|
|
|
// Register handlers for different message types
|
|
client.RegisterHandler("newt/wg/connect", func(msg websocket.WSMessage) {
|
|
logger.Debug("Received registration message")
|
|
regResult := "success"
|
|
defer func() {
|
|
telemetry.IncSiteRegistration(ctx, regResult)
|
|
}()
|
|
if stopFunc != nil {
|
|
stopFunc() // stop the ws from sending more requests
|
|
stopFunc = nil // reset stopFunc to nil to avoid double stopping
|
|
}
|
|
|
|
if connected {
|
|
// Mark as disconnected
|
|
|
|
closeWgTunnel()
|
|
|
|
connected = false
|
|
}
|
|
|
|
// print out the data
|
|
logger.Debug("Received registration message data: %+v", msg.Data)
|
|
|
|
jsonData, err := json.Marshal(msg.Data)
|
|
if err != nil {
|
|
logger.Info(fmtErrMarshaling, err)
|
|
regResult = "failure"
|
|
return
|
|
}
|
|
|
|
if err := json.Unmarshal(jsonData, &wgData); err != nil {
|
|
logger.Info("Error unmarshaling target data: %v", err)
|
|
regResult = "failure"
|
|
return
|
|
}
|
|
|
|
logger.Debug(fmtReceivedMsg, msg)
|
|
tun, tnet, err = netstack.CreateNetTUN(
|
|
[]netip.Addr{netip.MustParseAddr(wgData.TunnelIP)},
|
|
[]netip.Addr{netip.MustParseAddr(dns)},
|
|
mtuInt)
|
|
if err != nil {
|
|
logger.Error("Failed to create TUN device: %v", err)
|
|
regResult = "failure"
|
|
}
|
|
|
|
setDownstreamTNetstack(tnet)
|
|
|
|
// Create WireGuard device
|
|
dev = device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(
|
|
mapToWireGuardLogLevel(loggerLevel),
|
|
"wireguard: ",
|
|
))
|
|
|
|
host, _, err := net.SplitHostPort(wgData.Endpoint)
|
|
if err != nil {
|
|
logger.Error("Failed to split endpoint: %v", err)
|
|
regResult = "failure"
|
|
return
|
|
}
|
|
|
|
logger.Info("Connecting to endpoint: %s", host)
|
|
|
|
endpoint, err := resolveDomain(wgData.Endpoint)
|
|
if err != nil {
|
|
logger.Error("Failed to resolve endpoint: %v", err)
|
|
regResult = "failure"
|
|
return
|
|
}
|
|
|
|
clientsHandleNewtConnection(wgData.PublicKey, endpoint)
|
|
|
|
// Configure WireGuard
|
|
config := fmt.Sprintf(`private_key=%s
|
|
public_key=%s
|
|
allowed_ip=%s/32
|
|
endpoint=%s
|
|
persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint)
|
|
|
|
err = dev.IpcSet(config)
|
|
if err != nil {
|
|
logger.Error("Failed to configure WireGuard device: %v", err)
|
|
regResult = "failure"
|
|
}
|
|
|
|
// Bring up the device
|
|
err = dev.Up()
|
|
if err != nil {
|
|
logger.Error("Failed to bring up WireGuard device: %v", err)
|
|
regResult = "failure"
|
|
}
|
|
|
|
logger.Debug("WireGuard device created. Lets ping the server now...")
|
|
|
|
// Even if pingWithRetry returns an error, it will continue trying in the background
|
|
if pingWithRetryStopChan != nil {
|
|
// Stop the previous pingWithRetry if it exists
|
|
close(pingWithRetryStopChan)
|
|
pingWithRetryStopChan = nil
|
|
}
|
|
// Use reliable ping for initial connection test
|
|
logger.Debug("Testing initial connection with reliable ping...")
|
|
lat, err := reliablePing(tnet, wgData.ServerIP, pingTimeout, 5)
|
|
if err == nil && wgData.PublicKey != "" {
|
|
telemetry.ObserveTunnelLatency(ctx, wgData.PublicKey, "wireguard", lat.Seconds())
|
|
}
|
|
if err != nil {
|
|
logger.Warn("Initial reliable ping failed, but continuing: %v", err)
|
|
regResult = "failure"
|
|
} else {
|
|
logger.Debug("Initial connection test successful")
|
|
}
|
|
|
|
pingWithRetryStopChan, _ = pingWithRetry(tnet, wgData.ServerIP, pingTimeout)
|
|
|
|
// Always mark as connected and start the proxy manager regardless of initial ping result
|
|
// as the pings will continue in the background
|
|
if !connected {
|
|
logger.Debug("Starting ping check")
|
|
pingStopChan = startPingCheck(tnet, wgData.ServerIP, client, wgData.PublicKey)
|
|
}
|
|
|
|
// Create proxy manager
|
|
pm = proxy.NewProxyManager(tnet)
|
|
pm.SetAsyncBytes(metricsAsyncBytes)
|
|
// Set tunnel_id for metrics (WireGuard peer public key)
|
|
pm.SetTunnelID(wgData.PublicKey)
|
|
|
|
connected = true
|
|
|
|
// add the targets if there are any
|
|
if len(wgData.Targets.TCP) > 0 {
|
|
updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: wgData.Targets.TCP})
|
|
// Also update wgnetstack proxy manager
|
|
// if wgService != nil {
|
|
// updateTargets(wgService.GetProxyManager(), "add", wgData.TunnelIP, "tcp", TargetData{Targets: wgData.Targets.TCP})
|
|
// }
|
|
}
|
|
|
|
if len(wgData.Targets.UDP) > 0 {
|
|
updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: wgData.Targets.UDP})
|
|
// Also update wgnetstack proxy manager
|
|
// if wgService != nil {
|
|
// updateTargets(wgService.GetProxyManager(), "add", wgData.TunnelIP, "udp", TargetData{Targets: wgData.Targets.UDP})
|
|
// }
|
|
}
|
|
|
|
clientsAddProxyTarget(pm, wgData.TunnelIP)
|
|
|
|
if err := healthMonitor.AddTargets(wgData.HealthCheckTargets); err != nil {
|
|
logger.Error("Failed to bulk add health check targets: %v", err)
|
|
} else {
|
|
logger.Debug("Successfully added %d health check targets", len(wgData.HealthCheckTargets))
|
|
}
|
|
|
|
err = pm.Start()
|
|
if err != nil {
|
|
logger.Error("Failed to start proxy manager: %v", err)
|
|
}
|
|
})
|
|
|
|
client.RegisterHandler("newt/wg/reconnect", func(msg websocket.WSMessage) {
|
|
logger.Info("Received reconnect message")
|
|
if wgData.PublicKey != "" {
|
|
telemetry.IncReconnect(ctx, wgData.PublicKey, "server", telemetry.ReasonServerRequest)
|
|
}
|
|
|
|
// Close the WireGuard device and TUN
|
|
closeWgTunnel()
|
|
|
|
// Clear metrics attrs and sessions for the tunnel
|
|
if pm != nil {
|
|
pm.ClearTunnelID()
|
|
state.Global().ClearTunnel(wgData.PublicKey)
|
|
}
|
|
|
|
// Mark as disconnected
|
|
connected = false
|
|
|
|
if stopFunc != nil {
|
|
stopFunc() // stop the ws from sending more requests
|
|
stopFunc = nil // reset stopFunc to nil to avoid double stopping
|
|
}
|
|
|
|
// Request exit nodes from the server
|
|
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{
|
|
"noCloud": noCloud,
|
|
}, 3*time.Second)
|
|
|
|
logger.Info("Tunnel destroyed, ready for reconnection")
|
|
})
|
|
|
|
client.RegisterHandler("newt/wg/terminate", func(msg websocket.WSMessage) {
|
|
logger.Info("Received termination message")
|
|
if wgData.PublicKey != "" {
|
|
telemetry.IncReconnect(ctx, wgData.PublicKey, "server", telemetry.ReasonServerRequest)
|
|
}
|
|
|
|
// Close the WireGuard device and TUN
|
|
closeWgTunnel()
|
|
|
|
if stopFunc != nil {
|
|
stopFunc() // stop the ws from sending more requests
|
|
stopFunc = nil // reset stopFunc to nil to avoid double stopping
|
|
}
|
|
|
|
// Mark as disconnected
|
|
connected = false
|
|
|
|
logger.Info("Tunnel destroyed")
|
|
})
|
|
|
|
client.RegisterHandler("newt/ping/exitNodes", func(msg websocket.WSMessage) {
|
|
logger.Debug("Received ping message")
|
|
if stopFunc != nil {
|
|
stopFunc() // stop the ws from sending more requests
|
|
stopFunc = nil // reset stopFunc to nil to avoid double stopping
|
|
}
|
|
|
|
// Parse the incoming list of exit nodes
|
|
var exitNodeData ExitNodeData
|
|
|
|
jsonData, err := json.Marshal(msg.Data)
|
|
if err != nil {
|
|
logger.Info(fmtErrMarshaling, err)
|
|
return
|
|
}
|
|
if err := json.Unmarshal(jsonData, &exitNodeData); err != nil {
|
|
logger.Info("Error unmarshaling exit node data: %v", err)
|
|
return
|
|
}
|
|
exitNodes := exitNodeData.ExitNodes
|
|
|
|
if len(exitNodes) == 0 {
|
|
logger.Info("No exit nodes provided")
|
|
return
|
|
}
|
|
|
|
// If there is just one exit node, we can skip pinging it and use it directly
|
|
if len(exitNodes) == 1 || preferEndpoint != "" {
|
|
logger.Debug("Only one exit node available, using it directly: %s", exitNodes[0].Endpoint)
|
|
|
|
// if the preferEndpoint is set, we will use it instead of the exit node endpoint. first you need to find the exit node with that endpoint in the list and send that one
|
|
if preferEndpoint != "" {
|
|
for _, node := range exitNodes {
|
|
if node.Endpoint == preferEndpoint {
|
|
exitNodes[0] = node
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// Prepare data to send to the cloud for selection
|
|
pingResults := []ExitNodePingResult{
|
|
{
|
|
ExitNodeID: exitNodes[0].ID,
|
|
LatencyMs: 0, // No ping latency since we are using it directly
|
|
Weight: exitNodes[0].Weight,
|
|
Error: "",
|
|
Name: exitNodes[0].Name,
|
|
Endpoint: exitNodes[0].Endpoint,
|
|
WasPreviouslyConnected: exitNodes[0].WasPreviouslyConnected,
|
|
},
|
|
}
|
|
|
|
stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{
|
|
"publicKey": publicKey.String(),
|
|
"pingResults": pingResults,
|
|
"newtVersion": newtVersion,
|
|
}, 1*time.Second)
|
|
|
|
return
|
|
}
|
|
|
|
type nodeResult struct {
|
|
Node ExitNode
|
|
Latency time.Duration
|
|
Err error
|
|
}
|
|
|
|
results := make([]nodeResult, len(exitNodes))
|
|
const pingAttempts = 3
|
|
for i, node := range exitNodes {
|
|
var totalLatency time.Duration
|
|
var lastErr error
|
|
successes := 0
|
|
client := &http.Client{
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
url := node.Endpoint
|
|
if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") {
|
|
url = "http://" + url
|
|
}
|
|
if !strings.HasSuffix(url, "/ping") {
|
|
url = strings.TrimRight(url, "/") + "/ping"
|
|
}
|
|
for j := 0; j < pingAttempts; j++ {
|
|
start := time.Now()
|
|
resp, err := client.Get(url)
|
|
latency := time.Since(start)
|
|
if err != nil {
|
|
lastErr = err
|
|
logger.Warn("Failed to ping exit node %d (%s) attempt %d: %v", node.ID, url, j+1, err)
|
|
continue
|
|
}
|
|
resp.Body.Close()
|
|
totalLatency += latency
|
|
successes++
|
|
}
|
|
var avgLatency time.Duration
|
|
if successes > 0 {
|
|
avgLatency = totalLatency / time.Duration(successes)
|
|
}
|
|
if successes == 0 {
|
|
results[i] = nodeResult{Node: node, Latency: 0, Err: lastErr}
|
|
} else {
|
|
results[i] = nodeResult{Node: node, Latency: avgLatency, Err: nil}
|
|
}
|
|
}
|
|
|
|
// Prepare data to send to the cloud for selection
|
|
var pingResults []ExitNodePingResult
|
|
for _, res := range results {
|
|
errMsg := ""
|
|
if res.Err != nil {
|
|
errMsg = res.Err.Error()
|
|
}
|
|
pingResults = append(pingResults, ExitNodePingResult{
|
|
ExitNodeID: res.Node.ID,
|
|
LatencyMs: res.Latency.Milliseconds(),
|
|
Weight: res.Node.Weight,
|
|
Error: errMsg,
|
|
Name: res.Node.Name,
|
|
Endpoint: res.Node.Endpoint,
|
|
WasPreviouslyConnected: res.Node.WasPreviouslyConnected,
|
|
})
|
|
}
|
|
|
|
// If we were previously connected and there is at least one other good node,
|
|
// exclude the previously connected node from pingResults sent to the cloud so we don't try to reconnect to it
|
|
// This is to avoid issues where the previously connected node might be down or unreachable
|
|
if connected {
|
|
var filteredPingResults []ExitNodePingResult
|
|
previouslyConnectedNodeIdx := -1
|
|
for i, res := range pingResults {
|
|
if res.WasPreviouslyConnected {
|
|
previouslyConnectedNodeIdx = i
|
|
}
|
|
}
|
|
// Count good nodes (latency > 0, no error, not previously connected)
|
|
goodNodeCount := 0
|
|
for i, res := range pingResults {
|
|
if i != previouslyConnectedNodeIdx && res.LatencyMs > 0 && res.Error == "" {
|
|
goodNodeCount++
|
|
}
|
|
}
|
|
if previouslyConnectedNodeIdx != -1 && goodNodeCount > 0 {
|
|
for i, res := range pingResults {
|
|
if i != previouslyConnectedNodeIdx {
|
|
filteredPingResults = append(filteredPingResults, res)
|
|
}
|
|
}
|
|
pingResults = filteredPingResults
|
|
logger.Info("Excluding previously connected exit node from ping results due to other available nodes")
|
|
}
|
|
}
|
|
|
|
// Send the ping results to the cloud for selection
|
|
stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{
|
|
"publicKey": publicKey.String(),
|
|
"pingResults": pingResults,
|
|
"newtVersion": newtVersion,
|
|
}, 1*time.Second)
|
|
|
|
logger.Debug("Sent exit node ping results to cloud for selection: pingResults=%+v", pingResults)
|
|
})
|
|
|
|
client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) {
|
|
logger.Debug(fmtReceivedMsg, msg)
|
|
|
|
// if there is no wgData or pm, we can't add targets
|
|
if wgData.TunnelIP == "" || pm == nil {
|
|
logger.Info(msgNoTunnelOrProxy)
|
|
return
|
|
}
|
|
|
|
targetData, err := parseTargetData(msg.Data)
|
|
if err != nil {
|
|
logger.Info(fmtErrParsingTargetData, err)
|
|
return
|
|
}
|
|
|
|
if len(targetData.Targets) > 0 {
|
|
updateTargets(pm, "add", wgData.TunnelIP, "tcp", targetData)
|
|
|
|
// Also update wgnetstack proxy manager
|
|
// if wgService != nil && wgService.GetNetstackNet() != nil && wgService.GetProxyManager() != nil {
|
|
// updateTargets(wgService.GetProxyManager(), "add", wgData.TunnelIP, "tcp", targetData)
|
|
// }
|
|
}
|
|
})
|
|
|
|
client.RegisterHandler("newt/udp/add", func(msg websocket.WSMessage) {
|
|
logger.Info(fmtReceivedMsg, msg)
|
|
|
|
// if there is no wgData or pm, we can't add targets
|
|
if wgData.TunnelIP == "" || pm == nil {
|
|
logger.Info(msgNoTunnelOrProxy)
|
|
return
|
|
}
|
|
|
|
targetData, err := parseTargetData(msg.Data)
|
|
if err != nil {
|
|
logger.Info(fmtErrParsingTargetData, err)
|
|
return
|
|
}
|
|
|
|
if len(targetData.Targets) > 0 {
|
|
updateTargets(pm, "add", wgData.TunnelIP, "udp", targetData)
|
|
|
|
// Also update wgnetstack proxy manager
|
|
// if wgService != nil && wgService.GetNetstackNet() != nil && wgService.GetProxyManager() != nil {
|
|
// updateTargets(wgService.GetProxyManager(), "add", wgData.TunnelIP, "udp", targetData)
|
|
// }
|
|
}
|
|
})
|
|
|
|
client.RegisterHandler("newt/udp/remove", func(msg websocket.WSMessage) {
|
|
logger.Info(fmtReceivedMsg, msg)
|
|
|
|
// if there is no wgData or pm, we can't add targets
|
|
if wgData.TunnelIP == "" || pm == nil {
|
|
logger.Info(msgNoTunnelOrProxy)
|
|
return
|
|
}
|
|
|
|
targetData, err := parseTargetData(msg.Data)
|
|
if err != nil {
|
|
logger.Info(fmtErrParsingTargetData, err)
|
|
return
|
|
}
|
|
|
|
if len(targetData.Targets) > 0 {
|
|
updateTargets(pm, "remove", wgData.TunnelIP, "udp", targetData)
|
|
|
|
// Also update wgnetstack proxy manager
|
|
// if wgService != nil && wgService.GetNetstackNet() != nil && wgService.GetProxyManager() != nil {
|
|
// updateTargets(wgService.GetProxyManager(), "remove", wgData.TunnelIP, "udp", targetData)
|
|
// }
|
|
}
|
|
})
|
|
|
|
client.RegisterHandler("newt/tcp/remove", func(msg websocket.WSMessage) {
|
|
logger.Info(fmtReceivedMsg, msg)
|
|
|
|
// if there is no wgData or pm, we can't add targets
|
|
if wgData.TunnelIP == "" || pm == nil {
|
|
logger.Info(msgNoTunnelOrProxy)
|
|
return
|
|
}
|
|
|
|
targetData, err := parseTargetData(msg.Data)
|
|
if err != nil {
|
|
logger.Info(fmtErrParsingTargetData, err)
|
|
return
|
|
}
|
|
|
|
if len(targetData.Targets) > 0 {
|
|
updateTargets(pm, "remove", wgData.TunnelIP, "tcp", targetData)
|
|
|
|
// Also update wgnetstack proxy manager
|
|
// if wgService != nil && wgService.GetNetstackNet() != nil && wgService.GetProxyManager() != nil {
|
|
// updateTargets(wgService.GetProxyManager(), "remove", wgData.TunnelIP, "tcp", targetData)
|
|
// }
|
|
}
|
|
})
|
|
|
|
// Register handler for Docker socket check
|
|
client.RegisterHandler("newt/socket/check", func(msg websocket.WSMessage) {
|
|
logger.Debug("Received Docker socket check request")
|
|
|
|
if dockerSocket == "" {
|
|
logger.Debug("Docker socket path is not set")
|
|
err := client.SendMessage("newt/socket/status", map[string]interface{}{
|
|
"available": false,
|
|
"socketPath": dockerSocket,
|
|
})
|
|
if err != nil {
|
|
logger.Error("Failed to send Docker socket check response: %v", err)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Check if Docker socket is available
|
|
isAvailable := docker.CheckSocket(dockerSocket)
|
|
|
|
// Send response back to server
|
|
err := client.SendMessage("newt/socket/status", map[string]interface{}{
|
|
"available": isAvailable,
|
|
"socketPath": dockerSocket,
|
|
})
|
|
if err != nil {
|
|
logger.Error("Failed to send Docker socket check response: %v", err)
|
|
} else {
|
|
logger.Debug("Docker socket check response sent: available=%t", isAvailable)
|
|
}
|
|
})
|
|
|
|
// Register handler for Docker container listing
|
|
client.RegisterHandler("newt/socket/fetch", func(msg websocket.WSMessage) {
|
|
logger.Debug("Received Docker container fetch request")
|
|
|
|
if dockerSocket == "" {
|
|
logger.Debug("Docker socket path is not set")
|
|
return
|
|
}
|
|
|
|
// List Docker containers
|
|
containers, err := docker.ListContainers(dockerSocket, dockerEnforceNetworkValidationBool)
|
|
if err != nil {
|
|
logger.Error("Failed to list Docker containers: %v", err)
|
|
return
|
|
}
|
|
|
|
// Send container list back to server
|
|
err = client.SendMessage("newt/socket/containers", map[string]interface{}{
|
|
"containers": containers,
|
|
})
|
|
if err != nil {
|
|
logger.Error("Failed to send registration message: %v", err)
|
|
}
|
|
|
|
if err != nil {
|
|
logger.Error("Failed to send Docker container list: %v", err)
|
|
} else {
|
|
logger.Debug("Docker container list sent, count: %d", len(containers))
|
|
}
|
|
})
|
|
|
|
// EXPERIMENTAL: WHAT SHOULD WE DO ABOUT SECURITY?
|
|
client.RegisterHandler("newt/send/ssh/publicKey", func(msg websocket.WSMessage) {
|
|
logger.Debug("Received SSH public key request")
|
|
|
|
var sshPublicKeyData SSHPublicKeyData
|
|
|
|
jsonData, err := json.Marshal(msg.Data)
|
|
if err != nil {
|
|
logger.Info(fmtErrMarshaling, err)
|
|
return
|
|
}
|
|
if err := json.Unmarshal(jsonData, &sshPublicKeyData); err != nil {
|
|
logger.Info("Error unmarshaling SSH public key data: %v", err)
|
|
return
|
|
}
|
|
|
|
sshPublicKey := sshPublicKeyData.PublicKey
|
|
|
|
if authorizedKeysFile == "" {
|
|
logger.Debug("No authorized keys file set, skipping public key response")
|
|
return
|
|
}
|
|
|
|
// Expand tilde to home directory if present
|
|
expandedPath := authorizedKeysFile
|
|
if strings.HasPrefix(authorizedKeysFile, "~/") {
|
|
homeDir, err := os.UserHomeDir()
|
|
if err != nil {
|
|
logger.Error("Failed to get user home directory: %v", err)
|
|
return
|
|
}
|
|
expandedPath = filepath.Join(homeDir, authorizedKeysFile[2:])
|
|
}
|
|
|
|
// if it is set but the file does not exist, create it
|
|
if _, err := os.Stat(expandedPath); os.IsNotExist(err) {
|
|
logger.Debug("Authorized keys file does not exist, creating it: %s", expandedPath)
|
|
if err := os.MkdirAll(filepath.Dir(expandedPath), 0755); err != nil {
|
|
logger.Error("Failed to create directory for authorized keys file: %v", err)
|
|
return
|
|
}
|
|
if _, err := os.Create(expandedPath); err != nil {
|
|
logger.Error("Failed to create authorized keys file: %v", err)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Check if the public key already exists in the file
|
|
fileContent, err := os.ReadFile(expandedPath)
|
|
if err != nil {
|
|
logger.Error("Failed to read authorized keys file: %v", err)
|
|
return
|
|
}
|
|
|
|
// Check if the key already exists (trim whitespace for comparison)
|
|
existingKeys := strings.Split(string(fileContent), "\n")
|
|
keyAlreadyExists := false
|
|
trimmedNewKey := strings.TrimSpace(sshPublicKey)
|
|
|
|
for _, existingKey := range existingKeys {
|
|
if strings.TrimSpace(existingKey) == trimmedNewKey && trimmedNewKey != "" {
|
|
keyAlreadyExists = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if keyAlreadyExists {
|
|
logger.Info("SSH public key already exists in authorized keys file, skipping")
|
|
return
|
|
}
|
|
|
|
// append the public key to the authorized keys file
|
|
logger.Debug("Appending public key to authorized keys file: %s", sshPublicKey)
|
|
file, err := os.OpenFile(expandedPath, os.O_APPEND|os.O_WRONLY, 0644)
|
|
if err != nil {
|
|
logger.Error("Failed to open authorized keys file: %v", err)
|
|
return
|
|
}
|
|
defer file.Close()
|
|
|
|
if _, err := file.WriteString(sshPublicKey + "\n"); err != nil {
|
|
logger.Error("Failed to write public key to authorized keys file: %v", err)
|
|
return
|
|
}
|
|
|
|
logger.Info("SSH public key appended to authorized keys file")
|
|
})
|
|
|
|
// Register handler for adding health check targets
|
|
client.RegisterHandler("newt/healthcheck/add", func(msg websocket.WSMessage) {
|
|
logger.Debug("Received health check add request: %+v", msg)
|
|
|
|
type HealthCheckConfig struct {
|
|
Targets []healthcheck.Config `json:"targets"`
|
|
}
|
|
|
|
var config HealthCheckConfig
|
|
// add a bunch of targets at once
|
|
jsonData, err := json.Marshal(msg.Data)
|
|
if err != nil {
|
|
logger.Error("Error marshaling health check data: %v", err)
|
|
return
|
|
}
|
|
|
|
if err := json.Unmarshal(jsonData, &config); err != nil {
|
|
logger.Error("Error unmarshaling health check config: %v", err)
|
|
return
|
|
}
|
|
|
|
if err := healthMonitor.AddTargets(config.Targets); err != nil {
|
|
logger.Error("Failed to add health check targets: %v", err)
|
|
} else {
|
|
logger.Debug("Added %d health check targets", len(config.Targets))
|
|
}
|
|
|
|
logger.Debug("Health check targets added: %+v", config.Targets)
|
|
})
|
|
|
|
// Register handler for removing health check targets
|
|
client.RegisterHandler("newt/healthcheck/remove", func(msg websocket.WSMessage) {
|
|
logger.Debug("Received health check remove request: %+v", msg)
|
|
|
|
type HealthCheckConfig struct {
|
|
IDs []int `json:"ids"`
|
|
}
|
|
|
|
var requestData HealthCheckConfig
|
|
jsonData, err := json.Marshal(msg.Data)
|
|
if err != nil {
|
|
logger.Error("Error marshaling health check remove data: %v", err)
|
|
return
|
|
}
|
|
|
|
if err := json.Unmarshal(jsonData, &requestData); err != nil {
|
|
logger.Error("Error unmarshaling health check remove request: %v", err)
|
|
return
|
|
}
|
|
|
|
// Multiple target removal
|
|
if err := healthMonitor.RemoveTargets(requestData.IDs); err != nil {
|
|
logger.Error("Failed to remove health check targets %v: %v", requestData.IDs, err)
|
|
} else {
|
|
logger.Info("Removed %d health check targets: %v", len(requestData.IDs), requestData.IDs)
|
|
}
|
|
})
|
|
|
|
// Register handler for enabling health check targets
|
|
client.RegisterHandler("newt/healthcheck/enable", func(msg websocket.WSMessage) {
|
|
logger.Debug("Received health check enable request: %+v", msg)
|
|
|
|
var requestData struct {
|
|
ID int `json:"id"`
|
|
}
|
|
jsonData, err := json.Marshal(msg.Data)
|
|
if err != nil {
|
|
logger.Error("Error marshaling health check enable data: %v", err)
|
|
return
|
|
}
|
|
|
|
if err := json.Unmarshal(jsonData, &requestData); err != nil {
|
|
logger.Error("Error unmarshaling health check enable request: %v", err)
|
|
return
|
|
}
|
|
|
|
if err := healthMonitor.EnableTarget(requestData.ID); err != nil {
|
|
logger.Error("Failed to enable health check target %d: %v", requestData.ID, err)
|
|
} else {
|
|
logger.Info("Enabled health check target: %d", requestData.ID)
|
|
}
|
|
})
|
|
|
|
// Register handler for disabling health check targets
|
|
client.RegisterHandler("newt/healthcheck/disable", func(msg websocket.WSMessage) {
|
|
logger.Debug("Received health check disable request: %+v", msg)
|
|
|
|
var requestData struct {
|
|
ID int `json:"id"`
|
|
}
|
|
jsonData, err := json.Marshal(msg.Data)
|
|
if err != nil {
|
|
logger.Error("Error marshaling health check disable data: %v", err)
|
|
return
|
|
}
|
|
|
|
if err := json.Unmarshal(jsonData, &requestData); err != nil {
|
|
logger.Error("Error unmarshaling health check disable request: %v", err)
|
|
return
|
|
}
|
|
|
|
if err := healthMonitor.DisableTarget(requestData.ID); err != nil {
|
|
logger.Error("Failed to disable health check target %d: %v", requestData.ID, err)
|
|
} else {
|
|
logger.Info("Disabled health check target: %d", requestData.ID)
|
|
}
|
|
})
|
|
|
|
// Register handler for getting health check status
|
|
client.RegisterHandler("newt/healthcheck/status/request", func(msg websocket.WSMessage) {
|
|
logger.Debug("Received health check status request")
|
|
|
|
targets := healthMonitor.GetTargets()
|
|
healthStatuses := make(map[int]interface{})
|
|
for id, target := range targets {
|
|
healthStatuses[id] = map[string]interface{}{
|
|
"status": target.Status.String(),
|
|
"lastCheck": target.LastCheck.Format(time.RFC3339),
|
|
"checkCount": target.CheckCount,
|
|
"lastError": target.LastError,
|
|
"config": target.Config,
|
|
}
|
|
}
|
|
|
|
err := client.SendMessage("newt/healthcheck/status", map[string]interface{}{
|
|
"targets": healthStatuses,
|
|
})
|
|
if err != nil {
|
|
logger.Error("Failed to send health check status response: %v", err)
|
|
}
|
|
})
|
|
|
|
// Register handler for getting health check status
|
|
client.RegisterHandler("newt/blueprint/results", func(msg websocket.WSMessage) {
|
|
logger.Debug("Received blueprint results message")
|
|
|
|
var blueprintResult BlueprintResult
|
|
|
|
jsonData, err := json.Marshal(msg.Data)
|
|
if err != nil {
|
|
logger.Info("Error marshaling data: %v", err)
|
|
return
|
|
}
|
|
if err := json.Unmarshal(jsonData, &blueprintResult); err != nil {
|
|
logger.Info("Error unmarshaling config results data: %v", err)
|
|
return
|
|
}
|
|
|
|
if blueprintResult.Success {
|
|
logger.Debug("Blueprint applied successfully!")
|
|
} else {
|
|
logger.Warn("Blueprint application failed: %s", blueprintResult.Message)
|
|
}
|
|
})
|
|
|
|
client.OnConnect(func() error {
|
|
publicKey = privateKey.PublicKey()
|
|
logger.Debug("Public key: %s", publicKey)
|
|
logger.Info("Websocket connected")
|
|
|
|
if !connected {
|
|
// make sure the stop function is called
|
|
if stopFunc != nil {
|
|
stopFunc()
|
|
}
|
|
// request from the server the list of nodes to ping
|
|
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{
|
|
"noCloud": noCloud,
|
|
}, 3*time.Second)
|
|
logger.Debug("Requesting exit nodes from server")
|
|
clientsOnConnect()
|
|
}
|
|
|
|
// Send registration message to the server for backward compatibility
|
|
err := client.SendMessage(topicWGRegister, map[string]interface{}{
|
|
"publicKey": publicKey.String(),
|
|
"newtVersion": newtVersion,
|
|
"backwardsCompatible": true,
|
|
})
|
|
|
|
sendBlueprint(client)
|
|
|
|
if err != nil {
|
|
logger.Error("Failed to send registration message: %v", err)
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
})
|
|
|
|
// Connect to the WebSocket server
|
|
if err := client.Connect(); err != nil {
|
|
logger.Fatal("Failed to connect to server: %v", err)
|
|
}
|
|
defer client.Close()
|
|
|
|
// Initialize Docker event monitoring if Docker socket is available and monitoring is enabled
|
|
if dockerSocket != "" {
|
|
logger.Debug("Initializing Docker event monitoring")
|
|
dockerEventMonitor, err = docker.NewEventMonitor(dockerSocket, dockerEnforceNetworkValidationBool, func(containers []docker.Container) {
|
|
// Send updated container list via websocket when Docker events occur
|
|
logger.Debug("Docker event detected, sending updated container list (%d containers)", len(containers))
|
|
err := client.SendMessage("newt/socket/containers", map[string]interface{}{
|
|
"containers": containers,
|
|
})
|
|
if err != nil {
|
|
logger.Error("Failed to send updated container list after Docker event: %v", err)
|
|
} else {
|
|
logger.Debug("Updated container list sent successfully")
|
|
}
|
|
})
|
|
|
|
if err != nil {
|
|
logger.Error("Failed to create Docker event monitor: %v", err)
|
|
} else {
|
|
err = dockerEventMonitor.Start()
|
|
if err != nil {
|
|
logger.Error("Failed to start Docker event monitoring: %v", err)
|
|
} else {
|
|
logger.Debug("Docker event monitoring started successfully")
|
|
}
|
|
}
|
|
}
|
|
|
|
// Wait for interrupt signal
|
|
sigCh := make(chan os.Signal, 1)
|
|
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
|
<-sigCh
|
|
|
|
// Close clients first (including WGTester)
|
|
closeClients()
|
|
|
|
if dockerEventMonitor != nil {
|
|
dockerEventMonitor.Stop()
|
|
}
|
|
|
|
if healthMonitor != nil {
|
|
healthMonitor.Stop()
|
|
}
|
|
|
|
if dev != nil {
|
|
dev.Close()
|
|
}
|
|
|
|
if pm != nil {
|
|
pm.Stop()
|
|
}
|
|
|
|
if client != nil {
|
|
client.Close()
|
|
}
|
|
logger.Info("Exiting...")
|
|
os.Exit(0)
|
|
}
|
|
|
|
// validateTLSConfig validates the TLS configuration
|
|
func validateTLSConfig() error {
|
|
// Check for conflicting configurations
|
|
pkcs12Specified := tlsPrivateKey != ""
|
|
separateFilesSpecified := tlsClientCert != "" || tlsClientKey != "" || len(tlsClientCAs) > 0
|
|
|
|
if pkcs12Specified && separateFilesSpecified {
|
|
return fmt.Errorf("cannot use both PKCS12 format (--tls-client-cert) and separate certificate files (--tls-client-cert-file, --tls-client-key, --tls-client-ca)")
|
|
}
|
|
|
|
// If using separate files, both cert and key are required
|
|
if (tlsClientCert != "" && tlsClientKey == "") || (tlsClientCert == "" && tlsClientKey != "") {
|
|
return fmt.Errorf("both --tls-client-cert-file and --tls-client-key must be specified together")
|
|
}
|
|
|
|
// Validate certificate files exist
|
|
if tlsClientCert != "" {
|
|
if _, err := os.Stat(tlsClientCert); os.IsNotExist(err) {
|
|
return fmt.Errorf("client certificate file does not exist: %s", tlsClientCert)
|
|
}
|
|
}
|
|
|
|
if tlsClientKey != "" {
|
|
if _, err := os.Stat(tlsClientKey); os.IsNotExist(err) {
|
|
return fmt.Errorf("client key file does not exist: %s", tlsClientKey)
|
|
}
|
|
}
|
|
|
|
// Validate CA files exist
|
|
for _, caFile := range tlsClientCAs {
|
|
if _, err := os.Stat(caFile); os.IsNotExist(err) {
|
|
return fmt.Errorf("CA certificate file does not exist: %s", caFile)
|
|
}
|
|
}
|
|
|
|
// Validate PKCS12 file exists if specified
|
|
if tlsPrivateKey != "" {
|
|
if _, err := os.Stat(tlsPrivateKey); os.IsNotExist(err) {
|
|
return fmt.Errorf("PKCS12 certificate file does not exist: %s", tlsPrivateKey)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|