mirror of
https://github.com/fosrl/olm.git
synced 2026-03-09 15:22:23 -05:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f2d0e6a14c | ||
|
|
809dbe77de | ||
|
|
c67c2a60a1 | ||
|
|
051c0fdfd8 | ||
|
|
e7507e0837 |
59
api/api.go
59
api/api.go
@@ -78,6 +78,13 @@ type MetadataChangeRequest struct {
|
||||
Postures map[string]any `json:"postures"`
|
||||
}
|
||||
|
||||
// JITConnectionRequest defines the structure for a dynamic Just-In-Time connection request.
|
||||
// Either SiteID or ResourceID must be provided (but not necessarily both).
|
||||
type JITConnectionRequest struct {
|
||||
Site string `json:"site,omitempty"`
|
||||
Resource string `json:"resource,omitempty"`
|
||||
}
|
||||
|
||||
// API represents the HTTP server and its state
|
||||
type API struct {
|
||||
addr string
|
||||
@@ -92,6 +99,7 @@ type API struct {
|
||||
onExit func() error
|
||||
onRebind func() error
|
||||
onPowerMode func(PowerModeRequest) error
|
||||
onJITConnect func(JITConnectionRequest) error
|
||||
|
||||
statusMu sync.RWMutex
|
||||
peerStatuses map[int]*PeerStatus
|
||||
@@ -143,6 +151,7 @@ func (s *API) SetHandlers(
|
||||
onExit func() error,
|
||||
onRebind func() error,
|
||||
onPowerMode func(PowerModeRequest) error,
|
||||
onJITConnect func(JITConnectionRequest) error,
|
||||
) {
|
||||
s.onConnect = onConnect
|
||||
s.onSwitchOrg = onSwitchOrg
|
||||
@@ -151,6 +160,7 @@ func (s *API) SetHandlers(
|
||||
s.onExit = onExit
|
||||
s.onRebind = onRebind
|
||||
s.onPowerMode = onPowerMode
|
||||
s.onJITConnect = onJITConnect
|
||||
}
|
||||
|
||||
// Start starts the HTTP server
|
||||
@@ -169,6 +179,7 @@ func (s *API) Start() error {
|
||||
mux.HandleFunc("/health", s.handleHealth)
|
||||
mux.HandleFunc("/rebind", s.handleRebind)
|
||||
mux.HandleFunc("/power-mode", s.handlePowerMode)
|
||||
mux.HandleFunc("/jit-connect", s.handleJITConnect)
|
||||
|
||||
s.server = &http.Server{
|
||||
Handler: mux,
|
||||
@@ -633,6 +644,54 @@ func (s *API) handleRebind(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
// handleJITConnect handles the /jit-connect endpoint.
|
||||
// It initiates a dynamic Just-In-Time connection to a site identified by either
|
||||
// a site or a resource. Exactly one of the two must be provided.
|
||||
func (s *API) handleJITConnect(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req JITConnectionRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate that exactly one of site or resource is provided
|
||||
if req.Site == "" && req.Resource == "" {
|
||||
http.Error(w, "Missing required field: either site or resource must be provided", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if req.Site != "" && req.Resource != "" {
|
||||
http.Error(w, "Ambiguous request: provide either site or resource, not both", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Site != "" {
|
||||
logger.Info("Received JIT connection request via API: site=%s", req.Site)
|
||||
} else {
|
||||
logger.Info("Received JIT connection request via API: resource=%s", req.Resource)
|
||||
}
|
||||
|
||||
if s.onJITConnect != nil {
|
||||
if err := s.onJITConnect(req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("JIT connection failed: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
http.Error(w, "JIT connect handler not configured", http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "JIT connection request accepted",
|
||||
})
|
||||
}
|
||||
|
||||
// handlePowerMode handles the /power-mode endpoint
|
||||
// This allows changing the power mode between "normal" and "low"
|
||||
func (s *API) handlePowerMode(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
16
olm/data.go
16
olm/data.go
@@ -2,6 +2,7 @@ package olm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/holepunch"
|
||||
@@ -220,6 +221,7 @@ func (o *Olm) handleSync(msg websocket.WSMessage) {
|
||||
logger.Info("Sync: Adding new peer for site %d", siteId)
|
||||
|
||||
o.holePunchManager.TriggerHolePunch()
|
||||
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
|
||||
|
||||
// // TODO: do we need to send the message to the cloud to add the peer that way?
|
||||
// if err := o.peerManager.AddPeer(expectedSite); err != nil {
|
||||
@@ -230,9 +232,17 @@ func (o *Olm) handleSync(msg websocket.WSMessage) {
|
||||
|
||||
// add the peer via the server
|
||||
// this is important because newt needs to get triggered as well to add the peer once the hp is complete
|
||||
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
||||
"siteId": expectedSite.SiteId,
|
||||
}, 1*time.Second, 10)
|
||||
chainId := fmt.Sprintf("sync-%d", expectedSite.SiteId)
|
||||
o.peerSendMu.Lock()
|
||||
if stop, ok := o.stopPeerSends[chainId]; ok {
|
||||
stop()
|
||||
}
|
||||
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
||||
"siteId": expectedSite.SiteId,
|
||||
"chainId": chainId,
|
||||
}, 2*time.Second, 10)
|
||||
o.stopPeerSends[chainId] = stopFunc
|
||||
o.peerSendMu.Unlock()
|
||||
|
||||
} else {
|
||||
// Existing peer - check if update is needed
|
||||
|
||||
57
olm/olm.go
57
olm/olm.go
@@ -2,6 +2,8 @@ package olm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -65,7 +67,9 @@ type Olm struct {
|
||||
stopRegister func()
|
||||
updateRegister func(newData any)
|
||||
|
||||
stopPeerSend func()
|
||||
stopPeerSends map[string]func()
|
||||
stopPeerInits map[string]func()
|
||||
peerSendMu sync.Mutex
|
||||
|
||||
// WaitGroup to track tunnel lifecycle
|
||||
tunnelWg sync.WaitGroup
|
||||
@@ -116,6 +120,13 @@ func (o *Olm) initTunnelInfo(clientID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateChainId generates a random chain ID for tracking peer sender lifecycles.
|
||||
func generateChainId() string {
|
||||
b := make([]byte, 8)
|
||||
_, _ = rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
|
||||
logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel))
|
||||
|
||||
@@ -166,10 +177,12 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
|
||||
apiServer.SetAgent(config.Agent)
|
||||
|
||||
newOlm := &Olm{
|
||||
logFile: logFile,
|
||||
olmCtx: ctx,
|
||||
apiServer: apiServer,
|
||||
olmConfig: config,
|
||||
logFile: logFile,
|
||||
olmCtx: ctx,
|
||||
apiServer: apiServer,
|
||||
olmConfig: config,
|
||||
stopPeerSends: make(map[string]func()),
|
||||
stopPeerInits: make(map[string]func()),
|
||||
}
|
||||
|
||||
newOlm.registerAPICallbacks()
|
||||
@@ -284,6 +297,21 @@ func (o *Olm) registerAPICallbacks() {
|
||||
logger.Info("Processing power mode change request via API: mode=%s", req.Mode)
|
||||
return o.SetPowerMode(req.Mode)
|
||||
},
|
||||
func(req api.JITConnectionRequest) error {
|
||||
logger.Info("Processing JIT connect request via API: site=%s resource=%s", req.Site, req.Resource)
|
||||
|
||||
chainId := generateChainId()
|
||||
o.peerSendMu.Lock()
|
||||
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{
|
||||
"siteId": req.Site,
|
||||
"resourceId": req.Resource,
|
||||
"chainId": chainId,
|
||||
}, 2*time.Second, 10)
|
||||
o.stopPeerInits[chainId] = stopFunc
|
||||
o.peerSendMu.Unlock()
|
||||
|
||||
return nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -378,6 +406,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
||||
|
||||
// Handler for peer handshake - adds exit node to holepunch rotation and notifies server
|
||||
o.websocket.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite)
|
||||
o.websocket.RegisterHandler("olm/wg/peer/chain/cancel", o.handleCancelChain)
|
||||
o.websocket.RegisterHandler("olm/sync", o.handleSync)
|
||||
|
||||
o.websocket.OnConnect(func() error {
|
||||
@@ -420,7 +449,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
||||
"userToken": userToken,
|
||||
"fingerprint": o.fingerprint,
|
||||
"postures": o.postures,
|
||||
}, 1*time.Second, 10)
|
||||
}, 2*time.Second, 10)
|
||||
|
||||
// Invoke onRegistered callback if configured
|
||||
if o.olmConfig.OnRegistered != nil {
|
||||
@@ -517,6 +546,22 @@ func (o *Olm) Close() {
|
||||
o.stopRegister = nil
|
||||
}
|
||||
|
||||
// Stop all pending peer init and send senders before closing websocket
|
||||
o.peerSendMu.Lock()
|
||||
for _, stop := range o.stopPeerInits {
|
||||
if stop != nil {
|
||||
stop()
|
||||
}
|
||||
}
|
||||
o.stopPeerInits = make(map[string]func())
|
||||
for _, stop := range o.stopPeerSends {
|
||||
if stop != nil {
|
||||
stop()
|
||||
}
|
||||
}
|
||||
o.stopPeerSends = make(map[string]func())
|
||||
o.peerSendMu.Unlock()
|
||||
|
||||
// send a disconnect message to the cloud to show disconnected
|
||||
if o.websocket != nil {
|
||||
o.websocket.SendMessage("olm/disconnecting", map[string]any{})
|
||||
|
||||
118
olm/peer.go
118
olm/peer.go
@@ -20,31 +20,38 @@ func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) {
|
||||
return
|
||||
}
|
||||
|
||||
if o.stopPeerSend != nil {
|
||||
o.stopPeerSend()
|
||||
o.stopPeerSend = nil
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var siteConfig peers.SiteConfig
|
||||
if err := json.Unmarshal(jsonData, &siteConfig); err != nil {
|
||||
var siteConfigMsg struct {
|
||||
peers.SiteConfig
|
||||
ChainId string `json:"chainId"`
|
||||
}
|
||||
if err := json.Unmarshal(jsonData, &siteConfigMsg); err != nil {
|
||||
logger.Error("Error unmarshaling add data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if siteConfigMsg.ChainId != "" {
|
||||
o.peerSendMu.Lock()
|
||||
if stop, ok := o.stopPeerSends[siteConfigMsg.ChainId]; ok {
|
||||
stop()
|
||||
delete(o.stopPeerSends, siteConfigMsg.ChainId)
|
||||
}
|
||||
o.peerSendMu.Unlock()
|
||||
}
|
||||
|
||||
_ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
|
||||
|
||||
if err := o.peerManager.AddPeer(siteConfig); err != nil {
|
||||
if err := o.peerManager.AddPeer(siteConfigMsg.SiteConfig); err != nil {
|
||||
logger.Error("Failed to add peer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Successfully added peer for site %d", siteConfig.SiteId)
|
||||
logger.Info("Successfully added peer for site %d", siteConfigMsg.SiteId)
|
||||
}
|
||||
|
||||
func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) {
|
||||
@@ -164,12 +171,19 @@ func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) {
|
||||
return
|
||||
}
|
||||
|
||||
var relayData peers.RelayPeerData
|
||||
var relayData struct {
|
||||
peers.RelayPeerData
|
||||
ChainId string `json:"chainId"`
|
||||
}
|
||||
if err := json.Unmarshal(jsonData, &relayData); err != nil {
|
||||
logger.Error("Error unmarshaling relay data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if monitor := o.peerManager.GetPeerMonitor(); monitor != nil {
|
||||
monitor.CancelRelaySend(relayData.ChainId)
|
||||
}
|
||||
|
||||
primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve primary relay endpoint: %v", err)
|
||||
@@ -197,12 +211,19 @@ func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) {
|
||||
return
|
||||
}
|
||||
|
||||
var relayData peers.UnRelayPeerData
|
||||
var relayData struct {
|
||||
peers.UnRelayPeerData
|
||||
ChainId string `json:"chainId"`
|
||||
}
|
||||
if err := json.Unmarshal(jsonData, &relayData); err != nil {
|
||||
logger.Error("Error unmarshaling relay data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if monitor := o.peerManager.GetPeerMonitor(); monitor != nil {
|
||||
monitor.CancelRelaySend(relayData.ChainId)
|
||||
}
|
||||
|
||||
primaryRelay, err := util.ResolveDomain(relayData.Endpoint)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
|
||||
@@ -230,7 +251,8 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
|
||||
}
|
||||
|
||||
var handshakeData struct {
|
||||
SiteId int `json:"siteId"`
|
||||
SiteId int `json:"siteId"`
|
||||
ChainId string `json:"chainId"`
|
||||
ExitNode struct {
|
||||
PublicKey string `json:"publicKey"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
@@ -243,6 +265,16 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
|
||||
return
|
||||
}
|
||||
|
||||
// Stop the peer init sender for this chain, if any
|
||||
if handshakeData.ChainId != "" {
|
||||
o.peerSendMu.Lock()
|
||||
if stop, ok := o.stopPeerInits[handshakeData.ChainId]; ok {
|
||||
stop()
|
||||
delete(o.stopPeerInits, handshakeData.ChainId)
|
||||
}
|
||||
o.peerSendMu.Unlock()
|
||||
}
|
||||
|
||||
// Get existing peer from PeerManager
|
||||
_, exists := o.peerManager.GetPeer(handshakeData.SiteId)
|
||||
if exists {
|
||||
@@ -273,10 +305,64 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
|
||||
o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt
|
||||
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
|
||||
|
||||
// Send handshake acknowledgment back to server with retry
|
||||
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
||||
"siteId": handshakeData.SiteId,
|
||||
}, 1*time.Second, 10)
|
||||
// Send handshake acknowledgment back to server with retry, keyed by chainId
|
||||
chainId := handshakeData.ChainId
|
||||
if chainId == "" {
|
||||
chainId = generateChainId()
|
||||
}
|
||||
o.peerSendMu.Lock()
|
||||
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
||||
"siteId": handshakeData.SiteId,
|
||||
"chainId": chainId,
|
||||
}, 2*time.Second, 10)
|
||||
o.stopPeerSends[chainId] = stopFunc
|
||||
o.peerSendMu.Unlock()
|
||||
|
||||
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
|
||||
}
|
||||
|
||||
func (o *Olm) handleCancelChain(msg websocket.WSMessage) {
|
||||
logger.Debug("Received cancel-chain message: %v", msg.Data)
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling cancel-chain data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var cancelData struct {
|
||||
ChainId string `json:"chainId"`
|
||||
}
|
||||
if err := json.Unmarshal(jsonData, &cancelData); err != nil {
|
||||
logger.Error("Error unmarshaling cancel-chain data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if cancelData.ChainId == "" {
|
||||
logger.Warn("Received cancel-chain message with no chainId")
|
||||
return
|
||||
}
|
||||
|
||||
o.peerSendMu.Lock()
|
||||
defer o.peerSendMu.Unlock()
|
||||
|
||||
found := false
|
||||
|
||||
if stop, ok := o.stopPeerInits[cancelData.ChainId]; ok {
|
||||
stop()
|
||||
delete(o.stopPeerInits, cancelData.ChainId)
|
||||
found = true
|
||||
}
|
||||
|
||||
if stop, ok := o.stopPeerSends[cancelData.ChainId]; ok {
|
||||
stop()
|
||||
delete(o.stopPeerSends, cancelData.ChainId)
|
||||
found = true
|
||||
}
|
||||
|
||||
if found {
|
||||
logger.Info("Cancelled chain %s", cancelData.ChainId)
|
||||
} else {
|
||||
logger.Warn("Cancel-chain: no active sender found for chain %s", cancelData.ChainId)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package monitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -31,10 +33,14 @@ type PeerMonitor struct {
|
||||
monitors map[int]*Client
|
||||
mutex sync.Mutex
|
||||
running bool
|
||||
timeout time.Duration
|
||||
timeout time.Duration
|
||||
maxAttempts int
|
||||
wsClient *websocket.Client
|
||||
|
||||
// Relay sender tracking
|
||||
relaySends map[string]func()
|
||||
relaySendMu sync.Mutex
|
||||
|
||||
// Netstack fields
|
||||
middleDev *middleDevice.MiddleDevice
|
||||
localIP string
|
||||
@@ -47,13 +53,13 @@ type PeerMonitor struct {
|
||||
nsWg sync.WaitGroup
|
||||
|
||||
// Holepunch testing fields
|
||||
sharedBind *bind.SharedBind
|
||||
holepunchTester *holepunch.HolepunchTester
|
||||
holepunchTimeout time.Duration
|
||||
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
|
||||
holepunchStatus map[int]bool // siteID -> connected status
|
||||
holepunchStopChan chan struct{}
|
||||
holepunchUpdateChan chan struct{}
|
||||
sharedBind *bind.SharedBind
|
||||
holepunchTester *holepunch.HolepunchTester
|
||||
holepunchTimeout time.Duration
|
||||
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
|
||||
holepunchStatus map[int]bool // siteID -> connected status
|
||||
holepunchStopChan chan struct{}
|
||||
holepunchUpdateChan chan struct{}
|
||||
|
||||
// Relay tracking fields
|
||||
relayedPeers map[int]bool // siteID -> whether the peer is currently relayed
|
||||
@@ -82,6 +88,12 @@ type PeerMonitor struct {
|
||||
}
|
||||
|
||||
// NewPeerMonitor creates a new peer monitor with the given callback
|
||||
func generateChainId() string {
|
||||
b := make([]byte, 8)
|
||||
_, _ = rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API) *PeerMonitor {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pm := &PeerMonitor{
|
||||
@@ -99,6 +111,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
|
||||
holepunchEndpoints: make(map[int]string),
|
||||
holepunchStatus: make(map[int]bool),
|
||||
relayedPeers: make(map[int]bool),
|
||||
relaySends: make(map[string]func()),
|
||||
holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures
|
||||
holepunchFailures: make(map[int]int),
|
||||
// Rapid initial test settings: complete within ~1.5 seconds
|
||||
@@ -396,20 +409,23 @@ func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status Connectio
|
||||
}
|
||||
}
|
||||
|
||||
// sendRelay sends a relay message to the server
|
||||
// sendRelay sends a relay message to the server with retry, keyed by chainId
|
||||
func (pm *PeerMonitor) sendRelay(siteID int) error {
|
||||
if pm.wsClient == nil {
|
||||
return fmt.Errorf("websocket client is nil")
|
||||
}
|
||||
|
||||
err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{
|
||||
"siteId": siteID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send registration message: %v", err)
|
||||
return err
|
||||
}
|
||||
logger.Info("Sent relay message")
|
||||
chainId := generateChainId()
|
||||
stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/relay", map[string]interface{}{
|
||||
"siteId": siteID,
|
||||
"chainId": chainId,
|
||||
}, 2*time.Second, 10)
|
||||
|
||||
pm.relaySendMu.Lock()
|
||||
pm.relaySends[chainId] = stopFunc
|
||||
pm.relaySendMu.Unlock()
|
||||
|
||||
logger.Info("Sent relay message for site %d (chain %s)", siteID, chainId)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -419,23 +435,52 @@ func (pm *PeerMonitor) RequestRelay(siteID int) error {
|
||||
return pm.sendRelay(siteID)
|
||||
}
|
||||
|
||||
// sendUnRelay sends an unrelay message to the server
|
||||
// sendUnRelay sends an unrelay message to the server with retry, keyed by chainId
|
||||
func (pm *PeerMonitor) sendUnRelay(siteID int) error {
|
||||
if pm.wsClient == nil {
|
||||
return fmt.Errorf("websocket client is nil")
|
||||
}
|
||||
|
||||
err := pm.wsClient.SendMessage("olm/wg/unrelay", map[string]interface{}{
|
||||
"siteId": siteID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send registration message: %v", err)
|
||||
return err
|
||||
}
|
||||
logger.Info("Sent unrelay message")
|
||||
chainId := generateChainId()
|
||||
stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/unrelay", map[string]interface{}{
|
||||
"siteId": siteID,
|
||||
"chainId": chainId,
|
||||
}, 2*time.Second, 10)
|
||||
|
||||
pm.relaySendMu.Lock()
|
||||
pm.relaySends[chainId] = stopFunc
|
||||
pm.relaySendMu.Unlock()
|
||||
|
||||
logger.Info("Sent unrelay message for site %d (chain %s)", siteID, chainId)
|
||||
return nil
|
||||
}
|
||||
|
||||
// CancelRelaySend stops the interval sender for the given chainId, if one exists.
|
||||
// If chainId is empty, all active relay senders are stopped.
|
||||
func (pm *PeerMonitor) CancelRelaySend(chainId string) {
|
||||
pm.relaySendMu.Lock()
|
||||
defer pm.relaySendMu.Unlock()
|
||||
|
||||
if chainId == "" {
|
||||
for id, stop := range pm.relaySends {
|
||||
if stop != nil {
|
||||
stop()
|
||||
}
|
||||
delete(pm.relaySends, id)
|
||||
}
|
||||
logger.Info("Cancelled all relay senders")
|
||||
return
|
||||
}
|
||||
|
||||
if stop, ok := pm.relaySends[chainId]; ok {
|
||||
stop()
|
||||
delete(pm.relaySends, chainId)
|
||||
logger.Info("Cancelled relay sender for chain %s", chainId)
|
||||
} else {
|
||||
logger.Warn("CancelRelaySend: no active sender for chain %s", chainId)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops monitoring all peers
|
||||
func (pm *PeerMonitor) Stop() {
|
||||
// Stop holepunch monitor first (outside of mutex to avoid deadlock)
|
||||
@@ -534,7 +579,7 @@ func (pm *PeerMonitor) runHolepunchMonitor() {
|
||||
pm.holepunchCurrentInterval = pm.holepunchMinInterval
|
||||
currentInterval := pm.holepunchCurrentInterval
|
||||
pm.mutex.Unlock()
|
||||
|
||||
|
||||
timer.Reset(currentInterval)
|
||||
logger.Debug("Holepunch monitor interval updated, reset to %v", currentInterval)
|
||||
case <-timer.C:
|
||||
@@ -677,6 +722,16 @@ func (pm *PeerMonitor) Close() {
|
||||
// Stop holepunch monitor first (outside of mutex to avoid deadlock)
|
||||
pm.stopHolepunchMonitor()
|
||||
|
||||
// Stop all pending relay senders
|
||||
pm.relaySendMu.Lock()
|
||||
for chainId, stop := range pm.relaySends {
|
||||
if stop != nil {
|
||||
stop()
|
||||
}
|
||||
delete(pm.relaySends, chainId)
|
||||
}
|
||||
pm.relaySendMu.Unlock()
|
||||
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user