diff --git a/tun/tcp_offload_linux.go b/tun/tcp_offload_linux.go index 4912efd..39a7180 100644 --- a/tun/tcp_offload_linux.go +++ b/tun/tcp_offload_linux.go @@ -189,14 +189,29 @@ func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet return coalesceUnavailable } } - if pkt[1] != pktTarget[1] { - // cannot coalesce with unequal ToS values - return coalesceUnavailable - } - if pkt[6]>>5 != pktTarget[6]>>5 { - // cannot coalesce with unequal DF or reserved bits. MF is checked - // further up the stack. - return coalesceUnavailable + if pkt[0]>>4 == 6 { + if pkt[0] != pktTarget[0] || pkt[1]>>4 != pktTarget[1]>>4 { + // cannot coalesce with unequal Traffic class values + return coalesceUnavailable + } + if pkt[7] != pktTarget[7] { + // cannot coalesce with unequal Hop limit values + return coalesceUnavailable + } + } else { + if pkt[1] != pktTarget[1] { + // cannot coalesce with unequal ToS values + return coalesceUnavailable + } + if pkt[6]>>5 != pktTarget[6]>>5 { + // cannot coalesce with unequal DF or reserved bits. MF is checked + // further up the stack. + return coalesceUnavailable + } + if pkt[8] != pktTarget[8] { + // cannot coalesce with unequal TTL values + return coalesceUnavailable + } } // seq adjacency lhsLen := item.gsoSize @@ -366,7 +381,7 @@ func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize } const ( - ipv4FlagMoreFragments = 0x80 + ipv4FlagMoreFragments uint8 = 0x20 ) const ( @@ -409,7 +424,7 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) return false } if !isV6 { - if pkt[6]&ipv4FlagMoreFragments != 0 || (pkt[6]<<3 != 0 || pkt[7] != 0) { + if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { // no GRO support for fragmented segments for now return false } diff --git a/tun/tcp_offload_linux_test.go b/tun/tcp_offload_linux_test.go index 046e177..9160e18 100644 --- a/tun/tcp_offload_linux_test.go +++ b/tun/tcp_offload_linux_test.go @@ -28,19 +28,23 @@ var ( ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1") ) -func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { +func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte { totalLen := 40 + segmentSize b := make([]byte, offset+int(totalLen), 65535) ipv4H := header.IPv4(b[offset:]) srcAs4 := srcIPPort.Addr().As4() dstAs4 := dstIPPort.Addr().As4() - ipv4H.Encode(&header.IPv4Fields{ + ipFields := &header.IPv4Fields{ SrcAddr: tcpip.Address(srcAs4[:]), DstAddr: tcpip.Address(dstAs4[:]), Protocol: unix.IPPROTO_TCP, TTL: 64, TotalLength: uint16(totalLen), - }) + } + if ipFn != nil { + ipFn(ipFields) + } + ipv4H.Encode(ipFields) tcpH := header.TCP(b[offset+20:]) tcpH.Encode(&header.TCPFields{ SrcPort: srcIPPort.Port(), @@ -57,19 +61,27 @@ func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segm return b } -func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { +func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { + return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil) +} + +func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte { totalLen := 60 + segmentSize b := make([]byte, offset+int(totalLen), 65535) ipv6H := header.IPv6(b[offset:]) srcAs16 := srcIPPort.Addr().As16() dstAs16 := dstIPPort.Addr().As16() - ipv6H.Encode(&header.IPv6Fields{ + ipFields := &header.IPv6Fields{ SrcAddr: tcpip.Address(srcAs16[:]), DstAddr: tcpip.Address(dstAs16[:]), TransportProtocol: unix.IPPROTO_TCP, HopLimit: 64, PayloadLength: uint16(segmentSize + 20), - }) + } + if ipFn != nil { + ipFn(ipFields) + } + ipv6H.Encode(ipFields) tcpH := header.TCP(b[offset+40:]) tcpH.Encode(&header.TCPFields{ SrcPort: srcIPPort.Port(), @@ -85,6 +97,10 @@ func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segm return b } +func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { + return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil) +} + func Test_handleVirtioRead(t *testing.T) { tests := []struct { name string @@ -245,6 +261,78 @@ func Test_handleGRO(t *testing.T) { []int{340}, false, }, + { + "tcp4 unequal TTL", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), + tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { + fields.TTL++ + }), + }, + []int{0, 1}, + []int{140, 140}, + false, + }, + { + "tcp4 unequal ToS", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), + tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { + fields.TOS++ + }), + }, + []int{0, 1}, + []int{140, 140}, + false, + }, + { + "tcp4 unequal flags more fragments set", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), + tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { + fields.Flags = 1 + }), + }, + []int{0, 1}, + []int{140, 140}, + false, + }, + { + "tcp4 unequal flags DF set", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), + tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { + fields.Flags = 2 + }), + }, + []int{0, 1}, + []int{140, 140}, + false, + }, + { + "tcp6 unequal hop limit", + [][]byte{ + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), + tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { + fields.HopLimit++ + }), + }, + []int{0, 1}, + []int{160, 160}, + false, + }, + { + "tcp6 unequal traffic class", + [][]byte{ + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), + tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { + fields.TrafficClass++ + }), + }, + []int{0, 1}, + []int{160, 160}, + false, + }, } for _, tt := range tests {