Fix the dnssd code that browses on both the local and srp domains (#32675)

* Fix the dnssd code that browses on both the local and srp domains

- Fixes the UAF issue with the timer

- Resolves critical comments from PR #32631

* Fix GetDomainFromHostName to get the domain name correctly

* Restyled by clang-format

---------

Co-authored-by: Restyled.io <commits@restyled.io>
diff --git a/src/platform/Darwin/DnssdContexts.cpp b/src/platform/Darwin/DnssdContexts.cpp
index c6067a5..a8ae75e 100644
--- a/src/platform/Darwin/DnssdContexts.cpp
+++ b/src/platform/Darwin/DnssdContexts.cpp
@@ -458,27 +458,35 @@
                                std::shared_ptr<uint32_t> && consumerCounterToUse) :
     browseThatCausedResolve(browseCausingResolve)
 {
-    type            = ContextType::Resolve;
-    context         = cbContext;
-    callback        = cb;
-    protocol        = GetProtocol(cbAddressType);
-    instanceName    = instanceNameToResolve;
-    consumerCounter = std::move(consumerCounterToUse);
+    type               = ContextType::Resolve;
+    context            = cbContext;
+    callback           = cb;
+    protocol           = GetProtocol(cbAddressType);
+    instanceName       = instanceNameToResolve;
+    consumerCounter    = std::move(consumerCounterToUse);
+    hasSrpTimerStarted = false;
 }
 
 ResolveContext::ResolveContext(CommissioningResolveDelegate * delegate, chip::Inet::IPAddressType cbAddressType,
                                const char * instanceNameToResolve, std::shared_ptr<uint32_t> && consumerCounterToUse) :
     browseThatCausedResolve(nullptr)
 {
-    type            = ContextType::Resolve;
-    context         = delegate;
-    callback        = nullptr;
-    protocol        = GetProtocol(cbAddressType);
-    instanceName    = instanceNameToResolve;
-    consumerCounter = std::move(consumerCounterToUse);
+    type               = ContextType::Resolve;
+    context            = delegate;
+    callback           = nullptr;
+    protocol           = GetProtocol(cbAddressType);
+    instanceName       = instanceNameToResolve;
+    consumerCounter    = std::move(consumerCounterToUse);
+    hasSrpTimerStarted = false;
 }
 
-ResolveContext::~ResolveContext() {}
+ResolveContext::~ResolveContext()
+{
+    if (this->hasSrpTimerStarted)
+    {
+        CancelSrpTimer(this);
+    }
+}
 
 void ResolveContext::DispatchFailure(const char * errorStr, CHIP_ERROR err)
 {
@@ -526,8 +534,7 @@
 
     for (auto interfaceIndex : priorityInterfaceIndices)
     {
-        // Try finding interfaces for domains kLocalDot and kOpenThreadDot and delete them.
-        if (TryReportingResultsForInterfaceIndex(static_cast<uint32_t>(interfaceIndex), std::string(kLocalDot)))
+        if (TryReportingResultsForInterfaceIndex(interfaceIndex))
         {
             if (needDelete)
             {
@@ -536,7 +543,7 @@
             return;
         }
 
-        if (TryReportingResultsForInterfaceIndex(static_cast<uint32_t>(interfaceIndex), std::string(kOpenThreadDot)))
+        if (TryReportingResultsForInterfaceIndex(interfaceIndex))
         {
             if (needDelete)
             {
@@ -548,7 +555,8 @@
 
     for (auto & interface : interfaces)
     {
-        if (TryReportingResultsForInterfaceIndex(interface.first.first, interface.first.second))
+        auto interfaceId = interface.first.first;
+        if (TryReportingResultsForInterfaceIndex(interfaceId))
         {
             break;
         }
@@ -560,7 +568,7 @@
     }
 }
 
-bool ResolveContext::TryReportingResultsForInterfaceIndex(uint32_t interfaceIndex, std::string domainName)
+bool ResolveContext::TryReportingResultsForInterfaceIndex(uint32_t interfaceIndex)
 {
     if (interfaceIndex == 0)
     {
@@ -568,36 +576,44 @@
         return false;
     }
 
-    std::pair<uint32_t, std::string> interfaceKey = std::make_pair(interfaceIndex, domainName);
-    auto & interface                              = interfaces[interfaceKey];
-    auto & ips                                    = interface.addresses;
-
-    // Some interface may not have any ips, just ignore them.
-    if (ips.size() == 0)
+    std::map<std::pair<uint32_t, std::string>, InterfaceInfo>::iterator iter = interfaces.begin();
+    while (iter != interfaces.end())
     {
-        return false;
-    }
+        std::pair<uint32_t, std::string> key = iter->first;
+        if (key.first == interfaceIndex)
+        {
+            auto & interface = interfaces[key];
+            auto & ips       = interface.addresses;
 
-    ChipLogProgress(Discovery, "Mdns: Resolve success on interface %" PRIu32, interfaceIndex);
+            // Some interface may not have any ips, just ignore them.
+            if (ips.size() == 0)
+            {
+                return false;
+            }
 
-    auto & service = interface.service;
-    auto addresses = Span<Inet::IPAddress>(ips.data(), ips.size());
-    if (nullptr == callback)
-    {
-        auto delegate = static_cast<CommissioningResolveDelegate *>(context);
-        DiscoveredNodeData nodeData;
-        service.ToDiscoveredNodeData(addresses, nodeData);
-        delegate->OnNodeDiscovered(nodeData);
-    }
-    else
-    {
-        callback(context, &service, addresses, CHIP_NO_ERROR);
-    }
+            ChipLogProgress(Discovery, "Mdns: Resolve success on interface %" PRIu32, interfaceIndex);
 
-    return true;
+            auto & service = interface.service;
+            auto addresses = Span<Inet::IPAddress>(ips.data(), ips.size());
+            if (nullptr == callback)
+            {
+                auto delegate = static_cast<CommissioningResolveDelegate *>(context);
+                DiscoveredNodeData nodeData;
+                service.ToDiscoveredNodeData(addresses, nodeData);
+                delegate->OnNodeDiscovered(nodeData);
+            }
+            else
+            {
+                callback(context, &service, addresses, CHIP_NO_ERROR);
+            }
+
+            return true;
+        }
+    }
+    return false;
 }
 
-CHIP_ERROR ResolveContext::OnNewAddress(const std::pair<uint32_t, std::string> interfaceKey, const struct sockaddr * address)
+CHIP_ERROR ResolveContext::OnNewAddress(const std::pair<uint32_t, std::string> & interfaceKey, const struct sockaddr * address)
 {
     // If we don't have any information about this interfaceId, just ignore the
     // address, since it won't be usable anyway without things like the port.
@@ -605,7 +621,7 @@
     // on the system, because the hostnames we are looking up all end in
     // ".local".  In other words, we can get regular DNS results in here, not
     // just DNS-SD ones.
-    uint32_t interfaceId = interfaceKey.first;
+    auto interfaceId = interfaceKey.first;
 
     if (interfaces.find(interfaceKey) == interfaces.end())
     {
@@ -720,8 +736,7 @@
     }
 
     std::pair<uint32_t, std::string> interfaceKey = std::make_pair(interfaceId, domainFromHostname);
-
-    interfaces.insert(std::make_pair(interfaceKey, std::move(interface)));
+    interfaces.insert(std::make_pair(std::move(interfaceKey), std::move(interface)));
 }
 
 bool ResolveContext::HasInterface()
diff --git a/src/platform/Darwin/DnssdImpl.cpp b/src/platform/Darwin/DnssdImpl.cpp
index bf0fc96..94bbdc4 100644
--- a/src/platform/Darwin/DnssdImpl.cpp
+++ b/src/platform/Darwin/DnssdImpl.cpp
@@ -32,8 +32,12 @@
 
 namespace {
 
-// The extra time in milliseconds that we will wait for the resolution on the open thread domain to complete.
-constexpr uint16_t kOpenThreadTimeoutInMsec = 250;
+constexpr char kLocalDot[] = "local.";
+
+constexpr char kSrpDot[] = "default.service.arpa.";
+
+// The extra time in milliseconds that we will wait for the resolution on the srp domain to complete.
+constexpr uint16_t kSrpTimeoutInMsec = 250;
 
 constexpr DNSServiceFlags kRegisterFlags        = kDNSServiceFlagsNoAutoRename;
 constexpr DNSServiceFlags kBrowseFlags          = kDNSServiceFlagsShareConnection;
@@ -144,59 +148,57 @@
 {
     std::string hostname = std::string(hostnameWithDomain);
 
-    // Find the last occurence of '.'
-    size_t last_pos = hostname.find_last_of(".");
-    if (last_pos != std::string::npos)
-    {
-        // Get a substring without last '.'
-        std::string substring = hostname.substr(0, last_pos);
+    // Find the first occurence of '.'
+    size_t first_pos = hostname.find(".");
 
-        // Find the last occurence of '.' in the substring created above.
-        size_t pos = substring.find_last_of(".");
-        if (pos != std::string::npos)
-        {
-            // Return the domain name between the last 2 occurences of '.' including the trailing dot'.'.
-            return std::string(hostname.substr(pos + 1, last_pos));
-        }
-    }
-    return std::string();
+    // if not found, return empty string
+    VerifyOrReturnValue(first_pos != std::string::npos, std::string());
+
+    // Get a substring after the first occurence of '.' to the end of the string
+    return hostname.substr(first_pos + 1, hostname.size());
+}
+
+/**
+ * @brief Callback that is called when the timeout for resolving on the kSrpDot domain has expired.
+ *
+ * @param[in] systemLayer The system layer.
+ * @param[in] callbackContext The context passed to the timer callback.
+ */
+void SrpTimerExpiredCallback(System::Layer * systemLayer, void * callbackContext)
+{
+    ChipLogProgress(Discovery, "Mdns: Timer expired for resolve to complete on the srp domain.");
+    auto sdCtx = static_cast<ResolveContext *>(callbackContext);
+    VerifyOrDie(sdCtx != nullptr);
+    sdCtx->Finalize();
+}
+
+/**
+ * @brief Starts a timer to wait for the resolution on the kSrpDot domain to happen.
+ *
+ * @param[in] timeoutSeconds The timeout in seconds.
+ * @param[in] ResolveContext The resolve context.
+ */
+CHIP_ERROR StartSrpTimer(uint16_t timeoutInMSecs, ResolveContext * ctx)
+{
+    VerifyOrReturnValue(ctx != nullptr, CHIP_ERROR_INCORRECT_STATE);
+    return DeviceLayer::SystemLayer().StartTimer(System::Clock::Milliseconds16(timeoutInMSecs), SrpTimerExpiredCallback,
+                                                 reinterpret_cast<void *>(ctx));
+}
+
+/**
+ * @brief Cancels the timer that was started to wait for the resolution on the kSrpDot domain to happen.
+ *
+ * @param[in] ResolveContext The resolve context.
+ */
+void CancelSrpTimer(ResolveContext * ctx)
+{
+    DeviceLayer::SystemLayer().CancelTimer(SrpTimerExpiredCallback, reinterpret_cast<void *>(ctx));
 }
 
 Global<MdnsContexts> MdnsContexts::sInstance;
 
 namespace {
 
-/**
- * @brief Callback that is called when the timeout for resolving on the kOpenThreadDot domain has expired.
- *
- * @param[in] systemLayer The system layer.
- * @param[in] callbackContext The context passed to the timer callback.
- */
-void OpenThreadTimerExpiredCallback(System::Layer * systemLayer, void * callbackContext)
-{
-    ChipLogProgress(Discovery, "Mdns: Timer expired for resolve to complete on the open thread domain.");
-    auto sdCtx = static_cast<ResolveContext *>(callbackContext);
-    VerifyOrDie(sdCtx != nullptr);
-
-    if (sdCtx->hasOpenThreadTimerStarted)
-    {
-        sdCtx->Finalize();
-    }
-}
-
-/**
- * @brief Starts a timer to wait for the resolution on the kOpenThreadDot domain to happen.
- *
- * @param[in] timeoutSeconds The timeout in seconds.
- * @param[in] ResolveContext The resolve context.
- */
-void StartOpenThreadTimer(uint16_t timeoutInMSecs, ResolveContext * ctx)
-{
-    VerifyOrReturn(ctx != nullptr, ChipLogError(Discovery, "Can't schedule open thread timer since context is null"));
-    DeviceLayer::SystemLayer().StartTimer(System::Clock::Milliseconds16(timeoutInMSecs), OpenThreadTimerExpiredCallback,
-                                          reinterpret_cast<void *>(ctx));
-}
-
 static void OnRegister(DNSServiceRef sdRef, DNSServiceFlags flags, DNSServiceErrorType err, const char * name, const char * type,
                        const char * domain, void * context)
 {
@@ -248,17 +250,17 @@
     auto err = DNSServiceCreateConnection(&sdCtx->serviceRef);
     VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));
 
-    // We will browse on both the local domain and the open thread domain.
+    // We will browse on both the local domain and the srp domain.
     ChipLogProgress(Discovery, "Browsing for: %s on domain %s", StringOrNullMarker(type), kLocalDot);
 
     auto sdRefLocal = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
     err             = DNSServiceBrowse(&sdRefLocal, kBrowseFlags, interfaceId, type, kLocalDot, OnBrowse, sdCtx);
     VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));
 
-    ChipLogProgress(Discovery, "Browsing for: %s on domain %s", StringOrNullMarker(type), kOpenThreadDot);
+    ChipLogProgress(Discovery, "Browsing for: %s on domain %s", StringOrNullMarker(type), kSrpDot);
 
-    DNSServiceRef sdRefOpenThread = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
-    err = DNSServiceBrowse(&sdRefOpenThread, kBrowseFlags, interfaceId, type, kOpenThreadDot, OnBrowse, sdCtx);
+    auto sdRefSrp = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
+    err           = DNSServiceBrowse(&sdRefSrp, kBrowseFlags, interfaceId, type, kSrpDot, OnBrowse, sdCtx);
     VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));
 
     return MdnsContexts::GetInstance().Add(sdCtx, sdCtx->serviceRef);
@@ -307,25 +309,37 @@
     {
         VerifyOrReturn(sdCtx->HasAddress(), sdCtx->Finalize(kDNSServiceErr_BadState));
 
-        if (domainName.compare(kOpenThreadDot) == 0)
+        if (domainName.compare(kSrpDot) == 0)
         {
-            ChipLogProgress(Discovery, "Mdns: Resolve completed on the open thread domain.");
+            ChipLogProgress(Discovery, "Mdns: Resolve completed on the srp domain.");
+
+            // Cancel the timer if one has been started
+            if (sdCtx->hasSrpTimerStarted)
+            {
+                CancelSrpTimer(sdCtx);
+            }
             sdCtx->Finalize();
         }
         else if (domainName.compare(kLocalDot) == 0)
         {
-            ChipLogProgress(
-                Discovery,
-                "Mdns: Resolve completed on the local domain. Starting a timer for the open thread resolve to come back");
+            ChipLogProgress(Discovery,
+                            "Mdns: Resolve completed on the local domain. Starting a timer for the srp resolve to come back");
 
-            // Usually the resolution on the local domain is quicker than on the open thread domain. We would like to give the
-            // resolution on the open thread domain around 250 millisecs more to give it a chance to resolve before finalizing
+            // Usually the resolution on the local domain is quicker than on the srp domain. We would like to give the
+            // resolution on the srp domain around 250 millisecs more to give it a chance to resolve before finalizing
             // the resolution.
-            if (!sdCtx->hasOpenThreadTimerStarted)
+            if (!sdCtx->hasSrpTimerStarted)
             {
-                // Schedule a timer to allow the resolve on OpenThread domain to complete.
-                StartOpenThreadTimer(kOpenThreadTimeoutInMsec, sdCtx);
-                sdCtx->hasOpenThreadTimerStarted = true;
+                // Schedule a timer to allow the resolve on Srp domain to complete.
+                CHIP_ERROR error = StartSrpTimer(kSrpTimeoutInMsec, sdCtx);
+
+                // If the timer fails to start, finalize the context and return.
+                if (error != CHIP_NO_ERROR)
+                {
+                    sdCtx->Finalize();
+                    return;
+                }
+                sdCtx->hasSrpTimerStarted = true;
             }
         }
     }
@@ -367,8 +381,7 @@
         if (!sdCtx->isResolveRequested)
         {
             GetAddrInfo(sdCtx);
-            sdCtx->isResolveRequested        = true;
-            sdCtx->hasOpenThreadTimerStarted = false;
+            sdCtx->isResolveRequested = true;
         }
     }
 }
@@ -382,13 +395,13 @@
     auto err = DNSServiceCreateConnection(&sdCtx->serviceRef);
     VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));
 
-    // Similar to browse, will try to resolve using both the local domain and the open thread domain.
+    // Similar to browse, will try to resolve using both the local domain and the srp domain.
     auto sdRefLocal = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
     err             = DNSServiceResolve(&sdRefLocal, kResolveFlags, interfaceId, name, type, kLocalDot, OnResolve, sdCtx);
     VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));
 
-    auto sdRefOpenThread = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
-    err = DNSServiceResolve(&sdRefOpenThread, kResolveFlags, interfaceId, name, type, kOpenThreadDot, OnResolve, sdCtx);
+    auto sdRefSrp = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
+    err           = DNSServiceResolve(&sdRefSrp, kResolveFlags, interfaceId, name, type, kSrpDot, OnResolve, sdCtx);
     VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));
 
     auto retval = MdnsContexts::GetInstance().Add(sdCtx, sdCtx->serviceRef);
diff --git a/src/platform/Darwin/DnssdImpl.h b/src/platform/Darwin/DnssdImpl.h
index 713a465..8a98530 100644
--- a/src/platform/Darwin/DnssdImpl.h
+++ b/src/platform/Darwin/DnssdImpl.h
@@ -27,15 +27,17 @@
 #include <string>
 #include <vector>
 
-constexpr char kLocalDot[] = "local.";
-
-constexpr char kOpenThreadDot[] = "default.service.arpa.";
-
 namespace chip {
 namespace Dnssd {
 
+struct BrowseWithDelegateContext;
+struct RegisterContext;
+struct ResolveContext;
+
 std::string GetDomainFromHostName(const char * hostname);
 
+void CancelSrpTimer(ResolveContext * ctx);
+
 enum class ContextType
 {
     Register,
@@ -62,10 +64,6 @@
     CHIP_ERROR FinalizeInternal(const char * errorStr, CHIP_ERROR err);
 };
 
-struct BrowseWithDelegateContext;
-struct RegisterContext;
-struct ResolveContext;
-
 class MdnsContexts
 {
 public:
@@ -239,7 +237,7 @@
     bool isResolveRequested = false;
     std::shared_ptr<uint32_t> consumerCounter;
     BrowseContext * const browseThatCausedResolve; // Can be null
-    bool hasOpenThreadTimerStarted = false;
+    bool hasSrpTimerStarted = false;
 
     // browseCausingResolve can be null.
     ResolveContext(void * cbContext, DnssdResolveCallback cb, chip::Inet::IPAddressType cbAddressType,
@@ -252,7 +250,7 @@
     void DispatchFailure(const char * errorStr, CHIP_ERROR err) override;
     void DispatchSuccess() override;
 
-    CHIP_ERROR OnNewAddress(const std::pair<uint32_t, std::string> interfaceKey, const struct sockaddr * address);
+    CHIP_ERROR OnNewAddress(const std::pair<uint32_t, std::string> & interfaceKey, const struct sockaddr * address);
     bool HasAddress();
 
     void OnNewInterface(uint32_t interfaceId, const char * fullname, const char * hostname, uint16_t port, uint16_t txtLen,
@@ -266,7 +264,7 @@
      * Returns true if information was reported, false if not (e.g. if there
      * were no IP addresses, etc).
      */
-    bool TryReportingResultsForInterfaceIndex(uint32_t interfaceIndex, std::string domainName);
+    bool TryReportingResultsForInterfaceIndex(uint32_t interfaceIndex);
 };
 
 } // namespace Dnssd