Work on UAPI

Cross-platform API (get operation)
Handshake initiation creation process
Outbound packet flow
Fixes from code-review
This commit is contained in:
Mathias Hall-Andersen 2017-06-28 23:45:45 +02:00
parent 8236f3afa2
commit 1f0976a26c
18 changed files with 707 additions and 243 deletions

9
src/Makefile Normal file
View File

@ -0,0 +1,9 @@
BINARY=wireguard-go
build:
go build -o ${BINARY}
clean:
if [ -f ${BINARY} ]; then rm ${BINARY}; fi
.PHONY: clean

View File

@ -11,7 +11,7 @@ import (
"time" "time"
) )
/* todo : use real error code /* TODO : use real error code
* Many of which will be the same * Many of which will be the same
*/ */
const ( const (
@ -37,8 +37,55 @@ func (s *IPCError) ErrorCode() int {
return s.Code return s.Code
} }
func ipcGetOperation(socket *bufio.ReadWriter, dev *Device) { func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
device.mutex.RLock()
defer device.mutex.RUnlock()
// create lines
lines := make([]string, 0, 100)
send := func(line string) {
lines = append(lines, line)
}
if !device.privateKey.IsZero() {
send("private_key=" + device.privateKey.ToHex())
}
if device.address != nil {
send(fmt.Sprintf("listen_port=%d", device.address.Port))
}
for _, peer := range device.peers {
func() {
peer.mutex.RLock()
defer peer.mutex.RUnlock()
send("public_key=" + peer.handshake.remoteStatic.ToHex())
send("preshared_key=" + peer.handshake.presharedKey.ToHex())
if peer.endpoint != nil {
send("endpoint=" + peer.endpoint.String())
}
send(fmt.Sprintf("tx_bytes=%d", peer.tx_bytes))
send(fmt.Sprintf("rx_bytes=%d", peer.rx_bytes))
send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
for _, ip := range device.routingTable.AllowedIPs(peer) {
send("allowed_ip=" + ip.String())
}
}()
}
// send lines
for _, line := range lines {
device.log.Debug.Println("config:", line)
_, err := socket.WriteString(line + "\n")
if err != nil {
return err
}
}
return nil
} }
func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
@ -179,7 +226,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return nil return nil
} }
func ipcListen(dev *Device, socket io.ReadWriter) error { func ipcListen(device *Device, socket io.ReadWriter) error {
buffered := func(s io.ReadWriter) *bufio.ReadWriter { buffered := func(s io.ReadWriter) *bufio.ReadWriter {
reader := bufio.NewReader(s) reader := bufio.NewReader(s)
@ -187,6 +234,8 @@ func ipcListen(dev *Device, socket io.ReadWriter) error {
return bufio.NewReadWriter(reader, writer) return bufio.NewReadWriter(reader, writer)
}(socket) }(socket)
defer buffered.Flush()
for { for {
op, err := buffered.ReadString('\n') op, err := buffered.ReadString('\n')
if err != nil { if err != nil {
@ -197,17 +246,26 @@ func ipcListen(dev *Device, socket io.ReadWriter) error {
switch op { switch op {
case "set=1\n": case "set=1\n":
err := ipcSetOperation(dev, buffered) err := ipcSetOperation(device, buffered)
if err != nil { if err != nil {
fmt.Fprintf(buffered, "errno=%d\n", err.ErrorCode()) fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode())
return err return err
} else { } else {
fmt.Fprintf(buffered, "errno=0\n") fmt.Fprintf(buffered, "errno=0\n\n")
} }
buffered.Flush() buffered.Flush()
case "get=1\n": case "get=1\n":
err := ipcGetOperation(device, buffered)
if err != nil {
fmt.Fprintf(buffered, "errno=1\n\n") // fix
return err
} else {
fmt.Fprintf(buffered, "errno=0\n\n")
}
buffered.Flush()
case "\n":
default: default:
return errors.New("handle this please") return errors.New("handle this please")
} }

View File

@ -8,9 +8,14 @@ const (
RekeyAfterMessage = (1 << 64) - (1 << 16) - 1 RekeyAfterMessage = (1 << 64) - (1 << 16) - 1
RekeyAfterTime = time.Second * 120 RekeyAfterTime = time.Second * 120
RekeyAttemptTime = time.Second * 90 RekeyAttemptTime = time.Second * 90
RekeyTimeout = time.Second * 5 RekeyTimeout = time.Second * 5 // TODO: Exponential backoff
RejectAfterTime = time.Second * 180 RejectAfterTime = time.Second * 180
RejectAfterMessage = (1 << 64) - (1 << 4) - 1 RejectAfterMessage = (1 << 64) - (1 << 4) - 1
KeepaliveTimeout = time.Second * 10 KeepaliveTimeout = time.Second * 10
CookieRefreshTime = time.Second * 2 CookieRefreshTime = time.Second * 2
MaxHandshakeAttempTime = time.Second * 90
)
const (
QueueOutboundSize = 1024
) )

View File

@ -2,6 +2,7 @@ package main
import ( import (
"net" "net"
"runtime"
"sync" "sync"
) )
@ -16,7 +17,9 @@ type Device struct {
routingTable RoutingTable routingTable RoutingTable
indices IndexTable indices IndexTable
log *Logger log *Logger
queueWorkOutbound chan *OutboundWorkQueueElement queue struct {
encryption chan *QueueOutboundElement // parallel work queue
}
peers map[NoisePublicKey]*Peer peers map[NoisePublicKey]*Peer
mac MacStateDevice mac MacStateDevice
} }
@ -41,7 +44,9 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
} }
} }
func (device *Device) Init() { func NewDevice(tun TUNDevice) *Device {
device := new(Device)
device.mutex.Lock() device.mutex.Lock()
defer device.mutex.Unlock() defer device.mutex.Unlock()
@ -49,6 +54,14 @@ func (device *Device) Init() {
device.peers = make(map[NoisePublicKey]*Peer) device.peers = make(map[NoisePublicKey]*Peer)
device.indices.Init() device.indices.Init()
device.routingTable.Reset() device.routingTable.Reset()
// start workers
for i := 0; i < runtime.NumCPU(); i += 1 {
go device.RoutineEncryption()
}
go device.RoutineReadFromTUN(tun)
return device
} }
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {

172
src/handshake.go Normal file
View File

@ -0,0 +1,172 @@
package main
import (
"bytes"
"encoding/binary"
"net"
"sync/atomic"
"time"
)
/* Sends a keep-alive if no packets queued for peer
*
* Used by initiator of handshake and with active keep-alive
*/
func (peer *Peer) SendKeepAlive() bool {
if len(peer.queue.nonce) == 0 {
select {
case peer.queue.nonce <- []byte{}:
return true
default:
return false
}
}
return true
}
func (peer *Peer) RoutineHandshakeInitiator() {
var ongoing bool
var begun time.Time
var attempts uint
var timeout time.Timer
device := peer.device
work := new(QueueOutboundElement)
buffer := make([]byte, 0, 1024)
queueHandshakeInitiation := func() error {
work.mutex.Lock()
defer work.mutex.Unlock()
// create initiation
msg, err := device.CreateMessageInitiation(peer)
if err != nil {
return err
}
// create "work" element
writer := bytes.NewBuffer(buffer[:0])
binary.Write(writer, binary.LittleEndian, &msg)
work.packet = writer.Bytes()
peer.mac.AddMacs(work.packet)
peer.InsertOutbound(work)
return nil
}
for {
select {
case <-peer.signal.stopInitiator:
return
case <-peer.signal.newHandshake:
if ongoing {
continue
}
// create handshake
err := queueHandshakeInitiation()
if err != nil {
device.log.Error.Println("Failed to create initiation message:", err)
}
// log when we began
begun = time.Now()
ongoing = true
attempts = 0
timeout.Reset(RekeyTimeout)
case <-peer.timer.sendKeepalive.C:
// active keep-alives
peer.SendKeepAlive()
case <-peer.timer.handshakeTimeout.C:
// check if we can stop trying
if time.Now().Sub(begun) > MaxHandshakeAttempTime {
peer.signal.flushNonceQueue <- true
peer.timer.sendKeepalive.Stop()
ongoing = false
continue
}
// otherwise, try again (exponental backoff)
attempts += 1
err := queueHandshakeInitiation()
if err != nil {
device.log.Error.Println("Failed to create initiation message:", err)
}
peer.timer.handshakeTimeout.Reset((1 << attempts) * RekeyTimeout)
}
}
}
/* Handles packets related to handshake
*
*
*/
func (device *Device) HandshakeWorker(queue chan struct {
msg []byte
msgType uint32
addr *net.UDPAddr
}) {
for {
elem := <-queue
switch elem.msgType {
case MessageInitiationType:
if len(elem.msg) != MessageInitiationSize {
continue
}
// check for cookie
var msg MessageInitiation
binary.Read(nil, binary.LittleEndian, &msg)
case MessageResponseType:
if len(elem.msg) != MessageResponseSize {
continue
}
// check for cookie
case MessageCookieReplyType:
case MessageTransportType:
}
}
}
func (device *Device) KeepKeyFresh(peer *Peer) {
send := func() bool {
peer.keyPairs.mutex.RLock()
defer peer.keyPairs.mutex.RUnlock()
kp := peer.keyPairs.current
if kp == nil {
return false
}
nonce := atomic.LoadUint64(&kp.sendNonce)
if nonce > RekeyAfterMessage {
return true
}
return kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime
}()
if send {
}
}

64
src/helper_test.go Normal file
View File

@ -0,0 +1,64 @@
package main
import (
"bytes"
"testing"
)
/* Helpers for writing unit tests
*/
type DummyTUN struct {
name string
mtu uint
packets chan []byte
}
func (tun *DummyTUN) Name() string {
return tun.name
}
func (tun *DummyTUN) MTU() uint {
return tun.mtu
}
func (tun *DummyTUN) Write(d []byte) (int, error) {
tun.packets <- d
return len(d), nil
}
func (tun *DummyTUN) Read(d []byte) (int, error) {
t := <-tun.packets
copy(d, t)
return len(t), nil
}
func CreateDummyTUN(name string) (TUNDevice, error) {
var dummy DummyTUN
dummy.mtu = 1024
dummy.packets = make(chan []byte, 100)
return &dummy, nil
}
func assertNil(t *testing.T, err error) {
if err != nil {
t.Fatal(err)
}
}
func assertEqual(t *testing.T, a []byte, b []byte) {
if bytes.Compare(a, b) != 0 {
t.Fatal(a, "!=", b)
}
}
func randDevice(t *testing.T) *Device {
sk, err := newPrivateKey()
if err != nil {
t.Fatal(err)
}
tun, _ := CreateDummyTUN("dummy")
device := NewDevice(tun)
device.SetPrivateKey(sk)
return device
}

View File

@ -8,6 +8,7 @@ const (
IPv4version = 4 IPv4version = 4
IPv4offsetSrc = 12 IPv4offsetSrc = 12
IPv4offsetDst = IPv4offsetSrc + net.IPv4len IPv4offsetDst = IPv4offsetSrc + net.IPv4len
IPv4headerSize = 20
) )
const ( const (

View File

@ -8,8 +8,8 @@ import (
) )
func TestMAC1(t *testing.T) { func TestMAC1(t *testing.T) {
dev1 := newDevice(t) dev1 := randDevice(t)
dev2 := newDevice(t) dev2 := randDevice(t)
peer1 := dev2.NewPeer(dev1.privateKey.publicKey()) peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
peer2 := dev1.NewPeer(dev2.privateKey.publicKey()) peer2 := dev1.NewPeer(dev2.privateKey.publicKey())
@ -34,12 +34,10 @@ func TestMACs(t *testing.T) {
msg []byte, msg []byte,
receiver uint32, receiver uint32,
) bool { ) bool {
var device1 Device device1 := randDevice(t)
device1.Init()
device1.SetPrivateKey(sk1) device1.SetPrivateKey(sk1)
var device2 Device device2 := randDevice(t)
device2.Init()
device2.SetPrivateKey(sk2) device2.SetPrivateKey(sk2)
peer1 := device2.NewPeer(device1.privateKey.publicKey()) peer1 := device2.NewPeer(device1.privateKey.publicKey())

View File

@ -1,36 +1,30 @@
package main package main
import ( import (
"fmt"
)
func main() {
fd, err := CreateTUN("test0")
fmt.Println(fd, err)
queue := make(chan []byte, 1000)
// var device Device
// go OutgoingRoutingWorker(&device, queue)
for {
tmp := make([]byte, 1<<16)
n, err := fd.Read(tmp)
if err != nil {
break
}
queue <- tmp[:n]
}
}
/*
import (
"fmt"
"log" "log"
"net" "net"
) )
/*
*
* TODO: Fix logging
*/
func main() { func main() {
// Open TUN device
// TODO: Fix capabilities
tun, err := CreateTUN("test0")
log.Println(tun, err)
if err != nil {
return
}
device := NewDevice(tun)
// Start configuration lister
l, err := net.Listen("unix", "/var/run/wireguard/wg0.sock") l, err := net.Listen("unix", "/var/run/wireguard/wg0.sock")
if err != nil { if err != nil {
log.Fatal("listen error:", err) log.Fatal("listen error:", err)
@ -41,12 +35,9 @@ func main() {
if err != nil { if err != nil {
log.Fatal("accept error:", err) log.Fatal("accept error:", err)
} }
var dev Device
go func(conn net.Conn) { go func(conn net.Conn) {
err := ipcListen(&dev, conn) err := ipcListen(device, conn)
fmt.Println(err) log.Println(err)
}(fd) }(fd)
} }
} }
*/

View File

@ -77,7 +77,7 @@ type MessageCookieReply struct {
type Handshake struct { type Handshake struct {
state int state int
mutex sync.Mutex mutex sync.RWMutex
hash [blake2s.Size]byte // hash value hash [blake2s.Size]byte // hash value
chainKey [blake2s.Size]byte // chain key chainKey [blake2s.Size]byte // chain key
presharedKey NoiseSymmetricKey // psk presharedKey NoiseSymmetricKey // psk
@ -205,19 +205,26 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
} }
hash = mixHash(hash, msg.Static[:]) hash = mixHash(hash, msg.Static[:])
// find peer // lookup peer
peer := device.LookupPeer(peerPK) peer := device.LookupPeer(peerPK)
if peer == nil { if peer == nil {
return nil return nil
} }
handshake := &peer.handshake handshake := &peer.handshake
handshake.mutex.Lock()
defer handshake.mutex.Unlock() // verify identity
var timestamp TAI64N
ok := func() bool {
// read lock handshake
handshake.mutex.RLock()
defer handshake.mutex.RUnlock()
// decrypt timestamp // decrypt timestamp
var timestamp TAI64N
func() { func() {
var key [chacha20poly1305.KeySize]byte var key [chacha20poly1305.KeySize]byte
chainKey, key = KDF2( chainKey, key = KDF2(
@ -228,26 +235,34 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
}() }()
if err != nil { if err != nil {
return nil return false
} }
hash = mixHash(hash, msg.Timestamp[:]) hash = mixHash(hash, msg.Timestamp[:])
// TODO: check for flood attack
// check for replay attack // check for replay attack
if !timestamp.After(handshake.lastTimestamp) { return timestamp.After(handshake.lastTimestamp)
}()
if !ok {
return nil return nil
} }
// TODO: check for flood attack
// update handshake state // update handshake state
handshake.mutex.Lock()
handshake.hash = hash handshake.hash = hash
handshake.chainKey = chainKey handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender handshake.remoteIndex = msg.Sender
handshake.remoteEphemeral = msg.Ephemeral handshake.remoteEphemeral = msg.Ephemeral
handshake.lastTimestamp = timestamp handshake.lastTimestamp = timestamp
handshake.state = HandshakeInitiationConsumed handshake.state = HandshakeInitiationConsumed
handshake.mutex.Unlock()
return peer return peer
} }
@ -320,16 +335,26 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
return nil return nil
} }
handshake.mutex.Lock() var (
defer handshake.mutex.Unlock() hash [blake2s.Size]byte
chainKey [blake2s.Size]byte
)
ok := func() bool {
// read lock handshake
handshake.mutex.RLock()
defer handshake.mutex.RUnlock()
if handshake.state != HandshakeInitiationCreated { if handshake.state != HandshakeInitiationCreated {
return nil return false
} }
// finish 3-way DH // finish 3-way DH
hash := mixHash(handshake.hash, msg.Ephemeral[:]) hash = mixHash(handshake.hash, msg.Ephemeral[:])
chainKey := handshake.chainKey chainKey = handshake.chainKey
func() { func() {
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
@ -350,17 +375,27 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
aead, _ := chacha20poly1305.New(key[:]) aead, _ := chacha20poly1305.New(key[:])
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
if err != nil { if err != nil {
return nil return false
} }
hash = mixHash(hash, msg.Empty[:]) hash = mixHash(hash, msg.Empty[:])
return true
}()
if !ok {
return nil
}
// update handshake state // update handshake state
handshake.mutex.Lock()
handshake.hash = hash handshake.hash = hash
handshake.chainKey = chainKey handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender handshake.remoteIndex = msg.Sender
handshake.state = HandshakeResponseConsumed handshake.state = HandshakeResponseConsumed
handshake.mutex.Unlock()
return lookup.peer return lookup.peer
} }

View File

@ -6,29 +6,6 @@ import (
"testing" "testing"
) )
func assertNil(t *testing.T, err error) {
if err != nil {
t.Fatal(err)
}
}
func assertEqual(t *testing.T, a []byte, b []byte) {
if bytes.Compare(a, b) != 0 {
t.Fatal(a, "!=", b)
}
}
func newDevice(t *testing.T) *Device {
var device Device
sk, err := newPrivateKey()
if err != nil {
t.Fatal(err)
}
device.Init()
device.SetPrivateKey(sk)
return &device
}
func TestCurveWrappers(t *testing.T) { func TestCurveWrappers(t *testing.T) {
sk1, err := newPrivateKey() sk1, err := newPrivateKey()
assertNil(t, err) assertNil(t, err)
@ -49,8 +26,8 @@ func TestCurveWrappers(t *testing.T) {
func TestNoiseHandshake(t *testing.T) { func TestNoiseHandshake(t *testing.T) {
dev1 := newDevice(t) dev1 := randDevice(t)
dev2 := newDevice(t) dev2 := randDevice(t)
peer1 := dev2.NewPeer(dev1.privateKey.publicKey()) peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
peer2 := dev1.NewPeer(dev2.privateKey.publicKey()) peer2 := dev1.NewPeer(dev2.privateKey.publicKey())

View File

@ -3,18 +3,18 @@ package main
import ( import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"golang.org/x/crypto/chacha20poly1305"
) )
const ( const (
NoisePublicKeySize = 32 NoisePublicKeySize = 32
NoisePrivateKeySize = 32 NoisePrivateKeySize = 32
NoiseSymmetricKeySize = 32
) )
type ( type (
NoisePublicKey [NoisePublicKeySize]byte NoisePublicKey [NoisePublicKeySize]byte
NoisePrivateKey [NoisePrivateKeySize]byte NoisePrivateKey [NoisePrivateKeySize]byte
NoiseSymmetricKey [NoiseSymmetricKeySize]byte NoiseSymmetricKey [chacha20poly1305.KeySize]byte
NoiseNonce uint64 // padded to 12-bytes NoiseNonce uint64 // padded to 12-bytes
) )
@ -30,6 +30,15 @@ func loadExactHex(dst []byte, src string) error {
return nil return nil
} }
func (key NoisePrivateKey) IsZero() bool {
for _, b := range key[:] {
if b != 0 {
return false
}
}
return true
}
func (key *NoisePrivateKey) FromHex(src string) error { func (key *NoisePrivateKey) FromHex(src string) error {
return loadExactHex(key[:], src) return loadExactHex(key[:], src)
} }

View File

@ -7,9 +7,7 @@ import (
"time" "time"
) )
const ( const ()
OutboundQueueSize = 64
)
type Peer struct { type Peer struct {
mutex sync.RWMutex mutex sync.RWMutex
@ -18,9 +16,25 @@ type Peer struct {
keyPairs KeyPairs keyPairs KeyPairs
handshake Handshake handshake Handshake
device *Device device *Device
queueInbound chan []byte tx_bytes uint64
queueOutbound chan *OutboundWorkQueueElement rx_bytes uint64
queueOutboundRouting chan []byte time struct {
lastSend time.Time // last send message
}
signal struct {
newHandshake chan bool
flushNonceQueue chan bool // empty queued packets
stopSending chan bool // stop sending pipeline
stopInitiator chan bool // stop initiator timer
}
timer struct {
sendKeepalive time.Timer
handshakeTimeout time.Timer
}
queue struct {
nonce chan []byte // nonce / pre-handshake queue
outbound chan *QueueOutboundElement // sequential ordering of work
}
mac MacStatePeer mac MacStatePeer
} }
@ -33,7 +47,8 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
peer.device = device peer.device = device
peer.keyPairs.Init() peer.keyPairs.Init()
peer.mac.Init(pk) peer.mac.Init(pk)
peer.queueOutbound = make(chan *OutboundWorkQueueElement, OutboundQueueSize) peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
peer.queue.nonce = make(chan []byte, QueueOutboundSize)
// map public key // map public key
@ -54,5 +69,20 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
handshake.mutex.Unlock() handshake.mutex.Unlock()
peer.mutex.Unlock() peer.mutex.Unlock()
// start workers
peer.signal.stopSending = make(chan bool, 1)
peer.signal.stopInitiator = make(chan bool, 1)
peer.signal.newHandshake = make(chan bool, 1)
peer.signal.flushNonceQueue = make(chan bool, 1)
go peer.RoutineNonce()
go peer.RoutineHandshakeInitiator()
return &peer return &peer
} }
func (peer *Peer) Close() {
peer.signal.stopSending <- true
peer.signal.stopInitiator <- true
}

View File

@ -12,9 +12,20 @@ type RoutingTable struct {
mutex sync.RWMutex mutex sync.RWMutex
} }
func (table *RoutingTable) AllowedIPs(peer *Peer) []net.IPNet {
table.mutex.RLock()
defer table.mutex.RUnlock()
allowed := make([]net.IPNet, 10)
table.IPv4.AllowedIPs(peer, allowed)
table.IPv6.AllowedIPs(peer, allowed)
return allowed
}
func (table *RoutingTable) Reset() { func (table *RoutingTable) Reset() {
table.mutex.Lock() table.mutex.Lock()
defer table.mutex.Unlock() defer table.mutex.Unlock()
table.IPv4 = nil table.IPv4 = nil
table.IPv6 = nil table.IPv6 = nil
} }
@ -22,6 +33,7 @@ func (table *RoutingTable) Reset() {
func (table *RoutingTable) RemovePeer(peer *Peer) { func (table *RoutingTable) RemovePeer(peer *Peer) {
table.mutex.Lock() table.mutex.Lock()
defer table.mutex.Unlock() defer table.mutex.Unlock()
table.IPv4 = table.IPv4.RemovePeer(peer) table.IPv4 = table.IPv4.RemovePeer(peer)
table.IPv6 = table.IPv6.RemovePeer(peer) table.IPv6 = table.IPv6.RemovePeer(peer)
} }

View File

@ -5,30 +5,78 @@ import (
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"net" "net"
"sync" "sync"
"time"
) )
/* Handles outbound flow /* Handles outbound flow
* *
* 1. TUN queue * 1. TUN queue
* 2. Routing * 2. Routing (sequential)
* 3. Per peer queuing * 3. Nonce assignment (sequential)
* 4. (work queuing) * 4. Encryption (parallel)
* 5. Transmission (sequential)
* *
* The order of packets (per peer) is maintained.
* The functions in this file occure (roughly) in the order packets are processed.
*/ */
type OutboundWorkQueueElement struct { /* A work unit
wg sync.WaitGroup *
* The sequential consumers will attempt to take the lock,
* workers release lock when they have completed work on the packet.
*/
type QueueOutboundElement struct {
mutex sync.Mutex
packet []byte packet []byte
nonce uint64 nonce uint64
keyPair *KeyPair keyPair *KeyPair
} }
func (peer *Peer) HandshakeWorker(handshakeQueue []byte) { func (peer *Peer) FlushNonceQueue() {
elems := len(peer.queue.nonce)
for i := 0; i < elems; i += 1 {
select {
case <-peer.queue.nonce:
default:
return
}
}
} }
func (device *Device) SendPacket(packet []byte) { func (peer *Peer) InsertOutbound(elem *QueueOutboundElement) {
for {
select {
case peer.queue.outbound <- elem:
default:
select {
case <-peer.queue.outbound:
default:
}
}
}
}
/* Reads packets from the TUN and inserts
* into nonce queue for peer
*
* Obs. Single instance per TUN device
*/
func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
for {
// read packet
packet := make([]byte, 1<<16) // TODO: Fix & avoid dynamic allocation
size, err := tun.Read(packet)
if err != nil {
device.log.Error.Println("Failed to read packet from TUN device:", err)
continue
}
packet = packet[:size]
if len(packet) < IPv4headerSize {
device.log.Error.Println("Packet too short, length:", len(packet))
continue
}
device.log.Debug.Println("New packet on TUN:", packet) // TODO: Slow debugging, remove.
// lookup peer // lookup peer
@ -43,69 +91,73 @@ func (device *Device) SendPacket(packet []byte) {
peer = device.routingTable.LookupIPv6(dst) peer = device.routingTable.LookupIPv6(dst)
default: default:
device.log.Debug.Println("receieved packet with unknown IP version") device.log.Debug.Println("Receieved packet with unknown IP version")
return return
} }
if peer == nil { if peer == nil {
device.log.Debug.Println("No peer configured for IP")
return return
} }
// insert into peer queue // insert into nonce/pre-handshake queue
for { for {
select { select {
case peer.queueOutboundRouting <- packet: case peer.queue.nonce <- packet:
default: default:
select { select {
case <-peer.queueOutboundRouting: case <-peer.queue.nonce:
default: default:
} }
continue continue
} }
break break
} }
}
} }
/* Go routine /* Queues packets when there is no handshake.
* Then assigns nonces to packets sequentially
* and creates "work" structs for workers
* *
* TODO: Avoid dynamic allocation of work queue elements
* *
* 1. waits for handshake. * Obs. A single instance per peer
* 2. assigns key pair & nonce
* 3. inserts to working queue
*
* TODO: avoid dynamic allocation of work queue elements
*/ */
func (peer *Peer) RoutineOutboundNonceWorker() { func (peer *Peer) RoutineNonce() {
var packet []byte var packet []byte
var keyPair *KeyPair var keyPair *KeyPair
var flushTimer time.Timer
for { for {
// wait for packet // wait for packet
if packet == nil { if packet == nil {
packet = <-peer.queueOutboundRouting select {
case packet = <-peer.queue.nonce:
case <-peer.signal.stopSending:
close(peer.queue.outbound)
return
}
} }
// wait for key pair // wait for key pair
for keyPair == nil { for keyPair == nil {
flushTimer.Reset(time.Second * 10) peer.signal.newHandshake <- true
// TODO: Handshake or NOP
select { select {
case <-peer.keyPairs.newKeyPair: case <-peer.keyPairs.newKeyPair:
keyPair = peer.keyPairs.Current() keyPair = peer.keyPairs.Current()
continue continue
case <-flushTimer.C: case <-peer.signal.flushNonceQueue:
size := len(peer.queueOutboundRouting) peer.FlushNonceQueue()
for i := 0; i < size; i += 1 {
<-peer.queueOutboundRouting
}
packet = nil packet = nil
continue
case <-peer.signal.stopSending:
close(peer.queue.outbound)
return
} }
break
} }
// process current packet // process current packet
@ -114,14 +166,13 @@ func (peer *Peer) RoutineOutboundNonceWorker() {
// create work element // create work element
work := new(OutboundWorkQueueElement) work := new(QueueOutboundElement) // TODO: profile, maybe use pool
work.wg.Add(1)
work.keyPair = keyPair work.keyPair = keyPair
work.packet = packet work.packet = packet
work.nonce = keyPair.sendNonce work.nonce = keyPair.sendNonce
work.mutex.Lock()
packet = nil packet = nil
peer.queueOutbound <- work
keyPair.sendNonce += 1 keyPair.sendNonce += 1
// drop packets until there is space // drop packets until there is space
@ -129,46 +180,36 @@ func (peer *Peer) RoutineOutboundNonceWorker() {
func() { func() {
for { for {
select { select {
case peer.device.queueWorkOutbound <- work: case peer.device.queue.encryption <- work:
return return
default: default:
drop := <-peer.device.queueWorkOutbound drop := <-peer.device.queue.encryption
drop.packet = nil drop.packet = nil
drop.wg.Done() drop.mutex.Unlock()
} }
} }
}() }()
peer.queue.outbound <- work
} }
} }
} }
/* Go routine /* Encrypts the elements in the queue
* * and marks them for sequential consumption (by releasing the mutex)
* sequentially reads packets from queue and sends to endpoint
* *
* Obs. One instance per core
*/ */
func (peer *Peer) RoutineSequential() { func (device *Device) RoutineEncryption() {
for work := range peer.queueOutbound {
work.wg.Wait()
if work.packet == nil {
continue
}
if peer.endpoint == nil {
continue
}
peer.device.conn.WriteToUDP(work.packet, peer.endpoint)
}
}
func (device *Device) RoutineEncryptionWorker() {
var nonce [chacha20poly1305.NonceSize]byte var nonce [chacha20poly1305.NonceSize]byte
for work := range device.queueWorkOutbound { for work := range device.queue.encryption {
// pad packet // pad packet
padding := device.mtu - len(work.packet) padding := device.mtu - len(work.packet)
if padding < 0 { if padding < 0 {
// drop
work.packet = nil work.packet = nil
work.wg.Done() work.mutex.Unlock()
} }
for n := 0; n < padding; n += 1 { for n := 0; n < padding; n += 1 {
work.packet = append(work.packet, 0) work.packet = append(work.packet, 0)
@ -183,6 +224,30 @@ func (device *Device) RoutineEncryptionWorker() {
work.packet, work.packet,
nil, nil,
) )
work.wg.Done() work.mutex.Unlock()
}
}
/* Sequentially reads packets from queue and sends to endpoint
*
* Obs. Single instance per peer.
* The routine terminates then the outbound queue is closed.
*/
func (peer *Peer) RoutineSequential() {
for work := range peer.queue.outbound {
work.mutex.Lock()
func() {
peer.mutex.RLock()
defer peer.mutex.RUnlock()
if work.packet == nil {
return
}
if peer.endpoint == nil {
return
}
peer.device.conn.WriteToUDP(work.packet, peer.endpoint)
peer.timer.sendKeepalive.Reset(peer.persistentKeepaliveInterval)
}()
work.mutex.Unlock()
} }
} }

View File

@ -1,15 +1,20 @@
package main package main
import ( import (
"errors"
"net" "net"
) )
/* Binary trie /* Binary trie
*
* The net.IPs used here are not formatted the
* same way as those created by the "net" functions.
* Here the IPs are slices of either 4 or 16 byte (not always 16)
* *
* Syncronization done seperatly * Syncronization done seperatly
* See: routing.go * See: routing.go
* *
* Todo: Better commenting * TODO: Better commenting
*/ */
type Trie struct { type Trie struct {
@ -24,7 +29,7 @@ type Trie struct {
} }
/* Finds length of matching prefix /* Finds length of matching prefix
* Maybe there is a faster way * TODO: Make faster
* *
* Assumption: len(ip1) == len(ip2) * Assumption: len(ip1) == len(ip2)
*/ */
@ -189,3 +194,25 @@ func (node *Trie) Count() uint {
r := node.child[1].Count() r := node.child[1].Count()
return l + r return l + r
} }
func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) {
if node.peer == p {
var mask net.IPNet
mask.Mask = net.CIDRMask(int(node.cidr), len(node.bits)*8)
if len(node.bits) == net.IPv4len {
mask.IP = net.IPv4(
node.bits[0],
node.bits[1],
node.bits[2],
node.bits[3],
)
} else if len(node.bits) == net.IPv6len {
mask.IP = node.bits
} else {
panic(errors.New("bug: unexpected address length"))
}
results = append(results, mask)
}
node.child[0].AllowedIPs(p, results)
node.child[1].AllowedIPs(p, results)
}

View File

@ -1,6 +1,6 @@
package main package main
type TUN interface { type TUNDevice interface {
Read([]byte) (int, error) Read([]byte) (int, error)
Write([]byte) (int, error) Write([]byte) (int, error)
Name() string Name() string

View File

@ -9,9 +9,7 @@ import (
"unsafe" "unsafe"
) )
/* Platform dependent functions for interacting with /* Implementation of the TUN device interface for linux
* TUN devices on linux systems
*
*/ */
const CloneDevicePath = "/dev/net/tun" const CloneDevicePath = "/dev/net/tun"
@ -45,7 +43,7 @@ func (tun *NativeTun) Read(d []byte) (int, error) {
return tun.fd.Read(d) return tun.fd.Read(d)
} }
func CreateTUN(name string) (TUN, error) { func CreateTUN(name string) (TUNDevice, error) {
// Open clone device // Open clone device
fd, err := os.OpenFile(CloneDevicePath, os.O_RDWR, 0) fd, err := os.OpenFile(CloneDevicePath, os.O_RDWR, 0)
if err != nil { if err != nil {
@ -53,7 +51,7 @@ func CreateTUN(name string) (TUN, error) {
} }
// Prepare ifreq struct // Prepare ifreq struct
var ifr [18]byte var ifr [128]byte
var flags uint16 = IFF_TUN | IFF_NO_PI var flags uint16 = IFF_TUN | IFF_NO_PI
nameBytes := []byte(name) nameBytes := []byte(name)
if len(nameBytes) >= IFNAMSIZ { if len(nameBytes) >= IFNAMSIZ {