diff --git a/src/conn.go b/src/conn.go index 74bb075..5b40a23 100644 --- a/src/conn.go +++ b/src/conn.go @@ -24,11 +24,9 @@ type Bind interface { */ type Endpoint interface { ClearSrc() // clears the source address - ClearDst() // clears the destination address SrcToString() string // returns the local source address (ip:port) DstToString() string // returns the destination address (ip:port) DstToBytes() []byte // used for mac2 cookie calculations - SetDst(string) error // used for manually setting the endpoint (uapi) DstIP() net.IP SrcIP() net.IP } @@ -92,7 +90,7 @@ func UpdateUDPListener(device *Device) error { // bind to new port var err error - netc.bind, netc.port, err = CreateUDPBind(netc.port) + netc.bind, netc.port, err = CreateBind(netc.port) if err != nil { netc.bind = nil return err diff --git a/src/conn_default.go b/src/conn_default.go index 31cab5c..34168c6 100644 --- a/src/conn_default.go +++ b/src/conn_default.go @@ -13,11 +13,68 @@ import ( * See conn_linux.go for an implementation on the linux platform. */ -type Endpoint *net.UDPAddr +type NativeBind struct { + ipv4 *net.UDPConn + ipv6 *net.UDPConn +} -type NativeBind *net.UDPConn +type NativeEndpoint net.UDPAddr -func CreateUDPBind(port uint16) (UDPBind, uint16, error) { +var _ Bind = (*NativeBind)(nil) +var _ Endpoint = (*NativeEndpoint)(nil) + +func CreateEndpoint(s string) (Endpoint, error) { + addr, err := parseEndpoint(s) + return (addr).(*NativeEndpoint), err +} + +func (_ *NativeEndpoint) ClearSrc() {} + +func (e *NativeEndpoint) DstIP() net.IP { + return (*net.UDPAddr)(e).IP +} + +func (e *NativeEndpoint) SrcIP() net.IP { + return nil // not supported +} + +func (e *NativeEndpoint) DstToBytes() []byte { + addr := (*net.UDPAddr)(e) + out := addr.IP.([]byte) + out = append(out, byte(addr.Port&0xff)) + out = append(out, byte((addr.Port>>8)&0xff)) + return out +} + +func (e *NativeEndpoint) DstToString() string { + return (*net.UDPAddr)(e).String() +} + +func (e *NativeEndpoint) SrcToString() string { + return "" +} + +func listenNet(net string, port int) (*net.UDPConn, int, error) { + + // listen + + conn, err := net.ListenUDP("udp", &UDPAddr{Port: port}) + if err != nil { + return nil, 0, err + } + + // retrieve port + + laddr := conn.LocalAddr() + uaddr, _ = net.ResolveUDPAddr( + laddr.Network(), + laddr.String(), + ) + + return conn, uaddr.Port, nil +} + +func CreateBind(port uint16) (Bind, uint16, error) { // listen @@ -38,9 +95,3 @@ func CreateUDPBind(port uint16) (UDPBind, uint16, error) { ) return uaddr.Port } - -func (_ Endpoint) ClearSrc() {} - -func SetMark(conn *net.UDPConn, value uint32) error { - return nil -} diff --git a/src/conn_linux.go b/src/conn_linux.go index 46f873f..cdba74f 100644 --- a/src/conn_linux.go +++ b/src/conn_linux.go @@ -50,11 +50,44 @@ func ntohs(val uint16) uint16 { return binary.BigEndian.Uint16((*tmp)[:]) } -func NewEndpoint() Endpoint { - return &NativeEndpoint{} +func CreateEndpoint(s string) (Endpoint, error) { + var end NativeEndpoint + addr, err := parseEndpoint(s) + if err != nil { + return nil, err + } + + ipv4 := addr.IP.To4() + if ipv4 != nil { + dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) + dst.Family = unix.AF_INET + dst.Port = htons(uint16(addr.Port)) + dst.Zero = [8]byte{} + copy(dst.Addr[:], ipv4) + end.ClearSrc() + return &end, nil + } + + ipv6 := addr.IP.To16() + if ipv6 != nil { + zone, err := zoneToUint32(addr.Zone) + if err != nil { + return nil, err + } + dst := &end.dst + dst.Family = unix.AF_INET6 + dst.Port = htons(uint16(addr.Port)) + dst.Flowinfo = 0 + dst.Scope_id = zone + copy(dst.Addr[:], ipv6[:]) + end.ClearSrc() + return &end, nil + } + + return nil, errors.New("Failed to recognize IP address format") } -func CreateUDPBind(port uint16) (Bind, uint16, error) { +func CreateBind(port uint16) (Bind, uint16, error) { var err error var bind NativeBind @@ -325,42 +358,6 @@ func create6(port uint16) (int, uint16, error) { return fd, uint16(addr.Port), err } -func (end *NativeEndpoint) SetDst(s string) error { - addr, err := parseEndpoint(s) - if err != nil { - return err - } - - ipv4 := addr.IP.To4() - if ipv4 != nil { - dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) - dst.Family = unix.AF_INET - dst.Port = htons(uint16(addr.Port)) - dst.Zero = [8]byte{} - copy(dst.Addr[:], ipv4) - end.ClearSrc() - return nil - } - - ipv6 := addr.IP.To16() - if ipv6 != nil { - zone, err := zoneToUint32(addr.Zone) - if err != nil { - return err - } - dst := &end.dst - dst.Family = unix.AF_INET6 - dst.Port = htons(uint16(addr.Port)) - dst.Flowinfo = 0 - dst.Scope_id = zone - copy(dst.Addr[:], ipv6[:]) - end.ClearSrc() - return nil - } - - return errors.New("Failed to recognize IP address format") -} - func send6(sock int, end *NativeEndpoint, buff []byte) error { // construct message header diff --git a/src/uapi.go b/src/uapi.go index 670ecc4..dc8be66 100644 --- a/src/uapi.go +++ b/src/uapi.go @@ -260,9 +260,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { err := func() error { peer.mutex.Lock() defer peer.mutex.Unlock() - - endpoint := NewEndpoint() - if err := endpoint.SetDst(value); err != nil { + endpoint, err := CreateEndpoint(value) + if err != nil { return err } peer.endpoint = endpoint