Fix the dnssd code that browses on both the local and srp domains (#32675)
* Fix the dnssd code that browses on both the local and srp domains
- Fixes the UAF issue with the timer
- Resolves critical comments from PR #32631
* Fix GetDomainFromHostName to get the domain name correctly
* Restyled by clang-format
---------
Co-authored-by: Restyled.io <commits@restyled.io>
diff --git a/src/platform/Darwin/DnssdContexts.cpp b/src/platform/Darwin/DnssdContexts.cpp
index c6067a5..a8ae75e 100644
--- a/src/platform/Darwin/DnssdContexts.cpp
+++ b/src/platform/Darwin/DnssdContexts.cpp
@@ -458,27 +458,35 @@
std::shared_ptr<uint32_t> && consumerCounterToUse) :
browseThatCausedResolve(browseCausingResolve)
{
- type = ContextType::Resolve;
- context = cbContext;
- callback = cb;
- protocol = GetProtocol(cbAddressType);
- instanceName = instanceNameToResolve;
- consumerCounter = std::move(consumerCounterToUse);
+ type = ContextType::Resolve;
+ context = cbContext;
+ callback = cb;
+ protocol = GetProtocol(cbAddressType);
+ instanceName = instanceNameToResolve;
+ consumerCounter = std::move(consumerCounterToUse);
+ hasSrpTimerStarted = false;
}
ResolveContext::ResolveContext(CommissioningResolveDelegate * delegate, chip::Inet::IPAddressType cbAddressType,
const char * instanceNameToResolve, std::shared_ptr<uint32_t> && consumerCounterToUse) :
browseThatCausedResolve(nullptr)
{
- type = ContextType::Resolve;
- context = delegate;
- callback = nullptr;
- protocol = GetProtocol(cbAddressType);
- instanceName = instanceNameToResolve;
- consumerCounter = std::move(consumerCounterToUse);
+ type = ContextType::Resolve;
+ context = delegate;
+ callback = nullptr;
+ protocol = GetProtocol(cbAddressType);
+ instanceName = instanceNameToResolve;
+ consumerCounter = std::move(consumerCounterToUse);
+ hasSrpTimerStarted = false;
}
-ResolveContext::~ResolveContext() {}
+ResolveContext::~ResolveContext()
+{
+ if (this->hasSrpTimerStarted)
+ {
+ CancelSrpTimer(this);
+ }
+}
void ResolveContext::DispatchFailure(const char * errorStr, CHIP_ERROR err)
{
@@ -526,8 +534,7 @@
for (auto interfaceIndex : priorityInterfaceIndices)
{
- // Try finding interfaces for domains kLocalDot and kOpenThreadDot and delete them.
- if (TryReportingResultsForInterfaceIndex(static_cast<uint32_t>(interfaceIndex), std::string(kLocalDot)))
+ if (TryReportingResultsForInterfaceIndex(interfaceIndex))
{
if (needDelete)
{
@@ -536,7 +543,7 @@
return;
}
- if (TryReportingResultsForInterfaceIndex(static_cast<uint32_t>(interfaceIndex), std::string(kOpenThreadDot)))
+ if (TryReportingResultsForInterfaceIndex(interfaceIndex))
{
if (needDelete)
{
@@ -548,7 +555,8 @@
for (auto & interface : interfaces)
{
- if (TryReportingResultsForInterfaceIndex(interface.first.first, interface.first.second))
+ auto interfaceId = interface.first.first;
+ if (TryReportingResultsForInterfaceIndex(interfaceId))
{
break;
}
@@ -560,7 +568,7 @@
}
}
-bool ResolveContext::TryReportingResultsForInterfaceIndex(uint32_t interfaceIndex, std::string domainName)
+bool ResolveContext::TryReportingResultsForInterfaceIndex(uint32_t interfaceIndex)
{
if (interfaceIndex == 0)
{
@@ -568,36 +576,44 @@
return false;
}
- std::pair<uint32_t, std::string> interfaceKey = std::make_pair(interfaceIndex, domainName);
- auto & interface = interfaces[interfaceKey];
- auto & ips = interface.addresses;
-
- // Some interface may not have any ips, just ignore them.
- if (ips.size() == 0)
+ std::map<std::pair<uint32_t, std::string>, InterfaceInfo>::iterator iter = interfaces.begin();
+ while (iter != interfaces.end())
{
- return false;
- }
+ std::pair<uint32_t, std::string> key = iter->first;
+ if (key.first == interfaceIndex)
+ {
+ auto & interface = interfaces[key];
+ auto & ips = interface.addresses;
- ChipLogProgress(Discovery, "Mdns: Resolve success on interface %" PRIu32, interfaceIndex);
+ // Some interface may not have any ips, just ignore them.
+ if (ips.size() == 0)
+ {
+ return false;
+ }
- auto & service = interface.service;
- auto addresses = Span<Inet::IPAddress>(ips.data(), ips.size());
- if (nullptr == callback)
- {
- auto delegate = static_cast<CommissioningResolveDelegate *>(context);
- DiscoveredNodeData nodeData;
- service.ToDiscoveredNodeData(addresses, nodeData);
- delegate->OnNodeDiscovered(nodeData);
- }
- else
- {
- callback(context, &service, addresses, CHIP_NO_ERROR);
- }
+ ChipLogProgress(Discovery, "Mdns: Resolve success on interface %" PRIu32, interfaceIndex);
- return true;
+ auto & service = interface.service;
+ auto addresses = Span<Inet::IPAddress>(ips.data(), ips.size());
+ if (nullptr == callback)
+ {
+ auto delegate = static_cast<CommissioningResolveDelegate *>(context);
+ DiscoveredNodeData nodeData;
+ service.ToDiscoveredNodeData(addresses, nodeData);
+ delegate->OnNodeDiscovered(nodeData);
+ }
+ else
+ {
+ callback(context, &service, addresses, CHIP_NO_ERROR);
+ }
+
+ return true;
+ }
+ }
+ return false;
}
-CHIP_ERROR ResolveContext::OnNewAddress(const std::pair<uint32_t, std::string> interfaceKey, const struct sockaddr * address)
+CHIP_ERROR ResolveContext::OnNewAddress(const std::pair<uint32_t, std::string> & interfaceKey, const struct sockaddr * address)
{
// If we don't have any information about this interfaceId, just ignore the
// address, since it won't be usable anyway without things like the port.
@@ -605,7 +621,7 @@
// on the system, because the hostnames we are looking up all end in
// ".local". In other words, we can get regular DNS results in here, not
// just DNS-SD ones.
- uint32_t interfaceId = interfaceKey.first;
+ auto interfaceId = interfaceKey.first;
if (interfaces.find(interfaceKey) == interfaces.end())
{
@@ -720,8 +736,7 @@
}
std::pair<uint32_t, std::string> interfaceKey = std::make_pair(interfaceId, domainFromHostname);
-
- interfaces.insert(std::make_pair(interfaceKey, std::move(interface)));
+ interfaces.insert(std::make_pair(std::move(interfaceKey), std::move(interface)));
}
bool ResolveContext::HasInterface()
diff --git a/src/platform/Darwin/DnssdImpl.cpp b/src/platform/Darwin/DnssdImpl.cpp
index bf0fc96..94bbdc4 100644
--- a/src/platform/Darwin/DnssdImpl.cpp
+++ b/src/platform/Darwin/DnssdImpl.cpp
@@ -32,8 +32,12 @@
namespace {
-// The extra time in milliseconds that we will wait for the resolution on the open thread domain to complete.
-constexpr uint16_t kOpenThreadTimeoutInMsec = 250;
+constexpr char kLocalDot[] = "local.";
+
+constexpr char kSrpDot[] = "default.service.arpa.";
+
+// The extra time in milliseconds that we will wait for the resolution on the srp domain to complete.
+constexpr uint16_t kSrpTimeoutInMsec = 250;
constexpr DNSServiceFlags kRegisterFlags = kDNSServiceFlagsNoAutoRename;
constexpr DNSServiceFlags kBrowseFlags = kDNSServiceFlagsShareConnection;
@@ -144,59 +148,57 @@
{
std::string hostname = std::string(hostnameWithDomain);
- // Find the last occurence of '.'
- size_t last_pos = hostname.find_last_of(".");
- if (last_pos != std::string::npos)
- {
- // Get a substring without last '.'
- std::string substring = hostname.substr(0, last_pos);
+ // Find the first occurence of '.'
+ size_t first_pos = hostname.find(".");
- // Find the last occurence of '.' in the substring created above.
- size_t pos = substring.find_last_of(".");
- if (pos != std::string::npos)
- {
- // Return the domain name between the last 2 occurences of '.' including the trailing dot'.'.
- return std::string(hostname.substr(pos + 1, last_pos));
- }
- }
- return std::string();
+ // if not found, return empty string
+ VerifyOrReturnValue(first_pos != std::string::npos, std::string());
+
+ // Get a substring after the first occurence of '.' to the end of the string
+ return hostname.substr(first_pos + 1, hostname.size());
+}
+
+/**
+ * @brief Callback that is called when the timeout for resolving on the kSrpDot domain has expired.
+ *
+ * @param[in] systemLayer The system layer.
+ * @param[in] callbackContext The context passed to the timer callback.
+ */
+void SrpTimerExpiredCallback(System::Layer * systemLayer, void * callbackContext)
+{
+ ChipLogProgress(Discovery, "Mdns: Timer expired for resolve to complete on the srp domain.");
+ auto sdCtx = static_cast<ResolveContext *>(callbackContext);
+ VerifyOrDie(sdCtx != nullptr);
+ sdCtx->Finalize();
+}
+
+/**
+ * @brief Starts a timer to wait for the resolution on the kSrpDot domain to happen.
+ *
+ * @param[in] timeoutSeconds The timeout in seconds.
+ * @param[in] ResolveContext The resolve context.
+ */
+CHIP_ERROR StartSrpTimer(uint16_t timeoutInMSecs, ResolveContext * ctx)
+{
+ VerifyOrReturnValue(ctx != nullptr, CHIP_ERROR_INCORRECT_STATE);
+ return DeviceLayer::SystemLayer().StartTimer(System::Clock::Milliseconds16(timeoutInMSecs), SrpTimerExpiredCallback,
+ reinterpret_cast<void *>(ctx));
+}
+
+/**
+ * @brief Cancels the timer that was started to wait for the resolution on the kSrpDot domain to happen.
+ *
+ * @param[in] ResolveContext The resolve context.
+ */
+void CancelSrpTimer(ResolveContext * ctx)
+{
+ DeviceLayer::SystemLayer().CancelTimer(SrpTimerExpiredCallback, reinterpret_cast<void *>(ctx));
}
Global<MdnsContexts> MdnsContexts::sInstance;
namespace {
-/**
- * @brief Callback that is called when the timeout for resolving on the kOpenThreadDot domain has expired.
- *
- * @param[in] systemLayer The system layer.
- * @param[in] callbackContext The context passed to the timer callback.
- */
-void OpenThreadTimerExpiredCallback(System::Layer * systemLayer, void * callbackContext)
-{
- ChipLogProgress(Discovery, "Mdns: Timer expired for resolve to complete on the open thread domain.");
- auto sdCtx = static_cast<ResolveContext *>(callbackContext);
- VerifyOrDie(sdCtx != nullptr);
-
- if (sdCtx->hasOpenThreadTimerStarted)
- {
- sdCtx->Finalize();
- }
-}
-
-/**
- * @brief Starts a timer to wait for the resolution on the kOpenThreadDot domain to happen.
- *
- * @param[in] timeoutSeconds The timeout in seconds.
- * @param[in] ResolveContext The resolve context.
- */
-void StartOpenThreadTimer(uint16_t timeoutInMSecs, ResolveContext * ctx)
-{
- VerifyOrReturn(ctx != nullptr, ChipLogError(Discovery, "Can't schedule open thread timer since context is null"));
- DeviceLayer::SystemLayer().StartTimer(System::Clock::Milliseconds16(timeoutInMSecs), OpenThreadTimerExpiredCallback,
- reinterpret_cast<void *>(ctx));
-}
-
static void OnRegister(DNSServiceRef sdRef, DNSServiceFlags flags, DNSServiceErrorType err, const char * name, const char * type,
const char * domain, void * context)
{
@@ -248,17 +250,17 @@
auto err = DNSServiceCreateConnection(&sdCtx->serviceRef);
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));
- // We will browse on both the local domain and the open thread domain.
+ // We will browse on both the local domain and the srp domain.
ChipLogProgress(Discovery, "Browsing for: %s on domain %s", StringOrNullMarker(type), kLocalDot);
auto sdRefLocal = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
err = DNSServiceBrowse(&sdRefLocal, kBrowseFlags, interfaceId, type, kLocalDot, OnBrowse, sdCtx);
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));
- ChipLogProgress(Discovery, "Browsing for: %s on domain %s", StringOrNullMarker(type), kOpenThreadDot);
+ ChipLogProgress(Discovery, "Browsing for: %s on domain %s", StringOrNullMarker(type), kSrpDot);
- DNSServiceRef sdRefOpenThread = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
- err = DNSServiceBrowse(&sdRefOpenThread, kBrowseFlags, interfaceId, type, kOpenThreadDot, OnBrowse, sdCtx);
+ auto sdRefSrp = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
+ err = DNSServiceBrowse(&sdRefSrp, kBrowseFlags, interfaceId, type, kSrpDot, OnBrowse, sdCtx);
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));
return MdnsContexts::GetInstance().Add(sdCtx, sdCtx->serviceRef);
@@ -307,25 +309,37 @@
{
VerifyOrReturn(sdCtx->HasAddress(), sdCtx->Finalize(kDNSServiceErr_BadState));
- if (domainName.compare(kOpenThreadDot) == 0)
+ if (domainName.compare(kSrpDot) == 0)
{
- ChipLogProgress(Discovery, "Mdns: Resolve completed on the open thread domain.");
+ ChipLogProgress(Discovery, "Mdns: Resolve completed on the srp domain.");
+
+ // Cancel the timer if one has been started
+ if (sdCtx->hasSrpTimerStarted)
+ {
+ CancelSrpTimer(sdCtx);
+ }
sdCtx->Finalize();
}
else if (domainName.compare(kLocalDot) == 0)
{
- ChipLogProgress(
- Discovery,
- "Mdns: Resolve completed on the local domain. Starting a timer for the open thread resolve to come back");
+ ChipLogProgress(Discovery,
+ "Mdns: Resolve completed on the local domain. Starting a timer for the srp resolve to come back");
- // Usually the resolution on the local domain is quicker than on the open thread domain. We would like to give the
- // resolution on the open thread domain around 250 millisecs more to give it a chance to resolve before finalizing
+ // Usually the resolution on the local domain is quicker than on the srp domain. We would like to give the
+ // resolution on the srp domain around 250 millisecs more to give it a chance to resolve before finalizing
// the resolution.
- if (!sdCtx->hasOpenThreadTimerStarted)
+ if (!sdCtx->hasSrpTimerStarted)
{
- // Schedule a timer to allow the resolve on OpenThread domain to complete.
- StartOpenThreadTimer(kOpenThreadTimeoutInMsec, sdCtx);
- sdCtx->hasOpenThreadTimerStarted = true;
+ // Schedule a timer to allow the resolve on Srp domain to complete.
+ CHIP_ERROR error = StartSrpTimer(kSrpTimeoutInMsec, sdCtx);
+
+ // If the timer fails to start, finalize the context and return.
+ if (error != CHIP_NO_ERROR)
+ {
+ sdCtx->Finalize();
+ return;
+ }
+ sdCtx->hasSrpTimerStarted = true;
}
}
}
@@ -367,8 +381,7 @@
if (!sdCtx->isResolveRequested)
{
GetAddrInfo(sdCtx);
- sdCtx->isResolveRequested = true;
- sdCtx->hasOpenThreadTimerStarted = false;
+ sdCtx->isResolveRequested = true;
}
}
}
@@ -382,13 +395,13 @@
auto err = DNSServiceCreateConnection(&sdCtx->serviceRef);
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));
- // Similar to browse, will try to resolve using both the local domain and the open thread domain.
+ // Similar to browse, will try to resolve using both the local domain and the srp domain.
auto sdRefLocal = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
err = DNSServiceResolve(&sdRefLocal, kResolveFlags, interfaceId, name, type, kLocalDot, OnResolve, sdCtx);
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));
- auto sdRefOpenThread = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
- err = DNSServiceResolve(&sdRefOpenThread, kResolveFlags, interfaceId, name, type, kOpenThreadDot, OnResolve, sdCtx);
+ auto sdRefSrp = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
+ err = DNSServiceResolve(&sdRefSrp, kResolveFlags, interfaceId, name, type, kSrpDot, OnResolve, sdCtx);
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));
auto retval = MdnsContexts::GetInstance().Add(sdCtx, sdCtx->serviceRef);
diff --git a/src/platform/Darwin/DnssdImpl.h b/src/platform/Darwin/DnssdImpl.h
index 713a465..8a98530 100644
--- a/src/platform/Darwin/DnssdImpl.h
+++ b/src/platform/Darwin/DnssdImpl.h
@@ -27,15 +27,17 @@
#include <string>
#include <vector>
-constexpr char kLocalDot[] = "local.";
-
-constexpr char kOpenThreadDot[] = "default.service.arpa.";
-
namespace chip {
namespace Dnssd {
+struct BrowseWithDelegateContext;
+struct RegisterContext;
+struct ResolveContext;
+
std::string GetDomainFromHostName(const char * hostname);
+void CancelSrpTimer(ResolveContext * ctx);
+
enum class ContextType
{
Register,
@@ -62,10 +64,6 @@
CHIP_ERROR FinalizeInternal(const char * errorStr, CHIP_ERROR err);
};
-struct BrowseWithDelegateContext;
-struct RegisterContext;
-struct ResolveContext;
-
class MdnsContexts
{
public:
@@ -239,7 +237,7 @@
bool isResolveRequested = false;
std::shared_ptr<uint32_t> consumerCounter;
BrowseContext * const browseThatCausedResolve; // Can be null
- bool hasOpenThreadTimerStarted = false;
+ bool hasSrpTimerStarted = false;
// browseCausingResolve can be null.
ResolveContext(void * cbContext, DnssdResolveCallback cb, chip::Inet::IPAddressType cbAddressType,
@@ -252,7 +250,7 @@
void DispatchFailure(const char * errorStr, CHIP_ERROR err) override;
void DispatchSuccess() override;
- CHIP_ERROR OnNewAddress(const std::pair<uint32_t, std::string> interfaceKey, const struct sockaddr * address);
+ CHIP_ERROR OnNewAddress(const std::pair<uint32_t, std::string> & interfaceKey, const struct sockaddr * address);
bool HasAddress();
void OnNewInterface(uint32_t interfaceId, const char * fullname, const char * hostname, uint16_t port, uint16_t txtLen,
@@ -266,7 +264,7 @@
* Returns true if information was reported, false if not (e.g. if there
* were no IP addresses, etc).
*/
- bool TryReportingResultsForInterfaceIndex(uint32_t interfaceIndex, std::string domainName);
+ bool TryReportingResultsForInterfaceIndex(uint32_t interfaceIndex);
};
} // namespace Dnssd