Added new UDPBind interface
This commit is contained in:
		
							parent
							
								
									2d856045a0
								
							
						
					
					
						commit
						a72b0f7ae5
					
				
							
								
								
									
										83
									
								
								src/conn.go
									
									
									
									
									
								
							
							
						
						
									
										83
									
								
								src/conn.go
									
									
									
									
									
								
							@ -5,6 +5,14 @@ import (
 | 
			
		||||
	"net"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type UDPBind interface {
 | 
			
		||||
	SetMark(value uint32) error
 | 
			
		||||
	ReceiveIPv6(buff []byte, end *Endpoint) (int, error)
 | 
			
		||||
	ReceiveIPv4(buff []byte, end *Endpoint) (int, error)
 | 
			
		||||
	Send(buff []byte, end *Endpoint) error
 | 
			
		||||
	Close() error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func parseEndpoint(s string) (*net.UDPAddr, error) {
 | 
			
		||||
 | 
			
		||||
	// ensure that the host is an IP address
 | 
			
		||||
@ -26,19 +34,6 @@ func parseEndpoint(s string) (*net.UDPAddr, error) {
 | 
			
		||||
	return addr, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ListenerClose(l *Listener) (err error) {
 | 
			
		||||
	if l.active {
 | 
			
		||||
		err = CloseIPv4Socket(l.sock)
 | 
			
		||||
		l.active = false
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (l *Listener) Init() {
 | 
			
		||||
	l.update = make(chan struct{}, 1)
 | 
			
		||||
	ListenerClose(l)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ListeningUpdate(device *Device) error {
 | 
			
		||||
	netc := &device.net
 | 
			
		||||
	netc.mutex.Lock()
 | 
			
		||||
@ -46,11 +41,7 @@ func ListeningUpdate(device *Device) error {
 | 
			
		||||
 | 
			
		||||
	// close existing sockets
 | 
			
		||||
 | 
			
		||||
	if err := ListenerClose(&netc.ipv4); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := ListenerClose(&netc.ipv6); err != nil {
 | 
			
		||||
	if err := device.net.bind.Close(); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -58,45 +49,22 @@ func ListeningUpdate(device *Device) error {
 | 
			
		||||
 | 
			
		||||
	if device.tun.isUp.Get() {
 | 
			
		||||
 | 
			
		||||
		// listen on IPv4
 | 
			
		||||
		// bind to new port
 | 
			
		||||
 | 
			
		||||
		{
 | 
			
		||||
			list := &netc.ipv6
 | 
			
		||||
			sock, port, err := CreateIPv4Socket(netc.port)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			netc.port = port
 | 
			
		||||
			list.sock = sock
 | 
			
		||||
			list.active = true
 | 
			
		||||
 | 
			
		||||
			if err := SetMark(list.sock, netc.fwmark); err != nil {
 | 
			
		||||
				ListenerClose(list)
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			signalSend(list.update)
 | 
			
		||||
		var err error
 | 
			
		||||
		netc.bind, netc.port, err = CreateUDPBind(netc.port)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// listen on IPv6
 | 
			
		||||
		// set mark
 | 
			
		||||
 | 
			
		||||
		{
 | 
			
		||||
			list := &netc.ipv6
 | 
			
		||||
			sock, port, err := CreateIPv6Socket(netc.port)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			netc.port = port
 | 
			
		||||
			list.sock = sock
 | 
			
		||||
			list.active = true
 | 
			
		||||
 | 
			
		||||
			if err := SetMark(list.sock, netc.fwmark); err != nil {
 | 
			
		||||
				ListenerClose(list)
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			signalSend(list.update)
 | 
			
		||||
		err = netc.bind.SetMark(netc.fwmark)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// TODO: clear endpoint caches
 | 
			
		||||
		// TODO: clear endpoint (src) caches
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
@ -106,16 +74,5 @@ func ListeningClose(device *Device) error {
 | 
			
		||||
	netc := &device.net
 | 
			
		||||
	netc.mutex.Lock()
 | 
			
		||||
	defer netc.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	if err := ListenerClose(&netc.ipv4); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	signalSend(netc.ipv4.update)
 | 
			
		||||
 | 
			
		||||
	if err := ListenerClose(&netc.ipv6); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	signalSend(netc.ipv6.update)
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
	return netc.bind.Close()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -14,35 +14,158 @@ import (
 | 
			
		||||
	"unsafe"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
import "fmt"
 | 
			
		||||
 | 
			
		||||
/* Supports source address caching
 | 
			
		||||
 *
 | 
			
		||||
 * Currently there is no way to achieve this within the net package:
 | 
			
		||||
 * See e.g. https://github.com/golang/go/issues/17930
 | 
			
		||||
 * So this code is platform dependent.
 | 
			
		||||
 *
 | 
			
		||||
 * It is important that the endpoint is only updated after the packet content has been authenticated!
 | 
			
		||||
 * So this code is remains platform dependent.
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
type Endpoint struct {
 | 
			
		||||
	// source (selected based on dst type)
 | 
			
		||||
	// (could use RawSockaddrAny and unsafe)
 | 
			
		||||
	// TODO: Merge
 | 
			
		||||
	src6   unix.RawSockaddrInet6
 | 
			
		||||
	src4   unix.RawSockaddrInet4
 | 
			
		||||
	src4if int32
 | 
			
		||||
 | 
			
		||||
	dst unix.RawSockaddrAny
 | 
			
		||||
	src unix.RawSockaddrInet6
 | 
			
		||||
	dst unix.RawSockaddrInet6
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Socket int
 | 
			
		||||
type IPv4Source struct {
 | 
			
		||||
	src     unix.RawSockaddrInet4
 | 
			
		||||
	Ifindex int32
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* Returns a byte representation of the source field(s)
 | 
			
		||||
 * for use in "under load" cookie computations.
 | 
			
		||||
 */
 | 
			
		||||
func (endpoint *Endpoint) Source() []byte {
 | 
			
		||||
	return nil
 | 
			
		||||
type Bind struct {
 | 
			
		||||
	sock4 int
 | 
			
		||||
	sock6 int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CreateUDPBind(port uint16) (UDPBind, uint16, error) {
 | 
			
		||||
	var err error
 | 
			
		||||
	var bind Bind
 | 
			
		||||
 | 
			
		||||
	bind.sock6, port, err = create6(port)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, port, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	bind.sock4, port, err = create4(port)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		unix.Close(bind.sock6)
 | 
			
		||||
	}
 | 
			
		||||
	return &bind, port, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (bind *Bind) SetMark(value uint32) error {
 | 
			
		||||
	err := unix.SetsockoptInt(
 | 
			
		||||
		bind.sock6,
 | 
			
		||||
		unix.SOL_SOCKET,
 | 
			
		||||
		unix.SO_MARK,
 | 
			
		||||
		int(value),
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return unix.SetsockoptInt(
 | 
			
		||||
		bind.sock4,
 | 
			
		||||
		unix.SOL_SOCKET,
 | 
			
		||||
		unix.SO_MARK,
 | 
			
		||||
		int(value),
 | 
			
		||||
	)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (bind *Bind) Close() error {
 | 
			
		||||
	err1 := unix.Close(bind.sock6)
 | 
			
		||||
	err2 := unix.Close(bind.sock4)
 | 
			
		||||
	if err1 != nil {
 | 
			
		||||
		return err1
 | 
			
		||||
	}
 | 
			
		||||
	return err2
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (bind *Bind) ReceiveIPv6(buff []byte, end *Endpoint) (int, error) {
 | 
			
		||||
	return receive6(
 | 
			
		||||
		bind.sock6,
 | 
			
		||||
		buff,
 | 
			
		||||
		end,
 | 
			
		||||
	)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (bind *Bind) ReceiveIPv4(buff []byte, end *Endpoint) (int, error) {
 | 
			
		||||
	return receive4(
 | 
			
		||||
		bind.sock4,
 | 
			
		||||
		buff,
 | 
			
		||||
		end,
 | 
			
		||||
	)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (bind *Bind) Send(buff []byte, end *Endpoint) error {
 | 
			
		||||
	switch end.src.Family {
 | 
			
		||||
	case unix.AF_INET6:
 | 
			
		||||
		return send6(bind.sock6, end, buff)
 | 
			
		||||
	case unix.AF_INET:
 | 
			
		||||
		return send4(bind.sock4, end, buff)
 | 
			
		||||
	default:
 | 
			
		||||
		return errors.New("Unknown address family of source")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func sockaddrToString(addr unix.RawSockaddrInet6) string {
 | 
			
		||||
	var udpAddr net.UDPAddr
 | 
			
		||||
 | 
			
		||||
	switch addr.Family {
 | 
			
		||||
	case unix.AF_INET6:
 | 
			
		||||
		udpAddr.Port = int(addr.Port)
 | 
			
		||||
		udpAddr.IP = addr.Addr[:]
 | 
			
		||||
		return udpAddr.String()
 | 
			
		||||
 | 
			
		||||
	case unix.AF_INET:
 | 
			
		||||
		ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr))
 | 
			
		||||
		udpAddr.Port = int(ptr.Port)
 | 
			
		||||
		udpAddr.IP = net.IPv4(
 | 
			
		||||
			ptr.Addr[0],
 | 
			
		||||
			ptr.Addr[1],
 | 
			
		||||
			ptr.Addr[2],
 | 
			
		||||
			ptr.Addr[3],
 | 
			
		||||
		)
 | 
			
		||||
		return udpAddr.String()
 | 
			
		||||
 | 
			
		||||
	default:
 | 
			
		||||
		return "<unknown address family>"
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (end *Endpoint) DestinationIP() net.IP {
 | 
			
		||||
	switch end.dst.Family {
 | 
			
		||||
	case unix.AF_INET6:
 | 
			
		||||
		return end.dst.Addr[:]
 | 
			
		||||
	case unix.AF_INET:
 | 
			
		||||
		ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
 | 
			
		||||
		return net.IPv4(
 | 
			
		||||
			ptr.Addr[0],
 | 
			
		||||
			ptr.Addr[1],
 | 
			
		||||
			ptr.Addr[2],
 | 
			
		||||
			ptr.Addr[3],
 | 
			
		||||
		)
 | 
			
		||||
	default:
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (end *Endpoint) SourceToBytes() []byte {
 | 
			
		||||
	ptr := unsafe.Pointer(&end.src)
 | 
			
		||||
	arr := (*[unix.SizeofSockaddrInet6]byte)(ptr)
 | 
			
		||||
	return arr[:]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (end *Endpoint) SourceToString() string {
 | 
			
		||||
	return sockaddrToString(end.src)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (end *Endpoint) DestinationToString() string {
 | 
			
		||||
	return sockaddrToString(end.dst)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (end *Endpoint) ClearSrc() {
 | 
			
		||||
	end.src = unix.RawSockaddrInet6{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func zoneToUint32(zone string) (uint32, error) {
 | 
			
		||||
@ -56,7 +179,7 @@ func zoneToUint32(zone string) (uint32, error) {
 | 
			
		||||
	return uint32(n), err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CreateIPv4Socket(port uint16) (Socket, uint16, error) {
 | 
			
		||||
func create4(port uint16) (int, uint16, error) {
 | 
			
		||||
 | 
			
		||||
	// create socket
 | 
			
		||||
 | 
			
		||||
@ -100,18 +223,10 @@ func CreateIPv4Socket(port uint16) (Socket, uint16, error) {
 | 
			
		||||
		unix.Close(fd)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return Socket(fd), uint16(addr.Port), err
 | 
			
		||||
	return fd, uint16(addr.Port), err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CloseIPv4Socket(sock Socket) error {
 | 
			
		||||
	return unix.Close(int(sock))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CloseIPv6Socket(sock Socket) error {
 | 
			
		||||
	return unix.Close(int(sock))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CreateIPv6Socket(port uint16) (Socket, uint16, error) {
 | 
			
		||||
func create6(port uint16) (int, uint16, error) {
 | 
			
		||||
 | 
			
		||||
	// create socket
 | 
			
		||||
 | 
			
		||||
@ -166,13 +281,7 @@ func CreateIPv6Socket(port uint16) (Socket, uint16, error) {
 | 
			
		||||
		unix.Close(fd)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return Socket(fd), uint16(addr.Port), err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (end *Endpoint) ClearSrc() {
 | 
			
		||||
	end.src4if = 0
 | 
			
		||||
	end.src4 = unix.RawSockaddrInet4{}
 | 
			
		||||
	end.src6 = unix.RawSockaddrInet6{}
 | 
			
		||||
	return fd, uint16(addr.Port), err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (end *Endpoint) Set(s string) error {
 | 
			
		||||
@ -187,23 +296,23 @@ func (end *Endpoint) Set(s string) error {
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		ptr := (*unix.RawSockaddrInet6)(unsafe.Pointer(&end.dst))
 | 
			
		||||
		ptr.Family = unix.AF_INET6
 | 
			
		||||
		ptr.Port = uint16(addr.Port)
 | 
			
		||||
		ptr.Flowinfo = 0
 | 
			
		||||
		ptr.Scope_id = zone
 | 
			
		||||
		copy(ptr.Addr[:], ipv6[:])
 | 
			
		||||
		dst := &end.dst
 | 
			
		||||
		dst.Family = unix.AF_INET6
 | 
			
		||||
		dst.Port = uint16(addr.Port)
 | 
			
		||||
		dst.Flowinfo = 0
 | 
			
		||||
		dst.Scope_id = zone
 | 
			
		||||
		copy(dst.Addr[:], ipv6[:])
 | 
			
		||||
		end.ClearSrc()
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ipv4 := addr.IP.To4()
 | 
			
		||||
	if ipv4 != nil {
 | 
			
		||||
		ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
 | 
			
		||||
		ptr.Family = unix.AF_INET
 | 
			
		||||
		ptr.Port = uint16(addr.Port)
 | 
			
		||||
		ptr.Zero = [8]byte{}
 | 
			
		||||
		copy(ptr.Addr[:], ipv4)
 | 
			
		||||
		dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
 | 
			
		||||
		dst.Family = unix.AF_INET
 | 
			
		||||
		dst.Port = uint16(addr.Port)
 | 
			
		||||
		dst.Zero = [8]byte{}
 | 
			
		||||
		copy(dst.Addr[:], ipv4)
 | 
			
		||||
		end.ClearSrc()
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
@ -211,7 +320,7 @@ func (end *Endpoint) Set(s string) error {
 | 
			
		||||
	return errors.New("Failed to recognize IP address format")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func send6(sock uintptr, end *Endpoint, buff []byte) error {
 | 
			
		||||
func send6(sock int, end *Endpoint, buff []byte) error {
 | 
			
		||||
 | 
			
		||||
	// construct message header
 | 
			
		||||
 | 
			
		||||
@ -229,8 +338,8 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
 | 
			
		||||
			Len:   unix.SizeofInet6Pktinfo,
 | 
			
		||||
		},
 | 
			
		||||
		unix.Inet6Pktinfo{
 | 
			
		||||
			Addr:    end.src6.Addr,
 | 
			
		||||
			Ifindex: end.src6.Scope_id,
 | 
			
		||||
			Addr:    end.src.Addr,
 | 
			
		||||
			Ifindex: end.src.Scope_id,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -248,7 +357,7 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
 | 
			
		||||
 | 
			
		||||
	_, _, errno := unix.Syscall(
 | 
			
		||||
		unix.SYS_SENDMSG,
 | 
			
		||||
		sock,
 | 
			
		||||
		uintptr(sock),
 | 
			
		||||
		uintptr(unsafe.Pointer(&msghdr)),
 | 
			
		||||
		0,
 | 
			
		||||
	)
 | 
			
		||||
@ -258,7 +367,7 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
 | 
			
		||||
	return errno
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func send4(sock uintptr, end *Endpoint, buff []byte) error {
 | 
			
		||||
func send4(sock int, end *Endpoint, buff []byte) error {
 | 
			
		||||
 | 
			
		||||
	// construct message header
 | 
			
		||||
 | 
			
		||||
@ -266,6 +375,8 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
 | 
			
		||||
	iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
 | 
			
		||||
	iovec.SetLen(len(buff))
 | 
			
		||||
 | 
			
		||||
	src4 := (*IPv4Source)(unsafe.Pointer(&end.src))
 | 
			
		||||
 | 
			
		||||
	cmsg := struct {
 | 
			
		||||
		cmsghdr unix.Cmsghdr
 | 
			
		||||
		pktinfo unix.Inet4Pktinfo
 | 
			
		||||
@ -276,8 +387,8 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
 | 
			
		||||
			Len:   unix.SizeofInet4Pktinfo,
 | 
			
		||||
		},
 | 
			
		||||
		unix.Inet4Pktinfo{
 | 
			
		||||
			Spec_dst: end.src4.Addr,
 | 
			
		||||
			Ifindex:  end.src4if,
 | 
			
		||||
			Spec_dst: src4.src.Addr,
 | 
			
		||||
			Ifindex:  src4.Ifindex,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -295,7 +406,7 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
 | 
			
		||||
 | 
			
		||||
	_, _, errno := unix.Syscall(
 | 
			
		||||
		unix.SYS_SENDMSG,
 | 
			
		||||
		sock,
 | 
			
		||||
		uintptr(sock),
 | 
			
		||||
		uintptr(unsafe.Pointer(&msghdr)),
 | 
			
		||||
		0,
 | 
			
		||||
	)
 | 
			
		||||
@ -305,28 +416,7 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
 | 
			
		||||
	return errno
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (end *Endpoint) Send(c *net.UDPConn, buff []byte) error {
 | 
			
		||||
 | 
			
		||||
	// extract underlying file descriptor
 | 
			
		||||
 | 
			
		||||
	file, err := c.File()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	sock := file.Fd()
 | 
			
		||||
 | 
			
		||||
	// send depending on address family of dst
 | 
			
		||||
 | 
			
		||||
	family := *((*uint16)(unsafe.Pointer(&end.dst)))
 | 
			
		||||
	if family == unix.AF_INET {
 | 
			
		||||
		return send4(sock, end, buff)
 | 
			
		||||
	} else if family == unix.AF_INET6 {
 | 
			
		||||
		return send6(sock, end, buff)
 | 
			
		||||
	}
 | 
			
		||||
	return errors.New("Unknown address family of source")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (end *Endpoint) ReceiveIPv4(sock Socket, buff []byte) (int, error) {
 | 
			
		||||
func receive4(sock int, buff []byte, end *Endpoint) (int, error) {
 | 
			
		||||
 | 
			
		||||
	// contruct message header
 | 
			
		||||
 | 
			
		||||
@ -360,22 +450,21 @@ func (end *Endpoint) ReceiveIPv4(sock Socket, buff []byte) (int, error) {
 | 
			
		||||
		return 0, errno
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fmt.Println(msghdr)
 | 
			
		||||
	fmt.Println(cmsg)
 | 
			
		||||
 | 
			
		||||
	// update source cache
 | 
			
		||||
 | 
			
		||||
	if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
 | 
			
		||||
		cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
 | 
			
		||||
		cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
 | 
			
		||||
		end.src4.Addr = cmsg.pktinfo.Spec_dst
 | 
			
		||||
		end.src4if = cmsg.pktinfo.Ifindex
 | 
			
		||||
		src4 := (*IPv4Source)(unsafe.Pointer(&end.src))
 | 
			
		||||
		src4.src.Family = unix.AF_INET
 | 
			
		||||
		src4.src.Addr = cmsg.pktinfo.Spec_dst
 | 
			
		||||
		src4.Ifindex = cmsg.pktinfo.Ifindex
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return int(size), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (end *Endpoint) ReceiveIPv6(sock Socket, buff []byte) (int, error) {
 | 
			
		||||
func receive6(sock int, buff []byte, end *Endpoint) (int, error) {
 | 
			
		||||
 | 
			
		||||
	// contruct message header
 | 
			
		||||
 | 
			
		||||
@ -414,18 +503,10 @@ func (end *Endpoint) ReceiveIPv6(sock Socket, buff []byte) (int, error) {
 | 
			
		||||
	if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
 | 
			
		||||
		cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
 | 
			
		||||
		cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
 | 
			
		||||
		end.src6.Addr = cmsg.pktinfo.Addr
 | 
			
		||||
		end.src6.Scope_id = cmsg.pktinfo.Ifindex
 | 
			
		||||
		end.src.Family = unix.AF_INET6
 | 
			
		||||
		end.src.Addr = cmsg.pktinfo.Addr
 | 
			
		||||
		end.src.Scope_id = cmsg.pktinfo.Ifindex
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return int(size), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SetMark(sock Socket, value uint32) error {
 | 
			
		||||
	return unix.SetsockoptInt(
 | 
			
		||||
		int(sock),
 | 
			
		||||
		unix.SOL_SOCKET,
 | 
			
		||||
		unix.SO_MARK,
 | 
			
		||||
		int(value),
 | 
			
		||||
	)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -5,10 +5,8 @@ import (
 | 
			
		||||
	"crypto/rand"
 | 
			
		||||
	"golang.org/x/crypto/blake2s"
 | 
			
		||||
	"golang.org/x/crypto/chacha20poly1305"
 | 
			
		||||
	"net"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
	"unsafe"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type CookieChecker struct {
 | 
			
		||||
@ -76,7 +74,7 @@ func (st *CookieChecker) CheckMAC1(msg []byte) bool {
 | 
			
		||||
	return hmac.Equal(mac1[:], msg[smac1:smac2])
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool {
 | 
			
		||||
func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool {
 | 
			
		||||
	st.mutex.RLock()
 | 
			
		||||
	defer st.mutex.RUnlock()
 | 
			
		||||
 | 
			
		||||
@ -89,8 +87,7 @@ func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool {
 | 
			
		||||
	var cookie [blake2s.Size128]byte
 | 
			
		||||
	func() {
 | 
			
		||||
		mac, _ := blake2s.New128(st.mac2.secret[:])
 | 
			
		||||
		mac.Write(src.IP)
 | 
			
		||||
		mac.Write((*[unsafe.Sizeof(src.Port)]byte)(unsafe.Pointer(&src.Port))[:])
 | 
			
		||||
		mac.Write(src)
 | 
			
		||||
		mac.Sum(cookie[:0])
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
@ -111,7 +108,7 @@ func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool {
 | 
			
		||||
func (st *CookieChecker) CreateReply(
 | 
			
		||||
	msg []byte,
 | 
			
		||||
	recv uint32,
 | 
			
		||||
	src *net.UDPAddr,
 | 
			
		||||
	src []byte,
 | 
			
		||||
) (*MessageCookieReply, error) {
 | 
			
		||||
 | 
			
		||||
	st.mutex.RLock()
 | 
			
		||||
@ -136,8 +133,7 @@ func (st *CookieChecker) CreateReply(
 | 
			
		||||
	var cookie [blake2s.Size128]byte
 | 
			
		||||
	func() {
 | 
			
		||||
		mac, _ := blake2s.New128(st.mac2.secret[:])
 | 
			
		||||
		mac.Write(src.IP)
 | 
			
		||||
		mac.Write((*[unsafe.Sizeof(src.Port)]byte)(unsafe.Pointer(&src.Port))[:])
 | 
			
		||||
		mac.Write(src)
 | 
			
		||||
		mac.Sum(cookie[:0])
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,18 +1,14 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"golang.org/x/net/ipv4"
 | 
			
		||||
	"golang.org/x/net/ipv6"
 | 
			
		||||
	"runtime"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Listener struct {
 | 
			
		||||
	sock   Socket
 | 
			
		||||
	active bool
 | 
			
		||||
	update chan struct{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Device struct {
 | 
			
		||||
	log       *Logger // collection of loggers for levels
 | 
			
		||||
	idCounter uint    // for assigning debug ids to peers
 | 
			
		||||
@ -27,8 +23,7 @@ type Device struct {
 | 
			
		||||
	}
 | 
			
		||||
	net struct {
 | 
			
		||||
		mutex  sync.RWMutex
 | 
			
		||||
		ipv4   Listener
 | 
			
		||||
		ipv6   Listener
 | 
			
		||||
		bind   UDPBind
 | 
			
		||||
		port   uint16
 | 
			
		||||
		fwmark uint32
 | 
			
		||||
	}
 | 
			
		||||
@ -43,9 +38,8 @@ type Device struct {
 | 
			
		||||
		handshake  chan QueueHandshakeElement
 | 
			
		||||
	}
 | 
			
		||||
	signal struct {
 | 
			
		||||
		stop             chan struct{} // halts all go routines
 | 
			
		||||
		updateIPv4Socket chan struct{} // a net.conn was set (consumed by the receiver routine)
 | 
			
		||||
		updateIPv6Socket chan struct{} // a net.conn was set (consumed by the receiver routine)
 | 
			
		||||
		stop       chan struct{}
 | 
			
		||||
		updateBind chan struct{}
 | 
			
		||||
	}
 | 
			
		||||
	underLoadUntil atomic.Value
 | 
			
		||||
	ratelimiter    Ratelimiter
 | 
			
		||||
@ -146,8 +140,6 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
 | 
			
		||||
	device.tun.device = tun
 | 
			
		||||
 | 
			
		||||
	device.indices.Init()
 | 
			
		||||
	device.net.ipv4.Init()
 | 
			
		||||
	device.net.ipv6.Init()
 | 
			
		||||
	device.ratelimiter.Init()
 | 
			
		||||
 | 
			
		||||
	device.routingTable.Reset()
 | 
			
		||||
@ -181,8 +173,8 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
 | 
			
		||||
	go device.RoutineReadFromTUN()
 | 
			
		||||
	go device.RoutineTUNEventReader()
 | 
			
		||||
	go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
 | 
			
		||||
	go device.RoutineReceiveIncomming(&device.net.ipv4)
 | 
			
		||||
	go device.RoutineReceiveIncomming(&device.net.ipv6)
 | 
			
		||||
	go device.RoutineReceiveIncomming(ipv4.Version)
 | 
			
		||||
	go device.RoutineReceiveIncomming(ipv6.Version)
 | 
			
		||||
	return device
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -4,7 +4,6 @@ import (
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
@ -15,8 +14,8 @@ type Peer struct {
 | 
			
		||||
	persistentKeepaliveInterval uint64
 | 
			
		||||
	keyPairs                    KeyPairs
 | 
			
		||||
	handshake                   Handshake
 | 
			
		||||
	endpoint                    Endpoint
 | 
			
		||||
	device                      *Device
 | 
			
		||||
	endpoint                    *net.UDPAddr
 | 
			
		||||
	stats                       struct {
 | 
			
		||||
		txBytes           uint64 // bytes send to peer (endpoint)
 | 
			
		||||
		rxBytes           uint64 // bytes received from peer
 | 
			
		||||
@ -134,7 +133,7 @@ func (peer *Peer) String() string {
 | 
			
		||||
	return fmt.Sprintf(
 | 
			
		||||
		"peer(%d %s %s)",
 | 
			
		||||
		peer.id,
 | 
			
		||||
		peer.endpoint.String(),
 | 
			
		||||
		peer.endpoint.DestinationToString(),
 | 
			
		||||
		base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
 | 
			
		||||
	)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -97,17 +97,6 @@ func (device *Device) RoutineReceiveIncomming(IPVersion int) {
 | 
			
		||||
	logDebug := device.log.Debug
 | 
			
		||||
	logDebug.Println("Routine, receive incomming, started")
 | 
			
		||||
 | 
			
		||||
	var listener *Listener
 | 
			
		||||
 | 
			
		||||
	switch IPVersion {
 | 
			
		||||
	case ipv4.Version:
 | 
			
		||||
		listener = &device.net.ipv4
 | 
			
		||||
	case ipv6.Version:
 | 
			
		||||
		listener = &device.net.ipv6
 | 
			
		||||
	default:
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
 | 
			
		||||
		// wait for new conn
 | 
			
		||||
@ -118,15 +107,14 @@ func (device *Device) RoutineReceiveIncomming(IPVersion int) {
 | 
			
		||||
		case <-device.signal.stop:
 | 
			
		||||
			return
 | 
			
		||||
 | 
			
		||||
		case <-listener.update:
 | 
			
		||||
		case <-device.signal.updateBind:
 | 
			
		||||
 | 
			
		||||
			// fetch new socket
 | 
			
		||||
 | 
			
		||||
			device.net.mutex.RLock()
 | 
			
		||||
			sock := listener.sock
 | 
			
		||||
			okay := listener.active
 | 
			
		||||
			bind := device.net.bind
 | 
			
		||||
			device.net.mutex.RUnlock()
 | 
			
		||||
			if !okay {
 | 
			
		||||
			if bind == nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
@ -145,10 +133,13 @@ func (device *Device) RoutineReceiveIncomming(IPVersion int) {
 | 
			
		||||
 | 
			
		||||
				var endpoint Endpoint
 | 
			
		||||
 | 
			
		||||
				if IPVersion == ipv6.Version {
 | 
			
		||||
					size, err = endpoint.ReceiveIPv4(sock, buffer[:])
 | 
			
		||||
				} else {
 | 
			
		||||
					size, err = endpoint.ReceiveIPv6(sock, buffer[:])
 | 
			
		||||
				switch IPVersion {
 | 
			
		||||
				case ipv4.Version:
 | 
			
		||||
					size, err = bind.ReceiveIPv4(buffer[:], &endpoint)
 | 
			
		||||
				case ipv6.Version:
 | 
			
		||||
					size, err = bind.ReceiveIPv6(buffer[:], &endpoint)
 | 
			
		||||
				default:
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if err != nil {
 | 
			
		||||
@ -340,15 +331,19 @@ func (device *Device) RoutineHandshake() {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			srcBytes := elem.endpoint.SourceToBytes()
 | 
			
		||||
			if device.IsUnderLoad() {
 | 
			
		||||
				if !device.mac.CheckMAC2(elem.packet, elem.source) {
 | 
			
		||||
 | 
			
		||||
				// verify MAC2 field
 | 
			
		||||
 | 
			
		||||
				if !device.mac.CheckMAC2(elem.packet, srcBytes) {
 | 
			
		||||
 | 
			
		||||
					// construct cookie reply
 | 
			
		||||
 | 
			
		||||
					logDebug.Println("Sending cookie reply to:", elem.source.String())
 | 
			
		||||
					logDebug.Println("Sending cookie reply to:", elem.endpoint.SourceToString())
 | 
			
		||||
 | 
			
		||||
					sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type"
 | 
			
		||||
					reply, err := device.mac.CreateReply(elem.packet, sender, elem.source)
 | 
			
		||||
					reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						logError.Println("Failed to create cookie reply:", err)
 | 
			
		||||
						return
 | 
			
		||||
@ -358,9 +353,9 @@ func (device *Device) RoutineHandshake() {
 | 
			
		||||
 | 
			
		||||
					writer := bytes.NewBuffer(temp[:0])
 | 
			
		||||
					binary.Write(writer, binary.LittleEndian, reply)
 | 
			
		||||
					_, err = device.net.conn.WriteToUDP(
 | 
			
		||||
					device.net.bind.Send(
 | 
			
		||||
						writer.Bytes(),
 | 
			
		||||
						elem.source,
 | 
			
		||||
						&elem.endpoint,
 | 
			
		||||
					)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						logDebug.Println("Failed to send cookie reply:", err)
 | 
			
		||||
@ -368,7 +363,11 @@ func (device *Device) RoutineHandshake() {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if !device.ratelimiter.Allow(elem.source.IP) {
 | 
			
		||||
				// check ratelimiter
 | 
			
		||||
 | 
			
		||||
				if !device.ratelimiter.Allow(
 | 
			
		||||
					elem.endpoint.DestinationIP(),
 | 
			
		||||
				) {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
@ -399,8 +398,7 @@ func (device *Device) RoutineHandshake() {
 | 
			
		||||
			if peer == nil {
 | 
			
		||||
				logInfo.Println(
 | 
			
		||||
					"Recieved invalid initiation message from",
 | 
			
		||||
					elem.source.IP.String(),
 | 
			
		||||
					elem.source.Port,
 | 
			
		||||
					elem.endpoint.DestinationToString(),
 | 
			
		||||
				)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
@ -414,7 +412,7 @@ func (device *Device) RoutineHandshake() {
 | 
			
		||||
			// TODO: Discover destination address also, only update on change
 | 
			
		||||
 | 
			
		||||
			peer.mutex.Lock()
 | 
			
		||||
			peer.endpoint = elem.source
 | 
			
		||||
			peer.endpoint = elem.endpoint
 | 
			
		||||
			peer.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
			// create response
 | 
			
		||||
@ -460,8 +458,7 @@ func (device *Device) RoutineHandshake() {
 | 
			
		||||
			if peer == nil {
 | 
			
		||||
				logInfo.Println(
 | 
			
		||||
					"Recieved invalid response message from",
 | 
			
		||||
					elem.source.IP.String(),
 | 
			
		||||
					elem.source.Port,
 | 
			
		||||
					elem.endpoint.DestinationToString(),
 | 
			
		||||
				)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user