Add a way to capture message payloads in tests. (#21774)
One use in TestReadInteraction shows how this works. More uses in
that file will need to be added to address the other TODOs.
diff --git a/src/app/tests/TestReadInteraction.cpp b/src/app/tests/TestReadInteraction.cpp
index 2b78cb8..660cd8e 100644
--- a/src/app/tests/TestReadInteraction.cpp
+++ b/src/app/tests/TestReadInteraction.cpp
@@ -2696,6 +2696,20 @@
ctx.CreateSessionBobToAlice();
}
+namespace {
+
+void CheckForInvalidAction(nlTestSuite * apSuite, Test::MessageCapturer & messageLog)
+{
+ NL_TEST_ASSERT(apSuite, messageLog.MessageCount() == 1);
+ NL_TEST_ASSERT(apSuite, messageLog.IsMessageType(0, Protocols::InteractionModel::MsgType::StatusResponse));
+ CHIP_ERROR status;
+ NL_TEST_ASSERT(apSuite,
+ StatusResponse::ProcessStatusResponse(std::move(messageLog.MessagePayload(0)), status) == CHIP_NO_ERROR);
+ NL_TEST_ASSERT(apSuite, status == CHIP_IM_GLOBAL_STATUS(InvalidAction));
+}
+
+} // anonymous namespace
+
// Read Client sends the read request, Read Handler drops the response, then test injects unknown status reponse message for Read
// Client.
void TestReadInteraction::TestReadClientReceiveInvalidMessage(nlTestSuite * apSuite, void * apContext)
@@ -2750,6 +2764,9 @@
payloadHeader.SetExchangeID(0);
payloadHeader.SetMessageType(chip::Protocols::InteractionModel::MsgType::StatusResponse);
+ Test::MessageCapturer messageLog(ctx);
+ messageLog.mCaptureStandaloneAcks = false;
+
rm->ClearRetransTable(readClient.mExchange.Get());
NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2);
NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mDroppedMessageCount == 1);
@@ -2760,12 +2777,12 @@
readClient.OnMessageReceived(readClient.mExchange.Get(), payloadHeader, std::move(msgBuf));
ctx.DrainAndServiceIO();
- // TODO: Need to validate what status is being sent to the ReadHandler
// The ReadHandler closed its exchange when it sent the Report Data (which we dropped).
// Since we synthesized the StatusResponse to the ReadClient, instead of sending it from the ReadHandler,
// the only messages here are the ReadClient's StatusResponse to the unexpected message and an MRP ack.
- NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2);
NL_TEST_ASSERT(apSuite, delegate.mError == CHIP_IM_GLOBAL_STATUS(Busy));
+
+ CheckForInvalidAction(apSuite, messageLog);
}
engine->Shutdown();
diff --git a/src/messaging/tests/MessagingContext.cpp b/src/messaging/tests/MessagingContext.cpp
index 127b1a1..c31e66b 100644
--- a/src/messaging/tests/MessagingContext.cpp
+++ b/src/messaging/tests/MessagingContext.cpp
@@ -20,6 +20,7 @@
#include <credentials/tests/CHIPCert_unit_test_vectors.h>
#include <lib/support/CodeUtils.h>
#include <lib/support/ErrorStr.h>
+#include <protocols/secure_channel/Constants.h>
namespace chip {
namespace Test {
@@ -204,5 +205,16 @@
return mExchangeManager.NewContext(GetSessionAliceToBob(), delegate);
}
+void MessageCapturer::OnMessageReceived(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader,
+ const SessionHandle & session, DuplicateMessage isDuplicate,
+ System::PacketBufferHandle && msgBuf)
+{
+ if (mCaptureStandaloneAcks || !payloadHeader.HasMessageType(Protocols::SecureChannel::MsgType::StandaloneAck))
+ {
+ mCapturedMessages.emplace_back(Message{ packetHeader, payloadHeader, isDuplicate, msgBuf.CloneData() });
+ }
+ mOriginalDelegate.OnMessageReceived(packetHeader, payloadHeader, session, isDuplicate, std::move(msgBuf));
+}
+
} // namespace Test
} // namespace chip
diff --git a/src/messaging/tests/MessagingContext.h b/src/messaging/tests/MessagingContext.h
index 8c13fdd..559a046 100644
--- a/src/messaging/tests/MessagingContext.h
+++ b/src/messaging/tests/MessagingContext.h
@@ -30,6 +30,8 @@
#include <nlunit-test.h>
+#include <vector>
+
namespace chip {
namespace Test {
@@ -65,7 +67,7 @@
};
/**
- * @brief The context of test cases for messaging layer. It wil initialize network layer and system layer, and create
+ * @brief The context of test cases for messaging layer. It will initialize network layer and system layer, and create
* two secure sessions, connected with each other. Exchanges can be created for each secure session.
*/
class MessagingContext : public PlatformMemoryUser
@@ -213,5 +215,53 @@
using LoopbackTransportManager::GetSystemLayer;
};
+// Class that can be used to capture decrypted message traffic in tests using
+// MessagingContext.
+class MessageCapturer : public SessionMessageDelegate
+{
+public:
+ MessageCapturer(MessagingContext & aContext) :
+ mSessionManager(aContext.GetSecureSessionManager()), mOriginalDelegate(aContext.GetExchangeManager())
+ {
+ // Interpose ourselves into the message flow.
+ mSessionManager.SetMessageDelegate(this);
+ }
+
+ ~MessageCapturer()
+ {
+ // Restore the normal message flow.
+ mSessionManager.SetMessageDelegate(&mOriginalDelegate);
+ }
+
+ struct Message
+ {
+ PacketHeader mPacketHeader;
+ PayloadHeader mPayloadHeader;
+ DuplicateMessage mIsDuplicate;
+ System::PacketBufferHandle mPayload;
+ };
+
+ size_t MessageCount() const { return mCapturedMessages.size(); }
+
+ template <typename MessageType, typename = std::enable_if_t<std::is_enum<MessageType>::value>>
+ bool IsMessageType(size_t index, MessageType type)
+ {
+ return mCapturedMessages[index].mPayloadHeader.HasMessageType(type);
+ }
+
+ System::PacketBufferHandle & MessagePayload(size_t index) { return mCapturedMessages[index].mPayload; }
+
+ bool mCaptureStandaloneAcks = true;
+
+private:
+ // SessionMessageDelegate implementation.
+ void OnMessageReceived(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, const SessionHandle & session,
+ DuplicateMessage isDuplicate, System::PacketBufferHandle && msgBuf) override;
+
+ SessionManager & mSessionManager;
+ SessionMessageDelegate & mOriginalDelegate;
+ std::vector<Message> mCapturedMessages;
+};
+
} // namespace Test
} // namespace chip
diff --git a/src/transport/raw/MessageHeader.h b/src/transport/raw/MessageHeader.h
index f253aa4..c1ccf54 100644
--- a/src/transport/raw/MessageHeader.h
+++ b/src/transport/raw/MessageHeader.h
@@ -430,6 +430,7 @@
{
public:
constexpr PayloadHeader() { SetProtocol(Protocols::NotSpecified); }
+ constexpr PayloadHeader(const PayloadHeader &) = default;
PayloadHeader & operator=(const PayloadHeader &) = default;
/** Get the Session ID from this header. */