diff --git a/ovgenpy/triggers.py b/ovgenpy/triggers.py index bedf256..facf1ce 100644 --- a/ovgenpy/triggers.py +++ b/ovgenpy/triggers.py @@ -6,6 +6,7 @@ import numpy as np from scipy import signal from ovgenpy.renderer import MatplotlibRenderer, RendererConfig +from ovgenpy.util import find if TYPE_CHECKING: from ovgenpy.wave import Wave @@ -53,6 +54,9 @@ class CorrelationTrigger(Trigger): def __call__(self, wave: 'Wave', scan_nsamp: int): return CorrelationTrigger(wave, scan_nsamp, cfg=self) + # get_trigger postprocessing: self._zero_trigger + ZERO_CROSSING_SCAN = 256 + def __init__(self, wave: 'Wave', scan_nsamp: int, cfg: Config): """ Correlation-based trigger which looks at a window of `scan_nsamp` samples. @@ -72,6 +76,9 @@ class CorrelationTrigger(Trigger): # Create correlation buffer (containing a series of old data) self._buffer = np.zeros(scan_nsamp) + # Create zero crossing trigger, for postprocessing results + self._zero_trigger = ZeroCrossingTrigger(wave, self.ZERO_CROSSING_SCAN) + def get_trigger(self, index: int) -> int: """ :param index: sample index @@ -125,7 +132,8 @@ class CorrelationTrigger(Trigger): aligned = self._wave.get_around(trigger, self._buffer_nsamp) self._update_buffer(aligned) - return trigger + trigger2 = self._zero_trigger.get_trigger(trigger) + return trigger2 def _update_buffer(self, data: np.ndarray) -> None: """ @@ -180,3 +188,48 @@ def get_period(data: np.ndarray) -> int: crossX = zero_crossings[0] peakX = crossX + np.argmax(corr[crossX:]) return peakX + + +class ZeroCrossingTrigger(Trigger): + def __init__(self, wave: 'Wave', scan_nsamp: int): + super().__init__(wave, scan_nsamp) + + def get_trigger(self, index: int): + scan_nsamp = self._scan_nsamp + + if index not in range(len(self._wave.data)): + return index + + if self._wave[index] < 0: + direction = 1 + test = lambda a: a >= 0 + + elif self._wave[index] > 0: + direction = -1 + test = lambda a: a <= 0 + + else: # self._wave[sample] == 0 + return index + 1 + + data = self._wave[index : index + (direction * scan_nsamp) : direction] + intercepts = find(data, test) + try: + (delta,), value = next(intercepts) + return index + (delta * direction) + int(value <= 0) + + except StopIteration: # No zero-intercepts + return index + + # noinspection PyUnreachableCode + """ + `value <= 0` produces poor results on on sine waves, since it erroneously + increments the exact idx of the zero-crossing sample. + + `value < 0` produces poor results on impulse24000, since idx = 23999 which + doesn't match CorrelationTrigger. (scans left looking for a zero-crossing) + + CorrelationTrigger tries to maximize @trigger - @(trigger-1). I think always + incrementing zeros (impulse24000 = 24000) is acceptable. + + - To be consistent, we should increment zeros whenever we *start* there. + """ diff --git a/ovgenpy/util.py b/ovgenpy/util.py index c45ef07..b3945fc 100644 --- a/ovgenpy/util.py +++ b/ovgenpy/util.py @@ -1,2 +1,70 @@ +from typing import Callable, Tuple, TypeVar, Iterator + +import numpy as np +from itertools import chain + + def ceildiv(n, d): return -(-n // d) + + +T = TypeVar('T') + +# Adapted from https://github.com/numpy/numpy/issues/2269#issuecomment-14436725 +def find(a: 'np.ndarray[T]', predicate: 'Callable[[np.ndarray[T]], np.ndarray[bool]]', + chunk_size=1024) -> Iterator[Tuple[Tuple[int], T]]: + """ + Find the indices of array elements that match the predicate. + + Parameters + ---------- + a : array_like + Input data, must be 1D. + + predicate : function + A function which operates on sections of the given array, returning + element-wise True or False for each data value. + + chunk_size : integer + The length of the chunks to use when searching for matching indices. + For high probability predicates, a smaller number will make this + function quicker, similarly choose a larger number for low + probabilities. + + Returns + ------- + index_generator : generator + A generator of (indices, data value) tuples which make the predicate + True. + + See Also + -------- + where, nonzero + + Notes + ----- + This function is best used for finding the first, or first few, data values + which match the predicate. + + Examples + -------- + >>> a = np.sin(np.linspace(0, np.pi, 200)) + >>> result = find(a, lambda arr: arr > 0.9) + >>> next(result) + ((71, ), 0.900479032457) + >>> np.where(a > 0.9)[0][0] + 71 + + + """ + if a.ndim != 1: + raise ValueError('The array must be 1D, not {}.'.format(a.ndim)) + + i0 = 0 + chunk_inds = chain(range(chunk_size, a.size, chunk_size), [None]) + + for i1 in chunk_inds: + chunk = a[i0:i1] + for idx in predicate(chunk).nonzero()[0]: + yield (idx + i0, ), chunk[idx] + i0 = i1