net: icmp: Verify the address family before calling the callback

Check the address family of the packet before passing it to a ICMP
handler, to avoid scenarios where ICMPv4 packet is paseed to a ICMPv6
handler and vice versa.

Signed-off-by: Robert Lubos <robert.lubos@nordicsemi.no>
(cherry picked from commit 587d9e6a4a2d20e77093bc507a5d940f1b8fec9f)
diff --git a/include/zephyr/net/icmp.h b/include/zephyr/net/icmp.h
index 9cfcd38..edca65f 100644
--- a/include/zephyr/net/icmp.h
+++ b/include/zephyr/net/icmp.h
@@ -93,6 +93,9 @@
 	/** Opaque user supplied data */
 	void *user_data;
 
+	/** Address family the handler is registered for */
+	uint8_t family;
+
 	/** ICMP type of the response we are waiting */
 	uint8_t type;
 
@@ -157,12 +160,13 @@
  *        system.
  *
  * @param ctx ICMP context used in this request.
+ * @param family Address family the context is using.
  * @param type Type of ICMP message we are handling.
  * @param code Code of ICMP message we are handling.
  * @param handler Callback function that is called when a response is received.
  */
-int net_icmp_init_ctx(struct net_icmp_ctx *ctx, uint8_t type, uint8_t code,
-		      net_icmp_handler_t handler);
+int net_icmp_init_ctx(struct net_icmp_ctx *ctx, uint8_t family, uint8_t type,
+		      uint8_t code, net_icmp_handler_t handler);
 
 /**
  * @brief Cleanup the ICMP context structure. This will unregister the ICMP handler
diff --git a/subsys/net/ip/icmp.c b/subsys/net/ip/icmp.c
index 12d2f02..b4769ee 100644
--- a/subsys/net/ip/icmp.c
+++ b/subsys/net/ip/icmp.c
@@ -50,16 +50,22 @@
 
 #define PKT_WAIT_TIME K_SECONDS(1)
 
-int net_icmp_init_ctx(struct net_icmp_ctx *ctx, uint8_t type, uint8_t code,
-		      net_icmp_handler_t handler)
+int net_icmp_init_ctx(struct net_icmp_ctx *ctx, uint8_t family, uint8_t type,
+		      uint8_t code, net_icmp_handler_t handler)
 {
 	if (ctx == NULL || handler == NULL) {
 		return -EINVAL;
 	}
 
+	if (family != AF_INET && family != AF_INET6) {
+		NET_ERR("Wrong address family");
+		return -EINVAL;
+	}
+
 	memset(ctx, 0, sizeof(struct net_icmp_ctx));
 
 	ctx->handler = handler;
+	ctx->family = family;
 	ctx->type = type;
 	ctx->code = code;
 
@@ -511,6 +517,10 @@
 	k_mutex_lock(&lock, K_FOREVER);
 
 	SYS_SLIST_FOR_EACH_CONTAINER(&handlers, ctx, node) {
+		if (ip_hdr->family != ctx->family) {
+			continue;
+		}
+
 		if (ctx->type == icmp_hdr->type &&
 		    (ctx->code == icmp_hdr->code || ctx->code == 0U)) {
 			/* Do not use a handler that is expecting data from different
diff --git a/subsys/net/ip/icmpv4.c b/subsys/net/ip/icmpv4.c
index 6c5f6ab..7a2d5ca 100644
--- a/subsys/net/ip/icmpv4.c
+++ b/subsys/net/ip/icmpv4.c
@@ -766,14 +766,15 @@
 	static struct net_icmp_ctx ctx;
 	int ret;
 
-	ret = net_icmp_init_ctx(&ctx, NET_ICMPV4_ECHO_REQUEST, 0, icmpv4_handle_echo_request);
+	ret = net_icmp_init_ctx(&ctx, AF_INET, NET_ICMPV4_ECHO_REQUEST, 0,
+				icmpv4_handle_echo_request);
 	if (ret < 0) {
 		NET_ERR("Cannot register %s handler (%d)", STRINGIFY(NET_ICMPV4_ECHO_REQUEST),
 			ret);
 	}
 
 #if defined(CONFIG_NET_IPV4_PMTU)
-	ret = net_icmp_init_ctx(&dst_unreach_ctx, NET_ICMPV4_DST_UNREACH, 0,
+	ret = net_icmp_init_ctx(&dst_unreach_ctx, AF_INET, NET_ICMPV4_DST_UNREACH, 0,
 				icmpv4_handle_dst_unreach);
 	if (ret < 0) {
 		NET_ERR("Cannot register %s handler (%d)", STRINGIFY(NET_ICMPV4_DST_UNREACH),
diff --git a/subsys/net/ip/icmpv6.c b/subsys/net/ip/icmpv6.c
index b1d68a6..c9cebea 100644
--- a/subsys/net/ip/icmpv6.c
+++ b/subsys/net/ip/icmpv6.c
@@ -394,7 +394,8 @@
 	static struct net_icmp_ctx ctx;
 	int ret;
 
-	ret = net_icmp_init_ctx(&ctx, NET_ICMPV6_ECHO_REQUEST, 0, icmpv6_handle_echo_request);
+	ret = net_icmp_init_ctx(&ctx, AF_INET6, NET_ICMPV6_ECHO_REQUEST, 0,
+				icmpv6_handle_echo_request);
 	if (ret < 0) {
 		NET_ERR("Cannot register %s handler (%d)", STRINGIFY(NET_ICMPV6_ECHO_REQUEST),
 			ret);
diff --git a/subsys/net/ip/ipv6_mld.c b/subsys/net/ip/ipv6_mld.c
index 900dfc1..8444e36 100644
--- a/subsys/net/ip/ipv6_mld.c
+++ b/subsys/net/ip/ipv6_mld.c
@@ -471,7 +471,7 @@
 	static struct net_icmp_ctx ctx;
 	int ret;
 
-	ret = net_icmp_init_ctx(&ctx, NET_ICMPV6_MLD_QUERY, 0, handle_mld_query);
+	ret = net_icmp_init_ctx(&ctx, AF_INET6, NET_ICMPV6_MLD_QUERY, 0, handle_mld_query);
 	if (ret < 0) {
 		NET_ERR("Cannot register %s handler (%d)", STRINGIFY(NET_ICMPV6_MLD_QUERY),
 			ret);
diff --git a/subsys/net/ip/ipv6_nbr.c b/subsys/net/ip/ipv6_nbr.c
index c653545..25232bd 100644
--- a/subsys/net/ip/ipv6_nbr.c
+++ b/subsys/net/ip/ipv6_nbr.c
@@ -2869,13 +2869,13 @@
 	int ret;
 
 #if defined(CONFIG_NET_IPV6_NBR_CACHE)
-	ret = net_icmp_init_ctx(&ns_ctx, NET_ICMPV6_NS, 0, handle_ns_input);
+	ret = net_icmp_init_ctx(&ns_ctx, AF_INET6, NET_ICMPV6_NS, 0, handle_ns_input);
 	if (ret < 0) {
 		NET_ERR("Cannot register %s handler (%d)", STRINGIFY(NET_ICMPV6_NS),
 			ret);
 	}
 
-	ret = net_icmp_init_ctx(&na_ctx, NET_ICMPV6_NA, 0, handle_na_input);
+	ret = net_icmp_init_ctx(&na_ctx, AF_INET6, NET_ICMPV6_NA, 0, handle_na_input);
 	if (ret < 0) {
 		NET_ERR("Cannot register %s handler (%d)", STRINGIFY(NET_ICMPV6_NA),
 			ret);
@@ -2884,7 +2884,7 @@
 	k_work_init_delayable(&ipv6_ns_reply_timer, ipv6_ns_reply_timeout);
 #endif
 #if defined(CONFIG_NET_IPV6_ND)
-	ret = net_icmp_init_ctx(&ra_ctx, NET_ICMPV6_RA, 0, handle_ra_input);
+	ret = net_icmp_init_ctx(&ra_ctx, AF_INET6, NET_ICMPV6_RA, 0, handle_ra_input);
 	if (ret < 0) {
 		NET_ERR("Cannot register %s handler (%d)", STRINGIFY(NET_ICMPV6_RA),
 			ret);
@@ -2895,7 +2895,7 @@
 #endif
 
 #if defined(CONFIG_NET_IPV6_PMTU)
-	ret = net_icmp_init_ctx(&ptb_ctx, NET_ICMPV6_PACKET_TOO_BIG, 0, handle_ptb_input);
+	ret = net_icmp_init_ctx(&ptb_ctx, AF_INET6, NET_ICMPV6_PACKET_TOO_BIG, 0, handle_ptb_input);
 	if (ret < 0) {
 		NET_ERR("Cannot register %s handler (%d)", STRINGIFY(NET_ICMPV6_PACKET_TOO_BIG),
 			ret);
diff --git a/subsys/net/lib/dhcpv4/dhcpv4_server.c b/subsys/net/lib/dhcpv4/dhcpv4_server.c
index 43ef774..8e135f8 100644
--- a/subsys/net/lib/dhcpv4/dhcpv4_server.c
+++ b/subsys/net/lib/dhcpv4/dhcpv4_server.c
@@ -879,7 +879,7 @@
 
 static int dhcpv4_server_probing_init(struct dhcpv4_server_ctx *ctx)
 {
-	return net_icmp_init_ctx(&ctx->probe_ctx.icmp_ctx,
+	return net_icmp_init_ctx(&ctx->probe_ctx.icmp_ctx, AF_INET,
 				 NET_ICMPV4_ECHO_REPLY, 0,
 				 echo_reply_handler);
 }
diff --git a/subsys/net/lib/shell/ping.c b/subsys/net/lib/shell/ping.c
index 230da3b..3ccc62a 100644
--- a/subsys/net/lib/shell/ping.c
+++ b/subsys/net/lib/shell/ping.c
@@ -456,7 +456,7 @@
 	    net_addr_pton(AF_INET6, host, &ping_ctx.addr6.sin6_addr) == 0) {
 		ping_ctx.addr6.sin6_family = AF_INET6;
 
-		ret = net_icmp_init_ctx(&ping_ctx.icmp, NET_ICMPV6_ECHO_REPLY, 0,
+		ret = net_icmp_init_ctx(&ping_ctx.icmp, AF_INET6, NET_ICMPV6_ECHO_REPLY, 0,
 					handle_ipv6_echo_reply);
 		if (ret < 0) {
 			PR_WARNING("Cannot initialize ICMP context for %s\n", "IPv6");
@@ -466,7 +466,7 @@
 		   net_addr_pton(AF_INET, host, &ping_ctx.addr4.sin_addr) == 0) {
 		ping_ctx.addr4.sin_family = AF_INET;
 
-		ret = net_icmp_init_ctx(&ping_ctx.icmp, NET_ICMPV4_ECHO_REPLY, 0,
+		ret = net_icmp_init_ctx(&ping_ctx.icmp, AF_INET, NET_ICMPV4_ECHO_REPLY, 0,
 					handle_ipv4_echo_reply);
 		if (ret < 0) {
 			PR_WARNING("Cannot initialize ICMP context for %s\n", "IPv4");
diff --git a/subsys/net/lib/zperf/zperf_shell.c b/subsys/net/lib/zperf/zperf_shell.c
index 1ad9755..3847107 100644
--- a/subsys/net/lib/zperf/zperf_shell.c
+++ b/subsys/net/lib/zperf/zperf_shell.c
@@ -771,7 +771,7 @@
 	struct net_icmp_ctx ctx;
 	int ret;
 
-	ret = net_icmp_init_ctx(&ctx, NET_ICMPV6_ECHO_REPLY, 0, ping_handler);
+	ret = net_icmp_init_ctx(&ctx, AF_INET6, NET_ICMPV6_ECHO_REPLY, 0, ping_handler);
 	if (ret < 0) {
 		shell_fprintf(sh, SHELL_WARNING, "Cannot send ping (%d)\n", ret);
 		return;
diff --git a/tests/boards/espressif/ethernet/src/main.c b/tests/boards/espressif/ethernet/src/main.c
index e9c397d..6fbd2e6 100644
--- a/tests/boards/espressif/ethernet/src/main.c
+++ b/tests/boards/espressif/ethernet/src/main.c
@@ -97,7 +97,7 @@
 	gw_addr_4 = net_if_ipv4_get_gw(iface);
 	zassert_not_equal(gw_addr_4.s_addr, 0, "Gateway address is not set");
 
-	ret = net_icmp_init_ctx(&ctx, NET_ICMPV4_ECHO_REPLY, 0, icmp_event);
+	ret = net_icmp_init_ctx(&ctx, AF_INET, NET_ICMPV4_ECHO_REPLY, 0, icmp_event);
 	zassert_equal(ret, 0, "Cannot init ICMP (%d)", ret);
 
 	dst4.sin_family = AF_INET;
diff --git a/tests/boards/espressif/wifi/src/main.c b/tests/boards/espressif/wifi/src/main.c
index 931927e..c711414 100644
--- a/tests/boards/espressif/wifi/src/main.c
+++ b/tests/boards/espressif/wifi/src/main.c
@@ -282,7 +282,7 @@
 	gw_addr_4 = net_if_ipv4_get_gw(wifi_ctx.iface);
 	zassert_not_equal(gw_addr_4.s_addr, 0, "Gateway address is not set");
 
-	ret = net_icmp_init_ctx(&icmp_ctx, NET_ICMPV4_ECHO_REPLY, 0, icmp_event);
+	ret = net_icmp_init_ctx(&icmp_ctx, AF_INET, NET_ICMPV4_ECHO_REPLY, 0, icmp_event);
 	zassert_equal(ret, 0, "Cannot init ICMP (%d)", ret);
 
 	dst4.sin_family = AF_INET;
diff --git a/tests/net/checksum_offload/src/main.c b/tests/net/checksum_offload/src/main.c
index 296853b..52f8dfa 100644
--- a/tests/net/checksum_offload/src/main.c
+++ b/tests/net/checksum_offload/src/main.c
@@ -874,7 +874,7 @@
 
 	test_icmp_init(family, offloaded, &dst_addr, &iface);
 
-	ret = net_icmp_init_ctx(&ctx, 0, 0, dummy_icmp_handler);
+	ret = net_icmp_init_ctx(&ctx, family, 0, 0, dummy_icmp_handler);
 	zassert_equal(ret, 0, "Cannot init ICMP (%d)", ret);
 
 	test_started = true;
@@ -1210,7 +1210,7 @@
 
 	test_icmp_init(family, offloaded, &dst_addr, &iface);
 
-	ret = net_icmp_init_ctx(&ctx,
+	ret = net_icmp_init_ctx(&ctx, family,
 				family == AF_INET6 ? NET_ICMPV6_ECHO_REPLY :
 						     NET_ICMPV4_ECHO_REPLY,
 				0, icmp_handler);
@@ -1267,7 +1267,7 @@
 
 	test_icmp_init(family, offloaded, &dst_addr, &iface);
 
-	ret = net_icmp_init_ctx(&ctx,
+	ret = net_icmp_init_ctx(&ctx, family,
 				family == AF_INET6 ? NET_ICMPV6_ECHO_REPLY :
 						     NET_ICMPV4_ECHO_REPLY,
 				0, icmp_handler);
diff --git a/tests/net/icmp/src/main.c b/tests/net/icmp/src/main.c
index b377a75..366a59a 100644
--- a/tests/net/icmp/src/main.c
+++ b/tests/net/icmp/src/main.c
@@ -465,7 +465,7 @@
 		return;
 	}
 
-	ret = net_icmp_init_ctx(&ctx, NET_ICMPV6_ECHO_REPLY, 0, icmp_handler);
+	ret = net_icmp_init_ctx(&ctx, AF_INET6, NET_ICMPV6_ECHO_REPLY, 0, icmp_handler);
 	zassert_equal(ret, 0, "Cannot init ICMP (%d)", ret);
 
 	dst6.sin6_family = AF_INET6;
@@ -508,7 +508,7 @@
 		return;
 	}
 
-	ret = net_icmp_init_ctx(&ctx, NET_ICMPV4_ECHO_REPLY, 0, icmp_handler);
+	ret = net_icmp_init_ctx(&ctx, AF_INET, NET_ICMPV4_ECHO_REPLY, 0, icmp_handler);
 	zassert_equal(ret, 0, "Cannot init ICMP (%d)", ret);
 
 	dst4.sin_family = AF_INET;
@@ -549,7 +549,7 @@
 	struct net_icmp_ctx ctx;
 	int ret;
 
-	ret = net_icmp_init_ctx(&ctx, NET_ICMPV4_ECHO_REPLY, 0, icmp_handler);
+	ret = net_icmp_init_ctx(&ctx, AF_INET, NET_ICMPV4_ECHO_REPLY, 0, icmp_handler);
 	zassert_equal(ret, 0, "Cannot init ICMP (%d)", ret);
 
 	dst4.sin_family = AF_INET;
@@ -588,7 +588,7 @@
 	struct net_icmp_ctx ctx;
 	int ret;
 
-	ret = net_icmp_init_ctx(&ctx, NET_ICMPV6_ECHO_REPLY, 0, icmp_handler);
+	ret = net_icmp_init_ctx(&ctx, AF_INET6, NET_ICMPV6_ECHO_REPLY, 0, icmp_handler);
 	zassert_equal(ret, 0, "Cannot init ICMP (%d)", ret);
 
 	dst6.sin6_family = AF_INET6;
diff --git a/tests/net/icmpv4/src/main.c b/tests/net/icmpv4/src/main.c
index 36e9436..24d0eac 100644
--- a/tests/net/icmpv4/src/main.c
+++ b/tests/net/icmpv4/src/main.c
@@ -464,7 +464,7 @@
 	struct net_pkt *pkt;
 	int ret;
 
-	ret = net_icmp_init_ctx(&ctx, NET_ICMPV4_ECHO_REPLY,
+	ret = net_icmp_init_ctx(&ctx, AF_INET, NET_ICMPV4_ECHO_REPLY,
 				0, handle_reply_msg);
 	zassert_equal(ret, 0, "Cannot register %s handler (%d)",
 		      STRINGIFY(NET_ICMPV4_ECHO_REPLY), ret);
diff --git a/tests/net/icmpv6/src/main.c b/tests/net/icmpv6/src/main.c
index 0ee9d3b..11a7a7e 100644
--- a/tests/net/icmpv6/src/main.c
+++ b/tests/net/icmpv6/src/main.c
@@ -182,12 +182,12 @@
 	struct net_pkt *pkt;
 	int ret;
 
-	ret = net_icmp_init_ctx(&ctx1, NET_ICMPV6_ECHO_REPLY,
+	ret = net_icmp_init_ctx(&ctx1, AF_INET6, NET_ICMPV6_ECHO_REPLY,
 				0, handle_test_msg);
 	zassert_equal(ret, 0, "Cannot register %s handler (%d)",
 		      STRINGIFY(NET_ICMPV6_ECHO_REPLY), ret);
 
-	ret = net_icmp_init_ctx(&ctx2, NET_ICMPV6_ECHO_REQUEST,
+	ret = net_icmp_init_ctx(&ctx2, AF_INET6, NET_ICMPV6_ECHO_REQUEST,
 				0, handle_test_msg);
 	zassert_equal(ret, 0, "Cannot register %s handler (%d)",
 		      STRINGIFY(NET_ICMPV6_ECHO_REQUEST), ret);
diff --git a/tests/net/ipv6_fragment/src/main.c b/tests/net/ipv6_fragment/src/main.c
index 49d9622..6df732b 100644
--- a/tests/net/ipv6_fragment/src/main.c
+++ b/tests/net/ipv6_fragment/src/main.c
@@ -2299,7 +2299,7 @@
 	int ret;
 	struct net_icmp_ctx ctx;
 
-	ret = net_icmp_init_ctx(&ctx, NET_ICMPV6_ECHO_REPLY,
+	ret = net_icmp_init_ctx(&ctx, AF_INET6, NET_ICMPV6_ECHO_REPLY,
 				0, handle_ipv6_echo_reply);
 	zassert_equal(ret, 0, "Cannot register %s handler (%d)",
 		      STRINGIFY(NET_ICMPV6_ECHO_REPLY), ret);
diff --git a/tests/net/mld/src/main.c b/tests/net/mld/src/main.c
index 0036b58..ecfc441 100644
--- a/tests/net/mld/src/main.c
+++ b/tests/net/mld/src/main.c
@@ -557,7 +557,7 @@
 
 	is_query_received = false;
 
-	ret = net_icmp_init_ctx(&ctx, NET_ICMPV6_MLD_QUERY,
+	ret = net_icmp_init_ctx(&ctx, AF_INET6, NET_ICMPV6_MLD_QUERY,
 				0, handle_mld_query);
 	zassert_equal(ret, 0, "Cannot register %s handler (%d)",
 		      STRINGIFY(NET_ICMPV6_MLD_QUERY), ret);