Skip to content

Commit

Permalink
Merge pull request #387 from DHI/comparer-attrs
Browse files Browse the repository at this point in the history
Comparer attrs property
  • Loading branch information
jsmariegaard authored Jan 10, 2024
2 parents a6c1ca0 + a9eaea0 commit 1a4ee6c
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
13 changes: 13 additions & 0 deletions modelskill/comparison/_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
List,
Expand Down Expand Up @@ -45,6 +46,8 @@
if TYPE_CHECKING:
from ._collection import ComparerCollection

Serializable = Union[str, int, float]


class Scoreable(Protocol):
def score(self, metric: str | Callable, **kwargs) -> Dict[str, float]:
Expand Down Expand Up @@ -613,6 +616,7 @@ def mod_names(self) -> List[str]:
@property
def aux_names(self) -> List[str]:
"""List of auxiliary data names"""
# we don't require the kind attribute to be "auxiliary"
return list(
[
k
Expand Down Expand Up @@ -640,6 +644,15 @@ def _unit_text(self) -> str:
# Quantity name and unit as text suitable for plot labels
return f"{self.quantity.name} [{self.quantity.unit}]"

@property
def attrs(self) -> dict[str, Any]:
"""Attributes of the observation"""
return self.data.attrs

@attrs.setter
def attrs(self, value: dict[str, Serializable]) -> None:
self.data.attrs = value

# TODO: is this the best way to copy (self.data.copy.. )
def __copy__(self):
return deepcopy(self)
Expand Down
15 changes: 15 additions & 0 deletions tests/test_comparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,21 @@ def test_tc_properties(tc):
assert np.all(tc.raw_mod_data["m1"].x == [10.1, 10.2, 10.3, 10.4, 10.5, 10.6])


def test_attrs(pc):
pc.attrs["a2"] = "v2"
assert pc.attrs["a2"] == "v2"

pc.data.attrs["version"] = 42
assert pc.attrs["version"] == 42

pc.attrs["version"] = 43
assert pc.attrs["version"] == 43

# remove all attributes and add a new one
pc.attrs = {"version": 44}
assert pc.attrs["version"] == 44


def test_pc_sel_time(pc):
pc2 = pc.sel(time=slice("2019-01-03", "2019-01-04"))
assert pc2.n_points == 2
Expand Down

0 comments on commit 1a4ee6c

Please sign in to comment.