Merge branch 'source-caching'
This commit is contained in:
		
						commit
						b5ae42349c
					
				
							
								
								
									
										115
									
								
								src/conn.go
									
									
									
									
									
								
							
							
						
						
									
										115
									
								
								src/conn.go
									
									
									
									
									
								
							@ -2,10 +2,35 @@ package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"golang.org/x/net/ipv4"
 | 
			
		||||
	"golang.org/x/net/ipv6"
 | 
			
		||||
	"net"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic
 | 
			
		||||
 */
 | 
			
		||||
type Bind interface {
 | 
			
		||||
	SetMark(value uint32) error
 | 
			
		||||
	ReceiveIPv6(buff []byte) (int, Endpoint, error)
 | 
			
		||||
	ReceiveIPv4(buff []byte) (int, Endpoint, error)
 | 
			
		||||
	Send(buff []byte, end Endpoint) error
 | 
			
		||||
	Close() error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* An Endpoint maintains the source/destination caching for a peer
 | 
			
		||||
 *
 | 
			
		||||
 * dst : the remote address of a peer ("endpoint" in uapi terminology)
 | 
			
		||||
 * src : the local address from which datagrams originate going to the peer
 | 
			
		||||
 */
 | 
			
		||||
type Endpoint interface {
 | 
			
		||||
	ClearSrc()           // clears the source address
 | 
			
		||||
	SrcToString() string // returns the local source address (ip:port)
 | 
			
		||||
	DstToString() string // returns the destination address (ip:port)
 | 
			
		||||
	DstToBytes() []byte  // used for mac2 cookie calculations
 | 
			
		||||
	DstIP() net.IP
 | 
			
		||||
	SrcIP() net.IP
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func parseEndpoint(s string) (*net.UDPAddr, error) {
 | 
			
		||||
 | 
			
		||||
	// ensure that the host is an IP address
 | 
			
		||||
@ -27,63 +52,83 @@ func parseEndpoint(s string) (*net.UDPAddr, error) {
 | 
			
		||||
	return addr, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateUDPConn(device *Device) error {
 | 
			
		||||
/* Must hold device and net lock
 | 
			
		||||
 */
 | 
			
		||||
func unsafeCloseUDPListener(device *Device) error {
 | 
			
		||||
	var err error
 | 
			
		||||
	netc := &device.net
 | 
			
		||||
	if netc.bind != nil {
 | 
			
		||||
		err = netc.bind.Close()
 | 
			
		||||
		netc.bind = nil
 | 
			
		||||
	}
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// must inform all listeners
 | 
			
		||||
func UpdateUDPListener(device *Device) error {
 | 
			
		||||
	device.mutex.Lock()
 | 
			
		||||
	defer device.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	netc := &device.net
 | 
			
		||||
	netc.mutex.Lock()
 | 
			
		||||
	defer netc.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	// close existing connection
 | 
			
		||||
	// close existing sockets
 | 
			
		||||
 | 
			
		||||
	if netc.conn != nil {
 | 
			
		||||
		netc.conn.Close()
 | 
			
		||||
		netc.conn = nil
 | 
			
		||||
 | 
			
		||||
		// We need for that fd to be closed in all other go routines, which
 | 
			
		||||
		// means we have to wait. TODO: find less horrible way of doing this.
 | 
			
		||||
		time.Sleep(time.Second / 2)
 | 
			
		||||
	if err := unsafeCloseUDPListener(device); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// open new connection
 | 
			
		||||
	// assumption: netc.update WaitGroup should be exactly 1
 | 
			
		||||
 | 
			
		||||
	// open new sockets
 | 
			
		||||
 | 
			
		||||
	if device.tun.isUp.Get() {
 | 
			
		||||
 | 
			
		||||
		// listen on new address
 | 
			
		||||
		device.log.Debug.Println("UDP bind updating")
 | 
			
		||||
 | 
			
		||||
		conn, err := net.ListenUDP("udp", netc.addr)
 | 
			
		||||
		// bind to new port
 | 
			
		||||
 | 
			
		||||
		var err error
 | 
			
		||||
		netc.bind, netc.port, err = CreateBind(netc.port)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			netc.bind = nil
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// set mark
 | 
			
		||||
 | 
			
		||||
		err = netc.bind.SetMark(netc.fwmark)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// set fwmark
 | 
			
		||||
		// clear cached source addresses
 | 
			
		||||
 | 
			
		||||
		err = setMark(netc.conn, netc.fwmark)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		for _, peer := range device.peers {
 | 
			
		||||
			peer.mutex.Lock()
 | 
			
		||||
			if peer.endpoint != nil {
 | 
			
		||||
				peer.endpoint.ClearSrc()
 | 
			
		||||
			}
 | 
			
		||||
			peer.mutex.Unlock()
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// retrieve port (may have been chosen by kernel)
 | 
			
		||||
		// decrease waitgroup to 0
 | 
			
		||||
 | 
			
		||||
		addr := conn.LocalAddr()
 | 
			
		||||
		netc.conn = conn
 | 
			
		||||
		netc.addr, _ = net.ResolveUDPAddr(
 | 
			
		||||
			addr.Network(),
 | 
			
		||||
			addr.String(),
 | 
			
		||||
		)
 | 
			
		||||
		go device.RoutineReceiveIncomming(ipv4.Version, netc.bind)
 | 
			
		||||
		go device.RoutineReceiveIncomming(ipv6.Version, netc.bind)
 | 
			
		||||
 | 
			
		||||
		// notify goroutines
 | 
			
		||||
 | 
			
		||||
		signalSend(device.signal.newUDPConn)
 | 
			
		||||
		device.log.Debug.Println("UDP bind has been updated")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func closeUDPConn(device *Device) {
 | 
			
		||||
	netc := &device.net
 | 
			
		||||
	netc.mutex.Lock()
 | 
			
		||||
	if netc.conn != nil {
 | 
			
		||||
		netc.conn.Close()
 | 
			
		||||
	}
 | 
			
		||||
	netc.mutex.Unlock()
 | 
			
		||||
	signalSend(device.signal.newUDPConn)
 | 
			
		||||
func CloseUDPListener(device *Device) error {
 | 
			
		||||
	device.mutex.Lock()
 | 
			
		||||
	device.net.mutex.Lock()
 | 
			
		||||
	err := unsafeCloseUDPListener(device)
 | 
			
		||||
	device.net.mutex.Unlock()
 | 
			
		||||
	device.mutex.Unlock()
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -6,6 +6,126 @@ import (
 | 
			
		||||
	"net"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func setMark(conn *net.UDPConn, value uint32) error {
 | 
			
		||||
/* This code is meant to be a temporary solution
 | 
			
		||||
 * on platforms for which the sticky socket / source caching behavior
 | 
			
		||||
 * has not yet been implemented.
 | 
			
		||||
 *
 | 
			
		||||
 * See conn_linux.go for an implementation on the linux platform.
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
type NativeBind struct {
 | 
			
		||||
	ipv4 *net.UDPConn
 | 
			
		||||
	ipv6 *net.UDPConn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type NativeEndpoint net.UDPAddr
 | 
			
		||||
 | 
			
		||||
var _ Bind = (*NativeBind)(nil)
 | 
			
		||||
var _ Endpoint = (*NativeEndpoint)(nil)
 | 
			
		||||
 | 
			
		||||
func CreateEndpoint(s string) (Endpoint, error) {
 | 
			
		||||
	addr, err := parseEndpoint(s)
 | 
			
		||||
	return (*NativeEndpoint)(addr), err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (_ *NativeEndpoint) ClearSrc() {}
 | 
			
		||||
 | 
			
		||||
func (e *NativeEndpoint) DstIP() net.IP {
 | 
			
		||||
	return (*net.UDPAddr)(e).IP
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (e *NativeEndpoint) SrcIP() net.IP {
 | 
			
		||||
	return nil // not supported
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (e *NativeEndpoint) DstToBytes() []byte {
 | 
			
		||||
	addr := (*net.UDPAddr)(e)
 | 
			
		||||
	out := addr.IP
 | 
			
		||||
	out = append(out, byte(addr.Port&0xff))
 | 
			
		||||
	out = append(out, byte((addr.Port>>8)&0xff))
 | 
			
		||||
	return out
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (e *NativeEndpoint) DstToString() string {
 | 
			
		||||
	return (*net.UDPAddr)(e).String()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (e *NativeEndpoint) SrcToString() string {
 | 
			
		||||
	return ""
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func listenNet(network string, port int) (*net.UDPConn, int, error) {
 | 
			
		||||
 | 
			
		||||
	// listen
 | 
			
		||||
 | 
			
		||||
	conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, 0, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// retrieve port
 | 
			
		||||
 | 
			
		||||
	laddr := conn.LocalAddr()
 | 
			
		||||
	uaddr, err := net.ResolveUDPAddr(
 | 
			
		||||
		laddr.Network(),
 | 
			
		||||
		laddr.String(),
 | 
			
		||||
	)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, 0, err
 | 
			
		||||
	}
 | 
			
		||||
	return conn, uaddr.Port, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CreateBind(uport uint16) (Bind, uint16, error) {
 | 
			
		||||
	var err error
 | 
			
		||||
	var bind NativeBind
 | 
			
		||||
 | 
			
		||||
	port := int(uport)
 | 
			
		||||
 | 
			
		||||
	bind.ipv4, port, err = listenNet("udp4", port)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, 0, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	bind.ipv6, port, err = listenNet("udp6", port)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		bind.ipv4.Close()
 | 
			
		||||
		return nil, 0, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &bind, uint16(port), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (bind *NativeBind) Close() error {
 | 
			
		||||
	err1 := bind.ipv4.Close()
 | 
			
		||||
	err2 := bind.ipv6.Close()
 | 
			
		||||
	if err1 != nil {
 | 
			
		||||
		return err1
 | 
			
		||||
	}
 | 
			
		||||
	return err2
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
 | 
			
		||||
	n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
 | 
			
		||||
	return n, (*NativeEndpoint)(endpoint), err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
 | 
			
		||||
	n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
 | 
			
		||||
	return n, (*NativeEndpoint)(endpoint), err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (bind *NativeBind) Send(buff []byte, endpoint Endpoint) error {
 | 
			
		||||
	var err error
 | 
			
		||||
	nend := endpoint.(*NativeEndpoint)
 | 
			
		||||
	if nend.IP.To16() != nil {
 | 
			
		||||
		_, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
 | 
			
		||||
	} else {
 | 
			
		||||
		_, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
 | 
			
		||||
	}
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (bind *NativeBind) SetMark(_ uint32) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -7,6 +7,7 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/binary"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"golang.org/x/sys/unix"
 | 
			
		||||
	"net"
 | 
			
		||||
@ -15,20 +16,230 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/* Supports source address caching
 | 
			
		||||
 *
 | 
			
		||||
 * It is important that the endpoint is only updated after the packet content has been authenticated.
 | 
			
		||||
 *
 | 
			
		||||
 * Currently there is no way to achieve this within the net package:
 | 
			
		||||
 * See e.g. https://github.com/golang/go/issues/17930
 | 
			
		||||
 * So this code is remains platform dependent.
 | 
			
		||||
 */
 | 
			
		||||
type Endpoint struct {
 | 
			
		||||
	// source (selected based on dst type)
 | 
			
		||||
	// (could use RawSockaddrAny and unsafe)
 | 
			
		||||
	srcIPv6 unix.RawSockaddrInet6
 | 
			
		||||
	srcIPv4 unix.RawSockaddrInet4
 | 
			
		||||
	srcIf4  int32
 | 
			
		||||
type NativeEndpoint struct {
 | 
			
		||||
	src unix.RawSockaddrInet6
 | 
			
		||||
	dst unix.RawSockaddrInet6
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
	dst unix.RawSockaddrAny
 | 
			
		||||
type NativeBind struct {
 | 
			
		||||
	sock4 int
 | 
			
		||||
	sock6 int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ Endpoint = (*NativeEndpoint)(nil)
 | 
			
		||||
var _ Bind = NativeBind{}
 | 
			
		||||
 | 
			
		||||
type IPv4Source struct {
 | 
			
		||||
	src     unix.RawSockaddrInet4
 | 
			
		||||
	Ifindex int32
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func htons(val uint16) uint16 {
 | 
			
		||||
	var out [unsafe.Sizeof(val)]byte
 | 
			
		||||
	binary.BigEndian.PutUint16(out[:], val)
 | 
			
		||||
	return *((*uint16)(unsafe.Pointer(&out[0])))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ntohs(val uint16) uint16 {
 | 
			
		||||
	tmp := ((*[unsafe.Sizeof(val)]byte)(unsafe.Pointer(&val)))
 | 
			
		||||
	return binary.BigEndian.Uint16((*tmp)[:])
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CreateEndpoint(s string) (Endpoint, error) {
 | 
			
		||||
	var end NativeEndpoint
 | 
			
		||||
	addr, err := parseEndpoint(s)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ipv4 := addr.IP.To4()
 | 
			
		||||
	if ipv4 != nil {
 | 
			
		||||
		dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
 | 
			
		||||
		dst.Family = unix.AF_INET
 | 
			
		||||
		dst.Port = htons(uint16(addr.Port))
 | 
			
		||||
		dst.Zero = [8]byte{}
 | 
			
		||||
		copy(dst.Addr[:], ipv4)
 | 
			
		||||
		end.ClearSrc()
 | 
			
		||||
		return &end, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ipv6 := addr.IP.To16()
 | 
			
		||||
	if ipv6 != nil {
 | 
			
		||||
		zone, err := zoneToUint32(addr.Zone)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		dst := &end.dst
 | 
			
		||||
		dst.Family = unix.AF_INET6
 | 
			
		||||
		dst.Port = htons(uint16(addr.Port))
 | 
			
		||||
		dst.Flowinfo = 0
 | 
			
		||||
		dst.Scope_id = zone
 | 
			
		||||
		copy(dst.Addr[:], ipv6[:])
 | 
			
		||||
		end.ClearSrc()
 | 
			
		||||
		return &end, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, errors.New("Failed to recognize IP address format")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CreateBind(port uint16) (Bind, uint16, error) {
 | 
			
		||||
	var err error
 | 
			
		||||
	var bind NativeBind
 | 
			
		||||
 | 
			
		||||
	bind.sock6, port, err = create6(port)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, port, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	bind.sock4, port, err = create4(port)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		unix.Close(bind.sock6)
 | 
			
		||||
	}
 | 
			
		||||
	return bind, port, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (bind NativeBind) SetMark(value uint32) error {
 | 
			
		||||
	err := unix.SetsockoptInt(
 | 
			
		||||
		bind.sock6,
 | 
			
		||||
		unix.SOL_SOCKET,
 | 
			
		||||
		unix.SO_MARK,
 | 
			
		||||
		int(value),
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return unix.SetsockoptInt(
 | 
			
		||||
		bind.sock4,
 | 
			
		||||
		unix.SOL_SOCKET,
 | 
			
		||||
		unix.SO_MARK,
 | 
			
		||||
		int(value),
 | 
			
		||||
	)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func closeUnblock(fd int) error {
 | 
			
		||||
	// shutdown to unblock readers
 | 
			
		||||
	unix.Shutdown(fd, unix.SHUT_RD)
 | 
			
		||||
	return unix.Close(fd)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (bind NativeBind) Close() error {
 | 
			
		||||
	err1 := closeUnblock(bind.sock6)
 | 
			
		||||
	err2 := closeUnblock(bind.sock4)
 | 
			
		||||
	if err1 != nil {
 | 
			
		||||
		return err1
 | 
			
		||||
	}
 | 
			
		||||
	return err2
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (bind NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
 | 
			
		||||
	var end NativeEndpoint
 | 
			
		||||
	n, err := receive6(
 | 
			
		||||
		bind.sock6,
 | 
			
		||||
		buff,
 | 
			
		||||
		&end,
 | 
			
		||||
	)
 | 
			
		||||
	return n, &end, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (bind NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
 | 
			
		||||
	var end NativeEndpoint
 | 
			
		||||
	n, err := receive4(
 | 
			
		||||
		bind.sock4,
 | 
			
		||||
		buff,
 | 
			
		||||
		&end,
 | 
			
		||||
	)
 | 
			
		||||
	return n, &end, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (bind NativeBind) Send(buff []byte, end Endpoint) error {
 | 
			
		||||
	nend := end.(*NativeEndpoint)
 | 
			
		||||
	switch nend.dst.Family {
 | 
			
		||||
	case unix.AF_INET6:
 | 
			
		||||
		return send6(bind.sock6, nend, buff)
 | 
			
		||||
	case unix.AF_INET:
 | 
			
		||||
		return send4(bind.sock4, nend, buff)
 | 
			
		||||
	default:
 | 
			
		||||
		return errors.New("Unknown address family of destination")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func sockaddrToString(addr unix.RawSockaddrInet6) string {
 | 
			
		||||
	var udpAddr net.UDPAddr
 | 
			
		||||
 | 
			
		||||
	switch addr.Family {
 | 
			
		||||
	case unix.AF_INET6:
 | 
			
		||||
		udpAddr.Port = int(ntohs(addr.Port))
 | 
			
		||||
		udpAddr.IP = addr.Addr[:]
 | 
			
		||||
		return udpAddr.String()
 | 
			
		||||
 | 
			
		||||
	case unix.AF_INET:
 | 
			
		||||
		ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr))
 | 
			
		||||
		udpAddr.Port = int(ntohs(ptr.Port))
 | 
			
		||||
		udpAddr.IP = net.IPv4(
 | 
			
		||||
			ptr.Addr[0],
 | 
			
		||||
			ptr.Addr[1],
 | 
			
		||||
			ptr.Addr[2],
 | 
			
		||||
			ptr.Addr[3],
 | 
			
		||||
		)
 | 
			
		||||
		return udpAddr.String()
 | 
			
		||||
 | 
			
		||||
	default:
 | 
			
		||||
		return "<unknown address family>"
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func rawAddrToIP(addr unix.RawSockaddrInet6) net.IP {
 | 
			
		||||
	switch addr.Family {
 | 
			
		||||
	case unix.AF_INET6:
 | 
			
		||||
		return addr.Addr[:]
 | 
			
		||||
	case unix.AF_INET:
 | 
			
		||||
		ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr))
 | 
			
		||||
		return net.IPv4(
 | 
			
		||||
			ptr.Addr[0],
 | 
			
		||||
			ptr.Addr[1],
 | 
			
		||||
			ptr.Addr[2],
 | 
			
		||||
			ptr.Addr[3],
 | 
			
		||||
		)
 | 
			
		||||
	default:
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (end *NativeEndpoint) SrcIP() net.IP {
 | 
			
		||||
	return rawAddrToIP(end.src)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (end *NativeEndpoint) DstIP() net.IP {
 | 
			
		||||
	return rawAddrToIP(end.dst)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (end *NativeEndpoint) DstToBytes() []byte {
 | 
			
		||||
	ptr := unsafe.Pointer(&end.src)
 | 
			
		||||
	arr := (*[unix.SizeofSockaddrInet6]byte)(ptr)
 | 
			
		||||
	return arr[:]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (end *NativeEndpoint) SrcToString() string {
 | 
			
		||||
	return sockaddrToString(end.src)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (end *NativeEndpoint) DstToString() string {
 | 
			
		||||
	return sockaddrToString(end.dst)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (end *NativeEndpoint) ClearDst() {
 | 
			
		||||
	end.dst = unix.RawSockaddrInet6{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (end *NativeEndpoint) ClearSrc() {
 | 
			
		||||
	end.src = unix.RawSockaddrInet6{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func zoneToUint32(zone string) (uint32, error) {
 | 
			
		||||
@ -42,51 +253,116 @@ func zoneToUint32(zone string) (uint32, error) {
 | 
			
		||||
	return uint32(n), err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (end *Endpoint) ClearSrc() {
 | 
			
		||||
	end.srcIf4 = 0
 | 
			
		||||
	end.srcIPv4 = unix.RawSockaddrInet4{}
 | 
			
		||||
	end.srcIPv6 = unix.RawSockaddrInet6{}
 | 
			
		||||
}
 | 
			
		||||
func create4(port uint16) (int, uint16, error) {
 | 
			
		||||
 | 
			
		||||
	// create socket
 | 
			
		||||
 | 
			
		||||
	fd, err := unix.Socket(
 | 
			
		||||
		unix.AF_INET,
 | 
			
		||||
		unix.SOCK_DGRAM,
 | 
			
		||||
		0,
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
func (end *Endpoint) Set(s string) error {
 | 
			
		||||
	addr, err := parseEndpoint(s)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
		return -1, 0, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ipv6 := addr.IP.To16()
 | 
			
		||||
	if ipv6 != nil {
 | 
			
		||||
		zone, err := zoneToUint32(addr.Zone)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
	addr := unix.SockaddrInet4{
 | 
			
		||||
		Port: int(port),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// set sockopts and bind
 | 
			
		||||
 | 
			
		||||
	if err := func() error {
 | 
			
		||||
		if err := unix.SetsockoptInt(
 | 
			
		||||
			fd,
 | 
			
		||||
			unix.SOL_SOCKET,
 | 
			
		||||
			unix.SO_REUSEADDR,
 | 
			
		||||
			1,
 | 
			
		||||
		); err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		ptr := (*unix.RawSockaddrInet6)(unsafe.Pointer(&end.dst))
 | 
			
		||||
		ptr.Family = unix.AF_INET6
 | 
			
		||||
		ptr.Port = uint16(addr.Port)
 | 
			
		||||
		ptr.Flowinfo = 0
 | 
			
		||||
		ptr.Scope_id = zone
 | 
			
		||||
		copy(ptr.Addr[:], ipv6[:])
 | 
			
		||||
		end.ClearSrc()
 | 
			
		||||
		return nil
 | 
			
		||||
 | 
			
		||||
		if err := unix.SetsockoptInt(
 | 
			
		||||
			fd,
 | 
			
		||||
			unix.IPPROTO_IP,
 | 
			
		||||
			unix.IP_PKTINFO,
 | 
			
		||||
			1,
 | 
			
		||||
		); err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return unix.Bind(fd, &addr)
 | 
			
		||||
	}(); err != nil {
 | 
			
		||||
		unix.Close(fd)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ipv4 := addr.IP.To4()
 | 
			
		||||
	if ipv4 != nil {
 | 
			
		||||
		ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
 | 
			
		||||
		ptr.Family = unix.AF_INET
 | 
			
		||||
		ptr.Port = uint16(addr.Port)
 | 
			
		||||
		ptr.Zero = [8]byte{}
 | 
			
		||||
		copy(ptr.Addr[:], ipv4)
 | 
			
		||||
		end.ClearSrc()
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return errors.New("Failed to recognize IP address format")
 | 
			
		||||
	return fd, uint16(addr.Port), err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func send6(sock uintptr, end *Endpoint, buff []byte) error {
 | 
			
		||||
	var iovec unix.Iovec
 | 
			
		||||
func create6(port uint16) (int, uint16, error) {
 | 
			
		||||
 | 
			
		||||
	// create socket
 | 
			
		||||
 | 
			
		||||
	fd, err := unix.Socket(
 | 
			
		||||
		unix.AF_INET6,
 | 
			
		||||
		unix.SOCK_DGRAM,
 | 
			
		||||
		0,
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return -1, 0, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// set sockopts and bind
 | 
			
		||||
 | 
			
		||||
	addr := unix.SockaddrInet6{
 | 
			
		||||
		Port: int(port),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := func() error {
 | 
			
		||||
 | 
			
		||||
		if err := unix.SetsockoptInt(
 | 
			
		||||
			fd,
 | 
			
		||||
			unix.SOL_SOCKET,
 | 
			
		||||
			unix.SO_REUSEADDR,
 | 
			
		||||
			1,
 | 
			
		||||
		); err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if err := unix.SetsockoptInt(
 | 
			
		||||
			fd,
 | 
			
		||||
			unix.IPPROTO_IPV6,
 | 
			
		||||
			unix.IPV6_RECVPKTINFO,
 | 
			
		||||
			1,
 | 
			
		||||
		); err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if err := unix.SetsockoptInt(
 | 
			
		||||
			fd,
 | 
			
		||||
			unix.IPPROTO_IPV6,
 | 
			
		||||
			unix.IPV6_V6ONLY,
 | 
			
		||||
			1,
 | 
			
		||||
		); err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return unix.Bind(fd, &addr)
 | 
			
		||||
 | 
			
		||||
	}(); err != nil {
 | 
			
		||||
		unix.Close(fd)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return fd, uint16(addr.Port), err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func send6(sock int, end *NativeEndpoint, buff []byte) error {
 | 
			
		||||
 | 
			
		||||
	// construct message header
 | 
			
		||||
 | 
			
		||||
	var iovec unix.Iovec
 | 
			
		||||
	iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
 | 
			
		||||
	iovec.SetLen(len(buff))
 | 
			
		||||
 | 
			
		||||
@ -97,11 +373,11 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
 | 
			
		||||
		unix.Cmsghdr{
 | 
			
		||||
			Level: unix.IPPROTO_IPV6,
 | 
			
		||||
			Type:  unix.IPV6_PKTINFO,
 | 
			
		||||
			Len:   unix.SizeofInet6Pktinfo,
 | 
			
		||||
			Len:   unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
 | 
			
		||||
		},
 | 
			
		||||
		unix.Inet6Pktinfo{
 | 
			
		||||
			Addr:    end.srcIPv6.Addr,
 | 
			
		||||
			Ifindex: end.srcIPv6.Scope_id,
 | 
			
		||||
			Addr:    end.src.Addr,
 | 
			
		||||
			Ifindex: end.src.Scope_id,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -119,22 +395,41 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
 | 
			
		||||
 | 
			
		||||
	_, _, errno := unix.Syscall(
 | 
			
		||||
		unix.SYS_SENDMSG,
 | 
			
		||||
		sock,
 | 
			
		||||
		uintptr(sock),
 | 
			
		||||
		uintptr(unsafe.Pointer(&msghdr)),
 | 
			
		||||
		0,
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	if errno == 0 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// clear src and retry
 | 
			
		||||
 | 
			
		||||
	if errno == unix.EINVAL {
 | 
			
		||||
		end.ClearSrc()
 | 
			
		||||
		cmsg.pktinfo = unix.Inet6Pktinfo{}
 | 
			
		||||
		_, _, errno = unix.Syscall(
 | 
			
		||||
			unix.SYS_SENDMSG,
 | 
			
		||||
			uintptr(sock),
 | 
			
		||||
			uintptr(unsafe.Pointer(&msghdr)),
 | 
			
		||||
			0,
 | 
			
		||||
		)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return errno
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func send4(sock uintptr, end *Endpoint, buff []byte) error {
 | 
			
		||||
	var iovec unix.Iovec
 | 
			
		||||
func send4(sock int, end *NativeEndpoint, buff []byte) error {
 | 
			
		||||
 | 
			
		||||
	// construct message header
 | 
			
		||||
 | 
			
		||||
	var iovec unix.Iovec
 | 
			
		||||
	iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
 | 
			
		||||
	iovec.SetLen(len(buff))
 | 
			
		||||
 | 
			
		||||
	src4 := (*IPv4Source)(unsafe.Pointer(&end.src))
 | 
			
		||||
 | 
			
		||||
	cmsg := struct {
 | 
			
		||||
		cmsghdr unix.Cmsghdr
 | 
			
		||||
		pktinfo unix.Inet4Pktinfo
 | 
			
		||||
@ -142,11 +437,11 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
 | 
			
		||||
		unix.Cmsghdr{
 | 
			
		||||
			Level: unix.IPPROTO_IP,
 | 
			
		||||
			Type:  unix.IP_PKTINFO,
 | 
			
		||||
			Len:   unix.SizeofInet6Pktinfo,
 | 
			
		||||
			Len:   unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
 | 
			
		||||
		},
 | 
			
		||||
		unix.Inet4Pktinfo{
 | 
			
		||||
			Spec_dst: end.srcIPv4.Addr,
 | 
			
		||||
			Ifindex:  end.srcIf4,
 | 
			
		||||
			Spec_dst: src4.src.Addr,
 | 
			
		||||
			Ifindex:  src4.Ifindex,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -156,51 +451,44 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
 | 
			
		||||
		Name:    (*byte)(unsafe.Pointer(&end.dst)),
 | 
			
		||||
		Namelen: unix.SizeofSockaddrInet4,
 | 
			
		||||
		Control: (*byte)(unsafe.Pointer(&cmsg)),
 | 
			
		||||
		Flags:   0,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
 | 
			
		||||
 | 
			
		||||
	// sendmsg(sock, &msghdr, 0)
 | 
			
		||||
 | 
			
		||||
	_, _, errno := unix.Syscall(
 | 
			
		||||
		unix.SYS_SENDMSG,
 | 
			
		||||
		sock,
 | 
			
		||||
		uintptr(sock),
 | 
			
		||||
		uintptr(unsafe.Pointer(&msghdr)),
 | 
			
		||||
		0,
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	// clear source and try again
 | 
			
		||||
 | 
			
		||||
	if errno == unix.EINVAL {
 | 
			
		||||
		end.ClearSrc()
 | 
			
		||||
		cmsg.pktinfo = unix.Inet4Pktinfo{}
 | 
			
		||||
		_, _, errno = unix.Syscall(
 | 
			
		||||
			unix.SYS_SENDMSG,
 | 
			
		||||
			uintptr(sock),
 | 
			
		||||
			uintptr(unsafe.Pointer(&msghdr)),
 | 
			
		||||
			0,
 | 
			
		||||
		)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// errno = 0 is still an error instance
 | 
			
		||||
 | 
			
		||||
	if errno == 0 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return errno
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func send(c *net.UDPConn, end *Endpoint, buff []byte) error {
 | 
			
		||||
func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
 | 
			
		||||
 | 
			
		||||
	// extract underlying file descriptor
 | 
			
		||||
 | 
			
		||||
	file, err := c.File()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	sock := file.Fd()
 | 
			
		||||
 | 
			
		||||
	// send depending on address family of dst
 | 
			
		||||
 | 
			
		||||
	family := *((*uint16)(unsafe.Pointer(&end.dst)))
 | 
			
		||||
	if family == unix.AF_INET {
 | 
			
		||||
		return send4(sock, end, buff)
 | 
			
		||||
	} else if family == unix.AF_INET6 {
 | 
			
		||||
		return send6(sock, end, buff)
 | 
			
		||||
	}
 | 
			
		||||
	return errors.New("Unknown address family of source")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func receiveIPv4(end *Endpoint, c *net.UDPConn, buff []byte) (error, *net.UDPAddr, *net.UDPAddr) {
 | 
			
		||||
 | 
			
		||||
	file, err := c.File()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err, nil, nil
 | 
			
		||||
	}
 | 
			
		||||
	// contruct message header
 | 
			
		||||
 | 
			
		||||
	var iovec unix.Iovec
 | 
			
		||||
	iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
 | 
			
		||||
@ -208,60 +496,87 @@ func receiveIPv4(end *Endpoint, c *net.UDPConn, buff []byte) (error, *net.UDPAdd
 | 
			
		||||
 | 
			
		||||
	var cmsg struct {
 | 
			
		||||
		cmsghdr unix.Cmsghdr
 | 
			
		||||
		pktinfo unix.Inet6Pktinfo // big enough
 | 
			
		||||
		pktinfo unix.Inet4Pktinfo
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var msghdr unix.Msghdr
 | 
			
		||||
	msghdr.Iov = &iovec
 | 
			
		||||
	msghdr.Iovlen = 1
 | 
			
		||||
	msghdr.Name = (*byte)(unsafe.Pointer(&end.dst))
 | 
			
		||||
	msghdr.Namelen = unix.SizeofSockaddrInet4
 | 
			
		||||
	msghdr.Control = (*byte)(unsafe.Pointer(&cmsg))
 | 
			
		||||
	msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
 | 
			
		||||
 | 
			
		||||
	// recvmsg(sock, &mskhdr, 0)
 | 
			
		||||
 | 
			
		||||
	size, _, errno := unix.Syscall(
 | 
			
		||||
		unix.SYS_RECVMSG,
 | 
			
		||||
		uintptr(sock),
 | 
			
		||||
		uintptr(unsafe.Pointer(&msghdr)),
 | 
			
		||||
		0,
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	if errno != 0 {
 | 
			
		||||
		return 0, errno
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// update source cache
 | 
			
		||||
 | 
			
		||||
	if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
 | 
			
		||||
		cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
 | 
			
		||||
		cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
 | 
			
		||||
		src4 := (*IPv4Source)(unsafe.Pointer(&end.src))
 | 
			
		||||
		src4.src.Family = unix.AF_INET
 | 
			
		||||
		src4.src.Addr = cmsg.pktinfo.Spec_dst
 | 
			
		||||
		src4.Ifindex = cmsg.pktinfo.Ifindex
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return int(size), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
 | 
			
		||||
 | 
			
		||||
	// contruct message header
 | 
			
		||||
 | 
			
		||||
	var iovec unix.Iovec
 | 
			
		||||
	iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
 | 
			
		||||
	iovec.SetLen(len(buff))
 | 
			
		||||
 | 
			
		||||
	var cmsg struct {
 | 
			
		||||
		cmsghdr unix.Cmsghdr
 | 
			
		||||
		pktinfo unix.Inet6Pktinfo
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var msg unix.Msghdr
 | 
			
		||||
	msg.Iov = &iovec
 | 
			
		||||
	msg.Iovlen = 1
 | 
			
		||||
	msg.Name = (*byte)(unsafe.Pointer(&end.dst))
 | 
			
		||||
	msg.Namelen = uint32(unix.SizeofSockaddrAny)
 | 
			
		||||
	msg.Namelen = uint32(unix.SizeofSockaddrInet6)
 | 
			
		||||
	msg.Control = (*byte)(unsafe.Pointer(&cmsg))
 | 
			
		||||
	msg.SetControllen(int(unsafe.Sizeof(cmsg)))
 | 
			
		||||
 | 
			
		||||
	_, _, errno := unix.Syscall(
 | 
			
		||||
	// recvmsg(sock, &mskhdr, 0)
 | 
			
		||||
 | 
			
		||||
	size, _, errno := unix.Syscall(
 | 
			
		||||
		unix.SYS_RECVMSG,
 | 
			
		||||
		file.Fd(),
 | 
			
		||||
		uintptr(sock),
 | 
			
		||||
		uintptr(unsafe.Pointer(&msg)),
 | 
			
		||||
		0,
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	if errno != 0 {
 | 
			
		||||
		return errno, nil, nil
 | 
			
		||||
		return 0, errno
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// update source cache
 | 
			
		||||
 | 
			
		||||
	if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
 | 
			
		||||
		cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
 | 
			
		||||
		cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
 | 
			
		||||
 | 
			
		||||
		end.src.Family = unix.AF_INET6
 | 
			
		||||
		end.src.Addr = cmsg.pktinfo.Addr
 | 
			
		||||
		end.src.Scope_id = cmsg.pktinfo.Ifindex
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
 | 
			
		||||
		cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
 | 
			
		||||
		cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
 | 
			
		||||
 | 
			
		||||
		info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&cmsg.pktinfo))
 | 
			
		||||
		println(info)
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func setMark(conn *net.UDPConn, value uint32) error {
 | 
			
		||||
	if conn == nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	file, err := conn.File()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return unix.SetsockoptInt(
 | 
			
		||||
		int(file.Fd()),
 | 
			
		||||
		unix.SOL_SOCKET,
 | 
			
		||||
		unix.SO_MARK,
 | 
			
		||||
		int(value),
 | 
			
		||||
	)
 | 
			
		||||
	return int(size), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -5,10 +5,8 @@ import (
 | 
			
		||||
	"crypto/rand"
 | 
			
		||||
	"golang.org/x/crypto/blake2s"
 | 
			
		||||
	"golang.org/x/crypto/chacha20poly1305"
 | 
			
		||||
	"net"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
	"unsafe"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type CookieChecker struct {
 | 
			
		||||
@ -76,7 +74,7 @@ func (st *CookieChecker) CheckMAC1(msg []byte) bool {
 | 
			
		||||
	return hmac.Equal(mac1[:], msg[smac1:smac2])
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool {
 | 
			
		||||
func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool {
 | 
			
		||||
	st.mutex.RLock()
 | 
			
		||||
	defer st.mutex.RUnlock()
 | 
			
		||||
 | 
			
		||||
@ -89,8 +87,7 @@ func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool {
 | 
			
		||||
	var cookie [blake2s.Size128]byte
 | 
			
		||||
	func() {
 | 
			
		||||
		mac, _ := blake2s.New128(st.mac2.secret[:])
 | 
			
		||||
		mac.Write(src.IP)
 | 
			
		||||
		mac.Write((*[unsafe.Sizeof(src.Port)]byte)(unsafe.Pointer(&src.Port))[:])
 | 
			
		||||
		mac.Write(src)
 | 
			
		||||
		mac.Sum(cookie[:0])
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
@ -111,7 +108,7 @@ func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool {
 | 
			
		||||
func (st *CookieChecker) CreateReply(
 | 
			
		||||
	msg []byte,
 | 
			
		||||
	recv uint32,
 | 
			
		||||
	src *net.UDPAddr,
 | 
			
		||||
	src []byte,
 | 
			
		||||
) (*MessageCookieReply, error) {
 | 
			
		||||
 | 
			
		||||
	st.mutex.RLock()
 | 
			
		||||
@ -136,8 +133,7 @@ func (st *CookieChecker) CreateReply(
 | 
			
		||||
	var cookie [blake2s.Size128]byte
 | 
			
		||||
	func() {
 | 
			
		||||
		mac, _ := blake2s.New128(st.mac2.secret[:])
 | 
			
		||||
		mac.Write(src.IP)
 | 
			
		||||
		mac.Write((*[unsafe.Sizeof(src.Port)]byte)(unsafe.Pointer(&src.Port))[:])
 | 
			
		||||
		mac.Write(src)
 | 
			
		||||
		mac.Sum(cookie[:0])
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,6 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -25,7 +24,7 @@ func TestCookieMAC1(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
	// check mac1
 | 
			
		||||
 | 
			
		||||
	src, _ := net.ResolveUDPAddr("udp", "192.168.13.37:4000")
 | 
			
		||||
	src := []byte{192, 168, 13, 37, 10, 10, 10}
 | 
			
		||||
 | 
			
		||||
	checkMAC1 := func(msg []byte) {
 | 
			
		||||
		generator.AddMacs(msg)
 | 
			
		||||
@ -128,12 +127,12 @@ func TestCookieMAC1(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
		msg[5] ^= 0x20
 | 
			
		||||
 | 
			
		||||
		srcBad1, _ := net.ResolveUDPAddr("udp", "192.168.13.37:4001")
 | 
			
		||||
		srcBad1 := []byte{192, 168, 13, 37, 40, 01}
 | 
			
		||||
		if checker.CheckMAC2(msg, srcBad1) {
 | 
			
		||||
			t.Fatal("MAC2 generation/verification failed")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		srcBad2, _ := net.ResolveUDPAddr("udp", "192.168.13.38:4000")
 | 
			
		||||
		srcBad2 := []byte{192, 168, 13, 38, 40, 01}
 | 
			
		||||
		if checker.CheckMAC2(msg, srcBad2) {
 | 
			
		||||
			t.Fatal("MAC2 generation/verification failed")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
@ -2,29 +2,25 @@ package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"os"
 | 
			
		||||
	"os/exec"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/* Daemonizes the process on linux
 | 
			
		||||
 *
 | 
			
		||||
 * This is done by spawning and releasing a copy with the --foreground flag
 | 
			
		||||
 *
 | 
			
		||||
 * TODO: Use env variable to spawn in background
 | 
			
		||||
 */
 | 
			
		||||
func Daemonize(attr *os.ProcAttr) error {
 | 
			
		||||
	// I would like to use os.Executable,
 | 
			
		||||
	// however this means dropping support for Go <1.8
 | 
			
		||||
	path, err := exec.LookPath(os.Args[0])
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
func Daemonize() error {
 | 
			
		||||
	argv := []string{os.Args[0], "--foreground"}
 | 
			
		||||
	argv = append(argv, os.Args[1:]...)
 | 
			
		||||
	attr := &os.ProcAttr{
 | 
			
		||||
		Dir: ".",
 | 
			
		||||
		Env: os.Environ(),
 | 
			
		||||
		Files: []*os.File{
 | 
			
		||||
			os.Stdin,
 | 
			
		||||
			nil,
 | 
			
		||||
			nil,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	process, err := os.StartProcess(
 | 
			
		||||
		argv[0],
 | 
			
		||||
		path,
 | 
			
		||||
		argv,
 | 
			
		||||
		attr,
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,6 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net"
 | 
			
		||||
	"runtime"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
@ -9,8 +8,9 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Device struct {
 | 
			
		||||
	log       *Logger // collection of loggers for levels
 | 
			
		||||
	idCounter uint    // for assigning debug ids to peers
 | 
			
		||||
	closed    AtomicBool // device is closed? (acting as guard)
 | 
			
		||||
	log       *Logger    // collection of loggers for levels
 | 
			
		||||
	idCounter uint       // for assigning debug ids to peers
 | 
			
		||||
	fwMark    uint32
 | 
			
		||||
	tun       struct {
 | 
			
		||||
		device TUNDevice
 | 
			
		||||
@ -22,9 +22,9 @@ type Device struct {
 | 
			
		||||
	}
 | 
			
		||||
	net struct {
 | 
			
		||||
		mutex  sync.RWMutex
 | 
			
		||||
		addr   *net.UDPAddr // UDP source address
 | 
			
		||||
		conn   *net.UDPConn // UDP "connection"
 | 
			
		||||
		fwmark uint32
 | 
			
		||||
		bind   Bind   // bind interface
 | 
			
		||||
		port   uint16 // listening port
 | 
			
		||||
		fwmark uint32 // mark value (0 = disabled)
 | 
			
		||||
	}
 | 
			
		||||
	mutex        sync.RWMutex
 | 
			
		||||
	privateKey   NoisePrivateKey
 | 
			
		||||
@ -37,8 +37,7 @@ type Device struct {
 | 
			
		||||
		handshake  chan QueueHandshakeElement
 | 
			
		||||
	}
 | 
			
		||||
	signal struct {
 | 
			
		||||
		stop       chan struct{} // halts all go routines
 | 
			
		||||
		newUDPConn chan struct{} // a net.conn was set (consumed by the receiver routine)
 | 
			
		||||
		stop chan struct{}
 | 
			
		||||
	}
 | 
			
		||||
	underLoadUntil atomic.Value
 | 
			
		||||
	ratelimiter    Ratelimiter
 | 
			
		||||
@ -128,21 +127,23 @@ func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
 | 
			
		||||
	device.pool.messageBuffers.Put(msg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewDevice(tun TUNDevice, logLevel int) *Device {
 | 
			
		||||
func NewDevice(tun TUNDevice, logger *Logger) *Device {
 | 
			
		||||
	device := new(Device)
 | 
			
		||||
 | 
			
		||||
	device.mutex.Lock()
 | 
			
		||||
	defer device.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	device.log = NewLogger(logLevel, "("+tun.Name()+") ")
 | 
			
		||||
	device.log = logger
 | 
			
		||||
	device.peers = make(map[NoisePublicKey]*Peer)
 | 
			
		||||
	device.tun.device = tun
 | 
			
		||||
 | 
			
		||||
	device.indices.Init()
 | 
			
		||||
	device.ratelimiter.Init()
 | 
			
		||||
 | 
			
		||||
	device.routingTable.Reset()
 | 
			
		||||
	device.underLoadUntil.Store(time.Time{})
 | 
			
		||||
 | 
			
		||||
	// setup pools
 | 
			
		||||
	// setup buffer pool
 | 
			
		||||
 | 
			
		||||
	device.pool.messageBuffers = sync.Pool{
 | 
			
		||||
		New: func() interface{} {
 | 
			
		||||
@ -159,7 +160,11 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
 | 
			
		||||
	// prepare signals
 | 
			
		||||
 | 
			
		||||
	device.signal.stop = make(chan struct{})
 | 
			
		||||
	device.signal.newUDPConn = make(chan struct{}, 1)
 | 
			
		||||
 | 
			
		||||
	// prepare net
 | 
			
		||||
 | 
			
		||||
	device.net.port = 0
 | 
			
		||||
	device.net.bind = nil
 | 
			
		||||
 | 
			
		||||
	// start workers
 | 
			
		||||
 | 
			
		||||
@ -168,12 +173,9 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
 | 
			
		||||
		go device.RoutineDecryption()
 | 
			
		||||
		go device.RoutineHandshake()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	go device.RoutineReadFromTUN()
 | 
			
		||||
	go device.RoutineTUNEventReader()
 | 
			
		||||
	go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
 | 
			
		||||
	go device.RoutineReadFromTUN()
 | 
			
		||||
	go device.RoutineReceiveIncomming()
 | 
			
		||||
 | 
			
		||||
	return device
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -202,9 +204,13 @@ func (device *Device) RemoveAllPeers() {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (device *Device) Close() {
 | 
			
		||||
	if device.closed.Swap(true) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	device.log.Info.Println("Closing device")
 | 
			
		||||
	device.RemoveAllPeers()
 | 
			
		||||
	close(device.signal.stop)
 | 
			
		||||
	closeUDPConn(device)
 | 
			
		||||
	CloseUDPListener(device)
 | 
			
		||||
	device.tun.device.Close()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -2,6 +2,7 @@ package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"os"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -15,6 +16,10 @@ type DummyTUN struct {
 | 
			
		||||
	events  chan TUNEvent
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tun *DummyTUN) File() *os.File {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tun *DummyTUN) Name() string {
 | 
			
		||||
	return tun.name
 | 
			
		||||
}
 | 
			
		||||
@ -67,7 +72,8 @@ func randDevice(t *testing.T) *Device {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	tun, _ := CreateDummyTUN("dummy")
 | 
			
		||||
	device := NewDevice(tun, LogLevelError)
 | 
			
		||||
	logger := NewLogger(LogLevelError, "")
 | 
			
		||||
	device := NewDevice(tun, logger)
 | 
			
		||||
	device.SetPrivateKey(sk)
 | 
			
		||||
	return device
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										134
									
								
								src/main.go
									
									
									
									
									
								
							
							
						
						
									
										134
									
								
								src/main.go
									
									
									
									
									
								
							@ -2,10 +2,15 @@ package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"log"
 | 
			
		||||
	"os"
 | 
			
		||||
	"os/signal"
 | 
			
		||||
	"runtime"
 | 
			
		||||
	"strconv"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	ENV_WG_TUN_FD  = "WG_TUN_FD"
 | 
			
		||||
	ENV_WG_UAPI_FD = "WG_UAPI_FD"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func printUsage() {
 | 
			
		||||
@ -43,28 +48,6 @@ func main() {
 | 
			
		||||
		interfaceName = os.Args[1]
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// daemonize the process
 | 
			
		||||
 | 
			
		||||
	if !foreground {
 | 
			
		||||
		err := Daemonize()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Println("Failed to daemonize:", err)
 | 
			
		||||
		}
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// increase number of go workers (for Go <1.5)
 | 
			
		||||
 | 
			
		||||
	runtime.GOMAXPROCS(runtime.NumCPU())
 | 
			
		||||
 | 
			
		||||
	// open TUN device
 | 
			
		||||
 | 
			
		||||
	tun, err := CreateTUN(interfaceName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Println("Failed to create tun device:", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// get log level (default: info)
 | 
			
		||||
 | 
			
		||||
	logLevel := func() int {
 | 
			
		||||
@ -79,25 +62,103 @@ func main() {
 | 
			
		||||
		return LogLevelInfo
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	logger := NewLogger(
 | 
			
		||||
		logLevel,
 | 
			
		||||
		fmt.Sprintf("(%s) ", interfaceName),
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	logger.Debug.Println("Debug log enabled")
 | 
			
		||||
 | 
			
		||||
	// open TUN device (or use supplied fd)
 | 
			
		||||
 | 
			
		||||
	tun, err := func() (TUNDevice, error) {
 | 
			
		||||
		tunFdStr := os.Getenv(ENV_WG_TUN_FD)
 | 
			
		||||
		if tunFdStr == "" {
 | 
			
		||||
			return CreateTUN(interfaceName)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// construct tun device from supplied fd
 | 
			
		||||
 | 
			
		||||
		fd, err := strconv.ParseUint(tunFdStr, 10, 32)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		file := os.NewFile(uintptr(fd), "")
 | 
			
		||||
		return CreateTUNFromFile(interfaceName, file)
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error.Println("Failed to create TUN device:", err)
 | 
			
		||||
		os.Exit(ExitSetupFailed)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// open UAPI file (or use supplied fd)
 | 
			
		||||
 | 
			
		||||
	fileUAPI, err := func() (*os.File, error) {
 | 
			
		||||
		uapiFdStr := os.Getenv(ENV_WG_UAPI_FD)
 | 
			
		||||
		if uapiFdStr == "" {
 | 
			
		||||
			return UAPIOpen(interfaceName)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// use supplied fd
 | 
			
		||||
 | 
			
		||||
		fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return os.NewFile(uintptr(fd), ""), nil
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error.Println("UAPI listen error:", err)
 | 
			
		||||
		os.Exit(ExitSetupFailed)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// daemonize the process
 | 
			
		||||
 | 
			
		||||
	if !foreground {
 | 
			
		||||
		env := os.Environ()
 | 
			
		||||
		env = append(env, fmt.Sprintf("%s=3", ENV_WG_TUN_FD))
 | 
			
		||||
		env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD))
 | 
			
		||||
		attr := &os.ProcAttr{
 | 
			
		||||
			Files: []*os.File{
 | 
			
		||||
				nil, // stdin
 | 
			
		||||
				nil, // stdout
 | 
			
		||||
				nil, // stderr
 | 
			
		||||
				tun.File(),
 | 
			
		||||
				fileUAPI,
 | 
			
		||||
			},
 | 
			
		||||
			Dir: ".",
 | 
			
		||||
			Env: env,
 | 
			
		||||
		}
 | 
			
		||||
		err = Daemonize(attr)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Error.Println("Failed to daemonize:", err)
 | 
			
		||||
			os.Exit(ExitSetupFailed)
 | 
			
		||||
		}
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// increase number of go workers (for Go <1.5)
 | 
			
		||||
 | 
			
		||||
	runtime.GOMAXPROCS(runtime.NumCPU())
 | 
			
		||||
 | 
			
		||||
	// create wireguard device
 | 
			
		||||
 | 
			
		||||
	device := NewDevice(tun, logLevel)
 | 
			
		||||
	device := NewDevice(tun, logger)
 | 
			
		||||
 | 
			
		||||
	logInfo := device.log.Info
 | 
			
		||||
	logError := device.log.Error
 | 
			
		||||
	logInfo.Println("Starting device")
 | 
			
		||||
	logger.Info.Println("Device started")
 | 
			
		||||
 | 
			
		||||
	// start configuration lister
 | 
			
		||||
 | 
			
		||||
	uapi, err := NewUAPIListener(interfaceName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logError.Fatal("UAPI listen error:", err)
 | 
			
		||||
	}
 | 
			
		||||
	// start uapi listener
 | 
			
		||||
 | 
			
		||||
	errs := make(chan error)
 | 
			
		||||
	term := make(chan os.Signal)
 | 
			
		||||
	wait := device.WaitChannel()
 | 
			
		||||
 | 
			
		||||
	uapi, err := UAPIListen(interfaceName, fileUAPI)
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			conn, err := uapi.Accept()
 | 
			
		||||
@ -109,7 +170,7 @@ func main() {
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	logInfo.Println("UAPI listener started")
 | 
			
		||||
	logger.Info.Println("UAPI listener started")
 | 
			
		||||
 | 
			
		||||
	// wait for program to terminate
 | 
			
		||||
 | 
			
		||||
@ -122,9 +183,10 @@ func main() {
 | 
			
		||||
	case <-errs:
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// clean up UAPI bind
 | 
			
		||||
	// clean up
 | 
			
		||||
 | 
			
		||||
	uapi.Close()
 | 
			
		||||
	device.Close()
 | 
			
		||||
 | 
			
		||||
	logInfo.Println("Closing")
 | 
			
		||||
	logger.Info.Println("Shutting down")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -21,6 +21,14 @@ func (a *AtomicBool) Get() bool {
 | 
			
		||||
	return atomic.LoadInt32(&a.flag) == AtomicTrue
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *AtomicBool) Swap(val bool) bool {
 | 
			
		||||
	flag := AtomicFalse
 | 
			
		||||
	if val {
 | 
			
		||||
		flag = AtomicTrue
 | 
			
		||||
	}
 | 
			
		||||
	return atomic.SwapInt32(&a.flag, flag) == AtomicTrue
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *AtomicBool) Set(val bool) {
 | 
			
		||||
	flag := AtomicFalse
 | 
			
		||||
	if val {
 | 
			
		||||
 | 
			
		||||
@ -117,8 +117,8 @@ func TestNoiseHandshake(t *testing.T) {
 | 
			
		||||
		var err error
 | 
			
		||||
		var out []byte
 | 
			
		||||
		var nonce [12]byte
 | 
			
		||||
		out = key1.send.aead.Seal(out, nonce[:], testMsg, nil)
 | 
			
		||||
		out, err = key2.receive.aead.Open(out[:0], nonce[:], out, nil)
 | 
			
		||||
		out = key1.send.Seal(out, nonce[:], testMsg, nil)
 | 
			
		||||
		out, err = key2.receive.Open(out[:0], nonce[:], out, nil)
 | 
			
		||||
		assertNil(t, err)
 | 
			
		||||
		assertEqual(t, out, testMsg)
 | 
			
		||||
	}()
 | 
			
		||||
@ -128,8 +128,8 @@ func TestNoiseHandshake(t *testing.T) {
 | 
			
		||||
		var err error
 | 
			
		||||
		var out []byte
 | 
			
		||||
		var nonce [12]byte
 | 
			
		||||
		out = key2.send.aead.Seal(out, nonce[:], testMsg, nil)
 | 
			
		||||
		out, err = key1.receive.aead.Open(out[:0], nonce[:], out, nil)
 | 
			
		||||
		out = key2.send.Seal(out, nonce[:], testMsg, nil)
 | 
			
		||||
		out, err = key1.receive.Open(out[:0], nonce[:], out, nil)
 | 
			
		||||
		assertNil(t, err)
 | 
			
		||||
		assertEqual(t, out, testMsg)
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										29
									
								
								src/peer.go
									
									
									
									
									
								
							
							
						
						
									
										29
									
								
								src/peer.go
									
									
									
									
									
								
							@ -4,7 +4,6 @@ import (
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
@ -16,7 +15,7 @@ type Peer struct {
 | 
			
		||||
	keyPairs                    KeyPairs
 | 
			
		||||
	handshake                   Handshake
 | 
			
		||||
	device                      *Device
 | 
			
		||||
	endpoint                    *net.UDPAddr
 | 
			
		||||
	endpoint                    Endpoint
 | 
			
		||||
	stats                       struct {
 | 
			
		||||
		txBytes           uint64 // bytes send to peer (endpoint)
 | 
			
		||||
		rxBytes           uint64 // bytes received from peer
 | 
			
		||||
@ -106,6 +105,10 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
 | 
			
		||||
	handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic)
 | 
			
		||||
	handshake.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	// reset endpoint
 | 
			
		||||
 | 
			
		||||
	peer.endpoint = nil
 | 
			
		||||
 | 
			
		||||
	// prepare queuing
 | 
			
		||||
 | 
			
		||||
	peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
 | 
			
		||||
@ -130,11 +133,31 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
 | 
			
		||||
	return peer, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (peer *Peer) SendBuffer(buffer []byte) error {
 | 
			
		||||
	peer.device.net.mutex.RLock()
 | 
			
		||||
	defer peer.device.net.mutex.RUnlock()
 | 
			
		||||
	peer.mutex.RLock()
 | 
			
		||||
	defer peer.mutex.RUnlock()
 | 
			
		||||
	if peer.endpoint == nil {
 | 
			
		||||
		return errors.New("No known endpoint for peer")
 | 
			
		||||
	}
 | 
			
		||||
	return peer.device.net.bind.Send(buffer, peer.endpoint)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* Returns a short string identification for logging
 | 
			
		||||
 */
 | 
			
		||||
func (peer *Peer) String() string {
 | 
			
		||||
	if peer.endpoint == nil {
 | 
			
		||||
		return fmt.Sprintf(
 | 
			
		||||
			"peer(%d unknown %s)",
 | 
			
		||||
			peer.id,
 | 
			
		||||
			base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
 | 
			
		||||
		)
 | 
			
		||||
	}
 | 
			
		||||
	return fmt.Sprintf(
 | 
			
		||||
		"peer(%d %s %s)",
 | 
			
		||||
		peer.id,
 | 
			
		||||
		peer.endpoint.String(),
 | 
			
		||||
		peer.endpoint.DstToString(),
 | 
			
		||||
		base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
 | 
			
		||||
	)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										271
									
								
								src/receive.go
									
									
									
									
									
								
							
							
						
						
									
										271
									
								
								src/receive.go
									
									
									
									
									
								
							@ -13,19 +13,20 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type QueueHandshakeElement struct {
 | 
			
		||||
	msgType uint32
 | 
			
		||||
	packet  []byte
 | 
			
		||||
	buffer  *[MaxMessageSize]byte
 | 
			
		||||
	source  *net.UDPAddr
 | 
			
		||||
	msgType  uint32
 | 
			
		||||
	packet   []byte
 | 
			
		||||
	endpoint Endpoint
 | 
			
		||||
	buffer   *[MaxMessageSize]byte
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type QueueInboundElement struct {
 | 
			
		||||
	dropped int32
 | 
			
		||||
	mutex   sync.Mutex
 | 
			
		||||
	buffer  *[MaxMessageSize]byte
 | 
			
		||||
	packet  []byte
 | 
			
		||||
	counter uint64
 | 
			
		||||
	keyPair *KeyPair
 | 
			
		||||
	dropped  int32
 | 
			
		||||
	mutex    sync.Mutex
 | 
			
		||||
	buffer   *[MaxMessageSize]byte
 | 
			
		||||
	packet   []byte
 | 
			
		||||
	counter  uint64
 | 
			
		||||
	keyPair  *KeyPair
 | 
			
		||||
	endpoint Endpoint
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (elem *QueueInboundElement) Drop() {
 | 
			
		||||
@ -92,130 +93,122 @@ func (device *Device) addToHandshakeQueue(
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (device *Device) RoutineReceiveIncomming() {
 | 
			
		||||
func (device *Device) RoutineReceiveIncomming(IP int, bind Bind) {
 | 
			
		||||
 | 
			
		||||
	logDebug := device.log.Debug
 | 
			
		||||
	logDebug.Println("Routine, receive incomming, started")
 | 
			
		||||
	logDebug.Println("Routine, receive incomming, IP version:", IP)
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
 | 
			
		||||
		// wait for new conn
 | 
			
		||||
		// receive datagrams until conn is closed
 | 
			
		||||
 | 
			
		||||
		logDebug.Println("Waiting for udp socket")
 | 
			
		||||
		buffer := device.GetMessageBuffer()
 | 
			
		||||
 | 
			
		||||
		select {
 | 
			
		||||
		case <-device.signal.stop:
 | 
			
		||||
			return
 | 
			
		||||
		var (
 | 
			
		||||
			err      error
 | 
			
		||||
			size     int
 | 
			
		||||
			endpoint Endpoint
 | 
			
		||||
		)
 | 
			
		||||
 | 
			
		||||
		case <-device.signal.newUDPConn:
 | 
			
		||||
		for {
 | 
			
		||||
 | 
			
		||||
			// fetch connection
 | 
			
		||||
			// read next datagram
 | 
			
		||||
 | 
			
		||||
			device.net.mutex.RLock()
 | 
			
		||||
			conn := device.net.conn
 | 
			
		||||
			device.net.mutex.RUnlock()
 | 
			
		||||
			if conn == nil {
 | 
			
		||||
			switch IP {
 | 
			
		||||
			case ipv4.Version:
 | 
			
		||||
				size, endpoint, err = bind.ReceiveIPv4(buffer[:])
 | 
			
		||||
			case ipv6.Version:
 | 
			
		||||
				size, endpoint, err = bind.ReceiveIPv6(buffer[:])
 | 
			
		||||
			default:
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if size < MinMessageSize {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			logDebug.Println("Listening for inbound packets")
 | 
			
		||||
			// check size of packet
 | 
			
		||||
 | 
			
		||||
			// receive datagrams until conn is closed
 | 
			
		||||
			packet := buffer[:size]
 | 
			
		||||
			msgType := binary.LittleEndian.Uint32(packet[:4])
 | 
			
		||||
 | 
			
		||||
			buffer := device.GetMessageBuffer()
 | 
			
		||||
			var okay bool
 | 
			
		||||
 | 
			
		||||
			for {
 | 
			
		||||
			switch msgType {
 | 
			
		||||
 | 
			
		||||
				// read next datagram
 | 
			
		||||
			// check if transport
 | 
			
		||||
 | 
			
		||||
				size, raddr, err := conn.ReadFromUDP(buffer[:])
 | 
			
		||||
			case MessageTransportType:
 | 
			
		||||
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
				// check size
 | 
			
		||||
 | 
			
		||||
				if size < MinMessageSize {
 | 
			
		||||
				if len(packet) < MessageTransportType {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// check size of packet
 | 
			
		||||
				// lookup key pair
 | 
			
		||||
 | 
			
		||||
				packet := buffer[:size]
 | 
			
		||||
				msgType := binary.LittleEndian.Uint32(packet[:4])
 | 
			
		||||
 | 
			
		||||
				var okay bool
 | 
			
		||||
 | 
			
		||||
				switch msgType {
 | 
			
		||||
 | 
			
		||||
				// check if transport
 | 
			
		||||
 | 
			
		||||
				case MessageTransportType:
 | 
			
		||||
 | 
			
		||||
					// check size
 | 
			
		||||
 | 
			
		||||
					if len(packet) < MessageTransportType {
 | 
			
		||||
						continue
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					// lookup key pair
 | 
			
		||||
 | 
			
		||||
					receiver := binary.LittleEndian.Uint32(
 | 
			
		||||
						packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
 | 
			
		||||
					)
 | 
			
		||||
					value := device.indices.Lookup(receiver)
 | 
			
		||||
					keyPair := value.keyPair
 | 
			
		||||
					if keyPair == nil {
 | 
			
		||||
						continue
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					// check key-pair expiry
 | 
			
		||||
 | 
			
		||||
					if keyPair.created.Add(RejectAfterTime).Before(time.Now()) {
 | 
			
		||||
						continue
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					// create work element
 | 
			
		||||
 | 
			
		||||
					peer := value.peer
 | 
			
		||||
					elem := &QueueInboundElement{
 | 
			
		||||
						packet:  packet,
 | 
			
		||||
						buffer:  buffer,
 | 
			
		||||
						keyPair: keyPair,
 | 
			
		||||
						dropped: AtomicFalse,
 | 
			
		||||
					}
 | 
			
		||||
					elem.mutex.Lock()
 | 
			
		||||
 | 
			
		||||
					// add to decryption queues
 | 
			
		||||
 | 
			
		||||
					device.addToDecryptionQueue(device.queue.decryption, elem)
 | 
			
		||||
					device.addToInboundQueue(peer.queue.inbound, elem)
 | 
			
		||||
					buffer = device.GetMessageBuffer()
 | 
			
		||||
				receiver := binary.LittleEndian.Uint32(
 | 
			
		||||
					packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
 | 
			
		||||
				)
 | 
			
		||||
				value := device.indices.Lookup(receiver)
 | 
			
		||||
				keyPair := value.keyPair
 | 
			
		||||
				if keyPair == nil {
 | 
			
		||||
					continue
 | 
			
		||||
 | 
			
		||||
				// otherwise it is a handshake related packet
 | 
			
		||||
 | 
			
		||||
				case MessageInitiationType:
 | 
			
		||||
					okay = len(packet) == MessageInitiationSize
 | 
			
		||||
 | 
			
		||||
				case MessageResponseType:
 | 
			
		||||
					okay = len(packet) == MessageResponseSize
 | 
			
		||||
 | 
			
		||||
				case MessageCookieReplyType:
 | 
			
		||||
					okay = len(packet) == MessageCookieReplySize
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if okay {
 | 
			
		||||
					device.addToHandshakeQueue(
 | 
			
		||||
						device.queue.handshake,
 | 
			
		||||
						QueueHandshakeElement{
 | 
			
		||||
							msgType: msgType,
 | 
			
		||||
							buffer:  buffer,
 | 
			
		||||
							packet:  packet,
 | 
			
		||||
							source:  raddr,
 | 
			
		||||
						},
 | 
			
		||||
					)
 | 
			
		||||
					buffer = device.GetMessageBuffer()
 | 
			
		||||
				// check key-pair expiry
 | 
			
		||||
 | 
			
		||||
				if keyPair.created.Add(RejectAfterTime).Before(time.Now()) {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// create work element
 | 
			
		||||
 | 
			
		||||
				peer := value.peer
 | 
			
		||||
				elem := &QueueInboundElement{
 | 
			
		||||
					packet:   packet,
 | 
			
		||||
					buffer:   buffer,
 | 
			
		||||
					keyPair:  keyPair,
 | 
			
		||||
					dropped:  AtomicFalse,
 | 
			
		||||
					endpoint: endpoint,
 | 
			
		||||
				}
 | 
			
		||||
				elem.mutex.Lock()
 | 
			
		||||
 | 
			
		||||
				// add to decryption queues
 | 
			
		||||
 | 
			
		||||
				device.addToDecryptionQueue(device.queue.decryption, elem)
 | 
			
		||||
				device.addToInboundQueue(peer.queue.inbound, elem)
 | 
			
		||||
				buffer = device.GetMessageBuffer()
 | 
			
		||||
				continue
 | 
			
		||||
 | 
			
		||||
			// otherwise it is a fixed size & handshake related packet
 | 
			
		||||
 | 
			
		||||
			case MessageInitiationType:
 | 
			
		||||
				okay = len(packet) == MessageInitiationSize
 | 
			
		||||
 | 
			
		||||
			case MessageResponseType:
 | 
			
		||||
				okay = len(packet) == MessageResponseSize
 | 
			
		||||
 | 
			
		||||
			case MessageCookieReplyType:
 | 
			
		||||
				okay = len(packet) == MessageCookieReplySize
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if okay {
 | 
			
		||||
				device.addToHandshakeQueue(
 | 
			
		||||
					device.queue.handshake,
 | 
			
		||||
					QueueHandshakeElement{
 | 
			
		||||
						msgType:  msgType,
 | 
			
		||||
						buffer:   buffer,
 | 
			
		||||
						packet:   packet,
 | 
			
		||||
						endpoint: endpoint,
 | 
			
		||||
					},
 | 
			
		||||
				)
 | 
			
		||||
				buffer = device.GetMessageBuffer()
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
@ -293,8 +286,6 @@ func (device *Device) RoutineHandshake() {
 | 
			
		||||
 | 
			
		||||
			// unmarshal packet
 | 
			
		||||
 | 
			
		||||
			logDebug.Println("Process cookie reply from:", elem.source.String())
 | 
			
		||||
 | 
			
		||||
			var reply MessageCookieReply
 | 
			
		||||
			reader := bytes.NewReader(elem.packet)
 | 
			
		||||
			err := binary.Read(reader, binary.LittleEndian, &reply)
 | 
			
		||||
@ -321,15 +312,25 @@ func (device *Device) RoutineHandshake() {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// endpoints destination address is the source of the datagram
 | 
			
		||||
 | 
			
		||||
			srcBytes := elem.endpoint.DstToBytes()
 | 
			
		||||
 | 
			
		||||
			if device.IsUnderLoad() {
 | 
			
		||||
				if !device.mac.CheckMAC2(elem.packet, elem.source) {
 | 
			
		||||
 | 
			
		||||
				// verify MAC2 field
 | 
			
		||||
 | 
			
		||||
				if !device.mac.CheckMAC2(elem.packet, srcBytes) {
 | 
			
		||||
 | 
			
		||||
					// construct cookie reply
 | 
			
		||||
 | 
			
		||||
					logDebug.Println("Sending cookie reply to:", elem.source.String())
 | 
			
		||||
					logDebug.Println(
 | 
			
		||||
						"Sending cookie reply to:",
 | 
			
		||||
						elem.endpoint.DstToString(),
 | 
			
		||||
					)
 | 
			
		||||
 | 
			
		||||
					sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type"
 | 
			
		||||
					reply, err := device.mac.CreateReply(elem.packet, sender, elem.source)
 | 
			
		||||
					sender := binary.LittleEndian.Uint32(elem.packet[4:8])
 | 
			
		||||
					reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						logError.Println("Failed to create cookie reply:", err)
 | 
			
		||||
						return
 | 
			
		||||
@ -339,17 +340,16 @@ func (device *Device) RoutineHandshake() {
 | 
			
		||||
 | 
			
		||||
					writer := bytes.NewBuffer(temp[:0])
 | 
			
		||||
					binary.Write(writer, binary.LittleEndian, reply)
 | 
			
		||||
					_, err = device.net.conn.WriteToUDP(
 | 
			
		||||
						writer.Bytes(),
 | 
			
		||||
						elem.source,
 | 
			
		||||
					)
 | 
			
		||||
					device.net.bind.Send(writer.Bytes(), elem.endpoint)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						logDebug.Println("Failed to send cookie reply:", err)
 | 
			
		||||
					}
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if !device.ratelimiter.Allow(elem.source.IP) {
 | 
			
		||||
				// check ratelimiter
 | 
			
		||||
 | 
			
		||||
				if !device.ratelimiter.Allow(elem.endpoint.DstIP()) {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
@ -380,8 +380,7 @@ func (device *Device) RoutineHandshake() {
 | 
			
		||||
			if peer == nil {
 | 
			
		||||
				logInfo.Println(
 | 
			
		||||
					"Recieved invalid initiation message from",
 | 
			
		||||
					elem.source.IP.String(),
 | 
			
		||||
					elem.source.Port,
 | 
			
		||||
					elem.endpoint.DstToString(),
 | 
			
		||||
				)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
@ -392,10 +391,9 @@ func (device *Device) RoutineHandshake() {
 | 
			
		||||
			peer.TimerAnyAuthenticatedPacketReceived()
 | 
			
		||||
 | 
			
		||||
			// update endpoint
 | 
			
		||||
			// TODO: Discover destination address also, only update on change
 | 
			
		||||
 | 
			
		||||
			peer.mutex.Lock()
 | 
			
		||||
			peer.endpoint = elem.source
 | 
			
		||||
			peer.endpoint = elem.endpoint
 | 
			
		||||
			peer.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
			// create response
 | 
			
		||||
@ -418,9 +416,11 @@ func (device *Device) RoutineHandshake() {
 | 
			
		||||
 | 
			
		||||
			// send response
 | 
			
		||||
 | 
			
		||||
			_, err = peer.SendBuffer(packet)
 | 
			
		||||
			err = peer.SendBuffer(packet)
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				peer.TimerAnyAuthenticatedPacketTraversal()
 | 
			
		||||
			} else {
 | 
			
		||||
				logError.Println("Failed to send response to:", peer.String(), err)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		case MessageResponseType:
 | 
			
		||||
@ -441,12 +441,17 @@ func (device *Device) RoutineHandshake() {
 | 
			
		||||
			if peer == nil {
 | 
			
		||||
				logInfo.Println(
 | 
			
		||||
					"Recieved invalid response message from",
 | 
			
		||||
					elem.source.IP.String(),
 | 
			
		||||
					elem.source.Port,
 | 
			
		||||
					elem.endpoint.DstToString(),
 | 
			
		||||
				)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// update endpoint
 | 
			
		||||
 | 
			
		||||
			peer.mutex.Lock()
 | 
			
		||||
			peer.endpoint = elem.endpoint
 | 
			
		||||
			peer.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
			logDebug.Println("Received handshake initation from", peer)
 | 
			
		||||
 | 
			
		||||
			peer.TimerEphemeralKeyCreated()
 | 
			
		||||
@ -515,6 +520,12 @@ func (peer *Peer) RoutineSequentialReceiver() {
 | 
			
		||||
			}
 | 
			
		||||
			kp.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
			// update endpoint
 | 
			
		||||
 | 
			
		||||
			peer.mutex.Lock()
 | 
			
		||||
			peer.endpoint = elem.endpoint
 | 
			
		||||
			peer.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
			// check for keep-alive
 | 
			
		||||
 | 
			
		||||
			if len(elem.packet) == 0 {
 | 
			
		||||
@ -546,7 +557,10 @@ func (peer *Peer) RoutineSequentialReceiver() {
 | 
			
		||||
 | 
			
		||||
				src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
 | 
			
		||||
				if device.routingTable.LookupIPv4(src) != peer {
 | 
			
		||||
					logInfo.Println("Packet with unallowed source IP from", peer.String())
 | 
			
		||||
					logInfo.Println(
 | 
			
		||||
						"IPv4 packet with unallowed source address from",
 | 
			
		||||
						peer.String(),
 | 
			
		||||
					)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
@ -571,7 +585,10 @@ func (peer *Peer) RoutineSequentialReceiver() {
 | 
			
		||||
 | 
			
		||||
				src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
 | 
			
		||||
				if device.routingTable.LookupIPv6(src) != peer {
 | 
			
		||||
					logInfo.Println("Packet with unallowed source IP from", peer.String())
 | 
			
		||||
					logInfo.Println(
 | 
			
		||||
						"IPv6 packet with unallowed source address from",
 | 
			
		||||
						peer.String(),
 | 
			
		||||
					)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
@ -580,7 +597,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// write to tun
 | 
			
		||||
			// write to tun device
 | 
			
		||||
 | 
			
		||||
			atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
 | 
			
		||||
			_, err := device.tun.device.Write(elem.packet)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										23
									
								
								src/send.go
									
									
									
									
									
								
							
							
						
						
									
										23
									
								
								src/send.go
									
									
									
									
									
								
							@ -2,7 +2,6 @@ package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/binary"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"golang.org/x/crypto/chacha20poly1305"
 | 
			
		||||
	"golang.org/x/net/ipv4"
 | 
			
		||||
	"golang.org/x/net/ipv6"
 | 
			
		||||
@ -105,26 +104,6 @@ func addToEncryptionQueue(
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (peer *Peer) SendBuffer(buffer []byte) (int, error) {
 | 
			
		||||
	peer.device.net.mutex.RLock()
 | 
			
		||||
	defer peer.device.net.mutex.RUnlock()
 | 
			
		||||
 | 
			
		||||
	peer.mutex.RLock()
 | 
			
		||||
	defer peer.mutex.RUnlock()
 | 
			
		||||
 | 
			
		||||
	endpoint := peer.endpoint
 | 
			
		||||
	if endpoint == nil {
 | 
			
		||||
		return 0, errors.New("No known endpoint for peer")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	conn := peer.device.net.conn
 | 
			
		||||
	if conn == nil {
 | 
			
		||||
		return 0, errors.New("No UDP socket for device")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return conn.WriteToUDP(buffer, endpoint)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* Reads packets from the TUN and inserts
 | 
			
		||||
 * into nonce queue for peer
 | 
			
		||||
 *
 | 
			
		||||
@ -343,7 +322,7 @@ func (peer *Peer) RoutineSequentialSender() {
 | 
			
		||||
			// send message and return buffer to pool
 | 
			
		||||
 | 
			
		||||
			length := uint64(len(elem.packet))
 | 
			
		||||
			_, err := peer.SendBuffer(elem.packet)
 | 
			
		||||
			err := peer.SendBuffer(elem.packet)
 | 
			
		||||
			device.PutMessageBuffer(elem.buffer)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logDebug.Println("Failed to send authenticated packet to peer", peer.String())
 | 
			
		||||
 | 
			
		||||
@ -20,6 +20,14 @@
 | 
			
		||||
# wireguard peers in $ns1 and $ns2. Note that $ns0 is the endpoint for the wg1
 | 
			
		||||
# interfaces in $ns1 and $ns2. See https://www.wireguard.com/netns/ for further
 | 
			
		||||
# details on how this is accomplished.
 | 
			
		||||
 | 
			
		||||
# This code is ported to the WireGuard-Go directly from the kernel project.
 | 
			
		||||
#
 | 
			
		||||
# Please ensure that you have installed the newest version of the WireGuard
 | 
			
		||||
# tools from the WireGuard project and before running these tests as:
 | 
			
		||||
#
 | 
			
		||||
# ./netns.sh <path to wireguard-go>
 | 
			
		||||
 | 
			
		||||
set -e
 | 
			
		||||
 | 
			
		||||
exec 3>&1
 | 
			
		||||
@ -27,8 +35,8 @@ export WG_HIDE_KEYS=never
 | 
			
		||||
netns0="wg-test-$$-0"
 | 
			
		||||
netns1="wg-test-$$-1"
 | 
			
		||||
netns2="wg-test-$$-2"
 | 
			
		||||
program="../wireguard-go"
 | 
			
		||||
export LOG_LEVEL="error"
 | 
			
		||||
program=$1
 | 
			
		||||
export LOG_LEVEL="info"
 | 
			
		||||
 | 
			
		||||
pretty() { echo -e "\x1b[32m\x1b[1m[+] ${1:+NS$1: }${2}\x1b[0m" >&3; }
 | 
			
		||||
pp() { pretty "" "$*"; "$@"; }
 | 
			
		||||
@ -72,13 +80,11 @@ pp ip netns add $netns2
 | 
			
		||||
ip0 link set up dev lo
 | 
			
		||||
 | 
			
		||||
# ip0 link add dev wg1 type wireguard
 | 
			
		||||
n0 $program -f wg1 &
 | 
			
		||||
sleep 1
 | 
			
		||||
n0 $program wg1
 | 
			
		||||
ip0 link set wg1 netns $netns1
 | 
			
		||||
 | 
			
		||||
# ip0 link add dev wg1 type wireguard
 | 
			
		||||
n0 $program -f wg2 &
 | 
			
		||||
sleep 1
 | 
			
		||||
n0 $program wg2
 | 
			
		||||
ip0 link set wg2 netns $netns2
 | 
			
		||||
 | 
			
		||||
key1="$(pp wg genkey)"
 | 
			
		||||
@ -185,14 +191,14 @@ ip0 -4 addr del 127.0.0.1/8 dev lo
 | 
			
		||||
ip0 -4 addr add 127.212.121.99/8 dev lo
 | 
			
		||||
n0 wg set wg1 listen-port 9999
 | 
			
		||||
n0 wg set wg1 peer "$pub2" endpoint 127.0.0.1:20000
 | 
			
		||||
n1 ping6 -W 1 -c 1 fd00::20000
 | 
			
		||||
[[ $(n2 wg show wg2 endpoints) == "$pub1    127.212.121.99:9999" ]]
 | 
			
		||||
n1 ping6 -W 1 -c 1 fd00::2
 | 
			
		||||
[[ $(n2 wg show wg2 endpoints) == "$pub1	127.212.121.99:9999" ]]
 | 
			
		||||
 | 
			
		||||
# Test using IPv6 that roaming works
 | 
			
		||||
n1 wg set wg1 listen-port 9998
 | 
			
		||||
n1 wg set wg1 peer "$pub2" endpoint [::1]:20000
 | 
			
		||||
n1 ping -W 1 -c 1 192.168.241.2
 | 
			
		||||
[[ $(n2 wg show wg2 endpoints) == "$pub1    [::1]:9998" ]]
 | 
			
		||||
[[ $(n2 wg show wg2 endpoints) == "$pub1	[::1]:9998" ]]
 | 
			
		||||
 | 
			
		||||
# Test that crypto-RP filter works
 | 
			
		||||
n1 wg set wg1 peer "$pub2" allowed-ips 192.168.241.0/24
 | 
			
		||||
@ -212,7 +218,7 @@ n2 ncat -u 192.168.241.1 1111 <<<"X"
 | 
			
		||||
! read -r -N 1 -t 1 out <&4
 | 
			
		||||
kill $nmap_pid
 | 
			
		||||
n0 wg set wg1 peer "$more_specific_key" remove
 | 
			
		||||
[[ $(n1 wg show wg1 endpoints) == "$pub2    [::1]:9997" ]]
 | 
			
		||||
[[ $(n1 wg show wg1 endpoints) == "$pub2	[::1]:9997" ]]
 | 
			
		||||
 | 
			
		||||
ip1 link del wg1
 | 
			
		||||
ip2 link del wg2
 | 
			
		||||
@ -263,7 +269,7 @@ n0 iptables -t nat -A POSTROUTING -s 192.168.1.0/24 -d 10.0.0.0/24 -j SNAT --to
 | 
			
		||||
n0 wg set wg1 peer "$pub2" endpoint 10.0.0.100:20000 persistent-keepalive 1
 | 
			
		||||
n1 ping -W 1 -c 1 192.168.241.2
 | 
			
		||||
n2 ping -W 1 -c 1 192.168.241.1
 | 
			
		||||
[[ $(n2 wg show wg2 endpoints) == "$pub1    10.0.0.1:10000" ]]
 | 
			
		||||
[[ $(n2 wg show wg2 endpoints) == "$pub1	10.0.0.1:10000" ]]
 | 
			
		||||
# Demonstrate n2 can still send packets to n1, since persistent-keepalive will prevent connection tracking entry from expiring (to see entries: `n0 conntrack -L`).
 | 
			
		||||
pp sleep 3
 | 
			
		||||
n2 ping -W 1 -c 1 192.168.241.1
 | 
			
		||||
@ -289,7 +295,7 @@ ip2 link del wg2
 | 
			
		||||
# ip1 link add dev wg1 type wireguard
 | 
			
		||||
# ip2 link add dev wg1 type wireguard
 | 
			
		||||
n1 $program wg1
 | 
			
		||||
n2 $program wg1
 | 
			
		||||
n2 $program wg2
 | 
			
		||||
 | 
			
		||||
configure_peers
 | 
			
		||||
 | 
			
		||||
@ -336,17 +342,83 @@ waitiface $netns1 veth1
 | 
			
		||||
waitiface $netns2 veth2
 | 
			
		||||
n0 wg set wg2 peer "$pub1" endpoint 10.0.0.1:10000
 | 
			
		||||
n2 ping -W 1 -c 1 192.168.241.1
 | 
			
		||||
[[ $(n0 wg show wg2 endpoints) == "$pub1    10.0.0.1:10000" ]]
 | 
			
		||||
[[ $(n0 wg show wg2 endpoints) == "$pub1	10.0.0.1:10000" ]]
 | 
			
		||||
n0 wg set wg2 peer "$pub1" endpoint [fd00:aa::1]:10000
 | 
			
		||||
n2 ping -W 1 -c 1 192.168.241.1
 | 
			
		||||
[[ $(n0 wg show wg2 endpoints) == "$pub1    [fd00:aa::1]:10000" ]]
 | 
			
		||||
[[ $(n0 wg show wg2 endpoints) == "$pub1	[fd00:aa::1]:10000" ]]
 | 
			
		||||
n0 wg set wg2 peer "$pub1" endpoint 10.0.0.2:10000
 | 
			
		||||
n2 ping -W 1 -c 1 192.168.241.1
 | 
			
		||||
[[ $(n0 wg show wg2 endpoints) == "$pub1    10.0.0.2:10000" ]]
 | 
			
		||||
[[ $(n0 wg show wg2 endpoints) == "$pub1	10.0.0.2:10000" ]]
 | 
			
		||||
n0 wg set wg2 peer "$pub1" endpoint [fd00:aa::2]:10000
 | 
			
		||||
n2 ping -W 1 -c 1 192.168.241.1
 | 
			
		||||
[[ $(n0 wg show wg2 endpoints) == "$pub1    [fd00:aa::2]:10000" ]]
 | 
			
		||||
[[ $(n0 wg show wg2 endpoints) == "$pub1	[fd00:aa::2]:10000" ]]
 | 
			
		||||
 | 
			
		||||
ip1 link del veth1
 | 
			
		||||
ip1 link del wg1
 | 
			
		||||
ip2 link del wg2
 | 
			
		||||
 | 
			
		||||
# Test that Netlink/IPC is working properly by doing things that usually cause split responses
 | 
			
		||||
 | 
			
		||||
n0 $program wg0
 | 
			
		||||
sleep 5
 | 
			
		||||
config=( "[Interface]" "PrivateKey=$(wg genkey)" "[Peer]" "PublicKey=$(wg genkey)" )
 | 
			
		||||
for a in {1..255}; do
 | 
			
		||||
    for b in {0..255}; do
 | 
			
		||||
        config+=( "AllowedIPs=$a.$b.0.0/16,$a::$b/128" )
 | 
			
		||||
    done
 | 
			
		||||
done
 | 
			
		||||
n0 wg setconf wg0 <(printf '%s\n' "${config[@]}")
 | 
			
		||||
i=0
 | 
			
		||||
for ip in $(n0 wg show wg0 allowed-ips); do
 | 
			
		||||
    ((++i))
 | 
			
		||||
done
 | 
			
		||||
((i == 255*256*2+1))
 | 
			
		||||
ip0 link del wg0
 | 
			
		||||
 | 
			
		||||
n0 $program wg0
 | 
			
		||||
config=( "[Interface]" "PrivateKey=$(wg genkey)" )
 | 
			
		||||
for a in {1..40}; do
 | 
			
		||||
    config+=( "[Peer]" "PublicKey=$(wg genkey)" )
 | 
			
		||||
    for b in {1..52}; do
 | 
			
		||||
        config+=( "AllowedIPs=$a.$b.0.0/16" )
 | 
			
		||||
    done
 | 
			
		||||
done
 | 
			
		||||
n0 wg setconf wg0 <(printf '%s\n' "${config[@]}")
 | 
			
		||||
i=0
 | 
			
		||||
while read -r line; do
 | 
			
		||||
    j=0
 | 
			
		||||
    for ip in $line; do
 | 
			
		||||
        ((++j))
 | 
			
		||||
    done
 | 
			
		||||
    ((j == 53))
 | 
			
		||||
    ((++i))
 | 
			
		||||
done < <(n0 wg show wg0 allowed-ips)
 | 
			
		||||
((i == 40))
 | 
			
		||||
ip0 link del wg0
 | 
			
		||||
 | 
			
		||||
n0 $program wg0
 | 
			
		||||
config=( )
 | 
			
		||||
for i in {1..29}; do
 | 
			
		||||
    config+=( "[Peer]" "PublicKey=$(wg genkey)" )
 | 
			
		||||
done
 | 
			
		||||
config+=( "[Peer]" "PublicKey=$(wg genkey)" "AllowedIPs=255.2.3.4/32,abcd::255/128" )
 | 
			
		||||
n0 wg setconf wg0 <(printf '%s\n' "${config[@]}")
 | 
			
		||||
n0 wg showconf wg0 > /dev/null
 | 
			
		||||
ip0 link del wg0
 | 
			
		||||
 | 
			
		||||
! n0 wg show doesnotexist || false
 | 
			
		||||
 | 
			
		||||
declare -A objects
 | 
			
		||||
while read -t 0.1 -r line 2>/dev/null || [[ $? -ne 142 ]]; do
 | 
			
		||||
    [[ $line =~ .*(wg[0-9]+:\ [A-Z][a-z]+\ [0-9]+)\ .*(created|destroyed).* ]] || continue
 | 
			
		||||
    objects["${BASH_REMATCH[1]}"]+="${BASH_REMATCH[2]}"
 | 
			
		||||
done < /dev/kmsg
 | 
			
		||||
alldeleted=1
 | 
			
		||||
for object in "${!objects[@]}"; do
 | 
			
		||||
    if [[ ${objects["$object"]} != *createddestroyed ]]; then
 | 
			
		||||
        echo "Error: $object: merely ${objects["$object"]}" >&3
 | 
			
		||||
        alldeleted=0
 | 
			
		||||
    fi
 | 
			
		||||
done
 | 
			
		||||
[[ $alldeleted -eq 1 ]]
 | 
			
		||||
pretty "" "Objects that were created were also destroyed."
 | 
			
		||||
 | 
			
		||||
@ -279,34 +279,31 @@ func (peer *Peer) RoutineHandshakeInitiator() {
 | 
			
		||||
				break AttemptHandshakes
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
 | 
			
		||||
 | 
			
		||||
			// marshal and send
 | 
			
		||||
			// marshal handshake message
 | 
			
		||||
 | 
			
		||||
			writer := bytes.NewBuffer(temp[:0])
 | 
			
		||||
			binary.Write(writer, binary.LittleEndian, msg)
 | 
			
		||||
			packet := writer.Bytes()
 | 
			
		||||
			peer.mac.AddMacs(packet)
 | 
			
		||||
 | 
			
		||||
			_, err = peer.SendBuffer(packet)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
			// send to endpoint
 | 
			
		||||
 | 
			
		||||
			err = peer.SendBuffer(packet)
 | 
			
		||||
			jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
 | 
			
		||||
			timeout := time.NewTimer(RekeyTimeout + jitter)
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				peer.TimerAnyAuthenticatedPacketTraversal()
 | 
			
		||||
				logDebug.Println(
 | 
			
		||||
					"Handshake initiation attempt",
 | 
			
		||||
					attempts, "sent to", peer.String(),
 | 
			
		||||
				)
 | 
			
		||||
			} else {
 | 
			
		||||
				logError.Println(
 | 
			
		||||
					"Failed to send handshake initiation message to",
 | 
			
		||||
					peer.String(), ":", err,
 | 
			
		||||
				)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			peer.TimerAnyAuthenticatedPacketTraversal()
 | 
			
		||||
 | 
			
		||||
			// set handshake timeout
 | 
			
		||||
 | 
			
		||||
			timeout := time.NewTimer(RekeyTimeout + jitter)
 | 
			
		||||
			logDebug.Println(
 | 
			
		||||
				"Handshake initiation attempt",
 | 
			
		||||
				attempts, "sent to", peer.String(),
 | 
			
		||||
			)
 | 
			
		||||
 | 
			
		||||
			// wait for handshake or timeout
 | 
			
		||||
 | 
			
		||||
			select {
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,7 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"os"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -15,6 +16,7 @@ const (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type TUNDevice interface {
 | 
			
		||||
	File() *os.File            // returns the file descriptor of the device
 | 
			
		||||
	Read([]byte) (int, error)  // read a packet from the device (without any additional headers)
 | 
			
		||||
	Write([]byte) (int, error) // writes a packet to the device (without any additional headers)
 | 
			
		||||
	MTU() (int, error)         // returns the MTU of the device
 | 
			
		||||
@ -47,7 +49,7 @@ func (device *Device) RoutineTUNEventReader() {
 | 
			
		||||
			if !device.tun.isUp.Get() {
 | 
			
		||||
				logInfo.Println("Interface set up")
 | 
			
		||||
				device.tun.isUp.Set(true)
 | 
			
		||||
				updateUDPConn(device)
 | 
			
		||||
				UpdateUDPListener(device)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
@ -55,7 +57,7 @@ func (device *Device) RoutineTUNEventReader() {
 | 
			
		||||
			if device.tun.isUp.Get() {
 | 
			
		||||
				logInfo.Println("Interface set down")
 | 
			
		||||
				device.tun.isUp.Set(false)
 | 
			
		||||
				closeUDPConn(device)
 | 
			
		||||
				CloseUDPListener(device)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -56,6 +56,10 @@ type NativeTun struct {
 | 
			
		||||
	events chan TUNEvent // device related events
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tun *NativeTun) File() *os.File {
 | 
			
		||||
	return tun.fd
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tun *NativeTun) RoutineNetlinkListener() {
 | 
			
		||||
	sock := int(C.bind_rtmgrp())
 | 
			
		||||
	if sock < 0 {
 | 
			
		||||
@ -222,7 +226,7 @@ func (tun *NativeTun) MTU() (int, error) {
 | 
			
		||||
 | 
			
		||||
	val := binary.LittleEndian.Uint32(ifr[16:20])
 | 
			
		||||
	if val >= (1 << 31) {
 | 
			
		||||
		return int(val-(1<<31)) - (1 << 31), nil
 | 
			
		||||
		return int(toInt32(val)), nil
 | 
			
		||||
	}
 | 
			
		||||
	return int(val), nil
 | 
			
		||||
}
 | 
			
		||||
@ -248,6 +252,29 @@ func (tun *NativeTun) Close() error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CreateTUNFromFile(name string, fd *os.File) (TUNDevice, error) {
 | 
			
		||||
	device := &NativeTun{
 | 
			
		||||
		fd:     fd,
 | 
			
		||||
		name:   name,
 | 
			
		||||
		events: make(chan TUNEvent, 5),
 | 
			
		||||
		errors: make(chan error, 5),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// start event listener
 | 
			
		||||
 | 
			
		||||
	var err error
 | 
			
		||||
	device.index, err = getIFIndex(device.name)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	go device.RoutineNetlinkListener()
 | 
			
		||||
 | 
			
		||||
	// set default MTU
 | 
			
		||||
 | 
			
		||||
	return device, device.setMTU(DefaultMTU)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CreateTUN(name string) (TUNDevice, error) {
 | 
			
		||||
 | 
			
		||||
	// open clone device
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										83
									
								
								src/uapi.go
									
									
									
									
									
								
							
							
						
						
									
										83
									
								
								src/uapi.go
									
									
									
									
									
								
							@ -39,9 +39,10 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 | 
			
		||||
		send("private_key=" + device.privateKey.ToHex())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if device.net.addr != nil {
 | 
			
		||||
		send(fmt.Sprintf("listen_port=%d", device.net.addr.Port))
 | 
			
		||||
	if device.net.port != 0 {
 | 
			
		||||
		send(fmt.Sprintf("listen_port=%d", device.net.port))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if device.net.fwmark != 0 {
 | 
			
		||||
		send(fmt.Sprintf("fwmark=%d", device.net.fwmark))
 | 
			
		||||
	}
 | 
			
		||||
@ -53,7 +54,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 | 
			
		||||
			send("public_key=" + peer.handshake.remoteStatic.ToHex())
 | 
			
		||||
			send("preshared_key=" + peer.handshake.presharedKey.ToHex())
 | 
			
		||||
			if peer.endpoint != nil {
 | 
			
		||||
				send("endpoint=" + peer.endpoint.String())
 | 
			
		||||
				send("endpoint=" + peer.endpoint.DstToString())
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
 | 
			
		||||
@ -134,56 +135,38 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 | 
			
		||||
			case "listen_port":
 | 
			
		||||
				port, err := strconv.ParseUint(value, 10, 16)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logError.Println("Failed to set listen_port:", err)
 | 
			
		||||
					logError.Println("Failed to parse listen_port:", err)
 | 
			
		||||
					return &IPCError{Code: ipcErrorInvalid}
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port))
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logError.Println("Failed to set listen_port:", err)
 | 
			
		||||
					return &IPCError{Code: ipcErrorInvalid}
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				device.net.mutex.Lock()
 | 
			
		||||
				device.net.addr = addr
 | 
			
		||||
				device.net.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
				err = updateUDPConn(device)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
				device.net.port = uint16(port)
 | 
			
		||||
				if err := UpdateUDPListener(device); err != nil {
 | 
			
		||||
					logError.Println("Failed to set listen_port:", err)
 | 
			
		||||
					return &IPCError{Code: ipcErrorPortInUse}
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// TODO: Clear source address of all peers
 | 
			
		||||
 | 
			
		||||
			case "fwmark":
 | 
			
		||||
				fwmark, err := strconv.ParseUint(value, 10, 32)
 | 
			
		||||
 | 
			
		||||
				// parse fwmark field
 | 
			
		||||
 | 
			
		||||
				fwmark, err := func() (uint32, error) {
 | 
			
		||||
					if value == "" {
 | 
			
		||||
						return 0, nil
 | 
			
		||||
					}
 | 
			
		||||
					mark, err := strconv.ParseUint(value, 10, 32)
 | 
			
		||||
					return uint32(mark), err
 | 
			
		||||
				}()
 | 
			
		||||
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logError.Println("Invalid fwmark", err)
 | 
			
		||||
					return &IPCError{Code: ipcErrorInvalid}
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				device.net.mutex.Lock()
 | 
			
		||||
				if fwmark > 0 || device.net.fwmark > 0 {
 | 
			
		||||
					device.net.fwmark = uint32(fwmark)
 | 
			
		||||
					err := setMark(
 | 
			
		||||
						device.net.conn,
 | 
			
		||||
						device.net.fwmark,
 | 
			
		||||
					)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						logError.Println("Failed to set fwmark:", err)
 | 
			
		||||
						device.net.mutex.Unlock()
 | 
			
		||||
						return &IPCError{Code: ipcErrorIO}
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					// TODO: Clear source address of all peers
 | 
			
		||||
				}
 | 
			
		||||
				device.net.fwmark = uint32(fwmark)
 | 
			
		||||
				device.net.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
			case "public_key":
 | 
			
		||||
 | 
			
		||||
				// switch to peer configuration
 | 
			
		||||
 | 
			
		||||
				deviceConfig = false
 | 
			
		||||
 | 
			
		||||
			case "replace_peers":
 | 
			
		||||
@ -218,7 +201,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 | 
			
		||||
				device.mutex.RLock()
 | 
			
		||||
				if device.publicKey.Equals(pubKey) {
 | 
			
		||||
 | 
			
		||||
					// create dummy instance
 | 
			
		||||
					// create dummy instance (not added to device)
 | 
			
		||||
 | 
			
		||||
					peer = &Peer{}
 | 
			
		||||
					dummy = true
 | 
			
		||||
@ -244,6 +227,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
			case "remove":
 | 
			
		||||
 | 
			
		||||
				// remove currently selected peer from device
 | 
			
		||||
 | 
			
		||||
				if value != "true" {
 | 
			
		||||
					logError.Println("Failed to set remove, invalid value:", value)
 | 
			
		||||
					return &IPCError{Code: ipcErrorInvalid}
 | 
			
		||||
@ -256,6 +242,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 | 
			
		||||
				dummy = true
 | 
			
		||||
 | 
			
		||||
			case "preshared_key":
 | 
			
		||||
 | 
			
		||||
				// update PSK
 | 
			
		||||
 | 
			
		||||
				peer.mutex.Lock()
 | 
			
		||||
				err := peer.handshake.presharedKey.FromHex(value)
 | 
			
		||||
				peer.mutex.Unlock()
 | 
			
		||||
@ -265,15 +254,25 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
			case "endpoint":
 | 
			
		||||
				addr, err := parseEndpoint(value)
 | 
			
		||||
 | 
			
		||||
				// set endpoint destination
 | 
			
		||||
 | 
			
		||||
				err := func() error {
 | 
			
		||||
					peer.mutex.Lock()
 | 
			
		||||
					defer peer.mutex.Unlock()
 | 
			
		||||
					endpoint, err := CreateEndpoint(value)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return err
 | 
			
		||||
					}
 | 
			
		||||
					peer.endpoint = endpoint
 | 
			
		||||
					signalSend(peer.signal.handshakeReset)
 | 
			
		||||
					return nil
 | 
			
		||||
				}()
 | 
			
		||||
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logError.Println("Failed to set endpoint:", value)
 | 
			
		||||
					return &IPCError{Code: ipcErrorInvalid}
 | 
			
		||||
				}
 | 
			
		||||
				peer.mutex.Lock()
 | 
			
		||||
				peer.endpoint = addr
 | 
			
		||||
				peer.mutex.Unlock()
 | 
			
		||||
				signalSend(peer.signal.handshakeReset)
 | 
			
		||||
 | 
			
		||||
			case "persistent_keepalive_interval":
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -10,12 +10,12 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	ipcErrorIO         = -int64(unix.EIO)
 | 
			
		||||
	ipcErrorProtocol   = -int64(unix.EPROTO)
 | 
			
		||||
	ipcErrorInvalid    = -int64(unix.EINVAL)
 | 
			
		||||
	ipcErrorPortInUse  = -int64(unix.EADDRINUSE)
 | 
			
		||||
	socketDirectory    = "/var/run/wireguard"
 | 
			
		||||
	socketName         = "%s.sock"
 | 
			
		||||
	ipcErrorIO        = -int64(unix.EIO)
 | 
			
		||||
	ipcErrorProtocol  = -int64(unix.EPROTO)
 | 
			
		||||
	ipcErrorInvalid   = -int64(unix.EINVAL)
 | 
			
		||||
	ipcErrorPortInUse = -int64(unix.EADDRINUSE)
 | 
			
		||||
	socketDirectory   = "/var/run/wireguard"
 | 
			
		||||
	socketName        = "%s.sock"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type UAPIListener struct {
 | 
			
		||||
@ -50,49 +50,11 @@ func (l *UAPIListener) Addr() net.Addr {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func connectUnixSocket(path string) (net.Listener, error) {
 | 
			
		||||
func UAPIListen(name string, file *os.File) (net.Listener, error) {
 | 
			
		||||
 | 
			
		||||
	// attempt inital connection
 | 
			
		||||
	// wrap file in listener
 | 
			
		||||
 | 
			
		||||
	listener, err := net.Listen("unix", path)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		return listener, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// check if active
 | 
			
		||||
 | 
			
		||||
	_, err = net.Dial("unix", path)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		return nil, errors.New("Unix socket in use")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// attempt cleanup
 | 
			
		||||
 | 
			
		||||
	err = os.Remove(path)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return net.Listen("unix", path)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewUAPIListener(name string) (net.Listener, error) {
 | 
			
		||||
 | 
			
		||||
	// check if path exist
 | 
			
		||||
 | 
			
		||||
	err := os.MkdirAll(socketDirectory, 077)
 | 
			
		||||
	if err != nil && !os.IsExist(err) {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// open UNIX socket
 | 
			
		||||
 | 
			
		||||
	socketPath := path.Join(
 | 
			
		||||
		socketDirectory,
 | 
			
		||||
		fmt.Sprintf(socketName, name),
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	listener, err := connectUnixSocket(socketPath)
 | 
			
		||||
	listener, err := net.FileListener(file)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@ -105,6 +67,11 @@ func NewUAPIListener(name string) (net.Listener, error) {
 | 
			
		||||
 | 
			
		||||
	// watch for deletion of socket
 | 
			
		||||
 | 
			
		||||
	socketPath := path.Join(
 | 
			
		||||
		socketDirectory,
 | 
			
		||||
		fmt.Sprintf(socketName, name),
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	uapi.inotifyFd, err = unix.InotifyInit()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
@ -125,11 +92,12 @@ func NewUAPIListener(name string) (net.Listener, error) {
 | 
			
		||||
	go func(l *UAPIListener) {
 | 
			
		||||
		var buff [4096]byte
 | 
			
		||||
		for {
 | 
			
		||||
			unix.Read(uapi.inotifyFd, buff[:])
 | 
			
		||||
			// start with lstat to avoid race condition
 | 
			
		||||
			if _, err := os.Lstat(socketPath); os.IsNotExist(err) {
 | 
			
		||||
				l.connErr <- err
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			unix.Read(uapi.inotifyFd, buff[:])
 | 
			
		||||
		}
 | 
			
		||||
	}(uapi)
 | 
			
		||||
 | 
			
		||||
@ -148,3 +116,56 @@ func NewUAPIListener(name string) (net.Listener, error) {
 | 
			
		||||
 | 
			
		||||
	return uapi, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UAPIOpen(name string) (*os.File, error) {
 | 
			
		||||
 | 
			
		||||
	// check if path exist
 | 
			
		||||
 | 
			
		||||
	err := os.MkdirAll(socketDirectory, 0600)
 | 
			
		||||
	if err != nil && !os.IsExist(err) {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// open UNIX socket
 | 
			
		||||
 | 
			
		||||
	socketPath := path.Join(
 | 
			
		||||
		socketDirectory,
 | 
			
		||||
		fmt.Sprintf(socketName, name),
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	addr, err := net.ResolveUnixAddr("unix", socketPath)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	listener, err := func() (*net.UnixListener, error) {
 | 
			
		||||
 | 
			
		||||
		// initial connection attempt
 | 
			
		||||
 | 
			
		||||
		listener, err := net.ListenUnix("unix", addr)
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			return listener, nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// check if socket already active
 | 
			
		||||
 | 
			
		||||
		_, err = net.Dial("unix", socketPath)
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			return nil, errors.New("unix socket in use")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// cleanup & attempt again
 | 
			
		||||
 | 
			
		||||
		err = os.Remove(socketPath)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		return net.ListenUnix("unix", addr)
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return listener.File()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user