-
Notifications
You must be signed in to change notification settings - Fork 2
PseudoCode Snippets
Juan Emmanuel Johnson edited this page Feb 20, 2024
·
2 revisions
# create extra dimensions
da = da.assign(band=ds.band_id, time=ds.time)
# select within time range
ds = ds.sel(time=ds.time.dt.hour.isin(range(0, 24, time_resolution_hours)))
# convert to numpy array
ds: np.ndarray = (
ds.to_array(dim="variable").transpose("variable", "time", "latitude", "longitude").data
)
# get nearest lat-lon index
def get_nearest_latlon_index(ds: xr.Dataset, lat: float, lon: float) -> tuple[float, float]:
"""Get the nearest index into lat/lon to the given lat/lon"""
lat_index = np.abs(ds.latitude.data - lat).argmin()
lon_index = np.abs(ds.longitude.data - lon).argmin()
return lat_index, lon_index
lat_index, lon_index = get_nearest_latlon_index(ds, lat, lon)
ds = ds.isel(
latitude=slice(lat_index - image_size // 2, lat_index + image_size // 2),
longitude=slice(lon_index - image_size // 2, lon_index + image_size // 2),
)
Plot a map where you highlight a country or region
# Create the figure and axes with PlateCarree projection
ax_low = plt.axes(projection=ccrs.PlateCarree())
# Plot coastlines
ax_low.coastlines()
# Plot country boundaries for Brazil with thicker lines
brazil = cfeature.NaturalEarthFeature(
category='cultural',
name='admin_1_states_provinces_lines',
scale='50m',
facecolor='none',
edgecolor='white', # Change the edge color to white
linewidth=0.5 # Increase the linewidth to make it thicker
)
ax_low.add_feature(brazil)
# Plot country boundaries for other countries with thinner lines
other_countries = cfeature.NaturalEarthFeature(
category='cultural',
name='admin_0_countries',
scale='50m',
facecolor='none',
edgecolor='white', # Use grey color for other countries
linewidth=0.7 # Use thinner lines for other countries
)
ax_low.add_feature(other_countries)
# Plot your data on the map
data.isel(time=time).plot(ax=ax_low, transform=ccrs.PlateCarree(), vmin=vmin, vmax=vmax, cmap=cmap)
# Add gridlines and labels with grey color
gl = ax_low.gridlines(draw_labels=True, color='grey', linewidth=0.3)
# Show the plot
plt.show()
Create an xarray dataset in the safest way possible
def create_xarray(data: np.array, start_date: pd.Timestamp = pd.Timestamp('2007-01-01'), bbox=[-35, -75, 5, -35]):
num_time_steps, image_size, _ = data.shape
date_range = xr.cftime_range(start=start_date, periods=num_time_steps, freq="1M")
# Create the time_coords using xr.DataArray with date values
time_coords = xr.DataArray(date_range, dims=("time",), attrs={"units": "months"})
# Create latitude and longitude arrays
latitude_values = np.linspace(bbox[0], bbox[2], image_size)
longitude_values = np.linspace(bbox[1], bbox[3], image_size)
# Create coordinate arrays using xarray DataArray
latitude_coords = xr.DataArray(latitude_values, dims=("latitude",), attrs={"units": "degrees_north"})
longitude_coords = xr.DataArray(longitude_values, dims=("longitude",), attrs={"units": "degrees_east"})
# Create a DataArray with the input data and coordinate values
data_array = xr.DataArray(
data,
dims=("time", "latitude", "longitude"),
coords={"time": time_coords, "latitude": latitude_coords, "longitude": longitude_coords}
)
return data_array
A minimum example to load many files and clean
data = xr.open_mfdataset(f'{data_dir}/*.nc', combine='by_coords')
data = data * 86400
data = data.resample(time='1MS').sum()
data = data.assign_coords(lon=(((data.lon + 180) % 360) - 180))
data = data.roll(lon=int(len(data['lon']) / 2), roll_coords=True)
data = data.sel(time=slice(time_init, time_end), lat=slice(bbox[0], bbox[2]), lon=slice(bbox[1], bbox[3]))
data = data.rename({'lat': 'latitude', 'lon': 'longitude'})
A minimal example to load a file and convert it to a raster
data = xr.open_dataset(data_dir)
data = data.sel(time=slice(time_init, time_end), latitude=slice(bbox[0], bbox[2]), longitude=slice(bbox[1], bbox[3]))
data = data.rename({'precip': 'pr'})
data = data.rio.write_crs("EPSG:4326")
upscale_factor = 5
new_width = data.rio.width // upscale_factor
new_height = data.rio.height // upscale_factor
data = data.rio.reproject(
data.rio.crs,
shape=(new_height, new_width),
resampling=Resampling.average,
)
data = data.rename({'y': 'latitude', 'x': 'longitude'})
Create Slices for custom windows of xarray data
def slicing(data: xr.Dataset, size: int):
"""Slicing the data into smaller pieces and convert to numpy arrays.
Args:
data (xarray.Dataset): CMIP6 data for the given time period and bounding box.
size (int): Size of the slices.
Returns:
data_slices (List[float, float]): List of slices.
"""
data_slices = []
for year in range(0, len(data['time'])):
for latitude in range(0, len(data['latitude']), size):
for longitude in range(0, len(data['longitude']), size):
data_slices.append(data.pr.isel(time=year, latitude=slice(latitude, latitude+size), longitude=slice(longitude, longitude+size)).values)
return np.array(data_slices)
This research is funded through a NASA 22-MDRAIT22-0018 award (No 80NSSC23K1045) and managed by Trillium Technologies Inc (trillium.tech).