Skip to content

Commit

Permalink
init support for different protocols
Browse files Browse the repository at this point in the history
  • Loading branch information
CamDavidsonPilon committed Jan 4, 2025
1 parent 75b3d15 commit 08dae45
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 99 deletions.
2 changes: 0 additions & 2 deletions pioreactor/actions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +0,0 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
4 changes: 1 addition & 3 deletions pioreactor/automations/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

from typing import Optional

from msgspec.json import encode

from pioreactor import structs
Expand Down Expand Up @@ -49,7 +47,7 @@ def __init__(self, unit: str, experiment: str) -> None:
def on_init_to_ready(self) -> None:
self.start_passive_listeners()

def execute(self) -> Optional[events.AutomationEvent]:
def execute(self) -> events.AutomationEvent | None:
"""
Overwrite in subclass
"""
Expand Down
48 changes: 28 additions & 20 deletions pioreactor/calibrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from typing import Callable
from typing import Literal
from typing import overload
from typing import Type

from msgspec import ValidationError
from msgspec.yaml import decode as yaml_decode
from msgspec.yaml import encode as yaml_encode

from pioreactor import structs
from pioreactor.types import PumpCalibrationDevices
from pioreactor.utils import local_persistent_storage
from pioreactor.whoami import is_testing_env

Expand All @@ -19,62 +21,72 @@
else:
CALIBRATION_PATH = Path(".pioreactor/storage/calibrations/")

# Lookup table for different calibration assistants
calibration_assistants = {}
# Lookup table for different calibration protocols
calibration_protocols: dict[tuple[str, str], Type[CalibrationProtocol]] = {}


class CalibrationAssistant:
class CalibrationProtocol:
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
calibration_assistants[cls.target_device] = cls
calibration_protocols[(cls.target_device, cls.protocol_name)] = cls

def run(self, *args, **kwargs):
raise NotImplementedError("Subclasses must implement this method.")


class ODAssistant(CalibrationAssistant):
class SingleVialODProtocol(CalibrationProtocol):
target_device = "od"
calibration_struct = structs.ODCalibration
protocol_name = "single_vial"

def run(self) -> structs.ODCalibration:
from pioreactor.calibrations.od_calibration import run_od_calibration

return run_od_calibration()


class MediaPumpAssistant(CalibrationAssistant):
class BatchVialODProtocol(CalibrationProtocol):
target_device = "od"
protocol_name = "batch_vial"

def run(self) -> structs.ODCalibration:
from pioreactor.calibrations.od_calibration import run_od_calibration

return run_od_calibration()


class DurationBasedMediaPumpProtocol(CalibrationProtocol):
target_device = "media_pump"
calibration_struct = structs.SimplePeristalticPumpCalibration
protocol_name = "duration_based"

def run(self) -> structs.SimplePeristalticPumpCalibration:
from pioreactor.calibrations.pump_calibration import run_pump_calibration

return run_pump_calibration()


class AltMediaPumpAssistant(CalibrationAssistant):
class DurationBasedAltMediaPumpProtocol(CalibrationProtocol):
target_device = "alt_media_pump"
calibration_struct = structs.SimplePeristalticPumpCalibration
protocol_name = "duration_based"

def run(self) -> structs.SimplePeristalticPumpCalibration:
from pioreactor.calibrations.pump_calibration import run_pump_calibration

return run_pump_calibration()


class WastePumpAssistant(CalibrationAssistant):
class DurationBasedWasteMediaPumpProtocol(CalibrationProtocol):
target_device = "waste_pump"
calibration_struct = structs.SimplePeristalticPumpCalibration
protocol_name = "duration_based"

def run(self) -> structs.SimplePeristalticPumpCalibration:
from pioreactor.calibrations.pump_calibration import run_pump_calibration

return run_pump_calibration()


class StirringAssistant(CalibrationAssistant):
class DCBasedStirringProtocol(CalibrationProtocol):
target_device = "stirring"
calibration_struct = structs.SimpleStirringCalibration
protocol_name = "dc_based"

def run(self, min_dc: str | None = None, max_dc: str | None = None) -> structs.SimpleStirringCalibration:
from pioreactor.calibrations.stirring_calibration import run_stirring_calibration
Expand All @@ -90,9 +102,7 @@ def load_active_calibration(device: Literal["od"]) -> structs.ODCalibration:


@overload
def load_active_calibration(
device: Literal["media_pump", "waste_pump", "alt_media_pump"]
) -> structs.SimplePeristalticPumpCalibration:
def load_active_calibration(device: PumpCalibrationDevices) -> structs.SimplePeristalticPumpCalibration:
pass


Expand All @@ -119,10 +129,8 @@ def load_calibration(device: str, calibration_name: str) -> structs.AnyCalibrati
f"Calibration {calibration_name} was not found in {CALIBRATION_PATH / device}"
)

assistant = calibration_assistants[device]

try:
data = yaml_decode(target_file.read_bytes(), type=assistant.calibration_struct)
data = yaml_decode(target_file.read_bytes(), type=structs.subclass_union(structs.CalibrationBase))
return data
except ValidationError as e:
raise ValidationError(f"Error reading {target_file.stem}: {e}")
35 changes: 7 additions & 28 deletions pioreactor/calibrations/pump_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
from click import echo
from click import prompt
from click import style
from msgspec.json import decode
from msgspec.json import encode
from msgspec.json import format

from pioreactor import structs
from pioreactor.actions.pump import add_alt_media
from pioreactor.actions.pump import add_media
from pioreactor.actions.pump import remove_waste
from pioreactor.calibrations import load_active_calibration
from pioreactor.calibrations.utils import curve_to_callable
from pioreactor.config import config
from pioreactor.hardware import voltage_in_aux
Expand Down Expand Up @@ -106,38 +106,17 @@ def get_metadata_from_user(pump_device: PumpCalibrationDevices) -> str:


def which_pump_are_you_calibrating() -> tuple[PumpCalibrationDevices, Callable]:
with local_persistent_storage("active_calibrations") as cache:
has_media = "media_pump" in cache
has_waste = "waste_pump" in cache
has_alt_media = "alt_media_pump" in cache

if has_media:
media_timestamp = decode(cache["media"], type=structs.SimplePeristalticPumpCalibration).created_at
media_name = decode(
cache["media"], type=structs.SimplePeristalticPumpCalibration
).calibration_name

if has_waste:
waste_timestamp = decode(cache["waste"], type=structs.SimplePeristalticPumpCalibration).created_at
waste_name = decode(
cache["waste"], type=structs.SimplePeristalticPumpCalibration
).calibration_name

if has_alt_media:
alt_media_timestamp = decode(
cache["alt_media"], type=structs.SimplePeristalticPumpCalibration
).created_at
alt_media_name = decode(
cache["alt_media"], type=structs.SimplePeristalticPumpCalibration
).calibration_name
m = load_active_calibration("media_pump")
a = load_active_calibration("alt_media_pump")
w = load_active_calibration("waste_pump")

echo(green(bold("Step 1")))
r = prompt(
green(
f"""Which pump are you calibrating?
1. Media {f'[{media_name}, last ran {media_timestamp:%d %b, %Y}]' if has_media else '[No calibration]'}
2. Alt-media {f'[{alt_media_name}, last ran {alt_media_timestamp:%d %b, %Y}]' if has_alt_media else '[No calibration]'}
3. Waste {f'[{waste_name}, last ran {waste_timestamp:%d %b, %Y}]' if has_waste else '[No calibration]'}
1. Media {f'[{m.calibration_name}, last ran {m.created_at:%d %b, %Y}]' if m else '[No calibration]'}
2. Alt-media {f'[{a.calibration_name}, last ran {a.created_at:%d %b, %Y}]' if a else '[No calibration]'}
3. Waste {f'[{w.calibration_name}, last ran {w.created_at:%d %b, %Y}]' if w else '[No calibration]'}
""",
),
type=click.Choice(["1", "2", "3"]),
Expand Down
59 changes: 36 additions & 23 deletions pioreactor/cli/calibrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from msgspec.yaml import decode as yaml_decode
from msgspec.yaml import encode as yaml_encode

from pioreactor.calibrations import calibration_assistants
from pioreactor import structs
from pioreactor.calibrations import CALIBRATION_PATH
from pioreactor.calibrations import calibration_protocols
from pioreactor.calibrations import load_calibration
from pioreactor.calibrations.utils import curve_to_callable
from pioreactor.calibrations.utils import plot_data
Expand All @@ -16,36 +17,30 @@
@click.group(short_help="calibration utils")
def calibration():
"""
interface for all calibration types.
interface for all calibrations.
"""
pass


@calibration.command(name="list")
@click.option("--device", required=True, help="Filter by calibration type.")
@click.option("--device", required=True)
def list_calibrations(device: str):
"""
List existing calibrations for the given type.
List existing calibrations for the given device.
"""
calibration_dir = CALIBRATION_PATH / device
if not calibration_dir.exists():
click.echo(f"No calibrations found for device '{device}'. Directory does not exist.")
raise click.Abort()

try:
assistant = calibration_assistants[device]
except KeyError:
click.echo(f"No calibrations assistant for type '{device}'.")
raise click.Abort()

header = f"{'Name':<50}{'Created At':<25}{'Active?':<10}{'Location':<75}"
click.echo(header)
click.echo("-" * len(header))

with local_persistent_storage("active_calibrations") as c:
for file in calibration_dir.glob("*.yaml"):
try:
data = yaml_decode(file.read_bytes(), type=assistant.calibration_struct)
data = yaml_decode(file.read_bytes(), type=structs.subclass_union(structs.CalibrationBase))
active = c.get(device) == data.calibration_name
row = f"{data.calibration_name:<50}{data.created_at.strftime('%Y-%m-%d %H:%M:%S'):<25}{'✅' if active else '':<10}{file}"
click.echo(row)
Expand All @@ -55,32 +50,50 @@ def list_calibrations(device: str):


@calibration.command(name="run", context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
@click.option("--device", "device", required=True, help="Type of calibration (e.g. od, pump, stirring).")
@click.option(
"--device", "device", required=True, help="Target device of calibration (e.g. od, pump, stirring)."
)
@click.option("--protocol-name", required=False, help="name of protocol, defaults to basic builtin protocol")
@click.pass_context
def run_calibration(ctx, device: str):
def run_calibration(ctx, device: str, protocol_name: str | None):
"""
Run an interactive calibration assistant for a specific type.
On completion, stores a YAML file in: /home/pioreactor/.pioreactor/storage/calibrations/<type>/<calibration_name>.yaml
Run an interactive calibration assistant for a specific protocol.
On completion, stores a YAML file in: /home/pioreactor/.pioreactor/storage/calibrations/<device>/<calibration_name>.yaml
"""

# Dispatch to the assistant function for that type
assistant = calibration_assistants.get(device)
DEFAULT_PROTOCOLS = {
"od": "single_vial",
"media_pump": "duration_based",
"alt_media_pump": "duration_based",
"waste_pump": "duration_based",
"stirring": "dc_based",
}

# Dispatch to the assistant function for that device
if protocol_name is None and device in DEFAULT_PROTOCOLS:
protocol_name = DEFAULT_PROTOCOLS[device]

assert protocol_name is not None
assistant = calibration_protocols.get((device, protocol_name))
if assistant is None:
click.echo(
f"No assistant found for calibration device '{device}'. Available types: {list(calibration_assistants.keys())}"
f"No protocols found for calibration device '{device}'. Available {device} protocols: {list(c[1] for c in calibration_protocols.keys() if c[0] == device)}"
)
raise click.Abort()

# Run the assistant function to get the final calibration data
calibration_data = assistant().run(
calibration_struct = assistant().run(
**{ctx.args[i][2:].replace("-", "_"): ctx.args[i + 1] for i in range(0, len(ctx.args), 2)},
)
calibration_name = calibration_data.calibration_name

out_file = calibration_data.save_to_disk(device)
calibration_data.set_as_active_calibration_for_device(device)
out_file = calibration_struct.save_to_disk(device)
calibration_struct.set_as_active_calibration_for_device(device)

# post to leader??

click.echo(f"Calibration '{calibration_name}' of device '{device}' saved to {out_file}")
click.echo(
f"Calibration '{calibration_struct.calibration_name}' of device '{device}' saved to {out_file} ✅"
)


@calibration.command(name="display")
Expand Down
23 changes: 0 additions & 23 deletions pioreactor/tests/test_calibrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,9 @@
import pytest
from msgspec import ValidationError

from pioreactor.calibrations import AltMediaPumpAssistant
from pioreactor.calibrations import calibration_assistants
from pioreactor.calibrations import CALIBRATION_PATH
from pioreactor.calibrations import load_active_calibration
from pioreactor.calibrations import load_calibration
from pioreactor.calibrations import MediaPumpAssistant
from pioreactor.calibrations import ODAssistant
from pioreactor.calibrations import StirringAssistant
from pioreactor.calibrations import WastePumpAssistant
from pioreactor.structs import ODCalibration
from pioreactor.utils import local_persistent_storage

Expand All @@ -28,23 +22,6 @@ def temp_calibration_dir():
yield calibrations_dir


def test_calibration_assistants_dict() -> None:
assert "od" in calibration_assistants
assert calibration_assistants["od"] is ODAssistant

assert "media_pump" in calibration_assistants
assert calibration_assistants["media_pump"] is MediaPumpAssistant

assert "alt_media_pump" in calibration_assistants
assert calibration_assistants["alt_media_pump"] is AltMediaPumpAssistant

assert "waste_pump" in calibration_assistants
assert calibration_assistants["waste_pump"] is WastePumpAssistant

assert "stirring" in calibration_assistants
assert calibration_assistants["stirring"] is StirringAssistant


def test_save_and_load_calibration(temp_calibration_dir) -> None:
# 1. Create an ODCalibration object (fully valid).
od_cal = ODCalibration(
Expand Down

0 comments on commit 08dae45

Please sign in to comment.