blob: f01c00a71d6fa02f21fb471674ad62ec60d44b69 [file] [log] [blame]
# Copyright 2020 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Builds and manages databases of tokenized strings."""
import collections
import csv
from dataclasses import dataclass
from datetime import datetime
import io
import logging
from pathlib import Path
import re
import struct
from typing import (BinaryIO, Callable, Dict, Iterable, Iterator, List,
NamedTuple, Optional, Pattern, Tuple, Union, ValuesView)
DATE_FORMAT = '%Y-%m-%d'
DEFAULT_DOMAIN = ''
# The default hash length to use. This value only applies when hashing strings
# from a legacy-style ELF with plain strings. New tokenized string entries
# include the token alongside the string.
#
# This MUST match the default value of PW_TOKENIZER_CFG_C_HASH_LENGTH in
# pw_tokenizer/public/pw_tokenizer/config.h.
DEFAULT_C_HASH_LENGTH = 128
TOKENIZER_HASH_CONSTANT = 65599
_LOG = logging.getLogger('pw_tokenizer')
def _value(char: Union[int, str]) -> int:
return char if isinstance(char, int) else ord(char)
def pw_tokenizer_65599_hash(string: Union[str, bytes],
hash_length: int = None) -> int:
"""Hashes the provided string.
This hash function is only used when adding tokens from legacy-style
tokenized strings in an ELF, which do not include the token.
"""
hash_value = len(string)
coefficient = TOKENIZER_HASH_CONSTANT
for char in string[:hash_length]:
hash_value = (hash_value + coefficient * _value(char)) % 2**32
coefficient = (coefficient * TOKENIZER_HASH_CONSTANT) % 2**32
return hash_value
def default_hash(string: Union[str, bytes]) -> int:
return pw_tokenizer_65599_hash(string, DEFAULT_C_HASH_LENGTH)
class _EntryKey(NamedTuple):
"""Uniquely refers to an entry."""
token: int
string: str
@dataclass(eq=True, order=False)
class TokenizedStringEntry:
"""A tokenized string with its metadata."""
token: int
string: str
domain: str = DEFAULT_DOMAIN
date_removed: Optional[datetime] = None
def key(self) -> _EntryKey:
"""The key determines uniqueness for a tokenized string."""
return _EntryKey(self.token, self.string)
def update_date_removed(self,
new_date_removed: Optional[datetime]) -> None:
"""Sets self.date_removed if the other date is newer."""
# No removal date (None) is treated as the newest date.
if self.date_removed is None:
return
if new_date_removed is None or new_date_removed > self.date_removed:
self.date_removed = new_date_removed
def __lt__(self, other) -> bool:
"""Sorts the entry by token, date removed, then string."""
if self.token != other.token:
return self.token < other.token
# Sort removal dates in reverse, so the most recently removed (or still
# present) entry appears first.
if self.date_removed != other.date_removed:
return (other.date_removed or datetime.max) < (self.date_removed
or datetime.max)
return self.string < other.string
def __str__(self) -> str:
return self.string
class Database:
"""Database of tokenized strings stored as TokenizedStringEntry objects."""
def __init__(self, entries: Iterable[TokenizedStringEntry] = ()):
"""Creates a token database."""
# The database dict stores each unique (token, string) entry.
self._database: Dict[_EntryKey, TokenizedStringEntry] = {
entry.key(): entry
for entry in entries
}
# This is a cache for fast token lookup that is built as needed.
self._cache: Optional[Dict[int, List[TokenizedStringEntry]]] = None
@classmethod
def from_strings(
cls,
strings: Iterable[str],
domain: str = DEFAULT_DOMAIN,
tokenize: Callable[[str], int] = default_hash) -> 'Database':
"""Creates a Database from an iterable of strings."""
return cls((TokenizedStringEntry(tokenize(string), string, domain)
for string in strings))
@classmethod
def merged(cls, *databases: 'Database') -> 'Database':
"""Creates a TokenDatabase from one or more other databases."""
db = cls()
db.merge(*databases)
return db
@property
def token_to_entries(self) -> Dict[int, List[TokenizedStringEntry]]:
"""Returns a dict that maps tokens to a list of TokenizedStringEntry."""
if self._cache is None: # build cache token -> entry cache
self._cache = collections.defaultdict(list)
for entry in self._database.values():
self._cache[entry.token].append(entry)
return self._cache
def entries(self) -> ValuesView[TokenizedStringEntry]:
"""Returns iterable over all TokenizedStringEntries in the database."""
return self._database.values()
def collisions(self) -> Iterator[Tuple[int, List[TokenizedStringEntry]]]:
"""Returns tuple of (token, entries_list)) for all colliding tokens."""
for token, entries in self.token_to_entries.items():
if len(entries) > 1:
yield token, entries
def mark_removed(
self,
all_entries: Iterable[TokenizedStringEntry],
removal_date: Optional[datetime] = None
) -> List[TokenizedStringEntry]:
"""Marks entries missing from all_entries as having been removed.
The entries are assumed to represent the complete set of entries for the
database. Entries currently in the database not present in the provided
entries are marked with a removal date but remain in the database.
Entries in all_entries missing from the database are NOT added; call the
add function to add these.
Args:
all_entries: the complete set of strings present in the database
removal_date: the datetime for removed entries; today by default
Returns:
A list of entries marked as removed.
"""
self._cache = None
if removal_date is None:
removal_date = datetime.now()
all_keys = frozenset(entry.key() for entry in all_entries)
removed = []
for entry in self._database.values():
if (entry.key() not in all_keys
and (entry.date_removed is None
or removal_date < entry.date_removed)):
# Add a removal date, or update it to the oldest date.
entry.date_removed = removal_date
removed.append(entry)
return removed
def add(self, entries: Iterable[TokenizedStringEntry]) -> None:
"""Adds new entries and updates date_removed for existing entries."""
self._cache = None
for new_entry in entries:
# Update an existing entry or create a new one.
try:
entry = self._database[new_entry.key()]
entry.domain = new_entry.domain
entry.date_removed = None
except KeyError:
self._database[new_entry.key()] = TokenizedStringEntry(
new_entry.token, new_entry.string, new_entry.domain)
def purge(
self,
date_removed_cutoff: Optional[datetime] = None
) -> List[TokenizedStringEntry]:
"""Removes and returns entries removed on/before date_removed_cutoff."""
self._cache = None
if date_removed_cutoff is None:
date_removed_cutoff = datetime.max
to_delete = [
entry for _, entry in self._database.items()
if entry.date_removed and entry.date_removed <= date_removed_cutoff
]
for entry in to_delete:
del self._database[entry.key()]
return to_delete
def merge(self, *databases: 'Database') -> None:
"""Merges two or more databases together, keeping the newest dates."""
self._cache = None
for other_db in databases:
for entry in other_db.entries():
key = entry.key()
if key in self._database:
self._database[key].update_date_removed(entry.date_removed)
else:
self._database[key] = entry
def filter(
self,
include: Iterable[Union[str, Pattern[str]]] = (),
exclude: Iterable[Union[str, Pattern[str]]] = (),
replace: Iterable[Tuple[Union[str, Pattern[str]], str]] = ()
) -> None:
"""Filters the database using regular expressions (strings or compiled).
Args:
include: regexes; only entries matching at least one are kept
exclude: regexes; entries matching any of these are removed
replace: (regex, str) tuples; replaces matching terms in all entries
"""
self._cache = None
to_delete: List[_EntryKey] = []
if include:
include_re = [re.compile(pattern) for pattern in include]
to_delete.extend(
key for key, val in self._database.items()
if not any(rgx.search(val.string) for rgx in include_re))
if exclude:
exclude_re = [re.compile(pattern) for pattern in exclude]
to_delete.extend(key for key, val in self._database.items() if any(
rgx.search(val.string) for rgx in exclude_re))
for key in to_delete:
del self._database[key]
for search, replacement in replace:
search = re.compile(search)
for value in self._database.values():
value.string = search.sub(replacement, value.string)
def __len__(self) -> int:
"""Returns the number of entries in the database."""
return len(self.entries())
def __str__(self) -> str:
"""Outputs the database as CSV."""
csv_output = io.BytesIO()
write_csv(self, csv_output)
return csv_output.getvalue().decode()
def parse_csv(fd) -> Iterable[TokenizedStringEntry]:
"""Parses TokenizedStringEntries from a CSV token database file."""
for line in csv.reader(fd):
try:
token_str, date_str, string_literal = line
token = int(token_str, 16)
date = (datetime.strptime(date_str, DATE_FORMAT)
if date_str.strip() else None)
yield TokenizedStringEntry(token, string_literal, DEFAULT_DOMAIN,
date)
except (ValueError, UnicodeDecodeError) as err:
_LOG.error('Failed to parse tokenized string entry %s: %s', line,
err)
def write_csv(database: Database, fd: BinaryIO) -> None:
"""Writes the database as CSV to the provided binary file."""
for entry in sorted(database.entries()):
# Align the CSV output to 10-character columns for improved readability.
# Use \n instead of RFC 4180's \r\n.
fd.write('{:08x},{:10},"{}"\n'.format(
entry.token,
entry.date_removed.strftime(DATE_FORMAT) if entry.date_removed else
'', entry.string.replace('"', '""')).encode()) # escape " as ""
class _BinaryFileFormat(NamedTuple):
"""Attributes of the binary token database file format."""
magic: bytes = b'TOKENS\0\0'
header: struct.Struct = struct.Struct('<8sI4x')
entry: struct.Struct = struct.Struct('<IBBH')
BINARY_FORMAT = _BinaryFileFormat()
class DatabaseFormatError(Exception):
"""Failed to parse a token database file."""
def file_is_binary_database(fd: BinaryIO) -> bool:
"""True if the file starts with the binary token database magic string."""
try:
fd.seek(0)
magic = fd.read(len(BINARY_FORMAT.magic))
fd.seek(0)
return BINARY_FORMAT.magic == magic
except IOError:
return False
def _check_that_file_is_csv_database(path: Path) -> None:
"""Raises an error unless the path appears to be a CSV token database."""
try:
with path.open('rb') as fd:
data = fd.read(8) # Read 8 bytes, which should be the first token.
if not data:
return # File is empty, which is valid CSV.
if len(data) != 8:
raise DatabaseFormatError(
f'Attempted to read {path} as a CSV token database, but the '
f'file is too short ({len(data)} B)')
# Make sure the first 8 chars are a valid hexadecimal number.
_ = int(data.decode(), 16)
except (IOError, UnicodeDecodeError, ValueError) as err:
raise DatabaseFormatError(
f'Encountered error while reading {path} as a CSV token database'
) from err
def parse_binary(fd: BinaryIO) -> Iterable[TokenizedStringEntry]:
"""Parses TokenizedStringEntries from a binary token database file."""
magic, entry_count = BINARY_FORMAT.header.unpack(
fd.read(BINARY_FORMAT.header.size))
if magic != BINARY_FORMAT.magic:
raise DatabaseFormatError(
f'Binary token database magic number mismatch (found {magic!r}, '
f'expected {BINARY_FORMAT.magic!r}) while reading from {fd}')
entries = []
for _ in range(entry_count):
token, day, month, year = BINARY_FORMAT.entry.unpack(
fd.read(BINARY_FORMAT.entry.size))
try:
date_removed: Optional[datetime] = datetime(year, month, day)
except ValueError:
date_removed = None
entries.append((token, date_removed))
# Read the entire string table and define a function for looking up strings.
string_table = fd.read()
def read_string(start):
end = string_table.find(b'\0', start)
return string_table[start:string_table.find(b'\0', start)].decode(
), end + 1
offset = 0
for token, removed in entries:
string, offset = read_string(offset)
yield TokenizedStringEntry(token, string, DEFAULT_DOMAIN, removed)
def write_binary(database: Database, fd: BinaryIO) -> None:
"""Writes the database as packed binary to the provided binary file."""
entries = sorted(database.entries())
fd.write(BINARY_FORMAT.header.pack(BINARY_FORMAT.magic, len(entries)))
string_table = bytearray()
for entry in entries:
if entry.date_removed:
removed_day = entry.date_removed.day
removed_month = entry.date_removed.month
removed_year = entry.date_removed.year
else:
# If there is no removal date, use the special value 0xffffffff for
# the day/month/year. That ensures that still-present tokens appear
# as the newest tokens when sorted by removal date.
removed_day = 0xff
removed_month = 0xff
removed_year = 0xffff
string_table += entry.string.encode()
string_table.append(0)
fd.write(
BINARY_FORMAT.entry.pack(entry.token, removed_day, removed_month,
removed_year))
fd.write(string_table)
class DatabaseFile(Database):
"""A token database that is associated with a particular file.
This class adds the write_to_file() method that writes to file from which it
was created in the correct format (CSV or binary).
"""
def __init__(self, path: Union[Path, str]):
self.path = Path(path)
# Read the path as a packed binary file.
with self.path.open('rb') as fd:
if file_is_binary_database(fd):
super().__init__(parse_binary(fd))
self._export = write_binary
return
# Read the path as a CSV file.
_check_that_file_is_csv_database(self.path)
with self.path.open('r', newline='', encoding='utf-8') as file:
super().__init__(parse_csv(file))
self._export = write_csv
def write_to_file(self, path: Optional[Union[Path, str]] = None) -> None:
"""Exports in the original format to the original or provided path."""
with open(self.path if path is None else path, 'wb') as fd:
self._export(self, fd)