Skip to content

Commit

Permalink
Make preprocess easier to read
Browse files Browse the repository at this point in the history
  • Loading branch information
relativityhd committed Dec 21, 2024
1 parent 5130c9b commit 9a2b7c8
Showing 1 changed file with 78 additions and 53 deletions.
131 changes: 78 additions & 53 deletions darts/src/darts/legacy_training/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,79 @@
logger = logging.getLogger(__name__)


def split_dataset_paths(
s2_paths: list[Path], train_data_dir: Path, test_val_split: float, test_regions: list[str] | None
):
"""Split the dataset into a cross-val, a val-test and a test dataset.
Returns a generator with: input-path, output-path and split/mode.
The test set is splitted first by the given regions and is meant to be used to evaluate the regional value shift.
Then the val-test set is splitted then by random at given size to evaluate the variance value shift.
Args:
s2_paths (list[Path]): All paths found with tiffs.
train_data_dir (Path): Output path.
test_val_split (float): val-test ratio.
test_regions (list[str] | None): test regions.
Returns:
[zip[tuple[Path, Path, str]]]: A generator with input-path, output-path and split/mode.
"""
# Import here to avoid long loading times when running other commands
import geopandas as gpd
from darts_acquisition.s2 import parse_s2_tile_id
from sklearn.model_selection import train_test_split

train_data_dir.mkdir(exist_ok=True, parents=True)
output_dir_cross_val = train_data_dir / "cross-val"
output_dir_val_test = train_data_dir / "val-test"
output_dir_test = train_data_dir / "test"

# 1. Split regions
test_paths: list[Path] = []
training_paths: list[Path] = []
if test_regions:
for fpath in s2_paths:
_, s2_tile_id, _ = parse_s2_tile_id(fpath)
labels = gpd.read_file(fpath / f"{s2_tile_id}.shp")
# If any of the regions is in the test region, add to the test set
if labels["region"].isin(test_regions).any():
test_paths.append(fpath)
else:
training_paths.append(fpath)
else:
training_paths = s2_paths

# 2. Split by random sampling
cross_val_paths: list[Path]
val_test_paths: list[Path]
if len(training_paths) > 0:
cross_val_paths, val_test_paths = train_test_split(training_paths, test_size=test_val_split, random_state=42)
else:
cross_val_paths, val_test_paths = [], []
logger.warning("No left over training samples found. Skipping train-val split.")

logger.info(
f"Split the data into {len(cross_val_paths)} cross-val (train + val), "
f"{len(val_test_paths)} val-test (variance) and {len(test_paths)} test (region) samples."
)

fpathgen = chain(cross_val_paths, val_test_paths, test_paths)
outpathgen = chain(
repeat(output_dir_cross_val, len(cross_val_paths)),
repeat(output_dir_val_test, len(val_test_paths)),
repeat(output_dir_test, len(test_paths)),
)
modegen = chain(
repeat("cross-val", len(cross_val_paths)),
repeat("val-test", len(val_test_paths)),
repeat("test", len(test_paths)),
)

return zip(fpathgen, outpathgen, modegen)


def preprocess_s2_train_data(
*,
bands: list[str],
Expand All @@ -32,7 +105,7 @@ def preprocess_s2_train_data(
exclude_nan: bool = True,
mask_erosion_size: int = 10,
test_val_split: float = 0.05,
test_region: list[str] | str | None = None,
test_regions: list[str] | None = None,
):
"""Preprocess Sentinel 2 data for training.
Expand Down Expand Up @@ -65,7 +138,7 @@ def preprocess_s2_train_data(
mask_erosion_size (int, optional): The size of the disk to use for mask erosion and the edge-cropping.
Defaults to 10.
test_val_split (float, optional): The split ratio for the test and validation set. Defaults to 0.05.
test_region (list[str] | str, optional): The region to use for the test set. Defaults to None.
test_regions (list[str] | str, optional): The region to use for the test set. Defaults to None.
"""
# Import here to avoid long loading times when running other commands
Expand All @@ -81,7 +154,6 @@ def preprocess_s2_train_data(
from dask.distributed import Client, LocalCluster
from lovely_tensors import monkey_patch
from odc.stac import configure_rio
from sklearn.model_selection import train_test_split

from darts.utils.cuda import debug_info, decide_device
from darts.utils.earthengine import init_ee
Expand Down Expand Up @@ -114,61 +186,14 @@ def preprocess_s2_train_data(
norm_factors = {k: v for k, v in norm_factors.items() if k in bands}

train_data_dir.mkdir(exist_ok=True, parents=True)
output_dir_cross_val = train_data_dir / "cross-val"
output_dir_val_test = train_data_dir / "val-test"
output_dir_test = train_data_dir / "test"

# Find all Sentinel 2 scenes and split into train+val (cross-val), val-test (variance) and test (region)
n_patches = 0
joint_lables = []
s2_paths = sorted(sentinel2_dir.glob("*/"))
logger.info(f"Found {len(s2_paths)} Sentinel 2 scenes in {sentinel2_dir}")

# 1. Split regions
test_paths: list[Path] = []
training_paths: list[Path] = []
if test_region:
test_region = [test_region] if isinstance(test_region, str) else test_region
for fpath in s2_paths:
_, s2_tile_id, _ = parse_s2_tile_id(fpath)
labels = gpd.read_file(fpath / f"{s2_tile_id}.shp")
# If any of the regions is in the test region, add to the test set
if labels["region"].isin(test_region).any():
test_paths.append(fpath)
else:
training_paths.append(fpath)
else:
training_paths = s2_paths

# 2. Split by random sampling
cross_val_paths: list[Path]
val_test_paths: list[Path]
if len(training_paths) > 0:
cross_val_paths, val_test_paths = train_test_split(
training_paths, test_size=test_val_split, random_state=42
)
else:
cross_val_paths, val_test_paths = [], []
logger.warning("No left over training samples found. Skipping train-val split.")

logger.info(
f"Split the data into {len(cross_val_paths)} cross-val (train + val), "
f"{len(val_test_paths)} val-test (variance) and {len(test_paths)} test (region) samples."
)

fpathgen = chain(cross_val_paths, val_test_paths, test_paths)
outpathgen = chain(
repeat(output_dir_cross_val, len(cross_val_paths)),
repeat(output_dir_val_test, len(val_test_paths)),
repeat(output_dir_test, len(test_paths)),
)
modegen = chain(
repeat("cross-val", len(cross_val_paths)),
repeat("val-test", len(val_test_paths)),
repeat("test", len(test_paths)),
)

joint_lables = []
for i, (fpath, output_dir, mode) in enumerate(zip(fpathgen, outpathgen, modegen)):
path_gen = split_dataset_paths(s2_paths, train_data_dir, test_val_split, test_regions)
for i, (fpath, output_dir, mode) in enumerate(path_gen):
try:
_, s2_tile_id, tile_id = parse_s2_tile_id(fpath)

Expand Down

0 comments on commit 9a2b7c8

Please sign in to comment.