Skip to content

Commit

Permalink
Fix UseDNSSystemWide keepNameserver
Browse files Browse the repository at this point in the history
  • Loading branch information
qdm12 committed Jan 6, 2021
1 parent 9c3e47c commit 0f720cc
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 22 deletions.
10 changes: 9 additions & 1 deletion pkg/unbound/nameserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func (c *configurator) UseDNSInternally(ip net.IP) {

// UseDNSSystemWide changes the nameserver to use for DNS system wide.
func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error {
file, err := c.openFile(resolvConfFilepath, os.O_RDWR|os.O_TRUNC, 0644)
file, err := c.openFile(resolvConfFilepath, os.O_RDONLY, 0)
if err != nil {
return err
}
Expand All @@ -28,6 +28,9 @@ func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error {
_ = file.Close()
return err
}
if err := file.Close(); err != nil {
return err
}

s := strings.TrimSuffix(string(data), "\n")

Expand All @@ -43,6 +46,11 @@ func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error {
}

s = strings.Join(lines, "\n") + "\n"

file, err = c.openFile(resolvConfFilepath, os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
return err
}
_, err = file.WriteString(s)
if err != nil {
_ = file.Close()
Expand Down
98 changes: 77 additions & 21 deletions pkg/unbound/nameserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,54 @@ import (

func Test_UseDNSSystemWide(t *testing.T) {
t.Parallel()

tests := map[string]struct {
ip net.IP
keepNameserver bool
data []byte
writtenData string
openErr error
firstOpenErr error
readErr error
firstCloseErr error
secondOpenErr error
writtenData string
writeErr error
closeErr error
secondCloseErr error
err error
}{
"no data": {
ip: net.IP{127, 0, 0, 1},
writtenData: "nameserver 127.0.0.1\n",
},
"open error": {
ip: net.IP{127, 0, 0, 1},
openErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
"first open error": {
ip: net.IP{127, 0, 0, 1},
firstOpenErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"read error": {
readErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"first close error": {
firstCloseErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"second open error": {
ip: net.IP{127, 0, 0, 1},
secondOpenErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"write error": {
ip: net.IP{127, 0, 0, 1},
writtenData: "nameserver 127.0.0.1\n",
writeErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"second close error": {
ip: net.IP{127, 0, 0, 1},
writtenData: "nameserver 127.0.0.1\n",
secondCloseErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"lines without nameserver": {
ip: net.IP{127, 0, 0, 1},
data: []byte("abc\ndef\n"),
Expand All @@ -68,9 +86,20 @@ func Test_UseDNSSystemWide(t *testing.T) {
t.Parallel()
mockCtrl := gomock.NewController(t)

file := mock_os.NewMockFile(mockCtrl)
if tc.openErr == nil {
firstReadCall := file.EXPECT().
type fileCall struct {
path string
flag int
perm os.FileMode
file os.File
err error
}

var fileCalls []fileCall

readOnlyFile := mock_os.NewMockFile(mockCtrl)

if tc.firstOpenErr == nil {
firstReadCall := readOnlyFile.EXPECT().
Read(gomock.AssignableToTypeOf([]byte{})).
DoAndReturn(func(b []byte) (int, error) {
copy(b, tc.data)
Expand All @@ -80,23 +109,50 @@ func Test_UseDNSSystemWide(t *testing.T) {
if readErr == nil {
readErr = io.EOF
}
finalReadCall := file.EXPECT().
finalReadCall := readOnlyFile.EXPECT().
Read(gomock.AssignableToTypeOf([]byte{})).
Return(0, readErr).After(firstReadCall)
if tc.readErr == nil {
writeCall := file.EXPECT().WriteString(tc.writtenData).
Return(0, tc.writeErr).After(finalReadCall)
file.EXPECT().Close().Return(tc.closeErr).After(writeCall)
} else {
file.EXPECT().Close().Return(tc.closeErr).After(finalReadCall)
readOnlyFile.EXPECT().Close().
Return(tc.firstCloseErr).
After(finalReadCall)
}

fileCalls = append(fileCalls, fileCall{
path: resolvConfFilepath,
flag: os.O_RDONLY,
perm: 0,
file: readOnlyFile,
err: tc.firstOpenErr,
}) // always return readOnlyFile

if tc.firstOpenErr == nil && tc.readErr == nil && tc.firstCloseErr == nil {
writeOnlyFile := mock_os.NewMockFile(mockCtrl)
if tc.secondOpenErr == nil {
writeCall := writeOnlyFile.EXPECT().
WriteString(tc.writtenData).
Return(0, tc.writeErr)
writeOnlyFile.EXPECT().
Close().
Return(tc.secondCloseErr).
After(writeCall)
}
fileCalls = append(fileCalls, fileCall{
path: resolvConfFilepath,
flag: os.O_WRONLY | os.O_TRUNC,
perm: os.FileMode(0644),
file: writeOnlyFile,
err: tc.secondOpenErr,
})
}

fileCallIndex := 0
openFile := func(name string, flag int, perm os.FileMode) (os.File, error) {
assert.Equal(t, resolvConfFilepath, name)
assert.Equal(t, os.O_RDWR|os.O_TRUNC, flag)
assert.Equal(t, os.FileMode(0644), perm)
return file, tc.openErr
fileCall := fileCalls[fileCallIndex]
fileCallIndex++
assert.Equal(t, fileCall.path, name)
assert.Equal(t, fileCall.flag, flag)
assert.Equal(t, fileCall.perm, perm)
return fileCall.file, fileCall.err
}

c := &configurator{
Expand Down

0 comments on commit 0f720cc

Please sign in to comment.