/* SPDX-License-Identifier: MIT
 *
 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
 *
 * This implements userspace semantics of "sticky sockets", modeled after
 * WireGuard's kernelspace implementation. This is more or less a straight port
 * of the sticky-sockets.c example code:
 * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
 *
 * Currently there is no way to achieve this within the net package:
 * See e.g. https://github.com/golang/go/issues/17930
 * So this code is remains platform dependent.
 */

package device

import (
	"sync"
	"unsafe"

	"golang.org/x/sys/unix"

	"golang.zx2c4.com/wireguard/conn"
	"golang.zx2c4.com/wireguard/rwcancel"
)

func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
	if _, ok := bind.(*conn.LinuxSocketBind); !ok {
		return nil, nil
	}

	netlinkSock, err := createNetlinkRouteSocket()
	if err != nil {
		return nil, err
	}
	netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock)
	if err != nil {
		unix.Close(netlinkSock)
		return nil, err
	}

	go device.routineRouteListener(bind, netlinkSock, netlinkCancel)

	return netlinkCancel, nil
}

func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
	type peerEndpointPtr struct {
		peer     *Peer
		endpoint *conn.Endpoint
	}
	var reqPeer map[uint32]peerEndpointPtr
	var reqPeerLock sync.Mutex

	defer netlinkCancel.Close()
	defer unix.Close(netlinkSock)

	for msg := make([]byte, 1<<16); ; {
		var err error
		var msgn int
		for {
			msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0)
			if err == nil || !rwcancel.RetryAfterError(err) {
				break
			}
			if !netlinkCancel.ReadyRead() {
				return
			}
		}
		if err != nil {
			return
		}

		for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {

			hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))

			if uint(hdr.Len) > uint(len(remain)) {
				break
			}

			switch hdr.Type {
			case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
				if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
					if uint(len(remain)) < uint(hdr.Len) {
						break
					}
					if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
						attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
						for {
							if uint(len(attr)) < uint(unix.SizeofRtAttr) {
								break
							}
							attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
							if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
								break
							}
							if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
								ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
								reqPeerLock.Lock()
								if reqPeer == nil {
									reqPeerLock.Unlock()
									break
								}
								pePtr, ok := reqPeer[hdr.Seq]
								reqPeerLock.Unlock()
								if !ok {
									break
								}
								pePtr.peer.Lock()
								if &pePtr.peer.endpoint != pePtr.endpoint {
									pePtr.peer.Unlock()
									break
								}
								if uint32(pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).Src4().Ifindex) == ifidx {
									pePtr.peer.Unlock()
									break
								}
								pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).ClearSrc()
								pePtr.peer.Unlock()
							}
							attr = attr[attrhdr.Len:]
						}
					}
					break
				}
				reqPeerLock.Lock()
				reqPeer = make(map[uint32]peerEndpointPtr)
				reqPeerLock.Unlock()
				go func() {
					device.peers.RLock()
					i := uint32(1)
					for _, peer := range device.peers.keyMap {
						peer.RLock()
						if peer.endpoint == nil {
							peer.RUnlock()
							continue
						}
						nativeEP, _ := peer.endpoint.(*conn.LinuxSocketEndpoint)
						if nativeEP == nil {
							peer.RUnlock()
							continue
						}
						if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 {
							peer.RUnlock()
							break
						}
						nlmsg := struct {
							hdr     unix.NlMsghdr
							msg     unix.RtMsg
							dsthdr  unix.RtAttr
							dst     [4]byte
							srchdr  unix.RtAttr
							src     [4]byte
							markhdr unix.RtAttr
							mark    uint32
						}{
							unix.NlMsghdr{
								Type:  uint16(unix.RTM_GETROUTE),
								Flags: unix.NLM_F_REQUEST,
								Seq:   i,
							},
							unix.RtMsg{
								Family:  unix.AF_INET,
								Dst_len: 32,
								Src_len: 32,
							},
							unix.RtAttr{
								Len:  8,
								Type: unix.RTA_DST,
							},
							nativeEP.Dst4().Addr,
							unix.RtAttr{
								Len:  8,
								Type: unix.RTA_SRC,
							},
							nativeEP.Src4().Src,
							unix.RtAttr{
								Len:  8,
								Type: unix.RTA_MARK,
							},
							device.net.fwmark,
						}
						nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
						reqPeerLock.Lock()
						reqPeer[i] = peerEndpointPtr{
							peer:     peer,
							endpoint: &peer.endpoint,
						}
						reqPeerLock.Unlock()
						peer.RUnlock()
						i++
						_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
						if err != nil {
							break
						}
					}
					device.peers.RUnlock()
				}()
			}
			remain = remain[hdr.Len:]
		}
	}
}

func createNetlinkRouteSocket() (int, error) {
	sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
	if err != nil {
		return -1, err
	}
	saddr := &unix.SockaddrNetlink{
		Family: unix.AF_NETLINK,
		Groups: unix.RTMGRP_IPV4_ROUTE,
	}
	err = unix.Bind(sock, saddr)
	if err != nil {
		unix.Close(sock)
		return -1, err
	}
	return sock, nil
}