Allow ReliableMessageMgr to perform address relookup on first message fail (#21768)

* OnFirstMessageDeliveryFailed is removed from SessionDelegate
* CASESessionManager will implement delegate to allow updating Session transport peer address.
*During CASESessionManager::Init it will register itself with ReliableMessageMgr, so that ReliableMessageMgr has a callback.
* CASESessionManager create a different OperationalSessionSetup to perform address lookup and update that is separate
  from OperationalSessionSetup used for establishing a connection.

Co-authored-by: Boris Zbarsky <bzbarsky@apple.com>
diff --git a/src/app/CASESessionManager.cpp b/src/app/CASESessionManager.cpp
index 7ffa85d..7e38ccd 100644
--- a/src/app/CASESessionManager.cpp
+++ b/src/app/CASESessionManager.cpp
@@ -25,6 +25,7 @@
 {
     ReturnErrorOnFailure(params.sessionInitParams.Validate());
     mConfig = params;
+    params.sessionInitParams.exchangeMgr->GetReliableMessageMgr()->RegisterSessionUpdateDelegate(this);
     return AddressResolve::Resolver::Instance().Init(systemLayer);
 }
 
@@ -34,7 +35,8 @@
     ChipLogDetail(CASESessionManager, "FindOrEstablishSession: PeerId = [%d:" ChipLogFormatX64 "]", peerId.GetFabricIndex(),
                   ChipLogValueX64(peerId.GetNodeId()));
 
-    OperationalSessionSetup * session = FindExistingSessionSetup(peerId);
+    bool forAddressUpdate             = false;
+    OperationalSessionSetup * session = FindExistingSessionSetup(peerId, forAddressUpdate);
     if (session == nullptr)
     {
         ChipLogDetail(CASESessionManager, "FindOrEstablishSession: No existing OperationalSessionSetup instance found");
@@ -78,9 +80,28 @@
     return CHIP_NO_ERROR;
 }
 
-OperationalSessionSetup * CASESessionManager::FindExistingSessionSetup(const ScopedNodeId & peerId) const
+void CASESessionManager::UpdatePeerAddress(ScopedNodeId peerId)
 {
-    return mConfig.sessionSetupPool->FindSessionSetup(peerId);
+    bool forAddressUpdate             = true;
+    OperationalSessionSetup * session = FindExistingSessionSetup(peerId, forAddressUpdate);
+    if (session == nullptr)
+    {
+        ChipLogDetail(CASESessionManager, "UpdatePeerAddress: No existing OperationalSessionSetup instance found");
+
+        session = mConfig.sessionSetupPool->Allocate(mConfig.sessionInitParams, peerId, this);
+        if (session == nullptr)
+        {
+            ChipLogDetail(CASESessionManager, "UpdatePeerAddress: Failed to allocate OperationalSessionSetup instance");
+            return;
+        }
+    }
+
+    session->PerformAddressUpdate();
+}
+
+OperationalSessionSetup * CASESessionManager::FindExistingSessionSetup(const ScopedNodeId & peerId, bool forAddressUpdate) const
+{
+    return mConfig.sessionSetupPool->FindSessionSetup(peerId, forAddressUpdate);
 }
 
 Optional<SessionHandle> CASESessionManager::FindExistingSession(const ScopedNodeId & peerId) const
@@ -89,7 +110,7 @@
                                                                               MakeOptional(Transport::SecureSession::Type::kCASE));
 }
 
-void CASESessionManager::ReleaseSession(OperationalSessionSetup * session) const
+void CASESessionManager::ReleaseSession(OperationalSessionSetup * session)
 {
     if (session != nullptr)
     {
diff --git a/src/app/CASESessionManager.h b/src/app/CASESessionManager.h
index ddc68b3..d5243a9 100644
--- a/src/app/CASESessionManager.h
+++ b/src/app/CASESessionManager.h
@@ -26,6 +26,7 @@
 #include <lib/support/Pool.h>
 #include <platform/CHIPDeviceLayer.h>
 #include <transport/SessionDelegate.h>
+#include <transport/SessionUpdateDelegate.h>
 
 #include <lib/dnssd/ResolverProxy.h>
 
@@ -45,11 +46,17 @@
  * 4. During session establishment, trigger node ID resolution (if needed), and update the DNS-SD cache (if resolution is
  * successful)
  */
-class CASESessionManager : public OperationalSessionReleaseDelegate
+class CASESessionManager : public OperationalSessionReleaseDelegate, public SessionUpdateDelegate
 {
 public:
     CASESessionManager() = default;
-    virtual ~CASESessionManager() {}
+    virtual ~CASESessionManager()
+    {
+        if (mConfig.sessionInitParams.Validate() == CHIP_NO_ERROR)
+        {
+            mConfig.sessionInitParams.exchangeMgr->GetReliableMessageMgr()->RegisterSessionUpdateDelegate(nullptr);
+        }
+    }
 
     CHIP_ERROR Init(chip::System::Layer * systemLayer, const CASESessionManagerConfig & params);
     void Shutdown() {}
@@ -71,9 +78,9 @@
     void FindOrEstablishSession(const ScopedNodeId & peerId, Callback::Callback<OnDeviceConnected> * onConnection,
                                 Callback::Callback<OnDeviceConnectionFailure> * onFailure);
 
-    OperationalSessionSetup * FindExistingSessionSetup(const ScopedNodeId & peerId) const;
+    OperationalSessionSetup * FindExistingSessionSetup(const ScopedNodeId & peerId, bool forAddressUpdate = false) const;
 
-    void ReleaseSession(const ScopedNodeId & peerId) override;
+    void ReleaseSession(const ScopedNodeId & peerId);
 
     void ReleaseSessionsForFabric(FabricIndex fabricIndex);
 
@@ -89,11 +96,15 @@
      */
     CHIP_ERROR GetPeerAddress(const ScopedNodeId & peerId, Transport::PeerAddress & addr);
 
+    //////////// OperationalSessionReleaseDelegate Implementation ///////////////
+    void ReleaseSession(OperationalSessionSetup * device) override;
+
+    //////////// SessionUpdateDelegate Implementation ///////////////
+    void UpdatePeerAddress(ScopedNodeId peerId) override;
+
 private:
     Optional<SessionHandle> FindExistingSession(const ScopedNodeId & peerId) const;
 
-    void ReleaseSession(OperationalSessionSetup * device) const;
-
     CASESessionManagerConfig mConfig;
 };
 
diff --git a/src/app/OperationalSessionSetup.cpp b/src/app/OperationalSessionSetup.cpp
index ab58a58..8cafbe9 100644
--- a/src/app/OperationalSessionSetup.cpp
+++ b/src/app/OperationalSessionSetup.cpp
@@ -191,27 +191,29 @@
     if (mState == State::ResolvingAddress)
     {
         MoveToState(State::HasAddress);
-        err = EstablishConnection();
-        if (err != CHIP_NO_ERROR)
+        mInitParams.sessionManager->UpdateAllSessionsPeerAddress(mPeerId, addr);
+        if (!mPerformingAddressUpdate)
         {
-            DequeueConnectionCallbacks(err);
-            // Do not touch `this` instance anymore; it has been destroyed in DequeueConnectionCallbacks.
-            return;
-        }
-    }
-    else
-    {
-        if (!mSecureSession)
-        {
-            // Nothing needs to be done here.  It's not an error to not have a
-            // secureSession.  For one thing, we could have gotten a different
-            // UpdateAddress already and that caused connections to be torn down and
-            // whatnot.
+            err = EstablishConnection();
+            if (err != CHIP_NO_ERROR)
+            {
+                DequeueConnectionCallbacks(err);
+                // Do not touch `this` instance anymore; it has been destroyed in DequeueConnectionCallbacks.
+                return;
+            }
+            // We expect to get a callback via OnSessionEstablished or OnSessionEstablishmentError to continue
+            // the state machine forward.
             return;
         }
 
-        mSecureSession.Get().Value()->AsSecureSession()->SetPeerAddress(addr);
+        DequeueConnectionCallbacks(CHIP_NO_ERROR);
+        // Do not touch `this` instance anymore; it has been destroyed in DequeueConnectionCallbacks.
+        return;
     }
+
+    ChipLogError(Controller, "Received UpdateDeviceData in incorrect state");
+    DequeueConnectionCallbacks(CHIP_ERROR_INCORRECT_STATE);
+    // Do not touch `this` instance anymore; it has been destroyed in DequeueConnectionCallbacks.
 }
 
 CHIP_ERROR OperationalSessionSetup::EstablishConnection()
@@ -297,7 +299,7 @@
         }
     }
 
-    releaseDelegate->ReleaseSession(peerId);
+    releaseDelegate->ReleaseSession(this);
 }
 
 void OperationalSessionSetup::OnSessionEstablishmentError(CHIP_ERROR error)
@@ -364,11 +366,6 @@
     MoveToState(State::HasAddress);
 }
 
-void OperationalSessionSetup::OnFirstMessageDeliveryFailed()
-{
-    LookupPeerAddress();
-}
-
 void OperationalSessionSetup::OnSessionHang()
 {
     Disconnect();
@@ -423,6 +420,32 @@
     return Resolver::Instance().LookupNode(request, mAddressLookupHandle);
 }
 
+void OperationalSessionSetup::PerformAddressUpdate()
+{
+    if (mPerformingAddressUpdate)
+    {
+        // We are already in the middle of a lookup from a previous call to
+        // PerformAddressUpdate. In that case we will just exit right away as
+        // we are already looking to update the results from the previous lookup.
+        return;
+    }
+
+    // We must be newly-allocated to handle this address lookup, so must be in the NeedsAddress state.
+    VerifyOrDie(mState == State::NeedsAddress);
+
+    // We are doing an address lookup whether we have an active session for this peer or not.
+    mPerformingAddressUpdate = true;
+    MoveToState(State::ResolvingAddress);
+    CHIP_ERROR err = LookupPeerAddress();
+    if (err != CHIP_NO_ERROR)
+    {
+        ChipLogError(Controller, "PerformAddressUpdate could not perform lookup");
+        DequeueConnectionCallbacks(err);
+        // Do not touch `this` instance anymore; it has been destroyed in DequeueConnectionCallbacks.
+        return;
+    }
+}
+
 void OperationalSessionSetup::OnNodeAddressResolved(const PeerId & peerId, const ResolveResult & result)
 {
     UpdateDeviceData(result.address, result.mrpRemoteConfig);
@@ -438,6 +461,7 @@
         MoveToState(State::NeedsAddress);
     }
 
+    // No need to set mPerformingAddressUpdate to false since call below releases `this`.
     DequeueConnectionCallbacks(reason);
     // Do not touch `this` instance anymore; it has been destroyed in DequeueConnectionCallbacks.
 }
diff --git a/src/app/OperationalSessionSetup.h b/src/app/OperationalSessionSetup.h
index a821d04..ccea484 100644
--- a/src/app/OperationalSessionSetup.h
+++ b/src/app/OperationalSessionSetup.h
@@ -70,6 +70,8 @@
     }
 };
 
+class OperationalSessionSetup;
+
 /**
  * @brief Delegate provided when creating OperationalSessionSetup.
  *
@@ -80,10 +82,8 @@
 class OperationalSessionReleaseDelegate
 {
 public:
-    virtual ~OperationalSessionReleaseDelegate() = default;
-    // TODO Issue #20452: Once cleanup from #20452 takes place we can provide OperationalSessionSetup *
-    // instead of ScopedNodeId here.
-    virtual void ReleaseSession(const ScopedNodeId & peerId) = 0;
+    virtual ~OperationalSessionReleaseDelegate()                        = default;
+    virtual void ReleaseSession(OperationalSessionSetup * sessionSetup) = 0;
 };
 
 /**
@@ -202,6 +202,8 @@
 
     bool IsConnecting() const { return mState == State::Connecting; }
 
+    bool IsForAddressUpdate() const { return mPerformingAddressUpdate; }
+
     /**
      * IsResolvingAddress returns true if we are doing an address resolution
      * that needs to happen before we can establish CASE.  We can be in the
@@ -219,8 +221,6 @@
 
     // Called when a connection is closing. The object releases all resources associated with the connection.
     void OnSessionReleased() override;
-    // Called when a message is not acked within first retrans timer, try to refresh the peer address
-    void OnFirstMessageDeliveryFailed() override;
     // Called when a connection is hanging. Try to re-establish another session, and shift to the new session when done, the
     // original session won't be touched during the period.
     void OnSessionHang() override;
@@ -261,6 +261,8 @@
      */
     CHIP_ERROR LookupPeerAddress();
 
+    void PerformAddressUpdate();
+
     // AddressResolve::NodeListener - notifications when dnssd finds a node IP address
     void OnNodeAddressResolved(const PeerId & peerId, const AddressResolve::ResolveResult & result) override;
     void OnNodeAddressResolutionFailed(const PeerId & peerId, CHIP_ERROR reason) override;
@@ -304,6 +306,8 @@
 
     ReliableMessageProtocolConfig mRemoteMRPConfig = GetDefaultMRPConfig();
 
+    bool mPerformingAddressUpdate = false;
+
     CHIP_ERROR EstablishConnection();
 
     /*
diff --git a/src/app/OperationalSessionSetupPool.h b/src/app/OperationalSessionSetupPool.h
index bd80276..50d1fbd 100644
--- a/src/app/OperationalSessionSetupPool.h
+++ b/src/app/OperationalSessionSetupPool.h
@@ -32,7 +32,7 @@
 
     virtual void Release(OperationalSessionSetup * device) = 0;
 
-    virtual OperationalSessionSetup * FindSessionSetup(ScopedNodeId peerId) = 0;
+    virtual OperationalSessionSetup * FindSessionSetup(ScopedNodeId peerId, bool forAddressUpdate) = 0;
 
     virtual void ReleaseAllSessionSetupsForFabric(FabricIndex fabricIndex) = 0;
 
@@ -55,11 +55,11 @@
 
     void Release(OperationalSessionSetup * device) override { mSessionSetupPool.ReleaseObject(device); }
 
-    OperationalSessionSetup * FindSessionSetup(ScopedNodeId peerId) override
+    OperationalSessionSetup * FindSessionSetup(ScopedNodeId peerId, bool forAddressUpdate) override
     {
         OperationalSessionSetup * foundDevice = nullptr;
         mSessionSetupPool.ForEachActiveObject([&](auto * activeSetup) {
-            if (activeSetup->GetPeerId() == peerId)
+            if (activeSetup->GetPeerId() == peerId && activeSetup->IsForAddressUpdate() == forAddressUpdate)
             {
                 foundDevice = activeSetup;
                 return Loop::Break;
diff --git a/src/messaging/ReliableMessageMgr.cpp b/src/messaging/ReliableMessageMgr.cpp
index 96a8563..d37c186 100644
--- a/src/messaging/ReliableMessageMgr.cpp
+++ b/src/messaging/ReliableMessageMgr.cpp
@@ -320,9 +320,9 @@
         if (exchangeMgr)
         {
             // After the first failure notify session manager to refresh device data
-            if (entry->sendCount == 1)
+            if (entry->sendCount == 1 && mSessionUpdateDelegate != nullptr)
             {
-                entry->ec->GetSessionHandle()->DispatchSessionEvent(&SessionDelegate::OnFirstMessageDeliveryFailed);
+                mSessionUpdateDelegate->UpdatePeerAddress(entry->ec->GetSessionHandle()->GetPeer());
             }
         }
     }
@@ -409,6 +409,11 @@
     mSystemLayer->CancelTimer(Timeout, this);
 }
 
+void ReliableMessageMgr::RegisterSessionUpdateDelegate(SessionUpdateDelegate * sessionUpdateDelegate)
+{
+    mSessionUpdateDelegate = sessionUpdateDelegate;
+}
+
 #if CHIP_CONFIG_TEST
 int ReliableMessageMgr::TestGetCountRetransTable()
 {
diff --git a/src/messaging/ReliableMessageMgr.h b/src/messaging/ReliableMessageMgr.h
index 889d236..9b9809f 100644
--- a/src/messaging/ReliableMessageMgr.h
+++ b/src/messaging/ReliableMessageMgr.h
@@ -33,6 +33,7 @@
 #include <messaging/ReliableMessageProtocolConfig.h>
 #include <system/SystemLayer.h>
 #include <system/SystemPacketBuffer.h>
+#include <transport/SessionUpdateDelegate.h>
 #include <transport/raw/MessageHeader.h>
 
 namespace chip {
@@ -171,6 +172,16 @@
      */
     void StopTimer();
 
+    /**
+     *  Registers a delegate to perform an address lookup and update all active sessions.
+     *
+     *  @param[in] sessionUpdateDelegate - Pointer to delegate to perform address lookup
+     *             that will update all active session. A null pointer is allowed if you
+     *             no longer have a valid delegate.
+     *
+     */
+    void RegisterSessionUpdateDelegate(SessionUpdateDelegate * sessionUpdateDelegate);
+
 #if CHIP_CONFIG_TEST
     // Functions for testing
     int TestGetCountRetransTable();
@@ -194,6 +205,8 @@
 
     // ReliableMessageProtocol Global tables for timer context
     ObjectPool<RetransTableEntry, CHIP_CONFIG_RMP_RETRANS_TABLE_SIZE> mRetransTable;
+
+    SessionUpdateDelegate * mSessionUpdateDelegate = nullptr;
 };
 
 } // namespace Messaging
diff --git a/src/transport/SessionDelegate.h b/src/transport/SessionDelegate.h
index 5eedd6a..b7dbb4b 100644
--- a/src/transport/SessionDelegate.h
+++ b/src/transport/SessionDelegate.h
@@ -62,14 +62,6 @@
 
     /**
      * @brief
-     *   Called when the first message delivery in an exchange fails, so actions aiming to recover connection can be performed.
-     *
-     *   Note: the implementation must not do anything that will destroy the session or change the SessionHolder.
-     */
-    virtual void OnFirstMessageDeliveryFailed() {}
-
-    /**
-     * @brief
      *   Called when a session is unresponsive for a while (detected by MRP)
      *
      *   Note: the implementation must not do anything that will destroy the session or change the SessionHolder.
diff --git a/src/transport/SessionManager.cpp b/src/transport/SessionManager.cpp
index 5fa99e0..e388f13 100644
--- a/src/transport/SessionManager.cpp
+++ b/src/transport/SessionManager.cpp
@@ -452,7 +452,7 @@
         if (session->IsActiveSession() && session->GetPeer() == node &&
             (!type.HasValue() || type.Value() == session->GetSecureSessionType()))
         {
-            session->AsSecureSession()->MarkAsDefunct();
+            session->MarkAsDefunct();
             found = true;
         }
         return Loop::Continue;
@@ -461,6 +461,19 @@
     return found;
 }
 
+void SessionManager::UpdateAllSessionsPeerAddress(const ScopedNodeId & node, const Transport::PeerAddress & addr)
+{
+    mSecureSessions.ForEachSession([&node, &addr](auto session) {
+        // Arguably we should only be updating active and defunct sessions, but there is no harm
+        // in updating evicted sessions.
+        if (session->GetPeer() == node && Transport::SecureSession::Type::kCASE == session->GetSecureSessionType())
+        {
+            session->SetPeerAddress(addr);
+        }
+        return Loop::Continue;
+    });
+}
+
 Optional<SessionHandle> SessionManager::AllocateSession(SecureSession::Type secureSessionType,
                                                         const ScopedNodeId & sessionEvictionHint)
 {
diff --git a/src/transport/SessionManager.h b/src/transport/SessionManager.h
index 9454555..c40a8af 100644
--- a/src/transport/SessionManager.h
+++ b/src/transport/SessionManager.h
@@ -364,6 +364,15 @@
 
     /**
      * @brief
+     *   Update all CASE sessions that match `node` with the provided transport peer address.
+     *
+     * @param node    Scoped node ID of the active sessions we want to update.
+     * @param addr    Transport peer address that we want to update to.
+     */
+    void UpdateAllSessionsPeerAddress(const ScopedNodeId & node, const Transport::PeerAddress & addr);
+
+    /**
+     * @brief
      *   Return the System Layer pointer used by current SessionManager.
      */
     System::Layer * SystemLayer() { return mSystemLayer; }
diff --git a/src/transport/SessionUpdateDelegate.h b/src/transport/SessionUpdateDelegate.h
new file mode 100644
index 0000000..3dc01f8
--- /dev/null
+++ b/src/transport/SessionUpdateDelegate.h
@@ -0,0 +1,36 @@
+/*
+ *    Copyright (c) 2022 Project CHIP Authors
+ *
+ *    Licensed under the Apache License, Version 2.0 (the "License");
+ *    you may not use this file except in compliance with the License.
+ *    You may obtain a copy of the License at
+ *
+ *        http://www.apache.org/licenses/LICENSE-2.0
+ *
+ *    Unless required by applicable law or agreed to in writing, software
+ *    distributed under the License is distributed on an "AS IS" BASIS,
+ *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ *    See the License for the specific language governing permissions and
+ *    limitations under the License.
+ */
+
+#pragma once
+
+#include <lib/core/ScopedNodeId.h>
+
+namespace chip {
+
+/**
+ * @brief
+ *   Delegate interface that will be notified by ReliableMessageMgr when an exchange
+ *   fails to deliver the first message.
+ */
+class DLL_EXPORT SessionUpdateDelegate
+{
+public:
+    virtual ~SessionUpdateDelegate() {}
+
+    virtual void UpdatePeerAddress(ScopedNodeId peerId) = 0;
+};
+
+} // namespace chip