kopia lustrzana https://github.com/corrscope/corrscope
Merge pull request #450 from corrscope/multiprocessing
Add multi-core rendering for increased preview/render speedpull/462/head
commit
67b136b9e8
|
@ -2,6 +2,8 @@
|
|||
|
||||
### Features
|
||||
|
||||
- Add multi-core rendering for increased preview/render speed (#450)
|
||||
|
||||
### Major Changes
|
||||
|
||||
### Changelog
|
||||
|
|
|
@ -6,6 +6,7 @@ from typing import Optional, List, Tuple, Union, cast, TypeVar
|
|||
import click
|
||||
|
||||
import corrscope
|
||||
import corrscope.settings.global_prefs as gp
|
||||
from corrscope.channel import ChannelConfig
|
||||
from corrscope.config import yaml
|
||||
from corrscope.corrscope import template_config, CorrScope, Config, Arguments
|
||||
|
@ -79,6 +80,7 @@ def get_file_stem(cfg_path: Optional[Path], cfg: Config, default: T) -> Union[st
|
|||
|
||||
CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"])
|
||||
|
||||
|
||||
# fmt: off
|
||||
@click.command(context_settings=CONTEXT_SETTINGS)
|
||||
# Inputs
|
||||
|
@ -232,7 +234,15 @@ def main(
|
|||
outputs.append(cfg.get_ffmpeg_cfg(video_path))
|
||||
|
||||
if outputs:
|
||||
arg = Arguments(cfg_dir=cfg_dir, outputs=outputs)
|
||||
try:
|
||||
pref = gp.load_prefs()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pref = gp.GlobalPrefs()
|
||||
|
||||
arg = Arguments(
|
||||
cfg_dir=cfg_dir, outputs=outputs, parallelism=pref.parallelism()
|
||||
)
|
||||
command = lambda: CorrScope(cfg, arg).play()
|
||||
if profile:
|
||||
first_song_name = Path(files[0]).name
|
||||
|
|
|
@ -56,7 +56,6 @@ class MyYAML(YAML):
|
|||
def dump(
|
||||
self, data: Any, stream: "Union[Path, TextIO, None]" = None, **kwargs
|
||||
) -> Optional[str]:
|
||||
|
||||
# On Windows, when dumping to path, ruamel.yaml writes files in locale encoding.
|
||||
# Foreign characters are undumpable. Locale-compatible characters cannot be loaded.
|
||||
# https://bitbucket.org/ruamel/yaml/issues/316/unicode-encoding-decoding-errors-on
|
||||
|
@ -154,6 +153,7 @@ T = TypeVar("T")
|
|||
# yaml.dump(obj, stream)
|
||||
# return yaml.load(stream)
|
||||
|
||||
|
||||
# AKA pickle_copy
|
||||
def copy_config(obj: T) -> T:
|
||||
with BytesIO() as stream:
|
||||
|
|
|
@ -1,11 +1,16 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import os.path
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ProcessPoolExecutor, Future
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from enum import unique
|
||||
from fractions import Fraction
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Optional, List, Callable
|
||||
from queue import Queue, Empty
|
||||
from threading import Thread
|
||||
from typing import Iterator, Optional, List, Callable, Dict, Union, Any
|
||||
|
||||
import attr
|
||||
|
||||
|
@ -14,12 +19,17 @@ from corrscope.channel import Channel, ChannelConfig, DefaultLabel
|
|||
from corrscope.config import KeywordAttrs, DumpEnumAsStr, CorrError, with_units
|
||||
from corrscope.layout import LayoutConfig
|
||||
from corrscope.outputs import FFmpegOutputConfig, IOutputConfig
|
||||
from corrscope.renderer import Renderer, RendererConfig, RendererFrontend, RenderInput
|
||||
from corrscope.renderer import (
|
||||
Renderer,
|
||||
RendererConfig,
|
||||
RendererParams,
|
||||
RenderInput,
|
||||
)
|
||||
from corrscope.settings.global_prefs import Parallelism
|
||||
from corrscope.triggers import (
|
||||
CorrelationTriggerConfig,
|
||||
PerFrameCache,
|
||||
SpectrumConfig,
|
||||
MainTrigger,
|
||||
)
|
||||
from corrscope.util import pushd, coalesce
|
||||
from corrscope.wave import Wave, Flatten, FlattenOrStr
|
||||
|
@ -139,6 +149,33 @@ def template_config(**kwargs) -> Config:
|
|||
return attr.evolve(cfg, **kwargs)
|
||||
|
||||
|
||||
class PropagatingThread(Thread):
|
||||
# Based off https://stackoverflow.com/a/31614591 and Thread source code.
|
||||
def run(self):
|
||||
self.exc = None
|
||||
try:
|
||||
if self._target is not None:
|
||||
self.ret = self._target(*self._args, **self._kwargs)
|
||||
except BaseException as e:
|
||||
self.exc = e
|
||||
finally:
|
||||
# Avoid a refcycle if the thread is running a function with
|
||||
# an argument that has a member that points to the thread.
|
||||
del self._target, self._args, self._kwargs
|
||||
|
||||
def join(self, timeout=None) -> Any:
|
||||
try:
|
||||
super(PropagatingThread, self).join(timeout)
|
||||
if self.exc:
|
||||
raise RuntimeError(f"exception from {self.name}") from self.exc
|
||||
|
||||
return self.ret
|
||||
finally:
|
||||
# If join() raises, set `self = None` to avoid a reference cycle with the
|
||||
# backtrace, because concurrent.futures.Future.result() does it.
|
||||
self = None
|
||||
|
||||
|
||||
BeginFunc = Callable[[float, float], None]
|
||||
ProgressFunc = Callable[[int], None]
|
||||
IsAborted = Callable[[], bool]
|
||||
|
@ -148,6 +185,7 @@ IsAborted = Callable[[], bool]
|
|||
class Arguments:
|
||||
cfg_dir: str
|
||||
outputs: List[outputs_.IOutputConfig]
|
||||
parallelism: Optional[Parallelism] = None
|
||||
|
||||
on_begin: BeginFunc = lambda begin_time, end_time: None
|
||||
progress: ProgressFunc = lambda p: print(p, flush=True)
|
||||
|
@ -155,6 +193,44 @@ class Arguments:
|
|||
on_end: Callable[[], None] = lambda: None
|
||||
|
||||
|
||||
def worker_create_renderer(renderer_params: RendererParams, shmem_names: List[str]):
|
||||
import appnope
|
||||
|
||||
# Disable power saving for renderer processes.
|
||||
appnope.nope()
|
||||
|
||||
global WORKER_RENDERER
|
||||
global SHMEMS
|
||||
|
||||
WORKER_RENDERER = Renderer(renderer_params)
|
||||
SHMEMS = {
|
||||
name: SharedMemory(name) for name in shmem_names
|
||||
} # type: Dict[str, SharedMemory]
|
||||
|
||||
|
||||
prev = 0.0
|
||||
|
||||
|
||||
def worker_render_frame(
|
||||
render_inputs: List[RenderInput],
|
||||
trigger_samples: List[int],
|
||||
shmem_name: str,
|
||||
):
|
||||
global WORKER_RENDERER, SHMEMS, prev
|
||||
t = time.perf_counter() * 1000.0
|
||||
|
||||
renderer = WORKER_RENDERER
|
||||
renderer.update_main_lines(render_inputs, trigger_samples)
|
||||
frame_data = renderer.get_frame()
|
||||
t1 = time.perf_counter() * 1000.0
|
||||
|
||||
shmem = SHMEMS[shmem_name]
|
||||
shmem.buf[: len(frame_data)] = frame_data
|
||||
t2 = time.perf_counter() * 1000.0
|
||||
# print(f"idle = {t - prev}, dt1 = {t1 - t}, dt2 = {t2 - t1}")
|
||||
prev = t2
|
||||
|
||||
|
||||
class CorrScope:
|
||||
def __init__(self, cfg: Config, arg: Arguments):
|
||||
"""cfg is mutated!
|
||||
|
@ -223,16 +299,19 @@ class CorrScope:
|
|||
]
|
||||
yield
|
||||
|
||||
def _load_renderer(self) -> RendererFrontend:
|
||||
def _renderer_params(self) -> RendererParams:
|
||||
dummy_datas = [channel.get_render_around(0) for channel in self.channels]
|
||||
renderer = Renderer(
|
||||
return RendererParams.from_obj(
|
||||
self.cfg.render,
|
||||
self.cfg.layout,
|
||||
dummy_datas,
|
||||
self.cfg.channels,
|
||||
self.channels,
|
||||
)
|
||||
return renderer
|
||||
|
||||
# def _load_renderer(self) -> Renderer:
|
||||
# # only kept for unit tests I'm too lazy to rewrite.
|
||||
# return Renderer(self._renderer_params())
|
||||
|
||||
def play(self) -> None:
|
||||
if self.has_played:
|
||||
|
@ -251,13 +330,20 @@ class CorrScope:
|
|||
end_frame = fps * end_time
|
||||
end_frame = int(end_frame) + 1
|
||||
|
||||
@attr.dataclass
|
||||
class ThreadShared:
|
||||
# mutex? i hardly knew 'er!
|
||||
end_frame: int
|
||||
|
||||
thread_shared = ThreadShared(end_frame)
|
||||
del end_frame
|
||||
|
||||
self.arg.on_begin(self.cfg.begin_time, end_time)
|
||||
|
||||
renderer = self._load_renderer()
|
||||
renderer_params = self._renderer_params()
|
||||
renderer = Renderer(renderer_params)
|
||||
self.renderer = renderer # only used for unit tests
|
||||
|
||||
renderer.add_labels([channel.label for channel in self.channels])
|
||||
|
||||
# For debugging only
|
||||
# for trigger in self.triggers:
|
||||
# trigger.set_renderer(renderer)
|
||||
|
@ -268,20 +354,23 @@ class CorrScope:
|
|||
benchmark_mode = self.cfg.benchmark_mode
|
||||
not_benchmarking = not benchmark_mode
|
||||
|
||||
with self._load_outputs():
|
||||
prev = -1
|
||||
# When subsampling FPS, render frames from the future to alleviate lag.
|
||||
# subfps=1, ahead=0.
|
||||
# subfps=2, ahead=1.
|
||||
render_subfps = self.cfg.render_subfps
|
||||
ahead = render_subfps // 2
|
||||
|
||||
# When subsampling FPS, render frames from the future to alleviate lag.
|
||||
# subfps=1, ahead=0.
|
||||
# subfps=2, ahead=1.
|
||||
render_subfps = self.cfg.render_subfps
|
||||
ahead = render_subfps // 2
|
||||
# Single-process
|
||||
def play_impl():
|
||||
end_frame = thread_shared.end_frame
|
||||
prev = -1
|
||||
pt = 0.0
|
||||
|
||||
# For each frame, render each wave
|
||||
for frame in range(begin_frame, end_frame):
|
||||
if self.arg.is_aborted():
|
||||
# Used for FPS calculation
|
||||
end_frame = frame
|
||||
thread_shared.end_frame = frame
|
||||
|
||||
for output in self.outputs:
|
||||
output.terminate()
|
||||
|
@ -324,8 +413,13 @@ class CorrScope:
|
|||
|
||||
if not_benchmarking or benchmark_mode >= BenchmarkMode.RENDER:
|
||||
# Render frame
|
||||
|
||||
t = time.perf_counter() * 1000.0
|
||||
renderer.update_main_lines(render_inputs, trigger_samples)
|
||||
frame_data = renderer.get_frame()
|
||||
t1 = time.perf_counter() * 1000.0
|
||||
# print(f"idle = {t - pt}, dt1 = {t1 - t}")
|
||||
pt = t1
|
||||
|
||||
if not_benchmarking or benchmark_mode == BenchmarkMode.OUTPUT:
|
||||
# Output frame
|
||||
|
@ -336,13 +430,274 @@ class CorrScope:
|
|||
break
|
||||
if aborted:
|
||||
# Outputting frame happens after most computation finished.
|
||||
end_frame = frame + 1
|
||||
thread_shared.end_frame = frame + 1
|
||||
break
|
||||
|
||||
# Multiprocess
|
||||
def play_parallel(nthread: int):
|
||||
framebuffer_nbyte = len(renderer.get_frame())
|
||||
print(f"framebuffer_nbyte = {framebuffer_nbyte}")
|
||||
|
||||
# setup threading
|
||||
abort_from_thread = threading.Event()
|
||||
# self.arg.is_aborted() from GUI, abort_from_thread.is_set() from thread
|
||||
is_aborted = lambda: self.arg.is_aborted() or abort_from_thread.is_set()
|
||||
|
||||
@attr.dataclass
|
||||
class RenderToOutput:
|
||||
frame_idx: int
|
||||
shmem: SharedMemory
|
||||
completion: "Future[None]"
|
||||
|
||||
# Rely on avail_shmems for backpressure.
|
||||
render_to_output: "Queue[RenderToOutput | None]" = Queue()
|
||||
|
||||
# Release all shmems after finishing rendering.
|
||||
all_shmems: List[SharedMemory] = [
|
||||
SharedMemory(create=True, size=framebuffer_nbyte)
|
||||
for _ in range(2 * nthread)
|
||||
]
|
||||
|
||||
is_submitting = [False, 0]
|
||||
|
||||
# Only send unused shmems to a worker process, and wait for it to be
|
||||
# returned before reusing.
|
||||
avail_shmems: "Queue[SharedMemory]" = Queue()
|
||||
for shmem in all_shmems:
|
||||
avail_shmems.put(shmem)
|
||||
|
||||
# TODO https://stackoverflow.com/questions/2829329/catch-a-threads-exception-in-the-caller-thread
|
||||
def _render_thread():
|
||||
end_frame = thread_shared.end_frame
|
||||
prev = -1
|
||||
|
||||
# TODO gather trigger points from triggering threads
|
||||
# For each frame, render each wave
|
||||
for frame in range(begin_frame, end_frame):
|
||||
if is_aborted():
|
||||
break
|
||||
|
||||
time_seconds = frame / fps
|
||||
should_render = (frame - begin_frame) % render_subfps == ahead
|
||||
|
||||
rounded = int(time_seconds)
|
||||
if PRINT_TIMESTAMP and rounded != prev:
|
||||
self.arg.progress(rounded)
|
||||
prev = rounded
|
||||
|
||||
render_inputs = []
|
||||
trigger_samples = []
|
||||
# Get render-data from each wave.
|
||||
for render_wave, channel in zip(self.render_waves, self.channels):
|
||||
sample = round(render_wave.smp_s * time_seconds)
|
||||
|
||||
# Get trigger.
|
||||
if not_benchmarking or benchmark_mode == BenchmarkMode.TRIGGER:
|
||||
cache = PerFrameCache()
|
||||
|
||||
result = channel.trigger.get_trigger(sample, cache)
|
||||
trigger_sample = result.result
|
||||
freq_estimate = result.freq_estimate
|
||||
|
||||
else:
|
||||
trigger_sample = sample
|
||||
freq_estimate = 0
|
||||
|
||||
# Get render data.
|
||||
if should_render:
|
||||
trigger_samples.append(trigger_sample)
|
||||
data = channel.get_render_around(trigger_sample)
|
||||
render_inputs.append(RenderInput(data, freq_estimate))
|
||||
|
||||
if not should_render:
|
||||
continue
|
||||
|
||||
# blocks until frames get rendered and shmem is returned by
|
||||
# output_thread().
|
||||
t = time.perf_counter()
|
||||
shmem = avail_shmems.get()
|
||||
t = time.perf_counter() - t
|
||||
# if t >= 0.001:
|
||||
# print("get shmem", t)
|
||||
if is_aborted():
|
||||
break
|
||||
|
||||
# blocking
|
||||
t = time.perf_counter()
|
||||
render_to_output.put(
|
||||
RenderToOutput(
|
||||
frame,
|
||||
shmem,
|
||||
pool.submit(
|
||||
worker_render_frame,
|
||||
render_inputs,
|
||||
trigger_samples,
|
||||
shmem.name,
|
||||
),
|
||||
)
|
||||
)
|
||||
t = time.perf_counter() - t
|
||||
# if t >= 0.001:
|
||||
# print("send to render", t)
|
||||
|
||||
# TODO if is_aborted(), should we insert class CancellationToken,
|
||||
# rather than having output_thread() poll it too?
|
||||
render_to_output.put(None)
|
||||
print("exit render")
|
||||
|
||||
def render_thread():
|
||||
"""
|
||||
How do we know that if render_thread() crashes, output_thread() will
|
||||
not block?
|
||||
|
||||
- `_render_thread()` does not return early, and will always
|
||||
`render_to_output.put(None)` before returning.
|
||||
|
||||
- If `_render_thread()` crashes, `render_thread()` will call
|
||||
`abort_from_thread.set()` before writing `render_to_output.put(
|
||||
None)`. When the output thread reads None, it will see that it is
|
||||
aborted.
|
||||
"""
|
||||
try:
|
||||
_render_thread()
|
||||
except BaseException as e:
|
||||
abort_from_thread.set()
|
||||
render_to_output.put(None)
|
||||
raise e
|
||||
|
||||
def _output_thread():
|
||||
thread_shared.end_frame = begin_frame
|
||||
|
||||
while True:
|
||||
if is_aborted():
|
||||
for output in self.outputs:
|
||||
output.terminate()
|
||||
break
|
||||
|
||||
# blocking
|
||||
render_msg: Union[RenderToOutput, None] = render_to_output.get()
|
||||
|
||||
if render_msg is None:
|
||||
if is_aborted():
|
||||
for output in self.outputs:
|
||||
output.terminate()
|
||||
break
|
||||
|
||||
# Wait for shmem to be filled with data.
|
||||
render_msg.completion.result()
|
||||
frame_data = render_msg.shmem.buf[:framebuffer_nbyte]
|
||||
|
||||
if not_benchmarking or benchmark_mode == BenchmarkMode.OUTPUT:
|
||||
# Output frame
|
||||
for output in self.outputs:
|
||||
if output.write_frame(frame_data) is outputs_.Stop:
|
||||
abort_from_thread.set()
|
||||
break
|
||||
thread_shared.end_frame = render_msg.frame_idx + 1
|
||||
|
||||
avail_shmems.put(render_msg.shmem)
|
||||
|
||||
if is_aborted():
|
||||
output_on_error()
|
||||
|
||||
print("exit output")
|
||||
|
||||
def output_on_error():
|
||||
"""If is_aborted() is True but render_thread() is blocked on
|
||||
render_to_output.put(), then we need to clear the queue so
|
||||
render_thread() can return from put(), then check is_aborted() = True
|
||||
and terminate."""
|
||||
|
||||
# It is an error to call output_on_error() when not aborted. If so,
|
||||
# force an abort so we can print the error without deadlock.
|
||||
was_aborted = is_aborted()
|
||||
if not was_aborted:
|
||||
abort_from_thread.set()
|
||||
|
||||
while True:
|
||||
try:
|
||||
render_msg = render_to_output.get(block=False)
|
||||
if render_msg is None:
|
||||
continue # probably empty?
|
||||
|
||||
# To avoid deadlock, we must return the shmem to
|
||||
# _render_thread() in case it's blocked waiting for it. We do
|
||||
# not need to wait for the shmem to be no longer written to (
|
||||
# `render_msg.completion.result()`), since if we set
|
||||
# is_aborted() to true before returning a shmem,
|
||||
# `_render_thread()` will ignore the acquired shmem without
|
||||
# writing to it.
|
||||
|
||||
avail_shmems.put(render_msg.shmem)
|
||||
except Empty:
|
||||
break
|
||||
|
||||
assert was_aborted
|
||||
|
||||
def output_thread():
|
||||
"""
|
||||
How do we know that if output_thread() crashes, render_thread() will
|
||||
not block?
|
||||
|
||||
- `_output_thread()` does not return early. If it is aborted, it will
|
||||
call `output_on_error()` to unblock `_render_thread()`.
|
||||
|
||||
- If `_output_thread()` crashes, `output_thread()` will call
|
||||
`abort_from_thread.set()` before calling `output_on_error()` to
|
||||
unblock `_render_thread()`.
|
||||
|
||||
I miss being able to poll()/WaitForMultipleObjects().
|
||||
"""
|
||||
try:
|
||||
_output_thread()
|
||||
except BaseException as e:
|
||||
abort_from_thread.set()
|
||||
output_on_error()
|
||||
raise e
|
||||
|
||||
shmem_names: List[str] = [shmem.name for shmem in all_shmems]
|
||||
|
||||
with ProcessPoolExecutor(
|
||||
nthread,
|
||||
initializer=worker_create_renderer,
|
||||
initargs=(renderer_params, shmem_names),
|
||||
) as pool:
|
||||
render_handle = PropagatingThread(
|
||||
target=render_thread, name="render_thread"
|
||||
)
|
||||
output_handle = PropagatingThread(
|
||||
target=output_thread, name="output_thread"
|
||||
)
|
||||
|
||||
render_handle.start()
|
||||
output_handle.start()
|
||||
|
||||
# throws
|
||||
render_handle.join()
|
||||
output_handle.join()
|
||||
|
||||
# TODO is it a problem that ProcessPoolExecutor's
|
||||
# worker_create_renderer() creates SharedMemory handles, which are
|
||||
# never closed when the process terminates?
|
||||
#
|
||||
# Constructing a new SharedMemory on every worker_render_frame() call
|
||||
# is more "correct", but increases CPU usage by around 20% or more (
|
||||
# see "shmem question"), likely due to page table thrashing.
|
||||
|
||||
for shmem in all_shmems:
|
||||
shmem.unlink()
|
||||
|
||||
parallelism = self.arg.parallelism
|
||||
with self._load_outputs():
|
||||
if parallelism and parallelism.parallel:
|
||||
play_parallel(parallelism.max_render_cores)
|
||||
else:
|
||||
play_impl()
|
||||
|
||||
if PRINT_TIMESTAMP:
|
||||
# noinspection PyUnboundLocalVariable
|
||||
dtime_sec = time.perf_counter() - begin
|
||||
dframe = end_frame - begin_frame
|
||||
dframe = thread_shared.end_frame - begin_frame
|
||||
|
||||
frame_per_sec = dframe / dtime_sec
|
||||
try:
|
||||
|
|
|
@ -185,8 +185,6 @@ class MainWindow(qw.QMainWindow, Ui_MainWindow):
|
|||
prefs_error = None
|
||||
try:
|
||||
self.pref = gp.load_prefs()
|
||||
if not isinstance(self.pref, gp.GlobalPrefs):
|
||||
raise TypeError(f"prefs.yaml contains wrong type {type(self.pref)}")
|
||||
except Exception as e:
|
||||
prefs_error = e
|
||||
self.pref = gp.GlobalPrefs()
|
||||
|
@ -212,6 +210,9 @@ class MainWindow(qw.QMainWindow, Ui_MainWindow):
|
|||
self.on_separate_render_dir_toggled
|
||||
)
|
||||
|
||||
self.action_parallel.setChecked(self.pref.parallel)
|
||||
self.action_parallel.toggled.connect(self.on_parallel_toggled)
|
||||
|
||||
self.action_open_config_dir.triggered.connect(self.on_open_config_dir)
|
||||
|
||||
self.actionNew.triggered.connect(self.on_action_new)
|
||||
|
@ -467,6 +468,9 @@ class MainWindow(qw.QMainWindow, Ui_MainWindow):
|
|||
else:
|
||||
self.pref.render_dir = ""
|
||||
|
||||
def on_parallel_toggled(self, checked: bool):
|
||||
self.pref.parallel = checked
|
||||
|
||||
def on_open_config_dir(self):
|
||||
appdata_uri = qc.QUrl.fromLocalFile(str(paths.appdata_dir))
|
||||
QDesktopServices.openUrl(appdata_uri)
|
||||
|
@ -608,7 +612,10 @@ class MainWindow(qw.QMainWindow, Ui_MainWindow):
|
|||
)
|
||||
|
||||
arg = Arguments(
|
||||
cfg_dir=self.cfg_dir, outputs=outputs, is_aborted=raise_exception
|
||||
cfg_dir=self.cfg_dir,
|
||||
outputs=outputs,
|
||||
parallelism=self.pref.parallelism(),
|
||||
is_aborted=raise_exception,
|
||||
)
|
||||
return arg
|
||||
|
||||
|
@ -766,7 +773,7 @@ class CorrThread(Thread):
|
|||
job: CorrJob
|
||||
|
||||
def __init__(self, cfg: Config, arg: Arguments, mode: PreviewOrRender):
|
||||
Thread.__init__(self)
|
||||
Thread.__init__(self, name="CorrThread")
|
||||
self.job = CorrJob(cfg, arg, mode)
|
||||
|
||||
def run(self):
|
||||
|
|
|
@ -75,7 +75,6 @@ class MainWindow(QWidget):
|
|||
|
||||
# Right-hand channel list
|
||||
with append_widget(s, QVBoxLayout) as self.audioColumn:
|
||||
|
||||
# Top bar (master audio, trigger)
|
||||
self.add_top_bar(s)
|
||||
|
||||
|
@ -96,7 +95,6 @@ class MainWindow(QWidget):
|
|||
def add_general_tab(self, s: LayoutStack) -> QWidget:
|
||||
tr = self.tr
|
||||
with self.add_tab_stretch(s, tr("&General"), layout=QVBoxLayout) as tab:
|
||||
|
||||
# Global group
|
||||
with append_widget(s, QGroupBox) as self.optionGlobal:
|
||||
set_layout(s, QFormLayout)
|
||||
|
@ -161,7 +159,6 @@ class MainWindow(QWidget):
|
|||
with add_tab(
|
||||
s, VerticalScrollArea, tr("&Appearance")
|
||||
) as tab, fill_scroll_stretch(s, layout=QVBoxLayout):
|
||||
|
||||
with append_widget(
|
||||
s, QGroupBox, title=tr("Appearance"), layout=QFormLayout
|
||||
):
|
||||
|
@ -431,7 +428,6 @@ class MainWindow(QWidget):
|
|||
tr = self.tr
|
||||
with append_widget(s, QHBoxLayout):
|
||||
with append_widget(s, QVBoxLayout):
|
||||
|
||||
with append_widget(s, QGroupBox):
|
||||
s.widget.setTitle(tr("FFmpeg Options"))
|
||||
set_layout(s, QFormLayout)
|
||||
|
@ -496,6 +492,9 @@ class MainWindow(QWidget):
|
|||
self.action_separate_render_dir = create_element(
|
||||
QAction, MainWindow, text=tr("&Separate Render Folder"), checkable=True
|
||||
)
|
||||
self.action_parallel = create_element(
|
||||
QAction, MainWindow, text=tr("&Multi-Core Rendering"), checkable=True
|
||||
)
|
||||
self.action_open_config_dir = create_element(
|
||||
QAction, MainWindow, text=tr("Open &Config Folder")
|
||||
)
|
||||
|
@ -518,6 +517,7 @@ class MainWindow(QWidget):
|
|||
with append_menu(s) as self.menuTools:
|
||||
w = self.menuTools
|
||||
w.addAction(self.action_separate_render_dir)
|
||||
w.addAction(self.action_parallel)
|
||||
w.addSeparator()
|
||||
w.addAction(self.action_open_config_dir)
|
||||
|
||||
|
|
|
@ -304,6 +304,42 @@ def abstract_classvar(self) -> Any:
|
|||
"""A ClassVar to be overriden by a subclass."""
|
||||
|
||||
|
||||
@attr.dataclass
|
||||
class RendererParams:
|
||||
"""Serializable between processes."""
|
||||
|
||||
cfg: RendererConfig
|
||||
lcfg: LayoutConfig
|
||||
data_shapes: List[tuple]
|
||||
channel_cfgs: Optional[List[ChannelConfig]]
|
||||
render_strides: Optional[List[int]]
|
||||
labels: Optional[List[str]]
|
||||
|
||||
@staticmethod
|
||||
def from_obj(
|
||||
cfg: RendererConfig,
|
||||
lcfg: "LayoutConfig",
|
||||
dummy_datas: List[np.ndarray],
|
||||
channel_cfgs: Optional[List["ChannelConfig"]],
|
||||
channels: List["Channel"],
|
||||
):
|
||||
if channels is not None:
|
||||
render_strides = [channel.render_stride for channel in channels]
|
||||
labels = [channel.label for channel in channels]
|
||||
else:
|
||||
render_strides = None
|
||||
labels = None
|
||||
|
||||
return RendererParams(
|
||||
cfg,
|
||||
lcfg,
|
||||
[data.shape for data in dummy_datas],
|
||||
channel_cfgs,
|
||||
render_strides,
|
||||
labels,
|
||||
)
|
||||
|
||||
|
||||
class _RendererBackend(ABC):
|
||||
"""
|
||||
Renderer backend which takes data and produces images.
|
||||
|
@ -325,17 +361,15 @@ class _RendererBackend(ABC):
|
|||
Only used for tests/test_renderer.py.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_obj(cls, *args, **kwargs):
|
||||
return cls(RendererParams.from_obj(*args, **kwargs))
|
||||
|
||||
# Instance initializer
|
||||
def __init__(
|
||||
self,
|
||||
cfg: RendererConfig,
|
||||
lcfg: "LayoutConfig",
|
||||
dummy_datas: List[np.ndarray],
|
||||
channel_cfgs: Optional[List["ChannelConfig"]],
|
||||
channels: List["Channel"],
|
||||
):
|
||||
def __init__(self, params: RendererParams):
|
||||
cfg = params.cfg
|
||||
self.cfg = cfg
|
||||
self.lcfg = lcfg
|
||||
self.lcfg = params.lcfg
|
||||
|
||||
self.w = cfg.divided_width
|
||||
self.h = cfg.divided_height
|
||||
|
@ -343,13 +377,15 @@ class _RendererBackend(ABC):
|
|||
# Maps a continuous variable from 0 to 1 (representing one octave) to a color.
|
||||
self.pitch_cmap = gen_circular_cmap(cfg.pitch_colors)
|
||||
|
||||
self.nplots = len(dummy_datas)
|
||||
data_shapes = params.data_shapes
|
||||
self.nplots = len(data_shapes)
|
||||
|
||||
if self.nplots > 0:
|
||||
assert len(dummy_datas[0].shape) == 2, dummy_datas[0].shape
|
||||
self.wave_nsamps = [data.shape[0] for data in dummy_datas]
|
||||
self.wave_nchans = [data.shape[1] for data in dummy_datas]
|
||||
assert len(data_shapes[0]) == 2, data_shapes[0]
|
||||
self.wave_nsamps = [shape[0] for shape in data_shapes]
|
||||
self.wave_nchans = [shape[1] for shape in data_shapes]
|
||||
|
||||
channel_cfgs = params.channel_cfgs
|
||||
if channel_cfgs is None:
|
||||
channel_cfgs = [ChannelConfig("") for _ in range(self.nplots)]
|
||||
|
||||
|
@ -367,12 +403,13 @@ class _RendererBackend(ABC):
|
|||
]
|
||||
|
||||
# Load channel strides.
|
||||
if channels is not None:
|
||||
if len(channels) != self.nplots:
|
||||
render_strides = params.render_strides
|
||||
if render_strides is not None:
|
||||
if len(render_strides) != self.nplots:
|
||||
raise ValueError(
|
||||
f"cannot assign {len(channels)} channels to {self.nplots} plots"
|
||||
f"cannot assign {len(render_strides)} channels to {self.nplots} plots"
|
||||
)
|
||||
self.render_strides = [channel.render_stride for channel in channels]
|
||||
self.render_strides = render_strides
|
||||
else:
|
||||
self.render_strides = [1] * self.nplots
|
||||
|
||||
|
@ -446,8 +483,8 @@ class AbstractMatplotlibRenderer(_RendererBackend, ABC):
|
|||
def _canvas_to_bytes(canvas: "FigureCanvasBase") -> ByteBuffer:
|
||||
pass
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def __init__(self, params: RendererParams):
|
||||
super().__init__(params)
|
||||
|
||||
dict.__setitem__(
|
||||
matplotlib.rcParams, "lines.antialiased", self.cfg.antialiasing
|
||||
|
@ -455,6 +492,9 @@ class AbstractMatplotlibRenderer(_RendererBackend, ABC):
|
|||
|
||||
self._setup_axes(self.wave_nchans)
|
||||
|
||||
if params.labels is not None:
|
||||
self.add_labels(params.labels)
|
||||
|
||||
self._artists: List["Artist"] = []
|
||||
|
||||
_fig: "Figure"
|
||||
|
@ -892,8 +932,8 @@ class MatplotlibAggRenderer(AbstractMatplotlibRenderer):
|
|||
class RendererFrontend(_RendererBackend, ABC):
|
||||
"""Wrapper around _RendererBackend implementations, providing a better interface."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def __init__(self, params: RendererParams):
|
||||
super().__init__(params)
|
||||
|
||||
self._update_main_lines = None
|
||||
self._custom_lines = {} # type: Dict[Any, CustomLine]
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from typing import *
|
||||
|
||||
import attr
|
||||
from atomicwrites import atomic_write
|
||||
|
||||
from corrscope.config import DumpableAttrs, yaml
|
||||
|
@ -21,6 +22,12 @@ class Ref(Generic[Attrs]):
|
|||
setattr(self.obj, self.key, value)
|
||||
|
||||
|
||||
@attr.dataclass
|
||||
class Parallelism:
|
||||
parallel: bool = True
|
||||
max_render_cores: int = 2
|
||||
|
||||
|
||||
class GlobalPrefs(DumpableAttrs, always_dump="*"):
|
||||
# Most recent YAML or audio file opened
|
||||
file_dir: str = ""
|
||||
|
@ -40,13 +47,23 @@ class GlobalPrefs(DumpableAttrs, always_dump="*"):
|
|||
else:
|
||||
return self.file_dir_ref
|
||||
|
||||
parallel: bool = True
|
||||
max_render_cores: int = 2
|
||||
|
||||
def parallelism(self) -> Parallelism:
|
||||
return Parallelism(self.parallel, self.max_render_cores)
|
||||
|
||||
|
||||
_PREF_PATH = paths.appdata_dir / "prefs.yaml"
|
||||
|
||||
|
||||
def load_prefs() -> GlobalPrefs:
|
||||
try:
|
||||
return yaml.load(_PREF_PATH)
|
||||
pref = yaml.load(_PREF_PATH)
|
||||
if not isinstance(pref, GlobalPrefs):
|
||||
raise TypeError(f"prefs.yaml contains wrong type {type(pref)}")
|
||||
return pref
|
||||
|
||||
except FileNotFoundError:
|
||||
return GlobalPrefs()
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
import shlex
|
||||
import webbrowser
|
||||
|
||||
|
||||
# Obtain path from package.dist-info/entry_points.txt
|
||||
def run(path, arg_str):
|
||||
module, func = path.split(":")
|
||||
|
|
|
@ -24,6 +24,7 @@ bools = hs.booleans()
|
|||
default_labels = hs.sampled_from(DefaultLabel)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="bronke")
|
||||
@given(
|
||||
# Channel
|
||||
c_amplification=maybe_real,
|
||||
|
|
|
@ -37,7 +37,6 @@ def call_main(argv):
|
|||
def yaml_sink(_mocker, command: str):
|
||||
"""Mocks yaml.dump() and returns call args. Also tests dumping and loading."""
|
||||
with mock.patch.object(yaml, "dump") as dump:
|
||||
|
||||
argv = shlex.split(command) + ["-w"]
|
||||
call_main(argv)
|
||||
|
||||
|
|
|
@ -222,7 +222,7 @@ def test_renderer_layout():
|
|||
nplots = 15
|
||||
|
||||
datas = [RENDER_Y_ZEROS] * nplots
|
||||
r = Renderer(cfg, lcfg, datas, None, None)
|
||||
r = Renderer.from_obj(cfg, lcfg, datas, None, None)
|
||||
r.update_main_lines(RenderInput.wrap_datas(datas), [0] * nplots)
|
||||
layout = r.layout
|
||||
|
||||
|
|
|
@ -66,7 +66,7 @@ def test_render_output():
|
|||
"""Ensure rendering to output does not raise exceptions."""
|
||||
datas = [RENDER_Y_ZEROS]
|
||||
|
||||
renderer = Renderer(CFG.render, CFG.layout, datas, None, None)
|
||||
renderer = Renderer.from_obj(CFG.render, CFG.layout, datas, None, None)
|
||||
out: FFmpegOutput = NULL_FFMPEG_OUTPUT(CFG)
|
||||
|
||||
renderer.update_main_lines(RenderInput.wrap_datas(datas), [0])
|
||||
|
|
|
@ -199,13 +199,13 @@ def test_default_colors(appear: Appearance, data):
|
|||
lcfg = LayoutConfig(orientation=ORIENTATION)
|
||||
datas = [data] * NPLOTS
|
||||
|
||||
r = Renderer(cfg, lcfg, datas, None, None)
|
||||
r = Renderer.from_obj(cfg, lcfg, datas, None, None)
|
||||
verify(r, appear, datas)
|
||||
|
||||
# Ensure default ChannelConfig(line_color=None) does not override line color
|
||||
chan = ChannelConfig(wav_path="")
|
||||
channels = [chan] * NPLOTS
|
||||
r = Renderer(cfg, lcfg, datas, channels, None)
|
||||
r = Renderer.from_obj(cfg, lcfg, datas, channels, None)
|
||||
verify(r, appear, datas)
|
||||
|
||||
|
||||
|
@ -222,7 +222,7 @@ def test_line_colors(appear: Appearance, data):
|
|||
cfg.init_line_color = "#888888"
|
||||
chan.line_color = appear.fg.color
|
||||
|
||||
r = Renderer(cfg, lcfg, datas, channels, None)
|
||||
r = Renderer.from_obj(cfg, lcfg, datas, channels, None)
|
||||
verify(r, appear, datas)
|
||||
|
||||
|
||||
|
@ -343,7 +343,7 @@ def test_label_render(label_position: LabelPosition, data, hide_lines):
|
|||
labels = ["#"] * nplots
|
||||
datas = [data] * nplots
|
||||
|
||||
r = Renderer(cfg, lcfg, datas, None, None)
|
||||
r = Renderer.from_obj(cfg, lcfg, datas, None, None)
|
||||
r.add_labels(labels)
|
||||
if not hide_lines:
|
||||
r.update_main_lines(RenderInput.wrap_datas(datas), [0] * nplots)
|
||||
|
@ -426,7 +426,7 @@ def verify_res_divisor_rounding(
|
|||
datas = [RENDER_Y_ZEROS]
|
||||
|
||||
try:
|
||||
renderer = Renderer(cfg, LayoutConfig(), datas, None, None)
|
||||
renderer = Renderer.from_obj(cfg, LayoutConfig(), datas, None, None)
|
||||
if not speed_hack:
|
||||
renderer.update_main_lines(
|
||||
RenderInput.wrap_datas(datas), [0] * len(datas)
|
||||
|
@ -484,7 +484,7 @@ def test_renderer_knows_stride(mocker: "pytest_mock.MockFixture", integration: b
|
|||
else:
|
||||
channel = Channel(chan_cfg, corr_cfg, channel_idx=0)
|
||||
data = channel.get_render_around(0)
|
||||
renderer = Renderer(
|
||||
renderer = Renderer.from_obj(
|
||||
corr_cfg.render, corr_cfg.layout, [data], [chan_cfg], [channel]
|
||||
)
|
||||
assert renderer.render_strides == [subsampling * width_mul]
|
||||
|
@ -510,7 +510,9 @@ def test_frontend_overrides_backend(mocker: "pytest_mock.MockFixture"):
|
|||
channel = Channel(chan_cfg, corr_cfg, channel_idx=0)
|
||||
data = channel.get_render_around(0)
|
||||
|
||||
renderer = Renderer(corr_cfg.render, corr_cfg.layout, [data], [chan_cfg], [channel])
|
||||
renderer = Renderer.from_obj(
|
||||
corr_cfg.render, corr_cfg.layout, [data], [chan_cfg], [channel]
|
||||
)
|
||||
renderer.update_main_lines([RenderInput.stub_new(data)], [0])
|
||||
renderer.get_frame()
|
||||
|
||||
|
|
Ładowanie…
Reference in New Issue