From f26abacfb1220ceb29054ef6765dd4d7946f5a9f Mon Sep 17 00:00:00 2001 From: Sieger Falkena Date: Wed, 13 Nov 2024 15:17:41 +0100 Subject: [PATCH] Add plotting to notebook --- test.ipynb | 1076 +++++++++++++++++++++++++++++++++++++- torchgeo/datasets/geo.py | 8 +- 2 files changed, 1073 insertions(+), 11 deletions(-) diff --git a/test.ipynb b/test.ipynb index 530c9412fb0..86276ecaaea 100644 --- a/test.ipynb +++ b/test.ipynb @@ -95,21 +95,1079 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import plotly.express as px\n", + "\n", + "def plot(\n", + " sample: dict,\n", + " indices_to_plot,\n", + " show = False,\n", + " **kwargs,\n", + "):\n", + " \"\"\"Plots the image data from the given sample.\n", + "\n", + " Args:\n", + " sample (dict): A dictionary containing the image data returned by self.__get_item__. Should contain the key \"image\".\n", + " indices_to_plot (list, optional): A list of indices to plot. If not provided, the method will use the RGB bands defined in `self.rgb_bands`.\n", + " show (bool, optional): Whether to display the plot. Defaults to False.\n", + " **kwargs (dict): Additional keyword arguments to be passed to `px.imshow`.\n", + "\n", + " Returns:\n", + " fig: The plotly figure object.\n", + " \"\"\"\n", + " image = sample[\"image\"]\n", + "\n", + " # Reorder and rescale the image\n", + " if (sample[\"image\"].ndim == 4) and (sample[\"image\"].shape[0] > 1):\n", + " # Shape of image = [d, c, h, w]\n", + " image = image[:, indices_to_plot, :, :].permute(0, 2, 3, 1)\n", + " if image.shape[-1] == 1:\n", + " image = image.squeeze(-1)\n", + " image = torch.clamp(image / 10000, min=0, max=1).numpy()\n", + "\n", + " fig = px.imshow(\n", + " image, animation_frame=0, labels={\"animation_frame\": \"Date\"}, **kwargs\n", + " )\n", + " # Todo, currently taking the first date, need to handle multiple dates\n", + " date_labels = [\n", + " dates[0].strftime(\"%m/%d/%Y, %H:%M:%S\") for dates in sample[\"dates\"]\n", + " ]\n", + " for i, label in enumerate(date_labels):\n", + " fig.layout.sliders[0].steps[i].label = label\n", + "\n", + " else:\n", + " image = image[indices_to_plot].permute(1, 2, 0)\n", + " image = torch.clamp(image / 10000, min=0, max=1).numpy()\n", + "\n", + " # Plot the image\n", + " fig = px.imshow(image, **kwargs)\n", + "\n", + " fig.update_xaxes(showticklabels=False).update_yaxes(showticklabels=False)\n", + " if show:\n", + " fig.show()\n", + " return fig" + ] + }, + { + "cell_type": "code", + "execution_count": 12, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "None\n" - ] + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "hovertemplate": "x: %{x}
y: %{y}", + "name": "0", + "source": "", + "type": "image", + "xaxis": "x", + "yaxis": "y" + } + ], + "frames": [ + { + "data": [ + { + "name": "0", + "source": "", + "type": "image" + } + ], + "layout": { + "margin": { + "t": 60 + } + }, + "name": "0" + }, + { + "data": [ + { + "name": "1", + "source": "", + "type": "image" + } + ], + "layout": { + "margin": { + "t": 60 + } + }, + "name": "1" + } + ], + "layout": { + "margin": { + "t": 60 + }, + "sliders": [ + { + "active": 0, + "currentvalue": { + "prefix": "Date=" + }, + "len": 0.9, + "pad": { + "b": 10, + "t": 60 + }, + "steps": [ + { + "args": [ + [ + "0" + ], + { + "frame": { + "duration": 0, + "redraw": true + }, + "fromcurrent": true, + "mode": "immediate", + "transition": { + "duration": 0, + "easing": "linear" + } + } + ], + "label": "04/12/2019, 16:28:41", + "method": "animate" + }, + { + "args": [ + [ + "1" + ], + { + "frame": { + "duration": 0, + "redraw": true + }, + "fromcurrent": true, + "mode": "immediate", + "transition": { + "duration": 0, + "easing": "linear" + } + } + ], + "label": "04/12/2022, 16:28:41", + "method": "animate" + } + ], + "x": 0.1, + "xanchor": "left", + "y": 0, + "yanchor": "top" + } + ], + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "updatemenus": [ + { + "buttons": [ + { + "args": [ + null, + { + "frame": { + "duration": 500, + "redraw": true + }, + "fromcurrent": true, + "mode": "immediate", + "transition": { + "duration": 500, + "easing": "linear" + } + } + ], + "label": "▶", + "method": "animate" + }, + { + "args": [ + [ + null + ], + { + "frame": { + "duration": 0, + "redraw": true + }, + "fromcurrent": true, + "mode": "immediate", + "transition": { + "duration": 0, + "easing": "linear" + } + } + ], + "label": "◼", + "method": "animate" + } + ], + "direction": "left", + "pad": { + "r": 10, + "t": 70 + }, + "showactive": false, + "type": "buttons", + "x": 0.1, + "xanchor": "right", + "y": 0, + "yanchor": "top" + } + ], + "xaxis": { + "anchor": "y", + "domain": [ + 0, + 1 + ], + "showticklabels": false + }, + "yaxis": { + "anchor": "x", + "domain": [ + 0, + 1 + ], + "showticklabels": false + } + } + } + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ - "import rasterio\n", - "with rasterio.open(r\"C:\\Users\\Sieger.Falkena\\Documents\\torchgeo\\tests\\data\\sentinel2\\S2A_MSIL2A_20220414T110751_N0400_R108_T26EMU_20220414T165533.SAFE\\GRANULE\\L2A_T26EMU_A035569_20220414T110747\\IMG_DATA\\R10m\\T26EMU_20190414T110751_B02_10m.jp2\") as src:\n", - " print(src.nodata)" + "plot(sample, show=False, indices_to_plot=[2, 1, 0])" ] } ], diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 98cdef2d4d4..9ea6de07746 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -649,7 +649,9 @@ def _get_regex_groups_as_df(self, filepaths: list[str]) -> pd.DataFrame: return pd.DataFrame(file_metadata) - def __merge_single_bbox(self, query: BoundingBox) -> tuple[torch.Tensor | None, list[str]]: + def __merge_single_bbox( + self, query: BoundingBox + ) -> tuple[torch.Tensor | None, list[str]]: """Merge all files that intersect with a single bounding box. Args: @@ -687,7 +689,9 @@ def __merge_single_bbox(self, query: BoundingBox) -> tuple[torch.Tensor | None, if res_single_bbox is not None: # Check if res_single_date contains nodata values and only append if it doesn't - if not self.drop_nodata or not torch.any(res_single_bbox == self.nodata_value): + if not self.drop_nodata or not torch.any( + res_single_bbox == self.nodata_value + ): return res_single_bbox, dates return None, []