pw_tokenizer: Improve Base64 detokenization
- Support on-the-fly detokenization of strings followed by valid Base64
(e.g. $abcdef==abc).
- Use re.sub instead of custom logic for replacements.
- Expand type annotations.
Change-Id: Id0f165f57d11da9652ba4d7ad15fcb4d49fcd83c
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/23620
Commit-Queue: Wyatt Hepler <hepler@google.com>
Reviewed-by: Anthony DiGirolamo <tonymd@google.com>
Reviewed-by: Keir Mierle <keir@google.com>
diff --git a/pw_tokenizer/py/detokenize_test.py b/pw_tokenizer/py/detokenize_test.py
index 7158102..3039695 100755
--- a/pw_tokenizer/py/detokenize_test.py
+++ b/pw_tokenizer/py/detokenize_test.py
@@ -492,11 +492,17 @@
TEST_CASES = (
(b'', b''),
+ (b'nothing here', b'nothing here'),
(JELLO, b'Jello, world!'),
+ (JELLO + b'a', b'Jello, world!a'),
+ (JELLO + b'abc', b'Jello, world!abc'),
+ (JELLO + b'abc=', b'Jello, world!abc='),
+ (b'$a' + JELLO + b'a', b'$aJello, world!a'),
(b'Hello ' + JELLO + b'?', b'Hello Jello, world!?'),
(b'$' + JELLO, b'$Jello, world!'),
(JELLO + JELLO, b'Jello, world!Jello, world!'),
(JELLO + b'$' + JELLO, b'Jello, world!$Jello, world!'),
+ (JELLO + b'$a' + JELLO + b'bcd', b'Jello, world!$aJello, world!bcd'),
(b'$3141', b'$3141'),
(JELLO + b'$3141', b'Jello, world!$3141'),
(RECURSION, b'The secret message is "Jello, world!"'),
diff --git a/pw_tokenizer/py/pw_tokenizer/detokenize.py b/pw_tokenizer/py/pw_tokenizer/detokenize.py
index cbc76c8..b5a0b7b 100755
--- a/pw_tokenizer/py/pw_tokenizer/detokenize.py
+++ b/pw_tokenizer/py/pw_tokenizer/detokenize.py
@@ -44,7 +44,8 @@
import struct
import sys
import time
-from typing import Dict, List, Iterable, NamedTuple, Optional, Tuple
+from typing import (BinaryIO, Callable, Dict, List, Iterable, Iterator, Match,
+ NamedTuple, Optional, Pattern, Tuple, Union)
try:
from pw_tokenizer import database, decode, tokens
@@ -234,7 +235,9 @@
except FileNotFoundError:
return database.load_token_database()
- def __init__(self, *paths_or_files, min_poll_period_s: float = 1.0):
+ def __init__(self,
+ *paths_or_files,
+ min_poll_period_s: float = 1.0) -> None:
self.paths = tuple(self._DatabasePath(path) for path in paths_or_files)
self.min_poll_period_s = min_poll_period_s
self._last_checked_time: float = time.time()
@@ -255,13 +258,13 @@
class PrefixedMessageDecoder:
"""Parses messages that start with a prefix character from a byte stream."""
- def __init__(self, prefix, chars):
+ def __init__(self, prefix: Union[str, bytes], chars: Union[str, bytes]):
"""Parses prefixed messages.
- Args:
- prefix: str or bytes; one character that signifies the start of a message
- chars: str or bytes; characters allowed in a message
- """
+ Args:
+ prefix: one character that signifies the start of a message
+ chars: characters allowed in a message
+ """
self._prefix = prefix.encode() if isinstance(prefix, str) else prefix
if isinstance(chars, str):
@@ -279,14 +282,15 @@
self.data = bytearray()
- def _read_next(self, fd):
+ def _read_next(self, fd: BinaryIO) -> Tuple[bytes, int]:
"""Returns the next character and its index."""
char = fd.read(1)
index = len(self.data)
self.data += char
return char, index
- def read_messages(self, binary_fd):
+ def read_messages(self,
+ binary_fd: BinaryIO) -> Iterator[Tuple[bool, bytes]]:
"""Parses prefixed messages; yields (is_message, contents) chunks."""
message_start = None
@@ -312,21 +316,26 @@
else:
yield False, char
- def transform(self, binary_fd, transform):
+ def transform(self, binary_fd: BinaryIO,
+ transform: Callable[[bytes], bytes]) -> Iterator[bytes]:
"""Yields the file with a transformation applied to the messages."""
for is_message, chunk in self.read_messages(binary_fd):
yield transform(chunk) if is_message else chunk
-def _detokenize_prefixed_base64(detokenizer, prefix, recursion):
+def _detokenize_prefixed_base64(
+ detokenizer: Detokenizer, prefix: bytes,
+ recursion: int) -> Callable[[Match[bytes]], bytes]:
"""Returns a function that decodes prefixed Base64 with the detokenizer."""
- def decode_and_detokenize(original):
+ def decode_and_detokenize(match: Match[bytes]) -> bytes:
"""Decodes prefixed base64 with the provided detokenizer."""
+ original = match.group(0)
+
try:
- result = detokenizer.detokenize(
+ detokenized_string = detokenizer.detokenize(
base64.b64decode(original[1:], validate=True))
- if result.matches():
- result = str(result).encode()
+ if detokenized_string.matches():
+ result = str(detokenized_string).encode()
if recursion > 0 and original != result:
result = detokenize_base64(detokenizer, result, prefix,
@@ -345,13 +354,31 @@
DEFAULT_RECURSION = 9
-def detokenize_base64_live(detokenizer,
- input_file,
- output,
- prefix=BASE64_PREFIX,
- recursion=DEFAULT_RECURSION):
+def _base64_message_regex(prefix: bytes) -> Pattern[bytes]:
+ """Returns a regular expression for prefixed base64 tokenized strings."""
+ return re.compile(
+ # Base64 tokenized strings start with the prefix character ($)
+ re.escape(prefix) + (
+ # Tokenized strings contain 0 or more blocks of four Base64 chars.
+ br'(?:[A-Za-z0-9+/\-_]{4})*'
+ # The last block of 4 chars may have one or two padding chars (=).
+ br'(?:[A-Za-z0-9+/\-_]{3}=|[A-Za-z0-9+/\-_]{2}==)?'))
+
+
+def detokenize_base64_live(detokenizer: Detokenizer,
+ input_file: BinaryIO,
+ output: BinaryIO,
+ prefix: Union[str, bytes] = BASE64_PREFIX,
+ recursion: int = DEFAULT_RECURSION) -> None:
"""Reads chars one-at-a-time and decodes messages; SLOW for big files."""
- transform = _detokenize_prefixed_base64(detokenizer, prefix, recursion)
+ prefix_bytes = prefix.encode() if isinstance(prefix, str) else prefix
+
+ base64_message = _base64_message_regex(prefix_bytes)
+
+ def transform(data: bytes) -> bytes:
+ return base64_message.sub(
+ _detokenize_prefixed_base64(detokenizer, prefix_bytes, recursion),
+ data)
for message in PrefixedMessageDecoder(
prefix, string.ascii_letters + string.digits + '+/-_=').transform(
@@ -363,50 +390,40 @@
output.flush()
-def detokenize_base64_to_file(detokenizer,
- data,
- output,
- prefix=BASE64_PREFIX,
- recursion=DEFAULT_RECURSION):
+def detokenize_base64_to_file(detokenizer: Detokenizer,
+ data: bytes,
+ output: BinaryIO,
+ prefix: Union[str, bytes] = BASE64_PREFIX,
+ recursion: int = DEFAULT_RECURSION) -> None:
"""Decodes prefixed Base64 messages in data; decodes to an output file."""
- transform = _detokenize_prefixed_base64(detokenizer, prefix, recursion)
-
- messages = re.compile(
- re.escape(prefix.encode() if isinstance(prefix, str) else prefix) +
- (br'(?:[A-Za-z0-9+/\-_]{4})*'
- br'(?:[A-Za-z0-9+/\-_]{3}=|[A-Za-z0-9+/\-_]{2}==)?'))
-
- index = 0
-
- for match in messages.finditer(data):
- output.write(data[index:match.start()])
- output.write(transform(match.group(0)))
- index = match.end()
-
- output.write(data[index:])
+ prefix = prefix.encode() if isinstance(prefix, str) else prefix
+ output.write(
+ _base64_message_regex(prefix).sub(
+ _detokenize_prefixed_base64(detokenizer, prefix, recursion), data))
-def detokenize_base64(detokenizer,
- data,
- prefix=BASE64_PREFIX,
- recursion=DEFAULT_RECURSION):
+def detokenize_base64(detokenizer: Detokenizer,
+ data: bytes,
+ prefix: Union[str, bytes] = BASE64_PREFIX,
+ recursion: int = DEFAULT_RECURSION) -> bytes:
"""Decodes and replaces prefixed Base64 messages in the provided data.
- Args:
- detokenizer: the detokenizer with which to decode messages
- data: the binary data to decode
- prefix: one-character byte string that signals the start of a message
- recursion: how many levels to recursively decode
+ Args:
+ detokenizer: the detokenizer with which to decode messages
+ data: the binary data to decode
+ prefix: one-character byte string that signals the start of a message
+ recursion: how many levels to recursively decode
- Returns:
- copy of the data with all recognized tokens decoded
- """
+ Returns:
+ copy of the data with all recognized tokens decoded
+ """
output = io.BytesIO()
detokenize_base64_to_file(detokenizer, data, output, prefix, recursion)
return output.getvalue()
-def _handle_base64(databases, input_file, output, prefix, show_errors):
+def _handle_base64(databases, input_file: BinaryIO, output: BinaryIO,
+ prefix: str, show_errors: bool) -> None:
"""Handles the base64 command line option."""
# argparse.FileType doesn't correctly handle - for binary files.
if input_file is sys.stdin:
@@ -426,7 +443,7 @@
detokenize_base64_live(detokenizer, input_file, output, prefix)
-def _parse_args():
+def _parse_args() -> argparse.Namespace:
"""Parses and return command line arguments."""
parser = argparse.ArgumentParser(
@@ -472,13 +489,14 @@
return parser.parse_args()
-def main():
+def main() -> int:
args = _parse_args()
handler = args.handler
del args.handler
handler(**vars(args))
+ return 0
if __name__ == '__main__':