Fixed deadlock in index.go
This commit is contained in:
		
							parent
							
								
									dd4da93749
								
							
						
					
					
						commit
						c5d7efc246
					
				
							
								
								
									
										162
									
								
								src/config.go
									
									
									
									
									
								
							
							
						
						
									
										162
									
								
								src/config.go
									
									
									
									
									
								
							@ -8,39 +8,36 @@ import (
 | 
			
		||||
	"net"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
	"syscall"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// #include <errno.h>
 | 
			
		||||
import "C"
 | 
			
		||||
 | 
			
		||||
/* TODO: More fine grained?
 | 
			
		||||
 */
 | 
			
		||||
const (
 | 
			
		||||
	ipcErrorNoPeer       = C.EPROTO
 | 
			
		||||
	ipcErrorNoKeyValue   = C.EPROTO
 | 
			
		||||
	ipcErrorInvalidKey   = C.EPROTO
 | 
			
		||||
	ipcErrorInvalidValue = C.EPROTO
 | 
			
		||||
	ipcErrorIO           = syscall.EIO
 | 
			
		||||
	ipcErrorNoPeer       = syscall.EPROTO
 | 
			
		||||
	ipcErrorNoKeyValue   = syscall.EPROTO
 | 
			
		||||
	ipcErrorInvalidKey   = syscall.EPROTO
 | 
			
		||||
	ipcErrorInvalidValue = syscall.EPROTO
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type IPCError struct {
 | 
			
		||||
	Code int
 | 
			
		||||
	Code syscall.Errno
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *IPCError) Error() string {
 | 
			
		||||
	return fmt.Sprintf("IPC error: %d", s.Code)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *IPCError) ErrorCode() int {
 | 
			
		||||
	return s.Code
 | 
			
		||||
func (s *IPCError) ErrorCode() uintptr {
 | 
			
		||||
	return uintptr(s.Code)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
 | 
			
		||||
 | 
			
		||||
	device.mutex.RLock()
 | 
			
		||||
	defer device.mutex.RUnlock()
 | 
			
		||||
func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 | 
			
		||||
 | 
			
		||||
	// create lines
 | 
			
		||||
 | 
			
		||||
	device.mutex.RLock()
 | 
			
		||||
 | 
			
		||||
	lines := make([]string, 0, 100)
 | 
			
		||||
	send := func(line string) {
 | 
			
		||||
		lines = append(lines, line)
 | 
			
		||||
@ -63,19 +60,25 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
 | 
			
		||||
			}
 | 
			
		||||
			send(fmt.Sprintf("tx_bytes=%d", peer.txBytes))
 | 
			
		||||
			send(fmt.Sprintf("rx_bytes=%d", peer.rxBytes))
 | 
			
		||||
			send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
 | 
			
		||||
			send(fmt.Sprintf("persistent_keepalive_interval=%d",
 | 
			
		||||
				atomic.LoadUint64(&peer.persistentKeepaliveInterval),
 | 
			
		||||
			))
 | 
			
		||||
			for _, ip := range device.routingTable.AllowedIPs(peer) {
 | 
			
		||||
				send("allowed_ip=" + ip.String())
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	device.mutex.RUnlock()
 | 
			
		||||
 | 
			
		||||
	// send lines
 | 
			
		||||
 | 
			
		||||
	for _, line := range lines {
 | 
			
		||||
		_, err := socket.WriteString(line + "\n")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
			return &IPCError{
 | 
			
		||||
				Code: ipcErrorIO,
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -83,13 +86,14 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 | 
			
		||||
	logger := device.log.Debug
 | 
			
		||||
	scanner := bufio.NewScanner(socket)
 | 
			
		||||
	logError := device.log.Error
 | 
			
		||||
	logDebug := device.log.Debug
 | 
			
		||||
 | 
			
		||||
	var peer *Peer
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
 | 
			
		||||
		// Parse line
 | 
			
		||||
		// parse line
 | 
			
		||||
 | 
			
		||||
		line := scanner.Text()
 | 
			
		||||
		if line == "" {
 | 
			
		||||
@ -97,7 +101,6 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 | 
			
		||||
		}
 | 
			
		||||
		parts := strings.Split(line, "=")
 | 
			
		||||
		if len(parts) != 2 {
 | 
			
		||||
			device.log.Debug.Println(parts)
 | 
			
		||||
			return &IPCError{Code: ipcErrorNoKeyValue}
 | 
			
		||||
		}
 | 
			
		||||
		key := parts[0]
 | 
			
		||||
@ -105,7 +108,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 | 
			
		||||
 | 
			
		||||
		switch key {
 | 
			
		||||
 | 
			
		||||
		/* Interface configuration */
 | 
			
		||||
		/* interface configuration */
 | 
			
		||||
 | 
			
		||||
		case "private_key":
 | 
			
		||||
			if value == "" {
 | 
			
		||||
@ -116,7 +119,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 | 
			
		||||
				var sk NoisePrivateKey
 | 
			
		||||
				err := sk.FromHex(value)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Println("Failed to set private_key:", err)
 | 
			
		||||
					logError.Println("Failed to set private_key:", err)
 | 
			
		||||
					return &IPCError{Code: ipcErrorInvalidValue}
 | 
			
		||||
				}
 | 
			
		||||
				device.SetPrivateKey(sk)
 | 
			
		||||
@ -126,22 +129,26 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 | 
			
		||||
			var port int
 | 
			
		||||
			_, err := fmt.Sscanf(value, "%d", &port)
 | 
			
		||||
			if err != nil || port > (1<<16) || port < 0 {
 | 
			
		||||
				logger.Println("Failed to set listen_port:", err)
 | 
			
		||||
				logError.Println("Failed to set listen_port:", err)
 | 
			
		||||
				return &IPCError{Code: ipcErrorInvalidValue}
 | 
			
		||||
			}
 | 
			
		||||
			device.net.mutex.Lock()
 | 
			
		||||
			device.net.addr.Port = port
 | 
			
		||||
			device.net.conn, err = net.ListenUDP("udp", device.net.addr)
 | 
			
		||||
			device.net.mutex.Unlock()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logError.Println("Failed to create UDP listener:", err)
 | 
			
		||||
				return &IPCError{Code: ipcErrorInvalidValue}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		case "fwmark":
 | 
			
		||||
			logger.Println("FWMark not handled yet")
 | 
			
		||||
			logError.Println("FWMark not handled yet")
 | 
			
		||||
 | 
			
		||||
		case "public_key":
 | 
			
		||||
			var pubKey NoisePublicKey
 | 
			
		||||
			err := pubKey.FromHex(value)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Println("Failed to get peer by public_key:", err)
 | 
			
		||||
				logError.Println("Failed to get peer by public_key:", err)
 | 
			
		||||
				return &IPCError{Code: ipcErrorInvalidValue}
 | 
			
		||||
			}
 | 
			
		||||
			device.mutex.RLock()
 | 
			
		||||
@ -153,22 +160,23 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 | 
			
		||||
				peer = device.NewPeer(pubKey)
 | 
			
		||||
			}
 | 
			
		||||
			if peer == nil {
 | 
			
		||||
				panic(errors.New("bug: failed to find peer"))
 | 
			
		||||
				panic(errors.New("bug: failed to find / create peer"))
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		case "replace_peers":
 | 
			
		||||
			if value == "true" {
 | 
			
		||||
				device.RemoveAllPeers()
 | 
			
		||||
			} else {
 | 
			
		||||
				logger.Println("Failed to set replace_peers, invalid value:", value)
 | 
			
		||||
				logError.Println("Failed to set replace_peers, invalid value:", value)
 | 
			
		||||
				return &IPCError{Code: ipcErrorInvalidValue}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		default:
 | 
			
		||||
			/* Peer configuration */
 | 
			
		||||
 | 
			
		||||
			/* peer configuration */
 | 
			
		||||
 | 
			
		||||
			if peer == nil {
 | 
			
		||||
				logger.Println("No peer referenced, before peer operation")
 | 
			
		||||
				logError.Println("No peer referenced, before peer operation")
 | 
			
		||||
				return &IPCError{Code: ipcErrorNoPeer}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
@ -178,7 +186,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 | 
			
		||||
				peer.mutex.Lock()
 | 
			
		||||
				device.RemovePeer(peer.handshake.remoteStatic)
 | 
			
		||||
				peer.mutex.Unlock()
 | 
			
		||||
				logger.Println("Remove peer")
 | 
			
		||||
				logDebug.Println("Removing", peer.String())
 | 
			
		||||
				peer = nil
 | 
			
		||||
 | 
			
		||||
			case "preshared_key":
 | 
			
		||||
@ -188,14 +196,14 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 | 
			
		||||
					return peer.handshake.presharedKey.FromHex(value)
 | 
			
		||||
				}()
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Println("Failed to set preshared_key:", err)
 | 
			
		||||
					logError.Println("Failed to set preshared_key:", err)
 | 
			
		||||
					return &IPCError{Code: ipcErrorInvalidValue}
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
			case "endpoint":
 | 
			
		||||
				addr, err := net.ResolveUDPAddr("udp", value)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Println("Failed to set endpoint:", value)
 | 
			
		||||
					logError.Println("Failed to set endpoint:", value)
 | 
			
		||||
					return &IPCError{Code: ipcErrorInvalidValue}
 | 
			
		||||
				}
 | 
			
		||||
				peer.mutex.Lock()
 | 
			
		||||
@ -205,35 +213,34 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 | 
			
		||||
			case "persistent_keepalive_interval":
 | 
			
		||||
				secs, err := strconv.ParseInt(value, 10, 64)
 | 
			
		||||
				if secs < 0 || err != nil {
 | 
			
		||||
					logger.Println("Failed to set persistent_keepalive_interval:", err)
 | 
			
		||||
					logError.Println("Failed to set persistent_keepalive_interval:", err)
 | 
			
		||||
					return &IPCError{Code: ipcErrorInvalidValue}
 | 
			
		||||
				}
 | 
			
		||||
				peer.mutex.Lock()
 | 
			
		||||
				peer.persistentKeepaliveInterval = uint64(secs)
 | 
			
		||||
				peer.mutex.Unlock()
 | 
			
		||||
				atomic.StoreUint64(
 | 
			
		||||
					&peer.persistentKeepaliveInterval,
 | 
			
		||||
					uint64(secs),
 | 
			
		||||
				)
 | 
			
		||||
 | 
			
		||||
			case "replace_allowed_ips":
 | 
			
		||||
				if value == "true" {
 | 
			
		||||
					device.routingTable.RemovePeer(peer)
 | 
			
		||||
				} else {
 | 
			
		||||
					logger.Println("Failed to set replace_allowed_ips, invalid value:", value)
 | 
			
		||||
					logError.Println("Failed to set replace_allowed_ips, invalid value:", value)
 | 
			
		||||
					return &IPCError{Code: ipcErrorInvalidValue}
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
			case "allowed_ip":
 | 
			
		||||
				_, network, err := net.ParseCIDR(value)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Println("Failed to set allowed_ip:", err)
 | 
			
		||||
					logError.Println("Failed to set allowed_ip:", err)
 | 
			
		||||
					return &IPCError{Code: ipcErrorInvalidValue}
 | 
			
		||||
				}
 | 
			
		||||
				ones, _ := network.Mask.Size()
 | 
			
		||||
				logger.Println(network, ones, network.IP)
 | 
			
		||||
				logError.Println(network, ones, network.IP)
 | 
			
		||||
				device.routingTable.Insert(network.IP, uint(ones), peer)
 | 
			
		||||
 | 
			
		||||
			/* Invalid key */
 | 
			
		||||
 | 
			
		||||
			default:
 | 
			
		||||
				logger.Println("Invalid key:", key)
 | 
			
		||||
				logError.Println("Invalid UAPI key:", key)
 | 
			
		||||
				return &IPCError{Code: ipcErrorInvalidKey}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
@ -244,46 +251,45 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 | 
			
		||||
 | 
			
		||||
func ipcHandle(device *Device, socket net.Conn) {
 | 
			
		||||
 | 
			
		||||
	func() {
 | 
			
		||||
		buffered := func(s io.ReadWriter) *bufio.ReadWriter {
 | 
			
		||||
			reader := bufio.NewReader(s)
 | 
			
		||||
			writer := bufio.NewWriter(s)
 | 
			
		||||
			return bufio.NewReadWriter(reader, writer)
 | 
			
		||||
		}(socket)
 | 
			
		||||
	defer socket.Close()
 | 
			
		||||
 | 
			
		||||
		defer buffered.Flush()
 | 
			
		||||
	buffered := func(s io.ReadWriter) *bufio.ReadWriter {
 | 
			
		||||
		reader := bufio.NewReader(s)
 | 
			
		||||
		writer := bufio.NewWriter(s)
 | 
			
		||||
		return bufio.NewReadWriter(reader, writer)
 | 
			
		||||
	}(socket)
 | 
			
		||||
 | 
			
		||||
		op, err := buffered.ReadString('\n')
 | 
			
		||||
	defer buffered.Flush()
 | 
			
		||||
 | 
			
		||||
	op, err := buffered.ReadString('\n')
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch op {
 | 
			
		||||
 | 
			
		||||
	case "set=1\n":
 | 
			
		||||
		device.log.Debug.Println("Config, set operation")
 | 
			
		||||
		err := ipcSetOperation(device, buffered)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return
 | 
			
		||||
			fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode())
 | 
			
		||||
		} else {
 | 
			
		||||
			fmt.Fprintf(buffered, "errno=0\n\n")
 | 
			
		||||
		}
 | 
			
		||||
		return
 | 
			
		||||
 | 
			
		||||
		switch op {
 | 
			
		||||
 | 
			
		||||
		case "set=1\n":
 | 
			
		||||
			device.log.Debug.Println("Config, set operation")
 | 
			
		||||
			err := ipcSetOperation(device, buffered)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode())
 | 
			
		||||
			} else {
 | 
			
		||||
				fmt.Fprintf(buffered, "errno=0\n\n")
 | 
			
		||||
			}
 | 
			
		||||
			break
 | 
			
		||||
 | 
			
		||||
		case "get=1\n":
 | 
			
		||||
			device.log.Debug.Println("Config, get operation")
 | 
			
		||||
			err := ipcGetOperation(device, buffered)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				fmt.Fprintf(buffered, "errno=1\n\n") // fix
 | 
			
		||||
			} else {
 | 
			
		||||
				fmt.Fprintf(buffered, "errno=0\n\n")
 | 
			
		||||
			}
 | 
			
		||||
			break
 | 
			
		||||
 | 
			
		||||
		default:
 | 
			
		||||
			device.log.Info.Println("Invalid UAPI operation:", op)
 | 
			
		||||
	case "get=1\n":
 | 
			
		||||
		device.log.Debug.Println("Config, get operation")
 | 
			
		||||
		err := ipcGetOperation(device, buffered)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode())
 | 
			
		||||
		} else {
 | 
			
		||||
			fmt.Fprintf(buffered, "errno=0\n\n")
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
		return
 | 
			
		||||
 | 
			
		||||
	socket.Close()
 | 
			
		||||
	default:
 | 
			
		||||
		device.log.Error.Println("Invalid UAPI operation:", op)
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -78,7 +78,6 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
 | 
			
		||||
	defer device.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	device.log = NewLogger(logLevel)
 | 
			
		||||
	// device.mtu = tun.MTU()
 | 
			
		||||
	device.peers = make(map[NoisePublicKey]*Peer)
 | 
			
		||||
	device.indices.Init()
 | 
			
		||||
	device.ratelimiter.Init()
 | 
			
		||||
@ -131,12 +130,21 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
 | 
			
		||||
 | 
			
		||||
func (device *Device) RoutineMTUUpdater(tun TUNDevice) {
 | 
			
		||||
	logError := device.log.Error
 | 
			
		||||
	for ; ; time.Sleep(time.Second) {
 | 
			
		||||
	for ; ; time.Sleep(5 * time.Second) {
 | 
			
		||||
 | 
			
		||||
		// load updated MTU
 | 
			
		||||
 | 
			
		||||
		mtu, err := tun.MTU()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logError.Println("Failed to load updated MTU of device:", err)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// upper bound of mtu
 | 
			
		||||
 | 
			
		||||
		if mtu+MessageTransportSize > MaxMessageSize {
 | 
			
		||||
			mtu = MaxMessageSize - MessageTransportSize
 | 
			
		||||
		}
 | 
			
		||||
		atomic.StoreInt32(&device.mtu, int32(mtu))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -6,8 +6,6 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/* Index=0 is reserved for unset indecies
 | 
			
		||||
 *
 | 
			
		||||
 * TODO: Rethink map[id] -> peer VS map[id] -> handshake and handshake <ref> peer
 | 
			
		||||
 *
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
@ -72,12 +70,12 @@ func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) {
 | 
			
		||||
 | 
			
		||||
		table.mutex.RLock()
 | 
			
		||||
		_, ok := table.table[index]
 | 
			
		||||
		table.mutex.RUnlock()
 | 
			
		||||
		if ok {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		table.mutex.RUnlock()
 | 
			
		||||
 | 
			
		||||
		// replace index
 | 
			
		||||
		// map index to handshake
 | 
			
		||||
 | 
			
		||||
		table.mutex.Lock()
 | 
			
		||||
		_, found := table.table[index]
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										20
									
								
								src/main.go
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								src/main.go
									
									
									
									
									
								
							@ -17,12 +17,14 @@ func main() {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch os.Args[1] {
 | 
			
		||||
 | 
			
		||||
	case "-f", "--foreground":
 | 
			
		||||
		foreground = true
 | 
			
		||||
		if len(os.Args) != 3 {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		interfaceName = os.Args[2]
 | 
			
		||||
 | 
			
		||||
	default:
 | 
			
		||||
		foreground = false
 | 
			
		||||
		if len(os.Args) != 2 {
 | 
			
		||||
@ -48,8 +50,8 @@ func main() {
 | 
			
		||||
	// open TUN device
 | 
			
		||||
 | 
			
		||||
	tun, err := CreateTUN(interfaceName)
 | 
			
		||||
	log.Println(tun, err)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Println("Failed to create tun device:", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -69,11 +71,15 @@ func main() {
 | 
			
		||||
	}
 | 
			
		||||
	defer uapi.Close()
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		conn, err := uapi.Accept()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logError.Fatal("accept error:", err)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			conn, err := uapi.Accept()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logError.Fatal("UAPI accept error:", err)
 | 
			
		||||
			}
 | 
			
		||||
			go ipcHandle(device, conn)
 | 
			
		||||
		}
 | 
			
		||||
		go ipcHandle(device, conn)
 | 
			
		||||
	}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	device.Wait()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -459,7 +459,8 @@ func (peer *Peer) NewKeyPair() *KeyPair {
 | 
			
		||||
 | 
			
		||||
	// remap index
 | 
			
		||||
 | 
			
		||||
	peer.device.indices.Insert(handshake.localIndex, IndexTableEntry{
 | 
			
		||||
	indices := &peer.device.indices
 | 
			
		||||
	indices.Insert(handshake.localIndex, IndexTableEntry{
 | 
			
		||||
		peer:      peer,
 | 
			
		||||
		keyPair:   keyPair,
 | 
			
		||||
		handshake: nil,
 | 
			
		||||
@ -476,7 +477,7 @@ func (peer *Peer) NewKeyPair() *KeyPair {
 | 
			
		||||
			if kp.previous != nil {
 | 
			
		||||
				kp.previous.send = nil
 | 
			
		||||
				kp.previous.receive = nil
 | 
			
		||||
				peer.device.indices.Delete(kp.previous.localIndex)
 | 
			
		||||
				indices.Delete(kp.previous.localIndex)
 | 
			
		||||
			}
 | 
			
		||||
			kp.previous = kp.current
 | 
			
		||||
			kp.current = keyPair
 | 
			
		||||
 | 
			
		||||
@ -212,18 +212,18 @@ func (device *Device) RoutineReceiveIncomming() {
 | 
			
		||||
				// add to peer queue
 | 
			
		||||
 | 
			
		||||
				peer := value.peer
 | 
			
		||||
				work := &QueueInboundElement{
 | 
			
		||||
				elem := &QueueInboundElement{
 | 
			
		||||
					packet:  packet,
 | 
			
		||||
					buffer:  buffer,
 | 
			
		||||
					keyPair: keyPair,
 | 
			
		||||
					dropped: AtomicFalse,
 | 
			
		||||
				}
 | 
			
		||||
				work.mutex.Lock()
 | 
			
		||||
				elem.mutex.Lock()
 | 
			
		||||
 | 
			
		||||
				// add to decryption queues
 | 
			
		||||
 | 
			
		||||
				device.addToInboundQueue(device.queue.decryption, work)
 | 
			
		||||
				device.addToInboundQueue(peer.queue.inbound, work)
 | 
			
		||||
				device.addToInboundQueue(device.queue.decryption, elem)
 | 
			
		||||
				device.addToInboundQueue(peer.queue.inbound, elem)
 | 
			
		||||
				buffer = nil
 | 
			
		||||
 | 
			
		||||
			default:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										81
									
								
								src/send.go
									
									
									
									
									
								
							
							
						
						
									
										81
									
								
								src/send.go
									
									
									
									
									
								
							@ -270,50 +270,65 @@ func (peer *Peer) RoutineNonce() {
 | 
			
		||||
 * Obs. One instance per core
 | 
			
		||||
 */
 | 
			
		||||
func (device *Device) RoutineEncryption() {
 | 
			
		||||
 | 
			
		||||
	var elem *QueueOutboundElement
 | 
			
		||||
	var nonce [chacha20poly1305.NonceSize]byte
 | 
			
		||||
	for work := range device.queue.encryption {
 | 
			
		||||
 | 
			
		||||
	logDebug := device.log.Debug
 | 
			
		||||
	logDebug.Println("Routine, encryption worker, started")
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
 | 
			
		||||
		// fetch next element
 | 
			
		||||
 | 
			
		||||
		select {
 | 
			
		||||
		case elem = <-device.queue.encryption:
 | 
			
		||||
		case <-device.signal.stop:
 | 
			
		||||
			logDebug.Println("Routine, encryption worker, stopped")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// check if dropped
 | 
			
		||||
 | 
			
		||||
		if work.IsDropped() {
 | 
			
		||||
		if elem.IsDropped() {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// populate header fields
 | 
			
		||||
 | 
			
		||||
		header := work.buffer[:MessageTransportHeaderSize]
 | 
			
		||||
		header := elem.buffer[:MessageTransportHeaderSize]
 | 
			
		||||
 | 
			
		||||
		fieldType := header[0:4]
 | 
			
		||||
		fieldReceiver := header[4:8]
 | 
			
		||||
		fieldNonce := header[8:16]
 | 
			
		||||
 | 
			
		||||
		binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
 | 
			
		||||
		binary.LittleEndian.PutUint32(fieldReceiver, work.keyPair.remoteIndex)
 | 
			
		||||
		binary.LittleEndian.PutUint64(fieldNonce, work.nonce)
 | 
			
		||||
		binary.LittleEndian.PutUint32(fieldReceiver, elem.keyPair.remoteIndex)
 | 
			
		||||
		binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
 | 
			
		||||
 | 
			
		||||
		// pad content to MTU size
 | 
			
		||||
 | 
			
		||||
		mtu := int(atomic.LoadInt32(&device.mtu))
 | 
			
		||||
		for i := len(work.packet); i < mtu; i++ {
 | 
			
		||||
			work.packet = append(work.packet, 0)
 | 
			
		||||
		for i := len(elem.packet); i < mtu; i++ {
 | 
			
		||||
			elem.packet = append(elem.packet, 0)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// encrypt content
 | 
			
		||||
 | 
			
		||||
		binary.LittleEndian.PutUint64(nonce[4:], work.nonce)
 | 
			
		||||
		work.packet = work.keyPair.send.Seal(
 | 
			
		||||
			work.packet[:0],
 | 
			
		||||
		binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
 | 
			
		||||
		elem.packet = elem.keyPair.send.Seal(
 | 
			
		||||
			elem.packet[:0],
 | 
			
		||||
			nonce[:],
 | 
			
		||||
			work.packet,
 | 
			
		||||
			elem.packet,
 | 
			
		||||
			nil,
 | 
			
		||||
		)
 | 
			
		||||
		length := MessageTransportHeaderSize + len(work.packet)
 | 
			
		||||
		work.packet = work.buffer[:length]
 | 
			
		||||
		work.mutex.Unlock()
 | 
			
		||||
		length := MessageTransportHeaderSize + len(elem.packet)
 | 
			
		||||
		elem.packet = elem.buffer[:length]
 | 
			
		||||
		elem.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
		// refresh key if necessary
 | 
			
		||||
 | 
			
		||||
		work.peer.KeepKeyFreshSending()
 | 
			
		||||
		elem.peer.KeepKeyFreshSending()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -334,49 +349,43 @@ func (peer *Peer) RoutineSequentialSender() {
 | 
			
		||||
			logDebug.Println("Routine, sequential sender, stopped for", peer.String())
 | 
			
		||||
			return
 | 
			
		||||
 | 
			
		||||
		case work := <-peer.queue.outbound:
 | 
			
		||||
			work.mutex.Lock()
 | 
			
		||||
		case elem := <-peer.queue.outbound:
 | 
			
		||||
			elem.mutex.Lock()
 | 
			
		||||
 | 
			
		||||
			func() {
 | 
			
		||||
 | 
			
		||||
				// return buffer to pool after processing
 | 
			
		||||
 | 
			
		||||
				defer device.PutMessageBuffer(work.buffer)
 | 
			
		||||
				if work.IsDropped() {
 | 
			
		||||
				if elem.IsDropped() {
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// send to endpoint
 | 
			
		||||
				// get endpoint and connection
 | 
			
		||||
 | 
			
		||||
				peer.mutex.RLock()
 | 
			
		||||
				defer peer.mutex.RUnlock()
 | 
			
		||||
 | 
			
		||||
				if peer.endpoint == nil {
 | 
			
		||||
				endpoint := peer.endpoint
 | 
			
		||||
				peer.mutex.RUnlock()
 | 
			
		||||
				if endpoint == nil {
 | 
			
		||||
					logDebug.Println("No endpoint for", peer.String())
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				device.net.mutex.RLock()
 | 
			
		||||
				defer device.net.mutex.RUnlock()
 | 
			
		||||
 | 
			
		||||
				if device.net.conn == nil {
 | 
			
		||||
				conn := device.net.conn
 | 
			
		||||
				device.net.mutex.RUnlock()
 | 
			
		||||
				if conn == nil {
 | 
			
		||||
					logDebug.Println("No source for device")
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// send message and return buffer to pool
 | 
			
		||||
				// send message and refresh keys
 | 
			
		||||
 | 
			
		||||
				_, err := device.net.conn.WriteToUDP(work.packet, peer.endpoint)
 | 
			
		||||
				_, err := conn.WriteToUDP(elem.packet, endpoint)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				atomic.AddUint64(&peer.txBytes, uint64(len(work.packet)))
 | 
			
		||||
 | 
			
		||||
				// reset keep-alive
 | 
			
		||||
 | 
			
		||||
				atomic.AddUint64(&peer.txBytes, uint64(len(elem.packet)))
 | 
			
		||||
				peer.TimerResetKeepalive()
 | 
			
		||||
			}()
 | 
			
		||||
 | 
			
		||||
			device.PutMessageBuffer(elem.buffer)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -138,6 +138,7 @@ func (peer *Peer) BeginHandshakeInitiation() (*QueueOutboundElement, error) {
 | 
			
		||||
 | 
			
		||||
func (peer *Peer) RoutineTimerHandler() {
 | 
			
		||||
	device := peer.device
 | 
			
		||||
	indices := &device.indices
 | 
			
		||||
 | 
			
		||||
	logDebug := device.log.Debug
 | 
			
		||||
	logDebug.Println("Routine, timer handler, started for peer", peer.String())
 | 
			
		||||
@ -170,29 +171,42 @@ func (peer *Peer) RoutineTimerHandler() {
 | 
			
		||||
 | 
			
		||||
			logDebug.Println("Clearing all key material for", peer.String())
 | 
			
		||||
 | 
			
		||||
			// zero out key pairs
 | 
			
		||||
			kp := &peer.keyPairs
 | 
			
		||||
			kp.mutex.Lock()
 | 
			
		||||
 | 
			
		||||
			func() {
 | 
			
		||||
				kp := &peer.keyPairs
 | 
			
		||||
				kp.mutex.Lock()
 | 
			
		||||
				// best we can do is wait for GC :( ?
 | 
			
		||||
				kp.current = nil
 | 
			
		||||
				kp.previous = nil
 | 
			
		||||
				kp.next = nil
 | 
			
		||||
				kp.mutex.Unlock()
 | 
			
		||||
			}()
 | 
			
		||||
			hs := &peer.handshake
 | 
			
		||||
			hs.mutex.Lock()
 | 
			
		||||
 | 
			
		||||
			// unmap local indecies
 | 
			
		||||
 | 
			
		||||
			indices.mutex.Lock()
 | 
			
		||||
			if kp.previous != nil {
 | 
			
		||||
				delete(indices.table, kp.previous.localIndex)
 | 
			
		||||
			}
 | 
			
		||||
			if kp.current != nil {
 | 
			
		||||
				delete(indices.table, kp.current.localIndex)
 | 
			
		||||
			}
 | 
			
		||||
			if kp.next != nil {
 | 
			
		||||
				delete(indices.table, kp.next.localIndex)
 | 
			
		||||
			}
 | 
			
		||||
			delete(indices.table, hs.localIndex)
 | 
			
		||||
			indices.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
			// zero out key pairs (TODO: better than wait for GC)
 | 
			
		||||
 | 
			
		||||
			kp.current = nil
 | 
			
		||||
			kp.previous = nil
 | 
			
		||||
			kp.next = nil
 | 
			
		||||
			kp.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
			// zero out handshake
 | 
			
		||||
 | 
			
		||||
			func() {
 | 
			
		||||
				hs := &peer.handshake
 | 
			
		||||
				hs.mutex.Lock()
 | 
			
		||||
				hs.localEphemeral = NoisePrivateKey{}
 | 
			
		||||
				hs.remoteEphemeral = NoisePublicKey{}
 | 
			
		||||
				hs.chainKey = [blake2s.Size]byte{}
 | 
			
		||||
				hs.hash = [blake2s.Size]byte{}
 | 
			
		||||
				hs.mutex.Unlock()
 | 
			
		||||
			}()
 | 
			
		||||
			hs.localIndex = 0
 | 
			
		||||
			hs.localEphemeral = NoisePrivateKey{}
 | 
			
		||||
			hs.remoteEphemeral = NoisePublicKey{}
 | 
			
		||||
			hs.chainKey = [blake2s.Size]byte{}
 | 
			
		||||
			hs.hash = [blake2s.Size]byte{}
 | 
			
		||||
			hs.mutex.Unlock()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user