From d21a94377916c7ce2f3b415f76f7b7d626a9dbcc Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Thu, 27 Apr 2023 13:41:05 +0000 Subject: [PATCH] chore(all): use `netip.Prefix` for ip networks - remove usage of `net.IPNet` - remove usage of `netaddr.IPPrefix` --- internal/configuration/settings/firewall.go | 10 +-- .../configuration/settings/helpers/copy.go | 36 -------- .../configuration/settings/helpers/merge.go | 26 ------ .../settings/helpers/override.go | 9 -- internal/configuration/settings/wireguard.go | 15 ++-- .../configuration/sources/env/firewall.go | 14 ++- .../configuration/sources/env/wireguard.go | 10 +-- internal/firewall/enable.go | 4 +- internal/firewall/firewall.go | 4 +- internal/firewall/iptables.go | 12 +-- internal/firewall/outboundsubnets.go | 10 +-- internal/netlink/ipnet.go | 11 --- internal/provider/utils/wireguard.go | 12 +-- internal/provider/utils/wireguard_test.go | 11 +-- internal/routing/conversion.go | 33 +++++++ internal/routing/conversion_test.go | 52 +++++++++++ internal/routing/inbound.go | 36 +++++--- internal/routing/ip.go | 32 +++++-- internal/routing/ip_test.go | 32 +++++++ internal/routing/local.go | 13 +-- internal/routing/outbound.go | 14 +-- internal/routing/routes.go | 9 +- internal/routing/routing.go | 4 +- internal/routing/rules.go | 15 ++-- internal/routing/rules_test.go | 88 +++++++++---------- internal/subnet/equal.go | 7 -- internal/subnet/subsets.go | 16 ++-- internal/wireguard/address.go | 11 +-- internal/wireguard/address_test.go | 27 +++--- internal/wireguard/constructor_test.go | 11 ++- internal/wireguard/settings.go | 18 ++-- internal/wireguard/settings_test.go | 57 ++++-------- 32 files changed, 344 insertions(+), 315 deletions(-) delete mode 100644 internal/netlink/ipnet.go create mode 100644 internal/routing/conversion.go create mode 100644 internal/routing/conversion_test.go delete mode 100644 internal/subnet/equal.go diff --git a/internal/configuration/settings/firewall.go b/internal/configuration/settings/firewall.go index b8a907dc..e5786a59 100644 --- a/internal/configuration/settings/firewall.go +++ b/internal/configuration/settings/firewall.go @@ -2,7 +2,7 @@ package settings import ( "fmt" - "net" + "net/netip" "github.com/qdm12/gluetun/internal/configuration/settings/helpers" "github.com/qdm12/gotree" @@ -12,7 +12,7 @@ import ( type Firewall struct { VPNInputPorts []uint16 InputPorts []uint16 - OutboundSubnets []net.IPNet + OutboundSubnets []netip.Prefix Enabled *bool Debug *bool } @@ -42,7 +42,7 @@ func (f *Firewall) copy() (copied Firewall) { return Firewall{ VPNInputPorts: helpers.CopyUint16Slice(f.VPNInputPorts), InputPorts: helpers.CopyUint16Slice(f.InputPorts), - OutboundSubnets: helpers.CopyIPNetSlice(f.OutboundSubnets), + OutboundSubnets: helpers.CopyNetipPrefixesSlice(f.OutboundSubnets), Enabled: helpers.CopyBoolPtr(f.Enabled), Debug: helpers.CopyBoolPtr(f.Debug), } @@ -55,7 +55,7 @@ func (f *Firewall) copy() (copied Firewall) { func (f *Firewall) mergeWith(other Firewall) { f.VPNInputPorts = helpers.MergeUint16Slices(f.VPNInputPorts, other.VPNInputPorts) f.InputPorts = helpers.MergeUint16Slices(f.InputPorts, other.InputPorts) - f.OutboundSubnets = helpers.MergeIPNetsSlices(f.OutboundSubnets, other.OutboundSubnets) + f.OutboundSubnets = helpers.MergeNetipPrefixesSlices(f.OutboundSubnets, other.OutboundSubnets) f.Enabled = helpers.MergeWithBool(f.Enabled, other.Enabled) f.Debug = helpers.MergeWithBool(f.Debug, other.Debug) } @@ -66,7 +66,7 @@ func (f *Firewall) mergeWith(other Firewall) { func (f *Firewall) overrideWith(other Firewall) { f.VPNInputPorts = helpers.OverrideWithUint16Slice(f.VPNInputPorts, other.VPNInputPorts) f.InputPorts = helpers.OverrideWithUint16Slice(f.InputPorts, other.InputPorts) - f.OutboundSubnets = helpers.OverrideWithIPNetsSlice(f.OutboundSubnets, other.OutboundSubnets) + f.OutboundSubnets = helpers.OverrideWithNetipPrefixesSlice(f.OutboundSubnets, other.OutboundSubnets) f.Enabled = helpers.OverrideWithBool(f.Enabled, other.Enabled) f.Debug = helpers.OverrideWithBool(f.Debug, other.Debug) } diff --git a/internal/configuration/settings/helpers/copy.go b/internal/configuration/settings/helpers/copy.go index 9047666e..5f9cd42c 100644 --- a/internal/configuration/settings/helpers/copy.go +++ b/internal/configuration/settings/helpers/copy.go @@ -90,30 +90,6 @@ func CopyIP(original net.IP) (copied net.IP) { return copied } -func CopyIPNet(original net.IPNet) (copied net.IPNet) { - if original.IP != nil { - copied.IP = make(net.IP, len(original.IP)) - copy(copied.IP, original.IP) - } - - if original.Mask != nil { - copied.Mask = make(net.IPMask, len(original.Mask)) - copy(copied.Mask, original.Mask) - } - - return copied -} - -func CopyIPNetPtr(original *net.IPNet) (copied *net.IPNet) { - if original == nil { - return nil - } - - copied = new(net.IPNet) - *copied = CopyIPNet(*original) - return copied -} - func CopyNetipAddress(original netip.Addr) (copied netip.Addr) { // AsSlice creates a new byte slice so no need to copy the bytes. bytes := original.AsSlice() @@ -158,18 +134,6 @@ func CopyUint16Slice(original []uint16) (copied []uint16) { return copied } -func CopyIPNetSlice(original []net.IPNet) (copied []net.IPNet) { - if original == nil { - return nil - } - - copied = make([]net.IPNet, len(original)) - for i := range original { - copied[i] = CopyIPNet(original[i]) - } - return copied -} - func CopyNetipPrefixesSlice(original []netip.Prefix) (copied []netip.Prefix) { if original == nil { return nil diff --git a/internal/configuration/settings/helpers/merge.go b/internal/configuration/settings/helpers/merge.go index 81b9087a..3d5cad5c 100644 --- a/internal/configuration/settings/helpers/merge.go +++ b/internal/configuration/settings/helpers/merge.go @@ -187,32 +187,6 @@ func MergeUint16Slices(a, b []uint16) (result []uint16) { return result } -func MergeIPNetsSlices(a, b []net.IPNet) (result []net.IPNet) { - if a == nil && b == nil { - return nil - } - - seen := make(map[string]struct{}, len(a)+len(b)) - result = make([]net.IPNet, 0, len(a)+len(b)) - for _, ipNet := range a { - key := ipNet.String() - if _, ok := seen[key]; ok { - continue // duplicate - } - result = append(result, ipNet) - seen[key] = struct{}{} - } - for _, ipNet := range b { - key := ipNet.String() - if _, ok := seen[key]; ok { - continue // duplicate - } - result = append(result, ipNet) - seen[key] = struct{}{} - } - return result -} - func MergeNetipAddressesSlices(a, b []netip.Addr) (result []netip.Addr) { if a == nil && b == nil { return nil diff --git a/internal/configuration/settings/helpers/override.go b/internal/configuration/settings/helpers/override.go index 165b9a16..e75d86e1 100644 --- a/internal/configuration/settings/helpers/override.go +++ b/internal/configuration/settings/helpers/override.go @@ -145,15 +145,6 @@ func OverrideWithUint16Slice(existing, other []uint16) (result []uint16) { return result } -func OverrideWithIPNetsSlice(existing, other []net.IPNet) (result []net.IPNet) { - if other == nil { - return existing - } - result = make([]net.IPNet, len(other)) - copy(result, other) - return result -} - func OverrideWithNetipAddressesSlice(existing, other []netip.Addr) (result []netip.Addr) { if other == nil { return existing diff --git a/internal/configuration/settings/wireguard.go b/internal/configuration/settings/wireguard.go index 82a92e50..82a23fb8 100644 --- a/internal/configuration/settings/wireguard.go +++ b/internal/configuration/settings/wireguard.go @@ -2,7 +2,7 @@ package settings import ( "fmt" - "net" + "net/netip" "regexp" "github.com/qdm12/gluetun/internal/configuration/settings/helpers" @@ -22,7 +22,7 @@ type Wireguard struct { // It cannot be nil in the internal state. PreSharedKey *string // Addresses are the Wireguard interface addresses. - Addresses []net.IPNet + Addresses []netip.Prefix // Interface is the name of the Wireguard interface // to create. It cannot be the empty string in the // internal state. @@ -78,13 +78,12 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported bool) (err error) return fmt.Errorf("%w", ErrWireguardInterfaceAddressNotSet) } for i, ipNet := range w.Addresses { - if ipNet.IP == nil || ipNet.Mask == nil { + if !ipNet.IsValid() { return fmt.Errorf("%w: for address at index %d: %s", ErrWireguardInterfaceAddressNotSet, i, ipNet.String()) } - ipv6Net := ipNet.IP.To4() == nil - if ipv6Net && !ipv6Supported { + if !ipv6Supported && ipNet.Addr().Is6() { return fmt.Errorf("%w: address %s", ErrWireguardInterfaceAddressIPv6, ipNet) } @@ -109,7 +108,7 @@ func (w *Wireguard) copy() (copied Wireguard) { return Wireguard{ PrivateKey: helpers.CopyStringPtr(w.PrivateKey), PreSharedKey: helpers.CopyStringPtr(w.PreSharedKey), - Addresses: helpers.CopyIPNetSlice(w.Addresses), + Addresses: helpers.CopyNetipPrefixesSlice(w.Addresses), Interface: w.Interface, Implementation: w.Implementation, } @@ -118,7 +117,7 @@ func (w *Wireguard) copy() (copied Wireguard) { func (w *Wireguard) mergeWith(other Wireguard) { w.PrivateKey = helpers.MergeWithStringPtr(w.PrivateKey, other.PrivateKey) w.PreSharedKey = helpers.MergeWithStringPtr(w.PreSharedKey, other.PreSharedKey) - w.Addresses = helpers.MergeIPNetsSlices(w.Addresses, other.Addresses) + w.Addresses = helpers.MergeNetipPrefixesSlices(w.Addresses, other.Addresses) w.Interface = helpers.MergeWithString(w.Interface, other.Interface) w.Implementation = helpers.MergeWithString(w.Implementation, other.Implementation) } @@ -126,7 +125,7 @@ func (w *Wireguard) mergeWith(other Wireguard) { func (w *Wireguard) overrideWith(other Wireguard) { w.PrivateKey = helpers.OverrideWithStringPtr(w.PrivateKey, other.PrivateKey) w.PreSharedKey = helpers.OverrideWithStringPtr(w.PreSharedKey, other.PreSharedKey) - w.Addresses = helpers.OverrideWithIPNetsSlice(w.Addresses, other.Addresses) + w.Addresses = helpers.OverrideWithNetipPrefixesSlice(w.Addresses, other.Addresses) w.Interface = helpers.OverrideWithString(w.Interface, other.Interface) w.Implementation = helpers.OverrideWithString(w.Implementation, other.Implementation) } diff --git a/internal/configuration/sources/env/firewall.go b/internal/configuration/sources/env/firewall.go index 3813104f..5c9e625a 100644 --- a/internal/configuration/sources/env/firewall.go +++ b/internal/configuration/sources/env/firewall.go @@ -3,7 +3,7 @@ package env import ( "errors" "fmt" - "net" + "net/netip" "strconv" "github.com/qdm12/gluetun/internal/configuration/settings" @@ -24,7 +24,7 @@ func (s *Source) readFirewall() (firewall settings.Firewall, err error) { outboundSubnetsKey, _ := s.getEnvWithRetro("FIREWALL_OUTBOUND_SUBNETS", "EXTRA_SUBNETS") outboundSubnetStrings := envToCSV(outboundSubnetsKey) - firewall.OutboundSubnets, err = stringsToIPNets(outboundSubnetStrings) + firewall.OutboundSubnets, err = stringsToNetipPrefixes(outboundSubnetStrings) if err != nil { return firewall, fmt.Errorf("environment variable %s: %w", outboundSubnetsKey, err) } @@ -65,18 +65,16 @@ func stringsToPorts(ss []string) (ports []uint16, err error) { return ports, nil } -func stringsToIPNets(ss []string) (ipNets []net.IPNet, err error) { +func stringsToNetipPrefixes(ss []string) (ipPrefixes []netip.Prefix, err error) { if len(ss) == 0 { return nil, nil } - ipNets = make([]net.IPNet, len(ss)) + ipPrefixes = make([]netip.Prefix, len(ss)) for i, s := range ss { - ip, ipNet, err := net.ParseCIDR(s) + ipPrefixes[i], err = netip.ParsePrefix(s) if err != nil { return nil, fmt.Errorf("parsing IP network %q: %w", s, err) } - ipNet.IP = ip - ipNets[i] = *ipNet } - return ipNets, nil + return ipPrefixes, nil } diff --git a/internal/configuration/sources/env/wireguard.go b/internal/configuration/sources/env/wireguard.go index a8ae7c97..80ab34d3 100644 --- a/internal/configuration/sources/env/wireguard.go +++ b/internal/configuration/sources/env/wireguard.go @@ -2,7 +2,7 @@ package env import ( "fmt" - "net" + "net/netip" "os" "strings" @@ -24,22 +24,20 @@ func (s *Source) readWireguard() (wireguard settings.Wireguard, err error) { return wireguard, nil } -func (s *Source) readWireguardAddresses() (addresses []net.IPNet, err error) { +func (s *Source) readWireguardAddresses() (addresses []netip.Prefix, err error) { key, addressesCSV := s.getEnvWithRetro("WIREGUARD_ADDRESSES", "WIREGUARD_ADDRESS") if addressesCSV == "" { return nil, nil } addressStrings := strings.Split(addressesCSV, ",") - addresses = make([]net.IPNet, len(addressStrings)) + addresses = make([]netip.Prefix, len(addressStrings)) for i, addressString := range addressStrings { addressString = strings.TrimSpace(addressString) - ip, ipNet, err := net.ParseCIDR(addressString) + addresses[i], err = netip.ParsePrefix(addressString) if err != nil { return nil, fmt.Errorf("environment variable %s: %w", key, err) } - ipNet.IP = ip - addresses[i] = *ipNet } return addresses, nil diff --git a/internal/firewall/enable.go b/internal/firewall/enable.go index 1e625a62..47305bd0 100644 --- a/internal/firewall/enable.go +++ b/internal/firewall/enable.go @@ -98,7 +98,7 @@ func (c *Config) enable(ctx context.Context) (err error) { } for _, network := range c.localNetworks { - if err := c.acceptOutputFromIPToSubnet(ctx, network.InterfaceName, network.IP, *network.IPNet, remove); err != nil { + if err := c.acceptOutputFromIPToSubnet(ctx, network.InterfaceName, network.IP, network.IPNet, remove); err != nil { return err } if err = c.acceptIpv6MulticastOutput(ctx, network.InterfaceName, remove); err != nil { @@ -113,7 +113,7 @@ func (c *Config) enable(ctx context.Context) (err error) { // Allows packets from any IP address to go through eth0 / local network // to reach Gluetun. for _, network := range c.localNetworks { - if err := c.acceptInputToSubnet(ctx, network.InterfaceName, *network.IPNet, remove); err != nil { + if err := c.acceptInputToSubnet(ctx, network.InterfaceName, network.IPNet, remove); err != nil { return err } } diff --git a/internal/firewall/firewall.go b/internal/firewall/firewall.go index 60c92a3b..26e7f10d 100644 --- a/internal/firewall/firewall.go +++ b/internal/firewall/firewall.go @@ -2,7 +2,7 @@ package firewall import ( "context" - "net" + "net/netip" "sync" "github.com/qdm12/gluetun/internal/models" @@ -27,7 +27,7 @@ type Config struct { //nolint:maligned enabled bool vpnConnection models.Connection vpnIntf string - outboundSubnets []net.IPNet + outboundSubnets []netip.Prefix allowedInputPorts map[uint16]map[string]struct{} // port to interfaces set mapping stateMutex sync.Mutex } diff --git a/internal/firewall/iptables.go b/internal/firewall/iptables.go index 06c5bfc8..4e9b56a4 100644 --- a/internal/firewall/iptables.go +++ b/internal/firewall/iptables.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net" + "net/netip" "os" "os/exec" "strings" @@ -108,9 +109,8 @@ func (c *Config) acceptInputThroughInterface(ctx context.Context, intf string, r )) } -func (c *Config) acceptInputToSubnet(ctx context.Context, intf string, destination net.IPNet, remove bool) error { - isIP4Subnet := destination.IP.To4() != nil - +func (c *Config) acceptInputToSubnet(ctx context.Context, intf string, + destination netip.Prefix, remove bool) error { interfaceFlag := "-i " + intf if intf == "*" { // all interfaces interfaceFlag = "" @@ -119,7 +119,7 @@ func (c *Config) acceptInputToSubnet(ctx context.Context, intf string, destinati instruction := fmt.Sprintf("%s INPUT %s -d %s -j ACCEPT", appendOrDelete(remove), interfaceFlag, destination.String()) - if isIP4Subnet { + if destination.Addr().Is4() { return c.runIptablesInstruction(ctx, instruction) } if c.ip6Tables == "" { @@ -157,8 +157,8 @@ func (c *Config) acceptOutputTrafficToVPN(ctx context.Context, // Thanks to @npawelek. func (c *Config) acceptOutputFromIPToSubnet(ctx context.Context, - intf string, sourceIP net.IP, destinationSubnet net.IPNet, remove bool) error { - doIPv4 := sourceIP.To4() != nil && destinationSubnet.IP.To4() != nil + intf string, sourceIP net.IP, destinationSubnet netip.Prefix, remove bool) error { + doIPv4 := sourceIP.To4() != nil && destinationSubnet.Addr().Is4() interfaceFlag := "-o " + intf if intf == "*" { // all interfaces diff --git a/internal/firewall/outboundsubnets.go b/internal/firewall/outboundsubnets.go index 0851857b..b65e9c50 100644 --- a/internal/firewall/outboundsubnets.go +++ b/internal/firewall/outboundsubnets.go @@ -3,18 +3,18 @@ package firewall import ( "context" "fmt" - "net" + "net/netip" "github.com/qdm12/gluetun/internal/subnet" ) -func (c *Config) SetOutboundSubnets(ctx context.Context, subnets []net.IPNet) (err error) { +func (c *Config) SetOutboundSubnets(ctx context.Context, subnets []netip.Prefix) (err error) { c.stateMutex.Lock() defer c.stateMutex.Unlock() if !c.enabled { c.logger.Info("firewall disabled, only updating allowed subnets internal list") - c.outboundSubnets = make([]net.IPNet, len(subnets)) + c.outboundSubnets = make([]netip.Prefix, len(subnets)) copy(c.outboundSubnets, subnets) return nil } @@ -34,7 +34,7 @@ func (c *Config) SetOutboundSubnets(ctx context.Context, subnets []net.IPNet) (e return nil } -func (c *Config) removeOutboundSubnets(ctx context.Context, subnets []net.IPNet) { +func (c *Config) removeOutboundSubnets(ctx context.Context, subnets []netip.Prefix) { const remove = true for _, subNet := range subnets { for _, defaultRoute := range c.defaultRoutes { @@ -49,7 +49,7 @@ func (c *Config) removeOutboundSubnets(ctx context.Context, subnets []net.IPNet) } } -func (c *Config) addOutboundSubnets(ctx context.Context, subnets []net.IPNet) error { +func (c *Config) addOutboundSubnets(ctx context.Context, subnets []netip.Prefix) error { const remove = false for _, subnet := range subnets { for _, defaultRoute := range c.defaultRoutes { diff --git a/internal/netlink/ipnet.go b/internal/netlink/ipnet.go deleted file mode 100644 index fca0b804..00000000 --- a/internal/netlink/ipnet.go +++ /dev/null @@ -1,11 +0,0 @@ -package netlink - -import ( - "net" - - "github.com/vishvananda/netlink" -) - -func NewIPNet(ip net.IP) *net.IPNet { - return netlink.NewIPNet(ip) -} diff --git a/internal/provider/utils/wireguard.go b/internal/provider/utils/wireguard.go index 0b888266..d1afa948 100644 --- a/internal/provider/utils/wireguard.go +++ b/internal/provider/utils/wireguard.go @@ -2,6 +2,7 @@ package utils import ( "net" + "net/netip" "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/models" @@ -25,17 +26,12 @@ func BuildWireguardSettings(connection models.Connection, copy(settings.Endpoint.IP, connection.IP) settings.Endpoint.Port = int(connection.Port) - settings.Addresses = make([]*net.IPNet, 0, len(userSettings.Addresses)) + settings.Addresses = make([]netip.Prefix, 0, len(userSettings.Addresses)) for _, address := range userSettings.Addresses { - ipv6Address := address.IP.To4() == nil - if !ipv6Supported && ipv6Address { + if !ipv6Supported && address.Addr().Is6() { continue } - addressCopy := new(net.IPNet) - addressCopy.IP = make(net.IP, len(address.IP)) - copy(addressCopy.IP, address.IP) - addressCopy.Mask = make(net.IPMask, len(address.Mask)) - copy(addressCopy.Mask, address.Mask) + addressCopy := netip.PrefixFrom(address.Addr(), address.Bits()) settings.Addresses = append(settings.Addresses, addressCopy) } diff --git a/internal/provider/utils/wireguard_test.go b/internal/provider/utils/wireguard_test.go index fc2fb2d8..6ec8026f 100644 --- a/internal/provider/utils/wireguard_test.go +++ b/internal/provider/utils/wireguard_test.go @@ -2,6 +2,7 @@ package utils import ( "net" + "net/netip" "testing" "github.com/qdm12/gluetun/internal/configuration/settings" @@ -30,9 +31,9 @@ func Test_BuildWireguardSettings(t *testing.T) { userSettings: settings.Wireguard{ PrivateKey: stringPtr("private"), PreSharedKey: stringPtr("pre-shared"), - Addresses: []net.IPNet{ - {IP: net.IPv4(1, 1, 1, 1), Mask: net.IPv4Mask(255, 255, 255, 255)}, - {IP: net.IPv6zero, Mask: net.IPv4Mask(255, 255, 255, 255)}, + Addresses: []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32), + netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 32), }, Interface: "wg1", }, @@ -46,8 +47,8 @@ func Test_BuildWireguardSettings(t *testing.T) { IP: net.IPv4(1, 2, 3, 4), Port: 51821, }, - Addresses: []*net.IPNet{ - {IP: net.IPv4(1, 1, 1, 1), Mask: net.IPv4Mask(255, 255, 255, 255)}, + Addresses: []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32), }, RulePriority: 101, IPv6: boolPtr(false), diff --git a/internal/routing/conversion.go b/internal/routing/conversion.go new file mode 100644 index 00000000..439a114b --- /dev/null +++ b/internal/routing/conversion.go @@ -0,0 +1,33 @@ +package routing + +import ( + "fmt" + "net" + "net/netip" +) + +func NetipPrefixToIPNet(prefix *netip.Prefix) (ipNet *net.IPNet) { + if prefix == nil { + return nil + } + + s := prefix.String() + ip, ipNet, err := net.ParseCIDR(s) + if err != nil { + panic(err) + } + ipNet.IP = ip + return ipNet +} + +func netIPNetToNetipPrefix(ipNet net.IPNet) (prefix netip.Prefix) { + return netip.MustParsePrefix(ipNet.String()) +} + +func netIPToNetipAddress(ip net.IP) (address netip.Addr) { + address, ok := netip.AddrFromSlice(ip) + if !ok { + panic(fmt.Sprintf("converting %#v to netip.Addr failed", ip)) + } + return address +} diff --git a/internal/routing/conversion_test.go b/internal/routing/conversion_test.go new file mode 100644 index 00000000..a30dfbb5 --- /dev/null +++ b/internal/routing/conversion_test.go @@ -0,0 +1,52 @@ +package routing + +import ( + "net" + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_netIPToNetipAddress(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + ip net.IP + address netip.Addr + panicMessage string + }{ + "nil ip": { + panicMessage: "converting net.IP(nil) to netip.Addr failed", + }, + "IPv4": { + ip: net.IPv4(1, 2, 3, 4), + address: netip.AddrFrom4([4]byte{1, 2, 3, 4}), + }, + "IPv6": { + ip: net.IPv6zero, + address: netip.AddrFrom16([16]byte{}), + }, + "IPv4 prefixed with 0xffff": { + ip: net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 1, 2, 3, 4}, + address: netip.AddrFrom4([4]byte{1, 2, 3, 4}), + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + if testCase.panicMessage != "" { + assert.PanicsWithValue(t, testCase.panicMessage, func() { + netIPToNetipAddress(testCase.ip) + }) + return + } + + address := netIPToNetipAddress(testCase.ip) + assert.Equal(t, testCase.address, address) + }) + } +} diff --git a/internal/routing/inbound.go b/internal/routing/inbound.go index 63795743..3814f858 100644 --- a/internal/routing/inbound.go +++ b/internal/routing/inbound.go @@ -2,7 +2,7 @@ package routing import ( "fmt" - "net" + "net/netip" "github.com/qdm12/gluetun/internal/netlink" ) @@ -17,8 +17,9 @@ func (r *Routing) routeInboundFromDefault(defaultRoutes []DefaultRoute) (err err return fmt.Errorf("adding rule: %w", err) } - defaultDestinationIPv4 := net.IPNet{IP: net.IPv4(0, 0, 0, 0), Mask: net.IPv4Mask(0, 0, 0, 0)} - defaultDestinationIPv6 := net.IPNet{IP: net.IPv6zero, Mask: net.IPMask(net.IPv6zero)} + const bits = 0 + defaultDestinationIPv4 := netip.PrefixFrom(netip.AddrFrom4([4]byte{}), bits) + defaultDestinationIPv6 := netip.PrefixFrom(netip.AddrFrom16([16]byte{}), bits) for _, defaultRoute := range defaultRoutes { defaultDestination := defaultDestinationIPv4 @@ -36,8 +37,9 @@ func (r *Routing) routeInboundFromDefault(defaultRoutes []DefaultRoute) (err err } func (r *Routing) unrouteInboundFromDefault(defaultRoutes []DefaultRoute) (err error) { - defaultDestinationIPv4 := net.IPNet{IP: net.IPv4(0, 0, 0, 0), Mask: net.IPv4Mask(0, 0, 0, 0)} - defaultDestinationIPv6 := net.IPNet{IP: net.IPv6zero, Mask: net.IPMask(net.IPv6zero)} + const bits = 0 + defaultDestinationIPv4 := netip.PrefixFrom(netip.AddrFrom4([4]byte{}), bits) + defaultDestinationIPv6 := netip.PrefixFrom(netip.AddrFrom16([16]byte{}), bits) for _, defaultRoute := range defaultRoutes { defaultDestination := defaultDestinationIPv4 @@ -60,9 +62,16 @@ func (r *Routing) unrouteInboundFromDefault(defaultRoutes []DefaultRoute) (err e func (r *Routing) addRuleInboundFromDefault(table int, defaultRoutes []DefaultRoute) (err error) { for _, defaultRoute := range defaultRoutes { - defaultIPMasked32 := netlink.NewIPNet(defaultRoute.AssignedIP) - ruleDstNet := (*net.IPNet)(nil) - err = r.addIPRule(defaultIPMasked32, ruleDstNet, table, inboundPriority) + assignedIP := netIPToNetipAddress(defaultRoute.AssignedIP) + bits := 32 + if assignedIP.Is6() { + bits = 128 + } + r.logger.Debug(fmt.Sprintf("ASSIGNED IP IS %#v -> %s, bits %d", + defaultRoute.AssignedIP, assignedIP, bits)) + defaultIPMasked := netip.PrefixFrom(assignedIP, bits) + ruleDstNet := (*netip.Prefix)(nil) + err = r.addIPRule(&defaultIPMasked, ruleDstNet, table, inboundPriority) if err != nil { return fmt.Errorf("adding rule for default route %s: %w", defaultRoute, err) } @@ -73,9 +82,14 @@ func (r *Routing) addRuleInboundFromDefault(table int, defaultRoutes []DefaultRo func (r *Routing) delRuleInboundFromDefault(table int, defaultRoutes []DefaultRoute) (err error) { for _, defaultRoute := range defaultRoutes { - defaultIPMasked32 := netlink.NewIPNet(defaultRoute.AssignedIP) - ruleDstNet := (*net.IPNet)(nil) - err = r.deleteIPRule(defaultIPMasked32, ruleDstNet, table, inboundPriority) + assignedIP := netIPToNetipAddress(defaultRoute.AssignedIP) + bits := 32 + if assignedIP.Is6() { + bits = 128 + } + defaultIPMasked := netip.PrefixFrom(assignedIP, bits) + ruleDstNet := (*netip.Prefix)(nil) + err = r.deleteIPRule(&defaultIPMasked, ruleDstNet, table, inboundPriority) if err != nil { return fmt.Errorf("deleting rule for default route %s: %w", defaultRoute, err) } diff --git a/internal/routing/ip.go b/internal/routing/ip.go index bf488298..4b7415ec 100644 --- a/internal/routing/ip.go +++ b/internal/routing/ip.go @@ -22,6 +22,18 @@ func ipMatchesFamily(ip net.IP, family int) bool { (family == netlink.FAMILY_V4 && ip.To4() != nil) } +func ensureNoIPv6WrappedIPv4(candidateIP net.IP) (resultIP net.IP) { + const ipv4Size = 4 + if candidateIP.To4() == nil || len(candidateIP) == ipv4Size { // ipv6 or ipv4 + return candidateIP + } + + // ipv6-wrapped ipv4 + resultIP = make(net.IP, ipv4Size) + copy(resultIP, candidateIP[12:16]) + return resultIP +} + func (r *Routing) assignedIP(interfaceName string, family int) (ip net.IP, err error) { iface, err := net.InterfaceByName(interfaceName) if err != nil { @@ -34,14 +46,22 @@ func (r *Routing) assignedIP(interfaceName string, family int) (ip net.IP, err e for _, address := range addresses { switch value := address.(type) { case *net.IPAddr: - if ipMatchesFamily(value.IP, family) { - return value.IP, nil - } + ip = value.IP case *net.IPNet: - if ipMatchesFamily(value.IP, family) { - return value.IP, nil - } + ip = value.IP + default: + continue } + + if !ipMatchesFamily(ip, family) { + continue + } + + // Ensure we don't return an IPv6-wrapped IPv4 address + // since netip.Address String method works differently than + // net.IP String method for this kind of addresses. + ip = ensureNoIPv6WrappedIPv4(ip) + return ip, nil } return nil, fmt.Errorf("%w: interface %s in %d addresses", errInterfaceIPNotFound, interfaceName, len(addresses)) diff --git a/internal/routing/ip_test.go b/internal/routing/ip_test.go index 1f89d881..ed2783ee 100644 --- a/internal/routing/ip_test.go +++ b/internal/routing/ip_test.go @@ -96,3 +96,35 @@ func Test_IPIsPrivate(t *testing.T) { }) } } + +func Test_ensureNoIPv6WrappedIPv4(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + candidateIP net.IP + resultIP net.IP + }{ + "nil": {}, + "ipv6": { + candidateIP: net.IPv6loopback, + resultIP: net.IPv6loopback, + }, + "ipv4": { + candidateIP: net.IP{1, 2, 3, 4}, + resultIP: net.IP{1, 2, 3, 4}, + }, + "ipv6_wrapped_ipv4": { + candidateIP: net.IPv4(1, 2, 3, 4), + resultIP: net.IP{1, 2, 3, 4}, + }, + } + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + resultIP := ensureNoIPv6WrappedIPv4(testCase.candidateIP) + assert.Equal(t, testCase.resultIP, resultIP) + }) + } +} diff --git a/internal/routing/local.go b/internal/routing/local.go index e67fa6f9..14072f68 100644 --- a/internal/routing/local.go +++ b/internal/routing/local.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net" + "net/netip" "github.com/qdm12/gluetun/internal/netlink" ) @@ -15,7 +16,7 @@ var ( ) type LocalNetwork struct { - IPNet *net.IPNet + IPNet netip.Prefix InterfaceName string IP net.IP } @@ -55,7 +56,7 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) { var localNet LocalNetwork - localNet.IPNet = route.Dst + localNet.IPNet = netIPNetToNetipPrefix(*route.Dst) r.logger.Info("local ipnet found: " + localNet.IPNet.String()) link, err := r.netLinker.LinkByIndex(route.LinkIndex) @@ -66,7 +67,7 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) { localNet.InterfaceName = link.Attrs().Name family := netlink.FAMILY_V6 - if localNet.IPNet.IP.To4() != nil { + if localNet.IPNet.Addr().Is4() { family = netlink.FAMILY_V4 } ip, err := r.assignedIP(localNet.InterfaceName, family) @@ -87,7 +88,7 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) { } func (r *Routing) AddLocalRules(subnets []LocalNetwork) (err error) { - for _, net := range subnets { + for _, subnet := range subnets { // The main table is a built-in value for Linux, see "man 8 ip-route" const mainTable = 254 @@ -96,9 +97,9 @@ func (r *Routing) AddLocalRules(subnets []LocalNetwork) (err error) { const localPriority = 98 // Main table was setup correctly by Docker, just need to add rules to use it - err = r.addIPRule(nil, net.IPNet, mainTable, localPriority) + err = r.addIPRule(nil, &subnet.IPNet, mainTable, localPriority) if err != nil { - return fmt.Errorf("adding rule: %v: %w", net.IPNet, err) + return fmt.Errorf("adding rule: %v: %w", subnet.IPNet, err) } } return nil diff --git a/internal/routing/outbound.go b/internal/routing/outbound.go index a7a1488d..7a0531f3 100644 --- a/internal/routing/outbound.go +++ b/internal/routing/outbound.go @@ -2,7 +2,7 @@ package routing import ( "fmt" - "net" + "net/netip" "github.com/qdm12/gluetun/internal/subnet" ) @@ -12,7 +12,7 @@ const ( outboundPriority = 99 ) -func (r *Routing) SetOutboundRoutes(outboundSubnets []net.IPNet) error { +func (r *Routing) SetOutboundRoutes(outboundSubnets []netip.Prefix) error { defaultRoutes, err := r.DefaultRoutes() if err != nil { return err @@ -20,7 +20,7 @@ func (r *Routing) SetOutboundRoutes(outboundSubnets []net.IPNet) error { return r.setOutboundRoutes(outboundSubnets, defaultRoutes) } -func (r *Routing) setOutboundRoutes(outboundSubnets []net.IPNet, +func (r *Routing) setOutboundRoutes(outboundSubnets []netip.Prefix, defaultRoutes []DefaultRoute) (err error) { r.stateMutex.Lock() defer r.stateMutex.Unlock() @@ -45,7 +45,7 @@ func (r *Routing) setOutboundRoutes(outboundSubnets []net.IPNet, return nil } -func (r *Routing) removeOutboundSubnets(subnets []net.IPNet, +func (r *Routing) removeOutboundSubnets(subnets []netip.Prefix, defaultRoutes []DefaultRoute) (warnings []string) { for i, subNet := range subnets { for _, defaultRoute := range defaultRoutes { @@ -56,7 +56,7 @@ func (r *Routing) removeOutboundSubnets(subnets []net.IPNet, } } - ruleSrcNet := (*net.IPNet)(nil) + ruleSrcNet := (*netip.Prefix)(nil) ruleDstNet := &subnets[i] err := r.deleteIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority) if err != nil { @@ -71,7 +71,7 @@ func (r *Routing) removeOutboundSubnets(subnets []net.IPNet, return warnings } -func (r *Routing) addOutboundSubnets(subnets []net.IPNet, +func (r *Routing) addOutboundSubnets(subnets []netip.Prefix, defaultRoutes []DefaultRoute) (err error) { for i, subnet := range subnets { for _, defaultRoute := range defaultRoutes { @@ -81,7 +81,7 @@ func (r *Routing) addOutboundSubnets(subnets []net.IPNet, } } - ruleSrcNet := (*net.IPNet)(nil) + ruleSrcNet := (*netip.Prefix)(nil) ruleDstNet := &subnets[i] err = r.addIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority) if err != nil { diff --git a/internal/routing/routes.go b/internal/routing/routes.go index 8163d799..127ac856 100644 --- a/internal/routing/routes.go +++ b/internal/routing/routes.go @@ -3,12 +3,13 @@ package routing import ( "fmt" "net" + "net/netip" "strconv" "github.com/qdm12/gluetun/internal/netlink" ) -func (r *Routing) addRouteVia(destination net.IPNet, gateway net.IP, +func (r *Routing) addRouteVia(destination netip.Prefix, gateway net.IP, iface string, table int) error { destinationStr := destination.String() r.logger.Info("adding route for " + destinationStr) @@ -23,7 +24,7 @@ func (r *Routing) addRouteVia(destination net.IPNet, gateway net.IP, } route := netlink.Route{ - Dst: &destination, + Dst: NetipPrefixToIPNet(&destination), Gw: gateway, LinkIndex: link.Attrs().Index, Table: table, @@ -36,7 +37,7 @@ func (r *Routing) addRouteVia(destination net.IPNet, gateway net.IP, return nil } -func (r *Routing) deleteRouteVia(destination net.IPNet, gateway net.IP, +func (r *Routing) deleteRouteVia(destination netip.Prefix, gateway net.IP, iface string, table int) (err error) { destinationStr := destination.String() r.logger.Info("deleting route for " + destinationStr) @@ -51,7 +52,7 @@ func (r *Routing) deleteRouteVia(destination net.IPNet, gateway net.IP, } route := netlink.Route{ - Dst: &destination, + Dst: NetipPrefixToIPNet(&destination), Gw: gateway, LinkIndex: link.Attrs().Index, Table: table, diff --git a/internal/routing/routing.go b/internal/routing/routing.go index bd590976..0a254655 100644 --- a/internal/routing/routing.go +++ b/internal/routing/routing.go @@ -1,7 +1,7 @@ package routing import ( - "net" + "net/netip" "sync" "github.com/qdm12/gluetun/internal/netlink" @@ -48,7 +48,7 @@ type Linker interface { type Routing struct { netLinker NetLinker logger Logger - outboundSubnets []net.IPNet + outboundSubnets []netip.Prefix stateMutex sync.RWMutex } diff --git a/internal/routing/rules.go b/internal/routing/rules.go index 0c9da122..772c58b0 100644 --- a/internal/routing/rules.go +++ b/internal/routing/rules.go @@ -4,17 +4,18 @@ import ( "bytes" "fmt" "net" + "net/netip" "github.com/qdm12/gluetun/internal/netlink" ) -func (r *Routing) addIPRule(src, dst *net.IPNet, table, priority int) error { +func (r *Routing) addIPRule(src, dst *netip.Prefix, table, priority int) error { const add = true r.logger.Debug(ruleDbgMsg(add, src, dst, table, priority)) rule := netlink.NewRule() - rule.Src = src - rule.Dst = dst + rule.Src = NetipPrefixToIPNet(src) + rule.Dst = NetipPrefixToIPNet(dst) rule.Priority = priority rule.Table = table @@ -35,13 +36,13 @@ func (r *Routing) addIPRule(src, dst *net.IPNet, table, priority int) error { return nil } -func (r *Routing) deleteIPRule(src, dst *net.IPNet, table, priority int) error { +func (r *Routing) deleteIPRule(src, dst *netip.Prefix, table, priority int) error { const add = false r.logger.Debug(ruleDbgMsg(add, src, dst, table, priority)) rule := netlink.NewRule() - rule.Src = src - rule.Dst = dst + rule.Src = NetipPrefixToIPNet(src) + rule.Dst = NetipPrefixToIPNet(dst) rule.Priority = priority rule.Table = table @@ -60,7 +61,7 @@ func (r *Routing) deleteIPRule(src, dst *net.IPNet, table, priority int) error { return nil } -func ruleDbgMsg(add bool, src, dst *net.IPNet, +func ruleDbgMsg(add bool, src, dst *netip.Prefix, table, priority int) (debugMessage string) { debugMessage = "ip rule" diff --git a/internal/routing/rules_test.go b/internal/routing/rules_test.go index e24f07e7..2d01b645 100644 --- a/internal/routing/rules_test.go +++ b/internal/routing/rules_test.go @@ -3,6 +3,7 @@ package routing import ( "errors" "net" + "net/netip" "testing" "github.com/golang/mock/gomock" @@ -11,20 +12,17 @@ import ( "github.com/stretchr/testify/require" ) -func makeIPNet(t *testing.T, n byte) *net.IPNet { - t.Helper() - return &net.IPNet{ - IP: net.IPv4(n, n, n, 0), - Mask: net.IPv4Mask(255, 255, 255, 0), - } +func makeNetipPrefix(n byte) *netip.Prefix { + const bits = 24 + prefix := netip.PrefixFrom(netip.AddrFrom4([4]byte{n, n, n, 0}), bits) + return &prefix } -func makeIPRule(t *testing.T, src, dst *net.IPNet, +func makeIPRule(src, dst *netip.Prefix, table, priority int) *netlink.Rule { - t.Helper() rule := netlink.NewRule() - rule.Src = src - rule.Dst = dst + rule.Src = NetipPrefixToIPNet(src) + rule.Dst = NetipPrefixToIPNet(dst) rule.Table = table rule.Priority = priority return rule @@ -47,8 +45,8 @@ func Test_Routing_addIPRule(t *testing.T) { } testCases := map[string]struct { - src *net.IPNet - dst *net.IPNet + src *netip.Prefix + dst *netip.Prefix table int priority int dbgMsg string @@ -64,46 +62,46 @@ func Test_Routing_addIPRule(t *testing.T) { err: errors.New("listing rules: dummy error"), }, "rule already exists": { - src: makeIPNet(t, 1), - dst: makeIPNet(t, 2), + src: makeNetipPrefix(1), + dst: makeNetipPrefix(2), table: 99, priority: 99, dbgMsg: "ip rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99", ruleList: ruleListCall{ rules: []netlink.Rule{ - *makeIPRule(t, makeIPNet(t, 2), makeIPNet(t, 2), 99, 99), - *makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99), + *makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99), + *makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99), }, }, }, "add rule error": { - src: makeIPNet(t, 1), - dst: makeIPNet(t, 2), + src: makeNetipPrefix(1), + dst: makeNetipPrefix(2), table: 99, priority: 99, dbgMsg: "ip rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99", ruleAdd: ruleAddCall{ expected: true, - ruleToAdd: makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99), + ruleToAdd: makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99), err: errDummy, }, err: errors.New("adding rule ip rule 99: from 1.1.1.0/24 to 2.2.2.0/24 table 99: dummy error"), }, "add rule success": { - src: makeIPNet(t, 1), - dst: makeIPNet(t, 2), + src: makeNetipPrefix(1), + dst: makeNetipPrefix(2), table: 99, priority: 99, dbgMsg: "ip rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99", ruleList: ruleListCall{ rules: []netlink.Rule{ - *makeIPRule(t, makeIPNet(t, 2), makeIPNet(t, 2), 99, 99), - *makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 101, 101), + *makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99), + *makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 101, 101), }, }, ruleAdd: ruleAddCall{ expected: true, - ruleToAdd: makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99), + ruleToAdd: makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99), }, }, } @@ -160,8 +158,8 @@ func Test_Routing_deleteIPRule(t *testing.T) { } testCases := map[string]struct { - src *net.IPNet - dst *net.IPNet + src *netip.Prefix + dst *netip.Prefix table int priority int dbgMsg string @@ -177,50 +175,50 @@ func Test_Routing_deleteIPRule(t *testing.T) { err: errors.New("listing rules: dummy error"), }, "rule delete error": { - src: makeIPNet(t, 1), - dst: makeIPNet(t, 2), + src: makeNetipPrefix(1), + dst: makeNetipPrefix(2), table: 99, priority: 99, dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99", ruleList: ruleListCall{ rules: []netlink.Rule{ - *makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99), + *makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99), }, }, ruleDel: ruleDelCall{ expected: true, - ruleToDel: makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99), + ruleToDel: makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99), err: errDummy, }, err: errors.New("deleting rule ip rule 99: from 1.1.1.0/24 to 2.2.2.0/24 table 99: dummy error"), }, "rule deleted": { - src: makeIPNet(t, 1), - dst: makeIPNet(t, 2), + src: makeNetipPrefix(1), + dst: makeNetipPrefix(2), table: 99, priority: 99, dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99", ruleList: ruleListCall{ rules: []netlink.Rule{ - *makeIPRule(t, makeIPNet(t, 2), makeIPNet(t, 2), 99, 99), - *makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99), + *makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99), + *makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99), }, }, ruleDel: ruleDelCall{ expected: true, - ruleToDel: makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99), + ruleToDel: makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99), }, }, "rule does not exist": { - src: makeIPNet(t, 1), - dst: makeIPNet(t, 2), + src: makeNetipPrefix(1), + dst: makeNetipPrefix(2), table: 99, priority: 99, dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99", ruleList: ruleListCall{ rules: []netlink.Rule{ - *makeIPRule(t, makeIPNet(t, 2), makeIPNet(t, 2), 99, 99), - *makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 101, 101), + *makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99), + *makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 101, 101), }, }, }, @@ -266,8 +264,8 @@ func Test_ruleDbgMsg(t *testing.T) { testCases := map[string]struct { add bool - src *net.IPNet - dst *net.IPNet + src *netip.Prefix + dst *netip.Prefix table int priority int dbgMsg string @@ -277,15 +275,15 @@ func Test_ruleDbgMsg(t *testing.T) { }, "add rule": { add: true, - src: makeIPNet(t, 1), - dst: makeIPNet(t, 2), + src: makeNetipPrefix(1), + dst: makeNetipPrefix(2), table: 100, priority: 101, dbgMsg: "ip rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 100 pref 101", }, "del rule": { - src: makeIPNet(t, 1), - dst: makeIPNet(t, 2), + src: makeNetipPrefix(1), + dst: makeNetipPrefix(2), table: 100, priority: 101, dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 100 pref 101", diff --git a/internal/subnet/equal.go b/internal/subnet/equal.go deleted file mode 100644 index 1bfb0116..00000000 --- a/internal/subnet/equal.go +++ /dev/null @@ -1,7 +0,0 @@ -package subnet - -import "net" - -func subnetsAreEqual(a, b net.IPNet) bool { - return a.IP.Equal(b.IP) && a.Mask.String() == b.Mask.String() -} diff --git a/internal/subnet/subsets.go b/internal/subnet/subsets.go index 4e89c1f0..9f0e8bb8 100644 --- a/internal/subnet/subsets.go +++ b/internal/subnet/subsets.go @@ -1,20 +1,20 @@ package subnet import ( - "net" + "net/netip" ) -func FindSubnetsToChange(oldSubnets, newSubnets []net.IPNet) (subnetsToAdd, subnetsToRemove []net.IPNet) { +func FindSubnetsToChange(oldSubnets, newSubnets []netip.Prefix) (subnetsToAdd, subnetsToRemove []netip.Prefix) { subnetsToAdd = findSubnetsToAdd(oldSubnets, newSubnets) subnetsToRemove = findSubnetsToRemove(oldSubnets, newSubnets) return subnetsToAdd, subnetsToRemove } -func findSubnetsToAdd(oldSubnets, newSubnets []net.IPNet) (subnetsToAdd []net.IPNet) { +func findSubnetsToAdd(oldSubnets, newSubnets []netip.Prefix) (subnetsToAdd []netip.Prefix) { for _, newSubnet := range newSubnets { found := false for _, oldSubnet := range oldSubnets { - if subnetsAreEqual(oldSubnet, newSubnet) { + if oldSubnet.String() == newSubnet.String() { found = true break } @@ -26,11 +26,11 @@ func findSubnetsToAdd(oldSubnets, newSubnets []net.IPNet) (subnetsToAdd []net.IP return subnetsToAdd } -func findSubnetsToRemove(oldSubnets, newSubnets []net.IPNet) (subnetsToRemove []net.IPNet) { +func findSubnetsToRemove(oldSubnets, newSubnets []netip.Prefix) (subnetsToRemove []netip.Prefix) { for _, oldSubnet := range oldSubnets { found := false for _, newSubnet := range newSubnets { - if subnetsAreEqual(oldSubnet, newSubnet) { + if oldSubnet.String() == newSubnet.String() { found = true break } @@ -42,10 +42,10 @@ func findSubnetsToRemove(oldSubnets, newSubnets []net.IPNet) (subnetsToRemove [] return subnetsToRemove } -func RemoveSubnetFromSubnets(subnets []net.IPNet, subnet net.IPNet) []net.IPNet { +func RemoveSubnetFromSubnets(subnets []netip.Prefix, subnet netip.Prefix) []netip.Prefix { L := len(subnets) for i := range subnets { - if subnetsAreEqual(subnet, subnets[i]) { + if subnet.String() == subnets[i].String() { subnets[i] = subnets[L-1] subnets = subnets[:L-1] break diff --git a/internal/wireguard/address.go b/internal/wireguard/address.go index b6c50b44..32914e28 100644 --- a/internal/wireguard/address.go +++ b/internal/wireguard/address.go @@ -2,21 +2,22 @@ package wireguard import ( "fmt" - "net" + "net/netip" "github.com/qdm12/gluetun/internal/netlink" + "github.com/qdm12/gluetun/internal/routing" ) func (w *Wireguard) addAddresses(link netlink.Link, - addresses []*net.IPNet) (err error) { + addresses []netip.Prefix) (err error) { for _, ipNet := range addresses { - ipNetIsIPv6 := ipNet.IP.To4() == nil - if !*w.settings.IPv6 && ipNetIsIPv6 { + if !*w.settings.IPv6 && ipNet.Addr().Is6() { continue } + ipNet := ipNet address := &netlink.Addr{ - IPNet: ipNet, + IPNet: routing.NetipPrefixToIPNet(&ipNet), } err = w.netlink.AddrAdd(link, address) diff --git a/internal/wireguard/address_test.go b/internal/wireguard/address_test.go index 34ff7ae7..7382a419 100644 --- a/internal/wireguard/address_test.go +++ b/internal/wireguard/address_test.go @@ -2,11 +2,12 @@ package wireguard import ( "errors" - "net" + "net/netip" "testing" "github.com/golang/mock/gomock" "github.com/qdm12/gluetun/internal/netlink" + "github.com/qdm12/gluetun/internal/routing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -14,8 +15,8 @@ import ( func Test_Wireguard_addAddresses(t *testing.T) { t.Parallel() - ipNetOne := &net.IPNet{IP: net.IPv4(1, 2, 3, 4), Mask: net.IPv4Mask(255, 255, 255, 255)} - ipNetTwo := &net.IPNet{IP: net.ParseIP("::1234"), Mask: net.CIDRMask(64, 128)} + ipNetOne := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 32) + ipNetTwo := netip.PrefixFrom(netip.MustParseAddr("::1234"), 64) newLink := func() netlink.Link { linkAttrs := netlink.NewLinkAttrs() @@ -29,20 +30,20 @@ func Test_Wireguard_addAddresses(t *testing.T) { testCases := map[string]struct { link netlink.Link - addrs []*net.IPNet + addrs []netip.Prefix wgBuilder func(ctrl *gomock.Controller, link netlink.Link) *Wireguard err error }{ "success": { link: newLink(), - addrs: []*net.IPNet{ipNetOne, ipNetTwo}, + addrs: []netip.Prefix{ipNetOne, ipNetTwo}, wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard { netLinker := NewMockNetLinker(ctrl) firstCall := netLinker.EXPECT(). - AddrAdd(link, &netlink.Addr{IPNet: ipNetOne}). + AddrAdd(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetOne)}). Return(nil) netLinker.EXPECT(). - AddrAdd(link, &netlink.Addr{IPNet: ipNetTwo}). + AddrAdd(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetTwo)}). Return(nil).After(firstCall) return &Wireguard{ netlink: netLinker, @@ -54,11 +55,11 @@ func Test_Wireguard_addAddresses(t *testing.T) { }, "first add error": { link: newLink(), - addrs: []*net.IPNet{ipNetOne, ipNetTwo}, + addrs: []netip.Prefix{ipNetOne, ipNetTwo}, wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard { netLinker := NewMockNetLinker(ctrl) netLinker.EXPECT(). - AddrAdd(link, &netlink.Addr{IPNet: ipNetOne}). + AddrAdd(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetOne)}). Return(errDummy) return &Wireguard{ netlink: netLinker, @@ -71,14 +72,14 @@ func Test_Wireguard_addAddresses(t *testing.T) { }, "second add error": { link: newLink(), - addrs: []*net.IPNet{ipNetOne, ipNetTwo}, + addrs: []netip.Prefix{ipNetOne, ipNetTwo}, wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard { netLinker := NewMockNetLinker(ctrl) firstCall := netLinker.EXPECT(). - AddrAdd(link, &netlink.Addr{IPNet: ipNetOne}). + AddrAdd(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetOne)}). Return(nil) netLinker.EXPECT(). - AddrAdd(link, &netlink.Addr{IPNet: ipNetTwo}). + AddrAdd(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetTwo)}). Return(errDummy).After(firstCall) return &Wireguard{ netlink: netLinker, @@ -91,7 +92,7 @@ func Test_Wireguard_addAddresses(t *testing.T) { }, "ignore IPv6": { link: newLink(), - addrs: []*net.IPNet{ipNetTwo}, + addrs: []netip.Prefix{ipNetTwo}, wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard { return &Wireguard{ settings: Settings{ diff --git a/internal/wireguard/constructor_test.go b/internal/wireguard/constructor_test.go index d6ebee94..572f999c 100644 --- a/internal/wireguard/constructor_test.go +++ b/internal/wireguard/constructor_test.go @@ -2,6 +2,7 @@ package wireguard import ( "net" + "net/netip" "testing" "github.com/stretchr/testify/assert" @@ -33,9 +34,8 @@ func Test_New(t *testing.T) { Endpoint: &net.UDPAddr{ IP: net.IPv4(1, 2, 3, 4), }, - Addresses: []*net.IPNet{{ - IP: net.IPv4(5, 6, 7, 8), - Mask: net.IPv4Mask(255, 255, 255, 255)}, + Addresses: []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom4([4]byte{5, 6, 7, 8}), 32), }, FirewallMark: 100, }, @@ -50,9 +50,8 @@ func Test_New(t *testing.T) { IP: net.IPv4(1, 2, 3, 4), Port: 51820, }, - Addresses: []*net.IPNet{{ - IP: net.IPv4(5, 6, 7, 8), - Mask: net.IPv4Mask(255, 255, 255, 255)}, + Addresses: []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom4([4]byte{5, 6, 7, 8}), 32), }, FirewallMark: 100, IPv6: ptr(false), diff --git a/internal/wireguard/settings.go b/internal/wireguard/settings.go index 9492bfed..d33b0053 100644 --- a/internal/wireguard/settings.go +++ b/internal/wireguard/settings.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net" + "net/netip" "regexp" "strings" @@ -24,7 +25,7 @@ type Settings struct { Endpoint *net.UDPAddr // Addresses assigned to the client. // Note IPv6 addresses are ignored if IPv6 is not supported. - Addresses []*net.IPNet + Addresses []netip.Prefix // FirewallMark to be used in routing tables and IP rules. // It defaults to 51820 if left to 0. FirewallMark int @@ -77,9 +78,7 @@ var ( ErrEndpointIPMissing = errors.New("endpoint IP is missing") ErrEndpointPortMissing = errors.New("endpoint port is missing") ErrAddressMissing = errors.New("interface address is missing") - ErrAddressNil = errors.New("interface address is nil") - ErrAddressIPMissing = errors.New("interface address IP is missing") - ErrAddressMaskMissing = errors.New("interface address mask is missing") + ErrAddressNotValid = errors.New("interface address is not valid") ErrFirewallMarkMissing = errors.New("firewall mark is missing") ErrImplementationInvalid = errors.New("invalid implementation") ) @@ -122,16 +121,9 @@ func (s *Settings) Check() (err error) { return fmt.Errorf("%w", ErrAddressMissing) } for i, addr := range s.Addresses { - switch { - case addr == nil: + if !addr.IsValid() { return fmt.Errorf("%w: for address %d of %d", - ErrAddressNil, i+1, len(s.Addresses)) - case addr.IP == nil: - return fmt.Errorf("%w: for address %d of %d", - ErrAddressIPMissing, i+1, len(s.Addresses)) - case addr.Mask == nil: - return fmt.Errorf("%w: for address %d of %d", - ErrAddressMaskMissing, i+1, len(s.Addresses)) + ErrAddressNotValid, i+1, len(s.Addresses)) } } diff --git a/internal/wireguard/settings_test.go b/internal/wireguard/settings_test.go index b49dded1..b2ec904e 100644 --- a/internal/wireguard/settings_test.go +++ b/internal/wireguard/settings_test.go @@ -3,6 +3,7 @@ package wireguard import ( "errors" "net" + "net/netip" "testing" "github.com/stretchr/testify/assert" @@ -177,7 +178,7 @@ func Test_Settings_Check(t *testing.T) { }, err: ErrAddressMissing, }, - "nil address": { + "invalid address": { settings: Settings{ InterfaceName: "wg0", PrivateKey: validKey1, @@ -186,35 +187,9 @@ func Test_Settings_Check(t *testing.T) { IP: net.IPv4(1, 2, 3, 4), Port: 51820, }, - Addresses: []*net.IPNet{nil}, + Addresses: []netip.Prefix{{}}, }, - err: errors.New("interface address is nil: for address 1 of 1"), - }, - "nil address IP": { - settings: Settings{ - InterfaceName: "wg0", - PrivateKey: validKey1, - PublicKey: validKey2, - Endpoint: &net.UDPAddr{ - IP: net.IPv4(1, 2, 3, 4), - Port: 51820, - }, - Addresses: []*net.IPNet{{}}, - }, - err: errors.New("interface address IP is missing: for address 1 of 1"), - }, - "nil address mask": { - settings: Settings{ - InterfaceName: "wg0", - PrivateKey: validKey1, - PublicKey: validKey2, - Endpoint: &net.UDPAddr{ - IP: net.IPv4(1, 2, 3, 4), - Port: 51820, - }, - Addresses: []*net.IPNet{{IP: net.IPv4(1, 2, 3, 4)}}, - }, - err: errors.New("interface address mask is missing: for address 1 of 1"), + err: errors.New("interface address is not valid: for address 1 of 1"), }, "zero firewall mark": { settings: Settings{ @@ -225,7 +200,9 @@ func Test_Settings_Check(t *testing.T) { IP: net.IPv4(1, 2, 3, 4), Port: 51820, }, - Addresses: []*net.IPNet{{IP: net.IPv4(1, 2, 3, 4), Mask: net.CIDRMask(24, 32)}}, + Addresses: []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24), + }, }, err: ErrFirewallMarkMissing, }, @@ -238,7 +215,9 @@ func Test_Settings_Check(t *testing.T) { IP: net.IPv4(1, 2, 3, 4), Port: 51820, }, - Addresses: []*net.IPNet{{IP: net.IPv4(1, 2, 3, 4), Mask: net.CIDRMask(24, 32)}}, + Addresses: []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24), + }, FirewallMark: 999, Implementation: "x", }, @@ -253,7 +232,9 @@ func Test_Settings_Check(t *testing.T) { IP: net.IPv4(1, 2, 3, 4), Port: 51820, }, - Addresses: []*net.IPNet{{IP: net.IPv4(1, 2, 3, 4), Mask: net.CIDRMask(24, 32)}}, + Addresses: []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24), + }, FirewallMark: 999, Implementation: "userspace", }, @@ -356,9 +337,9 @@ func Test_Settings_Lines(t *testing.T) { }, FirewallMark: 999, RulePriority: 888, - Addresses: []*net.IPNet{ - {IP: net.IPv4(1, 1, 1, 1), Mask: net.CIDRMask(24, 32)}, - {IP: net.IPv4(2, 2, 2, 2), Mask: net.CIDRMask(32, 32)}, + Addresses: []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), + netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), }, IPv6: ptr(true), Implementation: "userspace", @@ -386,9 +367,9 @@ func Test_Settings_Lines(t *testing.T) { }, settings: Settings{ InterfaceName: "wg0", - Addresses: []*net.IPNet{ - {IP: net.IPv4(1, 1, 1, 1), Mask: net.CIDRMask(24, 32)}, - {IP: net.IPv4(2, 2, 2, 2), Mask: net.CIDRMask(32, 32)}, + Addresses: []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), + netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), }, IPv6: ptr(false), },