Merge branch 'source-caching'
This commit is contained in:
commit
b5ae42349c
115
src/conn.go
115
src/conn.go
@ -2,10 +2,35 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
"golang.org/x/net/ipv6"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic
|
||||||
|
*/
|
||||||
|
type Bind interface {
|
||||||
|
SetMark(value uint32) error
|
||||||
|
ReceiveIPv6(buff []byte) (int, Endpoint, error)
|
||||||
|
ReceiveIPv4(buff []byte) (int, Endpoint, error)
|
||||||
|
Send(buff []byte, end Endpoint) error
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
/* An Endpoint maintains the source/destination caching for a peer
|
||||||
|
*
|
||||||
|
* dst : the remote address of a peer ("endpoint" in uapi terminology)
|
||||||
|
* src : the local address from which datagrams originate going to the peer
|
||||||
|
*/
|
||||||
|
type Endpoint interface {
|
||||||
|
ClearSrc() // clears the source address
|
||||||
|
SrcToString() string // returns the local source address (ip:port)
|
||||||
|
DstToString() string // returns the destination address (ip:port)
|
||||||
|
DstToBytes() []byte // used for mac2 cookie calculations
|
||||||
|
DstIP() net.IP
|
||||||
|
SrcIP() net.IP
|
||||||
|
}
|
||||||
|
|
||||||
func parseEndpoint(s string) (*net.UDPAddr, error) {
|
func parseEndpoint(s string) (*net.UDPAddr, error) {
|
||||||
|
|
||||||
// ensure that the host is an IP address
|
// ensure that the host is an IP address
|
||||||
@ -27,63 +52,83 @@ func parseEndpoint(s string) (*net.UDPAddr, error) {
|
|||||||
return addr, err
|
return addr, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateUDPConn(device *Device) error {
|
/* Must hold device and net lock
|
||||||
|
*/
|
||||||
|
func unsafeCloseUDPListener(device *Device) error {
|
||||||
|
var err error
|
||||||
|
netc := &device.net
|
||||||
|
if netc.bind != nil {
|
||||||
|
err = netc.bind.Close()
|
||||||
|
netc.bind = nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// must inform all listeners
|
||||||
|
func UpdateUDPListener(device *Device) error {
|
||||||
|
device.mutex.Lock()
|
||||||
|
defer device.mutex.Unlock()
|
||||||
|
|
||||||
netc := &device.net
|
netc := &device.net
|
||||||
netc.mutex.Lock()
|
netc.mutex.Lock()
|
||||||
defer netc.mutex.Unlock()
|
defer netc.mutex.Unlock()
|
||||||
|
|
||||||
// close existing connection
|
// close existing sockets
|
||||||
|
|
||||||
if netc.conn != nil {
|
if err := unsafeCloseUDPListener(device); err != nil {
|
||||||
netc.conn.Close()
|
return err
|
||||||
netc.conn = nil
|
|
||||||
|
|
||||||
// We need for that fd to be closed in all other go routines, which
|
|
||||||
// means we have to wait. TODO: find less horrible way of doing this.
|
|
||||||
time.Sleep(time.Second / 2)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// open new connection
|
// assumption: netc.update WaitGroup should be exactly 1
|
||||||
|
|
||||||
|
// open new sockets
|
||||||
|
|
||||||
if device.tun.isUp.Get() {
|
if device.tun.isUp.Get() {
|
||||||
|
|
||||||
// listen on new address
|
device.log.Debug.Println("UDP bind updating")
|
||||||
|
|
||||||
conn, err := net.ListenUDP("udp", netc.addr)
|
// bind to new port
|
||||||
|
|
||||||
|
var err error
|
||||||
|
netc.bind, netc.port, err = CreateBind(netc.port)
|
||||||
|
if err != nil {
|
||||||
|
netc.bind = nil
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// set mark
|
||||||
|
|
||||||
|
err = netc.bind.SetMark(netc.fwmark)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// set fwmark
|
// clear cached source addresses
|
||||||
|
|
||||||
err = setMark(netc.conn, netc.fwmark)
|
for _, peer := range device.peers {
|
||||||
if err != nil {
|
peer.mutex.Lock()
|
||||||
return err
|
if peer.endpoint != nil {
|
||||||
|
peer.endpoint.ClearSrc()
|
||||||
|
}
|
||||||
|
peer.mutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// retrieve port (may have been chosen by kernel)
|
// decrease waitgroup to 0
|
||||||
|
|
||||||
addr := conn.LocalAddr()
|
go device.RoutineReceiveIncomming(ipv4.Version, netc.bind)
|
||||||
netc.conn = conn
|
go device.RoutineReceiveIncomming(ipv6.Version, netc.bind)
|
||||||
netc.addr, _ = net.ResolveUDPAddr(
|
|
||||||
addr.Network(),
|
|
||||||
addr.String(),
|
|
||||||
)
|
|
||||||
|
|
||||||
// notify goroutines
|
device.log.Debug.Println("UDP bind has been updated")
|
||||||
|
|
||||||
signalSend(device.signal.newUDPConn)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func closeUDPConn(device *Device) {
|
func CloseUDPListener(device *Device) error {
|
||||||
netc := &device.net
|
device.mutex.Lock()
|
||||||
netc.mutex.Lock()
|
device.net.mutex.Lock()
|
||||||
if netc.conn != nil {
|
err := unsafeCloseUDPListener(device)
|
||||||
netc.conn.Close()
|
device.net.mutex.Unlock()
|
||||||
}
|
device.mutex.Unlock()
|
||||||
netc.mutex.Unlock()
|
return err
|
||||||
signalSend(device.signal.newUDPConn)
|
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,126 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
func setMark(conn *net.UDPConn, value uint32) error {
|
/* This code is meant to be a temporary solution
|
||||||
|
* on platforms for which the sticky socket / source caching behavior
|
||||||
|
* has not yet been implemented.
|
||||||
|
*
|
||||||
|
* See conn_linux.go for an implementation on the linux platform.
|
||||||
|
*/
|
||||||
|
|
||||||
|
type NativeBind struct {
|
||||||
|
ipv4 *net.UDPConn
|
||||||
|
ipv6 *net.UDPConn
|
||||||
|
}
|
||||||
|
|
||||||
|
type NativeEndpoint net.UDPAddr
|
||||||
|
|
||||||
|
var _ Bind = (*NativeBind)(nil)
|
||||||
|
var _ Endpoint = (*NativeEndpoint)(nil)
|
||||||
|
|
||||||
|
func CreateEndpoint(s string) (Endpoint, error) {
|
||||||
|
addr, err := parseEndpoint(s)
|
||||||
|
return (*NativeEndpoint)(addr), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_ *NativeEndpoint) ClearSrc() {}
|
||||||
|
|
||||||
|
func (e *NativeEndpoint) DstIP() net.IP {
|
||||||
|
return (*net.UDPAddr)(e).IP
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *NativeEndpoint) SrcIP() net.IP {
|
||||||
|
return nil // not supported
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *NativeEndpoint) DstToBytes() []byte {
|
||||||
|
addr := (*net.UDPAddr)(e)
|
||||||
|
out := addr.IP
|
||||||
|
out = append(out, byte(addr.Port&0xff))
|
||||||
|
out = append(out, byte((addr.Port>>8)&0xff))
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *NativeEndpoint) DstToString() string {
|
||||||
|
return (*net.UDPAddr)(e).String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *NativeEndpoint) SrcToString() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func listenNet(network string, port int) (*net.UDPConn, int, error) {
|
||||||
|
|
||||||
|
// listen
|
||||||
|
|
||||||
|
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// retrieve port
|
||||||
|
|
||||||
|
laddr := conn.LocalAddr()
|
||||||
|
uaddr, err := net.ResolveUDPAddr(
|
||||||
|
laddr.Network(),
|
||||||
|
laddr.String(),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
return conn, uaddr.Port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateBind(uport uint16) (Bind, uint16, error) {
|
||||||
|
var err error
|
||||||
|
var bind NativeBind
|
||||||
|
|
||||||
|
port := int(uport)
|
||||||
|
|
||||||
|
bind.ipv4, port, err = listenNet("udp4", port)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
bind.ipv6, port, err = listenNet("udp6", port)
|
||||||
|
if err != nil {
|
||||||
|
bind.ipv4.Close()
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &bind, uint16(port), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *NativeBind) Close() error {
|
||||||
|
err1 := bind.ipv4.Close()
|
||||||
|
err2 := bind.ipv6.Close()
|
||||||
|
if err1 != nil {
|
||||||
|
return err1
|
||||||
|
}
|
||||||
|
return err2
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
|
||||||
|
n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
|
||||||
|
return n, (*NativeEndpoint)(endpoint), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
|
||||||
|
n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
|
||||||
|
return n, (*NativeEndpoint)(endpoint), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *NativeBind) Send(buff []byte, endpoint Endpoint) error {
|
||||||
|
var err error
|
||||||
|
nend := endpoint.(*NativeEndpoint)
|
||||||
|
if nend.IP.To16() != nil {
|
||||||
|
_, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
|
||||||
|
} else {
|
||||||
|
_, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *NativeBind) SetMark(_ uint32) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,7 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
"net"
|
"net"
|
||||||
@ -15,20 +16,230 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
/* Supports source address caching
|
/* Supports source address caching
|
||||||
*
|
|
||||||
* It is important that the endpoint is only updated after the packet content has been authenticated.
|
|
||||||
*
|
*
|
||||||
* Currently there is no way to achieve this within the net package:
|
* Currently there is no way to achieve this within the net package:
|
||||||
* See e.g. https://github.com/golang/go/issues/17930
|
* See e.g. https://github.com/golang/go/issues/17930
|
||||||
|
* So this code is remains platform dependent.
|
||||||
*/
|
*/
|
||||||
type Endpoint struct {
|
type NativeEndpoint struct {
|
||||||
// source (selected based on dst type)
|
src unix.RawSockaddrInet6
|
||||||
// (could use RawSockaddrAny and unsafe)
|
dst unix.RawSockaddrInet6
|
||||||
srcIPv6 unix.RawSockaddrInet6
|
}
|
||||||
srcIPv4 unix.RawSockaddrInet4
|
|
||||||
srcIf4 int32
|
|
||||||
|
|
||||||
dst unix.RawSockaddrAny
|
type NativeBind struct {
|
||||||
|
sock4 int
|
||||||
|
sock6 int
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Endpoint = (*NativeEndpoint)(nil)
|
||||||
|
var _ Bind = NativeBind{}
|
||||||
|
|
||||||
|
type IPv4Source struct {
|
||||||
|
src unix.RawSockaddrInet4
|
||||||
|
Ifindex int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func htons(val uint16) uint16 {
|
||||||
|
var out [unsafe.Sizeof(val)]byte
|
||||||
|
binary.BigEndian.PutUint16(out[:], val)
|
||||||
|
return *((*uint16)(unsafe.Pointer(&out[0])))
|
||||||
|
}
|
||||||
|
|
||||||
|
func ntohs(val uint16) uint16 {
|
||||||
|
tmp := ((*[unsafe.Sizeof(val)]byte)(unsafe.Pointer(&val)))
|
||||||
|
return binary.BigEndian.Uint16((*tmp)[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateEndpoint(s string) (Endpoint, error) {
|
||||||
|
var end NativeEndpoint
|
||||||
|
addr, err := parseEndpoint(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ipv4 := addr.IP.To4()
|
||||||
|
if ipv4 != nil {
|
||||||
|
dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
|
||||||
|
dst.Family = unix.AF_INET
|
||||||
|
dst.Port = htons(uint16(addr.Port))
|
||||||
|
dst.Zero = [8]byte{}
|
||||||
|
copy(dst.Addr[:], ipv4)
|
||||||
|
end.ClearSrc()
|
||||||
|
return &end, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ipv6 := addr.IP.To16()
|
||||||
|
if ipv6 != nil {
|
||||||
|
zone, err := zoneToUint32(addr.Zone)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
dst := &end.dst
|
||||||
|
dst.Family = unix.AF_INET6
|
||||||
|
dst.Port = htons(uint16(addr.Port))
|
||||||
|
dst.Flowinfo = 0
|
||||||
|
dst.Scope_id = zone
|
||||||
|
copy(dst.Addr[:], ipv6[:])
|
||||||
|
end.ClearSrc()
|
||||||
|
return &end, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errors.New("Failed to recognize IP address format")
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateBind(port uint16) (Bind, uint16, error) {
|
||||||
|
var err error
|
||||||
|
var bind NativeBind
|
||||||
|
|
||||||
|
bind.sock6, port, err = create6(port)
|
||||||
|
if err != nil {
|
||||||
|
return nil, port, err
|
||||||
|
}
|
||||||
|
|
||||||
|
bind.sock4, port, err = create4(port)
|
||||||
|
if err != nil {
|
||||||
|
unix.Close(bind.sock6)
|
||||||
|
}
|
||||||
|
return bind, port, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind NativeBind) SetMark(value uint32) error {
|
||||||
|
err := unix.SetsockoptInt(
|
||||||
|
bind.sock6,
|
||||||
|
unix.SOL_SOCKET,
|
||||||
|
unix.SO_MARK,
|
||||||
|
int(value),
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return unix.SetsockoptInt(
|
||||||
|
bind.sock4,
|
||||||
|
unix.SOL_SOCKET,
|
||||||
|
unix.SO_MARK,
|
||||||
|
int(value),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func closeUnblock(fd int) error {
|
||||||
|
// shutdown to unblock readers
|
||||||
|
unix.Shutdown(fd, unix.SHUT_RD)
|
||||||
|
return unix.Close(fd)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind NativeBind) Close() error {
|
||||||
|
err1 := closeUnblock(bind.sock6)
|
||||||
|
err2 := closeUnblock(bind.sock4)
|
||||||
|
if err1 != nil {
|
||||||
|
return err1
|
||||||
|
}
|
||||||
|
return err2
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
|
||||||
|
var end NativeEndpoint
|
||||||
|
n, err := receive6(
|
||||||
|
bind.sock6,
|
||||||
|
buff,
|
||||||
|
&end,
|
||||||
|
)
|
||||||
|
return n, &end, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
|
||||||
|
var end NativeEndpoint
|
||||||
|
n, err := receive4(
|
||||||
|
bind.sock4,
|
||||||
|
buff,
|
||||||
|
&end,
|
||||||
|
)
|
||||||
|
return n, &end, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind NativeBind) Send(buff []byte, end Endpoint) error {
|
||||||
|
nend := end.(*NativeEndpoint)
|
||||||
|
switch nend.dst.Family {
|
||||||
|
case unix.AF_INET6:
|
||||||
|
return send6(bind.sock6, nend, buff)
|
||||||
|
case unix.AF_INET:
|
||||||
|
return send4(bind.sock4, nend, buff)
|
||||||
|
default:
|
||||||
|
return errors.New("Unknown address family of destination")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func sockaddrToString(addr unix.RawSockaddrInet6) string {
|
||||||
|
var udpAddr net.UDPAddr
|
||||||
|
|
||||||
|
switch addr.Family {
|
||||||
|
case unix.AF_INET6:
|
||||||
|
udpAddr.Port = int(ntohs(addr.Port))
|
||||||
|
udpAddr.IP = addr.Addr[:]
|
||||||
|
return udpAddr.String()
|
||||||
|
|
||||||
|
case unix.AF_INET:
|
||||||
|
ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr))
|
||||||
|
udpAddr.Port = int(ntohs(ptr.Port))
|
||||||
|
udpAddr.IP = net.IPv4(
|
||||||
|
ptr.Addr[0],
|
||||||
|
ptr.Addr[1],
|
||||||
|
ptr.Addr[2],
|
||||||
|
ptr.Addr[3],
|
||||||
|
)
|
||||||
|
return udpAddr.String()
|
||||||
|
|
||||||
|
default:
|
||||||
|
return "<unknown address family>"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func rawAddrToIP(addr unix.RawSockaddrInet6) net.IP {
|
||||||
|
switch addr.Family {
|
||||||
|
case unix.AF_INET6:
|
||||||
|
return addr.Addr[:]
|
||||||
|
case unix.AF_INET:
|
||||||
|
ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr))
|
||||||
|
return net.IPv4(
|
||||||
|
ptr.Addr[0],
|
||||||
|
ptr.Addr[1],
|
||||||
|
ptr.Addr[2],
|
||||||
|
ptr.Addr[3],
|
||||||
|
)
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (end *NativeEndpoint) SrcIP() net.IP {
|
||||||
|
return rawAddrToIP(end.src)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (end *NativeEndpoint) DstIP() net.IP {
|
||||||
|
return rawAddrToIP(end.dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (end *NativeEndpoint) DstToBytes() []byte {
|
||||||
|
ptr := unsafe.Pointer(&end.src)
|
||||||
|
arr := (*[unix.SizeofSockaddrInet6]byte)(ptr)
|
||||||
|
return arr[:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (end *NativeEndpoint) SrcToString() string {
|
||||||
|
return sockaddrToString(end.src)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (end *NativeEndpoint) DstToString() string {
|
||||||
|
return sockaddrToString(end.dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (end *NativeEndpoint) ClearDst() {
|
||||||
|
end.dst = unix.RawSockaddrInet6{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (end *NativeEndpoint) ClearSrc() {
|
||||||
|
end.src = unix.RawSockaddrInet6{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func zoneToUint32(zone string) (uint32, error) {
|
func zoneToUint32(zone string) (uint32, error) {
|
||||||
@ -42,51 +253,116 @@ func zoneToUint32(zone string) (uint32, error) {
|
|||||||
return uint32(n), err
|
return uint32(n), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (end *Endpoint) ClearSrc() {
|
func create4(port uint16) (int, uint16, error) {
|
||||||
end.srcIf4 = 0
|
|
||||||
end.srcIPv4 = unix.RawSockaddrInet4{}
|
// create socket
|
||||||
end.srcIPv6 = unix.RawSockaddrInet6{}
|
|
||||||
}
|
fd, err := unix.Socket(
|
||||||
|
unix.AF_INET,
|
||||||
|
unix.SOCK_DGRAM,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
func (end *Endpoint) Set(s string) error {
|
|
||||||
addr, err := parseEndpoint(s)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return -1, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ipv6 := addr.IP.To16()
|
addr := unix.SockaddrInet4{
|
||||||
if ipv6 != nil {
|
Port: int(port),
|
||||||
zone, err := zoneToUint32(addr.Zone)
|
}
|
||||||
if err != nil {
|
|
||||||
|
// set sockopts and bind
|
||||||
|
|
||||||
|
if err := func() error {
|
||||||
|
if err := unix.SetsockoptInt(
|
||||||
|
fd,
|
||||||
|
unix.SOL_SOCKET,
|
||||||
|
unix.SO_REUSEADDR,
|
||||||
|
1,
|
||||||
|
); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ptr := (*unix.RawSockaddrInet6)(unsafe.Pointer(&end.dst))
|
|
||||||
ptr.Family = unix.AF_INET6
|
if err := unix.SetsockoptInt(
|
||||||
ptr.Port = uint16(addr.Port)
|
fd,
|
||||||
ptr.Flowinfo = 0
|
unix.IPPROTO_IP,
|
||||||
ptr.Scope_id = zone
|
unix.IP_PKTINFO,
|
||||||
copy(ptr.Addr[:], ipv6[:])
|
1,
|
||||||
end.ClearSrc()
|
); err != nil {
|
||||||
return nil
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return unix.Bind(fd, &addr)
|
||||||
|
}(); err != nil {
|
||||||
|
unix.Close(fd)
|
||||||
}
|
}
|
||||||
|
|
||||||
ipv4 := addr.IP.To4()
|
return fd, uint16(addr.Port), err
|
||||||
if ipv4 != nil {
|
|
||||||
ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
|
|
||||||
ptr.Family = unix.AF_INET
|
|
||||||
ptr.Port = uint16(addr.Port)
|
|
||||||
ptr.Zero = [8]byte{}
|
|
||||||
copy(ptr.Addr[:], ipv4)
|
|
||||||
end.ClearSrc()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return errors.New("Failed to recognize IP address format")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func send6(sock uintptr, end *Endpoint, buff []byte) error {
|
func create6(port uint16) (int, uint16, error) {
|
||||||
var iovec unix.Iovec
|
|
||||||
|
|
||||||
|
// create socket
|
||||||
|
|
||||||
|
fd, err := unix.Socket(
|
||||||
|
unix.AF_INET6,
|
||||||
|
unix.SOCK_DGRAM,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return -1, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// set sockopts and bind
|
||||||
|
|
||||||
|
addr := unix.SockaddrInet6{
|
||||||
|
Port: int(port),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := func() error {
|
||||||
|
|
||||||
|
if err := unix.SetsockoptInt(
|
||||||
|
fd,
|
||||||
|
unix.SOL_SOCKET,
|
||||||
|
unix.SO_REUSEADDR,
|
||||||
|
1,
|
||||||
|
); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := unix.SetsockoptInt(
|
||||||
|
fd,
|
||||||
|
unix.IPPROTO_IPV6,
|
||||||
|
unix.IPV6_RECVPKTINFO,
|
||||||
|
1,
|
||||||
|
); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := unix.SetsockoptInt(
|
||||||
|
fd,
|
||||||
|
unix.IPPROTO_IPV6,
|
||||||
|
unix.IPV6_V6ONLY,
|
||||||
|
1,
|
||||||
|
); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return unix.Bind(fd, &addr)
|
||||||
|
|
||||||
|
}(); err != nil {
|
||||||
|
unix.Close(fd)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fd, uint16(addr.Port), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func send6(sock int, end *NativeEndpoint, buff []byte) error {
|
||||||
|
|
||||||
|
// construct message header
|
||||||
|
|
||||||
|
var iovec unix.Iovec
|
||||||
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
|
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
|
||||||
iovec.SetLen(len(buff))
|
iovec.SetLen(len(buff))
|
||||||
|
|
||||||
@ -97,11 +373,11 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
|
|||||||
unix.Cmsghdr{
|
unix.Cmsghdr{
|
||||||
Level: unix.IPPROTO_IPV6,
|
Level: unix.IPPROTO_IPV6,
|
||||||
Type: unix.IPV6_PKTINFO,
|
Type: unix.IPV6_PKTINFO,
|
||||||
Len: unix.SizeofInet6Pktinfo,
|
Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
|
||||||
},
|
},
|
||||||
unix.Inet6Pktinfo{
|
unix.Inet6Pktinfo{
|
||||||
Addr: end.srcIPv6.Addr,
|
Addr: end.src.Addr,
|
||||||
Ifindex: end.srcIPv6.Scope_id,
|
Ifindex: end.src.Scope_id,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -119,22 +395,41 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
|
|||||||
|
|
||||||
_, _, errno := unix.Syscall(
|
_, _, errno := unix.Syscall(
|
||||||
unix.SYS_SENDMSG,
|
unix.SYS_SENDMSG,
|
||||||
sock,
|
uintptr(sock),
|
||||||
uintptr(unsafe.Pointer(&msghdr)),
|
uintptr(unsafe.Pointer(&msghdr)),
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if errno == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// clear src and retry
|
||||||
|
|
||||||
if errno == unix.EINVAL {
|
if errno == unix.EINVAL {
|
||||||
end.ClearSrc()
|
end.ClearSrc()
|
||||||
|
cmsg.pktinfo = unix.Inet6Pktinfo{}
|
||||||
|
_, _, errno = unix.Syscall(
|
||||||
|
unix.SYS_SENDMSG,
|
||||||
|
uintptr(sock),
|
||||||
|
uintptr(unsafe.Pointer(&msghdr)),
|
||||||
|
0,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return errno
|
return errno
|
||||||
}
|
}
|
||||||
|
|
||||||
func send4(sock uintptr, end *Endpoint, buff []byte) error {
|
func send4(sock int, end *NativeEndpoint, buff []byte) error {
|
||||||
var iovec unix.Iovec
|
|
||||||
|
|
||||||
|
// construct message header
|
||||||
|
|
||||||
|
var iovec unix.Iovec
|
||||||
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
|
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
|
||||||
iovec.SetLen(len(buff))
|
iovec.SetLen(len(buff))
|
||||||
|
|
||||||
|
src4 := (*IPv4Source)(unsafe.Pointer(&end.src))
|
||||||
|
|
||||||
cmsg := struct {
|
cmsg := struct {
|
||||||
cmsghdr unix.Cmsghdr
|
cmsghdr unix.Cmsghdr
|
||||||
pktinfo unix.Inet4Pktinfo
|
pktinfo unix.Inet4Pktinfo
|
||||||
@ -142,11 +437,11 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
|
|||||||
unix.Cmsghdr{
|
unix.Cmsghdr{
|
||||||
Level: unix.IPPROTO_IP,
|
Level: unix.IPPROTO_IP,
|
||||||
Type: unix.IP_PKTINFO,
|
Type: unix.IP_PKTINFO,
|
||||||
Len: unix.SizeofInet6Pktinfo,
|
Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
|
||||||
},
|
},
|
||||||
unix.Inet4Pktinfo{
|
unix.Inet4Pktinfo{
|
||||||
Spec_dst: end.srcIPv4.Addr,
|
Spec_dst: src4.src.Addr,
|
||||||
Ifindex: end.srcIf4,
|
Ifindex: src4.Ifindex,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -156,51 +451,44 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
|
|||||||
Name: (*byte)(unsafe.Pointer(&end.dst)),
|
Name: (*byte)(unsafe.Pointer(&end.dst)),
|
||||||
Namelen: unix.SizeofSockaddrInet4,
|
Namelen: unix.SizeofSockaddrInet4,
|
||||||
Control: (*byte)(unsafe.Pointer(&cmsg)),
|
Control: (*byte)(unsafe.Pointer(&cmsg)),
|
||||||
|
Flags: 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
|
msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
|
||||||
|
|
||||||
// sendmsg(sock, &msghdr, 0)
|
// sendmsg(sock, &msghdr, 0)
|
||||||
|
|
||||||
_, _, errno := unix.Syscall(
|
_, _, errno := unix.Syscall(
|
||||||
unix.SYS_SENDMSG,
|
unix.SYS_SENDMSG,
|
||||||
sock,
|
uintptr(sock),
|
||||||
uintptr(unsafe.Pointer(&msghdr)),
|
uintptr(unsafe.Pointer(&msghdr)),
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// clear source and try again
|
||||||
|
|
||||||
if errno == unix.EINVAL {
|
if errno == unix.EINVAL {
|
||||||
end.ClearSrc()
|
end.ClearSrc()
|
||||||
|
cmsg.pktinfo = unix.Inet4Pktinfo{}
|
||||||
|
_, _, errno = unix.Syscall(
|
||||||
|
unix.SYS_SENDMSG,
|
||||||
|
uintptr(sock),
|
||||||
|
uintptr(unsafe.Pointer(&msghdr)),
|
||||||
|
0,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// errno = 0 is still an error instance
|
||||||
|
|
||||||
|
if errno == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
return errno
|
return errno
|
||||||
}
|
}
|
||||||
|
|
||||||
func send(c *net.UDPConn, end *Endpoint, buff []byte) error {
|
func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
||||||
|
|
||||||
// extract underlying file descriptor
|
// contruct message header
|
||||||
|
|
||||||
file, err := c.File()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
sock := file.Fd()
|
|
||||||
|
|
||||||
// send depending on address family of dst
|
|
||||||
|
|
||||||
family := *((*uint16)(unsafe.Pointer(&end.dst)))
|
|
||||||
if family == unix.AF_INET {
|
|
||||||
return send4(sock, end, buff)
|
|
||||||
} else if family == unix.AF_INET6 {
|
|
||||||
return send6(sock, end, buff)
|
|
||||||
}
|
|
||||||
return errors.New("Unknown address family of source")
|
|
||||||
}
|
|
||||||
|
|
||||||
func receiveIPv4(end *Endpoint, c *net.UDPConn, buff []byte) (error, *net.UDPAddr, *net.UDPAddr) {
|
|
||||||
|
|
||||||
file, err := c.File()
|
|
||||||
if err != nil {
|
|
||||||
return err, nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var iovec unix.Iovec
|
var iovec unix.Iovec
|
||||||
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
|
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
|
||||||
@ -208,60 +496,87 @@ func receiveIPv4(end *Endpoint, c *net.UDPConn, buff []byte) (error, *net.UDPAdd
|
|||||||
|
|
||||||
var cmsg struct {
|
var cmsg struct {
|
||||||
cmsghdr unix.Cmsghdr
|
cmsghdr unix.Cmsghdr
|
||||||
pktinfo unix.Inet6Pktinfo // big enough
|
pktinfo unix.Inet4Pktinfo
|
||||||
|
}
|
||||||
|
|
||||||
|
var msghdr unix.Msghdr
|
||||||
|
msghdr.Iov = &iovec
|
||||||
|
msghdr.Iovlen = 1
|
||||||
|
msghdr.Name = (*byte)(unsafe.Pointer(&end.dst))
|
||||||
|
msghdr.Namelen = unix.SizeofSockaddrInet4
|
||||||
|
msghdr.Control = (*byte)(unsafe.Pointer(&cmsg))
|
||||||
|
msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
|
||||||
|
|
||||||
|
// recvmsg(sock, &mskhdr, 0)
|
||||||
|
|
||||||
|
size, _, errno := unix.Syscall(
|
||||||
|
unix.SYS_RECVMSG,
|
||||||
|
uintptr(sock),
|
||||||
|
uintptr(unsafe.Pointer(&msghdr)),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if errno != 0 {
|
||||||
|
return 0, errno
|
||||||
|
}
|
||||||
|
|
||||||
|
// update source cache
|
||||||
|
|
||||||
|
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
|
||||||
|
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
|
||||||
|
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
|
||||||
|
src4 := (*IPv4Source)(unsafe.Pointer(&end.src))
|
||||||
|
src4.src.Family = unix.AF_INET
|
||||||
|
src4.src.Addr = cmsg.pktinfo.Spec_dst
|
||||||
|
src4.Ifindex = cmsg.pktinfo.Ifindex
|
||||||
|
}
|
||||||
|
|
||||||
|
return int(size), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
||||||
|
|
||||||
|
// contruct message header
|
||||||
|
|
||||||
|
var iovec unix.Iovec
|
||||||
|
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
|
||||||
|
iovec.SetLen(len(buff))
|
||||||
|
|
||||||
|
var cmsg struct {
|
||||||
|
cmsghdr unix.Cmsghdr
|
||||||
|
pktinfo unix.Inet6Pktinfo
|
||||||
}
|
}
|
||||||
|
|
||||||
var msg unix.Msghdr
|
var msg unix.Msghdr
|
||||||
msg.Iov = &iovec
|
msg.Iov = &iovec
|
||||||
msg.Iovlen = 1
|
msg.Iovlen = 1
|
||||||
msg.Name = (*byte)(unsafe.Pointer(&end.dst))
|
msg.Name = (*byte)(unsafe.Pointer(&end.dst))
|
||||||
msg.Namelen = uint32(unix.SizeofSockaddrAny)
|
msg.Namelen = uint32(unix.SizeofSockaddrInet6)
|
||||||
msg.Control = (*byte)(unsafe.Pointer(&cmsg))
|
msg.Control = (*byte)(unsafe.Pointer(&cmsg))
|
||||||
msg.SetControllen(int(unsafe.Sizeof(cmsg)))
|
msg.SetControllen(int(unsafe.Sizeof(cmsg)))
|
||||||
|
|
||||||
_, _, errno := unix.Syscall(
|
// recvmsg(sock, &mskhdr, 0)
|
||||||
|
|
||||||
|
size, _, errno := unix.Syscall(
|
||||||
unix.SYS_RECVMSG,
|
unix.SYS_RECVMSG,
|
||||||
file.Fd(),
|
uintptr(sock),
|
||||||
uintptr(unsafe.Pointer(&msg)),
|
uintptr(unsafe.Pointer(&msg)),
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if errno != 0 {
|
if errno != 0 {
|
||||||
return errno, nil, nil
|
return 0, errno
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// update source cache
|
||||||
|
|
||||||
if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
|
if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
|
||||||
cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
|
cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
|
||||||
cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
|
cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
|
||||||
|
end.src.Family = unix.AF_INET6
|
||||||
|
end.src.Addr = cmsg.pktinfo.Addr
|
||||||
|
end.src.Scope_id = cmsg.pktinfo.Ifindex
|
||||||
}
|
}
|
||||||
|
|
||||||
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
|
return int(size), nil
|
||||||
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
|
|
||||||
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
|
|
||||||
|
|
||||||
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&cmsg.pktinfo))
|
|
||||||
println(info)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func setMark(conn *net.UDPConn, value uint32) error {
|
|
||||||
if conn == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
file, err := conn.File()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return unix.SetsockoptInt(
|
|
||||||
int(file.Fd()),
|
|
||||||
unix.SOL_SOCKET,
|
|
||||||
unix.SO_MARK,
|
|
||||||
int(value),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
@ -5,10 +5,8 @@ import (
|
|||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"golang.org/x/crypto/blake2s"
|
"golang.org/x/crypto/blake2s"
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
"net"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type CookieChecker struct {
|
type CookieChecker struct {
|
||||||
@ -76,7 +74,7 @@ func (st *CookieChecker) CheckMAC1(msg []byte) bool {
|
|||||||
return hmac.Equal(mac1[:], msg[smac1:smac2])
|
return hmac.Equal(mac1[:], msg[smac1:smac2])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool {
|
func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool {
|
||||||
st.mutex.RLock()
|
st.mutex.RLock()
|
||||||
defer st.mutex.RUnlock()
|
defer st.mutex.RUnlock()
|
||||||
|
|
||||||
@ -89,8 +87,7 @@ func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool {
|
|||||||
var cookie [blake2s.Size128]byte
|
var cookie [blake2s.Size128]byte
|
||||||
func() {
|
func() {
|
||||||
mac, _ := blake2s.New128(st.mac2.secret[:])
|
mac, _ := blake2s.New128(st.mac2.secret[:])
|
||||||
mac.Write(src.IP)
|
mac.Write(src)
|
||||||
mac.Write((*[unsafe.Sizeof(src.Port)]byte)(unsafe.Pointer(&src.Port))[:])
|
|
||||||
mac.Sum(cookie[:0])
|
mac.Sum(cookie[:0])
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@ -111,7 +108,7 @@ func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool {
|
|||||||
func (st *CookieChecker) CreateReply(
|
func (st *CookieChecker) CreateReply(
|
||||||
msg []byte,
|
msg []byte,
|
||||||
recv uint32,
|
recv uint32,
|
||||||
src *net.UDPAddr,
|
src []byte,
|
||||||
) (*MessageCookieReply, error) {
|
) (*MessageCookieReply, error) {
|
||||||
|
|
||||||
st.mutex.RLock()
|
st.mutex.RLock()
|
||||||
@ -136,8 +133,7 @@ func (st *CookieChecker) CreateReply(
|
|||||||
var cookie [blake2s.Size128]byte
|
var cookie [blake2s.Size128]byte
|
||||||
func() {
|
func() {
|
||||||
mac, _ := blake2s.New128(st.mac2.secret[:])
|
mac, _ := blake2s.New128(st.mac2.secret[:])
|
||||||
mac.Write(src.IP)
|
mac.Write(src)
|
||||||
mac.Write((*[unsafe.Sizeof(src.Port)]byte)(unsafe.Pointer(&src.Port))[:])
|
|
||||||
mac.Sum(cookie[:0])
|
mac.Sum(cookie[:0])
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -25,7 +24,7 @@ func TestCookieMAC1(t *testing.T) {
|
|||||||
|
|
||||||
// check mac1
|
// check mac1
|
||||||
|
|
||||||
src, _ := net.ResolveUDPAddr("udp", "192.168.13.37:4000")
|
src := []byte{192, 168, 13, 37, 10, 10, 10}
|
||||||
|
|
||||||
checkMAC1 := func(msg []byte) {
|
checkMAC1 := func(msg []byte) {
|
||||||
generator.AddMacs(msg)
|
generator.AddMacs(msg)
|
||||||
@ -128,12 +127,12 @@ func TestCookieMAC1(t *testing.T) {
|
|||||||
|
|
||||||
msg[5] ^= 0x20
|
msg[5] ^= 0x20
|
||||||
|
|
||||||
srcBad1, _ := net.ResolveUDPAddr("udp", "192.168.13.37:4001")
|
srcBad1 := []byte{192, 168, 13, 37, 40, 01}
|
||||||
if checker.CheckMAC2(msg, srcBad1) {
|
if checker.CheckMAC2(msg, srcBad1) {
|
||||||
t.Fatal("MAC2 generation/verification failed")
|
t.Fatal("MAC2 generation/verification failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
srcBad2, _ := net.ResolveUDPAddr("udp", "192.168.13.38:4000")
|
srcBad2 := []byte{192, 168, 13, 38, 40, 01}
|
||||||
if checker.CheckMAC2(msg, srcBad2) {
|
if checker.CheckMAC2(msg, srcBad2) {
|
||||||
t.Fatal("MAC2 generation/verification failed")
|
t.Fatal("MAC2 generation/verification failed")
|
||||||
}
|
}
|
||||||
|
@ -2,29 +2,25 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
|
"os/exec"
|
||||||
)
|
)
|
||||||
|
|
||||||
/* Daemonizes the process on linux
|
/* Daemonizes the process on linux
|
||||||
*
|
*
|
||||||
* This is done by spawning and releasing a copy with the --foreground flag
|
* This is done by spawning and releasing a copy with the --foreground flag
|
||||||
*
|
|
||||||
* TODO: Use env variable to spawn in background
|
|
||||||
*/
|
*/
|
||||||
|
func Daemonize(attr *os.ProcAttr) error {
|
||||||
|
// I would like to use os.Executable,
|
||||||
|
// however this means dropping support for Go <1.8
|
||||||
|
path, err := exec.LookPath(os.Args[0])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func Daemonize() error {
|
|
||||||
argv := []string{os.Args[0], "--foreground"}
|
argv := []string{os.Args[0], "--foreground"}
|
||||||
argv = append(argv, os.Args[1:]...)
|
argv = append(argv, os.Args[1:]...)
|
||||||
attr := &os.ProcAttr{
|
|
||||||
Dir: ".",
|
|
||||||
Env: os.Environ(),
|
|
||||||
Files: []*os.File{
|
|
||||||
os.Stdin,
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
process, err := os.StartProcess(
|
process, err := os.StartProcess(
|
||||||
argv[0],
|
path,
|
||||||
argv,
|
argv,
|
||||||
attr,
|
attr,
|
||||||
)
|
)
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@ -9,8 +8,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Device struct {
|
type Device struct {
|
||||||
log *Logger // collection of loggers for levels
|
closed AtomicBool // device is closed? (acting as guard)
|
||||||
idCounter uint // for assigning debug ids to peers
|
log *Logger // collection of loggers for levels
|
||||||
|
idCounter uint // for assigning debug ids to peers
|
||||||
fwMark uint32
|
fwMark uint32
|
||||||
tun struct {
|
tun struct {
|
||||||
device TUNDevice
|
device TUNDevice
|
||||||
@ -22,9 +22,9 @@ type Device struct {
|
|||||||
}
|
}
|
||||||
net struct {
|
net struct {
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
addr *net.UDPAddr // UDP source address
|
bind Bind // bind interface
|
||||||
conn *net.UDPConn // UDP "connection"
|
port uint16 // listening port
|
||||||
fwmark uint32
|
fwmark uint32 // mark value (0 = disabled)
|
||||||
}
|
}
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
privateKey NoisePrivateKey
|
privateKey NoisePrivateKey
|
||||||
@ -37,8 +37,7 @@ type Device struct {
|
|||||||
handshake chan QueueHandshakeElement
|
handshake chan QueueHandshakeElement
|
||||||
}
|
}
|
||||||
signal struct {
|
signal struct {
|
||||||
stop chan struct{} // halts all go routines
|
stop chan struct{}
|
||||||
newUDPConn chan struct{} // a net.conn was set (consumed by the receiver routine)
|
|
||||||
}
|
}
|
||||||
underLoadUntil atomic.Value
|
underLoadUntil atomic.Value
|
||||||
ratelimiter Ratelimiter
|
ratelimiter Ratelimiter
|
||||||
@ -128,21 +127,23 @@ func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
|
|||||||
device.pool.messageBuffers.Put(msg)
|
device.pool.messageBuffers.Put(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDevice(tun TUNDevice, logLevel int) *Device {
|
func NewDevice(tun TUNDevice, logger *Logger) *Device {
|
||||||
device := new(Device)
|
device := new(Device)
|
||||||
|
|
||||||
device.mutex.Lock()
|
device.mutex.Lock()
|
||||||
defer device.mutex.Unlock()
|
defer device.mutex.Unlock()
|
||||||
|
|
||||||
device.log = NewLogger(logLevel, "("+tun.Name()+") ")
|
device.log = logger
|
||||||
device.peers = make(map[NoisePublicKey]*Peer)
|
device.peers = make(map[NoisePublicKey]*Peer)
|
||||||
device.tun.device = tun
|
device.tun.device = tun
|
||||||
|
|
||||||
device.indices.Init()
|
device.indices.Init()
|
||||||
device.ratelimiter.Init()
|
device.ratelimiter.Init()
|
||||||
|
|
||||||
device.routingTable.Reset()
|
device.routingTable.Reset()
|
||||||
device.underLoadUntil.Store(time.Time{})
|
device.underLoadUntil.Store(time.Time{})
|
||||||
|
|
||||||
// setup pools
|
// setup buffer pool
|
||||||
|
|
||||||
device.pool.messageBuffers = sync.Pool{
|
device.pool.messageBuffers = sync.Pool{
|
||||||
New: func() interface{} {
|
New: func() interface{} {
|
||||||
@ -159,7 +160,11 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
|
|||||||
// prepare signals
|
// prepare signals
|
||||||
|
|
||||||
device.signal.stop = make(chan struct{})
|
device.signal.stop = make(chan struct{})
|
||||||
device.signal.newUDPConn = make(chan struct{}, 1)
|
|
||||||
|
// prepare net
|
||||||
|
|
||||||
|
device.net.port = 0
|
||||||
|
device.net.bind = nil
|
||||||
|
|
||||||
// start workers
|
// start workers
|
||||||
|
|
||||||
@ -168,12 +173,9 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
|
|||||||
go device.RoutineDecryption()
|
go device.RoutineDecryption()
|
||||||
go device.RoutineHandshake()
|
go device.RoutineHandshake()
|
||||||
}
|
}
|
||||||
|
go device.RoutineReadFromTUN()
|
||||||
go device.RoutineTUNEventReader()
|
go device.RoutineTUNEventReader()
|
||||||
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
|
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
|
||||||
go device.RoutineReadFromTUN()
|
|
||||||
go device.RoutineReceiveIncomming()
|
|
||||||
|
|
||||||
return device
|
return device
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -202,9 +204,13 @@ func (device *Device) RemoveAllPeers() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) Close() {
|
func (device *Device) Close() {
|
||||||
|
if device.closed.Swap(true) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
device.log.Info.Println("Closing device")
|
||||||
device.RemoveAllPeers()
|
device.RemoveAllPeers()
|
||||||
close(device.signal.stop)
|
close(device.signal.stop)
|
||||||
closeUDPConn(device)
|
CloseUDPListener(device)
|
||||||
device.tun.device.Close()
|
device.tun.device.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -15,6 +16,10 @@ type DummyTUN struct {
|
|||||||
events chan TUNEvent
|
events chan TUNEvent
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (tun *DummyTUN) File() *os.File {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (tun *DummyTUN) Name() string {
|
func (tun *DummyTUN) Name() string {
|
||||||
return tun.name
|
return tun.name
|
||||||
}
|
}
|
||||||
@ -67,7 +72,8 @@ func randDevice(t *testing.T) *Device {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
tun, _ := CreateDummyTUN("dummy")
|
tun, _ := CreateDummyTUN("dummy")
|
||||||
device := NewDevice(tun, LogLevelError)
|
logger := NewLogger(LogLevelError, "")
|
||||||
|
device := NewDevice(tun, logger)
|
||||||
device.SetPrivateKey(sk)
|
device.SetPrivateKey(sk)
|
||||||
return device
|
return device
|
||||||
}
|
}
|
||||||
|
134
src/main.go
134
src/main.go
@ -2,10 +2,15 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ENV_WG_TUN_FD = "WG_TUN_FD"
|
||||||
|
ENV_WG_UAPI_FD = "WG_UAPI_FD"
|
||||||
)
|
)
|
||||||
|
|
||||||
func printUsage() {
|
func printUsage() {
|
||||||
@ -43,28 +48,6 @@ func main() {
|
|||||||
interfaceName = os.Args[1]
|
interfaceName = os.Args[1]
|
||||||
}
|
}
|
||||||
|
|
||||||
// daemonize the process
|
|
||||||
|
|
||||||
if !foreground {
|
|
||||||
err := Daemonize()
|
|
||||||
if err != nil {
|
|
||||||
log.Println("Failed to daemonize:", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// increase number of go workers (for Go <1.5)
|
|
||||||
|
|
||||||
runtime.GOMAXPROCS(runtime.NumCPU())
|
|
||||||
|
|
||||||
// open TUN device
|
|
||||||
|
|
||||||
tun, err := CreateTUN(interfaceName)
|
|
||||||
if err != nil {
|
|
||||||
log.Println("Failed to create tun device:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// get log level (default: info)
|
// get log level (default: info)
|
||||||
|
|
||||||
logLevel := func() int {
|
logLevel := func() int {
|
||||||
@ -79,25 +62,103 @@ func main() {
|
|||||||
return LogLevelInfo
|
return LogLevelInfo
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
logger := NewLogger(
|
||||||
|
logLevel,
|
||||||
|
fmt.Sprintf("(%s) ", interfaceName),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.Debug.Println("Debug log enabled")
|
||||||
|
|
||||||
|
// open TUN device (or use supplied fd)
|
||||||
|
|
||||||
|
tun, err := func() (TUNDevice, error) {
|
||||||
|
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
|
||||||
|
if tunFdStr == "" {
|
||||||
|
return CreateTUN(interfaceName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// construct tun device from supplied fd
|
||||||
|
|
||||||
|
fd, err := strconv.ParseUint(tunFdStr, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
file := os.NewFile(uintptr(fd), "")
|
||||||
|
return CreateTUNFromFile(interfaceName, file)
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logger.Error.Println("Failed to create TUN device:", err)
|
||||||
|
os.Exit(ExitSetupFailed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// open UAPI file (or use supplied fd)
|
||||||
|
|
||||||
|
fileUAPI, err := func() (*os.File, error) {
|
||||||
|
uapiFdStr := os.Getenv(ENV_WG_UAPI_FD)
|
||||||
|
if uapiFdStr == "" {
|
||||||
|
return UAPIOpen(interfaceName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// use supplied fd
|
||||||
|
|
||||||
|
fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return os.NewFile(uintptr(fd), ""), nil
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logger.Error.Println("UAPI listen error:", err)
|
||||||
|
os.Exit(ExitSetupFailed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// daemonize the process
|
||||||
|
|
||||||
|
if !foreground {
|
||||||
|
env := os.Environ()
|
||||||
|
env = append(env, fmt.Sprintf("%s=3", ENV_WG_TUN_FD))
|
||||||
|
env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD))
|
||||||
|
attr := &os.ProcAttr{
|
||||||
|
Files: []*os.File{
|
||||||
|
nil, // stdin
|
||||||
|
nil, // stdout
|
||||||
|
nil, // stderr
|
||||||
|
tun.File(),
|
||||||
|
fileUAPI,
|
||||||
|
},
|
||||||
|
Dir: ".",
|
||||||
|
Env: env,
|
||||||
|
}
|
||||||
|
err = Daemonize(attr)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error.Println("Failed to daemonize:", err)
|
||||||
|
os.Exit(ExitSetupFailed)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// increase number of go workers (for Go <1.5)
|
||||||
|
|
||||||
|
runtime.GOMAXPROCS(runtime.NumCPU())
|
||||||
|
|
||||||
// create wireguard device
|
// create wireguard device
|
||||||
|
|
||||||
device := NewDevice(tun, logLevel)
|
device := NewDevice(tun, logger)
|
||||||
|
|
||||||
logInfo := device.log.Info
|
logger.Info.Println("Device started")
|
||||||
logError := device.log.Error
|
|
||||||
logInfo.Println("Starting device")
|
|
||||||
|
|
||||||
// start configuration lister
|
// start uapi listener
|
||||||
|
|
||||||
uapi, err := NewUAPIListener(interfaceName)
|
|
||||||
if err != nil {
|
|
||||||
logError.Fatal("UAPI listen error:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
errs := make(chan error)
|
errs := make(chan error)
|
||||||
term := make(chan os.Signal)
|
term := make(chan os.Signal)
|
||||||
wait := device.WaitChannel()
|
wait := device.WaitChannel()
|
||||||
|
|
||||||
|
uapi, err := UAPIListen(interfaceName, fileUAPI)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
conn, err := uapi.Accept()
|
conn, err := uapi.Accept()
|
||||||
@ -109,7 +170,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
logInfo.Println("UAPI listener started")
|
logger.Info.Println("UAPI listener started")
|
||||||
|
|
||||||
// wait for program to terminate
|
// wait for program to terminate
|
||||||
|
|
||||||
@ -122,9 +183,10 @@ func main() {
|
|||||||
case <-errs:
|
case <-errs:
|
||||||
}
|
}
|
||||||
|
|
||||||
// clean up UAPI bind
|
// clean up
|
||||||
|
|
||||||
uapi.Close()
|
uapi.Close()
|
||||||
|
device.Close()
|
||||||
|
|
||||||
logInfo.Println("Closing")
|
logger.Info.Println("Shutting down")
|
||||||
}
|
}
|
||||||
|
@ -21,6 +21,14 @@ func (a *AtomicBool) Get() bool {
|
|||||||
return atomic.LoadInt32(&a.flag) == AtomicTrue
|
return atomic.LoadInt32(&a.flag) == AtomicTrue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *AtomicBool) Swap(val bool) bool {
|
||||||
|
flag := AtomicFalse
|
||||||
|
if val {
|
||||||
|
flag = AtomicTrue
|
||||||
|
}
|
||||||
|
return atomic.SwapInt32(&a.flag, flag) == AtomicTrue
|
||||||
|
}
|
||||||
|
|
||||||
func (a *AtomicBool) Set(val bool) {
|
func (a *AtomicBool) Set(val bool) {
|
||||||
flag := AtomicFalse
|
flag := AtomicFalse
|
||||||
if val {
|
if val {
|
||||||
|
@ -117,8 +117,8 @@ func TestNoiseHandshake(t *testing.T) {
|
|||||||
var err error
|
var err error
|
||||||
var out []byte
|
var out []byte
|
||||||
var nonce [12]byte
|
var nonce [12]byte
|
||||||
out = key1.send.aead.Seal(out, nonce[:], testMsg, nil)
|
out = key1.send.Seal(out, nonce[:], testMsg, nil)
|
||||||
out, err = key2.receive.aead.Open(out[:0], nonce[:], out, nil)
|
out, err = key2.receive.Open(out[:0], nonce[:], out, nil)
|
||||||
assertNil(t, err)
|
assertNil(t, err)
|
||||||
assertEqual(t, out, testMsg)
|
assertEqual(t, out, testMsg)
|
||||||
}()
|
}()
|
||||||
@ -128,8 +128,8 @@ func TestNoiseHandshake(t *testing.T) {
|
|||||||
var err error
|
var err error
|
||||||
var out []byte
|
var out []byte
|
||||||
var nonce [12]byte
|
var nonce [12]byte
|
||||||
out = key2.send.aead.Seal(out, nonce[:], testMsg, nil)
|
out = key2.send.Seal(out, nonce[:], testMsg, nil)
|
||||||
out, err = key1.receive.aead.Open(out[:0], nonce[:], out, nil)
|
out, err = key1.receive.Open(out[:0], nonce[:], out, nil)
|
||||||
assertNil(t, err)
|
assertNil(t, err)
|
||||||
assertEqual(t, out, testMsg)
|
assertEqual(t, out, testMsg)
|
||||||
}()
|
}()
|
||||||
|
29
src/peer.go
29
src/peer.go
@ -4,7 +4,6 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -16,7 +15,7 @@ type Peer struct {
|
|||||||
keyPairs KeyPairs
|
keyPairs KeyPairs
|
||||||
handshake Handshake
|
handshake Handshake
|
||||||
device *Device
|
device *Device
|
||||||
endpoint *net.UDPAddr
|
endpoint Endpoint
|
||||||
stats struct {
|
stats struct {
|
||||||
txBytes uint64 // bytes send to peer (endpoint)
|
txBytes uint64 // bytes send to peer (endpoint)
|
||||||
rxBytes uint64 // bytes received from peer
|
rxBytes uint64 // bytes received from peer
|
||||||
@ -106,6 +105,10 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
|||||||
handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic)
|
handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic)
|
||||||
handshake.mutex.Unlock()
|
handshake.mutex.Unlock()
|
||||||
|
|
||||||
|
// reset endpoint
|
||||||
|
|
||||||
|
peer.endpoint = nil
|
||||||
|
|
||||||
// prepare queuing
|
// prepare queuing
|
||||||
|
|
||||||
peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
|
peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
|
||||||
@ -130,11 +133,31 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
|||||||
return peer, nil
|
return peer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (peer *Peer) SendBuffer(buffer []byte) error {
|
||||||
|
peer.device.net.mutex.RLock()
|
||||||
|
defer peer.device.net.mutex.RUnlock()
|
||||||
|
peer.mutex.RLock()
|
||||||
|
defer peer.mutex.RUnlock()
|
||||||
|
if peer.endpoint == nil {
|
||||||
|
return errors.New("No known endpoint for peer")
|
||||||
|
}
|
||||||
|
return peer.device.net.bind.Send(buffer, peer.endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Returns a short string identification for logging
|
||||||
|
*/
|
||||||
func (peer *Peer) String() string {
|
func (peer *Peer) String() string {
|
||||||
|
if peer.endpoint == nil {
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"peer(%d unknown %s)",
|
||||||
|
peer.id,
|
||||||
|
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
|
||||||
|
)
|
||||||
|
}
|
||||||
return fmt.Sprintf(
|
return fmt.Sprintf(
|
||||||
"peer(%d %s %s)",
|
"peer(%d %s %s)",
|
||||||
peer.id,
|
peer.id,
|
||||||
peer.endpoint.String(),
|
peer.endpoint.DstToString(),
|
||||||
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
|
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
271
src/receive.go
271
src/receive.go
@ -13,19 +13,20 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type QueueHandshakeElement struct {
|
type QueueHandshakeElement struct {
|
||||||
msgType uint32
|
msgType uint32
|
||||||
packet []byte
|
packet []byte
|
||||||
buffer *[MaxMessageSize]byte
|
endpoint Endpoint
|
||||||
source *net.UDPAddr
|
buffer *[MaxMessageSize]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueueInboundElement struct {
|
type QueueInboundElement struct {
|
||||||
dropped int32
|
dropped int32
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
buffer *[MaxMessageSize]byte
|
buffer *[MaxMessageSize]byte
|
||||||
packet []byte
|
packet []byte
|
||||||
counter uint64
|
counter uint64
|
||||||
keyPair *KeyPair
|
keyPair *KeyPair
|
||||||
|
endpoint Endpoint
|
||||||
}
|
}
|
||||||
|
|
||||||
func (elem *QueueInboundElement) Drop() {
|
func (elem *QueueInboundElement) Drop() {
|
||||||
@ -92,130 +93,122 @@ func (device *Device) addToHandshakeQueue(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) RoutineReceiveIncomming() {
|
func (device *Device) RoutineReceiveIncomming(IP int, bind Bind) {
|
||||||
|
|
||||||
logDebug := device.log.Debug
|
logDebug := device.log.Debug
|
||||||
logDebug.Println("Routine, receive incomming, started")
|
logDebug.Println("Routine, receive incomming, IP version:", IP)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
|
||||||
// wait for new conn
|
// receive datagrams until conn is closed
|
||||||
|
|
||||||
logDebug.Println("Waiting for udp socket")
|
buffer := device.GetMessageBuffer()
|
||||||
|
|
||||||
select {
|
var (
|
||||||
case <-device.signal.stop:
|
err error
|
||||||
return
|
size int
|
||||||
|
endpoint Endpoint
|
||||||
|
)
|
||||||
|
|
||||||
case <-device.signal.newUDPConn:
|
for {
|
||||||
|
|
||||||
// fetch connection
|
// read next datagram
|
||||||
|
|
||||||
device.net.mutex.RLock()
|
switch IP {
|
||||||
conn := device.net.conn
|
case ipv4.Version:
|
||||||
device.net.mutex.RUnlock()
|
size, endpoint, err = bind.ReceiveIPv4(buffer[:])
|
||||||
if conn == nil {
|
case ipv6.Version:
|
||||||
|
size, endpoint, err = bind.ReceiveIPv6(buffer[:])
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if size < MinMessageSize {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
logDebug.Println("Listening for inbound packets")
|
// check size of packet
|
||||||
|
|
||||||
// receive datagrams until conn is closed
|
packet := buffer[:size]
|
||||||
|
msgType := binary.LittleEndian.Uint32(packet[:4])
|
||||||
|
|
||||||
buffer := device.GetMessageBuffer()
|
var okay bool
|
||||||
|
|
||||||
for {
|
switch msgType {
|
||||||
|
|
||||||
// read next datagram
|
// check if transport
|
||||||
|
|
||||||
size, raddr, err := conn.ReadFromUDP(buffer[:])
|
case MessageTransportType:
|
||||||
|
|
||||||
if err != nil {
|
// check size
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if size < MinMessageSize {
|
if len(packet) < MessageTransportType {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// check size of packet
|
// lookup key pair
|
||||||
|
|
||||||
packet := buffer[:size]
|
receiver := binary.LittleEndian.Uint32(
|
||||||
msgType := binary.LittleEndian.Uint32(packet[:4])
|
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
|
||||||
|
)
|
||||||
var okay bool
|
value := device.indices.Lookup(receiver)
|
||||||
|
keyPair := value.keyPair
|
||||||
switch msgType {
|
if keyPair == nil {
|
||||||
|
|
||||||
// check if transport
|
|
||||||
|
|
||||||
case MessageTransportType:
|
|
||||||
|
|
||||||
// check size
|
|
||||||
|
|
||||||
if len(packet) < MessageTransportType {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// lookup key pair
|
|
||||||
|
|
||||||
receiver := binary.LittleEndian.Uint32(
|
|
||||||
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
|
|
||||||
)
|
|
||||||
value := device.indices.Lookup(receiver)
|
|
||||||
keyPair := value.keyPair
|
|
||||||
if keyPair == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// check key-pair expiry
|
|
||||||
|
|
||||||
if keyPair.created.Add(RejectAfterTime).Before(time.Now()) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// create work element
|
|
||||||
|
|
||||||
peer := value.peer
|
|
||||||
elem := &QueueInboundElement{
|
|
||||||
packet: packet,
|
|
||||||
buffer: buffer,
|
|
||||||
keyPair: keyPair,
|
|
||||||
dropped: AtomicFalse,
|
|
||||||
}
|
|
||||||
elem.mutex.Lock()
|
|
||||||
|
|
||||||
// add to decryption queues
|
|
||||||
|
|
||||||
device.addToDecryptionQueue(device.queue.decryption, elem)
|
|
||||||
device.addToInboundQueue(peer.queue.inbound, elem)
|
|
||||||
buffer = device.GetMessageBuffer()
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
// otherwise it is a handshake related packet
|
|
||||||
|
|
||||||
case MessageInitiationType:
|
|
||||||
okay = len(packet) == MessageInitiationSize
|
|
||||||
|
|
||||||
case MessageResponseType:
|
|
||||||
okay = len(packet) == MessageResponseSize
|
|
||||||
|
|
||||||
case MessageCookieReplyType:
|
|
||||||
okay = len(packet) == MessageCookieReplySize
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if okay {
|
// check key-pair expiry
|
||||||
device.addToHandshakeQueue(
|
|
||||||
device.queue.handshake,
|
if keyPair.created.Add(RejectAfterTime).Before(time.Now()) {
|
||||||
QueueHandshakeElement{
|
continue
|
||||||
msgType: msgType,
|
|
||||||
buffer: buffer,
|
|
||||||
packet: packet,
|
|
||||||
source: raddr,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
buffer = device.GetMessageBuffer()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create work element
|
||||||
|
|
||||||
|
peer := value.peer
|
||||||
|
elem := &QueueInboundElement{
|
||||||
|
packet: packet,
|
||||||
|
buffer: buffer,
|
||||||
|
keyPair: keyPair,
|
||||||
|
dropped: AtomicFalse,
|
||||||
|
endpoint: endpoint,
|
||||||
|
}
|
||||||
|
elem.mutex.Lock()
|
||||||
|
|
||||||
|
// add to decryption queues
|
||||||
|
|
||||||
|
device.addToDecryptionQueue(device.queue.decryption, elem)
|
||||||
|
device.addToInboundQueue(peer.queue.inbound, elem)
|
||||||
|
buffer = device.GetMessageBuffer()
|
||||||
|
continue
|
||||||
|
|
||||||
|
// otherwise it is a fixed size & handshake related packet
|
||||||
|
|
||||||
|
case MessageInitiationType:
|
||||||
|
okay = len(packet) == MessageInitiationSize
|
||||||
|
|
||||||
|
case MessageResponseType:
|
||||||
|
okay = len(packet) == MessageResponseSize
|
||||||
|
|
||||||
|
case MessageCookieReplyType:
|
||||||
|
okay = len(packet) == MessageCookieReplySize
|
||||||
|
}
|
||||||
|
|
||||||
|
if okay {
|
||||||
|
device.addToHandshakeQueue(
|
||||||
|
device.queue.handshake,
|
||||||
|
QueueHandshakeElement{
|
||||||
|
msgType: msgType,
|
||||||
|
buffer: buffer,
|
||||||
|
packet: packet,
|
||||||
|
endpoint: endpoint,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
buffer = device.GetMessageBuffer()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -293,8 +286,6 @@ func (device *Device) RoutineHandshake() {
|
|||||||
|
|
||||||
// unmarshal packet
|
// unmarshal packet
|
||||||
|
|
||||||
logDebug.Println("Process cookie reply from:", elem.source.String())
|
|
||||||
|
|
||||||
var reply MessageCookieReply
|
var reply MessageCookieReply
|
||||||
reader := bytes.NewReader(elem.packet)
|
reader := bytes.NewReader(elem.packet)
|
||||||
err := binary.Read(reader, binary.LittleEndian, &reply)
|
err := binary.Read(reader, binary.LittleEndian, &reply)
|
||||||
@ -321,15 +312,25 @@ func (device *Device) RoutineHandshake() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// endpoints destination address is the source of the datagram
|
||||||
|
|
||||||
|
srcBytes := elem.endpoint.DstToBytes()
|
||||||
|
|
||||||
if device.IsUnderLoad() {
|
if device.IsUnderLoad() {
|
||||||
if !device.mac.CheckMAC2(elem.packet, elem.source) {
|
|
||||||
|
// verify MAC2 field
|
||||||
|
|
||||||
|
if !device.mac.CheckMAC2(elem.packet, srcBytes) {
|
||||||
|
|
||||||
// construct cookie reply
|
// construct cookie reply
|
||||||
|
|
||||||
logDebug.Println("Sending cookie reply to:", elem.source.String())
|
logDebug.Println(
|
||||||
|
"Sending cookie reply to:",
|
||||||
|
elem.endpoint.DstToString(),
|
||||||
|
)
|
||||||
|
|
||||||
sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type"
|
sender := binary.LittleEndian.Uint32(elem.packet[4:8])
|
||||||
reply, err := device.mac.CreateReply(elem.packet, sender, elem.source)
|
reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Failed to create cookie reply:", err)
|
logError.Println("Failed to create cookie reply:", err)
|
||||||
return
|
return
|
||||||
@ -339,17 +340,16 @@ func (device *Device) RoutineHandshake() {
|
|||||||
|
|
||||||
writer := bytes.NewBuffer(temp[:0])
|
writer := bytes.NewBuffer(temp[:0])
|
||||||
binary.Write(writer, binary.LittleEndian, reply)
|
binary.Write(writer, binary.LittleEndian, reply)
|
||||||
_, err = device.net.conn.WriteToUDP(
|
device.net.bind.Send(writer.Bytes(), elem.endpoint)
|
||||||
writer.Bytes(),
|
|
||||||
elem.source,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logDebug.Println("Failed to send cookie reply:", err)
|
logDebug.Println("Failed to send cookie reply:", err)
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if !device.ratelimiter.Allow(elem.source.IP) {
|
// check ratelimiter
|
||||||
|
|
||||||
|
if !device.ratelimiter.Allow(elem.endpoint.DstIP()) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -380,8 +380,7 @@ func (device *Device) RoutineHandshake() {
|
|||||||
if peer == nil {
|
if peer == nil {
|
||||||
logInfo.Println(
|
logInfo.Println(
|
||||||
"Recieved invalid initiation message from",
|
"Recieved invalid initiation message from",
|
||||||
elem.source.IP.String(),
|
elem.endpoint.DstToString(),
|
||||||
elem.source.Port,
|
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -392,10 +391,9 @@ func (device *Device) RoutineHandshake() {
|
|||||||
peer.TimerAnyAuthenticatedPacketReceived()
|
peer.TimerAnyAuthenticatedPacketReceived()
|
||||||
|
|
||||||
// update endpoint
|
// update endpoint
|
||||||
// TODO: Discover destination address also, only update on change
|
|
||||||
|
|
||||||
peer.mutex.Lock()
|
peer.mutex.Lock()
|
||||||
peer.endpoint = elem.source
|
peer.endpoint = elem.endpoint
|
||||||
peer.mutex.Unlock()
|
peer.mutex.Unlock()
|
||||||
|
|
||||||
// create response
|
// create response
|
||||||
@ -418,9 +416,11 @@ func (device *Device) RoutineHandshake() {
|
|||||||
|
|
||||||
// send response
|
// send response
|
||||||
|
|
||||||
_, err = peer.SendBuffer(packet)
|
err = peer.SendBuffer(packet)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
peer.TimerAnyAuthenticatedPacketTraversal()
|
peer.TimerAnyAuthenticatedPacketTraversal()
|
||||||
|
} else {
|
||||||
|
logError.Println("Failed to send response to:", peer.String(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
case MessageResponseType:
|
case MessageResponseType:
|
||||||
@ -441,12 +441,17 @@ func (device *Device) RoutineHandshake() {
|
|||||||
if peer == nil {
|
if peer == nil {
|
||||||
logInfo.Println(
|
logInfo.Println(
|
||||||
"Recieved invalid response message from",
|
"Recieved invalid response message from",
|
||||||
elem.source.IP.String(),
|
elem.endpoint.DstToString(),
|
||||||
elem.source.Port,
|
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// update endpoint
|
||||||
|
|
||||||
|
peer.mutex.Lock()
|
||||||
|
peer.endpoint = elem.endpoint
|
||||||
|
peer.mutex.Unlock()
|
||||||
|
|
||||||
logDebug.Println("Received handshake initation from", peer)
|
logDebug.Println("Received handshake initation from", peer)
|
||||||
|
|
||||||
peer.TimerEphemeralKeyCreated()
|
peer.TimerEphemeralKeyCreated()
|
||||||
@ -515,6 +520,12 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
|||||||
}
|
}
|
||||||
kp.mutex.Unlock()
|
kp.mutex.Unlock()
|
||||||
|
|
||||||
|
// update endpoint
|
||||||
|
|
||||||
|
peer.mutex.Lock()
|
||||||
|
peer.endpoint = elem.endpoint
|
||||||
|
peer.mutex.Unlock()
|
||||||
|
|
||||||
// check for keep-alive
|
// check for keep-alive
|
||||||
|
|
||||||
if len(elem.packet) == 0 {
|
if len(elem.packet) == 0 {
|
||||||
@ -546,7 +557,10 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
|||||||
|
|
||||||
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
|
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
|
||||||
if device.routingTable.LookupIPv4(src) != peer {
|
if device.routingTable.LookupIPv4(src) != peer {
|
||||||
logInfo.Println("Packet with unallowed source IP from", peer.String())
|
logInfo.Println(
|
||||||
|
"IPv4 packet with unallowed source address from",
|
||||||
|
peer.String(),
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -571,7 +585,10 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
|||||||
|
|
||||||
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
|
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
|
||||||
if device.routingTable.LookupIPv6(src) != peer {
|
if device.routingTable.LookupIPv6(src) != peer {
|
||||||
logInfo.Println("Packet with unallowed source IP from", peer.String())
|
logInfo.Println(
|
||||||
|
"IPv6 packet with unallowed source address from",
|
||||||
|
peer.String(),
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -580,7 +597,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// write to tun
|
// write to tun device
|
||||||
|
|
||||||
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
|
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
|
||||||
_, err := device.tun.device.Write(elem.packet)
|
_, err := device.tun.device.Write(elem.packet)
|
||||||
|
23
src/send.go
23
src/send.go
@ -2,7 +2,6 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
@ -105,26 +104,6 @@ func addToEncryptionQueue(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) SendBuffer(buffer []byte) (int, error) {
|
|
||||||
peer.device.net.mutex.RLock()
|
|
||||||
defer peer.device.net.mutex.RUnlock()
|
|
||||||
|
|
||||||
peer.mutex.RLock()
|
|
||||||
defer peer.mutex.RUnlock()
|
|
||||||
|
|
||||||
endpoint := peer.endpoint
|
|
||||||
if endpoint == nil {
|
|
||||||
return 0, errors.New("No known endpoint for peer")
|
|
||||||
}
|
|
||||||
|
|
||||||
conn := peer.device.net.conn
|
|
||||||
if conn == nil {
|
|
||||||
return 0, errors.New("No UDP socket for device")
|
|
||||||
}
|
|
||||||
|
|
||||||
return conn.WriteToUDP(buffer, endpoint)
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Reads packets from the TUN and inserts
|
/* Reads packets from the TUN and inserts
|
||||||
* into nonce queue for peer
|
* into nonce queue for peer
|
||||||
*
|
*
|
||||||
@ -343,7 +322,7 @@ func (peer *Peer) RoutineSequentialSender() {
|
|||||||
// send message and return buffer to pool
|
// send message and return buffer to pool
|
||||||
|
|
||||||
length := uint64(len(elem.packet))
|
length := uint64(len(elem.packet))
|
||||||
_, err := peer.SendBuffer(elem.packet)
|
err := peer.SendBuffer(elem.packet)
|
||||||
device.PutMessageBuffer(elem.buffer)
|
device.PutMessageBuffer(elem.buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logDebug.Println("Failed to send authenticated packet to peer", peer.String())
|
logDebug.Println("Failed to send authenticated packet to peer", peer.String())
|
||||||
|
@ -20,6 +20,14 @@
|
|||||||
# wireguard peers in $ns1 and $ns2. Note that $ns0 is the endpoint for the wg1
|
# wireguard peers in $ns1 and $ns2. Note that $ns0 is the endpoint for the wg1
|
||||||
# interfaces in $ns1 and $ns2. See https://www.wireguard.com/netns/ for further
|
# interfaces in $ns1 and $ns2. See https://www.wireguard.com/netns/ for further
|
||||||
# details on how this is accomplished.
|
# details on how this is accomplished.
|
||||||
|
|
||||||
|
# This code is ported to the WireGuard-Go directly from the kernel project.
|
||||||
|
#
|
||||||
|
# Please ensure that you have installed the newest version of the WireGuard
|
||||||
|
# tools from the WireGuard project and before running these tests as:
|
||||||
|
#
|
||||||
|
# ./netns.sh <path to wireguard-go>
|
||||||
|
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
exec 3>&1
|
exec 3>&1
|
||||||
@ -27,8 +35,8 @@ export WG_HIDE_KEYS=never
|
|||||||
netns0="wg-test-$$-0"
|
netns0="wg-test-$$-0"
|
||||||
netns1="wg-test-$$-1"
|
netns1="wg-test-$$-1"
|
||||||
netns2="wg-test-$$-2"
|
netns2="wg-test-$$-2"
|
||||||
program="../wireguard-go"
|
program=$1
|
||||||
export LOG_LEVEL="error"
|
export LOG_LEVEL="info"
|
||||||
|
|
||||||
pretty() { echo -e "\x1b[32m\x1b[1m[+] ${1:+NS$1: }${2}\x1b[0m" >&3; }
|
pretty() { echo -e "\x1b[32m\x1b[1m[+] ${1:+NS$1: }${2}\x1b[0m" >&3; }
|
||||||
pp() { pretty "" "$*"; "$@"; }
|
pp() { pretty "" "$*"; "$@"; }
|
||||||
@ -72,13 +80,11 @@ pp ip netns add $netns2
|
|||||||
ip0 link set up dev lo
|
ip0 link set up dev lo
|
||||||
|
|
||||||
# ip0 link add dev wg1 type wireguard
|
# ip0 link add dev wg1 type wireguard
|
||||||
n0 $program -f wg1 &
|
n0 $program wg1
|
||||||
sleep 1
|
|
||||||
ip0 link set wg1 netns $netns1
|
ip0 link set wg1 netns $netns1
|
||||||
|
|
||||||
# ip0 link add dev wg1 type wireguard
|
# ip0 link add dev wg1 type wireguard
|
||||||
n0 $program -f wg2 &
|
n0 $program wg2
|
||||||
sleep 1
|
|
||||||
ip0 link set wg2 netns $netns2
|
ip0 link set wg2 netns $netns2
|
||||||
|
|
||||||
key1="$(pp wg genkey)"
|
key1="$(pp wg genkey)"
|
||||||
@ -185,14 +191,14 @@ ip0 -4 addr del 127.0.0.1/8 dev lo
|
|||||||
ip0 -4 addr add 127.212.121.99/8 dev lo
|
ip0 -4 addr add 127.212.121.99/8 dev lo
|
||||||
n0 wg set wg1 listen-port 9999
|
n0 wg set wg1 listen-port 9999
|
||||||
n0 wg set wg1 peer "$pub2" endpoint 127.0.0.1:20000
|
n0 wg set wg1 peer "$pub2" endpoint 127.0.0.1:20000
|
||||||
n1 ping6 -W 1 -c 1 fd00::20000
|
n1 ping6 -W 1 -c 1 fd00::2
|
||||||
[[ $(n2 wg show wg2 endpoints) == "$pub1 127.212.121.99:9999" ]]
|
[[ $(n2 wg show wg2 endpoints) == "$pub1 127.212.121.99:9999" ]]
|
||||||
|
|
||||||
# Test using IPv6 that roaming works
|
# Test using IPv6 that roaming works
|
||||||
n1 wg set wg1 listen-port 9998
|
n1 wg set wg1 listen-port 9998
|
||||||
n1 wg set wg1 peer "$pub2" endpoint [::1]:20000
|
n1 wg set wg1 peer "$pub2" endpoint [::1]:20000
|
||||||
n1 ping -W 1 -c 1 192.168.241.2
|
n1 ping -W 1 -c 1 192.168.241.2
|
||||||
[[ $(n2 wg show wg2 endpoints) == "$pub1 [::1]:9998" ]]
|
[[ $(n2 wg show wg2 endpoints) == "$pub1 [::1]:9998" ]]
|
||||||
|
|
||||||
# Test that crypto-RP filter works
|
# Test that crypto-RP filter works
|
||||||
n1 wg set wg1 peer "$pub2" allowed-ips 192.168.241.0/24
|
n1 wg set wg1 peer "$pub2" allowed-ips 192.168.241.0/24
|
||||||
@ -212,7 +218,7 @@ n2 ncat -u 192.168.241.1 1111 <<<"X"
|
|||||||
! read -r -N 1 -t 1 out <&4
|
! read -r -N 1 -t 1 out <&4
|
||||||
kill $nmap_pid
|
kill $nmap_pid
|
||||||
n0 wg set wg1 peer "$more_specific_key" remove
|
n0 wg set wg1 peer "$more_specific_key" remove
|
||||||
[[ $(n1 wg show wg1 endpoints) == "$pub2 [::1]:9997" ]]
|
[[ $(n1 wg show wg1 endpoints) == "$pub2 [::1]:9997" ]]
|
||||||
|
|
||||||
ip1 link del wg1
|
ip1 link del wg1
|
||||||
ip2 link del wg2
|
ip2 link del wg2
|
||||||
@ -263,7 +269,7 @@ n0 iptables -t nat -A POSTROUTING -s 192.168.1.0/24 -d 10.0.0.0/24 -j SNAT --to
|
|||||||
n0 wg set wg1 peer "$pub2" endpoint 10.0.0.100:20000 persistent-keepalive 1
|
n0 wg set wg1 peer "$pub2" endpoint 10.0.0.100:20000 persistent-keepalive 1
|
||||||
n1 ping -W 1 -c 1 192.168.241.2
|
n1 ping -W 1 -c 1 192.168.241.2
|
||||||
n2 ping -W 1 -c 1 192.168.241.1
|
n2 ping -W 1 -c 1 192.168.241.1
|
||||||
[[ $(n2 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]]
|
[[ $(n2 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]]
|
||||||
# Demonstrate n2 can still send packets to n1, since persistent-keepalive will prevent connection tracking entry from expiring (to see entries: `n0 conntrack -L`).
|
# Demonstrate n2 can still send packets to n1, since persistent-keepalive will prevent connection tracking entry from expiring (to see entries: `n0 conntrack -L`).
|
||||||
pp sleep 3
|
pp sleep 3
|
||||||
n2 ping -W 1 -c 1 192.168.241.1
|
n2 ping -W 1 -c 1 192.168.241.1
|
||||||
@ -289,7 +295,7 @@ ip2 link del wg2
|
|||||||
# ip1 link add dev wg1 type wireguard
|
# ip1 link add dev wg1 type wireguard
|
||||||
# ip2 link add dev wg1 type wireguard
|
# ip2 link add dev wg1 type wireguard
|
||||||
n1 $program wg1
|
n1 $program wg1
|
||||||
n2 $program wg1
|
n2 $program wg2
|
||||||
|
|
||||||
configure_peers
|
configure_peers
|
||||||
|
|
||||||
@ -336,17 +342,83 @@ waitiface $netns1 veth1
|
|||||||
waitiface $netns2 veth2
|
waitiface $netns2 veth2
|
||||||
n0 wg set wg2 peer "$pub1" endpoint 10.0.0.1:10000
|
n0 wg set wg2 peer "$pub1" endpoint 10.0.0.1:10000
|
||||||
n2 ping -W 1 -c 1 192.168.241.1
|
n2 ping -W 1 -c 1 192.168.241.1
|
||||||
[[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]]
|
[[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]]
|
||||||
n0 wg set wg2 peer "$pub1" endpoint [fd00:aa::1]:10000
|
n0 wg set wg2 peer "$pub1" endpoint [fd00:aa::1]:10000
|
||||||
n2 ping -W 1 -c 1 192.168.241.1
|
n2 ping -W 1 -c 1 192.168.241.1
|
||||||
[[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::1]:10000" ]]
|
[[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::1]:10000" ]]
|
||||||
n0 wg set wg2 peer "$pub1" endpoint 10.0.0.2:10000
|
n0 wg set wg2 peer "$pub1" endpoint 10.0.0.2:10000
|
||||||
n2 ping -W 1 -c 1 192.168.241.1
|
n2 ping -W 1 -c 1 192.168.241.1
|
||||||
[[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.2:10000" ]]
|
[[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.2:10000" ]]
|
||||||
n0 wg set wg2 peer "$pub1" endpoint [fd00:aa::2]:10000
|
n0 wg set wg2 peer "$pub1" endpoint [fd00:aa::2]:10000
|
||||||
n2 ping -W 1 -c 1 192.168.241.1
|
n2 ping -W 1 -c 1 192.168.241.1
|
||||||
[[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::2]:10000" ]]
|
[[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::2]:10000" ]]
|
||||||
|
|
||||||
ip1 link del veth1
|
ip1 link del veth1
|
||||||
ip1 link del wg1
|
ip1 link del wg1
|
||||||
ip2 link del wg2
|
ip2 link del wg2
|
||||||
|
|
||||||
|
# Test that Netlink/IPC is working properly by doing things that usually cause split responses
|
||||||
|
|
||||||
|
n0 $program wg0
|
||||||
|
sleep 5
|
||||||
|
config=( "[Interface]" "PrivateKey=$(wg genkey)" "[Peer]" "PublicKey=$(wg genkey)" )
|
||||||
|
for a in {1..255}; do
|
||||||
|
for b in {0..255}; do
|
||||||
|
config+=( "AllowedIPs=$a.$b.0.0/16,$a::$b/128" )
|
||||||
|
done
|
||||||
|
done
|
||||||
|
n0 wg setconf wg0 <(printf '%s\n' "${config[@]}")
|
||||||
|
i=0
|
||||||
|
for ip in $(n0 wg show wg0 allowed-ips); do
|
||||||
|
((++i))
|
||||||
|
done
|
||||||
|
((i == 255*256*2+1))
|
||||||
|
ip0 link del wg0
|
||||||
|
|
||||||
|
n0 $program wg0
|
||||||
|
config=( "[Interface]" "PrivateKey=$(wg genkey)" )
|
||||||
|
for a in {1..40}; do
|
||||||
|
config+=( "[Peer]" "PublicKey=$(wg genkey)" )
|
||||||
|
for b in {1..52}; do
|
||||||
|
config+=( "AllowedIPs=$a.$b.0.0/16" )
|
||||||
|
done
|
||||||
|
done
|
||||||
|
n0 wg setconf wg0 <(printf '%s\n' "${config[@]}")
|
||||||
|
i=0
|
||||||
|
while read -r line; do
|
||||||
|
j=0
|
||||||
|
for ip in $line; do
|
||||||
|
((++j))
|
||||||
|
done
|
||||||
|
((j == 53))
|
||||||
|
((++i))
|
||||||
|
done < <(n0 wg show wg0 allowed-ips)
|
||||||
|
((i == 40))
|
||||||
|
ip0 link del wg0
|
||||||
|
|
||||||
|
n0 $program wg0
|
||||||
|
config=( )
|
||||||
|
for i in {1..29}; do
|
||||||
|
config+=( "[Peer]" "PublicKey=$(wg genkey)" )
|
||||||
|
done
|
||||||
|
config+=( "[Peer]" "PublicKey=$(wg genkey)" "AllowedIPs=255.2.3.4/32,abcd::255/128" )
|
||||||
|
n0 wg setconf wg0 <(printf '%s\n' "${config[@]}")
|
||||||
|
n0 wg showconf wg0 > /dev/null
|
||||||
|
ip0 link del wg0
|
||||||
|
|
||||||
|
! n0 wg show doesnotexist || false
|
||||||
|
|
||||||
|
declare -A objects
|
||||||
|
while read -t 0.1 -r line 2>/dev/null || [[ $? -ne 142 ]]; do
|
||||||
|
[[ $line =~ .*(wg[0-9]+:\ [A-Z][a-z]+\ [0-9]+)\ .*(created|destroyed).* ]] || continue
|
||||||
|
objects["${BASH_REMATCH[1]}"]+="${BASH_REMATCH[2]}"
|
||||||
|
done < /dev/kmsg
|
||||||
|
alldeleted=1
|
||||||
|
for object in "${!objects[@]}"; do
|
||||||
|
if [[ ${objects["$object"]} != *createddestroyed ]]; then
|
||||||
|
echo "Error: $object: merely ${objects["$object"]}" >&3
|
||||||
|
alldeleted=0
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
[[ $alldeleted -eq 1 ]]
|
||||||
|
pretty "" "Objects that were created were also destroyed."
|
||||||
|
@ -279,34 +279,31 @@ func (peer *Peer) RoutineHandshakeInitiator() {
|
|||||||
break AttemptHandshakes
|
break AttemptHandshakes
|
||||||
}
|
}
|
||||||
|
|
||||||
jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
|
// marshal handshake message
|
||||||
|
|
||||||
// marshal and send
|
|
||||||
|
|
||||||
writer := bytes.NewBuffer(temp[:0])
|
writer := bytes.NewBuffer(temp[:0])
|
||||||
binary.Write(writer, binary.LittleEndian, msg)
|
binary.Write(writer, binary.LittleEndian, msg)
|
||||||
packet := writer.Bytes()
|
packet := writer.Bytes()
|
||||||
peer.mac.AddMacs(packet)
|
peer.mac.AddMacs(packet)
|
||||||
|
|
||||||
_, err = peer.SendBuffer(packet)
|
// send to endpoint
|
||||||
if err != nil {
|
|
||||||
|
err = peer.SendBuffer(packet)
|
||||||
|
jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
|
||||||
|
timeout := time.NewTimer(RekeyTimeout + jitter)
|
||||||
|
if err == nil {
|
||||||
|
peer.TimerAnyAuthenticatedPacketTraversal()
|
||||||
|
logDebug.Println(
|
||||||
|
"Handshake initiation attempt",
|
||||||
|
attempts, "sent to", peer.String(),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
logError.Println(
|
logError.Println(
|
||||||
"Failed to send handshake initiation message to",
|
"Failed to send handshake initiation message to",
|
||||||
peer.String(), ":", err,
|
peer.String(), ":", err,
|
||||||
)
|
)
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
peer.TimerAnyAuthenticatedPacketTraversal()
|
|
||||||
|
|
||||||
// set handshake timeout
|
|
||||||
|
|
||||||
timeout := time.NewTimer(RekeyTimeout + jitter)
|
|
||||||
logDebug.Println(
|
|
||||||
"Handshake initiation attempt",
|
|
||||||
attempts, "sent to", peer.String(),
|
|
||||||
)
|
|
||||||
|
|
||||||
// wait for handshake or timeout
|
// wait for handshake or timeout
|
||||||
|
|
||||||
select {
|
select {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"os"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -15,6 +16,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type TUNDevice interface {
|
type TUNDevice interface {
|
||||||
|
File() *os.File // returns the file descriptor of the device
|
||||||
Read([]byte) (int, error) // read a packet from the device (without any additional headers)
|
Read([]byte) (int, error) // read a packet from the device (without any additional headers)
|
||||||
Write([]byte) (int, error) // writes a packet to the device (without any additional headers)
|
Write([]byte) (int, error) // writes a packet to the device (without any additional headers)
|
||||||
MTU() (int, error) // returns the MTU of the device
|
MTU() (int, error) // returns the MTU of the device
|
||||||
@ -47,7 +49,7 @@ func (device *Device) RoutineTUNEventReader() {
|
|||||||
if !device.tun.isUp.Get() {
|
if !device.tun.isUp.Get() {
|
||||||
logInfo.Println("Interface set up")
|
logInfo.Println("Interface set up")
|
||||||
device.tun.isUp.Set(true)
|
device.tun.isUp.Set(true)
|
||||||
updateUDPConn(device)
|
UpdateUDPListener(device)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -55,7 +57,7 @@ func (device *Device) RoutineTUNEventReader() {
|
|||||||
if device.tun.isUp.Get() {
|
if device.tun.isUp.Get() {
|
||||||
logInfo.Println("Interface set down")
|
logInfo.Println("Interface set down")
|
||||||
device.tun.isUp.Set(false)
|
device.tun.isUp.Set(false)
|
||||||
closeUDPConn(device)
|
CloseUDPListener(device)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -56,6 +56,10 @@ type NativeTun struct {
|
|||||||
events chan TUNEvent // device related events
|
events chan TUNEvent // device related events
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (tun *NativeTun) File() *os.File {
|
||||||
|
return tun.fd
|
||||||
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) RoutineNetlinkListener() {
|
func (tun *NativeTun) RoutineNetlinkListener() {
|
||||||
sock := int(C.bind_rtmgrp())
|
sock := int(C.bind_rtmgrp())
|
||||||
if sock < 0 {
|
if sock < 0 {
|
||||||
@ -222,7 +226,7 @@ func (tun *NativeTun) MTU() (int, error) {
|
|||||||
|
|
||||||
val := binary.LittleEndian.Uint32(ifr[16:20])
|
val := binary.LittleEndian.Uint32(ifr[16:20])
|
||||||
if val >= (1 << 31) {
|
if val >= (1 << 31) {
|
||||||
return int(val-(1<<31)) - (1 << 31), nil
|
return int(toInt32(val)), nil
|
||||||
}
|
}
|
||||||
return int(val), nil
|
return int(val), nil
|
||||||
}
|
}
|
||||||
@ -248,6 +252,29 @@ func (tun *NativeTun) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CreateTUNFromFile(name string, fd *os.File) (TUNDevice, error) {
|
||||||
|
device := &NativeTun{
|
||||||
|
fd: fd,
|
||||||
|
name: name,
|
||||||
|
events: make(chan TUNEvent, 5),
|
||||||
|
errors: make(chan error, 5),
|
||||||
|
}
|
||||||
|
|
||||||
|
// start event listener
|
||||||
|
|
||||||
|
var err error
|
||||||
|
device.index, err = getIFIndex(device.name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
go device.RoutineNetlinkListener()
|
||||||
|
|
||||||
|
// set default MTU
|
||||||
|
|
||||||
|
return device, device.setMTU(DefaultMTU)
|
||||||
|
}
|
||||||
|
|
||||||
func CreateTUN(name string) (TUNDevice, error) {
|
func CreateTUN(name string) (TUNDevice, error) {
|
||||||
|
|
||||||
// open clone device
|
// open clone device
|
||||||
|
83
src/uapi.go
83
src/uapi.go
@ -39,9 +39,10 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
|||||||
send("private_key=" + device.privateKey.ToHex())
|
send("private_key=" + device.privateKey.ToHex())
|
||||||
}
|
}
|
||||||
|
|
||||||
if device.net.addr != nil {
|
if device.net.port != 0 {
|
||||||
send(fmt.Sprintf("listen_port=%d", device.net.addr.Port))
|
send(fmt.Sprintf("listen_port=%d", device.net.port))
|
||||||
}
|
}
|
||||||
|
|
||||||
if device.net.fwmark != 0 {
|
if device.net.fwmark != 0 {
|
||||||
send(fmt.Sprintf("fwmark=%d", device.net.fwmark))
|
send(fmt.Sprintf("fwmark=%d", device.net.fwmark))
|
||||||
}
|
}
|
||||||
@ -53,7 +54,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
|||||||
send("public_key=" + peer.handshake.remoteStatic.ToHex())
|
send("public_key=" + peer.handshake.remoteStatic.ToHex())
|
||||||
send("preshared_key=" + peer.handshake.presharedKey.ToHex())
|
send("preshared_key=" + peer.handshake.presharedKey.ToHex())
|
||||||
if peer.endpoint != nil {
|
if peer.endpoint != nil {
|
||||||
send("endpoint=" + peer.endpoint.String())
|
send("endpoint=" + peer.endpoint.DstToString())
|
||||||
}
|
}
|
||||||
|
|
||||||
nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
|
nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
|
||||||
@ -134,56 +135,38 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
|||||||
case "listen_port":
|
case "listen_port":
|
||||||
port, err := strconv.ParseUint(value, 10, 16)
|
port, err := strconv.ParseUint(value, 10, 16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Failed to set listen_port:", err)
|
logError.Println("Failed to parse listen_port:", err)
|
||||||
return &IPCError{Code: ipcErrorInvalid}
|
return &IPCError{Code: ipcErrorInvalid}
|
||||||
}
|
}
|
||||||
|
device.net.port = uint16(port)
|
||||||
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port))
|
if err := UpdateUDPListener(device); err != nil {
|
||||||
if err != nil {
|
|
||||||
logError.Println("Failed to set listen_port:", err)
|
|
||||||
return &IPCError{Code: ipcErrorInvalid}
|
|
||||||
}
|
|
||||||
|
|
||||||
device.net.mutex.Lock()
|
|
||||||
device.net.addr = addr
|
|
||||||
device.net.mutex.Unlock()
|
|
||||||
|
|
||||||
err = updateUDPConn(device)
|
|
||||||
if err != nil {
|
|
||||||
logError.Println("Failed to set listen_port:", err)
|
logError.Println("Failed to set listen_port:", err)
|
||||||
return &IPCError{Code: ipcErrorPortInUse}
|
return &IPCError{Code: ipcErrorPortInUse}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Clear source address of all peers
|
|
||||||
|
|
||||||
case "fwmark":
|
case "fwmark":
|
||||||
fwmark, err := strconv.ParseUint(value, 10, 32)
|
|
||||||
|
// parse fwmark field
|
||||||
|
|
||||||
|
fwmark, err := func() (uint32, error) {
|
||||||
|
if value == "" {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
mark, err := strconv.ParseUint(value, 10, 32)
|
||||||
|
return uint32(mark), err
|
||||||
|
}()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Invalid fwmark", err)
|
logError.Println("Invalid fwmark", err)
|
||||||
return &IPCError{Code: ipcErrorInvalid}
|
return &IPCError{Code: ipcErrorInvalid}
|
||||||
}
|
}
|
||||||
|
|
||||||
device.net.mutex.Lock()
|
device.net.mutex.Lock()
|
||||||
if fwmark > 0 || device.net.fwmark > 0 {
|
device.net.fwmark = uint32(fwmark)
|
||||||
device.net.fwmark = uint32(fwmark)
|
|
||||||
err := setMark(
|
|
||||||
device.net.conn,
|
|
||||||
device.net.fwmark,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
logError.Println("Failed to set fwmark:", err)
|
|
||||||
device.net.mutex.Unlock()
|
|
||||||
return &IPCError{Code: ipcErrorIO}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Clear source address of all peers
|
|
||||||
}
|
|
||||||
device.net.mutex.Unlock()
|
device.net.mutex.Unlock()
|
||||||
|
|
||||||
case "public_key":
|
case "public_key":
|
||||||
|
|
||||||
// switch to peer configuration
|
// switch to peer configuration
|
||||||
|
|
||||||
deviceConfig = false
|
deviceConfig = false
|
||||||
|
|
||||||
case "replace_peers":
|
case "replace_peers":
|
||||||
@ -218,7 +201,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
|||||||
device.mutex.RLock()
|
device.mutex.RLock()
|
||||||
if device.publicKey.Equals(pubKey) {
|
if device.publicKey.Equals(pubKey) {
|
||||||
|
|
||||||
// create dummy instance
|
// create dummy instance (not added to device)
|
||||||
|
|
||||||
peer = &Peer{}
|
peer = &Peer{}
|
||||||
dummy = true
|
dummy = true
|
||||||
@ -244,6 +227,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
case "remove":
|
case "remove":
|
||||||
|
|
||||||
|
// remove currently selected peer from device
|
||||||
|
|
||||||
if value != "true" {
|
if value != "true" {
|
||||||
logError.Println("Failed to set remove, invalid value:", value)
|
logError.Println("Failed to set remove, invalid value:", value)
|
||||||
return &IPCError{Code: ipcErrorInvalid}
|
return &IPCError{Code: ipcErrorInvalid}
|
||||||
@ -256,6 +242,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
|||||||
dummy = true
|
dummy = true
|
||||||
|
|
||||||
case "preshared_key":
|
case "preshared_key":
|
||||||
|
|
||||||
|
// update PSK
|
||||||
|
|
||||||
peer.mutex.Lock()
|
peer.mutex.Lock()
|
||||||
err := peer.handshake.presharedKey.FromHex(value)
|
err := peer.handshake.presharedKey.FromHex(value)
|
||||||
peer.mutex.Unlock()
|
peer.mutex.Unlock()
|
||||||
@ -265,15 +254,25 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
case "endpoint":
|
case "endpoint":
|
||||||
addr, err := parseEndpoint(value)
|
|
||||||
|
// set endpoint destination
|
||||||
|
|
||||||
|
err := func() error {
|
||||||
|
peer.mutex.Lock()
|
||||||
|
defer peer.mutex.Unlock()
|
||||||
|
endpoint, err := CreateEndpoint(value)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
peer.endpoint = endpoint
|
||||||
|
signalSend(peer.signal.handshakeReset)
|
||||||
|
return nil
|
||||||
|
}()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Failed to set endpoint:", value)
|
logError.Println("Failed to set endpoint:", value)
|
||||||
return &IPCError{Code: ipcErrorInvalid}
|
return &IPCError{Code: ipcErrorInvalid}
|
||||||
}
|
}
|
||||||
peer.mutex.Lock()
|
|
||||||
peer.endpoint = addr
|
|
||||||
peer.mutex.Unlock()
|
|
||||||
signalSend(peer.signal.handshakeReset)
|
|
||||||
|
|
||||||
case "persistent_keepalive_interval":
|
case "persistent_keepalive_interval":
|
||||||
|
|
||||||
|
@ -10,12 +10,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ipcErrorIO = -int64(unix.EIO)
|
ipcErrorIO = -int64(unix.EIO)
|
||||||
ipcErrorProtocol = -int64(unix.EPROTO)
|
ipcErrorProtocol = -int64(unix.EPROTO)
|
||||||
ipcErrorInvalid = -int64(unix.EINVAL)
|
ipcErrorInvalid = -int64(unix.EINVAL)
|
||||||
ipcErrorPortInUse = -int64(unix.EADDRINUSE)
|
ipcErrorPortInUse = -int64(unix.EADDRINUSE)
|
||||||
socketDirectory = "/var/run/wireguard"
|
socketDirectory = "/var/run/wireguard"
|
||||||
socketName = "%s.sock"
|
socketName = "%s.sock"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UAPIListener struct {
|
type UAPIListener struct {
|
||||||
@ -50,49 +50,11 @@ func (l *UAPIListener) Addr() net.Addr {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func connectUnixSocket(path string) (net.Listener, error) {
|
func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
||||||
|
|
||||||
// attempt inital connection
|
// wrap file in listener
|
||||||
|
|
||||||
listener, err := net.Listen("unix", path)
|
listener, err := net.FileListener(file)
|
||||||
if err == nil {
|
|
||||||
return listener, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if active
|
|
||||||
|
|
||||||
_, err = net.Dial("unix", path)
|
|
||||||
if err == nil {
|
|
||||||
return nil, errors.New("Unix socket in use")
|
|
||||||
}
|
|
||||||
|
|
||||||
// attempt cleanup
|
|
||||||
|
|
||||||
err = os.Remove(path)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return net.Listen("unix", path)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewUAPIListener(name string) (net.Listener, error) {
|
|
||||||
|
|
||||||
// check if path exist
|
|
||||||
|
|
||||||
err := os.MkdirAll(socketDirectory, 077)
|
|
||||||
if err != nil && !os.IsExist(err) {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// open UNIX socket
|
|
||||||
|
|
||||||
socketPath := path.Join(
|
|
||||||
socketDirectory,
|
|
||||||
fmt.Sprintf(socketName, name),
|
|
||||||
)
|
|
||||||
|
|
||||||
listener, err := connectUnixSocket(socketPath)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -105,6 +67,11 @@ func NewUAPIListener(name string) (net.Listener, error) {
|
|||||||
|
|
||||||
// watch for deletion of socket
|
// watch for deletion of socket
|
||||||
|
|
||||||
|
socketPath := path.Join(
|
||||||
|
socketDirectory,
|
||||||
|
fmt.Sprintf(socketName, name),
|
||||||
|
)
|
||||||
|
|
||||||
uapi.inotifyFd, err = unix.InotifyInit()
|
uapi.inotifyFd, err = unix.InotifyInit()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -125,11 +92,12 @@ func NewUAPIListener(name string) (net.Listener, error) {
|
|||||||
go func(l *UAPIListener) {
|
go func(l *UAPIListener) {
|
||||||
var buff [4096]byte
|
var buff [4096]byte
|
||||||
for {
|
for {
|
||||||
unix.Read(uapi.inotifyFd, buff[:])
|
// start with lstat to avoid race condition
|
||||||
if _, err := os.Lstat(socketPath); os.IsNotExist(err) {
|
if _, err := os.Lstat(socketPath); os.IsNotExist(err) {
|
||||||
l.connErr <- err
|
l.connErr <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
unix.Read(uapi.inotifyFd, buff[:])
|
||||||
}
|
}
|
||||||
}(uapi)
|
}(uapi)
|
||||||
|
|
||||||
@ -148,3 +116,56 @@ func NewUAPIListener(name string) (net.Listener, error) {
|
|||||||
|
|
||||||
return uapi, nil
|
return uapi, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func UAPIOpen(name string) (*os.File, error) {
|
||||||
|
|
||||||
|
// check if path exist
|
||||||
|
|
||||||
|
err := os.MkdirAll(socketDirectory, 0600)
|
||||||
|
if err != nil && !os.IsExist(err) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// open UNIX socket
|
||||||
|
|
||||||
|
socketPath := path.Join(
|
||||||
|
socketDirectory,
|
||||||
|
fmt.Sprintf(socketName, name),
|
||||||
|
)
|
||||||
|
|
||||||
|
addr, err := net.ResolveUnixAddr("unix", socketPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
listener, err := func() (*net.UnixListener, error) {
|
||||||
|
|
||||||
|
// initial connection attempt
|
||||||
|
|
||||||
|
listener, err := net.ListenUnix("unix", addr)
|
||||||
|
if err == nil {
|
||||||
|
return listener, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if socket already active
|
||||||
|
|
||||||
|
_, err = net.Dial("unix", socketPath)
|
||||||
|
if err == nil {
|
||||||
|
return nil, errors.New("unix socket in use")
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanup & attempt again
|
||||||
|
|
||||||
|
err = os.Remove(socketPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return net.ListenUnix("unix", addr)
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return listener.File()
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user