Rollback InvokeRequestMessage when AddResponseData fails (#33849)

* Rollback InvokeRequestMessage when AddResponseData fails

* Restyled by clang-format

* Fix rebase issues with TestCommandInteraction

* Move new test to pigweed style unit test after rebase

* Address PR comments

* Restyled by clang-format

* Address PR comments

* Address PR comments

* Small fixes

* Small fixes

* Restyled by clang-format

* Address PR comment

---------

Co-authored-by: Restyled.io <commits@restyled.io>
diff --git a/src/app/CommandSender.cpp b/src/app/CommandSender.cpp
index a91d0fd..2f61f94 100644
--- a/src/app/CommandSender.cpp
+++ b/src/app/CommandSender.cpp
@@ -555,13 +555,18 @@
 CHIP_ERROR CommandSender::AddRequestData(const CommandPathParams & aCommandPath, const DataModel::EncodableToTLV & aEncodable,
                                          AddRequestDataParameters & aAddRequestDataParams)
 {
+    ReturnErrorOnFailure(AllocateBuffer());
+
+    RollbackInvokeRequest rollback(*this);
     PrepareCommandParameters prepareCommandParams(aAddRequestDataParams);
     ReturnErrorOnFailure(PrepareCommand(aCommandPath, prepareCommandParams));
     TLV::TLVWriter * writer = GetCommandDataIBTLVWriter();
     VerifyOrReturnError(writer != nullptr, CHIP_ERROR_INCORRECT_STATE);
     ReturnErrorOnFailure(aEncodable.EncodeTo(*writer, TLV::ContextTag(CommandDataIB::Tag::kFields)));
     FinishCommandParameters finishCommandParams(aAddRequestDataParams);
-    return FinishCommand(finishCommandParams);
+    ReturnErrorOnFailure(FinishCommand(finishCommandParams));
+    rollback.DisableAutomaticRollback();
+    return CHIP_NO_ERROR;
 }
 
 CHIP_ERROR CommandSender::FinishCommandInternal(FinishCommandParameters & aFinishCommandParams)
@@ -672,5 +677,34 @@
     ChipLogDetail(DataManagement, "ICR moving to [%10.10s]", GetStateStr());
 }
 
+CommandSender::RollbackInvokeRequest::RollbackInvokeRequest(CommandSender & aCommandSender) : mCommandSender(aCommandSender)
+{
+    VerifyOrReturn(mCommandSender.mBufferAllocated);
+    VerifyOrReturn(mCommandSender.mState == State::Idle || mCommandSender.mState == State::AddedCommand);
+    VerifyOrReturn(mCommandSender.mInvokeRequestBuilder.GetInvokeRequests().GetError() == CHIP_NO_ERROR);
+    VerifyOrReturn(mCommandSender.mInvokeRequestBuilder.GetError() == CHIP_NO_ERROR);
+    mCommandSender.mInvokeRequestBuilder.Checkpoint(mBackupWriter);
+    mBackupState          = mCommandSender.mState;
+    mRollbackInDestructor = true;
+}
+
+CommandSender::RollbackInvokeRequest::~RollbackInvokeRequest()
+{
+    VerifyOrReturn(mRollbackInDestructor);
+    VerifyOrReturn(mCommandSender.mState == State::AddingCommand);
+    ChipLogDetail(DataManagement, "Rolling back response");
+    // TODO(#30453): Rollback of mInvokeRequestBuilder should handle resetting
+    // InvokeRequests.
+    mCommandSender.mInvokeRequestBuilder.GetInvokeRequests().ResetError();
+    mCommandSender.mInvokeRequestBuilder.Rollback(mBackupWriter);
+    mCommandSender.MoveToState(mBackupState);
+    mRollbackInDestructor = false;
+}
+
+void CommandSender::RollbackInvokeRequest::DisableAutomaticRollback()
+{
+    mRollbackInDestructor = false;
+}
+
 } // namespace app
 } // namespace chip
diff --git a/src/app/CommandSender.h b/src/app/CommandSender.h
index e34dbe7..30cf55b 100644
--- a/src/app/CommandSender.h
+++ b/src/app/CommandSender.h
@@ -216,6 +216,12 @@
 
         AddRequestDataParameters(const Optional<uint16_t> & aTimedInvokeTimeoutMs) : timedInvokeTimeoutMs(aTimedInvokeTimeoutMs) {}
 
+        AddRequestDataParameters & SetCommandRef(uint16_t aCommandRef)
+        {
+            commandRef.SetValue(aCommandRef);
+            return *this;
+        }
+
         // When a value is provided for timedInvokeTimeoutMs, this invoke becomes a timed
         // invoke. CommandSender will use the minimum of all provided timeouts for execution.
         const Optional<uint16_t> timedInvokeTimeoutMs;
@@ -512,6 +518,34 @@
         AwaitingDestruction, ///< The object has completed its work and is awaiting destruction by the application.
     };
 
+    /**
+     * Class to help backup CommandSender's buffer containing InvokeRequestMessage when adding InvokeRequest
+     * in case there is a failure to add InvokeRequest. Intended usage is as follows:
+     *  - Allocate RollbackInvokeRequest on the stack.
+     *  - Attempt adding InvokeRequest into InvokeRequestMessage buffer.
+     *  - If modification is added successfully, call DisableAutomaticRollback() to prevent destructor from
+     *    rolling back InvokeReqestMessage.
+     *  - If there is an issue adding InvokeRequest, destructor will take care of rolling back
+     *    InvokeRequestMessage to previously saved state.
+     */
+    class RollbackInvokeRequest
+    {
+    public:
+        explicit RollbackInvokeRequest(CommandSender & aCommandSender);
+        ~RollbackInvokeRequest();
+
+        /**
+         * Disables rolling back to previously saved state for InvokeRequestMessage.
+         */
+        void DisableAutomaticRollback();
+
+    private:
+        CommandSender & mCommandSender;
+        TLV::TLVWriter mBackupWriter;
+        State mBackupState;
+        bool mRollbackInDestructor = false;
+    };
+
     union CallbackHandle
     {
         CallbackHandle(Callback * apCallback) : legacyCallback(apCallback) {}
diff --git a/src/app/tests/TestCommandInteraction.cpp b/src/app/tests/TestCommandInteraction.cpp
index 0573db4..8c2537a 100644
--- a/src/app/tests/TestCommandInteraction.cpp
+++ b/src/app/tests/TestCommandInteraction.cpp
@@ -109,10 +109,9 @@
     kSizeGreaterThan255,
 };
 
-struct ForcedSizeBuffer
+class ForcedSizeBuffer : public app::DataModel::EncodableToTLV
 {
-    chip::Platform::ScopedMemoryBufferWithSize<uint8_t> mBuffer;
-
+public:
     ForcedSizeBuffer(uint32_t size)
     {
         if (mBuffer.Alloc(size))
@@ -124,7 +123,7 @@
 
     // No significance with using 0x12 as the CommandId, just using a value.
     static constexpr chip::CommandId GetCommandId() { return 0x12; }
-    CHIP_ERROR Encode(TLV::TLVWriter & aWriter, TLV::Tag aTag) const
+    CHIP_ERROR EncodeTo(TLV::TLVWriter & aWriter, TLV::Tag aTag) const override
     {
         VerifyOrReturnError(mBuffer, CHIP_ERROR_NO_MEMORY);
 
@@ -133,6 +132,9 @@
         ReturnErrorOnFailure(app::DataModel::Encode(aWriter, TLV::ContextTag(1), ByteSpan(mBuffer.Get(), mBuffer.AllocatedSize())));
         return aWriter.EndContainer(outerContainerType);
     }
+
+private:
+    chip::Platform::ScopedMemoryBufferWithSize<uint8_t> mBuffer;
 };
 
 struct Fields
@@ -387,6 +389,7 @@
     void TestCommandSender_WithProcessReceivedMsg();
     void TestCommandSender_ExtendableApiWithProcessReceivedMsg();
     void TestCommandSender_ExtendableApiWithProcessReceivedMsgContainingInvalidCommandRef();
+    void TestCommandSender_ValidateSecondLargeAddRequestDataRollbacked();
     void TestCommandHandler_WithoutResponderCallingAddStatus();
     void TestCommandHandler_WithoutResponderCallingAddResponse();
     void TestCommandHandler_WithoutResponderCallingDirectPrepareFinishCommandApis();
@@ -632,7 +635,8 @@
     // When ForcedSizeBuffer exceeds 255, an extra byte is needed for length, affecting the overhead size required by
     // AddResponseData. In order to have this accounted for in overhead calculation we set the length to be 256.
     uint32_t sizeOfForcedSizeBuffer = aBufferSizeHint == ForcedSizeBufferLengthHint::kSizeGreaterThan255 ? 256 : 0;
-    EXPECT_EQ(commandHandler.AddResponseData(aRequestCommandPath, ForcedSizeBuffer(sizeOfForcedSizeBuffer)), CHIP_NO_ERROR);
+    ForcedSizeBuffer responseData(sizeOfForcedSizeBuffer);
+    EXPECT_EQ(commandHandler.AddResponseData(aRequestCommandPath, responseData.GetCommandId(), responseData), CHIP_NO_ERROR);
     uint32_t remainingSizeAfter = commandHandler.mInvokeResponseBuilder.GetWriter()->GetRemainingFreeLength();
     uint32_t delta              = remainingSizeBefore - remainingSizeAfter - sizeOfForcedSizeBuffer;
 
@@ -657,7 +661,8 @@
     // Validating assumption. If this fails, it means overheadSizeNeededForAddingResponse is likely too large.
     EXPECT_GE(sizeToFill, 256u);
 
-    EXPECT_EQ(apCommandHandler->AddResponseData(aRequestCommandPath, ForcedSizeBuffer(sizeToFill)), CHIP_NO_ERROR);
+    ForcedSizeBuffer responseData(sizeToFill);
+    EXPECT_EQ(apCommandHandler->AddResponseData(aRequestCommandPath, responseData.GetCommandId(), responseData), CHIP_NO_ERROR);
 }
 
 void TestCommandInteraction::ValidateCommandHandlerEncodeInvokeResponseMessage(bool aNeedStatusCode)
@@ -1087,6 +1092,47 @@
     EXPECT_EQ(mockCommandSenderExtendedDelegate.onErrorCalledTimes, 0);
 }
 
+TEST_F_FROM_FIXTURE(TestCommandInteraction, TestCommandSender_ValidateSecondLargeAddRequestDataRollbacked)
+{
+    mockCommandSenderExtendedDelegate.ResetCounter();
+    PendingResponseTrackerImpl pendingResponseTracker;
+    app::CommandSender commandSender(kCommandSenderTestOnlyMarker, &mockCommandSenderExtendedDelegate,
+                                     &mpTestContext->GetExchangeManager(), &pendingResponseTracker);
+
+    app::CommandSender::AddRequestDataParameters addRequestDataParams;
+
+    CommandSender::ConfigParameters config;
+    config.SetRemoteMaxPathsPerInvoke(2);
+    EXPECT_EQ(commandSender.SetCommandSenderConfig(config), CHIP_NO_ERROR);
+
+    // The specific values chosen here are arbitrary.
+    uint16_t firstCommandRef  = 1;
+    uint16_t secondCommandRef = 2;
+    auto commandPathParams    = MakeTestCommandPath();
+    SimpleTLVPayload simplePayloadWriter;
+    addRequestDataParams.SetCommandRef(firstCommandRef);
+
+    EXPECT_EQ(commandSender.AddRequestData(commandPathParams, simplePayloadWriter, addRequestDataParams), CHIP_NO_ERROR);
+
+    uint32_t remainingSize = commandSender.mInvokeRequestBuilder.GetWriter()->GetRemainingFreeLength();
+    // Because request is made of both request data and request path (commandPathParams), using
+    // `remainingSize` is large enough fail.
+    ForcedSizeBuffer requestData(remainingSize);
+
+    addRequestDataParams.SetCommandRef(secondCommandRef);
+    EXPECT_EQ(commandSender.AddRequestData(commandPathParams, requestData, addRequestDataParams), CHIP_ERROR_NO_MEMORY);
+
+    // Confirm that we can still send out a request with the first command.
+    EXPECT_EQ(commandSender.SendCommandRequest(mpTestContext->GetSessionBobToAlice()), CHIP_NO_ERROR);
+    EXPECT_EQ(commandSender.GetInvokeResponseMessageCount(), 0u);
+
+    mpTestContext->DrainAndServiceIO();
+
+    EXPECT_EQ(mockCommandSenderExtendedDelegate.onResponseCalledTimes, 1);
+    EXPECT_EQ(mockCommandSenderExtendedDelegate.onFinalCalledTimes, 1);
+    EXPECT_EQ(mockCommandSenderExtendedDelegate.onErrorCalledTimes, 0);
+}
+
 TEST_F(TestCommandInteraction, TestCommandHandlerEncodeSimpleCommandData)
 {
     // Send response which has simple command data and command path
@@ -1188,7 +1234,8 @@
     CommandHandlerImpl commandHandler(&mockCommandHandlerDelegate);
 
     uint32_t sizeToFill = 50; // This is an arbitrary number, we need to select a non-zero value.
-    EXPECT_EQ(commandHandler.AddResponseData(requestCommandPath, ForcedSizeBuffer(sizeToFill)), CHIP_NO_ERROR);
+    ForcedSizeBuffer responseData(sizeToFill);
+    EXPECT_EQ(commandHandler.AddResponseData(requestCommandPath, responseData.GetCommandId(), responseData), CHIP_NO_ERROR);
 
     // Since calling AddResponseData is supposed to be a no-operation when there is no responder, it is
     // hard to validate. Best way is to check that we are still in an Idle state afterwards
@@ -1813,7 +1860,8 @@
     EXPECT_EQ(remainingSize, sizeToLeave);
 
     uint32_t sizeToFill = 50;
-    EXPECT_EQ(commandHandler.AddResponseData(requestCommandPath2, ForcedSizeBuffer(sizeToFill)), CHIP_NO_ERROR);
+    ForcedSizeBuffer responseData(sizeToFill);
+    EXPECT_EQ(commandHandler.AddResponseData(requestCommandPath2, responseData.GetCommandId(), responseData), CHIP_NO_ERROR);
 
     remainingSize = commandHandler.mInvokeResponseBuilder.GetWriter()->GetRemainingFreeLength();
     EXPECT_GT(remainingSize, sizeToLeave);