mirror of
https://github.com/fosrl/newt.git
synced 2026-05-07 00:20:00 -05:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9e92c42876 | ||
|
|
ffd26f9a6d | ||
|
|
bf33a66043 | ||
|
|
df3aa60cf5 | ||
|
|
cc663f1636 | ||
|
|
af2ecf486a | ||
|
|
a0d2bb999a |
14
common.go
14
common.go
@@ -279,7 +279,7 @@ func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Clien
|
|||||||
|
|
||||||
// More lenient threshold for declaring connection lost under load
|
// More lenient threshold for declaring connection lost under load
|
||||||
failureThreshold := 4
|
failureThreshold := 4
|
||||||
if consecutiveFailures >= failureThreshold {
|
if consecutiveFailures >= failureThreshold && currentInterval < maxInterval {
|
||||||
if !connectionLost {
|
if !connectionLost {
|
||||||
connectionLost = true
|
connectionLost = true
|
||||||
logger.Warn("Connection to server lost after %d failures. Continuous reconnection attempts will be made.", consecutiveFailures)
|
logger.Warn("Connection to server lost after %d failures. Continuous reconnection attempts will be made.", consecutiveFailures)
|
||||||
@@ -309,14 +309,12 @@ func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Clien
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if currentInterval < maxInterval {
|
currentInterval = time.Duration(float64(currentInterval) * 1.3) // Slower increase
|
||||||
currentInterval = time.Duration(float64(currentInterval) * 1.3) // Slower increase
|
if currentInterval > maxInterval {
|
||||||
if currentInterval > maxInterval {
|
currentInterval = maxInterval
|
||||||
currentInterval = maxInterval
|
|
||||||
}
|
|
||||||
ticker.Reset(currentInterval)
|
|
||||||
logger.Debug("Increased ping check interval to %v due to consecutive failures", currentInterval)
|
|
||||||
}
|
}
|
||||||
|
ticker.Reset(currentInterval)
|
||||||
|
logger.Debug("Increased ping check interval to %v due to consecutive failures", currentInterval)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Track recent latencies
|
// Track recent latencies
|
||||||
|
|||||||
@@ -6,10 +6,8 @@
|
|||||||
package netstack2
|
package netstack2
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -31,7 +29,7 @@ import (
|
|||||||
type HTTPTarget struct {
|
type HTTPTarget struct {
|
||||||
DestAddr string `json:"destAddr"` // IP address or hostname of the downstream service
|
DestAddr string `json:"destAddr"` // IP address or hostname of the downstream service
|
||||||
DestPort uint16 `json:"destPort"` // TCP port of the downstream service
|
DestPort uint16 `json:"destPort"` // TCP port of the downstream service
|
||||||
Scheme string `json:"scheme"` // When true the outbound leg uses HTTPS
|
Scheme string `json:"scheme"` // When true the outbound leg uses HTTPS
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
@@ -324,24 +322,6 @@ func (sc *statusCapture) WriteHeader(code int) {
|
|||||||
sc.ResponseWriter.WriteHeader(code)
|
sc.ResponseWriter.WriteHeader(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sc *statusCapture) Unwrap() http.ResponseWriter {
|
|
||||||
return sc.ResponseWriter
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sc *statusCapture) Flush() {
|
|
||||||
if flusher, ok := sc.ResponseWriter.(http.Flusher); ok {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sc *statusCapture) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
|
||||||
hijacker, ok := sc.ResponseWriter.(http.Hijacker)
|
|
||||||
if !ok {
|
|
||||||
return nil, nil, errors.New("underlying response writer does not support hijacking")
|
|
||||||
}
|
|
||||||
return hijacker.Hijack()
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleRequest is the http.Handler entry point. It retrieves the SubnetRule
|
// handleRequest is the http.Handler entry point. It retrieves the SubnetRule
|
||||||
// attached to the connection by ConnContext, selects the first configured
|
// attached to the connection by ConnContext, selects the first configured
|
||||||
// downstream target, and forwards the request via the cached ReverseProxy.
|
// downstream target, and forwards the request via the cached ReverseProxy.
|
||||||
@@ -356,16 +336,16 @@ func (h *HTTPHandler) handleRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the rule is HTTPS and a TLS certificate is configured, but the
|
// If the rule is plain HTTP but has a TLS certificate configured, redirect
|
||||||
// incoming request arrived over plain HTTP, redirect to HTTPS.
|
// the client to the HTTPS equivalent of the requested URL.
|
||||||
if rule.Protocol == "https" && rule.TLSCert != "" && rule.TLSKey != "" && r.TLS == nil {
|
if rule.Protocol == "http" && rule.TLSCert != "" && rule.TLSKey != "" {
|
||||||
host := r.Host
|
host := r.Host
|
||||||
if host == "" {
|
if host == "" {
|
||||||
host = r.URL.Host
|
host = r.URL.Host
|
||||||
}
|
}
|
||||||
httpsURL := "https://" + host + r.RequestURI
|
httpsURL := "https://" + host + r.RequestURI
|
||||||
logger.Info("HTTP handler: redirecting %s %s -> %s (TLS cert present)", r.Method, r.URL.RequestURI(), httpsURL)
|
logger.Info("HTTP handler: redirecting %s %s -> %s (TLS cert present)", r.Method, r.URL.RequestURI(), httpsURL)
|
||||||
http.Redirect(w, r, httpsURL, http.StatusPermanentRedirect)
|
http.Redirect(w, r, httpsURL, http.StatusMovedPermanently)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,97 +0,0 @@
|
|||||||
package netstack2
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"net/url"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestHTTPHandlerProxiesWebSocketUpgrade(t *testing.T) {
|
|
||||||
upgrader := websocket.Upgrader{}
|
|
||||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
conn, err := upgrader.Upgrade(w, r, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("upgrade failed: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
messageType, payload, err := conn.ReadMessage()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("read failed: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := conn.WriteMessage(messageType, append([]byte("echo:"), payload...)); err != nil {
|
|
||||||
t.Errorf("write failed: %v", err)
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
backendURL, err := url.Parse(backend.URL)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("parse backend URL: %v", err)
|
|
||||||
}
|
|
||||||
backendHost, backendPort, err := net.SplitHostPort(backendURL.Host)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("split backend host: %v", err)
|
|
||||||
}
|
|
||||||
port, err := net.LookupPort("tcp", backendPort)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("parse backend port: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
handler := NewHTTPHandler(nil, nil)
|
|
||||||
rule := &SubnetRule{
|
|
||||||
Protocol: "http",
|
|
||||||
HTTPTargets: []HTTPTarget{
|
|
||||||
{
|
|
||||||
DestAddr: backendHost,
|
|
||||||
DestPort: uint16(port),
|
|
||||||
Scheme: backendURL.Scheme,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
ctx := context.WithValue(r.Context(), connCtxKey{}, rule)
|
|
||||||
handler.handleRequest(w, r.WithContext(ctx))
|
|
||||||
}))
|
|
||||||
defer frontend.Close()
|
|
||||||
|
|
||||||
frontendURL, err := url.Parse(frontend.URL)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("parse frontend URL: %v", err)
|
|
||||||
}
|
|
||||||
wsURL := url.URL{
|
|
||||||
Scheme: "ws",
|
|
||||||
Host: frontendURL.Host,
|
|
||||||
Path: "/socket",
|
|
||||||
RawQuery: "token=test",
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL.String(), nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("dial websocket through proxy: %v", err)
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
if err := conn.WriteMessage(websocket.TextMessage, []byte("hello")); err != nil {
|
|
||||||
t.Fatalf("write websocket message: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
messageType, payload, err := conn.ReadMessage()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("read websocket message: %v", err)
|
|
||||||
}
|
|
||||||
if messageType != websocket.TextMessage {
|
|
||||||
t.Fatalf("message type = %d, want %d", messageType, websocket.TextMessage)
|
|
||||||
}
|
|
||||||
if got, want := string(payload), "echo:hello"; got != want {
|
|
||||||
t.Fatalf("payload = %q, want %q", got, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,60 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import sys
|
|
||||||
import websockets
|
|
||||||
|
|
||||||
# Argument parsing: Check if HOST and PORT are provided
|
|
||||||
if len(sys.argv) < 3 or len(sys.argv) > 4:
|
|
||||||
print("Usage: python ws_client.py <HOST_IP> <HOST_PORT> [ws|wss]")
|
|
||||||
# Example: python ws_client.py 127.0.0.1 8765
|
|
||||||
# Example: python ws_client.py 127.0.0.1 8765 wss
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
HOST = sys.argv[1]
|
|
||||||
try:
|
|
||||||
PORT = int(sys.argv[2])
|
|
||||||
except ValueError:
|
|
||||||
print("Error: HOST_PORT must be an integer.")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if len(sys.argv) == 4:
|
|
||||||
SCHEME = sys.argv[3].lower()
|
|
||||||
if SCHEME not in ("ws", "wss"):
|
|
||||||
print("Error: scheme must be 'ws' or 'wss'.")
|
|
||||||
sys.exit(1)
|
|
||||||
else:
|
|
||||||
SCHEME = "ws"
|
|
||||||
|
|
||||||
URI = f"{SCHEME}://{HOST}:{PORT}"
|
|
||||||
|
|
||||||
# The message to send to the server
|
|
||||||
MESSAGE = "Hello WebSocket Server! How are you?"
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
print(f"Connecting to {URI}...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with websockets.connect(URI) as websocket:
|
|
||||||
print(f"Connected to server.")
|
|
||||||
print(f"Sending message: '{MESSAGE}'")
|
|
||||||
|
|
||||||
await websocket.send(MESSAGE)
|
|
||||||
|
|
||||||
response = await websocket.recv()
|
|
||||||
|
|
||||||
print("-" * 30)
|
|
||||||
print(f"Received response from server:")
|
|
||||||
print(f"-> Data: '{response}'")
|
|
||||||
|
|
||||||
except ConnectionRefusedError:
|
|
||||||
print(f"Error: Connection to {URI} was refused. Is the server running?")
|
|
||||||
except websockets.exceptions.InvalidMessage as e:
|
|
||||||
print(f"Error: Server did not respond with a valid WebSocket handshake: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error during communication: {e}")
|
|
||||||
|
|
||||||
print("-" * 30)
|
|
||||||
print("Client finished.")
|
|
||||||
|
|
||||||
|
|
||||||
asyncio.run(main())
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import sys
|
|
||||||
import websockets
|
|
||||||
|
|
||||||
# Optionally take in a positional arg for the port
|
|
||||||
if len(sys.argv) > 1:
|
|
||||||
try:
|
|
||||||
PORT = int(sys.argv[1])
|
|
||||||
except ValueError:
|
|
||||||
print("Invalid port number. Using default port 8765.")
|
|
||||||
PORT = 8765
|
|
||||||
else:
|
|
||||||
PORT = 8765
|
|
||||||
|
|
||||||
# Define the server host
|
|
||||||
HOST = "0.0.0.0"
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_client(websocket):
|
|
||||||
client_address = websocket.remote_address
|
|
||||||
print(f"Client connected: {client_address[0]}:{client_address[1]}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
async for message in websocket:
|
|
||||||
print("-" * 30)
|
|
||||||
print(f"Received message from {client_address[0]}:{client_address[1]}:")
|
|
||||||
print(f"-> Data: '{message}'")
|
|
||||||
|
|
||||||
response = f"Hello client! Server received: '{message.upper()}'"
|
|
||||||
|
|
||||||
await websocket.send(response)
|
|
||||||
print(f"Sent response back to client.")
|
|
||||||
|
|
||||||
except websockets.exceptions.ConnectionClosedOK:
|
|
||||||
print(f"Client {client_address[0]}:{client_address[1]} disconnected cleanly.")
|
|
||||||
except websockets.exceptions.ConnectionClosedError as e:
|
|
||||||
print(f"Client {client_address[0]}:{client_address[1]} disconnected with error: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
print(f"WebSocket Server listening on {HOST}:{PORT}")
|
|
||||||
async with websockets.serve(handle_client, HOST, PORT):
|
|
||||||
await asyncio.Future() # Run forever
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
asyncio.run(main())
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("\nServer stopped.")
|
|
||||||
Reference in New Issue
Block a user