Skip to content

Commit

Permalink
task_with function for creating task variants (#1099)
Browse files Browse the repository at this point in the history
* task_with implementation

* proper handling of name

* consolidate task member assignment logic

* require keyword args for task_with

* update changelog

* mypy fix

* correct import for override
  • Loading branch information
jjallaire authored Jan 10, 2025
1 parent d6592fd commit c4f6d40
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 28 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased

- `task_with()` function for creating task variants.
- Print model conversations to terminal with `--display=conversation` (was formerly `--trace`, which is now deprecated).

## v0.3.57 (09 January 2025)
Expand Down
3 changes: 2 additions & 1 deletion src/inspect_ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from inspect_ai._eval.list import list_tasks
from inspect_ai._eval.registry import task
from inspect_ai._eval.score import score, score_async
from inspect_ai._eval.task import Epochs, Task, TaskInfo, Tasks
from inspect_ai._eval.task import Epochs, Task, TaskInfo, Tasks, task_with
from inspect_ai._util.constants import PKG_NAME
from inspect_ai.solver._human_agent.agent import human_agent

Expand All @@ -29,4 +29,5 @@
"TaskInfo",
"Tasks",
"task",
"task_with",
]
4 changes: 2 additions & 2 deletions src/inspect_ai/_eval/task/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .task import Task, TaskInfo, PreviousTask, Tasks # noqa: I001, F401
from .task import Task, TaskInfo, PreviousTask, Tasks, task_with # noqa: I001, F401
from .epochs import Epochs

__all__ = ["Epochs", "Task", "TaskInfo", "PreviousTask", "Tasks"]
__all__ = ["Epochs", "Task", "TaskInfo", "PreviousTask", "Tasks", "task_with"]
165 changes: 140 additions & 25 deletions src/inspect_ai/_eval/task/task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from dataclasses import dataclass
from logging import getLogger
from typing import Any, Callable, Sequence, cast
Expand All @@ -6,6 +7,7 @@
from typing_extensions import TypedDict, Unpack

from inspect_ai._util.logger import warn_once
from inspect_ai._util.notgiven import NOT_GIVEN, NotGiven
from inspect_ai._util.registry import is_registry_object, registry_info
from inspect_ai.approval._policy import ApprovalPolicy, approval_policies_from_config
from inspect_ai.dataset import Dataset, MemoryDataset, Sample
Expand Down Expand Up @@ -115,35 +117,15 @@ def __init__(
f"DEPRECATED: the '{arg}' parameter is deprecated (please use the '{newarg}' parameter instead)",
)

# resolve epochs / epochs_reducer
if isinstance(epochs, int):
epochs = Epochs(epochs)
if epochs is not None and epochs.epochs < 1:
raise ValueError("epochs must be a positive integer.")

# resolve dataset (provide empty sample to bootstrap tasks w/o samples,
# which could occur for testing or for an interactive mode eval)
dataset = dataset or [Sample(input="prompt")]
self.dataset: Dataset = (
dataset if isinstance(dataset, Dataset) else MemoryDataset(list(dataset))
)
self.dataset = resolve_dataset(dataset)
self.setup = setup
self.solver = chain(solver) if isinstance(solver, list) else solver
self.scorer = (
scorer
if isinstance(scorer, list)
else [scorer]
if scorer is not None
else None
)
self.solver = resolve_solver(solver)
self.scorer = resolve_scorer(scorer)
self.metrics = metrics
self.config = config
self.sandbox = resolve_sandbox_environment(sandbox)
self.approval = (
approval_policies_from_config(approval)
if isinstance(approval, str)
else approval
)
self.approval = resolve_approval(approval)
epochs = resolve_epochs(epochs)
self.epochs = epochs.epochs if epochs else None
self.epochs_reducer = epochs.reducer if epochs else None
self.fail_on_error = fail_on_error
Expand Down Expand Up @@ -171,6 +153,106 @@ def attribs(self) -> dict[str, Any]:
return dict()


def task_with(
task: Task,
*,
dataset: Dataset | Sequence[Sample] | None | NotGiven = NOT_GIVEN,
setup: Solver | list[Solver] | None | NotGiven = NOT_GIVEN,
solver: Solver | list[Solver] | NotGiven = NOT_GIVEN,
scorer: Scorer | list[Scorer] | None | NotGiven = NOT_GIVEN,
metrics: list[Metric] | dict[str, list[Metric]] | None | NotGiven = NOT_GIVEN,
config: GenerateConfig | NotGiven = NOT_GIVEN,
sandbox: SandboxEnvironmentType | None | NotGiven = NOT_GIVEN,
approval: str | list[ApprovalPolicy] | None | NotGiven = NOT_GIVEN,
epochs: int | Epochs | None | NotGiven = NOT_GIVEN,
fail_on_error: bool | float | None | NotGiven = NOT_GIVEN,
message_limit: int | None | NotGiven = NOT_GIVEN,
token_limit: int | None | NotGiven = NOT_GIVEN,
time_limit: int | None | NotGiven = NOT_GIVEN,
name: str | None | NotGiven = NOT_GIVEN,
version: int | NotGiven = NOT_GIVEN,
metadata: dict[str, Any] | None | NotGiven = NOT_GIVEN,
) -> Task:
"""Task adapted with alternate values for one or more options.
Args:
task (Task): Task to adapt (it is deep copied prior to mutating options)
dataset (Dataset | Sequence[Sample]): Dataset to evaluate
setup: (Solver | list[Solver] | None): Setup step (always run
even when the main `solver` is replaced).
solver: (Solver | list[Solver]): Solver or list of solvers.
Defaults to generate(), a normal call to the model.
scorer: (Scorer | list[Scorer] | None): Scorer used to evaluate model output.
metrics (list[Metric] | dict[str, list[Metric]] | None):
Alternative metrics (overrides the metrics provided by the specified scorer).
config (GenerateConfig): Model generation config.
sandbox (SandboxEnvironmentType | None): Sandbox environment type
(or optionally a str or tuple with a shorthand spec)
approval: (str | list[ApprovalPolicy] | None): Tool use approval policies.
Either a path to an approval policy config file or a list of approval policies.
Defaults to no approval policy.
epochs (int | Epochs | None): Epochs to repeat samples for and optional score
reducer function(s) used to combine sample scores (defaults to "mean")
fail_on_error (bool | float | None): `True` to fail on first sample error
(default); `False` to never fail on sample errors; Value between 0 and 1
to fail if a proportion of total samples fails. Value greater than 1 to fail
eval if a count of samples fails.
message_limit (int | None): Limit on total messages used for each sample.
token_limit (int | None): Limit on total tokens used for each sample.
time_limit (int | None): Limit on time (in seconds) for execution of each sample.
name: (str | None): Task name. If not specified is automatically
determined based on the name of the task directory (or "task")
if its anonymous task (e.g. created in a notebook and passed to
eval() directly)
version: (int): Version of task (to distinguish evolutions
of the task spec or breaking changes to it)
metadata: (dict[str, Any] | None): Additional metadata to associate with the task.
Returns:
Task: Task adapted with alternate options.
"""
# deep copy the task
task = deepcopy(task)

if not isinstance(dataset, NotGiven):
task.dataset = resolve_dataset(dataset)
if not isinstance(setup, NotGiven):
task.setup = setup
if not isinstance(solver, NotGiven):
task.solver = resolve_solver(solver)
if not isinstance(scorer, NotGiven):
task.scorer = resolve_scorer(scorer)
if not isinstance(metrics, NotGiven):
task.metrics = metrics
if not isinstance(config, NotGiven):
task.config = config
if not isinstance(sandbox, NotGiven):
task.sandbox = resolve_sandbox_environment(sandbox)
if not isinstance(approval, NotGiven):
task.approval = resolve_approval(approval)
if not isinstance(epochs, NotGiven):
epochs = resolve_epochs(epochs)
task.epochs = epochs.epochs if epochs else None
task.epochs_reducer = epochs.reducer if epochs else None
if not isinstance(fail_on_error, NotGiven):
task.fail_on_error = fail_on_error
if not isinstance(message_limit, NotGiven):
task.message_limit = message_limit
if not isinstance(token_limit, NotGiven):
task.token_limit = token_limit
if not isinstance(time_limit, NotGiven):
task.time_limit = time_limit
if not isinstance(version, NotGiven):
task.version = version
if not isinstance(name, NotGiven):
task._name = name
if not isinstance(metadata, NotGiven):
task.metadata = metadata

# return modified task
return task


class TaskInfo(BaseModel):
"""Task information (file, name, and attributes)."""

Expand Down Expand Up @@ -225,3 +307,36 @@ class PreviousTask:
can be specified). None is a request to read a task out
of the current working directory.
"""


def resolve_approval(
approval: str | list[ApprovalPolicy] | None,
) -> list[ApprovalPolicy] | None:
return (
approval_policies_from_config(approval)
if isinstance(approval, str)
else approval
)


def resolve_epochs(epochs: int | Epochs | None) -> Epochs | None:
if isinstance(epochs, int):
epochs = Epochs(epochs)
if epochs is not None and epochs.epochs < 1:
raise ValueError("epochs must be a positive integer.")
return epochs


def resolve_dataset(dataset: Dataset | Sequence[Sample] | None) -> Dataset:
dataset = dataset or [Sample(input="prompt")]
return dataset if isinstance(dataset, Dataset) else MemoryDataset(list(dataset))


def resolve_solver(solver: Solver | list[Solver]) -> Solver:
return chain(solver) if isinstance(solver, list) else solver


def resolve_scorer(scorer: Scorer | list[Scorer] | None) -> list[Scorer] | None:
return (
scorer if isinstance(scorer, list) else [scorer] if scorer is not None else None
)
18 changes: 18 additions & 0 deletions src/inspect_ai/_util/notgiven.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Sentinel class used until PEP 0661 is accepted
from typing import Literal

from typing_extensions import override


class NotGiven:
"""A sentinel singleton class used to distinguish omitted keyword arguments from those passed in with the value None (which may have different behavior)."""

def __bool__(self) -> Literal[False]:
return False

@override
def __repr__(self) -> str:
return "NOT_GIVEN"


NOT_GIVEN = NotGiven()
1 change: 1 addition & 0 deletions tests/test_helpers/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def minimal_task() -> Task:
solver=[generate()],
scorer=includes(),
metadata={"task_idx": 1},
name="minimal",
)


Expand Down
31 changes: 31 additions & 0 deletions tests/test_task_with.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from test_helpers.tasks import minimal_task

from inspect_ai import task_with


def test_task_with_add_options():
task = task_with(minimal_task(), time_limit=30)
assert task.time_limit == 30
assert task.metadata is not None


def test_task_with_remove_options():
task = task_with(
minimal_task(),
scorer=None,
)
assert task.scorer is None
assert task.metadata is not None


def test_task_with_edit_options():
task = task_with(
minimal_task(),
metadata={"foo": "bar"},
)
assert task.metadata == {"foo": "bar"}


def test_task_with_name_option():
task = task_with(minimal_task(), name="changed")
assert task.name == "changed"

0 comments on commit c4f6d40

Please sign in to comment.