From 9b6315a2ba8c9274f09bec7bf2aeb2a4d2732f98 Mon Sep 17 00:00:00 2001 From: Andrew Leech Date: Tue, 3 May 2022 17:07:48 +1000 Subject: [PATCH] unittest: Add exception capturing for subTest. --- python-stdlib/unittest/test_unittest.py | 8 +++ python-stdlib/unittest/unittest.py | 75 +++++++++++++++++++------ 2 files changed, 67 insertions(+), 16 deletions(-) diff --git a/python-stdlib/unittest/test_unittest.py b/python-stdlib/unittest/test_unittest.py index 690fb40d..1b02056d 100644 --- a/python-stdlib/unittest/test_unittest.py +++ b/python-stdlib/unittest/test_unittest.py @@ -148,6 +148,14 @@ class TestUnittestAssertions(unittest.TestCase): assert global_context is None global_context = True + def test_subtest_even(self): + """ + Test that numbers between 0 and 5 are all even. + """ + for i in range(0, 10, 2): + with self.subTest("Should only pass for even numbers", i=i): + self.assertEqual(i % 2, 0) + if __name__ == "__main__": unittest.main() diff --git a/python-stdlib/unittest/unittest.py b/python-stdlib/unittest/unittest.py index dd4bb607..8d20acd7 100644 --- a/python-stdlib/unittest/unittest.py +++ b/python-stdlib/unittest/unittest.py @@ -37,11 +37,35 @@ class AssertRaisesContext: return False +# These are used to provide required context to things like subTest +__current_test__ = None +__test_result__ = None + + +class SubtestContext: + def __enter__(self): + pass + + def __exit__(self, *exc_info): + if exc_info[0] is not None: + # Exception raised + global __test_result__, __current_test__ + handle_test_exception( + __current_test__, + __test_result__, + exc_info + ) + # Suppress the exception as we've captured it above + return True + + + + class NullContext: def __enter__(self): pass - def __exit__(self, a, b, c): + def __exit__(self, exc_type, exc_value, traceback): pass @@ -61,7 +85,7 @@ class TestCase: func(*args, **kwargs) def subTest(self, msg=None, **params): - return NullContext() + return SubtestContext(msg=msg, params=params) def skipTest(self, reason): raise SkipTest(reason) @@ -298,15 +322,29 @@ class TestResult: return self -def capture_exc(e): +def capture_exc(exc, traceback): buf = io.StringIO() if hasattr(sys, "print_exception"): - sys.print_exception(e, buf) + sys.print_exception(exc, buf) elif traceback is not None: - traceback.print_exception(None, e, sys.exc_info()[2], file=buf) + traceback.print_exception(None, exc, traceback, file=buf) return buf.getvalue() +def handle_test_exception(current_test: tuple, test_result: TestResult, exc_info: tuple): + exc = exc_info[1] + traceback = exc_info[2] + ex_str = capture_exc(exc, traceback) + if isinstance(exc, AssertionError): + test_result.failuresNum += 1 + test_result.failures.append((current_test, ex_str)) + print(" FAIL") + else: + test_result.errorsNum += 1 + test_result.errors.append((current_test, ex_str)) + print(" ERROR") + + def run_suite(c, test_result, suite_name=""): if isinstance(c, TestSuite): c.run(test_result) @@ -324,29 +362,34 @@ def run_suite(c, test_result, suite_name=""): except AttributeError: pass - def run_one(m): + def run_one(test_function): + global __test_result__, __current_test__ print("%s (%s) ..." % (name, suite_name), end="") set_up() + __test_result__ = test_result + test_container = f"({suite_name})" + __current_test__ = (name, test_container) try: test_result.testsRun += 1 - m() + test_globals = dict(**globals()) + test_globals["test_function"] = test_function + exec("test_function()", test_globals, test_globals) + # No exception occurred, test passed print(" ok") except SkipTest as e: print(" skipped:", e.args[0]) test_result.skippedNum += 1 except Exception as ex: - ex_str = capture_exc(ex) - if isinstance(ex, AssertionError): - test_result.failuresNum += 1 - test_result.failures.append(((name, c), ex_str)) - print(" FAIL") - else: - test_result.errorsNum += 1 - test_result.errors.append(((name, c), ex_str)) - print(" ERROR") + handle_test_exception( + current_test=(name, c), + test_result=test_result, + exc_info=sys.exc_info() + ) # Uncomment to investigate failure in detail # raise finally: + __test_result__ = None + __current_test__ = None tear_down() try: o.doCleanups()