Rollback InvokeRequestMessage when AddResponseData fails (#33849)
* Rollback InvokeRequestMessage when AddResponseData fails
* Restyled by clang-format
* Fix rebase issues with TestCommandInteraction
* Move new test to pigweed style unit test after rebase
* Address PR comments
* Restyled by clang-format
* Address PR comments
* Address PR comments
* Small fixes
* Small fixes
* Restyled by clang-format
* Address PR comment
---------
Co-authored-by: Restyled.io <commits@restyled.io>
diff --git a/src/app/CommandSender.cpp b/src/app/CommandSender.cpp
index a91d0fd..2f61f94 100644
--- a/src/app/CommandSender.cpp
+++ b/src/app/CommandSender.cpp
@@ -555,13 +555,18 @@
CHIP_ERROR CommandSender::AddRequestData(const CommandPathParams & aCommandPath, const DataModel::EncodableToTLV & aEncodable,
AddRequestDataParameters & aAddRequestDataParams)
{
+ ReturnErrorOnFailure(AllocateBuffer());
+
+ RollbackInvokeRequest rollback(*this);
PrepareCommandParameters prepareCommandParams(aAddRequestDataParams);
ReturnErrorOnFailure(PrepareCommand(aCommandPath, prepareCommandParams));
TLV::TLVWriter * writer = GetCommandDataIBTLVWriter();
VerifyOrReturnError(writer != nullptr, CHIP_ERROR_INCORRECT_STATE);
ReturnErrorOnFailure(aEncodable.EncodeTo(*writer, TLV::ContextTag(CommandDataIB::Tag::kFields)));
FinishCommandParameters finishCommandParams(aAddRequestDataParams);
- return FinishCommand(finishCommandParams);
+ ReturnErrorOnFailure(FinishCommand(finishCommandParams));
+ rollback.DisableAutomaticRollback();
+ return CHIP_NO_ERROR;
}
CHIP_ERROR CommandSender::FinishCommandInternal(FinishCommandParameters & aFinishCommandParams)
@@ -672,5 +677,34 @@
ChipLogDetail(DataManagement, "ICR moving to [%10.10s]", GetStateStr());
}
+CommandSender::RollbackInvokeRequest::RollbackInvokeRequest(CommandSender & aCommandSender) : mCommandSender(aCommandSender)
+{
+ VerifyOrReturn(mCommandSender.mBufferAllocated);
+ VerifyOrReturn(mCommandSender.mState == State::Idle || mCommandSender.mState == State::AddedCommand);
+ VerifyOrReturn(mCommandSender.mInvokeRequestBuilder.GetInvokeRequests().GetError() == CHIP_NO_ERROR);
+ VerifyOrReturn(mCommandSender.mInvokeRequestBuilder.GetError() == CHIP_NO_ERROR);
+ mCommandSender.mInvokeRequestBuilder.Checkpoint(mBackupWriter);
+ mBackupState = mCommandSender.mState;
+ mRollbackInDestructor = true;
+}
+
+CommandSender::RollbackInvokeRequest::~RollbackInvokeRequest()
+{
+ VerifyOrReturn(mRollbackInDestructor);
+ VerifyOrReturn(mCommandSender.mState == State::AddingCommand);
+ ChipLogDetail(DataManagement, "Rolling back response");
+ // TODO(#30453): Rollback of mInvokeRequestBuilder should handle resetting
+ // InvokeRequests.
+ mCommandSender.mInvokeRequestBuilder.GetInvokeRequests().ResetError();
+ mCommandSender.mInvokeRequestBuilder.Rollback(mBackupWriter);
+ mCommandSender.MoveToState(mBackupState);
+ mRollbackInDestructor = false;
+}
+
+void CommandSender::RollbackInvokeRequest::DisableAutomaticRollback()
+{
+ mRollbackInDestructor = false;
+}
+
} // namespace app
} // namespace chip
diff --git a/src/app/CommandSender.h b/src/app/CommandSender.h
index e34dbe7..30cf55b 100644
--- a/src/app/CommandSender.h
+++ b/src/app/CommandSender.h
@@ -216,6 +216,12 @@
AddRequestDataParameters(const Optional<uint16_t> & aTimedInvokeTimeoutMs) : timedInvokeTimeoutMs(aTimedInvokeTimeoutMs) {}
+ AddRequestDataParameters & SetCommandRef(uint16_t aCommandRef)
+ {
+ commandRef.SetValue(aCommandRef);
+ return *this;
+ }
+
// When a value is provided for timedInvokeTimeoutMs, this invoke becomes a timed
// invoke. CommandSender will use the minimum of all provided timeouts for execution.
const Optional<uint16_t> timedInvokeTimeoutMs;
@@ -512,6 +518,34 @@
AwaitingDestruction, ///< The object has completed its work and is awaiting destruction by the application.
};
+ /**
+ * Class to help backup CommandSender's buffer containing InvokeRequestMessage when adding InvokeRequest
+ * in case there is a failure to add InvokeRequest. Intended usage is as follows:
+ * - Allocate RollbackInvokeRequest on the stack.
+ * - Attempt adding InvokeRequest into InvokeRequestMessage buffer.
+ * - If modification is added successfully, call DisableAutomaticRollback() to prevent destructor from
+ * rolling back InvokeReqestMessage.
+ * - If there is an issue adding InvokeRequest, destructor will take care of rolling back
+ * InvokeRequestMessage to previously saved state.
+ */
+ class RollbackInvokeRequest
+ {
+ public:
+ explicit RollbackInvokeRequest(CommandSender & aCommandSender);
+ ~RollbackInvokeRequest();
+
+ /**
+ * Disables rolling back to previously saved state for InvokeRequestMessage.
+ */
+ void DisableAutomaticRollback();
+
+ private:
+ CommandSender & mCommandSender;
+ TLV::TLVWriter mBackupWriter;
+ State mBackupState;
+ bool mRollbackInDestructor = false;
+ };
+
union CallbackHandle
{
CallbackHandle(Callback * apCallback) : legacyCallback(apCallback) {}
diff --git a/src/app/tests/TestCommandInteraction.cpp b/src/app/tests/TestCommandInteraction.cpp
index 0573db4..8c2537a 100644
--- a/src/app/tests/TestCommandInteraction.cpp
+++ b/src/app/tests/TestCommandInteraction.cpp
@@ -109,10 +109,9 @@
kSizeGreaterThan255,
};
-struct ForcedSizeBuffer
+class ForcedSizeBuffer : public app::DataModel::EncodableToTLV
{
- chip::Platform::ScopedMemoryBufferWithSize<uint8_t> mBuffer;
-
+public:
ForcedSizeBuffer(uint32_t size)
{
if (mBuffer.Alloc(size))
@@ -124,7 +123,7 @@
// No significance with using 0x12 as the CommandId, just using a value.
static constexpr chip::CommandId GetCommandId() { return 0x12; }
- CHIP_ERROR Encode(TLV::TLVWriter & aWriter, TLV::Tag aTag) const
+ CHIP_ERROR EncodeTo(TLV::TLVWriter & aWriter, TLV::Tag aTag) const override
{
VerifyOrReturnError(mBuffer, CHIP_ERROR_NO_MEMORY);
@@ -133,6 +132,9 @@
ReturnErrorOnFailure(app::DataModel::Encode(aWriter, TLV::ContextTag(1), ByteSpan(mBuffer.Get(), mBuffer.AllocatedSize())));
return aWriter.EndContainer(outerContainerType);
}
+
+private:
+ chip::Platform::ScopedMemoryBufferWithSize<uint8_t> mBuffer;
};
struct Fields
@@ -387,6 +389,7 @@
void TestCommandSender_WithProcessReceivedMsg();
void TestCommandSender_ExtendableApiWithProcessReceivedMsg();
void TestCommandSender_ExtendableApiWithProcessReceivedMsgContainingInvalidCommandRef();
+ void TestCommandSender_ValidateSecondLargeAddRequestDataRollbacked();
void TestCommandHandler_WithoutResponderCallingAddStatus();
void TestCommandHandler_WithoutResponderCallingAddResponse();
void TestCommandHandler_WithoutResponderCallingDirectPrepareFinishCommandApis();
@@ -632,7 +635,8 @@
// When ForcedSizeBuffer exceeds 255, an extra byte is needed for length, affecting the overhead size required by
// AddResponseData. In order to have this accounted for in overhead calculation we set the length to be 256.
uint32_t sizeOfForcedSizeBuffer = aBufferSizeHint == ForcedSizeBufferLengthHint::kSizeGreaterThan255 ? 256 : 0;
- EXPECT_EQ(commandHandler.AddResponseData(aRequestCommandPath, ForcedSizeBuffer(sizeOfForcedSizeBuffer)), CHIP_NO_ERROR);
+ ForcedSizeBuffer responseData(sizeOfForcedSizeBuffer);
+ EXPECT_EQ(commandHandler.AddResponseData(aRequestCommandPath, responseData.GetCommandId(), responseData), CHIP_NO_ERROR);
uint32_t remainingSizeAfter = commandHandler.mInvokeResponseBuilder.GetWriter()->GetRemainingFreeLength();
uint32_t delta = remainingSizeBefore - remainingSizeAfter - sizeOfForcedSizeBuffer;
@@ -657,7 +661,8 @@
// Validating assumption. If this fails, it means overheadSizeNeededForAddingResponse is likely too large.
EXPECT_GE(sizeToFill, 256u);
- EXPECT_EQ(apCommandHandler->AddResponseData(aRequestCommandPath, ForcedSizeBuffer(sizeToFill)), CHIP_NO_ERROR);
+ ForcedSizeBuffer responseData(sizeToFill);
+ EXPECT_EQ(apCommandHandler->AddResponseData(aRequestCommandPath, responseData.GetCommandId(), responseData), CHIP_NO_ERROR);
}
void TestCommandInteraction::ValidateCommandHandlerEncodeInvokeResponseMessage(bool aNeedStatusCode)
@@ -1087,6 +1092,47 @@
EXPECT_EQ(mockCommandSenderExtendedDelegate.onErrorCalledTimes, 0);
}
+TEST_F_FROM_FIXTURE(TestCommandInteraction, TestCommandSender_ValidateSecondLargeAddRequestDataRollbacked)
+{
+ mockCommandSenderExtendedDelegate.ResetCounter();
+ PendingResponseTrackerImpl pendingResponseTracker;
+ app::CommandSender commandSender(kCommandSenderTestOnlyMarker, &mockCommandSenderExtendedDelegate,
+ &mpTestContext->GetExchangeManager(), &pendingResponseTracker);
+
+ app::CommandSender::AddRequestDataParameters addRequestDataParams;
+
+ CommandSender::ConfigParameters config;
+ config.SetRemoteMaxPathsPerInvoke(2);
+ EXPECT_EQ(commandSender.SetCommandSenderConfig(config), CHIP_NO_ERROR);
+
+ // The specific values chosen here are arbitrary.
+ uint16_t firstCommandRef = 1;
+ uint16_t secondCommandRef = 2;
+ auto commandPathParams = MakeTestCommandPath();
+ SimpleTLVPayload simplePayloadWriter;
+ addRequestDataParams.SetCommandRef(firstCommandRef);
+
+ EXPECT_EQ(commandSender.AddRequestData(commandPathParams, simplePayloadWriter, addRequestDataParams), CHIP_NO_ERROR);
+
+ uint32_t remainingSize = commandSender.mInvokeRequestBuilder.GetWriter()->GetRemainingFreeLength();
+ // Because request is made of both request data and request path (commandPathParams), using
+ // `remainingSize` is large enough fail.
+ ForcedSizeBuffer requestData(remainingSize);
+
+ addRequestDataParams.SetCommandRef(secondCommandRef);
+ EXPECT_EQ(commandSender.AddRequestData(commandPathParams, requestData, addRequestDataParams), CHIP_ERROR_NO_MEMORY);
+
+ // Confirm that we can still send out a request with the first command.
+ EXPECT_EQ(commandSender.SendCommandRequest(mpTestContext->GetSessionBobToAlice()), CHIP_NO_ERROR);
+ EXPECT_EQ(commandSender.GetInvokeResponseMessageCount(), 0u);
+
+ mpTestContext->DrainAndServiceIO();
+
+ EXPECT_EQ(mockCommandSenderExtendedDelegate.onResponseCalledTimes, 1);
+ EXPECT_EQ(mockCommandSenderExtendedDelegate.onFinalCalledTimes, 1);
+ EXPECT_EQ(mockCommandSenderExtendedDelegate.onErrorCalledTimes, 0);
+}
+
TEST_F(TestCommandInteraction, TestCommandHandlerEncodeSimpleCommandData)
{
// Send response which has simple command data and command path
@@ -1188,7 +1234,8 @@
CommandHandlerImpl commandHandler(&mockCommandHandlerDelegate);
uint32_t sizeToFill = 50; // This is an arbitrary number, we need to select a non-zero value.
- EXPECT_EQ(commandHandler.AddResponseData(requestCommandPath, ForcedSizeBuffer(sizeToFill)), CHIP_NO_ERROR);
+ ForcedSizeBuffer responseData(sizeToFill);
+ EXPECT_EQ(commandHandler.AddResponseData(requestCommandPath, responseData.GetCommandId(), responseData), CHIP_NO_ERROR);
// Since calling AddResponseData is supposed to be a no-operation when there is no responder, it is
// hard to validate. Best way is to check that we are still in an Idle state afterwards
@@ -1813,7 +1860,8 @@
EXPECT_EQ(remainingSize, sizeToLeave);
uint32_t sizeToFill = 50;
- EXPECT_EQ(commandHandler.AddResponseData(requestCommandPath2, ForcedSizeBuffer(sizeToFill)), CHIP_NO_ERROR);
+ ForcedSizeBuffer responseData(sizeToFill);
+ EXPECT_EQ(commandHandler.AddResponseData(requestCommandPath2, responseData.GetCommandId(), responseData), CHIP_NO_ERROR);
remainingSize = commandHandler.mInvokeResponseBuilder.GetWriter()->GetRemainingFreeLength();
EXPECT_GT(remainingSize, sizeToLeave);