Inital implementation of trie
This commit is contained in:
		
							parent
							
								
									8ce921987f
								
							
						
					
					
						commit
						ec3d656beb
					
				@ -6,6 +6,7 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/* todo : use real error code
 | 
			
		||||
@ -18,6 +19,7 @@ const (
 | 
			
		||||
	ipcErrorInvalidPrivateKey = 3
 | 
			
		||||
	ipcErrorInvalidPublicKey  = 4
 | 
			
		||||
	ipcErrorInvalidPort       = 5
 | 
			
		||||
	ipcErrorInvalidIPAddress  = 6
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type IPCError struct {
 | 
			
		||||
@ -104,6 +106,10 @@ func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError {
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		case "replace_peers":
 | 
			
		||||
			if key == "true" {
 | 
			
		||||
				dev.RemoveAllPeers()
 | 
			
		||||
			}
 | 
			
		||||
			// todo: else fail
 | 
			
		||||
 | 
			
		||||
		default:
 | 
			
		||||
			/* Peer configuration */
 | 
			
		||||
@ -116,20 +122,27 @@ func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError {
 | 
			
		||||
 | 
			
		||||
			case "remove":
 | 
			
		||||
				peer.mutex.Lock()
 | 
			
		||||
 | 
			
		||||
				dev.RemovePeer(peer.publicKey)
 | 
			
		||||
				peer = nil
 | 
			
		||||
 | 
			
		||||
			case "preshared_key":
 | 
			
		||||
				func() {
 | 
			
		||||
				err := func() error {
 | 
			
		||||
					peer.mutex.Lock()
 | 
			
		||||
					defer peer.mutex.Unlock()
 | 
			
		||||
					return peer.presharedKey.FromHex(value)
 | 
			
		||||
				}()
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return &IPCError{Code: ipcErrorInvalidPublicKey}
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
			case "endpoint":
 | 
			
		||||
				func() {
 | 
			
		||||
					peer.mutex.Lock()
 | 
			
		||||
					defer peer.mutex.Unlock()
 | 
			
		||||
				}()
 | 
			
		||||
				ip := net.ParseIP(value)
 | 
			
		||||
				if ip == nil {
 | 
			
		||||
					return &IPCError{Code: ipcErrorInvalidIPAddress}
 | 
			
		||||
				}
 | 
			
		||||
				peer.mutex.Lock()
 | 
			
		||||
				peer.endpoint = ip
 | 
			
		||||
				peer.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
			case "persistent_keepalive_interval":
 | 
			
		||||
				func() {
 | 
			
		||||
 | 
			
		||||
@ -5,10 +5,39 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Device struct {
 | 
			
		||||
	mutex      sync.RWMutex
 | 
			
		||||
	peers      map[NoisePublicKey]*Peer
 | 
			
		||||
	privateKey NoisePrivateKey
 | 
			
		||||
	publicKey  NoisePublicKey
 | 
			
		||||
	fwMark     uint32
 | 
			
		||||
	listenPort uint16
 | 
			
		||||
	mutex        sync.RWMutex
 | 
			
		||||
	peers        map[NoisePublicKey]*Peer
 | 
			
		||||
	privateKey   NoisePrivateKey
 | 
			
		||||
	publicKey    NoisePublicKey
 | 
			
		||||
	fwMark       uint32
 | 
			
		||||
	listenPort   uint16
 | 
			
		||||
	routingTable RoutingTable
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (dev *Device) RemovePeer(key NoisePublicKey) {
 | 
			
		||||
	dev.mutex.Lock()
 | 
			
		||||
	defer dev.mutex.Unlock()
 | 
			
		||||
	peer, ok := dev.peers[key]
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	peer.mutex.Lock()
 | 
			
		||||
	dev.routingTable.RemovePeer(peer)
 | 
			
		||||
	delete(dev.peers, key)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (dev *Device) RemoveAllAllowedIps(peer *Peer) {
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (dev *Device) RemoveAllPeers() {
 | 
			
		||||
	dev.mutex.Lock()
 | 
			
		||||
	defer dev.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	for key, peer := range dev.peers {
 | 
			
		||||
		peer.mutex.Lock()
 | 
			
		||||
		dev.routingTable.RemovePeer(peer)
 | 
			
		||||
		delete(dev.peers, key)
 | 
			
		||||
		peer.mutex.Unlock()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										38
									
								
								src/noise.go
									
									
									
									
									
								
							
							
						
						
									
										38
									
								
								src/noise.go
									
									
									
									
									
								
							@ -18,34 +18,38 @@ type (
 | 
			
		||||
	NoiseNonce        uint64 // padded to 12-bytes
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (key *NoisePrivateKey) FromHex(s string) error {
 | 
			
		||||
	slice, err := hex.DecodeString(s)
 | 
			
		||||
func loadExactHex(dst []byte, src string) error {
 | 
			
		||||
	slice, err := hex.DecodeString(src)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if len(slice) != NoisePrivateKeySize {
 | 
			
		||||
		return errors.New("Invalid length of hex string for curve25519 point")
 | 
			
		||||
	if len(slice) != len(dst) {
 | 
			
		||||
		return errors.New("Hex string does not fit the slice")
 | 
			
		||||
	}
 | 
			
		||||
	copy(key[:], slice)
 | 
			
		||||
	copy(dst, slice)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (key *NoisePrivateKey) ToHex() string {
 | 
			
		||||
func (key *NoisePrivateKey) FromHex(src string) error {
 | 
			
		||||
	return loadExactHex(key[:], src)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (key NoisePrivateKey) ToHex() string {
 | 
			
		||||
	return hex.EncodeToString(key[:])
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (key *NoisePublicKey) FromHex(s string) error {
 | 
			
		||||
	slice, err := hex.DecodeString(s)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if len(slice) != NoisePublicKeySize {
 | 
			
		||||
		return errors.New("Invalid length of hex string for curve25519 scalar")
 | 
			
		||||
	}
 | 
			
		||||
	copy(key[:], slice)
 | 
			
		||||
	return nil
 | 
			
		||||
func (key *NoisePublicKey) FromHex(src string) error {
 | 
			
		||||
	return loadExactHex(key[:], src)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (key *NoisePublicKey) ToHex() string {
 | 
			
		||||
func (key NoisePublicKey) ToHex() string {
 | 
			
		||||
	return hex.EncodeToString(key[:])
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (key *NoiseSymmetricKey) FromHex(src string) error {
 | 
			
		||||
	return loadExactHex(key[:], src)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (key NoiseSymmetricKey) ToHex() string {
 | 
			
		||||
	return hex.EncodeToString(key[:])
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,7 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net"
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -15,4 +16,5 @@ type Peer struct {
 | 
			
		||||
	mutex        sync.RWMutex
 | 
			
		||||
	publicKey    NoisePublicKey
 | 
			
		||||
	presharedKey NoiseSymmetricKey
 | 
			
		||||
	endpoint     net.IP
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										175
									
								
								src/ping-test.go
									
									
									
									
									
								
							
							
						
						
									
										175
									
								
								src/ping-test.go
									
									
									
									
									
								
							@ -1,175 +0,0 @@
 | 
			
		||||
/* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */
 | 
			
		||||
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/rand"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"encoding/binary"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/dchest/blake2s"
 | 
			
		||||
	"github.com/titanous/noise"
 | 
			
		||||
	"golang.org/x/net/icmp"
 | 
			
		||||
	"golang.org/x/net/ipv4"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func ipChecksum(buf []byte) uint16 {
 | 
			
		||||
	sum := uint32(0)
 | 
			
		||||
	for ; len(buf) >= 2; buf = buf[2:] {
 | 
			
		||||
		sum += uint32(buf[0])<<8 | uint32(buf[1])
 | 
			
		||||
	}
 | 
			
		||||
	if len(buf) > 0 {
 | 
			
		||||
		sum += uint32(buf[0]) << 8
 | 
			
		||||
	}
 | 
			
		||||
	for sum > 0xffff {
 | 
			
		||||
		sum = (sum >> 16) + (sum & 0xffff)
 | 
			
		||||
	}
 | 
			
		||||
	csum := ^uint16(sum)
 | 
			
		||||
	if csum == 0 {
 | 
			
		||||
		csum = 0xffff
 | 
			
		||||
	}
 | 
			
		||||
	return csum
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func main() {
 | 
			
		||||
	ourPrivate, _ := base64.StdEncoding.DecodeString("WAmgVYXkbT2bCtdcDwolI88/iVi/aV3/PHcUBTQSYmo=")
 | 
			
		||||
	ourPublic, _ := base64.StdEncoding.DecodeString("K5sF9yESrSBsOXPd6TcpKNgqoy1Ik3ZFKl4FolzrRyI=")
 | 
			
		||||
	theirPublic, _ := base64.StdEncoding.DecodeString("qRCwZSKInrMAq5sepfCdaCsRJaoLe5jhtzfiw7CjbwM=")
 | 
			
		||||
	preshared, _ := base64.StdEncoding.DecodeString("FpCyhws9cxwWoV4xELtfJvjJN+zQVRPISllRWgeopVE=")
 | 
			
		||||
	cs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashBLAKE2s)
 | 
			
		||||
	hs := noise.NewHandshakeState(noise.Config{
 | 
			
		||||
		CipherSuite:   cs,
 | 
			
		||||
		Random:        rand.Reader,
 | 
			
		||||
		Pattern:       noise.HandshakeIK,
 | 
			
		||||
		Initiator:     true,
 | 
			
		||||
		Prologue:      []byte("WireGuard v1 zx2c4 Jason@zx2c4.com"),
 | 
			
		||||
		PresharedKey:  preshared,
 | 
			
		||||
		PresharedKeyPlacement: 2,
 | 
			
		||||
		StaticKeypair: noise.DHKey{Private: ourPrivate, Public: ourPublic},
 | 
			
		||||
		PeerStatic:    theirPublic,
 | 
			
		||||
	})
 | 
			
		||||
	conn, err := net.Dial("udp", "demo.wireguard.io:12913")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatalf("error dialing udp socket: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
	defer conn.Close()
 | 
			
		||||
 | 
			
		||||
	// write handshake initiation packet
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	tai64n := make([]byte, 12)
 | 
			
		||||
	binary.BigEndian.PutUint64(tai64n[:], 4611686018427387914+uint64(now.Unix()))
 | 
			
		||||
	binary.BigEndian.PutUint32(tai64n[8:], uint32(now.UnixNano()))
 | 
			
		||||
	initiationPacket := make([]byte, 8)
 | 
			
		||||
	initiationPacket[0] = 1 // Type: Initiation
 | 
			
		||||
	initiationPacket[1] = 0 // Reserved
 | 
			
		||||
	initiationPacket[2] = 0	// Reserved
 | 
			
		||||
	initiationPacket[3] = 0	// Reserved
 | 
			
		||||
	binary.LittleEndian.PutUint32(initiationPacket[4:], 28) // Sender index: 28 (arbitrary)
 | 
			
		||||
	initiationPacket, _, _ = hs.WriteMessage(initiationPacket, tai64n)
 | 
			
		||||
	hasher, _ := blake2s.New(&blake2s.Config{Size: 32})
 | 
			
		||||
	hasher.Write([]byte("mac1----"))
 | 
			
		||||
	hasher.Write(theirPublic)
 | 
			
		||||
	hasher, _ = blake2s.New(&blake2s.Config{Size: 16, Key: hasher.Sum(nil)})
 | 
			
		||||
	hasher.Write(initiationPacket)
 | 
			
		||||
	initiationPacket = append(initiationPacket, hasher.Sum(nil)[:16]...)
 | 
			
		||||
	initiationPacket = append(initiationPacket, make([]byte, 16)...)
 | 
			
		||||
	if _, err := conn.Write(initiationPacket); err != nil {
 | 
			
		||||
		log.Fatalf("error writing initiation packet: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// read handshake response packet
 | 
			
		||||
	responsePacket := make([]byte, 92)
 | 
			
		||||
	n, err := conn.Read(responsePacket)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatalf("error reading response packet: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
	if n != len(responsePacket) {
 | 
			
		||||
		log.Fatalf("response packet too short: want %d, got %d", len(responsePacket), n)
 | 
			
		||||
	}
 | 
			
		||||
	if responsePacket[0] != 2 { // Type: Response
 | 
			
		||||
		log.Fatalf("response packet type wrong: want %d, got %d", 2, responsePacket[0])
 | 
			
		||||
	}
 | 
			
		||||
	if responsePacket[1] != 0 || responsePacket[2] != 0 || responsePacket[3] != 0 {
 | 
			
		||||
		log.Fatalf("response packet has non-zero reserved fields")
 | 
			
		||||
	}
 | 
			
		||||
	theirIndex := binary.LittleEndian.Uint32(responsePacket[4:])
 | 
			
		||||
	ourIndex := binary.LittleEndian.Uint32(responsePacket[8:])
 | 
			
		||||
	if ourIndex != 28 {
 | 
			
		||||
		log.Fatalf("response packet index wrong: want %d, got %d", 28, ourIndex)
 | 
			
		||||
	}
 | 
			
		||||
	payload, sendCipher, receiveCipher, err := hs.ReadMessage(nil, responsePacket[12:60])
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatalf("error reading handshake message: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
	if len(payload) > 0 {
 | 
			
		||||
		log.Fatalf("unexpected payload: %x", payload)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// write ICMP Echo packet
 | 
			
		||||
	pingMessage, _ := (&icmp.Message{
 | 
			
		||||
		Type: ipv4.ICMPTypeEcho,
 | 
			
		||||
		Body: &icmp.Echo{
 | 
			
		||||
			ID:   921,
 | 
			
		||||
			Seq:  438,
 | 
			
		||||
			Data: []byte("WireGuard"),
 | 
			
		||||
		},
 | 
			
		||||
	}).Marshal(nil)
 | 
			
		||||
	pingHeader, err := (&ipv4.Header{
 | 
			
		||||
		Version:  ipv4.Version,
 | 
			
		||||
		Len:      ipv4.HeaderLen,
 | 
			
		||||
		TotalLen: ipv4.HeaderLen + len(pingMessage),
 | 
			
		||||
		Protocol: 1, // ICMP
 | 
			
		||||
		TTL:      20,
 | 
			
		||||
		Src:      net.IPv4(10, 189, 129, 2),
 | 
			
		||||
		Dst:      net.IPv4(10, 189, 129, 1),
 | 
			
		||||
	}).Marshal()
 | 
			
		||||
	binary.BigEndian.PutUint16(pingHeader[2:], uint16(ipv4.HeaderLen+len(pingMessage))) // fix the length endianness on BSDs
 | 
			
		||||
	pingData := append(pingHeader, pingMessage...)
 | 
			
		||||
	binary.BigEndian.PutUint16(pingData[10:], ipChecksum(pingData))
 | 
			
		||||
	pingPacket := make([]byte, 16)
 | 
			
		||||
	pingPacket[0] = 4 // Type: Data
 | 
			
		||||
	pingPacket[1] = 0 // Reserved
 | 
			
		||||
	pingPacket[2] = 0 // Reserved
 | 
			
		||||
	pingPacket[3] = 0 // Reserved
 | 
			
		||||
	binary.LittleEndian.PutUint32(pingPacket[4:], theirIndex)
 | 
			
		||||
	binary.LittleEndian.PutUint64(pingPacket[8:], 0) // Nonce
 | 
			
		||||
	pingPacket = sendCipher.Encrypt(pingPacket, nil, pingData)
 | 
			
		||||
	if _, err := conn.Write(pingPacket); err != nil {
 | 
			
		||||
		log.Fatalf("error writing ping message: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// read ICMP Echo Reply packet
 | 
			
		||||
	replyPacket := make([]byte, 128)
 | 
			
		||||
	n, err = conn.Read(replyPacket)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatalf("error reading ping reply message: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
	replyPacket = replyPacket[:n]
 | 
			
		||||
	if replyPacket[0] != 4 { // Type: Data
 | 
			
		||||
		log.Fatalf("unexpected reply packet type: %d", replyPacket[0])
 | 
			
		||||
	}
 | 
			
		||||
	if replyPacket[1] != 0 || replyPacket[2] != 0 || replyPacket[3] != 0 {
 | 
			
		||||
		log.Fatalf("reply packet has non-zero reserved fields")
 | 
			
		||||
	}
 | 
			
		||||
	replyPacket, err = receiveCipher.Decrypt(nil, nil, replyPacket[16:])
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatalf("error decrypting reply packet: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
	replyHeaderLen := int(replyPacket[0]&0x0f) << 2
 | 
			
		||||
	replyLen := binary.BigEndian.Uint16(replyPacket[2:])
 | 
			
		||||
	replyMessage, err := icmp.ParseMessage(1, replyPacket[replyHeaderLen:replyLen])
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatalf("error parsing echo: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
	echo, ok := replyMessage.Body.(*icmp.Echo)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		log.Fatalf("unexpected reply body type %T", replyMessage.Body)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if echo.ID != 921 || echo.Seq != 438 || string(echo.Data) != "WireGuard" {
 | 
			
		||||
		log.Fatalf("incorrect echo response: %#v", echo)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										22
									
								
								src/routing.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								src/routing.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,22 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/* Thread-safe high level functions for cryptkey routing.
 | 
			
		||||
 *
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
type RoutingTable struct {
 | 
			
		||||
	IPv4  *Trie
 | 
			
		||||
	IPv6  *Trie
 | 
			
		||||
	mutex sync.RWMutex
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (table *RoutingTable) RemovePeer(peer *Peer) {
 | 
			
		||||
	table.mutex.Lock()
 | 
			
		||||
	defer table.mutex.Unlock()
 | 
			
		||||
	table.IPv4 = table.IPv4.RemovePeer(peer)
 | 
			
		||||
	table.IPv6 = table.IPv6.RemovePeer(peer)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										81
									
								
								src/trie.go
									
									
									
									
									
								
							
							
						
						
									
										81
									
								
								src/trie.go
									
									
									
									
									
								
							@ -1,9 +1,11 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import "fmt"
 | 
			
		||||
 | 
			
		||||
/* Syncronization must be done seperatly
 | 
			
		||||
/* Binary trie
 | 
			
		||||
 *
 | 
			
		||||
 * Syncronization done seperatly
 | 
			
		||||
 * See: routing.go
 | 
			
		||||
 *
 | 
			
		||||
 * Todo: Better commenting
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
type Trie struct {
 | 
			
		||||
@ -13,7 +15,6 @@ type Trie struct {
 | 
			
		||||
	peer  *Peer
 | 
			
		||||
 | 
			
		||||
	// Index of "branching" bit
 | 
			
		||||
	// bit_at_shift
 | 
			
		||||
	bit_at_byte  uint
 | 
			
		||||
	bit_at_shift uint
 | 
			
		||||
}
 | 
			
		||||
@ -92,7 +93,14 @@ func (node *Trie) RemovePeer(p *Peer) *Trie {
 | 
			
		||||
	return node.child[0]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (node *Trie) choose(key []byte) byte {
 | 
			
		||||
	return (key[node.bit_at_byte] >> node.bit_at_shift) & 1
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie {
 | 
			
		||||
 | 
			
		||||
	// At leaf
 | 
			
		||||
 | 
			
		||||
	if node == nil {
 | 
			
		||||
		return &Trie{
 | 
			
		||||
			bits:         key,
 | 
			
		||||
@ -107,22 +115,17 @@ func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie {
 | 
			
		||||
 | 
			
		||||
	common := commonBits(node.bits, key)
 | 
			
		||||
	if node.cidr <= cidr && common >= node.cidr {
 | 
			
		||||
		// Check if match the t.bits[:t.cidr] exactly
 | 
			
		||||
		if node.cidr == cidr {
 | 
			
		||||
			node.peer = peer
 | 
			
		||||
			return node
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Go to child
 | 
			
		||||
		bit := (key[node.bit_at_byte] >> node.bit_at_shift) & 1
 | 
			
		||||
		bit := node.choose(key)
 | 
			
		||||
		node.child[bit] = node.child[bit].Insert(key, cidr, peer)
 | 
			
		||||
		return node
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Split node
 | 
			
		||||
 | 
			
		||||
	fmt.Println("new", common)
 | 
			
		||||
 | 
			
		||||
	newNode := &Trie{
 | 
			
		||||
		bits:         key,
 | 
			
		||||
		peer:         peer,
 | 
			
		||||
@ -132,23 +135,53 @@ func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	cidr = min(cidr, common)
 | 
			
		||||
	node.cidr = cidr
 | 
			
		||||
	node.bit_at_byte = cidr / 8
 | 
			
		||||
	node.bit_at_shift = 7 - (cidr % 8)
 | 
			
		||||
 | 
			
		||||
	// bval := node.bits[node.bit_at_byte] >> node.bit_at_shift // todo : remember index
 | 
			
		||||
	// Work in progress
 | 
			
		||||
	node.child[0] = newNode
 | 
			
		||||
	node.child[1] = newNode
 | 
			
		||||
	// Check for shorter prefix
 | 
			
		||||
 | 
			
		||||
	return node
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *Trie) Lookup(key []byte) *Peer {
 | 
			
		||||
	if t == nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	if newNode.cidr == cidr {
 | 
			
		||||
		bit := newNode.choose(node.bits)
 | 
			
		||||
		newNode.child[bit] = node
 | 
			
		||||
		return newNode
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
	// Create new parent for node & newNode
 | 
			
		||||
 | 
			
		||||
	parent := &Trie{
 | 
			
		||||
		bits:         key,
 | 
			
		||||
		peer:         nil,
 | 
			
		||||
		cidr:         cidr,
 | 
			
		||||
		bit_at_byte:  cidr / 8,
 | 
			
		||||
		bit_at_shift: 7 - (cidr % 8),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	bit := parent.choose(key)
 | 
			
		||||
	parent.child[bit] = newNode
 | 
			
		||||
	parent.child[bit^1] = node
 | 
			
		||||
 | 
			
		||||
	return parent
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (node *Trie) Lookup(key []byte) *Peer {
 | 
			
		||||
	var found *Peer
 | 
			
		||||
	size := uint(len(key))
 | 
			
		||||
	for node != nil && commonBits(node.bits, key) >= node.cidr {
 | 
			
		||||
		if node.peer != nil {
 | 
			
		||||
			found = node.peer
 | 
			
		||||
		}
 | 
			
		||||
		if node.bit_at_byte == size {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
		bit := node.choose(key)
 | 
			
		||||
		node = node.child[bit]
 | 
			
		||||
	}
 | 
			
		||||
	return found
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (node *Trie) Count() uint {
 | 
			
		||||
	if node == nil {
 | 
			
		||||
		return 0
 | 
			
		||||
	}
 | 
			
		||||
	l := node.child[0].Count()
 | 
			
		||||
	r := node.child[1].Count()
 | 
			
		||||
	return l + r
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										180
									
								
								src/trie_test.go
									
									
									
									
									
								
							
							
						
						
									
										180
									
								
								src/trie_test.go
									
									
									
									
									
								
							@ -4,6 +4,9 @@ import (
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/* Todo: More comprehensive
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
type testPairCommonBits struct {
 | 
			
		||||
	s1    []byte
 | 
			
		||||
	s2    []byte
 | 
			
		||||
@ -16,6 +19,11 @@ type testPairTrieInsert struct {
 | 
			
		||||
	peer *Peer
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type testPairTrieLookup struct {
 | 
			
		||||
	key  []byte
 | 
			
		||||
	peer *Peer
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func printTrie(t *testing.T, p *Trie) {
 | 
			
		||||
	if p == nil {
 | 
			
		||||
		return
 | 
			
		||||
@ -41,26 +49,176 @@ func TestCommonBits(t *testing.T) {
 | 
			
		||||
			t.Error(
 | 
			
		||||
				"For slice", p.s1, p.s2,
 | 
			
		||||
				"expected match", p.match,
 | 
			
		||||
				"got", v,
 | 
			
		||||
				",but got", v,
 | 
			
		||||
			)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestTrieInsertV4(t *testing.T) {
 | 
			
		||||
/* Test ported from kernel implementation:
 | 
			
		||||
 * selftest/routingtable.h
 | 
			
		||||
 */
 | 
			
		||||
func TestTrieIPv4(t *testing.T) {
 | 
			
		||||
	a := &Peer{}
 | 
			
		||||
	b := &Peer{}
 | 
			
		||||
	c := &Peer{}
 | 
			
		||||
	d := &Peer{}
 | 
			
		||||
	e := &Peer{}
 | 
			
		||||
	g := &Peer{}
 | 
			
		||||
	h := &Peer{}
 | 
			
		||||
 | 
			
		||||
	var trie *Trie
 | 
			
		||||
 | 
			
		||||
	peer1 := Peer{}
 | 
			
		||||
	peer2 := Peer{}
 | 
			
		||||
 | 
			
		||||
	tests := []testPairTrieInsert{
 | 
			
		||||
		{key: []byte{192, 168, 1, 1}, cidr: 24, peer: &peer1},
 | 
			
		||||
		{key: []byte{192, 169, 1, 1}, cidr: 24, peer: &peer2},
 | 
			
		||||
	insert := func(peer *Peer, a, b, c, d byte, cidr uint) {
 | 
			
		||||
		trie = trie.Insert([]byte{a, b, c, d}, cidr, peer)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, p := range tests {
 | 
			
		||||
		trie = trie.Insert(p.key, p.cidr, p.peer)
 | 
			
		||||
		printTrie(t, trie)
 | 
			
		||||
	assertEQ := func(peer *Peer, a, b, c, d byte) {
 | 
			
		||||
		p := trie.Lookup([]byte{a, b, c, d})
 | 
			
		||||
		if p != peer {
 | 
			
		||||
			t.Error("Assert EQ failed")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	assertNEQ := func(peer *Peer, a, b, c, d byte) {
 | 
			
		||||
		p := trie.Lookup([]byte{a, b, c, d})
 | 
			
		||||
		if p == peer {
 | 
			
		||||
			t.Error("Assert NEQ failed")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	insert(a, 192, 168, 4, 0, 24)
 | 
			
		||||
	insert(b, 192, 168, 4, 4, 32)
 | 
			
		||||
	insert(c, 192, 168, 0, 0, 16)
 | 
			
		||||
	insert(d, 192, 95, 5, 64, 27)
 | 
			
		||||
	insert(c, 192, 95, 5, 65, 27) /* replaces previous entry, and maskself is required */
 | 
			
		||||
	insert(e, 0, 0, 0, 0, 0)
 | 
			
		||||
	insert(g, 64, 15, 112, 0, 20)
 | 
			
		||||
	insert(h, 64, 15, 123, 211, 25) /* maskself is required */
 | 
			
		||||
	insert(a, 10, 0, 0, 0, 25)
 | 
			
		||||
	insert(b, 10, 0, 0, 128, 25)
 | 
			
		||||
	insert(a, 10, 1, 0, 0, 30)
 | 
			
		||||
	insert(b, 10, 1, 0, 4, 30)
 | 
			
		||||
	insert(c, 10, 1, 0, 8, 29)
 | 
			
		||||
	insert(d, 10, 1, 0, 16, 29)
 | 
			
		||||
 | 
			
		||||
	assertEQ(a, 192, 168, 4, 20)
 | 
			
		||||
	assertEQ(a, 192, 168, 4, 0)
 | 
			
		||||
	assertEQ(b, 192, 168, 4, 4)
 | 
			
		||||
	assertEQ(c, 192, 168, 200, 182)
 | 
			
		||||
	assertEQ(c, 192, 95, 5, 68)
 | 
			
		||||
	assertEQ(e, 192, 95, 5, 96)
 | 
			
		||||
	assertEQ(g, 64, 15, 116, 26)
 | 
			
		||||
	assertEQ(g, 64, 15, 127, 3)
 | 
			
		||||
 | 
			
		||||
	insert(a, 1, 0, 0, 0, 32)
 | 
			
		||||
	insert(a, 64, 0, 0, 0, 32)
 | 
			
		||||
	insert(a, 128, 0, 0, 0, 32)
 | 
			
		||||
	insert(a, 192, 0, 0, 0, 32)
 | 
			
		||||
	insert(a, 255, 0, 0, 0, 32)
 | 
			
		||||
 | 
			
		||||
	assertEQ(a, 1, 0, 0, 0)
 | 
			
		||||
	assertEQ(a, 64, 0, 0, 0)
 | 
			
		||||
	assertEQ(a, 128, 0, 0, 0)
 | 
			
		||||
	assertEQ(a, 192, 0, 0, 0)
 | 
			
		||||
	assertEQ(a, 255, 0, 0, 0)
 | 
			
		||||
 | 
			
		||||
	trie = trie.RemovePeer(a)
 | 
			
		||||
 | 
			
		||||
	assertNEQ(a, 1, 0, 0, 0)
 | 
			
		||||
	assertNEQ(a, 64, 0, 0, 0)
 | 
			
		||||
	assertNEQ(a, 128, 0, 0, 0)
 | 
			
		||||
	assertNEQ(a, 192, 0, 0, 0)
 | 
			
		||||
	assertNEQ(a, 255, 0, 0, 0)
 | 
			
		||||
 | 
			
		||||
	trie = nil
 | 
			
		||||
 | 
			
		||||
	insert(a, 192, 168, 0, 0, 16)
 | 
			
		||||
	insert(a, 192, 168, 0, 0, 24)
 | 
			
		||||
 | 
			
		||||
	trie = trie.RemovePeer(a)
 | 
			
		||||
 | 
			
		||||
	assertNEQ(a, 192, 168, 0, 1)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* Test ported from kernel implementation:
 | 
			
		||||
 * selftest/routingtable.h
 | 
			
		||||
 */
 | 
			
		||||
func TestTrieIPv6(t *testing.T) {
 | 
			
		||||
	a := &Peer{}
 | 
			
		||||
	b := &Peer{}
 | 
			
		||||
	c := &Peer{}
 | 
			
		||||
	d := &Peer{}
 | 
			
		||||
	e := &Peer{}
 | 
			
		||||
	f := &Peer{}
 | 
			
		||||
	g := &Peer{}
 | 
			
		||||
	h := &Peer{}
 | 
			
		||||
 | 
			
		||||
	var trie *Trie
 | 
			
		||||
 | 
			
		||||
	expand := func(a uint32) []byte {
 | 
			
		||||
		var out [4]byte
 | 
			
		||||
		out[0] = byte(a >> 24 & 0xff)
 | 
			
		||||
		out[1] = byte(a >> 16 & 0xff)
 | 
			
		||||
		out[2] = byte(a >> 8 & 0xff)
 | 
			
		||||
		out[3] = byte(a & 0xff)
 | 
			
		||||
		return out[:]
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	insert := func(peer *Peer, a, b, c, d uint32, cidr uint) {
 | 
			
		||||
		var addr []byte
 | 
			
		||||
		addr = append(addr, expand(a)...)
 | 
			
		||||
		addr = append(addr, expand(b)...)
 | 
			
		||||
		addr = append(addr, expand(c)...)
 | 
			
		||||
		addr = append(addr, expand(d)...)
 | 
			
		||||
		trie = trie.Insert(addr, cidr, peer)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	assertEQ := func(peer *Peer, a, b, c, d uint32) {
 | 
			
		||||
		var addr []byte
 | 
			
		||||
		addr = append(addr, expand(a)...)
 | 
			
		||||
		addr = append(addr, expand(b)...)
 | 
			
		||||
		addr = append(addr, expand(c)...)
 | 
			
		||||
		addr = append(addr, expand(d)...)
 | 
			
		||||
		p := trie.Lookup(addr)
 | 
			
		||||
		if p != peer {
 | 
			
		||||
			t.Error("Assert EQ failed")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	/*
 | 
			
		||||
		assertNEQ := func(peer *Peer, a, b, c, d uint32) {
 | 
			
		||||
			var addr []byte
 | 
			
		||||
			addr = append(addr, expand(a)...)
 | 
			
		||||
			addr = append(addr, expand(b)...)
 | 
			
		||||
			addr = append(addr, expand(c)...)
 | 
			
		||||
			addr = append(addr, expand(d)...)
 | 
			
		||||
			p := trie.Lookup(addr)
 | 
			
		||||
			if p == peer {
 | 
			
		||||
				t.Error("Assert NEQ failed")
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	*/
 | 
			
		||||
 | 
			
		||||
	insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128)
 | 
			
		||||
	insert(c, 0x26075300, 0x60006b00, 0, 0, 64)
 | 
			
		||||
	insert(e, 0, 0, 0, 0, 0)
 | 
			
		||||
	insert(f, 0, 0, 0, 0, 0)
 | 
			
		||||
	insert(g, 0x24046800, 0, 0, 0, 32)
 | 
			
		||||
	insert(h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64)
 | 
			
		||||
	insert(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128)
 | 
			
		||||
	insert(c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
 | 
			
		||||
	insert(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
 | 
			
		||||
 | 
			
		||||
	assertEQ(d, 0x26075300, 0x60006b00, 0, 0xc05f0543)
 | 
			
		||||
	assertEQ(c, 0x26075300, 0x60006b00, 0, 0xc02e01ee)
 | 
			
		||||
	assertEQ(f, 0x26075300, 0x60006b01, 0, 0)
 | 
			
		||||
	assertEQ(g, 0x24046800, 0x40040806, 0, 0x1006)
 | 
			
		||||
	assertEQ(g, 0x24046800, 0x40040806, 0x1234, 0x5678)
 | 
			
		||||
	assertEQ(f, 0x240467ff, 0x40040806, 0x1234, 0x5678)
 | 
			
		||||
	assertEQ(f, 0x24046801, 0x40040806, 0x1234, 0x5678)
 | 
			
		||||
	assertEQ(h, 0x24046800, 0x40040800, 0x1234, 0x5678)
 | 
			
		||||
	assertEQ(h, 0x24046800, 0x40040800, 0, 0)
 | 
			
		||||
	assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010)
 | 
			
		||||
	assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user