blob: e4354fefb86239ce8e86957b077e67c41d75f026 [file] [log] [blame]
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for code_template."""
import string
import unittest
from compiler.back_end.util import code_template
def _format_template_str(template: str, **kwargs) -> str:
return code_template.format_template(string.Template(template), **kwargs)
class FormatTest(unittest.TestCase):
"""Tests for code_template.format."""
def test_no_replacement_fields(self):
self.assertEqual("foo", _format_template_str("foo"))
self.assertEqual("{foo}", _format_template_str("{foo}"))
self.assertEqual("${foo}", _format_template_str("$${foo}"))
def test_one_replacement_field(self):
self.assertEqual("foo", _format_template_str("${bar}", bar="foo"))
self.assertEqual("bazfoo", _format_template_str("baz${bar}", bar="foo"))
self.assertEqual("foobaz", _format_template_str("${bar}baz", bar="foo"))
self.assertEqual("bazfooqux", _format_template_str("baz${bar}qux", bar="foo"))
def test_one_replacement_field_with_formatting(self):
# Basic string.Templates don't support formatting values.
self.assertRaises(ValueError, _format_template_str, "${bar:.6f}", bar=1)
def test_one_replacement_field_value_missing(self):
self.assertRaises(KeyError, _format_template_str, "${bar}")
def test_multiple_replacement_fields(self):
self.assertEqual(
" aaa bbb ",
_format_template_str(" ${bar} ${baz} ", bar="aaa", baz="bbb"),
)
class ParseTemplatesTest(unittest.TestCase):
"""Tests for code_template.parse_templates."""
def assertTemplatesEqual(self, expected, actual): # pylint:disable=invalid-name
"""Compares the results of a parse_templates"""
# Extract the name and template from the result tuple
actual = {k: v.template for k, v in actual._asdict().items()}
self.assertEqual(expected, actual)
def test_handles_no_template_case(self):
self.assertTemplatesEqual({}, code_template.parse_templates(""))
self.assertTemplatesEqual(
{}, code_template.parse_templates("this is not a template")
)
def test_handles_one_template_at_start(self):
self.assertTemplatesEqual(
{"foo": "bar"}, code_template.parse_templates("** foo **\nbar")
)
def test_handles_one_template_after_start(self):
self.assertTemplatesEqual(
{"foo": "bar"}, code_template.parse_templates("text\n** foo **\nbar")
)
def test_handles_delimiter_with_other_text(self):
self.assertTemplatesEqual(
{"foo": "bar"},
code_template.parse_templates("text\n// ** foo ** ////\nbar"),
)
self.assertTemplatesEqual(
{"foo": "bar"},
code_template.parse_templates("text\n# ** foo ** #####\nbar"),
)
def test_handles_multiple_delimiters(self):
self.assertTemplatesEqual(
{"foo": "bar", "baz": "qux"},
code_template.parse_templates("** foo **\nbar\n** baz **\nqux"),
)
def test_returns_object_with_attributes(self):
self.assertEqual(
"bar",
code_template.parse_templates(
"** foo **\nbar\n** baz **\nqux"
).foo.template,
)
if __name__ == "__main__":
unittest.main()