From c4f6d40812c12df2e0d5b6e0d4b97307f07e0f49 Mon Sep 17 00:00:00 2001 From: jjallaire Date: Fri, 10 Jan 2025 00:31:07 +0000 Subject: [PATCH] task_with function for creating task variants (#1099) * task_with implementation * proper handling of name * consolidate task member assignment logic * require keyword args for task_with * update changelog * mypy fix * correct import for override --- CHANGELOG.md | 1 + src/inspect_ai/__init__.py | 3 +- src/inspect_ai/_eval/task/__init__.py | 4 +- src/inspect_ai/_eval/task/task.py | 165 ++++++++++++++++++++++---- src/inspect_ai/_util/notgiven.py | 18 +++ tests/test_helpers/tasks.py | 1 + tests/test_task_with.py | 31 +++++ 7 files changed, 195 insertions(+), 28 deletions(-) create mode 100644 src/inspect_ai/_util/notgiven.py create mode 100644 tests/test_task_with.py diff --git a/CHANGELOG.md b/CHANGELOG.md index e96ab16dd..8b28bf8e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## Unreleased +- `task_with()` function for creating task variants. - Print model conversations to terminal with `--display=conversation` (was formerly `--trace`, which is now deprecated). ## v0.3.57 (09 January 2025) diff --git a/src/inspect_ai/__init__.py b/src/inspect_ai/__init__.py index 570b6d56c..5f0af8495 100644 --- a/src/inspect_ai/__init__.py +++ b/src/inspect_ai/__init__.py @@ -7,7 +7,7 @@ from inspect_ai._eval.list import list_tasks from inspect_ai._eval.registry import task from inspect_ai._eval.score import score, score_async -from inspect_ai._eval.task import Epochs, Task, TaskInfo, Tasks +from inspect_ai._eval.task import Epochs, Task, TaskInfo, Tasks, task_with from inspect_ai._util.constants import PKG_NAME from inspect_ai.solver._human_agent.agent import human_agent @@ -29,4 +29,5 @@ "TaskInfo", "Tasks", "task", + "task_with", ] diff --git a/src/inspect_ai/_eval/task/__init__.py b/src/inspect_ai/_eval/task/__init__.py index d84618113..8cf723341 100644 --- a/src/inspect_ai/_eval/task/__init__.py +++ b/src/inspect_ai/_eval/task/__init__.py @@ -1,4 +1,4 @@ -from .task import Task, TaskInfo, PreviousTask, Tasks # noqa: I001, F401 +from .task import Task, TaskInfo, PreviousTask, Tasks, task_with # noqa: I001, F401 from .epochs import Epochs -__all__ = ["Epochs", "Task", "TaskInfo", "PreviousTask", "Tasks"] +__all__ = ["Epochs", "Task", "TaskInfo", "PreviousTask", "Tasks", "task_with"] diff --git a/src/inspect_ai/_eval/task/task.py b/src/inspect_ai/_eval/task/task.py index ac2e8fbf3..d5b81ea12 100644 --- a/src/inspect_ai/_eval/task/task.py +++ b/src/inspect_ai/_eval/task/task.py @@ -1,3 +1,4 @@ +from copy import deepcopy from dataclasses import dataclass from logging import getLogger from typing import Any, Callable, Sequence, cast @@ -6,6 +7,7 @@ from typing_extensions import TypedDict, Unpack from inspect_ai._util.logger import warn_once +from inspect_ai._util.notgiven import NOT_GIVEN, NotGiven from inspect_ai._util.registry import is_registry_object, registry_info from inspect_ai.approval._policy import ApprovalPolicy, approval_policies_from_config from inspect_ai.dataset import Dataset, MemoryDataset, Sample @@ -115,35 +117,15 @@ def __init__( f"DEPRECATED: the '{arg}' parameter is deprecated (please use the '{newarg}' parameter instead)", ) - # resolve epochs / epochs_reducer - if isinstance(epochs, int): - epochs = Epochs(epochs) - if epochs is not None and epochs.epochs < 1: - raise ValueError("epochs must be a positive integer.") - - # resolve dataset (provide empty sample to bootstrap tasks w/o samples, - # which could occur for testing or for an interactive mode eval) - dataset = dataset or [Sample(input="prompt")] - self.dataset: Dataset = ( - dataset if isinstance(dataset, Dataset) else MemoryDataset(list(dataset)) - ) + self.dataset = resolve_dataset(dataset) self.setup = setup - self.solver = chain(solver) if isinstance(solver, list) else solver - self.scorer = ( - scorer - if isinstance(scorer, list) - else [scorer] - if scorer is not None - else None - ) + self.solver = resolve_solver(solver) + self.scorer = resolve_scorer(scorer) self.metrics = metrics self.config = config self.sandbox = resolve_sandbox_environment(sandbox) - self.approval = ( - approval_policies_from_config(approval) - if isinstance(approval, str) - else approval - ) + self.approval = resolve_approval(approval) + epochs = resolve_epochs(epochs) self.epochs = epochs.epochs if epochs else None self.epochs_reducer = epochs.reducer if epochs else None self.fail_on_error = fail_on_error @@ -171,6 +153,106 @@ def attribs(self) -> dict[str, Any]: return dict() +def task_with( + task: Task, + *, + dataset: Dataset | Sequence[Sample] | None | NotGiven = NOT_GIVEN, + setup: Solver | list[Solver] | None | NotGiven = NOT_GIVEN, + solver: Solver | list[Solver] | NotGiven = NOT_GIVEN, + scorer: Scorer | list[Scorer] | None | NotGiven = NOT_GIVEN, + metrics: list[Metric] | dict[str, list[Metric]] | None | NotGiven = NOT_GIVEN, + config: GenerateConfig | NotGiven = NOT_GIVEN, + sandbox: SandboxEnvironmentType | None | NotGiven = NOT_GIVEN, + approval: str | list[ApprovalPolicy] | None | NotGiven = NOT_GIVEN, + epochs: int | Epochs | None | NotGiven = NOT_GIVEN, + fail_on_error: bool | float | None | NotGiven = NOT_GIVEN, + message_limit: int | None | NotGiven = NOT_GIVEN, + token_limit: int | None | NotGiven = NOT_GIVEN, + time_limit: int | None | NotGiven = NOT_GIVEN, + name: str | None | NotGiven = NOT_GIVEN, + version: int | NotGiven = NOT_GIVEN, + metadata: dict[str, Any] | None | NotGiven = NOT_GIVEN, +) -> Task: + """Task adapted with alternate values for one or more options. + + Args: + task (Task): Task to adapt (it is deep copied prior to mutating options) + dataset (Dataset | Sequence[Sample]): Dataset to evaluate + setup: (Solver | list[Solver] | None): Setup step (always run + even when the main `solver` is replaced). + solver: (Solver | list[Solver]): Solver or list of solvers. + Defaults to generate(), a normal call to the model. + scorer: (Scorer | list[Scorer] | None): Scorer used to evaluate model output. + metrics (list[Metric] | dict[str, list[Metric]] | None): + Alternative metrics (overrides the metrics provided by the specified scorer). + config (GenerateConfig): Model generation config. + sandbox (SandboxEnvironmentType | None): Sandbox environment type + (or optionally a str or tuple with a shorthand spec) + approval: (str | list[ApprovalPolicy] | None): Tool use approval policies. + Either a path to an approval policy config file or a list of approval policies. + Defaults to no approval policy. + epochs (int | Epochs | None): Epochs to repeat samples for and optional score + reducer function(s) used to combine sample scores (defaults to "mean") + fail_on_error (bool | float | None): `True` to fail on first sample error + (default); `False` to never fail on sample errors; Value between 0 and 1 + to fail if a proportion of total samples fails. Value greater than 1 to fail + eval if a count of samples fails. + message_limit (int | None): Limit on total messages used for each sample. + token_limit (int | None): Limit on total tokens used for each sample. + time_limit (int | None): Limit on time (in seconds) for execution of each sample. + name: (str | None): Task name. If not specified is automatically + determined based on the name of the task directory (or "task") + if its anonymous task (e.g. created in a notebook and passed to + eval() directly) + version: (int): Version of task (to distinguish evolutions + of the task spec or breaking changes to it) + metadata: (dict[str, Any] | None): Additional metadata to associate with the task. + + Returns: + Task: Task adapted with alternate options. + """ + # deep copy the task + task = deepcopy(task) + + if not isinstance(dataset, NotGiven): + task.dataset = resolve_dataset(dataset) + if not isinstance(setup, NotGiven): + task.setup = setup + if not isinstance(solver, NotGiven): + task.solver = resolve_solver(solver) + if not isinstance(scorer, NotGiven): + task.scorer = resolve_scorer(scorer) + if not isinstance(metrics, NotGiven): + task.metrics = metrics + if not isinstance(config, NotGiven): + task.config = config + if not isinstance(sandbox, NotGiven): + task.sandbox = resolve_sandbox_environment(sandbox) + if not isinstance(approval, NotGiven): + task.approval = resolve_approval(approval) + if not isinstance(epochs, NotGiven): + epochs = resolve_epochs(epochs) + task.epochs = epochs.epochs if epochs else None + task.epochs_reducer = epochs.reducer if epochs else None + if not isinstance(fail_on_error, NotGiven): + task.fail_on_error = fail_on_error + if not isinstance(message_limit, NotGiven): + task.message_limit = message_limit + if not isinstance(token_limit, NotGiven): + task.token_limit = token_limit + if not isinstance(time_limit, NotGiven): + task.time_limit = time_limit + if not isinstance(version, NotGiven): + task.version = version + if not isinstance(name, NotGiven): + task._name = name + if not isinstance(metadata, NotGiven): + task.metadata = metadata + + # return modified task + return task + + class TaskInfo(BaseModel): """Task information (file, name, and attributes).""" @@ -225,3 +307,36 @@ class PreviousTask: can be specified). None is a request to read a task out of the current working directory. """ + + +def resolve_approval( + approval: str | list[ApprovalPolicy] | None, +) -> list[ApprovalPolicy] | None: + return ( + approval_policies_from_config(approval) + if isinstance(approval, str) + else approval + ) + + +def resolve_epochs(epochs: int | Epochs | None) -> Epochs | None: + if isinstance(epochs, int): + epochs = Epochs(epochs) + if epochs is not None and epochs.epochs < 1: + raise ValueError("epochs must be a positive integer.") + return epochs + + +def resolve_dataset(dataset: Dataset | Sequence[Sample] | None) -> Dataset: + dataset = dataset or [Sample(input="prompt")] + return dataset if isinstance(dataset, Dataset) else MemoryDataset(list(dataset)) + + +def resolve_solver(solver: Solver | list[Solver]) -> Solver: + return chain(solver) if isinstance(solver, list) else solver + + +def resolve_scorer(scorer: Scorer | list[Scorer] | None) -> list[Scorer] | None: + return ( + scorer if isinstance(scorer, list) else [scorer] if scorer is not None else None + ) diff --git a/src/inspect_ai/_util/notgiven.py b/src/inspect_ai/_util/notgiven.py new file mode 100644 index 000000000..7d83a2b22 --- /dev/null +++ b/src/inspect_ai/_util/notgiven.py @@ -0,0 +1,18 @@ +# Sentinel class used until PEP 0661 is accepted +from typing import Literal + +from typing_extensions import override + + +class NotGiven: + """A sentinel singleton class used to distinguish omitted keyword arguments from those passed in with the value None (which may have different behavior).""" + + def __bool__(self) -> Literal[False]: + return False + + @override + def __repr__(self) -> str: + return "NOT_GIVEN" + + +NOT_GIVEN = NotGiven() diff --git a/tests/test_helpers/tasks.py b/tests/test_helpers/tasks.py index 2f0e4ebb9..cab30919e 100644 --- a/tests/test_helpers/tasks.py +++ b/tests/test_helpers/tasks.py @@ -17,6 +17,7 @@ def minimal_task() -> Task: solver=[generate()], scorer=includes(), metadata={"task_idx": 1}, + name="minimal", ) diff --git a/tests/test_task_with.py b/tests/test_task_with.py new file mode 100644 index 000000000..84a1bbdd9 --- /dev/null +++ b/tests/test_task_with.py @@ -0,0 +1,31 @@ +from test_helpers.tasks import minimal_task + +from inspect_ai import task_with + + +def test_task_with_add_options(): + task = task_with(minimal_task(), time_limit=30) + assert task.time_limit == 30 + assert task.metadata is not None + + +def test_task_with_remove_options(): + task = task_with( + minimal_task(), + scorer=None, + ) + assert task.scorer is None + assert task.metadata is not None + + +def test_task_with_edit_options(): + task = task_with( + minimal_task(), + metadata={"foo": "bar"}, + ) + assert task.metadata == {"foo": "bar"} + + +def test_task_with_name_option(): + task = task_with(minimal_task(), name="changed") + assert task.name == "changed"