Remove default constructor of SessionHandle, reduce dangling session handle (#9225)
diff --git a/src/app/CommandSender.cpp b/src/app/CommandSender.cpp
index 381a1fc..cfae386 100644
--- a/src/app/CommandSender.cpp
+++ b/src/app/CommandSender.cpp
@@ -34,7 +34,7 @@
namespace chip {
namespace app {
-CHIP_ERROR CommandSender::SendCommandRequest(NodeId aNodeId, FabricIndex aFabricIndex, SessionHandle * secureSession,
+CHIP_ERROR CommandSender::SendCommandRequest(NodeId aNodeId, FabricIndex aFabricIndex, Optional<SessionHandle> secureSession,
uint32_t timeout)
{
CHIP_ERROR err = CHIP_NO_ERROR;
@@ -50,14 +50,7 @@
AbortExistingExchangeContext();
// Create a new exchange context.
- if (secureSession == nullptr)
- {
- mpExchangeCtx = mpExchangeMgr->NewContext(SessionHandle(aNodeId, 0, 0, aFabricIndex), this);
- }
- else
- {
- mpExchangeCtx = mpExchangeMgr->NewContext(*secureSession, this);
- }
+ mpExchangeCtx = mpExchangeMgr->NewContext(secureSession.ValueOr(SessionHandle(aNodeId, 0, 0, aFabricIndex)), this);
VerifyOrExit(mpExchangeCtx != nullptr, err = CHIP_ERROR_NO_MEMORY);
mpExchangeCtx->SetResponseTimeout(timeout);
diff --git a/src/app/CommandSender.h b/src/app/CommandSender.h
index 316ffe7..6d7ce63 100644
--- a/src/app/CommandSender.h
+++ b/src/app/CommandSender.h
@@ -55,7 +55,7 @@
//
// If SendCommandRequest is never called, or the call fails, the API
// consumer is responsible for calling Shutdown on the CommandSender.
- CHIP_ERROR SendCommandRequest(NodeId aNodeId, FabricIndex aFabricIndex, SessionHandle * secureSession,
+ CHIP_ERROR SendCommandRequest(NodeId aNodeId, FabricIndex aFabricIndex, Optional<SessionHandle> secureSession,
uint32_t timeout = kImMessageTimeoutMsec);
private:
diff --git a/src/app/ReadHandler.cpp b/src/app/ReadHandler.cpp
index 102220b..3bc7bc7 100644
--- a/src/app/ReadHandler.cpp
+++ b/src/app/ReadHandler.cpp
@@ -132,7 +132,6 @@
if (IsInitialReport())
{
VerifyOrReturnLogError(mpExchangeCtx != nullptr, CHIP_ERROR_INCORRECT_STATE);
- mSecureHandle = mpExchangeCtx->GetSecureSession();
}
VerifyOrReturnLogError(mpExchangeCtx != nullptr, CHIP_ERROR_INCORRECT_STATE);
MoveToState(HandlerState::Reporting);
diff --git a/src/app/ReadHandler.h b/src/app/ReadHandler.h
index ca6d40a..9080bb2 100644
--- a/src/app/ReadHandler.h
+++ b/src/app/ReadHandler.h
@@ -166,7 +166,6 @@
Messaging::ExchangeManager * mpExchangeMgr = nullptr;
InteractionModelDelegate * mpDelegate = nullptr;
bool mInitialReport = false;
- SessionHandle mSecureHandle;
};
} // namespace app
} // namespace chip
diff --git a/src/app/ReadPrepareParams.h b/src/app/ReadPrepareParams.h
index 04f2e73..2403103 100644
--- a/src/app/ReadPrepareParams.h
+++ b/src/app/ReadPrepareParams.h
@@ -39,10 +39,9 @@
uint16_t mMinIntervalSeconds = 0;
uint16_t mMaxIntervalSeconds = 0;
- ReadPrepareParams() {}
- ReadPrepareParams(ReadPrepareParams && other)
+ ReadPrepareParams(SessionHandle sessionHandle) : mSessionHandle(sessionHandle) {}
+ ReadPrepareParams(ReadPrepareParams && other) : mSessionHandle(other.mSessionHandle)
{
- mSessionHandle = other.mSessionHandle;
mpEventPathParamsList = other.mpEventPathParamsList;
mEventPathParamsListSize = other.mEventPathParamsListSize;
mpAttributePathParamsList = other.mpAttributePathParamsList;
diff --git a/src/app/WriteClient.cpp b/src/app/WriteClient.cpp
index bca0126..6c34114 100644
--- a/src/app/WriteClient.cpp
+++ b/src/app/WriteClient.cpp
@@ -247,7 +247,7 @@
MoveToState(State::Uninitialized);
}
-CHIP_ERROR WriteClient::SendWriteRequest(NodeId aNodeId, FabricIndex aFabricIndex, SessionHandle * apSecureSession,
+CHIP_ERROR WriteClient::SendWriteRequest(NodeId aNodeId, FabricIndex aFabricIndex, Optional<SessionHandle> apSecureSession,
uint32_t timeout)
{
CHIP_ERROR err = CHIP_NO_ERROR;
@@ -263,14 +263,7 @@
ClearExistingExchangeContext();
// Create a new exchange context.
- if (apSecureSession == nullptr)
- {
- mpExchangeCtx = mpExchangeMgr->NewContext(SessionHandle(aNodeId, 0, 0, aFabricIndex), this);
- }
- else
- {
- mpExchangeCtx = mpExchangeMgr->NewContext(*apSecureSession, this);
- }
+ mpExchangeCtx = mpExchangeMgr->NewContext(apSecureSession.ValueOr(SessionHandle(aNodeId, 0, 0, aFabricIndex)), this);
VerifyOrExit(mpExchangeCtx != nullptr, err = CHIP_ERROR_NO_MEMORY);
mpExchangeCtx->SetResponseTimeout(timeout);
@@ -397,7 +390,7 @@
return err;
}
-CHIP_ERROR WriteClientHandle::SendWriteRequest(NodeId aNodeId, FabricIndex aFabricIndex, SessionHandle * apSecureSession,
+CHIP_ERROR WriteClientHandle::SendWriteRequest(NodeId aNodeId, FabricIndex aFabricIndex, Optional<SessionHandle> apSecureSession,
uint32_t timeout)
{
CHIP_ERROR err = mpWriteClient->SendWriteRequest(aNodeId, aFabricIndex, apSecureSession, timeout);
diff --git a/src/app/WriteClient.h b/src/app/WriteClient.h
index e757d2a..e731cb1 100644
--- a/src/app/WriteClient.h
+++ b/src/app/WriteClient.h
@@ -94,7 +94,8 @@
* If SendWriteRequest is never called, or the call fails, the API
* consumer is responsible for calling Shutdown on the WriteClient.
*/
- CHIP_ERROR SendWriteRequest(NodeId aNodeId, FabricIndex aFabricIndex, SessionHandle * apSecureSession, uint32_t timeout);
+ CHIP_ERROR SendWriteRequest(NodeId aNodeId, FabricIndex aFabricIndex, Optional<SessionHandle> apSecureSession,
+ uint32_t timeout);
/**
* Initialize the client object. Within the lifetime
@@ -175,7 +176,7 @@
* Finalize the message and send it to the desired node. The underlying write object will always be released, and the user
* should not use this object after calling this function.
*/
- CHIP_ERROR SendWriteRequest(NodeId aNodeId, FabricIndex aFabricIndex, SessionHandle * apSecureSession,
+ CHIP_ERROR SendWriteRequest(NodeId aNodeId, FabricIndex aFabricIndex, Optional<SessionHandle> apSecureSession,
uint32_t timeout = kImMessageTimeoutMsec);
/**
diff --git a/src/app/tests/TestCommandInteraction.cpp b/src/app/tests/TestCommandInteraction.cpp
index 219fc88..364a626 100644
--- a/src/app/tests/TestCommandInteraction.cpp
+++ b/src/app/tests/TestCommandInteraction.cpp
@@ -231,7 +231,7 @@
err = commandSender.Init(&gExchangeManager, nullptr);
NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR);
- err = commandSender.SendCommandRequest(kTestDeviceNodeId, gFabricIndex, nullptr);
+ err = commandSender.SendCommandRequest(kTestDeviceNodeId, gFabricIndex, Optional<SessionHandle>::Missing());
NL_TEST_ASSERT(apSuite, err == CHIP_ERROR_INCORRECT_STATE);
}
@@ -270,7 +270,7 @@
NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR);
AddCommandDataElement(apSuite, apContext, &commandSender, false);
- err = commandSender.SendCommandRequest(kTestDeviceNodeId, gFabricIndex, nullptr);
+ err = commandSender.SendCommandRequest(kTestDeviceNodeId, gFabricIndex, Optional<SessionHandle>::Missing());
NL_TEST_ASSERT(apSuite, err == CHIP_ERROR_NOT_CONNECTED);
GenerateReceivedCommand(apSuite, apContext, buf, true /*aNeedCommandData*/);
diff --git a/src/app/tests/TestReadInteraction.cpp b/src/app/tests/TestReadInteraction.cpp
index d57a502..8a1fe00 100644
--- a/src/app/tests/TestReadInteraction.cpp
+++ b/src/app/tests/TestReadInteraction.cpp
@@ -303,9 +303,8 @@
System::PacketBufferHandle buf = System::PacketBufferHandle::New(System::PacketBuffer::kMaxSize);
err = readClient.Init(&ctx.GetExchangeManager(), &delegate, 0 /* application identifier */);
NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR);
- ReadPrepareParams readPrepareParams;
- readPrepareParams.mSessionHandle = ctx.GetSessionLocalToPeer();
- err = readClient.SendReadRequest(readPrepareParams);
+ ReadPrepareParams readPrepareParams(ctx.GetSessionLocalToPeer());
+ err = readClient.SendReadRequest(readPrepareParams);
NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR);
GenerateReportData(apSuite, apContext, buf);
@@ -329,9 +328,8 @@
auto * engine = chip::app::InteractionModelEngine::GetInstance();
err = engine->Init(&ctx.GetExchangeManager(), &delegate);
NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR);
- Messaging::ExchangeManager exchangeManager;
- Messaging::ExchangeContext * exchangeCtx = exchangeManager.NewContext(SessionHandle(), nullptr);
- readHandler.Init(&exchangeManager, nullptr, exchangeCtx);
+ Messaging::ExchangeContext * exchangeCtx = ctx.NewExchangeToPeer(nullptr);
+ readHandler.Init(&ctx.GetExchangeManager(), nullptr, exchangeCtx);
GenerateReportData(apSuite, apContext, reportDatabuf);
err = readHandler.SendReportData(std::move(reportDatabuf));
@@ -364,6 +362,7 @@
err = readHandler.OnReadInitialRequest(std::move(readRequestbuf));
NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR);
+ exchangeCtx->Close();
engine->Shutdown();
}
@@ -431,9 +430,8 @@
err = readClient.Init(&ctx.GetExchangeManager(), &delegate, 0 /* application identifier */);
NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR);
- ReadPrepareParams readPrepareParams;
- readPrepareParams.mSessionHandle = ctx.GetSessionLocalToPeer();
- err = readClient.SendReadRequest(readPrepareParams);
+ ReadPrepareParams readPrepareParams(ctx.GetSessionLocalToPeer());
+ err = readClient.SendReadRequest(readPrepareParams);
NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR);
GenerateReportData(apSuite, apContext, buf, true /*aNeedInvalidReport*/);
@@ -458,9 +456,8 @@
auto * engine = chip::app::InteractionModelEngine::GetInstance();
err = engine->Init(&ctx.GetExchangeManager(), &delegate);
NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR);
- Messaging::ExchangeManager exchangeManager;
- Messaging::ExchangeContext * exchangeCtx = exchangeManager.NewContext(SessionHandle(), nullptr);
- readHandler.Init(&exchangeManager, nullptr, exchangeCtx);
+ Messaging::ExchangeContext * exchangeCtx = ctx.NewExchangeToPeer(nullptr);
+ readHandler.Init(&ctx.GetExchangeManager(), nullptr, exchangeCtx);
GenerateReportData(apSuite, apContext, reportDatabuf);
err = readHandler.SendReportData(std::move(reportDatabuf));
@@ -488,6 +485,8 @@
err = readHandler.OnReadInitialRequest(std::move(readRequestbuf));
NL_TEST_ASSERT(apSuite, err == CHIP_ERROR_IM_MALFORMED_ATTRIBUTE_PATH);
+
+ exchangeCtx->Close();
engine->Shutdown();
}
@@ -638,8 +637,7 @@
attributePathParams[1].mFlags.Set(chip::app::AttributePathParams::Flags::kFieldIdValid);
attributePathParams[1].mFlags.Set(chip::app::AttributePathParams::Flags::kListIndexValid);
- ReadPrepareParams readPrepareParams;
- readPrepareParams.mSessionHandle = ctx.GetSessionLocalToPeer();
+ ReadPrepareParams readPrepareParams(ctx.GetSessionLocalToPeer());
readPrepareParams.mpEventPathParamsList = eventPathParams;
readPrepareParams.mEventPathParamsListSize = 2;
readPrepareParams.mpAttributePathParamsList = attributePathParams;
@@ -684,8 +682,7 @@
attributePathParams[0].mListIndex = 0;
attributePathParams[0].mFlags.Set(chip::app::AttributePathParams::Flags::kFieldIdValid);
- ReadPrepareParams readPrepareParams;
- readPrepareParams.mSessionHandle = ctx.GetSessionLocalToPeer();
+ ReadPrepareParams readPrepareParams(ctx.GetSessionLocalToPeer());
readPrepareParams.mpAttributePathParamsList = attributePathParams;
readPrepareParams.mAttributePathParamsListSize = 1;
err = chip::app::InteractionModelEngine::GetInstance()->SendReadRequest(readPrepareParams);
diff --git a/src/app/tests/TestWriteInteraction.cpp b/src/app/tests/TestWriteInteraction.cpp
index dec03ce..466c9e9 100644
--- a/src/app/tests/TestWriteInteraction.cpp
+++ b/src/app/tests/TestWriteInteraction.cpp
@@ -217,7 +217,8 @@
AddAttributeDataElement(apSuite, apContext, writeClientHandle);
SessionHandle session = ctx.GetSessionLocalToPeer();
- err = writeClientHandle.SendWriteRequest(ctx.GetDestinationNodeId(), ctx.GetFabricIndex(), &session);
+ err = writeClientHandle.SendWriteRequest(ctx.GetDestinationNodeId(), ctx.GetFabricIndex(),
+ Optional<SessionHandle>::Value(session));
NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR);
// The internal WriteClient should be nullptr once we SendWriteRequest.
NL_TEST_ASSERT(apSuite, nullptr == writeClientHandle.mpWriteClient);
@@ -306,7 +307,7 @@
SessionHandle session = ctx.GetSessionLocalToPeer();
- err = writeClient.SendWriteRequest(ctx.GetDestinationNodeId(), ctx.GetFabricIndex(), &session);
+ err = writeClient.SendWriteRequest(ctx.GetDestinationNodeId(), ctx.GetFabricIndex(), Optional<SessionHandle>::Value(session));
NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR);
NL_TEST_ASSERT(apSuite, delegate.mGotResponse);
diff --git a/src/app/tests/integration/chip_im_initiator.cpp b/src/app/tests/integration/chip_im_initiator.cpp
index 362d063..ffa576f 100644
--- a/src/app/tests/integration/chip_im_initiator.cpp
+++ b/src/app/tests/integration/chip_im_initiator.cpp
@@ -126,7 +126,8 @@
err = commandSender->FinishCommand();
SuccessOrExit(err);
- err = commandSender->SendCommandRequest(chip::kTestDeviceNodeId, gFabricIndex, nullptr, gMessageTimeoutMsec);
+ err = commandSender->SendCommandRequest(chip::kTestDeviceNodeId, gFabricIndex, chip::Optional<chip::SessionHandle>::Missing(),
+ gMessageTimeoutMsec);
SuccessOrExit(err);
exit:
@@ -163,7 +164,8 @@
err = commandSender->FinishCommand();
SuccessOrExit(err);
- err = commandSender->SendCommandRequest(chip::kTestDeviceNodeId, gFabricIndex, nullptr, gMessageTimeoutMsec);
+ err = commandSender->SendCommandRequest(chip::kTestDeviceNodeId, gFabricIndex, chip::Optional<chip::SessionHandle>::Missing(),
+ gMessageTimeoutMsec);
SuccessOrExit(err);
exit:
@@ -181,7 +183,6 @@
CHIP_ERROR SendReadRequest()
{
CHIP_ERROR err = CHIP_NO_ERROR;
- chip::app::ReadPrepareParams readPrepareParams;
chip::app::EventPathParams eventPathParams[2];
eventPathParams[0].mNodeId = kTestNodeId;
eventPathParams[0].mEndpointId = kTestEndpointId;
@@ -198,7 +199,7 @@
printf("\nSend read request message to Node: %" PRIu64 "\n", chip::kTestDeviceNodeId);
- readPrepareParams.mSessionHandle = chip::SessionHandle(chip::kTestDeviceNodeId, 0, 0, gFabricIndex);
+ chip::app::ReadPrepareParams readPrepareParams(chip::SessionHandle(chip::kTestDeviceNodeId, 0, 0, gFabricIndex));
readPrepareParams.mTimeout = gMessageTimeoutMsec;
readPrepareParams.mpAttributePathParamsList = &attributePathParams;
readPrepareParams.mAttributePathParamsListSize = 1;
@@ -241,7 +242,8 @@
SuccessOrExit(err = writer->PutBoolean(chip::TLV::ContextTag(chip::app::AttributeDataElement::kCsTag_Data), true));
SuccessOrExit(err = apWriteClient->FinishAttribute());
- SuccessOrExit(err = apWriteClient.SendWriteRequest(chip::kTestDeviceNodeId, gFabricIndex, nullptr, gMessageTimeoutMsec));
+ SuccessOrExit(err = apWriteClient.SendWriteRequest(chip::kTestDeviceNodeId, gFabricIndex,
+ chip::Optional<chip::SessionHandle>::Missing(), gMessageTimeoutMsec));
gWriteCount++;
diff --git a/src/channel/ChannelContext.cpp b/src/channel/ChannelContext.cpp
index 607ff8a..72523b7 100644
--- a/src/channel/ChannelContext.cpp
+++ b/src/channel/ChannelContext.cpp
@@ -258,7 +258,8 @@
auto & prepare = GetPrepareVars();
prepare.mCasePairingSession = Platform::New<CASESession>();
- ExchangeContext * ctxt = mExchangeManager->NewContext(SessionHandle(), prepare.mCasePairingSession);
+ ExchangeContext * ctxt =
+ mExchangeManager->NewContext(SessionHandle::TemporaryUnauthenticatedSession(), prepare.mCasePairingSession);
VerifyOrReturn(ctxt != nullptr);
// TODO: currently only supports IP/UDP paring
diff --git a/src/controller/CHIPDevice.cpp b/src/controller/CHIPDevice.cpp
index fd1a2bd..78b9069 100644
--- a/src/controller/CHIPDevice.cpp
+++ b/src/controller/CHIPDevice.cpp
@@ -73,7 +73,7 @@
ReturnErrorOnFailure(LoadSecureSessionParametersIfNeeded(loadedSecureSession));
- Messaging::ExchangeContext * exchange = mExchangeMgr->NewContext(mSecureSession, nullptr);
+ Messaging::ExchangeContext * exchange = mExchangeMgr->NewContext(mSecureSession.Value(), nullptr);
VerifyOrReturnError(exchange != nullptr, CHIP_ERROR_NO_MEMORY);
if (!loadedSecureSession)
@@ -129,10 +129,18 @@
}
else
{
- Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession);
-
- // Check if the connection state has the correct transport information
- if (connectionState == nullptr || connectionState->GetPeerAddress().GetTransportType() == Transport::Type::kUndefined)
+ if (mSecureSession.HasValue())
+ {
+ Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession.Value());
+ // Check if the connection state has the correct transport information
+ if (connectionState->GetPeerAddress().GetTransportType() == Transport::Type::kUndefined)
+ {
+ mState = ConnectionState::NotConnected;
+ ReturnErrorOnFailure(LoadSecureSessionParameters(ResetTransport::kNo));
+ didLoad = true;
+ }
+ }
+ else
{
mState = ConnectionState::NotConnected;
ReturnErrorOnFailure(LoadSecureSessionParameters(ResetTransport::kNo));
@@ -148,7 +156,7 @@
bool loadedSecureSession = false;
ReturnErrorOnFailure(LoadSecureSessionParametersIfNeeded(loadedSecureSession));
VerifyOrReturnError(commandObj != nullptr, CHIP_ERROR_INVALID_ARGUMENT);
- return commandObj->SendCommandRequest(mDeviceId, mFabricIndex, &mSecureSession);
+ return commandObj->SendCommandRequest(mDeviceId, mFabricIndex, mSecureSession);
}
CHIP_ERROR Device::Serialize(SerializedDevice & output)
@@ -166,14 +174,13 @@
serializable.mDevicePort = Encoding::LittleEndian::HostSwap16(mDeviceAddress.GetPort());
serializable.mFabricIndex = Encoding::LittleEndian::HostSwap16(mFabricIndex);
- Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession);
-
// The connection state could be null if the device is moving from PASE connection to CASE connection.
// The device parameters (e.g. mDeviceOperationalCertProvisioned) are updated during this transition.
// The state during this transistion is being persisted so that the next access of the device will
// trigger the CASE based secure session.
- if (connectionState != nullptr)
+ if (mSecureSession.HasValue())
{
+ Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession.Value());
const uint32_t localMessageCounter = connectionState->GetSessionMessageCounter().GetLocalMessageCounter().Value();
const uint32_t peerMessageCounter = connectionState->GetSessionMessageCounter().GetPeerMessageCounter().GetCounter();
@@ -303,13 +310,13 @@
void Device::OnNewConnection(SessionHandle session)
{
- mState = ConnectionState::SecureConnected;
- mSecureSession = session;
+ mState = ConnectionState::SecureConnected;
+ mSecureSession.SetValue(session);
// Reset the message counters here because this is the first time we get a handle to the secure session.
// Since CHIPDevices can be serialized/deserialized in the middle of what is conceptually a single PASE session
// we need to restore the session counters along with the session information.
- Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession);
+ Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession.Value());
VerifyOrReturn(connectionState != nullptr);
MessageCounter & localCounter = connectionState->GetSessionMessageCounter().GetLocalMessageCounter();
if (localCounter.SetCounter(mLocalMessageCounter) != CHIP_NO_ERROR)
@@ -322,10 +329,10 @@
void Device::OnConnectionExpired(SessionHandle session)
{
- VerifyOrReturn(session == mSecureSession,
+ VerifyOrReturn(mSecureSession.HasValue() && mSecureSession.Value() == session,
ChipLogDetail(Controller, "Connection expired, but it doesn't match the current session"));
- mState = ConnectionState::NotConnected;
- mSecureSession = SessionHandle{};
+ mState = ConnectionState::NotConnected;
+ mSecureSession.ClearValue();
}
CHIP_ERROR Device::OnMessageReceived(Messaging::ExchangeContext * exchange, const PacketHeader & header,
@@ -407,7 +414,10 @@
CHIP_ERROR Device::CloseSession()
{
ReturnErrorCodeIf(mState != ConnectionState::SecureConnected, CHIP_ERROR_INCORRECT_STATE);
- mSessionManager->ExpirePairing(mSecureSession);
+ if (mSecureSession.HasValue())
+ {
+ mSessionManager->ExpirePairing(mSecureSession.Value());
+ }
mState = ConnectionState::NotConnected;
return CHIP_NO_ERROR;
}
@@ -420,8 +430,7 @@
ReturnErrorOnFailure(LoadSecureSessionParametersIfNeeded(didLoad));
- Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession);
- if (connectionState == nullptr)
+ if (!mSecureSession.HasValue())
{
// Nothing needs to be done here. It's not an error to not have a
// connectionState. For one thing, we could have gotten an different
@@ -430,6 +439,7 @@
return CHIP_NO_ERROR;
}
+ Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession.Value());
connectionState->SetPeerAddress(addr);
return CHIP_NO_ERROR;
@@ -440,8 +450,8 @@
if (IsActive() && mStorageDelegate != nullptr && mSessionManager != nullptr)
{
// If a session can be found, persist the device so that we track the newest message counter values
- Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession);
- if (connectionState != nullptr)
+
+ if (mSecureSession.HasValue())
{
Persist();
}
@@ -549,7 +559,8 @@
VerifyOrReturnError(mDeviceOperationalCertProvisioned, CHIP_ERROR_INCORRECT_STATE);
VerifyOrReturnError(mState == ConnectionState::NotConnected, CHIP_NO_ERROR);
- Messaging::ExchangeContext * exchange = mExchangeMgr->NewContext(SessionHandle(), &mCASESession);
+ Messaging::ExchangeContext * exchange =
+ mExchangeMgr->NewContext(SessionHandle::TemporaryUnauthenticatedSession(), &mCASESession);
VerifyOrReturnError(exchange != nullptr, CHIP_ERROR_INTERNAL);
ReturnErrorOnFailure(mCASESession.MessageDispatch().Init(mSessionManager->GetTransportManager()));
@@ -687,7 +698,6 @@
CHIP_ERROR Device::SendReadAttributeRequest(app::AttributePathParams aPath, Callback::Cancelable * onSuccessCallback,
Callback::Cancelable * onFailureCallback, app::TLVDataFilter aTlvDataFilter)
{
- chip::app::ReadPrepareParams readPrepareParams;
bool loadedSecureSession = false;
uint8_t seqNum = GetNextSequenceNumber();
aPath.mNodeId = GetDeviceId();
@@ -700,7 +710,7 @@
}
// The application context is used to identify different requests from client applicaiton the type of it is intptr_t, here we
// use the seqNum.
- readPrepareParams.mSessionHandle = mSecureSession;
+ chip::app::ReadPrepareParams readPrepareParams(mSecureSession.Value());
readPrepareParams.mpAttributePathParamsList = &aPath;
readPrepareParams.mAttributePathParamsListSize = 1;
CHIP_ERROR err =
@@ -726,7 +736,7 @@
{
AddResponseHandler(seqNum, onSuccessCallback, onFailureCallback);
}
- if ((err = aHandle.SendWriteRequest(GetDeviceId(), 0, &mSecureSession)) != CHIP_NO_ERROR)
+ if ((err = aHandle.SendWriteRequest(GetDeviceId(), 0, mSecureSession)) != CHIP_NO_ERROR)
{
CancelResponseHandler(seqNum);
}
diff --git a/src/controller/CHIPDevice.h b/src/controller/CHIPDevice.h
index ee7a09b..3c15aa0 100644
--- a/src/controller/CHIPDevice.h
+++ b/src/controller/CHIPDevice.h
@@ -366,9 +366,9 @@
NodeId GetDeviceId() const { return mDeviceId; }
- bool MatchesSession(SessionHandle session) const { return mSecureSession == session; }
+ bool MatchesSession(SessionHandle session) const { return mSecureSession.HasValue() && mSecureSession.Value() == session; }
- SessionHandle GetSecureSession() const { return mSecureSession; }
+ SessionHandle GetSecureSession() const { return mSecureSession.Value(); }
void SetAddress(const Inet::IPAddress & deviceAddr) { mDeviceAddress.SetIPAddress(deviceAddr); }
@@ -474,7 +474,7 @@
Messaging::ExchangeManager * mExchangeMgr = nullptr;
- SessionHandle mSecureSession = {};
+ Optional<SessionHandle> mSecureSession = Optional<SessionHandle>::Missing();
uint8_t mSequenceNumber = 0;
diff --git a/src/controller/CHIPDeviceController.cpp b/src/controller/CHIPDeviceController.cpp
index 2587614..eb7c466 100644
--- a/src/controller/CHIPDeviceController.cpp
+++ b/src/controller/CHIPDeviceController.cpp
@@ -883,7 +883,7 @@
}
}
#endif
- exchangeCtxt = mExchangeMgr->NewContext(SessionHandle(), &mPairingSession);
+ exchangeCtxt = mExchangeMgr->NewContext(SessionHandle::TemporaryUnauthenticatedSession(), &mPairingSession);
VerifyOrExit(exchangeCtxt != nullptr, err = CHIP_ERROR_INTERNAL);
err = mIDAllocator.Allocate(keyID);
diff --git a/src/messaging/ExchangeContext.cpp b/src/messaging/ExchangeContext.cpp
index 673fbc2..93def28 100644
--- a/src/messaging/ExchangeContext.cpp
+++ b/src/messaging/ExchangeContext.cpp
@@ -93,6 +93,7 @@
Transport::PeerConnectionState * state = nullptr;
VerifyOrReturnError(mExchangeMgr != nullptr, CHIP_ERROR_INTERNAL);
+ VerifyOrReturnError(mSecureSession.HasValue(), CHIP_ERROR_CONNECTION_ABORTED);
// Don't let method get called on a freed object.
VerifyOrDie(mExchangeMgr != nullptr && GetReferenceCount() > 0);
@@ -104,7 +105,7 @@
bool reliableTransmissionRequested = true;
- state = mExchangeMgr->GetSessionMgr()->GetPeerConnectionState(mSecureSession);
+ state = mExchangeMgr->GetSessionMgr()->GetPeerConnectionState(mSecureSession.Value());
// If sending via UDP and NoAutoRequestAck send flag is not specificed, request reliable transmission.
if (state != nullptr && state->GetPeerAddress().GetTransportType() != Transport::Type::kUdp)
{
@@ -141,7 +142,7 @@
{
// Create a new scope for `err`, to avoid shadowing warning previous `err`.
- CHIP_ERROR err = mDispatch->SendMessage(mSecureSession, mExchangeId, IsInitiator(), GetReliableMessageContext(),
+ CHIP_ERROR err = mDispatch->SendMessage(mSecureSession.Value(), mExchangeId, IsInitiator(), GetReliableMessageContext(),
reliableTransmissionRequested, protocolId, msgType, std::move(msgBuf));
if (err != CHIP_NO_ERROR && IsResponseExpected())
{
@@ -233,9 +234,9 @@
{
VerifyOrDie(mExchangeMgr == nullptr);
- mExchangeMgr = em;
- mExchangeId = ExchangeId;
- mSecureSession = session;
+ mExchangeMgr = em;
+ mExchangeId = ExchangeId;
+ mSecureSession.SetValue(session);
mFlags.Set(Flags::kFlagInitiator, Initiator);
mDelegate = delegate;
@@ -305,7 +306,7 @@
(mExchangeId == payloadHeader.GetExchangeID())
// AND The Session ID associated with the incoming message matches the Session ID associated with the exchange.
- && (mSecureSession.MatchIncomingSession(session))
+ && (mSecureSession.HasValue() && mSecureSession.Value().MatchIncomingSession(session))
// AND The message was sent by an initiator and the exchange context is a responder (IsInitiator==false)
// OR The message was sent by a responder and the exchange context is an initiator (IsInitiator==true) (for the broadcast
@@ -320,7 +321,7 @@
// connection state) value, because it's still referencing the now-expired
// connection. This will mean that no more messages can be sent via this
// exchange, which seems fine given the semantics of connection expiration.
- mSecureSession = SessionHandle();
+ mSecureSession.ClearValue();
if (!IsResponseExpected())
{
diff --git a/src/messaging/ExchangeContext.h b/src/messaging/ExchangeContext.h
index 4a99c18..854a8c9 100644
--- a/src/messaging/ExchangeContext.h
+++ b/src/messaging/ExchangeContext.h
@@ -156,7 +156,7 @@
{
if (mExchangeACL == nullptr)
{
- Transport::FabricInfo * fabric = table.FindFabricWithIndex(mSecureSession.GetFabricIndex());
+ Transport::FabricInfo * fabric = table.FindFabricWithIndex(mSecureSession.Value().GetFabricIndex());
if (fabric != nullptr)
{
mExchangeACL = chip::Platform::New<CASEExchangeACL>(fabric);
@@ -166,7 +166,7 @@
return mExchangeACL;
}
- SessionHandle GetSecureSession() { return mSecureSession; }
+ SessionHandle GetSecureSession() { return mSecureSession.Value(); }
uint16_t GetExchangeId() const { return mExchangeId; }
@@ -188,8 +188,8 @@
ExchangeMessageDispatch * mDispatch = nullptr;
- SessionHandle mSecureSession; // The connection state
- uint16_t mExchangeId; // Assigned exchange ID.
+ Optional<SessionHandle> mSecureSession; // The connection state
+ uint16_t mExchangeId; // Assigned exchange ID.
/**
* Determine whether a response is currently expected for a message that was sent over
diff --git a/src/messaging/ExchangeMgr.cpp b/src/messaging/ExchangeMgr.cpp
index dd69e2c..58be315 100644
--- a/src/messaging/ExchangeMgr.cpp
+++ b/src/messaging/ExchangeMgr.cpp
@@ -315,7 +315,7 @@
}
mContextPool.ForEachActiveObject([&](auto * ec) {
- if (ec->mSecureSession == session)
+ if (ec->mSecureSession.HasValue() && ec->mSecureSession.Value() == session)
{
ec->OnConnectionExpired();
// Continue to iterate because there can be multiple exchanges
diff --git a/src/protocols/echo/Echo.h b/src/protocols/echo/Echo.h
index 11bf7e9..3ee26e0 100644
--- a/src/protocols/echo/Echo.h
+++ b/src/protocols/echo/Echo.h
@@ -103,7 +103,7 @@
Messaging::ExchangeManager * mExchangeMgr = nullptr;
Messaging::ExchangeContext * mExchangeCtx = nullptr;
EchoFunct OnEchoResponseReceived = nullptr;
- SessionHandle mSecureSession;
+ Optional<SessionHandle> mSecureSession = Optional<SessionHandle>();
CHIP_ERROR OnMessageReceived(Messaging::ExchangeContext * ec, const PacketHeader & packetHeader,
const PayloadHeader & payloadHeader, System::PacketBufferHandle && payload) override;
diff --git a/src/protocols/echo/EchoClient.cpp b/src/protocols/echo/EchoClient.cpp
index e1bb097..cb76a34 100644
--- a/src/protocols/echo/EchoClient.cpp
+++ b/src/protocols/echo/EchoClient.cpp
@@ -38,8 +38,8 @@
if (mExchangeMgr != nullptr)
return CHIP_ERROR_INCORRECT_STATE;
- mExchangeMgr = exchangeMgr;
- mSecureSession = session;
+ mExchangeMgr = exchangeMgr;
+ mSecureSession.SetValue(session);
OnEchoResponseReceived = nullptr;
mExchangeCtx = nullptr;
@@ -71,7 +71,7 @@
}
// Create a new exchange context.
- mExchangeCtx = mExchangeMgr->NewContext(mSecureSession, this);
+ mExchangeCtx = mExchangeMgr->NewContext(mSecureSession.Value(), this);
if (mExchangeCtx == nullptr)
{
return CHIP_ERROR_NO_MEMORY;
diff --git a/src/transport/SecureSessionMgr.cpp b/src/transport/SecureSessionMgr.cpp
index 5229624..d5ffcb2 100644
--- a/src/transport/SecureSessionMgr.cpp
+++ b/src/transport/SecureSessionMgr.cpp
@@ -296,7 +296,7 @@
{
PayloadHeader payloadHeader;
ReturnOnFailure(payloadHeader.DecodeAndConsume(msg));
- mCB->OnMessageReceived(packetHeader, payloadHeader, SessionHandle(), peerAddress,
+ mCB->OnMessageReceived(packetHeader, payloadHeader, SessionHandle::TemporaryUnauthenticatedSession(), peerAddress,
SecureSessionMgrDelegate::DuplicateMessage::No, std::move(msg));
}
}
diff --git a/src/transport/SessionHandle.h b/src/transport/SessionHandle.h
index 238d82f..fe9f0ca 100644
--- a/src/transport/SessionHandle.h
+++ b/src/transport/SessionHandle.h
@@ -24,8 +24,6 @@
class SessionHandle
{
public:
- SessionHandle() : mPeerNodeId(kPlaceholderNodeId), mFabric(Transport::kUndefinedFabricIndex) {}
-
SessionHandle(NodeId peerNodeId, FabricIndex fabric) : mPeerNodeId(peerNodeId), mFabric(fabric) {}
SessionHandle(NodeId peerNodeId, uint16_t localKeyId, uint16_t peerKeyId, FabricIndex fabric) :
@@ -64,6 +62,12 @@
const Optional<uint16_t> & GetPeerKeyId() const { return mPeerKeyId; }
const Optional<uint16_t> & GetLocalKeyId() const { return mLocalKeyId; }
+ // TODO: currently SessionHandle is not able to identify a unauthenticated session, create an empty handle for it
+ static SessionHandle TemporaryUnauthenticatedSession()
+ {
+ return SessionHandle(kPlaceholderNodeId, Transport::kUndefinedFabricIndex);
+ }
+
private:
friend class SecureSessionMgr;
NodeId mPeerNodeId;
diff --git a/src/transport/tests/TestSecureSessionMgr.cpp b/src/transport/tests/TestSecureSessionMgr.cpp
index 2092c29..23b4377 100644
--- a/src/transport/tests/TestSecureSessionMgr.cpp
+++ b/src/transport/tests/TestSecureSessionMgr.cpp
@@ -66,7 +66,7 @@
const Transport::PeerAddress & source, DuplicateMessage isDuplicate,
System::PacketBufferHandle && msgBuf) override
{
- NL_TEST_ASSERT(mSuite, session == mRemoteToLocalSession); // Packet received by remote peer
+ NL_TEST_ASSERT(mSuite, session == mRemoteToLocalSession.Value()); // Packet received by remote peer
size_t data_len = msgBuf->DataLength();
@@ -88,20 +88,20 @@
{
// Preset the MessageCounter
if (NewConnectionHandlerCallCount == 0)
- mRemoteToLocalSession = session;
+ mRemoteToLocalSession.SetValue(session);
if (NewConnectionHandlerCallCount == 1)
- mLocalToRemoteSession = session;
+ mLocalToRemoteSession.SetValue(session);
NewConnectionHandlerCallCount++;
}
void OnConnectionExpired(SessionHandle session) override { mOldConnectionDropped = true; }
bool mOldConnectionDropped = false;
- nlTestSuite * mSuite = nullptr;
- SessionHandle mRemoteToLocalSession;
- SessionHandle mLocalToRemoteSession;
- int ReceiveHandlerCallCount = 0;
- int NewConnectionHandlerCallCount = 0;
+ nlTestSuite * mSuite = nullptr;
+ Optional<SessionHandle> mRemoteToLocalSession = Optional<SessionHandle>::Missing();
+ Optional<SessionHandle> mLocalToRemoteSession = Optional<SessionHandle>::Missing();
+ int ReceiveHandlerCallCount = 0;
+ int NewConnectionHandlerCallCount = 0;
bool LargeMessageSent = false;
};
@@ -166,7 +166,7 @@
err = secureSessionMgr.NewPairing(peer, kDestinationNodeId, &pairing2, SecureSession::SessionRole::kResponder, 0);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
- SessionHandle localToRemoteSession = callback.mLocalToRemoteSession;
+ SessionHandle localToRemoteSession = callback.mLocalToRemoteSession.Value();
// Should be able to send a message to itself by just calling send.
callback.ReceiveHandlerCallCount = 0;
@@ -256,7 +256,7 @@
err = secureSessionMgr.NewPairing(peer, kDestinationNodeId, &pairing2, SecureSession::SessionRole::kResponder, 0);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
- SessionHandle localToRemoteSession = callback.mLocalToRemoteSession;
+ SessionHandle localToRemoteSession = callback.mLocalToRemoteSession.Value();
// Should be able to send a message to itself by just calling send.
callback.ReceiveHandlerCallCount = 0;
@@ -279,7 +279,7 @@
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
// Reset receive side message counter, or duplicated message will be denied.
- Transport::PeerConnectionState * state = secureSessionMgr.GetPeerConnectionState(callback.mRemoteToLocalSession);
+ Transport::PeerConnectionState * state = secureSessionMgr.GetPeerConnectionState(callback.mRemoteToLocalSession.Value());
state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1);
NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1);
@@ -330,7 +330,7 @@
err = secureSessionMgr.NewPairing(peer, kDestinationNodeId, &pairing2, SecureSession::SessionRole::kResponder, 0);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
- SessionHandle localToRemoteSession = callback.mLocalToRemoteSession;
+ SessionHandle localToRemoteSession = callback.mLocalToRemoteSession.Value();
// Should be able to send a message to itself by just calling send.
callback.ReceiveHandlerCallCount = 0;
@@ -356,7 +356,7 @@
/* -------------------------------------------------------------------------------------------*/
// Reset receive side message counter, or duplicated message will be denied.
- Transport::PeerConnectionState * state = secureSessionMgr.GetPeerConnectionState(callback.mRemoteToLocalSession);
+ Transport::PeerConnectionState * state = secureSessionMgr.GetPeerConnectionState(callback.mRemoteToLocalSession.Value());
state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1);
PacketHeader packetHeader;
diff --git a/src/transport/tests/TestSessionHandle.cpp b/src/transport/tests/TestSessionHandle.cpp
index 8a31c8e..a4793c6 100644
--- a/src/transport/tests/TestSessionHandle.cpp
+++ b/src/transport/tests/TestSessionHandle.cpp
@@ -37,24 +37,8 @@
using namespace chip;
-void TestInitialState(nlTestSuite * inSuite, void * inContext)
-{
- SessionHandle session;
-
- NL_TEST_ASSERT(inSuite, session.GetPeerNodeId() == kPlaceholderNodeId);
- NL_TEST_ASSERT(inSuite, session.GetFabricIndex() == Transport::kUndefinedFabricIndex);
- NL_TEST_ASSERT(inSuite, !session.HasFabricIndex());
- NL_TEST_ASSERT(inSuite, !session.GetLocalKeyId().HasValue());
- NL_TEST_ASSERT(inSuite, !session.GetPeerKeyId().HasValue());
-}
-
void TestMatchSession(nlTestSuite * inSuite, void * inContext)
{
- SessionHandle session1;
- SessionHandle session2;
- NL_TEST_ASSERT(inSuite, session1 == session2);
- NL_TEST_ASSERT(inSuite, session1.MatchIncomingSession(session2));
-
SessionHandle session3(chip::kTestDeviceNodeId, 1, 1, 0);
SessionHandle session4(chip::kTestDeviceNodeId, 1, 2, 0);
NL_TEST_ASSERT(inSuite, !(session3 == session4));
@@ -69,7 +53,6 @@
// clang-format off
static const nlTest sTests[] =
{
- NL_TEST_DEF("InitialState", TestInitialState),
NL_TEST_DEF("MatchSession", TestMatchSession),
NL_TEST_SENTINEL()
};