StandaloneAck handling: Untangle code path for UMH and StandaloneAck (#19443)
* StandaloneAck handling: Untangle code path for UMH and StandaloneAck
* Fix comments from Boris
* Resolve comments
diff --git a/src/messaging/EphemeralExchangeDispatch.h b/src/messaging/EphemeralExchangeDispatch.h
new file mode 100644
index 0000000..2f58587
--- /dev/null
+++ b/src/messaging/EphemeralExchangeDispatch.h
@@ -0,0 +1,54 @@
+/*
+ * Copyright (c) 2022 Project CHIP Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <messaging/ExchangeMessageDispatch.h>
+#include <protocols/secure_channel/Constants.h>
+
+namespace chip {
+namespace Messaging {
+
+class EphemeralExchangeDispatch : public ExchangeMessageDispatch
+{
+public:
+ static ExchangeMessageDispatch & Instance()
+ {
+ static EphemeralExchangeDispatch instance;
+ return instance;
+ }
+
+ EphemeralExchangeDispatch() {}
+ ~EphemeralExchangeDispatch() override {}
+
+protected:
+ bool MessagePermitted(Protocols::Id protocol, uint8_t type) override
+ {
+ // Only permit StandaloneAck
+ return (protocol == Protocols::SecureChannel::Id &&
+ type == to_underlying(Protocols::SecureChannel::MsgType::StandaloneAck));
+ }
+
+ bool IsEncryptionRequired() const override
+ {
+ // This function should not be called at all
+ VerifyOrDie(false);
+ return false;
+ }
+};
+
+} // namespace Messaging
+} // namespace chip
diff --git a/src/messaging/ExchangeContext.cpp b/src/messaging/ExchangeContext.cpp
index cb67cdb..e5008e4 100644
--- a/src/messaging/ExchangeContext.cpp
+++ b/src/messaging/ExchangeContext.cpp
@@ -39,6 +39,7 @@
#include <lib/support/TypeTraits.h>
#include <lib/support/logging/CHIPLogging.h>
#include <messaging/ApplicationExchangeDispatch.h>
+#include <messaging/EphemeralExchangeDispatch.h>
#include <messaging/ExchangeContext.h>
#include <messaging/ExchangeMgr.h>
#include <protocols/Protocols.h>
@@ -290,8 +291,8 @@
}
ExchangeContext::ExchangeContext(ExchangeManager * em, uint16_t ExchangeId, const SessionHandle & session, bool Initiator,
- ExchangeDelegate * delegate) :
- mDispatch((delegate != nullptr) ? delegate->GetMessageDispatch() : ApplicationExchangeDispatch::Instance()),
+ ExchangeDelegate * delegate, bool isEphemeralExchange) :
+ mDispatch(GetMessageDispatch(isEphemeralExchange, delegate)),
mSession(*this)
{
VerifyOrDie(mExchangeMgr == nullptr);
@@ -300,6 +301,7 @@
mExchangeId = ExchangeId;
mSession.Grab(session);
mFlags.Set(Flags::kFlagInitiator, Initiator);
+ mFlags.Set(Flags::kFlagEphemeralExchange, isEphemeralExchange);
mDelegate = delegate;
SetAckPending(false);
@@ -494,6 +496,12 @@
return CHIP_NO_ERROR;
}
+ if (IsEphemeralExchange())
+ {
+ // The EphemeralExchange has done its job, since StandaloneAck is sent in previous FlushAcks() call.
+ return CHIP_NO_ERROR;
+ }
+
// Since we got the response, cancel the response timer.
CancelResponseTimer();
@@ -527,5 +535,16 @@
Close();
}
+ExchangeMessageDispatch & ExchangeContext::GetMessageDispatch(bool isEphemeralExchange, ExchangeDelegate * delegate)
+{
+ if (isEphemeralExchange)
+ return EphemeralExchangeDispatch::Instance();
+
+ if (delegate != nullptr)
+ return delegate->GetMessageDispatch();
+
+ return ApplicationExchangeDispatch::Instance();
+}
+
} // namespace Messaging
} // namespace chip
diff --git a/src/messaging/ExchangeContext.h b/src/messaging/ExchangeContext.h
index d1a8c94..db3d4e1 100644
--- a/src/messaging/ExchangeContext.h
+++ b/src/messaging/ExchangeContext.h
@@ -66,7 +66,7 @@
typedef System::Clock::Timeout Timeout; // Type used to express the timeout in this ExchangeContext
ExchangeContext(ExchangeManager * em, uint16_t ExchangeId, const SessionHandle & session, bool Initiator,
- ExchangeDelegate * delegate);
+ ExchangeDelegate * delegate, bool isEphemeralExchange = false);
~ExchangeContext() override;
@@ -154,8 +154,6 @@
ReliableMessageContext * GetReliableMessageContext() { return static_cast<ReliableMessageContext *>(this); };
- ExchangeMessageDispatch & GetMessageDispatch() { return mDispatch; }
-
SessionHandle GetSessionHandle() const
{
VerifyOrDie(mSession);
@@ -273,6 +271,8 @@
* exchange nor other component requests the active mode.
*/
void UpdateSEDIntervalMode(bool activeMode);
+
+ static ExchangeMessageDispatch & GetMessageDispatch(bool isEphemeralExchange, ExchangeDelegate * delegate);
};
} // namespace Messaging
diff --git a/src/messaging/ExchangeMgr.cpp b/src/messaging/ExchangeMgr.cpp
index 1a7a3ca..2c1242a 100644
--- a/src/messaging/ExchangeMgr.cpp
+++ b/src/messaging/ExchangeMgr.cpp
@@ -270,72 +270,94 @@
return;
}
- // If we found a handler or we need to send an ack, create an exchange to
- // handle the message.
- if (matchingUMH != nullptr || payloadHeader.NeedsAck())
+ // If we found a handler, create an exchange to handle the message.
+ if (matchingUMH != nullptr)
{
ExchangeDelegate * delegate = nullptr;
// Fetch delegate from the handler
- if (matchingUMH != nullptr)
+ CHIP_ERROR err = matchingUMH->Handler->OnUnsolicitedMessageReceived(payloadHeader, delegate);
+ if (err != CHIP_NO_ERROR)
{
- CHIP_ERROR err = matchingUMH->Handler->OnUnsolicitedMessageReceived(payloadHeader, delegate);
- if (err != CHIP_NO_ERROR)
- {
- // Using same error message for all errors to reduce code size.
- ChipLogError(ExchangeManager, "OnMessageReceived failed, err = %s", ErrorStr(err));
- return;
- }
+ // Using same error message for all errors to reduce code size.
+ ChipLogError(ExchangeManager, "OnMessageReceived failed, err = %s", ErrorStr(err));
+ SendStandaloneAckIfNeeded(packetHeader, payloadHeader, session, msgFlags, std::move(msgBuf));
+ return;
}
- // If rcvd msg is from initiator then this exchange is created as not Initiator.
- // If rcvd msg is not from initiator then this exchange is created as Initiator.
- // Note that if matchingUMH is not null then rcvd msg if from initiator.
- // TODO: Figure out which channel to use for the received message
- ExchangeContext * ec =
- mContextPool.CreateObject(this, payloadHeader.GetExchangeID(), session, !payloadHeader.IsInitiator(), delegate);
+ ExchangeContext * ec = mContextPool.CreateObject(this, payloadHeader.GetExchangeID(), session, false, delegate);
if (ec == nullptr)
{
- if (matchingUMH != nullptr && delegate != nullptr)
+ if (delegate != nullptr)
{
matchingUMH->Handler->OnExchangeCreationFailed(delegate);
}
// Using same error message for all errors to reduce code size.
ChipLogError(ExchangeManager, "OnMessageReceived failed, err = %s", ErrorStr(CHIP_ERROR_NO_MEMORY));
+ // No resource for creating new exchange, SendStandaloneAckIfNeeded probably also fails, so do not try it here
return;
}
ChipLogDetail(ExchangeManager, "Handling via exchange: " ChipLogFormatExchange ", Delegate: %p", ChipLogValueExchange(ec),
ec->GetDelegate());
- // Make sure the exchange stays alive through the code below even if we
- // close it before calling HandleMessage.
- ExchangeHandle ref(*ec);
-
- // Ignore encryption-required mismatches for emphemeral exchanges,
- // because those never have delegates anyway.
- if (matchingUMH != nullptr && ec->IsEncryptionRequired() != packetHeader.IsEncrypted())
+ if (ec->IsEncryptionRequired() != packetHeader.IsEncrypted())
{
- // We want to still to do MRP processing for this message, but we do
- // not want to deliver it to the application. Just close the
- // exchange (which will notify the delegate, null it out, etc), then
- // go ahead and call HandleMessage() on it to do the MRP
- // processing.null out the delegate on the exchange, pretend to
- // matchingUMH that exchange creation failed, so it cleans up the
- // delegate, then tell the exchagne to handle the message.
- ChipLogProgress(ExchangeManager, "OnMessageReceived encryption mismatch");
+ ChipLogError(ExchangeManager, "OnMessageReceived failed, err = %s", ErrorStr(CHIP_ERROR_INVALID_MESSAGE_TYPE));
ec->Close();
+ SendStandaloneAckIfNeeded(packetHeader, payloadHeader, session, msgFlags, std::move(msgBuf));
+ return;
}
- CHIP_ERROR err = ec->HandleMessage(packetHeader.GetMessageCounter(), payloadHeader, msgFlags, std::move(msgBuf));
+ err = ec->HandleMessage(packetHeader.GetMessageCounter(), payloadHeader, msgFlags, std::move(msgBuf));
if (err != CHIP_NO_ERROR)
{
// Using same error message for all errors to reduce code size.
ChipLogError(ExchangeManager, "OnMessageReceived failed, err = %s", ErrorStr(err));
}
+ return;
}
+
+ SendStandaloneAckIfNeeded(packetHeader, payloadHeader, session, msgFlags, std::move(msgBuf));
+ return;
+}
+
+void ExchangeManager::SendStandaloneAckIfNeeded(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader,
+ const SessionHandle & session, MessageFlags msgFlags,
+ System::PacketBufferHandle && msgBuf)
+{
+ // If we need to send a StandaloneAck, create a EphemeralExchange for the purpose to send the StandaloneAck
+ if (!payloadHeader.NeedsAck())
+ return;
+
+ // If rcvd msg is from initiator then this exchange is created as not Initiator.
+ // If rcvd msg is not from initiator then this exchange is created as Initiator.
+ // Create a EphemeralExchange to generate a StandaloneAck
+ ExchangeContext * ec = mContextPool.CreateObject(this, payloadHeader.GetExchangeID(), session, !payloadHeader.IsInitiator(),
+ nullptr, true /* IsEphemeralExchange */);
+
+ if (ec == nullptr)
+ {
+ // Using same error message for all errors to reduce code size.
+ ChipLogError(ExchangeManager, "OnMessageReceived failed, err = %s", ErrorStr(CHIP_ERROR_NO_MEMORY));
+ return;
+ }
+
+ ChipLogDetail(ExchangeManager, "Generating StandaloneAck via exchange: " ChipLogFormatExchange, ChipLogValueExchange(ec));
+
+ // No need to verify packet encryption type, the EphemeralExchange can handle both secure and insecure messages.
+
+ CHIP_ERROR err = ec->HandleMessage(packetHeader.GetMessageCounter(), payloadHeader, msgFlags, std::move(msgBuf));
+ if (err != CHIP_NO_ERROR)
+ {
+ // Using same error message for all errors to reduce code size.
+ ChipLogError(ExchangeManager, "OnMessageReceived failed, err = %s", ErrorStr(err));
+ }
+
+ // The exchange should be closed inside HandleMessage function. So don't bother close it here.
+ return;
}
void ExchangeManager::CloseAllContextsForDelegate(const ExchangeDelegate * delegate)
diff --git a/src/messaging/ExchangeMgr.h b/src/messaging/ExchangeMgr.h
index 1c0a783..d6b7754 100644
--- a/src/messaging/ExchangeMgr.h
+++ b/src/messaging/ExchangeMgr.h
@@ -240,6 +240,8 @@
void OnMessageReceived(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, const SessionHandle & session,
DuplicateMessage isDuplicate, System::PacketBufferHandle && msgBuf) override;
+ void SendStandaloneAckIfNeeded(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader,
+ const SessionHandle & session, MessageFlags msgFlags, System::PacketBufferHandle && msgBuf);
};
} // namespace Messaging
diff --git a/src/messaging/ReliableMessageContext.h b/src/messaging/ReliableMessageContext.h
index 6ebc822..e67b2d5 100644
--- a/src/messaging/ReliableMessageContext.h
+++ b/src/messaging/ReliableMessageContext.h
@@ -121,6 +121,9 @@
/// Determine whether this exchange is requesting Sleepy End Device active mode
bool IsRequestingActiveMode() const;
+ /// Determine whether this exchange is a EphemeralExchange for replying a StandaloneAck
+ bool IsEphemeralExchange() const;
+
/**
* Get the reliable message manager that corresponds to this reliable
* message context.
@@ -158,6 +161,9 @@
/// When set, signifies that the exchange is requesting Sleepy End Device active mode.
kFlagActiveMode = (1u << 8),
+
+ /// When set, signifies that the exchange created sorely for replying a StandaloneAck
+ kFlagEphemeralExchange = (1u << 9),
};
BitFlags<Flags> mFlags; // Internal state flags
@@ -234,5 +240,10 @@
mFlags.Set(Flags::kFlagActiveMode, activeMode);
}
+inline bool ReliableMessageContext::IsEphemeralExchange() const
+{
+ return mFlags.Has(Flags::kFlagEphemeralExchange);
+}
+
} // namespace Messaging
} // namespace chip