mirror of
https://github.com/fosrl/newt.git
synced 2026-05-05 23:50:10 -05:00
Compare commits
46 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
57aa2e2e2c | ||
|
|
5724c516dc | ||
|
|
b33c3b8849 | ||
|
|
8e19e475bf | ||
|
|
9e92c42876 | ||
|
|
66c72bbe2e | ||
|
|
ffd26f9a6d | ||
|
|
7610aa40bf | ||
|
|
bf33a66043 | ||
|
|
23caf57bf4 | ||
|
|
df3aa60cf5 | ||
|
|
5c43db466a | ||
|
|
cc663f1636 | ||
|
|
1a67ff30c2 | ||
|
|
bfd61ca511 | ||
|
|
294f99e024 | ||
|
|
af2ecf486a | ||
|
|
efd6743ce4 | ||
|
|
a0d2bb999a | ||
|
|
5d889fbc09 | ||
|
|
1a7cf06ff8 | ||
|
|
35a334c842 | ||
|
|
c8e5112a2a | ||
|
|
8bfb4659c0 | ||
|
|
309f9caad2 | ||
|
|
26de268466 | ||
|
|
0f927a37ab | ||
|
|
e8961c5de5 | ||
|
|
9bb8eaeadb | ||
|
|
d3d10d02e8 | ||
|
|
be1cd190e7 | ||
|
|
5c9d13bcca | ||
|
|
dc2e23380a | ||
|
|
12776d65c1 | ||
|
|
0569525743 | ||
|
|
342af9e42d | ||
|
|
092535441e | ||
|
|
5848c8d4b4 | ||
|
|
6becf0f719 | ||
|
|
47c646bc33 | ||
|
|
4d8d00241d | ||
|
|
31f899588f | ||
|
|
7e1e3408d5 | ||
|
|
d7c3c38d24 | ||
|
|
27e471942e | ||
|
|
184bfb12d6 |
1
.github/CODEOWNERS
vendored
Normal file
1
.github/CODEOWNERS
vendored
Normal file
@@ -0,0 +1 @@
|
||||
* @oschwartz10612 @miloschwartz
|
||||
9
.github/workflows/cicd.yml
vendored
9
.github/workflows/cicd.yml
vendored
@@ -110,15 +110,6 @@ jobs:
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Update version in flake.nix
|
||||
shell: bash
|
||||
env:
|
||||
VERSION: ${{ inputs.version }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
sed -i "s/version = \"[0-9]*\.[0-9]*\.[0-9]*\(-rc\.[0-9]*\)\?\"/version = \"$VERSION\"/" flake.nix
|
||||
echo "Updated flake.nix version to $VERSION"
|
||||
|
||||
- name: Create and push tag
|
||||
shell: bash
|
||||
env:
|
||||
|
||||
2
Makefile
2
Makefile
@@ -6,7 +6,7 @@ VERSION ?= dev
|
||||
LDFLAGS = -X main.newtVersion=$(VERSION)
|
||||
|
||||
local:
|
||||
CGO_ENABLED=0 go build -o ./bin/newt
|
||||
CGO_ENABLED=0 go build -ldflags "$(LDFLAGS)" -o ./bin/newt
|
||||
|
||||
docker-build:
|
||||
docker build -t fosrl/newt:latest .
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
resources:
|
||||
resource-nice-id:
|
||||
name: this is my resource
|
||||
protocol: http
|
||||
full-domain: level1.test3.example.com
|
||||
host-header: example.com
|
||||
tls-server-name: example.com
|
||||
auth:
|
||||
pincode: 123456
|
||||
password: sadfasdfadsf
|
||||
sso-enabled: true
|
||||
sso-roles:
|
||||
- Member
|
||||
sso-users:
|
||||
- owen@pangolin.net
|
||||
whitelist-users:
|
||||
- owen@pangolin.net
|
||||
targets:
|
||||
# - site: glossy-plains-viscacha-rat
|
||||
- hostname: localhost
|
||||
method: http
|
||||
port: 8000
|
||||
healthcheck:
|
||||
port: 8000
|
||||
hostname: localhost
|
||||
# - site: glossy-plains-viscacha-rat
|
||||
- hostname: localhost
|
||||
method: http
|
||||
port: 8001
|
||||
resource-nice-id2:
|
||||
name: this is other resource
|
||||
protocol: tcp
|
||||
proxy-port: 3000
|
||||
targets:
|
||||
# - site: glossy-plains-viscacha-rat
|
||||
- hostname: localhost
|
||||
port: 3000
|
||||
@@ -40,13 +40,17 @@ type WgConfig struct {
|
||||
}
|
||||
|
||||
type Target struct {
|
||||
SourcePrefix string `json:"sourcePrefix"`
|
||||
SourcePrefixes []string `json:"sourcePrefixes"`
|
||||
DestPrefix string `json:"destPrefix"`
|
||||
RewriteTo string `json:"rewriteTo,omitempty"`
|
||||
DisableIcmp bool `json:"disableIcmp,omitempty"`
|
||||
PortRange []PortRange `json:"portRange,omitempty"`
|
||||
ResourceId int `json:"resourceId,omitempty"`
|
||||
SourcePrefix string `json:"sourcePrefix"`
|
||||
SourcePrefixes []string `json:"sourcePrefixes"`
|
||||
DestPrefix string `json:"destPrefix"`
|
||||
RewriteTo string `json:"rewriteTo,omitempty"`
|
||||
DisableIcmp bool `json:"disableIcmp,omitempty"`
|
||||
PortRange []PortRange `json:"portRange,omitempty"`
|
||||
ResourceId int `json:"resourceId,omitempty"`
|
||||
Protocol string `json:"protocol,omitempty"` // for now practicably either http or https
|
||||
HTTPTargets []netstack2.HTTPTarget `json:"httpTargets,omitempty"` // for http protocol, list of downstream services to load balance across
|
||||
TLSCert string `json:"tlsCert,omitempty"` // PEM-encoded certificate for incoming HTTPS termination
|
||||
TLSKey string `json:"tlsKey,omitempty"` // PEM-encoded private key for incoming HTTPS termination
|
||||
}
|
||||
|
||||
type PortRange struct {
|
||||
@@ -74,18 +78,18 @@ type PeerReading struct {
|
||||
}
|
||||
|
||||
type WireGuardService struct {
|
||||
interfaceName string
|
||||
mtu int
|
||||
client *websocket.Client
|
||||
config WgConfig
|
||||
key wgtypes.Key
|
||||
newtId string
|
||||
lastReadings map[string]PeerReading
|
||||
mu sync.Mutex
|
||||
Port uint16
|
||||
host string
|
||||
serverPubKey string
|
||||
token string
|
||||
interfaceName string
|
||||
mtu int
|
||||
client *websocket.Client
|
||||
config WgConfig
|
||||
key wgtypes.Key
|
||||
newtId string
|
||||
lastReadings map[string]PeerReading
|
||||
mu sync.Mutex
|
||||
Port uint16
|
||||
host string
|
||||
serverPubKey string
|
||||
token string
|
||||
stopGetConfig func()
|
||||
pendingConfigChainId string
|
||||
// Netstack fields
|
||||
@@ -697,7 +701,18 @@ func (s *WireGuardService) syncTargets(desiredTargets []Target) error {
|
||||
})
|
||||
}
|
||||
|
||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId)
|
||||
s.tnet.AddProxySubnetRule(netstack2.SubnetRule{
|
||||
SourcePrefix: sourcePrefix,
|
||||
DestPrefix: destPrefix,
|
||||
RewriteTo: target.RewriteTo,
|
||||
PortRanges: portRanges,
|
||||
DisableIcmp: target.DisableIcmp,
|
||||
ResourceId: target.ResourceId,
|
||||
Protocol: target.Protocol,
|
||||
HTTPTargets: target.HTTPTargets,
|
||||
TLSCert: target.TLSCert,
|
||||
TLSKey: target.TLSKey,
|
||||
})
|
||||
logger.Info("Added target %s -> %s during sync", target.SourcePrefix, target.DestPrefix)
|
||||
}
|
||||
}
|
||||
@@ -835,6 +850,13 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
||||
})
|
||||
})
|
||||
|
||||
// Configure the HTTP request log sender to ship compressed request logs via websocket
|
||||
s.tnet.SetHTTPRequestLogSender(func(data string) error {
|
||||
return s.client.SendMessageNoLog("newt/request-log", map[string]interface{}{
|
||||
"compressed": data,
|
||||
})
|
||||
})
|
||||
|
||||
// Create WireGuard device using the shared bind
|
||||
s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger(
|
||||
device.LogLevelSilent, // Use silent logging by default - could be made configurable
|
||||
@@ -955,7 +977,18 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid CIDR %s: %v", sp, err)
|
||||
}
|
||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId)
|
||||
s.tnet.AddProxySubnetRule(netstack2.SubnetRule{
|
||||
SourcePrefix: sourcePrefix,
|
||||
DestPrefix: destPrefix,
|
||||
RewriteTo: target.RewriteTo,
|
||||
PortRanges: portRanges,
|
||||
DisableIcmp: target.DisableIcmp,
|
||||
ResourceId: target.ResourceId,
|
||||
Protocol: target.Protocol,
|
||||
HTTPTargets: target.HTTPTargets,
|
||||
TLSCert: target.TLSCert,
|
||||
TLSKey: target.TLSKey,
|
||||
})
|
||||
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
|
||||
}
|
||||
}
|
||||
@@ -1348,7 +1381,18 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
|
||||
logger.Info("Invalid CIDR %s: %v", sp, err)
|
||||
continue
|
||||
}
|
||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId)
|
||||
s.tnet.AddProxySubnetRule(netstack2.SubnetRule{
|
||||
SourcePrefix: sourcePrefix,
|
||||
DestPrefix: destPrefix,
|
||||
RewriteTo: target.RewriteTo,
|
||||
PortRanges: portRanges,
|
||||
DisableIcmp: target.DisableIcmp,
|
||||
ResourceId: target.ResourceId,
|
||||
Protocol: target.Protocol,
|
||||
HTTPTargets: target.HTTPTargets,
|
||||
TLSCert: target.TLSCert,
|
||||
TLSKey: target.TLSKey,
|
||||
})
|
||||
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
|
||||
}
|
||||
}
|
||||
@@ -1466,7 +1510,18 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
|
||||
logger.Info("Invalid CIDR %s: %v", sp, err)
|
||||
continue
|
||||
}
|
||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId)
|
||||
s.tnet.AddProxySubnetRule(netstack2.SubnetRule{
|
||||
SourcePrefix: sourcePrefix,
|
||||
DestPrefix: destPrefix,
|
||||
RewriteTo: target.RewriteTo,
|
||||
PortRanges: portRanges,
|
||||
DisableIcmp: target.DisableIcmp,
|
||||
ResourceId: target.ResourceId,
|
||||
Protocol: target.Protocol,
|
||||
HTTPTargets: target.HTTPTargets,
|
||||
TLSCert: target.TLSCert,
|
||||
TLSKey: target.TLSKey,
|
||||
})
|
||||
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
|
||||
}
|
||||
}
|
||||
|
||||
206
get-newt.sh
206
get-newt.sh
@@ -30,41 +30,38 @@ print_error() {
|
||||
|
||||
# Function to get latest version from GitHub API
|
||||
get_latest_version() {
|
||||
local latest_info
|
||||
|
||||
latest_info=""
|
||||
|
||||
if command -v curl >/dev/null 2>&1; then
|
||||
latest_info=$(curl -fsSL "$GITHUB_API_URL" 2>/dev/null)
|
||||
elif command -v wget >/dev/null 2>&1; then
|
||||
latest_info=$(wget -qO- "$GITHUB_API_URL" 2>/dev/null)
|
||||
else
|
||||
print_error "Neither curl nor wget is available. Please install one of them." >&2
|
||||
print_error "Neither curl nor wget is available."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
if [ -z "$latest_info" ]; then
|
||||
print_error "Failed to fetch latest version information" >&2
|
||||
print_error "Failed to fetch latest version info"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Extract version from JSON response (works without jq)
|
||||
local version=$(echo "$latest_info" | grep '"tag_name"' | head -1 | sed 's/.*"tag_name": *"\([^"]*\)".*/\1/')
|
||||
|
||||
|
||||
version=$(printf '%s' "$latest_info" | grep '"tag_name"' | head -1 | sed 's/.*"tag_name": *"\([^"]*\)".*/\1/')
|
||||
|
||||
if [ -z "$version" ]; then
|
||||
print_error "Could not parse version from GitHub API response" >&2
|
||||
print_error "Could not parse version from GitHub API response"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Remove 'v' prefix if present
|
||||
version=$(echo "$version" | sed 's/^v//')
|
||||
|
||||
echo "$version"
|
||||
|
||||
version=$(printf '%s' "$version" | sed 's/^v//')
|
||||
printf '%s' "$version"
|
||||
}
|
||||
|
||||
# Detect OS and architecture
|
||||
detect_platform() {
|
||||
local os arch
|
||||
|
||||
# Detect OS
|
||||
os=""
|
||||
arch=""
|
||||
|
||||
case "$(uname -s)" in
|
||||
Linux*) os="linux" ;;
|
||||
Darwin*) os="darwin" ;;
|
||||
@@ -75,12 +72,11 @@ detect_platform() {
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
# Detect architecture
|
||||
|
||||
case "$(uname -m)" in
|
||||
x86_64|amd64) arch="amd64" ;;
|
||||
arm64|aarch64) arch="arm64" ;;
|
||||
armv7l|armv6l)
|
||||
armv7l|armv6l)
|
||||
if [ "$os" = "linux" ]; then
|
||||
if [ "$(uname -m)" = "armv6l" ]; then
|
||||
arch="arm32v6"
|
||||
@@ -88,10 +84,10 @@ detect_platform() {
|
||||
arch="arm32"
|
||||
fi
|
||||
else
|
||||
arch="arm64" # Default for non-Linux ARM
|
||||
arch="arm64"
|
||||
fi
|
||||
;;
|
||||
riscv64)
|
||||
riscv64)
|
||||
if [ "$os" = "linux" ]; then
|
||||
arch="riscv64"
|
||||
else
|
||||
@@ -104,23 +100,68 @@ detect_platform() {
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
echo "${os}_${arch}"
|
||||
|
||||
printf '%s_%s' "$os" "$arch"
|
||||
}
|
||||
|
||||
# Get installation directory
|
||||
# Determine installation directory (default fallback)
|
||||
get_install_dir() {
|
||||
if [ "$OS" = "windows" ]; then
|
||||
echo "$HOME/bin"
|
||||
else
|
||||
# Prefer /usr/local/bin for system-wide installation
|
||||
echo "/usr/local/bin"
|
||||
case "$PLATFORM" in
|
||||
*windows*)
|
||||
echo "$HOME/bin"
|
||||
;;
|
||||
*)
|
||||
echo "/usr/local/bin"
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
# Parse --path argument from args
|
||||
# Returns the value after --path, or empty string if not provided
|
||||
parse_path_arg() {
|
||||
while [ $# -gt 0 ]; do
|
||||
case "$1" in
|
||||
--path)
|
||||
if [ -n "$2" ]; then
|
||||
printf '%s' "$2"
|
||||
return
|
||||
fi
|
||||
;;
|
||||
--path=*)
|
||||
printf '%s' "${1#--path=}"
|
||||
return
|
||||
;;
|
||||
esac
|
||||
shift
|
||||
done
|
||||
}
|
||||
|
||||
# Detect an existing newt binary location.
|
||||
# Tries unprivileged which first, then sudo which (for binaries only visible to root).
|
||||
# Returns the full path of the binary, or empty string if not found.
|
||||
detect_existing_binary() {
|
||||
existing=""
|
||||
|
||||
# Try unprivileged which first
|
||||
existing=$(command -v newt 2>/dev/null || true)
|
||||
if [ -n "$existing" ]; then
|
||||
printf '%s' "$existing"
|
||||
return
|
||||
fi
|
||||
|
||||
# Try sudo which — some installations land in paths only root can see in $PATH
|
||||
if command -v sudo >/dev/null 2>&1; then
|
||||
existing=$(sudo which newt 2>/dev/null || true)
|
||||
if [ -n "$existing" ]; then
|
||||
printf '%s' "$existing"
|
||||
return
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
# Check if we need sudo for installation
|
||||
needs_sudo() {
|
||||
local install_dir="$1"
|
||||
install_dir="$1"
|
||||
if [ -w "$install_dir" ] 2>/dev/null; then
|
||||
return 1 # No sudo needed
|
||||
else
|
||||
@@ -130,7 +171,7 @@ needs_sudo() {
|
||||
|
||||
# Get the appropriate command prefix (sudo or empty)
|
||||
get_sudo_cmd() {
|
||||
local install_dir="$1"
|
||||
install_dir="$1"
|
||||
if needs_sudo "$install_dir"; then
|
||||
if command -v sudo >/dev/null 2>&1; then
|
||||
echo "sudo"
|
||||
@@ -146,40 +187,46 @@ get_sudo_cmd() {
|
||||
|
||||
# Download and install newt
|
||||
install_newt() {
|
||||
local platform="$1"
|
||||
local install_dir="$2"
|
||||
local sudo_cmd="$3"
|
||||
local binary_name="newt_${platform}"
|
||||
local exe_suffix=""
|
||||
platform="$1"
|
||||
install_dir="$2"
|
||||
sudo_cmd="$3"
|
||||
custom_path="$4"
|
||||
binary_name="newt_${platform}"
|
||||
final_name="newt"
|
||||
|
||||
# Add .exe suffix for Windows
|
||||
case "$platform" in
|
||||
*windows*)
|
||||
binary_name="${binary_name}.exe"
|
||||
exe_suffix=".exe"
|
||||
final_name="newt.exe"
|
||||
;;
|
||||
esac
|
||||
|
||||
local download_url="${BASE_URL}/${binary_name}"
|
||||
local temp_file="/tmp/newt${exe_suffix}"
|
||||
local final_path="${install_dir}/newt${exe_suffix}"
|
||||
download_url="${BASE_URL}/${binary_name}"
|
||||
temp_file="/tmp/${final_name}"
|
||||
|
||||
# If a custom path is provided, use it directly; otherwise use install_dir/final_name
|
||||
if [ -n "$custom_path" ]; then
|
||||
final_path="$custom_path"
|
||||
install_dir=$(dirname "$final_path")
|
||||
else
|
||||
final_path="${install_dir}/${final_name}"
|
||||
fi
|
||||
|
||||
print_status "Downloading newt from ${download_url}"
|
||||
|
||||
# Download the binary
|
||||
if command -v curl >/dev/null 2>&1; then
|
||||
curl -fsSL "$download_url" -o "$temp_file"
|
||||
elif command -v wget >/dev/null 2>&1; then
|
||||
wget -q "$download_url" -O "$temp_file"
|
||||
else
|
||||
print_error "Neither curl nor wget is available. Please install one of them."
|
||||
print_error "Neither curl nor wget is available."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Make executable before moving
|
||||
chmod +x "$temp_file"
|
||||
|
||||
# Create install directory if it doesn't exist
|
||||
# Create install directory if it doesn't exist and move binary
|
||||
if [ -n "$sudo_cmd" ]; then
|
||||
$sudo_cmd mkdir -p "$install_dir"
|
||||
print_status "Using sudo to install to ${install_dir}"
|
||||
@@ -194,25 +241,25 @@ install_newt() {
|
||||
# Check if install directory is in PATH
|
||||
if ! echo "$PATH" | grep -q "$install_dir"; then
|
||||
print_warning "Install directory ${install_dir} is not in your PATH."
|
||||
print_warning "Add it to your PATH by adding this line to your shell profile:"
|
||||
print_warning "Add it with:"
|
||||
print_warning " export PATH=\"${install_dir}:\$PATH\""
|
||||
fi
|
||||
}
|
||||
|
||||
# Verify installation
|
||||
verify_installation() {
|
||||
local install_dir="$1"
|
||||
local exe_suffix=""
|
||||
|
||||
install_dir="$1"
|
||||
exe_suffix=""
|
||||
|
||||
case "$PLATFORM" in
|
||||
*windows*) exe_suffix=".exe" ;;
|
||||
esac
|
||||
|
||||
local newt_path="${install_dir}/newt${exe_suffix}"
|
||||
|
||||
if [ -f "$newt_path" ] && [ -x "$newt_path" ]; then
|
||||
|
||||
newt_path="${install_dir}/newt${exe_suffix}"
|
||||
|
||||
if [ -x "$newt_path" ]; then
|
||||
print_status "Installation successful!"
|
||||
print_status "newt version: $("$newt_path" --version 2>/dev/null || echo "unknown")"
|
||||
print_status "newt version: $("$newt_path" --version 2>/dev/null || printf 'unknown')"
|
||||
return 0
|
||||
else
|
||||
print_error "Installation failed. Binary not found or not executable."
|
||||
@@ -222,22 +269,40 @@ verify_installation() {
|
||||
|
||||
# Main installation process
|
||||
main() {
|
||||
print_status "Installing latest version of newt..."
|
||||
# --path explicitly overrides everything
|
||||
CUSTOM_PATH=$(parse_path_arg "$@")
|
||||
|
||||
# Get latest version
|
||||
print_status "Fetching latest version from GitHub..."
|
||||
if [ -n "$CUSTOM_PATH" ]; then
|
||||
print_status "Installing latest version of newt to ${CUSTOM_PATH} (--path override)..."
|
||||
else
|
||||
print_status "Installing latest version of newt..."
|
||||
fi
|
||||
|
||||
print_status "Fetching latest version..."
|
||||
VERSION=$(get_latest_version)
|
||||
print_status "Latest version: v${VERSION}"
|
||||
|
||||
# Set base URL with the fetched version
|
||||
BASE_URL="https://github.com/${REPO}/releases/download/${VERSION}"
|
||||
|
||||
# Detect platform
|
||||
PLATFORM=$(detect_platform)
|
||||
print_status "Detected platform: ${PLATFORM}"
|
||||
|
||||
# Get install directory
|
||||
INSTALL_DIR=$(get_install_dir)
|
||||
if [ -n "$CUSTOM_PATH" ]; then
|
||||
# --path wins; derive INSTALL_DIR from it
|
||||
INSTALL_DIR=$(dirname "$CUSTOM_PATH")
|
||||
else
|
||||
# Try to find an existing installation so we update the right place
|
||||
EXISTING_BINARY=$(detect_existing_binary)
|
||||
if [ -n "$EXISTING_BINARY" ]; then
|
||||
print_status "Found existing newt binary at ${EXISTING_BINARY}"
|
||||
CUSTOM_PATH="$EXISTING_BINARY"
|
||||
INSTALL_DIR=$(dirname "$EXISTING_BINARY")
|
||||
print_status "Will update existing installation at ${INSTALL_DIR}"
|
||||
else
|
||||
INSTALL_DIR=$(get_install_dir)
|
||||
fi
|
||||
fi
|
||||
|
||||
print_status "Install directory: ${INSTALL_DIR}"
|
||||
|
||||
# Check if we need sudo
|
||||
@@ -246,13 +311,20 @@ main() {
|
||||
print_status "Root privileges required for installation to ${INSTALL_DIR}"
|
||||
fi
|
||||
|
||||
# Install newt
|
||||
install_newt "$PLATFORM" "$INSTALL_DIR" "$SUDO_CMD"
|
||||
install_newt "$PLATFORM" "$INSTALL_DIR" "$SUDO_CMD" "$CUSTOM_PATH"
|
||||
|
||||
# Verify installation
|
||||
if verify_installation "$INSTALL_DIR"; then
|
||||
if [ -n "$CUSTOM_PATH" ]; then
|
||||
if [ -x "$CUSTOM_PATH" ]; then
|
||||
print_status "Installation successful!"
|
||||
print_status "newt version: $("$CUSTOM_PATH" --version 2>/dev/null || printf 'unknown')"
|
||||
print_status "newt is ready to use!"
|
||||
else
|
||||
print_error "Installation failed. Binary not found or not executable at ${CUSTOM_PATH}."
|
||||
exit 1
|
||||
fi
|
||||
elif verify_installation "$INSTALL_DIR"; then
|
||||
print_status "newt is ready to use!"
|
||||
print_status "Run 'newt --help' to get started"
|
||||
print_status "Run 'newt --help' to get started."
|
||||
else
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@@ -37,33 +37,38 @@ func (s Health) String() string {
|
||||
|
||||
// Config holds the health check configuration for a target
|
||||
type Config struct {
|
||||
ID int `json:"id"`
|
||||
Enabled bool `json:"hcEnabled"`
|
||||
Path string `json:"hcPath"`
|
||||
Scheme string `json:"hcScheme"`
|
||||
Mode string `json:"hcMode"`
|
||||
Hostname string `json:"hcHostname"`
|
||||
Port int `json:"hcPort"`
|
||||
Interval int `json:"hcInterval"` // in seconds
|
||||
UnhealthyInterval int `json:"hcUnhealthyInterval"` // in seconds
|
||||
Timeout int `json:"hcTimeout"` // in seconds
|
||||
Headers map[string]string `json:"hcHeaders"`
|
||||
Method string `json:"hcMethod"`
|
||||
Status int `json:"hcStatus"` // HTTP status code
|
||||
TLSServerName string `json:"hcTlsServerName"`
|
||||
ID int `json:"id"`
|
||||
Enabled bool `json:"hcEnabled"`
|
||||
Path string `json:"hcPath"`
|
||||
Scheme string `json:"hcScheme"`
|
||||
Mode string `json:"hcMode"`
|
||||
Hostname string `json:"hcHostname"`
|
||||
Port int `json:"hcPort"`
|
||||
Interval int `json:"hcInterval"` // in seconds
|
||||
UnhealthyInterval int `json:"hcUnhealthyInterval"` // in seconds
|
||||
Timeout int `json:"hcTimeout"` // in seconds
|
||||
FollowRedirects *bool `json:"hcFollowRedirects"`
|
||||
Headers map[string]string `json:"hcHeaders"`
|
||||
Method string `json:"hcMethod"`
|
||||
Status int `json:"hcStatus"` // HTTP status code
|
||||
TLSServerName string `json:"hcTlsServerName"`
|
||||
HealthyThreshold int `json:"hcHealthyThreshold"` // consecutive successes required to become healthy
|
||||
UnhealthyThreshold int `json:"hcUnhealthyThreshold"` // consecutive failures required to become unhealthy
|
||||
}
|
||||
|
||||
// Target represents a health check target with its current status
|
||||
type Target struct {
|
||||
Config Config `json:"config"`
|
||||
Status Health `json:"status"`
|
||||
LastCheck time.Time `json:"lastCheck"`
|
||||
LastError string `json:"lastError,omitempty"`
|
||||
CheckCount int `json:"checkCount"`
|
||||
timer *time.Timer
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
client *http.Client
|
||||
Config Config `json:"config"`
|
||||
Status Health `json:"status"`
|
||||
LastCheck time.Time `json:"lastCheck"`
|
||||
LastError string `json:"lastError,omitempty"`
|
||||
CheckCount int `json:"checkCount"`
|
||||
timer *time.Timer
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
client *http.Client
|
||||
consecutiveSuccesses int
|
||||
consecutiveFailures int
|
||||
}
|
||||
|
||||
// StatusChangeCallback is called when any target's status changes
|
||||
@@ -165,9 +170,16 @@ func (m *Monitor) addTargetUnsafe(config Config) error {
|
||||
if config.Timeout == 0 {
|
||||
config.Timeout = 5
|
||||
}
|
||||
if config.HealthyThreshold == 0 {
|
||||
config.HealthyThreshold = 1
|
||||
}
|
||||
if config.UnhealthyThreshold == 0 {
|
||||
config.UnhealthyThreshold = 1
|
||||
}
|
||||
|
||||
logger.Debug("Target %d configuration: scheme=%s, method=%s, interval=%ds, timeout=%ds",
|
||||
config.ID, config.Scheme, config.Method, config.Interval, config.Timeout)
|
||||
logger.Debug("Target %d configuration: mode=%s, scheme=%s, method=%s, interval=%ds, timeout=%ds, healthyThreshold=%d, unhealthyThreshold=%d",
|
||||
config.ID, config.Mode, config.Scheme, config.Method, config.Interval, config.Timeout,
|
||||
config.HealthyThreshold, config.UnhealthyThreshold)
|
||||
|
||||
// Parse headers if provided as string
|
||||
if len(config.Headers) == 0 && config.Path != "" {
|
||||
@@ -189,6 +201,16 @@ func (m *Monitor) addTargetUnsafe(config Config) error {
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
client: &http.Client{
|
||||
CheckRedirect: func() func(*http.Request, []*http.Request) error {
|
||||
// Default to following redirects if not explicitly configured
|
||||
followRedirects := config.FollowRedirects == nil || *config.FollowRedirects
|
||||
if !followRedirects {
|
||||
return func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}(),
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
// Configure TLS settings based on certificate enforcement
|
||||
@@ -230,7 +252,7 @@ func (m *Monitor) RemoveTarget(id int) error {
|
||||
|
||||
// Notify callback of status change
|
||||
if m.callback != nil {
|
||||
go m.callback(m.GetTargets())
|
||||
go m.callback(m.getAllTargetsUnsafe())
|
||||
}
|
||||
|
||||
logger.Info("Successfully removed target %d", id)
|
||||
@@ -263,7 +285,7 @@ func (m *Monitor) RemoveTargets(ids []int) error {
|
||||
|
||||
// Notify callback of status change if any targets were removed
|
||||
if len(notFound) != len(ids) && m.callback != nil {
|
||||
go m.callback(m.GetTargets())
|
||||
go m.callback(m.getAllTargetsUnsafe())
|
||||
}
|
||||
|
||||
if len(notFound) > 0 {
|
||||
@@ -361,12 +383,69 @@ func (m *Monitor) monitorTarget(target *Target) {
|
||||
}
|
||||
}
|
||||
|
||||
// performHealthCheck performs a health check on a target
|
||||
// performHealthCheck performs a health check on a target and applies threshold logic
|
||||
func (m *Monitor) performHealthCheck(target *Target) {
|
||||
target.CheckCount++
|
||||
target.LastCheck = time.Now()
|
||||
target.LastError = ""
|
||||
|
||||
var passed bool
|
||||
var checkErr string
|
||||
|
||||
switch strings.ToLower(target.Config.Mode) {
|
||||
case "tcp":
|
||||
passed, checkErr = m.performTCPCheck(target)
|
||||
default:
|
||||
// "http", "https", or anything else falls through to HTTP
|
||||
passed, checkErr = m.performHTTPCheck(target)
|
||||
}
|
||||
|
||||
if passed {
|
||||
target.consecutiveFailures = 0
|
||||
target.consecutiveSuccesses++
|
||||
|
||||
logger.Debug("Target %d: check passed (consecutive successes: %d / threshold: %d)",
|
||||
target.Config.ID, target.consecutiveSuccesses, target.Config.HealthyThreshold)
|
||||
|
||||
if target.consecutiveSuccesses >= target.Config.HealthyThreshold {
|
||||
target.Status = StatusHealthy
|
||||
target.LastError = ""
|
||||
}
|
||||
} else {
|
||||
target.consecutiveSuccesses = 0
|
||||
target.consecutiveFailures++
|
||||
target.LastError = checkErr
|
||||
|
||||
logger.Debug("Target %d: check failed (consecutive failures: %d / threshold: %d): %s",
|
||||
target.Config.ID, target.consecutiveFailures, target.Config.UnhealthyThreshold, checkErr)
|
||||
|
||||
if target.consecutiveFailures >= target.Config.UnhealthyThreshold {
|
||||
target.Status = StatusUnhealthy
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// performTCPCheck dials the target's host:port over TCP and returns whether it succeeded
|
||||
func (m *Monitor) performTCPCheck(target *Target) (bool, string) {
|
||||
address := net.JoinHostPort(target.Config.Hostname, strconv.Itoa(target.Config.Port))
|
||||
timeout := time.Duration(target.Config.Timeout) * time.Second
|
||||
|
||||
logger.Debug("Target %d: performing TCP health check to %s (timeout: %v)",
|
||||
target.Config.ID, address, timeout)
|
||||
|
||||
conn, err := net.DialTimeout("tcp", address, timeout)
|
||||
if err != nil {
|
||||
msg := fmt.Sprintf("TCP dial failed: %v", err)
|
||||
logger.Warn("Target %d: %s", target.Config.ID, msg)
|
||||
return false, msg
|
||||
}
|
||||
conn.Close()
|
||||
|
||||
logger.Debug("Target %d: TCP health check passed", target.Config.ID)
|
||||
return true, ""
|
||||
}
|
||||
|
||||
// performHTTPCheck performs an HTTP/HTTPS health check and returns whether it succeeded
|
||||
func (m *Monitor) performHTTPCheck(target *Target) (bool, string) {
|
||||
// Build URL (use net.JoinHostPort to properly handle IPv6 addresses with ports)
|
||||
host := target.Config.Hostname
|
||||
if target.Config.Port > 0 {
|
||||
@@ -380,7 +459,7 @@ func (m *Monitor) performHealthCheck(target *Target) {
|
||||
url += target.Config.Path
|
||||
}
|
||||
|
||||
logger.Debug("Target %d: performing health check %d to %s",
|
||||
logger.Debug("Target %d: performing HTTP health check %d to %s",
|
||||
target.Config.ID, target.CheckCount, url)
|
||||
|
||||
if target.Config.Scheme == "https" {
|
||||
@@ -388,16 +467,15 @@ func (m *Monitor) performHealthCheck(target *Target) {
|
||||
target.Config.ID, m.enforceCert)
|
||||
}
|
||||
|
||||
// Create request
|
||||
// Create request with timeout context
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(target.Config.Timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, target.Config.Method, url, nil)
|
||||
if err != nil {
|
||||
target.Status = StatusUnhealthy
|
||||
target.LastError = fmt.Sprintf("failed to create request: %v", err)
|
||||
logger.Warn("Target %d: failed to create request: %v", target.Config.ID, err)
|
||||
return
|
||||
msg := fmt.Sprintf("failed to create request: %v", err)
|
||||
logger.Warn("Target %d: %s", target.Config.ID, msg)
|
||||
return false, msg
|
||||
}
|
||||
|
||||
// Add headers
|
||||
@@ -413,43 +491,34 @@ func (m *Monitor) performHealthCheck(target *Target) {
|
||||
// Perform request
|
||||
resp, err := target.client.Do(req)
|
||||
if err != nil {
|
||||
target.Status = StatusUnhealthy
|
||||
target.LastError = fmt.Sprintf("request failed: %v", err)
|
||||
msg := fmt.Sprintf("request failed: %v", err)
|
||||
logger.Warn("Target %d: health check failed: %v", target.Config.ID, err)
|
||||
return
|
||||
return false, msg
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Check response status
|
||||
var expectedStatus int
|
||||
if target.Config.Status > 0 {
|
||||
expectedStatus = target.Config.Status
|
||||
} else {
|
||||
expectedStatus = 0 // Use range check for 200-299
|
||||
// Check for specific status code
|
||||
logger.Debug("Target %d: checking status against expected code %d", target.Config.ID, target.Config.Status)
|
||||
if resp.StatusCode == target.Config.Status {
|
||||
logger.Debug("Target %d: health check passed (status: %d)", target.Config.ID, resp.StatusCode)
|
||||
return true, ""
|
||||
}
|
||||
msg := fmt.Sprintf("unexpected status code: %d (expected: %d)", resp.StatusCode, target.Config.Status)
|
||||
logger.Warn("Target %d: %s", target.Config.ID, msg)
|
||||
return false, msg
|
||||
}
|
||||
|
||||
if expectedStatus > 0 {
|
||||
logger.Debug("Target %d: checking health status against expected code %d", target.Config.ID, expectedStatus)
|
||||
// Check for specific status code
|
||||
if resp.StatusCode == expectedStatus {
|
||||
target.Status = StatusHealthy
|
||||
logger.Debug("Target %d: health check passed (status: %d, expected: %d)", target.Config.ID, resp.StatusCode, expectedStatus)
|
||||
} else {
|
||||
target.Status = StatusUnhealthy
|
||||
target.LastError = fmt.Sprintf("unexpected status code: %d (expected: %d)", resp.StatusCode, expectedStatus)
|
||||
logger.Warn("Target %d: health check failed with status code %d (expected: %d)", target.Config.ID, resp.StatusCode, expectedStatus)
|
||||
}
|
||||
} else {
|
||||
// Check for 2xx range
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
target.Status = StatusHealthy
|
||||
logger.Debug("Target %d: health check passed (status: %d)", target.Config.ID, resp.StatusCode)
|
||||
} else {
|
||||
target.Status = StatusUnhealthy
|
||||
target.LastError = fmt.Sprintf("unhealthy status code: %d", resp.StatusCode)
|
||||
logger.Warn("Target %d: health check failed with status code %d", target.Config.ID, resp.StatusCode)
|
||||
}
|
||||
// Default: check for 2xx range
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
logger.Debug("Target %d: health check passed (status: %d)", target.Config.ID, resp.StatusCode)
|
||||
return true, ""
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("unhealthy status code: %d", resp.StatusCode)
|
||||
logger.Warn("Target %d: health check failed with status code %d", target.Config.ID, resp.StatusCode)
|
||||
return false, msg
|
||||
}
|
||||
|
||||
// Stop stops monitoring all targets
|
||||
@@ -516,7 +585,7 @@ func (m *Monitor) DisableTarget(id int) error {
|
||||
|
||||
// Notify callback of status change
|
||||
if m.callback != nil {
|
||||
go m.callback(m.GetTargets())
|
||||
go m.callback(m.getAllTargetsUnsafe())
|
||||
}
|
||||
} else {
|
||||
logger.Debug("Target %d is already disabled", id)
|
||||
|
||||
18
main.go
18
main.go
@@ -129,6 +129,7 @@ var (
|
||||
dockerEnforceNetworkValidationBool bool
|
||||
pingInterval time.Duration
|
||||
pingTimeout time.Duration
|
||||
udpProxyIdleTimeout time.Duration
|
||||
publicKey wgtypes.Key
|
||||
pingStopChan chan struct{}
|
||||
stopFunc func()
|
||||
@@ -261,6 +262,7 @@ func runNewtMain(ctx context.Context) {
|
||||
dockerSocket = os.Getenv("DOCKER_SOCKET")
|
||||
pingIntervalStr := os.Getenv("PING_INTERVAL")
|
||||
pingTimeoutStr := os.Getenv("PING_TIMEOUT")
|
||||
udpProxyIdleTimeoutStr := os.Getenv("NEWT_UDP_PROXY_IDLE_TIMEOUT")
|
||||
dockerEnforceNetworkValidation = os.Getenv("DOCKER_ENFORCE_NETWORK_VALIDATION")
|
||||
healthFile = os.Getenv("HEALTH_FILE")
|
||||
// authorizedKeysFile = os.Getenv("AUTHORIZED_KEYS_FILE")
|
||||
@@ -337,6 +339,9 @@ func runNewtMain(ctx context.Context) {
|
||||
if pingTimeoutStr == "" {
|
||||
flag.StringVar(&pingTimeoutStr, "ping-timeout", "7s", " Timeout for each ping (default 7s)")
|
||||
}
|
||||
if udpProxyIdleTimeoutStr == "" {
|
||||
flag.StringVar(&udpProxyIdleTimeoutStr, "udp-proxy-idle-timeout", "90s", "Idle timeout for UDP proxied client flows before cleanup")
|
||||
}
|
||||
// load the prefer endpoint just as a flag
|
||||
flag.StringVar(&preferEndpoint, "prefer-endpoint", "", "Prefer this endpoint for the connection (if set, will override the endpoint from the server)")
|
||||
if provisioningKey == "" {
|
||||
@@ -386,6 +391,16 @@ func runNewtMain(ctx context.Context) {
|
||||
pingTimeout = 7 * time.Second
|
||||
}
|
||||
|
||||
if udpProxyIdleTimeoutStr != "" {
|
||||
udpProxyIdleTimeout, err = time.ParseDuration(udpProxyIdleTimeoutStr)
|
||||
if err != nil || udpProxyIdleTimeout <= 0 {
|
||||
fmt.Printf("Invalid NEWT_UDP_PROXY_IDLE_TIMEOUT/--udp-proxy-idle-timeout value: %s, using default 90 seconds\n", udpProxyIdleTimeoutStr)
|
||||
udpProxyIdleTimeout = 90 * time.Second
|
||||
}
|
||||
} else {
|
||||
udpProxyIdleTimeout = 90 * time.Second
|
||||
}
|
||||
|
||||
if dockerEnforceNetworkValidation == "" {
|
||||
flag.StringVar(&dockerEnforceNetworkValidation, "docker-enforce-network-validation", "false", "Enforce validation of container on newt network (true or false)")
|
||||
}
|
||||
@@ -527,7 +542,7 @@ func runNewtMain(ctx context.Context) {
|
||||
if telErr != nil {
|
||||
logger.Warn("Telemetry init failed: %v", telErr)
|
||||
}
|
||||
if tel != nil {
|
||||
if tel != nil && (metricsEnabled || pprofEnabled) {
|
||||
// Admin HTTP server (exposes /metrics when Prometheus exporter is enabled)
|
||||
logger.Debug("Starting metrics server on %s", tcfg.AdminAddr)
|
||||
mux := http.NewServeMux()
|
||||
@@ -896,6 +911,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
||||
// Create proxy manager
|
||||
pm = proxy.NewProxyManager(tnet)
|
||||
pm.SetAsyncBytes(metricsAsyncBytes)
|
||||
pm.SetUDPIdleTimeout(udpProxyIdleTimeout)
|
||||
// Set tunnel_id for metrics (WireGuard peer public key)
|
||||
pm.SetTunnelID(wgData.PublicKey)
|
||||
|
||||
|
||||
@@ -137,14 +137,33 @@ func (h *TCPHandler) InstallTCPHandler() error {
|
||||
|
||||
// handleTCPConn handles a TCP connection by proxying it to the actual target
|
||||
func (h *TCPHandler) handleTCPConn(netstackConn *gonet.TCPConn, id stack.TransportEndpointID) {
|
||||
defer netstackConn.Close()
|
||||
|
||||
// Extract source and target address from the connection ID
|
||||
// Extract source and target address from the connection ID first so they
|
||||
// are available for HTTP routing before any defer is set up.
|
||||
srcIP := id.RemoteAddress.String()
|
||||
srcPort := id.RemotePort
|
||||
dstIP := id.LocalAddress.String()
|
||||
dstPort := id.LocalPort
|
||||
|
||||
// For HTTP/HTTPS ports, look up the matching subnet rule. If the rule has
|
||||
// Protocol configured, hand the connection off to the HTTP handler which
|
||||
// takes full ownership of the lifecycle (the defer close must not be
|
||||
// installed before this point).
|
||||
if (dstPort == 80 || dstPort == 443) && h.proxyHandler != nil && h.proxyHandler.httpHandler != nil {
|
||||
srcAddr, _ := netip.ParseAddr(srcIP)
|
||||
dstAddr, _ := netip.ParseAddr(dstIP)
|
||||
rule := h.proxyHandler.subnetLookup.Match(srcAddr, dstAddr, dstPort, tcp.ProtocolNumber)
|
||||
if rule != nil && rule.Protocol != "" && len(rule.HTTPTargets) > 0 {
|
||||
logger.Info("TCP Forwarder: Routing %s:%d -> %s:%d to HTTP handler (%s)",
|
||||
srcIP, srcPort, dstIP, dstPort, rule.Protocol)
|
||||
h.proxyHandler.httpHandler.HandleConn(netstackConn, rule)
|
||||
return
|
||||
}
|
||||
// Otherwise fall through to raw TCP forwarding (e.g. CIDR resources
|
||||
// that happen to use port 80/443 without HTTP configuration).
|
||||
}
|
||||
|
||||
defer netstackConn.Close()
|
||||
|
||||
logger.Info("TCP Forwarder: Handling connection %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
// Check if there's a destination rewrite for this connection (e.g., localhost targets)
|
||||
|
||||
396
netstack2/http_handler.go
Normal file
396
netstack2/http_handler.go
Normal file
@@ -0,0 +1,396 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package netstack2
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HTTPTarget
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// HTTPTarget describes a single downstream HTTP or HTTPS service that the
|
||||
// proxy should forward requests to.
|
||||
type HTTPTarget struct {
|
||||
DestAddr string `json:"destAddr"` // IP address or hostname of the downstream service
|
||||
DestPort uint16 `json:"destPort"` // TCP port of the downstream service
|
||||
Scheme string `json:"scheme"` // When true the outbound leg uses HTTPS
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HTTPHandler
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// HTTPHandler intercepts TCP connections from the netstack forwarder on ports
|
||||
// 80 and 443 and services them as HTTP or HTTPS, reverse-proxying each request
|
||||
// to downstream targets specified by the matching SubnetRule.
|
||||
//
|
||||
// HTTP and raw TCP are fully separate: a connection is only routed here when
|
||||
// its SubnetRule has Protocol set ("http" or "https"). All other connections
|
||||
// on those ports fall through to the normal raw-TCP path.
|
||||
//
|
||||
// Incoming TLS termination (Protocol == "https") is performed per-connection
|
||||
// using the certificate and key stored in the rule, so different subnet rules
|
||||
// can present different certificates without sharing any state.
|
||||
//
|
||||
// Outbound connections to downstream targets honour HTTPTarget.UseHTTPS
|
||||
// independently of the incoming protocol.
|
||||
type HTTPHandler struct {
|
||||
stack *stack.Stack
|
||||
proxyHandler *ProxyHandler
|
||||
requestLogger *HTTPRequestLogger
|
||||
|
||||
listener *chanListener
|
||||
server *http.Server
|
||||
|
||||
// proxyCache holds pre-built *httputil.ReverseProxy values keyed by the
|
||||
// canonical target URL string ("scheme://host:port"). Building a proxy is
|
||||
// cheap, but reusing one preserves the underlying http.Transport connection
|
||||
// pool, which matters for throughput.
|
||||
proxyCache sync.Map // map[string]*httputil.ReverseProxy
|
||||
|
||||
// tlsCache holds pre-parsed *tls.Config values keyed by the concatenation
|
||||
// of the PEM certificate and key. Parsing a keypair is relatively expensive
|
||||
// and the same cert is likely reused across many connections.
|
||||
tlsCache sync.Map // map[string]*tls.Config
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// chanListener – net.Listener backed by a channel
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// chanListener implements net.Listener by receiving net.Conn values over a
|
||||
// buffered channel. This lets the netstack TCP forwarder hand off connections
|
||||
// directly to a running http.Server without any real OS socket.
|
||||
type chanListener struct {
|
||||
connCh chan net.Conn
|
||||
closed chan struct{}
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func newChanListener() *chanListener {
|
||||
return &chanListener{
|
||||
connCh: make(chan net.Conn, 128),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Accept blocks until a connection is available or the listener is closed.
|
||||
func (l *chanListener) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case conn, ok := <-l.connCh:
|
||||
if !ok {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
return conn, nil
|
||||
case <-l.closed:
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the listener; subsequent Accept calls return net.ErrClosed.
|
||||
func (l *chanListener) Close() error {
|
||||
l.once.Do(func() { close(l.closed) })
|
||||
return nil
|
||||
}
|
||||
|
||||
// Addr returns a placeholder address (the listener has no real OS socket).
|
||||
func (l *chanListener) Addr() net.Addr {
|
||||
return &net.TCPAddr{}
|
||||
}
|
||||
|
||||
// send delivers conn to the listener. Returns false if the listener is already
|
||||
// closed, in which case the caller is responsible for closing conn.
|
||||
func (l *chanListener) send(conn net.Conn) bool {
|
||||
select {
|
||||
case l.connCh <- conn:
|
||||
return true
|
||||
case <-l.closed:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// httpConnCtx – conn wrapper that carries a SubnetRule through the listener
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// httpConnCtx wraps a net.Conn so the matching SubnetRule can be passed
|
||||
// through the chanListener into the http.Server's ConnContext callback,
|
||||
// making it available to request handlers via the request context.
|
||||
type httpConnCtx struct {
|
||||
net.Conn
|
||||
rule *SubnetRule
|
||||
}
|
||||
|
||||
// connCtxKey is the unexported context key used to store a *SubnetRule on the
|
||||
// per-connection context created by http.Server.ConnContext.
|
||||
type connCtxKey struct{}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Constructor and lifecycle
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// NewHTTPHandler creates an HTTPHandler attached to the given stack and
|
||||
// ProxyHandler. Call Start to begin serving connections.
|
||||
func NewHTTPHandler(s *stack.Stack, ph *ProxyHandler) *HTTPHandler {
|
||||
return &HTTPHandler{
|
||||
stack: s,
|
||||
proxyHandler: ph,
|
||||
}
|
||||
}
|
||||
|
||||
// SetRequestLogger attaches an HTTPRequestLogger so that every proxied request
|
||||
// is recorded and periodically shipped to the server.
|
||||
func (h *HTTPHandler) SetRequestLogger(rl *HTTPRequestLogger) {
|
||||
h.requestLogger = rl
|
||||
}
|
||||
|
||||
// Start launches the internal http.Server that services connections delivered
|
||||
// via HandleConn. The server runs for the lifetime of the HTTPHandler; call
|
||||
// Close to stop it.
|
||||
func (h *HTTPHandler) Start() error {
|
||||
h.listener = newChanListener()
|
||||
|
||||
h.server = &http.Server{
|
||||
Handler: http.HandlerFunc(h.handleRequest),
|
||||
// ConnContext runs once per accepted connection and attaches the
|
||||
// SubnetRule carried by httpConnCtx to the connection's context so
|
||||
// that handleRequest can retrieve it without any global state.
|
||||
ConnContext: func(ctx context.Context, c net.Conn) context.Context {
|
||||
if cc, ok := c.(*httpConnCtx); ok {
|
||||
return context.WithValue(ctx, connCtxKey{}, cc.rule)
|
||||
}
|
||||
return ctx
|
||||
},
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := h.server.Serve(h.listener); err != nil && err != http.ErrServerClosed {
|
||||
logger.Error("HTTP handler: server exited unexpectedly: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
logger.Debug("HTTP handler: ready — routing determined per SubnetRule on ports 80/443")
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleConn accepts a TCP connection from the netstack forwarder together
|
||||
// with the SubnetRule that matched it. The HTTP handler takes full ownership
|
||||
// of the connection's lifecycle; the caller must NOT close conn after this call.
|
||||
//
|
||||
// When rule.Protocol is "https", TLS termination is performed on conn using
|
||||
// the certificate and key stored in rule.TLSCert and rule.TLSKey before the
|
||||
// connection is passed to the HTTP server. The HTTP server itself is always
|
||||
// plain-HTTP; TLS is fully unwrapped at this layer.
|
||||
func (h *HTTPHandler) HandleConn(conn net.Conn, rule *SubnetRule) {
|
||||
var effectiveConn net.Conn = conn
|
||||
|
||||
if rule.Protocol == "https" {
|
||||
tlsCfg, err := h.getTLSConfig(rule)
|
||||
if err != nil {
|
||||
logger.Error("HTTP handler: cannot build TLS config for connection from %s: %v",
|
||||
conn.RemoteAddr(), err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
// tls.Server wraps the raw conn; the TLS handshake is deferred until
|
||||
// the first Read, which the http.Server will trigger naturally.
|
||||
effectiveConn = tls.Server(conn, tlsCfg)
|
||||
}
|
||||
|
||||
wrapped := &httpConnCtx{Conn: effectiveConn, rule: rule}
|
||||
if !h.listener.send(wrapped) {
|
||||
// Listener is already closed — clean up the orphaned connection.
|
||||
effectiveConn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Close gracefully shuts down the HTTP server and the underlying channel
|
||||
// listener, causing the goroutine started in Start to exit.
|
||||
func (h *HTTPHandler) Close() error {
|
||||
if h.server != nil {
|
||||
if err := h.server.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if h.listener != nil {
|
||||
h.listener.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Internal helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// getTLSConfig returns a *tls.Config for the cert/key pair in rule, using a
|
||||
// cache to avoid re-parsing the same keypair on every connection.
|
||||
// The cache key is the concatenation of the PEM cert and key strings, so
|
||||
// different rules that happen to share the same material hit the same entry.
|
||||
func (h *HTTPHandler) getTLSConfig(rule *SubnetRule) (*tls.Config, error) {
|
||||
cacheKey := rule.TLSCert + "|" + rule.TLSKey
|
||||
if v, ok := h.tlsCache.Load(cacheKey); ok {
|
||||
return v.(*tls.Config), nil
|
||||
}
|
||||
|
||||
cert, err := tls.X509KeyPair([]byte(rule.TLSCert), []byte(rule.TLSKey))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse TLS keypair: %w", err)
|
||||
}
|
||||
cfg := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
// LoadOrStore is safe under concurrent calls: if two goroutines race here
|
||||
// both will produce a valid config; the loser's work is discarded.
|
||||
actual, _ := h.tlsCache.LoadOrStore(cacheKey, cfg)
|
||||
return actual.(*tls.Config), nil
|
||||
}
|
||||
|
||||
// getProxy returns a cached *httputil.ReverseProxy for the given target,
|
||||
// creating one on first use. Reusing the proxy preserves its http.Transport
|
||||
// connection pool, avoiding repeated TCP/TLS handshakes to the downstream.
|
||||
func (h *HTTPHandler) getProxy(target HTTPTarget) *httputil.ReverseProxy {
|
||||
scheme := target.Scheme
|
||||
cacheKey := fmt.Sprintf("%s://%s:%d", scheme, target.DestAddr, target.DestPort)
|
||||
|
||||
if v, ok := h.proxyCache.Load(cacheKey); ok {
|
||||
return v.(*httputil.ReverseProxy)
|
||||
}
|
||||
|
||||
targetURL := &url.URL{
|
||||
Scheme: scheme,
|
||||
Host: fmt.Sprintf("%s:%d", target.DestAddr, target.DestPort),
|
||||
}
|
||||
var transport http.RoundTripper = http.DefaultTransport
|
||||
if target.Scheme == "https" {
|
||||
// Allow self-signed certificates on downstream HTTPS targets.
|
||||
transport = &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true, //nolint:gosec // downstream self-signed certs are a supported configuration
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
proxy := &httputil.ReverseProxy{
|
||||
Rewrite: func(pr *httputil.ProxyRequest) {
|
||||
pr.SetURL(targetURL)
|
||||
// SetXForwarded sets X-Forwarded-For from the inbound request's
|
||||
// RemoteAddr (the WireGuard/netstack client address), along with
|
||||
// X-Forwarded-Host and X-Forwarded-Proto. Using Rewrite instead of
|
||||
// Director means the proxy does not append its own automatic
|
||||
// X-Forwarded-For entry, so the header is set exactly once.
|
||||
pr.SetXForwarded()
|
||||
},
|
||||
Transport: transport,
|
||||
}
|
||||
|
||||
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
logger.Error("HTTP handler: upstream error (%s %s -> %s): %v",
|
||||
r.Method, r.URL.RequestURI(), cacheKey, err)
|
||||
http.Error(w, "Bad Gateway", http.StatusBadGateway)
|
||||
}
|
||||
|
||||
actual, _ := h.proxyCache.LoadOrStore(cacheKey, proxy)
|
||||
return actual.(*httputil.ReverseProxy)
|
||||
}
|
||||
|
||||
// statusCapture wraps an http.ResponseWriter and records the HTTP status code
|
||||
// written by the upstream handler. If WriteHeader is never called the status
|
||||
// defaults to 200 (http.StatusOK), matching net/http semantics.
|
||||
type statusCapture struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
}
|
||||
|
||||
func (sc *statusCapture) WriteHeader(code int) {
|
||||
sc.status = code
|
||||
sc.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (sc *statusCapture) Unwrap() http.ResponseWriter {
|
||||
return sc.ResponseWriter
|
||||
}
|
||||
|
||||
func (sc *statusCapture) Flush() {
|
||||
if flusher, ok := sc.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (sc *statusCapture) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
hijacker, ok := sc.ResponseWriter.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, nil, errors.New("underlying response writer does not support hijacking")
|
||||
}
|
||||
return hijacker.Hijack()
|
||||
}
|
||||
|
||||
// handleRequest is the http.Handler entry point. It retrieves the SubnetRule
|
||||
// attached to the connection by ConnContext, selects the first configured
|
||||
// downstream target, and forwards the request via the cached ReverseProxy.
|
||||
//
|
||||
// TODO: add host/path-based routing across multiple HTTPTargets once the
|
||||
// configuration model evolves beyond a single target per rule.
|
||||
func (h *HTTPHandler) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
rule, _ := r.Context().Value(connCtxKey{}).(*SubnetRule)
|
||||
if rule == nil || len(rule.HTTPTargets) == 0 {
|
||||
logger.Error("HTTP handler: no downstream targets for request %s %s", r.Method, r.URL.RequestURI())
|
||||
http.Error(w, "no targets configured", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
// If the rule is plain HTTP but has a TLS certificate configured, redirect
|
||||
// the client to the HTTPS equivalent of the requested URL.
|
||||
if rule.Protocol == "http" && rule.TLSCert != "" && rule.TLSKey != "" {
|
||||
host := r.Host
|
||||
if host == "" {
|
||||
host = r.URL.Host
|
||||
}
|
||||
httpsURL := "https://" + host + r.RequestURI
|
||||
logger.Info("HTTP handler: redirecting %s %s -> %s (TLS cert present)", r.Method, r.URL.RequestURI(), httpsURL)
|
||||
http.Redirect(w, r, httpsURL, http.StatusMovedPermanently)
|
||||
return
|
||||
}
|
||||
|
||||
target := rule.HTTPTargets[0]
|
||||
scheme := target.Scheme
|
||||
logger.Info("HTTP handler: %s %s -> %s://%s:%d",
|
||||
r.Method, r.URL.RequestURI(), scheme, target.DestAddr, target.DestPort)
|
||||
|
||||
timestamp := time.Now()
|
||||
sc := &statusCapture{ResponseWriter: w, status: http.StatusOK}
|
||||
|
||||
h.getProxy(target).ServeHTTP(sc, r)
|
||||
|
||||
if h.requestLogger != nil && rule.ResourceId != 0 {
|
||||
h.requestLogger.LogRequest(HTTPRequestLog{
|
||||
ResourceID: rule.ResourceId,
|
||||
Timestamp: timestamp,
|
||||
Method: r.Method,
|
||||
Scheme: rule.Protocol,
|
||||
Host: r.Host,
|
||||
Path: r.URL.Path,
|
||||
RawQuery: r.URL.RawQuery,
|
||||
UserAgent: r.UserAgent(),
|
||||
SourceAddr: r.RemoteAddr,
|
||||
TLS: rule.Protocol == "https",
|
||||
})
|
||||
}
|
||||
}
|
||||
97
netstack2/http_handler_test.go
Normal file
97
netstack2/http_handler_test.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package netstack2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func TestHTTPHandlerProxiesWebSocketUpgrade(t *testing.T) {
|
||||
upgrader := websocket.Upgrader{}
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Errorf("upgrade failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
messageType, payload, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
t.Errorf("read failed: %v", err)
|
||||
return
|
||||
}
|
||||
if err := conn.WriteMessage(messageType, append([]byte("echo:"), payload...)); err != nil {
|
||||
t.Errorf("write failed: %v", err)
|
||||
}
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
backendURL, err := url.Parse(backend.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("parse backend URL: %v", err)
|
||||
}
|
||||
backendHost, backendPort, err := net.SplitHostPort(backendURL.Host)
|
||||
if err != nil {
|
||||
t.Fatalf("split backend host: %v", err)
|
||||
}
|
||||
port, err := net.LookupPort("tcp", backendPort)
|
||||
if err != nil {
|
||||
t.Fatalf("parse backend port: %v", err)
|
||||
}
|
||||
|
||||
handler := NewHTTPHandler(nil, nil)
|
||||
rule := &SubnetRule{
|
||||
Protocol: "http",
|
||||
HTTPTargets: []HTTPTarget{
|
||||
{
|
||||
DestAddr: backendHost,
|
||||
DestPort: uint16(port),
|
||||
Scheme: backendURL.Scheme,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := context.WithValue(r.Context(), connCtxKey{}, rule)
|
||||
handler.handleRequest(w, r.WithContext(ctx))
|
||||
}))
|
||||
defer frontend.Close()
|
||||
|
||||
frontendURL, err := url.Parse(frontend.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("parse frontend URL: %v", err)
|
||||
}
|
||||
wsURL := url.URL{
|
||||
Scheme: "ws",
|
||||
Host: frontendURL.Host,
|
||||
Path: "/socket",
|
||||
RawQuery: "token=test",
|
||||
}
|
||||
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL.String(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("dial websocket through proxy: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if err := conn.WriteMessage(websocket.TextMessage, []byte("hello")); err != nil {
|
||||
t.Fatalf("write websocket message: %v", err)
|
||||
}
|
||||
|
||||
messageType, payload, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("read websocket message: %v", err)
|
||||
}
|
||||
if messageType != websocket.TextMessage {
|
||||
t.Fatalf("message type = %d, want %d", messageType, websocket.TextMessage)
|
||||
}
|
||||
if got, want := string(payload), "echo:hello"; got != want {
|
||||
t.Fatalf("payload = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
175
netstack2/http_request_log.go
Normal file
175
netstack2/http_request_log.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package netstack2
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/zlib"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
)
|
||||
|
||||
// HTTPRequestLog represents a single HTTP/HTTPS request proxied through the handler.
|
||||
type HTTPRequestLog struct {
|
||||
RequestID string `json:"requestId"`
|
||||
ResourceID int `json:"resourceId"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Method string `json:"method"`
|
||||
Scheme string `json:"scheme"`
|
||||
Host string `json:"host"`
|
||||
Path string `json:"path"`
|
||||
RawQuery string `json:"rawQuery,omitempty"`
|
||||
UserAgent string `json:"userAgent,omitempty"`
|
||||
SourceAddr string `json:"sourceAddr"`
|
||||
TLS bool `json:"tls"`
|
||||
}
|
||||
|
||||
// HTTPRequestLogger buffers HTTP request logs and periodically flushes them
|
||||
// to the server via a configurable SendFunc.
|
||||
type HTTPRequestLogger struct {
|
||||
mu sync.Mutex
|
||||
pending []HTTPRequestLog
|
||||
sendFn SendFunc
|
||||
stopCh chan struct{}
|
||||
flushDone chan struct{}
|
||||
}
|
||||
|
||||
// NewHTTPRequestLogger creates a new HTTPRequestLogger and starts its background flush loop.
|
||||
func NewHTTPRequestLogger() *HTTPRequestLogger {
|
||||
rl := &HTTPRequestLogger{
|
||||
pending: make([]HTTPRequestLog, 0),
|
||||
stopCh: make(chan struct{}),
|
||||
flushDone: make(chan struct{}),
|
||||
}
|
||||
go rl.backgroundLoop()
|
||||
return rl
|
||||
}
|
||||
|
||||
// SetSendFunc sets the callback used to send compressed HTTP request log batches
|
||||
// to the server. This can be called after construction once the websocket
|
||||
// client is available.
|
||||
func (rl *HTTPRequestLogger) SetSendFunc(fn SendFunc) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
rl.sendFn = fn
|
||||
}
|
||||
|
||||
// LogRequest adds an HTTP request log entry to the buffer. If the buffer
|
||||
// reaches maxBufferedSessions entries a flush is triggered immediately.
|
||||
func (rl *HTTPRequestLogger) LogRequest(log HTTPRequestLog) {
|
||||
if log.RequestID == "" {
|
||||
log.RequestID = generateSessionID()
|
||||
}
|
||||
|
||||
rl.mu.Lock()
|
||||
rl.pending = append(rl.pending, log)
|
||||
shouldFlush := len(rl.pending) >= maxBufferedSessions
|
||||
rl.mu.Unlock()
|
||||
|
||||
if shouldFlush {
|
||||
rl.flush()
|
||||
}
|
||||
}
|
||||
|
||||
// backgroundLoop handles periodic flushing of buffered request logs.
|
||||
func (rl *HTTPRequestLogger) backgroundLoop() {
|
||||
defer close(rl.flushDone)
|
||||
|
||||
ticker := time.NewTicker(flushInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-rl.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
rl.flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// flush drains the pending buffer, compresses with zlib, and sends via the SendFunc.
|
||||
// On send failure the batch is re-queued, capped at maxBufferedSessions*5 entries
|
||||
// to prevent unbounded memory growth when the server is unreachable.
|
||||
func (rl *HTTPRequestLogger) flush() {
|
||||
rl.mu.Lock()
|
||||
if len(rl.pending) == 0 {
|
||||
rl.mu.Unlock()
|
||||
return
|
||||
}
|
||||
batch := rl.pending
|
||||
rl.pending = make([]HTTPRequestLog, 0)
|
||||
sendFn := rl.sendFn
|
||||
rl.mu.Unlock()
|
||||
|
||||
if sendFn == nil {
|
||||
logger.Debug("HTTP request logger: no send function configured, discarding %d requests", len(batch))
|
||||
return
|
||||
}
|
||||
|
||||
compressed, err := compressRequestLogs(batch)
|
||||
if err != nil {
|
||||
logger.Error("HTTP request logger: failed to compress %d requests: %v", len(batch), err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := sendFn(compressed); err != nil {
|
||||
logger.Error("HTTP request logger: failed to send %d requests: %v", len(batch), err)
|
||||
// Re-queue the batch so we don't lose data
|
||||
rl.mu.Lock()
|
||||
rl.pending = append(batch, rl.pending...)
|
||||
// Cap re-queued data to prevent unbounded growth if server is unreachable
|
||||
if len(rl.pending) > maxBufferedSessions*5 {
|
||||
dropped := len(rl.pending) - maxBufferedSessions*5
|
||||
rl.pending = rl.pending[:maxBufferedSessions*5]
|
||||
logger.Warn("HTTP request logger: buffer overflow, dropped %d oldest requests", dropped)
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("HTTP request logger: sent %d requests to server", len(batch))
|
||||
}
|
||||
|
||||
// compressRequestLogs JSON-encodes the request logs, compresses with zlib, and
|
||||
// returns a base64-encoded string suitable for embedding in a JSON message.
|
||||
func compressRequestLogs(logs []HTTPRequestLog) (string, error) {
|
||||
jsonData, err := json.Marshal(logs)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
w, err := zlib.NewWriterLevel(&buf, zlib.BestCompression)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err := w.Write(jsonData); err != nil {
|
||||
w.Close()
|
||||
return "", err
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return base64.StdEncoding.EncodeToString(buf.Bytes()), nil
|
||||
}
|
||||
|
||||
// Close shuts down the background loop and performs one final flush to send
|
||||
// any remaining buffered requests to the server.
|
||||
func (rl *HTTPRequestLogger) Close() {
|
||||
select {
|
||||
case <-rl.stopCh:
|
||||
// Already closed
|
||||
return
|
||||
default:
|
||||
close(rl.stopCh)
|
||||
}
|
||||
|
||||
// Wait for the background loop to exit so we don't race on flush
|
||||
<-rl.flushDone
|
||||
|
||||
rl.flush()
|
||||
}
|
||||
@@ -53,6 +53,14 @@ type SubnetRule struct {
|
||||
RewriteTo string // Optional rewrite address for DNAT - can be IP/CIDR or domain name
|
||||
PortRanges []PortRange // empty slice means all ports allowed
|
||||
ResourceId int // Optional resource ID from the server for access logging
|
||||
|
||||
// HTTP proxy configuration (optional).
|
||||
// When Protocol is non-empty the TCP connection is handled by HTTPHandler
|
||||
// instead of the raw TCP forwarder.
|
||||
Protocol string // "", "http", or "https" — controls the incoming (client-facing) protocol
|
||||
HTTPTargets []HTTPTarget // downstream services to proxy requests to
|
||||
TLSCert string // PEM-encoded certificate for incoming HTTPS termination
|
||||
TLSKey string // PEM-encoded private key for incoming HTTPS termination
|
||||
}
|
||||
|
||||
// GetAllRules returns a copy of all subnet rules
|
||||
@@ -114,6 +122,7 @@ type ProxyHandler struct {
|
||||
tcpHandler *TCPHandler
|
||||
udpHandler *UDPHandler
|
||||
icmpHandler *ICMPHandler
|
||||
httpHandler *HTTPHandler
|
||||
subnetLookup *SubnetLookup
|
||||
natTable map[connKey]*natState
|
||||
reverseNatTable map[reverseConnKey]*natState // Reverse lookup map for O(1) reply packet NAT
|
||||
@@ -124,6 +133,7 @@ type ProxyHandler struct {
|
||||
icmpReplies chan []byte // Channel for ICMP reply packets to be sent back through the tunnel
|
||||
notifiable channel.Notification // Notification handler for triggering reads
|
||||
accessLogger *AccessLogger // Access logger for tracking sessions
|
||||
httpRequestLogger *HTTPRequestLogger // HTTP request logger for proxied HTTP/HTTPS requests
|
||||
}
|
||||
|
||||
// ProxyHandlerOptions configures the proxy handler
|
||||
@@ -164,12 +174,24 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
|
||||
}),
|
||||
}
|
||||
|
||||
// Initialize TCP handler if enabled
|
||||
// Initialize TCP handler if enabled. The HTTP handler piggybacks on the
|
||||
// TCP forwarder — TCPHandler.handleTCPConn checks the subnet rule for
|
||||
// ports 80/443 and routes matching connections to the HTTP handler, so
|
||||
// the HTTP handler is always initialised alongside TCP.
|
||||
if options.EnableTCP {
|
||||
handler.tcpHandler = NewTCPHandler(handler.proxyStack, handler)
|
||||
if err := handler.tcpHandler.InstallTCPHandler(); err != nil {
|
||||
return nil, fmt.Errorf("failed to install TCP handler: %v", err)
|
||||
}
|
||||
|
||||
handler.httpHandler = NewHTTPHandler(handler.proxyStack, handler)
|
||||
if err := handler.httpHandler.Start(); err != nil {
|
||||
return nil, fmt.Errorf("failed to start HTTP handler: %v", err)
|
||||
}
|
||||
|
||||
handler.httpRequestLogger = NewHTTPRequestLogger()
|
||||
handler.httpHandler.SetRequestLogger(handler.httpRequestLogger)
|
||||
logger.Debug("ProxyHandler: HTTP handler enabled")
|
||||
}
|
||||
|
||||
// Initialize UDP handler if enabled
|
||||
@@ -208,16 +230,14 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
|
||||
return handler, nil
|
||||
}
|
||||
|
||||
// AddSubnetRule adds a subnet with optional port restrictions to the proxy handler
|
||||
// sourcePrefix: The IP prefix of the peer sending the data
|
||||
// destPrefix: The IP prefix of the destination
|
||||
// rewriteTo: Optional address to rewrite destination to - can be IP/CIDR or domain name
|
||||
// If portRanges is nil or empty, all ports are allowed for this subnet
|
||||
func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) {
|
||||
// AddSubnetRule adds a subnet rule to the proxy handler.
|
||||
// HTTP proxy behaviour is configured via rule.Protocol, rule.HTTPTargets,
|
||||
// rule.TLSCert, and rule.TLSKey; leave Protocol empty for raw TCP/UDP.
|
||||
func (p *ProxyHandler) AddSubnetRule(rule SubnetRule) {
|
||||
if p == nil || !p.enabled {
|
||||
return
|
||||
}
|
||||
p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp, resourceId)
|
||||
p.subnetLookup.AddSubnet(rule)
|
||||
}
|
||||
|
||||
// RemoveSubnetRule removes a subnet from the proxy handler
|
||||
@@ -273,6 +293,24 @@ func (p *ProxyHandler) SetAccessLogSender(fn SendFunc) {
|
||||
p.accessLogger.SetSendFunc(fn)
|
||||
}
|
||||
|
||||
// GetHTTPRequestLogger returns the HTTP request logger.
|
||||
func (p *ProxyHandler) GetHTTPRequestLogger() *HTTPRequestLogger {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
return p.httpRequestLogger
|
||||
}
|
||||
|
||||
// SetHTTPRequestLogSender configures the function used to send compressed HTTP
|
||||
// request log batches to the server. This should be called once the websocket
|
||||
// client is available.
|
||||
func (p *ProxyHandler) SetHTTPRequestLogSender(fn SendFunc) {
|
||||
if p == nil || !p.enabled || p.httpRequestLogger == nil {
|
||||
return
|
||||
}
|
||||
p.httpRequestLogger.SetSendFunc(fn)
|
||||
}
|
||||
|
||||
// LookupDestinationRewrite looks up the rewritten destination for a connection
|
||||
// This is used by TCP/UDP handlers to find the actual target address
|
||||
func (p *ProxyHandler) LookupDestinationRewrite(srcIP, dstIP string, dstPort uint16, proto uint8) (netip.Addr, bool) {
|
||||
@@ -794,6 +832,16 @@ func (p *ProxyHandler) Close() error {
|
||||
p.accessLogger.Close()
|
||||
}
|
||||
|
||||
// Shut down HTTP request logger
|
||||
if p.httpRequestLogger != nil {
|
||||
p.httpRequestLogger.Close()
|
||||
}
|
||||
|
||||
// Shut down HTTP handler
|
||||
if p.httpHandler != nil {
|
||||
p.httpHandler.Close()
|
||||
}
|
||||
|
||||
// Close ICMP replies channel
|
||||
if p.icmpReplies != nil {
|
||||
close(p.icmpReplies)
|
||||
|
||||
@@ -44,24 +44,18 @@ func prefixEqual(a, b netip.Prefix) bool {
|
||||
return a.Masked() == b.Masked()
|
||||
}
|
||||
|
||||
// AddSubnet adds a subnet rule with source and destination prefixes and optional port restrictions
|
||||
// If portRanges is nil or empty, all ports are allowed for this subnet
|
||||
// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com")
|
||||
func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) {
|
||||
// AddSubnet adds a subnet rule to the lookup table.
|
||||
// If rule.PortRanges is nil or empty, all ports are allowed.
|
||||
// rule.RewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com").
|
||||
// HTTP proxy behaviour is driven by rule.Protocol, rule.HTTPTargets, rule.TLSCert, and rule.TLSKey.
|
||||
func (sl *SubnetLookup) AddSubnet(rule SubnetRule) {
|
||||
sl.mu.Lock()
|
||||
defer sl.mu.Unlock()
|
||||
|
||||
rule := &SubnetRule{
|
||||
SourcePrefix: sourcePrefix,
|
||||
DestPrefix: destPrefix,
|
||||
DisableIcmp: disableIcmp,
|
||||
RewriteTo: rewriteTo,
|
||||
PortRanges: portRanges,
|
||||
ResourceId: resourceId,
|
||||
}
|
||||
rulePtr := &rule
|
||||
|
||||
// Canonicalize source prefix to handle host bits correctly
|
||||
canonicalSourcePrefix := sourcePrefix.Masked()
|
||||
canonicalSourcePrefix := rule.SourcePrefix.Masked()
|
||||
|
||||
// Get or create destination trie for this source prefix
|
||||
destTriePtr, exists := sl.sourceTrie.Get(canonicalSourcePrefix)
|
||||
@@ -76,12 +70,12 @@ func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewrite
|
||||
|
||||
// Canonicalize destination prefix to handle host bits correctly
|
||||
// BART masks prefixes internally, so we need to match that behavior in our bookkeeping
|
||||
canonicalDestPrefix := destPrefix.Masked()
|
||||
canonicalDestPrefix := rule.DestPrefix.Masked()
|
||||
|
||||
// Add rule to destination trie
|
||||
// Original behavior: overwrite if same (sourcePrefix, destPrefix) exists
|
||||
// Store as single-element slice to match original overwrite behavior
|
||||
destTriePtr.trie.Insert(canonicalDestPrefix, []*SubnetRule{rule})
|
||||
destTriePtr.trie.Insert(canonicalDestPrefix, []*SubnetRule{rulePtr})
|
||||
|
||||
// Update destTriePtr.rules - remove old rule with same canonical prefix if exists, then add new one
|
||||
// Use canonical comparison to handle cases like 10.0.0.5/24 vs 10.0.0.0/24
|
||||
@@ -91,7 +85,7 @@ func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewrite
|
||||
newRules = append(newRules, r)
|
||||
}
|
||||
}
|
||||
newRules = append(newRules, rule)
|
||||
newRules = append(newRules, rulePtr)
|
||||
destTriePtr.rules = newRules
|
||||
}
|
||||
|
||||
|
||||
@@ -351,13 +351,13 @@ func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
|
||||
return net.DialUDP(laddr, nil)
|
||||
}
|
||||
|
||||
// AddProxySubnetRule adds a subnet rule to the proxy handler
|
||||
// If portRanges is nil or empty, all ports are allowed for this subnet
|
||||
// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com")
|
||||
func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) {
|
||||
// AddProxySubnetRule adds a subnet rule to the proxy handler.
|
||||
// HTTP proxy behaviour is configured via rule.Protocol, rule.HTTPTargets,
|
||||
// rule.TLSCert, and rule.TLSKey; leave Protocol empty for raw TCP/UDP.
|
||||
func (net *Net) AddProxySubnetRule(rule SubnetRule) {
|
||||
tun := (*netTun)(net)
|
||||
if tun.proxyHandler != nil {
|
||||
tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp, resourceId)
|
||||
tun.proxyHandler.AddSubnetRule(rule)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -394,6 +394,16 @@ func (net *Net) SetAccessLogSender(fn SendFunc) {
|
||||
}
|
||||
}
|
||||
|
||||
// SetHTTPRequestLogSender configures the function used to send compressed HTTP
|
||||
// request log batches to the server. This should be called once the websocket
|
||||
// client is available.
|
||||
func (net *Net) SetHTTPRequestLogSender(fn SendFunc) {
|
||||
tun := (*netTun)(net)
|
||||
if tun.proxyHandler != nil {
|
||||
tun.proxyHandler.SetHTTPRequestLogSender(fn)
|
||||
}
|
||||
}
|
||||
|
||||
type PingConn struct {
|
||||
laddr PingAddr
|
||||
raddr PingAddr
|
||||
|
||||
@@ -120,7 +120,7 @@ func configureDarwin(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
||||
prefix, _ := ipNet.Mask.Size()
|
||||
ipStr := fmt.Sprintf("%s/%d", ip.String(), prefix)
|
||||
|
||||
cmd := exec.Command("ifconfig", interfaceName, "inet", ipStr, ip.String(), "alias")
|
||||
cmd := exec.Command("/sbin/ifconfig", interfaceName, "inet", ipStr, ip.String(), "alias")
|
||||
logger.Info("Running command: %v", cmd)
|
||||
|
||||
out, err := cmd.CombinedOutput()
|
||||
@@ -129,7 +129,7 @@ func configureDarwin(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
||||
}
|
||||
|
||||
// Bring up the interface
|
||||
cmd = exec.Command("ifconfig", interfaceName, "up")
|
||||
cmd = exec.Command("/sbin/ifconfig", interfaceName, "up")
|
||||
logger.Info("Running command: %v", cmd)
|
||||
|
||||
out, err = cmd.CombinedOutput()
|
||||
|
||||
@@ -23,9 +23,31 @@ import (
|
||||
|
||||
const (
|
||||
errUnsupportedProtoFmt = "unsupported protocol: %s"
|
||||
maxUDPPacketSize = 65507
|
||||
maxUDPPacketSize = 65507 // Maximum UDP packet size
|
||||
defaultUDPIdleTimeout = 90 * time.Second
|
||||
)
|
||||
|
||||
// udpBufferPool provides reusable buffers for UDP packet handling.
|
||||
// This reduces GC pressure from frequent large allocations.
|
||||
var udpBufferPool = sync.Pool{
|
||||
New: func() any {
|
||||
buf := make([]byte, maxUDPPacketSize)
|
||||
return &buf
|
||||
},
|
||||
}
|
||||
|
||||
// getUDPBuffer retrieves a buffer from the pool.
|
||||
func getUDPBuffer() *[]byte {
|
||||
return udpBufferPool.Get().(*[]byte)
|
||||
}
|
||||
|
||||
// putUDPBuffer clears and returns a buffer to the pool.
|
||||
func putUDPBuffer(buf *[]byte) {
|
||||
// Clear the buffer to prevent data leakage
|
||||
clear(*buf)
|
||||
udpBufferPool.Put(buf)
|
||||
}
|
||||
|
||||
// Target represents a proxy target with its address and port
|
||||
type Target struct {
|
||||
Address string
|
||||
@@ -47,6 +69,7 @@ type ProxyManager struct {
|
||||
tunnels map[string]*tunnelEntry
|
||||
asyncBytes bool
|
||||
flushStop chan struct{}
|
||||
udpIdleTimeout time.Duration
|
||||
}
|
||||
|
||||
// tunnelEntry holds per-tunnel attributes and (optional) async counters.
|
||||
@@ -132,6 +155,7 @@ func NewProxyManager(tnet *netstack.Net) *ProxyManager {
|
||||
listeners: make([]*gonet.TCPListener, 0),
|
||||
udpConns: make([]*gonet.UDPConn, 0),
|
||||
tunnels: make(map[string]*tunnelEntry),
|
||||
udpIdleTimeout: defaultUDPIdleTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -209,6 +233,7 @@ func NewProxyManagerWithoutTNet() *ProxyManager {
|
||||
udpTargets: make(map[string]map[int]string),
|
||||
listeners: make([]*gonet.TCPListener, 0),
|
||||
udpConns: make([]*gonet.UDPConn, 0),
|
||||
udpIdleTimeout: defaultUDPIdleTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -345,6 +370,17 @@ func (pm *ProxyManager) SetAsyncBytes(b bool) {
|
||||
go pm.flushLoop()
|
||||
}
|
||||
}
|
||||
|
||||
// SetUDPIdleTimeout configures when idle UDP client flows are reclaimed.
|
||||
func (pm *ProxyManager) SetUDPIdleTimeout(d time.Duration) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
if d <= 0 {
|
||||
pm.udpIdleTimeout = defaultUDPIdleTimeout
|
||||
return
|
||||
}
|
||||
pm.udpIdleTimeout = d
|
||||
}
|
||||
func (pm *ProxyManager) flushLoop() {
|
||||
flushInterval := 2 * time.Second
|
||||
if v := os.Getenv("OTEL_METRIC_EXPORT_INTERVAL"); v != "" {
|
||||
@@ -555,7 +591,9 @@ func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
|
||||
buffer := make([]byte, maxUDPPacketSize) // Max UDP packet size
|
||||
bufPtr := getUDPBuffer()
|
||||
defer putUDPBuffer(bufPtr)
|
||||
buffer := *bufPtr
|
||||
clientConns := make(map[string]*net.UDPConn)
|
||||
var clientsMutex sync.RWMutex
|
||||
|
||||
@@ -623,6 +661,9 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
|
||||
telemetry.IncProxyAccept(context.Background(), pm.currentTunnelID, "udp", "failure", classifyProxyError(err))
|
||||
continue
|
||||
}
|
||||
// Prevent idle UDP client goroutines from living forever and
|
||||
// retaining large per-connection buffers.
|
||||
_ = targetConn.SetReadDeadline(time.Now().Add(pm.udpIdleTimeout))
|
||||
tunnelID := pm.currentTunnelID
|
||||
telemetry.IncProxyAccept(context.Background(), tunnelID, "udp", "success", "")
|
||||
telemetry.IncProxyConnectionEvent(context.Background(), tunnelID, "udp", telemetry.ProxyConnectionOpened)
|
||||
@@ -638,7 +679,10 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
|
||||
go func(clientKey string, targetConn *net.UDPConn, remoteAddr net.Addr, tunnelID string) {
|
||||
start := time.Now()
|
||||
result := "success"
|
||||
bufPtr := getUDPBuffer()
|
||||
defer func() {
|
||||
// Return buffer to pool first
|
||||
putUDPBuffer(bufPtr)
|
||||
// Always clean up when this goroutine exits
|
||||
clientsMutex.Lock()
|
||||
if storedConn, exists := clientConns[clientKey]; exists && storedConn == targetConn {
|
||||
@@ -653,10 +697,14 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
|
||||
telemetry.IncProxyConnectionEvent(context.Background(), tunnelID, "udp", telemetry.ProxyConnectionClosed)
|
||||
}()
|
||||
|
||||
buffer := make([]byte, maxUDPPacketSize)
|
||||
buffer := *bufPtr
|
||||
for {
|
||||
n, _, err := targetConn.ReadFromUDP(buffer)
|
||||
if err != nil {
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
return
|
||||
}
|
||||
// Connection closed is normal during cleanup
|
||||
if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) {
|
||||
return // defer will handle cleanup, result stays "success"
|
||||
@@ -699,6 +747,8 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
|
||||
delete(clientConns, clientKey)
|
||||
clientsMutex.Unlock()
|
||||
} else if pm.currentTunnelID != "" && written > 0 {
|
||||
// Extend idle timeout whenever client traffic is observed.
|
||||
_ = targetConn.SetReadDeadline(time.Now().Add(pm.udpIdleTimeout))
|
||||
if pm.asyncBytes {
|
||||
if e := pm.getEntry(pm.currentTunnelID); e != nil {
|
||||
e.bytesInUDP.Add(uint64(written))
|
||||
|
||||
60
testing/ws_client.py
Normal file
60
testing/ws_client.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import websockets
|
||||
|
||||
# Argument parsing: Check if HOST and PORT are provided
|
||||
if len(sys.argv) < 3 or len(sys.argv) > 4:
|
||||
print("Usage: python ws_client.py <HOST_IP> <HOST_PORT> [ws|wss]")
|
||||
# Example: python ws_client.py 127.0.0.1 8765
|
||||
# Example: python ws_client.py 127.0.0.1 8765 wss
|
||||
sys.exit(1)
|
||||
|
||||
HOST = sys.argv[1]
|
||||
try:
|
||||
PORT = int(sys.argv[2])
|
||||
except ValueError:
|
||||
print("Error: HOST_PORT must be an integer.")
|
||||
sys.exit(1)
|
||||
|
||||
if len(sys.argv) == 4:
|
||||
SCHEME = sys.argv[3].lower()
|
||||
if SCHEME not in ("ws", "wss"):
|
||||
print("Error: scheme must be 'ws' or 'wss'.")
|
||||
sys.exit(1)
|
||||
else:
|
||||
SCHEME = "ws"
|
||||
|
||||
URI = f"{SCHEME}://{HOST}:{PORT}"
|
||||
|
||||
# The message to send to the server
|
||||
MESSAGE = "Hello WebSocket Server! How are you?"
|
||||
|
||||
|
||||
async def main():
|
||||
print(f"Connecting to {URI}...")
|
||||
|
||||
try:
|
||||
async with websockets.connect(URI) as websocket:
|
||||
print(f"Connected to server.")
|
||||
print(f"Sending message: '{MESSAGE}'")
|
||||
|
||||
await websocket.send(MESSAGE)
|
||||
|
||||
response = await websocket.recv()
|
||||
|
||||
print("-" * 30)
|
||||
print(f"Received response from server:")
|
||||
print(f"-> Data: '{response}'")
|
||||
|
||||
except ConnectionRefusedError:
|
||||
print(f"Error: Connection to {URI} was refused. Is the server running?")
|
||||
except websockets.exceptions.InvalidMessage as e:
|
||||
print(f"Error: Server did not respond with a valid WebSocket handshake: {e}")
|
||||
except Exception as e:
|
||||
print(f"Error during communication: {e}")
|
||||
|
||||
print("-" * 30)
|
||||
print("Client finished.")
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
49
testing/ws_server.py
Normal file
49
testing/ws_server.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import websockets
|
||||
|
||||
# Optionally take in a positional arg for the port
|
||||
if len(sys.argv) > 1:
|
||||
try:
|
||||
PORT = int(sys.argv[1])
|
||||
except ValueError:
|
||||
print("Invalid port number. Using default port 8765.")
|
||||
PORT = 8765
|
||||
else:
|
||||
PORT = 8765
|
||||
|
||||
# Define the server host
|
||||
HOST = "0.0.0.0"
|
||||
|
||||
|
||||
async def handle_client(websocket):
|
||||
client_address = websocket.remote_address
|
||||
print(f"Client connected: {client_address[0]}:{client_address[1]}")
|
||||
|
||||
try:
|
||||
async for message in websocket:
|
||||
print("-" * 30)
|
||||
print(f"Received message from {client_address[0]}:{client_address[1]}:")
|
||||
print(f"-> Data: '{message}'")
|
||||
|
||||
response = f"Hello client! Server received: '{message.upper()}'"
|
||||
|
||||
await websocket.send(response)
|
||||
print(f"Sent response back to client.")
|
||||
|
||||
except websockets.exceptions.ConnectionClosedOK:
|
||||
print(f"Client {client_address[0]}:{client_address[1]} disconnected cleanly.")
|
||||
except websockets.exceptions.ConnectionClosedError as e:
|
||||
print(f"Client {client_address[0]}:{client_address[1]} disconnected with error: {e}")
|
||||
|
||||
|
||||
async def main():
|
||||
print(f"WebSocket Server listening on {HOST}:{PORT}")
|
||||
async with websockets.serve(handle_client, HOST, PORT):
|
||||
await asyncio.Future() # Run forever
|
||||
|
||||
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
print("\nServer stopped.")
|
||||
@@ -707,6 +707,10 @@ func (c *Client) sendPing() {
|
||||
}
|
||||
|
||||
c.writeMux.Lock()
|
||||
if c.conn == nil {
|
||||
c.writeMux.Unlock()
|
||||
return
|
||||
}
|
||||
err := c.conn.WriteJSON(pingMsg)
|
||||
if err == nil {
|
||||
telemetry.IncWSMessage(c.metricsContext(), "out", "ping")
|
||||
@@ -859,10 +863,12 @@ func (c *Client) readPumpWithDisconnectDetection(started time.Time) {
|
||||
func (c *Client) reconnect() {
|
||||
c.setConnected(false)
|
||||
telemetry.SetWSConnectionState(false)
|
||||
c.writeMux.Lock()
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
c.writeMux.Unlock()
|
||||
|
||||
// Only reconnect if we're not shutting down
|
||||
select {
|
||||
|
||||
@@ -71,6 +71,11 @@ func (c *Client) loadConfig() error {
|
||||
}
|
||||
return err
|
||||
}
|
||||
if len(bytes.TrimSpace(data)) == 0 {
|
||||
logger.Info("Config file at %s is empty, will initialize it with provided values", configPath)
|
||||
c.configNeedsSave = true
|
||||
return nil
|
||||
}
|
||||
|
||||
var config Config
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
|
||||
35
websocket/config_test.go
Normal file
35
websocket/config_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadConfig_EmptyFileMarksConfigForSave(t *testing.T) {
|
||||
t.Setenv("CONFIG_FILE", "")
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.json")
|
||||
if err := os.WriteFile(configPath, []byte(""), 0o644); err != nil {
|
||||
t.Fatalf("failed to create empty config file: %v", err)
|
||||
}
|
||||
|
||||
client := &Client{
|
||||
config: &Config{
|
||||
Endpoint: "https://example.com",
|
||||
ProvisioningKey: "spk-test",
|
||||
},
|
||||
clientType: "newt",
|
||||
configFilePath: configPath,
|
||||
}
|
||||
|
||||
if err := client.loadConfig(); err != nil {
|
||||
t.Fatalf("loadConfig returned error for empty file: %v", err)
|
||||
}
|
||||
|
||||
if !client.configNeedsSave {
|
||||
t.Fatal("expected empty config file to mark configNeedsSave")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user