mirror of
https://github.com/fosrl/newt.git
synced 2026-03-12 18:04:28 -05:00
Compare commits
197 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
759e4c5bac | ||
|
|
8609be130e | ||
|
|
e06b8de0a7 | ||
|
|
0af6fb8fef | ||
|
|
9526768dfe | ||
|
|
051ab6ca9d | ||
|
|
2055b773fd | ||
|
|
1c9c98e2f6 | ||
|
|
9c57677493 | ||
|
|
ff825a51dd | ||
|
|
cdfc5733f0 | ||
|
|
cadbb50bdf | ||
|
|
4ac33c824b | ||
|
|
d91228f636 | ||
|
|
6c3b85bb9a | ||
|
|
77d99f1722 | ||
|
|
43e1341352 | ||
|
|
daa1a90e05 | ||
|
|
3739c237c7 | ||
|
|
ddde1758e5 | ||
|
|
dca29781f3 | ||
|
|
91bfd69179 | ||
|
|
060d876429 | ||
|
|
69952efe89 | ||
|
|
66949ca047 | ||
|
|
8c12db6dff | ||
|
|
b84d465763 | ||
|
|
a62567997d | ||
|
|
9bb4bbccb8 | ||
|
|
c3fad797e5 | ||
|
|
0168b4796e | ||
|
|
6c05d76c88 | ||
|
|
a701add824 | ||
|
|
d754cea397 | ||
|
|
31d52ad3ff | ||
|
|
e1ee4dc8f2 | ||
|
|
f9b6f36b4f | ||
|
|
0e961761b8 | ||
|
|
baf1b9b972 | ||
|
|
f078136b5a | ||
|
|
ca341a8bb0 | ||
|
|
80ae03997a | ||
|
|
5c94789d9a | ||
|
|
6c65cc8e5e | ||
|
|
a21a8e90fa | ||
|
|
3d5335f2cb | ||
|
|
94788edce3 | ||
|
|
2bbe037544 | ||
|
|
9b015e9f7c | ||
|
|
3305f711b9 | ||
|
|
ff7fe1275b | ||
|
|
1cbf41e094 | ||
|
|
9bc35433ef | ||
|
|
b8349aab4e | ||
|
|
3f29a553ae | ||
|
|
745045f619 | ||
|
|
3783a12055 | ||
|
|
a9b84c8c09 | ||
|
|
5c5ef4c7e6 | ||
|
|
6e9249e664 | ||
|
|
55be2a52a5 | ||
|
|
058330d41b | ||
|
|
5e7b970115 | ||
|
|
dc180abba9 | ||
|
|
004bb9b12d | ||
|
|
0637360b31 | ||
|
|
d5e0771094 | ||
|
|
1dcb68d694 | ||
|
|
865ac4b682 | ||
|
|
de5627b0b7 | ||
|
|
44470abd54 | ||
|
|
4bb0537c39 | ||
|
|
92fb96f9bd | ||
|
|
b68b7fe49d | ||
|
|
1da424bb20 | ||
|
|
22e5104a41 | ||
|
|
b96adeaa5b | ||
|
|
533e0b9ca7 | ||
|
|
bd86abe8d5 | ||
|
|
d978b27ebc | ||
|
|
cdfcf49d89 | ||
|
|
2fb4bf09ea | ||
|
|
dddae547f5 | ||
|
|
73a14f5fa1 | ||
|
|
67d5217379 | ||
|
|
9f1f1328f6 | ||
|
|
30da7eaa8b | ||
|
|
0fca3457c3 | ||
|
|
1271e8235e | ||
|
|
24c6edf3e0 | ||
|
|
1875c987fe | ||
|
|
7cb1f7e2c2 | ||
|
|
3f4f4fa15c | ||
|
|
bf33a3d81f | ||
|
|
21ffc0ff4b | ||
|
|
13de05eec6 | ||
|
|
0e76b77adc | ||
|
|
c604f46065 | ||
|
|
f02e29f4dd | ||
|
|
6d79856895 | ||
|
|
bbece243dd | ||
|
|
6948066ae4 | ||
|
|
3bcafbf07a | ||
|
|
87e2eb33db | ||
|
|
5ce3f4502d | ||
|
|
e5e733123b | ||
|
|
f417ee32fb | ||
|
|
37c96d0b3e | ||
|
|
78dc39e153 | ||
|
|
71485743ad | ||
|
|
458912e5be | ||
|
|
2bc91d6c68 | ||
|
|
95c3efc365 | ||
|
|
72a9e111dc | ||
|
|
3c86edf0d5 | ||
|
|
32b1b817ac | ||
|
|
02949be245 | ||
|
|
6d51cbf0c0 | ||
|
|
4dbf200cca | ||
|
|
d8b4fb4acb | ||
|
|
ac691517ae | ||
|
|
8a45f6fd63 | ||
|
|
7f650bbfdf | ||
|
|
15b40b0f24 | ||
|
|
e27e6fbce8 | ||
|
|
f9fb13a0d7 | ||
|
|
8db50d94c0 | ||
|
|
09568c1aaf | ||
|
|
c7d656214f | ||
|
|
d981a82b1c | ||
|
|
5dd5a56379 | ||
|
|
8c4d6e2e0a | ||
|
|
284f1ce627 | ||
|
|
cd466ac43f | ||
|
|
2256d1f041 | ||
|
|
40ca839771 | ||
|
|
01ec6a0ce0 | ||
|
|
d04f6cf702 | ||
|
|
cdaff27964 | ||
|
|
de96be810b | ||
|
|
ba43083f04 | ||
|
|
5196effdb8 | ||
|
|
d6edd6ca01 | ||
|
|
1b1323b553 | ||
|
|
bb95d10e86 | ||
|
|
da04746781 | ||
|
|
a38e0b3e98 | ||
|
|
6ced7b5af0 | ||
|
|
61b9615aea | ||
|
|
39f5782583 | ||
|
|
025c94e586 | ||
|
|
75e666c396 | ||
|
|
82a999eb87 | ||
|
|
921e72f628 | ||
|
|
46b33fdca6 | ||
|
|
9caa9fa31e | ||
|
|
dbbea6b34c | ||
|
|
491180c6a1 | ||
|
|
f49a276259 | ||
|
|
c71c6e0b1a | ||
|
|
972c9a9760 | ||
|
|
8f7ee2a8dc | ||
|
|
a737c3e8de | ||
|
|
1ba10c1b68 | ||
|
|
b1f2fe8283 | ||
|
|
a1fdb06add | ||
|
|
25d5fab02b | ||
|
|
2c8755f346 | ||
|
|
348cac66c8 | ||
|
|
6226a262d6 | ||
|
|
5b70feb6a5 | ||
|
|
0ec18d6655 | ||
|
|
7d60240572 | ||
|
|
ee3e7d1442 | ||
|
|
527321a415 | ||
|
|
ff07692248 | ||
|
|
8d3ae5afd7 | ||
|
|
ed99dce7e0 | ||
|
|
f1e07272bd | ||
|
|
a1a3d63fcf | ||
|
|
2a273dc435 | ||
|
|
ec05686523 | ||
|
|
915e7e44d1 | ||
|
|
a729b91ac3 | ||
|
|
ddc37658df | ||
|
|
7c780f7a4f | ||
|
|
6b1c1ed077 | ||
|
|
7a07437b22 | ||
|
|
d63d8d6f5e | ||
|
|
bda1d04f67 | ||
|
|
7f8ee37c7f | ||
|
|
6d2073a478 | ||
|
|
6048f244f1 | ||
|
|
9fec22a53b | ||
|
|
c086e69dd0 | ||
|
|
c729ab5fc6 | ||
|
|
552617cbb5 |
5
.env.example
Normal file
5
.env.example
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copy this file to .env and fill in your values
|
||||
# Required for connecting to Pangolin service
|
||||
PANGOLIN_ENDPOINT=https://example.com
|
||||
NEWT_ID=changeme-id
|
||||
NEWT_SECRET=changeme-secret
|
||||
646
.github/workflows/cicd.yml
vendored
646
.github/workflows/cicd.yml
vendored
@@ -1,64 +1,616 @@
|
||||
name: CI/CD Pipeline
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
contents: write # gh-release
|
||||
packages: write # GHCR push
|
||||
id-token: write # Keyless-Signatures & Attestations
|
||||
attestations: write # actions/attest-build-provenance
|
||||
security-events: write # upload-sarif
|
||||
actions: read
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
push:
|
||||
tags:
|
||||
- "[0-9]+.[0-9]+.[0-9]+"
|
||||
- "[0-9]+.[0-9]+.[0-9]+-rc.[0-9]+"
|
||||
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: "SemVer version to release (e.g., 1.2.3, no leading 'v')"
|
||||
required: true
|
||||
type: string
|
||||
publish_latest:
|
||||
description: "Also publish the 'latest' image tag"
|
||||
required: true
|
||||
type: boolean
|
||||
default: false
|
||||
publish_minor:
|
||||
description: "Also publish the 'major.minor' image tag (e.g., 1.2)"
|
||||
required: true
|
||||
type: boolean
|
||||
default: false
|
||||
target_branch:
|
||||
description: "Branch to tag"
|
||||
required: false
|
||||
default: "main"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event_name == 'workflow_dispatch' && github.event.inputs.version || github.ref_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
release:
|
||||
name: Build and Release
|
||||
runs-on: ubuntu-latest
|
||||
prepare:
|
||||
if: github.event_name == 'workflow_dispatch'
|
||||
name: Prepare release (create tag)
|
||||
runs-on: ubuntu-24.04
|
||||
permissions:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
- name: Validate version input
|
||||
shell: bash
|
||||
env:
|
||||
INPUT_VERSION: ${{ inputs.version }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
if ! [[ "$INPUT_VERSION" =~ ^[0-9]+\.[0-9]+\.[0-9]+(-rc\.[0-9]+)?$ ]]; then
|
||||
echo "Invalid version: $INPUT_VERSION (expected X.Y.Z or X.Y.Z-rc.N)" >&2
|
||||
exit 1
|
||||
fi
|
||||
- name: Create and push tag
|
||||
shell: bash
|
||||
env:
|
||||
TARGET_BRANCH: ${{ inputs.target_branch }}
|
||||
VERSION: ${{ inputs.version }}
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
git fetch --prune origin
|
||||
git checkout "$TARGET_BRANCH"
|
||||
git pull --ff-only origin "$TARGET_BRANCH"
|
||||
if git rev-parse -q --verify "refs/tags/$VERSION" >/dev/null; then
|
||||
echo "Tag $VERSION already exists" >&2
|
||||
exit 1
|
||||
fi
|
||||
git tag -a "$VERSION" -m "Release $VERSION"
|
||||
git push origin "refs/tags/$VERSION"
|
||||
release:
|
||||
if: ${{ github.event_name == 'workflow_dispatch' || (github.event_name == 'push' && github.actor != 'github-actions[bot]') }}
|
||||
name: Build and Release
|
||||
runs-on: ubuntu-24.04
|
||||
timeout-minutes: 120
|
||||
env:
|
||||
DOCKERHUB_IMAGE: docker.io/fosrl/${{ github.event.repository.name }}
|
||||
GHCR_IMAGE: ghcr.io/${{ github.repository_owner }}/${{ github.event.repository.name }}
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Capture created timestamp
|
||||
run: echo "IMAGE_CREATED=$(date -u +%Y-%m-%dT%H:%M:%SZ)" >> $GITHUB_ENV
|
||||
shell: bash
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3.7.0
|
||||
|
||||
- name: Extract tag name
|
||||
id: get-tag
|
||||
run: echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3.12.0
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: 1.25
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
|
||||
with:
|
||||
registry: docker.io
|
||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
|
||||
|
||||
- name: Update version in main.go
|
||||
run: |
|
||||
TAG=${{ env.TAG }}
|
||||
if [ -f main.go ]; then
|
||||
sed -i 's/version_replaceme/'"$TAG"'/' main.go
|
||||
echo "Updated main.go with version $TAG"
|
||||
else
|
||||
echo "main.go not found"
|
||||
fi
|
||||
- name: Log in to GHCR
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Build and push Docker images
|
||||
run: |
|
||||
TAG=${{ env.TAG }}
|
||||
make docker-build-release tag=$TAG
|
||||
- name: Normalize image names to lowercase
|
||||
run: |
|
||||
set -euo pipefail
|
||||
echo "GHCR_IMAGE=${GHCR_IMAGE,,}" >> "$GITHUB_ENV"
|
||||
echo "DOCKERHUB_IMAGE=${DOCKERHUB_IMAGE,,}" >> "$GITHUB_ENV"
|
||||
shell: bash
|
||||
|
||||
- name: Build binaries
|
||||
run: |
|
||||
make go-build-release
|
||||
- name: Extract tag name
|
||||
env:
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
INPUT_VERSION: ${{ inputs.version }}
|
||||
run: |
|
||||
if [ "$EVENT_NAME" = "workflow_dispatch" ]; then
|
||||
echo "TAG=${INPUT_VERSION}" >> $GITHUB_ENV
|
||||
else
|
||||
echo "TAG=${{ github.ref_name }}" >> $GITHUB_ENV
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Upload artifacts from /bin
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: binaries
|
||||
path: bin/
|
||||
- name: Validate pushed tag format (no leading 'v')
|
||||
if: ${{ github.event_name == 'push' }}
|
||||
shell: bash
|
||||
env:
|
||||
TAG_GOT: ${{ env.TAG }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
if [[ "$TAG_GOT" =~ ^[0-9]+\.[0-9]+\.[0-9]+(-rc\.[0-9]+)?$ ]]; then
|
||||
echo "Tag OK: $TAG_GOT"
|
||||
exit 0
|
||||
fi
|
||||
echo "ERROR: Tag '$TAG_GOT' is not allowed. Use 'X.Y.Z' or 'X.Y.Z-rc.N' (no leading 'v')." >&2
|
||||
exit 1
|
||||
- name: Wait for tag to be visible (dispatch only)
|
||||
if: ${{ github.event_name == 'workflow_dispatch' }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
for i in {1..90}; do
|
||||
if git ls-remote --tags origin "refs/tags/${TAG}" | grep -qE "refs/tags/${TAG}$"; then
|
||||
echo "Tag ${TAG} is visible on origin"; exit 0
|
||||
fi
|
||||
echo "Tag not yet visible, retrying... ($i/90)"
|
||||
sleep 2
|
||||
done
|
||||
echo "Tag ${TAG} not visible after waiting"; exit 1
|
||||
shell: bash
|
||||
|
||||
- name: Update version in main.go
|
||||
run: |
|
||||
TAG=${{ env.TAG }}
|
||||
if [ -f main.go ]; then
|
||||
sed -i 's/version_replaceme/'"$TAG"'/' main.go
|
||||
echo "Updated main.go with version $TAG"
|
||||
else
|
||||
echo "main.go not found"
|
||||
fi
|
||||
|
||||
- name: Ensure repository is at the tagged commit (dispatch only)
|
||||
if: ${{ github.event_name == 'workflow_dispatch' }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
git fetch --tags --force
|
||||
git checkout "refs/tags/${TAG}"
|
||||
echo "Checked out $(git rev-parse --short HEAD) for tag ${TAG}"
|
||||
shell: bash
|
||||
|
||||
- name: Detect release candidate (rc)
|
||||
run: |
|
||||
set -euo pipefail
|
||||
if [[ "${TAG}" =~ ^[0-9]+\.[0-9]+\.[0-9]+-rc\.[0-9]+$ ]]; then
|
||||
echo "IS_RC=true" >> $GITHUB_ENV
|
||||
else
|
||||
echo "IS_RC=false" >> $GITHUB_ENV
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # v6.1.0
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: Resolve publish-latest flag
|
||||
env:
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
PL_INPUT: ${{ inputs.publish_latest }}
|
||||
PL_VAR: ${{ vars.PUBLISH_LATEST }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
val="false"
|
||||
if [ "$EVENT_NAME" = "workflow_dispatch" ]; then
|
||||
if [ "${PL_INPUT}" = "true" ]; then val="true"; fi
|
||||
else
|
||||
if [ "${PL_VAR}" = "true" ]; then val="true"; fi
|
||||
fi
|
||||
echo "PUBLISH_LATEST=$val" >> $GITHUB_ENV
|
||||
shell: bash
|
||||
|
||||
- name: Resolve publish-minor flag
|
||||
env:
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
PM_INPUT: ${{ inputs.publish_minor }}
|
||||
PM_VAR: ${{ vars.PUBLISH_MINOR }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
val="false"
|
||||
if [ "$EVENT_NAME" = "workflow_dispatch" ]; then
|
||||
if [ "${PM_INPUT}" = "true" ]; then val="true"; fi
|
||||
else
|
||||
if [ "${PM_VAR}" = "true" ]; then val="true"; fi
|
||||
fi
|
||||
echo "PUBLISH_MINOR=$val" >> $GITHUB_ENV
|
||||
shell: bash
|
||||
|
||||
- name: Cache Go modules
|
||||
if: ${{ hashFiles('**/go.sum') != '' }}
|
||||
uses: actions/cache@9255dc7a253b0ccc959486e2bca901246202afeb # v5.0.1
|
||||
with:
|
||||
path: |
|
||||
~/.cache/go-build
|
||||
~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-
|
||||
- name: Go vet & test
|
||||
if: ${{ hashFiles('**/go.mod') != '' }}
|
||||
run: |
|
||||
go version
|
||||
go vet ./...
|
||||
go test ./... -race -covermode=atomic
|
||||
shell: bash
|
||||
|
||||
- name: Resolve license fallback
|
||||
run: echo "IMAGE_LICENSE=${{ github.event.repository.license.spdx_id || 'NOASSERTION' }}" >> $GITHUB_ENV
|
||||
shell: bash
|
||||
|
||||
- name: Resolve registries list (GHCR always, Docker Hub only if creds)
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
images="${GHCR_IMAGE}"
|
||||
if [ -n "${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}" ] && [ -n "${{ secrets.DOCKER_HUB_USERNAME }}" ]; then
|
||||
images="${images}\n${DOCKERHUB_IMAGE}"
|
||||
fi
|
||||
{
|
||||
echo 'IMAGE_LIST<<EOF'
|
||||
echo -e "$images"
|
||||
echo 'EOF'
|
||||
} >> "$GITHUB_ENV"
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # v5.10.0
|
||||
with:
|
||||
images: ${{ env.IMAGE_LIST }}
|
||||
tags: |
|
||||
type=semver,pattern={{version}},value=${{ env.TAG }}
|
||||
type=semver,pattern={{major}}.{{minor}},value=${{ env.TAG }},enable=${{ env.PUBLISH_MINOR == 'true' && env.IS_RC != 'true' }}
|
||||
type=raw,value=latest,enable=${{ env.IS_RC != 'true' }}
|
||||
flavor: |
|
||||
latest=false
|
||||
labels: |
|
||||
org.opencontainers.image.title=${{ github.event.repository.name }}
|
||||
org.opencontainers.image.version=${{ env.TAG }}
|
||||
org.opencontainers.image.revision=${{ github.sha }}
|
||||
org.opencontainers.image.source=${{ github.event.repository.html_url }}
|
||||
org.opencontainers.image.url=${{ github.event.repository.html_url }}
|
||||
org.opencontainers.image.documentation=${{ github.event.repository.html_url }}
|
||||
org.opencontainers.image.description=${{ github.event.repository.description }}
|
||||
org.opencontainers.image.licenses=${{ env.IMAGE_LICENSE }}
|
||||
org.opencontainers.image.created=${{ env.IMAGE_CREATED }}
|
||||
org.opencontainers.image.ref.name=${{ env.TAG }}
|
||||
org.opencontainers.image.authors=${{ github.repository_owner }}
|
||||
- name: Echo build config (non-secret)
|
||||
shell: bash
|
||||
env:
|
||||
IMAGE_TITLE: ${{ github.event.repository.name }}
|
||||
IMAGE_VERSION: ${{ env.TAG }}
|
||||
IMAGE_REVISION: ${{ github.sha }}
|
||||
IMAGE_SOURCE_URL: ${{ github.event.repository.html_url }}
|
||||
IMAGE_URL: ${{ github.event.repository.html_url }}
|
||||
IMAGE_DESCRIPTION: ${{ github.event.repository.description }}
|
||||
IMAGE_LICENSE: ${{ env.IMAGE_LICENSE }}
|
||||
DOCKERHUB_IMAGE: ${{ env.DOCKERHUB_IMAGE }}
|
||||
GHCR_IMAGE: ${{ env.GHCR_IMAGE }}
|
||||
DOCKER_HUB_USER: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
REPO: ${{ github.repository }}
|
||||
OWNER: ${{ github.repository_owner }}
|
||||
WORKFLOW_REF: ${{ github.workflow_ref }}
|
||||
REF: ${{ github.ref }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
RUN_URL: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
echo "=== OCI Label Values ==="
|
||||
echo "org.opencontainers.image.title=${IMAGE_TITLE}"
|
||||
echo "org.opencontainers.image.version=${IMAGE_VERSION}"
|
||||
echo "org.opencontainers.image.revision=${IMAGE_REVISION}"
|
||||
echo "org.opencontainers.image.source=${IMAGE_SOURCE_URL}"
|
||||
echo "org.opencontainers.image.url=${IMAGE_URL}"
|
||||
echo "org.opencontainers.image.description=${IMAGE_DESCRIPTION}"
|
||||
echo "org.opencontainers.image.licenses=${IMAGE_LICENSE}"
|
||||
echo
|
||||
echo "=== Images ==="
|
||||
echo "DOCKERHUB_IMAGE=${DOCKERHUB_IMAGE}"
|
||||
echo "GHCR_IMAGE=${GHCR_IMAGE}"
|
||||
echo "DOCKER_HUB_USERNAME=${DOCKER_HUB_USER}"
|
||||
echo
|
||||
echo "=== GitHub Kontext ==="
|
||||
echo "repository=${REPO}"
|
||||
echo "owner=${OWNER}"
|
||||
echo "workflow_ref=${WORKFLOW_REF}"
|
||||
echo "ref=${REF}"
|
||||
echo "ref_name=${REF_NAME}"
|
||||
echo "run_url=${RUN_URL}"
|
||||
echo
|
||||
echo "=== docker/metadata-action outputs (Tags/Labels), raw ==="
|
||||
echo "::group::tags"
|
||||
echo "${{ steps.meta.outputs.tags }}"
|
||||
echo "::endgroup::"
|
||||
echo "::group::labels"
|
||||
echo "${{ steps.meta.outputs.labels }}"
|
||||
echo "::endgroup::"
|
||||
- name: Build and push (Docker Hub + GHCR)
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # v6.18.0
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
platforms: linux/amd64,linux/arm64,linux/arm/v7
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=gha,scope=${{ github.repository }}
|
||||
cache-to: type=gha,mode=max,scope=${{ github.repository }}
|
||||
provenance: mode=max
|
||||
sbom: true
|
||||
|
||||
- name: Compute image digest refs
|
||||
run: |
|
||||
echo "DIGEST=${{ steps.build.outputs.digest }}" >> $GITHUB_ENV
|
||||
echo "GHCR_REF=$GHCR_IMAGE@${{ steps.build.outputs.digest }}" >> $GITHUB_ENV
|
||||
echo "DH_REF=$DOCKERHUB_IMAGE@${{ steps.build.outputs.digest }}" >> $GITHUB_ENV
|
||||
echo "Built digest: ${{ steps.build.outputs.digest }}"
|
||||
shell: bash
|
||||
|
||||
- name: Attest build provenance (GHCR)
|
||||
id: attest-ghcr
|
||||
uses: actions/attest-build-provenance@00014ed6ed5efc5b1ab7f7f34a39eb55d41aa4f8 # v3.1.0
|
||||
with:
|
||||
subject-name: ${{ env.GHCR_IMAGE }}
|
||||
subject-digest: ${{ steps.build.outputs.digest }}
|
||||
push-to-registry: true
|
||||
show-summary: true
|
||||
|
||||
- name: Attest build provenance (Docker Hub)
|
||||
continue-on-error: true
|
||||
id: attest-dh
|
||||
uses: actions/attest-build-provenance@00014ed6ed5efc5b1ab7f7f34a39eb55d41aa4f8 # v3.1.0
|
||||
with:
|
||||
subject-name: index.docker.io/fosrl/${{ github.event.repository.name }}
|
||||
subject-digest: ${{ steps.build.outputs.digest }}
|
||||
push-to-registry: true
|
||||
show-summary: true
|
||||
|
||||
- name: Install cosign
|
||||
uses: sigstore/cosign-installer@faadad0cce49287aee09b3a48701e75088a2c6ad # v4.0.0
|
||||
with:
|
||||
cosign-release: 'v3.0.2'
|
||||
|
||||
- name: Sanity check cosign private key
|
||||
env:
|
||||
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
|
||||
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
cosign public-key --key env://COSIGN_PRIVATE_KEY >/dev/null
|
||||
shell: bash
|
||||
|
||||
- name: Sign GHCR image (digest) with key (recursive)
|
||||
env:
|
||||
COSIGN_YES: "true"
|
||||
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
|
||||
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
echo "Signing ${GHCR_REF} (digest) recursively with provided key"
|
||||
cosign sign --key env://COSIGN_PRIVATE_KEY --recursive "${GHCR_REF}"
|
||||
echo "Waiting 30 seconds for signatures to propagate..."
|
||||
sleep 30
|
||||
shell: bash
|
||||
|
||||
- name: Generate SBOM (SPDX JSON)
|
||||
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # v0.33.1
|
||||
with:
|
||||
image-ref: ${{ env.GHCR_IMAGE }}@${{ steps.build.outputs.digest }}
|
||||
format: spdx-json
|
||||
output: sbom.spdx.json
|
||||
|
||||
- name: Validate SBOM JSON
|
||||
run: jq -e . sbom.spdx.json >/dev/null
|
||||
shell: bash
|
||||
|
||||
- name: Minify SBOM JSON (optional hardening)
|
||||
run: jq -c . sbom.spdx.json > sbom.min.json && mv sbom.min.json sbom.spdx.json
|
||||
shell: bash
|
||||
|
||||
- name: Create SBOM attestation (GHCR, private key)
|
||||
env:
|
||||
COSIGN_YES: "true"
|
||||
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
|
||||
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
cosign attest \
|
||||
--key env://COSIGN_PRIVATE_KEY \
|
||||
--type spdxjson \
|
||||
--predicate sbom.spdx.json \
|
||||
"${GHCR_REF}"
|
||||
shell: bash
|
||||
|
||||
- name: Create SBOM attestation (Docker Hub, private key)
|
||||
continue-on-error: true
|
||||
env:
|
||||
COSIGN_YES: "true"
|
||||
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
|
||||
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
|
||||
COSIGN_DOCKER_MEDIA_TYPES: "1"
|
||||
run: |
|
||||
set -euo pipefail
|
||||
cosign attest \
|
||||
--key env://COSIGN_PRIVATE_KEY \
|
||||
--type spdxjson \
|
||||
--predicate sbom.spdx.json \
|
||||
"${DH_REF}"
|
||||
shell: bash
|
||||
|
||||
- name: Keyless sign & verify GHCR digest (OIDC)
|
||||
env:
|
||||
COSIGN_YES: "true"
|
||||
WORKFLOW_REF: ${{ github.workflow_ref }} # owner/repo/.github/workflows/<file>@refs/tags/<tag>
|
||||
ISSUER: https://token.actions.githubusercontent.com
|
||||
run: |
|
||||
set -euo pipefail
|
||||
echo "Keyless signing ${GHCR_REF}"
|
||||
cosign sign --rekor-url https://rekor.sigstore.dev --recursive "${GHCR_REF}"
|
||||
echo "Verify keyless (OIDC) signature policy on ${GHCR_REF}"
|
||||
cosign verify \
|
||||
--certificate-oidc-issuer "${ISSUER}" \
|
||||
--certificate-identity "https://github.com/${WORKFLOW_REF}" \
|
||||
"${GHCR_REF}" -o text
|
||||
shell: bash
|
||||
|
||||
- name: Sign Docker Hub image (digest) with key (recursive)
|
||||
continue-on-error: true
|
||||
env:
|
||||
COSIGN_YES: "true"
|
||||
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
|
||||
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
|
||||
COSIGN_DOCKER_MEDIA_TYPES: "1"
|
||||
run: |
|
||||
set -euo pipefail
|
||||
echo "Signing ${DH_REF} (digest) recursively with provided key (Docker media types fallback)"
|
||||
cosign sign --key env://COSIGN_PRIVATE_KEY --recursive "${DH_REF}"
|
||||
shell: bash
|
||||
|
||||
- name: Keyless sign & verify Docker Hub digest (OIDC)
|
||||
continue-on-error: true
|
||||
env:
|
||||
COSIGN_YES: "true"
|
||||
ISSUER: https://token.actions.githubusercontent.com
|
||||
COSIGN_DOCKER_MEDIA_TYPES: "1"
|
||||
run: |
|
||||
set -euo pipefail
|
||||
echo "Keyless signing ${DH_REF} (force public-good Rekor)"
|
||||
cosign sign --rekor-url https://rekor.sigstore.dev --recursive "${DH_REF}"
|
||||
echo "Keyless verify via Rekor (strict identity)"
|
||||
if ! cosign verify \
|
||||
--rekor-url https://rekor.sigstore.dev \
|
||||
--certificate-oidc-issuer "${ISSUER}" \
|
||||
--certificate-identity "https://github.com/${{ github.workflow_ref }}" \
|
||||
"${DH_REF}" -o text; then
|
||||
echo "Rekor verify failed — retry offline bundle verify (no Rekor)"
|
||||
if ! cosign verify \
|
||||
--offline \
|
||||
--certificate-oidc-issuer "${ISSUER}" \
|
||||
--certificate-identity "https://github.com/${{ github.workflow_ref }}" \
|
||||
"${DH_REF}" -o text; then
|
||||
echo "Offline bundle verify failed — ignore tlog (TEMP for debugging)"
|
||||
cosign verify \
|
||||
--insecure-ignore-tlog=true \
|
||||
--certificate-oidc-issuer "${ISSUER}" \
|
||||
--certificate-identity "https://github.com/${{ github.workflow_ref }}" \
|
||||
"${DH_REF}" -o text || true
|
||||
fi
|
||||
fi
|
||||
- name: Verify signature (public key) GHCR digest + tag
|
||||
env:
|
||||
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
TAG_VAR="${TAG}"
|
||||
echo "Verifying (digest) ${GHCR_REF}"
|
||||
cosign verify --key env://COSIGN_PUBLIC_KEY "$GHCR_REF" -o text
|
||||
echo "Verifying (tag) $GHCR_IMAGE:$TAG_VAR"
|
||||
cosign verify --key env://COSIGN_PUBLIC_KEY "$GHCR_IMAGE:$TAG_VAR" -o text
|
||||
shell: bash
|
||||
|
||||
- name: Verify SBOM attestation (GHCR)
|
||||
env:
|
||||
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
|
||||
run: cosign verify-attestation --key env://COSIGN_PUBLIC_KEY --type spdxjson "$GHCR_REF" -o text
|
||||
shell: bash
|
||||
|
||||
- name: Verify SLSA provenance (GHCR)
|
||||
env:
|
||||
ISSUER: https://token.actions.githubusercontent.com
|
||||
WFREF: ${{ github.workflow_ref }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
# (optional) show which predicate types are present to aid debugging
|
||||
cosign download attestation "$GHCR_REF" \
|
||||
| jq -r '.payload | @base64d | fromjson | .predicateType' | sort -u || true
|
||||
# Verify the SLSA v1 provenance attestation (predicate URL)
|
||||
cosign verify-attestation \
|
||||
--type 'https://slsa.dev/provenance/v1' \
|
||||
--certificate-oidc-issuer "$ISSUER" \
|
||||
--certificate-identity "https://github.com/${WFREF}" \
|
||||
--rekor-url https://rekor.sigstore.dev \
|
||||
"$GHCR_REF" -o text
|
||||
shell: bash
|
||||
|
||||
- name: Verify signature (public key) Docker Hub digest
|
||||
continue-on-error: true
|
||||
env:
|
||||
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
|
||||
COSIGN_DOCKER_MEDIA_TYPES: "1"
|
||||
run: |
|
||||
set -euo pipefail
|
||||
echo "Verifying (digest) ${DH_REF} with Docker media types"
|
||||
cosign verify --key env://COSIGN_PUBLIC_KEY "${DH_REF}" -o text
|
||||
shell: bash
|
||||
|
||||
- name: Verify signature (public key) Docker Hub tag
|
||||
continue-on-error: true
|
||||
env:
|
||||
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
|
||||
COSIGN_DOCKER_MEDIA_TYPES: "1"
|
||||
run: |
|
||||
set -euo pipefail
|
||||
echo "Verifying (tag) $DOCKERHUB_IMAGE:$TAG with Docker media types"
|
||||
cosign verify --key env://COSIGN_PUBLIC_KEY "$DOCKERHUB_IMAGE:$TAG" -o text
|
||||
shell: bash
|
||||
|
||||
# - name: Trivy scan (GHCR image)
|
||||
# id: trivy
|
||||
# uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # v0.33.1
|
||||
# with:
|
||||
# image-ref: ${{ env.GHCR_IMAGE }}@${{ steps.build.outputs.digest }}
|
||||
# format: sarif
|
||||
# output: trivy-ghcr.sarif
|
||||
# ignore-unfixed: true
|
||||
# vuln-type: os,library
|
||||
# severity: CRITICAL,HIGH
|
||||
# exit-code: ${{ (vars.TRIVY_FAIL || '0') }}
|
||||
|
||||
# - name: Upload SARIF,trivy
|
||||
# if: ${{ always() && hashFiles('trivy-ghcr.sarif') != '' }}
|
||||
# uses: github/codeql-action/upload-sarif@fdbfb4d2750291e159f0156def62b853c2798ca2 # v4.31.5
|
||||
# with:
|
||||
# sarif_file: trivy-ghcr.sarif
|
||||
# category: Image Vulnerability Scan
|
||||
|
||||
- name: Build binaries
|
||||
env:
|
||||
CGO_ENABLED: "0"
|
||||
GOFLAGS: "-trimpath"
|
||||
run: |
|
||||
set -euo pipefail
|
||||
TAG_VAR="${TAG}"
|
||||
make -j 10 go-build-release tag=$TAG_VAR
|
||||
shell: bash
|
||||
|
||||
- name: Create GitHub Release
|
||||
uses: softprops/action-gh-release@5be0e66d93ac7ed76da52eca8bb058f665c3a5fe # v2.4.2
|
||||
with:
|
||||
tag_name: ${{ env.TAG }}
|
||||
generate_release_notes: true
|
||||
prerelease: ${{ env.IS_RC == 'true' }}
|
||||
files: |
|
||||
bin/*
|
||||
fail_on_unmatched_files: true
|
||||
draft: true
|
||||
body: |
|
||||
## Container Images
|
||||
- GHCR: `${{ env.GHCR_REF }}`
|
||||
- Docker Hub: `${{ env.DH_REF || 'N/A' }}`
|
||||
**Digest:** `${{ steps.build.outputs.digest }}`
|
||||
|
||||
132
.github/workflows/mirror.yaml
vendored
Normal file
132
.github/workflows/mirror.yaml
vendored
Normal file
@@ -0,0 +1,132 @@
|
||||
name: Mirror & Sign (Docker Hub to GHCR)
|
||||
|
||||
on:
|
||||
workflow_dispatch: {}
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
id-token: write # for keyless OIDC
|
||||
|
||||
env:
|
||||
SOURCE_IMAGE: docker.io/fosrl/newt
|
||||
DEST_IMAGE: ghcr.io/${{ github.repository_owner }}/${{ github.event.repository.name }}
|
||||
|
||||
jobs:
|
||||
mirror-and-dual-sign:
|
||||
runs-on: amd64-runner
|
||||
steps:
|
||||
- name: Install skopeo + jq
|
||||
run: |
|
||||
sudo apt-get update -y
|
||||
sudo apt-get install -y skopeo jq
|
||||
skopeo --version
|
||||
|
||||
- name: Install cosign
|
||||
uses: sigstore/cosign-installer@faadad0cce49287aee09b3a48701e75088a2c6ad # v4.0.0
|
||||
|
||||
- name: Input check
|
||||
run: |
|
||||
test -n "${SOURCE_IMAGE}" || (echo "SOURCE_IMAGE is empty" && exit 1)
|
||||
echo "Source : ${SOURCE_IMAGE}"
|
||||
echo "Target : ${DEST_IMAGE}"
|
||||
|
||||
# Auth for skopeo (containers-auth)
|
||||
- name: Skopeo login to GHCR
|
||||
run: |
|
||||
skopeo login ghcr.io -u "${{ github.actor }}" -p "${{ secrets.GITHUB_TOKEN }}"
|
||||
|
||||
# Auth for cosign (docker-config)
|
||||
- name: Docker login to GHCR (for cosign)
|
||||
run: |
|
||||
echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u "${{ github.actor }}" --password-stdin
|
||||
|
||||
- name: List source tags
|
||||
run: |
|
||||
set -euo pipefail
|
||||
skopeo list-tags --retry-times 3 docker://"${SOURCE_IMAGE}" \
|
||||
| jq -r '.Tags[]' | sort -u > src-tags.txt
|
||||
echo "Found source tags: $(wc -l < src-tags.txt)"
|
||||
head -n 20 src-tags.txt || true
|
||||
|
||||
- name: List destination tags (skip existing)
|
||||
run: |
|
||||
set -euo pipefail
|
||||
if skopeo list-tags --retry-times 3 docker://"${DEST_IMAGE}" >/tmp/dst.json 2>/dev/null; then
|
||||
jq -r '.Tags[]' /tmp/dst.json | sort -u > dst-tags.txt
|
||||
else
|
||||
: > dst-tags.txt
|
||||
fi
|
||||
echo "Existing destination tags: $(wc -l < dst-tags.txt)"
|
||||
|
||||
- name: Mirror, dual-sign, and verify
|
||||
env:
|
||||
# keyless
|
||||
COSIGN_YES: "true"
|
||||
# key-based
|
||||
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
|
||||
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
|
||||
# verify
|
||||
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
copied=0; skipped=0; v_ok=0; errs=0
|
||||
|
||||
issuer="https://token.actions.githubusercontent.com"
|
||||
id_regex="^https://github.com/${{ github.repository }}/.+"
|
||||
|
||||
while read -r tag; do
|
||||
[ -z "$tag" ] && continue
|
||||
|
||||
if grep -Fxq "$tag" dst-tags.txt; then
|
||||
echo "::notice ::Skip (exists) ${DEST_IMAGE}:${tag}"
|
||||
skipped=$((skipped+1))
|
||||
continue
|
||||
fi
|
||||
|
||||
echo "==> Copy ${SOURCE_IMAGE}:${tag} → ${DEST_IMAGE}:${tag}"
|
||||
if ! skopeo copy --all --retry-times 3 \
|
||||
docker://"${SOURCE_IMAGE}:${tag}" docker://"${DEST_IMAGE}:${tag}"; then
|
||||
echo "::warning title=Copy failed::${SOURCE_IMAGE}:${tag}"
|
||||
errs=$((errs+1)); continue
|
||||
fi
|
||||
copied=$((copied+1))
|
||||
|
||||
digest="$(skopeo inspect --retry-times 3 docker://"${DEST_IMAGE}:${tag}" | jq -r '.Digest')"
|
||||
ref="${DEST_IMAGE}@${digest}"
|
||||
|
||||
echo "==> cosign sign (keyless) --recursive ${ref}"
|
||||
if ! cosign sign --recursive "${ref}"; then
|
||||
echo "::warning title=Keyless sign failed::${ref}"
|
||||
errs=$((errs+1))
|
||||
fi
|
||||
|
||||
echo "==> cosign sign (key) --recursive ${ref}"
|
||||
if ! cosign sign --key env://COSIGN_PRIVATE_KEY --recursive "${ref}"; then
|
||||
echo "::warning title=Key sign failed::${ref}"
|
||||
errs=$((errs+1))
|
||||
fi
|
||||
|
||||
echo "==> cosign verify (public key) ${ref}"
|
||||
if ! cosign verify --key env://COSIGN_PUBLIC_KEY "${ref}" -o text; then
|
||||
echo "::warning title=Verify(pubkey) failed::${ref}"
|
||||
errs=$((errs+1))
|
||||
fi
|
||||
|
||||
echo "==> cosign verify (keyless policy) ${ref}"
|
||||
if ! cosign verify \
|
||||
--certificate-oidc-issuer "${issuer}" \
|
||||
--certificate-identity-regexp "${id_regex}" \
|
||||
"${ref}" -o text; then
|
||||
echo "::warning title=Verify(keyless) failed::${ref}"
|
||||
errs=$((errs+1))
|
||||
else
|
||||
v_ok=$((v_ok+1))
|
||||
fi
|
||||
done < src-tags.txt
|
||||
|
||||
echo "---- Summary ----"
|
||||
echo "Copied : $copied"
|
||||
echo "Skipped : $skipped"
|
||||
echo "Verified OK : $v_ok"
|
||||
echo "Errors : $errs"
|
||||
23
.github/workflows/nix-build.yml
vendored
Normal file
23
.github/workflows/nix-build.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
name: Build Nix package
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
pull_request:
|
||||
paths:
|
||||
- go.mod
|
||||
- go.sum
|
||||
|
||||
jobs:
|
||||
nix-build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Install Nix
|
||||
uses: DeterminateSystems/nix-installer-action@main
|
||||
|
||||
- name: Build flake package
|
||||
run: |
|
||||
nix build .#pangolin-newt -L
|
||||
48
.github/workflows/nix-dependabot-update-hash.yml
vendored
Normal file
48
.github/workflows/nix-dependabot-update-hash.yml
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
name: Update Nix Package Hash On Dependabot PRs
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize]
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
nix-update:
|
||||
if: github.actor == 'dependabot[bot]'
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ github.head_ref }}
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Install Nix
|
||||
uses: DeterminateSystems/nix-installer-action@main
|
||||
|
||||
- name: Run nix-update
|
||||
run: |
|
||||
nix run nixpkgs#nix-update -- --flake pangolin-newt --no-src --version skip
|
||||
|
||||
- name: Check for changes
|
||||
id: changes
|
||||
run: |
|
||||
if git diff --quiet; then
|
||||
echo "changed=false" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Commit and push changes
|
||||
if: steps.changes.outputs.changed == 'true'
|
||||
run: |
|
||||
git config user.name "dependabot[bot]"
|
||||
git config user.email "dependabot[bot]@users.noreply.github.com"
|
||||
|
||||
git add .
|
||||
git commit -m "chore(nix): fix hash for updated go dependencies"
|
||||
git push
|
||||
37
.github/workflows/stale-bot.yml
vendored
Normal file
37
.github/workflows/stale-bot.yml
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
name: Mark and Close Stale Issues
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 0 * * *'
|
||||
workflow_dispatch: # Allow manual trigger
|
||||
|
||||
permissions:
|
||||
contents: write # only for delete-branch option
|
||||
issues: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1
|
||||
with:
|
||||
days-before-stale: 14
|
||||
days-before-close: 14
|
||||
stale-issue-message: 'This issue has been automatically marked as stale due to 14 days of inactivity. It will be closed in 14 days if no further activity occurs.'
|
||||
close-issue-message: 'This issue has been automatically closed due to inactivity. If you believe this is still relevant, please open a new issue with up-to-date information.'
|
||||
stale-issue-label: 'stale'
|
||||
|
||||
exempt-issue-labels: 'needs investigating, networking, new feature, reverse proxy, bug, api, authentication, documentation, enhancement, help wanted, good first issue, question'
|
||||
|
||||
exempt-all-issue-assignees: true
|
||||
|
||||
only-labels: ''
|
||||
exempt-pr-labels: ''
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
|
||||
operations-per-run: 100
|
||||
remove-stale-when-updated: true
|
||||
delete-branch: false
|
||||
enable-statistics: true
|
||||
32
.github/workflows/test.yml
vendored
32
.github/workflows/test.yml
vendored
@@ -10,22 +10,30 @@ on:
|
||||
- dev
|
||||
|
||||
jobs:
|
||||
test:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
target:
|
||||
- local
|
||||
- docker-build
|
||||
- go-build-release-darwin-amd64
|
||||
- go-build-release-darwin-arm64
|
||||
- go-build-release-freebsd-amd64
|
||||
- go-build-release-freebsd-arm64
|
||||
- go-build-release-linux-amd64
|
||||
- go-build-release-linux-arm32-v6
|
||||
- go-build-release-linux-arm32-v7
|
||||
- go-build-release-linux-riscv64
|
||||
- go-build-release-windows-amd64
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v6
|
||||
uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # v6.1.0
|
||||
with:
|
||||
go-version: 1.25
|
||||
|
||||
- name: Build go
|
||||
run: go build
|
||||
|
||||
- name: Build Docker image
|
||||
run: make build
|
||||
|
||||
- name: Build binaries
|
||||
run: make go-build-release
|
||||
- name: Build targets via `make`
|
||||
run: make ${{ matrix.target }}
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -1,4 +1,3 @@
|
||||
newt
|
||||
.DS_Store
|
||||
bin/
|
||||
nohup.out
|
||||
@@ -6,3 +5,6 @@ nohup.out
|
||||
*.iml
|
||||
certs/
|
||||
newt_arm64
|
||||
key
|
||||
/.direnv/
|
||||
/result*
|
||||
|
||||
@@ -4,11 +4,7 @@ Contributions are welcome!
|
||||
|
||||
Please see the contribution and local development guide on the docs page before getting started:
|
||||
|
||||
https://docs.fossorial.io/development
|
||||
|
||||
For ideas about what features to work on and our future plans, please see the roadmap:
|
||||
|
||||
https://docs.fossorial.io/roadmap
|
||||
https://docs.pangolin.net/development/contributing
|
||||
|
||||
### Licensing Considerations
|
||||
|
||||
@@ -21,4 +17,4 @@ By creating this pull request, I grant the project maintainers an unlimited,
|
||||
perpetual license to use, modify, and redistribute these contributions under any terms they
|
||||
choose, including both the AGPLv3 and the Fossorial Commercial license terms. I
|
||||
represent that I have the right to grant this license for all contributed content.
|
||||
```
|
||||
```
|
||||
14
Dockerfile
14
Dockerfile
@@ -1,5 +1,8 @@
|
||||
FROM golang:1.25-alpine AS builder
|
||||
|
||||
# Install git and ca-certificates
|
||||
RUN apk --no-cache add ca-certificates git tzdata
|
||||
|
||||
# Set the working directory inside the container
|
||||
WORKDIR /app
|
||||
|
||||
@@ -13,15 +16,18 @@ RUN go mod download
|
||||
COPY . .
|
||||
|
||||
# Build the application
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -o /newt
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o /newt
|
||||
|
||||
FROM alpine:3.22 AS runner
|
||||
FROM alpine:3.23 AS runner
|
||||
|
||||
RUN apk --no-cache add ca-certificates tzdata
|
||||
RUN apk --no-cache add ca-certificates tzdata iputils
|
||||
|
||||
COPY --from=builder /newt /usr/local/bin/
|
||||
COPY entrypoint.sh /
|
||||
|
||||
# Admin/metrics endpoint (Prometheus scrape)
|
||||
EXPOSE 2112
|
||||
|
||||
RUN chmod +x /entrypoint.sh
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
CMD ["newt"]
|
||||
CMD ["newt"]
|
||||
|
||||
83
Makefile
83
Makefile
@@ -1,37 +1,70 @@
|
||||
.PHONY: all local docker-build docker-build-release
|
||||
|
||||
all: build push
|
||||
all: local
|
||||
|
||||
local:
|
||||
CGO_ENABLED=0 go build -o ./bin/newt
|
||||
|
||||
docker-build:
|
||||
docker build -t fosrl/newt:latest .
|
||||
|
||||
docker-build-release:
|
||||
@if [ -z "$(tag)" ]; then \
|
||||
echo "Error: tag is required. Usage: make docker-build-release tag=<tag>"; \
|
||||
exit 1; \
|
||||
fi
|
||||
docker buildx build --platform linux/arm/v7,linux/arm64,linux/amd64 -t fosrl/newt:latest -f Dockerfile --push .
|
||||
docker buildx build --platform linux/arm/v7,linux/arm64,linux/amd64 -t fosrl/newt:$(tag) -f Dockerfile --push .
|
||||
docker buildx build . \
|
||||
--platform linux/arm/v7,linux/arm64,linux/amd64 \
|
||||
-t fosrl/newt:latest \
|
||||
-t fosrl/newt:$(tag) \
|
||||
-f Dockerfile \
|
||||
--push
|
||||
|
||||
build:
|
||||
docker build -t fosrl/newt:latest .
|
||||
.PHONY: go-build-release \
|
||||
go-build-release-linux-arm64 go-build-release-linux-arm32-v7 \
|
||||
go-build-release-linux-arm32-v6 go-build-release-linux-amd64 \
|
||||
go-build-release-linux-riscv64 go-build-release-darwin-arm64 \
|
||||
go-build-release-darwin-amd64 go-build-release-windows-amd64 \
|
||||
go-build-release-freebsd-amd64 go-build-release-freebsd-arm64
|
||||
|
||||
push:
|
||||
docker push fosrl/newt:latest
|
||||
go-build-release: \
|
||||
go-build-release-linux-arm64 \
|
||||
go-build-release-linux-arm32-v7 \
|
||||
go-build-release-linux-arm32-v6 \
|
||||
go-build-release-linux-amd64 \
|
||||
go-build-release-linux-riscv64 \
|
||||
go-build-release-darwin-arm64 \
|
||||
go-build-release-darwin-amd64 \
|
||||
go-build-release-windows-amd64 \
|
||||
go-build-release-freebsd-amd64 \
|
||||
go-build-release-freebsd-arm64
|
||||
|
||||
test:
|
||||
docker run fosrl/newt:latest
|
||||
|
||||
local:
|
||||
CGO_ENABLED=0 go build -o newt
|
||||
|
||||
go-build-release:
|
||||
go-build-release-linux-arm64:
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -o bin/newt_linux_arm64
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=7 go build -o bin/newt_linux_arm32
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=6 go build -o bin/newt_linux_arm32v6
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/newt_linux_amd64
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=riscv64 go build -o bin/newt_linux_riscv64
|
||||
CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -o bin/newt_darwin_arm64
|
||||
CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -o bin/newt_darwin_amd64
|
||||
CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -o bin/newt_windows_amd64.exe
|
||||
CGO_ENABLED=0 GOOS=freebsd GOARCH=amd64 go build -o bin/newt_freebsd_amd64
|
||||
CGO_ENABLED=0 GOOS=freebsd GOARCH=arm64 go build -o bin/newt_freebsd_arm64
|
||||
|
||||
clean:
|
||||
rm newt
|
||||
go-build-release-linux-arm32-v7:
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=7 go build -o bin/newt_linux_arm32
|
||||
|
||||
go-build-release-linux-arm32-v6:
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=6 go build -o bin/newt_linux_arm32v6
|
||||
|
||||
go-build-release-linux-amd64:
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/newt_linux_amd64
|
||||
|
||||
go-build-release-linux-riscv64:
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=riscv64 go build -o bin/newt_linux_riscv64
|
||||
|
||||
go-build-release-darwin-arm64:
|
||||
CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -o bin/newt_darwin_arm64
|
||||
|
||||
go-build-release-darwin-amd64:
|
||||
CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -o bin/newt_darwin_amd64
|
||||
|
||||
go-build-release-windows-amd64:
|
||||
CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -o bin/newt_windows_amd64.exe
|
||||
|
||||
go-build-release-freebsd-amd64:
|
||||
CGO_ENABLED=0 GOOS=freebsd GOARCH=amd64 go build -o bin/newt_freebsd_amd64
|
||||
|
||||
go-build-release-freebsd-arm64:
|
||||
CGO_ENABLED=0 GOOS=freebsd GOARCH=arm64 go build -o bin/newt_freebsd_arm64
|
||||
|
||||
365
README.md
365
README.md
@@ -9,13 +9,7 @@ Newt is a fully user space [WireGuard](https://www.wireguard.com/) tunnel client
|
||||
|
||||
Newt is used with Pangolin and Gerbil as part of the larger system. See documentation below:
|
||||
|
||||
- [Full Documentation](https://docs.fossorial.io)
|
||||
|
||||
## Preview
|
||||
|
||||
<img src="public/screenshots/preview.png" alt="Preview"/>
|
||||
|
||||
_Sample output of a Newt connected to Pangolin and hosting various resource target proxies._
|
||||
- [Full Documentation](https://docs.pangolin.net/manage/sites/understanding-sites)
|
||||
|
||||
## Key Functions
|
||||
|
||||
@@ -31,367 +25,14 @@ When Newt receives WireGuard control messages, it will use the information encod
|
||||
|
||||
When Newt receives WireGuard control messages, it will use the information encoded to create a local low level TCP and UDP proxies attached to the virtual tunnel in order to relay traffic to programmed targets.
|
||||
|
||||
## CLI Args
|
||||
|
||||
- `id`: Newt ID generated by Pangolin to identify the client.
|
||||
- `secret`: A unique secret (not shared and kept private) used to authenticate the client ID with the websocket in order to receive commands.
|
||||
- `endpoint`: The endpoint where both Gerbil and Pangolin reside in order to connect to the websocket.
|
||||
|
||||
- `mtu` (optional): MTU for the internal WG interface. Default: 1280
|
||||
- `dns` (optional): DNS server to use to resolve the endpoint. Default: 9.9.9.9
|
||||
- `log-level` (optional): The log level to use (DEBUG, INFO, WARN, ERROR, FATAL). Default: INFO
|
||||
- `enforce-hc-cert` (optional): Enforce certificate validation for health checks. Default: false (accepts any cert)
|
||||
- `docker-socket` (optional): Set the Docker socket to use the container discovery integration
|
||||
- `ping-interval` (optional): Interval for pinging the server. Default: 3s
|
||||
- `ping-timeout` (optional): Timeout for each ping. Default: 5s
|
||||
- `updown` (optional): A script to be called when targets are added or removed.
|
||||
- `tls-client-cert` (optional): Client certificate (p12 or pfx) for mTLS. See [mTLS](#mtls)
|
||||
- `tls-client-cert` (optional): Path to client certificate (PEM format, optional if using PKCS12). See [mTLS](#mtls)
|
||||
- `tls-client-key` (optional): Path to private key for mTLS (PEM format, optional if using PKCS12)
|
||||
- `tls-ca-cert` (optional): Path to CA certificate to verify server (PEM format, optional if using PKCS12)
|
||||
- `docker-enforce-network-validation` (optional): Validate the container target is on the same network as the newt process. Default: false
|
||||
- `health-file` (optional): Check if connection to WG server (pangolin) is ok. creates a file if ok, removes it if not ok. Can be used with docker healtcheck to restart newt
|
||||
- `accept-clients` (optional): Enable WireGuard server mode to accept incoming newt client connections. Default: false
|
||||
- `generateAndSaveKeyTo` (optional): Path to save generated private key
|
||||
- `native` (optional): Use native WireGuard interface when accepting clients (requires WireGuard kernel module and Linux, must run as root). Default: false (uses userspace netstack)
|
||||
- `interface` (optional): Name of the WireGuard interface. Default: newt
|
||||
- `keep-interface` (optional): Keep the WireGuard interface. Default: false
|
||||
- `blueprint-file` (optional): Path to blueprint file to define Pangolin resources and configurations.
|
||||
- `no-cloud` (optional): Don't fail over to the cloud when using managed nodes in Pangolin Cloud. Default: false
|
||||
|
||||
## Environment Variables
|
||||
|
||||
All CLI arguments can be set using environment variables as an alternative to command line flags. Environment variables are particularly useful when running Newt in containerized environments.
|
||||
|
||||
- `PANGOLIN_ENDPOINT`: Endpoint of your pangolin server (equivalent to `--endpoint`)
|
||||
- `NEWT_ID`: Newt ID generated by Pangolin (equivalent to `--id`)
|
||||
- `NEWT_SECRET`: Newt secret for authentication (equivalent to `--secret`)
|
||||
- `MTU`: MTU for the internal WG interface. Default: 1280 (equivalent to `--mtu`)
|
||||
- `DNS`: DNS server to use to resolve the endpoint. Default: 9.9.9.9 (equivalent to `--dns`)
|
||||
- `LOG_LEVEL`: Log level (DEBUG, INFO, WARN, ERROR, FATAL). Default: INFO (equivalent to `--log-level`)
|
||||
- `DOCKER_SOCKET`: Path to Docker socket for container discovery (equivalent to `--docker-socket`)
|
||||
- `PING_INTERVAL`: Interval for pinging the server. Default: 3s (equivalent to `--ping-interval`)
|
||||
- `PING_TIMEOUT`: Timeout for each ping. Default: 5s (equivalent to `--ping-timeout`)
|
||||
- `UPDOWN_SCRIPT`: Path to updown script for target add/remove events (equivalent to `--updown`)
|
||||
- `TLS_CLIENT_CERT`: Path to client certificate for mTLS (equivalent to `--tls-client-cert`)
|
||||
- `TLS_CLIENT_CERT`: Path to client certificate for mTLS (equivalent to `--tls-client-cert`)
|
||||
- `TLS_CLIENT_KEY`: Path to private key for mTLS (equivalent to `--tls-client-key`)
|
||||
- `TLS_CA_CERT`: Path to CA certificate to verify server (equivalent to `--tls-ca-cert`)
|
||||
- `DOCKER_ENFORCE_NETWORK_VALIDATION`: Validate container targets are on same network. Default: false (equivalent to `--docker-enforce-network-validation`)
|
||||
- `ENFORCE_HC_CERT`: Enforce certificate validation for health checks. Default: false (equivalent to `--enforce-hc-cert`)
|
||||
- `HEALTH_FILE`: Path to health file for connection monitoring (equivalent to `--health-file`)
|
||||
- `ACCEPT_CLIENTS`: Enable WireGuard server mode. Default: false (equivalent to `--accept-clients`)
|
||||
- `GENERATE_AND_SAVE_KEY_TO`: Path to save generated private key (equivalent to `--generateAndSaveKeyTo`)
|
||||
- `USE_NATIVE_INTERFACE`: Use native WireGuard interface (Linux only). Default: false (equivalent to `--native`)
|
||||
- `INTERFACE`: Name of the WireGuard interface. Default: newt (equivalent to `--interface`)
|
||||
- `KEEP_INTERFACE`: Keep the WireGuard interface after shutdown. Default: false (equivalent to `--keep-interface`)
|
||||
- `CONFIG_FILE`: Load the config json from this file instead of in the home folder.
|
||||
- `BLUEPRINT_FILE`: Path to blueprint file to define Pangolin resources and configurations. (equivalent to `--blueprint-file`)
|
||||
- `NO_CLOUD`: Don't fail over to the cloud when using managed nodes in Pangolin Cloud. Default: false (equivalent to `--no-cloud`)
|
||||
|
||||
## Loading secrets from files
|
||||
|
||||
You can use `CONFIG_FILE` to define a location of a config file to store the credentials between runs.
|
||||
|
||||
```
|
||||
$ cat ~/.config/newt-client/config.json
|
||||
{
|
||||
"id": "spmzu8rbpzj1qq6",
|
||||
"secret": "f6v61mjutwme2kkydbw3fjo227zl60a2tsf5psw9r25hgae3",
|
||||
"endpoint": "https://pangolin.fossorial.io",
|
||||
"tlsClientCert": ""
|
||||
}
|
||||
```
|
||||
|
||||
This file is also written to when newt first starts up. So you do not need to run every time with --id and secret if you have run it once!
|
||||
|
||||
Default locations:
|
||||
|
||||
- **macOS**: `~/Library/Application Support/newt-client/config.json`
|
||||
- **Windows**: `%PROGRAMDATA%\newt\newt-client\config.json`
|
||||
- **Linux/Others**: `~/.config/newt-client/config.json`
|
||||
|
||||
## Examples
|
||||
|
||||
**Note**: When both environment variables and CLI arguments are provided, CLI arguments take precedence.
|
||||
|
||||
- Example:
|
||||
|
||||
```bash
|
||||
newt \
|
||||
--id 31frd0uzbjvp721 \
|
||||
--secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 \
|
||||
--endpoint https://example.com
|
||||
```
|
||||
|
||||
You can also run it with Docker compose. For example, a service in your `docker-compose.yml` might look like this using environment vars (recommended):
|
||||
|
||||
```yaml
|
||||
services:
|
||||
newt:
|
||||
image: fosrl/newt
|
||||
container_name: newt
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- PANGOLIN_ENDPOINT=https://example.com
|
||||
- NEWT_ID=2ix2t8xk22ubpfy
|
||||
- NEWT_SECRET=nnisrfsdfc7prqsp9ewo1dvtvci50j5uiqotez00dgap0ii2
|
||||
- HEALTH_FILE=/tmp/healthy
|
||||
```
|
||||
|
||||
You can also pass the CLI args to the container:
|
||||
|
||||
```yaml
|
||||
services:
|
||||
newt:
|
||||
image: fosrl/newt
|
||||
container_name: newt
|
||||
restart: unless-stopped
|
||||
command:
|
||||
- --id 31frd0uzbjvp721
|
||||
- --secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6
|
||||
- --endpoint https://example.com
|
||||
- --health-file /tmp/healthy
|
||||
```
|
||||
|
||||
## Accept Client Connections
|
||||
|
||||
When the `--accept-clients` flag is enabled (or `ACCEPT_CLIENTS=true` environment variable is set), Newt operates as a WireGuard server that can accept incoming client connections from other devices. This enables peer-to-peer connectivity through the Newt instance.
|
||||
|
||||
### How It Works
|
||||
|
||||
In client acceptance mode, Newt:
|
||||
|
||||
- **Creates a WireGuard service** that can accept incoming connections from other WireGuard clients
|
||||
- **Starts a connection testing server** (WGTester) that responds to connectivity checks from remote clients
|
||||
- **Manages peer configurations** dynamically based on Pangolin's instructions
|
||||
- **Enables bidirectional communication** between the Newt instance and connected clients
|
||||
|
||||
### Use Cases
|
||||
|
||||
- **Site-to-site connectivity**: Connect multiple locations through a central Newt instance
|
||||
- **Client access to private networks**: Allow remote clients to access resources behind the Newt instance
|
||||
- **Development environments**: Provide developers secure access to internal services
|
||||
|
||||
### Client Tunneling Modes
|
||||
|
||||
Newt supports two WireGuard tunneling modes:
|
||||
|
||||
#### Userspace Mode (Default)
|
||||
|
||||
By default, Newt uses a fully userspace WireGuard implementation using [netstack](https://github.com/WireGuard/wireguard-go/blob/master/tun/netstack/examples/http_server.go). This mode:
|
||||
|
||||
- **Does not require root privileges**
|
||||
- **Works on all supported platforms** (Linux, Windows, macOS)
|
||||
- **Does not require WireGuard kernel module** to be installed
|
||||
- **Runs entirely in userspace** - no system network interface is created
|
||||
- **Is containerization-friendly** - works seamlessly in Docker containers
|
||||
|
||||
This is the recommended mode for most deployments, especially containerized environments.
|
||||
|
||||
In this mode, TCP and UDP is proxied out of newt from the remote client using TCP/UDP resources in Pangolin.
|
||||
|
||||
#### Native Mode (Linux only)
|
||||
|
||||
When using the `--native` flag or setting `USE_NATIVE_INTERFACE=true`, Newt uses the native WireGuard kernel module. This mode:
|
||||
|
||||
- **Requires root privileges** to create and manage network interfaces
|
||||
- **Only works on Linux** with the WireGuard kernel module installed
|
||||
- **Creates a real network interface** (e.g., `newt0`) on the system
|
||||
- **May offer better performance** for high-throughput scenarios
|
||||
- **Requires proper network permissions** and may conflict with existing network configurations
|
||||
|
||||
In this mode it functions like a traditional VPN interface - all data arrives on the interface and you must get it to the destination (or access things locally).
|
||||
|
||||
#### Native Mode Requirements
|
||||
|
||||
To use native mode:
|
||||
|
||||
1. Run on a Linux system
|
||||
2. Install the WireGuard kernel module
|
||||
3. Run Newt as root (`sudo`)
|
||||
4. Ensure the system allows creation of network interfaces
|
||||
|
||||
Docker Compose example:
|
||||
|
||||
```yaml
|
||||
services:
|
||||
newt:
|
||||
image: fosrl/newt
|
||||
container_name: newt
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- PANGOLIN_ENDPOINT=https://example.com
|
||||
- NEWT_ID=2ix2t8xk22ubpfy
|
||||
- NEWT_SECRET=nnisrfsdfc7prqsp9ewo1dvtvci50j5uiqotez00dgap0ii2
|
||||
- ACCEPT_CLIENTS=true
|
||||
```
|
||||
|
||||
### Technical Details
|
||||
|
||||
When client acceptance is enabled:
|
||||
|
||||
- **WGTester Server**: Runs on `port + 1` (e.g., if WireGuard uses port 51820, WGTester uses 51821)
|
||||
- **Connection Testing**: Responds to UDP packets with magic header `0xDEADBEEF` for connectivity verification
|
||||
- **Dynamic Configuration**: Peer configurations are managed remotely through Pangolin
|
||||
- **Proxy Integration**: Can work with both userspace (netstack) and native WireGuard modes
|
||||
|
||||
**Note**: Client acceptance mode requires coordination with Pangolin for peer management and configuration distribution.
|
||||
|
||||
### Docker Socket Integration
|
||||
|
||||
Newt can integrate with the Docker socket to provide remote inspection of Docker containers. This allows Pangolin to query and retrieve detailed information about containers running on the Newt client, including metadata, network configuration, port mappings, and more.
|
||||
|
||||
**Configuration:**
|
||||
|
||||
You can specify the Docker socket path using the `--docker-socket` CLI argument or by setting the `DOCKER_SOCKET` environment variable. If the Docker socket is not available or accessible, Newt will gracefully disable Docker integration and continue normal operation.
|
||||
|
||||
Supported values include:
|
||||
|
||||
- Local UNIX socket (default):
|
||||
>You must mount the socket file into the container using a volume, so Newt can access it.
|
||||
|
||||
`unix:///var/run/docker.sock`
|
||||
|
||||
- TCP socket (e.g., via Docker Socket Proxy):
|
||||
|
||||
`tcp://localhost:2375`
|
||||
|
||||
- HTTP/HTTPS endpoints (e.g., remote Docker APIs):
|
||||
|
||||
`http://your-host:2375`
|
||||
|
||||
- SSH connections (experimental, requires SSH setup):
|
||||
|
||||
`ssh://user@host`
|
||||
|
||||
|
||||
```yaml
|
||||
services:
|
||||
newt:
|
||||
image: fosrl/newt
|
||||
container_name: newt
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock:ro
|
||||
environment:
|
||||
- PANGOLIN_ENDPOINT=https://example.com
|
||||
- NEWT_ID=2ix2t8xk22ubpfy
|
||||
- NEWT_SECRET=nnisrfsdfc7prqsp9ewo1dvtvci50j5uiqotez00dgap0ii2
|
||||
- DOCKER_SOCKET=unix:///var/run/docker.sock
|
||||
```
|
||||
>If you previously used just a path like `/var/run/docker.sock`, it still works — Newt assumes it is a UNIX socket by default.
|
||||
|
||||
#### Hostnames vs IPs
|
||||
|
||||
When the Docker Socket Integration is used, depending on the network which Newt is run with, either the hostname (generally considered the container name) or the IP address of the container will be sent to Pangolin. Here are some of the scenarios where IPs or hostname of the container will be utilised:
|
||||
|
||||
- **Running in Network Mode 'host'**: IP addresses will be used
|
||||
- **Running in Network Mode 'bridge'**: IP addresses will be used
|
||||
- **Running in docker-compose without a network specification**: Docker compose creates a network for the compose by default, hostnames will be used
|
||||
- **Running on docker-compose with defined network**: Hostnames will be used
|
||||
|
||||
### Docker Enforce Network Validation
|
||||
|
||||
When run as a Docker container, Newt can validate that the target being provided is on the same network as the Newt container and only return containers directly accessible by Newt. Validation will be carried out against either the hostname/IP Address and the Port number to ensure the running container is exposing the ports to Newt.
|
||||
|
||||
It is important to note that if the Newt container is run with a network mode of `host` that this feature will not work. Running in `host` mode causes the container to share its resources with the host machine, therefore making it so the specific host container information for Newt cannot be retrieved to be able to carry out network validation.
|
||||
|
||||
**Configuration:**
|
||||
|
||||
Validation is `false` by default. It can be enabled via setting the `--docker-enforce-network-validation` CLI argument or by setting the `DOCKER_ENFORCE_NETWORK_VALIDATION` environment variable.
|
||||
|
||||
If validation is enforced and the Docker socket is available, Newt will **not** add the target as it cannot be verified. A warning will be presented in the Newt logs.
|
||||
|
||||
### Updown
|
||||
|
||||
You can pass in a updown script for Newt to call when it is adding or removing a target:
|
||||
|
||||
`--updown "python3 test.py"`
|
||||
|
||||
It will get called with args when a target is added:
|
||||
`python3 test.py add tcp localhost:8556`
|
||||
`python3 test.py remove tcp localhost:8556`
|
||||
|
||||
Returning a string from the script in the format of a target (`ip:dst` so `10.0.0.1:8080`) it will override the target and use this value instead to proxy.
|
||||
|
||||
You can look at updown.py as a reference script to get started!
|
||||
|
||||
### mTLS
|
||||
|
||||
Newt supports mutual TLS (mTLS) authentication if the server is configured to request a client certificate. You can use either a PKCS12 (.p12/.pfx) file or split PEM files for the client cert, private key, and CA.
|
||||
|
||||
#### Option 1: PKCS12 (Legacy)
|
||||
|
||||
> This is the original method and still supported.
|
||||
|
||||
* File must contain:
|
||||
|
||||
* Client private key
|
||||
* Public certificate
|
||||
* CA certificate
|
||||
* Encrypted `.p12` files are **not supported**
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
newt \
|
||||
--id 31frd0uzbjvp721 \
|
||||
--secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 \
|
||||
--endpoint https://example.com \
|
||||
--tls-client-cert ./client.p12
|
||||
```
|
||||
|
||||
#### Option 2: Split PEM Files (Preferred)
|
||||
|
||||
You can now provide separate files for:
|
||||
|
||||
* `--tls-client-cert`: client certificate (`.crt` or `.pem`)
|
||||
* `--tls-client-key`: client private key (`.key` or `.pem`)
|
||||
* `--tls-ca-cert`: CA cert to verify the server
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
newt \
|
||||
--id 31frd0uzbjvp721 \
|
||||
--secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 \
|
||||
--endpoint https://example.com \
|
||||
--tls-client-cert ./client.crt \
|
||||
--tls-client-key ./client.key \
|
||||
--tls-ca-cert ./ca.crt
|
||||
```
|
||||
|
||||
|
||||
```yaml
|
||||
services:
|
||||
newt:
|
||||
image: fosrl/newt
|
||||
container_name: newt
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- PANGOLIN_ENDPOINT=https://example.com
|
||||
- NEWT_ID=2ix2t8xk22ubpfy
|
||||
- NEWT_SECRET=nnisrfsdfc7prqsp9ewo1dvtvci50j5uiqotez00dgap0ii2
|
||||
- TLS_CLIENT_CERT=./client.p12
|
||||
```
|
||||
|
||||
## Build
|
||||
|
||||
### Container
|
||||
|
||||
Ensure Docker is installed.
|
||||
|
||||
```bash
|
||||
make
|
||||
```
|
||||
|
||||
### Binary
|
||||
|
||||
Make sure to have Go 1.23.1 installed.
|
||||
Make sure to have Go 1.25 installed.
|
||||
|
||||
```bash
|
||||
make local
|
||||
make
|
||||
```
|
||||
|
||||
### Nix Flake
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
If you discover a security vulnerability, please follow the steps below to responsibly disclose it to us:
|
||||
|
||||
1. **Do not create a public GitHub issue or discussion post.** This could put the security of other users at risk.
|
||||
2. Send a detailed report to [security@fossorial.io](mailto:security@fossorial.io) or send a **private** message to a maintainer on [Discord](https://discord.gg/HCJR8Xhme4). Include:
|
||||
2. Send a detailed report to [security@pangolin.net](mailto:security@pangolin.net) or send a **private** message to a maintainer on [Discord](https://discord.gg/HCJR8Xhme4). Include:
|
||||
|
||||
- Description and location of the vulnerability.
|
||||
- Potential impact of the vulnerability.
|
||||
|
||||
151
authdaemon.go
Normal file
151
authdaemon.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/fosrl/newt/authdaemon"
|
||||
"github.com/fosrl/newt/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPrincipalsPath = "/var/run/auth-daemon/principals"
|
||||
defaultCACertPath = "/etc/ssh/ca.pem"
|
||||
)
|
||||
|
||||
var (
|
||||
errPresharedKeyRequired = errors.New("auth-daemon-key is required when --auth-daemon is enabled")
|
||||
errRootRequired = errors.New("auth-daemon must be run as root (use sudo)")
|
||||
authDaemonServer *authdaemon.Server // Global auth daemon server instance
|
||||
)
|
||||
|
||||
// startAuthDaemon initializes and starts the auth daemon in the background.
|
||||
// It validates requirements (Linux, root, preshared key) and starts the server
|
||||
// in a goroutine so it runs alongside normal newt operation.
|
||||
func startAuthDaemon(ctx context.Context) error {
|
||||
// Validation
|
||||
if runtime.GOOS != "linux" {
|
||||
return fmt.Errorf("auth-daemon is only supported on Linux, not %s", runtime.GOOS)
|
||||
}
|
||||
if os.Geteuid() != 0 {
|
||||
return errRootRequired
|
||||
}
|
||||
|
||||
// Use defaults if not set
|
||||
principalsFile := authDaemonPrincipalsFile
|
||||
if principalsFile == "" {
|
||||
principalsFile = defaultPrincipalsPath
|
||||
}
|
||||
caCertPath := authDaemonCACertPath
|
||||
if caCertPath == "" {
|
||||
caCertPath = defaultCACertPath
|
||||
}
|
||||
|
||||
// Create auth daemon server
|
||||
cfg := authdaemon.Config{
|
||||
DisableHTTPS: true, // We run without HTTP server in newt
|
||||
PresharedKey: "this-key-is-not-used", // Not used in embedded mode, but set to non-empty to satisfy validation
|
||||
PrincipalsFilePath: principalsFile,
|
||||
CACertPath: caCertPath,
|
||||
Force: true,
|
||||
}
|
||||
|
||||
srv, err := authdaemon.NewServer(cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create auth daemon server: %w", err)
|
||||
}
|
||||
|
||||
authDaemonServer = srv
|
||||
|
||||
// Start the auth daemon in a goroutine so it runs alongside newt
|
||||
go func() {
|
||||
logger.Info("Auth daemon starting (native mode, no HTTP server)")
|
||||
if err := srv.Run(ctx); err != nil {
|
||||
logger.Error("Auth daemon error: %v", err)
|
||||
}
|
||||
logger.Info("Auth daemon stopped")
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
|
||||
// runPrincipalsCmd executes the principals subcommand logic
|
||||
func runPrincipalsCmd(args []string) {
|
||||
opts := struct {
|
||||
PrincipalsFile string
|
||||
Username string
|
||||
}{
|
||||
PrincipalsFile: defaultPrincipalsPath,
|
||||
}
|
||||
|
||||
// Parse flags manually
|
||||
for i := 0; i < len(args); i++ {
|
||||
switch args[i] {
|
||||
case "--principals-file":
|
||||
if i+1 >= len(args) {
|
||||
fmt.Fprintf(os.Stderr, "Error: --principals-file requires a value\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
opts.PrincipalsFile = args[i+1]
|
||||
i++
|
||||
case "--username":
|
||||
if i+1 >= len(args) {
|
||||
fmt.Fprintf(os.Stderr, "Error: --username requires a value\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
opts.Username = args[i+1]
|
||||
i++
|
||||
case "--help", "-h":
|
||||
printPrincipalsHelp()
|
||||
os.Exit(0)
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "Error: unknown flag: %s\n", args[i])
|
||||
printPrincipalsHelp()
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// Validation
|
||||
if opts.Username == "" {
|
||||
fmt.Fprintf(os.Stderr, "Error: username is required\n")
|
||||
printPrincipalsHelp()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Get principals
|
||||
list, err := authdaemon.GetPrincipals(opts.PrincipalsFile, opts.Username)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if len(list) == 0 {
|
||||
fmt.Println("")
|
||||
return
|
||||
}
|
||||
for _, principal := range list {
|
||||
fmt.Println(principal)
|
||||
}
|
||||
}
|
||||
|
||||
func printPrincipalsHelp() {
|
||||
fmt.Fprintf(os.Stderr, `Usage: newt principals [flags]
|
||||
|
||||
Output principals for a username (for AuthorizedPrincipalsCommand in sshd_config).
|
||||
Read the principals file and print principals that match the given username, one per line.
|
||||
Configure in sshd_config with AuthorizedPrincipalsCommand and %%u for the username.
|
||||
|
||||
Flags:
|
||||
--principals-file string Path to the principals file (default "%s")
|
||||
--username string Username to look up (required)
|
||||
--help, -h Show this help message
|
||||
|
||||
Example:
|
||||
newt principals --username alice
|
||||
|
||||
`, defaultPrincipalsPath)
|
||||
}
|
||||
27
authdaemon/connection.go
Normal file
27
authdaemon/connection.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package authdaemon
|
||||
|
||||
import (
|
||||
"github.com/fosrl/newt/logger"
|
||||
)
|
||||
|
||||
// ProcessConnection runs the same logic as POST /connection: CA cert, user create/reconcile, principals.
|
||||
// Use this when DisableHTTPS is true (e.g. embedded in Newt) instead of calling the API.
|
||||
func (s *Server) ProcessConnection(req ConnectionRequest) {
|
||||
logger.Info("connection: niceId=%q username=%q metadata.sudo=%v metadata.homedir=%v",
|
||||
req.NiceId, req.Username, req.Metadata.Sudo, req.Metadata.Homedir)
|
||||
|
||||
cfg := &s.cfg
|
||||
if cfg.CACertPath != "" {
|
||||
if err := writeCACertIfNotExists(cfg.CACertPath, req.CaCert, cfg.Force); err != nil {
|
||||
logger.Warn("auth-daemon: write CA cert: %v", err)
|
||||
}
|
||||
}
|
||||
if err := ensureUser(req.Username, req.Metadata); err != nil {
|
||||
logger.Warn("auth-daemon: ensure user: %v", err)
|
||||
}
|
||||
if cfg.PrincipalsFilePath != "" {
|
||||
if err := writePrincipals(cfg.PrincipalsFilePath, req.Username, req.NiceId); err != nil {
|
||||
logger.Warn("auth-daemon: write principals: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
224
authdaemon/host_linux.go
Normal file
224
authdaemon/host_linux.go
Normal file
@@ -0,0 +1,224 @@
|
||||
//go:build linux
|
||||
|
||||
package authdaemon
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
)
|
||||
|
||||
// writeCACertIfNotExists writes contents to path. If the file already exists: when force is false, skip; when force is true, overwrite only if content differs.
|
||||
func writeCACertIfNotExists(path, contents string, force bool) error {
|
||||
contents = strings.TrimSpace(contents)
|
||||
if contents != "" && !strings.HasSuffix(contents, "\n") {
|
||||
contents += "\n"
|
||||
}
|
||||
existing, err := os.ReadFile(path)
|
||||
if err == nil {
|
||||
existingStr := strings.TrimSpace(string(existing))
|
||||
if existingStr != "" && !strings.HasSuffix(existingStr, "\n") {
|
||||
existingStr += "\n"
|
||||
}
|
||||
if existingStr == contents {
|
||||
logger.Debug("auth-daemon: CA cert unchanged at %s, skipping write", path)
|
||||
return nil
|
||||
}
|
||||
if !force {
|
||||
logger.Debug("auth-daemon: CA cert already exists at %s, skipping write (Force disabled)", path)
|
||||
return nil
|
||||
}
|
||||
} else if !os.IsNotExist(err) {
|
||||
return fmt.Errorf("read %s: %w", path, err)
|
||||
}
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Errorf("mkdir %s: %w", dir, err)
|
||||
}
|
||||
if err := os.WriteFile(path, []byte(contents), 0644); err != nil {
|
||||
return fmt.Errorf("write CA cert: %w", err)
|
||||
}
|
||||
logger.Info("auth-daemon: wrote CA cert to %s", path)
|
||||
return nil
|
||||
}
|
||||
|
||||
// writePrincipals updates the principals file at path: JSON object keyed by username, value is array of principals. Adds username and niceId to that user's list (deduped).
|
||||
func writePrincipals(path, username, niceId string) error {
|
||||
if path == "" {
|
||||
return nil
|
||||
}
|
||||
username = strings.TrimSpace(username)
|
||||
niceId = strings.TrimSpace(niceId)
|
||||
if username == "" {
|
||||
return nil
|
||||
}
|
||||
data := make(map[string][]string)
|
||||
if raw, err := os.ReadFile(path); err == nil {
|
||||
_ = json.Unmarshal(raw, &data)
|
||||
}
|
||||
list := data[username]
|
||||
seen := make(map[string]struct{}, len(list)+2)
|
||||
for _, p := range list {
|
||||
seen[p] = struct{}{}
|
||||
}
|
||||
for _, p := range []string{username, niceId} {
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[p]; !ok {
|
||||
seen[p] = struct{}{}
|
||||
list = append(list, p)
|
||||
}
|
||||
}
|
||||
data[username] = list
|
||||
body, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal principals: %w", err)
|
||||
}
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Errorf("mkdir %s: %w", dir, err)
|
||||
}
|
||||
if err := os.WriteFile(path, body, 0644); err != nil {
|
||||
return fmt.Errorf("write principals: %w", err)
|
||||
}
|
||||
logger.Debug("auth-daemon: wrote principals to %s", path)
|
||||
return nil
|
||||
}
|
||||
|
||||
// sudoGroup returns the name of the sudo group (wheel or sudo) that exists on the system. Prefers wheel.
|
||||
func sudoGroup() string {
|
||||
f, err := os.Open("/etc/group")
|
||||
if err != nil {
|
||||
return "sudo"
|
||||
}
|
||||
defer f.Close()
|
||||
sc := bufio.NewScanner(f)
|
||||
hasWheel := false
|
||||
hasSudo := false
|
||||
for sc.Scan() {
|
||||
line := sc.Text()
|
||||
if strings.HasPrefix(line, "wheel:") {
|
||||
hasWheel = true
|
||||
}
|
||||
if strings.HasPrefix(line, "sudo:") {
|
||||
hasSudo = true
|
||||
}
|
||||
}
|
||||
if hasWheel {
|
||||
return "wheel"
|
||||
}
|
||||
if hasSudo {
|
||||
return "sudo"
|
||||
}
|
||||
return "sudo"
|
||||
}
|
||||
|
||||
// ensureUser creates the system user if missing, or reconciles sudo and homedir to match meta.
|
||||
func ensureUser(username string, meta ConnectionMetadata) error {
|
||||
if username == "" {
|
||||
return nil
|
||||
}
|
||||
u, err := user.Lookup(username)
|
||||
if err != nil {
|
||||
if _, ok := err.(user.UnknownUserError); !ok {
|
||||
return fmt.Errorf("lookup user %s: %w", username, err)
|
||||
}
|
||||
return createUser(username, meta)
|
||||
}
|
||||
return reconcileUser(u, meta)
|
||||
}
|
||||
|
||||
func createUser(username string, meta ConnectionMetadata) error {
|
||||
args := []string{"-s", "/bin/bash"}
|
||||
if meta.Homedir {
|
||||
args = append(args, "-m")
|
||||
} else {
|
||||
args = append(args, "-M")
|
||||
}
|
||||
args = append(args, username)
|
||||
cmd := exec.Command("useradd", args...)
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("useradd %s: %w (output: %s)", username, err, string(out))
|
||||
}
|
||||
logger.Info("auth-daemon: created user %s (homedir=%v)", username, meta.Homedir)
|
||||
if meta.Sudo {
|
||||
group := sudoGroup()
|
||||
cmd := exec.Command("usermod", "-aG", group, username)
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
logger.Warn("auth-daemon: usermod -aG %s %s: %v (output: %s)", group, username, err, string(out))
|
||||
} else {
|
||||
logger.Info("auth-daemon: added %s to %s", username, group)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mustAtoi(s string) int {
|
||||
n, _ := strconv.Atoi(s)
|
||||
return n
|
||||
}
|
||||
|
||||
func reconcileUser(u *user.User, meta ConnectionMetadata) error {
|
||||
group := sudoGroup()
|
||||
inGroup, err := userInGroup(u.Username, group)
|
||||
if err != nil {
|
||||
logger.Warn("auth-daemon: check group %s: %v", group, err)
|
||||
inGroup = false
|
||||
}
|
||||
if meta.Sudo && !inGroup {
|
||||
cmd := exec.Command("usermod", "-aG", group, u.Username)
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
logger.Warn("auth-daemon: usermod -aG %s %s: %v (output: %s)", group, u.Username, err, string(out))
|
||||
} else {
|
||||
logger.Info("auth-daemon: added %s to %s", u.Username, group)
|
||||
}
|
||||
} else if !meta.Sudo && inGroup {
|
||||
cmd := exec.Command("gpasswd", "-d", u.Username, group)
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
logger.Warn("auth-daemon: gpasswd -d %s %s: %v (output: %s)", u.Username, group, err, string(out))
|
||||
} else {
|
||||
logger.Info("auth-daemon: removed %s from %s", u.Username, group)
|
||||
}
|
||||
}
|
||||
if meta.Homedir && u.HomeDir != "" {
|
||||
if st, err := os.Stat(u.HomeDir); err != nil || !st.IsDir() {
|
||||
if err := os.MkdirAll(u.HomeDir, 0755); err != nil {
|
||||
logger.Warn("auth-daemon: mkdir %s: %v", u.HomeDir, err)
|
||||
} else {
|
||||
uid, gid := mustAtoi(u.Uid), mustAtoi(u.Gid)
|
||||
_ = os.Chown(u.HomeDir, uid, gid)
|
||||
logger.Info("auth-daemon: created home %s for %s", u.HomeDir, u.Username)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func userInGroup(username, groupName string) (bool, error) {
|
||||
// getent group wheel returns "wheel:x:10:user1,user2"
|
||||
cmd := exec.Command("getent", "group", groupName)
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
parts := strings.SplitN(strings.TrimSpace(string(out)), ":", 4)
|
||||
if len(parts) < 4 {
|
||||
return false, nil
|
||||
}
|
||||
members := strings.Split(parts[3], ",")
|
||||
for _, m := range members {
|
||||
if strings.TrimSpace(m) == username {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
22
authdaemon/host_stub.go
Normal file
22
authdaemon/host_stub.go
Normal file
@@ -0,0 +1,22 @@
|
||||
//go:build !linux
|
||||
|
||||
package authdaemon
|
||||
|
||||
import "fmt"
|
||||
|
||||
var errLinuxOnly = fmt.Errorf("auth-daemon PAM agent is only supported on Linux")
|
||||
|
||||
// writeCACertIfNotExists returns an error on non-Linux.
|
||||
func writeCACertIfNotExists(path, contents string, force bool) error {
|
||||
return errLinuxOnly
|
||||
}
|
||||
|
||||
// ensureUser returns an error on non-Linux.
|
||||
func ensureUser(username string, meta ConnectionMetadata) error {
|
||||
return errLinuxOnly
|
||||
}
|
||||
|
||||
// writePrincipals returns an error on non-Linux.
|
||||
func writePrincipals(path, username, niceId string) error {
|
||||
return errLinuxOnly
|
||||
}
|
||||
28
authdaemon/principals.go
Normal file
28
authdaemon/principals.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package authdaemon
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// GetPrincipals reads the principals data file at path, looks up the given user, and returns that user's principals as a string slice.
|
||||
// The file format is JSON: object with username keys and array-of-principals values, e.g. {"alice":["alice","usr-123"],"bob":["bob","usr-456"]}.
|
||||
// If the user is not found or the file is missing, returns nil and nil.
|
||||
func GetPrincipals(path, user string) ([]string, error) {
|
||||
if path == "" {
|
||||
return nil, fmt.Errorf("principals file path is required")
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("read principals file: %w", err)
|
||||
}
|
||||
var m map[string][]string
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return nil, fmt.Errorf("parse principals file: %w", err)
|
||||
}
|
||||
return m[user], nil
|
||||
}
|
||||
56
authdaemon/routes.go
Normal file
56
authdaemon/routes.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package authdaemon
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// registerRoutes registers all API routes. Add new endpoints here.
|
||||
func (s *Server) registerRoutes() {
|
||||
s.mux.HandleFunc("/health", s.handleHealth)
|
||||
s.mux.HandleFunc("/connection", s.handleConnection)
|
||||
}
|
||||
|
||||
// ConnectionMetadata is the metadata object in POST /connection.
|
||||
type ConnectionMetadata struct {
|
||||
Sudo bool `json:"sudo"`
|
||||
Homedir bool `json:"homedir"`
|
||||
}
|
||||
|
||||
// ConnectionRequest is the JSON body for POST /connection.
|
||||
type ConnectionRequest struct {
|
||||
CaCert string `json:"caCert"`
|
||||
NiceId string `json:"niceId"`
|
||||
Username string `json:"username"`
|
||||
Metadata ConnectionMetadata `json:"metadata"`
|
||||
}
|
||||
|
||||
// healthResponse is the JSON body for GET /health.
|
||||
type healthResponse struct {
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// handleHealth responds with 200 and {"status":"ok"}.
|
||||
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(healthResponse{Status: "ok"})
|
||||
}
|
||||
|
||||
// handleConnection accepts POST with connection payload and delegates to ProcessConnection.
|
||||
func (s *Server) handleConnection(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
var req ConnectionRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Bad Request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
s.ProcessConnection(req)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
179
authdaemon/server.go
Normal file
179
authdaemon/server.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package authdaemon
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
// DisableHTTPS: when true, Run() does not start the HTTPS server (for embedded use inside Newt). Call ProcessConnection directly for connection events.
|
||||
DisableHTTPS bool
|
||||
Port int // Required when DisableHTTPS is false. Listen port for the HTTPS server. No default.
|
||||
PresharedKey string // Required when DisableHTTPS is false. HTTP auth (Authorization: Bearer <key> or X-Preshared-Key: <key>). No default.
|
||||
CACertPath string // Required. Where to write the CA cert (e.g. /etc/ssh/ca.pem). No default.
|
||||
Force bool // If true, overwrite existing CA cert (and other items) when content differs. Default false.
|
||||
PrincipalsFilePath string // Required. Path to the principals data file (JSON: username -> array of principals). No default.
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
cfg Config
|
||||
addr string
|
||||
presharedKey string
|
||||
mux *http.ServeMux
|
||||
tlsCert tls.Certificate
|
||||
}
|
||||
|
||||
// generateTLSCert creates a self-signed certificate and key in memory (no disk).
|
||||
func generateTLSCert() (tls.Certificate, error) {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("generate key: %w", err)
|
||||
}
|
||||
serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("serial: %w", err)
|
||||
}
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: serial,
|
||||
Subject: pkix.Name{
|
||||
CommonName: "localhost",
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(365 * 24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
DNSNames: []string{"localhost", "127.0.0.1"},
|
||||
}
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, key.Public(), key)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("create certificate: %w", err)
|
||||
}
|
||||
keyDER, err := x509.MarshalECPrivateKey(key)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("marshal key: %w", err)
|
||||
}
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
||||
cert, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("x509 key pair: %w", err)
|
||||
}
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// authMiddleware wraps next and requires a valid preshared key on every request.
|
||||
// Accepts Authorization: Bearer <key> or X-Preshared-Key: <key>.
|
||||
func (s *Server) authMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
key := ""
|
||||
if v := r.Header.Get("Authorization"); strings.HasPrefix(v, "Bearer ") {
|
||||
key = strings.TrimSpace(strings.TrimPrefix(v, "Bearer "))
|
||||
}
|
||||
if key == "" {
|
||||
key = strings.TrimSpace(r.Header.Get("X-Preshared-Key"))
|
||||
}
|
||||
if key == "" || subtle.ConstantTimeCompare([]byte(key), []byte(s.presharedKey)) != 1 {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// NewServer builds a new auth-daemon server from cfg. Port, PresharedKey, CACertPath, and PrincipalsFilePath are required (no defaults).
|
||||
func NewServer(cfg Config) (*Server, error) {
|
||||
if runtime.GOOS != "linux" {
|
||||
return nil, fmt.Errorf("auth-daemon is only supported on Linux, not %s", runtime.GOOS)
|
||||
}
|
||||
if !cfg.DisableHTTPS {
|
||||
if cfg.Port <= 0 {
|
||||
return nil, fmt.Errorf("port is required and must be positive")
|
||||
}
|
||||
if cfg.PresharedKey == "" {
|
||||
return nil, fmt.Errorf("preshared key is required")
|
||||
}
|
||||
}
|
||||
if cfg.CACertPath == "" {
|
||||
return nil, fmt.Errorf("CACertPath is required")
|
||||
}
|
||||
if cfg.PrincipalsFilePath == "" {
|
||||
return nil, fmt.Errorf("PrincipalsFilePath is required")
|
||||
}
|
||||
s := &Server{cfg: cfg}
|
||||
if !cfg.DisableHTTPS {
|
||||
cert, err := generateTLSCert()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.addr = fmt.Sprintf(":%d", cfg.Port)
|
||||
s.presharedKey = cfg.PresharedKey
|
||||
s.mux = http.NewServeMux()
|
||||
s.tlsCert = cert
|
||||
s.registerRoutes()
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Run starts the HTTPS server (unless DisableHTTPS) and blocks until ctx is cancelled or the server errors.
|
||||
// When DisableHTTPS is true, Run() blocks on ctx only and does not listen; use ProcessConnection for connection events.
|
||||
func (s *Server) Run(ctx context.Context) error {
|
||||
if s.cfg.DisableHTTPS {
|
||||
logger.Info("auth-daemon running (HTTPS disabled)")
|
||||
<-ctx.Done()
|
||||
s.cleanupPrincipalsFile()
|
||||
return nil
|
||||
}
|
||||
tcfg := &tls.Config{
|
||||
Certificates: []tls.Certificate{s.tlsCert},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
handler := s.authMiddleware(s.mux)
|
||||
srv := &http.Server{
|
||||
Addr: s.addr,
|
||||
Handler: handler,
|
||||
TLSConfig: tcfg,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ReadHeaderTimeout: 5 * time.Second,
|
||||
IdleTimeout: 60 * time.Second,
|
||||
}
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||
logger.Warn("auth-daemon shutdown: %v", err)
|
||||
}
|
||||
}()
|
||||
logger.Info("auth-daemon listening on https://127.0.0.1%s", s.addr)
|
||||
if err := srv.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed {
|
||||
return err
|
||||
}
|
||||
s.cleanupPrincipalsFile()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) cleanupPrincipalsFile() {
|
||||
if s.cfg.PrincipalsFilePath != "" {
|
||||
if err := os.Remove(s.cfg.PrincipalsFilePath); err != nil && !os.IsNotExist(err) {
|
||||
logger.Warn("auth-daemon: remove principals file: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
840
bind/shared_bind.go
Normal file
840
bind/shared_bind.go
Normal file
@@ -0,0 +1,840 @@
|
||||
//go:build !js
|
||||
|
||||
package bind
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
// Magic packet constants for connection testing
|
||||
// These packets are intercepted by SharedBind and responded to directly,
|
||||
// without being passed to the WireGuard device.
|
||||
var (
|
||||
// MagicTestRequest is the prefix for a test request packet
|
||||
// Format: PANGOLIN_TEST_REQ + 8 bytes of random data (for echo)
|
||||
MagicTestRequest = []byte("PANGOLIN_TEST_REQ")
|
||||
|
||||
// MagicTestResponse is the prefix for a test response packet
|
||||
// Format: PANGOLIN_TEST_RSP + 8 bytes echoed from request
|
||||
MagicTestResponse = []byte("PANGOLIN_TEST_RSP")
|
||||
)
|
||||
|
||||
const (
|
||||
// MagicPacketDataLen is the length of random data included in test packets
|
||||
MagicPacketDataLen = 8
|
||||
|
||||
// MagicTestRequestLen is the total length of a test request packet
|
||||
MagicTestRequestLen = 17 + MagicPacketDataLen // len("PANGOLIN_TEST_REQ") + 8
|
||||
|
||||
// MagicTestResponseLen is the total length of a test response packet
|
||||
MagicTestResponseLen = 17 + MagicPacketDataLen // len("PANGOLIN_TEST_RSP") + 8
|
||||
)
|
||||
|
||||
// PacketSource identifies where a packet came from
|
||||
type PacketSource uint8
|
||||
|
||||
const (
|
||||
SourceSocket PacketSource = iota // From physical UDP socket (hole-punched clients)
|
||||
SourceNetstack // From netstack (relay through main tunnel)
|
||||
)
|
||||
|
||||
// SourceAwareEndpoint wraps an endpoint with source information
|
||||
type SourceAwareEndpoint struct {
|
||||
wgConn.Endpoint
|
||||
source PacketSource
|
||||
}
|
||||
|
||||
// GetSource returns the source of this endpoint
|
||||
func (e *SourceAwareEndpoint) GetSource() PacketSource {
|
||||
return e.source
|
||||
}
|
||||
|
||||
// injectedPacket represents a packet injected into the SharedBind from an internal source
|
||||
type injectedPacket struct {
|
||||
data []byte
|
||||
endpoint wgConn.Endpoint
|
||||
}
|
||||
|
||||
// Endpoint represents a network endpoint for the SharedBind
|
||||
type Endpoint struct {
|
||||
AddrPort netip.AddrPort
|
||||
}
|
||||
|
||||
// ClearSrc implements the wgConn.Endpoint interface
|
||||
func (e *Endpoint) ClearSrc() {}
|
||||
|
||||
// DstIP implements the wgConn.Endpoint interface
|
||||
func (e *Endpoint) DstIP() netip.Addr {
|
||||
return e.AddrPort.Addr()
|
||||
}
|
||||
|
||||
// SrcIP implements the wgConn.Endpoint interface
|
||||
func (e *Endpoint) SrcIP() netip.Addr {
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
// DstToBytes implements the wgConn.Endpoint interface
|
||||
func (e *Endpoint) DstToBytes() []byte {
|
||||
b, _ := e.AddrPort.MarshalBinary()
|
||||
return b
|
||||
}
|
||||
|
||||
// DstToString implements the wgConn.Endpoint interface
|
||||
func (e *Endpoint) DstToString() string {
|
||||
return e.AddrPort.String()
|
||||
}
|
||||
|
||||
// SrcToString implements the wgConn.Endpoint interface
|
||||
func (e *Endpoint) SrcToString() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// SharedBind is a thread-safe UDP bind that can be shared between WireGuard
|
||||
// and hole punch senders. It wraps a single UDP connection and implements
|
||||
// reference counting to prevent premature closure.
|
||||
// It also supports receiving packets from a netstack and routing responses
|
||||
// back through the appropriate source.
|
||||
type SharedBind struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// The underlying UDP connection (for hole-punched clients)
|
||||
udpConn *net.UDPConn
|
||||
|
||||
// IPv4 and IPv6 packet connections for advanced features
|
||||
ipv4PC *ipv4.PacketConn
|
||||
ipv6PC *ipv6.PacketConn
|
||||
|
||||
// Reference counting to prevent closing while in use
|
||||
refCount atomic.Int32
|
||||
closed atomic.Bool
|
||||
|
||||
// Channels for receiving data
|
||||
recvFuncs []wgConn.ReceiveFunc
|
||||
|
||||
// Port binding information
|
||||
port uint16
|
||||
|
||||
// Channel for packets from netstack (from direct relay) - larger buffer for throughput
|
||||
netstackPackets chan injectedPacket
|
||||
|
||||
// Netstack connection for sending responses back through the tunnel
|
||||
// Using atomic.Pointer for lock-free access in hot path
|
||||
netstackConn atomic.Pointer[net.PacketConn]
|
||||
|
||||
// Track which endpoints came from netstack (key: netip.AddrPort, value: struct{})
|
||||
// Using netip.AddrPort directly as key is more efficient than string
|
||||
netstackEndpoints sync.Map
|
||||
|
||||
// Pre-allocated message buffers for batch operations (Linux only)
|
||||
ipv4Msgs []ipv4.Message
|
||||
|
||||
// Shutdown signal for receive goroutines
|
||||
closeChan chan struct{}
|
||||
|
||||
// Callback for magic test responses (used for holepunch testing)
|
||||
magicResponseCallback atomic.Pointer[func(addr netip.AddrPort, echoData []byte)]
|
||||
|
||||
// Rebinding state - used to keep receive goroutines alive during socket transition
|
||||
rebinding bool // true when socket is being replaced
|
||||
rebindingCond *sync.Cond // signaled when rebind completes
|
||||
}
|
||||
|
||||
// MagicResponseCallback is the function signature for magic packet response callbacks
|
||||
type MagicResponseCallback func(addr netip.AddrPort, echoData []byte)
|
||||
|
||||
// New creates a new SharedBind from an existing UDP connection.
|
||||
// The SharedBind takes ownership of the connection and will close it
|
||||
// when all references are released.
|
||||
func New(udpConn *net.UDPConn) (*SharedBind, error) {
|
||||
if udpConn == nil {
|
||||
return nil, fmt.Errorf("udpConn cannot be nil")
|
||||
}
|
||||
|
||||
bind := &SharedBind{
|
||||
udpConn: udpConn,
|
||||
netstackPackets: make(chan injectedPacket, 1024), // Larger buffer for better throughput
|
||||
closeChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Initialize the rebinding condition variable
|
||||
bind.rebindingCond = sync.NewCond(&bind.mu)
|
||||
|
||||
// Initialize reference count to 1 (the creator holds the first reference)
|
||||
bind.refCount.Store(1)
|
||||
|
||||
// Get the local port
|
||||
if addr, ok := udpConn.LocalAddr().(*net.UDPAddr); ok {
|
||||
bind.port = uint16(addr.Port)
|
||||
}
|
||||
|
||||
return bind, nil
|
||||
}
|
||||
|
||||
// SetNetstackConn sets the netstack connection for receiving/sending packets through the tunnel.
|
||||
// This connection is used for relay traffic that should go back through the main tunnel.
|
||||
func (b *SharedBind) SetNetstackConn(conn net.PacketConn) {
|
||||
b.netstackConn.Store(&conn)
|
||||
}
|
||||
|
||||
// GetNetstackConn returns the netstack connection if set
|
||||
func (b *SharedBind) GetNetstackConn() net.PacketConn {
|
||||
ptr := b.netstackConn.Load()
|
||||
if ptr == nil {
|
||||
return nil
|
||||
}
|
||||
return *ptr
|
||||
}
|
||||
|
||||
// InjectPacket allows injecting a packet directly into the SharedBind's receive path.
|
||||
// This is used for direct relay from netstack without going through the host network.
|
||||
// The fromAddr should be the address the packet appears to come from.
|
||||
func (b *SharedBind) InjectPacket(data []byte, fromAddr netip.AddrPort) error {
|
||||
if b.closed.Load() {
|
||||
return net.ErrClosed
|
||||
}
|
||||
|
||||
// Unmap IPv4-in-IPv6 addresses to ensure consistency with parsed endpoints
|
||||
if fromAddr.Addr().Is4In6() {
|
||||
fromAddr = netip.AddrPortFrom(fromAddr.Addr().Unmap(), fromAddr.Port())
|
||||
}
|
||||
|
||||
// Track this endpoint as coming from netstack so responses go back the same way
|
||||
// Use AddrPort directly as key (more efficient than string)
|
||||
b.netstackEndpoints.Store(fromAddr, struct{}{})
|
||||
|
||||
// Make a copy of the data to avoid issues with buffer reuse
|
||||
dataCopy := make([]byte, len(data))
|
||||
copy(dataCopy, data)
|
||||
|
||||
select {
|
||||
case b.netstackPackets <- injectedPacket{
|
||||
data: dataCopy,
|
||||
endpoint: &wgConn.StdNetEndpoint{AddrPort: fromAddr},
|
||||
}:
|
||||
return nil
|
||||
case <-b.closeChan:
|
||||
return net.ErrClosed
|
||||
default:
|
||||
// Channel full, drop the packet
|
||||
return fmt.Errorf("netstack packet buffer full")
|
||||
}
|
||||
}
|
||||
|
||||
// AddRef increments the reference count. Call this when sharing
|
||||
// the bind with another component.
|
||||
func (b *SharedBind) AddRef() {
|
||||
newCount := b.refCount.Add(1)
|
||||
// Optional: Add logging for debugging
|
||||
_ = newCount // Placeholder for potential logging
|
||||
}
|
||||
|
||||
// Release decrements the reference count. When it reaches zero,
|
||||
// the underlying UDP connection is closed.
|
||||
func (b *SharedBind) Release() error {
|
||||
newCount := b.refCount.Add(-1)
|
||||
// Optional: Add logging for debugging
|
||||
_ = newCount // Placeholder for potential logging
|
||||
|
||||
if newCount < 0 {
|
||||
// This should never happen with proper usage
|
||||
b.refCount.Store(0)
|
||||
return fmt.Errorf("SharedBind reference count went negative")
|
||||
}
|
||||
|
||||
if newCount == 0 {
|
||||
return b.closeConnection()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// closeConnection actually closes the UDP connection
|
||||
func (b *SharedBind) closeConnection() error {
|
||||
if !b.closed.CompareAndSwap(false, true) {
|
||||
// Already closed
|
||||
return nil
|
||||
}
|
||||
|
||||
// Signal all goroutines to stop
|
||||
close(b.closeChan)
|
||||
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
var err error
|
||||
if b.udpConn != nil {
|
||||
err = b.udpConn.Close()
|
||||
b.udpConn = nil
|
||||
}
|
||||
|
||||
b.ipv4PC = nil
|
||||
b.ipv6PC = nil
|
||||
|
||||
// Clear netstack connection (but don't close it - it's managed externally)
|
||||
b.netstackConn.Store(nil)
|
||||
|
||||
// Clear tracked netstack endpoints
|
||||
b.netstackEndpoints = sync.Map{}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// ClearNetstackConn clears the netstack connection and tracked endpoints.
|
||||
// Call this when stopping the relay.
|
||||
func (b *SharedBind) ClearNetstackConn() {
|
||||
b.netstackConn.Store(nil)
|
||||
|
||||
// Clear tracked netstack endpoints
|
||||
b.netstackEndpoints = sync.Map{}
|
||||
}
|
||||
|
||||
// GetUDPConn returns the underlying UDP connection.
|
||||
// The caller must not close this connection directly.
|
||||
func (b *SharedBind) GetUDPConn() *net.UDPConn {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
return b.udpConn
|
||||
}
|
||||
|
||||
// GetRefCount returns the current reference count (for debugging)
|
||||
func (b *SharedBind) GetRefCount() int32 {
|
||||
return b.refCount.Load()
|
||||
}
|
||||
|
||||
// IsClosed returns whether the bind is closed
|
||||
func (b *SharedBind) IsClosed() bool {
|
||||
return b.closed.Load()
|
||||
}
|
||||
|
||||
// GetPort returns the current UDP port the bind is using.
|
||||
// This is useful when rebinding to try to reuse the same port.
|
||||
func (b *SharedBind) GetPort() uint16 {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
return b.port
|
||||
}
|
||||
|
||||
// CloseSocket closes the underlying UDP connection to release the port,
|
||||
// but keeps the SharedBind in a state where it can accept a new connection via Rebind.
|
||||
// This allows the caller to close the old socket first, then bind a new socket
|
||||
// to the same port before calling Rebind.
|
||||
//
|
||||
// Returns the port that was being used, so the caller can attempt to rebind to it.
|
||||
// Sets the rebinding flag so receive goroutines will wait for the new socket
|
||||
// instead of exiting.
|
||||
func (b *SharedBind) CloseSocket() (uint16, error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.closed.Load() {
|
||||
return 0, fmt.Errorf("bind is closed")
|
||||
}
|
||||
|
||||
port := b.port
|
||||
|
||||
// Set rebinding flag BEFORE closing the socket so receive goroutines
|
||||
// know to wait instead of exit
|
||||
b.rebinding = true
|
||||
|
||||
// Close the old connection to release the port
|
||||
if b.udpConn != nil {
|
||||
logger.Debug("Closing UDP connection to release port %d (rebinding)", port)
|
||||
b.udpConn.Close()
|
||||
b.udpConn = nil
|
||||
}
|
||||
|
||||
return port, nil
|
||||
}
|
||||
|
||||
// Rebind replaces the underlying UDP connection with a new one.
|
||||
// This is necessary when network connectivity changes (e.g., WiFi to cellular
|
||||
// transition on macOS/iOS) and the old socket becomes stale.
|
||||
//
|
||||
// The caller is responsible for creating the new UDP connection and passing it here.
|
||||
// After rebind, the caller should trigger a hole punch to re-establish NAT mappings.
|
||||
//
|
||||
// Note: Call CloseSocket() first if you need to rebind to the same port, as the
|
||||
// old socket must be closed before a new socket can bind to the same port.
|
||||
func (b *SharedBind) Rebind(newConn *net.UDPConn) error {
|
||||
if newConn == nil {
|
||||
return fmt.Errorf("newConn cannot be nil")
|
||||
}
|
||||
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.closed.Load() {
|
||||
return fmt.Errorf("bind is closed")
|
||||
}
|
||||
|
||||
// Close the old connection if it's still open
|
||||
// (it may have already been closed via CloseSocket)
|
||||
if b.udpConn != nil {
|
||||
logger.Debug("Closing old UDP connection during rebind")
|
||||
b.udpConn.Close()
|
||||
}
|
||||
|
||||
// Set up the new connection
|
||||
b.udpConn = newConn
|
||||
|
||||
// Update packet connections for the new socket
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
b.ipv4PC = ipv4.NewPacketConn(newConn)
|
||||
b.ipv6PC = ipv6.NewPacketConn(newConn)
|
||||
|
||||
// Re-initialize message buffers for batch operations
|
||||
batchSize := wgConn.IdealBatchSize
|
||||
b.ipv4Msgs = make([]ipv4.Message, batchSize)
|
||||
for i := range b.ipv4Msgs {
|
||||
b.ipv4Msgs[i].OOB = make([]byte, 0)
|
||||
}
|
||||
} else {
|
||||
// For non-Linux platforms, still set up ipv4PC for consistency
|
||||
b.ipv4PC = ipv4.NewPacketConn(newConn)
|
||||
b.ipv6PC = ipv6.NewPacketConn(newConn)
|
||||
}
|
||||
|
||||
// Update the port
|
||||
if addr, ok := newConn.LocalAddr().(*net.UDPAddr); ok {
|
||||
b.port = uint16(addr.Port)
|
||||
logger.Info("Rebound UDP socket to port %d", b.port)
|
||||
}
|
||||
|
||||
// Clear the rebinding flag and wake up any waiting receive goroutines
|
||||
b.rebinding = false
|
||||
b.rebindingCond.Broadcast()
|
||||
|
||||
logger.Debug("Rebind complete, signaled waiting receive goroutines")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetMagicResponseCallback sets a callback function that will be called when
|
||||
// a magic test response packet is received. This is used for holepunch testing.
|
||||
// Pass nil to clear the callback.
|
||||
func (b *SharedBind) SetMagicResponseCallback(callback MagicResponseCallback) {
|
||||
if callback == nil {
|
||||
b.magicResponseCallback.Store(nil)
|
||||
} else {
|
||||
// Convert to the function type the atomic.Pointer expects
|
||||
fn := func(addr netip.AddrPort, echoData []byte) {
|
||||
callback(addr, echoData)
|
||||
}
|
||||
b.magicResponseCallback.Store(&fn)
|
||||
}
|
||||
}
|
||||
|
||||
// WriteToUDP writes data to a specific UDP address.
|
||||
// This is thread-safe and can be used by hole punch senders.
|
||||
func (b *SharedBind) WriteToUDP(data []byte, addr *net.UDPAddr) (int, error) {
|
||||
if b.closed.Load() {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
b.mu.RLock()
|
||||
conn := b.udpConn
|
||||
b.mu.RUnlock()
|
||||
|
||||
if conn == nil {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
return conn.WriteToUDP(data, addr)
|
||||
}
|
||||
|
||||
// Close implements the WireGuard Bind interface.
|
||||
// It decrements the reference count and closes the connection if no references remain.
|
||||
func (b *SharedBind) Close() error {
|
||||
return b.Release()
|
||||
}
|
||||
|
||||
// Open implements the WireGuard Bind interface.
|
||||
// Since the connection is already open, this just sets up the receive functions.
|
||||
func (b *SharedBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
|
||||
if b.closed.Load() {
|
||||
return nil, 0, net.ErrClosed
|
||||
}
|
||||
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.udpConn == nil {
|
||||
return nil, 0, net.ErrClosed
|
||||
}
|
||||
|
||||
// Set up IPv4 and IPv6 packet connections for advanced features
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
b.ipv4PC = ipv4.NewPacketConn(b.udpConn)
|
||||
b.ipv6PC = ipv6.NewPacketConn(b.udpConn)
|
||||
|
||||
// Pre-allocate message buffers for batch operations
|
||||
batchSize := wgConn.IdealBatchSize
|
||||
b.ipv4Msgs = make([]ipv4.Message, batchSize)
|
||||
for i := range b.ipv4Msgs {
|
||||
b.ipv4Msgs[i].OOB = make([]byte, 0)
|
||||
}
|
||||
}
|
||||
|
||||
// Create receive functions - one for socket, one for netstack
|
||||
recvFuncs := make([]wgConn.ReceiveFunc, 0, 2)
|
||||
|
||||
// Add socket receive function (reads from physical UDP socket)
|
||||
recvFuncs = append(recvFuncs, b.makeReceiveSocket())
|
||||
|
||||
// Add netstack receive function (reads from injected packets channel)
|
||||
recvFuncs = append(recvFuncs, b.makeReceiveNetstack())
|
||||
|
||||
b.recvFuncs = recvFuncs
|
||||
return recvFuncs, b.port, nil
|
||||
}
|
||||
|
||||
// makeReceiveSocket creates a receive function for physical UDP socket packets
|
||||
func (b *SharedBind) makeReceiveSocket() wgConn.ReceiveFunc {
|
||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||
for {
|
||||
if b.closed.Load() {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
b.mu.RLock()
|
||||
conn := b.udpConn
|
||||
pc := b.ipv4PC
|
||||
b.mu.RUnlock()
|
||||
|
||||
if conn == nil {
|
||||
// Socket is nil - check if we're rebinding or truly closed
|
||||
if b.closed.Load() {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
// Wait for rebind to complete
|
||||
b.mu.Lock()
|
||||
for b.rebinding && !b.closed.Load() {
|
||||
logger.Debug("Receive goroutine waiting for socket rebind to complete")
|
||||
b.rebindingCond.Wait()
|
||||
}
|
||||
b.mu.Unlock()
|
||||
|
||||
// Check again after waking up
|
||||
if b.closed.Load() {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
// Loop back to retry with new socket
|
||||
continue
|
||||
}
|
||||
|
||||
// Use batch reading on Linux for performance
|
||||
var n int
|
||||
var err error
|
||||
if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") {
|
||||
n, err = b.receiveIPv4Batch(pc, bufs, sizes, eps)
|
||||
} else {
|
||||
n, err = b.receiveIPv4Simple(conn, bufs, sizes, eps)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// Check if this error is due to rebinding
|
||||
b.mu.RLock()
|
||||
rebinding := b.rebinding
|
||||
b.mu.RUnlock()
|
||||
|
||||
if rebinding {
|
||||
logger.Debug("Receive got error during rebind, waiting for new socket: %v", err)
|
||||
// Wait for rebind to complete and retry
|
||||
b.mu.Lock()
|
||||
for b.rebinding && !b.closed.Load() {
|
||||
b.rebindingCond.Wait()
|
||||
}
|
||||
b.mu.Unlock()
|
||||
|
||||
if b.closed.Load() {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
// Retry with new socket
|
||||
continue
|
||||
}
|
||||
|
||||
// Not rebinding, return the error
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// makeReceiveNetstack creates a receive function for netstack-injected packets
|
||||
func (b *SharedBind) makeReceiveNetstack() wgConn.ReceiveFunc {
|
||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||
select {
|
||||
case <-b.closeChan:
|
||||
return 0, net.ErrClosed
|
||||
case pkt := <-b.netstackPackets:
|
||||
if len(pkt.data) <= len(bufs[0]) {
|
||||
copy(bufs[0], pkt.data)
|
||||
sizes[0] = len(pkt.data)
|
||||
eps[0] = pkt.endpoint
|
||||
return 1, nil
|
||||
}
|
||||
// Packet too large for buffer, skip it
|
||||
return 0, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// receiveIPv4Batch uses batch reading for better performance on Linux
|
||||
func (b *SharedBind) receiveIPv4Batch(pc *ipv4.PacketConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
|
||||
// Use pre-allocated messages, just update buffer pointers
|
||||
numBufs := len(bufs)
|
||||
if numBufs > len(b.ipv4Msgs) {
|
||||
numBufs = len(b.ipv4Msgs)
|
||||
}
|
||||
|
||||
for i := 0; i < numBufs; i++ {
|
||||
b.ipv4Msgs[i].Buffers = [][]byte{bufs[i]}
|
||||
}
|
||||
|
||||
numMsgs, err := pc.ReadBatch(b.ipv4Msgs[:numBufs], 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Process messages and filter out magic packets
|
||||
writeIdx := 0
|
||||
for i := 0; i < numMsgs; i++ {
|
||||
if b.ipv4Msgs[i].N == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for magic packet
|
||||
if b.ipv4Msgs[i].Addr != nil {
|
||||
if udpAddr, ok := b.ipv4Msgs[i].Addr.(*net.UDPAddr); ok {
|
||||
data := bufs[i][:b.ipv4Msgs[i].N]
|
||||
if b.handleMagicPacket(data, udpAddr) {
|
||||
// Magic packet handled, skip this message
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Not a magic packet, include in output
|
||||
if writeIdx != i {
|
||||
// Need to copy data to the correct position
|
||||
copy(bufs[writeIdx], bufs[i][:b.ipv4Msgs[i].N])
|
||||
}
|
||||
sizes[writeIdx] = b.ipv4Msgs[i].N
|
||||
|
||||
if b.ipv4Msgs[i].Addr != nil {
|
||||
if udpAddr, ok := b.ipv4Msgs[i].Addr.(*net.UDPAddr); ok {
|
||||
addrPort := udpAddr.AddrPort()
|
||||
// Unmap IPv4-in-IPv6 addresses to ensure consistency with parsed endpoints
|
||||
if addrPort.Addr().Is4In6() {
|
||||
addrPort = netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())
|
||||
}
|
||||
eps[writeIdx] = &wgConn.StdNetEndpoint{AddrPort: addrPort}
|
||||
}
|
||||
}
|
||||
writeIdx++
|
||||
}
|
||||
|
||||
return writeIdx, nil
|
||||
}
|
||||
|
||||
// receiveIPv4Simple uses simple ReadFromUDP for non-Linux platforms
|
||||
func (b *SharedBind) receiveIPv4Simple(conn *net.UDPConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
|
||||
// No read deadline - we rely on socket close to unblock during rebind.
|
||||
// The caller (makeReceiveSocket) handles rebind state when errors occur.
|
||||
for {
|
||||
n, addr, err := conn.ReadFromUDP(bufs[0])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Check for magic test packet and handle it directly
|
||||
if b.handleMagicPacket(bufs[0][:n], addr) {
|
||||
// Magic packet was handled, read another packet
|
||||
continue
|
||||
}
|
||||
|
||||
sizes[0] = n
|
||||
if addr != nil {
|
||||
addrPort := addr.AddrPort()
|
||||
// Unmap IPv4-in-IPv6 addresses to ensure consistency with parsed endpoints
|
||||
if addrPort.Addr().Is4In6() {
|
||||
addrPort = netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())
|
||||
}
|
||||
eps[0] = &wgConn.StdNetEndpoint{AddrPort: addrPort}
|
||||
}
|
||||
|
||||
return 1, nil
|
||||
}
|
||||
}
|
||||
|
||||
// handleMagicPacket checks if the packet is a magic test packet and responds if so.
|
||||
// Returns true if the packet was a magic packet and was handled (should not be passed to WireGuard).
|
||||
func (b *SharedBind) handleMagicPacket(data []byte, addr *net.UDPAddr) bool {
|
||||
// Check if this is a test request packet
|
||||
if len(data) >= MagicTestRequestLen && bytes.HasPrefix(data, MagicTestRequest) {
|
||||
// logger.Debug("Received magic test REQUEST from %s, sending response", addr.String())
|
||||
// Extract the random data portion to echo back
|
||||
echoData := data[len(MagicTestRequest) : len(MagicTestRequest)+MagicPacketDataLen]
|
||||
|
||||
// Build response packet
|
||||
response := make([]byte, MagicTestResponseLen)
|
||||
copy(response, MagicTestResponse)
|
||||
copy(response[len(MagicTestResponse):], echoData)
|
||||
|
||||
// Send response back to sender
|
||||
b.mu.RLock()
|
||||
conn := b.udpConn
|
||||
b.mu.RUnlock()
|
||||
|
||||
if conn != nil {
|
||||
_, _ = conn.WriteToUDP(response, addr)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if this is a test response packet
|
||||
if len(data) >= MagicTestResponseLen && bytes.HasPrefix(data, MagicTestResponse) {
|
||||
// logger.Debug("Received magic test RESPONSE from %s", addr.String())
|
||||
// Extract the echoed data
|
||||
echoData := data[len(MagicTestResponse) : len(MagicTestResponse)+MagicPacketDataLen]
|
||||
|
||||
// Call the callback if set
|
||||
callbackPtr := b.magicResponseCallback.Load()
|
||||
if callbackPtr != nil {
|
||||
callback := *callbackPtr
|
||||
addrPort := addr.AddrPort()
|
||||
// Unmap IPv4-in-IPv6 addresses to ensure consistency
|
||||
if addrPort.Addr().Is4In6() {
|
||||
addrPort = netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())
|
||||
}
|
||||
callback(addrPort, echoData)
|
||||
} else {
|
||||
logger.Debug("Magic response received but no callback registered")
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Send implements the WireGuard Bind interface.
|
||||
// It sends packets to the specified endpoint, routing through the appropriate
|
||||
// source (netstack or physical socket) based on where the endpoint's packets came from.
|
||||
func (b *SharedBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||
if b.closed.Load() {
|
||||
return net.ErrClosed
|
||||
}
|
||||
|
||||
// Extract the destination address from the endpoint
|
||||
var destAddrPort netip.AddrPort
|
||||
|
||||
// Try to cast to StdNetEndpoint first (most common case, avoid allocations)
|
||||
if stdEp, ok := ep.(*wgConn.StdNetEndpoint); ok {
|
||||
destAddrPort = stdEp.AddrPort
|
||||
} else {
|
||||
// Fallback: construct from DstIP and DstToBytes
|
||||
dstBytes := ep.DstToBytes()
|
||||
if len(dstBytes) >= 6 { // Minimum for IPv4 (4 bytes) + port (2 bytes)
|
||||
var addr netip.Addr
|
||||
var port uint16
|
||||
|
||||
if len(dstBytes) >= 18 { // IPv6 (16 bytes) + port (2 bytes)
|
||||
addr, _ = netip.AddrFromSlice(dstBytes[:16])
|
||||
port = uint16(dstBytes[16]) | uint16(dstBytes[17])<<8
|
||||
} else { // IPv4
|
||||
addr, _ = netip.AddrFromSlice(dstBytes[:4])
|
||||
port = uint16(dstBytes[4]) | uint16(dstBytes[5])<<8
|
||||
}
|
||||
|
||||
if addr.IsValid() {
|
||||
destAddrPort = netip.AddrPortFrom(addr, port)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !destAddrPort.IsValid() {
|
||||
return fmt.Errorf("could not extract destination address from endpoint")
|
||||
}
|
||||
|
||||
// Check if this endpoint came from netstack - if so, send through netstack
|
||||
// Use AddrPort directly as key (more efficient than string conversion)
|
||||
if _, isNetstackEndpoint := b.netstackEndpoints.Load(destAddrPort); isNetstackEndpoint {
|
||||
connPtr := b.netstackConn.Load()
|
||||
if connPtr != nil && *connPtr != nil {
|
||||
netstackConn := *connPtr
|
||||
destAddr := net.UDPAddrFromAddrPort(destAddrPort)
|
||||
// Send all buffers through netstack
|
||||
for _, buf := range bufs {
|
||||
_, err := netstackConn.WriteTo(buf, destAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// Fall through to socket if netstack conn not available
|
||||
}
|
||||
|
||||
// Send through the physical UDP socket (for hole-punched clients)
|
||||
b.mu.RLock()
|
||||
conn := b.udpConn
|
||||
b.mu.RUnlock()
|
||||
|
||||
if conn == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
|
||||
destAddr := net.UDPAddrFromAddrPort(destAddrPort)
|
||||
|
||||
// Send all buffers to the destination
|
||||
for _, buf := range bufs {
|
||||
_, err := conn.WriteToUDP(buf, destAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetMark implements the WireGuard Bind interface.
|
||||
// It's a no-op for this implementation.
|
||||
func (b *SharedBind) SetMark(mark uint32) error {
|
||||
// Not implemented for this use case
|
||||
return nil
|
||||
}
|
||||
|
||||
// BatchSize returns the preferred batch size for sending packets.
|
||||
func (b *SharedBind) BatchSize() int {
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
return wgConn.IdealBatchSize
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// ParseEndpoint creates a new endpoint from a string address.
|
||||
func (b *SharedBind) ParseEndpoint(s string) (wgConn.Endpoint, error) {
|
||||
addrPort, err := netip.ParseAddrPort(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &wgConn.StdNetEndpoint{AddrPort: addrPort}, nil
|
||||
}
|
||||
555
bind/shared_bind_test.go
Normal file
555
bind/shared_bind_test.go
Normal file
@@ -0,0 +1,555 @@
|
||||
//go:build !js
|
||||
|
||||
package bind
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
// TestSharedBindCreation tests basic creation and initialization
|
||||
func TestSharedBindCreation(t *testing.T) {
|
||||
// Create a UDP connection
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||
}
|
||||
defer udpConn.Close()
|
||||
|
||||
// Create SharedBind
|
||||
bind, err := New(udpConn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||
}
|
||||
|
||||
if bind == nil {
|
||||
t.Fatal("SharedBind is nil")
|
||||
}
|
||||
|
||||
// Verify initial reference count
|
||||
if bind.refCount.Load() != 1 {
|
||||
t.Errorf("Expected initial refCount to be 1, got %d", bind.refCount.Load())
|
||||
}
|
||||
|
||||
// Clean up
|
||||
if err := bind.Close(); err != nil {
|
||||
t.Errorf("Failed to close SharedBind: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSharedBindReferenceCount tests reference counting
|
||||
func TestSharedBindReferenceCount(t *testing.T) {
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||
}
|
||||
|
||||
bind, err := New(udpConn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||
}
|
||||
|
||||
// Add references
|
||||
bind.AddRef()
|
||||
if bind.refCount.Load() != 2 {
|
||||
t.Errorf("Expected refCount to be 2, got %d", bind.refCount.Load())
|
||||
}
|
||||
|
||||
bind.AddRef()
|
||||
if bind.refCount.Load() != 3 {
|
||||
t.Errorf("Expected refCount to be 3, got %d", bind.refCount.Load())
|
||||
}
|
||||
|
||||
// Release references
|
||||
bind.Release()
|
||||
if bind.refCount.Load() != 2 {
|
||||
t.Errorf("Expected refCount to be 2 after release, got %d", bind.refCount.Load())
|
||||
}
|
||||
|
||||
bind.Release()
|
||||
bind.Release() // This should close the connection
|
||||
|
||||
if !bind.closed.Load() {
|
||||
t.Error("Expected bind to be closed after all references released")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSharedBindWriteToUDP tests the WriteToUDP functionality
|
||||
func TestSharedBindWriteToUDP(t *testing.T) {
|
||||
// Create sender
|
||||
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create sender UDP connection: %v", err)
|
||||
}
|
||||
|
||||
senderBind, err := New(senderConn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create sender SharedBind: %v", err)
|
||||
}
|
||||
defer senderBind.Close()
|
||||
|
||||
// Create receiver
|
||||
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create receiver UDP connection: %v", err)
|
||||
}
|
||||
defer receiverConn.Close()
|
||||
|
||||
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
|
||||
|
||||
// Send data
|
||||
testData := []byte("Hello, SharedBind!")
|
||||
n, err := senderBind.WriteToUDP(testData, receiverAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteToUDP failed: %v", err)
|
||||
}
|
||||
|
||||
if n != len(testData) {
|
||||
t.Errorf("Expected to send %d bytes, sent %d", len(testData), n)
|
||||
}
|
||||
|
||||
// Receive data
|
||||
buf := make([]byte, 1024)
|
||||
receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
n, _, err = receiverConn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to receive data: %v", err)
|
||||
}
|
||||
|
||||
if string(buf[:n]) != string(testData) {
|
||||
t.Errorf("Expected to receive %q, got %q", testData, buf[:n])
|
||||
}
|
||||
}
|
||||
|
||||
// TestSharedBindConcurrentWrites tests thread-safety
|
||||
func TestSharedBindConcurrentWrites(t *testing.T) {
|
||||
// Create sender
|
||||
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create sender UDP connection: %v", err)
|
||||
}
|
||||
|
||||
senderBind, err := New(senderConn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create sender SharedBind: %v", err)
|
||||
}
|
||||
defer senderBind.Close()
|
||||
|
||||
// Create receiver
|
||||
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create receiver UDP connection: %v", err)
|
||||
}
|
||||
defer receiverConn.Close()
|
||||
|
||||
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
|
||||
|
||||
// Launch concurrent writes
|
||||
numGoroutines := 100
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
data := []byte{byte(id)}
|
||||
_, err := senderBind.WriteToUDP(data, receiverAddr)
|
||||
if err != nil {
|
||||
t.Errorf("WriteToUDP failed in goroutine %d: %v", id, err)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TestSharedBindWireGuardInterface tests WireGuard Bind interface implementation
|
||||
func TestSharedBindWireGuardInterface(t *testing.T) {
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||
}
|
||||
|
||||
bind, err := New(udpConn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||
}
|
||||
defer bind.Close()
|
||||
|
||||
// Test Open
|
||||
recvFuncs, port, err := bind.Open(0)
|
||||
if err != nil {
|
||||
t.Fatalf("Open failed: %v", err)
|
||||
}
|
||||
|
||||
if len(recvFuncs) == 0 {
|
||||
t.Error("Expected at least one receive function")
|
||||
}
|
||||
|
||||
if port == 0 {
|
||||
t.Error("Expected non-zero port")
|
||||
}
|
||||
|
||||
// Test SetMark (should be a no-op)
|
||||
if err := bind.SetMark(0); err != nil {
|
||||
t.Errorf("SetMark failed: %v", err)
|
||||
}
|
||||
|
||||
// Test BatchSize
|
||||
batchSize := bind.BatchSize()
|
||||
if batchSize <= 0 {
|
||||
t.Error("Expected positive batch size")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSharedBindSend tests the Send method with WireGuard endpoints
|
||||
func TestSharedBindSend(t *testing.T) {
|
||||
// Create sender
|
||||
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create sender UDP connection: %v", err)
|
||||
}
|
||||
|
||||
senderBind, err := New(senderConn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create sender SharedBind: %v", err)
|
||||
}
|
||||
defer senderBind.Close()
|
||||
|
||||
// Create receiver
|
||||
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create receiver UDP connection: %v", err)
|
||||
}
|
||||
defer receiverConn.Close()
|
||||
|
||||
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
|
||||
|
||||
// Create an endpoint
|
||||
addrPort := receiverAddr.AddrPort()
|
||||
endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort}
|
||||
|
||||
// Send data
|
||||
testData := []byte("WireGuard packet")
|
||||
bufs := [][]byte{testData}
|
||||
err = senderBind.Send(bufs, endpoint)
|
||||
if err != nil {
|
||||
t.Fatalf("Send failed: %v", err)
|
||||
}
|
||||
|
||||
// Receive data
|
||||
buf := make([]byte, 1024)
|
||||
receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
n, _, err := receiverConn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to receive data: %v", err)
|
||||
}
|
||||
|
||||
if string(buf[:n]) != string(testData) {
|
||||
t.Errorf("Expected to receive %q, got %q", testData, buf[:n])
|
||||
}
|
||||
}
|
||||
|
||||
// TestSharedBindMultipleUsers simulates WireGuard and hole punch using the same bind
|
||||
func TestSharedBindMultipleUsers(t *testing.T) {
|
||||
// Create shared bind
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||
}
|
||||
|
||||
sharedBind, err := New(udpConn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||
}
|
||||
|
||||
// Add reference for hole punch sender
|
||||
sharedBind.AddRef()
|
||||
|
||||
// Create receiver
|
||||
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create receiver UDP connection: %v", err)
|
||||
}
|
||||
defer receiverConn.Close()
|
||||
|
||||
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Simulate WireGuard using the bind
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
addrPort := receiverAddr.AddrPort()
|
||||
endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
data := []byte("WireGuard packet")
|
||||
bufs := [][]byte{data}
|
||||
if err := sharedBind.Send(bufs, endpoint); err != nil {
|
||||
t.Errorf("WireGuard Send failed: %v", err)
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
}()
|
||||
|
||||
// Simulate hole punch sender using the bind
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 10; i++ {
|
||||
data := []byte("Hole punch packet")
|
||||
if _, err := sharedBind.WriteToUDP(data, receiverAddr); err != nil {
|
||||
t.Errorf("Hole punch WriteToUDP failed: %v", err)
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Release the hole punch reference
|
||||
sharedBind.Release()
|
||||
|
||||
// Close WireGuard's reference (should close the connection)
|
||||
sharedBind.Close()
|
||||
|
||||
if !sharedBind.closed.Load() {
|
||||
t.Error("Expected bind to be closed after all users released it")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEndpoint tests the Endpoint implementation
|
||||
func TestEndpoint(t *testing.T) {
|
||||
addr := netip.MustParseAddr("192.168.1.1")
|
||||
addrPort := netip.AddrPortFrom(addr, 51820)
|
||||
|
||||
ep := &Endpoint{AddrPort: addrPort}
|
||||
|
||||
// Test DstIP
|
||||
if ep.DstIP() != addr {
|
||||
t.Errorf("Expected DstIP to be %v, got %v", addr, ep.DstIP())
|
||||
}
|
||||
|
||||
// Test DstToString
|
||||
expected := "192.168.1.1:51820"
|
||||
if ep.DstToString() != expected {
|
||||
t.Errorf("Expected DstToString to be %q, got %q", expected, ep.DstToString())
|
||||
}
|
||||
|
||||
// Test DstToBytes
|
||||
bytes := ep.DstToBytes()
|
||||
if len(bytes) == 0 {
|
||||
t.Error("Expected DstToBytes to return non-empty slice")
|
||||
}
|
||||
|
||||
// Test SrcIP (should be zero)
|
||||
if ep.SrcIP().IsValid() {
|
||||
t.Error("Expected SrcIP to be invalid")
|
||||
}
|
||||
|
||||
// Test ClearSrc (should not panic)
|
||||
ep.ClearSrc()
|
||||
}
|
||||
|
||||
// TestParseEndpoint tests the ParseEndpoint method
|
||||
func TestParseEndpoint(t *testing.T) {
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||
}
|
||||
|
||||
bind, err := New(udpConn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||
}
|
||||
defer bind.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
checkAddr func(*testing.T, wgConn.Endpoint)
|
||||
}{
|
||||
{
|
||||
name: "valid IPv4",
|
||||
input: "192.168.1.1:51820",
|
||||
wantErr: false,
|
||||
checkAddr: func(t *testing.T, ep wgConn.Endpoint) {
|
||||
if ep.DstToString() != "192.168.1.1:51820" {
|
||||
t.Errorf("Expected 192.168.1.1:51820, got %s", ep.DstToString())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid IPv6",
|
||||
input: "[::1]:51820",
|
||||
wantErr: false,
|
||||
checkAddr: func(t *testing.T, ep wgConn.Endpoint) {
|
||||
if ep.DstToString() != "[::1]:51820" {
|
||||
t.Errorf("Expected [::1]:51820, got %s", ep.DstToString())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid - missing port",
|
||||
input: "192.168.1.1",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - bad format",
|
||||
input: "not-an-address",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ep, err := bind.ParseEndpoint(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseEndpoint() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !tt.wantErr && tt.checkAddr != nil {
|
||||
tt.checkAddr(t, ep)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNetstackRouting tests that packets from netstack endpoints are routed back through netstack
|
||||
func TestNetstackRouting(t *testing.T) {
|
||||
// Create the SharedBind with a physical UDP socket
|
||||
physicalConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create physical UDP connection: %v", err)
|
||||
}
|
||||
|
||||
sharedBind, err := New(physicalConn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||
}
|
||||
defer sharedBind.Close()
|
||||
|
||||
// Create a mock "netstack" connection (just another UDP socket for testing)
|
||||
netstackConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create netstack UDP connection: %v", err)
|
||||
}
|
||||
defer netstackConn.Close()
|
||||
|
||||
// Set the netstack connection
|
||||
sharedBind.SetNetstackConn(netstackConn)
|
||||
|
||||
// Create a "client" that would receive packets
|
||||
clientConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client UDP connection: %v", err)
|
||||
}
|
||||
defer clientConn.Close()
|
||||
|
||||
clientAddr := clientConn.LocalAddr().(*net.UDPAddr)
|
||||
clientAddrPort := clientAddr.AddrPort()
|
||||
|
||||
// Inject a packet from the "netstack" source - this should track the endpoint
|
||||
testData := []byte("test packet from netstack")
|
||||
err = sharedBind.InjectPacket(testData, clientAddrPort)
|
||||
if err != nil {
|
||||
t.Fatalf("InjectPacket failed: %v", err)
|
||||
}
|
||||
|
||||
// Now when we send a response to this endpoint, it should go through netstack
|
||||
endpoint := &wgConn.StdNetEndpoint{AddrPort: clientAddrPort}
|
||||
responseData := []byte("response packet")
|
||||
err = sharedBind.Send([][]byte{responseData}, endpoint)
|
||||
if err != nil {
|
||||
t.Fatalf("Send failed: %v", err)
|
||||
}
|
||||
|
||||
// The packet should be received by the client from the netstack connection
|
||||
buf := make([]byte, 1024)
|
||||
clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
n, fromAddr, err := clientConn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to receive response: %v", err)
|
||||
}
|
||||
|
||||
if string(buf[:n]) != string(responseData) {
|
||||
t.Errorf("Expected to receive %q, got %q", responseData, buf[:n])
|
||||
}
|
||||
|
||||
// Verify the response came from the netstack connection, not the physical one
|
||||
netstackAddr := netstackConn.LocalAddr().(*net.UDPAddr)
|
||||
if fromAddr.Port != netstackAddr.Port {
|
||||
t.Errorf("Expected response from netstack port %d, got %d", netstackAddr.Port, fromAddr.Port)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSocketRouting tests that packets from socket endpoints are routed through socket
|
||||
func TestSocketRouting(t *testing.T) {
|
||||
// Create the SharedBind with a physical UDP socket
|
||||
physicalConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create physical UDP connection: %v", err)
|
||||
}
|
||||
|
||||
sharedBind, err := New(physicalConn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||
}
|
||||
defer sharedBind.Close()
|
||||
|
||||
// Create a mock "netstack" connection
|
||||
netstackConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create netstack UDP connection: %v", err)
|
||||
}
|
||||
defer netstackConn.Close()
|
||||
|
||||
// Set the netstack connection
|
||||
sharedBind.SetNetstackConn(netstackConn)
|
||||
|
||||
// Create a "client" that would receive packets (this simulates a hole-punched client)
|
||||
clientConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client UDP connection: %v", err)
|
||||
}
|
||||
defer clientConn.Close()
|
||||
|
||||
clientAddr := clientConn.LocalAddr().(*net.UDPAddr)
|
||||
clientAddrPort := clientAddr.AddrPort()
|
||||
|
||||
// Don't inject from netstack - this endpoint is NOT tracked as netstack-sourced
|
||||
// So Send should use the physical socket
|
||||
|
||||
endpoint := &wgConn.StdNetEndpoint{AddrPort: clientAddrPort}
|
||||
responseData := []byte("response packet via socket")
|
||||
err = sharedBind.Send([][]byte{responseData}, endpoint)
|
||||
if err != nil {
|
||||
t.Fatalf("Send failed: %v", err)
|
||||
}
|
||||
|
||||
// The packet should be received by the client from the physical connection
|
||||
buf := make([]byte, 1024)
|
||||
clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
n, fromAddr, err := clientConn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to receive response: %v", err)
|
||||
}
|
||||
|
||||
if string(buf[:n]) != string(responseData) {
|
||||
t.Errorf("Expected to receive %q, got %q", responseData, buf[:n])
|
||||
}
|
||||
|
||||
// Verify the response came from the physical connection, not the netstack one
|
||||
physicalAddr := physicalConn.LocalAddr().(*net.UDPAddr)
|
||||
if fromAddr.Port != physicalAddr.Port {
|
||||
t.Errorf("Expected response from physical port %d, got %d", physicalAddr.Port, fromAddr.Port)
|
||||
}
|
||||
}
|
||||
@@ -12,9 +12,9 @@ resources:
|
||||
sso-roles:
|
||||
- Member
|
||||
sso-users:
|
||||
- owen@fossorial.io
|
||||
- owen@pangolin.net
|
||||
whitelist-users:
|
||||
- owen@fossorial.io
|
||||
- owen@pangolin.net
|
||||
targets:
|
||||
# - site: glossy-plains-viscacha-rat
|
||||
- hostname: localhost
|
||||
|
||||
78
clients.go
78
clients.go
@@ -1,20 +1,17 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/fosrl/newt/clients"
|
||||
wgnetstack "github.com/fosrl/newt/clients"
|
||||
"github.com/fosrl/newt/clients/permissions"
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/proxy"
|
||||
"github.com/fosrl/newt/websocket"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/fosrl/newt/wgnetstack"
|
||||
"github.com/fosrl/newt/wgtester"
|
||||
)
|
||||
|
||||
var wgService *wgnetstack.WireGuardService
|
||||
var wgTesterServer *wgtester.Server
|
||||
var wgService *clients.WireGuardService
|
||||
var ready bool
|
||||
|
||||
func setupClients(client *websocket.Client) {
|
||||
@@ -27,43 +24,29 @@ func setupClients(client *websocket.Client) {
|
||||
|
||||
host = strings.TrimSuffix(host, "/")
|
||||
|
||||
logger.Debug("Setting up clients with netstack2...")
|
||||
|
||||
// if useNativeInterface is true make sure we have permission to use native interface
|
||||
if useNativeInterface {
|
||||
setupClientsNative(client, host)
|
||||
} else {
|
||||
setupClientsNetstack(client, host)
|
||||
logger.Debug("Checking permissions for native interface")
|
||||
err := permissions.CheckNativeInterfacePermissions()
|
||||
if err != nil {
|
||||
logger.Fatal("Insufficient permissions to create native TUN interface: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ready = true
|
||||
}
|
||||
|
||||
func setupClientsNetstack(client *websocket.Client, host string) {
|
||||
logger.Info("Setting up clients with netstack...")
|
||||
// Create WireGuard service
|
||||
wgService, err = wgnetstack.NewWireGuardService(interfaceName, mtuInt, generateAndSaveKeyTo, host, id, client, "9.9.9.9")
|
||||
wgService, err = wgnetstack.NewWireGuardService(interfaceName, port, mtuInt, host, id, client, dns, useNativeInterface)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to create WireGuard service: %v", err)
|
||||
}
|
||||
|
||||
// // Set up callback to restart wgtester with netstack when WireGuard is ready
|
||||
wgService.SetOnNetstackReady(func(tnet *netstack.Net) {
|
||||
|
||||
wgTesterServer = wgtester.NewServerWithNetstack("0.0.0.0", wgService.Port, id, tnet) // TODO: maybe make this the same ip of the wg server?
|
||||
err := wgTesterServer.Start()
|
||||
if err != nil {
|
||||
logger.Error("Failed to start WireGuard tester server: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
wgService.SetOnNetstackClose(func() {
|
||||
if wgTesterServer != nil {
|
||||
wgTesterServer.Stop()
|
||||
wgTesterServer = nil
|
||||
}
|
||||
})
|
||||
|
||||
client.OnTokenUpdate(func(token string) {
|
||||
wgService.SetToken(token)
|
||||
})
|
||||
|
||||
ready = true
|
||||
}
|
||||
|
||||
func setDownstreamTNetstack(tnet *netstack.Net) {
|
||||
@@ -75,19 +58,12 @@ func setDownstreamTNetstack(tnet *netstack.Net) {
|
||||
func closeClients() {
|
||||
logger.Info("Closing clients...")
|
||||
if wgService != nil {
|
||||
wgService.Close(!keepInterface)
|
||||
wgService.Close()
|
||||
wgService = nil
|
||||
}
|
||||
|
||||
closeWgServiceNative()
|
||||
|
||||
if wgTesterServer != nil {
|
||||
wgTesterServer.Stop()
|
||||
wgTesterServer = nil
|
||||
}
|
||||
}
|
||||
|
||||
func clientsHandleNewtConnection(publicKey string, endpoint string) {
|
||||
func clientsHandleNewtConnection(publicKey string, endpoint string, relayPort uint16) {
|
||||
if !ready {
|
||||
return
|
||||
}
|
||||
@@ -101,10 +77,8 @@ func clientsHandleNewtConnection(publicKey string, endpoint string) {
|
||||
endpoint = strings.Join(parts[:len(parts)-1], ":")
|
||||
|
||||
if wgService != nil {
|
||||
wgService.StartHolepunch(publicKey, endpoint)
|
||||
wgService.StartHolepunch(publicKey, endpoint, relayPort)
|
||||
}
|
||||
|
||||
clientsHandleNewtConnectionNative(publicKey, endpoint)
|
||||
}
|
||||
|
||||
func clientsOnConnect() {
|
||||
@@ -114,19 +88,17 @@ func clientsOnConnect() {
|
||||
if wgService != nil {
|
||||
wgService.LoadRemoteConfig()
|
||||
}
|
||||
|
||||
clientsOnConnectNative()
|
||||
}
|
||||
|
||||
func clientsAddProxyTarget(pm *proxy.ProxyManager, tunnelIp string) {
|
||||
// clientsStartDirectRelay starts a direct UDP relay from the main tunnel netstack
|
||||
// to the clients' WireGuard, bypassing the proxy for better performance.
|
||||
func clientsStartDirectRelay(tunnelIP string) {
|
||||
if !ready {
|
||||
return
|
||||
}
|
||||
// add a udp proxy for localost and the wgService port
|
||||
// TODO: make sure this port is not used in a target
|
||||
if wgService != nil {
|
||||
pm.AddTarget("udp", tunnelIp, int(wgService.Port), fmt.Sprintf("127.0.0.1:%d", wgService.Port))
|
||||
if err := wgService.StartDirectUDPRelay(tunnelIP); err != nil {
|
||||
logger.Error("Failed to start direct UDP relay: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
clientsAddProxyTargetNative(pm, tunnelIp)
|
||||
}
|
||||
|
||||
1276
clients/clients.go
Normal file
1276
clients/clients.go
Normal file
File diff suppressed because it is too large
Load Diff
8
clients/permissions/permissions_android.go
Normal file
8
clients/permissions/permissions_android.go
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build android
|
||||
|
||||
package permissions
|
||||
|
||||
// CheckNativeInterfacePermissions always allows permission on Android.
|
||||
func CheckNativeInterfacePermissions() error {
|
||||
return nil
|
||||
}
|
||||
18
clients/permissions/permissions_darwin.go
Normal file
18
clients/permissions/permissions_darwin.go
Normal file
@@ -0,0 +1,18 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package permissions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// CheckNativeInterfacePermissions checks if the process has sufficient
|
||||
// permissions to create a native TUN interface on macOS.
|
||||
// This typically requires root privileges.
|
||||
func CheckNativeInterfacePermissions() error {
|
||||
if os.Geteuid() == 0 {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("insufficient permissions: need root to create TUN interface on macOS")
|
||||
}
|
||||
57
clients/permissions/permissions_freebsd.go
Normal file
57
clients/permissions/permissions_freebsd.go
Normal file
@@ -0,0 +1,57 @@
|
||||
//go:build freebsd
|
||||
|
||||
package permissions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
// TUN device on FreeBSD
|
||||
tunDevice = "/dev/tun"
|
||||
ifnamsiz = 16
|
||||
iffTun = 0x0001
|
||||
iffNoPi = 0x1000
|
||||
)
|
||||
|
||||
// ifReq is the structure for TUN interface configuration
|
||||
type ifReq struct {
|
||||
Name [ifnamsiz]byte
|
||||
Flags uint16
|
||||
_ [22]byte // padding to match kernel structure
|
||||
}
|
||||
|
||||
// CheckNativeInterfacePermissions checks if the process has sufficient
|
||||
// permissions to create a native TUN interface on FreeBSD.
|
||||
// This requires root privileges (UID 0).
|
||||
func CheckNativeInterfacePermissions() error {
|
||||
logger.Debug("Checking native interface permissions on FreeBSD")
|
||||
|
||||
// Check if running as root
|
||||
if os.Geteuid() == 0 {
|
||||
logger.Debug("Running as root, sufficient permissions for native TUN interface")
|
||||
return nil
|
||||
}
|
||||
|
||||
// On FreeBSD, only root can create TUN interfaces
|
||||
// Try to open the TUN device to verify
|
||||
return tryOpenTunDevice()
|
||||
}
|
||||
|
||||
// tryOpenTunDevice attempts to open the TUN device to verify permissions.
|
||||
// On FreeBSD, /dev/tun is a cloning device that creates a new interface
|
||||
// when opened.
|
||||
func tryOpenTunDevice() error {
|
||||
// Try opening /dev/tun (cloning device)
|
||||
f, err := os.OpenFile(tunDevice, os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot open %s: %v (need root privileges)", tunDevice, err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
logger.Debug("Successfully opened TUN device, sufficient permissions for native TUN interface")
|
||||
return nil
|
||||
}
|
||||
8
clients/permissions/permissions_ios.go
Normal file
8
clients/permissions/permissions_ios.go
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build ios
|
||||
|
||||
package permissions
|
||||
|
||||
// CheckNativeInterfacePermissions always allows permission on iOS.
|
||||
func CheckNativeInterfacePermissions() error {
|
||||
return nil
|
||||
}
|
||||
96
clients/permissions/permissions_linux.go
Normal file
96
clients/permissions/permissions_linux.go
Normal file
@@ -0,0 +1,96 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package permissions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"unsafe"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
// TUN device constants
|
||||
tunDevice = "/dev/net/tun"
|
||||
ifnamsiz = 16
|
||||
iffTun = 0x0001
|
||||
iffNoPi = 0x1000
|
||||
tunSetIff = 0x400454ca
|
||||
)
|
||||
|
||||
// ifReq is the structure for TUNSETIFF ioctl
|
||||
type ifReq struct {
|
||||
Name [ifnamsiz]byte
|
||||
Flags uint16
|
||||
_ [22]byte // padding to match kernel structure
|
||||
}
|
||||
|
||||
// CheckNativeInterfacePermissions checks if the process has sufficient
|
||||
// permissions to create a native TUN interface on Linux.
|
||||
// This requires either root privileges (UID 0) or CAP_NET_ADMIN capability.
|
||||
func CheckNativeInterfacePermissions() error {
|
||||
logger.Debug("Checking native interface permissions on Linux")
|
||||
|
||||
// Check if running as root
|
||||
if os.Geteuid() == 0 {
|
||||
logger.Debug("Running as root, sufficient permissions for native TUN interface")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check for CAP_NET_ADMIN capability
|
||||
caps := unix.CapUserHeader{
|
||||
Version: unix.LINUX_CAPABILITY_VERSION_3,
|
||||
Pid: 0, // 0 means current process
|
||||
}
|
||||
|
||||
var data [2]unix.CapUserData
|
||||
if err := unix.Capget(&caps, &data[0]); err != nil {
|
||||
logger.Debug("Failed to get capabilities: %v, will try creating test TUN", err)
|
||||
} else {
|
||||
// CAP_NET_ADMIN is capability bit 12
|
||||
const CAP_NET_ADMIN = 12
|
||||
if data[0].Effective&(1<<CAP_NET_ADMIN) != 0 {
|
||||
logger.Debug("Process has CAP_NET_ADMIN capability, sufficient permissions for native TUN interface")
|
||||
return nil
|
||||
}
|
||||
logger.Debug("Process does not have CAP_NET_ADMIN capability in effective set")
|
||||
}
|
||||
|
||||
// Actually try to create a TUN interface to verify permissions
|
||||
// This is the most reliable check as it tests the actual operation
|
||||
return tryCreateTestTun()
|
||||
}
|
||||
|
||||
// tryCreateTestTun attempts to create a temporary TUN interface to verify
|
||||
// we have the necessary permissions. This tests the actual ioctl call that
|
||||
// will be used when creating the real interface.
|
||||
func tryCreateTestTun() error {
|
||||
f, err := os.OpenFile(tunDevice, os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot open %s: %v (need root or CAP_NET_ADMIN capability)", tunDevice, err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Try to create a TUN interface with a test name
|
||||
// Using a random-ish name to avoid conflicts
|
||||
var req ifReq
|
||||
copy(req.Name[:], "tuntest0")
|
||||
req.Flags = iffTun | iffNoPi
|
||||
|
||||
_, _, errno := unix.Syscall(
|
||||
unix.SYS_IOCTL,
|
||||
f.Fd(),
|
||||
uintptr(tunSetIff),
|
||||
uintptr(unsafe.Pointer(&req)),
|
||||
)
|
||||
|
||||
if errno != 0 {
|
||||
return fmt.Errorf("cannot create TUN interface (ioctl TUNSETIFF failed): %v (need root or CAP_NET_ADMIN capability)", errno)
|
||||
}
|
||||
|
||||
// Success - the interface will be automatically destroyed when we close the fd
|
||||
logger.Debug("Successfully created test TUN interface, sufficient permissions for native TUN interface")
|
||||
return nil
|
||||
}
|
||||
48
clients/permissions/permissions_windows.go
Normal file
48
clients/permissions/permissions_windows.go
Normal file
@@ -0,0 +1,48 @@
|
||||
//go:build windows
|
||||
|
||||
package permissions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/svc"
|
||||
)
|
||||
|
||||
// CheckNativeInterfacePermissions checks if the process has sufficient
|
||||
// permissions to create a native TUN interface on Windows.
|
||||
// This requires Administrator privileges and must be running as a Windows service.
|
||||
func CheckNativeInterfacePermissions() error {
|
||||
// Check if running as a Windows service
|
||||
isService, err := svc.IsWindowsService()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if running as Windows service: %v", err)
|
||||
}
|
||||
if !isService {
|
||||
return fmt.Errorf("native TUN interface requires running as a Windows service")
|
||||
}
|
||||
|
||||
var sid *windows.SID
|
||||
err = windows.AllocateAndInitializeSid(
|
||||
&windows.SECURITY_NT_AUTHORITY,
|
||||
2,
|
||||
windows.SECURITY_BUILTIN_DOMAIN_RID,
|
||||
windows.DOMAIN_ALIAS_RID_ADMINS,
|
||||
0, 0, 0, 0, 0, 0,
|
||||
&sid)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize SID: %v", err)
|
||||
}
|
||||
defer windows.FreeSid(sid)
|
||||
|
||||
token := windows.Token(0)
|
||||
member, err := token.IsMember(sid)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check admin group membership: %v", err)
|
||||
}
|
||||
|
||||
if !member {
|
||||
return fmt.Errorf("insufficient permissions: need Administrator to create TUN interface on Windows")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -2,11 +2,9 @@ package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
@@ -14,32 +12,20 @@ import (
|
||||
|
||||
"math/rand"
|
||||
|
||||
"github.com/fosrl/newt/internal/telemetry"
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/proxy"
|
||||
"github.com/fosrl/newt/websocket"
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func fixKey(key string) string {
|
||||
// Remove any whitespace
|
||||
key = strings.TrimSpace(key)
|
||||
|
||||
// Decode from base64
|
||||
decoded, err := base64.StdEncoding.DecodeString(key)
|
||||
if err != nil {
|
||||
logger.Fatal("Error decoding base64: %v", err)
|
||||
}
|
||||
|
||||
// Convert to hex
|
||||
return hex.EncodeToString(decoded)
|
||||
}
|
||||
const msgHealthFileWriteFailed = "Failed to write health file: %v"
|
||||
|
||||
func ping(tnet *netstack.Net, dst string, timeout time.Duration) (time.Duration, error) {
|
||||
logger.Debug("Pinging %s", dst)
|
||||
// logger.Debug("Pinging %s", dst)
|
||||
socket, err := tnet.Dial("ping4", dst)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to create ICMP socket: %w", err)
|
||||
@@ -98,7 +84,7 @@ func ping(tnet *netstack.Net, dst string, timeout time.Duration) (time.Duration,
|
||||
|
||||
latency := time.Since(start)
|
||||
|
||||
logger.Debug("Ping to %s successful, latency: %v", dst, latency)
|
||||
// logger.Debug("Ping to %s successful, latency: %v", dst, latency)
|
||||
|
||||
return latency, nil
|
||||
}
|
||||
@@ -136,7 +122,7 @@ func reliablePing(tnet *netstack.Net, dst string, baseTimeout time.Duration, max
|
||||
// If we get at least one success, we can return early for health checks
|
||||
if successCount > 0 {
|
||||
avgLatency := totalLatency / time.Duration(successCount)
|
||||
logger.Debug("Reliable ping succeeded after %d attempts, avg latency: %v", attempt, avgLatency)
|
||||
// logger.Debug("Reliable ping succeeded after %d attempts, avg latency: %v", attempt, avgLatency)
|
||||
return avgLatency, nil
|
||||
}
|
||||
}
|
||||
@@ -176,7 +162,7 @@ func pingWithRetry(tnet *netstack.Net, dst string, timeout time.Duration) (stopC
|
||||
if healthFile != "" {
|
||||
err := os.WriteFile(healthFile, []byte("ok"), 0644)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to write health file: %v", err)
|
||||
logger.Warn(msgHealthFileWriteFailed, err)
|
||||
}
|
||||
}
|
||||
return stopChan, nil
|
||||
@@ -217,11 +203,13 @@ func pingWithRetry(tnet *netstack.Net, dst string, timeout time.Duration) (stopC
|
||||
if healthFile != "" {
|
||||
err := os.WriteFile(healthFile, []byte("ok"), 0644)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to write health file: %v", err)
|
||||
logger.Warn(msgHealthFileWriteFailed, err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
case <-pingStopChan:
|
||||
// Stop the goroutine when signaled
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -230,7 +218,7 @@ func pingWithRetry(tnet *netstack.Net, dst string, timeout time.Duration) (stopC
|
||||
return stopChan, fmt.Errorf("initial ping attempts failed, continuing in background")
|
||||
}
|
||||
|
||||
func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Client) chan struct{} {
|
||||
func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Client, tunnelID string) chan struct{} {
|
||||
maxInterval := 6 * time.Second
|
||||
currentInterval := pingInterval
|
||||
consecutiveFailures := 0
|
||||
@@ -293,6 +281,9 @@ func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Clien
|
||||
if !connectionLost {
|
||||
connectionLost = true
|
||||
logger.Warn("Connection to server lost after %d failures. Continuous reconnection attempts will be made.", consecutiveFailures)
|
||||
if tunnelID != "" {
|
||||
telemetry.IncReconnect(context.Background(), tunnelID, "client", telemetry.ReasonTimeout)
|
||||
}
|
||||
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second)
|
||||
// Send registration message to the server for backward compatibility
|
||||
err := client.SendMessage("newt/wg/register", map[string]interface{}{
|
||||
@@ -319,6 +310,10 @@ func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Clien
|
||||
} else {
|
||||
// Track recent latencies
|
||||
recentLatencies = append(recentLatencies, latency)
|
||||
// Record tunnel latency (limit sampling to this periodic check)
|
||||
if tunnelID != "" {
|
||||
telemetry.ObserveTunnelLatency(context.Background(), tunnelID, "wireguard", latency.Seconds())
|
||||
}
|
||||
if len(recentLatencies) > 10 {
|
||||
recentLatencies = recentLatencies[1:]
|
||||
}
|
||||
@@ -353,89 +348,6 @@ func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Clien
|
||||
return pingStopChan
|
||||
}
|
||||
|
||||
func parseLogLevel(level string) logger.LogLevel {
|
||||
switch strings.ToUpper(level) {
|
||||
case "DEBUG":
|
||||
return logger.DEBUG
|
||||
case "INFO":
|
||||
return logger.INFO
|
||||
case "WARN":
|
||||
return logger.WARN
|
||||
case "ERROR":
|
||||
return logger.ERROR
|
||||
case "FATAL":
|
||||
return logger.FATAL
|
||||
default:
|
||||
return logger.INFO // default to INFO if invalid level provided
|
||||
}
|
||||
}
|
||||
|
||||
func mapToWireGuardLogLevel(level logger.LogLevel) int {
|
||||
switch level {
|
||||
case logger.DEBUG:
|
||||
return device.LogLevelVerbose
|
||||
// case logger.INFO:
|
||||
// return device.LogLevel
|
||||
case logger.WARN:
|
||||
return device.LogLevelError
|
||||
case logger.ERROR, logger.FATAL:
|
||||
return device.LogLevelSilent
|
||||
default:
|
||||
return device.LogLevelSilent
|
||||
}
|
||||
}
|
||||
|
||||
func resolveDomain(domain string) (string, error) {
|
||||
// Check if there's a port in the domain
|
||||
host, port, err := net.SplitHostPort(domain)
|
||||
if err != nil {
|
||||
// No port found, use the domain as is
|
||||
host = domain
|
||||
port = ""
|
||||
}
|
||||
|
||||
// Remove any protocol prefix if present
|
||||
if strings.HasPrefix(host, "http://") {
|
||||
host = strings.TrimPrefix(host, "http://")
|
||||
} else if strings.HasPrefix(host, "https://") {
|
||||
host = strings.TrimPrefix(host, "https://")
|
||||
}
|
||||
|
||||
// if there are any trailing slashes, remove them
|
||||
host = strings.TrimSuffix(host, "/")
|
||||
|
||||
// Lookup IP addresses
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("DNS lookup failed: %v", err)
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
return "", fmt.Errorf("no IP addresses found for domain %s", host)
|
||||
}
|
||||
|
||||
// Get the first IPv4 address if available
|
||||
var ipAddr string
|
||||
for _, ip := range ips {
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
ipAddr = ipv4.String()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If no IPv4 found, use the first IP (might be IPv6)
|
||||
if ipAddr == "" {
|
||||
ipAddr = ips[0].String()
|
||||
}
|
||||
|
||||
// Add port back if it existed
|
||||
if port != "" {
|
||||
ipAddr = net.JoinHostPort(ipAddr, port)
|
||||
}
|
||||
|
||||
return ipAddr, nil
|
||||
}
|
||||
|
||||
func parseTargetData(data interface{}) (TargetData, error) {
|
||||
var targetData TargetData
|
||||
jsonData, err := json.Marshal(data)
|
||||
@@ -468,7 +380,8 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
|
||||
continue
|
||||
}
|
||||
|
||||
if action == "add" {
|
||||
switch action {
|
||||
case "add":
|
||||
target := parts[1] + ":" + parts[2]
|
||||
|
||||
// Call updown script if provided
|
||||
@@ -494,7 +407,7 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
|
||||
// Add the new target
|
||||
pm.AddTarget(proto, tunnelIP, port, processedTarget)
|
||||
|
||||
} else if action == "remove" {
|
||||
case "remove":
|
||||
logger.Info("Removing target with port %d", port)
|
||||
|
||||
target := parts[1] + ":" + parts[2]
|
||||
@@ -512,6 +425,8 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
|
||||
logger.Error("Failed to remove target: %v", err)
|
||||
return err
|
||||
}
|
||||
default:
|
||||
logger.Info("Unknown action: %s", action)
|
||||
}
|
||||
}
|
||||
|
||||
44
device/tun_unix.go
Normal file
44
device/tun_unix.go
Normal file
@@ -0,0 +1,44 @@
|
||||
//go:build !windows
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
|
||||
dupTunFd, err := unix.Dup(int(tunFd))
|
||||
if err != nil {
|
||||
logger.Error("Unable to dup tun fd: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = unix.SetNonblock(dupTunFd, true)
|
||||
if err != nil {
|
||||
unix.Close(dupTunFd)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(dupTunFd), "/dev/tun")
|
||||
device, err := tun.CreateTUNFromFile(file, mtuInt)
|
||||
if err != nil {
|
||||
file.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return device, nil
|
||||
}
|
||||
|
||||
func UapiOpen(interfaceName string) (*os.File, error) {
|
||||
return ipc.UAPIOpen(interfaceName)
|
||||
}
|
||||
|
||||
func UapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
|
||||
return ipc.UAPIListen(interfaceName, fileUAPI)
|
||||
}
|
||||
25
device/tun_windows.go
Normal file
25
device/tun_windows.go
Normal file
@@ -0,0 +1,25 @@
|
||||
//go:build windows
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
|
||||
return nil, errors.New("CreateTUNFromFile not supported on Windows")
|
||||
}
|
||||
|
||||
func UapiOpen(interfaceName string) (*os.File, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func UapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
|
||||
// On Windows, UAPIListen only takes one parameter
|
||||
return ipc.UAPIListen(interfaceName)
|
||||
}
|
||||
41
docker-compose.metrics.collector.yml
Normal file
41
docker-compose.metrics.collector.yml
Normal file
@@ -0,0 +1,41 @@
|
||||
services:
|
||||
newt:
|
||||
build: .
|
||||
image: newt:dev
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- NEWT_METRICS_PROMETHEUS_ENABLED=false # important: disable direct /metrics scraping
|
||||
- NEWT_METRICS_OTLP_ENABLED=true # OTLP to the Collector
|
||||
# optional:
|
||||
# - NEWT_METRICS_INCLUDE_TUNNEL_ID=false
|
||||
# When using the Collector pattern, do NOT map the Newt admin/metrics port
|
||||
# (2112) on the application service. Mapping 2112 here can cause port
|
||||
# conflicts and may result in duplicated Prometheus scraping (app AND
|
||||
# collector being scraped for the same metrics). Instead either:
|
||||
# - leave ports unset on the app service (recommended), or
|
||||
# - map 2112 only on a dedicated metrics/collector service that is
|
||||
# responsible for exposing metrics to Prometheus.
|
||||
# Example: do NOT map here
|
||||
# ports: []
|
||||
# Example: map 2112 only on a collector service
|
||||
# collector:
|
||||
# ports:
|
||||
# - "2112:2112" # collector's prometheus exporter (scraped by Prometheus)
|
||||
|
||||
otel-collector:
|
||||
image: otel/opentelemetry-collector-contrib:latest
|
||||
command: ["--config=/etc/otelcol/config.yaml"]
|
||||
volumes:
|
||||
- ./examples/otel-collector.yaml:/etc/otelcol/config.yaml:ro
|
||||
ports:
|
||||
- "4317:4317" # OTLP gRPC
|
||||
- "8889:8889" # Prometheus Exporter (scraped by Prometheus)
|
||||
|
||||
prometheus:
|
||||
image: prom/prometheus:latest
|
||||
volumes:
|
||||
- ./examples/prometheus.with-collector.yml:/etc/prometheus/prometheus.yml:ro
|
||||
ports:
|
||||
- "9090:9090"
|
||||
|
||||
56
docker-compose.metrics.yml
Normal file
56
docker-compose.metrics.yml
Normal file
@@ -0,0 +1,56 @@
|
||||
name: Newt-Metrics
|
||||
services:
|
||||
# Recommended Variant A: Direct Prometheus scrape of Newt (/metrics)
|
||||
# Optional: You may add the Collector service and enable OTLP export, but do NOT
|
||||
# scrape both Newt and the Collector for the same process.
|
||||
|
||||
newt:
|
||||
build: .
|
||||
image: newt:dev
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
OTEL_SERVICE_NAME: newt
|
||||
NEWT_METRICS_PROMETHEUS_ENABLED: "true"
|
||||
NEWT_METRICS_OTLP_ENABLED: "false" # avoid double-scrape by default
|
||||
NEWT_ADMIN_ADDR: ":2112"
|
||||
# Base NEWT configuration
|
||||
PANGOLIN_ENDPOINT: ${PANGOLIN_ENDPOINT}
|
||||
NEWT_ID: ${NEWT_ID}
|
||||
NEWT_SECRET: ${NEWT_SECRET}
|
||||
LOG_LEVEL: "DEBUG"
|
||||
ports:
|
||||
- "2112:2112"
|
||||
|
||||
# Optional Variant B: Enable the Collector and switch Prometheus scrape to it.
|
||||
# collector:
|
||||
# image: otel/opentelemetry-collector-contrib:0.136.0
|
||||
# command: ["--config=/etc/otelcol/config.yaml"]
|
||||
# volumes:
|
||||
# - ./examples/otel-collector.yaml:/etc/otelcol/config.yaml:ro
|
||||
# ports:
|
||||
# - "4317:4317" # OTLP gRPC in
|
||||
# - "8889:8889" # Prometheus scrape out
|
||||
|
||||
prometheus:
|
||||
image: prom/prometheus:v3.6.0
|
||||
volumes:
|
||||
- ./examples/prometheus.yml:/etc/prometheus/prometheus.yml:ro
|
||||
ports:
|
||||
- "9090:9090"
|
||||
|
||||
grafana:
|
||||
image: grafana/grafana:12.2.0
|
||||
container_name: newt-metrics-grafana
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- GF_SECURITY_ADMIN_USER=admin
|
||||
- GF_SECURITY_ADMIN_PASSWORD=admin
|
||||
ports:
|
||||
- "3005:3000"
|
||||
depends_on:
|
||||
- prometheus
|
||||
volumes:
|
||||
- ./examples/grafana/provisioning/datasources:/etc/grafana/provisioning/datasources:ro
|
||||
- ./examples/grafana/provisioning/dashboards:/etc/grafana/provisioning/dashboards:ro
|
||||
- ./examples/grafana/dashboards:/var/lib/grafana/dashboards:ro
|
||||
167
examples/README.md
Normal file
167
examples/README.md
Normal file
@@ -0,0 +1,167 @@
|
||||
# Extensible Logger
|
||||
|
||||
This logger package provides a flexible logging system that can be extended with custom log writers.
|
||||
|
||||
## Basic Usage (Current Behavior)
|
||||
|
||||
The logger works exactly as before with no changes required:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import "your-project/logger"
|
||||
|
||||
func main() {
|
||||
// Use default logger
|
||||
logger.Info("This works as before")
|
||||
logger.Debug("Debug message")
|
||||
logger.Error("Error message")
|
||||
|
||||
// Or create a custom instance
|
||||
log := logger.NewLogger()
|
||||
log.SetLevel(logger.INFO)
|
||||
log.Info("Custom logger instance")
|
||||
}
|
||||
```
|
||||
|
||||
## Custom Log Writers
|
||||
|
||||
To use a custom log backend, implement the `LogWriter` interface:
|
||||
|
||||
```go
|
||||
type LogWriter interface {
|
||||
Write(level LogLevel, timestamp time.Time, message string)
|
||||
}
|
||||
```
|
||||
|
||||
### Example: OS Log Writer (macOS/iOS)
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import "your-project/logger"
|
||||
|
||||
func main() {
|
||||
// Create an OS log writer
|
||||
osWriter := logger.NewOSLogWriter(
|
||||
"net.pangolin.Pangolin.PacketTunnel",
|
||||
"PangolinGo",
|
||||
"MyApp",
|
||||
)
|
||||
|
||||
// Create a logger with the OS log writer
|
||||
log := logger.NewLoggerWithWriter(osWriter)
|
||||
log.SetLevel(logger.DEBUG)
|
||||
|
||||
// Use it just like the standard logger
|
||||
log.Info("This message goes to os_log")
|
||||
log.Error("Error logged to os_log")
|
||||
}
|
||||
```
|
||||
|
||||
### Example: Custom Writer
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
"your-project/logger"
|
||||
)
|
||||
|
||||
// CustomWriter writes logs to a custom destination
|
||||
type CustomWriter struct {
|
||||
// your custom fields
|
||||
}
|
||||
|
||||
func (w *CustomWriter) Write(level logger.LogLevel, timestamp time.Time, message string) {
|
||||
// Your custom logging logic
|
||||
fmt.Printf("[CUSTOM] %s [%s] %s\n", timestamp.Format(time.RFC3339), level.String(), message)
|
||||
}
|
||||
|
||||
func main() {
|
||||
customWriter := &CustomWriter{}
|
||||
log := logger.NewLoggerWithWriter(customWriter)
|
||||
log.Info("Custom logging!")
|
||||
}
|
||||
```
|
||||
|
||||
### Example: Multi-Writer (Log to Multiple Destinations)
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"time"
|
||||
"your-project/logger"
|
||||
)
|
||||
|
||||
// MultiWriter writes to multiple log writers
|
||||
type MultiWriter struct {
|
||||
writers []logger.LogWriter
|
||||
}
|
||||
|
||||
func NewMultiWriter(writers ...logger.LogWriter) *MultiWriter {
|
||||
return &MultiWriter{writers: writers}
|
||||
}
|
||||
|
||||
func (w *MultiWriter) Write(level logger.LogLevel, timestamp time.Time, message string) {
|
||||
for _, writer := range w.writers {
|
||||
writer.Write(level, timestamp, message)
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Log to both standard output and OS log
|
||||
standardWriter := logger.NewStandardWriter()
|
||||
osWriter := logger.NewOSLogWriter("com.example.app", "Main", "App")
|
||||
|
||||
multiWriter := NewMultiWriter(standardWriter, osWriter)
|
||||
log := logger.NewLoggerWithWriter(multiWriter)
|
||||
|
||||
log.Info("This goes to both stdout and os_log!")
|
||||
}
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### Creating Loggers
|
||||
|
||||
- `NewLogger()` - Creates a logger with the default StandardWriter
|
||||
- `NewLoggerWithWriter(writer LogWriter)` - Creates a logger with a custom writer
|
||||
|
||||
### Built-in Writers
|
||||
|
||||
- `NewStandardWriter()` - Standard writer that outputs to stdout (default)
|
||||
- `NewOSLogWriter(subsystem, category, prefix string)` - OS log writer for macOS/iOS (example)
|
||||
|
||||
### Logger Methods
|
||||
|
||||
- `SetLevel(level LogLevel)` - Set minimum log level
|
||||
- `SetOutput(output *os.File)` - Set output file (StandardWriter only)
|
||||
- `Debug(format string, args ...interface{})` - Log debug message
|
||||
- `Info(format string, args ...interface{})` - Log info message
|
||||
- `Warn(format string, args ...interface{})` - Log warning message
|
||||
- `Error(format string, args ...interface{})` - Log error message
|
||||
- `Fatal(format string, args ...interface{})` - Log fatal message and exit
|
||||
|
||||
### Global Functions
|
||||
|
||||
For convenience, you can use global functions that use the default logger:
|
||||
|
||||
- `logger.Debug(format, args...)`
|
||||
- `logger.Info(format, args...)`
|
||||
- `logger.Warn(format, args...)`
|
||||
- `logger.Error(format, args...)`
|
||||
- `logger.Fatal(format, args...)`
|
||||
- `logger.SetOutput(output *os.File)`
|
||||
|
||||
## Migration Guide
|
||||
|
||||
No changes needed! The logger maintains 100% backward compatibility. Your existing code will continue to work without modifications.
|
||||
|
||||
If you want to switch to a custom writer:
|
||||
1. Create your writer implementing `LogWriter`
|
||||
2. Use `NewLoggerWithWriter()` instead of `NewLogger()`
|
||||
3. That's it!
|
||||
898
examples/grafana/dashboards/newt-overview.json
Normal file
898
examples/grafana/dashboards/newt-overview.json
Normal file
@@ -0,0 +1,898 @@
|
||||
{
|
||||
"annotations": {
|
||||
"list": [
|
||||
{
|
||||
"builtIn": 1,
|
||||
"datasource": {
|
||||
"type": "grafana",
|
||||
"uid": "-- Grafana --"
|
||||
},
|
||||
"enable": true,
|
||||
"hide": true,
|
||||
"iconColor": "rgba(0, 211, 255, 1)",
|
||||
"name": "Annotations & Alerts",
|
||||
"type": "dashboard"
|
||||
}
|
||||
]
|
||||
},
|
||||
"editable": true,
|
||||
"fiscalYearStartMonth": 0,
|
||||
"graphTooltip": 0,
|
||||
"id": null,
|
||||
"links": [],
|
||||
"liveNow": false,
|
||||
"panels": [
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": {
|
||||
"mode": "thresholds"
|
||||
},
|
||||
"decimals": 0,
|
||||
"mappings": [],
|
||||
"thresholds": {
|
||||
"mode": "absolute",
|
||||
"steps": [
|
||||
{
|
||||
"color": "green",
|
||||
"value": null
|
||||
},
|
||||
{
|
||||
"color": "red",
|
||||
"value": 500
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 7,
|
||||
"w": 6,
|
||||
"x": 0,
|
||||
"y": 0
|
||||
},
|
||||
"id": 1,
|
||||
"options": {
|
||||
"colorMode": "value",
|
||||
"graphMode": "area",
|
||||
"justifyMode": "auto",
|
||||
"orientation": "horizontal",
|
||||
"reduceOptions": {
|
||||
"calcs": [
|
||||
"lastNotNull"
|
||||
],
|
||||
"fields": "",
|
||||
"values": false
|
||||
},
|
||||
"textMode": "value_and_name"
|
||||
},
|
||||
"pluginVersion": "11.1.0",
|
||||
"targets": [
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"editorMode": "code",
|
||||
"expr": "go_goroutine_count",
|
||||
"instant": true,
|
||||
"legendFormat": "",
|
||||
"refId": "A"
|
||||
}
|
||||
],
|
||||
"title": "Goroutines",
|
||||
"transformations": [],
|
||||
"type": "stat"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": {
|
||||
"mode": "thresholds"
|
||||
},
|
||||
"decimals": 1,
|
||||
"mappings": [],
|
||||
"thresholds": {
|
||||
"mode": "absolute",
|
||||
"steps": [
|
||||
{
|
||||
"color": "green",
|
||||
"value": null
|
||||
},
|
||||
{
|
||||
"color": "orange",
|
||||
"value": 256
|
||||
},
|
||||
{
|
||||
"color": "red",
|
||||
"value": 512
|
||||
}
|
||||
]
|
||||
},
|
||||
"unit": "bytes"
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 7,
|
||||
"w": 6,
|
||||
"x": 6,
|
||||
"y": 0
|
||||
},
|
||||
"id": 2,
|
||||
"options": {
|
||||
"colorMode": "background",
|
||||
"graphMode": "area",
|
||||
"justifyMode": "auto",
|
||||
"orientation": "horizontal",
|
||||
"reduceOptions": {
|
||||
"calcs": [
|
||||
"lastNotNull"
|
||||
],
|
||||
"fields": "",
|
||||
"values": false
|
||||
},
|
||||
"textMode": "value_and_name"
|
||||
},
|
||||
"pluginVersion": "11.1.0",
|
||||
"targets": [
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"editorMode": "code",
|
||||
"expr": "go_memory_gc_goal_bytes / 1024 / 1024",
|
||||
"format": "time_series",
|
||||
"instant": true,
|
||||
"legendFormat": "",
|
||||
"refId": "A"
|
||||
}
|
||||
],
|
||||
"title": "GC Target Heap (MiB)",
|
||||
"transformations": [],
|
||||
"type": "stat"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": {
|
||||
"mode": "thresholds"
|
||||
},
|
||||
"decimals": 2,
|
||||
"mappings": [],
|
||||
"thresholds": {
|
||||
"mode": "absolute",
|
||||
"steps": [
|
||||
{
|
||||
"color": "green",
|
||||
"value": null
|
||||
},
|
||||
{
|
||||
"color": "orange",
|
||||
"value": 10
|
||||
},
|
||||
{
|
||||
"color": "red",
|
||||
"value": 25
|
||||
}
|
||||
]
|
||||
},
|
||||
"unit": "ops"
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 7,
|
||||
"w": 6,
|
||||
"x": 12,
|
||||
"y": 0
|
||||
},
|
||||
"id": 3,
|
||||
"options": {
|
||||
"colorMode": "value",
|
||||
"graphMode": "area",
|
||||
"justifyMode": "auto",
|
||||
"orientation": "horizontal",
|
||||
"reduceOptions": {
|
||||
"calcs": [
|
||||
"lastNotNull"
|
||||
],
|
||||
"fields": "",
|
||||
"values": false
|
||||
},
|
||||
"textMode": "value_and_name"
|
||||
},
|
||||
"pluginVersion": "11.1.0",
|
||||
"targets": [
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"editorMode": "code",
|
||||
"expr": "sum(rate(http_server_request_duration_seconds_count[$__rate_interval]))",
|
||||
"instant": false,
|
||||
"legendFormat": "req/s",
|
||||
"refId": "A"
|
||||
}
|
||||
],
|
||||
"title": "HTTP Requests / s",
|
||||
"transformations": [],
|
||||
"type": "stat"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": {
|
||||
"mode": "thresholds"
|
||||
},
|
||||
"decimals": 3,
|
||||
"mappings": [],
|
||||
"thresholds": {
|
||||
"mode": "absolute",
|
||||
"steps": [
|
||||
{
|
||||
"color": "green",
|
||||
"value": null
|
||||
},
|
||||
{
|
||||
"color": "orange",
|
||||
"value": 0.1
|
||||
},
|
||||
{
|
||||
"color": "red",
|
||||
"value": 0.5
|
||||
}
|
||||
]
|
||||
},
|
||||
"unit": "ops"
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 7,
|
||||
"w": 6,
|
||||
"x": 18,
|
||||
"y": 0
|
||||
},
|
||||
"id": 4,
|
||||
"options": {
|
||||
"colorMode": "value",
|
||||
"graphMode": "area",
|
||||
"justifyMode": "auto",
|
||||
"orientation": "horizontal",
|
||||
"reduceOptions": {
|
||||
"calcs": [
|
||||
"lastNotNull"
|
||||
],
|
||||
"fields": "",
|
||||
"values": false
|
||||
},
|
||||
"textMode": "value_and_name"
|
||||
},
|
||||
"pluginVersion": "11.1.0",
|
||||
"targets": [
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"editorMode": "code",
|
||||
"expr": "sum(rate(newt_connection_errors_total{site_id=~\"$site_id\"}[$__rate_interval]))",
|
||||
"instant": false,
|
||||
"legendFormat": "errors/s",
|
||||
"refId": "A"
|
||||
}
|
||||
],
|
||||
"title": "Connection Errors / s",
|
||||
"transformations": [],
|
||||
"type": "stat"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"custom": {},
|
||||
"mappings": [],
|
||||
"unit": "bytes"
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 9,
|
||||
"w": 12,
|
||||
"x": 0,
|
||||
"y": 7
|
||||
},
|
||||
"id": 5,
|
||||
"options": {
|
||||
"legend": {
|
||||
"calcs": [],
|
||||
"displayMode": "table",
|
||||
"placement": "right"
|
||||
},
|
||||
"tooltip": {
|
||||
"mode": "multi",
|
||||
"sort": "desc"
|
||||
}
|
||||
},
|
||||
"pluginVersion": "11.1.0",
|
||||
"targets": [
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"editorMode": "code",
|
||||
"expr": "sum(go_memory_used_bytes)",
|
||||
"legendFormat": "Used",
|
||||
"refId": "A"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"editorMode": "code",
|
||||
"expr": "go_memory_gc_goal_bytes",
|
||||
"legendFormat": "GC Goal",
|
||||
"refId": "B"
|
||||
}
|
||||
],
|
||||
"title": "Go Heap Usage vs GC Goal",
|
||||
"type": "timeseries"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"custom": {},
|
||||
"decimals": 0,
|
||||
"mappings": [],
|
||||
"unit": "short"
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 9,
|
||||
"w": 12,
|
||||
"x": 12,
|
||||
"y": 7
|
||||
},
|
||||
"id": 6,
|
||||
"options": {
|
||||
"legend": {
|
||||
"calcs": [],
|
||||
"displayMode": "table",
|
||||
"placement": "right"
|
||||
},
|
||||
"tooltip": {
|
||||
"mode": "multi",
|
||||
"sort": "desc"
|
||||
}
|
||||
},
|
||||
"pluginVersion": "11.1.0",
|
||||
"targets": [
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"editorMode": "code",
|
||||
"expr": "rate(go_memory_allocations_total[$__rate_interval])",
|
||||
"legendFormat": "Allocations/s",
|
||||
"refId": "A"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"editorMode": "code",
|
||||
"expr": "rate(go_memory_allocated_bytes_total[$__rate_interval])",
|
||||
"legendFormat": "Allocated bytes/s",
|
||||
"refId": "B"
|
||||
}
|
||||
],
|
||||
"title": "Allocation Activity",
|
||||
"type": "timeseries"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"custom": {},
|
||||
"mappings": [],
|
||||
"unit": "s"
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 9,
|
||||
"w": 12,
|
||||
"x": 0,
|
||||
"y": 16
|
||||
},
|
||||
"id": 7,
|
||||
"options": {
|
||||
"legend": {
|
||||
"calcs": [],
|
||||
"displayMode": "table",
|
||||
"placement": "right"
|
||||
},
|
||||
"tooltip": {
|
||||
"mode": "multi",
|
||||
"sort": "desc"
|
||||
}
|
||||
},
|
||||
"pluginVersion": "11.1.0",
|
||||
"targets": [
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"editorMode": "code",
|
||||
"expr": "histogram_quantile(0.5, sum(rate(http_server_request_duration_seconds_bucket[$__rate_interval])) by (le))",
|
||||
"legendFormat": "p50",
|
||||
"refId": "A"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"editorMode": "code",
|
||||
"expr": "histogram_quantile(0.95, sum(rate(http_server_request_duration_seconds_bucket[$__rate_interval])) by (le))",
|
||||
"legendFormat": "p95",
|
||||
"refId": "B"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"editorMode": "code",
|
||||
"expr": "histogram_quantile(0.99, sum(rate(http_server_request_duration_seconds_bucket[$__rate_interval])) by (le))",
|
||||
"legendFormat": "p99",
|
||||
"refId": "C"
|
||||
}
|
||||
],
|
||||
"title": "HTTP Request Duration Quantiles",
|
||||
"type": "timeseries"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"custom": {},
|
||||
"mappings": [],
|
||||
"unit": "ops"
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 9,
|
||||
"w": 12,
|
||||
"x": 12,
|
||||
"y": 16
|
||||
},
|
||||
"id": 8,
|
||||
"options": {
|
||||
"legend": {
|
||||
"calcs": [],
|
||||
"displayMode": "table",
|
||||
"placement": "right"
|
||||
},
|
||||
"tooltip": {
|
||||
"mode": "multi",
|
||||
"sort": "desc"
|
||||
}
|
||||
},
|
||||
"pluginVersion": "11.1.0",
|
||||
"targets": [
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"editorMode": "code",
|
||||
"expr": "sum(rate(http_server_request_duration_seconds_count[$__rate_interval])) by (http_response_status_code)",
|
||||
"legendFormat": "{{http_response_status_code}}",
|
||||
"refId": "A"
|
||||
}
|
||||
],
|
||||
"title": "HTTP Requests by Status",
|
||||
"type": "timeseries"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"custom": {},
|
||||
"mappings": [],
|
||||
"unit": "ops"
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 9,
|
||||
"w": 12,
|
||||
"x": 0,
|
||||
"y": 25
|
||||
},
|
||||
"id": 9,
|
||||
"options": {
|
||||
"legend": {
|
||||
"calcs": [],
|
||||
"displayMode": "table",
|
||||
"placement": "right"
|
||||
},
|
||||
"tooltip": {
|
||||
"mode": "multi",
|
||||
"sort": "desc"
|
||||
}
|
||||
},
|
||||
"pluginVersion": "11.1.0",
|
||||
"targets": [
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"editorMode": "code",
|
||||
"expr": "sum(rate(newt_connection_attempts_total{site_id=~\"$site_id\"}[$__rate_interval])) by (transport, result)",
|
||||
"legendFormat": "{{transport}} • {{result}}",
|
||||
"refId": "A"
|
||||
}
|
||||
],
|
||||
"title": "Connection Attempts by Transport/Result",
|
||||
"type": "timeseries"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"custom": {},
|
||||
"mappings": [],
|
||||
"unit": "ops"
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 9,
|
||||
"w": 12,
|
||||
"x": 12,
|
||||
"y": 25
|
||||
},
|
||||
"id": 10,
|
||||
"options": {
|
||||
"legend": {
|
||||
"calcs": [],
|
||||
"displayMode": "table",
|
||||
"placement": "right"
|
||||
},
|
||||
"tooltip": {
|
||||
"mode": "multi",
|
||||
"sort": "desc"
|
||||
}
|
||||
},
|
||||
"pluginVersion": "11.1.0",
|
||||
"targets": [
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"editorMode": "code",
|
||||
"expr": "sum(rate(newt_connection_errors_total{site_id=~\"$site_id\"}[$__rate_interval])) by (transport, error_type)",
|
||||
"legendFormat": "{{transport}} • {{error_type}}",
|
||||
"refId": "A"
|
||||
}
|
||||
],
|
||||
"title": "Connection Errors by Type",
|
||||
"type": "timeseries"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"custom": {},
|
||||
"decimals": 3,
|
||||
"mappings": [],
|
||||
"unit": "s"
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 9,
|
||||
"w": 12,
|
||||
"x": 0,
|
||||
"y": 34
|
||||
},
|
||||
"id": 11,
|
||||
"options": {
|
||||
"legend": {
|
||||
"calcs": [],
|
||||
"displayMode": "table",
|
||||
"placement": "right"
|
||||
},
|
||||
"tooltip": {
|
||||
"mode": "multi",
|
||||
"sort": "desc"
|
||||
}
|
||||
},
|
||||
"pluginVersion": "11.1.0",
|
||||
"targets": [
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"editorMode": "code",
|
||||
"expr": "histogram_quantile(0.5, sum(rate(newt_tunnel_latency_seconds_bucket{site_id=~\"$site_id\", tunnel_id=~\"$tunnel_id\"}[$__rate_interval])) by (le))",
|
||||
"legendFormat": "p50",
|
||||
"refId": "A"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"editorMode": "code",
|
||||
"expr": "histogram_quantile(0.95, sum(rate(newt_tunnel_latency_seconds_bucket{site_id=~\"$site_id\", tunnel_id=~\"$tunnel_id\"}[$__rate_interval])) by (le))",
|
||||
"legendFormat": "p95",
|
||||
"refId": "B"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"editorMode": "code",
|
||||
"expr": "histogram_quantile(0.99, sum(rate(newt_tunnel_latency_seconds_bucket{site_id=~\"$site_id\", tunnel_id=~\"$tunnel_id\"}[$__rate_interval])) by (le))",
|
||||
"legendFormat": "p99",
|
||||
"refId": "C"
|
||||
}
|
||||
],
|
||||
"title": "Tunnel Latency Quantiles",
|
||||
"type": "timeseries"
|
||||
},
|
||||
{
|
||||
"cards": {},
|
||||
"color": {
|
||||
"cardColor": "#b4ff00",
|
||||
"colorScale": "sqrt",
|
||||
"colorScheme": "interpolateTurbo"
|
||||
},
|
||||
"dataFormat": "tsbuckets",
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"custom": {},
|
||||
"mappings": [],
|
||||
"unit": "s"
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 9,
|
||||
"w": 12,
|
||||
"x": 12,
|
||||
"y": 34
|
||||
},
|
||||
"heatmap": {},
|
||||
"hideZeroBuckets": true,
|
||||
"id": 12,
|
||||
"legend": {
|
||||
"show": false
|
||||
},
|
||||
"options": {
|
||||
"calculate": true,
|
||||
"cellGap": 2,
|
||||
"cellSize": "auto",
|
||||
"color": {
|
||||
"exponent": 0.5
|
||||
},
|
||||
"exemplars": {
|
||||
"color": "rgba(255,255,255,0.7)"
|
||||
},
|
||||
"filterValues": {
|
||||
"le": 1e-9
|
||||
},
|
||||
"legend": {
|
||||
"show": false
|
||||
},
|
||||
"tooltip": {
|
||||
"mode": "single",
|
||||
"show": true
|
||||
},
|
||||
"xAxis": {
|
||||
"show": true
|
||||
},
|
||||
"yAxis": {
|
||||
"decimals": 3,
|
||||
"show": true
|
||||
}
|
||||
},
|
||||
"pluginVersion": "11.1.0",
|
||||
"targets": [
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"editorMode": "code",
|
||||
"expr": "sum(rate(newt_tunnel_latency_seconds_bucket{site_id=~\"$site_id\", tunnel_id=~\"$tunnel_id\"}[$__rate_interval])) by (le)",
|
||||
"format": "heatmap",
|
||||
"legendFormat": "{{le}}",
|
||||
"refId": "A"
|
||||
}
|
||||
],
|
||||
"title": "Tunnel Latency Bucket Rate",
|
||||
"type": "heatmap"
|
||||
}
|
||||
],
|
||||
"refresh": "30s",
|
||||
"schemaVersion": 39,
|
||||
"style": "dark",
|
||||
"tags": [
|
||||
"newt",
|
||||
"otel"
|
||||
],
|
||||
"templating": {
|
||||
"list": [
|
||||
{
|
||||
"current": {
|
||||
"selected": false,
|
||||
"text": "Prometheus",
|
||||
"value": "prometheus"
|
||||
},
|
||||
"hide": 0,
|
||||
"label": "Datasource",
|
||||
"name": "DS_PROMETHEUS",
|
||||
"options": [],
|
||||
"query": "prometheus",
|
||||
"refresh": 1,
|
||||
"type": "datasource"
|
||||
},
|
||||
{
|
||||
"current": {
|
||||
"selected": false,
|
||||
"text": "All",
|
||||
"value": "$__all"
|
||||
},
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"definition": "label_values(target_info, site_id)",
|
||||
"hide": 0,
|
||||
"includeAll": true,
|
||||
"label": "Site",
|
||||
"multi": true,
|
||||
"name": "site_id",
|
||||
"options": [],
|
||||
"query": {
|
||||
"query": "label_values(target_info, site_id)",
|
||||
"refId": "SiteIdVar"
|
||||
},
|
||||
"refresh": 2,
|
||||
"regex": "",
|
||||
"skipUrlSync": false,
|
||||
"sort": 1,
|
||||
"tagValuesQuery": "",
|
||||
"tags": [],
|
||||
"tagsQuery": "",
|
||||
"type": "query",
|
||||
"useTags": false
|
||||
},
|
||||
{
|
||||
"current": {
|
||||
"selected": false,
|
||||
"text": "All",
|
||||
"value": "$__all"
|
||||
},
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"definition": "label_values(newt_tunnel_latency_seconds_bucket{site_id=~\"$site_id\"}, tunnel_id)",
|
||||
"hide": 0,
|
||||
"includeAll": true,
|
||||
"label": "Tunnel",
|
||||
"multi": true,
|
||||
"name": "tunnel_id",
|
||||
"options": [],
|
||||
"query": {
|
||||
"query": "label_values(newt_tunnel_latency_seconds_bucket{site_id=~\"$site_id\"}, tunnel_id)",
|
||||
"refId": "TunnelVar"
|
||||
},
|
||||
"refresh": 2,
|
||||
"regex": "",
|
||||
"skipUrlSync": false,
|
||||
"sort": 1,
|
||||
"tagValuesQuery": "",
|
||||
"tags": [],
|
||||
"tagsQuery": "",
|
||||
"type": "query",
|
||||
"useTags": false
|
||||
}
|
||||
]
|
||||
},
|
||||
"time": {
|
||||
"from": "now-6h",
|
||||
"to": "now"
|
||||
},
|
||||
"timepicker": {
|
||||
"refresh_intervals": [
|
||||
"10s",
|
||||
"30s",
|
||||
"1m",
|
||||
"5m",
|
||||
"15m",
|
||||
"30m",
|
||||
"1h",
|
||||
"2h",
|
||||
"1d"
|
||||
],
|
||||
"time_options": [
|
||||
"5m",
|
||||
"15m",
|
||||
"1h",
|
||||
"6h",
|
||||
"12h",
|
||||
"24h",
|
||||
"2d",
|
||||
"7d",
|
||||
"30d"
|
||||
]
|
||||
},
|
||||
"timezone": "browser",
|
||||
"title": "Newt Overview",
|
||||
"uid": "newt-overview",
|
||||
"version": 1,
|
||||
"weekStart": ""
|
||||
}
|
||||
9
examples/grafana/provisioning/dashboards/dashboard.yaml
Normal file
9
examples/grafana/provisioning/dashboards/dashboard.yaml
Normal file
@@ -0,0 +1,9 @@
|
||||
apiVersion: 1
|
||||
providers:
|
||||
- name: "newt"
|
||||
folder: "Newt"
|
||||
type: file
|
||||
disableDeletion: false
|
||||
allowUiUpdates: true
|
||||
options:
|
||||
path: /var/lib/grafana/dashboards
|
||||
@@ -0,0 +1,9 @@
|
||||
apiVersion: 1
|
||||
datasources:
|
||||
- name: Prometheus
|
||||
type: prometheus
|
||||
access: proxy
|
||||
url: http://prometheus:9090
|
||||
uid: prometheus
|
||||
isDefault: true
|
||||
editable: true
|
||||
161
examples/logger_examples.go
Normal file
161
examples/logger_examples.go
Normal file
@@ -0,0 +1,161 @@
|
||||
// Example usage patterns for the extensible logger
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
)
|
||||
|
||||
// Example 1: Using the default logger (works exactly as before)
|
||||
func exampleDefaultLogger() {
|
||||
logger.Info("Starting application")
|
||||
logger.Debug("Debug information")
|
||||
logger.Warn("Warning message")
|
||||
logger.Error("Error occurred")
|
||||
}
|
||||
|
||||
// Example 2: Using a custom logger instance with standard writer
|
||||
func exampleCustomInstance() {
|
||||
log := logger.NewLogger()
|
||||
log.SetLevel(logger.INFO)
|
||||
log.Info("This is from a custom instance")
|
||||
}
|
||||
|
||||
// Example 3: Custom writer that adds JSON formatting
|
||||
type JSONWriter struct{}
|
||||
|
||||
func (w *JSONWriter) Write(level logger.LogLevel, timestamp time.Time, message string) {
|
||||
fmt.Printf("{\"time\":\"%s\",\"level\":\"%s\",\"message\":\"%s\"}\n",
|
||||
timestamp.Format(time.RFC3339),
|
||||
level.String(),
|
||||
message)
|
||||
}
|
||||
|
||||
func exampleJSONLogger() {
|
||||
jsonWriter := &JSONWriter{}
|
||||
log := logger.NewLoggerWithWriter(jsonWriter)
|
||||
log.Info("This will be logged as JSON")
|
||||
}
|
||||
|
||||
// Example 4: File writer
|
||||
type FileWriter struct {
|
||||
file *os.File
|
||||
}
|
||||
|
||||
func NewFileWriter(filename string) (*FileWriter, error) {
|
||||
file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &FileWriter{file: file}, nil
|
||||
}
|
||||
|
||||
func (w *FileWriter) Write(level logger.LogLevel, timestamp time.Time, message string) {
|
||||
fmt.Fprintf(w.file, "[%s] %s: %s\n",
|
||||
timestamp.Format("2006-01-02 15:04:05"),
|
||||
level.String(),
|
||||
message)
|
||||
}
|
||||
|
||||
func (w *FileWriter) Close() error {
|
||||
return w.file.Close()
|
||||
}
|
||||
|
||||
func exampleFileLogger() {
|
||||
fileWriter, err := NewFileWriter("/tmp/app.log")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer fileWriter.Close()
|
||||
|
||||
log := logger.NewLoggerWithWriter(fileWriter)
|
||||
log.Info("This goes to a file")
|
||||
}
|
||||
|
||||
// Example 5: Multi-writer to log to multiple destinations
|
||||
type MultiWriter struct {
|
||||
writers []logger.LogWriter
|
||||
}
|
||||
|
||||
func NewMultiWriter(writers ...logger.LogWriter) *MultiWriter {
|
||||
return &MultiWriter{writers: writers}
|
||||
}
|
||||
|
||||
func (w *MultiWriter) Write(level logger.LogLevel, timestamp time.Time, message string) {
|
||||
for _, writer := range w.writers {
|
||||
writer.Write(level, timestamp, message)
|
||||
}
|
||||
}
|
||||
|
||||
func exampleMultiWriter() {
|
||||
// Log to both stdout and a file
|
||||
standardWriter := logger.NewStandardWriter()
|
||||
fileWriter, _ := NewFileWriter("/tmp/app.log")
|
||||
|
||||
multiWriter := NewMultiWriter(standardWriter, fileWriter)
|
||||
log := logger.NewLoggerWithWriter(multiWriter)
|
||||
|
||||
log.Info("This goes to both stdout and file!")
|
||||
}
|
||||
|
||||
// Example 6: Conditional writer (only log errors to a specific destination)
|
||||
type ErrorOnlyWriter struct {
|
||||
writer logger.LogWriter
|
||||
}
|
||||
|
||||
func NewErrorOnlyWriter(writer logger.LogWriter) *ErrorOnlyWriter {
|
||||
return &ErrorOnlyWriter{writer: writer}
|
||||
}
|
||||
|
||||
func (w *ErrorOnlyWriter) Write(level logger.LogLevel, timestamp time.Time, message string) {
|
||||
if level >= logger.ERROR {
|
||||
w.writer.Write(level, timestamp, message)
|
||||
}
|
||||
}
|
||||
|
||||
func exampleConditionalWriter() {
|
||||
errorWriter, _ := NewFileWriter("/tmp/errors.log")
|
||||
errorOnlyWriter := NewErrorOnlyWriter(errorWriter)
|
||||
|
||||
log := logger.NewLoggerWithWriter(errorOnlyWriter)
|
||||
log.Info("This won't be logged")
|
||||
log.Error("This will be logged to errors.log")
|
||||
}
|
||||
|
||||
/* Example 7: OS Log Writer (macOS/iOS only)
|
||||
// Uncomment on Darwin platforms
|
||||
|
||||
func exampleOSLogWriter() {
|
||||
osWriter := logger.NewOSLogWriter(
|
||||
"net.pangolin.Pangolin.PacketTunnel",
|
||||
"PangolinGo",
|
||||
"MyApp",
|
||||
)
|
||||
|
||||
log := logger.NewLoggerWithWriter(osWriter)
|
||||
log.Info("This goes to os_log and can be viewed with Console.app")
|
||||
}
|
||||
*/
|
||||
|
||||
func main() {
|
||||
fmt.Println("=== Example 1: Default Logger ===")
|
||||
exampleDefaultLogger()
|
||||
|
||||
fmt.Println("\n=== Example 2: Custom Instance ===")
|
||||
exampleCustomInstance()
|
||||
|
||||
fmt.Println("\n=== Example 3: JSON Logger ===")
|
||||
exampleJSONLogger()
|
||||
|
||||
fmt.Println("\n=== Example 4: File Logger ===")
|
||||
exampleFileLogger()
|
||||
|
||||
fmt.Println("\n=== Example 5: Multi-Writer ===")
|
||||
exampleMultiWriter()
|
||||
|
||||
fmt.Println("\n=== Example 6: Conditional Writer ===")
|
||||
exampleConditionalWriter()
|
||||
}
|
||||
86
examples/oslog_writer_example.go
Normal file
86
examples/oslog_writer_example.go
Normal file
@@ -0,0 +1,86 @@
|
||||
//go:build darwin
|
||||
// +build darwin
|
||||
|
||||
package main
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -I../PacketTunnel
|
||||
#include "../PacketTunnel/OSLogBridge.h"
|
||||
#include <stdlib.h>
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// OSLogWriter is a LogWriter implementation that writes to Apple's os_log
|
||||
type OSLogWriter struct {
|
||||
subsystem string
|
||||
category string
|
||||
prefix string
|
||||
}
|
||||
|
||||
// NewOSLogWriter creates a new OSLogWriter
|
||||
func NewOSLogWriter(subsystem, category, prefix string) *OSLogWriter {
|
||||
writer := &OSLogWriter{
|
||||
subsystem: subsystem,
|
||||
category: category,
|
||||
prefix: prefix,
|
||||
}
|
||||
|
||||
// Initialize the OS log bridge
|
||||
cSubsystem := C.CString(subsystem)
|
||||
cCategory := C.CString(category)
|
||||
defer C.free(unsafe.Pointer(cSubsystem))
|
||||
defer C.free(unsafe.Pointer(cCategory))
|
||||
|
||||
C.initOSLogBridge(cSubsystem, cCategory)
|
||||
|
||||
return writer
|
||||
}
|
||||
|
||||
// Write implements the LogWriter interface
|
||||
func (w *OSLogWriter) Write(level LogLevel, timestamp time.Time, message string) {
|
||||
// Get caller information (skip 3 frames to get to the actual caller)
|
||||
_, file, line, ok := runtime.Caller(3)
|
||||
if !ok {
|
||||
file = "unknown"
|
||||
line = 0
|
||||
} else {
|
||||
// Get just the filename, not the full path
|
||||
for i := len(file) - 1; i > 0; i-- {
|
||||
if file[i] == '/' {
|
||||
file = file[i+1:]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
formattedTime := timestamp.Format("2006-01-02 15:04:05.000")
|
||||
fullMessage := fmt.Sprintf("[%s] [%s] [%s] %s:%d - %s",
|
||||
formattedTime, level.String(), w.prefix, file, line, message)
|
||||
|
||||
cMessage := C.CString(fullMessage)
|
||||
defer C.free(unsafe.Pointer(cMessage))
|
||||
|
||||
// Map Go log levels to os_log levels:
|
||||
// 0=DEBUG, 1=INFO, 2=DEFAULT (WARN), 3=ERROR
|
||||
var osLogLevel C.int
|
||||
switch level {
|
||||
case DEBUG:
|
||||
osLogLevel = 0 // DEBUG
|
||||
case INFO:
|
||||
osLogLevel = 1 // INFO
|
||||
case WARN:
|
||||
osLogLevel = 2 // DEFAULT
|
||||
case ERROR, FATAL:
|
||||
osLogLevel = 3 // ERROR
|
||||
default:
|
||||
osLogLevel = 2 // DEFAULT
|
||||
}
|
||||
|
||||
C.logToOSLog(osLogLevel, cMessage)
|
||||
}
|
||||
61
examples/otel-collector.yaml
Normal file
61
examples/otel-collector.yaml
Normal file
@@ -0,0 +1,61 @@
|
||||
# Variant A: Direct scrape of Newt (/metrics) via Prometheus (no Collector needed)
|
||||
# Note: Newt already exposes labels like site_id, protocol, direction. Do not promote
|
||||
# resource attributes into labels when scraping Newt directly.
|
||||
#
|
||||
# Example Prometheus scrape config:
|
||||
# global:
|
||||
# scrape_interval: 15s
|
||||
# scrape_configs:
|
||||
# - job_name: newt
|
||||
# static_configs:
|
||||
# - targets: ["newt:2112"]
|
||||
#
|
||||
# Variant B: Use OTEL Collector (Newt -> OTLP -> Collector -> Prometheus)
|
||||
# This pipeline scrapes metrics from the Collector's Prometheus exporter.
|
||||
# Labels are already on datapoints; promotion from resource is OPTIONAL and typically NOT required.
|
||||
# If you enable transform/promote below, ensure you do not duplicate labels.
|
||||
|
||||
receivers:
|
||||
otlp:
|
||||
protocols:
|
||||
grpc:
|
||||
endpoint: ":4317"
|
||||
|
||||
processors:
|
||||
memory_limiter:
|
||||
check_interval: 5s
|
||||
limit_percentage: 80
|
||||
spike_limit_percentage: 25
|
||||
resourcedetection:
|
||||
detectors: [env, system]
|
||||
timeout: 5s
|
||||
batch: {}
|
||||
# OPTIONAL: Only enable if you need to promote resource attributes to labels.
|
||||
# WARNING: Newt already provides site_id as a label; avoid double-promotion.
|
||||
# transform/promote:
|
||||
# error_mode: ignore
|
||||
# metric_statements:
|
||||
# - context: datapoint
|
||||
# statements:
|
||||
# - set(attributes["service_instance_id"], resource.attributes["service.instance.id"]) where resource.attributes["service.instance.id"] != nil
|
||||
# - set(attributes["site_id"], resource.attributes["site_id"]) where resource.attributes["site_id"] != nil
|
||||
|
||||
exporters:
|
||||
prometheus:
|
||||
endpoint: ":8889"
|
||||
send_timestamps: true
|
||||
# prometheusremotewrite:
|
||||
# endpoint: http://mimir:9009/api/v1/push
|
||||
debug:
|
||||
verbosity: basic
|
||||
|
||||
service:
|
||||
pipelines:
|
||||
metrics:
|
||||
receivers: [otlp]
|
||||
processors: [memory_limiter, resourcedetection, batch] # add transform/promote if you really need it
|
||||
exporters: [prometheus]
|
||||
traces:
|
||||
receivers: [otlp]
|
||||
processors: [memory_limiter, resourcedetection, batch]
|
||||
exporters: [debug]
|
||||
16
examples/prometheus.with-collector.yml
Normal file
16
examples/prometheus.with-collector.yml
Normal file
@@ -0,0 +1,16 @@
|
||||
global:
|
||||
scrape_interval: 15s
|
||||
|
||||
scrape_configs:
|
||||
# IMPORTANT: Do not scrape Newt directly; scrape only the Collector!
|
||||
- job_name: 'otel-collector'
|
||||
static_configs:
|
||||
- targets: ['otel-collector:8889']
|
||||
|
||||
# optional: limit metric cardinality
|
||||
relabel_configs:
|
||||
- action: labeldrop
|
||||
regex: 'tunnel_id'
|
||||
# - action: keep
|
||||
# source_labels: [site_id]
|
||||
# regex: '(site-a|site-b)'
|
||||
21
examples/prometheus.yml
Normal file
21
examples/prometheus.yml
Normal file
@@ -0,0 +1,21 @@
|
||||
global:
|
||||
scrape_interval: 15s
|
||||
|
||||
scrape_configs:
|
||||
- job_name: 'newt'
|
||||
scrape_interval: 15s
|
||||
static_configs:
|
||||
- targets: ['newt:2112'] # /metrics
|
||||
relabel_configs:
|
||||
# optional: drop tunnel_id
|
||||
- action: labeldrop
|
||||
regex: 'tunnel_id'
|
||||
# optional: allow only specific sites
|
||||
- action: keep
|
||||
source_labels: [site_id]
|
||||
regex: '(site-a|site-b)'
|
||||
|
||||
# WARNING: Do not enable this together with the 'newt' job above or you will double-count.
|
||||
# - job_name: 'otel-collector'
|
||||
# static_configs:
|
||||
# - targets: ['otel-collector:8889']
|
||||
8
flake.lock
generated
8
flake.lock
generated
@@ -2,16 +2,16 @@
|
||||
"nodes": {
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1756217674,
|
||||
"narHash": "sha256-TH1SfSP523QI7kcPiNtMAEuwZR3Jdz0MCDXPs7TS8uo=",
|
||||
"lastModified": 1763934636,
|
||||
"narHash": "sha256-9glbI7f1uU+yzQCq5LwLgdZqx6svOhZWkd4JRY265fc=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "4e7667a90c167f7a81d906e5a75cba4ad8bee620",
|
||||
"rev": "ee09932cedcef15aaf476f9343d1dea2cb77e261",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-25.05",
|
||||
"ref": "nixpkgs-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
|
||||
63
flake.nix
63
flake.nix
@@ -2,7 +2,7 @@
|
||||
description = "newt - A tunneling client for Pangolin";
|
||||
|
||||
inputs = {
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/nixos-25.05";
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable";
|
||||
};
|
||||
|
||||
outputs =
|
||||
@@ -22,30 +22,49 @@
|
||||
system:
|
||||
let
|
||||
pkgs = pkgsFor system;
|
||||
inherit (pkgs) lib;
|
||||
|
||||
# Update version when releasing
|
||||
version = "1.4.2";
|
||||
|
||||
# Update the version in a new source tree
|
||||
srcWithReplacedVersion = pkgs.runCommand "newt-src-with-version" { } ''
|
||||
cp -r ${./.} $out
|
||||
chmod -R +w $out
|
||||
rm -rf $out/.git $out/result $out/.envrc $out/.direnv
|
||||
sed -i "s/version_replaceme/${version}/g" $out/main.go
|
||||
'';
|
||||
version = "1.8.0";
|
||||
in
|
||||
{
|
||||
default = self.packages.${system}.pangolin-newt;
|
||||
|
||||
pangolin-newt = pkgs.buildGoModule {
|
||||
pname = "pangolin-newt";
|
||||
version = version;
|
||||
src = srcWithReplacedVersion;
|
||||
vendorHash = "sha256-PENsCO2yFxLVZNPgx2OP+gWVNfjJAfXkwWS7tzlm490=";
|
||||
meta = with pkgs.lib; {
|
||||
inherit version;
|
||||
src = pkgs.nix-gitignore.gitignoreSource [ ] ./.;
|
||||
|
||||
vendorHash = "sha256-Sib6AUCpMgxlMpTc2Esvs+UU0yduVOxWUgT44FHAI+k=";
|
||||
|
||||
nativeInstallCheckInputs = [ pkgs.versionCheckHook ];
|
||||
|
||||
env = {
|
||||
CGO_ENABLED = 0;
|
||||
};
|
||||
|
||||
ldflags = [
|
||||
"-s"
|
||||
"-w"
|
||||
"-X main.newtVersion=${version}"
|
||||
];
|
||||
|
||||
# Tests are broken due to a lack of Internet.
|
||||
# Disable running `go test`, and instead do
|
||||
# a simple version check instead.
|
||||
doCheck = false;
|
||||
doInstallCheck = true;
|
||||
|
||||
versionCheckProgramArg = [ "-version" ];
|
||||
|
||||
meta = {
|
||||
description = "A tunneling client for Pangolin";
|
||||
homepage = "https://github.com/fosrl/newt";
|
||||
license = licenses.gpl3;
|
||||
maintainers = [ ];
|
||||
license = lib.licenses.gpl3;
|
||||
maintainers = [
|
||||
lib.maintainers.water-sucks
|
||||
];
|
||||
mainProgram = "newt";
|
||||
};
|
||||
};
|
||||
}
|
||||
@@ -54,10 +73,20 @@
|
||||
system:
|
||||
let
|
||||
pkgs = pkgsFor system;
|
||||
|
||||
inherit (pkgs)
|
||||
go
|
||||
gopls
|
||||
gotools
|
||||
go-outline
|
||||
gopkgs
|
||||
godef
|
||||
golint
|
||||
;
|
||||
in
|
||||
{
|
||||
default = pkgs.mkShell {
|
||||
buildInputs = with pkgs; [
|
||||
buildInputs = [
|
||||
go
|
||||
gopls
|
||||
gotools
|
||||
|
||||
70
go.mod
70
go.mod
@@ -3,52 +3,74 @@ module github.com/fosrl/newt
|
||||
go 1.25
|
||||
|
||||
require (
|
||||
github.com/docker/docker v28.5.0+incompatible
|
||||
github.com/google/gopacket v1.1.19
|
||||
github.com/docker/docker v28.5.2+incompatible
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/vishvananda/netlink v1.3.1
|
||||
golang.org/x/crypto v0.42.0
|
||||
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792
|
||||
golang.org/x/net v0.45.0
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0
|
||||
go.opentelemetry.io/contrib/instrumentation/runtime v0.64.0
|
||||
go.opentelemetry.io/otel v1.39.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.39.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.39.0
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.61.0
|
||||
go.opentelemetry.io/otel/metric v1.39.0
|
||||
go.opentelemetry.io/otel/sdk v1.39.0
|
||||
go.opentelemetry.io/otel/sdk/metric v1.39.0
|
||||
golang.org/x/crypto v0.46.0
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6
|
||||
golang.org/x/net v0.48.0
|
||||
golang.org/x/sys v0.39.0
|
||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
google.golang.org/grpc v1.77.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c
|
||||
software.sslmate.com/src/go-pkcs12 v0.6.0
|
||||
software.sslmate.com/src/go-pkcs12 v0.7.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/containerd/errdefs v1.0.0 // indirect
|
||||
github.com/Microsoft/go-winio v0.6.0 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/containerd/errdefs v0.3.0 // indirect
|
||||
github.com/containerd/errdefs/pkg v0.3.0 // indirect
|
||||
github.com/distribution/reference v0.6.0 // indirect
|
||||
github.com/docker/go-connections v0.5.0 // indirect
|
||||
github.com/docker/go-units v0.5.0 // indirect
|
||||
github.com/docker/go-connections v0.6.0 // indirect
|
||||
github.com/docker/go-units v0.4.0 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/google/btree v1.1.3 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/josharian/native v1.1.0 // indirect
|
||||
github.com/mdlayher/genetlink v1.3.2 // indirect
|
||||
github.com/mdlayher/netlink v1.7.2 // indirect
|
||||
github.com/mdlayher/socket v0.5.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
|
||||
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||
github.com/moby/sys/atomicwriter v0.1.0 // indirect
|
||||
github.com/moby/term v0.5.2 // indirect
|
||||
github.com/morikuni/aec v1.0.0 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||
github.com/opencontainers/image-spec v1.1.0 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/prometheus/client_model v0.6.2 // indirect
|
||||
github.com/prometheus/common v0.67.4 // indirect
|
||||
github.com/prometheus/otlptranslator v1.0.0 // indirect
|
||||
github.com/prometheus/procfs v0.19.2 // indirect
|
||||
github.com/vishvananda/netns v0.0.5 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0 // indirect
|
||||
go.opentelemetry.io/otel v1.37.0 // indirect
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.36.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.37.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.37.0 // indirect
|
||||
golang.org/x/sync v0.16.0 // indirect
|
||||
golang.org/x/sys v0.36.0 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0 // indirect
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.38.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.39.0 // indirect
|
||||
go.opentelemetry.io/proto/otlp v1.9.0 // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.3 // indirect
|
||||
golang.org/x/mod v0.30.0 // indirect
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
golang.org/x/text v0.32.0 // indirect
|
||||
golang.org/x/time v0.12.0 // indirect
|
||||
golang.org/x/tools v0.39.0 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect
|
||||
google.golang.org/protobuf v1.36.10 // indirect
|
||||
)
|
||||
|
||||
192
go.sum
192
go.sum
@@ -1,12 +1,15 @@
|
||||
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEKWjV8V+WSxDXJ4NFATAsZjh8iIbsQIg=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4=
|
||||
github.com/cenkalti/backoff/v5 v5.0.2 h1:rIfFVxEf1QsI7E1ZHfp/B4DF/6QBAUhmgkxc0H7Zss8=
|
||||
github.com/cenkalti/backoff/v5 v5.0.2/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
|
||||
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
||||
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
|
||||
github.com/Microsoft/go-winio v0.6.0 h1:slsWYD/zyx7lCXoZVlvQrj0hPTM1HI4+v1sIda2yDvg=
|
||||
github.com/Microsoft/go-winio v0.6.0/go.mod h1:cTAf44im0RAYeL23bpB+fzCyDH2MJiz2BO69KH/soAE=
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM=
|
||||
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/containerd/errdefs v0.3.0 h1:FSZgGOeK4yuT/+DnF07/Olde/q4KBoMsaamhXxIMDp4=
|
||||
github.com/containerd/errdefs v0.3.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
|
||||
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
|
||||
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
|
||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||
@@ -15,12 +18,12 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||
github.com/docker/docker v28.5.0+incompatible h1:ZdSQoRUE9XxhFI/B8YLvhnEFMmYN9Pp8Egd2qcaFk1E=
|
||||
github.com/docker/docker v28.5.0+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
|
||||
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
|
||||
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||
github.com/docker/docker v28.5.2+incompatible h1:DBX0Y0zAjZbSrm1uzOkdr1onVghKaftjlSWt4AFexzM=
|
||||
github.com/docker/docker v28.5.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
|
||||
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
|
||||
github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw=
|
||||
github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
@@ -28,32 +31,26 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
|
||||
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
|
||||
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 h1:5ZPtiqj0JL5oKWmcsq4VMaAW5ukBEgSGXEN89zeH1Jo=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3/go.mod h1:ndYquD05frm2vACXE1nsccT4oJzjhw2arTS2cpUD1PI=
|
||||
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
|
||||
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
|
||||
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
|
||||
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
|
||||
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
|
||||
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
|
||||
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
|
||||
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
|
||||
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
||||
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
||||
github.com/moby/sys/atomicwriter v0.1.0 h1:kw5D/EqkBwsBFi0ss9v1VG3wIkVhzGvLklJ+w3A14Sw=
|
||||
@@ -64,87 +61,106 @@ github.com/moby/term v0.5.2 h1:6qk3FJAFDs6i/q3W/pQ97SX192qKfZgGjCQqfCJkgzQ=
|
||||
github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFLc=
|
||||
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
||||
github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug=
|
||||
github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
|
||||
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
|
||||
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
|
||||
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
|
||||
github.com/prometheus/common v0.67.4 h1:yR3NqWO1/UyO1w2PhUvXlGQs/PtFmoveVO0KZ4+Lvsc=
|
||||
github.com/prometheus/common v0.67.4/go.mod h1:gP0fq6YjjNCLssJCQp0yk4M8W6ikLURwkdd/YKtTbyI=
|
||||
github.com/prometheus/otlptranslator v1.0.0 h1:s0LJW/iN9dkIH+EnhiD3BlkkP5QVIUVEoIwkU+A6qos=
|
||||
github.com/prometheus/otlptranslator v1.0.0/go.mod h1:vRYWnXvI6aWGpsdY/mOT/cbeVRBlPWtBNDb7kGR3uKM=
|
||||
github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws=
|
||||
github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
|
||||
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
|
||||
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
|
||||
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0 h1:Hf9xI/XLML9ElpiHVDNwvqI0hIFlzV8dgIr35kV1kRU=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0/go.mod h1:NfchwuyNoMcZ5MLHwPrODwUF1HWCXWrL31s8gSAdIKY=
|
||||
go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ=
|
||||
go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.36.0 h1:dNzwXjZKpMpE2JhmO+9HsPl42NIXFIFSUSSs0fiqra0=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.36.0/go.mod h1:90PoxvaEB5n6AOdZvi+yWJQoE95U8Dhhw2bSyRqnTD0=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.36.0 h1:nRVXXvf78e00EwY6Wp0YII8ww2JVWshZ20HfTlE11AM=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.36.0/go.mod h1:r49hO7CgrxY9Voaj3Xe8pANWtr0Oq916d0XAmOoCZAQ=
|
||||
go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE=
|
||||
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
|
||||
go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI=
|
||||
go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps=
|
||||
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
|
||||
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
|
||||
go.opentelemetry.io/proto/otlp v1.6.0 h1:jQjP+AQyTf+Fe7OKj/MfkDrmK4MNVtw2NpXsf9fefDI=
|
||||
go.opentelemetry.io/proto/otlp v1.6.0/go.mod h1:cicgGehlFuNdgZkcALOCh3VE6K/u2tAjzlRhDwmVpZc=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
|
||||
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
|
||||
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 h1:R9PFI6EUdfVKgwKjZef7QIwGcBKu86OEFpJ9nUEP2l4=
|
||||
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
|
||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
|
||||
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 h1:ssfIgGNANqpVFCndZvcuyKbl0g+UAVcbBcqGkG28H0Y=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0/go.mod h1:GQ/474YrbE4Jx8gZ4q5I4hrhUzM6UPzyrqJYV2AqPoQ=
|
||||
go.opentelemetry.io/contrib/instrumentation/runtime v0.64.0 h1:/+/+UjlXjFcdDlXxKL1PouzX8Z2Vl0OxolRKeBEgYDw=
|
||||
go.opentelemetry.io/contrib/instrumentation/runtime v0.64.0/go.mod h1:Ldm/PDuzY2DP7IypudopCR3OCOW42NJlN9+mNEroevo=
|
||||
go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48=
|
||||
go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.39.0 h1:cEf8jF6WbuGQWUVcqgyWtTR0kOOAWY1DYZ+UhvdmQPw=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.39.0/go.mod h1:k1lzV5n5U3HkGvTCJHraTAGJ7MqsgL1wrGwTj1Isfiw=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0 h1:f0cb2XPmrqn4XMy9PNliTgRKJgS5WcL/u0/WRYGz4t0=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0/go.mod h1:vnakAaFckOMiMtOIhFI2MNH4FYrZzXCYxmb1LlhoGz8=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.39.0 h1:in9O8ESIOlwJAEGTkkf34DesGRAc/Pn8qJ7k3r/42LM=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.39.0/go.mod h1:Rp0EXBm5tfnv0WL+ARyO/PHBEaEAT8UUHQ6AGJcSq6c=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.38.0 h1:aTL7F04bJHUlztTsNGJ2l+6he8c+y/b//eR0jjjemT4=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.38.0/go.mod h1:kldtb7jDTeol0l3ewcmd8SDvx3EmIE7lyvqbasU3QC4=
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.61.0 h1:cCyZS4dr67d30uDyh8etKM2QyDsQ4zC9ds3bdbrVoD0=
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.61.0/go.mod h1:iivMuj3xpR2DkUrUya3TPS/Z9h3dz7h01GxU+fQBRNg=
|
||||
go.opentelemetry.io/otel/metric v1.39.0 h1:d1UzonvEZriVfpNKEVmHXbdf909uGTOQjA0HF0Ls5Q0=
|
||||
go.opentelemetry.io/otel/metric v1.39.0/go.mod h1:jrZSWL33sD7bBxg1xjrqyDjnuzTUB0x1nBERXd7Ftcs=
|
||||
go.opentelemetry.io/otel/sdk v1.39.0 h1:nMLYcjVsvdui1B/4FRkwjzoRVsMK8uL/cj0OyhKzt18=
|
||||
go.opentelemetry.io/otel/sdk v1.39.0/go.mod h1:vDojkC4/jsTJsE+kh+LXYQlbL8CgrEcwmt1ENZszdJE=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2WKg+sEJTtB8=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew=
|
||||
go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI=
|
||||
go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA=
|
||||
go.opentelemetry.io/proto/otlp v1.9.0 h1:l706jCMITVouPOqEnii2fIAuO3IVGBRPV5ICjceRb/A=
|
||||
go.opentelemetry.io/proto/otlp v1.9.0/go.mod h1:xE+Cx5E/eEHw+ISFkwPLwCZefwVjY+pqKg1qcK03+/4=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
|
||||
go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8=
|
||||
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
||||
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0=
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0=
|
||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
||||
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
|
||||
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
|
||||
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
|
||||
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
|
||||
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
|
||||
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
|
||||
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
|
||||
google.golang.org/genproto v0.0.0-20230920204549-e6e6cdab5c13 h1:vlzZttNJGVqTsRFU9AmdnrcO1Znh8Ew9kCD//yjigk0=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250519155744-55703ea1f237 h1:Kog3KlB4xevJlAcbbbzPfRG0+X9fdoGM+UBRKVz6Wr0=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250519155744-55703ea1f237/go.mod h1:ezi0AVyMKDWy5xAncvjLWH7UcLBB5n7y2fQ8MzjJcto=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250519155744-55703ea1f237 h1:cJfm9zPbe1e873mHJzmQ1nwVEeRDU/T1wXDK2kUSU34=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250519155744-55703ea1f237/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
|
||||
google.golang.org/grpc v1.72.1 h1:HR03wO6eyZ7lknl75XlxABNVLLFc2PAb6mHlYh756mA=
|
||||
google.golang.org/grpc v1.72.1/go.mod h1:wH5Aktxcg25y1I3w7H69nHfXdOG3UiadoBtjh3izSDM=
|
||||
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
|
||||
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 h1:fCvbg86sFXwdrl5LgVcTEvNC+2txB5mgROGmRL5mrls=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:+rXWjjaukWZun3mLfjmVnQi18E1AsFbDN9QdJ5YXLto=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 h1:gRkg/vSppuSQoDjxyiGfN4Upv/h/DQmIR10ZU8dh4Ww=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk=
|
||||
google.golang.org/grpc v1.77.0 h1:wVVY6/8cGA6vvffn+wWK5ToddbgdU3d8MNENr4evgXM=
|
||||
google.golang.org/grpc v1.77.0/go.mod h1:z0BY1iVj0q8E1uSQCjL9cppRj+gnZjzDnzV0dHhrNig=
|
||||
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
|
||||
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
@@ -154,5 +170,5 @@ gotest.tools/v3 v3.4.0 h1:ZazjZUfuVeZGLAmlKKuyv3IKP5orXcwtOwDQH6YVr6o=
|
||||
gotest.tools/v3 v3.4.0/go.mod h1:CtbdzLSsqVhDgMtKsx03ird5YTGB3ar27v0u/yKBW5g=
|
||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c h1:m/r7OM+Y2Ty1sgBQ7Qb27VgIMBW8ZZhT4gLnUyDIhzI=
|
||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=
|
||||
software.sslmate.com/src/go-pkcs12 v0.6.0 h1:f3sQittAeF+pao32Vb+mkli+ZyT+VwKaD014qFGq6oU=
|
||||
software.sslmate.com/src/go-pkcs12 v0.6.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=
|
||||
software.sslmate.com/src/go-pkcs12 v0.7.0 h1:Db8W44cB54TWD7stUFFSWxdfpdn6fZVcDl0w3R4RVM0=
|
||||
software.sslmate.com/src/go-pkcs12 v0.7.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=
|
||||
|
||||
@@ -48,6 +48,7 @@ type Config struct {
|
||||
Headers map[string]string `json:"hcHeaders"`
|
||||
Method string `json:"hcMethod"`
|
||||
Status int `json:"hcStatus"` // HTTP status code
|
||||
TLSServerName string `json:"hcTlsServerName"`
|
||||
}
|
||||
|
||||
// Target represents a health check target with its current status
|
||||
@@ -57,9 +58,10 @@ type Target struct {
|
||||
LastCheck time.Time `json:"lastCheck"`
|
||||
LastError string `json:"lastError,omitempty"`
|
||||
CheckCount int `json:"checkCount"`
|
||||
ticker *time.Ticker
|
||||
timer *time.Timer
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// StatusChangeCallback is called when any target's status changes
|
||||
@@ -70,7 +72,6 @@ type Monitor struct {
|
||||
targets map[int]*Target
|
||||
mutex sync.RWMutex
|
||||
callback StatusChangeCallback
|
||||
client *http.Client
|
||||
enforceCert bool
|
||||
}
|
||||
|
||||
@@ -78,21 +79,10 @@ type Monitor struct {
|
||||
func NewMonitor(callback StatusChangeCallback, enforceCert bool) *Monitor {
|
||||
logger.Debug("Creating new health check monitor with certificate enforcement: %t", enforceCert)
|
||||
|
||||
// Configure TLS settings based on certificate enforcement
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: !enforceCert,
|
||||
},
|
||||
}
|
||||
|
||||
return &Monitor{
|
||||
targets: make(map[int]*Target),
|
||||
callback: callback,
|
||||
enforceCert: enforceCert,
|
||||
client: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: transport,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -196,6 +186,16 @@ func (m *Monitor) addTargetUnsafe(config Config) error {
|
||||
Status: StatusUnknown,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
client: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
// Configure TLS settings based on certificate enforcement
|
||||
InsecureSkipVerify: !m.enforceCert,
|
||||
// Use SNI TLS header if present
|
||||
ServerName: config.TLSServerName,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
m.targets[config.ID] = target
|
||||
@@ -315,26 +315,26 @@ func (m *Monitor) monitorTarget(target *Target) {
|
||||
go m.callback(m.GetTargets())
|
||||
}
|
||||
|
||||
// Set up ticker based on current status
|
||||
// Set up timer based on current status
|
||||
interval := time.Duration(target.Config.Interval) * time.Second
|
||||
if target.Status == StatusUnhealthy {
|
||||
interval = time.Duration(target.Config.UnhealthyInterval) * time.Second
|
||||
}
|
||||
|
||||
logger.Debug("Target %d: initial check interval set to %v", target.Config.ID, interval)
|
||||
target.ticker = time.NewTicker(interval)
|
||||
defer target.ticker.Stop()
|
||||
target.timer = time.NewTimer(interval)
|
||||
defer target.timer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-target.ctx.Done():
|
||||
logger.Info("Stopping health check monitoring for target %d", target.Config.ID)
|
||||
return
|
||||
case <-target.ticker.C:
|
||||
case <-target.timer.C:
|
||||
oldStatus := target.Status
|
||||
m.performHealthCheck(target)
|
||||
|
||||
// Update ticker interval if status changed
|
||||
// Update timer interval if status changed
|
||||
newInterval := time.Duration(target.Config.Interval) * time.Second
|
||||
if target.Status == StatusUnhealthy {
|
||||
newInterval = time.Duration(target.Config.UnhealthyInterval) * time.Second
|
||||
@@ -343,11 +343,12 @@ func (m *Monitor) monitorTarget(target *Target) {
|
||||
if newInterval != interval {
|
||||
logger.Debug("Target %d: updating check interval from %v to %v due to status change",
|
||||
target.Config.ID, interval, newInterval)
|
||||
target.ticker.Stop()
|
||||
target.ticker = time.NewTicker(newInterval)
|
||||
interval = newInterval
|
||||
}
|
||||
|
||||
// Reset timer for next check with current interval
|
||||
target.timer.Reset(interval)
|
||||
|
||||
// Notify callback if status changed
|
||||
if oldStatus != target.Status && m.callback != nil {
|
||||
logger.Info("Target %d status changed: %s -> %s",
|
||||
@@ -398,11 +399,16 @@ func (m *Monitor) performHealthCheck(target *Target) {
|
||||
|
||||
// Add headers
|
||||
for key, value := range target.Config.Headers {
|
||||
req.Header.Set(key, value)
|
||||
// Handle Host header specially - it must be set on req.Host, not in headers
|
||||
if strings.EqualFold(key, "Host") {
|
||||
req.Host = value
|
||||
} else {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// Perform request
|
||||
resp, err := m.client.Do(req)
|
||||
resp, err := target.client.Do(req)
|
||||
if err != nil {
|
||||
target.Status = StatusUnhealthy
|
||||
target.LastError = fmt.Sprintf("request failed: %v", err)
|
||||
|
||||
602
holepunch/holepunch.go
Normal file
602
holepunch/holepunch.go
Normal file
@@ -0,0 +1,602 @@
|
||||
package holepunch
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/bind"
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/util"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
mrand "golang.org/x/exp/rand"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// ExitNode represents a WireGuard exit node for hole punching
|
||||
type ExitNode struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
RelayPort uint16 `json:"relayPort"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
SiteIds []int `json:"siteIds,omitempty"`
|
||||
}
|
||||
|
||||
// Manager handles UDP hole punching operations
|
||||
type Manager struct {
|
||||
mu sync.Mutex
|
||||
running bool
|
||||
stopChan chan struct{}
|
||||
sharedBind *bind.SharedBind
|
||||
ID string
|
||||
token string
|
||||
publicKey string
|
||||
clientType string
|
||||
exitNodes map[string]ExitNode // key is endpoint
|
||||
updateChan chan struct{} // signals the goroutine to refresh exit nodes
|
||||
|
||||
sendHolepunchInterval time.Duration
|
||||
sendHolepunchIntervalMin time.Duration
|
||||
sendHolepunchIntervalMax time.Duration
|
||||
defaultIntervalMin time.Duration
|
||||
defaultIntervalMax time.Duration
|
||||
}
|
||||
|
||||
const defaultSendHolepunchIntervalMax = 60 * time.Second
|
||||
const defaultSendHolepunchIntervalMin = 1 * time.Second
|
||||
|
||||
// NewManager creates a new hole punch manager
|
||||
func NewManager(sharedBind *bind.SharedBind, ID string, clientType string, publicKey string) *Manager {
|
||||
return &Manager{
|
||||
sharedBind: sharedBind,
|
||||
ID: ID,
|
||||
clientType: clientType,
|
||||
publicKey: publicKey,
|
||||
exitNodes: make(map[string]ExitNode),
|
||||
sendHolepunchInterval: defaultSendHolepunchIntervalMin,
|
||||
sendHolepunchIntervalMin: defaultSendHolepunchIntervalMin,
|
||||
sendHolepunchIntervalMax: defaultSendHolepunchIntervalMax,
|
||||
defaultIntervalMin: defaultSendHolepunchIntervalMin,
|
||||
defaultIntervalMax: defaultSendHolepunchIntervalMax,
|
||||
}
|
||||
}
|
||||
|
||||
// SetToken updates the authentication token used for hole punching
|
||||
func (m *Manager) SetToken(token string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.token = token
|
||||
}
|
||||
|
||||
// IsRunning returns whether hole punching is currently active
|
||||
func (m *Manager) IsRunning() bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.running
|
||||
}
|
||||
|
||||
// Stop stops any ongoing hole punch operations
|
||||
func (m *Manager) Stop() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if !m.running {
|
||||
return
|
||||
}
|
||||
|
||||
if m.stopChan != nil {
|
||||
close(m.stopChan)
|
||||
m.stopChan = nil
|
||||
}
|
||||
|
||||
if m.updateChan != nil {
|
||||
close(m.updateChan)
|
||||
m.updateChan = nil
|
||||
}
|
||||
|
||||
m.running = false
|
||||
logger.Info("Hole punch manager stopped")
|
||||
}
|
||||
|
||||
// AddExitNode adds a new exit node to the rotation if it doesn't already exist
|
||||
func (m *Manager) AddExitNode(exitNode ExitNode) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, exists := m.exitNodes[exitNode.Endpoint]; exists {
|
||||
logger.Debug("Exit node %s already exists in rotation", exitNode.Endpoint)
|
||||
return false
|
||||
}
|
||||
|
||||
m.exitNodes[exitNode.Endpoint] = exitNode
|
||||
logger.Info("Added exit node %s to hole punch rotation", exitNode.Endpoint)
|
||||
|
||||
// Signal the goroutine to refresh if running
|
||||
if m.running && m.updateChan != nil {
|
||||
select {
|
||||
case m.updateChan <- struct{}{}:
|
||||
default:
|
||||
// Channel full or closed, skip
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// RemoveExitNode removes an exit node from the rotation
|
||||
func (m *Manager) RemoveExitNode(endpoint string) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, exists := m.exitNodes[endpoint]; !exists {
|
||||
logger.Debug("Exit node %s not found in rotation", endpoint)
|
||||
return false
|
||||
}
|
||||
|
||||
delete(m.exitNodes, endpoint)
|
||||
logger.Info("Removed exit node %s from hole punch rotation", endpoint)
|
||||
|
||||
// Signal the goroutine to refresh if running
|
||||
if m.running && m.updateChan != nil {
|
||||
select {
|
||||
case m.updateChan <- struct{}{}:
|
||||
default:
|
||||
// Channel full or closed, skip
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
/*
|
||||
RemoveExitNodesByPeer removes the peer ID from the SiteIds list in each exit node.
|
||||
If the SiteIds list becomes empty after removal, the exit node is removed entirely.
|
||||
Returns the number of exit nodes removed.
|
||||
*/
|
||||
func (m *Manager) RemoveExitNodesByPeer(peerID int) int {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
removed := 0
|
||||
for endpoint, node := range m.exitNodes {
|
||||
// Remove peerID from SiteIds if present
|
||||
newSiteIds := make([]int, 0, len(node.SiteIds))
|
||||
for _, id := range node.SiteIds {
|
||||
if id != peerID {
|
||||
newSiteIds = append(newSiteIds, id)
|
||||
}
|
||||
}
|
||||
if len(newSiteIds) != len(node.SiteIds) {
|
||||
node.SiteIds = newSiteIds
|
||||
if len(node.SiteIds) == 0 {
|
||||
delete(m.exitNodes, endpoint)
|
||||
logger.Info("Removed exit node %s as no more site IDs remain after removing peer %d", endpoint, peerID)
|
||||
removed++
|
||||
} else {
|
||||
m.exitNodes[endpoint] = node
|
||||
logger.Info("Removed peer %d from exit node %s site IDs", peerID, endpoint)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if removed > 0 {
|
||||
// Signal the goroutine to refresh if running
|
||||
if m.running && m.updateChan != nil {
|
||||
select {
|
||||
case m.updateChan <- struct{}{}:
|
||||
default:
|
||||
// Channel full or closed, skip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return removed
|
||||
}
|
||||
|
||||
// GetExitNodes returns a copy of the current exit nodes
|
||||
func (m *Manager) GetExitNodes() []ExitNode {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
nodes := make([]ExitNode, 0, len(m.exitNodes))
|
||||
for _, node := range m.exitNodes {
|
||||
nodes = append(nodes, node)
|
||||
}
|
||||
return nodes
|
||||
}
|
||||
|
||||
// SetServerHolepunchInterval sets custom min and max intervals for hole punching.
|
||||
// This is useful for low power mode where longer intervals are desired.
|
||||
func (m *Manager) SetServerHolepunchInterval(min, max time.Duration) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.sendHolepunchIntervalMin = min
|
||||
m.sendHolepunchIntervalMax = max
|
||||
m.sendHolepunchInterval = min
|
||||
|
||||
logger.Info("Set hole punch intervals: min=%v, max=%v", min, max)
|
||||
|
||||
// Signal the goroutine to apply the new interval if running
|
||||
if m.running && m.updateChan != nil {
|
||||
select {
|
||||
case m.updateChan <- struct{}{}:
|
||||
default:
|
||||
// Channel full or closed, skip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetInterval returns the current min and max intervals
|
||||
func (m *Manager) GetServerHolepunchInterval() (min, max time.Duration) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.sendHolepunchIntervalMin, m.sendHolepunchIntervalMax
|
||||
}
|
||||
|
||||
// ResetServerHolepunchInterval resets the hole punch interval back to the default values.
|
||||
// This restores normal operation after low power mode or other custom settings.
|
||||
func (m *Manager) ResetServerHolepunchInterval() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.sendHolepunchIntervalMin = m.defaultIntervalMin
|
||||
m.sendHolepunchIntervalMax = m.defaultIntervalMax
|
||||
m.sendHolepunchInterval = m.defaultIntervalMin
|
||||
|
||||
logger.Info("Reset hole punch intervals to defaults: min=%v, max=%v", m.defaultIntervalMin, m.defaultIntervalMax)
|
||||
|
||||
// Signal the goroutine to apply the new interval if running
|
||||
if m.running && m.updateChan != nil {
|
||||
select {
|
||||
case m.updateChan <- struct{}{}:
|
||||
default:
|
||||
// Channel full or closed, skip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TriggerHolePunch sends an immediate hole punch packet to all configured exit nodes
|
||||
// This is useful for triggering hole punching on demand without waiting for the interval
|
||||
func (m *Manager) TriggerHolePunch() error {
|
||||
m.mu.Lock()
|
||||
|
||||
if len(m.exitNodes) == 0 {
|
||||
m.mu.Unlock()
|
||||
return fmt.Errorf("no exit nodes configured")
|
||||
}
|
||||
|
||||
// Get a copy of exit nodes to work with
|
||||
currentExitNodes := make([]ExitNode, 0, len(m.exitNodes))
|
||||
for _, node := range m.exitNodes {
|
||||
currentExitNodes = append(currentExitNodes, node)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
logger.Info("Triggering on-demand hole punch to %d exit nodes", len(currentExitNodes))
|
||||
|
||||
// Send hole punch to all exit nodes
|
||||
successCount := 0
|
||||
for _, exitNode := range currentExitNodes {
|
||||
host, err := util.ResolveDomain(exitNode.Endpoint)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
|
||||
continue
|
||||
}
|
||||
|
||||
serverAddr := net.JoinHostPort(host, strconv.Itoa(int(exitNode.RelayPort)))
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.sendHolePunch(remoteAddr, exitNode.PublicKey); err != nil {
|
||||
logger.Warn("Failed to send on-demand hole punch to %s: %v", exitNode.Endpoint, err)
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Debug("Sent on-demand hole punch to %s", exitNode.Endpoint)
|
||||
successCount++
|
||||
}
|
||||
|
||||
if successCount == 0 {
|
||||
return fmt.Errorf("failed to send hole punch to any exit node")
|
||||
}
|
||||
|
||||
logger.Info("Successfully sent on-demand hole punch to %d/%d exit nodes", successCount, len(currentExitNodes))
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartMultipleExitNodes starts hole punching to multiple exit nodes
|
||||
func (m *Manager) StartMultipleExitNodes(exitNodes []ExitNode) error {
|
||||
m.mu.Lock()
|
||||
|
||||
if m.running {
|
||||
m.mu.Unlock()
|
||||
logger.Debug("UDP hole punch already running, skipping new request")
|
||||
return fmt.Errorf("hole punch already running")
|
||||
}
|
||||
|
||||
// Populate exit nodes map
|
||||
m.exitNodes = make(map[string]ExitNode)
|
||||
for _, node := range exitNodes {
|
||||
m.exitNodes[node.Endpoint] = node
|
||||
}
|
||||
|
||||
m.running = true
|
||||
m.stopChan = make(chan struct{})
|
||||
m.updateChan = make(chan struct{}, 1)
|
||||
m.mu.Unlock()
|
||||
|
||||
logger.Debug("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes))
|
||||
|
||||
go m.runMultipleExitNodes()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start starts hole punching with the current set of exit nodes
|
||||
func (m *Manager) Start() error {
|
||||
m.mu.Lock()
|
||||
|
||||
if m.running {
|
||||
m.mu.Unlock()
|
||||
logger.Debug("UDP hole punch already running")
|
||||
return fmt.Errorf("hole punch already running")
|
||||
}
|
||||
|
||||
m.running = true
|
||||
m.stopChan = make(chan struct{})
|
||||
m.updateChan = make(chan struct{}, 1)
|
||||
nodeCount := len(m.exitNodes)
|
||||
m.mu.Unlock()
|
||||
|
||||
if nodeCount == 0 {
|
||||
logger.Info("Starting UDP hole punch manager (waiting for exit nodes to be added)")
|
||||
} else {
|
||||
logger.Info("Starting UDP hole punch with %d exit nodes", nodeCount)
|
||||
}
|
||||
|
||||
go m.runMultipleExitNodes()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// runMultipleExitNodes performs hole punching to multiple exit nodes
|
||||
func (m *Manager) runMultipleExitNodes() {
|
||||
defer func() {
|
||||
m.mu.Lock()
|
||||
m.running = false
|
||||
m.mu.Unlock()
|
||||
logger.Info("UDP hole punch goroutine ended for all exit nodes")
|
||||
}()
|
||||
|
||||
// Resolve all endpoints upfront
|
||||
type resolvedExitNode struct {
|
||||
remoteAddr *net.UDPAddr
|
||||
publicKey string
|
||||
endpointName string
|
||||
}
|
||||
|
||||
resolveNodes := func() []resolvedExitNode {
|
||||
m.mu.Lock()
|
||||
currentExitNodes := make([]ExitNode, 0, len(m.exitNodes))
|
||||
for _, node := range m.exitNodes {
|
||||
currentExitNodes = append(currentExitNodes, node)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
var resolvedNodes []resolvedExitNode
|
||||
for _, exitNode := range currentExitNodes {
|
||||
host, err := util.ResolveDomain(exitNode.Endpoint)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
|
||||
continue
|
||||
}
|
||||
|
||||
serverAddr := net.JoinHostPort(host, strconv.Itoa(int(exitNode.RelayPort)))
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
resolvedNodes = append(resolvedNodes, resolvedExitNode{
|
||||
remoteAddr: remoteAddr,
|
||||
publicKey: exitNode.PublicKey,
|
||||
endpointName: exitNode.Endpoint,
|
||||
})
|
||||
logger.Debug("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
|
||||
}
|
||||
return resolvedNodes
|
||||
}
|
||||
|
||||
resolvedNodes := resolveNodes()
|
||||
|
||||
if len(resolvedNodes) == 0 {
|
||||
logger.Info("No exit nodes available yet, waiting for nodes to be added")
|
||||
} else {
|
||||
// Send initial hole punch to all exit nodes
|
||||
for _, node := range resolvedNodes {
|
||||
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
|
||||
logger.Warn("Failed to send initial hole punch to %s: %v", node.endpointName, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Start with minimum interval
|
||||
m.mu.Lock()
|
||||
m.sendHolepunchInterval = m.sendHolepunchIntervalMin
|
||||
m.mu.Unlock()
|
||||
|
||||
ticker := time.NewTicker(m.sendHolepunchInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.stopChan:
|
||||
logger.Debug("Hole punch stopped by signal")
|
||||
return
|
||||
case <-m.updateChan:
|
||||
// Re-resolve exit nodes when update is signaled
|
||||
logger.Info("Refreshing exit nodes for hole punching")
|
||||
resolvedNodes = resolveNodes()
|
||||
if len(resolvedNodes) == 0 {
|
||||
logger.Warn("No exit nodes available after refresh")
|
||||
} else {
|
||||
logger.Info("Updated resolved nodes count: %d", len(resolvedNodes))
|
||||
}
|
||||
// Reset interval to minimum on update
|
||||
m.mu.Lock()
|
||||
m.sendHolepunchInterval = m.sendHolepunchIntervalMin
|
||||
m.mu.Unlock()
|
||||
ticker.Reset(m.sendHolepunchInterval)
|
||||
// Send immediate hole punch to newly resolved nodes
|
||||
for _, node := range resolvedNodes {
|
||||
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
|
||||
logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err)
|
||||
}
|
||||
}
|
||||
case <-ticker.C:
|
||||
// Send hole punch to all exit nodes (if any are available)
|
||||
if len(resolvedNodes) > 0 {
|
||||
for _, node := range resolvedNodes {
|
||||
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
|
||||
logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err)
|
||||
}
|
||||
}
|
||||
// Exponential backoff: double the interval up to max
|
||||
m.mu.Lock()
|
||||
newInterval := m.sendHolepunchInterval * 2
|
||||
if newInterval > m.sendHolepunchIntervalMax {
|
||||
newInterval = m.sendHolepunchIntervalMax
|
||||
}
|
||||
if newInterval != m.sendHolepunchInterval {
|
||||
m.sendHolepunchInterval = newInterval
|
||||
ticker.Reset(m.sendHolepunchInterval)
|
||||
logger.Debug("Increased hole punch interval to %v", m.sendHolepunchInterval)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendHolePunch sends an encrypted hole punch packet using the shared bind
|
||||
func (m *Manager) sendHolePunch(remoteAddr *net.UDPAddr, serverPubKey string) error {
|
||||
m.mu.Lock()
|
||||
token := m.token
|
||||
ID := m.ID
|
||||
m.mu.Unlock()
|
||||
|
||||
if serverPubKey == "" || token == "" {
|
||||
return fmt.Errorf("server public key or OLM token is empty")
|
||||
}
|
||||
|
||||
var payload interface{}
|
||||
if m.clientType == "newt" {
|
||||
payload = struct {
|
||||
ID string `json:"newtId"`
|
||||
Token string `json:"token"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
}{
|
||||
ID: ID,
|
||||
Token: token,
|
||||
PublicKey: m.publicKey,
|
||||
}
|
||||
} else {
|
||||
payload = struct {
|
||||
ID string `json:"olmId"`
|
||||
Token string `json:"token"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
}{
|
||||
ID: ID,
|
||||
Token: token,
|
||||
PublicKey: m.publicKey,
|
||||
}
|
||||
}
|
||||
|
||||
// Convert payload to JSON
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt the payload using the server's WireGuard public key
|
||||
encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt payload: %w", err)
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(encryptedPayload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal encrypted payload: %w", err)
|
||||
}
|
||||
|
||||
_, err = m.sharedBind.WriteToUDP(jsonData, remoteAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write to UDP: %w", err)
|
||||
}
|
||||
|
||||
logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// encryptPayload encrypts the payload using ChaCha20-Poly1305 AEAD with X25519 key exchange
|
||||
func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) {
|
||||
// Generate an ephemeral keypair for this message
|
||||
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
|
||||
}
|
||||
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
|
||||
|
||||
// Parse the server's public key
|
||||
serverPubKey, err := wgtypes.ParseKey(serverPublicKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse server public key: %v", err)
|
||||
}
|
||||
|
||||
// Use X25519 for key exchange
|
||||
var ephPrivKeyFixed [32]byte
|
||||
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
|
||||
|
||||
// Perform X25519 key exchange
|
||||
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
|
||||
}
|
||||
|
||||
// Create an AEAD cipher using the shared secret
|
||||
aead, err := chacha20poly1305.New(sharedSecret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
|
||||
}
|
||||
|
||||
// Generate a random nonce
|
||||
nonce := make([]byte, aead.NonceSize())
|
||||
if _, err := mrand.Read(nonce); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate nonce: %v", err)
|
||||
}
|
||||
|
||||
// Encrypt the payload
|
||||
ciphertext := aead.Seal(nil, nonce, payload, nil)
|
||||
|
||||
// Prepare the final encrypted message
|
||||
encryptedMsg := struct {
|
||||
EphemeralPublicKey string `json:"ephemeralPublicKey"`
|
||||
Nonce []byte `json:"nonce"`
|
||||
Ciphertext []byte `json:"ciphertext"`
|
||||
}{
|
||||
EphemeralPublicKey: ephemeralPublicKey.String(),
|
||||
Nonce: nonce,
|
||||
Ciphertext: ciphertext,
|
||||
}
|
||||
|
||||
return encryptedMsg, nil
|
||||
}
|
||||
404
holepunch/tester.go
Normal file
404
holepunch/tester.go
Normal file
@@ -0,0 +1,404 @@
|
||||
package holepunch
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/bind"
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/util"
|
||||
)
|
||||
|
||||
// TestResult represents the result of a connection test
|
||||
type TestResult struct {
|
||||
// Success indicates whether the test was successful
|
||||
Success bool
|
||||
// RTT is the round-trip time of the test packet
|
||||
RTT time.Duration
|
||||
// Endpoint is the endpoint that was tested
|
||||
Endpoint string
|
||||
// Error contains any error that occurred during the test
|
||||
Error error
|
||||
}
|
||||
|
||||
// TestConnectionOptions configures the connection test
|
||||
type TestConnectionOptions struct {
|
||||
// Timeout is how long to wait for a response (default: 5 seconds)
|
||||
Timeout time.Duration
|
||||
// Retries is the number of times to retry on failure (default: 0)
|
||||
Retries int
|
||||
}
|
||||
|
||||
// DefaultTestOptions returns the default test options
|
||||
func DefaultTestOptions() TestConnectionOptions {
|
||||
return TestConnectionOptions{
|
||||
Timeout: 5 * time.Second,
|
||||
Retries: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// cachedAddr holds a cached resolved UDP address
|
||||
type cachedAddr struct {
|
||||
addr *net.UDPAddr
|
||||
resolvedAt time.Time
|
||||
}
|
||||
|
||||
// HolepunchTester monitors holepunch connectivity using magic packets
|
||||
type HolepunchTester struct {
|
||||
sharedBind *bind.SharedBind
|
||||
mu sync.RWMutex
|
||||
running bool
|
||||
stopChan chan struct{}
|
||||
|
||||
// Pending requests waiting for responses (key: echo data as string)
|
||||
pendingRequests sync.Map // map[string]*pendingRequest
|
||||
|
||||
// Callback when connection status changes
|
||||
callback HolepunchStatusCallback
|
||||
|
||||
// Address cache to avoid repeated DNS/UDP resolution
|
||||
addrCache map[string]*cachedAddr
|
||||
addrCacheMu sync.RWMutex
|
||||
addrCacheTTL time.Duration // How long cached addresses are valid
|
||||
}
|
||||
|
||||
// HolepunchStatus represents the status of a holepunch connection
|
||||
type HolepunchStatus struct {
|
||||
Endpoint string
|
||||
Connected bool
|
||||
RTT time.Duration
|
||||
}
|
||||
|
||||
// HolepunchStatusCallback is called when holepunch status changes
|
||||
type HolepunchStatusCallback func(status HolepunchStatus)
|
||||
|
||||
// pendingRequest tracks a pending test request
|
||||
type pendingRequest struct {
|
||||
endpoint string
|
||||
sentAt time.Time
|
||||
replyChan chan time.Duration
|
||||
}
|
||||
|
||||
// NewHolepunchTester creates a new holepunch tester using the given SharedBind
|
||||
func NewHolepunchTester(sharedBind *bind.SharedBind) *HolepunchTester {
|
||||
return &HolepunchTester{
|
||||
sharedBind: sharedBind,
|
||||
addrCache: make(map[string]*cachedAddr),
|
||||
addrCacheTTL: 5 * time.Minute, // Cache addresses for 5 minutes
|
||||
}
|
||||
}
|
||||
|
||||
// SetCallback sets the callback for connection status changes
|
||||
func (t *HolepunchTester) SetCallback(callback HolepunchStatusCallback) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.callback = callback
|
||||
}
|
||||
|
||||
// Start begins listening for magic packet responses
|
||||
func (t *HolepunchTester) Start() error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if t.running {
|
||||
return fmt.Errorf("tester already running")
|
||||
}
|
||||
|
||||
if t.sharedBind == nil {
|
||||
return fmt.Errorf("sharedBind is nil")
|
||||
}
|
||||
|
||||
t.running = true
|
||||
t.stopChan = make(chan struct{})
|
||||
|
||||
// Register our callback with the SharedBind to receive magic responses
|
||||
t.sharedBind.SetMagicResponseCallback(t.handleResponse)
|
||||
|
||||
logger.Debug("HolepunchTester started")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the tester
|
||||
func (t *HolepunchTester) Stop() {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if !t.running {
|
||||
return
|
||||
}
|
||||
|
||||
t.running = false
|
||||
close(t.stopChan)
|
||||
|
||||
// Clear the callback
|
||||
if t.sharedBind != nil {
|
||||
t.sharedBind.SetMagicResponseCallback(nil)
|
||||
}
|
||||
|
||||
// Cancel all pending requests
|
||||
t.pendingRequests.Range(func(key, value interface{}) bool {
|
||||
if req, ok := value.(*pendingRequest); ok {
|
||||
close(req.replyChan)
|
||||
}
|
||||
t.pendingRequests.Delete(key)
|
||||
return true
|
||||
})
|
||||
|
||||
// Clear address cache
|
||||
t.addrCacheMu.Lock()
|
||||
t.addrCache = make(map[string]*cachedAddr)
|
||||
t.addrCacheMu.Unlock()
|
||||
|
||||
logger.Debug("HolepunchTester stopped")
|
||||
}
|
||||
|
||||
// resolveEndpoint resolves an endpoint to a UDP address, using cache when possible
|
||||
func (t *HolepunchTester) resolveEndpoint(endpoint string) (*net.UDPAddr, error) {
|
||||
// Check cache first
|
||||
t.addrCacheMu.RLock()
|
||||
cached, ok := t.addrCache[endpoint]
|
||||
ttl := t.addrCacheTTL
|
||||
t.addrCacheMu.RUnlock()
|
||||
|
||||
if ok && time.Since(cached.resolvedAt) < ttl {
|
||||
return cached.addr, nil
|
||||
}
|
||||
|
||||
// Resolve the endpoint
|
||||
host, err := util.ResolveDomain(endpoint)
|
||||
if err != nil {
|
||||
host = endpoint
|
||||
}
|
||||
|
||||
_, _, err = net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
host = net.JoinHostPort(host, "21820")
|
||||
}
|
||||
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve UDP address %s: %w", host, err)
|
||||
}
|
||||
|
||||
// Cache the result
|
||||
t.addrCacheMu.Lock()
|
||||
t.addrCache[endpoint] = &cachedAddr{
|
||||
addr: remoteAddr,
|
||||
resolvedAt: time.Now(),
|
||||
}
|
||||
t.addrCacheMu.Unlock()
|
||||
|
||||
return remoteAddr, nil
|
||||
}
|
||||
|
||||
// InvalidateCache removes a specific endpoint from the address cache
|
||||
func (t *HolepunchTester) InvalidateCache(endpoint string) {
|
||||
t.addrCacheMu.Lock()
|
||||
delete(t.addrCache, endpoint)
|
||||
t.addrCacheMu.Unlock()
|
||||
}
|
||||
|
||||
// ClearCache clears all cached addresses
|
||||
func (t *HolepunchTester) ClearCache() {
|
||||
t.addrCacheMu.Lock()
|
||||
t.addrCache = make(map[string]*cachedAddr)
|
||||
t.addrCacheMu.Unlock()
|
||||
}
|
||||
|
||||
// handleResponse is called by SharedBind when a magic response is received
|
||||
func (t *HolepunchTester) handleResponse(addr netip.AddrPort, echoData []byte) {
|
||||
// logger.Debug("Received magic response from %s", addr.String())
|
||||
key := string(echoData)
|
||||
|
||||
value, ok := t.pendingRequests.LoadAndDelete(key)
|
||||
if !ok {
|
||||
// No matching request found
|
||||
logger.Debug("No pending request found for magic response from %s", addr.String())
|
||||
return
|
||||
}
|
||||
|
||||
req := value.(*pendingRequest)
|
||||
rtt := time.Since(req.sentAt)
|
||||
// logger.Debug("Magic response matched pending request for %s (RTT: %v)", req.endpoint, rtt)
|
||||
|
||||
// Send RTT to the waiting goroutine (non-blocking)
|
||||
select {
|
||||
case req.replyChan <- rtt:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// TestEndpoint sends a magic test packet to the endpoint and waits for a response.
|
||||
// This uses the SharedBind so packets come from the same source port as WireGuard.
|
||||
func (t *HolepunchTester) TestEndpoint(endpoint string, timeout time.Duration) TestResult {
|
||||
result := TestResult{
|
||||
Endpoint: endpoint,
|
||||
}
|
||||
|
||||
t.mu.RLock()
|
||||
running := t.running
|
||||
sharedBind := t.sharedBind
|
||||
t.mu.RUnlock()
|
||||
|
||||
if !running {
|
||||
result.Error = fmt.Errorf("tester not running")
|
||||
return result
|
||||
}
|
||||
|
||||
if sharedBind == nil || sharedBind.IsClosed() {
|
||||
result.Error = fmt.Errorf("sharedBind is nil or closed")
|
||||
return result
|
||||
}
|
||||
|
||||
// Resolve the endpoint (using cache)
|
||||
remoteAddr, err := t.resolveEndpoint(endpoint)
|
||||
if err != nil {
|
||||
result.Error = err
|
||||
return result
|
||||
}
|
||||
|
||||
// Generate random data for the test packet
|
||||
randomData := make([]byte, bind.MagicPacketDataLen)
|
||||
if _, err := rand.Read(randomData); err != nil {
|
||||
result.Error = fmt.Errorf("failed to generate random data: %w", err)
|
||||
return result
|
||||
}
|
||||
|
||||
// Create a pending request
|
||||
req := &pendingRequest{
|
||||
endpoint: endpoint,
|
||||
sentAt: time.Now(),
|
||||
replyChan: make(chan time.Duration, 1),
|
||||
}
|
||||
|
||||
key := string(randomData)
|
||||
t.pendingRequests.Store(key, req)
|
||||
|
||||
// Build the test request packet
|
||||
request := make([]byte, bind.MagicTestRequestLen)
|
||||
copy(request, bind.MagicTestRequest)
|
||||
copy(request[len(bind.MagicTestRequest):], randomData)
|
||||
|
||||
// Send the test packet
|
||||
_, err = sharedBind.WriteToUDP(request, remoteAddr)
|
||||
if err != nil {
|
||||
t.pendingRequests.Delete(key)
|
||||
result.Error = fmt.Errorf("failed to send test packet: %w", err)
|
||||
return result
|
||||
}
|
||||
|
||||
// Wait for response with timeout
|
||||
select {
|
||||
case rtt, ok := <-req.replyChan:
|
||||
if ok {
|
||||
result.Success = true
|
||||
result.RTT = rtt
|
||||
} else {
|
||||
result.Error = fmt.Errorf("request cancelled")
|
||||
}
|
||||
case <-time.After(timeout):
|
||||
t.pendingRequests.Delete(key)
|
||||
result.Error = fmt.Errorf("timeout waiting for response")
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// TestConnectionWithBind sends a magic test packet using an existing SharedBind.
|
||||
// This is useful when you want to test the connection through the same socket
|
||||
// that WireGuard is using, which tests the actual hole-punched path.
|
||||
func TestConnectionWithBind(sharedBind *bind.SharedBind, endpoint string, opts *TestConnectionOptions) TestResult {
|
||||
if opts == nil {
|
||||
defaultOpts := DefaultTestOptions()
|
||||
opts = &defaultOpts
|
||||
}
|
||||
|
||||
result := TestResult{
|
||||
Endpoint: endpoint,
|
||||
}
|
||||
|
||||
if sharedBind == nil {
|
||||
result.Error = fmt.Errorf("sharedBind is nil")
|
||||
return result
|
||||
}
|
||||
|
||||
if sharedBind.IsClosed() {
|
||||
result.Error = fmt.Errorf("sharedBind is closed")
|
||||
return result
|
||||
}
|
||||
|
||||
// Resolve the endpoint
|
||||
host, err := util.ResolveDomain(endpoint)
|
||||
if err != nil {
|
||||
host = endpoint
|
||||
}
|
||||
|
||||
_, _, err = net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
host = net.JoinHostPort(host, "21820")
|
||||
}
|
||||
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", host)
|
||||
if err != nil {
|
||||
result.Error = fmt.Errorf("failed to resolve UDP address %s: %w", host, err)
|
||||
return result
|
||||
}
|
||||
|
||||
// Generate random data for the test packet
|
||||
randomData := make([]byte, bind.MagicPacketDataLen)
|
||||
if _, err := rand.Read(randomData); err != nil {
|
||||
result.Error = fmt.Errorf("failed to generate random data: %w", err)
|
||||
return result
|
||||
}
|
||||
|
||||
// Build the test request packet
|
||||
request := make([]byte, bind.MagicTestRequestLen)
|
||||
copy(request, bind.MagicTestRequest)
|
||||
copy(request[len(bind.MagicTestRequest):], randomData)
|
||||
|
||||
// Get the underlying UDP connection to set read deadline and read response
|
||||
udpConn := sharedBind.GetUDPConn()
|
||||
if udpConn == nil {
|
||||
result.Error = fmt.Errorf("could not get UDP connection from SharedBind")
|
||||
return result
|
||||
}
|
||||
|
||||
attempts := opts.Retries + 1
|
||||
for attempt := 0; attempt < attempts; attempt++ {
|
||||
if attempt > 0 {
|
||||
logger.Debug("Retrying connection test to %s (attempt %d/%d)", endpoint, attempt+1, attempts)
|
||||
}
|
||||
|
||||
// Note: We can't easily set a read deadline on the shared connection
|
||||
// without affecting WireGuard, so we use a goroutine with timeout instead
|
||||
startTime := time.Now()
|
||||
|
||||
// Send the test packet through the shared bind
|
||||
_, err = sharedBind.WriteToUDP(request, remoteAddr)
|
||||
if err != nil {
|
||||
result.Error = fmt.Errorf("failed to send test packet: %w", err)
|
||||
if attempt < attempts-1 {
|
||||
continue
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// For shared bind test, we send the packet but can't easily wait for
|
||||
// response without interfering with WireGuard's receive loop.
|
||||
// The response will be handled by SharedBind automatically.
|
||||
// We consider the test successful if the send succeeded.
|
||||
// For a full round-trip test, use TestConnection() with a separate socket.
|
||||
|
||||
result.RTT = time.Since(startTime)
|
||||
result.Success = true
|
||||
result.Error = nil
|
||||
logger.Debug("Test packet sent to %s via SharedBind", endpoint)
|
||||
return result
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
80
internal/state/telemetry_view.go
Normal file
80
internal/state/telemetry_view.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/internal/telemetry"
|
||||
)
|
||||
|
||||
// TelemetryView is a minimal, thread-safe implementation to feed observables.
|
||||
// Since one Newt process represents one site, we expose a single logical site.
|
||||
// site_id is a resource attribute, so we do not emit per-site labels here.
|
||||
type TelemetryView struct {
|
||||
online atomic.Bool
|
||||
lastHBUnix atomic.Int64 // unix seconds
|
||||
// per-tunnel sessions
|
||||
sessMu sync.RWMutex
|
||||
sessions map[string]*atomic.Int64
|
||||
}
|
||||
|
||||
var (
|
||||
globalView atomic.Pointer[TelemetryView]
|
||||
)
|
||||
|
||||
// Global returns a singleton TelemetryView.
|
||||
func Global() *TelemetryView {
|
||||
if v := globalView.Load(); v != nil { return v }
|
||||
v := &TelemetryView{ sessions: make(map[string]*atomic.Int64) }
|
||||
globalView.Store(v)
|
||||
telemetry.RegisterStateView(v)
|
||||
return v
|
||||
}
|
||||
|
||||
// Instrumentation helpers
|
||||
func (v *TelemetryView) IncSessions(tunnelID string) {
|
||||
v.sessMu.Lock(); defer v.sessMu.Unlock()
|
||||
c := v.sessions[tunnelID]
|
||||
if c == nil { c = &atomic.Int64{}; v.sessions[tunnelID] = c }
|
||||
c.Add(1)
|
||||
}
|
||||
func (v *TelemetryView) DecSessions(tunnelID string) {
|
||||
v.sessMu.Lock(); defer v.sessMu.Unlock()
|
||||
if c := v.sessions[tunnelID]; c != nil {
|
||||
c.Add(-1)
|
||||
if c.Load() <= 0 { delete(v.sessions, tunnelID) }
|
||||
}
|
||||
}
|
||||
func (v *TelemetryView) ClearTunnel(tunnelID string) {
|
||||
v.sessMu.Lock(); defer v.sessMu.Unlock()
|
||||
delete(v.sessions, tunnelID)
|
||||
}
|
||||
func (v *TelemetryView) SetOnline(b bool) { v.online.Store(b) }
|
||||
func (v *TelemetryView) TouchHeartbeat() { v.lastHBUnix.Store(time.Now().Unix()) }
|
||||
|
||||
// --- telemetry.StateView interface ---
|
||||
|
||||
func (v *TelemetryView) ListSites() []string { return []string{"self"} }
|
||||
func (v *TelemetryView) Online(_ string) (bool, bool) { return v.online.Load(), true }
|
||||
func (v *TelemetryView) LastHeartbeat(_ string) (time.Time, bool) {
|
||||
sec := v.lastHBUnix.Load()
|
||||
if sec == 0 { return time.Time{}, false }
|
||||
return time.Unix(sec, 0), true
|
||||
}
|
||||
func (v *TelemetryView) ActiveSessions(_ string) (int64, bool) {
|
||||
// aggregated sessions (not used for per-tunnel gauge)
|
||||
v.sessMu.RLock(); defer v.sessMu.RUnlock()
|
||||
var sum int64
|
||||
for _, c := range v.sessions { if c != nil { sum += c.Load() } }
|
||||
return sum, true
|
||||
}
|
||||
|
||||
// Extended accessor used by telemetry callback to publish per-tunnel samples.
|
||||
func (v *TelemetryView) SessionsByTunnel() map[string]int64 {
|
||||
v.sessMu.RLock(); defer v.sessMu.RUnlock()
|
||||
out := make(map[string]int64, len(v.sessions))
|
||||
for id, c := range v.sessions { if c != nil && c.Load() > 0 { out[id] = c.Load() } }
|
||||
return out
|
||||
}
|
||||
|
||||
19
internal/telemetry/constants.go
Normal file
19
internal/telemetry/constants.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package telemetry
|
||||
|
||||
// Protocol labels (low-cardinality)
|
||||
const (
|
||||
ProtocolTCP = "tcp"
|
||||
ProtocolUDP = "udp"
|
||||
)
|
||||
|
||||
// Reconnect reason bins (fixed, low-cardinality)
|
||||
const (
|
||||
ReasonServerRequest = "server_request"
|
||||
ReasonTimeout = "timeout"
|
||||
ReasonPeerClose = "peer_close"
|
||||
ReasonNetworkChange = "network_change"
|
||||
ReasonAuthError = "auth_error"
|
||||
ReasonHandshakeError = "handshake_error"
|
||||
ReasonConfigChange = "config_change"
|
||||
ReasonError = "error"
|
||||
)
|
||||
32
internal/telemetry/constants_test.go
Normal file
32
internal/telemetry/constants_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package telemetry
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestAllowedConstants(t *testing.T) {
|
||||
allowedReasons := map[string]struct{}{
|
||||
ReasonServerRequest: {},
|
||||
ReasonTimeout: {},
|
||||
ReasonPeerClose: {},
|
||||
ReasonNetworkChange: {},
|
||||
ReasonAuthError: {},
|
||||
ReasonHandshakeError: {},
|
||||
ReasonConfigChange: {},
|
||||
ReasonError: {},
|
||||
}
|
||||
for k := range allowedReasons {
|
||||
if k == "" {
|
||||
t.Fatalf("empty reason constant")
|
||||
}
|
||||
}
|
||||
|
||||
allowedProtocols := map[string]struct{}{
|
||||
ProtocolTCP: {},
|
||||
ProtocolUDP: {},
|
||||
}
|
||||
for k := range allowedProtocols {
|
||||
if k == "" {
|
||||
t.Fatalf("empty protocol constant")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
542
internal/telemetry/metrics.go
Normal file
542
internal/telemetry/metrics.go
Normal file
@@ -0,0 +1,542 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
)
|
||||
|
||||
// Instruments and helpers for Newt metrics following the naming, units, and
|
||||
// low-cardinality label guidance from the issue description.
|
||||
//
|
||||
// Counters end with _total, durations are in seconds, sizes in bytes.
|
||||
// Only low-cardinality stable labels are supported: tunnel_id,
|
||||
// transport, direction, result, reason, error_type.
|
||||
var (
|
||||
initOnce sync.Once
|
||||
|
||||
meter metric.Meter
|
||||
|
||||
// Site / Registration
|
||||
mSiteRegistrations metric.Int64Counter
|
||||
mSiteOnline metric.Int64ObservableGauge
|
||||
mSiteLastHeartbeat metric.Float64ObservableGauge
|
||||
|
||||
// Tunnel / Sessions
|
||||
mTunnelSessions metric.Int64ObservableGauge
|
||||
mTunnelBytes metric.Int64Counter
|
||||
mTunnelLatency metric.Float64Histogram
|
||||
mReconnects metric.Int64Counter
|
||||
|
||||
// Connection / NAT
|
||||
mConnAttempts metric.Int64Counter
|
||||
mConnErrors metric.Int64Counter
|
||||
|
||||
// Config/Restart
|
||||
mConfigReloads metric.Int64Counter
|
||||
mConfigApply metric.Float64Histogram
|
||||
mCertRotationTotal metric.Int64Counter
|
||||
mProcessStartTime metric.Float64ObservableGauge
|
||||
|
||||
// Build info
|
||||
mBuildInfo metric.Int64ObservableGauge
|
||||
|
||||
// WebSocket
|
||||
mWSConnectLatency metric.Float64Histogram
|
||||
mWSMessages metric.Int64Counter
|
||||
mWSDisconnects metric.Int64Counter
|
||||
mWSKeepaliveFailure metric.Int64Counter
|
||||
mWSSessionDuration metric.Float64Histogram
|
||||
mWSConnected metric.Int64ObservableGauge
|
||||
mWSReconnects metric.Int64Counter
|
||||
|
||||
// Proxy
|
||||
mProxyActiveConns metric.Int64ObservableGauge
|
||||
mProxyBufferBytes metric.Int64ObservableGauge
|
||||
mProxyAsyncBacklogByte metric.Int64ObservableGauge
|
||||
mProxyDropsTotal metric.Int64Counter
|
||||
mProxyAcceptsTotal metric.Int64Counter
|
||||
mProxyConnDuration metric.Float64Histogram
|
||||
mProxyConnectionsTotal metric.Int64Counter
|
||||
|
||||
buildVersion string
|
||||
buildCommit string
|
||||
processStartUnix = float64(time.Now().UnixNano()) / 1e9
|
||||
wsConnectedState atomic.Int64
|
||||
)
|
||||
|
||||
// Proxy connection lifecycle events.
|
||||
const (
|
||||
ProxyConnectionOpened = "opened"
|
||||
ProxyConnectionClosed = "closed"
|
||||
)
|
||||
|
||||
// attrsWithSite appends site/region labels only when explicitly enabled to keep
|
||||
// label cardinality low by default.
|
||||
func attrsWithSite(extra ...attribute.KeyValue) []attribute.KeyValue {
|
||||
attrs := make([]attribute.KeyValue, len(extra))
|
||||
copy(attrs, extra)
|
||||
if ShouldIncludeSiteLabels() {
|
||||
attrs = append(attrs, siteAttrs()...)
|
||||
}
|
||||
return attrs
|
||||
}
|
||||
|
||||
func registerInstruments() error {
|
||||
var err error
|
||||
initOnce.Do(func() {
|
||||
meter = otel.Meter("newt")
|
||||
if e := registerSiteInstruments(); e != nil {
|
||||
err = e
|
||||
return
|
||||
}
|
||||
if e := registerTunnelInstruments(); e != nil {
|
||||
err = e
|
||||
return
|
||||
}
|
||||
if e := registerConnInstruments(); e != nil {
|
||||
err = e
|
||||
return
|
||||
}
|
||||
if e := registerConfigInstruments(); e != nil {
|
||||
err = e
|
||||
return
|
||||
}
|
||||
if e := registerBuildWSProxyInstruments(); e != nil {
|
||||
err = e
|
||||
return
|
||||
}
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func registerSiteInstruments() error {
|
||||
var err error
|
||||
mSiteRegistrations, err = meter.Int64Counter("newt_site_registrations_total",
|
||||
metric.WithDescription("Total site registration attempts"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mSiteOnline, err = meter.Int64ObservableGauge("newt_site_online",
|
||||
metric.WithDescription("Site online (0/1)"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mSiteLastHeartbeat, err = meter.Float64ObservableGauge("newt_site_last_heartbeat_timestamp_seconds",
|
||||
metric.WithDescription("Unix timestamp of the last site heartbeat"),
|
||||
metric.WithUnit("s"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func registerTunnelInstruments() error {
|
||||
var err error
|
||||
mTunnelSessions, err = meter.Int64ObservableGauge("newt_tunnel_sessions",
|
||||
metric.WithDescription("Active tunnel sessions"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mTunnelBytes, err = meter.Int64Counter("newt_tunnel_bytes_total",
|
||||
metric.WithDescription("Tunnel bytes ingress/egress"),
|
||||
metric.WithUnit("By"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mTunnelLatency, err = meter.Float64Histogram("newt_tunnel_latency_seconds",
|
||||
metric.WithDescription("Per-tunnel latency in seconds"),
|
||||
metric.WithUnit("s"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mReconnects, err = meter.Int64Counter("newt_tunnel_reconnects_total",
|
||||
metric.WithDescription("Tunnel reconnect events"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func registerConnInstruments() error {
|
||||
var err error
|
||||
mConnAttempts, err = meter.Int64Counter("newt_connection_attempts_total",
|
||||
metric.WithDescription("Connection attempts"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mConnErrors, err = meter.Int64Counter("newt_connection_errors_total",
|
||||
metric.WithDescription("Connection errors by type"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func registerConfigInstruments() error {
|
||||
mConfigReloads, _ = meter.Int64Counter("newt_config_reloads_total",
|
||||
metric.WithDescription("Configuration reloads"))
|
||||
mConfigApply, _ = meter.Float64Histogram("newt_config_apply_seconds",
|
||||
metric.WithDescription("Configuration apply duration in seconds"),
|
||||
metric.WithUnit("s"))
|
||||
mCertRotationTotal, _ = meter.Int64Counter("newt_cert_rotation_total",
|
||||
metric.WithDescription("Certificate rotation events (success/failure)"))
|
||||
mProcessStartTime, _ = meter.Float64ObservableGauge("process_start_time_seconds",
|
||||
metric.WithDescription("Unix timestamp of the process start time"),
|
||||
metric.WithUnit("s"))
|
||||
if mProcessStartTime != nil {
|
||||
if _, err := meter.RegisterCallback(func(ctx context.Context, o metric.Observer) error {
|
||||
o.ObserveFloat64(mProcessStartTime, processStartUnix)
|
||||
return nil
|
||||
}, mProcessStartTime); err != nil {
|
||||
otel.Handle(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func registerBuildWSProxyInstruments() error {
|
||||
// Build info gauge (value 1 with version/commit attributes)
|
||||
mBuildInfo, _ = meter.Int64ObservableGauge("newt_build_info",
|
||||
metric.WithDescription("Newt build information (value is always 1)"))
|
||||
// WebSocket
|
||||
mWSConnectLatency, _ = meter.Float64Histogram("newt_websocket_connect_latency_seconds",
|
||||
metric.WithDescription("WebSocket connect latency in seconds"),
|
||||
metric.WithUnit("s"))
|
||||
mWSMessages, _ = meter.Int64Counter("newt_websocket_messages_total",
|
||||
metric.WithDescription("WebSocket messages by direction and type"))
|
||||
mWSDisconnects, _ = meter.Int64Counter("newt_websocket_disconnects_total",
|
||||
metric.WithDescription("WebSocket disconnects by reason/result"))
|
||||
mWSKeepaliveFailure, _ = meter.Int64Counter("newt_websocket_keepalive_failures_total",
|
||||
metric.WithDescription("WebSocket keepalive (ping/pong) failures"))
|
||||
mWSSessionDuration, _ = meter.Float64Histogram("newt_websocket_session_duration_seconds",
|
||||
metric.WithDescription("Duration of established WebSocket sessions"),
|
||||
metric.WithUnit("s"))
|
||||
mWSConnected, _ = meter.Int64ObservableGauge("newt_websocket_connected",
|
||||
metric.WithDescription("WebSocket connection state (1=connected, 0=disconnected)"))
|
||||
mWSReconnects, _ = meter.Int64Counter("newt_websocket_reconnects_total",
|
||||
metric.WithDescription("WebSocket reconnect attempts by reason"))
|
||||
// Proxy
|
||||
mProxyActiveConns, _ = meter.Int64ObservableGauge("newt_proxy_active_connections",
|
||||
metric.WithDescription("Proxy active connections per tunnel and protocol"))
|
||||
mProxyBufferBytes, _ = meter.Int64ObservableGauge("newt_proxy_buffer_bytes",
|
||||
metric.WithDescription("Proxy buffer bytes (may approximate async backlog)"),
|
||||
metric.WithUnit("By"))
|
||||
mProxyAsyncBacklogByte, _ = meter.Int64ObservableGauge("newt_proxy_async_backlog_bytes",
|
||||
metric.WithDescription("Unflushed async byte backlog per tunnel and protocol"),
|
||||
metric.WithUnit("By"))
|
||||
mProxyDropsTotal, _ = meter.Int64Counter("newt_proxy_drops_total",
|
||||
metric.WithDescription("Proxy drops due to write errors"))
|
||||
mProxyAcceptsTotal, _ = meter.Int64Counter("newt_proxy_accept_total",
|
||||
metric.WithDescription("Proxy connection accepts by protocol and result"))
|
||||
mProxyConnDuration, _ = meter.Float64Histogram("newt_proxy_connection_duration_seconds",
|
||||
metric.WithDescription("Duration of completed proxy connections"),
|
||||
metric.WithUnit("s"))
|
||||
mProxyConnectionsTotal, _ = meter.Int64Counter("newt_proxy_connections_total",
|
||||
metric.WithDescription("Proxy connection lifecycle events by protocol"))
|
||||
// Register a default callback for build info if version/commit set
|
||||
reg, e := meter.RegisterCallback(func(ctx context.Context, o metric.Observer) error {
|
||||
if buildVersion == "" && buildCommit == "" {
|
||||
return nil
|
||||
}
|
||||
attrs := []attribute.KeyValue{}
|
||||
if buildVersion != "" {
|
||||
attrs = append(attrs, attribute.String("version", buildVersion))
|
||||
}
|
||||
if buildCommit != "" {
|
||||
attrs = append(attrs, attribute.String("commit", buildCommit))
|
||||
}
|
||||
if ShouldIncludeSiteLabels() {
|
||||
attrs = append(attrs, siteAttrs()...)
|
||||
}
|
||||
o.ObserveInt64(mBuildInfo, 1, metric.WithAttributes(attrs...))
|
||||
return nil
|
||||
}, mBuildInfo)
|
||||
if e != nil {
|
||||
otel.Handle(e)
|
||||
} else {
|
||||
// Provide a functional stopper that unregisters the callback
|
||||
obsStopper = func() { _ = reg.Unregister() }
|
||||
}
|
||||
if mWSConnected != nil {
|
||||
if regConn, err := meter.RegisterCallback(func(ctx context.Context, o metric.Observer) error {
|
||||
val := wsConnectedState.Load()
|
||||
o.ObserveInt64(mWSConnected, val, metric.WithAttributes(attrsWithSite()...))
|
||||
return nil
|
||||
}, mWSConnected); err != nil {
|
||||
otel.Handle(err)
|
||||
} else {
|
||||
wsConnStopper = func() { _ = regConn.Unregister() }
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Observable registration: Newt can register a callback to report gauges.
|
||||
// Call SetObservableCallback once to start observing online status, last
|
||||
// heartbeat seconds, and active sessions.
|
||||
|
||||
var (
|
||||
obsOnce sync.Once
|
||||
obsStopper func()
|
||||
proxyObsOnce sync.Once
|
||||
proxyStopper func()
|
||||
wsConnStopper func()
|
||||
)
|
||||
|
||||
// SetObservableCallback registers a single callback that will be invoked
|
||||
// on collection. Use the provided observer to emit values for the observable
|
||||
// gauges defined here.
|
||||
//
|
||||
// Example inside your code (where you have access to current state):
|
||||
//
|
||||
// telemetry.SetObservableCallback(func(ctx context.Context, o metric.Observer) error {
|
||||
// o.ObserveInt64(mSiteOnline, 1)
|
||||
// o.ObserveFloat64(mSiteLastHeartbeat, float64(lastHB.Unix()))
|
||||
// o.ObserveInt64(mTunnelSessions, int64(len(activeSessions)))
|
||||
// return nil
|
||||
// })
|
||||
func SetObservableCallback(cb func(context.Context, metric.Observer) error) {
|
||||
obsOnce.Do(func() {
|
||||
reg, e := meter.RegisterCallback(cb, mSiteOnline, mSiteLastHeartbeat, mTunnelSessions)
|
||||
if e != nil {
|
||||
otel.Handle(e)
|
||||
obsStopper = func() {
|
||||
// no-op: registration failed; keep stopper callable
|
||||
}
|
||||
return
|
||||
}
|
||||
// Provide a functional stopper mirroring proxy/build-info behavior
|
||||
obsStopper = func() { _ = reg.Unregister() }
|
||||
})
|
||||
}
|
||||
|
||||
// SetProxyObservableCallback registers a callback to observe proxy gauges.
|
||||
func SetProxyObservableCallback(cb func(context.Context, metric.Observer) error) {
|
||||
proxyObsOnce.Do(func() {
|
||||
reg, e := meter.RegisterCallback(cb, mProxyActiveConns, mProxyBufferBytes, mProxyAsyncBacklogByte)
|
||||
if e != nil {
|
||||
otel.Handle(e)
|
||||
proxyStopper = func() {
|
||||
// no-op: registration failed; keep stopper callable
|
||||
}
|
||||
return
|
||||
}
|
||||
// Provide a functional stopper to unregister later if needed
|
||||
proxyStopper = func() { _ = reg.Unregister() }
|
||||
})
|
||||
}
|
||||
|
||||
// Build info registration
|
||||
func RegisterBuildInfo(version, commit string) {
|
||||
buildVersion = version
|
||||
buildCommit = commit
|
||||
}
|
||||
|
||||
// Config reloads
|
||||
func IncConfigReload(ctx context.Context, result string) {
|
||||
mConfigReloads.Add(ctx, 1, metric.WithAttributes(attrsWithSite(
|
||||
attribute.String("result", result),
|
||||
)...))
|
||||
}
|
||||
|
||||
// Helpers for counters/histograms
|
||||
|
||||
func IncSiteRegistration(ctx context.Context, result string) {
|
||||
attrs := []attribute.KeyValue{
|
||||
attribute.String("result", result),
|
||||
}
|
||||
mSiteRegistrations.Add(ctx, 1, metric.WithAttributes(attrsWithSite(attrs...)...))
|
||||
}
|
||||
|
||||
func AddTunnelBytes(ctx context.Context, tunnelID, direction string, n int64) {
|
||||
attrs := []attribute.KeyValue{
|
||||
attribute.String("direction", direction),
|
||||
}
|
||||
if ShouldIncludeTunnelID() && tunnelID != "" {
|
||||
attrs = append(attrs, attribute.String("tunnel_id", tunnelID))
|
||||
}
|
||||
mTunnelBytes.Add(ctx, n, metric.WithAttributes(attrsWithSite(attrs...)...))
|
||||
}
|
||||
|
||||
// AddTunnelBytesSet adds bytes using a pre-built attribute.Set to avoid per-call allocations.
|
||||
func AddTunnelBytesSet(ctx context.Context, n int64, attrs attribute.Set) {
|
||||
mTunnelBytes.Add(ctx, n, metric.WithAttributeSet(attrs))
|
||||
}
|
||||
|
||||
// --- WebSocket helpers ---
|
||||
|
||||
func ObserveWSConnectLatency(ctx context.Context, seconds float64, result, errorType string) {
|
||||
attrs := []attribute.KeyValue{
|
||||
attribute.String("transport", "websocket"),
|
||||
attribute.String("result", result),
|
||||
}
|
||||
if errorType != "" {
|
||||
attrs = append(attrs, attribute.String("error_type", errorType))
|
||||
}
|
||||
mWSConnectLatency.Record(ctx, seconds, metric.WithAttributes(attrsWithSite(attrs...)...))
|
||||
}
|
||||
|
||||
func IncWSMessage(ctx context.Context, direction, msgType string) {
|
||||
mWSMessages.Add(ctx, 1, metric.WithAttributes(attrsWithSite(
|
||||
attribute.String("direction", direction),
|
||||
attribute.String("msg_type", msgType),
|
||||
)...))
|
||||
}
|
||||
|
||||
func IncWSDisconnect(ctx context.Context, reason, result string) {
|
||||
mWSDisconnects.Add(ctx, 1, metric.WithAttributes(attrsWithSite(
|
||||
attribute.String("reason", reason),
|
||||
attribute.String("result", result),
|
||||
)...))
|
||||
}
|
||||
|
||||
func IncWSKeepaliveFailure(ctx context.Context, reason string) {
|
||||
mWSKeepaliveFailure.Add(ctx, 1, metric.WithAttributes(attrsWithSite(
|
||||
attribute.String("reason", reason),
|
||||
)...))
|
||||
}
|
||||
|
||||
// SetWSConnectionState updates the backing gauge for the WebSocket connected state.
|
||||
func SetWSConnectionState(connected bool) {
|
||||
if connected {
|
||||
wsConnectedState.Store(1)
|
||||
} else {
|
||||
wsConnectedState.Store(0)
|
||||
}
|
||||
}
|
||||
|
||||
// IncWSReconnect increments the WebSocket reconnect counter with a bounded reason label.
|
||||
func IncWSReconnect(ctx context.Context, reason string) {
|
||||
if reason == "" {
|
||||
reason = "unknown"
|
||||
}
|
||||
mWSReconnects.Add(ctx, 1, metric.WithAttributes(attrsWithSite(
|
||||
attribute.String("reason", reason),
|
||||
)...))
|
||||
}
|
||||
|
||||
func ObserveWSSessionDuration(ctx context.Context, seconds float64, result string) {
|
||||
mWSSessionDuration.Record(ctx, seconds, metric.WithAttributes(attrsWithSite(
|
||||
attribute.String("result", result),
|
||||
)...))
|
||||
}
|
||||
|
||||
// --- Proxy helpers ---
|
||||
|
||||
func ObserveProxyActiveConnsObs(o metric.Observer, value int64, attrs []attribute.KeyValue) {
|
||||
o.ObserveInt64(mProxyActiveConns, value, metric.WithAttributes(attrs...))
|
||||
}
|
||||
|
||||
func ObserveProxyBufferBytesObs(o metric.Observer, value int64, attrs []attribute.KeyValue) {
|
||||
o.ObserveInt64(mProxyBufferBytes, value, metric.WithAttributes(attrs...))
|
||||
}
|
||||
|
||||
func ObserveProxyAsyncBacklogObs(o metric.Observer, value int64, attrs []attribute.KeyValue) {
|
||||
o.ObserveInt64(mProxyAsyncBacklogByte, value, metric.WithAttributes(attrs...))
|
||||
}
|
||||
|
||||
func IncProxyDrops(ctx context.Context, tunnelID, protocol string) {
|
||||
attrs := []attribute.KeyValue{
|
||||
attribute.String("protocol", protocol),
|
||||
}
|
||||
if ShouldIncludeTunnelID() && tunnelID != "" {
|
||||
attrs = append(attrs, attribute.String("tunnel_id", tunnelID))
|
||||
}
|
||||
mProxyDropsTotal.Add(ctx, 1, metric.WithAttributes(attrsWithSite(attrs...)...))
|
||||
}
|
||||
|
||||
func IncProxyAccept(ctx context.Context, tunnelID, protocol, result, reason string) {
|
||||
attrs := []attribute.KeyValue{
|
||||
attribute.String("protocol", protocol),
|
||||
attribute.String("result", result),
|
||||
}
|
||||
if reason != "" {
|
||||
attrs = append(attrs, attribute.String("reason", reason))
|
||||
}
|
||||
if ShouldIncludeTunnelID() && tunnelID != "" {
|
||||
attrs = append(attrs, attribute.String("tunnel_id", tunnelID))
|
||||
}
|
||||
mProxyAcceptsTotal.Add(ctx, 1, metric.WithAttributes(attrsWithSite(attrs...)...))
|
||||
}
|
||||
|
||||
func ObserveProxyConnectionDuration(ctx context.Context, tunnelID, protocol, result string, seconds float64) {
|
||||
attrs := []attribute.KeyValue{
|
||||
attribute.String("protocol", protocol),
|
||||
attribute.String("result", result),
|
||||
}
|
||||
if ShouldIncludeTunnelID() && tunnelID != "" {
|
||||
attrs = append(attrs, attribute.String("tunnel_id", tunnelID))
|
||||
}
|
||||
mProxyConnDuration.Record(ctx, seconds, metric.WithAttributes(attrsWithSite(attrs...)...))
|
||||
}
|
||||
|
||||
// IncProxyConnectionEvent records proxy connection lifecycle events (opened/closed).
|
||||
func IncProxyConnectionEvent(ctx context.Context, tunnelID, protocol, event string) {
|
||||
if event == "" {
|
||||
event = "unknown"
|
||||
}
|
||||
attrs := []attribute.KeyValue{
|
||||
attribute.String("protocol", protocol),
|
||||
attribute.String("event", event),
|
||||
}
|
||||
if ShouldIncludeTunnelID() && tunnelID != "" {
|
||||
attrs = append(attrs, attribute.String("tunnel_id", tunnelID))
|
||||
}
|
||||
mProxyConnectionsTotal.Add(ctx, 1, metric.WithAttributes(attrsWithSite(attrs...)...))
|
||||
}
|
||||
|
||||
// --- Config/PKI helpers ---
|
||||
|
||||
func ObserveConfigApply(ctx context.Context, phase, result string, seconds float64) {
|
||||
mConfigApply.Record(ctx, seconds, metric.WithAttributes(attrsWithSite(
|
||||
attribute.String("phase", phase),
|
||||
attribute.String("result", result),
|
||||
)...))
|
||||
}
|
||||
|
||||
func IncCertRotation(ctx context.Context, result string) {
|
||||
mCertRotationTotal.Add(ctx, 1, metric.WithAttributes(attrsWithSite(
|
||||
attribute.String("result", result),
|
||||
)...))
|
||||
}
|
||||
|
||||
func ObserveTunnelLatency(ctx context.Context, tunnelID, transport string, seconds float64) {
|
||||
attrs := []attribute.KeyValue{
|
||||
attribute.String("transport", transport),
|
||||
}
|
||||
if ShouldIncludeTunnelID() && tunnelID != "" {
|
||||
attrs = append(attrs, attribute.String("tunnel_id", tunnelID))
|
||||
}
|
||||
mTunnelLatency.Record(ctx, seconds, metric.WithAttributes(attrsWithSite(attrs...)...))
|
||||
}
|
||||
|
||||
func IncReconnect(ctx context.Context, tunnelID, initiator, reason string) {
|
||||
attrs := []attribute.KeyValue{
|
||||
attribute.String("initiator", initiator),
|
||||
attribute.String("reason", reason),
|
||||
}
|
||||
if ShouldIncludeTunnelID() && tunnelID != "" {
|
||||
attrs = append(attrs, attribute.String("tunnel_id", tunnelID))
|
||||
}
|
||||
mReconnects.Add(ctx, 1, metric.WithAttributes(attrsWithSite(attrs...)...))
|
||||
}
|
||||
|
||||
func IncConnAttempt(ctx context.Context, transport, result string) {
|
||||
mConnAttempts.Add(ctx, 1, metric.WithAttributes(attrsWithSite(
|
||||
attribute.String("transport", transport),
|
||||
attribute.String("result", result),
|
||||
)...))
|
||||
}
|
||||
|
||||
func IncConnError(ctx context.Context, transport, typ string) {
|
||||
mConnErrors.Add(ctx, 1, metric.WithAttributes(attrsWithSite(
|
||||
attribute.String("transport", transport),
|
||||
attribute.String("error_type", typ),
|
||||
)...))
|
||||
}
|
||||
59
internal/telemetry/metrics_test_helper.go
Normal file
59
internal/telemetry/metrics_test_helper.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
func resetMetricsForTest() {
|
||||
initOnce = sync.Once{}
|
||||
obsOnce = sync.Once{}
|
||||
proxyObsOnce = sync.Once{}
|
||||
obsStopper = nil
|
||||
proxyStopper = nil
|
||||
if wsConnStopper != nil {
|
||||
wsConnStopper()
|
||||
}
|
||||
wsConnStopper = nil
|
||||
meter = nil
|
||||
|
||||
mSiteRegistrations = nil
|
||||
mSiteOnline = nil
|
||||
mSiteLastHeartbeat = nil
|
||||
|
||||
mTunnelSessions = nil
|
||||
mTunnelBytes = nil
|
||||
mTunnelLatency = nil
|
||||
mReconnects = nil
|
||||
|
||||
mConnAttempts = nil
|
||||
mConnErrors = nil
|
||||
|
||||
mConfigReloads = nil
|
||||
mConfigApply = nil
|
||||
mCertRotationTotal = nil
|
||||
mProcessStartTime = nil
|
||||
|
||||
mBuildInfo = nil
|
||||
|
||||
mWSConnectLatency = nil
|
||||
mWSMessages = nil
|
||||
mWSDisconnects = nil
|
||||
mWSKeepaliveFailure = nil
|
||||
mWSSessionDuration = nil
|
||||
mWSConnected = nil
|
||||
mWSReconnects = nil
|
||||
|
||||
mProxyActiveConns = nil
|
||||
mProxyBufferBytes = nil
|
||||
mProxyAsyncBacklogByte = nil
|
||||
mProxyDropsTotal = nil
|
||||
mProxyAcceptsTotal = nil
|
||||
mProxyConnDuration = nil
|
||||
mProxyConnectionsTotal = nil
|
||||
|
||||
processStartUnix = float64(time.Now().UnixNano()) / 1e9
|
||||
wsConnectedState.Store(0)
|
||||
includeTunnelIDVal.Store(false)
|
||||
includeSiteLabelVal.Store(false)
|
||||
}
|
||||
106
internal/telemetry/state_view.go
Normal file
106
internal/telemetry/state_view.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
)
|
||||
|
||||
// StateView provides a read-only view for observable gauges.
|
||||
// Implementations must be concurrency-safe and avoid blocking operations.
|
||||
// All methods should be fast and use RLocks where applicable.
|
||||
type StateView interface {
|
||||
// ListSites returns a stable, low-cardinality list of site IDs to expose.
|
||||
ListSites() []string
|
||||
// Online returns whether the site is online.
|
||||
Online(siteID string) (online bool, ok bool)
|
||||
// LastHeartbeat returns the last heartbeat time for a site.
|
||||
LastHeartbeat(siteID string) (t time.Time, ok bool)
|
||||
// ActiveSessions returns the current number of active sessions for a site (across tunnels),
|
||||
// or scoped to site if your model is site-scoped.
|
||||
ActiveSessions(siteID string) (n int64, ok bool)
|
||||
}
|
||||
|
||||
var (
|
||||
stateView atomic.Value // of type StateView
|
||||
)
|
||||
|
||||
// RegisterStateView sets the global StateView used by the default observable callback.
|
||||
func RegisterStateView(v StateView) {
|
||||
stateView.Store(v)
|
||||
// If instruments are registered, ensure a callback exists.
|
||||
if v != nil {
|
||||
SetObservableCallback(func(ctx context.Context, o metric.Observer) error {
|
||||
if any := stateView.Load(); any != nil {
|
||||
if sv, ok := any.(StateView); ok {
|
||||
for _, siteID := range sv.ListSites() {
|
||||
observeSiteOnlineFor(o, sv, siteID)
|
||||
observeLastHeartbeatFor(o, sv, siteID)
|
||||
observeSessionsFor(o, siteID, sv)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func observeSiteOnlineFor(o metric.Observer, sv StateView, siteID string) {
|
||||
if online, ok := sv.Online(siteID); ok {
|
||||
val := int64(0)
|
||||
if online {
|
||||
val = 1
|
||||
}
|
||||
o.ObserveInt64(mSiteOnline, val, metric.WithAttributes(
|
||||
attribute.String("site_id", siteID),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
func observeLastHeartbeatFor(o metric.Observer, sv StateView, siteID string) {
|
||||
if t, ok := sv.LastHeartbeat(siteID); ok {
|
||||
ts := float64(t.UnixNano()) / 1e9
|
||||
o.ObserveFloat64(mSiteLastHeartbeat, ts, metric.WithAttributes(
|
||||
attribute.String("site_id", siteID),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
func observeSessionsFor(o metric.Observer, siteID string, any interface{}) {
|
||||
if tm, ok := any.(interface{ SessionsByTunnel() map[string]int64 }); ok {
|
||||
sessions := tm.SessionsByTunnel()
|
||||
// If tunnel_id labels are enabled, preserve existing per-tunnel observations
|
||||
if ShouldIncludeTunnelID() {
|
||||
for tid, n := range sessions {
|
||||
attrs := []attribute.KeyValue{
|
||||
attribute.String("site_id", siteID),
|
||||
}
|
||||
if tid != "" {
|
||||
attrs = append(attrs, attribute.String("tunnel_id", tid))
|
||||
}
|
||||
o.ObserveInt64(mTunnelSessions, n, metric.WithAttributes(attrs...))
|
||||
}
|
||||
return
|
||||
}
|
||||
// When tunnel_id is disabled, collapse per-tunnel counts into a single site-level value
|
||||
var total int64
|
||||
for _, n := range sessions {
|
||||
total += n
|
||||
}
|
||||
// If there are no per-tunnel entries, fall back to ActiveSessions() if available
|
||||
if total == 0 {
|
||||
if svAny := stateView.Load(); svAny != nil {
|
||||
if sv, ok := svAny.(StateView); ok {
|
||||
if n, ok2 := sv.ActiveSessions(siteID); ok2 {
|
||||
total = n
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
o.ObserveInt64(mTunnelSessions, total, metric.WithAttributes(attribute.String("site_id", siteID)))
|
||||
return
|
||||
}
|
||||
}
|
||||
384
internal/telemetry/telemetry.go
Normal file
384
internal/telemetry/telemetry.go
Normal file
@@ -0,0 +1,384 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
promclient "github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"go.opentelemetry.io/contrib/instrumentation/runtime"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc"
|
||||
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
|
||||
"go.opentelemetry.io/otel/exporters/prometheus"
|
||||
"go.opentelemetry.io/otel/sdk/metric"
|
||||
"go.opentelemetry.io/otel/sdk/resource"
|
||||
"go.opentelemetry.io/otel/sdk/trace"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
|
||||
"google.golang.org/grpc/credentials"
|
||||
)
|
||||
|
||||
// Config controls telemetry initialization via env flags.
|
||||
//
|
||||
// Defaults align with the issue requirements:
|
||||
// - Prometheus exporter enabled by default (/metrics)
|
||||
// - OTLP exporter disabled by default
|
||||
// - Durations in seconds, bytes in raw bytes
|
||||
// - Admin HTTP server address configurable (for mounting /metrics)
|
||||
type Config struct {
|
||||
ServiceName string
|
||||
ServiceVersion string
|
||||
|
||||
// Optional resource attributes
|
||||
SiteID string
|
||||
Region string
|
||||
|
||||
PromEnabled bool
|
||||
OTLPEnabled bool
|
||||
|
||||
OTLPEndpoint string // host:port
|
||||
OTLPInsecure bool
|
||||
|
||||
MetricExportInterval time.Duration
|
||||
AdminAddr string // e.g.: ":2112"
|
||||
|
||||
// Optional build info for newt_build_info metric
|
||||
BuildVersion string
|
||||
BuildCommit string
|
||||
}
|
||||
|
||||
// FromEnv reads configuration from environment variables.
|
||||
//
|
||||
// NEWT_METRICS_PROMETHEUS_ENABLED (default: true)
|
||||
// NEWT_METRICS_OTLP_ENABLED (default: false)
|
||||
// OTEL_EXPORTER_OTLP_ENDPOINT (default: "localhost:4317")
|
||||
// OTEL_EXPORTER_OTLP_INSECURE (default: true)
|
||||
// OTEL_METRIC_EXPORT_INTERVAL (default: 15s)
|
||||
// OTEL_SERVICE_NAME (default: "newt")
|
||||
// OTEL_SERVICE_VERSION (default: "")
|
||||
// NEWT_ADMIN_ADDR (default: ":2112")
|
||||
func FromEnv() Config {
|
||||
// Prefer explicit NEWT_* env vars, then fall back to OTEL_RESOURCE_ATTRIBUTES
|
||||
site := getenv("NEWT_SITE_ID", "")
|
||||
if site == "" {
|
||||
site = getenv("NEWT_ID", "")
|
||||
}
|
||||
region := os.Getenv("NEWT_REGION")
|
||||
if site == "" || region == "" {
|
||||
if ra := os.Getenv("OTEL_RESOURCE_ATTRIBUTES"); ra != "" {
|
||||
m := parseResourceAttributes(ra)
|
||||
if site == "" {
|
||||
site = m["site_id"]
|
||||
}
|
||||
if region == "" {
|
||||
region = m["region"]
|
||||
}
|
||||
}
|
||||
}
|
||||
return Config{
|
||||
ServiceName: getenv("OTEL_SERVICE_NAME", "newt"),
|
||||
ServiceVersion: os.Getenv("OTEL_SERVICE_VERSION"),
|
||||
SiteID: site,
|
||||
Region: region,
|
||||
PromEnabled: getenv("NEWT_METRICS_PROMETHEUS_ENABLED", "true") == "true",
|
||||
OTLPEnabled: getenv("NEWT_METRICS_OTLP_ENABLED", "false") == "true",
|
||||
OTLPEndpoint: getenv("OTEL_EXPORTER_OTLP_ENDPOINT", "localhost:4317"),
|
||||
OTLPInsecure: getenv("OTEL_EXPORTER_OTLP_INSECURE", "true") == "true",
|
||||
MetricExportInterval: getdur("OTEL_METRIC_EXPORT_INTERVAL", 15*time.Second),
|
||||
AdminAddr: getenv("NEWT_ADMIN_ADDR", ":2112"),
|
||||
}
|
||||
}
|
||||
|
||||
// Setup holds initialized telemetry providers and (optionally) a /metrics handler.
|
||||
// Call Shutdown when the process terminates to flush exporters.
|
||||
type Setup struct {
|
||||
MeterProvider *metric.MeterProvider
|
||||
TracerProvider *trace.TracerProvider
|
||||
|
||||
PrometheusHandler http.Handler // nil if Prometheus exporter disabled
|
||||
|
||||
shutdowns []func(context.Context) error
|
||||
}
|
||||
|
||||
// Init configures OpenTelemetry metrics and (optionally) tracing.
|
||||
//
|
||||
// It sets a global MeterProvider and TracerProvider, registers runtime instrumentation,
|
||||
// installs recommended histogram views for *_latency_seconds, and returns a Setup with
|
||||
// a Shutdown method to flush exporters.
|
||||
func Init(ctx context.Context, cfg Config) (*Setup, error) {
|
||||
// Configure tunnel_id label inclusion from env (default true)
|
||||
if getenv("NEWT_METRICS_INCLUDE_TUNNEL_ID", "true") == "true" {
|
||||
includeTunnelIDVal.Store(true)
|
||||
} else {
|
||||
includeTunnelIDVal.Store(false)
|
||||
}
|
||||
if getenv("NEWT_METRICS_INCLUDE_SITE_LABELS", "true") == "true" {
|
||||
includeSiteLabelVal.Store(true)
|
||||
} else {
|
||||
includeSiteLabelVal.Store(false)
|
||||
}
|
||||
res := buildResource(ctx, cfg)
|
||||
UpdateSiteInfo(cfg.SiteID, cfg.Region)
|
||||
|
||||
s := &Setup{}
|
||||
readers, promHandler, shutdowns, err := setupMetricExport(ctx, cfg, res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.PrometheusHandler = promHandler
|
||||
// Build provider
|
||||
mp := buildMeterProvider(res, readers)
|
||||
otel.SetMeterProvider(mp)
|
||||
s.MeterProvider = mp
|
||||
s.shutdowns = append(s.shutdowns, mp.Shutdown)
|
||||
// Optional tracing
|
||||
if cfg.OTLPEnabled {
|
||||
if tp, shutdown := setupTracing(ctx, cfg, res); tp != nil {
|
||||
otel.SetTracerProvider(tp)
|
||||
s.TracerProvider = tp
|
||||
s.shutdowns = append(s.shutdowns, func(c context.Context) error {
|
||||
return errors.Join(shutdown(c), tp.Shutdown(c))
|
||||
})
|
||||
}
|
||||
}
|
||||
// Add metric exporter shutdowns
|
||||
s.shutdowns = append(s.shutdowns, shutdowns...)
|
||||
// Runtime metrics
|
||||
_ = runtime.Start(runtime.WithMeterProvider(mp))
|
||||
// Instruments
|
||||
if err := registerInstruments(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg.BuildVersion != "" || cfg.BuildCommit != "" {
|
||||
RegisterBuildInfo(cfg.BuildVersion, cfg.BuildCommit)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func buildResource(ctx context.Context, cfg Config) *resource.Resource {
|
||||
attrs := []attribute.KeyValue{
|
||||
semconv.ServiceName(cfg.ServiceName),
|
||||
semconv.ServiceVersion(cfg.ServiceVersion),
|
||||
}
|
||||
if cfg.SiteID != "" {
|
||||
attrs = append(attrs, attribute.String("site_id", cfg.SiteID))
|
||||
}
|
||||
if cfg.Region != "" {
|
||||
attrs = append(attrs, attribute.String("region", cfg.Region))
|
||||
}
|
||||
res, _ := resource.New(ctx, resource.WithFromEnv(), resource.WithHost(), resource.WithAttributes(attrs...))
|
||||
return res
|
||||
}
|
||||
|
||||
func setupMetricExport(ctx context.Context, cfg Config, _ *resource.Resource) ([]metric.Reader, http.Handler, []func(context.Context) error, error) {
|
||||
var readers []metric.Reader
|
||||
var shutdowns []func(context.Context) error
|
||||
var promHandler http.Handler
|
||||
if cfg.PromEnabled {
|
||||
reg := promclient.NewRegistry()
|
||||
exp, err := prometheus.New(prometheus.WithRegisterer(reg))
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
readers = append(readers, exp)
|
||||
promHandler = promhttp.HandlerFor(reg, promhttp.HandlerOpts{})
|
||||
}
|
||||
if cfg.OTLPEnabled {
|
||||
mopts := []otlpmetricgrpc.Option{otlpmetricgrpc.WithEndpoint(cfg.OTLPEndpoint)}
|
||||
if hdrs := parseOTLPHeaders(os.Getenv("OTEL_EXPORTER_OTLP_HEADERS")); len(hdrs) > 0 {
|
||||
mopts = append(mopts, otlpmetricgrpc.WithHeaders(hdrs))
|
||||
}
|
||||
if cfg.OTLPInsecure {
|
||||
mopts = append(mopts, otlpmetricgrpc.WithInsecure())
|
||||
} else if certFile := os.Getenv("OTEL_EXPORTER_OTLP_CERTIFICATE"); certFile != "" {
|
||||
if creds, cerr := credentials.NewClientTLSFromFile(certFile, ""); cerr == nil {
|
||||
mopts = append(mopts, otlpmetricgrpc.WithTLSCredentials(creds))
|
||||
}
|
||||
}
|
||||
mexp, err := otlpmetricgrpc.New(ctx, mopts...)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
readers = append(readers, metric.NewPeriodicReader(mexp, metric.WithInterval(cfg.MetricExportInterval)))
|
||||
shutdowns = append(shutdowns, mexp.Shutdown)
|
||||
}
|
||||
return readers, promHandler, shutdowns, nil
|
||||
}
|
||||
|
||||
func buildMeterProvider(res *resource.Resource, readers []metric.Reader) *metric.MeterProvider {
|
||||
var mpOpts []metric.Option
|
||||
mpOpts = append(mpOpts, metric.WithResource(res))
|
||||
for _, r := range readers {
|
||||
mpOpts = append(mpOpts, metric.WithReader(r))
|
||||
}
|
||||
mpOpts = append(mpOpts, metric.WithView(metric.NewView(
|
||||
metric.Instrument{Name: "newt_*_latency_seconds"},
|
||||
metric.Stream{Aggregation: metric.AggregationExplicitBucketHistogram{Boundaries: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10, 30}}},
|
||||
)))
|
||||
mpOpts = append(mpOpts, metric.WithView(metric.NewView(
|
||||
metric.Instrument{Name: "newt_*"},
|
||||
metric.Stream{AttributeFilter: func(kv attribute.KeyValue) bool {
|
||||
k := string(kv.Key)
|
||||
switch k {
|
||||
case "tunnel_id", "transport", "direction", "protocol", "result", "reason", "initiator", "error_type", "msg_type", "phase", "version", "commit", "site_id", "region":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}},
|
||||
)))
|
||||
return metric.NewMeterProvider(mpOpts...)
|
||||
}
|
||||
|
||||
func setupTracing(ctx context.Context, cfg Config, res *resource.Resource) (*trace.TracerProvider, func(context.Context) error) {
|
||||
topts := []otlptracegrpc.Option{otlptracegrpc.WithEndpoint(cfg.OTLPEndpoint)}
|
||||
if hdrs := parseOTLPHeaders(os.Getenv("OTEL_EXPORTER_OTLP_HEADERS")); len(hdrs) > 0 {
|
||||
topts = append(topts, otlptracegrpc.WithHeaders(hdrs))
|
||||
}
|
||||
if cfg.OTLPInsecure {
|
||||
topts = append(topts, otlptracegrpc.WithInsecure())
|
||||
} else if certFile := os.Getenv("OTEL_EXPORTER_OTLP_CERTIFICATE"); certFile != "" {
|
||||
if creds, cerr := credentials.NewClientTLSFromFile(certFile, ""); cerr == nil {
|
||||
topts = append(topts, otlptracegrpc.WithTLSCredentials(creds))
|
||||
}
|
||||
}
|
||||
exp, err := otlptracegrpc.New(ctx, topts...)
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
tp := trace.NewTracerProvider(trace.WithBatcher(exp), trace.WithResource(res))
|
||||
return tp, exp.Shutdown
|
||||
}
|
||||
|
||||
// Shutdown flushes exporters and providers in reverse init order.
|
||||
func (s *Setup) Shutdown(ctx context.Context) error {
|
||||
var err error
|
||||
for i := len(s.shutdowns) - 1; i >= 0; i-- {
|
||||
err = errors.Join(err, s.shutdowns[i](ctx))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func parseOTLPHeaders(h string) map[string]string {
|
||||
m := map[string]string{}
|
||||
if h == "" {
|
||||
return m
|
||||
}
|
||||
pairs := strings.Split(h, ",")
|
||||
for _, p := range pairs {
|
||||
kv := strings.SplitN(strings.TrimSpace(p), "=", 2)
|
||||
if len(kv) == 2 {
|
||||
m[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1])
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// parseResourceAttributes parses OTEL_RESOURCE_ATTRIBUTES formatted as k=v,k2=v2
|
||||
func parseResourceAttributes(s string) map[string]string {
|
||||
m := map[string]string{}
|
||||
if s == "" {
|
||||
return m
|
||||
}
|
||||
parts := strings.Split(s, ",")
|
||||
for _, p := range parts {
|
||||
kv := strings.SplitN(strings.TrimSpace(p), "=", 2)
|
||||
if len(kv) == 2 {
|
||||
m[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1])
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// Global site/region used to enrich metric labels.
|
||||
var siteIDVal atomic.Value
|
||||
var regionVal atomic.Value
|
||||
var (
|
||||
includeTunnelIDVal atomic.Value // bool; default true
|
||||
includeSiteLabelVal atomic.Value // bool; default false
|
||||
)
|
||||
|
||||
// UpdateSiteInfo updates the global site_id and region used for metric labels.
|
||||
// Thread-safe via atomic.Value: subsequent metric emissions will include
|
||||
// the new labels, prior emissions remain unchanged.
|
||||
func UpdateSiteInfo(siteID, region string) {
|
||||
if siteID != "" {
|
||||
siteIDVal.Store(siteID)
|
||||
}
|
||||
if region != "" {
|
||||
regionVal.Store(region)
|
||||
}
|
||||
}
|
||||
|
||||
func getSiteID() string {
|
||||
if v, ok := siteIDVal.Load().(string); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func getRegion() string {
|
||||
if v, ok := regionVal.Load().(string); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// siteAttrs returns label KVs for site_id and region (if set).
|
||||
func siteAttrs() []attribute.KeyValue {
|
||||
var out []attribute.KeyValue
|
||||
if s := getSiteID(); s != "" {
|
||||
out = append(out, attribute.String("site_id", s))
|
||||
}
|
||||
if r := getRegion(); r != "" {
|
||||
out = append(out, attribute.String("region", r))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// SiteLabelKVs exposes site label KVs for other packages (e.g., proxy manager).
|
||||
func SiteLabelKVs() []attribute.KeyValue {
|
||||
if !ShouldIncludeSiteLabels() {
|
||||
return nil
|
||||
}
|
||||
return siteAttrs()
|
||||
}
|
||||
|
||||
// ShouldIncludeTunnelID returns whether tunnel_id labels should be emitted.
|
||||
func ShouldIncludeTunnelID() bool {
|
||||
if v, ok := includeTunnelIDVal.Load().(bool); ok {
|
||||
return v
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// ShouldIncludeSiteLabels returns whether site_id/region should be emitted as
|
||||
// metric labels in addition to resource attributes.
|
||||
func ShouldIncludeSiteLabels() bool {
|
||||
if v, ok := includeSiteLabelVal.Load().(bool); ok {
|
||||
return v
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getenv(k, d string) string {
|
||||
if v := os.Getenv(k); v != "" {
|
||||
return v
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
func getdur(k string, d time.Duration) time.Duration {
|
||||
if v := os.Getenv(k); v != "" {
|
||||
if p, e := time.ParseDuration(v); e == nil {
|
||||
return p
|
||||
}
|
||||
}
|
||||
return d
|
||||
}
|
||||
53
internal/telemetry/telemetry_attrfilter_test.go
Normal file
53
internal/telemetry/telemetry_attrfilter_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
)
|
||||
|
||||
// Test that disallowed attributes are filtered from the exposition.
|
||||
func TestAttributeFilterDropsUnknownKeys(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
resetMetricsForTest()
|
||||
t.Setenv("NEWT_METRICS_INCLUDE_SITE_LABELS", "true")
|
||||
cfg := Config{ServiceName: "newt", PromEnabled: true, AdminAddr: "127.0.0.1:0"}
|
||||
tel, err := Init(ctx, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("init: %v", err)
|
||||
}
|
||||
defer func() { _ = tel.Shutdown(context.Background()) }()
|
||||
|
||||
if tel.PrometheusHandler == nil {
|
||||
t.Fatalf("prom handler nil")
|
||||
}
|
||||
ts := httptest.NewServer(tel.PrometheusHandler)
|
||||
defer ts.Close()
|
||||
|
||||
// Add samples with disallowed attribute keys
|
||||
for _, k := range []string{"forbidden", "site_id", "host"} {
|
||||
set := attribute.NewSet(attribute.String(k, "x"))
|
||||
AddTunnelBytesSet(ctx, 123, set)
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
resp, err := http.Get(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("GET: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
body := string(b)
|
||||
if strings.Contains(body, "forbidden=") {
|
||||
t.Fatalf("unexpected forbidden attribute leaked into metrics: %s", body)
|
||||
}
|
||||
if !strings.Contains(body, "site_id=\"x\"") {
|
||||
t.Fatalf("expected allowed attribute site_id to be present in metrics, got: %s", body)
|
||||
}
|
||||
}
|
||||
76
internal/telemetry/telemetry_golden_test.go
Normal file
76
internal/telemetry/telemetry_golden_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Golden test that /metrics contains expected metric names.
|
||||
func TestMetricsGoldenContains(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
resetMetricsForTest()
|
||||
t.Setenv("NEWT_METRICS_INCLUDE_SITE_LABELS", "true")
|
||||
cfg := Config{ServiceName: "newt", PromEnabled: true, AdminAddr: "127.0.0.1:0", BuildVersion: "test"}
|
||||
tel, err := Init(ctx, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("telemetry init error: %v", err)
|
||||
}
|
||||
defer func() { _ = tel.Shutdown(context.Background()) }()
|
||||
|
||||
if tel.PrometheusHandler == nil {
|
||||
t.Fatalf("prom handler nil")
|
||||
}
|
||||
ts := httptest.NewServer(tel.PrometheusHandler)
|
||||
defer ts.Close()
|
||||
|
||||
// Trigger counters to ensure they appear in the scrape
|
||||
IncConnAttempt(ctx, "websocket", "success")
|
||||
IncWSReconnect(ctx, "io_error")
|
||||
IncProxyConnectionEvent(ctx, "", "tcp", ProxyConnectionOpened)
|
||||
if tel.MeterProvider != nil {
|
||||
_ = tel.MeterProvider.ForceFlush(ctx)
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
var body string
|
||||
for i := 0; i < 5; i++ {
|
||||
resp, err := http.Get(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("GET metrics failed: %v", err)
|
||||
}
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
body = string(b)
|
||||
if strings.Contains(body, "newt_connection_attempts_total") {
|
||||
break
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
f, err := os.Open(filepath.Join("testdata", "expected_contains.golden"))
|
||||
if err != nil {
|
||||
t.Fatalf("read golden: %v", err)
|
||||
}
|
||||
defer f.Close()
|
||||
s := bufio.NewScanner(f)
|
||||
for s.Scan() {
|
||||
needle := strings.TrimSpace(s.Text())
|
||||
if needle == "" {
|
||||
continue
|
||||
}
|
||||
if !strings.Contains(body, needle) {
|
||||
t.Fatalf("expected metrics body to contain %q. body=\n%s", needle, body)
|
||||
}
|
||||
}
|
||||
if err := s.Err(); err != nil {
|
||||
t.Fatalf("scan golden: %v", err)
|
||||
}
|
||||
}
|
||||
65
internal/telemetry/telemetry_smoke_test.go
Normal file
65
internal/telemetry/telemetry_smoke_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Smoke test that /metrics contains at least one newt_* metric when Prom exporter is enabled.
|
||||
func TestMetricsSmoke(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
resetMetricsForTest()
|
||||
t.Setenv("NEWT_METRICS_INCLUDE_SITE_LABELS", "true")
|
||||
cfg := Config{
|
||||
ServiceName: "newt",
|
||||
PromEnabled: true,
|
||||
OTLPEnabled: false,
|
||||
AdminAddr: "127.0.0.1:0",
|
||||
BuildVersion: "test",
|
||||
BuildCommit: "deadbeef",
|
||||
MetricExportInterval: 5 * time.Second,
|
||||
}
|
||||
tel, err := Init(ctx, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("telemetry init error: %v", err)
|
||||
}
|
||||
defer func() { _ = tel.Shutdown(context.Background()) }()
|
||||
|
||||
// Serve the Prom handler on a test server
|
||||
if tel.PrometheusHandler == nil {
|
||||
t.Fatalf("Prometheus handler nil; PromEnabled should enable it")
|
||||
}
|
||||
ts := httptest.NewServer(tel.PrometheusHandler)
|
||||
defer ts.Close()
|
||||
|
||||
// Record a simple metric and then fetch /metrics
|
||||
IncConnAttempt(ctx, "websocket", "success")
|
||||
if tel.MeterProvider != nil {
|
||||
_ = tel.MeterProvider.ForceFlush(ctx)
|
||||
}
|
||||
// Give the exporter a tick to collect
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
var body string
|
||||
for i := 0; i < 5; i++ {
|
||||
resp, err := http.Get(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("GET /metrics failed: %v", err)
|
||||
}
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
body = string(b)
|
||||
if strings.Contains(body, "newt_connection_attempts_total") {
|
||||
break
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
if !strings.Contains(body, "newt_connection_attempts_total") {
|
||||
t.Fatalf("expected newt_connection_attempts_total in metrics, got:\n%s", body)
|
||||
}
|
||||
}
|
||||
3
internal/telemetry/testdata/expected_contains.golden
vendored
Normal file
3
internal/telemetry/testdata/expected_contains.golden
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
newt_connection_attempts_total
|
||||
newt_websocket_reconnects_total
|
||||
newt_proxy_connections_total
|
||||
74
linux.go
74
linux.go
@@ -1,74 +0,0 @@
|
||||
//go:build linux
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/proxy"
|
||||
"github.com/fosrl/newt/websocket"
|
||||
"github.com/fosrl/newt/wg"
|
||||
"github.com/fosrl/newt/wgtester"
|
||||
)
|
||||
|
||||
var wgServiceNative *wg.WireGuardService
|
||||
|
||||
func setupClientsNative(client *websocket.Client, host string) {
|
||||
|
||||
if runtime.GOOS != "linux" {
|
||||
logger.Fatal("Tunnel management is only supported on Linux right now!")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// make sure we are sudo
|
||||
if os.Geteuid() != 0 {
|
||||
logger.Fatal("You must run this program as root to manage tunnels on Linux.")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Create WireGuard service
|
||||
wgServiceNative, err = wg.NewWireGuardService(interfaceName, mtuInt, generateAndSaveKeyTo, host, id, client)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to create WireGuard service: %v", err)
|
||||
}
|
||||
|
||||
wgTesterServer = wgtester.NewServer("0.0.0.0", wgServiceNative.Port, id) // TODO: maybe make this the same ip of the wg server?
|
||||
err := wgTesterServer.Start()
|
||||
if err != nil {
|
||||
logger.Error("Failed to start WireGuard tester server: %v", err)
|
||||
}
|
||||
|
||||
client.OnTokenUpdate(func(token string) {
|
||||
wgServiceNative.SetToken(token)
|
||||
})
|
||||
}
|
||||
|
||||
func closeWgServiceNative() {
|
||||
if wgServiceNative != nil {
|
||||
wgServiceNative.Close(!keepInterface)
|
||||
wgServiceNative = nil
|
||||
}
|
||||
}
|
||||
|
||||
func clientsOnConnectNative() {
|
||||
if wgServiceNative != nil {
|
||||
wgServiceNative.LoadRemoteConfig()
|
||||
}
|
||||
}
|
||||
|
||||
func clientsHandleNewtConnectionNative(publicKey, endpoint string) {
|
||||
if wgServiceNative != nil {
|
||||
wgServiceNative.StartHolepunch(publicKey, endpoint)
|
||||
}
|
||||
}
|
||||
|
||||
func clientsAddProxyTargetNative(pm *proxy.ProxyManager, tunnelIp string) {
|
||||
// add a udp proxy for localost and the wgService port
|
||||
// TODO: make sure this port is not used in a target
|
||||
if wgServiceNative != nil {
|
||||
pm.AddTarget("udp", tunnelIp, int(wgServiceNative.Port), fmt.Sprintf("127.0.0.1:%d", wgServiceNative.Port))
|
||||
}
|
||||
}
|
||||
@@ -2,16 +2,15 @@ package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Logger struct holds the logger instance
|
||||
type Logger struct {
|
||||
logger *log.Logger
|
||||
writer LogWriter
|
||||
level LogLevel
|
||||
}
|
||||
|
||||
@@ -20,17 +19,29 @@ var (
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
// NewLogger creates a new logger instance
|
||||
// NewLogger creates a new logger instance with the default StandardWriter
|
||||
func NewLogger() *Logger {
|
||||
return &Logger{
|
||||
logger: log.New(os.Stdout, "", 0),
|
||||
writer: NewStandardWriter(),
|
||||
level: DEBUG,
|
||||
}
|
||||
}
|
||||
|
||||
// NewLoggerWithWriter creates a new logger instance with a custom LogWriter
|
||||
func NewLoggerWithWriter(writer LogWriter) *Logger {
|
||||
return &Logger{
|
||||
writer: writer,
|
||||
level: DEBUG,
|
||||
}
|
||||
}
|
||||
|
||||
// Init initializes the default logger
|
||||
func Init() *Logger {
|
||||
func Init(logger *Logger) *Logger {
|
||||
once.Do(func() {
|
||||
if logger != nil {
|
||||
defaultLogger = logger
|
||||
return
|
||||
}
|
||||
defaultLogger = NewLogger()
|
||||
})
|
||||
return defaultLogger
|
||||
@@ -39,7 +50,7 @@ func Init() *Logger {
|
||||
// GetLogger returns the default logger instance
|
||||
func GetLogger() *Logger {
|
||||
if defaultLogger == nil {
|
||||
Init()
|
||||
Init(nil)
|
||||
}
|
||||
return defaultLogger
|
||||
}
|
||||
@@ -49,9 +60,11 @@ func (l *Logger) SetLevel(level LogLevel) {
|
||||
l.level = level
|
||||
}
|
||||
|
||||
// SetOutput sets the output destination for the logger
|
||||
func (l *Logger) SetOutput(w io.Writer) {
|
||||
l.logger.SetOutput(w)
|
||||
// SetOutput sets the output destination for the logger (only works with StandardWriter)
|
||||
func (l *Logger) SetOutput(output *os.File) {
|
||||
if sw, ok := l.writer.(*StandardWriter); ok {
|
||||
sw.SetOutput(output)
|
||||
}
|
||||
}
|
||||
|
||||
// log handles the actual logging
|
||||
@@ -60,24 +73,8 @@ func (l *Logger) log(level LogLevel, format string, args ...interface{}) {
|
||||
return
|
||||
}
|
||||
|
||||
// Get timezone from environment variable or use local timezone
|
||||
timezone := os.Getenv("LOGGER_TIMEZONE")
|
||||
var location *time.Location
|
||||
var err error
|
||||
|
||||
if timezone != "" {
|
||||
location, err = time.LoadLocation(timezone)
|
||||
if err != nil {
|
||||
// If invalid timezone, fall back to local
|
||||
location = time.Local
|
||||
}
|
||||
} else {
|
||||
location = time.Local
|
||||
}
|
||||
|
||||
timestamp := time.Now().In(location).Format("2006/01/02 15:04:05")
|
||||
message := fmt.Sprintf(format, args...)
|
||||
l.logger.Printf("%s: %s %s", level.String(), timestamp, message)
|
||||
l.writer.Write(level, time.Now(), message)
|
||||
}
|
||||
|
||||
// Debug logs debug level messages
|
||||
@@ -128,6 +125,29 @@ func Fatal(format string, args ...interface{}) {
|
||||
}
|
||||
|
||||
// SetOutput sets the output destination for the default logger
|
||||
func SetOutput(w io.Writer) {
|
||||
GetLogger().SetOutput(w)
|
||||
func SetOutput(output *os.File) {
|
||||
GetLogger().SetOutput(output)
|
||||
}
|
||||
|
||||
// WireGuardLogger is a wrapper type that matches WireGuard's Logger interface
|
||||
type WireGuardLogger struct {
|
||||
Verbosef func(format string, args ...any)
|
||||
Errorf func(format string, args ...any)
|
||||
}
|
||||
|
||||
// GetWireGuardLogger returns a WireGuard-compatible logger that writes to the newt logger
|
||||
// The prepend string is added as a prefix to all log messages
|
||||
func (l *Logger) GetWireGuardLogger(prepend string) *WireGuardLogger {
|
||||
return &WireGuardLogger{
|
||||
Verbosef: func(format string, args ...any) {
|
||||
// if the format string contains "Sending keepalive packet", skip debug logging to reduce noise
|
||||
if strings.Contains(format, "Sending keepalive packet") {
|
||||
return
|
||||
}
|
||||
l.Debug(prepend+format, args...)
|
||||
},
|
||||
Errorf: func(format string, args ...any) {
|
||||
l.Error(prepend+format, args...)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
54
logger/writer.go
Normal file
54
logger/writer.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// LogWriter is an interface for writing log messages
|
||||
// Implement this interface to create custom log backends (OS log, syslog, etc.)
|
||||
type LogWriter interface {
|
||||
// Write writes a log message with the given level, timestamp, and formatted message
|
||||
Write(level LogLevel, timestamp time.Time, message string)
|
||||
}
|
||||
|
||||
// StandardWriter is the default log writer that writes to an io.Writer
|
||||
type StandardWriter struct {
|
||||
output *os.File
|
||||
timezone *time.Location
|
||||
}
|
||||
|
||||
// NewStandardWriter creates a new standard writer with the default configuration
|
||||
func NewStandardWriter() *StandardWriter {
|
||||
// Get timezone from environment variable or use local timezone
|
||||
timezone := os.Getenv("LOGGER_TIMEZONE")
|
||||
var location *time.Location
|
||||
var err error
|
||||
|
||||
if timezone != "" {
|
||||
location, err = time.LoadLocation(timezone)
|
||||
if err != nil {
|
||||
// If invalid timezone, fall back to local
|
||||
location = time.Local
|
||||
}
|
||||
} else {
|
||||
location = time.Local
|
||||
}
|
||||
|
||||
return &StandardWriter{
|
||||
output: os.Stdout,
|
||||
timezone: location,
|
||||
}
|
||||
}
|
||||
|
||||
// SetOutput sets the output destination
|
||||
func (w *StandardWriter) SetOutput(output *os.File) {
|
||||
w.output = output
|
||||
}
|
||||
|
||||
// Write implements the LogWriter interface
|
||||
func (w *StandardWriter) Write(level LogLevel, timestamp time.Time, message string) {
|
||||
formattedTime := timestamp.In(w.timezone).Format("2006/01/02 15:04:05")
|
||||
fmt.Fprintf(w.output, "%s: %s %s\n", level.String(), formattedTime, message)
|
||||
}
|
||||
633
main.go
633
main.go
@@ -1,7 +1,11 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -9,19 +13,23 @@ import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/authdaemon"
|
||||
"github.com/fosrl/newt/docker"
|
||||
"github.com/fosrl/newt/healthcheck"
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/proxy"
|
||||
"github.com/fosrl/newt/updates"
|
||||
"github.com/fosrl/newt/util"
|
||||
"github.com/fosrl/newt/websocket"
|
||||
|
||||
"github.com/fosrl/newt/internal/state"
|
||||
"github.com/fosrl/newt/internal/telemetry"
|
||||
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
@@ -31,6 +39,7 @@ import (
|
||||
|
||||
type WgData struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
RelayPort uint16 `json:"relayPort"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
ServerIP string `json:"serverIP"`
|
||||
TunnelIP string `json:"tunnelIP"`
|
||||
@@ -51,10 +60,6 @@ type ExitNodeData struct {
|
||||
ExitNodes []ExitNode `json:"exitNodes"`
|
||||
}
|
||||
|
||||
type SSHPublicKeyData struct {
|
||||
PublicKey string `json:"publicKey"`
|
||||
}
|
||||
|
||||
// ExitNode represents an exit node with an ID, endpoint, and weight.
|
||||
type ExitNode struct {
|
||||
ID int `json:"exitNodeId"`
|
||||
@@ -91,6 +96,14 @@ func (s *stringSlice) Set(value string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
fmtErrMarshaling = "Error marshaling data: %v"
|
||||
fmtReceivedMsg = "Received: %+v"
|
||||
topicWGRegister = "newt/wg/register"
|
||||
msgNoTunnelOrProxy = "No tunnel IP or proxy manager available"
|
||||
fmtErrParsingTargetData = "Error parsing target data: %v"
|
||||
)
|
||||
|
||||
var (
|
||||
endpoint string
|
||||
id string
|
||||
@@ -102,9 +115,8 @@ var (
|
||||
err error
|
||||
logLevel string
|
||||
interfaceName string
|
||||
generateAndSaveKeyTo string
|
||||
keepInterface bool
|
||||
acceptClients bool
|
||||
port uint16
|
||||
disableClients bool
|
||||
updownScript string
|
||||
dockerSocket string
|
||||
dockerEnforceNetworkValidation string
|
||||
@@ -120,8 +132,21 @@ var (
|
||||
preferEndpoint string
|
||||
healthMonitor *healthcheck.Monitor
|
||||
enforceHealthcheckCert bool
|
||||
blueprintFile string
|
||||
noCloud bool
|
||||
authDaemonKey string
|
||||
authDaemonPrincipalsFile string
|
||||
authDaemonCACertPath string
|
||||
authDaemonEnabled bool
|
||||
// Build/version (can be overridden via -ldflags "-X main.newtVersion=...")
|
||||
newtVersion = "version_replaceme"
|
||||
|
||||
// Observability/metrics flags
|
||||
metricsEnabled bool
|
||||
otlpEnabled bool
|
||||
adminAddr string
|
||||
region string
|
||||
metricsAsyncBytes bool
|
||||
blueprintFile string
|
||||
noCloud bool
|
||||
|
||||
// New mTLS configuration variables
|
||||
tlsClientCert string
|
||||
@@ -133,6 +158,49 @@ var (
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Check for subcommands first (only principals exits early)
|
||||
if len(os.Args) > 1 {
|
||||
switch os.Args[1] {
|
||||
case "auth-daemon":
|
||||
// Run principals subcommand only if the next argument is "principals"
|
||||
if len(os.Args) > 2 && os.Args[2] == "principals" {
|
||||
runPrincipalsCmd(os.Args[3:])
|
||||
return
|
||||
}
|
||||
|
||||
// auth-daemon subcommand without "principals" - show help
|
||||
fmt.Println("Error: auth-daemon subcommand requires 'principals' argument")
|
||||
fmt.Println()
|
||||
fmt.Println("Usage:")
|
||||
fmt.Println(" newt auth-daemon principals [options]")
|
||||
fmt.Println()
|
||||
|
||||
// If not "principals", exit the switch to continue with normal execution
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check if we're running as a Windows service
|
||||
if isWindowsService() {
|
||||
runService("NewtWireguardService", false, os.Args[1:])
|
||||
return
|
||||
}
|
||||
|
||||
// Handle service management commands on Windows (install, remove, start, stop, etc.)
|
||||
if handleServiceCommand() {
|
||||
return
|
||||
}
|
||||
|
||||
// Prepare context for graceful shutdown and signal handling
|
||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
// Run the main newt logic
|
||||
runNewtMain(ctx)
|
||||
}
|
||||
|
||||
// runNewtMain contains the main newt logic, extracted for service support
|
||||
func runNewtMain(ctx context.Context) {
|
||||
// if PANGOLIN_ENDPOINT, NEWT_ID, and NEWT_SECRET are set as environment variables, they will be used as default values
|
||||
endpoint = os.Getenv("PANGOLIN_ENDPOINT")
|
||||
id = os.Getenv("NEWT_ID")
|
||||
@@ -142,11 +210,21 @@ func main() {
|
||||
logLevel = os.Getenv("LOG_LEVEL")
|
||||
updownScript = os.Getenv("UPDOWN_SCRIPT")
|
||||
interfaceName = os.Getenv("INTERFACE")
|
||||
generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO")
|
||||
keepInterfaceEnv := os.Getenv("KEEP_INTERFACE")
|
||||
keepInterface = keepInterfaceEnv == "true"
|
||||
acceptClientsEnv := os.Getenv("ACCEPT_CLIENTS")
|
||||
acceptClients = acceptClientsEnv == "true"
|
||||
portStr := os.Getenv("PORT")
|
||||
authDaemonKey = os.Getenv("AD_KEY")
|
||||
authDaemonPrincipalsFile = os.Getenv("AD_PRINCIPALS_FILE")
|
||||
authDaemonCACertPath = os.Getenv("AD_CA_CERT_PATH")
|
||||
authDaemonEnabledEnv := os.Getenv("AUTH_DAEMON_ENABLED")
|
||||
|
||||
// Metrics/observability env mirrors
|
||||
metricsEnabledEnv := os.Getenv("NEWT_METRICS_PROMETHEUS_ENABLED")
|
||||
otlpEnabledEnv := os.Getenv("NEWT_METRICS_OTLP_ENABLED")
|
||||
adminAddrEnv := os.Getenv("NEWT_ADMIN_ADDR")
|
||||
regionEnv := os.Getenv("NEWT_REGION")
|
||||
asyncBytesEnv := os.Getenv("NEWT_METRICS_ASYNC_BYTES")
|
||||
|
||||
disableClientsEnv := os.Getenv("DISABLE_CLIENTS")
|
||||
disableClients = disableClientsEnv == "true"
|
||||
useNativeInterfaceEnv := os.Getenv("USE_NATIVE_INTERFACE")
|
||||
useNativeInterface = useNativeInterfaceEnv == "true"
|
||||
enforceHealthcheckCertEnv := os.Getenv("ENFORCE_HC_CERT")
|
||||
@@ -205,17 +283,14 @@ func main() {
|
||||
if interfaceName == "" {
|
||||
flag.StringVar(&interfaceName, "interface", "newt", "Name of the WireGuard interface")
|
||||
}
|
||||
if generateAndSaveKeyTo == "" {
|
||||
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
|
||||
}
|
||||
if keepInterfaceEnv == "" {
|
||||
flag.BoolVar(&keepInterface, "keep-interface", false, "Keep the WireGuard interface")
|
||||
if portStr == "" {
|
||||
flag.StringVar(&portStr, "port", "", "Port for client WireGuard interface")
|
||||
}
|
||||
if useNativeInterfaceEnv == "" {
|
||||
flag.BoolVar(&useNativeInterface, "native", false, "Use native WireGuard interface (requires WireGuard kernel module) and linux")
|
||||
flag.BoolVar(&useNativeInterface, "native", false, "Use native WireGuard interface")
|
||||
}
|
||||
if acceptClientsEnv == "" {
|
||||
flag.BoolVar(&acceptClients, "accept-clients", false, "Accept clients on the WireGuard interface")
|
||||
if disableClientsEnv == "" {
|
||||
flag.BoolVar(&disableClients, "disable-clients", false, "Disable clients on the WireGuard interface")
|
||||
}
|
||||
if enforceHealthcheckCertEnv == "" {
|
||||
flag.BoolVar(&enforceHealthcheckCert, "enforce-hc-cert", false, "Enforce certificate validation for health checks (default: false, accepts any cert)")
|
||||
@@ -232,10 +307,6 @@ func main() {
|
||||
// load the prefer endpoint just as a flag
|
||||
flag.StringVar(&preferEndpoint, "prefer-endpoint", "", "Prefer this endpoint for the connection (if set, will override the endpoint from the server)")
|
||||
|
||||
// if authorizedKeysFile == "" {
|
||||
// flag.StringVar(&authorizedKeysFile, "authorized-keys-file", "~/.ssh/authorized_keys", "Path to authorized keys file (if unset, no keys will be authorized)")
|
||||
// }
|
||||
|
||||
// Add new mTLS flags
|
||||
if tlsClientCert == "" {
|
||||
flag.StringVar(&tlsClientCert, "tls-client-cert-file", "", "Path to client certificate file (PEM/DER format)")
|
||||
@@ -273,6 +344,15 @@ func main() {
|
||||
pingTimeout = 5 * time.Second
|
||||
}
|
||||
|
||||
if portStr != "" {
|
||||
portInt, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to parse PORT, choosing a random port")
|
||||
} else {
|
||||
port = uint16(portInt)
|
||||
}
|
||||
}
|
||||
|
||||
if dockerEnforceNetworkValidation == "" {
|
||||
flag.StringVar(&dockerEnforceNetworkValidation, "docker-enforce-network-validation", "false", "Enforce validation of container on newt network (true or false)")
|
||||
}
|
||||
@@ -286,6 +366,61 @@ func main() {
|
||||
flag.BoolVar(&noCloud, "no-cloud", false, "Disable cloud failover")
|
||||
}
|
||||
|
||||
// Metrics/observability flags (mirror ENV if unset)
|
||||
if metricsEnabledEnv == "" {
|
||||
flag.BoolVar(&metricsEnabled, "metrics", false, "Enable Prometheus metrics exporter")
|
||||
} else {
|
||||
if v, err := strconv.ParseBool(metricsEnabledEnv); err == nil {
|
||||
metricsEnabled = v
|
||||
} else {
|
||||
metricsEnabled = true
|
||||
}
|
||||
}
|
||||
if otlpEnabledEnv == "" {
|
||||
flag.BoolVar(&otlpEnabled, "otlp", false, "Enable OTLP exporters (metrics/traces) to OTEL_EXPORTER_OTLP_ENDPOINT")
|
||||
} else {
|
||||
if v, err := strconv.ParseBool(otlpEnabledEnv); err == nil {
|
||||
otlpEnabled = v
|
||||
}
|
||||
}
|
||||
if adminAddrEnv == "" {
|
||||
flag.StringVar(&adminAddr, "metrics-admin-addr", "127.0.0.1:2112", "Admin/metrics bind address")
|
||||
} else {
|
||||
adminAddr = adminAddrEnv
|
||||
}
|
||||
// Async bytes toggle
|
||||
if asyncBytesEnv == "" {
|
||||
flag.BoolVar(&metricsAsyncBytes, "metrics-async-bytes", false, "Enable async bytes counting (background flush; lower hot path overhead)")
|
||||
} else {
|
||||
if v, err := strconv.ParseBool(asyncBytesEnv); err == nil {
|
||||
metricsAsyncBytes = v
|
||||
}
|
||||
}
|
||||
// Optional region flag (resource attribute)
|
||||
if regionEnv == "" {
|
||||
flag.StringVar(®ion, "region", "", "Optional region resource attribute (also NEWT_REGION)")
|
||||
} else {
|
||||
region = regionEnv
|
||||
}
|
||||
|
||||
// Auth daemon flags
|
||||
if authDaemonKey == "" {
|
||||
flag.StringVar(&authDaemonKey, "ad-preshared-key", "", "Preshared key for auth daemon authentication (required when --auth-daemon is true)")
|
||||
}
|
||||
if authDaemonPrincipalsFile == "" {
|
||||
flag.StringVar(&authDaemonPrincipalsFile, "ad-principals-file", "/var/run/auth-daemon/principals", "Path to the principals file for auth daemon")
|
||||
}
|
||||
if authDaemonCACertPath == "" {
|
||||
flag.StringVar(&authDaemonCACertPath, "ad-ca-cert-path", "/etc/ssh/ca.pem", "Path to the CA certificate file for auth daemon")
|
||||
}
|
||||
if authDaemonEnabledEnv == "" {
|
||||
flag.BoolVar(&authDaemonEnabled, "auth-daemon", false, "Enable auth daemon mode (runs alongside normal newt operation)")
|
||||
} else {
|
||||
if v, err := strconv.ParseBool(authDaemonEnabledEnv); err == nil {
|
||||
authDaemonEnabled = v
|
||||
}
|
||||
}
|
||||
|
||||
// do a --version check
|
||||
version := flag.Bool("version", false, "Print the version")
|
||||
|
||||
@@ -296,16 +431,69 @@ func main() {
|
||||
tlsClientCAs = append(tlsClientCAs, tlsClientCAsFlag...)
|
||||
}
|
||||
|
||||
logger.Init()
|
||||
loggerLevel := parseLogLevel(logLevel)
|
||||
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
|
||||
|
||||
newtVersion := "version_replaceme"
|
||||
if *version {
|
||||
fmt.Println("Newt version " + newtVersion)
|
||||
os.Exit(0)
|
||||
} else {
|
||||
logger.Info("Newt version " + newtVersion)
|
||||
logger.Info("Newt version %s", newtVersion)
|
||||
}
|
||||
|
||||
logger.Init(nil)
|
||||
loggerLevel := util.ParseLogLevel(logLevel)
|
||||
|
||||
// Start auth daemon if enabled
|
||||
if authDaemonEnabled {
|
||||
if err := startAuthDaemon(ctx); err != nil {
|
||||
logger.Fatal("Failed to start auth daemon: %v", err)
|
||||
}
|
||||
}
|
||||
logger.GetLogger().SetLevel(loggerLevel)
|
||||
|
||||
// Initialize telemetry after flags are parsed (so flags override env)
|
||||
tcfg := telemetry.FromEnv()
|
||||
tcfg.PromEnabled = metricsEnabled
|
||||
tcfg.OTLPEnabled = otlpEnabled
|
||||
if adminAddr != "" {
|
||||
tcfg.AdminAddr = adminAddr
|
||||
}
|
||||
// Resource attributes (if available)
|
||||
tcfg.SiteID = id
|
||||
tcfg.Region = region
|
||||
// Build info
|
||||
tcfg.BuildVersion = newtVersion
|
||||
tcfg.BuildCommit = os.Getenv("NEWT_COMMIT")
|
||||
|
||||
tel, telErr := telemetry.Init(ctx, tcfg)
|
||||
if telErr != nil {
|
||||
logger.Warn("Telemetry init failed: %v", telErr)
|
||||
}
|
||||
if tel != nil {
|
||||
// Admin HTTP server (exposes /metrics when Prometheus exporter is enabled)
|
||||
logger.Debug("Starting metrics server on %s", tcfg.AdminAddr)
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) })
|
||||
if tel.PrometheusHandler != nil {
|
||||
mux.Handle("/metrics", tel.PrometheusHandler)
|
||||
}
|
||||
admin := &http.Server{
|
||||
Addr: tcfg.AdminAddr,
|
||||
Handler: otelhttp.NewHandler(mux, "newt-admin"),
|
||||
ReadTimeout: 5 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ReadHeaderTimeout: 5 * time.Second,
|
||||
IdleTimeout: 30 * time.Second,
|
||||
}
|
||||
go func() {
|
||||
if err := admin.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
logger.Warn("admin http error: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = admin.Shutdown(ctx)
|
||||
}()
|
||||
defer func() { _ = tel.Shutdown(context.Background()) }()
|
||||
}
|
||||
|
||||
if err := updates.CheckForUpdate("fosrl", "newt", newtVersion); err != nil {
|
||||
@@ -376,6 +564,8 @@ func main() {
|
||||
}
|
||||
endpoint = client.GetConfig().Endpoint // Update endpoint from config
|
||||
id = client.GetConfig().ID // Update ID from config
|
||||
// Update site labels for metrics with the resolved ID
|
||||
telemetry.UpdateSiteInfo(id, region)
|
||||
|
||||
// output env var values if set
|
||||
logger.Debug("Endpoint: %v", endpoint)
|
||||
@@ -419,7 +609,7 @@ func main() {
|
||||
var wgData WgData
|
||||
var dockerEventMonitor *docker.EventMonitor
|
||||
|
||||
if acceptClients {
|
||||
if !disableClients {
|
||||
setupClients(client)
|
||||
}
|
||||
|
||||
@@ -484,6 +674,10 @@ func main() {
|
||||
// Register handlers for different message types
|
||||
client.RegisterHandler("newt/wg/connect", func(msg websocket.WSMessage) {
|
||||
logger.Debug("Received registration message")
|
||||
regResult := "success"
|
||||
defer func() {
|
||||
telemetry.IncSiteRegistration(ctx, regResult)
|
||||
}()
|
||||
if stopFunc != nil {
|
||||
stopFunc() // stop the ws from sending more requests
|
||||
stopFunc = nil // reset stopFunc to nil to avoid double stopping
|
||||
@@ -502,64 +696,76 @@ func main() {
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error marshaling data: %v", err)
|
||||
logger.Info(fmtErrMarshaling, err)
|
||||
regResult = "failure"
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &wgData); err != nil {
|
||||
logger.Info("Error unmarshaling target data: %v", err)
|
||||
regResult = "failure"
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debug("Received: %+v", msg)
|
||||
logger.Debug(fmtReceivedMsg, msg)
|
||||
tun, tnet, err = netstack.CreateNetTUN(
|
||||
[]netip.Addr{netip.MustParseAddr(wgData.TunnelIP)},
|
||||
[]netip.Addr{netip.MustParseAddr(dns)},
|
||||
mtuInt)
|
||||
if err != nil {
|
||||
logger.Error("Failed to create TUN device: %v", err)
|
||||
regResult = "failure"
|
||||
}
|
||||
|
||||
setDownstreamTNetstack(tnet)
|
||||
|
||||
// Create WireGuard device
|
||||
dev = device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(
|
||||
mapToWireGuardLogLevel(loggerLevel),
|
||||
"wireguard: ",
|
||||
util.MapToWireGuardLogLevel(loggerLevel),
|
||||
"gerbil-wireguard: ",
|
||||
))
|
||||
|
||||
host, _, err := net.SplitHostPort(wgData.Endpoint)
|
||||
if err != nil {
|
||||
logger.Error("Failed to split endpoint: %v", err)
|
||||
regResult = "failure"
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Connecting to endpoint: %s", host)
|
||||
|
||||
endpoint, err := resolveDomain(wgData.Endpoint)
|
||||
endpoint, err := util.ResolveDomain(wgData.Endpoint)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve endpoint: %v", err)
|
||||
regResult = "failure"
|
||||
return
|
||||
}
|
||||
|
||||
clientsHandleNewtConnection(wgData.PublicKey, endpoint)
|
||||
relayPort := wgData.RelayPort
|
||||
if relayPort == 0 {
|
||||
relayPort = 21820
|
||||
}
|
||||
|
||||
clientsHandleNewtConnection(wgData.PublicKey, endpoint, relayPort)
|
||||
|
||||
// Configure WireGuard
|
||||
config := fmt.Sprintf(`private_key=%s
|
||||
public_key=%s
|
||||
allowed_ip=%s/32
|
||||
endpoint=%s
|
||||
persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint)
|
||||
persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(wgData.PublicKey), wgData.ServerIP, endpoint)
|
||||
|
||||
err = dev.IpcSet(config)
|
||||
if err != nil {
|
||||
logger.Error("Failed to configure WireGuard device: %v", err)
|
||||
regResult = "failure"
|
||||
}
|
||||
|
||||
// Bring up the device
|
||||
err = dev.Up()
|
||||
if err != nil {
|
||||
logger.Error("Failed to bring up WireGuard device: %v", err)
|
||||
regResult = "failure"
|
||||
}
|
||||
|
||||
logger.Debug("WireGuard device created. Lets ping the server now...")
|
||||
@@ -572,9 +778,13 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
}
|
||||
// Use reliable ping for initial connection test
|
||||
logger.Debug("Testing initial connection with reliable ping...")
|
||||
_, err = reliablePing(tnet, wgData.ServerIP, pingTimeout, 5)
|
||||
lat, err := reliablePing(tnet, wgData.ServerIP, pingTimeout, 5)
|
||||
if err == nil && wgData.PublicKey != "" {
|
||||
telemetry.ObserveTunnelLatency(ctx, wgData.PublicKey, "wireguard", lat.Seconds())
|
||||
}
|
||||
if err != nil {
|
||||
logger.Warn("Initial reliable ping failed, but continuing: %v", err)
|
||||
regResult = "failure"
|
||||
} else {
|
||||
logger.Debug("Initial connection test successful")
|
||||
}
|
||||
@@ -585,11 +795,14 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
// as the pings will continue in the background
|
||||
if !connected {
|
||||
logger.Debug("Starting ping check")
|
||||
pingStopChan = startPingCheck(tnet, wgData.ServerIP, client)
|
||||
pingStopChan = startPingCheck(tnet, wgData.ServerIP, client, wgData.PublicKey)
|
||||
}
|
||||
|
||||
// Create proxy manager
|
||||
pm = proxy.NewProxyManager(tnet)
|
||||
pm.SetAsyncBytes(metricsAsyncBytes)
|
||||
// Set tunnel_id for metrics (WireGuard peer public key)
|
||||
pm.SetTunnelID(wgData.PublicKey)
|
||||
|
||||
connected = true
|
||||
|
||||
@@ -610,7 +823,8 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
// }
|
||||
}
|
||||
|
||||
clientsAddProxyTarget(pm, wgData.TunnelIP)
|
||||
// Start direct UDP relay from main tunnel to clients' WireGuard (bypasses proxy)
|
||||
clientsStartDirectRelay(wgData.TunnelIP)
|
||||
|
||||
if err := healthMonitor.AddTargets(wgData.HealthCheckTargets); err != nil {
|
||||
logger.Error("Failed to bulk add health check targets: %v", err)
|
||||
@@ -626,10 +840,19 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
|
||||
client.RegisterHandler("newt/wg/reconnect", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received reconnect message")
|
||||
if wgData.PublicKey != "" {
|
||||
telemetry.IncReconnect(ctx, wgData.PublicKey, "server", telemetry.ReasonServerRequest)
|
||||
}
|
||||
|
||||
// Close the WireGuard device and TUN
|
||||
closeWgTunnel()
|
||||
|
||||
// Clear metrics attrs and sessions for the tunnel
|
||||
if pm != nil {
|
||||
pm.ClearTunnelID()
|
||||
state.Global().ClearTunnel(wgData.PublicKey)
|
||||
}
|
||||
|
||||
// Mark as disconnected
|
||||
connected = false
|
||||
|
||||
@@ -648,9 +871,13 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
|
||||
client.RegisterHandler("newt/wg/terminate", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received termination message")
|
||||
if wgData.PublicKey != "" {
|
||||
telemetry.IncReconnect(ctx, wgData.PublicKey, "server", telemetry.ReasonServerRequest)
|
||||
}
|
||||
|
||||
// Close the WireGuard device and TUN
|
||||
closeWgTunnel()
|
||||
closeClients()
|
||||
|
||||
if stopFunc != nil {
|
||||
stopFunc() // stop the ws from sending more requests
|
||||
@@ -675,7 +902,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error marshaling data: %v", err)
|
||||
logger.Info(fmtErrMarshaling, err)
|
||||
return
|
||||
}
|
||||
if err := json.Unmarshal(jsonData, &exitNodeData); err != nil {
|
||||
@@ -716,7 +943,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
},
|
||||
}
|
||||
|
||||
stopFunc = client.SendMessageInterval("newt/wg/register", map[string]interface{}{
|
||||
stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"pingResults": pingResults,
|
||||
"newtVersion": newtVersion,
|
||||
@@ -819,7 +1046,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
}
|
||||
|
||||
// Send the ping results to the cloud for selection
|
||||
stopFunc = client.SendMessageInterval("newt/wg/register", map[string]interface{}{
|
||||
stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"pingResults": pingResults,
|
||||
"newtVersion": newtVersion,
|
||||
@@ -829,17 +1056,17 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
})
|
||||
|
||||
client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) {
|
||||
logger.Debug("Received: %+v", msg)
|
||||
logger.Debug(fmtReceivedMsg, msg)
|
||||
|
||||
// if there is no wgData or pm, we can't add targets
|
||||
if wgData.TunnelIP == "" || pm == nil {
|
||||
logger.Info("No tunnel IP or proxy manager available")
|
||||
logger.Info(msgNoTunnelOrProxy)
|
||||
return
|
||||
}
|
||||
|
||||
targetData, err := parseTargetData(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error parsing target data: %v", err)
|
||||
logger.Info(fmtErrParsingTargetData, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -854,17 +1081,17 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
})
|
||||
|
||||
client.RegisterHandler("newt/udp/add", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received: %+v", msg)
|
||||
logger.Info(fmtReceivedMsg, msg)
|
||||
|
||||
// if there is no wgData or pm, we can't add targets
|
||||
if wgData.TunnelIP == "" || pm == nil {
|
||||
logger.Info("No tunnel IP or proxy manager available")
|
||||
logger.Info(msgNoTunnelOrProxy)
|
||||
return
|
||||
}
|
||||
|
||||
targetData, err := parseTargetData(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error parsing target data: %v", err)
|
||||
logger.Info(fmtErrParsingTargetData, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -879,17 +1106,17 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
})
|
||||
|
||||
client.RegisterHandler("newt/udp/remove", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received: %+v", msg)
|
||||
logger.Info(fmtReceivedMsg, msg)
|
||||
|
||||
// if there is no wgData or pm, we can't add targets
|
||||
if wgData.TunnelIP == "" || pm == nil {
|
||||
logger.Info("No tunnel IP or proxy manager available")
|
||||
logger.Info(msgNoTunnelOrProxy)
|
||||
return
|
||||
}
|
||||
|
||||
targetData, err := parseTargetData(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error parsing target data: %v", err)
|
||||
logger.Info(fmtErrParsingTargetData, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -904,17 +1131,17 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
})
|
||||
|
||||
client.RegisterHandler("newt/tcp/remove", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received: %+v", msg)
|
||||
logger.Info(fmtReceivedMsg, msg)
|
||||
|
||||
// if there is no wgData or pm, we can't add targets
|
||||
if wgData.TunnelIP == "" || pm == nil {
|
||||
logger.Info("No tunnel IP or proxy manager available")
|
||||
logger.Info(msgNoTunnelOrProxy)
|
||||
return
|
||||
}
|
||||
|
||||
targetData, err := parseTargetData(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error parsing target data: %v", err)
|
||||
logger.Info(fmtErrParsingTargetData, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -990,94 +1217,6 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
}
|
||||
})
|
||||
|
||||
// EXPERIMENTAL: WHAT SHOULD WE DO ABOUT SECURITY?
|
||||
client.RegisterHandler("newt/send/ssh/publicKey", func(msg websocket.WSMessage) {
|
||||
logger.Debug("Received SSH public key request")
|
||||
|
||||
var sshPublicKeyData SSHPublicKeyData
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
if err := json.Unmarshal(jsonData, &sshPublicKeyData); err != nil {
|
||||
logger.Info("Error unmarshaling SSH public key data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
sshPublicKey := sshPublicKeyData.PublicKey
|
||||
|
||||
if authorizedKeysFile == "" {
|
||||
logger.Debug("No authorized keys file set, skipping public key response")
|
||||
return
|
||||
}
|
||||
|
||||
// Expand tilde to home directory if present
|
||||
expandedPath := authorizedKeysFile
|
||||
if strings.HasPrefix(authorizedKeysFile, "~/") {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
logger.Error("Failed to get user home directory: %v", err)
|
||||
return
|
||||
}
|
||||
expandedPath = filepath.Join(homeDir, authorizedKeysFile[2:])
|
||||
}
|
||||
|
||||
// if it is set but the file does not exist, create it
|
||||
if _, err := os.Stat(expandedPath); os.IsNotExist(err) {
|
||||
logger.Debug("Authorized keys file does not exist, creating it: %s", expandedPath)
|
||||
if err := os.MkdirAll(filepath.Dir(expandedPath), 0755); err != nil {
|
||||
logger.Error("Failed to create directory for authorized keys file: %v", err)
|
||||
return
|
||||
}
|
||||
if _, err := os.Create(expandedPath); err != nil {
|
||||
logger.Error("Failed to create authorized keys file: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the public key already exists in the file
|
||||
fileContent, err := os.ReadFile(expandedPath)
|
||||
if err != nil {
|
||||
logger.Error("Failed to read authorized keys file: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the key already exists (trim whitespace for comparison)
|
||||
existingKeys := strings.Split(string(fileContent), "\n")
|
||||
keyAlreadyExists := false
|
||||
trimmedNewKey := strings.TrimSpace(sshPublicKey)
|
||||
|
||||
for _, existingKey := range existingKeys {
|
||||
if strings.TrimSpace(existingKey) == trimmedNewKey && trimmedNewKey != "" {
|
||||
keyAlreadyExists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if keyAlreadyExists {
|
||||
logger.Info("SSH public key already exists in authorized keys file, skipping")
|
||||
return
|
||||
}
|
||||
|
||||
// append the public key to the authorized keys file
|
||||
logger.Debug("Appending public key to authorized keys file: %s", sshPublicKey)
|
||||
file, err := os.OpenFile(expandedPath, os.O_APPEND|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
logger.Error("Failed to open authorized keys file: %v", err)
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
if _, err := file.WriteString(sshPublicKey + "\n"); err != nil {
|
||||
logger.Error("Failed to write public key to authorized keys file: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("SSH public key appended to authorized keys file")
|
||||
})
|
||||
|
||||
// Register handler for adding health check targets
|
||||
client.RegisterHandler("newt/healthcheck/add", func(msg websocket.WSMessage) {
|
||||
logger.Debug("Received health check add request: %+v", msg)
|
||||
@@ -1155,9 +1294,9 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
}
|
||||
|
||||
if err := healthMonitor.EnableTarget(requestData.ID); err != nil {
|
||||
logger.Error("Failed to enable health check target %s: %v", requestData.ID, err)
|
||||
logger.Error("Failed to enable health check target %d: %v", requestData.ID, err)
|
||||
} else {
|
||||
logger.Info("Enabled health check target: %s", requestData.ID)
|
||||
logger.Info("Enabled health check target: %d", requestData.ID)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1180,9 +1319,9 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
}
|
||||
|
||||
if err := healthMonitor.DisableTarget(requestData.ID); err != nil {
|
||||
logger.Error("Failed to disable health check target %s: %v", requestData.ID, err)
|
||||
logger.Error("Failed to disable health check target %d: %v", requestData.ID, err)
|
||||
} else {
|
||||
logger.Info("Disabled health check target: %s", requestData.ID)
|
||||
logger.Info("Disabled health check target: %d", requestData.ID)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1233,6 +1372,168 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
}
|
||||
})
|
||||
|
||||
// Register handler for SSH certificate issued events
|
||||
client.RegisterHandler("newt/pam/connection", func(msg websocket.WSMessage) {
|
||||
logger.Debug("Received SSH certificate issued message")
|
||||
|
||||
// Define the structure of the incoming message
|
||||
type SSHCertData struct {
|
||||
MessageId int `json:"messageId"`
|
||||
AgentPort int `json:"agentPort"`
|
||||
AgentHost string `json:"agentHost"`
|
||||
CACert string `json:"caCert"`
|
||||
Username string `json:"username"`
|
||||
NiceID string `json:"niceId"`
|
||||
Metadata struct {
|
||||
Sudo bool `json:"sudo"`
|
||||
Homedir bool `json:"homedir"`
|
||||
} `json:"metadata"`
|
||||
}
|
||||
|
||||
var certData SSHCertData
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling SSH cert data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// print the received data for debugging
|
||||
logger.Debug("Received SSH cert data: %s", string(jsonData))
|
||||
|
||||
if err := json.Unmarshal(jsonData, &certData); err != nil {
|
||||
logger.Error("Error unmarshaling SSH cert data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if we're running the auth daemon internally
|
||||
if authDaemonServer != nil {
|
||||
// Call ProcessConnection directly when running internally
|
||||
logger.Debug("Calling internal auth daemon ProcessConnection for user %s", certData.Username)
|
||||
|
||||
authDaemonServer.ProcessConnection(authdaemon.ConnectionRequest{
|
||||
CaCert: certData.CACert,
|
||||
NiceId: certData.NiceID,
|
||||
Username: certData.Username,
|
||||
Metadata: authdaemon.ConnectionMetadata{
|
||||
Sudo: certData.Metadata.Sudo,
|
||||
Homedir: certData.Metadata.Homedir,
|
||||
},
|
||||
})
|
||||
|
||||
// Send success response back to cloud
|
||||
err = client.SendMessage("ws/round-trip/complete", map[string]interface{}{
|
||||
"messageId": certData.MessageId,
|
||||
"complete": true,
|
||||
})
|
||||
|
||||
logger.Info("Successfully processed connection via internal auth daemon for user %s", certData.Username)
|
||||
} else {
|
||||
// External auth daemon mode - make HTTP request
|
||||
// Check if auth daemon key is configured
|
||||
if authDaemonKey == "" {
|
||||
logger.Error("Auth daemon key not configured, cannot communicate with daemon")
|
||||
// Send failure response back to cloud
|
||||
err := client.SendMessage("ws/round-trip/complete", map[string]interface{}{
|
||||
"messageId": certData.MessageId,
|
||||
"complete": true,
|
||||
"error": "auth daemon key not configured",
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send SSH cert failure response: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Prepare the request body for the auth daemon
|
||||
requestBody := map[string]interface{}{
|
||||
"caCert": certData.CACert,
|
||||
"niceId": certData.NiceID,
|
||||
"username": certData.Username,
|
||||
"metadata": map[string]interface{}{
|
||||
"sudo": certData.Metadata.Sudo,
|
||||
"homedir": certData.Metadata.Homedir,
|
||||
},
|
||||
}
|
||||
|
||||
requestJSON, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
logger.Error("Failed to marshal auth daemon request: %v", err)
|
||||
// Send failure response
|
||||
client.SendMessage("ws/round-trip/complete", map[string]interface{}{
|
||||
"messageId": certData.MessageId,
|
||||
"complete": true,
|
||||
"error": fmt.Sprintf("failed to marshal request: %v", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Create HTTPS client that skips certificate verification
|
||||
// (auth daemon uses self-signed cert)
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
},
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
// Make the request to the auth daemon
|
||||
url := fmt.Sprintf("https://%s:%d/connection", certData.AgentHost, certData.AgentPort)
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(requestJSON))
|
||||
if err != nil {
|
||||
logger.Error("Failed to create auth daemon request: %v", err)
|
||||
client.SendMessage("ws/round-trip/complete", map[string]interface{}{
|
||||
"messageId": certData.MessageId,
|
||||
"complete": true,
|
||||
"error": fmt.Sprintf("failed to create request: %v", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Set headers
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+authDaemonKey)
|
||||
|
||||
logger.Debug("Sending SSH cert to auth daemon at %s", url)
|
||||
|
||||
// Send the request
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
logger.Error("Failed to connect to auth daemon: %v", err)
|
||||
client.SendMessage("ws/round-trip/complete", map[string]interface{}{
|
||||
"messageId": certData.MessageId,
|
||||
"complete": true,
|
||||
"error": fmt.Sprintf("failed to connect to auth daemon: %v", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Check response status
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
logger.Error("Auth daemon returned non-OK status: %d", resp.StatusCode)
|
||||
client.SendMessage("ws/round-trip/complete", map[string]interface{}{
|
||||
"messageId": certData.MessageId,
|
||||
"complete": true,
|
||||
"error": fmt.Sprintf("auth daemon returned status %d", resp.StatusCode),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Successfully registered SSH certificate with external auth daemon for user %s", certData.Username)
|
||||
}
|
||||
|
||||
// Send success response back to cloud
|
||||
err = client.SendMessage("ws/round-trip/complete", map[string]interface{}{
|
||||
"messageId": certData.MessageId,
|
||||
"complete": true,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send SSH cert success response: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
client.OnConnect(func() error {
|
||||
publicKey = privateKey.PublicKey()
|
||||
logger.Debug("Public key: %s", publicKey)
|
||||
@@ -1248,11 +1549,16 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
"noCloud": noCloud,
|
||||
}, 3*time.Second)
|
||||
logger.Debug("Requesting exit nodes from server")
|
||||
clientsOnConnect()
|
||||
|
||||
if client.GetServerVersion() != "" { // to prevent issues with running newt > 1.7 versions with older servers
|
||||
clientsOnConnect()
|
||||
} else {
|
||||
logger.Warn("CLIENTS WILL NOT WORK ON THIS VERSION OF NEWT WITH THIS VERSION OF PANGOLIN, PLEASE UPDATE THE SERVER TO 1.13 OR HIGHER OR DOWNGRADE NEWT")
|
||||
}
|
||||
}
|
||||
|
||||
// Send registration message to the server for backward compatibility
|
||||
err := client.SendMessage("newt/wg/register", map[string]interface{}{
|
||||
err := client.SendMessage(topicWGRegister, map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"newtVersion": newtVersion,
|
||||
"backwardsCompatible": true,
|
||||
@@ -1302,10 +1608,8 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for interrupt signal
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sigCh
|
||||
// Wait for context cancellation (from signal or service stop)
|
||||
<-ctx.Done()
|
||||
|
||||
// Close clients first (including WGTester)
|
||||
closeClients()
|
||||
@@ -1330,7 +1634,20 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
client.Close()
|
||||
}
|
||||
logger.Info("Exiting...")
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// runNewtMainWithArgs is used by the Windows service to run newt with specific arguments
|
||||
// It sets os.Args and then calls runNewtMain
|
||||
func runNewtMainWithArgs(ctx context.Context, args []string) {
|
||||
// Set os.Args to include the program name plus the provided args
|
||||
// This allows flag parsing to work correctly
|
||||
os.Args = append([]string{os.Args[0]}, args...)
|
||||
|
||||
// Setup Windows logging if running as a service
|
||||
setupWindowsEventLog()
|
||||
|
||||
// Run the main newt logic
|
||||
runNewtMain(ctx)
|
||||
}
|
||||
|
||||
// validateTLSConfig validates the TLS configuration
|
||||
|
||||
701
netstack2/handlers.go
Normal file
701
netstack2/handlers.go
Normal file
@@ -0,0 +1,701 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package netstack2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/checksum"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
)
|
||||
|
||||
const (
|
||||
// defaultWndSize if set to zero, the default
|
||||
// receive window buffer size is used instead.
|
||||
defaultWndSize = 0
|
||||
|
||||
// maxConnAttempts specifies the maximum number
|
||||
// of in-flight tcp connection attempts.
|
||||
maxConnAttempts = 2 << 10
|
||||
|
||||
// tcpKeepaliveCount is the maximum number of
|
||||
// TCP keep-alive probes to send before giving up
|
||||
// and killing the connection if no response is
|
||||
// obtained from the other end.
|
||||
tcpKeepaliveCount = 9
|
||||
|
||||
// tcpKeepaliveIdle specifies the time a connection
|
||||
// must remain idle before the first TCP keepalive
|
||||
// packet is sent. Once this time is reached,
|
||||
// tcpKeepaliveInterval option is used instead.
|
||||
tcpKeepaliveIdle = 60 * time.Second
|
||||
|
||||
// tcpKeepaliveInterval specifies the interval
|
||||
// time between sending TCP keepalive packets.
|
||||
tcpKeepaliveInterval = 30 * time.Second
|
||||
|
||||
// tcpConnectTimeout is the default timeout for TCP handshakes.
|
||||
tcpConnectTimeout = 5 * time.Second
|
||||
|
||||
// tcpWaitTimeout implements a TCP half-close timeout.
|
||||
tcpWaitTimeout = 60 * time.Second
|
||||
|
||||
// udpSessionTimeout is the default timeout for UDP sessions.
|
||||
udpSessionTimeout = 60 * time.Second
|
||||
|
||||
// Buffer size for copying data
|
||||
bufferSize = 32 * 1024
|
||||
|
||||
// icmpTimeout is the default timeout for ICMP ping requests.
|
||||
icmpTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// TCPHandler handles TCP connections from netstack
|
||||
type TCPHandler struct {
|
||||
stack *stack.Stack
|
||||
proxyHandler *ProxyHandler
|
||||
}
|
||||
|
||||
// UDPHandler handles UDP connections from netstack
|
||||
type UDPHandler struct {
|
||||
stack *stack.Stack
|
||||
proxyHandler *ProxyHandler
|
||||
}
|
||||
|
||||
// ICMPHandler handles ICMP packets from netstack
|
||||
type ICMPHandler struct {
|
||||
stack *stack.Stack
|
||||
proxyHandler *ProxyHandler
|
||||
}
|
||||
|
||||
// NewTCPHandler creates a new TCP handler
|
||||
func NewTCPHandler(s *stack.Stack, ph *ProxyHandler) *TCPHandler {
|
||||
return &TCPHandler{stack: s, proxyHandler: ph}
|
||||
}
|
||||
|
||||
// NewUDPHandler creates a new UDP handler
|
||||
func NewUDPHandler(s *stack.Stack, ph *ProxyHandler) *UDPHandler {
|
||||
return &UDPHandler{stack: s, proxyHandler: ph}
|
||||
}
|
||||
|
||||
// NewICMPHandler creates a new ICMP handler
|
||||
func NewICMPHandler(s *stack.Stack, ph *ProxyHandler) *ICMPHandler {
|
||||
return &ICMPHandler{stack: s, proxyHandler: ph}
|
||||
}
|
||||
|
||||
// InstallTCPHandler installs the TCP forwarder on the stack
|
||||
func (h *TCPHandler) InstallTCPHandler() error {
|
||||
tcpForwarder := tcp.NewForwarder(h.stack, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) {
|
||||
var (
|
||||
wq waiter.Queue
|
||||
ep tcpip.Endpoint
|
||||
err tcpip.Error
|
||||
id = r.ID()
|
||||
)
|
||||
|
||||
// Perform a TCP three-way handshake
|
||||
ep, err = r.CreateEndpoint(&wq)
|
||||
if err != nil {
|
||||
// RST: prevent potential half-open TCP connection leak
|
||||
r.Complete(true)
|
||||
return
|
||||
}
|
||||
defer r.Complete(false)
|
||||
|
||||
// Set socket options
|
||||
setTCPSocketOptions(h.stack, ep)
|
||||
|
||||
// Create TCP connection from netstack endpoint
|
||||
netstackConn := gonet.NewTCPConn(&wq, ep)
|
||||
|
||||
// Handle the connection in a goroutine
|
||||
go h.handleTCPConn(netstackConn, id)
|
||||
})
|
||||
|
||||
h.stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleTCPConn handles a TCP connection by proxying it to the actual target
|
||||
func (h *TCPHandler) handleTCPConn(netstackConn *gonet.TCPConn, id stack.TransportEndpointID) {
|
||||
defer netstackConn.Close()
|
||||
|
||||
// Extract source and target address from the connection ID
|
||||
srcIP := id.RemoteAddress.String()
|
||||
srcPort := id.RemotePort
|
||||
dstIP := id.LocalAddress.String()
|
||||
dstPort := id.LocalPort
|
||||
|
||||
logger.Info("TCP Forwarder: Handling connection %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
// Check if there's a destination rewrite for this connection (e.g., localhost targets)
|
||||
actualDstIP := dstIP
|
||||
if h.proxyHandler != nil {
|
||||
if rewrittenAddr, ok := h.proxyHandler.LookupDestinationRewrite(srcIP, dstIP, dstPort, uint8(tcp.ProtocolNumber)); ok {
|
||||
actualDstIP = rewrittenAddr.String()
|
||||
logger.Info("TCP Forwarder: Using rewritten destination %s (original: %s)", actualDstIP, dstIP)
|
||||
}
|
||||
}
|
||||
|
||||
targetAddr := fmt.Sprintf("%s:%d", actualDstIP, dstPort)
|
||||
|
||||
// Create context with timeout for connection establishment
|
||||
ctx, cancel := context.WithTimeout(context.Background(), tcpConnectTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Dial the actual target using standard net package
|
||||
var d net.Dialer
|
||||
targetConn, err := d.DialContext(ctx, "tcp", targetAddr)
|
||||
if err != nil {
|
||||
logger.Info("TCP Forwarder: Failed to connect to %s: %v", targetAddr, err)
|
||||
// Connection failed, netstack will handle RST
|
||||
return
|
||||
}
|
||||
defer targetConn.Close()
|
||||
|
||||
logger.Info("TCP Forwarder: Successfully connected to %s, starting bidirectional copy", targetAddr)
|
||||
|
||||
// Bidirectional copy between netstack and target
|
||||
pipeTCP(netstackConn, targetConn)
|
||||
}
|
||||
|
||||
// pipeTCP copies data bidirectionally between two connections
|
||||
func pipeTCP(origin, remote net.Conn) {
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
||||
go unidirectionalStreamTCP(remote, origin, "origin->remote", &wg)
|
||||
go unidirectionalStreamTCP(origin, remote, "remote->origin", &wg)
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// unidirectionalStreamTCP copies data in one direction
|
||||
func unidirectionalStreamTCP(dst, src net.Conn, dir string, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
|
||||
buf := make([]byte, bufferSize)
|
||||
_, _ = io.CopyBuffer(dst, src, buf)
|
||||
|
||||
// Do the upload/download side TCP half-close
|
||||
if cr, ok := src.(interface{ CloseRead() error }); ok {
|
||||
cr.CloseRead()
|
||||
}
|
||||
if cw, ok := dst.(interface{ CloseWrite() error }); ok {
|
||||
cw.CloseWrite()
|
||||
}
|
||||
|
||||
// Set TCP half-close timeout
|
||||
dst.SetReadDeadline(time.Now().Add(tcpWaitTimeout))
|
||||
}
|
||||
|
||||
// setTCPSocketOptions sets TCP socket options for better performance
|
||||
func setTCPSocketOptions(s *stack.Stack, ep tcpip.Endpoint) {
|
||||
// TCP keepalive options
|
||||
ep.SocketOptions().SetKeepAlive(true)
|
||||
|
||||
idle := tcpip.KeepaliveIdleOption(tcpKeepaliveIdle)
|
||||
ep.SetSockOpt(&idle)
|
||||
|
||||
interval := tcpip.KeepaliveIntervalOption(tcpKeepaliveInterval)
|
||||
ep.SetSockOpt(&interval)
|
||||
|
||||
ep.SetSockOptInt(tcpip.KeepaliveCountOption, tcpKeepaliveCount)
|
||||
|
||||
// TCP send/recv buffer size
|
||||
var ss tcpip.TCPSendBufferSizeRangeOption
|
||||
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &ss); err == nil {
|
||||
ep.SocketOptions().SetSendBufferSize(int64(ss.Default), false)
|
||||
}
|
||||
|
||||
var rs tcpip.TCPReceiveBufferSizeRangeOption
|
||||
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &rs); err == nil {
|
||||
ep.SocketOptions().SetReceiveBufferSize(int64(rs.Default), false)
|
||||
}
|
||||
}
|
||||
|
||||
// InstallUDPHandler installs the UDP forwarder on the stack
|
||||
func (h *UDPHandler) InstallUDPHandler() error {
|
||||
udpForwarder := udp.NewForwarder(h.stack, func(r *udp.ForwarderRequest) {
|
||||
var (
|
||||
wq waiter.Queue
|
||||
id = r.ID()
|
||||
)
|
||||
|
||||
ep, err := r.CreateEndpoint(&wq)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Create UDP connection from netstack endpoint
|
||||
netstackConn := gonet.NewUDPConn(&wq, ep)
|
||||
|
||||
// Handle the connection in a goroutine
|
||||
go h.handleUDPConn(netstackConn, id)
|
||||
})
|
||||
|
||||
h.stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleUDPConn handles a UDP connection by proxying it to the actual target
|
||||
func (h *UDPHandler) handleUDPConn(netstackConn *gonet.UDPConn, id stack.TransportEndpointID) {
|
||||
defer netstackConn.Close()
|
||||
|
||||
// Extract source and target address from the connection ID
|
||||
srcIP := id.RemoteAddress.String()
|
||||
srcPort := id.RemotePort
|
||||
dstIP := id.LocalAddress.String()
|
||||
dstPort := id.LocalPort
|
||||
|
||||
logger.Info("UDP Forwarder: Handling connection %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
// Check if there's a destination rewrite for this connection (e.g., localhost targets)
|
||||
actualDstIP := dstIP
|
||||
if h.proxyHandler != nil {
|
||||
if rewrittenAddr, ok := h.proxyHandler.LookupDestinationRewrite(srcIP, dstIP, dstPort, uint8(udp.ProtocolNumber)); ok {
|
||||
actualDstIP = rewrittenAddr.String()
|
||||
logger.Info("UDP Forwarder: Using rewritten destination %s (original: %s)", actualDstIP, dstIP)
|
||||
}
|
||||
}
|
||||
|
||||
targetAddr := fmt.Sprintf("%s:%d", actualDstIP, dstPort)
|
||||
|
||||
// Resolve target address
|
||||
remoteUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr)
|
||||
if err != nil {
|
||||
logger.Info("UDP Forwarder: Failed to resolve %s: %v", targetAddr, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve client address (for sending responses back)
|
||||
clientAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", srcIP, srcPort))
|
||||
if err != nil {
|
||||
logger.Info("UDP Forwarder: Failed to resolve client %s:%d: %v", srcIP, srcPort, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create unconnected UDP socket (so we can use WriteTo)
|
||||
targetConn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
logger.Info("UDP Forwarder: Failed to create UDP socket: %v", err)
|
||||
return
|
||||
}
|
||||
defer targetConn.Close()
|
||||
|
||||
logger.Info("UDP Forwarder: Successfully created UDP socket for %s, starting bidirectional copy", targetAddr)
|
||||
|
||||
// Bidirectional copy between netstack and target
|
||||
pipeUDP(netstackConn, targetConn, remoteUDPAddr, clientAddr, udpSessionTimeout)
|
||||
}
|
||||
|
||||
// pipeUDP copies UDP packets bidirectionally
|
||||
func pipeUDP(origin, remote net.PacketConn, serverAddr, clientAddr net.Addr, timeout time.Duration) {
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
||||
// Read from origin (netstack), write to remote (target server)
|
||||
go unidirectionalPacketStream(remote, origin, serverAddr, "origin->remote", &wg, timeout)
|
||||
// Read from remote (target server), write to origin (netstack) with client address
|
||||
go unidirectionalPacketStream(origin, remote, clientAddr, "remote->origin", &wg, timeout)
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// unidirectionalPacketStream copies packets in one direction
|
||||
func unidirectionalPacketStream(dst, src net.PacketConn, to net.Addr, dir string, wg *sync.WaitGroup, timeout time.Duration) {
|
||||
defer wg.Done()
|
||||
|
||||
logger.Info("UDP %s: Starting packet stream (to=%v)", dir, to)
|
||||
err := copyPacketData(dst, src, to, timeout)
|
||||
if err != nil {
|
||||
logger.Info("UDP %s: Stream ended with error: %v", dir, err)
|
||||
} else {
|
||||
logger.Info("UDP %s: Stream ended (timeout)", dir)
|
||||
}
|
||||
}
|
||||
|
||||
// copyPacketData copies UDP packet data with timeout
|
||||
func copyPacketData(dst, src net.PacketConn, to net.Addr, timeout time.Duration) error {
|
||||
buf := make([]byte, 65535) // Max UDP packet size
|
||||
|
||||
for {
|
||||
src.SetReadDeadline(time.Now().Add(timeout))
|
||||
n, srcAddr, err := src.ReadFrom(buf)
|
||||
if ne, ok := err.(net.Error); ok && ne.Timeout() {
|
||||
return nil // ignore I/O timeout
|
||||
} else if err == io.EOF {
|
||||
return nil // ignore EOF
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Info("UDP copyPacketData: Read %d bytes from %v", n, srcAddr)
|
||||
|
||||
// Determine write destination
|
||||
writeAddr := to
|
||||
if writeAddr == nil {
|
||||
// If no destination specified, use the source address from the packet
|
||||
writeAddr = srcAddr
|
||||
}
|
||||
|
||||
written, err := dst.WriteTo(buf[:n], writeAddr)
|
||||
if err != nil {
|
||||
logger.Info("UDP copyPacketData: Write error to %v: %v", writeAddr, err)
|
||||
return err
|
||||
}
|
||||
logger.Info("UDP copyPacketData: Wrote %d bytes to %v", written, writeAddr)
|
||||
|
||||
dst.SetReadDeadline(time.Now().Add(timeout))
|
||||
}
|
||||
}
|
||||
|
||||
// InstallICMPHandler installs the ICMP handler on the stack
|
||||
func (h *ICMPHandler) InstallICMPHandler() error {
|
||||
h.stack.SetTransportProtocolHandler(header.ICMPv4ProtocolNumber, h.handleICMPPacket)
|
||||
logger.Debug("ICMP Handler: Installed ICMP protocol handler")
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleICMPPacket handles incoming ICMP packets
|
||||
func (h *ICMPHandler) handleICMPPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
|
||||
logger.Debug("ICMP Handler: Received ICMP packet from %s to %s", id.RemoteAddress, id.LocalAddress)
|
||||
|
||||
// Get the ICMP header from the packet
|
||||
icmpData := pkt.TransportHeader().Slice()
|
||||
if len(icmpData) < header.ICMPv4MinimumSize {
|
||||
logger.Debug("ICMP Handler: Packet too small for ICMP header: %d bytes", len(icmpData))
|
||||
return false
|
||||
}
|
||||
|
||||
icmpHdr := header.ICMPv4(icmpData)
|
||||
icmpType := icmpHdr.Type()
|
||||
icmpCode := icmpHdr.Code()
|
||||
|
||||
logger.Debug("ICMP Handler: Type=%d, Code=%d, Ident=%d, Seq=%d",
|
||||
icmpType, icmpCode, icmpHdr.Ident(), icmpHdr.Sequence())
|
||||
|
||||
// Only handle Echo Request (ping)
|
||||
if icmpType != header.ICMPv4Echo {
|
||||
logger.Debug("ICMP Handler: Ignoring non-echo ICMP type: %d", icmpType)
|
||||
return false
|
||||
}
|
||||
|
||||
// Extract source and destination addresses
|
||||
srcIP := id.RemoteAddress.String()
|
||||
dstIP := id.LocalAddress.String()
|
||||
|
||||
logger.Info("ICMP Handler: Echo Request from %s to %s (ident=%d, seq=%d)",
|
||||
srcIP, dstIP, icmpHdr.Ident(), icmpHdr.Sequence())
|
||||
|
||||
// Convert to netip.Addr for subnet matching
|
||||
srcAddr, err := netip.ParseAddr(srcIP)
|
||||
if err != nil {
|
||||
logger.Debug("ICMP Handler: Failed to parse source IP %s: %v", srcIP, err)
|
||||
return false
|
||||
}
|
||||
dstAddr, err := netip.ParseAddr(dstIP)
|
||||
if err != nil {
|
||||
logger.Debug("ICMP Handler: Failed to parse dest IP %s: %v", dstIP, err)
|
||||
return false
|
||||
}
|
||||
|
||||
// Check subnet rules (use port 0 for ICMP since it doesn't have ports)
|
||||
if h.proxyHandler == nil {
|
||||
logger.Debug("ICMP Handler: No proxy handler configured")
|
||||
return false
|
||||
}
|
||||
|
||||
matchedRule := h.proxyHandler.subnetLookup.Match(srcAddr, dstAddr, 0, header.ICMPv4ProtocolNumber)
|
||||
if matchedRule == nil {
|
||||
logger.Debug("ICMP Handler: No matching subnet rule for %s -> %s", srcIP, dstIP)
|
||||
return false
|
||||
}
|
||||
|
||||
logger.Info("ICMP Handler: Matched subnet rule for %s -> %s", srcIP, dstIP)
|
||||
|
||||
// Determine actual destination (with possible rewrite)
|
||||
actualDstIP := dstIP
|
||||
if matchedRule.RewriteTo != "" {
|
||||
resolvedAddr, err := h.proxyHandler.resolveRewriteAddress(matchedRule.RewriteTo)
|
||||
if err != nil {
|
||||
logger.Info("ICMP Handler: Failed to resolve rewrite address %s: %v", matchedRule.RewriteTo, err)
|
||||
} else {
|
||||
actualDstIP = resolvedAddr.String()
|
||||
logger.Info("ICMP Handler: Using rewritten destination %s (original: %s)", actualDstIP, dstIP)
|
||||
}
|
||||
}
|
||||
|
||||
// Get the full ICMP payload (including the data after the header)
|
||||
icmpPayload := pkt.Data().AsRange().ToSlice()
|
||||
|
||||
// Handle the ping in a goroutine to avoid blocking
|
||||
go h.proxyPing(srcIP, dstIP, actualDstIP, icmpHdr.Ident(), icmpHdr.Sequence(), icmpPayload)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// proxyPing sends a ping to the actual destination and injects the reply back
|
||||
func (h *ICMPHandler) proxyPing(srcIP, originalDstIP, actualDstIP string, ident, seq uint16, payload []byte) {
|
||||
logger.Debug("ICMP Handler: Proxying ping from %s to %s (actual: %s), ident=%d, seq=%d",
|
||||
srcIP, originalDstIP, actualDstIP, ident, seq)
|
||||
|
||||
// Try three methods in order: ip4:icmp -> udp4 -> ping command
|
||||
// Track which method succeeded so we can handle identifier matching correctly
|
||||
method, success := h.tryICMPMethods(actualDstIP, ident, seq, payload)
|
||||
|
||||
if !success {
|
||||
logger.Info("ICMP Handler: All ping methods failed for %s", actualDstIP)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("ICMP Handler: Ping successful to %s using %s, injecting reply (ident=%d, seq=%d)",
|
||||
actualDstIP, method, ident, seq)
|
||||
|
||||
// Build the reply packet to inject back into the netstack
|
||||
// The reply should appear to come from the original destination (before rewrite)
|
||||
h.injectICMPReply(srcIP, originalDstIP, ident, seq, payload)
|
||||
}
|
||||
|
||||
// tryICMPMethods tries all available ICMP methods in order
|
||||
func (h *ICMPHandler) tryICMPMethods(actualDstIP string, ident, seq uint16, payload []byte) (string, bool) {
|
||||
if h.tryRawICMP(actualDstIP, ident, seq, payload, false) {
|
||||
return "raw ICMP", true
|
||||
}
|
||||
if h.tryUnprivilegedICMP(actualDstIP, ident, seq, payload) {
|
||||
return "unprivileged ICMP", true
|
||||
}
|
||||
if h.tryPingCommand(actualDstIP, ident, seq, payload) {
|
||||
return "ping command", true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// tryRawICMP attempts to ping using raw ICMP sockets (requires CAP_NET_RAW or root)
|
||||
func (h *ICMPHandler) tryRawICMP(actualDstIP string, ident, seq uint16, payload []byte, ignoreIdent bool) bool {
|
||||
conn, err := icmp.ListenPacket("ip4:icmp", "0.0.0.0")
|
||||
if err != nil {
|
||||
logger.Debug("ICMP Handler: Raw ICMP socket not available: %v", err)
|
||||
return false
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
logger.Debug("ICMP Handler: Using raw ICMP socket")
|
||||
return h.sendAndReceiveICMP(conn, actualDstIP, ident, seq, payload, false, ignoreIdent)
|
||||
}
|
||||
|
||||
// tryUnprivilegedICMP attempts to ping using unprivileged ICMP (requires ping_group_range configured)
|
||||
func (h *ICMPHandler) tryUnprivilegedICMP(actualDstIP string, ident, seq uint16, payload []byte) bool {
|
||||
conn, err := icmp.ListenPacket("udp4", "0.0.0.0")
|
||||
if err != nil {
|
||||
logger.Debug("ICMP Handler: Unprivileged ICMP socket not available: %v", err)
|
||||
return false
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
logger.Debug("ICMP Handler: Using unprivileged ICMP socket")
|
||||
// Unprivileged ICMP doesn't let us control the identifier, so we ignore it in matching
|
||||
return h.sendAndReceiveICMP(conn, actualDstIP, ident, seq, payload, true, true)
|
||||
}
|
||||
|
||||
// sendAndReceiveICMP sends an ICMP echo request and waits for the reply
|
||||
func (h *ICMPHandler) sendAndReceiveICMP(conn *icmp.PacketConn, actualDstIP string, ident, seq uint16, payload []byte, isUnprivileged bool, ignoreIdent bool) bool {
|
||||
// Build the ICMP echo request message
|
||||
echoMsg := &icmp.Message{
|
||||
Type: ipv4.ICMPTypeEcho,
|
||||
Code: 0,
|
||||
Body: &icmp.Echo{
|
||||
ID: int(ident),
|
||||
Seq: int(seq),
|
||||
Data: payload,
|
||||
},
|
||||
}
|
||||
|
||||
msgBytes, err := echoMsg.Marshal(nil)
|
||||
if err != nil {
|
||||
logger.Debug("ICMP Handler: Failed to marshal ICMP message: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
// Resolve destination address based on socket type
|
||||
var writeErr error
|
||||
if isUnprivileged {
|
||||
// For unprivileged ICMP, use UDP-style addressing
|
||||
udpAddr := &net.UDPAddr{IP: net.ParseIP(actualDstIP)}
|
||||
logger.Debug("ICMP Handler: Sending ping to %s (unprivileged)", udpAddr.String())
|
||||
conn.SetDeadline(time.Now().Add(icmpTimeout))
|
||||
_, writeErr = conn.WriteTo(msgBytes, udpAddr)
|
||||
} else {
|
||||
// For raw ICMP, use IP addressing
|
||||
dst, err := net.ResolveIPAddr("ip4", actualDstIP)
|
||||
if err != nil {
|
||||
logger.Debug("ICMP Handler: Failed to resolve destination %s: %v", actualDstIP, err)
|
||||
return false
|
||||
}
|
||||
logger.Debug("ICMP Handler: Sending ping to %s (raw)", dst.String())
|
||||
conn.SetDeadline(time.Now().Add(icmpTimeout))
|
||||
_, writeErr = conn.WriteTo(msgBytes, dst)
|
||||
}
|
||||
|
||||
if writeErr != nil {
|
||||
logger.Debug("ICMP Handler: Failed to send ping to %s: %v", actualDstIP, writeErr)
|
||||
return false
|
||||
}
|
||||
|
||||
logger.Debug("ICMP Handler: Ping sent to %s, waiting for reply (ident=%d, seq=%d)", actualDstIP, ident, seq)
|
||||
|
||||
// Wait for reply - loop to filter out non-matching packets
|
||||
replyBuf := make([]byte, 1500)
|
||||
|
||||
for {
|
||||
n, peer, err := conn.ReadFrom(replyBuf)
|
||||
if err != nil {
|
||||
logger.Debug("ICMP Handler: Failed to receive ping reply from %s: %v", actualDstIP, err)
|
||||
return false
|
||||
}
|
||||
|
||||
logger.Debug("ICMP Handler: Received %d bytes from %s", n, peer.String())
|
||||
|
||||
// Parse the reply
|
||||
replyMsg, err := icmp.ParseMessage(1, replyBuf[:n])
|
||||
if err != nil {
|
||||
logger.Debug("ICMP Handler: Failed to parse ICMP message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if it's an echo reply (type 0), not an echo request (type 8)
|
||||
if replyMsg.Type != ipv4.ICMPTypeEchoReply {
|
||||
logger.Debug("ICMP Handler: Received non-echo-reply type: %v, continuing to wait", replyMsg.Type)
|
||||
continue
|
||||
}
|
||||
|
||||
reply, ok := replyMsg.Body.(*icmp.Echo)
|
||||
if !ok {
|
||||
logger.Debug("ICMP Handler: Invalid echo reply body type, continuing to wait")
|
||||
continue
|
||||
}
|
||||
|
||||
// Verify the sequence matches what we sent
|
||||
// For unprivileged ICMP, the kernel controls the identifier, so we only check sequence
|
||||
if reply.Seq != int(seq) {
|
||||
logger.Debug("ICMP Handler: Reply seq mismatch: got seq=%d, want seq=%d", reply.Seq, seq)
|
||||
continue
|
||||
}
|
||||
|
||||
if !ignoreIdent && reply.ID != int(ident) {
|
||||
logger.Debug("ICMP Handler: Reply ident mismatch: got ident=%d, want ident=%d", reply.ID, ident)
|
||||
continue
|
||||
}
|
||||
|
||||
// Found matching reply
|
||||
logger.Debug("ICMP Handler: Received valid echo reply")
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// tryPingCommand attempts to ping using the system ping command (always works, but less control)
|
||||
func (h *ICMPHandler) tryPingCommand(actualDstIP string, ident, seq uint16, payload []byte) bool {
|
||||
logger.Debug("ICMP Handler: Attempting to use system ping command")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), icmpTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Send one ping with timeout
|
||||
// -c 1: count = 1 packet
|
||||
// -W 5: timeout = 5 seconds
|
||||
// -q: quiet output (just summary)
|
||||
cmd := exec.CommandContext(ctx, "ping", "-c", "1", "-W", "5", "-q", actualDstIP)
|
||||
output, err := cmd.CombinedOutput()
|
||||
|
||||
if err != nil {
|
||||
logger.Debug("ICMP Handler: System ping command failed: %v, output: %s", err, string(output))
|
||||
return false
|
||||
}
|
||||
|
||||
logger.Debug("ICMP Handler: System ping command succeeded")
|
||||
return true
|
||||
}
|
||||
|
||||
// injectICMPReply creates an ICMP echo reply packet and queues it to be sent back through the tunnel
|
||||
func (h *ICMPHandler) injectICMPReply(dstIP, srcIP string, ident, seq uint16, payload []byte) {
|
||||
logger.Debug("ICMP Handler: Creating reply from %s to %s (ident=%d, seq=%d)",
|
||||
srcIP, dstIP, ident, seq)
|
||||
|
||||
// Parse addresses
|
||||
srcAddr, err := netip.ParseAddr(srcIP)
|
||||
if err != nil {
|
||||
logger.Info("ICMP Handler: Failed to parse source IP for reply: %v", err)
|
||||
return
|
||||
}
|
||||
dstAddr, err := netip.ParseAddr(dstIP)
|
||||
if err != nil {
|
||||
logger.Info("ICMP Handler: Failed to parse dest IP for reply: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate total packet size
|
||||
ipHeaderLen := header.IPv4MinimumSize
|
||||
icmpHeaderLen := header.ICMPv4MinimumSize
|
||||
totalLen := ipHeaderLen + icmpHeaderLen + len(payload)
|
||||
|
||||
// Create the packet buffer
|
||||
pkt := make([]byte, totalLen)
|
||||
|
||||
// Build IPv4 header
|
||||
ipHdr := header.IPv4(pkt[:ipHeaderLen])
|
||||
ipHdr.Encode(&header.IPv4Fields{
|
||||
TotalLength: uint16(totalLen),
|
||||
TTL: 64,
|
||||
Protocol: uint8(header.ICMPv4ProtocolNumber),
|
||||
SrcAddr: tcpip.AddrFrom4(srcAddr.As4()),
|
||||
DstAddr: tcpip.AddrFrom4(dstAddr.As4()),
|
||||
})
|
||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||
|
||||
// Build ICMP header
|
||||
icmpHdr := header.ICMPv4(pkt[ipHeaderLen : ipHeaderLen+icmpHeaderLen])
|
||||
icmpHdr.SetType(header.ICMPv4EchoReply)
|
||||
icmpHdr.SetCode(0)
|
||||
icmpHdr.SetIdent(ident)
|
||||
icmpHdr.SetSequence(seq)
|
||||
|
||||
// Copy payload
|
||||
copy(pkt[ipHeaderLen+icmpHeaderLen:], payload)
|
||||
|
||||
// Calculate ICMP checksum (covers ICMP header + payload)
|
||||
icmpHdr.SetChecksum(0)
|
||||
icmpData := pkt[ipHeaderLen:]
|
||||
icmpHdr.SetChecksum(^checksum.Checksum(icmpData, 0))
|
||||
|
||||
logger.Debug("ICMP Handler: Built reply packet, total length=%d", totalLen)
|
||||
|
||||
// Queue the packet to be sent back through the tunnel
|
||||
if h.proxyHandler != nil {
|
||||
if h.proxyHandler.QueueICMPReply(pkt) {
|
||||
logger.Info("ICMP Handler: Queued echo reply packet for transmission")
|
||||
} else {
|
||||
logger.Info("ICMP Handler: Failed to queue echo reply packet")
|
||||
}
|
||||
} else {
|
||||
logger.Info("ICMP Handler: Cannot queue reply - proxy handler not available")
|
||||
}
|
||||
}
|
||||
797
netstack2/proxy.go
Normal file
797
netstack2/proxy.go
Normal file
@@ -0,0 +1,797 @@
|
||||
package netstack2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/checksum"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
)
|
||||
|
||||
// PortRange represents an allowed range of ports (inclusive) with optional protocol filtering
|
||||
// Protocol can be "tcp", "udp", or "" (empty string means both protocols)
|
||||
type PortRange struct {
|
||||
Min uint16
|
||||
Max uint16
|
||||
Protocol string // "tcp", "udp", or "" for both
|
||||
}
|
||||
|
||||
// SubnetRule represents a subnet with optional port restrictions and source address
|
||||
// When RewriteTo is set, DNAT (Destination Network Address Translation) is performed:
|
||||
// - Incoming packets: destination IP is rewritten to the resolved RewriteTo address
|
||||
// - Outgoing packets: source IP is rewritten back to the original destination
|
||||
//
|
||||
// RewriteTo can be either:
|
||||
// - An IP address with CIDR notation (e.g., "192.168.1.1/32")
|
||||
// - A domain name (e.g., "example.com") which will be resolved at request time
|
||||
//
|
||||
// This allows transparent proxying where traffic appears to come from the rewritten address
|
||||
type SubnetRule struct {
|
||||
SourcePrefix netip.Prefix // Source IP prefix (who is sending)
|
||||
DestPrefix netip.Prefix // Destination IP prefix (where it's going)
|
||||
DisableIcmp bool // If true, ICMP traffic is blocked for this subnet
|
||||
RewriteTo string // Optional rewrite address for DNAT - can be IP/CIDR or domain name
|
||||
PortRanges []PortRange // empty slice means all ports allowed
|
||||
}
|
||||
|
||||
// ruleKey is used as a map key for fast O(1) lookups
|
||||
type ruleKey struct {
|
||||
sourcePrefix string
|
||||
destPrefix string
|
||||
}
|
||||
|
||||
// SubnetLookup provides fast IP subnet and port matching with O(1) lookup performance
|
||||
type SubnetLookup struct {
|
||||
mu sync.RWMutex
|
||||
rules map[ruleKey]*SubnetRule // Map for O(1) lookups by prefix combination
|
||||
}
|
||||
|
||||
// NewSubnetLookup creates a new subnet lookup table
|
||||
func NewSubnetLookup() *SubnetLookup {
|
||||
return &SubnetLookup{
|
||||
rules: make(map[ruleKey]*SubnetRule),
|
||||
}
|
||||
}
|
||||
|
||||
// AddSubnet adds a subnet rule with source and destination prefixes and optional port restrictions
|
||||
// If portRanges is nil or empty, all ports are allowed for this subnet
|
||||
// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com")
|
||||
func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) {
|
||||
sl.mu.Lock()
|
||||
defer sl.mu.Unlock()
|
||||
|
||||
key := ruleKey{
|
||||
sourcePrefix: sourcePrefix.String(),
|
||||
destPrefix: destPrefix.String(),
|
||||
}
|
||||
|
||||
sl.rules[key] = &SubnetRule{
|
||||
SourcePrefix: sourcePrefix,
|
||||
DestPrefix: destPrefix,
|
||||
DisableIcmp: disableIcmp,
|
||||
RewriteTo: rewriteTo,
|
||||
PortRanges: portRanges,
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveSubnet removes a subnet rule from the lookup table
|
||||
func (sl *SubnetLookup) RemoveSubnet(sourcePrefix, destPrefix netip.Prefix) {
|
||||
sl.mu.Lock()
|
||||
defer sl.mu.Unlock()
|
||||
|
||||
key := ruleKey{
|
||||
sourcePrefix: sourcePrefix.String(),
|
||||
destPrefix: destPrefix.String(),
|
||||
}
|
||||
|
||||
delete(sl.rules, key)
|
||||
}
|
||||
|
||||
// Match checks if a source IP, destination IP, port, and protocol match any subnet rule
|
||||
// Returns the matched rule if ALL of these conditions are met:
|
||||
// - The source IP is in the rule's source prefix
|
||||
// - The destination IP is in the rule's destination prefix
|
||||
// - The port is in an allowed range (or no port restrictions exist)
|
||||
// - The protocol matches (or the port range allows both protocols)
|
||||
//
|
||||
// proto should be header.TCPProtocolNumber or header.UDPProtocolNumber
|
||||
// Returns nil if no rule matches
|
||||
func (sl *SubnetLookup) Match(srcIP, dstIP netip.Addr, port uint16, proto tcpip.TransportProtocolNumber) *SubnetRule {
|
||||
sl.mu.RLock()
|
||||
defer sl.mu.RUnlock()
|
||||
|
||||
// Iterate through all rules to find matching source and destination prefixes
|
||||
// This is O(n) but necessary since we need to check prefix containment, not exact match
|
||||
for _, rule := range sl.rules {
|
||||
// Check if source and destination IPs match their respective prefixes
|
||||
if !rule.SourcePrefix.Contains(srcIP) {
|
||||
continue
|
||||
}
|
||||
if !rule.DestPrefix.Contains(dstIP) {
|
||||
continue
|
||||
}
|
||||
|
||||
if rule.DisableIcmp && (proto == header.ICMPv4ProtocolNumber || proto == header.ICMPv6ProtocolNumber) {
|
||||
// ICMP is disabled for this subnet
|
||||
return nil
|
||||
}
|
||||
|
||||
// Both IPs match - now check port restrictions
|
||||
// If no port ranges specified, all ports are allowed
|
||||
if len(rule.PortRanges) == 0 {
|
||||
return rule
|
||||
}
|
||||
|
||||
// Check if port and protocol are in any of the allowed ranges
|
||||
for _, pr := range rule.PortRanges {
|
||||
if port >= pr.Min && port <= pr.Max {
|
||||
// Check protocol compatibility
|
||||
if pr.Protocol == "" {
|
||||
// Empty protocol means allow both TCP and UDP
|
||||
return rule
|
||||
}
|
||||
// Check if the packet protocol matches the port range protocol
|
||||
if (pr.Protocol == "tcp" && proto == header.TCPProtocolNumber) ||
|
||||
(pr.Protocol == "udp" && proto == header.UDPProtocolNumber) {
|
||||
return rule
|
||||
}
|
||||
// Port matches but protocol doesn't - continue checking other ranges
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// connKey uniquely identifies a connection for NAT tracking
|
||||
type connKey struct {
|
||||
srcIP string
|
||||
srcPort uint16
|
||||
dstIP string
|
||||
dstPort uint16
|
||||
proto uint8
|
||||
}
|
||||
|
||||
// destKey identifies a destination for handler lookups (without source port since it may change)
|
||||
type destKey struct {
|
||||
srcIP string
|
||||
dstIP string
|
||||
dstPort uint16
|
||||
proto uint8
|
||||
}
|
||||
|
||||
// natState tracks NAT translation state for reverse translation
|
||||
type natState struct {
|
||||
originalDst netip.Addr // Original destination before DNAT
|
||||
rewrittenTo netip.Addr // The address we rewrote to
|
||||
}
|
||||
|
||||
// ProxyHandler handles packet injection and extraction for promiscuous mode
|
||||
type ProxyHandler struct {
|
||||
proxyStack *stack.Stack
|
||||
proxyEp *channel.Endpoint
|
||||
proxyNotifyHandle *channel.NotificationHandle
|
||||
tcpHandler *TCPHandler
|
||||
udpHandler *UDPHandler
|
||||
icmpHandler *ICMPHandler
|
||||
subnetLookup *SubnetLookup
|
||||
natTable map[connKey]*natState
|
||||
destRewriteTable map[destKey]netip.Addr // Maps original dest to rewritten dest for handler lookups
|
||||
natMu sync.RWMutex
|
||||
enabled bool
|
||||
icmpReplies chan []byte // Channel for ICMP reply packets to be sent back through the tunnel
|
||||
notifiable channel.Notification // Notification handler for triggering reads
|
||||
}
|
||||
|
||||
// ProxyHandlerOptions configures the proxy handler
|
||||
type ProxyHandlerOptions struct {
|
||||
EnableTCP bool
|
||||
EnableUDP bool
|
||||
EnableICMP bool
|
||||
MTU int
|
||||
}
|
||||
|
||||
// NewProxyHandler creates a new proxy handler for promiscuous mode
|
||||
func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
|
||||
if !options.EnableTCP && !options.EnableUDP && !options.EnableICMP {
|
||||
return nil, nil // No proxy needed
|
||||
}
|
||||
|
||||
handler := &ProxyHandler{
|
||||
enabled: true,
|
||||
subnetLookup: NewSubnetLookup(),
|
||||
natTable: make(map[connKey]*natState),
|
||||
destRewriteTable: make(map[destKey]netip.Addr),
|
||||
icmpReplies: make(chan []byte, 256), // Buffer for ICMP reply packets
|
||||
proxyEp: channel.New(1024, uint32(options.MTU), ""),
|
||||
proxyStack: stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{
|
||||
ipv4.NewProtocol,
|
||||
ipv6.NewProtocol,
|
||||
},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{
|
||||
tcp.NewProtocol,
|
||||
udp.NewProtocol,
|
||||
icmp.NewProtocol4,
|
||||
icmp.NewProtocol6,
|
||||
},
|
||||
}),
|
||||
}
|
||||
|
||||
// Initialize TCP handler if enabled
|
||||
if options.EnableTCP {
|
||||
handler.tcpHandler = NewTCPHandler(handler.proxyStack, handler)
|
||||
if err := handler.tcpHandler.InstallTCPHandler(); err != nil {
|
||||
return nil, fmt.Errorf("failed to install TCP handler: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize UDP handler if enabled
|
||||
if options.EnableUDP {
|
||||
handler.udpHandler = NewUDPHandler(handler.proxyStack, handler)
|
||||
if err := handler.udpHandler.InstallUDPHandler(); err != nil {
|
||||
return nil, fmt.Errorf("failed to install UDP handler: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize ICMP handler if enabled
|
||||
if options.EnableICMP {
|
||||
handler.icmpHandler = NewICMPHandler(handler.proxyStack, handler)
|
||||
if err := handler.icmpHandler.InstallICMPHandler(); err != nil {
|
||||
return nil, fmt.Errorf("failed to install ICMP handler: %v", err)
|
||||
}
|
||||
logger.Debug("ProxyHandler: ICMP handler enabled")
|
||||
}
|
||||
|
||||
// // Example 1: Add a rule with no port restrictions (all ports allowed)
|
||||
// // This accepts all traffic FROM 10.0.0.0/24 TO 10.20.20.0/24
|
||||
// sourceSubnet := netip.MustParsePrefix("10.0.0.0/24")
|
||||
// destSubnet := netip.MustParsePrefix("10.20.20.0/24")
|
||||
// handler.AddSubnetRule(sourceSubnet, destSubnet, nil)
|
||||
|
||||
// // Example 2: Add a rule with specific port ranges
|
||||
// // This accepts traffic FROM 10.0.0.5/32 TO 10.20.21.21/32 only on ports 80, 443, and 8000-9000
|
||||
// sourceIP := netip.MustParsePrefix("10.0.0.5/32")
|
||||
// destIP := netip.MustParsePrefix("10.20.21.21/32")
|
||||
// handler.AddSubnetRule(sourceIP, destIP, []PortRange{
|
||||
// {Min: 80, Max: 80},
|
||||
// {Min: 443, Max: 443},
|
||||
// {Min: 8000, Max: 9000},
|
||||
// })
|
||||
|
||||
return handler, nil
|
||||
}
|
||||
|
||||
// AddSubnetRule adds a subnet with optional port restrictions to the proxy handler
|
||||
// sourcePrefix: The IP prefix of the peer sending the data
|
||||
// destPrefix: The IP prefix of the destination
|
||||
// rewriteTo: Optional address to rewrite destination to - can be IP/CIDR or domain name
|
||||
// If portRanges is nil or empty, all ports are allowed for this subnet
|
||||
func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) {
|
||||
if p == nil || !p.enabled {
|
||||
return
|
||||
}
|
||||
p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp)
|
||||
}
|
||||
|
||||
// RemoveSubnetRule removes a subnet from the proxy handler
|
||||
func (p *ProxyHandler) RemoveSubnetRule(sourcePrefix, destPrefix netip.Prefix) {
|
||||
if p == nil || !p.enabled {
|
||||
return
|
||||
}
|
||||
p.subnetLookup.RemoveSubnet(sourcePrefix, destPrefix)
|
||||
}
|
||||
|
||||
// LookupDestinationRewrite looks up the rewritten destination for a connection
|
||||
// This is used by TCP/UDP handlers to find the actual target address
|
||||
func (p *ProxyHandler) LookupDestinationRewrite(srcIP, dstIP string, dstPort uint16, proto uint8) (netip.Addr, bool) {
|
||||
if p == nil || !p.enabled {
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
|
||||
key := destKey{
|
||||
srcIP: srcIP,
|
||||
dstIP: dstIP,
|
||||
dstPort: dstPort,
|
||||
proto: proto,
|
||||
}
|
||||
|
||||
p.natMu.RLock()
|
||||
defer p.natMu.RUnlock()
|
||||
|
||||
addr, ok := p.destRewriteTable[key]
|
||||
return addr, ok
|
||||
}
|
||||
|
||||
// resolveRewriteAddress resolves a rewrite address which can be either:
|
||||
// - An IP address with CIDR notation (e.g., "192.168.1.1/32") - returns the IP directly
|
||||
// - A plain IP address (e.g., "192.168.1.1") - returns the IP directly
|
||||
// - A domain name (e.g., "example.com") - performs DNS lookup
|
||||
func (p *ProxyHandler) resolveRewriteAddress(rewriteTo string) (netip.Addr, error) {
|
||||
logger.Debug("Resolving rewrite address: %s", rewriteTo)
|
||||
|
||||
// First, try to parse as a CIDR prefix (e.g., "192.168.1.1/32")
|
||||
if prefix, err := netip.ParsePrefix(rewriteTo); err == nil {
|
||||
return prefix.Addr(), nil
|
||||
}
|
||||
|
||||
// Try to parse as a plain IP address (e.g., "192.168.1.1")
|
||||
if addr, err := netip.ParseAddr(rewriteTo); err == nil {
|
||||
return addr, nil
|
||||
}
|
||||
|
||||
// Not an IP address, treat as domain name - perform DNS lookup
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
ips, err := net.DefaultResolver.LookupIP(ctx, "ip4", rewriteTo)
|
||||
if err != nil {
|
||||
return netip.Addr{}, fmt.Errorf("failed to resolve domain %s: %w", rewriteTo, err)
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
return netip.Addr{}, fmt.Errorf("no IP addresses found for domain %s", rewriteTo)
|
||||
}
|
||||
|
||||
// Use the first resolved IP address
|
||||
ip := ips[0]
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
addr := netip.AddrFrom4([4]byte{ip4[0], ip4[1], ip4[2], ip4[3]})
|
||||
logger.Debug("Resolved %s to %s", rewriteTo, addr)
|
||||
return addr, nil
|
||||
}
|
||||
|
||||
return netip.Addr{}, fmt.Errorf("no IPv4 address found for domain %s", rewriteTo)
|
||||
}
|
||||
|
||||
// Initialize sets up the promiscuous NIC with the netTun's notification system
|
||||
func (p *ProxyHandler) Initialize(notifiable channel.Notification) error {
|
||||
if p == nil || !p.enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Store notifiable for triggering notifications on ICMP replies
|
||||
p.notifiable = notifiable
|
||||
|
||||
// Add notification handler
|
||||
p.proxyNotifyHandle = p.proxyEp.AddNotify(notifiable)
|
||||
|
||||
// Create NIC with promiscuous mode
|
||||
tcpipErr := p.proxyStack.CreateNICWithOptions(1, p.proxyEp, stack.NICOptions{
|
||||
Disabled: false,
|
||||
QDisc: nil,
|
||||
})
|
||||
if tcpipErr != nil {
|
||||
return fmt.Errorf("CreateNIC (proxy): %v", tcpipErr)
|
||||
}
|
||||
|
||||
// Enable promiscuous mode - accepts packets for any destination IP
|
||||
if tcpipErr := p.proxyStack.SetPromiscuousMode(1, true); tcpipErr != nil {
|
||||
return fmt.Errorf("SetPromiscuousMode: %v", tcpipErr)
|
||||
}
|
||||
|
||||
// Enable spoofing - allows sending packets from any source IP
|
||||
if tcpipErr := p.proxyStack.SetSpoofing(1, true); tcpipErr != nil {
|
||||
return fmt.Errorf("SetSpoofing: %v", tcpipErr)
|
||||
}
|
||||
|
||||
// Add default route
|
||||
p.proxyStack.AddRoute(tcpip.Route{
|
||||
Destination: header.IPv4EmptySubnet,
|
||||
NIC: 1,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleIncomingPacket processes incoming packets and determines if they should
|
||||
// be injected into the proxy stack
|
||||
func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
|
||||
if p == nil || !p.enabled {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check minimum packet size
|
||||
if len(packet) < header.IPv4MinimumSize {
|
||||
return false
|
||||
}
|
||||
|
||||
// Only handle IPv4 for now
|
||||
if packet[0]>>4 != 4 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Parse IPv4 header
|
||||
ipv4Header := header.IPv4(packet)
|
||||
srcIP := ipv4Header.SourceAddress()
|
||||
dstIP := ipv4Header.DestinationAddress()
|
||||
|
||||
// Convert gvisor tcpip.Address to netip.Addr
|
||||
srcBytes := srcIP.As4()
|
||||
srcAddr := netip.AddrFrom4(srcBytes)
|
||||
dstBytes := dstIP.As4()
|
||||
dstAddr := netip.AddrFrom4(dstBytes)
|
||||
|
||||
// Parse transport layer to get destination port
|
||||
var dstPort uint16
|
||||
protocol := ipv4Header.TransportProtocol()
|
||||
headerLen := int(ipv4Header.HeaderLength())
|
||||
|
||||
// Extract port based on protocol
|
||||
switch protocol {
|
||||
case header.TCPProtocolNumber:
|
||||
if len(packet) < headerLen+header.TCPMinimumSize {
|
||||
return false
|
||||
}
|
||||
tcpHeader := header.TCP(packet[headerLen:])
|
||||
dstPort = tcpHeader.DestinationPort()
|
||||
case header.UDPProtocolNumber:
|
||||
if len(packet) < headerLen+header.UDPMinimumSize {
|
||||
return false
|
||||
}
|
||||
udpHeader := header.UDP(packet[headerLen:])
|
||||
dstPort = udpHeader.DestinationPort()
|
||||
case header.ICMPv4ProtocolNumber:
|
||||
// ICMP doesn't have ports, use port 0 (must match rules with no port restrictions)
|
||||
dstPort = 0
|
||||
logger.Debug("HandleIncomingPacket: ICMP packet from %s to %s", srcAddr, dstAddr)
|
||||
default:
|
||||
// For other protocols, use port 0 (must match rules with no port restrictions)
|
||||
dstPort = 0
|
||||
logger.Debug("HandleIncomingPacket: Unknown protocol %d from %s to %s", protocol, srcAddr, dstAddr)
|
||||
}
|
||||
|
||||
// Check if the source IP, destination IP, port, and protocol match any subnet rule
|
||||
matchedRule := p.subnetLookup.Match(srcAddr, dstAddr, dstPort, protocol)
|
||||
if matchedRule != nil {
|
||||
logger.Debug("HandleIncomingPacket: Matched rule for %s -> %s (proto=%d, port=%d)",
|
||||
srcAddr, dstAddr, protocol, dstPort)
|
||||
// Check if we need to perform DNAT
|
||||
if matchedRule.RewriteTo != "" {
|
||||
// Create connection tracking key using original destination
|
||||
// This allows us to check if we've already resolved for this connection
|
||||
var srcPort uint16
|
||||
switch protocol {
|
||||
case header.TCPProtocolNumber:
|
||||
tcpHeader := header.TCP(packet[headerLen:])
|
||||
srcPort = tcpHeader.SourcePort()
|
||||
case header.UDPProtocolNumber:
|
||||
udpHeader := header.UDP(packet[headerLen:])
|
||||
srcPort = udpHeader.SourcePort()
|
||||
}
|
||||
|
||||
// Key using original destination to track the connection
|
||||
key := connKey{
|
||||
srcIP: srcAddr.String(),
|
||||
srcPort: srcPort,
|
||||
dstIP: dstAddr.String(),
|
||||
dstPort: dstPort,
|
||||
proto: uint8(protocol),
|
||||
}
|
||||
|
||||
// Key for handler lookups (doesn't include srcPort for flexibility)
|
||||
dKey := destKey{
|
||||
srcIP: srcAddr.String(),
|
||||
dstIP: dstAddr.String(),
|
||||
dstPort: dstPort,
|
||||
proto: uint8(protocol),
|
||||
}
|
||||
|
||||
// Check if we already have a NAT entry for this connection
|
||||
p.natMu.RLock()
|
||||
existingEntry, exists := p.natTable[key]
|
||||
p.natMu.RUnlock()
|
||||
|
||||
var newDst netip.Addr
|
||||
if exists {
|
||||
// Use the previously resolved address for this connection
|
||||
newDst = existingEntry.rewrittenTo
|
||||
logger.Debug("Using existing NAT entry for connection: %s -> %s", dstAddr, newDst)
|
||||
} else {
|
||||
// New connection - resolve the rewrite address
|
||||
var err error
|
||||
newDst, err = p.resolveRewriteAddress(matchedRule.RewriteTo)
|
||||
if err != nil {
|
||||
// Failed to resolve, skip DNAT but still proxy the packet
|
||||
logger.Debug("Failed to resolve rewrite address: %v", err)
|
||||
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(packet),
|
||||
})
|
||||
p.proxyEp.InjectInbound(header.IPv4ProtocolNumber, pkb)
|
||||
return true
|
||||
}
|
||||
|
||||
// Store NAT state for this connection
|
||||
p.natMu.Lock()
|
||||
p.natTable[key] = &natState{
|
||||
originalDst: dstAddr,
|
||||
rewrittenTo: newDst,
|
||||
}
|
||||
// Store destination rewrite for handler lookups
|
||||
p.destRewriteTable[dKey] = newDst
|
||||
p.natMu.Unlock()
|
||||
logger.Debug("New NAT entry for connection: %s -> %s", dstAddr, newDst)
|
||||
}
|
||||
|
||||
// Check if target is loopback - if so, don't rewrite packet destination
|
||||
// as gVisor will drop martian packets. Instead, the handlers will use
|
||||
// destRewriteTable to find the actual target address.
|
||||
if !newDst.IsLoopback() {
|
||||
// Rewrite the packet only for non-loopback destinations
|
||||
packet = p.rewritePacketDestination(packet, newDst)
|
||||
if packet == nil {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
logger.Debug("Target is loopback, not rewriting packet - handlers will use rewrite table")
|
||||
}
|
||||
}
|
||||
|
||||
// Inject into proxy stack
|
||||
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(packet),
|
||||
})
|
||||
p.proxyEp.InjectInbound(header.IPv4ProtocolNumber, pkb)
|
||||
logger.Debug("HandleIncomingPacket: Injected packet into proxy stack (proto=%d)", protocol)
|
||||
return true
|
||||
}
|
||||
|
||||
// logger.Debug("HandleIncomingPacket: No matching rule for %s -> %s (proto=%d, port=%d)",
|
||||
// srcAddr, dstAddr, protocol, dstPort)
|
||||
return false
|
||||
}
|
||||
|
||||
// rewritePacketDestination rewrites the destination IP in a packet and recalculates checksums
|
||||
func (p *ProxyHandler) rewritePacketDestination(packet []byte, newDst netip.Addr) []byte {
|
||||
if len(packet) < header.IPv4MinimumSize {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Make a copy to avoid modifying the original
|
||||
pkt := make([]byte, len(packet))
|
||||
copy(pkt, packet)
|
||||
|
||||
ipv4Header := header.IPv4(pkt)
|
||||
headerLen := int(ipv4Header.HeaderLength())
|
||||
|
||||
// Rewrite destination IP
|
||||
newDstBytes := newDst.As4()
|
||||
newDstAddr := tcpip.AddrFrom4(newDstBytes)
|
||||
ipv4Header.SetDestinationAddress(newDstAddr)
|
||||
|
||||
// Recalculate IP checksum
|
||||
ipv4Header.SetChecksum(0)
|
||||
ipv4Header.SetChecksum(^ipv4Header.CalculateChecksum())
|
||||
|
||||
// Update transport layer checksum if needed
|
||||
protocol := ipv4Header.TransportProtocol()
|
||||
switch protocol {
|
||||
case header.TCPProtocolNumber:
|
||||
if len(pkt) >= headerLen+header.TCPMinimumSize {
|
||||
tcpHeader := header.TCP(pkt[headerLen:])
|
||||
tcpHeader.SetChecksum(0)
|
||||
xsum := header.PseudoHeaderChecksum(
|
||||
header.TCPProtocolNumber,
|
||||
ipv4Header.SourceAddress(),
|
||||
ipv4Header.DestinationAddress(),
|
||||
uint16(len(pkt)-headerLen),
|
||||
)
|
||||
xsum = checksum.Checksum(pkt[headerLen:], xsum)
|
||||
tcpHeader.SetChecksum(^xsum)
|
||||
}
|
||||
case header.UDPProtocolNumber:
|
||||
if len(pkt) >= headerLen+header.UDPMinimumSize {
|
||||
udpHeader := header.UDP(pkt[headerLen:])
|
||||
udpHeader.SetChecksum(0)
|
||||
xsum := header.PseudoHeaderChecksum(
|
||||
header.UDPProtocolNumber,
|
||||
ipv4Header.SourceAddress(),
|
||||
ipv4Header.DestinationAddress(),
|
||||
uint16(len(pkt)-headerLen),
|
||||
)
|
||||
xsum = checksum.Checksum(pkt[headerLen:], xsum)
|
||||
udpHeader.SetChecksum(^xsum)
|
||||
}
|
||||
}
|
||||
|
||||
return pkt
|
||||
}
|
||||
|
||||
// rewritePacketSource rewrites the source IP in a packet and recalculates checksums (for reverse NAT)
|
||||
func (p *ProxyHandler) rewritePacketSource(packet []byte, newSrc netip.Addr) []byte {
|
||||
if len(packet) < header.IPv4MinimumSize {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Make a copy to avoid modifying the original
|
||||
pkt := make([]byte, len(packet))
|
||||
copy(pkt, packet)
|
||||
|
||||
ipv4Header := header.IPv4(pkt)
|
||||
headerLen := int(ipv4Header.HeaderLength())
|
||||
|
||||
// Rewrite source IP
|
||||
newSrcBytes := newSrc.As4()
|
||||
newSrcAddr := tcpip.AddrFrom4(newSrcBytes)
|
||||
ipv4Header.SetSourceAddress(newSrcAddr)
|
||||
|
||||
// Recalculate IP checksum
|
||||
ipv4Header.SetChecksum(0)
|
||||
ipv4Header.SetChecksum(^ipv4Header.CalculateChecksum())
|
||||
|
||||
// Update transport layer checksum if needed
|
||||
protocol := ipv4Header.TransportProtocol()
|
||||
switch protocol {
|
||||
case header.TCPProtocolNumber:
|
||||
if len(pkt) >= headerLen+header.TCPMinimumSize {
|
||||
tcpHeader := header.TCP(pkt[headerLen:])
|
||||
tcpHeader.SetChecksum(0)
|
||||
xsum := header.PseudoHeaderChecksum(
|
||||
header.TCPProtocolNumber,
|
||||
ipv4Header.SourceAddress(),
|
||||
ipv4Header.DestinationAddress(),
|
||||
uint16(len(pkt)-headerLen),
|
||||
)
|
||||
xsum = checksum.Checksum(pkt[headerLen:], xsum)
|
||||
tcpHeader.SetChecksum(^xsum)
|
||||
}
|
||||
case header.UDPProtocolNumber:
|
||||
if len(pkt) >= headerLen+header.UDPMinimumSize {
|
||||
udpHeader := header.UDP(pkt[headerLen:])
|
||||
udpHeader.SetChecksum(0)
|
||||
xsum := header.PseudoHeaderChecksum(
|
||||
header.UDPProtocolNumber,
|
||||
ipv4Header.SourceAddress(),
|
||||
ipv4Header.DestinationAddress(),
|
||||
uint16(len(pkt)-headerLen),
|
||||
)
|
||||
xsum = checksum.Checksum(pkt[headerLen:], xsum)
|
||||
udpHeader.SetChecksum(^xsum)
|
||||
}
|
||||
}
|
||||
|
||||
return pkt
|
||||
}
|
||||
|
||||
// ReadOutgoingPacket reads packets from the proxy stack that need to be
|
||||
// sent back through the tunnel
|
||||
func (p *ProxyHandler) ReadOutgoingPacket() *buffer.View {
|
||||
if p == nil || !p.enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// First check for ICMP reply packets (non-blocking)
|
||||
select {
|
||||
case icmpReply := <-p.icmpReplies:
|
||||
logger.Debug("ReadOutgoingPacket: Returning ICMP reply packet (%d bytes)", len(icmpReply))
|
||||
return buffer.NewViewWithData(icmpReply)
|
||||
default:
|
||||
// No ICMP reply available, continue to check proxy endpoint
|
||||
}
|
||||
|
||||
pkt := p.proxyEp.Read()
|
||||
if pkt != nil {
|
||||
view := pkt.ToView()
|
||||
pkt.DecRef()
|
||||
|
||||
// Check if we need to perform reverse NAT
|
||||
packet := view.AsSlice()
|
||||
if len(packet) >= header.IPv4MinimumSize && packet[0]>>4 == 4 {
|
||||
ipv4Header := header.IPv4(packet)
|
||||
srcIP := ipv4Header.SourceAddress()
|
||||
dstIP := ipv4Header.DestinationAddress()
|
||||
protocol := ipv4Header.TransportProtocol()
|
||||
headerLen := int(ipv4Header.HeaderLength())
|
||||
|
||||
// Extract ports
|
||||
var srcPort, dstPort uint16
|
||||
switch protocol {
|
||||
case header.TCPProtocolNumber:
|
||||
if len(packet) >= headerLen+header.TCPMinimumSize {
|
||||
tcpHeader := header.TCP(packet[headerLen:])
|
||||
srcPort = tcpHeader.SourcePort()
|
||||
dstPort = tcpHeader.DestinationPort()
|
||||
}
|
||||
case header.UDPProtocolNumber:
|
||||
if len(packet) >= headerLen+header.UDPMinimumSize {
|
||||
udpHeader := header.UDP(packet[headerLen:])
|
||||
srcPort = udpHeader.SourcePort()
|
||||
dstPort = udpHeader.DestinationPort()
|
||||
}
|
||||
case header.ICMPv4ProtocolNumber:
|
||||
// ICMP packets don't need NAT translation in our implementation
|
||||
// since we construct reply packets with the correct addresses
|
||||
logger.Debug("ReadOutgoingPacket: ICMP packet from %s to %s", srcIP, dstIP)
|
||||
return view
|
||||
}
|
||||
|
||||
// Look up NAT state for reverse translation
|
||||
// The key uses the original dst (before rewrite), so for replies we need to
|
||||
// find the entry where the rewritten address matches the current source
|
||||
p.natMu.RLock()
|
||||
var natEntry *natState
|
||||
for k, entry := range p.natTable {
|
||||
// Match: reply's dst should be original src, reply's src should be rewritten dst
|
||||
if k.srcIP == dstIP.String() && k.srcPort == dstPort &&
|
||||
entry.rewrittenTo.String() == srcIP.String() && k.dstPort == srcPort &&
|
||||
k.proto == uint8(protocol) {
|
||||
natEntry = entry
|
||||
break
|
||||
}
|
||||
}
|
||||
p.natMu.RUnlock()
|
||||
|
||||
if natEntry != nil {
|
||||
// Perform reverse NAT - rewrite source to original destination
|
||||
packet = p.rewritePacketSource(packet, natEntry.originalDst)
|
||||
if packet != nil {
|
||||
return buffer.NewViewWithData(packet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return view
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueueICMPReply queues an ICMP reply packet to be sent back through the tunnel
|
||||
func (p *ProxyHandler) QueueICMPReply(packet []byte) bool {
|
||||
if p == nil || !p.enabled {
|
||||
return false
|
||||
}
|
||||
|
||||
select {
|
||||
case p.icmpReplies <- packet:
|
||||
logger.Debug("QueueICMPReply: Queued ICMP reply packet (%d bytes)", len(packet))
|
||||
// Trigger notification so WriteNotify picks up the packet
|
||||
if p.notifiable != nil {
|
||||
p.notifiable.WriteNotify()
|
||||
}
|
||||
return true
|
||||
default:
|
||||
logger.Info("QueueICMPReply: ICMP reply channel full, dropping packet")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Close cleans up the proxy handler resources
|
||||
func (p *ProxyHandler) Close() error {
|
||||
if p == nil || !p.enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close ICMP replies channel
|
||||
if p.icmpReplies != nil {
|
||||
close(p.icmpReplies)
|
||||
}
|
||||
|
||||
if p.proxyStack != nil {
|
||||
p.proxyStack.RemoveNIC(1)
|
||||
p.proxyStack.Close()
|
||||
}
|
||||
|
||||
if p.proxyEp != nil {
|
||||
if p.proxyNotifyHandle != nil {
|
||||
p.proxyEp.RemoveNotify(p.proxyNotifyHandle)
|
||||
}
|
||||
p.proxyEp.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
1152
netstack2/tun.go
Normal file
1152
netstack2/tun.go
Normal file
File diff suppressed because it is too large
Load Diff
169
network/interface.go
Normal file
169
network/interface.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
// ConfigureInterface configures a network interface with an IP address and brings it up
|
||||
func ConfigureInterface(interfaceName string, tunnelIp string, mtu int) error {
|
||||
logger.Info("The tunnel IP is: %s", tunnelIp)
|
||||
|
||||
// Parse the IP address and network
|
||||
ip, ipNet, err := net.ParseCIDR(tunnelIp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid IP address: %v", err)
|
||||
}
|
||||
|
||||
// Convert CIDR mask to dotted decimal format (e.g., 255.255.255.0)
|
||||
mask := net.IP(ipNet.Mask).String()
|
||||
destinationAddress := ip.String()
|
||||
|
||||
logger.Debug("The destination address is: %s", destinationAddress)
|
||||
|
||||
// network.SetTunnelRemoteAddress() // what does this do?
|
||||
SetIPv4Settings([]string{destinationAddress}, []string{mask})
|
||||
SetMTU(mtu)
|
||||
|
||||
if interfaceName == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "linux":
|
||||
return configureLinux(interfaceName, ip, ipNet)
|
||||
case "darwin":
|
||||
return configureDarwin(interfaceName, ip, ipNet)
|
||||
case "windows":
|
||||
return configureWindows(interfaceName, ip, ipNet)
|
||||
case "android":
|
||||
return nil
|
||||
case "ios":
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// waitForInterfaceUp polls the network interface until it's up or times out
|
||||
func waitForInterfaceUp(interfaceName string, expectedIP net.IP, timeout time.Duration) error {
|
||||
logger.Info("Waiting for interface %s to be up with IP %s", interfaceName, expectedIP)
|
||||
deadline := time.Now().Add(timeout)
|
||||
pollInterval := 500 * time.Millisecond
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
// Check if interface exists and is up
|
||||
iface, err := net.InterfaceByName(interfaceName)
|
||||
if err == nil {
|
||||
// Check if interface is up
|
||||
if iface.Flags&net.FlagUp != 0 {
|
||||
// Check if it has the expected IP
|
||||
addrs, err := iface.Addrs()
|
||||
if err == nil {
|
||||
for _, addr := range addrs {
|
||||
ipNet, ok := addr.(*net.IPNet)
|
||||
if ok && ipNet.IP.Equal(expectedIP) {
|
||||
logger.Info("Interface %s is up with correct IP", interfaceName)
|
||||
return nil // Interface is up with correct IP
|
||||
}
|
||||
}
|
||||
logger.Info("Interface %s is up but doesn't have expected IP yet", interfaceName)
|
||||
}
|
||||
} else {
|
||||
logger.Info("Interface %s exists but is not up yet", interfaceName)
|
||||
}
|
||||
} else {
|
||||
logger.Info("Interface %s not found yet: %v", interfaceName, err)
|
||||
}
|
||||
|
||||
// Wait before next check
|
||||
time.Sleep(pollInterval)
|
||||
}
|
||||
|
||||
return fmt.Errorf("timed out waiting for interface %s to be up with IP %s", interfaceName, expectedIP)
|
||||
}
|
||||
|
||||
func FindUnusedUTUN() (string, error) {
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to list interfaces: %v", err)
|
||||
}
|
||||
used := make(map[int]bool)
|
||||
re := regexp.MustCompile(`^utun(\d+)$`)
|
||||
for _, iface := range ifaces {
|
||||
if matches := re.FindStringSubmatch(iface.Name); len(matches) == 2 {
|
||||
if num, err := strconv.Atoi(matches[1]); err == nil {
|
||||
used[num] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
// Try utun0 up to utun255.
|
||||
for i := 0; i < 256; i++ {
|
||||
if !used[i] {
|
||||
return fmt.Sprintf("utun%d", i), nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("no unused utun interface found")
|
||||
}
|
||||
|
||||
func configureDarwin(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
||||
logger.Info("Configuring darwin interface: %s", interfaceName)
|
||||
|
||||
prefix, _ := ipNet.Mask.Size()
|
||||
ipStr := fmt.Sprintf("%s/%d", ip.String(), prefix)
|
||||
|
||||
cmd := exec.Command("ifconfig", interfaceName, "inet", ipStr, ip.String(), "alias")
|
||||
logger.Info("Running command: %v", cmd)
|
||||
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("ifconfig command failed: %v, output: %s", err, out)
|
||||
}
|
||||
|
||||
// Bring up the interface
|
||||
cmd = exec.Command("ifconfig", interfaceName, "up")
|
||||
logger.Info("Running command: %v", cmd)
|
||||
|
||||
out, err = cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("ifconfig up command failed: %v, output: %s", err, out)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func configureLinux(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
||||
// Get the interface
|
||||
link, err := netlink.LinkByName(interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get interface %s: %v", interfaceName, err)
|
||||
}
|
||||
|
||||
// Create the IP address attributes
|
||||
addr := &netlink.Addr{
|
||||
IPNet: &net.IPNet{
|
||||
IP: ip,
|
||||
Mask: ipNet.Mask,
|
||||
},
|
||||
}
|
||||
|
||||
// Add the IP address to the interface
|
||||
if err := netlink.AddrAdd(link, addr); err != nil {
|
||||
return fmt.Errorf("failed to add IP address: %v", err)
|
||||
}
|
||||
|
||||
// Bring up the interface
|
||||
if err := netlink.LinkSetUp(link); err != nil {
|
||||
return fmt.Errorf("failed to bring up interface: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
12
network/interface_notwindows.go
Normal file
12
network/interface_notwindows.go
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build !windows
|
||||
|
||||
package network
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
||||
return fmt.Errorf("configureWindows called on non-Windows platform")
|
||||
}
|
||||
63
network/interface_windows.go
Normal file
63
network/interface_windows.go
Normal file
@@ -0,0 +1,63 @@
|
||||
//go:build windows
|
||||
|
||||
package network
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
)
|
||||
|
||||
func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
||||
logger.Info("Configuring Windows interface: %s", interfaceName)
|
||||
|
||||
// Get the LUID for the interface
|
||||
iface, err := net.InterfaceByName(interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get interface %s: %v", interfaceName, err)
|
||||
}
|
||||
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get LUID for interface %s: %v", interfaceName, err)
|
||||
}
|
||||
|
||||
// Create the IP address prefix
|
||||
maskBits, _ := ipNet.Mask.Size()
|
||||
|
||||
// Ensure we convert to the correct IP version (IPv4 vs IPv6)
|
||||
var addr netip.Addr
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
// IPv4 address
|
||||
addr, _ = netip.AddrFromSlice(ip4)
|
||||
} else {
|
||||
// IPv6 address
|
||||
addr, _ = netip.AddrFromSlice(ip)
|
||||
}
|
||||
if !addr.IsValid() {
|
||||
return fmt.Errorf("failed to convert IP address")
|
||||
}
|
||||
prefix := netip.PrefixFrom(addr, maskBits)
|
||||
|
||||
// Add the IP address to the interface
|
||||
logger.Info("Adding IP address %s to interface %s", prefix.String(), interfaceName)
|
||||
err = luid.AddIPAddress(prefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add IP address: %v", err)
|
||||
}
|
||||
|
||||
// This was required when we were using the subprocess "netsh" command to bring up the interface.
|
||||
// With the winipcfg library, the interface should already be up after adding the IP so we dont
|
||||
// need this step anymore as far as I can tell.
|
||||
|
||||
// // Wait for the interface to be up and have the correct IP
|
||||
// err = waitForInterfaceUp(interfaceName, ip, 30*time.Second)
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("interface did not come up within timeout: %v", err)
|
||||
// }
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,195 +0,0 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/net/bpf"
|
||||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
||||
const (
|
||||
udpProtocol = 17
|
||||
// EmptyUDPSize is the size of an empty UDP packet
|
||||
EmptyUDPSize = 28
|
||||
timeout = time.Second * 10
|
||||
)
|
||||
|
||||
// Server stores data relating to the server
|
||||
type Server struct {
|
||||
Hostname string
|
||||
Addr *net.IPAddr
|
||||
Port uint16
|
||||
}
|
||||
|
||||
// PeerNet stores data about a peer's endpoint
|
||||
type PeerNet struct {
|
||||
Resolved bool
|
||||
IP net.IP
|
||||
Port uint16
|
||||
NewtID string
|
||||
}
|
||||
|
||||
// GetClientIP gets source ip address that will be used when sending data to dstIP
|
||||
func GetClientIP(dstIP net.IP) net.IP {
|
||||
routes, err := netlink.RouteGet(dstIP)
|
||||
if err != nil {
|
||||
log.Fatalln("Error getting route:", err)
|
||||
}
|
||||
return routes[0].Src
|
||||
}
|
||||
|
||||
// HostToAddr resolves a hostname, whether DNS or IP to a valid net.IPAddr
|
||||
func HostToAddr(hostStr string) *net.IPAddr {
|
||||
remoteAddrs, err := net.LookupHost(hostStr)
|
||||
if err != nil {
|
||||
log.Fatalln("Error parsing remote address:", err)
|
||||
}
|
||||
|
||||
for _, addrStr := range remoteAddrs {
|
||||
if remoteAddr, err := net.ResolveIPAddr("ip4", addrStr); err == nil {
|
||||
return remoteAddr
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetupRawConn creates an ipv4 and udp only RawConn and applies packet filtering
|
||||
func SetupRawConn(server *Server, client *PeerNet) *ipv4.RawConn {
|
||||
packetConn, err := net.ListenPacket("ip4:udp", client.IP.String())
|
||||
if err != nil {
|
||||
log.Fatalln("Error creating packetConn:", err)
|
||||
}
|
||||
|
||||
rawConn, err := ipv4.NewRawConn(packetConn)
|
||||
if err != nil {
|
||||
log.Fatalln("Error creating rawConn:", err)
|
||||
}
|
||||
|
||||
ApplyBPF(rawConn, server, client)
|
||||
|
||||
return rawConn
|
||||
}
|
||||
|
||||
// ApplyBPF constructs a BPF program and applies it to the RawConn
|
||||
func ApplyBPF(rawConn *ipv4.RawConn, server *Server, client *PeerNet) {
|
||||
const ipv4HeaderLen = 20
|
||||
const srcIPOffset = 12
|
||||
const srcPortOffset = ipv4HeaderLen + 0
|
||||
const dstPortOffset = ipv4HeaderLen + 2
|
||||
|
||||
ipArr := []byte(server.Addr.IP.To4())
|
||||
ipInt := uint32(ipArr[0])<<(3*8) + uint32(ipArr[1])<<(2*8) + uint32(ipArr[2])<<8 + uint32(ipArr[3])
|
||||
|
||||
bpfRaw, err := bpf.Assemble([]bpf.Instruction{
|
||||
bpf.LoadAbsolute{Off: srcIPOffset, Size: 4},
|
||||
bpf.JumpIf{Cond: bpf.JumpEqual, Val: ipInt, SkipFalse: 5, SkipTrue: 0},
|
||||
|
||||
bpf.LoadAbsolute{Off: srcPortOffset, Size: 2},
|
||||
bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(server.Port), SkipFalse: 3, SkipTrue: 0},
|
||||
|
||||
bpf.LoadAbsolute{Off: dstPortOffset, Size: 2},
|
||||
bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(client.Port), SkipFalse: 1, SkipTrue: 0},
|
||||
|
||||
bpf.RetConstant{Val: 1<<(8*4) - 1},
|
||||
bpf.RetConstant{Val: 0},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Fatalln("Error assembling BPF:", err)
|
||||
}
|
||||
|
||||
err = rawConn.SetBPF(bpfRaw)
|
||||
if err != nil {
|
||||
log.Fatalln("Error setting BPF:", err)
|
||||
}
|
||||
}
|
||||
|
||||
// MakePacket constructs a request packet to send to the server
|
||||
func MakePacket(payload []byte, server *Server, client *PeerNet) []byte {
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
|
||||
opts := gopacket.SerializeOptions{
|
||||
FixLengths: true,
|
||||
ComputeChecksums: true,
|
||||
}
|
||||
|
||||
ipHeader := layers.IPv4{
|
||||
SrcIP: client.IP,
|
||||
DstIP: server.Addr.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
|
||||
udpHeader := layers.UDP{
|
||||
SrcPort: layers.UDPPort(client.Port),
|
||||
DstPort: layers.UDPPort(server.Port),
|
||||
}
|
||||
|
||||
payloadLayer := gopacket.Payload(payload)
|
||||
|
||||
udpHeader.SetNetworkLayerForChecksum(&ipHeader)
|
||||
|
||||
gopacket.SerializeLayers(buf, opts, &ipHeader, &udpHeader, &payloadLayer)
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// SendPacket sends packet to the Server
|
||||
func SendPacket(packet []byte, conn *ipv4.RawConn, server *Server, client *PeerNet) error {
|
||||
fullPacket := MakePacket(packet, server, client)
|
||||
_, err := conn.WriteToIP(fullPacket, server.Addr)
|
||||
return err
|
||||
}
|
||||
|
||||
// SendDataPacket sends a JSON payload to the Server
|
||||
func SendDataPacket(data interface{}, conn *ipv4.RawConn, server *Server, client *PeerNet) error {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal payload: %v", err)
|
||||
}
|
||||
|
||||
return SendPacket(jsonData, conn, server, client)
|
||||
}
|
||||
|
||||
// RecvPacket receives a UDP packet from server
|
||||
func RecvPacket(conn *ipv4.RawConn, server *Server, client *PeerNet) ([]byte, int, error) {
|
||||
err := conn.SetReadDeadline(time.Now().Add(timeout))
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
response := make([]byte, 4096)
|
||||
n, err := conn.Read(response)
|
||||
if err != nil {
|
||||
return nil, n, err
|
||||
}
|
||||
return response, n, nil
|
||||
}
|
||||
|
||||
// RecvDataPacket receives and unmarshals a JSON packet from server
|
||||
func RecvDataPacket(conn *ipv4.RawConn, server *Server, client *PeerNet) ([]byte, error) {
|
||||
response, n, err := RecvPacket(conn, server, client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Extract payload from UDP packet
|
||||
payload := response[EmptyUDPSize:n]
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// ParseResponse takes a response packet and parses it into an IP and port
|
||||
func ParseResponse(response []byte) (net.IP, uint16) {
|
||||
ip := net.IP(response[:4])
|
||||
port := binary.BigEndian.Uint16(response[4:6])
|
||||
return ip, port
|
||||
}
|
||||
286
network/route.go
Normal file
286
network/route.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
func DarwinAddRoute(destination string, gateway string, interfaceName string) error {
|
||||
if runtime.GOOS != "darwin" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var cmd *exec.Cmd
|
||||
|
||||
if gateway != "" {
|
||||
// Route with specific gateway
|
||||
cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-gateway", gateway)
|
||||
} else if interfaceName != "" {
|
||||
// Route via interface
|
||||
cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-interface", interfaceName)
|
||||
} else {
|
||||
return fmt.Errorf("either gateway or interface must be specified")
|
||||
}
|
||||
|
||||
logger.Info("Running command: %v", cmd)
|
||||
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("route command failed: %v, output: %s", err, out)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func DarwinRemoveRoute(destination string) error {
|
||||
if runtime.GOOS != "darwin" {
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd := exec.Command("route", "-q", "-n", "delete", "-inet", destination)
|
||||
logger.Info("Running command: %v", cmd)
|
||||
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("route delete command failed: %v, output: %s", err, out)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func LinuxAddRoute(destination string, gateway string, interfaceName string) error {
|
||||
if runtime.GOOS != "linux" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse destination CIDR
|
||||
_, ipNet, err := net.ParseCIDR(destination)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid destination address: %v", err)
|
||||
}
|
||||
|
||||
// Create route
|
||||
route := &netlink.Route{
|
||||
Dst: ipNet,
|
||||
}
|
||||
|
||||
if gateway != "" {
|
||||
// Route with specific gateway
|
||||
gw := net.ParseIP(gateway)
|
||||
if gw == nil {
|
||||
return fmt.Errorf("invalid gateway address: %s", gateway)
|
||||
}
|
||||
route.Gw = gw
|
||||
logger.Info("Adding route to %s via gateway %s", destination, gateway)
|
||||
} else if interfaceName != "" {
|
||||
// Route via interface
|
||||
link, err := netlink.LinkByName(interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get interface %s: %v", interfaceName, err)
|
||||
}
|
||||
route.LinkIndex = link.Attrs().Index
|
||||
logger.Info("Adding route to %s via interface %s", destination, interfaceName)
|
||||
} else {
|
||||
return fmt.Errorf("either gateway or interface must be specified")
|
||||
}
|
||||
|
||||
// Add the route
|
||||
if err := netlink.RouteAdd(route); err != nil {
|
||||
return fmt.Errorf("failed to add route: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func LinuxRemoveRoute(destination string) error {
|
||||
if runtime.GOOS != "linux" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse destination CIDR
|
||||
_, ipNet, err := net.ParseCIDR(destination)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid destination address: %v", err)
|
||||
}
|
||||
|
||||
// Create route to delete
|
||||
route := &netlink.Route{
|
||||
Dst: ipNet,
|
||||
}
|
||||
|
||||
logger.Info("Removing route to %s", destination)
|
||||
|
||||
// Delete the route
|
||||
if err := netlink.RouteDel(route); err != nil {
|
||||
return fmt.Errorf("failed to delete route: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addRouteForServerIP adds an OS-specific route for the server IP
|
||||
func AddRouteForServerIP(serverIP, interfaceName string) error {
|
||||
if interfaceName == "" {
|
||||
return nil
|
||||
}
|
||||
// TODO: does this also need to be ios?
|
||||
if runtime.GOOS == "darwin" { // macos requires routes for each peer to be added but this messes with other platforms
|
||||
if err := AddRouteForNetworkConfig(serverIP); err != nil {
|
||||
return err
|
||||
}
|
||||
return DarwinAddRoute(serverIP, "", interfaceName)
|
||||
}
|
||||
// else if runtime.GOOS == "windows" {
|
||||
// return WindowsAddRoute(serverIP, "", interfaceName)
|
||||
// } else if runtime.GOOS == "linux" {
|
||||
// return LinuxAddRoute(serverIP, "", interfaceName)
|
||||
// }
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeRouteForServerIP removes an OS-specific route for the server IP
|
||||
func RemoveRouteForServerIP(serverIP string, interfaceName string) error {
|
||||
if interfaceName == "" {
|
||||
return nil
|
||||
}
|
||||
// TODO: does this also need to be ios?
|
||||
if runtime.GOOS == "darwin" { // macos requires routes for each peer to be added but this messes with other platforms
|
||||
if err := RemoveRouteForNetworkConfig(serverIP); err != nil {
|
||||
return err
|
||||
}
|
||||
return DarwinRemoveRoute(serverIP)
|
||||
}
|
||||
// else if runtime.GOOS == "windows" {
|
||||
// return WindowsRemoveRoute(serverIP)
|
||||
// } else if runtime.GOOS == "linux" {
|
||||
// return LinuxRemoveRoute(serverIP)
|
||||
// }
|
||||
return nil
|
||||
}
|
||||
|
||||
func AddRouteForNetworkConfig(destination string) error {
|
||||
// Parse the subnet to extract IP and mask
|
||||
_, ipNet, err := net.ParseCIDR(destination)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse subnet %s: %v", destination, err)
|
||||
}
|
||||
|
||||
// Convert CIDR mask to dotted decimal format (e.g., 255.255.255.0)
|
||||
mask := net.IP(ipNet.Mask).String()
|
||||
destinationAddress := ipNet.IP.String()
|
||||
|
||||
AddIPv4IncludedRoute(IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func RemoveRouteForNetworkConfig(destination string) error {
|
||||
// Parse the subnet to extract IP and mask
|
||||
_, ipNet, err := net.ParseCIDR(destination)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse subnet %s: %v", destination, err)
|
||||
}
|
||||
|
||||
// Convert CIDR mask to dotted decimal format (e.g., 255.255.255.0)
|
||||
mask := net.IP(ipNet.Mask).String()
|
||||
destinationAddress := ipNet.IP.String()
|
||||
|
||||
RemoveIPv4IncludedRoute(IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addRoutes adds routes for each subnet in RemoteSubnets
|
||||
func AddRoutes(remoteSubnets []string, interfaceName string) error {
|
||||
if len(remoteSubnets) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add routes for each subnet
|
||||
for _, subnet := range remoteSubnets {
|
||||
subnet = strings.TrimSpace(subnet)
|
||||
if subnet == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := AddRouteForNetworkConfig(subnet); err != nil {
|
||||
logger.Error("Failed to add network config for subnet %s: %v", subnet, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Add route based on operating system
|
||||
if interfaceName == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
if err := DarwinAddRoute(subnet, "", interfaceName); err != nil {
|
||||
logger.Error("Failed to add Darwin route for subnet %s: %v", subnet, err)
|
||||
}
|
||||
case "windows":
|
||||
if err := WindowsAddRoute(subnet, "", interfaceName); err != nil {
|
||||
logger.Error("Failed to add Windows route for subnet %s: %v", subnet, err)
|
||||
}
|
||||
case "linux":
|
||||
if err := LinuxAddRoute(subnet, "", interfaceName); err != nil {
|
||||
logger.Error("Failed to add Linux route for subnet %s: %v", subnet, err)
|
||||
}
|
||||
case "android", "ios":
|
||||
// Routes handled by the OS/VPN service
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Info("Added route for remote subnet: %s", subnet)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeRoutesForRemoteSubnets removes routes for each subnet in RemoteSubnets
|
||||
func RemoveRoutes(remoteSubnets []string) error {
|
||||
if len(remoteSubnets) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove routes for each subnet
|
||||
for _, subnet := range remoteSubnets {
|
||||
subnet = strings.TrimSpace(subnet)
|
||||
if subnet == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := RemoveRouteForNetworkConfig(subnet); err != nil {
|
||||
logger.Error("Failed to remove network config for subnet %s: %v", subnet, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Remove route based on operating system
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
if err := DarwinRemoveRoute(subnet); err != nil {
|
||||
logger.Error("Failed to remove Darwin route for subnet %s: %v", subnet, err)
|
||||
}
|
||||
case "windows":
|
||||
if err := WindowsRemoveRoute(subnet); err != nil {
|
||||
logger.Error("Failed to remove Windows route for subnet %s: %v", subnet, err)
|
||||
}
|
||||
case "linux":
|
||||
if err := LinuxRemoveRoute(subnet); err != nil {
|
||||
logger.Error("Failed to remove Linux route for subnet %s: %v", subnet, err)
|
||||
}
|
||||
case "android", "ios":
|
||||
// Routes handled by the OS/VPN service
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Info("Removed route for remote subnet: %s", subnet)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
11
network/route_notwindows.go
Normal file
11
network/route_notwindows.go
Normal file
@@ -0,0 +1,11 @@
|
||||
//go:build !windows
|
||||
|
||||
package network
|
||||
|
||||
func WindowsAddRoute(destination string, gateway string, interfaceName string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func WindowsRemoveRoute(destination string) error {
|
||||
return nil
|
||||
}
|
||||
148
network/route_windows.go
Normal file
148
network/route_windows.go
Normal file
@@ -0,0 +1,148 @@
|
||||
//go:build windows
|
||||
|
||||
package network
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
)
|
||||
|
||||
func WindowsAddRoute(destination string, gateway string, interfaceName string) error {
|
||||
if runtime.GOOS != "windows" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse destination CIDR
|
||||
_, ipNet, err := net.ParseCIDR(destination)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid destination address: %v", err)
|
||||
}
|
||||
|
||||
// Convert to netip.Prefix
|
||||
maskBits, _ := ipNet.Mask.Size()
|
||||
|
||||
// Ensure we convert to the correct IP version (IPv4 vs IPv6)
|
||||
var addr netip.Addr
|
||||
if ip4 := ipNet.IP.To4(); ip4 != nil {
|
||||
// IPv4 address
|
||||
addr, _ = netip.AddrFromSlice(ip4)
|
||||
} else {
|
||||
// IPv6 address
|
||||
addr, _ = netip.AddrFromSlice(ipNet.IP)
|
||||
}
|
||||
if !addr.IsValid() {
|
||||
return fmt.Errorf("failed to convert destination IP")
|
||||
}
|
||||
prefix := netip.PrefixFrom(addr, maskBits)
|
||||
|
||||
var luid winipcfg.LUID
|
||||
var nextHop netip.Addr
|
||||
|
||||
if interfaceName != "" {
|
||||
// Get the interface LUID - needed for both gateway and interface-only routes
|
||||
iface, err := net.InterfaceByName(interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get interface %s: %v", interfaceName, err)
|
||||
}
|
||||
|
||||
luid, err = winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get LUID for interface %s: %v", interfaceName, err)
|
||||
}
|
||||
}
|
||||
|
||||
if gateway != "" {
|
||||
// Route with specific gateway
|
||||
gwIP := net.ParseIP(gateway)
|
||||
if gwIP == nil {
|
||||
return fmt.Errorf("invalid gateway address: %s", gateway)
|
||||
}
|
||||
// Convert to correct IP version
|
||||
if ip4 := gwIP.To4(); ip4 != nil {
|
||||
nextHop, _ = netip.AddrFromSlice(ip4)
|
||||
} else {
|
||||
nextHop, _ = netip.AddrFromSlice(gwIP)
|
||||
}
|
||||
if !nextHop.IsValid() {
|
||||
return fmt.Errorf("failed to convert gateway IP")
|
||||
}
|
||||
logger.Info("Adding route to %s via gateway %s on interface %s", destination, gateway, interfaceName)
|
||||
} else if interfaceName != "" {
|
||||
// Route via interface only
|
||||
if addr.Is4() {
|
||||
nextHop = netip.IPv4Unspecified()
|
||||
} else {
|
||||
nextHop = netip.IPv6Unspecified()
|
||||
}
|
||||
logger.Info("Adding route to %s via interface %s", destination, interfaceName)
|
||||
} else {
|
||||
return fmt.Errorf("either gateway or interface must be specified")
|
||||
}
|
||||
|
||||
// Add the route using winipcfg
|
||||
err = luid.AddRoute(prefix, nextHop, 1)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add route: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func WindowsRemoveRoute(destination string) error {
|
||||
// Parse destination CIDR
|
||||
_, ipNet, err := net.ParseCIDR(destination)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid destination address: %v", err)
|
||||
}
|
||||
|
||||
// Convert to netip.Prefix
|
||||
maskBits, _ := ipNet.Mask.Size()
|
||||
|
||||
// Ensure we convert to the correct IP version (IPv4 vs IPv6)
|
||||
var addr netip.Addr
|
||||
if ip4 := ipNet.IP.To4(); ip4 != nil {
|
||||
// IPv4 address
|
||||
addr, _ = netip.AddrFromSlice(ip4)
|
||||
} else {
|
||||
// IPv6 address
|
||||
addr, _ = netip.AddrFromSlice(ipNet.IP)
|
||||
}
|
||||
if !addr.IsValid() {
|
||||
return fmt.Errorf("failed to convert destination IP")
|
||||
}
|
||||
prefix := netip.PrefixFrom(addr, maskBits)
|
||||
|
||||
// Get all routes and find the one to delete
|
||||
// We need to get the LUID from the existing route
|
||||
var family winipcfg.AddressFamily
|
||||
if addr.Is4() {
|
||||
family = 2 // AF_INET
|
||||
} else {
|
||||
family = 23 // AF_INET6
|
||||
}
|
||||
|
||||
routes, err := winipcfg.GetIPForwardTable2(family)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get route table: %v", err)
|
||||
}
|
||||
|
||||
// Find and delete matching route
|
||||
for _, route := range routes {
|
||||
routePrefix := route.DestinationPrefix.Prefix()
|
||||
if routePrefix == prefix {
|
||||
logger.Info("Removing route to %s", destination)
|
||||
err = route.Delete()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete route: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("route to %s not found", destination)
|
||||
}
|
||||
190
network/settings.go
Normal file
190
network/settings.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"sync"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
)
|
||||
|
||||
// NetworkSettings represents the network configuration for the tunnel
|
||||
type NetworkSettings struct {
|
||||
TunnelRemoteAddress string `json:"tunnel_remote_address,omitempty"`
|
||||
MTU *int `json:"mtu,omitempty"`
|
||||
DNSServers []string `json:"dns_servers,omitempty"`
|
||||
IPv4Addresses []string `json:"ipv4_addresses,omitempty"`
|
||||
IPv4SubnetMasks []string `json:"ipv4_subnet_masks,omitempty"`
|
||||
IPv4IncludedRoutes []IPv4Route `json:"ipv4_included_routes,omitempty"`
|
||||
IPv4ExcludedRoutes []IPv4Route `json:"ipv4_excluded_routes,omitempty"`
|
||||
IPv6Addresses []string `json:"ipv6_addresses,omitempty"`
|
||||
IPv6NetworkPrefixes []string `json:"ipv6_network_prefixes,omitempty"`
|
||||
IPv6IncludedRoutes []IPv6Route `json:"ipv6_included_routes,omitempty"`
|
||||
IPv6ExcludedRoutes []IPv6Route `json:"ipv6_excluded_routes,omitempty"`
|
||||
}
|
||||
|
||||
// IPv4Route represents an IPv4 route
|
||||
type IPv4Route struct {
|
||||
DestinationAddress string `json:"destination_address"`
|
||||
SubnetMask string `json:"subnet_mask,omitempty"`
|
||||
GatewayAddress string `json:"gateway_address,omitempty"`
|
||||
IsDefault bool `json:"is_default,omitempty"`
|
||||
}
|
||||
|
||||
// IPv6Route represents an IPv6 route
|
||||
type IPv6Route struct {
|
||||
DestinationAddress string `json:"destination_address"`
|
||||
NetworkPrefixLength int `json:"network_prefix_length,omitempty"`
|
||||
GatewayAddress string `json:"gateway_address,omitempty"`
|
||||
IsDefault bool `json:"is_default,omitempty"`
|
||||
}
|
||||
|
||||
var (
|
||||
networkSettings NetworkSettings
|
||||
networkSettingsMutex sync.RWMutex
|
||||
incrementor int
|
||||
)
|
||||
|
||||
// SetTunnelRemoteAddress sets the tunnel remote address
|
||||
func SetTunnelRemoteAddress(address string) {
|
||||
networkSettingsMutex.Lock()
|
||||
defer networkSettingsMutex.Unlock()
|
||||
networkSettings.TunnelRemoteAddress = address
|
||||
incrementor++
|
||||
logger.Info("Set tunnel remote address: %s", address)
|
||||
}
|
||||
|
||||
// SetMTU sets the MTU value
|
||||
func SetMTU(mtu int) {
|
||||
networkSettingsMutex.Lock()
|
||||
defer networkSettingsMutex.Unlock()
|
||||
networkSettings.MTU = &mtu
|
||||
incrementor++
|
||||
logger.Info("Set MTU: %d", mtu)
|
||||
}
|
||||
|
||||
// SetDNSServers sets the DNS servers
|
||||
func SetDNSServers(servers []string) {
|
||||
networkSettingsMutex.Lock()
|
||||
defer networkSettingsMutex.Unlock()
|
||||
networkSettings.DNSServers = servers
|
||||
incrementor++
|
||||
logger.Info("Set DNS servers: %v", servers)
|
||||
}
|
||||
|
||||
// SetIPv4Settings sets IPv4 addresses and subnet masks
|
||||
func SetIPv4Settings(addresses []string, subnetMasks []string) {
|
||||
networkSettingsMutex.Lock()
|
||||
defer networkSettingsMutex.Unlock()
|
||||
networkSettings.IPv4Addresses = addresses
|
||||
networkSettings.IPv4SubnetMasks = subnetMasks
|
||||
incrementor++
|
||||
logger.Info("Set IPv4 addresses: %v, subnet masks: %v", addresses, subnetMasks)
|
||||
}
|
||||
|
||||
// SetIPv4IncludedRoutes sets the included IPv4 routes
|
||||
func SetIPv4IncludedRoutes(routes []IPv4Route) {
|
||||
networkSettingsMutex.Lock()
|
||||
defer networkSettingsMutex.Unlock()
|
||||
networkSettings.IPv4IncludedRoutes = routes
|
||||
incrementor++
|
||||
logger.Info("Set IPv4 included routes: %d routes", len(routes))
|
||||
}
|
||||
|
||||
func AddIPv4IncludedRoute(route IPv4Route) {
|
||||
networkSettingsMutex.Lock()
|
||||
defer networkSettingsMutex.Unlock()
|
||||
|
||||
// make sure it does not already exist
|
||||
for _, r := range networkSettings.IPv4IncludedRoutes {
|
||||
if r == route {
|
||||
logger.Info("IPv4 included route already exists: %+v", route)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
networkSettings.IPv4IncludedRoutes = append(networkSettings.IPv4IncludedRoutes, route)
|
||||
incrementor++
|
||||
logger.Info("Added IPv4 included route: %+v", route)
|
||||
}
|
||||
|
||||
func RemoveIPv4IncludedRoute(route IPv4Route) {
|
||||
networkSettingsMutex.Lock()
|
||||
defer networkSettingsMutex.Unlock()
|
||||
routes := networkSettings.IPv4IncludedRoutes
|
||||
for i, r := range routes {
|
||||
if r == route {
|
||||
networkSettings.IPv4IncludedRoutes = append(routes[:i], routes[i+1:]...)
|
||||
logger.Info("Removed IPv4 included route: %+v", route)
|
||||
break
|
||||
}
|
||||
}
|
||||
incrementor++
|
||||
logger.Info("IPv4 included route not found for removal: %+v", route)
|
||||
}
|
||||
|
||||
func SetIPv4ExcludedRoutes(routes []IPv4Route) {
|
||||
networkSettingsMutex.Lock()
|
||||
defer networkSettingsMutex.Unlock()
|
||||
networkSettings.IPv4ExcludedRoutes = routes
|
||||
incrementor++
|
||||
logger.Info("Set IPv4 excluded routes: %d routes", len(routes))
|
||||
}
|
||||
|
||||
// SetIPv6Settings sets IPv6 addresses and network prefixes
|
||||
func SetIPv6Settings(addresses []string, networkPrefixes []string) {
|
||||
networkSettingsMutex.Lock()
|
||||
defer networkSettingsMutex.Unlock()
|
||||
networkSettings.IPv6Addresses = addresses
|
||||
networkSettings.IPv6NetworkPrefixes = networkPrefixes
|
||||
incrementor++
|
||||
logger.Info("Set IPv6 addresses: %v, network prefixes: %v", addresses, networkPrefixes)
|
||||
}
|
||||
|
||||
// SetIPv6IncludedRoutes sets the included IPv6 routes
|
||||
func SetIPv6IncludedRoutes(routes []IPv6Route) {
|
||||
networkSettingsMutex.Lock()
|
||||
defer networkSettingsMutex.Unlock()
|
||||
networkSettings.IPv6IncludedRoutes = routes
|
||||
incrementor++
|
||||
logger.Info("Set IPv6 included routes: %d routes", len(routes))
|
||||
}
|
||||
|
||||
// SetIPv6ExcludedRoutes sets the excluded IPv6 routes
|
||||
func SetIPv6ExcludedRoutes(routes []IPv6Route) {
|
||||
networkSettingsMutex.Lock()
|
||||
defer networkSettingsMutex.Unlock()
|
||||
networkSettings.IPv6ExcludedRoutes = routes
|
||||
incrementor++
|
||||
logger.Info("Set IPv6 excluded routes: %d routes", len(routes))
|
||||
}
|
||||
|
||||
// ClearNetworkSettings clears all network settings
|
||||
func ClearNetworkSettings() {
|
||||
networkSettingsMutex.Lock()
|
||||
defer networkSettingsMutex.Unlock()
|
||||
networkSettings = NetworkSettings{}
|
||||
incrementor++
|
||||
logger.Info("Cleared all network settings")
|
||||
}
|
||||
|
||||
func GetJSON() (string, error) {
|
||||
networkSettingsMutex.RLock()
|
||||
defer networkSettingsMutex.RUnlock()
|
||||
data, err := json.MarshalIndent(networkSettings, "", " ")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
func GetSettings() NetworkSettings {
|
||||
networkSettingsMutex.RLock()
|
||||
defer networkSettingsMutex.RUnlock()
|
||||
return networkSettings
|
||||
}
|
||||
|
||||
func GetIncrementor() int {
|
||||
networkSettingsMutex.Lock()
|
||||
defer networkSettingsMutex.Unlock()
|
||||
return incrementor
|
||||
}
|
||||
152
newt.iss
Normal file
152
newt.iss
Normal file
@@ -0,0 +1,152 @@
|
||||
; Script generated by the Inno Setup Script Wizard.
|
||||
; SEE THE DOCUMENTATION FOR DETAILS ON CREATING INNO SETUP SCRIPT FILES!
|
||||
|
||||
#define MyAppName "newt"
|
||||
#define MyAppVersion "1.0.0"
|
||||
#define MyAppPublisher "Fossorial Inc."
|
||||
#define MyAppURL "https://pangolin.net"
|
||||
#define MyAppExeName "newt.exe"
|
||||
|
||||
[Setup]
|
||||
; NOTE: The value of AppId uniquely identifies this application. Do not use the same AppId value in installers for other applications.
|
||||
; (To generate a new GUID, click Tools | Generate GUID inside the IDE.)
|
||||
AppId={{25A1E3C4-F273-4334-8DF3-47408E83012D}
|
||||
AppName={#MyAppName}
|
||||
AppVersion={#MyAppVersion}
|
||||
;AppVerName={#MyAppName} {#MyAppVersion}
|
||||
AppPublisher={#MyAppPublisher}
|
||||
AppPublisherURL={#MyAppURL}
|
||||
AppSupportURL={#MyAppURL}
|
||||
AppUpdatesURL={#MyAppURL}
|
||||
DefaultDirName={autopf}\{#MyAppName}
|
||||
UninstallDisplayIcon={app}\{#MyAppExeName}
|
||||
; "ArchitecturesAllowed=x64compatible" specifies that Setup cannot run
|
||||
; on anything but x64 and Windows 11 on Arm.
|
||||
ArchitecturesAllowed=x64compatible
|
||||
; "ArchitecturesInstallIn64BitMode=x64compatible" requests that the
|
||||
; install be done in "64-bit mode" on x64 or Windows 11 on Arm,
|
||||
; meaning it should use the native 64-bit Program Files directory and
|
||||
; the 64-bit view of the registry.
|
||||
ArchitecturesInstallIn64BitMode=x64compatible
|
||||
DefaultGroupName={#MyAppName}
|
||||
DisableProgramGroupPage=yes
|
||||
; Uncomment the following line to run in non administrative install mode (install for current user only).
|
||||
;PrivilegesRequired=lowest
|
||||
OutputBaseFilename=mysetup
|
||||
SolidCompression=yes
|
||||
WizardStyle=modern
|
||||
; Add this to ensure PATH changes are applied and the system is prompted for a restart if needed
|
||||
RestartIfNeededByRun=no
|
||||
ChangesEnvironment=true
|
||||
|
||||
[Languages]
|
||||
Name: "english"; MessagesFile: "compiler:Default.isl"
|
||||
|
||||
[Files]
|
||||
; The 'DestName' flag ensures that 'newt_windows_amd64.exe' is installed as 'newt.exe'
|
||||
Source: "Z:\newt_windows_amd64.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}"; Flags: ignoreversion
|
||||
Source: "Z:\wintun.dll"; DestDir: "{app}"; Flags: ignoreversion
|
||||
; NOTE: Don't use "Flags: ignoreversion" on any shared system files
|
||||
|
||||
[Icons]
|
||||
Name: "{group}\{#MyAppName}"; Filename: "{app}\{#MyAppExeName}"
|
||||
|
||||
[Registry]
|
||||
; Add the application's installation directory to the system PATH environment variable.
|
||||
; HKLM (HKEY_LOCAL_MACHINE) is used for system-wide changes.
|
||||
; The 'Path' variable is located under 'SYSTEM\CurrentControlSet\Control\Session Manager\Environment'.
|
||||
; ValueType: expandsz allows for environment variables (like %ProgramFiles%) in the path.
|
||||
; ValueData: "{olddata};{app}" appends the current application directory to the existing PATH.
|
||||
; Note: Removal during uninstallation is handled by CurUninstallStepChanged procedure in [Code] section.
|
||||
; Check: NeedsAddPath ensures this is applied only if the path is not already present.
|
||||
[Registry]
|
||||
; Add the application's installation directory to the system PATH.
|
||||
Root: HKLM; Subkey: "SYSTEM\CurrentControlSet\Control\Session Manager\Environment"; \
|
||||
ValueType: expandsz; ValueName: "Path"; ValueData: "{olddata};{app}"; \
|
||||
Check: NeedsAddPath(ExpandConstant('{app}'))
|
||||
|
||||
[Code]
|
||||
function NeedsAddPath(Path: string): boolean;
|
||||
var
|
||||
OrigPath: string;
|
||||
begin
|
||||
if not RegQueryStringValue(HKEY_LOCAL_MACHINE,
|
||||
'SYSTEM\CurrentControlSet\Control\Session Manager\Environment',
|
||||
'Path', OrigPath)
|
||||
then begin
|
||||
// Path variable doesn't exist at all, so we definitely need to add it.
|
||||
Result := True;
|
||||
exit;
|
||||
end;
|
||||
|
||||
// Perform a case-insensitive check to see if the path is already present.
|
||||
// We add semicolons to prevent partial matches (e.g., matching C:\App in C:\App2).
|
||||
if Pos(';' + UpperCase(Path) + ';', ';' + UpperCase(OrigPath) + ';') > 0 then
|
||||
Result := False
|
||||
else
|
||||
Result := True;
|
||||
end;
|
||||
|
||||
procedure RemovePathEntry(PathToRemove: string);
|
||||
var
|
||||
OrigPath: string;
|
||||
NewPath: string;
|
||||
PathList: TStringList;
|
||||
I: Integer;
|
||||
begin
|
||||
if not RegQueryStringValue(HKEY_LOCAL_MACHINE,
|
||||
'SYSTEM\CurrentControlSet\Control\Session Manager\Environment',
|
||||
'Path', OrigPath)
|
||||
then begin
|
||||
// Path variable doesn't exist, nothing to remove
|
||||
exit;
|
||||
end;
|
||||
|
||||
// Create a string list to parse the PATH entries
|
||||
PathList := TStringList.Create;
|
||||
try
|
||||
// Split the PATH by semicolons
|
||||
PathList.Delimiter := ';';
|
||||
PathList.StrictDelimiter := True;
|
||||
PathList.DelimitedText := OrigPath;
|
||||
|
||||
// Find and remove the matching entry (case-insensitive)
|
||||
for I := PathList.Count - 1 downto 0 do
|
||||
begin
|
||||
if CompareText(Trim(PathList[I]), Trim(PathToRemove)) = 0 then
|
||||
begin
|
||||
Log('Found and removing PATH entry: ' + PathList[I]);
|
||||
PathList.Delete(I);
|
||||
end;
|
||||
end;
|
||||
|
||||
// Reconstruct the PATH
|
||||
NewPath := PathList.DelimitedText;
|
||||
|
||||
// Write the new PATH back to the registry
|
||||
if RegWriteExpandStringValue(HKEY_LOCAL_MACHINE,
|
||||
'SYSTEM\CurrentControlSet\Control\Session Manager\Environment',
|
||||
'Path', NewPath)
|
||||
then
|
||||
Log('Successfully removed path entry: ' + PathToRemove)
|
||||
else
|
||||
Log('Failed to write modified PATH to registry');
|
||||
finally
|
||||
PathList.Free;
|
||||
end;
|
||||
end;
|
||||
|
||||
procedure CurUninstallStepChanged(CurUninstallStep: TUninstallStep);
|
||||
var
|
||||
AppPath: string;
|
||||
begin
|
||||
if CurUninstallStep = usUninstall then
|
||||
begin
|
||||
// Get the application installation path
|
||||
AppPath := ExpandConstant('{app}');
|
||||
Log('Removing PATH entry for: ' + AppPath);
|
||||
|
||||
// Remove only our path entry from the system PATH
|
||||
RemovePathEntry(AppPath);
|
||||
end;
|
||||
end;
|
||||
379
proxy/manager.go
379
proxy/manager.go
@@ -1,18 +1,28 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/internal/state"
|
||||
"github.com/fosrl/newt/internal/telemetry"
|
||||
"github.com/fosrl/newt/logger"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
)
|
||||
|
||||
const errUnsupportedProtoFmt = "unsupported protocol: %s"
|
||||
|
||||
// Target represents a proxy target with its address and port
|
||||
type Target struct {
|
||||
Address string
|
||||
@@ -28,6 +38,90 @@ type ProxyManager struct {
|
||||
udpConns []*gonet.UDPConn
|
||||
running bool
|
||||
mutex sync.RWMutex
|
||||
|
||||
// telemetry (multi-tunnel)
|
||||
currentTunnelID string
|
||||
tunnels map[string]*tunnelEntry
|
||||
asyncBytes bool
|
||||
flushStop chan struct{}
|
||||
}
|
||||
|
||||
// tunnelEntry holds per-tunnel attributes and (optional) async counters.
|
||||
type tunnelEntry struct {
|
||||
attrInTCP attribute.Set
|
||||
attrOutTCP attribute.Set
|
||||
attrInUDP attribute.Set
|
||||
attrOutUDP attribute.Set
|
||||
|
||||
bytesInTCP atomic.Uint64
|
||||
bytesOutTCP atomic.Uint64
|
||||
bytesInUDP atomic.Uint64
|
||||
bytesOutUDP atomic.Uint64
|
||||
|
||||
activeTCP atomic.Int64
|
||||
activeUDP atomic.Int64
|
||||
}
|
||||
|
||||
// countingWriter wraps an io.Writer and adds bytes to OTel counter using a pre-built attribute set.
|
||||
type countingWriter struct {
|
||||
ctx context.Context
|
||||
w io.Writer
|
||||
set attribute.Set
|
||||
pm *ProxyManager
|
||||
ent *tunnelEntry
|
||||
out bool // false=in, true=out
|
||||
proto string // "tcp" or "udp"
|
||||
}
|
||||
|
||||
func (cw *countingWriter) Write(p []byte) (int, error) {
|
||||
n, err := cw.w.Write(p)
|
||||
if n > 0 {
|
||||
if cw.pm != nil && cw.pm.asyncBytes && cw.ent != nil {
|
||||
switch cw.proto {
|
||||
case "tcp":
|
||||
if cw.out {
|
||||
cw.ent.bytesOutTCP.Add(uint64(n))
|
||||
} else {
|
||||
cw.ent.bytesInTCP.Add(uint64(n))
|
||||
}
|
||||
case "udp":
|
||||
if cw.out {
|
||||
cw.ent.bytesOutUDP.Add(uint64(n))
|
||||
} else {
|
||||
cw.ent.bytesInUDP.Add(uint64(n))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
telemetry.AddTunnelBytesSet(cw.ctx, int64(n), cw.set)
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func classifyProxyError(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return "closed"
|
||||
}
|
||||
if ne, ok := err.(net.Error); ok {
|
||||
if ne.Timeout() {
|
||||
return "timeout"
|
||||
}
|
||||
if ne.Temporary() {
|
||||
return "temporary"
|
||||
}
|
||||
}
|
||||
msg := strings.ToLower(err.Error())
|
||||
switch {
|
||||
case strings.Contains(msg, "refused"):
|
||||
return "refused"
|
||||
case strings.Contains(msg, "reset"):
|
||||
return "reset"
|
||||
default:
|
||||
return "io_error"
|
||||
}
|
||||
}
|
||||
|
||||
// NewProxyManager creates a new proxy manager instance
|
||||
@@ -38,9 +132,77 @@ func NewProxyManager(tnet *netstack.Net) *ProxyManager {
|
||||
udpTargets: make(map[string]map[int]string),
|
||||
listeners: make([]*gonet.TCPListener, 0),
|
||||
udpConns: make([]*gonet.UDPConn, 0),
|
||||
tunnels: make(map[string]*tunnelEntry),
|
||||
}
|
||||
}
|
||||
|
||||
// SetTunnelID sets the WireGuard peer public key used as tunnel_id label.
|
||||
func (pm *ProxyManager) SetTunnelID(id string) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
pm.currentTunnelID = id
|
||||
if _, ok := pm.tunnels[id]; !ok {
|
||||
pm.tunnels[id] = &tunnelEntry{}
|
||||
}
|
||||
e := pm.tunnels[id]
|
||||
// include site labels if available
|
||||
site := telemetry.SiteLabelKVs()
|
||||
build := func(base []attribute.KeyValue) attribute.Set {
|
||||
if telemetry.ShouldIncludeTunnelID() {
|
||||
base = append([]attribute.KeyValue{attribute.String("tunnel_id", id)}, base...)
|
||||
}
|
||||
base = append(site, base...)
|
||||
return attribute.NewSet(base...)
|
||||
}
|
||||
e.attrInTCP = build([]attribute.KeyValue{
|
||||
attribute.String("direction", "ingress"),
|
||||
attribute.String("protocol", "tcp"),
|
||||
})
|
||||
e.attrOutTCP = build([]attribute.KeyValue{
|
||||
attribute.String("direction", "egress"),
|
||||
attribute.String("protocol", "tcp"),
|
||||
})
|
||||
e.attrInUDP = build([]attribute.KeyValue{
|
||||
attribute.String("direction", "ingress"),
|
||||
attribute.String("protocol", "udp"),
|
||||
})
|
||||
e.attrOutUDP = build([]attribute.KeyValue{
|
||||
attribute.String("direction", "egress"),
|
||||
attribute.String("protocol", "udp"),
|
||||
})
|
||||
}
|
||||
|
||||
// ClearTunnelID clears cached attribute sets for the current tunnel.
|
||||
func (pm *ProxyManager) ClearTunnelID() {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
id := pm.currentTunnelID
|
||||
if id == "" {
|
||||
return
|
||||
}
|
||||
if e, ok := pm.tunnels[id]; ok {
|
||||
// final flush for this tunnel
|
||||
inTCP := e.bytesInTCP.Swap(0)
|
||||
outTCP := e.bytesOutTCP.Swap(0)
|
||||
inUDP := e.bytesInUDP.Swap(0)
|
||||
outUDP := e.bytesOutUDP.Swap(0)
|
||||
if inTCP > 0 {
|
||||
telemetry.AddTunnelBytesSet(context.Background(), int64(inTCP), e.attrInTCP)
|
||||
}
|
||||
if outTCP > 0 {
|
||||
telemetry.AddTunnelBytesSet(context.Background(), int64(outTCP), e.attrOutTCP)
|
||||
}
|
||||
if inUDP > 0 {
|
||||
telemetry.AddTunnelBytesSet(context.Background(), int64(inUDP), e.attrInUDP)
|
||||
}
|
||||
if outUDP > 0 {
|
||||
telemetry.AddTunnelBytesSet(context.Background(), int64(outUDP), e.attrOutUDP)
|
||||
}
|
||||
delete(pm.tunnels, id)
|
||||
}
|
||||
pm.currentTunnelID = ""
|
||||
}
|
||||
|
||||
// init function without tnet
|
||||
func NewProxyManagerWithoutTNet() *ProxyManager {
|
||||
return &ProxyManager{
|
||||
@@ -75,7 +237,7 @@ func (pm *ProxyManager) AddTarget(proto, listenIP string, port int, targetAddr s
|
||||
}
|
||||
pm.udpTargets[listenIP][port] = targetAddr
|
||||
default:
|
||||
return fmt.Errorf("unsupported protocol: %s", proto)
|
||||
return fmt.Errorf(errUnsupportedProtoFmt, proto)
|
||||
}
|
||||
|
||||
if pm.running {
|
||||
@@ -124,13 +286,28 @@ func (pm *ProxyManager) RemoveTarget(proto, listenIP string, port int) error {
|
||||
return fmt.Errorf("target not found: %s:%d", listenIP, port)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported protocol: %s", proto)
|
||||
return fmt.Errorf(errUnsupportedProtoFmt, proto)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start begins listening for all configured proxy targets
|
||||
func (pm *ProxyManager) Start() error {
|
||||
// Register proxy observables once per process
|
||||
telemetry.SetProxyObservableCallback(func(ctx context.Context, o metric.Observer) error {
|
||||
pm.mutex.RLock()
|
||||
defer pm.mutex.RUnlock()
|
||||
for _, e := range pm.tunnels {
|
||||
// active connections
|
||||
telemetry.ObserveProxyActiveConnsObs(o, e.activeTCP.Load(), e.attrOutTCP.ToSlice())
|
||||
telemetry.ObserveProxyActiveConnsObs(o, e.activeUDP.Load(), e.attrOutUDP.ToSlice())
|
||||
// backlog bytes (sum of unflushed counters)
|
||||
b := int64(e.bytesInTCP.Load() + e.bytesOutTCP.Load() + e.bytesInUDP.Load() + e.bytesOutUDP.Load())
|
||||
telemetry.ObserveProxyAsyncBacklogObs(o, b, e.attrOutTCP.ToSlice())
|
||||
telemetry.ObserveProxyBufferBytesObs(o, b, e.attrOutTCP.ToSlice())
|
||||
}
|
||||
return nil
|
||||
})
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
@@ -160,6 +337,75 @@ func (pm *ProxyManager) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) SetAsyncBytes(b bool) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
pm.asyncBytes = b
|
||||
if b && pm.flushStop == nil {
|
||||
pm.flushStop = make(chan struct{})
|
||||
go pm.flushLoop()
|
||||
}
|
||||
}
|
||||
func (pm *ProxyManager) flushLoop() {
|
||||
flushInterval := 2 * time.Second
|
||||
if v := os.Getenv("OTEL_METRIC_EXPORT_INTERVAL"); v != "" {
|
||||
if d, err := time.ParseDuration(v); err == nil && d > 0 {
|
||||
if d/2 < flushInterval {
|
||||
flushInterval = d / 2
|
||||
}
|
||||
}
|
||||
}
|
||||
ticker := time.NewTicker(flushInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
pm.mutex.RLock()
|
||||
for _, e := range pm.tunnels {
|
||||
inTCP := e.bytesInTCP.Swap(0)
|
||||
outTCP := e.bytesOutTCP.Swap(0)
|
||||
inUDP := e.bytesInUDP.Swap(0)
|
||||
outUDP := e.bytesOutUDP.Swap(0)
|
||||
if inTCP > 0 {
|
||||
telemetry.AddTunnelBytesSet(context.Background(), int64(inTCP), e.attrInTCP)
|
||||
}
|
||||
if outTCP > 0 {
|
||||
telemetry.AddTunnelBytesSet(context.Background(), int64(outTCP), e.attrOutTCP)
|
||||
}
|
||||
if inUDP > 0 {
|
||||
telemetry.AddTunnelBytesSet(context.Background(), int64(inUDP), e.attrInUDP)
|
||||
}
|
||||
if outUDP > 0 {
|
||||
telemetry.AddTunnelBytesSet(context.Background(), int64(outUDP), e.attrOutUDP)
|
||||
}
|
||||
}
|
||||
pm.mutex.RUnlock()
|
||||
case <-pm.flushStop:
|
||||
pm.mutex.RLock()
|
||||
for _, e := range pm.tunnels {
|
||||
inTCP := e.bytesInTCP.Swap(0)
|
||||
outTCP := e.bytesOutTCP.Swap(0)
|
||||
inUDP := e.bytesInUDP.Swap(0)
|
||||
outUDP := e.bytesOutUDP.Swap(0)
|
||||
if inTCP > 0 {
|
||||
telemetry.AddTunnelBytesSet(context.Background(), int64(inTCP), e.attrInTCP)
|
||||
}
|
||||
if outTCP > 0 {
|
||||
telemetry.AddTunnelBytesSet(context.Background(), int64(outTCP), e.attrOutTCP)
|
||||
}
|
||||
if inUDP > 0 {
|
||||
telemetry.AddTunnelBytesSet(context.Background(), int64(inUDP), e.attrInUDP)
|
||||
}
|
||||
if outUDP > 0 {
|
||||
telemetry.AddTunnelBytesSet(context.Background(), int64(outUDP), e.attrOutUDP)
|
||||
}
|
||||
}
|
||||
pm.mutex.RUnlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) Stop() error {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
@@ -227,7 +473,7 @@ func (pm *ProxyManager) startTarget(proto, listenIP string, port int, targetAddr
|
||||
go pm.handleUDPProxy(conn, targetAddr)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unsupported protocol: %s", proto)
|
||||
return fmt.Errorf(errUnsupportedProtoFmt, proto)
|
||||
}
|
||||
|
||||
logger.Info("Started %s proxy to %s", proto, targetAddr)
|
||||
@@ -236,54 +482,84 @@ func (pm *ProxyManager) startTarget(proto, listenIP string, port int, targetAddr
|
||||
return nil
|
||||
}
|
||||
|
||||
// getEntry returns per-tunnel entry or nil.
|
||||
func (pm *ProxyManager) getEntry(id string) *tunnelEntry {
|
||||
pm.mutex.RLock()
|
||||
e := pm.tunnels[id]
|
||||
pm.mutex.RUnlock()
|
||||
return e
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string) {
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
// Check if we're shutting down or the listener was closed
|
||||
telemetry.IncProxyAccept(context.Background(), pm.currentTunnelID, "tcp", "failure", classifyProxyError(err))
|
||||
if !pm.running {
|
||||
return
|
||||
}
|
||||
|
||||
// Check for specific network errors that indicate the listener is closed
|
||||
if ne, ok := err.(net.Error); ok && !ne.Temporary() {
|
||||
logger.Info("TCP listener closed, stopping proxy handler for %v", listener.Addr())
|
||||
return
|
||||
}
|
||||
|
||||
logger.Error("Error accepting TCP connection: %v", err)
|
||||
// Don't hammer the CPU if we hit a temporary error
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
go func() {
|
||||
tunnelID := pm.currentTunnelID
|
||||
telemetry.IncProxyAccept(context.Background(), tunnelID, "tcp", "success", "")
|
||||
telemetry.IncProxyConnectionEvent(context.Background(), tunnelID, "tcp", telemetry.ProxyConnectionOpened)
|
||||
if tunnelID != "" {
|
||||
state.Global().IncSessions(tunnelID)
|
||||
if e := pm.getEntry(tunnelID); e != nil {
|
||||
e.activeTCP.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
go func(tunnelID string, accepted net.Conn) {
|
||||
connStart := time.Now()
|
||||
target, err := net.Dial("tcp", targetAddr)
|
||||
if err != nil {
|
||||
logger.Error("Error connecting to target: %v", err)
|
||||
conn.Close()
|
||||
accepted.Close()
|
||||
telemetry.IncProxyAccept(context.Background(), tunnelID, "tcp", "failure", classifyProxyError(err))
|
||||
telemetry.IncProxyConnectionEvent(context.Background(), tunnelID, "tcp", telemetry.ProxyConnectionClosed)
|
||||
telemetry.ObserveProxyConnectionDuration(context.Background(), tunnelID, "tcp", "failure", time.Since(connStart).Seconds())
|
||||
return
|
||||
}
|
||||
|
||||
// Create a WaitGroup to ensure both copy operations complete
|
||||
entry := pm.getEntry(tunnelID)
|
||||
if entry == nil {
|
||||
entry = &tunnelEntry{}
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
go func(ent *tunnelEntry) {
|
||||
defer wg.Done()
|
||||
io.Copy(target, conn)
|
||||
target.Close()
|
||||
}()
|
||||
cw := &countingWriter{ctx: context.Background(), w: target, set: ent.attrInTCP, pm: pm, ent: ent, out: false, proto: "tcp"}
|
||||
_, _ = io.Copy(cw, accepted)
|
||||
_ = target.Close()
|
||||
}(entry)
|
||||
|
||||
go func() {
|
||||
go func(ent *tunnelEntry) {
|
||||
defer wg.Done()
|
||||
io.Copy(conn, target)
|
||||
conn.Close()
|
||||
}()
|
||||
cw := &countingWriter{ctx: context.Background(), w: accepted, set: ent.attrOutTCP, pm: pm, ent: ent, out: true, proto: "tcp"}
|
||||
_, _ = io.Copy(cw, target)
|
||||
_ = accepted.Close()
|
||||
}(entry)
|
||||
|
||||
// Wait for both copies to complete
|
||||
wg.Wait()
|
||||
}()
|
||||
if tunnelID != "" {
|
||||
state.Global().DecSessions(tunnelID)
|
||||
if e := pm.getEntry(tunnelID); e != nil {
|
||||
e.activeTCP.Add(-1)
|
||||
}
|
||||
}
|
||||
telemetry.ObserveProxyConnectionDuration(context.Background(), tunnelID, "tcp", "success", time.Since(connStart).Seconds())
|
||||
telemetry.IncProxyConnectionEvent(context.Background(), tunnelID, "tcp", telemetry.ProxyConnectionClosed)
|
||||
}(tunnelID, conn)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -326,6 +602,18 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
|
||||
}
|
||||
|
||||
clientKey := remoteAddr.String()
|
||||
// bytes from client -> target (direction=in)
|
||||
if pm.currentTunnelID != "" && n > 0 {
|
||||
if pm.asyncBytes {
|
||||
if e := pm.getEntry(pm.currentTunnelID); e != nil {
|
||||
e.bytesInUDP.Add(uint64(n))
|
||||
}
|
||||
} else {
|
||||
if e := pm.getEntry(pm.currentTunnelID); e != nil {
|
||||
telemetry.AddTunnelBytesSet(context.Background(), int64(n), e.attrInUDP)
|
||||
}
|
||||
}
|
||||
}
|
||||
clientsMutex.RLock()
|
||||
targetConn, exists := clientConns[clientKey]
|
||||
clientsMutex.RUnlock()
|
||||
@@ -334,28 +622,44 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
|
||||
targetUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr)
|
||||
if err != nil {
|
||||
logger.Error("Error resolving target address: %v", err)
|
||||
telemetry.IncProxyAccept(context.Background(), pm.currentTunnelID, "udp", "failure", "resolve")
|
||||
continue
|
||||
}
|
||||
|
||||
targetConn, err = net.DialUDP("udp", nil, targetUDPAddr)
|
||||
if err != nil {
|
||||
logger.Error("Error connecting to target: %v", err)
|
||||
telemetry.IncProxyAccept(context.Background(), pm.currentTunnelID, "udp", "failure", classifyProxyError(err))
|
||||
continue
|
||||
}
|
||||
tunnelID := pm.currentTunnelID
|
||||
telemetry.IncProxyAccept(context.Background(), tunnelID, "udp", "success", "")
|
||||
telemetry.IncProxyConnectionEvent(context.Background(), tunnelID, "udp", telemetry.ProxyConnectionOpened)
|
||||
// Only increment activeUDP after a successful DialUDP
|
||||
if e := pm.getEntry(tunnelID); e != nil {
|
||||
e.activeUDP.Add(1)
|
||||
}
|
||||
|
||||
clientsMutex.Lock()
|
||||
clientConns[clientKey] = targetConn
|
||||
clientsMutex.Unlock()
|
||||
|
||||
go func(clientKey string, targetConn *net.UDPConn, remoteAddr net.Addr) {
|
||||
go func(clientKey string, targetConn *net.UDPConn, remoteAddr net.Addr, tunnelID string) {
|
||||
start := time.Now()
|
||||
result := "success"
|
||||
defer func() {
|
||||
// Always clean up when this goroutine exits
|
||||
clientsMutex.Lock()
|
||||
if storedConn, exists := clientConns[clientKey]; exists && storedConn == targetConn {
|
||||
delete(clientConns, clientKey)
|
||||
targetConn.Close()
|
||||
if e := pm.getEntry(tunnelID); e != nil {
|
||||
e.activeUDP.Add(-1)
|
||||
}
|
||||
}
|
||||
clientsMutex.Unlock()
|
||||
telemetry.ObserveProxyConnectionDuration(context.Background(), tunnelID, "udp", result, time.Since(start).Seconds())
|
||||
telemetry.IncProxyConnectionEvent(context.Background(), tunnelID, "udp", telemetry.ProxyConnectionClosed)
|
||||
}()
|
||||
|
||||
buffer := make([]byte, 65507)
|
||||
@@ -363,25 +667,52 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
|
||||
n, _, err := targetConn.ReadFromUDP(buffer)
|
||||
if err != nil {
|
||||
logger.Error("Error reading from target: %v", err)
|
||||
result = "failure"
|
||||
return // defer will handle cleanup
|
||||
}
|
||||
|
||||
// bytes from target -> client (direction=out)
|
||||
if pm.currentTunnelID != "" && n > 0 {
|
||||
if pm.asyncBytes {
|
||||
if e := pm.getEntry(pm.currentTunnelID); e != nil {
|
||||
e.bytesOutUDP.Add(uint64(n))
|
||||
}
|
||||
} else {
|
||||
if e := pm.getEntry(pm.currentTunnelID); e != nil {
|
||||
telemetry.AddTunnelBytesSet(context.Background(), int64(n), e.attrOutUDP)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_, err = conn.WriteTo(buffer[:n], remoteAddr)
|
||||
if err != nil {
|
||||
logger.Error("Error writing to client: %v", err)
|
||||
telemetry.IncProxyDrops(context.Background(), pm.currentTunnelID, "udp")
|
||||
result = "failure"
|
||||
return // defer will handle cleanup
|
||||
}
|
||||
}
|
||||
}(clientKey, targetConn, remoteAddr)
|
||||
}(clientKey, targetConn, remoteAddr, tunnelID)
|
||||
}
|
||||
|
||||
_, err = targetConn.Write(buffer[:n])
|
||||
written, err := targetConn.Write(buffer[:n])
|
||||
if err != nil {
|
||||
logger.Error("Error writing to target: %v", err)
|
||||
telemetry.IncProxyDrops(context.Background(), pm.currentTunnelID, "udp")
|
||||
targetConn.Close()
|
||||
clientsMutex.Lock()
|
||||
delete(clientConns, clientKey)
|
||||
clientsMutex.Unlock()
|
||||
} else if pm.currentTunnelID != "" && written > 0 {
|
||||
if pm.asyncBytes {
|
||||
if e := pm.getEntry(pm.currentTunnelID); e != nil {
|
||||
e.bytesInUDP.Add(uint64(written))
|
||||
}
|
||||
} else {
|
||||
if e := pm.getEntry(pm.currentTunnelID); e != nil {
|
||||
telemetry.AddTunnelBytesSet(context.Background(), int64(written), e.attrInUDP)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
59
service_unix.go
Normal file
59
service_unix.go
Normal file
@@ -0,0 +1,59 @@
|
||||
//go:build !windows
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Service management functions are not available on non-Windows platforms
|
||||
func installService() error {
|
||||
return fmt.Errorf("service management is only available on Windows")
|
||||
}
|
||||
|
||||
func removeService() error {
|
||||
return fmt.Errorf("service management is only available on Windows")
|
||||
}
|
||||
|
||||
func startService(args []string) error {
|
||||
_ = args // unused on Unix platforms
|
||||
return fmt.Errorf("service management is only available on Windows")
|
||||
}
|
||||
|
||||
func stopService() error {
|
||||
return fmt.Errorf("service management is only available on Windows")
|
||||
}
|
||||
|
||||
func getServiceStatus() (string, error) {
|
||||
return "", fmt.Errorf("service management is only available on Windows")
|
||||
}
|
||||
|
||||
func debugService(args []string) error {
|
||||
_ = args // unused on Unix platforms
|
||||
return fmt.Errorf("debug service is only available on Windows")
|
||||
}
|
||||
|
||||
func isWindowsService() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func runService(name string, isDebug bool, args []string) {
|
||||
// No-op on non-Windows platforms
|
||||
}
|
||||
|
||||
func setupWindowsEventLog() {
|
||||
// No-op on non-Windows platforms
|
||||
}
|
||||
|
||||
func watchLogFile(end bool) error {
|
||||
return fmt.Errorf("watching log file is only available on Windows")
|
||||
}
|
||||
|
||||
func showServiceConfig() {
|
||||
fmt.Println("Service configuration is only available on Windows")
|
||||
}
|
||||
|
||||
// handleServiceCommand returns false on non-Windows platforms
|
||||
func handleServiceCommand() bool {
|
||||
return false
|
||||
}
|
||||
760
service_windows.go
Normal file
760
service_windows.go
Normal file
@@ -0,0 +1,760 @@
|
||||
//go:build windows
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"golang.org/x/sys/windows/svc"
|
||||
"golang.org/x/sys/windows/svc/debug"
|
||||
"golang.org/x/sys/windows/svc/eventlog"
|
||||
"golang.org/x/sys/windows/svc/mgr"
|
||||
)
|
||||
|
||||
const (
|
||||
serviceName = "NewtWireguardService"
|
||||
serviceDisplayName = "Newt WireGuard Tunnel Service"
|
||||
serviceDescription = "Newt WireGuard tunnel service for secure network connectivity"
|
||||
)
|
||||
|
||||
// Global variable to store service arguments
|
||||
var serviceArgs []string
|
||||
|
||||
// getServiceArgsPath returns the path where service arguments are stored
|
||||
func getServiceArgsPath() string {
|
||||
logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "newt")
|
||||
return filepath.Join(logDir, "service_args.json")
|
||||
}
|
||||
|
||||
// saveServiceArgs saves the service arguments to a file
|
||||
func saveServiceArgs(args []string) error {
|
||||
logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "newt")
|
||||
err := os.MkdirAll(logDir, 0755)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create config directory: %v", err)
|
||||
}
|
||||
|
||||
argsPath := getServiceArgsPath()
|
||||
data, err := json.Marshal(args)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal service args: %v", err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(argsPath, data, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write service args: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadServiceArgs loads the service arguments from a file
|
||||
func loadServiceArgs() ([]string, error) {
|
||||
argsPath := getServiceArgsPath()
|
||||
data, err := os.ReadFile(argsPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return []string{}, nil // Return empty args if file doesn't exist
|
||||
}
|
||||
return nil, fmt.Errorf("failed to read service args: %v", err)
|
||||
}
|
||||
|
||||
var args []string
|
||||
err = json.Unmarshal(data, &args)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal service args: %v", err)
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
type newtService struct {
|
||||
elog debug.Log
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
args []string
|
||||
}
|
||||
|
||||
func (s *newtService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (bool, uint32) {
|
||||
const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown
|
||||
changes <- svc.Status{State: svc.StartPending}
|
||||
|
||||
s.elog.Info(1, fmt.Sprintf("Service Execute called with args: %v", args))
|
||||
|
||||
// Load saved service arguments
|
||||
savedArgs, err := loadServiceArgs()
|
||||
if err != nil {
|
||||
s.elog.Error(1, fmt.Sprintf("Failed to load service args: %v", err))
|
||||
// Continue with empty args if loading fails
|
||||
savedArgs = []string{}
|
||||
}
|
||||
s.elog.Info(1, fmt.Sprintf("Loaded saved service args: %v", savedArgs))
|
||||
|
||||
// Combine service start args with saved args, giving priority to service start args
|
||||
// Note: When the service is started via SCM, args[0] is the service name
|
||||
// When started via s.Start(args...), the args passed are exactly what we provide
|
||||
finalArgs := []string{}
|
||||
|
||||
// Check if we have args passed directly to Execute (from s.Start())
|
||||
if len(args) > 0 {
|
||||
// The first arg from SCM is the service name, but when we call s.Start(args...),
|
||||
// the args we pass become args[1:] in Execute. However, if started by SCM without
|
||||
// args, args[0] will be the service name.
|
||||
// We need to check if args[0] looks like the service name or a flag
|
||||
if len(args) == 1 && args[0] == serviceName {
|
||||
// Only service name, no actual args
|
||||
s.elog.Info(1, "Only service name in args, checking saved args")
|
||||
} else if len(args) > 1 && args[0] == serviceName {
|
||||
// Service name followed by actual args
|
||||
finalArgs = append(finalArgs, args[1:]...)
|
||||
s.elog.Info(1, fmt.Sprintf("Using service start parameters (after service name): %v", finalArgs))
|
||||
} else {
|
||||
// Args don't start with service name, use them all
|
||||
// This happens when args are passed via s.Start(args...)
|
||||
finalArgs = append(finalArgs, args...)
|
||||
s.elog.Info(1, fmt.Sprintf("Using service start parameters (direct): %v", finalArgs))
|
||||
}
|
||||
}
|
||||
|
||||
// If no service start parameters, use saved args
|
||||
if len(finalArgs) == 0 && len(savedArgs) > 0 {
|
||||
finalArgs = savedArgs
|
||||
s.elog.Info(1, fmt.Sprintf("Using saved service args: %v", finalArgs))
|
||||
}
|
||||
|
||||
s.elog.Info(1, fmt.Sprintf("Final args to use: %v", finalArgs))
|
||||
s.args = finalArgs
|
||||
|
||||
// Start the main newt functionality
|
||||
newtDone := make(chan struct{})
|
||||
go func() {
|
||||
s.runNewt()
|
||||
close(newtDone)
|
||||
}()
|
||||
|
||||
changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted}
|
||||
s.elog.Info(1, "Service status set to Running")
|
||||
|
||||
for {
|
||||
select {
|
||||
case c := <-r:
|
||||
switch c.Cmd {
|
||||
case svc.Interrogate:
|
||||
changes <- c.CurrentStatus
|
||||
case svc.Stop, svc.Shutdown:
|
||||
s.elog.Info(1, "Service stopping")
|
||||
changes <- svc.Status{State: svc.StopPending}
|
||||
if s.stop != nil {
|
||||
s.stop()
|
||||
}
|
||||
// Wait for main logic to finish or timeout
|
||||
select {
|
||||
case <-newtDone:
|
||||
s.elog.Info(1, "Main logic finished gracefully")
|
||||
case <-time.After(10 * time.Second):
|
||||
s.elog.Info(1, "Timeout waiting for main logic to finish")
|
||||
}
|
||||
return false, 0
|
||||
default:
|
||||
s.elog.Error(1, fmt.Sprintf("Unexpected control request #%d", c))
|
||||
}
|
||||
case <-newtDone:
|
||||
s.elog.Info(1, "Main newt logic completed, stopping service")
|
||||
changes <- svc.Status{State: svc.StopPending}
|
||||
return false, 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *newtService) runNewt() {
|
||||
// Create a context that can be cancelled when the service stops
|
||||
s.ctx, s.stop = context.WithCancel(context.Background())
|
||||
|
||||
// Setup logging for service mode
|
||||
s.elog.Info(1, "Starting Newt main logic")
|
||||
|
||||
// Run the main newt logic and wait for it to complete
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
s.elog.Error(1, fmt.Sprintf("Panic in newt main: %v", r))
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Call the main newt function with stored arguments
|
||||
// Use s.ctx as the signal context since the service manages shutdown
|
||||
runNewtMainWithArgs(s.ctx, s.args)
|
||||
}()
|
||||
|
||||
// Wait for either context cancellation or main logic completion
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
s.elog.Info(1, "Newt service context cancelled")
|
||||
case <-done:
|
||||
s.elog.Info(1, "Newt main logic completed")
|
||||
}
|
||||
}
|
||||
|
||||
func runService(name string, isDebug bool, args []string) {
|
||||
var err error
|
||||
var elog debug.Log
|
||||
|
||||
if isDebug {
|
||||
elog = debug.New(name)
|
||||
fmt.Printf("Starting %s service in debug mode\n", name)
|
||||
} else {
|
||||
elog, err = eventlog.Open(name)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to open event log: %v\n", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
defer elog.Close()
|
||||
|
||||
elog.Info(1, fmt.Sprintf("Starting %s service", name))
|
||||
run := svc.Run
|
||||
if isDebug {
|
||||
run = debug.Run
|
||||
}
|
||||
|
||||
service := &newtService{elog: elog, args: args}
|
||||
err = run(name, service)
|
||||
if err != nil {
|
||||
elog.Error(1, fmt.Sprintf("%s service failed: %v", name, err))
|
||||
if isDebug {
|
||||
fmt.Printf("Service failed: %v\n", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
elog.Info(1, fmt.Sprintf("%s service stopped", name))
|
||||
if isDebug {
|
||||
fmt.Printf("%s service stopped\n", name)
|
||||
}
|
||||
}
|
||||
|
||||
func installService() error {
|
||||
exepath, err := os.Executable()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get executable path: %v", err)
|
||||
}
|
||||
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to service manager: %v", err)
|
||||
}
|
||||
defer m.Disconnect()
|
||||
|
||||
s, err := m.OpenService(serviceName)
|
||||
if err == nil {
|
||||
s.Close()
|
||||
return fmt.Errorf("service %s already exists", serviceName)
|
||||
}
|
||||
|
||||
config := mgr.Config{
|
||||
ServiceType: 0x10, // SERVICE_WIN32_OWN_PROCESS
|
||||
StartType: mgr.StartManual,
|
||||
ErrorControl: mgr.ErrorNormal,
|
||||
DisplayName: serviceDisplayName,
|
||||
Description: serviceDescription,
|
||||
BinaryPathName: exepath,
|
||||
}
|
||||
|
||||
s, err = m.CreateService(serviceName, exepath, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create service: %v", err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
err = eventlog.InstallAsEventCreate(serviceName, eventlog.Error|eventlog.Warning|eventlog.Info)
|
||||
if err != nil {
|
||||
s.Delete()
|
||||
return fmt.Errorf("failed to install event log: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeService() error {
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to service manager: %v", err)
|
||||
}
|
||||
defer m.Disconnect()
|
||||
|
||||
s, err := m.OpenService(serviceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("service %s is not installed", serviceName)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
// Stop the service if it's running
|
||||
status, err := s.Query()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query service status: %v", err)
|
||||
}
|
||||
|
||||
if status.State != svc.Stopped {
|
||||
_, err = s.Control(svc.Stop)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to stop service: %v", err)
|
||||
}
|
||||
|
||||
// Wait for service to stop
|
||||
timeout := time.Now().Add(30 * time.Second)
|
||||
for status.State != svc.Stopped {
|
||||
if timeout.Before(time.Now()) {
|
||||
return fmt.Errorf("timeout waiting for service to stop")
|
||||
}
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
status, err = s.Query()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query service status: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = s.Delete()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete service: %v", err)
|
||||
}
|
||||
|
||||
err = eventlog.Remove(serviceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove event log: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func startService(args []string) error {
|
||||
fmt.Printf("Starting service with args: %v\n", args)
|
||||
|
||||
// Always save the service arguments so they can be loaded on service restart
|
||||
err := saveServiceArgs(args)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: failed to save service args: %v\n", err)
|
||||
// Continue anyway, args will still be passed directly
|
||||
} else {
|
||||
fmt.Printf("Saved service args to: %s\n", getServiceArgsPath())
|
||||
}
|
||||
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to service manager: %v", err)
|
||||
}
|
||||
defer m.Disconnect()
|
||||
|
||||
s, err := m.OpenService(serviceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("service %s is not installed", serviceName)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
// Pass arguments directly to the service start call
|
||||
// Note: These args will appear in Execute() after the service name
|
||||
err = s.Start(args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start service: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func stopService() error {
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to service manager: %v", err)
|
||||
}
|
||||
defer m.Disconnect()
|
||||
|
||||
s, err := m.OpenService(serviceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("service %s is not installed", serviceName)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
status, err := s.Control(svc.Stop)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to stop service: %v", err)
|
||||
}
|
||||
|
||||
timeout := time.Now().Add(30 * time.Second)
|
||||
for status.State != svc.Stopped {
|
||||
if timeout.Before(time.Now()) {
|
||||
return fmt.Errorf("timeout waiting for service to stop")
|
||||
}
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
status, err = s.Query()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query service status: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func debugService(args []string) error {
|
||||
// Save the service arguments before starting
|
||||
if len(args) > 0 {
|
||||
err := saveServiceArgs(args)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save service args: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Start the service with the provided arguments
|
||||
err := startService(args)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start service: %v", err)
|
||||
}
|
||||
|
||||
// Watch the log file
|
||||
return watchLogFile(true)
|
||||
}
|
||||
|
||||
func watchLogFile(end bool) error {
|
||||
logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "newt", "logs")
|
||||
logPath := filepath.Join(logDir, "newt.log")
|
||||
|
||||
// Ensure the log directory exists
|
||||
err := os.MkdirAll(logDir, 0755)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create log directory: %v", err)
|
||||
}
|
||||
|
||||
// Wait for the log file to be created if it doesn't exist
|
||||
var file *os.File
|
||||
for i := 0; i < 30; i++ { // Wait up to 15 seconds
|
||||
file, err = os.Open(logPath)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if i == 0 {
|
||||
fmt.Printf("Waiting for log file to be created...\n")
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open log file after waiting: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Seek to the end of the file to only show new logs
|
||||
_, err = file.Seek(0, 2)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to seek to end of file: %v", err)
|
||||
}
|
||||
|
||||
// Set up signal handling for graceful exit
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
// Create a ticker to check for new content
|
||||
ticker := time.NewTicker(500 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
buffer := make([]byte, 4096)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-sigCh:
|
||||
fmt.Printf("\n\nStopping log watch...\n")
|
||||
// stop the service if needed
|
||||
if end {
|
||||
fmt.Printf("Stopping service...\n")
|
||||
stopService()
|
||||
}
|
||||
fmt.Printf("Log watch stopped.\n")
|
||||
return nil
|
||||
case <-ticker.C:
|
||||
// Read new content
|
||||
n, err := file.Read(buffer)
|
||||
if err != nil && err != io.EOF {
|
||||
// Try to reopen the file in case it was recreated
|
||||
file.Close()
|
||||
file, err = os.Open(logPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
// Print the new content
|
||||
fmt.Print(string(buffer[:n]))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getServiceStatus() (string, error) {
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to connect to service manager: %v", err)
|
||||
}
|
||||
defer m.Disconnect()
|
||||
|
||||
s, err := m.OpenService(serviceName)
|
||||
if err != nil {
|
||||
return "Not Installed", nil
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
status, err := s.Query()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to query service status: %v", err)
|
||||
}
|
||||
|
||||
switch status.State {
|
||||
case svc.Stopped:
|
||||
return "Stopped", nil
|
||||
case svc.StartPending:
|
||||
return "Starting", nil
|
||||
case svc.StopPending:
|
||||
return "Stopping", nil
|
||||
case svc.Running:
|
||||
return "Running", nil
|
||||
case svc.ContinuePending:
|
||||
return "Continue Pending", nil
|
||||
case svc.PausePending:
|
||||
return "Pause Pending", nil
|
||||
case svc.Paused:
|
||||
return "Paused", nil
|
||||
default:
|
||||
return "Unknown", nil
|
||||
}
|
||||
}
|
||||
|
||||
// showServiceConfig displays current saved service configuration
|
||||
func showServiceConfig() {
|
||||
configPath := getServiceArgsPath()
|
||||
fmt.Printf("Service configuration file: %s\n", configPath)
|
||||
|
||||
args, err := loadServiceArgs()
|
||||
if err != nil {
|
||||
fmt.Printf("No saved configuration found or error loading: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) == 0 {
|
||||
fmt.Println("No saved service arguments found")
|
||||
} else {
|
||||
fmt.Printf("Saved service arguments: %v\n", args)
|
||||
}
|
||||
}
|
||||
|
||||
func isWindowsService() bool {
|
||||
isWindowsService, err := svc.IsWindowsService()
|
||||
return err == nil && isWindowsService
|
||||
}
|
||||
|
||||
// rotateLogFile handles daily log rotation
|
||||
func rotateLogFile(logDir string, logFile string) error {
|
||||
// Get current log file info
|
||||
info, err := os.Stat(logFile)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil // No current log file to rotate
|
||||
}
|
||||
return fmt.Errorf("failed to stat log file: %v", err)
|
||||
}
|
||||
|
||||
// Check if log file is from today
|
||||
now := time.Now()
|
||||
fileTime := info.ModTime()
|
||||
|
||||
// If the log file is from today, no rotation needed
|
||||
if now.Year() == fileTime.Year() && now.YearDay() == fileTime.YearDay() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create rotated filename with date
|
||||
rotatedName := fmt.Sprintf("newt-%s.log", fileTime.Format("2006-01-02"))
|
||||
rotatedPath := filepath.Join(logDir, rotatedName)
|
||||
|
||||
// Rename current log file to dated filename
|
||||
err = os.Rename(logFile, rotatedPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to rotate log file: %v", err)
|
||||
}
|
||||
|
||||
// Clean up old log files (keep last 30 days)
|
||||
cleanupOldLogFiles(logDir, 30)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupOldLogFiles removes log files older than specified days
|
||||
func cleanupOldLogFiles(logDir string, daysToKeep int) {
|
||||
cutoff := time.Now().AddDate(0, 0, -daysToKeep)
|
||||
|
||||
files, err := os.ReadDir(logDir)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
if !file.IsDir() && strings.HasPrefix(file.Name(), "newt-") && strings.HasSuffix(file.Name(), ".log") {
|
||||
filePath := filepath.Join(logDir, file.Name())
|
||||
info, err := file.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if info.ModTime().Before(cutoff) {
|
||||
os.Remove(filePath)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func setupWindowsEventLog() {
|
||||
// Create log directory if it doesn't exist
|
||||
logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "newt", "logs")
|
||||
err := os.MkdirAll(logDir, 0755)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to create log directory: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
logFile := filepath.Join(logDir, "newt.log")
|
||||
|
||||
// Rotate log file if needed
|
||||
err = rotateLogFile(logDir, logFile)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to rotate log file: %v\n", err)
|
||||
// Continue anyway to create new log file
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to open log file: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Set the custom logger output
|
||||
logger.GetLogger().SetOutput(file)
|
||||
|
||||
log.Printf("Newt service logging initialized - log file: %s", logFile)
|
||||
}
|
||||
|
||||
// handleServiceCommand checks for service management commands and returns true if handled
|
||||
func handleServiceCommand() bool {
|
||||
if len(os.Args) < 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
command := os.Args[1]
|
||||
|
||||
switch command {
|
||||
case "install":
|
||||
err := installService()
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to install service: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Println("Service installed successfully")
|
||||
return true
|
||||
case "remove", "uninstall":
|
||||
err := removeService()
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to remove service: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Println("Service removed successfully")
|
||||
return true
|
||||
case "start":
|
||||
// Pass the remaining arguments (after "start") to the service
|
||||
serviceArgs := os.Args[2:]
|
||||
err := startService(serviceArgs)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to start service: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Println("Service started successfully")
|
||||
return true
|
||||
case "stop":
|
||||
err := stopService()
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to stop service: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Println("Service stopped successfully")
|
||||
return true
|
||||
case "status":
|
||||
status, err := getServiceStatus()
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to get service status: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Printf("Service status: %s\n", status)
|
||||
return true
|
||||
case "debug":
|
||||
// get the status and if it is Not Installed then install it first
|
||||
status, err := getServiceStatus()
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to get service status: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if status == "Not Installed" {
|
||||
err := installService()
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to install service: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Println("Service installed successfully, now running in debug mode")
|
||||
}
|
||||
|
||||
// Pass the remaining arguments (after "debug") to the service
|
||||
serviceArgs := os.Args[2:]
|
||||
err = debugService(serviceArgs)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to debug service: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
return true
|
||||
case "logs":
|
||||
err := watchLogFile(false)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to watch log file: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
return true
|
||||
case "config":
|
||||
showServiceConfig()
|
||||
return true
|
||||
case "service-help":
|
||||
fmt.Println("Newt WireGuard Tunnel")
|
||||
fmt.Println("\nWindows Service Management:")
|
||||
fmt.Println(" install Install the service")
|
||||
fmt.Println(" remove Remove the service")
|
||||
fmt.Println(" start [args] Start the service with optional arguments")
|
||||
fmt.Println(" stop Stop the service")
|
||||
fmt.Println(" status Show service status")
|
||||
fmt.Println(" debug [args] Run service in debug mode with optional arguments")
|
||||
fmt.Println(" logs Tail the service log file")
|
||||
fmt.Println(" config Show current service configuration")
|
||||
fmt.Println(" service-help Show this service help")
|
||||
fmt.Println("\nExamples:")
|
||||
fmt.Println(" newt start --endpoint https://example.com --id myid --secret mysecret")
|
||||
fmt.Println(" newt debug --endpoint https://example.com --id myid --secret mysecret")
|
||||
fmt.Println("\nFor normal console mode, run with standard flags (e.g., newt --endpoint ...)")
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
17
stub.go
17
stub.go
@@ -8,25 +8,32 @@ import (
|
||||
)
|
||||
|
||||
func setupClientsNative(client *websocket.Client, host string) {
|
||||
return // This function is not implemented for non-Linux systems.
|
||||
_ = client
|
||||
_ = host
|
||||
// No-op for non-Linux systems
|
||||
}
|
||||
|
||||
func closeWgServiceNative() {
|
||||
// No-op for non-Linux systems
|
||||
return
|
||||
}
|
||||
|
||||
func clientsOnConnectNative() {
|
||||
// No-op for non-Linux systems
|
||||
return
|
||||
}
|
||||
|
||||
func clientsHandleNewtConnectionNative(publicKey, endpoint string) {
|
||||
_ = publicKey
|
||||
_ = endpoint
|
||||
// No-op for non-Linux systems
|
||||
return
|
||||
}
|
||||
|
||||
func clientsAddProxyTargetNative(pm *proxy.ProxyManager, tunnelIp string) {
|
||||
_ = pm
|
||||
_ = tunnelIp
|
||||
// No-op for non-Linux systems
|
||||
}
|
||||
|
||||
func clientsStartDirectRelayNative(tunnelIP string) {
|
||||
_ = tunnelIP
|
||||
// No-op for non-Linux systems
|
||||
return
|
||||
}
|
||||
|
||||
49
udp_client.py
Normal file
49
udp_client.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import socket
|
||||
import sys
|
||||
|
||||
# Argument parsing: Check if IP and Port are provided
|
||||
if len(sys.argv) != 3:
|
||||
print("Usage: python udp_client.py <HOST_IP> <HOST_PORT>")
|
||||
# Example: python udp_client.py 127.0.0.1 12000
|
||||
sys.exit(1)
|
||||
|
||||
HOST = sys.argv[1]
|
||||
try:
|
||||
PORT = int(sys.argv[2])
|
||||
except ValueError:
|
||||
print("Error: HOST_PORT must be an integer.")
|
||||
sys.exit(1)
|
||||
|
||||
# The message to send to the server
|
||||
MESSAGE = "Hello UDP Server! How are you?"
|
||||
|
||||
# Create a UDP socket
|
||||
try:
|
||||
client_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
except socket.error as err:
|
||||
print(f"Failed to create socket: {err}")
|
||||
sys.exit()
|
||||
|
||||
try:
|
||||
print(f"Sending message to {HOST}:{PORT}...")
|
||||
|
||||
# Send the message (data must be encoded to bytes)
|
||||
client_socket.sendto(MESSAGE.encode('utf-8'), (HOST, PORT))
|
||||
|
||||
# Wait for the server's response (buffer size 1024 bytes)
|
||||
data, server_address = client_socket.recvfrom(1024)
|
||||
|
||||
# Decode and print the server's response
|
||||
response = data.decode('utf-8')
|
||||
print("-" * 30)
|
||||
print(f"Received response from server {server_address[0]}:{server_address[1]}:")
|
||||
print(f"-> Data: '{response}'")
|
||||
|
||||
except socket.error as err:
|
||||
print(f"Error during communication: {err}")
|
||||
|
||||
finally:
|
||||
# Close the socket
|
||||
client_socket.close()
|
||||
print("-" * 30)
|
||||
print("Client finished and socket closed.")
|
||||
58
udp_server.py
Normal file
58
udp_server.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import socket
|
||||
import sys
|
||||
|
||||
# optionally take in some positional args for the port
|
||||
if len(sys.argv) > 1:
|
||||
try:
|
||||
PORT = int(sys.argv[1])
|
||||
except ValueError:
|
||||
print("Invalid port number. Using default port 12000.")
|
||||
PORT = 12000
|
||||
else:
|
||||
PORT = 12000
|
||||
|
||||
# Define the server host and port
|
||||
HOST = '0.0.0.0' # Standard loopback interface address (localhost)
|
||||
|
||||
# Create a UDP socket
|
||||
try:
|
||||
server_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
except socket.error as err:
|
||||
print(f"Failed to create socket: {err}")
|
||||
sys.exit()
|
||||
|
||||
# Bind the socket to the address
|
||||
try:
|
||||
server_socket.bind((HOST, PORT))
|
||||
print(f"UDP Server listening on {HOST}:{PORT}")
|
||||
except socket.error as err:
|
||||
print(f"Bind failed: {err}")
|
||||
server_socket.close()
|
||||
sys.exit()
|
||||
|
||||
# Wait for and process incoming data
|
||||
while True:
|
||||
try:
|
||||
# Receive data and the client's address (buffer size 1024 bytes)
|
||||
data, client_address = server_socket.recvfrom(1024)
|
||||
|
||||
# Decode the data and print the message
|
||||
message = data.decode('utf-8')
|
||||
print("-" * 30)
|
||||
print(f"Received message from {client_address[0]}:{client_address[1]}:")
|
||||
print(f"-> Data: '{message}'")
|
||||
|
||||
# Prepare the response message
|
||||
response_message = f"Hello client! Server received: '{message.upper()}'"
|
||||
|
||||
# Send the response back to the client
|
||||
server_socket.sendto(response_message.encode('utf-8'), client_address)
|
||||
print(f"Sent response back to client.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
break
|
||||
|
||||
# Clean up (though usually unreachable in an infinite server loop)
|
||||
server_socket.close()
|
||||
print("Server stopped.")
|
||||
@@ -119,7 +119,7 @@ func CheckForUpdate(owner, repo, currentVersion string) error {
|
||||
|
||||
// Check if update is available
|
||||
if currentVer.isNewer(latestVer) {
|
||||
printUpdateBanner(currentVer.String(), latestVer.String(), release.HTMLURL)
|
||||
printUpdateBanner(currentVer.String(), latestVer.String(), "curl -fsSL https://static.pangolin.net/get-newt.sh | bash")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -145,7 +145,7 @@ func printUpdateBanner(currentVersion, latestVersion, releaseURL string) {
|
||||
"║ A newer version is available! Please update to get the" + padRight("", contentWidth-56) + "║",
|
||||
"║ latest features, bug fixes, and security improvements." + padRight("", contentWidth-56) + "║",
|
||||
emptyLine,
|
||||
"║ Release URL: " + padRight(releaseURL, contentWidth-15) + "║",
|
||||
"║ Update: " + padRight(releaseURL, contentWidth-10) + "║",
|
||||
emptyLine,
|
||||
borderBot,
|
||||
}
|
||||
|
||||
226
util/util.go
Normal file
226
util/util.go
Normal file
@@ -0,0 +1,226 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
mathrand "math/rand/v2"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
)
|
||||
|
||||
func ResolveDomain(domain string) (string, error) {
|
||||
// trim whitespace
|
||||
domain = strings.TrimSpace(domain)
|
||||
|
||||
// Remove any protocol prefix if present (do this first, before splitting host/port)
|
||||
domain = strings.TrimPrefix(domain, "http://")
|
||||
domain = strings.TrimPrefix(domain, "https://")
|
||||
|
||||
// if there are any trailing slashes, remove them
|
||||
domain = strings.TrimSuffix(domain, "/")
|
||||
|
||||
// Check if there's a port in the domain
|
||||
host, port, err := net.SplitHostPort(domain)
|
||||
if err != nil {
|
||||
// No port found, use the domain as is
|
||||
host = domain
|
||||
port = ""
|
||||
}
|
||||
|
||||
// Check if host is already an IP address (IPv4 or IPv6)
|
||||
// For IPv6, the host from SplitHostPort will already have brackets stripped
|
||||
// but if there was no port, we need to handle bracketed IPv6 addresses
|
||||
cleanHost := strings.TrimPrefix(strings.TrimSuffix(host, "]"), "[")
|
||||
if ip := net.ParseIP(cleanHost); ip != nil {
|
||||
// It's already an IP address, no need to resolve
|
||||
ipAddr := ip.String()
|
||||
if port != "" {
|
||||
return net.JoinHostPort(ipAddr, port), nil
|
||||
}
|
||||
return ipAddr, nil
|
||||
}
|
||||
|
||||
// Lookup IP addresses
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("DNS lookup failed: %v", err)
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
return "", fmt.Errorf("no IP addresses found for domain %s", host)
|
||||
}
|
||||
|
||||
// Get the first IPv4 address if available
|
||||
var ipAddr string
|
||||
for _, ip := range ips {
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
ipAddr = ipv4.String()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If no IPv4 found, use the first IP (might be IPv6)
|
||||
if ipAddr == "" {
|
||||
ipAddr = ips[0].String()
|
||||
}
|
||||
|
||||
// Add port back if it existed
|
||||
if port != "" {
|
||||
ipAddr = net.JoinHostPort(ipAddr, port)
|
||||
}
|
||||
|
||||
return ipAddr, nil
|
||||
}
|
||||
|
||||
func ParseLogLevel(level string) logger.LogLevel {
|
||||
switch strings.ToUpper(level) {
|
||||
case "DEBUG":
|
||||
return logger.DEBUG
|
||||
case "INFO":
|
||||
return logger.INFO
|
||||
case "WARN":
|
||||
return logger.WARN
|
||||
case "ERROR":
|
||||
return logger.ERROR
|
||||
case "FATAL":
|
||||
return logger.FATAL
|
||||
default:
|
||||
return logger.INFO // default to INFO if invalid level provided
|
||||
}
|
||||
}
|
||||
|
||||
// find an available UDP port in the range [minPort, maxPort] and also the next port for the wgtester
|
||||
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
|
||||
if maxPort < minPort {
|
||||
return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort)
|
||||
}
|
||||
|
||||
// We need to check port+1 as well, so adjust the max port to avoid going out of range
|
||||
adjustedMaxPort := maxPort - 1
|
||||
if adjustedMaxPort < minPort {
|
||||
return 0, fmt.Errorf("insufficient port range to find consecutive ports: min=%d, max=%d", minPort, maxPort)
|
||||
}
|
||||
|
||||
// Create a slice of all ports in the range (excluding the last one)
|
||||
portRange := make([]uint16, adjustedMaxPort-minPort+1)
|
||||
for i := range portRange {
|
||||
portRange[i] = minPort + uint16(i)
|
||||
}
|
||||
|
||||
// Fisher-Yates shuffle to randomize the port order
|
||||
for i := len(portRange) - 1; i > 0; i-- {
|
||||
j := mathrand.IntN(i + 1)
|
||||
portRange[i], portRange[j] = portRange[j], portRange[i]
|
||||
}
|
||||
|
||||
// Try each port in the randomized order
|
||||
for _, port := range portRange {
|
||||
// Check if port is available
|
||||
addr1 := &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: int(port),
|
||||
}
|
||||
conn1, err1 := net.ListenUDP("udp", addr1)
|
||||
if err1 != nil {
|
||||
continue // Port is in use or there was an error, try next port
|
||||
}
|
||||
|
||||
conn1.Close()
|
||||
return port, nil
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("no available consecutive UDP ports found in range %d-%d", minPort, maxPort)
|
||||
}
|
||||
|
||||
func FixKey(key string) string {
|
||||
// Remove any whitespace
|
||||
key = strings.TrimSpace(key)
|
||||
|
||||
// Decode from base64
|
||||
decoded, err := base64.StdEncoding.DecodeString(key)
|
||||
if err != nil {
|
||||
logger.Fatal("Error decoding base64: %v", err)
|
||||
}
|
||||
|
||||
// Convert to hex
|
||||
return hex.EncodeToString(decoded)
|
||||
}
|
||||
|
||||
// this is the opposite of FixKey
|
||||
func UnfixKey(hexKey string) string {
|
||||
// Decode from hex
|
||||
decoded, err := hex.DecodeString(hexKey)
|
||||
if err != nil {
|
||||
logger.Fatal("Error decoding hex: %v", err)
|
||||
}
|
||||
|
||||
// Convert to base64
|
||||
return base64.StdEncoding.EncodeToString(decoded)
|
||||
}
|
||||
|
||||
func MapToWireGuardLogLevel(level logger.LogLevel) int {
|
||||
switch level {
|
||||
case logger.DEBUG:
|
||||
return device.LogLevelVerbose
|
||||
// case logger.INFO:
|
||||
// return device.LogLevel
|
||||
case logger.WARN:
|
||||
return device.LogLevelError
|
||||
case logger.ERROR, logger.FATAL:
|
||||
return device.LogLevelSilent
|
||||
default:
|
||||
return device.LogLevelSilent
|
||||
}
|
||||
}
|
||||
|
||||
// GetProtocol returns protocol number from IPv4 packet (fast path)
|
||||
func GetProtocol(packet []byte) (uint8, bool) {
|
||||
if len(packet) < 20 {
|
||||
return 0, false
|
||||
}
|
||||
version := packet[0] >> 4
|
||||
if version == 4 {
|
||||
return packet[9], true
|
||||
} else if version == 6 {
|
||||
if len(packet) < 40 {
|
||||
return 0, false
|
||||
}
|
||||
return packet[6], true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetDestPort returns destination port from TCP/UDP packet (fast path)
|
||||
func GetDestPort(packet []byte) (uint16, bool) {
|
||||
if len(packet) < 20 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
version := packet[0] >> 4
|
||||
var headerLen int
|
||||
|
||||
if version == 4 {
|
||||
ihl := packet[0] & 0x0F
|
||||
headerLen = int(ihl) * 4
|
||||
if len(packet) < headerLen+4 {
|
||||
return 0, false
|
||||
}
|
||||
} else if version == 6 {
|
||||
headerLen = 40
|
||||
if len(packet) < headerLen+4 {
|
||||
return 0, false
|
||||
}
|
||||
} else {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// Destination port is at bytes 2-3 of TCP/UDP header
|
||||
port := binary.BigEndian.Uint16(packet[headerLen+2 : headerLen+4])
|
||||
return port, true
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -18,6 +19,11 @@ import (
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"context"
|
||||
|
||||
"github.com/fosrl/newt/internal/telemetry"
|
||||
"go.opentelemetry.io/otel"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
@@ -37,7 +43,10 @@ type Client struct {
|
||||
writeMux sync.Mutex
|
||||
clientType string // Type of client (e.g., "newt", "olm")
|
||||
tlsConfig TLSConfig
|
||||
metricsCtxMu sync.RWMutex
|
||||
metricsCtx context.Context
|
||||
configNeedsSave bool // Flag to track if config needs to be saved
|
||||
serverVersion string
|
||||
}
|
||||
|
||||
type ClientOption func(*Client)
|
||||
@@ -81,6 +90,26 @@ func (c *Client) OnTokenUpdate(callback func(token string)) {
|
||||
c.onTokenUpdate = callback
|
||||
}
|
||||
|
||||
func (c *Client) metricsContext() context.Context {
|
||||
c.metricsCtxMu.RLock()
|
||||
defer c.metricsCtxMu.RUnlock()
|
||||
if c.metricsCtx != nil {
|
||||
return c.metricsCtx
|
||||
}
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
func (c *Client) setMetricsContext(ctx context.Context) {
|
||||
c.metricsCtxMu.Lock()
|
||||
c.metricsCtx = ctx
|
||||
c.metricsCtxMu.Unlock()
|
||||
}
|
||||
|
||||
// MetricsContext exposes the context used for telemetry emission when a connection is active.
|
||||
func (c *Client) MetricsContext() context.Context {
|
||||
return c.metricsContext()
|
||||
}
|
||||
|
||||
// NewClient creates a new websocket client
|
||||
func NewClient(clientType string, ID, secret string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) {
|
||||
config := &Config{
|
||||
@@ -121,6 +150,10 @@ func (c *Client) GetConfig() *Config {
|
||||
return c.config
|
||||
}
|
||||
|
||||
func (c *Client) GetServerVersion() string {
|
||||
return c.serverVersion
|
||||
}
|
||||
|
||||
// Connect establishes the WebSocket connection
|
||||
func (c *Client) Connect() error {
|
||||
go c.connectWithRetry()
|
||||
@@ -140,6 +173,7 @@ func (c *Client) Close() error {
|
||||
|
||||
// Set connection status to false
|
||||
c.setConnected(false)
|
||||
telemetry.SetWSConnectionState(false)
|
||||
|
||||
// Close the WebSocket connection gracefully
|
||||
if c.conn != nil {
|
||||
@@ -170,7 +204,31 @@ func (c *Client) SendMessage(messageType string, data interface{}) error {
|
||||
|
||||
c.writeMux.Lock()
|
||||
defer c.writeMux.Unlock()
|
||||
return c.conn.WriteJSON(msg)
|
||||
if err := c.conn.WriteJSON(msg); err != nil {
|
||||
return err
|
||||
}
|
||||
telemetry.IncWSMessage(c.metricsContext(), "out", "text")
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendMessage sends a message through the WebSocket connection
|
||||
func (c *Client) SendMessageNoLog(messageType string, data interface{}) error {
|
||||
if c.conn == nil {
|
||||
return fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
msg := WSMessage{
|
||||
Type: messageType,
|
||||
Data: data,
|
||||
}
|
||||
|
||||
c.writeMux.Lock()
|
||||
defer c.writeMux.Unlock()
|
||||
if err := c.conn.WriteJSON(msg); err != nil {
|
||||
return err
|
||||
}
|
||||
telemetry.IncWSMessage(c.metricsContext(), "out", "text")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func()) {
|
||||
@@ -265,8 +323,12 @@ func (c *Client) getToken() (string, error) {
|
||||
return "", fmt.Errorf("failed to marshal token request data: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Create a new request
|
||||
req, err := http.NewRequest(
|
||||
req, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
"POST",
|
||||
baseEndpoint+"/api/v1/auth/"+c.clientType+"/get-token",
|
||||
bytes.NewBuffer(jsonData),
|
||||
@@ -288,18 +350,32 @@ func (c *Client) getToken() (string, error) {
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
telemetry.IncConnAttempt(ctx, "auth", "failure")
|
||||
telemetry.IncConnError(ctx, "auth", classifyConnError(err))
|
||||
return "", fmt.Errorf("failed to request new token: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
logger.Debug("Token response body: %s", string(body))
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
|
||||
logger.Error("Failed to get token with status code: %d", resp.StatusCode)
|
||||
telemetry.IncConnAttempt(ctx, "auth", "failure")
|
||||
etype := "io_error"
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||
etype = "auth_failed"
|
||||
}
|
||||
telemetry.IncConnError(ctx, "auth", etype)
|
||||
// Reconnect reason mapping for auth failures
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||
telemetry.IncReconnect(ctx, c.config.ID, "client", telemetry.ReasonAuthError)
|
||||
}
|
||||
return "", fmt.Errorf("failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var tokenResp TokenResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||
logger.Error("Failed to decode token response.")
|
||||
return "", fmt.Errorf("failed to decode token response: %w", err)
|
||||
}
|
||||
@@ -312,11 +388,61 @@ func (c *Client) getToken() (string, error) {
|
||||
return "", fmt.Errorf("received empty token from server")
|
||||
}
|
||||
|
||||
// print server version
|
||||
logger.Info("Server version: %s", tokenResp.Data.ServerVersion)
|
||||
|
||||
c.serverVersion = tokenResp.Data.ServerVersion
|
||||
|
||||
logger.Debug("Received token: %s", tokenResp.Data.Token)
|
||||
telemetry.IncConnAttempt(ctx, "auth", "success")
|
||||
|
||||
return tokenResp.Data.Token, nil
|
||||
}
|
||||
|
||||
// classifyConnError maps to fixed, low-cardinality error_type values.
|
||||
// Allowed enum: dial_timeout, tls_handshake, auth_failed, io_error
|
||||
func classifyConnError(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
msg := strings.ToLower(err.Error())
|
||||
switch {
|
||||
case strings.Contains(msg, "tls") || strings.Contains(msg, "certificate"):
|
||||
return "tls_handshake"
|
||||
case strings.Contains(msg, "timeout") || strings.Contains(msg, "i/o timeout") || strings.Contains(msg, "deadline exceeded"):
|
||||
return "dial_timeout"
|
||||
case strings.Contains(msg, "unauthorized") || strings.Contains(msg, "forbidden"):
|
||||
return "auth_failed"
|
||||
default:
|
||||
// Group remaining network/socket errors as io_error to avoid label explosion
|
||||
return "io_error"
|
||||
}
|
||||
}
|
||||
|
||||
func classifyWSDisconnect(err error) (result, reason string) {
|
||||
if err == nil {
|
||||
return "success", "normal"
|
||||
}
|
||||
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
|
||||
return "success", "normal"
|
||||
}
|
||||
if ne, ok := err.(net.Error); ok && ne.Timeout() {
|
||||
return "error", "timeout"
|
||||
}
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
return "error", "unexpected_close"
|
||||
}
|
||||
msg := strings.ToLower(err.Error())
|
||||
switch {
|
||||
case strings.Contains(msg, "eof"):
|
||||
return "error", "eof"
|
||||
case strings.Contains(msg, "reset"):
|
||||
return "error", "connection_reset"
|
||||
default:
|
||||
return "error", "read_error"
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) connectWithRetry() {
|
||||
for {
|
||||
select {
|
||||
@@ -335,9 +461,13 @@ func (c *Client) connectWithRetry() {
|
||||
}
|
||||
|
||||
func (c *Client) establishConnection() error {
|
||||
ctx := context.Background()
|
||||
|
||||
// Get token for authentication
|
||||
token, err := c.getToken()
|
||||
if err != nil {
|
||||
telemetry.IncConnAttempt(ctx, "websocket", "failure")
|
||||
telemetry.IncConnError(ctx, "websocket", classifyConnError(err))
|
||||
return fmt.Errorf("failed to get token: %w", err)
|
||||
}
|
||||
|
||||
@@ -370,7 +500,12 @@ func (c *Client) establishConnection() error {
|
||||
q.Set("clientType", c.clientType)
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
// Connect to WebSocket
|
||||
// Connect to WebSocket (optional span)
|
||||
tr := otel.Tracer("newt")
|
||||
ctx, span := tr.Start(ctx, "ws.connect")
|
||||
defer span.End()
|
||||
|
||||
start := time.Now()
|
||||
dialer := websocket.DefaultDialer
|
||||
|
||||
// Use new TLS configuration method
|
||||
@@ -392,18 +527,42 @@ func (c *Client) establishConnection() error {
|
||||
logger.Debug("WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
|
||||
}
|
||||
|
||||
conn, _, err := dialer.Dial(u.String(), nil)
|
||||
conn, _, err := dialer.DialContext(ctx, u.String(), nil)
|
||||
lat := time.Since(start).Seconds()
|
||||
if err != nil {
|
||||
telemetry.IncConnAttempt(ctx, "websocket", "failure")
|
||||
etype := classifyConnError(err)
|
||||
telemetry.IncConnError(ctx, "websocket", etype)
|
||||
telemetry.ObserveWSConnectLatency(ctx, lat, "failure", etype)
|
||||
// Map handshake-related errors to reconnect reasons where appropriate
|
||||
if etype == "tls_handshake" {
|
||||
telemetry.IncReconnect(ctx, c.config.ID, "client", telemetry.ReasonHandshakeError)
|
||||
} else if etype == "dial_timeout" {
|
||||
telemetry.IncReconnect(ctx, c.config.ID, "client", telemetry.ReasonTimeout)
|
||||
} else {
|
||||
telemetry.IncReconnect(ctx, c.config.ID, "client", telemetry.ReasonError)
|
||||
}
|
||||
telemetry.IncWSReconnect(ctx, etype)
|
||||
return fmt.Errorf("failed to connect to WebSocket: %w", err)
|
||||
}
|
||||
|
||||
telemetry.IncConnAttempt(ctx, "websocket", "success")
|
||||
telemetry.ObserveWSConnectLatency(ctx, lat, "success", "")
|
||||
c.conn = conn
|
||||
c.setConnected(true)
|
||||
telemetry.SetWSConnectionState(true)
|
||||
c.setMetricsContext(ctx)
|
||||
sessionStart := time.Now()
|
||||
// Wire up pong handler for metrics
|
||||
c.conn.SetPongHandler(func(appData string) error {
|
||||
telemetry.IncWSMessage(c.metricsContext(), "in", "pong")
|
||||
return nil
|
||||
})
|
||||
|
||||
// Start the ping monitor
|
||||
go c.pingMonitor()
|
||||
// Start the read pump with disconnect detection
|
||||
go c.readPumpWithDisconnectDetection()
|
||||
go c.readPumpWithDisconnectDetection(sessionStart)
|
||||
|
||||
if c.onConnect != nil {
|
||||
err := c.saveConfig()
|
||||
@@ -496,6 +655,9 @@ func (c *Client) pingMonitor() {
|
||||
}
|
||||
c.writeMux.Lock()
|
||||
err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout))
|
||||
if err == nil {
|
||||
telemetry.IncWSMessage(c.metricsContext(), "out", "ping")
|
||||
}
|
||||
c.writeMux.Unlock()
|
||||
if err != nil {
|
||||
// Check if we're shutting down before logging error and reconnecting
|
||||
@@ -505,6 +667,8 @@ func (c *Client) pingMonitor() {
|
||||
return
|
||||
default:
|
||||
logger.Error("Ping failed: %v", err)
|
||||
telemetry.IncWSKeepaliveFailure(c.metricsContext(), "ping_write")
|
||||
telemetry.IncWSReconnect(c.metricsContext(), "ping_write")
|
||||
c.reconnect()
|
||||
return
|
||||
}
|
||||
@@ -514,17 +678,26 @@ func (c *Client) pingMonitor() {
|
||||
}
|
||||
|
||||
// readPumpWithDisconnectDetection reads messages and triggers reconnect on error
|
||||
func (c *Client) readPumpWithDisconnectDetection() {
|
||||
func (c *Client) readPumpWithDisconnectDetection(started time.Time) {
|
||||
ctx := c.metricsContext()
|
||||
disconnectReason := "shutdown"
|
||||
disconnectResult := "success"
|
||||
|
||||
defer func() {
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
}
|
||||
if !started.IsZero() {
|
||||
telemetry.ObserveWSSessionDuration(ctx, time.Since(started).Seconds(), disconnectResult)
|
||||
}
|
||||
telemetry.IncWSDisconnect(ctx, disconnectReason, disconnectResult)
|
||||
// Only attempt reconnect if we're not shutting down
|
||||
select {
|
||||
case <-c.done:
|
||||
// Shutting down, don't reconnect
|
||||
return
|
||||
default:
|
||||
telemetry.IncWSReconnect(ctx, disconnectReason)
|
||||
c.reconnect()
|
||||
}
|
||||
}()
|
||||
@@ -532,23 +705,33 @@ func (c *Client) readPumpWithDisconnectDetection() {
|
||||
for {
|
||||
select {
|
||||
case <-c.done:
|
||||
disconnectReason = "shutdown"
|
||||
disconnectResult = "success"
|
||||
return
|
||||
default:
|
||||
var msg WSMessage
|
||||
err := c.conn.ReadJSON(&msg)
|
||||
if err == nil {
|
||||
telemetry.IncWSMessage(c.metricsContext(), "in", "text")
|
||||
}
|
||||
if err != nil {
|
||||
// Check if we're shutting down before logging error
|
||||
select {
|
||||
case <-c.done:
|
||||
// Expected during shutdown, don't log as error
|
||||
logger.Debug("WebSocket connection closed during shutdown")
|
||||
disconnectReason = "shutdown"
|
||||
disconnectResult = "success"
|
||||
return
|
||||
default:
|
||||
// Unexpected error during normal operation
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) {
|
||||
logger.Error("WebSocket read error: %v", err)
|
||||
} else {
|
||||
logger.Debug("WebSocket connection closed: %v", err)
|
||||
disconnectResult, disconnectReason = classifyWSDisconnect(err)
|
||||
if disconnectResult == "error" {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) {
|
||||
logger.Error("WebSocket read error: %v", err)
|
||||
} else {
|
||||
logger.Debug("WebSocket connection closed: %v", err)
|
||||
}
|
||||
}
|
||||
return // triggers reconnect via defer
|
||||
}
|
||||
@@ -565,6 +748,7 @@ func (c *Client) readPumpWithDisconnectDetection() {
|
||||
|
||||
func (c *Client) reconnect() {
|
||||
c.setConnected(false)
|
||||
telemetry.SetWSConnectionState(false)
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
|
||||
@@ -9,7 +9,8 @@ type Config struct {
|
||||
|
||||
type TokenResponse struct {
|
||||
Data struct {
|
||||
Token string `json:"token"`
|
||||
Token string `json:"token"`
|
||||
ServerVersion string `json:"serverVersion"`
|
||||
} `json:"data"`
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
|
||||
999
wg/wg.go
999
wg/wg.go
@@ -1,999 +0,0 @@
|
||||
//go:build linux
|
||||
|
||||
package wg
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/network"
|
||||
"github.com/fosrl/newt/websocket"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.org/x/exp/rand"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
type WgConfig struct {
|
||||
IpAddress string `json:"ipAddress"`
|
||||
Peers []Peer `json:"peers"`
|
||||
}
|
||||
|
||||
type Peer struct {
|
||||
PublicKey string `json:"publicKey"`
|
||||
AllowedIPs []string `json:"allowedIps"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
}
|
||||
|
||||
type PeerBandwidth struct {
|
||||
PublicKey string `json:"publicKey"`
|
||||
BytesIn float64 `json:"bytesIn"`
|
||||
BytesOut float64 `json:"bytesOut"`
|
||||
}
|
||||
|
||||
type PeerReading struct {
|
||||
BytesReceived int64
|
||||
BytesTransmitted int64
|
||||
LastChecked time.Time
|
||||
}
|
||||
|
||||
type WireGuardService struct {
|
||||
interfaceName string
|
||||
mtu int
|
||||
client *websocket.Client
|
||||
wgClient *wgctrl.Client
|
||||
config WgConfig
|
||||
key wgtypes.Key
|
||||
keyFilePath string
|
||||
newtId string
|
||||
lastReadings map[string]PeerReading
|
||||
mu sync.Mutex
|
||||
Port uint16
|
||||
stopHolepunch chan struct{}
|
||||
host string
|
||||
serverPubKey string
|
||||
holePunchEndpoint string
|
||||
token string
|
||||
stopGetConfig func()
|
||||
interfaceCreated bool
|
||||
}
|
||||
|
||||
// Add this type definition
|
||||
type fixedPortBind struct {
|
||||
port uint16
|
||||
conn.Bind
|
||||
}
|
||||
|
||||
func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) {
|
||||
// Ignore the requested port and use our fixed port
|
||||
return b.Bind.Open(b.port)
|
||||
}
|
||||
|
||||
func NewFixedPortBind(port uint16) conn.Bind {
|
||||
return &fixedPortBind{
|
||||
port: port,
|
||||
Bind: conn.NewDefaultBind(),
|
||||
}
|
||||
}
|
||||
|
||||
// find an available UDP port in the range [minPort, maxPort] and also the next port for the wgtester
|
||||
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
|
||||
if maxPort < minPort {
|
||||
return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort)
|
||||
}
|
||||
|
||||
// We need to check port+1 as well, so adjust the max port to avoid going out of range
|
||||
adjustedMaxPort := maxPort - 1
|
||||
if adjustedMaxPort < minPort {
|
||||
return 0, fmt.Errorf("insufficient port range to find consecutive ports: min=%d, max=%d", minPort, maxPort)
|
||||
}
|
||||
|
||||
// Create a slice of all ports in the range (excluding the last one)
|
||||
portRange := make([]uint16, adjustedMaxPort-minPort+1)
|
||||
for i := range portRange {
|
||||
portRange[i] = minPort + uint16(i)
|
||||
}
|
||||
|
||||
// Fisher-Yates shuffle to randomize the port order
|
||||
rand.Seed(uint64(time.Now().UnixNano()))
|
||||
for i := len(portRange) - 1; i > 0; i-- {
|
||||
j := rand.Intn(i + 1)
|
||||
portRange[i], portRange[j] = portRange[j], portRange[i]
|
||||
}
|
||||
|
||||
// Try each port in the randomized order
|
||||
for _, port := range portRange {
|
||||
// Check if port is available
|
||||
addr1 := &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: int(port),
|
||||
}
|
||||
conn1, err1 := net.ListenUDP("udp", addr1)
|
||||
if err1 != nil {
|
||||
continue // Port is in use or there was an error, try next port
|
||||
}
|
||||
|
||||
// Check if port+1 is also available
|
||||
addr2 := &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: int(port + 1),
|
||||
}
|
||||
conn2, err2 := net.ListenUDP("udp", addr2)
|
||||
if err2 != nil {
|
||||
// The next port is not available, so close the first connection and try again
|
||||
conn1.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
// Both ports are available, close connections and return the first port
|
||||
conn1.Close()
|
||||
conn2.Close()
|
||||
return port, nil
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("no available consecutive UDP ports found in range %d-%d", minPort, maxPort)
|
||||
}
|
||||
|
||||
func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client) (*WireGuardService, error) {
|
||||
wgClient, err := wgctrl.New()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create WireGuard client: %v", err)
|
||||
}
|
||||
|
||||
var key wgtypes.Key
|
||||
var port uint16
|
||||
// if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file
|
||||
key, err = wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate private key: %v", err)
|
||||
}
|
||||
|
||||
// Load or generate private key
|
||||
if generateAndSaveKeyTo != "" {
|
||||
if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) {
|
||||
keyData, err := os.ReadFile(generateAndSaveKeyTo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read private key: %v", err)
|
||||
}
|
||||
key, err = wgtypes.ParseKey(strings.TrimSpace(string(keyData)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse private key: %v", err)
|
||||
}
|
||||
} else {
|
||||
err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0600)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to save private key: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get the existing wireguard port
|
||||
device, err := wgClient.Device(interfaceName)
|
||||
if err == nil {
|
||||
port = uint16(device.ListenPort)
|
||||
// also set the private key to the existing key
|
||||
key = device.PrivateKey
|
||||
if port != 0 {
|
||||
logger.Info("WireGuard interface %s already exists with port %d\n", interfaceName, port)
|
||||
} else {
|
||||
port, err = FindAvailableUDPPort(49152, 65535)
|
||||
if err != nil {
|
||||
fmt.Printf("Error finding available port: %v\n", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
port, err = FindAvailableUDPPort(49152, 65535)
|
||||
if err != nil {
|
||||
fmt.Printf("Error finding available port: %v\n", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
service := &WireGuardService{
|
||||
interfaceName: interfaceName,
|
||||
mtu: mtu,
|
||||
client: wsClient,
|
||||
wgClient: wgClient,
|
||||
key: key,
|
||||
Port: port,
|
||||
keyFilePath: generateAndSaveKeyTo,
|
||||
newtId: newtId,
|
||||
host: host,
|
||||
lastReadings: make(map[string]PeerReading),
|
||||
stopHolepunch: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Register websocket handlers
|
||||
wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig)
|
||||
wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer)
|
||||
wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer)
|
||||
wsClient.RegisterHandler("newt/wg/peer/update", service.handleUpdatePeer)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) Close(rm bool) {
|
||||
if s.stopGetConfig != nil {
|
||||
s.stopGetConfig()
|
||||
s.stopGetConfig = nil
|
||||
}
|
||||
|
||||
s.wgClient.Close()
|
||||
// Remove the WireGuard interface
|
||||
if rm {
|
||||
if err := s.removeInterface(); err != nil {
|
||||
logger.Error("Failed to remove WireGuard interface: %v", err)
|
||||
}
|
||||
|
||||
// Remove the private key file
|
||||
// if s.keyFilePath != "" {
|
||||
// if err := os.Remove(s.keyFilePath); err != nil {
|
||||
// logger.Error("Failed to remove private key file: %v", err)
|
||||
// }
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WireGuardService) StartHolepunch(serverPubKey string, endpoint string) {
|
||||
// if the device is already created dont start a new holepunch
|
||||
if s.interfaceCreated {
|
||||
return
|
||||
}
|
||||
|
||||
s.serverPubKey = serverPubKey
|
||||
s.holePunchEndpoint = endpoint
|
||||
|
||||
logger.Debug("Starting UDP hole punch to %s", s.holePunchEndpoint)
|
||||
|
||||
s.stopHolepunch = make(chan struct{})
|
||||
|
||||
// start the UDP holepunch
|
||||
go s.keepSendingUDPHolePunch(s.holePunchEndpoint)
|
||||
}
|
||||
|
||||
func (s *WireGuardService) SetToken(token string) {
|
||||
s.token = token
|
||||
}
|
||||
|
||||
func (s *WireGuardService) LoadRemoteConfig() error {
|
||||
s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{
|
||||
"publicKey": s.key.PublicKey().String(),
|
||||
"port": s.Port,
|
||||
}, 2*time.Second)
|
||||
|
||||
logger.Info("Requesting WireGuard configuration from remote server")
|
||||
go s.periodicBandwidthCheck()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
|
||||
var config WgConfig
|
||||
|
||||
logger.Debug("Received message: %v", msg)
|
||||
logger.Info("Received WireGuard clients configuration from remote server")
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &config); err != nil {
|
||||
logger.Info("Error unmarshaling target data: %v", err)
|
||||
return
|
||||
}
|
||||
s.config = config
|
||||
|
||||
if s.stopGetConfig != nil {
|
||||
s.stopGetConfig()
|
||||
s.stopGetConfig = nil
|
||||
}
|
||||
|
||||
// Ensure the WireGuard interface and peers are configured
|
||||
if err := s.ensureWireguardInterface(config); err != nil {
|
||||
logger.Error("Failed to ensure WireGuard interface: %v", err)
|
||||
}
|
||||
|
||||
if err := s.ensureWireguardPeers(config.Peers); err != nil {
|
||||
logger.Error("Failed to ensure WireGuard peers: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
||||
// Check if the WireGuard interface exists
|
||||
_, err := netlink.LinkByName(s.interfaceName)
|
||||
if err != nil {
|
||||
if _, ok := err.(netlink.LinkNotFoundError); ok {
|
||||
// Interface doesn't exist, so create it
|
||||
err = s.createWireGuardInterface()
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to create WireGuard interface: %v", err)
|
||||
}
|
||||
s.interfaceCreated = true
|
||||
logger.Info("Created WireGuard interface %s\n", s.interfaceName)
|
||||
} else {
|
||||
logger.Fatal("Error checking for WireGuard interface: %v", err)
|
||||
}
|
||||
} else {
|
||||
logger.Info("WireGuard interface %s already exists\n", s.interfaceName)
|
||||
|
||||
// get the exising wireguard port
|
||||
device, err := s.wgClient.Device(s.interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get device: %v", err)
|
||||
}
|
||||
|
||||
// get the existing port
|
||||
s.Port = uint16(device.ListenPort)
|
||||
logger.Info("WireGuard interface %s already exists with port %d\n", s.interfaceName, s.Port)
|
||||
|
||||
s.interfaceCreated = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// stop the holepunch its a channel
|
||||
if s.stopHolepunch != nil {
|
||||
close(s.stopHolepunch)
|
||||
s.stopHolepunch = nil
|
||||
}
|
||||
|
||||
logger.Info("Assigning IP address %s to interface %s\n", wgconfig.IpAddress, s.interfaceName)
|
||||
// Assign IP address to the interface
|
||||
err = s.assignIPAddress(wgconfig.IpAddress)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to assign IP address: %v", err)
|
||||
}
|
||||
|
||||
// Check if the interface already exists
|
||||
_, err = s.wgClient.Device(s.interfaceName)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return fmt.Errorf("interface %s does not exist", s.interfaceName)
|
||||
}
|
||||
return fmt.Errorf("failed to get device: %v", err)
|
||||
}
|
||||
|
||||
// Parse the private key
|
||||
key, err := wgtypes.ParseKey(s.key.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse private key: %v", err)
|
||||
}
|
||||
|
||||
config := wgtypes.Config{
|
||||
PrivateKey: &key,
|
||||
ListenPort: new(int),
|
||||
}
|
||||
|
||||
// Use the service's fixed port instead of the config port
|
||||
*config.ListenPort = int(s.Port)
|
||||
|
||||
// Create and configure the WireGuard interface
|
||||
err = s.wgClient.ConfigureDevice(s.interfaceName, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to configure WireGuard device: %v", err)
|
||||
}
|
||||
|
||||
// bring up the interface
|
||||
link, err := netlink.LinkByName(s.interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get interface: %v", err)
|
||||
}
|
||||
|
||||
if err := netlink.LinkSetMTU(link, s.mtu); err != nil {
|
||||
return fmt.Errorf("failed to set MTU: %v", err)
|
||||
}
|
||||
|
||||
if err := netlink.LinkSetUp(link); err != nil {
|
||||
return fmt.Errorf("failed to bring up interface: %v", err)
|
||||
}
|
||||
|
||||
// if err := s.ensureMSSClamping(); err != nil {
|
||||
// logger.Warn("Failed to ensure MSS clamping: %v", err)
|
||||
// }
|
||||
|
||||
logger.Info("WireGuard interface %s created and configured", s.interfaceName)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) createWireGuardInterface() error {
|
||||
wgLink := &netlink.GenericLink{
|
||||
LinkAttrs: netlink.LinkAttrs{Name: s.interfaceName},
|
||||
LinkType: "wireguard",
|
||||
}
|
||||
return netlink.LinkAdd(wgLink)
|
||||
}
|
||||
|
||||
func (s *WireGuardService) assignIPAddress(ipAddress string) error {
|
||||
link, err := netlink.LinkByName(s.interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get interface: %v", err)
|
||||
}
|
||||
|
||||
addr, err := netlink.ParseAddr(ipAddress)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse IP address: %v", err)
|
||||
}
|
||||
|
||||
return netlink.AddrAdd(link, addr)
|
||||
}
|
||||
|
||||
func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
|
||||
// get the current peers
|
||||
device, err := s.wgClient.Device(s.interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get device: %v", err)
|
||||
}
|
||||
|
||||
// get the peer public keys
|
||||
var currentPeers []string
|
||||
for _, peer := range device.Peers {
|
||||
currentPeers = append(currentPeers, peer.PublicKey.String())
|
||||
}
|
||||
|
||||
// remove any peers that are not in the config
|
||||
for _, peer := range currentPeers {
|
||||
found := false
|
||||
for _, configPeer := range peers {
|
||||
if peer == configPeer.PublicKey {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
err := s.removePeer(peer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove peer: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// add any peers that are in the config but not in the current peers
|
||||
for _, configPeer := range peers {
|
||||
found := false
|
||||
for _, peer := range currentPeers {
|
||||
if configPeer.PublicKey == peer {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
err := s.addPeer(configPeer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add peer: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) {
|
||||
logger.Debug("Received message: %v", msg.Data)
|
||||
var peer Peer
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error marshaling data: %v", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &peer); err != nil {
|
||||
logger.Info("Error unmarshaling target data: %v", err)
|
||||
}
|
||||
|
||||
err = s.addPeer(peer)
|
||||
if err != nil {
|
||||
logger.Info("Error adding peer: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WireGuardService) addPeer(peer Peer) error {
|
||||
pubKey, err := wgtypes.ParseKey(peer.PublicKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %v", err)
|
||||
}
|
||||
|
||||
// parse allowed IPs into array of net.IPNet
|
||||
var allowedIPs []net.IPNet
|
||||
for _, ipStr := range peer.AllowedIPs {
|
||||
_, ipNet, err := net.ParseCIDR(ipStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse allowed IP: %v", err)
|
||||
}
|
||||
allowedIPs = append(allowedIPs, *ipNet)
|
||||
}
|
||||
// add keep alive using *time.Duration of 1 second
|
||||
keepalive := time.Second
|
||||
|
||||
var peerConfig wgtypes.PeerConfig
|
||||
if peer.Endpoint != "" {
|
||||
endpoint, err := net.ResolveUDPAddr("udp", peer.Endpoint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve endpoint address: %w", err)
|
||||
}
|
||||
|
||||
peerConfig = wgtypes.PeerConfig{
|
||||
PublicKey: pubKey,
|
||||
AllowedIPs: allowedIPs,
|
||||
PersistentKeepaliveInterval: &keepalive,
|
||||
Endpoint: endpoint,
|
||||
}
|
||||
} else {
|
||||
peerConfig = wgtypes.PeerConfig{
|
||||
PublicKey: pubKey,
|
||||
AllowedIPs: allowedIPs,
|
||||
PersistentKeepaliveInterval: &keepalive,
|
||||
}
|
||||
logger.Info("Added peer with no endpoint!")
|
||||
}
|
||||
|
||||
config := wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{peerConfig},
|
||||
}
|
||||
|
||||
if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil {
|
||||
return fmt.Errorf("failed to add peer: %v", err)
|
||||
}
|
||||
|
||||
logger.Info("Peer %s added successfully", peer.PublicKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) handleRemovePeer(msg websocket.WSMessage) {
|
||||
logger.Debug("Received message: %v", msg.Data)
|
||||
// parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" }
|
||||
type RemoveRequest struct {
|
||||
PublicKey string `json:"publicKey"`
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error marshaling data: %v", err)
|
||||
}
|
||||
|
||||
var request RemoveRequest
|
||||
if err := json.Unmarshal(jsonData, &request); err != nil {
|
||||
logger.Info("Error unmarshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.removePeer(request.PublicKey); err != nil {
|
||||
logger.Info("Error removing peer: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WireGuardService) removePeer(publicKey string) error {
|
||||
pubKey, err := wgtypes.ParseKey(publicKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %v", err)
|
||||
}
|
||||
|
||||
peerConfig := wgtypes.PeerConfig{
|
||||
PublicKey: pubKey,
|
||||
Remove: true,
|
||||
}
|
||||
|
||||
config := wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{peerConfig},
|
||||
}
|
||||
|
||||
if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil {
|
||||
return fmt.Errorf("failed to remove peer: %v", err)
|
||||
}
|
||||
|
||||
logger.Info("Peer %s removed successfully", publicKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) handleUpdatePeer(msg websocket.WSMessage) {
|
||||
logger.Debug("Received message: %v", msg.Data)
|
||||
// Define a struct to match the incoming message structure with optional fields
|
||||
type UpdatePeerRequest struct {
|
||||
PublicKey string `json:"publicKey"`
|
||||
AllowedIPs []string `json:"allowedIps,omitempty"`
|
||||
Endpoint string `json:"endpoint,omitempty"`
|
||||
}
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
var request UpdatePeerRequest
|
||||
if err := json.Unmarshal(jsonData, &request); err != nil {
|
||||
logger.Info("Error unmarshaling peer data: %v", err)
|
||||
return
|
||||
}
|
||||
// First, get the current peer configuration to preserve any unmodified fields
|
||||
device, err := s.wgClient.Device(s.interfaceName)
|
||||
if err != nil {
|
||||
logger.Info("Error getting WireGuard device: %v", err)
|
||||
return
|
||||
}
|
||||
pubKey, err := wgtypes.ParseKey(request.PublicKey)
|
||||
if err != nil {
|
||||
logger.Info("Error parsing public key: %v", err)
|
||||
return
|
||||
}
|
||||
// Find the existing peer configuration
|
||||
var currentPeer *wgtypes.Peer
|
||||
for _, p := range device.Peers {
|
||||
if p.PublicKey == pubKey {
|
||||
currentPeer = &p
|
||||
break
|
||||
}
|
||||
}
|
||||
if currentPeer == nil {
|
||||
logger.Info("Peer %s not found, cannot update", request.PublicKey)
|
||||
return
|
||||
}
|
||||
// Create the update peer config
|
||||
peerConfig := wgtypes.PeerConfig{
|
||||
PublicKey: pubKey,
|
||||
UpdateOnly: true,
|
||||
}
|
||||
// Keep the default persistent keepalive of 1 second
|
||||
keepalive := time.Second
|
||||
peerConfig.PersistentKeepaliveInterval = &keepalive
|
||||
|
||||
// Handle Endpoint field special case
|
||||
// If Endpoint is included in the request but empty, we want to remove the endpoint
|
||||
// If Endpoint is not included, we don't modify it
|
||||
endpointSpecified := false
|
||||
for key := range msg.Data.(map[string]interface{}) {
|
||||
if key == "endpoint" {
|
||||
endpointSpecified = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Only update AllowedIPs if provided in the request
|
||||
if len(request.AllowedIPs) > 0 {
|
||||
var allowedIPs []net.IPNet
|
||||
for _, ipStr := range request.AllowedIPs {
|
||||
_, ipNet, err := net.ParseCIDR(ipStr)
|
||||
if err != nil {
|
||||
logger.Info("Error parsing allowed IP %s: %v", ipStr, err)
|
||||
return
|
||||
}
|
||||
allowedIPs = append(allowedIPs, *ipNet)
|
||||
}
|
||||
peerConfig.AllowedIPs = allowedIPs
|
||||
peerConfig.ReplaceAllowedIPs = true
|
||||
logger.Info("Updating AllowedIPs for peer %s", request.PublicKey)
|
||||
} else if endpointSpecified && request.Endpoint == "" {
|
||||
peerConfig.ReplaceAllowedIPs = false
|
||||
}
|
||||
|
||||
if endpointSpecified {
|
||||
if request.Endpoint != "" {
|
||||
// Update to new endpoint
|
||||
endpoint, err := net.ResolveUDPAddr("udp", request.Endpoint)
|
||||
if err != nil {
|
||||
logger.Info("Error resolving endpoint address %s: %v", request.Endpoint, err)
|
||||
return
|
||||
}
|
||||
peerConfig.Endpoint = endpoint
|
||||
logger.Info("Updating Endpoint for peer %s to %s", request.PublicKey, request.Endpoint)
|
||||
} else {
|
||||
// specify any address to listen for any incoming packets
|
||||
peerConfig.Endpoint = &net.UDPAddr{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
}
|
||||
logger.Info("Removing Endpoint for peer %s", request.PublicKey)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply the configuration update
|
||||
config := wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{peerConfig},
|
||||
}
|
||||
if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil {
|
||||
logger.Info("Error updating peer configuration: %v", err)
|
||||
return
|
||||
}
|
||||
logger.Info("Peer %s updated successfully", request.PublicKey)
|
||||
}
|
||||
|
||||
func (s *WireGuardService) periodicBandwidthCheck() {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
if err := s.reportPeerBandwidth(); err != nil {
|
||||
logger.Info("Failed to report peer bandwidth: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
||||
device, err := s.wgClient.Device(s.interfaceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get device: %v", err)
|
||||
}
|
||||
|
||||
peerBandwidths := []PeerBandwidth{}
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for _, peer := range device.Peers {
|
||||
publicKey := peer.PublicKey.String()
|
||||
currentReading := PeerReading{
|
||||
BytesReceived: peer.ReceiveBytes,
|
||||
BytesTransmitted: peer.TransmitBytes,
|
||||
LastChecked: now,
|
||||
}
|
||||
|
||||
var bytesInDiff, bytesOutDiff float64
|
||||
lastReading, exists := s.lastReadings[publicKey]
|
||||
|
||||
if exists {
|
||||
timeDiff := currentReading.LastChecked.Sub(lastReading.LastChecked).Seconds()
|
||||
if timeDiff > 0 {
|
||||
// Calculate bytes transferred since last reading
|
||||
bytesInDiff = float64(currentReading.BytesReceived - lastReading.BytesReceived)
|
||||
bytesOutDiff = float64(currentReading.BytesTransmitted - lastReading.BytesTransmitted)
|
||||
|
||||
// Handle counter wraparound (if the counter resets or overflows)
|
||||
if bytesInDiff < 0 {
|
||||
bytesInDiff = float64(currentReading.BytesReceived)
|
||||
}
|
||||
if bytesOutDiff < 0 {
|
||||
bytesOutDiff = float64(currentReading.BytesTransmitted)
|
||||
}
|
||||
|
||||
// Convert to MB
|
||||
bytesInMB := bytesInDiff / (1024 * 1024)
|
||||
bytesOutMB := bytesOutDiff / (1024 * 1024)
|
||||
|
||||
peerBandwidths = append(peerBandwidths, PeerBandwidth{
|
||||
PublicKey: publicKey,
|
||||
BytesIn: bytesInMB,
|
||||
BytesOut: bytesOutMB,
|
||||
})
|
||||
} else {
|
||||
// If readings are too close together or time hasn't passed, report 0
|
||||
peerBandwidths = append(peerBandwidths, PeerBandwidth{
|
||||
PublicKey: publicKey,
|
||||
BytesIn: 0,
|
||||
BytesOut: 0,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// For first reading of a peer, report 0 to establish baseline
|
||||
peerBandwidths = append(peerBandwidths, PeerBandwidth{
|
||||
PublicKey: publicKey,
|
||||
BytesIn: 0,
|
||||
BytesOut: 0,
|
||||
})
|
||||
}
|
||||
|
||||
// Update the last reading
|
||||
s.lastReadings[publicKey] = currentReading
|
||||
}
|
||||
|
||||
// Clean up old peers
|
||||
for publicKey := range s.lastReadings {
|
||||
found := false
|
||||
for _, peer := range device.Peers {
|
||||
if peer.PublicKey.String() == publicKey {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
delete(s.lastReadings, publicKey)
|
||||
}
|
||||
}
|
||||
|
||||
return peerBandwidths, nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) reportPeerBandwidth() error {
|
||||
bandwidths, err := s.calculatePeerBandwidth()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to calculate peer bandwidth: %v", err)
|
||||
}
|
||||
|
||||
err = s.client.SendMessage("newt/receive-bandwidth", map[string]interface{}{
|
||||
"bandwidthData": bandwidths,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send bandwidth data: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error {
|
||||
|
||||
if s.serverPubKey == "" || s.token == "" {
|
||||
logger.Debug("Server public key or token not set, skipping UDP hole punch")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse server address
|
||||
serverSplit := strings.Split(serverAddr, ":")
|
||||
if len(serverSplit) < 2 {
|
||||
return fmt.Errorf("invalid server address format, expected hostname:port")
|
||||
}
|
||||
|
||||
serverHostname := serverSplit[0]
|
||||
serverPort, err := strconv.ParseUint(serverSplit[1], 10, 16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse server port: %v", err)
|
||||
}
|
||||
|
||||
// Resolve server hostname to IP
|
||||
serverIPAddr := network.HostToAddr(serverHostname)
|
||||
if serverIPAddr == nil {
|
||||
return fmt.Errorf("failed to resolve server hostname")
|
||||
}
|
||||
|
||||
// Get client IP based on route to server
|
||||
clientIP := network.GetClientIP(serverIPAddr.IP)
|
||||
|
||||
// Create server and client configs
|
||||
server := &network.Server{
|
||||
Hostname: serverHostname,
|
||||
Addr: serverIPAddr,
|
||||
Port: uint16(serverPort),
|
||||
}
|
||||
|
||||
client := &network.PeerNet{
|
||||
IP: clientIP,
|
||||
Port: s.Port,
|
||||
NewtID: s.newtId,
|
||||
}
|
||||
|
||||
// Setup raw connection with BPF filtering
|
||||
rawConn := network.SetupRawConn(server, client)
|
||||
defer rawConn.Close()
|
||||
|
||||
// Create JSON payload
|
||||
payload := struct {
|
||||
NewtID string `json:"newtId"`
|
||||
Token string `json:"token"`
|
||||
}{
|
||||
NewtID: s.newtId,
|
||||
Token: s.token,
|
||||
}
|
||||
|
||||
// Convert payload to JSON
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal payload: %v", err)
|
||||
}
|
||||
|
||||
// Encrypt the payload using the server's WireGuard public key
|
||||
encryptedPayload, err := s.encryptPayload(payloadBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt payload: %v", err)
|
||||
}
|
||||
|
||||
// Send the encrypted packet using the raw connection
|
||||
err = network.SendDataPacket(encryptedPayload, rawConn, server, client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send UDP packet: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) encryptPayload(payload []byte) (interface{}, error) {
|
||||
// Generate an ephemeral keypair for this message
|
||||
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
|
||||
}
|
||||
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
|
||||
|
||||
// Parse the server's public key
|
||||
serverPubKey, err := wgtypes.ParseKey(s.serverPubKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse server public key: %v", err)
|
||||
}
|
||||
|
||||
// Use X25519 for key exchange (replacing deprecated ScalarMult)
|
||||
var ephPrivKeyFixed [32]byte
|
||||
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
|
||||
|
||||
// Perform X25519 key exchange
|
||||
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
|
||||
}
|
||||
|
||||
// Create an AEAD cipher using the shared secret
|
||||
aead, err := chacha20poly1305.New(sharedSecret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
|
||||
}
|
||||
|
||||
// Generate a random nonce
|
||||
nonce := make([]byte, aead.NonceSize())
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate nonce: %v", err)
|
||||
}
|
||||
|
||||
// Encrypt the payload
|
||||
ciphertext := aead.Seal(nil, nonce, payload, nil)
|
||||
|
||||
// Prepare the final encrypted message
|
||||
encryptedMsg := struct {
|
||||
EphemeralPublicKey string `json:"ephemeralPublicKey"`
|
||||
Nonce []byte `json:"nonce"`
|
||||
Ciphertext []byte `json:"ciphertext"`
|
||||
}{
|
||||
EphemeralPublicKey: ephemeralPublicKey.String(),
|
||||
Nonce: nonce,
|
||||
Ciphertext: ciphertext,
|
||||
}
|
||||
|
||||
return encryptedMsg, nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) keepSendingUDPHolePunch(host string) {
|
||||
logger.Info("Starting UDP hole punch routine to %s:21820", host)
|
||||
|
||||
// send initial hole punch
|
||||
if err := s.sendUDPHolePunch(host + ":21820"); err != nil {
|
||||
logger.Debug("Failed to send initial UDP hole punch: %v", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(3 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
timeout := time.NewTimer(15 * time.Second)
|
||||
defer timeout.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.stopHolepunch:
|
||||
logger.Info("Stopping UDP holepunch")
|
||||
return
|
||||
case <-timeout.C:
|
||||
logger.Info("UDP holepunch routine timed out after 15 seconds")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.sendUDPHolePunch(host + ":21820"); err != nil {
|
||||
logger.Debug("Failed to send UDP hole punch: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WireGuardService) removeInterface() error {
|
||||
// Remove the WireGuard interface
|
||||
link, err := netlink.LinkByName(s.interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get interface: %v", err)
|
||||
}
|
||||
|
||||
err = netlink.LinkDel(link)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete interface: %v", err)
|
||||
}
|
||||
|
||||
logger.Info("WireGuard interface %s removed successfully", s.interfaceName)
|
||||
|
||||
return nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,12 +3,13 @@ package wgtester
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"github.com/fosrl/newt/netstack2"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
)
|
||||
|
||||
@@ -37,9 +38,8 @@ type Server struct {
|
||||
isRunning bool
|
||||
runningLock sync.Mutex
|
||||
newtID string
|
||||
outputPrefix string
|
||||
useNetstack bool
|
||||
tnet interface{} // Will be *netstack.Net when using netstack
|
||||
tnet interface{} // Will be *netstack2.Net when using netstack
|
||||
}
|
||||
|
||||
// NewServer creates a new connection test server using UDP
|
||||
@@ -49,20 +49,18 @@ func NewServer(serverAddr string, serverPort uint16, newtID string) *Server {
|
||||
serverPort: serverPort + 1, // use the next port for the server
|
||||
shutdownCh: make(chan struct{}),
|
||||
newtID: newtID,
|
||||
outputPrefix: "[WGTester] ",
|
||||
useNetstack: false,
|
||||
tnet: nil,
|
||||
}
|
||||
}
|
||||
|
||||
// NewServerWithNetstack creates a new connection test server using WireGuard netstack
|
||||
func NewServerWithNetstack(serverAddr string, serverPort uint16, newtID string, tnet *netstack.Net) *Server {
|
||||
func NewServerWithNetstack(serverAddr string, serverPort uint16, newtID string, tnet *netstack2.Net) *Server {
|
||||
return &Server{
|
||||
serverAddr: serverAddr,
|
||||
serverPort: serverPort + 1, // use the next port for the server
|
||||
shutdownCh: make(chan struct{}),
|
||||
newtID: newtID,
|
||||
outputPrefix: "[WGTester] ",
|
||||
useNetstack: true,
|
||||
tnet: tnet,
|
||||
}
|
||||
@@ -82,7 +80,7 @@ func (s *Server) Start() error {
|
||||
|
||||
if s.useNetstack && s.tnet != nil {
|
||||
// Use WireGuard netstack
|
||||
tnet := s.tnet.(*netstack.Net)
|
||||
tnet := s.tnet.(*netstack2.Net)
|
||||
udpAddr := &net.UDPAddr{Port: int(s.serverPort)}
|
||||
netstackConn, err := tnet.ListenUDP(udpAddr)
|
||||
if err != nil {
|
||||
@@ -108,7 +106,7 @@ func (s *Server) Start() error {
|
||||
s.isRunning = true
|
||||
go s.handleConnections()
|
||||
|
||||
logger.Info("%sServer started on %s:%d", s.outputPrefix, s.serverAddr, s.serverPort)
|
||||
logger.Debug("WGTester Server started on %s:%d", s.serverAddr, s.serverPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -126,11 +124,11 @@ func (s *Server) Stop() {
|
||||
s.conn.Close()
|
||||
}
|
||||
s.isRunning = false
|
||||
logger.Info(s.outputPrefix + "Server stopped")
|
||||
logger.Info("WGTester Server stopped")
|
||||
}
|
||||
|
||||
// RestartWithNetstack stops the current server and restarts it with netstack
|
||||
func (s *Server) RestartWithNetstack(tnet *netstack.Net) error {
|
||||
func (s *Server) RestartWithNetstack(tnet *netstack2.Net) error {
|
||||
s.Stop()
|
||||
|
||||
// Update configuration to use netstack
|
||||
@@ -161,7 +159,7 @@ func (s *Server) handleConnections() {
|
||||
// Set read deadline to avoid blocking forever
|
||||
err := s.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
if err != nil {
|
||||
logger.Error(s.outputPrefix+"Error setting read deadline: %v", err)
|
||||
logger.Error("Error setting read deadline: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -187,7 +185,11 @@ func (s *Server) handleConnections() {
|
||||
case <-s.shutdownCh:
|
||||
return // Don't log error if we're shutting down
|
||||
default:
|
||||
logger.Error(s.outputPrefix+"Error reading from UDP: %v", err)
|
||||
// Don't log EOF errors during shutdown - these are expected when connection is closed
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
logger.Error("Error reading from UDP: %v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
@@ -219,7 +221,7 @@ func (s *Server) handleConnections() {
|
||||
copy(responsePacket[5:13], buffer[5:13])
|
||||
|
||||
// Log response being sent for debugging
|
||||
logger.Debug(s.outputPrefix+"Sending response to %s", addr.String())
|
||||
// logger.Debug("Sending response to %s", addr.String())
|
||||
|
||||
// Send the response packet - handle both regular UDP and netstack UDP
|
||||
if s.useNetstack {
|
||||
@@ -233,9 +235,9 @@ func (s *Server) handleConnections() {
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.Error(s.outputPrefix+"Error sending response: %v", err)
|
||||
logger.Error("Error sending response: %v", err)
|
||||
} else {
|
||||
logger.Debug(s.outputPrefix + "Response sent successfully")
|
||||
// logger.Debug("Response sent successfully")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user