conn: use netip for std bind

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2022-03-17 22:23:02 -06:00
parent ee1c8e0e87
commit 193cf8d6a5

View File

@ -27,7 +27,7 @@ type StdNetBind struct {
func NewStdNetBind() Bind { return &StdNetBind{} } func NewStdNetBind() Bind { return &StdNetBind{} }
type StdNetEndpoint net.UDPAddr type StdNetEndpoint netip.AddrPort
var ( var (
_ Bind = (*StdNetBind)(nil) _ Bind = (*StdNetBind)(nil)
@ -36,18 +36,13 @@ var (
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
e, err := netip.ParseAddrPort(s) e, err := netip.ParseAddrPort(s)
return (*StdNetEndpoint)(&net.UDPAddr{ return (*StdNetEndpoint)(&e), err
IP: e.Addr().AsSlice(),
Port: int(e.Port()),
Zone: e.Addr().Zone(),
}), err
} }
func (*StdNetEndpoint) ClearSrc() {} func (*StdNetEndpoint) ClearSrc() {}
func (e *StdNetEndpoint) DstIP() netip.Addr { func (e *StdNetEndpoint) DstIP() netip.Addr {
a, _ := netip.AddrFromSlice((*net.UDPAddr)(e).IP) return (*netip.AddrPort)(e).Addr()
return a
} }
func (e *StdNetEndpoint) SrcIP() netip.Addr { func (e *StdNetEndpoint) SrcIP() netip.Addr {
@ -55,18 +50,12 @@ func (e *StdNetEndpoint) SrcIP() netip.Addr {
} }
func (e *StdNetEndpoint) DstToBytes() []byte { func (e *StdNetEndpoint) DstToBytes() []byte {
addr := (*net.UDPAddr)(e) b, _ := (*netip.AddrPort)(e).MarshalBinary()
out := addr.IP.To4() return b
if out == nil {
out = addr.IP
}
out = append(out, byte(addr.Port&0xff))
out = append(out, byte((addr.Port>>8)&0xff))
return out
} }
func (e *StdNetEndpoint) DstToString() string { func (e *StdNetEndpoint) DstToString() string {
return (*net.UDPAddr)(e).String() return (*netip.AddrPort)(e).String()
} }
func (e *StdNetEndpoint) SrcToString() string { func (e *StdNetEndpoint) SrcToString() string {
@ -162,18 +151,15 @@ func (bind *StdNetBind) Close() error {
func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc { func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc {
return func(buff []byte) (int, Endpoint, error) { return func(buff []byte) (int, Endpoint, error) {
n, endpoint, err := conn.ReadFromUDP(buff) n, endpoint, err := conn.ReadFromUDPAddrPort(buff)
if endpoint != nil { return n, (*StdNetEndpoint)(&endpoint), err
endpoint.IP = endpoint.IP.To4()
}
return n, (*StdNetEndpoint)(endpoint), err
} }
} }
func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc { func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
return func(buff []byte) (int, Endpoint, error) { return func(buff []byte) (int, Endpoint, error) {
n, endpoint, err := conn.ReadFromUDP(buff) n, endpoint, err := conn.ReadFromUDPAddrPort(buff)
return n, (*StdNetEndpoint)(endpoint), err return n, (*StdNetEndpoint)(&endpoint), err
} }
} }
@ -183,11 +169,12 @@ func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
if !ok { if !ok {
return ErrWrongEndpointType return ErrWrongEndpointType
} }
addrPort := (*netip.AddrPort)(nend)
bind.mu.Lock() bind.mu.Lock()
blackhole := bind.blackhole4 blackhole := bind.blackhole4
conn := bind.ipv4 conn := bind.ipv4
if nend.IP.To4() == nil { if addrPort.Addr().Is6() {
blackhole = bind.blackhole6 blackhole = bind.blackhole6
conn = bind.ipv6 conn = bind.ipv6
} }
@ -199,6 +186,6 @@ func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
if conn == nil { if conn == nil {
return syscall.EAFNOSUPPORT return syscall.EAFNOSUPPORT
} }
_, err = conn.WriteToUDP(buff, (*net.UDPAddr)(nend)) _, err = conn.WriteToUDPAddrPort(buff, *addrPort)
return err return err
} }