Compare commits

...

7 Commits
1.12.1 ... dev

Author SHA1 Message Date
Owen
27f7ca6bb9 Try to fix failover not working 2026-05-05 11:40:39 -07:00
Owen
5090907307 Update status code 2026-04-30 15:55:52 -07:00
Owen
a6533b3fa0 Fix incorrect redirect logic 2026-04-29 21:11:07 -07:00
Owen Schwartz
5724c516dc Merge pull request #334 from LaurenceJJones/private-http-websocket
enhance(http): Support websocket upgrades
2026-04-29 15:58:30 -07:00
Owen
b33c3b8849 Add some test scripts for ws and move to testing/ 2026-04-29 15:57:31 -07:00
Laurence
8e19e475bf Support websocket upgrades in private HTTP proxy
Preserve optional ResponseWriter interfaces through statusCapture so httputil.ReverseProxy can hijack upgraded websocket connections. Add a regression test covering websocket traffic through the HTTP handler path.
2026-04-29 07:12:35 +01:00
Owen
66c72bbe2e Dont block tcp for http unless there are targets 2026-04-28 14:29:55 -07:00
8 changed files with 245 additions and 23 deletions

View File

@@ -279,7 +279,7 @@ func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Clien
// More lenient threshold for declaring connection lost under load
failureThreshold := 4
if consecutiveFailures >= failureThreshold && currentInterval < maxInterval {
if consecutiveFailures >= failureThreshold {
if !connectionLost {
connectionLost = true
logger.Warn("Connection to server lost after %d failures. Continuous reconnection attempts will be made.", consecutiveFailures)
@@ -309,12 +309,14 @@ func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Clien
}
}
}
currentInterval = time.Duration(float64(currentInterval) * 1.3) // Slower increase
if currentInterval > maxInterval {
currentInterval = maxInterval
if currentInterval < maxInterval {
currentInterval = time.Duration(float64(currentInterval) * 1.3) // Slower increase
if currentInterval > maxInterval {
currentInterval = maxInterval
}
ticker.Reset(currentInterval)
logger.Debug("Increased ping check interval to %v due to consecutive failures", currentInterval)
}
ticker.Reset(currentInterval)
logger.Debug("Increased ping check interval to %v due to consecutive failures", currentInterval)
}
} else {
// Track recent latencies

View File

@@ -152,20 +152,14 @@ func (h *TCPHandler) handleTCPConn(netstackConn *gonet.TCPConn, id stack.Transpo
srcAddr, _ := netip.ParseAddr(srcIP)
dstAddr, _ := netip.ParseAddr(dstIP)
rule := h.proxyHandler.subnetLookup.Match(srcAddr, dstAddr, dstPort, tcp.ProtocolNumber)
if rule != nil {
if rule.Protocol != "" {
logger.Info("TCP Forwarder: Routing %s:%d -> %s:%d to HTTP handler (%s)",
srcIP, srcPort, dstIP, dstPort, rule.Protocol)
h.proxyHandler.httpHandler.HandleConn(netstackConn, rule)
} else {
// A matching HTTP rule exists but has no protocol configured —
// do not fall through to the raw TCP handler; drop the connection.
logger.Info("TCP Forwarder: Dropping %s:%d -> %s:%d (HTTP rule matched but no protocol set)",
srcIP, srcPort, dstIP, dstPort)
netstackConn.Close()
}
if rule != nil && rule.Protocol != "" && len(rule.HTTPTargets) > 0 {
logger.Info("TCP Forwarder: Routing %s:%d -> %s:%d to HTTP handler (%s)",
srcIP, srcPort, dstIP, dstPort, rule.Protocol)
h.proxyHandler.httpHandler.HandleConn(netstackConn, rule)
return
}
// Otherwise fall through to raw TCP forwarding (e.g. CIDR resources
// that happen to use port 80/443 without HTTP configuration).
}
defer netstackConn.Close()

View File

@@ -6,8 +6,10 @@
package netstack2
import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
@@ -29,7 +31,7 @@ import (
type HTTPTarget struct {
DestAddr string `json:"destAddr"` // IP address or hostname of the downstream service
DestPort uint16 `json:"destPort"` // TCP port of the downstream service
Scheme string `json:"scheme"` // When true the outbound leg uses HTTPS
Scheme string `json:"scheme"` // When true the outbound leg uses HTTPS
}
// ---------------------------------------------------------------------------
@@ -322,6 +324,24 @@ func (sc *statusCapture) WriteHeader(code int) {
sc.ResponseWriter.WriteHeader(code)
}
func (sc *statusCapture) Unwrap() http.ResponseWriter {
return sc.ResponseWriter
}
func (sc *statusCapture) Flush() {
if flusher, ok := sc.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
func (sc *statusCapture) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := sc.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, errors.New("underlying response writer does not support hijacking")
}
return hijacker.Hijack()
}
// handleRequest is the http.Handler entry point. It retrieves the SubnetRule
// attached to the connection by ConnContext, selects the first configured
// downstream target, and forwards the request via the cached ReverseProxy.
@@ -336,16 +356,16 @@ func (h *HTTPHandler) handleRequest(w http.ResponseWriter, r *http.Request) {
return
}
// If the rule is plain HTTP but has a TLS certificate configured, redirect
// the client to the HTTPS equivalent of the requested URL.
if rule.Protocol == "http" && rule.TLSCert != "" && rule.TLSKey != "" {
// If the rule is HTTPS and a TLS certificate is configured, but the
// incoming request arrived over plain HTTP, redirect to HTTPS.
if rule.Protocol == "https" && rule.TLSCert != "" && rule.TLSKey != "" && r.TLS == nil {
host := r.Host
if host == "" {
host = r.URL.Host
}
httpsURL := "https://" + host + r.RequestURI
logger.Info("HTTP handler: redirecting %s %s -> %s (TLS cert present)", r.Method, r.URL.RequestURI(), httpsURL)
http.Redirect(w, r, httpsURL, http.StatusMovedPermanently)
http.Redirect(w, r, httpsURL, http.StatusPermanentRedirect)
return
}

View File

@@ -0,0 +1,97 @@
package netstack2
import (
"context"
"net"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/gorilla/websocket"
)
func TestHTTPHandlerProxiesWebSocketUpgrade(t *testing.T) {
upgrader := websocket.Upgrader{}
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade failed: %v", err)
return
}
defer conn.Close()
messageType, payload, err := conn.ReadMessage()
if err != nil {
t.Errorf("read failed: %v", err)
return
}
if err := conn.WriteMessage(messageType, append([]byte("echo:"), payload...)); err != nil {
t.Errorf("write failed: %v", err)
}
}))
defer backend.Close()
backendURL, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse backend URL: %v", err)
}
backendHost, backendPort, err := net.SplitHostPort(backendURL.Host)
if err != nil {
t.Fatalf("split backend host: %v", err)
}
port, err := net.LookupPort("tcp", backendPort)
if err != nil {
t.Fatalf("parse backend port: %v", err)
}
handler := NewHTTPHandler(nil, nil)
rule := &SubnetRule{
Protocol: "http",
HTTPTargets: []HTTPTarget{
{
DestAddr: backendHost,
DestPort: uint16(port),
Scheme: backendURL.Scheme,
},
},
}
frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), connCtxKey{}, rule)
handler.handleRequest(w, r.WithContext(ctx))
}))
defer frontend.Close()
frontendURL, err := url.Parse(frontend.URL)
if err != nil {
t.Fatalf("parse frontend URL: %v", err)
}
wsURL := url.URL{
Scheme: "ws",
Host: frontendURL.Host,
Path: "/socket",
RawQuery: "token=test",
}
conn, _, err := websocket.DefaultDialer.Dial(wsURL.String(), nil)
if err != nil {
t.Fatalf("dial websocket through proxy: %v", err)
}
defer conn.Close()
if err := conn.WriteMessage(websocket.TextMessage, []byte("hello")); err != nil {
t.Fatalf("write websocket message: %v", err)
}
messageType, payload, err := conn.ReadMessage()
if err != nil {
t.Fatalf("read websocket message: %v", err)
}
if messageType != websocket.TextMessage {
t.Fatalf("message type = %d, want %d", messageType, websocket.TextMessage)
}
if got, want := string(payload), "echo:hello"; got != want {
t.Fatalf("payload = %q, want %q", got, want)
}
}

60
testing/ws_client.py Normal file
View File

@@ -0,0 +1,60 @@
import asyncio
import sys
import websockets
# Argument parsing: Check if HOST and PORT are provided
if len(sys.argv) < 3 or len(sys.argv) > 4:
print("Usage: python ws_client.py <HOST_IP> <HOST_PORT> [ws|wss]")
# Example: python ws_client.py 127.0.0.1 8765
# Example: python ws_client.py 127.0.0.1 8765 wss
sys.exit(1)
HOST = sys.argv[1]
try:
PORT = int(sys.argv[2])
except ValueError:
print("Error: HOST_PORT must be an integer.")
sys.exit(1)
if len(sys.argv) == 4:
SCHEME = sys.argv[3].lower()
if SCHEME not in ("ws", "wss"):
print("Error: scheme must be 'ws' or 'wss'.")
sys.exit(1)
else:
SCHEME = "ws"
URI = f"{SCHEME}://{HOST}:{PORT}"
# The message to send to the server
MESSAGE = "Hello WebSocket Server! How are you?"
async def main():
print(f"Connecting to {URI}...")
try:
async with websockets.connect(URI) as websocket:
print(f"Connected to server.")
print(f"Sending message: '{MESSAGE}'")
await websocket.send(MESSAGE)
response = await websocket.recv()
print("-" * 30)
print(f"Received response from server:")
print(f"-> Data: '{response}'")
except ConnectionRefusedError:
print(f"Error: Connection to {URI} was refused. Is the server running?")
except websockets.exceptions.InvalidMessage as e:
print(f"Error: Server did not respond with a valid WebSocket handshake: {e}")
except Exception as e:
print(f"Error during communication: {e}")
print("-" * 30)
print("Client finished.")
asyncio.run(main())

49
testing/ws_server.py Normal file
View File

@@ -0,0 +1,49 @@
import asyncio
import sys
import websockets
# Optionally take in a positional arg for the port
if len(sys.argv) > 1:
try:
PORT = int(sys.argv[1])
except ValueError:
print("Invalid port number. Using default port 8765.")
PORT = 8765
else:
PORT = 8765
# Define the server host
HOST = "0.0.0.0"
async def handle_client(websocket):
client_address = websocket.remote_address
print(f"Client connected: {client_address[0]}:{client_address[1]}")
try:
async for message in websocket:
print("-" * 30)
print(f"Received message from {client_address[0]}:{client_address[1]}:")
print(f"-> Data: '{message}'")
response = f"Hello client! Server received: '{message.upper()}'"
await websocket.send(response)
print(f"Sent response back to client.")
except websockets.exceptions.ConnectionClosedOK:
print(f"Client {client_address[0]}:{client_address[1]} disconnected cleanly.")
except websockets.exceptions.ConnectionClosedError as e:
print(f"Client {client_address[0]}:{client_address[1]} disconnected with error: {e}")
async def main():
print(f"WebSocket Server listening on {HOST}:{PORT}")
async with websockets.serve(handle_client, HOST, PORT):
await asyncio.Future() # Run forever
try:
asyncio.run(main())
except KeyboardInterrupt:
print("\nServer stopped.")