diff --git a/conn_default.go b/conn_default.go index 14ed56c..92135cb 100644 --- a/conn_default.go +++ b/conn_default.go @@ -11,7 +11,9 @@ package main import ( "golang.org/x/sys/unix" "net" + "os" "runtime" + "syscall" ) /* This code is meant to be a temporary solution @@ -87,6 +89,18 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) { return conn, uaddr.Port, nil } +func extractErrno(err error) error { + opErr, ok := err.(*net.OpError) + if !ok { + return nil + } + syscallErr, ok := opErr.Err.(*os.SyscallError) + if !ok { + return nil + } + return syscallErr.Err +} + func CreateBind(uport uint16, device *Device) (Bind, uint16, error) { var err error var bind NativeBind @@ -94,13 +108,15 @@ func CreateBind(uport uint16, device *Device) (Bind, uint16, error) { port := int(uport) bind.ipv4, port, err = listenNet("udp4", port) - if err != nil { + if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT { return nil, 0, err } bind.ipv6, port, err = listenNet("udp6", port) - if err != nil { + if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT { + return nil, 0, err bind.ipv4.Close() + bind.ipv4 = nil return nil, 0, err } @@ -108,8 +124,13 @@ func CreateBind(uport uint16, device *Device) (Bind, uint16, error) { } func (bind *NativeBind) Close() error { - err1 := bind.ipv4.Close() - err2 := bind.ipv6.Close() + var err1, err2 error + if bind.ipv4 != nil { + err1 = bind.ipv4.Close() + } + if bind.ipv6 != nil { + err2 = bind.ipv6.Close() + } if err1 != nil { return err1 } @@ -117,6 +138,9 @@ func (bind *NativeBind) Close() error { } func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { + if bind.ipv4 == nil { + return 0, nil, syscall.EAFNOSUPPORT + } n, endpoint, err := bind.ipv4.ReadFromUDP(buff) if endpoint != nil { endpoint.IP = endpoint.IP.To4() @@ -125,6 +149,9 @@ func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { } func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { + if bind.ipv6 == nil { + return 0, nil, syscall.EAFNOSUPPORT + } n, endpoint, err := bind.ipv6.ReadFromUDP(buff) return n, (*NativeEndpoint)(endpoint), err } @@ -133,8 +160,14 @@ func (bind *NativeBind) Send(buff []byte, endpoint Endpoint) error { var err error nend := endpoint.(*NativeEndpoint) if nend.IP.To4() != nil { + if bind.ipv4 == nil { + return syscall.EAFNOSUPPORT + } _, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend)) } else { + if bind.ipv6 == nil { + return syscall.EAFNOSUPPORT + } _, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend)) } return err @@ -157,31 +190,29 @@ func (bind *NativeBind) SetMark(mark uint32) error { if fwmarkIoctl == 0 { return nil } - fd4, err1 := bind.ipv4.SyscallConn() - fd6, err2 := bind.ipv6.SyscallConn() - if err1 != nil { - return err1 + if bind.ipv4 != nil { + fd, err := bind.ipv4.SyscallConn() + if err != nil { + return err + } + err = fd.Control(func(fd uintptr) { + err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) + }) + if err != nil { + return err + } } - if err2 != nil { - return err2 - } - err3 := fd4.Control(func(fd uintptr) { - err1 = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) - }) - err4 := fd6.Control(func(fd uintptr) { - err2 = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) - }) - if err1 != nil { - return err1 - } - if err2 != nil { - return err2 - } - if err3 != nil { - return err3 - } - if err4 != nil { - return err4 + if bind.ipv6 != nil { + fd, err := bind.ipv6.SyscallConn() + if err != nil { + return err + } + err = fd.Control(func(fd uintptr) { + err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) + }) + if err != nil { + return err + } } return nil } diff --git a/conn_linux.go b/conn_linux.go index 0227f04..2b15d05 100644 --- a/conn_linux.go +++ b/conn_linux.go @@ -24,6 +24,7 @@ import ( "net" "strconv" "sync" + "syscall" "unsafe" ) @@ -140,40 +141,45 @@ func CreateBind(port uint16, device *Device) (*NativeBind, uint16, error) { go bind.routineRouteListener(device) bind.sock6, port, err = create6(port) - if err != nil { + if err != nil && err != syscall.EAFNOSUPPORT { bind.netlinkCancel.Cancel() - return nil, port, err + return nil, 0, err } bind.sock4, port, err = create4(port) - if err != nil { + if err != nil && err != syscall.EAFNOSUPPORT { bind.netlinkCancel.Cancel() unix.Close(bind.sock6) + return nil, 0, err } - return &bind, port, err + return &bind, port, nil } func (bind *NativeBind) SetMark(value uint32) error { - err := unix.SetsockoptInt( - bind.sock6, - unix.SOL_SOCKET, - unix.SO_MARK, - int(value), - ) + if bind.sock6 != -1 { + err := unix.SetsockoptInt( + bind.sock6, + unix.SOL_SOCKET, + unix.SO_MARK, + int(value), + ) - if err != nil { - return err + if err != nil { + return err + } } - err = unix.SetsockoptInt( - bind.sock4, - unix.SOL_SOCKET, - unix.SO_MARK, - int(value), - ) + if bind.sock4 != -1 { + err := unix.SetsockoptInt( + bind.sock4, + unix.SOL_SOCKET, + unix.SO_MARK, + int(value), + ) - if err != nil { - return err + if err != nil { + return err + } } bind.lastMark = value @@ -187,9 +193,14 @@ func closeUnblock(fd int) error { } func (bind *NativeBind) Close() error { - err1 := closeUnblock(bind.sock6) - err2 := closeUnblock(bind.sock4) - err3 := bind.netlinkCancel.Cancel() + var err1, err2, err3 error + if bind.sock6 != -1 { + err1 = closeUnblock(bind.sock6) + } + if bind.sock4 != -1 { + err2 = closeUnblock(bind.sock4) + } + err3 = bind.netlinkCancel.Cancel() if err1 != nil { return err1 @@ -202,6 +213,9 @@ func (bind *NativeBind) Close() error { func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { var end NativeEndpoint + if bind.sock6 == -1 { + return 0, nil, syscall.EAFNOSUPPORT + } n, err := receive6( bind.sock6, buff, @@ -212,6 +226,9 @@ func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { var end NativeEndpoint + if bind.sock4 == -1 { + return 0, nil, syscall.EAFNOSUPPORT + } n, err := receive4( bind.sock4, buff, @@ -223,8 +240,14 @@ func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { func (bind *NativeBind) Send(buff []byte, end Endpoint) error { nend := end.(*NativeEndpoint) if !nend.isV6 { + if bind.sock4 == -1 { + return syscall.EAFNOSUPPORT + } return send4(bind.sock4, nend, buff) } else { + if bind.sock6 == -1 { + return syscall.EAFNOSUPPORT + } return send6(bind.sock6, nend, buff) } }