net: pkt: introduce net_pkt_remove_tail()

Introduce a helper function for being able to remove any arbitrary
length from tail of packet. This is handy in cases when removing
unneeded data, like CRC once it was verified.

Signed-off-by: Marcin Niestroj <m.niestroj@emb.dev>
diff --git a/include/net/net_pkt.h b/include/net/net_pkt.h
index 9ddc055..e069249 100644
--- a/include/net/net_pkt.h
+++ b/include/net/net_pkt.h
@@ -1706,6 +1706,22 @@
 void net_pkt_trim_buffer(struct net_pkt *pkt);
 
 /**
+ * @brief Remove @a length bytes from tail of packet
+ *
+ * @details This function does not take packet cursor into account. It is a
+ *          helper to remove unneeded bytes from tail of packet (like appended
+ *          CRC). It takes care of buffer deallocation if removed bytes span
+ *          whole buffer(s).
+ *
+ * @param pkt    Network packet
+ * @param length Number of bytes to be removed
+ *
+ * @retval 0       On success.
+ * @retval -EINVAL If packet length is shorter than @a length.
+ */
+int net_pkt_remove_tail(struct net_pkt *pkt, size_t length);
+
+/**
  * @brief Initialize net_pkt cursor
  *
  * @details This will initialize the net_pkt cursor from its buffer.
diff --git a/subsys/net/ip/net_pkt.c b/subsys/net/ip/net_pkt.c
index 14ec127..e295201 100644
--- a/subsys/net/ip/net_pkt.c
+++ b/subsys/net/ip/net_pkt.c
@@ -1089,6 +1089,36 @@
 	}
 }
 
+int net_pkt_remove_tail(struct net_pkt *pkt, size_t length)
+{
+	struct net_buf *buf = pkt->buffer;
+	size_t remaining_len = net_pkt_get_len(pkt);
+
+	if (remaining_len < length) {
+		return -EINVAL;
+	}
+
+	remaining_len -= length;
+
+	while (buf) {
+		if (buf->len >= remaining_len) {
+			buf->len = remaining_len;
+
+			if (buf->frags) {
+				net_pkt_frag_unref(buf->frags);
+				buf->frags = NULL;
+			}
+
+			break;
+		}
+
+		remaining_len -= buf->len;
+		buf = buf->frags;
+	}
+
+	return 0;
+}
+
 #if NET_LOG_LEVEL >= LOG_LEVEL_DBG
 int net_pkt_alloc_buffer_debug(struct net_pkt *pkt,
 			       size_t size,
diff --git a/tests/net/net_pkt/src/main.c b/tests/net/net_pkt/src/main.c
index 39c54cb..43f0fa6 100644
--- a/tests/net/net_pkt/src/main.c
+++ b/tests/net/net_pkt/src/main.c
@@ -974,6 +974,90 @@
 	net_pkt_unref(pkt);
 }
 
+void test_net_pkt_remove_tail(void)
+{
+	struct net_pkt *pkt;
+	int err;
+
+	pkt = net_pkt_alloc_with_buffer(NULL,
+					CONFIG_NET_BUF_DATA_SIZE * 2 + 3,
+					AF_UNSPEC, 0, K_NO_WAIT);
+	zassert_true(pkt != NULL, "Pkt not allocated");
+
+	net_pkt_cursor_init(pkt);
+	net_pkt_write(pkt, small_buffer, CONFIG_NET_BUF_DATA_SIZE * 2 + 3);
+
+	zassert_equal(net_pkt_get_len(pkt), CONFIG_NET_BUF_DATA_SIZE * 2 + 3,
+		      "Pkt length is invalid");
+	zassert_equal(pkt->frags->frags->frags->len, 3,
+		      "3rd buffer length is invalid");
+
+	/* Remove some bytes from last buffer */
+	err = net_pkt_remove_tail(pkt, 2);
+	zassert_equal(err, 0, "Failed to remove tail");
+
+	zassert_equal(net_pkt_get_len(pkt), CONFIG_NET_BUF_DATA_SIZE * 2 + 1,
+		      "Pkt length is invalid");
+	zassert_not_equal(pkt->frags->frags->frags, NULL,
+			  "3rd buffer was removed");
+	zassert_equal(pkt->frags->frags->frags->len, 1,
+		      "3rd buffer length is invalid");
+
+	/* Remove last byte from last buffer */
+	err = net_pkt_remove_tail(pkt, 1);
+	zassert_equal(err, 0, "Failed to remove tail");
+
+	zassert_equal(net_pkt_get_len(pkt), CONFIG_NET_BUF_DATA_SIZE * 2,
+		      "Pkt length is invalid");
+	zassert_equal(pkt->frags->frags->frags, NULL,
+		      "3rd buffer was not removed");
+	zassert_equal(pkt->frags->frags->len, CONFIG_NET_BUF_DATA_SIZE,
+		      "2nd buffer length is invalid");
+
+	/* Remove 2nd buffer and one byte from 1st buffer */
+	err = net_pkt_remove_tail(pkt, CONFIG_NET_BUF_DATA_SIZE + 1);
+	zassert_equal(err, 0, "Failed to remove tail");
+
+	zassert_equal(net_pkt_get_len(pkt), CONFIG_NET_BUF_DATA_SIZE - 1,
+		      "Pkt length is invalid");
+	zassert_equal(pkt->frags->frags, NULL,
+		      "2nd buffer was not removed");
+	zassert_equal(pkt->frags->len, CONFIG_NET_BUF_DATA_SIZE - 1,
+		      "1st buffer length is invalid");
+
+	net_pkt_unref(pkt);
+
+	pkt = net_pkt_rx_alloc_with_buffer(NULL,
+					   CONFIG_NET_BUF_DATA_SIZE * 2 + 3,
+					   AF_UNSPEC, 0, K_NO_WAIT);
+
+	net_pkt_cursor_init(pkt);
+	net_pkt_write(pkt, small_buffer, CONFIG_NET_BUF_DATA_SIZE * 2 + 3);
+
+	zassert_equal(net_pkt_get_len(pkt), CONFIG_NET_BUF_DATA_SIZE * 2 + 3,
+		      "Pkt length is invalid");
+	zassert_equal(pkt->frags->frags->frags->len, 3,
+		      "3rd buffer length is invalid");
+
+	/* Remove bytes spanning 3 buffers */
+	err = net_pkt_remove_tail(pkt, CONFIG_NET_BUF_DATA_SIZE + 5);
+	zassert_equal(err, 0, "Failed to remove tail");
+
+	zassert_equal(net_pkt_get_len(pkt), CONFIG_NET_BUF_DATA_SIZE - 2,
+		      "Pkt length is invalid");
+	zassert_equal(pkt->frags->frags, NULL,
+		      "2nd buffer was not removed");
+	zassert_equal(pkt->frags->len, CONFIG_NET_BUF_DATA_SIZE - 2,
+		      "1st buffer length is invalid");
+
+	/* Try to remove more bytes than packet has */
+	err = net_pkt_remove_tail(pkt, CONFIG_NET_BUF_DATA_SIZE);
+	zassert_equal(err, -EINVAL,
+		      "Removing more bytes than available should fail");
+
+	net_pkt_unref(pkt);
+}
+
 void test_main(void)
 {
 	eth_if = net_if_get_default();
@@ -989,7 +1073,8 @@
 			 ztest_unit_test(test_net_pkt_clone),
 			 ztest_unit_test(test_net_pkt_headroom),
 			 ztest_unit_test(test_net_pkt_headroom_copy),
-			 ztest_unit_test(test_net_pkt_get_contiguous_len)
+			 ztest_unit_test(test_net_pkt_get_contiguous_len),
+			 ztest_unit_test(test_net_pkt_remove_tail)
 		);
 
 	ztest_run_test_suite(net_pkt_tests);