chore(all): use netip.Prefix for ip networks

- remove usage of `net.IPNet`
- remove usage of `netaddr.IPPrefix`
This commit is contained in:
Quentin McGaw 2023-04-27 13:41:05 +00:00
parent 801a7fd6fe
commit d21a943779
No known key found for this signature in database
32 changed files with 344 additions and 315 deletions

View File

@ -2,7 +2,7 @@ package settings
import (
"fmt"
"net"
"net/netip"
"github.com/qdm12/gluetun/internal/configuration/settings/helpers"
"github.com/qdm12/gotree"
@ -12,7 +12,7 @@ import (
type Firewall struct {
VPNInputPorts []uint16
InputPorts []uint16
OutboundSubnets []net.IPNet
OutboundSubnets []netip.Prefix
Enabled *bool
Debug *bool
}
@ -42,7 +42,7 @@ func (f *Firewall) copy() (copied Firewall) {
return Firewall{
VPNInputPorts: helpers.CopyUint16Slice(f.VPNInputPorts),
InputPorts: helpers.CopyUint16Slice(f.InputPorts),
OutboundSubnets: helpers.CopyIPNetSlice(f.OutboundSubnets),
OutboundSubnets: helpers.CopyNetipPrefixesSlice(f.OutboundSubnets),
Enabled: helpers.CopyBoolPtr(f.Enabled),
Debug: helpers.CopyBoolPtr(f.Debug),
}
@ -55,7 +55,7 @@ func (f *Firewall) copy() (copied Firewall) {
func (f *Firewall) mergeWith(other Firewall) {
f.VPNInputPorts = helpers.MergeUint16Slices(f.VPNInputPorts, other.VPNInputPorts)
f.InputPorts = helpers.MergeUint16Slices(f.InputPorts, other.InputPorts)
f.OutboundSubnets = helpers.MergeIPNetsSlices(f.OutboundSubnets, other.OutboundSubnets)
f.OutboundSubnets = helpers.MergeNetipPrefixesSlices(f.OutboundSubnets, other.OutboundSubnets)
f.Enabled = helpers.MergeWithBool(f.Enabled, other.Enabled)
f.Debug = helpers.MergeWithBool(f.Debug, other.Debug)
}
@ -66,7 +66,7 @@ func (f *Firewall) mergeWith(other Firewall) {
func (f *Firewall) overrideWith(other Firewall) {
f.VPNInputPorts = helpers.OverrideWithUint16Slice(f.VPNInputPorts, other.VPNInputPorts)
f.InputPorts = helpers.OverrideWithUint16Slice(f.InputPorts, other.InputPorts)
f.OutboundSubnets = helpers.OverrideWithIPNetsSlice(f.OutboundSubnets, other.OutboundSubnets)
f.OutboundSubnets = helpers.OverrideWithNetipPrefixesSlice(f.OutboundSubnets, other.OutboundSubnets)
f.Enabled = helpers.OverrideWithBool(f.Enabled, other.Enabled)
f.Debug = helpers.OverrideWithBool(f.Debug, other.Debug)
}

View File

@ -90,30 +90,6 @@ func CopyIP(original net.IP) (copied net.IP) {
return copied
}
func CopyIPNet(original net.IPNet) (copied net.IPNet) {
if original.IP != nil {
copied.IP = make(net.IP, len(original.IP))
copy(copied.IP, original.IP)
}
if original.Mask != nil {
copied.Mask = make(net.IPMask, len(original.Mask))
copy(copied.Mask, original.Mask)
}
return copied
}
func CopyIPNetPtr(original *net.IPNet) (copied *net.IPNet) {
if original == nil {
return nil
}
copied = new(net.IPNet)
*copied = CopyIPNet(*original)
return copied
}
func CopyNetipAddress(original netip.Addr) (copied netip.Addr) {
// AsSlice creates a new byte slice so no need to copy the bytes.
bytes := original.AsSlice()
@ -158,18 +134,6 @@ func CopyUint16Slice(original []uint16) (copied []uint16) {
return copied
}
func CopyIPNetSlice(original []net.IPNet) (copied []net.IPNet) {
if original == nil {
return nil
}
copied = make([]net.IPNet, len(original))
for i := range original {
copied[i] = CopyIPNet(original[i])
}
return copied
}
func CopyNetipPrefixesSlice(original []netip.Prefix) (copied []netip.Prefix) {
if original == nil {
return nil

View File

@ -187,32 +187,6 @@ func MergeUint16Slices(a, b []uint16) (result []uint16) {
return result
}
func MergeIPNetsSlices(a, b []net.IPNet) (result []net.IPNet) {
if a == nil && b == nil {
return nil
}
seen := make(map[string]struct{}, len(a)+len(b))
result = make([]net.IPNet, 0, len(a)+len(b))
for _, ipNet := range a {
key := ipNet.String()
if _, ok := seen[key]; ok {
continue // duplicate
}
result = append(result, ipNet)
seen[key] = struct{}{}
}
for _, ipNet := range b {
key := ipNet.String()
if _, ok := seen[key]; ok {
continue // duplicate
}
result = append(result, ipNet)
seen[key] = struct{}{}
}
return result
}
func MergeNetipAddressesSlices(a, b []netip.Addr) (result []netip.Addr) {
if a == nil && b == nil {
return nil

View File

@ -145,15 +145,6 @@ func OverrideWithUint16Slice(existing, other []uint16) (result []uint16) {
return result
}
func OverrideWithIPNetsSlice(existing, other []net.IPNet) (result []net.IPNet) {
if other == nil {
return existing
}
result = make([]net.IPNet, len(other))
copy(result, other)
return result
}
func OverrideWithNetipAddressesSlice(existing, other []netip.Addr) (result []netip.Addr) {
if other == nil {
return existing

View File

@ -2,7 +2,7 @@ package settings
import (
"fmt"
"net"
"net/netip"
"regexp"
"github.com/qdm12/gluetun/internal/configuration/settings/helpers"
@ -22,7 +22,7 @@ type Wireguard struct {
// It cannot be nil in the internal state.
PreSharedKey *string
// Addresses are the Wireguard interface addresses.
Addresses []net.IPNet
Addresses []netip.Prefix
// Interface is the name of the Wireguard interface
// to create. It cannot be the empty string in the
// internal state.
@ -78,13 +78,12 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported bool) (err error)
return fmt.Errorf("%w", ErrWireguardInterfaceAddressNotSet)
}
for i, ipNet := range w.Addresses {
if ipNet.IP == nil || ipNet.Mask == nil {
if !ipNet.IsValid() {
return fmt.Errorf("%w: for address at index %d: %s",
ErrWireguardInterfaceAddressNotSet, i, ipNet.String())
}
ipv6Net := ipNet.IP.To4() == nil
if ipv6Net && !ipv6Supported {
if !ipv6Supported && ipNet.Addr().Is6() {
return fmt.Errorf("%w: address %s",
ErrWireguardInterfaceAddressIPv6, ipNet)
}
@ -109,7 +108,7 @@ func (w *Wireguard) copy() (copied Wireguard) {
return Wireguard{
PrivateKey: helpers.CopyStringPtr(w.PrivateKey),
PreSharedKey: helpers.CopyStringPtr(w.PreSharedKey),
Addresses: helpers.CopyIPNetSlice(w.Addresses),
Addresses: helpers.CopyNetipPrefixesSlice(w.Addresses),
Interface: w.Interface,
Implementation: w.Implementation,
}
@ -118,7 +117,7 @@ func (w *Wireguard) copy() (copied Wireguard) {
func (w *Wireguard) mergeWith(other Wireguard) {
w.PrivateKey = helpers.MergeWithStringPtr(w.PrivateKey, other.PrivateKey)
w.PreSharedKey = helpers.MergeWithStringPtr(w.PreSharedKey, other.PreSharedKey)
w.Addresses = helpers.MergeIPNetsSlices(w.Addresses, other.Addresses)
w.Addresses = helpers.MergeNetipPrefixesSlices(w.Addresses, other.Addresses)
w.Interface = helpers.MergeWithString(w.Interface, other.Interface)
w.Implementation = helpers.MergeWithString(w.Implementation, other.Implementation)
}
@ -126,7 +125,7 @@ func (w *Wireguard) mergeWith(other Wireguard) {
func (w *Wireguard) overrideWith(other Wireguard) {
w.PrivateKey = helpers.OverrideWithStringPtr(w.PrivateKey, other.PrivateKey)
w.PreSharedKey = helpers.OverrideWithStringPtr(w.PreSharedKey, other.PreSharedKey)
w.Addresses = helpers.OverrideWithIPNetsSlice(w.Addresses, other.Addresses)
w.Addresses = helpers.OverrideWithNetipPrefixesSlice(w.Addresses, other.Addresses)
w.Interface = helpers.OverrideWithString(w.Interface, other.Interface)
w.Implementation = helpers.OverrideWithString(w.Implementation, other.Implementation)
}

View File

@ -3,7 +3,7 @@ package env
import (
"errors"
"fmt"
"net"
"net/netip"
"strconv"
"github.com/qdm12/gluetun/internal/configuration/settings"
@ -24,7 +24,7 @@ func (s *Source) readFirewall() (firewall settings.Firewall, err error) {
outboundSubnetsKey, _ := s.getEnvWithRetro("FIREWALL_OUTBOUND_SUBNETS", "EXTRA_SUBNETS")
outboundSubnetStrings := envToCSV(outboundSubnetsKey)
firewall.OutboundSubnets, err = stringsToIPNets(outboundSubnetStrings)
firewall.OutboundSubnets, err = stringsToNetipPrefixes(outboundSubnetStrings)
if err != nil {
return firewall, fmt.Errorf("environment variable %s: %w", outboundSubnetsKey, err)
}
@ -65,18 +65,16 @@ func stringsToPorts(ss []string) (ports []uint16, err error) {
return ports, nil
}
func stringsToIPNets(ss []string) (ipNets []net.IPNet, err error) {
func stringsToNetipPrefixes(ss []string) (ipPrefixes []netip.Prefix, err error) {
if len(ss) == 0 {
return nil, nil
}
ipNets = make([]net.IPNet, len(ss))
ipPrefixes = make([]netip.Prefix, len(ss))
for i, s := range ss {
ip, ipNet, err := net.ParseCIDR(s)
ipPrefixes[i], err = netip.ParsePrefix(s)
if err != nil {
return nil, fmt.Errorf("parsing IP network %q: %w", s, err)
}
ipNet.IP = ip
ipNets[i] = *ipNet
}
return ipNets, nil
return ipPrefixes, nil
}

View File

@ -2,7 +2,7 @@ package env
import (
"fmt"
"net"
"net/netip"
"os"
"strings"
@ -24,22 +24,20 @@ func (s *Source) readWireguard() (wireguard settings.Wireguard, err error) {
return wireguard, nil
}
func (s *Source) readWireguardAddresses() (addresses []net.IPNet, err error) {
func (s *Source) readWireguardAddresses() (addresses []netip.Prefix, err error) {
key, addressesCSV := s.getEnvWithRetro("WIREGUARD_ADDRESSES", "WIREGUARD_ADDRESS")
if addressesCSV == "" {
return nil, nil
}
addressStrings := strings.Split(addressesCSV, ",")
addresses = make([]net.IPNet, len(addressStrings))
addresses = make([]netip.Prefix, len(addressStrings))
for i, addressString := range addressStrings {
addressString = strings.TrimSpace(addressString)
ip, ipNet, err := net.ParseCIDR(addressString)
addresses[i], err = netip.ParsePrefix(addressString)
if err != nil {
return nil, fmt.Errorf("environment variable %s: %w", key, err)
}
ipNet.IP = ip
addresses[i] = *ipNet
}
return addresses, nil

View File

@ -98,7 +98,7 @@ func (c *Config) enable(ctx context.Context) (err error) {
}
for _, network := range c.localNetworks {
if err := c.acceptOutputFromIPToSubnet(ctx, network.InterfaceName, network.IP, *network.IPNet, remove); err != nil {
if err := c.acceptOutputFromIPToSubnet(ctx, network.InterfaceName, network.IP, network.IPNet, remove); err != nil {
return err
}
if err = c.acceptIpv6MulticastOutput(ctx, network.InterfaceName, remove); err != nil {
@ -113,7 +113,7 @@ func (c *Config) enable(ctx context.Context) (err error) {
// Allows packets from any IP address to go through eth0 / local network
// to reach Gluetun.
for _, network := range c.localNetworks {
if err := c.acceptInputToSubnet(ctx, network.InterfaceName, *network.IPNet, remove); err != nil {
if err := c.acceptInputToSubnet(ctx, network.InterfaceName, network.IPNet, remove); err != nil {
return err
}
}

View File

@ -2,7 +2,7 @@ package firewall
import (
"context"
"net"
"net/netip"
"sync"
"github.com/qdm12/gluetun/internal/models"
@ -27,7 +27,7 @@ type Config struct { //nolint:maligned
enabled bool
vpnConnection models.Connection
vpnIntf string
outboundSubnets []net.IPNet
outboundSubnets []netip.Prefix
allowedInputPorts map[uint16]map[string]struct{} // port to interfaces set mapping
stateMutex sync.Mutex
}

View File

@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net"
"net/netip"
"os"
"os/exec"
"strings"
@ -108,9 +109,8 @@ func (c *Config) acceptInputThroughInterface(ctx context.Context, intf string, r
))
}
func (c *Config) acceptInputToSubnet(ctx context.Context, intf string, destination net.IPNet, remove bool) error {
isIP4Subnet := destination.IP.To4() != nil
func (c *Config) acceptInputToSubnet(ctx context.Context, intf string,
destination netip.Prefix, remove bool) error {
interfaceFlag := "-i " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
@ -119,7 +119,7 @@ func (c *Config) acceptInputToSubnet(ctx context.Context, intf string, destinati
instruction := fmt.Sprintf("%s INPUT %s -d %s -j ACCEPT",
appendOrDelete(remove), interfaceFlag, destination.String())
if isIP4Subnet {
if destination.Addr().Is4() {
return c.runIptablesInstruction(ctx, instruction)
}
if c.ip6Tables == "" {
@ -157,8 +157,8 @@ func (c *Config) acceptOutputTrafficToVPN(ctx context.Context,
// Thanks to @npawelek.
func (c *Config) acceptOutputFromIPToSubnet(ctx context.Context,
intf string, sourceIP net.IP, destinationSubnet net.IPNet, remove bool) error {
doIPv4 := sourceIP.To4() != nil && destinationSubnet.IP.To4() != nil
intf string, sourceIP net.IP, destinationSubnet netip.Prefix, remove bool) error {
doIPv4 := sourceIP.To4() != nil && destinationSubnet.Addr().Is4()
interfaceFlag := "-o " + intf
if intf == "*" { // all interfaces

View File

@ -3,18 +3,18 @@ package firewall
import (
"context"
"fmt"
"net"
"net/netip"
"github.com/qdm12/gluetun/internal/subnet"
)
func (c *Config) SetOutboundSubnets(ctx context.Context, subnets []net.IPNet) (err error) {
func (c *Config) SetOutboundSubnets(ctx context.Context, subnets []netip.Prefix) (err error) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
if !c.enabled {
c.logger.Info("firewall disabled, only updating allowed subnets internal list")
c.outboundSubnets = make([]net.IPNet, len(subnets))
c.outboundSubnets = make([]netip.Prefix, len(subnets))
copy(c.outboundSubnets, subnets)
return nil
}
@ -34,7 +34,7 @@ func (c *Config) SetOutboundSubnets(ctx context.Context, subnets []net.IPNet) (e
return nil
}
func (c *Config) removeOutboundSubnets(ctx context.Context, subnets []net.IPNet) {
func (c *Config) removeOutboundSubnets(ctx context.Context, subnets []netip.Prefix) {
const remove = true
for _, subNet := range subnets {
for _, defaultRoute := range c.defaultRoutes {
@ -49,7 +49,7 @@ func (c *Config) removeOutboundSubnets(ctx context.Context, subnets []net.IPNet)
}
}
func (c *Config) addOutboundSubnets(ctx context.Context, subnets []net.IPNet) error {
func (c *Config) addOutboundSubnets(ctx context.Context, subnets []netip.Prefix) error {
const remove = false
for _, subnet := range subnets {
for _, defaultRoute := range c.defaultRoutes {

View File

@ -1,11 +0,0 @@
package netlink
import (
"net"
"github.com/vishvananda/netlink"
)
func NewIPNet(ip net.IP) *net.IPNet {
return netlink.NewIPNet(ip)
}

View File

@ -2,6 +2,7 @@ package utils
import (
"net"
"net/netip"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/models"
@ -25,17 +26,12 @@ func BuildWireguardSettings(connection models.Connection,
copy(settings.Endpoint.IP, connection.IP)
settings.Endpoint.Port = int(connection.Port)
settings.Addresses = make([]*net.IPNet, 0, len(userSettings.Addresses))
settings.Addresses = make([]netip.Prefix, 0, len(userSettings.Addresses))
for _, address := range userSettings.Addresses {
ipv6Address := address.IP.To4() == nil
if !ipv6Supported && ipv6Address {
if !ipv6Supported && address.Addr().Is6() {
continue
}
addressCopy := new(net.IPNet)
addressCopy.IP = make(net.IP, len(address.IP))
copy(addressCopy.IP, address.IP)
addressCopy.Mask = make(net.IPMask, len(address.Mask))
copy(addressCopy.Mask, address.Mask)
addressCopy := netip.PrefixFrom(address.Addr(), address.Bits())
settings.Addresses = append(settings.Addresses, addressCopy)
}

View File

@ -2,6 +2,7 @@ package utils
import (
"net"
"net/netip"
"testing"
"github.com/qdm12/gluetun/internal/configuration/settings"
@ -30,9 +31,9 @@ func Test_BuildWireguardSettings(t *testing.T) {
userSettings: settings.Wireguard{
PrivateKey: stringPtr("private"),
PreSharedKey: stringPtr("pre-shared"),
Addresses: []net.IPNet{
{IP: net.IPv4(1, 1, 1, 1), Mask: net.IPv4Mask(255, 255, 255, 255)},
{IP: net.IPv6zero, Mask: net.IPv4Mask(255, 255, 255, 255)},
Addresses: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32),
netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 32),
},
Interface: "wg1",
},
@ -46,8 +47,8 @@ func Test_BuildWireguardSettings(t *testing.T) {
IP: net.IPv4(1, 2, 3, 4),
Port: 51821,
},
Addresses: []*net.IPNet{
{IP: net.IPv4(1, 1, 1, 1), Mask: net.IPv4Mask(255, 255, 255, 255)},
Addresses: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32),
},
RulePriority: 101,
IPv6: boolPtr(false),

View File

@ -0,0 +1,33 @@
package routing
import (
"fmt"
"net"
"net/netip"
)
func NetipPrefixToIPNet(prefix *netip.Prefix) (ipNet *net.IPNet) {
if prefix == nil {
return nil
}
s := prefix.String()
ip, ipNet, err := net.ParseCIDR(s)
if err != nil {
panic(err)
}
ipNet.IP = ip
return ipNet
}
func netIPNetToNetipPrefix(ipNet net.IPNet) (prefix netip.Prefix) {
return netip.MustParsePrefix(ipNet.String())
}
func netIPToNetipAddress(ip net.IP) (address netip.Addr) {
address, ok := netip.AddrFromSlice(ip)
if !ok {
panic(fmt.Sprintf("converting %#v to netip.Addr failed", ip))
}
return address
}

View File

@ -0,0 +1,52 @@
package routing
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_netIPToNetipAddress(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
ip net.IP
address netip.Addr
panicMessage string
}{
"nil ip": {
panicMessage: "converting net.IP(nil) to netip.Addr failed",
},
"IPv4": {
ip: net.IPv4(1, 2, 3, 4),
address: netip.AddrFrom4([4]byte{1, 2, 3, 4}),
},
"IPv6": {
ip: net.IPv6zero,
address: netip.AddrFrom16([16]byte{}),
},
"IPv4 prefixed with 0xffff": {
ip: net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 1, 2, 3, 4},
address: netip.AddrFrom4([4]byte{1, 2, 3, 4}),
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
if testCase.panicMessage != "" {
assert.PanicsWithValue(t, testCase.panicMessage, func() {
netIPToNetipAddress(testCase.ip)
})
return
}
address := netIPToNetipAddress(testCase.ip)
assert.Equal(t, testCase.address, address)
})
}
}

View File

@ -2,7 +2,7 @@ package routing
import (
"fmt"
"net"
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
)
@ -17,8 +17,9 @@ func (r *Routing) routeInboundFromDefault(defaultRoutes []DefaultRoute) (err err
return fmt.Errorf("adding rule: %w", err)
}
defaultDestinationIPv4 := net.IPNet{IP: net.IPv4(0, 0, 0, 0), Mask: net.IPv4Mask(0, 0, 0, 0)}
defaultDestinationIPv6 := net.IPNet{IP: net.IPv6zero, Mask: net.IPMask(net.IPv6zero)}
const bits = 0
defaultDestinationIPv4 := netip.PrefixFrom(netip.AddrFrom4([4]byte{}), bits)
defaultDestinationIPv6 := netip.PrefixFrom(netip.AddrFrom16([16]byte{}), bits)
for _, defaultRoute := range defaultRoutes {
defaultDestination := defaultDestinationIPv4
@ -36,8 +37,9 @@ func (r *Routing) routeInboundFromDefault(defaultRoutes []DefaultRoute) (err err
}
func (r *Routing) unrouteInboundFromDefault(defaultRoutes []DefaultRoute) (err error) {
defaultDestinationIPv4 := net.IPNet{IP: net.IPv4(0, 0, 0, 0), Mask: net.IPv4Mask(0, 0, 0, 0)}
defaultDestinationIPv6 := net.IPNet{IP: net.IPv6zero, Mask: net.IPMask(net.IPv6zero)}
const bits = 0
defaultDestinationIPv4 := netip.PrefixFrom(netip.AddrFrom4([4]byte{}), bits)
defaultDestinationIPv6 := netip.PrefixFrom(netip.AddrFrom16([16]byte{}), bits)
for _, defaultRoute := range defaultRoutes {
defaultDestination := defaultDestinationIPv4
@ -60,9 +62,16 @@ func (r *Routing) unrouteInboundFromDefault(defaultRoutes []DefaultRoute) (err e
func (r *Routing) addRuleInboundFromDefault(table int, defaultRoutes []DefaultRoute) (err error) {
for _, defaultRoute := range defaultRoutes {
defaultIPMasked32 := netlink.NewIPNet(defaultRoute.AssignedIP)
ruleDstNet := (*net.IPNet)(nil)
err = r.addIPRule(defaultIPMasked32, ruleDstNet, table, inboundPriority)
assignedIP := netIPToNetipAddress(defaultRoute.AssignedIP)
bits := 32
if assignedIP.Is6() {
bits = 128
}
r.logger.Debug(fmt.Sprintf("ASSIGNED IP IS %#v -> %s, bits %d",
defaultRoute.AssignedIP, assignedIP, bits))
defaultIPMasked := netip.PrefixFrom(assignedIP, bits)
ruleDstNet := (*netip.Prefix)(nil)
err = r.addIPRule(&defaultIPMasked, ruleDstNet, table, inboundPriority)
if err != nil {
return fmt.Errorf("adding rule for default route %s: %w", defaultRoute, err)
}
@ -73,9 +82,14 @@ func (r *Routing) addRuleInboundFromDefault(table int, defaultRoutes []DefaultRo
func (r *Routing) delRuleInboundFromDefault(table int, defaultRoutes []DefaultRoute) (err error) {
for _, defaultRoute := range defaultRoutes {
defaultIPMasked32 := netlink.NewIPNet(defaultRoute.AssignedIP)
ruleDstNet := (*net.IPNet)(nil)
err = r.deleteIPRule(defaultIPMasked32, ruleDstNet, table, inboundPriority)
assignedIP := netIPToNetipAddress(defaultRoute.AssignedIP)
bits := 32
if assignedIP.Is6() {
bits = 128
}
defaultIPMasked := netip.PrefixFrom(assignedIP, bits)
ruleDstNet := (*netip.Prefix)(nil)
err = r.deleteIPRule(&defaultIPMasked, ruleDstNet, table, inboundPriority)
if err != nil {
return fmt.Errorf("deleting rule for default route %s: %w", defaultRoute, err)
}

View File

@ -22,6 +22,18 @@ func ipMatchesFamily(ip net.IP, family int) bool {
(family == netlink.FAMILY_V4 && ip.To4() != nil)
}
func ensureNoIPv6WrappedIPv4(candidateIP net.IP) (resultIP net.IP) {
const ipv4Size = 4
if candidateIP.To4() == nil || len(candidateIP) == ipv4Size { // ipv6 or ipv4
return candidateIP
}
// ipv6-wrapped ipv4
resultIP = make(net.IP, ipv4Size)
copy(resultIP, candidateIP[12:16])
return resultIP
}
func (r *Routing) assignedIP(interfaceName string, family int) (ip net.IP, err error) {
iface, err := net.InterfaceByName(interfaceName)
if err != nil {
@ -34,14 +46,22 @@ func (r *Routing) assignedIP(interfaceName string, family int) (ip net.IP, err e
for _, address := range addresses {
switch value := address.(type) {
case *net.IPAddr:
if ipMatchesFamily(value.IP, family) {
return value.IP, nil
}
ip = value.IP
case *net.IPNet:
if ipMatchesFamily(value.IP, family) {
return value.IP, nil
}
ip = value.IP
default:
continue
}
if !ipMatchesFamily(ip, family) {
continue
}
// Ensure we don't return an IPv6-wrapped IPv4 address
// since netip.Address String method works differently than
// net.IP String method for this kind of addresses.
ip = ensureNoIPv6WrappedIPv4(ip)
return ip, nil
}
return nil, fmt.Errorf("%w: interface %s in %d addresses",
errInterfaceIPNotFound, interfaceName, len(addresses))

View File

@ -96,3 +96,35 @@ func Test_IPIsPrivate(t *testing.T) {
})
}
}
func Test_ensureNoIPv6WrappedIPv4(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
candidateIP net.IP
resultIP net.IP
}{
"nil": {},
"ipv6": {
candidateIP: net.IPv6loopback,
resultIP: net.IPv6loopback,
},
"ipv4": {
candidateIP: net.IP{1, 2, 3, 4},
resultIP: net.IP{1, 2, 3, 4},
},
"ipv6_wrapped_ipv4": {
candidateIP: net.IPv4(1, 2, 3, 4),
resultIP: net.IP{1, 2, 3, 4},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
resultIP := ensureNoIPv6WrappedIPv4(testCase.candidateIP)
assert.Equal(t, testCase.resultIP, resultIP)
})
}
}

View File

@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"net"
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
)
@ -15,7 +16,7 @@ var (
)
type LocalNetwork struct {
IPNet *net.IPNet
IPNet netip.Prefix
InterfaceName string
IP net.IP
}
@ -55,7 +56,7 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
var localNet LocalNetwork
localNet.IPNet = route.Dst
localNet.IPNet = netIPNetToNetipPrefix(*route.Dst)
r.logger.Info("local ipnet found: " + localNet.IPNet.String())
link, err := r.netLinker.LinkByIndex(route.LinkIndex)
@ -66,7 +67,7 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
localNet.InterfaceName = link.Attrs().Name
family := netlink.FAMILY_V6
if localNet.IPNet.IP.To4() != nil {
if localNet.IPNet.Addr().Is4() {
family = netlink.FAMILY_V4
}
ip, err := r.assignedIP(localNet.InterfaceName, family)
@ -87,7 +88,7 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
}
func (r *Routing) AddLocalRules(subnets []LocalNetwork) (err error) {
for _, net := range subnets {
for _, subnet := range subnets {
// The main table is a built-in value for Linux, see "man 8 ip-route"
const mainTable = 254
@ -96,9 +97,9 @@ func (r *Routing) AddLocalRules(subnets []LocalNetwork) (err error) {
const localPriority = 98
// Main table was setup correctly by Docker, just need to add rules to use it
err = r.addIPRule(nil, net.IPNet, mainTable, localPriority)
err = r.addIPRule(nil, &subnet.IPNet, mainTable, localPriority)
if err != nil {
return fmt.Errorf("adding rule: %v: %w", net.IPNet, err)
return fmt.Errorf("adding rule: %v: %w", subnet.IPNet, err)
}
}
return nil

View File

@ -2,7 +2,7 @@ package routing
import (
"fmt"
"net"
"net/netip"
"github.com/qdm12/gluetun/internal/subnet"
)
@ -12,7 +12,7 @@ const (
outboundPriority = 99
)
func (r *Routing) SetOutboundRoutes(outboundSubnets []net.IPNet) error {
func (r *Routing) SetOutboundRoutes(outboundSubnets []netip.Prefix) error {
defaultRoutes, err := r.DefaultRoutes()
if err != nil {
return err
@ -20,7 +20,7 @@ func (r *Routing) SetOutboundRoutes(outboundSubnets []net.IPNet) error {
return r.setOutboundRoutes(outboundSubnets, defaultRoutes)
}
func (r *Routing) setOutboundRoutes(outboundSubnets []net.IPNet,
func (r *Routing) setOutboundRoutes(outboundSubnets []netip.Prefix,
defaultRoutes []DefaultRoute) (err error) {
r.stateMutex.Lock()
defer r.stateMutex.Unlock()
@ -45,7 +45,7 @@ func (r *Routing) setOutboundRoutes(outboundSubnets []net.IPNet,
return nil
}
func (r *Routing) removeOutboundSubnets(subnets []net.IPNet,
func (r *Routing) removeOutboundSubnets(subnets []netip.Prefix,
defaultRoutes []DefaultRoute) (warnings []string) {
for i, subNet := range subnets {
for _, defaultRoute := range defaultRoutes {
@ -56,7 +56,7 @@ func (r *Routing) removeOutboundSubnets(subnets []net.IPNet,
}
}
ruleSrcNet := (*net.IPNet)(nil)
ruleSrcNet := (*netip.Prefix)(nil)
ruleDstNet := &subnets[i]
err := r.deleteIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority)
if err != nil {
@ -71,7 +71,7 @@ func (r *Routing) removeOutboundSubnets(subnets []net.IPNet,
return warnings
}
func (r *Routing) addOutboundSubnets(subnets []net.IPNet,
func (r *Routing) addOutboundSubnets(subnets []netip.Prefix,
defaultRoutes []DefaultRoute) (err error) {
for i, subnet := range subnets {
for _, defaultRoute := range defaultRoutes {
@ -81,7 +81,7 @@ func (r *Routing) addOutboundSubnets(subnets []net.IPNet,
}
}
ruleSrcNet := (*net.IPNet)(nil)
ruleSrcNet := (*netip.Prefix)(nil)
ruleDstNet := &subnets[i]
err = r.addIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority)
if err != nil {

View File

@ -3,12 +3,13 @@ package routing
import (
"fmt"
"net"
"net/netip"
"strconv"
"github.com/qdm12/gluetun/internal/netlink"
)
func (r *Routing) addRouteVia(destination net.IPNet, gateway net.IP,
func (r *Routing) addRouteVia(destination netip.Prefix, gateway net.IP,
iface string, table int) error {
destinationStr := destination.String()
r.logger.Info("adding route for " + destinationStr)
@ -23,7 +24,7 @@ func (r *Routing) addRouteVia(destination net.IPNet, gateway net.IP,
}
route := netlink.Route{
Dst: &destination,
Dst: NetipPrefixToIPNet(&destination),
Gw: gateway,
LinkIndex: link.Attrs().Index,
Table: table,
@ -36,7 +37,7 @@ func (r *Routing) addRouteVia(destination net.IPNet, gateway net.IP,
return nil
}
func (r *Routing) deleteRouteVia(destination net.IPNet, gateway net.IP,
func (r *Routing) deleteRouteVia(destination netip.Prefix, gateway net.IP,
iface string, table int) (err error) {
destinationStr := destination.String()
r.logger.Info("deleting route for " + destinationStr)
@ -51,7 +52,7 @@ func (r *Routing) deleteRouteVia(destination net.IPNet, gateway net.IP,
}
route := netlink.Route{
Dst: &destination,
Dst: NetipPrefixToIPNet(&destination),
Gw: gateway,
LinkIndex: link.Attrs().Index,
Table: table,

View File

@ -1,7 +1,7 @@
package routing
import (
"net"
"net/netip"
"sync"
"github.com/qdm12/gluetun/internal/netlink"
@ -48,7 +48,7 @@ type Linker interface {
type Routing struct {
netLinker NetLinker
logger Logger
outboundSubnets []net.IPNet
outboundSubnets []netip.Prefix
stateMutex sync.RWMutex
}

View File

@ -4,17 +4,18 @@ import (
"bytes"
"fmt"
"net"
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
)
func (r *Routing) addIPRule(src, dst *net.IPNet, table, priority int) error {
func (r *Routing) addIPRule(src, dst *netip.Prefix, table, priority int) error {
const add = true
r.logger.Debug(ruleDbgMsg(add, src, dst, table, priority))
rule := netlink.NewRule()
rule.Src = src
rule.Dst = dst
rule.Src = NetipPrefixToIPNet(src)
rule.Dst = NetipPrefixToIPNet(dst)
rule.Priority = priority
rule.Table = table
@ -35,13 +36,13 @@ func (r *Routing) addIPRule(src, dst *net.IPNet, table, priority int) error {
return nil
}
func (r *Routing) deleteIPRule(src, dst *net.IPNet, table, priority int) error {
func (r *Routing) deleteIPRule(src, dst *netip.Prefix, table, priority int) error {
const add = false
r.logger.Debug(ruleDbgMsg(add, src, dst, table, priority))
rule := netlink.NewRule()
rule.Src = src
rule.Dst = dst
rule.Src = NetipPrefixToIPNet(src)
rule.Dst = NetipPrefixToIPNet(dst)
rule.Priority = priority
rule.Table = table
@ -60,7 +61,7 @@ func (r *Routing) deleteIPRule(src, dst *net.IPNet, table, priority int) error {
return nil
}
func ruleDbgMsg(add bool, src, dst *net.IPNet,
func ruleDbgMsg(add bool, src, dst *netip.Prefix,
table, priority int) (debugMessage string) {
debugMessage = "ip rule"

View File

@ -3,6 +3,7 @@ package routing
import (
"errors"
"net"
"net/netip"
"testing"
"github.com/golang/mock/gomock"
@ -11,20 +12,17 @@ import (
"github.com/stretchr/testify/require"
)
func makeIPNet(t *testing.T, n byte) *net.IPNet {
t.Helper()
return &net.IPNet{
IP: net.IPv4(n, n, n, 0),
Mask: net.IPv4Mask(255, 255, 255, 0),
}
func makeNetipPrefix(n byte) *netip.Prefix {
const bits = 24
prefix := netip.PrefixFrom(netip.AddrFrom4([4]byte{n, n, n, 0}), bits)
return &prefix
}
func makeIPRule(t *testing.T, src, dst *net.IPNet,
func makeIPRule(src, dst *netip.Prefix,
table, priority int) *netlink.Rule {
t.Helper()
rule := netlink.NewRule()
rule.Src = src
rule.Dst = dst
rule.Src = NetipPrefixToIPNet(src)
rule.Dst = NetipPrefixToIPNet(dst)
rule.Table = table
rule.Priority = priority
return rule
@ -47,8 +45,8 @@ func Test_Routing_addIPRule(t *testing.T) {
}
testCases := map[string]struct {
src *net.IPNet
dst *net.IPNet
src *netip.Prefix
dst *netip.Prefix
table int
priority int
dbgMsg string
@ -64,46 +62,46 @@ func Test_Routing_addIPRule(t *testing.T) {
err: errors.New("listing rules: dummy error"),
},
"rule already exists": {
src: makeIPNet(t, 1),
dst: makeIPNet(t, 2),
src: makeNetipPrefix(1),
dst: makeNetipPrefix(2),
table: 99,
priority: 99,
dbgMsg: "ip rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
ruleList: ruleListCall{
rules: []netlink.Rule{
*makeIPRule(t, makeIPNet(t, 2), makeIPNet(t, 2), 99, 99),
*makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99),
*makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
*makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
},
},
},
"add rule error": {
src: makeIPNet(t, 1),
dst: makeIPNet(t, 2),
src: makeNetipPrefix(1),
dst: makeNetipPrefix(2),
table: 99,
priority: 99,
dbgMsg: "ip rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
ruleAdd: ruleAddCall{
expected: true,
ruleToAdd: makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99),
ruleToAdd: makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
err: errDummy,
},
err: errors.New("adding rule ip rule 99: from 1.1.1.0/24 to 2.2.2.0/24 table 99: dummy error"),
},
"add rule success": {
src: makeIPNet(t, 1),
dst: makeIPNet(t, 2),
src: makeNetipPrefix(1),
dst: makeNetipPrefix(2),
table: 99,
priority: 99,
dbgMsg: "ip rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
ruleList: ruleListCall{
rules: []netlink.Rule{
*makeIPRule(t, makeIPNet(t, 2), makeIPNet(t, 2), 99, 99),
*makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 101, 101),
*makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
*makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 101, 101),
},
},
ruleAdd: ruleAddCall{
expected: true,
ruleToAdd: makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99),
ruleToAdd: makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
},
},
}
@ -160,8 +158,8 @@ func Test_Routing_deleteIPRule(t *testing.T) {
}
testCases := map[string]struct {
src *net.IPNet
dst *net.IPNet
src *netip.Prefix
dst *netip.Prefix
table int
priority int
dbgMsg string
@ -177,50 +175,50 @@ func Test_Routing_deleteIPRule(t *testing.T) {
err: errors.New("listing rules: dummy error"),
},
"rule delete error": {
src: makeIPNet(t, 1),
dst: makeIPNet(t, 2),
src: makeNetipPrefix(1),
dst: makeNetipPrefix(2),
table: 99,
priority: 99,
dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
ruleList: ruleListCall{
rules: []netlink.Rule{
*makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99),
*makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
},
},
ruleDel: ruleDelCall{
expected: true,
ruleToDel: makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99),
ruleToDel: makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
err: errDummy,
},
err: errors.New("deleting rule ip rule 99: from 1.1.1.0/24 to 2.2.2.0/24 table 99: dummy error"),
},
"rule deleted": {
src: makeIPNet(t, 1),
dst: makeIPNet(t, 2),
src: makeNetipPrefix(1),
dst: makeNetipPrefix(2),
table: 99,
priority: 99,
dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
ruleList: ruleListCall{
rules: []netlink.Rule{
*makeIPRule(t, makeIPNet(t, 2), makeIPNet(t, 2), 99, 99),
*makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99),
*makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
*makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
},
},
ruleDel: ruleDelCall{
expected: true,
ruleToDel: makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99),
ruleToDel: makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
},
},
"rule does not exist": {
src: makeIPNet(t, 1),
dst: makeIPNet(t, 2),
src: makeNetipPrefix(1),
dst: makeNetipPrefix(2),
table: 99,
priority: 99,
dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
ruleList: ruleListCall{
rules: []netlink.Rule{
*makeIPRule(t, makeIPNet(t, 2), makeIPNet(t, 2), 99, 99),
*makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 101, 101),
*makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
*makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 101, 101),
},
},
},
@ -266,8 +264,8 @@ func Test_ruleDbgMsg(t *testing.T) {
testCases := map[string]struct {
add bool
src *net.IPNet
dst *net.IPNet
src *netip.Prefix
dst *netip.Prefix
table int
priority int
dbgMsg string
@ -277,15 +275,15 @@ func Test_ruleDbgMsg(t *testing.T) {
},
"add rule": {
add: true,
src: makeIPNet(t, 1),
dst: makeIPNet(t, 2),
src: makeNetipPrefix(1),
dst: makeNetipPrefix(2),
table: 100,
priority: 101,
dbgMsg: "ip rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 100 pref 101",
},
"del rule": {
src: makeIPNet(t, 1),
dst: makeIPNet(t, 2),
src: makeNetipPrefix(1),
dst: makeNetipPrefix(2),
table: 100,
priority: 101,
dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 100 pref 101",

View File

@ -1,7 +0,0 @@
package subnet
import "net"
func subnetsAreEqual(a, b net.IPNet) bool {
return a.IP.Equal(b.IP) && a.Mask.String() == b.Mask.String()
}

View File

@ -1,20 +1,20 @@
package subnet
import (
"net"
"net/netip"
)
func FindSubnetsToChange(oldSubnets, newSubnets []net.IPNet) (subnetsToAdd, subnetsToRemove []net.IPNet) {
func FindSubnetsToChange(oldSubnets, newSubnets []netip.Prefix) (subnetsToAdd, subnetsToRemove []netip.Prefix) {
subnetsToAdd = findSubnetsToAdd(oldSubnets, newSubnets)
subnetsToRemove = findSubnetsToRemove(oldSubnets, newSubnets)
return subnetsToAdd, subnetsToRemove
}
func findSubnetsToAdd(oldSubnets, newSubnets []net.IPNet) (subnetsToAdd []net.IPNet) {
func findSubnetsToAdd(oldSubnets, newSubnets []netip.Prefix) (subnetsToAdd []netip.Prefix) {
for _, newSubnet := range newSubnets {
found := false
for _, oldSubnet := range oldSubnets {
if subnetsAreEqual(oldSubnet, newSubnet) {
if oldSubnet.String() == newSubnet.String() {
found = true
break
}
@ -26,11 +26,11 @@ func findSubnetsToAdd(oldSubnets, newSubnets []net.IPNet) (subnetsToAdd []net.IP
return subnetsToAdd
}
func findSubnetsToRemove(oldSubnets, newSubnets []net.IPNet) (subnetsToRemove []net.IPNet) {
func findSubnetsToRemove(oldSubnets, newSubnets []netip.Prefix) (subnetsToRemove []netip.Prefix) {
for _, oldSubnet := range oldSubnets {
found := false
for _, newSubnet := range newSubnets {
if subnetsAreEqual(oldSubnet, newSubnet) {
if oldSubnet.String() == newSubnet.String() {
found = true
break
}
@ -42,10 +42,10 @@ func findSubnetsToRemove(oldSubnets, newSubnets []net.IPNet) (subnetsToRemove []
return subnetsToRemove
}
func RemoveSubnetFromSubnets(subnets []net.IPNet, subnet net.IPNet) []net.IPNet {
func RemoveSubnetFromSubnets(subnets []netip.Prefix, subnet netip.Prefix) []netip.Prefix {
L := len(subnets)
for i := range subnets {
if subnetsAreEqual(subnet, subnets[i]) {
if subnet.String() == subnets[i].String() {
subnets[i] = subnets[L-1]
subnets = subnets[:L-1]
break

View File

@ -2,21 +2,22 @@ package wireguard
import (
"fmt"
"net"
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/routing"
)
func (w *Wireguard) addAddresses(link netlink.Link,
addresses []*net.IPNet) (err error) {
addresses []netip.Prefix) (err error) {
for _, ipNet := range addresses {
ipNetIsIPv6 := ipNet.IP.To4() == nil
if !*w.settings.IPv6 && ipNetIsIPv6 {
if !*w.settings.IPv6 && ipNet.Addr().Is6() {
continue
}
ipNet := ipNet
address := &netlink.Addr{
IPNet: ipNet,
IPNet: routing.NetipPrefixToIPNet(&ipNet),
}
err = w.netlink.AddrAdd(link, address)

View File

@ -2,11 +2,12 @@ package wireguard
import (
"errors"
"net"
"net/netip"
"testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/routing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -14,8 +15,8 @@ import (
func Test_Wireguard_addAddresses(t *testing.T) {
t.Parallel()
ipNetOne := &net.IPNet{IP: net.IPv4(1, 2, 3, 4), Mask: net.IPv4Mask(255, 255, 255, 255)}
ipNetTwo := &net.IPNet{IP: net.ParseIP("::1234"), Mask: net.CIDRMask(64, 128)}
ipNetOne := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 32)
ipNetTwo := netip.PrefixFrom(netip.MustParseAddr("::1234"), 64)
newLink := func() netlink.Link {
linkAttrs := netlink.NewLinkAttrs()
@ -29,20 +30,20 @@ func Test_Wireguard_addAddresses(t *testing.T) {
testCases := map[string]struct {
link netlink.Link
addrs []*net.IPNet
addrs []netip.Prefix
wgBuilder func(ctrl *gomock.Controller, link netlink.Link) *Wireguard
err error
}{
"success": {
link: newLink(),
addrs: []*net.IPNet{ipNetOne, ipNetTwo},
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
netLinker := NewMockNetLinker(ctrl)
firstCall := netLinker.EXPECT().
AddrAdd(link, &netlink.Addr{IPNet: ipNetOne}).
AddrAdd(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetOne)}).
Return(nil)
netLinker.EXPECT().
AddrAdd(link, &netlink.Addr{IPNet: ipNetTwo}).
AddrAdd(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetTwo)}).
Return(nil).After(firstCall)
return &Wireguard{
netlink: netLinker,
@ -54,11 +55,11 @@ func Test_Wireguard_addAddresses(t *testing.T) {
},
"first add error": {
link: newLink(),
addrs: []*net.IPNet{ipNetOne, ipNetTwo},
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
netLinker := NewMockNetLinker(ctrl)
netLinker.EXPECT().
AddrAdd(link, &netlink.Addr{IPNet: ipNetOne}).
AddrAdd(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetOne)}).
Return(errDummy)
return &Wireguard{
netlink: netLinker,
@ -71,14 +72,14 @@ func Test_Wireguard_addAddresses(t *testing.T) {
},
"second add error": {
link: newLink(),
addrs: []*net.IPNet{ipNetOne, ipNetTwo},
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
netLinker := NewMockNetLinker(ctrl)
firstCall := netLinker.EXPECT().
AddrAdd(link, &netlink.Addr{IPNet: ipNetOne}).
AddrAdd(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetOne)}).
Return(nil)
netLinker.EXPECT().
AddrAdd(link, &netlink.Addr{IPNet: ipNetTwo}).
AddrAdd(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetTwo)}).
Return(errDummy).After(firstCall)
return &Wireguard{
netlink: netLinker,
@ -91,7 +92,7 @@ func Test_Wireguard_addAddresses(t *testing.T) {
},
"ignore IPv6": {
link: newLink(),
addrs: []*net.IPNet{ipNetTwo},
addrs: []netip.Prefix{ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
return &Wireguard{
settings: Settings{

View File

@ -2,6 +2,7 @@ package wireguard
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
@ -33,9 +34,8 @@ func Test_New(t *testing.T) {
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
},
Addresses: []*net.IPNet{{
IP: net.IPv4(5, 6, 7, 8),
Mask: net.IPv4Mask(255, 255, 255, 255)},
Addresses: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{5, 6, 7, 8}), 32),
},
FirewallMark: 100,
},
@ -50,9 +50,8 @@ func Test_New(t *testing.T) {
IP: net.IPv4(1, 2, 3, 4),
Port: 51820,
},
Addresses: []*net.IPNet{{
IP: net.IPv4(5, 6, 7, 8),
Mask: net.IPv4Mask(255, 255, 255, 255)},
Addresses: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{5, 6, 7, 8}), 32),
},
FirewallMark: 100,
IPv6: ptr(false),

View File

@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"net"
"net/netip"
"regexp"
"strings"
@ -24,7 +25,7 @@ type Settings struct {
Endpoint *net.UDPAddr
// Addresses assigned to the client.
// Note IPv6 addresses are ignored if IPv6 is not supported.
Addresses []*net.IPNet
Addresses []netip.Prefix
// FirewallMark to be used in routing tables and IP rules.
// It defaults to 51820 if left to 0.
FirewallMark int
@ -77,9 +78,7 @@ var (
ErrEndpointIPMissing = errors.New("endpoint IP is missing")
ErrEndpointPortMissing = errors.New("endpoint port is missing")
ErrAddressMissing = errors.New("interface address is missing")
ErrAddressNil = errors.New("interface address is nil")
ErrAddressIPMissing = errors.New("interface address IP is missing")
ErrAddressMaskMissing = errors.New("interface address mask is missing")
ErrAddressNotValid = errors.New("interface address is not valid")
ErrFirewallMarkMissing = errors.New("firewall mark is missing")
ErrImplementationInvalid = errors.New("invalid implementation")
)
@ -122,16 +121,9 @@ func (s *Settings) Check() (err error) {
return fmt.Errorf("%w", ErrAddressMissing)
}
for i, addr := range s.Addresses {
switch {
case addr == nil:
if !addr.IsValid() {
return fmt.Errorf("%w: for address %d of %d",
ErrAddressNil, i+1, len(s.Addresses))
case addr.IP == nil:
return fmt.Errorf("%w: for address %d of %d",
ErrAddressIPMissing, i+1, len(s.Addresses))
case addr.Mask == nil:
return fmt.Errorf("%w: for address %d of %d",
ErrAddressMaskMissing, i+1, len(s.Addresses))
ErrAddressNotValid, i+1, len(s.Addresses))
}
}

View File

@ -3,6 +3,7 @@ package wireguard
import (
"errors"
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
@ -177,7 +178,7 @@ func Test_Settings_Check(t *testing.T) {
},
err: ErrAddressMissing,
},
"nil address": {
"invalid address": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
@ -186,35 +187,9 @@ func Test_Settings_Check(t *testing.T) {
IP: net.IPv4(1, 2, 3, 4),
Port: 51820,
},
Addresses: []*net.IPNet{nil},
Addresses: []netip.Prefix{{}},
},
err: errors.New("interface address is nil: for address 1 of 1"),
},
"nil address IP": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
PublicKey: validKey2,
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
Port: 51820,
},
Addresses: []*net.IPNet{{}},
},
err: errors.New("interface address IP is missing: for address 1 of 1"),
},
"nil address mask": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
PublicKey: validKey2,
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
Port: 51820,
},
Addresses: []*net.IPNet{{IP: net.IPv4(1, 2, 3, 4)}},
},
err: errors.New("interface address mask is missing: for address 1 of 1"),
err: errors.New("interface address is not valid: for address 1 of 1"),
},
"zero firewall mark": {
settings: Settings{
@ -225,7 +200,9 @@ func Test_Settings_Check(t *testing.T) {
IP: net.IPv4(1, 2, 3, 4),
Port: 51820,
},
Addresses: []*net.IPNet{{IP: net.IPv4(1, 2, 3, 4), Mask: net.CIDRMask(24, 32)}},
Addresses: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
},
},
err: ErrFirewallMarkMissing,
},
@ -238,7 +215,9 @@ func Test_Settings_Check(t *testing.T) {
IP: net.IPv4(1, 2, 3, 4),
Port: 51820,
},
Addresses: []*net.IPNet{{IP: net.IPv4(1, 2, 3, 4), Mask: net.CIDRMask(24, 32)}},
Addresses: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
},
FirewallMark: 999,
Implementation: "x",
},
@ -253,7 +232,9 @@ func Test_Settings_Check(t *testing.T) {
IP: net.IPv4(1, 2, 3, 4),
Port: 51820,
},
Addresses: []*net.IPNet{{IP: net.IPv4(1, 2, 3, 4), Mask: net.CIDRMask(24, 32)}},
Addresses: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
},
FirewallMark: 999,
Implementation: "userspace",
},
@ -356,9 +337,9 @@ func Test_Settings_Lines(t *testing.T) {
},
FirewallMark: 999,
RulePriority: 888,
Addresses: []*net.IPNet{
{IP: net.IPv4(1, 1, 1, 1), Mask: net.CIDRMask(24, 32)},
{IP: net.IPv4(2, 2, 2, 2), Mask: net.CIDRMask(32, 32)},
Addresses: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
},
IPv6: ptr(true),
Implementation: "userspace",
@ -386,9 +367,9 @@ func Test_Settings_Lines(t *testing.T) {
},
settings: Settings{
InterfaceName: "wg0",
Addresses: []*net.IPNet{
{IP: net.IPv4(1, 1, 1, 1), Mask: net.CIDRMask(24, 32)},
{IP: net.IPv4(2, 2, 2, 2), Mask: net.CIDRMask(32, 32)},
Addresses: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
},
IPv6: ptr(false),
},