Files
gerbil/main.go
2026-04-03 15:57:53 +02:00

1872 lines
59 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"bytes"
"context"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"log"
"net"
"net/http"
_ "net/http/pprof"
"os"
"os/exec"
"os/signal"
"runtime"
"runtime/pprof"
"strconv"
"strings"
"sync"
"syscall"
"time"
"github.com/fosrl/gerbil/internal/metrics"
"github.com/fosrl/gerbil/logger"
"github.com/fosrl/gerbil/proxy"
"github.com/fosrl/gerbil/relay"
"github.com/vishvananda/netlink"
"golang.org/x/sync/errgroup"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
var (
interfaceName string
listenAddr string
mtuInt int
lastReadings = make(map[string]PeerReading)
mu sync.Mutex
wgMu sync.Mutex // Protects WireGuard operations
notifyURL string
proxyRelay *relay.UDPProxyServer
proxySNI *proxy.SNIProxy
doTrafficShaping bool
bandwidthLimit string
ifbName string // IFB device name for ingress traffic shaping
)
type WgConfig struct {
PrivateKey string `json:"privateKey"`
ListenPort int `json:"listenPort"`
RelayPort int `json:"relayPort"`
IpAddress string `json:"ipAddress"`
Peers []Peer `json:"peers"`
}
type Peer struct {
PublicKey string `json:"publicKey"`
AllowedIPs []string `json:"allowedIps"`
}
type PeerBandwidth struct {
PublicKey string `json:"publicKey"`
BytesIn float64 `json:"bytesIn"`
BytesOut float64 `json:"bytesOut"`
}
type PeerReading struct {
BytesReceived int64
BytesTransmitted int64
LastChecked time.Time
}
var (
wgClient *wgctrl.Client
)
// Add this new type at the top with other type definitions
type ClientEndpoint struct {
OlmID string `json:"olmId"`
NewtID string `json:"newtId"`
IP string `json:"ip"`
Port int `json:"port"`
Timestamp int64 `json:"timestamp"`
}
type HolePunchMessage struct {
OlmID string `json:"olmId"`
NewtID string `json:"newtId"`
}
type ProxyMappingUpdate struct {
OldDestination relay.PeerDestination `json:"oldDestination"`
NewDestination relay.PeerDestination `json:"newDestination"`
}
type UpdateDestinationsRequest struct {
SourceIP string `json:"sourceIp"`
SourcePort int `json:"sourcePort"`
Destinations []relay.PeerDestination `json:"destinations"`
}
// httpMetricsMiddleware wraps HTTP handlers with metrics tracking
func httpMetricsMiddleware(endpoint string, handler http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
// Create a response writer wrapper to capture status code
ww := &responseWriterWrapper{ResponseWriter: w, statusCode: http.StatusOK}
// Call the actual handler
handler(ww, r)
// Record metrics
duration := time.Since(startTime).Seconds()
metrics.RecordHTTPRequest(endpoint, r.Method, fmt.Sprintf("%d", ww.statusCode))
metrics.RecordHTTPRequestDuration(endpoint, r.Method, duration)
}
}
// responseWriterWrapper wraps http.ResponseWriter to capture status code
type responseWriterWrapper struct {
http.ResponseWriter
statusCode int
}
func (w *responseWriterWrapper) WriteHeader(statusCode int) {
w.statusCode = statusCode
w.ResponseWriter.WriteHeader(statusCode)
}
func parseLogLevel(level string) logger.LogLevel {
switch strings.ToUpper(level) {
case "DEBUG":
return logger.DEBUG
case "INFO":
return logger.INFO
case "WARN":
return logger.WARN
case "ERROR":
return logger.ERROR
case "FATAL":
return logger.FATAL
default:
return logger.INFO // default to INFO if invalid level provided
}
}
func main() {
go monitorMemory(1024 * 1024 * 512) // trigger if memory usage exceeds 512MB
var (
err error
wgconfig WgConfig
configFile string
remoteConfigURL string
generateAndSaveKeyTo string
reachableAt string
logLevel string
mtu string
sniProxyPort int
localProxyAddr string
localProxyPort int
localOverridesStr string
trustedUpstreamsStr string
proxyProtocol bool
// Metrics configuration variables (set from env, then overridden by CLI flags)
metricsEnabled bool
metricsBackend string
metricsPath string
otelMetricsProtocol string
otelMetricsEndpoint string
otelMetricsInsecure bool
otelMetricsExportInterval time.Duration
)
interfaceName = os.Getenv("INTERFACE")
configFile = os.Getenv("CONFIG")
remoteConfigURL = os.Getenv("REMOTE_CONFIG")
listenAddr = os.Getenv("LISTEN")
generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO")
reachableAt = os.Getenv("REACHABLE_AT")
logLevel = os.Getenv("LOG_LEVEL")
mtu = os.Getenv("MTU")
notifyURL = os.Getenv("NOTIFY_URL")
sniProxyPortStr := os.Getenv("SNI_PORT")
localProxyAddr = os.Getenv("LOCAL_PROXY")
localProxyPortStr := os.Getenv("LOCAL_PROXY_PORT")
localOverridesStr = os.Getenv("LOCAL_OVERRIDES")
trustedUpstreamsStr = os.Getenv("TRUSTED_UPSTREAMS")
proxyProtocolStr := os.Getenv("PROXY_PROTOCOL")
doTrafficShapingStr := os.Getenv("DO_TRAFFIC_SHAPING")
bandwidthLimitStr := os.Getenv("BANDWIDTH_LIMIT")
// Read metrics env vars (defaults applied by DefaultMetricsConfig; these override defaults).
metricsEnabled = true // default
if v := os.Getenv("METRICS_ENABLED"); v != "" {
metricsEnabled = strings.ToLower(v) == "true"
}
metricsBackend = "prometheus" // default
if v := os.Getenv("METRICS_BACKEND"); v != "" {
metricsBackend = v
}
metricsPath = "/metrics" // default
if v := os.Getenv("METRICS_PATH"); v != "" {
metricsPath = v
}
otelMetricsProtocol = "grpc" // default
if v := os.Getenv("OTEL_METRICS_PROTOCOL"); v != "" {
otelMetricsProtocol = v
}
otelMetricsEndpoint = "localhost:4317" // default
if v := os.Getenv("OTEL_METRICS_ENDPOINT"); v != "" {
otelMetricsEndpoint = v
}
otelMetricsInsecure = true // default
if v := os.Getenv("OTEL_METRICS_INSECURE"); v != "" {
otelMetricsInsecure = strings.ToLower(v) == "true"
}
otelMetricsExportInterval = 60 * time.Second // default
if v := os.Getenv("OTEL_METRICS_EXPORT_INTERVAL"); v != "" {
if d, err2 := time.ParseDuration(v); err2 == nil {
otelMetricsExportInterval = d
} else {
log.Printf("WARN: invalid OTEL_METRICS_EXPORT_INTERVAL=%q: %v", v, err2)
}
}
if interfaceName == "" {
flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface")
}
if configFile == "" {
flag.StringVar(&configFile, "config", "", "Path to local configuration file")
}
if remoteConfigURL == "" {
flag.StringVar(&remoteConfigURL, "remoteConfig", "", "URL of the Pangolin server")
}
if listenAddr == "" {
flag.StringVar(&listenAddr, "listen", "", "DEPRECATED (overridden by reachableAt): Address to listen on")
}
// DEPRECATED AND UNSED: reportBandwidthTo
// allow reportBandwidthTo to be passed but dont do anything with it just thow it away
reportBandwidthTo := ""
flag.StringVar(&reportBandwidthTo, "reportBandwidthTo", "", "DEPRECATED: Use remoteConfig instead")
if generateAndSaveKeyTo == "" {
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
}
if reachableAt == "" {
flag.StringVar(&reachableAt, "reachableAt", "", "Endpoint of the http server to tell remote config about")
}
if logLevel == "" {
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
}
if mtu == "" {
flag.StringVar(&mtu, "mtu", "1280", "MTU of the WireGuard interface")
}
if notifyURL == "" {
flag.StringVar(&notifyURL, "notify", "", "URL to notify on peer changes")
}
if sniProxyPortStr != "" {
if port, err := strconv.Atoi(sniProxyPortStr); err == nil {
sniProxyPort = port
}
}
if sniProxyPortStr == "" {
flag.IntVar(&sniProxyPort, "sni-port", 8443, "Port to listen on")
}
if localProxyAddr == "" {
flag.StringVar(&localProxyAddr, "local-proxy", "localhost", "Local proxy address")
}
if localProxyPortStr != "" {
if port, err := strconv.Atoi(localProxyPortStr); err == nil {
localProxyPort = port
}
}
if localProxyPortStr == "" {
flag.IntVar(&localProxyPort, "local-proxy-port", 443, "Local proxy port")
}
if localOverridesStr != "" {
flag.StringVar(&localOverridesStr, "local-overrides", "", "Comma-separated list of local overrides for SNI proxy")
}
if trustedUpstreamsStr == "" {
flag.StringVar(&trustedUpstreamsStr, "trusted-upstreams", "", "Comma-separated list of trusted upstream proxy domain names/IPs that can send PROXY protocol")
}
if proxyProtocolStr != "" {
proxyProtocol = strings.ToLower(proxyProtocolStr) == "true"
}
if proxyProtocolStr == "" {
flag.BoolVar(&proxyProtocol, "proxy-protocol", true, "Enable PROXY protocol v1 for preserving client IP")
}
if doTrafficShapingStr != "" {
doTrafficShaping = strings.ToLower(doTrafficShapingStr) == "true"
}
if doTrafficShapingStr == "" {
flag.BoolVar(&doTrafficShaping, "do-traffic-shaping", false, "Whether to set up traffic shaping rules for peers (requires tc command and root privileges)")
}
if bandwidthLimitStr != "" {
bandwidthLimit = bandwidthLimitStr
}
if bandwidthLimitStr == "" {
flag.StringVar(&bandwidthLimit, "bandwidth-limit", "50mbit", "Bandwidth limit per peer for traffic shaping (e.g. 50mbit, 1gbit)")
}
// Metrics CLI flags always registered so that CLI overrides env/defaults.
flag.BoolVar(&metricsEnabled, "metrics-enabled", metricsEnabled, "Enable metrics collection (default: true)")
flag.StringVar(&metricsBackend, "metrics-backend", metricsBackend, "Metrics backend: prometheus, otel, or none")
flag.StringVar(&metricsPath, "metrics-path", metricsPath, "HTTP path for Prometheus /metrics endpoint")
flag.StringVar(&otelMetricsProtocol, "otel-metrics-protocol", otelMetricsProtocol, "OTLP transport protocol: grpc or http")
flag.StringVar(&otelMetricsEndpoint, "otel-metrics-endpoint", otelMetricsEndpoint, "OTLP collector endpoint (e.g. localhost:4317)")
flag.BoolVar(&otelMetricsInsecure, "otel-metrics-insecure", otelMetricsInsecure, "Disable TLS for OTLP connection")
flag.DurationVar(&otelMetricsExportInterval, "otel-metrics-export-interval", otelMetricsExportInterval, "Interval between OTLP metric pushes")
flag.Parse()
// Derive IFB device name from the WireGuard interface name (Linux limit: 15 chars)
ifbName = "ifb_" + interfaceName
if len(ifbName) > 15 {
ifbName = ifbName[:15]
}
logger.Init()
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
// Initialize metrics with the selected backend.
// Config precedence: CLI flags > env vars > defaults (already applied above).
metricsHandler, err := metrics.Initialize(metrics.Config{
Enabled: metricsEnabled,
Backend: metricsBackend,
Prometheus: metrics.PrometheusConfig{
Path: metricsPath,
},
OTel: metrics.OTelConfig{
Protocol: otelMetricsProtocol,
Endpoint: otelMetricsEndpoint,
Insecure: otelMetricsInsecure,
ExportInterval: otelMetricsExportInterval,
},
ServiceName: "gerbil",
ServiceVersion: "1.0.0",
DeploymentEnvironment: os.Getenv("DEPLOYMENT_ENVIRONMENT"),
})
if err != nil {
logger.Fatal("Failed to initialize metrics: %v", err)
}
defer func() {
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := metrics.Shutdown(shutdownCtx); err != nil {
logger.Error("Failed to shutdown metrics: %v", err)
}
}()
// Record restart metric
metrics.RecordRestart()
// Base context for the application; cancel on SIGINT/SIGTERM
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
// try to parse as http://host:port and set the listenAddr to the :port from this reachableAt.
if reachableAt != "" && listenAddr == "" {
if strings.HasPrefix(reachableAt, "http://") || strings.HasPrefix(reachableAt, "https://") {
parts := strings.Split(reachableAt, ":")
if len(parts) == 3 {
port := parts[2]
if strings.Contains(port, "/") {
port = strings.Split(port, "/")[0]
}
listenAddr = ":" + port
}
}
} else if listenAddr == "" {
listenAddr = ":3003"
}
mtuInt, err = strconv.Atoi(mtu)
if err != nil {
logger.Fatal("Failed to parse MTU: %v", err)
}
// are they missing either the config file or the remote config URL?
if configFile == "" && remoteConfigURL == "" {
logger.Fatal("You must provide either a config file or a remote config URL")
}
// do they have both the config file and the remote config URL?
if configFile != "" && remoteConfigURL != "" {
logger.Fatal("You must provide either a config file or a remote config URL, not both")
}
// clean up the reomte config URL for backwards compatibility
remoteConfigURL = strings.TrimSuffix(remoteConfigURL, "/gerbil/get-config")
remoteConfigURL = strings.TrimSuffix(remoteConfigURL, "/")
var key wgtypes.Key
// if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file
if generateAndSaveKeyTo != "" {
if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) {
// generate a new private key
key, err = wgtypes.GeneratePrivateKey()
if err != nil {
logger.Fatal("Failed to generate private key: %v", err)
}
// save the key to the file
err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644)
if err != nil {
logger.Fatal("Failed to save private key: %v", err)
}
} else {
keyData, err := os.ReadFile(generateAndSaveKeyTo)
if err != nil {
logger.Fatal("Failed to read private key: %v", err)
}
key, err = wgtypes.ParseKey(string(keyData))
if err != nil {
logger.Fatal("Failed to parse private key: %v", err)
}
}
} else {
// if no generateAndSaveKeyTo is provided, ensure that the private key is provided
if wgconfig.PrivateKey == "" {
// generate a new one
key, err = wgtypes.GeneratePrivateKey()
if err != nil {
logger.Fatal("Failed to generate private key: %v", err)
}
}
}
// Load configuration based on provided argument
if configFile != "" {
wgconfig, err = loadConfig(configFile)
if err != nil {
logger.Fatal("Failed to load configuration: %v", err)
}
if wgconfig.PrivateKey == "" {
wgconfig.PrivateKey = key.String()
}
} else {
// loop until we get the config
for wgconfig.PrivateKey == "" {
logger.Info("Fetching remote config from %s", remoteConfigURL+"/gerbil/get-config")
wgconfig, err = loadRemoteConfig(remoteConfigURL+"/gerbil/get-config", key, reachableAt)
if err != nil {
logger.Error("Failed to load configuration: %v", err)
time.Sleep(5 * time.Second)
continue
}
wgconfig.PrivateKey = key.String()
}
}
wgClient, err = wgctrl.New()
if err != nil {
logger.Fatal("Failed to create WireGuard client: %v", err)
}
defer wgClient.Close()
// Ensure the WireGuard interface exists and is configured
if err := ensureWireguardInterface(wgconfig); err != nil {
logger.Fatal("Failed to ensure WireGuard interface: %v", err)
}
// Set up IFB device for bidirectional ingress/egress traffic shaping if enabled
if doTrafficShaping {
if err := ensureIFBDevice(); err != nil {
logger.Fatal("Failed to ensure IFB device for traffic shaping: %v", err)
}
}
// Ensure the WireGuard peers exist
ensureWireguardPeers(wgconfig.Peers)
// Child error group derived from base context
group, groupCtx := errgroup.WithContext(ctx)
// Periodic bandwidth reporting
group.Go(func() error {
return periodicBandwidthCheck(groupCtx, remoteConfigURL+"/gerbil/receive-bandwidth")
})
// Start the UDP proxy server
relayPort := wgconfig.RelayPort
if relayPort == 0 {
relayPort = 21820 // in case there is no relay port set, use 21820
}
proxyRelay = relay.NewUDPProxyServer(groupCtx, fmt.Sprintf(":%d", relayPort), remoteConfigURL, key, reachableAt)
err = proxyRelay.Start()
if err != nil {
logger.Fatal("Failed to start UDP proxy server: %v", err)
}
defer proxyRelay.Stop()
// TODO: WE SHOULD PULL THIS OUT OF THE CONFIG OR SOMETHING
// SO YOU DON'T NEED TO SET THIS SEPARATELY
// Parse local overrides
var localOverrides []string
if localOverridesStr != "" {
localOverrides = strings.Split(localOverridesStr, ",")
for i, domain := range localOverrides {
localOverrides[i] = strings.TrimSpace(domain)
}
logger.Info("Local overrides configured: %v", localOverrides)
}
var trustedUpstreams []string
if trustedUpstreamsStr != "" {
trustedUpstreams = strings.Split(trustedUpstreamsStr, ",")
for i, upstream := range trustedUpstreams {
trustedUpstreams[i] = strings.TrimSpace(upstream)
}
logger.Info("Trusted upstreams configured: %v", trustedUpstreams)
}
proxySNI, err = proxy.NewSNIProxy(sniProxyPort, remoteConfigURL, key.PublicKey().String(), localProxyAddr, localProxyPort, localOverrides, proxyProtocol, trustedUpstreams)
if err != nil {
logger.Fatal("Failed to create proxy: %v", err)
}
if err := proxySNI.Start(); err != nil {
logger.Fatal("Failed to start proxy: %v", err)
}
// Set up HTTP server with metrics middleware
http.HandleFunc("/peer", httpMetricsMiddleware("peer", handlePeer))
http.HandleFunc("/update-proxy-mapping", httpMetricsMiddleware("update_proxy_mapping", handleUpdateProxyMapping))
http.HandleFunc("/update-destinations", httpMetricsMiddleware("update_destinations", handleUpdateDestinations))
http.HandleFunc("/update-local-snis", httpMetricsMiddleware("update_local_snis", handleUpdateLocalSNIs))
http.HandleFunc("/healthz", httpMetricsMiddleware("healthz", handleHealthz))
// Register metrics endpoint only for Prometheus backend.
// OTel backend pushes to a collector; no /metrics endpoint needed.
if metricsHandler != nil {
http.Handle(metricsPath, metricsHandler)
logger.Info("Metrics endpoint enabled at %s", metricsPath)
}
logger.Info("Starting HTTP server on %s", listenAddr)
// HTTP server with graceful shutdown on context cancel
server := &http.Server{
Addr: listenAddr,
Handler: nil,
ReadHeaderTimeout: 3 * time.Second,
}
group.Go(func() error {
// http.ErrServerClosed is returned on graceful shutdown; not an error for us
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}
return nil
})
group.Go(func() error {
<-groupCtx.Done()
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = server.Shutdown(shutdownCtx)
// Stop background components as the context is canceled
if proxySNI != nil {
_ = proxySNI.Stop()
}
if proxyRelay != nil {
proxyRelay.Stop()
}
return nil
})
// Wait for all goroutines to finish
if err := group.Wait(); err != nil && !errors.Is(err, context.Canceled) {
logger.Error("Service exited with error: %v", err)
} else if errors.Is(err, context.Canceled) {
logger.Info("Context cancelled, shutting down")
}
}
func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) {
var body *bytes.Buffer
if reachableAt == "" {
body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": %q}`, key.PublicKey().String())))
} else {
body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": %q, "reachableAt": %q}`, key.PublicKey().String(), reachableAt)))
}
resp, err := http.Post(url, "application/json", body)
if err != nil {
// print the error
logger.Error("Error fetching remote config %s: %v", url, err)
// Record remote config fetch error
metrics.RecordRemoteConfigFetch("error")
return WgConfig{}, err
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
metrics.RecordRemoteConfigFetch("error")
return WgConfig{}, err
}
var config WgConfig
err = json.Unmarshal(data, &config)
if err != nil {
metrics.RecordRemoteConfigFetch("error")
return config, err
}
// Record successful remote config fetch
metrics.RecordRemoteConfigFetch("success")
return config, err
}
func loadConfig(filename string) (WgConfig, error) {
// Open the JSON file
file, err := os.Open(filename)
if err != nil {
logger.Error("Error opening file %s: %v", filename, err)
return WgConfig{}, err
}
defer file.Close()
// Read the file contents
byteValue, err := io.ReadAll(file)
if err != nil {
logger.Error("Error reading file %s: %v", filename, err)
return WgConfig{}, err
}
// Create a variable of the appropriate type to hold the unmarshaled data
var wgconfig WgConfig
// Unmarshal the JSON data into the struct
err = json.Unmarshal(byteValue, &wgconfig)
if err != nil {
logger.Error("Error unmarshaling JSON data: %v", err)
return WgConfig{}, err
}
return wgconfig, nil
}
func ensureWireguardInterface(wgconfig WgConfig) error {
// Check if the WireGuard interface exists
_, err := netlink.LinkByName(interfaceName)
if err != nil {
if _, ok := err.(netlink.LinkNotFoundError); ok {
// Interface doesn't exist, so create it
err = createWireGuardInterface()
if err != nil {
logger.Fatal("Failed to create WireGuard interface: %v", err)
}
logger.Info("Created WireGuard interface %s\n", interfaceName)
} else {
logger.Fatal("Error checking for WireGuard interface: %v", err)
}
} else {
logger.Info("WireGuard interface %s already exists\n", interfaceName)
return nil
}
// Assign IP address to the interface
err = assignIPAddress(wgconfig.IpAddress)
if err != nil {
logger.Fatal("Failed to assign IP address: %v", err)
}
logger.Info("Assigned IP address %s to interface %s\n", wgconfig.IpAddress, interfaceName)
// Check if the interface already exists
_, err = wgClient.Device(interfaceName)
if err != nil {
return fmt.Errorf("interface %s does not exist", interfaceName)
}
// Parse the private key
key, err := wgtypes.ParseKey(wgconfig.PrivateKey)
if err != nil {
return fmt.Errorf("failed to parse private key: %v", err)
}
// Create a new WireGuard configuration
config := wgtypes.Config{
PrivateKey: &key,
ListenPort: new(int),
}
*config.ListenPort = wgconfig.ListenPort
// Create and configure the WireGuard interface
err = wgClient.ConfigureDevice(interfaceName, config)
if err != nil {
return fmt.Errorf("failed to configure WireGuard device: %v", err)
}
// bring up the interface
link, err := netlink.LinkByName(interfaceName)
if err != nil {
return fmt.Errorf("failed to get interface: %v", err)
}
if err := netlink.LinkSetMTU(link, mtuInt); err != nil {
return fmt.Errorf("failed to set MTU: %v", err)
}
if err := netlink.LinkSetUp(link); err != nil {
return fmt.Errorf("failed to bring up interface: %v", err)
}
if err := ensureMSSClamping(); err != nil {
logger.Warn("Failed to ensure MSS clamping: %v", err)
}
if err := ensureWireguardFirewall(); err != nil {
logger.Warn("Failed to ensure WireGuard firewall rules: %v", err)
}
logger.Info("WireGuard interface %s created and configured", interfaceName)
// Record interface state metric
hostname, _ := os.Hostname()
metrics.RecordInterfaceUp(interfaceName, hostname, true)
return nil
}
func createWireGuardInterface() error {
wgLink := &netlink.GenericLink{
LinkAttrs: netlink.LinkAttrs{Name: interfaceName},
LinkType: "wireguard",
}
return netlink.LinkAdd(wgLink)
}
func assignIPAddress(ipAddress string) error {
link, err := netlink.LinkByName(interfaceName)
if err != nil {
return fmt.Errorf("failed to get interface: %v", err)
}
addr, err := netlink.ParseAddr(ipAddress)
if err != nil {
return fmt.Errorf("failed to parse IP address: %v", err)
}
return netlink.AddrAdd(link, addr)
}
func ensureWireguardPeers(peers []Peer) error {
wgMu.Lock()
defer wgMu.Unlock()
// get the current peers
device, err := wgClient.Device(interfaceName)
if err != nil {
return fmt.Errorf("failed to get device: %v", err)
}
// get the peer public keys
var currentPeers []string
for _, peer := range device.Peers {
currentPeers = append(currentPeers, peer.PublicKey.String())
}
// remove any peers that are not in the config
for _, peer := range currentPeers {
found := false
for _, configPeer := range peers {
if peer == configPeer.PublicKey {
found = true
break
}
}
if !found {
// Note: We need to call the internal removal logic without re-acquiring the lock
if err := removePeerInternal(peer); err != nil {
return fmt.Errorf("failed to remove peer: %v", err)
}
}
}
// add any peers that are in the config but not in the current peers
for _, configPeer := range peers {
found := false
for _, peer := range currentPeers {
if configPeer.PublicKey == peer {
found = true
break
}
}
if !found {
// Note: We need to call the internal addition logic without re-acquiring the lock
if err := addPeerInternal(configPeer); err != nil {
return fmt.Errorf("failed to add peer: %v", err)
}
}
}
return nil
}
func ensureMSSClamping() error {
// Calculate MSS value (MTU - 40 for IPv4 header (20) and TCP header (20))
mssValue := mtuInt - 40
// Rules to be managed - just the chains, we'll construct the full command separately
chains := []string{"INPUT", "OUTPUT", "FORWARD"}
// First, try to delete any existing rules
for _, chain := range chains {
deleteCmd := exec.Command("/usr/sbin/iptables",
"-t", "mangle",
"-D", chain,
"-p", "tcp",
"--tcp-flags", "SYN,RST", "SYN",
"-j", "TCPMSS",
"--set-mss", fmt.Sprintf("%d", mssValue))
logger.Info("Attempting to delete existing MSS clamping rule for chain %s", chain)
// Try deletion multiple times to handle multiple existing rules
for i := 0; i < 3; i++ {
out, err := deleteCmd.CombinedOutput()
if err != nil {
// Convert exit status 1 to string for better logging
if exitErr, ok := err.(*exec.ExitError); ok {
logger.Debug("Deletion stopped for chain %s: %v (output: %s)",
chain, exitErr.String(), string(out))
}
break // No more rules to delete
}
logger.Info("Deleted MSS clamping rule for chain %s (attempt %d)", chain, i+1)
}
}
// Then add the new rules
var errors []error
for _, chain := range chains {
addCmd := exec.Command("/usr/sbin/iptables",
"-t", "mangle",
"-A", chain,
"-p", "tcp",
"--tcp-flags", "SYN,RST", "SYN",
"-j", "TCPMSS",
"--set-mss", fmt.Sprintf("%d", mssValue))
logger.Info("Adding MSS clamping rule for chain %s", chain)
if out, err := addCmd.CombinedOutput(); err != nil {
errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)",
chain, err, string(out))
logger.Error("%s", errMsg)
errors = append(errors, fmt.Errorf("%s", errMsg))
continue
}
// Verify the rule was added
checkCmd := exec.Command("/usr/sbin/iptables",
"-t", "mangle",
"-C", chain,
"-p", "tcp",
"--tcp-flags", "SYN,RST", "SYN",
"-j", "TCPMSS",
"--set-mss", fmt.Sprintf("%d", mssValue))
if out, err := checkCmd.CombinedOutput(); err != nil {
errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)",
chain, err, string(out))
logger.Error("%s", errMsg)
errors = append(errors, fmt.Errorf("%s", errMsg))
continue
}
logger.Info("Successfully added and verified MSS clamping rule for chain %s", chain)
}
// If we encountered any errors, return them combined
if len(errors) > 0 {
var errMsgs []string
for _, err := range errors {
errMsgs = append(errMsgs, err.Error())
}
return fmt.Errorf("MSS clamping setup encountered errors:\n%s",
strings.Join(errMsgs, "\n"))
}
return nil
}
func ensureWireguardFirewall() error {
// Rules to enforce:
// 1. Allow established/related connections (responses to our outbound traffic)
// 2. Allow ICMP ping packets
// 3. Drop all other inbound traffic from peers
// Define the rules we want to ensure exist
rules := [][]string{
// Allow established and related connections (responses to outbound traffic)
{
"-A", "INPUT",
"-i", interfaceName,
"-m", "conntrack",
"--ctstate", "ESTABLISHED,RELATED",
"-j", "ACCEPT",
},
// Allow ICMP ping requests
{
"-A", "INPUT",
"-i", interfaceName,
"-p", "icmp",
"--icmp-type", "8",
"-j", "ACCEPT",
},
// Drop all other inbound traffic from WireGuard interface
{
"-A", "INPUT",
"-i", interfaceName,
"-j", "DROP",
},
}
// First, try to delete any existing rules for this interface
for _, rule := range rules {
deleteArgs := make([]string, len(rule))
copy(deleteArgs, rule)
// Change -A to -D for deletion
for i, arg := range deleteArgs {
if arg == "-A" {
deleteArgs[i] = "-D"
break
}
}
deleteCmd := exec.Command("/usr/sbin/iptables", deleteArgs...)
logger.Debug("Attempting to delete existing firewall rule: %v", deleteArgs)
// Try deletion multiple times to handle multiple existing rules
for i := 0; i < 5; i++ {
out, err := deleteCmd.CombinedOutput()
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
logger.Debug("Deletion stopped: %v (output: %s)", exitErr.String(), string(out))
}
break // No more rules to delete
}
logger.Info("Deleted existing firewall rule (attempt %d)", i+1)
}
}
// Now add the rules
var errors []error
for i, rule := range rules {
addCmd := exec.Command("/usr/sbin/iptables", rule...)
logger.Info("Adding WireGuard firewall rule %d: %v", i+1, rule)
if out, err := addCmd.CombinedOutput(); err != nil {
errMsg := fmt.Sprintf("Failed to add firewall rule %d: %v (output: %s)", i+1, err, string(out))
logger.Error("%s", errMsg)
errors = append(errors, fmt.Errorf("%s", errMsg))
continue
}
// Verify the rule was added by checking
checkArgs := make([]string, len(rule))
copy(checkArgs, rule)
// Change -A to -C for check
for j, arg := range checkArgs {
if arg == "-A" {
checkArgs[j] = "-C"
break
}
}
checkCmd := exec.Command("/usr/sbin/iptables", checkArgs...)
if out, err := checkCmd.CombinedOutput(); err != nil {
errMsg := fmt.Sprintf("Rule verification failed for rule %d: %v (output: %s)", i+1, err, string(out))
logger.Error("%s", errMsg)
errors = append(errors, fmt.Errorf("%s", errMsg))
continue
}
logger.Info("Successfully added and verified WireGuard firewall rule %d", i+1)
}
if len(errors) > 0 {
var errMsgs []string
for _, err := range errors {
errMsgs = append(errMsgs, err.Error())
}
return fmt.Errorf("WireGuard firewall setup encountered errors:\n%s", strings.Join(errMsgs, "\n"))
}
logger.Info("WireGuard firewall rules successfully configured for interface %s", interfaceName)
return nil
}
func handlePeer(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodPost:
handleAddPeer(w, r)
case http.MethodDelete:
handleRemovePeer(w, r)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
func handleHealthz(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}
func handleAddPeer(w http.ResponseWriter, r *http.Request) {
var peer Peer
if err := json.NewDecoder(r.Body).Decode(&peer); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
// Record peer add error
metrics.RecordPeerOperation("add", "error")
return
}
err := addPeer(peer)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
// Record peer add error
metrics.RecordPeerOperation("add", "error")
return
}
// Record peer add success
metrics.RecordPeerOperation("add", "success")
// Notify if notifyURL is set
go notifyPeerChange("add", peer.PublicKey)
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]string{"status": "Peer added successfully"})
}
func addPeer(peer Peer) error {
wgMu.Lock()
defer wgMu.Unlock()
return addPeerInternal(peer)
}
func addPeerInternal(peer Peer) error {
pubKey, err := wgtypes.ParseKey(peer.PublicKey)
if err != nil {
return fmt.Errorf("failed to parse public key: %v", err)
}
logger.Debug("Adding peer %s with AllowedIPs: %v", peer.PublicKey, peer.AllowedIPs)
// parse allowed IPs into array of net.IPNet
var allowedIPs []net.IPNet
var wgIPs []string
for _, ipStr := range peer.AllowedIPs {
logger.Debug("Parsing AllowedIP: %s", ipStr)
_, ipNet, err := net.ParseCIDR(ipStr)
if err != nil {
logger.Warn("Failed to parse allowed IP '%s' for peer %s: %v", ipStr, peer.PublicKey, err)
return fmt.Errorf("failed to parse allowed IP: %v", err)
}
allowedIPs = append(allowedIPs, *ipNet)
// Extract the IP address from the CIDR for relay cleanup
extractedIP := ipNet.IP.String()
wgIPs = append(wgIPs, extractedIP)
logger.Debug("Extracted IP %s from AllowedIP %s", extractedIP, ipStr)
}
peerConfig := wgtypes.PeerConfig{
PublicKey: pubKey,
AllowedIPs: allowedIPs,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peerConfig},
}
if err := wgClient.ConfigureDevice(interfaceName, config); err != nil {
return fmt.Errorf("failed to add peer: %v", err)
}
// Setup bandwidth limiting for each peer IP
if doTrafficShaping {
logger.Debug("doTrafficShaping is true, setting up bandwidth limits for %d IPs", len(wgIPs))
for _, wgIP := range wgIPs {
if err := setupPeerBandwidthLimit(wgIP); err != nil {
logger.Warn("Failed to setup bandwidth limit for peer IP %s: %v", wgIP, err)
}
}
} else {
logger.Debug("doTrafficShaping is false, skipping bandwidth limit setup")
}
// Clear relay connections for the peer's WireGuard IPs
if proxyRelay != nil {
for _, wgIP := range wgIPs {
proxyRelay.OnPeerAdded(wgIP)
}
}
logger.Info("Peer %s added successfully", peer.PublicKey)
// Record metrics
metrics.RecordPeersTotal(interfaceName, 1)
metrics.RecordAllowedIPsCount(interfaceName, peer.PublicKey, int64(len(peer.AllowedIPs)))
return nil
}
func handleRemovePeer(w http.ResponseWriter, r *http.Request) {
publicKey := r.URL.Query().Get("public_key")
if publicKey == "" {
http.Error(w, "Missing public_key query parameter", http.StatusBadRequest)
// Record peer remove error
metrics.RecordPeerOperation("remove", "error")
return
}
err := removePeer(publicKey)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
// Record peer remove error
metrics.RecordPeerOperation("remove", "error")
return
}
// Record peer remove success
metrics.RecordPeerOperation("remove", "success")
// Notify if notifyURL is set
go notifyPeerChange("remove", publicKey)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"status": "Peer removed successfully"})
}
func removePeer(publicKey string) error {
wgMu.Lock()
defer wgMu.Unlock()
return removePeerInternal(publicKey)
}
func removePeerInternal(publicKey string) error {
pubKey, err := wgtypes.ParseKey(publicKey)
if err != nil {
return fmt.Errorf("failed to parse public key: %v", err)
}
// Get current peer info before removing to clear relay connections and bandwidth limits
var wgIPs []string
device, err := wgClient.Device(interfaceName)
if err == nil {
for _, peer := range device.Peers {
if peer.PublicKey.String() == publicKey {
// Extract WireGuard IPs from this peer's allowed IPs
for _, allowedIP := range peer.AllowedIPs {
wgIPs = append(wgIPs, allowedIP.IP.String())
}
break
}
}
}
peerConfig := wgtypes.PeerConfig{
PublicKey: pubKey,
Remove: true,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peerConfig},
}
if err := wgClient.ConfigureDevice(interfaceName, config); err != nil {
return fmt.Errorf("failed to remove peer: %v", err)
}
// Remove bandwidth limits for each peer IP
if doTrafficShaping {
for _, wgIP := range wgIPs {
if err := removePeerBandwidthLimit(wgIP); err != nil {
logger.Warn("Failed to remove bandwidth limit for peer IP %s: %v", wgIP, err)
}
}
}
// Clear relay connections for the peer's WireGuard IPs
if proxyRelay != nil {
for _, wgIP := range wgIPs {
proxyRelay.OnPeerRemoved(wgIP)
}
}
logger.Info("Peer %s removed successfully", publicKey)
// Record metrics
metrics.RecordPeersTotal(interfaceName, -1)
metrics.RecordAllowedIPsCount(interfaceName, publicKey, -int64(len(wgIPs)))
return nil
}
func handleUpdateProxyMapping(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
logger.Error("Invalid method: %s", r.Method)
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var update ProxyMappingUpdate
if err := json.NewDecoder(r.Body).Decode(&update); err != nil {
logger.Error("Failed to decode request body: %v", err)
http.Error(w, fmt.Sprintf("Failed to decode request body: %v", err), http.StatusBadRequest)
return
}
// Validate the update request
if update.OldDestination.DestinationIP == "" || update.NewDestination.DestinationIP == "" {
logger.Error("Both old and new destination IP addresses are required")
http.Error(w, "Both old and new destination IP addresses are required", http.StatusBadRequest)
return
}
if update.OldDestination.DestinationPort <= 0 || update.NewDestination.DestinationPort <= 0 {
logger.Error("Both old and new destination ports must be positive integers")
http.Error(w, "Both old and new destination ports must be positive integers", http.StatusBadRequest)
return
}
// Update the proxy mappings in the relay server
if proxyRelay == nil {
logger.Error("Proxy server is not available")
http.Error(w, "Proxy server is not available", http.StatusInternalServerError)
// Record error
metrics.RecordProxyMappingUpdateRequest("error")
return
}
updatedCount := proxyRelay.UpdateDestinationInMappings(update.OldDestination, update.NewDestination)
logger.Info("Updated %d proxy mappings: %s:%d -> %s:%d",
updatedCount,
update.OldDestination.DestinationIP, update.OldDestination.DestinationPort,
update.NewDestination.DestinationIP, update.NewDestination.DestinationPort)
// Record success
metrics.RecordProxyMappingUpdateRequest("success")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]interface{}{
"status": "Proxy mappings updated successfully",
"updatedCount": updatedCount,
"oldDestination": update.OldDestination,
"newDestination": update.NewDestination,
})
}
func handleUpdateDestinations(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
logger.Error("Invalid method: %s", r.Method)
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var request UpdateDestinationsRequest
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
logger.Error("Failed to decode request body: %v", err)
http.Error(w, fmt.Sprintf("Failed to decode request body: %v", err), http.StatusBadRequest)
return
}
// Validate the request
if request.SourceIP == "" {
logger.Error("Source IP address is required")
http.Error(w, "Source IP address is required", http.StatusBadRequest)
return
}
if request.SourcePort <= 0 {
logger.Error("Source port must be a positive integer")
http.Error(w, "Source port must be a positive integer", http.StatusBadRequest)
return
}
if len(request.Destinations) == 0 {
logger.Error("At least one destination is required")
http.Error(w, "At least one destination is required", http.StatusBadRequest)
return
}
// Validate each destination
for i, dest := range request.Destinations {
if dest.DestinationIP == "" {
logger.Error("Destination IP is required for destination %d", i)
http.Error(w, fmt.Sprintf("Destination IP is required for destination %d", i), http.StatusBadRequest)
return
}
if dest.DestinationPort <= 0 {
logger.Error("Destination port must be a positive integer for destination %d", i)
http.Error(w, fmt.Sprintf("Destination port must be a positive integer for destination %d", i), http.StatusBadRequest)
return
}
}
// Update the proxy mappings in the relay server
if proxyRelay == nil {
logger.Error("Proxy server is not available")
http.Error(w, "Proxy server is not available", http.StatusInternalServerError)
// Record error
metrics.RecordDestinationsUpdateRequest("error")
return
}
proxyRelay.UpdateProxyMapping(request.SourceIP, request.SourcePort, request.Destinations)
logger.Info("Updated proxy mapping for %s:%d with %d destinations",
request.SourceIP, request.SourcePort, len(request.Destinations))
// Record success
metrics.RecordDestinationsUpdateRequest("success")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]interface{}{
"status": "Destinations updated successfully",
"sourceIP": request.SourceIP,
"sourcePort": request.SourcePort,
"destinationCount": len(request.Destinations),
"destinations": request.Destinations,
})
}
// UpdateLocalSNIsRequest represents the JSON payload for updating local SNIs
type UpdateLocalSNIsRequest struct {
FullDomains []string `json:"fullDomains"`
}
func handleUpdateLocalSNIs(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
logger.Error("Invalid method: %s", r.Method)
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req UpdateLocalSNIsRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid JSON payload", http.StatusBadRequest)
return
}
proxySNI.UpdateLocalSNIs(req.FullDomains)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]interface{}{
"status": "Local SNIs updated successfully",
})
}
func periodicBandwidthCheck(ctx context.Context, endpoint string) error {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := reportPeerBandwidth(endpoint); err != nil {
logger.Info("Failed to report peer bandwidth: %v", err)
}
case <-ctx.Done():
return ctx.Err()
}
}
}
func calculatePeerBandwidth() ([]PeerBandwidth, error) {
wgMu.Lock()
device, err := wgClient.Device(interfaceName)
wgMu.Unlock()
if err != nil {
return nil, fmt.Errorf("failed to get device: %v", err)
}
var peerBandwidths []PeerBandwidth
now := time.Now()
mu.Lock()
defer mu.Unlock()
// Track the set of peers currently present on the device to prune stale readings efficiently
currentPeerKeys := make(map[string]struct{}, len(device.Peers))
for _, peer := range device.Peers {
publicKey := peer.PublicKey.String()
currentPeerKeys[publicKey] = struct{}{}
currentReading := PeerReading{
BytesReceived: peer.ReceiveBytes,
BytesTransmitted: peer.TransmitBytes,
LastChecked: now,
}
var bytesInDiff, bytesOutDiff float64
lastReading, exists := lastReadings[publicKey]
if exists {
timeDiff := currentReading.LastChecked.Sub(lastReading.LastChecked).Seconds()
if timeDiff > 0 {
// Calculate bytes transferred since last reading
bytesInDiff = float64(currentReading.BytesReceived - lastReading.BytesReceived)
bytesOutDiff = float64(currentReading.BytesTransmitted - lastReading.BytesTransmitted)
// Handle counter wraparound (if the counter resets or overflows)
if bytesInDiff < 0 {
bytesInDiff = float64(currentReading.BytesReceived)
}
if bytesOutDiff < 0 {
bytesOutDiff = float64(currentReading.BytesTransmitted)
}
// Convert to MB
bytesInMB := bytesInDiff / (1024 * 1024)
bytesOutMB := bytesOutDiff / (1024 * 1024)
// Record metrics (in bytes)
if bytesInDiff > 0 {
metrics.RecordBytesReceived(interfaceName, publicKey, int64(bytesInDiff))
}
if bytesOutDiff > 0 {
metrics.RecordBytesTransmitted(interfaceName, publicKey, int64(bytesOutDiff))
}
peerBandwidths = append(peerBandwidths, PeerBandwidth{
PublicKey: publicKey,
BytesIn: bytesInMB,
BytesOut: bytesOutMB,
})
} else {
// If readings are too close together or time hasn't passed, report 0
peerBandwidths = append(peerBandwidths, PeerBandwidth{
PublicKey: publicKey,
BytesIn: 0,
BytesOut: 0,
})
}
} else {
// For first reading of a peer, report 0 to establish baseline
peerBandwidths = append(peerBandwidths, PeerBandwidth{
PublicKey: publicKey,
BytesIn: 0,
BytesOut: 0,
})
}
// Update the last reading
lastReadings[publicKey] = currentReading
}
// Clean up old peers
for publicKey := range lastReadings {
if _, exists := currentPeerKeys[publicKey]; !exists {
delete(lastReadings, publicKey)
}
}
return peerBandwidths, nil
}
func reportPeerBandwidth(apiURL string) error {
bandwidths, err := calculatePeerBandwidth()
if err != nil {
// Record bandwidth report error
metrics.RecordBandwidthReport("error")
return fmt.Errorf("failed to calculate peer bandwidth: %v", err)
}
jsonData, err := json.Marshal(bandwidths)
if err != nil {
metrics.RecordBandwidthReport("error")
return fmt.Errorf("failed to marshal bandwidth data: %v", err)
}
resp, err := http.Post(apiURL, "application/json", bytes.NewBuffer(jsonData))
if err != nil {
metrics.RecordBandwidthReport("error")
return fmt.Errorf("failed to send bandwidth data: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
metrics.RecordBandwidthReport("error")
return fmt.Errorf("API returned non-OK status: %s", resp.Status)
}
// Record successful bandwidth report
metrics.RecordBandwidthReport("success")
return nil
}
// notifyPeerChange sends a POST request to notifyURL with the action and public key.
func notifyPeerChange(action, publicKey string) {
if notifyURL == "" {
return
}
payload := map[string]string{
"action": action,
"publicKey": publicKey,
}
data, err := json.Marshal(payload)
if err != nil {
logger.Warn("Failed to marshal notify payload: %v", err)
return
}
resp, err := http.Post(notifyURL, "application/json", bytes.NewBuffer(data))
if err != nil {
logger.Warn("Failed to notify peer change: %v", err)
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
logger.Warn("Notify server returned non-OK: %s", resp.Status)
}
}
func monitorMemory(limit uint64) {
var m runtime.MemStats
for {
runtime.ReadMemStats(&m)
if m.Alloc > limit {
// Determine severity based on how much over the limit
severity := "warning"
if m.Alloc > limit*2 {
severity = "critical"
}
fmt.Printf("Memory spike detected (%d bytes). Dumping profile...\n", m.Alloc)
// Record memory spike metric
metrics.RecordMemorySpike(severity)
f, err := os.Create(fmt.Sprintf("/var/config/heap/heap-spike-%d.pprof", time.Now().Unix()))
if err != nil {
log.Println("could not create profile:", err)
} else {
pprof.WriteHeapProfile(f)
f.Close()
// Record heap profile written metric
metrics.RecordHeapProfileWritten()
}
// Wait a while before checking again to avoid spamming profiles
time.Sleep(5 * time.Minute)
}
time.Sleep(5 * time.Second)
}
}
// ensureIFBDevice creates and configures the IFB (Intermediate Functional Block) device used to
// shape ingress traffic on the WireGuard interface. Linux TC qdiscs only control egress by default;
// the IFB trick redirects all ingress packets to a virtual device so HTB shaping can be applied
// there, and the packets are transparently re-injected into the kernel network stack afterwards.
// This is completely invisible to sockets/applications (including a reverse proxy on the host).
func ensureIFBDevice() error {
// Check if the ifb kernel module is loaded (works inside containers too)
if _, err := os.Stat("/sys/module/ifb"); os.IsNotExist(err) {
logger.Warn("IFB module not loaded, skipping IFB setup and ingress traffic shaping")
return nil
}
// Create the IFB device if it does not already exist
_, err := netlink.LinkByName(ifbName)
if err != nil {
if _, ok := err.(netlink.LinkNotFoundError); ok {
cmd := exec.Command("ip", "link", "add", ifbName, "type", "ifb")
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("failed to create IFB device %s: %v, output: %s", ifbName, err, string(out))
}
logger.Info("Created IFB device %s", ifbName)
} else {
return fmt.Errorf("failed to look up IFB device %s: %v", ifbName, err)
}
} else {
logger.Info("IFB device %s already exists", ifbName)
}
// Bring the IFB device up
cmd := exec.Command("ip", "link", "set", "dev", ifbName, "up")
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("failed to bring up IFB device %s: %v, output: %s", ifbName, err, string(out))
}
// Attach an ingress qdisc to the WireGuard interface if one is not already present
cmd = exec.Command("tc", "qdisc", "show", "dev", interfaceName)
out, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("failed to query qdiscs on %s: %v", interfaceName, err)
}
if !strings.Contains(string(out), "ingress") {
cmd = exec.Command("tc", "qdisc", "add", "dev", interfaceName, "handle", "ffff:", "ingress")
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("failed to add ingress qdisc to %s: %v, output: %s", interfaceName, err, string(out))
}
logger.Info("Added ingress qdisc to %s", interfaceName)
}
// Add a catch-all filter that redirects every ingress packet from wg0 to the IFB device.
// Per-peer rate limiting then happens on ifb0's egress HTB qdisc (handle 2:).
cmd = exec.Command("tc", "filter", "show", "dev", interfaceName, "parent", "ffff:")
out, err = cmd.CombinedOutput()
if err != nil || !strings.Contains(string(out), ifbName) {
cmd = exec.Command("tc", "filter", "add", "dev", interfaceName,
"parent", "ffff:", "protocol", "ip",
"u32", "match", "u32", "0", "0",
"action", "mirred", "egress", "redirect", "dev", ifbName)
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("failed to add ingress redirect filter on %s: %v, output: %s", interfaceName, err, string(out))
}
logger.Info("Added ingress redirect filter: %s -> %s", interfaceName, ifbName)
}
// Ensure an HTB root qdisc exists on the IFB device (handle 2:) for per-peer shaping
cmd = exec.Command("tc", "qdisc", "show", "dev", ifbName)
out, err = cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("failed to query qdiscs on %s: %v", ifbName, err)
}
if !strings.Contains(string(out), "htb") {
cmd = exec.Command("tc", "qdisc", "add", "dev", ifbName, "root", "handle", "2:", "htb", "default", "9999")
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("failed to add HTB qdisc to %s: %v, output: %s", ifbName, err, string(out))
}
logger.Info("Added HTB root qdisc (handle 2:) to IFB device %s", ifbName)
}
logger.Info("IFB device %s ready for ingress traffic shaping", ifbName)
return nil
}
// setupPeerBandwidthLimit sets up TC (Traffic Control) to limit bandwidth for a specific peer IP
// Bandwidth limit is configurable via the --bandwidth-limit flag or BANDWIDTH_LIMIT env var (default: 50mbit)
func setupPeerBandwidthLimit(peerIP string) error {
logger.Debug("setupPeerBandwidthLimit called for peer IP: %s", peerIP)
// Parse the IP to get just the IP address (strip any CIDR notation if present)
ip := peerIP
if strings.Contains(peerIP, "/") {
parsedIP, _, err := net.ParseCIDR(peerIP)
if err != nil {
return fmt.Errorf("failed to parse peer IP: %v", err)
}
ip = parsedIP.String()
}
// First, ensure we have a root qdisc on the interface (HTB - Hierarchical Token Bucket)
// Check if qdisc already exists
cmd := exec.Command("tc", "qdisc", "show", "dev", interfaceName)
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("failed to check qdisc: %v, output: %s", err, string(output))
}
// If no HTB qdisc exists, create one
if !strings.Contains(string(output), "htb") {
cmd = exec.Command("tc", "qdisc", "add", "dev", interfaceName, "root", "handle", "1:", "htb", "default", "9999")
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("failed to add root qdisc: %v, output: %s", err, string(output))
}
logger.Info("Created HTB root qdisc on %s", interfaceName)
}
// Generate a unique class ID based on the IP address
// We'll use the last octet of the IP as part of the class ID
ipParts := strings.Split(ip, ".")
if len(ipParts) != 4 {
return fmt.Errorf("invalid IPv4 address: %s", ip)
}
lastOctet := ipParts[3]
classID := fmt.Sprintf("1:%s", lastOctet)
logger.Debug("Generated class ID %s for peer IP %s", classID, ip)
// Create a class for this peer with bandwidth limit
cmd = exec.Command("tc", "class", "add", "dev", interfaceName, "parent", "1:", "classid", classID,
"htb", "rate", bandwidthLimit, "ceil", bandwidthLimit)
if output, err := cmd.CombinedOutput(); err != nil {
logger.Debug("tc class add failed for %s: %v, output: %s", ip, err, string(output))
// If class already exists, try to replace it
if strings.Contains(string(output), "File exists") {
cmd = exec.Command("tc", "class", "replace", "dev", interfaceName, "parent", "1:", "classid", classID,
"htb", "rate", bandwidthLimit, "ceil", bandwidthLimit)
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("failed to replace class: %v, output: %s", err, string(output))
}
logger.Debug("Successfully replaced existing class %s for peer IP %s", classID, ip)
} else {
return fmt.Errorf("failed to add class: %v, output: %s", err, string(output))
}
} else {
logger.Debug("Successfully added new class %s for peer IP %s", classID, ip)
}
// Add a filter to match traffic to this peer IP on wg0 egress (peer's download)
cmd = exec.Command("tc", "filter", "add", "dev", interfaceName, "protocol", "ip", "parent", "1:",
"prio", "1", "u32", "match", "ip", "dst", ip, "flowid", classID)
if output, err := cmd.CombinedOutput(); err != nil {
logger.Warn("Failed to add egress filter for peer IP %s: %v, output: %s", ip, err, string(output))
}
// Set up ingress shaping on the IFB device (peer's upload / ingress on wg0).
// All wg0 ingress is redirected to ifb0 by ensureIFBDevice; we add a per-peer
// class + src filter here so each peer gets its own independent rate limit.
ifbClassID := fmt.Sprintf("2:%s", lastOctet)
// Check if the ifb kernel module is loaded (works inside containers too)
if _, err := os.Stat("/sys/module/ifb"); os.IsNotExist(err) {
logger.Warn("IFB module not loaded, skipping IFB setup and ingress traffic shaping.")
logger.Info("Setup bandwidth limit of %s for peer IP %s (egress class %s, ingress class %s)", bandwidthLimit, ip, classID, ifbClassID)
return nil
}
cmd = exec.Command("tc", "class", "add", "dev", ifbName, "parent", "2:", "classid", ifbClassID,
"htb", "rate", bandwidthLimit, "ceil", bandwidthLimit)
if output, err := cmd.CombinedOutput(); err != nil {
if strings.Contains(string(output), "File exists") {
cmd = exec.Command("tc", "class", "replace", "dev", ifbName, "parent", "2:", "classid", ifbClassID,
"htb", "rate", bandwidthLimit, "ceil", bandwidthLimit)
if output, err := cmd.CombinedOutput(); err != nil {
logger.Warn("Failed to replace IFB class for peer IP %s: %v, output: %s", ip, err, string(output))
} else {
logger.Debug("Replaced existing IFB class %s for peer IP %s", ifbClassID, ip)
}
} else {
logger.Warn("Failed to add IFB class for peer IP %s: %v, output: %s", ip, err, string(output))
}
} else {
logger.Debug("Added IFB class %s for peer IP %s", ifbClassID, ip)
}
cmd = exec.Command("tc", "filter", "add", "dev", ifbName, "protocol", "ip", "parent", "2:",
"prio", "1", "u32", "match", "ip", "src", ip, "flowid", ifbClassID)
if output, err := cmd.CombinedOutput(); err != nil {
logger.Warn("Failed to add IFB ingress filter for peer IP %s: %v, output: %s", ip, err, string(output))
}
logger.Info("Setup bandwidth limit of %s for peer IP %s (egress class %s, ingress class %s)", bandwidthLimit, ip, classID, ifbClassID)
return nil
}
// removePeerBandwidthLimit removes TC rules for a specific peer IP
func removePeerBandwidthLimit(peerIP string) error {
// Parse the IP to get just the IP address
ip := peerIP
if strings.Contains(peerIP, "/") {
parsedIP, _, err := net.ParseCIDR(peerIP)
if err != nil {
return fmt.Errorf("failed to parse peer IP: %v", err)
}
ip = parsedIP.String()
}
// Generate the class ID based on the IP
ipParts := strings.Split(ip, ".")
if len(ipParts) != 4 {
return fmt.Errorf("invalid IPv4 address: %s", ip)
}
lastOctet := ipParts[3]
classID := fmt.Sprintf("1:%s", lastOctet)
// Remove filters for this IP
// List all filters to find the ones for this class
cmd := exec.Command("tc", "filter", "show", "dev", interfaceName, "parent", "1:")
output, err := cmd.CombinedOutput()
if err != nil {
logger.Warn("Failed to list filters for peer IP %s: %v, output: %s", ip, err, string(output))
} else {
// Parse the output to find filter handles that match this classID
// The output format includes lines like:
// filter parent 1: protocol ip pref 1 u32 chain 0 fh 800::800 order 2048 key ht 800 bkt 0 flowid 1:4
lines := strings.Split(string(output), "\n")
for _, line := range lines {
// Look for lines containing our flowid (classID)
if strings.Contains(line, "flowid "+classID) && strings.Contains(line, "fh ") {
// Extract handle (format: fh 800::800)
parts := strings.Fields(line)
var handle string
for j, part := range parts {
if part == "fh" && j+1 < len(parts) {
handle = parts[j+1]
break
}
}
if handle != "" {
// Delete this filter using the handle
delCmd := exec.Command("tc", "filter", "del", "dev", interfaceName, "parent", "1:", "handle", handle, "prio", "1", "u32")
if delOutput, delErr := delCmd.CombinedOutput(); delErr != nil {
logger.Debug("Failed to delete filter handle %s for peer IP %s: %v, output: %s", handle, ip, delErr, string(delOutput))
} else {
logger.Debug("Deleted filter handle %s for peer IP %s", handle, ip)
}
}
}
}
}
// Remove the egress class on wg0
cmd = exec.Command("tc", "class", "del", "dev", interfaceName, "classid", classID)
if output, err := cmd.CombinedOutput(); err != nil {
if !strings.Contains(string(output), "No such file or directory") && !strings.Contains(string(output), "Cannot find") {
logger.Warn("Failed to remove egress class for peer IP %s: %v, output: %s", ip, err, string(output))
}
}
// Remove the ingress class and filters on the IFB device
ifbClassID := fmt.Sprintf("2:%s", lastOctet)
// Check if the ifb kernel module is loaded (works inside containers too)
if _, err := os.Stat("/sys/module/ifb"); os.IsNotExist(err) {
logger.Warn("IFB module not loaded, skipping IFB setup and ingress traffic shaping")
logger.Info("Removed bandwidth limit for peer IP %s (egress class %s, ingress class %s)", ip, classID, ifbClassID)
return nil
}
cmd = exec.Command("tc", "filter", "show", "dev", ifbName, "parent", "2:")
output, err = cmd.CombinedOutput()
if err != nil {
logger.Warn("Failed to list IFB filters for peer IP %s: %v, output: %s", ip, err, string(output))
} else {
lines := strings.Split(string(output), "\n")
for _, line := range lines {
if strings.Contains(line, "flowid "+ifbClassID) && strings.Contains(line, "fh ") {
parts := strings.Fields(line)
var handle string
for j, part := range parts {
if part == "fh" && j+1 < len(parts) {
handle = parts[j+1]
break
}
}
if handle != "" {
delCmd := exec.Command("tc", "filter", "del", "dev", ifbName, "parent", "2:", "handle", handle, "prio", "1", "u32")
if delOutput, delErr := delCmd.CombinedOutput(); delErr != nil {
logger.Debug("Failed to delete IFB filter handle %s for peer IP %s: %v, output: %s", handle, ip, delErr, string(delOutput))
} else {
logger.Debug("Deleted IFB filter handle %s for peer IP %s", handle, ip)
}
}
}
}
}
cmd = exec.Command("tc", "class", "del", "dev", ifbName, "classid", ifbClassID)
if output, err := cmd.CombinedOutput(); err != nil {
if !strings.Contains(string(output), "No such file or directory") && !strings.Contains(string(output), "Cannot find") {
logger.Warn("Failed to remove IFB class for peer IP %s: %v, output: %s", ip, err, string(output))
}
}
logger.Info("Removed bandwidth limit for peer IP %s (egress class %s, ingress class %s)", ip, classID, ifbClassID)
return nil
}