Fix status reponse in write client (#21511)

diff --git a/src/app/WriteClient.cpp b/src/app/WriteClient.cpp
index 85af809..bbd4d28 100644
--- a/src/app/WriteClient.cpp
+++ b/src/app/WriteClient.cpp
@@ -423,22 +423,30 @@
         MoveToState(State::ResponseReceived);
     }
 
-    CHIP_ERROR err = CHIP_NO_ERROR;
-
+    CHIP_ERROR err          = CHIP_NO_ERROR;
+    bool sendStatusResponse = false;
     // Assert that the exchange context matches the client's current context.
     // This should never fail because even if SendWriteRequest is called
     // back-to-back, the second call will call Close() on the first exchange,
     // which clears the OnMessageReceived callback.
     VerifyOrExit(apExchangeContext == mExchangeCtx.Get(), err = CHIP_ERROR_INCORRECT_STATE);
 
+    sendStatusResponse = true;
+
     if (mState == State::AwaitingTimedStatus)
     {
-        VerifyOrExit(aPayloadHeader.HasMessageType(MsgType::StatusResponse), err = CHIP_ERROR_INVALID_MESSAGE_TYPE);
-        CHIP_ERROR statusError = CHIP_NO_ERROR;
-        SuccessOrExit(err = StatusResponse::ProcessStatusResponse(std::move(aPayload), statusError));
-        SuccessOrExit(err = statusError);
-        err = SendWriteRequest();
-
+        if (aPayloadHeader.HasMessageType(MsgType::StatusResponse))
+        {
+            CHIP_ERROR statusError = CHIP_NO_ERROR;
+            SuccessOrExit(err = StatusResponse::ProcessStatusResponse(std::move(aPayload), statusError));
+            sendStatusResponse = false;
+            SuccessOrExit(err = statusError);
+            err = SendWriteRequest();
+        }
+        else
+        {
+            err = CHIP_ERROR_INVALID_MESSAGE_TYPE;
+        }
         // Skip all other processing here (which is for the response to the
         // write request), no matter whether err is success or not.
         goto exit;
@@ -448,6 +456,7 @@
     {
         err = ProcessWriteResponseMessage(std::move(aPayload));
         SuccessOrExit(err);
+        sendStatusResponse = false;
         if (!mChunks.IsNull())
         {
             // Send the next chunk.
@@ -475,6 +484,11 @@
         }
     }
 
+    if (sendStatusResponse)
+    {
+        StatusResponse::Send(Status::InvalidAction, apExchangeContext, false /*aExpectResponse*/);
+    }
+
     if (mState != State::AwaitingResponse)
     {
         Close();
diff --git a/src/app/tests/TestWriteInteraction.cpp b/src/app/tests/TestWriteInteraction.cpp
index e014994..dcd13e5 100644
--- a/src/app/tests/TestWriteInteraction.cpp
+++ b/src/app/tests/TestWriteInteraction.cpp
@@ -60,6 +60,10 @@
     static void TestWriteClientGroup(nlTestSuite * apSuite, void * apContext);
     static void TestWriteHandler(nlTestSuite * apSuite, void * apContext);
     static void TestWriteRoundtrip(nlTestSuite * apSuite, void * apContext);
+    static void TestWriteInvalidMessage1(nlTestSuite * apSuite, void * apContext);
+    static void TestWriteInvalidMessage2(nlTestSuite * apSuite, void * apContext);
+    static void TestWriteInvalidMessage3(nlTestSuite * apSuite, void * apContext);
+    static void TestWriteInvalidMessage4(nlTestSuite * apSuite, void * apContext);
     static void TestWriteRoundtripWithClusterObjects(nlTestSuite * apSuite, void * apContext);
     static void TestWriteRoundtripWithClusterObjectsVersionMatch(nlTestSuite * apSuite, void * apContext);
     static void TestWriteRoundtripWithClusterObjectsVersionMismatch(nlTestSuite * apSuite, void * apContext);
@@ -98,6 +102,7 @@
     {
         mOnErrorCalled++;
         mLastErrorReason = app::StatusIB(chipError);
+        mError           = chipError;
     }
     void OnDone(WriteClient * apWriteClient) override { mOnDoneCalled++; }
 
@@ -106,6 +111,7 @@
     int mOnDoneCalled    = 0;
     StatusIB mStatus;
     StatusIB mLastErrorReason;
+    CHIP_ERROR mError = CHIP_NO_ERROR;
 };
 
 void TestWriteInteraction::AddAttributeDataIB(nlTestSuite * apSuite, void * apContext, WriteClient & aWriteClient)
@@ -621,6 +627,275 @@
     ctx.CreateSessionBobToAlice();
 }
 #endif
+
+// Write Client sends a write request, receives an unexpected message type, sends a status response to that.
+void TestWriteInteraction::TestWriteInvalidMessage1(nlTestSuite * apSuite, void * apContext)
+{
+    TestContext & ctx = *static_cast<TestContext *>(apContext);
+
+    CHIP_ERROR err = CHIP_NO_ERROR;
+
+    Messaging::ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr();
+    // Shouldn't have anything in the retransmit table when starting the test.
+    NL_TEST_ASSERT(apSuite, rm->TestGetCountRetransTable() == 0);
+
+    TestWriteClientCallback callback;
+    auto * engine = chip::app::InteractionModelEngine::GetInstance();
+    err           = engine->Init(&ctx.GetExchangeManager(), &ctx.GetFabricTable());
+    NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR);
+
+    app::WriteClient writeClient(engine->GetExchangeManager(), &callback, Optional<uint16_t>::Missing());
+
+    System::PacketBufferHandle buf = System::PacketBufferHandle::New(System::PacketBuffer::kMaxSize);
+    AddAttributeDataIB(apSuite, apContext, writeClient);
+
+    NL_TEST_ASSERT(apSuite, callback.mOnSuccessCalled == 0 && callback.mOnErrorCalled == 0 && callback.mOnDoneCalled == 0);
+
+    ctx.GetLoopback().mSentMessageCount                 = 0;
+    ctx.GetLoopback().mNumMessagesToDrop                = 1;
+    ctx.GetLoopback().mNumMessagesToAllowBeforeDropping = 1;
+    ctx.GetLoopback().mDroppedMessageCount              = 0;
+    err                                                 = writeClient.SendWriteRequest(ctx.GetSessionBobToAlice());
+    NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR);
+    ctx.DrainAndServiceIO();
+
+    NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2);
+    NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mDroppedMessageCount == 1);
+
+    System::PacketBufferHandle msgBuf = System::PacketBufferHandle::New(kMaxSecureSduLengthBytes);
+    NL_TEST_ASSERT(apSuite, !msgBuf.IsNull());
+    System::PacketBufferTLVWriter writer;
+    writer.Init(std::move(msgBuf));
+    ReportDataMessage::Builder response;
+    response.Init(&writer);
+    NL_TEST_ASSERT(apSuite, writer.Finalize(&msgBuf) == CHIP_NO_ERROR);
+    PayloadHeader payloadHeader;
+    payloadHeader.SetExchangeID(0);
+    payloadHeader.SetMessageType(chip::Protocols::InteractionModel::MsgType::ReportData);
+
+    rm->ClearRetransTable(writeClient.mExchangeCtx.Get());
+    ctx.GetLoopback().mSentMessageCount                 = 0;
+    ctx.GetLoopback().mNumMessagesToDrop                = 0;
+    ctx.GetLoopback().mNumMessagesToAllowBeforeDropping = 0;
+    ctx.GetLoopback().mDroppedMessageCount              = 0;
+    err = writeClient.OnMessageReceived(writeClient.mExchangeCtx.Get(), payloadHeader, std::move(msgBuf));
+    NL_TEST_ASSERT(apSuite, err == CHIP_ERROR_INVALID_MESSAGE_TYPE);
+    ctx.DrainAndServiceIO();
+    NL_TEST_ASSERT(apSuite, callback.mError == CHIP_ERROR_INVALID_MESSAGE_TYPE);
+    NL_TEST_ASSERT(apSuite, callback.mOnSuccessCalled == 0 && callback.mOnErrorCalled == 1 && callback.mOnDoneCalled == 1);
+
+    // TODO: Check that the server gets the right status.
+    // Client sents status report with invalid action, server's exchange has been closed, so all it sends is an MRP Ack
+    NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2);
+
+    engine->Shutdown();
+    ctx.ExpireSessionAliceToBob();
+    ctx.ExpireSessionBobToAlice();
+    ctx.CreateSessionAliceToBob();
+    ctx.CreateSessionBobToAlice();
+}
+
+// Write Client sends a write request, receives a malformed write response message, sends a Status Report.
+void TestWriteInteraction::TestWriteInvalidMessage2(nlTestSuite * apSuite, void * apContext)
+{
+    TestContext & ctx = *static_cast<TestContext *>(apContext);
+
+    CHIP_ERROR err = CHIP_NO_ERROR;
+
+    Messaging::ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr();
+    // Shouldn't have anything in the retransmit table when starting the test.
+    NL_TEST_ASSERT(apSuite, rm->TestGetCountRetransTable() == 0);
+
+    TestWriteClientCallback callback;
+    auto * engine = chip::app::InteractionModelEngine::GetInstance();
+    err           = engine->Init(&ctx.GetExchangeManager(), &ctx.GetFabricTable());
+    NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR);
+
+    app::WriteClient writeClient(engine->GetExchangeManager(), &callback, Optional<uint16_t>::Missing());
+
+    System::PacketBufferHandle buf = System::PacketBufferHandle::New(System::PacketBuffer::kMaxSize);
+    AddAttributeDataIB(apSuite, apContext, writeClient);
+
+    NL_TEST_ASSERT(apSuite, callback.mOnSuccessCalled == 0 && callback.mOnErrorCalled == 0 && callback.mOnDoneCalled == 0);
+
+    ctx.GetLoopback().mSentMessageCount                 = 0;
+    ctx.GetLoopback().mNumMessagesToDrop                = 1;
+    ctx.GetLoopback().mNumMessagesToAllowBeforeDropping = 1;
+    ctx.GetLoopback().mDroppedMessageCount              = 0;
+    err                                                 = writeClient.SendWriteRequest(ctx.GetSessionBobToAlice());
+    NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR);
+    ctx.DrainAndServiceIO();
+
+    NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2);
+    NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mDroppedMessageCount == 1);
+
+    System::PacketBufferHandle msgBuf = System::PacketBufferHandle::New(kMaxSecureSduLengthBytes);
+    NL_TEST_ASSERT(apSuite, !msgBuf.IsNull());
+    System::PacketBufferTLVWriter writer;
+    writer.Init(std::move(msgBuf));
+    WriteResponseMessage::Builder response;
+    response.Init(&writer);
+    NL_TEST_ASSERT(apSuite, writer.Finalize(&msgBuf) == CHIP_NO_ERROR);
+    PayloadHeader payloadHeader;
+    payloadHeader.SetExchangeID(0);
+    payloadHeader.SetMessageType(chip::Protocols::InteractionModel::MsgType::WriteResponse);
+
+    rm->ClearRetransTable(writeClient.mExchangeCtx.Get());
+    ctx.GetLoopback().mSentMessageCount                 = 0;
+    ctx.GetLoopback().mNumMessagesToDrop                = 0;
+    ctx.GetLoopback().mNumMessagesToAllowBeforeDropping = 0;
+    ctx.GetLoopback().mDroppedMessageCount              = 0;
+    err = writeClient.OnMessageReceived(writeClient.mExchangeCtx.Get(), payloadHeader, std::move(msgBuf));
+    NL_TEST_ASSERT(apSuite, err == CHIP_ERROR_END_OF_TLV);
+    ctx.DrainAndServiceIO();
+    NL_TEST_ASSERT(apSuite, callback.mError == CHIP_ERROR_END_OF_TLV);
+    NL_TEST_ASSERT(apSuite, callback.mOnSuccessCalled == 0 && callback.mOnErrorCalled == 1 && callback.mOnDoneCalled == 1);
+
+    // Client sents status report with invalid action, server's exchange has been closed, so all it sends is an MRP Ack
+    NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2);
+
+    engine->Shutdown();
+    ctx.ExpireSessionAliceToBob();
+    ctx.ExpireSessionBobToAlice();
+    ctx.CreateSessionAliceToBob();
+    ctx.CreateSessionBobToAlice();
+}
+
+// Write Client sends a write request, receives a malformed status response message.
+void TestWriteInteraction::TestWriteInvalidMessage3(nlTestSuite * apSuite, void * apContext)
+{
+    TestContext & ctx = *static_cast<TestContext *>(apContext);
+
+    CHIP_ERROR err = CHIP_NO_ERROR;
+
+    Messaging::ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr();
+    // Shouldn't have anything in the retransmit table when starting the test.
+    NL_TEST_ASSERT(apSuite, rm->TestGetCountRetransTable() == 0);
+
+    TestWriteClientCallback callback;
+    auto * engine = chip::app::InteractionModelEngine::GetInstance();
+    err           = engine->Init(&ctx.GetExchangeManager(), &ctx.GetFabricTable());
+    NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR);
+
+    app::WriteClient writeClient(engine->GetExchangeManager(), &callback, Optional<uint16_t>::Missing());
+
+    System::PacketBufferHandle buf = System::PacketBufferHandle::New(System::PacketBuffer::kMaxSize);
+    AddAttributeDataIB(apSuite, apContext, writeClient);
+
+    NL_TEST_ASSERT(apSuite, callback.mOnSuccessCalled == 0 && callback.mOnErrorCalled == 0 && callback.mOnDoneCalled == 0);
+
+    ctx.GetLoopback().mSentMessageCount                 = 0;
+    ctx.GetLoopback().mNumMessagesToDrop                = 1;
+    ctx.GetLoopback().mNumMessagesToAllowBeforeDropping = 1;
+    ctx.GetLoopback().mDroppedMessageCount              = 0;
+    err                                                 = writeClient.SendWriteRequest(ctx.GetSessionBobToAlice());
+    NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR);
+    ctx.DrainAndServiceIO();
+
+    NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2);
+    NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mDroppedMessageCount == 1);
+
+    System::PacketBufferHandle msgBuf = System::PacketBufferHandle::New(kMaxSecureSduLengthBytes);
+    NL_TEST_ASSERT(apSuite, !msgBuf.IsNull());
+    System::PacketBufferTLVWriter writer;
+    writer.Init(std::move(msgBuf));
+    StatusResponseMessage::Builder response;
+    response.Init(&writer);
+    NL_TEST_ASSERT(apSuite, writer.Finalize(&msgBuf) == CHIP_NO_ERROR);
+    PayloadHeader payloadHeader;
+    payloadHeader.SetExchangeID(0);
+    payloadHeader.SetMessageType(chip::Protocols::InteractionModel::MsgType::StatusResponse);
+
+    rm->ClearRetransTable(writeClient.mExchangeCtx.Get());
+    ctx.GetLoopback().mSentMessageCount                 = 0;
+    ctx.GetLoopback().mNumMessagesToDrop                = 0;
+    ctx.GetLoopback().mNumMessagesToAllowBeforeDropping = 0;
+    ctx.GetLoopback().mDroppedMessageCount              = 0;
+    err = writeClient.OnMessageReceived(writeClient.mExchangeCtx.Get(), payloadHeader, std::move(msgBuf));
+    NL_TEST_ASSERT(apSuite, err == CHIP_ERROR_END_OF_TLV);
+    ctx.DrainAndServiceIO();
+    NL_TEST_ASSERT(apSuite, callback.mError == CHIP_ERROR_END_OF_TLV);
+    NL_TEST_ASSERT(apSuite, callback.mOnSuccessCalled == 0 && callback.mOnErrorCalled == 1 && callback.mOnDoneCalled == 1);
+
+    // TODO: Check that the server gets the right status
+    // Client sents status report with invalid action, server's exchange has been closed, so all it sends is an MRP ack.
+    NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2);
+
+    engine->Shutdown();
+    ctx.ExpireSessionAliceToBob();
+    ctx.ExpireSessionBobToAlice();
+    ctx.CreateSessionAliceToBob();
+    ctx.CreateSessionBobToAlice();
+}
+
+// Write Client sends a write request, receives a busy status response message.
+void TestWriteInteraction::TestWriteInvalidMessage4(nlTestSuite * apSuite, void * apContext)
+{
+    TestContext & ctx = *static_cast<TestContext *>(apContext);
+
+    CHIP_ERROR err = CHIP_NO_ERROR;
+
+    Messaging::ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr();
+    // Shouldn't have anything in the retransmit table when starting the test.
+    NL_TEST_ASSERT(apSuite, rm->TestGetCountRetransTable() == 0);
+
+    TestWriteClientCallback callback;
+    auto * engine = chip::app::InteractionModelEngine::GetInstance();
+    err           = engine->Init(&ctx.GetExchangeManager(), &ctx.GetFabricTable());
+    NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR);
+
+    app::WriteClient writeClient(engine->GetExchangeManager(), &callback, Optional<uint16_t>::Missing());
+
+    System::PacketBufferHandle buf = System::PacketBufferHandle::New(System::PacketBuffer::kMaxSize);
+    AddAttributeDataIB(apSuite, apContext, writeClient);
+
+    NL_TEST_ASSERT(apSuite, callback.mOnSuccessCalled == 0 && callback.mOnErrorCalled == 0 && callback.mOnDoneCalled == 0);
+
+    ctx.GetLoopback().mSentMessageCount                 = 0;
+    ctx.GetLoopback().mNumMessagesToDrop                = 1;
+    ctx.GetLoopback().mNumMessagesToAllowBeforeDropping = 1;
+    ctx.GetLoopback().mDroppedMessageCount              = 0;
+    err                                                 = writeClient.SendWriteRequest(ctx.GetSessionBobToAlice());
+    NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR);
+    ctx.DrainAndServiceIO();
+
+    NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2);
+    NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mDroppedMessageCount == 1);
+
+    System::PacketBufferHandle msgBuf = System::PacketBufferHandle::New(kMaxSecureSduLengthBytes);
+    NL_TEST_ASSERT(apSuite, !msgBuf.IsNull());
+    System::PacketBufferTLVWriter writer;
+    writer.Init(std::move(msgBuf));
+    StatusResponseMessage::Builder response;
+    response.Init(&writer);
+    response.Status(Protocols::InteractionModel::Status::Busy);
+    NL_TEST_ASSERT(apSuite, writer.Finalize(&msgBuf) == CHIP_NO_ERROR);
+    PayloadHeader payloadHeader;
+    payloadHeader.SetExchangeID(0);
+    payloadHeader.SetMessageType(chip::Protocols::InteractionModel::MsgType::StatusResponse);
+
+    rm->ClearRetransTable(writeClient.mExchangeCtx.Get());
+    ctx.GetLoopback().mSentMessageCount                 = 0;
+    ctx.GetLoopback().mNumMessagesToDrop                = 0;
+    ctx.GetLoopback().mNumMessagesToAllowBeforeDropping = 0;
+    ctx.GetLoopback().mDroppedMessageCount              = 0;
+    err = writeClient.OnMessageReceived(writeClient.mExchangeCtx.Get(), payloadHeader, std::move(msgBuf));
+    NL_TEST_ASSERT(apSuite, err == CHIP_IM_GLOBAL_STATUS(Busy));
+    ctx.DrainAndServiceIO();
+    NL_TEST_ASSERT(apSuite, callback.mError == CHIP_IM_GLOBAL_STATUS(Busy));
+    NL_TEST_ASSERT(apSuite, callback.mOnSuccessCalled == 0 && callback.mOnErrorCalled == 1 && callback.mOnDoneCalled == 1);
+
+    // TODO: Check that the server gets the right status..
+    // Client sents status report with invalid action, server's exchange has been closed, so it just sends an MRP ack.
+    NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2);
+
+    engine->Shutdown();
+    ctx.ExpireSessionAliceToBob();
+    ctx.ExpireSessionBobToAlice();
+    ctx.CreateSessionAliceToBob();
+    ctx.CreateSessionBobToAlice();
+}
+
 } // namespace app
 } // namespace chip
 
@@ -643,6 +918,10 @@
 #if CONFIG_BUILD_FOR_HOST_UNIT_TEST
         NL_TEST_DEF("TestWriteHandlerReceiveInvalidMessage", chip::app::TestWriteInteraction::TestWriteHandlerReceiveInvalidMessage),
 #endif
+        NL_TEST_DEF("TestWriteInvalidMessage1", chip::app::TestWriteInteraction::TestWriteInvalidMessage1),
+        NL_TEST_DEF("TestWriteInvalidMessage2", chip::app::TestWriteInteraction::TestWriteInvalidMessage2),
+        NL_TEST_DEF("TestWriteInvalidMessage3", chip::app::TestWriteInteraction::TestWriteInvalidMessage3),
+        NL_TEST_DEF("TestWriteInvalidMessage4", chip::app::TestWriteInteraction::TestWriteInvalidMessage4),
         NL_TEST_SENTINEL()
 };
 // clang-format on