diff --git a/conn.go b/conn.go index 4b347ec..92f4cfe 100644 --- a/conn.go +++ b/conn.go @@ -123,7 +123,7 @@ func (device *Device) BindUpdate() error { var err error netc := &device.net - netc.bind, netc.port, err = CreateBind(netc.port) + netc.bind, netc.port, err = CreateBind(netc.port, device) if err != nil { netc.bind = nil netc.port = 0 diff --git a/conn_default.go b/conn_default.go index 047d5f6..7556210 100644 --- a/conn_default.go +++ b/conn_default.go @@ -81,7 +81,7 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) { return conn, uaddr.Port, nil } -func CreateBind(uport uint16) (Bind, uint16, error) { +func CreateBind(uport uint16, device *Device) (Bind, uint16, error) { var err error var bind NativeBind diff --git a/conn_linux.go b/conn_linux.go index a428138..2b920bf 100644 --- a/conn_linux.go +++ b/conn_linux.go @@ -55,11 +55,10 @@ func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 { } type NativeBind struct { - sock4 int - sock6 int - netlinkSock int - lastEndpoint *NativeEndpoint - lastMark uint32 + sock4 int + sock6 int + netlinkSock int + lastMark uint32 } var _ Endpoint = (*NativeEndpoint)(nil) @@ -118,7 +117,7 @@ func createNetlinkRouteSocket() (int, error) { } -func CreateBind(port uint16) (*NativeBind, uint16, error) { +func CreateBind(port uint16, device *Device) (*NativeBind, uint16, error) { var err error var bind NativeBind @@ -127,7 +126,7 @@ func CreateBind(port uint16) (*NativeBind, uint16, error) { return nil, 0, err } - go bind.routineRouteListener() + go bind.routineRouteListener(device) bind.sock6, port, err = create6(port) if err != nil { @@ -171,8 +170,8 @@ func (bind *NativeBind) SetMark(value uint32) error { } func closeUnblock(fd int) error { - // shutdown to unblock readers - unix.Shutdown(fd, unix.SHUT_RD) + // shutdown to unblock readers and writers + unix.Shutdown(fd, unix.SHUT_RDWR) return unix.Close(fd) } @@ -206,7 +205,6 @@ func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { buff, &end, ) - bind.lastEndpoint = &end return n, &end, err } @@ -551,8 +549,8 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { return size, nil } -func (bind *NativeBind) routineRouteListener() { - // TODO: this function doesn't lock the endpoint it modifies +func (bind *NativeBind) routineRouteListener(device *Device) { + var reqPeer map[uint32]*Peer for msg := make([]byte, 1<<16); ; { msgn, _, _, _, err := unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0) @@ -570,12 +568,7 @@ func (bind *NativeBind) routineRouteListener() { switch hdr.Type { case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: - - if bind.lastEndpoint == nil || bind.lastEndpoint.isV6 || bind.lastEndpoint.src4().ifindex == 0 { - break - } - - if hdr.Seq == 0xff { + if hdr.Seq <= MaxPeers { if uint(len(remain)) < uint(hdr.Len) { break } @@ -591,54 +584,90 @@ func (bind *NativeBind) routineRouteListener() { } if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 { ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr])) - if uint32(bind.lastEndpoint.src4().ifindex) != ifidx { - bind.lastEndpoint.ClearSrc() + if reqPeer == nil { + break } + peer, ok := reqPeer[hdr.Seq] + if !ok { + break + } + peer.mutex.RLock() + if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil { + peer.mutex.RUnlock() + break + } + if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 { + peer.mutex.RUnlock() + break + } + if uint32(peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx { + peer.mutex.RUnlock() + break + } + peer.mutex.RUnlock() + peer.mutex.Lock() + peer.endpoint.(*NativeEndpoint).ClearSrc() + peer.mutex.Unlock() } attr = attr[attrhdr.Len:] } } break } - - nlmsg := struct { - hdr unix.NlMsghdr - msg unix.RtMsg - dsthdr unix.RtAttr - dst [4]byte - srchdr unix.RtAttr - src [4]byte - markhdr unix.RtAttr - mark uint32 - }{ - unix.NlMsghdr{ - Type: uint16(unix.RTM_GETROUTE), - Flags: unix.NLM_F_REQUEST, - Seq: 0xff, - }, - unix.RtMsg{ - Family: unix.AF_INET, - Dst_len: 32, - Src_len: 32, - }, - unix.RtAttr{ - Len: 8, - Type: unix.RTA_DST, - }, - bind.lastEndpoint.dst4().Addr, - unix.RtAttr{ - Len: 8, - Type: unix.RTA_SRC, - }, - bind.lastEndpoint.src4().src, - unix.RtAttr{ - Len: 8, - Type: 0x10, //unix.RTA_MARK TODO: add this to x/sys/unix - }, - uint32(bind.lastMark), - } - nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) - unix.Write(bind.netlinkSock, (*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) + reqPeer = make(map[uint32]*Peer) + go func() { + device.peers.mutex.RLock() + i := uint32(1) + for _, peer := range device.peers.keyMap { + peer.mutex.RLock() + if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil { + peer.mutex.RUnlock() + continue + } + nlmsg := struct { + hdr unix.NlMsghdr + msg unix.RtMsg + dsthdr unix.RtAttr + dst [4]byte + srchdr unix.RtAttr + src [4]byte + markhdr unix.RtAttr + mark uint32 + }{ + unix.NlMsghdr{ + Type: uint16(unix.RTM_GETROUTE), + Flags: unix.NLM_F_REQUEST, + Seq: i, + }, + unix.RtMsg{ + Family: unix.AF_INET, + Dst_len: 32, + Src_len: 32, + }, + unix.RtAttr{ + Len: 8, + Type: unix.RTA_DST, + }, + peer.endpoint.(*NativeEndpoint).dst4().Addr, + unix.RtAttr{ + Len: 8, + Type: unix.RTA_SRC, + }, + peer.endpoint.(*NativeEndpoint).src4().src, + unix.RtAttr{ + Len: 8, + Type: 0x10, //unix.RTA_MARK TODO: add this to x/sys/unix + }, + uint32(bind.lastMark), + } + nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) + reqPeer[i] = peer + peer.mutex.RUnlock() + i++ + unix.Write(bind.netlinkSock, (*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) + } + device.peers.mutex.RUnlock() + }() } remain = remain[hdr.Len:] }