From 4fca89c28b6985186afe91a1a18adf37f74e2661 Mon Sep 17 00:00:00 2001 From: sarah-paradis Date: Mon, 6 May 2024 22:38:39 +0200 Subject: [PATCH] Fix train excluded for zero buffer radius --- spacv/base_classes.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/spacv/base_classes.py b/spacv/base_classes.py index 45609f0..facadb8 100644 --- a/spacv/base_classes.py +++ b/spacv/base_classes.py @@ -44,14 +44,10 @@ def split(self, XYs): indices = XYs.index.values for test_indices, train_excluded in self._iter_test_indices(XYs): - # Exclude the training indices within buffer + # Combine training indices within buffer with test indices train_excluded = np.concatenate([test_indices, train_excluded]) - train_index = np.setdiff1d( - np.union1d( - indices, - train_excluded - ), np.intersect1d(indices, train_excluded) - ) + # Exclude test indices and training indices within buffer to get final training indices + train_index = np.setdiff1d(indices, train_excluded) if len(train_index) < 1: raise ValueError( "Training set is empty. Try lowering buffer_radius to include more training instances." @@ -68,11 +64,11 @@ def _remove_buffered_indices(self, XYs, test_indices, buffer_radius, geometry_bu geometry_buffer = convert_geodataframe(geometry_buffer) deadzone_points = gpd.sjoin(candidate_deadzone, geometry_buffer) train_exclude = deadzone_points.loc[~deadzone_points.index.isin(test_indices)].index.values - return test_indices, train_exclude else: - # Yield empty array because no training data removed in dead zone when buffer is zero - _ = np.empty([], dtype=np.int) - return test_indices, _ + # Yield empty array (with the same dimensions as test_indices) because no training data removed in dead zone + # when buffer is zero + train_exclude = np.empty([0] * test_indices.ndim, dtype=np.int) + return test_indices, train_exclude @abstractmethod def _iter_test_indices(self, XYs):