blob: 685d07412a5aec93adee333dcf804e685ed22289 [file] [log] [blame]
# Copyright 2020 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Builds and manages databases of tokenized strings."""
import collections
import csv
from datetime import datetime
import io
import logging
from pathlib import Path
import re
import struct
from typing import BinaryIO, Callable, Dict, Iterable, List, NamedTuple
from typing import Optional, Tuple, Union, ValuesView
DATE_FORMAT = '%Y-%m-%d'
# The default hash length to use. This MUST match the default value of
# PW_TOKENIZER_CFG_HASH_LENGTH in pw_tokenizer/public/pw_tokenizer/config.h.
DEFAULT_HASH_LENGTH = 128
TOKENIZER_HASH_CONSTANT = 65599
_LOG = logging.getLogger('pw_tokenizer')
def _value(char: Union[int, str]) -> int:
return char if isinstance(char, int) else ord(char)
def pw_tokenizer_65599_fixed_length_hash(string: Union[str, bytes],
hash_length: int) -> int:
"""Hashes the provided string."""
hash_value = len(string)
coefficient = TOKENIZER_HASH_CONSTANT
for char in string[:hash_length]:
hash_value = (hash_value + coefficient * _value(char)) % 2**32
coefficient = (coefficient * TOKENIZER_HASH_CONSTANT) % 2**32
return hash_value
def default_hash(string: Union[str, bytes]) -> int:
return pw_tokenizer_65599_fixed_length_hash(string, DEFAULT_HASH_LENGTH)
_EntryKey = Tuple[int, str] # Key for uniquely referring to an entry
class TokenizedStringEntry:
"""A tokenized string with its metadata."""
def __init__(self,
token: int,
string: str,
date_removed: Optional[datetime] = None):
self.token = token
self.string = string
self.date_removed = date_removed
def key(self) -> _EntryKey:
"""The key determines uniqueness for a tokenized string."""
return self.token, self.string
def update_date_removed(self,
new_date_removed: Optional[datetime]) -> None:
"""Sets self.date_removed if the other date is newer."""
# No removal date (None) is treated as the newest date.
if self.date_removed is None:
return
if new_date_removed is None or new_date_removed > self.date_removed:
self.date_removed = new_date_removed
def __lt__(self, other) -> bool:
"""Sorts the entry by token, date removed, then string."""
if self.token != other.token:
return self.token < other.token
# Sort removal dates in reverse, so the most recently removed (or still
# present) entry appears first.
if self.date_removed != other.date_removed:
return (other.date_removed or datetime.max) < (self.date_removed
or datetime.max)
return self.string < other.string
def __str__(self) -> str:
return self.string
def __repr__(self) -> str:
return '{}({!r})'.format(type(self).__name__, self.string)
class Database:
"""Database of tokenized strings stored as TokenizedStringEntry objects."""
def __init__(self,
entries: Iterable[TokenizedStringEntry] = (),
tokenize: Callable[[str], int] = default_hash):
"""Creates a token database."""
# The database dict stores each unique (token, string) entry.
self._database: Dict[_EntryKey, TokenizedStringEntry] = {
entry.key(): entry
for entry in entries
}
self.tokenize = tokenize
# This is a cache for fast token lookup that is built as needed.
self._cache: Optional[Dict[int, List[TokenizedStringEntry]]] = None
@classmethod
def from_strings(
cls,
strings: Iterable[str],
tokenize: Callable[[str], int] = default_hash) -> 'Database':
"""Creates a Database from an iterable of strings."""
return cls((TokenizedStringEntry(tokenize(string), string)
for string in strings), tokenize)
@classmethod
def merged(cls, *databases: 'Database') -> 'Database':
"""Creates a TokenDatabase from one or more other databases."""
db = cls()
db.merge(*databases)
return db
@property
def token_to_entries(self) -> Dict[int, List[TokenizedStringEntry]]:
"""Returns a dict that maps tokens to a list of TokenizedStringEntry."""
if self._cache is None: # build cache token -> entry cache
self._cache = collections.defaultdict(list)
for entry in self._database.values():
self._cache[entry.token].append(entry)
return self._cache
def entries(self) -> ValuesView[TokenizedStringEntry]:
"""Returns iterable over all TokenizedStringEntries in the database."""
return self._database.values()
def collisions(self) -> Tuple[Tuple[int, List[TokenizedStringEntry]], ...]:
"""Returns tuple of (token, entries_list)) for all colliding tokens."""
return tuple((token, entries)
for token, entries in self.token_to_entries.items()
if len(entries) > 1)
def mark_removals(
self,
all_strings: Iterable[str],
removal_date: Optional[datetime] = None
) -> List[TokenizedStringEntry]:
"""Marks strings missing from all_strings as having been removed.
The strings are assumed to represent the complete set of strings for the
database. Strings currently in the database not present in the provided
strings are marked with a removal date but remain in the database.
Strings in all_strings missing from the database are NOT added; call the
add function to add these strings.
Args:
all_strings: the complete set of strings present in the database
removal_date: the datetime for removed entries; today by default
Returns:
A list of entries marked as removed.
"""
self._cache = None
if removal_date is None:
removal_date = datetime.now()
all_strings = frozenset(all_strings) # for faster lookup
removed = []
# Mark this entry as having been removed from the ELF.
for entry in self._database.values():
if (entry.string not in all_strings
and (entry.date_removed is None
or removal_date < entry.date_removed)):
# Add a removal date, or update it to the oldest date.
entry.date_removed = removal_date
removed.append(entry)
return removed
def add(self, strings: Iterable[str]) -> None:
"""Adds new strings to the database."""
self._cache = None
# Add new and update previously removed entries.
for string in strings:
key = self.tokenize(string), string
try:
entry = self._database[key]
if entry.date_removed:
entry.date_removed = None
except KeyError:
self._database[key] = TokenizedStringEntry(key[0], string)
def purge(
self,
date_removed_cutoff: Optional[datetime] = None
) -> List[TokenizedStringEntry]:
"""Removes and returns entries removed on/before date_removed_cutoff."""
self._cache = None
if date_removed_cutoff is None:
date_removed_cutoff = datetime.max
to_delete = [
entry for _, entry in self._database.items()
if entry.date_removed and entry.date_removed <= date_removed_cutoff
]
for entry in to_delete:
del self._database[entry.key()]
return to_delete
def merge(self, *databases: 'Database') -> None:
"""Merges two or more databases together, keeping the newest dates."""
self._cache = None
for other_db in databases:
for entry in other_db.entries():
key = entry.key()
if key in self._database:
self._database[key].update_date_removed(entry.date_removed)
else:
self._database[key] = entry
def filter(self, include: Iterable = (), exclude: Iterable = ()) -> None:
"""Filters the database using regular expressions (strings or compiled).
Args:
include: iterable of regexes; only entries matching at least one are kept
exclude: iterable of regexes; entries matching any of these are removed
"""
self._cache = None
to_delete: List[_EntryKey] = []
if include:
include_re = [re.compile(pattern) for pattern in include]
to_delete.extend(
key for key, val in self._database.items()
if not any(rgx.search(val.string) for rgx in include_re))
if exclude:
exclude_re = [re.compile(pattern) for pattern in exclude]
to_delete.extend(key for key, val in self._database.items() #
if any(
rgx.search(val.string) for rgx in exclude_re))
for key in to_delete:
del self._database[key]
def __len__(self) -> int:
"""Returns the number of entries in the database."""
return len(self.entries())
def __str__(self) -> str:
"""Outputs the database as CSV."""
csv_output = io.BytesIO()
write_csv(self, csv_output)
return csv_output.getvalue().decode()
def parse_csv(fd) -> Iterable[TokenizedStringEntry]:
"""Parses TokenizedStringEntries from a CSV token database file."""
for line in csv.reader(fd):
try:
token_str, date_str, string_literal = line
token = int(token_str, 16)
date = (datetime.strptime(date_str, DATE_FORMAT)
if date_str.strip() else None)
yield TokenizedStringEntry(token, string_literal, date)
except (ValueError, UnicodeDecodeError) as err:
_LOG.error('Failed to parse tokenized string entry %s: %s', line,
err)
def write_csv(database: Database, fd: BinaryIO) -> None:
"""Writes the database as CSV to the provided binary file."""
for entry in sorted(database.entries()):
# Align the CSV output to 10-character columns for improved readability.
# Use \n instead of RFC 4180's \r\n.
fd.write('{:08x},{:10},"{}"\n'.format(
entry.token,
entry.date_removed.strftime(DATE_FORMAT) if entry.date_removed else
'', entry.string.replace('"', '""')).encode()) # escape " as ""
class _BinaryFileFormat(NamedTuple):
"""Attributes of the binary token database file format."""
magic: bytes = b'TOKENS\0\0'
header: struct.Struct = struct.Struct('<8sI4x')
entry: struct.Struct = struct.Struct('<IBBH')
BINARY_FORMAT = _BinaryFileFormat()
def file_is_binary_database(fd: BinaryIO) -> bool:
"""True if the file starts with the binary token database magic string."""
try:
fd.seek(0)
magic = fd.read(len(BINARY_FORMAT.magic))
fd.seek(0)
return BINARY_FORMAT.magic == magic
except IOError:
return False
def parse_binary(fd: BinaryIO) -> Iterable[TokenizedStringEntry]:
"""Parses TokenizedStringEntries from a binary token database file."""
magic, entry_count = BINARY_FORMAT.header.unpack(
fd.read(BINARY_FORMAT.header.size))
if magic != BINARY_FORMAT.magic:
raise ValueError(
'Magic number mismatch (found {!r}, expected {!r})'.format(
magic, BINARY_FORMAT.magic))
entries = []
for _ in range(entry_count):
token, day, month, year = BINARY_FORMAT.entry.unpack(
fd.read(BINARY_FORMAT.entry.size))
try:
date_removed: Optional[datetime] = datetime(year, month, day)
except ValueError:
date_removed = None
entries.append((token, date_removed))
# Read the entire string table and define a function for looking up strings.
string_table = fd.read()
def read_string(start):
end = string_table.find(b'\0', start)
return string_table[start:string_table.find(b'\0', start)].decode(
), end + 1
offset = 0
for token, removed in entries:
string, offset = read_string(offset)
yield TokenizedStringEntry(token, string, removed)
def write_binary(database: Database, fd: BinaryIO) -> None:
"""Writes the database as packed binary to the provided binary file."""
entries = sorted(database.entries())
fd.write(BINARY_FORMAT.header.pack(BINARY_FORMAT.magic, len(entries)))
string_table = bytearray()
for entry in entries:
if entry.date_removed:
removed_day = entry.date_removed.day
removed_month = entry.date_removed.month
removed_year = entry.date_removed.year
else:
# If there is no removal date, use the special value 0xffffffff for
# the day/month/year. That ensures that still-present tokens appear
# as the newest tokens when sorted by removal date.
removed_day = 0xff
removed_month = 0xff
removed_year = 0xffff
string_table += entry.string.encode()
string_table.append(0)
fd.write(
BINARY_FORMAT.entry.pack(entry.token, removed_day, removed_month,
removed_year))
fd.write(string_table)
class DatabaseFile(Database):
"""A token database that is associated with a particular file.
This class adds the write_to_file() method that writes to file from which it
was created in the correct format (CSV or binary).
"""
def __init__(self, path: Union[Path, str]):
self.path = Path(path)
# Read the path as a packed binary file.
with self.path.open('rb') as fd:
if file_is_binary_database(fd):
super().__init__(parse_binary(fd))
self._export = write_binary
return
# Read the path as a CSV file.
with self.path.open('r', newline='') as file:
super().__init__(parse_csv(file))
self._export = write_csv
def write_to_file(self, path: Optional[Union[Path, str]] = None) -> None:
"""Exports in the original format to the original or provided path."""
with open(self.path if path is None else path, 'wb') as fd:
self._export(self, fd)