mirror of
https://github.com/fosrl/olm.git
synced 2025-12-05 19:17:35 -06:00
1146 lines
31 KiB
Go
1146 lines
31 KiB
Go
package main
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net"
|
|
"os/exec"
|
|
"regexp"
|
|
"runtime"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/fosrl/newt/logger"
|
|
"github.com/fosrl/olm/peermonitor"
|
|
"github.com/fosrl/olm/websocket"
|
|
"github.com/vishvananda/netlink"
|
|
"golang.org/x/crypto/chacha20poly1305"
|
|
"golang.org/x/crypto/curve25519"
|
|
"golang.org/x/exp/rand"
|
|
"golang.zx2c4.com/wireguard/conn"
|
|
"golang.zx2c4.com/wireguard/device"
|
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
)
|
|
|
|
type WgData struct {
|
|
Sites []SiteConfig `json:"sites"`
|
|
TunnelIP string `json:"tunnelIP"`
|
|
}
|
|
|
|
type SiteConfig struct {
|
|
SiteId int `json:"siteId"`
|
|
Endpoint string `json:"endpoint"`
|
|
PublicKey string `json:"publicKey"`
|
|
ServerIP string `json:"serverIP"`
|
|
ServerPort uint16 `json:"serverPort"`
|
|
RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access
|
|
}
|
|
|
|
type TargetsByType struct {
|
|
UDP []string `json:"udp"`
|
|
TCP []string `json:"tcp"`
|
|
}
|
|
|
|
type TargetData struct {
|
|
Targets []string `json:"targets"`
|
|
}
|
|
|
|
type HolePunchMessage struct {
|
|
NewtID string `json:"newtId"`
|
|
}
|
|
|
|
type ExitNode struct {
|
|
Endpoint string `json:"endpoint"`
|
|
PublicKey string `json:"publicKey"`
|
|
}
|
|
|
|
type HolePunchData struct {
|
|
ExitNodes []ExitNode `json:"exitNodes"`
|
|
}
|
|
|
|
type EncryptedHolePunchMessage struct {
|
|
EphemeralPublicKey string `json:"ephemeralPublicKey"`
|
|
Nonce []byte `json:"nonce"`
|
|
Ciphertext []byte `json:"ciphertext"`
|
|
}
|
|
|
|
var (
|
|
peerMonitor *peermonitor.PeerMonitor
|
|
stopHolepunch chan struct{}
|
|
stopRegister func()
|
|
stopPing chan struct{}
|
|
olmToken string
|
|
holePunchRunning bool
|
|
)
|
|
|
|
const (
|
|
ENV_WG_TUN_FD = "WG_TUN_FD"
|
|
ENV_WG_UAPI_FD = "WG_UAPI_FD"
|
|
ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND"
|
|
)
|
|
|
|
type fixedPortBind struct {
|
|
port uint16
|
|
conn.Bind
|
|
}
|
|
|
|
// PeerAction represents a request to add, update, or remove a peer
|
|
type PeerAction struct {
|
|
Action string `json:"action"` // "add", "update", or "remove"
|
|
SiteInfo SiteConfig `json:"siteInfo"` // Site configuration information
|
|
}
|
|
|
|
// UpdatePeerData represents the data needed to update a peer
|
|
type UpdatePeerData struct {
|
|
SiteId int `json:"siteId"`
|
|
Endpoint string `json:"endpoint"`
|
|
PublicKey string `json:"publicKey"`
|
|
ServerIP string `json:"serverIP"`
|
|
ServerPort uint16 `json:"serverPort"`
|
|
RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access
|
|
}
|
|
|
|
// AddPeerData represents the data needed to add a peer
|
|
type AddPeerData struct {
|
|
SiteId int `json:"siteId"`
|
|
Endpoint string `json:"endpoint"`
|
|
PublicKey string `json:"publicKey"`
|
|
ServerIP string `json:"serverIP"`
|
|
ServerPort uint16 `json:"serverPort"`
|
|
RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access
|
|
}
|
|
|
|
// RemovePeerData represents the data needed to remove a peer
|
|
type RemovePeerData struct {
|
|
SiteId int `json:"siteId"`
|
|
}
|
|
|
|
type RelayPeerData struct {
|
|
SiteId int `json:"siteId"`
|
|
Endpoint string `json:"endpoint"`
|
|
PublicKey string `json:"publicKey"`
|
|
}
|
|
|
|
func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) {
|
|
// Ignore the requested port and use our fixed port
|
|
return b.Bind.Open(b.port)
|
|
}
|
|
|
|
func NewFixedPortBind(port uint16) conn.Bind {
|
|
return &fixedPortBind{
|
|
port: port,
|
|
Bind: conn.NewDefaultBind(),
|
|
}
|
|
}
|
|
|
|
func fixKey(key string) string {
|
|
// Remove any whitespace
|
|
key = strings.TrimSpace(key)
|
|
|
|
// Decode from base64
|
|
decoded, err := base64.StdEncoding.DecodeString(key)
|
|
if err != nil {
|
|
logger.Fatal("Error decoding base64")
|
|
}
|
|
|
|
// Convert to hex
|
|
return hex.EncodeToString(decoded)
|
|
}
|
|
|
|
func parseLogLevel(level string) logger.LogLevel {
|
|
switch strings.ToUpper(level) {
|
|
case "DEBUG":
|
|
return logger.DEBUG
|
|
case "INFO":
|
|
return logger.INFO
|
|
case "WARN":
|
|
return logger.WARN
|
|
case "ERROR":
|
|
return logger.ERROR
|
|
case "FATAL":
|
|
return logger.FATAL
|
|
default:
|
|
return logger.INFO // default to INFO if invalid level provided
|
|
}
|
|
}
|
|
|
|
func mapToWireGuardLogLevel(level logger.LogLevel) int {
|
|
switch level {
|
|
case logger.DEBUG:
|
|
return device.LogLevelVerbose
|
|
// case logger.INFO:
|
|
// return device.LogLevel
|
|
case logger.WARN:
|
|
return device.LogLevelError
|
|
case logger.ERROR, logger.FATAL:
|
|
return device.LogLevelSilent
|
|
default:
|
|
return device.LogLevelSilent
|
|
}
|
|
}
|
|
|
|
func resolveDomain(domain string) (string, error) {
|
|
// First handle any protocol prefix
|
|
domain = strings.TrimPrefix(strings.TrimPrefix(domain, "https://"), "http://")
|
|
|
|
// if there are any trailing slashes, remove them
|
|
domain = strings.TrimSuffix(domain, "/")
|
|
|
|
// Now split host and port
|
|
host, port, err := net.SplitHostPort(domain)
|
|
if err != nil {
|
|
// No port found, use the domain as is
|
|
host = domain
|
|
port = ""
|
|
}
|
|
|
|
// Lookup IP addresses
|
|
ips, err := net.LookupIP(host)
|
|
if err != nil {
|
|
return "", fmt.Errorf("DNS lookup failed: %v", err)
|
|
}
|
|
|
|
if len(ips) == 0 {
|
|
return "", fmt.Errorf("no IP addresses found for domain %s", host)
|
|
}
|
|
|
|
// Get the first IPv4 address if available
|
|
var ipAddr string
|
|
for _, ip := range ips {
|
|
if ipv4 := ip.To4(); ipv4 != nil {
|
|
ipAddr = ipv4.String()
|
|
break
|
|
}
|
|
}
|
|
|
|
// If no IPv4 found, use the first IP (might be IPv6)
|
|
if ipAddr == "" {
|
|
ipAddr = ips[0].String()
|
|
}
|
|
|
|
// Add port back if it existed
|
|
if port != "" {
|
|
ipAddr = net.JoinHostPort(ipAddr, port)
|
|
}
|
|
|
|
return ipAddr, nil
|
|
}
|
|
|
|
func sendUDPHolePunchWithConn(conn *net.UDPConn, remoteAddr *net.UDPAddr, olmID string, serverPubKey string) error {
|
|
if serverPubKey == "" || olmToken == "" {
|
|
return nil
|
|
}
|
|
|
|
payload := struct {
|
|
OlmID string `json:"olmId"`
|
|
Token string `json:"token"`
|
|
}{
|
|
OlmID: olmID,
|
|
Token: olmToken,
|
|
}
|
|
|
|
// Convert payload to JSON
|
|
payloadBytes, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal payload: %v", err)
|
|
}
|
|
|
|
// Encrypt the payload using the server's WireGuard public key
|
|
encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to encrypt payload: %v", err)
|
|
}
|
|
|
|
jsonData, err := json.Marshal(encryptedPayload)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal encrypted payload: %v", err)
|
|
}
|
|
|
|
_, err = conn.WriteToUDP(jsonData, remoteAddr)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to send UDP packet: %v", err)
|
|
}
|
|
|
|
logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData))
|
|
|
|
return nil
|
|
}
|
|
|
|
func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) {
|
|
// Generate an ephemeral keypair for this message
|
|
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
|
|
}
|
|
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
|
|
|
|
// Parse the server's public key
|
|
serverPubKey, err := wgtypes.ParseKey(serverPublicKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse server public key: %v", err)
|
|
}
|
|
|
|
// Use X25519 for key exchange (replacing deprecated ScalarMult)
|
|
var ephPrivKeyFixed [32]byte
|
|
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
|
|
|
|
// Perform X25519 key exchange
|
|
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
|
|
}
|
|
|
|
// Create an AEAD cipher using the shared secret
|
|
aead, err := chacha20poly1305.New(sharedSecret)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
|
|
}
|
|
|
|
// Generate a random nonce
|
|
nonce := make([]byte, aead.NonceSize())
|
|
if _, err := rand.Read(nonce); err != nil {
|
|
return nil, fmt.Errorf("failed to generate nonce: %v", err)
|
|
}
|
|
|
|
// Encrypt the payload
|
|
ciphertext := aead.Seal(nil, nonce, payload, nil)
|
|
|
|
// Prepare the final encrypted message
|
|
encryptedMsg := struct {
|
|
EphemeralPublicKey string `json:"ephemeralPublicKey"`
|
|
Nonce []byte `json:"nonce"`
|
|
Ciphertext []byte `json:"ciphertext"`
|
|
}{
|
|
EphemeralPublicKey: ephemeralPublicKey.String(),
|
|
Nonce: nonce,
|
|
Ciphertext: ciphertext,
|
|
}
|
|
|
|
return encryptedMsg, nil
|
|
}
|
|
|
|
func keepSendingUDPHolePunchToMultipleExitNodes(exitNodes []ExitNode, olmID string, sourcePort uint16) {
|
|
if len(exitNodes) == 0 {
|
|
logger.Warn("No exit nodes provided for hole punching")
|
|
return
|
|
}
|
|
|
|
// Check if hole punching is already running
|
|
if holePunchRunning {
|
|
logger.Debug("UDP hole punch already running, skipping new request")
|
|
return
|
|
}
|
|
|
|
// Set the flag to indicate hole punching is running
|
|
holePunchRunning = true
|
|
defer func() {
|
|
holePunchRunning = false
|
|
logger.Info("UDP hole punch goroutine ended")
|
|
}()
|
|
|
|
logger.Info("Starting UDP hole punch to %d exit nodes", len(exitNodes))
|
|
defer logger.Info("UDP hole punch goroutine ended for all exit nodes")
|
|
|
|
// Create the UDP connection once and reuse it for all exit nodes
|
|
localAddr := &net.UDPAddr{
|
|
Port: int(sourcePort),
|
|
IP: net.IPv4zero,
|
|
}
|
|
|
|
conn, err := net.ListenUDP("udp", localAddr)
|
|
if err != nil {
|
|
logger.Error("Failed to bind UDP socket: %v", err)
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
// Resolve all endpoints upfront
|
|
type resolvedExitNode struct {
|
|
remoteAddr *net.UDPAddr
|
|
publicKey string
|
|
endpointName string
|
|
}
|
|
|
|
var resolvedNodes []resolvedExitNode
|
|
for _, exitNode := range exitNodes {
|
|
host, err := resolveDomain(exitNode.Endpoint)
|
|
if err != nil {
|
|
logger.Error("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
|
|
continue
|
|
}
|
|
|
|
serverAddr := net.JoinHostPort(host, "21820")
|
|
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
|
if err != nil {
|
|
logger.Error("Failed to resolve UDP address for %s: %v", exitNode.Endpoint, err)
|
|
continue
|
|
}
|
|
|
|
resolvedNodes = append(resolvedNodes, resolvedExitNode{
|
|
remoteAddr: remoteAddr,
|
|
publicKey: exitNode.PublicKey,
|
|
endpointName: exitNode.Endpoint,
|
|
})
|
|
logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
|
|
}
|
|
|
|
if len(resolvedNodes) == 0 {
|
|
logger.Error("No exit nodes could be resolved")
|
|
return
|
|
}
|
|
|
|
// Send initial hole punch to all exit nodes
|
|
for _, node := range resolvedNodes {
|
|
if err := sendUDPHolePunchWithConn(conn, node.remoteAddr, olmID, node.publicKey); err != nil {
|
|
logger.Error("Failed to send initial UDP hole punch to %s: %v", node.endpointName, err)
|
|
}
|
|
}
|
|
|
|
ticker := time.NewTicker(250 * time.Millisecond)
|
|
defer ticker.Stop()
|
|
|
|
timeout := time.NewTimer(15 * time.Second)
|
|
defer timeout.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-stopHolepunch:
|
|
logger.Info("Stopping UDP holepunch for all exit nodes")
|
|
return
|
|
case <-timeout.C:
|
|
logger.Info("UDP holepunch routine timed out after 15 seconds for all exit nodes")
|
|
return
|
|
case <-ticker.C:
|
|
// Send hole punch to all exit nodes
|
|
for _, node := range resolvedNodes {
|
|
if err := sendUDPHolePunchWithConn(conn, node.remoteAddr, olmID, node.publicKey); err != nil {
|
|
logger.Error("Failed to send UDP hole punch to %s: %v", node.endpointName, err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func keepSendingUDPHolePunch(endpoint string, olmID string, sourcePort uint16, serverPubKey string) {
|
|
|
|
// Check if hole punching is already running
|
|
if holePunchRunning {
|
|
logger.Debug("UDP hole punch already running, skipping new request")
|
|
return
|
|
}
|
|
|
|
// Set the flag to indicate hole punching is running
|
|
holePunchRunning = true
|
|
defer func() {
|
|
holePunchRunning = false
|
|
logger.Info("UDP hole punch goroutine ended")
|
|
}()
|
|
|
|
logger.Info("Starting UDP hole punch to %s", endpoint)
|
|
defer logger.Info("UDP hole punch goroutine ended for %s", endpoint)
|
|
|
|
host, err := resolveDomain(endpoint)
|
|
if err != nil {
|
|
logger.Error("Failed to resolve endpoint: %v", err)
|
|
return
|
|
}
|
|
|
|
serverAddr := net.JoinHostPort(host, "21820")
|
|
|
|
// Create the UDP connection once and reuse it
|
|
localAddr := &net.UDPAddr{
|
|
Port: int(sourcePort),
|
|
IP: net.IPv4zero,
|
|
}
|
|
|
|
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
|
if err != nil {
|
|
logger.Error("Failed to resolve UDP address: %v", err)
|
|
return
|
|
}
|
|
|
|
conn, err := net.ListenUDP("udp", localAddr)
|
|
if err != nil {
|
|
logger.Error("Failed to bind UDP socket: %v", err)
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
// Execute once immediately before starting the loop
|
|
if err := sendUDPHolePunchWithConn(conn, remoteAddr, olmID, serverPubKey); err != nil {
|
|
logger.Error("Failed to send UDP hole punch: %v", err)
|
|
}
|
|
|
|
ticker := time.NewTicker(250 * time.Millisecond)
|
|
defer ticker.Stop()
|
|
|
|
timeout := time.NewTimer(15 * time.Second)
|
|
defer timeout.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-stopHolepunch:
|
|
logger.Info("Stopping UDP holepunch")
|
|
return
|
|
case <-timeout.C:
|
|
logger.Info("UDP holepunch routine timed out after 15 seconds")
|
|
return
|
|
case <-ticker.C:
|
|
if err := sendUDPHolePunchWithConn(conn, remoteAddr, olmID, serverPubKey); err != nil {
|
|
logger.Error("Failed to send UDP hole punch: %v", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
|
|
if maxPort < minPort {
|
|
return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort)
|
|
}
|
|
|
|
// Create a slice of all ports in the range
|
|
portRange := make([]uint16, maxPort-minPort+1)
|
|
for i := range portRange {
|
|
portRange[i] = minPort + uint16(i)
|
|
}
|
|
|
|
// Fisher-Yates shuffle to randomize the port order
|
|
rand.Seed(uint64(time.Now().UnixNano()))
|
|
for i := len(portRange) - 1; i > 0; i-- {
|
|
j := rand.Intn(i + 1)
|
|
portRange[i], portRange[j] = portRange[j], portRange[i]
|
|
}
|
|
|
|
// Try each port in the randomized order
|
|
for _, port := range portRange {
|
|
addr := &net.UDPAddr{
|
|
IP: net.ParseIP("127.0.0.1"),
|
|
Port: int(port),
|
|
}
|
|
conn, err := net.ListenUDP("udp", addr)
|
|
if err != nil {
|
|
continue // Port is in use or there was an error, try next port
|
|
}
|
|
_ = conn.SetDeadline(time.Now())
|
|
conn.Close()
|
|
return port, nil
|
|
}
|
|
|
|
return 0, fmt.Errorf("no available UDP ports found in range %d-%d", minPort, maxPort)
|
|
}
|
|
|
|
func sendPing(olm *websocket.Client) error {
|
|
err := olm.SendMessage("olm/ping", map[string]interface{}{
|
|
"timestamp": time.Now().Unix(),
|
|
})
|
|
if err != nil {
|
|
logger.Error("Failed to send ping message: %v", err)
|
|
return err
|
|
}
|
|
logger.Debug("Sent ping message")
|
|
return nil
|
|
}
|
|
|
|
func keepSendingPing(olm *websocket.Client) {
|
|
// Send ping immediately on startup
|
|
if err := sendPing(olm); err != nil {
|
|
logger.Error("Failed to send initial ping: %v", err)
|
|
} else {
|
|
logger.Info("Sent initial ping message")
|
|
}
|
|
|
|
// Set up ticker for one minute intervals
|
|
ticker := time.NewTicker(1 * time.Minute)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-stopPing:
|
|
logger.Info("Stopping ping messages")
|
|
return
|
|
case <-ticker.C:
|
|
if err := sendPing(olm); err != nil {
|
|
logger.Error("Failed to send periodic ping: %v", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// ConfigurePeer sets up or updates a peer within the WireGuard device
|
|
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, endpoint string) error {
|
|
siteHost, err := resolveDomain(siteConfig.Endpoint)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err)
|
|
}
|
|
|
|
// Split off the CIDR of the server IP which is just a string and add /32 for the allowed IP
|
|
allowedIp := strings.Split(siteConfig.ServerIP, "/")
|
|
if len(allowedIp) > 1 {
|
|
allowedIp[1] = "32"
|
|
} else {
|
|
allowedIp = append(allowedIp, "32")
|
|
}
|
|
allowedIpStr := strings.Join(allowedIp, "/")
|
|
|
|
// Collect all allowed IPs in a slice
|
|
var allowedIPs []string
|
|
allowedIPs = append(allowedIPs, allowedIpStr)
|
|
|
|
// If we have anything in remoteSubnets, add those as well
|
|
if siteConfig.RemoteSubnets != "" {
|
|
// Split remote subnets by comma and add each one
|
|
remoteSubnets := strings.Split(siteConfig.RemoteSubnets, ",")
|
|
for _, subnet := range remoteSubnets {
|
|
subnet = strings.TrimSpace(subnet)
|
|
if subnet != "" {
|
|
allowedIPs = append(allowedIPs, subnet)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Construct WireGuard config for this peer
|
|
var configBuilder strings.Builder
|
|
configBuilder.WriteString(fmt.Sprintf("private_key=%s\n", fixKey(privateKey.String())))
|
|
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", fixKey(siteConfig.PublicKey)))
|
|
|
|
// Add each allowed IP separately
|
|
for _, allowedIP := range allowedIPs {
|
|
configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIP))
|
|
}
|
|
|
|
configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost))
|
|
configBuilder.WriteString("persistent_keepalive_interval=1\n")
|
|
|
|
config := configBuilder.String()
|
|
logger.Debug("Configuring peer with config: %s", config)
|
|
|
|
err = dev.IpcSet(config)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to configure WireGuard peer: %v", err)
|
|
}
|
|
|
|
// Set up peer monitoring
|
|
if peerMonitor != nil {
|
|
monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0]
|
|
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
|
|
logger.Debug("Setting up peer monitor for site %d at %s", siteConfig.SiteId, monitorPeer)
|
|
|
|
primaryRelay, err := resolveDomain(endpoint) // Using global endpoint variable
|
|
if err != nil {
|
|
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
|
|
}
|
|
|
|
wgConfig := &peermonitor.WireGuardConfig{
|
|
SiteID: siteConfig.SiteId,
|
|
PublicKey: fixKey(siteConfig.PublicKey),
|
|
ServerIP: strings.Split(siteConfig.ServerIP, "/")[0],
|
|
Endpoint: siteConfig.Endpoint,
|
|
PrimaryRelay: primaryRelay,
|
|
}
|
|
|
|
err = peerMonitor.AddPeer(siteConfig.SiteId, monitorPeer, wgConfig)
|
|
if err != nil {
|
|
logger.Warn("Failed to setup monitoring for site %d: %v", siteConfig.SiteId, err)
|
|
} else {
|
|
logger.Info("Started monitoring for site %d at %s", siteConfig.SiteId, monitorPeer)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// RemovePeer removes a peer from the WireGuard device
|
|
func RemovePeer(dev *device.Device, siteId int, publicKey string) error {
|
|
// Construct WireGuard config to remove the peer
|
|
var configBuilder strings.Builder
|
|
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", fixKey(publicKey)))
|
|
configBuilder.WriteString("remove=true\n")
|
|
|
|
config := configBuilder.String()
|
|
logger.Debug("Removing peer with config: %s", config)
|
|
|
|
err := dev.IpcSet(config)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to remove WireGuard peer: %v", err)
|
|
}
|
|
|
|
// Stop monitoring this peer
|
|
if peerMonitor != nil {
|
|
peerMonitor.RemovePeer(siteId)
|
|
logger.Info("Stopped monitoring for site %d", siteId)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ConfigureInterface configures a network interface with an IP address and brings it up
|
|
func ConfigureInterface(interfaceName string, wgData WgData) error {
|
|
var ipAddr string = wgData.TunnelIP
|
|
|
|
// Parse the IP address and network
|
|
ip, ipNet, err := net.ParseCIDR(ipAddr)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid IP address: %v", err)
|
|
}
|
|
|
|
switch runtime.GOOS {
|
|
case "linux":
|
|
return configureLinux(interfaceName, ip, ipNet)
|
|
case "darwin":
|
|
return configureDarwin(interfaceName, ip, ipNet)
|
|
case "windows":
|
|
return configureWindows(interfaceName, ip, ipNet)
|
|
default:
|
|
return fmt.Errorf("unsupported operating system: %s", runtime.GOOS)
|
|
}
|
|
}
|
|
|
|
func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
|
logger.Info("Configuring Windows interface: %s", interfaceName)
|
|
|
|
// Calculate mask string (e.g., 255.255.255.0)
|
|
maskBits, _ := ipNet.Mask.Size()
|
|
mask := net.CIDRMask(maskBits, 32)
|
|
maskIP := net.IP(mask)
|
|
|
|
// Set the IP address using netsh
|
|
cmd := exec.Command("netsh", "interface", "ipv4", "set", "address",
|
|
fmt.Sprintf("name=%s", interfaceName),
|
|
"source=static",
|
|
fmt.Sprintf("addr=%s", ip.String()),
|
|
fmt.Sprintf("mask=%s", maskIP.String()))
|
|
|
|
logger.Info("Running command: %v", cmd)
|
|
out, err := cmd.CombinedOutput()
|
|
if err != nil {
|
|
return fmt.Errorf("netsh command failed: %v, output: %s", err, out)
|
|
}
|
|
|
|
// Bring up the interface if needed (in Windows, setting the IP usually brings it up)
|
|
// But we'll explicitly enable it to be sure
|
|
cmd = exec.Command("netsh", "interface", "set", "interface",
|
|
interfaceName,
|
|
"admin=enable")
|
|
|
|
logger.Info("Running command: %v", cmd)
|
|
out, err = cmd.CombinedOutput()
|
|
if err != nil {
|
|
return fmt.Errorf("netsh enable interface command failed: %v, output: %s", err, out)
|
|
}
|
|
|
|
// delay 2 seconds
|
|
time.Sleep(8 * time.Second)
|
|
|
|
// Wait for the interface to be up and have the correct IP
|
|
err = waitForInterfaceUp(interfaceName, ip, 30*time.Second)
|
|
if err != nil {
|
|
return fmt.Errorf("interface did not come up within timeout: %v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// waitForInterfaceUp polls the network interface until it's up or times out
|
|
func waitForInterfaceUp(interfaceName string, expectedIP net.IP, timeout time.Duration) error {
|
|
logger.Info("Waiting for interface %s to be up with IP %s", interfaceName, expectedIP)
|
|
deadline := time.Now().Add(timeout)
|
|
pollInterval := 500 * time.Millisecond
|
|
|
|
for time.Now().Before(deadline) {
|
|
// Check if interface exists and is up
|
|
iface, err := net.InterfaceByName(interfaceName)
|
|
if err == nil {
|
|
// Check if interface is up
|
|
if iface.Flags&net.FlagUp != 0 {
|
|
// Check if it has the expected IP
|
|
addrs, err := iface.Addrs()
|
|
if err == nil {
|
|
for _, addr := range addrs {
|
|
ipNet, ok := addr.(*net.IPNet)
|
|
if ok && ipNet.IP.Equal(expectedIP) {
|
|
logger.Info("Interface %s is up with correct IP", interfaceName)
|
|
return nil // Interface is up with correct IP
|
|
}
|
|
}
|
|
logger.Info("Interface %s is up but doesn't have expected IP yet", interfaceName)
|
|
}
|
|
} else {
|
|
logger.Info("Interface %s exists but is not up yet", interfaceName)
|
|
}
|
|
} else {
|
|
logger.Info("Interface %s not found yet: %v", interfaceName, err)
|
|
}
|
|
|
|
// Wait before next check
|
|
time.Sleep(pollInterval)
|
|
}
|
|
|
|
return fmt.Errorf("timed out waiting for interface %s to be up with IP %s", interfaceName, expectedIP)
|
|
}
|
|
|
|
func WindowsAddRoute(destination string, gateway string, interfaceName string) error {
|
|
if runtime.GOOS != "windows" {
|
|
return nil
|
|
}
|
|
|
|
var cmd *exec.Cmd
|
|
|
|
// Parse destination to get the IP and subnet
|
|
ip, ipNet, err := net.ParseCIDR(destination)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid destination address: %v", err)
|
|
}
|
|
|
|
// Calculate the subnet mask
|
|
maskBits, _ := ipNet.Mask.Size()
|
|
mask := net.CIDRMask(maskBits, 32)
|
|
maskIP := net.IP(mask)
|
|
|
|
if gateway != "" {
|
|
// Route with specific gateway
|
|
cmd = exec.Command("route", "add",
|
|
ip.String(),
|
|
"mask", maskIP.String(),
|
|
gateway,
|
|
"metric", "1")
|
|
} else if interfaceName != "" {
|
|
// First, get the interface index
|
|
indexCmd := exec.Command("netsh", "interface", "ipv4", "show", "interfaces")
|
|
output, err := indexCmd.CombinedOutput()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get interface index: %v, output: %s", err, output)
|
|
}
|
|
|
|
// Parse the output to find the interface index
|
|
lines := strings.Split(string(output), "\n")
|
|
var ifIndex string
|
|
for _, line := range lines {
|
|
if strings.Contains(line, interfaceName) {
|
|
fields := strings.Fields(line)
|
|
if len(fields) > 0 {
|
|
ifIndex = fields[0]
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
if ifIndex == "" {
|
|
return fmt.Errorf("could not find index for interface %s", interfaceName)
|
|
}
|
|
|
|
// Convert to integer to validate
|
|
idx, err := strconv.Atoi(ifIndex)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid interface index: %v", err)
|
|
}
|
|
|
|
// Route via interface using the index
|
|
cmd = exec.Command("route", "add",
|
|
ip.String(),
|
|
"mask", maskIP.String(),
|
|
"0.0.0.0",
|
|
"if", strconv.Itoa(idx))
|
|
} else {
|
|
return fmt.Errorf("either gateway or interface must be specified")
|
|
}
|
|
|
|
logger.Info("Running command: %v", cmd)
|
|
out, err := cmd.CombinedOutput()
|
|
if err != nil {
|
|
return fmt.Errorf("route command failed: %v, output: %s", err, out)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func WindowsRemoveRoute(destination string) error {
|
|
// Parse destination to get the IP
|
|
ip, ipNet, err := net.ParseCIDR(destination)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid destination address: %v", err)
|
|
}
|
|
|
|
// Calculate the subnet mask
|
|
maskBits, _ := ipNet.Mask.Size()
|
|
mask := net.CIDRMask(maskBits, 32)
|
|
maskIP := net.IP(mask)
|
|
|
|
cmd := exec.Command("route", "delete",
|
|
ip.String(),
|
|
"mask", maskIP.String())
|
|
|
|
logger.Info("Running command: %v", cmd)
|
|
out, err := cmd.CombinedOutput()
|
|
if err != nil {
|
|
return fmt.Errorf("route delete command failed: %v, output: %s", err, out)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func findUnusedUTUN() (string, error) {
|
|
ifaces, err := net.Interfaces()
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to list interfaces: %v", err)
|
|
}
|
|
used := make(map[int]bool)
|
|
re := regexp.MustCompile(`^utun(\d+)$`)
|
|
for _, iface := range ifaces {
|
|
if matches := re.FindStringSubmatch(iface.Name); len(matches) == 2 {
|
|
if num, err := strconv.Atoi(matches[1]); err == nil {
|
|
used[num] = true
|
|
}
|
|
}
|
|
}
|
|
// Try utun0 up to utun255.
|
|
for i := 0; i < 256; i++ {
|
|
if !used[i] {
|
|
return fmt.Sprintf("utun%d", i), nil
|
|
}
|
|
}
|
|
return "", fmt.Errorf("no unused utun interface found")
|
|
}
|
|
|
|
func configureDarwin(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
|
logger.Info("Configuring darwin interface: %s", interfaceName)
|
|
|
|
prefix, _ := ipNet.Mask.Size()
|
|
ipStr := fmt.Sprintf("%s/%d", ip.String(), prefix)
|
|
|
|
cmd := exec.Command("ifconfig", interfaceName, "inet", ipStr, ip.String(), "alias")
|
|
logger.Info("Running command: %v", cmd)
|
|
|
|
out, err := cmd.CombinedOutput()
|
|
if err != nil {
|
|
return fmt.Errorf("ifconfig command failed: %v, output: %s", err, out)
|
|
}
|
|
|
|
// Bring up the interface
|
|
cmd = exec.Command("ifconfig", interfaceName, "up")
|
|
logger.Info("Running command: %v", cmd)
|
|
|
|
out, err = cmd.CombinedOutput()
|
|
if err != nil {
|
|
return fmt.Errorf("ifconfig up command failed: %v, output: %s", err, out)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func configureLinux(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
|
// Get the interface
|
|
link, err := netlink.LinkByName(interfaceName)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get interface %s: %v", interfaceName, err)
|
|
}
|
|
|
|
// Create the IP address attributes
|
|
addr := &netlink.Addr{
|
|
IPNet: &net.IPNet{
|
|
IP: ip,
|
|
Mask: ipNet.Mask,
|
|
},
|
|
}
|
|
|
|
// Add the IP address to the interface
|
|
if err := netlink.AddrAdd(link, addr); err != nil {
|
|
return fmt.Errorf("failed to add IP address: %v", err)
|
|
}
|
|
|
|
// Bring up the interface
|
|
if err := netlink.LinkSetUp(link); err != nil {
|
|
return fmt.Errorf("failed to bring up interface: %v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func DarwinAddRoute(destination string, gateway string, interfaceName string) error {
|
|
if runtime.GOOS != "darwin" {
|
|
return nil
|
|
}
|
|
|
|
var cmd *exec.Cmd
|
|
|
|
if gateway != "" {
|
|
// Route with specific gateway
|
|
cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-gateway", gateway)
|
|
} else if interfaceName != "" {
|
|
// Route via interface
|
|
cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-interface", interfaceName)
|
|
} else {
|
|
return fmt.Errorf("either gateway or interface must be specified")
|
|
}
|
|
|
|
logger.Info("Running command: %v", cmd)
|
|
|
|
out, err := cmd.CombinedOutput()
|
|
if err != nil {
|
|
return fmt.Errorf("route command failed: %v, output: %s", err, out)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func DarwinRemoveRoute(destination string) error {
|
|
if runtime.GOOS != "darwin" {
|
|
return nil
|
|
}
|
|
|
|
cmd := exec.Command("route", "-q", "-n", "delete", "-inet", destination)
|
|
logger.Info("Running command: %v", cmd)
|
|
|
|
out, err := cmd.CombinedOutput()
|
|
if err != nil {
|
|
return fmt.Errorf("route delete command failed: %v, output: %s", err, out)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func LinuxAddRoute(destination string, gateway string, interfaceName string) error {
|
|
if runtime.GOOS != "linux" {
|
|
return nil
|
|
}
|
|
|
|
var cmd *exec.Cmd
|
|
|
|
if gateway != "" {
|
|
// Route with specific gateway
|
|
cmd = exec.Command("ip", "route", "add", destination, "via", gateway)
|
|
} else if interfaceName != "" {
|
|
// Route via interface
|
|
cmd = exec.Command("ip", "route", "add", destination, "dev", interfaceName)
|
|
} else {
|
|
return fmt.Errorf("either gateway or interface must be specified")
|
|
}
|
|
|
|
logger.Info("Running command: %v", cmd)
|
|
|
|
out, err := cmd.CombinedOutput()
|
|
if err != nil {
|
|
return fmt.Errorf("ip route command failed: %v, output: %s", err, out)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func LinuxRemoveRoute(destination string) error {
|
|
if runtime.GOOS != "linux" {
|
|
return nil
|
|
}
|
|
|
|
cmd := exec.Command("ip", "route", "del", destination)
|
|
logger.Info("Running command: %v", cmd)
|
|
|
|
out, err := cmd.CombinedOutput()
|
|
if err != nil {
|
|
return fmt.Errorf("ip route delete command failed: %v, output: %s", err, out)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// addRouteForServerIP adds an OS-specific route for the server IP
|
|
func addRouteForServerIP(serverIP, interfaceName string) error {
|
|
if runtime.GOOS == "darwin" {
|
|
return DarwinAddRoute(serverIP, "", interfaceName)
|
|
}
|
|
// else if runtime.GOOS == "windows" {
|
|
// return WindowsAddRoute(serverIP, "", interfaceName)
|
|
// } else if runtime.GOOS == "linux" {
|
|
// return LinuxAddRoute(serverIP, "", interfaceName)
|
|
// }
|
|
return nil
|
|
}
|
|
|
|
// removeRouteForServerIP removes an OS-specific route for the server IP
|
|
func removeRouteForServerIP(serverIP string) error {
|
|
if runtime.GOOS == "darwin" {
|
|
return DarwinRemoveRoute(serverIP)
|
|
}
|
|
// else if runtime.GOOS == "windows" {
|
|
// return WindowsRemoveRoute(serverIP)
|
|
// } else if runtime.GOOS == "linux" {
|
|
// return LinuxRemoveRoute(serverIP)
|
|
// }
|
|
return nil
|
|
}
|
|
|
|
// addRoutesForRemoteSubnets adds routes for each comma-separated CIDR in RemoteSubnets
|
|
func addRoutesForRemoteSubnets(remoteSubnets, interfaceName string) error {
|
|
if remoteSubnets == "" {
|
|
return nil
|
|
}
|
|
|
|
// Split remote subnets by comma and add routes for each one
|
|
subnets := strings.Split(remoteSubnets, ",")
|
|
for _, subnet := range subnets {
|
|
subnet = strings.TrimSpace(subnet)
|
|
if subnet == "" {
|
|
continue
|
|
}
|
|
|
|
// Add route based on operating system
|
|
if runtime.GOOS == "darwin" {
|
|
if err := DarwinAddRoute(subnet, "", interfaceName); err != nil {
|
|
logger.Error("Failed to add Darwin route for subnet %s: %v", subnet, err)
|
|
return err
|
|
}
|
|
} else if runtime.GOOS == "windows" {
|
|
if err := WindowsAddRoute(subnet, "", interfaceName); err != nil {
|
|
logger.Error("Failed to add Windows route for subnet %s: %v", subnet, err)
|
|
return err
|
|
}
|
|
} else if runtime.GOOS == "linux" {
|
|
if err := LinuxAddRoute(subnet, "", interfaceName); err != nil {
|
|
logger.Error("Failed to add Linux route for subnet %s: %v", subnet, err)
|
|
return err
|
|
}
|
|
}
|
|
|
|
logger.Info("Added route for remote subnet: %s", subnet)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// removeRoutesForRemoteSubnets removes routes for each comma-separated CIDR in RemoteSubnets
|
|
func removeRoutesForRemoteSubnets(remoteSubnets string) error {
|
|
if remoteSubnets == "" {
|
|
return nil
|
|
}
|
|
|
|
// Split remote subnets by comma and remove routes for each one
|
|
subnets := strings.Split(remoteSubnets, ",")
|
|
for _, subnet := range subnets {
|
|
subnet = strings.TrimSpace(subnet)
|
|
if subnet == "" {
|
|
continue
|
|
}
|
|
|
|
// Remove route based on operating system
|
|
if runtime.GOOS == "darwin" {
|
|
if err := DarwinRemoveRoute(subnet); err != nil {
|
|
logger.Error("Failed to remove Darwin route for subnet %s: %v", subnet, err)
|
|
return err
|
|
}
|
|
} else if runtime.GOOS == "windows" {
|
|
if err := WindowsRemoveRoute(subnet); err != nil {
|
|
logger.Error("Failed to remove Windows route for subnet %s: %v", subnet, err)
|
|
return err
|
|
}
|
|
} else if runtime.GOOS == "linux" {
|
|
if err := LinuxRemoveRoute(subnet); err != nil {
|
|
logger.Error("Failed to remove Linux route for subnet %s: %v", subnet, err)
|
|
return err
|
|
}
|
|
}
|
|
|
|
logger.Info("Removed route for remote subnet: %s", subnet)
|
|
}
|
|
return nil
|
|
}
|