Merge pull request #450 from corrscope/multiprocessing

Add multi-core rendering for increased preview/render speed
pull/462/head
kitten 2023-12-22 23:57:47 -08:00 zatwierdzone przez GitHub
commit 67b136b9e8
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 4AEE18F83AFDEB23
14 zmienionych plików z 495 dodań i 61 usunięć

Wyświetl plik

@ -2,6 +2,8 @@
### Features
- Add multi-core rendering for increased preview/render speed (#450)
### Major Changes
### Changelog

Wyświetl plik

@ -6,6 +6,7 @@ from typing import Optional, List, Tuple, Union, cast, TypeVar
import click
import corrscope
import corrscope.settings.global_prefs as gp
from corrscope.channel import ChannelConfig
from corrscope.config import yaml
from corrscope.corrscope import template_config, CorrScope, Config, Arguments
@ -79,6 +80,7 @@ def get_file_stem(cfg_path: Optional[Path], cfg: Config, default: T) -> Union[st
CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"])
# fmt: off
@click.command(context_settings=CONTEXT_SETTINGS)
# Inputs
@ -232,7 +234,15 @@ def main(
outputs.append(cfg.get_ffmpeg_cfg(video_path))
if outputs:
arg = Arguments(cfg_dir=cfg_dir, outputs=outputs)
try:
pref = gp.load_prefs()
except Exception as e:
print(e)
pref = gp.GlobalPrefs()
arg = Arguments(
cfg_dir=cfg_dir, outputs=outputs, parallelism=pref.parallelism()
)
command = lambda: CorrScope(cfg, arg).play()
if profile:
first_song_name = Path(files[0]).name

Wyświetl plik

@ -56,7 +56,6 @@ class MyYAML(YAML):
def dump(
self, data: Any, stream: "Union[Path, TextIO, None]" = None, **kwargs
) -> Optional[str]:
# On Windows, when dumping to path, ruamel.yaml writes files in locale encoding.
# Foreign characters are undumpable. Locale-compatible characters cannot be loaded.
# https://bitbucket.org/ruamel/yaml/issues/316/unicode-encoding-decoding-errors-on
@ -154,6 +153,7 @@ T = TypeVar("T")
# yaml.dump(obj, stream)
# return yaml.load(stream)
# AKA pickle_copy
def copy_config(obj: T) -> T:
with BytesIO() as stream:

Wyświetl plik

@ -1,11 +1,16 @@
# -*- coding: utf-8 -*-
import os.path
import threading
import time
from concurrent.futures import ProcessPoolExecutor, Future
from contextlib import ExitStack, contextmanager
from enum import unique
from fractions import Fraction
from multiprocessing.shared_memory import SharedMemory
from pathlib import Path
from typing import Iterator, Optional, List, Callable
from queue import Queue, Empty
from threading import Thread
from typing import Iterator, Optional, List, Callable, Dict, Union, Any
import attr
@ -14,12 +19,17 @@ from corrscope.channel import Channel, ChannelConfig, DefaultLabel
from corrscope.config import KeywordAttrs, DumpEnumAsStr, CorrError, with_units
from corrscope.layout import LayoutConfig
from corrscope.outputs import FFmpegOutputConfig, IOutputConfig
from corrscope.renderer import Renderer, RendererConfig, RendererFrontend, RenderInput
from corrscope.renderer import (
Renderer,
RendererConfig,
RendererParams,
RenderInput,
)
from corrscope.settings.global_prefs import Parallelism
from corrscope.triggers import (
CorrelationTriggerConfig,
PerFrameCache,
SpectrumConfig,
MainTrigger,
)
from corrscope.util import pushd, coalesce
from corrscope.wave import Wave, Flatten, FlattenOrStr
@ -139,6 +149,33 @@ def template_config(**kwargs) -> Config:
return attr.evolve(cfg, **kwargs)
class PropagatingThread(Thread):
# Based off https://stackoverflow.com/a/31614591 and Thread source code.
def run(self):
self.exc = None
try:
if self._target is not None:
self.ret = self._target(*self._args, **self._kwargs)
except BaseException as e:
self.exc = e
finally:
# Avoid a refcycle if the thread is running a function with
# an argument that has a member that points to the thread.
del self._target, self._args, self._kwargs
def join(self, timeout=None) -> Any:
try:
super(PropagatingThread, self).join(timeout)
if self.exc:
raise RuntimeError(f"exception from {self.name}") from self.exc
return self.ret
finally:
# If join() raises, set `self = None` to avoid a reference cycle with the
# backtrace, because concurrent.futures.Future.result() does it.
self = None
BeginFunc = Callable[[float, float], None]
ProgressFunc = Callable[[int], None]
IsAborted = Callable[[], bool]
@ -148,6 +185,7 @@ IsAborted = Callable[[], bool]
class Arguments:
cfg_dir: str
outputs: List[outputs_.IOutputConfig]
parallelism: Optional[Parallelism] = None
on_begin: BeginFunc = lambda begin_time, end_time: None
progress: ProgressFunc = lambda p: print(p, flush=True)
@ -155,6 +193,44 @@ class Arguments:
on_end: Callable[[], None] = lambda: None
def worker_create_renderer(renderer_params: RendererParams, shmem_names: List[str]):
import appnope
# Disable power saving for renderer processes.
appnope.nope()
global WORKER_RENDERER
global SHMEMS
WORKER_RENDERER = Renderer(renderer_params)
SHMEMS = {
name: SharedMemory(name) for name in shmem_names
} # type: Dict[str, SharedMemory]
prev = 0.0
def worker_render_frame(
render_inputs: List[RenderInput],
trigger_samples: List[int],
shmem_name: str,
):
global WORKER_RENDERER, SHMEMS, prev
t = time.perf_counter() * 1000.0
renderer = WORKER_RENDERER
renderer.update_main_lines(render_inputs, trigger_samples)
frame_data = renderer.get_frame()
t1 = time.perf_counter() * 1000.0
shmem = SHMEMS[shmem_name]
shmem.buf[: len(frame_data)] = frame_data
t2 = time.perf_counter() * 1000.0
# print(f"idle = {t - prev}, dt1 = {t1 - t}, dt2 = {t2 - t1}")
prev = t2
class CorrScope:
def __init__(self, cfg: Config, arg: Arguments):
"""cfg is mutated!
@ -223,16 +299,19 @@ class CorrScope:
]
yield
def _load_renderer(self) -> RendererFrontend:
def _renderer_params(self) -> RendererParams:
dummy_datas = [channel.get_render_around(0) for channel in self.channels]
renderer = Renderer(
return RendererParams.from_obj(
self.cfg.render,
self.cfg.layout,
dummy_datas,
self.cfg.channels,
self.channels,
)
return renderer
# def _load_renderer(self) -> Renderer:
# # only kept for unit tests I'm too lazy to rewrite.
# return Renderer(self._renderer_params())
def play(self) -> None:
if self.has_played:
@ -251,13 +330,20 @@ class CorrScope:
end_frame = fps * end_time
end_frame = int(end_frame) + 1
@attr.dataclass
class ThreadShared:
# mutex? i hardly knew 'er!
end_frame: int
thread_shared = ThreadShared(end_frame)
del end_frame
self.arg.on_begin(self.cfg.begin_time, end_time)
renderer = self._load_renderer()
renderer_params = self._renderer_params()
renderer = Renderer(renderer_params)
self.renderer = renderer # only used for unit tests
renderer.add_labels([channel.label for channel in self.channels])
# For debugging only
# for trigger in self.triggers:
# trigger.set_renderer(renderer)
@ -268,20 +354,23 @@ class CorrScope:
benchmark_mode = self.cfg.benchmark_mode
not_benchmarking = not benchmark_mode
with self._load_outputs():
prev = -1
# When subsampling FPS, render frames from the future to alleviate lag.
# subfps=1, ahead=0.
# subfps=2, ahead=1.
render_subfps = self.cfg.render_subfps
ahead = render_subfps // 2
# Single-process
def play_impl():
end_frame = thread_shared.end_frame
prev = -1
pt = 0.0
# For each frame, render each wave
for frame in range(begin_frame, end_frame):
if self.arg.is_aborted():
# Used for FPS calculation
end_frame = frame
thread_shared.end_frame = frame
for output in self.outputs:
output.terminate()
@ -324,8 +413,13 @@ class CorrScope:
if not_benchmarking or benchmark_mode >= BenchmarkMode.RENDER:
# Render frame
t = time.perf_counter() * 1000.0
renderer.update_main_lines(render_inputs, trigger_samples)
frame_data = renderer.get_frame()
t1 = time.perf_counter() * 1000.0
# print(f"idle = {t - pt}, dt1 = {t1 - t}")
pt = t1
if not_benchmarking or benchmark_mode == BenchmarkMode.OUTPUT:
# Output frame
@ -336,13 +430,274 @@ class CorrScope:
break
if aborted:
# Outputting frame happens after most computation finished.
end_frame = frame + 1
thread_shared.end_frame = frame + 1
break
# Multiprocess
def play_parallel(nthread: int):
framebuffer_nbyte = len(renderer.get_frame())
print(f"framebuffer_nbyte = {framebuffer_nbyte}")
# setup threading
abort_from_thread = threading.Event()
# self.arg.is_aborted() from GUI, abort_from_thread.is_set() from thread
is_aborted = lambda: self.arg.is_aborted() or abort_from_thread.is_set()
@attr.dataclass
class RenderToOutput:
frame_idx: int
shmem: SharedMemory
completion: "Future[None]"
# Rely on avail_shmems for backpressure.
render_to_output: "Queue[RenderToOutput | None]" = Queue()
# Release all shmems after finishing rendering.
all_shmems: List[SharedMemory] = [
SharedMemory(create=True, size=framebuffer_nbyte)
for _ in range(2 * nthread)
]
is_submitting = [False, 0]
# Only send unused shmems to a worker process, and wait for it to be
# returned before reusing.
avail_shmems: "Queue[SharedMemory]" = Queue()
for shmem in all_shmems:
avail_shmems.put(shmem)
# TODO https://stackoverflow.com/questions/2829329/catch-a-threads-exception-in-the-caller-thread
def _render_thread():
end_frame = thread_shared.end_frame
prev = -1
# TODO gather trigger points from triggering threads
# For each frame, render each wave
for frame in range(begin_frame, end_frame):
if is_aborted():
break
time_seconds = frame / fps
should_render = (frame - begin_frame) % render_subfps == ahead
rounded = int(time_seconds)
if PRINT_TIMESTAMP and rounded != prev:
self.arg.progress(rounded)
prev = rounded
render_inputs = []
trigger_samples = []
# Get render-data from each wave.
for render_wave, channel in zip(self.render_waves, self.channels):
sample = round(render_wave.smp_s * time_seconds)
# Get trigger.
if not_benchmarking or benchmark_mode == BenchmarkMode.TRIGGER:
cache = PerFrameCache()
result = channel.trigger.get_trigger(sample, cache)
trigger_sample = result.result
freq_estimate = result.freq_estimate
else:
trigger_sample = sample
freq_estimate = 0
# Get render data.
if should_render:
trigger_samples.append(trigger_sample)
data = channel.get_render_around(trigger_sample)
render_inputs.append(RenderInput(data, freq_estimate))
if not should_render:
continue
# blocks until frames get rendered and shmem is returned by
# output_thread().
t = time.perf_counter()
shmem = avail_shmems.get()
t = time.perf_counter() - t
# if t >= 0.001:
# print("get shmem", t)
if is_aborted():
break
# blocking
t = time.perf_counter()
render_to_output.put(
RenderToOutput(
frame,
shmem,
pool.submit(
worker_render_frame,
render_inputs,
trigger_samples,
shmem.name,
),
)
)
t = time.perf_counter() - t
# if t >= 0.001:
# print("send to render", t)
# TODO if is_aborted(), should we insert class CancellationToken,
# rather than having output_thread() poll it too?
render_to_output.put(None)
print("exit render")
def render_thread():
"""
How do we know that if render_thread() crashes, output_thread() will
not block?
- `_render_thread()` does not return early, and will always
`render_to_output.put(None)` before returning.
- If `_render_thread()` crashes, `render_thread()` will call
`abort_from_thread.set()` before writing `render_to_output.put(
None)`. When the output thread reads None, it will see that it is
aborted.
"""
try:
_render_thread()
except BaseException as e:
abort_from_thread.set()
render_to_output.put(None)
raise e
def _output_thread():
thread_shared.end_frame = begin_frame
while True:
if is_aborted():
for output in self.outputs:
output.terminate()
break
# blocking
render_msg: Union[RenderToOutput, None] = render_to_output.get()
if render_msg is None:
if is_aborted():
for output in self.outputs:
output.terminate()
break
# Wait for shmem to be filled with data.
render_msg.completion.result()
frame_data = render_msg.shmem.buf[:framebuffer_nbyte]
if not_benchmarking or benchmark_mode == BenchmarkMode.OUTPUT:
# Output frame
for output in self.outputs:
if output.write_frame(frame_data) is outputs_.Stop:
abort_from_thread.set()
break
thread_shared.end_frame = render_msg.frame_idx + 1
avail_shmems.put(render_msg.shmem)
if is_aborted():
output_on_error()
print("exit output")
def output_on_error():
"""If is_aborted() is True but render_thread() is blocked on
render_to_output.put(), then we need to clear the queue so
render_thread() can return from put(), then check is_aborted() = True
and terminate."""
# It is an error to call output_on_error() when not aborted. If so,
# force an abort so we can print the error without deadlock.
was_aborted = is_aborted()
if not was_aborted:
abort_from_thread.set()
while True:
try:
render_msg = render_to_output.get(block=False)
if render_msg is None:
continue # probably empty?
# To avoid deadlock, we must return the shmem to
# _render_thread() in case it's blocked waiting for it. We do
# not need to wait for the shmem to be no longer written to (
# `render_msg.completion.result()`), since if we set
# is_aborted() to true before returning a shmem,
# `_render_thread()` will ignore the acquired shmem without
# writing to it.
avail_shmems.put(render_msg.shmem)
except Empty:
break
assert was_aborted
def output_thread():
"""
How do we know that if output_thread() crashes, render_thread() will
not block?
- `_output_thread()` does not return early. If it is aborted, it will
call `output_on_error()` to unblock `_render_thread()`.
- If `_output_thread()` crashes, `output_thread()` will call
`abort_from_thread.set()` before calling `output_on_error()` to
unblock `_render_thread()`.
I miss being able to poll()/WaitForMultipleObjects().
"""
try:
_output_thread()
except BaseException as e:
abort_from_thread.set()
output_on_error()
raise e
shmem_names: List[str] = [shmem.name for shmem in all_shmems]
with ProcessPoolExecutor(
nthread,
initializer=worker_create_renderer,
initargs=(renderer_params, shmem_names),
) as pool:
render_handle = PropagatingThread(
target=render_thread, name="render_thread"
)
output_handle = PropagatingThread(
target=output_thread, name="output_thread"
)
render_handle.start()
output_handle.start()
# throws
render_handle.join()
output_handle.join()
# TODO is it a problem that ProcessPoolExecutor's
# worker_create_renderer() creates SharedMemory handles, which are
# never closed when the process terminates?
#
# Constructing a new SharedMemory on every worker_render_frame() call
# is more "correct", but increases CPU usage by around 20% or more (
# see "shmem question"), likely due to page table thrashing.
for shmem in all_shmems:
shmem.unlink()
parallelism = self.arg.parallelism
with self._load_outputs():
if parallelism and parallelism.parallel:
play_parallel(parallelism.max_render_cores)
else:
play_impl()
if PRINT_TIMESTAMP:
# noinspection PyUnboundLocalVariable
dtime_sec = time.perf_counter() - begin
dframe = end_frame - begin_frame
dframe = thread_shared.end_frame - begin_frame
frame_per_sec = dframe / dtime_sec
try:

Wyświetl plik

@ -185,8 +185,6 @@ class MainWindow(qw.QMainWindow, Ui_MainWindow):
prefs_error = None
try:
self.pref = gp.load_prefs()
if not isinstance(self.pref, gp.GlobalPrefs):
raise TypeError(f"prefs.yaml contains wrong type {type(self.pref)}")
except Exception as e:
prefs_error = e
self.pref = gp.GlobalPrefs()
@ -212,6 +210,9 @@ class MainWindow(qw.QMainWindow, Ui_MainWindow):
self.on_separate_render_dir_toggled
)
self.action_parallel.setChecked(self.pref.parallel)
self.action_parallel.toggled.connect(self.on_parallel_toggled)
self.action_open_config_dir.triggered.connect(self.on_open_config_dir)
self.actionNew.triggered.connect(self.on_action_new)
@ -467,6 +468,9 @@ class MainWindow(qw.QMainWindow, Ui_MainWindow):
else:
self.pref.render_dir = ""
def on_parallel_toggled(self, checked: bool):
self.pref.parallel = checked
def on_open_config_dir(self):
appdata_uri = qc.QUrl.fromLocalFile(str(paths.appdata_dir))
QDesktopServices.openUrl(appdata_uri)
@ -608,7 +612,10 @@ class MainWindow(qw.QMainWindow, Ui_MainWindow):
)
arg = Arguments(
cfg_dir=self.cfg_dir, outputs=outputs, is_aborted=raise_exception
cfg_dir=self.cfg_dir,
outputs=outputs,
parallelism=self.pref.parallelism(),
is_aborted=raise_exception,
)
return arg
@ -766,7 +773,7 @@ class CorrThread(Thread):
job: CorrJob
def __init__(self, cfg: Config, arg: Arguments, mode: PreviewOrRender):
Thread.__init__(self)
Thread.__init__(self, name="CorrThread")
self.job = CorrJob(cfg, arg, mode)
def run(self):

Wyświetl plik

@ -75,7 +75,6 @@ class MainWindow(QWidget):
# Right-hand channel list
with append_widget(s, QVBoxLayout) as self.audioColumn:
# Top bar (master audio, trigger)
self.add_top_bar(s)
@ -96,7 +95,6 @@ class MainWindow(QWidget):
def add_general_tab(self, s: LayoutStack) -> QWidget:
tr = self.tr
with self.add_tab_stretch(s, tr("&General"), layout=QVBoxLayout) as tab:
# Global group
with append_widget(s, QGroupBox) as self.optionGlobal:
set_layout(s, QFormLayout)
@ -161,7 +159,6 @@ class MainWindow(QWidget):
with add_tab(
s, VerticalScrollArea, tr("&Appearance")
) as tab, fill_scroll_stretch(s, layout=QVBoxLayout):
with append_widget(
s, QGroupBox, title=tr("Appearance"), layout=QFormLayout
):
@ -431,7 +428,6 @@ class MainWindow(QWidget):
tr = self.tr
with append_widget(s, QHBoxLayout):
with append_widget(s, QVBoxLayout):
with append_widget(s, QGroupBox):
s.widget.setTitle(tr("FFmpeg Options"))
set_layout(s, QFormLayout)
@ -496,6 +492,9 @@ class MainWindow(QWidget):
self.action_separate_render_dir = create_element(
QAction, MainWindow, text=tr("&Separate Render Folder"), checkable=True
)
self.action_parallel = create_element(
QAction, MainWindow, text=tr("&Multi-Core Rendering"), checkable=True
)
self.action_open_config_dir = create_element(
QAction, MainWindow, text=tr("Open &Config Folder")
)
@ -518,6 +517,7 @@ class MainWindow(QWidget):
with append_menu(s) as self.menuTools:
w = self.menuTools
w.addAction(self.action_separate_render_dir)
w.addAction(self.action_parallel)
w.addSeparator()
w.addAction(self.action_open_config_dir)

Wyświetl plik

@ -304,6 +304,42 @@ def abstract_classvar(self) -> Any:
"""A ClassVar to be overriden by a subclass."""
@attr.dataclass
class RendererParams:
"""Serializable between processes."""
cfg: RendererConfig
lcfg: LayoutConfig
data_shapes: List[tuple]
channel_cfgs: Optional[List[ChannelConfig]]
render_strides: Optional[List[int]]
labels: Optional[List[str]]
@staticmethod
def from_obj(
cfg: RendererConfig,
lcfg: "LayoutConfig",
dummy_datas: List[np.ndarray],
channel_cfgs: Optional[List["ChannelConfig"]],
channels: List["Channel"],
):
if channels is not None:
render_strides = [channel.render_stride for channel in channels]
labels = [channel.label for channel in channels]
else:
render_strides = None
labels = None
return RendererParams(
cfg,
lcfg,
[data.shape for data in dummy_datas],
channel_cfgs,
render_strides,
labels,
)
class _RendererBackend(ABC):
"""
Renderer backend which takes data and produces images.
@ -325,17 +361,15 @@ class _RendererBackend(ABC):
Only used for tests/test_renderer.py.
"""
@classmethod
def from_obj(cls, *args, **kwargs):
return cls(RendererParams.from_obj(*args, **kwargs))
# Instance initializer
def __init__(
self,
cfg: RendererConfig,
lcfg: "LayoutConfig",
dummy_datas: List[np.ndarray],
channel_cfgs: Optional[List["ChannelConfig"]],
channels: List["Channel"],
):
def __init__(self, params: RendererParams):
cfg = params.cfg
self.cfg = cfg
self.lcfg = lcfg
self.lcfg = params.lcfg
self.w = cfg.divided_width
self.h = cfg.divided_height
@ -343,13 +377,15 @@ class _RendererBackend(ABC):
# Maps a continuous variable from 0 to 1 (representing one octave) to a color.
self.pitch_cmap = gen_circular_cmap(cfg.pitch_colors)
self.nplots = len(dummy_datas)
data_shapes = params.data_shapes
self.nplots = len(data_shapes)
if self.nplots > 0:
assert len(dummy_datas[0].shape) == 2, dummy_datas[0].shape
self.wave_nsamps = [data.shape[0] for data in dummy_datas]
self.wave_nchans = [data.shape[1] for data in dummy_datas]
assert len(data_shapes[0]) == 2, data_shapes[0]
self.wave_nsamps = [shape[0] for shape in data_shapes]
self.wave_nchans = [shape[1] for shape in data_shapes]
channel_cfgs = params.channel_cfgs
if channel_cfgs is None:
channel_cfgs = [ChannelConfig("") for _ in range(self.nplots)]
@ -367,12 +403,13 @@ class _RendererBackend(ABC):
]
# Load channel strides.
if channels is not None:
if len(channels) != self.nplots:
render_strides = params.render_strides
if render_strides is not None:
if len(render_strides) != self.nplots:
raise ValueError(
f"cannot assign {len(channels)} channels to {self.nplots} plots"
f"cannot assign {len(render_strides)} channels to {self.nplots} plots"
)
self.render_strides = [channel.render_stride for channel in channels]
self.render_strides = render_strides
else:
self.render_strides = [1] * self.nplots
@ -446,8 +483,8 @@ class AbstractMatplotlibRenderer(_RendererBackend, ABC):
def _canvas_to_bytes(canvas: "FigureCanvasBase") -> ByteBuffer:
pass
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, params: RendererParams):
super().__init__(params)
dict.__setitem__(
matplotlib.rcParams, "lines.antialiased", self.cfg.antialiasing
@ -455,6 +492,9 @@ class AbstractMatplotlibRenderer(_RendererBackend, ABC):
self._setup_axes(self.wave_nchans)
if params.labels is not None:
self.add_labels(params.labels)
self._artists: List["Artist"] = []
_fig: "Figure"
@ -892,8 +932,8 @@ class MatplotlibAggRenderer(AbstractMatplotlibRenderer):
class RendererFrontend(_RendererBackend, ABC):
"""Wrapper around _RendererBackend implementations, providing a better interface."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, params: RendererParams):
super().__init__(params)
self._update_main_lines = None
self._custom_lines = {} # type: Dict[Any, CustomLine]

Wyświetl plik

@ -1,5 +1,6 @@
from typing import *
import attr
from atomicwrites import atomic_write
from corrscope.config import DumpableAttrs, yaml
@ -21,6 +22,12 @@ class Ref(Generic[Attrs]):
setattr(self.obj, self.key, value)
@attr.dataclass
class Parallelism:
parallel: bool = True
max_render_cores: int = 2
class GlobalPrefs(DumpableAttrs, always_dump="*"):
# Most recent YAML or audio file opened
file_dir: str = ""
@ -40,13 +47,23 @@ class GlobalPrefs(DumpableAttrs, always_dump="*"):
else:
return self.file_dir_ref
parallel: bool = True
max_render_cores: int = 2
def parallelism(self) -> Parallelism:
return Parallelism(self.parallel, self.max_render_cores)
_PREF_PATH = paths.appdata_dir / "prefs.yaml"
def load_prefs() -> GlobalPrefs:
try:
return yaml.load(_PREF_PATH)
pref = yaml.load(_PREF_PATH)
if not isinstance(pref, GlobalPrefs):
raise TypeError(f"prefs.yaml contains wrong type {type(pref)}")
return pref
except FileNotFoundError:
return GlobalPrefs()

Wyświetl plik

@ -2,6 +2,7 @@
import shlex
import webbrowser
# Obtain path from package.dist-info/entry_points.txt
def run(path, arg_str):
module, func = path.split(":")

Wyświetl plik

@ -24,6 +24,7 @@ bools = hs.booleans()
default_labels = hs.sampled_from(DefaultLabel)
@pytest.mark.skip(reason="bronke")
@given(
# Channel
c_amplification=maybe_real,

Wyświetl plik

@ -37,7 +37,6 @@ def call_main(argv):
def yaml_sink(_mocker, command: str):
"""Mocks yaml.dump() and returns call args. Also tests dumping and loading."""
with mock.patch.object(yaml, "dump") as dump:
argv = shlex.split(command) + ["-w"]
call_main(argv)

Wyświetl plik

@ -222,7 +222,7 @@ def test_renderer_layout():
nplots = 15
datas = [RENDER_Y_ZEROS] * nplots
r = Renderer(cfg, lcfg, datas, None, None)
r = Renderer.from_obj(cfg, lcfg, datas, None, None)
r.update_main_lines(RenderInput.wrap_datas(datas), [0] * nplots)
layout = r.layout

Wyświetl plik

@ -66,7 +66,7 @@ def test_render_output():
"""Ensure rendering to output does not raise exceptions."""
datas = [RENDER_Y_ZEROS]
renderer = Renderer(CFG.render, CFG.layout, datas, None, None)
renderer = Renderer.from_obj(CFG.render, CFG.layout, datas, None, None)
out: FFmpegOutput = NULL_FFMPEG_OUTPUT(CFG)
renderer.update_main_lines(RenderInput.wrap_datas(datas), [0])

Wyświetl plik

@ -199,13 +199,13 @@ def test_default_colors(appear: Appearance, data):
lcfg = LayoutConfig(orientation=ORIENTATION)
datas = [data] * NPLOTS
r = Renderer(cfg, lcfg, datas, None, None)
r = Renderer.from_obj(cfg, lcfg, datas, None, None)
verify(r, appear, datas)
# Ensure default ChannelConfig(line_color=None) does not override line color
chan = ChannelConfig(wav_path="")
channels = [chan] * NPLOTS
r = Renderer(cfg, lcfg, datas, channels, None)
r = Renderer.from_obj(cfg, lcfg, datas, channels, None)
verify(r, appear, datas)
@ -222,7 +222,7 @@ def test_line_colors(appear: Appearance, data):
cfg.init_line_color = "#888888"
chan.line_color = appear.fg.color
r = Renderer(cfg, lcfg, datas, channels, None)
r = Renderer.from_obj(cfg, lcfg, datas, channels, None)
verify(r, appear, datas)
@ -343,7 +343,7 @@ def test_label_render(label_position: LabelPosition, data, hide_lines):
labels = ["#"] * nplots
datas = [data] * nplots
r = Renderer(cfg, lcfg, datas, None, None)
r = Renderer.from_obj(cfg, lcfg, datas, None, None)
r.add_labels(labels)
if not hide_lines:
r.update_main_lines(RenderInput.wrap_datas(datas), [0] * nplots)
@ -426,7 +426,7 @@ def verify_res_divisor_rounding(
datas = [RENDER_Y_ZEROS]
try:
renderer = Renderer(cfg, LayoutConfig(), datas, None, None)
renderer = Renderer.from_obj(cfg, LayoutConfig(), datas, None, None)
if not speed_hack:
renderer.update_main_lines(
RenderInput.wrap_datas(datas), [0] * len(datas)
@ -484,7 +484,7 @@ def test_renderer_knows_stride(mocker: "pytest_mock.MockFixture", integration: b
else:
channel = Channel(chan_cfg, corr_cfg, channel_idx=0)
data = channel.get_render_around(0)
renderer = Renderer(
renderer = Renderer.from_obj(
corr_cfg.render, corr_cfg.layout, [data], [chan_cfg], [channel]
)
assert renderer.render_strides == [subsampling * width_mul]
@ -510,7 +510,9 @@ def test_frontend_overrides_backend(mocker: "pytest_mock.MockFixture"):
channel = Channel(chan_cfg, corr_cfg, channel_idx=0)
data = channel.get_render_around(0)
renderer = Renderer(corr_cfg.render, corr_cfg.layout, [data], [chan_cfg], [channel])
renderer = Renderer.from_obj(
corr_cfg.render, corr_cfg.layout, [data], [chan_cfg], [channel]
)
renderer.update_main_lines([RenderInput.stub_new(data)], [0])
renderer.get_frame()