3bb8fec7e4
Accept packet vectors for reading and writing in the tun.Device and conn.Bind interfaces, so that the internal plumbing between these interfaces now passes a vector of packets. Vectors move untouched between these interfaces, i.e. if 128 packets are received from conn.Bind.Read(), 128 packets are passed to tun.Device.Write(). There is no internal buffering. Currently, existing implementations are only adjusted to have vectors of length one. Subsequent patches will improve that. Also, as a related fixup, use the unix and windows packages rather than the syscall package when possible. Co-authored-by: James Tucker <james@tailscale.com> Signed-off-by: James Tucker <james@tailscale.com> Signed-off-by: Jordan Whited <jordan@tailscale.com> Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
1051 lines
25 KiB
Go
1051 lines
25 KiB
Go
/* SPDX-License-Identifier: MIT
|
|
*
|
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
*/
|
|
|
|
package netstack
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/netip"
|
|
"os"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
|
|
"golang.zx2c4.com/wireguard/tun"
|
|
|
|
"golang.org/x/net/dns/dnsmessage"
|
|
"gvisor.dev/gvisor/pkg/bufferv2"
|
|
"gvisor.dev/gvisor/pkg/tcpip"
|
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
|
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
|
|
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
|
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
|
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
|
"gvisor.dev/gvisor/pkg/waiter"
|
|
)
|
|
|
|
type netTun struct {
|
|
ep *channel.Endpoint
|
|
stack *stack.Stack
|
|
events chan tun.Event
|
|
incomingPacket chan *bufferv2.View
|
|
mtu int
|
|
dnsServers []netip.Addr
|
|
hasV4, hasV6 bool
|
|
}
|
|
|
|
type Net netTun
|
|
|
|
func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
|
|
opts := stack.Options{
|
|
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
|
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
|
|
HandleLocal: true,
|
|
}
|
|
dev := &netTun{
|
|
ep: channel.New(1024, uint32(mtu), ""),
|
|
stack: stack.New(opts),
|
|
events: make(chan tun.Event, 10),
|
|
incomingPacket: make(chan *bufferv2.View),
|
|
dnsServers: dnsServers,
|
|
mtu: mtu,
|
|
}
|
|
dev.ep.AddNotify(dev)
|
|
tcpipErr := dev.stack.CreateNIC(1, dev.ep)
|
|
if tcpipErr != nil {
|
|
return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
|
|
}
|
|
for _, ip := range localAddresses {
|
|
var protoNumber tcpip.NetworkProtocolNumber
|
|
if ip.Is4() {
|
|
protoNumber = ipv4.ProtocolNumber
|
|
} else if ip.Is6() {
|
|
protoNumber = ipv6.ProtocolNumber
|
|
}
|
|
protoAddr := tcpip.ProtocolAddress{
|
|
Protocol: protoNumber,
|
|
AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(),
|
|
}
|
|
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
|
|
if tcpipErr != nil {
|
|
return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
|
|
}
|
|
if ip.Is4() {
|
|
dev.hasV4 = true
|
|
} else if ip.Is6() {
|
|
dev.hasV6 = true
|
|
}
|
|
}
|
|
if dev.hasV4 {
|
|
dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1})
|
|
}
|
|
if dev.hasV6 {
|
|
dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
|
|
}
|
|
|
|
dev.events <- tun.EventUp
|
|
return dev, (*Net)(dev), nil
|
|
}
|
|
|
|
func (tun *netTun) Name() (string, error) {
|
|
return "go", nil
|
|
}
|
|
|
|
func (tun *netTun) File() *os.File {
|
|
return nil
|
|
}
|
|
|
|
func (tun *netTun) Events() <-chan tun.Event {
|
|
return tun.events
|
|
}
|
|
|
|
func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
|
|
view, ok := <-tun.incomingPacket
|
|
if !ok {
|
|
return 0, os.ErrClosed
|
|
}
|
|
|
|
n, err := view.Read(buf[0][offset:])
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
sizes[0] = n
|
|
return 1, nil
|
|
}
|
|
|
|
func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
|
|
for _, buf := range buf {
|
|
packet := buf[offset:]
|
|
if len(packet) == 0 {
|
|
continue
|
|
}
|
|
|
|
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: bufferv2.MakeWithData(packet)})
|
|
switch packet[0] >> 4 {
|
|
case 4:
|
|
tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
|
|
case 6:
|
|
tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
|
|
default:
|
|
return 0, syscall.EAFNOSUPPORT
|
|
}
|
|
}
|
|
return len(buf), nil
|
|
}
|
|
|
|
func (tun *netTun) WriteNotify() {
|
|
pkt := tun.ep.Read()
|
|
if pkt.IsNil() {
|
|
return
|
|
}
|
|
|
|
view := pkt.ToView()
|
|
pkt.DecRef()
|
|
|
|
tun.incomingPacket <- view
|
|
}
|
|
|
|
func (tun *netTun) Close() error {
|
|
tun.stack.RemoveNIC(1)
|
|
|
|
if tun.events != nil {
|
|
close(tun.events)
|
|
}
|
|
|
|
tun.ep.Close()
|
|
|
|
if tun.incomingPacket != nil {
|
|
close(tun.incomingPacket)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (tun *netTun) MTU() (int, error) {
|
|
return tun.mtu, nil
|
|
}
|
|
|
|
func (tun *netTun) BatchSize() int {
|
|
return 1
|
|
}
|
|
|
|
func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
|
|
var protoNumber tcpip.NetworkProtocolNumber
|
|
if endpoint.Addr().Is4() {
|
|
protoNumber = ipv4.ProtocolNumber
|
|
} else {
|
|
protoNumber = ipv6.ProtocolNumber
|
|
}
|
|
return tcpip.FullAddress{
|
|
NIC: 1,
|
|
Addr: tcpip.Address(endpoint.Addr().AsSlice()),
|
|
Port: endpoint.Port(),
|
|
}, protoNumber
|
|
}
|
|
|
|
func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
|
|
fa, pn := convertToFullAddr(addr)
|
|
return gonet.DialContextTCP(ctx, net.stack, fa, pn)
|
|
}
|
|
|
|
func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
|
|
if addr == nil {
|
|
return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
|
|
}
|
|
ip, _ := netip.AddrFromSlice(addr.IP)
|
|
return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port)))
|
|
}
|
|
|
|
func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
|
|
fa, pn := convertToFullAddr(addr)
|
|
return gonet.DialTCP(net.stack, fa, pn)
|
|
}
|
|
|
|
func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
|
|
if addr == nil {
|
|
return net.DialTCPAddrPort(netip.AddrPort{})
|
|
}
|
|
ip, _ := netip.AddrFromSlice(addr.IP)
|
|
return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
|
|
}
|
|
|
|
func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
|
|
fa, pn := convertToFullAddr(addr)
|
|
return gonet.ListenTCP(net.stack, fa, pn)
|
|
}
|
|
|
|
func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
|
|
if addr == nil {
|
|
return net.ListenTCPAddrPort(netip.AddrPort{})
|
|
}
|
|
ip, _ := netip.AddrFromSlice(addr.IP)
|
|
return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
|
|
}
|
|
|
|
func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
|
|
var lfa, rfa *tcpip.FullAddress
|
|
var pn tcpip.NetworkProtocolNumber
|
|
if laddr.IsValid() || laddr.Port() > 0 {
|
|
var addr tcpip.FullAddress
|
|
addr, pn = convertToFullAddr(laddr)
|
|
lfa = &addr
|
|
}
|
|
if raddr.IsValid() || raddr.Port() > 0 {
|
|
var addr tcpip.FullAddress
|
|
addr, pn = convertToFullAddr(raddr)
|
|
rfa = &addr
|
|
}
|
|
return gonet.DialUDP(net.stack, lfa, rfa, pn)
|
|
}
|
|
|
|
func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) {
|
|
return net.DialUDPAddrPort(laddr, netip.AddrPort{})
|
|
}
|
|
|
|
func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
|
|
var la, ra netip.AddrPort
|
|
if laddr != nil {
|
|
ip, _ := netip.AddrFromSlice(laddr.IP)
|
|
la = netip.AddrPortFrom(ip, uint16(laddr.Port))
|
|
}
|
|
if raddr != nil {
|
|
ip, _ := netip.AddrFromSlice(raddr.IP)
|
|
ra = netip.AddrPortFrom(ip, uint16(raddr.Port))
|
|
}
|
|
return net.DialUDPAddrPort(la, ra)
|
|
}
|
|
|
|
func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
|
|
return net.DialUDP(laddr, nil)
|
|
}
|
|
|
|
type PingConn struct {
|
|
laddr PingAddr
|
|
raddr PingAddr
|
|
wq waiter.Queue
|
|
ep tcpip.Endpoint
|
|
deadline *time.Timer
|
|
}
|
|
|
|
type PingAddr struct{ addr netip.Addr }
|
|
|
|
func (ia PingAddr) String() string {
|
|
return ia.addr.String()
|
|
}
|
|
|
|
func (ia PingAddr) Network() string {
|
|
if ia.addr.Is4() {
|
|
return "ping4"
|
|
} else if ia.addr.Is6() {
|
|
return "ping6"
|
|
}
|
|
return "ping"
|
|
}
|
|
|
|
func (ia PingAddr) Addr() netip.Addr {
|
|
return ia.addr
|
|
}
|
|
|
|
func PingAddrFromAddr(addr netip.Addr) *PingAddr {
|
|
return &PingAddr{addr}
|
|
}
|
|
|
|
func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) {
|
|
if !laddr.IsValid() && !raddr.IsValid() {
|
|
return nil, errors.New("ping dial: invalid address")
|
|
}
|
|
v6 := laddr.Is6() || raddr.Is6()
|
|
bind := laddr.IsValid()
|
|
if !bind {
|
|
if v6 {
|
|
laddr = netip.IPv6Unspecified()
|
|
} else {
|
|
laddr = netip.IPv4Unspecified()
|
|
}
|
|
}
|
|
|
|
tn := icmp.ProtocolNumber4
|
|
pn := ipv4.ProtocolNumber
|
|
if v6 {
|
|
tn = icmp.ProtocolNumber6
|
|
pn = ipv6.ProtocolNumber
|
|
}
|
|
|
|
pc := &PingConn{
|
|
laddr: PingAddr{laddr},
|
|
deadline: time.NewTimer(time.Hour << 10),
|
|
}
|
|
pc.deadline.Stop()
|
|
|
|
ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq)
|
|
if tcpipErr != nil {
|
|
return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr)
|
|
}
|
|
pc.ep = ep
|
|
|
|
if bind {
|
|
fa, _ := convertToFullAddr(netip.AddrPortFrom(laddr, 0))
|
|
if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil {
|
|
return nil, fmt.Errorf("ping bind: %s", tcpipErr)
|
|
}
|
|
}
|
|
|
|
if raddr.IsValid() {
|
|
pc.raddr = PingAddr{raddr}
|
|
fa, _ := convertToFullAddr(netip.AddrPortFrom(raddr, 0))
|
|
if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil {
|
|
return nil, fmt.Errorf("ping connect: %s", tcpipErr)
|
|
}
|
|
}
|
|
|
|
return pc, nil
|
|
}
|
|
|
|
func (net *Net) ListenPingAddr(laddr netip.Addr) (*PingConn, error) {
|
|
return net.DialPingAddr(laddr, netip.Addr{})
|
|
}
|
|
|
|
func (net *Net) DialPing(laddr, raddr *PingAddr) (*PingConn, error) {
|
|
var la, ra netip.Addr
|
|
if laddr != nil {
|
|
la = laddr.addr
|
|
}
|
|
if raddr != nil {
|
|
ra = raddr.addr
|
|
}
|
|
return net.DialPingAddr(la, ra)
|
|
}
|
|
|
|
func (net *Net) ListenPing(laddr *PingAddr) (*PingConn, error) {
|
|
var la netip.Addr
|
|
if laddr != nil {
|
|
la = laddr.addr
|
|
}
|
|
return net.ListenPingAddr(la)
|
|
}
|
|
|
|
func (pc *PingConn) LocalAddr() net.Addr {
|
|
return pc.laddr
|
|
}
|
|
|
|
func (pc *PingConn) RemoteAddr() net.Addr {
|
|
return pc.raddr
|
|
}
|
|
|
|
func (pc *PingConn) Close() error {
|
|
pc.deadline.Reset(0)
|
|
pc.ep.Close()
|
|
return nil
|
|
}
|
|
|
|
func (pc *PingConn) SetWriteDeadline(t time.Time) error {
|
|
return errors.New("not implemented")
|
|
}
|
|
|
|
func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
|
var na netip.Addr
|
|
switch v := addr.(type) {
|
|
case *PingAddr:
|
|
na = v.addr
|
|
case *net.IPAddr:
|
|
na, _ = netip.AddrFromSlice(v.IP)
|
|
default:
|
|
return 0, fmt.Errorf("ping write: wrong net.Addr type")
|
|
}
|
|
if !((na.Is4() && pc.laddr.addr.Is4()) || (na.Is6() && pc.laddr.addr.Is6())) {
|
|
return 0, fmt.Errorf("ping write: mismatched protocols")
|
|
}
|
|
|
|
buf := bytes.NewReader(p)
|
|
rfa, _ := convertToFullAddr(netip.AddrPortFrom(na, 0))
|
|
// won't block, no deadlines
|
|
n64, tcpipErr := pc.ep.Write(buf, tcpip.WriteOptions{
|
|
To: &rfa,
|
|
})
|
|
if tcpipErr != nil {
|
|
return int(n64), fmt.Errorf("ping write: %s", tcpipErr)
|
|
}
|
|
|
|
return int(n64), nil
|
|
}
|
|
|
|
func (pc *PingConn) Write(p []byte) (n int, err error) {
|
|
return pc.WriteTo(p, &pc.raddr)
|
|
}
|
|
|
|
func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
|
e, notifyCh := waiter.NewChannelEntry(waiter.EventIn)
|
|
pc.wq.EventRegister(&e)
|
|
defer pc.wq.EventUnregister(&e)
|
|
|
|
select {
|
|
case <-pc.deadline.C:
|
|
return 0, nil, os.ErrDeadlineExceeded
|
|
case <-notifyCh:
|
|
}
|
|
|
|
w := tcpip.SliceWriter(p)
|
|
|
|
res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{
|
|
NeedRemoteAddr: true,
|
|
})
|
|
if tcpipErr != nil {
|
|
return 0, nil, fmt.Errorf("ping read: %s", tcpipErr)
|
|
}
|
|
|
|
remoteAddr, _ := netip.AddrFromSlice([]byte(res.RemoteAddr.Addr))
|
|
return res.Count, &PingAddr{remoteAddr}, nil
|
|
}
|
|
|
|
func (pc *PingConn) Read(p []byte) (n int, err error) {
|
|
n, _, err = pc.ReadFrom(p)
|
|
return
|
|
}
|
|
|
|
func (pc *PingConn) SetDeadline(t time.Time) error {
|
|
// pc.SetWriteDeadline is unimplemented
|
|
|
|
return pc.SetReadDeadline(t)
|
|
}
|
|
|
|
func (pc *PingConn) SetReadDeadline(t time.Time) error {
|
|
pc.deadline.Reset(time.Until(t))
|
|
return nil
|
|
}
|
|
|
|
var (
|
|
errNoSuchHost = errors.New("no such host")
|
|
errLameReferral = errors.New("lame referral")
|
|
errCannotUnmarshalDNSMessage = errors.New("cannot unmarshal DNS message")
|
|
errCannotMarshalDNSMessage = errors.New("cannot marshal DNS message")
|
|
errServerMisbehaving = errors.New("server misbehaving")
|
|
errInvalidDNSResponse = errors.New("invalid DNS response")
|
|
errNoAnswerFromDNSServer = errors.New("no answer from DNS server")
|
|
errServerTemporarilyMisbehaving = errors.New("server misbehaving")
|
|
errCanceled = errors.New("operation was canceled")
|
|
errTimeout = errors.New("i/o timeout")
|
|
errNumericPort = errors.New("port must be numeric")
|
|
errNoSuitableAddress = errors.New("no suitable address found")
|
|
errMissingAddress = errors.New("missing address")
|
|
)
|
|
|
|
func (net *Net) LookupHost(host string) (addrs []string, err error) {
|
|
return net.LookupContextHost(context.Background(), host)
|
|
}
|
|
|
|
func isDomainName(s string) bool {
|
|
l := len(s)
|
|
if l == 0 || l > 254 || l == 254 && s[l-1] != '.' {
|
|
return false
|
|
}
|
|
last := byte('.')
|
|
nonNumeric := false
|
|
partlen := 0
|
|
for i := 0; i < len(s); i++ {
|
|
c := s[i]
|
|
switch {
|
|
default:
|
|
return false
|
|
case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_':
|
|
nonNumeric = true
|
|
partlen++
|
|
case '0' <= c && c <= '9':
|
|
partlen++
|
|
case c == '-':
|
|
if last == '.' {
|
|
return false
|
|
}
|
|
partlen++
|
|
nonNumeric = true
|
|
case c == '.':
|
|
if last == '.' || last == '-' {
|
|
return false
|
|
}
|
|
if partlen > 63 || partlen == 0 {
|
|
return false
|
|
}
|
|
partlen = 0
|
|
}
|
|
last = c
|
|
}
|
|
if last == '-' || partlen > 63 {
|
|
return false
|
|
}
|
|
return nonNumeric
|
|
}
|
|
|
|
func randU16() uint16 {
|
|
var b [2]byte
|
|
_, err := rand.Read(b[:])
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return binary.LittleEndian.Uint16(b[:])
|
|
}
|
|
|
|
func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) {
|
|
id = randU16()
|
|
b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true})
|
|
b.EnableCompression()
|
|
if err := b.StartQuestions(); err != nil {
|
|
return 0, nil, nil, err
|
|
}
|
|
if err := b.Question(q); err != nil {
|
|
return 0, nil, nil, err
|
|
}
|
|
tcpReq, err = b.Finish()
|
|
udpReq = tcpReq[2:]
|
|
l := len(tcpReq) - 2
|
|
tcpReq[0] = byte(l >> 8)
|
|
tcpReq[1] = byte(l)
|
|
return id, udpReq, tcpReq, err
|
|
}
|
|
|
|
func equalASCIIName(x, y dnsmessage.Name) bool {
|
|
if x.Length != y.Length {
|
|
return false
|
|
}
|
|
for i := 0; i < int(x.Length); i++ {
|
|
a := x.Data[i]
|
|
b := y.Data[i]
|
|
if 'A' <= a && a <= 'Z' {
|
|
a += 0x20
|
|
}
|
|
if 'A' <= b && b <= 'Z' {
|
|
b += 0x20
|
|
}
|
|
if a != b {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool {
|
|
if !respHdr.Response {
|
|
return false
|
|
}
|
|
if reqID != respHdr.ID {
|
|
return false
|
|
}
|
|
if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func dnsPacketRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
|
|
if _, err := c.Write(b); err != nil {
|
|
return dnsmessage.Parser{}, dnsmessage.Header{}, err
|
|
}
|
|
b = make([]byte, 512)
|
|
for {
|
|
n, err := c.Read(b)
|
|
if err != nil {
|
|
return dnsmessage.Parser{}, dnsmessage.Header{}, err
|
|
}
|
|
var p dnsmessage.Parser
|
|
h, err := p.Start(b[:n])
|
|
if err != nil {
|
|
continue
|
|
}
|
|
q, err := p.Question()
|
|
if err != nil || !checkResponse(id, query, h, q) {
|
|
continue
|
|
}
|
|
return p, h, nil
|
|
}
|
|
}
|
|
|
|
func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
|
|
if _, err := c.Write(b); err != nil {
|
|
return dnsmessage.Parser{}, dnsmessage.Header{}, err
|
|
}
|
|
b = make([]byte, 1280)
|
|
if _, err := io.ReadFull(c, b[:2]); err != nil {
|
|
return dnsmessage.Parser{}, dnsmessage.Header{}, err
|
|
}
|
|
l := int(b[0])<<8 | int(b[1])
|
|
if l > len(b) {
|
|
b = make([]byte, l)
|
|
}
|
|
n, err := io.ReadFull(c, b[:l])
|
|
if err != nil {
|
|
return dnsmessage.Parser{}, dnsmessage.Header{}, err
|
|
}
|
|
var p dnsmessage.Parser
|
|
h, err := p.Start(b[:n])
|
|
if err != nil {
|
|
return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
|
|
}
|
|
q, err := p.Question()
|
|
if err != nil {
|
|
return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
|
|
}
|
|
if !checkResponse(id, query, h, q) {
|
|
return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
|
|
}
|
|
return p, h, nil
|
|
}
|
|
|
|
func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
|
|
q.Class = dnsmessage.ClassINET
|
|
id, udpReq, tcpReq, err := newRequest(q)
|
|
if err != nil {
|
|
return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage
|
|
}
|
|
|
|
for _, useUDP := range []bool{true, false} {
|
|
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
|
|
defer cancel()
|
|
|
|
var c net.Conn
|
|
var err error
|
|
if useUDP {
|
|
c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53))
|
|
} else {
|
|
c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53))
|
|
}
|
|
|
|
if err != nil {
|
|
return dnsmessage.Parser{}, dnsmessage.Header{}, err
|
|
}
|
|
if d, ok := ctx.Deadline(); ok && !d.IsZero() {
|
|
err := c.SetDeadline(d)
|
|
if err != nil {
|
|
return dnsmessage.Parser{}, dnsmessage.Header{}, err
|
|
}
|
|
}
|
|
var p dnsmessage.Parser
|
|
var h dnsmessage.Header
|
|
if useUDP {
|
|
p, h, err = dnsPacketRoundTrip(c, id, q, udpReq)
|
|
} else {
|
|
p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq)
|
|
}
|
|
c.Close()
|
|
if err != nil {
|
|
if err == context.Canceled {
|
|
err = errCanceled
|
|
} else if err == context.DeadlineExceeded {
|
|
err = errTimeout
|
|
}
|
|
return dnsmessage.Parser{}, dnsmessage.Header{}, err
|
|
}
|
|
if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone {
|
|
return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
|
|
}
|
|
if h.Truncated {
|
|
continue
|
|
}
|
|
return p, h, nil
|
|
}
|
|
return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer
|
|
}
|
|
|
|
func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error {
|
|
if h.RCode == dnsmessage.RCodeNameError {
|
|
return errNoSuchHost
|
|
}
|
|
_, err := p.AnswerHeader()
|
|
if err != nil && err != dnsmessage.ErrSectionDone {
|
|
return errCannotUnmarshalDNSMessage
|
|
}
|
|
if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone {
|
|
return errLameReferral
|
|
}
|
|
if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError {
|
|
if h.RCode == dnsmessage.RCodeServerFailure {
|
|
return errServerTemporarilyMisbehaving
|
|
}
|
|
return errServerMisbehaving
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error {
|
|
for {
|
|
h, err := p.AnswerHeader()
|
|
if err == dnsmessage.ErrSectionDone {
|
|
return errNoSuchHost
|
|
}
|
|
if err != nil {
|
|
return errCannotUnmarshalDNSMessage
|
|
}
|
|
if h.Type == qtype {
|
|
return nil
|
|
}
|
|
if err := p.SkipAnswer(); err != nil {
|
|
return errCannotUnmarshalDNSMessage
|
|
}
|
|
}
|
|
}
|
|
|
|
func (tnet *Net) tryOneName(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
|
|
var lastErr error
|
|
|
|
n, err := dnsmessage.NewName(name)
|
|
if err != nil {
|
|
return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage
|
|
}
|
|
q := dnsmessage.Question{
|
|
Name: n,
|
|
Type: qtype,
|
|
Class: dnsmessage.ClassINET,
|
|
}
|
|
|
|
for i := 0; i < 2; i++ {
|
|
for _, server := range tnet.dnsServers {
|
|
p, h, err := tnet.exchange(ctx, server, q, time.Second*5)
|
|
if err != nil {
|
|
dnsErr := &net.DNSError{
|
|
Err: err.Error(),
|
|
Name: name,
|
|
Server: server.String(),
|
|
}
|
|
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
|
dnsErr.IsTimeout = true
|
|
}
|
|
if _, ok := err.(*net.OpError); ok {
|
|
dnsErr.IsTemporary = true
|
|
}
|
|
lastErr = dnsErr
|
|
continue
|
|
}
|
|
|
|
if err := checkHeader(&p, h); err != nil {
|
|
dnsErr := &net.DNSError{
|
|
Err: err.Error(),
|
|
Name: name,
|
|
Server: server.String(),
|
|
}
|
|
if err == errServerTemporarilyMisbehaving {
|
|
dnsErr.IsTemporary = true
|
|
}
|
|
if err == errNoSuchHost {
|
|
dnsErr.IsNotFound = true
|
|
return p, server.String(), dnsErr
|
|
}
|
|
lastErr = dnsErr
|
|
continue
|
|
}
|
|
|
|
err = skipToAnswer(&p, qtype)
|
|
if err == nil {
|
|
return p, server.String(), nil
|
|
}
|
|
lastErr = &net.DNSError{
|
|
Err: err.Error(),
|
|
Name: name,
|
|
Server: server.String(),
|
|
}
|
|
if err == errNoSuchHost {
|
|
lastErr.(*net.DNSError).IsNotFound = true
|
|
return p, server.String(), lastErr
|
|
}
|
|
}
|
|
}
|
|
return dnsmessage.Parser{}, "", lastErr
|
|
}
|
|
|
|
func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, error) {
|
|
if host == "" || (!tnet.hasV6 && !tnet.hasV4) {
|
|
return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
|
|
}
|
|
zlen := len(host)
|
|
if strings.IndexByte(host, ':') != -1 {
|
|
if zidx := strings.LastIndexByte(host, '%'); zidx != -1 {
|
|
zlen = zidx
|
|
}
|
|
}
|
|
if ip, err := netip.ParseAddr(host[:zlen]); err == nil {
|
|
return []string{ip.String()}, nil
|
|
}
|
|
|
|
if !isDomainName(host) {
|
|
return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
|
|
}
|
|
type result struct {
|
|
p dnsmessage.Parser
|
|
server string
|
|
error
|
|
}
|
|
var addrsV4, addrsV6 []netip.Addr
|
|
lanes := 0
|
|
if tnet.hasV4 {
|
|
lanes++
|
|
}
|
|
if tnet.hasV6 {
|
|
lanes++
|
|
}
|
|
lane := make(chan result, lanes)
|
|
var lastErr error
|
|
if tnet.hasV4 {
|
|
go func() {
|
|
p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeA)
|
|
lane <- result{p, server, err}
|
|
}()
|
|
}
|
|
if tnet.hasV6 {
|
|
go func() {
|
|
p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeAAAA)
|
|
lane <- result{p, server, err}
|
|
}()
|
|
}
|
|
for l := 0; l < lanes; l++ {
|
|
result := <-lane
|
|
if result.error != nil {
|
|
if lastErr == nil {
|
|
lastErr = result.error
|
|
}
|
|
continue
|
|
}
|
|
|
|
loop:
|
|
for {
|
|
h, err := result.p.AnswerHeader()
|
|
if err != nil && err != dnsmessage.ErrSectionDone {
|
|
lastErr = &net.DNSError{
|
|
Err: errCannotMarshalDNSMessage.Error(),
|
|
Name: host,
|
|
Server: result.server,
|
|
}
|
|
}
|
|
if err != nil {
|
|
break
|
|
}
|
|
switch h.Type {
|
|
case dnsmessage.TypeA:
|
|
a, err := result.p.AResource()
|
|
if err != nil {
|
|
lastErr = &net.DNSError{
|
|
Err: errCannotMarshalDNSMessage.Error(),
|
|
Name: host,
|
|
Server: result.server,
|
|
}
|
|
break loop
|
|
}
|
|
addrsV4 = append(addrsV4, netip.AddrFrom4(a.A))
|
|
|
|
case dnsmessage.TypeAAAA:
|
|
aaaa, err := result.p.AAAAResource()
|
|
if err != nil {
|
|
lastErr = &net.DNSError{
|
|
Err: errCannotMarshalDNSMessage.Error(),
|
|
Name: host,
|
|
Server: result.server,
|
|
}
|
|
break loop
|
|
}
|
|
addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA))
|
|
|
|
default:
|
|
if err := result.p.SkipAnswer(); err != nil {
|
|
lastErr = &net.DNSError{
|
|
Err: errCannotMarshalDNSMessage.Error(),
|
|
Name: host,
|
|
Server: result.server,
|
|
}
|
|
break loop
|
|
}
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
// We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled
|
|
var addrs []netip.Addr
|
|
if tnet.hasV6 {
|
|
addrs = append(addrsV6, addrsV4...)
|
|
} else {
|
|
addrs = append(addrsV4, addrsV6...)
|
|
}
|
|
|
|
if len(addrs) == 0 && lastErr != nil {
|
|
return nil, lastErr
|
|
}
|
|
saddrs := make([]string, 0, len(addrs))
|
|
for _, ip := range addrs {
|
|
saddrs = append(saddrs, ip.String())
|
|
}
|
|
return saddrs, nil
|
|
}
|
|
|
|
func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) {
|
|
if deadline.IsZero() {
|
|
return deadline, nil
|
|
}
|
|
timeRemaining := deadline.Sub(now)
|
|
if timeRemaining <= 0 {
|
|
return time.Time{}, errTimeout
|
|
}
|
|
timeout := timeRemaining / time.Duration(addrsRemaining)
|
|
const saneMinimum = 2 * time.Second
|
|
if timeout < saneMinimum {
|
|
if timeRemaining < saneMinimum {
|
|
timeout = timeRemaining
|
|
} else {
|
|
timeout = saneMinimum
|
|
}
|
|
}
|
|
return now.Add(timeout), nil
|
|
}
|
|
|
|
var protoSplitter = regexp.MustCompile(`^(tcp|udp|ping)(4|6)?$`)
|
|
|
|
func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
|
if ctx == nil {
|
|
panic("nil context")
|
|
}
|
|
var acceptV4, acceptV6 bool
|
|
matches := protoSplitter.FindStringSubmatch(network)
|
|
if matches == nil {
|
|
return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
|
|
} else if len(matches[2]) == 0 {
|
|
acceptV4 = true
|
|
acceptV6 = true
|
|
} else {
|
|
acceptV4 = matches[2][0] == '4'
|
|
acceptV6 = !acceptV4
|
|
}
|
|
var host string
|
|
var port int
|
|
if matches[1] == "ping" {
|
|
host = address
|
|
} else {
|
|
var sport string
|
|
var err error
|
|
host, sport, err = net.SplitHostPort(address)
|
|
if err != nil {
|
|
return nil, &net.OpError{Op: "dial", Err: err}
|
|
}
|
|
port, err = strconv.Atoi(sport)
|
|
if err != nil || port < 0 || port > 65535 {
|
|
return nil, &net.OpError{Op: "dial", Err: errNumericPort}
|
|
}
|
|
}
|
|
allAddr, err := tnet.LookupContextHost(ctx, host)
|
|
if err != nil {
|
|
return nil, &net.OpError{Op: "dial", Err: err}
|
|
}
|
|
var addrs []netip.AddrPort
|
|
for _, addr := range allAddr {
|
|
ip, err := netip.ParseAddr(addr)
|
|
if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) {
|
|
addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port)))
|
|
}
|
|
}
|
|
if len(addrs) == 0 && len(allAddr) != 0 {
|
|
return nil, &net.OpError{Op: "dial", Err: errNoSuitableAddress}
|
|
}
|
|
|
|
var firstErr error
|
|
for i, addr := range addrs {
|
|
select {
|
|
case <-ctx.Done():
|
|
err := ctx.Err()
|
|
if err == context.Canceled {
|
|
err = errCanceled
|
|
} else if err == context.DeadlineExceeded {
|
|
err = errTimeout
|
|
}
|
|
return nil, &net.OpError{Op: "dial", Err: err}
|
|
default:
|
|
}
|
|
|
|
dialCtx := ctx
|
|
if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
|
|
partialDeadline, err := partialDeadline(time.Now(), deadline, len(addrs)-i)
|
|
if err != nil {
|
|
if firstErr == nil {
|
|
firstErr = &net.OpError{Op: "dial", Err: err}
|
|
}
|
|
break
|
|
}
|
|
if partialDeadline.Before(deadline) {
|
|
var cancel context.CancelFunc
|
|
dialCtx, cancel = context.WithDeadline(ctx, partialDeadline)
|
|
defer cancel()
|
|
}
|
|
}
|
|
|
|
var c net.Conn
|
|
switch matches[1] {
|
|
case "tcp":
|
|
c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
|
|
case "udp":
|
|
c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
|
|
case "ping":
|
|
c, err = tnet.DialPingAddr(netip.Addr{}, addr.Addr())
|
|
}
|
|
if err == nil {
|
|
return c, nil
|
|
}
|
|
if firstErr == nil {
|
|
firstErr = err
|
|
}
|
|
}
|
|
if firstErr == nil {
|
|
firstErr = &net.OpError{Op: "dial", Err: errMissingAddress}
|
|
}
|
|
return nil, firstErr
|
|
}
|
|
|
|
func (tnet *Net) Dial(network, address string) (net.Conn, error) {
|
|
return tnet.DialContext(context.Background(), network, address)
|
|
}
|