diff --git a/darts-acquisition/src/darts_acquisition/s2.py b/darts-acquisition/src/darts_acquisition/s2.py index b2eb753..6a7597c 100644 --- a/darts-acquisition/src/darts_acquisition/s2.py +++ b/darts-acquisition/src/darts_acquisition/s2.py @@ -58,6 +58,8 @@ def load_s2_scene(fpath: str | Path) -> xr.Dataset: ds_s2 = xr.merge(datasets) planet_crop_id = fpath.stem s2_tile_id = "_".join(s2_image.stem.split("_")[:3]) + ds_s2.attrs["planet_crop_id"] = planet_crop_id + ds_s2.attrs["s2_tile_id"] = s2_tile_id ds_s2.attrs["tile_id"] = f"{planet_crop_id}_{s2_tile_id}" logger.debug(f"Loaded Sentinel 2 scene in {time.time() - start_time} seconds.") return ds_s2 diff --git a/darts/src/darts/training.py b/darts/src/darts/training.py index 59b2a9e..c2b3a73 100644 --- a/darts/src/darts/training.py +++ b/darts/src/darts/training.py @@ -107,10 +107,10 @@ def preprocess_s2_train_data( device, ) - labels = gpd.read_file(fpath / f"{optical.attrs['tile_id']}.shp") - tile_id = optical.attrs["tile_id"] + labels = gpd.read_file(fpath / f"{optical.attrs['s2_tile_id']}.shp") # Save the patches + tile_id = optical.attrs["tile_id"] gen = create_training_patches(tile, labels, bands, patch_size, overlap, include_allzero, include_nan_edges) for patch_id, (x, y) in enumerate(gen): torch.save(x, outpath_x / f"{tile_id}_pid{patch_id}.pt")