Add some unit tests for ExchangeContext and ExchangeMessageDispatch. (#33340)
* Add some unit tests for ExchangeContext and ExchangeMessageDispatch.
* Address review comment.
diff --git a/src/messaging/tests/BUILD.gn b/src/messaging/tests/BUILD.gn
index c597c36..0e56444 100644
--- a/src/messaging/tests/BUILD.gn
+++ b/src/messaging/tests/BUILD.gn
@@ -49,7 +49,7 @@
chip_test_suite_using_nltest("tests") {
output_name = "libMessagingLayerTests"
- test_sources = []
+ test_sources = [ "TestExchange.cpp" ]
if (chip_device_platform != "efr32") {
# TODO(#10447): ReliableMessage Test has HF, and ExchangeMgr hangs on EFR32.
diff --git a/src/messaging/tests/TestExchange.cpp b/src/messaging/tests/TestExchange.cpp
new file mode 100644
index 0000000..67b9ab6
--- /dev/null
+++ b/src/messaging/tests/TestExchange.cpp
@@ -0,0 +1,261 @@
+/*
+ * Copyright (c) 2024 Project CHIP Authors
+ * All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <lib/core/CHIPCore.h>
+#include <lib/support/CHIPMem.h>
+#include <lib/support/CodeUtils.h>
+#include <lib/support/UnitTestContext.h>
+#include <lib/support/UnitTestRegistration.h>
+#include <messaging/ExchangeContext.h>
+#include <messaging/ExchangeMgr.h>
+#include <messaging/Flags.h>
+#include <messaging/tests/MessagingContext.h>
+#include <protocols/Protocols.h>
+#include <transport/SessionManager.h>
+#include <transport/TransportMgr.h>
+
+#include <nlbyteorder.h>
+#include <nlunit-test.h>
+
+#include <errno.h>
+#include <utility>
+
+#if CHIP_CRYPTO_PSA
+#include "psa/crypto.h"
+#endif
+
+namespace {
+
+using namespace chip;
+using namespace chip::Inet;
+using namespace chip::Transport;
+using namespace chip::Messaging;
+
+struct TestContext : Test::LoopbackMessagingContext
+{
+ // TODO Add TearDown function when changing test framework to Pigweed to make it more clear how it works.
+ // Currently, the TearDown function is from LoopbackMessagingContext
+ void SetUp() override
+ {
+#if CHIP_CRYPTO_PSA
+ // TODO: use ASSERT_EQ, once transition to pw_unit_test is complete
+ VerifyOrDie(psa_crypto_init() == PSA_SUCCESS);
+#endif
+ chip::Test::LoopbackMessagingContext::SetUp();
+ }
+};
+
+enum : uint8_t
+{
+ kMsgType_TEST1 = 0xf0,
+ kMsgType_TEST2 = 0xf1,
+};
+
+class MockExchangeDelegate : public UnsolicitedMessageHandler, public ExchangeDelegate
+{
+public:
+ CHIP_ERROR OnUnsolicitedMessageReceived(const PayloadHeader & payloadHeader, ExchangeDelegate *& newDelegate) override
+ {
+ newDelegate = this;
+ return CHIP_NO_ERROR;
+ }
+
+ CHIP_ERROR OnMessageReceived(ExchangeContext * ec, const PayloadHeader & payloadHeader,
+ System::PacketBufferHandle && buffer) override
+ {
+ ++mReceivedMessageCount;
+ if (mKeepExchangeAliveOnMessageReceipt)
+ {
+ ec->WillSendMessage();
+ mExchange = ec;
+ }
+ else
+ {
+ // Exchange will be closing, so don't hold on to a reference to it.
+ mExchange = nullptr;
+ }
+ return CHIP_NO_ERROR;
+ }
+
+ void OnResponseTimeout(ExchangeContext * ec) override {}
+
+ ExchangeMessageDispatch & GetMessageDispatch() override
+ {
+ if (mMessageDispatch != nullptr)
+ {
+ return *mMessageDispatch;
+ }
+
+ return ExchangeDelegate::GetMessageDispatch();
+ }
+
+ uint32_t mReceivedMessageCount = 0;
+ bool mKeepExchangeAliveOnMessageReceipt = true;
+ ExchangeContext * mExchange = nullptr;
+ ExchangeMessageDispatch * mMessageDispatch = nullptr;
+};
+
+// Helper used by several tests. Registers delegate2 as an unsolicited message
+// handler, sends a message of type requestMessageType via an exchange that has
+// delegate1 as delegate, responds with responseMessageType.
+template <typename AfterRequestChecker, typename AfterResponseChecker>
+void DoRoundTripTest(nlTestSuite * inSuite, void * inContext, MockExchangeDelegate & delegate1, MockExchangeDelegate & delegate2,
+ uint8_t requestMessageType, uint8_t responseMessageType, AfterRequestChecker && afterRequestChecker,
+ AfterResponseChecker && afterResponseChecker)
+{
+ TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
+
+ ExchangeContext * ec1 = ctx.NewExchangeToBob(&delegate1);
+ NL_TEST_ASSERT(inSuite, ec1 != nullptr);
+
+ CHIP_ERROR err = ctx.GetExchangeManager().RegisterUnsolicitedMessageHandlerForType(Protocols::SecureChannel::Id,
+ requestMessageType, &delegate2);
+ NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
+
+ // To simplify things, skip MRP for all our messages, and make sure we are
+ // always expecting responses.
+ constexpr auto sendFlags =
+ SendFlags(Messaging::SendMessageFlags::kNoAutoRequestAck, Messaging::SendMessageFlags::kExpectResponse);
+
+ err = ec1->SendMessage(Protocols::SecureChannel::Id, requestMessageType,
+ System::PacketBufferHandle::New(System::PacketBuffer::kMaxSize), sendFlags);
+ NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
+
+ ctx.DrainAndServiceIO();
+
+ afterRequestChecker();
+
+ ExchangeContext * ec2 = delegate2.mExchange;
+ err = ec2->SendMessage(Protocols::SecureChannel::Id, responseMessageType,
+ System::PacketBufferHandle::New(System::PacketBuffer::kMaxSize), sendFlags);
+ NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
+
+ ctx.DrainAndServiceIO();
+
+ afterResponseChecker();
+
+ ec1->Close();
+ ec2->Close();
+
+ err = ctx.GetExchangeManager().UnregisterUnsolicitedMessageHandlerForType(Protocols::SecureChannel::Id, kMsgType_TEST1);
+ NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
+}
+
+void CheckBasicMessageRoundTrip(nlTestSuite * inSuite, void * inContext)
+{
+ MockExchangeDelegate delegate1;
+ MockExchangeDelegate delegate2;
+ DoRoundTripTest(
+ inSuite, inContext, delegate1, delegate2, kMsgType_TEST1, kMsgType_TEST2,
+ [&] {
+ NL_TEST_ASSERT(inSuite, delegate1.mReceivedMessageCount == 0);
+ NL_TEST_ASSERT(inSuite, delegate2.mReceivedMessageCount == 1);
+ },
+ [&] {
+ NL_TEST_ASSERT(inSuite, delegate1.mReceivedMessageCount == 1);
+ NL_TEST_ASSERT(inSuite, delegate2.mReceivedMessageCount == 1);
+ });
+}
+
+void CheckBasicExchangeMessageDispatch(nlTestSuite * inSuite, void * inContext)
+{
+ class MockMessageDispatch : public ExchangeMessageDispatch
+ {
+ bool MessagePermitted(Protocols::Id protocol, uint8_t type) override
+ {
+ // Only allow TEST1 messages.
+ return protocol == Protocols::SecureChannel::Id && type == kMsgType_TEST1;
+ }
+ };
+
+ MockMessageDispatch dispatch;
+
+ {
+ // Allowed response.
+ MockExchangeDelegate delegate1;
+ delegate1.mMessageDispatch = &dispatch;
+ MockExchangeDelegate delegate2;
+
+ DoRoundTripTest(
+ inSuite, inContext, delegate1, delegate2, kMsgType_TEST1, kMsgType_TEST1,
+ [&] {
+ NL_TEST_ASSERT(inSuite, delegate1.mReceivedMessageCount == 0);
+ NL_TEST_ASSERT(inSuite, delegate2.mReceivedMessageCount == 1);
+ },
+ [&] {
+ NL_TEST_ASSERT(inSuite, delegate1.mReceivedMessageCount == 1);
+ NL_TEST_ASSERT(inSuite, delegate2.mReceivedMessageCount == 1);
+ });
+ }
+
+ {
+ // Disallowed response.
+ MockExchangeDelegate delegate1;
+ delegate1.mMessageDispatch = &dispatch;
+ MockExchangeDelegate delegate2;
+
+ DoRoundTripTest(
+ inSuite, inContext, delegate1, delegate2, kMsgType_TEST1, kMsgType_TEST2,
+ [&] {
+ NL_TEST_ASSERT(inSuite, delegate1.mReceivedMessageCount == 0);
+ NL_TEST_ASSERT(inSuite, delegate2.mReceivedMessageCount == 1);
+ },
+ [&] {
+ NL_TEST_ASSERT(inSuite, delegate1.mReceivedMessageCount == 0);
+ NL_TEST_ASSERT(inSuite, delegate2.mReceivedMessageCount == 1);
+ });
+ }
+}
+
+// Test Suite
+
+/**
+ * Test Suite that lists all the test functions.
+ */
+// clang-format off
+const nlTest sTests[] =
+{
+ NL_TEST_DEF("Test ExchangeContext::SendMessage", CheckBasicMessageRoundTrip),
+ NL_TEST_DEF("Test ExchangeMessageDispatch", CheckBasicExchangeMessageDispatch),
+
+ NL_TEST_SENTINEL()
+};
+// clang-format on
+
+// clang-format off
+nlTestSuite sSuite =
+{
+ "Test-Exchange",
+ &sTests[0],
+ TestContext::nlTestSetUpTestSuite,
+ TestContext::nlTestTearDownTestSuite,
+ TestContext::nlTestSetUp,
+ TestContext::nlTestTearDown,
+};
+// clang-format on
+
+} // namespace
+
+/**
+ * Main
+ */
+int TestExchange()
+{
+ return chip::ExecuteTestsWithContext<TestContext>(&sSuite);
+}
+
+CHIP_REGISTER_TEST_SUITE(TestExchange);