Refactor SessionHandle (#13330)
* Fix EPS32 build warning
* app/ReadPrepareParams: store SessionHolder instead of SessionHandle
* Refactor SessionHandle
* Fix GetFabricIndex/GetPeerNodeId calls in general commissioning server
* Fix usage of . instead of -> (tested with esp32 compile)
* Remove the define of global for nxp compilation: conflicts with STL usage and does not seem to be utilized at least for matter builds
* Restyle fixes
* Fix getting peer address compilation for exchange context
Co-authored-by: Andrei Litvin <andy314@gmail.com>
diff --git a/examples/chip-tool/commands/common/CommandInvoker.h b/examples/chip-tool/commands/common/CommandInvoker.h
index 6583cc9..e51f141 100644
--- a/examples/chip-tool/commands/common/CommandInvoker.h
+++ b/examples/chip-tool/commands/common/CommandInvoker.h
@@ -100,33 +100,22 @@
return CHIP_NO_ERROR;
}
- CHIP_ERROR InvokeGroupCommand(DeviceProxy * aDevice, GroupId groupId, const RequestType & aRequestData)
+ CHIP_ERROR InvokeGroupCommand(Messaging::ExchangeManager * exchangeManager, GroupId groupId, const RequestType & aRequestData)
{
app::CommandPathParams commandPath = { 0 /* endpoint */, groupId, RequestType::GetClusterId(), RequestType::GetCommandId(),
(app::CommandPathFlags::kGroupIdValid) };
- auto commandSender = Platform::MakeUnique<app::CommandSender>(this, aDevice->GetExchangeManager());
+ auto commandSender = Platform::MakeUnique<app::CommandSender>(this, exchangeManager);
VerifyOrReturnError(commandSender != nullptr, CHIP_ERROR_NO_MEMORY);
ReturnErrorOnFailure(commandSender->AddRequestData(commandPath, aRequestData));
- if (aDevice->GetSecureSession().HasValue())
+ Optional<SessionHandle> session = exchangeManager->GetSessionManager()->CreateGroupSession(groupId);
+ if (!session.HasValue())
{
- SessionHandle session = aDevice->GetSecureSession().Value();
- session.SetGroupId(groupId);
-
- if (!session.IsGroupSession())
- {
- return CHIP_ERROR_INCORRECT_STATE;
- }
-
- ReturnErrorOnFailure(commandSender->SendCommandRequest(session));
+ return CHIP_ERROR_NO_MEMORY;
}
- else
- {
- // something fishy is going on
- return CHIP_ERROR_INCORRECT_STATE;
- }
+ ReturnErrorOnFailure(commandSender->SendCommandRequest(session));
commandSender.release();
return CHIP_NO_ERROR;
diff --git a/examples/ota-provider-app/ota-provider-common/OTAProviderExample.cpp b/examples/ota-provider-app/ota-provider-common/OTAProviderExample.cpp
index 968bcea..43a9d06 100644
--- a/examples/ota-provider-app/ota-provider-common/OTAProviderExample.cpp
+++ b/examples/ota-provider-app/ota-provider-common/OTAProviderExample.cpp
@@ -114,7 +114,7 @@
{
// TODO: This uses the current node as the provider to supply the OTA image. This can be configurable such that the provider
// supplying the response is not the provider supplying the OTA image.
- FabricIndex fabricIndex = commandObj->GetExchangeContext()->GetSessionHandle().GetFabricIndex();
+ FabricIndex fabricIndex = commandObj->GetExchangeContext()->GetSessionHandle()->AsSecureSession()->GetFabricIndex();
FabricInfo * fabricInfo = Server::GetInstance().GetFabricTable().FindFabricWithIndex(fabricIndex);
NodeId nodeId = fabricInfo->GetPeerId().GetNodeId();
diff --git a/src/app/CASESessionManager.cpp b/src/app/CASESessionManager.cpp
index 5d3ed0e..08d5d42 100644
--- a/src/app/CASESessionManager.cpp
+++ b/src/app/CASESessionManager.cpp
@@ -109,14 +109,6 @@
return CHIP_NO_ERROR;
}
-void CASESessionManager::OnSessionReleased(const SessionHandle & sessionHandle)
-{
- OperationalDeviceProxy * session = FindSession(sessionHandle);
- VerifyOrReturn(session != nullptr, ChipLogDetail(Controller, "OnSessionReleased was called for unknown device, ignoring it."));
-
- session->OnSessionReleased(sessionHandle);
-}
-
OperationalDeviceProxy * CASESessionManager::FindSession(const SessionHandle & session)
{
return mConfig.devicePool->FindDevice(session);
diff --git a/src/app/CASESessionManager.h b/src/app/CASESessionManager.h
index 382fcd4..f22bf99 100644
--- a/src/app/CASESessionManager.h
+++ b/src/app/CASESessionManager.h
@@ -48,7 +48,7 @@
* 4. During session establishment, trigger node ID resolution (if needed), and update the DNS-SD cache (if resolution is
* successful)
*/
-class CASESessionManager : public SessionReleaseDelegate, public Dnssd::ResolverDelegate
+class CASESessionManager : public Dnssd::ResolverDelegate
{
public:
CASESessionManager() = delete;
@@ -105,9 +105,6 @@
*/
CHIP_ERROR GetPeerAddress(PeerId peerId, Transport::PeerAddress & addr);
- //////////// SessionReleaseDelegate Implementation ///////////////
- void OnSessionReleased(const SessionHandle & session) override;
-
//////////// ResolverDelegate Implementation ///////////////
void OnNodeIdResolved(const Dnssd::ResolvedNodeData & nodeData) override;
void OnNodeIdResolutionFailed(const PeerId & peerId, CHIP_ERROR error) override;
diff --git a/src/app/CommandHandler.cpp b/src/app/CommandHandler.cpp
index ce31bfe..7e41b30 100644
--- a/src/app/CommandHandler.cpp
+++ b/src/app/CommandHandler.cpp
@@ -436,7 +436,7 @@
FabricIndex CommandHandler::GetAccessingFabricIndex() const
{
- return mpExchangeCtx->GetSessionHandle().GetFabricIndex();
+ return mpExchangeCtx->GetSessionHandle()->AsSecureSession()->GetFabricIndex();
}
CommandHandler * CommandHandler::Handle::Get()
diff --git a/src/app/InteractionModelEngine.cpp b/src/app/InteractionModelEngine.cpp
index eae1ec8..09d4bc0 100644
--- a/src/app/InteractionModelEngine.cpp
+++ b/src/app/InteractionModelEngine.cpp
@@ -269,8 +269,8 @@
for (auto & readHandler : mReadHandlers)
{
if (!readHandler.IsFree() && readHandler.IsSubscriptionType() &&
- readHandler.GetInitiatorNodeId() == apExchangeContext->GetSessionHandle().GetPeerNodeId() &&
- readHandler.GetAccessingFabricIndex() == apExchangeContext->GetSessionHandle().GetFabricIndex())
+ readHandler.GetInitiatorNodeId() == apExchangeContext->GetSessionHandle()->AsSecureSession()->GetPeerNodeId() &&
+ readHandler.GetAccessingFabricIndex() == apExchangeContext->GetSessionHandle()->AsSecureSession()->GetFabricIndex())
{
bool keepSubscriptions = true;
System::PacketBufferTLVReader reader;
diff --git a/src/app/InteractionModelEngine.h b/src/app/InteractionModelEngine.h
index a306706..04becfc 100644
--- a/src/app/InteractionModelEngine.h
+++ b/src/app/InteractionModelEngine.h
@@ -308,7 +308,7 @@
*
* @param[in] aSubjectDescriptor The subject descriptor for the read.
* @param[in] aPath The concrete path of the data being read.
- * @param[in] aAttributeReport The TLV Builder for Cluter attribute builder.
+ * @param[in] aAttributeReports The TLV Builder for Cluter attribute builder.
*
* @retval CHIP_NO_ERROR on success
*/
diff --git a/src/app/OperationalDeviceProxy.cpp b/src/app/OperationalDeviceProxy.cpp
index 9cf2234..0164e5c 100644
--- a/src/app/OperationalDeviceProxy.cpp
+++ b/src/app/OperationalDeviceProxy.cpp
@@ -127,11 +127,7 @@
return CHIP_NO_ERROR;
}
- Transport::SecureSession * secureSession = mInitParams.sessionManager->GetSecureSession(mSecureSession.Get());
- if (secureSession != nullptr)
- {
- secureSession->SetPeerAddress(addr);
- }
+ mSecureSession.Get()->AsSecureSession()->SetPeerAddress(addr);
}
return err;
@@ -262,7 +258,7 @@
return CHIP_NO_ERROR;
}
-void OperationalDeviceProxy::SetConnectedSession(SessionHandle handle)
+void OperationalDeviceProxy::SetConnectedSession(const SessionHandle & handle)
{
mSecureSession.Grab(handle);
mState = State::SecureConnected;
@@ -296,12 +292,9 @@
mSystemLayer->ScheduleWork(CloseCASESessionTask, this);
}
-void OperationalDeviceProxy::OnSessionReleased(const SessionHandle & session)
+void OperationalDeviceProxy::OnSessionReleased()
{
- VerifyOrReturn(mSecureSession.Contains(session),
- ChipLogDetail(Controller, "Connection expired, but it doesn't match the current session"));
mState = State::Initialized;
- mSecureSession.Release();
}
CHIP_ERROR OperationalDeviceProxy::ShutdownSubscriptions()
diff --git a/src/app/OperationalDeviceProxy.h b/src/app/OperationalDeviceProxy.h
index c56a924..91dd127 100644
--- a/src/app/OperationalDeviceProxy.h
+++ b/src/app/OperationalDeviceProxy.h
@@ -80,7 +80,7 @@
{
public:
virtual ~OperationalDeviceProxy();
- OperationalDeviceProxy(DeviceProxyInitParams & params, PeerId peerId)
+ OperationalDeviceProxy(DeviceProxyInitParams & params, PeerId peerId) : mSecureSession(*this)
{
VerifyOrReturn(params.Validate() == CHIP_NO_ERROR);
@@ -124,7 +124,7 @@
* Called when a connection is closing.
* The object releases all resources associated with the connection.
*/
- void OnSessionReleased(const SessionHandle & session) override;
+ void OnSessionReleased() override;
void OnNodeIdResolved(const Dnssd::ResolvedNodeData & nodeResolutionData)
{
@@ -150,7 +150,7 @@
*
* Note: Avoid using this function generally as it is Deprecated
*/
- void SetConnectedSession(SessionHandle handle);
+ void SetConnectedSession(const SessionHandle & handle);
NodeId GetDeviceId() const override { return mPeerId.GetNodeId(); }
@@ -219,7 +219,7 @@
State mState = State::Uninitialized;
- SessionHolder mSecureSession;
+ SessionHolderWithDelegate mSecureSession;
uint8_t mSequenceNumber = 0;
diff --git a/src/app/ReadClient.cpp b/src/app/ReadClient.cpp
index c1b993e..e46654b 100644
--- a/src/app/ReadClient.cpp
+++ b/src/app/ReadClient.cpp
@@ -182,7 +182,7 @@
ReturnErrorOnFailure(writer.Finalize(&msgBuf));
}
- mpExchangeCtx = mpExchangeMgr->NewContext(aReadPrepareParams.mSessionHandle, this);
+ mpExchangeCtx = mpExchangeMgr->NewContext(aReadPrepareParams.mSessionHolder.Get(), this);
VerifyOrReturnError(mpExchangeCtx != nullptr, err = CHIP_ERROR_NO_MEMORY);
mpExchangeCtx->SetResponseTimeout(aReadPrepareParams.mTimeout);
@@ -190,8 +190,8 @@
ReturnErrorOnFailure(mpExchangeCtx->SendMessage(Protocols::InteractionModel::MsgType::ReadRequest, std::move(msgBuf),
Messaging::SendFlags(Messaging::SendMessageFlags::kExpectResponse)));
- mPeerNodeId = aReadPrepareParams.mSessionHandle.GetPeerNodeId();
- mFabricIndex = aReadPrepareParams.mSessionHandle.GetFabricIndex();
+ mPeerNodeId = aReadPrepareParams.mSessionHolder->AsSecureSession()->GetPeerNodeId();
+ mFabricIndex = aReadPrepareParams.mSessionHolder->AsSecureSession()->GetFabricIndex();
MoveToState(ClientState::AwaitingInitialReport);
@@ -576,8 +576,10 @@
CHIP_ERROR err = CHIP_NO_ERROR;
CancelLivenessCheckTimer();
VerifyOrReturnError(mpExchangeCtx != nullptr, err = CHIP_ERROR_INCORRECT_STATE);
+ VerifyOrReturnError(mpExchangeCtx->HasSessionHandle(), err = CHIP_ERROR_INCORRECT_STATE);
- System::Clock::Timeout timeout = System::Clock::Seconds16(mMaxIntervalCeilingSeconds) + mpExchangeCtx->GetAckTimeout();
+ System::Clock::Timeout timeout =
+ System::Clock::Seconds16(mMaxIntervalCeilingSeconds) + mpExchangeCtx->GetSessionHandle()->GetAckTimeout();
// EFR32/MBED/INFINION/K32W's chrono count return long unsinged, but other platform returns unsigned
ChipLogProgress(DataManagement, "Refresh LivenessCheckTime with %lu milliseconds", static_cast<long unsigned>(timeout.count()));
err = InteractionModelEngine::GetInstance()->GetExchangeManager()->GetSessionManager()->SystemLayer()->StartTimer(
@@ -699,21 +701,15 @@
ReturnErrorOnFailure(writer.Finalize(&msgBuf));
- mpExchangeCtx = mpExchangeMgr->NewContext(aReadPrepareParams.mSessionHandle, this);
+ mpExchangeCtx = mpExchangeMgr->NewContext(aReadPrepareParams.mSessionHolder.Get(), this);
VerifyOrReturnError(mpExchangeCtx != nullptr, err = CHIP_ERROR_NO_MEMORY);
-
mpExchangeCtx->SetResponseTimeout(kImMessageTimeout);
- if (mpExchangeCtx->IsBLETransport())
- {
- ChipLogError(DataManagement, "IM Subscribe cannot work with BLE");
- return CHIP_ERROR_INCORRECT_STATE;
- }
ReturnErrorOnFailure(mpExchangeCtx->SendMessage(Protocols::InteractionModel::MsgType::SubscribeRequest, std::move(msgBuf),
Messaging::SendFlags(Messaging::SendMessageFlags::kExpectResponse)));
- mPeerNodeId = aReadPrepareParams.mSessionHandle.GetPeerNodeId();
- mFabricIndex = aReadPrepareParams.mSessionHandle.GetFabricIndex();
+ mPeerNodeId = aReadPrepareParams.mSessionHolder->AsSecureSession()->GetPeerNodeId();
+ mFabricIndex = aReadPrepareParams.mSessionHolder->AsSecureSession()->GetFabricIndex();
MoveToState(ClientState::AwaitingInitialReport);
diff --git a/src/app/ReadHandler.cpp b/src/app/ReadHandler.cpp
index 8e9c936..2d2f429 100644
--- a/src/app/ReadHandler.cpp
+++ b/src/app/ReadHandler.cpp
@@ -58,8 +58,8 @@
mActiveSubscription = false;
mIsChunkedReport = false;
mInteractionType = aInteractionType;
- mInitiatorNodeId = apExchangeContext->GetSessionHandle().GetPeerNodeId();
- mSubjectDescriptor = apExchangeContext->GetSessionHandle().GetSubjectDescriptor();
+ mInitiatorNodeId = apExchangeContext->GetSessionHandle()->AsSecureSession()->GetPeerNodeId();
+ mSubjectDescriptor = apExchangeContext->GetSessionHandle()->GetSubjectDescriptor();
mHoldSync = false;
mLastWrittenEventsBytes = 0;
if (apExchangeContext != nullptr)
@@ -201,12 +201,13 @@
VerifyOrReturnLogError(IsReportable(), CHIP_ERROR_INCORRECT_STATE);
if (IsPriming() || IsChunkedReport())
{
- mSessionHandle.SetValue(mpExchangeCtx->GetSessionHandle());
+ mSessionHandle.Grab(mpExchangeCtx->GetSessionHandle());
}
else
{
VerifyOrReturnLogError(mpExchangeCtx == nullptr, CHIP_ERROR_INCORRECT_STATE);
- mpExchangeCtx = mpExchangeMgr->NewContext(mSessionHandle.Value(), this);
+ VerifyOrReturnLogError(mSessionHandle, CHIP_ERROR_INCORRECT_STATE);
+ mpExchangeCtx = mpExchangeMgr->NewContext(mSessionHandle.Get(), this);
mpExchangeCtx->SetResponseTimeout(kImMessageTimeout);
}
VerifyOrReturnLogError(mpExchangeCtx != nullptr, CHIP_ERROR_INCORRECT_STATE);
diff --git a/src/app/ReadHandler.h b/src/app/ReadHandler.h
index 75ce763..df8e1c4 100644
--- a/src/app/ReadHandler.h
+++ b/src/app/ReadHandler.h
@@ -216,7 +216,7 @@
uint64_t mSubscriptionId = 0;
uint16_t mMinIntervalFloorSeconds = 0;
uint16_t mMaxIntervalCeilingSeconds = 0;
- Optional<SessionHandle> mSessionHandle;
+ SessionHolder mSessionHandle;
bool mHoldReport = false;
bool mDirty = false;
bool mActiveSubscription = false;
diff --git a/src/app/ReadPrepareParams.h b/src/app/ReadPrepareParams.h
index 547da06..7173e27 100644
--- a/src/app/ReadPrepareParams.h
+++ b/src/app/ReadPrepareParams.h
@@ -29,7 +29,7 @@
namespace app {
struct ReadPrepareParams
{
- SessionHandle mSessionHandle;
+ SessionHolder mSessionHolder;
EventPathParams * mpEventPathParamsList = nullptr;
size_t mEventPathParamsListSize = 0;
AttributePathParams * mpAttributePathParamsList = nullptr;
@@ -40,8 +40,8 @@
uint16_t mMaxIntervalCeilingSeconds = 0;
bool mKeepSubscriptions = true;
- ReadPrepareParams(const SessionHandle & sessionHandle) : mSessionHandle(sessionHandle) {}
- ReadPrepareParams(ReadPrepareParams && other) : mSessionHandle(other.mSessionHandle)
+ ReadPrepareParams(const SessionHandle & sessionHandle) { mSessionHolder.Grab(sessionHandle); }
+ ReadPrepareParams(ReadPrepareParams && other) : mSessionHolder(other.mSessionHolder)
{
mKeepSubscriptions = other.mKeepSubscriptions;
mpEventPathParamsList = other.mpEventPathParamsList;
@@ -64,7 +64,7 @@
return *this;
mKeepSubscriptions = other.mKeepSubscriptions;
- mSessionHandle = other.mSessionHandle;
+ mSessionHolder = other.mSessionHolder;
mpEventPathParamsList = other.mpEventPathParamsList;
mEventPathParamsListSize = other.mEventPathParamsListSize;
mpAttributePathParamsList = other.mpAttributePathParamsList;
diff --git a/src/app/WriteClient.cpp b/src/app/WriteClient.cpp
index 13bcbff..fb13647 100644
--- a/src/app/WriteClient.cpp
+++ b/src/app/WriteClient.cpp
@@ -245,7 +245,7 @@
exit:
if (err != CHIP_NO_ERROR)
{
- ChipLogError(DataManagement, "Write client failed to SendWriteRequest");
+ ChipLogError(DataManagement, "Write client failed to SendWriteRequest: %s", ErrorStr(err));
ClearExistingExchangeContext();
}
else
@@ -254,7 +254,7 @@
// handle this object dying (e.g. due to IM enging shutdown) while the
// async bits are pending we'd need to malloc some state bit that we can
// twiddle if we die. For now just do the OnDone callback sync.
- if (session.IsGroupSession())
+ if (session->IsGroupSession())
{
// Always shutdown on Group communication
ChipLogDetail(DataManagement, "Closing on group Communication ");
diff --git a/src/app/WriteClient.h b/src/app/WriteClient.h
index 35a5996..5be883e 100644
--- a/src/app/WriteClient.h
+++ b/src/app/WriteClient.h
@@ -117,7 +117,7 @@
NodeId GetSourceNodeId() const
{
- return mpExchangeCtx != nullptr ? mpExchangeCtx->GetSessionHandle().GetPeerNodeId() : kUndefinedNodeId;
+ return mpExchangeCtx != nullptr ? mpExchangeCtx->GetSessionHandle()->AsSecureSession()->GetPeerNodeId() : kUndefinedNodeId;
}
private:
diff --git a/src/app/WriteHandler.cpp b/src/app/WriteHandler.cpp
index ccb1779..8ec46e5 100644
--- a/src/app/WriteHandler.cpp
+++ b/src/app/WriteHandler.cpp
@@ -113,7 +113,7 @@
CHIP_ERROR err = CHIP_NO_ERROR;
ReturnErrorCodeIf(mpExchangeCtx == nullptr, CHIP_ERROR_INTERNAL);
- const Access::SubjectDescriptor subjectDescriptor = mpExchangeCtx->GetSessionHandle().GetSubjectDescriptor();
+ const Access::SubjectDescriptor subjectDescriptor = mpExchangeCtx->GetSessionHandle()->GetSubjectDescriptor();
while (CHIP_NO_ERROR == (err = aAttributeDataIBsReader.Next()))
{
@@ -279,7 +279,7 @@
FabricIndex WriteHandler::GetAccessingFabricIndex() const
{
- return mpExchangeCtx->GetSessionHandle().GetFabricIndex();
+ return mpExchangeCtx->GetSessionHandle()->AsSecureSession()->GetFabricIndex();
}
const char * WriteHandler::GetStateStr() const
diff --git a/src/app/clusters/general-commissioning-server/general-commissioning-server.cpp b/src/app/clusters/general-commissioning-server/general-commissioning-server.cpp
index 9a5fe3e..6e95ab8 100644
--- a/src/app/clusters/general-commissioning-server/general-commissioning-server.cpp
+++ b/src/app/clusters/general-commissioning-server/general-commissioning-server.cpp
@@ -134,8 +134,9 @@
* This allows device to send messages back to commissioner.
* Once bindings are implemented, this may no longer be needed.
*/
- server->SetFabricIndex(commandObj->GetExchangeContext()->GetSessionHandle().GetFabricIndex());
- server->SetPeerNodeId(commandObj->GetExchangeContext()->GetSessionHandle().GetPeerNodeId());
+ SessionHandle handle = commandObj->GetExchangeContext()->GetSessionHandle();
+ server->SetFabricIndex(handle->AsSecureSession()->GetFabricIndex());
+ server->SetPeerNodeId(handle->AsSecureSession()->GetPeerNodeId());
CheckSuccess(server->CommissioningComplete(), Failure);
diff --git a/src/app/clusters/operational-credentials-server/operational-credentials-server.cpp b/src/app/clusters/operational-credentials-server/operational-credentials-server.cpp
index 5725f72..b95982e 100644
--- a/src/app/clusters/operational-credentials-server/operational-credentials-server.cpp
+++ b/src/app/clusters/operational-credentials-server/operational-credentials-server.cpp
@@ -173,12 +173,8 @@
// TODO: Create an alternative way to retrieve the Attestation Challenge without this huge amount of calls.
// Retrieve attestation challenge
- ByteSpan attestationChallenge = commandObj->GetExchangeContext()
- ->GetExchangeMgr()
- ->GetSessionManager()
- ->GetSecureSession(commandObj->GetExchangeContext()->GetSessionHandle())
- ->GetCryptoContext()
- .GetAttestationChallenge();
+ ByteSpan attestationChallenge =
+ commandObj->GetExchangeContext()->GetSessionHandle()->AsSecureSession()->GetCryptoContext().GetAttestationChallenge();
Hash_SHA256_stream hashStream;
ReturnErrorOnFailure(hashStream.Begin());
@@ -291,7 +287,7 @@
void OnResponseTimeout(chip::Messaging::ExchangeContext * ec) override {}
void OnExchangeClosing(chip::Messaging::ExchangeContext * ec) override
{
- FabricIndex currentFabricIndex = ec->GetSessionHandle().GetFabricIndex();
+ FabricIndex currentFabricIndex = ec->GetSessionHandle()->AsSecureSession()->GetFabricIndex();
ec->GetExchangeMgr()->GetSessionManager()->ExpireAllPairingsForFabric(currentFabricIndex);
}
};
diff --git a/src/app/clusters/ota-requestor/OTARequestor.cpp b/src/app/clusters/ota-requestor/OTARequestor.cpp
index 3feb796..d114e4f 100644
--- a/src/app/clusters/ota-requestor/OTARequestor.cpp
+++ b/src/app/clusters/ota-requestor/OTARequestor.cpp
@@ -200,7 +200,7 @@
}
mProviderNodeId = providerNodeId;
- mProviderFabricIndex = commandObj->GetExchangeContext()->GetSessionHandle().GetFabricIndex();
+ mProviderFabricIndex = commandObj->GetExchangeContext()->GetSessionHandle()->AsSecureSession()->GetFabricIndex();
mProviderEndpointId = providerEndpoint;
ChipLogProgress(SoftwareUpdate, "OTA Requestor received AnnounceOTAProvider");
diff --git a/src/app/clusters/scenes/scenes.cpp b/src/app/clusters/scenes/scenes.cpp
index 126a17f..c2de7e8 100644
--- a/src/app/clusters/scenes/scenes.cpp
+++ b/src/app/clusters/scenes/scenes.cpp
@@ -69,7 +69,7 @@
{
VerifyOrReturnError(nullptr != commandObj, 0);
VerifyOrReturnError(nullptr != commandObj->GetExchangeContext(), 0);
- return commandObj->GetExchangeContext()->GetSessionHandle().GetFabricIndex();
+ return commandObj->GetExchangeContext()->GetSessionHandle()->AsSecureSession()->GetFabricIndex();
}
static bool readServerAttribute(EndpointId endpoint, ClusterId clusterId, AttributeId attributeId, const char * name,
diff --git a/src/app/reporting/Engine.cpp b/src/app/reporting/Engine.cpp
index 79471ae..e432af8 100644
--- a/src/app/reporting/Engine.cpp
+++ b/src/app/reporting/Engine.cpp
@@ -314,7 +314,7 @@
ReportDataMessage::Builder reportDataBuilder;
chip::System::PacketBufferHandle bufHandle = System::PacketBufferHandle::New(chip::app::kMaxSecureSduLengthBytes);
uint16_t reservedSize = 0;
- bool hasMoreChunks;
+ bool hasMoreChunks = false;
// Reserved size for the MoreChunks boolean flag, which takes up 1 byte for the control tag and 1 byte for the context tag.
const uint32_t kReservedSizeForMoreChunksFlag = 1 + 1;
diff --git a/src/app/tests/TestReadInteraction.cpp b/src/app/tests/TestReadInteraction.cpp
index 7ac4043..fb019d3 100644
--- a/src/app/tests/TestReadInteraction.cpp
+++ b/src/app/tests/TestReadInteraction.cpp
@@ -1319,7 +1319,7 @@
readPrepareParams.mAttributePathParamsListSize = 1;
- readPrepareParams.mSessionHandle = ctx.GetSessionBobToAlice();
+ readPrepareParams.mSessionHolder.Grab(ctx.GetSessionBobToAlice());
readPrepareParams.mMinIntervalFloorSeconds = 2;
readPrepareParams.mMaxIntervalCeilingSeconds = 5;
printf("\nSend subscribe request message to Node: %" PRIu64 "\n", chip::kTestDeviceNodeId);
@@ -1406,7 +1406,7 @@
readPrepareParams.mAttributePathParamsListSize = 1;
- readPrepareParams.mSessionHandle = ctx.GetSessionBobToAlice();
+ readPrepareParams.mSessionHolder.Grab(ctx.GetSessionBobToAlice());
readPrepareParams.mMinIntervalFloorSeconds = 6;
readPrepareParams.mMaxIntervalCeilingSeconds = 5;
diff --git a/src/app/tests/TestWriteInteraction.cpp b/src/app/tests/TestWriteInteraction.cpp
index 3cc12e8..ee167fd 100644
--- a/src/app/tests/TestWriteInteraction.cpp
+++ b/src/app/tests/TestWriteInteraction.cpp
@@ -257,7 +257,7 @@
AddAttributeDataIB(apSuite, apContext, writeClientHandle);
SessionHandle groupSession = ctx.GetSessionBobToFriends();
- NL_TEST_ASSERT(apSuite, groupSession.IsGroupSession());
+ NL_TEST_ASSERT(apSuite, groupSession->IsGroupSession());
err = writeClientHandle.SendWriteRequest(groupSession);
diff --git a/src/app/util/af-types.h b/src/app/util/af-types.h
index cd83865..e204924 100644
--- a/src/app/util/af-types.h
+++ b/src/app/util/af-types.h
@@ -324,7 +324,7 @@
*/
struct EmberAfClusterCommand
{
- chip::NodeId SourceNodeId() const { return source->GetSessionHandle().GetPeerNodeId(); }
+ chip::NodeId SourceNodeId() const { return source->GetSessionHandle()->AsSecureSession()->GetPeerNodeId(); }
/**
* APS frame for the incoming message
diff --git a/src/app/util/util.cpp b/src/app/util/util.cpp
index 22f6d37..67621dc 100644
--- a/src/app/util/util.cpp
+++ b/src/app/util/util.cpp
@@ -448,7 +448,7 @@
}
#ifdef EMBER_AF_PLUGIN_GROUPS_SERVER
else if ((cmd->type == EMBER_INCOMING_MULTICAST || cmd->type == EMBER_INCOMING_MULTICAST_LOOPBACK) &&
- !emberAfGroupsClusterEndpointInGroupCallback(cmd->source->GetSessionHandle().GetFabricIndex(),
+ !emberAfGroupsClusterEndpointInGroupCallback(cmd->source->GetSessionHandle()->AsSecureSession()->GetFabricIndex(),
cmd->apsFrame->destinationEndpoint, cmd->apsFrame->groupId))
{
emberAfDebugPrint("Drop cluster " ChipLogFormatMEI " command " ChipLogFormatMEI, ChipLogValueMEI(cmd->apsFrame->clusterId),
diff --git a/src/controller/CHIPCluster.cpp b/src/controller/CHIPCluster.cpp
index 4d2ca1d..34a3f68 100644
--- a/src/controller/CHIPCluster.cpp
+++ b/src/controller/CHIPCluster.cpp
@@ -52,15 +52,14 @@
if (mDevice->GetSecureSession().HasValue())
{
// Local copy to preserve original SessionHandle for future Unicast communication.
- SessionHandle session = mDevice->GetSecureSession().Value();
- session.SetGroupId(groupId);
- mSessionHandle.SetValue(session);
-
+ Optional<SessionHandle> session = mDevice->GetExchangeManager()->GetSessionManager()->CreateGroupSession(groupId);
// Sanity check
- if (!mSessionHandle.Value().IsGroupSession())
+ if (!session.HasValue() || !session.Value()->IsGroupSession())
{
err = CHIP_ERROR_INCORRECT_STATE;
}
+
+ mGroupSession.Grab(session.Value());
}
else
{
diff --git a/src/controller/CHIPCluster.h b/src/controller/CHIPCluster.h
index 4d9452d..ca02eed 100644
--- a/src/controller/CHIPCluster.h
+++ b/src/controller/CHIPCluster.h
@@ -139,9 +139,17 @@
}
};
- return chip::Controller::WriteAttribute<AttrType>(
- (mSessionHandle.HasValue() ? mSessionHandle.Value() : mDevice->GetSecureSession().Value()), mEndpoint, clusterId,
- attributeId, requestData, onSuccessCb, onFailureCb, aTimedWriteTimeoutMs, onDoneCb);
+ if (mGroupSession)
+ {
+ return chip::Controller::WriteAttribute<AttrType>(mGroupSession.Get(), mEndpoint, clusterId, attributeId, requestData,
+ onSuccessCb, onFailureCb, aTimedWriteTimeoutMs, onDoneCb);
+ }
+ else
+ {
+ return chip::Controller::WriteAttribute<AttrType>(mDevice->GetSecureSession().Value(), mEndpoint, clusterId,
+ attributeId, requestData, onSuccessCb, onFailureCb,
+ aTimedWriteTimeoutMs, onDoneCb);
+ }
}
template <typename AttributeInfo>
@@ -262,7 +270,7 @@
const ClusterId mClusterId;
DeviceProxy * mDevice;
EndpointId mEndpoint;
- chip::Optional<SessionHandle> mSessionHandle;
+ SessionHolder mGroupSession;
};
} // namespace Controller
diff --git a/src/controller/CHIPDeviceController.cpp b/src/controller/CHIPDeviceController.cpp
index b5e77c0..d2fae9c 100644
--- a/src/controller/CHIPDeviceController.cpp
+++ b/src/controller/CHIPDeviceController.cpp
@@ -268,17 +268,11 @@
mCASESessionManager->ReleaseSession(mFabricInfo->GetPeerIdForNode(remoteDeviceId));
}
-void DeviceController::OnSessionReleased(const SessionHandle & session)
-{
- VerifyOrReturn(mState == State::Initialized, ChipLogError(Controller, "OnConnectionExpired was called in incorrect state"));
- mCASESessionManager->OnSessionReleased(session);
-}
-
void DeviceController::OnFirstMessageDeliveryFailed(const SessionHandle & session)
{
VerifyOrReturn(mState == State::Initialized,
ChipLogError(Controller, "OnFirstMessageDeliveryFailed was called in incorrect state"));
- UpdateDevice(session.GetPeerNodeId());
+ UpdateDevice(session->AsSecureSession()->GetPeerNodeId());
}
CHIP_ERROR DeviceController::InitializePairedDeviceList()
@@ -622,7 +616,6 @@
{
ReturnErrorOnFailure(DeviceController::Init(params));
- params.systemState->SessionMgr()->RegisterReleaseDelegate(*this);
params.systemState->SessionMgr()->RegisterRecoveryDelegate(*this);
uint16_t nextKeyID = 0;
@@ -686,16 +679,6 @@
return CHIP_NO_ERROR;
}
-void DeviceCommissioner::OnSessionReleased(const SessionHandle & session)
-{
- VerifyOrReturn(mState == State::Initialized, ChipLogError(Controller, "OnConnectionExpired was called in incorrect state"));
-
- CommissioneeDeviceProxy * device = FindCommissioneeDevice(session);
- VerifyOrReturn(device != nullptr, ChipLogDetail(Controller, "OnConnectionExpired was called for unknown device, ignoring it."));
-
- device->OnSessionReleased(session);
-}
-
CommissioneeDeviceProxy * DeviceCommissioner::FindCommissioneeDevice(const SessionHandle & session)
{
CommissioneeDeviceProxy * foundDevice = nullptr;
@@ -1198,10 +1181,8 @@
DeviceAttestationVerifier * dac_verifier = GetDeviceAttestationVerifier();
// Retrieve attestation challenge
- ByteSpan attestationChallenge = mSystemState->SessionMgr()
- ->GetSecureSession(mDeviceBeingCommissioned->GetSecureSession().Value())
- ->GetCryptoContext()
- .GetAttestationChallenge();
+ ByteSpan attestationChallenge =
+ mDeviceBeingCommissioned->GetSecureSession().Value()->AsSecureSession()->GetCryptoContext().GetAttestationChallenge();
dac_verifier->VerifyAttestationInformation(attestationElements, attestationChallenge, signature,
mDeviceBeingCommissioned->GetPAI(), mDeviceBeingCommissioned->GetDAC(),
@@ -1361,10 +1342,8 @@
ReturnErrorOnFailure(ExtractPubkeyFromX509Cert(device->GetDAC(), dacPubkey));
// Retrieve attestation challenge
- ByteSpan attestationChallenge = mSystemState->SessionMgr()
- ->GetSecureSession(device->GetSecureSession().Value())
- ->GetCryptoContext()
- .GetAttestationChallenge();
+ ByteSpan attestationChallenge =
+ device->GetSecureSession().Value()->AsSecureSession()->GetCryptoContext().GetAttestationChallenge();
// The operational CA should also verify this on its end during NOC generation, if end-to-end attestation is desired.
ReturnErrorOnFailure(dacVerifier->VerifyNodeOperationalCSRInformation(NOCSRElements, attestationChallenge, AttestationSignature,
diff --git a/src/controller/CHIPDeviceController.h b/src/controller/CHIPDeviceController.h
index f9437e7..22a6101 100644
--- a/src/controller/CHIPDeviceController.h
+++ b/src/controller/CHIPDeviceController.h
@@ -173,8 +173,7 @@
* and device pairing information for individual devices). Alternatively, this class can retrieve the
* relevant information when the application tries to communicate with the device
*/
-class DLL_EXPORT DeviceController : public SessionReleaseDelegate,
- public SessionRecoveryDelegate,
+class DLL_EXPORT DeviceController : public SessionRecoveryDelegate,
#if CHIP_DEVICE_CONFIG_ENABLE_DNSSD
public AbstractDnssdDiscoveryController,
#endif
@@ -377,9 +376,6 @@
ReliableMessageProtocolConfig mMRPConfig = gDefaultMRPConfig;
- //////////// SessionReleaseDelegate Implementation ///////////////
- void OnSessionReleased(const SessionHandle & session) override;
-
//////////// SessionRecoveryDelegate Implementation ///////////////
void OnFirstMessageDeliveryFailed(const SessionHandle & session) override;
@@ -687,9 +683,6 @@
void OnSessionEstablishmentTimeout();
- //////////// SessionReleaseDelegate Implementation ///////////////
- void OnSessionReleased(const SessionHandle & session) override;
-
static void OnSessionEstablishmentTimeoutCallback(System::Layer * aLayer, void * aAppState);
/* This function sends a Device Attestation Certificate chain request to the device.
diff --git a/src/controller/CommissioneeDeviceProxy.cpp b/src/controller/CommissioneeDeviceProxy.cpp
index 5c33fb2..bebd673 100644
--- a/src/controller/CommissioneeDeviceProxy.cpp
+++ b/src/controller/CommissioneeDeviceProxy.cpp
@@ -58,9 +58,8 @@
{
if (mSecureSession)
{
- Transport::SecureSession * secureSession = mSessionManager->GetSecureSession(mSecureSession.Get());
// Check if the connection state has the correct transport information
- if (secureSession->GetPeerAddress().GetTransportType() == Transport::Type::kUndefined)
+ if (mSecureSession->AsSecureSession()->GetPeerAddress().GetTransportType() == Transport::Type::kUndefined)
{
mState = ConnectionState::NotConnected;
ReturnErrorOnFailure(LoadSecureSessionParameters());
@@ -86,12 +85,9 @@
return commandObj->SendCommandRequest(mSecureSession.Get());
}
-void CommissioneeDeviceProxy::OnSessionReleased(const SessionHandle & session)
+void CommissioneeDeviceProxy::OnSessionReleased()
{
- VerifyOrReturn(mSecureSession.Contains(session),
- ChipLogDetail(Controller, "Connection expired, but it doesn't match the current session"));
mState = ConnectionState::NotConnected;
- mSecureSession.Release();
}
CHIP_ERROR CommissioneeDeviceProxy::CloseSession()
@@ -130,7 +126,7 @@
return CHIP_NO_ERROR;
}
- Transport::SecureSession * secureSession = mSessionManager->GetSecureSession(mSecureSession.Get());
+ Transport::SecureSession * secureSession = mSecureSession.Get()->AsSecureSession();
secureSession->SetPeerAddress(addr);
return CHIP_NO_ERROR;
@@ -155,6 +151,7 @@
CHIP_ERROR CommissioneeDeviceProxy::LoadSecureSessionParameters()
{
CHIP_ERROR err = CHIP_NO_ERROR;
+ SessionHolder sessionHolder;
if (mSessionManager == nullptr || mState == ConnectionState::SecureConnected)
{
diff --git a/src/controller/CommissioneeDeviceProxy.h b/src/controller/CommissioneeDeviceProxy.h
index b8183be..200e591 100644
--- a/src/controller/CommissioneeDeviceProxy.h
+++ b/src/controller/CommissioneeDeviceProxy.h
@@ -84,7 +84,7 @@
{
public:
~CommissioneeDeviceProxy();
- CommissioneeDeviceProxy() {}
+ CommissioneeDeviceProxy() : mSecureSession(*this) {}
CommissioneeDeviceProxy(const CommissioneeDeviceProxy &) = delete;
/**
@@ -164,7 +164,7 @@
*
* @param session A handle to the secure session
*/
- void OnSessionReleased(const SessionHandle & session) override;
+ void OnSessionReleased() override;
/**
* In case there exists an open session to the device, mark it as expired.
@@ -298,7 +298,7 @@
Messaging::ExchangeManager * mExchangeMgr = nullptr;
- SessionHolder mSecureSession;
+ SessionHolderWithDelegate mSecureSession;
Controller::DeviceControllerInteractionModelDelegate * mpIMDelegate = nullptr;
diff --git a/src/controller/WriteInteraction.h b/src/controller/WriteInteraction.h
index a79d820..4f2472b 100644
--- a/src/controller/WriteInteraction.h
+++ b/src/controller/WriteInteraction.h
@@ -111,7 +111,7 @@
// called.
callback.release();
- if (sessionHandle.IsGroupSession())
+ if (sessionHandle->IsGroupSession())
{
ReturnErrorOnFailure(
handle.EncodeAttributeWritePayload(chip::app::AttributePathParams(clusterId, attributeId), requestData));
diff --git a/src/controller/tests/data_model/TestRead.cpp b/src/controller/tests/data_model/TestRead.cpp
index 1c3556c..113586d 100644
--- a/src/controller/tests/data_model/TestRead.cpp
+++ b/src/controller/tests/data_model/TestRead.cpp
@@ -276,7 +276,7 @@
NL_TEST_ASSERT(apSuite, ctx.GetExchangeManager().GetNumActiveExchanges() == 2);
- ctx.GetExchangeManager().ExpireExchangesForSession(ctx.GetSessionBobToAlice());
+ ctx.ExpireSessionBobToAlice();
ctx.DrainAndServiceIO();
@@ -291,7 +291,7 @@
chip::app::InteractionModelEngine::GetInstance()->GetReportingEngine().Run();
ctx.DrainAndServiceIO();
- ctx.GetExchangeManager().ExpireExchangesForSession(ctx.GetSessionAliceToBob());
+ ctx.ExpireSessionAliceToBob();
NL_TEST_ASSERT(apSuite, chip::app::InteractionModelEngine::GetInstance()->GetNumActiveReadHandlers() == 0);
diff --git a/src/lib/core/CHIPConfig.h b/src/lib/core/CHIPConfig.h
index 40d7f48..aa6edf8 100644
--- a/src/lib/core/CHIPConfig.h
+++ b/src/lib/core/CHIPConfig.h
@@ -2307,6 +2307,15 @@
#endif // CHIP_CONFIG_UNAUTHENTICATED_CONNECTION_POOL_SIZE
/**
+ * @def CHIP_CONFIG_GROUP_CONNECTION_POOL_SIZE
+ *
+ * @brief Define the size of the pool used for tracking CHIP groups.
+ */
+#ifndef CHIP_CONFIG_GROUP_CONNECTION_POOL_SIZE
+#define CHIP_CONFIG_GROUP_CONNECTION_POOL_SIZE 8
+#endif // CHIP_CONFIG_GROUP_CONNECTION_POOL_SIZE
+
+/**
* @def CHIP_CONFIG_PEER_CONNECTION_POOL_SIZE
*
* @brief Define the size of the pool used for tracking CHIP
diff --git a/src/lib/support/ReferenceCountedHandle.h b/src/lib/support/ReferenceCountedHandle.h
index c433a63..2a54e17 100644
--- a/src/lib/support/ReferenceCountedHandle.h
+++ b/src/lib/support/ReferenceCountedHandle.h
@@ -38,7 +38,7 @@
bool operator==(const ReferenceCountedHandle & that) const { return &mTarget == &that.mTarget; }
bool operator!=(const ReferenceCountedHandle & that) const { return !(*this == that); }
- Target * operator->() { return &mTarget; }
+ Target * operator->() const { return &mTarget; }
Target & Get() const { return mTarget; }
private:
diff --git a/src/messaging/ExchangeContext.cpp b/src/messaging/ExchangeContext.cpp
index 5dc287c..3a5cfb4 100644
--- a/src/messaging/ExchangeContext.cpp
+++ b/src/messaging/ExchangeContext.cpp
@@ -87,7 +87,7 @@
#if CONFIG_DEVICE_LAYER && CHIP_DEVICE_CONFIG_ENABLE_SED
void ExchangeContext::UpdateSEDPollingMode()
{
- if (GetSessionHandle().GetPeerAddress(mExchangeMgr->GetSessionManager())->GetTransportType() != Transport::Type::kBle)
+ if (GetSessionHandle()->AsSecureSession()->GetPeerAddress().GetTransportType() != Transport::Type::kBle)
{
if (!IsResponseExpected() && !IsSendExpected() && (mExchangeMgr->GetNumActiveExchanges() == 1))
{
@@ -125,10 +125,8 @@
// an error arising below. at the end, we have to close it.
ExchangeHandle ref(*this);
- bool isUDPTransport = IsUDPTransport();
-
- // this check is ignored by the ExchangeMsgDispatch if !AutoRequestAck()
- bool reliableTransmissionRequested = isUDPTransport && !sendFlags.Has(SendMessageFlags::kNoAutoRequestAck);
+ // If session requires MRP and NoAutoRequestAck send flag is not specificed, request reliable transmission.
+ bool reliableTransmissionRequested = GetSessionHandle()->RequireMRP() && !sendFlags.Has(SendMessageFlags::kNoAutoRequestAck);
// If a response message is expected...
if (sendFlags.Has(SendMessageFlags::kExpectResponse) && !IsGroupExchangeContext())
@@ -252,7 +250,8 @@
ExchangeContext::ExchangeContext(ExchangeManager * em, uint16_t ExchangeId, const SessionHandle & session, bool Initiator,
ExchangeDelegate * delegate) :
- mDispatch((delegate != nullptr) ? delegate->GetMessageDispatch() : ApplicationExchangeDispatch::Instance())
+ mDispatch((delegate != nullptr) ? delegate->GetMessageDispatch() : ApplicationExchangeDispatch::Instance()),
+ mSession(*this)
{
VerifyOrDie(mExchangeMgr == nullptr);
@@ -267,7 +266,7 @@
SetMsgRcvdFromPeer(false);
// Do not request Ack for multicast
- SetAutoRequestAck(!session.IsGroupSession());
+ SetAutoRequestAck(!session->IsGroupSession());
#if defined(CHIP_EXCHANGE_CONTEXT_DETAIL_LOGGING)
ChipLogDetail(ExchangeManager, "ec++ id: " ChipLogFormatExchange, ChipLogValueExchange(this));
@@ -316,14 +315,8 @@
&& (payloadHeader.IsInitiator() != IsInitiator());
}
-void ExchangeContext::OnConnectionExpired()
+void ExchangeContext::OnSessionReleased()
{
- // Reset our mSession to a default-initialized (hence not matching any
- // connection state) value, because it's still referencing the now-expired
- // connection. This will mean that no more messages can be sent via this
- // exchange, which seems fine given the semantics of connection expiration.
- mSession.Release();
-
if (!IsResponseExpected())
{
// Nothing to do in this case
@@ -488,43 +481,5 @@
Close();
}
-bool ExchangeContext::IsUDPTransport()
-{
- const Transport::PeerAddress * peerAddress = GetSessionHandle().GetPeerAddress(mExchangeMgr->GetSessionManager());
- return peerAddress && peerAddress->GetTransportType() == Transport::Type::kUdp;
-}
-
-bool ExchangeContext::IsTCPTransport()
-{
- const Transport::PeerAddress * peerAddress = GetSessionHandle().GetPeerAddress(mExchangeMgr->GetSessionManager());
- return peerAddress && peerAddress->GetTransportType() == Transport::Type::kTcp;
-}
-
-bool ExchangeContext::IsBLETransport()
-{
- const Transport::PeerAddress * peerAddress = GetSessionHandle().GetPeerAddress(mExchangeMgr->GetSessionManager());
- return peerAddress && peerAddress->GetTransportType() == Transport::Type::kBle;
-}
-
-System::Clock::Milliseconds32 ExchangeContext::GetAckTimeout()
-{
- System::Clock::Timeout timeout;
- if (IsUDPTransport())
- {
- timeout = GetMRPConfig().mIdleRetransTimeout * (CHIP_CONFIG_RMP_DEFAULT_MAX_RETRANS + 1);
- }
- else if (IsTCPTransport())
- {
- // TODO: issue 12009, need actual tcp margin value considering restransmission
- timeout = System::Clock::Seconds16(30);
- }
- return timeout;
-}
-
-const ReliableMessageProtocolConfig & ExchangeContext::GetMRPConfig() const
-{
- return GetSessionHandle().GetMRPConfig(GetExchangeMgr()->GetSessionManager());
-}
-
} // namespace Messaging
} // namespace chip
diff --git a/src/messaging/ExchangeContext.h b/src/messaging/ExchangeContext.h
index 57c704c..c5ec80a 100644
--- a/src/messaging/ExchangeContext.h
+++ b/src/messaging/ExchangeContext.h
@@ -55,7 +55,9 @@
* It defines methods for encoding and communicating CHIP messages within an ExchangeContext
* over various transport mechanisms, for example, TCP, UDP, or CHIP Reliable Messaging.
*/
-class DLL_EXPORT ExchangeContext : public ReliableMessageContext, public ReferenceCounted<ExchangeContext, ExchangeContextDeletor>
+class DLL_EXPORT ExchangeContext : public ReliableMessageContext,
+ public ReferenceCounted<ExchangeContext, ExchangeContextDeletor>,
+ public SessionReleaseDelegate
{
friend class ExchangeManager;
friend class ExchangeContextDeletor;
@@ -77,7 +79,13 @@
bool IsEncryptionRequired() const { return mDispatch.IsEncryptionRequired(); }
- bool IsGroupExchangeContext() const { return (mSession && mSession.Get().IsGroupSession()); }
+ bool IsGroupExchangeContext() const
+ {
+ return (mSession && mSession->GetSessionType() == Transport::Session::SessionType::kGroup);
+ }
+
+ // Implement SessionReleaseDelegate
+ void OnSessionReleased() override;
/**
* Send a CHIP message on this exchange.
@@ -167,18 +175,6 @@
void SetResponseTimeout(Timeout timeout);
- // TODO: move following 5 functions into SessionHandle once we can access session vars w/o using a SessionManager
- /*
- * Get the overall acknowledge timeout period for the underneath transport(MRP+UDP/TCP)
- */
- System::Clock::Milliseconds32 GetAckTimeout();
-
- bool IsUDPTransport();
- bool IsTCPTransport();
- bool IsBLETransport();
- // Helper function for easily accessing MRP config
- const ReliableMessageProtocolConfig & GetMRPConfig() const;
-
private:
Timeout mResponseTimeout{ 0 }; // Maximum time to wait for response (in milliseconds); 0 disables response timeout.
ExchangeDelegate * mDelegate = nullptr;
@@ -186,8 +182,8 @@
ExchangeMessageDispatch & mDispatch;
- SessionHolder mSession; // The connection state
- uint16_t mExchangeId; // Assigned exchange ID.
+ SessionHolderWithDelegate mSession; // The connection state
+ uint16_t mExchangeId; // Assigned exchange ID.
/**
* Determine whether a response is currently expected for a message that was sent over
@@ -230,11 +226,6 @@
bool MatchExchange(const SessionHandle & session, const PacketHeader & packetHeader, const PayloadHeader & payloadHeader);
/**
- * Notify the exchange that its connection has expired.
- */
- void OnConnectionExpired();
-
- /**
* Notify our delegate, if any, that we have timed out waiting for a
* response.
*/
diff --git a/src/messaging/ExchangeMgr.cpp b/src/messaging/ExchangeMgr.cpp
index 41963fe..a6c8421 100644
--- a/src/messaging/ExchangeMgr.cpp
+++ b/src/messaging/ExchangeMgr.cpp
@@ -83,7 +83,6 @@
handler.Reset();
}
- sessionManager->RegisterReleaseDelegate(*this);
sessionManager->SetMessageDelegate(this);
mReliableMessageMgr.Init(sessionManager->SystemLayer());
@@ -106,7 +105,6 @@
if (mSessionManager != nullptr)
{
mSessionManager->SetMessageDelegate(nullptr);
- mSessionManager->UnregisterReleaseDelegate(*this);
mSessionManager = nullptr;
}
@@ -313,24 +311,6 @@
}
}
-void ExchangeManager::OnSessionReleased(const SessionHandle & session)
-{
- ExpireExchangesForSession(session);
-}
-
-void ExchangeManager::ExpireExchangesForSession(const SessionHandle & session)
-{
- mContextPool.ForEachActiveObject([&](auto * ec) {
- if (ec->mSession.Contains(session))
- {
- ec->OnConnectionExpired();
- // Continue to iterate because there can be multiple exchanges
- // associated with the connection.
- }
- return Loop::Continue;
- });
-}
-
void ExchangeManager::CloseAllContextsForDelegate(const ExchangeDelegate * delegate)
{
mContextPool.ForEachActiveObject([&](auto * ec) {
diff --git a/src/messaging/ExchangeMgr.h b/src/messaging/ExchangeMgr.h
index 5cd2939..e4e5d0f 100644
--- a/src/messaging/ExchangeMgr.h
+++ b/src/messaging/ExchangeMgr.h
@@ -49,7 +49,7 @@
* It works on be behalf of higher layers, creating ExchangeContexts and
* handling the registration/unregistration of unsolicited message handlers.
*/
-class DLL_EXPORT ExchangeManager : public SessionMessageDelegate, public SessionReleaseDelegate
+class DLL_EXPORT ExchangeManager : public SessionMessageDelegate
{
friend class ExchangeContext;
@@ -193,10 +193,6 @@
size_t GetNumActiveExchanges() { return mContextPool.Allocated(); }
- // TODO: this should be test only, after OnSessionReleased is move to SessionHandle within the exchange context
- // Expire all exchanges associated with the given session
- void ExpireExchangesForSession(const SessionHandle & session);
-
private:
enum class State
{
@@ -244,8 +240,6 @@
void OnMessageReceived(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, const SessionHandle & session,
const Transport::PeerAddress & source, DuplicateMessage isDuplicate,
System::PacketBufferHandle && msgBuf) override;
-
- void OnSessionReleased(const SessionHandle & session) override;
};
} // namespace Messaging
diff --git a/src/messaging/ReliableMessageMgr.cpp b/src/messaging/ReliableMessageMgr.cpp
index bc92ac8..4c696e8 100644
--- a/src/messaging/ReliableMessageMgr.cpp
+++ b/src/messaging/ReliableMessageMgr.cpp
@@ -141,7 +141,8 @@
" Send Cnt %d",
messageCounter, ChipLogValueExchange(&entry->ec.Get()), entry->sendCount);
// TODO: Choose active/idle timeout corresponding to the activity of exchanges of the session.
- entry->nextRetransTime = System::SystemClock().GetMonotonicTimestamp() + entry->ec->GetMRPConfig().mActiveRetransTimeout;
+ entry->nextRetransTime =
+ System::SystemClock().GetMonotonicTimestamp() + entry->ec->GetSessionHandle()->GetMRPConfig().mActiveRetransTimeout;
SendFromRetransTable(entry);
// For test not using async IO loop, the entry may have been removed after send, do not use entry below
@@ -185,7 +186,8 @@
void ReliableMessageMgr::StartRetransmision(RetransTableEntry * entry)
{
// TODO: Choose active/idle timeout corresponding to the activity of exchanges of the session.
- entry->nextRetransTime = System::SystemClock().GetMonotonicTimestamp() + entry->ec->GetMRPConfig().mIdleRetransTimeout;
+ entry->nextRetransTime =
+ System::SystemClock().GetMonotonicTimestamp() + entry->ec->GetSessionHandle()->GetMRPConfig().mIdleRetransTimeout;
StartTimer();
}
diff --git a/src/messaging/tests/MessagingContext.cpp b/src/messaging/tests/MessagingContext.cpp
index 267cc93..39d1c51 100644
--- a/src/messaging/tests/MessagingContext.cpp
+++ b/src/messaging/tests/MessagingContext.cpp
@@ -36,14 +36,11 @@
ReturnErrorOnFailure(mExchangeManager.Init(&mSessionManager));
ReturnErrorOnFailure(mMessageCounterManager.Init(&mExchangeManager));
- mSessionBobToFriends.Grab(mSessionManager.CreateGroupSession(GetBobKeyId(), GetFriendsGroupId(), GetFabricIndex()).Value());
+ ReturnErrorOnFailure(CreateSessionBobToAlice());
+ ReturnErrorOnFailure(CreateSessionAliceToBob());
+ ReturnErrorOnFailure(CreateSessionBobToFriends());
- ReturnErrorOnFailure(mSessionManager.NewPairing(mSessionBobToAlice, Optional<Transport::PeerAddress>::Value(mAliceAddress),
- GetAliceNodeId(), &mPairingBobToAlice, CryptoContext::SessionRole::kInitiator,
- mSrcFabricIndex));
-
- return mSessionManager.NewPairing(mSessionAliceToBob, Optional<Transport::PeerAddress>::Value(mBobAddress), GetBobNodeId(),
- &mPairingAliceToBob, CryptoContext::SessionRole::kResponder, mDestFabricIndex);
+ return CHIP_NO_ERROR;
}
// Shutdown all layers, finalize operations
@@ -71,6 +68,24 @@
return err;
}
+CHIP_ERROR MessagingContext::CreateSessionBobToAlice()
+{
+ return mSessionManager.NewPairing(mSessionBobToAlice, Optional<Transport::PeerAddress>::Value(mAliceAddress), GetAliceNodeId(),
+ &mPairingBobToAlice, CryptoContext::SessionRole::kInitiator, mSrcFabricIndex);
+}
+
+CHIP_ERROR MessagingContext::CreateSessionAliceToBob()
+{
+ return mSessionManager.NewPairing(mSessionAliceToBob, Optional<Transport::PeerAddress>::Value(mBobAddress), GetBobNodeId(),
+ &mPairingAliceToBob, CryptoContext::SessionRole::kResponder, mDestFabricIndex);
+}
+
+CHIP_ERROR MessagingContext::CreateSessionBobToFriends()
+{
+ mSessionBobToFriends.Grab(mSessionManager.CreateGroupSession(GetFriendsGroupId()).Value());
+ return CHIP_NO_ERROR;
+}
+
SessionHandle MessagingContext::GetSessionBobToAlice()
{
return mSessionBobToAlice.Get();
@@ -86,6 +101,21 @@
return mSessionBobToFriends.Get();
}
+void MessagingContext::ExpireSessionBobToAlice()
+{
+ mSessionManager.ExpirePairing(mSessionBobToAlice.Get());
+}
+
+void MessagingContext::ExpireSessionAliceToBob()
+{
+ mSessionManager.ExpirePairing(mSessionAliceToBob.Get());
+}
+
+void MessagingContext::ExpireSessionBobToFriends()
+{
+ // TODO: expire the group session
+}
+
Messaging::ExchangeContext * MessagingContext::NewUnauthenticatedExchangeToAlice(Messaging::ExchangeDelegate * delegate)
{
return mExchangeManager.NewContext(mSessionManager.CreateUnauthenticatedSession(mAliceAddress, gDefaultMRPConfig).Value(),
@@ -100,13 +130,11 @@
Messaging::ExchangeContext * MessagingContext::NewExchangeToAlice(Messaging::ExchangeDelegate * delegate)
{
- // TODO: temporary create a SessionHandle from node id, will be fix in PR 3602
return mExchangeManager.NewContext(GetSessionBobToAlice(), delegate);
}
Messaging::ExchangeContext * MessagingContext::NewExchangeToBob(Messaging::ExchangeDelegate * delegate)
{
- // TODO: temporary create a SessionHandle from node id, will be fix in PR 3602
return mExchangeManager.NewContext(GetSessionAliceToBob(), delegate);
}
diff --git a/src/messaging/tests/MessagingContext.h b/src/messaging/tests/MessagingContext.h
index ef3a261..be14c0a 100644
--- a/src/messaging/tests/MessagingContext.h
+++ b/src/messaging/tests/MessagingContext.h
@@ -88,6 +88,14 @@
Messaging::ExchangeManager & GetExchangeManager() { return mExchangeManager; }
secure_channel::MessageCounterManager & GetMessageCounterManager() { return mMessageCounterManager; }
+ CHIP_ERROR CreateSessionBobToAlice();
+ CHIP_ERROR CreateSessionAliceToBob();
+ CHIP_ERROR CreateSessionBobToFriends();
+
+ void ExpireSessionBobToAlice();
+ void ExpireSessionAliceToBob();
+ void ExpireSessionBobToFriends();
+
SessionHandle GetSessionBobToAlice();
SessionHandle GetSessionAliceToBob();
SessionHandle GetSessionBobToFriends();
diff --git a/src/messaging/tests/TestExchangeMgr.cpp b/src/messaging/tests/TestExchangeMgr.cpp
index 965cd8f..9ee9be6 100644
--- a/src/messaging/tests/TestExchangeMgr.cpp
+++ b/src/messaging/tests/TestExchangeMgr.cpp
@@ -97,7 +97,7 @@
NL_TEST_ASSERT(inSuite, ec1 != nullptr);
NL_TEST_ASSERT(inSuite, ec1->IsInitiator() == true);
NL_TEST_ASSERT(inSuite, ec1->GetExchangeId() != 0);
- auto sessionPeerToLocal = ctx.GetSecureSessionManager().GetSecureSession(ec1->GetSessionHandle());
+ auto sessionPeerToLocal = ec1->GetSessionHandle()->AsSecureSession();
NL_TEST_ASSERT(inSuite, sessionPeerToLocal->GetPeerNodeId() == ctx.GetBobNodeId());
NL_TEST_ASSERT(inSuite, sessionPeerToLocal->GetPeerSessionId() == ctx.GetBobKeyId());
NL_TEST_ASSERT(inSuite, ec1->GetDelegate() == &mockAppDelegate);
@@ -105,7 +105,7 @@
ExchangeContext * ec2 = ctx.NewExchangeToAlice(&mockAppDelegate);
NL_TEST_ASSERT(inSuite, ec2 != nullptr);
NL_TEST_ASSERT(inSuite, ec2->GetExchangeId() > ec1->GetExchangeId());
- auto sessionLocalToPeer = ctx.GetSecureSessionManager().GetSecureSession(ec2->GetSessionHandle());
+ auto sessionLocalToPeer = ec2->GetSessionHandle()->AsSecureSession();
NL_TEST_ASSERT(inSuite, sessionLocalToPeer->GetPeerNodeId() == ctx.GetAliceNodeId());
NL_TEST_ASSERT(inSuite, sessionLocalToPeer->GetPeerSessionId() == ctx.GetAliceKeyId());
@@ -121,7 +121,7 @@
ExchangeContext * ec1 = ctx.NewExchangeToBob(&sendDelegate);
// Expire the session this exchange is supposedly on.
- ctx.GetExchangeManager().ExpireExchangesForSession(ec1->GetSessionHandle());
+ ctx.GetSecureSessionManager().ExpirePairing(ec1->GetSessionHandle());
MockAppDelegate receiveDelegate;
CHIP_ERROR err =
@@ -138,6 +138,9 @@
err = ctx.GetExchangeManager().UnregisterUnsolicitedMessageHandlerForType(Protocols::BDX::Id, kMsgType_TEST1);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
+
+ // recreate closed session.
+ NL_TEST_ASSERT(inSuite, ctx.CreateSessionAliceToBob() == CHIP_NO_ERROR);
}
void CheckSessionExpirationTimeout(nlTestSuite * inSuite, void * inContext)
@@ -153,10 +156,12 @@
ctx.DrainAndServiceIO();
NL_TEST_ASSERT(inSuite, !sendDelegate.IsOnResponseTimeoutCalled);
- // Expire the session this exchange is supposedly on. This should close the
- // exchange.
- ctx.GetExchangeManager().ExpireExchangesForSession(ec1->GetSessionHandle());
+ // Expire the session this exchange is supposedly on. This should close the exchange.
+ ctx.GetSecureSessionManager().ExpirePairing(ec1->GetSessionHandle());
NL_TEST_ASSERT(inSuite, sendDelegate.IsOnResponseTimeoutCalled);
+
+ // recreate closed session.
+ NL_TEST_ASSERT(inSuite, ctx.CreateSessionAliceToBob() == CHIP_NO_ERROR);
}
void CheckUmhRegistrationTest(nlTestSuite * inSuite, void * inContext)
diff --git a/src/messaging/tests/TestReliableMessageProtocol.cpp b/src/messaging/tests/TestReliableMessageProtocol.cpp
index d78d38f..981c019 100644
--- a/src/messaging/tests/TestReliableMessageProtocol.cpp
+++ b/src/messaging/tests/TestReliableMessageProtocol.cpp
@@ -203,11 +203,10 @@
ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr();
NL_TEST_ASSERT(inSuite, rm != nullptr);
- exchange->GetSessionHandle().SetMRPConfig(&ctx.GetSecureSessionManager(),
- {
- 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL
- 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL
- });
+ exchange->GetSessionHandle()->AsSecureSession()->SetMRPConfig({
+ 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL
+ 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL
+ });
// Let's drop the initial message
gLoopback.mSentMessageCount = 0;
@@ -269,11 +268,10 @@
ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr();
NL_TEST_ASSERT(inSuite, rm != nullptr);
- exchange->GetSessionHandle().SetMRPConfig(&ctx.GetSecureSessionManager(),
- {
- 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL
- 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL
- });
+ exchange->GetSessionHandle()->AsSecureSession()->SetMRPConfig({
+ 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL
+ 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL
+ });
// Let's drop the initial message
gLoopback.mSentMessageCount = 0;
@@ -330,11 +328,10 @@
ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr();
NL_TEST_ASSERT(inSuite, rm != nullptr);
- exchange->GetSessionHandle().SetMRPConfig(&ctx.GetSecureSessionManager(),
- {
- 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL
- 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL
- });
+ exchange->GetSessionHandle()->AsSecureSession()->SetMRPConfig({
+ 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL
+ 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL
+ });
mockSender.mMessageDispatch.mRetainMessageOnSend = false;
@@ -421,11 +418,10 @@
ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr();
NL_TEST_ASSERT(inSuite, rm != nullptr);
- exchange->GetSessionHandle().SetMRPConfig(&ctx.GetSecureSessionManager(),
- {
- 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL
- 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL
- });
+ exchange->GetSessionHandle()->AsSecureSession()->SetMRPConfig({
+ 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL
+ 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL
+ });
// Let's drop the initial message
gLoopback.mSentMessageCount = 0;
@@ -484,11 +480,10 @@
ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr();
NL_TEST_ASSERT(inSuite, rm != nullptr);
- exchange->GetSessionHandle().SetMRPConfig(&ctx.GetSecureSessionManager(),
- {
- 64_ms32, // CHIP_CONFIG_RMP_DEFAULT_INITIAL_RETRY_INTERVAL
- 64_ms32, // CHIP_CONFIG_RMP_DEFAULT_ACTIVE_RETRY_INTERVAL
- });
+ exchange->GetSessionHandle()->AsSecureSession()->SetMRPConfig({
+ 64_ms32, // CHIP_CONFIG_RMP_DEFAULT_INITIAL_RETRY_INTERVAL
+ 64_ms32, // CHIP_CONFIG_RMP_DEFAULT_ACTIVE_RETRY_INTERVAL
+ });
// Let's not drop the message. Expectation is that it is received by the peer, but the ack is dropped
gLoopback.mSentMessageCount = 0;
@@ -556,11 +551,10 @@
ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr();
NL_TEST_ASSERT(inSuite, rm != nullptr);
- exchange->GetSessionHandle().SetMRPConfig(&ctx.GetSecureSessionManager(),
- {
- 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL
- 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL
- });
+ exchange->GetSessionHandle()->AsUnauthenticatedSession()->SetMRPConfig({
+ 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL
+ 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL
+ });
// Let's drop the initial message
gLoopback.mSentMessageCount = 0;
@@ -621,11 +615,10 @@
ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr();
NL_TEST_ASSERT(inSuite, rm != nullptr);
- exchange->GetSessionHandle().SetMRPConfig(&ctx.GetSecureSessionManager(),
- {
- 64_ms32, // CHIP_CONFIG_RMP_DEFAULT_INITIAL_RETRY_INTERVAL
- 64_ms32, // CHIP_CONFIG_RMP_DEFAULT_ACTIVE_RETRY_INTERVAL
- });
+ exchange->GetSessionHandle()->AsSecureSession()->SetMRPConfig({
+ 64_ms32, // CHIP_CONFIG_RMP_DEFAULT_INITIAL_RETRY_INTERVAL
+ 64_ms32, // CHIP_CONFIG_RMP_DEFAULT_ACTIVE_RETRY_INTERVAL
+ });
// Let's not drop the message. Expectation is that it is received by the peer, but the ack is dropped
gLoopback.mSentMessageCount = 0;
@@ -1120,11 +1113,10 @@
NL_TEST_ASSERT(inSuite, rm->TestGetCountRetransTable() == 0);
// Make sure that we resend our message before the other side does.
- exchange->GetSessionHandle().SetMRPConfig(&ctx.GetSecureSessionManager(),
- {
- 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL
- 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL
- });
+ exchange->GetSessionHandle()->AsSecureSession()->SetMRPConfig({
+ 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL
+ 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL
+ });
// We send a message, the other side sends an application-level response
// (which is lost), then we do a retransmit that is acked, then the other
@@ -1158,11 +1150,10 @@
// Make sure receiver resends after sender does, and there's enough of a gap
// that we are very unlikely to actually trigger the resends on the receiver
// when we trigger the resends on the sender.
- mockReceiver.mExchange->GetSessionHandle().SetMRPConfig(&ctx.GetSecureSessionManager(),
- {
- 256_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL
- 256_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL
- });
+ mockReceiver.mExchange->GetSessionHandle()->AsSecureSession()->SetMRPConfig({
+ 256_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL
+ 256_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL
+ });
// Now send a message from the other side, but drop it.
gLoopback.mNumMessagesToDrop = 1;
diff --git a/src/protocols/echo/Echo.h b/src/protocols/echo/Echo.h
index 1f3429b..cd00d68 100644
--- a/src/protocols/echo/Echo.h
+++ b/src/protocols/echo/Echo.h
@@ -103,7 +103,7 @@
Messaging::ExchangeManager * mExchangeMgr = nullptr;
Messaging::ExchangeContext * mExchangeCtx = nullptr;
EchoFunct OnEchoResponseReceived = nullptr;
- Optional<SessionHandle> mSecureSession = Optional<SessionHandle>();
+ SessionHolder mSecureSession;
CHIP_ERROR OnMessageReceived(Messaging::ExchangeContext * ec, const PayloadHeader & payloadHeader,
System::PacketBufferHandle && payload) override;
diff --git a/src/protocols/echo/EchoClient.cpp b/src/protocols/echo/EchoClient.cpp
index e8bf65c..6805dab 100644
--- a/src/protocols/echo/EchoClient.cpp
+++ b/src/protocols/echo/EchoClient.cpp
@@ -39,7 +39,7 @@
return CHIP_ERROR_INCORRECT_STATE;
mExchangeMgr = exchangeMgr;
- mSecureSession.SetValue(session);
+ mSecureSession.Grab(session);
OnEchoResponseReceived = nullptr;
mExchangeCtx = nullptr;
@@ -71,7 +71,7 @@
}
// Create a new exchange context.
- mExchangeCtx = mExchangeMgr->NewContext(mSecureSession.Value(), this);
+ mExchangeCtx = mExchangeMgr->NewContext(mSecureSession.Get(), this);
if (mExchangeCtx == nullptr)
{
return CHIP_ERROR_NO_MEMORY;
diff --git a/src/protocols/secure_channel/CASESession.cpp b/src/protocols/secure_channel/CASESession.cpp
index 6555232..ff9b25b 100644
--- a/src/protocols/secure_channel/CASESession.cpp
+++ b/src/protocols/secure_channel/CASESession.cpp
@@ -231,7 +231,7 @@
mFabricInfo = fabric;
mLocalMRPConfig = mrpConfig;
- mExchangeCtxt->SetResponseTimeout(kSigma_Response_Timeout + mExchangeCtxt->GetAckTimeout());
+ mExchangeCtxt->SetResponseTimeout(kSigma_Response_Timeout + mExchangeCtxt->GetSessionHandle()->GetAckTimeout());
SetPeerAddress(peerAddress);
SetPeerNodeId(peerNodeId);
@@ -1438,7 +1438,7 @@
else
{
mExchangeCtxt = ec;
- mExchangeCtxt->SetResponseTimeout(kSigma_Response_Timeout + mExchangeCtxt->GetAckTimeout());
+ mExchangeCtxt->SetResponseTimeout(kSigma_Response_Timeout + mExchangeCtxt->GetSessionHandle()->GetAckTimeout());
}
VerifyOrReturnError(!msg.IsNull(), CHIP_ERROR_INVALID_ARGUMENT);
diff --git a/src/protocols/secure_channel/MessageCounterManager.cpp b/src/protocols/secure_channel/MessageCounterManager.cpp
index 9c3a4d9..4ccf8a0 100644
--- a/src/protocols/secure_channel/MessageCounterManager.cpp
+++ b/src/protocols/secure_channel/MessageCounterManager.cpp
@@ -104,15 +104,13 @@
void MessageCounterManager::OnResponseTimeout(Messaging::ExchangeContext * exchangeContext)
{
- Transport::SecureSession * state = mExchangeMgr->GetSessionManager()->GetSecureSession(exchangeContext->GetSessionHandle());
-
- if (state != nullptr)
+ if (exchangeContext->HasSessionHandle())
{
- state->GetSessionMessageCounter().GetPeerMessageCounter().SyncFailed();
+ exchangeContext->GetSessionHandle()->AsSecureSession()->GetSessionMessageCounter().GetPeerMessageCounter().SyncFailed();
}
else
{
- ChipLogError(SecureChannel, "Timed out! Failed to clear message counter synchronization status.");
+ ChipLogError(SecureChannel, "MCSP Timeout! On a already released session.");
}
}
@@ -223,39 +221,27 @@
CHIP_ERROR MessageCounterManager::SendMsgCounterSyncResp(Messaging::ExchangeContext * exchangeContext,
FixedByteSpan<kChallengeSize> challenge)
{
- CHIP_ERROR err = CHIP_NO_ERROR;
- Transport::SecureSession * state = nullptr;
System::PacketBufferHandle msgBuf;
- uint8_t * msg = nullptr;
- state = mExchangeMgr->GetSessionManager()->GetSecureSession(exchangeContext->GetSessionHandle());
- VerifyOrExit(state != nullptr, err = CHIP_ERROR_NOT_CONNECTED);
+ VerifyOrDie(exchangeContext->HasSessionHandle());
// Allocate new buffer.
msgBuf = MessagePacketBuffer::New(kSyncRespMsgSize);
- VerifyOrExit(!msgBuf.IsNull(), err = CHIP_ERROR_NO_MEMORY);
-
- msg = msgBuf->Start();
+ VerifyOrReturnError(!msgBuf.IsNull(), CHIP_ERROR_NO_MEMORY);
{
+ uint8_t * msg = msgBuf->Start();
Encoding::LittleEndian::BufferWriter bbuf(msg, kSyncRespMsgSize);
- bbuf.Put32(state->GetSessionMessageCounter().GetLocalMessageCounter().Value());
+ bbuf.Put32(
+ exchangeContext->GetSessionHandle()->AsSecureSession()->GetSessionMessageCounter().GetLocalMessageCounter().Value());
bbuf.Put(challenge.data(), kChallengeSize);
- VerifyOrExit(bbuf.Fit(), err = CHIP_ERROR_NO_MEMORY);
+ VerifyOrReturnError(bbuf.Fit(), CHIP_ERROR_NO_MEMORY);
}
msgBuf->SetDataLength(kSyncRespMsgSize);
- err = exchangeContext->SendMessage(Protocols::SecureChannel::MsgType::MsgCounterSyncRsp, std::move(msgBuf),
- Messaging::SendFlags(Messaging::SendMessageFlags::kNoAutoRequestAck));
-
-exit:
- if (err != CHIP_NO_ERROR)
- {
- ChipLogError(SecureChannel, "Failed to send message counter synchronization response with error:%s", ErrorStr(err));
- }
-
- return err;
+ return exchangeContext->SendMessage(Protocols::SecureChannel::MsgType::MsgCounterSyncRsp, std::move(msgBuf),
+ Messaging::SendFlags(Messaging::SendMessageFlags::kNoAutoRequestAck));
}
CHIP_ERROR MessageCounterManager::HandleMsgCounterSyncReq(Messaging::ExchangeContext * exchangeContext,
@@ -288,17 +274,14 @@
{
CHIP_ERROR err = CHIP_NO_ERROR;
- Transport::SecureSession * state = nullptr;
- uint32_t syncCounter = 0;
+ uint32_t syncCounter = 0;
const uint8_t * resp = msgBuf->Start();
size_t resplen = msgBuf->DataLength();
ChipLogDetail(SecureChannel, "Received MsgCounterSyncResp response");
- // Find an active connection to the specified peer node
- state = mExchangeMgr->GetSessionManager()->GetSecureSession(exchangeContext->GetSessionHandle());
- VerifyOrExit(state != nullptr, err = CHIP_ERROR_NOT_CONNECTED);
+ VerifyOrDie(exchangeContext->HasSessionHandle());
VerifyOrExit(msgBuf->DataLength() == kSyncRespMsgSize, err = CHIP_ERROR_INVALID_MESSAGE_LENGTH);
@@ -310,11 +293,12 @@
// Verify that the response field matches the expected Challenge field for the exchange.
err =
- state->GetSessionMessageCounter().GetPeerMessageCounter().VerifyChallenge(syncCounter, FixedByteSpan<kChallengeSize>(resp));
+ exchangeContext->GetSessionHandle()->AsSecureSession()->GetSessionMessageCounter().GetPeerMessageCounter().VerifyChallenge(
+ syncCounter, FixedByteSpan<kChallengeSize>(resp));
SuccessOrExit(err);
// Process all queued incoming messages after message counter synchronization is completed.
- ProcessPendingMessages(exchangeContext->GetSessionHandle().GetPeerNodeId());
+ ProcessPendingMessages(exchangeContext->GetSessionHandle()->AsSecureSession()->GetPeerNodeId());
exit:
if (err != CHIP_NO_ERROR)
diff --git a/src/protocols/secure_channel/PASESession.cpp b/src/protocols/secure_channel/PASESession.cpp
index 5d1128b..4056ee1 100644
--- a/src/protocols/secure_channel/PASESession.cpp
+++ b/src/protocols/secure_channel/PASESession.cpp
@@ -321,7 +321,7 @@
SuccessOrExit(err);
mExchangeCtxt = exchangeCtxt;
- mExchangeCtxt->SetResponseTimeout(kSpake2p_Response_Timeout + mExchangeCtxt->GetAckTimeout());
+ mExchangeCtxt->SetResponseTimeout(kSpake2p_Response_Timeout + mExchangeCtxt->GetSessionHandle()->GetAckTimeout());
SetPeerAddress(peerAddress);
SetPeerNodeId(NodeIdFromPAKEKeyId(mPasscodeID));
@@ -890,7 +890,7 @@
else
{
mExchangeCtxt = exchange;
- mExchangeCtxt->SetResponseTimeout(kSpake2p_Response_Timeout + mExchangeCtxt->GetAckTimeout());
+ mExchangeCtxt->SetResponseTimeout(kSpake2p_Response_Timeout + mExchangeCtxt->GetSessionHandle()->GetAckTimeout());
}
VerifyOrReturnError(!msg.IsNull(), CHIP_ERROR_INVALID_ARGUMENT);
diff --git a/src/protocols/secure_channel/tests/TestPASESession.cpp b/src/protocols/secure_channel/tests/TestPASESession.cpp
index 3f50a98..1f2a672 100644
--- a/src/protocols/secure_channel/tests/TestPASESession.cpp
+++ b/src/protocols/secure_channel/tests/TestPASESession.cpp
@@ -181,11 +181,10 @@
NL_TEST_ASSERT(inSuite, rm != nullptr);
NL_TEST_ASSERT(inSuite, rc != nullptr);
- contextCommissioner->GetSessionHandle().SetMRPConfig(&ctx.GetSecureSessionManager(),
- {
- 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL
- 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL
- });
+ contextCommissioner->GetSessionHandle()->AsUnauthenticatedSession()->SetMRPConfig({
+ 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL
+ 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL
+ });
}
NL_TEST_ASSERT(inSuite,
@@ -309,11 +308,10 @@
NL_TEST_ASSERT(inSuite, rm != nullptr);
NL_TEST_ASSERT(inSuite, rc != nullptr);
- contextCommissioner->GetSessionHandle().SetMRPConfig(&ctx.GetSecureSessionManager(),
- {
- 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL
- 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL
- });
+ contextCommissioner->GetSessionHandle()->AsUnauthenticatedSession()->SetMRPConfig({
+ 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL
+ 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL
+ });
NL_TEST_ASSERT(inSuite,
ctx.GetExchangeManager().RegisterUnsolicitedMessageHandlerForType(
diff --git a/src/transport/BUILD.gn b/src/transport/BUILD.gn
index 106ac72..7ea6570 100644
--- a/src/transport/BUILD.gn
+++ b/src/transport/BUILD.gn
@@ -24,6 +24,7 @@
sources = [
"CryptoContext.cpp",
"CryptoContext.h",
+ "GroupSession.h",
"MessageCounter.cpp",
"MessageCounter.h",
"MessageCounterManagerInterface.h",
@@ -31,11 +32,16 @@
"PeerMessageCounter.h",
"SecureMessageCodec.cpp",
"SecureMessageCodec.h",
+ "SecureSession.cpp",
"SecureSession.h",
"SecureSessionTable.h",
+ "Session.cpp",
+ "Session.h",
"SessionDelegate.h",
"SessionHandle.cpp",
"SessionHandle.h",
+ "SessionHolder.cpp",
+ "SessionHolder.h",
"SessionManager.cpp",
"SessionManager.h",
"SessionMessageCounter.h",
diff --git a/src/transport/GroupSession.h b/src/transport/GroupSession.h
new file mode 100644
index 0000000..f0fd85a
--- /dev/null
+++ b/src/transport/GroupSession.h
@@ -0,0 +1,127 @@
+/*
+ * Copyright (c) 2021 Project CHIP Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <app/util/basic-types.h>
+#include <lib/core/GroupId.h>
+#include <lib/support/Pool.h>
+#include <transport/Session.h>
+
+namespace chip {
+namespace Transport {
+
+class GroupSession : public Session
+{
+public:
+ GroupSession(GroupId group, FabricIndex fabricIndex) : mGroupId(group), mFabricIndex(fabricIndex) {}
+ ~GroupSession() { NotifySessionReleased(); }
+
+ Session::SessionType GetSessionType() const override { return Session::SessionType::kGroup; }
+#if CHIP_PROGRESS_LOGGING
+ const char * GetSessionTypeString() const override { return "secure"; };
+#endif
+
+ Access::SubjectDescriptor GetSubjectDescriptor() const override
+ {
+ Access::SubjectDescriptor isd;
+ isd.authMode = Access::AuthMode::kGroup;
+ // TODO: fill other group subjects fields
+ return isd; // return an empty ISD for unauthenticated session.
+ }
+
+ bool RequireMRP() const override { return false; }
+
+ const ReliableMessageProtocolConfig & GetMRPConfig() const override
+ {
+ VerifyOrDie(false);
+ return gDefaultMRPConfig;
+ }
+
+ System::Clock::Milliseconds32 GetAckTimeout() const override
+ {
+ VerifyOrDie(false);
+ return System::Clock::Timeout();
+ }
+
+ GroupId GetGroupId() const { return mGroupId; }
+ FabricIndex GetFabricIndex() const { return mFabricIndex; }
+
+private:
+ const GroupId mGroupId;
+ const FabricIndex mFabricIndex;
+};
+
+/*
+ * @brief
+ * An table which manages GroupSessions
+ */
+template <size_t kMaxSessionCount>
+class GroupSessionTable
+{
+public:
+ ~GroupSessionTable() { mEntries.ReleaseAll(); }
+
+ /**
+ * Get a session given the peer address. If the session doesn't exist in the cache, allocate a new entry for it.
+ *
+ * @return the session found or allocated, nullptr if not found and allocation failed.
+ */
+ CHECK_RETURN_VALUE
+ Optional<SessionHandle> AllocEntry(GroupId group, FabricIndex fabricIndex)
+ {
+ GroupSession * entry = mEntries.CreateObject(group, fabricIndex);
+ if (entry != nullptr)
+ {
+ return MakeOptional<SessionHandle>(*entry);
+ }
+ else
+ {
+ return Optional<SessionHandle>::Missing();
+ }
+ }
+
+ /**
+ * Get a session using given GroupId
+ */
+ CHECK_RETURN_VALUE
+ Optional<SessionHandle> FindEntry(GroupId group, FabricIndex fabricIndex)
+ {
+ GroupSession * result = nullptr;
+ mEntries.ForEachActiveObject([&](GroupSession * entry) {
+ if (entry->GetGroupId() == group && entry->GetFabricIndex() == fabricIndex)
+ {
+ result = entry;
+ return Loop::Break;
+ }
+ return Loop::Continue;
+ });
+ if (result != nullptr)
+ {
+ return MakeOptional<SessionHandle>(*result);
+ }
+ else
+ {
+ return Optional<SessionHandle>::Missing();
+ }
+ }
+
+private:
+ BitMapObjectPool<GroupSession, kMaxSessionCount> mEntries;
+};
+
+} // namespace Transport
+} // namespace chip
diff --git a/src/transport/SecureSession.cpp b/src/transport/SecureSession.cpp
new file mode 100644
index 0000000..d9e2736
--- /dev/null
+++ b/src/transport/SecureSession.cpp
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2021 Project CHIP Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <access/AuthMode.h>
+#include <transport/SecureSession.h>
+
+namespace chip {
+namespace Transport {
+
+Access::SubjectDescriptor SecureSession::GetSubjectDescriptor() const
+{
+ Access::SubjectDescriptor subjectDescriptor;
+ if (IsOperationalNodeId(mPeerNodeId))
+ {
+ subjectDescriptor.authMode = Access::AuthMode::kCase;
+ subjectDescriptor.subject = mPeerNodeId;
+ subjectDescriptor.fabricIndex = mFabric;
+ // TODO(#10243): add CATs
+ }
+ else if (IsPAKEKeyId(mPeerNodeId))
+ {
+ subjectDescriptor.authMode = Access::AuthMode::kPase;
+ subjectDescriptor.subject = mPeerNodeId;
+ // TODO(#10242): PASE *can* have fabric in some situations
+ }
+ else
+ {
+ VerifyOrDie(false);
+ }
+ return subjectDescriptor;
+}
+
+} // namespace Transport
+} // namespace chip
diff --git a/src/transport/SecureSession.h b/src/transport/SecureSession.h
index 84b0dbe..05efecc 100644
--- a/src/transport/SecureSession.h
+++ b/src/transport/SecureSession.h
@@ -25,6 +25,7 @@
#include <credentials/CHIPCert.h>
#include <messaging/ReliableMessageProtocolConfig.h>
#include <transport/CryptoContext.h>
+#include <transport/Session.h>
#include <transport/SessionMessageCounter.h>
#include <transport/raw/Base.h>
#include <transport/raw/MessageHeader.h>
@@ -47,10 +48,8 @@
* - LastActivityTime is a monotonic timestamp of when this connection was
* last used. Inactive connections can expire.
* - CryptoContext contains the encryption context of a connection
- *
- * TODO: to add any message ACK information
*/
-class SecureSession
+class SecureSession : public Session
{
public:
/**
@@ -70,14 +69,37 @@
mPeerNodeId(peerNodeId), mPeerCATs(peerCATs), mLocalSessionId(localSessionId), mPeerSessionId(peerSessionId),
mFabric(fabric), mLastActivityTime(System::SystemClock().GetMonotonicTimestamp()), mMRPConfig(config)
{}
+ ~SecureSession() { NotifySessionReleased(); }
SecureSession(SecureSession &&) = delete;
SecureSession(const SecureSession &) = delete;
SecureSession & operator=(const SecureSession &) = delete;
SecureSession & operator=(SecureSession &&) = delete;
+ Session::SessionType GetSessionType() const override { return Session::SessionType::kSecure; }
+#if CHIP_PROGRESS_LOGGING
+ const char * GetSessionTypeString() const override { return "secure"; };
+#endif
+
+ Access::SubjectDescriptor GetSubjectDescriptor() const override;
+
+ bool RequireMRP() const override { return GetPeerAddress().GetTransportType() == Transport::Type::kUdp; }
+
+ System::Clock::Milliseconds32 GetAckTimeout() const override
+ {
+ switch (mPeerAddress.GetTransportType())
+ {
+ case Transport::Type::kUdp:
+ return GetMRPConfig().mIdleRetransTimeout * (CHIP_CONFIG_RMP_DEFAULT_MAX_RETRANS + 1);
+ case Transport::Type::kTcp:
+ return System::Clock::Seconds16(30);
+ default:
+ break;
+ }
+ return System::Clock::Timeout();
+ }
+
const PeerAddress & GetPeerAddress() const { return mPeerAddress; }
- PeerAddress & GetPeerAddress() { return mPeerAddress; }
void SetPeerAddress(const PeerAddress & address) { mPeerAddress = address; }
Type GetSecureSessionType() const { return mSecureSessionType; }
@@ -86,7 +108,7 @@
void SetMRPConfig(const ReliableMessageProtocolConfig & config) { mMRPConfig = config; }
- const ReliableMessageProtocolConfig & GetMRPConfig() const { return mMRPConfig; }
+ const ReliableMessageProtocolConfig & GetMRPConfig() const override { return mMRPConfig; }
uint16_t GetLocalSessionId() const { return mLocalSessionId; }
uint16_t GetPeerSessionId() const { return mPeerSessionId; }
diff --git a/src/transport/SecureSessionTable.h b/src/transport/SecureSessionTable.h
index a680ae0..9dce219 100644
--- a/src/transport/SecureSessionTable.h
+++ b/src/transport/SecureSessionTable.h
@@ -60,11 +60,13 @@
* has been reached (with CHIP_ERROR_NO_MEMORY).
*/
CHECK_RETURN_VALUE
- SecureSession * CreateNewSecureSession(SecureSession::Type secureSessionType, uint16_t localSessionId, NodeId peerNodeId,
- CATValues peerCATs, uint16_t peerSessionId, FabricIndex fabric,
- const ReliableMessageProtocolConfig & config)
+ Optional<SessionHandle> CreateNewSecureSession(SecureSession::Type secureSessionType, uint16_t localSessionId,
+ NodeId peerNodeId, CATValues peerCATs, uint16_t peerSessionId,
+ FabricIndex fabric, const ReliableMessageProtocolConfig & config)
{
- return mEntries.CreateObject(secureSessionType, localSessionId, peerNodeId, peerCATs, peerSessionId, fabric, config);
+ SecureSession * result =
+ mEntries.CreateObject(secureSessionType, localSessionId, peerNodeId, peerCATs, peerSessionId, fabric, config);
+ return result != nullptr ? MakeOptional<SessionHandle>(*result) : Optional<SessionHandle>::Missing();
}
void ReleaseSession(SecureSession * session) { mEntries.ReleaseObject(session); }
@@ -83,7 +85,7 @@
* @return the state found, nullptr if not found
*/
CHECK_RETURN_VALUE
- SecureSession * FindSecureSessionByLocalKey(uint16_t localSessionId)
+ Optional<SessionHandle> FindSecureSessionByLocalKey(uint16_t localSessionId)
{
SecureSession * result = nullptr;
mEntries.ForEachActiveObject([&](auto session) {
@@ -94,7 +96,7 @@
}
return Loop::Continue;
});
- return result;
+ return result != nullptr ? MakeOptional<SessionHandle>(*result) : Optional<SessionHandle>::Missing();
}
/**
diff --git a/src/transport/Session.cpp b/src/transport/Session.cpp
new file mode 100644
index 0000000..768cc46
--- /dev/null
+++ b/src/transport/Session.cpp
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) 2021 Project CHIP Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <transport/GroupSession.h>
+#include <transport/SecureSession.h>
+#include <transport/Session.h>
+#include <transport/UnauthenticatedSessionTable.h>
+
+namespace chip {
+namespace Transport {
+
+SecureSession * Session::AsSecureSession()
+{
+ VerifyOrDie(GetSessionType() == SessionType::kSecure);
+ return static_cast<SecureSession *>(this);
+}
+
+UnauthenticatedSession * Session::AsUnauthenticatedSession()
+{
+ VerifyOrDie(GetSessionType() == SessionType::kUnauthenticated);
+ return static_cast<UnauthenticatedSession *>(this);
+}
+
+GroupSession * Session::AsGroupSession()
+{
+ VerifyOrDie(GetSessionType() == SessionType::kGroup);
+ return static_cast<GroupSession *>(this);
+}
+
+} // namespace Transport
+} // namespace chip
diff --git a/src/transport/Session.h b/src/transport/Session.h
new file mode 100644
index 0000000..4e343cb
--- /dev/null
+++ b/src/transport/Session.h
@@ -0,0 +1,93 @@
+/*
+ * Copyright (c) 2021 Project CHIP Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <lib/core/CHIPConfig.h>
+#include <messaging/ReliableMessageProtocolConfig.h>
+#include <transport/SessionHolder.h>
+#include <transport/raw/PeerAddress.h>
+
+namespace chip {
+namespace Transport {
+
+class SecureSession;
+class UnauthenticatedSession;
+class GroupSession;
+
+class Session
+{
+public:
+ virtual ~Session() {}
+
+ enum class SessionType : uint8_t
+ {
+ kUndefined = 0,
+ kUnauthenticated = 1,
+ kSecure = 2,
+ kGroup = 3,
+ };
+
+ virtual SessionType GetSessionType() const = 0;
+#if CHIP_PROGRESS_LOGGING
+ virtual const char * GetSessionTypeString() const = 0;
+#endif
+
+ void AddHolder(SessionHolder & holder)
+ {
+ VerifyOrDie(!holder.IsInList());
+ mHolders.PushBack(&holder);
+ }
+
+ void RemoveHolder(SessionHolder & holder)
+ {
+ VerifyOrDie(mHolders.Contains(&holder));
+ mHolders.Remove(&holder);
+ }
+
+ // For types of sessions using reference counter, override these functions, otherwise leave it empty.
+ virtual void Retain() {}
+ virtual void Release() {}
+
+ virtual Access::SubjectDescriptor GetSubjectDescriptor() const = 0;
+ virtual bool RequireMRP() const = 0;
+ virtual const ReliableMessageProtocolConfig & GetMRPConfig() const = 0;
+ virtual System::Clock::Milliseconds32 GetAckTimeout() const = 0;
+
+ SecureSession * AsSecureSession();
+ UnauthenticatedSession * AsUnauthenticatedSession();
+ GroupSession * AsGroupSession();
+
+ bool IsGroupSession() const { return GetSessionType() == SessionType::kGroup; }
+
+protected:
+ // This should be called by sub-classes at the very beginning of the destructor, before any data field is disposed, such that
+ // the session is still functional during the callback.
+ void NotifySessionReleased()
+ {
+ SessionHandle session(*this);
+ while (!mHolders.Empty())
+ {
+ mHolders.begin()->OnSessionReleased(); // OnSessionReleased must remove the item from the linked list
+ }
+ }
+
+private:
+ IntrusiveList<SessionHolder> mHolders;
+};
+
+} // namespace Transport
+} // namespace chip
diff --git a/src/transport/SessionDelegate.h b/src/transport/SessionDelegate.h
index bee1ab6..97e1d66 100644
--- a/src/transport/SessionDelegate.h
+++ b/src/transport/SessionDelegate.h
@@ -31,7 +31,7 @@
*
* @param session The handle to the secure session
*/
- virtual void OnSessionReleased(const SessionHandle & session) = 0;
+ virtual void OnSessionReleased() = 0;
};
class DLL_EXPORT SessionRecoveryDelegate
diff --git a/src/transport/SessionHandle.cpp b/src/transport/SessionHandle.cpp
index eebed4f..3fe0ebf 100644
--- a/src/transport/SessionHandle.cpp
+++ b/src/transport/SessionHandle.cpp
@@ -23,83 +23,4 @@
using namespace Transport;
-using AuthMode = Access::AuthMode;
-using SubjectDescriptor = Access::SubjectDescriptor;
-
-SubjectDescriptor SessionHandle::GetSubjectDescriptor() const
-{
- SubjectDescriptor subjectDescriptor;
- if (IsSecure())
- {
- if (IsOperationalNodeId(mPeerNodeId))
- {
- subjectDescriptor.authMode = AuthMode::kCase;
- subjectDescriptor.subject = mPeerNodeId;
- subjectDescriptor.fabricIndex = mFabric;
- // TODO(#10243): add CATs
- }
- else if (IsPAKEKeyId(mPeerNodeId))
- {
- subjectDescriptor.authMode = AuthMode::kPase;
- subjectDescriptor.subject = mPeerNodeId;
- // TODO(#10242): PASE *can* have fabric in some situations
- }
- else if (mGroupId.HasValue())
- {
- subjectDescriptor.authMode = AuthMode::kGroup;
- subjectDescriptor.subject = NodeIdFromGroupId(mGroupId.Value());
- }
- }
- return subjectDescriptor;
-}
-
-const PeerAddress * SessionHandle::GetPeerAddress(SessionManager * sessionManager) const
-{
- if (IsSecure())
- {
- SecureSession * state = sessionManager->GetSecureSession(*this);
- if (state == nullptr)
- {
- return nullptr;
- }
-
- return &state->GetPeerAddress();
- }
-
- return &GetUnauthenticatedSession()->GetPeerAddress();
-}
-
-const ReliableMessageProtocolConfig & SessionHandle::GetMRPConfig(SessionManager * sessionManager) const
-{
- if (IsSecure())
- {
- SecureSession * secureSession = sessionManager->GetSecureSession(*this);
- if (secureSession == nullptr)
- {
- return gDefaultMRPConfig;
- }
- return secureSession->GetMRPConfig();
- }
- else
- {
- return GetUnauthenticatedSession()->GetMRPConfig();
- }
-}
-
-void SessionHandle::SetMRPConfig(SessionManager * sessionManager, const ReliableMessageProtocolConfig & config)
-{
- if (IsSecure())
- {
- SecureSession * secureSession = sessionManager->GetSecureSession(*this);
- if (secureSession != nullptr)
- {
- secureSession->SetMRPConfig(config);
- }
- }
- else
- {
- return GetUnauthenticatedSession()->SetMRPConfig(config);
- }
-}
-
} // namespace chip
diff --git a/src/transport/SessionHandle.h b/src/transport/SessionHandle.h
index fb9dc31..344dab3 100644
--- a/src/transport/SessionHandle.h
+++ b/src/transport/SessionHandle.h
@@ -17,92 +17,39 @@
#pragma once
-#include <access/AccessControl.h>
-#include <app/util/basic-types.h>
-#include <lib/core/NodeId.h>
-#include <lib/core/Optional.h>
-#include <transport/UnauthenticatedSessionTable.h>
-#include <transport/raw/PeerAddress.h>
+#include <access/SubjectDescriptor.h>
+#include <lib/support/ReferenceCountedHandle.h>
namespace chip {
-class SessionManager;
+namespace Transport {
+class Session;
+} // namespace Transport
+class SessionHolder;
+
+/** @brief
+ * Non-copyable session reference. All SessionHandles are created within SessionManager. SessionHandle is not
+ * reference *counted, hence it is not allowed to store SessionHandle anywhere except for function arguments and
+ * return values. SessionHandle is short-lived as it is only available as stack variable, so it is never dangling. */
class SessionHandle
{
public:
- using SubjectDescriptor = Access::SubjectDescriptor;
+ SessionHandle(Transport::Session & session) : mSession(session) {}
+ ~SessionHandle() {}
- SessionHandle(NodeId peerNodeId, FabricIndex fabric) : mPeerNodeId(peerNodeId), mFabric(fabric) {}
+ SessionHandle(const SessionHandle &) = delete;
+ SessionHandle operator=(const SessionHandle &) = delete;
+ SessionHandle(SessionHandle &&) = default;
+ SessionHandle & operator=(SessionHandle &&) = delete;
- SessionHandle(Transport::UnauthenticatedSessionHandle session) :
- mPeerNodeId(kPlaceholderNodeId), mFabric(kUndefinedFabricIndex), mUnauthenticatedSessionHandle(session)
- {}
+ bool operator==(const SessionHandle & that) const { return &mSession.Get() == &that.mSession.Get(); }
- SessionHandle(Transport::SecureSession & session) : mPeerNodeId(session.GetPeerNodeId()), mFabric(session.GetFabricIndex())
- {
- mLocalSessionId.SetValue(session.GetLocalSessionId());
- mPeerSessionId.SetValue(session.GetPeerSessionId());
- }
-
- SessionHandle(NodeId peerNodeId, GroupId groupId, FabricIndex fabric) : mPeerNodeId(peerNodeId), mFabric(fabric)
- {
- mGroupId.SetValue(groupId);
- }
-
- bool IsSecure() const { return !mUnauthenticatedSessionHandle.HasValue(); }
-
- bool HasFabricIndex() const { return (mFabric != kUndefinedFabricIndex); }
- FabricIndex GetFabricIndex() const { return mFabric; }
- void SetFabricIndex(FabricIndex fabricId) { mFabric = fabricId; }
- void SetGroupId(GroupId groupId) { mGroupId.SetValue(groupId); }
-
- SubjectDescriptor GetSubjectDescriptor() const;
-
- bool operator==(const SessionHandle & that) const
- {
- if (IsSecure())
- {
- return that.IsSecure() && mLocalSessionId.Value() == that.mLocalSessionId.Value();
- }
- else
- {
- return !that.IsSecure() && mUnauthenticatedSessionHandle.Value() == that.mUnauthenticatedSessionHandle.Value();
- }
- }
-
- NodeId GetPeerNodeId() const { return mPeerNodeId; }
- bool IsGroupSession() const { return mGroupId.HasValue(); }
- const Optional<GroupId> & GetGroupId() const { return mGroupId; }
- const Optional<uint16_t> & GetPeerSessionId() const { return mPeerSessionId; }
- const Optional<uint16_t> & GetLocalSessionId() const { return mLocalSessionId; }
-
- // Return the peer address for this session. May return null if the peer
- // address is not known. This can happen for secure sessions that have been
- // torn down, at the very least.
- const Transport::PeerAddress * GetPeerAddress(SessionManager * sessionManager) const;
-
- const ReliableMessageProtocolConfig & GetMRPConfig(SessionManager * sessionManager) const;
- void SetMRPConfig(SessionManager * sessionManager, const ReliableMessageProtocolConfig & config);
-
- Transport::UnauthenticatedSessionHandle GetUnauthenticatedSession() const { return mUnauthenticatedSessionHandle.Value(); }
+ Transport::Session * operator->() const { return mSession.operator->(); }
private:
- friend class SessionManager;
-
- // Fields for secure session
- NodeId mPeerNodeId;
- Optional<uint16_t> mLocalSessionId;
- Optional<uint16_t> mPeerSessionId;
- Optional<GroupId> mGroupId;
- // TODO: Re-evaluate the storing of Fabric ID in SessionHandle
- // The Fabric ID will not be available for PASE and group sessions. So need
- // to identify an approach that'll allow looking up the corresponding information for
- // such sessions.
- FabricIndex mFabric;
-
- // Fields for unauthenticated session
- Optional<Transport::UnauthenticatedSessionHandle> mUnauthenticatedSessionHandle;
+ friend class SessionHolder;
+ ReferenceCountedHandle<Transport::Session> mSession;
};
} // namespace chip
diff --git a/src/transport/SessionHolder.cpp b/src/transport/SessionHolder.cpp
new file mode 100644
index 0000000..290cc18
--- /dev/null
+++ b/src/transport/SessionHolder.cpp
@@ -0,0 +1,91 @@
+/*
+ * Copyright (c) 2021 Project CHIP Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <transport/Session.h>
+#include <transport/SessionHolder.h>
+
+namespace chip {
+
+SessionHolder::~SessionHolder()
+{
+ Release();
+}
+
+SessionHolder::SessionHolder(const SessionHolder & that) : IntrusiveListNodeBase()
+{
+ mSession = that.mSession;
+ if (mSession.HasValue())
+ {
+ mSession.Value()->AddHolder(*this);
+ }
+}
+
+SessionHolder::SessionHolder(SessionHolder && that) : IntrusiveListNodeBase()
+{
+ mSession = that.mSession;
+ if (mSession.HasValue())
+ {
+ mSession.Value()->AddHolder(*this);
+ }
+
+ that.Release();
+}
+
+SessionHolder & SessionHolder::operator=(const SessionHolder & that)
+{
+ Release();
+
+ mSession = that.mSession;
+ if (mSession.HasValue())
+ {
+ mSession.Value()->AddHolder(*this);
+ }
+
+ return *this;
+}
+
+SessionHolder & SessionHolder::operator=(SessionHolder && that)
+{
+ Release();
+
+ mSession = that.mSession;
+ if (mSession.HasValue())
+ {
+ mSession.Value()->AddHolder(*this);
+ }
+
+ that.Release();
+
+ return *this;
+}
+
+void SessionHolder::Grab(const SessionHandle & session)
+{
+ Release();
+ mSession.Emplace(session.mSession);
+ session->AddHolder(*this);
+}
+
+void SessionHolder::Release()
+{
+ if (mSession.HasValue())
+ {
+ mSession.Value()->RemoveHolder(*this);
+ mSession.ClearValue();
+ }
+}
+
+} // namespace chip
diff --git a/src/transport/SessionHolder.h b/src/transport/SessionHolder.h
index 8210f8e..897004f 100644
--- a/src/transport/SessionHolder.h
+++ b/src/transport/SessionHolder.h
@@ -17,6 +17,7 @@
#pragma once
#include <lib/core/Optional.h>
+#include <lib/support/IntrusiveList.h>
#include <transport/SessionDelegate.h>
#include <transport/SessionHandle.h>
@@ -26,42 +27,59 @@
* Managed session reference. The object is used to store a session, the stored session will be automatically
* released when the underlying session is released. One must verify it is available before use. The object can be
* created using SessionHandle.Grab()
- *
- * TODO: release holding session when the session is released. This will be implemented by following PRs
*/
-class SessionHolder : public SessionReleaseDelegate
+class SessionHolder : public SessionReleaseDelegate, public IntrusiveListNodeBase
{
public:
SessionHolder() {}
- SessionHolder(const SessionHandle & session) : mSession(session) {}
- ~SessionHolder() { Release(); }
+ ~SessionHolder();
SessionHolder(const SessionHolder &);
- SessionHolder operator=(const SessionHolder &);
SessionHolder(SessionHolder && that);
- SessionHolder operator=(SessionHolder && that);
+ SessionHolder & operator=(const SessionHolder &);
+ SessionHolder & operator=(SessionHolder && that);
- void Grab(const SessionHandle & sessionHandle)
+ // Implement SessionReleaseDelegate
+ void OnSessionReleased() override { Release(); }
+
+ bool Contains(const SessionHandle & session) const
{
- Release();
- mSession.SetValue(sessionHandle);
+ return mSession.HasValue() && &mSession.Value().Get() == &session.mSession.Get();
}
- void Release() { mSession.ClearValue(); }
-
- // TODO: call this function when the underlying session is released
- // Implement SessionReleaseDelegate
- void OnSessionReleased(const SessionHandle & session) override { Release(); }
-
- // Check whether the SessionHolder contains a session matching given session
- bool Contains(const SessionHandle & session) const { return mSession.HasValue() && mSession.Value() == session; }
+ void Grab(const SessionHandle & session);
+ void Release();
operator bool() const { return mSession.HasValue(); }
- const SessionHandle & Get() const { return mSession.Value(); }
- Optional<SessionHandle> ToOptional() const { return mSession; }
+ SessionHandle Get() const { return SessionHandle{ mSession.Value().Get() }; }
+ Optional<SessionHandle> ToOptional() const
+ {
+ return mSession.HasValue() ? chip::MakeOptional<SessionHandle>(Get()) : chip::Optional<SessionHandle>::Missing();
+ }
+
+ Transport::Session * operator->() const { return &mSession.Value().Get(); }
private:
- Optional<SessionHandle> mSession;
+ Optional<ReferenceCountedHandle<Transport::Session>> mSession;
+};
+
+// @brief Extends SessionHolder to allow propagate OnSessionReleased event to an extra given destination
+class SessionHolderWithDelegate : public SessionHolder
+{
+public:
+ SessionHolderWithDelegate(SessionReleaseDelegate & delegate) : mDelegate(delegate) {}
+ operator bool() const { return SessionHolder::operator bool(); }
+
+ void OnSessionReleased() override
+ {
+ Release();
+
+ // Note, the session is already cleared during mDelegate.OnSessionReleased
+ mDelegate.OnSessionReleased();
+ }
+
+private:
+ SessionReleaseDelegate & mDelegate;
};
} // namespace chip
diff --git a/src/transport/SessionManager.cpp b/src/transport/SessionManager.cpp
index 0e6452e..1edf5c2 100644
--- a/src/transport/SessionManager.cpp
+++ b/src/transport/SessionManager.cpp
@@ -37,6 +37,7 @@
#include <lib/support/logging/CHIPLogging.h>
#include <platform/CHIPDeviceLayer.h>
#include <protocols/secure_channel/Constants.h>
+#include <transport/GroupSession.h>
#include <transport/PairingSession.h>
#include <transport/SecureMessageCodec.h>
#include <transport/TransportMgr.h>
@@ -95,7 +96,6 @@
{
CancelExpiryTimer();
- mSessionReleaseDelegates.ReleaseAll();
mSessionRecoveryDelegates.ReleaseAll();
mMessageCounterManager = nullptr;
@@ -119,48 +119,47 @@
NodeId destination;
FabricIndex fabricIndex;
#endif // CHIP_PROGRESS_LOGGING
- if (sessionHandle.IsSecure())
+
+ switch (sessionHandle->GetSessionType())
{
- if (sessionHandle.IsGroupSession())
- {
- // TODO : #11911
- // For now, just set the packetHeader with the correct data.
- packetHeader.SetDestinationGroupId(sessionHandle.GetGroupId());
- packetHeader.SetFlags(Header::SecFlagValues::kPrivacyFlag);
- packetHeader.SetSessionType(Header::SessionType::kGroupSession);
- // TODO : Replace the PeerNodeId with Our nodeId
- packetHeader.SetSourceNodeId(sessionHandle.GetPeerNodeId());
+ case Transport::Session::SessionType::kGroup: {
+ // TODO : #11911
+ // For now, just set the packetHeader with the correct data.
+ packetHeader.SetDestinationGroupId(sessionHandle->AsGroupSession()->GetGroupId());
+ packetHeader.SetFlags(Header::SecFlagValues::kPrivacyFlag);
+ packetHeader.SetSessionType(Header::SessionType::kGroupSession);
+ // TODO : Replace the PeerNodeId with Our nodeId
+ packetHeader.SetSourceNodeId(kUndefinedNodeId);
- if (!packetHeader.IsValidGroupMsg())
- {
- return CHIP_ERROR_INTERNAL;
- }
- // TODO #11911 Update SecureMessageCodec::Encrypt for Group
- ReturnErrorOnFailure(payloadHeader.EncodeBeforeData(message));
+ if (!packetHeader.IsValidGroupMsg())
+ {
+ return CHIP_ERROR_INTERNAL;
+ }
+ // TODO #11911 Update SecureMessageCodec::Encrypt for Group
+ ReturnErrorOnFailure(payloadHeader.EncodeBeforeData(message));
#if CHIP_PROGRESS_LOGGING
- destination = sessionHandle.GetPeerNodeId();
- fabricIndex = sessionHandle.GetFabricIndex();
+ destination = kUndefinedNodeId;
+ fabricIndex = kUndefinedFabricIndex;
#endif // CHIP_PROGRESS_LOGGING
- }
- else
- {
- SecureSession * session = GetSecureSession(sessionHandle);
- if (session == nullptr)
- {
- return CHIP_ERROR_NOT_CONNECTED;
- }
- MessageCounter & counter = GetSendCounterForPacket(payloadHeader, *session);
- ReturnErrorOnFailure(SecureMessageCodec::Encrypt(session, payloadHeader, packetHeader, message, counter));
-
-#if CHIP_PROGRESS_LOGGING
- destination = session->GetPeerNodeId();
- fabricIndex = session->GetFabricIndex();
-#endif // CHIP_PROGRESS_LOGGING
- }
}
- else
- {
+ break;
+ case Transport::Session::SessionType::kSecure: {
+ SecureSession * session = sessionHandle->AsSecureSession();
+ if (session == nullptr)
+ {
+ return CHIP_ERROR_NOT_CONNECTED;
+ }
+ MessageCounter & counter = GetSendCounterForPacket(payloadHeader, *session);
+ ReturnErrorOnFailure(SecureMessageCodec::Encrypt(session, payloadHeader, packetHeader, message, counter));
+
+#if CHIP_PROGRESS_LOGGING
+ destination = session->GetPeerNodeId();
+ fabricIndex = session->GetFabricIndex();
+#endif // CHIP_PROGRESS_LOGGING
+ }
+ break;
+ case Transport::Session::SessionType::kUnauthenticated: {
ReturnErrorOnFailure(payloadHeader.EncodeBeforeData(message));
MessageCounter & counter = mGlobalUnencryptedMessageCounter;
@@ -174,13 +173,17 @@
fabricIndex = kUndefinedFabricIndex;
#endif // CHIP_PROGRESS_LOGGING
}
+ break;
+ default:
+ return CHIP_ERROR_INTERNAL;
+ }
ChipLogProgress(Inet,
"Prepared %s message %p to 0x" ChipLogFormatX64 " (%u) of type " ChipLogFormatMessageType
" and protocolId " ChipLogFormatProtocolId " on exchange " ChipLogFormatExchangeId
" with MessageCounter:" ChipLogFormatMessageCounter ".",
- sessionHandle.IsSecure() ? "encrypted" : "plaintext", &preparedMessage, ChipLogValueX64(destination),
- fabricIndex, payloadHeader.GetMessageType(), ChipLogValueProtocolId(payloadHeader.GetProtocolID()),
+ sessionHandle->GetSessionTypeString(), &preparedMessage, ChipLogValueX64(destination), fabricIndex,
+ payloadHeader.GetMessageType(), ChipLogValueProtocolId(payloadHeader.GetProtocolID()),
ChipLogValueExchangeIdFromSentHeader(payloadHeader), packetHeader.GetMessageCounter());
ReturnErrorOnFailure(packetHeader.EncodeBeforeData(message));
@@ -197,58 +200,56 @@
const Transport::PeerAddress * destination;
- if (sessionHandle.IsSecure())
+ switch (sessionHandle->GetSessionType())
{
- if (sessionHandle.IsGroupSession())
- {
- chip::Transport::PeerAddress multicastAddress =
- Transport::PeerAddress::Multicast(sessionHandle.GetFabricIndex(), sessionHandle.GetGroupId().Value());
- destination = static_cast<Transport::PeerAddress *>(&multicastAddress);
- char addressStr[Transport::PeerAddress::kMaxToStringSize];
- multicastAddress.ToString(addressStr, Transport::PeerAddress::kMaxToStringSize);
+ case Transport::Session::SessionType::kGroup: {
+ auto groupSession = sessionHandle->AsGroupSession();
+ Transport::PeerAddress multicastAddress =
+ Transport::PeerAddress::Multicast(groupSession->GetFabricIndex(), groupSession->GetGroupId());
+ destination = &multicastAddress; // XXX: this is dangling pointer, must be fixed
+ char addressStr[Transport::PeerAddress::kMaxToStringSize];
+ multicastAddress.ToString(addressStr, Transport::PeerAddress::kMaxToStringSize);
- ChipLogProgress(Inet,
- "Sending %s msg %p with MessageCounter:" ChipLogFormatMessageCounter " to %d"
- " at monotonic time: %" PRId64
- " msec to Multicast IPV6 address : %s with GroupID of %d and fabric Id of %d",
- "encrypted", &preparedMessage, preparedMessage.GetMessageCounter(), sessionHandle.GetGroupId().Value(),
- System::SystemClock().GetMonotonicMilliseconds64().count(), addressStr,
- sessionHandle.GetGroupId().Value(), sessionHandle.GetFabricIndex());
- }
- else
- {
- // Find an active connection to the specified peer node
- SecureSession * session = GetSecureSession(sessionHandle);
- if (session == nullptr)
- {
- ChipLogError(Inet, "Secure transport could not find a valid PeerConnection");
- return CHIP_ERROR_NOT_CONNECTED;
- }
-
- // This marks any connection where we send data to as 'active'
- session->MarkActive();
-
- destination = &session->GetPeerAddress();
-
- ChipLogProgress(Inet,
- "Sending %s msg %p with MessageCounter:" ChipLogFormatMessageCounter " to 0x" ChipLogFormatX64
- " (%u) at monotonic time: %" PRId64 " msec",
- "encrypted", &preparedMessage, preparedMessage.GetMessageCounter(),
- ChipLogValueX64(session->GetPeerNodeId()), session->GetFabricIndex(),
- System::SystemClock().GetMonotonicMilliseconds64().count());
- }
+ ChipLogProgress(Inet,
+ "Sending %s msg %p with MessageCounter:" ChipLogFormatMessageCounter " to %d"
+ " at monotonic time: %" PRId64
+ " msec to Multicast IPV6 address : %s with GroupID of %d and fabric Id of %d",
+ "encrypted", &preparedMessage, preparedMessage.GetMessageCounter(), groupSession->GetGroupId(),
+ System::SystemClock().GetMonotonicMilliseconds64().count(), addressStr, groupSession->GetGroupId(),
+ groupSession->GetFabricIndex());
}
- else
- {
- auto unauthenticated = sessionHandle.GetUnauthenticatedSession();
+ break;
+ case Transport::Session::SessionType::kSecure: {
+ // Find an active connection to the specified peer node
+ SecureSession * secure = sessionHandle->AsSecureSession();
+
+ // This marks any connection where we send data to as 'active'
+ secure->MarkActive();
+
+ destination = &secure->GetPeerAddress();
+
+ ChipLogProgress(Inet,
+ "Sending %s msg %p with MessageCounter:" ChipLogFormatMessageCounter " to 0x" ChipLogFormatX64
+ " (%u) at monotonic time: %" PRId64 " msec",
+ "encrypted", &preparedMessage, preparedMessage.GetMessageCounter(),
+ ChipLogValueX64(secure->GetPeerNodeId()), secure->GetFabricIndex(),
+ System::SystemClock().GetMonotonicMilliseconds64().count());
+ }
+ break;
+ case Transport::Session::SessionType::kUnauthenticated: {
+ auto unauthenticated = sessionHandle->AsUnauthenticatedSession();
unauthenticated->MarkActive();
destination = &unauthenticated->GetPeerAddress();
ChipLogProgress(Inet,
"Sending %s msg %p with MessageCounter:" ChipLogFormatMessageCounter " to 0x" ChipLogFormatX64
" at monotonic time: %" PRId64 " msec",
- "plaintext", &preparedMessage, preparedMessage.GetMessageCounter(), ChipLogValueX64(kUndefinedNodeId),
- System::SystemClock().GetMonotonicMilliseconds64().count());
+ sessionHandle->GetSessionTypeString(), &preparedMessage, preparedMessage.GetMessageCounter(),
+ ChipLogValueX64(kUndefinedNodeId), System::SystemClock().GetMonotonicMilliseconds64().count());
+ }
+ break;
+ default:
+ return CHIP_ERROR_INTERNAL;
}
PacketBufferHandle msgBuf = preparedMessage.CastToWritable();
@@ -268,12 +269,7 @@
void SessionManager::ExpirePairing(const SessionHandle & sessionHandle)
{
- SecureSession * session = GetSecureSession(sessionHandle);
- if (session != nullptr)
- {
- HandleConnectionExpired(*session);
- mSecureSessions.ReleaseSession(session);
- }
+ mSecureSessions.ReleaseSession(sessionHandle->AsSecureSession());
}
void SessionManager::ExpireAllPairings(NodeId peerNodeId, FabricIndex fabric)
@@ -281,7 +277,6 @@
mSecureSessions.ForEachSession([&](auto session) {
if (session->GetPeerNodeId() == peerNodeId && session->GetFabricIndex() == fabric)
{
- HandleConnectionExpired(*session);
mSecureSessions.ReleaseSession(session);
}
return Loop::Continue;
@@ -294,7 +289,6 @@
mSecureSessions.ForEachSession([&](auto session) {
if (session->GetFabricIndex() == fabric)
{
- HandleConnectionExpired(*session);
mSecureSessions.ReleaseSession(session);
}
return Loop::Continue;
@@ -305,30 +299,30 @@
NodeId peerNodeId, PairingSession * pairing, CryptoContext::SessionRole direction,
FabricIndex fabric)
{
- uint16_t peerSessionId = pairing->GetPeerSessionId();
- uint16_t localSessionId = pairing->GetLocalSessionId();
- SecureSession * session = mSecureSessions.FindSecureSessionByLocalKey(localSessionId);
+ uint16_t peerSessionId = pairing->GetPeerSessionId();
+ uint16_t localSessionId = pairing->GetLocalSessionId();
+ Optional<SessionHandle> session = mSecureSessions.FindSecureSessionByLocalKey(localSessionId);
// Find any existing connection with the same local key ID
- if (session)
+ if (session.HasValue())
{
- HandleConnectionExpired(*session);
- mSecureSessions.ReleaseSession(session);
+ mSecureSessions.ReleaseSession(session.Value()->AsSecureSession());
}
ChipLogDetail(Inet, "New secure session created for device 0x" ChipLogFormatX64 ", key %d!!", ChipLogValueX64(peerNodeId),
peerSessionId);
session = mSecureSessions.CreateNewSecureSession(pairing->GetSecureSessionType(), localSessionId, peerNodeId,
pairing->GetPeerCATs(), peerSessionId, fabric, pairing->GetMRPConfig());
- ReturnErrorCodeIf(session == nullptr, CHIP_ERROR_NO_MEMORY);
+ ReturnErrorCodeIf(!session.HasValue(), CHIP_ERROR_NO_MEMORY);
+ Transport::SecureSession * secureSession = session.Value()->AsSecureSession();
if (peerAddr.HasValue() && peerAddr.Value().GetIPAddress() != Inet::IPAddress::Any)
{
- session->SetPeerAddress(peerAddr.Value());
+ secureSession->SetPeerAddress(peerAddr.Value());
}
else if (peerAddr.HasValue() && peerAddr.Value().GetTransportType() == Transport::Type::kBle)
{
- session->SetPeerAddress(peerAddr.Value());
+ secureSession->SetPeerAddress(peerAddr.Value());
}
else if (peerAddr.HasValue() &&
(peerAddr.Value().GetTransportType() == Transport::Type::kTcp ||
@@ -337,11 +331,10 @@
return CHIP_ERROR_INVALID_ARGUMENT;
}
- ReturnErrorOnFailure(pairing->DeriveSecureSession(session->GetCryptoContext(), direction));
+ ReturnErrorOnFailure(pairing->DeriveSecureSession(secureSession->GetCryptoContext(), direction));
- session->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(pairing->GetPeerCounter());
- sessionHolder.Grab(SessionHandle(*session));
-
+ secureSession->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(pairing->GetPeerCounter());
+ sessionHolder.Grab(session.Value());
return CHIP_NO_ERROR;
}
@@ -419,19 +412,19 @@
void SessionManager::MessageDispatch(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress,
System::PacketBufferHandle && msg)
{
- Optional<Transport::UnauthenticatedSessionHandle> optionalSession =
- mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress, gDefaultMRPConfig);
+ Optional<SessionHandle> optionalSession = mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress, gDefaultMRPConfig);
if (!optionalSession.HasValue())
{
ChipLogError(Inet, "UnauthenticatedSession exhausted");
return;
}
- Transport::UnauthenticatedSessionHandle session = optionalSession.Value();
+ const SessionHandle & session = optionalSession.Value();
SessionMessageDelegate::DuplicateMessage isDuplicate = SessionMessageDelegate::DuplicateMessage::No;
// Verify message counter
- CHIP_ERROR err = session->GetPeerMessageCounter().VerifyOrTrustFirst(packetHeader.GetMessageCounter());
+ CHIP_ERROR err =
+ session->AsUnauthenticatedSession()->GetPeerMessageCounter().VerifyOrTrustFirst(packetHeader.GetMessageCounter());
if (err == CHIP_ERROR_DUPLICATE_MESSAGE_RECEIVED)
{
isDuplicate = SessionMessageDelegate::DuplicateMessage::Yes;
@@ -439,7 +432,7 @@
}
VerifyOrDie(err == CHIP_NO_ERROR);
- session->MarkActive();
+ session->AsUnauthenticatedSession()->MarkActive();
PayloadHeader payloadHeader;
ReturnOnFailure(payloadHeader.DecodeAndConsume(msg));
@@ -452,11 +445,11 @@
packetHeader.GetMessageCounter(), ChipLogValueExchangeIdFromReceivedHeader(payloadHeader));
}
- session->GetPeerMessageCounter().Commit(packetHeader.GetMessageCounter());
+ session->AsUnauthenticatedSession()->GetPeerMessageCounter().Commit(packetHeader.GetMessageCounter());
if (mCB != nullptr)
{
- mCB->OnMessageReceived(packetHeader, payloadHeader, SessionHandle(session), peerAddress, isDuplicate, std::move(msg));
+ mCB->OnMessageReceived(packetHeader, payloadHeader, optionalSession.Value(), peerAddress, isDuplicate, std::move(msg));
}
}
@@ -465,7 +458,7 @@
{
CHIP_ERROR err = CHIP_NO_ERROR;
- SecureSession * session = mSecureSessions.FindSecureSessionByLocalKey(packetHeader.GetSessionId());
+ Optional<SessionHandle> session = mSecureSessions.FindSecureSessionByLocalKey(packetHeader.GetSessionId());
PayloadHeader payloadHeader;
@@ -477,20 +470,21 @@
return;
}
- if (session == nullptr)
+ if (!session.HasValue())
{
ChipLogError(Inet, "Data received on an unknown connection (%d). Dropping it!!", packetHeader.GetSessionId());
return;
}
+ Transport::SecureSession * secureSession = session.Value()->AsSecureSession();
// Decrypt and verify the message before message counter verification or any further processing.
- if (SecureMessageCodec::Decrypt(session, payloadHeader, packetHeader, msg) != CHIP_NO_ERROR)
+ if (SecureMessageCodec::Decrypt(secureSession, payloadHeader, packetHeader, msg) != CHIP_NO_ERROR)
{
ChipLogError(Inet, "Secure transport received message, but failed to decode/authenticate it, discarding");
return;
}
- err = session->GetSessionMessageCounter().GetPeerMessageCounter().Verify(packetHeader.GetMessageCounter());
+ err = secureSession->GetSessionMessageCounter().GetPeerMessageCounter().Verify(packetHeader.GetMessageCounter());
if (err == CHIP_ERROR_DUPLICATE_MESSAGE_RECEIVED)
{
isDuplicate = SessionMessageDelegate::DuplicateMessage::Yes;
@@ -502,7 +496,7 @@
return;
}
- session->MarkActive();
+ secureSession->MarkActive();
if (isDuplicate == SessionMessageDelegate::DuplicateMessage::Yes && !payloadHeader.NeedsAck())
{
@@ -518,20 +512,19 @@
}
}
- session->GetSessionMessageCounter().GetPeerMessageCounter().Commit(packetHeader.GetMessageCounter());
+ secureSession->GetSessionMessageCounter().GetPeerMessageCounter().Commit(packetHeader.GetMessageCounter());
// TODO: once mDNS address resolution is available reconsider if this is required
// This updates the peer address once a packet is received from a new address
// and serves as a way to auto-detect peer changing IPs.
- if (session->GetPeerAddress() != peerAddress)
+ if (secureSession->GetPeerAddress() != peerAddress)
{
- session->SetPeerAddress(peerAddress);
+ secureSession->SetPeerAddress(peerAddress);
}
if (mCB != nullptr)
{
- SessionHandle sessionHandle(*session);
- mCB->OnMessageReceived(packetHeader, payloadHeader, sessionHandle, peerAddress, isDuplicate, std::move(msg));
+ mCB->OnMessageReceived(packetHeader, payloadHeader, session.Value(), peerAddress, isDuplicate, std::move(msg));
}
}
@@ -540,9 +533,19 @@
{
PayloadHeader payloadHeader;
SessionMessageDelegate::DuplicateMessage isDuplicate = SessionMessageDelegate::DuplicateMessage::No;
- FabricIndex fabricIndex = 0; // TODO : remove initialization once GroupDataProvider->Decrypt is implemented
// Credentials::GroupDataProvider * groups = Credentials::GetGroupDataProvider();
+ if (!packetHeader.GetDestinationGroupId().HasValue())
+ {
+ return; // malformed packet
+ }
+
+ Optional<SessionHandle> session = FindGroupSession(packetHeader.GetDestinationGroupId().Value());
+ if (!session.HasValue())
+ {
+ return;
+ }
+
if (msg.IsNull())
{
ChipLogError(Inet, "Secure transport received Groupcast NULL packet, discarding");
@@ -600,50 +603,22 @@
if (mCB != nullptr)
{
- SessionHandle session(packetHeader.GetSourceNodeId().Value(), packetHeader.GetDestinationGroupId().Value(), fabricIndex);
- mCB->OnMessageReceived(packetHeader, payloadHeader, session, peerAddress, isDuplicate, std::move(msg));
+ mCB->OnMessageReceived(packetHeader, payloadHeader, session.Value(), peerAddress, isDuplicate, std::move(msg));
}
}
-void SessionManager::HandleConnectionExpired(Transport::SecureSession & session)
-{
- ChipLogDetail(Inet, "Marking old secure session for device 0x" ChipLogFormatX64 " as expired",
- ChipLogValueX64(session.GetPeerNodeId()));
-
- SessionHandle sessionHandle(session);
- mSessionReleaseDelegates.ForEachActiveObject([&](std::reference_wrapper<SessionReleaseDelegate> * cb) {
- cb->get().OnSessionReleased(sessionHandle);
- return Loop::Continue;
- });
-
- mTransportMgr->Disconnect(session.GetPeerAddress());
-}
-
void SessionManager::ExpiryTimerCallback(System::Layer * layer, void * param)
{
SessionManager * mgr = reinterpret_cast<SessionManager *>(param);
#if CHIP_CONFIG_SESSION_REKEYING
// TODO(#2279): session expiration is currently disabled until rekeying is supported
// the #ifdef should be removed after that.
- mgr->mSecureSessions.ExpireInactiveSessions(
- System::SystemClock().GetMonotonicTimestamp(), System::Clock::Milliseconds32(CHIP_PEER_CONNECTION_TIMEOUT_MS),
- [this](const Transport::SecureSession & state1) { HandleConnectionExpired(state1); });
+ mgr->mSecureSessions.ExpireInactiveSessions(System::SystemClock().GetMonotonicTimestamp(),
+ System::Clock::Milliseconds32(CHIP_PEER_CONNECTION_TIMEOUT_MS));
#endif
mgr->ScheduleExpiryTimer(); // re-schedule the oneshot timer
}
-SecureSession * SessionManager::GetSecureSession(const SessionHandle & session)
-{
- if (session.mLocalSessionId.HasValue())
- {
- return mSecureSessions.FindSecureSessionByLocalKey(session.mLocalSessionId.Value());
- }
- else
- {
- return nullptr;
- }
-}
-
SessionHandle SessionManager::FindSecureSessionForNode(NodeId peerNodeId)
{
SecureSession * found = nullptr;
diff --git a/src/transport/SessionManager.h b/src/transport/SessionManager.h
index 70add95..62d40e3 100644
--- a/src/transport/SessionManager.h
+++ b/src/transport/SessionManager.h
@@ -34,6 +34,7 @@
#include <messaging/ReliableMessageProtocolConfig.h>
#include <protocols/secure_channel/Constants.h>
#include <transport/CryptoContext.h>
+#include <transport/GroupSession.h>
#include <transport/MessageCounterManagerInterface.h>
#include <transport/SecureSessionTable.h>
#include <transport/SessionDelegate.h>
@@ -143,37 +144,10 @@
*/
CHIP_ERROR SendPreparedMessage(const SessionHandle & session, const EncryptedPacketBufferHandle & preparedMessage);
- Transport::SecureSession * GetSecureSession(const SessionHandle & session);
-
/// @brief Set the delegate for handling incoming messages. There can be only one message delegate (probably the
/// ExchangeManager)
void SetMessageDelegate(SessionMessageDelegate * cb) { mCB = cb; }
- /// @brief Set the delegate for handling session release.
- void RegisterReleaseDelegate(SessionReleaseDelegate & cb)
- {
-#ifndef NDEBUG
- mSessionReleaseDelegates.ForEachActiveObject([&](std::reference_wrapper<SessionReleaseDelegate> * i) {
- VerifyOrDie(std::addressof(cb) != std::addressof(i->get()));
- return Loop::Continue;
- });
-#endif
- std::reference_wrapper<SessionReleaseDelegate> * slot = mSessionReleaseDelegates.CreateObject(cb);
- VerifyOrDie(slot != nullptr);
- }
-
- void UnregisterReleaseDelegate(SessionReleaseDelegate & cb)
- {
- mSessionReleaseDelegates.ForEachActiveObject([&](std::reference_wrapper<SessionReleaseDelegate> * i) {
- if (std::addressof(cb) == std::addressof(i->get()))
- {
- mSessionReleaseDelegates.ReleaseObject(i);
- return Loop::Break;
- }
- return Loop::Continue;
- });
- }
-
void RegisterRecoveryDelegate(SessionRecoveryDelegate & cb);
void UnregisterRecoveryDelegate(SessionRecoveryDelegate & cb);
void RefreshSessionOperationalData(const SessionHandle & sessionHandle);
@@ -232,16 +206,12 @@
Optional<SessionHandle> CreateUnauthenticatedSession(const Transport::PeerAddress & peerAddress,
const ReliableMessageProtocolConfig & config)
{
- Optional<Transport::UnauthenticatedSessionHandle> session =
- mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress, config);
- return session.HasValue() ? MakeOptional<SessionHandle>(session.Value()) : NullOptional;
+ return mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress, config);
}
- // TODO: placeholder function for creating GroupSession. Implements a GroupSession class in the future
- Optional<SessionHandle> CreateGroupSession(NodeId peerNodeId, GroupId groupId, FabricIndex fabricIndex)
- {
- return MakeOptional(SessionHandle(peerNodeId, groupId, fabricIndex));
- }
+ // TODO: implements group sessions
+ Optional<SessionHandle> CreateGroupSession(GroupId group) { return mGroupSessions.AllocEntry(group, kUndefinedFabricIndex); }
+ Optional<SessionHandle> FindGroupSession(GroupId group) { return mGroupSessions.FindEntry(group, kUndefinedFabricIndex); }
// TODO: this is a temporary solution for legacy tests which use nodeId to send packets
// and tv-casting-app that uses the TV's node ID to find the associated secure session
@@ -265,17 +235,12 @@
System::Layer * mSystemLayer = nullptr;
Transport::UnauthenticatedSessionTable<CHIP_CONFIG_UNAUTHENTICATED_CONNECTION_POOL_SIZE> mUnauthenticatedSessions;
- Transport::SecureSessionTable<CHIP_CONFIG_PEER_CONNECTION_POOL_SIZE> mSecureSessions; // < Active connections to other peers
- State mState; // < Initialization state of the object
+ Transport::SecureSessionTable<CHIP_CONFIG_PEER_CONNECTION_POOL_SIZE> mSecureSessions;
+ Transport::GroupSessionTable<CHIP_CONFIG_GROUP_CONNECTION_POOL_SIZE> mGroupSessions;
+ State mState; // < Initialization state of the object
SessionMessageDelegate * mCB = nullptr;
- // TODO: This is a temporary solution to release sessions, in the near future, SessionReleaseDelegate will be
- // directly associated with the every SessionHolder. Then the callback function is called on over the handle
- // delegate directly, in order to prevent dangling handles.
- BitMapObjectPool<std::reference_wrapper<SessionReleaseDelegate>, CHIP_CONFIG_MAX_SESSION_RELEASE_DELEGATES>
- mSessionReleaseDelegates;
-
BitMapObjectPool<std::reference_wrapper<SessionRecoveryDelegate>, CHIP_CONFIG_MAX_SESSION_RECOVERY_DELEGATES>
mSessionRecoveryDelegates;
@@ -285,6 +250,8 @@
GlobalUnencryptedMessageCounter mGlobalUnencryptedMessageCounter;
GlobalEncryptedMessageCounter mGlobalEncryptedMessageCounter;
+ friend class SessionHandle;
+
/** Schedules a new oneshot timer for checking connection expiry. */
void ScheduleExpiryTimer();
@@ -292,11 +259,6 @@
void CancelExpiryTimer();
/**
- * Called when a specific connection expires.
- */
- void HandleConnectionExpired(Transport::SecureSession & state);
-
- /**
* Callback for timer expiry check
*/
static void ExpiryTimerCallback(System::Layer * layer, void * param);
diff --git a/src/transport/UnauthenticatedSessionTable.h b/src/transport/UnauthenticatedSessionTable.h
index c4ca79f..0b645d0 100644
--- a/src/transport/UnauthenticatedSessionTable.h
+++ b/src/transport/UnauthenticatedSessionTable.h
@@ -26,14 +26,12 @@
#include <system/TimeSource.h>
#include <transport/MessageCounter.h>
#include <transport/PeerMessageCounter.h>
+#include <transport/Session.h>
#include <transport/raw/PeerAddress.h>
namespace chip {
namespace Transport {
-class UnauthenticatedSession;
-using UnauthenticatedSessionHandle = ReferenceCountedHandle<UnauthenticatedSession>;
-
class UnauthenticatedSessionDeleter
{
public:
@@ -45,12 +43,13 @@
* @brief
* An UnauthenticatedSession stores the binding of TransportAddress, and message counters.
*/
-class UnauthenticatedSession : public ReferenceCounted<UnauthenticatedSession, UnauthenticatedSessionDeleter, 0>
+class UnauthenticatedSession : public Session, public ReferenceCounted<UnauthenticatedSession, UnauthenticatedSessionDeleter, 0>
{
public:
UnauthenticatedSession(const PeerAddress & address, const ReliableMessageProtocolConfig & config) :
mPeerAddress(address), mLastActivityTime(System::SystemClock().GetMonotonicTimestamp()), mMRPConfig(config)
{}
+ ~UnauthenticatedSession() { NotifySessionReleased(); }
UnauthenticatedSession(const UnauthenticatedSession &) = delete;
UnauthenticatedSession & operator=(const UnauthenticatedSession &) = delete;
@@ -60,11 +59,41 @@
System::Clock::Timestamp GetLastActivityTime() const { return mLastActivityTime; }
void MarkActive() { mLastActivityTime = System::SystemClock().GetMonotonicTimestamp(); }
+ Session::SessionType GetSessionType() const override { return Session::SessionType::kUnauthenticated; }
+#if CHIP_PROGRESS_LOGGING
+ const char * GetSessionTypeString() const override { return "unauthenticated"; };
+#endif
+
+ void Retain() override { ReferenceCounted<UnauthenticatedSession, UnauthenticatedSessionDeleter, 0>::Retain(); }
+ void Release() override { ReferenceCounted<UnauthenticatedSession, UnauthenticatedSessionDeleter, 0>::Release(); }
+
+ Access::SubjectDescriptor GetSubjectDescriptor() const override
+ {
+ return Access::SubjectDescriptor(); // return an empty ISD for unauthenticated session.
+ }
+
+ bool RequireMRP() const override { return GetPeerAddress().GetTransportType() == Transport::Type::kUdp; }
+
+ System::Clock::Milliseconds32 GetAckTimeout() const override
+ {
+ switch (mPeerAddress.GetTransportType())
+ {
+ case Transport::Type::kUdp:
+ return GetMRPConfig().mIdleRetransTimeout * (CHIP_CONFIG_RMP_DEFAULT_MAX_RETRANS + 1);
+ case Transport::Type::kTcp:
+ return System::Clock::Seconds16(30);
+ default:
+ break;
+ }
+ return System::Clock::Timeout();
+ }
+
+ NodeId GetPeerNodeId() const { return kUndefinedNodeId; }
const PeerAddress & GetPeerAddress() const { return mPeerAddress; }
void SetMRPConfig(const ReliableMessageProtocolConfig & config) { mMRPConfig = config; }
- const ReliableMessageProtocolConfig & GetMRPConfig() const { return mMRPConfig; }
+ const ReliableMessageProtocolConfig & GetMRPConfig() const override { return mMRPConfig; }
PeerMessageCounter & GetPeerMessageCounter() { return mPeerMessageCounter; }
@@ -79,9 +108,8 @@
* @brief
* An table which manages UnauthenticatedSessions
*
- * The UnauthenticatedSession entries are rotated using LRU, but entry can be
- * hold by using UnauthenticatedSessionHandle, which increase the reference
- * count by 1. If the reference count is not 0, the entry won't be pruned.
+ * The UnauthenticatedSession entries are rotated using LRU, but entry can be hold by using SessionHandle or
+ * SessionHolder, which increase the reference count by 1. If the reference count is not 0, the entry won't be pruned.
*/
template <size_t kMaxSessionCount>
class UnauthenticatedSessionTable
@@ -95,21 +123,20 @@
* @return the session found or allocated, nullptr if not found and allocation failed.
*/
CHECK_RETURN_VALUE
- Optional<UnauthenticatedSessionHandle> FindOrAllocateEntry(const PeerAddress & address,
- const ReliableMessageProtocolConfig & config)
+ Optional<SessionHandle> FindOrAllocateEntry(const PeerAddress & address, const ReliableMessageProtocolConfig & config)
{
UnauthenticatedSession * result = FindEntry(address);
if (result != nullptr)
- return MakeOptional<UnauthenticatedSessionHandle>(*result);
+ return MakeOptional<SessionHandle>(*result);
CHIP_ERROR err = AllocEntry(address, config, result);
if (err == CHIP_NO_ERROR)
{
- return MakeOptional<UnauthenticatedSessionHandle>(*result);
+ return MakeOptional<SessionHandle>(*result);
}
else
{
- return Optional<UnauthenticatedSessionHandle>::Missing();
+ return Optional<SessionHandle>::Missing();
}
}
diff --git a/src/transport/tests/TestPeerConnections.cpp b/src/transport/tests/TestPeerConnections.cpp
index 259defe..e716f86 100644
--- a/src/transport/tests/TestPeerConnections.cpp
+++ b/src/transport/tests/TestPeerConnections.cpp
@@ -62,7 +62,6 @@
void TestBasicFunctionality(nlTestSuite * inSuite, void * inContext)
{
- SecureSession * statePtr;
SecureSessionTable<2> connections;
System::Clock::Internal::MockClock clock;
System::Clock::ClockBase * realClock = &System::SystemClock();
@@ -71,54 +70,53 @@
CATValues peerCATs;
// Node ID 1, peer key 1, local key 2
- statePtr = connections.CreateNewSecureSession(kPeer1SessionType, 2, kPeer1NodeId, kPeer1CATs, 1, 0 /* fabricIndex */,
- gDefaultMRPConfig);
- NL_TEST_ASSERT(inSuite, statePtr != nullptr);
- NL_TEST_ASSERT(inSuite, statePtr->GetSecureSessionType() == kPeer1SessionType);
- NL_TEST_ASSERT(inSuite, statePtr->GetPeerNodeId() == kPeer1NodeId);
- peerCATs = statePtr->GetPeerCATs();
+ auto optionalSession = connections.CreateNewSecureSession(kPeer1SessionType, 2, kPeer1NodeId, kPeer1CATs, 1,
+ 0 /* fabricIndex */, gDefaultMRPConfig);
+ NL_TEST_ASSERT(inSuite, optionalSession.HasValue());
+ NL_TEST_ASSERT(inSuite, optionalSession.Value()->AsSecureSession()->GetSecureSessionType() == kPeer1SessionType);
+ NL_TEST_ASSERT(inSuite, optionalSession.Value()->AsSecureSession()->GetPeerNodeId() == kPeer1NodeId);
+ peerCATs = optionalSession.Value()->AsSecureSession()->GetPeerCATs();
NL_TEST_ASSERT(inSuite, memcmp(&peerCATs, &kPeer1CATs, sizeof(CATValues)) == 0);
// Node ID 2, peer key 3, local key 4
- statePtr = connections.CreateNewSecureSession(kPeer2SessionType, 4, kPeer2NodeId, kPeer2CATs, 3, 0 /* fabricIndex */,
- gDefaultMRPConfig);
- NL_TEST_ASSERT(inSuite, statePtr != nullptr);
- NL_TEST_ASSERT(inSuite, statePtr->GetSecureSessionType() == kPeer2SessionType);
- NL_TEST_ASSERT(inSuite, statePtr->GetPeerNodeId() == kPeer2NodeId);
- NL_TEST_ASSERT(inSuite, statePtr->GetLastActivityTime() == 100_ms64);
- peerCATs = statePtr->GetPeerCATs();
+ optionalSession = connections.CreateNewSecureSession(kPeer2SessionType, 4, kPeer2NodeId, kPeer2CATs, 3, 0 /* fabricIndex */,
+ gDefaultMRPConfig);
+ NL_TEST_ASSERT(inSuite, optionalSession.HasValue());
+ NL_TEST_ASSERT(inSuite, optionalSession.Value()->AsSecureSession()->GetSecureSessionType() == kPeer2SessionType);
+ NL_TEST_ASSERT(inSuite, optionalSession.Value()->AsSecureSession()->GetPeerNodeId() == kPeer2NodeId);
+ NL_TEST_ASSERT(inSuite, optionalSession.Value()->AsSecureSession()->GetLastActivityTime() == 100_ms64);
+ peerCATs = optionalSession.Value()->AsSecureSession()->GetPeerCATs();
NL_TEST_ASSERT(inSuite, memcmp(&peerCATs, &kPeer2CATs, sizeof(CATValues)) == 0);
// Insufficient space for new connections. Object is max size 2
- statePtr = connections.CreateNewSecureSession(kPeer3SessionType, 6, kPeer3NodeId, kPeer3CATs, 5, 0 /* fabricIndex */,
- gDefaultMRPConfig);
- NL_TEST_ASSERT(inSuite, statePtr == nullptr);
+ optionalSession = connections.CreateNewSecureSession(kPeer3SessionType, 6, kPeer3NodeId, kPeer3CATs, 5, 0 /* fabricIndex */,
+ gDefaultMRPConfig);
+ NL_TEST_ASSERT(inSuite, !optionalSession.HasValue());
System::Clock::Internal::SetSystemClockForTesting(realClock);
}
void TestFindByKeyId(nlTestSuite * inSuite, void * inContext)
{
- SecureSession * statePtr;
SecureSessionTable<2> connections;
System::Clock::Internal::MockClock clock;
System::Clock::ClockBase * realClock = &System::SystemClock();
System::Clock::Internal::SetSystemClockForTesting(&clock);
// Node ID 1, peer key 1, local key 2
- statePtr = connections.CreateNewSecureSession(kPeer1SessionType, 2, kPeer1NodeId, kPeer1CATs, 1, 0 /* fabricIndex */,
- gDefaultMRPConfig);
- NL_TEST_ASSERT(inSuite, statePtr != nullptr);
+ auto optionalSession = connections.CreateNewSecureSession(kPeer1SessionType, 2, kPeer1NodeId, kPeer1CATs, 1,
+ 0 /* fabricIndex */, gDefaultMRPConfig);
+ NL_TEST_ASSERT(inSuite, optionalSession.HasValue());
- NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(1));
- NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(2));
+ NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(1).HasValue());
+ NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(2).HasValue());
// Node ID 2, peer key 3, local key 4
- statePtr = connections.CreateNewSecureSession(kPeer2SessionType, 4, kPeer2NodeId, kPeer2CATs, 3, 0 /* fabricIndex */,
- gDefaultMRPConfig);
- NL_TEST_ASSERT(inSuite, statePtr != nullptr);
+ optionalSession = connections.CreateNewSecureSession(kPeer2SessionType, 4, kPeer2NodeId, kPeer2CATs, 3, 0 /* fabricIndex */,
+ gDefaultMRPConfig);
+ NL_TEST_ASSERT(inSuite, optionalSession.HasValue());
- NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(3));
- NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(4));
+ NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(3).HasValue());
+ NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(4).HasValue());
System::Clock::Internal::SetSystemClockForTesting(realClock);
}
@@ -133,7 +131,6 @@
void TestExpireConnections(nlTestSuite * inSuite, void * inContext)
{
ExpiredCallInfo callInfo;
- SecureSession * statePtr;
SecureSessionTable<2> connections;
System::Clock::Internal::MockClock clock;
@@ -143,23 +140,23 @@
clock.SetMonotonic(100_ms64);
// Node ID 1, peer key 1, local key 2
- statePtr = connections.CreateNewSecureSession(kPeer1SessionType, 2, kPeer1NodeId, kPeer1CATs, 1, 0 /* fabricIndex */,
- gDefaultMRPConfig);
- NL_TEST_ASSERT(inSuite, statePtr != nullptr);
- statePtr->SetPeerAddress(kPeer1Addr);
+ auto optionalSession = connections.CreateNewSecureSession(kPeer1SessionType, 2, kPeer1NodeId, kPeer1CATs, 1,
+ 0 /* fabricIndex */, gDefaultMRPConfig);
+ NL_TEST_ASSERT(inSuite, optionalSession.HasValue());
+ optionalSession.Value()->AsSecureSession()->SetPeerAddress(kPeer1Addr);
clock.SetMonotonic(200_ms64);
// Node ID 2, peer key 3, local key 4
- statePtr = connections.CreateNewSecureSession(kPeer2SessionType, 4, kPeer2NodeId, kPeer2CATs, 3, 0 /* fabricIndex */,
- gDefaultMRPConfig);
- NL_TEST_ASSERT(inSuite, statePtr != nullptr);
- statePtr->SetPeerAddress(kPeer2Addr);
+ optionalSession = connections.CreateNewSecureSession(kPeer2SessionType, 4, kPeer2NodeId, kPeer2CATs, 3, 0 /* fabricIndex */,
+ gDefaultMRPConfig);
+ NL_TEST_ASSERT(inSuite, optionalSession.HasValue());
+ optionalSession.Value()->AsSecureSession()->SetPeerAddress(kPeer2Addr);
// cannot add before expiry
clock.SetMonotonic(300_ms64);
- statePtr = connections.CreateNewSecureSession(kPeer3SessionType, 6, kPeer3NodeId, kPeer3CATs, 5, 0 /* fabricIndex */,
- gDefaultMRPConfig);
- NL_TEST_ASSERT(inSuite, statePtr == nullptr);
+ optionalSession = connections.CreateNewSecureSession(kPeer3SessionType, 6, kPeer3NodeId, kPeer3CATs, 5, 0 /* fabricIndex */,
+ gDefaultMRPConfig);
+ NL_TEST_ASSERT(inSuite, !optionalSession.HasValue());
// at time 300, this expires ip addr 1
connections.ExpireInactiveSessions(150_ms64, [&callInfo](const SecureSession & state) {
@@ -170,21 +167,22 @@
NL_TEST_ASSERT(inSuite, callInfo.callCount == 1);
NL_TEST_ASSERT(inSuite, callInfo.lastCallNodeId == kPeer1NodeId);
NL_TEST_ASSERT(inSuite, callInfo.lastCallPeerAddress == kPeer1Addr);
- NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2));
+ NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2).HasValue());
// now that the connections were expired, we can add peer3
clock.SetMonotonic(300_ms64);
// Node ID 3, peer key 5, local key 6
- statePtr = connections.CreateNewSecureSession(kPeer3SessionType, 6, kPeer3NodeId, kPeer3CATs, 5, 0 /* fabricIndex */,
- gDefaultMRPConfig);
- NL_TEST_ASSERT(inSuite, statePtr != nullptr);
- statePtr->SetPeerAddress(kPeer3Addr);
+ optionalSession = connections.CreateNewSecureSession(kPeer3SessionType, 6, kPeer3NodeId, kPeer3CATs, 5, 0 /* fabricIndex */,
+ gDefaultMRPConfig);
+ NL_TEST_ASSERT(inSuite, optionalSession.HasValue());
+ optionalSession.Value()->AsSecureSession()->SetPeerAddress(kPeer3Addr);
clock.SetMonotonic(400_ms64);
- NL_TEST_ASSERT(inSuite, statePtr = connections.FindSecureSessionByLocalKey(4));
+ optionalSession = connections.FindSecureSessionByLocalKey(4);
+ NL_TEST_ASSERT(inSuite, optionalSession.HasValue());
- statePtr->MarkActive();
- NL_TEST_ASSERT(inSuite, statePtr->GetLastActivityTime() == clock.GetMonotonicTimestamp());
+ optionalSession.Value()->AsSecureSession()->MarkActive();
+ NL_TEST_ASSERT(inSuite, optionalSession.Value()->AsSecureSession()->GetLastActivityTime() == clock.GetMonotonicTimestamp());
// At this time:
// Peer 3 active at time 300
@@ -202,17 +200,17 @@
NL_TEST_ASSERT(inSuite, callInfo.callCount == 1);
NL_TEST_ASSERT(inSuite, callInfo.lastCallNodeId == kPeer3NodeId);
NL_TEST_ASSERT(inSuite, callInfo.lastCallPeerAddress == kPeer3Addr);
- NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2));
- NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(4));
- NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(6));
+ NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2).HasValue());
+ NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(4).HasValue());
+ NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(6).HasValue());
// Node ID 1, peer key 1, local key 2
- statePtr = connections.CreateNewSecureSession(kPeer1SessionType, 2, kPeer1NodeId, kPeer1CATs, 1, 0 /* fabricIndex */,
- gDefaultMRPConfig);
- NL_TEST_ASSERT(inSuite, statePtr != nullptr);
- NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(2));
- NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(4));
- NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(6));
+ optionalSession = connections.CreateNewSecureSession(kPeer1SessionType, 2, kPeer1NodeId, kPeer1CATs, 1, 0 /* fabricIndex */,
+ gDefaultMRPConfig);
+ NL_TEST_ASSERT(inSuite, optionalSession.HasValue());
+ NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(2).HasValue());
+ NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(4).HasValue());
+ NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(6).HasValue());
// peer 1 and 2 are active
clock.SetMonotonic(1000_ms64);
@@ -223,9 +221,9 @@
callInfo.lastCallPeerAddress = state.GetPeerAddress();
});
NL_TEST_ASSERT(inSuite, callInfo.callCount == 2); // everything expired
- NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2));
- NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(4));
- NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(6));
+ NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2).HasValue());
+ NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(4).HasValue());
+ NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(6).HasValue());
System::Clock::Internal::SetSystemClockForTesting(realClock);
}
diff --git a/src/transport/tests/TestSessionManager.cpp b/src/transport/tests/TestSessionManager.cpp
index b39cd4a..acdc76e 100644
--- a/src/transport/tests/TestSessionManager.cpp
+++ b/src/transport/tests/TestSessionManager.cpp
@@ -59,15 +59,20 @@
const char LARGE_PAYLOAD[kMaxAppMessageLen + 1] = "test message";
-class TestSessMgrCallback : public SessionReleaseDelegate, public SessionMessageDelegate
+class TestSessionReleaseCallback : public SessionReleaseDelegate
+{
+public:
+ void OnSessionReleased() override { mOldConnectionDropped = true; }
+ bool mOldConnectionDropped = false;
+};
+
+class TestSessMgrCallback : public SessionMessageDelegate
{
public:
void OnMessageReceived(const PacketHeader & header, const PayloadHeader & payloadHeader, const SessionHandle & session,
const Transport::PeerAddress & source, DuplicateMessage isDuplicate,
System::PacketBufferHandle && msgBuf) override
{
- NL_TEST_ASSERT(mSuite, mRemoteToLocalSession.Contains(session)); // Packet received by remote peer
-
size_t data_len = msgBuf->DataLength();
if (LargeMessageSent)
@@ -84,17 +89,9 @@
ReceiveHandlerCallCount++;
}
- void OnSessionReleased(const SessionHandle & session) override { mOldConnectionDropped = true; }
-
- bool mOldConnectionDropped = false;
-
- nlTestSuite * mSuite = nullptr;
- SessionHolder mRemoteToLocalSession;
- SessionHolder mLocalToRemoteSession;
- int ReceiveHandlerCallCount = 0;
- int NewConnectionHandlerCallCount = 0;
-
- bool LargeMessageSent = false;
+ nlTestSuite * mSuite = nullptr;
+ int ReceiveHandlerCallCount = 0;
+ bool LargeMessageSent = false;
};
void CheckSimpleInitTest(nlTestSuite * inSuite, void * inContext)
@@ -145,19 +142,19 @@
sessionManager.SetMessageDelegate(&callback);
Optional<Transport::PeerAddress> peer(Transport::PeerAddress::UDP(addr, CHIP_PORT));
+ SessionHolder localToRemoteSession;
+ SessionHolder remoteToLocalSession;
SecurePairingUsingTestSecret pairing1(1, 2);
- err = sessionManager.NewPairing(callback.mRemoteToLocalSession, peer, kSourceNodeId, &pairing1,
- CryptoContext::SessionRole::kInitiator, 1);
+ err =
+ sessionManager.NewPairing(localToRemoteSession, peer, kSourceNodeId, &pairing1, CryptoContext::SessionRole::kInitiator, 1);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
SecurePairingUsingTestSecret pairing2(2, 1);
- err = sessionManager.NewPairing(callback.mLocalToRemoteSession, peer, kDestinationNodeId, &pairing2,
+ err = sessionManager.NewPairing(remoteToLocalSession, peer, kDestinationNodeId, &pairing2,
CryptoContext::SessionRole::kResponder, 0);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
- SessionHandle localToRemoteSession = callback.mLocalToRemoteSession.Get();
-
// Should be able to send a message to itself by just calling send.
callback.ReceiveHandlerCallCount = 0;
@@ -170,10 +167,10 @@
payloadHeader.SetMessageType(chip::Protocols::Echo::MsgType::EchoRequest);
EncryptedPacketBufferHandle preparedMessage;
- err = sessionManager.PrepareMessage(localToRemoteSession, payloadHeader, std::move(buffer), preparedMessage);
+ err = sessionManager.PrepareMessage(localToRemoteSession.Get(), payloadHeader, std::move(buffer), preparedMessage);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
- err = sessionManager.SendPreparedMessage(localToRemoteSession, preparedMessage);
+ err = sessionManager.SendPreparedMessage(localToRemoteSession.Get(), preparedMessage);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1);
@@ -184,10 +181,10 @@
callback.LargeMessageSent = true;
- err = sessionManager.PrepareMessage(localToRemoteSession, payloadHeader, std::move(large_buffer), preparedMessage);
+ err = sessionManager.PrepareMessage(localToRemoteSession.Get(), payloadHeader, std::move(large_buffer), preparedMessage);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
- err = sessionManager.SendPreparedMessage(localToRemoteSession, preparedMessage);
+ err = sessionManager.SendPreparedMessage(localToRemoteSession.Get(), preparedMessage);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 2);
@@ -200,7 +197,7 @@
callback.LargeMessageSent = true;
- err = sessionManager.PrepareMessage(localToRemoteSession, payloadHeader, std::move(extra_large_buffer), preparedMessage);
+ err = sessionManager.PrepareMessage(localToRemoteSession.Get(), payloadHeader, std::move(extra_large_buffer), preparedMessage);
NL_TEST_ASSERT(inSuite, err == CHIP_ERROR_MESSAGE_TOO_LONG);
sessionManager.Shutdown();
@@ -237,19 +234,19 @@
sessionManager.SetMessageDelegate(&callback);
Optional<Transport::PeerAddress> peer(Transport::PeerAddress::UDP(addr, CHIP_PORT));
+ SessionHolder localToRemoteSession;
+ SessionHolder remoteToLocalSession;
SecurePairingUsingTestSecret pairing1(1, 2);
- err = sessionManager.NewPairing(callback.mRemoteToLocalSession, peer, kSourceNodeId, &pairing1,
- CryptoContext::SessionRole::kInitiator, 1);
+ err =
+ sessionManager.NewPairing(localToRemoteSession, peer, kSourceNodeId, &pairing1, CryptoContext::SessionRole::kInitiator, 1);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
SecurePairingUsingTestSecret pairing2(2, 1);
- err = sessionManager.NewPairing(callback.mLocalToRemoteSession, peer, kDestinationNodeId, &pairing2,
+ err = sessionManager.NewPairing(remoteToLocalSession, peer, kDestinationNodeId, &pairing2,
CryptoContext::SessionRole::kResponder, 0);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
- SessionHandle localToRemoteSession = callback.mLocalToRemoteSession.Get();
-
// Should be able to send a message to itself by just calling send.
callback.ReceiveHandlerCallCount = 0;
@@ -264,19 +261,19 @@
payloadHeader.SetInitiator(true);
- err = sessionManager.PrepareMessage(localToRemoteSession, payloadHeader, std::move(buffer), preparedMessage);
+ err = sessionManager.PrepareMessage(localToRemoteSession.Get(), payloadHeader, std::move(buffer), preparedMessage);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
- err = sessionManager.SendPreparedMessage(localToRemoteSession, preparedMessage);
+ err = sessionManager.SendPreparedMessage(localToRemoteSession.Get(), preparedMessage);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
// Reset receive side message counter, or duplicated message will be denied.
- Transport::SecureSession * state = sessionManager.GetSecureSession(callback.mRemoteToLocalSession.Get());
- state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1);
+ Transport::SecureSession * session = remoteToLocalSession.Get()->AsSecureSession();
+ session->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1);
NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1);
- err = sessionManager.SendPreparedMessage(localToRemoteSession, preparedMessage);
+ err = sessionManager.SendPreparedMessage(localToRemoteSession.Get(), preparedMessage);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 2);
@@ -315,19 +312,19 @@
sessionManager.SetMessageDelegate(&callback);
Optional<Transport::PeerAddress> peer(Transport::PeerAddress::UDP(addr, CHIP_PORT));
+ SessionHolder localToRemoteSession;
+ SessionHolder remoteToLocalSession;
SecurePairingUsingTestSecret pairing1(1, 2);
- err = sessionManager.NewPairing(callback.mRemoteToLocalSession, peer, kSourceNodeId, &pairing1,
- CryptoContext::SessionRole::kInitiator, 1);
+ err =
+ sessionManager.NewPairing(localToRemoteSession, peer, kSourceNodeId, &pairing1, CryptoContext::SessionRole::kInitiator, 1);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
SecurePairingUsingTestSecret pairing2(2, 1);
- err = sessionManager.NewPairing(callback.mLocalToRemoteSession, peer, kDestinationNodeId, &pairing2,
+ err = sessionManager.NewPairing(remoteToLocalSession, peer, kDestinationNodeId, &pairing2,
CryptoContext::SessionRole::kResponder, 0);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
- SessionHandle localToRemoteSession = callback.mLocalToRemoteSession.Get();
-
// Should be able to send a message to itself by just calling send.
callback.ReceiveHandlerCallCount = 0;
@@ -342,23 +339,21 @@
payloadHeader.SetInitiator(true);
- err = sessionManager.PrepareMessage(localToRemoteSession, payloadHeader, std::move(buffer), preparedMessage);
+ err = sessionManager.PrepareMessage(localToRemoteSession.Get(), payloadHeader, std::move(buffer), preparedMessage);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
- err = sessionManager.SendPreparedMessage(localToRemoteSession, preparedMessage);
+ err = sessionManager.SendPreparedMessage(localToRemoteSession.Get(), preparedMessage);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1);
/* -------------------------------------------------------------------------------------------*/
// Reset receive side message counter, or duplicated message will be denied.
- Transport::SecureSession * state = sessionManager.GetSecureSession(callback.mRemoteToLocalSession.Get());
- state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1);
+ Transport::SecureSession * session = remoteToLocalSession.Get()->AsSecureSession();
+ session->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1);
PacketHeader packetHeader;
- state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1);
-
// Change Message ID
EncryptedPacketBufferHandle badMessageCounterMsg = preparedMessage.CloneData();
NL_TEST_ASSERT(inSuite, badMessageCounterMsg.ExtractPacketHeader(packetHeader) == CHIP_NO_ERROR);
@@ -367,13 +362,13 @@
packetHeader.SetMessageCounter(messageCounter + 1);
NL_TEST_ASSERT(inSuite, badMessageCounterMsg.InsertPacketHeader(packetHeader) == CHIP_NO_ERROR);
- err = sessionManager.SendPreparedMessage(localToRemoteSession, badMessageCounterMsg);
+ err = sessionManager.SendPreparedMessage(localToRemoteSession.Get(), badMessageCounterMsg);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1);
/* -------------------------------------------------------------------------------------------*/
- state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1);
+ session->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1);
// Change Key ID
EncryptedPacketBufferHandle badKeyIdMsg = preparedMessage.CloneData();
@@ -383,16 +378,16 @@
packetHeader.SetSessionId(3);
NL_TEST_ASSERT(inSuite, badKeyIdMsg.InsertPacketHeader(packetHeader) == CHIP_NO_ERROR);
- err = sessionManager.SendPreparedMessage(localToRemoteSession, badKeyIdMsg);
+ err = sessionManager.SendPreparedMessage(localToRemoteSession.Get(), badKeyIdMsg);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
/* -------------------------------------------------------------------------------------------*/
- state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1);
+ session->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1);
NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1);
// Send the correct encrypted msg
- err = sessionManager.SendPreparedMessage(localToRemoteSession, preparedMessage);
+ err = sessionManager.SendPreparedMessage(localToRemoteSession.Get(), preparedMessage);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 2);
@@ -418,51 +413,46 @@
err = sessionManager.Init(&ctx.GetSystemLayer(), &transportMgr, &gMessageCounterManager);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
- TestSessMgrCallback callback;
- callback.mSuite = inSuite;
-
- sessionManager.RegisterReleaseDelegate(callback);
- sessionManager.SetMessageDelegate(&callback);
-
Optional<Transport::PeerAddress> peer(Transport::PeerAddress::UDP(addr, CHIP_PORT));
+ TestSessionReleaseCallback callback;
+ SessionHolderWithDelegate session1(callback);
+ SessionHolderWithDelegate session2(callback);
+ SessionHolderWithDelegate session3(callback);
+ SessionHolderWithDelegate session4(callback);
+ SessionHolderWithDelegate session5(callback);
// First pairing
SecurePairingUsingTestSecret pairing1(1, 1);
callback.mOldConnectionDropped = false;
- err = sessionManager.NewPairing(callback.mRemoteToLocalSession, peer, kSourceNodeId, &pairing1,
- CryptoContext::SessionRole::kInitiator, 1);
+ err = sessionManager.NewPairing(session1, peer, kSourceNodeId, &pairing1, CryptoContext::SessionRole::kInitiator, 1);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, !callback.mOldConnectionDropped);
// New pairing with different peer node ID and different local key ID (same peer key ID)
SecurePairingUsingTestSecret pairing2(1, 2);
callback.mOldConnectionDropped = false;
- err = sessionManager.NewPairing(callback.mLocalToRemoteSession, peer, kSourceNodeId, &pairing2,
- CryptoContext::SessionRole::kResponder, 0);
+ err = sessionManager.NewPairing(session2, peer, kSourceNodeId, &pairing2, CryptoContext::SessionRole::kResponder, 0);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, !callback.mOldConnectionDropped);
// New pairing with undefined node ID and different local key ID (same peer key ID)
SecurePairingUsingTestSecret pairing3(1, 3);
callback.mOldConnectionDropped = false;
- err = sessionManager.NewPairing(callback.mLocalToRemoteSession, peer, kUndefinedNodeId, &pairing3,
- CryptoContext::SessionRole::kResponder, 0);
+ err = sessionManager.NewPairing(session3, peer, kUndefinedNodeId, &pairing3, CryptoContext::SessionRole::kResponder, 0);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, !callback.mOldConnectionDropped);
// New pairing with same local key ID, and a given node ID
SecurePairingUsingTestSecret pairing4(1, 2);
callback.mOldConnectionDropped = false;
- err = sessionManager.NewPairing(callback.mLocalToRemoteSession, peer, kSourceNodeId, &pairing4,
- CryptoContext::SessionRole::kResponder, 0);
+ err = sessionManager.NewPairing(session4, peer, kSourceNodeId, &pairing4, CryptoContext::SessionRole::kResponder, 0);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, callback.mOldConnectionDropped);
// New pairing with same local key ID, and undefined node ID
SecurePairingUsingTestSecret pairing5(1, 1);
callback.mOldConnectionDropped = false;
- err = sessionManager.NewPairing(callback.mLocalToRemoteSession, peer, kUndefinedNodeId, &pairing5,
- CryptoContext::SessionRole::kResponder, 0);
+ err = sessionManager.NewPairing(session5, peer, kUndefinedNodeId, &pairing5, CryptoContext::SessionRole::kResponder, 0);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, callback.mOldConnectionDropped);
diff --git a/third_party/nxp/k32w0_sdk/sdk_fixes/patch_ble_utils_h.patch b/third_party/nxp/k32w0_sdk/sdk_fixes/patch_ble_utils_h.patch
new file mode 100644
index 0000000..bb2051b
--- /dev/null
+++ b/third_party/nxp/k32w0_sdk/sdk_fixes/patch_ble_utils_h.patch
@@ -0,0 +1,14 @@
+--- a/ble_utils.h 2022-01-07 16:41:34.017433835 +0000
++++ b/ble_utils.h 2022-01-07 16:42:09.797788620 +0000
+@@ -73,11 +73,6 @@
+ #define ALIGN_64BIT #pragma pack(8)
+ #endif
+
+-/*! Marks that this variable is in the interface. */
+-#ifndef global
+-#define global
+-#endif
+-
+ /*! Marks a function that never returns. */
+ #if !defined(__IAR_SYSTEMS_ICC__)
+ #ifndef __noreturn
diff --git a/third_party/nxp/k32w0_sdk/sdk_fixes/patch_k32w_sdk.sh b/third_party/nxp/k32w0_sdk/sdk_fixes/patch_k32w_sdk.sh
index 378feb7..d3f262f 100755
--- a/third_party/nxp/k32w0_sdk/sdk_fixes/patch_k32w_sdk.sh
+++ b/third_party/nxp/k32w0_sdk/sdk_fixes/patch_k32w_sdk.sh
@@ -5,6 +5,9 @@
exit 1
fi
+SOURCE=${BASH_SOURCE[0]}
+SOURCE_DIR=$(cd "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
+
cp ./third_party/nxp/k32w0_sdk/sdk_fixes/gpio_pins.h "$NXP_K32W061_SDK_ROOT"/boards/k32w061dk6/wireless_examples/openthread/reed/bm/
cp ./third_party/nxp/k32w0_sdk/sdk_fixes/pin_mux.c "$NXP_K32W061_SDK_ROOT"/boards/k32w061dk6/wireless_examples/openthread/enablement/
@@ -17,5 +20,7 @@
cp -r ./third_party/nxp/k32w0_sdk/sdk_fixes/OtaUtils.c "$NXP_K32W061_SDK_ROOT"/middleware/wireless/framework/OtaSupport/Source/
+patch -d "$NXP_K32W061_SDK_ROOT"/middleware/wireless/bluetooth/host/interface -p1 <"$SOURCE_DIR/patch_ble_utils_h.patch"
+
echo "K32W SDK MR3 QP1 was patched!"
exit 0