Skip to content

Commit

Permalink
Variable attrs
Browse files Browse the repository at this point in the history
  • Loading branch information
ecomodeller committed Jan 9, 2025
1 parent 32485ae commit e6ea96f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 9 deletions.
28 changes: 19 additions & 9 deletions modelskill/comparison/_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,7 +1265,15 @@ def save(self, filename: Union[str, Path]) -> None:
con = duckdb.connect(filename)
# TODO figure out how to save the x, y, z coordinates and other attributes later
df = ds.to_dataframe().drop(columns=["x", "y", "z"]).reset_index() # noqa
duckdb.sql("CREATE TABLE data AS SELECT * FROM df", connection=con)
duckdb.sql("CREATE TABLE matched_data AS SELECT * FROM df", connection=con)

attr_dict = {key: str(ds[key].attrs) for key in ds.data_vars}
attr_df = pd.DataFrame(attr_dict.items(), columns=["key", "value"]) # noqa

# attr_df["global", "key"] = str(ds.attrs)

duckdb.sql("CREATE TABLE attrs AS SELECT * FROM attr_df", connection=con)

con.close()
elif ext == ".nc":
if self.gtype == "point":
Expand Down Expand Up @@ -1304,18 +1312,20 @@ def load(filename: Union[str, Path]) -> "Comparer":
import duckdb

con = duckdb.connect(filename)
df = duckdb.sql("SELECT * FROM data", connection=con).df().set_index("time")
df = (
duckdb.sql("SELECT * FROM matched_data", connection=con)
.df()
.set_index("time")
)

# convert pandas dataframe to xarray dataset
ds = xr.Dataset.from_dataframe(df)

# set observation attribute
ds.Observation.attrs["kind"] = "observation"

# set model attributes
for key in ds.data_vars:
if key != "Observation":
ds[key].attrs["kind"] = "model"
attrs = duckdb.sql("SELECT * FROM attrs", connection=con).df()
for row in attrs.iterrows():
key = row[1]["key"]
value = row[1]["value"]
ds[key].attrs = eval(value)

# TODO figure out aux variables

Expand Down
7 changes: 7 additions & 0 deletions tests/test_comparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,3 +962,10 @@ def test_save_load(pc, tmp_path) -> None:
assert "m1" in pc2.mod_names
assert "m2" in pc2.mod_names
assert pc2.n_points == 5
assert pc2.data.m1.attrs["kind"] == "model"
assert pc2.data.m2.attrs["kind"] == "model"
assert pc2.data.Observation.attrs["kind"] == "observation"

# TODO global attrs

# TODO raw_mod_data

0 comments on commit e6ea96f

Please sign in to comment.