Terminate on interface deletion
Program now terminates when the interface is removed Increases the number of os threads (relevant for Go <1.5, not tested) More consistent commenting Improved logging (additional peer information)
This commit is contained in:
		
							parent
							
								
									8393cbff52
								
							
						
					
					
						commit
						93e3848ea7
					
				@ -29,6 +29,6 @@ const (
 | 
			
		||||
	QueueInboundSize       = 1024
 | 
			
		||||
	QueueHandshakeSize     = 1024
 | 
			
		||||
	QueueHandshakeBusySize = QueueHandshakeSize / 8
 | 
			
		||||
	MinMessageSize         = MessageTransportSize // keep-alive
 | 
			
		||||
	MaxMessageSize         = 4096                 // TODO: make depend on the MTU?
 | 
			
		||||
	MinMessageSize         = MessageTransportSize // size of keep-alive
 | 
			
		||||
	MaxMessageSize         = (1 << 16) - 1
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -98,9 +98,9 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	go device.RoutineBusyMonitor()
 | 
			
		||||
	go device.RoutineWriteToTUN(tun)
 | 
			
		||||
	go device.RoutineReadFromTUN(tun)
 | 
			
		||||
	go device.RoutineReceiveIncomming()
 | 
			
		||||
	go device.RoutineWriteToTUN(tun)
 | 
			
		||||
	go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
 | 
			
		||||
 | 
			
		||||
	return device
 | 
			
		||||
@ -141,5 +141,8 @@ func (device *Device) RemoveAllPeers() {
 | 
			
		||||
func (device *Device) Close() {
 | 
			
		||||
	device.RemoveAllPeers()
 | 
			
		||||
	close(device.signal.stop)
 | 
			
		||||
	close(device.queue.encryption)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (device *Device) Wait() {
 | 
			
		||||
	<-device.signal.stop
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -5,17 +5,13 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	IPv4version           = 4
 | 
			
		||||
	IPv4offsetTotalLength = 2
 | 
			
		||||
	IPv4offsetSrc         = 12
 | 
			
		||||
	IPv4offsetDst         = IPv4offsetSrc + net.IPv4len
 | 
			
		||||
	IPv4headerSize        = 20
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	IPv6version             = 6
 | 
			
		||||
	IPv6offsetPayloadLength = 4
 | 
			
		||||
	IPv6offsetSrc           = 8
 | 
			
		||||
	IPv6offsetDst           = IPv6offsetSrc + net.IPv6len
 | 
			
		||||
	IPv6headerSize          = 40
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										31
									
								
								src/main.go
									
									
									
									
									
								
							
							
						
						
									
										31
									
								
								src/main.go
									
									
									
									
									
								
							@ -5,6 +5,7 @@ import (
 | 
			
		||||
	"log"
 | 
			
		||||
	"net"
 | 
			
		||||
	"os"
 | 
			
		||||
	"runtime"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/* TODO: Fix logging
 | 
			
		||||
@ -18,6 +19,10 @@ func main() {
 | 
			
		||||
	}
 | 
			
		||||
	deviceName := os.Args[1]
 | 
			
		||||
 | 
			
		||||
	// increase number of go workers (for Go <1.5)
 | 
			
		||||
 | 
			
		||||
	runtime.GOMAXPROCS(runtime.NumCPU())
 | 
			
		||||
 | 
			
		||||
	// open TUN device
 | 
			
		||||
 | 
			
		||||
	tun, err := CreateTUN(deviceName)
 | 
			
		||||
@ -31,17 +36,21 @@ func main() {
 | 
			
		||||
 | 
			
		||||
	// start configuration lister
 | 
			
		||||
 | 
			
		||||
	socketPath := fmt.Sprintf("/var/run/wireguard/%s.sock", deviceName)
 | 
			
		||||
	l, err := net.Listen("unix", socketPath)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatal("listen error:", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		conn, err := l.Accept()
 | 
			
		||||
	go func() {
 | 
			
		||||
		socketPath := fmt.Sprintf("/var/run/wireguard/%s.sock", deviceName)
 | 
			
		||||
		l, err := net.Listen("unix", socketPath)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Fatal("accept error:", err)
 | 
			
		||||
			log.Fatal("listen error:", err)
 | 
			
		||||
		}
 | 
			
		||||
		go ipcHandle(device, conn)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
		for {
 | 
			
		||||
			conn, err := l.Accept()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				log.Fatal("accept error:", err)
 | 
			
		||||
			}
 | 
			
		||||
			go ipcHandle(device, conn)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	device.Wait()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										19
									
								
								src/peer.go
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								src/peer.go
									
									
									
									
									
								
							@ -1,7 +1,9 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
@ -38,9 +40,9 @@ type Peer struct {
 | 
			
		||||
		/* Both keep-alive timers acts as one (see timers.go)
 | 
			
		||||
		 * They are kept seperate to simplify the implementation.
 | 
			
		||||
		 */
 | 
			
		||||
		keepalivePersistent      *time.Timer // set for persistent keepalives
 | 
			
		||||
		keepaliveAcknowledgement *time.Timer // set upon recieving messages
 | 
			
		||||
		zeroAllKeys              *time.Timer // zero all key material after RejectAfterTime*3
 | 
			
		||||
		keepalivePersistent *time.Timer // set for persistent keepalives
 | 
			
		||||
		keepalivePassive    *time.Timer // set upon recieving messages
 | 
			
		||||
		zeroAllKeys         *time.Timer // zero all key material after RejectAfterTime*3
 | 
			
		||||
	}
 | 
			
		||||
	queue struct {
 | 
			
		||||
		nonce    chan *QueueOutboundElement // nonce / pre-handshake queue
 | 
			
		||||
@ -63,8 +65,8 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
 | 
			
		||||
	peer.mac.Init(pk)
 | 
			
		||||
	peer.device = device
 | 
			
		||||
 | 
			
		||||
	peer.timer.keepalivePassive = NewStoppedTimer()
 | 
			
		||||
	peer.timer.keepalivePersistent = NewStoppedTimer()
 | 
			
		||||
	peer.timer.keepaliveAcknowledgement = NewStoppedTimer()
 | 
			
		||||
	peer.timer.zeroAllKeys = NewStoppedTimer()
 | 
			
		||||
 | 
			
		||||
	peer.flags.keepaliveWaiting = AtomicFalse
 | 
			
		||||
@ -115,6 +117,15 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
 | 
			
		||||
	return peer
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (peer *Peer) String() string {
 | 
			
		||||
	return fmt.Sprintf(
 | 
			
		||||
		"peer(%d %s %s)",
 | 
			
		||||
		peer.id,
 | 
			
		||||
		peer.endpoint.String(),
 | 
			
		||||
		base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
 | 
			
		||||
	)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (peer *Peer) Close() {
 | 
			
		||||
	close(peer.signal.stop)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -4,6 +4,8 @@ import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/binary"
 | 
			
		||||
	"golang.org/x/crypto/chacha20poly1305"
 | 
			
		||||
	"golang.org/x/net/ipv4"
 | 
			
		||||
	"golang.org/x/net/ipv6"
 | 
			
		||||
	"net"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
@ -362,7 +364,7 @@ func (device *Device) RoutineHandshake() {
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				logDebug.Println("Creating response...")
 | 
			
		||||
				logDebug.Println("Creating response message for", peer.String())
 | 
			
		||||
 | 
			
		||||
				outElem := device.NewOutboundElement()
 | 
			
		||||
				writer := bytes.NewBuffer(outElem.data[:0])
 | 
			
		||||
@ -416,6 +418,8 @@ func (peer *Peer) RoutineSequentialReceiver() {
 | 
			
		||||
	var elem *QueueInboundElement
 | 
			
		||||
 | 
			
		||||
	device := peer.device
 | 
			
		||||
 | 
			
		||||
	logInfo := device.log.Info
 | 
			
		||||
	logDebug := device.log.Debug
 | 
			
		||||
	logDebug.Println("Routine, sequential receiver, started for peer", peer.id)
 | 
			
		||||
 | 
			
		||||
@ -450,7 +454,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
 | 
			
		||||
 | 
			
		||||
			peer.KeepKeyFreshReceiving()
 | 
			
		||||
 | 
			
		||||
			// check if confirming handshake
 | 
			
		||||
			// check if using new key-pair
 | 
			
		||||
 | 
			
		||||
			kp := &peer.keyPairs
 | 
			
		||||
			kp.mutex.Lock()
 | 
			
		||||
@ -465,17 +469,18 @@ func (peer *Peer) RoutineSequentialReceiver() {
 | 
			
		||||
			// check for keep-alive
 | 
			
		||||
 | 
			
		||||
			if len(elem.packet) == 0 {
 | 
			
		||||
				logDebug.Println("Received keep-alive from", peer.String())
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// verify source and strip padding
 | 
			
		||||
 | 
			
		||||
			switch elem.packet[0] >> 4 {
 | 
			
		||||
			case IPv4version:
 | 
			
		||||
			case ipv4.Version:
 | 
			
		||||
 | 
			
		||||
				// strip padding
 | 
			
		||||
 | 
			
		||||
				if len(elem.packet) < IPv4headerSize {
 | 
			
		||||
				if len(elem.packet) < ipv4.HeaderLen {
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
@ -487,31 +492,33 @@ func (peer *Peer) RoutineSequentialReceiver() {
 | 
			
		||||
 | 
			
		||||
				dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
 | 
			
		||||
				if device.routingTable.LookupIPv4(dst) != peer {
 | 
			
		||||
					logInfo.Println("Packet with unallowed source IP from", peer.String())
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
			case IPv6version:
 | 
			
		||||
			case ipv6.Version:
 | 
			
		||||
 | 
			
		||||
				// strip padding
 | 
			
		||||
 | 
			
		||||
				if len(elem.packet) < IPv6headerSize {
 | 
			
		||||
				if len(elem.packet) < ipv6.HeaderLen {
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
 | 
			
		||||
				length := binary.BigEndian.Uint16(field)
 | 
			
		||||
				length += IPv6headerSize
 | 
			
		||||
				length += ipv6.HeaderLen
 | 
			
		||||
				elem.packet = elem.packet[:length]
 | 
			
		||||
 | 
			
		||||
				// verify IPv6 source
 | 
			
		||||
 | 
			
		||||
				dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
 | 
			
		||||
				if device.routingTable.LookupIPv6(dst) != peer {
 | 
			
		||||
					logInfo.Println("Packet with unallowed source IP from", peer.String())
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
			default:
 | 
			
		||||
				logDebug.Println("Receieved packet with unknown IP version")
 | 
			
		||||
				logInfo.Println("Packet with invalid IP version from", peer.String())
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
@ -522,6 +529,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (device *Device) RoutineWriteToTUN(tun TUNDevice) {
 | 
			
		||||
 | 
			
		||||
	logError := device.log.Error
 | 
			
		||||
	logDebug := device.log.Debug
 | 
			
		||||
	logDebug.Println("Routine, sequential tun writer, started")
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										69
									
								
								src/send.go
									
									
									
									
									
								
							
							
						
						
									
										69
									
								
								src/send.go
									
									
									
									
									
								
							@ -3,6 +3,8 @@ package main
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/binary"
 | 
			
		||||
	"golang.org/x/crypto/chacha20poly1305"
 | 
			
		||||
	"golang.org/x/net/ipv4"
 | 
			
		||||
	"golang.org/x/net/ipv6"
 | 
			
		||||
	"net"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
@ -21,28 +23,26 @@ import (
 | 
			
		||||
 * The functions in this file occure (roughly) in the order packets are processed.
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
/* A work unit
 | 
			
		||||
 *
 | 
			
		||||
 * The sequential consumers will attempt to take the lock,
 | 
			
		||||
 * workers release lock when they have completed work on the packet.
 | 
			
		||||
/* The sequential consumers will attempt to take the lock,
 | 
			
		||||
 * workers release lock when they have completed work (encryption) on the packet.
 | 
			
		||||
 *
 | 
			
		||||
 * If the element is inserted into the "encryption queue",
 | 
			
		||||
 * the content is preceeded by enough "junk" to contain the header
 | 
			
		||||
 * the content is preceeded by enough "junk" to contain the transport header
 | 
			
		||||
 * (to allow the construction of transport messages in-place)
 | 
			
		||||
 */
 | 
			
		||||
type QueueOutboundElement struct {
 | 
			
		||||
	dropped int32
 | 
			
		||||
	mutex   sync.Mutex
 | 
			
		||||
	data    [MaxMessageSize]byte
 | 
			
		||||
	packet  []byte   // slice of "data" (always!)
 | 
			
		||||
	nonce   uint64   // nonce for encryption
 | 
			
		||||
	keyPair *KeyPair // key-pair for encryption
 | 
			
		||||
	peer    *Peer    // related peer
 | 
			
		||||
	data    [MaxMessageSize]byte // slice holding the packet data
 | 
			
		||||
	packet  []byte               // slice of "data" (always!)
 | 
			
		||||
	nonce   uint64               // nonce for encryption
 | 
			
		||||
	keyPair *KeyPair             // key-pair for encryption
 | 
			
		||||
	peer    *Peer                // related peer
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (peer *Peer) FlushNonceQueue() {
 | 
			
		||||
	elems := len(peer.queue.nonce)
 | 
			
		||||
	for i := 0; i < elems; i += 1 {
 | 
			
		||||
	for i := 0; i < elems; i++ {
 | 
			
		||||
		select {
 | 
			
		||||
		case <-peer.queue.nonce:
 | 
			
		||||
		default:
 | 
			
		||||
@ -111,14 +111,18 @@ func addToEncryptionQueue(
 | 
			
		||||
 * Obs. Single instance per TUN device
 | 
			
		||||
 */
 | 
			
		||||
func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
 | 
			
		||||
 | 
			
		||||
	if tun == nil {
 | 
			
		||||
		// dummy
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	elem := device.NewOutboundElement()
 | 
			
		||||
 | 
			
		||||
	device.log.Debug.Println("Routine, TUN Reader: started")
 | 
			
		||||
	logDebug := device.log.Debug
 | 
			
		||||
	logError := device.log.Error
 | 
			
		||||
 | 
			
		||||
	logDebug.Println("Routine, TUN Reader: started")
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		// read packet
 | 
			
		||||
 | 
			
		||||
@ -129,12 +133,17 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
 | 
			
		||||
		elem.packet = elem.data[MessageTransportHeaderSize:]
 | 
			
		||||
		size, err := tun.Read(elem.packet)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			device.log.Error.Println("Failed to read packet from TUN device:", err)
 | 
			
		||||
			continue
 | 
			
		||||
 | 
			
		||||
			// stop process
 | 
			
		||||
 | 
			
		||||
			logError.Println("Failed to read packet from TUN device:", err)
 | 
			
		||||
			device.Close()
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		elem.packet = elem.packet[:size]
 | 
			
		||||
		if len(elem.packet) < IPv4headerSize {
 | 
			
		||||
			device.log.Error.Println("Packet too short, length:", size)
 | 
			
		||||
		if len(elem.packet) < ipv4.HeaderLen {
 | 
			
		||||
			logError.Println("Packet too short, length:", size)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
@ -142,23 +151,24 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
 | 
			
		||||
 | 
			
		||||
		var peer *Peer
 | 
			
		||||
		switch elem.packet[0] >> 4 {
 | 
			
		||||
		case IPv4version:
 | 
			
		||||
		case ipv4.Version:
 | 
			
		||||
			dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
 | 
			
		||||
			peer = device.routingTable.LookupIPv4(dst)
 | 
			
		||||
 | 
			
		||||
		case IPv6version:
 | 
			
		||||
		case ipv6.Version:
 | 
			
		||||
			dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
 | 
			
		||||
			peer = device.routingTable.LookupIPv6(dst)
 | 
			
		||||
 | 
			
		||||
		default:
 | 
			
		||||
			device.log.Debug.Println("Receieved packet with unknown IP version")
 | 
			
		||||
			logDebug.Println("Receieved packet with unknown IP version")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if peer == nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if peer.endpoint == nil {
 | 
			
		||||
			device.log.Debug.Println("No known endpoint for peer", peer.id)
 | 
			
		||||
			logDebug.Println("No known endpoint for peer", peer.String())
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
@ -184,7 +194,7 @@ func (peer *Peer) RoutineNonce() {
 | 
			
		||||
 | 
			
		||||
	device := peer.device
 | 
			
		||||
	logDebug := device.log.Debug
 | 
			
		||||
	logDebug.Println("Routine, nonce worker, started for peer", peer.id)
 | 
			
		||||
	logDebug.Println("Routine, nonce worker, started for peer", peer.String())
 | 
			
		||||
 | 
			
		||||
	func() {
 | 
			
		||||
 | 
			
		||||
@ -216,15 +226,15 @@ func (peer *Peer) RoutineNonce() {
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
				signalSend(peer.signal.handshakeBegin)
 | 
			
		||||
				logDebug.Println("Waiting for key-pair, peer", peer.id)
 | 
			
		||||
				logDebug.Println("Awaiting key-pair for", peer.String())
 | 
			
		||||
 | 
			
		||||
				select {
 | 
			
		||||
				case <-peer.signal.newKeyPair:
 | 
			
		||||
					logDebug.Println("Key-pair negotiated for peer", peer.id)
 | 
			
		||||
					logDebug.Println("Key-pair negotiated for", peer.String())
 | 
			
		||||
					goto NextPacket
 | 
			
		||||
 | 
			
		||||
				case <-peer.signal.flushNonceQueue:
 | 
			
		||||
					logDebug.Println("Clearing queue for peer", peer.id)
 | 
			
		||||
					logDebug.Println("Clearing queue for", peer.String())
 | 
			
		||||
					peer.FlushNonceQueue()
 | 
			
		||||
					elem = nil
 | 
			
		||||
					goto NextPacket
 | 
			
		||||
@ -313,13 +323,14 @@ func (peer *Peer) RoutineSequentialSender() {
 | 
			
		||||
	device := peer.device
 | 
			
		||||
 | 
			
		||||
	logDebug := device.log.Debug
 | 
			
		||||
	logDebug.Println("Routine, sequential sender, started for peer", peer.id)
 | 
			
		||||
	logDebug.Println("Routine, sequential sender, started for", peer.String())
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case <-peer.signal.stop:
 | 
			
		||||
			logDebug.Println("Routine, sequential sender, stopped for peer", peer.id)
 | 
			
		||||
			logDebug.Println("Routine, sequential sender, stopped for", peer.String())
 | 
			
		||||
			return
 | 
			
		||||
 | 
			
		||||
		case work := <-peer.queue.outbound:
 | 
			
		||||
			work.mutex.Lock()
 | 
			
		||||
			if work.IsDropped() {
 | 
			
		||||
@ -334,7 +345,7 @@ func (peer *Peer) RoutineSequentialSender() {
 | 
			
		||||
				defer peer.mutex.RUnlock()
 | 
			
		||||
 | 
			
		||||
				if peer.endpoint == nil {
 | 
			
		||||
					logDebug.Println("No endpoint for peer:", peer.id)
 | 
			
		||||
					logDebug.Println("No endpoint for", peer.String())
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
@ -352,7 +363,7 @@ func (peer *Peer) RoutineSequentialSender() {
 | 
			
		||||
				}
 | 
			
		||||
				atomic.AddUint64(&peer.txBytes, uint64(len(work.packet)))
 | 
			
		||||
 | 
			
		||||
				// reset keep-alive (passive keep-alives / acknowledgements)
 | 
			
		||||
				// reset keep-alive
 | 
			
		||||
 | 
			
		||||
				peer.TimerResetKeepalive()
 | 
			
		||||
			}()
 | 
			
		||||
 | 
			
		||||
@ -50,7 +50,7 @@ func (peer *Peer) KeepKeyFreshReceiving() {
 | 
			
		||||
 * - First transport message under the "next" key
 | 
			
		||||
 */
 | 
			
		||||
func (peer *Peer) EventHandshakeComplete() {
 | 
			
		||||
	peer.device.log.Debug.Println("Handshake completed")
 | 
			
		||||
	peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
 | 
			
		||||
	peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
 | 
			
		||||
	signalSend(peer.signal.handshakeCompleted)
 | 
			
		||||
}
 | 
			
		||||
@ -112,7 +112,7 @@ func (peer *Peer) TimerResetKeepalive() {
 | 
			
		||||
 | 
			
		||||
	// stop acknowledgement timer
 | 
			
		||||
 | 
			
		||||
	timerStop(peer.timer.keepaliveAcknowledgement)
 | 
			
		||||
	timerStop(peer.timer.keepalivePassive)
 | 
			
		||||
	atomic.StoreInt32(&peer.flags.keepaliveWaiting, AtomicFalse)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -140,7 +140,7 @@ func (peer *Peer) RoutineTimerHandler() {
 | 
			
		||||
	device := peer.device
 | 
			
		||||
 | 
			
		||||
	logDebug := device.log.Debug
 | 
			
		||||
	logDebug.Println("Routine, timer handler, started for peer", peer.id)
 | 
			
		||||
	logDebug.Println("Routine, timer handler, started for peer", peer.String())
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
@ -152,14 +152,14 @@ func (peer *Peer) RoutineTimerHandler() {
 | 
			
		||||
 | 
			
		||||
		case <-peer.timer.keepalivePersistent.C:
 | 
			
		||||
 | 
			
		||||
			logDebug.Println("Sending persistent keep-alive to peer", peer.id)
 | 
			
		||||
			logDebug.Println("Sending persistent keep-alive to", peer.String())
 | 
			
		||||
 | 
			
		||||
			peer.SendKeepAlive()
 | 
			
		||||
			peer.TimerResetKeepalive()
 | 
			
		||||
 | 
			
		||||
		case <-peer.timer.keepaliveAcknowledgement.C:
 | 
			
		||||
		case <-peer.timer.keepalivePassive.C:
 | 
			
		||||
 | 
			
		||||
			logDebug.Println("Sending passive persistent keep-alive to peer", peer.id)
 | 
			
		||||
			logDebug.Println("Sending passive persistent keep-alive to", peer.String())
 | 
			
		||||
 | 
			
		||||
			peer.SendKeepAlive()
 | 
			
		||||
			peer.TimerResetKeepalive()
 | 
			
		||||
@ -168,7 +168,7 @@ func (peer *Peer) RoutineTimerHandler() {
 | 
			
		||||
 | 
			
		||||
		case <-peer.timer.zeroAllKeys.C:
 | 
			
		||||
 | 
			
		||||
			logDebug.Println("Clearing all key material for peer", peer.id)
 | 
			
		||||
			logDebug.Println("Clearing all key material for", peer.String())
 | 
			
		||||
 | 
			
		||||
			// zero out key pairs
 | 
			
		||||
 | 
			
		||||
@ -208,14 +208,12 @@ func (peer *Peer) RoutineHandshakeInitiator() {
 | 
			
		||||
 | 
			
		||||
	var elem *QueueOutboundElement
 | 
			
		||||
 | 
			
		||||
	logInfo := device.log.Info
 | 
			
		||||
	logError := device.log.Error
 | 
			
		||||
	logDebug := device.log.Debug
 | 
			
		||||
	logDebug.Println("Routine, handshake initator, started for peer", peer.id)
 | 
			
		||||
	logDebug.Println("Routine, handshake initator, started for", peer.String())
 | 
			
		||||
 | 
			
		||||
	for run := true; run; {
 | 
			
		||||
		var err error
 | 
			
		||||
		var attempts uint
 | 
			
		||||
		var deadline time.Time
 | 
			
		||||
	for {
 | 
			
		||||
 | 
			
		||||
		// wait for signal
 | 
			
		||||
 | 
			
		||||
@ -227,15 +225,17 @@ func (peer *Peer) RoutineHandshakeInitiator() {
 | 
			
		||||
 | 
			
		||||
		// wait for handshake
 | 
			
		||||
 | 
			
		||||
		run = func() bool {
 | 
			
		||||
			for {
 | 
			
		||||
		func() {
 | 
			
		||||
			var err error
 | 
			
		||||
			var deadline time.Time
 | 
			
		||||
			for attempts := uint(1); ; attempts++ {
 | 
			
		||||
 | 
			
		||||
				// clear completed signal
 | 
			
		||||
 | 
			
		||||
				select {
 | 
			
		||||
				case <-peer.signal.handshakeCompleted:
 | 
			
		||||
				case <-peer.signal.stop:
 | 
			
		||||
					return false
 | 
			
		||||
					return
 | 
			
		||||
				default:
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
@ -246,43 +246,39 @@ func (peer *Peer) RoutineHandshakeInitiator() {
 | 
			
		||||
				}
 | 
			
		||||
				elem, err = peer.BeginHandshakeInitiation()
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logError.Println("Failed to create initiation message:", err)
 | 
			
		||||
					break
 | 
			
		||||
					logError.Println("Failed to create initiation message", err, "for", peer.String())
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// set timeout
 | 
			
		||||
 | 
			
		||||
				attempts += 1
 | 
			
		||||
				if attempts == 1 {
 | 
			
		||||
					deadline = time.Now().Add(MaxHandshakeAttemptTime)
 | 
			
		||||
				}
 | 
			
		||||
				timeout := time.NewTimer(RekeyTimeout)
 | 
			
		||||
				logDebug.Println("Handshake initiation attempt", attempts, "queued for peer", peer.id)
 | 
			
		||||
				logDebug.Println("Handshake initiation attempt", attempts, "queued for", peer.String())
 | 
			
		||||
 | 
			
		||||
				// wait for handshake or timeout
 | 
			
		||||
 | 
			
		||||
				select {
 | 
			
		||||
 | 
			
		||||
				case <-peer.signal.stop:
 | 
			
		||||
					return true
 | 
			
		||||
					return
 | 
			
		||||
 | 
			
		||||
				case <-peer.signal.handshakeCompleted:
 | 
			
		||||
					<-timeout.C
 | 
			
		||||
					return true
 | 
			
		||||
					return
 | 
			
		||||
 | 
			
		||||
				case <-timeout.C:
 | 
			
		||||
					logDebug.Println("Timeout")
 | 
			
		||||
 | 
			
		||||
					// check if sufficient time for retry
 | 
			
		||||
 | 
			
		||||
					if deadline.Before(time.Now().Add(RekeyTimeout)) {
 | 
			
		||||
						logInfo.Println("Handshake negotiation timed out for", peer.String())
 | 
			
		||||
						signalSend(peer.signal.flushNonceQueue)
 | 
			
		||||
						timerStop(peer.timer.keepalivePersistent)
 | 
			
		||||
						timerStop(peer.timer.keepaliveAcknowledgement)
 | 
			
		||||
						return true
 | 
			
		||||
						timerStop(peer.timer.keepalivePassive)
 | 
			
		||||
						return
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			return true
 | 
			
		||||
		}()
 | 
			
		||||
 | 
			
		||||
		signalClear(peer.signal.handshakeBegin)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										19
									
								
								src/trie.go
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								src/trie.go
									
									
									
									
									
								
							@ -23,7 +23,8 @@ type Trie struct {
 | 
			
		||||
	bits  []byte
 | 
			
		||||
	peer  *Peer
 | 
			
		||||
 | 
			
		||||
	// Index of "branching" bit
 | 
			
		||||
	// index of "branching" bit
 | 
			
		||||
 | 
			
		||||
	bit_at_byte  uint
 | 
			
		||||
	bit_at_shift uint
 | 
			
		||||
}
 | 
			
		||||
@ -36,7 +37,7 @@ type Trie struct {
 | 
			
		||||
func commonBits(ip1 net.IP, ip2 net.IP) uint {
 | 
			
		||||
	var i uint
 | 
			
		||||
	size := uint(len(ip1))
 | 
			
		||||
	for i = 0; i < size; i += 1 {
 | 
			
		||||
	for i = 0; i < size; i++ {
 | 
			
		||||
		v := ip1[i] ^ ip2[i]
 | 
			
		||||
		if v != 0 {
 | 
			
		||||
			v >>= 1
 | 
			
		||||
@ -84,7 +85,7 @@ func (node *Trie) RemovePeer(p *Peer) *Trie {
 | 
			
		||||
		return node
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Walk recursivly
 | 
			
		||||
	// walk recursivly
 | 
			
		||||
 | 
			
		||||
	node.child[0] = node.child[0].RemovePeer(p)
 | 
			
		||||
	node.child[1] = node.child[1].RemovePeer(p)
 | 
			
		||||
@ -93,7 +94,7 @@ func (node *Trie) RemovePeer(p *Peer) *Trie {
 | 
			
		||||
		return node
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Remove peer & merge
 | 
			
		||||
	// remove peer & merge
 | 
			
		||||
 | 
			
		||||
	node.peer = nil
 | 
			
		||||
	if node.child[0] == nil {
 | 
			
		||||
@ -108,7 +109,7 @@ func (node *Trie) choose(ip net.IP) byte {
 | 
			
		||||
 | 
			
		||||
func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
 | 
			
		||||
 | 
			
		||||
	// At leaf
 | 
			
		||||
	// at leaf
 | 
			
		||||
 | 
			
		||||
	if node == nil {
 | 
			
		||||
		return &Trie{
 | 
			
		||||
@ -120,7 +121,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Traverse deeper
 | 
			
		||||
	// traverse deeper
 | 
			
		||||
 | 
			
		||||
	common := commonBits(node.bits, ip)
 | 
			
		||||
	if node.cidr <= cidr && common >= node.cidr {
 | 
			
		||||
@ -133,7 +134,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
 | 
			
		||||
		return node
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Split node
 | 
			
		||||
	// split node
 | 
			
		||||
 | 
			
		||||
	newNode := &Trie{
 | 
			
		||||
		bits:         ip,
 | 
			
		||||
@ -145,7 +146,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
 | 
			
		||||
 | 
			
		||||
	cidr = min(cidr, common)
 | 
			
		||||
 | 
			
		||||
	// Check for shorter prefix
 | 
			
		||||
	// check for shorter prefix
 | 
			
		||||
 | 
			
		||||
	if newNode.cidr == cidr {
 | 
			
		||||
		bit := newNode.choose(node.bits)
 | 
			
		||||
@ -153,7 +154,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
 | 
			
		||||
		return newNode
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Create new parent for node & newNode
 | 
			
		||||
	// create new parent for node & newNode
 | 
			
		||||
 | 
			
		||||
	parent := &Trie{
 | 
			
		||||
		bits:         ip,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user