Compare commits

..

1 Commits
jit ... dev

Author SHA1 Message Date
André Gilerson
3f258d3500 Fix crash when peer has nil publicKey in site config
Skip sites with empty/nil publicKey instead of passing them to the
WireGuard UAPI layer, which expects a valid 64-char hex string. A nil
key occurs when a Newt site has never connected. Previously this caused
all sites to fail with "hex string does not fit the slice".
2026-03-07 20:44:25 -08:00
6 changed files with 61 additions and 305 deletions

View File

@@ -78,13 +78,6 @@ type MetadataChangeRequest struct {
Postures map[string]any `json:"postures"`
}
// JITConnectionRequest defines the structure for a dynamic Just-In-Time connection request.
// Either SiteID or ResourceID must be provided (but not necessarily both).
type JITConnectionRequest struct {
Site string `json:"site,omitempty"`
Resource string `json:"resource,omitempty"`
}
// API represents the HTTP server and its state
type API struct {
addr string
@@ -99,7 +92,6 @@ type API struct {
onExit func() error
onRebind func() error
onPowerMode func(PowerModeRequest) error
onJITConnect func(JITConnectionRequest) error
statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus
@@ -151,7 +143,6 @@ func (s *API) SetHandlers(
onExit func() error,
onRebind func() error,
onPowerMode func(PowerModeRequest) error,
onJITConnect func(JITConnectionRequest) error,
) {
s.onConnect = onConnect
s.onSwitchOrg = onSwitchOrg
@@ -160,7 +151,6 @@ func (s *API) SetHandlers(
s.onExit = onExit
s.onRebind = onRebind
s.onPowerMode = onPowerMode
s.onJITConnect = onJITConnect
}
// Start starts the HTTP server
@@ -179,7 +169,6 @@ func (s *API) Start() error {
mux.HandleFunc("/health", s.handleHealth)
mux.HandleFunc("/rebind", s.handleRebind)
mux.HandleFunc("/power-mode", s.handlePowerMode)
mux.HandleFunc("/jit-connect", s.handleJITConnect)
s.server = &http.Server{
Handler: mux,
@@ -644,54 +633,6 @@ func (s *API) handleRebind(w http.ResponseWriter, r *http.Request) {
})
}
// handleJITConnect handles the /jit-connect endpoint.
// It initiates a dynamic Just-In-Time connection to a site identified by either
// a site or a resource. Exactly one of the two must be provided.
func (s *API) handleJITConnect(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req JITConnectionRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
return
}
// Validate that exactly one of site or resource is provided
if req.Site == "" && req.Resource == "" {
http.Error(w, "Missing required field: either site or resource must be provided", http.StatusBadRequest)
return
}
if req.Site != "" && req.Resource != "" {
http.Error(w, "Ambiguous request: provide either site or resource, not both", http.StatusBadRequest)
return
}
if req.Site != "" {
logger.Info("Received JIT connection request via API: site=%s", req.Site)
} else {
logger.Info("Received JIT connection request via API: resource=%s", req.Resource)
}
if s.onJITConnect != nil {
if err := s.onJITConnect(req); err != nil {
http.Error(w, fmt.Sprintf("JIT connection failed: %v", err), http.StatusInternalServerError)
return
}
} else {
http.Error(w, "JIT connect handler not configured", http.StatusNotImplemented)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted)
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "JIT connection request accepted",
})
}
// handlePowerMode handles the /power-mode endpoint
// This allows changing the power mode between "normal" and "low"
func (s *API) handlePowerMode(w http.ResponseWriter, r *http.Request) {

View File

@@ -172,6 +172,12 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) {
for i := range wgData.Sites {
site := wgData.Sites[i]
if site.PublicKey == "" {
logger.Warn("Skipping site %d (%s): no public key available (site may not be connected)", site.SiteId, site.Name)
continue
}
var siteEndpoint string
// here we are going to take the relay endpoint if it exists which means we requested a relay for this peer
if site.RelayEndpoint != "" {

View File

@@ -2,7 +2,6 @@ package olm
import (
"encoding/json"
"fmt"
"time"
"github.com/fosrl/newt/holepunch"
@@ -221,7 +220,6 @@ func (o *Olm) handleSync(msg websocket.WSMessage) {
logger.Info("Sync: Adding new peer for site %d", siteId)
o.holePunchManager.TriggerHolePunch()
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
// // TODO: do we need to send the message to the cloud to add the peer that way?
// if err := o.peerManager.AddPeer(expectedSite); err != nil {
@@ -232,17 +230,9 @@ func (o *Olm) handleSync(msg websocket.WSMessage) {
// add the peer via the server
// this is important because newt needs to get triggered as well to add the peer once the hp is complete
chainId := fmt.Sprintf("sync-%d", expectedSite.SiteId)
o.peerSendMu.Lock()
if stop, ok := o.stopPeerSends[chainId]; ok {
stop()
}
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": expectedSite.SiteId,
"chainId": chainId,
}, 2*time.Second, 10)
o.stopPeerSends[chainId] = stopFunc
o.peerSendMu.Unlock()
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": expectedSite.SiteId,
}, 1*time.Second, 10)
} else {
// Existing peer - check if update is needed

View File

@@ -2,8 +2,6 @@ package olm
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"net"
"net/http"
@@ -67,9 +65,7 @@ type Olm struct {
stopRegister func()
updateRegister func(newData any)
stopPeerSends map[string]func()
stopPeerInits map[string]func()
peerSendMu sync.Mutex
stopPeerSend func()
// WaitGroup to track tunnel lifecycle
tunnelWg sync.WaitGroup
@@ -120,13 +116,6 @@ func (o *Olm) initTunnelInfo(clientID string) error {
return nil
}
// generateChainId generates a random chain ID for tracking peer sender lifecycles.
func generateChainId() string {
b := make([]byte, 8)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
}
func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel))
@@ -177,12 +166,10 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
apiServer.SetAgent(config.Agent)
newOlm := &Olm{
logFile: logFile,
olmCtx: ctx,
apiServer: apiServer,
olmConfig: config,
stopPeerSends: make(map[string]func()),
stopPeerInits: make(map[string]func()),
logFile: logFile,
olmCtx: ctx,
apiServer: apiServer,
olmConfig: config,
}
newOlm.registerAPICallbacks()
@@ -297,21 +284,6 @@ func (o *Olm) registerAPICallbacks() {
logger.Info("Processing power mode change request via API: mode=%s", req.Mode)
return o.SetPowerMode(req.Mode)
},
func(req api.JITConnectionRequest) error {
logger.Info("Processing JIT connect request via API: site=%s resource=%s", req.Site, req.Resource)
chainId := generateChainId()
o.peerSendMu.Lock()
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{
"siteId": req.Site,
"resourceId": req.Resource,
"chainId": chainId,
}, 2*time.Second, 10)
o.stopPeerInits[chainId] = stopFunc
o.peerSendMu.Unlock()
return nil
},
)
}
@@ -406,7 +378,6 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
// Handler for peer handshake - adds exit node to holepunch rotation and notifies server
o.websocket.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite)
o.websocket.RegisterHandler("olm/wg/peer/chain/cancel", o.handleCancelChain)
o.websocket.RegisterHandler("olm/sync", o.handleSync)
o.websocket.OnConnect(func() error {
@@ -449,7 +420,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
"userToken": userToken,
"fingerprint": o.fingerprint,
"postures": o.postures,
}, 2*time.Second, 10)
}, 1*time.Second, 10)
// Invoke onRegistered callback if configured
if o.olmConfig.OnRegistered != nil {
@@ -546,22 +517,6 @@ func (o *Olm) Close() {
o.stopRegister = nil
}
// Stop all pending peer init and send senders before closing websocket
o.peerSendMu.Lock()
for _, stop := range o.stopPeerInits {
if stop != nil {
stop()
}
}
o.stopPeerInits = make(map[string]func())
for _, stop := range o.stopPeerSends {
if stop != nil {
stop()
}
}
o.stopPeerSends = make(map[string]func())
o.peerSendMu.Unlock()
// send a disconnect message to the cloud to show disconnected
if o.websocket != nil {
o.websocket.SendMessage("olm/disconnecting", map[string]any{})

View File

@@ -20,38 +20,36 @@ func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) {
return
}
if o.stopPeerSend != nil {
o.stopPeerSend()
o.stopPeerSend = nil
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var siteConfigMsg struct {
peers.SiteConfig
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &siteConfigMsg); err != nil {
var siteConfig peers.SiteConfig
if err := json.Unmarshal(jsonData, &siteConfig); err != nil {
logger.Error("Error unmarshaling add data: %v", err)
return
}
if siteConfigMsg.ChainId != "" {
o.peerSendMu.Lock()
if stop, ok := o.stopPeerSends[siteConfigMsg.ChainId]; ok {
stop()
delete(o.stopPeerSends, siteConfigMsg.ChainId)
}
o.peerSendMu.Unlock()
if siteConfig.PublicKey == "" {
logger.Warn("Skipping add-peer for site %d (%s): no public key available (site may not be connected)", siteConfig.SiteId, siteConfig.Name)
return
}
_ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
if err := o.peerManager.AddPeer(siteConfigMsg.SiteConfig); err != nil {
if err := o.peerManager.AddPeer(siteConfig); err != nil {
logger.Error("Failed to add peer: %v", err)
return
}
logger.Info("Successfully added peer for site %d", siteConfigMsg.SiteId)
logger.Info("Successfully added peer for site %d", siteConfig.SiteId)
}
func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) {
@@ -171,19 +169,12 @@ func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) {
return
}
var relayData struct {
peers.RelayPeerData
ChainId string `json:"chainId"`
}
var relayData peers.RelayPeerData
if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err)
return
}
if monitor := o.peerManager.GetPeerMonitor(); monitor != nil {
monitor.CancelRelaySend(relayData.ChainId)
}
primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint)
if err != nil {
logger.Error("Failed to resolve primary relay endpoint: %v", err)
@@ -211,19 +202,12 @@ func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) {
return
}
var relayData struct {
peers.UnRelayPeerData
ChainId string `json:"chainId"`
}
var relayData peers.UnRelayPeerData
if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err)
return
}
if monitor := o.peerManager.GetPeerMonitor(); monitor != nil {
monitor.CancelRelaySend(relayData.ChainId)
}
primaryRelay, err := util.ResolveDomain(relayData.Endpoint)
if err != nil {
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
@@ -251,8 +235,7 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
}
var handshakeData struct {
SiteId int `json:"siteId"`
ChainId string `json:"chainId"`
SiteId int `json:"siteId"`
ExitNode struct {
PublicKey string `json:"publicKey"`
Endpoint string `json:"endpoint"`
@@ -265,16 +248,6 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
return
}
// Stop the peer init sender for this chain, if any
if handshakeData.ChainId != "" {
o.peerSendMu.Lock()
if stop, ok := o.stopPeerInits[handshakeData.ChainId]; ok {
stop()
delete(o.stopPeerInits, handshakeData.ChainId)
}
o.peerSendMu.Unlock()
}
// Get existing peer from PeerManager
_, exists := o.peerManager.GetPeer(handshakeData.SiteId)
if exists {
@@ -305,64 +278,10 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
// Send handshake acknowledgment back to server with retry, keyed by chainId
chainId := handshakeData.ChainId
if chainId == "" {
chainId = generateChainId()
}
o.peerSendMu.Lock()
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": handshakeData.SiteId,
"chainId": chainId,
}, 2*time.Second, 10)
o.stopPeerSends[chainId] = stopFunc
o.peerSendMu.Unlock()
// Send handshake acknowledgment back to server with retry
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": handshakeData.SiteId,
}, 1*time.Second, 10)
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
}
func (o *Olm) handleCancelChain(msg websocket.WSMessage) {
logger.Debug("Received cancel-chain message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling cancel-chain data: %v", err)
return
}
var cancelData struct {
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &cancelData); err != nil {
logger.Error("Error unmarshaling cancel-chain data: %v", err)
return
}
if cancelData.ChainId == "" {
logger.Warn("Received cancel-chain message with no chainId")
return
}
o.peerSendMu.Lock()
defer o.peerSendMu.Unlock()
found := false
if stop, ok := o.stopPeerInits[cancelData.ChainId]; ok {
stop()
delete(o.stopPeerInits, cancelData.ChainId)
found = true
}
if stop, ok := o.stopPeerSends[cancelData.ChainId]; ok {
stop()
delete(o.stopPeerSends, cancelData.ChainId)
found = true
}
if found {
logger.Info("Cancelled chain %s", cancelData.ChainId)
} else {
logger.Warn("Cancel-chain: no active sender found for chain %s", cancelData.ChainId)
}
}

View File

@@ -2,8 +2,6 @@ package monitor
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"net"
"net/netip"
@@ -33,14 +31,10 @@ type PeerMonitor struct {
monitors map[int]*Client
mutex sync.Mutex
running bool
timeout time.Duration
timeout time.Duration
maxAttempts int
wsClient *websocket.Client
// Relay sender tracking
relaySends map[string]func()
relaySendMu sync.Mutex
// Netstack fields
middleDev *middleDevice.MiddleDevice
localIP string
@@ -53,13 +47,13 @@ type PeerMonitor struct {
nsWg sync.WaitGroup
// Holepunch testing fields
sharedBind *bind.SharedBind
holepunchTester *holepunch.HolepunchTester
holepunchTimeout time.Duration
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
holepunchStatus map[int]bool // siteID -> connected status
holepunchStopChan chan struct{}
holepunchUpdateChan chan struct{}
sharedBind *bind.SharedBind
holepunchTester *holepunch.HolepunchTester
holepunchTimeout time.Duration
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
holepunchStatus map[int]bool // siteID -> connected status
holepunchStopChan chan struct{}
holepunchUpdateChan chan struct{}
// Relay tracking fields
relayedPeers map[int]bool // siteID -> whether the peer is currently relayed
@@ -88,12 +82,6 @@ type PeerMonitor struct {
}
// NewPeerMonitor creates a new peer monitor with the given callback
func generateChainId() string {
b := make([]byte, 8)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
}
func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API) *PeerMonitor {
ctx, cancel := context.WithCancel(context.Background())
pm := &PeerMonitor{
@@ -111,7 +99,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
holepunchEndpoints: make(map[int]string),
holepunchStatus: make(map[int]bool),
relayedPeers: make(map[int]bool),
relaySends: make(map[string]func()),
holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures
holepunchFailures: make(map[int]int),
// Rapid initial test settings: complete within ~1.5 seconds
@@ -409,23 +396,20 @@ func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status Connectio
}
}
// sendRelay sends a relay message to the server with retry, keyed by chainId
// sendRelay sends a relay message to the server
func (pm *PeerMonitor) sendRelay(siteID int) error {
if pm.wsClient == nil {
return fmt.Errorf("websocket client is nil")
}
chainId := generateChainId()
stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/relay", map[string]interface{}{
"siteId": siteID,
"chainId": chainId,
}, 2*time.Second, 10)
pm.relaySendMu.Lock()
pm.relaySends[chainId] = stopFunc
pm.relaySendMu.Unlock()
logger.Info("Sent relay message for site %d (chain %s)", siteID, chainId)
err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{
"siteId": siteID,
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return err
}
logger.Info("Sent relay message")
return nil
}
@@ -435,52 +419,23 @@ func (pm *PeerMonitor) RequestRelay(siteID int) error {
return pm.sendRelay(siteID)
}
// sendUnRelay sends an unrelay message to the server with retry, keyed by chainId
// sendUnRelay sends an unrelay message to the server
func (pm *PeerMonitor) sendUnRelay(siteID int) error {
if pm.wsClient == nil {
return fmt.Errorf("websocket client is nil")
}
chainId := generateChainId()
stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/unrelay", map[string]interface{}{
"siteId": siteID,
"chainId": chainId,
}, 2*time.Second, 10)
pm.relaySendMu.Lock()
pm.relaySends[chainId] = stopFunc
pm.relaySendMu.Unlock()
logger.Info("Sent unrelay message for site %d (chain %s)", siteID, chainId)
err := pm.wsClient.SendMessage("olm/wg/unrelay", map[string]interface{}{
"siteId": siteID,
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return err
}
logger.Info("Sent unrelay message")
return nil
}
// CancelRelaySend stops the interval sender for the given chainId, if one exists.
// If chainId is empty, all active relay senders are stopped.
func (pm *PeerMonitor) CancelRelaySend(chainId string) {
pm.relaySendMu.Lock()
defer pm.relaySendMu.Unlock()
if chainId == "" {
for id, stop := range pm.relaySends {
if stop != nil {
stop()
}
delete(pm.relaySends, id)
}
logger.Info("Cancelled all relay senders")
return
}
if stop, ok := pm.relaySends[chainId]; ok {
stop()
delete(pm.relaySends, chainId)
logger.Info("Cancelled relay sender for chain %s", chainId)
} else {
logger.Warn("CancelRelaySend: no active sender for chain %s", chainId)
}
}
// Stop stops monitoring all peers
func (pm *PeerMonitor) Stop() {
// Stop holepunch monitor first (outside of mutex to avoid deadlock)
@@ -579,7 +534,7 @@ func (pm *PeerMonitor) runHolepunchMonitor() {
pm.holepunchCurrentInterval = pm.holepunchMinInterval
currentInterval := pm.holepunchCurrentInterval
pm.mutex.Unlock()
timer.Reset(currentInterval)
logger.Debug("Holepunch monitor interval updated, reset to %v", currentInterval)
case <-timer.C:
@@ -722,16 +677,6 @@ func (pm *PeerMonitor) Close() {
// Stop holepunch monitor first (outside of mutex to avoid deadlock)
pm.stopHolepunchMonitor()
// Stop all pending relay senders
pm.relaySendMu.Lock()
for chainId, stop := range pm.relaySends {
if stop != nil {
stop()
}
delete(pm.relaySends, chainId)
}
pm.relaySendMu.Unlock()
pm.mutex.Lock()
defer pm.mutex.Unlock()