kopia lustrzana https://github.com/corrscope/corrscope
Generate flickering video
rodzic
769e2ed64f
commit
02d17c4e97
corrscope
|
@ -264,9 +264,6 @@ class CorrScope:
|
|||
else:
|
||||
self.output_cfgs = [] # type: List[IOutputConfig]
|
||||
|
||||
if len(self.cfg.channels) == 0:
|
||||
raise CorrError("Config.channels is empty")
|
||||
|
||||
# Check for ffmpeg video recording, then mutate cfg.
|
||||
is_record = False
|
||||
for output in self.output_cfgs:
|
||||
|
@ -278,30 +275,9 @@ class CorrScope:
|
|||
else:
|
||||
self.cfg.before_preview()
|
||||
|
||||
trigger_waves: List[Wave]
|
||||
render_waves: List[Wave]
|
||||
channels: List[Channel]
|
||||
outputs: List[outputs_.Output]
|
||||
nchan: int
|
||||
|
||||
def _load_channels(self) -> None:
|
||||
with pushd(self.arg.cfg_dir):
|
||||
# Tell user if master audio path is invalid.
|
||||
# (Otherwise, only ffmpeg uses the value of master_audio)
|
||||
# Windows likes to raise OSError when path contains *, but we don't care.
|
||||
if self.cfg.master_audio and not Path(self.cfg.master_audio).exists():
|
||||
raise CorrError(
|
||||
f'File not found: master_audio="{self.cfg.master_audio}"'
|
||||
)
|
||||
self.channels = [
|
||||
Channel(ccfg, self.cfg, idx)
|
||||
for idx, ccfg in enumerate(self.cfg.channels)
|
||||
]
|
||||
self.trigger_waves = [channel.trigger_wave for channel in self.channels]
|
||||
self.render_waves = [channel.render_wave for channel in self.channels]
|
||||
self.triggers = [channel.trigger for channel in self.channels]
|
||||
self.nchan = len(self.channels)
|
||||
|
||||
@contextmanager
|
||||
def _load_outputs(self) -> Iterator[None]:
|
||||
with pushd(self.arg.cfg_dir):
|
||||
|
@ -313,13 +289,13 @@ class CorrScope:
|
|||
yield
|
||||
|
||||
def _renderer_params(self) -> RendererParams:
|
||||
dummy_datas = [channel.get_render_around(0) for channel in self.channels]
|
||||
dummy_datas = []
|
||||
return RendererParams.from_obj(
|
||||
self.cfg.render,
|
||||
self.cfg.layout,
|
||||
dummy_datas,
|
||||
self.cfg.channels,
|
||||
self.channels,
|
||||
None,
|
||||
self.arg.cfg_dir,
|
||||
)
|
||||
|
||||
|
@ -332,15 +308,12 @@ class CorrScope:
|
|||
raise ValueError("Cannot call CorrScope.play() more than once")
|
||||
self.has_played = True
|
||||
|
||||
self._load_channels()
|
||||
# Calculate number of frames (TODO master file?)
|
||||
fps = self.cfg.fps
|
||||
|
||||
begin_frame = round(fps * self.cfg.begin_time)
|
||||
|
||||
end_time = coalesce(
|
||||
self.cfg.end_time, max(wave.get_s() for wave in self.render_waves)
|
||||
)
|
||||
end_time = self.cfg.end_time
|
||||
end_frame = fps * end_time
|
||||
end_frame = int(end_frame) + 1
|
||||
|
||||
|
@ -398,50 +371,6 @@ class CorrScope:
|
|||
self.arg.progress(rounded)
|
||||
prev = rounded
|
||||
|
||||
render_inputs = []
|
||||
trigger_samples = []
|
||||
# Get render-data from each wave.
|
||||
for wave_idx, (render_wave, channel) in enumerate(
|
||||
zip(self.render_waves, self.channels)
|
||||
):
|
||||
sample = round(render_wave.smp_s * time_seconds)
|
||||
|
||||
# Get trigger.
|
||||
if not_benchmarking or benchmark_mode == BenchmarkMode.TRIGGER:
|
||||
cache = PerFrameCache()
|
||||
|
||||
result = channel.trigger.get_trigger(sample, cache)
|
||||
trigger_sample = result.result
|
||||
freq_estimate = result.freq_estimate
|
||||
|
||||
else:
|
||||
trigger_sample = sample
|
||||
freq_estimate = 0
|
||||
|
||||
# Get render data.
|
||||
if should_render:
|
||||
trigger_samples.append(trigger_sample)
|
||||
data = channel.get_render_around(trigger_sample)
|
||||
|
||||
stereo_data = None
|
||||
if (
|
||||
renderer.is_stereo_bars(wave_idx)
|
||||
and not channel.stereo_wave.is_mono
|
||||
):
|
||||
stereo_data = data
|
||||
# If stereo track is flattened to mono for rendering,
|
||||
# get raw stereo data.
|
||||
if stereo_data.shape[1] == 1:
|
||||
stereo_data = channel.get_render_stereo(trigger_sample)
|
||||
|
||||
stereo_levels = None
|
||||
if stereo_data is not None:
|
||||
stereo_levels = calc_stereo_levels(stereo_data)
|
||||
|
||||
render_inputs.append(
|
||||
RenderInput(data, stereo_levels, freq_estimate)
|
||||
)
|
||||
|
||||
if not should_render:
|
||||
continue
|
||||
|
||||
|
@ -449,7 +378,6 @@ class CorrScope:
|
|||
# Render frame
|
||||
|
||||
t = time.perf_counter() * 1000.0
|
||||
renderer.update_main_lines(render_inputs, trigger_samples)
|
||||
frame_data = renderer.get_frame()
|
||||
t1 = time.perf_counter() * 1000.0
|
||||
# print(f"idle = {t - pt}, dt1 = {t1 - t}")
|
||||
|
@ -467,287 +395,8 @@ class CorrScope:
|
|||
thread_shared.end_frame = frame + 1
|
||||
break
|
||||
|
||||
# Multiprocess
|
||||
def play_parallel(nthread: int):
|
||||
framebuffer_nbyte = len(renderer.get_frame())
|
||||
|
||||
# setup threading
|
||||
abort_from_thread = threading.Event()
|
||||
# self.arg.is_aborted() from GUI, abort_from_thread.is_set() from thread
|
||||
is_aborted = lambda: self.arg.is_aborted() or abort_from_thread.is_set()
|
||||
|
||||
@attr.dataclass
|
||||
class RenderToOutput:
|
||||
frame_idx: int
|
||||
shmem: SharedMemory
|
||||
completion: "Future[None]"
|
||||
|
||||
# Rely on avail_shmems for backpressure.
|
||||
render_to_output: "Queue[RenderToOutput | None]" = Queue()
|
||||
|
||||
# Release all shmems after finishing rendering.
|
||||
all_shmems: List[SharedMemory] = [
|
||||
SharedMemory(create=True, size=framebuffer_nbyte)
|
||||
for _ in range(2 * nthread)
|
||||
]
|
||||
|
||||
is_submitting = [False, 0]
|
||||
|
||||
# Only send unused shmems to a worker process, and wait for it to be
|
||||
# returned before reusing.
|
||||
avail_shmems: "Queue[SharedMemory]" = Queue()
|
||||
for shmem in all_shmems:
|
||||
avail_shmems.put(shmem)
|
||||
|
||||
# TODO https://stackoverflow.com/questions/2829329/catch-a-threads-exception-in-the-caller-thread
|
||||
def _render_thread():
|
||||
end_frame = thread_shared.end_frame
|
||||
prev = -1
|
||||
|
||||
# TODO gather trigger points from triggering threads
|
||||
# For each frame, render each wave
|
||||
for frame in range(begin_frame, end_frame):
|
||||
if is_aborted():
|
||||
break
|
||||
|
||||
time_seconds = frame / fps
|
||||
should_render = (frame - begin_frame) % render_subfps == ahead
|
||||
|
||||
rounded = int(time_seconds)
|
||||
if PRINT_TIMESTAMP and rounded != prev:
|
||||
self.arg.progress(rounded)
|
||||
prev = rounded
|
||||
|
||||
render_inputs = []
|
||||
trigger_samples = []
|
||||
# Get render-data from each wave.
|
||||
for wave_idx, (render_wave, channel) in enumerate(
|
||||
zip(self.render_waves, self.channels)
|
||||
):
|
||||
sample = round(render_wave.smp_s * time_seconds)
|
||||
|
||||
# Get trigger.
|
||||
if not_benchmarking or benchmark_mode == BenchmarkMode.TRIGGER:
|
||||
cache = PerFrameCache()
|
||||
|
||||
result = channel.trigger.get_trigger(sample, cache)
|
||||
trigger_sample = result.result
|
||||
freq_estimate = result.freq_estimate
|
||||
|
||||
else:
|
||||
trigger_sample = sample
|
||||
freq_estimate = 0
|
||||
|
||||
# Get render data.
|
||||
if should_render:
|
||||
trigger_samples.append(trigger_sample)
|
||||
data = channel.get_render_around(trigger_sample)
|
||||
|
||||
stereo_data = None
|
||||
if (
|
||||
renderer.is_stereo_bars(wave_idx)
|
||||
and not channel.stereo_wave.is_mono
|
||||
):
|
||||
stereo_data = data
|
||||
# If stereo track is flattened to mono for rendering,
|
||||
# get raw stereo data.
|
||||
if stereo_data.shape[1] == 1:
|
||||
stereo_data = channel.get_render_stereo(
|
||||
trigger_sample
|
||||
)
|
||||
|
||||
stereo_levels = None
|
||||
if stereo_data is not None:
|
||||
stereo_levels = calc_stereo_levels(stereo_data)
|
||||
|
||||
render_inputs.append(
|
||||
RenderInput(data, stereo_levels, freq_estimate)
|
||||
)
|
||||
|
||||
if not should_render:
|
||||
continue
|
||||
|
||||
# blocks until frames get rendered and shmem is returned by
|
||||
# output_thread().
|
||||
t = time.perf_counter()
|
||||
shmem = avail_shmems.get()
|
||||
t = time.perf_counter() - t
|
||||
# if t >= 0.001:
|
||||
# print("get shmem", t)
|
||||
if is_aborted():
|
||||
break
|
||||
|
||||
# blocking
|
||||
t = time.perf_counter()
|
||||
render_to_output.put(
|
||||
RenderToOutput(
|
||||
frame,
|
||||
shmem,
|
||||
pool.submit(
|
||||
worker_render_frame,
|
||||
render_inputs,
|
||||
trigger_samples,
|
||||
shmem.name,
|
||||
),
|
||||
)
|
||||
)
|
||||
t = time.perf_counter() - t
|
||||
# if t >= 0.001:
|
||||
# print("send to render", t)
|
||||
|
||||
# TODO if is_aborted(), should we insert class CancellationToken,
|
||||
# rather than having output_thread() poll it too?
|
||||
render_to_output.put(None)
|
||||
print("exit render")
|
||||
|
||||
def render_thread():
|
||||
"""
|
||||
How do we know that if render_thread() crashes, output_thread() will
|
||||
not block?
|
||||
|
||||
- `_render_thread()` does not return early, and will always
|
||||
`render_to_output.put(None)` before returning.
|
||||
|
||||
- If `_render_thread()` crashes, `render_thread()` will call
|
||||
`abort_from_thread.set()` before writing `render_to_output.put(
|
||||
None)`. When the output thread reads None, it will see that it is
|
||||
aborted.
|
||||
"""
|
||||
try:
|
||||
_render_thread()
|
||||
except BaseException as e:
|
||||
abort_from_thread.set()
|
||||
render_to_output.put(None)
|
||||
raise e
|
||||
|
||||
def _output_thread():
|
||||
thread_shared.end_frame = begin_frame
|
||||
|
||||
while True:
|
||||
if is_aborted():
|
||||
for output in self.outputs:
|
||||
output.terminate()
|
||||
break
|
||||
|
||||
# blocking
|
||||
render_msg: Union[RenderToOutput, None] = render_to_output.get()
|
||||
|
||||
if render_msg is None:
|
||||
if is_aborted():
|
||||
for output in self.outputs:
|
||||
output.terminate()
|
||||
break
|
||||
|
||||
# Wait for shmem to be filled with data.
|
||||
render_msg.completion.result()
|
||||
frame_data = render_msg.shmem.buf[:framebuffer_nbyte]
|
||||
|
||||
if not_benchmarking or benchmark_mode == BenchmarkMode.OUTPUT:
|
||||
# Output frame
|
||||
for output in self.outputs:
|
||||
if output.write_frame(frame_data) is outputs_.Stop:
|
||||
abort_from_thread.set()
|
||||
break
|
||||
thread_shared.end_frame = render_msg.frame_idx + 1
|
||||
|
||||
avail_shmems.put(render_msg.shmem)
|
||||
|
||||
if is_aborted():
|
||||
output_on_error()
|
||||
|
||||
print("exit output")
|
||||
|
||||
def output_on_error():
|
||||
"""If is_aborted() is True but render_thread() is blocked on
|
||||
render_to_output.put(), then we need to clear the queue so
|
||||
render_thread() can return from put(), then check is_aborted() = True
|
||||
and terminate."""
|
||||
|
||||
# It is an error to call output_on_error() when not aborted. If so,
|
||||
# force an abort so we can print the error without deadlock.
|
||||
was_aborted = is_aborted()
|
||||
if not was_aborted:
|
||||
abort_from_thread.set()
|
||||
|
||||
while True:
|
||||
try:
|
||||
render_msg = render_to_output.get(block=False)
|
||||
if render_msg is None:
|
||||
continue # probably empty?
|
||||
|
||||
# To avoid deadlock, we must return the shmem to
|
||||
# _render_thread() in case it's blocked waiting for it. We do
|
||||
# not need to wait for the shmem to be no longer written to (
|
||||
# `render_msg.completion.result()`), since if we set
|
||||
# is_aborted() to true before returning a shmem,
|
||||
# `_render_thread()` will ignore the acquired shmem without
|
||||
# writing to it.
|
||||
|
||||
avail_shmems.put(render_msg.shmem)
|
||||
except Empty:
|
||||
break
|
||||
|
||||
assert was_aborted
|
||||
|
||||
def output_thread():
|
||||
"""
|
||||
How do we know that if output_thread() crashes, render_thread() will
|
||||
not block?
|
||||
|
||||
- `_output_thread()` does not return early. If it is aborted, it will
|
||||
call `output_on_error()` to unblock `_render_thread()`.
|
||||
|
||||
- If `_output_thread()` crashes, `output_thread()` will call
|
||||
`abort_from_thread.set()` before calling `output_on_error()` to
|
||||
unblock `_render_thread()`.
|
||||
|
||||
I miss being able to poll()/WaitForMultipleObjects().
|
||||
"""
|
||||
try:
|
||||
_output_thread()
|
||||
except BaseException as e:
|
||||
abort_from_thread.set()
|
||||
output_on_error()
|
||||
raise e
|
||||
|
||||
shmem_names: List[str] = [shmem.name for shmem in all_shmems]
|
||||
|
||||
with ProcessPoolExecutor(
|
||||
nthread,
|
||||
initializer=worker_create_renderer,
|
||||
initargs=(renderer_params, shmem_names),
|
||||
) as pool:
|
||||
render_handle = PropagatingThread(
|
||||
target=render_thread, name="render_thread"
|
||||
)
|
||||
output_handle = PropagatingThread(
|
||||
target=output_thread, name="output_thread"
|
||||
)
|
||||
|
||||
render_handle.start()
|
||||
output_handle.start()
|
||||
|
||||
# throws
|
||||
render_handle.join()
|
||||
output_handle.join()
|
||||
|
||||
# TODO is it a problem that ProcessPoolExecutor's
|
||||
# worker_create_renderer() creates SharedMemory handles, which are
|
||||
# never closed when the process terminates?
|
||||
#
|
||||
# Constructing a new SharedMemory on every worker_render_frame() call
|
||||
# is more "correct", but increases CPU usage by around 20% or more (
|
||||
# see "shmem question"), likely due to page table thrashing.
|
||||
|
||||
for shmem in all_shmems:
|
||||
shmem.unlink()
|
||||
|
||||
parallelism = self.arg.parallelism
|
||||
with self._load_outputs():
|
||||
if parallelism and parallelism.parallel:
|
||||
play_parallel(parallelism.max_render_cores)
|
||||
else:
|
||||
play_impl()
|
||||
play_impl()
|
||||
|
||||
if PRINT_TIMESTAMP:
|
||||
# noinspection PyUnboundLocalVariable
|
||||
|
|
|
@ -449,36 +449,10 @@ class _RendererBase(ABC):
|
|||
|
||||
# Instance functionality
|
||||
|
||||
@abstractmethod
|
||||
def update_main_lines(
|
||||
self, inputs: List[RenderInput], trigger_samples: List[int]
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_frame(self) -> ByteBuffer:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def add_labels(self, labels: List[str]) -> Any:
|
||||
...
|
||||
|
||||
# Primarily used by RendererFrontend, not outside world.
|
||||
@abstractmethod
|
||||
def _update_lines_stereo(self, inputs: List[RenderInput]) -> None:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _add_xy_line_mono(
|
||||
self,
|
||||
name: str,
|
||||
wave_idx: int,
|
||||
xs: Sequence[float],
|
||||
ys: Sequence[float],
|
||||
stride: int,
|
||||
) -> CustomLine:
|
||||
...
|
||||
|
||||
|
||||
# See Wave.get_around() and designNotes.md.
|
||||
# Viewport functions
|
||||
|
@ -544,516 +518,18 @@ class AbstractMatplotlibRenderer(_RendererBase, ABC):
|
|||
|
||||
def __init__(self, params: RendererParams):
|
||||
super().__init__(params)
|
||||
|
||||
dict.__setitem__(
|
||||
matplotlib.rcParams, "lines.antialiased", self.cfg.antialiasing
|
||||
)
|
||||
|
||||
self._setup_axes(self.wave_nchans)
|
||||
|
||||
if params.labels is not None:
|
||||
self.add_labels(params.labels)
|
||||
|
||||
self._artists = []
|
||||
|
||||
_fig: "Figure"
|
||||
|
||||
_artists: List["Artist"]
|
||||
|
||||
# [wave][chan] Axes
|
||||
# Primary, used to draw oscilloscope lines and gridlines.
|
||||
_wave_chan_to_axes: List[List["Axes"]] # set by set_layout()
|
||||
|
||||
# _axes_mono[wave] = Axes
|
||||
# Secondary, used for titles and debug plots.
|
||||
_wave_to_mono_axes: List["Axes"]
|
||||
|
||||
# Fields updated by _update_lines_stereo():
|
||||
# [wave][chan] Line2D
|
||||
_wave_chan_to_line: "Optional[List[List[Line2D]]]" = None
|
||||
|
||||
# Only for stereo channels, if stereo bars are enabled.
|
||||
_wave_to_stereo_bar: "List[Optional[StereoBar]]"
|
||||
|
||||
def _setup_axes(self, wave_nchans: List[int]) -> None:
|
||||
"""
|
||||
Creates a flat array of Matplotlib Axes, with the new layout.
|
||||
Sets up each Axes with correct region limits.
|
||||
"""
|
||||
|
||||
# Only read by unit tests.
|
||||
self.layout = RendererLayout(self.lcfg, wave_nchans)
|
||||
layout_mono = RendererLayout(self.lcfg, [1] * self.nplots)
|
||||
|
||||
if hasattr(self, "_fig"):
|
||||
raise Exception("I don't currently expect to call _setup_axes() twice")
|
||||
# plt.close(self.fig)
|
||||
|
||||
cfg = self.cfg
|
||||
|
||||
self._fig = Figure()
|
||||
self._canvas_type(self._fig)
|
||||
|
||||
px_inch = PX_INCH / cfg.res_divisor
|
||||
self._fig.set_dpi(px_inch)
|
||||
|
||||
"""
|
||||
Requirements:
|
||||
- px_inch /= res_divisor (to scale visual elements correctly)
|
||||
- int(set_size_inches * px_inch) == self.w,h
|
||||
- matplotlib uses int instead of round. Who knows why.
|
||||
- round(set_size_inches * px_inch) == self.w,h
|
||||
- just in case matplotlib changes its mind
|
||||
|
||||
Solution:
|
||||
- (set_size_inches * px_inch) == self.w,h + 0.25
|
||||
- set_size_inches == (self.w,h + 0.25) / px_inch
|
||||
"""
|
||||
offset = 0.25
|
||||
self._fig.set_size_inches(
|
||||
(self.w + offset) / px_inch, (self.h + offset) / px_inch
|
||||
)
|
||||
|
||||
real_dims = self._fig.canvas.get_width_height()
|
||||
assert (self.w, self.h) == real_dims, [(self.w, self.h), real_dims]
|
||||
del real_dims
|
||||
|
||||
# Setup background
|
||||
self._fig.set_facecolor(cfg.bg_color)
|
||||
|
||||
if cfg.bg_image:
|
||||
img = mpl.image.imread(cfg.bg_image)
|
||||
|
||||
ax = self._fig.add_axes([0, 0, 1, 1])
|
||||
|
||||
# Hide black borders around screen edge.
|
||||
ax.set_axis_off()
|
||||
|
||||
# Size image to fill screen pixel-perfectly. Somehow, matplotlib requires
|
||||
# showing the image 1 screen-pixel smaller than the full area.
|
||||
|
||||
# Get image dimensions (in ipx).
|
||||
w = img.shape[1]
|
||||
h = img.shape[0]
|
||||
|
||||
# Setup axes to fit image to screen (while maintaining square pixels).
|
||||
# Axes automatically expand their limits to maintain square coordinates,
|
||||
# while imshow() stretches images to the full area supplied.
|
||||
ax.set_xlim(0, w)
|
||||
ax.set_ylim(0, h)
|
||||
|
||||
# Calculate (image pixels per screen pixel). Since we fit the image
|
||||
# on-screen, pick the minimum of the horizontal/vertical zoom factors.
|
||||
zoom = min(self.w / w, self.h / h)
|
||||
ipx_per_spx = 1 / zoom
|
||||
|
||||
# imshow() takes coordinates in axes units (here, ipx) and renders to
|
||||
# screen pixels. To workaround matplotlib stretching images off-screen,
|
||||
# we need an extent 1 spx smaller than full scale. So subtract 1 spx
|
||||
# (converted to ipx) from dimensions.
|
||||
ax.imshow(img, extent=(0, w - ipx_per_spx, 0, h - ipx_per_spx))
|
||||
|
||||
# Create Axes (using self.lcfg, wave_nchans)
|
||||
# [wave][chan] Axes
|
||||
self._wave_chan_to_axes = self.layout.arrange(self._axes_factory)
|
||||
|
||||
# _axes_mono[wave] = Axes
|
||||
self._wave_to_mono_axes = []
|
||||
|
||||
"""
|
||||
When calling _axes_factory() with the same position twice, we should pass a
|
||||
different label to get a different Axes, to avoid warning:
|
||||
|
||||
>>> Adding an axes using the same arguments as a previous axes
|
||||
currently reuses the earlier instance.
|
||||
In a future version, a new instance will always be created and returned.
|
||||
Meanwhile, this warning can be suppressed, and the future behavior ensured,
|
||||
by passing a unique label to each axes instance.
|
||||
|
||||
<<< ax=fig.add_axes(label=) is unused, even if you call ax.legend().
|
||||
"""
|
||||
# Returns 2D list of [self.nplots][1]Axes.
|
||||
axes_mono_2d = layout_mono.arrange(self._axes_factory, label="mono")
|
||||
for axes_list in axes_mono_2d:
|
||||
(axes,) = axes_list # type: Axes
|
||||
|
||||
# Pick colormap used for debug lines (_add_xy_line_mono()).
|
||||
# List of colors at
|
||||
# https://matplotlib.org/gallery/color/colormap_reference.html
|
||||
# Discussion at https://github.com/matplotlib/matplotlib/issues/10840
|
||||
cmap: ListedColormap = matplotlib.colormaps["Accent"]
|
||||
colors = cmap.colors
|
||||
axes.set_prop_cycle(color=colors)
|
||||
|
||||
self._wave_to_mono_axes.append(axes)
|
||||
|
||||
# Setup axes
|
||||
for wave_idx, N in enumerate(self.wave_nsamps):
|
||||
chan_to_axes = self._wave_chan_to_axes[wave_idx]
|
||||
|
||||
# Calculate the bounds of an Axes object to match the scale of calc_xs()
|
||||
# (unless cfg.viewport_width != 1).
|
||||
viewport_stride = self.render_strides[wave_idx] * cfg.viewport_width
|
||||
xlims = calc_limits(N, viewport_stride)
|
||||
ylim = cfg.viewport_height
|
||||
|
||||
def scale_axes(ax: "Axes"):
|
||||
ax.set_xlim(*xlims)
|
||||
ax.set_ylim(-ylim, ylim)
|
||||
|
||||
scale_axes(self._wave_to_mono_axes[wave_idx])
|
||||
|
||||
# When using overlay stereo, all channels map to the same Axes object.
|
||||
for ax in unique_by_id(chan_to_axes):
|
||||
scale_axes(ax)
|
||||
|
||||
# Setup midlines (depends on max_x and wave_data)
|
||||
midline_color = cfg.midline_color
|
||||
midline_width = cfg.grid_line_width
|
||||
|
||||
# Not quite sure if midlines or gridlines draw on top
|
||||
kw = dict(color=midline_color, linewidth=midline_width)
|
||||
if cfg.v_midline:
|
||||
ax.axvline(x=calc_center(viewport_stride), **kw)
|
||||
if cfg.h_midline:
|
||||
ax.axhline(y=0, **kw)
|
||||
|
||||
self._save_background()
|
||||
|
||||
transparent = "#00000000"
|
||||
|
||||
# satisfies RegionFactory
|
||||
def _axes_factory(self, r: RegionSpec, label: str = "") -> "Axes":
|
||||
cfg = self.cfg
|
||||
|
||||
# Calculate plot positions (relative to bottom-left) as fractions of the screen.
|
||||
width = 1 / r.ncol
|
||||
left = r.col / r.ncol
|
||||
assert 0 <= left < 1
|
||||
|
||||
height = 1 / r.nrow
|
||||
# We index rows from top down, but matplotlib positions plots from bottom up.
|
||||
# The final row (row = nrow-1) is located at the bottom of the graph, at y=0.
|
||||
bottom = (r.nrow - (r.row + 1)) / r.nrow
|
||||
assert 0 <= bottom < 1
|
||||
|
||||
# Disabling xticks/yticks is unnecessary, since we hide Axises.
|
||||
ax = self._fig.add_axes(
|
||||
[left, bottom, width, height], xticks=[], yticks=[], label=label
|
||||
)
|
||||
|
||||
grid_color = cfg.grid_color
|
||||
if grid_color:
|
||||
# Initialize borders
|
||||
# Hide Axises
|
||||
# (drawing them is very slow, and we disable ticks+labels anyway)
|
||||
ax.get_xaxis().set_visible(False)
|
||||
ax.get_yaxis().set_visible(False)
|
||||
|
||||
# Background color
|
||||
# ax.patch.set_fill(False) sets _fill=False,
|
||||
# then calls _set_facecolor(...) "alpha = self._alpha if self._fill else 0".
|
||||
# It is no faster than below.
|
||||
ax.set_facecolor(self.transparent)
|
||||
|
||||
# Set border colors
|
||||
for spine in ax.spines.values(): # type: Spine
|
||||
spine.set_linewidth(cfg.grid_line_width)
|
||||
spine.set_color(grid_color)
|
||||
|
||||
def hide(key: str):
|
||||
ax.spines[key].set_visible(False)
|
||||
|
||||
# Hide all borders except bottom-right.
|
||||
hide("top")
|
||||
hide("left")
|
||||
|
||||
# If bottom of screen, hide bottom. If right of screen, hide right.
|
||||
if r.screen_edges & Edges.Bottom:
|
||||
hide("bottom")
|
||||
if r.screen_edges & Edges.Right:
|
||||
hide("right")
|
||||
|
||||
# If our Axes is a stereo track, dim borders between channels. (Show
|
||||
# borders between waves at full opacity.)
|
||||
if cfg.stereo_grid_opacity > 0:
|
||||
dim_color = matplotlib.colors.to_rgba_array(grid_color)[0]
|
||||
dim_color[-1] = cfg.stereo_grid_opacity
|
||||
|
||||
def dim(key: str):
|
||||
ax.spines[key].set_color(dim_color)
|
||||
|
||||
else:
|
||||
dim = hide
|
||||
|
||||
# If not bottom of wave, dim bottom. If not right of wave, dim right.
|
||||
if not r.wave_edges & Edges.Bottom:
|
||||
dim("bottom")
|
||||
if not r.wave_edges & Edges.Right:
|
||||
dim("right")
|
||||
|
||||
else:
|
||||
ax.set_axis_off()
|
||||
|
||||
return ax
|
||||
|
||||
# Protected API
|
||||
def __add_lines_stereo(self, inputs: List[RenderInput]):
|
||||
cfg = self.cfg
|
||||
strides = self.render_strides
|
||||
|
||||
# Plot lines over background
|
||||
line_width = cfg.line_width
|
||||
|
||||
# Foreach wave, plot dummy data.
|
||||
lines2d = []
|
||||
wave_to_stereo_bar = []
|
||||
for wave_idx, input in enumerate(inputs):
|
||||
wave_data = input.data
|
||||
line_params = self._line_params[wave_idx]
|
||||
|
||||
# [nsamp][nchan] Amplitude
|
||||
wave_zeros = np.zeros_like(wave_data)
|
||||
|
||||
chan_to_axes = self._wave_chan_to_axes[wave_idx]
|
||||
wave_lines = []
|
||||
|
||||
xs = calc_xs(len(wave_zeros), strides[wave_idx])
|
||||
line_color = line_params.color
|
||||
|
||||
# Foreach chan
|
||||
for chan_idx, chan_zeros in enumerate(wave_zeros.T):
|
||||
ax = chan_to_axes[chan_idx]
|
||||
|
||||
chan_line: Line2D = ax.plot(
|
||||
xs, chan_zeros, color=line_color, linewidth=line_width
|
||||
)[0]
|
||||
|
||||
if cfg.line_outline_width > 0:
|
||||
chan_line.set_path_effects(
|
||||
[
|
||||
mpl.patheffects.Stroke(
|
||||
linewidth=cfg.line_width + 2 * cfg.line_outline_width,
|
||||
foreground=cfg.global_line_outline_color,
|
||||
),
|
||||
mpl.patheffects.Normal(),
|
||||
]
|
||||
)
|
||||
|
||||
wave_lines.append(chan_line)
|
||||
|
||||
lines2d.append(wave_lines)
|
||||
self._artists.extend(wave_lines)
|
||||
|
||||
# Add stereo bars if enabled and track is stereo.
|
||||
if input.stereo_levels:
|
||||
assert self._line_params[wave_idx].stereo_bars
|
||||
ax = self._wave_to_mono_axes[wave_idx]
|
||||
|
||||
viewport_stride = self.render_strides[wave_idx] * cfg.viewport_width
|
||||
x_center = calc_center(viewport_stride)
|
||||
|
||||
xlim = ax.get_xlim()
|
||||
x_range = (xlim[1] - xlim[0]) / 2
|
||||
|
||||
y_bottom = ax.get_ylim()[0]
|
||||
|
||||
h = abs(y_bottom) / 16
|
||||
stereo_rect = Rectangle((x_center, y_bottom - h), 0, 2 * h)
|
||||
stereo_rect.set_color(cfg.stereo_bar_color)
|
||||
stereo_rect.set_linewidth(0)
|
||||
ax.add_patch(stereo_rect)
|
||||
|
||||
stereo_bar = StereoBar(stereo_rect, x_center, x_range)
|
||||
|
||||
wave_to_stereo_bar.append(stereo_bar)
|
||||
self._artists.append(stereo_rect)
|
||||
else:
|
||||
wave_to_stereo_bar.append(None)
|
||||
|
||||
self._wave_chan_to_line = lines2d
|
||||
self._wave_to_stereo_bar = wave_to_stereo_bar
|
||||
|
||||
def _update_lines_stereo(self, inputs: List[RenderInput]) -> None:
|
||||
"""
|
||||
Preconditions:
|
||||
- inputs[wave] = ndarray, [samp][chan] = f32
|
||||
"""
|
||||
if self._wave_chan_to_line is None:
|
||||
self.__add_lines_stereo(inputs)
|
||||
|
||||
lines2d = self._wave_chan_to_line
|
||||
nplots = len(lines2d)
|
||||
ndata = len(inputs)
|
||||
if nplots != ndata:
|
||||
raise ValueError(
|
||||
f"incorrect data to plot: {nplots} plots but {ndata} dummy_datas"
|
||||
)
|
||||
|
||||
# Draw waveform data
|
||||
# Foreach wave
|
||||
for wave_idx, input in enumerate(inputs):
|
||||
wave_data = input.data
|
||||
freq_estimate = input.freq_estimate
|
||||
|
||||
wave_lines = lines2d[wave_idx]
|
||||
|
||||
color_by_pitch = self._line_params[wave_idx].color_by_pitch
|
||||
|
||||
# If we color notes by pitch, then on every frame,
|
||||
# recompute the color based on current pitch.
|
||||
# If no sound is detected, fall back to the default color.
|
||||
# If we don't color notes by pitch,
|
||||
# just keep the initial color and never overwrite it.
|
||||
if color_by_pitch:
|
||||
fallback_color = self._line_params[wave_idx].color
|
||||
color = freq_to_color(self.pitch_cmap, freq_estimate, fallback_color)
|
||||
|
||||
# Foreach chan
|
||||
for chan_idx, chan_data in enumerate(wave_data.T):
|
||||
chan_line = wave_lines[chan_idx]
|
||||
chan_line.set_ydata(chan_data)
|
||||
if color_by_pitch:
|
||||
chan_line.set_color(color)
|
||||
|
||||
stereo_bar = self._wave_to_stereo_bar[wave_idx]
|
||||
stereo_levels = inputs[wave_idx].stereo_levels
|
||||
assert bool(stereo_bar) == bool(
|
||||
stereo_levels
|
||||
), f"wave {wave_idx}: plot={stereo_bar} != values={stereo_levels}"
|
||||
if stereo_bar:
|
||||
stereo_bar.set_range(*stereo_levels)
|
||||
|
||||
def _add_xy_line_mono(
|
||||
self,
|
||||
name: str,
|
||||
wave_idx: int,
|
||||
xs: Sequence[float],
|
||||
ys: Sequence[float],
|
||||
stride: int,
|
||||
) -> CustomLine:
|
||||
"""Add a debug line, which can be repositioned every frame."""
|
||||
cfg = self.cfg
|
||||
|
||||
# Plot lines over background
|
||||
line_width = cfg.line_width
|
||||
|
||||
ax = self._wave_to_mono_axes[wave_idx]
|
||||
mono_line: Line2D = ax.plot(xs, ys, linewidth=line_width)[0]
|
||||
print(f"{name} {wave_idx} has color {mono_line.get_color()}")
|
||||
|
||||
self._artists.append(mono_line)
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
return CustomLine(stride, xs, mono_line.set_xdata, mono_line.set_ydata)
|
||||
|
||||
# Channel labels
|
||||
def add_labels(self, labels: List[str]) -> List["Text"]:
|
||||
"""
|
||||
Updates background, adds text.
|
||||
Do NOT call after calling self.add_lines().
|
||||
"""
|
||||
nlabel = len(labels)
|
||||
if nlabel != self.nplots:
|
||||
raise ValueError(
|
||||
f"incorrect labels: {self.nplots} plots but {nlabel} labels"
|
||||
)
|
||||
|
||||
cfg = self.cfg
|
||||
color = cfg.get_label_color
|
||||
|
||||
size_pt = cfg.label_font.size
|
||||
distance_px = cfg.label_padding_ratio * size_pt
|
||||
|
||||
@attr.dataclass
|
||||
class AxisPosition:
|
||||
pos_axes: float
|
||||
offset_px: float
|
||||
align: str
|
||||
|
||||
xpos = cfg.label_position.x.match(
|
||||
left=AxisPosition(0, distance_px, "left"),
|
||||
right=AxisPosition(1, -distance_px, "right"),
|
||||
)
|
||||
ypos = cfg.label_position.y.match(
|
||||
bottom=AxisPosition(0, distance_px, "bottom"),
|
||||
top=AxisPosition(1, -distance_px, "top"),
|
||||
)
|
||||
|
||||
pos_axes = (xpos.pos_axes, ypos.pos_axes)
|
||||
offset_pt = (xpos.offset_px, ypos.offset_px)
|
||||
|
||||
out: List["Text"] = []
|
||||
for label_text, ax in zip(labels, self._wave_to_mono_axes):
|
||||
# https://matplotlib.org/api/_as_gen/matplotlib.axes.Axes.annotate.html
|
||||
# Annotation subclasses Text.
|
||||
text: "Annotation" = ax.annotate(
|
||||
label_text,
|
||||
# Positioning
|
||||
xy=pos_axes,
|
||||
xycoords="axes fraction",
|
||||
xytext=offset_pt,
|
||||
textcoords="offset points",
|
||||
horizontalalignment=xpos.align,
|
||||
verticalalignment=ypos.align,
|
||||
# Cosmetics
|
||||
color=color,
|
||||
fontsize=px_from_points(size_pt),
|
||||
fontfamily=cfg.label_font.family,
|
||||
fontweight=("bold" if cfg.label_font.bold else "normal"),
|
||||
fontstyle=("italic" if cfg.label_font.italic else "normal"),
|
||||
)
|
||||
out.append(text)
|
||||
|
||||
self._save_background()
|
||||
return out
|
||||
self.color = 0
|
||||
|
||||
# Output frames
|
||||
def get_frame(self) -> ByteBuffer:
|
||||
"""Returns bytes with shape (h, w, self.bytes_per_pixel).
|
||||
The actual return value's shape may be flat.
|
||||
"""
|
||||
self._redraw_over_background()
|
||||
|
||||
canvas = self._fig.canvas
|
||||
|
||||
# Agg is the default noninteractive backend except on OSX.
|
||||
# https://matplotlib.org/faq/usage_faq.html
|
||||
if not isinstance(canvas, self._canvas_type):
|
||||
raise RuntimeError(
|
||||
f"oh shit, cannot read data from {obj_name(canvas)} != {self._canvas_type.__name__}"
|
||||
)
|
||||
|
||||
buffer_rgb = self._canvas_to_bytes(canvas)
|
||||
assert len(buffer_rgb) == self.w * self.h * self.bytes_per_pixel
|
||||
buffer_rgb = bytes([self.color]) * (self.w * self.h * self.bytes_per_pixel)
|
||||
self.color = 255 - self.color
|
||||
|
||||
return buffer_rgb
|
||||
|
||||
# Pre-rendered background
|
||||
bg_cache: Any # "matplotlib.backends._backend_agg.BufferRegion"
|
||||
|
||||
def _save_background(self) -> None:
|
||||
"""Draw static background."""
|
||||
# https://stackoverflow.com/a/8956211
|
||||
# https://matplotlib.org/api/animation_api.html#funcanimation
|
||||
fig = self._fig
|
||||
|
||||
fig.canvas.draw()
|
||||
self.bg_cache = fig.canvas.copy_from_bbox(fig.bbox)
|
||||
|
||||
def _redraw_over_background(self) -> None:
|
||||
"""Redraw animated elements of the image."""
|
||||
|
||||
# Both FigureCanvasAgg and FigureCanvasCairo, but not FigureCanvasBase,
|
||||
# support restore_region().
|
||||
canvas: FigureCanvasAgg = self._fig.canvas
|
||||
canvas.restore_region(self.bg_cache)
|
||||
|
||||
for artist in self._artists:
|
||||
artist.axes.draw_artist(artist)
|
||||
|
||||
# canvas.blit(self._fig.bbox) is unnecessary when drawing off-screen.
|
||||
|
||||
|
||||
class MatplotlibAggRenderer(AbstractMatplotlibRenderer):
|
||||
# implements AbstractMatplotlibRenderer
|
||||
|
@ -1105,17 +581,6 @@ class RendererFrontend(_RendererBase, ABC):
|
|||
line.set_xdata(0 * line.xdata)
|
||||
return out
|
||||
|
||||
# New methods.
|
||||
def update_main_lines(
|
||||
self, inputs: List[RenderInput], trigger_samples: List[int]
|
||||
) -> None:
|
||||
datas = [input.data for input in inputs]
|
||||
|
||||
self._update_lines_stereo(inputs)
|
||||
assert len(datas) == len(trigger_samples)
|
||||
for i, (data, trigger) in enumerate(zip(datas, trigger_samples)):
|
||||
self.move_viewport(i, trigger) # - len(data) / 2
|
||||
|
||||
_absolute: DefaultDict[int, MutableSequence[CustomLine]]
|
||||
|
||||
def update_custom_line(
|
||||
|
|
Ładowanie…
Reference in New Issue