Skip to content

Commit

Permalink
Make preprocess function
Browse files Browse the repository at this point in the history
  • Loading branch information
relativityhd committed Oct 11, 2024
1 parent 44e9b97 commit c121655
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 33 deletions.
48 changes: 48 additions & 0 deletions darts-preprocessing/src/darts_preprocessing/preprocess_tobi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""PLANET scene based preprocessing."""

from pathlib import Path

import xarray as xr

from darts_preprocessing.utils.data_pre_processing import (
calculate_ndvi,
load_auxiliary,
load_data_masks,
load_planet_scene,
)


def load_and_preprocess_planet_scene(planet_scene_path: Path, elevation_path: Path, slope_path: Path) -> xr.Dataset:
"""Load and preprocess a Planet Scene (PSOrthoTile or PSScene) into an xr.Dataset.
Args:
planet_scene_path (Path): path to the Planet Scene
elevation_path (Path): path to the elevation data
slope_path (Path): path to the slope data
Returns:
xr.Dataset: preprocessed Planet Scene
"""
# load planet scene
ds_planet = load_planet_scene(planet_scene_path)

# calculate xr.dataset ndvi
ds_ndvi = calculate_ndvi(ds_planet)

# get xr.dataset for elevation
ds_elevation = load_auxiliary(planet_scene_path, elevation_path, xr_dataset_name="relative_elevation")

# get xr.dataset for slope
ds_slope = load_auxiliary(planet_scene_path, slope_path, xr_dataset_name="slope")

# # get xr.dataset for tcvis
# ds_tcvis = load_auxiliary(planet_scene_path, tcvis_path)

# load udm2
ds_data_masks = load_data_masks(planet_scene_path)

# merge to final dataset
ds_merged = xr.merge([ds_planet, ds_ndvi, ds_elevation, ds_slope, ds_data_masks])

return ds_merged
24 changes: 0 additions & 24 deletions darts-segmentation/src/darts_segmentation/hardcoded_stuff.py

This file was deleted.

2 changes: 1 addition & 1 deletion darts-segmentation/src/darts_segmentation/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def predict_in_patches(

# Infer logits with model and turn into probabilities with sigmoid
patched_logits = model(patches)
patched_probabilities = torch.sigmoid(patched_logits) # TODO: check if this is the correct function
patched_probabilities = torch.sigmoid(patched_logits)

# Reconstruct the image from the patches
prediction = torch.zeros(bs, h, w, device=tensor_tiles.device)
Expand Down
75 changes: 67 additions & 8 deletions notebooks/model-convert.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,23 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"{'architecture': 'UnetPlusPlus',\n",
" 'encoder': 'resnet34',\n",
" 'encoder_weights': 'random',\n",
" 'input_channels': 7}"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Load config and checkpoint\n",
"path = \"../models/old/RTS_v6_notcvis/checkpoints/41.pt\"\n",
Expand All @@ -35,9 +49,20 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Try loading\n",
"model = smp.create_model(\n",
Expand Down Expand Up @@ -83,6 +108,8 @@
" \"relative_elevation\": 1 / 30000,\n",
" \"slope\": 1 / 90,\n",
" },\n",
" \"patch_size\": 1024,\n",
" \"model_framework\": \"smp\",\n",
" },\n",
" \"statedict\": model.module.state_dict(),\n",
" },\n",
Expand All @@ -92,9 +119,20 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Test it\n",
"checkpoint = torch.load(\"../models/RTS_v6_notcvis.pt\")\n",
Expand All @@ -111,9 +149,28 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'architecture': 'UnetPlusPlus', 'encoder': 'resnet34', 'encoder_weights': 'random', 'input_channels': 10}\n",
"<All keys matched successfully>\n"
]
},
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Load config and checkpoint\n",
"path = \"../models/old/RTS_v6_tcvis/checkpoints/14.pt\"\n",
Expand Down Expand Up @@ -166,6 +223,8 @@
" \"tc_greenness\": 1 / 255,\n",
" \"tc_wetness\": 1 / 255,\n",
" },\n",
" \"patch_size\": 1024,\n",
" \"model_framework\": \"smp\",\n",
" },\n",
" \"statedict\": model.module.state_dict(),\n",
" },\n",
Expand Down
85 changes: 85 additions & 0 deletions notebooks/test-e2e.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"import xarray as xr\n",
"from darts_postprocessing.prepare_export import prepare_export\n",
"from darts_preprocessing.preprocess_tobi import load_and_preprocess_planet_scene\n",
"from darts_segmentation.segment import SMPSegmenter\n",
"from lovely_tensors import monkey_patch\n",
"from rich import traceback\n",
"\n",
"xr.set_options(display_expand_data=False)\n",
"\n",
"monkey_patch()\n",
"traceback.install(show_locals=True)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"DATA_ROOT = Path(\"../data/input\")\n",
"\n",
"fpath = DATA_ROOT / \"planet/PSOrthoTile/4372514/5790392_4372514_2022-07-16_2459\"\n",
"scene_id = fpath.parent.name\n",
"\n",
"# TODO: change to vrt\n",
"elevation_path = DATA_ROOT / \"ArcticDEM\" / \"relative_elevation\" / f\"{scene_id}_relative_elevation_100.tif\"\n",
"slope_path = DATA_ROOT / \"ArcticDEM\" / \"slope\" / f\"{scene_id}_slope.tif\"\n",
"\n",
"tile = load_and_preprocess_planet_scene(fpath, elevation_path, slope_path)\n",
"tile\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = SMPSegmenter(\"../models/RTS_v6_notcvis.pt\")\n",
"tile = model.segment_tile(tile)\n",
"final = prepare_export(tile)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tile"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit c121655

Please sign in to comment.