Use Optional<ExchangeHandle> variable instead of ExchangeContext in PairingSession. (#32499)
This ensures that the ExchangeContext for the session is automatically
reference counted without explicit manual Retain() and Release() calls.
As a result, the ExchangeContext is held until PairingSession::Clear()
gets called that, in turn, calls ClearValue() on the Optional to
internally call Release() on the underlying target.
Fixes Issue #32498.
diff --git a/src/protocols/secure_channel/CASESession.cpp b/src/protocols/secure_channel/CASESession.cpp
index a9b6248..5fda196 100644
--- a/src/protocols/secure_channel/CASESession.cpp
+++ b/src/protocols/secure_channel/CASESession.cpp
@@ -514,7 +514,7 @@
// We are setting the exchange context specifically before checking for error.
// This is to make sure the exchange will get closed if Init() returned an error.
- mExchangeCtxt = exchangeCtxt;
+ mExchangeCtxt.Emplace(*exchangeCtxt);
// From here onwards, let's go to exit on error, as some state might have already
// been initialized
@@ -527,7 +527,7 @@
mSessionResumptionStorage = sessionResumptionStorage;
mLocalMRPConfig = mrpLocalConfig.ValueOr(GetDefaultMRPConfig());
- mExchangeCtxt->UseSuggestedResponseTimeout(kExpectedSigma1ProcessingTime);
+ mExchangeCtxt.Value()->UseSuggestedResponseTimeout(kExpectedSigma1ProcessingTime);
mPeerNodeId = peerScopedNodeId.GetNodeId();
mLocalNodeId = fabricInfo->GetNodeId();
@@ -549,7 +549,8 @@
{
MATTER_TRACE_SCOPE("OnResponseTimeout", "CASESession");
VerifyOrReturn(ec != nullptr, ChipLogError(SecureChannel, "CASESession::OnResponseTimeout was called by null exchange"));
- VerifyOrReturn(mExchangeCtxt == ec, ChipLogError(SecureChannel, "CASESession::OnResponseTimeout exchange doesn't match"));
+ VerifyOrReturn(mExchangeCtxt.HasValue() && (&mExchangeCtxt.Value().Get() == ec),
+ ChipLogError(SecureChannel, "CASESession::OnResponseTimeout exchange doesn't match"));
ChipLogError(SecureChannel, "CASESession timed out while waiting for a response from the peer. Current state was %u",
to_underlying(mState));
MATTER_TRACE_COUNTER("CASETimeout");
@@ -735,8 +736,8 @@
ReturnErrorOnFailure(mCommissioningHash.AddData(ByteSpan{ msg_R1->Start(), msg_R1->DataLength() }));
// Call delegate to send the msg to peer
- ReturnErrorOnFailure(mExchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma1, std::move(msg_R1),
- SendFlags(SendMessageFlags::kExpectResponse)));
+ ReturnErrorOnFailure(mExchangeCtxt.Value()->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma1, std::move(msg_R1),
+ SendFlags(SendMessageFlags::kExpectResponse)));
mState = resuming ? State::kSentSigma1Resume : State::kSentSigma1;
@@ -959,8 +960,9 @@
ReturnErrorOnFailure(tlvWriter.Finalize(&msg_R2_resume));
// Call delegate to send the msg to peer
- ReturnErrorOnFailure(mExchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma2Resume, std::move(msg_R2_resume),
- SendFlags(SendMessageFlags::kExpectResponse)));
+ ReturnErrorOnFailure(mExchangeCtxt.Value()->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma2Resume,
+ std::move(msg_R2_resume),
+ SendFlags(SendMessageFlags::kExpectResponse)));
mState = State::kSentSigma2Resume;
@@ -1096,8 +1098,8 @@
ReturnErrorOnFailure(mCommissioningHash.AddData(ByteSpan{ msg_R2->Start(), msg_R2->DataLength() }));
// Call delegate to send the msg to peer
- ReturnErrorOnFailure(mExchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma2, std::move(msg_R2),
- SendFlags(SendMessageFlags::kExpectResponse)));
+ ReturnErrorOnFailure(mExchangeCtxt.Value()->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma2, std::move(msg_R2),
+ SendFlags(SendMessageFlags::kExpectResponse)));
mState = State::kSentSigma2;
@@ -1148,7 +1150,8 @@
if (tlvReader.Next() != CHIP_END_OF_TLV)
{
SuccessOrExit(err = DecodeMRPParametersIfPresent(TLV::ContextTag(4), tlvReader));
- mExchangeCtxt->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(GetRemoteSessionParameters());
+ mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(
+ GetRemoteSessionParameters());
}
ChipLogDetail(SecureChannel, "Peer assigned session session ID %d", responderSessionId);
@@ -1341,7 +1344,8 @@
if (tlvReader.Next() != CHIP_END_OF_TLV)
{
SuccessOrExit(err = DecodeMRPParametersIfPresent(TLV::ContextTag(kTag_Sigma2_ResponderMRPParams), tlvReader));
- mExchangeCtxt->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(GetRemoteSessionParameters());
+ mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(
+ GetRemoteSessionParameters());
}
exit:
@@ -1410,7 +1414,7 @@
{
SuccessOrExit(err = helper->ScheduleWork());
mSendSigma3Helper = helper;
- mExchangeCtxt->WillSendMessage();
+ mExchangeCtxt.Value()->WillSendMessage();
mState = State::kSendSigma3Pending;
}
else
@@ -1537,8 +1541,8 @@
SuccessOrExit(err);
// Call delegate to send the Msg3 to peer
- err = mExchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma3, std::move(msg_R3),
- SendFlags(SendMessageFlags::kExpectResponse));
+ err = mExchangeCtxt.Value()->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma3, std::move(msg_R3),
+ SendFlags(SendMessageFlags::kExpectResponse));
SuccessOrExit(err);
ChipLogProgress(SecureChannel, "Sent Sigma3 msg");
@@ -1704,7 +1708,7 @@
SuccessOrExit(err = helper->ScheduleWork());
mHandleSigma3Helper = helper;
- mExchangeCtxt->WillSendMessage();
+ mExchangeCtxt.Value()->WillSendMessage();
mState = State::kHandleSigma3Pending;
}
@@ -2036,7 +2040,8 @@
if (err == CHIP_NO_ERROR && tlvReader.GetTag() == ContextTag(kInitiatorMRPParamsTag))
{
ReturnErrorOnFailure(DecodeMRPParametersIfPresent(TLV::ContextTag(kInitiatorMRPParamsTag), tlvReader));
- mExchangeCtxt->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(GetRemoteSessionParameters());
+ mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(
+ GetRemoteSessionParameters());
err = tlvReader.Next();
}
@@ -2092,18 +2097,19 @@
// mExchangeCtxt can be nullptr if this is the first message (CASE_Sigma1) received by CASESession
// via UnsolicitedMessageHandler. The exchange context is allocated by exchange manager and provided
// to the handler (CASESession object).
- if (mExchangeCtxt != nullptr)
+ if (mExchangeCtxt.HasValue())
{
- if (mExchangeCtxt != ec)
+ if (&mExchangeCtxt.Value().Get() != ec)
{
ReturnErrorOnFailure(CHIP_ERROR_INVALID_ARGUMENT);
}
}
else
{
- mExchangeCtxt = ec;
+ mExchangeCtxt.SetValue(ExchangeHandle(*ec));
+ mExchangeCtxt.Emplace(*ec);
}
- mExchangeCtxt->UseSuggestedResponseTimeout(kExpectedHighProcessingTime);
+ mExchangeCtxt.Value()->UseSuggestedResponseTimeout(kExpectedHighProcessingTime);
VerifyOrReturnError(!msg.IsNull(), CHIP_ERROR_INVALID_ARGUMENT);
return CHIP_NO_ERROR;
@@ -2128,7 +2134,7 @@
//
// Should you need to resume the CASESession, you could theoretically pass along the msg to a callback that gets
// registered when setting mStopHandshakeAtState.
- mExchangeCtxt->WillSendMessage();
+ mExchangeCtxt.Value()->WillSendMessage();
return CHIP_NO_ERROR;
}
#endif // CONFIG_BUILD_FOR_HOST_UNIT_TEST
@@ -2138,7 +2144,7 @@
msgType == Protocols::SecureChannel::MsgType::CASE_Sigma2Resume ||
msgType == Protocols::SecureChannel::MsgType::CASE_Sigma3)
{
- SuccessOrExit(err = mExchangeCtxt->FlushAcks());
+ SuccessOrExit(err = mExchangeCtxt.Value()->FlushAcks());
}
#endif // CHIP_CONFIG_SLOW_CRYPTO
diff --git a/src/protocols/secure_channel/PASESession.cpp b/src/protocols/secure_channel/PASESession.cpp
index 40dc677..ca2524d 100644
--- a/src/protocols/secure_channel/PASESession.cpp
+++ b/src/protocols/secure_channel/PASESession.cpp
@@ -218,12 +218,12 @@
mRole = CryptoContext::SessionRole::kInitiator;
- mExchangeCtxt = exchangeCtxt;
+ mExchangeCtxt.Emplace(*exchangeCtxt);
// When commissioning starts, the peer is assumed to be active.
- mExchangeCtxt->GetSessionHandle()->AsUnauthenticatedSession()->MarkActiveRx();
+ mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->MarkActiveRx();
- mExchangeCtxt->UseSuggestedResponseTimeout(kExpectedLowProcessingTime);
+ mExchangeCtxt.Value()->UseSuggestedResponseTimeout(kExpectedLowProcessingTime);
mLocalMRPConfig = mrpLocalConfig.ValueOr(GetDefaultMRPConfig());
@@ -244,7 +244,7 @@
{
MATTER_TRACE_SCOPE("OnResponseTimeout", "PASESession");
VerifyOrReturn(ec != nullptr, ChipLogError(SecureChannel, "PASESession::OnResponseTimeout was called by null exchange"));
- VerifyOrReturn(mExchangeCtxt == nullptr || mExchangeCtxt == ec,
+ VerifyOrReturn(!mExchangeCtxt.HasValue() || &mExchangeCtxt.Value().Get() == ec,
ChipLogError(SecureChannel, "PASESession::OnResponseTimeout exchange doesn't match"));
// If we were waiting for something, mNextExpectedMsg had better have a value.
ChipLogError(SecureChannel, "PASESession timed out while waiting for a response from the peer. Expected message type was %u",
@@ -308,8 +308,8 @@
// Update commissioning hash with the pbkdf2 param request that's being sent.
ReturnErrorOnFailure(mCommissioningHash.AddData(ByteSpan{ req->Start(), req->DataLength() }));
- ReturnErrorOnFailure(
- mExchangeCtxt->SendMessage(MsgType::PBKDFParamRequest, std::move(req), SendFlags(SendMessageFlags::kExpectResponse)));
+ ReturnErrorOnFailure(mExchangeCtxt.Value()->SendMessage(MsgType::PBKDFParamRequest, std::move(req),
+ SendFlags(SendMessageFlags::kExpectResponse)));
mNextExpectedMsg.SetValue(MsgType::PBKDFParamResponse);
@@ -364,7 +364,8 @@
if (tlvReader.Next() != CHIP_END_OF_TLV)
{
SuccessOrExit(err = DecodeMRPParametersIfPresent(TLV::ContextTag(5), tlvReader));
- mExchangeCtxt->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(GetRemoteSessionParameters());
+ mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(
+ GetRemoteSessionParameters());
}
err = SendPBKDFParamResponse(ByteSpan(initiatorRandom), hasPBKDFParameters);
@@ -428,8 +429,8 @@
ReturnErrorOnFailure(mCommissioningHash.AddData(ByteSpan{ resp->Start(), resp->DataLength() }));
ReturnErrorOnFailure(SetupSpake2p());
- ReturnErrorOnFailure(
- mExchangeCtxt->SendMessage(MsgType::PBKDFParamResponse, std::move(resp), SendFlags(SendMessageFlags::kExpectResponse)));
+ ReturnErrorOnFailure(mExchangeCtxt.Value()->SendMessage(MsgType::PBKDFParamResponse, std::move(resp),
+ SendFlags(SendMessageFlags::kExpectResponse)));
ChipLogDetail(SecureChannel, "Sent PBKDF param response");
mNextExpectedMsg.SetValue(MsgType::PASE_Pake1);
@@ -483,7 +484,8 @@
if (tlvReader.Next() != CHIP_END_OF_TLV)
{
SuccessOrExit(err = DecodeMRPParametersIfPresent(TLV::ContextTag(5), tlvReader));
- mExchangeCtxt->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(GetRemoteSessionParameters());
+ mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(
+ GetRemoteSessionParameters());
}
// TODO - Add a unit test that exercises mHavePBKDFParameters path
@@ -508,7 +510,8 @@
if (tlvReader.Next() != CHIP_END_OF_TLV)
{
SuccessOrExit(err = DecodeMRPParametersIfPresent(TLV::ContextTag(5), tlvReader));
- mExchangeCtxt->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(GetRemoteSessionParameters());
+ mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(
+ GetRemoteSessionParameters());
}
}
@@ -558,7 +561,7 @@
ReturnErrorOnFailure(tlvWriter.Finalize(&msg));
ReturnErrorOnFailure(
- mExchangeCtxt->SendMessage(MsgType::PASE_Pake1, std::move(msg), SendFlags(SendMessageFlags::kExpectResponse)));
+ mExchangeCtxt.Value()->SendMessage(MsgType::PASE_Pake1, std::move(msg), SendFlags(SendMessageFlags::kExpectResponse)));
ChipLogDetail(SecureChannel, "Sent spake2p msg1");
mNextExpectedMsg.SetValue(MsgType::PASE_Pake2);
@@ -620,7 +623,8 @@
SuccessOrExit(err = tlvWriter.EndContainer(outerContainerType));
SuccessOrExit(err = tlvWriter.Finalize(&msg2));
- err = mExchangeCtxt->SendMessage(MsgType::PASE_Pake2, std::move(msg2), SendFlags(SendMessageFlags::kExpectResponse));
+ err =
+ mExchangeCtxt.Value()->SendMessage(MsgType::PASE_Pake2, std::move(msg2), SendFlags(SendMessageFlags::kExpectResponse));
SuccessOrExit(err);
mNextExpectedMsg.SetValue(MsgType::PASE_Pake3);
@@ -696,7 +700,8 @@
SuccessOrExit(err = tlvWriter.EndContainer(outerContainerType));
SuccessOrExit(err = tlvWriter.Finalize(&msg3));
- err = mExchangeCtxt->SendMessage(MsgType::PASE_Pake3, std::move(msg3), SendFlags(SendMessageFlags::kExpectResponse));
+ err =
+ mExchangeCtxt.Value()->SendMessage(MsgType::PASE_Pake3, std::move(msg3), SendFlags(SendMessageFlags::kExpectResponse));
SuccessOrExit(err);
mNextExpectedMsg.SetValue(MsgType::StatusReport);
@@ -786,25 +791,25 @@
// mExchangeCtxt can be nullptr if this is the first message (PBKDFParamRequest) received by PASESession
// via UnsolicitedMessageHandler. The exchange context is allocated by exchange manager and provided
// to the handler (PASESession object).
- if (mExchangeCtxt != nullptr)
+ if (mExchangeCtxt.HasValue())
{
- if (mExchangeCtxt != exchange)
+ if (&mExchangeCtxt.Value().Get() != exchange)
{
ReturnErrorOnFailure(CHIP_ERROR_INVALID_ARGUMENT);
}
}
else
{
- mExchangeCtxt = exchange;
+ mExchangeCtxt.Emplace(*exchange);
}
- if (!mExchangeCtxt->GetSessionHandle()->IsUnauthenticatedSession())
+ if (!mExchangeCtxt.Value()->GetSessionHandle()->IsUnauthenticatedSession())
{
ChipLogError(SecureChannel, "PASESession received PBKDFParamRequest over encrypted session. Ignoring.");
return CHIP_ERROR_INCORRECT_STATE;
}
- mExchangeCtxt->UseSuggestedResponseTimeout(kExpectedHighProcessingTime);
+ mExchangeCtxt.Value()->UseSuggestedResponseTimeout(kExpectedHighProcessingTime);
VerifyOrReturnError(!msg.IsNull(), CHIP_ERROR_INVALID_ARGUMENT);
VerifyOrReturnError((mNextExpectedMsg.HasValue() && payloadHeader.HasMessageType(mNextExpectedMsg.Value())) ||
@@ -833,7 +838,7 @@
if (msgType == MsgType::PBKDFParamRequest || msgType == MsgType::PBKDFParamResponse || msgType == MsgType::PASE_Pake1 ||
msgType == MsgType::PASE_Pake2 || msgType == MsgType::PASE_Pake3)
{
- SuccessOrExit(err = mExchangeCtxt->FlushAcks());
+ SuccessOrExit(err = mExchangeCtxt.Value()->FlushAcks());
}
#endif // CHIP_CONFIG_SLOW_CRYPTO
diff --git a/src/protocols/secure_channel/PairingSession.cpp b/src/protocols/secure_channel/PairingSession.cpp
index 6a6d03f..1f7874b 100644
--- a/src/protocols/secure_channel/PairingSession.cpp
+++ b/src/protocols/secure_channel/PairingSession.cpp
@@ -56,7 +56,7 @@
void PairingSession::Finish()
{
- Transport::PeerAddress address = mExchangeCtxt->GetSessionHandle()->AsUnauthenticatedSession()->GetPeerAddress();
+ Transport::PeerAddress address = mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->GetPeerAddress();
// Discard the exchange so that Clear() doesn't try closing it. The exchange will handle that.
DiscardExchange();
@@ -79,14 +79,15 @@
void PairingSession::DiscardExchange()
{
- if (mExchangeCtxt != nullptr)
+ if (mExchangeCtxt.HasValue())
{
// Make sure the exchange doesn't try to notify us when it closes,
// since we might be dead by then.
- mExchangeCtxt->SetDelegate(nullptr);
+ mExchangeCtxt.Value()->SetDelegate(nullptr);
+
// Null out mExchangeCtxt so that Clear() doesn't try closing it. The
// exchange will handle that.
- mExchangeCtxt = nullptr;
+ mExchangeCtxt.ClearValue();
}
}
@@ -227,19 +228,18 @@
void PairingSession::Clear()
{
- // Clear acts like the destructor if PairingSession, if it is call during
- // middle of a pairing, means we should terminate the exchange. For normal
- // path, the exchange should already be discarded before calling Clear.
- if (mExchangeCtxt != nullptr)
+ // Clear acts like the destructor of PairingSession. If it is called during
+ // the middle of pairing, that means we should terminate the exchange. For the
+ // normal path, the exchange should already be discarded before calling Clear.
+ if (mExchangeCtxt.HasValue())
{
- // The only time we reach this is if we are getting destroyed in the
- // middle of our handshake. In that case, there is no point trying to
- // do MRP resends of the last message we sent, so abort the exchange
+ // The only time we reach this is when we are getting destroyed in the
+ // middle of our handshake. In that case, there is no point in trying to
+ // do MRP resends of the last message we sent. So, abort the exchange
// instead of just closing it.
- mExchangeCtxt->Abort();
- mExchangeCtxt = nullptr;
+ mExchangeCtxt.Value()->Abort();
+ mExchangeCtxt.ClearValue();
}
-
mSecureSessionHolder.Release();
mPeerSessionId.ClearValue();
mSessionManager = nullptr;
diff --git a/src/protocols/secure_channel/PairingSession.h b/src/protocols/secure_channel/PairingSession.h
index f49dbf7..ea69f65 100644
--- a/src/protocols/secure_channel/PairingSession.h
+++ b/src/protocols/secure_channel/PairingSession.h
@@ -139,14 +139,14 @@
return CHIP_ERROR_INTERNAL;
}
- void SendStatusReport(Messaging::ExchangeContext * exchangeCtxt, uint16_t protocolCode)
+ void SendStatusReport(Optional<Messaging::ExchangeHandle> & exchangeCtxt, uint16_t protocolCode)
{
Protocols::SecureChannel::GeneralStatusCode generalCode = (protocolCode == Protocols::SecureChannel::kProtocolCodeSuccess)
? Protocols::SecureChannel::GeneralStatusCode::kSuccess
: Protocols::SecureChannel::GeneralStatusCode::kFailure;
ChipLogDetail(SecureChannel, "Sending status report. Protocol code %d, exchange %d", protocolCode,
- exchangeCtxt->GetExchangeId());
+ exchangeCtxt.Value()->GetExchangeId());
Protocols::SecureChannel::StatusReport statusReport(generalCode, Protocols::SecureChannel::Id, protocolCode);
@@ -159,7 +159,7 @@
System::PacketBufferHandle msg = bbuf.Finalize();
VerifyOrReturn(!msg.IsNull(), ChipLogError(SecureChannel, "Failed to allocate status report message"));
- CHIP_ERROR err = exchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::StatusReport, std::move(msg));
+ CHIP_ERROR err = exchangeCtxt.Value()->SendMessage(Protocols::SecureChannel::MsgType::StatusReport, std::move(msg));
if (err != CHIP_NO_ERROR)
{
ChipLogError(SecureChannel, "Failed to send status report message: %" CHIP_ERROR_FORMAT, err.Format());
@@ -238,9 +238,9 @@
SessionHolderWithDelegate mSecureSessionHolder;
// mSessionManager is set if we actually allocate a secure session, so we
// can clean it up later as needed.
- SessionManager * mSessionManager = nullptr;
- Messaging::ExchangeContext * mExchangeCtxt = nullptr;
- SessionEstablishmentDelegate * mDelegate = nullptr;
+ SessionManager * mSessionManager = nullptr;
+ Optional<Messaging::ExchangeHandle> mExchangeCtxt = NullOptional;
+ SessionEstablishmentDelegate * mDelegate = nullptr;
// mLocalMRPConfig is our config which is sent to the other end and used by the peer session.
// mRemoteSessionParams is received from other end and set to our session.