blob: 06d5c37609b0e9f1286c9a540eb2d8c3d20e0662 [file] [log] [blame]
# Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for flagsaver."""
from absl import flags
from absl.testing import absltest
from absl.testing import flagsaver
from absl.testing import parameterized
flags.DEFINE_string('flagsaver_test_flag0', 'unchanged0', 'flag to test with')
flags.DEFINE_string('flagsaver_test_flag1', 'unchanged1', 'flag to test with')
flags.DEFINE_string('flagsaver_test_validated_flag', None, 'flag to test with')
flags.register_validator('flagsaver_test_validated_flag', lambda x: not x)
flags.DEFINE_string('flagsaver_test_validated_flag1', None, 'flag to test with')
flags.DEFINE_string('flagsaver_test_validated_flag2', None, 'flag to test with')
INT_FLAG = flags.DEFINE_integer(
'flagsaver_test_int_flag', default=1, help='help')
STR_FLAG = flags.DEFINE_string(
'flagsaver_test_str_flag', default='str default', help='help')
MULTI_INT_FLAG = flags.DEFINE_multi_integer('flagsaver_test_multi_int_flag',
None, 'flag to test with')
@flags.multi_flags_validator(
('flagsaver_test_validated_flag1', 'flagsaver_test_validated_flag2'))
def validate_test_flags(flag_dict):
return (flag_dict['flagsaver_test_validated_flag1'] ==
flag_dict['flagsaver_test_validated_flag2'])
FLAGS = flags.FLAGS
@flags.validator('flagsaver_test_flag0')
def check_no_upper_case(value):
return value == value.lower()
class _TestError(Exception):
"""Exception class for use in these tests."""
class CommonUsageTest(absltest.TestCase):
"""These test cases cover the most common usages of flagsaver."""
def test_as_parsed_context_manager(self):
# Precondition check, we expect all the flags to start as their default.
self.assertEqual('str default', STR_FLAG.value)
self.assertFalse(STR_FLAG.present)
self.assertEqual(1, INT_FLAG.value)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1)
# Flagsaver will also save the state of flags that have been modified.
FLAGS.flagsaver_test_flag1 = 'outside flagsaver'
# Save all existing flag state, and set some flags as if they were parsed on
# the command line. Because of this, the new values must be provided as str,
# even if the flag type is something other than string.
with flagsaver.as_parsed(
(STR_FLAG, 'new string value'), # Override using flagholder object.
(INT_FLAG, '123'), # Override an int flag (NOTE: must specify as str).
flagsaver_test_flag0='new value', # Override using flag name.
):
# All the flags have their overridden values.
self.assertEqual('new string value', STR_FLAG.value)
self.assertTrue(STR_FLAG.present)
self.assertEqual(123, INT_FLAG.value)
self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
# Even if we change other flags, they will reset on context exit.
FLAGS.flagsaver_test_flag1 = 'new value 1'
# The flags have all reset to their pre-flagsaver values.
self.assertEqual('str default', STR_FLAG.value)
self.assertFalse(STR_FLAG.present)
self.assertEqual(1, INT_FLAG.value)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
self.assertEqual('outside flagsaver', FLAGS.flagsaver_test_flag1)
def test_as_parsed_decorator(self):
# flagsaver.as_parsed can also be used as a decorator.
@flagsaver.as_parsed((INT_FLAG, '123'))
def do_something_with_flags():
self.assertEqual(123, INT_FLAG.value)
self.assertTrue(INT_FLAG.present)
do_something_with_flags()
self.assertEqual(1, INT_FLAG.value)
self.assertFalse(INT_FLAG.present)
def test_flagsaver_flagsaver(self):
# If you don't want the flags to go through parsing, you can instead use
# flagsaver.flagsaver(). With this method, you provide the native python
# value you'd like the flags to take on. Otherwise it functions similar to
# flagsaver.as_parsed().
@flagsaver.flagsaver((INT_FLAG, 345))
def do_something_with_flags():
self.assertEqual(345, INT_FLAG.value)
# Note that because this flag was never parsed, it will not register as
# .present unless you manually set that attribute.
self.assertFalse(INT_FLAG.present)
# If you do chose to modify things about the flag (such as .present) those
# changes will still be cleaned up when flagsaver.flagsaver() exits.
INT_FLAG.present = True
self.assertEqual(1, INT_FLAG.value)
# flagsaver.flagsaver() restored INT_FLAG.present to the state it was in
# before entering the context.
self.assertFalse(INT_FLAG.present)
class SaveFlagValuesTest(absltest.TestCase):
"""Test flagsaver.save_flag_values() and flagsaver.restore_flag_values().
In this test, we insure that *all* properties of flags get restored. In other
tests we only try changing the flag value.
"""
def test_assign_value(self):
# First save the flag values.
saved_flag_values = flagsaver.save_flag_values()
# Now mutate the flag's value field and check that it changed.
FLAGS.flagsaver_test_flag0 = 'new value'
self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
# Now restore the flag to its original value.
flagsaver.restore_flag_values(saved_flag_values)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
def test_set_default(self):
# First save the flag.
saved_flag_values = flagsaver.save_flag_values()
# Now mutate the flag's default field and check that it changed.
FLAGS.set_default('flagsaver_test_flag0', 'new_default')
self.assertEqual('new_default', FLAGS['flagsaver_test_flag0'].default)
# Now restore the flag's default field.
flagsaver.restore_flag_values(saved_flag_values)
self.assertEqual('unchanged0', FLAGS['flagsaver_test_flag0'].default)
def test_parse(self):
# First save the flag.
saved_flag_values = flagsaver.save_flag_values()
# Sanity check (would fail if called with --flagsaver_test_flag0).
self.assertEqual(0, FLAGS['flagsaver_test_flag0'].present)
# Now populate the flag and check that it changed.
FLAGS['flagsaver_test_flag0'].parse('new value')
self.assertEqual('new value', FLAGS['flagsaver_test_flag0'].value)
self.assertEqual(1, FLAGS['flagsaver_test_flag0'].present)
# Now restore the flag to its original value.
flagsaver.restore_flag_values(saved_flag_values)
self.assertEqual('unchanged0', FLAGS['flagsaver_test_flag0'].value)
self.assertEqual(0, FLAGS['flagsaver_test_flag0'].present)
def test_assign_validators(self):
# First save the flag.
saved_flag_values = flagsaver.save_flag_values()
# Sanity check that a validator already exists.
self.assertLen(FLAGS['flagsaver_test_flag0'].validators, 1)
original_validators = list(FLAGS['flagsaver_test_flag0'].validators)
def no_space(value):
return ' ' not in value
# Add a new validator.
flags.register_validator('flagsaver_test_flag0', no_space)
self.assertLen(FLAGS['flagsaver_test_flag0'].validators, 2)
# Now restore the flag to its original value.
flagsaver.restore_flag_values(saved_flag_values)
self.assertEqual(
original_validators, FLAGS['flagsaver_test_flag0'].validators
)
@parameterized.named_parameters(
dict(
testcase_name='flagsaver.flagsaver',
flagsaver_method=flagsaver.flagsaver,
),
dict(
testcase_name='flagsaver.as_parsed',
flagsaver_method=flagsaver.as_parsed,
),
)
class NoOverridesTest(parameterized.TestCase):
"""Test flagsaver.flagsaver and flagsaver.as_parsed without overrides."""
def test_context_manager_with_call(self, flagsaver_method):
with flagsaver_method():
FLAGS.flagsaver_test_flag0 = 'new value'
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
def test_context_manager_with_exception(self, flagsaver_method):
with self.assertRaises(_TestError):
with flagsaver_method():
FLAGS.flagsaver_test_flag0 = 'new value'
# Simulate a failed test.
raise _TestError('something happened')
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
def test_decorator_without_call(self, flagsaver_method):
@flagsaver_method
def mutate_flags():
FLAGS.flagsaver_test_flag0 = 'new value'
mutate_flags()
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
def test_decorator_with_call(self, flagsaver_method):
@flagsaver_method()
def mutate_flags():
FLAGS.flagsaver_test_flag0 = 'new value'
mutate_flags()
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
def test_decorator_with_exception(self, flagsaver_method):
@flagsaver_method()
def raise_exception():
FLAGS.flagsaver_test_flag0 = 'new value'
# Simulate a failed test.
raise _TestError('something happened')
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
self.assertRaises(_TestError, raise_exception)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
@parameterized.named_parameters(
dict(
testcase_name='flagsaver.flagsaver',
flagsaver_method=flagsaver.flagsaver,
),
dict(
testcase_name='flagsaver.as_parsed',
flagsaver_method=flagsaver.as_parsed,
),
)
class TestStringFlagOverrides(parameterized.TestCase):
"""Test flagsaver.flagsaver and flagsaver.as_parsed with string overrides.
Note that these tests can be parameterized because both .flagsaver and
.as_parsed expect a str input when overriding a string flag. For non-string
flags these two flagsaver methods have separate tests elsewhere in this file.
Each test is one class of overrides, executed twice. Once as a context
manager, and once as a decorator on a mutate_flags() method.
"""
def test_keyword_overrides(self, flagsaver_method):
# Context manager:
with flagsaver_method(flagsaver_test_flag0='new value'):
self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
# Decorator:
@flagsaver_method(flagsaver_test_flag0='new value')
def mutate_flags():
self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
mutate_flags()
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
def test_flagholder_overrides(self, flagsaver_method):
with flagsaver_method((STR_FLAG, 'new value')):
self.assertEqual('new value', STR_FLAG.value)
self.assertEqual('str default', STR_FLAG.value)
@flagsaver_method((STR_FLAG, 'new value'))
def mutate_flags():
self.assertEqual('new value', STR_FLAG.value)
mutate_flags()
self.assertEqual('str default', STR_FLAG.value)
def test_keyword_and_flagholder_overrides(self, flagsaver_method):
with flagsaver_method(
(STR_FLAG, 'another value'), flagsaver_test_flag0='new value'
):
self.assertEqual('another value', STR_FLAG.value)
self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
self.assertEqual('str default', STR_FLAG.value)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
@flagsaver_method(
(STR_FLAG, 'another value'), flagsaver_test_flag0='new value'
)
def mutate_flags():
self.assertEqual('another value', STR_FLAG.value)
self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
mutate_flags()
self.assertEqual('str default', STR_FLAG.value)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
def test_cross_validated_overrides_set_together(self, flagsaver_method):
# When the flags are set in the same flagsaver call their validators will
# be triggered only once the setting is done.
with flagsaver_method(
flagsaver_test_validated_flag1='new_value',
flagsaver_test_validated_flag2='new_value',
):
self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag1)
self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag2)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
@flagsaver_method(
flagsaver_test_validated_flag1='new_value',
flagsaver_test_validated_flag2='new_value',
)
def mutate_flags():
self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag1)
self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag2)
mutate_flags()
self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
def test_cross_validated_overrides_set_badly(self, flagsaver_method):
# Different values should violate the validator.
with self.assertRaisesRegex(
flags.IllegalFlagValueError, 'Flag validation failed'
):
with flagsaver_method(
flagsaver_test_validated_flag1='new_value',
flagsaver_test_validated_flag2='other_value',
):
pass
self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
@flagsaver_method(
flagsaver_test_validated_flag1='new_value',
flagsaver_test_validated_flag2='other_value',
)
def mutate_flags():
pass
self.assertRaisesRegex(
flags.IllegalFlagValueError, 'Flag validation failed', mutate_flags
)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
def test_cross_validated_overrides_set_separately(self, flagsaver_method):
# Setting just one flag will trip the validator as well.
with self.assertRaisesRegex(
flags.IllegalFlagValueError, 'Flag validation failed'
):
with flagsaver_method(flagsaver_test_validated_flag1='new_value'):
pass
self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
@flagsaver_method(flagsaver_test_validated_flag1='new_value')
def mutate_flags():
pass
self.assertRaisesRegex(
flags.IllegalFlagValueError, 'Flag validation failed', mutate_flags
)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
def test_validation_exception(self, flagsaver_method):
with self.assertRaises(flags.IllegalFlagValueError):
with flagsaver_method(
flagsaver_test_flag0='new value',
flagsaver_test_validated_flag='new value',
):
pass
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag)
@flagsaver_method(
flagsaver_test_flag0='new value',
flagsaver_test_validated_flag='new value',
)
def mutate_flags():
pass
self.assertRaises(flags.IllegalFlagValueError, mutate_flags)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag)
def test_unknown_flag_raises_exception(self, flagsaver_method):
self.assertNotIn('this_flag_does_not_exist', FLAGS)
# Flagsaver raises an error when trying to override a non-existent flag.
with self.assertRaises(flags.UnrecognizedFlagError):
with flagsaver_method(
flagsaver_test_flag0='new value', this_flag_does_not_exist='new value'
):
pass
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
@flagsaver_method(
flagsaver_test_flag0='new value', this_flag_does_not_exist='new value'
)
def mutate_flags():
pass
self.assertRaises(flags.UnrecognizedFlagError, mutate_flags)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
# Make sure flagsaver didn't create the flag at any point.
self.assertNotIn('this_flag_does_not_exist', FLAGS)
class AsParsedTest(absltest.TestCase):
def test_parse_context_manager_sets_present_and_using_default(self):
self.assertFalse(INT_FLAG.present)
self.assertFalse(STR_FLAG.present)
# Note that .using_default_value isn't available on the FlagHolder directly.
self.assertTrue(FLAGS[INT_FLAG.name].using_default_value)
self.assertTrue(FLAGS[STR_FLAG.name].using_default_value)
with flagsaver.as_parsed((INT_FLAG, '123'),
flagsaver_test_str_flag='new value'):
self.assertTrue(INT_FLAG.present)
self.assertTrue(STR_FLAG.present)
self.assertFalse(FLAGS[INT_FLAG.name].using_default_value)
self.assertFalse(FLAGS[STR_FLAG.name].using_default_value)
self.assertFalse(INT_FLAG.present)
self.assertFalse(STR_FLAG.present)
self.assertTrue(FLAGS[INT_FLAG.name].using_default_value)
self.assertTrue(FLAGS[STR_FLAG.name].using_default_value)
def test_parse_decorator_sets_present_and_using_default(self):
self.assertFalse(INT_FLAG.present)
self.assertFalse(STR_FLAG.present)
# Note that .using_default_value isn't available on the FlagHolder directly.
self.assertTrue(FLAGS[INT_FLAG.name].using_default_value)
self.assertTrue(FLAGS[STR_FLAG.name].using_default_value)
@flagsaver.as_parsed((INT_FLAG, '123'), flagsaver_test_str_flag='new value')
def some_func():
self.assertTrue(INT_FLAG.present)
self.assertTrue(STR_FLAG.present)
self.assertFalse(FLAGS[INT_FLAG.name].using_default_value)
self.assertFalse(FLAGS[STR_FLAG.name].using_default_value)
some_func()
self.assertFalse(INT_FLAG.present)
self.assertFalse(STR_FLAG.present)
self.assertTrue(FLAGS[INT_FLAG.name].using_default_value)
self.assertTrue(FLAGS[STR_FLAG.name].using_default_value)
def test_parse_decorator_with_multi_int_flag(self):
self.assertFalse(MULTI_INT_FLAG.present)
self.assertIsNone(MULTI_INT_FLAG.value)
@flagsaver.as_parsed((MULTI_INT_FLAG, ['123', '456']))
def assert_flags_updated():
self.assertTrue(MULTI_INT_FLAG.present)
self.assertCountEqual([123, 456], MULTI_INT_FLAG.value)
assert_flags_updated()
self.assertFalse(MULTI_INT_FLAG.present)
self.assertIsNone(MULTI_INT_FLAG.value)
def test_parse_raises_type_error(self):
with self.assertRaisesRegex(
TypeError,
r'flagsaver\.as_parsed\(\) cannot parse flagsaver_test_int_flag\. '
r'Expected a single string or sequence of strings but .*int.* was '
r'provided\.'):
manager = flagsaver.as_parsed(flagsaver_test_int_flag=123) # pytype: disable=wrong-arg-types
del manager
class SetUpTearDownTest(absltest.TestCase):
"""Example using a single flagsaver in setUp."""
def setUp(self):
super().setUp()
self.saved_flag_values = flagsaver.save_flag_values()
def tearDown(self):
super().tearDown()
flagsaver.restore_flag_values(self.saved_flag_values)
def test_mutate1(self):
# Even though other test cases change the flag, it should be
# restored to 'unchanged0' if the flagsaver is working.
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
FLAGS.flagsaver_test_flag0 = 'changed0'
def test_mutate2(self):
# Even though other test cases change the flag, it should be
# restored to 'unchanged0' if the flagsaver is working.
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
FLAGS.flagsaver_test_flag0 = 'changed0'
@parameterized.named_parameters(
dict(
testcase_name='flagsaver.flagsaver',
flagsaver_method=flagsaver.flagsaver,
),
dict(
testcase_name='flagsaver.as_parsed',
flagsaver_method=flagsaver.as_parsed,
),
)
class BadUsageTest(parameterized.TestCase):
"""Tests that improper usage (such as decorating a class) raise errors."""
def test_flag_saver_on_class(self, flagsaver_method):
with self.assertRaises(TypeError):
# WRONG. Don't do this.
# Consider the correct usage example in FlagSaverSetUpTearDownUsageTest.
@flagsaver_method
class FooTest(absltest.TestCase):
def test_tautology(self):
pass
del FooTest
def test_flag_saver_call_on_class(self, flagsaver_method):
with self.assertRaises(TypeError):
# WRONG. Don't do this.
# Consider the correct usage example in FlagSaverSetUpTearDownUsageTest.
@flagsaver_method()
class FooTest(absltest.TestCase):
def test_tautology(self):
pass
del FooTest
def test_flag_saver_with_overrides_on_class(self, flagsaver_method):
with self.assertRaises(TypeError):
# WRONG. Don't do this.
# Consider the correct usage example in FlagSaverSetUpTearDownUsageTest.
@flagsaver_method(foo='bar')
class FooTest(absltest.TestCase):
def test_tautology(self):
pass
del FooTest
def test_multiple_positional_parameters(self, flagsaver_method):
with self.assertRaises(ValueError):
func_a = lambda: None
func_b = lambda: None
flagsaver_method(func_a, func_b)
def test_both_positional_and_keyword_parameters(self, flagsaver_method):
with self.assertRaises(ValueError):
func_a = lambda: None
flagsaver_method(func_a, flagsaver_test_flag0='new value')
def test_duplicate_holder_parameters(self, flagsaver_method):
with self.assertRaises(ValueError):
flagsaver_method((INT_FLAG, 45), (INT_FLAG, 45))
def test_duplicate_holder_and_kw_parameter(self, flagsaver_method):
with self.assertRaises(ValueError):
flagsaver_method((INT_FLAG, 45), **{INT_FLAG.name: 45})
def test_both_positional_and_holder_parameters(self, flagsaver_method):
with self.assertRaises(ValueError):
func_a = lambda: None
flagsaver_method(func_a, (INT_FLAG, 45))
def test_holder_parameters_wrong_shape(self, flagsaver_method):
with self.assertRaises(ValueError):
flagsaver_method(INT_FLAG)
def test_holder_parameters_tuple_too_long(self, flagsaver_method):
with self.assertRaises(ValueError):
# Even if it is a bool flag, it should be a tuple
flagsaver_method((INT_FLAG, 4, 5))
def test_holder_parameters_tuple_wrong_type(self, flagsaver_method):
with self.assertRaises(ValueError):
# Even if it is a bool flag, it should be a tuple
flagsaver_method((4, INT_FLAG))
def test_both_wrong_positional_parameters(self, flagsaver_method):
with self.assertRaises(ValueError):
func_a = lambda: None
flagsaver_method(func_a, STR_FLAG, '45')
def test_context_manager_no_call(self, flagsaver_method):
# The exact exception that's raised appears to be system specific.
with self.assertRaises((AttributeError, TypeError)):
# Wrong. You must call the flagsaver method before using it as a CM.
with flagsaver_method:
# We don't expect to get here. A type error should happen when
# attempting to enter the context manager.
pass
if __name__ == '__main__':
absltest.main()