Skip to content

Commit

Permalink
Merge pull request #139 from Oxid15/develop
Browse files Browse the repository at this point in the history
v0.9.0 - Stability update
  • Loading branch information
Oxid15 authored Dec 16, 2022
2 parents 96bd86b + 271933d commit 203f8ba
Show file tree
Hide file tree
Showing 62 changed files with 1,084 additions and 363 deletions.
2 changes: 1 addition & 1 deletion cascade/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""


__version__ = '0.8.0'
__version__ = '0.9.0'
__author__ = 'Ilia Moiseev'
__author_email__ = '[email protected]'

Expand Down
22 changes: 21 additions & 1 deletion cascade/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,23 @@
"""
Copyright 2022 Ilia Moiseev
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from typing import List, Dict, Any

Meta = List[Dict[Any, Any]]

from .meta_handler import MetaHandler, supported_meta_formats
from .traceable import Traceable
from .traceable import Traceable, Meta
from .meta_handler import CustomEncoder as JSONEncoder
40 changes: 24 additions & 16 deletions cascade/base/meta_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,19 @@
import os
import json
import datetime
from typing import Union, List, Dict
from typing import NoReturn, Union, List, Dict, Any
from json import JSONEncoder

import yaml
import numpy as np

from . import Meta

supported_meta_formats = ('.json', '.yml')


class CustomEncoder(JSONEncoder):
def default(self, obj):
def default(self, obj: Any) -> Any:
if isinstance(obj, type):
return str(obj)

Expand Down Expand Up @@ -60,26 +61,29 @@ def default(self, obj):

return super(CustomEncoder, self).default(obj)

def obj_to_dict(self, obj) -> Dict:
def obj_to_dict(self, obj: Any) -> Dict:
return json.loads(self.encode(obj))


class BaseHandler:
def read(self, path: str) -> Union[Dict, List[Dict]]:
def read(self, path: str) -> Union[List[Any], Dict[Any, Any]]:
raise NotImplementedError()

def write(self, path: str, obj, overwrite=True) -> None:
def write(self, path: str, obj: Any, overwrite: bool = True) -> None:
raise NotImplementedError()

def _raise_io_error(self, path, exc):
def _raise_io_error(self, path: str, exc: Union[Exception, None] = None) -> NoReturn:
# Any file decoding errors will be
# prepended with filepath for user
# to be able to identify broken file
raise IOError(f'Error while reading file `{path}`') from exc
if exc is not None:
raise IOError(f'Error while reading file `{path}`') from exc
else:
raise IOError(f'Error while reading file `{path}`')


class JSONHandler(BaseHandler):
def read(self, path: str) -> Union[Dict, List[Dict]]:
def read(self, path: str) -> Union[List[Any], Dict[Any, Any]]:
_, ext = os.path.splitext(path)
if ext == '':
path += '.json'
Expand All @@ -93,7 +97,7 @@ def read(self, path: str) -> Union[Dict, List[Dict]]:
self._raise_io_error(path, e)
return meta

def write(self, path: str, obj: List[Dict], overwrite=True) -> None:
def write(self, path: str, obj: List[Dict], overwrite: bool = True) -> None:
if not overwrite and os.path.exists(path):
return

Expand All @@ -102,19 +106,23 @@ def write(self, path: str, obj: List[Dict], overwrite=True) -> None:


class YAMLHandler(BaseHandler):
def read(self, path: str) -> Union[Dict, List[Dict]]:
def read(self, path: str) -> Union[List[Any], Dict[Any, Any]]:
_, ext = os.path.splitext(path)
if ext == '':
path += '.yml'

with open(path, 'r') as meta_file:
try:
meta = yaml.safe_load(meta_file)

# Safe load may return None if something wrong
if meta is None:
self._raise_io_error(path)
except yaml.YAMLError as e:
self._raise_io_error(path, e)
return meta

def write(self, path: str, obj, overwrite=True) -> None:
def write(self, path: str, obj: Any, overwrite: bool = True) -> None:
if not overwrite and os.path.exists(path):
return

Expand All @@ -139,7 +147,7 @@ def read(self, path: str) -> Dict:
meta = {path: ''.join(meta_file.readlines())}
return meta

def write(self, path, obj, overwrite=True) -> None:
def write(self, path: str, obj: Any, overwrite: bool = True) -> NoReturn:
raise NotImplementedError(
'MetaHandler does not write text files, only reads')

Expand All @@ -160,7 +168,7 @@ class MetaHandler:
>>> mh.write('meta.yml', {'hello': 'world'})
>>> obj = mh.read('meta.yml')
"""
def read(self, path: str) -> Union[Dict, List[Dict]]:
def read(self, path: str) -> Union[List[Any], Dict[Any, Any]]:
"""
Reads object from path.
Expand All @@ -171,7 +179,7 @@ def read(self, path: str) -> Union[Dict, List[Dict]]:
Returns
-------
obj: Union[Dict, List[Dict]]
obj: Union[List[Any], Dict[Any, Any]]
Raises
------
Expand All @@ -181,7 +189,7 @@ def read(self, path: str) -> Union[Dict, List[Dict]]:
handler = self._get_handler(path)
return handler.read(path)

def write(self, path: str, obj, overwrite: bool = True) -> None:
def write(self, path: str, obj: Any, overwrite: bool = True) -> None:
"""
Writes object to path.
Expand All @@ -203,7 +211,7 @@ def write(self, path: str, obj, overwrite: bool = True) -> None:
handler = self._get_handler(path)
return handler.write(path, obj, overwrite=overwrite)

def _get_handler(self, path) -> BaseHandler:
def _get_handler(self, path: str) -> BaseHandler:
ext = os.path.splitext(path)[-1]
if ext == '.json':
return JSONHandler()
Expand Down
36 changes: 33 additions & 3 deletions cascade/base/traceable.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,37 @@
"""
Copyright 2022 Ilia Moiseev
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""


import warnings
from typing import List, Dict, Union
from typing import List, Dict, Union, Any

from . import Meta


class Traceable:
"""
Base class for everything that has metadata in cascade.
Handles the logic of getting and updating internal meta prefix.
"""
def __init__(self, *args, meta_prefix: Union[Dict, str] = None, **kwargs) -> None:
def __init__(
self,
*args: Any,
meta_prefix: Union[Meta, str, None] = None,
**kwargs: Any
) -> None:
"""
Parameters
----------
Expand All @@ -27,7 +51,7 @@ def __init__(self, *args, meta_prefix: Union[Dict, str] = None, **kwargs) -> Non
self._meta_prefix = meta_prefix

@staticmethod
def _read_meta_from_file(path: str) -> Union[List[Dict], Dict]:
def _read_meta_from_file(path: str) -> Union[List[Any], Dict[Any, Any]]:
from . import MetaHandler
return MetaHandler().read(path)

Expand Down Expand Up @@ -58,6 +82,12 @@ def update_meta(self, obj: Union[Dict, str]) -> None:
if isinstance(obj, str):
obj = self._read_meta_from_file(obj)

if isinstance(obj, list):
raise RuntimeError(
'Object that was passed or read from path is a list.'
'There is no clear way how to update this object\'s meta'
'using list')

if hasattr(self, '_meta_prefix'):
self._meta_prefix.update(obj)
else:
Expand Down
2 changes: 1 addition & 1 deletion cascade/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
limitations under the License.
"""

from .dataset import Dataset, Modifier, Sampler, T, Wrapper, Iterator
from .dataset import Dataset, Modifier, Sampler, T, Wrapper, Iterator, SizedDataset

from .apply_modifier import ApplyModifier
from .bruteforce_cacher import BruteforceCacher
Expand Down
7 changes: 4 additions & 3 deletions cascade/data/apply_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@
limitations under the License.
"""

from typing import Callable
from typing import Callable, Any, List, Dict
from . import Dataset, Modifier, T


class ApplyModifier(Modifier):
"""
Modifier that maps a function to given dataset's items in a lazy way.
"""
def __init__(self, dataset: Dataset, func: Callable, *args, **kwargs) -> None:
def __init__(self, dataset: Dataset[T], func: Callable[[T], Any],
*args: List[Any], **kwargs: Dict[Any, Any]) -> None:
"""
Parameters
----------
Expand All @@ -45,6 +46,6 @@ def __init__(self, dataset: Dataset, func: Callable, *args, **kwargs) -> None:
super().__init__(dataset, *args, **kwargs)
self._func = func

def __getitem__(self, index: int) -> T:
def __getitem__(self, index: int) -> Any:
item = self._dataset[index]
return self._func(item)
12 changes: 8 additions & 4 deletions cascade/data/bruteforce_cacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
limitations under the License.
"""

from typing import Any
from tqdm import tqdm, trange
from . import Dataset, Modifier, T

Expand Down Expand Up @@ -51,7 +52,8 @@ class BruteforceCacher(Modifier):
cascade.data.SequentialCacher
cascade.data.Pickler
"""
def __init__(self, dataset: Dataset, *args, **kwargs) -> None:
def __init__(self, dataset: Dataset[T],
*args: Any, **kwargs: Any) -> None:
"""
Loads every item in dataset in internal list.
"""
Expand All @@ -62,10 +64,12 @@ def __init__(self, dataset: Dataset, *args, **kwargs) -> None:
elif hasattr(self._dataset, '__iter__'):
self._data = [item for item in tqdm(self._dataset)]
else:
raise AttributeError('Input dataset must provide Mapping or Iterable interface')
raise AttributeError(
'Input dataset must provide __len__ and __getitem__ or __iter__'
)

def __getitem__(self, index) -> T:
def __getitem__(self, index: int) -> T:
return self._data[index]

def __len__(self):
def __len__(self) -> int:
return len(self._data)
33 changes: 26 additions & 7 deletions cascade/data/composer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,27 @@
from typing import List, Iterable, Tuple, Dict
from . import T, Dataset
"""
Copyright 2022 Ilia Moiseev
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
class Composer(Dataset):
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""


from typing import List, Tuple, Dict, Any

from . import SizedDataset
from ..base import Meta


class Composer(SizedDataset):
"""
Unifies two or more datasets element-wise.
Expand All @@ -14,7 +33,7 @@ class Composer(Dataset):
>>> ds = cdd.Composer((items, labels))
>>> assert ds[0] == (0, 1)
"""
def __init__(self, datasets: Iterable[Dataset], *args, **kwargs) -> None:
def __init__(self, datasets: List[SizedDataset[Any]], *args: Any, **kwargs: Any) -> None:
"""
Parameters
----------
Expand All @@ -28,7 +47,7 @@ def __init__(self, datasets: Iterable[Dataset], *args, **kwargs) -> None:
# set the length of any dataset as the length of Composer
self._len = len(self._datasets[0])

def _validate_input(self, datasets):
def _validate_input(self, datasets: List[SizedDataset[Any]]) -> None:
lengths = [len(ds) for ds in datasets]
first = lengths[0]
if not all([ln == first for ln in lengths]):
Expand All @@ -37,13 +56,13 @@ def _validate_input(self, datasets):
f'Actual lengths: {lengths}'
)

def __getitem__(self, index: int) -> Tuple[T]:
def __getitem__(self, index: int) -> Tuple[Any]:
return tuple(ds[index] for ds in self._datasets)

def __len__(self) -> int:
return self._len

def get_meta(self) -> List[Dict]:
def get_meta(self) -> Meta:
"""
Composer calls `get_meta()` of all its datasets
"""
Expand Down
Loading

0 comments on commit 203f8ba

Please sign in to comment.