Update parameterized.CoopTestCase to work with python3 metaclasses.
By setting and accessing __metaclass__, we were not achieving the goals of this
function, which is to make it possible to have both the functionality of
paramterized.TestCase and another TestCase instance with its own metaclass.
Because there are client classes that are using CoopTestCase to combine parameterized.TestCase with another class that isn't using a metaclass we need to support that use-case. We do it by returning a simple multiple inheritance subclass without a combined metaclass, and emit a warning telling the author they don't need to use this method.
PiperOrigin-RevId: 548732415
diff --git a/CHANGELOG.md b/CHANGELOG.md
index e4ec771..34840af 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -26,6 +26,11 @@
`enum_values` is provided as a single string value. Additionally,
`EnumParser.enum_values` is now stored as a list copy of the provided
`enum_values` parameter.
+* (tesing) Updated `paramaterized.CoopTestCase()` to use Python 3 metaclass
+ idioms. Most uses of this function continued working during the Python 3
+ migration still worked because a Python 2 compatibility `__metaclass__`
+ variables also existed. Now pure Python 3 base classes without backwards
+ compatibility will work as intended.
## 1.4.0 (2023-01-11)
diff --git a/absl/testing/parameterized.py b/absl/testing/parameterized.py
index 3eb773c..d3d2c2b 100644
--- a/absl/testing/parameterized.py
+++ b/absl/testing/parameterized.py
@@ -217,6 +217,7 @@
import re
import types
import unittest
+import warnings
from absl.testing import absltest
@@ -697,10 +698,27 @@
Returns:
A new class object.
"""
- metaclass = type(
- 'CoopMetaclass',
- (other_base_class.__metaclass__,
- TestGeneratorMetaclass), {})
- return metaclass(
- 'CoopTestCase',
- (other_base_class, TestCase), {})
+ # If the other base class has a metaclass of 'type' then trying to combine
+ # the metaclasses will result in an MRO error. So simply combine them and
+ # return.
+ if type(other_base_class) == type: # pylint: disable=unidiomatic-typecheck
+ warnings.warn(
+ 'CoopTestCase is only necessary when combining with a class that uses'
+ ' a metaclass. Use multiple inheritance like this instead: class'
+ f' ExampleTest(paramaterized.TestCase, {other_base_class.__name__}):',
+ stacklevel=2,
+ )
+
+ class CoopTestCaseBase(other_base_class, TestCase):
+ pass
+
+ return CoopTestCaseBase
+ else:
+
+ class CoopMetaclass(type(other_base_class), TestGeneratorMetaclass): # pylint: disable=unused-variable
+ pass
+
+ class CoopTestCaseBase(other_base_class, TestCase, metaclass=CoopMetaclass):
+ pass
+
+ return CoopTestCaseBase
diff --git a/absl/testing/tests/parameterized_test.py b/absl/testing/tests/parameterized_test.py
index 8acbd93..731cc76 100644
--- a/absl/testing/tests/parameterized_test.py
+++ b/absl/testing/tests/parameterized_test.py
@@ -15,6 +15,7 @@
"""Tests for absl.testing.parameterized."""
from collections import abc
+import os
import sys
import unittest
@@ -27,7 +28,6 @@
def dummy_decorator(method):
-
def decorated(*args, **kwargs):
return method(*args, **kwargs)
@@ -48,6 +48,7 @@
Returns:
The test decorator
"""
+
def decorator(test_method):
# If decorating result of another dict_decorator
if isinstance(test_method, abc.Iterable):
@@ -62,10 +63,11 @@
test_method.testcases = actual_tests
return test_method
else:
- test_suffix = ('_%s_%s') % (key, value)
+ test_suffix = '_%s_%s' % (key, value)
tests_to_make = ((test_suffix, {key: value}),)
# 'test_method' here is the original test method
return parameterized.named_parameters(*tests_to_make)(test_method)
+
return decorator
@@ -75,9 +77,7 @@
class GoodAdditionParams(parameterized.TestCase):
- @parameterized.parameters(
- (1, 2, 3),
- (4, 5, 9))
+ @parameterized.parameters((1, 2, 3), (4, 5, 9))
def test_addition(self, op1, op2, result):
self.arguments = (op1, op2, result)
self.assertEqual(result, op1 + op2)
@@ -85,17 +85,13 @@
# This class does not inherit from TestCase.
class BadAdditionParams(absltest.TestCase):
- @parameterized.parameters(
- (1, 2, 3),
- (4, 5, 9))
+ @parameterized.parameters((1, 2, 3), (4, 5, 9))
def test_addition(self, op1, op2, result):
pass # Always passes, but not called w/out TestCase.
class MixedAdditionParams(parameterized.TestCase):
- @parameterized.parameters(
- (1, 2, 1),
- (4, 5, 9))
+ @parameterized.parameters((1, 2, 1), (4, 5, 9))
def test_addition(self, op1, op2, result):
self.arguments = (op1, op2, result)
self.assertEqual(result, op1 + op2)
@@ -103,8 +99,8 @@
class DictionaryArguments(parameterized.TestCase):
@parameterized.parameters(
- {'op1': 1, 'op2': 2, 'result': 3},
- {'op1': 4, 'op2': 5, 'result': 9})
+ {'op1': 1, 'op2': 2, 'result': 3}, {'op1': 4, 'op2': 5, 'result': 9}
+ )
def test_addition(self, op1, op2, result):
self.assertEqual(result, op1 + op2)
@@ -238,13 +234,13 @@
@dict_decorator('cone', 'waffle')
@dict_decorator('flavor', 'strawberry')
def test_chained(self, dictionary):
- self.assertDictEqual(dictionary, {'cone': 'waffle',
- 'flavor': 'strawberry'})
+ self.assertDictEqual(
+ dictionary, {'cone': 'waffle', 'flavor': 'strawberry'}
+ )
class SingletonListExtraction(parameterized.TestCase):
- @parameterized.parameters(
- (i, i * 2) for i in range(10))
+ @parameterized.parameters((i, i * 2) for i in range(10))
def test_something(self, unused_1, unused_2):
pass
@@ -264,9 +260,7 @@
def test_something(self, op1, op2):
del op1, op2
- @parameterized.parameters(
- (1, 2, 3),
- (4, 5, 9))
+ @parameterized.parameters((1, 2, 3), (4, 5, 9))
class DecoratedClass(parameterized.TestCase):
def test_add(self, arg1, arg2, arg3):
@@ -276,7 +270,8 @@
self.assertEqual(arg3 + arg2, arg1)
@parameterized.parameters(
- (a, b, a+b) for a in range(1, 5) for b in range(1, 5))
+ (a, b, a + b) for a in range(1, 5) for b in range(1, 5)
+ )
class GeneratorDecoratedClass(parameterized.TestCase):
def test_add(self, arg1, arg2, arg3):
@@ -322,14 +317,14 @@
@dummy_decorator
@parameterized.named_parameters(
- {'testcase_name': 'a', 'arg1': 1},
- {'testcase_name': 'b', 'arg1': 2})
+ {'testcase_name': 'a', 'arg1': 1}, {'testcase_name': 'b', 'arg1': 2}
+ )
def test_other_then_parameterized(self, arg1):
pass
@parameterized.named_parameters(
- {'testcase_name': 'a', 'arg1': 1},
- {'testcase_name': 'b', 'arg1': 2})
+ {'testcase_name': 'a', 'arg1': 1}, {'testcase_name': 'b', 'arg1': 2}
+ )
@dummy_decorator
def test_parameterized_then_other(self, arg1):
pass
@@ -380,7 +375,8 @@
@unittest.skipIf(
(sys.version_info[:2] == (3, 7) and sys.version_info[2] in {0, 1, 2}),
'Python 3.7.0 to 3.7.2 have a bug that breaks this test, see '
- 'https://bugs.python.org/issue35767')
+ 'https://bugs.python.org/issue35767',
+ )
def test_missing_inheritance(self):
ts = unittest.makeSuite(self.BadAdditionParams)
self.assertEqual(1, ts.countTestCases())
@@ -407,9 +403,7 @@
ts = unittest.makeSuite(self.GoodAdditionParams)
res = unittest.TestResult()
- params = set([
- (1, 2, 3),
- (4, 5, 9)])
+ params = set([(1, 2, 3), (4, 5, 9)])
for test in ts:
test(res)
self.assertIn(test.arguments, params)
@@ -432,38 +426,46 @@
short_desc = list(ts)[0].shortDescription()
location = unittest.util.strclass(self.GoodAdditionParams).replace(
- '__main__.', '')
- expected = ('{}.test_addition0 (1, 2, 3)\n'.format(location) +
- 'test_addition(1, 2, 3)')
+ '__main__.', ''
+ )
+ expected = (
+ '{}.test_addition0 (1, 2, 3)\n'.format(location)
+ + 'test_addition(1, 2, 3)'
+ )
self.assertEqual(expected, short_desc)
def test_short_description_addresses_removed(self):
ts = unittest.makeSuite(self.ArgumentsWithAddresses)
short_desc = list(ts)[0].shortDescription().split('\n')
- self.assertEqual(
- 'test_something(<object>)', short_desc[1])
+ self.assertEqual('test_something(<object>)', short_desc[1])
short_desc = list(ts)[1].shortDescription().split('\n')
- self.assertEqual(
- 'test_something(<__main__.MyOwnClass>)', short_desc[1])
+ self.assertEqual('test_something(<__main__.MyOwnClass>)', short_desc[1])
def test_id(self):
ts = unittest.makeSuite(self.ArgumentsWithAddresses)
self.assertEqual(
- (unittest.util.strclass(self.ArgumentsWithAddresses) +
- '.test_something0 (<object>)'),
- list(ts)[0].id())
+ (
+ unittest.util.strclass(self.ArgumentsWithAddresses)
+ + '.test_something0 (<object>)'
+ ),
+ list(ts)[0].id(),
+ )
ts = unittest.makeSuite(self.GoodAdditionParams)
self.assertEqual(
- (unittest.util.strclass(self.GoodAdditionParams) +
- '.test_addition0 (1, 2, 3)'),
- list(ts)[0].id())
+ (
+ unittest.util.strclass(self.GoodAdditionParams)
+ + '.test_addition0 (1, 2, 3)'
+ ),
+ list(ts)[0].id(),
+ )
def test_str(self):
ts = unittest.makeSuite(self.GoodAdditionParams)
test = list(ts)[0]
expected = 'test_addition0 (1, 2, 3) ({})'.format(
- unittest.util.strclass(self.GoodAdditionParams))
+ unittest.util.strclass(self.GoodAdditionParams)
+ )
self.assertEqual(expected, str(test))
def test_dict_parameters(self):
@@ -486,17 +488,13 @@
'{}.testNormal'.format(full_class_name),
'{}.test_normal'.format(full_class_name),
],
- short_descs)
+ short_descs,
+ )
def test_successful_product_test_testgrid(self):
-
class GoodProductTestCase(parameterized.TestCase):
- @parameterized.product(
- num=(0, 20, 80),
- modulo=(2, 4),
- expected=(0,)
- )
+ @parameterized.product(num=(0, 20, 80), modulo=(2, 4), expected=(0,))
def testModuloResult(self, num, modulo, expected):
self.assertEqual(expected, num % modulo)
@@ -508,12 +506,13 @@
self.assertTrue(res.wasSuccessful())
def test_successful_product_test_kwarg_seqs(self):
-
class GoodProductTestCase(parameterized.TestCase):
- @parameterized.product((dict(num=0), dict(num=20), dict(num=0)),
- (dict(modulo=2), dict(modulo=4)),
- (dict(expected=0),))
+ @parameterized.product(
+ (dict(num=0), dict(num=20), dict(num=0)),
+ (dict(modulo=2), dict(modulo=4)),
+ (dict(expected=0),),
+ )
def testModuloResult(self, num, modulo, expected):
self.assertEqual(expected, num % modulo)
@@ -525,12 +524,15 @@
self.assertTrue(res.wasSuccessful())
def test_successful_product_test_kwarg_seq_and_testgrid(self):
-
class GoodProductTestCase(parameterized.TestCase):
- @parameterized.product((dict(
- num=5, modulo=3, expected=2), dict(num=7, modulo=4, expected=3)),
- dtype=(int, float))
+ @parameterized.product(
+ (
+ dict(num=5, modulo=3, expected=2),
+ dict(num=7, modulo=4, expected=3),
+ ),
+ dtype=(int, float),
+ )
def testModuloResult(self, num, dtype, modulo, expected):
self.assertEqual(expected, dtype(num) % modulo)
@@ -546,8 +548,9 @@
class BadProductParams(parameterized.TestCase): # pylint: disable=unused-variable
- @parameterized.product((dict(num=5, modulo=3), dict(num=7, modula=2)),
- dtype=(int, float))
+ @parameterized.product(
+ (dict(num=5, modulo=3), dict(num=7, modula=2)), dtype=(int, float)
+ )
def test_something(self):
pass # not called because argnames are not the same
@@ -556,9 +559,11 @@
class BadProductParams(parameterized.TestCase): # pylint: disable=unused-variable
- @parameterized.product((dict(num=5, modulo=3), dict(num=7, modulo=4)),
- (dict(foo='bar', num=5), dict(foo='baz', num=7)),
- dtype=(int, float))
+ @parameterized.product(
+ (dict(num=5, modulo=3), dict(num=7, modulo=4)),
+ (dict(foo='bar', num=5), dict(foo='baz', num=7)),
+ dtype=(int, float),
+ )
def test_something(self):
pass # not called because `num` is specified twice
@@ -577,14 +582,9 @@
pass # not called because `foo` is specified twice
def test_product_recorded_failures(self):
-
class MixedProductTestCase(parameterized.TestCase):
- @parameterized.product(
- num=(0, 10, 20),
- modulo=(2, 4),
- expected=(0,)
- )
+ @parameterized.product(num=(0, 10, 20), modulo=(2, 4), expected=(0,))
def testModuloResult(self, num, modulo, expected):
self.assertEqual(expected, num % modulo)
@@ -599,13 +599,9 @@
self.assertEmpty(res.errors)
def test_mismatched_product_parameter(self):
-
class MismatchedProductParam(parameterized.TestCase):
- @parameterized.product(
- a=(1, 2),
- mismatch=(1, 2)
- )
+ @parameterized.product(a=(1, 2), mismatch=(1, 2))
# will fail because of mismatch in parameter names.
def test_something(self, a, b):
pass
@@ -637,6 +633,7 @@
def test_generator_tests_disallowed(self):
with self.assertRaisesRegex(RuntimeError, 'generated.*without'):
+
class GeneratorTests(parameterized.TestCase): # pylint: disable=unused-variable
test_generator_method = (lambda self: None for _ in range(10))
@@ -649,74 +646,53 @@
self.assertTrue(res.wasSuccessful())
def test_named_parameters_id(self):
- ts = sorted(unittest.makeSuite(self.CamelCaseNamedTests),
- key=lambda t: t.id())
+ ts = sorted(
+ unittest.makeSuite(self.CamelCaseNamedTests), key=lambda t: t.id()
+ )
self.assertLen(ts, 9)
full_class_name = unittest.util.strclass(self.CamelCaseNamedTests)
+ self.assertEqual(full_class_name + '.testDictSingleInteresting', ts[0].id())
+ self.assertEqual(full_class_name + '.testDictSomethingBoring', ts[1].id())
self.assertEqual(
- full_class_name + '.testDictSingleInteresting',
- ts[0].id())
+ full_class_name + '.testDictSomethingInteresting', ts[2].id()
+ )
+ self.assertEqual(full_class_name + '.testMixedSomethingBoring', ts[3].id())
self.assertEqual(
- full_class_name + '.testDictSomethingBoring',
- ts[1].id())
- self.assertEqual(
- full_class_name + '.testDictSomethingInteresting',
- ts[2].id())
- self.assertEqual(
- full_class_name + '.testMixedSomethingBoring',
- ts[3].id())
- self.assertEqual(
- full_class_name + '.testMixedSomethingInteresting',
- ts[4].id())
- self.assertEqual(
- full_class_name + '.testSingleInteresting',
- ts[5].id())
- self.assertEqual(
- full_class_name + '.testSomethingBoring',
- ts[6].id())
- self.assertEqual(
- full_class_name + '.testSomethingInteresting',
- ts[7].id())
- self.assertEqual(
- full_class_name + '.testWithoutParameters',
- ts[8].id())
+ full_class_name + '.testMixedSomethingInteresting', ts[4].id()
+ )
+ self.assertEqual(full_class_name + '.testSingleInteresting', ts[5].id())
+ self.assertEqual(full_class_name + '.testSomethingBoring', ts[6].id())
+ self.assertEqual(full_class_name + '.testSomethingInteresting', ts[7].id())
+ self.assertEqual(full_class_name + '.testWithoutParameters', ts[8].id())
def test_named_parameters_id_with_underscore_case(self):
- ts = sorted(unittest.makeSuite(self.NamedTests),
- key=lambda t: t.id())
+ ts = sorted(unittest.makeSuite(self.NamedTests), key=lambda t: t.id())
self.assertLen(ts, 9)
full_class_name = unittest.util.strclass(self.NamedTests)
self.assertEqual(
- full_class_name + '.test_dict_single_interesting',
- ts[0].id())
+ full_class_name + '.test_dict_single_interesting', ts[0].id()
+ )
self.assertEqual(
- full_class_name + '.test_dict_something_boring',
- ts[1].id())
+ full_class_name + '.test_dict_something_boring', ts[1].id()
+ )
self.assertEqual(
- full_class_name + '.test_dict_something_interesting',
- ts[2].id())
+ full_class_name + '.test_dict_something_interesting', ts[2].id()
+ )
self.assertEqual(
- full_class_name + '.test_mixed_something_boring',
- ts[3].id())
+ full_class_name + '.test_mixed_something_boring', ts[3].id()
+ )
self.assertEqual(
- full_class_name + '.test_mixed_something_interesting',
- ts[4].id())
+ full_class_name + '.test_mixed_something_interesting', ts[4].id()
+ )
+ self.assertEqual(full_class_name + '.test_single_interesting', ts[5].id())
+ self.assertEqual(full_class_name + '.test_something_boring', ts[6].id())
self.assertEqual(
- full_class_name + '.test_single_interesting',
- ts[5].id())
- self.assertEqual(
- full_class_name + '.test_something_boring',
- ts[6].id())
- self.assertEqual(
- full_class_name + '.test_something_interesting',
- ts[7].id())
- self.assertEqual(
- full_class_name + '.test_without_parameters',
- ts[8].id())
+ full_class_name + '.test_something_interesting', ts[7].id()
+ )
+ self.assertEqual(full_class_name + '.test_without_parameters', ts[8].id())
def test_named_parameters_short_description(self):
- ts = sorted(unittest.makeSuite(self.NamedTests),
- key=lambda t: t.id())
+ ts = sorted(unittest.makeSuite(self.NamedTests), key=lambda t: t.id())
actual = {t._testMethodName: t.shortDescription() for t in ts}
expected = {
'test_dict_single_interesting': 'case=0',
@@ -734,8 +710,11 @@
def test_load_tuple_named_test(self):
loader = unittest.TestLoader()
- ts = list(loader.loadTestsFromName('NamedTests.test_something_interesting',
- module=self))
+ ts = list(
+ loader.loadTestsFromName(
+ 'NamedTests.test_something_interesting', module=self
+ )
+ )
self.assertLen(ts, 1)
self.assertEndsWith(ts[0].id(), '.test_something_interesting')
@@ -743,7 +722,9 @@
loader = unittest.TestLoader()
ts = list(
loader.loadTestsFromName(
- 'NamedTests.test_dict_something_interesting', module=self))
+ 'NamedTests.test_dict_something_interesting', module=self
+ )
+ )
self.assertLen(ts, 1)
self.assertEndsWith(ts[0].id(), '.test_dict_something_interesting')
@@ -751,7 +732,9 @@
loader = unittest.TestLoader()
ts = list(
loader.loadTestsFromName(
- 'NamedTests.test_mixed_something_interesting', module=self))
+ 'NamedTests.test_mixed_something_interesting', module=self
+ )
+ )
self.assertLen(ts, 1)
self.assertEndsWith(ts[0].id(), '.test_mixed_something_interesting')
@@ -886,7 +869,6 @@
pass
def test_double_class_decorations_not_supported(self):
-
@parameterized.parameters('foo', 'bar')
class SuperclassWithClassDecorator(parameterized.TestCase):
@@ -968,7 +950,6 @@
del test_something
def test_unique_descriptive_names(self):
-
class RecordSuccessTestsResult(unittest.TestResult):
def __init__(self, *args, **kwargs):
@@ -1009,10 +990,12 @@
def test_subclass_inherits_superclass_test_params_reprs(self):
self.assertEqual(
{'test_name0': "('foo')", 'test_name1': "('bar')"},
- self.SuperclassTestCase._test_params_reprs)
+ self.SuperclassTestCase._test_params_reprs,
+ )
self.assertEqual(
{'test_name0': "('foo')", 'test_name1': "('bar')"},
- self.SubclassTestCase._test_params_reprs)
+ self.SubclassTestCase._test_params_reprs,
+ )
def _decorate_with_side_effects(func, self):
@@ -1022,7 +1005,19 @@
class CoopMetaclassCreationTest(absltest.TestCase):
- class TestBase(absltest.TestCase):
+ class TestBaseMetaclass(type):
+
+ def __init__(cls, name, bases, dct):
+ type.__init__(cls, name, bases, dct)
+ for member_name, obj in dct.items():
+ if member_name.startswith('test'):
+ setattr(
+ cls,
+ member_name,
+ lambda self, f=obj: _decorate_with_side_effects(f, self),
+ )
+
+ class TestBase(absltest.TestCase, metaclass=TestBaseMetaclass):
# This test simulates a metaclass that sets some attribute ('sideeffect')
# on each member of the class that starts with 'test'. The test code then
@@ -1033,21 +1028,11 @@
# since the TestGeneratorMetaclass already overrides __new__. Only one
# base metaclass can override __new__, but all can provide custom __init__
# methods.
-
- class __metaclass__(type): # pylint: disable=g-bad-name
-
- def __init__(cls, name, bases, dct):
- type.__init__(cls, name, bases, dct)
- for member_name, obj in dct.items():
- if member_name.startswith('test'):
- setattr(cls, member_name,
- lambda self, f=obj: _decorate_with_side_effects(f, self))
+ pass
class MyParams(parameterized.CoopTestCase(TestBase)):
- @parameterized.parameters(
- (1, 2, 3),
- (4, 5, 9))
+ @parameterized.parameters((1, 2, 3), (4, 5, 9))
def test_addition(self, op1, op2, result):
self.assertEqual(result, op1 + op2)
@@ -1072,6 +1057,20 @@
ts.run(res)
self.assertTrue(list(ts)[0].sideeffect)
+ def test_no_metaclass(self):
+ class SimpleMixinTestCase(absltest.TestCase):
+ pass
+
+ with self.assertWarnsRegex(
+ UserWarning,
+ 'CoopTestCase is only necessary when combining with a class that uses a'
+ ' metaclass',
+ ) as warning:
+ parameterized.CoopTestCase(SimpleMixinTestCase)
+ self.assertEqual(
+ os.path.basename(warning.filename), 'parameterized_test.py'
+ )
+
if __name__ == '__main__':
absltest.main()