unittest: Add support for specifying custom TestRunner.

pull/488/head
Andrew Leech 2022-05-05 21:17:48 +10:00
rodzic 2d61dbdb93
commit 959115d3a9
1 zmienionych plików z 19 dodań i 9 usunięć

Wyświetl plik

@ -263,7 +263,7 @@ class TestSuite:
class TestRunner: class TestRunner:
def run(self, suite): def run(self, suite: TestSuite):
res = TestResult() res = TestResult()
suite.run(res) suite.run(res)
@ -292,7 +292,8 @@ class TestResult:
self.testsRun = 0 self.testsRun = 0
self.errors = [] self.errors = []
self.failures = [] self.failures = []
self.newFailures = 0 self.skipped = []
self._newFailures = 0
def wasSuccessful(self): def wasSuccessful(self):
return self.errorsNum == 0 and self.failuresNum == 0 return self.errorsNum == 0 and self.failuresNum == 0
@ -326,6 +327,7 @@ class TestResult:
self.testsRun += other.testsRun self.testsRun += other.testsRun
self.errors.extend(other.errors) self.errors.extend(other.errors)
self.failures.extend(other.failures) self.failures.extend(other.failures)
self.skipped.extend(other.skipped)
return self return self
@ -354,10 +356,10 @@ def handle_test_exception(
test_result.errors.append((current_test, ex_str)) test_result.errors.append((current_test, ex_str))
if verbose: if verbose:
print(" ERROR") print(" ERROR")
test_result.newFailures += 1 test_result._newFailures += 1
def run_suite(c, test_result, suite_name=""): def run_suite(c, test_result: TestResult, suite_name=""):
if isinstance(c, TestSuite): if isinstance(c, TestSuite):
c.run(test_result) c.run(test_result)
return return
@ -384,19 +386,21 @@ def run_suite(c, test_result, suite_name=""):
test_container = f"({suite_name})" test_container = f"({suite_name})"
__current_test__ = (name, test_container) __current_test__ = (name, test_container)
try: try:
test_result.newFailures = 0 test_result._newFailures = 0
test_result.testsRun += 1 test_result.testsRun += 1
test_globals = dict(**globals()) test_globals = dict(**globals())
test_globals["test_function"] = test_function test_globals["test_function"] = test_function
exec("test_function()", test_globals, test_globals) exec("test_function()", test_globals, test_globals)
# No exception occurred, test passed # No exception occurred, test passed
if test_result.newFailures: if test_result._newFailures:
print(" FAIL") print(" FAIL")
else: else:
print(" ok") print(" ok")
except SkipTest as e: except SkipTest as e:
print(" skipped:", e.args[0]) reason = e.args[0]
print(" skipped:", reason)
test_result.skippedNum += 1 test_result.skippedNum += 1
test_result.skipped.append((name, c, reason))
except Exception as ex: except Exception as ex:
handle_test_exception( handle_test_exception(
current_test=(name, c), test_result=test_result, exc_info=sys.exc_info() current_test=(name, c), test_result=test_result, exc_info=sys.exc_info()
@ -477,7 +481,13 @@ def discover(runner: TestRunner):
return discover(runner=runner) return discover(runner=runner)
def main(module="__main__"): def main(module="__main__", testRunner=None):
if testRunner:
if isinstance(testRunner, type):
runner = testRunner()
else:
runner = testRunner
else:
runner = TestRunner() runner = TestRunner()
if len(sys.argv) <= 1: if len(sys.argv) <= 1: