From e6ea96f12b0db56701b5a5206fafa4d857a7dd06 Mon Sep 17 00:00:00 2001 From: Henrik Andersson Date: Thu, 9 Jan 2025 19:08:34 +0100 Subject: [PATCH] Variable attrs --- modelskill/comparison/_comparison.py | 28 +++++++++++++++++++--------- tests/test_comparer.py | 7 +++++++ 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/modelskill/comparison/_comparison.py b/modelskill/comparison/_comparison.py index 366cfcbd..26be14ce 100644 --- a/modelskill/comparison/_comparison.py +++ b/modelskill/comparison/_comparison.py @@ -1265,7 +1265,15 @@ def save(self, filename: Union[str, Path]) -> None: con = duckdb.connect(filename) # TODO figure out how to save the x, y, z coordinates and other attributes later df = ds.to_dataframe().drop(columns=["x", "y", "z"]).reset_index() # noqa - duckdb.sql("CREATE TABLE data AS SELECT * FROM df", connection=con) + duckdb.sql("CREATE TABLE matched_data AS SELECT * FROM df", connection=con) + + attr_dict = {key: str(ds[key].attrs) for key in ds.data_vars} + attr_df = pd.DataFrame(attr_dict.items(), columns=["key", "value"]) # noqa + + # attr_df["global", "key"] = str(ds.attrs) + + duckdb.sql("CREATE TABLE attrs AS SELECT * FROM attr_df", connection=con) + con.close() elif ext == ".nc": if self.gtype == "point": @@ -1304,18 +1312,20 @@ def load(filename: Union[str, Path]) -> "Comparer": import duckdb con = duckdb.connect(filename) - df = duckdb.sql("SELECT * FROM data", connection=con).df().set_index("time") + df = ( + duckdb.sql("SELECT * FROM matched_data", connection=con) + .df() + .set_index("time") + ) # convert pandas dataframe to xarray dataset ds = xr.Dataset.from_dataframe(df) - # set observation attribute - ds.Observation.attrs["kind"] = "observation" - - # set model attributes - for key in ds.data_vars: - if key != "Observation": - ds[key].attrs["kind"] = "model" + attrs = duckdb.sql("SELECT * FROM attrs", connection=con).df() + for row in attrs.iterrows(): + key = row[1]["key"] + value = row[1]["value"] + ds[key].attrs = eval(value) # TODO figure out aux variables diff --git a/tests/test_comparer.py b/tests/test_comparer.py index 81bfbd90..7f5b15ab 100644 --- a/tests/test_comparer.py +++ b/tests/test_comparer.py @@ -962,3 +962,10 @@ def test_save_load(pc, tmp_path) -> None: assert "m1" in pc2.mod_names assert "m2" in pc2.mod_names assert pc2.n_points == 5 + assert pc2.data.m1.attrs["kind"] == "model" + assert pc2.data.m2.attrs["kind"] == "model" + assert pc2.data.Observation.attrs["kind"] == "observation" + + # TODO global attrs + + # TODO raw_mod_data