Initial working full exchange

The implementation is now capable of connecting to another
wireguard instance, complete a handshake and exchange transport
messages.
This commit is contained in:
Mathias Hall-Andersen 2017-07-06 15:43:55 +02:00
parent 2aa0daf4d5
commit 59f9316f51
8 changed files with 184 additions and 201 deletions

View File

@ -21,5 +21,6 @@ const (
QueueInboundSize = 1024
QueueHandshakeSize = 1024
QueueHandshakeBusySize = QueueHandshakeSize / 8
MinMessageSize = MessageTransportSize
MinMessageSize = MessageTransportSize // keep-alive
MaxMessageSize = 4096
)

View File

@ -80,6 +80,7 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize)
device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize)
device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize)
device.queue.inbound = make(chan []byte, QueueInboundSize)
// prepare signals
@ -94,6 +95,7 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
}
go device.RoutineReadFromTUN(tun)
go device.RoutineReceiveIncomming()
go device.RoutineWriteToTUN(tun)
return device
}

View File

@ -12,9 +12,11 @@ import (
* Used by initiator of handshake and with active keep-alive
*/
func (peer *Peer) SendKeepAlive() bool {
elem := peer.device.NewOutboundElement()
elem.packet = nil
if len(peer.queue.nonce) == 0 {
select {
case peer.queue.nonce <- []byte{}:
case peer.queue.nonce <- elem:
return true
default:
return false
@ -60,11 +62,10 @@ func (peer *Peer) KeepKeyFreshSending() {
*/
func (peer *Peer) RoutineHandshakeInitiator() {
device := peer.device
buffer := make([]byte, 1024)
logger := device.log.Debug
timeout := stoppedTimer()
var work *QueueOutboundElement
var elem *QueueOutboundElement
logger.Println("Routine, handshake initator, started for peer", peer.id)
@ -94,25 +95,25 @@ func (peer *Peer) RoutineHandshakeInitiator() {
// create initiation
if work != nil {
work.mutex.Lock()
work.packet = nil
work.mutex.Unlock()
if elem != nil {
elem.Drop()
}
work = new(QueueOutboundElement)
elem = device.NewOutboundElement()
msg, err := device.CreateMessageInitiation(peer)
if err != nil {
device.log.Error.Println("Failed to create initiation message:", err)
break
}
// schedule for sending
// marshal & schedule for sending
writer := bytes.NewBuffer(buffer[:0])
writer := bytes.NewBuffer(elem.data[:0])
binary.Write(writer, binary.LittleEndian, msg)
work.packet = writer.Bytes()
peer.mac.AddMacs(work.packet)
peer.InsertOutbound(work)
elem.packet = writer.Bytes()
peer.mac.AddMacs(elem.packet)
println(elem)
addToOutboundQueue(peer.queue.outbound, elem)
if attempts == 0 {
deadline = time.Now().Add(MaxHandshakeAttemptTime)
@ -132,9 +133,11 @@ func (peer *Peer) RoutineHandshakeInitiator() {
return
case <-peer.signal.handshakeCompleted:
device.log.Debug.Println("Handshake complete")
break HandshakeLoop
case <-timeout.C:
device.log.Debug.Println("Timeout")
if deadline.Before(time.Now().Add(RekeyTimeout)) {
peer.signal.flushNonceQueue <- struct{}{}
if !peer.timer.sendKeepalive.Stop() {

View File

@ -7,8 +7,7 @@ import (
)
type KeyPair struct {
recv cipher.AEAD
recvNonce uint64
receive cipher.AEAD
send cipher.AEAD
sendNonce uint64
isInitiator bool

View File

@ -446,10 +446,10 @@ func (peer *Peer) NewKeyPair() *KeyPair {
keyPair := new(KeyPair)
keyPair.send, _ = chacha20poly1305.New(sendKey[:])
keyPair.recv, _ = chacha20poly1305.New(recvKey[:])
keyPair.receive, _ = chacha20poly1305.New(recvKey[:])
keyPair.sendNonce = 0
keyPair.recvNonce = 0
keyPair.created = time.Now()
keyPair.isInitiator = isInitiator
keyPair.localIndex = peer.handshake.localIndex
keyPair.remoteIndex = peer.handshake.remoteIndex
@ -462,7 +462,7 @@ func (peer *Peer) NewKeyPair() *KeyPair {
})
handshake.localIndex = 0
// start timer for keypair
// TODO: start timer for keypair (clearing)
// rotate key pairs
@ -473,7 +473,7 @@ func (peer *Peer) NewKeyPair() *KeyPair {
if isInitiator {
if kp.previous != nil {
kp.previous.send = nil
kp.previous.recv = nil
kp.previous.receive = nil
peer.device.indices.Delete(kp.previous.localIndex)
}
kp.previous = kp.current

View File

@ -35,7 +35,7 @@ type Peer struct {
handshakeTimeout *time.Timer
}
queue struct {
nonce chan []byte // nonce / pre-handshake queue
nonce chan *QueueOutboundElement // nonce / pre-handshake queue
outbound chan *QueueOutboundElement // sequential ordering of work
inbound chan *QueueInboundElement // sequential ordering of work
}
@ -78,9 +78,9 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
// prepare queuing
peer.queue.nonce = make(chan []byte, QueueOutboundSize)
peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
// prepare signaling

View File

@ -31,17 +31,39 @@ type QueueInboundElement struct {
func (elem *QueueInboundElement) Drop() {
atomic.StoreUint32(&elem.state, ElementStateDropped)
elem.mutex.Unlock()
}
func (elem *QueueInboundElement) IsDropped() bool {
return atomic.LoadUint32(&elem.state) == ElementStateDropped
}
func addToInboundQueue(
queue chan *QueueInboundElement,
element *QueueInboundElement,
) {
for {
select {
case queue <- element:
return
default:
select {
case old := <-queue:
old.Drop()
default:
}
}
}
}
func (device *Device) RoutineReceiveIncomming() {
var packet []byte
debugLog := device.log.Debug
debugLog.Println("Routine, receive incomming, started")
errorLog := device.log.Error
var buffer []byte // unsliced buffer
for {
// check if stopped
@ -54,28 +76,28 @@ func (device *Device) RoutineReceiveIncomming() {
// read next datagram
if packet == nil {
packet = make([]byte, 1<<16)
if buffer == nil {
buffer = make([]byte, MaxMessageSize)
}
device.net.mutex.RLock()
conn := device.net.conn
device.net.mutex.RUnlock()
if conn == nil {
time.Sleep(time.Second)
continue
}
conn.SetReadDeadline(time.Now().Add(time.Second))
size, raddr, err := conn.ReadFromUDP(packet)
if err != nil {
continue
}
if size < MinMessageSize {
size, raddr, err := conn.ReadFromUDP(buffer)
if err != nil || size < MinMessageSize {
continue
}
// handle packet
packet = packet[:size]
debugLog.Println("GOT:", packet)
packet := buffer[:size]
msgType := binary.LittleEndian.Uint32(packet[:4])
func() {
@ -112,6 +134,7 @@ func (device *Device) RoutineReceiveIncomming() {
// add to handshake queue
buffer = nil
device.queue.handshake <- QueueHandshakeElement{
msgType: msgType,
packet: packet,
@ -137,8 +160,6 @@ func (device *Device) RoutineReceiveIncomming() {
case MessageTransportType:
debugLog.Println("DEBUG: Got transport")
// lookup key pair
if len(packet) < MessageTransportSize {
@ -169,42 +190,15 @@ func (device *Device) RoutineReceiveIncomming() {
work.state = ElementStateOkay
work.mutex.Lock()
// add to parallel decryption queue
// add to decryption queues
func() {
for {
select {
case device.queue.decryption <- work:
return
default:
select {
case elem := <-device.queue.decryption:
elem.Drop()
default:
}
}
}
}()
// add to sequential inbound queue
func() {
for {
select {
case peer.queue.inbound <- work:
break
default:
select {
case elem := <-peer.queue.inbound:
elem.Drop()
default:
}
}
}
}()
addToInboundQueue(device.queue.decryption, work)
addToInboundQueue(peer.queue.inbound, work)
buffer = nil
default:
// unknown message type
debugLog.Println("Got unknown message from:", raddr)
}
}()
}
@ -214,6 +208,9 @@ func (device *Device) RoutineDecryption() {
var elem *QueueInboundElement
var nonce [chacha20poly1305.NonceSize]byte
logDebug := device.log.Debug
logDebug.Println("Routine, decryption, started for device")
for {
select {
case elem = <-device.queue.decryption:
@ -223,31 +220,25 @@ func (device *Device) RoutineDecryption() {
// check if dropped
state := atomic.LoadUint32(&elem.state)
if state != ElementStateOkay {
if elem.IsDropped() {
elem.mutex.Unlock()
continue
}
// split message into fields
counter := binary.LittleEndian.Uint64(
elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent],
)
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
content := elem.packet[MessageTransportOffsetContent:]
// decrypt with key-pair
var err error
binary.LittleEndian.PutUint64(nonce[4:], counter)
elem.packet, err = elem.keyPair.recv.Open(elem.packet[:0], nonce[:], content, nil)
copy(nonce[4:], counter)
elem.counter = binary.LittleEndian.Uint64(counter)
elem.packet, err = elem.keyPair.receive.Open(elem.packet[:0], nonce[:], content, nil)
if err != nil {
elem.Drop()
continue
}
// release to consumer
elem.counter = counter
elem.mutex.Unlock()
}
}
@ -261,6 +252,7 @@ func (device *Device) RoutineHandshake() {
logInfo := device.log.Info
logError := device.log.Error
logDebug := device.log.Debug
logDebug.Println("Routine, handshake routine, started for device")
var elem QueueHandshakeElement
@ -332,13 +324,15 @@ func (device *Device) RoutineHandshake() {
}
sendSignal(peer.signal.handshakeCompleted)
logDebug.Println("Recieved valid response message for peer", peer.id)
peer.NewKeyPair()
kp := peer.NewKeyPair()
if kp == nil {
logDebug.Println("Failed to derieve key-pair")
}
peer.SendKeepAlive()
default:
device.log.Error.Println("Invalid message type in handshake queue")
}
}()
}
}
@ -348,7 +342,6 @@ func (peer *Peer) RoutineSequentialReceiver() {
device := peer.device
logDebug := device.log.Debug
logDebug.Println("Routine, sequential receiver, started for peer", peer.id)
for {
@ -359,20 +352,15 @@ func (peer *Peer) RoutineSequentialReceiver() {
return
case elem = <-peer.queue.inbound:
}
elem.mutex.Lock()
// check if dropped
logDebug.Println("MESSSAGE:", elem)
state := atomic.LoadUint32(&elem.state)
if state != ElementStateOkay {
if elem.IsDropped() {
continue
}
// check for replay
// strip padding
// update timers
// check for keep-alive
@ -380,26 +368,30 @@ func (peer *Peer) RoutineSequentialReceiver() {
continue
}
// strip padding
// insert into inbound TUN queue
device.queue.inbound <- elem.packet
}
// update key material
}
}
func (device *Device) RoutineWriteToTUN(tun TUNDevice) {
for {
var packet []byte
logError := device.log.Error
logDebug := device.log.Debug
logDebug.Println("Routine, sequential tun writer, started")
for {
select {
case <-device.signal.stop:
case packet = <-device.queue.inbound:
}
size, err := tun.Write(packet)
device.log.Debug.Println("DEBUG:", size, err)
if err != nil {
return
case packet := <-device.queue.inbound:
_, err := tun.Write(packet)
if err != nil {
logError.Println("Failed to write packet to TUN device:", err)
}
}
}
}

View File

@ -25,14 +25,19 @@ import (
*
* The sequential consumers will attempt to take the lock,
* workers release lock when they have completed work on the packet.
*
* If the element is inserted into the "encryption queue",
* the content is preceeded by enough "junk" to contain the header
* (to allow the constuction of transport messages in-place)
*/
type QueueOutboundElement struct {
state uint32
mutex sync.Mutex
packet []byte
nonce uint64
keyPair *KeyPair
peer *Peer
data [MaxMessageSize]byte
packet []byte // slice of packet (sending)
nonce uint64 // nonce for encryption
keyPair *KeyPair // key-pair for encryption
peer *Peer // related peer
}
func (peer *Peer) FlushNonceQueue() {
@ -46,18 +51,9 @@ func (peer *Peer) FlushNonceQueue() {
}
}
func (peer *Peer) InsertOutbound(elem *QueueOutboundElement) {
for {
select {
case peer.queue.outbound <- elem:
return
default:
select {
case <-peer.queue.outbound:
default:
}
}
}
func (device *Device) NewOutboundElement() *QueueOutboundElement {
elem := new(QueueOutboundElement) // TODO: profile, consider sync.Pool
return elem
}
func (elem *QueueOutboundElement) Drop() {
@ -68,53 +64,74 @@ func (elem *QueueOutboundElement) IsDropped() bool {
return atomic.LoadUint32(&elem.state) == ElementStateDropped
}
func addToOutboundQueue(
queue chan *QueueOutboundElement,
element *QueueOutboundElement,
) {
for {
select {
case queue <- element:
return
default:
select {
case old := <-queue:
old.Drop()
default:
}
}
}
}
/* Reads packets from the TUN and inserts
* into nonce queue for peer
*
* Obs. Single instance per TUN device
*/
func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
if tun.MTU() == 0 {
// Dummy
if tun == nil {
// dummy
return
}
elem := device.NewOutboundElement()
device.log.Debug.Println("Routine, TUN Reader: started")
for {
// read packet
packet := make([]byte, 1<<16) // TODO: Fix & avoid dynamic allocation
size, err := tun.Read(packet)
if elem == nil {
elem = device.NewOutboundElement()
}
elem.packet = elem.data[MessageTransportHeaderSize:]
size, err := tun.Read(elem.packet)
if err != nil {
device.log.Error.Println("Failed to read packet from TUN device:", err)
continue
}
packet = packet[:size]
if len(packet) < IPv4headerSize {
device.log.Error.Println("Packet too short, length:", len(packet))
elem.packet = elem.packet[:size]
if len(elem.packet) < IPv4headerSize {
device.log.Error.Println("Packet too short, length:", size)
continue
}
// lookup peer
var peer *Peer
switch packet[0] >> 4 {
switch elem.packet[0] >> 4 {
case IPv4version:
dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer = device.routingTable.LookupIPv4(dst)
device.log.Debug.Println("New IPv4 packet:", packet, dst)
case IPv6version:
dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
peer = device.routingTable.LookupIPv6(dst)
device.log.Debug.Println("New IPv6 packet:", packet, dst)
default:
device.log.Debug.Println("Receieved packet with unknown IP version")
}
if peer == nil {
device.log.Debug.Println("No peer configured for IP")
continue
}
if peer.endpoint == nil {
@ -124,18 +141,9 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
// insert into nonce/pre-handshake queue
for {
select {
case peer.queue.nonce <- packet:
default:
select {
case <-peer.queue.nonce:
default:
}
continue
}
break
}
addToOutboundQueue(peer.queue.nonce, elem)
elem = nil
}
}
@ -148,8 +156,8 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
* Obs. A single instance per peer
*/
func (peer *Peer) RoutineNonce() {
var packet []byte
var keyPair *KeyPair
var elem *QueueOutboundElement
device := peer.device
logger := device.log.Debug
@ -163,9 +171,9 @@ func (peer *Peer) RoutineNonce() {
// wait for packet
if packet == nil {
if elem == nil {
select {
case packet = <-peer.queue.nonce:
case elem = <-peer.queue.nonce:
case <-peer.signal.stop:
return
}
@ -198,7 +206,7 @@ func (peer *Peer) RoutineNonce() {
case <-peer.signal.flushNonceQueue:
logger.Println("Clearing queue for peer", peer.id)
peer.FlushNonceQueue()
packet = nil
elem = nil
goto NextPacket
case <-peer.signal.stop:
@ -208,36 +216,20 @@ func (peer *Peer) RoutineNonce() {
// process current packet
if packet != nil {
if elem != nil {
// create work element
work := new(QueueOutboundElement) // TODO: profile, maybe use pool
work.keyPair = keyPair
work.packet = packet
work.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1
work.peer = peer
work.mutex.Lock()
elem.keyPair = keyPair
elem.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1
elem.peer = peer
elem.mutex.Lock()
packet = nil
// add to parallel processing and sequential consuming queue
// drop packets until there is space
func() {
for {
select {
case peer.device.queue.encryption <- work:
return
default:
select {
case elem := <-peer.device.queue.encryption:
elem.Drop()
default:
}
}
}
}()
peer.queue.outbound <- work
addToOutboundQueue(device.queue.encryption, elem)
addToOutboundQueue(peer.queue.outbound, elem)
elem = nil
}
}
}()
@ -257,42 +249,38 @@ func (device *Device) RoutineEncryption() {
continue
}
// pad packet
// populate header fields
padding := device.mtu - len(work.packet) - MessageTransportSize
if padding < 0 {
work.Drop()
continue
}
func() {
header := work.data[:MessageTransportHeaderSize]
for n := 0; n < padding; n += 1 {
work.packet = append(work.packet, 0)
}
content := work.packet[MessageTransportHeaderSize:]
copy(content, work.packet)
fieldType := header[0:4]
fieldReceiver := header[4:8]
fieldNonce := header[8:16]
// prepare header
binary.LittleEndian.PutUint32(work.packet[:4], MessageTransportType)
binary.LittleEndian.PutUint32(work.packet[4:8], work.keyPair.remoteIndex)
binary.LittleEndian.PutUint64(work.packet[8:16], work.nonce)
device.log.Debug.Println(work.packet, work.nonce)
binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
binary.LittleEndian.PutUint32(fieldReceiver, work.keyPair.remoteIndex)
binary.LittleEndian.PutUint64(fieldNonce, work.nonce)
}()
// encrypt content
binary.LittleEndian.PutUint64(nonce[4:], work.nonce)
work.keyPair.send.Seal(
content[:0],
nonce[:],
content,
nil,
)
work.mutex.Unlock()
func() {
binary.LittleEndian.PutUint64(nonce[4:], work.nonce)
work.packet = work.keyPair.send.Seal(
work.packet[:0],
nonce[:],
work.packet,
nil,
)
work.mutex.Unlock()
}()
device.log.Debug.Println(work.packet, work.nonce)
// reslice to include header
// initiate new handshake
work.packet = work.data[:MessageTransportHeaderSize+len(work.packet)]
// refresh key if necessary
work.peer.KeepKeyFreshSending()
}
@ -340,8 +328,6 @@ func (peer *Peer) RoutineSequentialSender() {
return
}
logger.Println(work.packet)
_, err := device.net.conn.WriteToUDP(work.packet, peer.endpoint)
if err != nil {
return