Add ZeroCrossingTrigger and call it from CorrelationTrigger

pull/357/head
nyanpasu64 2018-07-15 03:11:58 -07:00
rodzic 9b9288f7ad
commit d7f15b233b
2 zmienionych plików z 122 dodań i 1 usunięć

Wyświetl plik

@ -6,6 +6,7 @@ import numpy as np
from scipy import signal
from ovgenpy.renderer import MatplotlibRenderer, RendererConfig
from ovgenpy.util import find
if TYPE_CHECKING:
from ovgenpy.wave import Wave
@ -53,6 +54,9 @@ class CorrelationTrigger(Trigger):
def __call__(self, wave: 'Wave', scan_nsamp: int):
return CorrelationTrigger(wave, scan_nsamp, cfg=self)
# get_trigger postprocessing: self._zero_trigger
ZERO_CROSSING_SCAN = 256
def __init__(self, wave: 'Wave', scan_nsamp: int, cfg: Config):
"""
Correlation-based trigger which looks at a window of `scan_nsamp` samples.
@ -72,6 +76,9 @@ class CorrelationTrigger(Trigger):
# Create correlation buffer (containing a series of old data)
self._buffer = np.zeros(scan_nsamp)
# Create zero crossing trigger, for postprocessing results
self._zero_trigger = ZeroCrossingTrigger(wave, self.ZERO_CROSSING_SCAN)
def get_trigger(self, index: int) -> int:
"""
:param index: sample index
@ -125,7 +132,8 @@ class CorrelationTrigger(Trigger):
aligned = self._wave.get_around(trigger, self._buffer_nsamp)
self._update_buffer(aligned)
return trigger
trigger2 = self._zero_trigger.get_trigger(trigger)
return trigger2
def _update_buffer(self, data: np.ndarray) -> None:
"""
@ -180,3 +188,48 @@ def get_period(data: np.ndarray) -> int:
crossX = zero_crossings[0]
peakX = crossX + np.argmax(corr[crossX:])
return peakX
class ZeroCrossingTrigger(Trigger):
def __init__(self, wave: 'Wave', scan_nsamp: int):
super().__init__(wave, scan_nsamp)
def get_trigger(self, index: int):
scan_nsamp = self._scan_nsamp
if index not in range(len(self._wave.data)):
return index
if self._wave[index] < 0:
direction = 1
test = lambda a: a >= 0
elif self._wave[index] > 0:
direction = -1
test = lambda a: a <= 0
else: # self._wave[sample] == 0
return index + 1
data = self._wave[index : index + (direction * scan_nsamp) : direction]
intercepts = find(data, test)
try:
(delta,), value = next(intercepts)
return index + (delta * direction) + int(value <= 0)
except StopIteration: # No zero-intercepts
return index
# noinspection PyUnreachableCode
"""
`value <= 0` produces poor results on on sine waves, since it erroneously
increments the exact idx of the zero-crossing sample.
`value < 0` produces poor results on impulse24000, since idx = 23999 which
doesn't match CorrelationTrigger. (scans left looking for a zero-crossing)
CorrelationTrigger tries to maximize @trigger - @(trigger-1). I think always
incrementing zeros (impulse24000 = 24000) is acceptable.
- To be consistent, we should increment zeros whenever we *start* there.
"""

Wyświetl plik

@ -1,2 +1,70 @@
from typing import Callable, Tuple, TypeVar, Iterator
import numpy as np
from itertools import chain
def ceildiv(n, d):
return -(-n // d)
T = TypeVar('T')
# Adapted from https://github.com/numpy/numpy/issues/2269#issuecomment-14436725
def find(a: 'np.ndarray[T]', predicate: 'Callable[[np.ndarray[T]], np.ndarray[bool]]',
chunk_size=1024) -> Iterator[Tuple[Tuple[int], T]]:
"""
Find the indices of array elements that match the predicate.
Parameters
----------
a : array_like
Input data, must be 1D.
predicate : function
A function which operates on sections of the given array, returning
element-wise True or False for each data value.
chunk_size : integer
The length of the chunks to use when searching for matching indices.
For high probability predicates, a smaller number will make this
function quicker, similarly choose a larger number for low
probabilities.
Returns
-------
index_generator : generator
A generator of (indices, data value) tuples which make the predicate
True.
See Also
--------
where, nonzero
Notes
-----
This function is best used for finding the first, or first few, data values
which match the predicate.
Examples
--------
>>> a = np.sin(np.linspace(0, np.pi, 200))
>>> result = find(a, lambda arr: arr > 0.9)
>>> next(result)
((71, ), 0.900479032457)
>>> np.where(a > 0.9)[0][0]
71
"""
if a.ndim != 1:
raise ValueError('The array must be 1D, not {}.'.format(a.ndim))
i0 = 0
chunk_inds = chain(range(chunk_size, a.size, chunk_size), [None])
for i1 in chunk_inds:
chunk = a[i0:i1]
for idx in predicate(chunk).nonzero()[0]:
yield (idx + i0, ), chunk[idx]
i0 = i1