From cfa3bb3b64c73b0525ed71d75c65125703cac069 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 14 Dec 2021 11:03:36 +0000 Subject: [PATCH] feat(internal/wireguard): opportunistic kernelspace - Auto detect if kernelspace implementation is available - Fallback to Go userspace implementation if kernel is not available --- README.md | 2 +- go.mod | 4 +- go.sum | 11 +- internal/netlink/family.go | 23 ++- internal/netlink/family_test.go | 21 +++ internal/netlink/interface.go | 1 + internal/netlink/link.go | 5 +- internal/netlink/mock_netlink/interface.go | 15 ++ internal/routing/rules_test.go | 4 +- internal/wireguard/netlinker.go | 2 + internal/wireguard/netlinker_mock_test.go | 29 ++++ internal/wireguard/route_test.go | 2 +- internal/wireguard/rule_test.go | 4 +- internal/wireguard/run.go | 185 ++++++++++++++------- 14 files changed, 229 insertions(+), 79 deletions(-) create mode 100644 internal/netlink/family_test.go diff --git a/README.md b/README.md index aea164d7..494c6154 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,7 @@ using Go, OpenVPN or Wireguard, iptables, DNS over TLS, ShadowSocks and an HTTP - Based on Alpine 3.14 for a small Docker image of 33MB - Supports: **Cyberghost**, **ExpressVPN**, **FastestVPN**, **HideMyAss**, **IPVanish**, **IVPN**, **Mullvad**, **NordVPN**, **Perfect Privacy**, **Privado**, **Private Internet Access**, **PrivateVPN**, **ProtonVPN**, **PureVPN**, **Surfshark**, **TorGuard**, **VPNUnlimited**, **Vyprvpn**, **WeVPN**, **Windscribe** servers - Supports OpenVPN for all providers listed -- Supports Wireguard +- Supports Wireguard both kernelspace and userspace - For **Mullvad**, **Ivpn** and **Windscribe** - For **Torguard**, **VPN Unlimited** and **WeVPN** using [the custom provider](https://github.com/qdm12/gluetun/wiki/Custom-provider) - For custom Wireguard configurations using [the custom provider](https://github.com/qdm12/gluetun/wiki/Custom-provider) diff --git a/go.mod b/go.mod index c6ebdf91..82d1c927 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/qdm12/ss-server v0.3.0 github.com/qdm12/updated v0.0.0-20210603204757-205acfe6937e github.com/stretchr/testify v1.7.0 - github.com/vishvananda/netlink v1.1.0 + github.com/vishvananda/netlink v1.1.1-0.20211129163951-9ada19101fc5 golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c golang.zx2c4.com/wireguard v0.0.0-20210805125648-3957e9b9dd19 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20210803171230-4253848d036c @@ -34,7 +34,7 @@ require ( github.com/mr-tron/base58 v1.2.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect - github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect + github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae // indirect go4.org/intern v0.0.0-20210108033219-3eb7198706b2 // indirect go4.org/unsafe/assume-no-moving-gc v0.0.0-20201222180813-1025295fd063 // indirect golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 // indirect diff --git a/go.sum b/go.sum index 0030b9f9..d1083403 100644 --- a/go.sum +++ b/go.sum @@ -130,10 +130,10 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/vishvananda/netlink v1.1.0 h1:1iyaYNBLmP6L0220aDnYQpo1QEV4t4hJ+xEEhhJH8j0= -github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE= -github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df h1:OviZH7qLw/7ZovXvuNyL3XQl8UFofeikI1NW1Gypu7k= -github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU= +github.com/vishvananda/netlink v1.1.1-0.20211129163951-9ada19101fc5 h1:b/k/BVWzWRS5v6AB0gf2ckFSbFsHN5jR0HoNso1pN+w= +github.com/vishvananda/netlink v1.1.1-0.20211129163951-9ada19101fc5/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho= +github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae h1:4hwBBUfQCFe3Cym0ZtKyq7L16eZUtYKs+BaHDN6mAns= +github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= github.com/xanzy/ssh-agent v0.2.1/go.mod h1:mLlQY/MoOhWBj+gOGMQkOeiEvkx+8pJSI+0Bx9h2kr4= github.com/yl2chen/cidranger v1.0.2/go.mod h1:9U1yz7WPYDwf0vpNWFaeRh0bjwz5RVgRy/9UEQfHl0g= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -185,7 +185,6 @@ golang.org/x/sys v0.0.0-20190221075227-b4e8571b14e0/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190411185658-b44545bcd369/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -193,7 +192,9 @@ golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201009025420-dfb3f7c4e634/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201118182958-a01c418693c7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/internal/netlink/family.go b/internal/netlink/family.go index 605e1b8e..6745b67c 100644 --- a/internal/netlink/family.go +++ b/internal/netlink/family.go @@ -1,6 +1,10 @@ package netlink -import "github.com/vishvananda/netlink" +import ( + "fmt" + + "github.com/vishvananda/netlink" +) //nolint:revive const ( @@ -8,3 +12,20 @@ const ( FAMILY_V4 = netlink.FAMILY_V4 FAMILY_V6 = netlink.FAMILY_V6 ) + +type WireguardChecker interface { + IsWireguardSupported() (ok bool, err error) +} + +func (n *NetLink) IsWireguardSupported() (ok bool, err error) { + families, err := netlink.GenlFamilyList() + if err != nil { + return false, fmt.Errorf("cannot list gen 1 families: %w", err) + } + for _, family := range families { + if family.Name == "wireguard" { + return true, nil + } + } + return false, nil +} diff --git a/internal/netlink/family_test.go b/internal/netlink/family_test.go new file mode 100644 index 00000000..5a9fb34e --- /dev/null +++ b/internal/netlink/family_test.go @@ -0,0 +1,21 @@ +package netlink + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_NetLink_IsWireguardSupported(t *testing.T) { + t.Skip() // TODO unskip once the data race problem with netlink.GenlFamilyList() is fixed + + t.Parallel() + netLink := &NetLink{} + ok, err := netLink.IsWireguardSupported() + require.NoError(t, err) + if ok { // cannot assert since this depends on kernel + t.Log("wireguard is supported") + } else { + t.Log("wireguard is not supported") + } +} diff --git a/internal/netlink/interface.go b/internal/netlink/interface.go index 39bfcfd9..8cb56254 100644 --- a/internal/netlink/interface.go +++ b/internal/netlink/interface.go @@ -9,4 +9,5 @@ type NetLinker interface { Linker Router Ruler + WireguardChecker } diff --git a/internal/netlink/link.go b/internal/netlink/link.go index cf769185..84d886b5 100644 --- a/internal/netlink/link.go +++ b/internal/netlink/link.go @@ -3,8 +3,9 @@ package netlink import "github.com/vishvananda/netlink" type ( - Link = netlink.Link - Bridge = netlink.Bridge + Link = netlink.Link + Bridge = netlink.Bridge + Wireguard = netlink.Wireguard ) var _ Linker = (*NetLink)(nil) diff --git a/internal/netlink/mock_netlink/interface.go b/internal/netlink/mock_netlink/interface.go index d96af938..7e87ebaf 100644 --- a/internal/netlink/mock_netlink/interface.go +++ b/internal/netlink/mock_netlink/interface.go @@ -63,6 +63,21 @@ func (mr *MockNetLinkerMockRecorder) AddrList(arg0, arg1 interface{}) *gomock.Ca return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddrList", reflect.TypeOf((*MockNetLinker)(nil).AddrList), arg0, arg1) } +// IsWireguardSupported mocks base method. +func (m *MockNetLinker) IsWireguardSupported() (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsWireguardSupported") + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IsWireguardSupported indicates an expected call of IsWireguardSupported. +func (mr *MockNetLinkerMockRecorder) IsWireguardSupported() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsWireguardSupported", reflect.TypeOf((*MockNetLinker)(nil).IsWireguardSupported)) +} + // LinkAdd mocks base method. func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) error { m.ctrl.T.Helper() diff --git a/internal/routing/rules_test.go b/internal/routing/rules_test.go index f93240ff..7e5767ba 100644 --- a/internal/routing/rules_test.go +++ b/internal/routing/rules_test.go @@ -88,7 +88,7 @@ func Test_Routing_addIPRule(t *testing.T) { ruleToAdd: makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99), err: errDummy, }, - err: errors.New("dummy error: for rule: ip rule 99: from 1.1.1.0/24 table 99"), + err: errors.New("dummy error: for rule: ip rule 99: from 1.1.1.0/24 to 2.2.2.0/24 table 99"), }, "add rule success": { src: makeIPNet(t, 1), @@ -193,7 +193,7 @@ func Test_Routing_deleteIPRule(t *testing.T) { ruleToDel: makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99), err: errDummy, }, - err: errors.New("dummy error: for rule: ip rule 99: from 1.1.1.0/24 table 99"), + err: errors.New("dummy error: for rule: ip rule 99: from 1.1.1.0/24 to 2.2.2.0/24 table 99"), }, "rule deleted": { src: makeIPNet(t, 1), diff --git a/internal/wireguard/netlinker.go b/internal/wireguard/netlinker.go index 7989ede7..de84bd2b 100644 --- a/internal/wireguard/netlinker.go +++ b/internal/wireguard/netlinker.go @@ -10,9 +10,11 @@ type NetLinker interface { RouteAdd(route *netlink.Route) error RuleAdd(rule *netlink.Rule) error RuleDel(rule *netlink.Rule) error + LinkAdd(link netlink.Link) (err error) LinkList() (links []netlink.Link, err error) LinkByName(name string) (link netlink.Link, err error) LinkSetUp(link netlink.Link) error LinkSetDown(link netlink.Link) error LinkDel(link netlink.Link) error + IsWireguardSupported() (ok bool, err error) } diff --git a/internal/wireguard/netlinker_mock_test.go b/internal/wireguard/netlinker_mock_test.go index bf9fc94e..1bed2ca0 100644 --- a/internal/wireguard/netlinker_mock_test.go +++ b/internal/wireguard/netlinker_mock_test.go @@ -48,6 +48,35 @@ func (mr *MockNetLinkerMockRecorder) AddrAdd(arg0, arg1 interface{}) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddrAdd", reflect.TypeOf((*MockNetLinker)(nil).AddrAdd), arg0, arg1) } +// IsWireguardSupported mocks base method. +func (m *MockNetLinker) IsWireguardSupported() (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsWireguardSupported") + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IsWireguardSupported indicates an expected call of IsWireguardSupported. +func (mr *MockNetLinkerMockRecorder) IsWireguardSupported() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsWireguardSupported", reflect.TypeOf((*MockNetLinker)(nil).IsWireguardSupported)) +} + +// LinkAdd mocks base method. +func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkAdd", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// LinkAdd indicates an expected call of LinkAdd. +func (mr *MockNetLinkerMockRecorder) LinkAdd(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkAdd", reflect.TypeOf((*MockNetLinker)(nil).LinkAdd), arg0) +} + // LinkByName mocks base method. func (m *MockNetLinker) LinkByName(arg0 string) (netlink.Link, error) { m.ctrl.T.Helper() diff --git a/internal/wireguard/route_test.go b/internal/wireguard/route_test.go index 3b0abfa9..967a16b0 100644 --- a/internal/wireguard/route_test.go +++ b/internal/wireguard/route_test.go @@ -53,7 +53,7 @@ func Test_Wireguard_addRoute(t *testing.T) { Table: firewallMark, }, routeAddErr: errDummy, - err: errors.New("dummy: when adding route: {Ifindex: 88 Dst: 1.2.3.4/32 Src: Gw: Flags: [] Table: 51820}"), //nolint:lll + err: errors.New("dummy: when adding route: {Ifindex: 88 Dst: 1.2.3.4/32 Src: Gw: Flags: [] Table: 51820 Realm: 0}"), //nolint:lll }, } diff --git a/internal/wireguard/rule_test.go b/internal/wireguard/rule_test.go index 796edde4..bd434b07 100644 --- a/internal/wireguard/rule_test.go +++ b/internal/wireguard/rule_test.go @@ -51,7 +51,7 @@ func Test_Wireguard_addRule(t *testing.T) { SuppressPrefixlen: -1, }, ruleAddErr: errDummy, - err: errors.New("dummy: when adding rule: ip rule 987: from table 456"), + err: errors.New("dummy: when adding rule: ip rule 987: from all to all table 456"), }, "rule delete error": { expectedRule: &netlink.Rule{ @@ -66,7 +66,7 @@ func Test_Wireguard_addRule(t *testing.T) { SuppressPrefixlen: -1, }, ruleDelErr: errDummy, - cleanupErr: errors.New("dummy: when deleting rule: ip rule 987: from table 456"), + cleanupErr: errors.New("dummy: when deleting rule: ip rule 987: from all to all table 456"), }, } diff --git a/internal/wireguard/run.go b/internal/wireguard/run.go index f47dad59..7d9312eb 100644 --- a/internal/wireguard/run.go +++ b/internal/wireguard/run.go @@ -6,6 +6,7 @@ import ( "fmt" "net" + "github.com/qdm12/gluetun/internal/netlink" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/ipc" @@ -15,7 +16,9 @@ import ( var ( ErrDetectIPv6 = errors.New("cannot detect IPv6 support") + ErrDetectKernel = errors.New("cannot detect Kernel support") ErrCreateTun = errors.New("cannot create TUN device") + ErrAddLink = errors.New("cannot add Wireguard link") ErrFindLink = errors.New("cannot find link") ErrFindDevice = errors.New("cannot find Wireguard device") ErrUAPISocketOpening = errors.New("cannot open UAPI socket") @@ -23,6 +26,7 @@ var ( ErrUAPIListen = errors.New("cannot listen on UAPI socket") ErrAddAddress = errors.New("cannot add address to wireguard interface") ErrConfigure = errors.New("cannot configure wireguard interface") + ErrDeviceInfo = errors.New("cannot get wireguard device information") ErrIfaceUp = errors.New("cannot set the interface to UP") ErrRouteAdd = errors.New("cannot add route for interface") ErrRuleAdd = errors.New("cannot add rule for interface") @@ -41,6 +45,12 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan< return } + doKernel, err := w.netlink.IsWireguardSupported() + if err != nil { + waitError <- fmt.Errorf("%w: %s", ErrDetectKernel, err) + return + } + client, err := wgctrl.New() if err != nil { waitError <- fmt.Errorf("%w: %s", ErrWgctrlOpen, err) @@ -52,62 +62,21 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan< defer closers.cleanup(w.logger) - tun, err := tun.CreateTUN(w.settings.InterfaceName, device.DefaultMTU) - if err != nil { - waitError <- fmt.Errorf("%w: %s", ErrCreateTun, err) - return + setupFunction := setupUserSpace + if doKernel { + w.logger.Info("Using available kernelspace implementation") + setupFunction = setupKernelSpace + } else { + w.logger.Info("Using userspace implementation since Kernel support does not exist") } - closers.add("closing TUN device", stepSeven, tun.Close) - - tunName, err := tun.Name() + link, waitAndCleanup, err := setupFunction(ctx, + w.settings.InterfaceName, w.netlink, &closers, w.logger) if err != nil { - waitError <- fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err) - return - } else if tunName != w.settings.InterfaceName { - waitError <- fmt.Errorf("%w: names don't match: expected %q and got %q", - ErrCreateTun, w.settings.InterfaceName, tunName) + waitError <- err return } - link, err := w.netlink.LinkByName(w.settings.InterfaceName) - if err != nil { - waitError <- fmt.Errorf("%w: %s: %s", ErrFindLink, w.settings.InterfaceName, err) - return - } - - bind := conn.NewDefaultBind() - - closers.add("closing bind", stepSeven, bind.Close) - - deviceLogger := makeDeviceLogger(w.logger) - device := device.NewDevice(tun, bind, deviceLogger) - - closers.add("closing Wireguard device", stepSix, func() error { - device.Close() - return nil - }) - - uapiFile, err := ipc.UAPIOpen(w.settings.InterfaceName) - if err != nil { - waitError <- fmt.Errorf("%w: %s", ErrUAPISocketOpening, err) - return - } - - closers.add("closing UAPI file", stepThree, uapiFile.Close) - - uapiListener, err := ipc.UAPIListen(w.settings.InterfaceName, uapiFile) - if err != nil { - waitError <- fmt.Errorf("%w: %s", ErrUAPIListen, err) - return - } - - closers.add("closing UAPI listener", stepTwo, uapiListener.Close) - - // acceptAndHandle exits when uapiListener is closed - uapiAcceptErrorCh := make(chan error) - go acceptAndHandle(uapiListener, device, uapiAcceptErrorCh) - err = w.addAddresses(link, w.settings.Addresses) if err != nil { waitError <- fmt.Errorf("%w: %s", ErrAddAddress, err) @@ -128,9 +97,6 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan< closers.add("shutting down link", stepFour, func() error { return w.netlink.LinkSetDown(link) }) - closers.add("deleting link", stepFive, func() error { - return w.netlink.LinkDel(link) - }) err = w.addRoute(link, allIPv4(), w.settings.FirewallMark) if err != nil { @@ -158,20 +124,113 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan< w.logger.Info("Wireguard is up") ready <- struct{}{} - select { - case <-ctx.Done(): - err = ctx.Err() - case err = <-uapiAcceptErrorCh: - close(uapiAcceptErrorCh) - case <-device.Wait(): - err = ErrDeviceWaited + waitError <- waitAndCleanup() +} + +type waitAndCleanupFunc func() error + +func setupKernelSpace(ctx context.Context, + interfaceName string, netLinker NetLinker, + closers *closers, logger Logger) ( + link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error) { + linkAttrs := netlink.LinkAttrs{ + Name: interfaceName, + MTU: device.DefaultMTU, // TODO + } + link = &netlink.Wireguard{ + LinkAttrs: linkAttrs, + } + err = netLinker.LinkAdd(link) + if err != nil { + return nil, nil, fmt.Errorf("%w: %s", ErrAddLink, err) + } + closers.add("deleting link", stepFive, func() error { + return netLinker.LinkDel(link) + }) + + waitAndCleanup = func() error { + <-ctx.Done() + closers.cleanup(logger) + return ctx.Err() } - closers.cleanup(w.logger) + return link, waitAndCleanup, nil +} - <-uapiAcceptErrorCh // wait for acceptAndHandle to exit +func setupUserSpace(ctx context.Context, + interfaceName string, netLinker NetLinker, + closers *closers, logger Logger) ( + link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error) { + tun, err := tun.CreateTUN(interfaceName, device.DefaultMTU) + if err != nil { + return nil, nil, fmt.Errorf("%w: %s", ErrCreateTun, err) + } - waitError <- err + closers.add("closing TUN device", stepSeven, tun.Close) + + tunName, err := tun.Name() + if err != nil { + return nil, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err) + } else if tunName != interfaceName { + return nil, nil, fmt.Errorf("%w: names don't match: expected %q and got %q", + ErrCreateTun, interfaceName, tunName) + } + + link, err = netLinker.LinkByName(interfaceName) + if err != nil { + return nil, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err) + } + closers.add("deleting link", stepFive, func() error { + return netLinker.LinkDel(link) + }) + + bind := conn.NewDefaultBind() + + closers.add("closing bind", stepSeven, bind.Close) + + deviceLogger := makeDeviceLogger(logger) + device := device.NewDevice(tun, bind, deviceLogger) + + closers.add("closing Wireguard device", stepSix, func() error { + device.Close() + return nil + }) + + uapiFile, err := ipc.UAPIOpen(interfaceName) + if err != nil { + return nil, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err) + } + + closers.add("closing UAPI file", stepThree, uapiFile.Close) + + uapiListener, err := ipc.UAPIListen(interfaceName, uapiFile) + if err != nil { + return nil, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err) + } + + closers.add("closing UAPI listener", stepTwo, uapiListener.Close) + + // acceptAndHandle exits when uapiListener is closed + uapiAcceptErrorCh := make(chan error) + go acceptAndHandle(uapiListener, device, uapiAcceptErrorCh) + waitAndCleanup = func() error { + select { + case <-ctx.Done(): + err = ctx.Err() + case err = <-uapiAcceptErrorCh: + close(uapiAcceptErrorCh) + case <-device.Wait(): + err = ErrDeviceWaited + } + + closers.cleanup(logger) + + <-uapiAcceptErrorCh // wait for acceptAndHandle to exit + + return err + } + + return link, waitAndCleanup, nil } func acceptAndHandle(uapi net.Listener, device *device.Device,