diff --git a/corrscope/corrscope.py b/corrscope/corrscope.py index fc6bcdb..22fc013 100644 --- a/corrscope/corrscope.py +++ b/corrscope/corrscope.py @@ -10,7 +10,7 @@ from multiprocessing.shared_memory import SharedMemory from pathlib import Path from queue import Queue, Empty from threading import Thread -from typing import Iterator, Optional, List, Callable, Tuple, Dict, Union +from typing import Iterator, Optional, List, Callable, Tuple, Dict, Union, Any import attr @@ -178,6 +178,33 @@ def template_config(**kwargs) -> Config: return attr.evolve(cfg, **kwargs) +class PropagatingThread(Thread): + # Based off https://stackoverflow.com/a/31614591 and Thread source code. + def run(self): + self.exc = None + try: + if self._target is not None: + self.ret = self._target(*self._args, **self._kwargs) + except BaseException as e: + self.exc = e + finally: + # Avoid a refcycle if the thread is running a function with + # an argument that has a member that points to the thread. + del self._target, self._args, self._kwargs + + def join(self, timeout=None) -> Any: + try: + super(PropagatingThread, self).join(timeout) + if self.exc: + raise RuntimeError(f"exception from {self.name}") from self.exc + + return self.ret + finally: + # If join() raises, set `self = None` to avoid a reference cycle with the + # backtrace, because concurrent.futures.Future.result() does it. + self = None + + BeginFunc = Callable[[float, float], None] ProgressFunc = Callable[[int], None] IsAborted = Callable[[], bool] @@ -448,7 +475,7 @@ class CorrScope: } # type: Dict[str, SharedMemory] # TODO https://stackoverflow.com/questions/2829329/catch-a-threads-exception-in-the-caller-thread - def render_thread(): + def _render_thread(): end_frame = thread_shared.end_frame prev = -1 @@ -519,6 +546,14 @@ class CorrScope: render_to_output.put(None) print("exit render") + def render_thread(): + try: + _render_thread() + except BaseException as e: + abort_from_thread.set() + render_to_output.put(None) + raise e + global worker_render_frame # hack to allow pickling function def worker_render_frame( @@ -534,7 +569,7 @@ class CorrScope: shmem = SHMEMS[shmem_name] shmem.buf[:] = frame_data - def output_thread(): + def _output_thread(): while True: if is_aborted(): for output in self.outputs: @@ -568,34 +603,50 @@ class CorrScope: avail_shmems.put(render_msg.shmem) if is_aborted(): - # If is_aborted() is True but render_thread() is blocked on - # render_to_output.put(), then we need to clear the queue so - # render_thread() can return from put(), then check is_aborted() - # = True and terminate. - while True: - try: - render_msg = render_to_output.get(block=False) - if render_msg is None: - continue # probably empty? - - avail_shmems.put(render_msg.shmem) - except Empty: - break + output_on_error() print("exit output") + def output_on_error(): + """If is_aborted() is True but render_thread() is blocked on + render_to_output.put(), then we need to clear the queue so + render_thread() can return from put(), then check is_aborted() = True + and terminate.""" + while True: + try: + render_msg = render_to_output.get(block=False) + if render_msg is None: + continue # probably empty? + + avail_shmems.put(render_msg.shmem) + except Empty: + break + + def output_thread(): + try: + _output_thread() + except BaseException as e: + abort_from_thread.set() + output_on_error() + raise e + shmem_names: List[str] = [shmem.name for shmem in all_shmems] with ProcessPoolExecutor( nthread, initializer=worker_create_renderer, initargs=(renderer, shmem_names), ) as pool: - render_handle = Thread(target=render_thread, name="render_thread") - output_handle = Thread(target=output_thread, name="output_thread") + render_handle = PropagatingThread( + target=render_thread, name="render_thread" + ) + output_handle = PropagatingThread( + target=output_thread, name="output_thread" + ) render_handle.start() output_handle.start() + # throws render_handle.join() output_handle.join()