diff --git a/examples/demo.py b/examples/demo.py index 9663ddf..dd75296 100644 --- a/examples/demo.py +++ b/examples/demo.py @@ -36,9 +36,11 @@ def main( help="Network port to receive packets. This should always be 51002." ), ] = 51002, - recv_bufsize: Annotated[ - int, typer.Option(help="UDP socket recv buffer size.") - ] = (8 if sys.platform == "win32" else 6) * 1024 * 1024, + recv_bufsize: Annotated[int, typer.Option(help="UDP socket recv buffer size.")] = ( + 8 if sys.platform == "win32" else 6 + ) + * 1024 + * 1024, protocol: Annotated[ str, typer.Option(help="Protocol Version. 3.11, 4.0, or 4.1 supported.") ] = "3.11", @@ -48,6 +50,12 @@ def main( help="Duration of buffer for continuous data. Note: buffer may occupy ~15 MB / second." ), ] = 0.5, + microvolts: Annotated[ + bool, + typer.Option( + help="Convert continuous data to microvolts (True) or keep raw integers (False)." + ), + ] = True, ): source_settings = NSPSourceSettings( inst_addr, @@ -57,6 +65,7 @@ def main( recv_bufsize, protocol, cont_buffer_dur, + microvolts, ) comps = { @@ -75,6 +84,5 @@ def main( ez.run(components=comps, connections=conns) - if __name__ == "__main__": typer.run(main) diff --git a/examples/enable_cont.py b/examples/enable_cont.py index 4b1898e..189190a 100644 --- a/examples/enable_cont.py +++ b/examples/enable_cont.py @@ -36,14 +36,15 @@ def main( help="Network port to receive packets. This should always be 51002." ), ] = 51002, - recv_bufsize: Annotated[ - int, typer.Option(help="UDP socket recv buffer size.") - ] = (8 if sys.platform == "win32" else 6) * 1024 * 1024, + recv_bufsize: Annotated[int, typer.Option(help="UDP socket recv buffer size.")] = ( + 8 if sys.platform == "win32" else 6 + ) + * 1024 + * 1024, protocol: Annotated[ str, typer.Option(help="Protocol Version. 3.11, 4.0, or 4.1 supported.") ] = "3.11", ): - params = cbsdk.create_params( inst_addr=inst_addr, inst_port=inst_port, @@ -61,11 +62,9 @@ def main( k for k, v in config["channel_infos"].items() if config["channel_types"][k] - in (CBChannelType.FrontEnd, CBChannelType.AnalogIn) + in (CBChannelType.FrontEnd, CBChannelType.AnalogIn) ]: - _ = cbsdk.set_channel_config( - device, chid, "smpgroup", smp_group - ) + _ = cbsdk.set_channel_config(device, chid, "smpgroup", smp_group) # Refresh config time.sleep(0.5) # Make sure all the config packets have returned. diff --git a/src/ezmsg/blackrock/nsp.py b/src/ezmsg/blackrock/nsp.py index d55fae5..be28667 100644 --- a/src/ezmsg/blackrock/nsp.py +++ b/src/ezmsg/blackrock/nsp.py @@ -4,7 +4,7 @@ import typing import numpy as np -from pycbsdk import cbsdk +from pycbsdk import cbsdk, cbhw import ezmsg.core as ez from ezmsg.util.messages.axisarray import AxisArray, replace from ezmsg.event.message import EventMessage @@ -24,7 +24,8 @@ class NSPSourceSettings(ez.Settings): cont_buffer_dur: float = 0.5 """Duration of continuous buffer to hold recv packets. Up to ~15 MB / second.""" - # TODO: convert_uV: bool = False # Returned values are converted to uV (True) or stay raw integers (False) + microvolts: bool = True + """Convert continuous data to uV (True) or keep raw integers (False).""" class NSPSourceState(ez.State): @@ -39,6 +40,7 @@ class NSPSourceState(ez.State): template_cont = { _: AxisArray(data=np.array([[]]), dims=["time", "ch"]) for _ in range(1, 7) } + scale_cont = {_: np.array([]) for _ in range(1, 7)} sysfreq: int = 30_000 # Default for pre-Gemini system @@ -69,30 +71,50 @@ async def initialize(self) -> None: _ = cbsdk.register_spk_callback(self.STATE.device, self.on_spike) for grp_idx in range(1, 7): - n_channels = len(config["group_infos"][grp_idx]) - self._reset_buffer(grp_idx, n_channels) + self._reset_buffer(grp_idx, config) _ = cbsdk.register_group_callback( self.STATE.device, grp_idx, functools.partial(self.on_smp_group, grp_idx=grp_idx), ) - def _reset_buffer(self, grp_idx: int, n_channels: int) -> None: + def _reset_buffer(self, grp_idx: int, config: dict) -> None: + chanset = config["group_infos"][grp_idx] buff_samples = int(self.SETTINGS.cont_buffer_dur * grp_fs[grp_idx]) + n_channels = len(chanset) self.STATE.cont_buffer[grp_idx] = ( np.zeros((buff_samples,), dtype=int), np.zeros((buff_samples, n_channels), dtype=np.int16), ) self.STATE.cont_read_idx[grp_idx] = 0 self.STATE.cont_write_idx[grp_idx] = 0 - time_ax = AxisArray.Axis.TimeAxis(grp_fs[grp_idx], offset=0.0) + time_ax = AxisArray.TimeAxis(grp_fs[grp_idx], offset=0.0) + + chan_labels = [] + scale_factors = [] + for ch_idx in chanset: + pkt: cbhw.packet.packets.CBPacketChanInfo = config["channel_infos"][ch_idx] + chan_labels.append(pkt.label.decode("utf-8")) + scale_fac = (pkt.scalin.anamax - pkt.scalin.anamin) / ( + pkt.scalin.digmax - pkt.scalin.digmin + ) + if pkt.scalin.anaunit.decode("utf-8") == "mV": + scale_fac /= 1000 + scale_factors.append(scale_fac) + + ch_ax = AxisArray.CoordinateAxis( + data=np.array(chan_labels), dims=["ch"], unit="label" + ) self.STATE.template_cont[grp_idx] = AxisArray( np.zeros((0, 0)), dims=["time", "ch"], - axes={"time": time_ax}, # TODO: Ch CoordinateAxis + axes={"time": time_ax, "ch": ch_ax}, key=f"SMP{grp_idx}" if grp_idx < 6 else "RAW", + attrs={"unit": "uV" if self.SETTINGS.microvolts else "raw"}, ) + self.STATE.scale_cont[grp_idx] = np.array(scale_factors) + def shutdown(self) -> None: if hasattr(self.STATE, "device") and self.STATE.device is not None: self.STATE.device.disconnect() @@ -121,15 +143,16 @@ async def pub_cont(self) -> typing.AsyncGenerator: else: b_any = True read_slice = slice(_read_idx, min(buff_len, read_term)) - read_view = _buff[1][read_slice] + out_dat = _buff[1][read_slice].copy() + if self.SETTINGS.microvolts: + out_dat = out_dat * self.STATE.scale_cont[grp_idx][None, :] new_offset: float = _buff[0][_read_idx] / self.STATE.sysfreq _templ = self.STATE.template_cont[grp_idx] new_time_ax = replace(_templ.axes["time"], offset=new_offset) out_msg = replace( _templ, - data=read_view.copy(), # TODO: Scale to uV. Needs per-channel scale factor. + data=out_dat, axes={**_templ.axes, **{"time": new_time_ax}}, - key=f"SMP{grp_idx}" if grp_idx < 6 else "RAW", ) self.STATE.cont_read_idx[grp_idx] = read_term % buff_len yield self.OUTPUT_SIGNAL, out_msg