From 7f3d931938989ab226067c7cda3dcaa149c32f2c Mon Sep 17 00:00:00 2001 From: Jeremi Levesque Date: Mon, 7 Oct 2024 11:45:10 -0400 Subject: [PATCH] Fix pep8 --- dwi_ml/data/dataset/streamline_containers.py | 9 +- dwi_ml/training/projects/ae_trainer.py | 31 ++++--- dwi_ml/viz/latent_streamlines.py | 87 ++++++++++++-------- 3 files changed, 75 insertions(+), 52 deletions(-) diff --git a/dwi_ml/data/dataset/streamline_containers.py b/dwi_ml/data/dataset/streamline_containers.py index d78f01f4..79ab283a 100644 --- a/dwi_ml/data/dataset/streamline_containers.py +++ b/dwi_ml/data/dataset/streamline_containers.py @@ -69,17 +69,18 @@ def _load_connectivity_info(hdf_group: h5py.Group): return contains_connectivity, connectivity_nb_blocs, connectivity_labels -def _load_data_per_streamline(hdf_group, dps_key: str = None) -> Union[np.ndarray, None]: +def _load_data_per_streamline(hdf_group, + dps_key: str = None) -> Union[np.ndarray, None]: dps_dict = defaultdict(list) # Load only related data key if specified - if not 'data_per_streamline' in hdf_group.keys(): + if 'data_per_streamline' not in hdf_group.keys(): return dps_dict dps_group = hdf_group['data_per_streamline'] if dps_key is not None: # Make sure the related data key is in the hdf5 group if not (dps_key in dps_group.keys()): - raise KeyError("The key '{}' is not in the hdf5 group. Keys found: {}" + raise KeyError("The key '{}' is not in the hdf5 group. Found: {}" .format(dps_key, dps_group.keys())) # Load the related data per streamline @@ -352,7 +353,7 @@ def _access_connectivity_matrix(self, indxyz: Tuple = None): return self._connectivity_matrix @classmethod - def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group, dps_key: str = None): + def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group): """ Creating class instance from the hdf in cases where data is not loaded yet. Non-lazy = loading the data here. diff --git a/dwi_ml/training/projects/ae_trainer.py b/dwi_ml/training/projects/ae_trainer.py index cc873176..e74abba5 100644 --- a/dwi_ml/training/projects/ae_trainer.py +++ b/dwi_ml/training/projects/ae_trainer.py @@ -41,12 +41,17 @@ def __init__(self, comet_workspace: str = None, comet_project: str = None, from_checkpoint: bool = False, log_level=logging.root.level, viz_latent_space: bool = False, color_by: str = None, - bundles_mapping_file: str = None, max_viz_subset_size: int = 1000): - - super().__init__(model, experiments_path, experiment_name, batch_sampler, batch_loader, - learning_rates, weight_decay, optimizer, max_epochs, max_batches_per_epoch_training, - max_batches_per_epoch_validation, patience, patience_delta, nb_cpu_processes, use_gpu, - clip_grad, comet_workspace, comet_project, from_checkpoint, log_level) + bundles_mapping_file: str = None, + max_viz_subset_size: int = 1000): + + super().__init__(model, experiments_path, experiment_name, + batch_sampler, batch_loader, learning_rates, + weight_decay, optimizer, max_epochs, + max_batches_per_epoch_training, + max_batches_per_epoch_validation, patience, + patience_delta, nb_cpu_processes, use_gpu, + clip_grad, comet_workspace, comet_project, + from_checkpoint, log_level) self.color_by = color_by self.viz_latent_space = viz_latent_space @@ -57,10 +62,11 @@ def __init__(self, os.makedirs(save_dir, exist_ok=True) bundle_mapping = parse_bundle_mapping(bundles_mapping_file) - self.ls_viz = BundlesLatentSpaceVisualizer(save_dir, - prefix_numbering=True, - max_subset_size=max_viz_subset_size, - bundle_mapping=bundle_mapping) + self.ls_viz = BundlesLatentSpaceVisualizer( + save_dir, + prefix_numbering=True, + max_subset_size=max_viz_subset_size, + bundle_mapping=bundle_mapping) self.warning_printed = False # Register what to do post encoding. @@ -71,10 +77,11 @@ def handle_latent_encodings(encoding, data_per_streamline): if self.color_by is None: bundle_index = None - elif not self.color_by in data_per_streamline.keys(): + elif self.color_by not in data_per_streamline.keys(): if not self.warning_printed: LOGGER.warning( - f"Coloring by {self.color_by} not found in data_per_streamline.") + f"Coloring by {self.color_by} not found in " + "data_per_streamline.") self.warning_printed = True bundle_index = None else: diff --git a/dwi_ml/viz/latent_streamlines.py b/dwi_ml/viz/latent_streamlines.py index 66925cd7..5977923b 100644 --- a/dwi_ml/viz/latent_streamlines.py +++ b/dwi_ml/viz/latent_streamlines.py @@ -20,33 +20,40 @@ def __init__(self, max_num_bundles: int = 40): self.bundle_color_map = {} self.color_map = self._init_colormap(max_num_bundles) - def _init_colormap(self, number_of_distinct_colors): + def _init_colormap(self, nb_distinct_colors: int): """ Create a colormap with a number of distinct colors. Needed to have bigger color maps for more bundles. - Code directly copied from: - https://stackoverflow.com/questions/42697933/colormap-with-maximum-distinguishable-colours + Code directly copied from: + https://stackoverflow.com/questions/42697933 """ - if number_of_distinct_colors == 0: - number_of_distinct_colors = 80 + if nb_distinct_colors == 0: + nb_distinct_colors = 80 - number_of_shades = 7 - number_of_distinct_colors_with_multiply_of_shades = int( - math.ceil(number_of_distinct_colors / number_of_shades) * number_of_shades) + nb_of_shades = 7 + nb_of_distinct_colors_with_mult_of_shades = int( + math.ceil(nb_distinct_colors / nb_of_shades) + * nb_of_shades) - # Create an array with uniformly drawn floats taken from <0, 1) partition + # Create an array with uniformly drawn floats taken from <0, 1) + # partition linearly_distributed_nums = np.arange( - number_of_distinct_colors_with_multiply_of_shades) / number_of_distinct_colors_with_multiply_of_shades - - # We are going to reorganise monotonically growing numbers in such way that there will be single array with saw-like pattern - # but each saw tooth is slightly higher than the one before - # First divide linearly_distributed_nums into number_of_shades sub-arrays containing linearly distributed numbers + nb_of_distinct_colors_with_mult_of_shades) / \ + nb_of_distinct_colors_with_mult_of_shades + + # We are going to reorganise monotonically growing numbers in such way + # that there will be single array with saw-like pattern but each saw + # tooth is slightly higher than the one before. First divide + # linearly_distributed_nums into nb_of_shades sub-arrays containing + # linearly distributed numbers. arr_by_shade_rows = linearly_distributed_nums.reshape( - number_of_shades, number_of_distinct_colors_with_multiply_of_shades // number_of_shades) + nb_of_shades, nb_of_distinct_colors_with_mult_of_shades // + nb_of_shades) - # Transpose the above matrix (columns become rows) - as a result each row contains saw tooth with values slightly higher than row above + # Transpose the above matrix (columns become rows) - as a result each + # row contains saw tooth with values slightly higher than row above arr_by_shade_columns = arr_by_shade_rows.T # Keep number of saw teeth for later @@ -55,27 +62,31 @@ def _init_colormap(self, number_of_distinct_colors): # Flatten the above matrix - join each row into single array nums_distributed_like_rising_saw = arr_by_shade_columns.reshape(-1) - # HSV colour map is cyclic (https://matplotlib.org/tutorials/colors/colormaps.html#cyclic), we'll use this property + # HSV colour map is cyclic we'll use this property + # (https://matplotlib.org/tutorials/colors/colormaps.html#cyclic) initial_cm = hsv(nums_distributed_like_rising_saw) lower_partitions_half = number_of_partitions // 2 upper_partitions_half = number_of_partitions - lower_partitions_half - # Modify lower half in such way that colours towards beginning of partition are darker - # First colours are affected more, colours closer to the middle are affected less - lower_half = lower_partitions_half * number_of_shades + # Modify lower half in such way that colours towards beginning of + # partition are darker .First colours are affected more, colours + # closer to the middle are affected less + lower_half = lower_partitions_half * nb_of_shades for i in range(3): initial_cm[0:lower_half, i] *= np.arange(0.2, 1, 0.8/lower_half) - # Modify second half in such way that colours towards end of partition are less intense and brighter - # Colours closer to the middle are affected less, colours closer to the end are affected more + # Modify second half in such way that colours towards end of partition + # are less intense and brighter. Colours closer to the middle are + # affected less, colours closer to the end are affected more for i in range(3): for j in range(upper_partitions_half): - modifier = np.ones( - number_of_shades) - initial_cm[lower_half + j * number_of_shades: lower_half + (j + 1) * number_of_shades, i] + modifier = np.ones(nb_of_shades) \ + - initial_cm[lower_half + j * nb_of_shades: + lower_half + (j + 1) * nb_of_shades, i] modifier = j * modifier / upper_partitions_half - initial_cm[lower_half + j * number_of_shades: lower_half + - (j + 1) * number_of_shades, i] += modifier + initial_cm[lower_half + j * nb_of_shades: lower_half + + (j + 1) * nb_of_shades, i] += modifier return ListedColormap(initial_cm) @@ -186,21 +197,24 @@ def add_data_to_plot(self, data: np.ndarray, labels: List[str]): self.bundles[DEFAULT_BUNDLE_NAME] = latent_space_streamlines else: all_labels = np.unique(labels) - _remaining_indices = np.arange(len(labels)) + remaining_indices = np.arange(len(labels)) for label in all_labels: - label_indices = labels[_remaining_indices] == label - label_data = latent_space_streamlines[_remaining_indices][label_indices] + label_indices = labels[remaining_indices] == label + label_data = \ + latent_space_streamlines[remaining_indices][label_indices] label_data = self._resample_max_subset_size(label_data) self.bundles[label] = label_data - _remaining_indices = _remaining_indices[~label_indices] + remaining_indices = remaining_indices[~label_indices] - if len(_remaining_indices) > 0: + if len(remaining_indices) > 0: LOGGER.warning( "Some streamlines were not considered in the bundles," "some labels are missing.\n" - "Added them to the {} bundle.".format(DEFAULT_BUNDLE_NAME)) - self.bundles[DEFAULT_BUNDLE_NAME] = latent_space_streamlines[_remaining_indices] + "Added them to the {} bundle." + .format(DEFAULT_BUNDLE_NAME)) + self.bundles[DEFAULT_BUNDLE_NAME] = \ + latent_space_streamlines[remaining_indices] def add_bundle_to_plot(self, data: np.ndarray, label: str = '_'): """ @@ -244,7 +258,8 @@ def plot(self, epoch: int, figure_name_prefix: str = 'lt_space'): # So that the warning above is only displayed once. self.should_call_reset_before_plot = True - # Start by making sure the number of streamlines doesn't exceed the threshold. + # Start by making sure the number of streamlines doesn't + # exceed the threshold. for (bname, bdata) in self.bundles.items(): if bdata.shape[0] > self.max_subset_size: self.bundles[bname] = self._resample_max_subset_size(bdata) @@ -347,8 +362,8 @@ def _resample_max_subset_size(self, data: np.ndarray): "A max_subset_size of an integer value greater" "than 0 is required.") - # Only sample if we need to reduce the number of latent streamlines - # to show on the plot. + # Only sample if we need to reduce the number of latent + # streamlines to show on the plot. if (len(data) > self.max_subset_size): sample_indices = np.random.choice( len(data),