wireguard-go/src/conn_linux.go

513 lines
9.2 KiB
Go
Raw Normal View History

/* Copyright 2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
*
* This implements userspace semantics of "sticky sockets", modeled after
* WireGuard's kernelspace implementation.
*/
2017-08-25 14:53:23 +02:00
package main
import (
"errors"
2017-08-25 14:53:23 +02:00
"golang.org/x/sys/unix"
"net"
"strconv"
"unsafe"
2017-08-25 14:53:23 +02:00
)
/* Supports source address caching
*
* Currently there is no way to achieve this within the net package:
* See e.g. https://github.com/golang/go/issues/17930
2017-10-08 22:03:32 +02:00
* So this code is remains platform dependent.
*/
type Endpoint struct {
2017-10-08 22:03:32 +02:00
src unix.RawSockaddrInet6
dst unix.RawSockaddrInet6
}
type IPv4Source struct {
src unix.RawSockaddrInet4
Ifindex int32
}
2017-10-08 22:03:32 +02:00
type Bind struct {
sock4 int
sock6 int
}
2017-10-08 22:03:32 +02:00
func CreateUDPBind(port uint16) (UDPBind, uint16, error) {
var err error
var bind Bind
bind.sock6, port, err = create6(port)
if err != nil {
return nil, port, err
}
bind.sock4, port, err = create4(port)
if err != nil {
unix.Close(bind.sock6)
}
return &bind, port, err
}
func (bind *Bind) SetMark(value uint32) error {
err := unix.SetsockoptInt(
bind.sock6,
unix.SOL_SOCKET,
unix.SO_MARK,
int(value),
)
if err != nil {
return err
}
return unix.SetsockoptInt(
bind.sock4,
unix.SOL_SOCKET,
unix.SO_MARK,
int(value),
)
}
func (bind *Bind) Close() error {
err1 := unix.Close(bind.sock6)
err2 := unix.Close(bind.sock4)
if err1 != nil {
return err1
}
return err2
}
func (bind *Bind) ReceiveIPv6(buff []byte, end *Endpoint) (int, error) {
return receive6(
bind.sock6,
buff,
end,
)
}
func (bind *Bind) ReceiveIPv4(buff []byte, end *Endpoint) (int, error) {
return receive4(
bind.sock4,
buff,
end,
)
}
func (bind *Bind) Send(buff []byte, end *Endpoint) error {
switch end.src.Family {
case unix.AF_INET6:
return send6(bind.sock6, end, buff)
case unix.AF_INET:
return send4(bind.sock4, end, buff)
default:
return errors.New("Unknown address family of source")
}
}
func sockaddrToString(addr unix.RawSockaddrInet6) string {
var udpAddr net.UDPAddr
switch addr.Family {
case unix.AF_INET6:
udpAddr.Port = int(addr.Port)
udpAddr.IP = addr.Addr[:]
return udpAddr.String()
case unix.AF_INET:
ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr))
udpAddr.Port = int(ptr.Port)
udpAddr.IP = net.IPv4(
ptr.Addr[0],
ptr.Addr[1],
ptr.Addr[2],
ptr.Addr[3],
)
return udpAddr.String()
default:
return "<unknown address family>"
}
}
func (end *Endpoint) DestinationIP() net.IP {
switch end.dst.Family {
case unix.AF_INET6:
return end.dst.Addr[:]
case unix.AF_INET:
ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
return net.IPv4(
ptr.Addr[0],
ptr.Addr[1],
ptr.Addr[2],
ptr.Addr[3],
)
default:
return nil
}
}
func (end *Endpoint) SourceToBytes() []byte {
ptr := unsafe.Pointer(&end.src)
arr := (*[unix.SizeofSockaddrInet6]byte)(ptr)
return arr[:]
}
func (end *Endpoint) SourceToString() string {
return sockaddrToString(end.src)
}
func (end *Endpoint) DestinationToString() string {
return sockaddrToString(end.dst)
}
func (end *Endpoint) ClearSrc() {
end.src = unix.RawSockaddrInet6{}
}
func zoneToUint32(zone string) (uint32, error) {
if zone == "" {
return 0, nil
}
if intr, err := net.InterfaceByName(zone); err == nil {
return uint32(intr.Index), nil
}
n, err := strconv.ParseUint(zone, 10, 32)
return uint32(n), err
}
2017-10-08 22:03:32 +02:00
func create4(port uint16) (int, uint16, error) {
// create socket
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return -1, 0, err
}
addr := unix.SockaddrInet4{
Port: int(port),
}
// set sockopts and bind
if err := func() error {
if err := unix.SetsockoptInt(
fd,
unix.SOL_SOCKET,
unix.SO_REUSEADDR,
1,
); err != nil {
return err
}
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IP,
unix.IP_PKTINFO,
1,
); err != nil {
return err
}
return unix.Bind(fd, &addr)
}(); err != nil {
unix.Close(fd)
}
2017-10-08 22:03:32 +02:00
return fd, uint16(addr.Port), err
}
2017-10-08 22:03:32 +02:00
func create6(port uint16) (int, uint16, error) {
// create socket
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return -1, 0, err
}
// set sockopts and bind
addr := unix.SockaddrInet6{
Port: int(port),
}
if err := func() error {
if err := unix.SetsockoptInt(
fd,
unix.SOL_SOCKET,
unix.SO_REUSEADDR,
1,
); err != nil {
return err
}
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IPV6,
unix.IPV6_RECVPKTINFO,
1,
); err != nil {
return err
}
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IPV6,
unix.IPV6_V6ONLY,
1,
); err != nil {
return err
}
return unix.Bind(fd, &addr)
}(); err != nil {
unix.Close(fd)
}
2017-10-08 22:03:32 +02:00
return fd, uint16(addr.Port), err
}
func (end *Endpoint) Set(s string) error {
addr, err := parseEndpoint(s)
if err != nil {
return err
}
ipv6 := addr.IP.To16()
if ipv6 != nil {
zone, err := zoneToUint32(addr.Zone)
if err != nil {
return err
}
2017-10-08 22:03:32 +02:00
dst := &end.dst
dst.Family = unix.AF_INET6
dst.Port = uint16(addr.Port)
dst.Flowinfo = 0
dst.Scope_id = zone
copy(dst.Addr[:], ipv6[:])
end.ClearSrc()
return nil
}
ipv4 := addr.IP.To4()
if ipv4 != nil {
2017-10-08 22:03:32 +02:00
dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
dst.Family = unix.AF_INET
dst.Port = uint16(addr.Port)
dst.Zero = [8]byte{}
copy(dst.Addr[:], ipv4)
end.ClearSrc()
return nil
}
return errors.New("Failed to recognize IP address format")
}
2017-10-08 22:03:32 +02:00
func send6(sock int, end *Endpoint, buff []byte) error {
// construct message header
var iovec unix.Iovec
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
iovec.SetLen(len(buff))
cmsg := struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet6Pktinfo
}{
unix.Cmsghdr{
Level: unix.IPPROTO_IPV6,
Type: unix.IPV6_PKTINFO,
Len: unix.SizeofInet6Pktinfo,
},
unix.Inet6Pktinfo{
2017-10-08 22:03:32 +02:00
Addr: end.src.Addr,
Ifindex: end.src.Scope_id,
},
}
msghdr := unix.Msghdr{
Iov: &iovec,
Iovlen: 1,
Name: (*byte)(unsafe.Pointer(&end.dst)),
Namelen: unix.SizeofSockaddrInet6,
Control: (*byte)(unsafe.Pointer(&cmsg)),
}
msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
// sendmsg(sock, &msghdr, 0)
_, _, errno := unix.Syscall(
unix.SYS_SENDMSG,
2017-10-08 22:03:32 +02:00
uintptr(sock),
uintptr(unsafe.Pointer(&msghdr)),
0,
)
if errno == unix.EINVAL {
end.ClearSrc()
}
return errno
}
2017-10-08 22:03:32 +02:00
func send4(sock int, end *Endpoint, buff []byte) error {
// construct message header
var iovec unix.Iovec
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
iovec.SetLen(len(buff))
2017-10-08 22:03:32 +02:00
src4 := (*IPv4Source)(unsafe.Pointer(&end.src))
cmsg := struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet4Pktinfo
}{
unix.Cmsghdr{
Level: unix.IPPROTO_IP,
Type: unix.IP_PKTINFO,
Len: unix.SizeofInet4Pktinfo,
},
unix.Inet4Pktinfo{
2017-10-08 22:03:32 +02:00
Spec_dst: src4.src.Addr,
Ifindex: src4.Ifindex,
},
}
msghdr := unix.Msghdr{
Iov: &iovec,
Iovlen: 1,
Name: (*byte)(unsafe.Pointer(&end.dst)),
Namelen: unix.SizeofSockaddrInet4,
Control: (*byte)(unsafe.Pointer(&cmsg)),
}
msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
// sendmsg(sock, &msghdr, 0)
_, _, errno := unix.Syscall(
unix.SYS_SENDMSG,
2017-10-08 22:03:32 +02:00
uintptr(sock),
uintptr(unsafe.Pointer(&msghdr)),
0,
)
if errno == unix.EINVAL {
end.ClearSrc()
}
return errno
}
2017-10-08 22:03:32 +02:00
func receive4(sock int, buff []byte, end *Endpoint) (int, error) {
// contruct message header
var iovec unix.Iovec
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
iovec.SetLen(len(buff))
var cmsg struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet4Pktinfo
}
var msghdr unix.Msghdr
msghdr.Iov = &iovec
msghdr.Iovlen = 1
msghdr.Name = (*byte)(unsafe.Pointer(&end.dst))
msghdr.Namelen = unix.SizeofSockaddrInet4
msghdr.Control = (*byte)(unsafe.Pointer(&cmsg))
msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
// recvmsg(sock, &mskhdr, 0)
size, _, errno := unix.Syscall(
unix.SYS_RECVMSG,
uintptr(sock),
uintptr(unsafe.Pointer(&msghdr)),
0,
)
if errno != 0 {
return 0, errno
}
// update source cache
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
2017-10-08 22:03:32 +02:00
src4 := (*IPv4Source)(unsafe.Pointer(&end.src))
src4.src.Family = unix.AF_INET
src4.src.Addr = cmsg.pktinfo.Spec_dst
src4.Ifindex = cmsg.pktinfo.Ifindex
}
return int(size), nil
}
2017-10-08 22:03:32 +02:00
func receive6(sock int, buff []byte, end *Endpoint) (int, error) {
// contruct message header
var iovec unix.Iovec
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
iovec.SetLen(len(buff))
var cmsg struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet6Pktinfo
}
var msg unix.Msghdr
msg.Iov = &iovec
msg.Iovlen = 1
msg.Name = (*byte)(unsafe.Pointer(&end.dst))
msg.Namelen = uint32(unix.SizeofSockaddrInet6)
msg.Control = (*byte)(unsafe.Pointer(&cmsg))
msg.SetControllen(int(unsafe.Sizeof(cmsg)))
// recvmsg(sock, &mskhdr, 0)
size, _, errno := unix.Syscall(
unix.SYS_RECVMSG,
uintptr(sock),
uintptr(unsafe.Pointer(&msg)),
0,
)
if errno != 0 {
return 0, errno
}
// update source cache
if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
2017-10-08 22:03:32 +02:00
end.src.Family = unix.AF_INET6
end.src.Addr = cmsg.pktinfo.Addr
end.src.Scope_id = cmsg.pktinfo.Ifindex
}
return int(size), nil
}