Skip to content

Commit

Permalink
Add zero filter in save functions (#198)
Browse files Browse the repository at this point in the history
* Add zero filter in save functions

* Add docstring

* Linting

* Add test

* Rename test file

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
devsjc and pre-commit-ci[bot] authored Nov 21, 2023
1 parent 390b976 commit 78cd3d6
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
23 changes: 23 additions & 0 deletions satip/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,13 +410,30 @@ def get_dataset_from_scene(filename: str, hrv_scaler, use_rescaler: bool, save_d
hrv_dataset.attrs.update(attrs)
log.debug("Converted HRV to DataArray", memory=get_memory())
now_time = pd.Timestamp(hrv_dataset["time"].values[0]).strftime("%Y%m%d%H%M")

# Check for data quality
if not data_quality_filter(hrv_dataset):
del hrv_dataset
gc.collect()
return

save_file = os.path.join(save_dir, f"{'15_' if using_backup else ''}hrv_{now_time}.zarr.zip")
log.debug(f"Saving HRV netcdf in {save_file}", memory=get_memory())
save_to_zarr_to_s3(hrv_dataset, save_file)
del hrv_dataset
gc.collect()
log.debug("Saved HRV to NetCDF", memory=get_memory())

def data_quality_filter(ds: xr.Dataset, threshold_fraction: float = 0.9) -> bool:
"""
Filter out datasets with a high fraction of zeros"
"""
for var in ds.data_vars:
fraction_of_zeros = np.isclose(ds[var], 0.0).mean()
if fraction_of_zeros > threshold_fraction:
log.debug(f"Ignoring dataset {ds} as {var} has {fraction_of_zeros} fraction of zeros"\
f" (threshold {threshold_fraction})")
return False

def get_nonhrv_dataset_from_scene(
filename: str, scaler, use_rescaler: bool, save_dir, using_backup
Expand Down Expand Up @@ -508,6 +525,12 @@ def get_nonhrv_dataset_from_scene(
dataset.attrs.update(attrs)
log.debug("Deleted return list", memory=get_memory())
now_time = pd.Timestamp(dataset["time"].values[0]).strftime("%Y%m%d%H%M")

if not data_quality_filter(dataset):
del dataset
gc.collect()
return

save_file = os.path.join(save_dir, f"{'15_' if using_backup else ''}{now_time}.zarr.zip")
log.debug(f"Saving non-HRV netcdf in {save_file}", memory=get_memory())
save_to_zarr_to_s3(dataset, save_file)
Expand Down
16 changes: 16 additions & 0 deletions tests/disabled_utils.py → tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
load_cloudmask_to_dataarray,
load_native_to_dataarray,
save_dataarray_to_zarr,
data_quality_filter,
)

USER_KEY = os.environ.get("EUMETSAT_USER_KEY")
Expand Down Expand Up @@ -95,3 +96,18 @@ def test_save_dataarray_to_zarr(self): # noqa D102
zarr_mode="w",
)
self.assertEqual(1, len(list(glob.glob(zarr_path))))

def test_data_quality_filter(self):
test_dataset_zeros = xarray.Dataset({
"data": (("time", "y", "x"), np.zeros(100, 100, 100))
})

out = data_quality_filter(test, 0.9)
self.assertFalse(out)

test_dataset_ones = xarray.Dataset({
"data": (("time", "y", "x"), np.ones(100, 100, 100))
})

out = data_quality_filter(test, 0.9)
self.assertTrue(out)

0 comments on commit 78cd3d6

Please sign in to comment.