blob: c6717992d510da422af13f3ebd85ceb7b7d55064 [file] [log] [blame]
#!/usr/bin/env python3
# Copyright 2021 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Tests decoding a proto with tokenized fields."""
import base64
import unittest
from pw_tokenizer_tests.detokenize_proto_test_pb2 import TheMessage
from pw_tokenizer import detokenize, encode, tokens
from pw_tokenizer.proto import detokenize_fields, decode_optionally_tokenized
_DATABASE = tokens.Database([
tokens.TokenizedStringEntry(0xAABBCCDD, "Luke, we're gonna have %s"),
tokens.TokenizedStringEntry(0x12345678, "This string has a $oeQAAA=="),
tokens.TokenizedStringEntry(0x0000e4a1, "recursive token"),
])
_DETOKENIZER = detokenize.Detokenizer(_DATABASE)
class TestDetokenizeProtoFields(unittest.TestCase):
"""Tests detokenizing optionally tokenized proto fields."""
def test_plain_text(self) -> None:
proto = TheMessage(message=b'boring conversation anyway!')
detokenize_fields(_DETOKENIZER, proto)
self.assertEqual(proto.message, b'boring conversation anyway!')
def test_binary(self) -> None:
proto = TheMessage(message=b'\xDD\xCC\xBB\xAA\x07company')
detokenize_fields(_DETOKENIZER, proto)
self.assertEqual(proto.message, b"Luke, we're gonna have company")
def test_recursive_binary(self) -> None:
proto = TheMessage(message=b'\x78\x56\x34\x12')
detokenize_fields(_DETOKENIZER, proto)
self.assertEqual(proto.message, b"This string has a recursive token")
def test_base64(self) -> None:
base64_msg = encode.prefixed_base64(b'\xDD\xCC\xBB\xAA\x07company')
proto = TheMessage(message=base64_msg.encode())
detokenize_fields(_DETOKENIZER, proto)
self.assertEqual(proto.message, b"Luke, we're gonna have company")
def test_recursive_base64(self) -> None:
base64_msg = encode.prefixed_base64(b'\x78\x56\x34\x12')
proto = TheMessage(message=base64_msg.encode())
detokenize_fields(_DETOKENIZER, proto)
self.assertEqual(proto.message, b"This string has a recursive token")
def test_plain_text_with_prefixed_base64(self) -> None:
base64_msg = encode.prefixed_base64(b'\xDD\xCC\xBB\xAA\x09pancakes!')
proto = TheMessage(message=f'Good morning, {base64_msg}'.encode())
detokenize_fields(_DETOKENIZER, proto)
self.assertEqual(proto.message,
b"Good morning, Luke, we're gonna have pancakes!")
def test_unknown_token_not_utf8(self) -> None:
proto = TheMessage(message=b'\xFE\xED\xF0\x0D')
detokenize_fields(_DETOKENIZER, proto)
self.assertEqual(proto.message.decode(),
encode.prefixed_base64(b'\xFE\xED\xF0\x0D'))
def test_only_control_characters(self) -> None:
proto = TheMessage(message=b'\1\2\3\4')
detokenize_fields(_DETOKENIZER, proto)
self.assertEqual(proto.message.decode(),
encode.prefixed_base64(b'\1\2\3\4'))
class TestDecodeOptionallyTokenized(unittest.TestCase):
"""Tests optional detokenization directly."""
def setUp(self):
self.detok = detokenize.Detokenizer(
tokens.Database([
tokens.TokenizedStringEntry(0, 'cheese'),
tokens.TokenizedStringEntry(1, 'on pizza'),
tokens.TokenizedStringEntry(2, 'is quite good'),
tokens.TokenizedStringEntry(3, 'they say'),
]))
def test_found_binary_token(self):
self.assertEqual(
'on pizza',
decode_optionally_tokenized(self.detok, b'\x01\x00\x00\x00'))
def test_missing_binary_token(self):
self.assertEqual(
'$' + base64.b64encode(b'\xD5\x8A\xF9\x2A\x8A').decode(),
decode_optionally_tokenized(self.detok, b'\xD5\x8A\xF9\x2A\x8A'))
def test_found_b64_token(self):
b64_bytes = b'$' + base64.b64encode(b'\x03\x00\x00\x00')
self.assertEqual('they say',
decode_optionally_tokenized(self.detok, b64_bytes))
def test_missing_b64_token(self):
b64_bytes = b'$' + base64.b64encode(b'\xD5\x8A\xF9\x2A\x8A')
self.assertEqual(b64_bytes.decode(),
decode_optionally_tokenized(self.detok, b64_bytes))
def test_found_alternate_prefix(self):
b64_bytes = b'~' + base64.b64encode(b'\x00\x00\x00\x00')
self.assertEqual(
'cheese', decode_optionally_tokenized(self.detok, b64_bytes, b'~'))
def test_missing_alternate_prefix(self):
b64_bytes = b'~' + base64.b64encode(b'\x02\x00\x00\x00')
self.assertEqual(
b64_bytes.decode(),
decode_optionally_tokenized(self.detok, b64_bytes, b'^'))
if __name__ == '__main__':
unittest.main()