diff --git a/tun/tcp_offload_linux.go b/tun/tcp_offload_linux.go index 39a7180..1afd27e 100644 --- a/tun/tcp_offload_linux.go +++ b/tun/tcp_offload_linux.go @@ -269,11 +269,11 @@ func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool { type coalesceResult int const ( - coalesceInsufficientCap coalesceResult = 0 - coalescePSHEnding coalesceResult = 1 - coalesceItemInvalidCSum coalesceResult = 2 - coalescePktInvalidCSum coalesceResult = 3 - coalesceSuccess coalesceResult = 4 + coalesceInsufficientCap coalesceResult = iota + coalescePSHEnding + coalesceItemInvalidCSum + coalescePktInvalidCSum + coalesceSuccess ) // coalesceTCPPackets attempts to coalesce pkt with the packet described by @@ -339,42 +339,6 @@ func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize if gsoSize > item.gsoSize { item.gsoSize = gsoSize } - hdr := virtioNetHdr{ - flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb - hdrLen: uint16(headersLen), - gsoSize: uint16(item.gsoSize), - csumStart: uint16(item.iphLen), - csumOffset: 16, - } - - // Recalculate the total len (IPv4) or payload len (IPv6). Recalculate the - // (IPv4) header checksum. - if isV6 { - hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6 - binary.BigEndian.PutUint16(pktHead[4:], uint16(coalescedLen)-uint16(item.iphLen)) // set new payload len - } else { - hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4 - pktHead[10], pktHead[11] = 0, 0 // clear checksum field - binary.BigEndian.PutUint16(pktHead[2:], uint16(coalescedLen)) // set new total length - iphCSum := ^checksum(pktHead[:item.iphLen], 0) // compute checksum - binary.BigEndian.PutUint16(pktHead[10:], iphCSum) // set checksum field - } - hdr.encode(bufs[item.bufsIndex][bufsOffset-virtioNetHdrLen:]) - - // Calculate the pseudo header checksum and place it at the TCP checksum - // offset. Downstream checksum offloading will combine this with computation - // of the tcp header and payload checksum. - addrLen := 4 - addrOffset := ipv4SrcAddrOffset - if isV6 { - addrLen = 16 - addrOffset = ipv6SrcAddrOffset - } - srcAddrAt := bufsOffset + addrOffset - srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] - dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] - psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(coalescedLen-int(item.iphLen))) - binary.BigEndian.PutUint16(pktHead[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) item.numMerged++ return coalesceSuccess @@ -390,43 +354,52 @@ const ( maxUint16 = 1<<16 - 1 ) +type tcpGROResult int + +const ( + tcpGROResultNoop tcpGROResult = iota + tcpGROResultTableInsert + tcpGROResultCoalesced +) + // tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with -// existing packets tracked in table. It will return false when pktI is not -// coalesced, otherwise true. This indicates to the caller if bufs[pktI] -// should be written to the Device. -func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) (pktCoalesced bool) { +// existing packets tracked in table. It returns a tcpGROResultNoop when no +// action was taken, tcpGROResultTableInsert when the evaluated packet was +// inserted into table, and tcpGROResultCoalesced when the evaluated packet was +// coalesced with another packet in table. +func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) tcpGROResult { pkt := bufs[pktI][offset:] if len(pkt) > maxUint16 { // A valid IPv4 or IPv6 packet will never exceed this. - return false + return tcpGROResultNoop } iphLen := int((pkt[0] & 0x0F) * 4) if isV6 { iphLen = 40 ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) if ipv6HPayloadLen != len(pkt)-iphLen { - return false + return tcpGROResultNoop } } else { totalLen := int(binary.BigEndian.Uint16(pkt[2:])) if totalLen != len(pkt) { - return false + return tcpGROResultNoop } } if len(pkt) < iphLen { - return false + return tcpGROResultNoop } tcphLen := int((pkt[iphLen+12] >> 4) * 4) if tcphLen < 20 || tcphLen > 60 { - return false + return tcpGROResultNoop } if len(pkt) < iphLen+tcphLen { - return false + return tcpGROResultNoop } if !isV6 { if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { // no GRO support for fragmented segments for now - return false + return tcpGROResultNoop } } tcpFlags := pkt[iphLen+tcpFlagsOffset] @@ -434,14 +407,14 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) // not a candidate if any non-ACK flags (except PSH+ACK) are set if tcpFlags != tcpFlagACK { if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH { - return false + return tcpGROResultNoop } pshSet = true } gsoSize := uint16(len(pkt) - tcphLen - iphLen) // not a candidate if payload len is 0 if gsoSize < 1 { - return false + return tcpGROResultNoop } seq := binary.BigEndian.Uint32(pkt[iphLen+4:]) srcAddrOffset := ipv4SrcAddrOffset @@ -452,7 +425,7 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) } items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) if !existing { - return false + return tcpGROResultNoop } for i := len(items) - 1; i >= 0; i-- { // In the best case of packets arriving in order iterating in reverse is @@ -470,20 +443,20 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) switch result { case coalesceSuccess: table.updateAt(item, i) - return true + return tcpGROResultCoalesced case coalesceItemInvalidCSum: // delete the item with an invalid csum table.deleteAt(item.key, i) case coalescePktInvalidCSum: // no point in inserting an item that we can't coalesce - return false + return tcpGROResultNoop default: } } } // failed to coalesce with any other packets; store the item in the flow table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) - return false + return tcpGROResultTableInsert } func isTCP4NoIPOptions(b []byte) bool { @@ -515,6 +488,64 @@ func isTCP6NoEH(b []byte) bool { return true } +// applyCoalesceAccounting updates bufs to account for coalescing based on the +// metadata found in table. +func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6 bool) error { + for _, items := range table.itemsByFlow { + for _, item := range items { + if item.numMerged > 0 { + hdr := virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb + hdrLen: uint16(item.iphLen + item.tcphLen), + gsoSize: item.gsoSize, + csumStart: uint16(item.iphLen), + csumOffset: 16, + } + pkt := bufs[item.bufsIndex][offset:] + + // Recalculate the total len (IPv4) or payload len (IPv6). + // Recalculate the (IPv4) header checksum. + if isV6 { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6 + binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len + } else { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4 + pkt[10], pkt[11] = 0, 0 + binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length + iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum + binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field + } + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + + // Calculate the pseudo header checksum and place it at the TCP + // checksum offset. Downstream checksum offloading will combine + // this with computation of the tcp header and payload checksum. + addrLen := 4 + addrOffset := ipv4SrcAddrOffset + if isV6 { + addrLen = 16 + addrOffset = ipv6SrcAddrOffset + } + srcAddrAt := offset + addrOffset + srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] + dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] + psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) + } else { + hdr := virtioNetHdr{} + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + } + } + } + return nil +} + // handleGRO evaluates bufs for GRO, and writes the indices of the resulting // packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be // empty (but non-nil), and are passed in to save allocs as the caller may reset @@ -524,23 +555,28 @@ func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toW if offset < virtioNetHdrLen || offset > len(bufs[i])-1 { return errors.New("invalid offset") } - var coalesced bool + var result tcpGROResult switch { case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce - coalesced = tcpGRO(bufs, offset, i, tcp4Table, false) + result = tcpGRO(bufs, offset, i, tcp4Table, false) case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce - coalesced = tcpGRO(bufs, offset, i, tcp6Table, true) + result = tcpGRO(bufs, offset, i, tcp6Table, true) } - if !coalesced { + switch result { + case tcpGROResultNoop: hdr := virtioNetHdr{} err := hdr.encode(bufs[i][offset-virtioNetHdrLen:]) if err != nil { return err } + fallthrough + case tcpGROResultTableInsert: *toWrite = append(*toWrite, i) } } - return nil + err4 := applyCoalesceAccounting(bufs, offset, tcp4Table, false) + err6 := applyCoalesceAccounting(bufs, offset, tcp6Table, true) + return errors.Join(err4, err6) } // tcpTSO splits packets from in into outBuffs, writing the size of each