From 8e19e475bf3f552db8aee0984cbbf6063a2aaa52 Mon Sep 17 00:00:00 2001 From: Laurence Date: Wed, 29 Apr 2026 07:12:35 +0100 Subject: [PATCH 1/2] Support websocket upgrades in private HTTP proxy Preserve optional ResponseWriter interfaces through statusCapture so httputil.ReverseProxy can hijack upgraded websocket connections. Add a regression test covering websocket traffic through the HTTP handler path. --- netstack2/http_handler.go | 22 +++++++- netstack2/http_handler_test.go | 97 ++++++++++++++++++++++++++++++++++ 2 files changed, 118 insertions(+), 1 deletion(-) create mode 100644 netstack2/http_handler_test.go diff --git a/netstack2/http_handler.go b/netstack2/http_handler.go index d894127..5e44844 100644 --- a/netstack2/http_handler.go +++ b/netstack2/http_handler.go @@ -6,8 +6,10 @@ package netstack2 import ( + "bufio" "context" "crypto/tls" + "errors" "fmt" "net" "net/http" @@ -29,7 +31,7 @@ import ( type HTTPTarget struct { DestAddr string `json:"destAddr"` // IP address or hostname of the downstream service DestPort uint16 `json:"destPort"` // TCP port of the downstream service - Scheme string `json:"scheme"` // When true the outbound leg uses HTTPS + Scheme string `json:"scheme"` // When true the outbound leg uses HTTPS } // --------------------------------------------------------------------------- @@ -322,6 +324,24 @@ func (sc *statusCapture) WriteHeader(code int) { sc.ResponseWriter.WriteHeader(code) } +func (sc *statusCapture) Unwrap() http.ResponseWriter { + return sc.ResponseWriter +} + +func (sc *statusCapture) Flush() { + if flusher, ok := sc.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } +} + +func (sc *statusCapture) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hijacker, ok := sc.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, errors.New("underlying response writer does not support hijacking") + } + return hijacker.Hijack() +} + // handleRequest is the http.Handler entry point. It retrieves the SubnetRule // attached to the connection by ConnContext, selects the first configured // downstream target, and forwards the request via the cached ReverseProxy. diff --git a/netstack2/http_handler_test.go b/netstack2/http_handler_test.go new file mode 100644 index 0000000..a4cc3cd --- /dev/null +++ b/netstack2/http_handler_test.go @@ -0,0 +1,97 @@ +package netstack2 + +import ( + "context" + "net" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/gorilla/websocket" +) + +func TestHTTPHandlerProxiesWebSocketUpgrade(t *testing.T) { + upgrader := websocket.Upgrader{} + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade failed: %v", err) + return + } + defer conn.Close() + + messageType, payload, err := conn.ReadMessage() + if err != nil { + t.Errorf("read failed: %v", err) + return + } + if err := conn.WriteMessage(messageType, append([]byte("echo:"), payload...)); err != nil { + t.Errorf("write failed: %v", err) + } + })) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse backend URL: %v", err) + } + backendHost, backendPort, err := net.SplitHostPort(backendURL.Host) + if err != nil { + t.Fatalf("split backend host: %v", err) + } + port, err := net.LookupPort("tcp", backendPort) + if err != nil { + t.Fatalf("parse backend port: %v", err) + } + + handler := NewHTTPHandler(nil, nil) + rule := &SubnetRule{ + Protocol: "http", + HTTPTargets: []HTTPTarget{ + { + DestAddr: backendHost, + DestPort: uint16(port), + Scheme: backendURL.Scheme, + }, + }, + } + + frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := context.WithValue(r.Context(), connCtxKey{}, rule) + handler.handleRequest(w, r.WithContext(ctx)) + })) + defer frontend.Close() + + frontendURL, err := url.Parse(frontend.URL) + if err != nil { + t.Fatalf("parse frontend URL: %v", err) + } + wsURL := url.URL{ + Scheme: "ws", + Host: frontendURL.Host, + Path: "/socket", + RawQuery: "token=test", + } + + conn, _, err := websocket.DefaultDialer.Dial(wsURL.String(), nil) + if err != nil { + t.Fatalf("dial websocket through proxy: %v", err) + } + defer conn.Close() + + if err := conn.WriteMessage(websocket.TextMessage, []byte("hello")); err != nil { + t.Fatalf("write websocket message: %v", err) + } + + messageType, payload, err := conn.ReadMessage() + if err != nil { + t.Fatalf("read websocket message: %v", err) + } + if messageType != websocket.TextMessage { + t.Fatalf("message type = %d, want %d", messageType, websocket.TextMessage) + } + if got, want := string(payload), "echo:hello"; got != want { + t.Fatalf("payload = %q, want %q", got, want) + } +} From b33c3b88497e17d64dcdc692e6531e18434ad793 Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 29 Apr 2026 15:57:31 -0700 Subject: [PATCH 2/2] Add some test scripts for ws and move to testing/ --- udp_client.py => testing/udp_client.py | 0 udp_server.py => testing/udp_server.py | 0 testing/ws_client.py | 60 ++++++++++++++++++++++++++ testing/ws_server.py | 49 +++++++++++++++++++++ 4 files changed, 109 insertions(+) rename udp_client.py => testing/udp_client.py (100%) rename udp_server.py => testing/udp_server.py (100%) create mode 100644 testing/ws_client.py create mode 100644 testing/ws_server.py diff --git a/udp_client.py b/testing/udp_client.py similarity index 100% rename from udp_client.py rename to testing/udp_client.py diff --git a/udp_server.py b/testing/udp_server.py similarity index 100% rename from udp_server.py rename to testing/udp_server.py diff --git a/testing/ws_client.py b/testing/ws_client.py new file mode 100644 index 0000000..5aa5c72 --- /dev/null +++ b/testing/ws_client.py @@ -0,0 +1,60 @@ +import asyncio +import sys +import websockets + +# Argument parsing: Check if HOST and PORT are provided +if len(sys.argv) < 3 or len(sys.argv) > 4: + print("Usage: python ws_client.py [ws|wss]") + # Example: python ws_client.py 127.0.0.1 8765 + # Example: python ws_client.py 127.0.0.1 8765 wss + sys.exit(1) + +HOST = sys.argv[1] +try: + PORT = int(sys.argv[2]) +except ValueError: + print("Error: HOST_PORT must be an integer.") + sys.exit(1) + +if len(sys.argv) == 4: + SCHEME = sys.argv[3].lower() + if SCHEME not in ("ws", "wss"): + print("Error: scheme must be 'ws' or 'wss'.") + sys.exit(1) +else: + SCHEME = "ws" + +URI = f"{SCHEME}://{HOST}:{PORT}" + +# The message to send to the server +MESSAGE = "Hello WebSocket Server! How are you?" + + +async def main(): + print(f"Connecting to {URI}...") + + try: + async with websockets.connect(URI) as websocket: + print(f"Connected to server.") + print(f"Sending message: '{MESSAGE}'") + + await websocket.send(MESSAGE) + + response = await websocket.recv() + + print("-" * 30) + print(f"Received response from server:") + print(f"-> Data: '{response}'") + + except ConnectionRefusedError: + print(f"Error: Connection to {URI} was refused. Is the server running?") + except websockets.exceptions.InvalidMessage as e: + print(f"Error: Server did not respond with a valid WebSocket handshake: {e}") + except Exception as e: + print(f"Error during communication: {e}") + + print("-" * 30) + print("Client finished.") + + +asyncio.run(main()) diff --git a/testing/ws_server.py b/testing/ws_server.py new file mode 100644 index 0000000..2e2880d --- /dev/null +++ b/testing/ws_server.py @@ -0,0 +1,49 @@ +import asyncio +import sys +import websockets + +# Optionally take in a positional arg for the port +if len(sys.argv) > 1: + try: + PORT = int(sys.argv[1]) + except ValueError: + print("Invalid port number. Using default port 8765.") + PORT = 8765 +else: + PORT = 8765 + +# Define the server host +HOST = "0.0.0.0" + + +async def handle_client(websocket): + client_address = websocket.remote_address + print(f"Client connected: {client_address[0]}:{client_address[1]}") + + try: + async for message in websocket: + print("-" * 30) + print(f"Received message from {client_address[0]}:{client_address[1]}:") + print(f"-> Data: '{message}'") + + response = f"Hello client! Server received: '{message.upper()}'" + + await websocket.send(response) + print(f"Sent response back to client.") + + except websockets.exceptions.ConnectionClosedOK: + print(f"Client {client_address[0]}:{client_address[1]} disconnected cleanly.") + except websockets.exceptions.ConnectionClosedError as e: + print(f"Client {client_address[0]}:{client_address[1]} disconnected with error: {e}") + + +async def main(): + print(f"WebSocket Server listening on {HOST}:{PORT}") + async with websockets.serve(handle_client, HOST, PORT): + await asyncio.Future() # Run forever + + +try: + asyncio.run(main()) +except KeyboardInterrupt: + print("\nServer stopped.")