gluetun/internal/routing/rules_test.go

382 lines
9.3 KiB
Go

package routing
import (
"errors"
"net/netip"
"testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func makeNetipPrefix(n byte) netip.Prefix {
const bits = 24
return netip.PrefixFrom(netip.AddrFrom4([4]byte{n, n, n, 0}), bits)
}
func makeIPRule(src, dst netip.Prefix,
table, priority int,
) netlink.Rule {
rule := netlink.NewRule()
rule.Src = src
rule.Dst = dst
rule.Table = table
rule.Priority = priority
return rule
}
func Test_Routing_addIPRule(t *testing.T) {
t.Parallel()
errDummy := errors.New("dummy error")
type ruleListCall struct {
rules []netlink.Rule
err error
}
type ruleAddCall struct {
expected bool
ruleToAdd netlink.Rule
err error
}
testCases := map[string]struct {
src netip.Prefix
dst netip.Prefix
table int
priority int
ruleList ruleListCall
ruleAdd ruleAddCall
err error
}{
"list error": {
ruleList: ruleListCall{
err: errDummy,
},
err: errors.New("listing rules: dummy error"),
},
"rule already exists": {
src: makeNetipPrefix(1),
dst: makeNetipPrefix(2),
table: 99,
priority: 99,
ruleList: ruleListCall{
rules: []netlink.Rule{
makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
},
},
},
"add rule error": {
src: makeNetipPrefix(1),
dst: makeNetipPrefix(2),
table: 99,
priority: 99,
ruleAdd: ruleAddCall{
expected: true,
ruleToAdd: makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
err: errDummy,
},
err: errors.New("adding ip rule 99: from 1.1.1.0/24 to 2.2.2.0/24 table 99: dummy error"),
},
"add rule success": {
src: makeNetipPrefix(1),
dst: makeNetipPrefix(2),
table: 99,
priority: 99,
ruleList: ruleListCall{
rules: []netlink.Rule{
makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 101, 101),
},
},
ruleAdd: ruleAddCall{
expected: true,
ruleToAdd: makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
},
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
netLinker := NewMockNetLinker(ctrl)
netLinker.EXPECT().RuleList(netlink.FamilyAll).
Return(testCase.ruleList.rules, testCase.ruleList.err)
if testCase.ruleAdd.expected {
netLinker.EXPECT().RuleAdd(testCase.ruleAdd.ruleToAdd).
Return(testCase.ruleAdd.err)
}
r := Routing{
netLinker: netLinker,
}
err := r.addIPRule(testCase.src, testCase.dst,
testCase.table, testCase.priority)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
})
}
}
func Test_Routing_deleteIPRule(t *testing.T) {
t.Parallel()
errDummy := errors.New("dummy error")
type ruleListCall struct {
rules []netlink.Rule
err error
}
type ruleDelCall struct {
expected bool
ruleToDel netlink.Rule
err error
}
testCases := map[string]struct {
src netip.Prefix
dst netip.Prefix
table int
priority int
ruleList ruleListCall
ruleDel ruleDelCall
err error
}{
"list error": {
ruleList: ruleListCall{
err: errDummy,
},
err: errors.New("listing rules: dummy error"),
},
"rule delete error": {
src: makeNetipPrefix(1),
dst: makeNetipPrefix(2),
table: 99,
priority: 99,
ruleList: ruleListCall{
rules: []netlink.Rule{
makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
},
},
ruleDel: ruleDelCall{
expected: true,
ruleToDel: makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
err: errDummy,
},
err: errors.New("deleting rule ip rule 99: from 1.1.1.0/24 to 2.2.2.0/24 table 99: dummy error"),
},
"rule deleted": {
src: makeNetipPrefix(1),
dst: makeNetipPrefix(2),
table: 99,
priority: 99,
ruleList: ruleListCall{
rules: []netlink.Rule{
makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
},
},
ruleDel: ruleDelCall{
expected: true,
ruleToDel: makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
},
},
"rule does not exist": {
src: makeNetipPrefix(1),
dst: makeNetipPrefix(2),
table: 99,
priority: 99,
ruleList: ruleListCall{
rules: []netlink.Rule{
makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 101, 101),
},
},
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
netLinker := NewMockNetLinker(ctrl)
netLinker.EXPECT().RuleList(netlink.FamilyAll).
Return(testCase.ruleList.rules, testCase.ruleList.err)
if testCase.ruleDel.expected {
netLinker.EXPECT().RuleDel(testCase.ruleDel.ruleToDel).
Return(testCase.ruleDel.err)
}
r := Routing{
netLinker: netLinker,
}
err := r.deleteIPRule(testCase.src, testCase.dst,
testCase.table, testCase.priority)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
})
}
}
func Test_rulesAreEqual(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
a netlink.Rule
b netlink.Rule
equal bool
}{
"both_empty": {
equal: true,
},
"not_equal_by_src": {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 101,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 101,
},
},
"not_equal_by_dst": {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 32),
Priority: 100,
Table: 101,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 101,
},
},
"not_equal_by_priority": {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 999,
Table: 101,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 101,
},
},
"not_equal_by_table": {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 999,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 101,
},
},
"equal": {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 101,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 101,
},
equal: true,
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
equal := rulesAreEqual(testCase.a, testCase.b)
assert.Equal(t, testCase.equal, equal)
})
}
}
func Test_ipPrefixesAreEqual(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
a netip.Prefix
b netip.Prefix
equal bool
}{
"both_not_valid": {
equal: true,
},
"first_not_valid": {
b: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
},
"second_not_valid": {
a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
},
"both_equal": {
a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
b: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
equal: true,
},
"both_not_equal_by_IP": {
a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
b: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 24),
},
"both_not_equal_by_bits": {
a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
b: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32),
},
"both_not_equal_by_IP_and_bits": {
a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
b: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
equal := ipPrefixesAreEqual(testCase.a, testCase.b)
assert.Equal(t, testCase.equal, equal)
})
}
}