Work on UAPI
Cross-platform API (get operation) Handshake initiation creation process Outbound packet flow Fixes from code-review
This commit is contained in:
parent
8236f3afa2
commit
1f0976a26c
9
src/Makefile
Normal file
9
src/Makefile
Normal file
@ -0,0 +1,9 @@
|
||||
BINARY=wireguard-go
|
||||
|
||||
build:
|
||||
go build -o ${BINARY}
|
||||
|
||||
clean:
|
||||
if [ -f ${BINARY} ]; then rm ${BINARY}; fi
|
||||
|
||||
.PHONY: clean
|
@ -11,7 +11,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
/* todo : use real error code
|
||||
/* TODO : use real error code
|
||||
* Many of which will be the same
|
||||
*/
|
||||
const (
|
||||
@ -37,8 +37,55 @@ func (s *IPCError) ErrorCode() int {
|
||||
return s.Code
|
||||
}
|
||||
|
||||
func ipcGetOperation(socket *bufio.ReadWriter, dev *Device) {
|
||||
func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
|
||||
|
||||
device.mutex.RLock()
|
||||
defer device.mutex.RUnlock()
|
||||
|
||||
// create lines
|
||||
|
||||
lines := make([]string, 0, 100)
|
||||
send := func(line string) {
|
||||
lines = append(lines, line)
|
||||
}
|
||||
|
||||
if !device.privateKey.IsZero() {
|
||||
send("private_key=" + device.privateKey.ToHex())
|
||||
}
|
||||
|
||||
if device.address != nil {
|
||||
send(fmt.Sprintf("listen_port=%d", device.address.Port))
|
||||
}
|
||||
|
||||
for _, peer := range device.peers {
|
||||
func() {
|
||||
peer.mutex.RLock()
|
||||
defer peer.mutex.RUnlock()
|
||||
send("public_key=" + peer.handshake.remoteStatic.ToHex())
|
||||
send("preshared_key=" + peer.handshake.presharedKey.ToHex())
|
||||
if peer.endpoint != nil {
|
||||
send("endpoint=" + peer.endpoint.String())
|
||||
}
|
||||
send(fmt.Sprintf("tx_bytes=%d", peer.tx_bytes))
|
||||
send(fmt.Sprintf("rx_bytes=%d", peer.rx_bytes))
|
||||
send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
|
||||
for _, ip := range device.routingTable.AllowedIPs(peer) {
|
||||
send("allowed_ip=" + ip.String())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// send lines
|
||||
|
||||
for _, line := range lines {
|
||||
device.log.Debug.Println("config:", line)
|
||||
_, err := socket.WriteString(line + "\n")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
||||
@ -179,7 +226,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
||||
return nil
|
||||
}
|
||||
|
||||
func ipcListen(dev *Device, socket io.ReadWriter) error {
|
||||
func ipcListen(device *Device, socket io.ReadWriter) error {
|
||||
|
||||
buffered := func(s io.ReadWriter) *bufio.ReadWriter {
|
||||
reader := bufio.NewReader(s)
|
||||
@ -187,6 +234,8 @@ func ipcListen(dev *Device, socket io.ReadWriter) error {
|
||||
return bufio.NewReadWriter(reader, writer)
|
||||
}(socket)
|
||||
|
||||
defer buffered.Flush()
|
||||
|
||||
for {
|
||||
op, err := buffered.ReadString('\n')
|
||||
if err != nil {
|
||||
@ -197,17 +246,26 @@ func ipcListen(dev *Device, socket io.ReadWriter) error {
|
||||
switch op {
|
||||
|
||||
case "set=1\n":
|
||||
err := ipcSetOperation(dev, buffered)
|
||||
err := ipcSetOperation(device, buffered)
|
||||
if err != nil {
|
||||
fmt.Fprintf(buffered, "errno=%d\n", err.ErrorCode())
|
||||
fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode())
|
||||
return err
|
||||
} else {
|
||||
fmt.Fprintf(buffered, "errno=0\n")
|
||||
fmt.Fprintf(buffered, "errno=0\n\n")
|
||||
}
|
||||
buffered.Flush()
|
||||
|
||||
case "get=1\n":
|
||||
err := ipcGetOperation(device, buffered)
|
||||
if err != nil {
|
||||
fmt.Fprintf(buffered, "errno=1\n\n") // fix
|
||||
return err
|
||||
} else {
|
||||
fmt.Fprintf(buffered, "errno=0\n\n")
|
||||
}
|
||||
buffered.Flush()
|
||||
|
||||
case "\n":
|
||||
default:
|
||||
return errors.New("handle this please")
|
||||
}
|
||||
|
@ -8,9 +8,14 @@ const (
|
||||
RekeyAfterMessage = (1 << 64) - (1 << 16) - 1
|
||||
RekeyAfterTime = time.Second * 120
|
||||
RekeyAttemptTime = time.Second * 90
|
||||
RekeyTimeout = time.Second * 5
|
||||
RekeyTimeout = time.Second * 5 // TODO: Exponential backoff
|
||||
RejectAfterTime = time.Second * 180
|
||||
RejectAfterMessage = (1 << 64) - (1 << 4) - 1
|
||||
KeepaliveTimeout = time.Second * 10
|
||||
CookieRefreshTime = time.Second * 2
|
||||
MaxHandshakeAttempTime = time.Second * 90
|
||||
)
|
||||
|
||||
const (
|
||||
QueueOutboundSize = 1024
|
||||
)
|
||||
|
@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"net"
|
||||
"runtime"
|
||||
"sync"
|
||||
)
|
||||
|
||||
@ -16,7 +17,9 @@ type Device struct {
|
||||
routingTable RoutingTable
|
||||
indices IndexTable
|
||||
log *Logger
|
||||
queueWorkOutbound chan *OutboundWorkQueueElement
|
||||
queue struct {
|
||||
encryption chan *QueueOutboundElement // parallel work queue
|
||||
}
|
||||
peers map[NoisePublicKey]*Peer
|
||||
mac MacStateDevice
|
||||
}
|
||||
@ -41,7 +44,9 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
|
||||
}
|
||||
}
|
||||
|
||||
func (device *Device) Init() {
|
||||
func NewDevice(tun TUNDevice) *Device {
|
||||
device := new(Device)
|
||||
|
||||
device.mutex.Lock()
|
||||
defer device.mutex.Unlock()
|
||||
|
||||
@ -49,6 +54,14 @@ func (device *Device) Init() {
|
||||
device.peers = make(map[NoisePublicKey]*Peer)
|
||||
device.indices.Init()
|
||||
device.routingTable.Reset()
|
||||
|
||||
// start workers
|
||||
|
||||
for i := 0; i < runtime.NumCPU(); i += 1 {
|
||||
go device.RoutineEncryption()
|
||||
}
|
||||
go device.RoutineReadFromTUN(tun)
|
||||
return device
|
||||
}
|
||||
|
||||
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
|
||||
|
172
src/handshake.go
Normal file
172
src/handshake.go
Normal file
@ -0,0 +1,172 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
/* Sends a keep-alive if no packets queued for peer
|
||||
*
|
||||
* Used by initiator of handshake and with active keep-alive
|
||||
*/
|
||||
func (peer *Peer) SendKeepAlive() bool {
|
||||
if len(peer.queue.nonce) == 0 {
|
||||
select {
|
||||
case peer.queue.nonce <- []byte{}:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (peer *Peer) RoutineHandshakeInitiator() {
|
||||
var ongoing bool
|
||||
var begun time.Time
|
||||
var attempts uint
|
||||
var timeout time.Timer
|
||||
|
||||
device := peer.device
|
||||
work := new(QueueOutboundElement)
|
||||
buffer := make([]byte, 0, 1024)
|
||||
|
||||
queueHandshakeInitiation := func() error {
|
||||
work.mutex.Lock()
|
||||
defer work.mutex.Unlock()
|
||||
|
||||
// create initiation
|
||||
|
||||
msg, err := device.CreateMessageInitiation(peer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// create "work" element
|
||||
|
||||
writer := bytes.NewBuffer(buffer[:0])
|
||||
binary.Write(writer, binary.LittleEndian, &msg)
|
||||
work.packet = writer.Bytes()
|
||||
peer.mac.AddMacs(work.packet)
|
||||
peer.InsertOutbound(work)
|
||||
return nil
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-peer.signal.stopInitiator:
|
||||
return
|
||||
|
||||
case <-peer.signal.newHandshake:
|
||||
if ongoing {
|
||||
continue
|
||||
}
|
||||
|
||||
// create handshake
|
||||
|
||||
err := queueHandshakeInitiation()
|
||||
if err != nil {
|
||||
device.log.Error.Println("Failed to create initiation message:", err)
|
||||
}
|
||||
|
||||
// log when we began
|
||||
|
||||
begun = time.Now()
|
||||
ongoing = true
|
||||
attempts = 0
|
||||
timeout.Reset(RekeyTimeout)
|
||||
|
||||
case <-peer.timer.sendKeepalive.C:
|
||||
|
||||
// active keep-alives
|
||||
|
||||
peer.SendKeepAlive()
|
||||
|
||||
case <-peer.timer.handshakeTimeout.C:
|
||||
|
||||
// check if we can stop trying
|
||||
|
||||
if time.Now().Sub(begun) > MaxHandshakeAttempTime {
|
||||
peer.signal.flushNonceQueue <- true
|
||||
peer.timer.sendKeepalive.Stop()
|
||||
ongoing = false
|
||||
continue
|
||||
}
|
||||
|
||||
// otherwise, try again (exponental backoff)
|
||||
|
||||
attempts += 1
|
||||
err := queueHandshakeInitiation()
|
||||
if err != nil {
|
||||
device.log.Error.Println("Failed to create initiation message:", err)
|
||||
}
|
||||
peer.timer.handshakeTimeout.Reset((1 << attempts) * RekeyTimeout)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Handles packets related to handshake
|
||||
*
|
||||
*
|
||||
*/
|
||||
func (device *Device) HandshakeWorker(queue chan struct {
|
||||
msg []byte
|
||||
msgType uint32
|
||||
addr *net.UDPAddr
|
||||
}) {
|
||||
for {
|
||||
elem := <-queue
|
||||
|
||||
switch elem.msgType {
|
||||
case MessageInitiationType:
|
||||
if len(elem.msg) != MessageInitiationSize {
|
||||
continue
|
||||
}
|
||||
|
||||
// check for cookie
|
||||
|
||||
var msg MessageInitiation
|
||||
|
||||
binary.Read(nil, binary.LittleEndian, &msg)
|
||||
|
||||
case MessageResponseType:
|
||||
if len(elem.msg) != MessageResponseSize {
|
||||
continue
|
||||
}
|
||||
|
||||
// check for cookie
|
||||
|
||||
case MessageCookieReplyType:
|
||||
|
||||
case MessageTransportType:
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func (device *Device) KeepKeyFresh(peer *Peer) {
|
||||
|
||||
send := func() bool {
|
||||
peer.keyPairs.mutex.RLock()
|
||||
defer peer.keyPairs.mutex.RUnlock()
|
||||
|
||||
kp := peer.keyPairs.current
|
||||
if kp == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
nonce := atomic.LoadUint64(&kp.sendNonce)
|
||||
if nonce > RekeyAfterMessage {
|
||||
return true
|
||||
}
|
||||
|
||||
return kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime
|
||||
}()
|
||||
|
||||
if send {
|
||||
|
||||
}
|
||||
}
|
64
src/helper_test.go
Normal file
64
src/helper_test.go
Normal file
@ -0,0 +1,64 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
/* Helpers for writing unit tests
|
||||
*/
|
||||
|
||||
type DummyTUN struct {
|
||||
name string
|
||||
mtu uint
|
||||
packets chan []byte
|
||||
}
|
||||
|
||||
func (tun *DummyTUN) Name() string {
|
||||
return tun.name
|
||||
}
|
||||
|
||||
func (tun *DummyTUN) MTU() uint {
|
||||
return tun.mtu
|
||||
}
|
||||
|
||||
func (tun *DummyTUN) Write(d []byte) (int, error) {
|
||||
tun.packets <- d
|
||||
return len(d), nil
|
||||
}
|
||||
|
||||
func (tun *DummyTUN) Read(d []byte) (int, error) {
|
||||
t := <-tun.packets
|
||||
copy(d, t)
|
||||
return len(t), nil
|
||||
}
|
||||
|
||||
func CreateDummyTUN(name string) (TUNDevice, error) {
|
||||
var dummy DummyTUN
|
||||
dummy.mtu = 1024
|
||||
dummy.packets = make(chan []byte, 100)
|
||||
return &dummy, nil
|
||||
}
|
||||
|
||||
func assertNil(t *testing.T, err error) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func assertEqual(t *testing.T, a []byte, b []byte) {
|
||||
if bytes.Compare(a, b) != 0 {
|
||||
t.Fatal(a, "!=", b)
|
||||
}
|
||||
}
|
||||
|
||||
func randDevice(t *testing.T) *Device {
|
||||
sk, err := newPrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tun, _ := CreateDummyTUN("dummy")
|
||||
device := NewDevice(tun)
|
||||
device.SetPrivateKey(sk)
|
||||
return device
|
||||
}
|
@ -8,6 +8,7 @@ const (
|
||||
IPv4version = 4
|
||||
IPv4offsetSrc = 12
|
||||
IPv4offsetDst = IPv4offsetSrc + net.IPv4len
|
||||
IPv4headerSize = 20
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -8,8 +8,8 @@ import (
|
||||
)
|
||||
|
||||
func TestMAC1(t *testing.T) {
|
||||
dev1 := newDevice(t)
|
||||
dev2 := newDevice(t)
|
||||
dev1 := randDevice(t)
|
||||
dev2 := randDevice(t)
|
||||
|
||||
peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
|
||||
peer2 := dev1.NewPeer(dev2.privateKey.publicKey())
|
||||
@ -34,12 +34,10 @@ func TestMACs(t *testing.T) {
|
||||
msg []byte,
|
||||
receiver uint32,
|
||||
) bool {
|
||||
var device1 Device
|
||||
device1.Init()
|
||||
device1 := randDevice(t)
|
||||
device1.SetPrivateKey(sk1)
|
||||
|
||||
var device2 Device
|
||||
device2.Init()
|
||||
device2 := randDevice(t)
|
||||
device2.SetPrivateKey(sk2)
|
||||
|
||||
peer1 := device2.NewPeer(device1.privateKey.publicKey())
|
||||
|
53
src/main.go
53
src/main.go
@ -1,36 +1,30 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func main() {
|
||||
fd, err := CreateTUN("test0")
|
||||
fmt.Println(fd, err)
|
||||
|
||||
queue := make(chan []byte, 1000)
|
||||
|
||||
// var device Device
|
||||
|
||||
// go OutgoingRoutingWorker(&device, queue)
|
||||
|
||||
for {
|
||||
tmp := make([]byte, 1<<16)
|
||||
n, err := fd.Read(tmp)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
queue <- tmp[:n]
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
)
|
||||
|
||||
/*
|
||||
*
|
||||
* TODO: Fix logging
|
||||
*/
|
||||
|
||||
func main() {
|
||||
// Open TUN device
|
||||
|
||||
// TODO: Fix capabilities
|
||||
|
||||
tun, err := CreateTUN("test0")
|
||||
log.Println(tun, err)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
device := NewDevice(tun)
|
||||
|
||||
// Start configuration lister
|
||||
|
||||
l, err := net.Listen("unix", "/var/run/wireguard/wg0.sock")
|
||||
if err != nil {
|
||||
log.Fatal("listen error:", err)
|
||||
@ -41,12 +35,9 @@ func main() {
|
||||
if err != nil {
|
||||
log.Fatal("accept error:", err)
|
||||
}
|
||||
|
||||
var dev Device
|
||||
go func(conn net.Conn) {
|
||||
err := ipcListen(&dev, conn)
|
||||
fmt.Println(err)
|
||||
err := ipcListen(device, conn)
|
||||
log.Println(err)
|
||||
}(fd)
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
@ -77,7 +77,7 @@ type MessageCookieReply struct {
|
||||
|
||||
type Handshake struct {
|
||||
state int
|
||||
mutex sync.Mutex
|
||||
mutex sync.RWMutex
|
||||
hash [blake2s.Size]byte // hash value
|
||||
chainKey [blake2s.Size]byte // chain key
|
||||
presharedKey NoiseSymmetricKey // psk
|
||||
@ -205,19 +205,26 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
||||
}
|
||||
hash = mixHash(hash, msg.Static[:])
|
||||
|
||||
// find peer
|
||||
// lookup peer
|
||||
|
||||
peer := device.LookupPeer(peerPK)
|
||||
if peer == nil {
|
||||
return nil
|
||||
}
|
||||
handshake := &peer.handshake
|
||||
handshake.mutex.Lock()
|
||||
defer handshake.mutex.Unlock()
|
||||
|
||||
// verify identity
|
||||
|
||||
var timestamp TAI64N
|
||||
ok := func() bool {
|
||||
|
||||
// read lock handshake
|
||||
|
||||
handshake.mutex.RLock()
|
||||
defer handshake.mutex.RUnlock()
|
||||
|
||||
// decrypt timestamp
|
||||
|
||||
var timestamp TAI64N
|
||||
func() {
|
||||
var key [chacha20poly1305.KeySize]byte
|
||||
chainKey, key = KDF2(
|
||||
@ -228,26 +235,34 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
||||
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
|
||||
}()
|
||||
if err != nil {
|
||||
return nil
|
||||
return false
|
||||
}
|
||||
hash = mixHash(hash, msg.Timestamp[:])
|
||||
|
||||
// TODO: check for flood attack
|
||||
|
||||
// check for replay attack
|
||||
|
||||
if !timestamp.After(handshake.lastTimestamp) {
|
||||
return timestamp.After(handshake.lastTimestamp)
|
||||
}()
|
||||
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: check for flood attack
|
||||
|
||||
// update handshake state
|
||||
|
||||
handshake.mutex.Lock()
|
||||
|
||||
handshake.hash = hash
|
||||
handshake.chainKey = chainKey
|
||||
handshake.remoteIndex = msg.Sender
|
||||
handshake.remoteEphemeral = msg.Ephemeral
|
||||
handshake.lastTimestamp = timestamp
|
||||
handshake.state = HandshakeInitiationConsumed
|
||||
|
||||
handshake.mutex.Unlock()
|
||||
|
||||
return peer
|
||||
}
|
||||
|
||||
@ -320,16 +335,26 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
||||
return nil
|
||||
}
|
||||
|
||||
handshake.mutex.Lock()
|
||||
defer handshake.mutex.Unlock()
|
||||
var (
|
||||
hash [blake2s.Size]byte
|
||||
chainKey [blake2s.Size]byte
|
||||
)
|
||||
|
||||
ok := func() bool {
|
||||
|
||||
// read lock handshake
|
||||
|
||||
handshake.mutex.RLock()
|
||||
defer handshake.mutex.RUnlock()
|
||||
|
||||
if handshake.state != HandshakeInitiationCreated {
|
||||
return nil
|
||||
return false
|
||||
}
|
||||
|
||||
// finish 3-way DH
|
||||
|
||||
hash := mixHash(handshake.hash, msg.Ephemeral[:])
|
||||
chainKey := handshake.chainKey
|
||||
hash = mixHash(handshake.hash, msg.Ephemeral[:])
|
||||
chainKey = handshake.chainKey
|
||||
|
||||
func() {
|
||||
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
|
||||
@ -350,17 +375,27 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
||||
aead, _ := chacha20poly1305.New(key[:])
|
||||
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
|
||||
if err != nil {
|
||||
return nil
|
||||
return false
|
||||
}
|
||||
hash = mixHash(hash, msg.Empty[:])
|
||||
return true
|
||||
}()
|
||||
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// update handshake state
|
||||
|
||||
handshake.mutex.Lock()
|
||||
|
||||
handshake.hash = hash
|
||||
handshake.chainKey = chainKey
|
||||
handshake.remoteIndex = msg.Sender
|
||||
handshake.state = HandshakeResponseConsumed
|
||||
|
||||
handshake.mutex.Unlock()
|
||||
|
||||
return lookup.peer
|
||||
}
|
||||
|
||||
|
@ -6,29 +6,6 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func assertNil(t *testing.T, err error) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func assertEqual(t *testing.T, a []byte, b []byte) {
|
||||
if bytes.Compare(a, b) != 0 {
|
||||
t.Fatal(a, "!=", b)
|
||||
}
|
||||
}
|
||||
|
||||
func newDevice(t *testing.T) *Device {
|
||||
var device Device
|
||||
sk, err := newPrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
device.Init()
|
||||
device.SetPrivateKey(sk)
|
||||
return &device
|
||||
}
|
||||
|
||||
func TestCurveWrappers(t *testing.T) {
|
||||
sk1, err := newPrivateKey()
|
||||
assertNil(t, err)
|
||||
@ -49,8 +26,8 @@ func TestCurveWrappers(t *testing.T) {
|
||||
|
||||
func TestNoiseHandshake(t *testing.T) {
|
||||
|
||||
dev1 := newDevice(t)
|
||||
dev2 := newDevice(t)
|
||||
dev1 := randDevice(t)
|
||||
dev2 := randDevice(t)
|
||||
|
||||
peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
|
||||
peer2 := dev1.NewPeer(dev2.privateKey.publicKey())
|
||||
|
@ -3,18 +3,18 @@ package main
|
||||
import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
const (
|
||||
NoisePublicKeySize = 32
|
||||
NoisePrivateKeySize = 32
|
||||
NoiseSymmetricKeySize = 32
|
||||
)
|
||||
|
||||
type (
|
||||
NoisePublicKey [NoisePublicKeySize]byte
|
||||
NoisePrivateKey [NoisePrivateKeySize]byte
|
||||
NoiseSymmetricKey [NoiseSymmetricKeySize]byte
|
||||
NoiseSymmetricKey [chacha20poly1305.KeySize]byte
|
||||
NoiseNonce uint64 // padded to 12-bytes
|
||||
)
|
||||
|
||||
@ -30,6 +30,15 @@ func loadExactHex(dst []byte, src string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (key NoisePrivateKey) IsZero() bool {
|
||||
for _, b := range key[:] {
|
||||
if b != 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (key *NoisePrivateKey) FromHex(src string) error {
|
||||
return loadExactHex(key[:], src)
|
||||
}
|
||||
|
44
src/peer.go
44
src/peer.go
@ -7,9 +7,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
OutboundQueueSize = 64
|
||||
)
|
||||
const ()
|
||||
|
||||
type Peer struct {
|
||||
mutex sync.RWMutex
|
||||
@ -18,9 +16,25 @@ type Peer struct {
|
||||
keyPairs KeyPairs
|
||||
handshake Handshake
|
||||
device *Device
|
||||
queueInbound chan []byte
|
||||
queueOutbound chan *OutboundWorkQueueElement
|
||||
queueOutboundRouting chan []byte
|
||||
tx_bytes uint64
|
||||
rx_bytes uint64
|
||||
time struct {
|
||||
lastSend time.Time // last send message
|
||||
}
|
||||
signal struct {
|
||||
newHandshake chan bool
|
||||
flushNonceQueue chan bool // empty queued packets
|
||||
stopSending chan bool // stop sending pipeline
|
||||
stopInitiator chan bool // stop initiator timer
|
||||
}
|
||||
timer struct {
|
||||
sendKeepalive time.Timer
|
||||
handshakeTimeout time.Timer
|
||||
}
|
||||
queue struct {
|
||||
nonce chan []byte // nonce / pre-handshake queue
|
||||
outbound chan *QueueOutboundElement // sequential ordering of work
|
||||
}
|
||||
mac MacStatePeer
|
||||
}
|
||||
|
||||
@ -33,7 +47,8 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
|
||||
peer.device = device
|
||||
peer.keyPairs.Init()
|
||||
peer.mac.Init(pk)
|
||||
peer.queueOutbound = make(chan *OutboundWorkQueueElement, OutboundQueueSize)
|
||||
peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
|
||||
peer.queue.nonce = make(chan []byte, QueueOutboundSize)
|
||||
|
||||
// map public key
|
||||
|
||||
@ -54,5 +69,20 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
|
||||
handshake.mutex.Unlock()
|
||||
peer.mutex.Unlock()
|
||||
|
||||
// start workers
|
||||
|
||||
peer.signal.stopSending = make(chan bool, 1)
|
||||
peer.signal.stopInitiator = make(chan bool, 1)
|
||||
peer.signal.newHandshake = make(chan bool, 1)
|
||||
peer.signal.flushNonceQueue = make(chan bool, 1)
|
||||
|
||||
go peer.RoutineNonce()
|
||||
go peer.RoutineHandshakeInitiator()
|
||||
|
||||
return &peer
|
||||
}
|
||||
|
||||
func (peer *Peer) Close() {
|
||||
peer.signal.stopSending <- true
|
||||
peer.signal.stopInitiator <- true
|
||||
}
|
||||
|
@ -12,9 +12,20 @@ type RoutingTable struct {
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
func (table *RoutingTable) AllowedIPs(peer *Peer) []net.IPNet {
|
||||
table.mutex.RLock()
|
||||
defer table.mutex.RUnlock()
|
||||
|
||||
allowed := make([]net.IPNet, 10)
|
||||
table.IPv4.AllowedIPs(peer, allowed)
|
||||
table.IPv6.AllowedIPs(peer, allowed)
|
||||
return allowed
|
||||
}
|
||||
|
||||
func (table *RoutingTable) Reset() {
|
||||
table.mutex.Lock()
|
||||
defer table.mutex.Unlock()
|
||||
|
||||
table.IPv4 = nil
|
||||
table.IPv6 = nil
|
||||
}
|
||||
@ -22,6 +33,7 @@ func (table *RoutingTable) Reset() {
|
||||
func (table *RoutingTable) RemovePeer(peer *Peer) {
|
||||
table.mutex.Lock()
|
||||
defer table.mutex.Unlock()
|
||||
|
||||
table.IPv4 = table.IPv4.RemovePeer(peer)
|
||||
table.IPv6 = table.IPv6.RemovePeer(peer)
|
||||
}
|
||||
|
177
src/send.go
177
src/send.go
@ -5,30 +5,78 @@ import (
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
/* Handles outbound flow
|
||||
*
|
||||
* 1. TUN queue
|
||||
* 2. Routing
|
||||
* 3. Per peer queuing
|
||||
* 4. (work queuing)
|
||||
* 2. Routing (sequential)
|
||||
* 3. Nonce assignment (sequential)
|
||||
* 4. Encryption (parallel)
|
||||
* 5. Transmission (sequential)
|
||||
*
|
||||
* The order of packets (per peer) is maintained.
|
||||
* The functions in this file occure (roughly) in the order packets are processed.
|
||||
*/
|
||||
|
||||
type OutboundWorkQueueElement struct {
|
||||
wg sync.WaitGroup
|
||||
/* A work unit
|
||||
*
|
||||
* The sequential consumers will attempt to take the lock,
|
||||
* workers release lock when they have completed work on the packet.
|
||||
*/
|
||||
type QueueOutboundElement struct {
|
||||
mutex sync.Mutex
|
||||
packet []byte
|
||||
nonce uint64
|
||||
keyPair *KeyPair
|
||||
}
|
||||
|
||||
func (peer *Peer) HandshakeWorker(handshakeQueue []byte) {
|
||||
|
||||
func (peer *Peer) FlushNonceQueue() {
|
||||
elems := len(peer.queue.nonce)
|
||||
for i := 0; i < elems; i += 1 {
|
||||
select {
|
||||
case <-peer.queue.nonce:
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (device *Device) SendPacket(packet []byte) {
|
||||
func (peer *Peer) InsertOutbound(elem *QueueOutboundElement) {
|
||||
for {
|
||||
select {
|
||||
case peer.queue.outbound <- elem:
|
||||
default:
|
||||
select {
|
||||
case <-peer.queue.outbound:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Reads packets from the TUN and inserts
|
||||
* into nonce queue for peer
|
||||
*
|
||||
* Obs. Single instance per TUN device
|
||||
*/
|
||||
func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
|
||||
for {
|
||||
// read packet
|
||||
|
||||
packet := make([]byte, 1<<16) // TODO: Fix & avoid dynamic allocation
|
||||
size, err := tun.Read(packet)
|
||||
if err != nil {
|
||||
device.log.Error.Println("Failed to read packet from TUN device:", err)
|
||||
continue
|
||||
}
|
||||
packet = packet[:size]
|
||||
if len(packet) < IPv4headerSize {
|
||||
device.log.Error.Println("Packet too short, length:", len(packet))
|
||||
continue
|
||||
}
|
||||
|
||||
device.log.Debug.Println("New packet on TUN:", packet) // TODO: Slow debugging, remove.
|
||||
|
||||
// lookup peer
|
||||
|
||||
@ -43,69 +91,73 @@ func (device *Device) SendPacket(packet []byte) {
|
||||
peer = device.routingTable.LookupIPv6(dst)
|
||||
|
||||
default:
|
||||
device.log.Debug.Println("receieved packet with unknown IP version")
|
||||
device.log.Debug.Println("Receieved packet with unknown IP version")
|
||||
return
|
||||
}
|
||||
|
||||
if peer == nil {
|
||||
device.log.Debug.Println("No peer configured for IP")
|
||||
return
|
||||
}
|
||||
|
||||
// insert into peer queue
|
||||
// insert into nonce/pre-handshake queue
|
||||
|
||||
for {
|
||||
select {
|
||||
case peer.queueOutboundRouting <- packet:
|
||||
case peer.queue.nonce <- packet:
|
||||
default:
|
||||
select {
|
||||
case <-peer.queueOutboundRouting:
|
||||
case <-peer.queue.nonce:
|
||||
default:
|
||||
}
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Go routine
|
||||
/* Queues packets when there is no handshake.
|
||||
* Then assigns nonces to packets sequentially
|
||||
* and creates "work" structs for workers
|
||||
*
|
||||
* TODO: Avoid dynamic allocation of work queue elements
|
||||
*
|
||||
* 1. waits for handshake.
|
||||
* 2. assigns key pair & nonce
|
||||
* 3. inserts to working queue
|
||||
*
|
||||
* TODO: avoid dynamic allocation of work queue elements
|
||||
* Obs. A single instance per peer
|
||||
*/
|
||||
func (peer *Peer) RoutineOutboundNonceWorker() {
|
||||
func (peer *Peer) RoutineNonce() {
|
||||
var packet []byte
|
||||
var keyPair *KeyPair
|
||||
var flushTimer time.Timer
|
||||
|
||||
for {
|
||||
|
||||
// wait for packet
|
||||
|
||||
if packet == nil {
|
||||
packet = <-peer.queueOutboundRouting
|
||||
select {
|
||||
case packet = <-peer.queue.nonce:
|
||||
case <-peer.signal.stopSending:
|
||||
close(peer.queue.outbound)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// wait for key pair
|
||||
|
||||
for keyPair == nil {
|
||||
flushTimer.Reset(time.Second * 10)
|
||||
// TODO: Handshake or NOP
|
||||
peer.signal.newHandshake <- true
|
||||
select {
|
||||
case <-peer.keyPairs.newKeyPair:
|
||||
keyPair = peer.keyPairs.Current()
|
||||
continue
|
||||
case <-flushTimer.C:
|
||||
size := len(peer.queueOutboundRouting)
|
||||
for i := 0; i < size; i += 1 {
|
||||
<-peer.queueOutboundRouting
|
||||
}
|
||||
case <-peer.signal.flushNonceQueue:
|
||||
peer.FlushNonceQueue()
|
||||
packet = nil
|
||||
continue
|
||||
case <-peer.signal.stopSending:
|
||||
close(peer.queue.outbound)
|
||||
return
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// process current packet
|
||||
@ -114,14 +166,13 @@ func (peer *Peer) RoutineOutboundNonceWorker() {
|
||||
|
||||
// create work element
|
||||
|
||||
work := new(OutboundWorkQueueElement)
|
||||
work.wg.Add(1)
|
||||
work := new(QueueOutboundElement) // TODO: profile, maybe use pool
|
||||
work.keyPair = keyPair
|
||||
work.packet = packet
|
||||
work.nonce = keyPair.sendNonce
|
||||
work.mutex.Lock()
|
||||
|
||||
packet = nil
|
||||
peer.queueOutbound <- work
|
||||
keyPair.sendNonce += 1
|
||||
|
||||
// drop packets until there is space
|
||||
@ -129,46 +180,36 @@ func (peer *Peer) RoutineOutboundNonceWorker() {
|
||||
func() {
|
||||
for {
|
||||
select {
|
||||
case peer.device.queueWorkOutbound <- work:
|
||||
case peer.device.queue.encryption <- work:
|
||||
return
|
||||
default:
|
||||
drop := <-peer.device.queueWorkOutbound
|
||||
drop := <-peer.device.queue.encryption
|
||||
drop.packet = nil
|
||||
drop.wg.Done()
|
||||
drop.mutex.Unlock()
|
||||
}
|
||||
}
|
||||
}()
|
||||
peer.queue.outbound <- work
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Go routine
|
||||
*
|
||||
* sequentially reads packets from queue and sends to endpoint
|
||||
/* Encrypts the elements in the queue
|
||||
* and marks them for sequential consumption (by releasing the mutex)
|
||||
*
|
||||
* Obs. One instance per core
|
||||
*/
|
||||
func (peer *Peer) RoutineSequential() {
|
||||
for work := range peer.queueOutbound {
|
||||
work.wg.Wait()
|
||||
if work.packet == nil {
|
||||
continue
|
||||
}
|
||||
if peer.endpoint == nil {
|
||||
continue
|
||||
}
|
||||
peer.device.conn.WriteToUDP(work.packet, peer.endpoint)
|
||||
}
|
||||
}
|
||||
|
||||
func (device *Device) RoutineEncryptionWorker() {
|
||||
func (device *Device) RoutineEncryption() {
|
||||
var nonce [chacha20poly1305.NonceSize]byte
|
||||
for work := range device.queueWorkOutbound {
|
||||
for work := range device.queue.encryption {
|
||||
|
||||
// pad packet
|
||||
|
||||
padding := device.mtu - len(work.packet)
|
||||
if padding < 0 {
|
||||
// drop
|
||||
work.packet = nil
|
||||
work.wg.Done()
|
||||
work.mutex.Unlock()
|
||||
}
|
||||
for n := 0; n < padding; n += 1 {
|
||||
work.packet = append(work.packet, 0)
|
||||
@ -183,6 +224,30 @@ func (device *Device) RoutineEncryptionWorker() {
|
||||
work.packet,
|
||||
nil,
|
||||
)
|
||||
work.wg.Done()
|
||||
work.mutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
/* Sequentially reads packets from queue and sends to endpoint
|
||||
*
|
||||
* Obs. Single instance per peer.
|
||||
* The routine terminates then the outbound queue is closed.
|
||||
*/
|
||||
func (peer *Peer) RoutineSequential() {
|
||||
for work := range peer.queue.outbound {
|
||||
work.mutex.Lock()
|
||||
func() {
|
||||
peer.mutex.RLock()
|
||||
defer peer.mutex.RUnlock()
|
||||
if work.packet == nil {
|
||||
return
|
||||
}
|
||||
if peer.endpoint == nil {
|
||||
return
|
||||
}
|
||||
peer.device.conn.WriteToUDP(work.packet, peer.endpoint)
|
||||
peer.timer.sendKeepalive.Reset(peer.persistentKeepaliveInterval)
|
||||
}()
|
||||
work.mutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
31
src/trie.go
31
src/trie.go
@ -1,15 +1,20 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
)
|
||||
|
||||
/* Binary trie
|
||||
*
|
||||
* The net.IPs used here are not formatted the
|
||||
* same way as those created by the "net" functions.
|
||||
* Here the IPs are slices of either 4 or 16 byte (not always 16)
|
||||
*
|
||||
* Syncronization done seperatly
|
||||
* See: routing.go
|
||||
*
|
||||
* Todo: Better commenting
|
||||
* TODO: Better commenting
|
||||
*/
|
||||
|
||||
type Trie struct {
|
||||
@ -24,7 +29,7 @@ type Trie struct {
|
||||
}
|
||||
|
||||
/* Finds length of matching prefix
|
||||
* Maybe there is a faster way
|
||||
* TODO: Make faster
|
||||
*
|
||||
* Assumption: len(ip1) == len(ip2)
|
||||
*/
|
||||
@ -189,3 +194,25 @@ func (node *Trie) Count() uint {
|
||||
r := node.child[1].Count()
|
||||
return l + r
|
||||
}
|
||||
|
||||
func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) {
|
||||
if node.peer == p {
|
||||
var mask net.IPNet
|
||||
mask.Mask = net.CIDRMask(int(node.cidr), len(node.bits)*8)
|
||||
if len(node.bits) == net.IPv4len {
|
||||
mask.IP = net.IPv4(
|
||||
node.bits[0],
|
||||
node.bits[1],
|
||||
node.bits[2],
|
||||
node.bits[3],
|
||||
)
|
||||
} else if len(node.bits) == net.IPv6len {
|
||||
mask.IP = node.bits
|
||||
} else {
|
||||
panic(errors.New("bug: unexpected address length"))
|
||||
}
|
||||
results = append(results, mask)
|
||||
}
|
||||
node.child[0].AllowedIPs(p, results)
|
||||
node.child[1].AllowedIPs(p, results)
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
package main
|
||||
|
||||
type TUN interface {
|
||||
type TUNDevice interface {
|
||||
Read([]byte) (int, error)
|
||||
Write([]byte) (int, error)
|
||||
Name() string
|
||||
|
@ -9,9 +9,7 @@ import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
/* Platform dependent functions for interacting with
|
||||
* TUN devices on linux systems
|
||||
*
|
||||
/* Implementation of the TUN device interface for linux
|
||||
*/
|
||||
|
||||
const CloneDevicePath = "/dev/net/tun"
|
||||
@ -45,7 +43,7 @@ func (tun *NativeTun) Read(d []byte) (int, error) {
|
||||
return tun.fd.Read(d)
|
||||
}
|
||||
|
||||
func CreateTUN(name string) (TUN, error) {
|
||||
func CreateTUN(name string) (TUNDevice, error) {
|
||||
// Open clone device
|
||||
fd, err := os.OpenFile(CloneDevicePath, os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
@ -53,7 +51,7 @@ func CreateTUN(name string) (TUN, error) {
|
||||
}
|
||||
|
||||
// Prepare ifreq struct
|
||||
var ifr [18]byte
|
||||
var ifr [128]byte
|
||||
var flags uint16 = IFF_TUN | IFF_NO_PI
|
||||
nameBytes := []byte(name)
|
||||
if len(nameBytes) >= IFNAMSIZ {
|
||||
|
Loading…
Reference in New Issue
Block a user