Clear src cache if route changes to new ifindex
This commit is contained in:
		
							parent
							
								
									92261b770f
								
							
						
					
					
						commit
						b34604245e
					
				
							
								
								
									
										160
									
								
								conn_linux.go
									
									
									
									
									
								
							
							
						
						
									
										160
									
								
								conn_linux.go
									
									
									
									
									
								
							@ -53,12 +53,15 @@ func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type NativeBind struct {
 | 
			
		||||
	sock4 int
 | 
			
		||||
	sock6 int
 | 
			
		||||
	sock4        int
 | 
			
		||||
	sock6        int
 | 
			
		||||
	netlinkSock  int
 | 
			
		||||
	lastEndpoint *NativeEndpoint
 | 
			
		||||
	lastMark     uint32
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ Endpoint = (*NativeEndpoint)(nil)
 | 
			
		||||
var _ Bind = NativeBind{}
 | 
			
		||||
var _ Bind = (*NativeBind)(nil)
 | 
			
		||||
 | 
			
		||||
func CreateEndpoint(s string) (Endpoint, error) {
 | 
			
		||||
	var end NativeEndpoint
 | 
			
		||||
@ -95,23 +98,50 @@ func CreateEndpoint(s string) (Endpoint, error) {
 | 
			
		||||
	return nil, errors.New("Invalid IP address")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CreateBind(port uint16) (Bind, uint16, error) {
 | 
			
		||||
func createNetlinkRouteSocket() (int, error) {
 | 
			
		||||
	sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return -1, err
 | 
			
		||||
	}
 | 
			
		||||
	saddr := &unix.SockaddrNetlink{
 | 
			
		||||
		Family: unix.AF_NETLINK,
 | 
			
		||||
		Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)),
 | 
			
		||||
	}
 | 
			
		||||
	err = unix.Bind(sock, saddr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		unix.Close(sock)
 | 
			
		||||
		return -1, err
 | 
			
		||||
	}
 | 
			
		||||
	return sock, nil
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CreateBind(port uint16) (*NativeBind, uint16, error) {
 | 
			
		||||
	var err error
 | 
			
		||||
	var bind NativeBind
 | 
			
		||||
 | 
			
		||||
	bind.netlinkSock, err = createNetlinkRouteSocket()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, 0, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	go bind.routineRouteListener()
 | 
			
		||||
 | 
			
		||||
	bind.sock6, port, err = create6(port)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		unix.Close(bind.netlinkSock)
 | 
			
		||||
		return nil, port, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	bind.sock4, port, err = create4(port)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		unix.Close(bind.netlinkSock)
 | 
			
		||||
		unix.Close(bind.sock6)
 | 
			
		||||
	}
 | 
			
		||||
	return bind, port, err
 | 
			
		||||
	return &bind, port, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (bind NativeBind) SetMark(value uint32) error {
 | 
			
		||||
func (bind *NativeBind) SetMark(value uint32) error {
 | 
			
		||||
	err := unix.SetsockoptInt(
 | 
			
		||||
		bind.sock6,
 | 
			
		||||
		unix.SOL_SOCKET,
 | 
			
		||||
@ -123,12 +153,19 @@ func (bind NativeBind) SetMark(value uint32) error {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return unix.SetsockoptInt(
 | 
			
		||||
	err = unix.SetsockoptInt(
 | 
			
		||||
		bind.sock4,
 | 
			
		||||
		unix.SOL_SOCKET,
 | 
			
		||||
		unix.SO_MARK,
 | 
			
		||||
		int(value),
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	bind.lastMark = value
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func closeUnblock(fd int) error {
 | 
			
		||||
@ -137,16 +174,20 @@ func closeUnblock(fd int) error {
 | 
			
		||||
	return unix.Close(fd)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (bind NativeBind) Close() error {
 | 
			
		||||
func (bind *NativeBind) Close() error {
 | 
			
		||||
	err1 := closeUnblock(bind.sock6)
 | 
			
		||||
	err2 := closeUnblock(bind.sock4)
 | 
			
		||||
	err3 := closeUnblock(bind.netlinkSock)
 | 
			
		||||
	if err1 != nil {
 | 
			
		||||
		return err1
 | 
			
		||||
	}
 | 
			
		||||
	return err2
 | 
			
		||||
	if err2 != nil {
 | 
			
		||||
		return err2
 | 
			
		||||
	}
 | 
			
		||||
	return err3
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (bind NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
 | 
			
		||||
func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
 | 
			
		||||
	var end NativeEndpoint
 | 
			
		||||
	n, err := receive6(
 | 
			
		||||
		bind.sock6,
 | 
			
		||||
@ -156,17 +197,18 @@ func (bind NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
 | 
			
		||||
	return n, &end, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (bind NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
 | 
			
		||||
func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
 | 
			
		||||
	var end NativeEndpoint
 | 
			
		||||
	n, err := receive4(
 | 
			
		||||
		bind.sock4,
 | 
			
		||||
		buff,
 | 
			
		||||
		&end,
 | 
			
		||||
	)
 | 
			
		||||
	bind.lastEndpoint = &end
 | 
			
		||||
	return n, &end, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (bind NativeBind) Send(buff []byte, end Endpoint) error {
 | 
			
		||||
func (bind *NativeBind) Send(buff []byte, end Endpoint) error {
 | 
			
		||||
	nend := end.(*NativeEndpoint)
 | 
			
		||||
	if !nend.isV6 {
 | 
			
		||||
		return send4(bind.sock4, nend, buff)
 | 
			
		||||
@ -506,3 +548,97 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
 | 
			
		||||
 | 
			
		||||
	return size, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (bind *NativeBind) routineRouteListener() {
 | 
			
		||||
	// TODO: this function doesn't lock the endpoint it modifies
 | 
			
		||||
 | 
			
		||||
	for msg := make([]byte, 1<<16); ; {
 | 
			
		||||
		msgn, _, _, _, err := unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
 | 
			
		||||
 | 
			
		||||
			hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
 | 
			
		||||
 | 
			
		||||
			if uint(hdr.Len) > uint(len(remain)) {
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			switch hdr.Type {
 | 
			
		||||
			case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
 | 
			
		||||
 | 
			
		||||
				if bind.lastEndpoint == nil || bind.lastEndpoint.isV6 || bind.lastEndpoint.src4().ifindex == 0 {
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if hdr.Seq == 0xff {
 | 
			
		||||
					if uint(len(remain)) < uint(hdr.Len) {
 | 
			
		||||
						break
 | 
			
		||||
					}
 | 
			
		||||
					if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
 | 
			
		||||
						attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
 | 
			
		||||
						for {
 | 
			
		||||
							if uint(len(attr)) < uint(unix.SizeofRtAttr) {
 | 
			
		||||
								break
 | 
			
		||||
							}
 | 
			
		||||
							attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
 | 
			
		||||
							if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
 | 
			
		||||
								break
 | 
			
		||||
							}
 | 
			
		||||
							if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
 | 
			
		||||
								ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
 | 
			
		||||
								if uint32(bind.lastEndpoint.src4().ifindex) != ifidx {
 | 
			
		||||
									bind.lastEndpoint.ClearSrc()
 | 
			
		||||
								}
 | 
			
		||||
							}
 | 
			
		||||
							attr = attr[attrhdr.Len:]
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				nlmsg := struct {
 | 
			
		||||
					hdr     unix.NlMsghdr
 | 
			
		||||
					msg     unix.RtMsg
 | 
			
		||||
					dsthdr  unix.RtAttr
 | 
			
		||||
					dst     [4]byte
 | 
			
		||||
					srchdr  unix.RtAttr
 | 
			
		||||
					src     [4]byte
 | 
			
		||||
					markhdr unix.RtAttr
 | 
			
		||||
					mark    uint32
 | 
			
		||||
				}{
 | 
			
		||||
					unix.NlMsghdr{
 | 
			
		||||
						Type:  uint16(unix.RTM_GETROUTE),
 | 
			
		||||
						Flags: unix.NLM_F_REQUEST,
 | 
			
		||||
						Seq:   0xff,
 | 
			
		||||
					},
 | 
			
		||||
					unix.RtMsg{
 | 
			
		||||
						Family:  unix.AF_INET,
 | 
			
		||||
						Dst_len: 32,
 | 
			
		||||
						Src_len: 32,
 | 
			
		||||
					},
 | 
			
		||||
					unix.RtAttr{
 | 
			
		||||
						Len:  8,
 | 
			
		||||
						Type: unix.RTA_DST,
 | 
			
		||||
					},
 | 
			
		||||
					bind.lastEndpoint.dst4().Addr,
 | 
			
		||||
					unix.RtAttr{
 | 
			
		||||
						Len:  8,
 | 
			
		||||
						Type: unix.RTA_SRC,
 | 
			
		||||
					},
 | 
			
		||||
					bind.lastEndpoint.src4().src,
 | 
			
		||||
					unix.RtAttr{
 | 
			
		||||
						Len:  8,
 | 
			
		||||
						Type: 0x10, //unix.RTA_MARK  TODO: add this to x/sys/unix
 | 
			
		||||
					},
 | 
			
		||||
					uint32(bind.lastMark),
 | 
			
		||||
				}
 | 
			
		||||
				nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
 | 
			
		||||
				unix.Write(bind.netlinkSock, (*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
 | 
			
		||||
			}
 | 
			
		||||
			remain = remain[hdr.Len:]
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -79,7 +79,6 @@ func (tun *NativeTun) RoutineNetlinkListener() {
 | 
			
		||||
	defer unix.Close(sock)
 | 
			
		||||
	saddr := &unix.SockaddrNetlink{
 | 
			
		||||
		Family: unix.AF_NETLINK,
 | 
			
		||||
		Pid:    uint32(os.Getpid()),
 | 
			
		||||
		Groups: uint32(groups),
 | 
			
		||||
	}
 | 
			
		||||
	err = unix.Bind(sock, saddr)
 | 
			
		||||
@ -90,7 +89,9 @@ func (tun *NativeTun) RoutineNetlinkListener() {
 | 
			
		||||
 | 
			
		||||
	// TODO: This function never actually exits in response to anything,
 | 
			
		||||
	// a go routine that goes forever. We'll want to fix that if this is
 | 
			
		||||
	// to ever be used as any sort of library.
 | 
			
		||||
	// to ever be used as any sort of library. See what we've done with
 | 
			
		||||
	// calling shutdown() on the netlink socket in conn_linux.go, and
 | 
			
		||||
	// change this to be more like that.
 | 
			
		||||
 | 
			
		||||
	for msg := make([]byte, 1<<16); ; {
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user