mirror of
https://github.com/fosrl/newt.git
synced 2026-03-16 11:50:30 -05:00
Compare commits
82 Commits
1.0.0-beta
...
hp-multi-c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
494e30704b | ||
|
|
175718a48e | ||
|
|
6a146ed371 | ||
|
|
027d9a059f | ||
|
|
0ced66e157 | ||
|
|
6b0ca9cab5 | ||
|
|
e47ddaa916 | ||
|
|
65dc81ca8b | ||
|
|
09d6829f8b | ||
|
|
f677376fae | ||
|
|
72e0adc1bf | ||
|
|
b3e8bf7d12 | ||
|
|
c5978d9c4e | ||
|
|
7f9a31ac3e | ||
|
|
f08378b67e | ||
|
|
9d80161ab7 | ||
|
|
1501de691a | ||
|
|
2a19856556 | ||
|
|
f4e17a4dd7 | ||
|
|
ab544fc9ed | ||
|
|
f9e52c4d91 | ||
|
|
14eff8e16c | ||
|
|
067e079293 | ||
|
|
623be5ea0d | ||
|
|
72d264d427 | ||
|
|
a19fc8c588 | ||
|
|
dbc2a92456 | ||
|
|
437d8b67a4 | ||
|
|
6f1d4752f0 | ||
|
|
683312c78e | ||
|
|
29543aece3 | ||
|
|
e68a38e929 | ||
|
|
bc72c96b5e | ||
|
|
3d15ecb732 | ||
|
|
a69618310b | ||
|
|
ed8a2ccd23 | ||
|
|
5e673c829b | ||
|
|
cd3ec0b259 | ||
|
|
b68502de9e | ||
|
|
f6429b6eee | ||
|
|
8795c57b2e | ||
|
|
e8141a177b | ||
|
|
4aa718d55f | ||
|
|
afa93d8a3f | ||
|
|
270ee9cd19 | ||
|
|
0affef401c | ||
|
|
18d99de924 | ||
|
|
bff6707577 | ||
|
|
95eab504fa | ||
|
|
56e75902e3 | ||
|
|
45a1ab91d7 | ||
|
|
fb199cc94b | ||
|
|
66edae4288 | ||
|
|
f69a7f647d | ||
|
|
e8bd55bed9 | ||
|
|
b23eda9c06 | ||
|
|
92bc883b5b | ||
|
|
76503f3f2c | ||
|
|
9c3112f9bd | ||
|
|
462af30d16 | ||
|
|
fa6038eb38 | ||
|
|
f346b6cc5d | ||
|
|
f20b9ebb14 | ||
|
|
39bfe5b230 | ||
|
|
a1a3dd9ba2 | ||
|
|
7b1492f327 | ||
|
|
4e50819785 | ||
|
|
f8dccbec80 | ||
|
|
0c5c59cf00 | ||
|
|
868bb55f87 | ||
|
|
5b4245402a | ||
|
|
f7a705e6f8 | ||
|
|
3a63657822 | ||
|
|
759780508a | ||
|
|
533886f2e4 | ||
|
|
79f8745909 | ||
|
|
7b663027ac | ||
|
|
e90e55d982 | ||
|
|
a46fb23cdd | ||
|
|
10982b47a5 | ||
|
|
ab12098c9c | ||
|
|
446eb4d6f1 |
61
.github/workflows/cicd.yml
vendored
Normal file
61
.github/workflows/cicd.yml
vendored
Normal file
@@ -0,0 +1,61 @@
|
||||
name: CI/CD Pipeline
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
jobs:
|
||||
release:
|
||||
name: Build and Release
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
|
||||
|
||||
- name: Extract tag name
|
||||
id: get-tag
|
||||
run: echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: 1.23.1
|
||||
|
||||
- name: Update version in main.go
|
||||
run: |
|
||||
TAG=${{ env.TAG }}
|
||||
if [ -f main.go ]; then
|
||||
sed -i 's/Newt version replaceme/Newt version '"$TAG"'/' main.go
|
||||
echo "Updated main.go with version $TAG"
|
||||
else
|
||||
echo "main.go not found"
|
||||
fi
|
||||
|
||||
- name: Build and push Docker images
|
||||
run: |
|
||||
TAG=${{ env.TAG }}
|
||||
make docker-build-release tag=$TAG
|
||||
|
||||
- name: Build binaries
|
||||
run: |
|
||||
make go-build-release
|
||||
|
||||
- name: Upload artifacts from /bin
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: binaries
|
||||
path: bin/
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -1 +1,4 @@
|
||||
newt
|
||||
newt
|
||||
.DS_Store
|
||||
bin/
|
||||
nohup.out
|
||||
1
.go-version
Normal file
1
.go-version
Normal file
@@ -0,0 +1 @@
|
||||
1.23.2
|
||||
10
Dockerfile
10
Dockerfile
@@ -15,19 +15,13 @@ COPY . .
|
||||
# Build the application
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -o /newt
|
||||
|
||||
# Start a new stage from scratch
|
||||
FROM ubuntu:22.04 AS runner
|
||||
FROM alpine:3.19 AS runner
|
||||
|
||||
RUN apt-get update && apt-get install ca-certificates -y && rm -rf /var/lib/apt/lists/*
|
||||
RUN apk --no-cache add ca-certificates
|
||||
|
||||
# Copy the pre-built binary file from the previous stage and the entrypoint script
|
||||
COPY --from=builder /newt /usr/local/bin/
|
||||
COPY entrypoint.sh /
|
||||
|
||||
RUN chmod +x /entrypoint.sh
|
||||
|
||||
# Copy the entrypoint script
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
|
||||
# Command to run the executable
|
||||
CMD ["newt"]
|
||||
22
Makefile
22
Makefile
@@ -1,6 +1,14 @@
|
||||
|
||||
all: build push
|
||||
|
||||
docker-build-release:
|
||||
@if [ -z "$(tag)" ]; then \
|
||||
echo "Error: tag is required. Usage: make build-all tag=<tag>"; \
|
||||
exit 1; \
|
||||
fi
|
||||
docker buildx build --platform linux/arm/v7,linux/arm64,linux/amd64 -t fosrl/newt:latest -f Dockerfile --push .
|
||||
docker buildx build --platform linux/arm/v7,linux/arm64,linux/amd64 -t fosrl/newt:$(tag) -f Dockerfile --push .
|
||||
|
||||
build:
|
||||
docker build -t fosrl/newt:latest .
|
||||
|
||||
@@ -11,7 +19,19 @@ test:
|
||||
docker run fosrl/newt:latest
|
||||
|
||||
local:
|
||||
CGO_ENABLED=0 go build -o newt
|
||||
CGO_ENABLED=0 go build -o newt
|
||||
|
||||
go-build-release:
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -o bin/newt_linux_arm64
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=7 go build -o bin/newt_linux_arm32
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=6 go build -o bin/newt_linux_arm32v6
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/newt_linux_amd64
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=riscv64 go build -o bin/newt_linux_riscv64
|
||||
CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -o bin/newt_darwin_arm64
|
||||
CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -o bin/newt_darwin_amd64
|
||||
CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -o bin/newt_windows_amd64.exe
|
||||
CGO_ENABLED=0 GOOS=freebsd GOARCH=amd64 go build -o bin/newt_freebsd_amd64
|
||||
CGO_ENABLED=0 GOOS=freebsd GOARCH=arm64 go build -o bin/newt_freebsd_arm64
|
||||
|
||||
clean:
|
||||
rm newt
|
||||
|
||||
41
README.md
41
README.md
@@ -19,7 +19,7 @@ _Sample output of a Newt container connected to Pangolin and hosting various res
|
||||
|
||||
### Registers with Pangolin
|
||||
|
||||
Using the Newt ID and a secret the client will make HTTP requests to Pangolin to receive a session token. Using that token it will connect to a websocket and maintain that connection. Control messages will be sent over the websocket.
|
||||
Using the Newt ID and a secret, the client will make HTTP requests to Pangolin to receive a session token. Using that token, it will connect to a websocket and maintain that connection. Control messages will be sent over the websocket.
|
||||
|
||||
### Receives WireGuard Control Messages
|
||||
|
||||
@@ -27,7 +27,7 @@ When Newt receives WireGuard control messages, it will use the information encod
|
||||
|
||||
### Receives Proxy Control Messages
|
||||
|
||||
When Newt receives WireGuard control messages, it will use the information encoded to crate local low level TCP and UDP proxies attached to the virtual tunnel in order to relay traffic to programmed targets.
|
||||
When Newt receives WireGuard control messages, it will use the information encoded to create a local low level TCP and UDP proxies attached to the virtual tunnel in order to relay traffic to programmed targets.
|
||||
|
||||
## CLI Args
|
||||
|
||||
@@ -36,7 +36,8 @@ When Newt receives WireGuard control messages, it will use the information encod
|
||||
- `secret`: A unique secret (not shared and kept private) used to authenticate the client ID with the websocket in order to receive commands.
|
||||
- `dns`: DNS server to use to resolve the endpoint
|
||||
- `log-level` (optional): The log level to use. Default: INFO
|
||||
|
||||
- `updown` (optional): A script to be called when targets are added or removed.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
@@ -74,6 +75,38 @@ services:
|
||||
- --endpoint https://example.com
|
||||
```
|
||||
|
||||
Finally a basic systemd service:
|
||||
|
||||
```
|
||||
[Unit]
|
||||
Description=Newt VPN Client
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
ExecStart=/usr/local/bin/newt --id 31frd0uzbjvp721 --secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 --endpoint https://example.com
|
||||
Restart=always
|
||||
User=root
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
```
|
||||
|
||||
Make sure to `mv ./newt /usr/local/bin/newt`!
|
||||
|
||||
### Updown
|
||||
|
||||
You can pass in a updown script for Newt to call when it is adding or removing a target:
|
||||
|
||||
`--updown "python3 test.py"`
|
||||
|
||||
It will get called with args when a target is added:
|
||||
`python3 test.py add tcp localhost:8556`
|
||||
`python3 test.py remove tcp localhost:8556`
|
||||
|
||||
Returning a string from the script in the format of a target (`ip:dst` so `10.0.0.1:8080`) it will override the target and use this value instead to proxy.
|
||||
|
||||
You can look at updown.py as a reference script to get started!
|
||||
|
||||
## Build
|
||||
|
||||
### Container
|
||||
@@ -98,4 +131,4 @@ Newt is dual licensed under the AGPLv3 and the Fossorial Commercial license. For
|
||||
|
||||
## Contributions
|
||||
|
||||
Please see [CONTRIBUTIONS](./CONTRIBUTING.md) in the repository for guidelines and best practices.
|
||||
Please see [CONTRIBUTIONS](./CONTRIBUTING.md) in the repository for guidelines and best practices.
|
||||
|
||||
14
SECURITY.md
Normal file
14
SECURITY.md
Normal file
@@ -0,0 +1,14 @@
|
||||
# Security Policy
|
||||
|
||||
If you discover a security vulnerability, please follow the steps below to responsibly disclose it to us:
|
||||
|
||||
1. **Do not create a public GitHub issue or discussion post.** This could put the security of other users at risk.
|
||||
2. Send a detailed report to [security@fossorial.io](mailto:security@fossorial.io) or send a **private** message to a maintainer on [Discord](https://discord.gg/HCJR8Xhme4). Include:
|
||||
|
||||
- Description and location of the vulnerability.
|
||||
- Potential impact of the vulnerability.
|
||||
- Steps to reproduce the vulnerability.
|
||||
- Potential solutions to fix the vulnerability.
|
||||
- Your name/handle and a link for recognition (optional).
|
||||
|
||||
We aim to address the issue as soon as possible.
|
||||
@@ -6,4 +6,5 @@ services:
|
||||
environment:
|
||||
- PANGOLIN_ENDPOINT=https://example.com
|
||||
- NEWT_ID=2ix2t8xk22ubpfy
|
||||
- NEWT_SECRET=nnisrfsdfc7prqsp9ewo1dvtvci50j5uiqotez00dgap0ii2
|
||||
- NEWT_SECRET=nnisrfsdfc7prqsp9ewo1dvtvci50j5uiqotez00dgap0ii2
|
||||
- LOG_LEVEL=DEBUG
|
||||
26
go.mod
26
go.mod
@@ -4,16 +4,28 @@ go 1.23.1
|
||||
|
||||
toolchain go1.23.2
|
||||
|
||||
require golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
|
||||
require (
|
||||
github.com/google/gopacket v1.1.19
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/vishvananda/netlink v1.3.0
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa
|
||||
golang.org/x/net v0.33.0
|
||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/google/btree v1.1.2 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
golang.org/x/crypto v0.28.0 // indirect
|
||||
golang.org/x/net v0.30.0 // indirect
|
||||
golang.org/x/sys v0.26.0 // indirect
|
||||
github.com/google/go-cmp v0.6.0 // indirect
|
||||
github.com/josharian/native v1.1.0 // indirect
|
||||
github.com/mdlayher/genetlink v1.3.2 // indirect
|
||||
github.com/mdlayher/netlink v1.7.2 // indirect
|
||||
github.com/mdlayher/socket v0.5.1 // indirect
|
||||
github.com/vishvananda/netns v0.0.4 // indirect
|
||||
golang.org/x/crypto v0.31.0 // indirect
|
||||
golang.org/x/sync v0.11.0 // indirect
|
||||
golang.org/x/sys v0.28.0 // indirect
|
||||
golang.org/x/time v0.7.0 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 // indirect
|
||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect
|
||||
)
|
||||
|
||||
52
go.sum
52
go.sum
@@ -1,20 +1,56 @@
|
||||
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
|
||||
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
|
||||
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
|
||||
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
|
||||
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
|
||||
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
|
||||
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
|
||||
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
|
||||
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
||||
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
|
||||
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
|
||||
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
|
||||
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
|
||||
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
|
||||
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
|
||||
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
|
||||
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
|
||||
github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk=
|
||||
github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs=
|
||||
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
||||
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4=
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk=
|
||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
||||
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
|
||||
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
|
||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
|
||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
|
||||
|
||||
@@ -53,7 +53,23 @@ func (l *Logger) log(level LogLevel, format string, args ...interface{}) {
|
||||
if level < l.level {
|
||||
return
|
||||
}
|
||||
timestamp := time.Now().Format("2006/01/02 15:04:05")
|
||||
|
||||
// Get timezone from environment variable or use local timezone
|
||||
timezone := os.Getenv("LOGGER_TIMEZONE")
|
||||
var location *time.Location
|
||||
var err error
|
||||
|
||||
if timezone != "" {
|
||||
location, err = time.LoadLocation(timezone)
|
||||
if err != nil {
|
||||
// If invalid timezone, fall back to local
|
||||
location = time.Local
|
||||
}
|
||||
} else {
|
||||
location = time.Local
|
||||
}
|
||||
|
||||
timestamp := time.Now().In(location).Format("2006/01/02 15:04:05")
|
||||
message := fmt.Sprintf(format, args...)
|
||||
l.logger.Printf("%s: %s %s", level.String(), timestamp, message)
|
||||
}
|
||||
|
||||
387
main.go
387
main.go
@@ -11,7 +11,10 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -19,6 +22,8 @@ import (
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/proxy"
|
||||
"github.com/fosrl/newt/websocket"
|
||||
"github.com/fosrl/newt/wg"
|
||||
"github.com/fosrl/newt/wgtester"
|
||||
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv4"
|
||||
@@ -53,7 +58,7 @@ func fixKey(key string) string {
|
||||
// Decode from base64
|
||||
decoded, err := base64.StdEncoding.DecodeString(key)
|
||||
if err != nil {
|
||||
logger.Fatal("Error decoding base64:", err)
|
||||
logger.Fatal("Error decoding base64")
|
||||
}
|
||||
|
||||
// Convert to hex
|
||||
@@ -113,7 +118,12 @@ func ping(tnet *netstack.Net, dst string) error {
|
||||
}
|
||||
|
||||
func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{}) {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
initialInterval := 10 * time.Second
|
||||
maxInterval := 60 * time.Second
|
||||
currentInterval := initialInterval
|
||||
consecutiveFailures := 0
|
||||
|
||||
ticker := time.NewTicker(currentInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
go func() {
|
||||
@@ -122,7 +132,34 @@ func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{})
|
||||
case <-ticker.C:
|
||||
err := ping(tnet, serverIP)
|
||||
if err != nil {
|
||||
logger.Warn("Periodic ping failed: %v", err)
|
||||
consecutiveFailures++
|
||||
logger.Warn("Periodic ping failed (%d consecutive failures): %v",
|
||||
consecutiveFailures, err)
|
||||
logger.Warn("HINT: Do you have UDP port 51820 (or the port in config.yml) open on your Pangolin server?")
|
||||
|
||||
// Increase interval if we have consistent failures, with a maximum cap
|
||||
if consecutiveFailures >= 3 && currentInterval < maxInterval {
|
||||
// Increase by 50% each time, up to the maximum
|
||||
currentInterval = time.Duration(float64(currentInterval) * 1.5)
|
||||
if currentInterval > maxInterval {
|
||||
currentInterval = maxInterval
|
||||
}
|
||||
ticker.Reset(currentInterval)
|
||||
logger.Info("Increased ping check interval to %v due to consecutive failures",
|
||||
currentInterval)
|
||||
}
|
||||
} else {
|
||||
// On success, if we've backed off, gradually return to normal interval
|
||||
if currentInterval > initialInterval {
|
||||
currentInterval = time.Duration(float64(currentInterval) * 0.8)
|
||||
if currentInterval < initialInterval {
|
||||
currentInterval = initialInterval
|
||||
}
|
||||
ticker.Reset(currentInterval)
|
||||
logger.Info("Decreased ping check interval to %v after successful ping",
|
||||
currentInterval)
|
||||
}
|
||||
consecutiveFailures = 0
|
||||
}
|
||||
case <-stopChan:
|
||||
logger.Info("Stopping ping check")
|
||||
@@ -132,34 +169,97 @@ func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{})
|
||||
}()
|
||||
}
|
||||
|
||||
// Function to track connection status and trigger reconnection as needed
|
||||
func monitorConnectionStatus(tnet *netstack.Net, serverIP string, client *websocket.Client) {
|
||||
const checkInterval = 30 * time.Second
|
||||
connectionLost := false
|
||||
ticker := time.NewTicker(checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// Try a ping to see if connection is alive
|
||||
err := ping(tnet, serverIP)
|
||||
|
||||
if err != nil && !connectionLost {
|
||||
// We just lost connection
|
||||
connectionLost = true
|
||||
logger.Warn("Connection to server lost. Continuous reconnection attempts will be made.")
|
||||
|
||||
// Notify the user they might need to check their network
|
||||
logger.Warn("Please check your internet connection and ensure the Pangolin server is online.")
|
||||
logger.Warn("Newt will continue reconnection attempts automatically when connectivity is restored.")
|
||||
} else if err == nil && connectionLost {
|
||||
// Connection has been restored
|
||||
connectionLost = false
|
||||
logger.Info("Connection to server restored!")
|
||||
|
||||
// Tell the server we're back
|
||||
err := client.SendMessage("newt/wg/register", map[string]interface{}{
|
||||
"publicKey": fmt.Sprintf("%s", privateKey.PublicKey()),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
logger.Error("Failed to send registration message after reconnection: %v", err)
|
||||
} else {
|
||||
logger.Info("Successfully re-registered with server after reconnection")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func pingWithRetry(tnet *netstack.Net, dst string) error {
|
||||
const (
|
||||
maxAttempts = 5
|
||||
retryDelay = 2 * time.Second
|
||||
initialMaxAttempts = 15
|
||||
initialRetryDelay = 2 * time.Second
|
||||
maxRetryDelay = 60 * time.Second // Cap the maximum delay
|
||||
)
|
||||
|
||||
var lastErr error
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
logger.Info("Ping attempt %d of %d", attempt, maxAttempts)
|
||||
|
||||
if err := ping(tnet, dst); err != nil {
|
||||
lastErr = err
|
||||
logger.Warn("Ping attempt %d failed: %v", attempt, err)
|
||||
|
||||
if attempt < maxAttempts {
|
||||
time.Sleep(retryDelay)
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("all ping attempts failed after %d tries, last error: %w",
|
||||
maxAttempts, lastErr)
|
||||
}
|
||||
attempt := 1
|
||||
retryDelay := initialRetryDelay
|
||||
|
||||
// First try with the initial parameters
|
||||
logger.Info("Ping attempt %d", attempt)
|
||||
if err := ping(tnet, dst); err == nil {
|
||||
// Successful ping
|
||||
return nil
|
||||
} else {
|
||||
logger.Warn("Ping attempt %d failed: %v", attempt, err)
|
||||
}
|
||||
|
||||
// This shouldn't be reached due to the return in the loop, but added for completeness
|
||||
return fmt.Errorf("unexpected error: all ping attempts failed")
|
||||
// Start a goroutine that will attempt pings indefinitely with increasing delays
|
||||
go func() {
|
||||
attempt = 2 // Continue from attempt 2
|
||||
|
||||
for {
|
||||
logger.Info("Ping attempt %d", attempt)
|
||||
|
||||
if err := ping(tnet, dst); err != nil {
|
||||
logger.Warn("Ping attempt %d failed: %v", attempt, err)
|
||||
|
||||
// Increase delay after certain thresholds but cap it
|
||||
if attempt%5 == 0 && retryDelay < maxRetryDelay {
|
||||
retryDelay = time.Duration(float64(retryDelay) * 1.5)
|
||||
if retryDelay > maxRetryDelay {
|
||||
retryDelay = maxRetryDelay
|
||||
}
|
||||
logger.Info("Increasing ping retry delay to %v", retryDelay)
|
||||
}
|
||||
|
||||
time.Sleep(retryDelay)
|
||||
attempt++
|
||||
} else {
|
||||
// Successful ping
|
||||
logger.Info("Ping succeeded after %d attempts", attempt)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Return an error for the first batch of attempts (to maintain compatibility with existing code)
|
||||
return fmt.Errorf("initial ping attempts failed, continuing in background")
|
||||
}
|
||||
|
||||
func parseLogLevel(level string) logger.LogLevel {
|
||||
@@ -210,6 +310,9 @@ func resolveDomain(domain string) (string, error) {
|
||||
host = strings.TrimPrefix(host, "https://")
|
||||
}
|
||||
|
||||
// if there are any trailing slashes, remove them
|
||||
host = strings.TrimSuffix(host, "/")
|
||||
|
||||
// Lookup IP addresses
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
@@ -242,23 +345,36 @@ func resolveDomain(domain string) (string, error) {
|
||||
return ipAddr, nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
var (
|
||||
endpoint string
|
||||
id string
|
||||
secret string
|
||||
dns string
|
||||
privateKey wgtypes.Key
|
||||
err error
|
||||
logLevel string
|
||||
)
|
||||
var (
|
||||
endpoint string
|
||||
id string
|
||||
secret string
|
||||
mtu string
|
||||
mtuInt int
|
||||
dns string
|
||||
privateKey wgtypes.Key
|
||||
err error
|
||||
logLevel string
|
||||
updownScript string
|
||||
interfaceName string
|
||||
generateAndSaveKeyTo string
|
||||
rm bool
|
||||
acceptClients bool
|
||||
)
|
||||
|
||||
func main() {
|
||||
// if PANGOLIN_ENDPOINT, NEWT_ID, and NEWT_SECRET are set as environment variables, they will be used as default values
|
||||
endpoint = os.Getenv("PANGOLIN_ENDPOINT")
|
||||
id = os.Getenv("NEWT_ID")
|
||||
secret = os.Getenv("NEWT_SECRET")
|
||||
mtu = os.Getenv("MTU")
|
||||
dns = os.Getenv("DNS")
|
||||
logLevel = os.Getenv("LOG_LEVEL")
|
||||
updownScript = os.Getenv("UPDOWN_SCRIPT")
|
||||
interfaceName = os.Getenv("INTERFACE")
|
||||
generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO")
|
||||
rm = os.Getenv("RM") == "true"
|
||||
acceptClients = os.Getenv("ACCEPT_CLIENTS") == "true"
|
||||
|
||||
if endpoint == "" {
|
||||
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
|
||||
@@ -269,21 +385,45 @@ func main() {
|
||||
if secret == "" {
|
||||
flag.StringVar(&secret, "secret", "", "Newt secret")
|
||||
}
|
||||
if mtu == "" {
|
||||
flag.StringVar(&mtu, "mtu", "1280", "MTU to use")
|
||||
}
|
||||
if dns == "" {
|
||||
flag.StringVar(&dns, "dns", "8.8.8.8", "DNS server to use")
|
||||
}
|
||||
if logLevel == "" {
|
||||
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
||||
}
|
||||
if updownScript == "" {
|
||||
flag.StringVar(&updownScript, "updown", "", "Path to updown script to be called when targets are added or removed")
|
||||
}
|
||||
if interfaceName == "" {
|
||||
flag.StringVar(&interfaceName, "interface", "wg1", "Name of the WireGuard interface")
|
||||
}
|
||||
if generateAndSaveKeyTo == "" {
|
||||
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "/tmp/newtkey", "Path to save generated private key")
|
||||
}
|
||||
flag.BoolVar(&rm, "rm", false, "Remove the WireGuard interface")
|
||||
flag.BoolVar(&acceptClients, "accept-clients", false, "Accept clients on the WireGuard interface")
|
||||
|
||||
// do a --version check
|
||||
version := flag.Bool("version", false, "Print the version")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if *version {
|
||||
fmt.Println("Newt version replaceme")
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
logger.Init()
|
||||
loggerLevel := parseLogLevel(logLevel)
|
||||
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
|
||||
|
||||
// Validate required fields
|
||||
if endpoint == "" || id == "" || secret == "" {
|
||||
logger.Fatal("endpoint, id, and secret are required either via CLI flags or environment variables")
|
||||
// parse the mtu string into an int
|
||||
mtuInt, err = strconv.Atoi(mtu)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to parse MTU: %v", err)
|
||||
}
|
||||
|
||||
privateKey, err = wgtypes.GeneratePrivateKey()
|
||||
@@ -301,6 +441,7 @@ func main() {
|
||||
logger.Fatal("Failed to create client: %v", err)
|
||||
}
|
||||
|
||||
var wgService *wg.WireGuardService
|
||||
// Create TUN device and network stack
|
||||
var tun tun.Device
|
||||
var tnet *netstack.Net
|
||||
@@ -308,6 +449,40 @@ func main() {
|
||||
var pm *proxy.ProxyManager
|
||||
var connected bool
|
||||
var wgData WgData
|
||||
var wgTesterServer *wgtester.Server
|
||||
|
||||
if acceptClients {
|
||||
// make sure we are running on linux
|
||||
if runtime.GOOS != "linux" {
|
||||
logger.Fatal("Tunnel management is only supported on Linux right now!")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
var host = endpoint
|
||||
if strings.HasPrefix(host, "http://") {
|
||||
host = strings.TrimPrefix(host, "http://")
|
||||
} else if strings.HasPrefix(host, "https://") {
|
||||
host = strings.TrimPrefix(host, "https://")
|
||||
}
|
||||
|
||||
host = strings.TrimSuffix(host, "/")
|
||||
|
||||
// Create WireGuard service
|
||||
wgService, err = wg.NewWireGuardService(interfaceName, mtuInt, generateAndSaveKeyTo, host, id, client)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to create WireGuard service: %v", err)
|
||||
}
|
||||
defer wgService.Close(rm)
|
||||
|
||||
wgTesterServer = wgtester.NewServer("0.0.0.0", wgService.Port, id) // TODO: maybe make this the same ip of the wg server?
|
||||
err := wgTesterServer.Start()
|
||||
if err != nil {
|
||||
logger.Error("Failed to start WireGuard tester server: %v", err)
|
||||
} else {
|
||||
// Make sure to stop the server on exit
|
||||
defer wgTesterServer.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
client.RegisterHandler("newt/terminate", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received terminate message")
|
||||
@@ -329,12 +504,8 @@ func main() {
|
||||
|
||||
if connected {
|
||||
logger.Info("Already connected! But I will send a ping anyway...")
|
||||
// ping(tnet, wgData.ServerIP)
|
||||
err = pingWithRetry(tnet, wgData.ServerIP)
|
||||
if err != nil {
|
||||
// Handle complete failure after all retries
|
||||
logger.Error("Failed to ping %s: %v", wgData.ServerIP, err)
|
||||
}
|
||||
// Even if pingWithRetry returns an error, it will continue trying in the background
|
||||
_ = pingWithRetry(tnet, wgData.ServerIP) // Ignoring initial error as pings will continue
|
||||
return
|
||||
}
|
||||
|
||||
@@ -349,11 +520,15 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
if wgService != nil {
|
||||
wgService.SetServerPubKey(wgData.PublicKey)
|
||||
}
|
||||
|
||||
logger.Info("Received: %+v", msg)
|
||||
tun, tnet, err = netstack.CreateNetTUN(
|
||||
[]netip.Addr{netip.MustParseAddr(wgData.TunnelIP)},
|
||||
[]netip.Addr{netip.MustParseAddr(dns)},
|
||||
1420)
|
||||
mtuInt)
|
||||
if err != nil {
|
||||
logger.Error("Failed to create TUN device: %v", err)
|
||||
}
|
||||
@@ -389,17 +564,18 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
|
||||
}
|
||||
|
||||
logger.Info("WireGuard device created. Lets ping the server now...")
|
||||
// Ping to bring the tunnel up on the server side quickly
|
||||
// ping(tnet, wgData.ServerIP)
|
||||
err = pingWithRetry(tnet, wgData.ServerIP)
|
||||
if err != nil {
|
||||
// Handle complete failure after all retries
|
||||
logger.Error("Failed to ping %s: %v", wgData.ServerIP, err)
|
||||
}
|
||||
|
||||
// Even if pingWithRetry returns an error, it will continue trying in the background
|
||||
_ = pingWithRetry(tnet, wgData.ServerIP)
|
||||
|
||||
// Always mark as connected and start the proxy manager regardless of initial ping result
|
||||
// as the pings will continue in the background
|
||||
if !connected {
|
||||
logger.Info("Starting ping check")
|
||||
startPingCheck(tnet, wgData.ServerIP, pingStopChan)
|
||||
|
||||
// Start connection monitoring in a separate goroutine
|
||||
go monitorConnectionStatus(tnet, wgData.ServerIP, client)
|
||||
}
|
||||
|
||||
// Create proxy manager
|
||||
@@ -416,6 +592,13 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
|
||||
updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: wgData.Targets.UDP})
|
||||
}
|
||||
|
||||
// first make sure the wpgService has a port
|
||||
if wgService != nil {
|
||||
// add a udp proxy for localost and the wgService port
|
||||
// TODO: make sure this port is not used in a target
|
||||
pm.AddTarget("udp", wgData.TunnelIP, int(wgService.Port), fmt.Sprintf("127.0.0.1:%d", wgService.Port))
|
||||
}
|
||||
|
||||
err = pm.Start()
|
||||
if err != nil {
|
||||
logger.Error("Failed to start proxy manager: %v", err)
|
||||
@@ -440,11 +623,6 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
|
||||
if len(targetData.Targets) > 0 {
|
||||
updateTargets(pm, "add", wgData.TunnelIP, "tcp", targetData)
|
||||
}
|
||||
|
||||
err = pm.Start()
|
||||
if err != nil {
|
||||
logger.Error("Failed to start proxy manager: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
client.RegisterHandler("newt/udp/add", func(msg websocket.WSMessage) {
|
||||
@@ -465,11 +643,6 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
|
||||
if len(targetData.Targets) > 0 {
|
||||
updateTargets(pm, "add", wgData.TunnelIP, "udp", targetData)
|
||||
}
|
||||
|
||||
err = pm.Start()
|
||||
if err != nil {
|
||||
logger.Error("Failed to start proxy manager: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
client.RegisterHandler("newt/udp/remove", func(msg websocket.WSMessage) {
|
||||
@@ -524,10 +697,20 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
|
||||
return err
|
||||
}
|
||||
|
||||
if wgService != nil {
|
||||
wgService.LoadRemoteConfig()
|
||||
}
|
||||
|
||||
logger.Info("Sent registration message")
|
||||
return nil
|
||||
})
|
||||
|
||||
client.OnTokenUpdate(func(token string) {
|
||||
if wgService != nil {
|
||||
wgService.SetToken(token)
|
||||
}
|
||||
})
|
||||
|
||||
// Connect to the WebSocket server
|
||||
if err := client.Connect(); err != nil {
|
||||
logger.Fatal("Failed to connect to server: %v", err)
|
||||
@@ -539,8 +722,25 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sigCh
|
||||
|
||||
// Cleanup
|
||||
dev.Close()
|
||||
|
||||
if wgService != nil {
|
||||
wgService.Close(rm)
|
||||
}
|
||||
|
||||
if wgTesterServer != nil {
|
||||
wgTesterServer.Stop()
|
||||
}
|
||||
|
||||
if pm != nil {
|
||||
pm.Stop()
|
||||
}
|
||||
|
||||
if client != nil {
|
||||
client.Close()
|
||||
}
|
||||
logger.Info("Exiting...")
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func parseTargetData(data interface{}) (TargetData, error) {
|
||||
@@ -577,6 +777,18 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
|
||||
|
||||
if action == "add" {
|
||||
target := parts[1] + ":" + parts[2]
|
||||
|
||||
// Call updown script if provided
|
||||
processedTarget := target
|
||||
if updownScript != "" {
|
||||
newTarget, err := executeUpdownScript(action, proto, target)
|
||||
if err != nil {
|
||||
logger.Warn("Updown script error: %v", err)
|
||||
} else if newTarget != "" {
|
||||
processedTarget = newTarget
|
||||
}
|
||||
}
|
||||
|
||||
// Only remove the specific target if it exists
|
||||
err := pm.RemoveTarget(proto, tunnelIP, port)
|
||||
if err != nil {
|
||||
@@ -587,10 +799,21 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
|
||||
}
|
||||
|
||||
// Add the new target
|
||||
pm.AddTarget(proto, tunnelIP, port, target)
|
||||
pm.AddTarget(proto, tunnelIP, port, processedTarget)
|
||||
|
||||
} else if action == "remove" {
|
||||
logger.Info("Removing target with port %d", port)
|
||||
|
||||
target := parts[1] + ":" + parts[2]
|
||||
|
||||
// Call updown script if provided
|
||||
if updownScript != "" {
|
||||
_, err := executeUpdownScript(action, proto, target)
|
||||
if err != nil {
|
||||
logger.Warn("Updown script error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
err := pm.RemoveTarget(proto, tunnelIP, port)
|
||||
if err != nil {
|
||||
logger.Error("Failed to remove target: %v", err)
|
||||
@@ -601,3 +824,45 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func executeUpdownScript(action, proto, target string) (string, error) {
|
||||
if updownScript == "" {
|
||||
return target, nil
|
||||
}
|
||||
|
||||
// Split the updownScript in case it contains spaces (like "/usr/bin/python3 script.py")
|
||||
parts := strings.Fields(updownScript)
|
||||
if len(parts) == 0 {
|
||||
return target, fmt.Errorf("invalid updown script command")
|
||||
}
|
||||
|
||||
var cmd *exec.Cmd
|
||||
if len(parts) == 1 {
|
||||
// If it's a single executable
|
||||
logger.Info("Executing updown script: %s %s %s %s", updownScript, action, proto, target)
|
||||
cmd = exec.Command(parts[0], action, proto, target)
|
||||
} else {
|
||||
// If it includes interpreter and script
|
||||
args := append(parts[1:], action, proto, target)
|
||||
logger.Info("Executing updown script: %s %s %s %s %s", parts[0], strings.Join(parts[1:], " "), action, proto, target)
|
||||
cmd = exec.Command(parts[0], args...)
|
||||
}
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
return "", fmt.Errorf("updown script execution failed (exit code %d): %s",
|
||||
exitErr.ExitCode(), string(exitErr.Stderr))
|
||||
}
|
||||
return "", fmt.Errorf("updown script execution failed: %v", err)
|
||||
}
|
||||
|
||||
// If the script returns a new target, use it
|
||||
newTarget := strings.TrimSpace(string(output))
|
||||
if newTarget != "" {
|
||||
logger.Info("Updown script returned new target: %s", newTarget)
|
||||
return newTarget, nil
|
||||
}
|
||||
|
||||
return target, nil
|
||||
}
|
||||
|
||||
195
network/network.go
Normal file
195
network/network.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/net/bpf"
|
||||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
||||
const (
|
||||
udpProtocol = 17
|
||||
// EmptyUDPSize is the size of an empty UDP packet
|
||||
EmptyUDPSize = 28
|
||||
timeout = time.Second * 10
|
||||
)
|
||||
|
||||
// Server stores data relating to the server
|
||||
type Server struct {
|
||||
Hostname string
|
||||
Addr *net.IPAddr
|
||||
Port uint16
|
||||
}
|
||||
|
||||
// PeerNet stores data about a peer's endpoint
|
||||
type PeerNet struct {
|
||||
Resolved bool
|
||||
IP net.IP
|
||||
Port uint16
|
||||
NewtID string
|
||||
}
|
||||
|
||||
// GetClientIP gets source ip address that will be used when sending data to dstIP
|
||||
func GetClientIP(dstIP net.IP) net.IP {
|
||||
routes, err := netlink.RouteGet(dstIP)
|
||||
if err != nil {
|
||||
log.Fatalln("Error getting route:", err)
|
||||
}
|
||||
return routes[0].Src
|
||||
}
|
||||
|
||||
// HostToAddr resolves a hostname, whether DNS or IP to a valid net.IPAddr
|
||||
func HostToAddr(hostStr string) *net.IPAddr {
|
||||
remoteAddrs, err := net.LookupHost(hostStr)
|
||||
if err != nil {
|
||||
log.Fatalln("Error parsing remote address:", err)
|
||||
}
|
||||
|
||||
for _, addrStr := range remoteAddrs {
|
||||
if remoteAddr, err := net.ResolveIPAddr("ip4", addrStr); err == nil {
|
||||
return remoteAddr
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetupRawConn creates an ipv4 and udp only RawConn and applies packet filtering
|
||||
func SetupRawConn(server *Server, client *PeerNet) *ipv4.RawConn {
|
||||
packetConn, err := net.ListenPacket("ip4:udp", client.IP.String())
|
||||
if err != nil {
|
||||
log.Fatalln("Error creating packetConn:", err)
|
||||
}
|
||||
|
||||
rawConn, err := ipv4.NewRawConn(packetConn)
|
||||
if err != nil {
|
||||
log.Fatalln("Error creating rawConn:", err)
|
||||
}
|
||||
|
||||
ApplyBPF(rawConn, server, client)
|
||||
|
||||
return rawConn
|
||||
}
|
||||
|
||||
// ApplyBPF constructs a BPF program and applies it to the RawConn
|
||||
func ApplyBPF(rawConn *ipv4.RawConn, server *Server, client *PeerNet) {
|
||||
const ipv4HeaderLen = 20
|
||||
const srcIPOffset = 12
|
||||
const srcPortOffset = ipv4HeaderLen + 0
|
||||
const dstPortOffset = ipv4HeaderLen + 2
|
||||
|
||||
ipArr := []byte(server.Addr.IP.To4())
|
||||
ipInt := uint32(ipArr[0])<<(3*8) + uint32(ipArr[1])<<(2*8) + uint32(ipArr[2])<<8 + uint32(ipArr[3])
|
||||
|
||||
bpfRaw, err := bpf.Assemble([]bpf.Instruction{
|
||||
bpf.LoadAbsolute{Off: srcIPOffset, Size: 4},
|
||||
bpf.JumpIf{Cond: bpf.JumpEqual, Val: ipInt, SkipFalse: 5, SkipTrue: 0},
|
||||
|
||||
bpf.LoadAbsolute{Off: srcPortOffset, Size: 2},
|
||||
bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(server.Port), SkipFalse: 3, SkipTrue: 0},
|
||||
|
||||
bpf.LoadAbsolute{Off: dstPortOffset, Size: 2},
|
||||
bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(client.Port), SkipFalse: 1, SkipTrue: 0},
|
||||
|
||||
bpf.RetConstant{Val: 1<<(8*4) - 1},
|
||||
bpf.RetConstant{Val: 0},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Fatalln("Error assembling BPF:", err)
|
||||
}
|
||||
|
||||
err = rawConn.SetBPF(bpfRaw)
|
||||
if err != nil {
|
||||
log.Fatalln("Error setting BPF:", err)
|
||||
}
|
||||
}
|
||||
|
||||
// MakePacket constructs a request packet to send to the server
|
||||
func MakePacket(payload []byte, server *Server, client *PeerNet) []byte {
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
|
||||
opts := gopacket.SerializeOptions{
|
||||
FixLengths: true,
|
||||
ComputeChecksums: true,
|
||||
}
|
||||
|
||||
ipHeader := layers.IPv4{
|
||||
SrcIP: client.IP,
|
||||
DstIP: server.Addr.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
|
||||
udpHeader := layers.UDP{
|
||||
SrcPort: layers.UDPPort(client.Port),
|
||||
DstPort: layers.UDPPort(server.Port),
|
||||
}
|
||||
|
||||
payloadLayer := gopacket.Payload(payload)
|
||||
|
||||
udpHeader.SetNetworkLayerForChecksum(&ipHeader)
|
||||
|
||||
gopacket.SerializeLayers(buf, opts, &ipHeader, &udpHeader, &payloadLayer)
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// SendPacket sends packet to the Server
|
||||
func SendPacket(packet []byte, conn *ipv4.RawConn, server *Server, client *PeerNet) error {
|
||||
fullPacket := MakePacket(packet, server, client)
|
||||
_, err := conn.WriteToIP(fullPacket, server.Addr)
|
||||
return err
|
||||
}
|
||||
|
||||
// SendDataPacket sends a JSON payload to the Server
|
||||
func SendDataPacket(data interface{}, conn *ipv4.RawConn, server *Server, client *PeerNet) error {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal payload: %v", err)
|
||||
}
|
||||
|
||||
return SendPacket(jsonData, conn, server, client)
|
||||
}
|
||||
|
||||
// RecvPacket receives a UDP packet from server
|
||||
func RecvPacket(conn *ipv4.RawConn, server *Server, client *PeerNet) ([]byte, int, error) {
|
||||
err := conn.SetReadDeadline(time.Now().Add(timeout))
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
response := make([]byte, 4096)
|
||||
n, err := conn.Read(response)
|
||||
if err != nil {
|
||||
return nil, n, err
|
||||
}
|
||||
return response, n, nil
|
||||
}
|
||||
|
||||
// RecvDataPacket receives and unmarshals a JSON packet from server
|
||||
func RecvDataPacket(conn *ipv4.RawConn, server *Server, client *PeerNet) ([]byte, error) {
|
||||
response, n, err := RecvPacket(conn, server, client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Extract payload from UDP packet
|
||||
payload := response[EmptyUDPSize:n]
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// ParseResponse takes a response packet and parses it into an IP and port
|
||||
func ParseResponse(response []byte) (net.IP, uint16) {
|
||||
ip := net.IP(response[:4])
|
||||
port := binary.BigEndian.Uint16(response[4:6])
|
||||
return ip, port
|
||||
}
|
||||
BIN
newt_arm64
Executable file
BIN
newt_arm64
Executable file
Binary file not shown.
530
proxy/manager.go
530
proxy/manager.go
@@ -9,326 +9,344 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
)
|
||||
|
||||
// Target represents a proxy target with its address and port
|
||||
type Target struct {
|
||||
Address string
|
||||
Port int
|
||||
}
|
||||
|
||||
// ProxyManager handles the creation and management of proxy connections
|
||||
type ProxyManager struct {
|
||||
tnet *netstack.Net
|
||||
tcpTargets map[string]map[int]string // map[listenIP]map[port]targetAddress
|
||||
udpTargets map[string]map[int]string
|
||||
listeners []*gonet.TCPListener
|
||||
udpConns []*gonet.UDPConn
|
||||
running bool
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewProxyManager creates a new proxy manager instance
|
||||
func NewProxyManager(tnet *netstack.Net) *ProxyManager {
|
||||
return &ProxyManager{
|
||||
tnet: tnet,
|
||||
tnet: tnet,
|
||||
tcpTargets: make(map[string]map[int]string),
|
||||
udpTargets: make(map[string]map[int]string),
|
||||
listeners: make([]*gonet.TCPListener, 0),
|
||||
udpConns: make([]*gonet.UDPConn, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) AddTarget(protocol, listen string, port int, target string) {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
// AddTarget adds as new target for proxying
|
||||
func (pm *ProxyManager) AddTarget(proto, listenIP string, port int, targetAddr string) error {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
logger.Info("Adding target: %s://%s:%d -> %s", protocol, listen, port, target)
|
||||
|
||||
newTarget := ProxyTarget{
|
||||
Protocol: protocol,
|
||||
Listen: listen,
|
||||
Port: port,
|
||||
Target: target,
|
||||
cancel: make(chan struct{}),
|
||||
done: make(chan struct{}),
|
||||
switch proto {
|
||||
case "tcp":
|
||||
if pm.tcpTargets[listenIP] == nil {
|
||||
pm.tcpTargets[listenIP] = make(map[int]string)
|
||||
}
|
||||
pm.tcpTargets[listenIP][port] = targetAddr
|
||||
case "udp":
|
||||
if pm.udpTargets[listenIP] == nil {
|
||||
pm.udpTargets[listenIP] = make(map[int]string)
|
||||
}
|
||||
pm.udpTargets[listenIP][port] = targetAddr
|
||||
default:
|
||||
return fmt.Errorf("unsupported protocol: %s", proto)
|
||||
}
|
||||
|
||||
pm.targets = append(pm.targets, newTarget)
|
||||
if pm.running {
|
||||
return pm.startTarget(proto, listenIP, port, targetAddr)
|
||||
} else {
|
||||
logger.Debug("Not adding target because not running")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) RemoveTarget(protocol, listen string, port int) error {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
func (pm *ProxyManager) RemoveTarget(proto, listenIP string, port int) error {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
protocol = strings.ToLower(protocol)
|
||||
if protocol != "tcp" && protocol != "udp" {
|
||||
return fmt.Errorf("unsupported protocol: %s", protocol)
|
||||
}
|
||||
|
||||
for i, target := range pm.targets {
|
||||
if target.Listen == listen &&
|
||||
target.Port == port &&
|
||||
strings.ToLower(target.Protocol) == protocol {
|
||||
|
||||
// Signal the serving goroutine to stop
|
||||
select {
|
||||
case <-target.cancel:
|
||||
// Channel is already closed, no need to close it again
|
||||
default:
|
||||
close(target.cancel)
|
||||
}
|
||||
|
||||
// Close the appropriate listener/connection based on protocol
|
||||
target.Lock()
|
||||
switch protocol {
|
||||
case "tcp":
|
||||
if target.listener != nil {
|
||||
select {
|
||||
case <-target.cancel:
|
||||
// Listener was already closed by Stop()
|
||||
default:
|
||||
target.listener.Close()
|
||||
}
|
||||
}
|
||||
case "udp":
|
||||
if target.udpConn != nil {
|
||||
select {
|
||||
case <-target.cancel:
|
||||
// Connection was already closed by Stop()
|
||||
default:
|
||||
target.udpConn.Close()
|
||||
}
|
||||
switch proto {
|
||||
case "tcp":
|
||||
if targets, ok := pm.tcpTargets[listenIP]; ok {
|
||||
delete(targets, port)
|
||||
// Remove and close the corresponding TCP listener
|
||||
for i, listener := range pm.listeners {
|
||||
if addr, ok := listener.Addr().(*net.TCPAddr); ok && addr.Port == port {
|
||||
listener.Close()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
// Remove from slice
|
||||
pm.listeners = append(pm.listeners[:i], pm.listeners[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
target.Unlock()
|
||||
|
||||
// Wait for the target to fully stop
|
||||
<-target.done
|
||||
|
||||
// Remove the target from the slice
|
||||
pm.targets = append(pm.targets[:i], pm.targets[i+1:]...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("target not found for %s %s:%d", protocol, listen, port)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) Start() error {
|
||||
pm.RLock()
|
||||
defer pm.RUnlock()
|
||||
|
||||
for i := range pm.targets {
|
||||
target := &pm.targets[i]
|
||||
|
||||
target.Lock()
|
||||
// If target is already running, skip it
|
||||
if target.listener != nil || target.udpConn != nil {
|
||||
target.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
// Mark the target as starting by creating a nil listener/connection
|
||||
// This prevents other goroutines from trying to start it
|
||||
if strings.ToLower(target.Protocol) == "tcp" {
|
||||
target.listener = nil
|
||||
} else {
|
||||
target.udpConn = nil
|
||||
return fmt.Errorf("target not found: %s:%d", listenIP, port)
|
||||
}
|
||||
target.Unlock()
|
||||
case "udp":
|
||||
if targets, ok := pm.udpTargets[listenIP]; ok {
|
||||
delete(targets, port)
|
||||
// Remove and close the corresponding UDP connection
|
||||
for i, conn := range pm.udpConns {
|
||||
if addr, ok := conn.LocalAddr().(*net.UDPAddr); ok && addr.Port == port {
|
||||
conn.Close()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
// Remove from slice
|
||||
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("target not found: %s:%d", listenIP, port)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported protocol: %s", proto)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
switch strings.ToLower(target.Protocol) {
|
||||
case "tcp":
|
||||
go pm.serveTCP(target)
|
||||
case "udp":
|
||||
go pm.serveUDP(target)
|
||||
default:
|
||||
return fmt.Errorf("unsupported protocol: %s", target.Protocol)
|
||||
// Start begins listening for all configured proxy targets
|
||||
func (pm *ProxyManager) Start() error {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
if pm.running {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start TCP targets
|
||||
for listenIP, targets := range pm.tcpTargets {
|
||||
for port, targetAddr := range targets {
|
||||
if err := pm.startTarget("tcp", listenIP, port, targetAddr); err != nil {
|
||||
return fmt.Errorf("failed to start TCP target: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Start UDP targets
|
||||
for listenIP, targets := range pm.udpTargets {
|
||||
for port, targetAddr := range targets {
|
||||
if err := pm.startTarget("udp", listenIP, port, targetAddr); err != nil {
|
||||
return fmt.Errorf("failed to start UDP target: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pm.running = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) Stop() error {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := range pm.targets {
|
||||
target := &pm.targets[i]
|
||||
wg.Add(1)
|
||||
go func(t *ProxyTarget) {
|
||||
defer wg.Done()
|
||||
close(t.cancel)
|
||||
t.Lock()
|
||||
if t.listener != nil {
|
||||
t.listener.Close()
|
||||
}
|
||||
if t.udpConn != nil {
|
||||
t.udpConn.Close()
|
||||
}
|
||||
t.Unlock()
|
||||
// Wait for the target to fully stop
|
||||
<-t.done
|
||||
}(target)
|
||||
if !pm.running {
|
||||
return nil
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Set running to false first to signal handlers to stop
|
||||
pm.running = false
|
||||
|
||||
// Close TCP listeners
|
||||
for i := len(pm.listeners) - 1; i >= 0; i-- {
|
||||
listener := pm.listeners[i]
|
||||
if err := listener.Close(); err != nil {
|
||||
logger.Error("Error closing TCP listener: %v", err)
|
||||
}
|
||||
// Remove from slice
|
||||
pm.listeners = append(pm.listeners[:i], pm.listeners[i+1:]...)
|
||||
}
|
||||
|
||||
// Close UDP connections
|
||||
for i := len(pm.udpConns) - 1; i >= 0; i-- {
|
||||
conn := pm.udpConns[i]
|
||||
if err := conn.Close(); err != nil {
|
||||
logger.Error("Error closing UDP connection: %v", err)
|
||||
}
|
||||
// Remove from slice
|
||||
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
|
||||
}
|
||||
|
||||
// Clear the target maps
|
||||
for k := range pm.tcpTargets {
|
||||
delete(pm.tcpTargets, k)
|
||||
}
|
||||
for k := range pm.udpTargets {
|
||||
delete(pm.udpTargets, k)
|
||||
}
|
||||
|
||||
// Give active connections a chance to close gracefully
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) serveTCP(target *ProxyTarget) {
|
||||
defer close(target.done) // Signal that this target is fully stopped
|
||||
func (pm *ProxyManager) startTarget(proto, listenIP string, port int, targetAddr string) error {
|
||||
switch proto {
|
||||
case "tcp":
|
||||
listener, err := pm.tnet.ListenTCP(&net.TCPAddr{Port: port})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create TCP listener: %v", err)
|
||||
}
|
||||
|
||||
listener, err := pm.tnet.ListenTCP(&net.TCPAddr{
|
||||
IP: net.ParseIP(target.Listen),
|
||||
Port: target.Port,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Info("Failed to start TCP listener for %s:%d: %v", target.Listen, target.Port, err)
|
||||
return
|
||||
pm.listeners = append(pm.listeners, listener)
|
||||
go pm.handleTCPProxy(listener, targetAddr)
|
||||
|
||||
case "udp":
|
||||
addr := &net.UDPAddr{Port: port}
|
||||
conn, err := pm.tnet.ListenUDP(addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create UDP listener: %v", err)
|
||||
}
|
||||
|
||||
pm.udpConns = append(pm.udpConns, conn)
|
||||
go pm.handleUDPProxy(conn, targetAddr)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unsupported protocol: %s", proto)
|
||||
}
|
||||
|
||||
target.Lock()
|
||||
target.listener = listener
|
||||
target.Unlock()
|
||||
logger.Info("Started %s proxy from %s:%d to %s", proto, listenIP, port, targetAddr)
|
||||
|
||||
defer listener.Close()
|
||||
logger.Info("TCP proxy listening on %s", listener.Addr())
|
||||
|
||||
var activeConns sync.WaitGroup
|
||||
acceptDone := make(chan struct{})
|
||||
|
||||
// Goroutine to handle shutdown signal
|
||||
go func() {
|
||||
<-target.cancel
|
||||
close(acceptDone)
|
||||
listener.Close()
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string) {
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-target.cancel:
|
||||
// Wait for active connections to finish
|
||||
activeConns.Wait()
|
||||
// Check if we're shutting down or the listener was closed
|
||||
if !pm.running {
|
||||
return
|
||||
default:
|
||||
logger.Info("Failed to accept TCP connection: %v", err)
|
||||
// Don't return here, try to accept new connections
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for specific network errors that indicate the listener is closed
|
||||
if ne, ok := err.(net.Error); ok && !ne.Temporary() {
|
||||
logger.Info("TCP listener closed, stopping proxy handler for %v", listener.Addr())
|
||||
return
|
||||
}
|
||||
|
||||
logger.Error("Error accepting TCP connection: %v", err)
|
||||
// Don't hammer the CPU if we hit a temporary error
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
activeConns.Add(1)
|
||||
go func() {
|
||||
defer activeConns.Done()
|
||||
pm.handleTCPConnection(conn, target.Target, acceptDone)
|
||||
target, err := net.Dial("tcp", targetAddr)
|
||||
if err != nil {
|
||||
logger.Error("Error connecting to target: %v", err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// Create a WaitGroup to ensure both copy operations complete
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
io.Copy(target, conn)
|
||||
target.Close()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
io.Copy(conn, target)
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
// Wait for both copies to complete
|
||||
wg.Wait()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) handleTCPConnection(clientConn net.Conn, target string, done chan struct{}) {
|
||||
defer clientConn.Close()
|
||||
|
||||
serverConn, err := net.Dial("tcp", target)
|
||||
if err != nil {
|
||||
logger.Info("Failed to connect to target %s: %v", target, err)
|
||||
return
|
||||
}
|
||||
defer serverConn.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
// Client -> Server
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
io.Copy(serverConn, clientConn)
|
||||
}
|
||||
}()
|
||||
|
||||
// Server -> Client
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
io.Copy(clientConn, serverConn)
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) serveUDP(target *ProxyTarget) {
|
||||
defer close(target.done) // Signal that this target is fully stopped
|
||||
|
||||
addr := &net.UDPAddr{
|
||||
IP: net.ParseIP(target.Listen),
|
||||
Port: target.Port,
|
||||
}
|
||||
|
||||
conn, err := pm.tnet.ListenUDP(addr)
|
||||
if err != nil {
|
||||
logger.Info("Failed to start UDP listener for %s:%d: %v", target.Listen, target.Port, err)
|
||||
return
|
||||
}
|
||||
|
||||
target.Lock()
|
||||
target.udpConn = conn
|
||||
target.Unlock()
|
||||
|
||||
defer conn.Close()
|
||||
logger.Info("UDP proxy listening on %s", conn.LocalAddr())
|
||||
|
||||
buffer := make([]byte, 65535)
|
||||
var activeConns sync.WaitGroup
|
||||
func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
|
||||
buffer := make([]byte, 65507) // Max UDP packet size
|
||||
clientConns := make(map[string]*net.UDPConn)
|
||||
var clientsMutex sync.RWMutex
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-target.cancel:
|
||||
activeConns.Wait() // Wait for all active UDP handlers to complete
|
||||
return
|
||||
default:
|
||||
n, remoteAddr, err := conn.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
select {
|
||||
case <-target.cancel:
|
||||
activeConns.Wait()
|
||||
return
|
||||
default:
|
||||
logger.Info("Failed to read UDP packet: %v", err)
|
||||
continue
|
||||
}
|
||||
n, remoteAddr, err := conn.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
if !pm.running {
|
||||
return
|
||||
}
|
||||
|
||||
targetAddr, err := net.ResolveUDPAddr("udp", target.Target)
|
||||
// Check for connection closed conditions
|
||||
if err == io.EOF || strings.Contains(err.Error(), "use of closed network connection") {
|
||||
logger.Info("UDP connection closed, stopping proxy handler")
|
||||
|
||||
// Clean up existing client connections
|
||||
clientsMutex.Lock()
|
||||
for _, targetConn := range clientConns {
|
||||
targetConn.Close()
|
||||
}
|
||||
clientConns = nil
|
||||
clientsMutex.Unlock()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
logger.Error("Error reading UDP packet: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
clientKey := remoteAddr.String()
|
||||
clientsMutex.RLock()
|
||||
targetConn, exists := clientConns[clientKey]
|
||||
clientsMutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
targetUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr)
|
||||
if err != nil {
|
||||
logger.Info("Failed to resolve target address %s: %v", target.Target, err)
|
||||
logger.Error("Error resolving target address: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
activeConns.Add(1)
|
||||
go func(data []byte, remote net.Addr) {
|
||||
defer activeConns.Done()
|
||||
targetConn, err := net.DialUDP("udp", nil, targetAddr)
|
||||
if err != nil {
|
||||
logger.Info("Failed to connect to target %s: %v", target.Target, err)
|
||||
return
|
||||
}
|
||||
defer targetConn.Close()
|
||||
targetConn, err = net.DialUDP("udp", nil, targetUDPAddr)
|
||||
if err != nil {
|
||||
logger.Error("Error connecting to target: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case <-target.cancel:
|
||||
return
|
||||
default:
|
||||
_, err = targetConn.Write(data)
|
||||
clientsMutex.Lock()
|
||||
clientConns[clientKey] = targetConn
|
||||
clientsMutex.Unlock()
|
||||
|
||||
go func() {
|
||||
buffer := make([]byte, 65507)
|
||||
for {
|
||||
n, _, err := targetConn.ReadFromUDP(buffer)
|
||||
if err != nil {
|
||||
logger.Info("Failed to write to target: %v", err)
|
||||
logger.Error("Error reading from target: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]byte, 65535)
|
||||
n, err := targetConn.Read(response)
|
||||
_, err = conn.WriteTo(buffer[:n], remoteAddr)
|
||||
if err != nil {
|
||||
logger.Info("Failed to read response from target: %v", err)
|
||||
logger.Error("Error writing to client: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = conn.WriteTo(response[:n], remote)
|
||||
if err != nil {
|
||||
logger.Info("Failed to write response to client: %v", err)
|
||||
}
|
||||
}
|
||||
}(buffer[:n], remoteAddr)
|
||||
}()
|
||||
}
|
||||
|
||||
_, err = targetConn.Write(buffer[:n])
|
||||
if err != nil {
|
||||
logger.Error("Error writing to target: %v", err)
|
||||
targetConn.Close()
|
||||
clientsMutex.Lock()
|
||||
delete(clientConns, clientKey)
|
||||
clientsMutex.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
)
|
||||
|
||||
type ProxyTarget struct {
|
||||
Protocol string
|
||||
Listen string
|
||||
Port int
|
||||
Target string
|
||||
cancel chan struct{} // Channel to signal shutdown
|
||||
done chan struct{} // Channel to signal completion
|
||||
listener net.Listener // For TCP
|
||||
udpConn net.PacketConn // For UDP
|
||||
sync.Mutex // Protect access to connection
|
||||
}
|
||||
|
||||
type ProxyManager struct {
|
||||
targets []ProxyTarget
|
||||
tnet *netstack.Net
|
||||
log *log.Logger
|
||||
sync.RWMutex // Protect access to targets slice
|
||||
}
|
||||
77
updown.py
Normal file
77
updown.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
Sample updown script for Newt proxy
|
||||
Usage: update.py <action> <protocol> <target>
|
||||
|
||||
Parameters:
|
||||
- action: 'add' or 'remove'
|
||||
- protocol: 'tcp' or 'udp'
|
||||
- target: the target address in format 'host:port'
|
||||
|
||||
If the action is 'add', the script can return a modified target that
|
||||
will be used instead of the original.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import logging
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
# Configure logging
|
||||
LOG_FILE = "/tmp/newt-updown.log"
|
||||
logging.basicConfig(
|
||||
filename=LOG_FILE,
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
def log_event(action, protocol, target):
|
||||
"""Log each event to a file for auditing purposes"""
|
||||
timestamp = datetime.now().isoformat()
|
||||
event = {
|
||||
"timestamp": timestamp,
|
||||
"action": action,
|
||||
"protocol": protocol,
|
||||
"target": target
|
||||
}
|
||||
logging.info(json.dumps(event))
|
||||
|
||||
def handle_add(protocol, target):
|
||||
"""Handle 'add' action"""
|
||||
logging.info(f"Adding {protocol} target: {target}")
|
||||
|
||||
def handle_remove(protocol, target):
|
||||
"""Handle 'remove' action"""
|
||||
logging.info(f"Removing {protocol} target: {target}")
|
||||
# For remove action, no return value is expected or used
|
||||
|
||||
def main():
|
||||
# Check arguments
|
||||
if len(sys.argv) != 4:
|
||||
logging.error(f"Invalid arguments: {sys.argv}")
|
||||
sys.exit(1)
|
||||
|
||||
action = sys.argv[1]
|
||||
protocol = sys.argv[2]
|
||||
target = sys.argv[3]
|
||||
|
||||
# Log the event
|
||||
log_event(action, protocol, target)
|
||||
|
||||
# Handle the action
|
||||
if action == "add":
|
||||
new_target = handle_add(protocol, target)
|
||||
# Print the new target to stdout (if empty, no change will be made)
|
||||
if new_target and new_target != target:
|
||||
print(new_target)
|
||||
elif action == "remove":
|
||||
handle_remove(protocol, target)
|
||||
else:
|
||||
logging.error(f"Unknown action: {action}")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except Exception as e:
|
||||
logging.error(f"Unhandled exception: {e}")
|
||||
sys.exit(1)
|
||||
@@ -27,7 +27,8 @@ type Client struct {
|
||||
isConnected bool
|
||||
reconnectMux sync.RWMutex
|
||||
|
||||
onConnect func() error
|
||||
onConnect func() error
|
||||
onTokenUpdate func(token string)
|
||||
}
|
||||
|
||||
type ClientOption func(*Client)
|
||||
@@ -45,6 +46,10 @@ func (c *Client) OnConnect(callback func() error) {
|
||||
c.onConnect = callback
|
||||
}
|
||||
|
||||
func (c *Client) OnTokenUpdate(callback func(token string)) {
|
||||
c.onTokenUpdate = callback
|
||||
}
|
||||
|
||||
// NewClient creates a new Newt client
|
||||
func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*Client, error) {
|
||||
config := &Config{
|
||||
@@ -228,6 +233,10 @@ func (c *Client) getToken() (string, error) {
|
||||
|
||||
var tokenResp TokenResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
// print out the token response for debugging
|
||||
buf := new(bytes.Buffer)
|
||||
buf.ReadFrom(resp.Body)
|
||||
logger.Info("Token response: %s", buf.String())
|
||||
return "", fmt.Errorf("failed to decode token response: %w", err)
|
||||
}
|
||||
|
||||
@@ -266,6 +275,8 @@ func (c *Client) establishConnection() error {
|
||||
return fmt.Errorf("failed to get token: %w", err)
|
||||
}
|
||||
|
||||
c.onTokenUpdate(token)
|
||||
|
||||
// Parse the base URL to determine protocol and hostname
|
||||
baseURL, err := url.Parse(c.baseURL)
|
||||
if err != nil {
|
||||
@@ -288,6 +299,7 @@ func (c *Client) establishConnection() error {
|
||||
// Add token to query parameters
|
||||
q := u.Query()
|
||||
q.Set("token", token)
|
||||
q.Set("clientType", "newt")
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
// Connect to WebSocket
|
||||
@@ -305,6 +317,10 @@ func (c *Client) establishConnection() error {
|
||||
go c.readPump()
|
||||
|
||||
if c.onConnect != nil {
|
||||
err := c.saveConfig()
|
||||
if err != nil {
|
||||
logger.Error("Failed to save config: %v", err)
|
||||
}
|
||||
if err := c.onConnect(); err != nil {
|
||||
logger.Error("OnConnect callback failed: %v", err)
|
||||
}
|
||||
|
||||
979
wg/wg.go
Normal file
979
wg/wg.go
Normal file
@@ -0,0 +1,979 @@
|
||||
package wg
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/network"
|
||||
"github.com/fosrl/newt/websocket"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.org/x/exp/rand"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
type WgConfig struct {
|
||||
IpAddress string `json:"ipAddress"`
|
||||
Peers []Peer `json:"peers"`
|
||||
}
|
||||
|
||||
type Peer struct {
|
||||
PublicKey string `json:"publicKey"`
|
||||
AllowedIPs []string `json:"allowedIps"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
}
|
||||
|
||||
type PeerBandwidth struct {
|
||||
PublicKey string `json:"publicKey"`
|
||||
BytesIn float64 `json:"bytesIn"`
|
||||
BytesOut float64 `json:"bytesOut"`
|
||||
}
|
||||
|
||||
type PeerReading struct {
|
||||
BytesReceived int64
|
||||
BytesTransmitted int64
|
||||
LastChecked time.Time
|
||||
}
|
||||
|
||||
type WireGuardService struct {
|
||||
interfaceName string
|
||||
mtu int
|
||||
client *websocket.Client
|
||||
wgClient *wgctrl.Client
|
||||
config WgConfig
|
||||
key wgtypes.Key
|
||||
newtId string
|
||||
lastReadings map[string]PeerReading
|
||||
mu sync.Mutex
|
||||
Port uint16
|
||||
stopHolepunch chan struct{}
|
||||
host string
|
||||
serverPubKey string
|
||||
token string
|
||||
stopGetConfig chan struct{}
|
||||
}
|
||||
|
||||
// Add this type definition
|
||||
type fixedPortBind struct {
|
||||
port uint16
|
||||
conn.Bind
|
||||
}
|
||||
|
||||
func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) {
|
||||
// Ignore the requested port and use our fixed port
|
||||
return b.Bind.Open(b.port)
|
||||
}
|
||||
|
||||
func NewFixedPortBind(port uint16) conn.Bind {
|
||||
return &fixedPortBind{
|
||||
port: port,
|
||||
Bind: conn.NewDefaultBind(),
|
||||
}
|
||||
}
|
||||
|
||||
// find an available UDP port in the range [minPort, maxPort] and also the next port for the wgtester
|
||||
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
|
||||
if maxPort < minPort {
|
||||
return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort)
|
||||
}
|
||||
|
||||
// We need to check port+1 as well, so adjust the max port to avoid going out of range
|
||||
adjustedMaxPort := maxPort - 1
|
||||
if adjustedMaxPort < minPort {
|
||||
return 0, fmt.Errorf("insufficient port range to find consecutive ports: min=%d, max=%d", minPort, maxPort)
|
||||
}
|
||||
|
||||
// Create a slice of all ports in the range (excluding the last one)
|
||||
portRange := make([]uint16, adjustedMaxPort-minPort+1)
|
||||
for i := range portRange {
|
||||
portRange[i] = minPort + uint16(i)
|
||||
}
|
||||
|
||||
// Fisher-Yates shuffle to randomize the port order
|
||||
rand.Seed(uint64(time.Now().UnixNano()))
|
||||
for i := len(portRange) - 1; i > 0; i-- {
|
||||
j := rand.Intn(i + 1)
|
||||
portRange[i], portRange[j] = portRange[j], portRange[i]
|
||||
}
|
||||
|
||||
// Try each port in the randomized order
|
||||
for _, port := range portRange {
|
||||
// Check if port is available
|
||||
addr1 := &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: int(port),
|
||||
}
|
||||
conn1, err1 := net.ListenUDP("udp", addr1)
|
||||
if err1 != nil {
|
||||
continue // Port is in use or there was an error, try next port
|
||||
}
|
||||
|
||||
// Check if port+1 is also available
|
||||
addr2 := &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: int(port + 1),
|
||||
}
|
||||
conn2, err2 := net.ListenUDP("udp", addr2)
|
||||
if err2 != nil {
|
||||
// The next port is not available, so close the first connection and try again
|
||||
conn1.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
// Both ports are available, close connections and return the first port
|
||||
conn1.Close()
|
||||
conn2.Close()
|
||||
return port, nil
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("no available consecutive UDP ports found in range %d-%d", minPort, maxPort)
|
||||
}
|
||||
|
||||
func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client) (*WireGuardService, error) {
|
||||
wgClient, err := wgctrl.New()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create WireGuard client: %v", err)
|
||||
}
|
||||
|
||||
var key wgtypes.Key
|
||||
// if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file
|
||||
if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) {
|
||||
// generate a new private key
|
||||
key, err = wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to generate private key: %v", err)
|
||||
}
|
||||
// save the key to the file
|
||||
err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to save private key: %v", err)
|
||||
}
|
||||
} else {
|
||||
keyData, err := os.ReadFile(generateAndSaveKeyTo)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to read private key: %v", err)
|
||||
}
|
||||
key, err = wgtypes.ParseKey(string(keyData))
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to parse private key: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
service := &WireGuardService{
|
||||
interfaceName: interfaceName,
|
||||
mtu: mtu,
|
||||
client: wsClient,
|
||||
wgClient: wgClient,
|
||||
key: key,
|
||||
newtId: newtId,
|
||||
host: host,
|
||||
lastReadings: make(map[string]PeerReading),
|
||||
stopHolepunch: make(chan struct{}),
|
||||
stopGetConfig: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Get the existing wireguard port (keep this part)
|
||||
device, err := service.wgClient.Device(service.interfaceName)
|
||||
if err == nil {
|
||||
service.Port = uint16(device.ListenPort)
|
||||
logger.Info("WireGuard interface %s already exists with port %d\n", service.interfaceName, service.Port)
|
||||
} else {
|
||||
service.Port, err = FindAvailableUDPPort(49152, 65535)
|
||||
if err != nil {
|
||||
fmt.Printf("Error finding available port: %v\n", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Register websocket handlers
|
||||
wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig)
|
||||
wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer)
|
||||
wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer)
|
||||
wsClient.RegisterHandler("newt/wg/peer/update", service.handleUpdatePeer)
|
||||
|
||||
if err := service.sendUDPHolePunch(service.host + ":21820"); err != nil {
|
||||
logger.Error("Failed to send UDP hole punch: %v", err)
|
||||
}
|
||||
|
||||
// start the UDP holepunch
|
||||
go service.keepSendingUDPHolePunch(service.host)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) Close(rm bool) {
|
||||
select {
|
||||
case <-s.stopGetConfig:
|
||||
// Already closed, do nothing
|
||||
default:
|
||||
close(s.stopGetConfig)
|
||||
}
|
||||
|
||||
s.wgClient.Close()
|
||||
// Remove the WireGuard interface
|
||||
if rm {
|
||||
if err := s.removeInterface(); err != nil {
|
||||
logger.Error("Failed to remove WireGuard interface: %v", err)
|
||||
}
|
||||
|
||||
// Remove the private key file
|
||||
if err := os.Remove(s.key.String()); err != nil {
|
||||
logger.Error("Failed to remove private key file: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WireGuardService) SetServerPubKey(serverPubKey string) {
|
||||
s.serverPubKey = serverPubKey
|
||||
}
|
||||
|
||||
func (s *WireGuardService) SetToken(token string) {
|
||||
s.token = token
|
||||
}
|
||||
|
||||
func (s *WireGuardService) LoadRemoteConfig() error {
|
||||
// Send the initial message
|
||||
err := s.sendGetConfigMessage()
|
||||
if err != nil {
|
||||
logger.Error("Failed to send initial get-config message: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Start goroutine to periodically send the message until config is received
|
||||
go s.keepSendingGetConfig()
|
||||
|
||||
go s.periodicBandwidthCheck()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
|
||||
var config WgConfig
|
||||
|
||||
logger.Info("Received message: %v", msg)
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &config); err != nil {
|
||||
logger.Info("Error unmarshaling target data: %v", err)
|
||||
return
|
||||
}
|
||||
s.config = config
|
||||
|
||||
close(s.stopGetConfig)
|
||||
|
||||
// Ensure the WireGuard interface and peers are configured
|
||||
if err := s.ensureWireguardInterface(config); err != nil {
|
||||
logger.Error("Failed to ensure WireGuard interface: %v", err)
|
||||
}
|
||||
|
||||
if err := s.ensureWireguardPeers(config.Peers); err != nil {
|
||||
logger.Error("Failed to ensure WireGuard peers: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
||||
// Check if the WireGuard interface exists
|
||||
_, err := netlink.LinkByName(s.interfaceName)
|
||||
if err != nil {
|
||||
if _, ok := err.(netlink.LinkNotFoundError); ok {
|
||||
// Interface doesn't exist, so create it
|
||||
err = s.createWireGuardInterface()
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to create WireGuard interface: %v", err)
|
||||
}
|
||||
logger.Info("Created WireGuard interface %s\n", s.interfaceName)
|
||||
} else {
|
||||
logger.Fatal("Error checking for WireGuard interface: %v", err)
|
||||
}
|
||||
} else {
|
||||
logger.Info("WireGuard interface %s already exists\n", s.interfaceName)
|
||||
|
||||
// get the exising wireguard port
|
||||
device, err := s.wgClient.Device(s.interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get device: %v", err)
|
||||
}
|
||||
|
||||
// get the existing port
|
||||
s.Port = uint16(device.ListenPort)
|
||||
logger.Info("WireGuard interface %s already exists with port %d\n", s.interfaceName, s.Port)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Info("Assigning IP address %s to interface %s\n", wgconfig.IpAddress, s.interfaceName)
|
||||
// Assign IP address to the interface
|
||||
err = s.assignIPAddress(wgconfig.IpAddress)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to assign IP address: %v", err)
|
||||
}
|
||||
|
||||
// Check if the interface already exists
|
||||
_, err = s.wgClient.Device(s.interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("interface %s does not exist", s.interfaceName)
|
||||
}
|
||||
|
||||
// Parse the private key
|
||||
key, err := wgtypes.ParseKey(s.key.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse private key: %v", err)
|
||||
}
|
||||
|
||||
config := wgtypes.Config{
|
||||
PrivateKey: &key,
|
||||
ListenPort: new(int),
|
||||
}
|
||||
|
||||
// Use the service's fixed port instead of the config port
|
||||
*config.ListenPort = int(s.Port)
|
||||
|
||||
// Create and configure the WireGuard interface
|
||||
err = s.wgClient.ConfigureDevice(s.interfaceName, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to configure WireGuard device: %v", err)
|
||||
}
|
||||
|
||||
// bring up the interface
|
||||
link, err := netlink.LinkByName(s.interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get interface: %v", err)
|
||||
}
|
||||
|
||||
if err := netlink.LinkSetMTU(link, s.mtu); err != nil {
|
||||
return fmt.Errorf("failed to set MTU: %v", err)
|
||||
}
|
||||
|
||||
if err := netlink.LinkSetUp(link); err != nil {
|
||||
return fmt.Errorf("failed to bring up interface: %v", err)
|
||||
}
|
||||
|
||||
// if err := s.ensureMSSClamping(); err != nil {
|
||||
// logger.Warn("Failed to ensure MSS clamping: %v", err)
|
||||
// }
|
||||
|
||||
logger.Info("WireGuard interface %s created and configured", s.interfaceName)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) createWireGuardInterface() error {
|
||||
wgLink := &netlink.GenericLink{
|
||||
LinkAttrs: netlink.LinkAttrs{Name: s.interfaceName},
|
||||
LinkType: "wireguard",
|
||||
}
|
||||
return netlink.LinkAdd(wgLink)
|
||||
}
|
||||
|
||||
func (s *WireGuardService) assignIPAddress(ipAddress string) error {
|
||||
link, err := netlink.LinkByName(s.interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get interface: %v", err)
|
||||
}
|
||||
|
||||
addr, err := netlink.ParseAddr(ipAddress)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse IP address: %v", err)
|
||||
}
|
||||
|
||||
return netlink.AddrAdd(link, addr)
|
||||
}
|
||||
|
||||
func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
|
||||
// get the current peers
|
||||
device, err := s.wgClient.Device(s.interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get device: %v", err)
|
||||
}
|
||||
|
||||
// get the peer public keys
|
||||
var currentPeers []string
|
||||
for _, peer := range device.Peers {
|
||||
currentPeers = append(currentPeers, peer.PublicKey.String())
|
||||
}
|
||||
|
||||
// remove any peers that are not in the config
|
||||
for _, peer := range currentPeers {
|
||||
found := false
|
||||
for _, configPeer := range peers {
|
||||
if peer == configPeer.PublicKey {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
err := s.removePeer(peer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove peer: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// add any peers that are in the config but not in the current peers
|
||||
for _, configPeer := range peers {
|
||||
found := false
|
||||
for _, peer := range currentPeers {
|
||||
if configPeer.PublicKey == peer {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
err := s.addPeer(configPeer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add peer: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) {
|
||||
logger.Info("Received message: %v", msg.Data)
|
||||
var peer Peer
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error marshaling data: %v", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &peer); err != nil {
|
||||
logger.Info("Error unmarshaling target data: %v", err)
|
||||
}
|
||||
|
||||
err = s.addPeer(peer)
|
||||
if err != nil {
|
||||
logger.Info("Error adding peer: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WireGuardService) addPeer(peer Peer) error {
|
||||
pubKey, err := wgtypes.ParseKey(peer.PublicKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %v", err)
|
||||
}
|
||||
|
||||
// parse allowed IPs into array of net.IPNet
|
||||
var allowedIPs []net.IPNet
|
||||
for _, ipStr := range peer.AllowedIPs {
|
||||
_, ipNet, err := net.ParseCIDR(ipStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse allowed IP: %v", err)
|
||||
}
|
||||
allowedIPs = append(allowedIPs, *ipNet)
|
||||
}
|
||||
// add keep alive using *time.Duration of 1 second
|
||||
keepalive := time.Second
|
||||
|
||||
var peerConfig wgtypes.PeerConfig
|
||||
if peer.Endpoint != "" {
|
||||
endpoint, err := net.ResolveUDPAddr("udp", peer.Endpoint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve endpoint address: %w", err)
|
||||
}
|
||||
|
||||
peerConfig = wgtypes.PeerConfig{
|
||||
PublicKey: pubKey,
|
||||
AllowedIPs: allowedIPs,
|
||||
PersistentKeepaliveInterval: &keepalive,
|
||||
Endpoint: endpoint,
|
||||
}
|
||||
} else {
|
||||
peerConfig = wgtypes.PeerConfig{
|
||||
PublicKey: pubKey,
|
||||
AllowedIPs: allowedIPs,
|
||||
PersistentKeepaliveInterval: &keepalive,
|
||||
}
|
||||
logger.Info("Added peer with no endpoint!")
|
||||
}
|
||||
|
||||
config := wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{peerConfig},
|
||||
}
|
||||
|
||||
if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil {
|
||||
return fmt.Errorf("failed to add peer: %v", err)
|
||||
}
|
||||
|
||||
logger.Info("Peer %s added successfully", peer.PublicKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) handleRemovePeer(msg websocket.WSMessage) {
|
||||
logger.Info("Received message: %v", msg.Data)
|
||||
// parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" }
|
||||
type RemoveRequest struct {
|
||||
PublicKey string `json:"publicKey"`
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error marshaling data: %v", err)
|
||||
}
|
||||
|
||||
var request RemoveRequest
|
||||
if err := json.Unmarshal(jsonData, &request); err != nil {
|
||||
logger.Info("Error unmarshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.removePeer(request.PublicKey); err != nil {
|
||||
logger.Info("Error removing peer: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WireGuardService) removePeer(publicKey string) error {
|
||||
pubKey, err := wgtypes.ParseKey(publicKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %v", err)
|
||||
}
|
||||
|
||||
peerConfig := wgtypes.PeerConfig{
|
||||
PublicKey: pubKey,
|
||||
Remove: true,
|
||||
}
|
||||
|
||||
config := wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{peerConfig},
|
||||
}
|
||||
|
||||
if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil {
|
||||
return fmt.Errorf("failed to remove peer: %v", err)
|
||||
}
|
||||
|
||||
logger.Info("Peer %s removed successfully", publicKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) handleUpdatePeer(msg websocket.WSMessage) {
|
||||
logger.Info("Received message: %v", msg.Data)
|
||||
// Define a struct to match the incoming message structure with optional fields
|
||||
type UpdatePeerRequest struct {
|
||||
PublicKey string `json:"publicKey"`
|
||||
AllowedIPs []string `json:"allowedIps,omitempty"`
|
||||
Endpoint string `json:"endpoint,omitempty"`
|
||||
}
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
var request UpdatePeerRequest
|
||||
if err := json.Unmarshal(jsonData, &request); err != nil {
|
||||
logger.Info("Error unmarshaling peer data: %v", err)
|
||||
return
|
||||
}
|
||||
// First, get the current peer configuration to preserve any unmodified fields
|
||||
device, err := s.wgClient.Device(s.interfaceName)
|
||||
if err != nil {
|
||||
logger.Info("Error getting WireGuard device: %v", err)
|
||||
return
|
||||
}
|
||||
pubKey, err := wgtypes.ParseKey(request.PublicKey)
|
||||
if err != nil {
|
||||
logger.Info("Error parsing public key: %v", err)
|
||||
return
|
||||
}
|
||||
// Find the existing peer configuration
|
||||
var currentPeer *wgtypes.Peer
|
||||
for _, p := range device.Peers {
|
||||
if p.PublicKey == pubKey {
|
||||
currentPeer = &p
|
||||
break
|
||||
}
|
||||
}
|
||||
if currentPeer == nil {
|
||||
logger.Info("Peer %s not found, cannot update", request.PublicKey)
|
||||
return
|
||||
}
|
||||
// Create the update peer config
|
||||
peerConfig := wgtypes.PeerConfig{
|
||||
PublicKey: pubKey,
|
||||
UpdateOnly: true,
|
||||
}
|
||||
// Keep the default persistent keepalive of 1 second
|
||||
keepalive := time.Second
|
||||
peerConfig.PersistentKeepaliveInterval = &keepalive
|
||||
|
||||
// Handle Endpoint field special case
|
||||
// If Endpoint is included in the request but empty, we want to remove the endpoint
|
||||
// If Endpoint is not included, we don't modify it
|
||||
endpointSpecified := false
|
||||
for key := range msg.Data.(map[string]interface{}) {
|
||||
if key == "endpoint" {
|
||||
endpointSpecified = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Only update AllowedIPs if provided in the request
|
||||
if request.AllowedIPs != nil && len(request.AllowedIPs) > 0 {
|
||||
var allowedIPs []net.IPNet
|
||||
for _, ipStr := range request.AllowedIPs {
|
||||
_, ipNet, err := net.ParseCIDR(ipStr)
|
||||
if err != nil {
|
||||
logger.Info("Error parsing allowed IP %s: %v", ipStr, err)
|
||||
return
|
||||
}
|
||||
allowedIPs = append(allowedIPs, *ipNet)
|
||||
}
|
||||
peerConfig.AllowedIPs = allowedIPs
|
||||
peerConfig.ReplaceAllowedIPs = true
|
||||
logger.Info("Updating AllowedIPs for peer %s", request.PublicKey)
|
||||
} else if endpointSpecified && request.Endpoint == "" {
|
||||
peerConfig.ReplaceAllowedIPs = false
|
||||
}
|
||||
|
||||
if endpointSpecified {
|
||||
if request.Endpoint != "" {
|
||||
// Update to new endpoint
|
||||
endpoint, err := net.ResolveUDPAddr("udp", request.Endpoint)
|
||||
if err != nil {
|
||||
logger.Info("Error resolving endpoint address %s: %v", request.Endpoint, err)
|
||||
return
|
||||
}
|
||||
peerConfig.Endpoint = endpoint
|
||||
logger.Info("Updating Endpoint for peer %s to %s", request.PublicKey, request.Endpoint)
|
||||
} else {
|
||||
// specify any address to listen for any incoming packets
|
||||
peerConfig.Endpoint = &net.UDPAddr{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
}
|
||||
logger.Info("Removing Endpoint for peer %s", request.PublicKey)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply the configuration update
|
||||
config := wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{peerConfig},
|
||||
}
|
||||
if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil {
|
||||
logger.Info("Error updating peer configuration: %v", err)
|
||||
return
|
||||
}
|
||||
logger.Info("Peer %s updated successfully", request.PublicKey)
|
||||
}
|
||||
|
||||
func (s *WireGuardService) periodicBandwidthCheck() {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
if err := s.reportPeerBandwidth(); err != nil {
|
||||
logger.Info("Failed to report peer bandwidth: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
||||
device, err := s.wgClient.Device(s.interfaceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get device: %v", err)
|
||||
}
|
||||
|
||||
peerBandwidths := []PeerBandwidth{}
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for _, peer := range device.Peers {
|
||||
publicKey := peer.PublicKey.String()
|
||||
currentReading := PeerReading{
|
||||
BytesReceived: peer.ReceiveBytes,
|
||||
BytesTransmitted: peer.TransmitBytes,
|
||||
LastChecked: now,
|
||||
}
|
||||
|
||||
var bytesInDiff, bytesOutDiff float64
|
||||
lastReading, exists := s.lastReadings[publicKey]
|
||||
|
||||
if exists {
|
||||
timeDiff := currentReading.LastChecked.Sub(lastReading.LastChecked).Seconds()
|
||||
if timeDiff > 0 {
|
||||
// Calculate bytes transferred since last reading
|
||||
bytesInDiff = float64(currentReading.BytesReceived - lastReading.BytesReceived)
|
||||
bytesOutDiff = float64(currentReading.BytesTransmitted - lastReading.BytesTransmitted)
|
||||
|
||||
// Handle counter wraparound (if the counter resets or overflows)
|
||||
if bytesInDiff < 0 {
|
||||
bytesInDiff = float64(currentReading.BytesReceived)
|
||||
}
|
||||
if bytesOutDiff < 0 {
|
||||
bytesOutDiff = float64(currentReading.BytesTransmitted)
|
||||
}
|
||||
|
||||
// Convert to MB
|
||||
bytesInMB := bytesInDiff / (1024 * 1024)
|
||||
bytesOutMB := bytesOutDiff / (1024 * 1024)
|
||||
|
||||
peerBandwidths = append(peerBandwidths, PeerBandwidth{
|
||||
PublicKey: publicKey,
|
||||
BytesIn: bytesInMB,
|
||||
BytesOut: bytesOutMB,
|
||||
})
|
||||
} else {
|
||||
// If readings are too close together or time hasn't passed, report 0
|
||||
peerBandwidths = append(peerBandwidths, PeerBandwidth{
|
||||
PublicKey: publicKey,
|
||||
BytesIn: 0,
|
||||
BytesOut: 0,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// For first reading of a peer, report 0 to establish baseline
|
||||
peerBandwidths = append(peerBandwidths, PeerBandwidth{
|
||||
PublicKey: publicKey,
|
||||
BytesIn: 0,
|
||||
BytesOut: 0,
|
||||
})
|
||||
}
|
||||
|
||||
// Update the last reading
|
||||
s.lastReadings[publicKey] = currentReading
|
||||
}
|
||||
|
||||
// Clean up old peers
|
||||
for publicKey := range s.lastReadings {
|
||||
found := false
|
||||
for _, peer := range device.Peers {
|
||||
if peer.PublicKey.String() == publicKey {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
delete(s.lastReadings, publicKey)
|
||||
}
|
||||
}
|
||||
|
||||
return peerBandwidths, nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) reportPeerBandwidth() error {
|
||||
bandwidths, err := s.calculatePeerBandwidth()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to calculate peer bandwidth: %v", err)
|
||||
}
|
||||
|
||||
err = s.client.SendMessage("newt/receive-bandwidth", map[string]interface{}{
|
||||
"bandwidthData": bandwidths,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send bandwidth data: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error {
|
||||
|
||||
if s.serverPubKey == "" || s.token == "" {
|
||||
logger.Debug("Server public key or token not set, skipping UDP hole punch")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse server address
|
||||
serverSplit := strings.Split(serverAddr, ":")
|
||||
if len(serverSplit) < 2 {
|
||||
return fmt.Errorf("invalid server address format, expected hostname:port")
|
||||
}
|
||||
|
||||
serverHostname := serverSplit[0]
|
||||
serverPort, err := strconv.ParseUint(serverSplit[1], 10, 16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse server port: %v", err)
|
||||
}
|
||||
|
||||
// Resolve server hostname to IP
|
||||
serverIPAddr := network.HostToAddr(serverHostname)
|
||||
if serverIPAddr == nil {
|
||||
return fmt.Errorf("failed to resolve server hostname")
|
||||
}
|
||||
|
||||
// Get client IP based on route to server
|
||||
clientIP := network.GetClientIP(serverIPAddr.IP)
|
||||
|
||||
// Create server and client configs
|
||||
server := &network.Server{
|
||||
Hostname: serverHostname,
|
||||
Addr: serverIPAddr,
|
||||
Port: uint16(serverPort),
|
||||
}
|
||||
|
||||
client := &network.PeerNet{
|
||||
IP: clientIP,
|
||||
Port: s.Port,
|
||||
NewtID: s.newtId,
|
||||
}
|
||||
|
||||
// Setup raw connection with BPF filtering
|
||||
rawConn := network.SetupRawConn(server, client)
|
||||
defer rawConn.Close()
|
||||
|
||||
// Create JSON payload
|
||||
payload := struct {
|
||||
NewtID string `json:"newtId"`
|
||||
Token string `json:"token"`
|
||||
}{
|
||||
NewtID: s.newtId,
|
||||
Token: s.token,
|
||||
}
|
||||
|
||||
// Convert payload to JSON
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal payload: %v", err)
|
||||
}
|
||||
|
||||
// Encrypt the payload using the server's WireGuard public key
|
||||
encryptedPayload, err := s.encryptPayload(payloadBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt payload: %v", err)
|
||||
}
|
||||
|
||||
// Send the encrypted packet using the raw connection
|
||||
err = network.SendDataPacket(encryptedPayload, rawConn, server, client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send UDP packet: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) encryptPayload(payload []byte) (interface{}, error) {
|
||||
// Generate an ephemeral keypair for this message
|
||||
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
|
||||
}
|
||||
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
|
||||
|
||||
// Parse the server's public key
|
||||
serverPubKey, err := wgtypes.ParseKey(s.serverPubKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse server public key: %v", err)
|
||||
}
|
||||
|
||||
// Use X25519 for key exchange (replacing deprecated ScalarMult)
|
||||
var ephPrivKeyFixed [32]byte
|
||||
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
|
||||
|
||||
// Perform X25519 key exchange
|
||||
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
|
||||
}
|
||||
|
||||
// Create an AEAD cipher using the shared secret
|
||||
aead, err := chacha20poly1305.New(sharedSecret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
|
||||
}
|
||||
|
||||
// Generate a random nonce
|
||||
nonce := make([]byte, aead.NonceSize())
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate nonce: %v", err)
|
||||
}
|
||||
|
||||
// Encrypt the payload
|
||||
ciphertext := aead.Seal(nil, nonce, payload, nil)
|
||||
|
||||
// Prepare the final encrypted message
|
||||
encryptedMsg := struct {
|
||||
EphemeralPublicKey string `json:"ephemeralPublicKey"`
|
||||
Nonce []byte `json:"nonce"`
|
||||
Ciphertext []byte `json:"ciphertext"`
|
||||
}{
|
||||
EphemeralPublicKey: ephemeralPublicKey.String(),
|
||||
Nonce: nonce,
|
||||
Ciphertext: ciphertext,
|
||||
}
|
||||
|
||||
return encryptedMsg, nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) keepSendingUDPHolePunch(host string) {
|
||||
ticker := time.NewTicker(3 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.stopHolepunch:
|
||||
logger.Info("Stopping UDP holepunch")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.sendUDPHolePunch(host + ":21820"); err != nil {
|
||||
logger.Error("Failed to send UDP hole punch: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WireGuardService) removeInterface() error {
|
||||
// Remove the WireGuard interface
|
||||
link, err := netlink.LinkByName(s.interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get interface: %v", err)
|
||||
}
|
||||
|
||||
err = netlink.LinkDel(link)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete interface: %v", err)
|
||||
}
|
||||
|
||||
logger.Info("WireGuard interface %s removed successfully", s.interfaceName)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) sendGetConfigMessage() error {
|
||||
err := s.client.SendMessage("newt/wg/get-config", map[string]interface{}{
|
||||
"publicKey": fmt.Sprintf("%s", s.key.PublicKey().String()),
|
||||
"port": s.Port,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send get-config message: %v", err)
|
||||
return err
|
||||
}
|
||||
logger.Info("Requesting WireGuard configuration from remote server")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) keepSendingGetConfig() {
|
||||
ticker := time.NewTicker(3 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.stopGetConfig:
|
||||
logger.Info("Stopping get-config messages")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.sendGetConfigMessage(); err != nil {
|
||||
logger.Error("Failed to send periodic get-config: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
164
wgtester/wgtester.go
Normal file
164
wgtester/wgtester.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package wgtester
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
// Magic bytes to identify our packets
|
||||
magicHeader uint32 = 0xDEADBEEF
|
||||
// Request packet type
|
||||
packetTypeRequest uint8 = 1
|
||||
// Response packet type
|
||||
packetTypeResponse uint8 = 2
|
||||
// Packet format:
|
||||
// - 4 bytes: magic header (0xDEADBEEF)
|
||||
// - 1 byte: packet type (1 = request, 2 = response)
|
||||
// - 8 bytes: timestamp (for round-trip timing)
|
||||
packetSize = 13
|
||||
)
|
||||
|
||||
// Server handles listening for connection check requests using UDP
|
||||
type Server struct {
|
||||
conn *net.UDPConn
|
||||
serverAddr string
|
||||
serverPort uint16
|
||||
shutdownCh chan struct{}
|
||||
isRunning bool
|
||||
runningLock sync.Mutex
|
||||
newtID string
|
||||
outputPrefix string
|
||||
}
|
||||
|
||||
// NewServer creates a new connection test server using UDP
|
||||
func NewServer(serverAddr string, serverPort uint16, newtID string) *Server {
|
||||
return &Server{
|
||||
serverAddr: serverAddr,
|
||||
serverPort: serverPort + 1, // use the next port for the server
|
||||
shutdownCh: make(chan struct{}),
|
||||
newtID: newtID,
|
||||
outputPrefix: "[WGTester] ",
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins listening for connection test packets using UDP
|
||||
func (s *Server) Start() error {
|
||||
s.runningLock.Lock()
|
||||
defer s.runningLock.Unlock()
|
||||
|
||||
if s.isRunning {
|
||||
return nil
|
||||
}
|
||||
|
||||
//create the address to listen on
|
||||
addr := net.JoinHostPort(s.serverAddr, fmt.Sprintf("%d", s.serverPort))
|
||||
|
||||
// Create UDP address to listen on
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create UDP connection
|
||||
conn, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.conn = conn
|
||||
|
||||
s.isRunning = true
|
||||
go s.handleConnections()
|
||||
|
||||
logger.Info("%sServer started on %s:%d", s.outputPrefix, s.serverAddr, s.serverPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop shuts down the server
|
||||
func (s *Server) Stop() {
|
||||
s.runningLock.Lock()
|
||||
defer s.runningLock.Unlock()
|
||||
|
||||
if !s.isRunning {
|
||||
return
|
||||
}
|
||||
|
||||
close(s.shutdownCh)
|
||||
if s.conn != nil {
|
||||
s.conn.Close()
|
||||
}
|
||||
s.isRunning = false
|
||||
logger.Info(s.outputPrefix + "Server stopped")
|
||||
}
|
||||
|
||||
// handleConnections processes incoming packets
|
||||
func (s *Server) handleConnections() {
|
||||
buffer := make([]byte, 2000) // Buffer large enough for any UDP packet
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.shutdownCh:
|
||||
return
|
||||
default:
|
||||
// Set read deadline to avoid blocking forever
|
||||
err := s.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
if err != nil {
|
||||
logger.Error(s.outputPrefix+"Error setting read deadline: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Read from UDP connection
|
||||
n, addr, err := s.conn.ReadFromUDP(buffer)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
// Just a timeout, keep going
|
||||
continue
|
||||
}
|
||||
logger.Error(s.outputPrefix+"Error reading from UDP: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Process packet only if it meets minimum size requirements
|
||||
if n < packetSize {
|
||||
continue // Too small to be our packet
|
||||
}
|
||||
|
||||
// Check magic header
|
||||
magic := binary.BigEndian.Uint32(buffer[0:4])
|
||||
if magic != magicHeader {
|
||||
continue // Not our packet
|
||||
}
|
||||
|
||||
// Check packet type
|
||||
packetType := buffer[4]
|
||||
if packetType != packetTypeRequest {
|
||||
continue // Not a request packet
|
||||
}
|
||||
|
||||
// Create response packet
|
||||
responsePacket := make([]byte, packetSize)
|
||||
// Copy the same magic header
|
||||
binary.BigEndian.PutUint32(responsePacket[0:4], magicHeader)
|
||||
// Change the packet type to response
|
||||
responsePacket[4] = packetTypeResponse
|
||||
// Copy the timestamp (for RTT calculation)
|
||||
copy(responsePacket[5:13], buffer[5:13])
|
||||
|
||||
// Log response being sent for debugging
|
||||
logger.Debug(s.outputPrefix+"Sending response to %s", addr.String())
|
||||
|
||||
// Send the response packet directly to the source address
|
||||
_, err = s.conn.WriteToUDP(responsePacket, addr)
|
||||
if err != nil {
|
||||
logger.Error(s.outputPrefix+"Error sending response: %v", err)
|
||||
} else {
|
||||
logger.Debug(s.outputPrefix + "Response sent successfully")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user