Skip to content

Commit

Permalink
Merge pull request #580 from mlte-team/feature/store-improvements
Browse files Browse the repository at this point in the history
Feature/store improvements
  • Loading branch information
sei-secheverria authored Jan 15, 2025
2 parents 05278ce + c317443 commit d9849e2
Show file tree
Hide file tree
Showing 25 changed files with 219 additions and 196 deletions.
2 changes: 1 addition & 1 deletion mlte/backend/api/endpoints/catalog_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def create_catalog_entry(
)

return catalog_session.entry_mapper.create_with_header(
entry, current_user.username
entry, user=current_user.username
)
except errors.ErrorNotFound as e:
raise HTTPException(
Expand Down
6 changes: 3 additions & 3 deletions mlte/backend/api/endpoints/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from mlte.backend.api.auth.authorization import AuthorizedUser
from mlte.backend.api.error_handlers import raise_http_internal_error
from mlte.backend.core import state_stores
from mlte.context.model import Model, ModelCreate, Version, VersionCreate
from mlte.context.model import Model, Version
from mlte.store.user.policy import Policy
from mlte.user.model import ResourceType

Expand All @@ -26,7 +26,7 @@
@router.post("")
def create_model(
*,
model: ModelCreate,
model: Model,
current_user: AuthorizedUser,
) -> Model:
"""
Expand Down Expand Up @@ -147,7 +147,7 @@ def delete_model(
def create_version(
*,
model_id: str,
version: VersionCreate,
version: Version,
current_user: AuthorizedUser,
) -> Version:
"""
Expand Down
20 changes: 1 addition & 19 deletions mlte/context/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,36 +9,18 @@
from mlte.model import BaseModel


class VersionCreate(BaseModel):
"""The model that defines the data necessary to create a MLTE version."""

identifier: str
"""The identifier for the version."""


class Version(BaseModel):
"""Model implementation for MLTE model version."""

identifier: str
"""The identifier for the model version."""

# TODO(Kyle): In the future, we may implement new endpoints
# that allow one to GET /version to get all artifacts associated
# with a (model, version) duple


class ModelCreate(BaseModel):
"""The model that defines the data necessary to create a MLTE model."""

identifier: str
"""The identifier for the model."""


class Model(BaseModel):
"""Model implementation for MLTE model identifier."""

identifier: str
"""The identifier for the model."""

versions: List[Version]
versions: List[Version] = []
"""A collection of the model versions."""
6 changes: 3 additions & 3 deletions mlte/store/artifact/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import List, Optional, cast

from mlte.artifact.model import ArtifactModel
from mlte.context.model import Model, ModelCreate, Version, VersionCreate
from mlte.context.model import Model, Version
from mlte.store.base import ManagedSession, Store, StoreSession
from mlte.store.query import Query

Expand Down Expand Up @@ -47,7 +47,7 @@ class ArtifactStoreSession(StoreSession):
# Interface: Context
# -------------------------------------------------------------------------

def create_model(self, model: ModelCreate) -> Model:
def create_model(self, model: Model) -> Model:
"""
Create a MLTE model.
:param model: The model data to create the model
Expand Down Expand Up @@ -86,7 +86,7 @@ def delete_model(self, model_id: str) -> Model:
"Cannot invoke method on abstract ArtifactStoreSession."
)

def create_version(self, model_id: str, version: VersionCreate) -> Version:
def create_version(self, model_id: str, version: Version) -> Version:
"""
Create a MLTE model version.
:param model_id: The identifier for the model
Expand Down
6 changes: 3 additions & 3 deletions mlte/store/artifact/underlying/fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import mlte.store.artifact.util as storeutil
import mlte.store.error as errors
from mlte.artifact.model import ArtifactModel
from mlte.context.model import Model, ModelCreate, Version, VersionCreate
from mlte.context.model import Model, Version
from mlte.store.artifact.store import ArtifactStore, ArtifactStoreSession
from mlte.store.base import StoreURI
from mlte.store.common.fs_storage import FileSystemStorage
Expand Down Expand Up @@ -65,7 +65,7 @@ def close(self) -> None:
# Structural Elements
# -------------------------------------------------------------------------

def create_model(self, model: ModelCreate) -> Model:
def create_model(self, model: Model) -> Model:
try:
self.storage.create_folder(
Path(self.storage.base_path, model.identifier)
Expand All @@ -91,7 +91,7 @@ def delete_model(self, model_id: str) -> Model:
self.storage.delete_folder(Path(self.storage.base_path, model_id))
return model

def create_version(self, model_id: str, version: VersionCreate) -> Version:
def create_version(self, model_id: str, version: Version) -> Version:
self._ensure_model_exists(model_id)

try:
Expand Down
6 changes: 3 additions & 3 deletions mlte/store/artifact/underlying/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from mlte.artifact.model import ArtifactModel
from mlte.backend.api.models.artifact_model import WriteArtifactRequest
from mlte.backend.core.config import settings
from mlte.context.model import Model, ModelCreate, Version, VersionCreate
from mlte.context.model import Model, Version
from mlte.store.artifact.store import ArtifactStore, ArtifactStoreSession
from mlte.store.base import StoreURI
from mlte.store.common.http_clients import OAuthHttpClient
Expand Down Expand Up @@ -72,7 +72,7 @@ def close(self):
# Structural Elements
# -------------------------------------------------------------------------

def create_model(self, model: ModelCreate) -> Model:
def create_model(self, model: Model) -> Model:
url = f"{self.url}{API_PREFIX}/model"
res = self.client.post(url, json=model.to_json())
self.client.raise_for_response(res)
Expand Down Expand Up @@ -100,7 +100,7 @@ def delete_model(self, model_id: str) -> Model:

return Model(**res.json())

def create_version(self, model_id: str, version: VersionCreate) -> Version:
def create_version(self, model_id: str, version: Version) -> Version:
url = f"{self.url}{API_PREFIX}/model/{model_id}/version"
res = self.client.post(url, json=version.to_json())
self.client.raise_for_response(res)
Expand Down
6 changes: 3 additions & 3 deletions mlte/store/artifact/underlying/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import mlte.store.artifact.util as storeutil
import mlte.store.error as errors
from mlte.artifact.model import ArtifactModel
from mlte.context.model import Model, ModelCreate, Version, VersionCreate
from mlte.context.model import Model, Version
from mlte.store.artifact.store import ArtifactStore, ArtifactStoreSession
from mlte.store.base import StoreURI
from mlte.store.query import Query
Expand Down Expand Up @@ -72,7 +72,7 @@ def close(self) -> None:
# Structural Elements
# -------------------------------------------------------------------------

def create_model(self, model: ModelCreate) -> Model:
def create_model(self, model: Model) -> Model:
if model.identifier in self.storage.models:
raise errors.ErrorAlreadyExists(f"Model {model.identifier}")
self.storage.models[model.identifier] = ModelWithVersions(
Expand All @@ -96,7 +96,7 @@ def delete_model(self, model_id: str) -> Model:
del self.storage.models[model_id]
return popped

def create_version(self, model_id: str, version: VersionCreate) -> Version:
def create_version(self, model_id: str, version: Version) -> Version:
if model_id not in self.storage.models:
raise errors.ErrorNotFound(f"Model {model_id}")

Expand Down
6 changes: 3 additions & 3 deletions mlte/store/artifact/underlying/rdbs/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import mlte.store.artifact.util as storeutil
import mlte.store.error as errors
from mlte.artifact.model import ArtifactModel
from mlte.context.model import Model, ModelCreate, Version, VersionCreate
from mlte.context.model import Model, Version
from mlte.store.artifact.store import ArtifactStore, ArtifactStoreSession
from mlte.store.artifact.underlying.rdbs import factory
from mlte.store.artifact.underlying.rdbs.metadata import (
Expand Down Expand Up @@ -88,7 +88,7 @@ def close(self) -> None:
# Structural Elements
# -------------------------------------------------------------------------

def create_model(self, model: ModelCreate) -> Model:
def create_model(self, model: Model) -> Model:
with Session(self.storage.engine) as session:
try:
_, _ = DBReader.get_model(model.identifier, session)
Expand Down Expand Up @@ -125,7 +125,7 @@ def delete_model(self, model_id: str) -> Model:
session.commit()
return model

def create_version(self, model_id: str, version: VersionCreate) -> Version:
def create_version(self, model_id: str, version: Version) -> Version:
with Session(self.storage.engine) as session:
try:
_, _ = DBReader.get_version(
Expand Down
6 changes: 3 additions & 3 deletions mlte/store/artifact/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

import mlte.store.error as errors
from mlte.context.model import ModelCreate, VersionCreate
from mlte.context.model import Model, Version
from mlte.store.artifact.store import ArtifactStoreSession


Expand All @@ -21,11 +21,11 @@ def create_parents(
:param version_id: The version identifier
"""
try:
session.create_model(ModelCreate(identifier=model_id))
session.create_model(Model(identifier=model_id))
except errors.ErrorAlreadyExists:
pass

try:
session.create_version(model_id, VersionCreate(identifier=version_id))
session.create_version(model_id, Version(identifier=version_id))
except errors.ErrorAlreadyExists:
pass
41 changes: 27 additions & 14 deletions mlte/store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from __future__ import annotations

from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, List, Protocol

Expand Down Expand Up @@ -163,8 +164,10 @@ def __exit__(self, exc_type, exc_value, exc_tb) -> None:
self.session.close()


class ResourceMapper:
"""A generic interface for mapping CRUD actions to store specific resources."""
class ResourceMapper(ABC):
"""
A generic interface for mapping CRUD actions to store specific resources.
"""

NOT_IMPLEMENTED_ERROR_MSG = (
"Cannot invoke method that has not been implemented for this mapper."
Expand All @@ -174,70 +177,80 @@ class ResourceMapper:
DEFAULT_LIST_LIMIT = 100
"""Default limit for lists."""

def create(self, new_resource: Any) -> Any:
@abstractmethod
def create(self, new_resource: Any, context: Any = None) -> Any:
"""
Create a new resource.
:param new_resource: The data to create the resource
:param context: Any additional context needed for this resource.
:return: The created resource
"""
raise NotImplementedError(self.NOT_IMPLEMENTED_ERROR_MSG)

def edit(self, updated_resource: Any) -> Any:
@abstractmethod
def edit(self, updated_resource: Any, context: Any = None) -> Any:
"""
Edit an existing resource.
:param updated_resource: The data to edit the resource
:param context: Any additional context needed for this resource.
:return: The edited resource
"""
raise NotImplementedError(self.NOT_IMPLEMENTED_ERROR_MSG)

def read(self, resource_identifier: str) -> Any:
@abstractmethod
def read(self, resource_identifier: str, context: Any = None) -> Any:
"""
Read a resource.
:param resource_identifier: The identifier for the resource
:param context: Any additional context needed for this resource.
:return: The resource
"""
raise NotImplementedError(self.NOT_IMPLEMENTED_ERROR_MSG)

def list(self) -> List[str]:
@abstractmethod
def list(self, context: Any = None) -> List[str]:
"""
List all resources of this type in the store.
:param context: Any additional context needed for this resource.
:return: A collection of identifiers for all resources of this type
"""
raise NotImplementedError(self.NOT_IMPLEMENTED_ERROR_MSG)

def delete(self, resource_identifier: str) -> Any:
@abstractmethod
def delete(self, resource_identifier: str, context: Any = None) -> Any:
"""
Delete a resource.
:param resource_identifier: The identifier for the resource
:param context: Any additional context needed for this resource.
:return: The deleted resource
"""
raise NotImplementedError(self.NOT_IMPLEMENTED_ERROR_MSG)

def list_details(
self,
context: Any = None,
limit: int = DEFAULT_LIST_LIMIT,
offset: int = 0,
) -> List[Any]:
"""
Read details of resources within limit and offset.
:param context: Any additional context needed for this resource.
:param limit: The limit on resources to read
:param offset: The offset on resources to read
:return: The read resources
"""
entry_ids = self.list()
return [self.read(entry_id) for entry_id in entry_ids][
entry_ids = self.list(context)
return [self.read(entry_id, context) for entry_id in entry_ids][
offset : offset + limit
]

def search(
self,
query: Query = Query(),
) -> List[Any]:
def search(self, query: Query = Query(), context: Any = None) -> List[Any]:
"""
Read a collection of resources, optionally filtered.
:param query: The resource query to apply
:param context: Any additional context needed for this resource.
:return: A collection of resources that satisfy the filter
"""
# TODO: not the most efficient way, since it loads all items first, before filtering.
entries = self.list_details()
entries = self.list_details(context)
return [entry for entry in entries if query.filter.match(entry)]
6 changes: 4 additions & 2 deletions mlte/store/catalog/catalog_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,15 @@ def list_details(
)

catalog_session = self.sessions[catalog_id]
return catalog_session.entry_mapper.list_details(limit, offset)
return catalog_session.entry_mapper.list_details(
limit=limit, offset=offset
)
else:
# Go over all catalogs, reading from each one, and grouping results.
results: List[CatalogEntry] = []
for catalog_id, session in self.sessions.items():
partial_results = session.entry_mapper.list_details(
limit, offset
limit=limit, offset=offset
)
results.extend(partial_results)
return results[offset : offset + limit]
Expand Down
Loading

0 comments on commit d9849e2

Please sign in to comment.