diff --git a/device/device_test.go b/device/device_test.go index db5a3c0..b6212b5 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -9,21 +9,17 @@ package device * without network dependencies */ -import "testing" +import ( + "bytes" + "testing" +) func TestDevice(t *testing.T) { // prepare tun devices for generating traffic - tun1, err := CreateDummyTUN("tun1") - if err != nil { - t.Error("failed to create tun:", err.Error()) - } - - tun2, err := CreateDummyTUN("tun2") - if err != nil { - t.Error("failed to create tun:", err.Error()) - } + tun1 := newDummyTUN("tun1") + tun2 := newDummyTUN("tun2") _ = tun1 _ = tun2 @@ -46,3 +42,27 @@ func TestDevice(t *testing.T) { // create binds } + +func randDevice(t *testing.T) *Device { + sk, err := newPrivateKey() + if err != nil { + t.Fatal(err) + } + tun := newDummyTUN("dummy") + logger := NewLogger(LogLevelError, "") + device := NewDevice(tun, logger) + device.SetPrivateKey(sk) + return device +} + +func assertNil(t *testing.T, err error) { + if err != nil { + t.Fatal(err) + } +} + +func assertEqual(t *testing.T, a []byte, b []byte) { + if bytes.Compare(a, b) != 0 { + t.Fatal(a, "!=", b) + } +} diff --git a/device/tun_test.go b/device/tun_test.go new file mode 100644 index 0000000..fbe4c1d --- /dev/null +++ b/device/tun_test.go @@ -0,0 +1,56 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "errors" + "os" + + "golang.zx2c4.com/wireguard/tun" +) + +// newDummyTUN creates a dummy TUN device with the specified name. +func newDummyTUN(name string) tun.TUNDevice { + return &dummyTUN{ + name: name, + packets: make(chan []byte, 100), + events: make(chan tun.TUNEvent, 10), + } +} + +// A dummyTUN is a tun.TUNDevice which is used in unit tests. +type dummyTUN struct { + name string + mtu int + packets chan []byte + events chan tun.TUNEvent +} + +func (d *dummyTUN) Events() chan tun.TUNEvent { return d.events } +func (*dummyTUN) File() *os.File { return nil } +func (*dummyTUN) Flush() error { return nil } +func (d *dummyTUN) MTU() (int, error) { return d.mtu, nil } +func (d *dummyTUN) Name() (string, error) { return d.name, nil } + +func (d *dummyTUN) Close() error { + close(d.events) + close(d.packets) + return nil +} + +func (d *dummyTUN) Read(b []byte, offset int) (int, error) { + buf, ok := <-d.packets + if !ok { + return 0, errors.New("device closed") + } + copy(b[offset:], buf) + return len(buf), nil +} + +func (d *dummyTUN) Write(b []byte, offset int) (int, error) { + d.packets <- b[offset:] + return len(b), nil +} diff --git a/tun/helper_test.go b/tun/helper_test.go deleted file mode 100644 index 4fa0357..0000000 --- a/tun/helper_test.go +++ /dev/null @@ -1,93 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package tun - -import ( - "bytes" - "errors" - "os" - "testing" - - "golang.zx2c4.com/wireguard/tun" -) - -/* Helpers for writing unit tests - */ - -type DummyTUN struct { - name string - mtu int - packets chan []byte - events chan tun.TUNEvent -} - -func (tun *DummyTUN) File() *os.File { - return nil -} - -func (tun *DummyTUN) Name() (string, error) { - return tun.name, nil -} - -func (tun *DummyTUN) MTU() (int, error) { - return tun.mtu, nil -} - -func (tun *DummyTUN) Write(d []byte, offset int) (int, error) { - tun.packets <- d[offset:] - return len(d), nil -} - -func (tun *DummyTUN) Close() error { - close(tun.events) - close(tun.packets) - return nil -} - -func (tun *DummyTUN) Events() chan tun.TUNEvent { - return tun.events -} - -func (tun *DummyTUN) Read(d []byte, offset int) (int, error) { - t, ok := <-tun.packets - if !ok { - return 0, errors.New("device closed") - } - copy(d[offset:], t) - return len(t), nil -} - -func CreateDummyTUN(name string) (tun.TUNDevice, error) { - var dummy DummyTUN - dummy.mtu = 0 - dummy.packets = make(chan []byte, 100) - dummy.events = make(chan tun.TUNEvent, 10) - return &dummy, nil -} - -func assertNil(t *testing.T, err error) { - if err != nil { - t.Fatal(err) - } -} - -func assertEqual(t *testing.T, a []byte, b []byte) { - if bytes.Compare(a, b) != 0 { - t.Fatal(a, "!=", b) - } -} - -func randDevice(t *testing.T) *Device { - sk, err := newPrivateKey() - if err != nil { - t.Fatal(err) - } - tun, _ := CreateDummyTUN("dummy") - logger := NewLogger(LogLevelError, "") - device := NewDevice(tun, logger) - device.SetPrivateKey(sk) - return device -}