From 8909faf5aca0ac6c06b7bdc951c250bc603120b8 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Thu, 14 Nov 2024 21:34:54 -0500 Subject: [PATCH] ClockSync runs as a singleton thread --- src/ezmsg/lsl/inlet.py | 13 +----- src/ezmsg/lsl/outlet.py | 10 +---- src/ezmsg/lsl/util.py | 97 +++++++++++++++++++++-------------------- tests/test_inlet.py | 35 +++++++-------- tests/test_util.py | 13 +++--- 5 files changed, 77 insertions(+), 91 deletions(-) diff --git a/src/ezmsg/lsl/inlet.py b/src/ezmsg/lsl/inlet.py index fcd9f83..933e840 100644 --- a/src/ezmsg/lsl/inlet.py +++ b/src/ezmsg/lsl/inlet.py @@ -75,7 +75,6 @@ class LSLInletSettings(ez.Settings): class LSLInletState(ez.State): resolver: typing.Optional[pylsl.ContinuousResolver] = None inlet: typing.Optional[pylsl.StreamInlet] = None - clock_offset: float = 0.0 class LSLInletUnit(ez.Unit): @@ -94,9 +93,6 @@ class LSLInletUnit(ez.Unit): INPUT_SETTINGS = ez.InputStream(LSLInletSettings) OUTPUT_SIGNAL = ez.OutputStream(AxisArray) - # Share clock correction across all instances - clock_sync = ClockSync() - def __init__(self, *args, **kwargs) -> None: """ Handle deprecated arguments. Whereas previously stream_name and stream_type were in the @@ -107,6 +103,7 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._msg_template: typing.Optional[AxisArray] = None self._fetch_buffer: typing.Optional[npt.NDArray] = None + self._clock_sync = ClockSync() def _reset_resolver(self) -> None: self.STATE.resolver = pylsl.ContinuousResolver(pred=None, forget_after=30.0) @@ -211,7 +208,6 @@ def _reset_inlet(self) -> None: async def initialize(self) -> None: self._reset_resolver() self._reset_inlet() - await self.clock_sync.update(force=True, burst=1000) def shutdown(self) -> None: if self.STATE.inlet is not None: @@ -222,11 +218,6 @@ def shutdown(self) -> None: del self.STATE.resolver self.STATE.resolver = None - @ez.task - async def clock_sync_task(self) -> None: - while True: - await self.clock_sync.update(force=False, burst=4) - @ez.subscriber(INPUT_SETTINGS) async def on_settings(self, msg: LSLInletSettings) -> None: # The message may be full LSLInletSettings, a dict of settings, just the info, or dict of just info. @@ -268,7 +259,7 @@ async def lsl_pull(self) -> typing.AsyncGenerator: if self.SETTINGS.use_arrival_time: timestamps = time.time() - (timestamps - timestamps[0]) else: - timestamps = self.clock_sync.lsl2system(timestamps) + timestamps = self._clock_sync.lsl2system(timestamps) if self.SETTINGS.info.nominal_srate <= 0.0: # Irregular rate stream uses CoordinateAxis for time so each sample has a timestamp. diff --git a/src/ezmsg/lsl/outlet.py b/src/ezmsg/lsl/outlet.py index 7226e69..74d8b31 100644 --- a/src/ezmsg/lsl/outlet.py +++ b/src/ezmsg/lsl/outlet.py @@ -48,22 +48,14 @@ class LSLOutletUnit(ez.Unit): SETTINGS = LSLOutletSettings STATE = LSLOutletState - # Share clock correction across all instances - clock_sync = ClockSync() - async def initialize(self) -> None: self._stream_created = False - await self.clock_sync.update(force=True, burst=1000) + self._clock_sync = ClockSync() def shutdown(self) -> None: del self.STATE.outlet self.STATE.outlet = None - @ez.task - async def clock_sync_task(self) -> None: - while True: - await self.clock_sync.update(force=False, burst=4) - @ez.subscriber(INPUT_SIGNAL, zero_copy=True) async def lsl_outlet(self, msg: AxisArray) -> None: if self.STATE.outlet is None: diff --git a/src/ezmsg/lsl/util.py b/src/ezmsg/lsl/util.py index 6b0a3be..8038675 100644 --- a/src/ezmsg/lsl/util.py +++ b/src/ezmsg/lsl/util.py @@ -1,4 +1,4 @@ -import asyncio +import threading import time import typing @@ -7,7 +7,7 @@ import pylsl -async def collect_timestamp_pairs( +def collect_timestamp_pairs( npairs: int = 4, ) -> typing.Tuple[np.ndarray, np.ndarray]: xs = [] @@ -19,67 +19,70 @@ async def collect_timestamp_pairs( x, y = pylsl.local_clock(), time.time() xs.append(x) ys.append(y) - await asyncio.sleep(0.001) + time.sleep(0.001) return np.array(xs), np.array(ys) class ClockSync: - def __init__(self, alpha: float = 0.1, min_interval: float = 0.5): - self.alpha = alpha - self.min_interval = min_interval + _instance = None + _lock = threading.Lock() - self._offset: typing.Optional[float] = None - self._last_update = 0.0 - self.count = 0 + def __new__(cls, *args, **kwargs): + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance - @property - def offset(self) -> float: - return self._offset + def __init__(self, alpha: float = 0.1, min_interval: float = 0.1): + if not hasattr(self, "_initialized"): + self._alpha = alpha + self._interval = min_interval + self._offset: typing.Optional[float] = None - @property - def last_updated(self) -> float: - return self._last_update - - async def update(self, force: bool = False, burst: int = 4) -> None: - """ - Update the clock offset estimate. This should be called ~regularly in a long-running - asynchronous task. - - Args: - force: Whether to force an update even if the minimum interval hasn't passed. - burst: The number of pairs to collect for this update. - """ - dur_since_last = time.time() - self._last_update - dur_until_next = self.min_interval - dur_since_last - if force or dur_until_next <= 0: - xs, ys = await collect_timestamp_pairs(burst) - self.count += burst - if burst > 0: - self._last_update = ys[-1] + self._running = True + self._initialized = True + self._thread = threading.Thread(target=self._run) + self._thread.daemon = True + self._thread.start() + + def _run(self): + xs, ys = collect_timestamp_pairs(100) + self._offset = np.mean(ys - xs) + while self._running: + time.sleep(self._interval) + xs, ys = collect_timestamp_pairs(4) offset = np.mean(ys - xs) - if self.offset is not None: - # Do exponential smoothing update. - self._offset = (1 - self.alpha) * self._offset + self.alpha * offset - else: - # First iteration. No smoothing. - self._offset = offset - else: - await asyncio.sleep(dur_until_next) + self._offset = (1 - self._alpha) * self._offset + self._alpha * offset + + @property + def offset(self) -> float: + with self._lock: + return self._offset @typing.overload - def lsl2system(self, lsl_timestamp: float) -> float: ... + def lsl2system(self, lsl_timestamp: float) -> float: + ... + @typing.overload - def lsl2system(self, lsl_timestamp: npt.NDArray[float]) -> npt.NDArray[float]: ... + def lsl2system(self, lsl_timestamp: npt.NDArray[float]) -> npt.NDArray[float]: + ... + def lsl2system(self, lsl_timestamp): # offset = system - lsl --> system = lsl + offset - return lsl_timestamp + self._offset + with self._lock: + return lsl_timestamp + self._offset @typing.overload - def system2lsl(self, system_timestamp: float) -> float: ... + def system2lsl(self, system_timestamp: float) -> float: + ... + @typing.overload def system2lsl( - self, system_timestamp: npt.NDArray[float] - ) -> npt.NDArray[float]: ... + self, system_timestamp: npt.NDArray[float] + ) -> npt.NDArray[float]: + ... + def system2lsl(self, system_timestamp): # offset = system - lsl --> lsl = system - offset - return system_timestamp - self._offset + with self._lock: + return system_timestamp - self._offset diff --git a/tests/test_inlet.py b/tests/test_inlet.py index 43c37a5..936b45a 100644 --- a/tests/test_inlet.py +++ b/tests/test_inlet.py @@ -26,27 +26,10 @@ def test_inlet_init_defaults(): assert True -async def dummy_outlet(rate: float = 100.0, n_chans: int = 8): - info = pylsl.StreamInfo( - name="dummy", type="dummy", channel_count=n_chans, nominal_srate=rate - ) - outlet = pylsl.StreamOutlet(info) - eff_rate = rate or 100.0 - n_interval = int(eff_rate / 10) - n_pushed = 0 - t0 = pylsl.local_clock() - while True: - t_next = t0 + (n_pushed + n_interval) / (rate or 100.0) - t_now = pylsl.local_clock() - await asyncio.sleep(t_next - t_now) - data = np.random.random((n_interval, n_chans)) - outlet.push_chunk(data) - n_pushed += n_interval - - class DummyOutletSettings(ez.Settings): rate: float = 100.0 n_chans: int = 8 + running: bool = True class DummyOutlet(ez.Unit): @@ -54,7 +37,21 @@ class DummyOutlet(ez.Unit): @ez.task async def run_dummy(self) -> None: - await dummy_outlet(rate=self.SETTINGS.rate, n_chans=self.SETTINGS.n_chans) + info = pylsl.StreamInfo( + name="dummy", type="dummy", channel_count=self.SETTINGS.n_chans, nominal_srate=self.SETTINGS.rate + ) + outlet = pylsl.StreamOutlet(info) + eff_rate = self.SETTINGS.rate or 100.0 + n_interval = int(eff_rate / 10) + n_pushed = 0 + t0 = pylsl.local_clock() + while self.SETTINGS.running: + t_next = t0 + (n_pushed + n_interval) / (self.SETTINGS.rate or 100.0) + t_now = pylsl.local_clock() + await asyncio.sleep(t_next - t_now) + data = np.random.random((n_interval, self.SETTINGS.n_chans)) + outlet.push_chunk(data) + n_pushed += n_interval @pytest.mark.parametrize("rate", [100.0, 0.0]) diff --git a/tests/test_util.py b/tests/test_util.py index 5fa4538..333eeed 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,4 +1,4 @@ -import asyncio +import time import numpy as np @@ -8,17 +8,20 @@ def test_clock_sync(): tol = 10e-3 # 1 msec - # Run a few updates to get a stable estimate. clock_sync = ClockSync() - asyncio.run(clock_sync.update(force=True, burst=1000)) - asyncio.run(clock_sync.update(force=True, burst=10)) + # Let it run a bit to get a stable estimate. + time.sleep(1.0) offsets = [] for _ in range(10): - xs, ys = asyncio.run(collect_timestamp_pairs(100)) + xs, ys = collect_timestamp_pairs(100) offsets.append(np.mean(ys - xs)) est_diff = np.abs(np.mean(offsets) - clock_sync.offset) print(est_diff) assert est_diff < tol + + # Assert singleton-ness + clock_sync2 = ClockSync() + assert clock_sync is clock_sync2