Add exception handling?

perf-debug
nyanpasu64 2023-11-27 22:18:23 -08:00
rodzic 2447a3a45d
commit 16ea5e55ac
1 zmienionych plików z 69 dodań i 18 usunięć

Wyświetl plik

@ -10,7 +10,7 @@ from multiprocessing.shared_memory import SharedMemory
from pathlib import Path
from queue import Queue, Empty
from threading import Thread
from typing import Iterator, Optional, List, Callable, Tuple, Dict, Union
from typing import Iterator, Optional, List, Callable, Tuple, Dict, Union, Any
import attr
@ -178,6 +178,33 @@ def template_config(**kwargs) -> Config:
return attr.evolve(cfg, **kwargs)
class PropagatingThread(Thread):
# Based off https://stackoverflow.com/a/31614591 and Thread source code.
def run(self):
self.exc = None
try:
if self._target is not None:
self.ret = self._target(*self._args, **self._kwargs)
except BaseException as e:
self.exc = e
finally:
# Avoid a refcycle if the thread is running a function with
# an argument that has a member that points to the thread.
del self._target, self._args, self._kwargs
def join(self, timeout=None) -> Any:
try:
super(PropagatingThread, self).join(timeout)
if self.exc:
raise RuntimeError(f"exception from {self.name}") from self.exc
return self.ret
finally:
# If join() raises, set `self = None` to avoid a reference cycle with the
# backtrace, because concurrent.futures.Future.result() does it.
self = None
BeginFunc = Callable[[float, float], None]
ProgressFunc = Callable[[int], None]
IsAborted = Callable[[], bool]
@ -448,7 +475,7 @@ class CorrScope:
} # type: Dict[str, SharedMemory]
# TODO https://stackoverflow.com/questions/2829329/catch-a-threads-exception-in-the-caller-thread
def render_thread():
def _render_thread():
end_frame = thread_shared.end_frame
prev = -1
@ -519,6 +546,14 @@ class CorrScope:
render_to_output.put(None)
print("exit render")
def render_thread():
try:
_render_thread()
except BaseException as e:
abort_from_thread.set()
render_to_output.put(None)
raise e
global worker_render_frame # hack to allow pickling function
def worker_render_frame(
@ -534,7 +569,7 @@ class CorrScope:
shmem = SHMEMS[shmem_name]
shmem.buf[:] = frame_data
def output_thread():
def _output_thread():
while True:
if is_aborted():
for output in self.outputs:
@ -568,34 +603,50 @@ class CorrScope:
avail_shmems.put(render_msg.shmem)
if is_aborted():
# If is_aborted() is True but render_thread() is blocked on
# render_to_output.put(), then we need to clear the queue so
# render_thread() can return from put(), then check is_aborted()
# = True and terminate.
while True:
try:
render_msg = render_to_output.get(block=False)
if render_msg is None:
continue # probably empty?
avail_shmems.put(render_msg.shmem)
except Empty:
break
output_on_error()
print("exit output")
def output_on_error():
"""If is_aborted() is True but render_thread() is blocked on
render_to_output.put(), then we need to clear the queue so
render_thread() can return from put(), then check is_aborted() = True
and terminate."""
while True:
try:
render_msg = render_to_output.get(block=False)
if render_msg is None:
continue # probably empty?
avail_shmems.put(render_msg.shmem)
except Empty:
break
def output_thread():
try:
_output_thread()
except BaseException as e:
abort_from_thread.set()
output_on_error()
raise e
shmem_names: List[str] = [shmem.name for shmem in all_shmems]
with ProcessPoolExecutor(
nthread,
initializer=worker_create_renderer,
initargs=(renderer, shmem_names),
) as pool:
render_handle = Thread(target=render_thread, name="render_thread")
output_handle = Thread(target=output_thread, name="output_thread")
render_handle = PropagatingThread(
target=render_thread, name="render_thread"
)
output_handle = PropagatingThread(
target=output_thread, name="output_thread"
)
render_handle.start()
output_handle.start()
# throws
render_handle.join()
output_handle.join()