kopia lustrzana https://github.com/corrscope/corrscope
Add pitch invariant trigger, set trigger_diameter=None (improves bass)
rodzic
da480dffe6
commit
3260104df2
|
@ -20,6 +20,7 @@ from corrscope.triggers import (
|
||||||
CorrelationTriggerConfig,
|
CorrelationTriggerConfig,
|
||||||
PerFrameCache,
|
PerFrameCache,
|
||||||
CorrelationTrigger,
|
CorrelationTrigger,
|
||||||
|
SpectrumConfig,
|
||||||
)
|
)
|
||||||
from corrscope.util import pushd, coalesce
|
from corrscope.util import pushd, coalesce
|
||||||
from corrscope.wave import Wave, Flatten
|
from corrscope.wave import Wave, Flatten
|
||||||
|
@ -118,6 +119,7 @@ def default_config(**kwargs) -> Config:
|
||||||
responsiveness=0.5,
|
responsiveness=0.5,
|
||||||
buffer_falloff=0.5,
|
buffer_falloff=0.5,
|
||||||
use_edge_trigger=False,
|
use_edge_trigger=False,
|
||||||
|
pitch_invariance=SpectrumConfig()
|
||||||
# Removed due to speed hit.
|
# Removed due to speed hit.
|
||||||
# post=LocalPostTriggerConfig(strength=0.1),
|
# post=LocalPostTriggerConfig(strength=0.1),
|
||||||
),
|
),
|
||||||
|
|
|
@ -1,6 +1,18 @@
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Type, Tuple, Optional, ClassVar, Callable, Union
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Type,
|
||||||
|
Tuple,
|
||||||
|
Optional,
|
||||||
|
ClassVar,
|
||||||
|
Callable,
|
||||||
|
Union,
|
||||||
|
NewType,
|
||||||
|
Sequence,
|
||||||
|
List,
|
||||||
|
Any,
|
||||||
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -105,10 +117,184 @@ class PerFrameCache:
|
||||||
# CorrelationTrigger
|
# CorrelationTrigger
|
||||||
|
|
||||||
|
|
||||||
class CorrelationTriggerConfig(ITriggerConfig):
|
class SpectrumConfig(KeywordAttrs):
|
||||||
|
"""
|
||||||
|
# Rationale:
|
||||||
|
If no basal frequency note-bands are to be truncated,
|
||||||
|
the spectrum must have freq resolution
|
||||||
|
`min_hz * (2 ** 1/notes_per_octave - 1)`.
|
||||||
|
|
||||||
|
At 20hz, 10 octaves, 12 notes/octave, this is 1.19Hz fft freqs.
|
||||||
|
Our highest band must be
|
||||||
|
`min_hz * 2**octaves`,
|
||||||
|
leading to nearly 20K freqs, which produces an somewhat slow FFT.
|
||||||
|
|
||||||
|
So increase min_hz and decrease octaves and notes_per_octave.
|
||||||
|
--------
|
||||||
|
Using a Constant-Q transform may eliminate performance concerns?
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Spectrum X density
|
||||||
|
min_hz: float = 20
|
||||||
|
octaves: int = 8
|
||||||
|
notes_per_octave: int = 6
|
||||||
|
|
||||||
|
# Spectrum Y power
|
||||||
|
exponent: float = 1
|
||||||
|
divide_by_freq: bool = True
|
||||||
|
|
||||||
|
# Spectral alignment and resampling
|
||||||
|
pitch_estimate_boost: float = 1.2
|
||||||
|
add_current_to_history: float = 0.1 # FIXME why does this exist?
|
||||||
|
max_octaves_to_resample: float = 1.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_notes_to_resample(self) -> int:
|
||||||
|
return round(self.notes_per_octave * self.max_octaves_to_resample)
|
||||||
|
|
||||||
|
# Time-domain history parameters
|
||||||
|
min_frames_between_recompute: int = 6
|
||||||
|
frames_to_lookbehind: int = 2
|
||||||
|
|
||||||
|
|
||||||
|
class DummySpectrum:
|
||||||
|
# noinspection PyMethodMayBeStatic,PyUnusedLocal
|
||||||
|
def calc_spectrum(self, data: np.ndarray) -> np.ndarray:
|
||||||
|
return np.array([])
|
||||||
|
|
||||||
|
|
||||||
|
# Indices are linearly spaced in FFT. Notes are exponentially spaced.
|
||||||
|
# FFT is grouped into notes.
|
||||||
|
FFTIndex = NewType("FFTIndex", int)
|
||||||
|
# Very hacky and weird. Maybe it's not worth getting mypy to pass.
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
FFTIndexArray = Any # mypy
|
||||||
|
else:
|
||||||
|
FFTIndexArray = "np.ndarray[FFTIndex]" # pycharm
|
||||||
|
|
||||||
|
|
||||||
|
class LogFreqSpectrum(DummySpectrum):
|
||||||
|
"""
|
||||||
|
Invariants:
|
||||||
|
- len(note_fenceposts) == n_fencepost
|
||||||
|
|
||||||
|
- rfft()[ : note_fenceposts[0]] is NOT used.
|
||||||
|
- rfft()[note_fenceposts[-1] : ] is NOT used.
|
||||||
|
- rfft()[note_fenceposts[0] : note_fenceposts[1]] becomes a note.
|
||||||
|
"""
|
||||||
|
|
||||||
|
n_fftindex: FFTIndex # Determines frequency resolution, not range.
|
||||||
|
note_fenceposts: FFTIndexArray
|
||||||
|
n_fencepost: int
|
||||||
|
|
||||||
|
def __init__(self, scfg: SpectrumConfig, subsmp_s: float, dummy_data: np.ndarray):
|
||||||
|
self.scfg = scfg
|
||||||
|
|
||||||
|
n_fftindex: FFTIndex = signal.next_fast_len(len(dummy_data))
|
||||||
|
|
||||||
|
# Increase n_fftindex until every note has nonzero width.
|
||||||
|
while True:
|
||||||
|
# Compute parameters
|
||||||
|
self.min_hz = scfg.min_hz
|
||||||
|
self.max_hz = self.min_hz * 2 ** scfg.octaves
|
||||||
|
n_fencepost = scfg.notes_per_octave * scfg.octaves + 1
|
||||||
|
|
||||||
|
note_fenceposts_hz = np.geomspace(
|
||||||
|
self.min_hz, self.max_hz, n_fencepost, dtype=FLOAT
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert fenceposts to FFTIndex
|
||||||
|
fft_from_hertz = n_fftindex / subsmp_s
|
||||||
|
note_fenceposts: FFTIndexArray = (
|
||||||
|
fft_from_hertz * note_fenceposts_hz
|
||||||
|
).astype(np.int32)
|
||||||
|
note_widths = np.diff(note_fenceposts)
|
||||||
|
|
||||||
|
if np.any(note_widths == 0):
|
||||||
|
n_fftindex = signal.next_fast_len(n_fftindex + n_fftindex // 5 + 1)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
self.n_fftindex = n_fftindex # Passed to rfft() to automatically zero-pad data.
|
||||||
|
self.note_fenceposts = note_fenceposts
|
||||||
|
self.n_fencepost = len(note_fenceposts)
|
||||||
|
|
||||||
|
def calc_spectrum(self, data: np.ndarray) -> np.ndarray:
|
||||||
|
""" Unfortunately converting to FLOAT (single) adds too much overhead.
|
||||||
|
|
||||||
|
Input: Time-domain signal to be analyzed.
|
||||||
|
Output: Frequency-domain spectrum with exponentially-spaced notes.
|
||||||
|
- ret[note] = nonnegative float.
|
||||||
|
"""
|
||||||
|
scfg = self.scfg
|
||||||
|
|
||||||
|
# Compute FFT spectrum[freq]
|
||||||
|
spectrum = np.fft.rfft(data, self.n_fftindex)
|
||||||
|
spectrum = abs(spectrum)
|
||||||
|
if scfg.exponent != 1:
|
||||||
|
spectrum **= scfg.exponent
|
||||||
|
|
||||||
|
# Compute energy of each note
|
||||||
|
# spectrum_per_note[note] = np.ndarray[float]
|
||||||
|
spectrum_per_note: List[np.ndarray] = split(spectrum, self.note_fenceposts)
|
||||||
|
|
||||||
|
# energy_per_note[note] = float
|
||||||
|
energy_per_note: np.ndarray
|
||||||
|
|
||||||
|
# np.add.reduce is much faster than np.sum/mean.
|
||||||
|
if scfg.divide_by_freq:
|
||||||
|
energy_per_note = np.array(
|
||||||
|
[np.add.reduce(region) / len(region) for region in spectrum_per_note]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
energy_per_note = np.array(
|
||||||
|
[np.add.reduce(region) for region in spectrum_per_note]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(energy_per_note) == self.n_fencepost - 1
|
||||||
|
return energy_per_note
|
||||||
|
|
||||||
|
|
||||||
|
def split(data: np.ndarray, fenceposts: Sequence[FFTIndex]) -> List[np.ndarray]:
|
||||||
|
""" Based off np.split(), but faster.
|
||||||
|
Unlike np.split, does not include data before fenceposts[0] or after fenceposts[-1].
|
||||||
|
"""
|
||||||
|
sub_arys = []
|
||||||
|
ndata = len(data)
|
||||||
|
for i in range(len(fenceposts) - 1):
|
||||||
|
st = fenceposts[i]
|
||||||
|
end = fenceposts[i + 1]
|
||||||
|
if not st < ndata:
|
||||||
|
break
|
||||||
|
region = data[st:end]
|
||||||
|
sub_arys.append(region)
|
||||||
|
|
||||||
|
return sub_arys
|
||||||
|
|
||||||
|
|
||||||
|
class CircularArray:
|
||||||
|
def __init__(self, size: int, *dims: int):
|
||||||
|
self.size = size
|
||||||
|
self.buf = np.zeros((size, *dims))
|
||||||
|
self.index = 0
|
||||||
|
|
||||||
|
def push(self, arr: np.ndarray) -> None:
|
||||||
|
if self.size == 0:
|
||||||
|
return
|
||||||
|
self.buf[self.index] = arr
|
||||||
|
self.index = (self.index + 1) % self.size
|
||||||
|
|
||||||
|
def peek(self) -> np.ndarray:
|
||||||
|
"""Return is borrowed from self.buf.
|
||||||
|
Do NOT push to self while borrow is alive."""
|
||||||
|
return self.buf[self.index]
|
||||||
|
|
||||||
|
|
||||||
|
class CorrelationTriggerConfig(ITriggerConfig, always_dump="pitch_invariance"):
|
||||||
# get_trigger
|
# get_trigger
|
||||||
edge_strength: float
|
edge_strength: float
|
||||||
trigger_diameter: float = 0.5
|
trigger_diameter: Optional[float] = None
|
||||||
|
|
||||||
trigger_falloff: Tuple[float, float] = (4.0, 1.0)
|
trigger_falloff: Tuple[float, float] = (4.0, 1.0)
|
||||||
recalc_semitones: float = 1.0
|
recalc_semitones: float = 1.0
|
||||||
|
@ -118,6 +304,9 @@ class CorrelationTriggerConfig(ITriggerConfig):
|
||||||
responsiveness: float
|
responsiveness: float
|
||||||
buffer_falloff: float # Gaussian std = wave_period * buffer_falloff
|
buffer_falloff: float # Gaussian std = wave_period * buffer_falloff
|
||||||
|
|
||||||
|
# Pitch invariance = compute spectrum.
|
||||||
|
pitch_invariance: Optional["SpectrumConfig"] = None
|
||||||
|
|
||||||
# region Legacy Aliases
|
# region Legacy Aliases
|
||||||
trigger_strength = Alias("edge_strength")
|
trigger_strength = Alias("edge_strength")
|
||||||
falloff_width = Alias("buffer_falloff")
|
falloff_width = Alias("buffer_falloff")
|
||||||
|
@ -152,6 +341,10 @@ class CorrelationTriggerConfig(ITriggerConfig):
|
||||||
class CorrelationTrigger(Trigger):
|
class CorrelationTrigger(Trigger):
|
||||||
cfg: CorrelationTriggerConfig
|
cfg: CorrelationTriggerConfig
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scfg(self) -> SpectrumConfig:
|
||||||
|
return self.cfg.pitch_invariance
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Correlation-based trigger which looks at a window of `trigger_tsamp` samples.
|
Correlation-based trigger which looks at a window of `trigger_tsamp` samples.
|
||||||
|
@ -181,6 +374,24 @@ class CorrelationTrigger(Trigger):
|
||||||
self._prev_period: Optional[int] = None
|
self._prev_period: Optional[int] = None
|
||||||
self._prev_window: Optional[np.ndarray] = None
|
self._prev_window: Optional[np.ndarray] = None
|
||||||
|
|
||||||
|
# (mutable) Log-scaled spectrum
|
||||||
|
self.frames_since_spectrum = 0
|
||||||
|
|
||||||
|
if self.scfg:
|
||||||
|
self._spectrum_calc = LogFreqSpectrum(
|
||||||
|
scfg=self.scfg,
|
||||||
|
subsmp_s=self._wave.smp_s / self._stride,
|
||||||
|
dummy_data=self._buffer,
|
||||||
|
)
|
||||||
|
self._spectrum = self._spectrum_calc.calc_spectrum(self._buffer)
|
||||||
|
self.history = CircularArray(
|
||||||
|
self.scfg.frames_to_lookbehind, self._buffer_nsamp
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._spectrum_calc = DummySpectrum()
|
||||||
|
self._spectrum = np.array([0])
|
||||||
|
self.history = CircularArray(0, self._buffer_nsamp)
|
||||||
|
|
||||||
def _calc_data_taper(self) -> np.ndarray:
|
def _calc_data_taper(self) -> np.ndarray:
|
||||||
""" Input data window. Zeroes out all data older than 1 frame old.
|
""" Input data window. Zeroes out all data older than 1 frame old.
|
||||||
See https://github.com/nyanpasu64/corrscope/wiki/Correlation-Trigger
|
See https://github.com/nyanpasu64/corrscope/wiki/Correlation-Trigger
|
||||||
|
@ -242,6 +453,7 @@ class CorrelationTrigger(Trigger):
|
||||||
# begin per-frame
|
# begin per-frame
|
||||||
def get_trigger(self, index: int, cache: "PerFrameCache") -> int:
|
def get_trigger(self, index: int, cache: "PerFrameCache") -> int:
|
||||||
N = self._buffer_nsamp
|
N = self._buffer_nsamp
|
||||||
|
cfg = self.cfg
|
||||||
|
|
||||||
# Get data
|
# Get data
|
||||||
stride = self._stride
|
stride = self._stride
|
||||||
|
@ -253,50 +465,39 @@ class CorrelationTrigger(Trigger):
|
||||||
period = get_period(data)
|
period = get_period(data)
|
||||||
cache.period = period * stride
|
cache.period = period * stride
|
||||||
|
|
||||||
if self._is_window_invalid(period):
|
semitones = self._is_window_invalid(period)
|
||||||
diameter, falloff = [round(period * x) for x in self.cfg.trigger_falloff]
|
# If pitch changed...
|
||||||
|
if semitones:
|
||||||
|
diameter, falloff = [round(period * x) for x in cfg.trigger_falloff]
|
||||||
falloff_window = cosine_flat(N, diameter, falloff)
|
falloff_window = cosine_flat(N, diameter, falloff)
|
||||||
window = np.minimum(falloff_window, self._data_taper)
|
window = np.minimum(falloff_window, self._data_taper)
|
||||||
|
|
||||||
|
# If pitch invariance enabled, rescale buffer to match data's pitch.
|
||||||
|
if self.scfg and (data != 0).any():
|
||||||
|
if isinstance(semitones, float):
|
||||||
|
peak_semitones = semitones
|
||||||
|
else:
|
||||||
|
peak_semitones = None
|
||||||
|
self.spectrum_rescale_buffer(data, peak_semitones)
|
||||||
|
|
||||||
self._prev_period = period
|
self._prev_period = period
|
||||||
self._prev_window = window
|
self._prev_window = window
|
||||||
else:
|
else:
|
||||||
window = self._prev_window
|
window = self._prev_window
|
||||||
|
|
||||||
|
self.history.push(data)
|
||||||
data *= window
|
data *= window
|
||||||
|
|
||||||
# prev_buffer
|
prev_buffer: np.ndarray = self._buffer.copy()
|
||||||
prev_buffer = self._windowed_step + self._buffer
|
prev_buffer += self._windowed_step
|
||||||
|
|
||||||
# Calculate correlation
|
# Calculate correlation
|
||||||
"""
|
if self.cfg.trigger_diameter is not None:
|
||||||
If offset < optimal, we need to `offset += positive`.
|
radius = round(N * self.cfg.trigger_diameter / 2)
|
||||||
- The peak will appear near the right of `data`.
|
else:
|
||||||
|
radius = None
|
||||||
|
|
||||||
Either we must slide prev_buffer to the right:
|
peak_offset = self.correlate_offset(data, prev_buffer, radius)
|
||||||
- correlate(data, prev_buffer)
|
|
||||||
- trigger = offset + peak_offset
|
|
||||||
|
|
||||||
Or we must slide data to the left (by sliding offset to the right):
|
|
||||||
- correlate(prev_buffer, data)
|
|
||||||
- trigger = offset - peak_offset
|
|
||||||
"""
|
|
||||||
corr = signal.correlate(data, prev_buffer) # returns double, not single/FLOAT
|
|
||||||
assert len(corr) == 2 * N - 1
|
|
||||||
|
|
||||||
# Find optimal offset (within trigger_diameter, default=±N/4)
|
|
||||||
mid = N - 1
|
|
||||||
radius = round(N * self.cfg.trigger_diameter / 2)
|
|
||||||
|
|
||||||
left = mid - radius
|
|
||||||
right = mid + radius + 1
|
|
||||||
|
|
||||||
corr = corr[left:right]
|
|
||||||
mid = mid - left
|
|
||||||
|
|
||||||
# argmax(corr) == mid + peak_offset == (data >> peak_offset)
|
|
||||||
# peak_offset == argmax(corr) - mid
|
|
||||||
peak_offset = np.argmax(corr) - mid # type: int
|
|
||||||
trigger = index + (stride * peak_offset)
|
trigger = index + (stride * peak_offset)
|
||||||
|
|
||||||
# Apply post trigger (before updating correlation buffer)
|
# Apply post trigger (before updating correlation buffer)
|
||||||
|
@ -306,11 +507,108 @@ class CorrelationTrigger(Trigger):
|
||||||
# Update correlation buffer (distinct from visible area)
|
# Update correlation buffer (distinct from visible area)
|
||||||
aligned = self._wave.get_around(trigger, self._buffer_nsamp, stride)
|
aligned = self._wave.get_around(trigger, self._buffer_nsamp, stride)
|
||||||
self._update_buffer(aligned, cache)
|
self._update_buffer(aligned, cache)
|
||||||
|
self.frames_since_spectrum += 1
|
||||||
|
|
||||||
return trigger
|
return trigger
|
||||||
|
|
||||||
def _is_window_invalid(self, period: int) -> bool:
|
def spectrum_rescale_buffer(
|
||||||
""" Returns True if pitch has changed more than `recalc_semitones`. """
|
self, data: np.ndarray, peak_semitones: Optional[float]
|
||||||
|
) -> None:
|
||||||
|
"""Rewrites self._spectrum, and possibly rescales self._buffer."""
|
||||||
|
|
||||||
|
scfg = self.scfg
|
||||||
|
N = self._buffer_nsamp
|
||||||
|
|
||||||
|
if self.frames_since_spectrum < self.scfg.min_frames_between_recompute:
|
||||||
|
return
|
||||||
|
self.frames_since_spectrum = 0
|
||||||
|
|
||||||
|
spectrum = self._spectrum_calc.calc_spectrum(data)
|
||||||
|
normalize_buffer(spectrum)
|
||||||
|
|
||||||
|
# Don't normalize self._spectrum. It was already normalized when being assigned.
|
||||||
|
prev_spectrum = self._spectrum_calc.calc_spectrum(self.history.peek())
|
||||||
|
prev_spectrum += scfg.add_current_to_history * spectrum
|
||||||
|
|
||||||
|
# rewrite spectrum
|
||||||
|
self._spectrum = spectrum
|
||||||
|
|
||||||
|
assert not np.any(np.isnan(spectrum))
|
||||||
|
|
||||||
|
# Find spectral correlation peak,
|
||||||
|
# but prioritize "changing pitch by ???".
|
||||||
|
if peak_semitones is not None:
|
||||||
|
boost_x = int(round(peak_semitones / 12 * scfg.notes_per_octave))
|
||||||
|
boost_y: float = scfg.pitch_estimate_boost
|
||||||
|
else:
|
||||||
|
boost_x = 0
|
||||||
|
boost_y = 1.0
|
||||||
|
|
||||||
|
# If we want to double pitch...
|
||||||
|
resample_notes = self.correlate_offset(
|
||||||
|
spectrum,
|
||||||
|
prev_spectrum,
|
||||||
|
scfg.max_notes_to_resample,
|
||||||
|
boost_x=boost_x,
|
||||||
|
boost_y=boost_y,
|
||||||
|
)
|
||||||
|
if resample_notes != 0:
|
||||||
|
# we must divide sampling rate by 2.
|
||||||
|
new_len = int(round(N / 2 ** (resample_notes / scfg.notes_per_octave)))
|
||||||
|
|
||||||
|
# Copy+resample self._buffer.
|
||||||
|
self._buffer = np.interp(
|
||||||
|
np.linspace(0, 1, new_len), np.linspace(0, 1, N), self._buffer
|
||||||
|
)
|
||||||
|
# assert len(self._buffer) == new_len
|
||||||
|
self._buffer = midpad(self._buffer, N)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def correlate_offset(
|
||||||
|
data: np.ndarray,
|
||||||
|
prev_buffer: np.ndarray,
|
||||||
|
radius: Optional[int],
|
||||||
|
boost_x: int = 0,
|
||||||
|
boost_y: float = 1.0,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
This is confusing.
|
||||||
|
|
||||||
|
If data index < optimal, data will be too far to the right,
|
||||||
|
and we need to `index += positive`.
|
||||||
|
- The peak will appear near the right of `data`.
|
||||||
|
|
||||||
|
Either we must slide prev_buffer to the right,
|
||||||
|
or we must slide data to the left (by sliding index to the right):
|
||||||
|
- correlate(data, prev_buffer)
|
||||||
|
- trigger = index + peak_offset
|
||||||
|
"""
|
||||||
|
N = len(data)
|
||||||
|
corr = signal.correlate(data, prev_buffer) # returns double, not single/FLOAT
|
||||||
|
Ncorr = 2 * N - 1
|
||||||
|
assert len(corr) == Ncorr
|
||||||
|
|
||||||
|
# Find optimal offset
|
||||||
|
mid = N - 1
|
||||||
|
|
||||||
|
if radius is not None:
|
||||||
|
left = max(mid - radius, 0)
|
||||||
|
right = min(mid + radius + 1, Ncorr)
|
||||||
|
|
||||||
|
corr = corr[left:right]
|
||||||
|
mid = mid - left
|
||||||
|
|
||||||
|
# Prioritize part of it.
|
||||||
|
corr[mid + boost_x : mid + boost_x + 1] *= boost_y
|
||||||
|
|
||||||
|
# argmax(corr) == mid + peak_offset == (data >> peak_offset)
|
||||||
|
# peak_offset == argmax(corr) - mid
|
||||||
|
peak_offset = np.argmax(corr) - mid # type: int
|
||||||
|
return peak_offset
|
||||||
|
|
||||||
|
def _is_window_invalid(self, period: int) -> Union[bool, float]:
|
||||||
|
""" Returns number of semitones,
|
||||||
|
if pitch has changed more than `recalc_semitones`. """
|
||||||
|
|
||||||
prev = self._prev_period
|
prev = self._prev_period
|
||||||
|
|
||||||
|
@ -319,12 +617,12 @@ class CorrelationTrigger(Trigger):
|
||||||
elif prev * period == 0:
|
elif prev * period == 0:
|
||||||
return prev != period
|
return prev != period
|
||||||
else:
|
else:
|
||||||
semitones = abs(np.log(period / prev) / np.log(2) * 12)
|
# If period doubles, semitones are -12.
|
||||||
|
semitones = np.log(period / prev) / np.log(2) * -12
|
||||||
# If semitones == recalc_semitones == 0, do NOT recalc.
|
# If semitones == recalc_semitones == 0, do NOT recalc.
|
||||||
if semitones <= self.cfg.recalc_semitones:
|
if abs(semitones) <= self.cfg.recalc_semitones:
|
||||||
return False
|
return False
|
||||||
return True
|
return semitones
|
||||||
|
|
||||||
def _update_buffer(self, data: np.ndarray, cache: PerFrameCache) -> None:
|
def _update_buffer(self, data: np.ndarray, cache: PerFrameCache) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
import attr
|
import attr
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from matplotlib.axes import Axes
|
from matplotlib.axes import Axes
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
|
from pytest_cases import pytest_fixture_plus
|
||||||
|
|
||||||
from corrscope import triggers
|
from corrscope import triggers
|
||||||
from corrscope.triggers import (
|
from corrscope.triggers import (
|
||||||
|
@ -11,6 +13,7 @@ from corrscope.triggers import (
|
||||||
PerFrameCache,
|
PerFrameCache,
|
||||||
ZeroCrossingTriggerConfig,
|
ZeroCrossingTriggerConfig,
|
||||||
LocalPostTriggerConfig,
|
LocalPostTriggerConfig,
|
||||||
|
SpectrumConfig,
|
||||||
)
|
)
|
||||||
from corrscope.wave import Wave
|
from corrscope.wave import Wave
|
||||||
|
|
||||||
|
@ -25,10 +28,16 @@ def cfg_template(**kwargs) -> CorrelationTriggerConfig:
|
||||||
return attr.evolve(cfg, **kwargs)
|
return attr.evolve(cfg, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", params=[False, True])
|
@pytest_fixture_plus
|
||||||
def cfg(request):
|
@pytest.mark.parametrize("use_edge_trigger", [False, True])
|
||||||
use_edge_trigger = request.param
|
@pytest.mark.parametrize("trigger_diameter", [None, 0.5])
|
||||||
return cfg_template(use_edge_trigger=use_edge_trigger)
|
@pytest.mark.parametrize("pitch_invariance", [None, SpectrumConfig()])
|
||||||
|
def cfg(use_edge_trigger, trigger_diameter, pitch_invariance):
|
||||||
|
return cfg_template(
|
||||||
|
use_edge_trigger=use_edge_trigger,
|
||||||
|
trigger_diameter=trigger_diameter,
|
||||||
|
pitch_invariance=pitch_invariance,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(
|
@pytest.fixture(
|
||||||
|
@ -177,6 +186,43 @@ def test_trigger_should_recalc_window():
|
||||||
assert trigger._is_window_invalid(x), x
|
assert trigger._is_window_invalid(x), x
|
||||||
|
|
||||||
|
|
||||||
|
# Test pitch-invariant triggering using spectrum
|
||||||
|
def test_correlate_offset():
|
||||||
|
"""
|
||||||
|
Catches bug where writing N instead of Ncorr
|
||||||
|
prevented function from returning positive numbers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
np.random.seed(31337)
|
||||||
|
correlate_offset = CorrelationTrigger.correlate_offset
|
||||||
|
|
||||||
|
# Ensure autocorrelation on random data returns peak at 0.
|
||||||
|
N = 100
|
||||||
|
spectrum = np.random.random(N)
|
||||||
|
assert correlate_offset(spectrum, spectrum, 12) == 0
|
||||||
|
|
||||||
|
# Ensure cross-correlation of time-shifted impulses works.
|
||||||
|
# Assume wave where y=[i==99].
|
||||||
|
wave = np.eye(N)[::-1]
|
||||||
|
# Taking a slice beginning at index i will produce an impulse at 99-i.
|
||||||
|
left = wave[30]
|
||||||
|
right = wave[40]
|
||||||
|
|
||||||
|
# We need to slide `left` to the right by 10 samples, and vice versa.
|
||||||
|
for radius in [None, 12]:
|
||||||
|
assert correlate_offset(data=left, prev_buffer=right, radius=radius) == 10
|
||||||
|
assert correlate_offset(data=right, prev_buffer=left, radius=radius) == -10
|
||||||
|
|
||||||
|
# The correlation peak at zero-offset is small enough for boost_x to be returned.
|
||||||
|
boost_y = 1.5
|
||||||
|
ones = np.ones(N)
|
||||||
|
for boost_x in [6, -6]:
|
||||||
|
assert (
|
||||||
|
correlate_offset(ones, ones, radius=9, boost_x=boost_x, boost_y=boost_y)
|
||||||
|
== boost_x
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Test the ability to load legacy TriggerConfig
|
# Test the ability to load legacy TriggerConfig
|
||||||
|
|
||||||
|
|
||||||
|
|
Ładowanie…
Reference in New Issue