diff --git a/internal/firewall/outboundsubnets.go b/internal/firewall/outboundsubnets.go index fdd65492..3d8711f5 100644 --- a/internal/firewall/outboundsubnets.go +++ b/internal/firewall/outboundsubnets.go @@ -25,8 +25,7 @@ func (c *Config) SetOutboundSubnets(ctx context.Context, subnets []net.IPNet) (e c.logger.Info("setting allowed subnets through firewall...") - subnetsToAdd := subnet.FindSubnetsToAdd(c.outboundSubnets, subnets) - subnetsToRemove := subnet.FindSubnetsToRemove(c.outboundSubnets, subnets) + subnetsToAdd, subnetsToRemove := subnet.FindSubnetsToChange(c.outboundSubnets, subnets) if len(subnetsToAdd) == 0 && len(subnetsToRemove) == 0 { return nil } diff --git a/internal/routing/outboundsubnets.go b/internal/routing/outboundsubnets.go index 87ebadc3..61a14f4c 100644 --- a/internal/routing/outboundsubnets.go +++ b/internal/routing/outboundsubnets.go @@ -29,8 +29,8 @@ func (r *Routing) setOutboundRoutes(outboundSubnets []net.IPNet, r.stateMutex.Lock() defer r.stateMutex.Unlock() - subnetsToRemove := subnet.FindSubnetsToRemove(r.outboundSubnets, outboundSubnets) - subnetsToAdd := subnet.FindSubnetsToAdd(r.outboundSubnets, outboundSubnets) + subnetsToAdd, subnetsToRemove := subnet.FindSubnetsToChange( + r.outboundSubnets, outboundSubnets) if len(subnetsToAdd) == 0 && len(subnetsToRemove) == 0 { return nil diff --git a/internal/subnet/subnets.go b/internal/subnet/subnets.go index f857d529..c58a31f1 100644 --- a/internal/subnet/subnets.go +++ b/internal/subnet/subnets.go @@ -4,7 +4,13 @@ import ( "net" ) -func FindSubnetsToAdd(oldSubnets, newSubnets []net.IPNet) (subnetsToAdd []net.IPNet) { +func FindSubnetsToChange(oldSubnets, newSubnets []net.IPNet) (subnetsToAdd, subnetsToRemove []net.IPNet) { + subnetsToAdd = findSubnetsToAdd(oldSubnets, newSubnets) + subnetsToRemove = findSubnetsToRemove(oldSubnets, newSubnets) + return subnetsToAdd, subnetsToRemove +} + +func findSubnetsToAdd(oldSubnets, newSubnets []net.IPNet) (subnetsToAdd []net.IPNet) { for _, newSubnet := range newSubnets { found := false for _, oldSubnet := range oldSubnets { @@ -20,7 +26,7 @@ func FindSubnetsToAdd(oldSubnets, newSubnets []net.IPNet) (subnetsToAdd []net.IP return subnetsToAdd } -func FindSubnetsToRemove(oldSubnets, newSubnets []net.IPNet) (subnetsToRemove []net.IPNet) { +func findSubnetsToRemove(oldSubnets, newSubnets []net.IPNet) (subnetsToRemove []net.IPNet) { for _, oldSubnet := range oldSubnets { found := false for _, newSubnet := range newSubnets {