mirror of
https://github.com/fosrl/gerbil.git
synced 2026-03-22 12:54:30 -05:00
Compare commits
96 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
40da38708c | ||
|
|
3af64d8bd3 | ||
|
|
fcead8cc15 | ||
|
|
20dad7bb8e | ||
|
|
a955aa6169 | ||
|
|
b118fef265 | ||
|
|
7985f97eb6 | ||
|
|
b9261b8fea | ||
|
|
c3e73d0189 | ||
|
|
df2fbdf160 | ||
|
|
cb4ac8199d | ||
|
|
dd4b86b3e5 | ||
|
|
bad290aa4e | ||
|
|
8c27d5e3bf | ||
|
|
7e7a37d49c | ||
|
|
d44aa97f32 | ||
|
|
b57ad74589 | ||
|
|
82256a3f6f | ||
|
|
9e140a94db | ||
|
|
d0c9ea5a57 | ||
|
|
c88810ef24 | ||
|
|
463a4eea79 | ||
|
|
4576a2e8a7 | ||
|
|
69c13adcdb | ||
|
|
3886c1a8c1 | ||
|
|
06eb4d4310 | ||
|
|
247c47b27f | ||
|
|
060038c29b | ||
|
|
5414d21dcd | ||
|
|
364fa020aa | ||
|
|
b96ee16fbf | ||
|
|
467d69aa7c | ||
|
|
7c7762ebc5 | ||
|
|
526f9c8b4e | ||
|
|
905983cf61 | ||
|
|
a0879114e2 | ||
|
|
0d54a07973 | ||
|
|
4cb2fde961 | ||
|
|
9602599565 | ||
|
|
11f858b341 | ||
|
|
29b2cb33a2 | ||
|
|
34290ffe09 | ||
|
|
1013d0591e | ||
|
|
2f6d62ab45 | ||
|
|
8d6ba79408 | ||
|
|
208b434cb7 | ||
|
|
39ce0ac407 | ||
|
|
72bee56412 | ||
|
|
b32da3a714 | ||
|
|
971452e5d3 | ||
|
|
bba4345b0f | ||
|
|
b2392fb250 | ||
|
|
697f4131e7 | ||
|
|
e282715251 | ||
|
|
709df6db3e | ||
|
|
cf2b436470 | ||
|
|
2a29021572 | ||
|
|
a3f9a89079 | ||
|
|
ee27bf3153 | ||
|
|
a90f681957 | ||
|
|
3afc82ef9a | ||
|
|
d3a16f4c59 | ||
|
|
2a1911a66f | ||
|
|
08341b2385 | ||
|
|
6cde07d479 | ||
|
|
06b1e84f99 | ||
|
|
2b7e93ec92 | ||
|
|
ca23ae7a30 | ||
|
|
661fd86305 | ||
|
|
594a499b95 | ||
|
|
44aed84827 | ||
|
|
bf038eb4a2 | ||
|
|
6da3129b4e | ||
|
|
ac0f9b6a82 | ||
|
|
16aef10cca | ||
|
|
19031ebdfd | ||
|
|
0eebbc51d5 | ||
|
|
d321a8ba7e | ||
|
|
3ea86222ca | ||
|
|
c3ebe930d9 | ||
|
|
f2b96f2a38 | ||
|
|
9038239bbe | ||
|
|
3e64eb9c4f | ||
|
|
92992b8c14 | ||
|
|
4ee9d77532 | ||
|
|
bd7a5bd4b0 | ||
|
|
1cd49f8ee3 | ||
|
|
7a919d867b | ||
|
|
ce50c627a7 | ||
|
|
691d5f0271 | ||
|
|
56151089e3 | ||
|
|
af7c1caf98 | ||
|
|
dd208ab67c | ||
|
|
8189d41a45 | ||
|
|
ea3477c8ce | ||
|
|
b03f8911a5 |
47
.github/DISCUSSION_TEMPLATE/feature-requests.yml
vendored
Normal file
47
.github/DISCUSSION_TEMPLATE/feature-requests.yml
vendored
Normal file
@@ -0,0 +1,47 @@
|
||||
body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Summary
|
||||
description: A clear and concise summary of the requested feature.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Motivation
|
||||
description: |
|
||||
Why is this feature important?
|
||||
Explain the problem this feature would solve or what use case it would enable.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Proposed Solution
|
||||
description: |
|
||||
How would you like to see this feature implemented?
|
||||
Provide as much detail as possible about the desired behavior, configuration, or changes.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Alternatives Considered
|
||||
description: Describe any alternative solutions or workarounds you've thought about.
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Additional Context
|
||||
description: Add any other context, mockups, or screenshots about the feature request here.
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Before submitting, please:
|
||||
- Check if there is an existing issue for this feature.
|
||||
- Clearly explain the benefit and use case.
|
||||
- Be as specific as possible to help contributors evaluate and implement.
|
||||
51
.github/ISSUE_TEMPLATE/1.bug_report.yml
vendored
Normal file
51
.github/ISSUE_TEMPLATE/1.bug_report.yml
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
name: Bug Report
|
||||
description: Create a bug report
|
||||
labels: []
|
||||
body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Describe the Bug
|
||||
description: A clear and concise description of what the bug is.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Environment
|
||||
description: Please fill out the relevant details below for your environment.
|
||||
value: |
|
||||
- OS Type & Version: (e.g., Ubuntu 22.04)
|
||||
- Pangolin Version:
|
||||
- Gerbil Version:
|
||||
- Traefik Version:
|
||||
- Newt Version:
|
||||
- Olm Version: (if applicable)
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: To Reproduce
|
||||
description: |
|
||||
Steps to reproduce the behavior, please provide a clear description of how to reproduce the issue, based on the linked minimal reproduction. Screenshots can be provided in the issue body below.
|
||||
|
||||
If using code blocks, make sure syntax highlighting is correct and double-check that the rendered preview is not broken.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Expected Behavior
|
||||
description: A clear and concise description of what you expected to happen.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Before posting the issue go through the steps you've written down to make sure the steps provided are detailed and clear.
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Contributors should be able to follow the steps provided in order to reproduce the bug.
|
||||
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: Need help or have questions?
|
||||
url: https://github.com/orgs/fosrl/discussions
|
||||
about: Ask questions, get help, and discuss with other community members
|
||||
- name: Request a Feature
|
||||
url: https://github.com/orgs/fosrl/discussions/new?category=feature-requests
|
||||
about: Feature requests should be opened as discussions so others can upvote and comment
|
||||
179
.github/workflows/cicd.yml
vendored
179
.github/workflows/cicd.yml
vendored
@@ -1,52 +1,161 @@
|
||||
name: CI/CD Pipeline
|
||||
|
||||
# CI/CD workflow for building, publishing, mirroring, signing container images and building release binaries.
|
||||
# Actions are pinned to specific SHAs to reduce supply-chain risk. This workflow triggers on tag push events.
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write # for GHCR push
|
||||
id-token: write # for Cosign Keyless (OIDC) Signing
|
||||
|
||||
# Required secrets:
|
||||
# - DOCKER_HUB_USERNAME / DOCKER_HUB_ACCESS_TOKEN: push to Docker Hub
|
||||
# - GITHUB_TOKEN: used for GHCR login and OIDC keyless signing
|
||||
# - COSIGN_PRIVATE_KEY / COSIGN_PASSWORD / COSIGN_PUBLIC_KEY: for key-based signing
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
- "[0-9]+.[0-9]+.[0-9]+"
|
||||
- "[0-9]+.[0-9]+.[0-9]+.rc.[0-9]+"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
release:
|
||||
name: Build and Release
|
||||
runs-on: ubuntu-latest
|
||||
release:
|
||||
name: Build and Release
|
||||
runs-on: amd64-runner
|
||||
# Job-level timeout to avoid runaway or stuck runs
|
||||
timeout-minutes: 120
|
||||
env:
|
||||
# Target images
|
||||
DOCKERHUB_IMAGE: docker.io/fosrl/${{ github.event.repository.name }}
|
||||
GHCR_IMAGE: ghcr.io/${{ github.repository_owner }}/${{ github.event.repository.name }}
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3.7.0
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3.12.0
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
|
||||
with:
|
||||
registry: docker.io
|
||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
|
||||
|
||||
- name: Extract tag name
|
||||
id: get-tag
|
||||
run: echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
|
||||
- name: Extract tag name
|
||||
id: get-tag
|
||||
run: echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
|
||||
shell: bash
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: 1.25
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@7a3fe6cf4cb3a834922a1244abfce67bcef6a0c5 # v6.2.0
|
||||
with:
|
||||
go-version: 1.25
|
||||
|
||||
- name: Build and push Docker images
|
||||
run: |
|
||||
TAG=${{ env.TAG }}
|
||||
make docker-build-release tag=$TAG
|
||||
- name: Update version in main.go
|
||||
run: |
|
||||
TAG=${{ env.TAG }}
|
||||
if [ -f main.go ]; then
|
||||
sed -i 's/version_replaceme/'"$TAG"'/' main.go
|
||||
echo "Updated main.go with version $TAG"
|
||||
else
|
||||
echo "main.go not found"
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Build binaries
|
||||
run: |
|
||||
make go-build-release
|
||||
- name: Build and push Docker images (Docker Hub)
|
||||
run: |
|
||||
TAG=${{ env.TAG }}
|
||||
make docker-build-release tag=$TAG
|
||||
echo "Built & pushed to: ${{ env.DOCKERHUB_IMAGE }}:${TAG}"
|
||||
shell: bash
|
||||
|
||||
- name: Upload artifacts from /bin
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: binaries
|
||||
path: bin/
|
||||
- name: Login in to GHCR
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Install skopeo + jq
|
||||
# skopeo: copy/inspect images between registries
|
||||
# jq: JSON parsing tool used to extract digest values
|
||||
run: |
|
||||
sudo apt-get update -y
|
||||
sudo apt-get install -y skopeo jq
|
||||
skopeo --version
|
||||
shell: bash
|
||||
|
||||
- name: Copy tag from Docker Hub to GHCR
|
||||
# Mirror the already-built image (all architectures) to GHCR so we can sign it
|
||||
run: |
|
||||
set -euo pipefail
|
||||
TAG=${{ env.TAG }}
|
||||
echo "Copying ${{ env.DOCKERHUB_IMAGE }}:${TAG} -> ${{ env.GHCR_IMAGE }}:${TAG}"
|
||||
skopeo copy --all --retry-times 3 \
|
||||
docker://$DOCKERHUB_IMAGE:$TAG \
|
||||
docker://$GHCR_IMAGE:$TAG
|
||||
shell: bash
|
||||
|
||||
- name: Install cosign
|
||||
# cosign is used to sign and verify container images (key and keyless)
|
||||
uses: sigstore/cosign-installer@faadad0cce49287aee09b3a48701e75088a2c6ad # v4.0.0
|
||||
|
||||
- name: Dual-sign and verify (GHCR & Docker Hub)
|
||||
# Sign each image by digest using keyless (OIDC) and key-based signing,
|
||||
# then verify both the public key signature and the keyless OIDC signature.
|
||||
env:
|
||||
TAG: ${{ env.TAG }}
|
||||
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
|
||||
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
|
||||
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
|
||||
COSIGN_YES: "true"
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
issuer="https://token.actions.githubusercontent.com"
|
||||
id_regex="^https://github.com/${{ github.repository }}/.+" # accept this repo (all workflows/refs)
|
||||
|
||||
for IMAGE in "${GHCR_IMAGE}" "${DOCKERHUB_IMAGE}"; do
|
||||
echo "Processing ${IMAGE}:${TAG}"
|
||||
|
||||
DIGEST="$(skopeo inspect --retry-times 3 docker://${IMAGE}:${TAG} | jq -r '.Digest')"
|
||||
REF="${IMAGE}@${DIGEST}"
|
||||
echo "Resolved digest: ${REF}"
|
||||
|
||||
echo "==> cosign sign (keyless) --recursive ${REF}"
|
||||
cosign sign --recursive "${REF}"
|
||||
|
||||
echo "==> cosign sign (key) --recursive ${REF}"
|
||||
cosign sign --key env://COSIGN_PRIVATE_KEY --recursive "${REF}"
|
||||
|
||||
echo "==> cosign verify (public key) ${REF}"
|
||||
cosign verify --key env://COSIGN_PUBLIC_KEY "${REF}" -o text
|
||||
|
||||
echo "==> cosign verify (keyless policy) ${REF}"
|
||||
cosign verify \
|
||||
--certificate-oidc-issuer "${issuer}" \
|
||||
--certificate-identity-regexp "${id_regex}" \
|
||||
"${REF}" -o text
|
||||
done
|
||||
shell: bash
|
||||
|
||||
- name: Build binaries
|
||||
run: |
|
||||
make go-build-release
|
||||
shell: bash
|
||||
|
||||
- name: Upload artifacts from /bin
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
|
||||
with:
|
||||
name: binaries
|
||||
path: bin/
|
||||
|
||||
132
.github/workflows/mirror.yaml
vendored
Normal file
132
.github/workflows/mirror.yaml
vendored
Normal file
@@ -0,0 +1,132 @@
|
||||
name: Mirror & Sign (Docker Hub to GHCR)
|
||||
|
||||
on:
|
||||
workflow_dispatch: {}
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
id-token: write # for keyless OIDC
|
||||
|
||||
env:
|
||||
SOURCE_IMAGE: docker.io/fosrl/gerbil
|
||||
DEST_IMAGE: ghcr.io/${{ github.repository_owner }}/${{ github.event.repository.name }}
|
||||
|
||||
jobs:
|
||||
mirror-and-dual-sign:
|
||||
runs-on: amd64-runner
|
||||
steps:
|
||||
- name: Install skopeo + jq
|
||||
run: |
|
||||
sudo apt-get update -y
|
||||
sudo apt-get install -y skopeo jq
|
||||
skopeo --version
|
||||
|
||||
- name: Install cosign
|
||||
uses: sigstore/cosign-installer@faadad0cce49287aee09b3a48701e75088a2c6ad # v4.0.0
|
||||
|
||||
- name: Input check
|
||||
run: |
|
||||
test -n "${SOURCE_IMAGE}" || (echo "SOURCE_IMAGE is empty" && exit 1)
|
||||
echo "Source : ${SOURCE_IMAGE}"
|
||||
echo "Target : ${DEST_IMAGE}"
|
||||
|
||||
# Auth for skopeo (containers-auth)
|
||||
- name: Skopeo login to GHCR
|
||||
run: |
|
||||
skopeo login ghcr.io -u "${{ github.actor }}" -p "${{ secrets.GITHUB_TOKEN }}"
|
||||
|
||||
# Auth for cosign (docker-config)
|
||||
- name: Docker login to GHCR (for cosign)
|
||||
run: |
|
||||
echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u "${{ github.actor }}" --password-stdin
|
||||
|
||||
- name: List source tags
|
||||
run: |
|
||||
set -euo pipefail
|
||||
skopeo list-tags --retry-times 3 docker://"${SOURCE_IMAGE}" \
|
||||
| jq -r '.Tags[]' | sort -u > src-tags.txt
|
||||
echo "Found source tags: $(wc -l < src-tags.txt)"
|
||||
head -n 20 src-tags.txt || true
|
||||
|
||||
- name: List destination tags (skip existing)
|
||||
run: |
|
||||
set -euo pipefail
|
||||
if skopeo list-tags --retry-times 3 docker://"${DEST_IMAGE}" >/tmp/dst.json 2>/dev/null; then
|
||||
jq -r '.Tags[]' /tmp/dst.json | sort -u > dst-tags.txt
|
||||
else
|
||||
: > dst-tags.txt
|
||||
fi
|
||||
echo "Existing destination tags: $(wc -l < dst-tags.txt)"
|
||||
|
||||
- name: Mirror, dual-sign, and verify
|
||||
env:
|
||||
# keyless
|
||||
COSIGN_YES: "true"
|
||||
# key-based
|
||||
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
|
||||
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
|
||||
# verify
|
||||
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
copied=0; skipped=0; v_ok=0; errs=0
|
||||
|
||||
issuer="https://token.actions.githubusercontent.com"
|
||||
id_regex="^https://github.com/${{ github.repository }}/.+"
|
||||
|
||||
while read -r tag; do
|
||||
[ -z "$tag" ] && continue
|
||||
|
||||
if grep -Fxq "$tag" dst-tags.txt; then
|
||||
echo "::notice ::Skip (exists) ${DEST_IMAGE}:${tag}"
|
||||
skipped=$((skipped+1))
|
||||
continue
|
||||
fi
|
||||
|
||||
echo "==> Copy ${SOURCE_IMAGE}:${tag} → ${DEST_IMAGE}:${tag}"
|
||||
if ! skopeo copy --all --retry-times 3 \
|
||||
docker://"${SOURCE_IMAGE}:${tag}" docker://"${DEST_IMAGE}:${tag}"; then
|
||||
echo "::warning title=Copy failed::${SOURCE_IMAGE}:${tag}"
|
||||
errs=$((errs+1)); continue
|
||||
fi
|
||||
copied=$((copied+1))
|
||||
|
||||
digest="$(skopeo inspect --retry-times 3 docker://"${DEST_IMAGE}:${tag}" | jq -r '.Digest')"
|
||||
ref="${DEST_IMAGE}@${digest}"
|
||||
|
||||
echo "==> cosign sign (keyless) --recursive ${ref}"
|
||||
if ! cosign sign --recursive "${ref}"; then
|
||||
echo "::warning title=Keyless sign failed::${ref}"
|
||||
errs=$((errs+1))
|
||||
fi
|
||||
|
||||
echo "==> cosign sign (key) --recursive ${ref}"
|
||||
if ! cosign sign --key env://COSIGN_PRIVATE_KEY --recursive "${ref}"; then
|
||||
echo "::warning title=Key sign failed::${ref}"
|
||||
errs=$((errs+1))
|
||||
fi
|
||||
|
||||
echo "==> cosign verify (public key) ${ref}"
|
||||
if ! cosign verify --key env://COSIGN_PUBLIC_KEY "${ref}" -o text; then
|
||||
echo "::warning title=Verify(pubkey) failed::${ref}"
|
||||
errs=$((errs+1))
|
||||
fi
|
||||
|
||||
echo "==> cosign verify (keyless policy) ${ref}"
|
||||
if ! cosign verify \
|
||||
--certificate-oidc-issuer "${issuer}" \
|
||||
--certificate-identity-regexp "${id_regex}" \
|
||||
"${ref}" -o text; then
|
||||
echo "::warning title=Verify(keyless) failed::${ref}"
|
||||
errs=$((errs+1))
|
||||
else
|
||||
v_ok=$((v_ok+1))
|
||||
fi
|
||||
done < src-tags.txt
|
||||
|
||||
echo "---- Summary ----"
|
||||
echo "Copied : $copied"
|
||||
echo "Skipped : $skipped"
|
||||
echo "Verified OK : $v_ok"
|
||||
echo "Errors : $errs"
|
||||
11
.github/workflows/test.yml
vendored
11
.github/workflows/test.yml
vendored
@@ -1,5 +1,8 @@
|
||||
name: Run Tests
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
@@ -8,15 +11,15 @@ on:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: amd64-runner
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@7a3fe6cf4cb3a834922a1244abfce67bcef6a0c5 # v6.2.0
|
||||
with:
|
||||
go-version: '1.25'
|
||||
go-version: 1.25
|
||||
|
||||
- name: Build go
|
||||
run: go build
|
||||
|
||||
@@ -4,11 +4,7 @@ Contributions are welcome!
|
||||
|
||||
Please see the contribution and local development guide on the docs page before getting started:
|
||||
|
||||
https://docs.fossorial.io/development
|
||||
|
||||
For ideas about what features to work on and our future plans, please see the roadmap:
|
||||
|
||||
https://docs.fossorial.io/roadmap
|
||||
https://docs.pangolin.net/development/contributing
|
||||
|
||||
### Licensing Considerations
|
||||
|
||||
|
||||
@@ -16,18 +16,13 @@ COPY . .
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -o /gerbil
|
||||
|
||||
# Start a new stage from scratch
|
||||
FROM ubuntu:24.04 AS runner
|
||||
FROM alpine:3.23 AS runner
|
||||
|
||||
RUN apt-get update && apt-get install -y iptables iproute2 && rm -rf /var/lib/apt/lists/*
|
||||
RUN apk add --no-cache iptables iproute2
|
||||
|
||||
# Copy the pre-built binary file from the previous stage and the entrypoint script
|
||||
COPY --from=builder /gerbil /usr/local/bin/
|
||||
COPY entrypoint.sh /
|
||||
|
||||
RUN chmod +x /entrypoint.sh
|
||||
|
||||
# Copy the entrypoint script
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
|
||||
# Command to run the executable
|
||||
CMD ["gerbil"]
|
||||
18
README.md
18
README.md
@@ -6,7 +6,7 @@ Gerbil is a simple [WireGuard](https://www.wireguard.com/) interface management
|
||||
|
||||
Gerbil works with Pangolin, Newt, and Olm as part of the larger system. See documentation below:
|
||||
|
||||
- [Full Documentation](https://docs.fossorial.io)
|
||||
- [Full Documentation](https://docs.pangolin.net)
|
||||
|
||||
## Key Functions
|
||||
|
||||
@@ -20,7 +20,7 @@ Gerbil will create the peers defined in the config on the WireGuard interface. T
|
||||
|
||||
### Report Bandwidth
|
||||
|
||||
Bytes transmitted in and out of each peer are collected every 10 seconds, and incremental usage is reported via the "reportBandwidthTo" endpoint. This can be used to track data usage of each peer on the remote server.
|
||||
Bytes transmitted in and out of each peer are collected every 10 seconds, and incremental usage is reported via the api endpoint. This can be used to track data usage of each peer on the remote server.
|
||||
|
||||
### Handle client relaying
|
||||
|
||||
@@ -42,16 +42,15 @@ In single node (self hosted) Pangolin deployments this can be bypassed by using
|
||||
|
||||
## CLI Args
|
||||
|
||||
Important:
|
||||
- `reachableAt`: How should the remote server reach Gerbil's API?
|
||||
- `generateAndSaveKeyTo`: Where to save the generated WireGuard private key to persist across restarts.
|
||||
- `remoteConfig` (optional): Remote config location to HTTP get the JSON based config from. See `example_config.json`
|
||||
- `config` (optional): Local JSON file path to load config. Used if remote config is not supplied. See `example_config.json`
|
||||
|
||||
Note: You must use either `config` or `remoteConfig` to configure WireGuard.
|
||||
- `remoteConfig`: Remote config location to HTTP get the JSON based config from.
|
||||
|
||||
Others:
|
||||
- `reportBandwidthTo` (optional): **DEPRECATED** - Use `remoteConfig` instead. Remote HTTP endpoint to send peer bandwidth data
|
||||
- `interface` (optional): Name of the WireGuard interface created by Gerbil. Default: `wg0`
|
||||
- `listen` (optional): Port to listen on for HTTP server. Default: `:3003`
|
||||
- `listen` (optional): Port to listen on for HTTP server. Default: `:3004`
|
||||
- `log-level` (optional): The log level to use (DEBUG, INFO, WARN, ERROR, FATAL). Default: `INFO`
|
||||
- `mtu` (optional): MTU of the WireGuard interface. Default: `1280`
|
||||
- `notify` (optional): URL to notify on peer changes
|
||||
@@ -66,7 +65,6 @@ Note: You must use either `config` or `remoteConfig` to configure WireGuard.
|
||||
All CLI arguments can also be provided via environment variables:
|
||||
|
||||
- `INTERFACE`: Name of the WireGuard interface
|
||||
- `CONFIG`: Path to local configuration file
|
||||
- `REMOTE_CONFIG`: URL of the remote config server
|
||||
- `LISTEN`: Address to listen on for HTTP server
|
||||
- `GENERATE_AND_SAVE_KEY_TO`: Path to save generated private key
|
||||
@@ -84,7 +82,7 @@ Example:
|
||||
|
||||
```bash
|
||||
./gerbil \
|
||||
--reachableAt=http://gerbil:3003 \
|
||||
--reachableAt=http://gerbil:3004 \
|
||||
--generateAndSaveKeyTo=/var/config/key \
|
||||
--remoteConfig=http://pangolin:3001/api/v1/
|
||||
```
|
||||
@@ -96,7 +94,7 @@ services:
|
||||
container_name: gerbil
|
||||
restart: unless-stopped
|
||||
command:
|
||||
- --reachableAt=http://gerbil:3003
|
||||
- --reachableAt=http://gerbil:3004
|
||||
- --generateAndSaveKeyTo=/var/config/key
|
||||
- --remoteConfig=http://pangolin:3001/api/v1/
|
||||
volumes:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
If you discover a security vulnerability, please follow the steps below to responsibly disclose it to us:
|
||||
|
||||
1. **Do not create a public GitHub issue or discussion post.** This could put the security of other users at risk.
|
||||
2. Send a detailed report to [security@fossorial.io](mailto:security@fossorial.io) or send a **private** message to a maintainer on [Discord](https://discord.gg/HCJR8Xhme4). Include:
|
||||
2. Send a detailed report to [security@pangolin.net](mailto:security@pangolin.net) or send a **private** message to a maintainer on [Discord](https://discord.gg/HCJR8Xhme4). Include:
|
||||
|
||||
- Description and location of the vulnerability.
|
||||
- Potential impact of the vulnerability.
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
{
|
||||
"privateKey": "kBGTgk7c+zncEEoSnMl+jsLjVh5ZVoL/HwBSQem+d1M=",
|
||||
"listenPort": 51820,
|
||||
"ipAddress": "10.0.0.1/24",
|
||||
"peers": [
|
||||
{
|
||||
"publicKey": "5UzzoeveFVSzuqK3nTMS5bA1jIMs1fQffVQzJ8MXUQM=",
|
||||
"allowedIps": ["10.0.0.0/28"]
|
||||
},
|
||||
{
|
||||
"publicKey": "kYrZpuO2NsrFoBh1GMNgkhd1i9Rgtu1rAjbJ7qsfngU=",
|
||||
"allowedIps": ["10.0.0.16/28"]
|
||||
},
|
||||
{
|
||||
"publicKey": "1YfPUVr9ZF4zehkbI2BQhCxaRLz+Vtwa4vJwH+mpK0A=",
|
||||
"allowedIps": ["10.0.0.32/28"]
|
||||
},
|
||||
{
|
||||
"publicKey": "2/U4oyZ+sai336Dal/yExCphL8AxyqvIxMk4qsUy4iI=",
|
||||
"allowedIps": ["10.0.0.48/28"]
|
||||
}
|
||||
]
|
||||
}
|
||||
10
go.mod
10
go.mod
@@ -3,8 +3,10 @@ module github.com/fosrl/gerbil
|
||||
go 1.25
|
||||
|
||||
require (
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
github.com/vishvananda/netlink v1.3.1
|
||||
golang.org/x/crypto v0.36.0
|
||||
golang.org/x/crypto v0.46.0
|
||||
golang.org/x/sync v0.1.0
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
||||
)
|
||||
|
||||
@@ -14,10 +16,8 @@ require (
|
||||
github.com/mdlayher/genetlink v1.3.2 // indirect
|
||||
github.com/mdlayher/netlink v1.7.2 // indirect
|
||||
github.com/mdlayher/socket v0.4.1 // indirect
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
|
||||
github.com/vishvananda/netns v0.0.5 // indirect
|
||||
golang.org/x/net v0.38.0 // indirect
|
||||
golang.org/x/sync v0.1.0 // indirect
|
||||
golang.org/x/sys v0.31.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/sys v0.39.0 // indirect
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b // indirect
|
||||
)
|
||||
|
||||
12
go.sum
12
go.sum
@@ -16,16 +16,16 @@ github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW
|
||||
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
|
||||
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
|
||||
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
|
||||
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
||||
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
|
||||
|
||||
507
main.go
507
main.go
@@ -2,15 +2,21 @@ package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -21,25 +27,28 @@ import (
|
||||
"github.com/fosrl/gerbil/proxy"
|
||||
"github.com/fosrl/gerbil/relay"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
var (
|
||||
interfaceName string
|
||||
listenAddr string
|
||||
mtuInt int
|
||||
lastReadings = make(map[string]PeerReading)
|
||||
mu sync.Mutex
|
||||
wgMu sync.Mutex // Protects WireGuard operations
|
||||
notifyURL string
|
||||
proxyRelay *relay.UDPProxyServer
|
||||
proxySNI *proxy.SNIProxy
|
||||
interfaceName string
|
||||
listenAddr string
|
||||
mtuInt int
|
||||
lastReadings = make(map[string]PeerReading)
|
||||
mu sync.Mutex
|
||||
wgMu sync.Mutex // Protects WireGuard operations
|
||||
notifyURL string
|
||||
proxyRelay *relay.UDPProxyServer
|
||||
proxySNI *proxy.SNIProxy
|
||||
doTrafficShaping bool
|
||||
)
|
||||
|
||||
type WgConfig struct {
|
||||
PrivateKey string `json:"privateKey"`
|
||||
ListenPort int `json:"listenPort"`
|
||||
RelayPort int `json:"relayPort"`
|
||||
IpAddress string `json:"ipAddress"`
|
||||
Peers []Peer `json:"peers"`
|
||||
}
|
||||
@@ -108,6 +117,8 @@ func parseLogLevel(level string) logger.LogLevel {
|
||||
}
|
||||
|
||||
func main() {
|
||||
go monitorMemory(1024 * 1024 * 512) // trigger if memory usage exceeds 512MB
|
||||
|
||||
var (
|
||||
err error
|
||||
wgconfig WgConfig
|
||||
@@ -121,6 +132,7 @@ func main() {
|
||||
localProxyAddr string
|
||||
localProxyPort int
|
||||
localOverridesStr string
|
||||
trustedUpstreamsStr string
|
||||
proxyProtocol bool
|
||||
)
|
||||
|
||||
@@ -138,7 +150,9 @@ func main() {
|
||||
localProxyAddr = os.Getenv("LOCAL_PROXY")
|
||||
localProxyPortStr := os.Getenv("LOCAL_PROXY_PORT")
|
||||
localOverridesStr = os.Getenv("LOCAL_OVERRIDES")
|
||||
trustedUpstreamsStr = os.Getenv("TRUSTED_UPSTREAMS")
|
||||
proxyProtocolStr := os.Getenv("PROXY_PROTOCOL")
|
||||
doTrafficShapingStr := os.Getenv("DO_TRAFFIC_SHAPING")
|
||||
|
||||
if interfaceName == "" {
|
||||
flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface")
|
||||
@@ -150,7 +164,7 @@ func main() {
|
||||
flag.StringVar(&remoteConfigURL, "remoteConfig", "", "URL of the Pangolin server")
|
||||
}
|
||||
if listenAddr == "" {
|
||||
flag.StringVar(&listenAddr, "listen", ":3003", "Address to listen on")
|
||||
flag.StringVar(&listenAddr, "listen", "", "DEPRECATED (overridden by reachableAt): Address to listen on")
|
||||
}
|
||||
// DEPRECATED AND UNSED: reportBandwidthTo
|
||||
// allow reportBandwidthTo to be passed but dont do anything with it just thow it away
|
||||
@@ -160,9 +174,11 @@ func main() {
|
||||
if generateAndSaveKeyTo == "" {
|
||||
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
|
||||
}
|
||||
|
||||
if reachableAt == "" {
|
||||
flag.StringVar(&reachableAt, "reachableAt", "", "Endpoint of the http server to tell remote config about")
|
||||
}
|
||||
|
||||
if logLevel == "" {
|
||||
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
||||
}
|
||||
@@ -197,6 +213,9 @@ func main() {
|
||||
if localOverridesStr != "" {
|
||||
flag.StringVar(&localOverridesStr, "local-overrides", "", "Comma-separated list of local overrides for SNI proxy")
|
||||
}
|
||||
if trustedUpstreamsStr == "" {
|
||||
flag.StringVar(&trustedUpstreamsStr, "trusted-upstreams", "", "Comma-separated list of trusted upstream proxy domain names/IPs that can send PROXY protocol")
|
||||
}
|
||||
|
||||
if proxyProtocolStr != "" {
|
||||
proxyProtocol = strings.ToLower(proxyProtocolStr) == "true"
|
||||
@@ -205,11 +224,38 @@ func main() {
|
||||
flag.BoolVar(&proxyProtocol, "proxy-protocol", true, "Enable PROXY protocol v1 for preserving client IP")
|
||||
}
|
||||
|
||||
if doTrafficShapingStr != "" {
|
||||
doTrafficShaping = strings.ToLower(doTrafficShapingStr) == "true"
|
||||
}
|
||||
if doTrafficShapingStr == "" {
|
||||
flag.BoolVar(&doTrafficShaping, "do-traffic-shaping", false, "Whether to set up traffic shaping rules for peers (requires tc command and root privileges)")
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
logger.Init()
|
||||
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
|
||||
|
||||
// Base context for the application; cancel on SIGINT/SIGTERM
|
||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
// try to parse as http://host:port and set the listenAddr to the :port from this reachableAt.
|
||||
if reachableAt != "" && listenAddr == "" {
|
||||
if strings.HasPrefix(reachableAt, "http://") || strings.HasPrefix(reachableAt, "https://") {
|
||||
parts := strings.Split(reachableAt, ":")
|
||||
if len(parts) == 3 {
|
||||
port := parts[2]
|
||||
if strings.Contains(port, "/") {
|
||||
port = strings.Split(port, "/")[0]
|
||||
}
|
||||
listenAddr = ":" + port
|
||||
}
|
||||
}
|
||||
} else if listenAddr == "" {
|
||||
listenAddr = ":3003"
|
||||
}
|
||||
|
||||
mtuInt, err = strconv.Atoi(mtu)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to parse MTU: %v", err)
|
||||
@@ -301,10 +347,20 @@ func main() {
|
||||
// Ensure the WireGuard peers exist
|
||||
ensureWireguardPeers(wgconfig.Peers)
|
||||
|
||||
go periodicBandwidthCheck(remoteConfigURL + "/gerbil/receive-bandwidth")
|
||||
// Child error group derived from base context
|
||||
group, groupCtx := errgroup.WithContext(ctx)
|
||||
|
||||
// Periodic bandwidth reporting
|
||||
group.Go(func() error {
|
||||
return periodicBandwidthCheck(groupCtx, remoteConfigURL+"/gerbil/receive-bandwidth")
|
||||
})
|
||||
|
||||
// Start the UDP proxy server
|
||||
proxyRelay = relay.NewUDPProxyServer(":21820", remoteConfigURL, key, reachableAt)
|
||||
relayPort := wgconfig.RelayPort
|
||||
if relayPort == 0 {
|
||||
relayPort = 21820 // in case there is no relay port set, use 21820
|
||||
}
|
||||
proxyRelay = relay.NewUDPProxyServer(groupCtx, fmt.Sprintf(":%d", relayPort), remoteConfigURL, key, reachableAt)
|
||||
err = proxyRelay.Start()
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to start UDP proxy server: %v", err)
|
||||
@@ -323,7 +379,16 @@ func main() {
|
||||
logger.Info("Local overrides configured: %v", localOverrides)
|
||||
}
|
||||
|
||||
proxySNI, err = proxy.NewSNIProxy(sniProxyPort, remoteConfigURL, key.PublicKey().String(), localProxyAddr, localProxyPort, localOverrides, proxyProtocol)
|
||||
var trustedUpstreams []string
|
||||
if trustedUpstreamsStr != "" {
|
||||
trustedUpstreams = strings.Split(trustedUpstreamsStr, ",")
|
||||
for i, upstream := range trustedUpstreams {
|
||||
trustedUpstreams[i] = strings.TrimSpace(upstream)
|
||||
}
|
||||
logger.Info("Trusted upstreams configured: %v", trustedUpstreams)
|
||||
}
|
||||
|
||||
proxySNI, err = proxy.NewSNIProxy(sniProxyPort, remoteConfigURL, key.PublicKey().String(), localProxyAddr, localProxyPort, localOverrides, proxyProtocol, trustedUpstreams)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to create proxy: %v", err)
|
||||
}
|
||||
@@ -337,20 +402,42 @@ func main() {
|
||||
http.HandleFunc("/update-proxy-mapping", handleUpdateProxyMapping)
|
||||
http.HandleFunc("/update-destinations", handleUpdateDestinations)
|
||||
http.HandleFunc("/update-local-snis", handleUpdateLocalSNIs)
|
||||
http.HandleFunc("/healthz", handleHealthz)
|
||||
logger.Info("Starting HTTP server on %s", listenAddr)
|
||||
|
||||
// Run HTTP server in a goroutine
|
||||
go func() {
|
||||
if err := http.ListenAndServe(listenAddr, nil); err != nil {
|
||||
logger.Error("HTTP server failed: %v", err)
|
||||
// HTTP server with graceful shutdown on context cancel
|
||||
server := &http.Server{
|
||||
Addr: listenAddr,
|
||||
Handler: nil,
|
||||
}
|
||||
group.Go(func() error {
|
||||
// http.ErrServerClosed is returned on graceful shutdown; not an error for us
|
||||
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return err
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
})
|
||||
group.Go(func() error {
|
||||
<-groupCtx.Done()
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = server.Shutdown(shutdownCtx)
|
||||
// Stop background components as the context is canceled
|
||||
if proxySNI != nil {
|
||||
_ = proxySNI.Stop()
|
||||
}
|
||||
if proxyRelay != nil {
|
||||
proxyRelay.Stop()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Keep the main goroutine running
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sigCh
|
||||
logger.Info("Shutting down servers...")
|
||||
// Wait for all goroutines to finish
|
||||
if err := group.Wait(); err != nil && !errors.Is(err, context.Canceled) {
|
||||
logger.Error("Service exited with error: %v", err)
|
||||
} else if errors.Is(err, context.Canceled) {
|
||||
logger.Info("Context cancelled, shutting down")
|
||||
}
|
||||
}
|
||||
|
||||
func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) {
|
||||
@@ -477,6 +564,10 @@ func ensureWireguardInterface(wgconfig WgConfig) error {
|
||||
logger.Warn("Failed to ensure MSS clamping: %v", err)
|
||||
}
|
||||
|
||||
if err := ensureWireguardFirewall(); err != nil {
|
||||
logger.Warn("Failed to ensure WireGuard firewall rules: %v", err)
|
||||
}
|
||||
|
||||
logger.Info("WireGuard interface %s created and configured", interfaceName)
|
||||
|
||||
return nil
|
||||
@@ -607,7 +698,7 @@ func ensureMSSClamping() error {
|
||||
if out, err := addCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)",
|
||||
chain, err, string(out))
|
||||
logger.Error(errMsg)
|
||||
logger.Error("%s", errMsg)
|
||||
errors = append(errors, fmt.Errorf("%s", errMsg))
|
||||
continue
|
||||
}
|
||||
@@ -624,7 +715,7 @@ func ensureMSSClamping() error {
|
||||
if out, err := checkCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)",
|
||||
chain, err, string(out))
|
||||
logger.Error(errMsg)
|
||||
logger.Error("%s", errMsg)
|
||||
errors = append(errors, fmt.Errorf("%s", errMsg))
|
||||
continue
|
||||
}
|
||||
@@ -645,6 +736,113 @@ func ensureMSSClamping() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func ensureWireguardFirewall() error {
|
||||
// Rules to enforce:
|
||||
// 1. Allow established/related connections (responses to our outbound traffic)
|
||||
// 2. Allow ICMP ping packets
|
||||
// 3. Drop all other inbound traffic from peers
|
||||
|
||||
// Define the rules we want to ensure exist
|
||||
rules := [][]string{
|
||||
// Allow established and related connections (responses to outbound traffic)
|
||||
{
|
||||
"-A", "INPUT",
|
||||
"-i", interfaceName,
|
||||
"-m", "conntrack",
|
||||
"--ctstate", "ESTABLISHED,RELATED",
|
||||
"-j", "ACCEPT",
|
||||
},
|
||||
// Allow ICMP ping requests
|
||||
{
|
||||
"-A", "INPUT",
|
||||
"-i", interfaceName,
|
||||
"-p", "icmp",
|
||||
"--icmp-type", "8",
|
||||
"-j", "ACCEPT",
|
||||
},
|
||||
// Drop all other inbound traffic from WireGuard interface
|
||||
{
|
||||
"-A", "INPUT",
|
||||
"-i", interfaceName,
|
||||
"-j", "DROP",
|
||||
},
|
||||
}
|
||||
|
||||
// First, try to delete any existing rules for this interface
|
||||
for _, rule := range rules {
|
||||
deleteArgs := make([]string, len(rule))
|
||||
copy(deleteArgs, rule)
|
||||
// Change -A to -D for deletion
|
||||
for i, arg := range deleteArgs {
|
||||
if arg == "-A" {
|
||||
deleteArgs[i] = "-D"
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
deleteCmd := exec.Command("/usr/sbin/iptables", deleteArgs...)
|
||||
logger.Debug("Attempting to delete existing firewall rule: %v", deleteArgs)
|
||||
|
||||
// Try deletion multiple times to handle multiple existing rules
|
||||
for i := 0; i < 5; i++ {
|
||||
out, err := deleteCmd.CombinedOutput()
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
logger.Debug("Deletion stopped: %v (output: %s)", exitErr.String(), string(out))
|
||||
}
|
||||
break // No more rules to delete
|
||||
}
|
||||
logger.Info("Deleted existing firewall rule (attempt %d)", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
// Now add the rules
|
||||
var errors []error
|
||||
for i, rule := range rules {
|
||||
addCmd := exec.Command("/usr/sbin/iptables", rule...)
|
||||
logger.Info("Adding WireGuard firewall rule %d: %v", i+1, rule)
|
||||
|
||||
if out, err := addCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("Failed to add firewall rule %d: %v (output: %s)", i+1, err, string(out))
|
||||
logger.Error("%s", errMsg)
|
||||
errors = append(errors, fmt.Errorf("%s", errMsg))
|
||||
continue
|
||||
}
|
||||
|
||||
// Verify the rule was added by checking
|
||||
checkArgs := make([]string, len(rule))
|
||||
copy(checkArgs, rule)
|
||||
// Change -A to -C for check
|
||||
for j, arg := range checkArgs {
|
||||
if arg == "-A" {
|
||||
checkArgs[j] = "-C"
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
checkCmd := exec.Command("/usr/sbin/iptables", checkArgs...)
|
||||
if out, err := checkCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("Rule verification failed for rule %d: %v (output: %s)", i+1, err, string(out))
|
||||
logger.Error("%s", errMsg)
|
||||
errors = append(errors, fmt.Errorf("%s", errMsg))
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Info("Successfully added and verified WireGuard firewall rule %d", i+1)
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
var errMsgs []string
|
||||
for _, err := range errors {
|
||||
errMsgs = append(errMsgs, err.Error())
|
||||
}
|
||||
return fmt.Errorf("WireGuard firewall setup encountered errors:\n%s", strings.Join(errMsgs, "\n"))
|
||||
}
|
||||
|
||||
logger.Info("WireGuard firewall rules successfully configured for interface %s", interfaceName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func handlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
@@ -656,6 +854,15 @@ func handlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
func handleHealthz(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}
|
||||
|
||||
func handleAddPeer(w http.ResponseWriter, r *http.Request) {
|
||||
var peer Peer
|
||||
if err := json.NewDecoder(r.Body).Decode(&peer); err != nil {
|
||||
@@ -688,17 +895,23 @@ func addPeerInternal(peer Peer) error {
|
||||
return fmt.Errorf("failed to parse public key: %v", err)
|
||||
}
|
||||
|
||||
logger.Debug("Adding peer %s with AllowedIPs: %v", peer.PublicKey, peer.AllowedIPs)
|
||||
|
||||
// parse allowed IPs into array of net.IPNet
|
||||
var allowedIPs []net.IPNet
|
||||
var wgIPs []string
|
||||
for _, ipStr := range peer.AllowedIPs {
|
||||
logger.Debug("Parsing AllowedIP: %s", ipStr)
|
||||
_, ipNet, err := net.ParseCIDR(ipStr)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to parse allowed IP '%s' for peer %s: %v", ipStr, peer.PublicKey, err)
|
||||
return fmt.Errorf("failed to parse allowed IP: %v", err)
|
||||
}
|
||||
allowedIPs = append(allowedIPs, *ipNet)
|
||||
// Extract the IP address from the CIDR for relay cleanup
|
||||
wgIPs = append(wgIPs, ipNet.IP.String())
|
||||
extractedIP := ipNet.IP.String()
|
||||
wgIPs = append(wgIPs, extractedIP)
|
||||
logger.Debug("Extracted IP %s from AllowedIP %s", extractedIP, ipStr)
|
||||
}
|
||||
|
||||
peerConfig := wgtypes.PeerConfig{
|
||||
@@ -714,6 +927,18 @@ func addPeerInternal(peer Peer) error {
|
||||
return fmt.Errorf("failed to add peer: %v", err)
|
||||
}
|
||||
|
||||
// Setup bandwidth limiting for each peer IP
|
||||
if doTrafficShaping {
|
||||
logger.Debug("doTrafficShaping is true, setting up bandwidth limits for %d IPs", len(wgIPs))
|
||||
for _, wgIP := range wgIPs {
|
||||
if err := setupPeerBandwidthLimit(wgIP); err != nil {
|
||||
logger.Warn("Failed to setup bandwidth limit for peer IP %s: %v", wgIP, err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logger.Debug("doTrafficShaping is false, skipping bandwidth limit setup")
|
||||
}
|
||||
|
||||
// Clear relay connections for the peer's WireGuard IPs
|
||||
if proxyRelay != nil {
|
||||
for _, wgIP := range wgIPs {
|
||||
@@ -758,19 +983,17 @@ func removePeerInternal(publicKey string) error {
|
||||
return fmt.Errorf("failed to parse public key: %v", err)
|
||||
}
|
||||
|
||||
// Get current peer info before removing to clear relay connections
|
||||
// Get current peer info before removing to clear relay connections and bandwidth limits
|
||||
var wgIPs []string
|
||||
if proxyRelay != nil {
|
||||
device, err := wgClient.Device(interfaceName)
|
||||
if err == nil {
|
||||
for _, peer := range device.Peers {
|
||||
if peer.PublicKey.String() == publicKey {
|
||||
// Extract WireGuard IPs from this peer's allowed IPs
|
||||
for _, allowedIP := range peer.AllowedIPs {
|
||||
wgIPs = append(wgIPs, allowedIP.IP.String())
|
||||
}
|
||||
break
|
||||
device, err := wgClient.Device(interfaceName)
|
||||
if err == nil {
|
||||
for _, peer := range device.Peers {
|
||||
if peer.PublicKey.String() == publicKey {
|
||||
// Extract WireGuard IPs from this peer's allowed IPs
|
||||
for _, allowedIP := range peer.AllowedIPs {
|
||||
wgIPs = append(wgIPs, allowedIP.IP.String())
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -788,6 +1011,15 @@ func removePeerInternal(publicKey string) error {
|
||||
return fmt.Errorf("failed to remove peer: %v", err)
|
||||
}
|
||||
|
||||
// Remove bandwidth limits for each peer IP
|
||||
if doTrafficShaping {
|
||||
for _, wgIP := range wgIPs {
|
||||
if err := removePeerBandwidthLimit(wgIP); err != nil {
|
||||
logger.Warn("Failed to remove bandwidth limit for peer IP %s: %v", wgIP, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clear relay connections for the peer's WireGuard IPs
|
||||
if proxyRelay != nil {
|
||||
for _, wgIP := range wgIPs {
|
||||
@@ -945,13 +1177,18 @@ func handleUpdateLocalSNIs(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
func periodicBandwidthCheck(endpoint string) {
|
||||
func periodicBandwidthCheck(ctx context.Context, endpoint string) error {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
if err := reportPeerBandwidth(endpoint); err != nil {
|
||||
logger.Info("Failed to report peer bandwidth: %v", err)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := reportPeerBandwidth(endpoint); err != nil {
|
||||
logger.Info("Failed to report peer bandwidth: %v", err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -971,8 +1208,13 @@ func calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
// Track the set of peers currently present on the device to prune stale readings efficiently
|
||||
currentPeerKeys := make(map[string]struct{}, len(device.Peers))
|
||||
|
||||
for _, peer := range device.Peers {
|
||||
publicKey := peer.PublicKey.String()
|
||||
currentPeerKeys[publicKey] = struct{}{}
|
||||
|
||||
currentReading := PeerReading{
|
||||
BytesReceived: peer.ReceiveBytes,
|
||||
BytesTransmitted: peer.TransmitBytes,
|
||||
@@ -1029,14 +1271,7 @@ func calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
||||
|
||||
// Clean up old peers
|
||||
for publicKey := range lastReadings {
|
||||
found := false
|
||||
for _, peer := range device.Peers {
|
||||
if peer.PublicKey.String() == publicKey {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
if _, exists := currentPeerKeys[publicKey]; !exists {
|
||||
delete(lastReadings, publicKey)
|
||||
}
|
||||
}
|
||||
@@ -1092,3 +1327,177 @@ func notifyPeerChange(action, publicKey string) {
|
||||
logger.Warn("Notify server returned non-OK: %s", resp.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func monitorMemory(limit uint64) {
|
||||
var m runtime.MemStats
|
||||
for {
|
||||
runtime.ReadMemStats(&m)
|
||||
if m.Alloc > limit {
|
||||
fmt.Printf("Memory spike detected (%d bytes). Dumping profile...\n", m.Alloc)
|
||||
|
||||
f, err := os.Create(fmt.Sprintf("/var/config/heap/heap-spike-%d.pprof", time.Now().Unix()))
|
||||
if err != nil {
|
||||
log.Println("could not create profile:", err)
|
||||
} else {
|
||||
pprof.WriteHeapProfile(f)
|
||||
f.Close()
|
||||
}
|
||||
|
||||
// Wait a while before checking again to avoid spamming profiles
|
||||
time.Sleep(5 * time.Minute)
|
||||
}
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
// setupPeerBandwidthLimit sets up TC (Traffic Control) to limit bandwidth for a specific peer IP
|
||||
// Currently hardcoded to 20 Mbps per peer
|
||||
func setupPeerBandwidthLimit(peerIP string) error {
|
||||
logger.Debug("setupPeerBandwidthLimit called for peer IP: %s", peerIP)
|
||||
const bandwidthLimit = "50mbit" // 50 Mbps limit per peer
|
||||
|
||||
// Parse the IP to get just the IP address (strip any CIDR notation if present)
|
||||
ip := peerIP
|
||||
if strings.Contains(peerIP, "/") {
|
||||
parsedIP, _, err := net.ParseCIDR(peerIP)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse peer IP: %v", err)
|
||||
}
|
||||
ip = parsedIP.String()
|
||||
}
|
||||
|
||||
// First, ensure we have a root qdisc on the interface (HTB - Hierarchical Token Bucket)
|
||||
// Check if qdisc already exists
|
||||
cmd := exec.Command("tc", "qdisc", "show", "dev", interfaceName)
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check qdisc: %v, output: %s", err, string(output))
|
||||
}
|
||||
|
||||
// If no HTB qdisc exists, create one
|
||||
if !strings.Contains(string(output), "htb") {
|
||||
cmd = exec.Command("tc", "qdisc", "add", "dev", interfaceName, "root", "handle", "1:", "htb", "default", "9999")
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to add root qdisc: %v, output: %s", err, string(output))
|
||||
}
|
||||
logger.Info("Created HTB root qdisc on %s", interfaceName)
|
||||
}
|
||||
|
||||
// Generate a unique class ID based on the IP address
|
||||
// We'll use the last octet of the IP as part of the class ID
|
||||
ipParts := strings.Split(ip, ".")
|
||||
if len(ipParts) != 4 {
|
||||
return fmt.Errorf("invalid IPv4 address: %s", ip)
|
||||
}
|
||||
lastOctet := ipParts[3]
|
||||
classID := fmt.Sprintf("1:%s", lastOctet)
|
||||
logger.Debug("Generated class ID %s for peer IP %s", classID, ip)
|
||||
|
||||
// Create a class for this peer with bandwidth limit
|
||||
cmd = exec.Command("tc", "class", "add", "dev", interfaceName, "parent", "1:", "classid", classID,
|
||||
"htb", "rate", bandwidthLimit, "ceil", bandwidthLimit)
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
logger.Debug("tc class add failed for %s: %v, output: %s", ip, err, string(output))
|
||||
// If class already exists, try to replace it
|
||||
if strings.Contains(string(output), "File exists") {
|
||||
cmd = exec.Command("tc", "class", "replace", "dev", interfaceName, "parent", "1:", "classid", classID,
|
||||
"htb", "rate", bandwidthLimit, "ceil", bandwidthLimit)
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to replace class: %v, output: %s", err, string(output))
|
||||
}
|
||||
logger.Debug("Successfully replaced existing class %s for peer IP %s", classID, ip)
|
||||
} else {
|
||||
return fmt.Errorf("failed to add class: %v, output: %s", err, string(output))
|
||||
}
|
||||
} else {
|
||||
logger.Debug("Successfully added new class %s for peer IP %s", classID, ip)
|
||||
}
|
||||
|
||||
// Add a filter to match traffic from this peer IP (ingress)
|
||||
cmd = exec.Command("tc", "filter", "add", "dev", interfaceName, "protocol", "ip", "parent", "1:",
|
||||
"prio", "1", "u32", "match", "ip", "src", ip, "flowid", classID)
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
// If filter fails, log but don't fail the peer addition
|
||||
logger.Warn("Failed to add ingress filter for peer IP %s: %v, output: %s", ip, err, string(output))
|
||||
}
|
||||
|
||||
// Add a filter to match traffic to this peer IP (egress)
|
||||
cmd = exec.Command("tc", "filter", "add", "dev", interfaceName, "protocol", "ip", "parent", "1:",
|
||||
"prio", "1", "u32", "match", "ip", "dst", ip, "flowid", classID)
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
// If filter fails, log but don't fail the peer addition
|
||||
logger.Warn("Failed to add egress filter for peer IP %s: %v, output: %s", ip, err, string(output))
|
||||
}
|
||||
|
||||
logger.Info("Setup bandwidth limit of %s for peer IP %s (class %s)", bandwidthLimit, ip, classID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// removePeerBandwidthLimit removes TC rules for a specific peer IP
|
||||
func removePeerBandwidthLimit(peerIP string) error {
|
||||
// Parse the IP to get just the IP address
|
||||
ip := peerIP
|
||||
if strings.Contains(peerIP, "/") {
|
||||
parsedIP, _, err := net.ParseCIDR(peerIP)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse peer IP: %v", err)
|
||||
}
|
||||
ip = parsedIP.String()
|
||||
}
|
||||
|
||||
// Generate the class ID based on the IP
|
||||
ipParts := strings.Split(ip, ".")
|
||||
if len(ipParts) != 4 {
|
||||
return fmt.Errorf("invalid IPv4 address: %s", ip)
|
||||
}
|
||||
lastOctet := ipParts[3]
|
||||
classID := fmt.Sprintf("1:%s", lastOctet)
|
||||
|
||||
// Remove filters for this IP
|
||||
// List all filters to find the ones for this class
|
||||
cmd := exec.Command("tc", "filter", "show", "dev", interfaceName, "parent", "1:")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
logger.Warn("Failed to list filters for peer IP %s: %v, output: %s", ip, err, string(output))
|
||||
} else {
|
||||
// Parse the output to find filter handles that match this classID
|
||||
// The output format includes lines like:
|
||||
// filter parent 1: protocol ip pref 1 u32 chain 0 fh 800::800 order 2048 key ht 800 bkt 0 flowid 1:4
|
||||
lines := strings.Split(string(output), "\n")
|
||||
for _, line := range lines {
|
||||
// Look for lines containing our flowid (classID)
|
||||
if strings.Contains(line, "flowid "+classID) && strings.Contains(line, "fh ") {
|
||||
// Extract handle (format: fh 800::800)
|
||||
parts := strings.Fields(line)
|
||||
var handle string
|
||||
for j, part := range parts {
|
||||
if part == "fh" && j+1 < len(parts) {
|
||||
handle = parts[j+1]
|
||||
break
|
||||
}
|
||||
}
|
||||
if handle != "" {
|
||||
// Delete this filter using the handle
|
||||
delCmd := exec.Command("tc", "filter", "del", "dev", interfaceName, "parent", "1:", "handle", handle, "prio", "1", "u32")
|
||||
if delOutput, delErr := delCmd.CombinedOutput(); delErr != nil {
|
||||
logger.Debug("Failed to delete filter handle %s for peer IP %s: %v, output: %s", handle, ip, delErr, string(delOutput))
|
||||
} else {
|
||||
logger.Debug("Deleted filter handle %s for peer IP %s", handle, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the class
|
||||
cmd = exec.Command("tc", "class", "del", "dev", interfaceName, "classid", classID)
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
// It's okay if the class doesn't exist
|
||||
if !strings.Contains(string(output), "No such file or directory") && !strings.Contains(string(output), "Cannot find") {
|
||||
logger.Warn("Failed to remove class for peer IP %s: %v, output: %s", ip, err, string(output))
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Removed bandwidth limit for peer IP %s (class %s)", ip, classID)
|
||||
return nil
|
||||
}
|
||||
|
||||
320
proxy/proxy.go
320
proxy/proxy.go
@@ -11,6 +11,7 @@ import (
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -31,6 +32,16 @@ type RouteAPIResponse struct {
|
||||
Endpoints []string `json:"endpoints"`
|
||||
}
|
||||
|
||||
// ProxyProtocolInfo holds information parsed from incoming PROXY protocol header
|
||||
type ProxyProtocolInfo struct {
|
||||
Protocol string // TCP4 or TCP6
|
||||
SrcIP string
|
||||
DestIP string
|
||||
SrcPort int
|
||||
DestPort int
|
||||
OriginalConn net.Conn // The original connection after PROXY protocol parsing
|
||||
}
|
||||
|
||||
// SNIProxy represents the main proxy server
|
||||
type SNIProxy struct {
|
||||
port int
|
||||
@@ -55,6 +66,9 @@ type SNIProxy struct {
|
||||
// Track active tunnels by SNI
|
||||
activeTunnels map[string]*activeTunnel
|
||||
activeTunnelsLock sync.Mutex
|
||||
|
||||
// Trusted upstream proxies that can send PROXY protocol
|
||||
trustedUpstreams map[string]struct{}
|
||||
}
|
||||
|
||||
type activeTunnel struct {
|
||||
@@ -75,6 +89,194 @@ func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
// parseProxyProtocolHeader parses a PROXY protocol v1 header from the connection
|
||||
func (p *SNIProxy) parseProxyProtocolHeader(conn net.Conn) (*ProxyProtocolInfo, net.Conn, error) {
|
||||
// Check if the connection comes from a trusted upstream
|
||||
remoteHost, _, err := net.SplitHostPort(conn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
return nil, conn, fmt.Errorf("failed to parse remote address: %w", err)
|
||||
}
|
||||
|
||||
// Resolve the remote IP to hostname to check if it's trusted
|
||||
// For simplicity, we'll check the IP directly in trusted upstreams
|
||||
// In production, you might want to do reverse DNS lookup
|
||||
if _, isTrusted := p.trustedUpstreams[remoteHost]; !isTrusted {
|
||||
// Not from trusted upstream, return original connection
|
||||
return nil, conn, nil
|
||||
}
|
||||
|
||||
// Set read timeout for PROXY protocol parsing
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
return nil, conn, fmt.Errorf("failed to set read deadline: %w", err)
|
||||
}
|
||||
|
||||
// Read the first line (PROXY protocol header)
|
||||
buffer := make([]byte, 512) // PROXY protocol header should be much smaller
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil {
|
||||
// If we can't read from trusted upstream, treat as regular connection
|
||||
logger.Debug("Could not read from trusted upstream %s, treating as regular connection: %v", remoteHost, err)
|
||||
// Clear read timeout before returning
|
||||
if clearErr := conn.SetReadDeadline(time.Time{}); clearErr != nil {
|
||||
logger.Debug("Failed to clear read deadline: %v", clearErr)
|
||||
}
|
||||
return nil, conn, nil
|
||||
}
|
||||
|
||||
// Find the end of the first line (CRLF)
|
||||
headerEnd := bytes.Index(buffer[:n], []byte("\r\n"))
|
||||
if headerEnd == -1 {
|
||||
// No PROXY protocol header found, treat as regular TLS connection
|
||||
// Return the connection with the buffered data prepended
|
||||
logger.Debug("No PROXY protocol header from trusted upstream %s, treating as regular TLS connection", remoteHost)
|
||||
|
||||
// Clear read timeout
|
||||
if err := conn.SetReadDeadline(time.Time{}); err != nil {
|
||||
logger.Debug("Failed to clear read deadline: %v", err)
|
||||
}
|
||||
|
||||
// Create a reader that includes the buffered data + original connection
|
||||
newReader := io.MultiReader(bytes.NewReader(buffer[:n]), conn)
|
||||
wrappedConn := &proxyProtocolConn{
|
||||
Conn: conn,
|
||||
reader: newReader,
|
||||
}
|
||||
return nil, wrappedConn, nil
|
||||
}
|
||||
|
||||
headerLine := string(buffer[:headerEnd])
|
||||
remainingData := buffer[headerEnd+2 : n]
|
||||
|
||||
// Parse PROXY protocol line: "PROXY TCP4/TCP6 srcIP destIP srcPort destPort"
|
||||
parts := strings.Fields(headerLine)
|
||||
if len(parts) != 6 || parts[0] != "PROXY" {
|
||||
// Check for PROXY UNKNOWN
|
||||
if len(parts) == 2 && parts[0] == "PROXY" && parts[1] == "UNKNOWN" {
|
||||
// PROXY UNKNOWN - use original connection info
|
||||
return nil, conn, nil
|
||||
}
|
||||
// Invalid PROXY protocol, but might be regular TLS - treat as such
|
||||
logger.Debug("Invalid PROXY protocol from trusted upstream %s, treating as regular TLS connection: %s", remoteHost, headerLine)
|
||||
|
||||
// Clear read timeout
|
||||
if err := conn.SetReadDeadline(time.Time{}); err != nil {
|
||||
logger.Debug("Failed to clear read deadline: %v", err)
|
||||
}
|
||||
|
||||
// Return the connection with all buffered data prepended
|
||||
newReader := io.MultiReader(bytes.NewReader(buffer[:n]), conn)
|
||||
wrappedConn := &proxyProtocolConn{
|
||||
Conn: conn,
|
||||
reader: newReader,
|
||||
}
|
||||
return nil, wrappedConn, nil
|
||||
}
|
||||
|
||||
protocol := parts[1]
|
||||
srcIP := parts[2]
|
||||
destIP := parts[3]
|
||||
srcPort, err := strconv.Atoi(parts[4])
|
||||
if err != nil {
|
||||
return nil, conn, fmt.Errorf("invalid source port in PROXY header: %s", parts[4])
|
||||
}
|
||||
destPort, err := strconv.Atoi(parts[5])
|
||||
if err != nil {
|
||||
return nil, conn, fmt.Errorf("invalid destination port in PROXY header: %s", parts[5])
|
||||
}
|
||||
|
||||
// Create a new reader that includes remaining data + original connection
|
||||
var newReader io.Reader
|
||||
if len(remainingData) > 0 {
|
||||
newReader = io.MultiReader(bytes.NewReader(remainingData), conn)
|
||||
} else {
|
||||
newReader = conn
|
||||
}
|
||||
|
||||
// Create a wrapper connection that reads from the combined reader
|
||||
wrappedConn := &proxyProtocolConn{
|
||||
Conn: conn,
|
||||
reader: newReader,
|
||||
}
|
||||
|
||||
proxyInfo := &ProxyProtocolInfo{
|
||||
Protocol: protocol,
|
||||
SrcIP: srcIP,
|
||||
DestIP: destIP,
|
||||
SrcPort: srcPort,
|
||||
DestPort: destPort,
|
||||
OriginalConn: wrappedConn,
|
||||
}
|
||||
|
||||
// Clear read timeout
|
||||
if err := conn.SetReadDeadline(time.Time{}); err != nil {
|
||||
return nil, conn, fmt.Errorf("failed to clear read deadline: %w", err)
|
||||
}
|
||||
|
||||
return proxyInfo, wrappedConn, nil
|
||||
}
|
||||
|
||||
// proxyProtocolConn wraps a connection to read from a custom reader
|
||||
type proxyProtocolConn struct {
|
||||
net.Conn
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
func (c *proxyProtocolConn) Read(b []byte) (int, error) {
|
||||
return c.reader.Read(b)
|
||||
}
|
||||
|
||||
// buildProxyProtocolHeaderFromInfo creates a PROXY protocol v1 header using ProxyProtocolInfo
|
||||
func (p *SNIProxy) buildProxyProtocolHeaderFromInfo(proxyInfo *ProxyProtocolInfo, targetAddr net.Addr) string {
|
||||
targetTCP, ok := targetAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
// Fallback for unknown address types
|
||||
return "PROXY UNKNOWN\r\n"
|
||||
}
|
||||
|
||||
// Use the original client information from the PROXY protocol
|
||||
var targetIP string
|
||||
var protocol string
|
||||
|
||||
// Parse source IP to determine protocol family
|
||||
srcIP := net.ParseIP(proxyInfo.SrcIP)
|
||||
if srcIP == nil {
|
||||
return "PROXY UNKNOWN\r\n"
|
||||
}
|
||||
|
||||
if srcIP.To4() != nil {
|
||||
// Source is IPv4, use TCP4 protocol
|
||||
protocol = "TCP4"
|
||||
if targetTCP.IP.To4() != nil {
|
||||
// Target is also IPv4, use as-is
|
||||
targetIP = targetTCP.IP.String()
|
||||
} else {
|
||||
// Target is IPv6, but we need IPv4 for consistent protocol family
|
||||
if targetTCP.IP.IsLoopback() {
|
||||
targetIP = "127.0.0.1"
|
||||
} else {
|
||||
targetIP = "127.0.0.1" // Safe fallback
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Source is IPv6, use TCP6 protocol
|
||||
protocol = "TCP6"
|
||||
if targetTCP.IP.To4() != nil {
|
||||
// Target is IPv4, convert to IPv6 representation
|
||||
targetIP = "::ffff:" + targetTCP.IP.String()
|
||||
} else {
|
||||
// Target is also IPv6, use as-is
|
||||
targetIP = targetTCP.IP.String()
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("PROXY %s %s %s %d %d\r\n",
|
||||
protocol,
|
||||
proxyInfo.SrcIP,
|
||||
targetIP,
|
||||
proxyInfo.SrcPort,
|
||||
targetTCP.Port)
|
||||
}
|
||||
|
||||
// buildProxyProtocolHeader creates a PROXY protocol v1 header
|
||||
func buildProxyProtocolHeader(clientAddr, targetAddr net.Addr) string {
|
||||
clientTCP, ok := clientAddr.(*net.TCPAddr)
|
||||
@@ -131,7 +333,7 @@ func buildProxyProtocolHeader(clientAddr, targetAddr net.Addr) string {
|
||||
}
|
||||
|
||||
// NewSNIProxy creates a new SNI proxy instance
|
||||
func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, localProxyPort int, localOverrides []string, proxyProtocol bool) (*SNIProxy, error) {
|
||||
func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, localProxyPort int, localOverrides []string, proxyProtocol bool, trustedUpstreams []string) (*SNIProxy, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Create local overrides map
|
||||
@@ -142,19 +344,36 @@ func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, lo
|
||||
}
|
||||
}
|
||||
|
||||
// Create trusted upstreams map
|
||||
trustedMap := make(map[string]struct{})
|
||||
for _, upstream := range trustedUpstreams {
|
||||
if upstream != "" {
|
||||
// Add both the domain and potentially resolved IPs
|
||||
trustedMap[upstream] = struct{}{}
|
||||
|
||||
// Try to resolve the domain to IPs and add them too
|
||||
if ips, err := net.LookupIP(upstream); err == nil {
|
||||
for _, ip := range ips {
|
||||
trustedMap[ip.String()] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
proxy := &SNIProxy{
|
||||
port: port,
|
||||
cache: cache.New(3*time.Second, 10*time.Minute),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
localProxyAddr: localProxyAddr,
|
||||
localProxyPort: localProxyPort,
|
||||
remoteConfigURL: remoteConfigURL,
|
||||
publicKey: publicKey,
|
||||
proxyProtocol: proxyProtocol,
|
||||
localSNIs: make(map[string]struct{}),
|
||||
localOverrides: overridesMap,
|
||||
activeTunnels: make(map[string]*activeTunnel),
|
||||
port: port,
|
||||
cache: cache.New(3*time.Second, 10*time.Minute),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
localProxyAddr: localProxyAddr,
|
||||
localProxyPort: localProxyPort,
|
||||
remoteConfigURL: remoteConfigURL,
|
||||
publicKey: publicKey,
|
||||
proxyProtocol: proxyProtocol,
|
||||
localSNIs: make(map[string]struct{}),
|
||||
localOverrides: overridesMap,
|
||||
activeTunnels: make(map[string]*activeTunnel),
|
||||
trustedUpstreams: trustedMap,
|
||||
}
|
||||
|
||||
return proxy, nil
|
||||
@@ -270,14 +489,35 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
||||
|
||||
logger.Debug("Accepted connection from %s", clientConn.RemoteAddr())
|
||||
|
||||
// Check for PROXY protocol from trusted upstream
|
||||
var proxyInfo *ProxyProtocolInfo
|
||||
var actualClientConn net.Conn = clientConn
|
||||
|
||||
if len(p.trustedUpstreams) > 0 {
|
||||
var err error
|
||||
proxyInfo, actualClientConn, err = p.parseProxyProtocolHeader(clientConn)
|
||||
if err != nil {
|
||||
logger.Debug("Failed to parse PROXY protocol: %v", err)
|
||||
return
|
||||
}
|
||||
if proxyInfo != nil {
|
||||
logger.Debug("Received PROXY protocol from trusted upstream: %s:%d -> %s:%d",
|
||||
proxyInfo.SrcIP, proxyInfo.SrcPort, proxyInfo.DestIP, proxyInfo.DestPort)
|
||||
} else {
|
||||
// No PROXY protocol detected, but connection is from trusted upstream
|
||||
// This is fine - treat as regular connection
|
||||
logger.Debug("No PROXY protocol detected from trusted upstream, treating as regular connection")
|
||||
}
|
||||
}
|
||||
|
||||
// Set read timeout for SNI extraction
|
||||
if err := clientConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
if err := actualClientConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
logger.Debug("Failed to set read deadline: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract SNI hostname
|
||||
hostname, clientReader, err := p.extractSNI(clientConn)
|
||||
hostname, clientReader, err := p.extractSNI(actualClientConn)
|
||||
if err != nil {
|
||||
logger.Debug("SNI extraction failed: %v", err)
|
||||
return
|
||||
@@ -291,13 +531,20 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
||||
logger.Debug("SNI hostname detected: %s", hostname)
|
||||
|
||||
// Remove read timeout for normal operation
|
||||
if err := clientConn.SetReadDeadline(time.Time{}); err != nil {
|
||||
if err := actualClientConn.SetReadDeadline(time.Time{}); err != nil {
|
||||
logger.Debug("Failed to clear read deadline: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Get routing information
|
||||
route, err := p.getRoute(hostname, clientConn.RemoteAddr().String())
|
||||
// Get routing information - use original client address if available from PROXY protocol
|
||||
var clientAddrStr string
|
||||
if proxyInfo != nil {
|
||||
clientAddrStr = fmt.Sprintf("%s:%d", proxyInfo.SrcIP, proxyInfo.SrcPort)
|
||||
} else {
|
||||
clientAddrStr = clientConn.RemoteAddr().String()
|
||||
}
|
||||
|
||||
route, err := p.getRoute(hostname, clientAddrStr)
|
||||
if err != nil {
|
||||
logger.Debug("Failed to get route for %s: %v", hostname, err)
|
||||
return
|
||||
@@ -325,7 +572,14 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
||||
|
||||
// Send PROXY protocol header if enabled
|
||||
if p.proxyProtocol {
|
||||
proxyHeader := buildProxyProtocolHeader(clientConn.RemoteAddr(), targetConn.LocalAddr())
|
||||
var proxyHeader string
|
||||
if proxyInfo != nil {
|
||||
// Use original client info from PROXY protocol
|
||||
proxyHeader = p.buildProxyProtocolHeaderFromInfo(proxyInfo, targetConn.LocalAddr())
|
||||
} else {
|
||||
// Use direct client connection info
|
||||
proxyHeader = buildProxyProtocolHeader(clientConn.RemoteAddr(), targetConn.LocalAddr())
|
||||
}
|
||||
logger.Debug("Sending PROXY protocol header: %s", strings.TrimSpace(proxyHeader))
|
||||
|
||||
if _, err := targetConn.Write([]byte(proxyHeader)); err != nil {
|
||||
@@ -341,7 +595,7 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
||||
tunnel = &activeTunnel{}
|
||||
p.activeTunnels[hostname] = tunnel
|
||||
}
|
||||
tunnel.conns = append(tunnel.conns, clientConn)
|
||||
tunnel.conns = append(tunnel.conns, actualClientConn)
|
||||
p.activeTunnelsLock.Unlock()
|
||||
|
||||
defer func() {
|
||||
@@ -350,7 +604,7 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
||||
if tunnel, ok := p.activeTunnels[hostname]; ok {
|
||||
newConns := make([]net.Conn, 0, len(tunnel.conns))
|
||||
for _, c := range tunnel.conns {
|
||||
if c != clientConn {
|
||||
if c != actualClientConn {
|
||||
newConns = append(newConns, c)
|
||||
}
|
||||
}
|
||||
@@ -364,7 +618,7 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
||||
}()
|
||||
|
||||
// Start bidirectional data transfer
|
||||
p.pipe(clientConn, targetConn, clientReader)
|
||||
p.pipe(actualClientConn, targetConn, clientReader)
|
||||
}
|
||||
|
||||
// getRoute retrieves routing information for a hostname
|
||||
@@ -504,14 +758,20 @@ func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
// closeOnce ensures we only close connections once
|
||||
var closeOnce sync.Once
|
||||
closeConns := func() {
|
||||
closeOnce.Do(func() {
|
||||
// Close both connections to unblock any pending reads
|
||||
clientConn.Close()
|
||||
targetConn.Close()
|
||||
})
|
||||
}
|
||||
|
||||
// Copy data from client to target (using the buffered reader)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer func() {
|
||||
if tcpConn, ok := targetConn.(*net.TCPConn); ok {
|
||||
tcpConn.CloseWrite()
|
||||
}
|
||||
}()
|
||||
defer closeConns()
|
||||
|
||||
// Use a large buffer for better performance
|
||||
buf := make([]byte, 32*1024)
|
||||
@@ -524,11 +784,7 @@ func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader)
|
||||
// Copy data from target to client
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer func() {
|
||||
if tcpConn, ok := clientConn.(*net.TCPConn); ok {
|
||||
tcpConn.CloseWrite()
|
||||
}
|
||||
}()
|
||||
defer closeConns()
|
||||
|
||||
// Use a large buffer for better performance
|
||||
buf := make([]byte, 32*1024)
|
||||
|
||||
@@ -76,3 +76,44 @@ func TestBuildProxyProtocolHeaderUnknownType(t *testing.T) {
|
||||
t.Errorf("Expected %q, got %q", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildProxyProtocolHeaderFromInfo(t *testing.T) {
|
||||
proxy, err := NewSNIProxy(8443, "", "", "127.0.0.1", 443, nil, true, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SNI proxy: %v", err)
|
||||
}
|
||||
|
||||
// Test IPv4 case
|
||||
proxyInfo := &ProxyProtocolInfo{
|
||||
Protocol: "TCP4",
|
||||
SrcIP: "10.0.0.1",
|
||||
DestIP: "192.168.1.100",
|
||||
SrcPort: 12345,
|
||||
DestPort: 443,
|
||||
}
|
||||
|
||||
targetAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:8080")
|
||||
header := proxy.buildProxyProtocolHeaderFromInfo(proxyInfo, targetAddr)
|
||||
|
||||
expected := "PROXY TCP4 10.0.0.1 127.0.0.1 12345 8080\r\n"
|
||||
if header != expected {
|
||||
t.Errorf("Expected header '%s', got '%s'", expected, header)
|
||||
}
|
||||
|
||||
// Test IPv6 case
|
||||
proxyInfo = &ProxyProtocolInfo{
|
||||
Protocol: "TCP6",
|
||||
SrcIP: "2001:db8::1",
|
||||
DestIP: "2001:db8::2",
|
||||
SrcPort: 12345,
|
||||
DestPort: 443,
|
||||
}
|
||||
|
||||
targetAddr, _ = net.ResolveTCPAddr("tcp6", "[::1]:8080")
|
||||
header = proxy.buildProxyProtocolHeaderFromInfo(proxyInfo, targetAddr)
|
||||
|
||||
expected = "PROXY TCP6 2001:db8::1 ::1 12345 8080\r\n"
|
||||
if header != expected {
|
||||
t.Errorf("Expected header '%s', got '%s'", expected, header)
|
||||
}
|
||||
}
|
||||
|
||||
486
relay/relay.go
486
relay/relay.go
@@ -2,12 +2,14 @@ package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -24,20 +26,22 @@ type EncryptedHolePunchMessage struct {
|
||||
}
|
||||
|
||||
type HolePunchMessage struct {
|
||||
OlmID string `json:"olmId"`
|
||||
NewtID string `json:"newtId"`
|
||||
Token string `json:"token"`
|
||||
OlmID string `json:"olmId"`
|
||||
NewtID string `json:"newtId"`
|
||||
Token string `json:"token"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
}
|
||||
|
||||
type ClientEndpoint struct {
|
||||
OlmID string `json:"olmId"`
|
||||
NewtID string `json:"newtId"`
|
||||
Token string `json:"token"`
|
||||
IP string `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
ReachableAt string `json:"reachableAt"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
OlmID string `json:"olmId"`
|
||||
NewtID string `json:"newtId"`
|
||||
Token string `json:"token"`
|
||||
IP string `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
ReachableAt string `json:"reachableAt"`
|
||||
ExitNodePublicKey string `json:"exitNodePublicKey"`
|
||||
ClientPublicKey string `json:"publicKey"`
|
||||
}
|
||||
|
||||
// Updated to support multiple destination peers
|
||||
@@ -58,12 +62,52 @@ type DestinationConn struct {
|
||||
|
||||
// Type for storing WireGuard handshake information
|
||||
type WireGuardSession struct {
|
||||
mu sync.RWMutex
|
||||
ReceiverIndex uint32
|
||||
SenderIndex uint32
|
||||
DestAddr *net.UDPAddr
|
||||
LastSeen time.Time
|
||||
}
|
||||
|
||||
// GetSenderIndex returns the SenderIndex in a thread-safe manner
|
||||
func (s *WireGuardSession) GetSenderIndex() uint32 {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.SenderIndex
|
||||
}
|
||||
|
||||
// GetDestAddr returns the DestAddr in a thread-safe manner
|
||||
func (s *WireGuardSession) GetDestAddr() *net.UDPAddr {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.DestAddr
|
||||
}
|
||||
|
||||
// GetLastSeen returns the LastSeen timestamp in a thread-safe manner
|
||||
func (s *WireGuardSession) GetLastSeen() time.Time {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.LastSeen
|
||||
}
|
||||
|
||||
// UpdateLastSeen updates the LastSeen timestamp in a thread-safe manner
|
||||
func (s *WireGuardSession) UpdateLastSeen() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.LastSeen = time.Now()
|
||||
}
|
||||
|
||||
// Type for tracking bidirectional communication patterns to rebuild sessions
|
||||
type CommunicationPattern struct {
|
||||
FromClient *net.UDPAddr // The client address
|
||||
ToDestination *net.UDPAddr // The destination address
|
||||
ClientIndex uint32 // The receiver index seen from client
|
||||
DestIndex uint32 // The receiver index seen from destination
|
||||
LastFromClient time.Time // Last packet from client to destination
|
||||
LastFromDest time.Time // Last packet from destination to client
|
||||
PacketCount int // Number of packets observed
|
||||
}
|
||||
|
||||
type InitialMappings struct {
|
||||
Mappings map[string]ProxyMapping `json:"mappings"` // key is "ip:port"
|
||||
}
|
||||
@@ -75,6 +119,13 @@ type Packet struct {
|
||||
n int
|
||||
}
|
||||
|
||||
// holePunchRateLimitEntry tracks hole punch message counts within a sliding 1-second window.
|
||||
type holePunchRateLimitEntry struct {
|
||||
mu sync.Mutex
|
||||
count int
|
||||
windowStart time.Time
|
||||
}
|
||||
|
||||
// WireGuard message types
|
||||
const (
|
||||
WireGuardMessageTypeHandshakeInitiation = 1
|
||||
@@ -101,22 +152,35 @@ type UDPProxyServer struct {
|
||||
connections sync.Map // map[string]*DestinationConn where key is destination "ip:port"
|
||||
privateKey wgtypes.Key
|
||||
packetChan chan Packet
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// Session tracking for WireGuard peers
|
||||
// Key format: "senderIndex:receiverIndex"
|
||||
wgSessions sync.Map
|
||||
// Communication pattern tracking for rebuilding sessions
|
||||
// Key format: "clientIP:clientPort-destIP:destPort"
|
||||
commPatterns sync.Map
|
||||
// Rate limiter for encrypted hole punch messages, keyed by "ip:port"
|
||||
holePunchRateLimiter sync.Map
|
||||
// Cache for resolved UDP addresses to avoid per-packet DNS lookups
|
||||
// Key: "ip:port" string, Value: *net.UDPAddr
|
||||
addrCache sync.Map
|
||||
// ReachableAt is the URL where this server can be reached
|
||||
ReachableAt string
|
||||
}
|
||||
|
||||
// NewUDPProxyServer initializes the server with a buffered packet channel.
|
||||
func NewUDPProxyServer(addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer {
|
||||
// NewUDPProxyServer initializes the server with a buffered packet channel and derived context.
|
||||
func NewUDPProxyServer(parentCtx context.Context, addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer {
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
return &UDPProxyServer{
|
||||
addr: addr,
|
||||
serverURL: serverURL,
|
||||
privateKey: privateKey,
|
||||
packetChan: make(chan Packet, 1000),
|
||||
packetChan: make(chan Packet, 50000), // Increased from 1000 to handle high throughput
|
||||
ReachableAt: reachableAt,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -138,8 +202,13 @@ func (s *UDPProxyServer) Start() error {
|
||||
s.conn = conn
|
||||
logger.Info("UDP server listening on %s", s.addr)
|
||||
|
||||
// Start a fixed number of worker goroutines.
|
||||
workerCount := 10 // TODO: Make this configurable or pick it better!
|
||||
// Start worker goroutines based on CPU cores for better parallelism
|
||||
// At high throughput (160+ Mbps), we need many workers to avoid bottlenecks
|
||||
workerCount := runtime.NumCPU() * 10
|
||||
if workerCount < 20 {
|
||||
workerCount = 20 // Minimum 20 workers
|
||||
}
|
||||
logger.Info("Starting %d packet workers (CPUs: %d)", workerCount, runtime.NumCPU())
|
||||
for i := 0; i < workerCount; i++ {
|
||||
go s.packetWorker()
|
||||
}
|
||||
@@ -156,21 +225,61 @@ func (s *UDPProxyServer) Start() error {
|
||||
// Start the proxy mapping cleanup routine
|
||||
go s.cleanupIdleProxyMappings()
|
||||
|
||||
// Start the communication pattern cleanup routine
|
||||
go s.cleanupIdleCommunicationPatterns()
|
||||
|
||||
// Start the hole punch rate limiter cleanup routine
|
||||
go s.cleanupHolePunchRateLimiter()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UDPProxyServer) Stop() {
|
||||
s.conn.Close()
|
||||
// Signal all background goroutines to stop
|
||||
if s.cancel != nil {
|
||||
s.cancel()
|
||||
}
|
||||
// Close listener to unblock reads
|
||||
if s.conn != nil {
|
||||
_ = s.conn.Close()
|
||||
}
|
||||
// Close all downstream UDP connections
|
||||
s.connections.Range(func(key, value interface{}) bool {
|
||||
if dc, ok := value.(*DestinationConn); ok && dc.conn != nil {
|
||||
_ = dc.conn.Close()
|
||||
}
|
||||
return true
|
||||
})
|
||||
// Close packet channel to stop workers
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
default:
|
||||
}
|
||||
close(s.packetChan)
|
||||
}
|
||||
|
||||
// readPackets continuously reads from the UDP socket and pushes packets into the channel.
|
||||
func (s *UDPProxyServer) readPackets() {
|
||||
for {
|
||||
// Exit promptly if context is canceled
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
buf := bufferPool.Get().([]byte)
|
||||
n, remoteAddr, err := s.conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
logger.Error("Error reading UDP packet: %v", err)
|
||||
continue
|
||||
// If we're shutting down, exit
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
bufferPool.Put(buf[:1500])
|
||||
return
|
||||
default:
|
||||
logger.Error("Error reading UDP packet: %v", err)
|
||||
bufferPool.Put(buf[:1500])
|
||||
continue
|
||||
}
|
||||
}
|
||||
s.packetChan <- Packet{data: buf[:n], remoteAddr: remoteAddr, n: n}
|
||||
}
|
||||
@@ -184,6 +293,27 @@ func (s *UDPProxyServer) packetWorker() {
|
||||
// Process as a WireGuard packet.
|
||||
s.handleWireGuardPacket(packet.data, packet.remoteAddr)
|
||||
} else {
|
||||
// Rate limit: allow at most 2 hole punch messages per IP:Port per second
|
||||
rateLimitKey := packet.remoteAddr.String()
|
||||
entryVal, _ := s.holePunchRateLimiter.LoadOrStore(rateLimitKey, &holePunchRateLimitEntry{
|
||||
windowStart: time.Now(),
|
||||
})
|
||||
rlEntry := entryVal.(*holePunchRateLimitEntry)
|
||||
rlEntry.mu.Lock()
|
||||
now := time.Now()
|
||||
if now.Sub(rlEntry.windowStart) >= time.Second {
|
||||
rlEntry.count = 0
|
||||
rlEntry.windowStart = now
|
||||
}
|
||||
rlEntry.count++
|
||||
allowed := rlEntry.count <= 2
|
||||
rlEntry.mu.Unlock()
|
||||
if !allowed {
|
||||
// logger.Debug("Rate limiting hole punch message from %s", rateLimitKey)
|
||||
bufferPool.Put(packet.data[:1500])
|
||||
continue
|
||||
}
|
||||
|
||||
// Process as an encrypted hole punch message
|
||||
var encMsg EncryptedHolePunchMessage
|
||||
if err := json.Unmarshal(packet.data, &encMsg); err != nil {
|
||||
@@ -203,7 +333,7 @@ func (s *UDPProxyServer) packetWorker() {
|
||||
// This appears to be an encrypted message
|
||||
decryptedData, err := s.decryptMessage(encMsg)
|
||||
if err != nil {
|
||||
logger.Error("Failed to decrypt message: %v", err)
|
||||
// logger.Error("Failed to decrypt message: %v", err)
|
||||
// Return the buffer to the pool for reuse and continue with next packet
|
||||
bufferPool.Put(packet.data[:1500])
|
||||
continue
|
||||
@@ -219,14 +349,15 @@ func (s *UDPProxyServer) packetWorker() {
|
||||
}
|
||||
|
||||
endpoint := ClientEndpoint{
|
||||
NewtID: msg.NewtID,
|
||||
OlmID: msg.OlmID,
|
||||
Token: msg.Token,
|
||||
IP: packet.remoteAddr.IP.String(),
|
||||
Port: packet.remoteAddr.Port,
|
||||
Timestamp: time.Now().Unix(),
|
||||
ReachableAt: s.ReachableAt,
|
||||
PublicKey: s.privateKey.PublicKey().String(),
|
||||
NewtID: msg.NewtID,
|
||||
OlmID: msg.OlmID,
|
||||
Token: msg.Token,
|
||||
IP: packet.remoteAddr.IP.String(),
|
||||
Port: packet.remoteAddr.Port,
|
||||
Timestamp: time.Now().Unix(),
|
||||
ReachableAt: s.ReachableAt,
|
||||
ExitNodePublicKey: s.privateKey.PublicKey().String(),
|
||||
ClientPublicKey: msg.PublicKey,
|
||||
}
|
||||
logger.Debug("Created endpoint from packet remoteAddr %s: IP=%s, Port=%d", packet.remoteAddr.String(), endpoint.IP, endpoint.Port)
|
||||
s.notifyServer(endpoint)
|
||||
@@ -327,6 +458,43 @@ func extractWireGuardIndices(packet []byte) (uint32, uint32, bool) {
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
// cachedAddr holds a resolved UDP address with TTL
|
||||
type cachedAddr struct {
|
||||
addr *net.UDPAddr
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// addrCacheTTL is how long resolved addresses are cached before re-resolving
|
||||
const addrCacheTTL = 5 * time.Minute
|
||||
|
||||
// getCachedAddr returns a cached UDP address or resolves and caches it.
|
||||
// This avoids per-packet DNS lookups which are a major throughput bottleneck.
|
||||
func (s *UDPProxyServer) getCachedAddr(ip string, port int) (*net.UDPAddr, error) {
|
||||
key := fmt.Sprintf("%s:%d", ip, port)
|
||||
|
||||
// Check cache first
|
||||
if cached, ok := s.addrCache.Load(key); ok {
|
||||
entry := cached.(*cachedAddr)
|
||||
if time.Now().Before(entry.expiresAt) {
|
||||
return entry.addr, nil
|
||||
}
|
||||
// Cache expired, delete and re-resolve
|
||||
s.addrCache.Delete(key)
|
||||
}
|
||||
|
||||
// Resolve and cache
|
||||
addr, err := net.ResolveUDPAddr("udp", key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.addrCache.Store(key, &cachedAddr{
|
||||
addr: addr,
|
||||
expiresAt: time.Now().Add(addrCacheTTL),
|
||||
})
|
||||
return addr, nil
|
||||
}
|
||||
|
||||
// Updated to handle multi-peer WireGuard communication
|
||||
func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UDPAddr) {
|
||||
if len(packet) == 0 {
|
||||
@@ -361,7 +529,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
||||
logger.Debug("Forwarding handshake initiation from %s (sender index: %d) to peers %v", remoteAddr, senderIndex, proxyMapping.Destinations)
|
||||
|
||||
for _, dest := range proxyMapping.Destinations {
|
||||
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
|
||||
destAddr, err := s.getCachedAddr(dest.DestinationIP, dest.DestinationPort)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve destination address: %v", err)
|
||||
continue
|
||||
@@ -375,7 +543,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
||||
|
||||
_, err = conn.Write(packet)
|
||||
if err != nil {
|
||||
logger.Error("Failed to forward handshake initiation: %v", err)
|
||||
logger.Debug("Failed to forward handshake initiation: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -397,7 +565,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
||||
|
||||
// Forward the response to the original sender
|
||||
for _, dest := range proxyMapping.Destinations {
|
||||
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
|
||||
destAddr, err := s.getCachedAddr(dest.DestinationIP, dest.DestinationPort)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve destination address: %v", err)
|
||||
continue
|
||||
@@ -425,13 +593,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
||||
// First check for existing sessions to see if we know where to send this packet
|
||||
s.wgSessions.Range(func(k, v interface{}) bool {
|
||||
session := v.(*WireGuardSession)
|
||||
if session.SenderIndex == receiverIndex {
|
||||
// Found matching session
|
||||
destAddr = session.DestAddr
|
||||
|
||||
// Update last seen time
|
||||
session.LastSeen = time.Now()
|
||||
s.wgSessions.Store(k, session)
|
||||
// Check if session matches (read lock for check)
|
||||
if session.GetSenderIndex() == receiverIndex {
|
||||
// Found matching session - get dest addr and update last seen
|
||||
destAddr = session.GetDestAddr()
|
||||
session.UpdateLastSeen()
|
||||
return false // stop iteration
|
||||
}
|
||||
return true // continue iteration
|
||||
@@ -445,6 +611,9 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
||||
return
|
||||
}
|
||||
|
||||
// Track communication pattern for session rebuilding
|
||||
s.trackCommunicationPattern(remoteAddr, destAddr, receiverIndex, true)
|
||||
|
||||
_, err = conn.Write(packet)
|
||||
if err != nil {
|
||||
logger.Debug("Failed to forward transport data: %v", err)
|
||||
@@ -453,7 +622,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
||||
// No known session, fall back to forwarding to all peers
|
||||
logger.Debug("No session found for receiver index %d, forwarding to all destinations", receiverIndex)
|
||||
for _, dest := range proxyMapping.Destinations {
|
||||
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
|
||||
destAddr, err := s.getCachedAddr(dest.DestinationIP, dest.DestinationPort)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve destination address: %v", err)
|
||||
continue
|
||||
@@ -465,6 +634,9 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
||||
continue
|
||||
}
|
||||
|
||||
// Track communication pattern for session rebuilding
|
||||
s.trackCommunicationPattern(remoteAddr, destAddr, receiverIndex, true)
|
||||
|
||||
_, err = conn.Write(packet)
|
||||
if err != nil {
|
||||
logger.Debug("Failed to forward transport data: %v", err)
|
||||
@@ -478,7 +650,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
||||
|
||||
// Forward to all peers
|
||||
for _, dest := range proxyMapping.Destinations {
|
||||
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
|
||||
destAddr, err := s.getCachedAddr(dest.DestinationIP, dest.DestinationPort)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve destination address: %v", err)
|
||||
continue
|
||||
@@ -548,6 +720,9 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd
|
||||
LastSeen: time.Now(),
|
||||
})
|
||||
logger.Debug("Stored session mapping: %s -> %s", sessionKey, destAddr.String())
|
||||
} else if ok && buffer[0] == WireGuardMessageTypeTransportData {
|
||||
// Track communication pattern for session rebuilding (reverse direction)
|
||||
s.trackCommunicationPattern(destAddr, remoteAddr, receiverIndex, false)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -562,49 +737,69 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd
|
||||
// Add a cleanup method to periodically remove idle connections
|
||||
func (s *UDPProxyServer) cleanupIdleConnections() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
s.connections.Range(func(key, value interface{}) bool {
|
||||
destConn := value.(*DestinationConn)
|
||||
if now.Sub(destConn.lastUsed) > 10*time.Minute {
|
||||
destConn.conn.Close()
|
||||
s.connections.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
s.connections.Range(func(key, value interface{}) bool {
|
||||
destConn := value.(*DestinationConn)
|
||||
if now.Sub(destConn.lastUsed) > 10*time.Minute {
|
||||
destConn.conn.Close()
|
||||
s.connections.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// New method to periodically remove idle sessions
|
||||
func (s *UDPProxyServer) cleanupIdleSessions() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
s.wgSessions.Range(func(key, value interface{}) bool {
|
||||
session := value.(*WireGuardSession)
|
||||
if now.Sub(session.LastSeen) > 15*time.Minute {
|
||||
s.wgSessions.Delete(key)
|
||||
logger.Debug("Removed idle session: %s", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
s.wgSessions.Range(func(key, value interface{}) bool {
|
||||
session := value.(*WireGuardSession)
|
||||
// Use thread-safe method to read LastSeen
|
||||
if now.Sub(session.GetLastSeen()) > 15*time.Minute {
|
||||
s.wgSessions.Delete(key)
|
||||
logger.Debug("Removed idle session: %s", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// New method to periodically remove idle proxy mappings
|
||||
func (s *UDPProxyServer) cleanupIdleProxyMappings() {
|
||||
ticker := time.NewTicker(10 * time.Minute)
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
s.proxyMappings.Range(func(key, value interface{}) bool {
|
||||
mapping := value.(ProxyMapping)
|
||||
// Remove mappings that haven't been used in 30 minutes
|
||||
if now.Sub(mapping.LastUsed) > 30*time.Minute {
|
||||
s.proxyMappings.Delete(key)
|
||||
logger.Debug("Removed idle proxy mapping: %s", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
s.proxyMappings.Range(func(key, value interface{}) bool {
|
||||
mapping := value.(ProxyMapping)
|
||||
// Remove mappings that haven't been used in 30 minutes
|
||||
if now.Sub(mapping.LastUsed) > 30*time.Minute {
|
||||
s.proxyMappings.Delete(key)
|
||||
logger.Debug("Removed idle proxy mapping: %s", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -709,8 +904,9 @@ func (s *UDPProxyServer) clearSessionsForIP(ip string) {
|
||||
keyStr := key.(string)
|
||||
session := value.(*WireGuardSession)
|
||||
|
||||
// Check if the session's destination address contains the WG IP
|
||||
if session.DestAddr != nil && session.DestAddr.IP.String() == ip {
|
||||
// Check if the session's destination address contains the WG IP (thread-safe)
|
||||
destAddr := session.GetDestAddr()
|
||||
if destAddr != nil && destAddr.IP.String() == ip {
|
||||
keysToDelete = append(keysToDelete, keyStr)
|
||||
logger.Debug("Marking session for deletion for WG IP %s: %s", ip, keyStr)
|
||||
}
|
||||
@@ -722,7 +918,7 @@ func (s *UDPProxyServer) clearSessionsForIP(ip string) {
|
||||
s.wgSessions.Delete(key)
|
||||
}
|
||||
|
||||
logger.Info("Cleared %d sessions for WG IP: %s", len(keysToDelete), ip)
|
||||
logger.Debug("Cleared %d sessions for WG IP: %s", len(keysToDelete), ip)
|
||||
}
|
||||
|
||||
// // clearProxyMappingsForWGIP removes all proxy mappings that have destinations pointing to a specific WireGuard IP
|
||||
@@ -823,3 +1019,145 @@ func (s *UDPProxyServer) UpdateDestinationInMappings(oldDest, newDest PeerDestin
|
||||
|
||||
return updatedCount
|
||||
}
|
||||
|
||||
// trackCommunicationPattern tracks bidirectional communication patterns to rebuild sessions
|
||||
func (s *UDPProxyServer) trackCommunicationPattern(fromAddr, toAddr *net.UDPAddr, receiverIndex uint32, fromClient bool) {
|
||||
var clientAddr, destAddr *net.UDPAddr
|
||||
var clientIndex, destIndex uint32
|
||||
|
||||
if fromClient {
|
||||
clientAddr = fromAddr
|
||||
destAddr = toAddr
|
||||
clientIndex = receiverIndex
|
||||
destIndex = 0 // We don't know the destination index yet
|
||||
} else {
|
||||
clientAddr = toAddr
|
||||
destAddr = fromAddr
|
||||
clientIndex = 0 // We don't know the client index yet
|
||||
destIndex = receiverIndex
|
||||
}
|
||||
|
||||
patternKey := fmt.Sprintf("%s-%s", clientAddr.String(), destAddr.String())
|
||||
now := time.Now()
|
||||
|
||||
if existingPattern, ok := s.commPatterns.Load(patternKey); ok {
|
||||
pattern := existingPattern.(*CommunicationPattern)
|
||||
|
||||
// Update the pattern
|
||||
if fromClient {
|
||||
pattern.LastFromClient = now
|
||||
if pattern.ClientIndex == 0 {
|
||||
pattern.ClientIndex = clientIndex
|
||||
}
|
||||
} else {
|
||||
pattern.LastFromDest = now
|
||||
if pattern.DestIndex == 0 {
|
||||
pattern.DestIndex = destIndex
|
||||
}
|
||||
}
|
||||
|
||||
pattern.PacketCount++
|
||||
s.commPatterns.Store(patternKey, pattern)
|
||||
|
||||
// Check if we have bidirectional communication and can rebuild a session
|
||||
s.tryRebuildSession(pattern)
|
||||
} else {
|
||||
// Create new pattern
|
||||
pattern := &CommunicationPattern{
|
||||
FromClient: clientAddr,
|
||||
ToDestination: destAddr,
|
||||
ClientIndex: clientIndex,
|
||||
DestIndex: destIndex,
|
||||
PacketCount: 1,
|
||||
}
|
||||
|
||||
if fromClient {
|
||||
pattern.LastFromClient = now
|
||||
} else {
|
||||
pattern.LastFromDest = now
|
||||
}
|
||||
|
||||
s.commPatterns.Store(patternKey, pattern)
|
||||
}
|
||||
}
|
||||
|
||||
// tryRebuildSession attempts to rebuild a WireGuard session from communication patterns
|
||||
func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) {
|
||||
// Check if we have bidirectional communication within a reasonable time window
|
||||
timeDiff := pattern.LastFromClient.Sub(pattern.LastFromDest)
|
||||
if timeDiff < 0 {
|
||||
timeDiff = -timeDiff
|
||||
}
|
||||
|
||||
// Only rebuild if we have recent bidirectional communication and both indices
|
||||
if timeDiff < 30*time.Second && pattern.ClientIndex != 0 && pattern.DestIndex != 0 && pattern.PacketCount >= 4 {
|
||||
// Create session mapping: client's index maps to destination
|
||||
sessionKey := fmt.Sprintf("%d:%d", pattern.DestIndex, pattern.ClientIndex)
|
||||
|
||||
// Check if we already have this session
|
||||
if _, exists := s.wgSessions.Load(sessionKey); !exists {
|
||||
s.wgSessions.Store(sessionKey, &WireGuardSession{
|
||||
ReceiverIndex: pattern.DestIndex,
|
||||
SenderIndex: pattern.ClientIndex,
|
||||
DestAddr: pattern.ToDestination,
|
||||
LastSeen: time.Now(),
|
||||
})
|
||||
logger.Info("Rebuilt WireGuard session from communication pattern: %s -> %s (packets: %d)",
|
||||
sessionKey, pattern.ToDestination.String(), pattern.PacketCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupIdleCommunicationPatterns periodically removes idle communication patterns
|
||||
// cleanupHolePunchRateLimiter periodically evicts stale rate limit entries to prevent unbounded growth.
|
||||
func (s *UDPProxyServer) cleanupHolePunchRateLimiter() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
s.holePunchRateLimiter.Range(func(key, value interface{}) bool {
|
||||
rlEntry := value.(*holePunchRateLimitEntry)
|
||||
rlEntry.mu.Lock()
|
||||
stale := now.Sub(rlEntry.windowStart) > 10*time.Second
|
||||
rlEntry.mu.Unlock()
|
||||
if stale {
|
||||
s.holePunchRateLimiter.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() {
|
||||
ticker := time.NewTicker(10 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
s.commPatterns.Range(func(key, value interface{}) bool {
|
||||
pattern := value.(*CommunicationPattern)
|
||||
|
||||
// Get the most recent activity
|
||||
lastActivity := pattern.LastFromClient
|
||||
if pattern.LastFromDest.After(lastActivity) {
|
||||
lastActivity = pattern.LastFromDest
|
||||
}
|
||||
|
||||
// Remove patterns that haven't had activity in 20 minutes
|
||||
if now.Sub(lastActivity) > 20*time.Minute {
|
||||
s.commPatterns.Delete(key)
|
||||
logger.Debug("Removed idle communication pattern: %s", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user