| # Copyright 2019 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # https://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| """Tests for code_template.""" |
| |
| import string |
| import unittest |
| from compiler.back_end.util import code_template |
| |
| |
| def _format_template_str(template: str, **kwargs) -> str: |
| return code_template.format_template(string.Template(template), **kwargs) |
| |
| |
| class FormatTest(unittest.TestCase): |
| """Tests for code_template.format.""" |
| |
| def test_no_replacement_fields(self): |
| self.assertEqual("foo", _format_template_str("foo")) |
| self.assertEqual("{foo}", _format_template_str("{foo}")) |
| self.assertEqual("${foo}", _format_template_str("$${foo}")) |
| |
| def test_one_replacement_field(self): |
| self.assertEqual("foo", _format_template_str("${bar}", bar="foo")) |
| self.assertEqual("bazfoo", _format_template_str("baz${bar}", bar="foo")) |
| self.assertEqual("foobaz", _format_template_str("${bar}baz", bar="foo")) |
| self.assertEqual("bazfooqux", _format_template_str("baz${bar}qux", bar="foo")) |
| |
| def test_one_replacement_field_with_formatting(self): |
| # Basic string.Templates don't support formatting values. |
| self.assertRaises(ValueError, _format_template_str, "${bar:.6f}", bar=1) |
| |
| def test_one_replacement_field_value_missing(self): |
| self.assertRaises(KeyError, _format_template_str, "${bar}") |
| |
| def test_multiple_replacement_fields(self): |
| self.assertEqual( |
| " aaa bbb ", |
| _format_template_str(" ${bar} ${baz} ", bar="aaa", baz="bbb"), |
| ) |
| |
| |
| class ParseTemplatesTest(unittest.TestCase): |
| """Tests for code_template.parse_templates.""" |
| |
| def assertTemplatesEqual(self, expected, actual): # pylint:disable=invalid-name |
| """Compares the results of a parse_templates""" |
| # Extract the name and template from the result tuple |
| actual = {k: v.template for k, v in actual._asdict().items()} |
| self.assertEqual(expected, actual) |
| |
| def test_handles_no_template_case(self): |
| self.assertTemplatesEqual({}, code_template.parse_templates("")) |
| self.assertTemplatesEqual( |
| {}, code_template.parse_templates("this is not a template") |
| ) |
| |
| def test_handles_one_template_at_start(self): |
| self.assertTemplatesEqual( |
| {"foo": "bar"}, code_template.parse_templates("** foo **\nbar") |
| ) |
| |
| def test_handles_one_template_after_start(self): |
| self.assertTemplatesEqual( |
| {"foo": "bar"}, code_template.parse_templates("text\n** foo **\nbar") |
| ) |
| |
| def test_handles_delimiter_with_other_text(self): |
| self.assertTemplatesEqual( |
| {"foo": "bar"}, |
| code_template.parse_templates("text\n// ** foo ** ////\nbar"), |
| ) |
| self.assertTemplatesEqual( |
| {"foo": "bar"}, |
| code_template.parse_templates("text\n# ** foo ** #####\nbar"), |
| ) |
| |
| def test_handles_multiple_delimiters(self): |
| self.assertTemplatesEqual( |
| {"foo": "bar", "baz": "qux"}, |
| code_template.parse_templates("** foo **\nbar\n** baz **\nqux"), |
| ) |
| |
| def test_returns_object_with_attributes(self): |
| self.assertEqual( |
| "bar", |
| code_template.parse_templates( |
| "** foo **\nbar\n** baz **\nqux" |
| ).foo.template, |
| ) |
| |
| |
| if __name__ == "__main__": |
| unittest.main() |