From 4311f635126f31966e6146cd891fbe42410e6040 Mon Sep 17 00:00:00 2001 From: Angela Date: Tue, 30 Jul 2024 16:54:57 -0700 Subject: [PATCH] cleanlab studio beta api --- cleanlab_studio/errors.py | 8 ++ cleanlab_studio/internal/api/api.py | 118 ++++++---------- cleanlab_studio/internal/api/api_helper.py | 52 ++++++- cleanlab_studio/internal/api/beta_api.py | 95 +++++++++++++ cleanlab_studio/internal/studio_base.py | 37 +++++ cleanlab_studio/internal/upload_helpers.py | 10 +- cleanlab_studio/studio/studio.py | 55 ++------ cleanlab_studio/studio_beta/__init__.py | 1 + cleanlab_studio/studio_beta/beta_dataset.py | 68 +++++++++ cleanlab_studio/studio_beta/beta_job.py | 145 ++++++++++++++++++++ cleanlab_studio/studio_beta/studio_beta.py | 30 ++++ 11 files changed, 497 insertions(+), 122 deletions(-) create mode 100644 cleanlab_studio/internal/api/beta_api.py create mode 100644 cleanlab_studio/internal/studio_base.py create mode 100644 cleanlab_studio/studio_beta/__init__.py create mode 100644 cleanlab_studio/studio_beta/beta_dataset.py create mode 100644 cleanlab_studio/studio_beta/beta_job.py create mode 100644 cleanlab_studio/studio_beta/studio_beta.py diff --git a/cleanlab_studio/errors.py b/cleanlab_studio/errors.py index c7829ecf..42d8aaeb 100644 --- a/cleanlab_studio/errors.py +++ b/cleanlab_studio/errors.py @@ -152,3 +152,11 @@ def __init__(self, filepath: Union[str, pathlib.Path] = "") -> None: if isinstance(filepath, pathlib.Path): filepath = str(filepath) super().__init__(f"File could not be found at {filepath}. Please check the file path.") + + +class BetaJobError(HandledError): + pass + + +class DownloadResultsError(HandledError): + pass diff --git a/cleanlab_studio/internal/api/api.py b/cleanlab_studio/internal/api/api.py index 9e749205..24011763 100644 --- a/cleanlab_studio/internal/api/api.py +++ b/cleanlab_studio/internal/api/api.py @@ -40,52 +40,22 @@ pyspark_exists = False from cleanlab_studio.errors import NotInstalledError -from cleanlab_studio.internal.api.api_helper import check_uuid_well_formed +from cleanlab_studio.internal.api.api_helper import ( + check_uuid_well_formed, + construct_headers, + handle_api_error, +) from cleanlab_studio.internal.types import JSONDict, SchemaOverride from cleanlab_studio.version import __version__ -base_url = os.environ.get("CLEANLAB_API_BASE_URL", "https://api.cleanlab.ai/api") -cli_base_url = f"{base_url}/cli/v0" -upload_base_url = f"{base_url}/upload/v1" -dataset_base_url = f"{base_url}/datasets" -project_base_url = f"{base_url}/projects" -cleanset_base_url = f"{base_url}/cleansets" -model_base_url = f"{base_url}/v1/deployment" -tlm_base_url = f"{base_url}/v0/trustworthy_llm" - - -def _construct_headers( - api_key: Optional[str], content_type: Optional[str] = "application/json" -) -> JSONDict: - retval = dict() - if api_key: - retval["Authorization"] = f"bearer {api_key}" - if content_type: - retval["Content-Type"] = content_type - retval["Client-Type"] = "python-api" - return retval - - -def handle_api_error(res: requests.Response) -> None: - handle_api_error_from_json(res.json(), res.status_code) - - -def handle_api_error_from_json(res_json: JSONDict, status_code: Optional[int] = None) -> None: - if "code" in res_json and "description" in res_json: # AuthError or UserQuotaError format - if res_json["code"] == "user_soft_quota_exceeded": - pass # soft quota limit is going away soon, so ignore it - else: - raise APIError(res_json["description"]) - - if res_json.get("error", None) is not None: - error = res_json["error"] - if ( - status_code == 422 - and isinstance(error, dict) - and error.get("code", None) == "UNSUPPORTED_PROJECT_CONFIGURATION" - ): - raise InvalidProjectConfiguration(error["description"]) - raise APIError(res_json["error"]) +API_BASE_URL = os.environ.get("CLEANLAB_API_BASE_URL", "https://api.cleanlab.ai/api") +cli_base_url = f"{API_BASE_URL}/cli/v0" +upload_base_url = f"{API_BASE_URL}/upload/v1" +dataset_base_url = f"{API_BASE_URL}/datasets" +project_base_url = f"{API_BASE_URL}/projects" +cleanset_base_url = f"{API_BASE_URL}/cleansets" +model_base_url = f"{API_BASE_URL}/v1/deployment" +tlm_base_url = f"{API_BASE_URL}/v0/trustworthy_llm" def handle_rate_limit_error_from_resp(resp: aiohttp.ClientResponse) -> None: @@ -134,7 +104,7 @@ def validate_api_key(api_key: str) -> bool: res = requests.get( cli_base_url + "/validate", json=dict(api_key=api_key), - headers=_construct_headers(api_key), + headers=construct_headers(api_key), ) handle_api_error(res) valid: bool = res.json()["valid"] @@ -154,7 +124,7 @@ def initialize_upload( res = requests.post( f"{upload_base_url}/file/initialize", json=dict(size_in_bytes=str(file_size), filename=filename, file_type=file_type), - headers=_construct_headers(api_key), + headers=construct_headers(api_key), ) handle_api_error(res) upload_id: str = res.json()["upload_id"] @@ -169,7 +139,7 @@ def complete_file_upload(api_key: str, upload_id: str, upload_parts: List[JSONDi res = requests.post( f"{upload_base_url}/file/complete", json=request_json, - headers=_construct_headers(api_key), + headers=construct_headers(api_key), ) handle_api_error(res) @@ -184,7 +154,7 @@ def confirm_upload( res = requests.post( f"{upload_base_url}/confirm", json=request_json, - headers=_construct_headers(api_key), + headers=construct_headers(api_key), ) handle_api_error(res) @@ -199,7 +169,7 @@ def update_schema( res = requests.patch( f"{upload_base_url}/schema", json=request_json, - headers=_construct_headers(api_key), + headers=construct_headers(api_key), ) handle_api_error(res) @@ -209,7 +179,7 @@ def get_ingestion_status(api_key: str, upload_id: str) -> JSONDict: res = requests.get( f"{upload_base_url}/total_progress", params=dict(upload_id=upload_id), - headers=_construct_headers(api_key), + headers=construct_headers(api_key), ) handle_api_error(res) res_json: JSONDict = res.json() @@ -221,7 +191,7 @@ def get_dataset_id(api_key: str, upload_id: str) -> JSONDict: res = requests.get( f"{upload_base_url}/dataset_id", params=dict(upload_id=upload_id), - headers=_construct_headers(api_key), + headers=construct_headers(api_key), ) handle_api_error(res) res_json: JSONDict = res.json() @@ -232,7 +202,7 @@ def get_project_of_cleanset(api_key: str, cleanset_id: str) -> str: check_uuid_well_formed(cleanset_id, "cleanset ID") res = requests.get( cli_base_url + f"/cleansets/{cleanset_id}/project", - headers=_construct_headers(api_key), + headers=construct_headers(api_key), ) handle_api_error(res) project_id: str = res.json()["project_id"] @@ -243,7 +213,7 @@ def get_label_column_of_project(api_key: str, project_id: str) -> str: check_uuid_well_formed(project_id, "project ID") res = requests.get( cli_base_url + f"/projects/{project_id}/label_column", - headers=_construct_headers(api_key), + headers=construct_headers(api_key), ) handle_api_error(res) label_column: str = res.json()["label_column"] @@ -274,7 +244,7 @@ def download_cleanlab_columns( include_cleanlab_columns=include_cleanlab_columns, include_project_details=include_project_details, ), - headers=_construct_headers(api_key), + headers=construct_headers(api_key), ) handle_api_error(res) id_col = get_id_column(api_key, cleanset_id) @@ -306,7 +276,7 @@ def download_array( check_uuid_well_formed(cleanset_id, "cleanset ID") res = requests.get( cli_base_url + f"/cleansets/{cleanset_id}/{name}", - headers=_construct_headers(api_key), + headers=construct_headers(api_key), ) handle_api_error(res) res_json: JSONDict = res.json() @@ -323,7 +293,7 @@ def get_id_column(api_key: str, cleanset_id: str) -> str: check_uuid_well_formed(cleanset_id, "cleanset ID") res = requests.get( cli_base_url + f"/cleansets/{cleanset_id}/id_column", - headers=_construct_headers(api_key), + headers=construct_headers(api_key), ) handle_api_error(res) id_column: str = res.json()["id_column"] @@ -334,7 +304,7 @@ def get_dataset_of_project(api_key: str, project_id: str) -> str: check_uuid_well_formed(project_id, "project ID") res = requests.get( cli_base_url + f"/projects/{project_id}/dataset", - headers=_construct_headers(api_key), + headers=construct_headers(api_key), ) handle_api_error(res) dataset_id: str = res.json()["dataset_id"] @@ -345,7 +315,7 @@ def get_dataset_schema(api_key: str, dataset_id: str) -> JSONDict: check_uuid_well_formed(dataset_id, "dataset ID") res = requests.get( cli_base_url + f"/datasets/{dataset_id}/schema", - headers=_construct_headers(api_key), + headers=construct_headers(api_key), ) handle_api_error(res) schema: JSONDict = res.json()["schema"] @@ -357,7 +327,7 @@ def get_dataset_details(api_key: str, dataset_id: str, task_type: Optional[str]) res = requests.get( project_base_url + f"/dataset_details/{dataset_id}", params=dict(tasktype=task_type), - headers=_construct_headers(api_key), + headers=construct_headers(api_key), ) handle_api_error(res) dataset_details: JSONDict = res.json() @@ -368,7 +338,7 @@ def check_column_diversity(api_key: str, dataset_id: str, column_name: str) -> J check_uuid_well_formed(dataset_id, "dataset ID") res = requests.get( dataset_base_url + f"/diversity/{dataset_id}/{column_name}", - headers=_construct_headers(api_key), + headers=construct_headers(api_key), ) handle_api_error(res) column_diversity: JSONDict = res.json() @@ -379,7 +349,7 @@ def is_valid_multilabel_column(api_key: str, dataset_id: str, column_name: str) check_uuid_well_formed(dataset_id, "dataset ID") res = requests.get( dataset_base_url + f"/check_valid_multilabel/{dataset_id}/{column_name}", - headers=_construct_headers(api_key), + headers=construct_headers(api_key), ) handle_api_error(res) multilabel_column: JSONDict = res.json() @@ -410,7 +380,7 @@ def clean_dataset( ) res = requests.post( project_base_url + f"/clean", - headers=_construct_headers(api_key), + headers=construct_headers(api_key), json=request_json, ) handle_api_error(res) @@ -422,7 +392,7 @@ def get_latest_cleanset_id(api_key: str, project_id: str) -> str: check_uuid_well_formed(project_id, "project ID") res = requests.get( cleanset_base_url + f"/project/{project_id}/latest_cleanset_id", - headers=_construct_headers(api_key), + headers=construct_headers(api_key), ) handle_api_error(res) cleanset_id = res.json()["cleanset_id"] @@ -448,7 +418,7 @@ def get_dataset_id_for_name( res = requests.get( dataset_base_url + f"/dataset_id_for_name", params=dict(dataset_name=dataset_name), - headers=_construct_headers(api_key), + headers=construct_headers(api_key), ) handle_api_error(res) return cast(Optional[str], res.json().get("dataset_id", None)) @@ -458,7 +428,7 @@ def get_cleanset_status(api_key: str, cleanset_id: str) -> JSONDict: check_uuid_well_formed(cleanset_id, "cleanset ID") res = requests.get( cleanset_base_url + f"/{cleanset_id}/status", - headers=_construct_headers(api_key), + headers=construct_headers(api_key), ) handle_api_error(res) status: JSONDict = res.json() @@ -467,13 +437,13 @@ def get_cleanset_status(api_key: str, cleanset_id: str) -> JSONDict: def delete_dataset(api_key: str, dataset_id: str) -> None: check_uuid_well_formed(dataset_id, "dataset ID") - res = requests.delete(dataset_base_url + f"/{dataset_id}", headers=_construct_headers(api_key)) + res = requests.delete(dataset_base_url + f"/{dataset_id}", headers=construct_headers(api_key)) handle_api_error(res) def delete_project(api_key: str, project_id: str) -> None: check_uuid_well_formed(project_id, "project ID") - res = requests.delete(project_base_url + f"/{project_id}", headers=_construct_headers(api_key)) + res = requests.delete(project_base_url + f"/{project_id}", headers=construct_headers(api_key)) handle_api_error(res) @@ -528,7 +498,7 @@ def deploy_model(api_key: str, cleanset_id: str, model_name: str) -> str: check_uuid_well_formed(cleanset_id, "cleanset ID") res = requests.post( model_base_url, - headers=_construct_headers(api_key), + headers=construct_headers(api_key), json=dict(cleanset_id=cleanset_id, deployment_name=model_name), ) @@ -542,7 +512,7 @@ def get_deployment_status(api_key: str, model_id: str) -> str: check_uuid_well_formed(model_id, "model ID") res = requests.get( f"{model_base_url}/{model_id}", - headers=_construct_headers(api_key), + headers=construct_headers(api_key), ) handle_api_error(res) deployment: JSONDict = res.json() @@ -555,7 +525,7 @@ def upload_predict_batch(api_key: str, model_id: str, batch: io.StringIO) -> str url = f"{model_base_url}/{model_id}/upload" res = requests.post( url, - headers=_construct_headers(api_key), + headers=construct_headers(api_key), ) handle_api_error(res) @@ -573,7 +543,7 @@ def start_prediction(api_key: str, model_id: str, query_id: str) -> None: check_uuid_well_formed(query_id, "query ID") res = requests.post( f"{model_base_url}/{model_id}/predict/{query_id}", - headers=_construct_headers(api_key), + headers=construct_headers(api_key), ) handle_api_error(res) @@ -584,7 +554,7 @@ def get_prediction_status(api_key: str, query_id: str) -> Dict[str, str]: check_uuid_well_formed(query_id, "query ID") res = requests.get( f"{model_base_url}/predict/{query_id}", - headers=_construct_headers(api_key), + headers=construct_headers(api_key), ) handle_api_error(res) @@ -596,7 +566,7 @@ def get_deployed_model_info(api_key: str, model_id: str) -> Dict[str, str]: check_uuid_well_formed(model_id, "model ID") res = requests.get( f"{model_base_url}/{model_id}", - headers=_construct_headers(api_key), + headers=construct_headers(api_key), ) handle_api_error(res) @@ -672,7 +642,7 @@ async def tlm_prompt( res = await client_session.post( f"{tlm_base_url}/prompt", json=dict(prompt=prompt, quality=quality_preset, options=options or {}), - headers=_construct_headers(api_key), + headers=construct_headers(api_key), ) res_json = await res.json() @@ -733,7 +703,7 @@ async def tlm_get_confidence_score( quality=quality_preset, options=options or {}, ), - headers=_construct_headers(api_key), + headers=construct_headers(api_key), ) res_json = await res.json() diff --git a/cleanlab_studio/internal/api/api_helper.py b/cleanlab_studio/internal/api/api_helper.py index a531e5e2..2864657f 100644 --- a/cleanlab_studio/internal/api/api_helper.py +++ b/cleanlab_studio/internal/api/api_helper.py @@ -1,6 +1,22 @@ import uuid +from typing import List, Optional, TypedDict -from cleanlab_studio.errors import InvalidUUIDError +import requests + +from cleanlab_studio.errors import ( + APIError, + InvalidProjectConfiguration, + InvalidUUIDError, +) +from cleanlab_studio.internal.types import JSONDict + + +class UploadPart(TypedDict): + ETag: str + PartNumber: int + + +UploadParts = List[UploadPart] def check_uuid_well_formed(uuid_string: str, id_name: str) -> None: @@ -10,3 +26,37 @@ def check_uuid_well_formed(uuid_string: str, id_name: str) -> None: raise InvalidUUIDError( f"{uuid_string} is not a well-formed {id_name}, please double check and try again." ) + + +def construct_headers( + api_key: Optional[str], content_type: Optional[str] = "application/json" +) -> JSONDict: + retval = dict() + if api_key: + retval["Authorization"] = f"bearer {api_key}" + if content_type: + retval["Content-Type"] = content_type + retval["Client-Type"] = "python-api" + return retval + + +def handle_api_error(res: requests.Response) -> None: + handle_api_error_from_json(res.json(), res.status_code) + + +def handle_api_error_from_json(res_json: JSONDict, status_code: Optional[int] = None) -> None: + if "code" in res_json and "description" in res_json: # AuthError or UserQuotaError format + if res_json["code"] == "user_soft_quota_exceeded": + pass # soft quota limit is going away soon, so ignore it + else: + raise APIError(res_json["description"]) + + if res_json.get("error", None) is not None: + error = res_json["error"] + if ( + status_code == 422 + and isinstance(error, dict) + and error.get("code", None) == "UNSUPPORTED_PROJECT_CONFIGURATION" + ): + raise InvalidProjectConfiguration(error["description"]) + raise APIError(res_json["error"]) diff --git a/cleanlab_studio/internal/api/beta_api.py b/cleanlab_studio/internal/api/beta_api.py new file mode 100644 index 00000000..311ac9be --- /dev/null +++ b/cleanlab_studio/internal/api/beta_api.py @@ -0,0 +1,95 @@ +from typing import Any, Dict, List + +import requests + +from .api import API_BASE_URL, construct_headers +from .api_helper import JSONDict, UploadParts, handle_api_error + +experimental_jobs_base_url = f"{API_BASE_URL}/v0/experimental_jobs" + + +def initialize_upload( + api_key: str, filename: str, file_type: str, file_size: int +) -> Dict[str, Any]: + url = f"{experimental_jobs_base_url}/upload/initialize" + headers = construct_headers(api_key) + data = { + "filename": filename, + "file_type": file_type, + "size_in_bytes": file_size, + } + resp = requests.post(url, headers=headers, json=data) + resp.raise_for_status() + return resp.json() + + +def complete_upload(api_key: str, dataset_id: str, upload_parts: UploadParts) -> JSONDict: + url = f"{experimental_jobs_base_url}/upload/complete" + headers = construct_headers(api_key) + data = { + "dataset_id": dataset_id, + "upload_parts": upload_parts, + } + resp = requests.post(url, headers=headers, json=data) + handle_api_error(resp) + return resp.json() + + +def get_dataset(api_key: str, dataset_id: str) -> JSONDict: + url = f"{experimental_jobs_base_url}/datasets/{dataset_id}" + headers = construct_headers(api_key) + resp = requests.get(url, headers=headers) + handle_api_error(resp) + return resp.json() + + +def run_job(api_key: str, dataset_id: str, job_definition_name: str) -> JSONDict: + url = f"{experimental_jobs_base_url}/run" + headers = construct_headers(api_key) + data = { + "dataset_id": dataset_id, + "job_definition_name": job_definition_name, + } + resp = requests.post(url, headers=headers, json=data) + handle_api_error(resp) + return resp.json() + + +def get_job(api_key: str, job_id: str) -> JSONDict: + url = f"{experimental_jobs_base_url}/{job_id}" + headers = construct_headers(api_key) + resp = requests.get(url, headers=headers) + handle_api_error(resp) + return resp.json() + + +def get_job_status(api_key: str, job_id: str) -> JSONDict: + url = f"{experimental_jobs_base_url}/{job_id}/status" + headers = construct_headers(api_key) + resp = requests.get(url, headers=headers) + resp.raise_for_status() + return resp.json() + + +def get_results(api_key: str, job_id: str) -> JSONDict: + url = f"{experimental_jobs_base_url}/{job_id}/results" + headers = construct_headers(api_key) + resp = requests.get(url, headers=headers) + resp.raise_for_status() + return resp.json() + + +def list_datasets(api_key: str) -> List[JSONDict]: + url = f"{experimental_jobs_base_url}/datasets" + headers = construct_headers(api_key) + resp = requests.get(url, headers=headers) + handle_api_error(resp) + return resp.json()["datasets"] + + +def list_jobs(api_key: str) -> List[JSONDict]: + url = f"{experimental_jobs_base_url}/jobs" + headers = construct_headers(api_key) + resp = requests.get(url, headers=headers) + handle_api_error(resp) + return resp.json()["jobs"] diff --git a/cleanlab_studio/internal/studio_base.py b/cleanlab_studio/internal/studio_base.py new file mode 100644 index 00000000..2d6c6d2b --- /dev/null +++ b/cleanlab_studio/internal/studio_base.py @@ -0,0 +1,37 @@ +from aiohttp_retry import Optional + +from cleanlab_studio.errors import MissingAPIKeyError, VersionError +from cleanlab_studio.internal.api import api +from cleanlab_studio.internal.settings import CleanlabSettings + + +class StudioBase: + _api_key: str + + def __init__(self, api_key: Optional[str]): + """ + Creates a Cleanlab Studio client. + + Args: + api_key: You can find your API key on your [account page](https://app.cleanlab.ai/account) in Cleanlab Studio. Instead of specifying the API key here, you can also log in with `cleanlab login` on the command-line. + + """ + if not api.is_valid_client_version(): + raise VersionError( + "CLI is out of date and must be updated. Run 'pip install --upgrade cleanlab-studio'." + ) + if api_key is None: + try: + api_key = CleanlabSettings.load().api_key + if api_key is None: + raise ValueError + except (FileNotFoundError, KeyError, ValueError): + raise MissingAPIKeyError( + "No API key found; either specify API key or log in with 'cleanlab login' first" + ) + if not api.validate_api_key(api_key): + raise ValueError( + f"Invalid API key, please check if it is properly specified: {api_key}" + ) + + self._api_key = api_key diff --git a/cleanlab_studio/internal/upload_helpers.py b/cleanlab_studio/internal/upload_helpers.py index 2a70190d..5cade0c4 100644 --- a/cleanlab_studio/internal/upload_helpers.py +++ b/cleanlab_studio/internal/upload_helpers.py @@ -2,17 +2,19 @@ import functools import json from typing import Any, Dict, List, Optional -from tqdm import tqdm import aiohttp -from multidict import CIMultiDictProxy import requests +from multidict import CIMultiDictProxy from requests.adapters import HTTPAdapter, Retry +from tqdm import tqdm + +from cleanlab_studio.errors import InvalidSchemaTypeError from .api import api +from .api.api_helper import UploadParts from .dataset_source import DatasetSource from .types import JSONDict, SchemaOverride -from cleanlab_studio.errors import InvalidSchemaTypeError def upload_dataset( @@ -64,7 +66,7 @@ async def upload_file_parts_async( def upload_file_parts( dataset_source: DatasetSource, part_sizes: List[int], presigned_posts: List[str] -) -> List[JSONDict]: +) -> UploadParts: session = requests.Session() session.mount("https://", adapter=HTTPAdapter(max_retries=Retry(total=3, backoff_factor=1))) diff --git a/cleanlab_studio/studio/studio.py b/cleanlab_studio/studio/studio.py index 6f41dd7f..1a508ffa 100644 --- a/cleanlab_studio/studio/studio.py +++ b/cleanlab_studio/studio/studio.py @@ -2,30 +2,29 @@ Python API for Cleanlab Studio. """ -from typing import Any, List, Literal, Optional, Union -from types import FunctionType import warnings +from types import FunctionType +from typing import Any, List, Literal, Optional, Union import numpy as np import numpy.typing as npt import pandas as pd -from . import inference -from . import trustworthy_language_model -from cleanlab_studio.utils import tlm_hybrid -from cleanlab_studio.errors import CleansetError +from cleanlab_studio.errors import CleansetError, InvalidDatasetError from cleanlab_studio.internal import clean_helpers, deploy_helpers, upload_helpers from cleanlab_studio.internal.api import api +from cleanlab_studio.internal.studio_base import StudioBase +from cleanlab_studio.internal.types import SchemaOverride, TLMQualityPreset from cleanlab_studio.internal.util import ( - init_dataset_source, - telemetry, + apply_corrections_pd_df, apply_corrections_snowpark_df, apply_corrections_spark_df, - apply_corrections_pd_df, + init_dataset_source, + telemetry, ) -from cleanlab_studio.internal.settings import CleanlabSettings -from cleanlab_studio.internal.types import SchemaOverride, TLMQualityPreset -from cleanlab_studio.errors import VersionError, MissingAPIKeyError, InvalidDatasetError +from cleanlab_studio.utils import tlm_hybrid + +from . import inference, trustworthy_language_model _snowflake_exists = api.snowflake_exists if _snowflake_exists: @@ -36,37 +35,7 @@ import pyspark.sql -class Studio: - _api_key: str - - def __init__(self, api_key: Optional[str]): - """ - Creates a Cleanlab Studio client. - - Args: - api_key: You can find your API key on your [account page](https://app.cleanlab.ai/account) in Cleanlab Studio. Instead of specifying the API key here, you can also log in with `cleanlab login` on the command-line. - - """ - if not api.is_valid_client_version(): - raise VersionError( - "CLI is out of date and must be updated. Run 'pip install --upgrade cleanlab-studio'." - ) - if api_key is None: - try: - api_key = CleanlabSettings.load().api_key - if api_key is None: - raise ValueError - except (FileNotFoundError, KeyError, ValueError): - raise MissingAPIKeyError( - "No API key found; either specify API key or log in with 'cleanlab login' first" - ) - if not api.validate_api_key(api_key): - raise ValueError( - f"Invalid API key, please check if it is properly specified: {api_key}" - ) - - self._api_key = api_key - +class Studio(StudioBase): def upload_dataset( self, dataset: Any, diff --git a/cleanlab_studio/studio_beta/__init__.py b/cleanlab_studio/studio_beta/__init__.py new file mode 100644 index 00000000..18721a11 --- /dev/null +++ b/cleanlab_studio/studio_beta/__init__.py @@ -0,0 +1 @@ +from cleanlab_studio.studio_beta.studio_beta import StudioBeta as StudioBeta diff --git a/cleanlab_studio/studio_beta/beta_dataset.py b/cleanlab_studio/studio_beta/beta_dataset.py new file mode 100644 index 00000000..34d3c393 --- /dev/null +++ b/cleanlab_studio/studio_beta/beta_dataset.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import pathlib +from dataclasses import dataclass +from typing import List + +from cleanlab_studio.internal.api.beta_api import ( + complete_upload, + get_dataset, + initialize_upload, + list_datasets, +) +from cleanlab_studio.internal.dataset_source import FilepathDatasetSource +from cleanlab_studio.internal.upload_helpers import upload_file_parts + + +@dataclass +class BetaDataset: + id: str + filename: str + upload_complete: bool + upload_date: int + + @classmethod + def from_id(cls, api_key: str, dataset_id: str) -> "BetaDataset": + dataset = get_dataset(api_key, dataset_id) + return cls( + id=dataset_id, + filename=dataset["filename"], + upload_complete=dataset["complete"], + upload_date=dataset["upload_date"], + ) + + @classmethod + def from_filepath(cls, api_key: str, filepath: str) -> "BetaDataset": + dataset_source = FilepathDatasetSource(filepath=pathlib.Path(filepath)) + initialize_response = initialize_upload( + api_key, + dataset_source.get_filename(), + dataset_source.file_type, + dataset_source.file_size, + ) + dataset_id = initialize_response["id"] + part_sizes = initialize_response["part_sizes"] + presigned_posts = initialize_response["presigned_posts"] + + # TODO: upload file parts + upload_parts = upload_file_parts(dataset_source, part_sizes, presigned_posts) + dataset = complete_upload(api_key, dataset_id, upload_parts) + return cls( + id=dataset_id, + filename=dataset["filename"], + upload_complete=dataset["complete"], + upload_date=dataset["upload_date"], + ) + + @classmethod + def list(cls, api_key: str) -> List[BetaDataset]: + datasets = list_datasets(api_key) + return [ + cls( + id=dataset["id"], + filename=dataset["filename"], + upload_complete=dataset["complete"], + upload_date=dataset["upload_date"], + ) + for dataset in datasets + ] diff --git a/cleanlab_studio/studio_beta/beta_job.py b/cleanlab_studio/studio_beta/beta_job.py new file mode 100644 index 00000000..c32be862 --- /dev/null +++ b/cleanlab_studio/studio_beta/beta_job.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +import enum +import itertools +import pathlib +import time +from dataclasses import dataclass +from typing import Optional + +import requests +from tqdm import tqdm + +from cleanlab_studio.errors import BetaJobError, DownloadResultsError +from cleanlab_studio.internal.api.beta_api import ( + get_job, + get_job_status, + get_results, + list_jobs, + run_job, +) + + +class JobStatus(enum.Enum): + CREATED = 0 + RUNNING = 1 + READY = 2 + FAILED = -1 + + @classmethod + def from_name(cls, name: str) -> "JobStatus": + return cls[name.upper()] + + +@dataclass +class BetaJob: + id: str + status: JobStatus + dataset_id: str + job_definition_name: str + created_at: int + _api_key: str + + @classmethod + def from_id(cls, api_key: str, job_id: str) -> "BetaJob": + """Loads an existing job by ID.""" + job_resp = get_job(api_key, job_id) + job = cls( + _api_key=api_key, + id=job_resp["id"], + dataset_id=job_resp["dataset_id"], + job_definition_name=job_resp["job_definition_name"], + status=JobStatus.from_name(job_resp["status"]), + created_at=job_resp["created_at"], + ) + return job + + @classmethod + def run(cls, api_key: str, dataset_id: str, job_definition_name: str) -> "BetaJob": + """Creates and runs a new job with the given dataset and job definition.""" + job_resp = run_job(api_key, dataset_id, job_definition_name) + job = cls( + _api_key=api_key, + id=job_resp["id"], + dataset_id=dataset_id, + job_definition_name=job_definition_name, + status=JobStatus.from_name(job_resp["status"]), + created_at=job_resp["created_at"], + ) + return job + + def wait_until_ready(self, timeout: Optional[int] = None) -> None: + """Blocks until a job is ready or the timeout is reached. + + Args: + timeout (Optional[float], optional): timeout for polling, in seconds. Defaults to None. + + Raises: + TimeoutError: if job is not ready by end of timeout + BetaJobError: if job fails + """ + start_time = time.time() + res = get_job_status(self._api_key, self.id) + self.status = JobStatus.from_name(res["status"]) + spinner = itertools.cycle("|/-\\") + + with tqdm( + total=JobStatus.READY.value, + desc="Job Progress: \\", + bar_format="{desc} {postfix}", + ) as pbar: + while self.status != JobStatus.READY and self.status != JobStatus.FAILED: + pbar.set_postfix_str(self.status.name.capitalize()) + pbar.update(int(self.status.value) - pbar.n) + + if timeout is not None and time.time() - start_time > timeout: + raise TimeoutError("Result not ready before timeout") + + for _ in range(50): + time.sleep(0.1) + pbar.set_description_str(f"Job Progress: {next(spinner)}") + + res = get_job_status(self._api_key, self.id) + self.status = JobStatus.from_name(res["status"]) + + if self.status == JobStatus.READY: + pbar.update(pbar.total - pbar.n) + pbar.set_postfix_str(self.status.name.capitalize()) + return + + if self.status == JobStatus.FAILED: + pbar.set_postfix_str(self.status.name.capitalize()) + raise BetaJobError(f"Experimental job {self.id} failed to complete") + + def download_results(self, output_filepath: str) -> None: + output_path = pathlib.Path(output_filepath) + if self.status != JobStatus.READY: + raise BetaJobError("Job must be ready to download results") + + if self.status == JobStatus.FAILED: + raise BetaJobError("Job failed, cannot download results") + + results = get_results(self._api_key, self.id) + if output_path.suffix != results["result_file_type"]: + raise DownloadResultsError( + f"Output file extension does not match result file type {results['result_file_type']}" + ) + + resp = requests.get(results["result_url"]) + resp.raise_for_status() + output_path.write_bytes(resp.content) + + @classmethod + def list(cls, api_key: str) -> None: + jobs = list_jobs(api_key) + return [ + cls( + _api_key=api_key, + id=job["id"], + dataset_id=job["dataset_id"], + job_definition_name=job["job_definition_name"], + status=JobStatus.from_name(job["status"]), + created_at=job["created_at"], + ) + for job in jobs + ] diff --git a/cleanlab_studio/studio_beta/studio_beta.py b/cleanlab_studio/studio_beta/studio_beta.py new file mode 100644 index 00000000..551f4256 --- /dev/null +++ b/cleanlab_studio/studio_beta/studio_beta.py @@ -0,0 +1,30 @@ +from typing import List + +from cleanlab_studio.internal.studio_base import StudioBase +from cleanlab_studio.studio_beta.beta_dataset import BetaDataset +from cleanlab_studio.studio_beta.beta_job import BetaJob + + +class StudioBeta(StudioBase): + def upload_dataset( + self, + filepath: str, + ) -> BetaDataset: + """Uploads an experimental dataset from the given filepath.""" + return BetaDataset.from_filepath(self._api_key, filepath) + + def run_job(self, dataset_id: str, job_definition_name: str) -> BetaJob: + """Runs an experimental job with the given dataset and job definition.""" + return BetaJob.run(self._api_key, dataset_id, job_definition_name) + + def download_results(self, job_id: str, output_filename: str) -> None: + """Downloads the results of an experimental job to the given output filename.""" + BetaJob.from_id(self._api_key, job_id).download_results(output_filename) + + def list_datasets(self) -> List[BetaDataset]: + """Lists all datasets you have uploaded through the beta API.""" + return BetaDataset.list(self._api_key) + + def list_jobs(self) -> List[BetaJob]: + """Lists all jobs you have run through the beta API.""" + return BetaJob.list(self._api_key)