Compare commits

...

65 Commits

Author SHA1 Message Date
Owen
72a9e111dc Localhost working - is this the best way to do it? 2025-12-05 16:33:43 -05:00
Owen
02949be245 Support connection testing in native 2025-12-04 21:48:32 -05:00
Owen
6d51cbf0c0 Check permissions 2025-12-04 21:39:32 -05:00
Owen
4dbf200cca Change DNS lookup to conntrack 2025-12-04 20:13:48 -05:00
Owen
d8b4fb4acb Change to disable clients 2025-12-04 20:13:35 -05:00
Owen
5dd5a56379 Add caching to the dns requests - is this good enough? 2025-12-03 22:00:23 -05:00
Owen
8c4d6e2e0a Working on more hp 2025-12-03 20:49:46 -05:00
Owen
284f1ce627 Also close the clients 2025-12-02 11:17:34 -05:00
Owen
cd466ac43f Fix some ipv4 in v6 issues 2025-12-01 17:54:38 -05:00
Owen
2256d1f041 Holepunch tester working? 2025-12-01 17:44:33 -05:00
Owen
40ca839771 Handle hp and other stuff 2025-12-01 16:20:30 -05:00
Owen
01ec6a0ce0 Handle holepunches better 2025-12-01 13:54:14 -05:00
Owen
d04f6cf702 Dont throw errors on cleanup 2025-11-30 19:45:25 -05:00
Owen
cdaff27964 Speed much better! 2025-11-30 11:24:50 -05:00
Owen
de96be810b Working but no wgtester? - revert if bad 2025-11-29 17:38:34 -05:00
Owen
5196effdb8 Kind of working - revert if not 2025-11-26 17:57:27 -05:00
Owen
d6edd6ca01 Make hp regular 2025-11-26 17:39:10 -05:00
Owen
1b1323b553 Move network to newt - handle --native mode 2025-11-26 15:06:16 -05:00
Owen
bb95d10e86 Rewriting desitnation works 2025-11-26 14:28:51 -05:00
Owen
da04746781 Add rewriteTo 2025-11-25 11:29:41 -05:00
Owen
61b9615aea Add utility functions 2025-11-23 17:07:40 -05:00
Owen
025c94e586 Export wireguard logger 2025-11-18 14:53:12 -05:00
Owen
75e666c396 Update logger to take in when initing 2025-11-17 21:49:07 -05:00
Owen
82a999eb87 Fix resolve 2025-11-17 18:07:36 -05:00
Owen
921e72f628 Update clients 2025-11-17 15:55:24 -05:00
Owen
46b33fdca6 Remove native and add util 2025-11-17 15:32:22 -05:00
Owen
9caa9fa31e Make logger extensible 2025-11-17 13:49:43 -05:00
Owen
dbbea6b34c Shift things around - remove native 2025-11-17 13:39:32 -05:00
Owen
491180c6a1 Remove proxy manager and break out subnet proxy 2025-11-15 21:46:32 -05:00
Owen
f49a276259 Centralize some functions 2025-11-15 16:32:02 -05:00
Owen
c71c6e0b1a Update to use new packages 2025-11-15 16:14:40 -05:00
Owen
972c9a9760 UDP WORKING! 2025-11-14 15:30:26 -05:00
Owen
8f7ee2a8dc TCP WORKING! 2025-11-14 15:23:20 -05:00
Owen
a737c3e8de REmove readme 2025-11-10 21:37:03 -05:00
Owen
1ba10c1b68 Experiment 2025-11-10 21:33:31 -05:00
Owen
2c8755f346 Using 2 nics not working 2025-11-05 21:46:29 -08:00
Owen
348cac66c8 Bring in netstack locally 2025-11-05 13:39:54 -08:00
Owen
6226a262d6 Merge branch 'main' into dev 2025-11-04 16:50:11 -08:00
Owen
5b70feb6a5 Merge branch 'main' into dev 2025-11-04 16:50:01 -08:00
Owen Schwartz
0ec18d6655 Merge pull request #169 from marcschaeferger/gh-action
Fix Github CICD Action and add Improvements
2025-10-28 21:11:56 -07:00
Marc Schäfer
7d60240572 testdata: add expected telemetry metrics for connection attempts and events 2025-10-28 23:17:05 +01:00
Marc Schäfer
ee3e7d1442 Added Improvements for CICD Action 2025-10-28 23:14:40 +01:00
Owen
527321a415 Update cicd 2025-10-27 21:26:14 -07:00
Owen
ff07692248 Merge branch 'main' into dev 2025-10-27 21:25:20 -07:00
Owen
8d3ae5afd7 Add doc for SKIP_TLS_VERIFY 2025-10-27 21:25:12 -07:00
Owen
ed99dce7e0 Add doc for SKIP_TLS_VERIFY 2025-10-27 21:24:22 -07:00
Owen Schwartz
f1e07272bd Merge pull request #166 from marcschaeferger/gh-action
Adding GHCR to CI/CD Release Workflow & further improvements
2025-10-20 17:21:05 -07:00
Marc Schäfer
a1a3d63fcf ci(actions): change runner from ubuntu-latest to amd64-runner for CI/CD workflows 2025-10-21 02:17:49 +02:00
Marc Schäfer
2a273dc435 ci(actions): add GHCR mirroring and cosign signing for Docker images
- mirror images from Docker Hub to GHCR using skopeo (preserves multi-arch manifests)
- login to GHCR via docker/login-action for signing/pushing
- install cosign and perform dual signing: keyless (OIDC) + key-based; verify signatures
- add required permissions for id-token/packages and reference necessary secrets
2025-10-21 00:22:32 +02:00
Marc Schäfer
ec05686523 ci(actions): pin action versions to commit SHAs for security
- Pin actions/checkout to SHA for v5.0.0
- Pin docker/setup-qemu-action to SHA for v3.6.0
- Pin docker/setup-buildx-action to SHA for v3.11.1
- Pin docker/login-action to SHA for v3.6.0
- Pin actions/setup-go to SHA for v6.0.0
- Pin actions/upload-artifact to SHA for v4.6.2
2025-10-21 00:21:28 +02:00
Owen Schwartz
915e7e44d1 Merge pull request #165 from marcschaeferger/ghcr
feat(actions): Sync Images from Docker to GHCR
2025-10-20 12:32:41 -07:00
Marc Schäfer
a729b91ac3 feat(actions): Sync Images from Docker to GHCR 2025-10-20 21:30:31 +02:00
Owen
ddc37658df Update domain 2025-10-19 15:12:15 -07:00
Owen
7c780f7a4f Merge branch 'dev' of github.com:fosrl/newt into dev 2025-10-16 21:09:41 -07:00
Owen
6b1c1ed077 Merge branch 'main' of github.com:fosrl/newt 2025-10-16 21:06:33 -07:00
Owen Schwartz
7a07437b22 Merge pull request #162 from marcschaeferger/otel
Adding OpenTelemetry Metrics and Tracing
2025-10-16 20:48:37 -07:00
Owen
d63d8d6f5e Add log message that the server is on 2025-10-16 20:42:02 -07:00
Owen
bda1d04f67 Add documentation for cli and reporg 2025-10-16 20:39:41 -07:00
Owen
7f8ee37c7f Update runner 2025-10-16 17:51:25 -07:00
Marc Schäfer
6d2073a478 Remove Coolify Code 2025-10-11 18:46:02 +02:00
Owen Schwartz
6048f244f1 Merge pull request #158 from fosrl/dependabot/go_modules/prod-patch-updates-46361b25de
Bump github.com/docker/docker from 28.5.0+incompatible to 28.5.1+incompatible in the prod-patch-updates group
2025-10-11 09:41:30 -07:00
Owen Schwartz
9fec22a53b Merge pull request #159 from fosrl/dependabot/go_modules/prod-minor-updates-a55d2abe4a
Bump the prod-minor-updates group with 2 updates
2025-10-11 09:41:20 -07:00
Marc Schäfer
c086e69dd0 Adding OpenTelemetry Metrics and Tracing 2025-10-11 18:19:51 +02:00
dependabot[bot]
c729ab5fc6 Bump the prod-minor-updates group with 2 updates
Bumps the prod-minor-updates group with 2 updates: [golang.org/x/crypto](https://github.com/golang/crypto) and [golang.org/x/net](https://github.com/golang/net).


Updates `golang.org/x/crypto` from 0.42.0 to 0.43.0
- [Commits](https://github.com/golang/crypto/compare/v0.42.0...v0.43.0)

Updates `golang.org/x/net` from 0.45.0 to 0.46.0
- [Commits](https://github.com/golang/net/compare/v0.45.0...v0.46.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-version: 0.43.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: prod-minor-updates
- dependency-name: golang.org/x/net
  dependency-version: 0.46.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: prod-minor-updates
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-10-09 09:34:56 +00:00
dependabot[bot]
552617cbb5 Bump github.com/docker/docker in the prod-patch-updates group
Bumps the prod-patch-updates group with 1 update: [github.com/docker/docker](https://github.com/docker/docker).


Updates `github.com/docker/docker` from 28.5.0+incompatible to 28.5.1+incompatible
- [Release notes](https://github.com/docker/docker/releases)
- [Commits](https://github.com/docker/docker/compare/v28.5.0...v28.5.1)

---
updated-dependencies:
- dependency-name: github.com/docker/docker
  dependency-version: 28.5.1+incompatible
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: prod-patch-updates
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-10-09 09:34:49 +00:00
68 changed files with 11654 additions and 3023 deletions

5
.env.example Normal file
View File

@@ -0,0 +1,5 @@
# Copy this file to .env and fill in your values
# Required for connecting to Pangolin service
PANGOLIN_ENDPOINT=https://example.com
NEWT_ID=changeme-id
NEWT_SECRET=changeme-secret

View File

@@ -1,64 +1,601 @@
name: CI/CD Pipeline
permissions:
contents: read
contents: write # gh-release
packages: write # GHCR push
id-token: write # Keyless-Signatures & Attestations
attestations: write # actions/attest-build-provenance
security-events: write # upload-sarif
actions: read
on:
push:
tags:
- "*"
push:
tags:
- "*"
workflow_dispatch:
inputs:
version:
description: "SemVer version to release (e.g., 1.2.3, no leading 'v')"
required: true
type: string
publish_latest:
description: "Also publish the 'latest' image tag"
required: true
type: boolean
default: false
publish_minor:
description: "Also publish the 'major.minor' image tag (e.g., 1.2)"
required: true
type: boolean
default: false
target_branch:
description: "Branch to tag"
required: false
default: "main"
concurrency:
group: ${{ github.workflow }}-${{ github.event_name == 'workflow_dispatch' && github.event.inputs.version || github.ref_name }}
cancel-in-progress: true
jobs:
release:
name: Build and Release
runs-on: ubuntu-latest
prepare:
if: github.event_name == 'workflow_dispatch'
name: Prepare release (create tag)
runs-on: ubuntu-24.04
permissions:
contents: write
steps:
- name: Checkout repository
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
fetch-depth: 0
steps:
- name: Checkout code
uses: actions/checkout@v5
- name: Validate version input
shell: bash
env:
INPUT_VERSION: ${{ inputs.version }}
run: |
set -euo pipefail
if ! [[ "$INPUT_VERSION" =~ ^[0-9]+\.[0-9]+\.[0-9]+(-rc\.[0-9]+)?$ ]]; then
echo "Invalid version: $INPUT_VERSION (expected X.Y.Z or X.Y.Z-rc.N)" >&2
exit 1
fi
- name: Create and push tag
shell: bash
env:
TARGET_BRANCH: ${{ inputs.target_branch }}
VERSION: ${{ inputs.version }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
set -euo pipefail
git config user.name "github-actions[bot]"
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
git fetch --prune origin
git checkout "$TARGET_BRANCH"
git pull --ff-only origin "$TARGET_BRANCH"
if git rev-parse -q --verify "refs/tags/$VERSION" >/dev/null; then
echo "Tag $VERSION already exists" >&2
exit 1
fi
git tag -a "$VERSION" -m "Release $VERSION"
git push origin "refs/tags/$VERSION"
release:
if: ${{ github.event_name == 'workflow_dispatch' || (github.event_name == 'push' && github.actor != 'github-actions[bot]') }}
name: Build and Release
runs-on: ubuntu-24.04
timeout-minutes: 120
env:
DOCKERHUB_IMAGE: docker.io/${{ vars.DOCKER_HUB_USERNAME }}/${{ github.event.repository.name }}
GHCR_IMAGE: ghcr.io/${{ github.repository_owner }}/${{ github.event.repository.name }}
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
steps:
- name: Checkout code
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
fetch-depth: 0
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Capture created timestamp
run: echo "IMAGE_CREATED=$(date -u +%Y-%m-%dT%H:%M:%SZ)" >> $GITHUB_ENV
shell: bash
- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
- name: Set up QEMU
uses: docker/setup-qemu-action@29109295f81e9208d7d86ff1c6c12d2833863392 # v3.6.0
- name: Extract tag name
id: get-tag
run: echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # v3.11.1
- name: Install Go
uses: actions/setup-go@v6
with:
go-version: 1.25
- name: Log in to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
with:
registry: docker.io
username: ${{ vars.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
- name: Update version in main.go
run: |
TAG=${{ env.TAG }}
if [ -f main.go ]; then
sed -i 's/version_replaceme/'"$TAG"'/' main.go
echo "Updated main.go with version $TAG"
else
echo "main.go not found"
fi
- name: Log in to GHCR
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build and push Docker images
run: |
TAG=${{ env.TAG }}
make docker-build-release tag=$TAG
- name: Normalize image names to lowercase
run: |
set -euo pipefail
echo "GHCR_IMAGE=${GHCR_IMAGE,,}" >> "$GITHUB_ENV"
echo "DOCKERHUB_IMAGE=${DOCKERHUB_IMAGE,,}" >> "$GITHUB_ENV"
shell: bash
- name: Build binaries
run: |
make go-build-release
- name: Extract tag name
env:
EVENT_NAME: ${{ github.event_name }}
INPUT_VERSION: ${{ inputs.version }}
run: |
if [ "$EVENT_NAME" = "workflow_dispatch" ]; then
echo "TAG=${INPUT_VERSION}" >> $GITHUB_ENV
else
echo "TAG=${{ github.ref_name }}" >> $GITHUB_ENV
fi
shell: bash
- name: Upload artifacts from /bin
uses: actions/upload-artifact@v4
with:
name: binaries
path: bin/
- name: Validate pushed tag format (no leading 'v')
if: ${{ github.event_name == 'push' }}
shell: bash
env:
TAG_GOT: ${{ env.TAG }}
run: |
set -euo pipefail
if [[ "$TAG_GOT" =~ ^[0-9]+\.[0-9]+\.[0-9]+(-rc\.[0-9]+)?$ ]]; then
echo "Tag OK: $TAG_GOT"
exit 0
fi
echo "ERROR: Tag '$TAG_GOT' is not allowed. Use 'X.Y.Z' or 'X.Y.Z-rc.N' (no leading 'v')." >&2
exit 1
- name: Wait for tag to be visible (dispatch only)
if: ${{ github.event_name == 'workflow_dispatch' }}
run: |
set -euo pipefail
for i in {1..90}; do
if git ls-remote --tags origin "refs/tags/${TAG}" | grep -qE "refs/tags/${TAG}$"; then
echo "Tag ${TAG} is visible on origin"; exit 0
fi
echo "Tag not yet visible, retrying... ($i/90)"
sleep 2
done
echo "Tag ${TAG} not visible after waiting"; exit 1
shell: bash
- name: Ensure repository is at the tagged commit (dispatch only)
if: ${{ github.event_name == 'workflow_dispatch' }}
run: |
set -euo pipefail
git fetch --tags --force
git checkout "refs/tags/${TAG}"
echo "Checked out $(git rev-parse --short HEAD) for tag ${TAG}"
shell: bash
- name: Detect release candidate (rc)
run: |
set -euo pipefail
if [[ "${TAG}" =~ ^[0-9]+\.[0-9]+\.[0-9]+-rc\.[0-9]+$ ]]; then
echo "IS_RC=true" >> $GITHUB_ENV
else
echo "IS_RC=false" >> $GITHUB_ENV
fi
shell: bash
- name: Install Go
uses: actions/setup-go@44694675825211faa026b3c33043df3e48a5fa00 # v6.0.0
with:
go-version-file: go.mod
- name: Resolve publish-latest flag
env:
EVENT_NAME: ${{ github.event_name }}
PL_INPUT: ${{ inputs.publish_latest }}
PL_VAR: ${{ vars.PUBLISH_LATEST }}
run: |
set -euo pipefail
val="false"
if [ "$EVENT_NAME" = "workflow_dispatch" ]; then
if [ "${PL_INPUT}" = "true" ]; then val="true"; fi
else
if [ "${PL_VAR}" = "true" ]; then val="true"; fi
fi
echo "PUBLISH_LATEST=$val" >> $GITHUB_ENV
shell: bash
- name: Resolve publish-minor flag
env:
EVENT_NAME: ${{ github.event_name }}
PM_INPUT: ${{ inputs.publish_minor }}
PM_VAR: ${{ vars.PUBLISH_MINOR }}
run: |
set -euo pipefail
val="false"
if [ "$EVENT_NAME" = "workflow_dispatch" ]; then
if [ "${PM_INPUT}" = "true" ]; then val="true"; fi
else
if [ "${PM_VAR}" = "true" ]; then val="true"; fi
fi
echo "PUBLISH_MINOR=$val" >> $GITHUB_ENV
shell: bash
- name: Cache Go modules
if: ${{ hashFiles('**/go.sum') != '' }}
uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0
with:
path: |
~/.cache/go-build
~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: Go vet & test
if: ${{ hashFiles('**/go.mod') != '' }}
run: |
go version
go vet ./...
go test ./... -race -covermode=atomic
shell: bash
- name: Resolve license fallback
run: echo "IMAGE_LICENSE=${{ github.event.repository.license.spdx_id || 'NOASSERTION' }}" >> $GITHUB_ENV
shell: bash
- name: Resolve registries list (GHCR always, Docker Hub only if creds)
shell: bash
run: |
set -euo pipefail
images="${GHCR_IMAGE}"
if [ -n "${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}" ] && [ -n "${{ vars.DOCKER_HUB_USERNAME }}" ]; then
images="${images}\n${DOCKERHUB_IMAGE}"
fi
{
echo 'IMAGE_LIST<<EOF'
echo -e "$images"
echo 'EOF'
} >> "$GITHUB_ENV"
- name: Docker meta
id: meta
uses: docker/metadata-action@c1e51972afc2121e065aed6d45c65596fe445f3f # v5.8.0
with:
images: ${{ env.IMAGE_LIST }}
tags: |
type=semver,pattern={{version}},value=${{ env.TAG }}
type=semver,pattern={{major}}.{{minor}},value=${{ env.TAG }},enable=${{ env.PUBLISH_MINOR == 'true' && env.IS_RC != 'true' }}
type=raw,value=latest,enable=${{ env.PUBLISH_LATEST == 'true' && env.IS_RC != 'true' }}
flavor: |
latest=false
labels: |
org.opencontainers.image.title=${{ github.event.repository.name }}
org.opencontainers.image.version=${{ env.TAG }}
org.opencontainers.image.revision=${{ github.sha }}
org.opencontainers.image.source=${{ github.event.repository.html_url }}
org.opencontainers.image.url=${{ github.event.repository.html_url }}
org.opencontainers.image.documentation=${{ github.event.repository.html_url }}
org.opencontainers.image.description=${{ github.event.repository.description }}
org.opencontainers.image.licenses=${{ env.IMAGE_LICENSE }}
org.opencontainers.image.created=${{ env.IMAGE_CREATED }}
org.opencontainers.image.ref.name=${{ env.TAG }}
org.opencontainers.image.authors=${{ github.repository_owner }}
- name: Echo build config (non-secret)
shell: bash
env:
IMAGE_TITLE: ${{ github.event.repository.name }}
IMAGE_VERSION: ${{ env.TAG }}
IMAGE_REVISION: ${{ github.sha }}
IMAGE_SOURCE_URL: ${{ github.event.repository.html_url }}
IMAGE_URL: ${{ github.event.repository.html_url }}
IMAGE_DESCRIPTION: ${{ github.event.repository.description }}
IMAGE_LICENSE: ${{ env.IMAGE_LICENSE }}
DOCKERHUB_IMAGE: ${{ env.DOCKERHUB_IMAGE }}
GHCR_IMAGE: ${{ env.GHCR_IMAGE }}
DOCKER_HUB_USER: ${{ vars.DOCKER_HUB_USERNAME }}
REPO: ${{ github.repository }}
OWNER: ${{ github.repository_owner }}
WORKFLOW_REF: ${{ github.workflow_ref }}
REF: ${{ github.ref }}
REF_NAME: ${{ github.ref_name }}
RUN_URL: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}
run: |
set -euo pipefail
echo "=== OCI Label Values ==="
echo "org.opencontainers.image.title=${IMAGE_TITLE}"
echo "org.opencontainers.image.version=${IMAGE_VERSION}"
echo "org.opencontainers.image.revision=${IMAGE_REVISION}"
echo "org.opencontainers.image.source=${IMAGE_SOURCE_URL}"
echo "org.opencontainers.image.url=${IMAGE_URL}"
echo "org.opencontainers.image.description=${IMAGE_DESCRIPTION}"
echo "org.opencontainers.image.licenses=${IMAGE_LICENSE}"
echo
echo "=== Images ==="
echo "DOCKERHUB_IMAGE=${DOCKERHUB_IMAGE}"
echo "GHCR_IMAGE=${GHCR_IMAGE}"
echo "DOCKER_HUB_USERNAME=${DOCKER_HUB_USER}"
echo
echo "=== GitHub Kontext ==="
echo "repository=${REPO}"
echo "owner=${OWNER}"
echo "workflow_ref=${WORKFLOW_REF}"
echo "ref=${REF}"
echo "ref_name=${REF_NAME}"
echo "run_url=${RUN_URL}"
echo
echo "=== docker/metadata-action outputs (Tags/Labels), raw ==="
echo "::group::tags"
echo "${{ steps.meta.outputs.tags }}"
echo "::endgroup::"
echo "::group::labels"
echo "${{ steps.meta.outputs.labels }}"
echo "::endgroup::"
- name: Build and push (Docker Hub + GHCR)
id: build
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # v6.18.0
with:
context: .
push: true
platforms: linux/amd64,linux/arm64
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=gha,scope=${{ github.repository }}
cache-to: type=gha,mode=max,scope=${{ github.repository }}
provenance: mode=max
sbom: true
- name: Compute image digest refs
run: |
echo "DIGEST=${{ steps.build.outputs.digest }}" >> $GITHUB_ENV
echo "GHCR_REF=$GHCR_IMAGE@${{ steps.build.outputs.digest }}" >> $GITHUB_ENV
echo "DH_REF=$DOCKERHUB_IMAGE@${{ steps.build.outputs.digest }}" >> $GITHUB_ENV
echo "Built digest: ${{ steps.build.outputs.digest }}"
shell: bash
- name: Attest build provenance (GHCR)
id: attest-ghcr
uses: actions/attest-build-provenance@977bb373ede98d70efdf65b84cb5f73e068dcc2a # v3.0.0
with:
subject-name: ${{ env.GHCR_IMAGE }}
subject-digest: ${{ steps.build.outputs.digest }}
push-to-registry: true
show-summary: true
- name: Attest build provenance (Docker Hub)
continue-on-error: true
id: attest-dh
uses: actions/attest-build-provenance@977bb373ede98d70efdf65b84cb5f73e068dcc2a # v3.0.0
with:
subject-name: index.docker.io/${{ vars.DOCKER_HUB_USERNAME }}/${{ github.event.repository.name }}
subject-digest: ${{ steps.build.outputs.digest }}
push-to-registry: true
show-summary: true
- name: Install cosign
uses: sigstore/cosign-installer@faadad0cce49287aee09b3a48701e75088a2c6ad # v4.0.0
with:
cosign-release: 'v3.0.2'
- name: Sanity check cosign private key
env:
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
run: |
set -euo pipefail
cosign public-key --key env://COSIGN_PRIVATE_KEY >/dev/null
shell: bash
- name: Sign GHCR image (digest) with key (recursive)
env:
COSIGN_YES: "true"
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
run: |
set -euo pipefail
echo "Signing ${GHCR_REF} (digest) recursively with provided key"
cosign sign --key env://COSIGN_PRIVATE_KEY --recursive "${GHCR_REF}"
shell: bash
- name: Generate SBOM (SPDX JSON)
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # v0.33.1
with:
image-ref: ${{ env.GHCR_IMAGE }}@${{ steps.build.outputs.digest }}
format: spdx-json
output: sbom.spdx.json
- name: Validate SBOM JSON
run: jq -e . sbom.spdx.json >/dev/null
shell: bash
- name: Minify SBOM JSON (optional hardening)
run: jq -c . sbom.spdx.json > sbom.min.json && mv sbom.min.json sbom.spdx.json
shell: bash
- name: Create SBOM attestation (GHCR, private key)
env:
COSIGN_YES: "true"
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
run: |
set -euo pipefail
cosign attest \
--key env://COSIGN_PRIVATE_KEY \
--type spdxjson \
--predicate sbom.spdx.json \
"${GHCR_REF}"
shell: bash
- name: Create SBOM attestation (Docker Hub, private key)
continue-on-error: true
env:
COSIGN_YES: "true"
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
COSIGN_DOCKER_MEDIA_TYPES: "1"
run: |
set -euo pipefail
cosign attest \
--key env://COSIGN_PRIVATE_KEY \
--type spdxjson \
--predicate sbom.spdx.json \
"${DH_REF}"
shell: bash
- name: Keyless sign & verify GHCR digest (OIDC)
env:
COSIGN_YES: "true"
WORKFLOW_REF: ${{ github.workflow_ref }} # owner/repo/.github/workflows/<file>@refs/tags/<tag>
ISSUER: https://token.actions.githubusercontent.com
run: |
set -euo pipefail
echo "Keyless signing ${GHCR_REF}"
cosign sign --rekor-url https://rekor.sigstore.dev --recursive "${GHCR_REF}"
echo "Verify keyless (OIDC) signature policy on ${GHCR_REF}"
cosign verify \
--certificate-oidc-issuer "${ISSUER}" \
--certificate-identity "https://github.com/${WORKFLOW_REF}" \
"${GHCR_REF}" -o text
shell: bash
- name: Sign Docker Hub image (digest) with key (recursive)
continue-on-error: true
env:
COSIGN_YES: "true"
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
COSIGN_DOCKER_MEDIA_TYPES: "1"
run: |
set -euo pipefail
echo "Signing ${DH_REF} (digest) recursively with provided key (Docker media types fallback)"
cosign sign --key env://COSIGN_PRIVATE_KEY --recursive "${DH_REF}"
shell: bash
- name: Keyless sign & verify Docker Hub digest (OIDC)
continue-on-error: true
env:
COSIGN_YES: "true"
ISSUER: https://token.actions.githubusercontent.com
COSIGN_DOCKER_MEDIA_TYPES: "1"
run: |
set -euo pipefail
echo "Keyless signing ${DH_REF} (force public-good Rekor)"
cosign sign --rekor-url https://rekor.sigstore.dev --recursive "${DH_REF}"
echo "Keyless verify via Rekor (strict identity)"
if ! cosign verify \
--rekor-url https://rekor.sigstore.dev \
--certificate-oidc-issuer "${ISSUER}" \
--certificate-identity "https://github.com/${{ github.workflow_ref }}" \
"${DH_REF}" -o text; then
echo "Rekor verify failed — retry offline bundle verify (no Rekor)"
if ! cosign verify \
--offline \
--certificate-oidc-issuer "${ISSUER}" \
--certificate-identity "https://github.com/${{ github.workflow_ref }}" \
"${DH_REF}" -o text; then
echo "Offline bundle verify failed — ignore tlog (TEMP for debugging)"
cosign verify \
--insecure-ignore-tlog=true \
--certificate-oidc-issuer "${ISSUER}" \
--certificate-identity "https://github.com/${{ github.workflow_ref }}" \
"${DH_REF}" -o text || true
fi
fi
- name: Verify signature (public key) GHCR digest + tag
env:
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
run: |
set -euo pipefail
TAG_VAR="${TAG}"
echo "Verifying (digest) ${GHCR_REF}"
cosign verify --key env://COSIGN_PUBLIC_KEY "$GHCR_REF" -o text
echo "Verifying (tag) $GHCR_IMAGE:$TAG_VAR"
cosign verify --key env://COSIGN_PUBLIC_KEY "$GHCR_IMAGE:$TAG_VAR" -o text
shell: bash
- name: Verify SBOM attestation (GHCR)
env:
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
run: cosign verify-attestation --key env://COSIGN_PUBLIC_KEY --type spdxjson "$GHCR_REF" -o text
shell: bash
- name: Verify SLSA provenance (GHCR)
env:
ISSUER: https://token.actions.githubusercontent.com
WFREF: ${{ github.workflow_ref }}
run: |
set -euo pipefail
# (optional) show which predicate types are present to aid debugging
cosign download attestation "$GHCR_REF" \
| jq -r '.payload | @base64d | fromjson | .predicateType' | sort -u || true
# Verify the SLSA v1 provenance attestation (predicate URL)
cosign verify-attestation \
--type 'https://slsa.dev/provenance/v1' \
--certificate-oidc-issuer "$ISSUER" \
--certificate-identity "https://github.com/${WFREF}" \
--rekor-url https://rekor.sigstore.dev \
"$GHCR_REF" -o text
shell: bash
- name: Verify signature (public key) Docker Hub digest
continue-on-error: true
env:
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
COSIGN_DOCKER_MEDIA_TYPES: "1"
run: |
set -euo pipefail
echo "Verifying (digest) ${DH_REF} with Docker media types"
cosign verify --key env://COSIGN_PUBLIC_KEY "${DH_REF}" -o text
shell: bash
- name: Verify signature (public key) Docker Hub tag
continue-on-error: true
env:
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
COSIGN_DOCKER_MEDIA_TYPES: "1"
run: |
set -euo pipefail
echo "Verifying (tag) $DOCKERHUB_IMAGE:$TAG with Docker media types"
cosign verify --key env://COSIGN_PUBLIC_KEY "$DOCKERHUB_IMAGE:$TAG" -o text
shell: bash
- name: Trivy scan (GHCR image)
id: trivy
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # v0.33.1
with:
image-ref: ${{ env.GHCR_IMAGE }}@${{ steps.build.outputs.digest }}
format: sarif
output: trivy-ghcr.sarif
ignore-unfixed: true
vuln-type: os,library
severity: CRITICAL,HIGH
exit-code: ${{ (vars.TRIVY_FAIL || '0') }}
- name: Upload SARIF
if: ${{ always() && hashFiles('trivy-ghcr.sarif') != '' }}
uses: github/codeql-action/upload-sarif@4e94bd11f71e507f7f87df81788dff88d1dacbfb # v4.31.0
with:
sarif_file: trivy-ghcr.sarif
category: Image Vulnerability Scan
- name: Build binaries
env:
CGO_ENABLED: "0"
GOFLAGS: "-trimpath"
run: |
set -euo pipefail
TAG_VAR="${TAG}"
make go-build-release tag=$TAG_VAR
shell: bash
- name: Create GitHub Release
uses: softprops/action-gh-release@6da8fa9354ddfdc4aeace5fc48d7f679b5214090 # v2.4.1
with:
tag_name: ${{ env.TAG }}
generate_release_notes: true
prerelease: ${{ env.IS_RC == 'true' }}
files: |
bin/*
fail_on_unmatched_files: true
body: |
## Container Images
- GHCR: `${{ env.GHCR_REF }}`
- Docker Hub: `${{ env.DH_REF || 'N/A' }}`
**Digest:** `${{ steps.build.outputs.digest }}`

132
.github/workflows/mirror.yaml vendored Normal file
View File

@@ -0,0 +1,132 @@
name: Mirror & Sign (Docker Hub to GHCR)
on:
workflow_dispatch: {}
permissions:
contents: read
packages: write
id-token: write # for keyless OIDC
env:
SOURCE_IMAGE: docker.io/fosrl/newt
DEST_IMAGE: ghcr.io/${{ github.repository_owner }}/${{ github.event.repository.name }}
jobs:
mirror-and-dual-sign:
runs-on: amd64-runner
steps:
- name: Install skopeo + jq
run: |
sudo apt-get update -y
sudo apt-get install -y skopeo jq
skopeo --version
- name: Install cosign
uses: sigstore/cosign-installer@faadad0cce49287aee09b3a48701e75088a2c6ad # v4.0.0
- name: Input check
run: |
test -n "${SOURCE_IMAGE}" || (echo "SOURCE_IMAGE is empty" && exit 1)
echo "Source : ${SOURCE_IMAGE}"
echo "Target : ${DEST_IMAGE}"
# Auth for skopeo (containers-auth)
- name: Skopeo login to GHCR
run: |
skopeo login ghcr.io -u "${{ github.actor }}" -p "${{ secrets.GITHUB_TOKEN }}"
# Auth for cosign (docker-config)
- name: Docker login to GHCR (for cosign)
run: |
echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u "${{ github.actor }}" --password-stdin
- name: List source tags
run: |
set -euo pipefail
skopeo list-tags --retry-times 3 docker://"${SOURCE_IMAGE}" \
| jq -r '.Tags[]' | sort -u > src-tags.txt
echo "Found source tags: $(wc -l < src-tags.txt)"
head -n 20 src-tags.txt || true
- name: List destination tags (skip existing)
run: |
set -euo pipefail
if skopeo list-tags --retry-times 3 docker://"${DEST_IMAGE}" >/tmp/dst.json 2>/dev/null; then
jq -r '.Tags[]' /tmp/dst.json | sort -u > dst-tags.txt
else
: > dst-tags.txt
fi
echo "Existing destination tags: $(wc -l < dst-tags.txt)"
- name: Mirror, dual-sign, and verify
env:
# keyless
COSIGN_YES: "true"
# key-based
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
# verify
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
run: |
set -euo pipefail
copied=0; skipped=0; v_ok=0; errs=0
issuer="https://token.actions.githubusercontent.com"
id_regex="^https://github.com/${{ github.repository }}/.+"
while read -r tag; do
[ -z "$tag" ] && continue
if grep -Fxq "$tag" dst-tags.txt; then
echo "::notice ::Skip (exists) ${DEST_IMAGE}:${tag}"
skipped=$((skipped+1))
continue
fi
echo "==> Copy ${SOURCE_IMAGE}:${tag} → ${DEST_IMAGE}:${tag}"
if ! skopeo copy --all --retry-times 3 \
docker://"${SOURCE_IMAGE}:${tag}" docker://"${DEST_IMAGE}:${tag}"; then
echo "::warning title=Copy failed::${SOURCE_IMAGE}:${tag}"
errs=$((errs+1)); continue
fi
copied=$((copied+1))
digest="$(skopeo inspect --retry-times 3 docker://"${DEST_IMAGE}:${tag}" | jq -r '.Digest')"
ref="${DEST_IMAGE}@${digest}"
echo "==> cosign sign (keyless) --recursive ${ref}"
if ! cosign sign --recursive "${ref}"; then
echo "::warning title=Keyless sign failed::${ref}"
errs=$((errs+1))
fi
echo "==> cosign sign (key) --recursive ${ref}"
if ! cosign sign --key env://COSIGN_PRIVATE_KEY --recursive "${ref}"; then
echo "::warning title=Key sign failed::${ref}"
errs=$((errs+1))
fi
echo "==> cosign verify (public key) ${ref}"
if ! cosign verify --key env://COSIGN_PUBLIC_KEY "${ref}" -o text; then
echo "::warning title=Verify(pubkey) failed::${ref}"
errs=$((errs+1))
fi
echo "==> cosign verify (keyless policy) ${ref}"
if ! cosign verify \
--certificate-oidc-issuer "${issuer}" \
--certificate-identity-regexp "${id_regex}" \
"${ref}" -o text; then
echo "::warning title=Verify(keyless) failed::${ref}"
errs=$((errs+1))
else
v_ok=$((v_ok+1))
fi
done < src-tags.txt
echo "---- Summary ----"
echo "Copied : $copied"
echo "Skipped : $skipped"
echo "Verified OK : $v_ok"
echo "Errors : $errs"

View File

@@ -11,13 +11,13 @@ on:
jobs:
test:
runs-on: ubuntu-latest
runs-on: amd64-runner
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Set up Go
uses: actions/setup-go@v6
uses: actions/setup-go@44694675825211faa026b3c33043df3e48a5fa00 # v6.0.0
with:
go-version: 1.25

1
.gitignore vendored
View File

@@ -6,3 +6,4 @@ nohup.out
*.iml
certs/
newt_arm64
key

View File

@@ -4,11 +4,7 @@ Contributions are welcome!
Please see the contribution and local development guide on the docs page before getting started:
https://docs.fossorial.io/development
For ideas about what features to work on and our future plans, please see the roadmap:
https://docs.fossorial.io/roadmap
https://docs.pangolin.net/development/contributing
### Licensing Considerations
@@ -21,4 +17,4 @@ By creating this pull request, I grant the project maintainers an unlimited,
perpetual license to use, modify, and redistribute these contributions under any terms they
choose, including both the AGPLv3 and the Fossorial Commercial license terms. I
represent that I have the right to grant this license for all contributed content.
```
```

View File

@@ -1,5 +1,8 @@
FROM golang:1.25-alpine AS builder
# Install git and ca-certificates
RUN apk --no-cache add ca-certificates git tzdata
# Set the working directory inside the container
WORKDIR /app
@@ -13,7 +16,7 @@ RUN go mod download
COPY . .
# Build the application
RUN CGO_ENABLED=0 GOOS=linux go build -o /newt
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o /newt
FROM alpine:3.22 AS runner
@@ -22,6 +25,9 @@ RUN apk --no-cache add ca-certificates tzdata
COPY --from=builder /newt /usr/local/bin/
COPY entrypoint.sh /
# Admin/metrics endpoint (Prometheus scrape)
EXPOSE 2112
RUN chmod +x /entrypoint.sh
ENTRYPOINT ["/entrypoint.sh"]
CMD ["newt"]
CMD ["newt"]

108
README.md
View File

@@ -9,7 +9,7 @@ Newt is a fully user space [WireGuard](https://www.wireguard.com/) tunnel client
Newt is used with Pangolin and Gerbil as part of the larger system. See documentation below:
- [Full Documentation](https://docs.fossorial.io)
- [Full Documentation](https://docs.pangolin.net)
## Preview
@@ -33,61 +33,109 @@ When Newt receives WireGuard control messages, it will use the information encod
## CLI Args
### Core Configuration
- `id`: Newt ID generated by Pangolin to identify the client.
- `secret`: A unique secret (not shared and kept private) used to authenticate the client ID with the websocket in order to receive commands.
- `endpoint`: The endpoint where both Gerbil and Pangolin reside in order to connect to the websocket.
- `mtu` (optional): MTU for the internal WG interface. Default: 1280
- `dns` (optional): DNS server to use to resolve the endpoint. Default: 9.9.9.9
- `blueprint-file` (optional): Path to blueprint file to define Pangolin resources and configurations.
- `no-cloud` (optional): Don't fail over to the cloud when using managed nodes in Pangolin Cloud. Default: false
- `log-level` (optional): The log level to use (DEBUG, INFO, WARN, ERROR, FATAL). Default: INFO
- `enforce-hc-cert` (optional): Enforce certificate validation for health checks. Default: false (accepts any cert)
### Docker Integration
- `docker-socket` (optional): Set the Docker socket to use the container discovery integration
- `ping-interval` (optional): Interval for pinging the server. Default: 3s
- `ping-timeout` (optional): Timeout for each ping. Default: 5s
- `updown` (optional): A script to be called when targets are added or removed.
- `tls-client-cert` (optional): Client certificate (p12 or pfx) for mTLS. See [mTLS](#mtls)
- `tls-client-cert` (optional): Path to client certificate (PEM format, optional if using PKCS12). See [mTLS](#mtls)
- `tls-client-key` (optional): Path to private key for mTLS (PEM format, optional if using PKCS12)
- `tls-ca-cert` (optional): Path to CA certificate to verify server (PEM format, optional if using PKCS12)
- `docker-enforce-network-validation` (optional): Validate the container target is on the same network as the newt process. Default: false
- `health-file` (optional): Check if connection to WG server (pangolin) is ok. creates a file if ok, removes it if not ok. Can be used with docker healtcheck to restart newt
### Accpet Client Connection
- `accept-clients` (optional): Enable WireGuard server mode to accept incoming newt client connections. Default: false
- `generateAndSaveKeyTo` (optional): Path to save generated private key
- `native` (optional): Use native WireGuard interface when accepting clients (requires WireGuard kernel module and Linux, must run as root). Default: false (uses userspace netstack)
- `interface` (optional): Name of the WireGuard interface. Default: newt
- `keep-interface` (optional): Keep the WireGuard interface. Default: false
- `blueprint-file` (optional): Path to blueprint file to define Pangolin resources and configurations.
- `no-cloud` (optional): Don't fail over to the cloud when using managed nodes in Pangolin Cloud. Default: false
### Metrics & Observability
- `metrics` (optional): Enable Prometheus /metrics exporter. Default: true
- `otlp` (optional): Enable OTLP exporters (metrics/traces) to OTEL_EXPORTER_OTLP_ENDPOINT. Default: false
- `metrics-admin-addr` (optional): Admin/metrics bind address. Default: 127.0.0.1:2112
- `metrics-async-bytes` (optional): Enable async bytes counting (background flush; lower hot path overhead). Default: false
- `region` (optional): Optional region resource attribute for telemetry and metrics.
### Network Configuration
- `mtu` (optional): MTU for the internal WG interface. Default: 1280
- `dns` (optional): DNS server to use to resolve the endpoint. Default: 9.9.9.9
- `ping-interval` (optional): Interval for pinging the server. Default: 3s
- `ping-timeout` (optional): Timeout for each ping. Default: 5s
### Security & TLS
- `enforce-hc-cert` (optional): Enforce certificate validation for health checks. Default: false (accepts any cert)
- `tls-client-cert` (optional): Client certificate (p12 or pfx) for mTLS or path to client certificate (PEM format). See [mTLS](#mtls)
- `tls-client-key` (optional): Path to private key for mTLS (PEM format, optional if using PKCS12)
- `tls-ca-cert` (optional): Path to CA certificate to verify server (PEM format, optional if using PKCS12)
### Monitoring & Health
- `health-file` (optional): Check if connection to WG server (pangolin) is ok. creates a file if ok, removes it if not ok. Can be used with docker healtcheck to restart newt
- `updown` (optional): A script to be called when targets are added or removed.
## Environment Variables
All CLI arguments can be set using environment variables as an alternative to command line flags. Environment variables are particularly useful when running Newt in containerized environments.
### Core Configuration
- `PANGOLIN_ENDPOINT`: Endpoint of your pangolin server (equivalent to `--endpoint`)
- `NEWT_ID`: Newt ID generated by Pangolin (equivalent to `--id`)
- `NEWT_SECRET`: Newt secret for authentication (equivalent to `--secret`)
- `MTU`: MTU for the internal WG interface. Default: 1280 (equivalent to `--mtu`)
- `DNS`: DNS server to use to resolve the endpoint. Default: 9.9.9.9 (equivalent to `--dns`)
- `CONFIG_FILE`: Load the config json from this file instead of in the home folder.
- `BLUEPRINT_FILE`: Path to blueprint file to define Pangolin resources and configurations. (equivalent to `--blueprint-file`)
- `NO_CLOUD`: Don't fail over to the cloud when using managed nodes in Pangolin Cloud. Default: false (equivalent to `--no-cloud`)
- `LOG_LEVEL`: Log level (DEBUG, INFO, WARN, ERROR, FATAL). Default: INFO (equivalent to `--log-level`)
### Docker Integration
- `DOCKER_SOCKET`: Path to Docker socket for container discovery (equivalent to `--docker-socket`)
- `PING_INTERVAL`: Interval for pinging the server. Default: 3s (equivalent to `--ping-interval`)
- `PING_TIMEOUT`: Timeout for each ping. Default: 5s (equivalent to `--ping-timeout`)
- `UPDOWN_SCRIPT`: Path to updown script for target add/remove events (equivalent to `--updown`)
- `TLS_CLIENT_CERT`: Path to client certificate for mTLS (equivalent to `--tls-client-cert`)
- `TLS_CLIENT_CERT`: Path to client certificate for mTLS (equivalent to `--tls-client-cert`)
- `TLS_CLIENT_KEY`: Path to private key for mTLS (equivalent to `--tls-client-key`)
- `TLS_CA_CERT`: Path to CA certificate to verify server (equivalent to `--tls-ca-cert`)
- `DOCKER_ENFORCE_NETWORK_VALIDATION`: Validate container targets are on same network. Default: false (equivalent to `--docker-enforce-network-validation`)
- `ENFORCE_HC_CERT`: Enforce certificate validation for health checks. Default: false (equivalent to `--enforce-hc-cert`)
- `HEALTH_FILE`: Path to health file for connection monitoring (equivalent to `--health-file`)
### Accept Client Connections
- `ACCEPT_CLIENTS`: Enable WireGuard server mode. Default: false (equivalent to `--accept-clients`)
- `GENERATE_AND_SAVE_KEY_TO`: Path to save generated private key (equivalent to `--generateAndSaveKeyTo`)
- `USE_NATIVE_INTERFACE`: Use native WireGuard interface (Linux only). Default: false (equivalent to `--native`)
- `INTERFACE`: Name of the WireGuard interface. Default: newt (equivalent to `--interface`)
- `KEEP_INTERFACE`: Keep the WireGuard interface after shutdown. Default: false (equivalent to `--keep-interface`)
- `CONFIG_FILE`: Load the config json from this file instead of in the home folder.
- `BLUEPRINT_FILE`: Path to blueprint file to define Pangolin resources and configurations. (equivalent to `--blueprint-file`)
- `NO_CLOUD`: Don't fail over to the cloud when using managed nodes in Pangolin Cloud. Default: false (equivalent to `--no-cloud`)
### Monitoring & Health
- `HEALTH_FILE`: Path to health file for connection monitoring (equivalent to `--health-file`)
- `UPDOWN_SCRIPT`: Path to updown script for target add/remove events (equivalent to `--updown`)
### Metrics & Observability
- `NEWT_METRICS_PROMETHEUS_ENABLED`: Enable Prometheus /metrics exporter. Default: true (equivalent to `--metrics`)
- `NEWT_METRICS_OTLP_ENABLED`: Enable OTLP exporters (metrics/traces) to OTEL_EXPORTER_OTLP_ENDPOINT. Default: false (equivalent to `--otlp`)
- `NEWT_ADMIN_ADDR`: Admin/metrics bind address. Default: 127.0.0.1:2112 (equivalent to `--metrics-admin-addr`)
- `NEWT_METRICS_ASYNC_BYTES`: Enable async bytes counting (background flush; lower hot path overhead). Default: false (equivalent to `--metrics-async-bytes`)
- `NEWT_REGION`: Optional region resource attribute for telemetry and metrics (equivalent to `--region`)
### Network Configuration
- `MTU`: MTU for the internal WG interface. Default: 1280 (equivalent to `--mtu`)
- `DNS`: DNS server to use to resolve the endpoint. Default: 9.9.9.9 (equivalent to `--dns`)
- `PING_INTERVAL`: Interval for pinging the server. Default: 3s (equivalent to `--ping-interval`)
- `PING_TIMEOUT`: Timeout for each ping. Default: 5s (equivalent to `--ping-timeout`)
### Security & TLS
- `ENFORCE_HC_CERT`: Enforce certificate validation for health checks. Default: false (equivalent to `--enforce-hc-cert`)
- `TLS_CLIENT_CERT`: Path to client certificate for mTLS (equivalent to `--tls-client-cert`)
- `TLS_CLIENT_KEY`: Path to private key for mTLS (equivalent to `--tls-client-key`)
- `TLS_CA_CERT`: Path to CA certificate to verify server (equivalent to `--tls-ca-cert`)
- `SKIP_TLS_VERIFY`: Skip TLS verification for server connections. Default: false
## Loading secrets from files
@@ -98,7 +146,7 @@ $ cat ~/.config/newt-client/config.json
{
"id": "spmzu8rbpzj1qq6",
"secret": "f6v61mjutwme2kkydbw3fjo227zl60a2tsf5psw9r25hgae3",
"endpoint": "https://pangolin.fossorial.io",
"endpoint": "https://app.pangolin.net",
"tlsClientCert": ""
}
```

View File

@@ -3,7 +3,7 @@
If you discover a security vulnerability, please follow the steps below to responsibly disclose it to us:
1. **Do not create a public GitHub issue or discussion post.** This could put the security of other users at risk.
2. Send a detailed report to [security@fossorial.io](mailto:security@fossorial.io) or send a **private** message to a maintainer on [Discord](https://discord.gg/HCJR8Xhme4). Include:
2. Send a detailed report to [security@pangolin.net](mailto:security@pangolin.net) or send a **private** message to a maintainer on [Discord](https://discord.gg/HCJR8Xhme4). Include:
- Description and location of the vulnerability.
- Potential impact of the vulnerability.

675
bind/shared_bind.go Normal file
View File

@@ -0,0 +1,675 @@
//go:build !js
package bind
import (
"bytes"
"fmt"
"net"
"net/netip"
"runtime"
"sync"
"sync/atomic"
"github.com/fosrl/newt/logger"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
)
// Magic packet constants for connection testing
// These packets are intercepted by SharedBind and responded to directly,
// without being passed to the WireGuard device.
var (
// MagicTestRequest is the prefix for a test request packet
// Format: PANGOLIN_TEST_REQ + 8 bytes of random data (for echo)
MagicTestRequest = []byte("PANGOLIN_TEST_REQ")
// MagicTestResponse is the prefix for a test response packet
// Format: PANGOLIN_TEST_RSP + 8 bytes echoed from request
MagicTestResponse = []byte("PANGOLIN_TEST_RSP")
)
const (
// MagicPacketDataLen is the length of random data included in test packets
MagicPacketDataLen = 8
// MagicTestRequestLen is the total length of a test request packet
MagicTestRequestLen = 17 + MagicPacketDataLen // len("PANGOLIN_TEST_REQ") + 8
// MagicTestResponseLen is the total length of a test response packet
MagicTestResponseLen = 17 + MagicPacketDataLen // len("PANGOLIN_TEST_RSP") + 8
)
// PacketSource identifies where a packet came from
type PacketSource uint8
const (
SourceSocket PacketSource = iota // From physical UDP socket (hole-punched clients)
SourceNetstack // From netstack (relay through main tunnel)
)
// SourceAwareEndpoint wraps an endpoint with source information
type SourceAwareEndpoint struct {
wgConn.Endpoint
source PacketSource
}
// GetSource returns the source of this endpoint
func (e *SourceAwareEndpoint) GetSource() PacketSource {
return e.source
}
// injectedPacket represents a packet injected into the SharedBind from an internal source
type injectedPacket struct {
data []byte
endpoint wgConn.Endpoint
}
// Endpoint represents a network endpoint for the SharedBind
type Endpoint struct {
AddrPort netip.AddrPort
}
// ClearSrc implements the wgConn.Endpoint interface
func (e *Endpoint) ClearSrc() {}
// DstIP implements the wgConn.Endpoint interface
func (e *Endpoint) DstIP() netip.Addr {
return e.AddrPort.Addr()
}
// SrcIP implements the wgConn.Endpoint interface
func (e *Endpoint) SrcIP() netip.Addr {
return netip.Addr{}
}
// DstToBytes implements the wgConn.Endpoint interface
func (e *Endpoint) DstToBytes() []byte {
b, _ := e.AddrPort.MarshalBinary()
return b
}
// DstToString implements the wgConn.Endpoint interface
func (e *Endpoint) DstToString() string {
return e.AddrPort.String()
}
// SrcToString implements the wgConn.Endpoint interface
func (e *Endpoint) SrcToString() string {
return ""
}
// SharedBind is a thread-safe UDP bind that can be shared between WireGuard
// and hole punch senders. It wraps a single UDP connection and implements
// reference counting to prevent premature closure.
// It also supports receiving packets from a netstack and routing responses
// back through the appropriate source.
type SharedBind struct {
mu sync.RWMutex
// The underlying UDP connection (for hole-punched clients)
udpConn *net.UDPConn
// IPv4 and IPv6 packet connections for advanced features
ipv4PC *ipv4.PacketConn
ipv6PC *ipv6.PacketConn
// Reference counting to prevent closing while in use
refCount atomic.Int32
closed atomic.Bool
// Channels for receiving data
recvFuncs []wgConn.ReceiveFunc
// Port binding information
port uint16
// Channel for packets from netstack (from direct relay) - larger buffer for throughput
netstackPackets chan injectedPacket
// Netstack connection for sending responses back through the tunnel
// Using atomic.Pointer for lock-free access in hot path
netstackConn atomic.Pointer[net.PacketConn]
// Track which endpoints came from netstack (key: netip.AddrPort, value: struct{})
// Using netip.AddrPort directly as key is more efficient than string
netstackEndpoints sync.Map
// Pre-allocated message buffers for batch operations (Linux only)
ipv4Msgs []ipv4.Message
// Shutdown signal for receive goroutines
closeChan chan struct{}
// Callback for magic test responses (used for holepunch testing)
magicResponseCallback atomic.Pointer[func(addr netip.AddrPort, echoData []byte)]
}
// MagicResponseCallback is the function signature for magic packet response callbacks
type MagicResponseCallback func(addr netip.AddrPort, echoData []byte)
// New creates a new SharedBind from an existing UDP connection.
// The SharedBind takes ownership of the connection and will close it
// when all references are released.
func New(udpConn *net.UDPConn) (*SharedBind, error) {
if udpConn == nil {
return nil, fmt.Errorf("udpConn cannot be nil")
}
bind := &SharedBind{
udpConn: udpConn,
netstackPackets: make(chan injectedPacket, 1024), // Larger buffer for better throughput
closeChan: make(chan struct{}),
}
// Initialize reference count to 1 (the creator holds the first reference)
bind.refCount.Store(1)
// Get the local port
if addr, ok := udpConn.LocalAddr().(*net.UDPAddr); ok {
bind.port = uint16(addr.Port)
}
return bind, nil
}
// SetNetstackConn sets the netstack connection for receiving/sending packets through the tunnel.
// This connection is used for relay traffic that should go back through the main tunnel.
func (b *SharedBind) SetNetstackConn(conn net.PacketConn) {
b.netstackConn.Store(&conn)
}
// GetNetstackConn returns the netstack connection if set
func (b *SharedBind) GetNetstackConn() net.PacketConn {
ptr := b.netstackConn.Load()
if ptr == nil {
return nil
}
return *ptr
}
// InjectPacket allows injecting a packet directly into the SharedBind's receive path.
// This is used for direct relay from netstack without going through the host network.
// The fromAddr should be the address the packet appears to come from.
func (b *SharedBind) InjectPacket(data []byte, fromAddr netip.AddrPort) error {
if b.closed.Load() {
return net.ErrClosed
}
// Unmap IPv4-in-IPv6 addresses to ensure consistency with parsed endpoints
if fromAddr.Addr().Is4In6() {
fromAddr = netip.AddrPortFrom(fromAddr.Addr().Unmap(), fromAddr.Port())
}
// Track this endpoint as coming from netstack so responses go back the same way
// Use AddrPort directly as key (more efficient than string)
b.netstackEndpoints.Store(fromAddr, struct{}{})
// Make a copy of the data to avoid issues with buffer reuse
dataCopy := make([]byte, len(data))
copy(dataCopy, data)
select {
case b.netstackPackets <- injectedPacket{
data: dataCopy,
endpoint: &wgConn.StdNetEndpoint{AddrPort: fromAddr},
}:
return nil
case <-b.closeChan:
return net.ErrClosed
default:
// Channel full, drop the packet
return fmt.Errorf("netstack packet buffer full")
}
}
// AddRef increments the reference count. Call this when sharing
// the bind with another component.
func (b *SharedBind) AddRef() {
newCount := b.refCount.Add(1)
// Optional: Add logging for debugging
_ = newCount // Placeholder for potential logging
}
// Release decrements the reference count. When it reaches zero,
// the underlying UDP connection is closed.
func (b *SharedBind) Release() error {
newCount := b.refCount.Add(-1)
// Optional: Add logging for debugging
_ = newCount // Placeholder for potential logging
if newCount < 0 {
// This should never happen with proper usage
b.refCount.Store(0)
return fmt.Errorf("SharedBind reference count went negative")
}
if newCount == 0 {
return b.closeConnection()
}
return nil
}
// closeConnection actually closes the UDP connection
func (b *SharedBind) closeConnection() error {
if !b.closed.CompareAndSwap(false, true) {
// Already closed
return nil
}
// Signal all goroutines to stop
close(b.closeChan)
b.mu.Lock()
defer b.mu.Unlock()
var err error
if b.udpConn != nil {
err = b.udpConn.Close()
b.udpConn = nil
}
b.ipv4PC = nil
b.ipv6PC = nil
// Clear netstack connection (but don't close it - it's managed externally)
b.netstackConn.Store(nil)
// Clear tracked netstack endpoints
b.netstackEndpoints = sync.Map{}
return err
}
// ClearNetstackConn clears the netstack connection and tracked endpoints.
// Call this when stopping the relay.
func (b *SharedBind) ClearNetstackConn() {
b.netstackConn.Store(nil)
// Clear tracked netstack endpoints
b.netstackEndpoints = sync.Map{}
}
// GetUDPConn returns the underlying UDP connection.
// The caller must not close this connection directly.
func (b *SharedBind) GetUDPConn() *net.UDPConn {
b.mu.RLock()
defer b.mu.RUnlock()
return b.udpConn
}
// GetRefCount returns the current reference count (for debugging)
func (b *SharedBind) GetRefCount() int32 {
return b.refCount.Load()
}
// IsClosed returns whether the bind is closed
func (b *SharedBind) IsClosed() bool {
return b.closed.Load()
}
// SetMagicResponseCallback sets a callback function that will be called when
// a magic test response packet is received. This is used for holepunch testing.
// Pass nil to clear the callback.
func (b *SharedBind) SetMagicResponseCallback(callback MagicResponseCallback) {
if callback == nil {
b.magicResponseCallback.Store(nil)
} else {
// Convert to the function type the atomic.Pointer expects
fn := func(addr netip.AddrPort, echoData []byte) {
callback(addr, echoData)
}
b.magicResponseCallback.Store(&fn)
}
}
// WriteToUDP writes data to a specific UDP address.
// This is thread-safe and can be used by hole punch senders.
func (b *SharedBind) WriteToUDP(data []byte, addr *net.UDPAddr) (int, error) {
if b.closed.Load() {
return 0, net.ErrClosed
}
b.mu.RLock()
conn := b.udpConn
b.mu.RUnlock()
if conn == nil {
return 0, net.ErrClosed
}
return conn.WriteToUDP(data, addr)
}
// Close implements the WireGuard Bind interface.
// It decrements the reference count and closes the connection if no references remain.
func (b *SharedBind) Close() error {
return b.Release()
}
// Open implements the WireGuard Bind interface.
// Since the connection is already open, this just sets up the receive functions.
func (b *SharedBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
if b.closed.Load() {
return nil, 0, net.ErrClosed
}
b.mu.Lock()
defer b.mu.Unlock()
if b.udpConn == nil {
return nil, 0, net.ErrClosed
}
// Set up IPv4 and IPv6 packet connections for advanced features
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
b.ipv4PC = ipv4.NewPacketConn(b.udpConn)
b.ipv6PC = ipv6.NewPacketConn(b.udpConn)
// Pre-allocate message buffers for batch operations
batchSize := wgConn.IdealBatchSize
b.ipv4Msgs = make([]ipv4.Message, batchSize)
for i := range b.ipv4Msgs {
b.ipv4Msgs[i].OOB = make([]byte, 0)
}
}
// Create receive functions - one for socket, one for netstack
recvFuncs := make([]wgConn.ReceiveFunc, 0, 2)
// Add socket receive function (reads from physical UDP socket)
recvFuncs = append(recvFuncs, b.makeReceiveSocket())
// Add netstack receive function (reads from injected packets channel)
recvFuncs = append(recvFuncs, b.makeReceiveNetstack())
b.recvFuncs = recvFuncs
return recvFuncs, b.port, nil
}
// makeReceiveSocket creates a receive function for physical UDP socket packets
func (b *SharedBind) makeReceiveSocket() wgConn.ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
if b.closed.Load() {
return 0, net.ErrClosed
}
b.mu.RLock()
conn := b.udpConn
pc := b.ipv4PC
b.mu.RUnlock()
if conn == nil {
return 0, net.ErrClosed
}
// Use batch reading on Linux for performance
if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") {
return b.receiveIPv4Batch(pc, bufs, sizes, eps)
}
return b.receiveIPv4Simple(conn, bufs, sizes, eps)
}
}
// makeReceiveNetstack creates a receive function for netstack-injected packets
func (b *SharedBind) makeReceiveNetstack() wgConn.ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
select {
case <-b.closeChan:
return 0, net.ErrClosed
case pkt := <-b.netstackPackets:
if len(pkt.data) <= len(bufs[0]) {
copy(bufs[0], pkt.data)
sizes[0] = len(pkt.data)
eps[0] = pkt.endpoint
return 1, nil
}
// Packet too large for buffer, skip it
return 0, nil
}
}
}
// receiveIPv4Batch uses batch reading for better performance on Linux
func (b *SharedBind) receiveIPv4Batch(pc *ipv4.PacketConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
// Use pre-allocated messages, just update buffer pointers
numBufs := len(bufs)
if numBufs > len(b.ipv4Msgs) {
numBufs = len(b.ipv4Msgs)
}
for i := 0; i < numBufs; i++ {
b.ipv4Msgs[i].Buffers = [][]byte{bufs[i]}
}
numMsgs, err := pc.ReadBatch(b.ipv4Msgs[:numBufs], 0)
if err != nil {
return 0, err
}
// Process messages and filter out magic packets
writeIdx := 0
for i := 0; i < numMsgs; i++ {
if b.ipv4Msgs[i].N == 0 {
continue
}
// Check for magic packet
if b.ipv4Msgs[i].Addr != nil {
if udpAddr, ok := b.ipv4Msgs[i].Addr.(*net.UDPAddr); ok {
data := bufs[i][:b.ipv4Msgs[i].N]
if b.handleMagicPacket(data, udpAddr) {
// Magic packet handled, skip this message
continue
}
}
}
// Not a magic packet, include in output
if writeIdx != i {
// Need to copy data to the correct position
copy(bufs[writeIdx], bufs[i][:b.ipv4Msgs[i].N])
}
sizes[writeIdx] = b.ipv4Msgs[i].N
if b.ipv4Msgs[i].Addr != nil {
if udpAddr, ok := b.ipv4Msgs[i].Addr.(*net.UDPAddr); ok {
addrPort := udpAddr.AddrPort()
// Unmap IPv4-in-IPv6 addresses to ensure consistency with parsed endpoints
if addrPort.Addr().Is4In6() {
addrPort = netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())
}
eps[writeIdx] = &wgConn.StdNetEndpoint{AddrPort: addrPort}
}
}
writeIdx++
}
return writeIdx, nil
}
// receiveIPv4Simple uses simple ReadFromUDP for non-Linux platforms
func (b *SharedBind) receiveIPv4Simple(conn *net.UDPConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
for {
n, addr, err := conn.ReadFromUDP(bufs[0])
if err != nil {
return 0, err
}
// Check for magic test packet and handle it directly
if b.handleMagicPacket(bufs[0][:n], addr) {
// Magic packet was handled, read another packet
continue
}
sizes[0] = n
if addr != nil {
addrPort := addr.AddrPort()
// Unmap IPv4-in-IPv6 addresses to ensure consistency with parsed endpoints
if addrPort.Addr().Is4In6() {
addrPort = netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())
}
eps[0] = &wgConn.StdNetEndpoint{AddrPort: addrPort}
}
return 1, nil
}
}
// handleMagicPacket checks if the packet is a magic test packet and responds if so.
// Returns true if the packet was a magic packet and was handled (should not be passed to WireGuard).
func (b *SharedBind) handleMagicPacket(data []byte, addr *net.UDPAddr) bool {
// Check if this is a test request packet
if len(data) >= MagicTestRequestLen && bytes.HasPrefix(data, MagicTestRequest) {
logger.Debug("Received magic test REQUEST from %s, sending response", addr.String())
// Extract the random data portion to echo back
echoData := data[len(MagicTestRequest) : len(MagicTestRequest)+MagicPacketDataLen]
// Build response packet
response := make([]byte, MagicTestResponseLen)
copy(response, MagicTestResponse)
copy(response[len(MagicTestResponse):], echoData)
// Send response back to sender
b.mu.RLock()
conn := b.udpConn
b.mu.RUnlock()
if conn != nil {
_, _ = conn.WriteToUDP(response, addr)
}
return true
}
// Check if this is a test response packet
if len(data) >= MagicTestResponseLen && bytes.HasPrefix(data, MagicTestResponse) {
logger.Debug("Received magic test RESPONSE from %s", addr.String())
// Extract the echoed data
echoData := data[len(MagicTestResponse) : len(MagicTestResponse)+MagicPacketDataLen]
// Call the callback if set
callbackPtr := b.magicResponseCallback.Load()
if callbackPtr != nil {
callback := *callbackPtr
addrPort := addr.AddrPort()
// Unmap IPv4-in-IPv6 addresses to ensure consistency
if addrPort.Addr().Is4In6() {
addrPort = netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())
}
callback(addrPort, echoData)
} else {
logger.Debug("Magic response received but no callback registered")
}
return true
}
return false
}
// Send implements the WireGuard Bind interface.
// It sends packets to the specified endpoint, routing through the appropriate
// source (netstack or physical socket) based on where the endpoint's packets came from.
func (b *SharedBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
if b.closed.Load() {
return net.ErrClosed
}
// Extract the destination address from the endpoint
var destAddrPort netip.AddrPort
// Try to cast to StdNetEndpoint first (most common case, avoid allocations)
if stdEp, ok := ep.(*wgConn.StdNetEndpoint); ok {
destAddrPort = stdEp.AddrPort
} else {
// Fallback: construct from DstIP and DstToBytes
dstBytes := ep.DstToBytes()
if len(dstBytes) >= 6 { // Minimum for IPv4 (4 bytes) + port (2 bytes)
var addr netip.Addr
var port uint16
if len(dstBytes) >= 18 { // IPv6 (16 bytes) + port (2 bytes)
addr, _ = netip.AddrFromSlice(dstBytes[:16])
port = uint16(dstBytes[16]) | uint16(dstBytes[17])<<8
} else { // IPv4
addr, _ = netip.AddrFromSlice(dstBytes[:4])
port = uint16(dstBytes[4]) | uint16(dstBytes[5])<<8
}
if addr.IsValid() {
destAddrPort = netip.AddrPortFrom(addr, port)
}
}
}
if !destAddrPort.IsValid() {
return fmt.Errorf("could not extract destination address from endpoint")
}
// Check if this endpoint came from netstack - if so, send through netstack
// Use AddrPort directly as key (more efficient than string conversion)
if _, isNetstackEndpoint := b.netstackEndpoints.Load(destAddrPort); isNetstackEndpoint {
connPtr := b.netstackConn.Load()
if connPtr != nil && *connPtr != nil {
netstackConn := *connPtr
destAddr := net.UDPAddrFromAddrPort(destAddrPort)
// Send all buffers through netstack
for _, buf := range bufs {
_, err := netstackConn.WriteTo(buf, destAddr)
if err != nil {
return err
}
}
return nil
}
// Fall through to socket if netstack conn not available
}
// Send through the physical UDP socket (for hole-punched clients)
b.mu.RLock()
conn := b.udpConn
b.mu.RUnlock()
if conn == nil {
return net.ErrClosed
}
destAddr := net.UDPAddrFromAddrPort(destAddrPort)
// Send all buffers to the destination
for _, buf := range bufs {
_, err := conn.WriteToUDP(buf, destAddr)
if err != nil {
return err
}
}
return nil
}
// SetMark implements the WireGuard Bind interface.
// It's a no-op for this implementation.
func (b *SharedBind) SetMark(mark uint32) error {
// Not implemented for this use case
return nil
}
// BatchSize returns the preferred batch size for sending packets.
func (b *SharedBind) BatchSize() int {
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
return wgConn.IdealBatchSize
}
return 1
}
// ParseEndpoint creates a new endpoint from a string address.
func (b *SharedBind) ParseEndpoint(s string) (wgConn.Endpoint, error) {
addrPort, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}
return &wgConn.StdNetEndpoint{AddrPort: addrPort}, nil
}

605
bind/shared_bind_test.go Normal file
View File

@@ -0,0 +1,605 @@
//go:build !js
package bind
import (
"net"
"net/netip"
"sync"
"testing"
"time"
wgConn "golang.zx2c4.com/wireguard/conn"
)
// TestSharedBindCreation tests basic creation and initialization
func TestSharedBindCreation(t *testing.T) {
// Create a UDP connection
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
defer udpConn.Close()
// Create SharedBind
bind, err := New(udpConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
if bind == nil {
t.Fatal("SharedBind is nil")
}
// Verify initial reference count
if bind.refCount.Load() != 1 {
t.Errorf("Expected initial refCount to be 1, got %d", bind.refCount.Load())
}
// Clean up
if err := bind.Close(); err != nil {
t.Errorf("Failed to close SharedBind: %v", err)
}
}
// TestSharedBindReferenceCount tests reference counting
func TestSharedBindReferenceCount(t *testing.T) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
bind, err := New(udpConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
// Add references
bind.AddRef()
if bind.refCount.Load() != 2 {
t.Errorf("Expected refCount to be 2, got %d", bind.refCount.Load())
}
bind.AddRef()
if bind.refCount.Load() != 3 {
t.Errorf("Expected refCount to be 3, got %d", bind.refCount.Load())
}
// Release references
bind.Release()
if bind.refCount.Load() != 2 {
t.Errorf("Expected refCount to be 2 after release, got %d", bind.refCount.Load())
}
bind.Release()
bind.Release() // This should close the connection
if !bind.closed.Load() {
t.Error("Expected bind to be closed after all references released")
}
}
// TestSharedBindWriteToUDP tests the WriteToUDP functionality
func TestSharedBindWriteToUDP(t *testing.T) {
// Create sender
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create sender UDP connection: %v", err)
}
senderBind, err := New(senderConn)
if err != nil {
t.Fatalf("Failed to create sender SharedBind: %v", err)
}
defer senderBind.Close()
// Create receiver
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create receiver UDP connection: %v", err)
}
defer receiverConn.Close()
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
// Send data
testData := []byte("Hello, SharedBind!")
n, err := senderBind.WriteToUDP(testData, receiverAddr)
if err != nil {
t.Fatalf("WriteToUDP failed: %v", err)
}
if n != len(testData) {
t.Errorf("Expected to send %d bytes, sent %d", len(testData), n)
}
// Receive data
buf := make([]byte, 1024)
receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second))
n, _, err = receiverConn.ReadFromUDP(buf)
if err != nil {
t.Fatalf("Failed to receive data: %v", err)
}
if string(buf[:n]) != string(testData) {
t.Errorf("Expected to receive %q, got %q", testData, buf[:n])
}
}
// TestSharedBindConcurrentWrites tests thread-safety
func TestSharedBindConcurrentWrites(t *testing.T) {
// Create sender
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create sender UDP connection: %v", err)
}
senderBind, err := New(senderConn)
if err != nil {
t.Fatalf("Failed to create sender SharedBind: %v", err)
}
defer senderBind.Close()
// Create receiver
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create receiver UDP connection: %v", err)
}
defer receiverConn.Close()
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
// Launch concurrent writes
numGoroutines := 100
var wg sync.WaitGroup
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer wg.Done()
data := []byte{byte(id)}
_, err := senderBind.WriteToUDP(data, receiverAddr)
if err != nil {
t.Errorf("WriteToUDP failed in goroutine %d: %v", id, err)
}
}(i)
}
wg.Wait()
}
// TestSharedBindWireGuardInterface tests WireGuard Bind interface implementation
func TestSharedBindWireGuardInterface(t *testing.T) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
bind, err := New(udpConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
defer bind.Close()
// Test Open
recvFuncs, port, err := bind.Open(0)
if err != nil {
t.Fatalf("Open failed: %v", err)
}
if len(recvFuncs) == 0 {
t.Error("Expected at least one receive function")
}
if port == 0 {
t.Error("Expected non-zero port")
}
// Test SetMark (should be a no-op)
if err := bind.SetMark(0); err != nil {
t.Errorf("SetMark failed: %v", err)
}
// Test BatchSize
batchSize := bind.BatchSize()
if batchSize <= 0 {
t.Error("Expected positive batch size")
}
}
// TestSharedBindSend tests the Send method with WireGuard endpoints
func TestSharedBindSend(t *testing.T) {
// Create sender
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create sender UDP connection: %v", err)
}
senderBind, err := New(senderConn)
if err != nil {
t.Fatalf("Failed to create sender SharedBind: %v", err)
}
defer senderBind.Close()
// Create receiver
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create receiver UDP connection: %v", err)
}
defer receiverConn.Close()
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
// Create an endpoint
addrPort := receiverAddr.AddrPort()
endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort}
// Send data
testData := []byte("WireGuard packet")
bufs := [][]byte{testData}
err = senderBind.Send(bufs, endpoint)
if err != nil {
t.Fatalf("Send failed: %v", err)
}
// Receive data
buf := make([]byte, 1024)
receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second))
n, _, err := receiverConn.ReadFromUDP(buf)
if err != nil {
t.Fatalf("Failed to receive data: %v", err)
}
if string(buf[:n]) != string(testData) {
t.Errorf("Expected to receive %q, got %q", testData, buf[:n])
}
}
// TestSharedBindMultipleUsers simulates WireGuard and hole punch using the same bind
func TestSharedBindMultipleUsers(t *testing.T) {
// Create shared bind
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
sharedBind, err := New(udpConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
// Add reference for hole punch sender
sharedBind.AddRef()
// Create receiver
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create receiver UDP connection: %v", err)
}
defer receiverConn.Close()
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
var wg sync.WaitGroup
// Simulate WireGuard using the bind
wg.Add(1)
go func() {
defer wg.Done()
addrPort := receiverAddr.AddrPort()
endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort}
for i := 0; i < 10; i++ {
data := []byte("WireGuard packet")
bufs := [][]byte{data}
if err := sharedBind.Send(bufs, endpoint); err != nil {
t.Errorf("WireGuard Send failed: %v", err)
}
time.Sleep(10 * time.Millisecond)
}
}()
// Simulate hole punch sender using the bind
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 10; i++ {
data := []byte("Hole punch packet")
if _, err := sharedBind.WriteToUDP(data, receiverAddr); err != nil {
t.Errorf("Hole punch WriteToUDP failed: %v", err)
}
time.Sleep(10 * time.Millisecond)
}
}()
wg.Wait()
// Release the hole punch reference
sharedBind.Release()
// Close WireGuard's reference (should close the connection)
sharedBind.Close()
if !sharedBind.closed.Load() {
t.Error("Expected bind to be closed after all users released it")
}
}
// TestEndpoint tests the Endpoint implementation
func TestEndpoint(t *testing.T) {
addr := netip.MustParseAddr("192.168.1.1")
addrPort := netip.AddrPortFrom(addr, 51820)
ep := &Endpoint{AddrPort: addrPort}
// Test DstIP
if ep.DstIP() != addr {
t.Errorf("Expected DstIP to be %v, got %v", addr, ep.DstIP())
}
// Test DstToString
expected := "192.168.1.1:51820"
if ep.DstToString() != expected {
t.Errorf("Expected DstToString to be %q, got %q", expected, ep.DstToString())
}
// Test DstToBytes
bytes := ep.DstToBytes()
if len(bytes) == 0 {
t.Error("Expected DstToBytes to return non-empty slice")
}
// Test SrcIP (should be zero)
if ep.SrcIP().IsValid() {
t.Error("Expected SrcIP to be invalid")
}
// Test ClearSrc (should not panic)
ep.ClearSrc()
}
// TestParseEndpoint tests the ParseEndpoint method
func TestParseEndpoint(t *testing.T) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
bind, err := New(udpConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
defer bind.Close()
tests := []struct {
name string
input string
wantErr bool
checkAddr func(*testing.T, wgConn.Endpoint)
}{
{
name: "valid IPv4",
input: "192.168.1.1:51820",
wantErr: false,
checkAddr: func(t *testing.T, ep wgConn.Endpoint) {
if ep.DstToString() != "192.168.1.1:51820" {
t.Errorf("Expected 192.168.1.1:51820, got %s", ep.DstToString())
}
},
},
{
name: "valid IPv6",
input: "[::1]:51820",
wantErr: false,
checkAddr: func(t *testing.T, ep wgConn.Endpoint) {
if ep.DstToString() != "[::1]:51820" {
t.Errorf("Expected [::1]:51820, got %s", ep.DstToString())
}
},
},
{
name: "invalid - missing port",
input: "192.168.1.1",
wantErr: true,
},
{
name: "invalid - bad format",
input: "not-an-address",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ep, err := bind.ParseEndpoint(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("ParseEndpoint() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && tt.checkAddr != nil {
tt.checkAddr(t, ep)
}
})
}
}
// TestNetstackRouting tests that packets from netstack endpoints are routed back through netstack
func TestNetstackRouting(t *testing.T) {
// Create the SharedBind with a physical UDP socket
physicalConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create physical UDP connection: %v", err)
}
sharedBind, err := New(physicalConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
defer sharedBind.Close()
// Create a mock "netstack" connection (just another UDP socket for testing)
netstackConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create netstack UDP connection: %v", err)
}
defer netstackConn.Close()
// Set the netstack connection
sharedBind.SetNetstackConn(netstackConn)
// Create a "client" that would receive packets
clientConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create client UDP connection: %v", err)
}
defer clientConn.Close()
clientAddr := clientConn.LocalAddr().(*net.UDPAddr)
clientAddrPort := clientAddr.AddrPort()
// Inject a packet from the "netstack" source - this should track the endpoint
testData := []byte("test packet from netstack")
err = sharedBind.InjectPacket(testData, clientAddrPort)
if err != nil {
t.Fatalf("InjectPacket failed: %v", err)
}
// Now when we send a response to this endpoint, it should go through netstack
endpoint := &wgConn.StdNetEndpoint{AddrPort: clientAddrPort}
responseData := []byte("response packet")
err = sharedBind.Send([][]byte{responseData}, endpoint)
if err != nil {
t.Fatalf("Send failed: %v", err)
}
// The packet should be received by the client from the netstack connection
buf := make([]byte, 1024)
clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
n, fromAddr, err := clientConn.ReadFromUDP(buf)
if err != nil {
t.Fatalf("Failed to receive response: %v", err)
}
if string(buf[:n]) != string(responseData) {
t.Errorf("Expected to receive %q, got %q", responseData, buf[:n])
}
// Verify the response came from the netstack connection, not the physical one
netstackAddr := netstackConn.LocalAddr().(*net.UDPAddr)
if fromAddr.Port != netstackAddr.Port {
t.Errorf("Expected response from netstack port %d, got %d", netstackAddr.Port, fromAddr.Port)
}
}
// TestSocketRouting tests that packets from socket endpoints are routed through socket
func TestSocketRouting(t *testing.T) {
// Create the SharedBind with a physical UDP socket
physicalConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create physical UDP connection: %v", err)
}
sharedBind, err := New(physicalConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
defer sharedBind.Close()
// Create a mock "netstack" connection
netstackConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create netstack UDP connection: %v", err)
}
defer netstackConn.Close()
// Set the netstack connection
sharedBind.SetNetstackConn(netstackConn)
// Create a "client" that would receive packets (this simulates a hole-punched client)
clientConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create client UDP connection: %v", err)
}
defer clientConn.Close()
clientAddr := clientConn.LocalAddr().(*net.UDPAddr)
clientAddrPort := clientAddr.AddrPort()
// Don't inject from netstack - this endpoint is NOT tracked as netstack-sourced
// So Send should use the physical socket
endpoint := &wgConn.StdNetEndpoint{AddrPort: clientAddrPort}
responseData := []byte("response packet via socket")
err = sharedBind.Send([][]byte{responseData}, endpoint)
if err != nil {
t.Fatalf("Send failed: %v", err)
}
// The packet should be received by the client from the physical connection
buf := make([]byte, 1024)
clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
n, fromAddr, err := clientConn.ReadFromUDP(buf)
if err != nil {
t.Fatalf("Failed to receive response: %v", err)
}
if string(buf[:n]) != string(responseData) {
t.Errorf("Expected to receive %q, got %q", responseData, buf[:n])
}
// Verify the response came from the physical connection, not the netstack one
physicalAddr := physicalConn.LocalAddr().(*net.UDPAddr)
if fromAddr.Port != physicalAddr.Port {
t.Errorf("Expected response from physical port %d, got %d", physicalAddr.Port, fromAddr.Port)
}
}
// TestClearNetstackConn tests that clearing the netstack connection works correctly
func TestClearNetstackConn(t *testing.T) {
physicalConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create physical UDP connection: %v", err)
}
sharedBind, err := New(physicalConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
defer sharedBind.Close()
// Set a netstack connection
netstackConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create netstack UDP connection: %v", err)
}
defer netstackConn.Close()
sharedBind.SetNetstackConn(netstackConn)
// Inject a packet to track an endpoint
testAddrPort := netip.MustParseAddrPort("192.168.1.100:51820")
err = sharedBind.InjectPacket([]byte("test"), testAddrPort)
if err != nil {
t.Fatalf("InjectPacket failed: %v", err)
}
// Verify the endpoint is tracked
_, tracked := sharedBind.netstackEndpoints.Load(testAddrPort.String())
if !tracked {
t.Error("Expected endpoint to be tracked as netstack-sourced")
}
// Clear the netstack connection
sharedBind.ClearNetstackConn()
// Verify the netstack connection is cleared
if sharedBind.GetNetstackConn() != nil {
t.Error("Expected netstack connection to be nil after clear")
}
// Verify the tracked endpoints are cleared
_, stillTracked := sharedBind.netstackEndpoints.Load(testAddrPort.String())
if stillTracked {
t.Error("Expected endpoint tracking to be cleared")
}
}

View File

@@ -12,9 +12,9 @@ resources:
sso-roles:
- Member
sso-users:
- owen@fossorial.io
- owen@pangolin.net
whitelist-users:
- owen@fossorial.io
- owen@pangolin.net
targets:
# - site: glossy-plains-viscacha-rat
- hostname: localhost

View File

@@ -1,20 +1,17 @@
package main
import (
"fmt"
"strings"
"github.com/fosrl/newt/clients"
wgnetstack "github.com/fosrl/newt/clients"
"github.com/fosrl/newt/clients/permissions"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/proxy"
"github.com/fosrl/newt/websocket"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/fosrl/newt/wgnetstack"
"github.com/fosrl/newt/wgtester"
)
var wgService *wgnetstack.WireGuardService
var wgTesterServer *wgtester.Server
var wgService *clients.WireGuardService
var ready bool
func setupClients(client *websocket.Client) {
@@ -27,43 +24,29 @@ func setupClients(client *websocket.Client) {
host = strings.TrimSuffix(host, "/")
logger.Info("Setting up clients with netstack2...")
// if useNativeInterface is true make sure we have permission to use native interface
if useNativeInterface {
setupClientsNative(client, host)
} else {
setupClientsNetstack(client, host)
logger.Debug("Checking permissions for native interface")
err := permissions.CheckNativeInterfacePermissions()
if err != nil {
logger.Fatal("Insufficient permissions to create native TUN interface: %v", err)
return
}
}
ready = true
}
func setupClientsNetstack(client *websocket.Client, host string) {
logger.Info("Setting up clients with netstack...")
// Create WireGuard service
wgService, err = wgnetstack.NewWireGuardService(interfaceName, mtuInt, generateAndSaveKeyTo, host, id, client, "9.9.9.9")
wgService, err = wgnetstack.NewWireGuardService(interfaceName, mtuInt, host, id, client, dns, useNativeInterface)
if err != nil {
logger.Fatal("Failed to create WireGuard service: %v", err)
}
// // Set up callback to restart wgtester with netstack when WireGuard is ready
wgService.SetOnNetstackReady(func(tnet *netstack.Net) {
wgTesterServer = wgtester.NewServerWithNetstack("0.0.0.0", wgService.Port, id, tnet) // TODO: maybe make this the same ip of the wg server?
err := wgTesterServer.Start()
if err != nil {
logger.Error("Failed to start WireGuard tester server: %v", err)
}
})
wgService.SetOnNetstackClose(func() {
if wgTesterServer != nil {
wgTesterServer.Stop()
wgTesterServer = nil
}
})
client.OnTokenUpdate(func(token string) {
wgService.SetToken(token)
})
ready = true
}
func setDownstreamTNetstack(tnet *netstack.Net) {
@@ -75,16 +58,9 @@ func setDownstreamTNetstack(tnet *netstack.Net) {
func closeClients() {
logger.Info("Closing clients...")
if wgService != nil {
wgService.Close(!keepInterface)
wgService.Close()
wgService = nil
}
closeWgServiceNative()
if wgTesterServer != nil {
wgTesterServer.Stop()
wgTesterServer = nil
}
}
func clientsHandleNewtConnection(publicKey string, endpoint string) {
@@ -103,8 +79,6 @@ func clientsHandleNewtConnection(publicKey string, endpoint string) {
if wgService != nil {
wgService.StartHolepunch(publicKey, endpoint)
}
clientsHandleNewtConnectionNative(publicKey, endpoint)
}
func clientsOnConnect() {
@@ -114,19 +88,17 @@ func clientsOnConnect() {
if wgService != nil {
wgService.LoadRemoteConfig()
}
clientsOnConnectNative()
}
func clientsAddProxyTarget(pm *proxy.ProxyManager, tunnelIp string) {
// clientsStartDirectRelay starts a direct UDP relay from the main tunnel netstack
// to the clients' WireGuard, bypassing the proxy for better performance.
func clientsStartDirectRelay(tunnelIP string) {
if !ready {
return
}
// add a udp proxy for localost and the wgService port
// TODO: make sure this port is not used in a target
if wgService != nil {
pm.AddTarget("udp", tunnelIp, int(wgService.Port), fmt.Sprintf("127.0.0.1:%d", wgService.Port))
if err := wgService.StartDirectUDPRelay(tunnelIP); err != nil {
logger.Error("Failed to start direct UDP relay: %v", err)
}
}
clientsAddProxyTargetNative(pm, tunnelIp)
}

1253
clients/clients.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,18 @@
//go:build darwin
package permissions
import (
"fmt"
"os"
)
// CheckNativeInterfacePermissions checks if the process has sufficient
// permissions to create a native TUN interface on macOS.
// This typically requires root privileges.
func CheckNativeInterfacePermissions() error {
if os.Geteuid() == 0 {
return nil
}
return fmt.Errorf("insufficient permissions: need root to create TUN interface on macOS")
}

View File

@@ -0,0 +1,96 @@
//go:build linux
package permissions
import (
"fmt"
"os"
"unsafe"
"github.com/fosrl/newt/logger"
"golang.org/x/sys/unix"
)
const (
// TUN device constants
tunDevice = "/dev/net/tun"
ifnamsiz = 16
iffTun = 0x0001
iffNoPi = 0x1000
tunSetIff = 0x400454ca
)
// ifReq is the structure for TUNSETIFF ioctl
type ifReq struct {
Name [ifnamsiz]byte
Flags uint16
_ [22]byte // padding to match kernel structure
}
// CheckNativeInterfacePermissions checks if the process has sufficient
// permissions to create a native TUN interface on Linux.
// This requires either root privileges (UID 0) or CAP_NET_ADMIN capability.
func CheckNativeInterfacePermissions() error {
logger.Debug("Checking native interface permissions on Linux")
// Check if running as root
if os.Geteuid() == 0 {
logger.Debug("Running as root, sufficient permissions for native TUN interface")
return nil
}
// Check for CAP_NET_ADMIN capability
caps := unix.CapUserHeader{
Version: unix.LINUX_CAPABILITY_VERSION_3,
Pid: 0, // 0 means current process
}
var data [2]unix.CapUserData
if err := unix.Capget(&caps, &data[0]); err != nil {
logger.Debug("Failed to get capabilities: %v, will try creating test TUN", err)
} else {
// CAP_NET_ADMIN is capability bit 12
const CAP_NET_ADMIN = 12
if data[0].Effective&(1<<CAP_NET_ADMIN) != 0 {
logger.Debug("Process has CAP_NET_ADMIN capability, sufficient permissions for native TUN interface")
return nil
}
logger.Debug("Process does not have CAP_NET_ADMIN capability in effective set")
}
// Actually try to create a TUN interface to verify permissions
// This is the most reliable check as it tests the actual operation
return tryCreateTestTun()
}
// tryCreateTestTun attempts to create a temporary TUN interface to verify
// we have the necessary permissions. This tests the actual ioctl call that
// will be used when creating the real interface.
func tryCreateTestTun() error {
f, err := os.OpenFile(tunDevice, os.O_RDWR, 0)
if err != nil {
return fmt.Errorf("cannot open %s: %v (need root or CAP_NET_ADMIN capability)", tunDevice, err)
}
defer f.Close()
// Try to create a TUN interface with a test name
// Using a random-ish name to avoid conflicts
var req ifReq
copy(req.Name[:], "tuntest0")
req.Flags = iffTun | iffNoPi
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
f.Fd(),
uintptr(tunSetIff),
uintptr(unsafe.Pointer(&req)),
)
if errno != 0 {
return fmt.Errorf("cannot create TUN interface (ioctl TUNSETIFF failed): %v (need root or CAP_NET_ADMIN capability)", errno)
}
// Success - the interface will be automatically destroyed when we close the fd
logger.Debug("Successfully created test TUN interface, sufficient permissions for native TUN interface")
return nil
}

View File

@@ -0,0 +1,38 @@
//go:build windows
package permissions
import (
"fmt"
"golang.org/x/sys/windows"
)
// CheckNativeInterfacePermissions checks if the process has sufficient
// permissions to create a native TUN interface on Windows.
// This requires Administrator privileges.
func CheckNativeInterfacePermissions() error {
var sid *windows.SID
err := windows.AllocateAndInitializeSid(
&windows.SECURITY_NT_AUTHORITY,
2,
windows.SECURITY_BUILTIN_DOMAIN_RID,
windows.DOMAIN_ALIAS_RID_ADMINS,
0, 0, 0, 0, 0, 0,
&sid)
if err != nil {
return fmt.Errorf("failed to initialize SID: %v", err)
}
defer windows.FreeSid(sid)
token := windows.Token(0)
member, err := token.IsMember(sid)
if err != nil {
return fmt.Errorf("failed to check admin group membership: %v", err)
}
if !member {
return fmt.Errorf("insufficient permissions: need Administrator to create TUN interface on Windows")
}
return nil
}

View File

@@ -2,11 +2,9 @@ package main
import (
"bytes"
"encoding/base64"
"encoding/hex"
"context"
"encoding/json"
"fmt"
"net"
"os"
"os/exec"
"strings"
@@ -14,32 +12,20 @@ import (
"math/rand"
"github.com/fosrl/newt/internal/telemetry"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/proxy"
"github.com/fosrl/newt/websocket"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"gopkg.in/yaml.v3"
)
func fixKey(key string) string {
// Remove any whitespace
key = strings.TrimSpace(key)
// Decode from base64
decoded, err := base64.StdEncoding.DecodeString(key)
if err != nil {
logger.Fatal("Error decoding base64: %v", err)
}
// Convert to hex
return hex.EncodeToString(decoded)
}
const msgHealthFileWriteFailed = "Failed to write health file: %v"
func ping(tnet *netstack.Net, dst string, timeout time.Duration) (time.Duration, error) {
logger.Debug("Pinging %s", dst)
// logger.Debug("Pinging %s", dst)
socket, err := tnet.Dial("ping4", dst)
if err != nil {
return 0, fmt.Errorf("failed to create ICMP socket: %w", err)
@@ -98,7 +84,7 @@ func ping(tnet *netstack.Net, dst string, timeout time.Duration) (time.Duration,
latency := time.Since(start)
logger.Debug("Ping to %s successful, latency: %v", dst, latency)
// logger.Debug("Ping to %s successful, latency: %v", dst, latency)
return latency, nil
}
@@ -136,7 +122,7 @@ func reliablePing(tnet *netstack.Net, dst string, baseTimeout time.Duration, max
// If we get at least one success, we can return early for health checks
if successCount > 0 {
avgLatency := totalLatency / time.Duration(successCount)
logger.Debug("Reliable ping succeeded after %d attempts, avg latency: %v", attempt, avgLatency)
// logger.Debug("Reliable ping succeeded after %d attempts, avg latency: %v", attempt, avgLatency)
return avgLatency, nil
}
}
@@ -176,7 +162,7 @@ func pingWithRetry(tnet *netstack.Net, dst string, timeout time.Duration) (stopC
if healthFile != "" {
err := os.WriteFile(healthFile, []byte("ok"), 0644)
if err != nil {
logger.Warn("Failed to write health file: %v", err)
logger.Warn(msgHealthFileWriteFailed, err)
}
}
return stopChan, nil
@@ -217,11 +203,13 @@ func pingWithRetry(tnet *netstack.Net, dst string, timeout time.Duration) (stopC
if healthFile != "" {
err := os.WriteFile(healthFile, []byte("ok"), 0644)
if err != nil {
logger.Warn("Failed to write health file: %v", err)
logger.Warn(msgHealthFileWriteFailed, err)
}
}
return
}
case <-pingStopChan:
// Stop the goroutine when signaled
return
}
}
}()
@@ -230,7 +218,7 @@ func pingWithRetry(tnet *netstack.Net, dst string, timeout time.Duration) (stopC
return stopChan, fmt.Errorf("initial ping attempts failed, continuing in background")
}
func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Client) chan struct{} {
func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Client, tunnelID string) chan struct{} {
maxInterval := 6 * time.Second
currentInterval := pingInterval
consecutiveFailures := 0
@@ -293,6 +281,9 @@ func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Clien
if !connectionLost {
connectionLost = true
logger.Warn("Connection to server lost after %d failures. Continuous reconnection attempts will be made.", consecutiveFailures)
if tunnelID != "" {
telemetry.IncReconnect(context.Background(), tunnelID, "client", telemetry.ReasonTimeout)
}
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second)
// Send registration message to the server for backward compatibility
err := client.SendMessage("newt/wg/register", map[string]interface{}{
@@ -319,6 +310,10 @@ func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Clien
} else {
// Track recent latencies
recentLatencies = append(recentLatencies, latency)
// Record tunnel latency (limit sampling to this periodic check)
if tunnelID != "" {
telemetry.ObserveTunnelLatency(context.Background(), tunnelID, "wireguard", latency.Seconds())
}
if len(recentLatencies) > 10 {
recentLatencies = recentLatencies[1:]
}
@@ -353,89 +348,6 @@ func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Clien
return pingStopChan
}
func parseLogLevel(level string) logger.LogLevel {
switch strings.ToUpper(level) {
case "DEBUG":
return logger.DEBUG
case "INFO":
return logger.INFO
case "WARN":
return logger.WARN
case "ERROR":
return logger.ERROR
case "FATAL":
return logger.FATAL
default:
return logger.INFO // default to INFO if invalid level provided
}
}
func mapToWireGuardLogLevel(level logger.LogLevel) int {
switch level {
case logger.DEBUG:
return device.LogLevelVerbose
// case logger.INFO:
// return device.LogLevel
case logger.WARN:
return device.LogLevelError
case logger.ERROR, logger.FATAL:
return device.LogLevelSilent
default:
return device.LogLevelSilent
}
}
func resolveDomain(domain string) (string, error) {
// Check if there's a port in the domain
host, port, err := net.SplitHostPort(domain)
if err != nil {
// No port found, use the domain as is
host = domain
port = ""
}
// Remove any protocol prefix if present
if strings.HasPrefix(host, "http://") {
host = strings.TrimPrefix(host, "http://")
} else if strings.HasPrefix(host, "https://") {
host = strings.TrimPrefix(host, "https://")
}
// if there are any trailing slashes, remove them
host = strings.TrimSuffix(host, "/")
// Lookup IP addresses
ips, err := net.LookupIP(host)
if err != nil {
return "", fmt.Errorf("DNS lookup failed: %v", err)
}
if len(ips) == 0 {
return "", fmt.Errorf("no IP addresses found for domain %s", host)
}
// Get the first IPv4 address if available
var ipAddr string
for _, ip := range ips {
if ipv4 := ip.To4(); ipv4 != nil {
ipAddr = ipv4.String()
break
}
}
// If no IPv4 found, use the first IP (might be IPv6)
if ipAddr == "" {
ipAddr = ips[0].String()
}
// Add port back if it existed
if port != "" {
ipAddr = net.JoinHostPort(ipAddr, port)
}
return ipAddr, nil
}
func parseTargetData(data interface{}) (TargetData, error) {
var targetData TargetData
jsonData, err := json.Marshal(data)
@@ -468,7 +380,8 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
continue
}
if action == "add" {
switch action {
case "add":
target := parts[1] + ":" + parts[2]
// Call updown script if provided
@@ -494,7 +407,7 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
// Add the new target
pm.AddTarget(proto, tunnelIP, port, processedTarget)
} else if action == "remove" {
case "remove":
logger.Info("Removing target with port %d", port)
target := parts[1] + ":" + parts[2]
@@ -512,6 +425,8 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
logger.Error("Failed to remove target: %v", err)
return err
}
default:
logger.Info("Unknown action: %s", action)
}
}

View File

@@ -0,0 +1,41 @@
services:
newt:
build: .
image: newt:dev
env_file:
- .env
environment:
- NEWT_METRICS_PROMETHEUS_ENABLED=false # important: disable direct /metrics scraping
- NEWT_METRICS_OTLP_ENABLED=true # OTLP to the Collector
# optional:
# - NEWT_METRICS_INCLUDE_TUNNEL_ID=false
# When using the Collector pattern, do NOT map the Newt admin/metrics port
# (2112) on the application service. Mapping 2112 here can cause port
# conflicts and may result in duplicated Prometheus scraping (app AND
# collector being scraped for the same metrics). Instead either:
# - leave ports unset on the app service (recommended), or
# - map 2112 only on a dedicated metrics/collector service that is
# responsible for exposing metrics to Prometheus.
# Example: do NOT map here
# ports: []
# Example: map 2112 only on a collector service
# collector:
# ports:
# - "2112:2112" # collector's prometheus exporter (scraped by Prometheus)
otel-collector:
image: otel/opentelemetry-collector-contrib:latest
command: ["--config=/etc/otelcol/config.yaml"]
volumes:
- ./examples/otel-collector.yaml:/etc/otelcol/config.yaml:ro
ports:
- "4317:4317" # OTLP gRPC
- "8889:8889" # Prometheus Exporter (scraped by Prometheus)
prometheus:
image: prom/prometheus:latest
volumes:
- ./examples/prometheus.with-collector.yml:/etc/prometheus/prometheus.yml:ro
ports:
- "9090:9090"

View File

@@ -0,0 +1,56 @@
name: Newt-Metrics
services:
# Recommended Variant A: Direct Prometheus scrape of Newt (/metrics)
# Optional: You may add the Collector service and enable OTLP export, but do NOT
# scrape both Newt and the Collector for the same process.
newt:
build: .
image: newt:dev
env_file:
- .env
environment:
OTEL_SERVICE_NAME: newt
NEWT_METRICS_PROMETHEUS_ENABLED: "true"
NEWT_METRICS_OTLP_ENABLED: "false" # avoid double-scrape by default
NEWT_ADMIN_ADDR: ":2112"
# Base NEWT configuration
PANGOLIN_ENDPOINT: ${PANGOLIN_ENDPOINT}
NEWT_ID: ${NEWT_ID}
NEWT_SECRET: ${NEWT_SECRET}
LOG_LEVEL: "DEBUG"
ports:
- "2112:2112"
# Optional Variant B: Enable the Collector and switch Prometheus scrape to it.
# collector:
# image: otel/opentelemetry-collector-contrib:0.136.0
# command: ["--config=/etc/otelcol/config.yaml"]
# volumes:
# - ./examples/otel-collector.yaml:/etc/otelcol/config.yaml:ro
# ports:
# - "4317:4317" # OTLP gRPC in
# - "8889:8889" # Prometheus scrape out
prometheus:
image: prom/prometheus:v3.6.0
volumes:
- ./examples/prometheus.yml:/etc/prometheus/prometheus.yml:ro
ports:
- "9090:9090"
grafana:
image: grafana/grafana:12.2.0
container_name: newt-metrics-grafana
restart: unless-stopped
environment:
- GF_SECURITY_ADMIN_USER=admin
- GF_SECURITY_ADMIN_PASSWORD=admin
ports:
- "3005:3000"
depends_on:
- prometheus
volumes:
- ./examples/grafana/provisioning/datasources:/etc/grafana/provisioning/datasources:ro
- ./examples/grafana/provisioning/dashboards:/etc/grafana/provisioning/dashboards:ro
- ./examples/grafana/dashboards:/var/lib/grafana/dashboards:ro

167
examples/README.md Normal file
View File

@@ -0,0 +1,167 @@
# Extensible Logger
This logger package provides a flexible logging system that can be extended with custom log writers.
## Basic Usage (Current Behavior)
The logger works exactly as before with no changes required:
```go
package main
import "your-project/logger"
func main() {
// Use default logger
logger.Info("This works as before")
logger.Debug("Debug message")
logger.Error("Error message")
// Or create a custom instance
log := logger.NewLogger()
log.SetLevel(logger.INFO)
log.Info("Custom logger instance")
}
```
## Custom Log Writers
To use a custom log backend, implement the `LogWriter` interface:
```go
type LogWriter interface {
Write(level LogLevel, timestamp time.Time, message string)
}
```
### Example: OS Log Writer (macOS/iOS)
```go
package main
import "your-project/logger"
func main() {
// Create an OS log writer
osWriter := logger.NewOSLogWriter(
"net.pangolin.Pangolin.PacketTunnel",
"PangolinGo",
"MyApp",
)
// Create a logger with the OS log writer
log := logger.NewLoggerWithWriter(osWriter)
log.SetLevel(logger.DEBUG)
// Use it just like the standard logger
log.Info("This message goes to os_log")
log.Error("Error logged to os_log")
}
```
### Example: Custom Writer
```go
package main
import (
"fmt"
"time"
"your-project/logger"
)
// CustomWriter writes logs to a custom destination
type CustomWriter struct {
// your custom fields
}
func (w *CustomWriter) Write(level logger.LogLevel, timestamp time.Time, message string) {
// Your custom logging logic
fmt.Printf("[CUSTOM] %s [%s] %s\n", timestamp.Format(time.RFC3339), level.String(), message)
}
func main() {
customWriter := &CustomWriter{}
log := logger.NewLoggerWithWriter(customWriter)
log.Info("Custom logging!")
}
```
### Example: Multi-Writer (Log to Multiple Destinations)
```go
package main
import (
"time"
"your-project/logger"
)
// MultiWriter writes to multiple log writers
type MultiWriter struct {
writers []logger.LogWriter
}
func NewMultiWriter(writers ...logger.LogWriter) *MultiWriter {
return &MultiWriter{writers: writers}
}
func (w *MultiWriter) Write(level logger.LogLevel, timestamp time.Time, message string) {
for _, writer := range w.writers {
writer.Write(level, timestamp, message)
}
}
func main() {
// Log to both standard output and OS log
standardWriter := logger.NewStandardWriter()
osWriter := logger.NewOSLogWriter("com.example.app", "Main", "App")
multiWriter := NewMultiWriter(standardWriter, osWriter)
log := logger.NewLoggerWithWriter(multiWriter)
log.Info("This goes to both stdout and os_log!")
}
```
## API Reference
### Creating Loggers
- `NewLogger()` - Creates a logger with the default StandardWriter
- `NewLoggerWithWriter(writer LogWriter)` - Creates a logger with a custom writer
### Built-in Writers
- `NewStandardWriter()` - Standard writer that outputs to stdout (default)
- `NewOSLogWriter(subsystem, category, prefix string)` - OS log writer for macOS/iOS (example)
### Logger Methods
- `SetLevel(level LogLevel)` - Set minimum log level
- `SetOutput(output *os.File)` - Set output file (StandardWriter only)
- `Debug(format string, args ...interface{})` - Log debug message
- `Info(format string, args ...interface{})` - Log info message
- `Warn(format string, args ...interface{})` - Log warning message
- `Error(format string, args ...interface{})` - Log error message
- `Fatal(format string, args ...interface{})` - Log fatal message and exit
### Global Functions
For convenience, you can use global functions that use the default logger:
- `logger.Debug(format, args...)`
- `logger.Info(format, args...)`
- `logger.Warn(format, args...)`
- `logger.Error(format, args...)`
- `logger.Fatal(format, args...)`
- `logger.SetOutput(output *os.File)`
## Migration Guide
No changes needed! The logger maintains 100% backward compatibility. Your existing code will continue to work without modifications.
If you want to switch to a custom writer:
1. Create your writer implementing `LogWriter`
2. Use `NewLoggerWithWriter()` instead of `NewLogger()`
3. That's it!

View File

@@ -0,0 +1,898 @@
{
"annotations": {
"list": [
{
"builtIn": 1,
"datasource": {
"type": "grafana",
"uid": "-- Grafana --"
},
"enable": true,
"hide": true,
"iconColor": "rgba(0, 211, 255, 1)",
"name": "Annotations & Alerts",
"type": "dashboard"
}
]
},
"editable": true,
"fiscalYearStartMonth": 0,
"graphTooltip": 0,
"id": null,
"links": [],
"liveNow": false,
"panels": [
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "thresholds"
},
"decimals": 0,
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "red",
"value": 500
}
]
}
},
"overrides": []
},
"gridPos": {
"h": 7,
"w": 6,
"x": 0,
"y": 0
},
"id": 1,
"options": {
"colorMode": "value",
"graphMode": "area",
"justifyMode": "auto",
"orientation": "horizontal",
"reduceOptions": {
"calcs": [
"lastNotNull"
],
"fields": "",
"values": false
},
"textMode": "value_and_name"
},
"pluginVersion": "11.1.0",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "go_goroutine_count",
"instant": true,
"legendFormat": "",
"refId": "A"
}
],
"title": "Goroutines",
"transformations": [],
"type": "stat"
},
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "thresholds"
},
"decimals": 1,
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "orange",
"value": 256
},
{
"color": "red",
"value": 512
}
]
},
"unit": "bytes"
},
"overrides": []
},
"gridPos": {
"h": 7,
"w": 6,
"x": 6,
"y": 0
},
"id": 2,
"options": {
"colorMode": "background",
"graphMode": "area",
"justifyMode": "auto",
"orientation": "horizontal",
"reduceOptions": {
"calcs": [
"lastNotNull"
],
"fields": "",
"values": false
},
"textMode": "value_and_name"
},
"pluginVersion": "11.1.0",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "go_memory_gc_goal_bytes / 1024 / 1024",
"format": "time_series",
"instant": true,
"legendFormat": "",
"refId": "A"
}
],
"title": "GC Target Heap (MiB)",
"transformations": [],
"type": "stat"
},
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "thresholds"
},
"decimals": 2,
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "orange",
"value": 10
},
{
"color": "red",
"value": 25
}
]
},
"unit": "ops"
},
"overrides": []
},
"gridPos": {
"h": 7,
"w": 6,
"x": 12,
"y": 0
},
"id": 3,
"options": {
"colorMode": "value",
"graphMode": "area",
"justifyMode": "auto",
"orientation": "horizontal",
"reduceOptions": {
"calcs": [
"lastNotNull"
],
"fields": "",
"values": false
},
"textMode": "value_and_name"
},
"pluginVersion": "11.1.0",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "sum(rate(http_server_request_duration_seconds_count[$__rate_interval]))",
"instant": false,
"legendFormat": "req/s",
"refId": "A"
}
],
"title": "HTTP Requests / s",
"transformations": [],
"type": "stat"
},
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "thresholds"
},
"decimals": 3,
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "orange",
"value": 0.1
},
{
"color": "red",
"value": 0.5
}
]
},
"unit": "ops"
},
"overrides": []
},
"gridPos": {
"h": 7,
"w": 6,
"x": 18,
"y": 0
},
"id": 4,
"options": {
"colorMode": "value",
"graphMode": "area",
"justifyMode": "auto",
"orientation": "horizontal",
"reduceOptions": {
"calcs": [
"lastNotNull"
],
"fields": "",
"values": false
},
"textMode": "value_and_name"
},
"pluginVersion": "11.1.0",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "sum(rate(newt_connection_errors_total{site_id=~\"$site_id\"}[$__rate_interval]))",
"instant": false,
"legendFormat": "errors/s",
"refId": "A"
}
],
"title": "Connection Errors / s",
"transformations": [],
"type": "stat"
},
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"fieldConfig": {
"defaults": {
"custom": {},
"mappings": [],
"unit": "bytes"
},
"overrides": []
},
"gridPos": {
"h": 9,
"w": 12,
"x": 0,
"y": 7
},
"id": 5,
"options": {
"legend": {
"calcs": [],
"displayMode": "table",
"placement": "right"
},
"tooltip": {
"mode": "multi",
"sort": "desc"
}
},
"pluginVersion": "11.1.0",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "sum(go_memory_used_bytes)",
"legendFormat": "Used",
"refId": "A"
},
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "go_memory_gc_goal_bytes",
"legendFormat": "GC Goal",
"refId": "B"
}
],
"title": "Go Heap Usage vs GC Goal",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"fieldConfig": {
"defaults": {
"custom": {},
"decimals": 0,
"mappings": [],
"unit": "short"
},
"overrides": []
},
"gridPos": {
"h": 9,
"w": 12,
"x": 12,
"y": 7
},
"id": 6,
"options": {
"legend": {
"calcs": [],
"displayMode": "table",
"placement": "right"
},
"tooltip": {
"mode": "multi",
"sort": "desc"
}
},
"pluginVersion": "11.1.0",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "rate(go_memory_allocations_total[$__rate_interval])",
"legendFormat": "Allocations/s",
"refId": "A"
},
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "rate(go_memory_allocated_bytes_total[$__rate_interval])",
"legendFormat": "Allocated bytes/s",
"refId": "B"
}
],
"title": "Allocation Activity",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"fieldConfig": {
"defaults": {
"custom": {},
"mappings": [],
"unit": "s"
},
"overrides": []
},
"gridPos": {
"h": 9,
"w": 12,
"x": 0,
"y": 16
},
"id": 7,
"options": {
"legend": {
"calcs": [],
"displayMode": "table",
"placement": "right"
},
"tooltip": {
"mode": "multi",
"sort": "desc"
}
},
"pluginVersion": "11.1.0",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "histogram_quantile(0.5, sum(rate(http_server_request_duration_seconds_bucket[$__rate_interval])) by (le))",
"legendFormat": "p50",
"refId": "A"
},
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "histogram_quantile(0.95, sum(rate(http_server_request_duration_seconds_bucket[$__rate_interval])) by (le))",
"legendFormat": "p95",
"refId": "B"
},
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "histogram_quantile(0.99, sum(rate(http_server_request_duration_seconds_bucket[$__rate_interval])) by (le))",
"legendFormat": "p99",
"refId": "C"
}
],
"title": "HTTP Request Duration Quantiles",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"fieldConfig": {
"defaults": {
"custom": {},
"mappings": [],
"unit": "ops"
},
"overrides": []
},
"gridPos": {
"h": 9,
"w": 12,
"x": 12,
"y": 16
},
"id": 8,
"options": {
"legend": {
"calcs": [],
"displayMode": "table",
"placement": "right"
},
"tooltip": {
"mode": "multi",
"sort": "desc"
}
},
"pluginVersion": "11.1.0",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "sum(rate(http_server_request_duration_seconds_count[$__rate_interval])) by (http_response_status_code)",
"legendFormat": "{{http_response_status_code}}",
"refId": "A"
}
],
"title": "HTTP Requests by Status",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"fieldConfig": {
"defaults": {
"custom": {},
"mappings": [],
"unit": "ops"
},
"overrides": []
},
"gridPos": {
"h": 9,
"w": 12,
"x": 0,
"y": 25
},
"id": 9,
"options": {
"legend": {
"calcs": [],
"displayMode": "table",
"placement": "right"
},
"tooltip": {
"mode": "multi",
"sort": "desc"
}
},
"pluginVersion": "11.1.0",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "sum(rate(newt_connection_attempts_total{site_id=~\"$site_id\"}[$__rate_interval])) by (transport, result)",
"legendFormat": "{{transport}} • {{result}}",
"refId": "A"
}
],
"title": "Connection Attempts by Transport/Result",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"fieldConfig": {
"defaults": {
"custom": {},
"mappings": [],
"unit": "ops"
},
"overrides": []
},
"gridPos": {
"h": 9,
"w": 12,
"x": 12,
"y": 25
},
"id": 10,
"options": {
"legend": {
"calcs": [],
"displayMode": "table",
"placement": "right"
},
"tooltip": {
"mode": "multi",
"sort": "desc"
}
},
"pluginVersion": "11.1.0",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "sum(rate(newt_connection_errors_total{site_id=~\"$site_id\"}[$__rate_interval])) by (transport, error_type)",
"legendFormat": "{{transport}} • {{error_type}}",
"refId": "A"
}
],
"title": "Connection Errors by Type",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"fieldConfig": {
"defaults": {
"custom": {},
"decimals": 3,
"mappings": [],
"unit": "s"
},
"overrides": []
},
"gridPos": {
"h": 9,
"w": 12,
"x": 0,
"y": 34
},
"id": 11,
"options": {
"legend": {
"calcs": [],
"displayMode": "table",
"placement": "right"
},
"tooltip": {
"mode": "multi",
"sort": "desc"
}
},
"pluginVersion": "11.1.0",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "histogram_quantile(0.5, sum(rate(newt_tunnel_latency_seconds_bucket{site_id=~\"$site_id\", tunnel_id=~\"$tunnel_id\"}[$__rate_interval])) by (le))",
"legendFormat": "p50",
"refId": "A"
},
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "histogram_quantile(0.95, sum(rate(newt_tunnel_latency_seconds_bucket{site_id=~\"$site_id\", tunnel_id=~\"$tunnel_id\"}[$__rate_interval])) by (le))",
"legendFormat": "p95",
"refId": "B"
},
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "histogram_quantile(0.99, sum(rate(newt_tunnel_latency_seconds_bucket{site_id=~\"$site_id\", tunnel_id=~\"$tunnel_id\"}[$__rate_interval])) by (le))",
"legendFormat": "p99",
"refId": "C"
}
],
"title": "Tunnel Latency Quantiles",
"type": "timeseries"
},
{
"cards": {},
"color": {
"cardColor": "#b4ff00",
"colorScale": "sqrt",
"colorScheme": "interpolateTurbo"
},
"dataFormat": "tsbuckets",
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"fieldConfig": {
"defaults": {
"custom": {},
"mappings": [],
"unit": "s"
},
"overrides": []
},
"gridPos": {
"h": 9,
"w": 12,
"x": 12,
"y": 34
},
"heatmap": {},
"hideZeroBuckets": true,
"id": 12,
"legend": {
"show": false
},
"options": {
"calculate": true,
"cellGap": 2,
"cellSize": "auto",
"color": {
"exponent": 0.5
},
"exemplars": {
"color": "rgba(255,255,255,0.7)"
},
"filterValues": {
"le": 1e-9
},
"legend": {
"show": false
},
"tooltip": {
"mode": "single",
"show": true
},
"xAxis": {
"show": true
},
"yAxis": {
"decimals": 3,
"show": true
}
},
"pluginVersion": "11.1.0",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "sum(rate(newt_tunnel_latency_seconds_bucket{site_id=~\"$site_id\", tunnel_id=~\"$tunnel_id\"}[$__rate_interval])) by (le)",
"format": "heatmap",
"legendFormat": "{{le}}",
"refId": "A"
}
],
"title": "Tunnel Latency Bucket Rate",
"type": "heatmap"
}
],
"refresh": "30s",
"schemaVersion": 39,
"style": "dark",
"tags": [
"newt",
"otel"
],
"templating": {
"list": [
{
"current": {
"selected": false,
"text": "Prometheus",
"value": "prometheus"
},
"hide": 0,
"label": "Datasource",
"name": "DS_PROMETHEUS",
"options": [],
"query": "prometheus",
"refresh": 1,
"type": "datasource"
},
{
"current": {
"selected": false,
"text": "All",
"value": "$__all"
},
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"definition": "label_values(target_info, site_id)",
"hide": 0,
"includeAll": true,
"label": "Site",
"multi": true,
"name": "site_id",
"options": [],
"query": {
"query": "label_values(target_info, site_id)",
"refId": "SiteIdVar"
},
"refresh": 2,
"regex": "",
"skipUrlSync": false,
"sort": 1,
"tagValuesQuery": "",
"tags": [],
"tagsQuery": "",
"type": "query",
"useTags": false
},
{
"current": {
"selected": false,
"text": "All",
"value": "$__all"
},
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"definition": "label_values(newt_tunnel_latency_seconds_bucket{site_id=~\"$site_id\"}, tunnel_id)",
"hide": 0,
"includeAll": true,
"label": "Tunnel",
"multi": true,
"name": "tunnel_id",
"options": [],
"query": {
"query": "label_values(newt_tunnel_latency_seconds_bucket{site_id=~\"$site_id\"}, tunnel_id)",
"refId": "TunnelVar"
},
"refresh": 2,
"regex": "",
"skipUrlSync": false,
"sort": 1,
"tagValuesQuery": "",
"tags": [],
"tagsQuery": "",
"type": "query",
"useTags": false
}
]
},
"time": {
"from": "now-6h",
"to": "now"
},
"timepicker": {
"refresh_intervals": [
"10s",
"30s",
"1m",
"5m",
"15m",
"30m",
"1h",
"2h",
"1d"
],
"time_options": [
"5m",
"15m",
"1h",
"6h",
"12h",
"24h",
"2d",
"7d",
"30d"
]
},
"timezone": "browser",
"title": "Newt Overview",
"uid": "newt-overview",
"version": 1,
"weekStart": ""
}

View File

@@ -0,0 +1,9 @@
apiVersion: 1
providers:
- name: "newt"
folder: "Newt"
type: file
disableDeletion: false
allowUiUpdates: true
options:
path: /var/lib/grafana/dashboards

View File

@@ -0,0 +1,9 @@
apiVersion: 1
datasources:
- name: Prometheus
type: prometheus
access: proxy
url: http://prometheus:9090
uid: prometheus
isDefault: true
editable: true

161
examples/logger_examples.go Normal file
View File

@@ -0,0 +1,161 @@
// Example usage patterns for the extensible logger
package main
import (
"fmt"
"os"
"time"
"github.com/fosrl/newt/logger"
)
// Example 1: Using the default logger (works exactly as before)
func exampleDefaultLogger() {
logger.Info("Starting application")
logger.Debug("Debug information")
logger.Warn("Warning message")
logger.Error("Error occurred")
}
// Example 2: Using a custom logger instance with standard writer
func exampleCustomInstance() {
log := logger.NewLogger()
log.SetLevel(logger.INFO)
log.Info("This is from a custom instance")
}
// Example 3: Custom writer that adds JSON formatting
type JSONWriter struct{}
func (w *JSONWriter) Write(level logger.LogLevel, timestamp time.Time, message string) {
fmt.Printf("{\"time\":\"%s\",\"level\":\"%s\",\"message\":\"%s\"}\n",
timestamp.Format(time.RFC3339),
level.String(),
message)
}
func exampleJSONLogger() {
jsonWriter := &JSONWriter{}
log := logger.NewLoggerWithWriter(jsonWriter)
log.Info("This will be logged as JSON")
}
// Example 4: File writer
type FileWriter struct {
file *os.File
}
func NewFileWriter(filename string) (*FileWriter, error) {
file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
return nil, err
}
return &FileWriter{file: file}, nil
}
func (w *FileWriter) Write(level logger.LogLevel, timestamp time.Time, message string) {
fmt.Fprintf(w.file, "[%s] %s: %s\n",
timestamp.Format("2006-01-02 15:04:05"),
level.String(),
message)
}
func (w *FileWriter) Close() error {
return w.file.Close()
}
func exampleFileLogger() {
fileWriter, err := NewFileWriter("/tmp/app.log")
if err != nil {
panic(err)
}
defer fileWriter.Close()
log := logger.NewLoggerWithWriter(fileWriter)
log.Info("This goes to a file")
}
// Example 5: Multi-writer to log to multiple destinations
type MultiWriter struct {
writers []logger.LogWriter
}
func NewMultiWriter(writers ...logger.LogWriter) *MultiWriter {
return &MultiWriter{writers: writers}
}
func (w *MultiWriter) Write(level logger.LogLevel, timestamp time.Time, message string) {
for _, writer := range w.writers {
writer.Write(level, timestamp, message)
}
}
func exampleMultiWriter() {
// Log to both stdout and a file
standardWriter := logger.NewStandardWriter()
fileWriter, _ := NewFileWriter("/tmp/app.log")
multiWriter := NewMultiWriter(standardWriter, fileWriter)
log := logger.NewLoggerWithWriter(multiWriter)
log.Info("This goes to both stdout and file!")
}
// Example 6: Conditional writer (only log errors to a specific destination)
type ErrorOnlyWriter struct {
writer logger.LogWriter
}
func NewErrorOnlyWriter(writer logger.LogWriter) *ErrorOnlyWriter {
return &ErrorOnlyWriter{writer: writer}
}
func (w *ErrorOnlyWriter) Write(level logger.LogLevel, timestamp time.Time, message string) {
if level >= logger.ERROR {
w.writer.Write(level, timestamp, message)
}
}
func exampleConditionalWriter() {
errorWriter, _ := NewFileWriter("/tmp/errors.log")
errorOnlyWriter := NewErrorOnlyWriter(errorWriter)
log := logger.NewLoggerWithWriter(errorOnlyWriter)
log.Info("This won't be logged")
log.Error("This will be logged to errors.log")
}
/* Example 7: OS Log Writer (macOS/iOS only)
// Uncomment on Darwin platforms
func exampleOSLogWriter() {
osWriter := logger.NewOSLogWriter(
"net.pangolin.Pangolin.PacketTunnel",
"PangolinGo",
"MyApp",
)
log := logger.NewLoggerWithWriter(osWriter)
log.Info("This goes to os_log and can be viewed with Console.app")
}
*/
func main() {
fmt.Println("=== Example 1: Default Logger ===")
exampleDefaultLogger()
fmt.Println("\n=== Example 2: Custom Instance ===")
exampleCustomInstance()
fmt.Println("\n=== Example 3: JSON Logger ===")
exampleJSONLogger()
fmt.Println("\n=== Example 4: File Logger ===")
exampleFileLogger()
fmt.Println("\n=== Example 5: Multi-Writer ===")
exampleMultiWriter()
fmt.Println("\n=== Example 6: Conditional Writer ===")
exampleConditionalWriter()
}

View File

@@ -0,0 +1,86 @@
//go:build darwin
// +build darwin
package main
/*
#cgo CFLAGS: -I../PacketTunnel
#include "../PacketTunnel/OSLogBridge.h"
#include <stdlib.h>
*/
import "C"
import (
"fmt"
"runtime"
"time"
"unsafe"
)
// OSLogWriter is a LogWriter implementation that writes to Apple's os_log
type OSLogWriter struct {
subsystem string
category string
prefix string
}
// NewOSLogWriter creates a new OSLogWriter
func NewOSLogWriter(subsystem, category, prefix string) *OSLogWriter {
writer := &OSLogWriter{
subsystem: subsystem,
category: category,
prefix: prefix,
}
// Initialize the OS log bridge
cSubsystem := C.CString(subsystem)
cCategory := C.CString(category)
defer C.free(unsafe.Pointer(cSubsystem))
defer C.free(unsafe.Pointer(cCategory))
C.initOSLogBridge(cSubsystem, cCategory)
return writer
}
// Write implements the LogWriter interface
func (w *OSLogWriter) Write(level LogLevel, timestamp time.Time, message string) {
// Get caller information (skip 3 frames to get to the actual caller)
_, file, line, ok := runtime.Caller(3)
if !ok {
file = "unknown"
line = 0
} else {
// Get just the filename, not the full path
for i := len(file) - 1; i > 0; i-- {
if file[i] == '/' {
file = file[i+1:]
break
}
}
}
formattedTime := timestamp.Format("2006-01-02 15:04:05.000")
fullMessage := fmt.Sprintf("[%s] [%s] [%s] %s:%d - %s",
formattedTime, level.String(), w.prefix, file, line, message)
cMessage := C.CString(fullMessage)
defer C.free(unsafe.Pointer(cMessage))
// Map Go log levels to os_log levels:
// 0=DEBUG, 1=INFO, 2=DEFAULT (WARN), 3=ERROR
var osLogLevel C.int
switch level {
case DEBUG:
osLogLevel = 0 // DEBUG
case INFO:
osLogLevel = 1 // INFO
case WARN:
osLogLevel = 2 // DEFAULT
case ERROR, FATAL:
osLogLevel = 3 // ERROR
default:
osLogLevel = 2 // DEFAULT
}
C.logToOSLog(osLogLevel, cMessage)
}

View File

@@ -0,0 +1,61 @@
# Variant A: Direct scrape of Newt (/metrics) via Prometheus (no Collector needed)
# Note: Newt already exposes labels like site_id, protocol, direction. Do not promote
# resource attributes into labels when scraping Newt directly.
#
# Example Prometheus scrape config:
# global:
# scrape_interval: 15s
# scrape_configs:
# - job_name: newt
# static_configs:
# - targets: ["newt:2112"]
#
# Variant B: Use OTEL Collector (Newt -> OTLP -> Collector -> Prometheus)
# This pipeline scrapes metrics from the Collector's Prometheus exporter.
# Labels are already on datapoints; promotion from resource is OPTIONAL and typically NOT required.
# If you enable transform/promote below, ensure you do not duplicate labels.
receivers:
otlp:
protocols:
grpc:
endpoint: ":4317"
processors:
memory_limiter:
check_interval: 5s
limit_percentage: 80
spike_limit_percentage: 25
resourcedetection:
detectors: [env, system]
timeout: 5s
batch: {}
# OPTIONAL: Only enable if you need to promote resource attributes to labels.
# WARNING: Newt already provides site_id as a label; avoid double-promotion.
# transform/promote:
# error_mode: ignore
# metric_statements:
# - context: datapoint
# statements:
# - set(attributes["service_instance_id"], resource.attributes["service.instance.id"]) where resource.attributes["service.instance.id"] != nil
# - set(attributes["site_id"], resource.attributes["site_id"]) where resource.attributes["site_id"] != nil
exporters:
prometheus:
endpoint: ":8889"
send_timestamps: true
# prometheusremotewrite:
# endpoint: http://mimir:9009/api/v1/push
debug:
verbosity: basic
service:
pipelines:
metrics:
receivers: [otlp]
processors: [memory_limiter, resourcedetection, batch] # add transform/promote if you really need it
exporters: [prometheus]
traces:
receivers: [otlp]
processors: [memory_limiter, resourcedetection, batch]
exporters: [debug]

View File

@@ -0,0 +1,16 @@
global:
scrape_interval: 15s
scrape_configs:
# IMPORTANT: Do not scrape Newt directly; scrape only the Collector!
- job_name: 'otel-collector'
static_configs:
- targets: ['otel-collector:8889']
# optional: limit metric cardinality
relabel_configs:
- action: labeldrop
regex: 'tunnel_id'
# - action: keep
# source_labels: [site_id]
# regex: '(site-a|site-b)'

21
examples/prometheus.yml Normal file
View File

@@ -0,0 +1,21 @@
global:
scrape_interval: 15s
scrape_configs:
- job_name: 'newt'
scrape_interval: 15s
static_configs:
- targets: ['newt:2112'] # /metrics
relabel_configs:
# optional: drop tunnel_id
- action: labeldrop
regex: 'tunnel_id'
# optional: allow only specific sites
- action: keep
source_labels: [site_id]
regex: '(site-a|site-b)'
# WARNING: Do not enable this together with the 'newt' job above or you will double-count.
# - job_name: 'otel-collector'
# static_configs:
# - targets: ['otel-collector:8889']

60
go.mod
View File

@@ -3,32 +3,49 @@ module github.com/fosrl/newt
go 1.25
require (
github.com/docker/docker v28.5.0+incompatible
github.com/docker/docker v28.5.1+incompatible
github.com/google/gopacket v1.1.19
github.com/gorilla/websocket v1.5.3
github.com/prometheus/client_golang v1.23.2
github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.42.0
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792
golang.org/x/net v0.45.0
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0
go.opentelemetry.io/contrib/instrumentation/runtime v0.63.0
go.opentelemetry.io/otel v1.38.0
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.38.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0
go.opentelemetry.io/otel/exporters/prometheus v0.60.0
go.opentelemetry.io/otel/metric v1.38.0
go.opentelemetry.io/otel/sdk v1.38.0
go.opentelemetry.io/otel/sdk/metric v1.38.0
golang.org/x/crypto v0.44.0
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6
golang.org/x/net v0.47.0
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
google.golang.org/grpc v1.76.0
gopkg.in/yaml.v3 v3.0.1
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c
software.sslmate.com/src/go-pkcs12 v0.6.0
)
require (
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/containerd/errdefs v1.0.0 // indirect
github.com/Microsoft/go-winio v0.6.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/containerd/errdefs v0.3.0 // indirect
github.com/containerd/errdefs/pkg v0.3.0 // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/go-connections v0.5.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/docker/go-connections v0.6.0 // indirect
github.com/docker/go-units v0.4.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/google/btree v1.1.3 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/grafana/regexp v0.0.0-20240518133315-a468a5bfb3bc // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
@@ -37,18 +54,29 @@ require (
github.com/moby/sys/atomicwriter v0.1.0 // indirect
github.com/moby/term v0.5.2 // indirect
github.com/morikuni/aec v1.0.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/otlptranslator v0.0.2 // indirect
github.com/prometheus/procfs v0.17.0 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0 // indirect
go.opentelemetry.io/otel v1.37.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.36.0 // indirect
go.opentelemetry.io/otel/metric v1.37.0 // indirect
go.opentelemetry.io/otel/trace v1.37.0 // indirect
golang.org/x/sync v0.16.0 // indirect
golang.org/x/sys v0.36.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.38.0 // indirect
go.opentelemetry.io/otel/trace v1.38.0 // indirect
go.opentelemetry.io/proto/otlp v1.7.1 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/mod v0.30.0 // indirect
golang.org/x/sync v0.18.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect
golang.org/x/time v0.12.0 // indirect
golang.org/x/tools v0.39.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect
google.golang.org/protobuf v1.36.8 // indirect
)

156
go.sum
View File

@@ -1,12 +1,15 @@
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEKWjV8V+WSxDXJ4NFATAsZjh8iIbsQIg=
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4=
github.com/cenkalti/backoff/v5 v5.0.2 h1:rIfFVxEf1QsI7E1ZHfp/B4DF/6QBAUhmgkxc0H7Zss8=
github.com/cenkalti/backoff/v5 v5.0.2/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
github.com/Microsoft/go-winio v0.6.0 h1:slsWYD/zyx7lCXoZVlvQrj0hPTM1HI4+v1sIda2yDvg=
github.com/Microsoft/go-winio v0.6.0/go.mod h1:cTAf44im0RAYeL23bpB+fzCyDH2MJiz2BO69KH/soAE=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM=
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/containerd/errdefs v0.3.0 h1:FSZgGOeK4yuT/+DnF07/Olde/q4KBoMsaamhXxIMDp4=
github.com/containerd/errdefs v0.3.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
@@ -15,12 +18,12 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/docker/docker v28.5.0+incompatible h1:ZdSQoRUE9XxhFI/B8YLvhnEFMmYN9Pp8Egd2qcaFk1E=
github.com/docker/docker v28.5.0+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw=
github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
@@ -28,6 +31,8 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
@@ -38,14 +43,20 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 h1:5ZPtiqj0JL5oKWmcsq4VMaAW5ukBEgSGXEN89zeH1Jo=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3/go.mod h1:ndYquD05frm2vACXE1nsccT4oJzjhw2arTS2cpUD1PI=
github.com/grafana/regexp v0.0.0-20240518133315-a468a5bfb3bc h1:GN2Lv3MGO7AS6PrRoT6yV5+wkrOpcszoIsO4+4ds248=
github.com/grafana/regexp v0.0.0-20240518133315-a468a5bfb3bc/go.mod h1:+JKpmjMGhpgPL+rXZ5nsZieVzvarn86asRlBg4uNGnk=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs=
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
@@ -64,71 +75,99 @@ github.com/moby/term v0.5.2 h1:6qk3FJAFDs6i/q3W/pQ97SX192qKfZgGjCQqfCJkgzQ=
github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFLc=
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug=
github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
github.com/prometheus/otlptranslator v0.0.2 h1:+1CdeLVrRQ6Psmhnobldo0kTp96Rj80DRXRd5OSnMEQ=
github.com/prometheus/otlptranslator v0.0.2/go.mod h1:P8AwMgdD7XEr6QRUJ2QWLpiAZTgTE2UYgjlu3svompI=
github.com/prometheus/procfs v0.17.0 h1:FuLQ+05u4ZI+SS/w9+BWEM2TXiHKsUQ9TADiRH7DuK0=
github.com/prometheus/procfs v0.17.0/go.mod h1:oPQLaDAMRbA+u8H5Pbfq+dl3VDAvHxMUOVhe0wYB2zw=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0 h1:Hf9xI/XLML9ElpiHVDNwvqI0hIFlzV8dgIr35kV1kRU=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0/go.mod h1:NfchwuyNoMcZ5MLHwPrODwUF1HWCXWrL31s8gSAdIKY=
go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ=
go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.36.0 h1:dNzwXjZKpMpE2JhmO+9HsPl42NIXFIFSUSSs0fiqra0=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.36.0/go.mod h1:90PoxvaEB5n6AOdZvi+yWJQoE95U8Dhhw2bSyRqnTD0=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.36.0 h1:nRVXXvf78e00EwY6Wp0YII8ww2JVWshZ20HfTlE11AM=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.36.0/go.mod h1:r49hO7CgrxY9Voaj3Xe8pANWtr0Oq916d0XAmOoCZAQ=
go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE=
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI=
go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg=
go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc=
go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps=
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
go.opentelemetry.io/proto/otlp v1.6.0 h1:jQjP+AQyTf+Fe7OKj/MfkDrmK4MNVtw2NpXsf9fefDI=
go.opentelemetry.io/proto/otlp v1.6.0/go.mod h1:cicgGehlFuNdgZkcALOCh3VE6K/u2tAjzlRhDwmVpZc=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 h1:RbKq8BG0FI8OiXhBfcRtqqHcZcka+gU3cskNuf05R18=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0/go.mod h1:h06DGIukJOevXaj/xrNjhi/2098RZzcLTbc0jDAUbsg=
go.opentelemetry.io/contrib/instrumentation/runtime v0.63.0 h1:PeBoRj6af6xMI7qCupwFvTbbnd49V7n5YpG6pg8iDYQ=
go.opentelemetry.io/contrib/instrumentation/runtime v0.63.0/go.mod h1:ingqBCtMCe8I4vpz/UVzCW6sxoqgZB37nao91mLQ3Bw=
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.38.0 h1:vl9obrcoWVKp/lwl8tRE33853I8Xru9HFbw/skNeLs8=
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.38.0/go.mod h1:GAXRxmLJcVM3u22IjTg74zWBrRCKq8BnOqUVLodpcpw=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 h1:GqRJVj7UmLjCVyVJ3ZFLdPRmhDUp2zFmQe3RHIOsw24=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0/go.mod h1:ri3aaHSmCTVYu2AWv44YMauwAQc0aqI9gHKIcSbI1pU=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0 h1:lwI4Dc5leUqENgGuQImwLo4WnuXFPetmPpkLi2IrX54=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0/go.mod h1:Kz/oCE7z5wuyhPxsXDuaPteSWqjSBD5YaSdbxZYGbGk=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.38.0 h1:aTL7F04bJHUlztTsNGJ2l+6he8c+y/b//eR0jjjemT4=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.38.0/go.mod h1:kldtb7jDTeol0l3ewcmd8SDvx3EmIE7lyvqbasU3QC4=
go.opentelemetry.io/otel/exporters/prometheus v0.60.0 h1:cGtQxGvZbnrWdC2GyjZi0PDKVSLWP/Jocix3QWfXtbo=
go.opentelemetry.io/otel/exporters/prometheus v0.60.0/go.mod h1:hkd1EekxNo69PTV4OWFGZcKQiIqg0RfuWExcPKFvepk=
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E=
go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg=
go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM=
go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA=
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
go.opentelemetry.io/proto/otlp v1.7.1 h1:gTOMpGDb0WTBOP8JaO72iL3auEZhVmAQg4ipjOVAtj4=
go.opentelemetry.io/proto/otlp v1.7.1/go.mod h1:b2rVh6rfI/s2pHWNlB7ILJcRALpcNDzKhACevjI+ZnE=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 h1:R9PFI6EUdfVKgwKjZef7QIwGcBKu86OEFpJ9nUEP2l4=
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0=
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
@@ -136,15 +175,16 @@ golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+Z
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
google.golang.org/genproto v0.0.0-20230920204549-e6e6cdab5c13 h1:vlzZttNJGVqTsRFU9AmdnrcO1Znh8Ew9kCD//yjigk0=
google.golang.org/genproto/googleapis/api v0.0.0-20250519155744-55703ea1f237 h1:Kog3KlB4xevJlAcbbbzPfRG0+X9fdoGM+UBRKVz6Wr0=
google.golang.org/genproto/googleapis/api v0.0.0-20250519155744-55703ea1f237/go.mod h1:ezi0AVyMKDWy5xAncvjLWH7UcLBB5n7y2fQ8MzjJcto=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250519155744-55703ea1f237 h1:cJfm9zPbe1e873mHJzmQ1nwVEeRDU/T1wXDK2kUSU34=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250519155744-55703ea1f237/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
google.golang.org/grpc v1.72.1 h1:HR03wO6eyZ7lknl75XlxABNVLLFc2PAb6mHlYh756mA=
google.golang.org/grpc v1.72.1/go.mod h1:wH5Aktxcg25y1I3w7H69nHfXdOG3UiadoBtjh3izSDM=
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 h1:BIRfGDEjiHRrk0QKZe3Xv2ieMhtgRGeLcZQ0mIVn4EY=
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5/go.mod h1:j3QtIyytwqGr1JUDtYXwtMXWPKsEa5LtzIFN1Wn5WvE=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 h1:eaY8u2EuxbRv7c3NiGK0/NedzVsCcV6hDuU5qPX5EGE=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5/go.mod h1:M4/wBTSeyLxupu3W3tJtOgB14jILAS/XWPSSa3TAlJc=
google.golang.org/grpc v1.76.0 h1:UnVkv1+uMLYXoIz6o7chp59WfQUYA2ex/BXQ9rHZu7A=
google.golang.org/grpc v1.76.0/go.mod h1:Ju12QI8M6iQJtbcsV+awF5a4hfJMLi4X0JLo94ULZ6c=
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=

521
holepunch/holepunch.go Normal file
View File

@@ -0,0 +1,521 @@
package holepunch
import (
"encoding/json"
"fmt"
"net"
"sync"
"time"
"github.com/fosrl/newt/bind"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/util"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/curve25519"
mrand "golang.org/x/exp/rand"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// ExitNode represents a WireGuard exit node for hole punching
type ExitNode struct {
Endpoint string `json:"endpoint"`
PublicKey string `json:"publicKey"`
}
// Manager handles UDP hole punching operations
type Manager struct {
mu sync.Mutex
running bool
stopChan chan struct{}
sharedBind *bind.SharedBind
ID string
token string
publicKey string
clientType string
exitNodes map[string]ExitNode // key is endpoint
updateChan chan struct{} // signals the goroutine to refresh exit nodes
sendHolepunchInterval time.Duration
}
const sendHolepunchIntervalMax = 60 * time.Second
const sendHolepunchIntervalMin = 1 * time.Second
// NewManager creates a new hole punch manager
func NewManager(sharedBind *bind.SharedBind, ID string, clientType string, publicKey string) *Manager {
return &Manager{
sharedBind: sharedBind,
ID: ID,
clientType: clientType,
publicKey: publicKey,
exitNodes: make(map[string]ExitNode),
sendHolepunchInterval: sendHolepunchIntervalMin,
}
}
// SetToken updates the authentication token used for hole punching
func (m *Manager) SetToken(token string) {
m.mu.Lock()
defer m.mu.Unlock()
m.token = token
}
// IsRunning returns whether hole punching is currently active
func (m *Manager) IsRunning() bool {
m.mu.Lock()
defer m.mu.Unlock()
return m.running
}
// Stop stops any ongoing hole punch operations
func (m *Manager) Stop() {
m.mu.Lock()
defer m.mu.Unlock()
if !m.running {
return
}
if m.stopChan != nil {
close(m.stopChan)
m.stopChan = nil
}
if m.updateChan != nil {
close(m.updateChan)
m.updateChan = nil
}
m.running = false
logger.Info("Hole punch manager stopped")
}
// AddExitNode adds a new exit node to the rotation if it doesn't already exist
func (m *Manager) AddExitNode(exitNode ExitNode) bool {
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.exitNodes[exitNode.Endpoint]; exists {
logger.Debug("Exit node %s already exists in rotation", exitNode.Endpoint)
return false
}
m.exitNodes[exitNode.Endpoint] = exitNode
logger.Info("Added exit node %s to hole punch rotation", exitNode.Endpoint)
// Signal the goroutine to refresh if running
if m.running && m.updateChan != nil {
select {
case m.updateChan <- struct{}{}:
default:
// Channel full or closed, skip
}
}
return true
}
// RemoveExitNode removes an exit node from the rotation
func (m *Manager) RemoveExitNode(endpoint string) bool {
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.exitNodes[endpoint]; !exists {
logger.Debug("Exit node %s not found in rotation", endpoint)
return false
}
delete(m.exitNodes, endpoint)
logger.Info("Removed exit node %s from hole punch rotation", endpoint)
// Signal the goroutine to refresh if running
if m.running && m.updateChan != nil {
select {
case m.updateChan <- struct{}{}:
default:
// Channel full or closed, skip
}
}
return true
}
// GetExitNodes returns a copy of the current exit nodes
func (m *Manager) GetExitNodes() []ExitNode {
m.mu.Lock()
defer m.mu.Unlock()
nodes := make([]ExitNode, 0, len(m.exitNodes))
for _, node := range m.exitNodes {
nodes = append(nodes, node)
}
return nodes
}
// ResetInterval resets the hole punch interval back to the minimum value,
// allowing it to climb back up through exponential backoff.
// This is useful when network conditions change or connectivity is restored.
func (m *Manager) ResetInterval() {
m.mu.Lock()
defer m.mu.Unlock()
if m.sendHolepunchInterval != sendHolepunchIntervalMin {
m.sendHolepunchInterval = sendHolepunchIntervalMin
logger.Info("Reset hole punch interval to minimum (%v)", sendHolepunchIntervalMin)
}
// Signal the goroutine to apply the new interval if running
if m.running && m.updateChan != nil {
select {
case m.updateChan <- struct{}{}:
default:
// Channel full or closed, skip
}
}
}
// TriggerHolePunch sends an immediate hole punch packet to all configured exit nodes
// This is useful for triggering hole punching on demand without waiting for the interval
func (m *Manager) TriggerHolePunch() error {
m.mu.Lock()
if len(m.exitNodes) == 0 {
m.mu.Unlock()
return fmt.Errorf("no exit nodes configured")
}
// Get a copy of exit nodes to work with
currentExitNodes := make([]ExitNode, 0, len(m.exitNodes))
for _, node := range m.exitNodes {
currentExitNodes = append(currentExitNodes, node)
}
m.mu.Unlock()
logger.Info("Triggering on-demand hole punch to %d exit nodes", len(currentExitNodes))
// Send hole punch to all exit nodes
successCount := 0
for _, exitNode := range currentExitNodes {
host, err := util.ResolveDomain(exitNode.Endpoint)
if err != nil {
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
continue
}
serverAddr := net.JoinHostPort(host, "21820")
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
continue
}
if err := m.sendHolePunch(remoteAddr, exitNode.PublicKey); err != nil {
logger.Warn("Failed to send on-demand hole punch to %s: %v", exitNode.Endpoint, err)
continue
}
logger.Debug("Sent on-demand hole punch to %s", exitNode.Endpoint)
successCount++
}
if successCount == 0 {
return fmt.Errorf("failed to send hole punch to any exit node")
}
logger.Info("Successfully sent on-demand hole punch to %d/%d exit nodes", successCount, len(currentExitNodes))
return nil
}
// StartMultipleExitNodes starts hole punching to multiple exit nodes
func (m *Manager) StartMultipleExitNodes(exitNodes []ExitNode) error {
m.mu.Lock()
if m.running {
m.mu.Unlock()
logger.Debug("UDP hole punch already running, skipping new request")
return fmt.Errorf("hole punch already running")
}
if len(exitNodes) == 0 {
m.mu.Unlock()
logger.Warn("No exit nodes provided for hole punching")
return fmt.Errorf("no exit nodes provided")
}
// Populate exit nodes map
m.exitNodes = make(map[string]ExitNode)
for _, node := range exitNodes {
m.exitNodes[node.Endpoint] = node
}
m.running = true
m.stopChan = make(chan struct{})
m.updateChan = make(chan struct{}, 1)
m.mu.Unlock()
logger.Info("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes))
go m.runMultipleExitNodes()
return nil
}
// Start starts hole punching with the current set of exit nodes
func (m *Manager) Start() error {
m.mu.Lock()
if m.running {
m.mu.Unlock()
logger.Debug("UDP hole punch already running")
return fmt.Errorf("hole punch already running")
}
if len(m.exitNodes) == 0 {
m.mu.Unlock()
logger.Warn("No exit nodes configured for hole punching")
return fmt.Errorf("no exit nodes configured")
}
m.running = true
m.stopChan = make(chan struct{})
m.updateChan = make(chan struct{}, 1)
m.mu.Unlock()
logger.Info("Starting UDP hole punch with %d exit nodes", len(m.exitNodes))
go m.runMultipleExitNodes()
return nil
}
// runMultipleExitNodes performs hole punching to multiple exit nodes
func (m *Manager) runMultipleExitNodes() {
defer func() {
m.mu.Lock()
m.running = false
m.mu.Unlock()
logger.Info("UDP hole punch goroutine ended for all exit nodes")
}()
// Resolve all endpoints upfront
type resolvedExitNode struct {
remoteAddr *net.UDPAddr
publicKey string
endpointName string
}
resolveNodes := func() []resolvedExitNode {
m.mu.Lock()
currentExitNodes := make([]ExitNode, 0, len(m.exitNodes))
for _, node := range m.exitNodes {
currentExitNodes = append(currentExitNodes, node)
}
m.mu.Unlock()
var resolvedNodes []resolvedExitNode
for _, exitNode := range currentExitNodes {
host, err := util.ResolveDomain(exitNode.Endpoint)
if err != nil {
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
continue
}
serverAddr := net.JoinHostPort(host, "21820")
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
continue
}
resolvedNodes = append(resolvedNodes, resolvedExitNode{
remoteAddr: remoteAddr,
publicKey: exitNode.PublicKey,
endpointName: exitNode.Endpoint,
})
logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
}
return resolvedNodes
}
resolvedNodes := resolveNodes()
if len(resolvedNodes) == 0 {
logger.Error("No exit nodes could be resolved")
return
}
// Send initial hole punch to all exit nodes
for _, node := range resolvedNodes {
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
logger.Warn("Failed to send initial hole punch to %s: %v", node.endpointName, err)
}
}
// Start with minimum interval
m.mu.Lock()
m.sendHolepunchInterval = sendHolepunchIntervalMin
m.mu.Unlock()
ticker := time.NewTicker(m.sendHolepunchInterval)
defer ticker.Stop()
for {
select {
case <-m.stopChan:
logger.Debug("Hole punch stopped by signal")
return
case <-m.updateChan:
// Re-resolve exit nodes when update is signaled
logger.Info("Refreshing exit nodes for hole punching")
resolvedNodes = resolveNodes()
if len(resolvedNodes) == 0 {
logger.Warn("No exit nodes available after refresh")
}
// Reset interval to minimum on update
m.mu.Lock()
m.sendHolepunchInterval = sendHolepunchIntervalMin
m.mu.Unlock()
ticker.Reset(m.sendHolepunchInterval)
// Send immediate hole punch to newly resolved nodes
for _, node := range resolvedNodes {
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err)
}
}
case <-ticker.C:
// Send hole punch to all exit nodes
for _, node := range resolvedNodes {
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err)
}
}
// Exponential backoff: double the interval up to max
m.mu.Lock()
newInterval := m.sendHolepunchInterval * 2
if newInterval > sendHolepunchIntervalMax {
newInterval = sendHolepunchIntervalMax
}
if newInterval != m.sendHolepunchInterval {
m.sendHolepunchInterval = newInterval
ticker.Reset(m.sendHolepunchInterval)
logger.Debug("Increased hole punch interval to %v", m.sendHolepunchInterval)
}
m.mu.Unlock()
}
}
}
// sendHolePunch sends an encrypted hole punch packet using the shared bind
func (m *Manager) sendHolePunch(remoteAddr *net.UDPAddr, serverPubKey string) error {
m.mu.Lock()
token := m.token
ID := m.ID
m.mu.Unlock()
if serverPubKey == "" || token == "" {
return fmt.Errorf("server public key or OLM token is empty")
}
var payload interface{}
if m.clientType == "newt" {
payload = struct {
ID string `json:"newtId"`
Token string `json:"token"`
PublicKey string `json:"publicKey"`
}{
ID: ID,
Token: token,
PublicKey: m.publicKey,
}
} else {
payload = struct {
ID string `json:"olmId"`
Token string `json:"token"`
PublicKey string `json:"publicKey"`
}{
ID: ID,
Token: token,
PublicKey: m.publicKey,
}
}
// Convert payload to JSON
payloadBytes, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal payload: %w", err)
}
// Encrypt the payload using the server's WireGuard public key
encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey)
if err != nil {
return fmt.Errorf("failed to encrypt payload: %w", err)
}
jsonData, err := json.Marshal(encryptedPayload)
if err != nil {
return fmt.Errorf("failed to marshal encrypted payload: %w", err)
}
_, err = m.sharedBind.WriteToUDP(jsonData, remoteAddr)
if err != nil {
return fmt.Errorf("failed to write to UDP: %w", err)
}
logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData))
return nil
}
// encryptPayload encrypts the payload using ChaCha20-Poly1305 AEAD with X25519 key exchange
func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) {
// Generate an ephemeral keypair for this message
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
}
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
// Parse the server's public key
serverPubKey, err := wgtypes.ParseKey(serverPublicKey)
if err != nil {
return nil, fmt.Errorf("failed to parse server public key: %v", err)
}
// Use X25519 for key exchange
var ephPrivKeyFixed [32]byte
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
// Perform X25519 key exchange
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
if err != nil {
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
}
// Create an AEAD cipher using the shared secret
aead, err := chacha20poly1305.New(sharedSecret)
if err != nil {
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
}
// Generate a random nonce
nonce := make([]byte, aead.NonceSize())
if _, err := mrand.Read(nonce); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %v", err)
}
// Encrypt the payload
ciphertext := aead.Seal(nil, nonce, payload, nil)
// Prepare the final encrypted message
encryptedMsg := struct {
EphemeralPublicKey string `json:"ephemeralPublicKey"`
Nonce []byte `json:"nonce"`
Ciphertext []byte `json:"ciphertext"`
}{
EphemeralPublicKey: ephemeralPublicKey.String(),
Nonce: nonce,
Ciphertext: ciphertext,
}
return encryptedMsg, nil
}

343
holepunch/tester.go Normal file
View File

@@ -0,0 +1,343 @@
package holepunch
import (
"crypto/rand"
"fmt"
"net"
"net/netip"
"sync"
"time"
"github.com/fosrl/newt/bind"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/util"
)
// TestResult represents the result of a connection test
type TestResult struct {
// Success indicates whether the test was successful
Success bool
// RTT is the round-trip time of the test packet
RTT time.Duration
// Endpoint is the endpoint that was tested
Endpoint string
// Error contains any error that occurred during the test
Error error
}
// TestConnectionOptions configures the connection test
type TestConnectionOptions struct {
// Timeout is how long to wait for a response (default: 5 seconds)
Timeout time.Duration
// Retries is the number of times to retry on failure (default: 0)
Retries int
}
// DefaultTestOptions returns the default test options
func DefaultTestOptions() TestConnectionOptions {
return TestConnectionOptions{
Timeout: 5 * time.Second,
Retries: 0,
}
}
// HolepunchTester monitors holepunch connectivity using magic packets
type HolepunchTester struct {
sharedBind *bind.SharedBind
mu sync.RWMutex
running bool
stopChan chan struct{}
// Pending requests waiting for responses (key: echo data as string)
pendingRequests sync.Map // map[string]*pendingRequest
// Callback when connection status changes
callback HolepunchStatusCallback
}
// HolepunchStatus represents the status of a holepunch connection
type HolepunchStatus struct {
Endpoint string
Connected bool
RTT time.Duration
}
// HolepunchStatusCallback is called when holepunch status changes
type HolepunchStatusCallback func(status HolepunchStatus)
// pendingRequest tracks a pending test request
type pendingRequest struct {
endpoint string
sentAt time.Time
replyChan chan time.Duration
}
// NewHolepunchTester creates a new holepunch tester using the given SharedBind
func NewHolepunchTester(sharedBind *bind.SharedBind) *HolepunchTester {
return &HolepunchTester{
sharedBind: sharedBind,
}
}
// SetCallback sets the callback for connection status changes
func (t *HolepunchTester) SetCallback(callback HolepunchStatusCallback) {
t.mu.Lock()
defer t.mu.Unlock()
t.callback = callback
}
// Start begins listening for magic packet responses
func (t *HolepunchTester) Start() error {
t.mu.Lock()
defer t.mu.Unlock()
if t.running {
return fmt.Errorf("tester already running")
}
if t.sharedBind == nil {
return fmt.Errorf("sharedBind is nil")
}
t.running = true
t.stopChan = make(chan struct{})
// Register our callback with the SharedBind to receive magic responses
t.sharedBind.SetMagicResponseCallback(t.handleResponse)
logger.Debug("HolepunchTester started")
return nil
}
// Stop stops the tester
func (t *HolepunchTester) Stop() {
t.mu.Lock()
defer t.mu.Unlock()
if !t.running {
return
}
t.running = false
close(t.stopChan)
// Clear the callback
if t.sharedBind != nil {
t.sharedBind.SetMagicResponseCallback(nil)
}
// Cancel all pending requests
t.pendingRequests.Range(func(key, value interface{}) bool {
if req, ok := value.(*pendingRequest); ok {
close(req.replyChan)
}
t.pendingRequests.Delete(key)
return true
})
logger.Debug("HolepunchTester stopped")
}
// handleResponse is called by SharedBind when a magic response is received
func (t *HolepunchTester) handleResponse(addr netip.AddrPort, echoData []byte) {
logger.Debug("Received magic response from %s", addr.String())
key := string(echoData)
value, ok := t.pendingRequests.LoadAndDelete(key)
if !ok {
// No matching request found
logger.Debug("No pending request found for magic response from %s", addr.String())
return
}
req := value.(*pendingRequest)
rtt := time.Since(req.sentAt)
logger.Debug("Magic response matched pending request for %s (RTT: %v)", req.endpoint, rtt)
// Send RTT to the waiting goroutine (non-blocking)
select {
case req.replyChan <- rtt:
default:
}
}
// TestEndpoint sends a magic test packet to the endpoint and waits for a response.
// This uses the SharedBind so packets come from the same source port as WireGuard.
func (t *HolepunchTester) TestEndpoint(endpoint string, timeout time.Duration) TestResult {
result := TestResult{
Endpoint: endpoint,
}
t.mu.RLock()
running := t.running
sharedBind := t.sharedBind
t.mu.RUnlock()
if !running {
result.Error = fmt.Errorf("tester not running")
return result
}
if sharedBind == nil || sharedBind.IsClosed() {
result.Error = fmt.Errorf("sharedBind is nil or closed")
return result
}
// Resolve the endpoint
host, err := util.ResolveDomain(endpoint)
if err != nil {
host = endpoint
}
_, _, err = net.SplitHostPort(host)
if err != nil {
host = net.JoinHostPort(host, "21820")
}
remoteAddr, err := net.ResolveUDPAddr("udp", host)
if err != nil {
result.Error = fmt.Errorf("failed to resolve UDP address %s: %w", host, err)
return result
}
// Generate random data for the test packet
randomData := make([]byte, bind.MagicPacketDataLen)
if _, err := rand.Read(randomData); err != nil {
result.Error = fmt.Errorf("failed to generate random data: %w", err)
return result
}
// Create a pending request
req := &pendingRequest{
endpoint: endpoint,
sentAt: time.Now(),
replyChan: make(chan time.Duration, 1),
}
key := string(randomData)
t.pendingRequests.Store(key, req)
// Build the test request packet
request := make([]byte, bind.MagicTestRequestLen)
copy(request, bind.MagicTestRequest)
copy(request[len(bind.MagicTestRequest):], randomData)
// Send the test packet
_, err = sharedBind.WriteToUDP(request, remoteAddr)
if err != nil {
t.pendingRequests.Delete(key)
result.Error = fmt.Errorf("failed to send test packet: %w", err)
return result
}
// Wait for response with timeout
select {
case rtt, ok := <-req.replyChan:
if ok {
result.Success = true
result.RTT = rtt
} else {
result.Error = fmt.Errorf("request cancelled")
}
case <-time.After(timeout):
t.pendingRequests.Delete(key)
result.Error = fmt.Errorf("timeout waiting for response")
}
return result
}
// TestConnectionWithBind sends a magic test packet using an existing SharedBind.
// This is useful when you want to test the connection through the same socket
// that WireGuard is using, which tests the actual hole-punched path.
func TestConnectionWithBind(sharedBind *bind.SharedBind, endpoint string, opts *TestConnectionOptions) TestResult {
if opts == nil {
defaultOpts := DefaultTestOptions()
opts = &defaultOpts
}
result := TestResult{
Endpoint: endpoint,
}
if sharedBind == nil {
result.Error = fmt.Errorf("sharedBind is nil")
return result
}
if sharedBind.IsClosed() {
result.Error = fmt.Errorf("sharedBind is closed")
return result
}
// Resolve the endpoint
host, err := util.ResolveDomain(endpoint)
if err != nil {
host = endpoint
}
_, _, err = net.SplitHostPort(host)
if err != nil {
host = net.JoinHostPort(host, "21820")
}
remoteAddr, err := net.ResolveUDPAddr("udp", host)
if err != nil {
result.Error = fmt.Errorf("failed to resolve UDP address %s: %w", host, err)
return result
}
// Generate random data for the test packet
randomData := make([]byte, bind.MagicPacketDataLen)
if _, err := rand.Read(randomData); err != nil {
result.Error = fmt.Errorf("failed to generate random data: %w", err)
return result
}
// Build the test request packet
request := make([]byte, bind.MagicTestRequestLen)
copy(request, bind.MagicTestRequest)
copy(request[len(bind.MagicTestRequest):], randomData)
// Get the underlying UDP connection to set read deadline and read response
udpConn := sharedBind.GetUDPConn()
if udpConn == nil {
result.Error = fmt.Errorf("could not get UDP connection from SharedBind")
return result
}
attempts := opts.Retries + 1
for attempt := 0; attempt < attempts; attempt++ {
if attempt > 0 {
logger.Debug("Retrying connection test to %s (attempt %d/%d)", endpoint, attempt+1, attempts)
}
// Note: We can't easily set a read deadline on the shared connection
// without affecting WireGuard, so we use a goroutine with timeout instead
startTime := time.Now()
// Send the test packet through the shared bind
_, err = sharedBind.WriteToUDP(request, remoteAddr)
if err != nil {
result.Error = fmt.Errorf("failed to send test packet: %w", err)
if attempt < attempts-1 {
continue
}
return result
}
// For shared bind test, we send the packet but can't easily wait for
// response without interfering with WireGuard's receive loop.
// The response will be handled by SharedBind automatically.
// We consider the test successful if the send succeeded.
// For a full round-trip test, use TestConnection() with a separate socket.
result.RTT = time.Since(startTime)
result.Success = true
result.Error = nil
logger.Debug("Test packet sent to %s via SharedBind", endpoint)
return result
}
return result
}

View File

@@ -0,0 +1,80 @@
package state
import (
"sync"
"sync/atomic"
"time"
"github.com/fosrl/newt/internal/telemetry"
)
// TelemetryView is a minimal, thread-safe implementation to feed observables.
// Since one Newt process represents one site, we expose a single logical site.
// site_id is a resource attribute, so we do not emit per-site labels here.
type TelemetryView struct {
online atomic.Bool
lastHBUnix atomic.Int64 // unix seconds
// per-tunnel sessions
sessMu sync.RWMutex
sessions map[string]*atomic.Int64
}
var (
globalView atomic.Pointer[TelemetryView]
)
// Global returns a singleton TelemetryView.
func Global() *TelemetryView {
if v := globalView.Load(); v != nil { return v }
v := &TelemetryView{ sessions: make(map[string]*atomic.Int64) }
globalView.Store(v)
telemetry.RegisterStateView(v)
return v
}
// Instrumentation helpers
func (v *TelemetryView) IncSessions(tunnelID string) {
v.sessMu.Lock(); defer v.sessMu.Unlock()
c := v.sessions[tunnelID]
if c == nil { c = &atomic.Int64{}; v.sessions[tunnelID] = c }
c.Add(1)
}
func (v *TelemetryView) DecSessions(tunnelID string) {
v.sessMu.Lock(); defer v.sessMu.Unlock()
if c := v.sessions[tunnelID]; c != nil {
c.Add(-1)
if c.Load() <= 0 { delete(v.sessions, tunnelID) }
}
}
func (v *TelemetryView) ClearTunnel(tunnelID string) {
v.sessMu.Lock(); defer v.sessMu.Unlock()
delete(v.sessions, tunnelID)
}
func (v *TelemetryView) SetOnline(b bool) { v.online.Store(b) }
func (v *TelemetryView) TouchHeartbeat() { v.lastHBUnix.Store(time.Now().Unix()) }
// --- telemetry.StateView interface ---
func (v *TelemetryView) ListSites() []string { return []string{"self"} }
func (v *TelemetryView) Online(_ string) (bool, bool) { return v.online.Load(), true }
func (v *TelemetryView) LastHeartbeat(_ string) (time.Time, bool) {
sec := v.lastHBUnix.Load()
if sec == 0 { return time.Time{}, false }
return time.Unix(sec, 0), true
}
func (v *TelemetryView) ActiveSessions(_ string) (int64, bool) {
// aggregated sessions (not used for per-tunnel gauge)
v.sessMu.RLock(); defer v.sessMu.RUnlock()
var sum int64
for _, c := range v.sessions { if c != nil { sum += c.Load() } }
return sum, true
}
// Extended accessor used by telemetry callback to publish per-tunnel samples.
func (v *TelemetryView) SessionsByTunnel() map[string]int64 {
v.sessMu.RLock(); defer v.sessMu.RUnlock()
out := make(map[string]int64, len(v.sessions))
for id, c := range v.sessions { if c != nil && c.Load() > 0 { out[id] = c.Load() } }
return out
}

View File

@@ -0,0 +1,19 @@
package telemetry
// Protocol labels (low-cardinality)
const (
ProtocolTCP = "tcp"
ProtocolUDP = "udp"
)
// Reconnect reason bins (fixed, low-cardinality)
const (
ReasonServerRequest = "server_request"
ReasonTimeout = "timeout"
ReasonPeerClose = "peer_close"
ReasonNetworkChange = "network_change"
ReasonAuthError = "auth_error"
ReasonHandshakeError = "handshake_error"
ReasonConfigChange = "config_change"
ReasonError = "error"
)

View File

@@ -0,0 +1,32 @@
package telemetry
import "testing"
func TestAllowedConstants(t *testing.T) {
allowedReasons := map[string]struct{}{
ReasonServerRequest: {},
ReasonTimeout: {},
ReasonPeerClose: {},
ReasonNetworkChange: {},
ReasonAuthError: {},
ReasonHandshakeError: {},
ReasonConfigChange: {},
ReasonError: {},
}
for k := range allowedReasons {
if k == "" {
t.Fatalf("empty reason constant")
}
}
allowedProtocols := map[string]struct{}{
ProtocolTCP: {},
ProtocolUDP: {},
}
for k := range allowedProtocols {
if k == "" {
t.Fatalf("empty protocol constant")
}
}
}

View File

@@ -0,0 +1,542 @@
package telemetry
import (
"context"
"sync"
"sync/atomic"
"time"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
)
// Instruments and helpers for Newt metrics following the naming, units, and
// low-cardinality label guidance from the issue description.
//
// Counters end with _total, durations are in seconds, sizes in bytes.
// Only low-cardinality stable labels are supported: tunnel_id,
// transport, direction, result, reason, error_type.
var (
initOnce sync.Once
meter metric.Meter
// Site / Registration
mSiteRegistrations metric.Int64Counter
mSiteOnline metric.Int64ObservableGauge
mSiteLastHeartbeat metric.Float64ObservableGauge
// Tunnel / Sessions
mTunnelSessions metric.Int64ObservableGauge
mTunnelBytes metric.Int64Counter
mTunnelLatency metric.Float64Histogram
mReconnects metric.Int64Counter
// Connection / NAT
mConnAttempts metric.Int64Counter
mConnErrors metric.Int64Counter
// Config/Restart
mConfigReloads metric.Int64Counter
mConfigApply metric.Float64Histogram
mCertRotationTotal metric.Int64Counter
mProcessStartTime metric.Float64ObservableGauge
// Build info
mBuildInfo metric.Int64ObservableGauge
// WebSocket
mWSConnectLatency metric.Float64Histogram
mWSMessages metric.Int64Counter
mWSDisconnects metric.Int64Counter
mWSKeepaliveFailure metric.Int64Counter
mWSSessionDuration metric.Float64Histogram
mWSConnected metric.Int64ObservableGauge
mWSReconnects metric.Int64Counter
// Proxy
mProxyActiveConns metric.Int64ObservableGauge
mProxyBufferBytes metric.Int64ObservableGauge
mProxyAsyncBacklogByte metric.Int64ObservableGauge
mProxyDropsTotal metric.Int64Counter
mProxyAcceptsTotal metric.Int64Counter
mProxyConnDuration metric.Float64Histogram
mProxyConnectionsTotal metric.Int64Counter
buildVersion string
buildCommit string
processStartUnix = float64(time.Now().UnixNano()) / 1e9
wsConnectedState atomic.Int64
)
// Proxy connection lifecycle events.
const (
ProxyConnectionOpened = "opened"
ProxyConnectionClosed = "closed"
)
// attrsWithSite appends site/region labels only when explicitly enabled to keep
// label cardinality low by default.
func attrsWithSite(extra ...attribute.KeyValue) []attribute.KeyValue {
attrs := make([]attribute.KeyValue, len(extra))
copy(attrs, extra)
if ShouldIncludeSiteLabels() {
attrs = append(attrs, siteAttrs()...)
}
return attrs
}
func registerInstruments() error {
var err error
initOnce.Do(func() {
meter = otel.Meter("newt")
if e := registerSiteInstruments(); e != nil {
err = e
return
}
if e := registerTunnelInstruments(); e != nil {
err = e
return
}
if e := registerConnInstruments(); e != nil {
err = e
return
}
if e := registerConfigInstruments(); e != nil {
err = e
return
}
if e := registerBuildWSProxyInstruments(); e != nil {
err = e
return
}
})
return err
}
func registerSiteInstruments() error {
var err error
mSiteRegistrations, err = meter.Int64Counter("newt_site_registrations_total",
metric.WithDescription("Total site registration attempts"))
if err != nil {
return err
}
mSiteOnline, err = meter.Int64ObservableGauge("newt_site_online",
metric.WithDescription("Site online (0/1)"))
if err != nil {
return err
}
mSiteLastHeartbeat, err = meter.Float64ObservableGauge("newt_site_last_heartbeat_timestamp_seconds",
metric.WithDescription("Unix timestamp of the last site heartbeat"),
metric.WithUnit("s"))
if err != nil {
return err
}
return nil
}
func registerTunnelInstruments() error {
var err error
mTunnelSessions, err = meter.Int64ObservableGauge("newt_tunnel_sessions",
metric.WithDescription("Active tunnel sessions"))
if err != nil {
return err
}
mTunnelBytes, err = meter.Int64Counter("newt_tunnel_bytes_total",
metric.WithDescription("Tunnel bytes ingress/egress"),
metric.WithUnit("By"))
if err != nil {
return err
}
mTunnelLatency, err = meter.Float64Histogram("newt_tunnel_latency_seconds",
metric.WithDescription("Per-tunnel latency in seconds"),
metric.WithUnit("s"))
if err != nil {
return err
}
mReconnects, err = meter.Int64Counter("newt_tunnel_reconnects_total",
metric.WithDescription("Tunnel reconnect events"))
if err != nil {
return err
}
return nil
}
func registerConnInstruments() error {
var err error
mConnAttempts, err = meter.Int64Counter("newt_connection_attempts_total",
metric.WithDescription("Connection attempts"))
if err != nil {
return err
}
mConnErrors, err = meter.Int64Counter("newt_connection_errors_total",
metric.WithDescription("Connection errors by type"))
if err != nil {
return err
}
return nil
}
func registerConfigInstruments() error {
mConfigReloads, _ = meter.Int64Counter("newt_config_reloads_total",
metric.WithDescription("Configuration reloads"))
mConfigApply, _ = meter.Float64Histogram("newt_config_apply_seconds",
metric.WithDescription("Configuration apply duration in seconds"),
metric.WithUnit("s"))
mCertRotationTotal, _ = meter.Int64Counter("newt_cert_rotation_total",
metric.WithDescription("Certificate rotation events (success/failure)"))
mProcessStartTime, _ = meter.Float64ObservableGauge("process_start_time_seconds",
metric.WithDescription("Unix timestamp of the process start time"),
metric.WithUnit("s"))
if mProcessStartTime != nil {
if _, err := meter.RegisterCallback(func(ctx context.Context, o metric.Observer) error {
o.ObserveFloat64(mProcessStartTime, processStartUnix)
return nil
}, mProcessStartTime); err != nil {
otel.Handle(err)
}
}
return nil
}
func registerBuildWSProxyInstruments() error {
// Build info gauge (value 1 with version/commit attributes)
mBuildInfo, _ = meter.Int64ObservableGauge("newt_build_info",
metric.WithDescription("Newt build information (value is always 1)"))
// WebSocket
mWSConnectLatency, _ = meter.Float64Histogram("newt_websocket_connect_latency_seconds",
metric.WithDescription("WebSocket connect latency in seconds"),
metric.WithUnit("s"))
mWSMessages, _ = meter.Int64Counter("newt_websocket_messages_total",
metric.WithDescription("WebSocket messages by direction and type"))
mWSDisconnects, _ = meter.Int64Counter("newt_websocket_disconnects_total",
metric.WithDescription("WebSocket disconnects by reason/result"))
mWSKeepaliveFailure, _ = meter.Int64Counter("newt_websocket_keepalive_failures_total",
metric.WithDescription("WebSocket keepalive (ping/pong) failures"))
mWSSessionDuration, _ = meter.Float64Histogram("newt_websocket_session_duration_seconds",
metric.WithDescription("Duration of established WebSocket sessions"),
metric.WithUnit("s"))
mWSConnected, _ = meter.Int64ObservableGauge("newt_websocket_connected",
metric.WithDescription("WebSocket connection state (1=connected, 0=disconnected)"))
mWSReconnects, _ = meter.Int64Counter("newt_websocket_reconnects_total",
metric.WithDescription("WebSocket reconnect attempts by reason"))
// Proxy
mProxyActiveConns, _ = meter.Int64ObservableGauge("newt_proxy_active_connections",
metric.WithDescription("Proxy active connections per tunnel and protocol"))
mProxyBufferBytes, _ = meter.Int64ObservableGauge("newt_proxy_buffer_bytes",
metric.WithDescription("Proxy buffer bytes (may approximate async backlog)"),
metric.WithUnit("By"))
mProxyAsyncBacklogByte, _ = meter.Int64ObservableGauge("newt_proxy_async_backlog_bytes",
metric.WithDescription("Unflushed async byte backlog per tunnel and protocol"),
metric.WithUnit("By"))
mProxyDropsTotal, _ = meter.Int64Counter("newt_proxy_drops_total",
metric.WithDescription("Proxy drops due to write errors"))
mProxyAcceptsTotal, _ = meter.Int64Counter("newt_proxy_accept_total",
metric.WithDescription("Proxy connection accepts by protocol and result"))
mProxyConnDuration, _ = meter.Float64Histogram("newt_proxy_connection_duration_seconds",
metric.WithDescription("Duration of completed proxy connections"),
metric.WithUnit("s"))
mProxyConnectionsTotal, _ = meter.Int64Counter("newt_proxy_connections_total",
metric.WithDescription("Proxy connection lifecycle events by protocol"))
// Register a default callback for build info if version/commit set
reg, e := meter.RegisterCallback(func(ctx context.Context, o metric.Observer) error {
if buildVersion == "" && buildCommit == "" {
return nil
}
attrs := []attribute.KeyValue{}
if buildVersion != "" {
attrs = append(attrs, attribute.String("version", buildVersion))
}
if buildCommit != "" {
attrs = append(attrs, attribute.String("commit", buildCommit))
}
if ShouldIncludeSiteLabels() {
attrs = append(attrs, siteAttrs()...)
}
o.ObserveInt64(mBuildInfo, 1, metric.WithAttributes(attrs...))
return nil
}, mBuildInfo)
if e != nil {
otel.Handle(e)
} else {
// Provide a functional stopper that unregisters the callback
obsStopper = func() { _ = reg.Unregister() }
}
if mWSConnected != nil {
if regConn, err := meter.RegisterCallback(func(ctx context.Context, o metric.Observer) error {
val := wsConnectedState.Load()
o.ObserveInt64(mWSConnected, val, metric.WithAttributes(attrsWithSite()...))
return nil
}, mWSConnected); err != nil {
otel.Handle(err)
} else {
wsConnStopper = func() { _ = regConn.Unregister() }
}
}
return nil
}
// Observable registration: Newt can register a callback to report gauges.
// Call SetObservableCallback once to start observing online status, last
// heartbeat seconds, and active sessions.
var (
obsOnce sync.Once
obsStopper func()
proxyObsOnce sync.Once
proxyStopper func()
wsConnStopper func()
)
// SetObservableCallback registers a single callback that will be invoked
// on collection. Use the provided observer to emit values for the observable
// gauges defined here.
//
// Example inside your code (where you have access to current state):
//
// telemetry.SetObservableCallback(func(ctx context.Context, o metric.Observer) error {
// o.ObserveInt64(mSiteOnline, 1)
// o.ObserveFloat64(mSiteLastHeartbeat, float64(lastHB.Unix()))
// o.ObserveInt64(mTunnelSessions, int64(len(activeSessions)))
// return nil
// })
func SetObservableCallback(cb func(context.Context, metric.Observer) error) {
obsOnce.Do(func() {
reg, e := meter.RegisterCallback(cb, mSiteOnline, mSiteLastHeartbeat, mTunnelSessions)
if e != nil {
otel.Handle(e)
obsStopper = func() {
// no-op: registration failed; keep stopper callable
}
return
}
// Provide a functional stopper mirroring proxy/build-info behavior
obsStopper = func() { _ = reg.Unregister() }
})
}
// SetProxyObservableCallback registers a callback to observe proxy gauges.
func SetProxyObservableCallback(cb func(context.Context, metric.Observer) error) {
proxyObsOnce.Do(func() {
reg, e := meter.RegisterCallback(cb, mProxyActiveConns, mProxyBufferBytes, mProxyAsyncBacklogByte)
if e != nil {
otel.Handle(e)
proxyStopper = func() {
// no-op: registration failed; keep stopper callable
}
return
}
// Provide a functional stopper to unregister later if needed
proxyStopper = func() { _ = reg.Unregister() }
})
}
// Build info registration
func RegisterBuildInfo(version, commit string) {
buildVersion = version
buildCommit = commit
}
// Config reloads
func IncConfigReload(ctx context.Context, result string) {
mConfigReloads.Add(ctx, 1, metric.WithAttributes(attrsWithSite(
attribute.String("result", result),
)...))
}
// Helpers for counters/histograms
func IncSiteRegistration(ctx context.Context, result string) {
attrs := []attribute.KeyValue{
attribute.String("result", result),
}
mSiteRegistrations.Add(ctx, 1, metric.WithAttributes(attrsWithSite(attrs...)...))
}
func AddTunnelBytes(ctx context.Context, tunnelID, direction string, n int64) {
attrs := []attribute.KeyValue{
attribute.String("direction", direction),
}
if ShouldIncludeTunnelID() && tunnelID != "" {
attrs = append(attrs, attribute.String("tunnel_id", tunnelID))
}
mTunnelBytes.Add(ctx, n, metric.WithAttributes(attrsWithSite(attrs...)...))
}
// AddTunnelBytesSet adds bytes using a pre-built attribute.Set to avoid per-call allocations.
func AddTunnelBytesSet(ctx context.Context, n int64, attrs attribute.Set) {
mTunnelBytes.Add(ctx, n, metric.WithAttributeSet(attrs))
}
// --- WebSocket helpers ---
func ObserveWSConnectLatency(ctx context.Context, seconds float64, result, errorType string) {
attrs := []attribute.KeyValue{
attribute.String("transport", "websocket"),
attribute.String("result", result),
}
if errorType != "" {
attrs = append(attrs, attribute.String("error_type", errorType))
}
mWSConnectLatency.Record(ctx, seconds, metric.WithAttributes(attrsWithSite(attrs...)...))
}
func IncWSMessage(ctx context.Context, direction, msgType string) {
mWSMessages.Add(ctx, 1, metric.WithAttributes(attrsWithSite(
attribute.String("direction", direction),
attribute.String("msg_type", msgType),
)...))
}
func IncWSDisconnect(ctx context.Context, reason, result string) {
mWSDisconnects.Add(ctx, 1, metric.WithAttributes(attrsWithSite(
attribute.String("reason", reason),
attribute.String("result", result),
)...))
}
func IncWSKeepaliveFailure(ctx context.Context, reason string) {
mWSKeepaliveFailure.Add(ctx, 1, metric.WithAttributes(attrsWithSite(
attribute.String("reason", reason),
)...))
}
// SetWSConnectionState updates the backing gauge for the WebSocket connected state.
func SetWSConnectionState(connected bool) {
if connected {
wsConnectedState.Store(1)
} else {
wsConnectedState.Store(0)
}
}
// IncWSReconnect increments the WebSocket reconnect counter with a bounded reason label.
func IncWSReconnect(ctx context.Context, reason string) {
if reason == "" {
reason = "unknown"
}
mWSReconnects.Add(ctx, 1, metric.WithAttributes(attrsWithSite(
attribute.String("reason", reason),
)...))
}
func ObserveWSSessionDuration(ctx context.Context, seconds float64, result string) {
mWSSessionDuration.Record(ctx, seconds, metric.WithAttributes(attrsWithSite(
attribute.String("result", result),
)...))
}
// --- Proxy helpers ---
func ObserveProxyActiveConnsObs(o metric.Observer, value int64, attrs []attribute.KeyValue) {
o.ObserveInt64(mProxyActiveConns, value, metric.WithAttributes(attrs...))
}
func ObserveProxyBufferBytesObs(o metric.Observer, value int64, attrs []attribute.KeyValue) {
o.ObserveInt64(mProxyBufferBytes, value, metric.WithAttributes(attrs...))
}
func ObserveProxyAsyncBacklogObs(o metric.Observer, value int64, attrs []attribute.KeyValue) {
o.ObserveInt64(mProxyAsyncBacklogByte, value, metric.WithAttributes(attrs...))
}
func IncProxyDrops(ctx context.Context, tunnelID, protocol string) {
attrs := []attribute.KeyValue{
attribute.String("protocol", protocol),
}
if ShouldIncludeTunnelID() && tunnelID != "" {
attrs = append(attrs, attribute.String("tunnel_id", tunnelID))
}
mProxyDropsTotal.Add(ctx, 1, metric.WithAttributes(attrsWithSite(attrs...)...))
}
func IncProxyAccept(ctx context.Context, tunnelID, protocol, result, reason string) {
attrs := []attribute.KeyValue{
attribute.String("protocol", protocol),
attribute.String("result", result),
}
if reason != "" {
attrs = append(attrs, attribute.String("reason", reason))
}
if ShouldIncludeTunnelID() && tunnelID != "" {
attrs = append(attrs, attribute.String("tunnel_id", tunnelID))
}
mProxyAcceptsTotal.Add(ctx, 1, metric.WithAttributes(attrsWithSite(attrs...)...))
}
func ObserveProxyConnectionDuration(ctx context.Context, tunnelID, protocol, result string, seconds float64) {
attrs := []attribute.KeyValue{
attribute.String("protocol", protocol),
attribute.String("result", result),
}
if ShouldIncludeTunnelID() && tunnelID != "" {
attrs = append(attrs, attribute.String("tunnel_id", tunnelID))
}
mProxyConnDuration.Record(ctx, seconds, metric.WithAttributes(attrsWithSite(attrs...)...))
}
// IncProxyConnectionEvent records proxy connection lifecycle events (opened/closed).
func IncProxyConnectionEvent(ctx context.Context, tunnelID, protocol, event string) {
if event == "" {
event = "unknown"
}
attrs := []attribute.KeyValue{
attribute.String("protocol", protocol),
attribute.String("event", event),
}
if ShouldIncludeTunnelID() && tunnelID != "" {
attrs = append(attrs, attribute.String("tunnel_id", tunnelID))
}
mProxyConnectionsTotal.Add(ctx, 1, metric.WithAttributes(attrsWithSite(attrs...)...))
}
// --- Config/PKI helpers ---
func ObserveConfigApply(ctx context.Context, phase, result string, seconds float64) {
mConfigApply.Record(ctx, seconds, metric.WithAttributes(attrsWithSite(
attribute.String("phase", phase),
attribute.String("result", result),
)...))
}
func IncCertRotation(ctx context.Context, result string) {
mCertRotationTotal.Add(ctx, 1, metric.WithAttributes(attrsWithSite(
attribute.String("result", result),
)...))
}
func ObserveTunnelLatency(ctx context.Context, tunnelID, transport string, seconds float64) {
attrs := []attribute.KeyValue{
attribute.String("transport", transport),
}
if ShouldIncludeTunnelID() && tunnelID != "" {
attrs = append(attrs, attribute.String("tunnel_id", tunnelID))
}
mTunnelLatency.Record(ctx, seconds, metric.WithAttributes(attrsWithSite(attrs...)...))
}
func IncReconnect(ctx context.Context, tunnelID, initiator, reason string) {
attrs := []attribute.KeyValue{
attribute.String("initiator", initiator),
attribute.String("reason", reason),
}
if ShouldIncludeTunnelID() && tunnelID != "" {
attrs = append(attrs, attribute.String("tunnel_id", tunnelID))
}
mReconnects.Add(ctx, 1, metric.WithAttributes(attrsWithSite(attrs...)...))
}
func IncConnAttempt(ctx context.Context, transport, result string) {
mConnAttempts.Add(ctx, 1, metric.WithAttributes(attrsWithSite(
attribute.String("transport", transport),
attribute.String("result", result),
)...))
}
func IncConnError(ctx context.Context, transport, typ string) {
mConnErrors.Add(ctx, 1, metric.WithAttributes(attrsWithSite(
attribute.String("transport", transport),
attribute.String("error_type", typ),
)...))
}

View File

@@ -0,0 +1,59 @@
package telemetry
import (
"sync"
"time"
)
func resetMetricsForTest() {
initOnce = sync.Once{}
obsOnce = sync.Once{}
proxyObsOnce = sync.Once{}
obsStopper = nil
proxyStopper = nil
if wsConnStopper != nil {
wsConnStopper()
}
wsConnStopper = nil
meter = nil
mSiteRegistrations = nil
mSiteOnline = nil
mSiteLastHeartbeat = nil
mTunnelSessions = nil
mTunnelBytes = nil
mTunnelLatency = nil
mReconnects = nil
mConnAttempts = nil
mConnErrors = nil
mConfigReloads = nil
mConfigApply = nil
mCertRotationTotal = nil
mProcessStartTime = nil
mBuildInfo = nil
mWSConnectLatency = nil
mWSMessages = nil
mWSDisconnects = nil
mWSKeepaliveFailure = nil
mWSSessionDuration = nil
mWSConnected = nil
mWSReconnects = nil
mProxyActiveConns = nil
mProxyBufferBytes = nil
mProxyAsyncBacklogByte = nil
mProxyDropsTotal = nil
mProxyAcceptsTotal = nil
mProxyConnDuration = nil
mProxyConnectionsTotal = nil
processStartUnix = float64(time.Now().UnixNano()) / 1e9
wsConnectedState.Store(0)
includeTunnelIDVal.Store(false)
includeSiteLabelVal.Store(false)
}

View File

@@ -0,0 +1,106 @@
package telemetry
import (
"context"
"sync/atomic"
"time"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
)
// StateView provides a read-only view for observable gauges.
// Implementations must be concurrency-safe and avoid blocking operations.
// All methods should be fast and use RLocks where applicable.
type StateView interface {
// ListSites returns a stable, low-cardinality list of site IDs to expose.
ListSites() []string
// Online returns whether the site is online.
Online(siteID string) (online bool, ok bool)
// LastHeartbeat returns the last heartbeat time for a site.
LastHeartbeat(siteID string) (t time.Time, ok bool)
// ActiveSessions returns the current number of active sessions for a site (across tunnels),
// or scoped to site if your model is site-scoped.
ActiveSessions(siteID string) (n int64, ok bool)
}
var (
stateView atomic.Value // of type StateView
)
// RegisterStateView sets the global StateView used by the default observable callback.
func RegisterStateView(v StateView) {
stateView.Store(v)
// If instruments are registered, ensure a callback exists.
if v != nil {
SetObservableCallback(func(ctx context.Context, o metric.Observer) error {
if any := stateView.Load(); any != nil {
if sv, ok := any.(StateView); ok {
for _, siteID := range sv.ListSites() {
observeSiteOnlineFor(o, sv, siteID)
observeLastHeartbeatFor(o, sv, siteID)
observeSessionsFor(o, siteID, sv)
}
}
}
return nil
})
}
}
func observeSiteOnlineFor(o metric.Observer, sv StateView, siteID string) {
if online, ok := sv.Online(siteID); ok {
val := int64(0)
if online {
val = 1
}
o.ObserveInt64(mSiteOnline, val, metric.WithAttributes(
attribute.String("site_id", siteID),
))
}
}
func observeLastHeartbeatFor(o metric.Observer, sv StateView, siteID string) {
if t, ok := sv.LastHeartbeat(siteID); ok {
ts := float64(t.UnixNano()) / 1e9
o.ObserveFloat64(mSiteLastHeartbeat, ts, metric.WithAttributes(
attribute.String("site_id", siteID),
))
}
}
func observeSessionsFor(o metric.Observer, siteID string, any interface{}) {
if tm, ok := any.(interface{ SessionsByTunnel() map[string]int64 }); ok {
sessions := tm.SessionsByTunnel()
// If tunnel_id labels are enabled, preserve existing per-tunnel observations
if ShouldIncludeTunnelID() {
for tid, n := range sessions {
attrs := []attribute.KeyValue{
attribute.String("site_id", siteID),
}
if tid != "" {
attrs = append(attrs, attribute.String("tunnel_id", tid))
}
o.ObserveInt64(mTunnelSessions, n, metric.WithAttributes(attrs...))
}
return
}
// When tunnel_id is disabled, collapse per-tunnel counts into a single site-level value
var total int64
for _, n := range sessions {
total += n
}
// If there are no per-tunnel entries, fall back to ActiveSessions() if available
if total == 0 {
if svAny := stateView.Load(); svAny != nil {
if sv, ok := svAny.(StateView); ok {
if n, ok2 := sv.ActiveSessions(siteID); ok2 {
total = n
}
}
}
}
o.ObserveInt64(mTunnelSessions, total, metric.WithAttributes(attribute.String("site_id", siteID)))
return
}
}

View File

@@ -0,0 +1,384 @@
package telemetry
import (
"context"
"errors"
"net/http"
"os"
"strings"
"sync/atomic"
"time"
promclient "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"go.opentelemetry.io/contrib/instrumentation/runtime"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
"go.opentelemetry.io/otel/exporters/prometheus"
"go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/resource"
"go.opentelemetry.io/otel/sdk/trace"
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
"google.golang.org/grpc/credentials"
)
// Config controls telemetry initialization via env flags.
//
// Defaults align with the issue requirements:
// - Prometheus exporter enabled by default (/metrics)
// - OTLP exporter disabled by default
// - Durations in seconds, bytes in raw bytes
// - Admin HTTP server address configurable (for mounting /metrics)
type Config struct {
ServiceName string
ServiceVersion string
// Optional resource attributes
SiteID string
Region string
PromEnabled bool
OTLPEnabled bool
OTLPEndpoint string // host:port
OTLPInsecure bool
MetricExportInterval time.Duration
AdminAddr string // e.g.: ":2112"
// Optional build info for newt_build_info metric
BuildVersion string
BuildCommit string
}
// FromEnv reads configuration from environment variables.
//
// NEWT_METRICS_PROMETHEUS_ENABLED (default: true)
// NEWT_METRICS_OTLP_ENABLED (default: false)
// OTEL_EXPORTER_OTLP_ENDPOINT (default: "localhost:4317")
// OTEL_EXPORTER_OTLP_INSECURE (default: true)
// OTEL_METRIC_EXPORT_INTERVAL (default: 15s)
// OTEL_SERVICE_NAME (default: "newt")
// OTEL_SERVICE_VERSION (default: "")
// NEWT_ADMIN_ADDR (default: ":2112")
func FromEnv() Config {
// Prefer explicit NEWT_* env vars, then fall back to OTEL_RESOURCE_ATTRIBUTES
site := getenv("NEWT_SITE_ID", "")
if site == "" {
site = getenv("NEWT_ID", "")
}
region := os.Getenv("NEWT_REGION")
if site == "" || region == "" {
if ra := os.Getenv("OTEL_RESOURCE_ATTRIBUTES"); ra != "" {
m := parseResourceAttributes(ra)
if site == "" {
site = m["site_id"]
}
if region == "" {
region = m["region"]
}
}
}
return Config{
ServiceName: getenv("OTEL_SERVICE_NAME", "newt"),
ServiceVersion: os.Getenv("OTEL_SERVICE_VERSION"),
SiteID: site,
Region: region,
PromEnabled: getenv("NEWT_METRICS_PROMETHEUS_ENABLED", "true") == "true",
OTLPEnabled: getenv("NEWT_METRICS_OTLP_ENABLED", "false") == "true",
OTLPEndpoint: getenv("OTEL_EXPORTER_OTLP_ENDPOINT", "localhost:4317"),
OTLPInsecure: getenv("OTEL_EXPORTER_OTLP_INSECURE", "true") == "true",
MetricExportInterval: getdur("OTEL_METRIC_EXPORT_INTERVAL", 15*time.Second),
AdminAddr: getenv("NEWT_ADMIN_ADDR", ":2112"),
}
}
// Setup holds initialized telemetry providers and (optionally) a /metrics handler.
// Call Shutdown when the process terminates to flush exporters.
type Setup struct {
MeterProvider *metric.MeterProvider
TracerProvider *trace.TracerProvider
PrometheusHandler http.Handler // nil if Prometheus exporter disabled
shutdowns []func(context.Context) error
}
// Init configures OpenTelemetry metrics and (optionally) tracing.
//
// It sets a global MeterProvider and TracerProvider, registers runtime instrumentation,
// installs recommended histogram views for *_latency_seconds, and returns a Setup with
// a Shutdown method to flush exporters.
func Init(ctx context.Context, cfg Config) (*Setup, error) {
// Configure tunnel_id label inclusion from env (default true)
if getenv("NEWT_METRICS_INCLUDE_TUNNEL_ID", "true") == "true" {
includeTunnelIDVal.Store(true)
} else {
includeTunnelIDVal.Store(false)
}
if getenv("NEWT_METRICS_INCLUDE_SITE_LABELS", "true") == "true" {
includeSiteLabelVal.Store(true)
} else {
includeSiteLabelVal.Store(false)
}
res := buildResource(ctx, cfg)
UpdateSiteInfo(cfg.SiteID, cfg.Region)
s := &Setup{}
readers, promHandler, shutdowns, err := setupMetricExport(ctx, cfg, res)
if err != nil {
return nil, err
}
s.PrometheusHandler = promHandler
// Build provider
mp := buildMeterProvider(res, readers)
otel.SetMeterProvider(mp)
s.MeterProvider = mp
s.shutdowns = append(s.shutdowns, mp.Shutdown)
// Optional tracing
if cfg.OTLPEnabled {
if tp, shutdown := setupTracing(ctx, cfg, res); tp != nil {
otel.SetTracerProvider(tp)
s.TracerProvider = tp
s.shutdowns = append(s.shutdowns, func(c context.Context) error {
return errors.Join(shutdown(c), tp.Shutdown(c))
})
}
}
// Add metric exporter shutdowns
s.shutdowns = append(s.shutdowns, shutdowns...)
// Runtime metrics
_ = runtime.Start(runtime.WithMeterProvider(mp))
// Instruments
if err := registerInstruments(); err != nil {
return nil, err
}
if cfg.BuildVersion != "" || cfg.BuildCommit != "" {
RegisterBuildInfo(cfg.BuildVersion, cfg.BuildCommit)
}
return s, nil
}
func buildResource(ctx context.Context, cfg Config) *resource.Resource {
attrs := []attribute.KeyValue{
semconv.ServiceName(cfg.ServiceName),
semconv.ServiceVersion(cfg.ServiceVersion),
}
if cfg.SiteID != "" {
attrs = append(attrs, attribute.String("site_id", cfg.SiteID))
}
if cfg.Region != "" {
attrs = append(attrs, attribute.String("region", cfg.Region))
}
res, _ := resource.New(ctx, resource.WithFromEnv(), resource.WithHost(), resource.WithAttributes(attrs...))
return res
}
func setupMetricExport(ctx context.Context, cfg Config, _ *resource.Resource) ([]metric.Reader, http.Handler, []func(context.Context) error, error) {
var readers []metric.Reader
var shutdowns []func(context.Context) error
var promHandler http.Handler
if cfg.PromEnabled {
reg := promclient.NewRegistry()
exp, err := prometheus.New(prometheus.WithRegisterer(reg))
if err != nil {
return nil, nil, nil, err
}
readers = append(readers, exp)
promHandler = promhttp.HandlerFor(reg, promhttp.HandlerOpts{})
}
if cfg.OTLPEnabled {
mopts := []otlpmetricgrpc.Option{otlpmetricgrpc.WithEndpoint(cfg.OTLPEndpoint)}
if hdrs := parseOTLPHeaders(os.Getenv("OTEL_EXPORTER_OTLP_HEADERS")); len(hdrs) > 0 {
mopts = append(mopts, otlpmetricgrpc.WithHeaders(hdrs))
}
if cfg.OTLPInsecure {
mopts = append(mopts, otlpmetricgrpc.WithInsecure())
} else if certFile := os.Getenv("OTEL_EXPORTER_OTLP_CERTIFICATE"); certFile != "" {
if creds, cerr := credentials.NewClientTLSFromFile(certFile, ""); cerr == nil {
mopts = append(mopts, otlpmetricgrpc.WithTLSCredentials(creds))
}
}
mexp, err := otlpmetricgrpc.New(ctx, mopts...)
if err != nil {
return nil, nil, nil, err
}
readers = append(readers, metric.NewPeriodicReader(mexp, metric.WithInterval(cfg.MetricExportInterval)))
shutdowns = append(shutdowns, mexp.Shutdown)
}
return readers, promHandler, shutdowns, nil
}
func buildMeterProvider(res *resource.Resource, readers []metric.Reader) *metric.MeterProvider {
var mpOpts []metric.Option
mpOpts = append(mpOpts, metric.WithResource(res))
for _, r := range readers {
mpOpts = append(mpOpts, metric.WithReader(r))
}
mpOpts = append(mpOpts, metric.WithView(metric.NewView(
metric.Instrument{Name: "newt_*_latency_seconds"},
metric.Stream{Aggregation: metric.AggregationExplicitBucketHistogram{Boundaries: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10, 30}}},
)))
mpOpts = append(mpOpts, metric.WithView(metric.NewView(
metric.Instrument{Name: "newt_*"},
metric.Stream{AttributeFilter: func(kv attribute.KeyValue) bool {
k := string(kv.Key)
switch k {
case "tunnel_id", "transport", "direction", "protocol", "result", "reason", "initiator", "error_type", "msg_type", "phase", "version", "commit", "site_id", "region":
return true
default:
return false
}
}},
)))
return metric.NewMeterProvider(mpOpts...)
}
func setupTracing(ctx context.Context, cfg Config, res *resource.Resource) (*trace.TracerProvider, func(context.Context) error) {
topts := []otlptracegrpc.Option{otlptracegrpc.WithEndpoint(cfg.OTLPEndpoint)}
if hdrs := parseOTLPHeaders(os.Getenv("OTEL_EXPORTER_OTLP_HEADERS")); len(hdrs) > 0 {
topts = append(topts, otlptracegrpc.WithHeaders(hdrs))
}
if cfg.OTLPInsecure {
topts = append(topts, otlptracegrpc.WithInsecure())
} else if certFile := os.Getenv("OTEL_EXPORTER_OTLP_CERTIFICATE"); certFile != "" {
if creds, cerr := credentials.NewClientTLSFromFile(certFile, ""); cerr == nil {
topts = append(topts, otlptracegrpc.WithTLSCredentials(creds))
}
}
exp, err := otlptracegrpc.New(ctx, topts...)
if err != nil {
return nil, nil
}
tp := trace.NewTracerProvider(trace.WithBatcher(exp), trace.WithResource(res))
return tp, exp.Shutdown
}
// Shutdown flushes exporters and providers in reverse init order.
func (s *Setup) Shutdown(ctx context.Context) error {
var err error
for i := len(s.shutdowns) - 1; i >= 0; i-- {
err = errors.Join(err, s.shutdowns[i](ctx))
}
return err
}
func parseOTLPHeaders(h string) map[string]string {
m := map[string]string{}
if h == "" {
return m
}
pairs := strings.Split(h, ",")
for _, p := range pairs {
kv := strings.SplitN(strings.TrimSpace(p), "=", 2)
if len(kv) == 2 {
m[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1])
}
}
return m
}
// parseResourceAttributes parses OTEL_RESOURCE_ATTRIBUTES formatted as k=v,k2=v2
func parseResourceAttributes(s string) map[string]string {
m := map[string]string{}
if s == "" {
return m
}
parts := strings.Split(s, ",")
for _, p := range parts {
kv := strings.SplitN(strings.TrimSpace(p), "=", 2)
if len(kv) == 2 {
m[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1])
}
}
return m
}
// Global site/region used to enrich metric labels.
var siteIDVal atomic.Value
var regionVal atomic.Value
var (
includeTunnelIDVal atomic.Value // bool; default true
includeSiteLabelVal atomic.Value // bool; default false
)
// UpdateSiteInfo updates the global site_id and region used for metric labels.
// Thread-safe via atomic.Value: subsequent metric emissions will include
// the new labels, prior emissions remain unchanged.
func UpdateSiteInfo(siteID, region string) {
if siteID != "" {
siteIDVal.Store(siteID)
}
if region != "" {
regionVal.Store(region)
}
}
func getSiteID() string {
if v, ok := siteIDVal.Load().(string); ok {
return v
}
return ""
}
func getRegion() string {
if v, ok := regionVal.Load().(string); ok {
return v
}
return ""
}
// siteAttrs returns label KVs for site_id and region (if set).
func siteAttrs() []attribute.KeyValue {
var out []attribute.KeyValue
if s := getSiteID(); s != "" {
out = append(out, attribute.String("site_id", s))
}
if r := getRegion(); r != "" {
out = append(out, attribute.String("region", r))
}
return out
}
// SiteLabelKVs exposes site label KVs for other packages (e.g., proxy manager).
func SiteLabelKVs() []attribute.KeyValue {
if !ShouldIncludeSiteLabels() {
return nil
}
return siteAttrs()
}
// ShouldIncludeTunnelID returns whether tunnel_id labels should be emitted.
func ShouldIncludeTunnelID() bool {
if v, ok := includeTunnelIDVal.Load().(bool); ok {
return v
}
return true
}
// ShouldIncludeSiteLabels returns whether site_id/region should be emitted as
// metric labels in addition to resource attributes.
func ShouldIncludeSiteLabels() bool {
if v, ok := includeSiteLabelVal.Load().(bool); ok {
return v
}
return false
}
func getenv(k, d string) string {
if v := os.Getenv(k); v != "" {
return v
}
return d
}
func getdur(k string, d time.Duration) time.Duration {
if v := os.Getenv(k); v != "" {
if p, e := time.ParseDuration(v); e == nil {
return p
}
}
return d
}

View File

@@ -0,0 +1,53 @@
package telemetry
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"go.opentelemetry.io/otel/attribute"
)
// Test that disallowed attributes are filtered from the exposition.
func TestAttributeFilterDropsUnknownKeys(t *testing.T) {
ctx := context.Background()
resetMetricsForTest()
t.Setenv("NEWT_METRICS_INCLUDE_SITE_LABELS", "true")
cfg := Config{ServiceName: "newt", PromEnabled: true, AdminAddr: "127.0.0.1:0"}
tel, err := Init(ctx, cfg)
if err != nil {
t.Fatalf("init: %v", err)
}
defer func() { _ = tel.Shutdown(context.Background()) }()
if tel.PrometheusHandler == nil {
t.Fatalf("prom handler nil")
}
ts := httptest.NewServer(tel.PrometheusHandler)
defer ts.Close()
// Add samples with disallowed attribute keys
for _, k := range []string{"forbidden", "site_id", "host"} {
set := attribute.NewSet(attribute.String(k, "x"))
AddTunnelBytesSet(ctx, 123, set)
}
time.Sleep(50 * time.Millisecond)
resp, err := http.Get(ts.URL)
if err != nil {
t.Fatalf("GET: %v", err)
}
defer resp.Body.Close()
b, _ := io.ReadAll(resp.Body)
body := string(b)
if strings.Contains(body, "forbidden=") {
t.Fatalf("unexpected forbidden attribute leaked into metrics: %s", body)
}
if !strings.Contains(body, "site_id=\"x\"") {
t.Fatalf("expected allowed attribute site_id to be present in metrics, got: %s", body)
}
}

View File

@@ -0,0 +1,76 @@
package telemetry
import (
"bufio"
"context"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"time"
)
// Golden test that /metrics contains expected metric names.
func TestMetricsGoldenContains(t *testing.T) {
ctx := context.Background()
resetMetricsForTest()
t.Setenv("NEWT_METRICS_INCLUDE_SITE_LABELS", "true")
cfg := Config{ServiceName: "newt", PromEnabled: true, AdminAddr: "127.0.0.1:0", BuildVersion: "test"}
tel, err := Init(ctx, cfg)
if err != nil {
t.Fatalf("telemetry init error: %v", err)
}
defer func() { _ = tel.Shutdown(context.Background()) }()
if tel.PrometheusHandler == nil {
t.Fatalf("prom handler nil")
}
ts := httptest.NewServer(tel.PrometheusHandler)
defer ts.Close()
// Trigger counters to ensure they appear in the scrape
IncConnAttempt(ctx, "websocket", "success")
IncWSReconnect(ctx, "io_error")
IncProxyConnectionEvent(ctx, "", "tcp", ProxyConnectionOpened)
if tel.MeterProvider != nil {
_ = tel.MeterProvider.ForceFlush(ctx)
}
time.Sleep(100 * time.Millisecond)
var body string
for i := 0; i < 5; i++ {
resp, err := http.Get(ts.URL)
if err != nil {
t.Fatalf("GET metrics failed: %v", err)
}
b, _ := io.ReadAll(resp.Body)
_ = resp.Body.Close()
body = string(b)
if strings.Contains(body, "newt_connection_attempts_total") {
break
}
time.Sleep(100 * time.Millisecond)
}
f, err := os.Open(filepath.Join("testdata", "expected_contains.golden"))
if err != nil {
t.Fatalf("read golden: %v", err)
}
defer f.Close()
s := bufio.NewScanner(f)
for s.Scan() {
needle := strings.TrimSpace(s.Text())
if needle == "" {
continue
}
if !strings.Contains(body, needle) {
t.Fatalf("expected metrics body to contain %q. body=\n%s", needle, body)
}
}
if err := s.Err(); err != nil {
t.Fatalf("scan golden: %v", err)
}
}

View File

@@ -0,0 +1,65 @@
package telemetry
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
// Smoke test that /metrics contains at least one newt_* metric when Prom exporter is enabled.
func TestMetricsSmoke(t *testing.T) {
ctx := context.Background()
resetMetricsForTest()
t.Setenv("NEWT_METRICS_INCLUDE_SITE_LABELS", "true")
cfg := Config{
ServiceName: "newt",
PromEnabled: true,
OTLPEnabled: false,
AdminAddr: "127.0.0.1:0",
BuildVersion: "test",
BuildCommit: "deadbeef",
MetricExportInterval: 5 * time.Second,
}
tel, err := Init(ctx, cfg)
if err != nil {
t.Fatalf("telemetry init error: %v", err)
}
defer func() { _ = tel.Shutdown(context.Background()) }()
// Serve the Prom handler on a test server
if tel.PrometheusHandler == nil {
t.Fatalf("Prometheus handler nil; PromEnabled should enable it")
}
ts := httptest.NewServer(tel.PrometheusHandler)
defer ts.Close()
// Record a simple metric and then fetch /metrics
IncConnAttempt(ctx, "websocket", "success")
if tel.MeterProvider != nil {
_ = tel.MeterProvider.ForceFlush(ctx)
}
// Give the exporter a tick to collect
time.Sleep(100 * time.Millisecond)
var body string
for i := 0; i < 5; i++ {
resp, err := http.Get(ts.URL)
if err != nil {
t.Fatalf("GET /metrics failed: %v", err)
}
b, _ := io.ReadAll(resp.Body)
_ = resp.Body.Close()
body = string(b)
if strings.Contains(body, "newt_connection_attempts_total") {
break
}
time.Sleep(100 * time.Millisecond)
}
if !strings.Contains(body, "newt_connection_attempts_total") {
t.Fatalf("expected newt_connection_attempts_total in metrics, got:\n%s", body)
}
}

View File

@@ -0,0 +1,3 @@
newt_connection_attempts_total
newt_websocket_reconnects_total
newt_proxy_connections_total

1
key
View File

@@ -1 +0,0 @@
oBvcoMJZXGzTZ4X+aNSCCQIjroREFBeRCs+a328xWGA=

View File

@@ -1,74 +0,0 @@
//go:build linux
package main
import (
"fmt"
"os"
"runtime"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/proxy"
"github.com/fosrl/newt/websocket"
"github.com/fosrl/newt/wg"
"github.com/fosrl/newt/wgtester"
)
var wgServiceNative *wg.WireGuardService
func setupClientsNative(client *websocket.Client, host string) {
if runtime.GOOS != "linux" {
logger.Fatal("Tunnel management is only supported on Linux right now!")
os.Exit(1)
}
// make sure we are sudo
if os.Geteuid() != 0 {
logger.Fatal("You must run this program as root to manage tunnels on Linux.")
os.Exit(1)
}
// Create WireGuard service
wgServiceNative, err = wg.NewWireGuardService(interfaceName, mtuInt, generateAndSaveKeyTo, host, id, client)
if err != nil {
logger.Fatal("Failed to create WireGuard service: %v", err)
}
wgTesterServer = wgtester.NewServer("0.0.0.0", wgServiceNative.Port, id) // TODO: maybe make this the same ip of the wg server?
err := wgTesterServer.Start()
if err != nil {
logger.Error("Failed to start WireGuard tester server: %v", err)
}
client.OnTokenUpdate(func(token string) {
wgServiceNative.SetToken(token)
})
}
func closeWgServiceNative() {
if wgServiceNative != nil {
wgServiceNative.Close(!keepInterface)
wgServiceNative = nil
}
}
func clientsOnConnectNative() {
if wgServiceNative != nil {
wgServiceNative.LoadRemoteConfig()
}
}
func clientsHandleNewtConnectionNative(publicKey, endpoint string) {
if wgServiceNative != nil {
wgServiceNative.StartHolepunch(publicKey, endpoint)
}
}
func clientsAddProxyTargetNative(pm *proxy.ProxyManager, tunnelIp string) {
// add a udp proxy for localost and the wgService port
// TODO: make sure this port is not used in a target
if wgServiceNative != nil {
pm.AddTarget("udp", tunnelIp, int(wgServiceNative.Port), fmt.Sprintf("127.0.0.1:%d", wgServiceNative.Port))
}
}

View File

@@ -2,16 +2,15 @@ package logger
import (
"fmt"
"io"
"log"
"os"
"strings"
"sync"
"time"
)
// Logger struct holds the logger instance
type Logger struct {
logger *log.Logger
writer LogWriter
level LogLevel
}
@@ -20,17 +19,29 @@ var (
once sync.Once
)
// NewLogger creates a new logger instance
// NewLogger creates a new logger instance with the default StandardWriter
func NewLogger() *Logger {
return &Logger{
logger: log.New(os.Stdout, "", 0),
writer: NewStandardWriter(),
level: DEBUG,
}
}
// NewLoggerWithWriter creates a new logger instance with a custom LogWriter
func NewLoggerWithWriter(writer LogWriter) *Logger {
return &Logger{
writer: writer,
level: DEBUG,
}
}
// Init initializes the default logger
func Init() *Logger {
func Init(logger *Logger) *Logger {
once.Do(func() {
if logger != nil {
defaultLogger = logger
return
}
defaultLogger = NewLogger()
})
return defaultLogger
@@ -39,7 +50,7 @@ func Init() *Logger {
// GetLogger returns the default logger instance
func GetLogger() *Logger {
if defaultLogger == nil {
Init()
Init(nil)
}
return defaultLogger
}
@@ -49,9 +60,11 @@ func (l *Logger) SetLevel(level LogLevel) {
l.level = level
}
// SetOutput sets the output destination for the logger
func (l *Logger) SetOutput(w io.Writer) {
l.logger.SetOutput(w)
// SetOutput sets the output destination for the logger (only works with StandardWriter)
func (l *Logger) SetOutput(output *os.File) {
if sw, ok := l.writer.(*StandardWriter); ok {
sw.SetOutput(output)
}
}
// log handles the actual logging
@@ -60,24 +73,8 @@ func (l *Logger) log(level LogLevel, format string, args ...interface{}) {
return
}
// Get timezone from environment variable or use local timezone
timezone := os.Getenv("LOGGER_TIMEZONE")
var location *time.Location
var err error
if timezone != "" {
location, err = time.LoadLocation(timezone)
if err != nil {
// If invalid timezone, fall back to local
location = time.Local
}
} else {
location = time.Local
}
timestamp := time.Now().In(location).Format("2006/01/02 15:04:05")
message := fmt.Sprintf(format, args...)
l.logger.Printf("%s: %s %s", level.String(), timestamp, message)
l.writer.Write(level, time.Now(), message)
}
// Debug logs debug level messages
@@ -128,6 +125,29 @@ func Fatal(format string, args ...interface{}) {
}
// SetOutput sets the output destination for the default logger
func SetOutput(w io.Writer) {
GetLogger().SetOutput(w)
func SetOutput(output *os.File) {
GetLogger().SetOutput(output)
}
// WireGuardLogger is a wrapper type that matches WireGuard's Logger interface
type WireGuardLogger struct {
Verbosef func(format string, args ...any)
Errorf func(format string, args ...any)
}
// GetWireGuardLogger returns a WireGuard-compatible logger that writes to the newt logger
// The prepend string is added as a prefix to all log messages
func (l *Logger) GetWireGuardLogger(prepend string) *WireGuardLogger {
return &WireGuardLogger{
Verbosef: func(format string, args ...any) {
// if the format string contains "Sending keepalive packet", skip debug logging to reduce noise
if strings.Contains(format, "Sending keepalive packet") {
return
}
l.Debug(prepend+format, args...)
},
Errorf: func(format string, args ...any) {
l.Error(prepend+format, args...)
},
}
}

54
logger/writer.go Normal file
View File

@@ -0,0 +1,54 @@
package logger
import (
"fmt"
"os"
"time"
)
// LogWriter is an interface for writing log messages
// Implement this interface to create custom log backends (OS log, syslog, etc.)
type LogWriter interface {
// Write writes a log message with the given level, timestamp, and formatted message
Write(level LogLevel, timestamp time.Time, message string)
}
// StandardWriter is the default log writer that writes to an io.Writer
type StandardWriter struct {
output *os.File
timezone *time.Location
}
// NewStandardWriter creates a new standard writer with the default configuration
func NewStandardWriter() *StandardWriter {
// Get timezone from environment variable or use local timezone
timezone := os.Getenv("LOGGER_TIMEZONE")
var location *time.Location
var err error
if timezone != "" {
location, err = time.LoadLocation(timezone)
if err != nil {
// If invalid timezone, fall back to local
location = time.Local
}
} else {
location = time.Local
}
return &StandardWriter{
output: os.Stdout,
timezone: location,
}
}
// SetOutput sets the output destination
func (w *StandardWriter) SetOutput(output *os.File) {
w.output = output
}
// Write implements the LogWriter interface
func (w *StandardWriter) Write(level LogLevel, timestamp time.Time, message string) {
formattedTime := timestamp.In(w.timezone).Format("2006/01/02 15:04:05")
fmt.Fprintf(w.output, "%s: %s %s\n", level.String(), formattedTime, message)
}

249
main.go
View File

@@ -1,7 +1,9 @@
package main
import (
"context"
"encoding/json"
"errors"
"flag"
"fmt"
"net"
@@ -20,8 +22,12 @@ import (
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/proxy"
"github.com/fosrl/newt/updates"
"github.com/fosrl/newt/util"
"github.com/fosrl/newt/websocket"
"github.com/fosrl/newt/internal/state"
"github.com/fosrl/newt/internal/telemetry"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
@@ -91,6 +97,14 @@ func (s *stringSlice) Set(value string) error {
return nil
}
const (
fmtErrMarshaling = "Error marshaling data: %v"
fmtReceivedMsg = "Received: %+v"
topicWGRegister = "newt/wg/register"
msgNoTunnelOrProxy = "No tunnel IP or proxy manager available"
fmtErrParsingTargetData = "Error parsing target data: %v"
)
var (
endpoint string
id string
@@ -102,9 +116,7 @@ var (
err error
logLevel string
interfaceName string
generateAndSaveKeyTo string
keepInterface bool
acceptClients bool
disableClients bool
updownScript string
dockerSocket string
dockerEnforceNetworkValidation string
@@ -120,8 +132,17 @@ var (
preferEndpoint string
healthMonitor *healthcheck.Monitor
enforceHealthcheckCert bool
blueprintFile string
noCloud bool
// Build/version (can be overridden via -ldflags "-X main.newtVersion=...")
newtVersion = "version_replaceme"
// Observability/metrics flags
metricsEnabled bool
otlpEnabled bool
adminAddr string
region string
metricsAsyncBytes bool
blueprintFile string
noCloud bool
// New mTLS configuration variables
tlsClientCert string
@@ -133,6 +154,10 @@ var (
)
func main() {
// Prepare context for graceful shutdown and signal handling
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
// if PANGOLIN_ENDPOINT, NEWT_ID, and NEWT_SECRET are set as environment variables, they will be used as default values
endpoint = os.Getenv("PANGOLIN_ENDPOINT")
id = os.Getenv("NEWT_ID")
@@ -142,11 +167,16 @@ func main() {
logLevel = os.Getenv("LOG_LEVEL")
updownScript = os.Getenv("UPDOWN_SCRIPT")
interfaceName = os.Getenv("INTERFACE")
generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO")
keepInterfaceEnv := os.Getenv("KEEP_INTERFACE")
keepInterface = keepInterfaceEnv == "true"
acceptClientsEnv := os.Getenv("ACCEPT_CLIENTS")
acceptClients = acceptClientsEnv == "true"
// Metrics/observability env mirrors
metricsEnabledEnv := os.Getenv("NEWT_METRICS_PROMETHEUS_ENABLED")
otlpEnabledEnv := os.Getenv("NEWT_METRICS_OTLP_ENABLED")
adminAddrEnv := os.Getenv("NEWT_ADMIN_ADDR")
regionEnv := os.Getenv("NEWT_REGION")
asyncBytesEnv := os.Getenv("NEWT_METRICS_ASYNC_BYTES")
disableClientsEnv := os.Getenv("DISABLE_CLIENTS")
disableClients = disableClientsEnv == "true"
useNativeInterfaceEnv := os.Getenv("USE_NATIVE_INTERFACE")
useNativeInterface = useNativeInterfaceEnv == "true"
enforceHealthcheckCertEnv := os.Getenv("ENFORCE_HC_CERT")
@@ -205,17 +235,11 @@ func main() {
if interfaceName == "" {
flag.StringVar(&interfaceName, "interface", "newt", "Name of the WireGuard interface")
}
if generateAndSaveKeyTo == "" {
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
}
if keepInterfaceEnv == "" {
flag.BoolVar(&keepInterface, "keep-interface", false, "Keep the WireGuard interface")
}
if useNativeInterfaceEnv == "" {
flag.BoolVar(&useNativeInterface, "native", false, "Use native WireGuard interface (requires WireGuard kernel module) and linux")
flag.BoolVar(&useNativeInterface, "native", false, "Use native WireGuard interface")
}
if acceptClientsEnv == "" {
flag.BoolVar(&acceptClients, "accept-clients", false, "Accept clients on the WireGuard interface")
if disableClientsEnv == "" {
flag.BoolVar(&disableClients, "disable-clients", false, "Disable clients on the WireGuard interface")
}
if enforceHealthcheckCertEnv == "" {
flag.BoolVar(&enforceHealthcheckCert, "enforce-hc-cert", false, "Enforce certificate validation for health checks (default: false, accepts any cert)")
@@ -286,6 +310,43 @@ func main() {
flag.BoolVar(&noCloud, "no-cloud", false, "Disable cloud failover")
}
// Metrics/observability flags (mirror ENV if unset)
if metricsEnabledEnv == "" {
flag.BoolVar(&metricsEnabled, "metrics", true, "Enable Prometheus /metrics exporter")
} else {
if v, err := strconv.ParseBool(metricsEnabledEnv); err == nil {
metricsEnabled = v
} else {
metricsEnabled = true
}
}
if otlpEnabledEnv == "" {
flag.BoolVar(&otlpEnabled, "otlp", false, "Enable OTLP exporters (metrics/traces) to OTEL_EXPORTER_OTLP_ENDPOINT")
} else {
if v, err := strconv.ParseBool(otlpEnabledEnv); err == nil {
otlpEnabled = v
}
}
if adminAddrEnv == "" {
flag.StringVar(&adminAddr, "metrics-admin-addr", "127.0.0.1:2112", "Admin/metrics bind address")
} else {
adminAddr = adminAddrEnv
}
// Async bytes toggle
if asyncBytesEnv == "" {
flag.BoolVar(&metricsAsyncBytes, "metrics-async-bytes", false, "Enable async bytes counting (background flush; lower hot path overhead)")
} else {
if v, err := strconv.ParseBool(asyncBytesEnv); err == nil {
metricsAsyncBytes = v
}
}
// Optional region flag (resource attribute)
if regionEnv == "" {
flag.StringVar(&region, "region", "", "Optional region resource attribute (also NEWT_REGION)")
} else {
region = regionEnv
}
// do a --version check
version := flag.Bool("version", false, "Print the version")
@@ -296,16 +357,62 @@ func main() {
tlsClientCAs = append(tlsClientCAs, tlsClientCAsFlag...)
}
logger.Init()
loggerLevel := parseLogLevel(logLevel)
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
logger.Init(nil)
loggerLevel := util.ParseLogLevel(logLevel)
logger.GetLogger().SetLevel(loggerLevel)
// Initialize telemetry after flags are parsed (so flags override env)
tcfg := telemetry.FromEnv()
tcfg.PromEnabled = metricsEnabled
tcfg.OTLPEnabled = otlpEnabled
if adminAddr != "" {
tcfg.AdminAddr = adminAddr
}
// Resource attributes (if available)
tcfg.SiteID = id
tcfg.Region = region
// Build info
tcfg.BuildVersion = newtVersion
tcfg.BuildCommit = os.Getenv("NEWT_COMMIT")
tel, telErr := telemetry.Init(ctx, tcfg)
if telErr != nil {
logger.Warn("Telemetry init failed: %v", telErr)
}
if tel != nil {
// Admin HTTP server (exposes /metrics when Prometheus exporter is enabled)
logger.Info("Starting metrics server on %s", tcfg.AdminAddr)
mux := http.NewServeMux()
mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) })
if tel.PrometheusHandler != nil {
mux.Handle("/metrics", tel.PrometheusHandler)
}
admin := &http.Server{
Addr: tcfg.AdminAddr,
Handler: otelhttp.NewHandler(mux, "newt-admin"),
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
ReadHeaderTimeout: 5 * time.Second,
IdleTimeout: 30 * time.Second,
}
go func() {
if err := admin.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.Warn("admin http error: %v", err)
}
}()
defer func() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = admin.Shutdown(ctx)
}()
defer func() { _ = tel.Shutdown(context.Background()) }()
}
newtVersion := "version_replaceme"
if *version {
fmt.Println("Newt version " + newtVersion)
os.Exit(0)
} else {
logger.Info("Newt version " + newtVersion)
logger.Info("Newt version %s", newtVersion)
}
if err := updates.CheckForUpdate("fosrl", "newt", newtVersion); err != nil {
@@ -376,6 +483,8 @@ func main() {
}
endpoint = client.GetConfig().Endpoint // Update endpoint from config
id = client.GetConfig().ID // Update ID from config
// Update site labels for metrics with the resolved ID
telemetry.UpdateSiteInfo(id, region)
// output env var values if set
logger.Debug("Endpoint: %v", endpoint)
@@ -419,7 +528,7 @@ func main() {
var wgData WgData
var dockerEventMonitor *docker.EventMonitor
if acceptClients {
if !disableClients {
setupClients(client)
}
@@ -484,6 +593,10 @@ func main() {
// Register handlers for different message types
client.RegisterHandler("newt/wg/connect", func(msg websocket.WSMessage) {
logger.Debug("Received registration message")
regResult := "success"
defer func() {
telemetry.IncSiteRegistration(ctx, regResult)
}()
if stopFunc != nil {
stopFunc() // stop the ws from sending more requests
stopFunc = nil // reset stopFunc to nil to avoid double stopping
@@ -502,43 +615,48 @@ func main() {
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
logger.Info(fmtErrMarshaling, err)
regResult = "failure"
return
}
if err := json.Unmarshal(jsonData, &wgData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
regResult = "failure"
return
}
logger.Debug("Received: %+v", msg)
logger.Debug(fmtReceivedMsg, msg)
tun, tnet, err = netstack.CreateNetTUN(
[]netip.Addr{netip.MustParseAddr(wgData.TunnelIP)},
[]netip.Addr{netip.MustParseAddr(dns)},
mtuInt)
if err != nil {
logger.Error("Failed to create TUN device: %v", err)
regResult = "failure"
}
setDownstreamTNetstack(tnet)
// Create WireGuard device
dev = device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(
mapToWireGuardLogLevel(loggerLevel),
util.MapToWireGuardLogLevel(loggerLevel),
"wireguard: ",
))
host, _, err := net.SplitHostPort(wgData.Endpoint)
if err != nil {
logger.Error("Failed to split endpoint: %v", err)
regResult = "failure"
return
}
logger.Info("Connecting to endpoint: %s", host)
endpoint, err := resolveDomain(wgData.Endpoint)
endpoint, err := util.ResolveDomain(wgData.Endpoint)
if err != nil {
logger.Error("Failed to resolve endpoint: %v", err)
regResult = "failure"
return
}
@@ -549,17 +667,19 @@ func main() {
public_key=%s
allowed_ip=%s/32
endpoint=%s
persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint)
persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(wgData.PublicKey), wgData.ServerIP, endpoint)
err = dev.IpcSet(config)
if err != nil {
logger.Error("Failed to configure WireGuard device: %v", err)
regResult = "failure"
}
// Bring up the device
err = dev.Up()
if err != nil {
logger.Error("Failed to bring up WireGuard device: %v", err)
regResult = "failure"
}
logger.Debug("WireGuard device created. Lets ping the server now...")
@@ -572,9 +692,13 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
}
// Use reliable ping for initial connection test
logger.Debug("Testing initial connection with reliable ping...")
_, err = reliablePing(tnet, wgData.ServerIP, pingTimeout, 5)
lat, err := reliablePing(tnet, wgData.ServerIP, pingTimeout, 5)
if err == nil && wgData.PublicKey != "" {
telemetry.ObserveTunnelLatency(ctx, wgData.PublicKey, "wireguard", lat.Seconds())
}
if err != nil {
logger.Warn("Initial reliable ping failed, but continuing: %v", err)
regResult = "failure"
} else {
logger.Debug("Initial connection test successful")
}
@@ -585,11 +709,14 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
// as the pings will continue in the background
if !connected {
logger.Debug("Starting ping check")
pingStopChan = startPingCheck(tnet, wgData.ServerIP, client)
pingStopChan = startPingCheck(tnet, wgData.ServerIP, client, wgData.PublicKey)
}
// Create proxy manager
pm = proxy.NewProxyManager(tnet)
pm.SetAsyncBytes(metricsAsyncBytes)
// Set tunnel_id for metrics (WireGuard peer public key)
pm.SetTunnelID(wgData.PublicKey)
connected = true
@@ -610,7 +737,8 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
// }
}
clientsAddProxyTarget(pm, wgData.TunnelIP)
// Start direct UDP relay from main tunnel to clients' WireGuard (bypasses proxy)
clientsStartDirectRelay(wgData.TunnelIP)
if err := healthMonitor.AddTargets(wgData.HealthCheckTargets); err != nil {
logger.Error("Failed to bulk add health check targets: %v", err)
@@ -626,10 +754,19 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
client.RegisterHandler("newt/wg/reconnect", func(msg websocket.WSMessage) {
logger.Info("Received reconnect message")
if wgData.PublicKey != "" {
telemetry.IncReconnect(ctx, wgData.PublicKey, "server", telemetry.ReasonServerRequest)
}
// Close the WireGuard device and TUN
closeWgTunnel()
// Clear metrics attrs and sessions for the tunnel
if pm != nil {
pm.ClearTunnelID()
state.Global().ClearTunnel(wgData.PublicKey)
}
// Mark as disconnected
connected = false
@@ -648,9 +785,13 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
client.RegisterHandler("newt/wg/terminate", func(msg websocket.WSMessage) {
logger.Info("Received termination message")
if wgData.PublicKey != "" {
telemetry.IncReconnect(ctx, wgData.PublicKey, "server", telemetry.ReasonServerRequest)
}
// Close the WireGuard device and TUN
closeWgTunnel()
closeClients()
if stopFunc != nil {
stopFunc() // stop the ws from sending more requests
@@ -675,7 +816,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
logger.Info(fmtErrMarshaling, err)
return
}
if err := json.Unmarshal(jsonData, &exitNodeData); err != nil {
@@ -716,7 +857,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
},
}
stopFunc = client.SendMessageInterval("newt/wg/register", map[string]interface{}{
stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{
"publicKey": publicKey.String(),
"pingResults": pingResults,
"newtVersion": newtVersion,
@@ -819,7 +960,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
}
// Send the ping results to the cloud for selection
stopFunc = client.SendMessageInterval("newt/wg/register", map[string]interface{}{
stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{
"publicKey": publicKey.String(),
"pingResults": pingResults,
"newtVersion": newtVersion,
@@ -829,17 +970,17 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
})
client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) {
logger.Debug("Received: %+v", msg)
logger.Debug(fmtReceivedMsg, msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
logger.Info("No tunnel IP or proxy manager available")
logger.Info(msgNoTunnelOrProxy)
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
logger.Info("Error parsing target data: %v", err)
logger.Info(fmtErrParsingTargetData, err)
return
}
@@ -854,17 +995,17 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
})
client.RegisterHandler("newt/udp/add", func(msg websocket.WSMessage) {
logger.Info("Received: %+v", msg)
logger.Info(fmtReceivedMsg, msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
logger.Info("No tunnel IP or proxy manager available")
logger.Info(msgNoTunnelOrProxy)
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
logger.Info("Error parsing target data: %v", err)
logger.Info(fmtErrParsingTargetData, err)
return
}
@@ -879,17 +1020,17 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
})
client.RegisterHandler("newt/udp/remove", func(msg websocket.WSMessage) {
logger.Info("Received: %+v", msg)
logger.Info(fmtReceivedMsg, msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
logger.Info("No tunnel IP or proxy manager available")
logger.Info(msgNoTunnelOrProxy)
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
logger.Info("Error parsing target data: %v", err)
logger.Info(fmtErrParsingTargetData, err)
return
}
@@ -904,17 +1045,17 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
})
client.RegisterHandler("newt/tcp/remove", func(msg websocket.WSMessage) {
logger.Info("Received: %+v", msg)
logger.Info(fmtReceivedMsg, msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
logger.Info("No tunnel IP or proxy manager available")
logger.Info(msgNoTunnelOrProxy)
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
logger.Info("Error parsing target data: %v", err)
logger.Info(fmtErrParsingTargetData, err)
return
}
@@ -998,7 +1139,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
logger.Info(fmtErrMarshaling, err)
return
}
if err := json.Unmarshal(jsonData, &sshPublicKeyData); err != nil {
@@ -1155,9 +1296,9 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
}
if err := healthMonitor.EnableTarget(requestData.ID); err != nil {
logger.Error("Failed to enable health check target %s: %v", requestData.ID, err)
logger.Error("Failed to enable health check target %d: %v", requestData.ID, err)
} else {
logger.Info("Enabled health check target: %s", requestData.ID)
logger.Info("Enabled health check target: %d", requestData.ID)
}
})
@@ -1180,9 +1321,9 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
}
if err := healthMonitor.DisableTarget(requestData.ID); err != nil {
logger.Error("Failed to disable health check target %s: %v", requestData.ID, err)
logger.Error("Failed to disable health check target %d: %v", requestData.ID, err)
} else {
logger.Info("Disabled health check target: %s", requestData.ID)
logger.Info("Disabled health check target: %d", requestData.ID)
}
})
@@ -1252,7 +1393,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
}
// Send registration message to the server for backward compatibility
err := client.SendMessage("newt/wg/register", map[string]interface{}{
err := client.SendMessage(topicWGRegister, map[string]interface{}{
"publicKey": publicKey.String(),
"newtVersion": newtVersion,
"backwardsCompatible": true,

350
netstack2/handlers.go Normal file
View File

@@ -0,0 +1,350 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package netstack2
import (
"context"
"fmt"
"io"
"net"
"sync"
"time"
"github.com/fosrl/newt/logger"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
const (
// defaultWndSize if set to zero, the default
// receive window buffer size is used instead.
defaultWndSize = 0
// maxConnAttempts specifies the maximum number
// of in-flight tcp connection attempts.
maxConnAttempts = 2 << 10
// tcpKeepaliveCount is the maximum number of
// TCP keep-alive probes to send before giving up
// and killing the connection if no response is
// obtained from the other end.
tcpKeepaliveCount = 9
// tcpKeepaliveIdle specifies the time a connection
// must remain idle before the first TCP keepalive
// packet is sent. Once this time is reached,
// tcpKeepaliveInterval option is used instead.
tcpKeepaliveIdle = 60 * time.Second
// tcpKeepaliveInterval specifies the interval
// time between sending TCP keepalive packets.
tcpKeepaliveInterval = 30 * time.Second
// tcpConnectTimeout is the default timeout for TCP handshakes.
tcpConnectTimeout = 5 * time.Second
// tcpWaitTimeout implements a TCP half-close timeout.
tcpWaitTimeout = 60 * time.Second
// udpSessionTimeout is the default timeout for UDP sessions.
udpSessionTimeout = 60 * time.Second
// Buffer size for copying data
bufferSize = 32 * 1024
)
// TCPHandler handles TCP connections from netstack
type TCPHandler struct {
stack *stack.Stack
proxyHandler *ProxyHandler
}
// UDPHandler handles UDP connections from netstack
type UDPHandler struct {
stack *stack.Stack
proxyHandler *ProxyHandler
}
// NewTCPHandler creates a new TCP handler
func NewTCPHandler(s *stack.Stack, ph *ProxyHandler) *TCPHandler {
return &TCPHandler{stack: s, proxyHandler: ph}
}
// NewUDPHandler creates a new UDP handler
func NewUDPHandler(s *stack.Stack, ph *ProxyHandler) *UDPHandler {
return &UDPHandler{stack: s, proxyHandler: ph}
}
// InstallTCPHandler installs the TCP forwarder on the stack
func (h *TCPHandler) InstallTCPHandler() error {
tcpForwarder := tcp.NewForwarder(h.stack, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) {
var (
wq waiter.Queue
ep tcpip.Endpoint
err tcpip.Error
id = r.ID()
)
// Perform a TCP three-way handshake
ep, err = r.CreateEndpoint(&wq)
if err != nil {
// RST: prevent potential half-open TCP connection leak
r.Complete(true)
return
}
defer r.Complete(false)
// Set socket options
setTCPSocketOptions(h.stack, ep)
// Create TCP connection from netstack endpoint
netstackConn := gonet.NewTCPConn(&wq, ep)
// Handle the connection in a goroutine
go h.handleTCPConn(netstackConn, id)
})
h.stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
return nil
}
// handleTCPConn handles a TCP connection by proxying it to the actual target
func (h *TCPHandler) handleTCPConn(netstackConn *gonet.TCPConn, id stack.TransportEndpointID) {
defer netstackConn.Close()
// Extract source and target address from the connection ID
srcIP := id.RemoteAddress.String()
srcPort := id.RemotePort
dstIP := id.LocalAddress.String()
dstPort := id.LocalPort
logger.Info("TCP Forwarder: Handling connection %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
// Check if there's a destination rewrite for this connection (e.g., localhost targets)
actualDstIP := dstIP
if h.proxyHandler != nil {
if rewrittenAddr, ok := h.proxyHandler.LookupDestinationRewrite(srcIP, dstIP, dstPort, uint8(tcp.ProtocolNumber)); ok {
actualDstIP = rewrittenAddr.String()
logger.Info("TCP Forwarder: Using rewritten destination %s (original: %s)", actualDstIP, dstIP)
}
}
targetAddr := fmt.Sprintf("%s:%d", actualDstIP, dstPort)
// Create context with timeout for connection establishment
ctx, cancel := context.WithTimeout(context.Background(), tcpConnectTimeout)
defer cancel()
// Dial the actual target using standard net package
var d net.Dialer
targetConn, err := d.DialContext(ctx, "tcp", targetAddr)
if err != nil {
logger.Info("TCP Forwarder: Failed to connect to %s: %v", targetAddr, err)
// Connection failed, netstack will handle RST
return
}
defer targetConn.Close()
logger.Info("TCP Forwarder: Successfully connected to %s, starting bidirectional copy", targetAddr)
// Bidirectional copy between netstack and target
pipeTCP(netstackConn, targetConn)
}
// pipeTCP copies data bidirectionally between two connections
func pipeTCP(origin, remote net.Conn) {
wg := sync.WaitGroup{}
wg.Add(2)
go unidirectionalStreamTCP(remote, origin, "origin->remote", &wg)
go unidirectionalStreamTCP(origin, remote, "remote->origin", &wg)
wg.Wait()
}
// unidirectionalStreamTCP copies data in one direction
func unidirectionalStreamTCP(dst, src net.Conn, dir string, wg *sync.WaitGroup) {
defer wg.Done()
buf := make([]byte, bufferSize)
_, _ = io.CopyBuffer(dst, src, buf)
// Do the upload/download side TCP half-close
if cr, ok := src.(interface{ CloseRead() error }); ok {
cr.CloseRead()
}
if cw, ok := dst.(interface{ CloseWrite() error }); ok {
cw.CloseWrite()
}
// Set TCP half-close timeout
dst.SetReadDeadline(time.Now().Add(tcpWaitTimeout))
}
// setTCPSocketOptions sets TCP socket options for better performance
func setTCPSocketOptions(s *stack.Stack, ep tcpip.Endpoint) {
// TCP keepalive options
ep.SocketOptions().SetKeepAlive(true)
idle := tcpip.KeepaliveIdleOption(tcpKeepaliveIdle)
ep.SetSockOpt(&idle)
interval := tcpip.KeepaliveIntervalOption(tcpKeepaliveInterval)
ep.SetSockOpt(&interval)
ep.SetSockOptInt(tcpip.KeepaliveCountOption, tcpKeepaliveCount)
// TCP send/recv buffer size
var ss tcpip.TCPSendBufferSizeRangeOption
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &ss); err == nil {
ep.SocketOptions().SetSendBufferSize(int64(ss.Default), false)
}
var rs tcpip.TCPReceiveBufferSizeRangeOption
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &rs); err == nil {
ep.SocketOptions().SetReceiveBufferSize(int64(rs.Default), false)
}
}
// InstallUDPHandler installs the UDP forwarder on the stack
func (h *UDPHandler) InstallUDPHandler() error {
udpForwarder := udp.NewForwarder(h.stack, func(r *udp.ForwarderRequest) {
var (
wq waiter.Queue
id = r.ID()
)
ep, err := r.CreateEndpoint(&wq)
if err != nil {
return
}
// Create UDP connection from netstack endpoint
netstackConn := gonet.NewUDPConn(&wq, ep)
// Handle the connection in a goroutine
go h.handleUDPConn(netstackConn, id)
})
h.stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
return nil
}
// handleUDPConn handles a UDP connection by proxying it to the actual target
func (h *UDPHandler) handleUDPConn(netstackConn *gonet.UDPConn, id stack.TransportEndpointID) {
defer netstackConn.Close()
// Extract source and target address from the connection ID
srcIP := id.RemoteAddress.String()
srcPort := id.RemotePort
dstIP := id.LocalAddress.String()
dstPort := id.LocalPort
logger.Info("UDP Forwarder: Handling connection %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
// Check if there's a destination rewrite for this connection (e.g., localhost targets)
actualDstIP := dstIP
if h.proxyHandler != nil {
if rewrittenAddr, ok := h.proxyHandler.LookupDestinationRewrite(srcIP, dstIP, dstPort, uint8(udp.ProtocolNumber)); ok {
actualDstIP = rewrittenAddr.String()
logger.Info("UDP Forwarder: Using rewritten destination %s (original: %s)", actualDstIP, dstIP)
}
}
targetAddr := fmt.Sprintf("%s:%d", actualDstIP, dstPort)
// Resolve target address
remoteUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr)
if err != nil {
logger.Info("UDP Forwarder: Failed to resolve %s: %v", targetAddr, err)
return
}
// Resolve client address (for sending responses back)
clientAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", srcIP, srcPort))
if err != nil {
logger.Info("UDP Forwarder: Failed to resolve client %s:%d: %v", srcIP, srcPort, err)
return
}
// Create unconnected UDP socket (so we can use WriteTo)
targetConn, err := net.ListenUDP("udp", nil)
if err != nil {
logger.Info("UDP Forwarder: Failed to create UDP socket: %v", err)
return
}
defer targetConn.Close()
logger.Info("UDP Forwarder: Successfully created UDP socket for %s, starting bidirectional copy", targetAddr)
// Bidirectional copy between netstack and target
pipeUDP(netstackConn, targetConn, remoteUDPAddr, clientAddr, udpSessionTimeout)
}
// pipeUDP copies UDP packets bidirectionally
func pipeUDP(origin, remote net.PacketConn, serverAddr, clientAddr net.Addr, timeout time.Duration) {
wg := sync.WaitGroup{}
wg.Add(2)
// Read from origin (netstack), write to remote (target server)
go unidirectionalPacketStream(remote, origin, serverAddr, "origin->remote", &wg, timeout)
// Read from remote (target server), write to origin (netstack) with client address
go unidirectionalPacketStream(origin, remote, clientAddr, "remote->origin", &wg, timeout)
wg.Wait()
}
// unidirectionalPacketStream copies packets in one direction
func unidirectionalPacketStream(dst, src net.PacketConn, to net.Addr, dir string, wg *sync.WaitGroup, timeout time.Duration) {
defer wg.Done()
logger.Info("UDP %s: Starting packet stream (to=%v)", dir, to)
err := copyPacketData(dst, src, to, timeout)
if err != nil {
logger.Info("UDP %s: Stream ended with error: %v", dir, err)
} else {
logger.Info("UDP %s: Stream ended (timeout)", dir)
}
}
// copyPacketData copies UDP packet data with timeout
func copyPacketData(dst, src net.PacketConn, to net.Addr, timeout time.Duration) error {
buf := make([]byte, 65535) // Max UDP packet size
for {
src.SetReadDeadline(time.Now().Add(timeout))
n, srcAddr, err := src.ReadFrom(buf)
if ne, ok := err.(net.Error); ok && ne.Timeout() {
return nil // ignore I/O timeout
} else if err == io.EOF {
return nil // ignore EOF
} else if err != nil {
return err
}
logger.Info("UDP copyPacketData: Read %d bytes from %v", n, srcAddr)
// Determine write destination
writeAddr := to
if writeAddr == nil {
// If no destination specified, use the source address from the packet
writeAddr = srcAddr
}
written, err := dst.WriteTo(buf[:n], writeAddr)
if err != nil {
logger.Info("UDP copyPacketData: Write error to %v: %v", writeAddr, err)
return err
}
logger.Info("UDP copyPacketData: Wrote %d bytes to %v", written, writeAddr)
dst.SetReadDeadline(time.Now().Add(timeout))
}
}

710
netstack2/proxy.go Normal file
View File

@@ -0,0 +1,710 @@
package netstack2
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"time"
"github.com/fosrl/newt/logger"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/checksum"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
// PortRange represents an allowed range of ports (inclusive)
type PortRange struct {
Min uint16
Max uint16
}
// SubnetRule represents a subnet with optional port restrictions and source address
// When RewriteTo is set, DNAT (Destination Network Address Translation) is performed:
// - Incoming packets: destination IP is rewritten to the resolved RewriteTo address
// - Outgoing packets: source IP is rewritten back to the original destination
//
// RewriteTo can be either:
// - An IP address with CIDR notation (e.g., "192.168.1.1/32")
// - A domain name (e.g., "example.com") which will be resolved at request time
//
// This allows transparent proxying where traffic appears to come from the rewritten address
type SubnetRule struct {
SourcePrefix netip.Prefix // Source IP prefix (who is sending)
DestPrefix netip.Prefix // Destination IP prefix (where it's going)
RewriteTo string // Optional rewrite address for DNAT - can be IP/CIDR or domain name
PortRanges []PortRange // empty slice means all ports allowed
}
// ruleKey is used as a map key for fast O(1) lookups
type ruleKey struct {
sourcePrefix string
destPrefix string
}
// SubnetLookup provides fast IP subnet and port matching with O(1) lookup performance
type SubnetLookup struct {
mu sync.RWMutex
rules map[ruleKey]*SubnetRule // Map for O(1) lookups by prefix combination
}
// NewSubnetLookup creates a new subnet lookup table
func NewSubnetLookup() *SubnetLookup {
return &SubnetLookup{
rules: make(map[ruleKey]*SubnetRule),
}
}
// AddSubnet adds a subnet rule with source and destination prefixes and optional port restrictions
// If portRanges is nil or empty, all ports are allowed for this subnet
// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com")
func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange) {
sl.mu.Lock()
defer sl.mu.Unlock()
key := ruleKey{
sourcePrefix: sourcePrefix.String(),
destPrefix: destPrefix.String(),
}
sl.rules[key] = &SubnetRule{
SourcePrefix: sourcePrefix,
DestPrefix: destPrefix,
RewriteTo: rewriteTo,
PortRanges: portRanges,
}
}
// RemoveSubnet removes a subnet rule from the lookup table
func (sl *SubnetLookup) RemoveSubnet(sourcePrefix, destPrefix netip.Prefix) {
sl.mu.Lock()
defer sl.mu.Unlock()
key := ruleKey{
sourcePrefix: sourcePrefix.String(),
destPrefix: destPrefix.String(),
}
delete(sl.rules, key)
}
// Match checks if a source IP, destination IP, and port match any subnet rule
// Returns the matched rule if BOTH:
// - The source IP is in the rule's source prefix
// - The destination IP is in the rule's destination prefix
// - The port is in an allowed range (or no port restrictions exist)
//
// Returns nil if no rule matches
func (sl *SubnetLookup) Match(srcIP, dstIP netip.Addr, port uint16) *SubnetRule {
sl.mu.RLock()
defer sl.mu.RUnlock()
// Iterate through all rules to find matching source and destination prefixes
// This is O(n) but necessary since we need to check prefix containment, not exact match
for _, rule := range sl.rules {
// Check if source and destination IPs match their respective prefixes
if !rule.SourcePrefix.Contains(srcIP) {
continue
}
if !rule.DestPrefix.Contains(dstIP) {
continue
}
// Both IPs match - now check port restrictions
// If no port ranges specified, all ports are allowed
if len(rule.PortRanges) == 0 {
return rule
}
// Check if port is in any of the allowed ranges
for _, pr := range rule.PortRanges {
if port >= pr.Min && port <= pr.Max {
return rule
}
}
}
return nil
}
// connKey uniquely identifies a connection for NAT tracking
type connKey struct {
srcIP string
srcPort uint16
dstIP string
dstPort uint16
proto uint8
}
// destKey identifies a destination for handler lookups (without source port since it may change)
type destKey struct {
srcIP string
dstIP string
dstPort uint16
proto uint8
}
// natState tracks NAT translation state for reverse translation
type natState struct {
originalDst netip.Addr // Original destination before DNAT
rewrittenTo netip.Addr // The address we rewrote to
}
// ProxyHandler handles packet injection and extraction for promiscuous mode
type ProxyHandler struct {
proxyStack *stack.Stack
proxyEp *channel.Endpoint
proxyNotifyHandle *channel.NotificationHandle
tcpHandler *TCPHandler
udpHandler *UDPHandler
subnetLookup *SubnetLookup
natTable map[connKey]*natState
destRewriteTable map[destKey]netip.Addr // Maps original dest to rewritten dest for handler lookups
natMu sync.RWMutex
enabled bool
}
// ProxyHandlerOptions configures the proxy handler
type ProxyHandlerOptions struct {
EnableTCP bool
EnableUDP bool
MTU int
}
// NewProxyHandler creates a new proxy handler for promiscuous mode
func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
if !options.EnableTCP && !options.EnableUDP {
return nil, nil // No proxy needed
}
handler := &ProxyHandler{
enabled: true,
subnetLookup: NewSubnetLookup(),
natTable: make(map[connKey]*natState),
destRewriteTable: make(map[destKey]netip.Addr),
proxyEp: channel.New(1024, uint32(options.MTU), ""),
proxyStack: stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
ipv6.NewProtocol,
},
TransportProtocols: []stack.TransportProtocolFactory{
tcp.NewProtocol,
udp.NewProtocol,
icmp.NewProtocol4,
icmp.NewProtocol6,
},
}),
}
// Initialize TCP handler if enabled
if options.EnableTCP {
handler.tcpHandler = NewTCPHandler(handler.proxyStack, handler)
if err := handler.tcpHandler.InstallTCPHandler(); err != nil {
return nil, fmt.Errorf("failed to install TCP handler: %v", err)
}
}
// Initialize UDP handler if enabled
if options.EnableUDP {
handler.udpHandler = NewUDPHandler(handler.proxyStack, handler)
if err := handler.udpHandler.InstallUDPHandler(); err != nil {
return nil, fmt.Errorf("failed to install UDP handler: %v", err)
}
}
// // Example 1: Add a rule with no port restrictions (all ports allowed)
// // This accepts all traffic FROM 10.0.0.0/24 TO 10.20.20.0/24
// sourceSubnet := netip.MustParsePrefix("10.0.0.0/24")
// destSubnet := netip.MustParsePrefix("10.20.20.0/24")
// handler.AddSubnetRule(sourceSubnet, destSubnet, nil)
// // Example 2: Add a rule with specific port ranges
// // This accepts traffic FROM 10.0.0.5/32 TO 10.20.21.21/32 only on ports 80, 443, and 8000-9000
// sourceIP := netip.MustParsePrefix("10.0.0.5/32")
// destIP := netip.MustParsePrefix("10.20.21.21/32")
// handler.AddSubnetRule(sourceIP, destIP, []PortRange{
// {Min: 80, Max: 80},
// {Min: 443, Max: 443},
// {Min: 8000, Max: 9000},
// })
return handler, nil
}
// AddSubnetRule adds a subnet with optional port restrictions to the proxy handler
// sourcePrefix: The IP prefix of the peer sending the data
// destPrefix: The IP prefix of the destination
// rewriteTo: Optional address to rewrite destination to - can be IP/CIDR or domain name
// If portRanges is nil or empty, all ports are allowed for this subnet
func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange) {
if p == nil || !p.enabled {
return
}
p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, rewriteTo, portRanges)
}
// RemoveSubnetRule removes a subnet from the proxy handler
func (p *ProxyHandler) RemoveSubnetRule(sourcePrefix, destPrefix netip.Prefix) {
if p == nil || !p.enabled {
return
}
p.subnetLookup.RemoveSubnet(sourcePrefix, destPrefix)
}
// LookupDestinationRewrite looks up the rewritten destination for a connection
// This is used by TCP/UDP handlers to find the actual target address
func (p *ProxyHandler) LookupDestinationRewrite(srcIP, dstIP string, dstPort uint16, proto uint8) (netip.Addr, bool) {
if p == nil || !p.enabled {
return netip.Addr{}, false
}
key := destKey{
srcIP: srcIP,
dstIP: dstIP,
dstPort: dstPort,
proto: proto,
}
p.natMu.RLock()
defer p.natMu.RUnlock()
addr, ok := p.destRewriteTable[key]
return addr, ok
}
// resolveRewriteAddress resolves a rewrite address which can be either:
// - An IP address with CIDR notation (e.g., "192.168.1.1/32") - returns the IP directly
// - A plain IP address (e.g., "192.168.1.1") - returns the IP directly
// - A domain name (e.g., "example.com") - performs DNS lookup
func (p *ProxyHandler) resolveRewriteAddress(rewriteTo string) (netip.Addr, error) {
logger.Debug("Resolving rewrite address: %s", rewriteTo)
// First, try to parse as a CIDR prefix (e.g., "192.168.1.1/32")
if prefix, err := netip.ParsePrefix(rewriteTo); err == nil {
return prefix.Addr(), nil
}
// Try to parse as a plain IP address (e.g., "192.168.1.1")
if addr, err := netip.ParseAddr(rewriteTo); err == nil {
return addr, nil
}
// Not an IP address, treat as domain name - perform DNS lookup
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
ips, err := net.DefaultResolver.LookupIP(ctx, "ip4", rewriteTo)
if err != nil {
return netip.Addr{}, fmt.Errorf("failed to resolve domain %s: %w", rewriteTo, err)
}
if len(ips) == 0 {
return netip.Addr{}, fmt.Errorf("no IP addresses found for domain %s", rewriteTo)
}
// Use the first resolved IP address
ip := ips[0]
if ip4 := ip.To4(); ip4 != nil {
addr := netip.AddrFrom4([4]byte{ip4[0], ip4[1], ip4[2], ip4[3]})
logger.Debug("Resolved %s to %s", rewriteTo, addr)
return addr, nil
}
return netip.Addr{}, fmt.Errorf("no IPv4 address found for domain %s", rewriteTo)
}
// Initialize sets up the promiscuous NIC with the netTun's notification system
func (p *ProxyHandler) Initialize(notifiable channel.Notification) error {
if p == nil || !p.enabled {
return nil
}
// Add notification handler
p.proxyNotifyHandle = p.proxyEp.AddNotify(notifiable)
// Create NIC with promiscuous mode
tcpipErr := p.proxyStack.CreateNICWithOptions(1, p.proxyEp, stack.NICOptions{
Disabled: false,
QDisc: nil,
})
if tcpipErr != nil {
return fmt.Errorf("CreateNIC (proxy): %v", tcpipErr)
}
// Enable promiscuous mode - accepts packets for any destination IP
if tcpipErr := p.proxyStack.SetPromiscuousMode(1, true); tcpipErr != nil {
return fmt.Errorf("SetPromiscuousMode: %v", tcpipErr)
}
// Enable spoofing - allows sending packets from any source IP
if tcpipErr := p.proxyStack.SetSpoofing(1, true); tcpipErr != nil {
return fmt.Errorf("SetSpoofing: %v", tcpipErr)
}
// Add default route
p.proxyStack.AddRoute(tcpip.Route{
Destination: header.IPv4EmptySubnet,
NIC: 1,
})
return nil
}
// HandleIncomingPacket processes incoming packets and determines if they should
// be injected into the proxy stack
func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
if p == nil || !p.enabled {
return false
}
// Check minimum packet size
if len(packet) < header.IPv4MinimumSize {
return false
}
// Only handle IPv4 for now
if packet[0]>>4 != 4 {
return false
}
// Parse IPv4 header
ipv4Header := header.IPv4(packet)
srcIP := ipv4Header.SourceAddress()
dstIP := ipv4Header.DestinationAddress()
// Convert gvisor tcpip.Address to netip.Addr
srcBytes := srcIP.As4()
srcAddr := netip.AddrFrom4(srcBytes)
dstBytes := dstIP.As4()
dstAddr := netip.AddrFrom4(dstBytes)
// Parse transport layer to get destination port
var dstPort uint16
protocol := ipv4Header.TransportProtocol()
headerLen := int(ipv4Header.HeaderLength())
// Extract port based on protocol
switch protocol {
case header.TCPProtocolNumber:
if len(packet) < headerLen+header.TCPMinimumSize {
return false
}
tcpHeader := header.TCP(packet[headerLen:])
dstPort = tcpHeader.DestinationPort()
case header.UDPProtocolNumber:
if len(packet) < headerLen+header.UDPMinimumSize {
return false
}
udpHeader := header.UDP(packet[headerLen:])
dstPort = udpHeader.DestinationPort()
default:
// For other protocols (ICMP, etc.), use port 0 (must match rules with no port restrictions)
dstPort = 0
}
// Check if the source IP, destination IP, and port match any subnet rule
matchedRule := p.subnetLookup.Match(srcAddr, dstAddr, dstPort)
if matchedRule != nil {
// Check if we need to perform DNAT
if matchedRule.RewriteTo != "" {
// Create connection tracking key using original destination
// This allows us to check if we've already resolved for this connection
var srcPort uint16
switch protocol {
case header.TCPProtocolNumber:
tcpHeader := header.TCP(packet[headerLen:])
srcPort = tcpHeader.SourcePort()
case header.UDPProtocolNumber:
udpHeader := header.UDP(packet[headerLen:])
srcPort = udpHeader.SourcePort()
}
// Key using original destination to track the connection
key := connKey{
srcIP: srcAddr.String(),
srcPort: srcPort,
dstIP: dstAddr.String(),
dstPort: dstPort,
proto: uint8(protocol),
}
// Key for handler lookups (doesn't include srcPort for flexibility)
dKey := destKey{
srcIP: srcAddr.String(),
dstIP: dstAddr.String(),
dstPort: dstPort,
proto: uint8(protocol),
}
// Check if we already have a NAT entry for this connection
p.natMu.RLock()
existingEntry, exists := p.natTable[key]
p.natMu.RUnlock()
var newDst netip.Addr
if exists {
// Use the previously resolved address for this connection
newDst = existingEntry.rewrittenTo
logger.Debug("Using existing NAT entry for connection: %s -> %s", dstAddr, newDst)
} else {
// New connection - resolve the rewrite address
var err error
newDst, err = p.resolveRewriteAddress(matchedRule.RewriteTo)
if err != nil {
// Failed to resolve, skip DNAT but still proxy the packet
logger.Debug("Failed to resolve rewrite address: %v", err)
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet),
})
p.proxyEp.InjectInbound(header.IPv4ProtocolNumber, pkb)
return true
}
// Store NAT state for this connection
p.natMu.Lock()
p.natTable[key] = &natState{
originalDst: dstAddr,
rewrittenTo: newDst,
}
// Store destination rewrite for handler lookups
p.destRewriteTable[dKey] = newDst
p.natMu.Unlock()
logger.Debug("New NAT entry for connection: %s -> %s", dstAddr, newDst)
}
// Check if target is loopback - if so, don't rewrite packet destination
// as gVisor will drop martian packets. Instead, the handlers will use
// destRewriteTable to find the actual target address.
if !newDst.IsLoopback() {
// Rewrite the packet only for non-loopback destinations
packet = p.rewritePacketDestination(packet, newDst)
if packet == nil {
return false
}
} else {
logger.Debug("Target is loopback, not rewriting packet - handlers will use rewrite table")
}
}
// Inject into proxy stack
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet),
})
p.proxyEp.InjectInbound(header.IPv4ProtocolNumber, pkb)
return true
}
return false
}
// rewritePacketDestination rewrites the destination IP in a packet and recalculates checksums
func (p *ProxyHandler) rewritePacketDestination(packet []byte, newDst netip.Addr) []byte {
if len(packet) < header.IPv4MinimumSize {
return nil
}
// Make a copy to avoid modifying the original
pkt := make([]byte, len(packet))
copy(pkt, packet)
ipv4Header := header.IPv4(pkt)
headerLen := int(ipv4Header.HeaderLength())
// Rewrite destination IP
newDstBytes := newDst.As4()
newDstAddr := tcpip.AddrFrom4(newDstBytes)
ipv4Header.SetDestinationAddress(newDstAddr)
// Recalculate IP checksum
ipv4Header.SetChecksum(0)
ipv4Header.SetChecksum(^ipv4Header.CalculateChecksum())
// Update transport layer checksum if needed
protocol := ipv4Header.TransportProtocol()
switch protocol {
case header.TCPProtocolNumber:
if len(pkt) >= headerLen+header.TCPMinimumSize {
tcpHeader := header.TCP(pkt[headerLen:])
tcpHeader.SetChecksum(0)
xsum := header.PseudoHeaderChecksum(
header.TCPProtocolNumber,
ipv4Header.SourceAddress(),
ipv4Header.DestinationAddress(),
uint16(len(pkt)-headerLen),
)
xsum = checksum.Checksum(pkt[headerLen:], xsum)
tcpHeader.SetChecksum(^xsum)
}
case header.UDPProtocolNumber:
if len(pkt) >= headerLen+header.UDPMinimumSize {
udpHeader := header.UDP(pkt[headerLen:])
udpHeader.SetChecksum(0)
xsum := header.PseudoHeaderChecksum(
header.UDPProtocolNumber,
ipv4Header.SourceAddress(),
ipv4Header.DestinationAddress(),
uint16(len(pkt)-headerLen),
)
xsum = checksum.Checksum(pkt[headerLen:], xsum)
udpHeader.SetChecksum(^xsum)
}
}
return pkt
}
// rewritePacketSource rewrites the source IP in a packet and recalculates checksums (for reverse NAT)
func (p *ProxyHandler) rewritePacketSource(packet []byte, newSrc netip.Addr) []byte {
if len(packet) < header.IPv4MinimumSize {
return nil
}
// Make a copy to avoid modifying the original
pkt := make([]byte, len(packet))
copy(pkt, packet)
ipv4Header := header.IPv4(pkt)
headerLen := int(ipv4Header.HeaderLength())
// Rewrite source IP
newSrcBytes := newSrc.As4()
newSrcAddr := tcpip.AddrFrom4(newSrcBytes)
ipv4Header.SetSourceAddress(newSrcAddr)
// Recalculate IP checksum
ipv4Header.SetChecksum(0)
ipv4Header.SetChecksum(^ipv4Header.CalculateChecksum())
// Update transport layer checksum if needed
protocol := ipv4Header.TransportProtocol()
switch protocol {
case header.TCPProtocolNumber:
if len(pkt) >= headerLen+header.TCPMinimumSize {
tcpHeader := header.TCP(pkt[headerLen:])
tcpHeader.SetChecksum(0)
xsum := header.PseudoHeaderChecksum(
header.TCPProtocolNumber,
ipv4Header.SourceAddress(),
ipv4Header.DestinationAddress(),
uint16(len(pkt)-headerLen),
)
xsum = checksum.Checksum(pkt[headerLen:], xsum)
tcpHeader.SetChecksum(^xsum)
}
case header.UDPProtocolNumber:
if len(pkt) >= headerLen+header.UDPMinimumSize {
udpHeader := header.UDP(pkt[headerLen:])
udpHeader.SetChecksum(0)
xsum := header.PseudoHeaderChecksum(
header.UDPProtocolNumber,
ipv4Header.SourceAddress(),
ipv4Header.DestinationAddress(),
uint16(len(pkt)-headerLen),
)
xsum = checksum.Checksum(pkt[headerLen:], xsum)
udpHeader.SetChecksum(^xsum)
}
}
return pkt
}
// ReadOutgoingPacket reads packets from the proxy stack that need to be
// sent back through the tunnel
func (p *ProxyHandler) ReadOutgoingPacket() *buffer.View {
if p == nil || !p.enabled {
return nil
}
pkt := p.proxyEp.Read()
if pkt != nil {
view := pkt.ToView()
pkt.DecRef()
// Check if we need to perform reverse NAT
packet := view.AsSlice()
if len(packet) >= header.IPv4MinimumSize && packet[0]>>4 == 4 {
ipv4Header := header.IPv4(packet)
srcIP := ipv4Header.SourceAddress()
dstIP := ipv4Header.DestinationAddress()
protocol := ipv4Header.TransportProtocol()
headerLen := int(ipv4Header.HeaderLength())
// Extract ports
var srcPort, dstPort uint16
switch protocol {
case header.TCPProtocolNumber:
if len(packet) >= headerLen+header.TCPMinimumSize {
tcpHeader := header.TCP(packet[headerLen:])
srcPort = tcpHeader.SourcePort()
dstPort = tcpHeader.DestinationPort()
}
case header.UDPProtocolNumber:
if len(packet) >= headerLen+header.UDPMinimumSize {
udpHeader := header.UDP(packet[headerLen:])
srcPort = udpHeader.SourcePort()
dstPort = udpHeader.DestinationPort()
}
}
// Look up NAT state for reverse translation
// The key uses the original dst (before rewrite), so for replies we need to
// find the entry where the rewritten address matches the current source
p.natMu.RLock()
var natEntry *natState
for k, entry := range p.natTable {
// Match: reply's dst should be original src, reply's src should be rewritten dst
if k.srcIP == dstIP.String() && k.srcPort == dstPort &&
entry.rewrittenTo.String() == srcIP.String() && k.dstPort == srcPort &&
k.proto == uint8(protocol) {
natEntry = entry
break
}
}
p.natMu.RUnlock()
if natEntry != nil {
// Perform reverse NAT - rewrite source to original destination
packet = p.rewritePacketSource(packet, natEntry.originalDst)
if packet != nil {
return buffer.NewViewWithData(packet)
}
}
}
return view
}
return nil
}
// Close cleans up the proxy handler resources
func (p *ProxyHandler) Close() error {
if p == nil || !p.enabled {
return nil
}
if p.proxyStack != nil {
p.proxyStack.RemoveNIC(1)
p.proxyStack.Close()
}
if p.proxyEp != nil {
if p.proxyNotifyHandle != nil {
p.proxyEp.RemoveNotify(p.proxyNotifyHandle)
}
p.proxyEp.Close()
}
return nil
}

1149
netstack2/tun.go Normal file

File diff suppressed because it is too large Load Diff

165
network/interface.go Normal file
View File

@@ -0,0 +1,165 @@
package network
import (
"fmt"
"net"
"os/exec"
"regexp"
"runtime"
"strconv"
"time"
"github.com/fosrl/newt/logger"
"github.com/vishvananda/netlink"
)
// ConfigureInterface configures a network interface with an IP address and brings it up
func ConfigureInterface(interfaceName string, tunnelIp string, mtu int) error {
logger.Info("The tunnel IP is: %s", tunnelIp)
// Parse the IP address and network
ip, ipNet, err := net.ParseCIDR(tunnelIp)
if err != nil {
return fmt.Errorf("invalid IP address: %v", err)
}
// Convert CIDR mask to dotted decimal format (e.g., 255.255.255.0)
mask := net.IP(ipNet.Mask).String()
destinationAddress := ip.String()
logger.Debug("The destination address is: %s", destinationAddress)
// network.SetTunnelRemoteAddress() // what does this do?
SetIPv4Settings([]string{destinationAddress}, []string{mask})
SetMTU(mtu)
if interfaceName == "" {
return nil
}
switch runtime.GOOS {
case "linux":
return configureLinux(interfaceName, ip, ipNet)
case "darwin":
return configureDarwin(interfaceName, ip, ipNet)
case "windows":
return configureWindows(interfaceName, ip, ipNet)
default:
return fmt.Errorf("unsupported operating system: %s", runtime.GOOS)
}
}
// waitForInterfaceUp polls the network interface until it's up or times out
func waitForInterfaceUp(interfaceName string, expectedIP net.IP, timeout time.Duration) error {
logger.Info("Waiting for interface %s to be up with IP %s", interfaceName, expectedIP)
deadline := time.Now().Add(timeout)
pollInterval := 500 * time.Millisecond
for time.Now().Before(deadline) {
// Check if interface exists and is up
iface, err := net.InterfaceByName(interfaceName)
if err == nil {
// Check if interface is up
if iface.Flags&net.FlagUp != 0 {
// Check if it has the expected IP
addrs, err := iface.Addrs()
if err == nil {
for _, addr := range addrs {
ipNet, ok := addr.(*net.IPNet)
if ok && ipNet.IP.Equal(expectedIP) {
logger.Info("Interface %s is up with correct IP", interfaceName)
return nil // Interface is up with correct IP
}
}
logger.Info("Interface %s is up but doesn't have expected IP yet", interfaceName)
}
} else {
logger.Info("Interface %s exists but is not up yet", interfaceName)
}
} else {
logger.Info("Interface %s not found yet: %v", interfaceName, err)
}
// Wait before next check
time.Sleep(pollInterval)
}
return fmt.Errorf("timed out waiting for interface %s to be up with IP %s", interfaceName, expectedIP)
}
func FindUnusedUTUN() (string, error) {
ifaces, err := net.Interfaces()
if err != nil {
return "", fmt.Errorf("failed to list interfaces: %v", err)
}
used := make(map[int]bool)
re := regexp.MustCompile(`^utun(\d+)$`)
for _, iface := range ifaces {
if matches := re.FindStringSubmatch(iface.Name); len(matches) == 2 {
if num, err := strconv.Atoi(matches[1]); err == nil {
used[num] = true
}
}
}
// Try utun0 up to utun255.
for i := 0; i < 256; i++ {
if !used[i] {
return fmt.Sprintf("utun%d", i), nil
}
}
return "", fmt.Errorf("no unused utun interface found")
}
func configureDarwin(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
logger.Info("Configuring darwin interface: %s", interfaceName)
prefix, _ := ipNet.Mask.Size()
ipStr := fmt.Sprintf("%s/%d", ip.String(), prefix)
cmd := exec.Command("ifconfig", interfaceName, "inet", ipStr, ip.String(), "alias")
logger.Info("Running command: %v", cmd)
out, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("ifconfig command failed: %v, output: %s", err, out)
}
// Bring up the interface
cmd = exec.Command("ifconfig", interfaceName, "up")
logger.Info("Running command: %v", cmd)
out, err = cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("ifconfig up command failed: %v, output: %s", err, out)
}
return nil
}
func configureLinux(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
// Get the interface
link, err := netlink.LinkByName(interfaceName)
if err != nil {
return fmt.Errorf("failed to get interface %s: %v", interfaceName, err)
}
// Create the IP address attributes
addr := &netlink.Addr{
IPNet: &net.IPNet{
IP: ip,
Mask: ipNet.Mask,
},
}
// Add the IP address to the interface
if err := netlink.AddrAdd(link, addr); err != nil {
return fmt.Errorf("failed to add IP address: %v", err)
}
// Bring up the interface
if err := netlink.LinkSetUp(link); err != nil {
return fmt.Errorf("failed to bring up interface: %v", err)
}
return nil
}

View File

@@ -0,0 +1,12 @@
//go:build !windows
package network
import (
"fmt"
"net"
)
func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
return fmt.Errorf("configureWindows called on non-Windows platform")
}

View File

@@ -0,0 +1,63 @@
//go:build windows
package network
import (
"fmt"
"net"
"net/netip"
"github.com/fosrl/newt/logger"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
logger.Info("Configuring Windows interface: %s", interfaceName)
// Get the LUID for the interface
iface, err := net.InterfaceByName(interfaceName)
if err != nil {
return fmt.Errorf("failed to get interface %s: %v", interfaceName, err)
}
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
if err != nil {
return fmt.Errorf("failed to get LUID for interface %s: %v", interfaceName, err)
}
// Create the IP address prefix
maskBits, _ := ipNet.Mask.Size()
// Ensure we convert to the correct IP version (IPv4 vs IPv6)
var addr netip.Addr
if ip4 := ip.To4(); ip4 != nil {
// IPv4 address
addr, _ = netip.AddrFromSlice(ip4)
} else {
// IPv6 address
addr, _ = netip.AddrFromSlice(ip)
}
if !addr.IsValid() {
return fmt.Errorf("failed to convert IP address")
}
prefix := netip.PrefixFrom(addr, maskBits)
// Add the IP address to the interface
logger.Info("Adding IP address %s to interface %s", prefix.String(), interfaceName)
err = luid.AddIPAddress(prefix)
if err != nil {
return fmt.Errorf("failed to add IP address: %v", err)
}
// This was required when we were using the subprocess "netsh" command to bring up the interface.
// With the winipcfg library, the interface should already be up after adding the IP so we dont
// need this step anymore as far as I can tell.
// // Wait for the interface to be up and have the correct IP
// err = waitForInterfaceUp(interfaceName, ip, 30*time.Second)
// if err != nil {
// return fmt.Errorf("interface did not come up within timeout: %v", err)
// }
return nil
}

View File

@@ -1,195 +0,0 @@
package network
import (
"encoding/binary"
"encoding/json"
"fmt"
"log"
"net"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/vishvananda/netlink"
"golang.org/x/net/bpf"
"golang.org/x/net/ipv4"
)
const (
udpProtocol = 17
// EmptyUDPSize is the size of an empty UDP packet
EmptyUDPSize = 28
timeout = time.Second * 10
)
// Server stores data relating to the server
type Server struct {
Hostname string
Addr *net.IPAddr
Port uint16
}
// PeerNet stores data about a peer's endpoint
type PeerNet struct {
Resolved bool
IP net.IP
Port uint16
NewtID string
}
// GetClientIP gets source ip address that will be used when sending data to dstIP
func GetClientIP(dstIP net.IP) net.IP {
routes, err := netlink.RouteGet(dstIP)
if err != nil {
log.Fatalln("Error getting route:", err)
}
return routes[0].Src
}
// HostToAddr resolves a hostname, whether DNS or IP to a valid net.IPAddr
func HostToAddr(hostStr string) *net.IPAddr {
remoteAddrs, err := net.LookupHost(hostStr)
if err != nil {
log.Fatalln("Error parsing remote address:", err)
}
for _, addrStr := range remoteAddrs {
if remoteAddr, err := net.ResolveIPAddr("ip4", addrStr); err == nil {
return remoteAddr
}
}
return nil
}
// SetupRawConn creates an ipv4 and udp only RawConn and applies packet filtering
func SetupRawConn(server *Server, client *PeerNet) *ipv4.RawConn {
packetConn, err := net.ListenPacket("ip4:udp", client.IP.String())
if err != nil {
log.Fatalln("Error creating packetConn:", err)
}
rawConn, err := ipv4.NewRawConn(packetConn)
if err != nil {
log.Fatalln("Error creating rawConn:", err)
}
ApplyBPF(rawConn, server, client)
return rawConn
}
// ApplyBPF constructs a BPF program and applies it to the RawConn
func ApplyBPF(rawConn *ipv4.RawConn, server *Server, client *PeerNet) {
const ipv4HeaderLen = 20
const srcIPOffset = 12
const srcPortOffset = ipv4HeaderLen + 0
const dstPortOffset = ipv4HeaderLen + 2
ipArr := []byte(server.Addr.IP.To4())
ipInt := uint32(ipArr[0])<<(3*8) + uint32(ipArr[1])<<(2*8) + uint32(ipArr[2])<<8 + uint32(ipArr[3])
bpfRaw, err := bpf.Assemble([]bpf.Instruction{
bpf.LoadAbsolute{Off: srcIPOffset, Size: 4},
bpf.JumpIf{Cond: bpf.JumpEqual, Val: ipInt, SkipFalse: 5, SkipTrue: 0},
bpf.LoadAbsolute{Off: srcPortOffset, Size: 2},
bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(server.Port), SkipFalse: 3, SkipTrue: 0},
bpf.LoadAbsolute{Off: dstPortOffset, Size: 2},
bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(client.Port), SkipFalse: 1, SkipTrue: 0},
bpf.RetConstant{Val: 1<<(8*4) - 1},
bpf.RetConstant{Val: 0},
})
if err != nil {
log.Fatalln("Error assembling BPF:", err)
}
err = rawConn.SetBPF(bpfRaw)
if err != nil {
log.Fatalln("Error setting BPF:", err)
}
}
// MakePacket constructs a request packet to send to the server
func MakePacket(payload []byte, server *Server, client *PeerNet) []byte {
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
FixLengths: true,
ComputeChecksums: true,
}
ipHeader := layers.IPv4{
SrcIP: client.IP,
DstIP: server.Addr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
udpHeader := layers.UDP{
SrcPort: layers.UDPPort(client.Port),
DstPort: layers.UDPPort(server.Port),
}
payloadLayer := gopacket.Payload(payload)
udpHeader.SetNetworkLayerForChecksum(&ipHeader)
gopacket.SerializeLayers(buf, opts, &ipHeader, &udpHeader, &payloadLayer)
return buf.Bytes()
}
// SendPacket sends packet to the Server
func SendPacket(packet []byte, conn *ipv4.RawConn, server *Server, client *PeerNet) error {
fullPacket := MakePacket(packet, server, client)
_, err := conn.WriteToIP(fullPacket, server.Addr)
return err
}
// SendDataPacket sends a JSON payload to the Server
func SendDataPacket(data interface{}, conn *ipv4.RawConn, server *Server, client *PeerNet) error {
jsonData, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("failed to marshal payload: %v", err)
}
return SendPacket(jsonData, conn, server, client)
}
// RecvPacket receives a UDP packet from server
func RecvPacket(conn *ipv4.RawConn, server *Server, client *PeerNet) ([]byte, int, error) {
err := conn.SetReadDeadline(time.Now().Add(timeout))
if err != nil {
return nil, 0, err
}
response := make([]byte, 4096)
n, err := conn.Read(response)
if err != nil {
return nil, n, err
}
return response, n, nil
}
// RecvDataPacket receives and unmarshals a JSON packet from server
func RecvDataPacket(conn *ipv4.RawConn, server *Server, client *PeerNet) ([]byte, error) {
response, n, err := RecvPacket(conn, server, client)
if err != nil {
return nil, err
}
// Extract payload from UDP packet
payload := response[EmptyUDPSize:n]
return payload, nil
}
// ParseResponse takes a response packet and parses it into an IP and port
func ParseResponse(response []byte) (net.IP, uint16) {
ip := net.IP(response[:4])
port := binary.BigEndian.Uint16(response[4:6])
return ip, port
}

282
network/route.go Normal file
View File

@@ -0,0 +1,282 @@
package network
import (
"fmt"
"net"
"os/exec"
"runtime"
"strings"
"github.com/fosrl/newt/logger"
"github.com/vishvananda/netlink"
)
func DarwinAddRoute(destination string, gateway string, interfaceName string) error {
if runtime.GOOS != "darwin" {
return nil
}
var cmd *exec.Cmd
if gateway != "" {
// Route with specific gateway
cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-gateway", gateway)
} else if interfaceName != "" {
// Route via interface
cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-interface", interfaceName)
} else {
return fmt.Errorf("either gateway or interface must be specified")
}
logger.Info("Running command: %v", cmd)
out, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("route command failed: %v, output: %s", err, out)
}
return nil
}
func DarwinRemoveRoute(destination string) error {
if runtime.GOOS != "darwin" {
return nil
}
cmd := exec.Command("route", "-q", "-n", "delete", "-inet", destination)
logger.Info("Running command: %v", cmd)
out, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("route delete command failed: %v, output: %s", err, out)
}
return nil
}
func LinuxAddRoute(destination string, gateway string, interfaceName string) error {
if runtime.GOOS != "linux" {
return nil
}
// Parse destination CIDR
_, ipNet, err := net.ParseCIDR(destination)
if err != nil {
return fmt.Errorf("invalid destination address: %v", err)
}
// Create route
route := &netlink.Route{
Dst: ipNet,
}
if gateway != "" {
// Route with specific gateway
gw := net.ParseIP(gateway)
if gw == nil {
return fmt.Errorf("invalid gateway address: %s", gateway)
}
route.Gw = gw
logger.Info("Adding route to %s via gateway %s", destination, gateway)
} else if interfaceName != "" {
// Route via interface
link, err := netlink.LinkByName(interfaceName)
if err != nil {
return fmt.Errorf("failed to get interface %s: %v", interfaceName, err)
}
route.LinkIndex = link.Attrs().Index
logger.Info("Adding route to %s via interface %s", destination, interfaceName)
} else {
return fmt.Errorf("either gateway or interface must be specified")
}
// Add the route
if err := netlink.RouteAdd(route); err != nil {
return fmt.Errorf("failed to add route: %v", err)
}
return nil
}
func LinuxRemoveRoute(destination string) error {
if runtime.GOOS != "linux" {
return nil
}
// Parse destination CIDR
_, ipNet, err := net.ParseCIDR(destination)
if err != nil {
return fmt.Errorf("invalid destination address: %v", err)
}
// Create route to delete
route := &netlink.Route{
Dst: ipNet,
}
logger.Info("Removing route to %s", destination)
// Delete the route
if err := netlink.RouteDel(route); err != nil {
return fmt.Errorf("failed to delete route: %v", err)
}
return nil
}
// addRouteForServerIP adds an OS-specific route for the server IP
func AddRouteForServerIP(serverIP, interfaceName string) error {
if err := AddRouteForNetworkConfig(serverIP); err != nil {
return err
}
if interfaceName == "" {
return nil
}
if runtime.GOOS == "darwin" {
return DarwinAddRoute(serverIP, "", interfaceName)
}
// else if runtime.GOOS == "windows" {
// return WindowsAddRoute(serverIP, "", interfaceName)
// } else if runtime.GOOS == "linux" {
// return LinuxAddRoute(serverIP, "", interfaceName)
// }
return nil
}
// removeRouteForServerIP removes an OS-specific route for the server IP
func RemoveRouteForServerIP(serverIP string, interfaceName string) error {
if err := RemoveRouteForNetworkConfig(serverIP); err != nil {
return err
}
if interfaceName == "" {
return nil
}
if runtime.GOOS == "darwin" {
return DarwinRemoveRoute(serverIP)
}
// else if runtime.GOOS == "windows" {
// return WindowsRemoveRoute(serverIP)
// } else if runtime.GOOS == "linux" {
// return LinuxRemoveRoute(serverIP)
// }
return nil
}
func AddRouteForNetworkConfig(destination string) error {
// Parse the subnet to extract IP and mask
_, ipNet, err := net.ParseCIDR(destination)
if err != nil {
return fmt.Errorf("failed to parse subnet %s: %v", destination, err)
}
// Convert CIDR mask to dotted decimal format (e.g., 255.255.255.0)
mask := net.IP(ipNet.Mask).String()
destinationAddress := ipNet.IP.String()
AddIPv4IncludedRoute(IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask})
return nil
}
func RemoveRouteForNetworkConfig(destination string) error {
// Parse the subnet to extract IP and mask
_, ipNet, err := net.ParseCIDR(destination)
if err != nil {
return fmt.Errorf("failed to parse subnet %s: %v", destination, err)
}
// Convert CIDR mask to dotted decimal format (e.g., 255.255.255.0)
mask := net.IP(ipNet.Mask).String()
destinationAddress := ipNet.IP.String()
RemoveIPv4IncludedRoute(IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask})
return nil
}
// addRoutes adds routes for each subnet in RemoteSubnets
func AddRoutes(remoteSubnets []string, interfaceName string) error {
if len(remoteSubnets) == 0 {
return nil
}
// Add routes for each subnet
for _, subnet := range remoteSubnets {
subnet = strings.TrimSpace(subnet)
if subnet == "" {
continue
}
if err := AddRouteForNetworkConfig(subnet); err != nil {
logger.Error("Failed to add network config for subnet %s: %v", subnet, err)
continue
}
// Add route based on operating system
if interfaceName == "" {
continue
}
if runtime.GOOS == "darwin" {
if err := DarwinAddRoute(subnet, "", interfaceName); err != nil {
logger.Error("Failed to add Darwin route for subnet %s: %v", subnet, err)
return err
}
} else if runtime.GOOS == "windows" {
if err := WindowsAddRoute(subnet, "", interfaceName); err != nil {
logger.Error("Failed to add Windows route for subnet %s: %v", subnet, err)
return err
}
} else if runtime.GOOS == "linux" {
if err := LinuxAddRoute(subnet, "", interfaceName); err != nil {
logger.Error("Failed to add Linux route for subnet %s: %v", subnet, err)
return err
}
}
logger.Info("Added route for remote subnet: %s", subnet)
}
return nil
}
// removeRoutesForRemoteSubnets removes routes for each subnet in RemoteSubnets
func RemoveRoutes(remoteSubnets []string) error {
if len(remoteSubnets) == 0 {
return nil
}
// Remove routes for each subnet
for _, subnet := range remoteSubnets {
subnet = strings.TrimSpace(subnet)
if subnet == "" {
continue
}
if err := RemoveRouteForNetworkConfig(subnet); err != nil {
logger.Error("Failed to remove network config for subnet %s: %v", subnet, err)
continue
}
// Remove route based on operating system
if runtime.GOOS == "darwin" {
if err := DarwinRemoveRoute(subnet); err != nil {
logger.Error("Failed to remove Darwin route for subnet %s: %v", subnet, err)
return err
}
} else if runtime.GOOS == "windows" {
if err := WindowsRemoveRoute(subnet); err != nil {
logger.Error("Failed to remove Windows route for subnet %s: %v", subnet, err)
return err
}
} else if runtime.GOOS == "linux" {
if err := LinuxRemoveRoute(subnet); err != nil {
logger.Error("Failed to remove Linux route for subnet %s: %v", subnet, err)
return err
}
}
logger.Info("Removed route for remote subnet: %s", subnet)
}
return nil
}

View File

@@ -0,0 +1,11 @@
//go:build !windows
package network
func WindowsAddRoute(destination string, gateway string, interfaceName string) error {
return nil
}
func WindowsRemoveRoute(destination string) error {
return nil
}

148
network/route_windows.go Normal file
View File

@@ -0,0 +1,148 @@
//go:build windows
package network
import (
"fmt"
"net"
"net/netip"
"runtime"
"github.com/fosrl/newt/logger"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
func WindowsAddRoute(destination string, gateway string, interfaceName string) error {
if runtime.GOOS != "windows" {
return nil
}
// Parse destination CIDR
_, ipNet, err := net.ParseCIDR(destination)
if err != nil {
return fmt.Errorf("invalid destination address: %v", err)
}
// Convert to netip.Prefix
maskBits, _ := ipNet.Mask.Size()
// Ensure we convert to the correct IP version (IPv4 vs IPv6)
var addr netip.Addr
if ip4 := ipNet.IP.To4(); ip4 != nil {
// IPv4 address
addr, _ = netip.AddrFromSlice(ip4)
} else {
// IPv6 address
addr, _ = netip.AddrFromSlice(ipNet.IP)
}
if !addr.IsValid() {
return fmt.Errorf("failed to convert destination IP")
}
prefix := netip.PrefixFrom(addr, maskBits)
var luid winipcfg.LUID
var nextHop netip.Addr
if interfaceName != "" {
// Get the interface LUID - needed for both gateway and interface-only routes
iface, err := net.InterfaceByName(interfaceName)
if err != nil {
return fmt.Errorf("failed to get interface %s: %v", interfaceName, err)
}
luid, err = winipcfg.LUIDFromIndex(uint32(iface.Index))
if err != nil {
return fmt.Errorf("failed to get LUID for interface %s: %v", interfaceName, err)
}
}
if gateway != "" {
// Route with specific gateway
gwIP := net.ParseIP(gateway)
if gwIP == nil {
return fmt.Errorf("invalid gateway address: %s", gateway)
}
// Convert to correct IP version
if ip4 := gwIP.To4(); ip4 != nil {
nextHop, _ = netip.AddrFromSlice(ip4)
} else {
nextHop, _ = netip.AddrFromSlice(gwIP)
}
if !nextHop.IsValid() {
return fmt.Errorf("failed to convert gateway IP")
}
logger.Info("Adding route to %s via gateway %s on interface %s", destination, gateway, interfaceName)
} else if interfaceName != "" {
// Route via interface only
if addr.Is4() {
nextHop = netip.IPv4Unspecified()
} else {
nextHop = netip.IPv6Unspecified()
}
logger.Info("Adding route to %s via interface %s", destination, interfaceName)
} else {
return fmt.Errorf("either gateway or interface must be specified")
}
// Add the route using winipcfg
err = luid.AddRoute(prefix, nextHop, 1)
if err != nil {
return fmt.Errorf("failed to add route: %v", err)
}
return nil
}
func WindowsRemoveRoute(destination string) error {
// Parse destination CIDR
_, ipNet, err := net.ParseCIDR(destination)
if err != nil {
return fmt.Errorf("invalid destination address: %v", err)
}
// Convert to netip.Prefix
maskBits, _ := ipNet.Mask.Size()
// Ensure we convert to the correct IP version (IPv4 vs IPv6)
var addr netip.Addr
if ip4 := ipNet.IP.To4(); ip4 != nil {
// IPv4 address
addr, _ = netip.AddrFromSlice(ip4)
} else {
// IPv6 address
addr, _ = netip.AddrFromSlice(ipNet.IP)
}
if !addr.IsValid() {
return fmt.Errorf("failed to convert destination IP")
}
prefix := netip.PrefixFrom(addr, maskBits)
// Get all routes and find the one to delete
// We need to get the LUID from the existing route
var family winipcfg.AddressFamily
if addr.Is4() {
family = 2 // AF_INET
} else {
family = 23 // AF_INET6
}
routes, err := winipcfg.GetIPForwardTable2(family)
if err != nil {
return fmt.Errorf("failed to get route table: %v", err)
}
// Find and delete matching route
for _, route := range routes {
routePrefix := route.DestinationPrefix.Prefix()
if routePrefix == prefix {
logger.Info("Removing route to %s", destination)
err = route.Delete()
if err != nil {
return fmt.Errorf("failed to delete route: %v", err)
}
return nil
}
}
return fmt.Errorf("route to %s not found", destination)
}

190
network/settings.go Normal file
View File

@@ -0,0 +1,190 @@
package network
import (
"encoding/json"
"sync"
"github.com/fosrl/newt/logger"
)
// NetworkSettings represents the network configuration for the tunnel
type NetworkSettings struct {
TunnelRemoteAddress string `json:"tunnel_remote_address,omitempty"`
MTU *int `json:"mtu,omitempty"`
DNSServers []string `json:"dns_servers,omitempty"`
IPv4Addresses []string `json:"ipv4_addresses,omitempty"`
IPv4SubnetMasks []string `json:"ipv4_subnet_masks,omitempty"`
IPv4IncludedRoutes []IPv4Route `json:"ipv4_included_routes,omitempty"`
IPv4ExcludedRoutes []IPv4Route `json:"ipv4_excluded_routes,omitempty"`
IPv6Addresses []string `json:"ipv6_addresses,omitempty"`
IPv6NetworkPrefixes []string `json:"ipv6_network_prefixes,omitempty"`
IPv6IncludedRoutes []IPv6Route `json:"ipv6_included_routes,omitempty"`
IPv6ExcludedRoutes []IPv6Route `json:"ipv6_excluded_routes,omitempty"`
}
// IPv4Route represents an IPv4 route
type IPv4Route struct {
DestinationAddress string `json:"destination_address"`
SubnetMask string `json:"subnet_mask,omitempty"`
GatewayAddress string `json:"gateway_address,omitempty"`
IsDefault bool `json:"is_default,omitempty"`
}
// IPv6Route represents an IPv6 route
type IPv6Route struct {
DestinationAddress string `json:"destination_address"`
NetworkPrefixLength int `json:"network_prefix_length,omitempty"`
GatewayAddress string `json:"gateway_address,omitempty"`
IsDefault bool `json:"is_default,omitempty"`
}
var (
networkSettings NetworkSettings
networkSettingsMutex sync.RWMutex
incrementor int
)
// SetTunnelRemoteAddress sets the tunnel remote address
func SetTunnelRemoteAddress(address string) {
networkSettingsMutex.Lock()
defer networkSettingsMutex.Unlock()
networkSettings.TunnelRemoteAddress = address
incrementor++
logger.Info("Set tunnel remote address: %s", address)
}
// SetMTU sets the MTU value
func SetMTU(mtu int) {
networkSettingsMutex.Lock()
defer networkSettingsMutex.Unlock()
networkSettings.MTU = &mtu
incrementor++
logger.Info("Set MTU: %d", mtu)
}
// SetDNSServers sets the DNS servers
func SetDNSServers(servers []string) {
networkSettingsMutex.Lock()
defer networkSettingsMutex.Unlock()
networkSettings.DNSServers = servers
incrementor++
logger.Info("Set DNS servers: %v", servers)
}
// SetIPv4Settings sets IPv4 addresses and subnet masks
func SetIPv4Settings(addresses []string, subnetMasks []string) {
networkSettingsMutex.Lock()
defer networkSettingsMutex.Unlock()
networkSettings.IPv4Addresses = addresses
networkSettings.IPv4SubnetMasks = subnetMasks
incrementor++
logger.Info("Set IPv4 addresses: %v, subnet masks: %v", addresses, subnetMasks)
}
// SetIPv4IncludedRoutes sets the included IPv4 routes
func SetIPv4IncludedRoutes(routes []IPv4Route) {
networkSettingsMutex.Lock()
defer networkSettingsMutex.Unlock()
networkSettings.IPv4IncludedRoutes = routes
incrementor++
logger.Info("Set IPv4 included routes: %d routes", len(routes))
}
func AddIPv4IncludedRoute(route IPv4Route) {
networkSettingsMutex.Lock()
defer networkSettingsMutex.Unlock()
// make sure it does not already exist
for _, r := range networkSettings.IPv4IncludedRoutes {
if r == route {
logger.Info("IPv4 included route already exists: %+v", route)
return
}
}
networkSettings.IPv4IncludedRoutes = append(networkSettings.IPv4IncludedRoutes, route)
incrementor++
logger.Info("Added IPv4 included route: %+v", route)
}
func RemoveIPv4IncludedRoute(route IPv4Route) {
networkSettingsMutex.Lock()
defer networkSettingsMutex.Unlock()
routes := networkSettings.IPv4IncludedRoutes
for i, r := range routes {
if r == route {
networkSettings.IPv4IncludedRoutes = append(routes[:i], routes[i+1:]...)
logger.Info("Removed IPv4 included route: %+v", route)
return
}
}
incrementor++
logger.Info("IPv4 included route not found for removal: %+v", route)
}
func SetIPv4ExcludedRoutes(routes []IPv4Route) {
networkSettingsMutex.Lock()
defer networkSettingsMutex.Unlock()
networkSettings.IPv4ExcludedRoutes = routes
incrementor++
logger.Info("Set IPv4 excluded routes: %d routes", len(routes))
}
// SetIPv6Settings sets IPv6 addresses and network prefixes
func SetIPv6Settings(addresses []string, networkPrefixes []string) {
networkSettingsMutex.Lock()
defer networkSettingsMutex.Unlock()
networkSettings.IPv6Addresses = addresses
networkSettings.IPv6NetworkPrefixes = networkPrefixes
incrementor++
logger.Info("Set IPv6 addresses: %v, network prefixes: %v", addresses, networkPrefixes)
}
// SetIPv6IncludedRoutes sets the included IPv6 routes
func SetIPv6IncludedRoutes(routes []IPv6Route) {
networkSettingsMutex.Lock()
defer networkSettingsMutex.Unlock()
networkSettings.IPv6IncludedRoutes = routes
incrementor++
logger.Info("Set IPv6 included routes: %d routes", len(routes))
}
// SetIPv6ExcludedRoutes sets the excluded IPv6 routes
func SetIPv6ExcludedRoutes(routes []IPv6Route) {
networkSettingsMutex.Lock()
defer networkSettingsMutex.Unlock()
networkSettings.IPv6ExcludedRoutes = routes
incrementor++
logger.Info("Set IPv6 excluded routes: %d routes", len(routes))
}
// ClearNetworkSettings clears all network settings
func ClearNetworkSettings() {
networkSettingsMutex.Lock()
defer networkSettingsMutex.Unlock()
networkSettings = NetworkSettings{}
incrementor++
logger.Info("Cleared all network settings")
}
func GetJSON() (string, error) {
networkSettingsMutex.RLock()
defer networkSettingsMutex.RUnlock()
data, err := json.MarshalIndent(networkSettings, "", " ")
if err != nil {
return "", err
}
return string(data), nil
}
func GetSettings() NetworkSettings {
networkSettingsMutex.RLock()
defer networkSettingsMutex.RUnlock()
return networkSettings
}
func GetIncrementor() int {
networkSettingsMutex.Lock()
defer networkSettingsMutex.Unlock()
return incrementor
}

View File

@@ -1,18 +1,28 @@
package proxy
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/fosrl/newt/internal/state"
"github.com/fosrl/newt/internal/telemetry"
"github.com/fosrl/newt/logger"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
"golang.zx2c4.com/wireguard/tun/netstack"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
)
const errUnsupportedProtoFmt = "unsupported protocol: %s"
// Target represents a proxy target with its address and port
type Target struct {
Address string
@@ -28,6 +38,90 @@ type ProxyManager struct {
udpConns []*gonet.UDPConn
running bool
mutex sync.RWMutex
// telemetry (multi-tunnel)
currentTunnelID string
tunnels map[string]*tunnelEntry
asyncBytes bool
flushStop chan struct{}
}
// tunnelEntry holds per-tunnel attributes and (optional) async counters.
type tunnelEntry struct {
attrInTCP attribute.Set
attrOutTCP attribute.Set
attrInUDP attribute.Set
attrOutUDP attribute.Set
bytesInTCP atomic.Uint64
bytesOutTCP atomic.Uint64
bytesInUDP atomic.Uint64
bytesOutUDP atomic.Uint64
activeTCP atomic.Int64
activeUDP atomic.Int64
}
// countingWriter wraps an io.Writer and adds bytes to OTel counter using a pre-built attribute set.
type countingWriter struct {
ctx context.Context
w io.Writer
set attribute.Set
pm *ProxyManager
ent *tunnelEntry
out bool // false=in, true=out
proto string // "tcp" or "udp"
}
func (cw *countingWriter) Write(p []byte) (int, error) {
n, err := cw.w.Write(p)
if n > 0 {
if cw.pm != nil && cw.pm.asyncBytes && cw.ent != nil {
switch cw.proto {
case "tcp":
if cw.out {
cw.ent.bytesOutTCP.Add(uint64(n))
} else {
cw.ent.bytesInTCP.Add(uint64(n))
}
case "udp":
if cw.out {
cw.ent.bytesOutUDP.Add(uint64(n))
} else {
cw.ent.bytesInUDP.Add(uint64(n))
}
}
} else {
telemetry.AddTunnelBytesSet(cw.ctx, int64(n), cw.set)
}
}
return n, err
}
func classifyProxyError(err error) string {
if err == nil {
return ""
}
if errors.Is(err, net.ErrClosed) {
return "closed"
}
if ne, ok := err.(net.Error); ok {
if ne.Timeout() {
return "timeout"
}
if ne.Temporary() {
return "temporary"
}
}
msg := strings.ToLower(err.Error())
switch {
case strings.Contains(msg, "refused"):
return "refused"
case strings.Contains(msg, "reset"):
return "reset"
default:
return "io_error"
}
}
// NewProxyManager creates a new proxy manager instance
@@ -38,9 +132,77 @@ func NewProxyManager(tnet *netstack.Net) *ProxyManager {
udpTargets: make(map[string]map[int]string),
listeners: make([]*gonet.TCPListener, 0),
udpConns: make([]*gonet.UDPConn, 0),
tunnels: make(map[string]*tunnelEntry),
}
}
// SetTunnelID sets the WireGuard peer public key used as tunnel_id label.
func (pm *ProxyManager) SetTunnelID(id string) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.currentTunnelID = id
if _, ok := pm.tunnels[id]; !ok {
pm.tunnels[id] = &tunnelEntry{}
}
e := pm.tunnels[id]
// include site labels if available
site := telemetry.SiteLabelKVs()
build := func(base []attribute.KeyValue) attribute.Set {
if telemetry.ShouldIncludeTunnelID() {
base = append([]attribute.KeyValue{attribute.String("tunnel_id", id)}, base...)
}
base = append(site, base...)
return attribute.NewSet(base...)
}
e.attrInTCP = build([]attribute.KeyValue{
attribute.String("direction", "ingress"),
attribute.String("protocol", "tcp"),
})
e.attrOutTCP = build([]attribute.KeyValue{
attribute.String("direction", "egress"),
attribute.String("protocol", "tcp"),
})
e.attrInUDP = build([]attribute.KeyValue{
attribute.String("direction", "ingress"),
attribute.String("protocol", "udp"),
})
e.attrOutUDP = build([]attribute.KeyValue{
attribute.String("direction", "egress"),
attribute.String("protocol", "udp"),
})
}
// ClearTunnelID clears cached attribute sets for the current tunnel.
func (pm *ProxyManager) ClearTunnelID() {
pm.mutex.Lock()
defer pm.mutex.Unlock()
id := pm.currentTunnelID
if id == "" {
return
}
if e, ok := pm.tunnels[id]; ok {
// final flush for this tunnel
inTCP := e.bytesInTCP.Swap(0)
outTCP := e.bytesOutTCP.Swap(0)
inUDP := e.bytesInUDP.Swap(0)
outUDP := e.bytesOutUDP.Swap(0)
if inTCP > 0 {
telemetry.AddTunnelBytesSet(context.Background(), int64(inTCP), e.attrInTCP)
}
if outTCP > 0 {
telemetry.AddTunnelBytesSet(context.Background(), int64(outTCP), e.attrOutTCP)
}
if inUDP > 0 {
telemetry.AddTunnelBytesSet(context.Background(), int64(inUDP), e.attrInUDP)
}
if outUDP > 0 {
telemetry.AddTunnelBytesSet(context.Background(), int64(outUDP), e.attrOutUDP)
}
delete(pm.tunnels, id)
}
pm.currentTunnelID = ""
}
// init function without tnet
func NewProxyManagerWithoutTNet() *ProxyManager {
return &ProxyManager{
@@ -75,7 +237,7 @@ func (pm *ProxyManager) AddTarget(proto, listenIP string, port int, targetAddr s
}
pm.udpTargets[listenIP][port] = targetAddr
default:
return fmt.Errorf("unsupported protocol: %s", proto)
return fmt.Errorf(errUnsupportedProtoFmt, proto)
}
if pm.running {
@@ -124,13 +286,28 @@ func (pm *ProxyManager) RemoveTarget(proto, listenIP string, port int) error {
return fmt.Errorf("target not found: %s:%d", listenIP, port)
}
default:
return fmt.Errorf("unsupported protocol: %s", proto)
return fmt.Errorf(errUnsupportedProtoFmt, proto)
}
return nil
}
// Start begins listening for all configured proxy targets
func (pm *ProxyManager) Start() error {
// Register proxy observables once per process
telemetry.SetProxyObservableCallback(func(ctx context.Context, o metric.Observer) error {
pm.mutex.RLock()
defer pm.mutex.RUnlock()
for _, e := range pm.tunnels {
// active connections
telemetry.ObserveProxyActiveConnsObs(o, e.activeTCP.Load(), e.attrOutTCP.ToSlice())
telemetry.ObserveProxyActiveConnsObs(o, e.activeUDP.Load(), e.attrOutUDP.ToSlice())
// backlog bytes (sum of unflushed counters)
b := int64(e.bytesInTCP.Load() + e.bytesOutTCP.Load() + e.bytesInUDP.Load() + e.bytesOutUDP.Load())
telemetry.ObserveProxyAsyncBacklogObs(o, b, e.attrOutTCP.ToSlice())
telemetry.ObserveProxyBufferBytesObs(o, b, e.attrOutTCP.ToSlice())
}
return nil
})
pm.mutex.Lock()
defer pm.mutex.Unlock()
@@ -160,6 +337,75 @@ func (pm *ProxyManager) Start() error {
return nil
}
func (pm *ProxyManager) SetAsyncBytes(b bool) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.asyncBytes = b
if b && pm.flushStop == nil {
pm.flushStop = make(chan struct{})
go pm.flushLoop()
}
}
func (pm *ProxyManager) flushLoop() {
flushInterval := 2 * time.Second
if v := os.Getenv("OTEL_METRIC_EXPORT_INTERVAL"); v != "" {
if d, err := time.ParseDuration(v); err == nil && d > 0 {
if d/2 < flushInterval {
flushInterval = d / 2
}
}
}
ticker := time.NewTicker(flushInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
pm.mutex.RLock()
for _, e := range pm.tunnels {
inTCP := e.bytesInTCP.Swap(0)
outTCP := e.bytesOutTCP.Swap(0)
inUDP := e.bytesInUDP.Swap(0)
outUDP := e.bytesOutUDP.Swap(0)
if inTCP > 0 {
telemetry.AddTunnelBytesSet(context.Background(), int64(inTCP), e.attrInTCP)
}
if outTCP > 0 {
telemetry.AddTunnelBytesSet(context.Background(), int64(outTCP), e.attrOutTCP)
}
if inUDP > 0 {
telemetry.AddTunnelBytesSet(context.Background(), int64(inUDP), e.attrInUDP)
}
if outUDP > 0 {
telemetry.AddTunnelBytesSet(context.Background(), int64(outUDP), e.attrOutUDP)
}
}
pm.mutex.RUnlock()
case <-pm.flushStop:
pm.mutex.RLock()
for _, e := range pm.tunnels {
inTCP := e.bytesInTCP.Swap(0)
outTCP := e.bytesOutTCP.Swap(0)
inUDP := e.bytesInUDP.Swap(0)
outUDP := e.bytesOutUDP.Swap(0)
if inTCP > 0 {
telemetry.AddTunnelBytesSet(context.Background(), int64(inTCP), e.attrInTCP)
}
if outTCP > 0 {
telemetry.AddTunnelBytesSet(context.Background(), int64(outTCP), e.attrOutTCP)
}
if inUDP > 0 {
telemetry.AddTunnelBytesSet(context.Background(), int64(inUDP), e.attrInUDP)
}
if outUDP > 0 {
telemetry.AddTunnelBytesSet(context.Background(), int64(outUDP), e.attrOutUDP)
}
}
pm.mutex.RUnlock()
return
}
}
}
func (pm *ProxyManager) Stop() error {
pm.mutex.Lock()
defer pm.mutex.Unlock()
@@ -227,7 +473,7 @@ func (pm *ProxyManager) startTarget(proto, listenIP string, port int, targetAddr
go pm.handleUDPProxy(conn, targetAddr)
default:
return fmt.Errorf("unsupported protocol: %s", proto)
return fmt.Errorf(errUnsupportedProtoFmt, proto)
}
logger.Info("Started %s proxy to %s", proto, targetAddr)
@@ -236,54 +482,84 @@ func (pm *ProxyManager) startTarget(proto, listenIP string, port int, targetAddr
return nil
}
// getEntry returns per-tunnel entry or nil.
func (pm *ProxyManager) getEntry(id string) *tunnelEntry {
pm.mutex.RLock()
e := pm.tunnels[id]
pm.mutex.RUnlock()
return e
}
func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string) {
for {
conn, err := listener.Accept()
if err != nil {
// Check if we're shutting down or the listener was closed
telemetry.IncProxyAccept(context.Background(), pm.currentTunnelID, "tcp", "failure", classifyProxyError(err))
if !pm.running {
return
}
// Check for specific network errors that indicate the listener is closed
if ne, ok := err.(net.Error); ok && !ne.Temporary() {
logger.Info("TCP listener closed, stopping proxy handler for %v", listener.Addr())
return
}
logger.Error("Error accepting TCP connection: %v", err)
// Don't hammer the CPU if we hit a temporary error
time.Sleep(100 * time.Millisecond)
continue
}
go func() {
tunnelID := pm.currentTunnelID
telemetry.IncProxyAccept(context.Background(), tunnelID, "tcp", "success", "")
telemetry.IncProxyConnectionEvent(context.Background(), tunnelID, "tcp", telemetry.ProxyConnectionOpened)
if tunnelID != "" {
state.Global().IncSessions(tunnelID)
if e := pm.getEntry(tunnelID); e != nil {
e.activeTCP.Add(1)
}
}
go func(tunnelID string, accepted net.Conn) {
connStart := time.Now()
target, err := net.Dial("tcp", targetAddr)
if err != nil {
logger.Error("Error connecting to target: %v", err)
conn.Close()
accepted.Close()
telemetry.IncProxyAccept(context.Background(), tunnelID, "tcp", "failure", classifyProxyError(err))
telemetry.IncProxyConnectionEvent(context.Background(), tunnelID, "tcp", telemetry.ProxyConnectionClosed)
telemetry.ObserveProxyConnectionDuration(context.Background(), tunnelID, "tcp", "failure", time.Since(connStart).Seconds())
return
}
// Create a WaitGroup to ensure both copy operations complete
entry := pm.getEntry(tunnelID)
if entry == nil {
entry = &tunnelEntry{}
}
var wg sync.WaitGroup
wg.Add(2)
go func() {
go func(ent *tunnelEntry) {
defer wg.Done()
io.Copy(target, conn)
target.Close()
}()
cw := &countingWriter{ctx: context.Background(), w: target, set: ent.attrInTCP, pm: pm, ent: ent, out: false, proto: "tcp"}
_, _ = io.Copy(cw, accepted)
_ = target.Close()
}(entry)
go func() {
go func(ent *tunnelEntry) {
defer wg.Done()
io.Copy(conn, target)
conn.Close()
}()
cw := &countingWriter{ctx: context.Background(), w: accepted, set: ent.attrOutTCP, pm: pm, ent: ent, out: true, proto: "tcp"}
_, _ = io.Copy(cw, target)
_ = accepted.Close()
}(entry)
// Wait for both copies to complete
wg.Wait()
}()
if tunnelID != "" {
state.Global().DecSessions(tunnelID)
if e := pm.getEntry(tunnelID); e != nil {
e.activeTCP.Add(-1)
}
}
telemetry.ObserveProxyConnectionDuration(context.Background(), tunnelID, "tcp", "success", time.Since(connStart).Seconds())
telemetry.IncProxyConnectionEvent(context.Background(), tunnelID, "tcp", telemetry.ProxyConnectionClosed)
}(tunnelID, conn)
}
}
@@ -326,6 +602,18 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
}
clientKey := remoteAddr.String()
// bytes from client -> target (direction=in)
if pm.currentTunnelID != "" && n > 0 {
if pm.asyncBytes {
if e := pm.getEntry(pm.currentTunnelID); e != nil {
e.bytesInUDP.Add(uint64(n))
}
} else {
if e := pm.getEntry(pm.currentTunnelID); e != nil {
telemetry.AddTunnelBytesSet(context.Background(), int64(n), e.attrInUDP)
}
}
}
clientsMutex.RLock()
targetConn, exists := clientConns[clientKey]
clientsMutex.RUnlock()
@@ -334,28 +622,44 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
targetUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr)
if err != nil {
logger.Error("Error resolving target address: %v", err)
telemetry.IncProxyAccept(context.Background(), pm.currentTunnelID, "udp", "failure", "resolve")
continue
}
targetConn, err = net.DialUDP("udp", nil, targetUDPAddr)
if err != nil {
logger.Error("Error connecting to target: %v", err)
telemetry.IncProxyAccept(context.Background(), pm.currentTunnelID, "udp", "failure", classifyProxyError(err))
continue
}
tunnelID := pm.currentTunnelID
telemetry.IncProxyAccept(context.Background(), tunnelID, "udp", "success", "")
telemetry.IncProxyConnectionEvent(context.Background(), tunnelID, "udp", telemetry.ProxyConnectionOpened)
// Only increment activeUDP after a successful DialUDP
if e := pm.getEntry(tunnelID); e != nil {
e.activeUDP.Add(1)
}
clientsMutex.Lock()
clientConns[clientKey] = targetConn
clientsMutex.Unlock()
go func(clientKey string, targetConn *net.UDPConn, remoteAddr net.Addr) {
go func(clientKey string, targetConn *net.UDPConn, remoteAddr net.Addr, tunnelID string) {
start := time.Now()
result := "success"
defer func() {
// Always clean up when this goroutine exits
clientsMutex.Lock()
if storedConn, exists := clientConns[clientKey]; exists && storedConn == targetConn {
delete(clientConns, clientKey)
targetConn.Close()
if e := pm.getEntry(tunnelID); e != nil {
e.activeUDP.Add(-1)
}
}
clientsMutex.Unlock()
telemetry.ObserveProxyConnectionDuration(context.Background(), tunnelID, "udp", result, time.Since(start).Seconds())
telemetry.IncProxyConnectionEvent(context.Background(), tunnelID, "udp", telemetry.ProxyConnectionClosed)
}()
buffer := make([]byte, 65507)
@@ -363,25 +667,52 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
n, _, err := targetConn.ReadFromUDP(buffer)
if err != nil {
logger.Error("Error reading from target: %v", err)
result = "failure"
return // defer will handle cleanup
}
// bytes from target -> client (direction=out)
if pm.currentTunnelID != "" && n > 0 {
if pm.asyncBytes {
if e := pm.getEntry(pm.currentTunnelID); e != nil {
e.bytesOutUDP.Add(uint64(n))
}
} else {
if e := pm.getEntry(pm.currentTunnelID); e != nil {
telemetry.AddTunnelBytesSet(context.Background(), int64(n), e.attrOutUDP)
}
}
}
_, err = conn.WriteTo(buffer[:n], remoteAddr)
if err != nil {
logger.Error("Error writing to client: %v", err)
telemetry.IncProxyDrops(context.Background(), pm.currentTunnelID, "udp")
result = "failure"
return // defer will handle cleanup
}
}
}(clientKey, targetConn, remoteAddr)
}(clientKey, targetConn, remoteAddr, tunnelID)
}
_, err = targetConn.Write(buffer[:n])
written, err := targetConn.Write(buffer[:n])
if err != nil {
logger.Error("Error writing to target: %v", err)
telemetry.IncProxyDrops(context.Background(), pm.currentTunnelID, "udp")
targetConn.Close()
clientsMutex.Lock()
delete(clientConns, clientKey)
clientsMutex.Unlock()
} else if pm.currentTunnelID != "" && written > 0 {
if pm.asyncBytes {
if e := pm.getEntry(pm.currentTunnelID); e != nil {
e.bytesInUDP.Add(uint64(written))
}
} else {
if e := pm.getEntry(pm.currentTunnelID); e != nil {
telemetry.AddTunnelBytesSet(context.Background(), int64(written), e.attrInUDP)
}
}
}
}
}

17
stub.go
View File

@@ -8,25 +8,32 @@ import (
)
func setupClientsNative(client *websocket.Client, host string) {
return // This function is not implemented for non-Linux systems.
_ = client
_ = host
// No-op for non-Linux systems
}
func closeWgServiceNative() {
// No-op for non-Linux systems
return
}
func clientsOnConnectNative() {
// No-op for non-Linux systems
return
}
func clientsHandleNewtConnectionNative(publicKey, endpoint string) {
_ = publicKey
_ = endpoint
// No-op for non-Linux systems
return
}
func clientsAddProxyTargetNative(pm *proxy.ProxyManager, tunnelIp string) {
_ = pm
_ = tunnelIp
// No-op for non-Linux systems
}
func clientsStartDirectRelayNative(tunnelIP string) {
_ = tunnelIP
// No-op for non-Linux systems
return
}

213
util/util.go Normal file
View File

@@ -0,0 +1,213 @@
package util
import (
"encoding/base64"
"encoding/binary"
"encoding/hex"
"fmt"
"net"
"strings"
mathrand "math/rand/v2"
"github.com/fosrl/newt/logger"
"golang.zx2c4.com/wireguard/device"
)
func ResolveDomain(domain string) (string, error) {
// trim whitespace
domain = strings.TrimSpace(domain)
// Remove any protocol prefix if present (do this first, before splitting host/port)
domain = strings.TrimPrefix(domain, "http://")
domain = strings.TrimPrefix(domain, "https://")
// if there are any trailing slashes, remove them
domain = strings.TrimSuffix(domain, "/")
// Check if there's a port in the domain
host, port, err := net.SplitHostPort(domain)
if err != nil {
// No port found, use the domain as is
host = domain
port = ""
}
// Lookup IP addresses
ips, err := net.LookupIP(host)
if err != nil {
return "", fmt.Errorf("DNS lookup failed: %v", err)
}
if len(ips) == 0 {
return "", fmt.Errorf("no IP addresses found for domain %s", host)
}
// Get the first IPv4 address if available
var ipAddr string
for _, ip := range ips {
if ipv4 := ip.To4(); ipv4 != nil {
ipAddr = ipv4.String()
break
}
}
// If no IPv4 found, use the first IP (might be IPv6)
if ipAddr == "" {
ipAddr = ips[0].String()
}
// Add port back if it existed
if port != "" {
ipAddr = net.JoinHostPort(ipAddr, port)
}
return ipAddr, nil
}
func ParseLogLevel(level string) logger.LogLevel {
switch strings.ToUpper(level) {
case "DEBUG":
return logger.DEBUG
case "INFO":
return logger.INFO
case "WARN":
return logger.WARN
case "ERROR":
return logger.ERROR
case "FATAL":
return logger.FATAL
default:
return logger.INFO // default to INFO if invalid level provided
}
}
// find an available UDP port in the range [minPort, maxPort] and also the next port for the wgtester
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
if maxPort < minPort {
return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort)
}
// We need to check port+1 as well, so adjust the max port to avoid going out of range
adjustedMaxPort := maxPort - 1
if adjustedMaxPort < minPort {
return 0, fmt.Errorf("insufficient port range to find consecutive ports: min=%d, max=%d", minPort, maxPort)
}
// Create a slice of all ports in the range (excluding the last one)
portRange := make([]uint16, adjustedMaxPort-minPort+1)
for i := range portRange {
portRange[i] = minPort + uint16(i)
}
// Fisher-Yates shuffle to randomize the port order
for i := len(portRange) - 1; i > 0; i-- {
j := mathrand.IntN(i + 1)
portRange[i], portRange[j] = portRange[j], portRange[i]
}
// Try each port in the randomized order
for _, port := range portRange {
// Check if port is available
addr1 := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: int(port),
}
conn1, err1 := net.ListenUDP("udp", addr1)
if err1 != nil {
continue // Port is in use or there was an error, try next port
}
conn1.Close()
return port, nil
}
return 0, fmt.Errorf("no available consecutive UDP ports found in range %d-%d", minPort, maxPort)
}
func FixKey(key string) string {
// Remove any whitespace
key = strings.TrimSpace(key)
// Decode from base64
decoded, err := base64.StdEncoding.DecodeString(key)
if err != nil {
logger.Fatal("Error decoding base64: %v", err)
}
// Convert to hex
return hex.EncodeToString(decoded)
}
// this is the opposite of FixKey
func UnfixKey(hexKey string) string {
// Decode from hex
decoded, err := hex.DecodeString(hexKey)
if err != nil {
logger.Fatal("Error decoding hex: %v", err)
}
// Convert to base64
return base64.StdEncoding.EncodeToString(decoded)
}
func MapToWireGuardLogLevel(level logger.LogLevel) int {
switch level {
case logger.DEBUG:
return device.LogLevelVerbose
// case logger.INFO:
// return device.LogLevel
case logger.WARN:
return device.LogLevelError
case logger.ERROR, logger.FATAL:
return device.LogLevelSilent
default:
return device.LogLevelSilent
}
}
// GetProtocol returns protocol number from IPv4 packet (fast path)
func GetProtocol(packet []byte) (uint8, bool) {
if len(packet) < 20 {
return 0, false
}
version := packet[0] >> 4
if version == 4 {
return packet[9], true
} else if version == 6 {
if len(packet) < 40 {
return 0, false
}
return packet[6], true
}
return 0, false
}
// GetDestPort returns destination port from TCP/UDP packet (fast path)
func GetDestPort(packet []byte) (uint16, bool) {
if len(packet) < 20 {
return 0, false
}
version := packet[0] >> 4
var headerLen int
if version == 4 {
ihl := packet[0] & 0x0F
headerLen = int(ihl) * 4
if len(packet) < headerLen+4 {
return 0, false
}
} else if version == 6 {
headerLen = 40
if len(packet) < headerLen+4 {
return 0, false
}
} else {
return 0, false
}
// Destination port is at bytes 2-3 of TCP/UDP header
port := binary.BigEndian.Uint16(packet[headerLen+2 : headerLen+4])
return port, true
}

View File

@@ -7,6 +7,7 @@ import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
@@ -18,6 +19,11 @@ import (
"github.com/fosrl/newt/logger"
"github.com/gorilla/websocket"
"context"
"github.com/fosrl/newt/internal/telemetry"
"go.opentelemetry.io/otel"
)
type Client struct {
@@ -37,6 +43,8 @@ type Client struct {
writeMux sync.Mutex
clientType string // Type of client (e.g., "newt", "olm")
tlsConfig TLSConfig
metricsCtxMu sync.RWMutex
metricsCtx context.Context
configNeedsSave bool // Flag to track if config needs to be saved
}
@@ -81,6 +89,26 @@ func (c *Client) OnTokenUpdate(callback func(token string)) {
c.onTokenUpdate = callback
}
func (c *Client) metricsContext() context.Context {
c.metricsCtxMu.RLock()
defer c.metricsCtxMu.RUnlock()
if c.metricsCtx != nil {
return c.metricsCtx
}
return context.Background()
}
func (c *Client) setMetricsContext(ctx context.Context) {
c.metricsCtxMu.Lock()
c.metricsCtx = ctx
c.metricsCtxMu.Unlock()
}
// MetricsContext exposes the context used for telemetry emission when a connection is active.
func (c *Client) MetricsContext() context.Context {
return c.metricsContext()
}
// NewClient creates a new websocket client
func NewClient(clientType string, ID, secret string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) {
config := &Config{
@@ -140,6 +168,7 @@ func (c *Client) Close() error {
// Set connection status to false
c.setConnected(false)
telemetry.SetWSConnectionState(false)
// Close the WebSocket connection gracefully
if c.conn != nil {
@@ -170,7 +199,31 @@ func (c *Client) SendMessage(messageType string, data interface{}) error {
c.writeMux.Lock()
defer c.writeMux.Unlock()
return c.conn.WriteJSON(msg)
if err := c.conn.WriteJSON(msg); err != nil {
return err
}
telemetry.IncWSMessage(c.metricsContext(), "out", "text")
return nil
}
// SendMessage sends a message through the WebSocket connection
func (c *Client) SendMessageNoLog(messageType string, data interface{}) error {
if c.conn == nil {
return fmt.Errorf("not connected")
}
msg := WSMessage{
Type: messageType,
Data: data,
}
c.writeMux.Lock()
defer c.writeMux.Unlock()
if err := c.conn.WriteJSON(msg); err != nil {
return err
}
telemetry.IncWSMessage(c.metricsContext(), "out", "text")
return nil
}
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func()) {
@@ -265,8 +318,12 @@ func (c *Client) getToken() (string, error) {
return "", fmt.Errorf("failed to marshal token request data: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Create a new request
req, err := http.NewRequest(
req, err := http.NewRequestWithContext(
ctx,
"POST",
baseEndpoint+"/api/v1/auth/"+c.clientType+"/get-token",
bytes.NewBuffer(jsonData),
@@ -288,6 +345,8 @@ func (c *Client) getToken() (string, error) {
}
resp, err := client.Do(req)
if err != nil {
telemetry.IncConnAttempt(ctx, "auth", "failure")
telemetry.IncConnError(ctx, "auth", classifyConnError(err))
return "", fmt.Errorf("failed to request new token: %w", err)
}
defer resp.Body.Close()
@@ -295,6 +354,16 @@ func (c *Client) getToken() (string, error) {
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
telemetry.IncConnAttempt(ctx, "auth", "failure")
etype := "io_error"
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
etype = "auth_failed"
}
telemetry.IncConnError(ctx, "auth", etype)
// Reconnect reason mapping for auth failures
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
telemetry.IncReconnect(ctx, c.config.ID, "client", telemetry.ReasonAuthError)
}
return "", fmt.Errorf("failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
}
@@ -313,10 +382,55 @@ func (c *Client) getToken() (string, error) {
}
logger.Debug("Received token: %s", tokenResp.Data.Token)
telemetry.IncConnAttempt(ctx, "auth", "success")
return tokenResp.Data.Token, nil
}
// classifyConnError maps to fixed, low-cardinality error_type values.
// Allowed enum: dial_timeout, tls_handshake, auth_failed, io_error
func classifyConnError(err error) string {
if err == nil {
return ""
}
msg := strings.ToLower(err.Error())
switch {
case strings.Contains(msg, "tls") || strings.Contains(msg, "certificate"):
return "tls_handshake"
case strings.Contains(msg, "timeout") || strings.Contains(msg, "i/o timeout") || strings.Contains(msg, "deadline exceeded"):
return "dial_timeout"
case strings.Contains(msg, "unauthorized") || strings.Contains(msg, "forbidden"):
return "auth_failed"
default:
// Group remaining network/socket errors as io_error to avoid label explosion
return "io_error"
}
}
func classifyWSDisconnect(err error) (result, reason string) {
if err == nil {
return "success", "normal"
}
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
return "success", "normal"
}
if ne, ok := err.(net.Error); ok && ne.Timeout() {
return "error", "timeout"
}
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
return "error", "unexpected_close"
}
msg := strings.ToLower(err.Error())
switch {
case strings.Contains(msg, "eof"):
return "error", "eof"
case strings.Contains(msg, "reset"):
return "error", "connection_reset"
default:
return "error", "read_error"
}
}
func (c *Client) connectWithRetry() {
for {
select {
@@ -335,9 +449,13 @@ func (c *Client) connectWithRetry() {
}
func (c *Client) establishConnection() error {
ctx := context.Background()
// Get token for authentication
token, err := c.getToken()
if err != nil {
telemetry.IncConnAttempt(ctx, "websocket", "failure")
telemetry.IncConnError(ctx, "websocket", classifyConnError(err))
return fmt.Errorf("failed to get token: %w", err)
}
@@ -370,7 +488,12 @@ func (c *Client) establishConnection() error {
q.Set("clientType", c.clientType)
u.RawQuery = q.Encode()
// Connect to WebSocket
// Connect to WebSocket (optional span)
tr := otel.Tracer("newt")
ctx, span := tr.Start(ctx, "ws.connect")
defer span.End()
start := time.Now()
dialer := websocket.DefaultDialer
// Use new TLS configuration method
@@ -392,18 +515,42 @@ func (c *Client) establishConnection() error {
logger.Debug("WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
}
conn, _, err := dialer.Dial(u.String(), nil)
conn, _, err := dialer.DialContext(ctx, u.String(), nil)
lat := time.Since(start).Seconds()
if err != nil {
telemetry.IncConnAttempt(ctx, "websocket", "failure")
etype := classifyConnError(err)
telemetry.IncConnError(ctx, "websocket", etype)
telemetry.ObserveWSConnectLatency(ctx, lat, "failure", etype)
// Map handshake-related errors to reconnect reasons where appropriate
if etype == "tls_handshake" {
telemetry.IncReconnect(ctx, c.config.ID, "client", telemetry.ReasonHandshakeError)
} else if etype == "dial_timeout" {
telemetry.IncReconnect(ctx, c.config.ID, "client", telemetry.ReasonTimeout)
} else {
telemetry.IncReconnect(ctx, c.config.ID, "client", telemetry.ReasonError)
}
telemetry.IncWSReconnect(ctx, etype)
return fmt.Errorf("failed to connect to WebSocket: %w", err)
}
telemetry.IncConnAttempt(ctx, "websocket", "success")
telemetry.ObserveWSConnectLatency(ctx, lat, "success", "")
c.conn = conn
c.setConnected(true)
telemetry.SetWSConnectionState(true)
c.setMetricsContext(ctx)
sessionStart := time.Now()
// Wire up pong handler for metrics
c.conn.SetPongHandler(func(appData string) error {
telemetry.IncWSMessage(c.metricsContext(), "in", "pong")
return nil
})
// Start the ping monitor
go c.pingMonitor()
// Start the read pump with disconnect detection
go c.readPumpWithDisconnectDetection()
go c.readPumpWithDisconnectDetection(sessionStart)
if c.onConnect != nil {
err := c.saveConfig()
@@ -496,6 +643,9 @@ func (c *Client) pingMonitor() {
}
c.writeMux.Lock()
err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout))
if err == nil {
telemetry.IncWSMessage(c.metricsContext(), "out", "ping")
}
c.writeMux.Unlock()
if err != nil {
// Check if we're shutting down before logging error and reconnecting
@@ -505,6 +655,8 @@ func (c *Client) pingMonitor() {
return
default:
logger.Error("Ping failed: %v", err)
telemetry.IncWSKeepaliveFailure(c.metricsContext(), "ping_write")
telemetry.IncWSReconnect(c.metricsContext(), "ping_write")
c.reconnect()
return
}
@@ -514,17 +666,26 @@ func (c *Client) pingMonitor() {
}
// readPumpWithDisconnectDetection reads messages and triggers reconnect on error
func (c *Client) readPumpWithDisconnectDetection() {
func (c *Client) readPumpWithDisconnectDetection(started time.Time) {
ctx := c.metricsContext()
disconnectReason := "shutdown"
disconnectResult := "success"
defer func() {
if c.conn != nil {
c.conn.Close()
}
if !started.IsZero() {
telemetry.ObserveWSSessionDuration(ctx, time.Since(started).Seconds(), disconnectResult)
}
telemetry.IncWSDisconnect(ctx, disconnectReason, disconnectResult)
// Only attempt reconnect if we're not shutting down
select {
case <-c.done:
// Shutting down, don't reconnect
return
default:
telemetry.IncWSReconnect(ctx, disconnectReason)
c.reconnect()
}
}()
@@ -532,23 +693,33 @@ func (c *Client) readPumpWithDisconnectDetection() {
for {
select {
case <-c.done:
disconnectReason = "shutdown"
disconnectResult = "success"
return
default:
var msg WSMessage
err := c.conn.ReadJSON(&msg)
if err == nil {
telemetry.IncWSMessage(c.metricsContext(), "in", "text")
}
if err != nil {
// Check if we're shutting down before logging error
select {
case <-c.done:
// Expected during shutdown, don't log as error
logger.Debug("WebSocket connection closed during shutdown")
disconnectReason = "shutdown"
disconnectResult = "success"
return
default:
// Unexpected error during normal operation
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) {
logger.Error("WebSocket read error: %v", err)
} else {
logger.Debug("WebSocket connection closed: %v", err)
disconnectResult, disconnectReason = classifyWSDisconnect(err)
if disconnectResult == "error" {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) {
logger.Error("WebSocket read error: %v", err)
} else {
logger.Debug("WebSocket connection closed: %v", err)
}
}
return // triggers reconnect via defer
}
@@ -565,6 +736,7 @@ func (c *Client) readPumpWithDisconnectDetection() {
func (c *Client) reconnect() {
c.setConnected(false)
telemetry.SetWSConnectionState(false)
if c.conn != nil {
c.conn.Close()
c.conn = nil

999
wg/wg.go
View File

@@ -1,999 +0,0 @@
//go:build linux
package wg
import (
"encoding/json"
"errors"
"fmt"
"net"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/network"
"github.com/fosrl/newt/websocket"
"github.com/vishvananda/netlink"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/curve25519"
"golang.org/x/exp/rand"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type WgConfig struct {
IpAddress string `json:"ipAddress"`
Peers []Peer `json:"peers"`
}
type Peer struct {
PublicKey string `json:"publicKey"`
AllowedIPs []string `json:"allowedIps"`
Endpoint string `json:"endpoint"`
}
type PeerBandwidth struct {
PublicKey string `json:"publicKey"`
BytesIn float64 `json:"bytesIn"`
BytesOut float64 `json:"bytesOut"`
}
type PeerReading struct {
BytesReceived int64
BytesTransmitted int64
LastChecked time.Time
}
type WireGuardService struct {
interfaceName string
mtu int
client *websocket.Client
wgClient *wgctrl.Client
config WgConfig
key wgtypes.Key
keyFilePath string
newtId string
lastReadings map[string]PeerReading
mu sync.Mutex
Port uint16
stopHolepunch chan struct{}
host string
serverPubKey string
holePunchEndpoint string
token string
stopGetConfig func()
interfaceCreated bool
}
// Add this type definition
type fixedPortBind struct {
port uint16
conn.Bind
}
func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) {
// Ignore the requested port and use our fixed port
return b.Bind.Open(b.port)
}
func NewFixedPortBind(port uint16) conn.Bind {
return &fixedPortBind{
port: port,
Bind: conn.NewDefaultBind(),
}
}
// find an available UDP port in the range [minPort, maxPort] and also the next port for the wgtester
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
if maxPort < minPort {
return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort)
}
// We need to check port+1 as well, so adjust the max port to avoid going out of range
adjustedMaxPort := maxPort - 1
if adjustedMaxPort < minPort {
return 0, fmt.Errorf("insufficient port range to find consecutive ports: min=%d, max=%d", minPort, maxPort)
}
// Create a slice of all ports in the range (excluding the last one)
portRange := make([]uint16, adjustedMaxPort-minPort+1)
for i := range portRange {
portRange[i] = minPort + uint16(i)
}
// Fisher-Yates shuffle to randomize the port order
rand.Seed(uint64(time.Now().UnixNano()))
for i := len(portRange) - 1; i > 0; i-- {
j := rand.Intn(i + 1)
portRange[i], portRange[j] = portRange[j], portRange[i]
}
// Try each port in the randomized order
for _, port := range portRange {
// Check if port is available
addr1 := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: int(port),
}
conn1, err1 := net.ListenUDP("udp", addr1)
if err1 != nil {
continue // Port is in use or there was an error, try next port
}
// Check if port+1 is also available
addr2 := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: int(port + 1),
}
conn2, err2 := net.ListenUDP("udp", addr2)
if err2 != nil {
// The next port is not available, so close the first connection and try again
conn1.Close()
continue
}
// Both ports are available, close connections and return the first port
conn1.Close()
conn2.Close()
return port, nil
}
return 0, fmt.Errorf("no available consecutive UDP ports found in range %d-%d", minPort, maxPort)
}
func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client) (*WireGuardService, error) {
wgClient, err := wgctrl.New()
if err != nil {
return nil, fmt.Errorf("failed to create WireGuard client: %v", err)
}
var key wgtypes.Key
var port uint16
// if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file
key, err = wgtypes.GeneratePrivateKey()
if err != nil {
return nil, fmt.Errorf("failed to generate private key: %v", err)
}
// Load or generate private key
if generateAndSaveKeyTo != "" {
if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) {
keyData, err := os.ReadFile(generateAndSaveKeyTo)
if err != nil {
return nil, fmt.Errorf("failed to read private key: %v", err)
}
key, err = wgtypes.ParseKey(strings.TrimSpace(string(keyData)))
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %v", err)
}
} else {
err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0600)
if err != nil {
return nil, fmt.Errorf("failed to save private key: %v", err)
}
}
}
// Get the existing wireguard port
device, err := wgClient.Device(interfaceName)
if err == nil {
port = uint16(device.ListenPort)
// also set the private key to the existing key
key = device.PrivateKey
if port != 0 {
logger.Info("WireGuard interface %s already exists with port %d\n", interfaceName, port)
} else {
port, err = FindAvailableUDPPort(49152, 65535)
if err != nil {
fmt.Printf("Error finding available port: %v\n", err)
return nil, err
}
}
} else {
port, err = FindAvailableUDPPort(49152, 65535)
if err != nil {
fmt.Printf("Error finding available port: %v\n", err)
return nil, err
}
}
service := &WireGuardService{
interfaceName: interfaceName,
mtu: mtu,
client: wsClient,
wgClient: wgClient,
key: key,
Port: port,
keyFilePath: generateAndSaveKeyTo,
newtId: newtId,
host: host,
lastReadings: make(map[string]PeerReading),
stopHolepunch: make(chan struct{}),
}
// Register websocket handlers
wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig)
wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer)
wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer)
wsClient.RegisterHandler("newt/wg/peer/update", service.handleUpdatePeer)
return service, nil
}
func (s *WireGuardService) Close(rm bool) {
if s.stopGetConfig != nil {
s.stopGetConfig()
s.stopGetConfig = nil
}
s.wgClient.Close()
// Remove the WireGuard interface
if rm {
if err := s.removeInterface(); err != nil {
logger.Error("Failed to remove WireGuard interface: %v", err)
}
// Remove the private key file
// if s.keyFilePath != "" {
// if err := os.Remove(s.keyFilePath); err != nil {
// logger.Error("Failed to remove private key file: %v", err)
// }
// }
}
}
func (s *WireGuardService) StartHolepunch(serverPubKey string, endpoint string) {
// if the device is already created dont start a new holepunch
if s.interfaceCreated {
return
}
s.serverPubKey = serverPubKey
s.holePunchEndpoint = endpoint
logger.Debug("Starting UDP hole punch to %s", s.holePunchEndpoint)
s.stopHolepunch = make(chan struct{})
// start the UDP holepunch
go s.keepSendingUDPHolePunch(s.holePunchEndpoint)
}
func (s *WireGuardService) SetToken(token string) {
s.token = token
}
func (s *WireGuardService) LoadRemoteConfig() error {
s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{
"publicKey": s.key.PublicKey().String(),
"port": s.Port,
}, 2*time.Second)
logger.Info("Requesting WireGuard configuration from remote server")
go s.periodicBandwidthCheck()
return nil
}
func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
var config WgConfig
logger.Debug("Received message: %v", msg)
logger.Info("Received WireGuard clients configuration from remote server")
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &config); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
s.config = config
if s.stopGetConfig != nil {
s.stopGetConfig()
s.stopGetConfig = nil
}
// Ensure the WireGuard interface and peers are configured
if err := s.ensureWireguardInterface(config); err != nil {
logger.Error("Failed to ensure WireGuard interface: %v", err)
}
if err := s.ensureWireguardPeers(config.Peers); err != nil {
logger.Error("Failed to ensure WireGuard peers: %v", err)
}
}
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
// Check if the WireGuard interface exists
_, err := netlink.LinkByName(s.interfaceName)
if err != nil {
if _, ok := err.(netlink.LinkNotFoundError); ok {
// Interface doesn't exist, so create it
err = s.createWireGuardInterface()
if err != nil {
logger.Fatal("Failed to create WireGuard interface: %v", err)
}
s.interfaceCreated = true
logger.Info("Created WireGuard interface %s\n", s.interfaceName)
} else {
logger.Fatal("Error checking for WireGuard interface: %v", err)
}
} else {
logger.Info("WireGuard interface %s already exists\n", s.interfaceName)
// get the exising wireguard port
device, err := s.wgClient.Device(s.interfaceName)
if err != nil {
return fmt.Errorf("failed to get device: %v", err)
}
// get the existing port
s.Port = uint16(device.ListenPort)
logger.Info("WireGuard interface %s already exists with port %d\n", s.interfaceName, s.Port)
s.interfaceCreated = true
return nil
}
// stop the holepunch its a channel
if s.stopHolepunch != nil {
close(s.stopHolepunch)
s.stopHolepunch = nil
}
logger.Info("Assigning IP address %s to interface %s\n", wgconfig.IpAddress, s.interfaceName)
// Assign IP address to the interface
err = s.assignIPAddress(wgconfig.IpAddress)
if err != nil {
logger.Fatal("Failed to assign IP address: %v", err)
}
// Check if the interface already exists
_, err = s.wgClient.Device(s.interfaceName)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("interface %s does not exist", s.interfaceName)
}
return fmt.Errorf("failed to get device: %v", err)
}
// Parse the private key
key, err := wgtypes.ParseKey(s.key.String())
if err != nil {
return fmt.Errorf("failed to parse private key: %v", err)
}
config := wgtypes.Config{
PrivateKey: &key,
ListenPort: new(int),
}
// Use the service's fixed port instead of the config port
*config.ListenPort = int(s.Port)
// Create and configure the WireGuard interface
err = s.wgClient.ConfigureDevice(s.interfaceName, config)
if err != nil {
return fmt.Errorf("failed to configure WireGuard device: %v", err)
}
// bring up the interface
link, err := netlink.LinkByName(s.interfaceName)
if err != nil {
return fmt.Errorf("failed to get interface: %v", err)
}
if err := netlink.LinkSetMTU(link, s.mtu); err != nil {
return fmt.Errorf("failed to set MTU: %v", err)
}
if err := netlink.LinkSetUp(link); err != nil {
return fmt.Errorf("failed to bring up interface: %v", err)
}
// if err := s.ensureMSSClamping(); err != nil {
// logger.Warn("Failed to ensure MSS clamping: %v", err)
// }
logger.Info("WireGuard interface %s created and configured", s.interfaceName)
return nil
}
func (s *WireGuardService) createWireGuardInterface() error {
wgLink := &netlink.GenericLink{
LinkAttrs: netlink.LinkAttrs{Name: s.interfaceName},
LinkType: "wireguard",
}
return netlink.LinkAdd(wgLink)
}
func (s *WireGuardService) assignIPAddress(ipAddress string) error {
link, err := netlink.LinkByName(s.interfaceName)
if err != nil {
return fmt.Errorf("failed to get interface: %v", err)
}
addr, err := netlink.ParseAddr(ipAddress)
if err != nil {
return fmt.Errorf("failed to parse IP address: %v", err)
}
return netlink.AddrAdd(link, addr)
}
func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
// get the current peers
device, err := s.wgClient.Device(s.interfaceName)
if err != nil {
return fmt.Errorf("failed to get device: %v", err)
}
// get the peer public keys
var currentPeers []string
for _, peer := range device.Peers {
currentPeers = append(currentPeers, peer.PublicKey.String())
}
// remove any peers that are not in the config
for _, peer := range currentPeers {
found := false
for _, configPeer := range peers {
if peer == configPeer.PublicKey {
found = true
break
}
}
if !found {
err := s.removePeer(peer)
if err != nil {
return fmt.Errorf("failed to remove peer: %v", err)
}
}
}
// add any peers that are in the config but not in the current peers
for _, configPeer := range peers {
found := false
for _, peer := range currentPeers {
if configPeer.PublicKey == peer {
found = true
break
}
}
if !found {
err := s.addPeer(configPeer)
if err != nil {
return fmt.Errorf("failed to add peer: %v", err)
}
}
}
return nil
}
func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)
var peer Peer
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
}
if err := json.Unmarshal(jsonData, &peer); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
}
err = s.addPeer(peer)
if err != nil {
logger.Info("Error adding peer: %v", err)
return
}
}
func (s *WireGuardService) addPeer(peer Peer) error {
pubKey, err := wgtypes.ParseKey(peer.PublicKey)
if err != nil {
return fmt.Errorf("failed to parse public key: %v", err)
}
// parse allowed IPs into array of net.IPNet
var allowedIPs []net.IPNet
for _, ipStr := range peer.AllowedIPs {
_, ipNet, err := net.ParseCIDR(ipStr)
if err != nil {
return fmt.Errorf("failed to parse allowed IP: %v", err)
}
allowedIPs = append(allowedIPs, *ipNet)
}
// add keep alive using *time.Duration of 1 second
keepalive := time.Second
var peerConfig wgtypes.PeerConfig
if peer.Endpoint != "" {
endpoint, err := net.ResolveUDPAddr("udp", peer.Endpoint)
if err != nil {
return fmt.Errorf("failed to resolve endpoint address: %w", err)
}
peerConfig = wgtypes.PeerConfig{
PublicKey: pubKey,
AllowedIPs: allowedIPs,
PersistentKeepaliveInterval: &keepalive,
Endpoint: endpoint,
}
} else {
peerConfig = wgtypes.PeerConfig{
PublicKey: pubKey,
AllowedIPs: allowedIPs,
PersistentKeepaliveInterval: &keepalive,
}
logger.Info("Added peer with no endpoint!")
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peerConfig},
}
if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil {
return fmt.Errorf("failed to add peer: %v", err)
}
logger.Info("Peer %s added successfully", peer.PublicKey)
return nil
}
func (s *WireGuardService) handleRemovePeer(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)
// parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" }
type RemoveRequest struct {
PublicKey string `json:"publicKey"`
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
}
var request RemoveRequest
if err := json.Unmarshal(jsonData, &request); err != nil {
logger.Info("Error unmarshaling data: %v", err)
return
}
if err := s.removePeer(request.PublicKey); err != nil {
logger.Info("Error removing peer: %v", err)
return
}
}
func (s *WireGuardService) removePeer(publicKey string) error {
pubKey, err := wgtypes.ParseKey(publicKey)
if err != nil {
return fmt.Errorf("failed to parse public key: %v", err)
}
peerConfig := wgtypes.PeerConfig{
PublicKey: pubKey,
Remove: true,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peerConfig},
}
if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil {
return fmt.Errorf("failed to remove peer: %v", err)
}
logger.Info("Peer %s removed successfully", publicKey)
return nil
}
func (s *WireGuardService) handleUpdatePeer(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)
// Define a struct to match the incoming message structure with optional fields
type UpdatePeerRequest struct {
PublicKey string `json:"publicKey"`
AllowedIPs []string `json:"allowedIps,omitempty"`
Endpoint string `json:"endpoint,omitempty"`
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
var request UpdatePeerRequest
if err := json.Unmarshal(jsonData, &request); err != nil {
logger.Info("Error unmarshaling peer data: %v", err)
return
}
// First, get the current peer configuration to preserve any unmodified fields
device, err := s.wgClient.Device(s.interfaceName)
if err != nil {
logger.Info("Error getting WireGuard device: %v", err)
return
}
pubKey, err := wgtypes.ParseKey(request.PublicKey)
if err != nil {
logger.Info("Error parsing public key: %v", err)
return
}
// Find the existing peer configuration
var currentPeer *wgtypes.Peer
for _, p := range device.Peers {
if p.PublicKey == pubKey {
currentPeer = &p
break
}
}
if currentPeer == nil {
logger.Info("Peer %s not found, cannot update", request.PublicKey)
return
}
// Create the update peer config
peerConfig := wgtypes.PeerConfig{
PublicKey: pubKey,
UpdateOnly: true,
}
// Keep the default persistent keepalive of 1 second
keepalive := time.Second
peerConfig.PersistentKeepaliveInterval = &keepalive
// Handle Endpoint field special case
// If Endpoint is included in the request but empty, we want to remove the endpoint
// If Endpoint is not included, we don't modify it
endpointSpecified := false
for key := range msg.Data.(map[string]interface{}) {
if key == "endpoint" {
endpointSpecified = true
break
}
}
// Only update AllowedIPs if provided in the request
if len(request.AllowedIPs) > 0 {
var allowedIPs []net.IPNet
for _, ipStr := range request.AllowedIPs {
_, ipNet, err := net.ParseCIDR(ipStr)
if err != nil {
logger.Info("Error parsing allowed IP %s: %v", ipStr, err)
return
}
allowedIPs = append(allowedIPs, *ipNet)
}
peerConfig.AllowedIPs = allowedIPs
peerConfig.ReplaceAllowedIPs = true
logger.Info("Updating AllowedIPs for peer %s", request.PublicKey)
} else if endpointSpecified && request.Endpoint == "" {
peerConfig.ReplaceAllowedIPs = false
}
if endpointSpecified {
if request.Endpoint != "" {
// Update to new endpoint
endpoint, err := net.ResolveUDPAddr("udp", request.Endpoint)
if err != nil {
logger.Info("Error resolving endpoint address %s: %v", request.Endpoint, err)
return
}
peerConfig.Endpoint = endpoint
logger.Info("Updating Endpoint for peer %s to %s", request.PublicKey, request.Endpoint)
} else {
// specify any address to listen for any incoming packets
peerConfig.Endpoint = &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
}
logger.Info("Removing Endpoint for peer %s", request.PublicKey)
}
}
// Apply the configuration update
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peerConfig},
}
if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil {
logger.Info("Error updating peer configuration: %v", err)
return
}
logger.Info("Peer %s updated successfully", request.PublicKey)
}
func (s *WireGuardService) periodicBandwidthCheck() {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for range ticker.C {
if err := s.reportPeerBandwidth(); err != nil {
logger.Info("Failed to report peer bandwidth: %v", err)
}
}
}
func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
device, err := s.wgClient.Device(s.interfaceName)
if err != nil {
return nil, fmt.Errorf("failed to get device: %v", err)
}
peerBandwidths := []PeerBandwidth{}
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
for _, peer := range device.Peers {
publicKey := peer.PublicKey.String()
currentReading := PeerReading{
BytesReceived: peer.ReceiveBytes,
BytesTransmitted: peer.TransmitBytes,
LastChecked: now,
}
var bytesInDiff, bytesOutDiff float64
lastReading, exists := s.lastReadings[publicKey]
if exists {
timeDiff := currentReading.LastChecked.Sub(lastReading.LastChecked).Seconds()
if timeDiff > 0 {
// Calculate bytes transferred since last reading
bytesInDiff = float64(currentReading.BytesReceived - lastReading.BytesReceived)
bytesOutDiff = float64(currentReading.BytesTransmitted - lastReading.BytesTransmitted)
// Handle counter wraparound (if the counter resets or overflows)
if bytesInDiff < 0 {
bytesInDiff = float64(currentReading.BytesReceived)
}
if bytesOutDiff < 0 {
bytesOutDiff = float64(currentReading.BytesTransmitted)
}
// Convert to MB
bytesInMB := bytesInDiff / (1024 * 1024)
bytesOutMB := bytesOutDiff / (1024 * 1024)
peerBandwidths = append(peerBandwidths, PeerBandwidth{
PublicKey: publicKey,
BytesIn: bytesInMB,
BytesOut: bytesOutMB,
})
} else {
// If readings are too close together or time hasn't passed, report 0
peerBandwidths = append(peerBandwidths, PeerBandwidth{
PublicKey: publicKey,
BytesIn: 0,
BytesOut: 0,
})
}
} else {
// For first reading of a peer, report 0 to establish baseline
peerBandwidths = append(peerBandwidths, PeerBandwidth{
PublicKey: publicKey,
BytesIn: 0,
BytesOut: 0,
})
}
// Update the last reading
s.lastReadings[publicKey] = currentReading
}
// Clean up old peers
for publicKey := range s.lastReadings {
found := false
for _, peer := range device.Peers {
if peer.PublicKey.String() == publicKey {
found = true
break
}
}
if !found {
delete(s.lastReadings, publicKey)
}
}
return peerBandwidths, nil
}
func (s *WireGuardService) reportPeerBandwidth() error {
bandwidths, err := s.calculatePeerBandwidth()
if err != nil {
return fmt.Errorf("failed to calculate peer bandwidth: %v", err)
}
err = s.client.SendMessage("newt/receive-bandwidth", map[string]interface{}{
"bandwidthData": bandwidths,
})
if err != nil {
return fmt.Errorf("failed to send bandwidth data: %v", err)
}
return nil
}
func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error {
if s.serverPubKey == "" || s.token == "" {
logger.Debug("Server public key or token not set, skipping UDP hole punch")
return nil
}
// Parse server address
serverSplit := strings.Split(serverAddr, ":")
if len(serverSplit) < 2 {
return fmt.Errorf("invalid server address format, expected hostname:port")
}
serverHostname := serverSplit[0]
serverPort, err := strconv.ParseUint(serverSplit[1], 10, 16)
if err != nil {
return fmt.Errorf("failed to parse server port: %v", err)
}
// Resolve server hostname to IP
serverIPAddr := network.HostToAddr(serverHostname)
if serverIPAddr == nil {
return fmt.Errorf("failed to resolve server hostname")
}
// Get client IP based on route to server
clientIP := network.GetClientIP(serverIPAddr.IP)
// Create server and client configs
server := &network.Server{
Hostname: serverHostname,
Addr: serverIPAddr,
Port: uint16(serverPort),
}
client := &network.PeerNet{
IP: clientIP,
Port: s.Port,
NewtID: s.newtId,
}
// Setup raw connection with BPF filtering
rawConn := network.SetupRawConn(server, client)
defer rawConn.Close()
// Create JSON payload
payload := struct {
NewtID string `json:"newtId"`
Token string `json:"token"`
}{
NewtID: s.newtId,
Token: s.token,
}
// Convert payload to JSON
payloadBytes, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal payload: %v", err)
}
// Encrypt the payload using the server's WireGuard public key
encryptedPayload, err := s.encryptPayload(payloadBytes)
if err != nil {
return fmt.Errorf("failed to encrypt payload: %v", err)
}
// Send the encrypted packet using the raw connection
err = network.SendDataPacket(encryptedPayload, rawConn, server, client)
if err != nil {
return fmt.Errorf("failed to send UDP packet: %v", err)
}
return nil
}
func (s *WireGuardService) encryptPayload(payload []byte) (interface{}, error) {
// Generate an ephemeral keypair for this message
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
}
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
// Parse the server's public key
serverPubKey, err := wgtypes.ParseKey(s.serverPubKey)
if err != nil {
return nil, fmt.Errorf("failed to parse server public key: %v", err)
}
// Use X25519 for key exchange (replacing deprecated ScalarMult)
var ephPrivKeyFixed [32]byte
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
// Perform X25519 key exchange
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
if err != nil {
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
}
// Create an AEAD cipher using the shared secret
aead, err := chacha20poly1305.New(sharedSecret)
if err != nil {
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
}
// Generate a random nonce
nonce := make([]byte, aead.NonceSize())
if _, err := rand.Read(nonce); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %v", err)
}
// Encrypt the payload
ciphertext := aead.Seal(nil, nonce, payload, nil)
// Prepare the final encrypted message
encryptedMsg := struct {
EphemeralPublicKey string `json:"ephemeralPublicKey"`
Nonce []byte `json:"nonce"`
Ciphertext []byte `json:"ciphertext"`
}{
EphemeralPublicKey: ephemeralPublicKey.String(),
Nonce: nonce,
Ciphertext: ciphertext,
}
return encryptedMsg, nil
}
func (s *WireGuardService) keepSendingUDPHolePunch(host string) {
logger.Info("Starting UDP hole punch routine to %s:21820", host)
// send initial hole punch
if err := s.sendUDPHolePunch(host + ":21820"); err != nil {
logger.Debug("Failed to send initial UDP hole punch: %v", err)
}
ticker := time.NewTicker(3 * time.Second)
defer ticker.Stop()
timeout := time.NewTimer(15 * time.Second)
defer timeout.Stop()
for {
select {
case <-s.stopHolepunch:
logger.Info("Stopping UDP holepunch")
return
case <-timeout.C:
logger.Info("UDP holepunch routine timed out after 15 seconds")
return
case <-ticker.C:
if err := s.sendUDPHolePunch(host + ":21820"); err != nil {
logger.Debug("Failed to send UDP hole punch: %v", err)
}
}
}
}
func (s *WireGuardService) removeInterface() error {
// Remove the WireGuard interface
link, err := netlink.LinkByName(s.interfaceName)
if err != nil {
return fmt.Errorf("failed to get interface: %v", err)
}
err = netlink.LinkDel(link)
if err != nil {
return fmt.Errorf("failed to delete interface: %v", err)
}
logger.Info("WireGuard interface %s removed successfully", s.interfaceName)
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -3,12 +3,13 @@ package wgtester
import (
"encoding/binary"
"fmt"
"io"
"net"
"sync"
"time"
"github.com/fosrl/newt/logger"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/fosrl/newt/netstack2"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
)
@@ -39,7 +40,7 @@ type Server struct {
newtID string
outputPrefix string
useNetstack bool
tnet interface{} // Will be *netstack.Net when using netstack
tnet interface{} // Will be *netstack2.Net when using netstack
}
// NewServer creates a new connection test server using UDP
@@ -56,7 +57,7 @@ func NewServer(serverAddr string, serverPort uint16, newtID string) *Server {
}
// NewServerWithNetstack creates a new connection test server using WireGuard netstack
func NewServerWithNetstack(serverAddr string, serverPort uint16, newtID string, tnet *netstack.Net) *Server {
func NewServerWithNetstack(serverAddr string, serverPort uint16, newtID string, tnet *netstack2.Net) *Server {
return &Server{
serverAddr: serverAddr,
serverPort: serverPort + 1, // use the next port for the server
@@ -82,7 +83,7 @@ func (s *Server) Start() error {
if s.useNetstack && s.tnet != nil {
// Use WireGuard netstack
tnet := s.tnet.(*netstack.Net)
tnet := s.tnet.(*netstack2.Net)
udpAddr := &net.UDPAddr{Port: int(s.serverPort)}
netstackConn, err := tnet.ListenUDP(udpAddr)
if err != nil {
@@ -126,11 +127,11 @@ func (s *Server) Stop() {
s.conn.Close()
}
s.isRunning = false
logger.Info(s.outputPrefix + "Server stopped")
logger.Info("%sServer stopped", s.outputPrefix)
}
// RestartWithNetstack stops the current server and restarts it with netstack
func (s *Server) RestartWithNetstack(tnet *netstack.Net) error {
func (s *Server) RestartWithNetstack(tnet *netstack2.Net) error {
s.Stop()
// Update configuration to use netstack
@@ -161,7 +162,7 @@ func (s *Server) handleConnections() {
// Set read deadline to avoid blocking forever
err := s.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
if err != nil {
logger.Error(s.outputPrefix+"Error setting read deadline: %v", err)
logger.Error("%sError setting read deadline: %v", s.outputPrefix, err)
continue
}
@@ -187,7 +188,11 @@ func (s *Server) handleConnections() {
case <-s.shutdownCh:
return // Don't log error if we're shutting down
default:
logger.Error(s.outputPrefix+"Error reading from UDP: %v", err)
// Don't log EOF errors during shutdown - these are expected when connection is closed
if err == io.EOF {
return
}
logger.Error("%sError reading from UDP: %v", s.outputPrefix, err)
}
continue
}
@@ -219,7 +224,7 @@ func (s *Server) handleConnections() {
copy(responsePacket[5:13], buffer[5:13])
// Log response being sent for debugging
logger.Debug(s.outputPrefix+"Sending response to %s", addr.String())
// logger.Debug("%sSending response to %s", s.outputPrefix, addr.String())
// Send the response packet - handle both regular UDP and netstack UDP
if s.useNetstack {
@@ -233,9 +238,9 @@ func (s *Server) handleConnections() {
}
if err != nil {
logger.Error(s.outputPrefix+"Error sending response: %v", err)
logger.Error("%sError sending response: %v", s.outputPrefix, err)
} else {
logger.Debug(s.outputPrefix + "Response sent successfully")
// logger.Debug("%sResponse sent successfully", s.outputPrefix)
}
}
}