Compare commits

...

27 Commits

Author SHA1 Message Date
Owen
c5978d9c4e Handle port correctly and delete interface 2025-03-27 22:12:54 -04:00
Owen
f08378b67e Fix segfault if no wgService created 2025-03-25 20:54:09 -04:00
Owen
1501de691a Handle encrypted messages 2025-03-15 21:47:22 -04:00
Owen
2a19856556 Merge branch 'main' into holepunch 2025-03-15 11:33:34 -04:00
Owen Schwartz
f4e17a4dd7 Merge pull request #22 from fosrl/dev
Fix 51820 typo
2025-03-14 18:52:18 -04:00
Owen
ab544fc9ed Fix 51820 typo 2025-03-14 18:51:33 -04:00
Owen
f9e52c4d91 Working on encryption 2025-03-14 18:49:50 -04:00
Owen
14eff8e16c Merge branch 'holepunch' of github.com:fosrl/newt_dg into holepunch 2025-03-12 20:41:38 -04:00
Owen
067e079293 Handle / better 2025-03-12 20:37:57 -04:00
Owen
5e673c829b Clean up when wg is used 2025-02-24 10:05:35 -05:00
Owen
cd3ec0b259 Support relay switch 2025-02-23 20:18:25 -05:00
Owen
b68502de9e Basic relay working! 2025-02-23 16:49:24 -05:00
Owen
f6429b6eee Basic holepunch working 2025-02-23 01:00:16 -05:00
Owen
8795c57b2e HP works! 2025-02-22 12:53:23 -05:00
Owen
4aa718d55f Initial hp working but need to fix port issue 2025-02-22 11:21:13 -05:00
Owen
afa93d8a3f Add static port and udp hole punch 2025-02-21 22:27:24 -05:00
Owen
270ee9cd19 Fix panic 2025-02-21 20:33:31 -05:00
Owen
0affef401c Properly handle key 2025-02-21 18:04:36 -05:00
Owen
18d99de924 Handle messages correctly 2025-02-21 17:13:00 -05:00
Owen
bff6707577 Basic create wg seems to be working 2025-02-21 16:20:03 -05:00
Owen
95eab504fa Get wg working 2025-02-21 16:12:12 -05:00
Owen
56e75902e3 Adjust ws types 2025-02-21 12:44:52 -05:00
Owen
45a1ab91d7 Dont always do wg 2025-02-20 22:10:02 -05:00
Owen
fb199cc94b Tidy 2025-02-20 22:07:27 -05:00
Owen
66edae4288 Clean up implementation 2025-02-20 21:01:44 -05:00
Owen
f69a7f647d Move wg into more of a class 2025-02-20 20:37:31 -05:00
Owen
e8bd55bed9 Copy in gerbil wg config 2025-02-20 20:04:01 -05:00
7 changed files with 1157 additions and 27 deletions

3
.gitignore vendored
View File

@@ -1,3 +1,4 @@
newt
.DS_Store
bin/
bin/
nohup.out

17
go.mod
View File

@@ -5,18 +5,27 @@ go 1.23.1
toolchain go1.23.2
require (
github.com/google/gopacket v1.1.19
github.com/gorilla/websocket v1.5.3
golang.org/x/net v0.30.0
github.com/vishvananda/netlink v1.3.0
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa
golang.org/x/net v0.33.0
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259
)
require (
github.com/google/btree v1.1.2 // indirect
github.com/google/go-cmp v0.6.0 // indirect
golang.org/x/crypto v0.28.0 // indirect
golang.org/x/sys v0.26.0 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
github.com/mdlayher/socket v0.5.1 // indirect
github.com/vishvananda/netns v0.0.4 // indirect
golang.org/x/crypto v0.31.0 // indirect
golang.org/x/sync v0.11.0 // indirect
golang.org/x/sys v0.28.0 // indirect
golang.org/x/time v0.7.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
)

50
go.sum
View File

@@ -2,21 +2,55 @@ github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk=
github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4=
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4=
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=

101
main.go
View File

@@ -13,6 +13,7 @@ import (
"os"
"os/exec"
"os/signal"
"runtime"
"strconv"
"strings"
"syscall"
@@ -21,6 +22,7 @@ import (
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/proxy"
"github.com/fosrl/newt/websocket"
"github.com/fosrl/newt/wg"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
@@ -55,7 +57,7 @@ func fixKey(key string) string {
// Decode from base64
decoded, err := base64.StdEncoding.DecodeString(key)
if err != nil {
logger.Fatal("Error decoding base64:", err)
logger.Fatal("Error decoding base64")
}
// Convert to hex
@@ -213,6 +215,9 @@ func resolveDomain(domain string) (string, error) {
host = strings.TrimPrefix(host, "https://")
}
// if there are any trailing slashes, remove them
host = strings.TrimSuffix(host, "/")
// Lookup IP addresses
ips, err := net.LookupIP(host)
if err != nil {
@@ -246,16 +251,18 @@ func resolveDomain(domain string) (string, error) {
}
var (
endpoint string
id string
secret string
mtu string
mtuInt int
dns string
privateKey wgtypes.Key
err error
logLevel string
updownScript string
endpoint string
id string
secret string
mtu string
mtuInt int
dns string
privateKey wgtypes.Key
err error
logLevel string
updownScript string
interfaceName string
generateAndSaveKeyTo string
)
func main() {
@@ -267,6 +274,8 @@ func main() {
dns = os.Getenv("DNS")
logLevel = os.Getenv("LOG_LEVEL")
updownScript = os.Getenv("UPDOWN_SCRIPT")
interfaceName = os.Getenv("INTERFACE")
generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO")
if endpoint == "" {
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
@@ -289,6 +298,12 @@ func main() {
if updownScript == "" {
flag.StringVar(&updownScript, "updown", "", "Path to updown script to be called when targets are added or removed")
}
if interfaceName == "" {
flag.StringVar(&interfaceName, "interface", "wg1", "Name of the WireGuard interface")
}
if generateAndSaveKeyTo == "" {
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
}
// do a --version check
version := flag.Bool("version", false, "Print the version")
@@ -325,6 +340,7 @@ func main() {
logger.Fatal("Failed to create client: %v", err)
}
var wgService *wg.WireGuardService
// Create TUN device and network stack
var tun tun.Device
var tnet *netstack.Net
@@ -333,6 +349,30 @@ func main() {
var connected bool
var wgData WgData
if generateAndSaveKeyTo != "" {
// make sure we are running on linux
if runtime.GOOS != "linux" {
logger.Fatal("Tunnel management is only supported on Linux right now!")
os.Exit(1)
}
var host = endpoint
if strings.HasPrefix(host, "http://") {
host = strings.TrimPrefix(host, "http://")
} else if strings.HasPrefix(host, "https://") {
host = strings.TrimPrefix(host, "https://")
}
host = strings.TrimSuffix(host, "/")
// Create WireGuard service
wgService, err = wg.NewWireGuardService(interfaceName, mtuInt, generateAndSaveKeyTo, host, id, client)
if err != nil {
logger.Fatal("Failed to create WireGuard service: %v", err)
}
defer wgService.Close()
}
client.RegisterHandler("newt/terminate", func(msg websocket.WSMessage) {
logger.Info("Received terminate message")
if pm != nil {
@@ -358,7 +398,7 @@ func main() {
if err != nil {
// Handle complete failure after all retries
logger.Warn("Failed to ping %s: %v", wgData.ServerIP, err)
logger.Warn("HINT: Do you have UDP port 51280 (or the port in config.yml) open on your Pangolin server?")
logger.Warn("HINT: Do you have UDP port 51820 (or the port in config.yml) open on your Pangolin server?")
}
return
}
@@ -374,6 +414,10 @@ func main() {
return
}
if wgService != nil {
wgService.SetServerPubKey(wgData.PublicKey)
}
logger.Info("Received: %+v", msg)
tun, tnet, err = netstack.CreateNetTUN(
[]netip.Addr{netip.MustParseAddr(wgData.TunnelIP)},
@@ -420,6 +464,7 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
if err != nil {
// Handle complete failure after all retries
logger.Error("Failed to ping %s: %v", wgData.ServerIP, err)
fmt.Sprintf("%s", privateKey)
}
if !connected {
@@ -441,6 +486,13 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: wgData.Targets.UDP})
}
// first make sure the wpgService has a port
if wgService != nil {
// add a udp proxy for localost and the wgService port
// TODO: make sure this port is not used in a target
pm.AddTarget("udp", wgData.TunnelIP, int(wgService.Port), fmt.Sprintf("127.0.0.1:%d", wgService.Port))
}
err = pm.Start()
if err != nil {
logger.Error("Failed to start proxy manager: %v", err)
@@ -539,10 +591,20 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
return err
}
if wgService != nil {
wgService.LoadRemoteConfig()
}
logger.Info("Sent registration message")
return nil
})
client.OnTokenUpdate(func(token string) {
if wgService != nil {
wgService.SetToken(token)
}
})
// Connect to the WebSocket server
if err := client.Connect(); err != nil {
logger.Fatal("Failed to connect to server: %v", err)
@@ -554,8 +616,21 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
// Cleanup
dev.Close()
if wgService != nil {
wgService.Close()
}
if pm != nil {
pm.Stop()
}
if client != nil {
client.Close()
}
logger.Info("Exiting...")
os.Exit(0)
}
func parseTargetData(data interface{}) (TargetData, error) {

202
network/network.go Normal file
View File

@@ -0,0 +1,202 @@
package network
import (
"encoding/binary"
"encoding/json"
"fmt"
"log"
"net"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/vishvananda/netlink"
"golang.org/x/net/bpf"
"golang.org/x/net/ipv4"
)
const (
udpProtocol = 17
// EmptyUDPSize is the size of an empty UDP packet
EmptyUDPSize = 28
timeout = time.Second * 10
)
// Server stores data relating to the server
type Server struct {
Hostname string
Addr *net.IPAddr
Port uint16
}
// PeerNet stores data about a peer's endpoint
type PeerNet struct {
Resolved bool
IP net.IP
Port uint16
NewtID string
}
// GetClientIP gets source ip address that will be used when sending data to dstIP
func GetClientIP(dstIP net.IP) net.IP {
routes, err := netlink.RouteGet(dstIP)
if err != nil {
log.Fatalln("Error getting route:", err)
}
return routes[0].Src
}
// HostToAddr resolves a hostname, whether DNS or IP to a valid net.IPAddr
func HostToAddr(hostStr string) *net.IPAddr {
remoteAddrs, err := net.LookupHost(hostStr)
if err != nil {
log.Fatalln("Error parsing remote address:", err)
}
for _, addrStr := range remoteAddrs {
if remoteAddr, err := net.ResolveIPAddr("ip4", addrStr); err == nil {
return remoteAddr
}
}
return nil
}
// SetupRawConn creates an ipv4 and udp only RawConn and applies packet filtering
func SetupRawConn(server *Server, client *PeerNet) *ipv4.RawConn {
packetConn, err := net.ListenPacket("ip4:udp", client.IP.String())
if err != nil {
log.Fatalln("Error creating packetConn:", err)
}
rawConn, err := ipv4.NewRawConn(packetConn)
if err != nil {
log.Fatalln("Error creating rawConn:", err)
}
ApplyBPF(rawConn, server, client)
return rawConn
}
// ApplyBPF constructs a BPF program and applies it to the RawConn
func ApplyBPF(rawConn *ipv4.RawConn, server *Server, client *PeerNet) {
const ipv4HeaderLen = 20
const srcIPOffset = 12
const srcPortOffset = ipv4HeaderLen + 0
const dstPortOffset = ipv4HeaderLen + 2
ipArr := []byte(server.Addr.IP.To4())
ipInt := uint32(ipArr[0])<<(3*8) + uint32(ipArr[1])<<(2*8) + uint32(ipArr[2])<<8 + uint32(ipArr[3])
bpfRaw, err := bpf.Assemble([]bpf.Instruction{
bpf.LoadAbsolute{Off: srcIPOffset, Size: 4},
bpf.JumpIf{Cond: bpf.JumpEqual, Val: ipInt, SkipFalse: 5, SkipTrue: 0},
bpf.LoadAbsolute{Off: srcPortOffset, Size: 2},
bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(server.Port), SkipFalse: 3, SkipTrue: 0},
bpf.LoadAbsolute{Off: dstPortOffset, Size: 2},
bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(client.Port), SkipFalse: 1, SkipTrue: 0},
bpf.RetConstant{Val: 1<<(8*4) - 1},
bpf.RetConstant{Val: 0},
})
if err != nil {
log.Fatalln("Error assembling BPF:", err)
}
err = rawConn.SetBPF(bpfRaw)
if err != nil {
log.Fatalln("Error setting BPF:", err)
}
}
// MakePacket constructs a request packet to send to the server
func MakePacket(payload []byte, server *Server, client *PeerNet) []byte {
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
FixLengths: true,
ComputeChecksums: true,
}
ipHeader := layers.IPv4{
SrcIP: client.IP,
DstIP: server.Addr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
udpHeader := layers.UDP{
SrcPort: layers.UDPPort(client.Port),
DstPort: layers.UDPPort(server.Port),
}
payloadLayer := gopacket.Payload(payload)
udpHeader.SetNetworkLayerForChecksum(&ipHeader)
gopacket.SerializeLayers(buf, opts, &ipHeader, &udpHeader, &payloadLayer)
return buf.Bytes()
}
// SendPacket sends packet to the Server
func SendPacket(packet []byte, conn *ipv4.RawConn, server *Server, client *PeerNet) error {
fullPacket := MakePacket(packet, server, client)
_, err := conn.WriteToIP(fullPacket, server.Addr)
return err
}
// SendDataPacket sends a JSON payload to the Server
func SendDataPacket(data interface{}, conn *ipv4.RawConn, server *Server, client *PeerNet) error {
jsonData, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("failed to marshal payload: %v", err)
}
return SendPacket(jsonData, conn, server, client)
}
// RecvPacket receives a UDP packet from server
func RecvPacket(conn *ipv4.RawConn, server *Server, client *PeerNet) ([]byte, int, error) {
err := conn.SetReadDeadline(time.Now().Add(timeout))
if err != nil {
return nil, 0, err
}
response := make([]byte, 4096)
n, err := conn.Read(response)
if err != nil {
return nil, n, err
}
return response, n, nil
}
// RecvDataPacket receives and unmarshals a JSON packet from server
func RecvDataPacket(conn *ipv4.RawConn, server *Server, client *PeerNet) ([]byte, error) {
response, n, err := RecvPacket(conn, server, client)
if err != nil {
return nil, err
}
// Extract payload from UDP packet
payload := response[EmptyUDPSize:n]
return payload, nil
}
// ParseResponse takes a response packet and parses it into an IP and port
func ParseResponse(response []byte) (net.IP, uint16) {
ip := net.IP(response[:4])
port := binary.BigEndian.Uint16(response[4:6])
return ip, port
}
func parseForBPF(response []byte) (srcIP net.IP, srcPort uint16, dstPort uint16) {
srcIP = net.IP(response[12:16])
srcPort = binary.BigEndian.Uint16(response[20:22])
dstPort = binary.BigEndian.Uint16(response[22:24])
return
}

View File

@@ -27,7 +27,8 @@ type Client struct {
isConnected bool
reconnectMux sync.RWMutex
onConnect func() error
onConnect func() error
onTokenUpdate func(token string)
}
type ClientOption func(*Client)
@@ -45,6 +46,10 @@ func (c *Client) OnConnect(callback func() error) {
c.onConnect = callback
}
func (c *Client) OnTokenUpdate(callback func(token string)) {
c.onTokenUpdate = callback
}
// NewClient creates a new Newt client
func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*Client, error) {
config := &Config{
@@ -270,6 +275,8 @@ func (c *Client) establishConnection() error {
return fmt.Errorf("failed to get token: %w", err)
}
c.onTokenUpdate(token)
// Parse the base URL to determine protocol and hostname
baseURL, err := url.Parse(c.baseURL)
if err != nil {
@@ -292,6 +299,7 @@ func (c *Client) establishConnection() error {
// Add token to query parameters
q := u.Query()
q.Set("token", token)
q.Set("clientType", "newt")
u.RawQuery = q.Encode()
// Connect to WebSocket

801
wg/wg.go Normal file
View File

@@ -0,0 +1,801 @@
package wg
import (
"encoding/json"
"fmt"
"net"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/network"
"github.com/fosrl/newt/websocket"
"github.com/vishvananda/netlink"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/curve25519"
"golang.org/x/exp/rand"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type WgConfig struct {
IpAddress string `json:"ipAddress"`
Peers []Peer `json:"peers"`
}
type Peer struct {
PublicKey string `json:"publicKey"`
AllowedIPs []string `json:"allowedIps"`
Endpoint string `json:"endpoint"`
}
type PeerBandwidth struct {
PublicKey string `json:"publicKey"`
BytesIn float64 `json:"bytesIn"`
BytesOut float64 `json:"bytesOut"`
}
type PeerReading struct {
BytesReceived int64
BytesTransmitted int64
LastChecked time.Time
}
type WireGuardService struct {
interfaceName string
mtu int
client *websocket.Client
wgClient *wgctrl.Client
config WgConfig
key wgtypes.Key
newtId string
lastReadings map[string]PeerReading
mu sync.Mutex
Port uint16
stopHolepunch chan struct{}
host string
serverPubKey string
token string
}
// Add this type definition
type fixedPortBind struct {
port uint16
conn.Bind
}
func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) {
// Ignore the requested port and use our fixed port
return b.Bind.Open(b.port)
}
func NewFixedPortBind(port uint16) conn.Bind {
return &fixedPortBind{
port: port,
Bind: conn.NewDefaultBind(),
}
}
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
if maxPort < minPort {
return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort)
}
// Create a slice of all ports in the range
portRange := make([]uint16, maxPort-minPort+1)
for i := range portRange {
portRange[i] = minPort + uint16(i)
}
// Fisher-Yates shuffle to randomize the port order
rand.Seed(uint64(time.Now().UnixNano()))
for i := len(portRange) - 1; i > 0; i-- {
j := rand.Intn(i + 1)
portRange[i], portRange[j] = portRange[j], portRange[i]
}
// Try each port in the randomized order
for _, port := range portRange {
addr := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: int(port),
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
continue // Port is in use or there was an error, try next port
}
_ = conn.SetDeadline(time.Now())
conn.Close()
return port, nil
}
return 0, fmt.Errorf("no available UDP ports found in range %d-%d", minPort, maxPort)
}
func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client) (*WireGuardService, error) {
wgClient, err := wgctrl.New()
if err != nil {
return nil, fmt.Errorf("failed to create WireGuard client: %v", err)
}
var key wgtypes.Key
// if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file
if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) {
// generate a new private key
key, err = wgtypes.GeneratePrivateKey()
if err != nil {
logger.Fatal("Failed to generate private key: %v", err)
}
// save the key to the file
err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644)
if err != nil {
logger.Fatal("Failed to save private key: %v", err)
}
} else {
keyData, err := os.ReadFile(generateAndSaveKeyTo)
if err != nil {
logger.Fatal("Failed to read private key: %v", err)
}
key, err = wgtypes.ParseKey(string(keyData))
if err != nil {
logger.Fatal("Failed to parse private key: %v", err)
}
}
port, err := FindAvailableUDPPort(49152, 65535)
if err != nil {
fmt.Printf("Error finding available port: %v\n", err)
return nil, err
}
service := &WireGuardService{
interfaceName: interfaceName,
mtu: mtu,
client: wsClient,
wgClient: wgClient,
key: key,
newtId: newtId,
lastReadings: make(map[string]PeerReading),
Port: port,
stopHolepunch: make(chan struct{}),
host: host,
}
// Register websocket handlers
wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig)
wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer)
wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer)
return service, nil
}
func (s *WireGuardService) Close() {
s.wgClient.Close()
// Remove the WireGuard interface
if err := s.removeInterface(); err != nil {
logger.Error("Failed to remove WireGuard interface: %v", err)
}
}
func (s *WireGuardService) SetServerPubKey(serverPubKey string) {
s.serverPubKey = serverPubKey
}
func (s *WireGuardService) SetToken(token string) {
s.token = token
}
func (s *WireGuardService) LoadRemoteConfig() error {
// get the exising wireguard port
device, err := s.wgClient.Device(s.interfaceName)
if err == nil {
s.Port = uint16(device.ListenPort)
logger.Info("WireGuard interface %s already exists with port %d\n", s.interfaceName, s.Port)
}
err = s.client.SendMessage("newt/wg/get-config", map[string]interface{}{
"publicKey": fmt.Sprintf("%s", s.key.PublicKey().String()),
"port": s.Port,
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return err
}
logger.Info("Requesting WireGuard configuration from remote server")
go s.periodicBandwidthCheck()
return nil
}
func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
var config WgConfig
logger.Info("Received message: %v", msg)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &config); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
s.config = config
// Ensure the WireGuard interface and peers are configured
if err := s.ensureWireguardInterface(config); err != nil {
logger.Error("Failed to ensure WireGuard interface: %v", err)
}
if err := s.ensureWireguardPeers(config.Peers); err != nil {
logger.Error("Failed to ensure WireGuard peers: %v", err)
}
if err := s.sendUDPHolePunch(s.host + ":21820"); err != nil {
logger.Error("Failed to send UDP hole punch: %v", err)
}
// start the UDP holepunch
go s.keepSendingUDPHolePunch(s.host)
}
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
// Check if the WireGuard interface exists
_, err := netlink.LinkByName(s.interfaceName)
if err != nil {
if _, ok := err.(netlink.LinkNotFoundError); ok {
// Interface doesn't exist, so create it
err = s.createWireGuardInterface()
if err != nil {
logger.Fatal("Failed to create WireGuard interface: %v", err)
}
logger.Info("Created WireGuard interface %s\n", s.interfaceName)
} else {
logger.Fatal("Error checking for WireGuard interface: %v", err)
}
} else {
logger.Info("WireGuard interface %s already exists\n", s.interfaceName)
// get the exising wireguard port
device, err := s.wgClient.Device(s.interfaceName)
if err != nil {
return fmt.Errorf("failed to get device: %v", err)
}
// get the existing port
s.Port = uint16(device.ListenPort)
logger.Info("WireGuard interface %s already exists with port %d\n", s.interfaceName, s.Port)
return nil
}
logger.Info("Assigning IP address %s to interface %s\n", wgconfig.IpAddress, s.interfaceName)
// Assign IP address to the interface
err = s.assignIPAddress(wgconfig.IpAddress)
if err != nil {
logger.Fatal("Failed to assign IP address: %v", err)
}
// Check if the interface already exists
_, err = s.wgClient.Device(s.interfaceName)
if err != nil {
return fmt.Errorf("interface %s does not exist", s.interfaceName)
}
// Parse the private key
key, err := wgtypes.ParseKey(s.key.String())
if err != nil {
return fmt.Errorf("failed to parse private key: %v", err)
}
config := wgtypes.Config{
PrivateKey: &key,
ListenPort: new(int),
}
// Use the service's fixed port instead of the config port
*config.ListenPort = int(s.Port)
// Create and configure the WireGuard interface
err = s.wgClient.ConfigureDevice(s.interfaceName, config)
if err != nil {
return fmt.Errorf("failed to configure WireGuard device: %v", err)
}
// bring up the interface
link, err := netlink.LinkByName(s.interfaceName)
if err != nil {
return fmt.Errorf("failed to get interface: %v", err)
}
if err := netlink.LinkSetMTU(link, s.mtu); err != nil {
return fmt.Errorf("failed to set MTU: %v", err)
}
if err := netlink.LinkSetUp(link); err != nil {
return fmt.Errorf("failed to bring up interface: %v", err)
}
// if err := s.ensureMSSClamping(); err != nil {
// logger.Warn("Failed to ensure MSS clamping: %v", err)
// }
logger.Info("WireGuard interface %s created and configured", s.interfaceName)
return nil
}
func (s *WireGuardService) createWireGuardInterface() error {
wgLink := &netlink.GenericLink{
LinkAttrs: netlink.LinkAttrs{Name: s.interfaceName},
LinkType: "wireguard",
}
return netlink.LinkAdd(wgLink)
}
func (s *WireGuardService) assignIPAddress(ipAddress string) error {
link, err := netlink.LinkByName(s.interfaceName)
if err != nil {
return fmt.Errorf("failed to get interface: %v", err)
}
addr, err := netlink.ParseAddr(ipAddress)
if err != nil {
return fmt.Errorf("failed to parse IP address: %v", err)
}
return netlink.AddrAdd(link, addr)
}
func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
// get the current peers
device, err := s.wgClient.Device(s.interfaceName)
if err != nil {
return fmt.Errorf("failed to get device: %v", err)
}
// get the peer public keys
var currentPeers []string
for _, peer := range device.Peers {
currentPeers = append(currentPeers, peer.PublicKey.String())
}
// remove any peers that are not in the config
for _, peer := range currentPeers {
found := false
for _, configPeer := range peers {
if peer == configPeer.PublicKey {
found = true
break
}
}
if !found {
err := s.removePeer(peer)
if err != nil {
return fmt.Errorf("failed to remove peer: %v", err)
}
}
}
// add any peers that are in the config but not in the current peers
for _, configPeer := range peers {
found := false
for _, peer := range currentPeers {
if configPeer.PublicKey == peer {
found = true
break
}
}
if !found {
err := s.addPeer(configPeer)
if err != nil {
return fmt.Errorf("failed to add peer: %v", err)
}
}
}
return nil
}
func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) {
var peer Peer
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
}
if err := json.Unmarshal(jsonData, &peer); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
}
err = s.addPeer(peer)
if err != nil {
logger.Info("Error adding peer: %v", err)
return
}
}
func (s *WireGuardService) addPeer(peer Peer) error {
pubKey, err := wgtypes.ParseKey(peer.PublicKey)
if err != nil {
return fmt.Errorf("failed to parse public key: %v", err)
}
// parse allowed IPs into array of net.IPNet
var allowedIPs []net.IPNet
for _, ipStr := range peer.AllowedIPs {
_, ipNet, err := net.ParseCIDR(ipStr)
if err != nil {
return fmt.Errorf("failed to parse allowed IP: %v", err)
}
allowedIPs = append(allowedIPs, *ipNet)
}
// add keep alive using *time.Duration of 1 second
keepalive := time.Second
var peerConfig wgtypes.PeerConfig
if peer.Endpoint != "" {
endpoint, err := net.ResolveUDPAddr("udp", peer.Endpoint)
if err != nil {
return fmt.Errorf("failed to resolve endpoint address: %w", err)
}
// make the endpoint localhost to test
peerConfig = wgtypes.PeerConfig{
PublicKey: pubKey,
AllowedIPs: allowedIPs,
PersistentKeepaliveInterval: &keepalive,
Endpoint: endpoint,
}
} else {
peerConfig = wgtypes.PeerConfig{
PublicKey: pubKey,
AllowedIPs: allowedIPs,
PersistentKeepaliveInterval: &keepalive,
}
logger.Info("Added peer with no endpoint!")
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peerConfig},
}
if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil {
return fmt.Errorf("failed to add peer: %v", err)
}
logger.Info("Peer %s added successfully", peer.PublicKey)
return nil
}
func (s *WireGuardService) handleRemovePeer(msg websocket.WSMessage) {
// parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" }
type RemoveRequest struct {
PublicKey string `json:"publicKey"`
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
}
var request RemoveRequest
if err := json.Unmarshal(jsonData, &request); err != nil {
logger.Info("Error unmarshaling data: %v", err)
return
}
if err := s.removePeer(request.PublicKey); err != nil {
logger.Info("Error removing peer: %v", err)
return
}
}
func (s *WireGuardService) removePeer(publicKey string) error {
pubKey, err := wgtypes.ParseKey(publicKey)
if err != nil {
return fmt.Errorf("failed to parse public key: %v", err)
}
peerConfig := wgtypes.PeerConfig{
PublicKey: pubKey,
Remove: true,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peerConfig},
}
if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil {
return fmt.Errorf("failed to remove peer: %v", err)
}
logger.Info("Peer %s removed successfully", publicKey)
return nil
}
func (s *WireGuardService) periodicBandwidthCheck() {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for range ticker.C {
if err := s.reportPeerBandwidth(); err != nil {
logger.Info("Failed to report peer bandwidth: %v", err)
}
}
}
func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
device, err := s.wgClient.Device(s.interfaceName)
if err != nil {
return nil, fmt.Errorf("failed to get device: %v", err)
}
peerBandwidths := []PeerBandwidth{}
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
for _, peer := range device.Peers {
publicKey := peer.PublicKey.String()
currentReading := PeerReading{
BytesReceived: peer.ReceiveBytes,
BytesTransmitted: peer.TransmitBytes,
LastChecked: now,
}
var bytesInDiff, bytesOutDiff float64
lastReading, exists := s.lastReadings[publicKey]
if exists {
timeDiff := currentReading.LastChecked.Sub(lastReading.LastChecked).Seconds()
if timeDiff > 0 {
// Calculate bytes transferred since last reading
bytesInDiff = float64(currentReading.BytesReceived - lastReading.BytesReceived)
bytesOutDiff = float64(currentReading.BytesTransmitted - lastReading.BytesTransmitted)
// Handle counter wraparound (if the counter resets or overflows)
if bytesInDiff < 0 {
bytesInDiff = float64(currentReading.BytesReceived)
}
if bytesOutDiff < 0 {
bytesOutDiff = float64(currentReading.BytesTransmitted)
}
// Convert to MB
bytesInMB := bytesInDiff / (1024 * 1024)
bytesOutMB := bytesOutDiff / (1024 * 1024)
peerBandwidths = append(peerBandwidths, PeerBandwidth{
PublicKey: publicKey,
BytesIn: bytesInMB,
BytesOut: bytesOutMB,
})
} else {
// If readings are too close together or time hasn't passed, report 0
peerBandwidths = append(peerBandwidths, PeerBandwidth{
PublicKey: publicKey,
BytesIn: 0,
BytesOut: 0,
})
}
} else {
// For first reading of a peer, report 0 to establish baseline
peerBandwidths = append(peerBandwidths, PeerBandwidth{
PublicKey: publicKey,
BytesIn: 0,
BytesOut: 0,
})
}
// Update the last reading
s.lastReadings[publicKey] = currentReading
}
// Clean up old peers
for publicKey := range s.lastReadings {
found := false
for _, peer := range device.Peers {
if peer.PublicKey.String() == publicKey {
found = true
break
}
}
if !found {
delete(s.lastReadings, publicKey)
}
}
return peerBandwidths, nil
}
func (s *WireGuardService) reportPeerBandwidth() error {
bandwidths, err := s.calculatePeerBandwidth()
if err != nil {
return fmt.Errorf("failed to calculate peer bandwidth: %v", err)
}
err = s.client.SendMessage("newt/receive-bandwidth", map[string]interface{}{
"bandwidthData": bandwidths,
})
if err != nil {
return fmt.Errorf("failed to send bandwidth data: %v", err)
}
return nil
}
func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error {
if s.serverPubKey == "" || s.token == "" {
return fmt.Errorf("server public key or token is not set")
}
// Parse server address
serverSplit := strings.Split(serverAddr, ":")
if len(serverSplit) < 2 {
return fmt.Errorf("invalid server address format, expected hostname:port")
}
serverHostname := serverSplit[0]
serverPort, err := strconv.ParseUint(serverSplit[1], 10, 16)
if err != nil {
return fmt.Errorf("failed to parse server port: %v", err)
}
// Resolve server hostname to IP
serverIPAddr := network.HostToAddr(serverHostname)
if serverIPAddr == nil {
return fmt.Errorf("failed to resolve server hostname")
}
// Get client IP based on route to server
clientIP := network.GetClientIP(serverIPAddr.IP)
// Create server and client configs
server := &network.Server{
Hostname: serverHostname,
Addr: serverIPAddr,
Port: uint16(serverPort),
}
client := &network.PeerNet{
IP: clientIP,
Port: s.Port,
NewtID: s.newtId,
}
// Setup raw connection with BPF filtering
rawConn := network.SetupRawConn(server, client)
defer rawConn.Close()
// Create JSON payload
payload := struct {
NewtID string `json:"newtId"`
Token string `json:"token"`
}{
NewtID: s.newtId,
Token: s.token,
}
// Convert payload to JSON
payloadBytes, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal payload: %v", err)
}
// Encrypt the payload using the server's WireGuard public key
encryptedPayload, err := s.encryptPayload(payloadBytes)
if err != nil {
return fmt.Errorf("failed to encrypt payload: %v", err)
}
// Send the encrypted packet using the raw connection
err = network.SendDataPacket(encryptedPayload, rawConn, server, client)
if err != nil {
return fmt.Errorf("failed to send UDP packet: %v", err)
}
return nil
}
func (s *WireGuardService) encryptPayload(payload []byte) (interface{}, error) {
// Generate an ephemeral keypair for this message
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
}
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
// Parse the server's public key
serverPubKey, err := wgtypes.ParseKey(s.serverPubKey)
if err != nil {
return nil, fmt.Errorf("failed to parse server public key: %v", err)
}
// Use X25519 for key exchange (replacing deprecated ScalarMult)
var ephPrivKeyFixed [32]byte
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
// Perform X25519 key exchange
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
if err != nil {
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
}
// Create an AEAD cipher using the shared secret
aead, err := chacha20poly1305.New(sharedSecret)
if err != nil {
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
}
// Generate a random nonce
nonce := make([]byte, aead.NonceSize())
if _, err := rand.Read(nonce); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %v", err)
}
// Encrypt the payload
ciphertext := aead.Seal(nil, nonce, payload, nil)
// Prepare the final encrypted message
encryptedMsg := struct {
EphemeralPublicKey string `json:"ephemeralPublicKey"`
Nonce []byte `json:"nonce"`
Ciphertext []byte `json:"ciphertext"`
}{
EphemeralPublicKey: ephemeralPublicKey.String(),
Nonce: nonce,
Ciphertext: ciphertext,
}
return encryptedMsg, nil
}
func (s *WireGuardService) keepSendingUDPHolePunch(host string) {
ticker := time.NewTicker(3 * time.Second)
defer ticker.Stop()
for {
select {
case <-s.stopHolepunch:
logger.Info("Stopping UDP holepunch")
return
case <-ticker.C:
if err := s.sendUDPHolePunch(host + ":21820"); err != nil {
logger.Error("Failed to send UDP hole punch: %v", err)
}
}
}
}
func (s *WireGuardService) removeInterface() error {
// Remove the WireGuard interface
link, err := netlink.LinkByName(s.interfaceName)
if err != nil {
return fmt.Errorf("failed to get interface: %v", err)
}
err = netlink.LinkDel(link)
if err != nil {
return fmt.Errorf("failed to delete interface: %v", err)
}
logger.Info("WireGuard interface %s removed successfully", s.interfaceName)
return nil
}