diff --git a/src/config.go b/src/config.go index f6f1378..62af67a 100644 --- a/src/config.go +++ b/src/config.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "log" + "net" ) /* todo : use real error code @@ -18,6 +19,7 @@ const ( ipcErrorInvalidPrivateKey = 3 ipcErrorInvalidPublicKey = 4 ipcErrorInvalidPort = 5 + ipcErrorInvalidIPAddress = 6 ) type IPCError struct { @@ -104,6 +106,10 @@ func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError { } case "replace_peers": + if key == "true" { + dev.RemoveAllPeers() + } + // todo: else fail default: /* Peer configuration */ @@ -116,20 +122,27 @@ func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError { case "remove": peer.mutex.Lock() - + dev.RemovePeer(peer.publicKey) peer = nil case "preshared_key": - func() { + err := func() error { peer.mutex.Lock() defer peer.mutex.Unlock() + return peer.presharedKey.FromHex(value) }() + if err != nil { + return &IPCError{Code: ipcErrorInvalidPublicKey} + } case "endpoint": - func() { - peer.mutex.Lock() - defer peer.mutex.Unlock() - }() + ip := net.ParseIP(value) + if ip == nil { + return &IPCError{Code: ipcErrorInvalidIPAddress} + } + peer.mutex.Lock() + peer.endpoint = ip + peer.mutex.Unlock() case "persistent_keepalive_interval": func() { diff --git a/src/device.go b/src/device.go index cd0835c..d03057d 100644 --- a/src/device.go +++ b/src/device.go @@ -5,10 +5,39 @@ import ( ) type Device struct { - mutex sync.RWMutex - peers map[NoisePublicKey]*Peer - privateKey NoisePrivateKey - publicKey NoisePublicKey - fwMark uint32 - listenPort uint16 + mutex sync.RWMutex + peers map[NoisePublicKey]*Peer + privateKey NoisePrivateKey + publicKey NoisePublicKey + fwMark uint32 + listenPort uint16 + routingTable RoutingTable +} + +func (dev *Device) RemovePeer(key NoisePublicKey) { + dev.mutex.Lock() + defer dev.mutex.Unlock() + peer, ok := dev.peers[key] + if !ok { + return + } + peer.mutex.Lock() + dev.routingTable.RemovePeer(peer) + delete(dev.peers, key) +} + +func (dev *Device) RemoveAllAllowedIps(peer *Peer) { + +} + +func (dev *Device) RemoveAllPeers() { + dev.mutex.Lock() + defer dev.mutex.Unlock() + + for key, peer := range dev.peers { + peer.mutex.Lock() + dev.routingTable.RemovePeer(peer) + delete(dev.peers, key) + peer.mutex.Unlock() + } } diff --git a/src/noise.go b/src/noise.go index d13bdd6..5508f9a 100644 --- a/src/noise.go +++ b/src/noise.go @@ -18,34 +18,38 @@ type ( NoiseNonce uint64 // padded to 12-bytes ) -func (key *NoisePrivateKey) FromHex(s string) error { - slice, err := hex.DecodeString(s) +func loadExactHex(dst []byte, src string) error { + slice, err := hex.DecodeString(src) if err != nil { return err } - if len(slice) != NoisePrivateKeySize { - return errors.New("Invalid length of hex string for curve25519 point") + if len(slice) != len(dst) { + return errors.New("Hex string does not fit the slice") } - copy(key[:], slice) + copy(dst, slice) return nil } -func (key *NoisePrivateKey) ToHex() string { +func (key *NoisePrivateKey) FromHex(src string) error { + return loadExactHex(key[:], src) +} + +func (key NoisePrivateKey) ToHex() string { return hex.EncodeToString(key[:]) } -func (key *NoisePublicKey) FromHex(s string) error { - slice, err := hex.DecodeString(s) - if err != nil { - return err - } - if len(slice) != NoisePublicKeySize { - return errors.New("Invalid length of hex string for curve25519 scalar") - } - copy(key[:], slice) - return nil +func (key *NoisePublicKey) FromHex(src string) error { + return loadExactHex(key[:], src) } -func (key *NoisePublicKey) ToHex() string { +func (key NoisePublicKey) ToHex() string { + return hex.EncodeToString(key[:]) +} + +func (key *NoiseSymmetricKey) FromHex(src string) error { + return loadExactHex(key[:], src) +} + +func (key NoiseSymmetricKey) ToHex() string { return hex.EncodeToString(key[:]) } diff --git a/src/peer.go b/src/peer.go index 7c000da..7b2b2a6 100644 --- a/src/peer.go +++ b/src/peer.go @@ -1,6 +1,7 @@ package main import ( + "net" "sync" ) @@ -15,4 +16,5 @@ type Peer struct { mutex sync.RWMutex publicKey NoisePublicKey presharedKey NoiseSymmetricKey + endpoint net.IP } diff --git a/src/ping-test.go b/src/ping-test.go deleted file mode 100644 index 4b58891..0000000 --- a/src/ping-test.go +++ /dev/null @@ -1,175 +0,0 @@ -/* Copyright (C) 2015-2017 Jason A. Donenfeld . All Rights Reserved. */ - -package main - -import ( - "crypto/rand" - "encoding/base64" - "encoding/binary" - "log" - "net" - "time" - - "github.com/dchest/blake2s" - "github.com/titanous/noise" - "golang.org/x/net/icmp" - "golang.org/x/net/ipv4" -) - -func ipChecksum(buf []byte) uint16 { - sum := uint32(0) - for ; len(buf) >= 2; buf = buf[2:] { - sum += uint32(buf[0])<<8 | uint32(buf[1]) - } - if len(buf) > 0 { - sum += uint32(buf[0]) << 8 - } - for sum > 0xffff { - sum = (sum >> 16) + (sum & 0xffff) - } - csum := ^uint16(sum) - if csum == 0 { - csum = 0xffff - } - return csum -} - -func main() { - ourPrivate, _ := base64.StdEncoding.DecodeString("WAmgVYXkbT2bCtdcDwolI88/iVi/aV3/PHcUBTQSYmo=") - ourPublic, _ := base64.StdEncoding.DecodeString("K5sF9yESrSBsOXPd6TcpKNgqoy1Ik3ZFKl4FolzrRyI=") - theirPublic, _ := base64.StdEncoding.DecodeString("qRCwZSKInrMAq5sepfCdaCsRJaoLe5jhtzfiw7CjbwM=") - preshared, _ := base64.StdEncoding.DecodeString("FpCyhws9cxwWoV4xELtfJvjJN+zQVRPISllRWgeopVE=") - cs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashBLAKE2s) - hs := noise.NewHandshakeState(noise.Config{ - CipherSuite: cs, - Random: rand.Reader, - Pattern: noise.HandshakeIK, - Initiator: true, - Prologue: []byte("WireGuard v1 zx2c4 Jason@zx2c4.com"), - PresharedKey: preshared, - PresharedKeyPlacement: 2, - StaticKeypair: noise.DHKey{Private: ourPrivate, Public: ourPublic}, - PeerStatic: theirPublic, - }) - conn, err := net.Dial("udp", "demo.wireguard.io:12913") - if err != nil { - log.Fatalf("error dialing udp socket: %s", err) - } - defer conn.Close() - - // write handshake initiation packet - now := time.Now() - tai64n := make([]byte, 12) - binary.BigEndian.PutUint64(tai64n[:], 4611686018427387914+uint64(now.Unix())) - binary.BigEndian.PutUint32(tai64n[8:], uint32(now.UnixNano())) - initiationPacket := make([]byte, 8) - initiationPacket[0] = 1 // Type: Initiation - initiationPacket[1] = 0 // Reserved - initiationPacket[2] = 0 // Reserved - initiationPacket[3] = 0 // Reserved - binary.LittleEndian.PutUint32(initiationPacket[4:], 28) // Sender index: 28 (arbitrary) - initiationPacket, _, _ = hs.WriteMessage(initiationPacket, tai64n) - hasher, _ := blake2s.New(&blake2s.Config{Size: 32}) - hasher.Write([]byte("mac1----")) - hasher.Write(theirPublic) - hasher, _ = blake2s.New(&blake2s.Config{Size: 16, Key: hasher.Sum(nil)}) - hasher.Write(initiationPacket) - initiationPacket = append(initiationPacket, hasher.Sum(nil)[:16]...) - initiationPacket = append(initiationPacket, make([]byte, 16)...) - if _, err := conn.Write(initiationPacket); err != nil { - log.Fatalf("error writing initiation packet: %s", err) - } - - // read handshake response packet - responsePacket := make([]byte, 92) - n, err := conn.Read(responsePacket) - if err != nil { - log.Fatalf("error reading response packet: %s", err) - } - if n != len(responsePacket) { - log.Fatalf("response packet too short: want %d, got %d", len(responsePacket), n) - } - if responsePacket[0] != 2 { // Type: Response - log.Fatalf("response packet type wrong: want %d, got %d", 2, responsePacket[0]) - } - if responsePacket[1] != 0 || responsePacket[2] != 0 || responsePacket[3] != 0 { - log.Fatalf("response packet has non-zero reserved fields") - } - theirIndex := binary.LittleEndian.Uint32(responsePacket[4:]) - ourIndex := binary.LittleEndian.Uint32(responsePacket[8:]) - if ourIndex != 28 { - log.Fatalf("response packet index wrong: want %d, got %d", 28, ourIndex) - } - payload, sendCipher, receiveCipher, err := hs.ReadMessage(nil, responsePacket[12:60]) - if err != nil { - log.Fatalf("error reading handshake message: %s", err) - } - if len(payload) > 0 { - log.Fatalf("unexpected payload: %x", payload) - } - - // write ICMP Echo packet - pingMessage, _ := (&icmp.Message{ - Type: ipv4.ICMPTypeEcho, - Body: &icmp.Echo{ - ID: 921, - Seq: 438, - Data: []byte("WireGuard"), - }, - }).Marshal(nil) - pingHeader, err := (&ipv4.Header{ - Version: ipv4.Version, - Len: ipv4.HeaderLen, - TotalLen: ipv4.HeaderLen + len(pingMessage), - Protocol: 1, // ICMP - TTL: 20, - Src: net.IPv4(10, 189, 129, 2), - Dst: net.IPv4(10, 189, 129, 1), - }).Marshal() - binary.BigEndian.PutUint16(pingHeader[2:], uint16(ipv4.HeaderLen+len(pingMessage))) // fix the length endianness on BSDs - pingData := append(pingHeader, pingMessage...) - binary.BigEndian.PutUint16(pingData[10:], ipChecksum(pingData)) - pingPacket := make([]byte, 16) - pingPacket[0] = 4 // Type: Data - pingPacket[1] = 0 // Reserved - pingPacket[2] = 0 // Reserved - pingPacket[3] = 0 // Reserved - binary.LittleEndian.PutUint32(pingPacket[4:], theirIndex) - binary.LittleEndian.PutUint64(pingPacket[8:], 0) // Nonce - pingPacket = sendCipher.Encrypt(pingPacket, nil, pingData) - if _, err := conn.Write(pingPacket); err != nil { - log.Fatalf("error writing ping message: %s", err) - } - - // read ICMP Echo Reply packet - replyPacket := make([]byte, 128) - n, err = conn.Read(replyPacket) - if err != nil { - log.Fatalf("error reading ping reply message: %s", err) - } - replyPacket = replyPacket[:n] - if replyPacket[0] != 4 { // Type: Data - log.Fatalf("unexpected reply packet type: %d", replyPacket[0]) - } - if replyPacket[1] != 0 || replyPacket[2] != 0 || replyPacket[3] != 0 { - log.Fatalf("reply packet has non-zero reserved fields") - } - replyPacket, err = receiveCipher.Decrypt(nil, nil, replyPacket[16:]) - if err != nil { - log.Fatalf("error decrypting reply packet: %s", err) - } - replyHeaderLen := int(replyPacket[0]&0x0f) << 2 - replyLen := binary.BigEndian.Uint16(replyPacket[2:]) - replyMessage, err := icmp.ParseMessage(1, replyPacket[replyHeaderLen:replyLen]) - if err != nil { - log.Fatalf("error parsing echo: %s", err) - } - echo, ok := replyMessage.Body.(*icmp.Echo) - if !ok { - log.Fatalf("unexpected reply body type %T", replyMessage.Body) - } - - if echo.ID != 921 || echo.Seq != 438 || string(echo.Data) != "WireGuard" { - log.Fatalf("incorrect echo response: %#v", echo) - } -} diff --git a/src/routing.go b/src/routing.go new file mode 100644 index 0000000..99b180c --- /dev/null +++ b/src/routing.go @@ -0,0 +1,22 @@ +package main + +import ( + "sync" +) + +/* Thread-safe high level functions for cryptkey routing. + * + */ + +type RoutingTable struct { + IPv4 *Trie + IPv6 *Trie + mutex sync.RWMutex +} + +func (table *RoutingTable) RemovePeer(peer *Peer) { + table.mutex.Lock() + defer table.mutex.Unlock() + table.IPv4 = table.IPv4.RemovePeer(peer) + table.IPv6 = table.IPv6.RemovePeer(peer) +} diff --git a/src/trie.go b/src/trie.go index 7fd7c5f..31a4d92 100644 --- a/src/trie.go +++ b/src/trie.go @@ -1,9 +1,11 @@ package main -import "fmt" - -/* Syncronization must be done seperatly +/* Binary trie * + * Syncronization done seperatly + * See: routing.go + * + * Todo: Better commenting */ type Trie struct { @@ -13,7 +15,6 @@ type Trie struct { peer *Peer // Index of "branching" bit - // bit_at_shift bit_at_byte uint bit_at_shift uint } @@ -92,7 +93,14 @@ func (node *Trie) RemovePeer(p *Peer) *Trie { return node.child[0] } +func (node *Trie) choose(key []byte) byte { + return (key[node.bit_at_byte] >> node.bit_at_shift) & 1 +} + func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie { + + // At leaf + if node == nil { return &Trie{ bits: key, @@ -107,22 +115,17 @@ func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie { common := commonBits(node.bits, key) if node.cidr <= cidr && common >= node.cidr { - // Check if match the t.bits[:t.cidr] exactly if node.cidr == cidr { node.peer = peer return node } - - // Go to child - bit := (key[node.bit_at_byte] >> node.bit_at_shift) & 1 + bit := node.choose(key) node.child[bit] = node.child[bit].Insert(key, cidr, peer) return node } // Split node - fmt.Println("new", common) - newNode := &Trie{ bits: key, peer: peer, @@ -132,23 +135,53 @@ func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie { } cidr = min(cidr, common) - node.cidr = cidr - node.bit_at_byte = cidr / 8 - node.bit_at_shift = 7 - (cidr % 8) - // bval := node.bits[node.bit_at_byte] >> node.bit_at_shift // todo : remember index - // Work in progress - node.child[0] = newNode - node.child[1] = newNode + // Check for shorter prefix - return node -} - -func (t *Trie) Lookup(key []byte) *Peer { - if t == nil { - return nil + if newNode.cidr == cidr { + bit := newNode.choose(node.bits) + newNode.child[bit] = node + return newNode } - return nil + // Create new parent for node & newNode + parent := &Trie{ + bits: key, + peer: nil, + cidr: cidr, + bit_at_byte: cidr / 8, + bit_at_shift: 7 - (cidr % 8), + } + + bit := parent.choose(key) + parent.child[bit] = newNode + parent.child[bit^1] = node + + return parent +} + +func (node *Trie) Lookup(key []byte) *Peer { + var found *Peer + size := uint(len(key)) + for node != nil && commonBits(node.bits, key) >= node.cidr { + if node.peer != nil { + found = node.peer + } + if node.bit_at_byte == size { + break + } + bit := node.choose(key) + node = node.child[bit] + } + return found +} + +func (node *Trie) Count() uint { + if node == nil { + return 0 + } + l := node.child[0].Count() + r := node.child[1].Count() + return l + r } diff --git a/src/trie_test.go b/src/trie_test.go index ec4cde3..35af0aa 100644 --- a/src/trie_test.go +++ b/src/trie_test.go @@ -4,6 +4,9 @@ import ( "testing" ) +/* Todo: More comprehensive + */ + type testPairCommonBits struct { s1 []byte s2 []byte @@ -16,6 +19,11 @@ type testPairTrieInsert struct { peer *Peer } +type testPairTrieLookup struct { + key []byte + peer *Peer +} + func printTrie(t *testing.T, p *Trie) { if p == nil { return @@ -41,26 +49,176 @@ func TestCommonBits(t *testing.T) { t.Error( "For slice", p.s1, p.s2, "expected match", p.match, - "got", v, + ",but got", v, ) } } } -func TestTrieInsertV4(t *testing.T) { +/* Test ported from kernel implementation: + * selftest/routingtable.h + */ +func TestTrieIPv4(t *testing.T) { + a := &Peer{} + b := &Peer{} + c := &Peer{} + d := &Peer{} + e := &Peer{} + g := &Peer{} + h := &Peer{} + var trie *Trie - peer1 := Peer{} - peer2 := Peer{} - - tests := []testPairTrieInsert{ - {key: []byte{192, 168, 1, 1}, cidr: 24, peer: &peer1}, - {key: []byte{192, 169, 1, 1}, cidr: 24, peer: &peer2}, + insert := func(peer *Peer, a, b, c, d byte, cidr uint) { + trie = trie.Insert([]byte{a, b, c, d}, cidr, peer) } - for _, p := range tests { - trie = trie.Insert(p.key, p.cidr, p.peer) - printTrie(t, trie) + assertEQ := func(peer *Peer, a, b, c, d byte) { + p := trie.Lookup([]byte{a, b, c, d}) + if p != peer { + t.Error("Assert EQ failed") + } } + assertNEQ := func(peer *Peer, a, b, c, d byte) { + p := trie.Lookup([]byte{a, b, c, d}) + if p == peer { + t.Error("Assert NEQ failed") + } + } + + insert(a, 192, 168, 4, 0, 24) + insert(b, 192, 168, 4, 4, 32) + insert(c, 192, 168, 0, 0, 16) + insert(d, 192, 95, 5, 64, 27) + insert(c, 192, 95, 5, 65, 27) /* replaces previous entry, and maskself is required */ + insert(e, 0, 0, 0, 0, 0) + insert(g, 64, 15, 112, 0, 20) + insert(h, 64, 15, 123, 211, 25) /* maskself is required */ + insert(a, 10, 0, 0, 0, 25) + insert(b, 10, 0, 0, 128, 25) + insert(a, 10, 1, 0, 0, 30) + insert(b, 10, 1, 0, 4, 30) + insert(c, 10, 1, 0, 8, 29) + insert(d, 10, 1, 0, 16, 29) + + assertEQ(a, 192, 168, 4, 20) + assertEQ(a, 192, 168, 4, 0) + assertEQ(b, 192, 168, 4, 4) + assertEQ(c, 192, 168, 200, 182) + assertEQ(c, 192, 95, 5, 68) + assertEQ(e, 192, 95, 5, 96) + assertEQ(g, 64, 15, 116, 26) + assertEQ(g, 64, 15, 127, 3) + + insert(a, 1, 0, 0, 0, 32) + insert(a, 64, 0, 0, 0, 32) + insert(a, 128, 0, 0, 0, 32) + insert(a, 192, 0, 0, 0, 32) + insert(a, 255, 0, 0, 0, 32) + + assertEQ(a, 1, 0, 0, 0) + assertEQ(a, 64, 0, 0, 0) + assertEQ(a, 128, 0, 0, 0) + assertEQ(a, 192, 0, 0, 0) + assertEQ(a, 255, 0, 0, 0) + + trie = trie.RemovePeer(a) + + assertNEQ(a, 1, 0, 0, 0) + assertNEQ(a, 64, 0, 0, 0) + assertNEQ(a, 128, 0, 0, 0) + assertNEQ(a, 192, 0, 0, 0) + assertNEQ(a, 255, 0, 0, 0) + + trie = nil + + insert(a, 192, 168, 0, 0, 16) + insert(a, 192, 168, 0, 0, 24) + + trie = trie.RemovePeer(a) + + assertNEQ(a, 192, 168, 0, 1) +} + +/* Test ported from kernel implementation: + * selftest/routingtable.h + */ +func TestTrieIPv6(t *testing.T) { + a := &Peer{} + b := &Peer{} + c := &Peer{} + d := &Peer{} + e := &Peer{} + f := &Peer{} + g := &Peer{} + h := &Peer{} + + var trie *Trie + + expand := func(a uint32) []byte { + var out [4]byte + out[0] = byte(a >> 24 & 0xff) + out[1] = byte(a >> 16 & 0xff) + out[2] = byte(a >> 8 & 0xff) + out[3] = byte(a & 0xff) + return out[:] + } + + insert := func(peer *Peer, a, b, c, d uint32, cidr uint) { + var addr []byte + addr = append(addr, expand(a)...) + addr = append(addr, expand(b)...) + addr = append(addr, expand(c)...) + addr = append(addr, expand(d)...) + trie = trie.Insert(addr, cidr, peer) + } + + assertEQ := func(peer *Peer, a, b, c, d uint32) { + var addr []byte + addr = append(addr, expand(a)...) + addr = append(addr, expand(b)...) + addr = append(addr, expand(c)...) + addr = append(addr, expand(d)...) + p := trie.Lookup(addr) + if p != peer { + t.Error("Assert EQ failed") + } + } + + /* + assertNEQ := func(peer *Peer, a, b, c, d uint32) { + var addr []byte + addr = append(addr, expand(a)...) + addr = append(addr, expand(b)...) + addr = append(addr, expand(c)...) + addr = append(addr, expand(d)...) + p := trie.Lookup(addr) + if p == peer { + t.Error("Assert NEQ failed") + } + } + */ + + insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128) + insert(c, 0x26075300, 0x60006b00, 0, 0, 64) + insert(e, 0, 0, 0, 0, 0) + insert(f, 0, 0, 0, 0, 0) + insert(g, 0x24046800, 0, 0, 0, 32) + insert(h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64) + insert(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128) + insert(c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) + insert(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) + + assertEQ(d, 0x26075300, 0x60006b00, 0, 0xc05f0543) + assertEQ(c, 0x26075300, 0x60006b00, 0, 0xc02e01ee) + assertEQ(f, 0x26075300, 0x60006b01, 0, 0) + assertEQ(g, 0x24046800, 0x40040806, 0, 0x1006) + assertEQ(g, 0x24046800, 0x40040806, 0x1234, 0x5678) + assertEQ(f, 0x240467ff, 0x40040806, 0x1234, 0x5678) + assertEQ(f, 0x24046801, 0x40040806, 0x1234, 0x5678) + assertEQ(h, 0x24046800, 0x40040800, 0x1234, 0x5678) + assertEQ(h, 0x24046800, 0x40040800, 0, 0) + assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010) + assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef) }