blob: 3e468f3e6238a301be68e778e25b3401076fb45a [file] [log] [blame]
/*
*
* Copyright (c) 2021 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <lib/dnssd/minimal_mdns/ResponseSender.h>
#include <chrono>
#include <condition_variable>
#include <string>
#include <vector>
#include <lib/dnssd/MinimalMdnsServer.h>
#include <lib/dnssd/minimal_mdns/RecordData.h>
#include <lib/dnssd/minimal_mdns/Server.h>
#include <lib/dnssd/minimal_mdns/records/Ptr.h>
#include <lib/dnssd/minimal_mdns/records/Srv.h>
#include <lib/dnssd/minimal_mdns/records/Txt.h>
#include <lib/support/CHIPMemString.h>
#include <lib/support/UnitTestRegistration.h>
#include <system/SystemMutex.h>
#include <nlunit-test.h>
namespace mdns {
namespace Minimal {
namespace test {
constexpr QNamePart kIgnoreQNameParts[] = { "IGNORE", "THIS" };
namespace {
bool StringMatches(const BytesRange & br, const char * str)
{
return br.Size() == strlen(str) && memcmp(str, br.Start(), br.Size()) == 0;
}
template <size_t N>
void MakePrintableName(char (&location)[N], SerializedQNameIterator name)
{
auto buf = chip::Encoding::BigEndian::BufferWriter(reinterpret_cast<uint8_t *>(&location[0]), N);
while (name.Next())
{
buf.Put(name.Value());
buf.Put(".");
}
buf.Put('\0');
}
template <size_t N>
void MakePrintableName(char (&location)[N], FullQName name)
{
auto buf = chip::Encoding::BigEndian::BufferWriter(reinterpret_cast<uint8_t *>(&location[0]), N);
for (size_t i = 0; i < name.nameCount; ++i)
{
buf.Put(name.names[i]);
buf.Put(".");
}
buf.Put('\0');
}
} // namespace
class CheckOnlyServer : private chip::PoolImpl<ServerBase::EndpointInfo, 0, chip::ObjectPoolMem::kInline,
ServerBase::EndpointInfoPoolType::Interface>,
public ServerBase,
public ParserDelegate,
public TxtRecordDelegate
{
public:
CheckOnlyServer(nlTestSuite * inSuite) : ServerBase(*static_cast<ServerBase::EndpointInfoPoolType *>(this)), mInSuite(inSuite)
{
Reset();
}
CheckOnlyServer() : ServerBase(*static_cast<ServerBase::EndpointInfoPoolType *>(this)), mInSuite(nullptr) { Reset(); }
~CheckOnlyServer() {}
// Parser delegates
void OnHeader(ConstHeaderRef & header) override
{
NL_TEST_ASSERT(mInSuite, header.GetFlags().IsResponse());
NL_TEST_ASSERT(mInSuite, header.GetFlags().IsValidMdns());
mTotalRecords += header.GetAnswerCount() + header.GetAdditionalCount();
if (!header.GetFlags().IsTruncated())
{
NL_TEST_ASSERT(mInSuite, mTotalRecords == GetNumExpectedRecords());
if (mTotalRecords != GetNumExpectedRecords())
{
ChipLogError(Discovery, "Received %d records, expected %d", mTotalRecords, GetNumExpectedRecords());
}
mHeaderFound = true;
}
}
void OnResource(ResourceType type, const ResourceData & data) override
{
SerializedQNameIterator target;
switch (data.GetType())
{
case QType::PTR:
ParsePtrRecord(data.GetData(), mPacketData, &target);
break;
case QType::SRV: {
SrvRecord srv;
bool srvParseOk = srv.Parse(data.GetData(), mPacketData);
NL_TEST_ASSERT(mInSuite, srvParseOk);
if (!srvParseOk)
{
return;
}
target = srv.GetName();
break;
}
default:
break;
}
bool recordIsExpected = false;
for (auto & info : mExpectedRecordInfo)
{
if (info.record == nullptr || info.found)
{
continue;
}
if (data.GetType() == info.record->GetType() &&
(info.record->GetName() == kIgnoreQname || data.GetName() == info.record->GetName()) &&
(info.target == kIgnoreQname || target == info.target))
{
if (data.GetType() == QType::TXT)
{
// First parse out the expected record to see what keys/values we have.
ClearTxtRecords();
const TxtResourceRecord * expectedTxt = static_cast<const TxtResourceRecord *>(info.record);
for (size_t t = 0; t < expectedTxt->GetNumEntries(); ++t)
{
bool ok = AddExpectedTxtRecord(expectedTxt->GetEntries()[t]);
NL_TEST_ASSERT(mInSuite, ok);
}
ParseTxtRecord(data.GetData(), this);
if (CheckTxtRecordMatches())
{
info.found = true;
recordIsExpected = true;
break;
}
}
else
{
info.found = true;
recordIsExpected = true;
break;
}
}
}
NL_TEST_ASSERT(mInSuite, recordIsExpected);
if (!recordIsExpected)
{
char nameStr[64];
char targetStr[64];
SerializedQNameIterator dataTarget;
SerializedQNameIterator it = data.GetName();
MakePrintableName(nameStr, it);
switch (data.GetType())
{
case QType::PTR:
ParsePtrRecord(data.GetData(), data.GetData(), &dataTarget);
break;
case QType::SRV: {
SrvRecord srv;
if (srv.Parse(data.GetData(), data.GetData()))
{
dataTarget = srv.GetName();
}
break;
}
default:
break;
}
MakePrintableName(targetStr, dataTarget);
ChipLogError(Discovery, "Received unexpected record of type %u: %s %s", static_cast<uint16_t>(data.GetType()), nameStr,
targetStr);
}
}
void OnQuery(const QueryData & data) override {}
// TxtRecordDelegate
void OnRecord(const BytesRange & name, const BytesRange & value) override
{
for (size_t i = 0; i < mNumExpectedTxtRecords; ++i)
{
if (StringMatches(name, mExpectedTxt[i].key) && StringMatches(value, mExpectedTxt[i].val))
{
mExpectedTxt[i].found = true;
break;
}
}
mNumReceivedTxtRecords++;
}
bool CheckTxtRecordMatches()
{
if (mNumReceivedTxtRecords != mNumExpectedTxtRecords)
{
return false;
}
for (size_t i = 0; i < mNumExpectedTxtRecords; ++i)
{
if (!mExpectedTxt[i].found)
{
return false;
}
}
return true;
}
// ServerBase overrides
CHIP_ERROR
DirectSend(chip::System::PacketBufferHandle && data, const chip::Inet::IPAddress & addr, uint16_t port,
chip::Inet::InterfaceId interface) override
{
mPacketData = BytesRange(data->Start(), data->Start() + data->TotalLength());
ParsePacket(mPacketData, this);
if (mHeaderFound)
{
TestGotAllExpectedPackets();
}
mSendCalled = true;
return CHIP_NO_ERROR;
}
// Functions used for controlling testing.
void AddExpectedRecord(PtrResourceRecord * ptr)
{
RecordInfo * info = AddExpectedRecordBase(ptr);
if (info == nullptr)
{
return;
}
info->target = ptr->GetPtr();
}
void AddExpectedRecord(SrvResourceRecord * srv)
{
RecordInfo * info = AddExpectedRecordBase(srv);
NL_TEST_ASSERT(mInSuite, info != nullptr);
if (info == nullptr)
{
return;
}
info->target = srv->GetServerName();
}
void AddExpectedRecord(TxtResourceRecord * txt)
{
RecordInfo * info = AddExpectedRecordBase(txt);
NL_TEST_ASSERT(mInSuite, info != nullptr);
if (info == nullptr)
{
return;
}
info->target = kIgnoreQname;
}
bool GetSendCalled() { return mSendCalled; }
bool GetHeaderFound() { return mHeaderFound; }
void SetTestSuite(nlTestSuite * suite) { mInSuite = suite; }
void Reset()
{
for (auto & info : mExpectedRecordInfo)
{
info.record = nullptr;
info.found = false;
}
mHeaderFound = false;
mSendCalled = false;
mTotalRecords = 0;
ClearTxtRecords();
}
private:
nlTestSuite * mInSuite;
static constexpr size_t kMaxExpectedRecords = 10;
struct RecordInfo
{
ResourceRecord * record;
bool found = false;
FullQName target;
};
RecordInfo mExpectedRecordInfo[kMaxExpectedRecords];
struct KV
{
static constexpr size_t kMaxKey = 10;
static constexpr size_t kMaxVal = 128; // max pairing instruction len + 1
char key[kMaxKey + 1] = "";
char val[kMaxVal + 1] = "";
bool found = false;
bool operator==(const KV & rhs) const { return strcmp(key, rhs.key) == 0 && strcmp(val, rhs.val) == 0; }
void Clear()
{
memset(key, 0, sizeof(key));
memset(val, 0, sizeof(val));
found = false;
}
};
static constexpr size_t kMaxExpectedTxt = 11;
KV mExpectedTxt[kMaxExpectedTxt];
size_t mNumExpectedTxtRecords = 0;
size_t mNumReceivedTxtRecords = 0;
bool mHeaderFound = false;
bool mSendCalled = false;
int mTotalRecords = 0;
FullQName kIgnoreQname = FullQName(kIgnoreQNameParts);
BytesRange mPacketData;
int GetNumExpectedRecords() const
{
int num = 0;
for (auto & info : mExpectedRecordInfo)
{
if (info.record != nullptr)
{
++num;
}
}
return num;
}
void TestGotAllExpectedPackets()
{
if (mInSuite == nullptr)
{
return;
}
for (auto & info : mExpectedRecordInfo)
{
if (info.record == nullptr)
{
continue;
}
NL_TEST_ASSERT(mInSuite, info.found == true);
if (!info.found)
{
char name[64];
char target[64];
MakePrintableName(name, info.record->GetName());
MakePrintableName(target, info.target);
ChipLogError(Discovery, "Did not receive expected record of type %u : %s %s",
static_cast<uint16_t>(info.record->GetType()), name, target);
}
}
}
void ClearTxtRecords()
{
for (auto & kv : mExpectedTxt)
{
kv.Clear();
}
mNumExpectedTxtRecords = 0;
mNumReceivedTxtRecords = 0;
}
bool AddExpectedTxtRecord(const char * const entry)
{
if (mNumExpectedTxtRecords == kMaxExpectedTxt)
{
return false;
}
size_t lenKey = strlen(entry);
const char * equal = strchr(entry, '=');
if (equal != nullptr)
{
chip::Platform::CopyString(mExpectedTxt[mNumExpectedTxtRecords].val, (equal + 1));
lenKey = static_cast<size_t>(equal - entry);
}
chip::ByteSpan key = chip::ByteSpan(reinterpret_cast<const uint8_t *>(entry), lenKey);
chip::Platform::CopyString(mExpectedTxt[mNumExpectedTxtRecords++].key, key);
return true;
}
RecordInfo * AddExpectedRecordBase(ResourceRecord * record)
{
for (auto & info : mExpectedRecordInfo)
{
if (info.record == nullptr)
{
info.record = record;
info.found = false;
return &info;
}
}
return nullptr;
}
};
struct ServerSwapper
{
ServerSwapper(CheckOnlyServer * server)
{
chip::Dnssd::GlobalMinimalMdnsServer::Instance().Server().Shutdown();
chip::Dnssd::GlobalMinimalMdnsServer::Instance().SetReplacementServer(server);
}
~ServerSwapper() { chip::Dnssd::GlobalMinimalMdnsServer::Instance().SetReplacementServer(nullptr); }
};
} // namespace test
} // namespace Minimal
} // namespace mdns