Begin work on source address caching (linux)
This commit is contained in:
		
							parent
							
								
									c545d63bb9
								
							
						
					
					
						commit
						eefa47b0f9
					
				
							
								
								
									
										22
									
								
								src/conn.go
									
									
									
									
									
								
							
							
						
						
									
										22
									
								
								src/conn.go
									
									
									
									
									
								
							@ -1,9 +1,31 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"net"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func parseEndpoint(s string) (*net.UDPAddr, error) {
 | 
			
		||||
 | 
			
		||||
	// ensure that the host is an IP address
 | 
			
		||||
 | 
			
		||||
	host, _, err := net.SplitHostPort(s)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	if ip := net.ParseIP(host); ip == nil {
 | 
			
		||||
		return nil, errors.New("Failed to parse IP address: " + host)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// parse address and port
 | 
			
		||||
 | 
			
		||||
	addr, err := net.ResolveUDPAddr("udp", s)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return addr, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateUDPConn(device *Device) error {
 | 
			
		||||
	netc := &device.net
 | 
			
		||||
	netc.mutex.Lock()
 | 
			
		||||
 | 
			
		||||
@ -1,10 +1,253 @@
 | 
			
		||||
/* Copyright 2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
 | 
			
		||||
 *
 | 
			
		||||
 * This implements userspace semantics of "sticky sockets", modeled after
 | 
			
		||||
 * WireGuard's kernelspace implementation.
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"golang.org/x/sys/unix"
 | 
			
		||||
	"net"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"unsafe"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/* Supports source address caching
 | 
			
		||||
 *
 | 
			
		||||
 * It is important that the endpoint is only updated after the packet content has been authenticated.
 | 
			
		||||
 *
 | 
			
		||||
 * Currently there is no way to achieve this within the net package:
 | 
			
		||||
 * See e.g. https://github.com/golang/go/issues/17930
 | 
			
		||||
 */
 | 
			
		||||
type Endpoint struct {
 | 
			
		||||
	// source (selected based on dst type)
 | 
			
		||||
	// (could use RawSockaddrAny and unsafe)
 | 
			
		||||
	srcIPv6 unix.RawSockaddrInet6
 | 
			
		||||
	srcIPv4 unix.RawSockaddrInet4
 | 
			
		||||
	srcIf4  int32
 | 
			
		||||
 | 
			
		||||
	dst unix.RawSockaddrAny
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func zoneToUint32(zone string) (uint32, error) {
 | 
			
		||||
	if zone == "" {
 | 
			
		||||
		return 0, nil
 | 
			
		||||
	}
 | 
			
		||||
	if intr, err := net.InterfaceByName(zone); err == nil {
 | 
			
		||||
		return uint32(intr.Index), nil
 | 
			
		||||
	}
 | 
			
		||||
	n, err := strconv.ParseUint(zone, 10, 32)
 | 
			
		||||
	return uint32(n), err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (end *Endpoint) ClearSrc() {
 | 
			
		||||
	end.srcIf4 = 0
 | 
			
		||||
	end.srcIPv4 = unix.RawSockaddrInet4{}
 | 
			
		||||
	end.srcIPv6 = unix.RawSockaddrInet6{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (end *Endpoint) Set(s string) error {
 | 
			
		||||
	addr, err := parseEndpoint(s)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ipv6 := addr.IP.To16()
 | 
			
		||||
	if ipv6 != nil {
 | 
			
		||||
		zone, err := zoneToUint32(addr.Zone)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		ptr := (*unix.RawSockaddrInet6)(unsafe.Pointer(&end.dst))
 | 
			
		||||
		ptr.Family = unix.AF_INET6
 | 
			
		||||
		ptr.Port = uint16(addr.Port)
 | 
			
		||||
		ptr.Flowinfo = 0
 | 
			
		||||
		ptr.Scope_id = zone
 | 
			
		||||
		copy(ptr.Addr[:], ipv6[:])
 | 
			
		||||
		end.ClearSrc()
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ipv4 := addr.IP.To4()
 | 
			
		||||
	if ipv4 != nil {
 | 
			
		||||
		ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
 | 
			
		||||
		ptr.Family = unix.AF_INET
 | 
			
		||||
		ptr.Port = uint16(addr.Port)
 | 
			
		||||
		ptr.Zero = [8]byte{}
 | 
			
		||||
		copy(ptr.Addr[:], ipv4)
 | 
			
		||||
		end.ClearSrc()
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return errors.New("Failed to recognize IP address format")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func send6(sock uintptr, end *Endpoint, buff []byte) error {
 | 
			
		||||
	var iovec unix.Iovec
 | 
			
		||||
 | 
			
		||||
	iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
 | 
			
		||||
	iovec.SetLen(len(buff))
 | 
			
		||||
 | 
			
		||||
	cmsg := struct {
 | 
			
		||||
		cmsghdr unix.Cmsghdr
 | 
			
		||||
		pktinfo unix.Inet6Pktinfo
 | 
			
		||||
	}{
 | 
			
		||||
		unix.Cmsghdr{
 | 
			
		||||
			Level: unix.IPPROTO_IPV6,
 | 
			
		||||
			Type:  unix.IPV6_PKTINFO,
 | 
			
		||||
			Len:   unix.SizeofInet6Pktinfo,
 | 
			
		||||
		},
 | 
			
		||||
		unix.Inet6Pktinfo{
 | 
			
		||||
			Addr:    end.srcIPv6.Addr,
 | 
			
		||||
			Ifindex: end.srcIPv6.Scope_id,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	msghdr := unix.Msghdr{
 | 
			
		||||
		Iov:     &iovec,
 | 
			
		||||
		Iovlen:  1,
 | 
			
		||||
		Name:    (*byte)(unsafe.Pointer(&end.dst)),
 | 
			
		||||
		Namelen: unix.SizeofSockaddrInet6,
 | 
			
		||||
		Control: (*byte)(unsafe.Pointer(&cmsg)),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
 | 
			
		||||
 | 
			
		||||
	// sendmsg(sock, &msghdr, 0)
 | 
			
		||||
 | 
			
		||||
	_, _, errno := unix.Syscall(
 | 
			
		||||
		unix.SYS_SENDMSG,
 | 
			
		||||
		sock,
 | 
			
		||||
		uintptr(unsafe.Pointer(&msghdr)),
 | 
			
		||||
		0,
 | 
			
		||||
	)
 | 
			
		||||
	if errno == unix.EINVAL {
 | 
			
		||||
		end.ClearSrc()
 | 
			
		||||
	}
 | 
			
		||||
	return errno
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func send4(sock uintptr, end *Endpoint, buff []byte) error {
 | 
			
		||||
	var iovec unix.Iovec
 | 
			
		||||
 | 
			
		||||
	iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
 | 
			
		||||
	iovec.SetLen(len(buff))
 | 
			
		||||
 | 
			
		||||
	cmsg := struct {
 | 
			
		||||
		cmsghdr unix.Cmsghdr
 | 
			
		||||
		pktinfo unix.Inet4Pktinfo
 | 
			
		||||
	}{
 | 
			
		||||
		unix.Cmsghdr{
 | 
			
		||||
			Level: unix.IPPROTO_IP,
 | 
			
		||||
			Type:  unix.IP_PKTINFO,
 | 
			
		||||
			Len:   unix.SizeofInet6Pktinfo,
 | 
			
		||||
		},
 | 
			
		||||
		unix.Inet4Pktinfo{
 | 
			
		||||
			Spec_dst: end.srcIPv4.Addr,
 | 
			
		||||
			Ifindex:  end.srcIf4,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	msghdr := unix.Msghdr{
 | 
			
		||||
		Iov:     &iovec,
 | 
			
		||||
		Iovlen:  1,
 | 
			
		||||
		Name:    (*byte)(unsafe.Pointer(&end.dst)),
 | 
			
		||||
		Namelen: unix.SizeofSockaddrInet4,
 | 
			
		||||
		Control: (*byte)(unsafe.Pointer(&cmsg)),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
 | 
			
		||||
 | 
			
		||||
	// sendmsg(sock, &msghdr, 0)
 | 
			
		||||
 | 
			
		||||
	_, _, errno := unix.Syscall(
 | 
			
		||||
		unix.SYS_SENDMSG,
 | 
			
		||||
		sock,
 | 
			
		||||
		uintptr(unsafe.Pointer(&msghdr)),
 | 
			
		||||
		0,
 | 
			
		||||
	)
 | 
			
		||||
	if errno == unix.EINVAL {
 | 
			
		||||
		end.ClearSrc()
 | 
			
		||||
	}
 | 
			
		||||
	return errno
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func send(c *net.UDPConn, end *Endpoint, buff []byte) error {
 | 
			
		||||
 | 
			
		||||
	// extract underlying file descriptor
 | 
			
		||||
 | 
			
		||||
	file, err := c.File()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	sock := file.Fd()
 | 
			
		||||
 | 
			
		||||
	// send depending on address family of dst
 | 
			
		||||
 | 
			
		||||
	family := *((*uint16)(unsafe.Pointer(&end.dst)))
 | 
			
		||||
	if family == unix.AF_INET {
 | 
			
		||||
		return send4(sock, end, buff)
 | 
			
		||||
	} else if family == unix.AF_INET6 {
 | 
			
		||||
		return send6(sock, end, buff)
 | 
			
		||||
	}
 | 
			
		||||
	return errors.New("Unknown address family of source")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func receiveIPv4(end *Endpoint, c *net.UDPConn, buff []byte) (error, *net.UDPAddr, *net.UDPAddr) {
 | 
			
		||||
 | 
			
		||||
	file, err := c.File()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err, nil, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var iovec unix.Iovec
 | 
			
		||||
	iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
 | 
			
		||||
	iovec.SetLen(len(buff))
 | 
			
		||||
 | 
			
		||||
	var cmsg struct {
 | 
			
		||||
		cmsghdr unix.Cmsghdr
 | 
			
		||||
		pktinfo unix.Inet6Pktinfo // big enough
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var msg unix.Msghdr
 | 
			
		||||
	msg.Iov = &iovec
 | 
			
		||||
	msg.Iovlen = 1
 | 
			
		||||
	msg.Name = (*byte)(unsafe.Pointer(&end.dst))
 | 
			
		||||
	msg.Namelen = uint32(unix.SizeofSockaddrAny)
 | 
			
		||||
	msg.Control = (*byte)(unsafe.Pointer(&cmsg))
 | 
			
		||||
	msg.SetControllen(int(unsafe.Sizeof(cmsg)))
 | 
			
		||||
 | 
			
		||||
	_, _, errno := unix.Syscall(
 | 
			
		||||
		unix.SYS_RECVMSG,
 | 
			
		||||
		file.Fd(),
 | 
			
		||||
		uintptr(unsafe.Pointer(&msg)),
 | 
			
		||||
		0,
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	if errno != 0 {
 | 
			
		||||
		return errno, nil, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
 | 
			
		||||
		cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
 | 
			
		||||
		cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
 | 
			
		||||
		cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
 | 
			
		||||
		cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
 | 
			
		||||
 | 
			
		||||
		info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&cmsg.pktinfo))
 | 
			
		||||
		println(info)
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func setMark(conn *net.UDPConn, value uint32) error {
 | 
			
		||||
	if conn == nil {
 | 
			
		||||
		return nil
 | 
			
		||||
 | 
			
		||||
@ -29,6 +29,11 @@ func (a *AtomicBool) Set(val bool) {
 | 
			
		||||
	atomic.StoreInt32(&a.flag, flag)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func toInt32(n uint32) int32 {
 | 
			
		||||
	mask := uint32(1 << 31)
 | 
			
		||||
	return int32(-(n & mask) + (n & ^mask))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func min(a uint, b uint) uint {
 | 
			
		||||
	if a > b {
 | 
			
		||||
		return b
 | 
			
		||||
 | 
			
		||||
@ -120,14 +120,6 @@ func (tun *NativeTun) Name() string {
 | 
			
		||||
	return tun.name
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func toInt32(val []byte) int32 {
 | 
			
		||||
	n := binary.LittleEndian.Uint32(val[:4])
 | 
			
		||||
	if n >= (1 << 31) {
 | 
			
		||||
		return -int32(^n) - 1
 | 
			
		||||
	}
 | 
			
		||||
	return int32(n)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getDummySock() (int, error) {
 | 
			
		||||
	return unix.Socket(
 | 
			
		||||
		unix.AF_INET,
 | 
			
		||||
@ -157,7 +149,8 @@ func getIFIndex(name string) (int32, error) {
 | 
			
		||||
		return 0, errno
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return toInt32(ifr[unix.IFNAMSIZ:]), nil
 | 
			
		||||
	index := binary.LittleEndian.Uint32(ifr[unix.IFNAMSIZ:])
 | 
			
		||||
	return toInt32(index), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tun *NativeTun) setMTU(n int) error {
 | 
			
		||||
 | 
			
		||||
@ -273,8 +273,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
			case "endpoint":
 | 
			
		||||
				// TODO: Only IP and port
 | 
			
		||||
				addr, err := net.ResolveUDPAddr("udp", value)
 | 
			
		||||
				addr, err := parseEndpoint(value)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logError.Println("Failed to set endpoint:", value)
 | 
			
		||||
					return &IPCError{Code: ipcErrorInvalid}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user