From 1eebdf88a320824b8f155caa1d5c725c38d51de8 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Tue, 30 May 2017 22:36:49 +0200 Subject: [PATCH] Beginning work on UAPI and routing table --- src/config.go | 190 +++++++++++++++++++++++++++++++++++++++++++++++ src/device.go | 14 ++++ src/main.go | 28 +++++++ src/misc.go | 8 ++ src/noise.go | 51 +++++++++++++ src/peer.go | 18 +++++ src/trie.go | 154 ++++++++++++++++++++++++++++++++++++++ src/trie_test.go | 66 ++++++++++++++++ 8 files changed, 529 insertions(+) create mode 100644 src/config.go create mode 100644 src/device.go create mode 100644 src/main.go create mode 100644 src/misc.go create mode 100644 src/noise.go create mode 100644 src/peer.go create mode 100644 src/trie.go create mode 100644 src/trie_test.go diff --git a/src/config.go b/src/config.go new file mode 100644 index 0000000..f6f1378 --- /dev/null +++ b/src/config.go @@ -0,0 +1,190 @@ +package main + +import ( + "bufio" + "errors" + "fmt" + "io" + "log" +) + +/* todo : use real error code + * Many of which will be the same + */ +const ( + ipcErrorNoPeer = 0 + ipcErrorNoKeyValue = 1 + ipcErrorInvalidKey = 2 + ipcErrorInvalidPrivateKey = 3 + ipcErrorInvalidPublicKey = 4 + ipcErrorInvalidPort = 5 +) + +type IPCError struct { + Code int +} + +func (s *IPCError) Error() string { + return fmt.Sprintf("IPC error: %d", s.Code) +} + +func (s *IPCError) ErrorCode() int { + return s.Code +} + +// Writes the configuration to the socket +func ipcGetOperation(socket *bufio.ReadWriter, dev *Device) { + +} + +// Creates new config, from old and socket message +func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError { + + scanner := bufio.NewScanner(socket) + + dev.mutex.Lock() + defer dev.mutex.Unlock() + + for scanner.Scan() { + var key string + var value string + var peer *Peer + + // Parse line + + line := scanner.Text() + if line == "\n" { + break + } + fmt.Println(line) + n, err := fmt.Sscanf(line, "%s=%s\n", &key, &value) + if n != 2 || err != nil { + fmt.Println(err, n) + return &IPCError{Code: ipcErrorNoKeyValue} + } + + switch key { + + /* Interface configuration */ + + case "private_key": + if value == "" { + dev.privateKey = NoisePrivateKey{} + } else { + err := dev.privateKey.FromHex(value) + if err != nil { + return &IPCError{Code: ipcErrorInvalidPrivateKey} + } + } + + case "listen_port": + _, err := fmt.Sscanf(value, "%ud", &dev.listenPort) + if err != nil { + return &IPCError{Code: ipcErrorInvalidPort} + } + + case "fwmark": + panic(nil) // not handled yet + + case "public_key": + var pubKey NoisePublicKey + err := pubKey.FromHex(value) + if err != nil { + return &IPCError{Code: ipcErrorInvalidPublicKey} + } + found, ok := dev.peers[pubKey] + if ok { + peer = found + } else { + newPeer := &Peer{ + publicKey: pubKey, + } + peer = newPeer + dev.peers[pubKey] = newPeer + } + + case "replace_peers": + + default: + /* Peer configuration */ + + if peer == nil { + return &IPCError{Code: ipcErrorNoPeer} + } + + switch key { + + case "remove": + peer.mutex.Lock() + + peer = nil + + case "preshared_key": + func() { + peer.mutex.Lock() + defer peer.mutex.Unlock() + }() + + case "endpoint": + func() { + peer.mutex.Lock() + defer peer.mutex.Unlock() + }() + + case "persistent_keepalive_interval": + func() { + peer.mutex.Lock() + defer peer.mutex.Unlock() + }() + + case "replace_allowed_ips": + // remove peer from trie + + case "allowed_ip": + + /* Invalid key */ + + default: + return &IPCError{Code: ipcErrorInvalidKey} + } + } + } + + return nil +} + +func ipcListen(dev *Device, socket io.ReadWriter) error { + + buffered := func(s io.ReadWriter) *bufio.ReadWriter { + reader := bufio.NewReader(s) + writer := bufio.NewWriter(s) + return bufio.NewReadWriter(reader, writer) + }(socket) + + for { + op, err := buffered.ReadString('\n') + if err != nil { + return err + } + log.Println(op) + + switch op { + + case "set=1\n": + err := ipcSetOperation(dev, buffered) + if err != nil { + fmt.Fprintf(buffered, "errno=%d\n", err.ErrorCode()) + return err + } else { + fmt.Fprintf(buffered, "errno=0\n") + } + buffered.Flush() + + case "get=1\n": + + default: + return errors.New("handle this please") + } + } + +} diff --git a/src/device.go b/src/device.go new file mode 100644 index 0000000..cd0835c --- /dev/null +++ b/src/device.go @@ -0,0 +1,14 @@ +package main + +import ( + "sync" +) + +type Device struct { + mutex sync.RWMutex + peers map[NoisePublicKey]*Peer + privateKey NoisePrivateKey + publicKey NoisePublicKey + fwMark uint32 + listenPort uint16 +} diff --git a/src/main.go b/src/main.go new file mode 100644 index 0000000..0f5016d --- /dev/null +++ b/src/main.go @@ -0,0 +1,28 @@ +package main + +import ( + "fmt" + "log" + "net" +) + +func main() { + l, err := net.Listen("unix", "/var/run/wireguard/wg0.sock") + if err != nil { + log.Fatal("listen error:", err) + } + + for { + fd, err := l.Accept() + if err != nil { + log.Fatal("accept error:", err) + } + + var dev Device + go func(conn net.Conn) { + err := ipcListen(&dev, conn) + fmt.Println(err) + }(fd) + } + +} diff --git a/src/misc.go b/src/misc.go new file mode 100644 index 0000000..e1244d6 --- /dev/null +++ b/src/misc.go @@ -0,0 +1,8 @@ +package main + +func min(a uint, b uint) uint { + if a > b { + return b + } + return a +} diff --git a/src/noise.go b/src/noise.go new file mode 100644 index 0000000..d13bdd6 --- /dev/null +++ b/src/noise.go @@ -0,0 +1,51 @@ +package main + +import ( + "encoding/hex" + "errors" +) + +const ( + NoisePublicKeySize = 32 + NoisePrivateKeySize = 32 + NoiseSymmetricKeySize = 32 +) + +type ( + NoisePublicKey [NoisePublicKeySize]byte + NoisePrivateKey [NoisePrivateKeySize]byte + NoiseSymmetricKey [NoiseSymmetricKeySize]byte + NoiseNonce uint64 // padded to 12-bytes +) + +func (key *NoisePrivateKey) FromHex(s string) error { + slice, err := hex.DecodeString(s) + if err != nil { + return err + } + if len(slice) != NoisePrivateKeySize { + return errors.New("Invalid length of hex string for curve25519 point") + } + copy(key[:], slice) + return nil +} + +func (key *NoisePrivateKey) ToHex() string { + return hex.EncodeToString(key[:]) +} + +func (key *NoisePublicKey) FromHex(s string) error { + slice, err := hex.DecodeString(s) + if err != nil { + return err + } + if len(slice) != NoisePublicKeySize { + return errors.New("Invalid length of hex string for curve25519 scalar") + } + copy(key[:], slice) + return nil +} + +func (key *NoisePublicKey) ToHex() string { + return hex.EncodeToString(key[:]) +} diff --git a/src/peer.go b/src/peer.go new file mode 100644 index 0000000..7c000da --- /dev/null +++ b/src/peer.go @@ -0,0 +1,18 @@ +package main + +import ( + "sync" +) + +type KeyPair struct { + recieveKey NoiseSymmetricKey + recieveNonce NoiseNonce + sendKey NoiseSymmetricKey + sendNonce NoiseNonce +} + +type Peer struct { + mutex sync.RWMutex + publicKey NoisePublicKey + presharedKey NoiseSymmetricKey +} diff --git a/src/trie.go b/src/trie.go new file mode 100644 index 0000000..7fd7c5f --- /dev/null +++ b/src/trie.go @@ -0,0 +1,154 @@ +package main + +import "fmt" + +/* Syncronization must be done seperatly + * + */ + +type Trie struct { + cidr uint + child [2]*Trie + bits []byte + peer *Peer + + // Index of "branching" bit + // bit_at_shift + bit_at_byte uint + bit_at_shift uint +} + +/* Finds length of matching prefix + * Maybe there is a faster way + * + * Assumption: len(s1) == len(s2) + */ +func commonBits(s1 []byte, s2 []byte) uint { + var i uint + size := uint(len(s1)) + for i = 0; i < size; i += 1 { + v := s1[i] ^ s2[i] + if v != 0 { + v >>= 1 + if v == 0 { + return i*8 + 7 + } + + v >>= 1 + if v == 0 { + return i*8 + 6 + } + + v >>= 1 + if v == 0 { + return i*8 + 5 + } + + v >>= 1 + if v == 0 { + return i*8 + 4 + } + + v >>= 1 + if v == 0 { + return i*8 + 3 + } + + v >>= 1 + if v == 0 { + return i*8 + 2 + } + + v >>= 1 + if v == 0 { + return i*8 + 1 + } + return i * 8 + } + } + return i * 8 +} + +func (node *Trie) RemovePeer(p *Peer) *Trie { + if node == nil { + return node + } + + // Walk recursivly + + node.child[0] = node.child[0].RemovePeer(p) + node.child[1] = node.child[1].RemovePeer(p) + + if node.peer != p { + return node + } + + // Remove peer & merge + + node.peer = nil + if node.child[0] == nil { + return node.child[1] + } + return node.child[0] +} + +func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie { + if node == nil { + return &Trie{ + bits: key, + peer: peer, + cidr: cidr, + bit_at_byte: cidr / 8, + bit_at_shift: 7 - (cidr % 8), + } + } + + // Traverse deeper + + common := commonBits(node.bits, key) + if node.cidr <= cidr && common >= node.cidr { + // Check if match the t.bits[:t.cidr] exactly + if node.cidr == cidr { + node.peer = peer + return node + } + + // Go to child + bit := (key[node.bit_at_byte] >> node.bit_at_shift) & 1 + node.child[bit] = node.child[bit].Insert(key, cidr, peer) + return node + } + + // Split node + + fmt.Println("new", common) + + newNode := &Trie{ + bits: key, + peer: peer, + cidr: cidr, + bit_at_byte: cidr / 8, + bit_at_shift: 7 - (cidr % 8), + } + + cidr = min(cidr, common) + node.cidr = cidr + node.bit_at_byte = cidr / 8 + node.bit_at_shift = 7 - (cidr % 8) + + // bval := node.bits[node.bit_at_byte] >> node.bit_at_shift // todo : remember index + // Work in progress + node.child[0] = newNode + node.child[1] = newNode + + return node +} + +func (t *Trie) Lookup(key []byte) *Peer { + if t == nil { + return nil + } + + return nil + +} diff --git a/src/trie_test.go b/src/trie_test.go new file mode 100644 index 0000000..ec4cde3 --- /dev/null +++ b/src/trie_test.go @@ -0,0 +1,66 @@ +package main + +import ( + "testing" +) + +type testPairCommonBits struct { + s1 []byte + s2 []byte + match uint +} + +type testPairTrieInsert struct { + key []byte + cidr uint + peer *Peer +} + +func printTrie(t *testing.T, p *Trie) { + if p == nil { + return + } + t.Log(p) + printTrie(t, p.child[0]) + printTrie(t, p.child[1]) +} + +func TestCommonBits(t *testing.T) { + + tests := []testPairCommonBits{ + {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7}, + {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13}, + {s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31}, + {s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15}, + {s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0}, + } + + for _, p := range tests { + v := commonBits(p.s1, p.s2) + if v != p.match { + t.Error( + "For slice", p.s1, p.s2, + "expected match", p.match, + "got", v, + ) + } + } +} + +func TestTrieInsertV4(t *testing.T) { + var trie *Trie + + peer1 := Peer{} + peer2 := Peer{} + + tests := []testPairTrieInsert{ + {key: []byte{192, 168, 1, 1}, cidr: 24, peer: &peer1}, + {key: []byte{192, 169, 1, 1}, cidr: 24, peer: &peer2}, + } + + for _, p := range tests { + trie = trie.Insert(p.key, p.cidr, p.peer) + printTrie(t, trie) + } + +}