diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 588981bd..7f52c887 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "os" + "os/exec" "os/signal" "strings" "syscall" @@ -15,6 +16,7 @@ import ( _ "github.com/breml/rootcerts" "github.com/qdm12/gluetun/internal/alpine" "github.com/qdm12/gluetun/internal/cli" + "github.com/qdm12/gluetun/internal/command" "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/sources/files" "github.com/qdm12/gluetun/internal/configuration/sources/secrets" @@ -41,7 +43,6 @@ import ( "github.com/qdm12/gluetun/internal/updater/resolver" "github.com/qdm12/gluetun/internal/updater/unzip" "github.com/qdm12/gluetun/internal/vpn" - "github.com/qdm12/golibs/command" "github.com/qdm12/gosettings/reader" "github.com/qdm12/gosettings/reader/sources/env" "github.com/qdm12/goshutdown" @@ -78,7 +79,7 @@ func main() { netLinkDebugLogger := logger.New(log.SetComponent("netlink")) netLinker := netlink.New(netLinkDebugLogger) cli := cli.New() - cmder := command.NewCmder() + cmder := command.New() reader := reader.New(reader.Settings{ Sources: []reader.Source{ @@ -145,7 +146,7 @@ var ( //nolint:gocognit,gocyclo,maintidx func _main(ctx context.Context, buildInfo models.BuildInformation, args []string, logger log.LoggerInterface, reader *reader.Reader, - tun Tun, netLinker netLinker, cmder command.RunStarter, + tun Tun, netLinker netLinker, cmder RunStarter, cli clier) error { if len(args) > 1 { // cli operation switch args[1] { @@ -591,3 +592,9 @@ type Tun interface { Check(tunDevice string) error Create(tunDevice string) error } + +type RunStarter interface { + Run(cmd *exec.Cmd) (output string, err error) + Start(cmd *exec.Cmd) (stdoutLines, stderrLines <-chan string, + waitError <-chan error, err error) +} diff --git a/internal/command/cmder.go b/internal/command/cmder.go new file mode 100644 index 00000000..86c9f679 --- /dev/null +++ b/internal/command/cmder.go @@ -0,0 +1,8 @@ +package command + +// Cmder handles running subprograms synchronously and asynchronously. +type Cmder struct{} + +func New() *Cmder { + return &Cmder{} +} diff --git a/internal/command/interfaces_local.go b/internal/command/interfaces_local.go new file mode 100644 index 00000000..2da739d1 --- /dev/null +++ b/internal/command/interfaces_local.go @@ -0,0 +1,11 @@ +package command + +import "io" + +type execCmd interface { + CombinedOutput() ([]byte, error) + StdoutPipe() (io.ReadCloser, error) + StderrPipe() (io.ReadCloser, error) + Start() error + Wait() error +} diff --git a/internal/command/mocks_generate_test.go b/internal/command/mocks_generate_test.go new file mode 100644 index 00000000..61598aa4 --- /dev/null +++ b/internal/command/mocks_generate_test.go @@ -0,0 +1,3 @@ +package command + +//go:generate mockgen -destination=mocks_local_test.go -package=$GOPACKAGE -source=interfaces_local.go diff --git a/internal/command/mocks_local_test.go b/internal/command/mocks_local_test.go new file mode 100644 index 00000000..f8400ab7 --- /dev/null +++ b/internal/command/mocks_local_test.go @@ -0,0 +1,108 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: interfaces_local.go + +// Package command is a generated GoMock package. +package command + +import ( + io "io" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockexecCmd is a mock of execCmd interface. +type MockexecCmd struct { + ctrl *gomock.Controller + recorder *MockexecCmdMockRecorder +} + +// MockexecCmdMockRecorder is the mock recorder for MockexecCmd. +type MockexecCmdMockRecorder struct { + mock *MockexecCmd +} + +// NewMockexecCmd creates a new mock instance. +func NewMockexecCmd(ctrl *gomock.Controller) *MockexecCmd { + mock := &MockexecCmd{ctrl: ctrl} + mock.recorder = &MockexecCmdMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockexecCmd) EXPECT() *MockexecCmdMockRecorder { + return m.recorder +} + +// CombinedOutput mocks base method. +func (m *MockexecCmd) CombinedOutput() ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CombinedOutput") + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CombinedOutput indicates an expected call of CombinedOutput. +func (mr *MockexecCmdMockRecorder) CombinedOutput() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CombinedOutput", reflect.TypeOf((*MockexecCmd)(nil).CombinedOutput)) +} + +// Start mocks base method. +func (m *MockexecCmd) Start() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Start") + ret0, _ := ret[0].(error) + return ret0 +} + +// Start indicates an expected call of Start. +func (mr *MockexecCmdMockRecorder) Start() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockexecCmd)(nil).Start)) +} + +// StderrPipe mocks base method. +func (m *MockexecCmd) StderrPipe() (io.ReadCloser, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StderrPipe") + ret0, _ := ret[0].(io.ReadCloser) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// StderrPipe indicates an expected call of StderrPipe. +func (mr *MockexecCmdMockRecorder) StderrPipe() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StderrPipe", reflect.TypeOf((*MockexecCmd)(nil).StderrPipe)) +} + +// StdoutPipe mocks base method. +func (m *MockexecCmd) StdoutPipe() (io.ReadCloser, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StdoutPipe") + ret0, _ := ret[0].(io.ReadCloser) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// StdoutPipe indicates an expected call of StdoutPipe. +func (mr *MockexecCmdMockRecorder) StdoutPipe() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StdoutPipe", reflect.TypeOf((*MockexecCmd)(nil).StdoutPipe)) +} + +// Wait mocks base method. +func (m *MockexecCmd) Wait() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Wait") + ret0, _ := ret[0].(error) + return ret0 +} + +// Wait indicates an expected call of Wait. +func (mr *MockexecCmdMockRecorder) Wait() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Wait", reflect.TypeOf((*MockexecCmd)(nil).Wait)) +} diff --git a/internal/command/run.go b/internal/command/run.go new file mode 100644 index 00000000..59d98a6e --- /dev/null +++ b/internal/command/run.go @@ -0,0 +1,30 @@ +package command + +import ( + "os/exec" + "strings" +) + +// Run runs a command in a blocking manner, returning its output and +// an error if it failed. +func (c *Cmder) Run(cmd *exec.Cmd) (output string, err error) { + return run(cmd) +} + +func run(cmd execCmd) (output string, err error) { + stdout, err := cmd.CombinedOutput() + output = string(stdout) + output = strings.TrimSuffix(output, "\n") + lines := stringToLines(output) + for i := range lines { + lines[i] = strings.TrimPrefix(lines[i], "'") + lines[i] = strings.TrimSuffix(lines[i], "'") + } + output = strings.Join(lines, "\n") + return output, err +} + +func stringToLines(s string) (lines []string) { + s = strings.TrimSuffix(s, "\n") + return strings.Split(s, "\n") +} diff --git a/internal/command/run_test.go b/internal/command/run_test.go new file mode 100644 index 00000000..0f7892c5 --- /dev/null +++ b/internal/command/run_test.go @@ -0,0 +1,55 @@ +package command + +import ( + "errors" + "testing" + + gomock "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_run(t *testing.T) { + t.Parallel() + + errDummy := errors.New("dummy") + + testCases := map[string]struct { + stdout []byte + cmdErr error + output string + err error + }{ + "no output": {}, + "cmd error": { + stdout: []byte("'hello \nworld'\n"), + cmdErr: errDummy, + output: "hello \nworld", + err: errDummy, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + + mockCmd := NewMockexecCmd(ctrl) + + mockCmd.EXPECT().CombinedOutput().Return(testCase.stdout, testCase.cmdErr) + + output, err := run(mockCmd) + + if testCase.err != nil { + require.Error(t, err) + assert.Equal(t, testCase.err.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + + assert.Equal(t, testCase.output, output) + }) + } +} diff --git a/internal/command/start.go b/internal/command/start.go new file mode 100644 index 00000000..9b987039 --- /dev/null +++ b/internal/command/start.go @@ -0,0 +1,97 @@ +package command + +import ( + "bufio" + "errors" + "io" + "os" + "os/exec" +) + +// Start launches a command and streams stdout and stderr to channels. +// All the channels returned are ready only and won't be closed +// if the command fails later. +func (c *Cmder) Start(cmd *exec.Cmd) ( + stdoutLines, stderrLines <-chan string, + waitError <-chan error, startErr error) { + return start(cmd) +} + +func start(cmd execCmd) (stdoutLines, stderrLines <-chan string, + waitError <-chan error, startErr error) { + stop := make(chan struct{}) + stdoutReady := make(chan struct{}) + stdoutLinesCh := make(chan string) + stdoutDone := make(chan struct{}) + stderrReady := make(chan struct{}) + stderrLinesCh := make(chan string) + stderrDone := make(chan struct{}) + + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, nil, nil, err + } + go streamToChannel(stdoutReady, stop, stdoutDone, stdout, stdoutLinesCh) + + stderr, err := cmd.StderrPipe() + if err != nil { + _ = stdout.Close() + close(stop) + <-stdoutDone + return nil, nil, nil, err + } + go streamToChannel(stderrReady, stop, stderrDone, stderr, stderrLinesCh) + + err = cmd.Start() + if err != nil { + _ = stdout.Close() + _ = stderr.Close() + close(stop) + <-stdoutDone + <-stderrDone + return nil, nil, nil, err + } + + waitErrorCh := make(chan error) + go func() { + err := cmd.Wait() + _ = stdout.Close() + _ = stderr.Close() + close(stop) + <-stdoutDone + <-stderrDone + waitErrorCh <- err + }() + + return stdoutLinesCh, stderrLinesCh, waitErrorCh, nil +} + +func streamToChannel(ready chan<- struct{}, + stop <-chan struct{}, done chan<- struct{}, + stream io.Reader, lines chan<- string) { + defer close(done) + close(ready) + scanner := bufio.NewScanner(stream) + lineBuffer := make([]byte, bufio.MaxScanTokenSize) // 64KB + const maxCapacity = 20 * 1024 * 1024 // 20MB + scanner.Buffer(lineBuffer, maxCapacity) + + for scanner.Scan() { + // scanner is closed if the context is canceled + // or if the command failed starting because the + // stream is closed (io.EOF error). + lines <- scanner.Text() + } + err := scanner.Err() + if err == nil || errors.Is(err, os.ErrClosed) { + return + } + + // ignore the error if it is stopped. + select { + case <-stop: + return + default: + lines <- "stream error: " + err.Error() + } +} diff --git a/internal/command/start_test.go b/internal/command/start_test.go new file mode 100644 index 00000000..5f808ae3 --- /dev/null +++ b/internal/command/start_test.go @@ -0,0 +1,119 @@ +package command + +import ( + "bytes" + "errors" + "io" + "strings" + "testing" + + gomock "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func linesToReadCloser(lines []string) io.ReadCloser { + s := strings.Join(lines, "\n") + return io.NopCloser(bytes.NewBufferString(s)) +} + +func Test_start(t *testing.T) { + t.Parallel() + + errDummy := errors.New("dummy") + + testCases := map[string]struct { + stdout []string + stdoutPipeErr error + stderr []string + stderrPipeErr error + startErr error + waitErr error + err error + }{ + "no output": {}, + "success": { + stdout: []string{"hello", "world"}, + stderr: []string{"some", "error"}, + }, + "stdout pipe error": { + stdoutPipeErr: errDummy, + err: errDummy, + }, + "stderr pipe error": { + stderrPipeErr: errDummy, + err: errDummy, + }, + "start error": { + startErr: errDummy, + err: errDummy, + }, + "wait error": { + waitErr: errDummy, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + + stdout := linesToReadCloser(testCase.stdout) + stderr := linesToReadCloser(testCase.stderr) + + mockCmd := NewMockexecCmd(ctrl) + + mockCmd.EXPECT().StdoutPipe(). + Return(stdout, testCase.stdoutPipeErr) + if testCase.stdoutPipeErr == nil { + mockCmd.EXPECT().StderrPipe().Return(stderr, testCase.stderrPipeErr) + if testCase.stderrPipeErr == nil { + mockCmd.EXPECT().Start().Return(testCase.startErr) + if testCase.startErr == nil { + mockCmd.EXPECT().Wait().Return(testCase.waitErr) + } + } + } + + stdoutLines, stderrLines, waitError, err := start(mockCmd) + + if testCase.err != nil { + require.Error(t, err) + assert.Equal(t, testCase.err.Error(), err.Error()) + assert.Nil(t, stdoutLines) + assert.Nil(t, stderrLines) + assert.Nil(t, waitError) + return + } + + require.NoError(t, err) + + var stdoutIndex, stderrIndex int + + done := false + for !done { + select { + case line := <-stdoutLines: + assert.Equal(t, testCase.stdout[stdoutIndex], line) + stdoutIndex++ + case line := <-stderrLines: + assert.Equal(t, testCase.stderr[stderrIndex], line) + stderrIndex++ + case err := <-waitError: + if testCase.waitErr != nil { + require.Error(t, err) + assert.Equal(t, testCase.waitErr.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + done = true + } + } + + assert.Equal(t, len(testCase.stdout), stdoutIndex) + assert.Equal(t, len(testCase.stderr), stderrIndex) + }) + } +} diff --git a/internal/firewall/delete.go b/internal/firewall/delete.go index eb490784..7ad2a3b1 100644 --- a/internal/firewall/delete.go +++ b/internal/firewall/delete.go @@ -33,7 +33,7 @@ func isDeleteMatchInstruction(instruction string) bool { } func deleteIPTablesRule(ctx context.Context, iptablesBinary, instruction string, - runner Runner, logger Logger) (err error) { + runner CmdRunner, logger Logger) (err error) { targetRule, err := parseIptablesInstruction(instruction) if err != nil { return fmt.Errorf("parsing iptables command: %w", err) @@ -68,7 +68,7 @@ func deleteIPTablesRule(ctx context.Context, iptablesBinary, instruction string, // findLineNumber finds the line number of an iptables rule. // It returns 0 if the rule is not found. func findLineNumber(ctx context.Context, iptablesBinary string, - instruction iptablesInstruction, runner Runner, logger Logger) ( + instruction iptablesInstruction, runner CmdRunner, logger Logger) ( lineNumber uint16, err error) { listFlags := []string{"-t", instruction.table, "-L", instruction.chain, "--line-numbers", "-n", "-v"} diff --git a/internal/firewall/delete_test.go b/internal/firewall/delete_test.go index f50dfa1a..2095ecf9 100644 --- a/internal/firewall/delete_test.go +++ b/internal/firewall/delete_test.go @@ -62,7 +62,7 @@ func Test_deleteIPTablesRule(t *testing.T) { testCases := map[string]struct { instruction string - makeRunner func(ctrl *gomock.Controller) *MockRunner + makeRunner func(ctrl *gomock.Controller) *MockCmdRunner makeLogger func(ctrl *gomock.Controller) *MockLogger errWrapped error errMessage string @@ -75,8 +75,8 @@ func Test_deleteIPTablesRule(t *testing.T) { }, "list_error": { instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678", - makeRunner: func(ctrl *gomock.Controller) *MockRunner { - runner := NewMockRunner(ctrl) + makeRunner: func(ctrl *gomock.Controller) *MockCmdRunner { + runner := NewMockCmdRunner(ctrl) runner.EXPECT(). Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")). Return("", errTest) @@ -93,8 +93,8 @@ func Test_deleteIPTablesRule(t *testing.T) { }, "rule_not_found": { instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678", - makeRunner: func(ctrl *gomock.Controller) *MockRunner { - runner := NewMockRunner(ctrl) + makeRunner: func(ctrl *gomock.Controller) *MockCmdRunner { + runner := NewMockCmdRunner(ctrl) runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")). Return(`Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes) num pkts bytes target prot opt in out source destination @@ -112,8 +112,8 @@ func Test_deleteIPTablesRule(t *testing.T) { }, "rule_found_delete_error": { instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678", - makeRunner: func(ctrl *gomock.Controller) *MockRunner { - runner := NewMockRunner(ctrl) + makeRunner: func(ctrl *gomock.Controller) *MockCmdRunner { + runner := NewMockCmdRunner(ctrl) runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")). Return("Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)\n"+ "num pkts bytes target prot opt in out source destination \n"+ @@ -137,8 +137,8 @@ func Test_deleteIPTablesRule(t *testing.T) { }, "rule_found_delete_success": { instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678", - makeRunner: func(ctrl *gomock.Controller) *MockRunner { - runner := NewMockRunner(ctrl) + makeRunner: func(ctrl *gomock.Controller) *MockCmdRunner { + runner := NewMockCmdRunner(ctrl) runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")). Return("Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)\n"+ "num pkts bytes target prot opt in out source destination \n"+ @@ -168,7 +168,7 @@ func Test_deleteIPTablesRule(t *testing.T) { ctx := context.Background() instruction := testCase.instruction - var runner *MockRunner + var runner *MockCmdRunner if testCase.makeRunner != nil { runner = testCase.makeRunner(ctrl) } diff --git a/internal/firewall/firewall.go b/internal/firewall/firewall.go index c9235c3f..b488ca2b 100644 --- a/internal/firewall/firewall.go +++ b/internal/firewall/firewall.go @@ -7,11 +7,10 @@ import ( "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/routing" - "github.com/qdm12/golibs/command" ) type Config struct { //nolint:maligned - runner command.Runner + runner CmdRunner logger Logger iptablesMutex sync.Mutex ip6tablesMutex sync.Mutex @@ -36,7 +35,7 @@ type Config struct { //nolint:maligned // NewConfig creates a new Config instance and returns an error // if no iptables implementation is available. func NewConfig(ctx context.Context, logger Logger, - runner command.Runner, defaultRoutes []routing.DefaultRoute, + runner CmdRunner, defaultRoutes []routing.DefaultRoute, localNetworks []routing.LocalNetwork) (config *Config, err error) { iptables, err := checkIptablesSupport(ctx, runner, "iptables", "iptables-nft", "iptables-legacy") if err != nil { diff --git a/internal/firewall/interfaces.go b/internal/firewall/interfaces.go index a4c88dc6..768f38e2 100644 --- a/internal/firewall/interfaces.go +++ b/internal/firewall/interfaces.go @@ -1,9 +1,9 @@ package firewall -import "github.com/qdm12/golibs/command" +import "os/exec" -type Runner interface { - Run(cmd command.ExecCmd) (output string, err error) +type CmdRunner interface { + Run(cmd *exec.Cmd) (output string, err error) } type Logger interface { diff --git a/internal/firewall/ip6tables.go b/internal/firewall/ip6tables.go index e304ca4d..54bd3eae 100644 --- a/internal/firewall/ip6tables.go +++ b/internal/firewall/ip6tables.go @@ -6,14 +6,12 @@ import ( "fmt" "os/exec" "strings" - - "github.com/qdm12/golibs/command" ) // findIP6tablesSupported checks for multiple iptables implementations // and returns the iptables path that is supported. If none work, an // empty string path is returned. -func findIP6tablesSupported(ctx context.Context, runner command.Runner) ( +func findIP6tablesSupported(ctx context.Context, runner CmdRunner) ( ip6tablesPath string, err error) { ip6tablesPath, err = checkIptablesSupport(ctx, runner, "ip6tables", "ip6tables-nft", "ip6tables-legacy") if errors.Is(err, ErrIPTablesNotSupported) { diff --git a/internal/firewall/mocks_generate_test.go b/internal/firewall/mocks_generate_test.go index 0d9c4541..ae563ca9 100644 --- a/internal/firewall/mocks_generate_test.go +++ b/internal/firewall/mocks_generate_test.go @@ -1,3 +1,3 @@ package firewall -//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Runner,Logger +//go:generate mockgen -destination=mocks_test.go -package $GOPACKAGE . CmdRunner,Logger diff --git a/internal/firewall/mocks_test.go b/internal/firewall/mocks_test.go index 61650abb..4490a904 100644 --- a/internal/firewall/mocks_test.go +++ b/internal/firewall/mocks_test.go @@ -1,41 +1,41 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/qdm12/gluetun/internal/firewall (interfaces: Runner,Logger) +// Source: github.com/qdm12/gluetun/internal/firewall (interfaces: CmdRunner,Logger) // Package firewall is a generated GoMock package. package firewall import ( + exec "os/exec" reflect "reflect" gomock "github.com/golang/mock/gomock" - command "github.com/qdm12/golibs/command" ) -// MockRunner is a mock of Runner interface. -type MockRunner struct { +// MockCmdRunner is a mock of CmdRunner interface. +type MockCmdRunner struct { ctrl *gomock.Controller - recorder *MockRunnerMockRecorder + recorder *MockCmdRunnerMockRecorder } -// MockRunnerMockRecorder is the mock recorder for MockRunner. -type MockRunnerMockRecorder struct { - mock *MockRunner +// MockCmdRunnerMockRecorder is the mock recorder for MockCmdRunner. +type MockCmdRunnerMockRecorder struct { + mock *MockCmdRunner } -// NewMockRunner creates a new mock instance. -func NewMockRunner(ctrl *gomock.Controller) *MockRunner { - mock := &MockRunner{ctrl: ctrl} - mock.recorder = &MockRunnerMockRecorder{mock} +// NewMockCmdRunner creates a new mock instance. +func NewMockCmdRunner(ctrl *gomock.Controller) *MockCmdRunner { + mock := &MockCmdRunner{ctrl: ctrl} + mock.recorder = &MockCmdRunnerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockRunner) EXPECT() *MockRunnerMockRecorder { +func (m *MockCmdRunner) EXPECT() *MockCmdRunnerMockRecorder { return m.recorder } // Run mocks base method. -func (m *MockRunner) Run(arg0 command.ExecCmd) (string, error) { +func (m *MockCmdRunner) Run(arg0 *exec.Cmd) (string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Run", arg0) ret0, _ := ret[0].(string) @@ -44,9 +44,9 @@ func (m *MockRunner) Run(arg0 command.ExecCmd) (string, error) { } // Run indicates an expected call of Run. -func (mr *MockRunnerMockRecorder) Run(arg0 interface{}) *gomock.Call { +func (mr *MockCmdRunnerMockRecorder) Run(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockRunner)(nil).Run), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockCmdRunner)(nil).Run), arg0) } // MockLogger is a mock of Logger interface. diff --git a/internal/firewall/support.go b/internal/firewall/support.go index 3147b275..f637baf5 100644 --- a/internal/firewall/support.go +++ b/internal/firewall/support.go @@ -8,8 +8,6 @@ import ( "os/exec" "sort" "strings" - - "github.com/qdm12/golibs/command" ) var ( @@ -19,7 +17,7 @@ var ( ErrIPTablesNotSupported = errors.New("no iptables supported found") ) -func checkIptablesSupport(ctx context.Context, runner command.Runner, +func checkIptablesSupport(ctx context.Context, runner CmdRunner, iptablesPathsToTry ...string) (iptablesPath string, err error) { iptablesPathToUnsupportedMessage := make(map[string]string, len(iptablesPathsToTry)) for _, pathToTest := range iptablesPathsToTry { @@ -62,7 +60,7 @@ func checkIptablesSupport(ctx context.Context, runner command.Runner, } func testIptablesPath(ctx context.Context, path string, - runner command.Runner) (ok bool, unsupportedMessage string, + runner CmdRunner) (ok bool, unsupportedMessage string, criticalErr error) { // Just listing iptables rules often work but we need // to modify them to ensure we can support the iptables diff --git a/internal/firewall/support_test.go b/internal/firewall/support_test.go index a5d43067..4f7039bd 100644 --- a/internal/firewall/support_test.go +++ b/internal/firewall/support_test.go @@ -6,7 +6,6 @@ import ( "testing" "github.com/golang/mock/gomock" - "github.com/qdm12/golibs/command" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -41,15 +40,15 @@ func Test_checkIptablesSupport(t *testing.T) { const inputPolicy = "ACCEPT" testCases := map[string]struct { - buildRunner func(ctrl *gomock.Controller) command.Runner + buildRunner func(ctrl *gomock.Controller) CmdRunner iptablesPathsToTry []string iptablesPath string errSentinel error errMessage string }{ "critical error when checking": { - buildRunner: func(ctrl *gomock.Controller) command.Runner { - runner := NewMockRunner(ctrl) + buildRunner: func(ctrl *gomock.Controller) CmdRunner { + runner := NewMockCmdRunner(ctrl) runner.EXPECT().Run(newAppendTestRuleMatcher("path1")). Return("", nil) runner.EXPECT().Run(newDeleteTestRuleMatcher("path1")). @@ -62,8 +61,8 @@ func Test_checkIptablesSupport(t *testing.T) { "output (exit code 4)", }, "found valid path": { - buildRunner: func(ctrl *gomock.Controller) command.Runner { - runner := NewMockRunner(ctrl) + buildRunner: func(ctrl *gomock.Controller) CmdRunner { + runner := NewMockCmdRunner(ctrl) runner.EXPECT().Run(newAppendTestRuleMatcher("path1")). Return("", nil) runner.EXPECT().Run(newDeleteTestRuleMatcher("path1")). @@ -78,8 +77,8 @@ func Test_checkIptablesSupport(t *testing.T) { iptablesPath: "path1", }, "all permission denied": { - buildRunner: func(ctrl *gomock.Controller) command.Runner { - runner := NewMockRunner(ctrl) + buildRunner: func(ctrl *gomock.Controller) CmdRunner { + runner := NewMockCmdRunner(ctrl) runner.EXPECT().Run(newAppendTestRuleMatcher("path1")). Return("Permission denied (you must be root) more context", errDummy) runner.EXPECT().Run(newAppendTestRuleMatcher("path2")). @@ -93,8 +92,8 @@ func Test_checkIptablesSupport(t *testing.T) { "path2: context: Permission denied (you must be root) (exit code 4)", }, "no valid path": { - buildRunner: func(ctrl *gomock.Controller) command.Runner { - runner := NewMockRunner(ctrl) + buildRunner: func(ctrl *gomock.Controller) CmdRunner { + runner := NewMockCmdRunner(ctrl) runner.EXPECT().Run(newAppendTestRuleMatcher("path1")). Return("output 1", errDummy) runner.EXPECT().Run(newAppendTestRuleMatcher("path2")). @@ -139,15 +138,15 @@ func Test_testIptablesPath(t *testing.T) { const inputPolicy = "ACCEPT" testCases := map[string]struct { - buildRunner func(ctrl *gomock.Controller) command.Runner + buildRunner func(ctrl *gomock.Controller) CmdRunner ok bool unsupportedMessage string criticalErrWrapped error criticalErrMessage string }{ "append test rule permission denied": { - buildRunner: func(ctrl *gomock.Controller) command.Runner { - runner := NewMockRunner(ctrl) + buildRunner: func(ctrl *gomock.Controller) CmdRunner { + runner := NewMockCmdRunner(ctrl) runner.EXPECT().Run(newAppendTestRuleMatcher(path)). Return("Permission denied (you must be root)", errDummy) return runner @@ -155,8 +154,8 @@ func Test_testIptablesPath(t *testing.T) { unsupportedMessage: "Permission denied (you must be root) (exit code 4)", }, "append test rule unsupported": { - buildRunner: func(ctrl *gomock.Controller) command.Runner { - runner := NewMockRunner(ctrl) + buildRunner: func(ctrl *gomock.Controller) CmdRunner { + runner := NewMockCmdRunner(ctrl) runner.EXPECT().Run(newAppendTestRuleMatcher(path)). Return("some output", errDummy) return runner @@ -164,8 +163,8 @@ func Test_testIptablesPath(t *testing.T) { unsupportedMessage: "some output (exit code 4)", }, "remove test rule error": { - buildRunner: func(ctrl *gomock.Controller) command.Runner { - runner := NewMockRunner(ctrl) + buildRunner: func(ctrl *gomock.Controller) CmdRunner { + runner := NewMockCmdRunner(ctrl) runner.EXPECT().Run(newAppendTestRuleMatcher(path)).Return("", nil) runner.EXPECT().Run(newDeleteTestRuleMatcher(path)). Return("some output", errDummy) @@ -175,8 +174,8 @@ func Test_testIptablesPath(t *testing.T) { criticalErrMessage: "failed cleaning up test rule: some output (exit code 4)", }, "list input rules permission denied": { - buildRunner: func(ctrl *gomock.Controller) command.Runner { - runner := NewMockRunner(ctrl) + buildRunner: func(ctrl *gomock.Controller) CmdRunner { + runner := NewMockCmdRunner(ctrl) runner.EXPECT().Run(newAppendTestRuleMatcher(path)).Return("", nil) runner.EXPECT().Run(newDeleteTestRuleMatcher(path)).Return("", nil) runner.EXPECT().Run(newListInputRulesMatcher(path)). @@ -186,8 +185,8 @@ func Test_testIptablesPath(t *testing.T) { unsupportedMessage: "Permission denied (you must be root) (exit code 4)", }, "list input rules unsupported": { - buildRunner: func(ctrl *gomock.Controller) command.Runner { - runner := NewMockRunner(ctrl) + buildRunner: func(ctrl *gomock.Controller) CmdRunner { + runner := NewMockCmdRunner(ctrl) runner.EXPECT().Run(newAppendTestRuleMatcher(path)).Return("", nil) runner.EXPECT().Run(newDeleteTestRuleMatcher(path)).Return("", nil) runner.EXPECT().Run(newListInputRulesMatcher(path)). @@ -197,8 +196,8 @@ func Test_testIptablesPath(t *testing.T) { unsupportedMessage: "some output (exit code 4)", }, "list input rules no policy": { - buildRunner: func(ctrl *gomock.Controller) command.Runner { - runner := NewMockRunner(ctrl) + buildRunner: func(ctrl *gomock.Controller) CmdRunner { + runner := NewMockCmdRunner(ctrl) runner.EXPECT().Run(newAppendTestRuleMatcher(path)).Return("", nil) runner.EXPECT().Run(newDeleteTestRuleMatcher(path)).Return("", nil) runner.EXPECT().Run(newListInputRulesMatcher(path)). @@ -209,8 +208,8 @@ func Test_testIptablesPath(t *testing.T) { criticalErrMessage: "input policy not found: in INPUT rules: some\noutput", }, "set policy permission denied": { - buildRunner: func(ctrl *gomock.Controller) command.Runner { - runner := NewMockRunner(ctrl) + buildRunner: func(ctrl *gomock.Controller) CmdRunner { + runner := NewMockCmdRunner(ctrl) runner.EXPECT().Run(newAppendTestRuleMatcher(path)).Return("", nil) runner.EXPECT().Run(newDeleteTestRuleMatcher(path)).Return("", nil) runner.EXPECT().Run(newListInputRulesMatcher(path)). @@ -222,8 +221,8 @@ func Test_testIptablesPath(t *testing.T) { unsupportedMessage: "Permission denied (you must be root) (exit code 4)", }, "set policy unsupported": { - buildRunner: func(ctrl *gomock.Controller) command.Runner { - runner := NewMockRunner(ctrl) + buildRunner: func(ctrl *gomock.Controller) CmdRunner { + runner := NewMockCmdRunner(ctrl) runner.EXPECT().Run(newAppendTestRuleMatcher(path)).Return("", nil) runner.EXPECT().Run(newDeleteTestRuleMatcher(path)).Return("", nil) runner.EXPECT().Run(newListInputRulesMatcher(path)). @@ -235,8 +234,8 @@ func Test_testIptablesPath(t *testing.T) { unsupportedMessage: "some output (exit code 4)", }, "success": { - buildRunner: func(ctrl *gomock.Controller) command.Runner { - runner := NewMockRunner(ctrl) + buildRunner: func(ctrl *gomock.Controller) CmdRunner { + runner := NewMockCmdRunner(ctrl) runner.EXPECT().Run(newAppendTestRuleMatcher(path)).Return("", nil) runner.EXPECT().Run(newDeleteTestRuleMatcher(path)).Return("", nil) runner.EXPECT().Run(newListInputRulesMatcher(path)). diff --git a/internal/openvpn/interfaces.go b/internal/openvpn/interfaces.go new file mode 100644 index 00000000..82f88703 --- /dev/null +++ b/internal/openvpn/interfaces.go @@ -0,0 +1,14 @@ +package openvpn + +import "os/exec" + +type CmdStarter interface { + Start(cmd *exec.Cmd) ( + stdoutLines, stderrLines <-chan string, + waitError <-chan error, startErr error) +} + +type CmdRunStarter interface { + Run(cmd *exec.Cmd) (output string, err error) + CmdStarter +} diff --git a/internal/openvpn/openvpn.go b/internal/openvpn/openvpn.go index 4e4b2c5b..14b08d01 100644 --- a/internal/openvpn/openvpn.go +++ b/internal/openvpn/openvpn.go @@ -2,19 +2,18 @@ package openvpn import ( "github.com/qdm12/gluetun/internal/constants/openvpn" - "github.com/qdm12/golibs/command" ) type Configurator struct { logger Infoer - cmder command.RunStarter + cmder CmdRunStarter configPath string authFilePath string askPassPath string puid, pgid int } -func New(logger Infoer, cmder command.RunStarter, +func New(logger Infoer, cmder CmdRunStarter, puid, pgid int) *Configurator { return &Configurator{ logger: logger, diff --git a/internal/openvpn/run.go b/internal/openvpn/run.go index 59ae5ea9..a8ad7762 100644 --- a/internal/openvpn/run.go +++ b/internal/openvpn/run.go @@ -4,16 +4,15 @@ import ( "context" "github.com/qdm12/gluetun/internal/configuration/settings" - "github.com/qdm12/golibs/command" ) type Runner struct { settings settings.OpenVPN - starter command.Starter + starter CmdStarter logger Logger } -func NewRunner(settings settings.OpenVPN, starter command.Starter, +func NewRunner(settings settings.OpenVPN, starter CmdStarter, logger Logger) *Runner { return &Runner{ starter: starter, @@ -37,12 +36,10 @@ func (r *Runner) Run(ctx context.Context, errCh chan<- error, ready chan<- struc select { case <-ctx.Done(): <-waitError - close(waitError) streamCancel() <-streamDone errCh <- ctx.Err() case err := <-waitError: - close(waitError) streamCancel() <-streamDone errCh <- err diff --git a/internal/openvpn/start.go b/internal/openvpn/start.go index 8ff7b99a..8eedda7b 100644 --- a/internal/openvpn/start.go +++ b/internal/openvpn/start.go @@ -8,7 +8,6 @@ import ( "syscall" "github.com/qdm12/gluetun/internal/constants/openvpn" - "github.com/qdm12/golibs/command" ) var ErrVersionUnknown = errors.New("OpenVPN version is unknown") @@ -18,8 +17,8 @@ const ( binOpenvpn26 = "openvpn2.6" ) -func start(ctx context.Context, starter command.Starter, version string, flags []string) ( - stdoutLines, stderrLines chan string, waitError chan error, err error) { +func start(ctx context.Context, starter CmdStarter, version string, flags []string) ( + stdoutLines, stderrLines <-chan string, waitError <-chan error, err error) { var bin string switch version { case openvpn.Openvpn25: diff --git a/internal/openvpn/stream.go b/internal/openvpn/stream.go index 5d6a9b6d..dba63b91 100644 --- a/internal/openvpn/stream.go +++ b/internal/openvpn/stream.go @@ -6,7 +6,7 @@ import ( ) func streamLines(ctx context.Context, done chan<- struct{}, - logger Logger, stdout, stderr chan string, + logger Logger, stdout, stderr <-chan string, tunnelReady chan<- struct{}) { defer close(done) @@ -16,10 +16,6 @@ func streamLines(ctx context.Context, done chan<- struct{}, errLine := false select { case <-ctx.Done(): - // Context should only be canceled after stdout and stderr are done - // being written to. - close(stdout) - close(stderr) return case line = <-stdout: case line = <-stderr: diff --git a/internal/vpn/interfaces.go b/internal/vpn/interfaces.go index d116b8dc..68103690 100644 --- a/internal/vpn/interfaces.go +++ b/internal/vpn/interfaces.go @@ -3,6 +3,7 @@ package vpn import ( "context" "net/netip" + "os/exec" "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/models" @@ -92,3 +93,9 @@ type PublicIPLoop interface { RunOnce(ctx context.Context) (err error) ClearData() (err error) } + +type CmdStarter interface { + Start(cmd *exec.Cmd) ( + stdoutLines, stderrLines <-chan string, + waitError <-chan error, startErr error) +} diff --git a/internal/vpn/loop.go b/internal/vpn/loop.go index 2bbe3520..8a90ca26 100644 --- a/internal/vpn/loop.go +++ b/internal/vpn/loop.go @@ -9,7 +9,6 @@ import ( "github.com/qdm12/gluetun/internal/loopstate" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/vpn/state" - "github.com/qdm12/golibs/command" "github.com/qdm12/log" ) @@ -32,7 +31,7 @@ type Loop struct { publicip PublicIPLoop dnsLooper DNSLoop // Other objects - starter command.Starter // for OpenVPN + starter CmdStarter // for OpenVPN logger log.LoggerInterface client *http.Client // Internal channels and values @@ -52,7 +51,7 @@ const ( func NewLoop(vpnSettings settings.VPN, ipv6Supported bool, vpnInputPorts []uint16, providers Providers, storage Storage, openvpnConf OpenVPN, netLinker NetLinker, fw Firewall, routing Routing, - portForward PortForward, starter command.Starter, + portForward PortForward, starter CmdStarter, publicip PublicIPLoop, dnsLooper DNSLoop, logger log.LoggerInterface, client *http.Client, buildInfo models.BuildInformation, versionInfo bool) *Loop { diff --git a/internal/vpn/openvpn.go b/internal/vpn/openvpn.go index 0a248b1d..4f91de83 100644 --- a/internal/vpn/openvpn.go +++ b/internal/vpn/openvpn.go @@ -7,14 +7,13 @@ import ( "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/openvpn" "github.com/qdm12/gluetun/internal/provider" - "github.com/qdm12/golibs/command" ) // setupOpenVPN sets OpenVPN up using the configurators and settings given. // It returns a serverName for port forwarding (PIA) and an error if it fails. func setupOpenVPN(ctx context.Context, fw Firewall, openvpnConf OpenVPN, providerConf provider.Provider, - settings settings.VPN, ipv6Supported bool, starter command.Starter, + settings settings.VPN, ipv6Supported bool, starter CmdStarter, logger openvpn.Logger) (runner *openvpn.Runner, serverName string, canPortForward bool, err error) { connection, err := providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)