diff --git a/corrscope/corrscope.py b/corrscope/corrscope.py index f2b1f21..74b1a3a 100644 --- a/corrscope/corrscope.py +++ b/corrscope/corrscope.py @@ -1,12 +1,15 @@ # -*- coding: utf-8 -*- import os.path +import threading import time +from concurrent.futures import ProcessPoolExecutor, Future from contextlib import ExitStack, contextmanager from enum import unique from fractions import Fraction from pathlib import Path -from typing import Iterator, Optional, List, Callable -from multiprocessing.pool import Pool +from queue import Queue +from threading import Thread +from typing import Iterator, Optional, List, Callable, Tuple import attr @@ -15,12 +18,17 @@ from corrscope.channel import Channel, ChannelConfig, DefaultLabel from corrscope.config import KeywordAttrs, DumpEnumAsStr, CorrError, with_units from corrscope.layout import LayoutConfig from corrscope.outputs import FFmpegOutputConfig, IOutputConfig -from corrscope.renderer import Renderer, RendererConfig, RendererFrontend, RenderInput +from corrscope.renderer import ( + Renderer, + RendererConfig, + RendererFrontend, + RenderInput, + ByteBuffer, +) from corrscope.triggers import ( CorrelationTriggerConfig, PerFrameCache, SpectrumConfig, - MainTrigger, ) from corrscope.util import pushd, coalesce from corrscope.wave import Wave, Flatten, FlattenOrStr @@ -351,9 +359,132 @@ class CorrScope: thread_shared.end_frame = frame + 1 break + # Multiprocess + def play_parallel(): + ncores = len(os.sched_getaffinity(0)) + + abort_from_thread = threading.Event() + # self.arg.is_aborted() from GUI, abort_from_thread.is_set() from thread + is_aborted = lambda: self.arg.is_aborted() or abort_from_thread.is_set() + + # Same size as ProcessPoolExecutor, so threads won't starve if they all + # finish a job at the same time. + render_to_output: "Queue[Tuple[int, Future[ByteBuffer]] | None]" = Queue( + ncores + ) + + def worker_create_renderer(renderer: RendererFrontend): + global WORKER_RENDERER + # TODO del self.renderer and recreate Renderer if it can't be pickled? + WORKER_RENDERER = renderer + + # TODO https://stackoverflow.com/questions/2829329/catch-a-threads-exception-in-the-caller-thread + def render_thread(): + end_frame = thread_shared.end_frame + prev = -1 + + # TODO gather trigger points from triggering threads + # For each frame, render each wave + for frame in range(begin_frame, end_frame): + if is_aborted(): + # Only count output-displayed frames, not rendered. + # # Used for FPS calculation + # thread_shared.end_frame = frame + + for output in self.outputs: + output.terminate() + break + + time_seconds = frame / fps + should_render = (frame - begin_frame) % render_subfps == ahead + + rounded = int(time_seconds) + if PRINT_TIMESTAMP and rounded != prev: + self.arg.progress(rounded) + prev = rounded + + render_inputs = [] + trigger_samples = [] + # Get render-data from each wave. + for render_wave, channel in zip(self.render_waves, self.channels): + sample = round(render_wave.smp_s * time_seconds) + + # Get trigger. + if not_benchmarking or benchmark_mode == BenchmarkMode.TRIGGER: + cache = PerFrameCache() + + result = channel.trigger.get_trigger(sample, cache) + trigger_sample = result.result + freq_estimate = result.freq_estimate + + else: + trigger_sample = sample + freq_estimate = 0 + + # Get render data. + if should_render: + trigger_samples.append(trigger_sample) + data = channel.get_render_around(trigger_sample) + render_inputs.append(RenderInput(data, freq_estimate)) + + if not should_render: + continue + + # blocking + render_to_output.put( + ( + frame, + pool.submit( + worker_render_frame, render_inputs, trigger_samples + ), + ) + ) + + render_to_output.put(None) + + def worker_render_frame( + render_inputs: List[RenderInput], trigger_samples: List[int] + ) -> ByteBuffer: + global WORKER_RENDERER + renderer = WORKER_RENDERER + renderer.update_main_lines(render_inputs, trigger_samples) + frame_data = renderer.get_frame() + return frame_data + + def output_thread(): + while True: + msg = render_to_output.get() # blocking + if msg is None: + break + frame, render_future = msg + frame_data: ByteBuffer = render_future.result() + + if not_benchmarking or benchmark_mode == BenchmarkMode.OUTPUT: + # Output frame + for output in self.outputs: + if output.write_frame(frame_data) is outputs_.Stop: + abort_from_thread.set() + break + if is_aborted(): + # Outputting frame happens after most computation finished. + thread_shared.end_frame = frame + 1 + break + + with ProcessPoolExecutor( + ncores, initializer=worker_create_renderer, initargs=(renderer,) + ) as pool: + render_handle = Thread(target=render_thread, name="render_thread") + output_handle = Thread(target=output_thread, name="output_thread") + + render_handle.start() + output_handle.start() + + render_handle.join() + output_handle.join() + with self._load_outputs(): if self.arg.parallel: - pass + play_parallel() else: play_impl()