From 8e19e475bf3f552db8aee0984cbbf6063a2aaa52 Mon Sep 17 00:00:00 2001 From: Laurence Date: Wed, 29 Apr 2026 07:12:35 +0100 Subject: [PATCH] Support websocket upgrades in private HTTP proxy Preserve optional ResponseWriter interfaces through statusCapture so httputil.ReverseProxy can hijack upgraded websocket connections. Add a regression test covering websocket traffic through the HTTP handler path. --- netstack2/http_handler.go | 22 +++++++- netstack2/http_handler_test.go | 97 ++++++++++++++++++++++++++++++++++ 2 files changed, 118 insertions(+), 1 deletion(-) create mode 100644 netstack2/http_handler_test.go diff --git a/netstack2/http_handler.go b/netstack2/http_handler.go index d894127..5e44844 100644 --- a/netstack2/http_handler.go +++ b/netstack2/http_handler.go @@ -6,8 +6,10 @@ package netstack2 import ( + "bufio" "context" "crypto/tls" + "errors" "fmt" "net" "net/http" @@ -29,7 +31,7 @@ import ( type HTTPTarget struct { DestAddr string `json:"destAddr"` // IP address or hostname of the downstream service DestPort uint16 `json:"destPort"` // TCP port of the downstream service - Scheme string `json:"scheme"` // When true the outbound leg uses HTTPS + Scheme string `json:"scheme"` // When true the outbound leg uses HTTPS } // --------------------------------------------------------------------------- @@ -322,6 +324,24 @@ func (sc *statusCapture) WriteHeader(code int) { sc.ResponseWriter.WriteHeader(code) } +func (sc *statusCapture) Unwrap() http.ResponseWriter { + return sc.ResponseWriter +} + +func (sc *statusCapture) Flush() { + if flusher, ok := sc.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } +} + +func (sc *statusCapture) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hijacker, ok := sc.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, errors.New("underlying response writer does not support hijacking") + } + return hijacker.Hijack() +} + // handleRequest is the http.Handler entry point. It retrieves the SubnetRule // attached to the connection by ConnContext, selects the first configured // downstream target, and forwards the request via the cached ReverseProxy. diff --git a/netstack2/http_handler_test.go b/netstack2/http_handler_test.go new file mode 100644 index 0000000..a4cc3cd --- /dev/null +++ b/netstack2/http_handler_test.go @@ -0,0 +1,97 @@ +package netstack2 + +import ( + "context" + "net" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/gorilla/websocket" +) + +func TestHTTPHandlerProxiesWebSocketUpgrade(t *testing.T) { + upgrader := websocket.Upgrader{} + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade failed: %v", err) + return + } + defer conn.Close() + + messageType, payload, err := conn.ReadMessage() + if err != nil { + t.Errorf("read failed: %v", err) + return + } + if err := conn.WriteMessage(messageType, append([]byte("echo:"), payload...)); err != nil { + t.Errorf("write failed: %v", err) + } + })) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse backend URL: %v", err) + } + backendHost, backendPort, err := net.SplitHostPort(backendURL.Host) + if err != nil { + t.Fatalf("split backend host: %v", err) + } + port, err := net.LookupPort("tcp", backendPort) + if err != nil { + t.Fatalf("parse backend port: %v", err) + } + + handler := NewHTTPHandler(nil, nil) + rule := &SubnetRule{ + Protocol: "http", + HTTPTargets: []HTTPTarget{ + { + DestAddr: backendHost, + DestPort: uint16(port), + Scheme: backendURL.Scheme, + }, + }, + } + + frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := context.WithValue(r.Context(), connCtxKey{}, rule) + handler.handleRequest(w, r.WithContext(ctx)) + })) + defer frontend.Close() + + frontendURL, err := url.Parse(frontend.URL) + if err != nil { + t.Fatalf("parse frontend URL: %v", err) + } + wsURL := url.URL{ + Scheme: "ws", + Host: frontendURL.Host, + Path: "/socket", + RawQuery: "token=test", + } + + conn, _, err := websocket.DefaultDialer.Dial(wsURL.String(), nil) + if err != nil { + t.Fatalf("dial websocket through proxy: %v", err) + } + defer conn.Close() + + if err := conn.WriteMessage(websocket.TextMessage, []byte("hello")); err != nil { + t.Fatalf("write websocket message: %v", err) + } + + messageType, payload, err := conn.ReadMessage() + if err != nil { + t.Fatalf("read websocket message: %v", err) + } + if messageType != websocket.TextMessage { + t.Fatalf("message type = %d, want %d", messageType, websocket.TextMessage) + } + if got, want := string(payload), "echo:hello"; got != want { + t.Fatalf("payload = %q, want %q", got, want) + } +}