Skip to content

Commit

Permalink
Merge pull request #9 from ezmsg-org/dev
Browse files Browse the repository at this point in the history
continuous data: add chan labels and optionally convert to microvolts
  • Loading branch information
cboulay authored Nov 22, 2024
2 parents c69dc16 + 4a6f0d6 commit c3a5163
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 22 deletions.
16 changes: 12 additions & 4 deletions examples/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ def main(
help="Network port to receive packets. This should always be 51002."
),
] = 51002,
recv_bufsize: Annotated[
int, typer.Option(help="UDP socket recv buffer size.")
] = (8 if sys.platform == "win32" else 6) * 1024 * 1024,
recv_bufsize: Annotated[int, typer.Option(help="UDP socket recv buffer size.")] = (
8 if sys.platform == "win32" else 6
)
* 1024
* 1024,
protocol: Annotated[
str, typer.Option(help="Protocol Version. 3.11, 4.0, or 4.1 supported.")
] = "3.11",
Expand All @@ -48,6 +50,12 @@ def main(
help="Duration of buffer for continuous data. Note: buffer may occupy ~15 MB / second."
),
] = 0.5,
microvolts: Annotated[
bool,
typer.Option(
help="Convert continuous data to microvolts (True) or keep raw integers (False)."
),
] = True,
):
source_settings = NSPSourceSettings(
inst_addr,
Expand All @@ -57,6 +65,7 @@ def main(
recv_bufsize,
protocol,
cont_buffer_dur,
microvolts,
)

comps = {
Expand All @@ -75,6 +84,5 @@ def main(
ez.run(components=comps, connections=conns)



if __name__ == "__main__":
typer.run(main)
15 changes: 7 additions & 8 deletions examples/enable_cont.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@ def main(
help="Network port to receive packets. This should always be 51002."
),
] = 51002,
recv_bufsize: Annotated[
int, typer.Option(help="UDP socket recv buffer size.")
] = (8 if sys.platform == "win32" else 6) * 1024 * 1024,
recv_bufsize: Annotated[int, typer.Option(help="UDP socket recv buffer size.")] = (
8 if sys.platform == "win32" else 6
)
* 1024
* 1024,
protocol: Annotated[
str, typer.Option(help="Protocol Version. 3.11, 4.0, or 4.1 supported.")
] = "3.11",
):

params = cbsdk.create_params(
inst_addr=inst_addr,
inst_port=inst_port,
Expand All @@ -61,11 +62,9 @@ def main(
k
for k, v in config["channel_infos"].items()
if config["channel_types"][k]
in (CBChannelType.FrontEnd, CBChannelType.AnalogIn)
in (CBChannelType.FrontEnd, CBChannelType.AnalogIn)
]:
_ = cbsdk.set_channel_config(
device, chid, "smpgroup", smp_group
)
_ = cbsdk.set_channel_config(device, chid, "smpgroup", smp_group)
# Refresh config
time.sleep(0.5) # Make sure all the config packets have returned.

Expand Down
43 changes: 33 additions & 10 deletions src/ezmsg/blackrock/nsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import typing

import numpy as np
from pycbsdk import cbsdk
from pycbsdk import cbsdk, cbhw
import ezmsg.core as ez
from ezmsg.util.messages.axisarray import AxisArray, replace
from ezmsg.event.message import EventMessage
Expand All @@ -24,7 +24,8 @@ class NSPSourceSettings(ez.Settings):
cont_buffer_dur: float = 0.5
"""Duration of continuous buffer to hold recv packets. Up to ~15 MB / second."""

# TODO: convert_uV: bool = False # Returned values are converted to uV (True) or stay raw integers (False)
microvolts: bool = True
"""Convert continuous data to uV (True) or keep raw integers (False)."""


class NSPSourceState(ez.State):
Expand All @@ -39,6 +40,7 @@ class NSPSourceState(ez.State):
template_cont = {
_: AxisArray(data=np.array([[]]), dims=["time", "ch"]) for _ in range(1, 7)
}
scale_cont = {_: np.array([]) for _ in range(1, 7)}
sysfreq: int = 30_000 # Default for pre-Gemini system


Expand Down Expand Up @@ -69,30 +71,50 @@ async def initialize(self) -> None:
_ = cbsdk.register_spk_callback(self.STATE.device, self.on_spike)

for grp_idx in range(1, 7):
n_channels = len(config["group_infos"][grp_idx])
self._reset_buffer(grp_idx, n_channels)
self._reset_buffer(grp_idx, config)
_ = cbsdk.register_group_callback(
self.STATE.device,
grp_idx,
functools.partial(self.on_smp_group, grp_idx=grp_idx),
)

def _reset_buffer(self, grp_idx: int, n_channels: int) -> None:
def _reset_buffer(self, grp_idx: int, config: dict) -> None:
chanset = config["group_infos"][grp_idx]
buff_samples = int(self.SETTINGS.cont_buffer_dur * grp_fs[grp_idx])
n_channels = len(chanset)
self.STATE.cont_buffer[grp_idx] = (
np.zeros((buff_samples,), dtype=int),
np.zeros((buff_samples, n_channels), dtype=np.int16),
)
self.STATE.cont_read_idx[grp_idx] = 0
self.STATE.cont_write_idx[grp_idx] = 0
time_ax = AxisArray.Axis.TimeAxis(grp_fs[grp_idx], offset=0.0)
time_ax = AxisArray.TimeAxis(grp_fs[grp_idx], offset=0.0)

chan_labels = []
scale_factors = []
for ch_idx in chanset:
pkt: cbhw.packet.packets.CBPacketChanInfo = config["channel_infos"][ch_idx]
chan_labels.append(pkt.label.decode("utf-8"))
scale_fac = (pkt.scalin.anamax - pkt.scalin.anamin) / (
pkt.scalin.digmax - pkt.scalin.digmin
)
if pkt.scalin.anaunit.decode("utf-8") == "mV":
scale_fac /= 1000
scale_factors.append(scale_fac)

ch_ax = AxisArray.CoordinateAxis(
data=np.array(chan_labels), dims=["ch"], unit="label"
)
self.STATE.template_cont[grp_idx] = AxisArray(
np.zeros((0, 0)),
dims=["time", "ch"],
axes={"time": time_ax}, # TODO: Ch CoordinateAxis
axes={"time": time_ax, "ch": ch_ax},
key=f"SMP{grp_idx}" if grp_idx < 6 else "RAW",
attrs={"unit": "uV" if self.SETTINGS.microvolts else "raw"},
)

self.STATE.scale_cont[grp_idx] = np.array(scale_factors)

def shutdown(self) -> None:
if hasattr(self.STATE, "device") and self.STATE.device is not None:
self.STATE.device.disconnect()
Expand Down Expand Up @@ -121,15 +143,16 @@ async def pub_cont(self) -> typing.AsyncGenerator:
else:
b_any = True
read_slice = slice(_read_idx, min(buff_len, read_term))
read_view = _buff[1][read_slice]
out_dat = _buff[1][read_slice].copy()
if self.SETTINGS.microvolts:
out_dat = out_dat * self.STATE.scale_cont[grp_idx][None, :]
new_offset: float = _buff[0][_read_idx] / self.STATE.sysfreq
_templ = self.STATE.template_cont[grp_idx]
new_time_ax = replace(_templ.axes["time"], offset=new_offset)
out_msg = replace(
_templ,
data=read_view.copy(), # TODO: Scale to uV. Needs per-channel scale factor.
data=out_dat,
axes={**_templ.axes, **{"time": new_time_ax}},
key=f"SMP{grp_idx}" if grp_idx < 6 else "RAW",
)
self.STATE.cont_read_idx[grp_idx] = read_term % buff_len
yield self.OUTPUT_SIGNAL, out_msg
Expand Down

0 comments on commit c3a5163

Please sign in to comment.