Fix status reponse in command sender (#21518)

diff --git a/src/app/CommandSender.cpp b/src/app/CommandSender.cpp
index 0b6bad6..7af043c 100644
--- a/src/app/CommandSender.cpp
+++ b/src/app/CommandSender.cpp
@@ -136,16 +136,25 @@
         MoveToState(State::ResponseReceived);
     }
 
-    CHIP_ERROR err = CHIP_NO_ERROR;
+    CHIP_ERROR err          = CHIP_NO_ERROR;
+    bool sendStatusResponse = false;
     VerifyOrExit(apExchangeContext == mExchangeCtx.Get(), err = CHIP_ERROR_INCORRECT_STATE);
+    sendStatusResponse = true;
 
     if (mState == State::AwaitingTimedStatus)
     {
-        VerifyOrExit(aPayloadHeader.HasMessageType(MsgType::StatusResponse), err = CHIP_ERROR_INVALID_MESSAGE_TYPE);
-        CHIP_ERROR statusError = CHIP_NO_ERROR;
-        SuccessOrExit(err = StatusResponse::ProcessStatusResponse(std::move(aPayload), statusError));
-        SuccessOrExit(err = statusError);
-        err = SendInvokeRequest();
+        if (aPayloadHeader.HasMessageType(Protocols::InteractionModel::MsgType::StatusResponse))
+        {
+            CHIP_ERROR statusError = CHIP_NO_ERROR;
+            SuccessOrExit(err = StatusResponse::ProcessStatusResponse(std::move(aPayload), statusError));
+            sendStatusResponse = false;
+            SuccessOrExit(err = statusError);
+            err = SendInvokeRequest();
+        }
+        else
+        {
+            err = CHIP_ERROR_INVALID_MESSAGE_TYPE;
+        }
         // Skip all other processing here (which is for the response to the
         // invoke request), no matter whether err is success or not.
         goto exit;
@@ -155,6 +164,7 @@
     {
         err = ProcessInvokeResponse(std::move(aPayload));
         SuccessOrExit(err);
+        sendStatusResponse = false;
     }
     else if (aPayloadHeader.HasMessageType(MsgType::StatusResponse))
     {
@@ -177,6 +187,11 @@
         }
     }
 
+    if (sendStatusResponse)
+    {
+        StatusResponse::Send(Status::InvalidAction, apExchangeContext, false /*aExpectResponse*/);
+    }
+
     if (mState != State::CommandSent)
     {
         Close();
diff --git a/src/app/tests/TestCommandInteraction.cpp b/src/app/tests/TestCommandInteraction.cpp
index 3821b7f..ef95b10 100644
--- a/src/app/tests/TestCommandInteraction.cpp
+++ b/src/app/tests/TestCommandInteraction.cpp
@@ -50,6 +50,21 @@
 using TestContext = chip::Test::AppContext;
 using namespace chip::Protocols;
 
+namespace {
+
+void CheckForInvalidAction(nlTestSuite * apSuite, chip::Test::MessageCapturer & messageLog)
+{
+    NL_TEST_ASSERT(apSuite, messageLog.MessageCount() == 1);
+    NL_TEST_ASSERT(apSuite, messageLog.IsMessageType(0, chip::Protocols::InteractionModel::MsgType::StatusResponse));
+    CHIP_ERROR status;
+    NL_TEST_ASSERT(apSuite,
+                   chip::app::StatusResponse::ProcessStatusResponse(std::move(messageLog.MessagePayload(0)), status) ==
+                       CHIP_NO_ERROR);
+    NL_TEST_ASSERT(apSuite, status == CHIP_IM_GLOBAL_STATUS(InvalidAction));
+}
+
+} // anonymous namespace
+
 namespace chip {
 
 namespace {
@@ -172,6 +187,7 @@
     {
         ChipLogError(Controller, "OnError happens with %" CHIP_ERROR_FORMAT, aError.Format());
         onErrorCalledTimes++;
+        mError = aError;
     }
     void OnDone(chip::app::CommandSender * apCommandSender) override { onFinalCalledTimes++; }
 
@@ -185,6 +201,8 @@
     int onResponseCalledTimes = 0;
     int onErrorCalledTimes    = 0;
     int onFinalCalledTimes    = 0;
+
+    CHIP_ERROR mError = CHIP_NO_ERROR;
 } mockCommandSenderDelegate;
 
 class MockCommandHandlerCallback : public CommandHandler::Callback
@@ -215,6 +233,10 @@
     static void TestCommandHandlerWithSendSimpleCommandData(nlTestSuite * apSuite, void * apContext);
     static void TestCommandHandlerCommandDataEncoding(nlTestSuite * apSuite, void * apContext);
     static void TestCommandHandlerCommandEncodeFailure(nlTestSuite * apSuite, void * apContext);
+    static void TestCommandInvalidMessage1(nlTestSuite * apSuite, void * apContext);
+    static void TestCommandInvalidMessage2(nlTestSuite * apSuite, void * apContext);
+    static void TestCommandInvalidMessage3(nlTestSuite * apSuite, void * apContext);
+    static void TestCommandInvalidMessage4(nlTestSuite * apSuite, void * apContext);
     static void TestCommandHandlerCommandEncodeExternalFailure(nlTestSuite * apSuite, void * apContext);
     static void TestCommandHandlerWithSendSimpleStatusCode(nlTestSuite * apSuite, void * apContext);
     static void TestCommandHandlerWithSendEmptyResponse(nlTestSuite * apSuite, void * apContext);
@@ -248,6 +270,7 @@
                                        EndpointId aEndpointId = kTestEndpointId);
     static void AddInvokeRequestData(nlTestSuite * apSuite, void * apContext, CommandSender * apCommandSender,
                                      CommandId aCommandId = kTestCommandIdWithData);
+
     static void AddInvokeResponseData(nlTestSuite * apSuite, void * apContext, CommandHandler * apCommandHandler,
                                       bool aNeedStatusCode, CommandId aCommandId = kTestCommandIdWithData);
     static void ValidateCommandHandlerWithSendCommand(nlTestSuite * apSuite, void * apContext, bool aNeedStatusCode);
@@ -628,6 +651,278 @@
 #endif
 }
 
+// Command Sender sends invoke request, command handler drops invoke response, then test injects status response message with
+// busy to client, client sends out a status response with invalid action.
+void TestCommandInteraction::TestCommandInvalidMessage1(nlTestSuite * apSuite, void * apContext)
+{
+    TestContext & ctx = *static_cast<TestContext *>(apContext);
+    CHIP_ERROR err    = CHIP_NO_ERROR;
+    mockCommandSenderDelegate.ResetCounter();
+    app::CommandSender commandSender(&mockCommandSenderDelegate, &ctx.GetExchangeManager());
+
+    AddInvokeRequestData(apSuite, apContext, &commandSender);
+    asyncCommand = false;
+
+    ctx.GetLoopback().mSentMessageCount                 = 0;
+    ctx.GetLoopback().mNumMessagesToDrop                = 1;
+    ctx.GetLoopback().mNumMessagesToAllowBeforeDropping = 1;
+    err                                                 = commandSender.SendCommandRequest(ctx.GetSessionBobToAlice());
+    NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR);
+    ctx.DrainAndServiceIO();
+
+    NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2);
+    NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mDroppedMessageCount == 1);
+
+    NL_TEST_ASSERT(apSuite,
+                   mockCommandSenderDelegate.onResponseCalledTimes == 0 && mockCommandSenderDelegate.onFinalCalledTimes == 0 &&
+                       mockCommandSenderDelegate.onErrorCalledTimes == 0);
+
+    NL_TEST_ASSERT(apSuite, GetNumActiveHandlerObjects() == 0);
+
+    System::PacketBufferHandle msgBuf = System::PacketBufferHandle::New(kMaxSecureSduLengthBytes);
+    NL_TEST_ASSERT(apSuite, !msgBuf.IsNull());
+    System::PacketBufferTLVWriter writer;
+    writer.Init(std::move(msgBuf));
+    StatusResponseMessage::Builder response;
+    response.Init(&writer);
+    response.Status(Protocols::InteractionModel::Status::Busy);
+    NL_TEST_ASSERT(apSuite, writer.Finalize(&msgBuf) == CHIP_NO_ERROR);
+
+    PayloadHeader payloadHeader;
+    payloadHeader.SetExchangeID(0);
+    payloadHeader.SetMessageType(chip::Protocols::InteractionModel::MsgType::StatusResponse);
+    Test::MessageCapturer messageLog(ctx);
+    messageLog.mCaptureStandaloneAcks = false;
+
+    Messaging::ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr();
+    rm->ClearRetransTable(commandSender.mExchangeCtx.Get());
+    ctx.GetLoopback().mSentMessageCount                 = 0;
+    ctx.GetLoopback().mNumMessagesToDrop                = 0;
+    ctx.GetLoopback().mNumMessagesToAllowBeforeDropping = 0;
+    ctx.GetLoopback().mDroppedMessageCount              = 0;
+
+    err = commandSender.OnMessageReceived(commandSender.mExchangeCtx.Get(), payloadHeader, std::move(msgBuf));
+    NL_TEST_ASSERT(apSuite, err == CHIP_IM_GLOBAL_STATUS(Busy));
+    NL_TEST_ASSERT(apSuite, mockCommandSenderDelegate.mError == CHIP_IM_GLOBAL_STATUS(Busy));
+    NL_TEST_ASSERT(apSuite,
+                   mockCommandSenderDelegate.onResponseCalledTimes == 0 && mockCommandSenderDelegate.onFinalCalledTimes == 1 &&
+                       mockCommandSenderDelegate.onErrorCalledTimes == 1);
+
+    ctx.DrainAndServiceIO();
+
+    // Client sent status report with invalid action, server's exchange has been closed, so all it sent is an MRP Ack
+    NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2);
+    CheckForInvalidAction(apSuite, messageLog);
+    NL_TEST_ASSERT(apSuite, GetNumActiveHandlerObjects() == 0);
+    ctx.ExpireSessionAliceToBob();
+    ctx.ExpireSessionBobToAlice();
+    ctx.CreateSessionAliceToBob();
+    ctx.CreateSessionBobToAlice();
+}
+
+// Command Sender sends invoke request, command handler drops invoke response, then test injects unknown message to client,
+// client sends out status response with invalid action.
+void TestCommandInteraction::TestCommandInvalidMessage2(nlTestSuite * apSuite, void * apContext)
+{
+    TestContext & ctx = *static_cast<TestContext *>(apContext);
+    CHIP_ERROR err    = CHIP_NO_ERROR;
+    mockCommandSenderDelegate.ResetCounter();
+    app::CommandSender commandSender(&mockCommandSenderDelegate, &ctx.GetExchangeManager());
+
+    AddInvokeRequestData(apSuite, apContext, &commandSender);
+    asyncCommand = false;
+
+    ctx.GetLoopback().mSentMessageCount                 = 0;
+    ctx.GetLoopback().mNumMessagesToDrop                = 1;
+    ctx.GetLoopback().mNumMessagesToAllowBeforeDropping = 1;
+    err                                                 = commandSender.SendCommandRequest(ctx.GetSessionBobToAlice());
+    NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR);
+    ctx.DrainAndServiceIO();
+
+    NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2);
+    NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mDroppedMessageCount == 1);
+
+    NL_TEST_ASSERT(apSuite,
+                   mockCommandSenderDelegate.onResponseCalledTimes == 0 && mockCommandSenderDelegate.onFinalCalledTimes == 0 &&
+                       mockCommandSenderDelegate.onErrorCalledTimes == 0);
+
+    NL_TEST_ASSERT(apSuite, GetNumActiveHandlerObjects() == 0);
+
+    System::PacketBufferHandle msgBuf = System::PacketBufferHandle::New(kMaxSecureSduLengthBytes);
+    NL_TEST_ASSERT(apSuite, !msgBuf.IsNull());
+    System::PacketBufferTLVWriter writer;
+    writer.Init(std::move(msgBuf));
+    ReportDataMessage::Builder response;
+    response.Init(&writer);
+    NL_TEST_ASSERT(apSuite, writer.Finalize(&msgBuf) == CHIP_NO_ERROR);
+
+    PayloadHeader payloadHeader;
+    payloadHeader.SetExchangeID(0);
+    payloadHeader.SetMessageType(chip::Protocols::InteractionModel::MsgType::ReportData);
+    Test::MessageCapturer messageLog(ctx);
+    messageLog.mCaptureStandaloneAcks  = false;
+    Messaging::ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr();
+    rm->ClearRetransTable(commandSender.mExchangeCtx.Get());
+    ctx.GetLoopback().mSentMessageCount                 = 0;
+    ctx.GetLoopback().mNumMessagesToDrop                = 0;
+    ctx.GetLoopback().mNumMessagesToAllowBeforeDropping = 0;
+    ctx.GetLoopback().mDroppedMessageCount              = 0;
+
+    err = commandSender.OnMessageReceived(commandSender.mExchangeCtx.Get(), payloadHeader, std::move(msgBuf));
+    NL_TEST_ASSERT(apSuite, err == CHIP_ERROR_INVALID_MESSAGE_TYPE);
+    NL_TEST_ASSERT(apSuite, mockCommandSenderDelegate.mError == CHIP_ERROR_INVALID_MESSAGE_TYPE);
+    NL_TEST_ASSERT(apSuite,
+                   mockCommandSenderDelegate.onResponseCalledTimes == 0 && mockCommandSenderDelegate.onFinalCalledTimes == 1 &&
+                       mockCommandSenderDelegate.onErrorCalledTimes == 1);
+
+    ctx.DrainAndServiceIO();
+
+    // Client sent status report with invalid action, server's exchange has been closed, so all it sent is an MRP Ack
+    NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2);
+    CheckForInvalidAction(apSuite, messageLog);
+    NL_TEST_ASSERT(apSuite, GetNumActiveHandlerObjects() == 0);
+    ctx.ExpireSessionAliceToBob();
+    ctx.ExpireSessionBobToAlice();
+    ctx.CreateSessionAliceToBob();
+    ctx.CreateSessionBobToAlice();
+}
+
+// Command Sender sends invoke request, command handler drops invoke response, then test injects malformed invoke response
+// message to client, client sends out status response with invalid action.
+void TestCommandInteraction::TestCommandInvalidMessage3(nlTestSuite * apSuite, void * apContext)
+{
+    TestContext & ctx = *static_cast<TestContext *>(apContext);
+    CHIP_ERROR err    = CHIP_NO_ERROR;
+    mockCommandSenderDelegate.ResetCounter();
+    app::CommandSender commandSender(&mockCommandSenderDelegate, &ctx.GetExchangeManager());
+
+    AddInvokeRequestData(apSuite, apContext, &commandSender);
+    asyncCommand = false;
+
+    ctx.GetLoopback().mSentMessageCount                 = 0;
+    ctx.GetLoopback().mNumMessagesToDrop                = 1;
+    ctx.GetLoopback().mNumMessagesToAllowBeforeDropping = 1;
+    err                                                 = commandSender.SendCommandRequest(ctx.GetSessionBobToAlice());
+    NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR);
+    ctx.DrainAndServiceIO();
+
+    NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2);
+    NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mDroppedMessageCount == 1);
+
+    NL_TEST_ASSERT(apSuite,
+                   mockCommandSenderDelegate.onResponseCalledTimes == 0 && mockCommandSenderDelegate.onFinalCalledTimes == 0 &&
+                       mockCommandSenderDelegate.onErrorCalledTimes == 0);
+
+    NL_TEST_ASSERT(apSuite, GetNumActiveHandlerObjects() == 0);
+
+    System::PacketBufferHandle msgBuf = System::PacketBufferHandle::New(kMaxSecureSduLengthBytes);
+    NL_TEST_ASSERT(apSuite, !msgBuf.IsNull());
+    System::PacketBufferTLVWriter writer;
+    writer.Init(std::move(msgBuf));
+    InvokeResponseMessage::Builder response;
+    response.Init(&writer);
+    NL_TEST_ASSERT(apSuite, writer.Finalize(&msgBuf) == CHIP_NO_ERROR);
+
+    PayloadHeader payloadHeader;
+    payloadHeader.SetExchangeID(0);
+    payloadHeader.SetMessageType(chip::Protocols::InteractionModel::MsgType::InvokeCommandResponse);
+    Test::MessageCapturer messageLog(ctx);
+    messageLog.mCaptureStandaloneAcks = false;
+
+    Messaging::ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr();
+    rm->ClearRetransTable(commandSender.mExchangeCtx.Get());
+    ctx.GetLoopback().mSentMessageCount                 = 0;
+    ctx.GetLoopback().mNumMessagesToDrop                = 0;
+    ctx.GetLoopback().mNumMessagesToAllowBeforeDropping = 0;
+    ctx.GetLoopback().mDroppedMessageCount              = 0;
+
+    err = commandSender.OnMessageReceived(commandSender.mExchangeCtx.Get(), payloadHeader, std::move(msgBuf));
+    NL_TEST_ASSERT(apSuite, err == CHIP_ERROR_END_OF_TLV);
+    NL_TEST_ASSERT(apSuite, mockCommandSenderDelegate.mError == CHIP_ERROR_END_OF_TLV);
+    NL_TEST_ASSERT(apSuite,
+                   mockCommandSenderDelegate.onResponseCalledTimes == 0 && mockCommandSenderDelegate.onFinalCalledTimes == 1 &&
+                       mockCommandSenderDelegate.onErrorCalledTimes == 1);
+
+    ctx.DrainAndServiceIO();
+
+    // Client sent status report with invalid action, server's exchange has been closed, so all it sent is an MRP Ack
+    NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2);
+    CheckForInvalidAction(apSuite, messageLog);
+    NL_TEST_ASSERT(apSuite, GetNumActiveHandlerObjects() == 0);
+    ctx.ExpireSessionAliceToBob();
+    ctx.ExpireSessionBobToAlice();
+    ctx.CreateSessionAliceToBob();
+    ctx.CreateSessionBobToAlice();
+}
+
+// Command Sender sends invoke request, command handler drops invoke response, then test injects malformed status response to
+// client, client responds to the status response with invalid action.
+void TestCommandInteraction::TestCommandInvalidMessage4(nlTestSuite * apSuite, void * apContext)
+{
+    TestContext & ctx = *static_cast<TestContext *>(apContext);
+    CHIP_ERROR err    = CHIP_NO_ERROR;
+    mockCommandSenderDelegate.ResetCounter();
+    app::CommandSender commandSender(&mockCommandSenderDelegate, &ctx.GetExchangeManager());
+
+    AddInvokeRequestData(apSuite, apContext, &commandSender);
+    asyncCommand = false;
+
+    ctx.GetLoopback().mSentMessageCount                 = 0;
+    ctx.GetLoopback().mNumMessagesToDrop                = 1;
+    ctx.GetLoopback().mNumMessagesToAllowBeforeDropping = 1;
+    err                                                 = commandSender.SendCommandRequest(ctx.GetSessionBobToAlice());
+    NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR);
+    ctx.DrainAndServiceIO();
+
+    NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2);
+    NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mDroppedMessageCount == 1);
+
+    NL_TEST_ASSERT(apSuite,
+                   mockCommandSenderDelegate.onResponseCalledTimes == 0 && mockCommandSenderDelegate.onFinalCalledTimes == 0 &&
+                       mockCommandSenderDelegate.onErrorCalledTimes == 0);
+
+    NL_TEST_ASSERT(apSuite, GetNumActiveHandlerObjects() == 0);
+
+    System::PacketBufferHandle msgBuf = System::PacketBufferHandle::New(kMaxSecureSduLengthBytes);
+    NL_TEST_ASSERT(apSuite, !msgBuf.IsNull());
+    System::PacketBufferTLVWriter writer;
+    writer.Init(std::move(msgBuf));
+    StatusResponseMessage::Builder response;
+    response.Init(&writer);
+    NL_TEST_ASSERT(apSuite, writer.Finalize(&msgBuf) == CHIP_NO_ERROR);
+
+    PayloadHeader payloadHeader;
+    payloadHeader.SetExchangeID(0);
+    payloadHeader.SetMessageType(chip::Protocols::InteractionModel::MsgType::StatusResponse);
+    Test::MessageCapturer messageLog(ctx);
+    messageLog.mCaptureStandaloneAcks = false;
+
+    Messaging::ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr();
+    rm->ClearRetransTable(commandSender.mExchangeCtx.Get());
+    ctx.GetLoopback().mSentMessageCount                 = 0;
+    ctx.GetLoopback().mNumMessagesToDrop                = 0;
+    ctx.GetLoopback().mNumMessagesToAllowBeforeDropping = 0;
+    ctx.GetLoopback().mDroppedMessageCount              = 0;
+
+    err = commandSender.OnMessageReceived(commandSender.mExchangeCtx.Get(), payloadHeader, std::move(msgBuf));
+    NL_TEST_ASSERT(apSuite, err == CHIP_ERROR_END_OF_TLV);
+    NL_TEST_ASSERT(apSuite, mockCommandSenderDelegate.mError == CHIP_ERROR_END_OF_TLV);
+    NL_TEST_ASSERT(apSuite,
+                   mockCommandSenderDelegate.onResponseCalledTimes == 0 && mockCommandSenderDelegate.onFinalCalledTimes == 1 &&
+                       mockCommandSenderDelegate.onErrorCalledTimes == 1);
+
+    ctx.DrainAndServiceIO();
+
+    // Client sent status report with invalid action, server's exchange has been closed, so all it sent is an MRP Ack
+    NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2);
+    CheckForInvalidAction(apSuite, messageLog);
+    NL_TEST_ASSERT(apSuite, GetNumActiveHandlerObjects() == 0);
+    ctx.ExpireSessionAliceToBob();
+    ctx.ExpireSessionBobToAlice();
+    ctx.CreateSessionAliceToBob();
+    ctx.CreateSessionBobToAlice();
+}
+
 void TestCommandInteraction::TestCommandHandlerCommandEncodeExternalFailure(nlTestSuite * apSuite, void * apContext)
 {
     TestContext & ctx = *static_cast<TestContext *>(apContext);
@@ -959,6 +1254,10 @@
 // clang-format off
 const nlTest sTests[] =
 {
+    NL_TEST_DEF("TestCommandInvalidMessage1", chip::app::TestCommandInteraction::TestCommandInvalidMessage1),
+    NL_TEST_DEF("TestCommandInvalidMessage2", chip::app::TestCommandInteraction::TestCommandInvalidMessage2),
+    NL_TEST_DEF("TestCommandInvalidMessage3", chip::app::TestCommandInteraction::TestCommandInvalidMessage3),
+    NL_TEST_DEF("TestCommandInvalidMessage4", chip::app::TestCommandInteraction::TestCommandInvalidMessage4),
     NL_TEST_DEF("TestCommandSenderWithWrongState", chip::app::TestCommandInteraction::TestCommandSenderWithWrongState),
     NL_TEST_DEF("TestCommandHandlerWithWrongState", chip::app::TestCommandInteraction::TestCommandHandlerWithWrongState),
     NL_TEST_DEF("TestCommandSenderWithSendCommand", chip::app::TestCommandInteraction::TestCommandSenderWithSendCommand),