diff --git a/src/device.go b/src/device.go index b272544..a15961a 100644 --- a/src/device.go +++ b/src/device.go @@ -11,7 +11,11 @@ type Device struct { log *Logger // collection of loggers for levels idCounter uint // for assigning debug ids to peers fwMark uint32 - net struct { + pool struct { + // pools objects for reuse + messageBuffers sync.Pool + } + net struct { // seperate for performance reasons mutex sync.RWMutex addr *net.UDPAddr // UDP source address @@ -57,6 +61,14 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) { } } +func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { + return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) +} + +func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { + device.pool.messageBuffers.Put(msg) +} + func NewDevice(tun TUNDevice, logLevel int) *Device { device := new(Device) @@ -78,6 +90,14 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { device.net.addr, _ = net.ResolveUDPAddr(addr.Network(), addr.String()) device.net.mutex.Unlock() + // setup pools + + device.pool.messageBuffers = sync.Pool{ + New: func() interface{} { + return new([MaxMessageSize]byte) + }, + } + // create queues device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize) diff --git a/src/receive.go b/src/receive.go index f2bf70b..31f74e2 100644 --- a/src/receive.go +++ b/src/receive.go @@ -15,12 +15,14 @@ import ( type QueueHandshakeElement struct { msgType uint32 packet []byte + buffer *[MaxMessageSize]byte source *net.UDPAddr } type QueueInboundElement struct { dropped int32 mutex sync.Mutex + buffer *[MaxMessageSize]byte packet []byte counter uint64 keyPair *KeyPair @@ -34,7 +36,7 @@ func (elem *QueueInboundElement) IsDropped() bool { return atomic.LoadInt32(&elem.dropped) == AtomicTrue } -func addToInboundQueue( +func (device *Device) addToInboundQueue( queue chan *QueueInboundElement, element *QueueInboundElement, ) { @@ -52,7 +54,7 @@ func addToInboundQueue( } } -func addToHandshakeQueue( +func (device *Device) addToHandshakeQueue( queue chan QueueHandshakeElement, element QueueHandshakeElement, ) { @@ -62,7 +64,8 @@ func addToHandshakeQueue( return default: select { - case <-queue: + case elem := <-queue: + device.PutMessageBuffer(elem.buffer) default: } } @@ -70,9 +73,6 @@ func addToHandshakeQueue( } /* Routine determining the busy state of the interface - * - * TODO: prehaps nicer to do this in response to events - * TODO: more well reasoned definition of "busy" */ func (device *Device) RoutineBusyMonitor() { samples := 0 @@ -109,10 +109,11 @@ func (device *Device) RoutineBusyMonitor() { func (device *Device) RoutineReceiveIncomming() { + logInfo := device.log.Info logDebug := device.log.Debug logDebug.Println("Routine, receive incomming, started") - var buffer []byte + var buffer *[MaxMessageSize]byte for { @@ -127,7 +128,7 @@ func (device *Device) RoutineReceiveIncomming() { // read next datagram if buffer == nil { - buffer = make([]byte, MaxMessageSize) + buffer = device.GetMessageBuffer() } device.net.mutex.RLock() @@ -140,7 +141,7 @@ func (device *Device) RoutineReceiveIncomming() { conn.SetReadDeadline(time.Now().Add(time.Second)) - size, raddr, err := conn.ReadFromUDP(buffer) + size, raddr, err := conn.ReadFromUDP(buffer[:]) if err != nil || size < MinMessageSize { continue } @@ -157,10 +158,11 @@ func (device *Device) RoutineReceiveIncomming() { // add to handshake queue - addToHandshakeQueue( + device.addToHandshakeQueue( device.queue.handshake, QueueHandshakeElement{ msgType: msgType, + buffer: buffer, packet: packet, source: raddr, }, @@ -210,21 +212,22 @@ func (device *Device) RoutineReceiveIncomming() { // add to peer queue peer := value.peer - work := new(QueueInboundElement) - work.packet = packet - work.keyPair = keyPair - work.dropped = AtomicFalse + work := &QueueInboundElement{ + packet: packet, + buffer: buffer, + keyPair: keyPair, + dropped: AtomicFalse, + } work.mutex.Lock() // add to decryption queues - addToInboundQueue(device.queue.decryption, work) - addToInboundQueue(peer.queue.inbound, work) + device.addToInboundQueue(device.queue.decryption, work) + device.addToInboundQueue(peer.queue.inbound, work) buffer = nil default: - // unknown message type - logDebug.Println("Got unknown message from:", raddr) + logInfo.Println("Got unknown message from:", raddr) } }() } @@ -261,7 +264,12 @@ func (device *Device) RoutineDecryption() { var err error copy(nonce[4:], counter) elem.counter = binary.LittleEndian.Uint64(counter) - elem.packet, err = elem.keyPair.receive.Open(elem.packet[:0], nonce[:], content, nil) + elem.packet, err = elem.keyPair.receive.Open( + elem.buffer[:0], + nonce[:], + content, + nil, + ) if err != nil { elem.Drop() } @@ -373,12 +381,16 @@ func (device *Device) RoutineHandshake() { logDebug.Println("Creating response message for", peer.String()) outElem := device.NewOutboundElement() - writer := bytes.NewBuffer(outElem.data[:0]) + writer := bytes.NewBuffer(outElem.buffer[:0]) binary.Write(writer, binary.LittleEndian, response) outElem.packet = writer.Bytes() peer.mac.AddMacs(outElem.packet) addToOutboundQueue(peer.queue.outbound, outElem) + // create new keypair + + peer.NewKeyPair() + case MessageResponseType: // unmarshal @@ -414,7 +426,7 @@ func (device *Device) RoutineHandshake() { peer.EventHandshakeComplete() default: - device.log.Error.Println("Invalid message type in handshake queue") + logError.Println("Invalid message type in handshake queue") } }() } @@ -529,7 +541,7 @@ func (peer *Peer) RoutineSequentialReceiver() { } atomic.AddUint64(&peer.rxBytes, uint64(len(elem.packet))) - addToInboundQueue(device.queue.inbound, elem) + device.addToInboundQueue(device.queue.inbound, elem) }() } } @@ -546,6 +558,7 @@ func (device *Device) RoutineWriteToTUN(tun TUNDevice) { return case elem := <-device.queue.inbound: _, err := tun.Write(elem.packet) + device.PutMessageBuffer(elem.buffer) if err != nil { logError.Println("Failed to write packet to TUN device:", err) } diff --git a/src/send.go b/src/send.go index d8ddc82..7a2fe44 100644 --- a/src/send.go +++ b/src/send.go @@ -33,11 +33,11 @@ import ( type QueueOutboundElement struct { dropped int32 mutex sync.Mutex - data [MaxMessageSize]byte // slice holding the packet data - packet []byte // slice of "data" (always!) - nonce uint64 // nonce for encryption - keyPair *KeyPair // key-pair for encryption - peer *Peer // related peer + buffer *[MaxMessageSize]byte // slice holding the packet data + packet []byte // slice of "data" (always!) + nonce uint64 // nonce for encryption + keyPair *KeyPair // key-pair for encryption + peer *Peer // related peer } func (peer *Peer) FlushNonceQueue() { @@ -51,13 +51,11 @@ func (peer *Peer) FlushNonceQueue() { } } -/* - * Assumption: The mutex of the returned element is released - */ func (device *Device) NewOutboundElement() *QueueOutboundElement { - // TODO: profile, consider sync.Pool - elem := new(QueueOutboundElement) - return elem + return &QueueOutboundElement{ + dropped: AtomicFalse, + buffer: device.pool.messageBuffers.Get().(*[MaxMessageSize]byte), + } } func (elem *QueueOutboundElement) Drop() { @@ -130,7 +128,7 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) { elem = device.NewOutboundElement() } - elem.packet = elem.data[MessageTransportHeaderSize:] + elem.packet = elem.buffer[MessageTransportHeaderSize:] size, err := tun.Read(elem.packet) if err != nil { @@ -284,7 +282,7 @@ func (device *Device) RoutineEncryption() { // populate header fields func() { - header := work.data[:MessageTransportHeaderSize] + header := work.buffer[:MessageTransportHeaderSize] fieldType := header[0:4] fieldReceiver := header[4:8] @@ -305,7 +303,7 @@ func (device *Device) RoutineEncryption() { nil, ) length := MessageTransportHeaderSize + len(work.packet) - work.packet = work.data[:length] + work.packet = work.buffer[:length] work.mutex.Unlock() // refresh key if necessary @@ -333,12 +331,16 @@ func (peer *Peer) RoutineSequentialSender() { case work := <-peer.queue.outbound: work.mutex.Lock() - if work.IsDropped() { - continue - } func() { + // return buffer to pool after processing + + defer device.PutMessageBuffer(work.buffer) + if work.IsDropped() { + return + } + // send to endpoint peer.mutex.RLock() @@ -357,10 +359,13 @@ func (peer *Peer) RoutineSequentialSender() { return } + // send message and return buffer to pool + _, err := device.net.conn.WriteToUDP(work.packet, peer.endpoint) if err != nil { return } + atomic.AddUint64(&peer.txBytes, uint64(len(work.packet))) // reset keep-alive diff --git a/src/timers.go b/src/timers.go index 2e5046e..9140e41 100644 --- a/src/timers.go +++ b/src/timers.go @@ -128,7 +128,7 @@ func (peer *Peer) BeginHandshakeInitiation() (*QueueOutboundElement, error) { // marshal & schedule for sending - writer := bytes.NewBuffer(elem.data[:0]) + writer := bytes.NewBuffer(elem.buffer[:0]) binary.Write(writer, binary.LittleEndian, msg) elem.packet = writer.Bytes() peer.mac.AddMacs(elem.packet)