tun/netstack: implement ICMP ping
Provide a PacketConn interface for netstack's ICMP endpoint; netstack currently only provides EchoRequest/EchoResponse ICMP support, so this code exposes only an interface for doing ping. Signed-off-by: Thomas Ptacek <thomas@sockpuppet.org> [Jason: rework structure, match std go interfaces, add example code] Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
parent
e0b8f11489
commit
b9669b734e
76
tun/netstack/examples/ping_client.go
Normal file
76
tun/netstack/examples/ping_client.go
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
//go:build ignore
|
||||||
|
// +build ignore
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"log"
|
||||||
|
"math/rand"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/icmp"
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/go118/netip"
|
||||||
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
|
"golang.zx2c4.com/wireguard/device"
|
||||||
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
tun, tnet, err := netstack.CreateNetTUN(
|
||||||
|
[]netip.Addr{netip.MustParseAddr("192.168.4.29")},
|
||||||
|
[]netip.Addr{netip.MustParseAddr("8.8.8.8")},
|
||||||
|
1420)
|
||||||
|
if err != nil {
|
||||||
|
log.Panic(err)
|
||||||
|
}
|
||||||
|
dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
|
||||||
|
dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f
|
||||||
|
public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b
|
||||||
|
endpoint=163.172.161.0:12912
|
||||||
|
allowed_ip=0.0.0.0/0
|
||||||
|
`)
|
||||||
|
err = dev.Up()
|
||||||
|
if err != nil {
|
||||||
|
log.Panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
socket, err := tnet.Dial("ping4", "zx2c4.com")
|
||||||
|
if err != nil {
|
||||||
|
log.Panic(err)
|
||||||
|
}
|
||||||
|
requestPing := icmp.Echo{
|
||||||
|
Seq: rand.Intn(1 << 16),
|
||||||
|
Data: []byte("gopher burrow"),
|
||||||
|
}
|
||||||
|
icmpBytes, _ := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil)
|
||||||
|
socket.SetReadDeadline(time.Now().Add(time.Second * 10))
|
||||||
|
start := time.Now()
|
||||||
|
_, err = socket.Write(icmpBytes)
|
||||||
|
if err != nil {
|
||||||
|
log.Panic(err)
|
||||||
|
}
|
||||||
|
n, err := socket.Read(icmpBytes[:])
|
||||||
|
if err != nil {
|
||||||
|
log.Panic(err)
|
||||||
|
}
|
||||||
|
replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n])
|
||||||
|
if err != nil {
|
||||||
|
log.Panic(err)
|
||||||
|
}
|
||||||
|
replyPing, ok := replyPacket.Body.(*icmp.Echo)
|
||||||
|
if !ok {
|
||||||
|
log.Panicf("invalid reply type: %v", replyPacket)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq {
|
||||||
|
log.Panicf("invalid ping reply: %v", replyPing)
|
||||||
|
}
|
||||||
|
log.Printf("Ping latency: %v", time.Since(start))
|
||||||
|
}
|
@ -14,8 +14,10 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.zx2c4.com/go118/netip"
|
"golang.zx2c4.com/go118/netip"
|
||||||
@ -29,8 +31,10 @@ import (
|
|||||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||||
|
"gvisor.dev/gvisor/pkg/waiter"
|
||||||
)
|
)
|
||||||
|
|
||||||
type netTun struct {
|
type netTun struct {
|
||||||
@ -101,7 +105,7 @@ func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.Network
|
|||||||
func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
|
func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
|
||||||
opts := stack.Options{
|
opts := stack.Options{
|
||||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
||||||
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
|
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
|
||||||
HandleLocal: true,
|
HandleLocal: true,
|
||||||
}
|
}
|
||||||
dev := &netTun{
|
dev := &netTun{
|
||||||
@ -270,6 +274,10 @@ func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, er
|
|||||||
return gonet.DialUDP(net.stack, lfa, rfa, pn)
|
return gonet.DialUDP(net.stack, lfa, rfa, pn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) {
|
||||||
|
return net.DialUDPAddrPort(laddr, netip.AddrPort{})
|
||||||
|
}
|
||||||
|
|
||||||
func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
|
func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
|
||||||
var la, ra netip.AddrPort
|
var la, ra netip.AddrPort
|
||||||
if laddr != nil {
|
if laddr != nil {
|
||||||
@ -281,6 +289,233 @@ func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
|
|||||||
return net.DialUDPAddrPort(la, ra)
|
return net.DialUDPAddrPort(la, ra)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
|
||||||
|
return net.DialUDP(laddr, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
type PingConn struct {
|
||||||
|
laddr PingAddr
|
||||||
|
raddr PingAddr
|
||||||
|
wq waiter.Queue
|
||||||
|
ep tcpip.Endpoint
|
||||||
|
mu sync.RWMutex
|
||||||
|
deadline time.Time
|
||||||
|
deadlineBreaker chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type PingAddr struct{ addr netip.Addr }
|
||||||
|
|
||||||
|
func (ia PingAddr) String() string {
|
||||||
|
return ia.addr.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ia PingAddr) Network() string {
|
||||||
|
if ia.addr.Is4() {
|
||||||
|
return "ping4"
|
||||||
|
} else if ia.addr.Is6() {
|
||||||
|
return "ping6"
|
||||||
|
}
|
||||||
|
return "ping"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ia PingAddr) Addr() netip.Addr {
|
||||||
|
return ia.addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func PingAddrFromAddr(addr netip.Addr) *PingAddr {
|
||||||
|
return &PingAddr{addr}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) {
|
||||||
|
if !laddr.IsValid() && !raddr.IsValid() {
|
||||||
|
return nil, errors.New("ping dial: invalid address")
|
||||||
|
}
|
||||||
|
v6 := laddr.Is6() || raddr.Is6()
|
||||||
|
bind := laddr.IsValid()
|
||||||
|
if !bind {
|
||||||
|
if v6 {
|
||||||
|
laddr = netip.IPv6Unspecified()
|
||||||
|
} else {
|
||||||
|
laddr = netip.IPv4Unspecified()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tn := icmp.ProtocolNumber4
|
||||||
|
pn := ipv4.ProtocolNumber
|
||||||
|
if v6 {
|
||||||
|
tn = icmp.ProtocolNumber6
|
||||||
|
pn = ipv6.ProtocolNumber
|
||||||
|
}
|
||||||
|
|
||||||
|
pc := &PingConn{
|
||||||
|
laddr: PingAddr{laddr},
|
||||||
|
deadlineBreaker: make(chan struct{}, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq)
|
||||||
|
if tcpipErr != nil {
|
||||||
|
return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr)
|
||||||
|
}
|
||||||
|
pc.ep = ep
|
||||||
|
|
||||||
|
if bind {
|
||||||
|
fa, _ := convertToFullAddr(netip.AddrPortFrom(laddr, 0))
|
||||||
|
if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil {
|
||||||
|
return nil, fmt.Errorf("ping bind: %s", tcpipErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if raddr.IsValid() {
|
||||||
|
pc.raddr = PingAddr{raddr}
|
||||||
|
fa, _ := convertToFullAddr(netip.AddrPortFrom(raddr, 0))
|
||||||
|
if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil {
|
||||||
|
return nil, fmt.Errorf("ping connect: %s", tcpipErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return pc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (net *Net) ListenPingAddr(laddr netip.Addr) (*PingConn, error) {
|
||||||
|
return net.DialPingAddr(laddr, netip.Addr{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (net *Net) DialPing(laddr, raddr *PingAddr) (*PingConn, error) {
|
||||||
|
var la, ra netip.Addr
|
||||||
|
if laddr != nil {
|
||||||
|
la = laddr.addr
|
||||||
|
}
|
||||||
|
if raddr != nil {
|
||||||
|
ra = raddr.addr
|
||||||
|
}
|
||||||
|
return net.DialPingAddr(la, ra)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (net *Net) ListenPing(laddr *PingAddr) (*PingConn, error) {
|
||||||
|
var la netip.Addr
|
||||||
|
if laddr != nil {
|
||||||
|
la = laddr.addr
|
||||||
|
}
|
||||||
|
return net.ListenPingAddr(la)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pc *PingConn) LocalAddr() net.Addr {
|
||||||
|
return pc.laddr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pc *PingConn) RemoteAddr() net.Addr {
|
||||||
|
return pc.raddr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pc *PingConn) Close() error {
|
||||||
|
close(pc.deadlineBreaker)
|
||||||
|
pc.ep.Close()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pc *PingConn) SetWriteDeadline(t time.Time) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||||
|
var na netip.Addr
|
||||||
|
switch v := addr.(type) {
|
||||||
|
case *PingAddr:
|
||||||
|
na = v.addr
|
||||||
|
case *net.IPAddr:
|
||||||
|
na = netip.AddrFromSlice(v.IP)
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("ping write: wrong net.Addr type")
|
||||||
|
}
|
||||||
|
if !((na.Is4() && pc.laddr.addr.Is4()) || (na.Is6() && pc.laddr.addr.Is6())) {
|
||||||
|
return 0, fmt.Errorf("ping write: mismatched protocols")
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := buffer.NewViewFromBytes(p)
|
||||||
|
rdr := buf.Reader()
|
||||||
|
rfa, _ := convertToFullAddr(netip.AddrPortFrom(na, 0))
|
||||||
|
// won't block, no deadlines
|
||||||
|
n64, tcpipErr := pc.ep.Write(&rdr, tcpip.WriteOptions{
|
||||||
|
To: &rfa,
|
||||||
|
})
|
||||||
|
if tcpipErr != nil {
|
||||||
|
return int(n64), fmt.Errorf("ping write: %s", tcpipErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return int(n64), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pc *PingConn) Write(p []byte) (n int, err error) {
|
||||||
|
return pc.WriteTo(p, &pc.raddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||||
|
e, notifyCh := waiter.NewChannelEntry(nil)
|
||||||
|
pc.wq.EventRegister(&e, waiter.EventIn)
|
||||||
|
defer pc.wq.EventUnregister(&e)
|
||||||
|
|
||||||
|
ready := false
|
||||||
|
|
||||||
|
for !ready {
|
||||||
|
pc.mu.RLock()
|
||||||
|
deadlineBreaker := pc.deadlineBreaker
|
||||||
|
deadline := pc.deadline
|
||||||
|
pc.mu.RUnlock()
|
||||||
|
|
||||||
|
if deadline.IsZero() {
|
||||||
|
select {
|
||||||
|
case <-deadlineBreaker:
|
||||||
|
case <-notifyCh:
|
||||||
|
ready = true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t := time.NewTimer(deadline.Sub(time.Now()))
|
||||||
|
defer t.Stop()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-t.C:
|
||||||
|
return 0, nil, os.ErrDeadlineExceeded
|
||||||
|
|
||||||
|
case <-deadlineBreaker:
|
||||||
|
case <-notifyCh:
|
||||||
|
ready = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
w := tcpip.SliceWriter(p)
|
||||||
|
|
||||||
|
res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{
|
||||||
|
NeedRemoteAddr: true,
|
||||||
|
})
|
||||||
|
if tcpipErr != nil {
|
||||||
|
return 0, nil, fmt.Errorf("ping read: %s", tcpipErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
addr = &PingAddr{netip.AddrFromSlice([]byte(res.RemoteAddr.Addr))}
|
||||||
|
return res.Count, addr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pc *PingConn) Read(p []byte) (n int, err error) {
|
||||||
|
n, _, err = pc.ReadFrom(p)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pc *PingConn) SetDeadline(t time.Time) error {
|
||||||
|
// pc.SetWriteDeadline is unimplemented
|
||||||
|
|
||||||
|
return pc.SetReadDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pc *PingConn) SetReadDeadline(t time.Time) error {
|
||||||
|
pc.mu.Lock()
|
||||||
|
defer pc.mu.Unlock()
|
||||||
|
close(pc.deadlineBreaker)
|
||||||
|
pc.deadlineBreaker = make(chan struct{}, 1)
|
||||||
|
pc.deadline = t
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errNoSuchHost = errors.New("no such host")
|
errNoSuchHost = errors.New("no such host")
|
||||||
errLameReferral = errors.New("lame referral")
|
errLameReferral = errors.New("lame referral")
|
||||||
@ -755,34 +990,39 @@ func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, er
|
|||||||
return now.Add(timeout), nil
|
return now.Add(timeout), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var protoSplitter = regexp.MustCompile(`^(tcp|udp|ping)(4|6)?$`)
|
||||||
|
|
||||||
func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
panic("nil context")
|
panic("nil context")
|
||||||
}
|
}
|
||||||
var acceptV4, acceptV6, useUDP bool
|
var acceptV4, acceptV6 bool
|
||||||
if len(network) == 3 {
|
matches := protoSplitter.FindStringSubmatch(network)
|
||||||
|
if matches == nil {
|
||||||
|
return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
|
||||||
|
} else if len(matches[2]) == 0 {
|
||||||
acceptV4 = true
|
acceptV4 = true
|
||||||
acceptV6 = true
|
acceptV6 = true
|
||||||
} else if len(network) == 4 {
|
} else {
|
||||||
acceptV4 = network[3] == '4'
|
acceptV4 = matches[2][0] == '4'
|
||||||
acceptV6 = network[3] == '6'
|
acceptV6 = !acceptV4
|
||||||
}
|
}
|
||||||
if !acceptV4 && !acceptV6 {
|
var host string
|
||||||
return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
|
var port int
|
||||||
}
|
if matches[1] == "ping" {
|
||||||
if network[:3] == "udp" {
|
host = address
|
||||||
useUDP = true
|
} else {
|
||||||
} else if network[:3] != "tcp" {
|
var sport string
|
||||||
return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
|
var err error
|
||||||
}
|
host, sport, err = net.SplitHostPort(address)
|
||||||
host, sport, err := net.SplitHostPort(address)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &net.OpError{Op: "dial", Err: err}
|
return nil, &net.OpError{Op: "dial", Err: err}
|
||||||
}
|
}
|
||||||
port, err := strconv.Atoi(sport)
|
port, err = strconv.Atoi(sport)
|
||||||
if err != nil || port < 0 || port > 65535 {
|
if err != nil || port < 0 || port > 65535 {
|
||||||
return nil, &net.OpError{Op: "dial", Err: errNumericPort}
|
return nil, &net.OpError{Op: "dial", Err: errNumericPort}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
allAddr, err := tnet.LookupContextHost(ctx, host)
|
allAddr, err := tnet.LookupContextHost(ctx, host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &net.OpError{Op: "dial", Err: err}
|
return nil, &net.OpError{Op: "dial", Err: err}
|
||||||
@ -829,10 +1069,13 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
|
|||||||
}
|
}
|
||||||
|
|
||||||
var c net.Conn
|
var c net.Conn
|
||||||
if useUDP {
|
switch matches[1] {
|
||||||
c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
|
case "tcp":
|
||||||
} else {
|
|
||||||
c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
|
c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
|
||||||
|
case "udp":
|
||||||
|
c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
|
||||||
|
case "ping":
|
||||||
|
c, err = tnet.DialPingAddr(netip.Addr{}, addr.Addr())
|
||||||
}
|
}
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return c, nil
|
return c, nil
|
||||||
|
Loading…
Reference in New Issue
Block a user