Skip to content

Commit

Permalink
Fix code styling
Browse files Browse the repository at this point in the history
  • Loading branch information
Oxid15 committed Nov 15, 2022
1 parent eea789d commit ef065b3
Show file tree
Hide file tree
Showing 26 changed files with 81 additions and 92 deletions.
4 changes: 2 additions & 2 deletions cascade/base/meta_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def read(self, path: str) -> Union[Dict, List[Dict]]:
self._raise_io_error(path, e)
return meta

def write(self, path:str, obj: List[Dict], overwrite=True) -> None:
def write(self, path: str, obj: List[Dict], overwrite=True) -> None:
if not overwrite and os.path.exists(path):
return

Expand Down Expand Up @@ -181,7 +181,7 @@ def read(self, path: str) -> Union[Dict, List[Dict]]:
handler = self._get_handler(path)
return handler.read(path)

def write(self, path: str, obj, overwrite:bool = True) -> None:
def write(self, path: str, obj, overwrite: bool = True) -> None:
"""
Writes object to path.
Expand Down
4 changes: 2 additions & 2 deletions cascade/base/traceable.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ class Traceable:
Base class for everything that has metadata in cascade.
Handles the logic of getting and updating internal meta prefix.
"""
def __init__(self, *args, meta_prefix:Union[Dict, str] = None, **kwargs) -> None:
def __init__(self, *args, meta_prefix: Union[Dict, str] = None, **kwargs) -> None:
"""
Parameters
----------
meta_prefix: Union[Dict, str], optional
The dictionary that is used to update object's meta in `get_meta` call.
Due to the call of update can overwrite default values.
If str - prefix assumed to be path and loaded using MetaHandler.
See also
--------
cascade.base.MetaHandler
Expand Down
2 changes: 1 addition & 1 deletion cascade/data/apply_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
limitations under the License.
"""

from typing import List, Dict, Callable
from typing import Callable
from . import Dataset, Modifier, T


Expand Down
7 changes: 4 additions & 3 deletions cascade/data/composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ def __init__(self, datasets: Iterable[Dataset], *args, **kwargs) -> None:
def _validate_input(self, datasets):
lengths = [len(ds) for ds in datasets]
first = lengths[0]
if not all([l == first for l in lengths]):
if not all([ln == first for ln in lengths]):
raise ValueError(
f'The datasets passed should be of the same length\n' \
f'Actual lengths: {lengths}')
f'The datasets passed should be of the same length\n'
f'Actual lengths: {lengths}'
)

def __getitem__(self, index: int) -> Tuple[T]:
return tuple(ds[index] for ds in self._datasets)
Expand Down
2 changes: 1 addition & 1 deletion cascade/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __repr__(self) -> str:
"""
Returns
-------
repr: str
repr: str
Representation of a Dataset. This repr used as a name for get_meta() method
by default gives the name of class from basic repr
Expand Down
2 changes: 1 addition & 1 deletion cascade/data/folder_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, root: str, *args, **kwargs) -> None:
Parameters
----------
root: str
A path to the folder of files
A path to the folder of files
"""
super().__init__(*args, **kwargs)
self._root = os.path.abspath(root)
Expand Down
13 changes: 7 additions & 6 deletions cascade/data/range_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@ class RangeSampler(Sampler):
2
3
"""
def __init__(self,
dataset: Dataset,
start:int = None,
stop:int = None,
step:int = 1,
*args, **kwargs) -> None:
def __init__(
self,
dataset: Dataset,
start: int = None,
stop: int = None,
step: int = 1,
*args, **kwargs) -> None:
"""
Parameters
----------
Expand Down
8 changes: 4 additions & 4 deletions cascade/data/sequential_cacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ class SequentialCacher(Modifier):
BruteforceCacher
"""
def __init__(
self,
dataset: Dataset,
batch_size: int = 2,
*args, **kwargs) -> None:
self,
dataset: Dataset,
batch_size: int = 2,
*args, **kwargs) -> None:
"""
Parameters
----------
Expand Down
8 changes: 4 additions & 4 deletions cascade/data/version_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class VersionAssigner(Modifier):
The version consists of two parts, namely major and minor in the format `MAJOR.MINOR` just
like in semantic versioning. The meaning of parts is the following: *major* number changes
if there are changes in the structure of the pipeline e.g. some dataset was added/removed;
*minor* number changes in case of any metadata change e.g. new data arrived and changed
*minor* number changes in case of any metadata change e.g. new data arrived and changed
the length of modifiers on pipeline.
Example
Expand Down Expand Up @@ -71,7 +71,7 @@ def __init__(self, dataset: Dataset, path: str, verbose=False, *args, **kwargs)
"""
super().__init__(dataset, *args, **kwargs)
self._mh = MetaHandler()
self._assign_path(path)
self._assign_path(path)
self._versions = {}

# get meta for info about pipeline
Expand All @@ -87,7 +87,7 @@ def __init__(self, dataset: Dataset, path: str, verbose=False, *args, **kwargs)
pipe_hash = md5(str.encode(pipeline, 'utf-8')).hexdigest()

if os.path.exists(self._root):
self._versions = self._mh.read(self._root)
self._versions = self._mh.read(self._root)

if pipe_hash in self._versions:
if meta_hash in self._versions[pipe_hash]:
Expand Down Expand Up @@ -124,7 +124,7 @@ def __init__(self, dataset: Dataset, path: str, verbose=False, *args, **kwargs)
'pipeline': pipeline
}
self._mh.write(self._root, self._versions)

if verbose:
print('Dataset version:', self.version)

Expand Down
6 changes: 3 additions & 3 deletions cascade/meta/dataleak_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def __init__(self, train_ds, test_ds, hash_fn=None, **kwargs) -> None:
size = len(train_repeating_idx)
if size > 0:
raise DataValidationException(
f'Train and test datasets have {size} common items\n' \
f'Train indices: {prettify_items(train_repeating_idx)}\n' \
f'Train and test datasets have {size} common items\n'
f'Train indices: {prettify_items(train_repeating_idx)}\n'
f'Test indices: {prettify_items(test_repeating_idx)}'
)
else:
print('OK!')

super().__init__(self, train_ds, **kwargs)
1 change: 0 additions & 1 deletion cascade/meta/hashes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""

from hashlib import md5
import numpy as np


def numpy_md5(x) -> str:
Expand Down
13 changes: 5 additions & 8 deletions cascade/meta/history_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
limitations under the License.
"""

import os
from typing import List
import pendulum
import pandas as pd
Expand All @@ -25,8 +24,6 @@
from plotly import graph_objects as go

from . import MetaViewer
from .. import __version__
from ..data import Dataset


class HistoryViewer:
Expand All @@ -36,10 +33,10 @@ class HistoryViewer:
models with different hyperparameters depend on each other.
"""
def __init__(
self,
repo,
last_lines: int = None,
last_models: int = None) -> None:
self,
repo,
last_lines: int = None,
last_models: int = None) -> None:
"""
Parameters
----------
Expand Down Expand Up @@ -196,7 +193,7 @@ def plot(self, metric: str, show: bool = False) -> plotly.graph_objects.Figure:
def serve(self, metric: str, **kwargs):
"""
Run dash-based server with HistoryViewer, updating plots in real-time.
Note
----
This feature needs `dash` to be installed.
Expand Down
2 changes: 1 addition & 1 deletion cascade/meta/meta_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class MetaViewer:
"""
The class to view all metadata in folders and subfolders.
"""
def __init__(self, root: str, filt: Dict=None) -> None:
def __init__(self, root: str, filt: Dict = None) -> None:
"""
Parameters
----------
Expand Down
11 changes: 5 additions & 6 deletions cascade/meta/metric_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import pandas as pd

from . import MetaViewer
from .. import __version__


class MetricViewer:
Expand Down Expand Up @@ -139,11 +138,11 @@ def get_best_by(self, metric: str, maximize=True):
return self._repo[name][num]

def serve(
self,
page_size: int = 50,
include: List[str] = None,
exclude: List[str] = None,
**kwargs) -> None:
self,
page_size: int = 50,
include: List[str] = None,
exclude: List[str] = None,
**kwargs) -> None:
"""
Runs dash-based server with interactive table of metrics and parameters
Expand Down
10 changes: 3 additions & 7 deletions cascade/models/model_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,16 @@
"""

import os
import glob
import warnings
import itertools
import logging
from typing import List, Dict, Iterable, Union
import shutil

from plotly import graph_objects as go
from flatten_json import flatten
import pandas as pd
import pendulum
from deepdiff.diff import DeepDiff

from ..base import Traceable, MetaHandler, JSONEncoder, supported_meta_formats
from ..meta import MetricViewer
from .model import Model
from .model_line import ModelLine

Expand Down Expand Up @@ -129,7 +124,7 @@ def _load_lines(self):
for name in sorted(os.listdir(self._root))
if os.path.isdir(os.path.join(self._root, name))}

def add_line(self, name:str=None, *args, meta_fmt=None, **kwargs):
def add_line(self, name: str = None, *args, meta_fmt=None, **kwargs):
"""
Adds new line to repo if it doesn't exist and returns it.
If line exists, defines it in repo with parameters provided.
Expand All @@ -139,7 +134,7 @@ def add_line(self, name:str=None, *args, meta_fmt=None, **kwargs):
Parameters:
name: str, optional
Name of the line. It is used to name a folder of line.
Repo prepends it with `self._root` before creating.
Repo prepends it with `self._root` before creating.
Optional argument. If omitted - names new line automatically
using f'{len(self):0>5d}'
meta_fmt: str, optional
Expand Down Expand Up @@ -266,6 +261,7 @@ def get_line_names(self) -> List[str]:
# TODO: write test covering this
return list(self._lines.keys())


class ModelRepoConcatenator(Repo):
"""
The class to concatenate different Repos.
Expand Down
1 change: 1 addition & 0 deletions cascade/tests/test_concatenator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def test_meta():
assert c.get_meta()[0]['num'] == 1
assert len(c.get_meta()[0]['data']) == 2


# TODO: replace arrs with datasets
@pytest.mark.parametrize(
'arrs', [
Expand Down
2 changes: 0 additions & 2 deletions cascade/tests/test_dataleak_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def test_simple_hash():

cme.DataleakValidator(train_ds, test_ds)


train_ds = cdd.Wrapper(['a', 'b', 'c'])
test_ds = cdd.Wrapper(['c', 'd', 'e'])

Expand All @@ -46,7 +45,6 @@ def test_np_hash():

cme.DataleakValidator(train_ds, test_ds, hash_fn=cme.numpy_md5)


train_ds = cdd.Wrapper(np.zeros(10))
test_ds = cdd.Wrapper(np.zeros(10))

Expand Down
1 change: 1 addition & 0 deletions cascade/tests/test_model_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ def test_integer_indices(tmp_path, ext):
assert first_line == repo[0]
assert last_line == repo[-1]


@pytest.mark.parametrize(
'ext', [
'.json',
Expand Down
4 changes: 2 additions & 2 deletions cascade/tests/test_version_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ def test(tmp_path, ext):
ds = Wrapper([0, 1, 2, 3, 4])
ds = ApplyModifier(ds, lambda x: x ** 2)
ds = VersionAssigner(ds, filepath)

assert ds.version == '1.0'

ds = Wrapper([0, 1, 2, 3, 4, 5])
ds = ApplyModifier(ds, lambda x: x ** 2)
ds = VersionAssigner(ds, filepath)

assert ds.version == '1.1'
Loading

0 comments on commit ef065b3

Please sign in to comment.