pw_tokenizer: Improve Base64 detokenization

- Support on-the-fly detokenization of strings followed by valid Base64
  (e.g. $abcdef==abc).
- Use re.sub instead of custom logic for replacements.
- Expand type annotations.

Change-Id: Id0f165f57d11da9652ba4d7ad15fcb4d49fcd83c
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/23620
Commit-Queue: Wyatt Hepler <hepler@google.com>
Reviewed-by: Anthony DiGirolamo <tonymd@google.com>
Reviewed-by: Keir Mierle <keir@google.com>
diff --git a/pw_tokenizer/py/detokenize_test.py b/pw_tokenizer/py/detokenize_test.py
index 7158102..3039695 100755
--- a/pw_tokenizer/py/detokenize_test.py
+++ b/pw_tokenizer/py/detokenize_test.py
@@ -492,11 +492,17 @@
 
     TEST_CASES = (
         (b'', b''),
+        (b'nothing here', b'nothing here'),
         (JELLO, b'Jello, world!'),
+        (JELLO + b'a', b'Jello, world!a'),
+        (JELLO + b'abc', b'Jello, world!abc'),
+        (JELLO + b'abc=', b'Jello, world!abc='),
+        (b'$a' + JELLO + b'a', b'$aJello, world!a'),
         (b'Hello ' + JELLO + b'?', b'Hello Jello, world!?'),
         (b'$' + JELLO, b'$Jello, world!'),
         (JELLO + JELLO, b'Jello, world!Jello, world!'),
         (JELLO + b'$' + JELLO, b'Jello, world!$Jello, world!'),
+        (JELLO + b'$a' + JELLO + b'bcd', b'Jello, world!$aJello, world!bcd'),
         (b'$3141', b'$3141'),
         (JELLO + b'$3141', b'Jello, world!$3141'),
         (RECURSION, b'The secret message is "Jello, world!"'),
diff --git a/pw_tokenizer/py/pw_tokenizer/detokenize.py b/pw_tokenizer/py/pw_tokenizer/detokenize.py
index cbc76c8..b5a0b7b 100755
--- a/pw_tokenizer/py/pw_tokenizer/detokenize.py
+++ b/pw_tokenizer/py/pw_tokenizer/detokenize.py
@@ -44,7 +44,8 @@
 import struct
 import sys
 import time
-from typing import Dict, List, Iterable, NamedTuple, Optional, Tuple
+from typing import (BinaryIO, Callable, Dict, List, Iterable, Iterator, Match,
+                    NamedTuple, Optional, Pattern, Tuple, Union)
 
 try:
     from pw_tokenizer import database, decode, tokens
@@ -234,7 +235,9 @@
             except FileNotFoundError:
                 return database.load_token_database()
 
-    def __init__(self, *paths_or_files, min_poll_period_s: float = 1.0):
+    def __init__(self,
+                 *paths_or_files,
+                 min_poll_period_s: float = 1.0) -> None:
         self.paths = tuple(self._DatabasePath(path) for path in paths_or_files)
         self.min_poll_period_s = min_poll_period_s
         self._last_checked_time: float = time.time()
@@ -255,13 +258,13 @@
 
 class PrefixedMessageDecoder:
     """Parses messages that start with a prefix character from a byte stream."""
-    def __init__(self, prefix, chars):
+    def __init__(self, prefix: Union[str, bytes], chars: Union[str, bytes]):
         """Parses prefixed messages.
 
-    Args:
-      prefix: str or bytes; one character that signifies the start of a message
-      chars: str or bytes; characters allowed in a message
-    """
+        Args:
+          prefix: one character that signifies the start of a message
+          chars: characters allowed in a message
+        """
         self._prefix = prefix.encode() if isinstance(prefix, str) else prefix
 
         if isinstance(chars, str):
@@ -279,14 +282,15 @@
 
         self.data = bytearray()
 
-    def _read_next(self, fd):
+    def _read_next(self, fd: BinaryIO) -> Tuple[bytes, int]:
         """Returns the next character and its index."""
         char = fd.read(1)
         index = len(self.data)
         self.data += char
         return char, index
 
-    def read_messages(self, binary_fd):
+    def read_messages(self,
+                      binary_fd: BinaryIO) -> Iterator[Tuple[bool, bytes]]:
         """Parses prefixed messages; yields (is_message, contents) chunks."""
         message_start = None
 
@@ -312,21 +316,26 @@
             else:
                 yield False, char
 
-    def transform(self, binary_fd, transform):
+    def transform(self, binary_fd: BinaryIO,
+                  transform: Callable[[bytes], bytes]) -> Iterator[bytes]:
         """Yields the file with a transformation applied to the messages."""
         for is_message, chunk in self.read_messages(binary_fd):
             yield transform(chunk) if is_message else chunk
 
 
-def _detokenize_prefixed_base64(detokenizer, prefix, recursion):
+def _detokenize_prefixed_base64(
+        detokenizer: Detokenizer, prefix: bytes,
+        recursion: int) -> Callable[[Match[bytes]], bytes]:
     """Returns a function that decodes prefixed Base64 with the detokenizer."""
-    def decode_and_detokenize(original):
+    def decode_and_detokenize(match: Match[bytes]) -> bytes:
         """Decodes prefixed base64 with the provided detokenizer."""
+        original = match.group(0)
+
         try:
-            result = detokenizer.detokenize(
+            detokenized_string = detokenizer.detokenize(
                 base64.b64decode(original[1:], validate=True))
-            if result.matches():
-                result = str(result).encode()
+            if detokenized_string.matches():
+                result = str(detokenized_string).encode()
 
                 if recursion > 0 and original != result:
                     result = detokenize_base64(detokenizer, result, prefix,
@@ -345,13 +354,31 @@
 DEFAULT_RECURSION = 9
 
 
-def detokenize_base64_live(detokenizer,
-                           input_file,
-                           output,
-                           prefix=BASE64_PREFIX,
-                           recursion=DEFAULT_RECURSION):
+def _base64_message_regex(prefix: bytes) -> Pattern[bytes]:
+    """Returns a regular expression for prefixed base64 tokenized strings."""
+    return re.compile(
+        # Base64 tokenized strings start with the prefix character ($)
+        re.escape(prefix) + (
+            # Tokenized strings contain 0 or more blocks of four Base64 chars.
+            br'(?:[A-Za-z0-9+/\-_]{4})*'
+            # The last block of 4 chars may have one or two padding chars (=).
+            br'(?:[A-Za-z0-9+/\-_]{3}=|[A-Za-z0-9+/\-_]{2}==)?'))
+
+
+def detokenize_base64_live(detokenizer: Detokenizer,
+                           input_file: BinaryIO,
+                           output: BinaryIO,
+                           prefix: Union[str, bytes] = BASE64_PREFIX,
+                           recursion: int = DEFAULT_RECURSION) -> None:
     """Reads chars one-at-a-time and decodes messages; SLOW for big files."""
-    transform = _detokenize_prefixed_base64(detokenizer, prefix, recursion)
+    prefix_bytes = prefix.encode() if isinstance(prefix, str) else prefix
+
+    base64_message = _base64_message_regex(prefix_bytes)
+
+    def transform(data: bytes) -> bytes:
+        return base64_message.sub(
+            _detokenize_prefixed_base64(detokenizer, prefix_bytes, recursion),
+            data)
 
     for message in PrefixedMessageDecoder(
             prefix, string.ascii_letters + string.digits + '+/-_=').transform(
@@ -363,50 +390,40 @@
             output.flush()
 
 
-def detokenize_base64_to_file(detokenizer,
-                              data,
-                              output,
-                              prefix=BASE64_PREFIX,
-                              recursion=DEFAULT_RECURSION):
+def detokenize_base64_to_file(detokenizer: Detokenizer,
+                              data: bytes,
+                              output: BinaryIO,
+                              prefix: Union[str, bytes] = BASE64_PREFIX,
+                              recursion: int = DEFAULT_RECURSION) -> None:
     """Decodes prefixed Base64 messages in data; decodes to an output file."""
-    transform = _detokenize_prefixed_base64(detokenizer, prefix, recursion)
-
-    messages = re.compile(
-        re.escape(prefix.encode() if isinstance(prefix, str) else prefix) +
-        (br'(?:[A-Za-z0-9+/\-_]{4})*'
-         br'(?:[A-Za-z0-9+/\-_]{3}=|[A-Za-z0-9+/\-_]{2}==)?'))
-
-    index = 0
-
-    for match in messages.finditer(data):
-        output.write(data[index:match.start()])
-        output.write(transform(match.group(0)))
-        index = match.end()
-
-    output.write(data[index:])
+    prefix = prefix.encode() if isinstance(prefix, str) else prefix
+    output.write(
+        _base64_message_regex(prefix).sub(
+            _detokenize_prefixed_base64(detokenizer, prefix, recursion), data))
 
 
-def detokenize_base64(detokenizer,
-                      data,
-                      prefix=BASE64_PREFIX,
-                      recursion=DEFAULT_RECURSION):
+def detokenize_base64(detokenizer: Detokenizer,
+                      data: bytes,
+                      prefix: Union[str, bytes] = BASE64_PREFIX,
+                      recursion: int = DEFAULT_RECURSION) -> bytes:
     """Decodes and replaces prefixed Base64 messages in the provided data.
 
-  Args:
-    detokenizer: the detokenizer with which to decode messages
-    data: the binary data to decode
-    prefix: one-character byte string that signals the start of a message
-    recursion: how many levels to recursively decode
+    Args:
+      detokenizer: the detokenizer with which to decode messages
+      data: the binary data to decode
+      prefix: one-character byte string that signals the start of a message
+      recursion: how many levels to recursively decode
 
-  Returns:
-    copy of the data with all recognized tokens decoded
-  """
+    Returns:
+      copy of the data with all recognized tokens decoded
+    """
     output = io.BytesIO()
     detokenize_base64_to_file(detokenizer, data, output, prefix, recursion)
     return output.getvalue()
 
 
-def _handle_base64(databases, input_file, output, prefix, show_errors):
+def _handle_base64(databases, input_file: BinaryIO, output: BinaryIO,
+                   prefix: str, show_errors: bool) -> None:
     """Handles the base64 command line option."""
     # argparse.FileType doesn't correctly handle - for binary files.
     if input_file is sys.stdin:
@@ -426,7 +443,7 @@
         detokenize_base64_live(detokenizer, input_file, output, prefix)
 
 
-def _parse_args():
+def _parse_args() -> argparse.Namespace:
     """Parses and return command line arguments."""
 
     parser = argparse.ArgumentParser(
@@ -472,13 +489,14 @@
     return parser.parse_args()
 
 
-def main():
+def main() -> int:
     args = _parse_args()
 
     handler = args.handler
     del args.handler
 
     handler(**vars(args))
+    return 0
 
 
 if __name__ == '__main__':