pw_tokenizer: Add entries rather than strings to databases
- Do not recalculate tokens when adding them to a token database.
- Remove unneeded APIs for adding plain strings. These are no longer
needed and should not be used since the hash calculations are done
entirely in firmware now.
Change-Id: I8d15088f2ebe60adf6e21d6ccbd0dc99ad255fd7
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/24460
Reviewed-by: Keir Mierle <keir@google.com>
Commit-Queue: Wyatt Hepler <hepler@google.com>
diff --git a/pw_tokenizer/py/database_test.py b/pw_tokenizer/py/database_test.py
index 18530b8..54c890e 100755
--- a/pw_tokenizer/py/database_test.py
+++ b/pw_tokenizer/py/database_test.py
@@ -118,7 +118,7 @@
}
-def run_cli(*args):
+def run_cli(*args) -> None:
original_argv = sys.argv
sys.argv = ['database.py', *(str(a) for a in args)]
# pylint: disable=protected-access
@@ -133,7 +133,7 @@
sys.argv = original_argv
-def _mock_output():
+def _mock_output() -> io.TextIOWrapper:
output = io.BytesIO()
output.name = '<fake stdout>'
return io.TextIOWrapper(output, write_through=True)
@@ -187,11 +187,15 @@
self.assertEqual(CSV_DEFAULT_DOMAIN.splitlines(),
self._csv.read_text().splitlines())
- def test_add(self):
- self._csv.write_text(CSV_ALL_DOMAINS)
+ def test_add_does_not_recalculate_tokens(self):
+ db_with_custom_token = '01234567, ,"hello"'
- run_cli('add', '--database', self._csv, f'{self._elf}#TEST_DOMAIN')
- self.assertEqual(CSV_ALL_DOMAINS.splitlines(),
+ to_add = self._dir / 'add_this.csv'
+ to_add.write_text(db_with_custom_token + '\n')
+ self._csv.touch()
+
+ run_cli('add', '--database', self._csv, to_add)
+ self.assertEqual(db_with_custom_token.splitlines(),
self._csv.read_text().splitlines())
def test_mark_removals(self):
diff --git a/pw_tokenizer/py/detokenize_test.py b/pw_tokenizer/py/detokenize_test.py
index 3039695..6613e53 100755
--- a/pw_tokenizer/py/detokenize_test.py
+++ b/pw_tokenizer/py/detokenize_test.py
@@ -445,7 +445,7 @@
os.unlink(file.name)
-def _next_char(message):
+def _next_char(message: bytes) -> bytes:
return bytes(b + 1 for b in message)
@@ -514,7 +514,9 @@
super().setUp()
db = database.load_token_database(
io.BytesIO(ELF_WITH_TOKENIZER_SECTIONS))
- db.add([self.RECURSION_STRING, self.RECURSION_STRING_2])
+ db.add(
+ tokens.TokenizedStringEntry(tokens.default_hash(s), s)
+ for s in [self.RECURSION_STRING, self.RECURSION_STRING_2])
self.detok = detokenize.Detokenizer(db)
def test_detokenize_base64_live(self):
diff --git a/pw_tokenizer/py/pw_tokenizer/database.py b/pw_tokenizer/py/pw_tokenizer/database.py
index 43ba5cb..cf21481 100755
--- a/pw_tokenizer/py/pw_tokenizer/database.py
+++ b/pw_tokenizer/py/pw_tokenizer/database.py
@@ -285,7 +285,7 @@
initial = len(token_database)
for source in databases:
- token_database.add((entry.string for entry in source.entries()))
+ token_database.add(source.entries())
token_database.write_to_file()
@@ -295,8 +295,7 @@
def _handle_mark_removals(token_database, databases, date):
marked_removed = token_database.mark_removals(
- (entry.string
- for entry in tokens.Database.merged(*databases).entries()
+ (entry for entry in tokens.Database.merged(*databases).entries()
if not entry.date_removed), date)
token_database.write_to_file()
diff --git a/pw_tokenizer/py/pw_tokenizer/tokens.py b/pw_tokenizer/py/pw_tokenizer/tokens.py
index 570b021..a45f4b0 100644
--- a/pw_tokenizer/py/pw_tokenizer/tokens.py
+++ b/pw_tokenizer/py/pw_tokenizer/tokens.py
@@ -163,19 +163,19 @@
def mark_removals(
self,
- all_strings: Iterable[str],
+ all_entries: Iterable[TokenizedStringEntry],
removal_date: Optional[datetime] = None
) -> List[TokenizedStringEntry]:
- """Marks strings missing from all_strings as having been removed.
+ """Marks entries missing from all_entries as having been removed.
- The strings are assumed to represent the complete set of strings for the
- database. Strings currently in the database not present in the provided
- strings are marked with a removal date but remain in the database.
- Strings in all_strings missing from the database are NOT ; call the
- add function to add these strings.
+ The entries are assumed to represent the complete set of entries for the
+ database. Entries currently in the database not present in the provided
+ entries are marked with a removal date but remain in the database.
+ Entries in all_entries missing from the database are NOT added; call the
+ add function to add these.
Args:
- all_strings: the complete set of strings present in the database
+ all_entries: the complete set of strings present in the database
removal_date: the datetime for removed entries; today by default
Returns:
@@ -186,13 +186,12 @@
if removal_date is None:
removal_date = datetime.now()
- all_strings = frozenset(all_strings) # for faster lookup
+ all_keys = frozenset(entry.key() for entry in all_entries)
removed = []
- # Mark this entry as having been removed from the ELF.
for entry in self._database.values():
- if (entry.string not in all_strings
+ if (entry.key() not in all_keys
and (entry.date_removed is None
or removal_date < entry.date_removed)):
# Add a removal date, or update it to the oldest date.
@@ -201,29 +200,19 @@
return removed
- def add(self,
- entries: Iterable[Union[str, TokenizedStringEntry]],
- tokenize: Callable[[str], int] = default_hash) -> None:
- """Adds new entries or strings to the database."""
+ def add(self, entries: Iterable[TokenizedStringEntry]) -> None:
+ """Adds new entries and updates date_removed for existing entries."""
self._cache = None
- # Add new and update previously removed entries.
for new_entry in entries:
- # Handle legacy plain string entries, which need to be hashed.
- if isinstance(new_entry, str):
- key = _EntryKey(tokenize(new_entry), new_entry)
- domain = DEFAULT_DOMAIN
- else:
- key = _EntryKey(new_entry.token, new_entry.string)
- domain = new_entry.domain
-
+ # Update an existing entry or create a new one.
try:
- entry = self._database[key]
- if entry.date_removed:
- entry.date_removed = None
+ entry = self._database[new_entry.key()]
+ entry.domain = new_entry.domain
+ entry.date_removed = None
except KeyError:
- self._database[key] = TokenizedStringEntry(
- key.token, key.string, domain)
+ self._database[new_entry.key()] = TokenizedStringEntry(
+ new_entry.token, new_entry.string, new_entry.domain)
def purge(
self,
diff --git a/pw_tokenizer/py/tokens_test.py b/pw_tokenizer/py/tokens_test.py
index 58014e2..8c71a3b 100755
--- a/pw_tokenizer/py/tokens_test.py
+++ b/pw_tokenizer/py/tokens_test.py
@@ -19,6 +19,7 @@
import logging
from pathlib import Path
import tempfile
+from typing import Iterator
import unittest
from pw_tokenizer import tokens
@@ -94,6 +95,11 @@
return tokens.Database(tokens.parse_csv(csv_db))
+def _entries(*strings: str) -> Iterator[tokens.TokenizedStringEntry]:
+ for string in strings:
+ yield tokens.TokenizedStringEntry(default_hash(string), string)
+
+
class TokenDatabaseTest(unittest.TestCase):
"""Tests the token database class."""
def test_csv(self):
@@ -313,7 +319,7 @@
self.assertEqual(len(db.token_to_entries), 16)
# Add two strings with the same hash.
- db.add(['o000', '0Q1Q'])
+ db.add(_entries('o000', '0Q1Q'))
self.assertEqual(len(db.entries()), 18)
self.assertEqual(len(db.token_to_entries), 17)
@@ -327,7 +333,7 @@
all(entry.date_removed is None for entry in db.entries()))
date_1 = datetime.datetime(1, 2, 3)
- db.mark_removals(['apples', 'oranges', 'pears'], date_1)
+ db.mark_removals(_entries('apples', 'oranges', 'pears'), date_1)
self.assertEqual(
db.token_to_entries[default_hash('MILK')][0].date_removed, date_1)
@@ -336,7 +342,7 @@
date_1)
now = datetime.datetime.now()
- db.mark_removals(['MILK', 'CHEESE', 'pears'])
+ db.mark_removals(_entries('MILK', 'CHEESE', 'pears'))
# New strings are not added or re-added in mark_removed().
self.assertGreaterEqual(
@@ -355,16 +361,16 @@
def test_add(self):
db = tokens.Database()
- db.add(['MILK', 'apples'])
+ db.add(_entries('MILK', 'apples'))
self.assertEqual({e.string for e in db.entries()}, {'MILK', 'apples'})
- db.add(['oranges', 'CHEESE', 'pears'])
+ db.add(_entries('oranges', 'CHEESE', 'pears'))
self.assertEqual(len(db.entries()), 5)
- db.add(['MILK', 'apples', 'only this one is new'])
+ db.add(_entries('MILK', 'apples', 'only this one is new'))
self.assertEqual(len(db.entries()), 6)
- db.add(['MILK'])
+ db.add(_entries('MILK'))
self.assertEqual({e.string
for e in db.entries()}, {
'MILK', 'apples', 'oranges', 'CHEESE', 'pears',