mirror of
https://github.com/fosrl/newt.git
synced 2026-03-12 18:04:28 -05:00
Compare commits
79 Commits
1.0.0-beta
...
holepunch
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c5978d9c4e | ||
|
|
f08378b67e | ||
|
|
1501de691a | ||
|
|
2a19856556 | ||
|
|
f4e17a4dd7 | ||
|
|
ab544fc9ed | ||
|
|
f9e52c4d91 | ||
|
|
14eff8e16c | ||
|
|
067e079293 | ||
|
|
623be5ea0d | ||
|
|
72d264d427 | ||
|
|
a19fc8c588 | ||
|
|
dbc2a92456 | ||
|
|
437d8b67a4 | ||
|
|
6f1d4752f0 | ||
|
|
683312c78e | ||
|
|
29543aece3 | ||
|
|
e68a38e929 | ||
|
|
bc72c96b5e | ||
|
|
3d15ecb732 | ||
|
|
a69618310b | ||
|
|
ed8a2ccd23 | ||
|
|
5e673c829b | ||
|
|
cd3ec0b259 | ||
|
|
b68502de9e | ||
|
|
f6429b6eee | ||
|
|
8795c57b2e | ||
|
|
e8141a177b | ||
|
|
4aa718d55f | ||
|
|
afa93d8a3f | ||
|
|
270ee9cd19 | ||
|
|
0affef401c | ||
|
|
18d99de924 | ||
|
|
bff6707577 | ||
|
|
95eab504fa | ||
|
|
56e75902e3 | ||
|
|
45a1ab91d7 | ||
|
|
fb199cc94b | ||
|
|
66edae4288 | ||
|
|
f69a7f647d | ||
|
|
e8bd55bed9 | ||
|
|
b23eda9c06 | ||
|
|
92bc883b5b | ||
|
|
76503f3f2c | ||
|
|
9c3112f9bd | ||
|
|
462af30d16 | ||
|
|
fa6038eb38 | ||
|
|
f346b6cc5d | ||
|
|
f20b9ebb14 | ||
|
|
39bfe5b230 | ||
|
|
a1a3dd9ba2 | ||
|
|
7b1492f327 | ||
|
|
4e50819785 | ||
|
|
f8dccbec80 | ||
|
|
0c5c59cf00 | ||
|
|
868bb55f87 | ||
|
|
5b4245402a | ||
|
|
f7a705e6f8 | ||
|
|
3a63657822 | ||
|
|
759780508a | ||
|
|
533886f2e4 | ||
|
|
79f8745909 | ||
|
|
7b663027ac | ||
|
|
e90e55d982 | ||
|
|
a46fb23cdd | ||
|
|
10982b47a5 | ||
|
|
ab12098c9c | ||
|
|
446eb4d6f1 | ||
|
|
313afdb4c5 | ||
|
|
235a3b9426 | ||
|
|
c298ff52f3 | ||
|
|
75518b2e04 | ||
|
|
739f708ff7 | ||
|
|
2897b92f72 | ||
|
|
2c612d4018 | ||
|
|
41f0973308 | ||
|
|
4a791bdb6e | ||
|
|
9497f9c96f | ||
|
|
e17276b0c4 |
61
.github/workflows/cicd.yml
vendored
Normal file
61
.github/workflows/cicd.yml
vendored
Normal file
@@ -0,0 +1,61 @@
|
||||
name: CI/CD Pipeline
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
jobs:
|
||||
release:
|
||||
name: Build and Release
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
|
||||
|
||||
- name: Extract tag name
|
||||
id: get-tag
|
||||
run: echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: 1.23.1
|
||||
|
||||
- name: Update version in main.go
|
||||
run: |
|
||||
TAG=${{ env.TAG }}
|
||||
if [ -f main.go ]; then
|
||||
sed -i 's/Newt version replaceme/Newt version '"$TAG"'/' main.go
|
||||
echo "Updated main.go with version $TAG"
|
||||
else
|
||||
echo "main.go not found"
|
||||
fi
|
||||
|
||||
- name: Build and push Docker images
|
||||
run: |
|
||||
TAG=${{ env.TAG }}
|
||||
make docker-build-release tag=$TAG
|
||||
|
||||
- name: Build binaries
|
||||
run: |
|
||||
make go-build-release
|
||||
|
||||
- name: Upload artifacts from /bin
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: binaries
|
||||
path: bin/
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -1 +1,4 @@
|
||||
newt
|
||||
newt
|
||||
.DS_Store
|
||||
bin/
|
||||
nohup.out
|
||||
1
.go-version
Normal file
1
.go-version
Normal file
@@ -0,0 +1 @@
|
||||
1.23.2
|
||||
@@ -1,6 +1,12 @@
|
||||
## Contributing
|
||||
|
||||
Contributions are welcome! Please see the following page in our documentation with future plans and feature ideas if you are looking for a place to start.
|
||||
Contributions are welcome!
|
||||
|
||||
Please see the contribution and local development guide on the docs page before getting started:
|
||||
|
||||
https://docs.fossorial.io/development
|
||||
|
||||
For ideas about what features to work on and our future plans, please see the roadmap:
|
||||
|
||||
https://docs.fossorial.io/roadmap
|
||||
|
||||
@@ -15,4 +21,4 @@ By creating this pull request, I grant the project maintainers an unlimited,
|
||||
perpetual license to use, modify, and redistribute these contributions under any terms they
|
||||
choose, including both the AGPLv3 and the Fossorial Commercial license terms. I
|
||||
represent that I have the right to grant this license for all contributed content.
|
||||
```
|
||||
```
|
||||
|
||||
10
Dockerfile
10
Dockerfile
@@ -15,19 +15,13 @@ COPY . .
|
||||
# Build the application
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -o /newt
|
||||
|
||||
# Start a new stage from scratch
|
||||
FROM ubuntu:22.04 AS runner
|
||||
FROM alpine:3.19 AS runner
|
||||
|
||||
RUN apt-get update && apt-get install ca-certificates -y && rm -rf /var/lib/apt/lists/*
|
||||
RUN apk --no-cache add ca-certificates
|
||||
|
||||
# Copy the pre-built binary file from the previous stage and the entrypoint script
|
||||
COPY --from=builder /newt /usr/local/bin/
|
||||
COPY entrypoint.sh /
|
||||
|
||||
RUN chmod +x /entrypoint.sh
|
||||
|
||||
# Copy the entrypoint script
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
|
||||
# Command to run the executable
|
||||
CMD ["newt"]
|
||||
22
Makefile
22
Makefile
@@ -1,6 +1,14 @@
|
||||
|
||||
all: build push
|
||||
|
||||
docker-build-release:
|
||||
@if [ -z "$(tag)" ]; then \
|
||||
echo "Error: tag is required. Usage: make build-all tag=<tag>"; \
|
||||
exit 1; \
|
||||
fi
|
||||
docker buildx build --platform linux/arm/v7,linux/arm64,linux/amd64 -t fosrl/newt:latest -f Dockerfile --push .
|
||||
docker buildx build --platform linux/arm/v7,linux/arm64,linux/amd64 -t fosrl/newt:$(tag) -f Dockerfile --push .
|
||||
|
||||
build:
|
||||
docker build -t fosrl/newt:latest .
|
||||
|
||||
@@ -11,7 +19,19 @@ test:
|
||||
docker run fosrl/newt:latest
|
||||
|
||||
local:
|
||||
CGO_ENABLED=0 go build -o newt
|
||||
CGO_ENABLED=0 go build -o newt
|
||||
|
||||
go-build-release:
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -o bin/newt_linux_arm64
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=7 go build -o bin/newt_linux_arm32
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=6 go build -o bin/newt_linux_arm32v6
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/newt_linux_amd64
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=riscv64 go build -o bin/newt_linux_riscv64
|
||||
CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -o bin/newt_darwin_arm64
|
||||
CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -o bin/newt_darwin_amd64
|
||||
CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -o bin/newt_windows_amd64.exe
|
||||
CGO_ENABLED=0 GOOS=freebsd GOARCH=amd64 go build -o bin/newt_freebsd_amd64
|
||||
CGO_ENABLED=0 GOOS=freebsd GOARCH=arm64 go build -o bin/newt_freebsd_arm64
|
||||
|
||||
clean:
|
||||
rm newt
|
||||
|
||||
61
README.md
61
README.md
@@ -19,7 +19,7 @@ _Sample output of a Newt container connected to Pangolin and hosting various res
|
||||
|
||||
### Registers with Pangolin
|
||||
|
||||
Using the Newt ID and a secret the client will make HTTP requests to Pangolin to receive a session token. Using that token it will connect to a websocket and maintain that connection. Control messages will be sent over the websocket.
|
||||
Using the Newt ID and a secret, the client will make HTTP requests to Pangolin to receive a session token. Using that token, it will connect to a websocket and maintain that connection. Control messages will be sent over the websocket.
|
||||
|
||||
### Receives WireGuard Control Messages
|
||||
|
||||
@@ -27,7 +27,7 @@ When Newt receives WireGuard control messages, it will use the information encod
|
||||
|
||||
### Receives Proxy Control Messages
|
||||
|
||||
When Newt receives WireGuard control messages, it will use the information encoded to crate local low level TCP and UDP proxies attached to the virtual tunnel in order to relay traffic to programmed targets.
|
||||
When Newt receives WireGuard control messages, it will use the information encoded to create a local low level TCP and UDP proxies attached to the virtual tunnel in order to relay traffic to programmed targets.
|
||||
|
||||
## CLI Args
|
||||
|
||||
@@ -36,7 +36,8 @@ When Newt receives WireGuard control messages, it will use the information encod
|
||||
- `secret`: A unique secret (not shared and kept private) used to authenticate the client ID with the websocket in order to receive commands.
|
||||
- `dns`: DNS server to use to resolve the endpoint
|
||||
- `log-level` (optional): The log level to use. Default: INFO
|
||||
|
||||
- `updown` (optional): A script to be called when targets are added or removed.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
@@ -46,6 +47,22 @@ Example:
|
||||
--endpoint https://example.com
|
||||
```
|
||||
|
||||
You can also run it with Docker compose. For example, a service in your `docker-compose.yml` might look like this using environment vars (recommended):
|
||||
|
||||
```yaml
|
||||
services:
|
||||
newt:
|
||||
image: fosrl/newt
|
||||
container_name: newt
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- PANGOLIN_ENDPOINT=https://example.com
|
||||
- NEWT_ID=2ix2t8xk22ubpfy
|
||||
- NEWT_SECRET=nnisrfsdfc7prqsp9ewo1dvtvci50j5uiqotez00dgap0ii2
|
||||
```
|
||||
|
||||
You can also pass the CLI args to the container:
|
||||
|
||||
```yaml
|
||||
services:
|
||||
newt:
|
||||
@@ -53,11 +70,43 @@ services:
|
||||
container_name: newt
|
||||
restart: unless-stopped
|
||||
command:
|
||||
- --id 31frd0uzbjvp721 \
|
||||
- --secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 \
|
||||
- --id 31frd0uzbjvp721
|
||||
- --secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6
|
||||
- --endpoint https://example.com
|
||||
```
|
||||
|
||||
Finally a basic systemd service:
|
||||
|
||||
```
|
||||
[Unit]
|
||||
Description=Newt VPN Client
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
ExecStart=/usr/local/bin/newt --id 31frd0uzbjvp721 --secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 --endpoint https://example.com
|
||||
Restart=always
|
||||
User=root
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
```
|
||||
|
||||
Make sure to `mv ./newt /usr/local/bin/newt`!
|
||||
|
||||
### Updown
|
||||
|
||||
You can pass in a updown script for Newt to call when it is adding or removing a target:
|
||||
|
||||
`--updown "python3 test.py"`
|
||||
|
||||
It will get called with args when a target is added:
|
||||
`python3 test.py add tcp localhost:8556`
|
||||
`python3 test.py remove tcp localhost:8556`
|
||||
|
||||
Returning a string from the script in the format of a target (`ip:dst` so `10.0.0.1:8080`) it will override the target and use this value instead to proxy.
|
||||
|
||||
You can look at updown.py as a reference script to get started!
|
||||
|
||||
## Build
|
||||
|
||||
### Container
|
||||
@@ -82,4 +131,4 @@ Newt is dual licensed under the AGPLv3 and the Fossorial Commercial license. For
|
||||
|
||||
## Contributions
|
||||
|
||||
Please see [CONTRIBUTIONS](./CONTRIBUTING.md) in the repository for guidelines and best practices.
|
||||
Please see [CONTRIBUTIONS](./CONTRIBUTING.md) in the repository for guidelines and best practices.
|
||||
|
||||
14
SECURITY.md
Normal file
14
SECURITY.md
Normal file
@@ -0,0 +1,14 @@
|
||||
# Security Policy
|
||||
|
||||
If you discover a security vulnerability, please follow the steps below to responsibly disclose it to us:
|
||||
|
||||
1. **Do not create a public GitHub issue or discussion post.** This could put the security of other users at risk.
|
||||
2. Send a detailed report to [security@fossorial.io](mailto:security@fossorial.io) or send a **private** message to a maintainer on [Discord](https://discord.gg/HCJR8Xhme4). Include:
|
||||
|
||||
- Description and location of the vulnerability.
|
||||
- Potential impact of the vulnerability.
|
||||
- Steps to reproduce the vulnerability.
|
||||
- Potential solutions to fix the vulnerability.
|
||||
- Your name/handle and a link for recognition (optional).
|
||||
|
||||
We aim to address the issue as soon as possible.
|
||||
10
docker-compose.yml
Normal file
10
docker-compose.yml
Normal file
@@ -0,0 +1,10 @@
|
||||
services:
|
||||
newt:
|
||||
image: fosrl/newt:latest
|
||||
container_name: newt
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- PANGOLIN_ENDPOINT=https://example.com
|
||||
- NEWT_ID=2ix2t8xk22ubpfy
|
||||
- NEWT_SECRET=nnisrfsdfc7prqsp9ewo1dvtvci50j5uiqotez00dgap0ii2
|
||||
- LOG_LEVEL=DEBUG
|
||||
@@ -1,7 +1,5 @@
|
||||
#!/bin/sh
|
||||
|
||||
# Sample from https://github.com/traefik/traefik-library-image/blob/5070edb25b03cca6802d75d5037576c840f73fdd/v3.1/alpine/entrypoint.sh
|
||||
|
||||
set -e
|
||||
|
||||
# first arg is `-f` or `--some-option`
|
||||
@@ -9,13 +7,4 @@ if [ "${1#-}" != "$1" ]; then
|
||||
set -- newt "$@"
|
||||
fi
|
||||
|
||||
# if our command is a valid newt subcommand, let's invoke it through newt instead
|
||||
# (this allows for "docker run newt version", etc)
|
||||
if newt "$1" --help >/dev/null 2>&1
|
||||
then
|
||||
set -- newt "$@"
|
||||
else
|
||||
echo "= '$1' is not a newt command: assuming shell execution." 1>&2
|
||||
fi
|
||||
|
||||
exec "$@"
|
||||
26
go.mod
26
go.mod
@@ -4,16 +4,28 @@ go 1.23.1
|
||||
|
||||
toolchain go1.23.2
|
||||
|
||||
require golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
|
||||
require (
|
||||
github.com/google/gopacket v1.1.19
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/vishvananda/netlink v1.3.0
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa
|
||||
golang.org/x/net v0.33.0
|
||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/google/btree v1.1.2 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
golang.org/x/crypto v0.28.0 // indirect
|
||||
golang.org/x/net v0.30.0 // indirect
|
||||
golang.org/x/sys v0.26.0 // indirect
|
||||
github.com/google/go-cmp v0.6.0 // indirect
|
||||
github.com/josharian/native v1.1.0 // indirect
|
||||
github.com/mdlayher/genetlink v1.3.2 // indirect
|
||||
github.com/mdlayher/netlink v1.7.2 // indirect
|
||||
github.com/mdlayher/socket v0.5.1 // indirect
|
||||
github.com/vishvananda/netns v0.0.4 // indirect
|
||||
golang.org/x/crypto v0.31.0 // indirect
|
||||
golang.org/x/sync v0.11.0 // indirect
|
||||
golang.org/x/sys v0.28.0 // indirect
|
||||
golang.org/x/time v0.7.0 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 // indirect
|
||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect
|
||||
)
|
||||
|
||||
52
go.sum
52
go.sum
@@ -1,20 +1,56 @@
|
||||
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
|
||||
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
|
||||
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
|
||||
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
|
||||
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
|
||||
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
|
||||
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
|
||||
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
|
||||
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
||||
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
|
||||
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
|
||||
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
|
||||
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
|
||||
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
|
||||
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
|
||||
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
|
||||
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
|
||||
github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk=
|
||||
github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs=
|
||||
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
||||
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4=
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk=
|
||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
||||
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
|
||||
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
|
||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
|
||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
|
||||
|
||||
272
main.go
272
main.go
@@ -11,7 +11,10 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -19,6 +22,7 @@ import (
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/proxy"
|
||||
"github.com/fosrl/newt/websocket"
|
||||
"github.com/fosrl/newt/wg"
|
||||
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv4"
|
||||
@@ -53,7 +57,7 @@ func fixKey(key string) string {
|
||||
// Decode from base64
|
||||
decoded, err := base64.StdEncoding.DecodeString(key)
|
||||
if err != nil {
|
||||
logger.Fatal("Error decoding base64:", err)
|
||||
logger.Fatal("Error decoding base64")
|
||||
}
|
||||
|
||||
// Convert to hex
|
||||
@@ -112,6 +116,27 @@ func ping(tnet *netstack.Net, dst string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{}) {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
err := ping(tnet, serverIP)
|
||||
if err != nil {
|
||||
logger.Warn("Periodic ping failed: %v", err)
|
||||
logger.Warn("HINT: Do you have UDP port 51820 (or the port in config.yml) open on your Pangolin server?")
|
||||
}
|
||||
case <-stopChan:
|
||||
logger.Info("Stopping ping check")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func pingWithRetry(tnet *netstack.Net, dst string) error {
|
||||
const (
|
||||
maxAttempts = 5
|
||||
@@ -190,6 +215,9 @@ func resolveDomain(domain string) (string, error) {
|
||||
host = strings.TrimPrefix(host, "https://")
|
||||
}
|
||||
|
||||
// if there are any trailing slashes, remove them
|
||||
host = strings.TrimSuffix(host, "/")
|
||||
|
||||
// Lookup IP addresses
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
@@ -222,39 +250,79 @@ func resolveDomain(domain string) (string, error) {
|
||||
return ipAddr, nil
|
||||
}
|
||||
|
||||
func getEnvWithDefault(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
var (
|
||||
endpoint string
|
||||
id string
|
||||
secret string
|
||||
mtu string
|
||||
mtuInt int
|
||||
dns string
|
||||
privateKey wgtypes.Key
|
||||
err error
|
||||
logLevel string
|
||||
updownScript string
|
||||
interfaceName string
|
||||
generateAndSaveKeyTo string
|
||||
)
|
||||
|
||||
func main() {
|
||||
var (
|
||||
endpoint string
|
||||
id string
|
||||
secret string
|
||||
dns string
|
||||
privateKey wgtypes.Key
|
||||
err error
|
||||
logLevel string
|
||||
)
|
||||
// if PANGOLIN_ENDPOINT, NEWT_ID, and NEWT_SECRET are set as environment variables, they will be used as default values
|
||||
endpoint = os.Getenv("PANGOLIN_ENDPOINT")
|
||||
id = os.Getenv("NEWT_ID")
|
||||
secret = os.Getenv("NEWT_SECRET")
|
||||
mtu = os.Getenv("MTU")
|
||||
dns = os.Getenv("DNS")
|
||||
logLevel = os.Getenv("LOG_LEVEL")
|
||||
updownScript = os.Getenv("UPDOWN_SCRIPT")
|
||||
interfaceName = os.Getenv("INTERFACE")
|
||||
generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO")
|
||||
|
||||
if endpoint == "" {
|
||||
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
|
||||
}
|
||||
if id == "" {
|
||||
flag.StringVar(&id, "id", "", "Newt ID")
|
||||
}
|
||||
if secret == "" {
|
||||
flag.StringVar(&secret, "secret", "", "Newt secret")
|
||||
}
|
||||
if mtu == "" {
|
||||
flag.StringVar(&mtu, "mtu", "1280", "MTU to use")
|
||||
}
|
||||
if dns == "" {
|
||||
flag.StringVar(&dns, "dns", "8.8.8.8", "DNS server to use")
|
||||
}
|
||||
if logLevel == "" {
|
||||
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
||||
}
|
||||
if updownScript == "" {
|
||||
flag.StringVar(&updownScript, "updown", "", "Path to updown script to be called when targets are added or removed")
|
||||
}
|
||||
if interfaceName == "" {
|
||||
flag.StringVar(&interfaceName, "interface", "wg1", "Name of the WireGuard interface")
|
||||
}
|
||||
if generateAndSaveKeyTo == "" {
|
||||
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
|
||||
}
|
||||
|
||||
// do a --version check
|
||||
version := flag.Bool("version", false, "Print the version")
|
||||
|
||||
// Define CLI flags with default values from environment variables
|
||||
flag.StringVar(&endpoint, "endpoint", os.Getenv("PANGOLIN_ENDPOINT"), "Endpoint of your pangolin server")
|
||||
flag.StringVar(&id, "id", os.Getenv("NEWT_ID"), "Newt ID")
|
||||
flag.StringVar(&secret, "secret", os.Getenv("NEWT_SECRET"), "Newt secret")
|
||||
flag.StringVar(&dns, "dns", getEnvWithDefault("DEFAULT_DNS", "8.8.8.8"), "DNS server to use")
|
||||
flag.StringVar(&logLevel, "log-level", getEnvWithDefault("LOG_LEVEL", "INFO"), "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
||||
flag.Parse()
|
||||
|
||||
if *version {
|
||||
fmt.Println("Newt version replaceme")
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
logger.Init()
|
||||
loggerLevel := parseLogLevel(logLevel)
|
||||
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
|
||||
|
||||
// Validate required fields
|
||||
if endpoint == "" || id == "" || secret == "" {
|
||||
logger.Fatal("endpoint, id, and secret are required either via CLI flags or environment variables")
|
||||
// parse the mtu string into an int
|
||||
mtuInt, err = strconv.Atoi(mtu)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to parse MTU: %v", err)
|
||||
}
|
||||
|
||||
privateKey, err = wgtypes.GeneratePrivateKey()
|
||||
@@ -272,6 +340,7 @@ func main() {
|
||||
logger.Fatal("Failed to create client: %v", err)
|
||||
}
|
||||
|
||||
var wgService *wg.WireGuardService
|
||||
// Create TUN device and network stack
|
||||
var tun tun.Device
|
||||
var tnet *netstack.Net
|
||||
@@ -280,6 +349,30 @@ func main() {
|
||||
var connected bool
|
||||
var wgData WgData
|
||||
|
||||
if generateAndSaveKeyTo != "" {
|
||||
// make sure we are running on linux
|
||||
if runtime.GOOS != "linux" {
|
||||
logger.Fatal("Tunnel management is only supported on Linux right now!")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
var host = endpoint
|
||||
if strings.HasPrefix(host, "http://") {
|
||||
host = strings.TrimPrefix(host, "http://")
|
||||
} else if strings.HasPrefix(host, "https://") {
|
||||
host = strings.TrimPrefix(host, "https://")
|
||||
}
|
||||
|
||||
host = strings.TrimSuffix(host, "/")
|
||||
|
||||
// Create WireGuard service
|
||||
wgService, err = wg.NewWireGuardService(interfaceName, mtuInt, generateAndSaveKeyTo, host, id, client)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to create WireGuard service: %v", err)
|
||||
}
|
||||
defer wgService.Close()
|
||||
}
|
||||
|
||||
client.RegisterHandler("newt/terminate", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received terminate message")
|
||||
if pm != nil {
|
||||
@@ -291,17 +384,21 @@ func main() {
|
||||
client.Close()
|
||||
})
|
||||
|
||||
pingStopChan := make(chan struct{})
|
||||
defer close(pingStopChan)
|
||||
|
||||
// Register handlers for different message types
|
||||
client.RegisterHandler("newt/wg/connect", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received registration message")
|
||||
|
||||
if connected {
|
||||
logger.Info("Already connected! Put I will send a ping anyway...")
|
||||
logger.Info("Already connected! But I will send a ping anyway...")
|
||||
// ping(tnet, wgData.ServerIP)
|
||||
err = pingWithRetry(tnet, wgData.ServerIP)
|
||||
if err != nil {
|
||||
// Handle complete failure after all retries
|
||||
logger.Error("Failed to ping %s: %v", wgData.ServerIP, err)
|
||||
logger.Warn("Failed to ping %s: %v", wgData.ServerIP, err)
|
||||
logger.Warn("HINT: Do you have UDP port 51820 (or the port in config.yml) open on your Pangolin server?")
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -317,11 +414,15 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
if wgService != nil {
|
||||
wgService.SetServerPubKey(wgData.PublicKey)
|
||||
}
|
||||
|
||||
logger.Info("Received: %+v", msg)
|
||||
tun, tnet, err = netstack.CreateNetTUN(
|
||||
[]netip.Addr{netip.MustParseAddr(wgData.TunnelIP)},
|
||||
[]netip.Addr{netip.MustParseAddr(dns)},
|
||||
1420)
|
||||
mtuInt)
|
||||
if err != nil {
|
||||
logger.Error("Failed to create TUN device: %v", err)
|
||||
}
|
||||
@@ -363,6 +464,12 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
|
||||
if err != nil {
|
||||
// Handle complete failure after all retries
|
||||
logger.Error("Failed to ping %s: %v", wgData.ServerIP, err)
|
||||
fmt.Sprintf("%s", privateKey)
|
||||
}
|
||||
|
||||
if !connected {
|
||||
logger.Info("Starting ping check")
|
||||
startPingCheck(tnet, wgData.ServerIP, pingStopChan)
|
||||
}
|
||||
|
||||
// Create proxy manager
|
||||
@@ -379,6 +486,13 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
|
||||
updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: wgData.Targets.UDP})
|
||||
}
|
||||
|
||||
// first make sure the wpgService has a port
|
||||
if wgService != nil {
|
||||
// add a udp proxy for localost and the wgService port
|
||||
// TODO: make sure this port is not used in a target
|
||||
pm.AddTarget("udp", wgData.TunnelIP, int(wgService.Port), fmt.Sprintf("127.0.0.1:%d", wgService.Port))
|
||||
}
|
||||
|
||||
err = pm.Start()
|
||||
if err != nil {
|
||||
logger.Error("Failed to start proxy manager: %v", err)
|
||||
@@ -403,11 +517,6 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
|
||||
if len(targetData.Targets) > 0 {
|
||||
updateTargets(pm, "add", wgData.TunnelIP, "tcp", targetData)
|
||||
}
|
||||
|
||||
err = pm.Start()
|
||||
if err != nil {
|
||||
logger.Error("Failed to start proxy manager: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
client.RegisterHandler("newt/udp/add", func(msg websocket.WSMessage) {
|
||||
@@ -428,11 +537,6 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
|
||||
if len(targetData.Targets) > 0 {
|
||||
updateTargets(pm, "add", wgData.TunnelIP, "udp", targetData)
|
||||
}
|
||||
|
||||
err = pm.Start()
|
||||
if err != nil {
|
||||
logger.Error("Failed to start proxy manager: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
client.RegisterHandler("newt/udp/remove", func(msg websocket.WSMessage) {
|
||||
@@ -487,10 +591,20 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
|
||||
return err
|
||||
}
|
||||
|
||||
if wgService != nil {
|
||||
wgService.LoadRemoteConfig()
|
||||
}
|
||||
|
||||
logger.Info("Sent registration message")
|
||||
return nil
|
||||
})
|
||||
|
||||
client.OnTokenUpdate(func(token string) {
|
||||
if wgService != nil {
|
||||
wgService.SetToken(token)
|
||||
}
|
||||
})
|
||||
|
||||
// Connect to the WebSocket server
|
||||
if err := client.Connect(); err != nil {
|
||||
logger.Fatal("Failed to connect to server: %v", err)
|
||||
@@ -502,8 +616,21 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sigCh
|
||||
|
||||
// Cleanup
|
||||
dev.Close()
|
||||
|
||||
if wgService != nil {
|
||||
wgService.Close()
|
||||
}
|
||||
|
||||
if pm != nil {
|
||||
pm.Stop()
|
||||
}
|
||||
|
||||
if client != nil {
|
||||
client.Close()
|
||||
}
|
||||
logger.Info("Exiting...")
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func parseTargetData(data interface{}) (TargetData, error) {
|
||||
@@ -540,6 +667,18 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
|
||||
|
||||
if action == "add" {
|
||||
target := parts[1] + ":" + parts[2]
|
||||
|
||||
// Call updown script if provided
|
||||
processedTarget := target
|
||||
if updownScript != "" {
|
||||
newTarget, err := executeUpdownScript(action, proto, target)
|
||||
if err != nil {
|
||||
logger.Warn("Updown script error: %v", err)
|
||||
} else if newTarget != "" {
|
||||
processedTarget = newTarget
|
||||
}
|
||||
}
|
||||
|
||||
// Only remove the specific target if it exists
|
||||
err := pm.RemoveTarget(proto, tunnelIP, port)
|
||||
if err != nil {
|
||||
@@ -550,10 +689,21 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
|
||||
}
|
||||
|
||||
// Add the new target
|
||||
pm.AddTarget(proto, tunnelIP, port, target)
|
||||
pm.AddTarget(proto, tunnelIP, port, processedTarget)
|
||||
|
||||
} else if action == "remove" {
|
||||
logger.Info("Removing target with port %d", port)
|
||||
|
||||
target := parts[1] + ":" + parts[2]
|
||||
|
||||
// Call updown script if provided
|
||||
if updownScript != "" {
|
||||
_, err := executeUpdownScript(action, proto, target)
|
||||
if err != nil {
|
||||
logger.Warn("Updown script error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
err := pm.RemoveTarget(proto, tunnelIP, port)
|
||||
if err != nil {
|
||||
logger.Error("Failed to remove target: %v", err)
|
||||
@@ -564,3 +714,45 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func executeUpdownScript(action, proto, target string) (string, error) {
|
||||
if updownScript == "" {
|
||||
return target, nil
|
||||
}
|
||||
|
||||
// Split the updownScript in case it contains spaces (like "/usr/bin/python3 script.py")
|
||||
parts := strings.Fields(updownScript)
|
||||
if len(parts) == 0 {
|
||||
return target, fmt.Errorf("invalid updown script command")
|
||||
}
|
||||
|
||||
var cmd *exec.Cmd
|
||||
if len(parts) == 1 {
|
||||
// If it's a single executable
|
||||
logger.Info("Executing updown script: %s %s %s %s", updownScript, action, proto, target)
|
||||
cmd = exec.Command(parts[0], action, proto, target)
|
||||
} else {
|
||||
// If it includes interpreter and script
|
||||
args := append(parts[1:], action, proto, target)
|
||||
logger.Info("Executing updown script: %s %s %s %s %s", parts[0], strings.Join(parts[1:], " "), action, proto, target)
|
||||
cmd = exec.Command(parts[0], args...)
|
||||
}
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
return "", fmt.Errorf("updown script execution failed (exit code %d): %s",
|
||||
exitErr.ExitCode(), string(exitErr.Stderr))
|
||||
}
|
||||
return "", fmt.Errorf("updown script execution failed: %v", err)
|
||||
}
|
||||
|
||||
// If the script returns a new target, use it
|
||||
newTarget := strings.TrimSpace(string(output))
|
||||
if newTarget != "" {
|
||||
logger.Info("Updown script returned new target: %s", newTarget)
|
||||
return newTarget, nil
|
||||
}
|
||||
|
||||
return target, nil
|
||||
}
|
||||
|
||||
202
network/network.go
Normal file
202
network/network.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/net/bpf"
|
||||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
||||
const (
|
||||
udpProtocol = 17
|
||||
// EmptyUDPSize is the size of an empty UDP packet
|
||||
EmptyUDPSize = 28
|
||||
timeout = time.Second * 10
|
||||
)
|
||||
|
||||
// Server stores data relating to the server
|
||||
type Server struct {
|
||||
Hostname string
|
||||
Addr *net.IPAddr
|
||||
Port uint16
|
||||
}
|
||||
|
||||
// PeerNet stores data about a peer's endpoint
|
||||
type PeerNet struct {
|
||||
Resolved bool
|
||||
IP net.IP
|
||||
Port uint16
|
||||
NewtID string
|
||||
}
|
||||
|
||||
// GetClientIP gets source ip address that will be used when sending data to dstIP
|
||||
func GetClientIP(dstIP net.IP) net.IP {
|
||||
routes, err := netlink.RouteGet(dstIP)
|
||||
if err != nil {
|
||||
log.Fatalln("Error getting route:", err)
|
||||
}
|
||||
return routes[0].Src
|
||||
}
|
||||
|
||||
// HostToAddr resolves a hostname, whether DNS or IP to a valid net.IPAddr
|
||||
func HostToAddr(hostStr string) *net.IPAddr {
|
||||
remoteAddrs, err := net.LookupHost(hostStr)
|
||||
if err != nil {
|
||||
log.Fatalln("Error parsing remote address:", err)
|
||||
}
|
||||
|
||||
for _, addrStr := range remoteAddrs {
|
||||
if remoteAddr, err := net.ResolveIPAddr("ip4", addrStr); err == nil {
|
||||
return remoteAddr
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetupRawConn creates an ipv4 and udp only RawConn and applies packet filtering
|
||||
func SetupRawConn(server *Server, client *PeerNet) *ipv4.RawConn {
|
||||
packetConn, err := net.ListenPacket("ip4:udp", client.IP.String())
|
||||
if err != nil {
|
||||
log.Fatalln("Error creating packetConn:", err)
|
||||
}
|
||||
|
||||
rawConn, err := ipv4.NewRawConn(packetConn)
|
||||
if err != nil {
|
||||
log.Fatalln("Error creating rawConn:", err)
|
||||
}
|
||||
|
||||
ApplyBPF(rawConn, server, client)
|
||||
|
||||
return rawConn
|
||||
}
|
||||
|
||||
// ApplyBPF constructs a BPF program and applies it to the RawConn
|
||||
func ApplyBPF(rawConn *ipv4.RawConn, server *Server, client *PeerNet) {
|
||||
const ipv4HeaderLen = 20
|
||||
const srcIPOffset = 12
|
||||
const srcPortOffset = ipv4HeaderLen + 0
|
||||
const dstPortOffset = ipv4HeaderLen + 2
|
||||
|
||||
ipArr := []byte(server.Addr.IP.To4())
|
||||
ipInt := uint32(ipArr[0])<<(3*8) + uint32(ipArr[1])<<(2*8) + uint32(ipArr[2])<<8 + uint32(ipArr[3])
|
||||
|
||||
bpfRaw, err := bpf.Assemble([]bpf.Instruction{
|
||||
bpf.LoadAbsolute{Off: srcIPOffset, Size: 4},
|
||||
bpf.JumpIf{Cond: bpf.JumpEqual, Val: ipInt, SkipFalse: 5, SkipTrue: 0},
|
||||
|
||||
bpf.LoadAbsolute{Off: srcPortOffset, Size: 2},
|
||||
bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(server.Port), SkipFalse: 3, SkipTrue: 0},
|
||||
|
||||
bpf.LoadAbsolute{Off: dstPortOffset, Size: 2},
|
||||
bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(client.Port), SkipFalse: 1, SkipTrue: 0},
|
||||
|
||||
bpf.RetConstant{Val: 1<<(8*4) - 1},
|
||||
bpf.RetConstant{Val: 0},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Fatalln("Error assembling BPF:", err)
|
||||
}
|
||||
|
||||
err = rawConn.SetBPF(bpfRaw)
|
||||
if err != nil {
|
||||
log.Fatalln("Error setting BPF:", err)
|
||||
}
|
||||
}
|
||||
|
||||
// MakePacket constructs a request packet to send to the server
|
||||
func MakePacket(payload []byte, server *Server, client *PeerNet) []byte {
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
|
||||
opts := gopacket.SerializeOptions{
|
||||
FixLengths: true,
|
||||
ComputeChecksums: true,
|
||||
}
|
||||
|
||||
ipHeader := layers.IPv4{
|
||||
SrcIP: client.IP,
|
||||
DstIP: server.Addr.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
|
||||
udpHeader := layers.UDP{
|
||||
SrcPort: layers.UDPPort(client.Port),
|
||||
DstPort: layers.UDPPort(server.Port),
|
||||
}
|
||||
|
||||
payloadLayer := gopacket.Payload(payload)
|
||||
|
||||
udpHeader.SetNetworkLayerForChecksum(&ipHeader)
|
||||
|
||||
gopacket.SerializeLayers(buf, opts, &ipHeader, &udpHeader, &payloadLayer)
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// SendPacket sends packet to the Server
|
||||
func SendPacket(packet []byte, conn *ipv4.RawConn, server *Server, client *PeerNet) error {
|
||||
fullPacket := MakePacket(packet, server, client)
|
||||
_, err := conn.WriteToIP(fullPacket, server.Addr)
|
||||
return err
|
||||
}
|
||||
|
||||
// SendDataPacket sends a JSON payload to the Server
|
||||
func SendDataPacket(data interface{}, conn *ipv4.RawConn, server *Server, client *PeerNet) error {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal payload: %v", err)
|
||||
}
|
||||
|
||||
return SendPacket(jsonData, conn, server, client)
|
||||
}
|
||||
|
||||
// RecvPacket receives a UDP packet from server
|
||||
func RecvPacket(conn *ipv4.RawConn, server *Server, client *PeerNet) ([]byte, int, error) {
|
||||
err := conn.SetReadDeadline(time.Now().Add(timeout))
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
response := make([]byte, 4096)
|
||||
n, err := conn.Read(response)
|
||||
if err != nil {
|
||||
return nil, n, err
|
||||
}
|
||||
return response, n, nil
|
||||
}
|
||||
|
||||
// RecvDataPacket receives and unmarshals a JSON packet from server
|
||||
func RecvDataPacket(conn *ipv4.RawConn, server *Server, client *PeerNet) ([]byte, error) {
|
||||
response, n, err := RecvPacket(conn, server, client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Extract payload from UDP packet
|
||||
payload := response[EmptyUDPSize:n]
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// ParseResponse takes a response packet and parses it into an IP and port
|
||||
func ParseResponse(response []byte) (net.IP, uint16) {
|
||||
ip := net.IP(response[:4])
|
||||
port := binary.BigEndian.Uint16(response[4:6])
|
||||
return ip, port
|
||||
}
|
||||
|
||||
func parseForBPF(response []byte) (srcIP net.IP, srcPort uint16, dstPort uint16) {
|
||||
srcIP = net.IP(response[12:16])
|
||||
srcPort = binary.BigEndian.Uint16(response[20:22])
|
||||
dstPort = binary.BigEndian.Uint16(response[22:24])
|
||||
return
|
||||
}
|
||||
530
proxy/manager.go
530
proxy/manager.go
@@ -9,326 +9,344 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
)
|
||||
|
||||
// Target represents a proxy target with its address and port
|
||||
type Target struct {
|
||||
Address string
|
||||
Port int
|
||||
}
|
||||
|
||||
// ProxyManager handles the creation and management of proxy connections
|
||||
type ProxyManager struct {
|
||||
tnet *netstack.Net
|
||||
tcpTargets map[string]map[int]string // map[listenIP]map[port]targetAddress
|
||||
udpTargets map[string]map[int]string
|
||||
listeners []*gonet.TCPListener
|
||||
udpConns []*gonet.UDPConn
|
||||
running bool
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewProxyManager creates a new proxy manager instance
|
||||
func NewProxyManager(tnet *netstack.Net) *ProxyManager {
|
||||
return &ProxyManager{
|
||||
tnet: tnet,
|
||||
tnet: tnet,
|
||||
tcpTargets: make(map[string]map[int]string),
|
||||
udpTargets: make(map[string]map[int]string),
|
||||
listeners: make([]*gonet.TCPListener, 0),
|
||||
udpConns: make([]*gonet.UDPConn, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) AddTarget(protocol, listen string, port int, target string) {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
// AddTarget adds as new target for proxying
|
||||
func (pm *ProxyManager) AddTarget(proto, listenIP string, port int, targetAddr string) error {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
logger.Info("Adding target: %s://%s:%d -> %s", protocol, listen, port, target)
|
||||
|
||||
newTarget := ProxyTarget{
|
||||
Protocol: protocol,
|
||||
Listen: listen,
|
||||
Port: port,
|
||||
Target: target,
|
||||
cancel: make(chan struct{}),
|
||||
done: make(chan struct{}),
|
||||
switch proto {
|
||||
case "tcp":
|
||||
if pm.tcpTargets[listenIP] == nil {
|
||||
pm.tcpTargets[listenIP] = make(map[int]string)
|
||||
}
|
||||
pm.tcpTargets[listenIP][port] = targetAddr
|
||||
case "udp":
|
||||
if pm.udpTargets[listenIP] == nil {
|
||||
pm.udpTargets[listenIP] = make(map[int]string)
|
||||
}
|
||||
pm.udpTargets[listenIP][port] = targetAddr
|
||||
default:
|
||||
return fmt.Errorf("unsupported protocol: %s", proto)
|
||||
}
|
||||
|
||||
pm.targets = append(pm.targets, newTarget)
|
||||
if pm.running {
|
||||
return pm.startTarget(proto, listenIP, port, targetAddr)
|
||||
} else {
|
||||
logger.Debug("Not adding target because not running")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) RemoveTarget(protocol, listen string, port int) error {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
func (pm *ProxyManager) RemoveTarget(proto, listenIP string, port int) error {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
protocol = strings.ToLower(protocol)
|
||||
if protocol != "tcp" && protocol != "udp" {
|
||||
return fmt.Errorf("unsupported protocol: %s", protocol)
|
||||
}
|
||||
|
||||
for i, target := range pm.targets {
|
||||
if target.Listen == listen &&
|
||||
target.Port == port &&
|
||||
strings.ToLower(target.Protocol) == protocol {
|
||||
|
||||
// Signal the serving goroutine to stop
|
||||
select {
|
||||
case <-target.cancel:
|
||||
// Channel is already closed, no need to close it again
|
||||
default:
|
||||
close(target.cancel)
|
||||
}
|
||||
|
||||
// Close the appropriate listener/connection based on protocol
|
||||
target.Lock()
|
||||
switch protocol {
|
||||
case "tcp":
|
||||
if target.listener != nil {
|
||||
select {
|
||||
case <-target.cancel:
|
||||
// Listener was already closed by Stop()
|
||||
default:
|
||||
target.listener.Close()
|
||||
}
|
||||
}
|
||||
case "udp":
|
||||
if target.udpConn != nil {
|
||||
select {
|
||||
case <-target.cancel:
|
||||
// Connection was already closed by Stop()
|
||||
default:
|
||||
target.udpConn.Close()
|
||||
}
|
||||
switch proto {
|
||||
case "tcp":
|
||||
if targets, ok := pm.tcpTargets[listenIP]; ok {
|
||||
delete(targets, port)
|
||||
// Remove and close the corresponding TCP listener
|
||||
for i, listener := range pm.listeners {
|
||||
if addr, ok := listener.Addr().(*net.TCPAddr); ok && addr.Port == port {
|
||||
listener.Close()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
// Remove from slice
|
||||
pm.listeners = append(pm.listeners[:i], pm.listeners[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
target.Unlock()
|
||||
|
||||
// Wait for the target to fully stop
|
||||
<-target.done
|
||||
|
||||
// Remove the target from the slice
|
||||
pm.targets = append(pm.targets[:i], pm.targets[i+1:]...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("target not found for %s %s:%d", protocol, listen, port)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) Start() error {
|
||||
pm.RLock()
|
||||
defer pm.RUnlock()
|
||||
|
||||
for i := range pm.targets {
|
||||
target := &pm.targets[i]
|
||||
|
||||
target.Lock()
|
||||
// If target is already running, skip it
|
||||
if target.listener != nil || target.udpConn != nil {
|
||||
target.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
// Mark the target as starting by creating a nil listener/connection
|
||||
// This prevents other goroutines from trying to start it
|
||||
if strings.ToLower(target.Protocol) == "tcp" {
|
||||
target.listener = nil
|
||||
} else {
|
||||
target.udpConn = nil
|
||||
return fmt.Errorf("target not found: %s:%d", listenIP, port)
|
||||
}
|
||||
target.Unlock()
|
||||
case "udp":
|
||||
if targets, ok := pm.udpTargets[listenIP]; ok {
|
||||
delete(targets, port)
|
||||
// Remove and close the corresponding UDP connection
|
||||
for i, conn := range pm.udpConns {
|
||||
if addr, ok := conn.LocalAddr().(*net.UDPAddr); ok && addr.Port == port {
|
||||
conn.Close()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
// Remove from slice
|
||||
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("target not found: %s:%d", listenIP, port)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported protocol: %s", proto)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
switch strings.ToLower(target.Protocol) {
|
||||
case "tcp":
|
||||
go pm.serveTCP(target)
|
||||
case "udp":
|
||||
go pm.serveUDP(target)
|
||||
default:
|
||||
return fmt.Errorf("unsupported protocol: %s", target.Protocol)
|
||||
// Start begins listening for all configured proxy targets
|
||||
func (pm *ProxyManager) Start() error {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
if pm.running {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start TCP targets
|
||||
for listenIP, targets := range pm.tcpTargets {
|
||||
for port, targetAddr := range targets {
|
||||
if err := pm.startTarget("tcp", listenIP, port, targetAddr); err != nil {
|
||||
return fmt.Errorf("failed to start TCP target: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Start UDP targets
|
||||
for listenIP, targets := range pm.udpTargets {
|
||||
for port, targetAddr := range targets {
|
||||
if err := pm.startTarget("udp", listenIP, port, targetAddr); err != nil {
|
||||
return fmt.Errorf("failed to start UDP target: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pm.running = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) Stop() error {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := range pm.targets {
|
||||
target := &pm.targets[i]
|
||||
wg.Add(1)
|
||||
go func(t *ProxyTarget) {
|
||||
defer wg.Done()
|
||||
close(t.cancel)
|
||||
t.Lock()
|
||||
if t.listener != nil {
|
||||
t.listener.Close()
|
||||
}
|
||||
if t.udpConn != nil {
|
||||
t.udpConn.Close()
|
||||
}
|
||||
t.Unlock()
|
||||
// Wait for the target to fully stop
|
||||
<-t.done
|
||||
}(target)
|
||||
if !pm.running {
|
||||
return nil
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Set running to false first to signal handlers to stop
|
||||
pm.running = false
|
||||
|
||||
// Close TCP listeners
|
||||
for i := len(pm.listeners) - 1; i >= 0; i-- {
|
||||
listener := pm.listeners[i]
|
||||
if err := listener.Close(); err != nil {
|
||||
logger.Error("Error closing TCP listener: %v", err)
|
||||
}
|
||||
// Remove from slice
|
||||
pm.listeners = append(pm.listeners[:i], pm.listeners[i+1:]...)
|
||||
}
|
||||
|
||||
// Close UDP connections
|
||||
for i := len(pm.udpConns) - 1; i >= 0; i-- {
|
||||
conn := pm.udpConns[i]
|
||||
if err := conn.Close(); err != nil {
|
||||
logger.Error("Error closing UDP connection: %v", err)
|
||||
}
|
||||
// Remove from slice
|
||||
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
|
||||
}
|
||||
|
||||
// Clear the target maps
|
||||
for k := range pm.tcpTargets {
|
||||
delete(pm.tcpTargets, k)
|
||||
}
|
||||
for k := range pm.udpTargets {
|
||||
delete(pm.udpTargets, k)
|
||||
}
|
||||
|
||||
// Give active connections a chance to close gracefully
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) serveTCP(target *ProxyTarget) {
|
||||
defer close(target.done) // Signal that this target is fully stopped
|
||||
func (pm *ProxyManager) startTarget(proto, listenIP string, port int, targetAddr string) error {
|
||||
switch proto {
|
||||
case "tcp":
|
||||
listener, err := pm.tnet.ListenTCP(&net.TCPAddr{Port: port})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create TCP listener: %v", err)
|
||||
}
|
||||
|
||||
listener, err := pm.tnet.ListenTCP(&net.TCPAddr{
|
||||
IP: net.ParseIP(target.Listen),
|
||||
Port: target.Port,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Info("Failed to start TCP listener for %s:%d: %v", target.Listen, target.Port, err)
|
||||
return
|
||||
pm.listeners = append(pm.listeners, listener)
|
||||
go pm.handleTCPProxy(listener, targetAddr)
|
||||
|
||||
case "udp":
|
||||
addr := &net.UDPAddr{Port: port}
|
||||
conn, err := pm.tnet.ListenUDP(addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create UDP listener: %v", err)
|
||||
}
|
||||
|
||||
pm.udpConns = append(pm.udpConns, conn)
|
||||
go pm.handleUDPProxy(conn, targetAddr)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unsupported protocol: %s", proto)
|
||||
}
|
||||
|
||||
target.Lock()
|
||||
target.listener = listener
|
||||
target.Unlock()
|
||||
logger.Info("Started %s proxy from %s:%d to %s", proto, listenIP, port, targetAddr)
|
||||
|
||||
defer listener.Close()
|
||||
logger.Info("TCP proxy listening on %s", listener.Addr())
|
||||
|
||||
var activeConns sync.WaitGroup
|
||||
acceptDone := make(chan struct{})
|
||||
|
||||
// Goroutine to handle shutdown signal
|
||||
go func() {
|
||||
<-target.cancel
|
||||
close(acceptDone)
|
||||
listener.Close()
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string) {
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-target.cancel:
|
||||
// Wait for active connections to finish
|
||||
activeConns.Wait()
|
||||
// Check if we're shutting down or the listener was closed
|
||||
if !pm.running {
|
||||
return
|
||||
default:
|
||||
logger.Info("Failed to accept TCP connection: %v", err)
|
||||
// Don't return here, try to accept new connections
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for specific network errors that indicate the listener is closed
|
||||
if ne, ok := err.(net.Error); ok && !ne.Temporary() {
|
||||
logger.Info("TCP listener closed, stopping proxy handler for %v", listener.Addr())
|
||||
return
|
||||
}
|
||||
|
||||
logger.Error("Error accepting TCP connection: %v", err)
|
||||
// Don't hammer the CPU if we hit a temporary error
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
activeConns.Add(1)
|
||||
go func() {
|
||||
defer activeConns.Done()
|
||||
pm.handleTCPConnection(conn, target.Target, acceptDone)
|
||||
target, err := net.Dial("tcp", targetAddr)
|
||||
if err != nil {
|
||||
logger.Error("Error connecting to target: %v", err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// Create a WaitGroup to ensure both copy operations complete
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
io.Copy(target, conn)
|
||||
target.Close()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
io.Copy(conn, target)
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
// Wait for both copies to complete
|
||||
wg.Wait()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) handleTCPConnection(clientConn net.Conn, target string, done chan struct{}) {
|
||||
defer clientConn.Close()
|
||||
|
||||
serverConn, err := net.Dial("tcp", target)
|
||||
if err != nil {
|
||||
logger.Info("Failed to connect to target %s: %v", target, err)
|
||||
return
|
||||
}
|
||||
defer serverConn.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
// Client -> Server
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
io.Copy(serverConn, clientConn)
|
||||
}
|
||||
}()
|
||||
|
||||
// Server -> Client
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
io.Copy(clientConn, serverConn)
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) serveUDP(target *ProxyTarget) {
|
||||
defer close(target.done) // Signal that this target is fully stopped
|
||||
|
||||
addr := &net.UDPAddr{
|
||||
IP: net.ParseIP(target.Listen),
|
||||
Port: target.Port,
|
||||
}
|
||||
|
||||
conn, err := pm.tnet.ListenUDP(addr)
|
||||
if err != nil {
|
||||
logger.Info("Failed to start UDP listener for %s:%d: %v", target.Listen, target.Port, err)
|
||||
return
|
||||
}
|
||||
|
||||
target.Lock()
|
||||
target.udpConn = conn
|
||||
target.Unlock()
|
||||
|
||||
defer conn.Close()
|
||||
logger.Info("UDP proxy listening on %s", conn.LocalAddr())
|
||||
|
||||
buffer := make([]byte, 65535)
|
||||
var activeConns sync.WaitGroup
|
||||
func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
|
||||
buffer := make([]byte, 65507) // Max UDP packet size
|
||||
clientConns := make(map[string]*net.UDPConn)
|
||||
var clientsMutex sync.RWMutex
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-target.cancel:
|
||||
activeConns.Wait() // Wait for all active UDP handlers to complete
|
||||
return
|
||||
default:
|
||||
n, remoteAddr, err := conn.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
select {
|
||||
case <-target.cancel:
|
||||
activeConns.Wait()
|
||||
return
|
||||
default:
|
||||
logger.Info("Failed to read UDP packet: %v", err)
|
||||
continue
|
||||
}
|
||||
n, remoteAddr, err := conn.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
if !pm.running {
|
||||
return
|
||||
}
|
||||
|
||||
targetAddr, err := net.ResolveUDPAddr("udp", target.Target)
|
||||
// Check for connection closed conditions
|
||||
if err == io.EOF || strings.Contains(err.Error(), "use of closed network connection") {
|
||||
logger.Info("UDP connection closed, stopping proxy handler")
|
||||
|
||||
// Clean up existing client connections
|
||||
clientsMutex.Lock()
|
||||
for _, targetConn := range clientConns {
|
||||
targetConn.Close()
|
||||
}
|
||||
clientConns = nil
|
||||
clientsMutex.Unlock()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
logger.Error("Error reading UDP packet: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
clientKey := remoteAddr.String()
|
||||
clientsMutex.RLock()
|
||||
targetConn, exists := clientConns[clientKey]
|
||||
clientsMutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
targetUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr)
|
||||
if err != nil {
|
||||
logger.Info("Failed to resolve target address %s: %v", target.Target, err)
|
||||
logger.Error("Error resolving target address: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
activeConns.Add(1)
|
||||
go func(data []byte, remote net.Addr) {
|
||||
defer activeConns.Done()
|
||||
targetConn, err := net.DialUDP("udp", nil, targetAddr)
|
||||
if err != nil {
|
||||
logger.Info("Failed to connect to target %s: %v", target.Target, err)
|
||||
return
|
||||
}
|
||||
defer targetConn.Close()
|
||||
targetConn, err = net.DialUDP("udp", nil, targetUDPAddr)
|
||||
if err != nil {
|
||||
logger.Error("Error connecting to target: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case <-target.cancel:
|
||||
return
|
||||
default:
|
||||
_, err = targetConn.Write(data)
|
||||
clientsMutex.Lock()
|
||||
clientConns[clientKey] = targetConn
|
||||
clientsMutex.Unlock()
|
||||
|
||||
go func() {
|
||||
buffer := make([]byte, 65507)
|
||||
for {
|
||||
n, _, err := targetConn.ReadFromUDP(buffer)
|
||||
if err != nil {
|
||||
logger.Info("Failed to write to target: %v", err)
|
||||
logger.Error("Error reading from target: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]byte, 65535)
|
||||
n, err := targetConn.Read(response)
|
||||
_, err = conn.WriteTo(buffer[:n], remoteAddr)
|
||||
if err != nil {
|
||||
logger.Info("Failed to read response from target: %v", err)
|
||||
logger.Error("Error writing to client: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = conn.WriteTo(response[:n], remote)
|
||||
if err != nil {
|
||||
logger.Info("Failed to write response to client: %v", err)
|
||||
}
|
||||
}
|
||||
}(buffer[:n], remoteAddr)
|
||||
}()
|
||||
}
|
||||
|
||||
_, err = targetConn.Write(buffer[:n])
|
||||
if err != nil {
|
||||
logger.Error("Error writing to target: %v", err)
|
||||
targetConn.Close()
|
||||
clientsMutex.Lock()
|
||||
delete(clientConns, clientKey)
|
||||
clientsMutex.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
)
|
||||
|
||||
type ProxyTarget struct {
|
||||
Protocol string
|
||||
Listen string
|
||||
Port int
|
||||
Target string
|
||||
cancel chan struct{} // Channel to signal shutdown
|
||||
done chan struct{} // Channel to signal completion
|
||||
listener net.Listener // For TCP
|
||||
udpConn net.PacketConn // For UDP
|
||||
sync.Mutex // Protect access to connection
|
||||
}
|
||||
|
||||
type ProxyManager struct {
|
||||
targets []ProxyTarget
|
||||
tnet *netstack.Net
|
||||
log *log.Logger
|
||||
sync.RWMutex // Protect access to targets slice
|
||||
}
|
||||
77
updown.py
Normal file
77
updown.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
Sample updown script for Newt proxy
|
||||
Usage: update.py <action> <protocol> <target>
|
||||
|
||||
Parameters:
|
||||
- action: 'add' or 'remove'
|
||||
- protocol: 'tcp' or 'udp'
|
||||
- target: the target address in format 'host:port'
|
||||
|
||||
If the action is 'add', the script can return a modified target that
|
||||
will be used instead of the original.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import logging
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
# Configure logging
|
||||
LOG_FILE = "/tmp/newt-updown.log"
|
||||
logging.basicConfig(
|
||||
filename=LOG_FILE,
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
def log_event(action, protocol, target):
|
||||
"""Log each event to a file for auditing purposes"""
|
||||
timestamp = datetime.now().isoformat()
|
||||
event = {
|
||||
"timestamp": timestamp,
|
||||
"action": action,
|
||||
"protocol": protocol,
|
||||
"target": target
|
||||
}
|
||||
logging.info(json.dumps(event))
|
||||
|
||||
def handle_add(protocol, target):
|
||||
"""Handle 'add' action"""
|
||||
logging.info(f"Adding {protocol} target: {target}")
|
||||
|
||||
def handle_remove(protocol, target):
|
||||
"""Handle 'remove' action"""
|
||||
logging.info(f"Removing {protocol} target: {target}")
|
||||
# For remove action, no return value is expected or used
|
||||
|
||||
def main():
|
||||
# Check arguments
|
||||
if len(sys.argv) != 4:
|
||||
logging.error(f"Invalid arguments: {sys.argv}")
|
||||
sys.exit(1)
|
||||
|
||||
action = sys.argv[1]
|
||||
protocol = sys.argv[2]
|
||||
target = sys.argv[3]
|
||||
|
||||
# Log the event
|
||||
log_event(action, protocol, target)
|
||||
|
||||
# Handle the action
|
||||
if action == "add":
|
||||
new_target = handle_add(protocol, target)
|
||||
# Print the new target to stdout (if empty, no change will be made)
|
||||
if new_target and new_target != target:
|
||||
print(new_target)
|
||||
elif action == "remove":
|
||||
handle_remove(protocol, target)
|
||||
else:
|
||||
logging.error(f"Unknown action: {action}")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except Exception as e:
|
||||
logging.error(f"Unhandled exception: {e}")
|
||||
sys.exit(1)
|
||||
@@ -27,7 +27,8 @@ type Client struct {
|
||||
isConnected bool
|
||||
reconnectMux sync.RWMutex
|
||||
|
||||
onConnect func() error
|
||||
onConnect func() error
|
||||
onTokenUpdate func(token string)
|
||||
}
|
||||
|
||||
type ClientOption func(*Client)
|
||||
@@ -45,6 +46,10 @@ func (c *Client) OnConnect(callback func() error) {
|
||||
c.onConnect = callback
|
||||
}
|
||||
|
||||
func (c *Client) OnTokenUpdate(callback func(token string)) {
|
||||
c.onTokenUpdate = callback
|
||||
}
|
||||
|
||||
// NewClient creates a new Newt client
|
||||
func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*Client, error) {
|
||||
config := &Config{
|
||||
@@ -228,6 +233,10 @@ func (c *Client) getToken() (string, error) {
|
||||
|
||||
var tokenResp TokenResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
// print out the token response for debugging
|
||||
buf := new(bytes.Buffer)
|
||||
buf.ReadFrom(resp.Body)
|
||||
logger.Info("Token response: %s", buf.String())
|
||||
return "", fmt.Errorf("failed to decode token response: %w", err)
|
||||
}
|
||||
|
||||
@@ -266,6 +275,8 @@ func (c *Client) establishConnection() error {
|
||||
return fmt.Errorf("failed to get token: %w", err)
|
||||
}
|
||||
|
||||
c.onTokenUpdate(token)
|
||||
|
||||
// Parse the base URL to determine protocol and hostname
|
||||
baseURL, err := url.Parse(c.baseURL)
|
||||
if err != nil {
|
||||
@@ -288,6 +299,7 @@ func (c *Client) establishConnection() error {
|
||||
// Add token to query parameters
|
||||
q := u.Query()
|
||||
q.Set("token", token)
|
||||
q.Set("clientType", "newt")
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
// Connect to WebSocket
|
||||
@@ -305,6 +317,10 @@ func (c *Client) establishConnection() error {
|
||||
go c.readPump()
|
||||
|
||||
if c.onConnect != nil {
|
||||
err := c.saveConfig()
|
||||
if err != nil {
|
||||
logger.Error("Failed to save config: %v", err)
|
||||
}
|
||||
if err := c.onConnect(); err != nil {
|
||||
logger.Error("OnConnect callback failed: %v", err)
|
||||
}
|
||||
|
||||
801
wg/wg.go
Normal file
801
wg/wg.go
Normal file
@@ -0,0 +1,801 @@
|
||||
package wg
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/network"
|
||||
"github.com/fosrl/newt/websocket"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.org/x/exp/rand"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
type WgConfig struct {
|
||||
IpAddress string `json:"ipAddress"`
|
||||
Peers []Peer `json:"peers"`
|
||||
}
|
||||
|
||||
type Peer struct {
|
||||
PublicKey string `json:"publicKey"`
|
||||
AllowedIPs []string `json:"allowedIps"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
}
|
||||
|
||||
type PeerBandwidth struct {
|
||||
PublicKey string `json:"publicKey"`
|
||||
BytesIn float64 `json:"bytesIn"`
|
||||
BytesOut float64 `json:"bytesOut"`
|
||||
}
|
||||
|
||||
type PeerReading struct {
|
||||
BytesReceived int64
|
||||
BytesTransmitted int64
|
||||
LastChecked time.Time
|
||||
}
|
||||
|
||||
type WireGuardService struct {
|
||||
interfaceName string
|
||||
mtu int
|
||||
client *websocket.Client
|
||||
wgClient *wgctrl.Client
|
||||
config WgConfig
|
||||
key wgtypes.Key
|
||||
newtId string
|
||||
lastReadings map[string]PeerReading
|
||||
mu sync.Mutex
|
||||
Port uint16
|
||||
stopHolepunch chan struct{}
|
||||
host string
|
||||
serverPubKey string
|
||||
token string
|
||||
}
|
||||
|
||||
// Add this type definition
|
||||
type fixedPortBind struct {
|
||||
port uint16
|
||||
conn.Bind
|
||||
}
|
||||
|
||||
func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) {
|
||||
// Ignore the requested port and use our fixed port
|
||||
return b.Bind.Open(b.port)
|
||||
}
|
||||
|
||||
func NewFixedPortBind(port uint16) conn.Bind {
|
||||
return &fixedPortBind{
|
||||
port: port,
|
||||
Bind: conn.NewDefaultBind(),
|
||||
}
|
||||
}
|
||||
|
||||
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
|
||||
if maxPort < minPort {
|
||||
return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort)
|
||||
}
|
||||
|
||||
// Create a slice of all ports in the range
|
||||
portRange := make([]uint16, maxPort-minPort+1)
|
||||
for i := range portRange {
|
||||
portRange[i] = minPort + uint16(i)
|
||||
}
|
||||
|
||||
// Fisher-Yates shuffle to randomize the port order
|
||||
rand.Seed(uint64(time.Now().UnixNano()))
|
||||
for i := len(portRange) - 1; i > 0; i-- {
|
||||
j := rand.Intn(i + 1)
|
||||
portRange[i], portRange[j] = portRange[j], portRange[i]
|
||||
}
|
||||
|
||||
// Try each port in the randomized order
|
||||
for _, port := range portRange {
|
||||
addr := &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: int(port),
|
||||
}
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
if err != nil {
|
||||
continue // Port is in use or there was an error, try next port
|
||||
}
|
||||
_ = conn.SetDeadline(time.Now())
|
||||
conn.Close()
|
||||
return port, nil
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("no available UDP ports found in range %d-%d", minPort, maxPort)
|
||||
}
|
||||
|
||||
func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client) (*WireGuardService, error) {
|
||||
wgClient, err := wgctrl.New()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create WireGuard client: %v", err)
|
||||
}
|
||||
|
||||
var key wgtypes.Key
|
||||
// if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file
|
||||
if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) {
|
||||
// generate a new private key
|
||||
key, err = wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to generate private key: %v", err)
|
||||
}
|
||||
// save the key to the file
|
||||
err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to save private key: %v", err)
|
||||
}
|
||||
} else {
|
||||
keyData, err := os.ReadFile(generateAndSaveKeyTo)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to read private key: %v", err)
|
||||
}
|
||||
key, err = wgtypes.ParseKey(string(keyData))
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to parse private key: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
port, err := FindAvailableUDPPort(49152, 65535)
|
||||
if err != nil {
|
||||
fmt.Printf("Error finding available port: %v\n", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
service := &WireGuardService{
|
||||
interfaceName: interfaceName,
|
||||
mtu: mtu,
|
||||
client: wsClient,
|
||||
wgClient: wgClient,
|
||||
key: key,
|
||||
newtId: newtId,
|
||||
lastReadings: make(map[string]PeerReading),
|
||||
Port: port,
|
||||
stopHolepunch: make(chan struct{}),
|
||||
host: host,
|
||||
}
|
||||
|
||||
// Register websocket handlers
|
||||
wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig)
|
||||
wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer)
|
||||
wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) Close() {
|
||||
s.wgClient.Close()
|
||||
// Remove the WireGuard interface
|
||||
if err := s.removeInterface(); err != nil {
|
||||
logger.Error("Failed to remove WireGuard interface: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WireGuardService) SetServerPubKey(serverPubKey string) {
|
||||
s.serverPubKey = serverPubKey
|
||||
}
|
||||
|
||||
func (s *WireGuardService) SetToken(token string) {
|
||||
s.token = token
|
||||
}
|
||||
|
||||
func (s *WireGuardService) LoadRemoteConfig() error {
|
||||
|
||||
// get the exising wireguard port
|
||||
device, err := s.wgClient.Device(s.interfaceName)
|
||||
if err == nil {
|
||||
s.Port = uint16(device.ListenPort)
|
||||
logger.Info("WireGuard interface %s already exists with port %d\n", s.interfaceName, s.Port)
|
||||
}
|
||||
|
||||
err = s.client.SendMessage("newt/wg/get-config", map[string]interface{}{
|
||||
"publicKey": fmt.Sprintf("%s", s.key.PublicKey().String()),
|
||||
"port": s.Port,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send registration message: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Info("Requesting WireGuard configuration from remote server")
|
||||
|
||||
go s.periodicBandwidthCheck()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
|
||||
var config WgConfig
|
||||
|
||||
logger.Info("Received message: %v", msg)
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &config); err != nil {
|
||||
logger.Info("Error unmarshaling target data: %v", err)
|
||||
return
|
||||
}
|
||||
s.config = config
|
||||
|
||||
// Ensure the WireGuard interface and peers are configured
|
||||
if err := s.ensureWireguardInterface(config); err != nil {
|
||||
logger.Error("Failed to ensure WireGuard interface: %v", err)
|
||||
}
|
||||
|
||||
if err := s.ensureWireguardPeers(config.Peers); err != nil {
|
||||
logger.Error("Failed to ensure WireGuard peers: %v", err)
|
||||
}
|
||||
|
||||
if err := s.sendUDPHolePunch(s.host + ":21820"); err != nil {
|
||||
logger.Error("Failed to send UDP hole punch: %v", err)
|
||||
}
|
||||
|
||||
// start the UDP holepunch
|
||||
go s.keepSendingUDPHolePunch(s.host)
|
||||
}
|
||||
|
||||
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
||||
// Check if the WireGuard interface exists
|
||||
_, err := netlink.LinkByName(s.interfaceName)
|
||||
if err != nil {
|
||||
if _, ok := err.(netlink.LinkNotFoundError); ok {
|
||||
// Interface doesn't exist, so create it
|
||||
err = s.createWireGuardInterface()
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to create WireGuard interface: %v", err)
|
||||
}
|
||||
logger.Info("Created WireGuard interface %s\n", s.interfaceName)
|
||||
} else {
|
||||
logger.Fatal("Error checking for WireGuard interface: %v", err)
|
||||
}
|
||||
} else {
|
||||
logger.Info("WireGuard interface %s already exists\n", s.interfaceName)
|
||||
|
||||
// get the exising wireguard port
|
||||
device, err := s.wgClient.Device(s.interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get device: %v", err)
|
||||
}
|
||||
|
||||
// get the existing port
|
||||
s.Port = uint16(device.ListenPort)
|
||||
logger.Info("WireGuard interface %s already exists with port %d\n", s.interfaceName, s.Port)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Info("Assigning IP address %s to interface %s\n", wgconfig.IpAddress, s.interfaceName)
|
||||
// Assign IP address to the interface
|
||||
err = s.assignIPAddress(wgconfig.IpAddress)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to assign IP address: %v", err)
|
||||
}
|
||||
|
||||
// Check if the interface already exists
|
||||
_, err = s.wgClient.Device(s.interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("interface %s does not exist", s.interfaceName)
|
||||
}
|
||||
|
||||
// Parse the private key
|
||||
key, err := wgtypes.ParseKey(s.key.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse private key: %v", err)
|
||||
}
|
||||
|
||||
config := wgtypes.Config{
|
||||
PrivateKey: &key,
|
||||
ListenPort: new(int),
|
||||
}
|
||||
|
||||
// Use the service's fixed port instead of the config port
|
||||
*config.ListenPort = int(s.Port)
|
||||
|
||||
// Create and configure the WireGuard interface
|
||||
err = s.wgClient.ConfigureDevice(s.interfaceName, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to configure WireGuard device: %v", err)
|
||||
}
|
||||
|
||||
// bring up the interface
|
||||
link, err := netlink.LinkByName(s.interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get interface: %v", err)
|
||||
}
|
||||
|
||||
if err := netlink.LinkSetMTU(link, s.mtu); err != nil {
|
||||
return fmt.Errorf("failed to set MTU: %v", err)
|
||||
}
|
||||
|
||||
if err := netlink.LinkSetUp(link); err != nil {
|
||||
return fmt.Errorf("failed to bring up interface: %v", err)
|
||||
}
|
||||
|
||||
// if err := s.ensureMSSClamping(); err != nil {
|
||||
// logger.Warn("Failed to ensure MSS clamping: %v", err)
|
||||
// }
|
||||
|
||||
logger.Info("WireGuard interface %s created and configured", s.interfaceName)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) createWireGuardInterface() error {
|
||||
wgLink := &netlink.GenericLink{
|
||||
LinkAttrs: netlink.LinkAttrs{Name: s.interfaceName},
|
||||
LinkType: "wireguard",
|
||||
}
|
||||
return netlink.LinkAdd(wgLink)
|
||||
}
|
||||
|
||||
func (s *WireGuardService) assignIPAddress(ipAddress string) error {
|
||||
link, err := netlink.LinkByName(s.interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get interface: %v", err)
|
||||
}
|
||||
|
||||
addr, err := netlink.ParseAddr(ipAddress)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse IP address: %v", err)
|
||||
}
|
||||
|
||||
return netlink.AddrAdd(link, addr)
|
||||
}
|
||||
|
||||
func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
|
||||
// get the current peers
|
||||
device, err := s.wgClient.Device(s.interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get device: %v", err)
|
||||
}
|
||||
|
||||
// get the peer public keys
|
||||
var currentPeers []string
|
||||
for _, peer := range device.Peers {
|
||||
currentPeers = append(currentPeers, peer.PublicKey.String())
|
||||
}
|
||||
|
||||
// remove any peers that are not in the config
|
||||
for _, peer := range currentPeers {
|
||||
found := false
|
||||
for _, configPeer := range peers {
|
||||
if peer == configPeer.PublicKey {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
err := s.removePeer(peer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove peer: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// add any peers that are in the config but not in the current peers
|
||||
for _, configPeer := range peers {
|
||||
found := false
|
||||
for _, peer := range currentPeers {
|
||||
if configPeer.PublicKey == peer {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
err := s.addPeer(configPeer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add peer: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) {
|
||||
var peer Peer
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error marshaling data: %v", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &peer); err != nil {
|
||||
logger.Info("Error unmarshaling target data: %v", err)
|
||||
}
|
||||
|
||||
err = s.addPeer(peer)
|
||||
if err != nil {
|
||||
logger.Info("Error adding peer: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WireGuardService) addPeer(peer Peer) error {
|
||||
pubKey, err := wgtypes.ParseKey(peer.PublicKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %v", err)
|
||||
}
|
||||
|
||||
// parse allowed IPs into array of net.IPNet
|
||||
var allowedIPs []net.IPNet
|
||||
for _, ipStr := range peer.AllowedIPs {
|
||||
_, ipNet, err := net.ParseCIDR(ipStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse allowed IP: %v", err)
|
||||
}
|
||||
allowedIPs = append(allowedIPs, *ipNet)
|
||||
}
|
||||
// add keep alive using *time.Duration of 1 second
|
||||
keepalive := time.Second
|
||||
|
||||
var peerConfig wgtypes.PeerConfig
|
||||
if peer.Endpoint != "" {
|
||||
endpoint, err := net.ResolveUDPAddr("udp", peer.Endpoint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve endpoint address: %w", err)
|
||||
}
|
||||
|
||||
// make the endpoint localhost to test
|
||||
|
||||
peerConfig = wgtypes.PeerConfig{
|
||||
PublicKey: pubKey,
|
||||
AllowedIPs: allowedIPs,
|
||||
PersistentKeepaliveInterval: &keepalive,
|
||||
Endpoint: endpoint,
|
||||
}
|
||||
} else {
|
||||
peerConfig = wgtypes.PeerConfig{
|
||||
PublicKey: pubKey,
|
||||
AllowedIPs: allowedIPs,
|
||||
PersistentKeepaliveInterval: &keepalive,
|
||||
}
|
||||
logger.Info("Added peer with no endpoint!")
|
||||
}
|
||||
|
||||
config := wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{peerConfig},
|
||||
}
|
||||
|
||||
if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil {
|
||||
return fmt.Errorf("failed to add peer: %v", err)
|
||||
}
|
||||
|
||||
logger.Info("Peer %s added successfully", peer.PublicKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) handleRemovePeer(msg websocket.WSMessage) {
|
||||
// parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" }
|
||||
type RemoveRequest struct {
|
||||
PublicKey string `json:"publicKey"`
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error marshaling data: %v", err)
|
||||
}
|
||||
|
||||
var request RemoveRequest
|
||||
if err := json.Unmarshal(jsonData, &request); err != nil {
|
||||
logger.Info("Error unmarshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.removePeer(request.PublicKey); err != nil {
|
||||
logger.Info("Error removing peer: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WireGuardService) removePeer(publicKey string) error {
|
||||
pubKey, err := wgtypes.ParseKey(publicKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %v", err)
|
||||
}
|
||||
|
||||
peerConfig := wgtypes.PeerConfig{
|
||||
PublicKey: pubKey,
|
||||
Remove: true,
|
||||
}
|
||||
|
||||
config := wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{peerConfig},
|
||||
}
|
||||
|
||||
if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil {
|
||||
return fmt.Errorf("failed to remove peer: %v", err)
|
||||
}
|
||||
|
||||
logger.Info("Peer %s removed successfully", publicKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) periodicBandwidthCheck() {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
if err := s.reportPeerBandwidth(); err != nil {
|
||||
logger.Info("Failed to report peer bandwidth: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
||||
device, err := s.wgClient.Device(s.interfaceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get device: %v", err)
|
||||
}
|
||||
|
||||
peerBandwidths := []PeerBandwidth{}
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for _, peer := range device.Peers {
|
||||
publicKey := peer.PublicKey.String()
|
||||
currentReading := PeerReading{
|
||||
BytesReceived: peer.ReceiveBytes,
|
||||
BytesTransmitted: peer.TransmitBytes,
|
||||
LastChecked: now,
|
||||
}
|
||||
|
||||
var bytesInDiff, bytesOutDiff float64
|
||||
lastReading, exists := s.lastReadings[publicKey]
|
||||
|
||||
if exists {
|
||||
timeDiff := currentReading.LastChecked.Sub(lastReading.LastChecked).Seconds()
|
||||
if timeDiff > 0 {
|
||||
// Calculate bytes transferred since last reading
|
||||
bytesInDiff = float64(currentReading.BytesReceived - lastReading.BytesReceived)
|
||||
bytesOutDiff = float64(currentReading.BytesTransmitted - lastReading.BytesTransmitted)
|
||||
|
||||
// Handle counter wraparound (if the counter resets or overflows)
|
||||
if bytesInDiff < 0 {
|
||||
bytesInDiff = float64(currentReading.BytesReceived)
|
||||
}
|
||||
if bytesOutDiff < 0 {
|
||||
bytesOutDiff = float64(currentReading.BytesTransmitted)
|
||||
}
|
||||
|
||||
// Convert to MB
|
||||
bytesInMB := bytesInDiff / (1024 * 1024)
|
||||
bytesOutMB := bytesOutDiff / (1024 * 1024)
|
||||
|
||||
peerBandwidths = append(peerBandwidths, PeerBandwidth{
|
||||
PublicKey: publicKey,
|
||||
BytesIn: bytesInMB,
|
||||
BytesOut: bytesOutMB,
|
||||
})
|
||||
} else {
|
||||
// If readings are too close together or time hasn't passed, report 0
|
||||
peerBandwidths = append(peerBandwidths, PeerBandwidth{
|
||||
PublicKey: publicKey,
|
||||
BytesIn: 0,
|
||||
BytesOut: 0,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// For first reading of a peer, report 0 to establish baseline
|
||||
peerBandwidths = append(peerBandwidths, PeerBandwidth{
|
||||
PublicKey: publicKey,
|
||||
BytesIn: 0,
|
||||
BytesOut: 0,
|
||||
})
|
||||
}
|
||||
|
||||
// Update the last reading
|
||||
s.lastReadings[publicKey] = currentReading
|
||||
}
|
||||
|
||||
// Clean up old peers
|
||||
for publicKey := range s.lastReadings {
|
||||
found := false
|
||||
for _, peer := range device.Peers {
|
||||
if peer.PublicKey.String() == publicKey {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
delete(s.lastReadings, publicKey)
|
||||
}
|
||||
}
|
||||
|
||||
return peerBandwidths, nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) reportPeerBandwidth() error {
|
||||
bandwidths, err := s.calculatePeerBandwidth()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to calculate peer bandwidth: %v", err)
|
||||
}
|
||||
|
||||
err = s.client.SendMessage("newt/receive-bandwidth", map[string]interface{}{
|
||||
"bandwidthData": bandwidths,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send bandwidth data: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error {
|
||||
|
||||
if s.serverPubKey == "" || s.token == "" {
|
||||
return fmt.Errorf("server public key or token is not set")
|
||||
}
|
||||
|
||||
// Parse server address
|
||||
serverSplit := strings.Split(serverAddr, ":")
|
||||
if len(serverSplit) < 2 {
|
||||
return fmt.Errorf("invalid server address format, expected hostname:port")
|
||||
}
|
||||
|
||||
serverHostname := serverSplit[0]
|
||||
serverPort, err := strconv.ParseUint(serverSplit[1], 10, 16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse server port: %v", err)
|
||||
}
|
||||
|
||||
// Resolve server hostname to IP
|
||||
serverIPAddr := network.HostToAddr(serverHostname)
|
||||
if serverIPAddr == nil {
|
||||
return fmt.Errorf("failed to resolve server hostname")
|
||||
}
|
||||
|
||||
// Get client IP based on route to server
|
||||
clientIP := network.GetClientIP(serverIPAddr.IP)
|
||||
|
||||
// Create server and client configs
|
||||
server := &network.Server{
|
||||
Hostname: serverHostname,
|
||||
Addr: serverIPAddr,
|
||||
Port: uint16(serverPort),
|
||||
}
|
||||
|
||||
client := &network.PeerNet{
|
||||
IP: clientIP,
|
||||
Port: s.Port,
|
||||
NewtID: s.newtId,
|
||||
}
|
||||
|
||||
// Setup raw connection with BPF filtering
|
||||
rawConn := network.SetupRawConn(server, client)
|
||||
defer rawConn.Close()
|
||||
|
||||
// Create JSON payload
|
||||
payload := struct {
|
||||
NewtID string `json:"newtId"`
|
||||
Token string `json:"token"`
|
||||
}{
|
||||
NewtID: s.newtId,
|
||||
Token: s.token,
|
||||
}
|
||||
|
||||
// Convert payload to JSON
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal payload: %v", err)
|
||||
}
|
||||
|
||||
// Encrypt the payload using the server's WireGuard public key
|
||||
encryptedPayload, err := s.encryptPayload(payloadBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt payload: %v", err)
|
||||
}
|
||||
|
||||
// Send the encrypted packet using the raw connection
|
||||
err = network.SendDataPacket(encryptedPayload, rawConn, server, client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send UDP packet: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) encryptPayload(payload []byte) (interface{}, error) {
|
||||
// Generate an ephemeral keypair for this message
|
||||
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
|
||||
}
|
||||
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
|
||||
|
||||
// Parse the server's public key
|
||||
serverPubKey, err := wgtypes.ParseKey(s.serverPubKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse server public key: %v", err)
|
||||
}
|
||||
|
||||
// Use X25519 for key exchange (replacing deprecated ScalarMult)
|
||||
var ephPrivKeyFixed [32]byte
|
||||
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
|
||||
|
||||
// Perform X25519 key exchange
|
||||
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
|
||||
}
|
||||
|
||||
// Create an AEAD cipher using the shared secret
|
||||
aead, err := chacha20poly1305.New(sharedSecret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
|
||||
}
|
||||
|
||||
// Generate a random nonce
|
||||
nonce := make([]byte, aead.NonceSize())
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate nonce: %v", err)
|
||||
}
|
||||
|
||||
// Encrypt the payload
|
||||
ciphertext := aead.Seal(nil, nonce, payload, nil)
|
||||
|
||||
// Prepare the final encrypted message
|
||||
encryptedMsg := struct {
|
||||
EphemeralPublicKey string `json:"ephemeralPublicKey"`
|
||||
Nonce []byte `json:"nonce"`
|
||||
Ciphertext []byte `json:"ciphertext"`
|
||||
}{
|
||||
EphemeralPublicKey: ephemeralPublicKey.String(),
|
||||
Nonce: nonce,
|
||||
Ciphertext: ciphertext,
|
||||
}
|
||||
|
||||
return encryptedMsg, nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) keepSendingUDPHolePunch(host string) {
|
||||
ticker := time.NewTicker(3 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.stopHolepunch:
|
||||
logger.Info("Stopping UDP holepunch")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.sendUDPHolePunch(host + ":21820"); err != nil {
|
||||
logger.Error("Failed to send UDP hole punch: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WireGuardService) removeInterface() error {
|
||||
// Remove the WireGuard interface
|
||||
link, err := netlink.LinkByName(s.interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get interface: %v", err)
|
||||
}
|
||||
|
||||
err = netlink.LinkDel(link)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete interface: %v", err)
|
||||
}
|
||||
|
||||
logger.Info("WireGuard interface %s removed successfully", s.interfaceName)
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user