Skip to content

Commit

Permalink
Fix training-processing bug
Browse files Browse the repository at this point in the history
  • Loading branch information
relativityhd committed Nov 29, 2024
1 parent f5de1ab commit 0c78f74
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
2 changes: 2 additions & 0 deletions darts-acquisition/src/darts_acquisition/s2.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def load_s2_scene(fpath: str | Path) -> xr.Dataset:
ds_s2 = xr.merge(datasets)
planet_crop_id = fpath.stem
s2_tile_id = "_".join(s2_image.stem.split("_")[:3])
ds_s2.attrs["planet_crop_id"] = planet_crop_id
ds_s2.attrs["s2_tile_id"] = s2_tile_id
ds_s2.attrs["tile_id"] = f"{planet_crop_id}_{s2_tile_id}"
logger.debug(f"Loaded Sentinel 2 scene in {time.time() - start_time} seconds.")
return ds_s2
Expand Down
4 changes: 2 additions & 2 deletions darts/src/darts/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ def preprocess_s2_train_data(
device,
)

labels = gpd.read_file(fpath / f"{optical.attrs['tile_id']}.shp")
tile_id = optical.attrs["tile_id"]
labels = gpd.read_file(fpath / f"{optical.attrs['s2_tile_id']}.shp")

# Save the patches
tile_id = optical.attrs["tile_id"]
gen = create_training_patches(tile, labels, bands, patch_size, overlap, include_allzero, include_nan_edges)
for patch_id, (x, y) in enumerate(gen):
torch.save(x, outpath_x / f"{tile_id}_pid{patch_id}.pt")
Expand Down

0 comments on commit 0c78f74

Please sign in to comment.