blob: c1330960149696ae857d86f003102655612d9325 [file] [log] [blame]
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for code_template."""
import string
import unittest
from compiler.back_end.util import code_template
def _format_template_str(template: str, **kwargs) -> str:
return code_template.format_template(string.Template(template), **kwargs)
class FormatTest(unittest.TestCase):
"""Tests for code_template.format."""
def test_no_replacement_fields(self):
self.assertEqual("foo", _format_template_str("foo"))
self.assertEqual("{foo}", _format_template_str("{foo}"))
self.assertEqual("${foo}", _format_template_str("$${foo}"))
def test_one_replacement_field(self):
self.assertEqual("foo", _format_template_str("${bar}", bar="foo"))
self.assertEqual("bazfoo",
_format_template_str("baz${bar}", bar="foo"))
self.assertEqual("foobaz",
_format_template_str("${bar}baz", bar="foo"))
self.assertEqual("bazfooqux",
_format_template_str("baz${bar}qux", bar="foo"))
def test_one_replacement_field_with_formatting(self):
# Basic string.Templates don't support formatting values.
self.assertRaises(ValueError,
_format_template_str, "${bar:.6f}", bar=1)
def test_one_replacement_field_value_missing(self):
self.assertRaises(KeyError, _format_template_str, "${bar}")
def test_multiple_replacement_fields(self):
self.assertEqual(" aaa bbb ",
_format_template_str(" ${bar} ${baz} ",
bar="aaa",
baz="bbb"))
class ParseTemplatesTest(unittest.TestCase):
"""Tests for code_template.parse_templates."""
def assertTemplatesEqual(self, expected, actual): # pylint:disable=invalid-name
"""Compares the results of a parse_templates"""
# Extract the name and template from the result tuple
actual = {
k: v.template for k, v in actual._asdict().items()
}
self.assertEqual(expected, actual)
def test_handles_no_template_case(self):
self.assertTemplatesEqual({}, code_template.parse_templates(""))
self.assertTemplatesEqual({}, code_template.parse_templates(
"this is not a template"))
def test_handles_one_template_at_start(self):
self.assertTemplatesEqual({"foo": "bar"},
code_template.parse_templates("** foo **\nbar"))
def test_handles_one_template_after_start(self):
self.assertTemplatesEqual(
{"foo": "bar"},
code_template.parse_templates("text\n** foo **\nbar"))
def test_handles_delimiter_with_other_text(self):
self.assertTemplatesEqual(
{"foo": "bar"},
code_template.parse_templates("text\n// ** foo ** ////\nbar"))
self.assertTemplatesEqual(
{"foo": "bar"},
code_template.parse_templates("text\n# ** foo ** #####\nbar"))
def test_handles_multiple_delimiters(self):
self.assertTemplatesEqual({"foo": "bar",
"baz": "qux"}, code_template.parse_templates(
"** foo **\nbar\n** baz **\nqux"))
def test_returns_object_with_attributes(self):
self.assertEqual("bar", code_template.parse_templates(
"** foo **\nbar\n** baz **\nqux").foo.template)
if __name__ == "__main__":
unittest.main()