Skip to content

Commit

Permalink
Merge pull request #17 from ezmsg-org/dev
Browse files Browse the repository at this point in the history
ClockSync runs as a singleton thread
  • Loading branch information
cboulay authored Nov 15, 2024
2 parents 23986da + 8909faf commit 395d00a
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 91 deletions.
13 changes: 2 additions & 11 deletions src/ezmsg/lsl/inlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ class LSLInletSettings(ez.Settings):
class LSLInletState(ez.State):
resolver: typing.Optional[pylsl.ContinuousResolver] = None
inlet: typing.Optional[pylsl.StreamInlet] = None
clock_offset: float = 0.0


class LSLInletUnit(ez.Unit):
Expand All @@ -94,9 +93,6 @@ class LSLInletUnit(ez.Unit):
INPUT_SETTINGS = ez.InputStream(LSLInletSettings)
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)

# Share clock correction across all instances
clock_sync = ClockSync()

def __init__(self, *args, **kwargs) -> None:
"""
Handle deprecated arguments. Whereas previously stream_name and stream_type were in the
Expand All @@ -107,6 +103,7 @@ def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._msg_template: typing.Optional[AxisArray] = None
self._fetch_buffer: typing.Optional[npt.NDArray] = None
self._clock_sync = ClockSync()

def _reset_resolver(self) -> None:
self.STATE.resolver = pylsl.ContinuousResolver(pred=None, forget_after=30.0)
Expand Down Expand Up @@ -211,7 +208,6 @@ def _reset_inlet(self) -> None:
async def initialize(self) -> None:
self._reset_resolver()
self._reset_inlet()
await self.clock_sync.update(force=True, burst=1000)

def shutdown(self) -> None:
if self.STATE.inlet is not None:
Expand All @@ -222,11 +218,6 @@ def shutdown(self) -> None:
del self.STATE.resolver
self.STATE.resolver = None

@ez.task
async def clock_sync_task(self) -> None:
while True:
await self.clock_sync.update(force=False, burst=4)

@ez.subscriber(INPUT_SETTINGS)
async def on_settings(self, msg: LSLInletSettings) -> None:
# The message may be full LSLInletSettings, a dict of settings, just the info, or dict of just info.
Expand Down Expand Up @@ -268,7 +259,7 @@ async def lsl_pull(self) -> typing.AsyncGenerator:
if self.SETTINGS.use_arrival_time:
timestamps = time.time() - (timestamps - timestamps[0])
else:
timestamps = self.clock_sync.lsl2system(timestamps)
timestamps = self._clock_sync.lsl2system(timestamps)

if self.SETTINGS.info.nominal_srate <= 0.0:
# Irregular rate stream uses CoordinateAxis for time so each sample has a timestamp.
Expand Down
10 changes: 1 addition & 9 deletions src/ezmsg/lsl/outlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,14 @@ class LSLOutletUnit(ez.Unit):
SETTINGS = LSLOutletSettings
STATE = LSLOutletState

# Share clock correction across all instances
clock_sync = ClockSync()

async def initialize(self) -> None:
self._stream_created = False
await self.clock_sync.update(force=True, burst=1000)
self._clock_sync = ClockSync()

def shutdown(self) -> None:
del self.STATE.outlet
self.STATE.outlet = None

@ez.task
async def clock_sync_task(self) -> None:
while True:
await self.clock_sync.update(force=False, burst=4)

@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
async def lsl_outlet(self, msg: AxisArray) -> None:
if self.STATE.outlet is None:
Expand Down
97 changes: 50 additions & 47 deletions src/ezmsg/lsl/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import asyncio
import threading
import time
import typing

Expand All @@ -7,7 +7,7 @@
import pylsl


async def collect_timestamp_pairs(
def collect_timestamp_pairs(
npairs: int = 4,
) -> typing.Tuple[np.ndarray, np.ndarray]:
xs = []
Expand All @@ -19,67 +19,70 @@ async def collect_timestamp_pairs(
x, y = pylsl.local_clock(), time.time()
xs.append(x)
ys.append(y)
await asyncio.sleep(0.001)
time.sleep(0.001)
return np.array(xs), np.array(ys)


class ClockSync:
def __init__(self, alpha: float = 0.1, min_interval: float = 0.5):
self.alpha = alpha
self.min_interval = min_interval
_instance = None
_lock = threading.Lock()

self._offset: typing.Optional[float] = None
self._last_update = 0.0
self.count = 0
def __new__(cls, *args, **kwargs):
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance

@property
def offset(self) -> float:
return self._offset
def __init__(self, alpha: float = 0.1, min_interval: float = 0.1):
if not hasattr(self, "_initialized"):
self._alpha = alpha
self._interval = min_interval
self._offset: typing.Optional[float] = None

@property
def last_updated(self) -> float:
return self._last_update

async def update(self, force: bool = False, burst: int = 4) -> None:
"""
Update the clock offset estimate. This should be called ~regularly in a long-running
asynchronous task.
Args:
force: Whether to force an update even if the minimum interval hasn't passed.
burst: The number of pairs to collect for this update.
"""
dur_since_last = time.time() - self._last_update
dur_until_next = self.min_interval - dur_since_last
if force or dur_until_next <= 0:
xs, ys = await collect_timestamp_pairs(burst)
self.count += burst
if burst > 0:
self._last_update = ys[-1]
self._running = True
self._initialized = True
self._thread = threading.Thread(target=self._run)
self._thread.daemon = True
self._thread.start()

def _run(self):
xs, ys = collect_timestamp_pairs(100)
self._offset = np.mean(ys - xs)
while self._running:
time.sleep(self._interval)
xs, ys = collect_timestamp_pairs(4)
offset = np.mean(ys - xs)
if self.offset is not None:
# Do exponential smoothing update.
self._offset = (1 - self.alpha) * self._offset + self.alpha * offset
else:
# First iteration. No smoothing.
self._offset = offset
else:
await asyncio.sleep(dur_until_next)
self._offset = (1 - self._alpha) * self._offset + self._alpha * offset

@property
def offset(self) -> float:
with self._lock:
return self._offset

@typing.overload
def lsl2system(self, lsl_timestamp: float) -> float: ...
def lsl2system(self, lsl_timestamp: float) -> float:
...

@typing.overload
def lsl2system(self, lsl_timestamp: npt.NDArray[float]) -> npt.NDArray[float]: ...
def lsl2system(self, lsl_timestamp: npt.NDArray[float]) -> npt.NDArray[float]:
...

def lsl2system(self, lsl_timestamp):
# offset = system - lsl --> system = lsl + offset
return lsl_timestamp + self._offset
with self._lock:
return lsl_timestamp + self._offset

@typing.overload
def system2lsl(self, system_timestamp: float) -> float: ...
def system2lsl(self, system_timestamp: float) -> float:
...

@typing.overload
def system2lsl(
self, system_timestamp: npt.NDArray[float]
) -> npt.NDArray[float]: ...
self, system_timestamp: npt.NDArray[float]
) -> npt.NDArray[float]:
...

def system2lsl(self, system_timestamp):
# offset = system - lsl --> lsl = system - offset
return system_timestamp - self._offset
with self._lock:
return system_timestamp - self._offset
35 changes: 16 additions & 19 deletions tests/test_inlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,35 +26,32 @@ def test_inlet_init_defaults():
assert True


async def dummy_outlet(rate: float = 100.0, n_chans: int = 8):
info = pylsl.StreamInfo(
name="dummy", type="dummy", channel_count=n_chans, nominal_srate=rate
)
outlet = pylsl.StreamOutlet(info)
eff_rate = rate or 100.0
n_interval = int(eff_rate / 10)
n_pushed = 0
t0 = pylsl.local_clock()
while True:
t_next = t0 + (n_pushed + n_interval) / (rate or 100.0)
t_now = pylsl.local_clock()
await asyncio.sleep(t_next - t_now)
data = np.random.random((n_interval, n_chans))
outlet.push_chunk(data)
n_pushed += n_interval


class DummyOutletSettings(ez.Settings):
rate: float = 100.0
n_chans: int = 8
running: bool = True


class DummyOutlet(ez.Unit):
SETTINGS = DummyOutletSettings

@ez.task
async def run_dummy(self) -> None:
await dummy_outlet(rate=self.SETTINGS.rate, n_chans=self.SETTINGS.n_chans)
info = pylsl.StreamInfo(
name="dummy", type="dummy", channel_count=self.SETTINGS.n_chans, nominal_srate=self.SETTINGS.rate
)
outlet = pylsl.StreamOutlet(info)
eff_rate = self.SETTINGS.rate or 100.0
n_interval = int(eff_rate / 10)
n_pushed = 0
t0 = pylsl.local_clock()
while self.SETTINGS.running:
t_next = t0 + (n_pushed + n_interval) / (self.SETTINGS.rate or 100.0)
t_now = pylsl.local_clock()
await asyncio.sleep(t_next - t_now)
data = np.random.random((n_interval, self.SETTINGS.n_chans))
outlet.push_chunk(data)
n_pushed += n_interval


@pytest.mark.parametrize("rate", [100.0, 0.0])
Expand Down
13 changes: 8 additions & 5 deletions tests/test_util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import asyncio
import time

import numpy as np

Expand All @@ -8,17 +8,20 @@
def test_clock_sync():
tol = 10e-3 # 1 msec

# Run a few updates to get a stable estimate.
clock_sync = ClockSync()
asyncio.run(clock_sync.update(force=True, burst=1000))
asyncio.run(clock_sync.update(force=True, burst=10))
# Let it run a bit to get a stable estimate.
time.sleep(1.0)

offsets = []
for _ in range(10):
xs, ys = asyncio.run(collect_timestamp_pairs(100))
xs, ys = collect_timestamp_pairs(100)
offsets.append(np.mean(ys - xs))

est_diff = np.abs(np.mean(offsets) - clock_sync.offset)
print(est_diff)

assert est_diff < tol

# Assert singleton-ness
clock_sync2 = ClockSync()
assert clock_sync is clock_sync2

0 comments on commit 395d00a

Please sign in to comment.