diff --git a/CHANGELOG.md b/CHANGELOG.md index 092412a27..ddac9905c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,14 @@ # Changelog +## v0.3.9 (14 May 2024) + +- Add `ollama` local model provider. +- Add `multi_scorer()` and `majority_vote()` functions for combining multiple scorers into a single score. +- Add support for multiple model graders in `model_graded_qa()`. +- Raise `TypeError` for solvers and scorers not declared as `async`. +- Fallback to standard parase if `NaN` or `Inf` is encountered while reading log file header. +- Remove deprecated support for matching partial model names (e.g. "gpt" or "claude"). + ## v0.3.8 (07 May 2024) - Exclude null config values from listings in log viewer. diff --git a/README.md b/README.md index a98b0d5fe..ce1b86244 100644 --- a/README.md +++ b/README.md @@ -18,4 +18,4 @@ $ cd inspect_ai $ pip install -e ".[dev]" ``` -If you use VS Code, you should be sure to have installed the recommended extensions (Python, Ruff, and MyPy). Note that you'll be promoted to install these when you open the project in VS Code. +If you use VS Code, you should be sure to have installed the recommended extensions (Python, Ruff, and MyPy). Note that you'll be prompted to install these when you open the project in VS Code. diff --git a/docs/eval-suites.qmd b/docs/eval-suites.qmd index dce742dc0..d811e38ed 100644 --- a/docs/eval-suites.qmd +++ b/docs/eval-suites.qmd @@ -46,7 +46,7 @@ if __name__ == "__main__": eval(security_guide, model="google/gemini-1.0-pro") ``` -Doing this allows your source file to be both a Python script that is convenient to run during development as well as be a Python module that tasks can be read from without executing the eval. There is no real downside to this, and it's a good way in general to write all of your eval scripts and notebooks (see the docs on [\_\_main\_\_](https://docs.python.org/3/library/main.html) for additional details.) +Doing this allows your source file to be both a Python script that is convenient to run during development as well as be a Python module that tasks can be read from without executing the eval. There is no real downside to this, and it's a good way in general to write all of your eval scripts and notebooks (see the docs on [\_\_main\_\_](https://docs.python.org/3/library/__main__.html) for additional details.) ## Use Cases diff --git a/docs/index.qmd b/docs/index.qmd index 51cf2230f..5f124ca1c 100644 --- a/docs/index.qmd +++ b/docs/index.qmd @@ -75,7 +75,7 @@ $ inspect eval ctf.py --model together/Qwen/Qwen1.5-72B-Chat ``` ::: -In addition to the model providers shown above, Inspect also supports models hosted on Azure AI, AWS Bedrock, and Cloudflare. See the documentation on [Models](#sec-models) for additional details. +In addition to the model providers shown above, Inspect also supports models hosted on Azure AI, AWS Bedrock, and Cloudflare, as well as local models with Ollama. See the documentation on [Models](#sec-models) for additional details. ## Hello, Inspect {#sec-hello-inspect} diff --git a/docs/models.qmd b/docs/models.qmd index daff06653..b281483f5 100644 --- a/docs/models.qmd +++ b/docs/models.qmd @@ -11,6 +11,7 @@ Inspect has built in support for a variety of language model API providers and c | Google | `pip install google-generativeai` | `GOOGLE_API_KEY` | | Mistral | `pip install mistralai` | `MISTRAL_API_KEY` | | Hugging Face | `pip install transformers` | `HF_TOKEN` | +| Ollama | `pip install openai` | None required | | TogetherAI | `pip install openai` | `TOGETHER_API_KEY` | | AWS Bedrock | `pip install boto3` | `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, and `AWS_DEFAULT_REGION` | | Azure AI | None required | `AZURE_API_KEY` and `INSPECT_EVAL_MODEL_BASE_URL` | @@ -18,6 +19,10 @@ Inspect has built in support for a variety of language model API providers and c : {tbl-colwidths="\[18,45,37\]"} +::: {.callout-note appearance="minimal"} +Note that some providers ([Ollama](https://github.com/ollama/ollama/blob/main/docs/openai.md) and [TogetherAI](https://docs.together.ai/docs/openai-api-compatibility)) support the OpenAI Python package as a client, which is why you need to `pip install openai` for these providers even though you aren't actually interacting with the OpenAI service when you use them. +::: + ## Using Models To select a model for use in an evaluation task you specify it using a *model name*. Model names include their API provider and the specific model to use (e.g. `openai/gpt-4`) Here are the supported providers along with example model names and links to documentation on all available models: @@ -29,6 +34,7 @@ To select a model for use in an evaluation task you specify it using a *model na | Google | `google/gemini-1.0-pro` | [Google Models](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models) | | Mistral | `mistral/mistral-large-latest` | [Mistral Models](https://docs.mistral.ai/platform/endpoints/) | | Hugging Face | `hf/openai-community/gpt2` | [Hugging Face Models](https://huggingface.co/models?pipeline_tag=text-generation&sort=trending) | +| Ollama | `ollama/llama3` | [Ollama Models](https://ollama.com/library) | | TogetherAI | `together/lmsys/vicuna-13b-v1.5` | [TogetherAI Models](https://docs.together.ai/docs/inference-models#chat-models) | | AWS Bedrock | `bedrock/meta.llama2-70b-chat-v1` | [AWS Bedrock Models](https://aws.amazon.com/bedrock/) | | Azure AI | `azureai/azure-deployment-name` | [Azure AI Models](https://ai.azure.com/explore/models) | @@ -71,6 +77,7 @@ Each model also can use a different base URL than the default (e.g. if running t | Google | `GOOGLE_BASE_URL` | | Mistral | `MISTRAL_BASE_URL` | | TogetherAI | `TOGETHER_BASE_URL` | +| Ollama | `OLLAMA_BASE_URL` | | AWS Bedrock | `BEDROCK_BASE_URL` | | Azure AI | `AZUREAI_BASE_URL` | | Cloudflare | `CLOUDFLARE_BASE_URL` | @@ -310,13 +317,14 @@ The additional `model_args` are forwarded as follows for the various providers: | Google | `genai.configure` | | Mistral | `MistralAsyncClient` | | Hugging Face | `AutoModelForCausalLM.from_pretrained` | +| Ollama | `AsyncOpenAI` | | TogetherAI | `AsyncOpenAI` | | AzureAI | Chat HTTP Post Body | | Cloudflare | Chat HTTP Post Body | : {tbl-colwidths="\[30,70\]"} -See the OpenAI, Anthropic, Google, Mistral, Hugging Face, TogetherAI, Azure AI, and Cloudflare provider documentation for more information on the additional options available. +See the OpenAI, Anthropic, Google, Mistral, Hugging Face, Ollama, TogetherAI, Azure AI, and Cloudflare provider documentation for more information on the additional options available. ## Custom Models @@ -358,4 +366,4 @@ model = get_model("custom/name-of-model") eval(math, model = "custom/name-of-model") ``` -In this example, the `model_name` argument passed to `__init__()` will be "name-of-model". +In this example, the `model_name` argument passed to `__init__()` will be "name-of-model". \ No newline at end of file diff --git a/docs/scorers.qmd b/docs/scorers.qmd index 93d19fb22..92a3437d1 100644 --- a/docs/scorers.qmd +++ b/docs/scorers.qmd @@ -60,7 +60,7 @@ Task( ) ``` -### Model Graded +## Model Graded Model graded scorers are well suited to assessing open ended answers as well as factual answers that are embedded in a longer narrative. The built-in model graded scorers can be customised in several ways—you can also create entirely new model scorers (see the model graded example below for a starting point). @@ -87,7 +87,25 @@ The default model graded QA scorer is tuned to grade answers to open ended quest The `model_graded_fact()` scorer works identically to `model_graded_qa()`, and simply provides an alternate `template` oriented around judging whether a fact is included in the model output. -If you want to understand how the default templates for `model_graded_qa()` and `model_graded_fact()` work, see their [source code](https://github.com/AI-Safety-Institute/inspect_ai/blob/main/src/inspect_ai/scorer/_model.py). +If you want to understand how the default templates for `model_graded_qa()` and `model_graded_fact()` work, see their [source code](https://github.com/UKGovernmentBEIS/inspect_ai/blob/main/src/inspect_ai/scorer/_model.py). + +### Multiple Models + +The built-in model graded scorers also support using multiple grader models (whereby the final grade is chosen by majority vote). For example, here we specify that 3 models should be used for grading: + +```python +model_graded_qa( + models = [ + "google/gemini-1.0-pro", + "anthropic/claude-3-opus-20240229" + "together/meta-llama/Llama-3-70b-chat-hf", + ] +) +``` + +The implementation of multiple grader models takes advantage of the `multi_scorer()` and `majority_vote()` functions, both of which can be used in your own scorers (as described in the [Multi Scorer](#sec-multi-scorer) section below). + + ## Custom Scorers @@ -239,6 +257,43 @@ Note that the call to `model_grader.generate()` is done with `await`—this is c Note also we use the `input_text` property of the `TaskState` to access a string version of the original user input to substitute it into the grading template. Using the `input_text` has two benefits: (1) It is guaranteed to cover the original input from the dataset (rather than a transformed prompt in `messages`); and (2) It normalises the input to a string (as it could have been a message list). +## Multi Scorer + +It's possible to use multiple scorers in parallel, then combine their output into a final overall score. This is done using the `multi_scorer()` function. For example, this is roughly how the built in model graders use multiple models for grading: + +```python +multi_scorer( + scorers = [model_graded_qa(model=model) for model in models], + reducer = majority_vote +) +``` + +Use of `multi_scorer()` requires both a list of scorers as well as a _reducer_ which is a function that takes a list of scores and turns it into a single score. In this case we use the built in `majority_vote()` reducer which returns the score that appeared most frequently in the answers. + +You can imagine a variety of different strategies for reducing scores (take the average, take the high or low, majority vote, etc.). For example, here's a reducer that computes the average score: + +```python +import numpy as np + +def average_score(scores: list[Score]) -> Score: + values = [score.as_float() for score in scores] + avg = np.mean(values).item() + return Score( + value=avg, + explanation=f"average of {', '.join(values)}" + ) +``` + +Which might be used like this: + +```python +multi_scorer( + scorers = [model_graded_qa(model=model) for model in models], + reducer = average_score +) +``` + + ## Metrics Each scorer provides one or more built-in metrics (typically `accuracy` and `bootstrap_std`). In addition, you can specify other metrics (either built-in or custom) to compute when defining a `Task`: diff --git a/docs/tools.qmd b/docs/tools.qmd index 573d32d9e..f6b7ce3b1 100644 --- a/docs/tools.qmd +++ b/docs/tools.qmd @@ -9,7 +9,7 @@ Inspect natively supports registering Python functions as tools and providing th ::: {.callout-note} ### Tools and Agents -One application of tools is to run them within an agent scaffold that pursues an objective over multiple interactions with a model. The scaffold uses the model to help make decisions about which tools to use and when, and orchestrates calls to the model to use the tools. We'll cover how to use agent scaffolds in [Agent Solvers](#agents) below. +One application of tools is to run them within an agent scaffold that pursues an objective over multiple interactions with a model. The scaffold uses the model to help make decisions about which tools to use and when, and orchestrates calls to the model to use the tools. We'll cover how to use agent scaffolds in [Agent Solvers](#agent-solvers) below. ::: ## Tool Basics @@ -149,7 +149,7 @@ plan = [ ## Web Search -Inspect has a built in `web_search()` tool that provides models with the ability to enhance their context window by performing a search. By default web searches retrieves 10 results from a provider, uses a model to determine if the contents is relevant then returns the top 3 relevant search results to the main model. Here is the definition of the `web_search()` function: +Inspect has a built in `web_search()` tool that provides models with the ability to enhance their context window by performing a search. By default web searches retrieve 10 results from a provider, uses a model to determine if the contents is relevant then returns the top 3 relevant search results to the main model. Here is the definition of the `web_search()` function: ``` python def web_search( diff --git a/docs/workflow.qmd b/docs/workflow.qmd index 82d7ff673..95c97a293 100644 --- a/docs/workflow.qmd +++ b/docs/workflow.qmd @@ -234,7 +234,7 @@ if __name__ == "__main__" ``` ::: {.callout-note appearance="minimal"} -If you aren't familiar with the `__name__ == "__main__"` idiom, see the docs on [\_\_main\_\_](https://docs.python.org/3/library/main.html) for additional details. +If you aren't familiar with the `__name__ == "__main__"` idiom, see the docs on [\_\_main\_\_](https://docs.python.org/3/library/__main__.html) for additional details. ::: Now we can take the same script and use it with `inspect eval` (while leaving our exploratory code intact and protected by the `__main__` check): @@ -255,7 +255,7 @@ We refer to notebooks above but show scripts in all of the examples. Everything 1. You can use the `__name__ == "__main__"` check to protect cells that should only be run in exploratory mode. -2. You can pass a notebook to `insect eval` just the same as a script (including passing task parameters) +2. You can pass a notebook to `inspect eval` just the same as a script (including passing task parameters) For example, imagine that all of the code shown above for `security.py` was in `security.ipynb`. You could run the eval and optionally pass a task parameter as follows: diff --git a/src/inspect_ai/__init__.py b/src/inspect_ai/__init__.py index f8d7d947e..1fc015999 100644 --- a/src/inspect_ai/__init__.py +++ b/src/inspect_ai/__init__.py @@ -6,7 +6,7 @@ from inspect_ai._eval.list import list_tasks from inspect_ai._eval.registry import task from inspect_ai._eval.score import score, score_async -from inspect_ai._eval.task import Task, TaskInfo, Tasks +from inspect_ai._eval.types import Task, TaskInfo, Tasks from inspect_ai._util.constants import PKG_NAME __version__ = importlib_version(PKG_NAME) diff --git a/src/inspect_ai/_cli/list.py b/src/inspect_ai/_cli/list.py index 75fb1755f..eb9489cb6 100644 --- a/src/inspect_ai/_cli/list.py +++ b/src/inspect_ai/_cli/list.py @@ -11,7 +11,7 @@ from inspect_ai._cli.common import CommonOptions, common_options, resolve_common_options from inspect_ai._cli.util import parse_cli_args from inspect_ai._eval.list import list_tasks -from inspect_ai._eval.task import TaskInfo +from inspect_ai._eval.types import TaskInfo from inspect_ai.log import list_eval_logs diff --git a/src/inspect_ai/_cli/score.py b/src/inspect_ai/_cli/score.py index 7057553e4..3eb119af4 100644 --- a/src/inspect_ai/_cli/score.py +++ b/src/inspect_ai/_cli/score.py @@ -6,6 +6,7 @@ from inspect_ai._display import display from inspect_ai._display.logger import init_logger from inspect_ai._eval.loader import load_tasks +from inspect_ai._eval.score import task_score from inspect_ai._util.constants import SCORED_SUFFIX from inspect_ai._util.dotenv import init_dotenv from inspect_ai.log._file import JSONRecorder @@ -76,7 +77,7 @@ async def score( score_task = load_tasks([task], model)[0] # re-score the task - eval_log = await score_task.score(eval_log) + eval_log = await task_score(score_task, eval_log) # re-write the log (w/ a -score suffix if requested) scored = f"{SCORED_SUFFIX}.json" diff --git a/src/inspect_ai/_eval/eval.py b/src/inspect_ai/_eval/eval.py index c5599f599..ffb23d130 100644 --- a/src/inspect_ai/_eval/eval.py +++ b/src/inspect_ai/_eval/eval.py @@ -26,8 +26,10 @@ from inspect_ai.util._context import init_async_context from .loader import resolve_tasks -from .log import EvalLogger -from .task import Tasks, TaskSpec, task_file, task_run_dir +from .task.log import TaskLogger +from .task.run import task_run +from .task.util import task_file, task_run_dir +from .types import Tasks, TaskSpec log = logging.getLogger(__name__) @@ -130,34 +132,35 @@ async def eval_async( ) -> list[EvalLog]: r"""Evaluate tasks using a Model (async). - tasks: (Tasks): Task(s) to evaluate. If None, attempt - to evaluate a task in the current working directory - model (str | Model | None): Model for evaluation. If not - specified uses the current eval's model, or failing that - the value of the INSPECT_EVAL_MODEL environment variable. - model_base_url: (str | None): Base URL for communicating - with the model API. - model_args (dict[str,Any]): Model creation parameters - task_args (dict[str,Any]): Task arguments - plan (Solver | list[Solver] | None): Alternative plan - for evaluating task(s). Optional (uses task plan by default). - log_level (str | None): "debug", "http", "info", "warning", "error", - or "critical" (defaults to "info") - log_dir (str | None): Output path for logging results - (defaults to file log in ./logs directory). - limit (int | tuple[int, int] | None): Limit evaluated samples - (defaults to all samples). - epochs (int | None): Number of times to repeat evaluation of - samples (defaults to 1) - max_messages (int | None): Maximum number of messages to allow - in a task conversation. - max_subprocesses (int | None): Maximum number of subprocesses to - run in parallel (default is os.cpu_count()) - log_samples: (bool | None): Log detailed samples and scores (defaults to True) - log_images: (bool | None): Log base64 encoded version of images, - even if specified as a filename or URL (defaults to True) - score (bool): Score output (defaults to True) - **kwargs (GenerateConfigArgs): Model generation options. + Args: + tasks: (Tasks): Task(s) to evaluate. If None, attempt + to evaluate a task in the current working directory + model (str | Model | None): Model for evaluation. If not + specified uses the current eval's model, or failing that + the value of the INSPECT_EVAL_MODEL environment variable. + model_base_url: (str | None): Base URL for communicating + with the model API. + model_args (dict[str,Any]): Model creation parameters + task_args (dict[str,Any]): Task arguments + plan (Solver | list[Solver] | None): Alternative plan + for evaluating task(s). Optional (uses task plan by default). + log_level (str | None): "debug", "http", "info", "warning", "error", + or "critical" (defaults to "info") + log_dir (str | None): Output path for logging results + (defaults to file log in ./logs directory). + limit (int | tuple[int, int] | None): Limit evaluated samples + (defaults to all samples). + epochs (int | None): Number of times to repeat evaluation of + samples (defaults to 1) + max_messages (int | None): Maximum number of messages to allow + in a task conversation. + max_subprocesses (int | None): Maximum number of subprocesses to + run in parallel (default is os.cpu_count()) + log_samples: (bool | None): Log detailed samples and scores (defaults to True) + log_images: (bool | None): Log base64 encoded version of images, + even if specified as a filename or URL (defaults to True) + score (bool): Score output (defaults to True) + **kwargs (GenerateConfigArgs): Model generation options. Returns: List of EvalLog (one for each task) @@ -214,7 +217,7 @@ async def eval_async( ) run_id = uuid() - loggers: list[EvalLogger] = [] + loggers: list[TaskLogger] = [] results: list[EvalLog] = [] for index, name, version, task in zip( range(0, len(task_names)), task_names, task_versions, eval_tasks @@ -227,10 +230,10 @@ async def eval_async( task_eval_config.max_messages = task.max_messages # create and track the logger - logger = EvalLogger( + logger = TaskLogger( task_name=name, task_version=version, - task_file=task_file(task, True), + task_file=task_file(task, relative=True), task_run_dir=task_run_dir(task), task_id=task_id if task_id else uuid(), run_id=run_id, @@ -245,7 +248,8 @@ async def eval_async( loggers.append(logger) # run the eval - result = await task.run( + result = await task_run( + task=task, sequence=(index + 1, len(task_names)), model=model, logger=logger, diff --git a/src/inspect_ai/_eval/list.py b/src/inspect_ai/_eval/list.py index 975d907da..3d269773f 100644 --- a/src/inspect_ai/_eval/list.py +++ b/src/inspect_ai/_eval/list.py @@ -1,23 +1,14 @@ import ast -import inspect import os import re -from importlib.machinery import SourceFileLoader -from importlib.util import module_from_spec, spec_from_loader from logging import getLogger from pathlib import Path -from types import ModuleType from typing import Any, Callable -from inspect_ai._util.dotenv import dotenv_environ from inspect_ai._util.error import exception_message from inspect_ai._util.file import file -from inspect_ai._util.path import chdir_python -from inspect_ai._util.registry import RegistryInfo, is_registry_object, registry_info -from inspect_ai.model import ModelName -from .registry import task_create -from .task import TASK_FILE_ATTR, TASK_RUN_DIR_ATTR, Task, TaskInfo +from .types import TaskInfo logger = getLogger(__name__) @@ -59,39 +50,6 @@ def list_tasks( return sorted(tasks, key=lambda t: f"{t.file}@{t.name}") -def create_tasks( - globs: list[str], - model: ModelName, - task_args: dict[str, Any] = {}, - root_dir: Path | None = None, -) -> list[Task]: - tasks: list[Task] = [] - - root_dir = root_dir if root_dir is not None else Path.cwd() - - for glob in globs: - # sometimes globs are direct references to files - # that include an @ index. for this case directly - # create the task (we also need to load the file - # so the task is registered before we create it) - spec_split = split_task_spec(glob) - if len(spec_split[1]) > 0: - task_path = Path(spec_split[0]) - load_file_tasks(task_path.absolute()) - tasks.extend( - create_file_tasks(task_path, model, [spec_split[1]], task_args) - ) - else: - # if the glob is the root dir then set it to empty (will result in - # enumeration of the root dir) - target = [] if Path(glob).resolve() == root_dir.resolve() else [glob] - files = task_files(target, root_dir) - files = sorted(files, key=lambda f: f.as_posix()) - for file in files: - tasks.extend(create_file_tasks(file, model, None, task_args)) - return tasks - - def task_files(globs: list[str] = [], root_dir: Path | None = None) -> list[Path]: # root dir root_dir = root_dir if root_dir else Path.cwd() @@ -137,54 +95,6 @@ def tasks_in_dir(path: Path) -> list[Path]: return paths -def load_file_tasks(file: Path) -> list[RegistryInfo]: - with chdir_python(file.parent.as_posix()), dotenv_environ(): - return _load_task_specs(file) - - -def create_file_tasks( - file: Path, - model: ModelName, - task_specs: list[str] | list[RegistryInfo] | None = None, - task_args: dict[str, Any] = {}, -) -> list[Task]: - with chdir_python(file.parent.as_posix()), dotenv_environ(): - # if we don't have task specs then go get them (also, - # turn them into plain names) - if task_specs is None: - task_specs = _load_task_specs(file) - # convert to plain names - task_specs = [ - spec if isinstance(spec, str) else spec.name for spec in task_specs - ] - - tasks: list[Task] = [] - for task_spec in task_specs: - # create the task from the loaded source file and - # note that it was loaded from this directory - # (will be used later to ensure it runs in the directory) - task = task_create(task_spec, model, **task_args) - setattr(task, TASK_FILE_ATTR, file.as_posix()) - setattr(task, TASK_RUN_DIR_ATTR, file.parent.as_posix()) - tasks.append(task) - return tasks - - -# don't call this function directly, rather, call one of the -# higher level loading functions above (those functions -# change the working directory, this one does not b/c it is -# intended as a helper function) -def _load_task_specs(task_path: Path) -> list[RegistryInfo]: - # load the module - module = load_task_module(task_path) - if module: - # find the tasks in the module - tasks = inspect.getmembers(module, lambda m: is_registry_object(m, "task")) - return [registry_info(task[1]) for task in tasks] - else: - return [] - - excluded_pattern = re.compile("^[_\\.].*$") @@ -203,69 +113,6 @@ def is_task_path(path: Path) -> bool: ) and not is_task_path_excluded(path.name) -def split_task_spec(task_spec: str) -> tuple[str, str]: - parts = task_spec.rsplit("@", 1) - if len(parts) == 2: - return parts[0], parts[1] - else: - return task_spec, "" - - -def load_task_module(task_path: Path) -> ModuleType | None: - if task_path.suffix == ".py": - # bail if the code doesn't have a task - with open(task_path, "r", encoding="utf-8") as file: - if not code_has_task(file.read()): - return None - - module_name = task_path.as_posix() - loader = SourceFileLoader(module_name, task_path.absolute().as_posix()) - spec = spec_from_loader(loader.name, loader) - if not spec: - raise ModuleNotFoundError(f"Module {module_name} not found") - module = module_from_spec(spec) - loader.exec_module(module) - return module - - elif task_path.suffix == ".ipynb": - try: - from inspect_ai._util.notebook import NotebookLoader - except ImportError: - return None - - # bail if the code doesn't have a task - def exec_filter(cells: list[str]) -> bool: - code = "\n\n".join(cells) - return code_has_task(code) - - notebook_loader = NotebookLoader(exec_filter) - return notebook_loader.load_module(task_path.as_posix()) - - else: - raise ModuleNotFoundError( - f"Invalid extension for task file: {task_path.suffix}" - ) - - -def code_has_task(code: str) -> bool: - try: - tree = ast.parse(code) - for node in ast.iter_child_nodes(tree): - if isinstance(node, ast.FunctionDef): - for decorator in node.decorator_list: - if isinstance(decorator, ast.Name): - if str(decorator.id) == "task": - return True - elif isinstance(decorator, ast.Call): - if isinstance(decorator.func, ast.Name): - if str(decorator.func.id) == "task": - return True - except SyntaxError: - pass - - return False - - def parse_tasks(path: Path, root_dir: Path, absolute: bool) -> list[TaskInfo]: # read code from python source file if path.suffix.lower() == ".py": diff --git a/src/inspect_ai/_eval/loader.py b/src/inspect_ai/_eval/loader.py index bab3ac01c..67654a92d 100644 --- a/src/inspect_ai/_eval/loader.py +++ b/src/inspect_ai/_eval/loader.py @@ -1,15 +1,25 @@ +import ast +import inspect +from importlib.machinery import SourceFileLoader +from importlib.util import module_from_spec, spec_from_loader from pathlib import Path +from types import ModuleType from typing import Any, cast +from inspect_ai._util.dotenv import dotenv_environ +from inspect_ai._util.path import chdir_python from inspect_ai._util.registry import ( + RegistryInfo, + is_registry_object, registry_info, registry_lookup, ) from inspect_ai.model import Model, ModelName -from .list import create_tasks +from .list import task_files from .registry import task_create -from .task import Task, TaskInfo, Tasks +from .task.constants import TASK_FILE_ATTR, TASK_RUN_DIR_ATTR +from .types import Task, TaskInfo, Tasks def resolve_tasks( @@ -71,3 +81,147 @@ def load_task_spec( else: # load tasks from glob return create_tasks([task_spec], model, task_args) + + +def create_tasks( + globs: list[str], + model: ModelName, + task_args: dict[str, Any] = {}, + root_dir: Path | None = None, +) -> list[Task]: + tasks: list[Task] = [] + + root_dir = root_dir if root_dir is not None else Path.cwd() + + for glob in globs: + # sometimes globs are direct references to files + # that include an @ index. for this case directly + # create the task (we also need to load the file + # so the task is registered before we create it) + spec_split = split_task_spec(glob) + if len(spec_split[1]) > 0: + task_path = Path(spec_split[0]) + load_file_tasks(task_path.absolute()) + tasks.extend( + create_file_tasks(task_path, model, [spec_split[1]], task_args) + ) + else: + # if the glob is the root dir then set it to empty (will result in + # enumeration of the root dir) + target = [] if Path(glob).resolve() == root_dir.resolve() else [glob] + files = task_files(target, root_dir) + files = sorted(files, key=lambda f: f.as_posix()) + for file in files: + tasks.extend(create_file_tasks(file, model, None, task_args)) + return tasks + + +def load_file_tasks(file: Path) -> list[RegistryInfo]: + with chdir_python(file.parent.as_posix()), dotenv_environ(): + return _load_task_specs(file) + + +def create_file_tasks( + file: Path, + model: ModelName, + task_specs: list[str] | list[RegistryInfo] | None = None, + task_args: dict[str, Any] = {}, +) -> list[Task]: + with chdir_python(file.parent.as_posix()), dotenv_environ(): + # if we don't have task specs then go get them (also, + # turn them into plain names) + if task_specs is None: + task_specs = _load_task_specs(file) + # convert to plain names + task_specs = [ + spec if isinstance(spec, str) else spec.name for spec in task_specs + ] + + tasks: list[Task] = [] + for task_spec in task_specs: + # create the task from the loaded source file and + # note that it was loaded from this directory + # (will be used later to ensure it runs in the directory) + task = task_create(task_spec, model, **task_args) + setattr(task, TASK_FILE_ATTR, file.as_posix()) + setattr(task, TASK_RUN_DIR_ATTR, file.parent.as_posix()) + tasks.append(task) + return tasks + + +# don't call this function directly, rather, call one of the +# higher level loading functions above (those functions +# change the working directory, this one does not b/c it is +# intended as a helper function) +def _load_task_specs(task_path: Path) -> list[RegistryInfo]: + # load the module + module = load_task_module(task_path) + if module: + # find the tasks in the module + tasks = inspect.getmembers(module, lambda m: is_registry_object(m, "task")) + return [registry_info(task[1]) for task in tasks] + else: + return [] + + +def split_task_spec(task_spec: str) -> tuple[str, str]: + parts = task_spec.rsplit("@", 1) + if len(parts) == 2: + return parts[0], parts[1] + else: + return task_spec, "" + + +def load_task_module(task_path: Path) -> ModuleType | None: + if task_path.suffix == ".py": + # bail if the code doesn't have a task + with open(task_path, "r", encoding="utf-8") as file: + if not code_has_task(file.read()): + return None + + module_name = task_path.as_posix() + loader = SourceFileLoader(module_name, task_path.absolute().as_posix()) + spec = spec_from_loader(loader.name, loader) + if not spec: + raise ModuleNotFoundError(f"Module {module_name} not found") + module = module_from_spec(spec) + loader.exec_module(module) + return module + + elif task_path.suffix == ".ipynb": + try: + from inspect_ai._util.notebook import NotebookLoader + except ImportError: + return None + + # bail if the code doesn't have a task + def exec_filter(cells: list[str]) -> bool: + code = "\n\n".join(cells) + return code_has_task(code) + + notebook_loader = NotebookLoader(exec_filter) + return notebook_loader.load_module(task_path.as_posix()) + + else: + raise ModuleNotFoundError( + f"Invalid extension for task file: {task_path.suffix}" + ) + + +def code_has_task(code: str) -> bool: + try: + tree = ast.parse(code) + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.FunctionDef): + for decorator in node.decorator_list: + if isinstance(decorator, ast.Name): + if str(decorator.id) == "task": + return True + elif isinstance(decorator, ast.Call): + if isinstance(decorator.func, ast.Name): + if str(decorator.func.id) == "task": + return True + except SyntaxError: + pass + + return False diff --git a/src/inspect_ai/_eval/registry.py b/src/inspect_ai/_eval/registry.py index 26350e812..c0b676b4d 100644 --- a/src/inspect_ai/_eval/registry.py +++ b/src/inspect_ai/_eval/registry.py @@ -14,7 +14,7 @@ ) from inspect_ai.model import ModelName -from .task import Task +from .types import Task MODEL_PARAM = "model" diff --git a/src/inspect_ai/_eval/score.py b/src/inspect_ai/_eval/score.py index 09ca4b97d..98df02e3f 100644 --- a/src/inspect_ai/_eval/score.py +++ b/src/inspect_ai/_eval/score.py @@ -1,23 +1,26 @@ import asyncio -import re from copy import deepcopy from typing import Callable, cast from inspect_ai._display import display +from inspect_ai._util.dotenv import dotenv_environ +from inspect_ai._util.path import chdir_python from inspect_ai._util.platform import platform_init from inspect_ai._util.registry import ( registry_create, - registry_info, - registry_log_name, - registry_params, - registry_unqualified_name, ) -from inspect_ai.log import EvalLog, EvalMetric, EvalResults, EvalScorer +from inspect_ai.log import ( + EvalLog, + EvalMetric, +) from inspect_ai.model import ModelName from inspect_ai.scorer import Metric, Score, Scorer, Target -from inspect_ai.scorer._scorer import SCORER_METRICS, scorer_metrics from inspect_ai.solver import TaskState +from .task.results import eval_results +from .task.util import task_run_dir +from .types import Task + def score(log: EvalLog, scorer: Scorer) -> EvalLog: """Score an evaluation log. @@ -96,6 +99,33 @@ def progress() -> None: return log +async def task_score(task: Task, log: EvalLog) -> EvalLog: + with chdir_python(task_run_dir(task)), dotenv_environ(): + # confirm we have a scorer + if task.scorer is None: + raise ValueError("You must specify a scorer for evals to be scored.") + + # confirm we have samples + if log.samples is None or len(log.samples) == 0: + raise ValueError("There are no samples to score in the log.") + + task_name = task.name + display().print(f"Scoring {len(log.samples)} samples for task: {task_name}") + + # perform scoring + log = await score_async(log, task.scorer) + + # compute and log metrics + display().print(f"Aggregating scores for task: {task_name}") + if task.scorer and log.samples: + log.results = eval_results( + [sample.score for sample in log.samples if isinstance(sample.score, Score)], + task.scorer, + task.metrics, + ) + return log + + async def run_score_task( state: TaskState, target: Target, @@ -107,67 +137,6 @@ async def run_score_task( return result -def eval_results( - scores: list[Score], scorer: Scorer | None, metrics: list[Metric] = [] -) -> EvalResults: - # record scorer - results = EvalResults() - if scorer: - # extract non-metrics metadata - metadata = deepcopy(registry_info(scorer).metadata) - del metadata[SCORER_METRICS] - - # build results - results.scorer = EvalScorer( - name=registry_log_name(scorer), - params=registry_params(scorer), - metadata=metadata if len(metadata.keys()) > 0 else None, - ) - - # we want to use simple names for metrics in the metrics dict - # (i.e. without package prefixes). we do this by getting the - # unqualified name, then appending a suffix if there are duplicates - # this keeps the code straightforward and intuitive for users - # programming against the log (e.g. metrics["accuracy"]) vs. - # metrics["pkgname/accuracy"]) - for metric in target_metrics(scorer, metrics): - key = metrics_unique_key( - registry_unqualified_name(metric), list(results.metrics.keys()) - ) - results.metrics[key] = EvalMetric( - name=registry_log_name(metric), value=metric(scores) - ) - return results - - -def metrics_unique_key(key: str, existing: list[str]) -> str: - if key not in existing: - return key - else: - key_index = 2 - pattern = re.compile(f"{re.escape(key)}(\\d+)") - for existing_key in existing: - match = pattern.match(existing_key) - index = int(match.group(1)) if match else None - if index and (index >= key_index): - key_index = index + 1 - return f"{key}{key_index}" - - -# build a list of metrics (scorer built-in metrics + de-duplicated additional metrics) -def target_metrics(scorer: Scorer, metrics: list[Metric]) -> list[Metric]: - target_metrics = scorer_metrics(scorer) - target_metrics_names = [registry_log_name(metric) for metric in target_metrics] - target_metrics.extend( - [ - metric - for metric in metrics - if registry_log_name(metric) not in target_metrics_names - ] - ) - return target_metrics - - def metrics_from_log(log: EvalLog) -> list[Metric]: return ( [metric_from_log(metric) for metric in log.results.metrics.values()] diff --git a/src/inspect_ai/_eval/task.py b/src/inspect_ai/_eval/task.py deleted file mode 100644 index dd22c57a9..000000000 --- a/src/inspect_ai/_eval/task.py +++ /dev/null @@ -1,668 +0,0 @@ -import asyncio -import os -import sys -from copy import deepcopy -from dataclasses import dataclass -from logging import getLogger -from typing import Any, Callable, Sequence, cast - -from pydantic import BaseModel -from typing_extensions import Unpack - -from inspect_ai._display import display -from inspect_ai._display._display import TaskProfile -from inspect_ai._util.constants import DEFAULT_EPOCHS -from inspect_ai._util.datetime import iso_now -from inspect_ai._util.dotenv import dotenv_environ -from inspect_ai._util.error import exception_message -from inspect_ai._util.path import chdir_python, cwd_relative_path -from inspect_ai._util.registry import ( - is_registry_object, - registry_info, - registry_log_name, - registry_params, -) -from inspect_ai.dataset import Dataset, MemoryDataset, Sample -from inspect_ai.log import ( - EvalConfig, - EvalError, - EvalLog, - EvalPlan, - EvalPlanStep, - EvalStats, - LoggingMessage, -) -from inspect_ai.log._log import eval_error -from inspect_ai.model import ( - ChatMessage, - ChatMessageTool, - ChatMessageUser, - GenerateConfig, - GenerateConfigArgs, - Model, - ModelName, - ToolCall, - ToolFunction, - ToolInfo, -) -from inspect_ai.model._model import collect_model_usage -from inspect_ai.scorer import Metric, Score, Scorer, Target -from inspect_ai.solver import Generate, Plan, Solver, TaskState, Tool, generate -from inspect_ai.solver._tool.tool import TOOL_PARAMS -from inspect_ai.solver._tool.tool_def import ToolDef, tool_defs -from inspect_ai.util._context.logger import collect_logger_records - -from .images import ( - messages_with_base64_images, - samples_with_base64_images, -) -from .log import EvalLogger -from .score import eval_results, score_async - -logger = getLogger(__name__) - -TASK_FILE_ATTR = "__task_file__" -TASK_RUN_DIR_ATTR = "__task_run_dir__" - - -class Task: - r"""Evaluation task. - - Tasks are the basis for defining and running evaluations. Tasks - are parameterized with a dataset, a scorer, and metrics. Tasks - also may optionally provide a default plan for execution. - - Args: - dataset (Dataset | Sequence[Sample]): Dataset to evaluate - plan: (Plan | Solver | list[Solver]): Default plan. If not specified - defaults to generate(), a normal call to the model. - scorer: (Scorer | None): Scorer used to evaluate model output. - metrics (list[Metric]): Additional metrics to compute beyond - the base metrics provided by the scorer. - config (GenerateConfig): Model generation config. - epochs (int): Default number of epochs to run for. - max_messages (int | None): Limit on total messages in the conversation. - name: (str | None): Task name. If not specified is automatically - determined based on the name of the task directory (or "task") - if its anonymous task (e.g. created in a notebook and passed to - eval() directly) - version: (int): Version of task (to distinguish evolutions - of the task spec or breaking changes to it) - """ - - def __init__( - self, - dataset: Dataset | Sequence[Sample], - plan: Plan | Solver | list[Solver] = generate(), - scorer: Scorer | None = None, - metrics: list[Metric] = [], - config: GenerateConfig = GenerateConfig(), - epochs: int | None = None, - max_messages: int | None = None, - name: str | None = None, - version: int = 0, - ) -> None: - self.dataset = ( - dataset if isinstance(dataset, Dataset) else MemoryDataset(list(dataset)) - ) - self.plan = plan if isinstance(plan, Plan) else Plan(plan) - self.scorer = scorer - self.metrics = metrics - self.config = config - self.epochs = epochs - self.max_messages = max_messages - self.version = version - self._name = name - - @property - def name(self) -> str: - if self._name is not None: - return self._name - elif is_registry_object(self): - return registry_info(self).name - else: - return "task" - - @property - def attribs(self) -> dict[str, Any]: - if is_registry_object(self): - return cast(dict[str, Any], registry_info(self).metadata.get("attribs", {})) - else: - return dict() - - async def run( - self, - sequence: tuple[int, int], - model: Model, - logger: EvalLogger, - config: EvalConfig = EvalConfig(), - plan: Plan | Solver | list[Solver] | None = None, - score: bool = True, - **kwargs: Unpack[GenerateConfigArgs], - ) -> EvalLog: - r"""Run the task. - - Run the task with the passed model and configuration, using the - samples, scorer, metrics and solver(s) specified for the task. - - Args: - sequence (int): Sequence of the run within a larger set of runs - model (Model): Model used to generate output - logger (EvalLogger): Logger for recording results. - config (EvalConfig): Config (sample range/epochs, logging options) - plan:(Plan | Solver | list[Solver] | None): Override of - task default plan. - score (bool | None): Score model output. If not specified - is determined automatically based on whether the task - has a solver and metrics defined. - **kwargs (GenerateConfigArgs): Generation config options - - Returns: - EvalLog for executed task. - - """ - with chdir_python(task_run_dir(self)), dotenv_environ(): - # track stats and error - stats = EvalStats(started_at=iso_now()) - error: EvalError | None = None - - # see if we are scoring - score = score and self.scorer is not None - - # evaluate the task (accumulate scores for metrics) - model_name = ModelName(model) - - # apply limit to dataset - dataset_limit = ( - slice(0, len(self.dataset)) - if config.limit is None - else ( - slice(*config.limit) - if isinstance(config.limit, tuple) - else slice(0, config.limit) - ) - ) - dataset = self.dataset[dataset_limit] if dataset_limit else self.dataset - - # add sample ids to dataset if they aren't there (start at 1 not 0) - for id, sample in zip( - range(dataset_limit.start, dataset_limit.stop), dataset - ): - if sample.id is None: - sample.id = id + 1 - - # resolve the plan and scorer - plan = ( - plan - if isinstance(plan, Plan) - else Plan(plan) - if plan is not None - else self.plan - ) - scorer: Scorer | None = self.scorer if (score and self.scorer) else None - - # compute the generate() config. we start with the base task config, - # then merge any deltas provided by the **kwargs for this call to run() - generate_config = self.config.merge(GenerateConfigArgs(**kwargs)) - - # log the plan - self._log_plan(logger, plan, generate_config) - - # provide solvers a function that they can use to generate output - async def generate( - state: TaskState, **kwargs: Unpack[GenerateConfigArgs] - ) -> TaskState: - return await self._generate( - model=model, - state=state, - config=generate_config.merge(kwargs), - max_messages=config.max_messages, - ) - - # apply epochs (deepcopy the samples so they remain independent) - epochs = config.epochs if config.epochs else DEFAULT_EPOCHS - samples: list[Sample] = [] - for _ in range(0, epochs): - samples.extend([deepcopy(sample) for sample in dataset]) - - # if we are logging images then resolve sample images here - log_images = config.log_images is not False - if log_images: - samples = await samples_with_base64_images(samples) - - # prime the eval tasks (deep copy so they share no state w/ sample) - sample_epochs: list[int] = [] - for e in range(0, epochs): - sample_epochs.extend([e + 1] * len(dataset)) - states = [ - deepcopy( - TaskState( - sample_id=sample.id or 0, - epoch=epoch, - model=model_name, - input=sample.input, - choices=sample.choices, - messages=sample_messages(sample), - completed=False, - metadata=sample.metadata if sample.metadata else {}, - ) - ) - for epoch, sample in zip(sample_epochs, samples) - ] - - # create task profile for display - profile = TaskProfile( - name=self.name, - sequence=sequence, - model=model_name, - dataset=self.dataset.name or "(samples)", - scorer=( - registry_log_name(self.scorer) - if is_registry_object(self.scorer) - else "(none)" - ), - samples=len(samples), - eval_config=config, - task_args=logger.eval.task_args, - generate_config=generate_config, - log_location=logger.location, - ) - - with display().task(profile) as td: - try: - # run w/ progress (steps = samples * steps in plan + 1 for scorer) - total_steps = len(samples) * ( - len(plan.steps) + (1 if plan.finish else 0) + (1) # scorer - ) - with td.progress(total=total_steps) as p: - - def progress() -> None: - p.update(1) - - tasks = [ - self.run_eval_task( - sample=sample, - state=state, - plan=plan, - max_messages=config.max_messages, - scorer=scorer, - generate=generate, - progress=progress, - ) - for (sample, state) in zip(samples, states) - ] - - # run them in parallel - scores = await asyncio.gather(*tasks) - - # log output by epoch - if config.log_samples is not False: - # if we are logging images then be sure to base64 images injected by solvers - if log_images: - states = await states_with_base64_images(states) - - for e in range(0, epochs): - sl = slice(e * len(dataset), (e + 1) * (len(dataset))) - self._log_output( - logger, e + 1, samples[sl], states[sl], scores[sl] - ) - - # compute and record metrics if we have scores (don't compute metrics on errors) - completed_scores = [ - score for score in scores if isinstance(score, Score) - ] - if len(completed_scores) > 0: - results = eval_results( - completed_scores, - self.scorer, - self.metrics, - ) - logger.log_results(results) - - # collect eval data - collect_eval_data(stats, logger) - - # display task summary - td.summary(results, stats) - - except asyncio.CancelledError as ex: - raise ex - - except BaseException as ex: - # mark completed - stats.completed_at = iso_now() - - # get exception info - type, value, traceback = sys.exc_info() - type = type if type else BaseException - value = value if value else ex - - # build eval error - error = eval_error(ex, type, value, traceback) - - # collect eval data - collect_eval_data(stats, logger) - - # display it - td.error(error, type, value, traceback) - - # log as appropriate - if error: - return logger.log_failure(stats, error) - else: - return logger.log_success(stats) - - async def score(self, log: EvalLog) -> EvalLog: - with chdir_python(task_run_dir(self)), dotenv_environ(): - # confirm we have a scorer - if self.scorer is None: - raise ValueError("You must specify a scorer for evals to be scored.") - - # confirm we have samples - if log.samples is None or len(log.samples) == 0: - raise ValueError("There are no samples to score in the log.") - - task_name = self.name - display().print(f"Scoring {len(log.samples)} samples for task: {task_name}") - - # perform scoring - log = await score_async(log, self.scorer) - - # compute and log metrics - display().print(f"Aggregating scores for task: {task_name}") - if self.scorer and log.samples: - log.results = eval_results( - [ - sample.score - for sample in log.samples - if isinstance(sample.score, Score) - ], - self.scorer, - self.metrics, - ) - return log - - async def run_eval_task( - self, - sample: Sample, - state: TaskState, - plan: Plan, - max_messages: int | None, - scorer: Scorer | None, - generate: Generate, - progress: Callable[..., None], - ) -> Score | None: - # solver loop - try: - # run plan steps (checking for early termination) - for index, solver in enumerate(plan.steps): - # run the solver - state = await solver(state, generate) - progress() - - # check for early termination (tick remaining progress) - if state.completed or has_max_messages(state, max_messages): - for _ in range(index + 1, len(plan.steps)): - progress() - break - - # run finishing step them mark completed - if plan.finish: - state = await plan.finish(state, generate) - progress() - state.completed = True - - finally: - # safely run cleanup function if there is one - if plan.cleanup: - try: - await plan.cleanup(state) - except Exception as ex: - logger.warning( - f"Exception occurred during plan cleanup for task {self.name}: " - + f"{exception_message(ex)}" - ) - pass - - # score it - result = await scorer(state, Target(sample.target)) if scorer else None - progress() - - # return - return result - - async def _generate( - self, - model: Model, - state: TaskState, - config: GenerateConfig, - max_messages: int | None, - ) -> TaskState: - # track tool_choice (revert to "none" after first forced call of a tool) - tool_choice = state.tool_choice - - while True: - # call the model - output = await model.generate( - state.messages, - tools_info(state.tools), - tool_choice, - config, - ) - - # append the assistant message - message = output.choices[0].message - state.messages.append(message) - - # check for max messages - if has_max_messages(state, max_messages): - state.output = output - return state - - # resolve tool calls if necessary - tdefs = tool_defs(state.tools) - if message.tool_calls and len(message.tool_calls) > 0: - for tool_call in message.tool_calls: - tool_error: str | None = None - try: - result = await call_tool(tdefs, tool_call, state.metadata) - except Exception as ex: - result = "" - tool_error = exception_message(ex) - - if isinstance(result, tuple): - result, metadata = result - state.metadata.update(metadata) - - state.messages.append( - ChatMessageTool( - content=str(result), - tool_error=tool_error, - tool_call_id=tool_call.id, - ) - ) - - # check for max messages - if has_max_messages(state, max_messages): - state.output = output - return state - - # if a tool_call was forced set tool_choice to 'none' - # (otherwise it will get forced over and over again) - if isinstance(tool_choice, ToolFunction): - tool_choice = "none" - - # no tool calls, we are done! - else: - state.output = output - return state - - def _log_output( - self, - logger: EvalLogger, - epoch: int, - samples: list[Sample], - states: list[TaskState], - scores: list[Score | None], - ) -> None: - for i in range(len(samples)): - logger.log_sample(epoch, samples[i], states[i], scores[i]) - - def _log_plan( - self, - logger: EvalLogger, - plan: Plan, - config: GenerateConfig, - ) -> None: - def eval_plan_step(solver: Solver) -> EvalPlanStep: - return EvalPlanStep( - solver=registry_log_name(solver), params=registry_params(solver) - ) - - eval_plan = EvalPlan( - name=plan.name, - steps=[eval_plan_step(solver) for solver in plan.steps], - finish=eval_plan_step(plan.finish) if plan.finish else None, - config=config, - ) - if plan.finish: - eval_plan.steps.append(eval_plan_step(plan.finish)) - - logger.log_event("plan", eval_plan) - - -class TaskInfo(BaseModel): - """Task information (file, name, and attributes).""" - - file: str - """File path where task was loaded from.""" - - name: str - """Task name (defaults to function name)""" - - attribs: dict[str, Any] - """Task attributes (arguments passed to `@task`)""" - - def __str__(self) -> str: - return f"{self.file}@{self.name}" - - def __hash__(self) -> int: - return hash( - (self.file, self.name) - + tuple(self.attribs.keys()) - + tuple(self.attribs.values()) - ) - - -@dataclass -class TaskSpec: - id: str - task: str - - -Tasks = ( - str - | TaskSpec - | TaskInfo - | Task - | Callable[..., Task] - | type[Task] - | list[str] - | list[TaskInfo] - | list[Task] - | list[Callable[..., Task]] - | list[type[Task]] - | None -) -r"""One or more tasks. - -Tasks to be evaluated. Many forms of task specification are -supported including directory names, task functions, task -classes, and task instances (a single task or list of tasks -can be specified). None is a request to read a task out -of the current working directory. -""" - - -def task_file(task: Task, relative: bool = False) -> str | None: - file = cast(str | None, getattr(task, TASK_FILE_ATTR, None)) - if file: - if relative: - return cwd_relative_path(file) - else: - return file - else: - return None - - -def task_run_dir(task: Task) -> str: - return getattr(task, TASK_RUN_DIR_ATTR, os.getcwd()) - - -def sample_messages(sample: Sample) -> list[ChatMessage]: - if isinstance(sample.input, str): - return [ChatMessageUser(content=sample.input, source="input")] - else: - messages = deepcopy(sample.input) - for message in messages: - message.source = "input" - return messages - - -def has_max_messages(state: TaskState, max_messages: int | None) -> bool: - return max_messages is not None and (len(state.messages) >= max_messages) - - -async def states_with_base64_images(states: list[TaskState]) -> list[TaskState]: - return await asyncio.gather(*[state_with_base64_images(state) for state in states]) - - -async def state_with_base64_images(state: TaskState) -> TaskState: - state.messages = await messages_with_base64_images(state.messages) - return state - - -def collect_eval_data(stats: EvalStats, logger: EvalLogger) -> None: - # collect stats - stats.completed_at = iso_now() - stats.model_usage = collect_model_usage() - - # collect log output - for record in collect_logger_records(): - logger.log_event("logging", LoggingMessage.from_log_record(record)) - - -def tools_info(tools: list[Tool]) -> list[ToolInfo]: - tdefs = tool_defs(tools) - return [ - ToolInfo(name=tool.name, description=tool.description, params=tool.params) - for tool in tdefs - ] - - -async def call_tool( - tools: list[ToolDef], call: ToolCall, metadata: dict[str, Any] -) -> Any: - # find the tool - tool_def = next((tool for tool in tools if tool.name == call.function), None) - if tool_def is None: - return f"Tool {call.function} not found" - - # resolve metadata params and prepend to arguments - tool_params: dict[str, str] = registry_info(tool_def.tool).metadata.get( - TOOL_PARAMS, {} - ) - resolved_params: dict[str, Any] = {} - for name, value in tool_params.items(): - key = value.removeprefix("metadata.") - resolved = metadata.get(key, None) - if resolved is None: - raise ValueError(f"Metadata value '{key}' not found for tool parameter") - resolved_params[name] = resolved - arguments = resolved_params | call.arguments - - # call the tool - try: - return await tool_def.tool(**arguments) - except Exception as e: - return f"Error: {exception_message(e)}" diff --git a/src/inspect_ai/_eval/task/constants.py b/src/inspect_ai/_eval/task/constants.py new file mode 100644 index 000000000..e1e7ca594 --- /dev/null +++ b/src/inspect_ai/_eval/task/constants.py @@ -0,0 +1,2 @@ +TASK_FILE_ATTR = "__task_file__" +TASK_RUN_DIR_ATTR = "__task_run_dir__" diff --git a/src/inspect_ai/_eval/task/generate.py b/src/inspect_ai/_eval/task/generate.py new file mode 100644 index 000000000..379fba594 --- /dev/null +++ b/src/inspect_ai/_eval/task/generate.py @@ -0,0 +1,121 @@ +from typing import Any + +from inspect_ai._util.error import exception_message +from inspect_ai._util.registry import ( + registry_info, +) +from inspect_ai.model import ( + ChatMessageTool, + GenerateConfig, + Model, + ToolCall, + ToolFunction, + ToolInfo, +) +from inspect_ai.solver import TaskState, Tool +from inspect_ai.solver._tool.tool import TOOL_PARAMS +from inspect_ai.solver._tool.tool_def import ToolDef, tool_defs + +from .util import has_max_messages + + +async def task_generate( + model: Model, + state: TaskState, + config: GenerateConfig, + max_messages: int | None, +) -> TaskState: + # track tool_choice (revert to "none" after first forced call of a tool) + tool_choice = state.tool_choice + + while True: + # call the model + output = await model.generate( + state.messages, + tools_info(state.tools), + tool_choice, + config, + ) + + # append the assistant message + message = output.choices[0].message + state.messages.append(message) + + # check for max messages + if has_max_messages(state, max_messages): + state.output = output + return state + + # resolve tool calls if necessary + tdefs = tool_defs(state.tools) + if message.tool_calls and len(message.tool_calls) > 0: + for tool_call in message.tool_calls: + tool_error: str | None = None + try: + result = await call_tool(tdefs, tool_call, state.metadata) + except Exception as ex: + result = "" + tool_error = exception_message(ex) + + if isinstance(result, tuple): + result, metadata = result + state.metadata.update(metadata) + + state.messages.append( + ChatMessageTool( + content=str(result), + tool_error=tool_error, + tool_call_id=tool_call.id, + ) + ) + + # check for max messages + if has_max_messages(state, max_messages): + state.output = output + return state + + # if a tool_call was forced set tool_choice to 'none' + # (otherwise it will get forced over and over again) + if isinstance(tool_choice, ToolFunction): + tool_choice = "none" + + # no tool calls, we are done! + else: + state.output = output + return state + + +def tools_info(tools: list[Tool]) -> list[ToolInfo]: + tdefs = tool_defs(tools) + return [ + ToolInfo(name=tool.name, description=tool.description, params=tool.params) + for tool in tdefs + ] + + +async def call_tool( + tools: list[ToolDef], call: ToolCall, metadata: dict[str, Any] +) -> Any: + # find the tool + tool_def = next((tool for tool in tools if tool.name == call.function), None) + if tool_def is None: + return f"Tool {call.function} not found" + + # resolve metadata params and prepend to arguments + tool_params: dict[str, str] = registry_info(tool_def.tool).metadata.get( + TOOL_PARAMS, {} + ) + resolved_params: dict[str, Any] = {} + for name, value in tool_params.items(): + key = value.removeprefix("metadata.") + resolved = metadata.get(key, None) + if resolved is None: + raise ValueError(f"Metadata value '{key}' not found for tool parameter") + resolved_params[name] = resolved + arguments = resolved_params | call.arguments + + # call the tool + try: + return await tool_def.tool(**arguments) + except Exception as e: + return f"Error: {exception_message(e)}" diff --git a/src/inspect_ai/_eval/images.py b/src/inspect_ai/_eval/task/images.py similarity index 82% rename from src/inspect_ai/_eval/images.py rename to src/inspect_ai/_eval/task/images.py index a87623a7e..75ae6e4ff 100644 --- a/src/inspect_ai/_eval/images.py +++ b/src/inspect_ai/_eval/task/images.py @@ -3,6 +3,16 @@ from inspect_ai._util.images import image_as_data_uri from inspect_ai.dataset import Sample from inspect_ai.model import ChatMessage, ChatMessageUser, Content, ContentImage +from inspect_ai.solver import TaskState + + +async def states_with_base64_images(states: list[TaskState]) -> list[TaskState]: + return await asyncio.gather(*[state_with_base64_images(state) for state in states]) + + +async def state_with_base64_images(state: TaskState) -> TaskState: + state.messages = await messages_with_base64_images(state.messages) + return state async def samples_with_base64_images(samples: list[Sample]) -> list[Sample]: diff --git a/src/inspect_ai/_eval/log.py b/src/inspect_ai/_eval/task/log.py similarity index 67% rename from src/inspect_ai/_eval/log.py rename to src/inspect_ai/_eval/task/log.py index bb12e92ad..d1a11818a 100644 --- a/src/inspect_ai/_eval/log.py +++ b/src/inspect_ai/_eval/task/log.py @@ -1,4 +1,5 @@ from importlib import metadata as importlib_metadata +from logging import LogRecord from typing import Any from shortuuid import uuid @@ -7,6 +8,10 @@ from inspect_ai._util.datetime import iso_now from inspect_ai._util.git import git_context from inspect_ai._util.path import cwd_relative_path +from inspect_ai._util.registry import ( + registry_log_name, + registry_params, +) from inspect_ai.dataset import Dataset, Sample from inspect_ai.log import ( EvalConfig, @@ -14,6 +19,7 @@ EvalError, EvalLog, EvalPlan, + EvalPlanStep, EvalResults, EvalRevision, EvalSample, @@ -22,12 +28,18 @@ LoggingMessage, ) from inspect_ai.log._log import LogEvent, Recorder -from inspect_ai.model import Model, ModelName +from inspect_ai.model import ( + GenerateConfig, + Model, + ModelName, +) +from inspect_ai.model._model import collect_model_usage from inspect_ai.scorer import Score -from inspect_ai.solver import TaskState +from inspect_ai.solver import Plan, Solver, TaskState +from inspect_ai.util._context.logger import collect_logger_records -class EvalLogger: +class TaskLogger: def __init__( self, task_name: str, @@ -123,3 +135,50 @@ def log_success(self, stats: EvalStats) -> EvalLog: def log_failure(self, stats: EvalStats, error: EvalError) -> EvalLog: return self.recorder.log_failure(self.eval, stats, error) + + +def log_output( + logger: TaskLogger, + epoch: int, + samples: list[Sample], + states: list[TaskState], + scores: list[Score | None], +) -> None: + for i in range(len(samples)): + logger.log_sample(epoch, samples[i], states[i], scores[i]) + + +def log_plan( + logger: TaskLogger, + plan: Plan, + config: GenerateConfig, +) -> None: + def eval_plan_step(solver: Solver) -> EvalPlanStep: + return EvalPlanStep( + solver=registry_log_name(solver), params=registry_params(solver) + ) + + eval_plan = EvalPlan( + name=plan.name, + steps=[eval_plan_step(solver) for solver in plan.steps], + finish=eval_plan_step(plan.finish) if plan.finish else None, + config=config, + ) + if plan.finish: + eval_plan.steps.append(eval_plan_step(plan.finish)) + + logger.log_event("plan", eval_plan) + + +def collect_eval_data(stats: EvalStats, logger: TaskLogger) -> None: + # collect stats + stats.completed_at = iso_now() + stats.model_usage = collect_model_usage() + + # collect log output + log_logger_records(logger, collect_logger_records()) + + +def log_logger_records(logger: TaskLogger, records: list[LogRecord]) -> None: + for record in records: + logger.log_event("logging", LoggingMessage.from_log_record(record)) diff --git a/src/inspect_ai/_eval/task/results.py b/src/inspect_ai/_eval/task/results.py new file mode 100644 index 000000000..193ea669f --- /dev/null +++ b/src/inspect_ai/_eval/task/results.py @@ -0,0 +1,77 @@ +import re +from copy import deepcopy + +from inspect_ai._util.registry import ( + registry_info, + registry_log_name, + registry_params, + registry_unqualified_name, +) +from inspect_ai.log import ( + EvalMetric, + EvalResults, + EvalScorer, +) +from inspect_ai.scorer import Metric, Score, Scorer +from inspect_ai.scorer._scorer import SCORER_METRICS, scorer_metrics + + +def eval_results( + scores: list[Score], scorer: Scorer | None, metrics: list[Metric] = [] +) -> EvalResults: + # record scorer + results = EvalResults() + if scorer: + # extract non-metrics metadata + metadata = deepcopy(registry_info(scorer).metadata) + del metadata[SCORER_METRICS] + + # build results + results.scorer = EvalScorer( + name=registry_log_name(scorer), + params=registry_params(scorer), + metadata=metadata if len(metadata.keys()) > 0 else None, + ) + + # we want to use simple names for metrics in the metrics dict + # (i.e. without package prefixes). we do this by getting the + # unqualified name, then appending a suffix if there are duplicates + # this keeps the code straightforward and intuitive for users + # programming against the log (e.g. metrics["accuracy"]) vs. + # metrics["pkgname/accuracy"]) + for metric in target_metrics(scorer, metrics): + key = metrics_unique_key( + registry_unqualified_name(metric), list(results.metrics.keys()) + ) + results.metrics[key] = EvalMetric( + name=registry_log_name(metric), value=metric(scores) + ) + return results + + +def metrics_unique_key(key: str, existing: list[str]) -> str: + if key not in existing: + return key + else: + key_index = 2 + pattern = re.compile(f"{re.escape(key)}(\\d+)") + for existing_key in existing: + match = pattern.match(existing_key) + index = int(match.group(1)) if match else None + if index and (index >= key_index): + key_index = index + 1 + return f"{key}{key_index}" + + +# build a list of metrics (scorer built-in metrics + de-duplicated additional metrics) +def target_metrics(scorer: Scorer, metrics: list[Metric]) -> list[Metric]: + target_metrics = scorer_metrics(scorer) + target_metrics_names = [registry_log_name(metric) for metric in target_metrics] + target_metrics.extend( + [ + metric + for metric in metrics + if registry_log_name(metric) not in target_metrics_names + ] + ) + return target_metrics diff --git a/src/inspect_ai/_eval/task/run.py b/src/inspect_ai/_eval/task/run.py new file mode 100644 index 000000000..bd2f122f3 --- /dev/null +++ b/src/inspect_ai/_eval/task/run.py @@ -0,0 +1,323 @@ +import asyncio +import sys +from copy import deepcopy +from logging import getLogger +from typing import Callable + +from typing_extensions import Unpack + +from inspect_ai._display import display +from inspect_ai._display._display import TaskProfile +from inspect_ai._util.constants import DEFAULT_EPOCHS +from inspect_ai._util.datetime import iso_now +from inspect_ai._util.dotenv import dotenv_environ +from inspect_ai._util.error import exception_message +from inspect_ai._util.path import chdir_python +from inspect_ai._util.registry import ( + is_registry_object, + registry_log_name, +) +from inspect_ai.dataset import Dataset, Sample +from inspect_ai.log import ( + EvalConfig, + EvalError, + EvalLog, + EvalStats, +) +from inspect_ai.log._log import eval_error +from inspect_ai.model import ( + GenerateConfigArgs, + Model, + ModelName, +) +from inspect_ai.scorer import Score, Scorer, Target +from inspect_ai.solver import Generate, Plan, Solver, TaskState + +from ..types import Task +from .generate import task_generate +from .images import samples_with_base64_images, states_with_base64_images +from .log import TaskLogger, collect_eval_data, log_output, log_plan +from .results import eval_results +from .util import has_max_messages, sample_messages, task_run_dir + +logger = getLogger(__name__) + + +async def task_run( + task: Task, + sequence: tuple[int, int], + model: Model, + logger: TaskLogger, + config: EvalConfig = EvalConfig(), + plan: Plan | Solver | list[Solver] | None = None, + score: bool = True, + **kwargs: Unpack[GenerateConfigArgs], +) -> EvalLog: + r"""Run the task. + + Run the task with the passed model and configuration, using the + samples, scorer, metrics and solver(s) specified for the task. + + Args: + task (Task): Task to run. + sequence (int): Sequence of the run within a larger set of runs + model (Model): Model used to generate output + logger (TaskLogger): Logger for recording results. + config (EvalConfig): Config (sample range/epochs, logging options) + plan:(Plan | Solver | list[Solver] | None): Override of + task default plan. + score (bool | None): Score model output. If not specified + is determined automatically based on whether the task + has a solver and metrics defined. + **kwargs (GenerateConfigArgs): Generation config options + + Returns: + EvalLog for executed task. + + """ + with chdir_python(task_run_dir(task)), dotenv_environ(): + # track stats and error + stats = EvalStats(started_at=iso_now()) + error: EvalError | None = None + + # resolve some config + model_name = ModelName(model) + epochs = config.epochs if config.epochs else DEFAULT_EPOCHS + log_images = config.log_images is not False + log_samples = config.log_samples is not False + generate_config = task.config.merge(GenerateConfigArgs(**kwargs)) + + # resolve dataset + dataset, samples, states = await resolve_dataset( + dataset=task.dataset, + model_name=model_name, + limit=config.limit, + epochs=epochs, + log_images=log_images, + ) + + # resolve the plan and scorer + plan = ( + plan + if isinstance(plan, Plan) + else Plan(plan) if plan is not None else task.plan + ) + score = score and task.scorer is not None + scorer: Scorer | None = task.scorer if (score and task.scorer) else None + + # create task profile for display + profile = TaskProfile( + name=task.name, + sequence=sequence, + model=model_name, + dataset=task.dataset.name or "(samples)", + scorer=( + registry_log_name(scorer) if is_registry_object(scorer) else "(none)" + ), + samples=len(samples), + eval_config=config, + task_args=logger.eval.task_args, + generate_config=generate_config, + log_location=logger.location, + ) + + with display().task(profile) as td: + try: + # log the plan + log_plan(logger, plan, generate_config) + + # run w/ progress (steps = samples * steps in plan + 1 for scorer) + total_steps = len(samples) * ( + len(plan.steps) + (1 if plan.finish else 0) + (1) # scorer + ) + with td.progress(total=total_steps) as p: + + # forward progress + def progress() -> None: + p.update(1) + + # provide solvers a function that they can use to generate output + async def generate( + state: TaskState, **kwargs: Unpack[GenerateConfigArgs] + ) -> TaskState: + return await task_generate( + model=model, + state=state, + config=generate_config.merge(kwargs), + max_messages=config.max_messages, + ) + + tasks = [ + task_run_sample( + task_name=task.name, + sample=sample, + state=state, + plan=plan, + max_messages=config.max_messages, + scorer=scorer, + generate=generate, + progress=progress, + ) + for (sample, state) in zip(samples, states) + ] + + # run them in parallel + scores = await asyncio.gather(*tasks) + + # log output by epoch + if log_samples is not False: + # if we are logging images then be sure to base64 images injected by solvers + if log_images: + states = await states_with_base64_images(states) + + for e in range(0, epochs): + sl = slice(e * len(dataset), (e + 1) * (len(dataset))) + log_output(logger, e + 1, samples[sl], states[sl], scores[sl]) + + # compute and record metrics if we have scores + completed_scores = [ + score for score in scores if isinstance(score, Score) + ] + if len(completed_scores) > 0: + results = eval_results( + scores=completed_scores, + scorer=scorer, + metrics=task.metrics, + ) + logger.log_results(results) + + # collect eval data + collect_eval_data(stats, logger) + + # display task summary + td.summary(results, stats) + + except asyncio.CancelledError as ex: + raise ex + + except BaseException as ex: + # mark completed + stats.completed_at = iso_now() + + # get exception info + type, value, traceback = sys.exc_info() + type = type if type else BaseException + value = value if value else ex + + # build eval error + error = eval_error(ex, type, value, traceback) + + # collect eval data + collect_eval_data(stats, logger) + + # display it + td.error(error, type, value, traceback) + + # log as appropriate + if error: + return logger.log_failure(stats, error) + else: + return logger.log_success(stats) + + +async def task_run_sample( + task_name: str, + sample: Sample, + state: TaskState, + plan: Plan, + max_messages: int | None, + scorer: Scorer | None, + generate: Generate, + progress: Callable[..., None], +) -> Score | None: + # solver loop + try: + # run plan steps (checking for early termination) + for index, solver in enumerate(plan.steps): + # run the solver + state = await solver(state, generate) + progress() + + # check for early termination (tick remaining progress) + if state.completed or has_max_messages(state, max_messages): + for _ in range(index + 1, len(plan.steps)): + progress() + break + + # run finishing step them mark completed + if plan.finish: + state = await plan.finish(state, generate) + progress() + state.completed = True + + finally: + # safely run cleanup function if there is one + if plan.cleanup: + try: + await plan.cleanup(state) + except Exception as ex: + logger.warning( + f"Exception occurred during plan cleanup for task {task_name}: " + + f"{exception_message(ex)}" + ) + pass + + # score it + result = await scorer(state, Target(sample.target)) if scorer else None + progress() + + # return + return result + + +async def resolve_dataset( + dataset: Dataset, + model_name: ModelName, + limit: int | tuple[int, int] | None, + epochs: int, + log_images: bool, +) -> tuple[Dataset, list[Sample], list[TaskState]]: + + # apply limit to dataset + dataset_limit = ( + slice(0, len(dataset)) + if limit is None + else (slice(*limit) if isinstance(limit, tuple) else slice(0, limit)) + ) + dataset = dataset[dataset_limit] + + # add sample ids to dataset if they aren't there (start at 1 not 0) + for id, sample in zip(range(dataset_limit.start, dataset_limit.stop), dataset): + if sample.id is None: + sample.id = id + 1 + + # apply epochs (deepcopy the samples so they remain independent) + samples: list[Sample] = [] + for _ in range(0, epochs): + samples.extend([deepcopy(sample) for sample in dataset]) + + # if we are logging images then resolve sample images here + if log_images: + samples = await samples_with_base64_images(samples) + + # prime the eval tasks (deep copy so they share no state w/ sample) + sample_epochs: list[int] = [] + for e in range(0, epochs): + sample_epochs.extend([e + 1] * len(dataset)) + states = [ + deepcopy( + TaskState( + sample_id=sample.id or 0, + epoch=epoch, + model=model_name, + input=sample.input, + choices=sample.choices, + messages=sample_messages(sample), + completed=False, + metadata=sample.metadata if sample.metadata else {}, + ) + ) + for epoch, sample in zip(sample_epochs, samples) + ] + + return (dataset, samples, states) diff --git a/src/inspect_ai/_eval/task/util.py b/src/inspect_ai/_eval/task/util.py new file mode 100644 index 000000000..093fa3583 --- /dev/null +++ b/src/inspect_ai/_eval/task/util.py @@ -0,0 +1,40 @@ +import os +from copy import deepcopy +from typing import cast + +from inspect_ai._util.path import cwd_relative_path +from inspect_ai.dataset import Sample +from inspect_ai.model import ChatMessage, ChatMessageUser +from inspect_ai.solver import TaskState + +from ..types import Task +from .constants import TASK_FILE_ATTR, TASK_RUN_DIR_ATTR + + +def sample_messages(sample: Sample) -> list[ChatMessage]: + if isinstance(sample.input, str): + return [ChatMessageUser(content=sample.input, source="input")] + else: + messages = deepcopy(sample.input) + for message in messages: + message.source = "input" + return messages + + +def has_max_messages(state: TaskState, max_messages: int | None) -> bool: + return max_messages is not None and (len(state.messages) >= max_messages) + + +def task_run_dir(task: Task) -> str: + return getattr(task, TASK_RUN_DIR_ATTR, os.getcwd()) + + +def task_file(task: Task, relative: bool = False) -> str | None: + file = cast(str | None, getattr(task, TASK_FILE_ATTR, None)) + if file: + if relative: + return cwd_relative_path(file) + else: + return file + else: + return None diff --git a/src/inspect_ai/_eval/types.py b/src/inspect_ai/_eval/types.py new file mode 100644 index 000000000..97ba42f0c --- /dev/null +++ b/src/inspect_ai/_eval/types.py @@ -0,0 +1,132 @@ +from dataclasses import dataclass +from logging import getLogger +from typing import Any, Callable, Sequence, cast + +from pydantic import BaseModel + +from inspect_ai._util.registry import is_registry_object, registry_info +from inspect_ai.dataset import Dataset, MemoryDataset, Sample +from inspect_ai.model import GenerateConfig +from inspect_ai.scorer import Metric, Scorer +from inspect_ai.solver import Plan, Solver, generate + +logger = getLogger(__name__) + + +class Task: + r"""Evaluation task. + + Tasks are the basis for defining and running evaluations. Tasks + are parameterized with a dataset, a scorer, and metrics. Tasks + also may optionally provide a default plan for execution. + + Args: + dataset (Dataset | Sequence[Sample]): Dataset to evaluate + plan: (Plan | Solver | list[Solver]): Default plan. If not specified + defaults to generate(), a normal call to the model. + scorer: (Scorer | None): Scorer used to evaluate model output. + metrics (list[Metric]): Additional metrics to compute beyond + the base metrics provided by the scorer. + config (GenerateConfig): Model generation config. + epochs (int): Default number of epochs to run for. + max_messages (int | None): Limit on total messages in the conversation. + name: (str | None): Task name. If not specified is automatically + determined based on the name of the task directory (or "task") + if its anonymous task (e.g. created in a notebook and passed to + eval() directly) + version: (int): Version of task (to distinguish evolutions + of the task spec or breaking changes to it) + """ + + def __init__( + self, + dataset: Dataset | Sequence[Sample], + plan: Plan | Solver | list[Solver] = generate(), + scorer: Scorer | None = None, + metrics: list[Metric] = [], + config: GenerateConfig = GenerateConfig(), + epochs: int | None = None, + max_messages: int | None = None, + name: str | None = None, + version: int = 0, + ) -> None: + self.dataset: Dataset = ( + dataset if isinstance(dataset, Dataset) else MemoryDataset(list(dataset)) + ) + self.plan = plan if isinstance(plan, Plan) else Plan(plan) + self.scorer = scorer + self.metrics = metrics + self.config = config + self.epochs = epochs + self.max_messages = max_messages + self.version = version + self._name = name + + @property + def name(self) -> str: + if self._name is not None: + return self._name + elif is_registry_object(self): + return registry_info(self).name + else: + return "task" + + @property + def attribs(self) -> dict[str, Any]: + if is_registry_object(self): + return cast(dict[str, Any], registry_info(self).metadata.get("attribs", {})) + else: + return dict() + + +class TaskInfo(BaseModel): + """Task information (file, name, and attributes).""" + + file: str + """File path where task was loaded from.""" + + name: str + """Task name (defaults to function name)""" + + attribs: dict[str, Any] + """Task attributes (arguments passed to `@task`)""" + + def __str__(self) -> str: + return f"{self.file}@{self.name}" + + def __hash__(self) -> int: + return hash( + (self.file, self.name) + + tuple(self.attribs.keys()) + + tuple(self.attribs.values()) + ) + + +@dataclass +class TaskSpec: + id: str + task: str + + +Tasks = ( + str + | TaskSpec + | TaskInfo + | Task + | Callable[..., Task] + | type[Task] + | list[str] + | list[TaskInfo] + | list[Task] + | list[Callable[..., Task]] + | list[type[Task]] + | None +) +r"""One or more tasks. + +Tasks to be evaluated. Many forms of task specification are +supported including directory names, task functions, task +classes, and task instances (a single task or list of tasks +can be specified). None is a request to read a task out +of the current working directory. +""" diff --git a/src/inspect_ai/_util/_async.py b/src/inspect_ai/_util/_async.py new file mode 100644 index 000000000..34a80864d --- /dev/null +++ b/src/inspect_ai/_util/_async.py @@ -0,0 +1,10 @@ +import asyncio +from typing import Any + + +def is_callable_coroutine(func_or_cls: Any) -> bool: + if asyncio.iscoroutinefunction(func_or_cls): + return True + elif callable(func_or_cls): + return asyncio.iscoroutinefunction(func_or_cls.__call__) + return False diff --git a/src/inspect_ai/log/_file.py b/src/inspect_ai/log/_file.py index 34689ab54..37efe2579 100644 --- a/src/inspect_ai/log/_file.py +++ b/src/inspect_ai/log/_file.py @@ -1,4 +1,3 @@ -import json import os import re from pathlib import Path @@ -7,7 +6,7 @@ import json_stream # type: ignore from pydantic import BaseModel, Field -from pydantic_core import to_json +from pydantic_core import from_json, to_json from inspect_ai._util.file import FileInfo, file, filesystem @@ -116,38 +115,69 @@ def read_eval_log(log_file: str | FileInfo, header_only: bool = False) -> EvalLo Returns: EvalLog object read from file. """ + # resolve to file path log_file = log_file if isinstance(log_file, str) else log_file.name + + # verify we know about this version of the log file format + def validate_version(ver: int) -> None: + if ver > 1: + raise ValueError(f"Unable to read version {ver} of log format.") + + # header-only uses json-stream + if header_only: + with file(log_file, "r") as f: + try: + data = json_stream.load(f, persistent=True) + + def read_field(field: str) -> Any: + if field in data.keys(): + return json_stream.to_standard_types(data[field]) + else: + return None + + # fail for unknown version + version = read_field("version") + validate_version(version) + + results = read_field("results") + error = read_field("error") + + return EvalLog( + version=version, + status=read_field("status"), + eval=EvalSpec(**read_field("eval")), + plan=EvalPlan(**read_field("plan")), + results=EvalResults(**results) if results else None, + stats=EvalStats(**read_field("stats")), + error=EvalError(**error) if error else None, + ) + # The Python JSON serializer supports NaN and Inf, however + # this isn't technically part of the JSON spec. The json-stream + # library shares this limitation, so if we fail with an + # invalid character then we move on and and parse w/ pydantic + # (which does support NaN and Inf by default) + except ValueError as ex: + if str(ex).find("Invalid JSON character") != -1: + pass + else: + raise ex + + # parse full log (also used as a fallback for header_only encountering NaN or Inf) with file(log_file, "r") as f: - # header-only uses json-stream + # parse w/ pydantic + raw_data = from_json(f.read()) + log = EvalLog(**raw_data) + + # fail for unknown version + validate_version(log.version) + + # prune if header_only if header_only: - data = json_stream.load(f, persistent=True) + log.samples = None + log.logging.clear() - def read_field(field: str) -> Any: - if field in data.keys(): - return json_stream.to_standard_types(data[field]) - else: - return None - - results = read_field("results") - error = read_field("error") - - return EvalLog( - version=read_field("version"), - status=read_field("status"), - eval=EvalSpec(**read_field("eval")), - plan=EvalPlan(**read_field("plan")), - results=EvalResults(**results) if results else None, - stats=EvalStats(**read_field("stats")), - error=EvalError(**error) if error else None, - ) - - # otherwise normal json parse - else: - raw_data = json.load(f) - log = EvalLog(**raw_data) - if log.version > 1: - raise ValueError(f"Unable to read version {log.version} of log format.") - return log + # return log + return log def read_eval_log_headers(log_files: list[str] | list[FileInfo]) -> list[EvalLog]: diff --git a/src/inspect_ai/model/_model.py b/src/inspect_ai/model/_model.py index b7131086b..fbd075ece 100644 --- a/src/inspect_ai/model/_model.py +++ b/src/inspect_ai/model/_model.py @@ -304,6 +304,8 @@ class ModelUsage(BaseModel): class TopLogprob(BaseModel): + """List of the most likely tokens and their log probability, at this token position.""" + token: str """The top-kth token represented as a string.""" @@ -315,6 +317,8 @@ class TopLogprob(BaseModel): class Logprob(BaseModel): + """Log probability for a token.""" + token: str """The predicted token represented as a string.""" @@ -329,6 +333,8 @@ class Logprob(BaseModel): class Logprobs(BaseModel): + """Log probability information for a completion choice.""" + content: list[Logprob] """a (num_generated_tokens,) length list containing the individual log probabilities for each generated token.""" @@ -759,32 +765,25 @@ def get_model( model = "/".join(parts[1:]) # predicate to match model - def match_model(info: RegistryInfo) -> bool: + def match_modelapi_type(info: RegistryInfo) -> bool: # strip package name (we use the 'api' as the namespace, we will # introduce package scoping if it proves necessary) - if info.type == "modelapi": - # model patterns for this provider - models = info.metadata.get("models", []) - - # if there is an api_name explicitly specified that - # matches the registered api then trust the model name - # TODO: this is ugly, we need to clarify the relationship - # and registration semantics of pkg -> provider -> model - if ( - info.name == api_name - or info.name.replace(f"{PKG_NAME}/", "") == api_name - ): - return True - # otherwise check for a name match - else: - return len([name for name in models if name in model]) > 0 + + # if there is an api_name explicitly specified that + # matches the registered api then trust the model name + # TODO: this is ugly, we need to clarify the relationship + # and registration semantics of pkg -> provider -> model + if info.type == "modelapi" and ( + info.name == api_name or info.name.replace(f"{PKG_NAME}/", "") == api_name + ): + return True else: return False # find a matching model type - model_types = registry_find(match_model) - if len(model_types) > 0: - modelapi_type = cast(type[ModelAPI], model_types[0]) + modelapi_types = registry_find(match_modelapi_type) + if len(modelapi_types) > 0: + modelapi_type = cast(type[ModelAPI], modelapi_types[0]) modelapi_instance = modelapi_type( model_name=model, base_url=base_url, config=config, **model_args ) @@ -879,7 +878,7 @@ def consecutive_message_reducer( def combine_messages( a: ChatMessage, b: ChatMessage, message_type: Type[ChatMessage] -) -> ChatMessageUser: +) -> ChatMessage: if isinstance(a.content, str) and isinstance(b.content, str): return message_type(content=f"{a.content}\n{b.content}") elif isinstance(a.content, list) and isinstance(b.content, list): diff --git a/src/inspect_ai/model/_providers/hf.py b/src/inspect_ai/model/_providers/hf.py index 5aecf6add..8766d29f5 100644 --- a/src/inspect_ai/model/_providers/hf.py +++ b/src/inspect_ai/model/_providers/hf.py @@ -153,7 +153,7 @@ async def generate( if config.logprobs is not None: final_logprobs = extract_logprobs( response=response, - top_logprobs=config.top_logprobs, + top=config.top_logprobs, tokenizer=self.tokenizer, ) @@ -163,9 +163,9 @@ async def generate( content=response.output, source="generate", ), - logprobs=Logprobs(content=final_logprobs) - if final_logprobs is not None - else None, + logprobs=( + Logprobs(content=final_logprobs) if final_logprobs is not None else None + ), ) # return output @@ -214,6 +214,12 @@ def set_random_seeds(seed: int | None = None) -> None: set_seed(seed) +# return value from generate as a result of specifying return_dict_in_generate +class ModelGenerateOutput: + sequences: Tensor + logits: tuple[Tensor] + + class Tokenizer(Protocol): def __call__( self, input: list[str] @@ -315,8 +321,9 @@ def process_batches() -> None: # generate with torch.inference_mode(): - generation_outputs = generator( - input_ids=input_ids, attention_mask=attention_mask + generation_outputs = cast( + ModelGenerateOutput, + generator(input_ids=input_ids, attention_mask=attention_mask), ) generate_ids = generation_outputs.sequences logits = generation_outputs.logits @@ -324,8 +331,8 @@ def process_batches() -> None: # get logprobs from logits logprobs = None if logits is not None: - logits = torch.stack(logits).transpose(0, 1) - logprobs = torch.nn.functional.log_softmax(logits, dim=-1) + stacked_logits = torch.stack(logits).transpose(0, 1) + logprobs = torch.nn.functional.log_softmax(stacked_logits, dim=-1) # decode generated_tokens = generate_ids[:, input_ids.size(dim=1) :] @@ -361,15 +368,15 @@ def process_batches() -> None: def extract_logprobs( response: GenerateOutput, - top_logprobs: int | None, + top: int | None, tokenizer: PreTrainedTokenizerBase, ) -> list[Logprob]: assert response.logprobs is not None - k = top_logprobs or 1 + k = top or 1 topk_values, topk_inds = response.logprobs.topk(k=k, dim=-1) final_logprobs = [] for toks, vals in zip(topk_inds, topk_values): - top_logprobs = [] + top_logprobs: list[TopLogprob] = [] for tok, val in zip(toks, vals): # TODO: you get byte artifacts converting single ids to tokens like this... # but `tokenizer.decode` strips spaces. There must be a better way to do this. diff --git a/src/inspect_ai/model/_providers/ollama.py b/src/inspect_ai/model/_providers/ollama.py new file mode 100644 index 000000000..d782c58b4 --- /dev/null +++ b/src/inspect_ai/model/_providers/ollama.py @@ -0,0 +1,18 @@ +from inspect_ai.model._providers.util import model_base_url + +from .._model import GenerateConfig +from .openai import OpenAIAPI + + +class OllamaAPI(OpenAIAPI): + def __init__( + self, + model_name: str, + base_url: str | None = None, + config: GenerateConfig = GenerateConfig(), + ) -> None: + base_url = model_base_url(base_url, "OLLAMA_BASE_URL") + base_url = base_url if base_url else "http://localhost:11434/v1" + super().__init__( + model_name=model_name, base_url=base_url, config=config, api_key="ollama" + ) diff --git a/src/inspect_ai/model/_providers/providers.py b/src/inspect_ai/model/_providers/providers.py index 548f68312..2c60e66e1 100644 --- a/src/inspect_ai/model/_providers/providers.py +++ b/src/inspect_ai/model/_providers/providers.py @@ -10,7 +10,7 @@ # strictly require this treatment but we do it anyway for uniformity, -@modelapi(name="openai", models=["gpt"]) +@modelapi(name="openai") def openai() -> type[ModelAPI]: # validate validate_openai_client("OpenAI API") @@ -21,7 +21,7 @@ def openai() -> type[ModelAPI]: return OpenAIAPI -@modelapi(name="anthropic", models=["claude"]) +@modelapi(name="anthropic") def anthropic() -> type[ModelAPI]: FEATURE = "Anthropic API" PACKAGE = "anthropic" @@ -42,7 +42,7 @@ def anthropic() -> type[ModelAPI]: return AnthropicAPI -@modelapi(name="google", models=["gemini", "bison", "gdm"]) +@modelapi(name="google") def google() -> type[ModelAPI]: FEATURE = "Google API" PACKAGE = "google-generativeai" @@ -68,7 +68,9 @@ def hf() -> type[ModelAPI]: try: from .hf import HuggingFaceAPI except ImportError: - raise pip_dependency_error("Hugging Face Models", ["torch", "transformers"]) + raise pip_dependency_error( + "Hugging Face Models", ["torch", "transformers", "accelerate"] + ) return HuggingFaceAPI @@ -112,6 +114,17 @@ def together() -> type[ModelAPI]: return TogetherAIAPI +@modelapi(name="ollama") +def ollama() -> type[ModelAPI]: + # validate + validate_openai_client("Ollama API") + + # in the clear + from .ollama import OllamaAPI + + return OllamaAPI + + @modelapi(name="azureai") def azureai() -> type[ModelAPI]: from .azureai import AzureAIAPI diff --git a/src/inspect_ai/model/_registry.py b/src/inspect_ai/model/_registry.py index fab4a9da2..7219bca6e 100644 --- a/src/inspect_ai/model/_registry.py +++ b/src/inspect_ai/model/_registry.py @@ -10,34 +10,28 @@ from ._model import ModelAPI -def modelapi_register( - model_type: type[ModelAPI], name: str, models: list[str] -) -> type[ModelAPI]: +def modelapi_register(model_type: type[ModelAPI], name: str) -> type[ModelAPI]: r"""Register a model api. Args: model_type (type[Model]): Class deriving from Model name (str): API serving this model - models (list[str]): Model names by this API Returns: Model API with registry attributes. """ registry_add( model_type, - RegistryInfo(type="modelapi", name=name, metadata=dict(models=models)), + RegistryInfo(type="modelapi", name=name), ) return model_type -def modelapi(name: str, models: list[str] = []) -> Callable[..., type[ModelAPI]]: +def modelapi(name: str) -> Callable[..., type[ModelAPI]]: r"""Decorator for registering model APIs. Args: name (str): Name of API - models (list[str]): Model names that should match this API. - If no `models` are provided then this model type will always - require an API prefix (e.g. "hf/openai-community/gpt2") Returns: Model API with registry attributes. @@ -66,14 +60,13 @@ def model_wrapper(*args: Any, **kwargs: Any) -> ModelAPI: RegistryInfo( type="modelapi", name=model_api, - metadata=dict(models=models), ), *args, **kwargs, ) return model - return modelapi_register(cast(type[ModelAPI], model_wrapper), model_api, models) + return modelapi_register(cast(type[ModelAPI], model_wrapper), model_api) def wrapper( model_type: type[ModelAPI] | Callable[..., type[ModelAPI]], diff --git a/src/inspect_ai/scorer/__init__.py b/src/inspect_ai/scorer/__init__.py index 143d2083d..6cbeb82ca 100644 --- a/src/inspect_ai/scorer/__init__.py +++ b/src/inspect_ai/scorer/__init__.py @@ -16,6 +16,7 @@ from ._metrics.mean import mean from ._metrics.std import bootstrap_std from ._model import model_graded_fact, model_graded_qa +from ._multi import ScoreReducer, majority_vote, multi_scorer from ._pattern import pattern from ._scorer import ( Scorer, @@ -47,4 +48,7 @@ "INCORRECT", "PARTIAL", "NOANSWER", + "multi_scorer", + "majority_vote", + "ScoreReducer", ] diff --git a/src/inspect_ai/scorer/_model.py b/src/inspect_ai/scorer/_model.py index 81d7665da..eaaa48125 100644 --- a/src/inspect_ai/scorer/_model.py +++ b/src/inspect_ai/scorer/_model.py @@ -1,4 +1,5 @@ import re +from functools import partial from inspect_ai.model import ChatMessageUser, Model, get_model from inspect_ai.solver import TaskState @@ -6,6 +7,7 @@ from ._metric import INCORRECT, Score from ._metrics import accuracy, bootstrap_std +from ._multi import majority_vote, multi_scorer from ._scorer import Scorer, Target, scorer @@ -15,7 +17,7 @@ def model_graded_fact( instructions: str | None = None, grade_pattern: str | None = None, partial_credit: bool = False, - model: str | Model | None = None, + model: list[str | Model] | str | Model | None = None, ) -> Scorer: """Score a question/answer task with a fact response using a model. @@ -38,8 +40,7 @@ def model_graded_fact( to `False`. Note that this parameter is only used with the default `instructions` (as custom instructions provide their own prompts for grades). - model (str | Model | none): Model to use for grading - (by default the model being evaluated is used). + model (list[str | Model] | str | Model | None): Model or Models to use for grading. If multiple models are passed, a majority vote of their grade will be returned. By default the model being evaluated is used. """ return model_graded_qa( template=template if template else DEFAULT_MODEL_GRADED_FACT_TEMPLATE, @@ -56,7 +57,7 @@ def model_graded_qa( instructions: str | None = None, grade_pattern: str | None = None, partial_credit: bool = False, - model: str | Model | None = None, + model: list[str | Model] | str | Model | None = None, ) -> Scorer: """Score a question/answer task using a model. @@ -79,9 +80,32 @@ def model_graded_qa( to `False`. Note that this parameter is only used with the default `instructions` (as custom instructions provide their own prompts for grades). - model (str | Model | None): Model to use for grading - (by default the model being evaluated is used). + model (list[str | Model] | str | Model | None): Model or Models to use for grading. If multiple models are passed, a majority vote of their grade will be returned. By default the model being evaluated is used. """ + # bind variables + get_scorer = partial( + _model_graded_qa_single, template, instructions, grade_pattern, partial_credit + ) + # if only a single model is passed, return a single scorer + if model is None or not isinstance(model, list): + return get_scorer(model) + + # otherwise, use multi scorer + assert isinstance(model, list) + scorers = [get_scorer(model) for model in model] + return multi_scorer(scorers, majority_vote) + + +@scorer(metrics=[accuracy(), bootstrap_std()]) +def _model_graded_qa_single( + template: str | None = None, + instructions: str | None = None, + grade_pattern: str | None = None, + partial_credit: bool = False, + model: str | Model | None = None, +) -> Scorer: + # returns a scorer that does model graded qa for a single model + # resolve model grader_model = get_model(model) diff --git a/src/inspect_ai/scorer/_multi.py b/src/inspect_ai/scorer/_multi.py new file mode 100644 index 000000000..86ac9cfb5 --- /dev/null +++ b/src/inspect_ai/scorer/_multi.py @@ -0,0 +1,50 @@ +import asyncio +from collections import Counter +from typing import ( + Protocol, + runtime_checkable, +) + +from inspect_ai.solver import TaskState + +from ._metric import Score +from ._scorer import Scorer, Target + + +@runtime_checkable +class ScoreReducer(Protocol): + def __call__(self, scores: list[Score]) -> Score: ... + + +def multi_scorer(scorers: list[Scorer], reducer: ScoreReducer) -> Scorer: + r"""Returns a Scorer that runs multiple Scorers in parallel and aggregates their results into a single Score using the provided reducer function. + + Args: + scorers: a list of Scorers. + reducer: a function which takes in a list of Scores and returns a single Score. + """ + + async def score(state: TaskState, target: Target) -> Score: + scores = await asyncio.gather(*[_scorer(state, target) for _scorer in scorers]) + return reducer(scores) + + return score + + +def majority_vote(scores: list[Score]) -> Score: + r"""A utility function for taking a majority vote over a list of scores. + + Args: + scores: a list of Scores. + """ + counts: Counter[str | int | float | bool] = Counter() + for score in scores: + counts[score._as_scalar()] += 1 + return Score( + value=counts.most_common(1)[0][0], + answer=scores[0].answer, + explanation=scores[0].explanation, + metadata={ + "individual_scores": scores + }, # TODO: massage into format better for display + ) diff --git a/src/inspect_ai/scorer/_scorer.py b/src/inspect_ai/scorer/_scorer.py index 982d799fe..85622bfd6 100644 --- a/src/inspect_ai/scorer/_scorer.py +++ b/src/inspect_ai/scorer/_scorer.py @@ -10,6 +10,7 @@ runtime_checkable, ) +from inspect_ai._util._async import is_callable_coroutine from inspect_ai._util.registry import ( RegistryInfo, registry_add, @@ -134,6 +135,11 @@ def wrapper(scorer_type: ScorerType) -> ScorerType: def scorer_wrapper(*args: Any, **kwargs: Any) -> Scorer: scorer = scorer_type(*args, **kwargs) + if not is_callable_coroutine(scorer): + raise TypeError( + f"'{scorer_name}' is not declared as an async callable." + ) + registry_tag( scorer_type, scorer, diff --git a/src/inspect_ai/solver/_solver.py b/src/inspect_ai/solver/_solver.py index 0469d378c..b1d1c4e32 100644 --- a/src/inspect_ai/solver/_solver.py +++ b/src/inspect_ai/solver/_solver.py @@ -10,6 +10,7 @@ from typing_extensions import Unpack +from inspect_ai._util._async import is_callable_coroutine from inspect_ai._util.registry import ( RegistryInfo, registry_add, @@ -256,6 +257,9 @@ def create_solver_wrapper( def solver_wrapper(*args: Any, **kwargs: dict[str, Any]) -> Solver: solver = solver_type(*args, **kwargs) + if not is_callable_coroutine(solver): + raise TypeError(f"'{solver}' is not declared as an async callable.") + registry_tag( solver_type, solver, diff --git a/tests/test_anthropic.py b/tests/test_anthropic.py index e373f4bb5..d9987036b 100644 --- a/tests/test_anthropic.py +++ b/tests/test_anthropic.py @@ -8,7 +8,7 @@ @skip_if_no_anthropic async def test_anthropic_api() -> None: model = get_model( - "claude-2.1", + "anthropic/claude-2.1", config=GenerateConfig( frequency_penalty=0.0, stop_seqs=None, diff --git a/tests/test_cloudlfare.py b/tests/test_cloudflare.py similarity index 100% rename from tests/test_cloudlfare.py rename to tests/test_cloudflare.py diff --git a/tests/test_eval_log.py b/tests/test_eval_log.py index 190aa6444..fad0de3b0 100644 --- a/tests/test_eval_log.py +++ b/tests/test_eval_log.py @@ -1,8 +1,13 @@ +import math +import os + +import pytest from pydantic_core import PydanticSerializationError from utils import skip_if_no_openai from inspect_ai import Task, eval from inspect_ai.dataset import Sample +from inspect_ai.log import read_eval_log from inspect_ai.solver import ( Generate, Plan, @@ -12,6 +17,11 @@ ) +def log_path(file: str) -> str: + # use .txt extension so vscode linter doesn't complain about invalid json + return os.path.join("tests", "test_eval_log", f"{file}.txt") + + class NotSerializable: name: str @@ -35,3 +45,27 @@ async def solve(state: TaskState, generate: Generate): eval(task, model="openai/gpt-4") except PydanticSerializationError: assert False, "Eval raised Pydantic serialization error." + + +def test_read_nan(): + def check_for_nan(log): + assert math.isnan(log.results.metrics.get("accuracy").value) + + log_file = log_path("log_with_nan") + check_for_nan(read_eval_log(log_file)) + check_for_nan(read_eval_log(log_file, header_only=True)) + + +def test_fail_invalid(): + check_log_raises(log_path("log_invalid")) + + +def test_fail_version(): + check_log_raises(log_path("log_version_2")) + + +def check_log_raises(log_file): + with pytest.raises(ValueError): + read_eval_log(log_file) + with pytest.raises(ValueError): + read_eval_log(log_file, header_only=True) diff --git a/tests/test_eval_log/log_invalid.txt b/tests/test_eval_log/log_invalid.txt new file mode 100644 index 000000000..0cd7323f4 --- /dev/null +++ b/tests/test_eval_log/log_invalid.txt @@ -0,0 +1,2 @@ +{ + "version": 1, "status": \ No newline at end of file diff --git a/tests/test_eval_log/log_version_2.txt b/tests/test_eval_log/log_version_2.txt new file mode 100644 index 000000000..63e92fb01 --- /dev/null +++ b/tests/test_eval_log/log_version_2.txt @@ -0,0 +1,63 @@ +{ + "version": 2, + "status": "success", + "eval": { + "task": "wikipedia", + "task_version": 0, + "task_file": "examples/agents/langchain/wikipedia.py", + "task_id": "YAdbKczyeSb6mEgPd3R9Qs", + "run_id": "i5LyrzaUdD9K4EW5WTAd5t", + "created": "2024-05-05T07:59:35", + "dataset": { + "name": "wikipedia", + "location": "wikipedia.jsonl" + }, + "model": "openai/gpt-4", + "task_attribs": {}, + "task_args": {}, + "model_args": {}, + "config": { + "limit": 20 + } + }, + "plan": { + "name": "plan", + "steps": [ + { + "solver": "wikipedia_search", + "params": {} + } + ], + "config": {} + }, + "results": { + "scorer": { + "name": "model_graded_fact", + "params": {} + }, + "metrics": { + "accuracy": { + "name": "accuracy", + "value": 1, + "options": {} + }, + "bootstrap_std": { + "name": "bootstrap_std", + "value": 0.0, + "options": {} + } + } + }, + "stats": { + "started_at": "2024-05-05T07:59:35", + "completed_at": "2024-05-05T08:00:03", + "model_usage": { + "openai/gpt-4": { + "input_tokens": 8868, + "output_tokens": 1351, + "total_tokens": 10219 + } + } + }, + "logging": [] +} \ No newline at end of file diff --git a/tests/test_eval_log/log_with_nan.txt b/tests/test_eval_log/log_with_nan.txt new file mode 100644 index 000000000..bbf5692e6 --- /dev/null +++ b/tests/test_eval_log/log_with_nan.txt @@ -0,0 +1,63 @@ +{ + "version": 1, + "status": "success", + "eval": { + "task": "wikipedia", + "task_version": 0, + "task_file": "examples/agents/langchain/wikipedia.py", + "task_id": "YAdbKczyeSb6mEgPd3R9Qs", + "run_id": "i5LyrzaUdD9K4EW5WTAd5t", + "created": "2024-05-05T07:59:35", + "dataset": { + "name": "wikipedia", + "location": "wikipedia.jsonl" + }, + "model": "openai/gpt-4", + "task_attribs": {}, + "task_args": {}, + "model_args": {}, + "config": { + "limit": 20 + } + }, + "plan": { + "name": "plan", + "steps": [ + { + "solver": "wikipedia_search", + "params": {} + } + ], + "config": {} + }, + "results": { + "scorer": { + "name": "model_graded_fact", + "params": {} + }, + "metrics": { + "accuracy": { + "name": "accuracy", + "value": NaN, + "options": {} + }, + "bootstrap_std": { + "name": "bootstrap_std", + "value": 0.0, + "options": {} + } + } + }, + "stats": { + "started_at": "2024-05-05T07:59:35", + "completed_at": "2024-05-05T08:00:03", + "model_usage": { + "openai/gpt-4": { + "input_tokens": 8868, + "output_tokens": 1351, + "total_tokens": 10219 + } + } + }, + "logging": [] +} \ No newline at end of file diff --git a/tests/test_images.py b/tests/test_images.py index 297934f3f..f4843ee14 100644 --- a/tests/test_images.py +++ b/tests/test_images.py @@ -36,7 +36,7 @@ def test_google_images(): @skip_if_no_openai def test_openai_images(): - check_images("opeanai/gpt-4") + check_images("openai/gpt-4") @skip_if_no_anthropic diff --git a/tests/test_list_task.py b/tests/test_list_task.py index 93841846c..b425a9bbb 100644 --- a/tests/test_list_task.py +++ b/tests/test_list_task.py @@ -1,8 +1,7 @@ from pathlib import Path from typing import Callable -from inspect_ai._eval.list import list_tasks -from inspect_ai._eval.task import TaskInfo +from inspect_ai import TaskInfo, list_tasks TEST_TASKS_DIR = Path("tests/test_task_list") diff --git a/tests/test_scorer.py b/tests/test_scorer.py index 7db7e2fe6..964458eac 100644 --- a/tests/test_scorer.py +++ b/tests/test_scorer.py @@ -1,3 +1,4 @@ +import pytest from utils import run_example, skip_if_no_openai from inspect_ai import Task, eval, score @@ -24,6 +25,41 @@ def test_scorer_lookup(): assert scorer +def test_invalid_scorers_error(): + def not_async(): + def inner(state: TaskState, target: Target) -> Score: + return Score(value="C") + + return inner + + class NotCallable: + async def inner(self, state: TaskState, target: Target) -> Score: + return Score(value="C") + + class NotAsyncCallable: + def __call__(self, state: TaskState, target: Target) -> Score: + return Score(value="C") + + for f in [not_async, NotCallable, NotAsyncCallable]: + with pytest.raises(TypeError): + scorer(metrics=[accuracy()], name=f.__name__)(f)() + + +def test_valid_scorers_succeed(): + def is_async(): + async def inner(state: TaskState, target: Target) -> Score: + return Score(value="C") + + return inner + + class IsAsyncCallable: + async def __call__(self, state: TaskState, target: Target) -> Score: + return Score(value="C") + + for f in [is_async, IsAsyncCallable]: + scorer(metrics=[accuracy()], name=f.__name__)(f)() + + @skip_if_no_openai def test_no_scorer(): task = Task( diff --git a/tests/test_solver.py b/tests/test_solver.py index 6d6d26f57..6b9cf1fa0 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -1,3 +1,4 @@ +import pytest from utils import skip_if_no_openai from inspect_ai import Task, eval @@ -67,3 +68,38 @@ async def solve(state: TaskState, generate: Generate): log = eval(task, model=model, max_messages=2)[0] assert len(log.samples[0].messages) == 2 assert log.samples[0].output.completion == "finished" + + +def test_invalid_solvers_error(): + def not_async(): + def inner(state: TaskState, generate: Generate) -> TaskState: + return state + + return inner + + class NotCallable: + async def inner(self, state: TaskState, generate: Generate) -> TaskState: + return state + + class NotAsyncCallable: + def __call__(self, state: TaskState, target: Generate) -> TaskState: + return state + + for f in [not_async, NotCallable, NotAsyncCallable]: + with pytest.raises(TypeError): + solver(name=f.__name__)(f)() + + +def test_valid_solvers_succeed(): + def is_async(): + async def inner(self, state: TaskState, generate: Generate) -> TaskState: + return state + + return inner + + class IsAsyncCallable: + async def __call__(self, state: TaskState, generate: Generate) -> TaskState: + return state + + for f in [is_async, IsAsyncCallable]: + solver(name=f.__name__)(f)() diff --git a/tools/vscode/CHANGELOG.md b/tools/vscode/CHANGELOG.md new file mode 100644 index 000000000..6b10d4996 --- /dev/null +++ b/tools/vscode/CHANGELOG.md @@ -0,0 +1,14 @@ +# Changelog + +## 0.3.13 + +- Ensure that inspect CLI is in the path for terminals using a global Python environment +- Add 'Show Logs' command to the environment panel. +- Improve models in the environment panel + - Display literal provider names (rather than pretty names) + - Remember the last used model for each provider + - Allow free-form provide in model + - Add autocomplete for Ollama +- Fix 'Restart' when debugging to properly restart the Inspect debugging session +- Improve performance loading task tree, selecting tasks within outline, and navigating to tasks +- Improve task selection behavior when the activity bar is first shown \ No newline at end of file diff --git a/tools/vscode/package.json b/tools/vscode/package.json index 64f8569d5..9155c321f 100644 --- a/tools/vscode/package.json +++ b/tools/vscode/package.json @@ -7,7 +7,7 @@ "author": { "name": "UK AI Safety Institute" }, - "version": "0.3.12", + "version": "0.3.13", "license": "MIT", "homepage": "https://ukgovernmentbeis.github.io/inspect_ai/", "repository": { @@ -42,6 +42,7 @@ { "command": "inspect.showLogview", "title": "Show Logs", + "icon": "$(code-oss)", "category": "Inspect", "enablement": "workspaceFolderCount != 0" }, @@ -239,6 +240,11 @@ "command": "inspect.taskOutlineTree", "when": "view == inspect_ai.task-outline-view && config.inspect_ai.taskListView == 'list'" }, + { + "command": "inspect.showLogview", + "when": "view == inspect_ai.task-outline-view", + "group": "navigation" + }, { "command": "inspect.debugConfigTask", "when": "view == inspect_ai.task-configuration && inspect_ai.activeTask", diff --git a/tools/vscode/src/components/debugger.ts b/tools/vscode/src/components/debugger.ts deleted file mode 100644 index aae5d5a14..000000000 --- a/tools/vscode/src/components/debugger.ts +++ /dev/null @@ -1,72 +0,0 @@ -import { DebugSession, debug } from "vscode"; -import { activeWorkspaceFolder } from "../core/workspace"; -import { sleep } from "../core/wait"; -import { isServerListening } from "../core/port"; - -export class DebuggerManager { - constructor(private readonly sessionName_: string) { - debug.onDidStartDebugSession((debugSession: DebugSession) => { - if (this.activeSession_) { - throw new Error( - "Unexpectedly tried to start a debug session when one is already running!" - ); - } - if (debugSession.configuration.name === sessionName_) { - this.activeSession_ = debugSession; - } - }); - debug.onDidTerminateDebugSession((debugSession: DebugSession) => { - if ( - this.activeSession_ && - debugSession.configuration.name === this.activeSession_?.name - ) { - this.activeSession_ = undefined; - } - }); - } - - public async attach(port: number) { - // Stop any active sessions - if (this.activeSession_) { - await debug.stopDebugging(this.activeSession_); - this.activeSession_ = undefined; - } - - // Start a new session - const debugConfiguration = { - name: this.sessionName_, - type: "python", - request: "attach", - connect: { - host: "localhost", - port, - }, - }; - - // Retry attaching a few times - const maxRetries = 15; - const waitMs = 250; - let retries = 0; - - let launched = false; - do { - await sleep(waitMs); - const listening = await isServerListening( - debugConfiguration.connect.port, - debugConfiguration.connect.host - ); - if (listening) { - launched = await debug.startDebugging( - activeWorkspaceFolder(), - debugConfiguration - ); - } - retries++; - } while (!launched && retries < maxRetries); - - if (!launched) { - throw new Error("Timed out while waiting for debugger to attach"); - } - } - private activeSession_: DebugSession | undefined; -} diff --git a/tools/vscode/src/components/document.ts b/tools/vscode/src/components/document.ts index 9c9141127..3a0fcd75c 100644 --- a/tools/vscode/src/components/document.ts +++ b/tools/vscode/src/components/document.ts @@ -1,61 +1,41 @@ -import { DocumentSymbol, Position, Selection, TextDocument, Uri, commands, workspace } from "vscode"; -import { symbolIsTask } from "./symbol"; +import { Position, Selection, TextDocument, Uri, workspace } from "vscode"; +import { readTaskData } from "./task"; // Provides a Selection for a task with a document export const taskRangeForDocument = async (task: string, documentUri: Uri) => { - const taskSymbols = await tasksForDocument(documentUri); + const taskDatas = await tasksForDocument(documentUri); // Find the task that matches the name (or just select the first task) - const taskSymbol = taskSymbols.find((symbol) => { - return symbol.name === task; + const taskData = taskDatas.find((data) => { + return data.name === task; }); // If the task is within this document, find its position - if (taskSymbol) { - const position = new Position(taskSymbol.range.start.line + 1, 0); + if (taskData) { + const position = new Position(taskData.line + 1, 0); return new Selection(position, position); } }; export const firstTaskRangeForDocument = async (documentUri: Uri) => { - const symbols = await commands.executeCommand( - "vscode.executeDocumentSymbolProvider", - documentUri - ); - const document = await workspace.openTextDocument(documentUri); - const taskSymbol = symbols.find((pred) => { - return symbolIsTask(document, pred); - }); - - if (taskSymbol) { - const position = new Position(taskSymbol.range.start.line + 1, 0); + const taskDatas = await tasksForDocument(documentUri); + if (taskDatas.length > 0) { + const position = new Position(taskDatas[0].line + 1, 0); return new Selection(position, position); } }; // Provides a list of task DocumentSymbols for a document const tasksForDocument = async (documentUri: Uri) => { - const symbols = await commands.executeCommand( - "vscode.executeDocumentSymbolProvider", - documentUri - ); - const document = await workspace.openTextDocument(documentUri); - return symbols.filter((pred) => { - return symbolIsTask(document, pred); - }); + const tasks = readTaskData(document); + return tasks; }; -export const documentHasTasks = async (document: TextDocument) => { - const symbols = await commands.executeCommand( - "vscode.executeDocumentSymbolProvider", - document.uri - ); - - return symbols.some((pred) => { - return symbolIsTask(document, pred); - }); +export const documentHasTasks = (document: TextDocument) => { + const tasks = readTaskData(document); + return tasks.length > 0; }; diff --git a/tools/vscode/src/components/notebook.ts b/tools/vscode/src/components/notebook.ts index 6c849675f..531961c06 100644 --- a/tools/vscode/src/components/notebook.ts +++ b/tools/vscode/src/components/notebook.ts @@ -1,6 +1,6 @@ import { extname } from "path"; -import { DocumentSymbol, NotebookCellKind, NotebookDocument, NotebookRange, Position, Range, Selection, Uri, commands } from "vscode"; -import { symbolIsTask } from "./symbol"; +import { NotebookCellKind, NotebookDocument, NotebookRange, Position, Range, Selection, Uri } from "vscode"; +import { TaskData, readTaskData } from "./task"; export interface NotebookCellSelection { cell: NotebookRange, @@ -9,94 +9,78 @@ export interface NotebookCellSelection { // Tests whether a given Uri is a notebook export const isNotebook = (uri: Uri) => { - const ext = extname(uri.fsPath); + return isNotebookPath(uri.path); +}; + +export const isNotebookPath = (path: string) => { + const ext = extname(path); return ext === ".ipynb"; }; // Find the cell selection for a task within a notebook // Note that this provides both the cell range and the selection // within the cell -export const taskRangeForNotebook = async (task: string, document: NotebookDocument): Promise => { - const cellSelections = await taskCellsForNotebook(document); +export const taskRangeForNotebook = (task: string, document: NotebookDocument): NotebookCellSelection | undefined => { + const cells = cellTasks(document); // Find the cell that contains the task - const cellSelection = cellSelections.find((selection) => { - return selection.symbols.find((symbol) => { - return symbol.name === task; + const cellTask = cells.find((cell) => { + return cell.tasks.find((child) => { + return child.name === task; }); }); - // If there is a cell with this task, compute its range - if (cellSelection) { - const symbol = cellSelection.symbols.find((sym) => { - return sym.name === task; + if (cellTask) { + const taskDetail = cellTask.tasks.find((child) => { + return child.name === task; }); - - if (symbol) { - const position = new Position(symbol.range.start.line + 1, 0); + if (taskDetail) { + const index = cellTask.cellIndex; + const position = new Position(taskDetail.line + 1, 0); return { - cell: new NotebookRange(cellSelection.cellIndex, cellSelection.cellIndex), + cell: new NotebookRange(index, index), selection: new Selection(position, position) }; } } }; -export const firstTaskRangeForNotebook = async (document: NotebookDocument) => { - const cellSelections = await taskCellsForNotebook(document); +export const firstTaskRangeForNotebook = (document: NotebookDocument) => { + const cells = cellTasks(document); // Find a cell that contains a task - const cellSelection = cellSelections.find((selection) => { - return selection.symbols.find((symbol) => { - return symbolIsTask(document.cellAt(selection.cellIndex).document, symbol); - }); + const cellTask = cells.find((cell) => { + return cell.tasks.length > 0; }); // If there is a cell with a task, compute its range - if (cellSelection) { + if (cellTask) { // Just take the first task in the cell - const symbol = cellSelection.symbols.find((sym) => { - return symbolIsTask(document.cellAt(cellSelection.cellIndex).document, sym); - }); - - if (symbol) { - const position = new Position(symbol.range.start.line + 1, 0); - return { - cell: new NotebookRange(cellSelection.cellIndex, cellSelection.cellIndex), - selection: new Selection(position, position) - }; - } + const task = cellTask.tasks[0]; + const index = cellTask.cellIndex; + const position = new Position(task.line + 1, 0); + return { + cell: new NotebookRange(index, index), + selection: new Selection(position, position) + }; } }; - // Describes a cell position and the symbols within a cell -interface NotebookSymbols { +interface CellTasks { cellIndex: number; - symbols: DocumentSymbol[]; + tasks: TaskData[]; } // Provides a list of cell and DocumentSymbols within the cells which contain tasks -const taskCellsForNotebook = async (document: NotebookDocument): Promise => { - const ranges: NotebookSymbols[] = []; +export const cellTasks = (document: NotebookDocument): CellTasks[] => { + const ranges: CellTasks[] = []; for (const cell of document.getCells()) { if (cell.kind === NotebookCellKind.Code) { - - // Execute the document symbol provider for the cell document - const symbols: DocumentSymbol[] = await commands.executeCommand( - 'vscode.executeDocumentSymbolProvider', cell.document.uri) || []; - - // Find the function symbol in the cell - const taskSymbols = symbols.filter(symbol => { - // Check for the `@task` decorator before the function - const textRange = new Range(symbol.range.start, symbol.range.end); - const textBeforeFunction = cell.document.getText(textRange); - return textBeforeFunction.startsWith('@task'); - }); - - if (taskSymbols.length > 0) { - ranges.push({ cellIndex: cell.index, symbols: taskSymbols }); + const tasks = readTaskData(cell.document); + if (tasks.length > 0) { + ranges.push({ cellIndex: cell.index, tasks }); } } } diff --git a/tools/vscode/src/components/task.ts b/tools/vscode/src/components/task.ts new file mode 100644 index 000000000..691524533 --- /dev/null +++ b/tools/vscode/src/components/task.ts @@ -0,0 +1,102 @@ + +import { + TextDocument, + Uri, + +} from "vscode"; +import { lines } from "../core/text"; + + +// Task information for a document +export interface DocumentTaskInfo { + document: Uri; + tasks: TaskData[]; + activeTask?: TaskData; +} + +// Describes the current active task +export interface TaskData { + name: string; + params: string[]; + line: number +} + +// Reads tasks from a TextDocument +// Quickly reads the default task using text based parsing +// This can't properly deal with things like selection, so this should only +// be used when no selection behavior is warranted +const kTaskPattern = /@task/; +const kFunctionNamePattern = /def\s+(.*)\((.*)$/; + +const kFunctionEndPattern = /\s*\):\s*/; +const kParamsPattern = /^(.*?)(\):)?$/; + +export function readTaskData(document: TextDocument): TaskData[] { + const tasks: TaskData[] = []; + const docLines = lines(document.getText()); + + let state: "seeking-task" | "seeking-function" | "reading-params" = "seeking-task"; + let startLine = -1; + docLines.forEach((line, idx) => { + switch (state) { + case "seeking-task": + if (kTaskPattern.test(line)) { + startLine = idx; + state = "seeking-function"; + } + break; + case "seeking-function": { + const match = line.match(kFunctionNamePattern); + if (match) { + const fnName = match[1]; + const task: TaskData = { + name: fnName, + params: [], + line: startLine + }; + tasks.push(task); + + const restOfLine = match[2]; + const keepReading = readParams(restOfLine, task); + if (keepReading) { + state = "reading-params"; + } else { + // We've read the complete function, go + // back to seeking tasks + state = "seeking-task"; + } + } + } + break; + case "reading-params": { + const keepReading = readParams(line, tasks[tasks.length - 1]); + if (keepReading) { + state = "reading-params"; + } else { + // We've read the complete function, go + // back to seeking tasks + state = "seeking-task"; + } + } + } + }); + return tasks; +} + + +const readParams = (line: string, task: TaskData) => { + const paramsMatch = line.match(kParamsPattern); + if (paramsMatch) { + const paramsStr = paramsMatch[1]; + if (paramsStr) { + const params = paramsStr.split(","); + params.forEach((param) => { + const name = param.split("=")[0].trim(); + if (name) { + task.params.push(name); + } + }); + } + } + return !kFunctionEndPattern.test(line); +}; diff --git a/tools/vscode/src/providers/active-task/active-task-provider.ts b/tools/vscode/src/providers/active-task/active-task-provider.ts index 4b163414f..0c84d42fb 100644 --- a/tools/vscode/src/providers/active-task/active-task-provider.ts +++ b/tools/vscode/src/providers/active-task/active-task-provider.ts @@ -1,25 +1,24 @@ import { - DocumentSymbol, Event, EventEmitter, ExtensionContext, Position, - Range, Selection, - SymbolKind, TextDocument, Uri, commands, window, + workspace, } from "vscode"; -import { sleep } from "../../core/wait"; import { DebugActiveTaskCommand, RunActiveTaskCommand, } from "./active-task-command"; import { InspectEvalManager } from "../inspect/inspect-eval"; import { Command } from "../../core/command"; +import { DocumentTaskInfo, readTaskData } from "../../components/task"; +import { cellTasks, isNotebook } from "../../components/notebook"; // Activates the provider which tracks the currently active task (document and task name) export function activateActiveTaskProvider( @@ -35,22 +34,9 @@ export function activateActiveTaskProvider( return [commands, activeTaskManager]; } -// Task information for a document -export interface ActiveTaskInfo { - document: Uri; - tasks: TaskData[]; - activeTask?: TaskData; -} - -// Describes the current active task -export interface TaskData { - name: string; - params: string[]; -} - // Fired when the active task changes export interface ActiveTaskChangedEvent { - activeTaskInfo?: ActiveTaskInfo; + activeTaskInfo?: DocumentTaskInfo; } // Tracks task information for the current editor @@ -60,7 +46,7 @@ export class ActiveTaskManager { // when there is a new selection context.subscriptions.push( window.onDidChangeTextEditorSelection(async (event) => { - await this.updateActiveTask( + await this.updateActiveTaskWithDocument( event.textEditor.document, event.selections[0] ); @@ -71,7 +57,7 @@ export class ActiveTaskManager { window.onDidChangeActiveNotebookEditor(async (event) => { if (window.activeNotebookEditor?.selection.start) { const cell = event?.notebook.cellAt(window.activeNotebookEditor.selection.start); - await this.updateActiveTask( + await this.updateActiveTaskWithDocument( cell?.document, new Selection(new Position(0, 0), new Position(0, 0)) ); @@ -81,20 +67,20 @@ export class ActiveTaskManager { context.subscriptions.push(window.onDidChangeActiveTextEditor(async (event) => { if (event) { - await this.updateActiveTask( + await this.updateActiveTaskWithDocument( event.document ); } })); } - private activeTaskInfo_: ActiveTaskInfo | undefined; + private activeTaskInfo_: DocumentTaskInfo | undefined; private readonly onActiveTaskChanged_ = new EventEmitter(); // Event to be notified when task information changes public readonly onActiveTaskChanged: Event = this.onActiveTaskChanged_.event; // Get the task information for the current selection - public getActiveTaskInfo(): ActiveTaskInfo | undefined { + public getActiveTaskInfo(): DocumentTaskInfo | undefined { return this.activeTaskInfo_; } @@ -102,29 +88,51 @@ export class ActiveTaskManager { public async refresh() { const currentSelection = window.activeTextEditor?.selection; const currentDocument = window.activeTextEditor?.document; - await this.updateActiveTask(currentDocument, currentSelection); + await this.updateActiveTaskWithDocument(currentDocument, currentSelection); } - async updateActiveTask(document?: TextDocument, selection?: Selection) { + private async updateTask(activeTaskInfo?: DocumentTaskInfo) { let taskActive = false; + if (activeTaskInfo) { + this.setActiveTaskInfo(activeTaskInfo); + taskActive = activeTaskInfo !== undefined; + } + await commands.executeCommand( + "setContext", + "inspect_ai.activeTask", + taskActive + ); + } + + async updateActiveTaskWithDocument(document?: TextDocument, selection?: Selection) { if (document && selection) { - if (document.languageId === "python") { - const activeTaskInfo = await getTaskInfo(document, selection); - if (activeTaskInfo) { - this.setActiveTaskInfo(activeTaskInfo); - taskActive = activeTaskInfo !== undefined; - } + const activeTaskInfo = document.languageId === "python" ? getTaskInfoFromDocument(document, selection) : undefined; + await this.updateTask(activeTaskInfo); + } + } + + async updateActiveTask(documentUri: Uri, task: string) { + if (isNotebook(documentUri)) { + // Compute the cell and position of the task + const notebookDocument = await workspace.openNotebookDocument(documentUri); + const cells = cellTasks(notebookDocument); + const cellTask = cells.find((c) => { + return c.tasks.find((t) => { return t.name === task; }); + }); + if (cellTask) { + const cell = notebookDocument.cellAt(cellTask?.cellIndex); + const taskInfo = getTaskInfo(cell.document, task); + await this.updateTask(taskInfo); } - await commands.executeCommand( - "setContext", - "inspect_ai.activeTask", - taskActive - ); + } else { + const document = await workspace.openTextDocument(documentUri); + const taskInfo = getTaskInfo(document, task); + await this.updateTask(taskInfo); } } // Set the task information - setActiveTaskInfo(task?: ActiveTaskInfo) { + setActiveTaskInfo(task?: DocumentTaskInfo) { if (this.activeTaskInfo_ !== task) { this.activeTaskInfo_ = task; this.onActiveTaskChanged_.fire({ activeTaskInfo: this.activeTaskInfo_ }); @@ -132,103 +140,56 @@ export class ActiveTaskManager { } } -async function getTaskInfo( +function getTaskInfoFromDocument( document: TextDocument, - selection: Selection -): Promise { + selection?: Selection +): DocumentTaskInfo | undefined { // Try to get symbols to read task info for this document // Note that the retry is here since the symbol provider // has latency in loading and there wasn't a way to wait // on it specifically (waiting on the Python extension didn't work) - let symbols: DocumentSymbol[] | undefined; - let count = 0; - do { - symbols = await commands.executeCommand( - "vscode.executeDocumentSymbolProvider", - document.uri - ); - count++; - if (!symbols) { - await sleep(500); - } - - } while (count <= 5 && !symbols); - - if (symbols) { - const functionSymbols = symbols.filter( - (symbol) => symbol.kind === SymbolKind.Function - ); + const tasks = readTaskData(document); - const tasks: TaskData[] = []; - let activeTask = undefined; + // If there are no tasks in this document, return undefined + if (tasks.length === 0) { + return undefined; + } - if (functionSymbols) { - functionSymbols.forEach((symbol) => { - if (isTask(document, symbol)) { - const signatureRange = getSignatureRange(document, symbol); - const variables = symbol.children.filter( - (child) => child.kind === SymbolKind.Variable - ); - const params = variables - .filter((variable) => { - return signatureRange?.contains(variable.range); - }) - .map((variable) => variable.name); - - const task = { - name: symbol.name, - params, - }; - tasks.push(task); - - if (symbol.range.contains(selection.start)) { - activeTask = task; - } - } - }); - } + const selectionLine = selection?.start.line || 0; - // If there are tasks in this file, just consider the first task active - if (tasks.length > 0 && !activeTask) { - activeTask = tasks[0]; - } + // Find the first task that appears before the selection + // or otherwise the first task - return { - document: document.uri, - tasks, - activeTask, - }; - } + const activeTask = [...tasks].reverse().find((task) => { + return task.line <= selectionLine; + }); + return { + document: document.uri, + tasks, + activeTask: activeTask || (tasks.length > 0 ? tasks[0] : undefined) + }; } -function isTask(document: TextDocument, symbol: DocumentSymbol) { - const functionText = document.getText(symbol.range); - if (functionText.startsWith("@task")) { - return true; - } +function getTaskInfo( + document: TextDocument, + task: string +): DocumentTaskInfo | undefined { + // Try to get symbols to read task info for this document + // Note that the retry is here since the symbol provider + // has latency in loading and there wasn't a way to wait + // on it specifically (waiting on the Python extension didn't work) + const tasks = readTaskData(document); + + // Find the first task that appears before the selection + // or otherwise the first task + + const activeTask = [...tasks].reverse().find((t) => { + return t.name === task; + }); + return { + document: document.uri, + tasks, + activeTask: activeTask || (tasks.length > 0 ? tasks[0] : undefined) + }; } - -function getSignatureRange(document: TextDocument, symbol: DocumentSymbol) { - const functionText = document.getText(symbol.range); - - // Split the text into lines to process it line-by-line - const lines = functionText.split("\n"); - for (let i = 0; i < lines.length; i++) { - const line = lines[i]; - if (line.trim().startsWith("def") && line.trim().endsWith(":")) { - // Calculate the end position of the function definition line - const startLine = symbol.range.start.line + i; - const endLine = startLine; - - // End at the end of the definition line - const endCharacter = line.length; - - // Create a new range for just the definition - return new Range( - new Position(startLine, 0), - new Position(endLine, endCharacter) - ); - } - } -} \ No newline at end of file diff --git a/tools/vscode/src/providers/activity-bar/activity-bar-provider.ts b/tools/vscode/src/providers/activity-bar/activity-bar-provider.ts index 78d711402..7bc31cbbe 100644 --- a/tools/vscode/src/providers/activity-bar/activity-bar-provider.ts +++ b/tools/vscode/src/providers/activity-bar/activity-bar-provider.ts @@ -30,9 +30,9 @@ export async function activateActivityBar( window.registerWebviewViewProvider(EnvConfigurationProvider.viewType, envProvider) ); - const taskConfigPRovider = new TaskConfigurationProvider(context.extensionUri, workspaceStateMgr, activeTaskManager, inspectManager); + const taskConfigProvider = new TaskConfigurationProvider(context.extensionUri, workspaceStateMgr, activeTaskManager, inspectManager); context.subscriptions.push( - window.registerWebviewViewProvider(TaskConfigurationProvider.viewType, taskConfigPRovider) + window.registerWebviewViewProvider(TaskConfigurationProvider.viewType, taskConfigProvider) ); const taskConfigCommands = [ new RunConfigTaskCommand(activeTaskManager, inspectEvalMgr), diff --git a/tools/vscode/src/providers/activity-bar/env-config-provider.ts b/tools/vscode/src/providers/activity-bar/env-config-provider.ts index 71d2e42ec..45e21d4fa 100644 --- a/tools/vscode/src/providers/activity-bar/env-config-provider.ts +++ b/tools/vscode/src/providers/activity-bar/env-config-provider.ts @@ -43,6 +43,34 @@ export interface EnvConfiguration { modelBaseUrl?: string; } +// A list of the auto-complete models and the minimum +// version required in order to support the model +const kInspectModels: Record = { + "openai": "0.3.8", + "anthropic": "0.3.8", + "google": "0.3.8", + "mistral": "0.3.8", + "hf": "0.3.8", + "together": "0.3.8", + "bedrock": "0.3.8", + "ollama": "0.3.9", + "azureai": "0.3.8", + "cf": "0.3.8" +}; + +const inspectModels = () => { + const currentVersion = inspectVersion(); + return Object.keys(kInspectModels).filter((key) => { + const ver = kInspectModels[key]; + if (currentVersion !== null) { + return currentVersion.compare(ver) > -1; + } else { + return false; + } + }); +}; + + export class EnvConfigurationProvider implements WebviewViewProvider { public static readonly viewType = "inspect_ai.env-configuration-view"; @@ -78,9 +106,41 @@ export class EnvConfigurationProvider implements WebviewViewProvider { break; } case "setEnvValue": - setConfiguration(data.default, data.value, this.env); - this.envManager_.setValues(configToEnv(this.env)); - break; + { + // Set the value + setConfiguration(data.default, data.value, this.env); + + // Special case for provider, potentially restoring the + // previously used model + let updateWebview = false; + if (data.default === "provider") { + const modelState = this.stateManager_.getModelState(data.value); + setConfiguration("model", modelState.lastModel || "", this.env); + updateWebview = true; + } + + // Save the most recently used model for this provider + if (this.env.provider && data.default === "model") { + const modelState = this.stateManager_.getModelState(this.env.provider); + modelState.lastModel = data.value; + await this.stateManager_.setModelState(this.env.provider, modelState); + } + + // Save the env + this.envManager_.setValues(configToEnv(this.env)); + + + if (updateWebview) { + await webviewView.webview.postMessage({ + type: kEnvChanged, + message: { + env: this.env, + }, + }); + } + + break; + } case "openUrl": await env.openExternal(Uri.parse(data.url)); break; @@ -168,6 +228,10 @@ export class EnvConfigurationProvider implements WebviewViewProvider { // Use a nonce to only allow a specific script to be run. const nonce = getNonce(); + const modelOptions = inspectModels().map((model) => { + return `${model}`; + }); + return ` @@ -206,15 +270,7 @@ export class EnvConfigurationProvider implements WebviewViewProvider {
None - OpenAI - Anthropic - Google - Mistral - Hugging Face - TogetherAI - AWS Bedrock - Azure AI - Cloudflare + ${modelOptions.join("\n")}
diff --git a/tools/vscode/src/providers/activity-bar/task-config-provider.ts b/tools/vscode/src/providers/activity-bar/task-config-provider.ts index e4e6cbf0d..907dbe8db 100644 --- a/tools/vscode/src/providers/activity-bar/task-config-provider.ts +++ b/tools/vscode/src/providers/activity-bar/task-config-provider.ts @@ -13,16 +13,16 @@ import { } from "../workspace/workspace-state-provider"; import { ActiveTaskChangedEvent, - ActiveTaskInfo, ActiveTaskManager, } from "../active-task/active-task-provider"; import { basename } from "path"; import { inspectVersion } from "../../inspect"; import { InspectManager } from "../inspect/inspect-manager"; +import { DocumentTaskInfo } from "../../components/task"; export type SetActiveTaskCommand = { type: "setActiveTask"; - task: ActiveTaskInfo; + task: DocumentTaskInfo; state: DocumentState; } & Record; @@ -74,6 +74,14 @@ export class TaskConfigurationProvider implements WebviewViewProvider { }); }; + const updateTaskInfo = async (activeTaskInfo?: DocumentTaskInfo) => { + if (activeTaskInfo) { + await postActiveTaskMsg(activeTaskInfo); + } else { + webviewView.description = ""; + } + }; + this.disposables_.push( this.taskManager_.onActiveTaskChanged( async (e: ActiveTaskChangedEvent) => { @@ -94,9 +102,7 @@ export class TaskConfigurationProvider implements WebviewViewProvider { this.disposables_.push(webviewView.onDidChangeVisibility(async () => { const activeTask = this.taskManager_.getActiveTaskInfo(); - if (activeTask) { - await postActiveTaskMsg(activeTask); - } + await updateTaskInfo(activeTask); })); webviewView.webview.html = this.htmlForWebview(webviewView.webview); @@ -104,7 +110,8 @@ export class TaskConfigurationProvider implements WebviewViewProvider { await initMsg(); - const postActiveTaskMsg = async (activeTaskInfo: ActiveTaskInfo) => { + + const postActiveTaskMsg = async (activeTaskInfo: DocumentTaskInfo) => { // Ignore active task changes that don't include a task (e.g. they are files) if ( !activeTaskInfo.activeTask || @@ -130,19 +137,16 @@ export class TaskConfigurationProvider implements WebviewViewProvider { this.disposables_.push( this.taskManager_.onActiveTaskChanged(async (e) => { await updateSidebarState(e.activeTaskInfo); - if (e.activeTaskInfo) { - await postActiveTaskMsg(e.activeTaskInfo); - } + await updateTaskInfo(e.activeTaskInfo); }) ); if (inspectVersion() === null) { await noInspectMsg(); } + const activeTask = this.taskManager_.getActiveTaskInfo(); - if (activeTask) { - await postActiveTaskMsg(activeTask); - } + await updateTaskInfo(activeTask); // Process UI messages webviewView.webview.onDidReceiveMessage( @@ -287,7 +291,7 @@ export class TaskConfigurationProvider implements WebviewViewProvider { } } -const updateSidebarState = async (taskInfo?: ActiveTaskInfo) => { +const updateSidebarState = async (taskInfo?: DocumentTaskInfo) => { await commands.executeCommand( "setContext", "inspect_ai.task-configuration.task-active", diff --git a/tools/vscode/src/providers/activity-bar/task-outline-commands.ts b/tools/vscode/src/providers/activity-bar/task-outline-commands.ts index eb7ef8302..941c223c7 100644 --- a/tools/vscode/src/providers/activity-bar/task-outline-commands.ts +++ b/tools/vscode/src/providers/activity-bar/task-outline-commands.ts @@ -6,7 +6,7 @@ import { ConfigurationTarget, ExtensionContext, ViewColumn, - commands + commands, } from "vscode"; import { Command } from "../../core/command"; @@ -20,10 +20,19 @@ import { writeFileSync } from "fs"; import { readTemplate, templates } from "../../components/templates"; import { isValidPythonFnName } from "../../core/python"; -import { documentHasTasks, firstTaskRangeForDocument, taskRangeForDocument } from "../../components/document"; -import { firstTaskRangeForNotebook, isNotebook, taskRangeForNotebook } from "../../components/notebook"; +import { + documentHasTasks, + firstTaskRangeForDocument, + taskRangeForDocument, +} from "../../components/document"; +import { + firstTaskRangeForNotebook, + isNotebook, + taskRangeForNotebook, +} from "../../components/notebook"; import { scheduleReturnFocus } from "../../components/focus"; import { InspectLogviewManager } from "../logview/logview-manager"; +import { ActiveTaskManager } from "../active-task/active-task-provider"; export class ShowTaskTree implements Command { constructor(private readonly provider_: TaskOutLineTreeDataProvider) { } @@ -56,7 +65,11 @@ export class RunSelectedEvalCommand implements Command { const task = treeItem.taskPath.type === "task" ? treeItem.taskPath.name : undefined; - const evalPromise = this.inspectEvalMgr_.startEval(toAbsolutePath(path), task, false); + const evalPromise = this.inspectEvalMgr_.startEval( + toAbsolutePath(path), + task, + false + ); const resumeFocusCommand = "inspect_ai.task-outline-view.focus"; scheduleReturnFocus(resumeFocusCommand); await evalPromise; @@ -78,7 +91,11 @@ export class DebugSelectedEvalCommand implements Command { } export class EditSelectedTaskCommand implements Command { - constructor(private readonly tree_: TreeView, private inspectLogviewManager_: InspectLogviewManager) { } + constructor( + private readonly tree_: TreeView, + private inspectLogviewManager_: InspectLogviewManager, + private activeTaskManager_: ActiveTaskManager + ) { } async execute() { if (this.tree_.selection.length > 1) { throw new Error("Expected only a single selector for the task tree"); @@ -95,32 +112,42 @@ export class EditSelectedTaskCommand implements Command { // If this is a specific task, go right to that const task = - treeItem.taskPath.type === "task" ? treeItem.taskPath.name : undefined; + treeItem.taskPath.type === "task" ? treeItem.taskPath.name : treeItem.taskPath.children?.[0].name; // Note if/where the logview is showing const logViewColumn = this.inspectLogviewManager_.viewColumn(); + // Update the active task + if (task) { + await this.activeTaskManager_.updateActiveTask(fileUri, task); + } + if (isNotebook(fileUri)) { // Compute the cell and position of the task const notebookDocument = await workspace.openNotebookDocument(fileUri); - const findTaskSelection = async (task: string | undefined) => { + const findTaskSelection = (task: string | undefined) => { if (task) { - return await taskRangeForNotebook(task, notebookDocument); + return taskRangeForNotebook(task, notebookDocument); } else { - return await firstTaskRangeForNotebook(notebookDocument); + return firstTaskRangeForNotebook(notebookDocument); } }; - const taskSelection = await findTaskSelection(task); + const taskSelection = findTaskSelection(task); // Open the notebook to the specified cell, if any - const selections = taskSelection?.cell ? [taskSelection?.cell] : undefined; - await window.showNotebookDocument(notebookDocument, { selections, preview: false, viewColumn: findTargetViewColumn(logViewColumn) }); + const selections = taskSelection?.cell + ? [taskSelection?.cell] + : undefined; + await window.showNotebookDocument(notebookDocument, { + selections, + preview: false, + viewColumn: findTargetViewColumn(logViewColumn), + }); if (selections) { - await commands.executeCommand('notebook.cell.edit'); + await commands.executeCommand("notebook.cell.edit"); } } else { - // Find the task selection for the document (if any) const findTaskSelection = async (task: string | undefined) => { if (task) { @@ -132,7 +159,11 @@ export class EditSelectedTaskCommand implements Command { const selection = await findTaskSelection(task); // Show the document - await window.showTextDocument(fileUri, { selection, viewColumn: findTargetViewColumn(logViewColumn), preview: false }); + await window.showTextDocument(fileUri, { + selection, + viewColumn: findTargetViewColumn(logViewColumn), + preview: false, + }); } } } @@ -156,7 +187,7 @@ export const findTargetViewColumn = (logViewColumn?: ViewColumn) => { return targetEditor.viewColumn; } - // There are no editors with tasks, but if there is any active + // There are no editors with tasks, but if there is any active // editor, let's use that if (visibleEditors.length > 0) { return visibleEditors[0].viewColumn; @@ -165,7 +196,7 @@ export const findTargetViewColumn = (logViewColumn?: ViewColumn) => { // If we get here, there are no editors open at all // so as a last ditch, just open next to the logview - // (if it showing) or in the first column + // (if it showing) or in the first column // (who knows what it showing in this case) if (logViewColumn) { // The log view is open, but not any editors @@ -195,7 +226,6 @@ export class CreateTaskCommand implements Command { }); if (taskName) { - // force the task name to lower case const taskNameLower = taskName.toLowerCase(); @@ -212,7 +242,6 @@ export class CreateTaskCommand implements Command { // Create empty document const document = await workspace.openTextDocument(absPath.path); await window.showTextDocument(document); - } } diff --git a/tools/vscode/src/providers/activity-bar/task-outline-provider.ts b/tools/vscode/src/providers/activity-bar/task-outline-provider.ts index cdd34982d..306220ff4 100644 --- a/tools/vscode/src/providers/activity-bar/task-outline-provider.ts +++ b/tools/vscode/src/providers/activity-bar/task-outline-provider.ts @@ -13,6 +13,7 @@ import { Disposable, TreeView, TreeViewVisibilityChangeEvent, + Uri, } from "vscode"; import { RunSelectedEvalCommand, @@ -32,13 +33,13 @@ import { } from "../workspace/workspace-task-provider"; import { basename, relative, sep } from "path"; import { - ActiveTaskInfo, ActiveTaskManager, } from "../active-task/active-task-provider"; import { throttle } from "lodash"; import { inspectVersion } from "../../inspect"; import { InspectManager } from "../inspect/inspect-manager"; import { InspectLogviewManager } from "../logview/logview-manager"; +import { DocumentTaskInfo } from "../../components/task"; // Activation function for the task outline export async function activateTaskOutline( @@ -70,9 +71,11 @@ export async function activateTaskOutline( }; // If the interpreter changes, refresh the tasks - context.subscriptions.push(inspectManager.onInspectChanged(async () => { - await checkInspect(); - })); + context.subscriptions.push( + inspectManager.onInspectChanged(async () => { + await checkInspect(); + }) + ); await checkInspect(); const tree = window.createTreeView(TaskOutLineTreeDataProvider.viewType, { @@ -81,11 +84,34 @@ export async function activateTaskOutline( canSelectMany: false, }); - // Activate task tracking context.subscriptions.push( - ...await activateTaskTracking(treeDataProvider, tree, activeTaskManager) + tree.onDidChangeVisibility(async (e: TreeViewVisibilityChangeEvent) => { + // If the tree becomes visible with nothing selected, try selecting + if (e.visible && tree.selection.length === 0) { + const activeTask = activeTaskManager.getActiveTaskInfo(); + if (activeTask) { + await showTreeItem(treeDataProvider, tree, activeTask); + } else { + const first = await findFirstTask(treeDataProvider); + if (first) { + await activeTaskManager.updateActiveTask(Uri.file(first.taskPath.path), first.taskPath.name); + } + } + } else if (e.visible) { + // Tree just became visible, be sure selection matches the active task + await showTreeItem( + treeDataProvider, + tree, + activeTaskManager.getActiveTaskInfo() + ); + } + }) ); + // Activate task tracking + context.subscriptions.push( + ...(await activateTaskTracking(treeDataProvider, tree, activeTaskManager)) + ); return [ [ @@ -93,53 +119,62 @@ export async function activateTaskOutline( new ShowTaskTree(treeDataProvider), new RunSelectedEvalCommand(inspectEvalMgr), new DebugSelectedEvalCommand(inspectEvalMgr), - new EditSelectedTaskCommand(tree, inspectLogviewManager), + new EditSelectedTaskCommand(tree, inspectLogviewManager, activeTaskManager), new CreateTaskCommand(context), ], treeDataProvider, ]; } - const activateTaskTracking = async ( treeDataProvider: TaskOutLineTreeDataProvider, tree: TreeView, activeTaskManager: ActiveTaskManager ) => { - - const showTreeItem = async (activeTask?: ActiveTaskInfo) => { - if (!activeTask) { - return; - } - - const treeItem = await findTreeItem(activeTask, treeDataProvider); - if (treeItem && treeItem.taskPath.type === "task") { - // Don't reveal the item if the tree isn't visible (this will force the - // activity bar containing the tree to become visible, which is very jarring) - if (tree.visible) { - await tree.reveal(treeItem, { select: true, focus: false }); - } - } - }; - // Listen for changes to the active task and drive the tree to the item const activeTaskChanged = activeTaskManager.onActiveTaskChanged(async (e) => { - await showTreeItem(e.activeTaskInfo); + await showTreeItem(treeDataProvider, tree, e.activeTaskInfo); }); - const treeVisibilityChanged = tree.onDidChangeVisibility(async (e: TreeViewVisibilityChangeEvent) => { - // Since we may have been ignoring activation events for tasks, when the tree reveals - // itself, ensure that we select the currently activate task, if any - if (e.visible) { - await showTreeItem(activeTaskManager.getActiveTaskInfo()); + const currentlyActive = activeTaskManager.getActiveTaskInfo(); + await showTreeItem(treeDataProvider, tree, currentlyActive); + return [activeTaskChanged]; +}; + +const showTreeItem = async ( + treeDataProvider: TaskOutLineTreeDataProvider, + tree: TreeView, + activeTask?: DocumentTaskInfo +) => { + if (!activeTask) { + return; + } + + const treeItem = await findTreeItem(activeTask, treeDataProvider); + if (treeItem && treeItem.taskPath.type === "task") { + // Don't reveal the item if the tree isn't visible (this will force the + // activity bar containing the tree to become visible, which is very jarring) + if (tree.visible) { + await tree.reveal(treeItem, { select: true, focus: false }); } - }); + } +}; - const currentlyActive = activeTaskManager.getActiveTaskInfo(); - await showTreeItem(currentlyActive); - return [activeTaskChanged, treeVisibilityChanged]; +const findFirstTask = async ( + treeDataProvider: TaskOutLineTreeDataProvider, + element?: TaskTreeItem +): Promise => { + const children = await treeDataProvider.getChildren(element); + for (const child of children) { + if (child.taskPath.type === "task") { + return child; + } else { + return await findFirstTask(treeDataProvider, child); + } + } }; + // A tree item for a task, file, or folder export class TaskTreeItem extends TreeItem { constructor( @@ -199,23 +234,28 @@ export class TaskOutLineTreeDataProvider private readonly workspaceMgr: WorkspaceTaskManager, private readonly command_?: VsCodeCommand ) { - this.disposables_.push( - this.workspaceMgr.onTasksChanged(throttle(async (e: TasksChangedEvent) => { - this.setTasks(e.tasks || []); - await commands.executeCommand( - "setContext", - "inspect_ai.task-outline-view.tasksLoaded", - true - ); - await commands.executeCommand( - "setContext", - "inspect_ai.task-outline-view.noTasks", - e.tasks?.length === 0 - ); - }, 500, { leading: true, trailing: true }))); + this.workspaceMgr.onTasksChanged( + throttle( + async (e: TasksChangedEvent) => { + this.setTasks(e.tasks || []); + await commands.executeCommand( + "setContext", + "inspect_ai.task-outline-view.tasksLoaded", + true + ); + await commands.executeCommand( + "setContext", + "inspect_ai.task-outline-view.noTasks", + e.tasks?.length === 0 + ); + }, + 500, + { leading: true, trailing: true } + ) + ) + ); this.workspaceMgr.refresh(); - } private disposables_: Disposable[] = []; dispose() { @@ -315,7 +355,7 @@ function describeRelativePath(path: string, workspacePath?: AbsolutePath) { // Find a task in the tree based upon its // path (e.g. document path and task name) async function findTreeItem( - activeTask: ActiveTaskInfo, + activeTask: DocumentTaskInfo, treeDataProvider: TaskOutLineTreeDataProvider ) { const filePath = activeTask.document.fsPath; diff --git a/tools/vscode/src/providers/activity-bar/webview/env-config-webview.ts b/tools/vscode/src/providers/activity-bar/webview/env-config-webview.ts index b886f786b..686d5ba91 100644 --- a/tools/vscode/src/providers/activity-bar/webview/env-config-webview.ts +++ b/tools/vscode/src/providers/activity-bar/webview/env-config-webview.ts @@ -87,9 +87,7 @@ function getProviderEl() { function getProviderText() { const providerEl = getProviderEl(); - - const index = providerEl.selectedIndex; - return index > -1 ? providerEl.options[index].value : providerEl.value; + return providerEl.value; } function resetModel() { @@ -157,10 +155,7 @@ const attachListeners = () => { getProviderEl().value = ""; } if (e.target) { - const index = el.selectedIndex; - const value = index > -1 ? el.options[index].value : el.value; - setEnvValue("provider", value); - setEnvValue("model", ""); + setEnvValue("provider", txt); resetModel(); showProviderHelp(); } @@ -169,7 +164,6 @@ const attachListeners = () => { const el = document.getElementById("provider") as HTMLSelectElement; el.addEventListener("change", providerChanged); - el.addEventListener("keyup", debounce(providerChanged, kBounceInterval)); setEnvWhenKeyup("model", "model"); diff --git a/tools/vscode/src/providers/inspect/inspect-eval.ts b/tools/vscode/src/providers/inspect/inspect-eval.ts index 6f7b01529..72cdb4caf 100644 --- a/tools/vscode/src/providers/inspect/inspect-eval.ts +++ b/tools/vscode/src/providers/inspect/inspect-eval.ts @@ -1,36 +1,37 @@ -import { window, workspace } from "vscode"; -import { findOpenPort } from "../../core/port"; -import { DebuggerManager } from "../../components/debugger"; +import { DebugConfiguration, debug, window, workspace } from "vscode"; import { inspectEvalCommands } from "./inspect-eval-commands"; import { Command } from "../../core/command"; -import { AbsolutePath, activeWorkspacePath, workspaceRelativePath } from "../../core/path"; +import { + AbsolutePath, + activeWorkspacePath, + workspaceRelativePath, +} from "../../core/path"; import { WorkspaceStateManager } from "../workspace/workspace-state-provider"; import { inspectVersion } from "../../inspect"; import { inspectBinPath } from "../../inspect/props"; +import { activeWorkspaceFolder } from "../../core/workspace"; +import { findOpenPort } from "../../core/port"; - -const kDebugSessionName = "Inspect Eval"; - -export function activateEvalManager(stateManager: WorkspaceStateManager): [Command[], InspectEvalManager] { +export function activateEvalManager( + stateManager: WorkspaceStateManager +): [Command[], InspectEvalManager] { const inspectEvalMgr = new InspectEvalManager(stateManager); return [inspectEvalCommands(inspectEvalMgr), inspectEvalMgr]; } export class InspectEvalManager { constructor(private readonly stateManager_: WorkspaceStateManager) { - this.debuggerManager_ = new DebuggerManager(kDebugSessionName); } - private debuggerManager_: DebuggerManager; public async startEval(file: AbsolutePath, task?: string, debug = false) { - // if we don't have inspect bail and let the user know if (!inspectVersion()) { await window.showWarningMessage( - `Unable to ${debug ? "Debug" : "Run"} Eval (Inspect Package Not Installed)`, + `Unable to ${debug ? "Debug" : "Run" + } Eval (Inspect Package Not Installed)`, { modal: true, - detail: "pip install --upgrade inspect-ai" + detail: "pip install --upgrade inspect-ai", } ); return; @@ -48,7 +49,10 @@ export class InspectEvalManager { // Forward the various doc state args const limit = docState.limit; - if (debug === true && workspace.getConfiguration("inspect_ai").get("debugSingleSample")) { + if ( + debug === true && + workspace.getConfiguration("inspect_ai").get("debugSingleSample") + ) { args.push(...["--limit", "1"]); } else if (limit) { args.push(...["--limit", limit]); @@ -88,33 +92,28 @@ export class InspectEvalManager { }); } - // Handle debugging - let debugPort = 5678; - if (debug === true) { - // Provision a port - debugPort = await findOpenPort(debugPort); + // If we're debugging, launch using the debugger + if (debug) { - args.push("--debug"); - args.push("--debug-port"); - args.push(debugPort.toString()); - } + // Handle debugging + let debugPort = 5678; + if (debug === true) { + // Provision a port + debugPort = await findOpenPort(debugPort); - // Run the command - runEvalCmd(args, workspaceDir.path); + args.push("--debug-port"); + args.push(debugPort.toString()); + } - // If we're debugging, attach the debugger - if (debug) { - await this.debuggerManager_.attach(debugPort); + await runDebugger(inspectBinPath()?.path || "inspect", args, workspaceDir.path, debugPort); + } else { + // Run the command + runEvalCmd(args, workspaceDir.path); } } } const runEvalCmd = (args: string[], cwd: string) => { - // Figure out which version of inspect to use - rely on path - // or form a full path - const binPath = inspectBinPath(); - const inspect = process.platform === "linux" && binPath && binPath.path.includes(".local/bin") ? binPath.path : "inspect"; - // See if there a non-busy terminal that we can re-use const name = "Inspect Eval"; let terminal = window.terminals.find((t) => { @@ -124,5 +123,20 @@ const runEvalCmd = (args: string[], cwd: string) => { terminal = window.createTerminal({ name, cwd }); } terminal.show(); - terminal.sendText([inspect, ...args].join(" ")); + terminal.sendText(["inspect", ...args].join(" ")); +}; + +const runDebugger = async (program: string, args: string[], cwd: string, port: number) => { + const name = "Inspect Eval"; + const debugConfiguration: DebugConfiguration = { + name, + type: "python", + request: "launch", + program, + args, + console: "internalConsole", + cwd, + port + }; + await debug.startDebugging(activeWorkspaceFolder(), debugConfiguration); }; diff --git a/tools/vscode/src/providers/inspect/inspect-manager.ts b/tools/vscode/src/providers/inspect/inspect-manager.ts index 78104018e..dad072eb8 100644 --- a/tools/vscode/src/providers/inspect/inspect-manager.ts +++ b/tools/vscode/src/providers/inspect/inspect-manager.ts @@ -1,16 +1,28 @@ import { Disposable, Event, EventEmitter, ExtensionContext } from "vscode"; import { pythonInterpreter } from "../../core/python"; import { inspectBinPath } from "../../inspect/props"; +import { AbsolutePath } from "../../core/path"; +import { delimiter } from "path"; // Activates the provider which tracks the availability of Inspect export function activateInspectManager(context: ExtensionContext) { const inspectManager = new InspectManager(context); + + // Initialize the terminal with the inspect bin path + // on the path (if needed) + const terminalEnv = terminalEnvironment(context); + context.subscriptions.push(inspectManager.onInspectChanged((e: InspectChangedEvent) => { + terminalEnv.update(e.binPath); + })); + terminalEnv.update(inspectBinPath()); + return inspectManager; } // Fired when the active task changes export interface InspectChangedEvent { available: boolean; + binPath: AbsolutePath | null; } export class InspectManager implements Disposable { @@ -23,18 +35,19 @@ export class InspectManager implements Disposable { ); this.updateInspectAvailable(); } - private inspectAvailable_: boolean = false; + private inspectBinPath_: string | undefined = undefined; get available(): boolean { - return !!this.inspectAvailable_; + return this.inspectBinPath_ !== null; } private updateInspectAvailable() { - const available = inspectBinPath() !== null; - const valueChanged = this.inspectAvailable_ !== available; - this.inspectAvailable_ = available; + const binPath = inspectBinPath(); + const available = binPath !== null; + const valueChanged = this.inspectBinPath_ !== binPath?.path; if (valueChanged) { - this.onInspectChanged_.fire({ available: this.inspectAvailable_ }); + this.inspectBinPath_ = binPath?.path; + this.onInspectChanged_.fire({ available: !!this.inspectBinPath_, binPath }); } if (!available) { this.watchForInspect(); @@ -67,3 +80,38 @@ export class InspectManager implements Disposable { public readonly onInspectChanged: Event = this.onInspectChanged_.event; } + +// Configures the terminal environment to support inspect. We do this +// to ensure the the 'inspect' command will work from within the +// terminal (especially in cases where the global interpreter is being used) +const terminalEnvironment = (context: ExtensionContext) => { + const filter = (binPath: AbsolutePath | null) => { + switch (process.platform) { + case "win32": + { + const localPath = process.env['LocalAppData']; + if (localPath) { + return binPath?.path.startsWith(localPath); + } + return false; + } + case "linux": + return binPath && binPath.path.includes(".local/bin"); + default: + return false; + } + }; + + return { + update: (binPath: AbsolutePath | null) => { + // The path info + const env = context.environmentVariableCollection; + env.delete('PATH'); + // Actually update the path + const binDir = binPath?.dirname(); + if (binDir && filter(binPath)) { + env.append('PATH', `${delimiter}${binDir.path}`); + } + } + }; +}; diff --git a/tools/vscode/src/providers/workspace/workspace-state-provider.ts b/tools/vscode/src/providers/workspace/workspace-state-provider.ts index ad6ab014a..09fc6d24a 100644 --- a/tools/vscode/src/providers/workspace/workspace-state-provider.ts +++ b/tools/vscode/src/providers/workspace/workspace-state-provider.ts @@ -18,6 +18,10 @@ export interface DocumentState { params?: Record; } +export interface ModelState { + lastModel?: string; +} + export class WorkspaceStateManager { constructor(private readonly context_: ExtensionContext) { } @@ -37,6 +41,14 @@ export class WorkspaceStateManager { public async setTaskState(taskFilePath: string, state: DocumentState, taskName?: string) { await this.context_.workspaceState.update(taskKey(taskFilePath, taskName), state); } + + public getModelState(provider: string): ModelState { + return this.context_.workspaceState.get(modelKey(provider)) || {}; + } + + public async setModelState(provider: string, state: ModelState) { + await this.context_.workspaceState.update(modelKey(provider), state); + } } function taskKey(file: string, task?: string) { @@ -47,3 +59,7 @@ function taskKey(file: string, task?: string) { } } +function modelKey(provider: string) { + return `provider-${provider}`; +} +