mirror of
https://github.com/fosrl/newt.git
synced 2026-05-08 00:48:55 -05:00
Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
663e98af60 | ||
|
|
901ec71baf | ||
|
|
9bc0204f57 | ||
|
|
1e77b09e3b | ||
|
|
74fd3f3aa3 | ||
|
|
e8dc19a62b | ||
|
|
9ff32b8a8b | ||
|
|
27f7ca6bb9 | ||
|
|
5090907307 | ||
|
|
a6533b3fa0 | ||
|
|
57aa2e2e2c | ||
|
|
5724c516dc | ||
|
|
b33c3b8849 | ||
|
|
8e19e475bf | ||
|
|
9e92c42876 | ||
|
|
66c72bbe2e | ||
|
|
ffd26f9a6d | ||
|
|
7610aa40bf |
82
common.go
82
common.go
@@ -208,6 +208,7 @@ func pingWithRetry(tnet *netstack.Net, dst string, timeout time.Duration) (stopC
|
|||||||
logger.Warn(msgHealthFileWriteFailed, err)
|
logger.Warn(msgHealthFileWriteFailed, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
case <-pingStopChan:
|
case <-pingStopChan:
|
||||||
// Stop the goroutine when signaled
|
// Stop the goroutine when signaled
|
||||||
@@ -220,6 +221,25 @@ func pingWithRetry(tnet *netstack.Net, dst string, timeout time.Duration) (stopC
|
|||||||
return stopChan, fmt.Errorf("initial ping attempts failed, continuing in background")
|
return stopChan, fmt.Errorf("initial ping attempts failed, continuing in background")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// shouldFireRecovery decides whether the data-plane recovery flow in
|
||||||
|
// startPingCheck should run on this tick. Recovery fires once when the
|
||||||
|
// consecutive-failure counter first crosses the threshold; the connectionLost
|
||||||
|
// flag prevents re-firing until a successful ping resets the state.
|
||||||
|
//
|
||||||
|
// This condition was previously inlined into startPingCheck and AND-ed with
|
||||||
|
// `currentInterval < maxInterval`, which silently broke recovery once
|
||||||
|
// pingInterval's default was bumped to 15s while maxInterval stayed at 6s
|
||||||
|
// (commit 8161fa6, March 2026): the gate became permanently false on default
|
||||||
|
// settings, so the recovery code never executed and ping failures climbed
|
||||||
|
// forever — the proximate cause of fosrl/newt#284, #310 and pangolin#1004.
|
||||||
|
//
|
||||||
|
// Recovery and backoff are independent concerns; the backoff ramp is now
|
||||||
|
// computed separately in the caller. Do not re-introduce currentInterval
|
||||||
|
// here.
|
||||||
|
func shouldFireRecovery(consecutiveFailures, failureThreshold int, connectionLost bool) bool {
|
||||||
|
return consecutiveFailures >= failureThreshold && !connectionLost
|
||||||
|
}
|
||||||
|
|
||||||
func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Client, tunnelID string) chan struct{} {
|
func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Client, tunnelID string) chan struct{} {
|
||||||
maxInterval := 6 * time.Second
|
maxInterval := 6 * time.Second
|
||||||
currentInterval := pingInterval
|
currentInterval := pingInterval
|
||||||
@@ -279,42 +299,44 @@ func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Clien
|
|||||||
|
|
||||||
// More lenient threshold for declaring connection lost under load
|
// More lenient threshold for declaring connection lost under load
|
||||||
failureThreshold := 4
|
failureThreshold := 4
|
||||||
if consecutiveFailures >= failureThreshold && currentInterval < maxInterval {
|
if shouldFireRecovery(consecutiveFailures, failureThreshold, connectionLost) {
|
||||||
if !connectionLost {
|
connectionLost = true
|
||||||
connectionLost = true
|
logger.Warn("Connection to server lost after %d failures. Continuous reconnection attempts will be made.", consecutiveFailures)
|
||||||
logger.Warn("Connection to server lost after %d failures. Continuous reconnection attempts will be made.", consecutiveFailures)
|
if tunnelID != "" {
|
||||||
if tunnelID != "" {
|
telemetry.IncReconnect(context.Background(), tunnelID, "client", telemetry.ReasonTimeout)
|
||||||
telemetry.IncReconnect(context.Background(), tunnelID, "client", telemetry.ReasonTimeout)
|
}
|
||||||
}
|
pingChainId := generateChainId()
|
||||||
pingChainId := generateChainId()
|
pendingPingChainId = pingChainId
|
||||||
pendingPingChainId = pingChainId
|
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{
|
||||||
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{
|
"chainId": pingChainId,
|
||||||
"chainId": pingChainId,
|
}, 3*time.Second)
|
||||||
}, 3*time.Second)
|
// Send registration message to the server for backward compatibility
|
||||||
// Send registration message to the server for backward compatibility
|
bcChainId := generateChainId()
|
||||||
bcChainId := generateChainId()
|
pendingRegisterChainId = bcChainId
|
||||||
pendingRegisterChainId = bcChainId
|
err := client.SendMessage("newt/wg/register", map[string]interface{}{
|
||||||
err := client.SendMessage("newt/wg/register", map[string]interface{}{
|
"publicKey": publicKey.String(),
|
||||||
"publicKey": publicKey.String(),
|
"backwardsCompatible": true,
|
||||||
"backwardsCompatible": true,
|
"chainId": bcChainId,
|
||||||
"chainId": bcChainId,
|
})
|
||||||
})
|
if err != nil {
|
||||||
|
logger.Error("Failed to send registration message: %v", err)
|
||||||
|
}
|
||||||
|
if healthFile != "" {
|
||||||
|
err = os.Remove(healthFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to send registration message: %v", err)
|
logger.Error("Failed to remove health file: %v", err)
|
||||||
}
|
|
||||||
if healthFile != "" {
|
|
||||||
err = os.Remove(healthFile)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to remove health file: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
currentInterval = time.Duration(float64(currentInterval) * 1.3) // Slower increase
|
}
|
||||||
|
// Backoff: ramp the periodic-ping interval up while we are
|
||||||
|
// past the failure threshold, capped at maxInterval. Kept
|
||||||
|
// independent of the recovery trigger above so the trigger
|
||||||
|
// fires on every outage regardless of pingInterval.
|
||||||
|
if consecutiveFailures >= failureThreshold && currentInterval < maxInterval {
|
||||||
|
currentInterval = time.Duration(float64(currentInterval) * 1.3)
|
||||||
if currentInterval > maxInterval {
|
if currentInterval > maxInterval {
|
||||||
currentInterval = maxInterval
|
currentInterval = maxInterval
|
||||||
}
|
}
|
||||||
ticker.Reset(currentInterval)
|
|
||||||
logger.Debug("Increased ping check interval to %v due to consecutive failures", currentInterval)
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Track recent latencies
|
// Track recent latencies
|
||||||
|
|||||||
@@ -210,3 +210,42 @@ func TestParseTargetStringNetDialCompatibility(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestShouldFireRecovery is the regression guard for the broken trigger gate
|
||||||
|
// that prevented data-plane recovery from ever firing under default settings
|
||||||
|
// (fosrl/newt#284, #310, pangolin#1004). The pre-fix condition was
|
||||||
|
//
|
||||||
|
// consecutiveFailures >= failureThreshold && currentInterval < maxInterval
|
||||||
|
//
|
||||||
|
// which became permanently false once pingInterval's default was bumped from
|
||||||
|
// 3s to 15s in commit 8161fa6 — currentInterval starts at pingInterval=15s,
|
||||||
|
// maxInterval stayed at 6s, so 15<6 is false and the recovery branch never
|
||||||
|
// executed.
|
||||||
|
//
|
||||||
|
// The fix is to drop currentInterval from the trigger condition entirely;
|
||||||
|
// backoff is a separate concern computed in the caller. The cases below
|
||||||
|
// exercise the documented contract.
|
||||||
|
func TestShouldFireRecovery(t *testing.T) {
|
||||||
|
const threshold = 4
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
failures int
|
||||||
|
connectionLost bool
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"below threshold, fresh", 3, false, false},
|
||||||
|
{"below threshold, already lost", 3, true, false},
|
||||||
|
{"at threshold, fresh — recovery must fire", threshold, false, true},
|
||||||
|
{"at threshold, already lost — gate prevents re-fire", threshold, true, false},
|
||||||
|
{"far above threshold, fresh", 100, false, true},
|
||||||
|
{"far above threshold, already lost", 100, true, false},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
if got := shouldFireRecovery(c.failures, threshold, c.connectionLost); got != c.want {
|
||||||
|
t.Errorf("shouldFireRecovery(failures=%d, threshold=%d, lost=%v) = %v, want %v",
|
||||||
|
c.failures, threshold, c.connectionLost, got, c.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -25,7 +25,7 @@
|
|||||||
inherit (pkgs) lib;
|
inherit (pkgs) lib;
|
||||||
|
|
||||||
# Update version when releasing
|
# Update version when releasing
|
||||||
version = "1.11.0";
|
version = "1.12.4";
|
||||||
in
|
in
|
||||||
{
|
{
|
||||||
default = self.packages.${system}.pangolin-newt;
|
default = self.packages.${system}.pangolin-newt;
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ type Config struct {
|
|||||||
Interval int `json:"hcInterval"` // in seconds
|
Interval int `json:"hcInterval"` // in seconds
|
||||||
UnhealthyInterval int `json:"hcUnhealthyInterval"` // in seconds
|
UnhealthyInterval int `json:"hcUnhealthyInterval"` // in seconds
|
||||||
Timeout int `json:"hcTimeout"` // in seconds
|
Timeout int `json:"hcTimeout"` // in seconds
|
||||||
FollowRedirects bool `json:"hcFollowRedirects"`
|
FollowRedirects *bool `json:"hcFollowRedirects"`
|
||||||
Headers map[string]string `json:"hcHeaders"`
|
Headers map[string]string `json:"hcHeaders"`
|
||||||
Method string `json:"hcMethod"`
|
Method string `json:"hcMethod"`
|
||||||
Status int `json:"hcStatus"` // HTTP status code
|
Status int `json:"hcStatus"` // HTTP status code
|
||||||
@@ -202,7 +202,9 @@ func (m *Monitor) addTargetUnsafe(config Config) error {
|
|||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
client: &http.Client{
|
client: &http.Client{
|
||||||
CheckRedirect: func() func(*http.Request, []*http.Request) error {
|
CheckRedirect: func() func(*http.Request, []*http.Request) error {
|
||||||
if !config.FollowRedirects {
|
// Default to following redirects if not explicitly configured
|
||||||
|
followRedirects := config.FollowRedirects == nil || *config.FollowRedirects
|
||||||
|
if !followRedirects {
|
||||||
return func(req *http.Request, via []*http.Request) error {
|
return func(req *http.Request, via []*http.Request) error {
|
||||||
return http.ErrUseLastResponse
|
return http.ErrUseLastResponse
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -152,20 +152,14 @@ func (h *TCPHandler) handleTCPConn(netstackConn *gonet.TCPConn, id stack.Transpo
|
|||||||
srcAddr, _ := netip.ParseAddr(srcIP)
|
srcAddr, _ := netip.ParseAddr(srcIP)
|
||||||
dstAddr, _ := netip.ParseAddr(dstIP)
|
dstAddr, _ := netip.ParseAddr(dstIP)
|
||||||
rule := h.proxyHandler.subnetLookup.Match(srcAddr, dstAddr, dstPort, tcp.ProtocolNumber)
|
rule := h.proxyHandler.subnetLookup.Match(srcAddr, dstAddr, dstPort, tcp.ProtocolNumber)
|
||||||
if rule != nil {
|
if rule != nil && rule.Protocol != "" && len(rule.HTTPTargets) > 0 {
|
||||||
if rule.Protocol != "" {
|
logger.Info("TCP Forwarder: Routing %s:%d -> %s:%d to HTTP handler (%s)",
|
||||||
logger.Info("TCP Forwarder: Routing %s:%d -> %s:%d to HTTP handler (%s)",
|
srcIP, srcPort, dstIP, dstPort, rule.Protocol)
|
||||||
srcIP, srcPort, dstIP, dstPort, rule.Protocol)
|
h.proxyHandler.httpHandler.HandleConn(netstackConn, rule)
|
||||||
h.proxyHandler.httpHandler.HandleConn(netstackConn, rule)
|
|
||||||
} else {
|
|
||||||
// A matching HTTP rule exists but has no protocol configured —
|
|
||||||
// do not fall through to the raw TCP handler; drop the connection.
|
|
||||||
logger.Info("TCP Forwarder: Dropping %s:%d -> %s:%d (HTTP rule matched but no protocol set)",
|
|
||||||
srcIP, srcPort, dstIP, dstPort)
|
|
||||||
netstackConn.Close()
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// Otherwise fall through to raw TCP forwarding (e.g. CIDR resources
|
||||||
|
// that happen to use port 80/443 without HTTP configuration).
|
||||||
}
|
}
|
||||||
|
|
||||||
defer netstackConn.Close()
|
defer netstackConn.Close()
|
||||||
|
|||||||
@@ -6,8 +6,10 @@
|
|||||||
package netstack2
|
package netstack2
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -29,7 +31,7 @@ import (
|
|||||||
type HTTPTarget struct {
|
type HTTPTarget struct {
|
||||||
DestAddr string `json:"destAddr"` // IP address or hostname of the downstream service
|
DestAddr string `json:"destAddr"` // IP address or hostname of the downstream service
|
||||||
DestPort uint16 `json:"destPort"` // TCP port of the downstream service
|
DestPort uint16 `json:"destPort"` // TCP port of the downstream service
|
||||||
Scheme string `json:"scheme"` // When true the outbound leg uses HTTPS
|
Scheme string `json:"scheme"` // When true the outbound leg uses HTTPS
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
@@ -322,6 +324,24 @@ func (sc *statusCapture) WriteHeader(code int) {
|
|||||||
sc.ResponseWriter.WriteHeader(code)
|
sc.ResponseWriter.WriteHeader(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (sc *statusCapture) Unwrap() http.ResponseWriter {
|
||||||
|
return sc.ResponseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *statusCapture) Flush() {
|
||||||
|
if flusher, ok := sc.ResponseWriter.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *statusCapture) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
hijacker, ok := sc.ResponseWriter.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
return nil, nil, errors.New("underlying response writer does not support hijacking")
|
||||||
|
}
|
||||||
|
return hijacker.Hijack()
|
||||||
|
}
|
||||||
|
|
||||||
// handleRequest is the http.Handler entry point. It retrieves the SubnetRule
|
// handleRequest is the http.Handler entry point. It retrieves the SubnetRule
|
||||||
// attached to the connection by ConnContext, selects the first configured
|
// attached to the connection by ConnContext, selects the first configured
|
||||||
// downstream target, and forwards the request via the cached ReverseProxy.
|
// downstream target, and forwards the request via the cached ReverseProxy.
|
||||||
@@ -336,16 +356,16 @@ func (h *HTTPHandler) handleRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the rule is plain HTTP but has a TLS certificate configured, redirect
|
// If the rule is HTTPS and a TLS certificate is configured, but the
|
||||||
// the client to the HTTPS equivalent of the requested URL.
|
// incoming request arrived over plain HTTP, redirect to HTTPS.
|
||||||
if rule.Protocol == "http" && rule.TLSCert != "" && rule.TLSKey != "" {
|
if rule.Protocol == "https" && rule.TLSCert != "" && rule.TLSKey != "" && r.TLS == nil {
|
||||||
host := r.Host
|
host := r.Host
|
||||||
if host == "" {
|
if host == "" {
|
||||||
host = r.URL.Host
|
host = r.URL.Host
|
||||||
}
|
}
|
||||||
httpsURL := "https://" + host + r.RequestURI
|
httpsURL := "https://" + host + r.RequestURI
|
||||||
logger.Info("HTTP handler: redirecting %s %s -> %s (TLS cert present)", r.Method, r.URL.RequestURI(), httpsURL)
|
logger.Info("HTTP handler: redirecting %s %s -> %s (TLS cert present)", r.Method, r.URL.RequestURI(), httpsURL)
|
||||||
http.Redirect(w, r, httpsURL, http.StatusMovedPermanently)
|
http.Redirect(w, r, httpsURL, http.StatusPermanentRedirect)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
97
netstack2/http_handler_test.go
Normal file
97
netstack2/http_handler_test.go
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
package netstack2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHTTPHandlerProxiesWebSocketUpgrade(t *testing.T) {
|
||||||
|
upgrader := websocket.Upgrader{}
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("upgrade failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
messageType, payload, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("read failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := conn.WriteMessage(messageType, append([]byte("echo:"), payload...)); err != nil {
|
||||||
|
t.Errorf("write failed: %v", err)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
backendURL, err := url.Parse(backend.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse backend URL: %v", err)
|
||||||
|
}
|
||||||
|
backendHost, backendPort, err := net.SplitHostPort(backendURL.Host)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("split backend host: %v", err)
|
||||||
|
}
|
||||||
|
port, err := net.LookupPort("tcp", backendPort)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse backend port: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := NewHTTPHandler(nil, nil)
|
||||||
|
rule := &SubnetRule{
|
||||||
|
Protocol: "http",
|
||||||
|
HTTPTargets: []HTTPTarget{
|
||||||
|
{
|
||||||
|
DestAddr: backendHost,
|
||||||
|
DestPort: uint16(port),
|
||||||
|
Scheme: backendURL.Scheme,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := context.WithValue(r.Context(), connCtxKey{}, rule)
|
||||||
|
handler.handleRequest(w, r.WithContext(ctx))
|
||||||
|
}))
|
||||||
|
defer frontend.Close()
|
||||||
|
|
||||||
|
frontendURL, err := url.Parse(frontend.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse frontend URL: %v", err)
|
||||||
|
}
|
||||||
|
wsURL := url.URL{
|
||||||
|
Scheme: "ws",
|
||||||
|
Host: frontendURL.Host,
|
||||||
|
Path: "/socket",
|
||||||
|
RawQuery: "token=test",
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial websocket through proxy: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
if err := conn.WriteMessage(websocket.TextMessage, []byte("hello")); err != nil {
|
||||||
|
t.Fatalf("write websocket message: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
messageType, payload, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read websocket message: %v", err)
|
||||||
|
}
|
||||||
|
if messageType != websocket.TextMessage {
|
||||||
|
t.Fatalf("message type = %d, want %d", messageType, websocket.TextMessage)
|
||||||
|
}
|
||||||
|
if got, want := string(payload), "echo:hello"; got != want {
|
||||||
|
t.Fatalf("payload = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -572,6 +572,18 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
|
|||||||
|
|
||||||
// Store destination rewrite for handler lookups
|
// Store destination rewrite for handler lookups
|
||||||
p.destRewriteTable[dKey] = newDst
|
p.destRewriteTable[dKey] = newDst
|
||||||
|
|
||||||
|
// Also store the resource ID under the rewritten destination key so that
|
||||||
|
// TCP/UDP handlers can find it after DNAT (they see the post-NAT dst IP).
|
||||||
|
if matchedRule.ResourceId != 0 {
|
||||||
|
rewrittenKey := destKey{
|
||||||
|
srcIP: srcAddr.String(),
|
||||||
|
dstIP: newDst.String(),
|
||||||
|
dstPort: dstPort,
|
||||||
|
proto: uint8(protocol),
|
||||||
|
}
|
||||||
|
p.resourceTable[rewrittenKey] = matchedRule.ResourceId
|
||||||
|
}
|
||||||
p.natMu.Unlock()
|
p.natMu.Unlock()
|
||||||
logger.Debug("New NAT entry for connection: %s -> %s", dstAddr, newDst)
|
logger.Debug("New NAT entry for connection: %s -> %s", dstAddr, newDst)
|
||||||
}
|
}
|
||||||
|
|||||||
60
testing/ws_client.py
Normal file
60
testing/ws_client.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
# Argument parsing: Check if HOST and PORT are provided
|
||||||
|
if len(sys.argv) < 3 or len(sys.argv) > 4:
|
||||||
|
print("Usage: python ws_client.py <HOST_IP> <HOST_PORT> [ws|wss]")
|
||||||
|
# Example: python ws_client.py 127.0.0.1 8765
|
||||||
|
# Example: python ws_client.py 127.0.0.1 8765 wss
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
HOST = sys.argv[1]
|
||||||
|
try:
|
||||||
|
PORT = int(sys.argv[2])
|
||||||
|
except ValueError:
|
||||||
|
print("Error: HOST_PORT must be an integer.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if len(sys.argv) == 4:
|
||||||
|
SCHEME = sys.argv[3].lower()
|
||||||
|
if SCHEME not in ("ws", "wss"):
|
||||||
|
print("Error: scheme must be 'ws' or 'wss'.")
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
SCHEME = "ws"
|
||||||
|
|
||||||
|
URI = f"{SCHEME}://{HOST}:{PORT}"
|
||||||
|
|
||||||
|
# The message to send to the server
|
||||||
|
MESSAGE = "Hello WebSocket Server! How are you?"
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
print(f"Connecting to {URI}...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with websockets.connect(URI) as websocket:
|
||||||
|
print(f"Connected to server.")
|
||||||
|
print(f"Sending message: '{MESSAGE}'")
|
||||||
|
|
||||||
|
await websocket.send(MESSAGE)
|
||||||
|
|
||||||
|
response = await websocket.recv()
|
||||||
|
|
||||||
|
print("-" * 30)
|
||||||
|
print(f"Received response from server:")
|
||||||
|
print(f"-> Data: '{response}'")
|
||||||
|
|
||||||
|
except ConnectionRefusedError:
|
||||||
|
print(f"Error: Connection to {URI} was refused. Is the server running?")
|
||||||
|
except websockets.exceptions.InvalidMessage as e:
|
||||||
|
print(f"Error: Server did not respond with a valid WebSocket handshake: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during communication: {e}")
|
||||||
|
|
||||||
|
print("-" * 30)
|
||||||
|
print("Client finished.")
|
||||||
|
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
49
testing/ws_server.py
Normal file
49
testing/ws_server.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
# Optionally take in a positional arg for the port
|
||||||
|
if len(sys.argv) > 1:
|
||||||
|
try:
|
||||||
|
PORT = int(sys.argv[1])
|
||||||
|
except ValueError:
|
||||||
|
print("Invalid port number. Using default port 8765.")
|
||||||
|
PORT = 8765
|
||||||
|
else:
|
||||||
|
PORT = 8765
|
||||||
|
|
||||||
|
# Define the server host
|
||||||
|
HOST = "0.0.0.0"
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_client(websocket):
|
||||||
|
client_address = websocket.remote_address
|
||||||
|
print(f"Client connected: {client_address[0]}:{client_address[1]}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for message in websocket:
|
||||||
|
print("-" * 30)
|
||||||
|
print(f"Received message from {client_address[0]}:{client_address[1]}:")
|
||||||
|
print(f"-> Data: '{message}'")
|
||||||
|
|
||||||
|
response = f"Hello client! Server received: '{message.upper()}'"
|
||||||
|
|
||||||
|
await websocket.send(response)
|
||||||
|
print(f"Sent response back to client.")
|
||||||
|
|
||||||
|
except websockets.exceptions.ConnectionClosedOK:
|
||||||
|
print(f"Client {client_address[0]}:{client_address[1]} disconnected cleanly.")
|
||||||
|
except websockets.exceptions.ConnectionClosedError as e:
|
||||||
|
print(f"Client {client_address[0]}:{client_address[1]} disconnected with error: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
print(f"WebSocket Server listening on {HOST}:{PORT}")
|
||||||
|
async with websockets.serve(handle_client, HOST, PORT):
|
||||||
|
await asyncio.Future() # Run forever
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
asyncio.run(main())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nServer stopped.")
|
||||||
@@ -48,7 +48,7 @@ type Client struct {
|
|||||||
metricsCtx context.Context
|
metricsCtx context.Context
|
||||||
configNeedsSave bool // Flag to track if config needs to be saved
|
configNeedsSave bool // Flag to track if config needs to be saved
|
||||||
serverVersion string
|
serverVersion string
|
||||||
configVersion int64 // Latest config version received from server
|
configVersion int64 // Latest config version received from server
|
||||||
configVersionMux sync.RWMutex
|
configVersionMux sync.RWMutex
|
||||||
processingMessage bool // Flag to track if a message is currently being processed
|
processingMessage bool // Flag to track if a message is currently being processed
|
||||||
processingMux sync.RWMutex // Protects processingMessage
|
processingMux sync.RWMutex // Protects processingMessage
|
||||||
@@ -271,13 +271,17 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
|
|||||||
stopChan := make(chan struct{})
|
stopChan := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
count := 0
|
count := 0
|
||||||
maxAttempts := 10
|
maxAttempts := 16
|
||||||
|
|
||||||
|
c.reconnectMux.RLock()
|
||||||
|
connected := c.isConnected
|
||||||
|
c.reconnectMux.RUnlock()
|
||||||
err := c.SendMessage(messageType, data) // Send immediately
|
err := c.SendMessage(messageType, data) // Send immediately
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to send initial message: %v", err)
|
logger.Error("Failed to send initial message: %v", err)
|
||||||
|
} else if connected {
|
||||||
|
count++
|
||||||
}
|
}
|
||||||
count++
|
|
||||||
|
|
||||||
ticker := time.NewTicker(interval)
|
ticker := time.NewTicker(interval)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
@@ -288,11 +292,15 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
|
|||||||
logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
|
logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
c.reconnectMux.RLock()
|
||||||
|
connected = c.isConnected
|
||||||
|
c.reconnectMux.RUnlock()
|
||||||
err = c.SendMessage(messageType, data)
|
err = c.SendMessage(messageType, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to send message: %v", err)
|
logger.Error("Failed to send message: %v", err)
|
||||||
|
} else if connected {
|
||||||
|
count++
|
||||||
}
|
}
|
||||||
count++
|
|
||||||
case <-stopChan:
|
case <-stopChan:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -836,7 +844,7 @@ func (c *Client) readPumpWithDisconnectDetection(started time.Time) {
|
|||||||
logger.Error("WebSocket failed to parse message: %v", err)
|
logger.Error("WebSocket failed to parse message: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
c.setConfigVersion(msg.ConfigVersion)
|
c.setConfigVersion(msg.ConfigVersion)
|
||||||
|
|
||||||
c.handlersMux.RLock()
|
c.handlersMux.RLock()
|
||||||
|
|||||||
Reference in New Issue
Block a user