Unit test for CRMP to discard entries with missing retained buffer (#7068)

* Unit test for CRMP to discard entries with missing retained buffer

* remove unnecessary comment

* make operator= public only for tests

* cleanup

* restyled

* address review comments

* address review comments

* Update src/messaging/ReliableMessageMgr.cpp

Co-authored-by: Boris Zbarsky <bzbarsky@apple.com>

* restyled

Co-authored-by: Boris Zbarsky <bzbarsky@apple.com>
diff --git a/src/messaging/ReliableMessageMgr.cpp b/src/messaging/ReliableMessageMgr.cpp
index 9745725..e5a8c20 100644
--- a/src/messaging/ReliableMessageMgr.cpp
+++ b/src/messaging/ReliableMessageMgr.cpp
@@ -133,6 +133,17 @@
         if (!rc || entry.nextRetransTimeTick != 0)
             continue;
 
+        if (entry.retainedBuf.IsNull())
+        {
+            // We generally try to prevent entries with a null buffer being in a table, but it could happen
+            // if the message dispatch (which is supposed to fill in the buffer) fails to do so _and_ returns
+            // success (so its caller doesn't clear out the bogus table entry).
+            //
+            // If that were to happen, we would crash in the code below.  Guard against it, just in case.
+            ClearRetransTable(entry);
+            continue;
+        }
+
         uint8_t sendCount = entry.sendCount;
         uint32_t msgId    = entry.retainedBuf.GetMsgId();
 
diff --git a/src/messaging/tests/TestReliableMessageProtocol.cpp b/src/messaging/tests/TestReliableMessageProtocol.cpp
index 818335e..6b7d371 100644
--- a/src/messaging/tests/TestReliableMessageProtocol.cpp
+++ b/src/messaging/tests/TestReliableMessageProtocol.cpp
@@ -102,6 +102,49 @@
     bool IsOnMessageReceivedCalled = false;
 };
 
+class MockSessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDispatch
+{
+public:
+    CHIP_ERROR SendMessageImpl(SecureSessionHandle session, PayloadHeader & payloadHeader, System::PacketBufferHandle && message,
+                               EncryptedPacketBufferHandle * retainedMessage) override
+    {
+        PacketHeader packetHeader;
+
+        ReturnErrorOnFailure(payloadHeader.EncodeBeforeData(message));
+        ReturnErrorOnFailure(packetHeader.EncodeBeforeData(message));
+
+        if (retainedMessage != nullptr && mRetainMessageOnSend)
+        {
+            *retainedMessage = EncryptedPacketBufferHandle::MarkEncrypted(message.Retain());
+        }
+        return gTransportMgr.SendMessage(Transport::PeerAddress(), std::move(message));
+    }
+
+    bool MessagePermitted(uint16_t protocol, uint8_t type) override { return true; }
+
+    bool mRetainMessageOnSend = true;
+};
+
+class MockSessionEstablishmentDelegate : public ExchangeDelegate
+{
+public:
+    void OnMessageReceived(ExchangeContext * ec, const PacketHeader & packetHeader, const PayloadHeader & payloadHeader,
+                           System::PacketBufferHandle && buffer) override
+    {
+        IsOnMessageReceivedCalled = true;
+    }
+
+    void OnResponseTimeout(ExchangeContext * ec) override {}
+
+    virtual ExchangeMessageDispatch * GetMessageDispatch(ReliableMessageMgr * rmMgr, SecureSessionMgr * sessionMgr) override
+    {
+        return &mMessageDispatch;
+    }
+
+    bool IsOnMessageReceivedCalled = false;
+    MockSessionEstablishmentExchangeDispatch mMessageDispatch;
+};
+
 void test_os_sleep_ms(uint64_t millisecs)
 {
     struct timespec sleep_time;
@@ -291,6 +334,61 @@
     rm->ClearRetransTable(rc);
 }
 
+void CheckFailedMessageRetainOnSend(nlTestSuite * inSuite, void * inContext)
+{
+    TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
+
+    ctx.GetInetLayer().SystemLayer()->Init(nullptr);
+
+    chip::System::PacketBufferHandle buffer = chip::MessagePacketBuffer::NewWithData(PAYLOAD, sizeof(PAYLOAD));
+    NL_TEST_ASSERT(inSuite, !buffer.IsNull());
+
+    CHIP_ERROR err = CHIP_NO_ERROR;
+
+    MockSessionEstablishmentDelegate mockSender;
+    ExchangeContext * exchange = ctx.NewExchangeToPeer(&mockSender);
+    NL_TEST_ASSERT(inSuite, exchange != nullptr);
+
+    ReliableMessageMgr * rm     = ctx.GetExchangeManager().GetReliableMessageMgr();
+    ReliableMessageContext * rc = exchange->GetReliableMessageContext();
+    NL_TEST_ASSERT(inSuite, rm != nullptr);
+    NL_TEST_ASSERT(inSuite, rc != nullptr);
+
+    rc->SetConfig({
+        1, // CHIP_CONFIG_RMP_DEFAULT_INITIAL_RETRY_INTERVAL
+        1, // CHIP_CONFIG_RMP_DEFAULT_ACTIVE_RETRY_INTERVAL
+    });
+
+    err = mockSender.mMessageDispatch.Init(rm);
+    NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
+
+    mockSender.mMessageDispatch.mRetainMessageOnSend = false;
+
+    // Let's drop the initial message
+    gLoopback.mSendMessageCount    = 0;
+    gLoopback.mNumMessagesToDrop   = 1;
+    gLoopback.mDroppedMessageCount = 0;
+
+    // Ensure the retransmit table is empty right now
+    NL_TEST_ASSERT(inSuite, rm->TestGetCountRetransTable() == 0);
+    err = exchange->SendMessage(Echo::MsgType::EchoRequest, std::move(buffer));
+    NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
+
+    // Ensure the message was dropped
+    NL_TEST_ASSERT(inSuite, gLoopback.mDroppedMessageCount == 1);
+
+    // 1 tick is 64 ms, sleep 65 ms to trigger first re-transmit
+    test_os_sleep_ms(65);
+    ReliableMessageMgr::Timeout(&ctx.GetSystemLayer(), rm, CHIP_SYSTEM_NO_ERROR);
+
+    // Ensure the retransmit table is empty, as we did not provide a message to retain
+    NL_TEST_ASSERT(inSuite, rm->TestGetCountRetransTable() == 0);
+
+    exchange->Close();
+
+    rm->ClearRetransTable(rc);
+}
+
 void CheckSendStandaloneAckMessage(nlTestSuite * inSuite, void * inContext)
 {
     TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
@@ -323,6 +421,7 @@
     NL_TEST_DEF("Test ReliableMessageMgr::CheckFailRetrans", CheckFailRetrans),
     NL_TEST_DEF("Test ReliableMessageMgr::CheckResendApplicationMessage", CheckResendApplicationMessage),
     NL_TEST_DEF("Test ReliableMessageMgr::CheckCloseExchangeAndResendApplicationMessage", CheckCloseExchangeAndResendApplicationMessage),
+    NL_TEST_DEF("Test ReliableMessageMgr::CheckFailedMessageRetainOnSend", CheckFailedMessageRetainOnSend),
     NL_TEST_DEF("Test ReliableMessageMgr::CheckSendStandaloneAckMessage", CheckSendStandaloneAckMessage),
 
     NL_TEST_SENTINEL()
diff --git a/src/transport/SecureSessionMgr.cpp b/src/transport/SecureSessionMgr.cpp
index 788296a..85cc40c 100644
--- a/src/transport/SecureSessionMgr.cpp
+++ b/src/transport/SecureSessionMgr.cpp
@@ -187,7 +187,7 @@
     // Retain the packet buffer in case it's needed for retransmissions.
     if (bufferRetainSlot != nullptr)
     {
-        (*bufferRetainSlot) = msgBuf.Retain();
+        *bufferRetainSlot = EncryptedPacketBufferHandle::MarkEncrypted(msgBuf.Retain());
     }
 
     ChipLogProgress(Inet, "Sending msg from 0x" ChipLogFormatX64 " to 0x" ChipLogFormatX64 " at utc time: %" PRId64 " msec",
diff --git a/src/transport/SecureSessionMgr.h b/src/transport/SecureSessionMgr.h
index a716a40..f1a4ae7 100644
--- a/src/transport/SecureSessionMgr.h
+++ b/src/transport/SecureSessionMgr.h
@@ -53,7 +53,7 @@
  *  EncryptedPacketBufferHandle is a kind of PacketBufferHandle class and used to hold a packet buffer
  *  object whose payload has already been encrypted.
  */
-class EncryptedPacketBufferHandle final : private System::PacketBufferHandle
+class EncryptedPacketBufferHandle final : public System::PacketBufferHandle
 {
 public:
     EncryptedPacketBufferHandle() {}
@@ -92,13 +92,13 @@
     CHIP_ERROR InsertPacketHeader(const PacketHeader & aPacketHeader) { return aPacketHeader.EncodeBeforeData(*this); }
 #endif // CHIP_ENABLE_TEST_ENCRYPTED_BUFFER_API
 
+    static EncryptedPacketBufferHandle MarkEncrypted(PacketBufferHandle && aBuffer)
+    {
+        return EncryptedPacketBufferHandle(std::move(aBuffer));
+    }
+
 private:
-    // Allow SecureSessionMgr to assign or construct us from a PacketBufferHandle
-    friend class SecureSessionMgr;
-
     EncryptedPacketBufferHandle(PacketBufferHandle && aBuffer) : PacketBufferHandle(std::move(aBuffer)) {}
-
-    void operator=(PacketBufferHandle && aBuffer) { PacketBufferHandle::operator=(std::move(aBuffer)); }
 };
 
 /**