Skip to content

Commit

Permalink
Fix pep8
Browse files Browse the repository at this point in the history
  • Loading branch information
levje committed Oct 7, 2024
1 parent 2de0a43 commit 7f3d931
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 52 deletions.
9 changes: 5 additions & 4 deletions dwi_ml/data/dataset/streamline_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,18 @@ def _load_connectivity_info(hdf_group: h5py.Group):
return contains_connectivity, connectivity_nb_blocs, connectivity_labels


def _load_data_per_streamline(hdf_group, dps_key: str = None) -> Union[np.ndarray, None]:
def _load_data_per_streamline(hdf_group,
dps_key: str = None) -> Union[np.ndarray, None]:
dps_dict = defaultdict(list)
# Load only related data key if specified
if not 'data_per_streamline' in hdf_group.keys():
if 'data_per_streamline' not in hdf_group.keys():
return dps_dict

dps_group = hdf_group['data_per_streamline']
if dps_key is not None:
# Make sure the related data key is in the hdf5 group
if not (dps_key in dps_group.keys()):
raise KeyError("The key '{}' is not in the hdf5 group. Keys found: {}"
raise KeyError("The key '{}' is not in the hdf5 group. Found: {}"
.format(dps_key, dps_group.keys()))

# Load the related data per streamline
Expand Down Expand Up @@ -352,7 +353,7 @@ def _access_connectivity_matrix(self, indxyz: Tuple = None):
return self._connectivity_matrix

@classmethod
def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group, dps_key: str = None):
def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group):
"""
Creating class instance from the hdf in cases where data is not
loaded yet. Non-lazy = loading the data here.
Expand Down
31 changes: 19 additions & 12 deletions dwi_ml/training/projects/ae_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,17 @@ def __init__(self,
comet_workspace: str = None, comet_project: str = None,
from_checkpoint: bool = False, log_level=logging.root.level,
viz_latent_space: bool = False, color_by: str = None,
bundles_mapping_file: str = None, max_viz_subset_size: int = 1000):

super().__init__(model, experiments_path, experiment_name, batch_sampler, batch_loader,
learning_rates, weight_decay, optimizer, max_epochs, max_batches_per_epoch_training,
max_batches_per_epoch_validation, patience, patience_delta, nb_cpu_processes, use_gpu,
clip_grad, comet_workspace, comet_project, from_checkpoint, log_level)
bundles_mapping_file: str = None,
max_viz_subset_size: int = 1000):

super().__init__(model, experiments_path, experiment_name,
batch_sampler, batch_loader, learning_rates,
weight_decay, optimizer, max_epochs,
max_batches_per_epoch_training,
max_batches_per_epoch_validation, patience,
patience_delta, nb_cpu_processes, use_gpu,
clip_grad, comet_workspace, comet_project,
from_checkpoint, log_level)

self.color_by = color_by
self.viz_latent_space = viz_latent_space
Expand All @@ -57,10 +62,11 @@ def __init__(self,
os.makedirs(save_dir, exist_ok=True)

bundle_mapping = parse_bundle_mapping(bundles_mapping_file)
self.ls_viz = BundlesLatentSpaceVisualizer(save_dir,
prefix_numbering=True,
max_subset_size=max_viz_subset_size,
bundle_mapping=bundle_mapping)
self.ls_viz = BundlesLatentSpaceVisualizer(
save_dir,
prefix_numbering=True,
max_subset_size=max_viz_subset_size,
bundle_mapping=bundle_mapping)
self.warning_printed = False

# Register what to do post encoding.
Expand All @@ -71,10 +77,11 @@ def handle_latent_encodings(encoding, data_per_streamline):

if self.color_by is None:
bundle_index = None
elif not self.color_by in data_per_streamline.keys():
elif self.color_by not in data_per_streamline.keys():
if not self.warning_printed:
LOGGER.warning(
f"Coloring by {self.color_by} not found in data_per_streamline.")
f"Coloring by {self.color_by} not found in "
"data_per_streamline.")
self.warning_printed = True
bundle_index = None
else:
Expand Down
87 changes: 51 additions & 36 deletions dwi_ml/viz/latent_streamlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,40 @@ def __init__(self, max_num_bundles: int = 40):
self.bundle_color_map = {}
self.color_map = self._init_colormap(max_num_bundles)

def _init_colormap(self, number_of_distinct_colors):
def _init_colormap(self, nb_distinct_colors: int):
"""
Create a colormap with a number of distinct colors.
Needed to have bigger color maps for more bundles.
Code directly copied from:
https://stackoverflow.com/questions/42697933/colormap-with-maximum-distinguishable-colours
Code directly copied from:
https://stackoverflow.com/questions/42697933
"""
if number_of_distinct_colors == 0:
number_of_distinct_colors = 80
if nb_distinct_colors == 0:
nb_distinct_colors = 80

number_of_shades = 7
number_of_distinct_colors_with_multiply_of_shades = int(
math.ceil(number_of_distinct_colors / number_of_shades) * number_of_shades)
nb_of_shades = 7
nb_of_distinct_colors_with_mult_of_shades = int(
math.ceil(nb_distinct_colors / nb_of_shades)
* nb_of_shades)

# Create an array with uniformly drawn floats taken from <0, 1) partition
# Create an array with uniformly drawn floats taken from <0, 1)
# partition
linearly_distributed_nums = np.arange(
number_of_distinct_colors_with_multiply_of_shades) / number_of_distinct_colors_with_multiply_of_shades

# We are going to reorganise monotonically growing numbers in such way that there will be single array with saw-like pattern
# but each saw tooth is slightly higher than the one before
# First divide linearly_distributed_nums into number_of_shades sub-arrays containing linearly distributed numbers
nb_of_distinct_colors_with_mult_of_shades) / \
nb_of_distinct_colors_with_mult_of_shades

# We are going to reorganise monotonically growing numbers in such way
# that there will be single array with saw-like pattern but each saw
# tooth is slightly higher than the one before. First divide
# linearly_distributed_nums into nb_of_shades sub-arrays containing
# linearly distributed numbers.
arr_by_shade_rows = linearly_distributed_nums.reshape(
number_of_shades, number_of_distinct_colors_with_multiply_of_shades // number_of_shades)
nb_of_shades, nb_of_distinct_colors_with_mult_of_shades //
nb_of_shades)

# Transpose the above matrix (columns become rows) - as a result each row contains saw tooth with values slightly higher than row above
# Transpose the above matrix (columns become rows) - as a result each
# row contains saw tooth with values slightly higher than row above
arr_by_shade_columns = arr_by_shade_rows.T

# Keep number of saw teeth for later
Expand All @@ -55,27 +62,31 @@ def _init_colormap(self, number_of_distinct_colors):
# Flatten the above matrix - join each row into single array
nums_distributed_like_rising_saw = arr_by_shade_columns.reshape(-1)

# HSV colour map is cyclic (https://matplotlib.org/tutorials/colors/colormaps.html#cyclic), we'll use this property
# HSV colour map is cyclic we'll use this property
# (https://matplotlib.org/tutorials/colors/colormaps.html#cyclic)
initial_cm = hsv(nums_distributed_like_rising_saw)

lower_partitions_half = number_of_partitions // 2
upper_partitions_half = number_of_partitions - lower_partitions_half

# Modify lower half in such way that colours towards beginning of partition are darker
# First colours are affected more, colours closer to the middle are affected less
lower_half = lower_partitions_half * number_of_shades
# Modify lower half in such way that colours towards beginning of
# partition are darker .First colours are affected more, colours
# closer to the middle are affected less
lower_half = lower_partitions_half * nb_of_shades
for i in range(3):
initial_cm[0:lower_half, i] *= np.arange(0.2, 1, 0.8/lower_half)

# Modify second half in such way that colours towards end of partition are less intense and brighter
# Colours closer to the middle are affected less, colours closer to the end are affected more
# Modify second half in such way that colours towards end of partition
# are less intense and brighter. Colours closer to the middle are
# affected less, colours closer to the end are affected more
for i in range(3):
for j in range(upper_partitions_half):
modifier = np.ones(
number_of_shades) - initial_cm[lower_half + j * number_of_shades: lower_half + (j + 1) * number_of_shades, i]
modifier = np.ones(nb_of_shades) \
- initial_cm[lower_half + j * nb_of_shades:
lower_half + (j + 1) * nb_of_shades, i]
modifier = j * modifier / upper_partitions_half
initial_cm[lower_half + j * number_of_shades: lower_half +
(j + 1) * number_of_shades, i] += modifier
initial_cm[lower_half + j * nb_of_shades: lower_half +
(j + 1) * nb_of_shades, i] += modifier

return ListedColormap(initial_cm)

Expand Down Expand Up @@ -186,21 +197,24 @@ def add_data_to_plot(self, data: np.ndarray, labels: List[str]):
self.bundles[DEFAULT_BUNDLE_NAME] = latent_space_streamlines
else:
all_labels = np.unique(labels)
_remaining_indices = np.arange(len(labels))
remaining_indices = np.arange(len(labels))
for label in all_labels:
label_indices = labels[_remaining_indices] == label
label_data = latent_space_streamlines[_remaining_indices][label_indices]
label_indices = labels[remaining_indices] == label
label_data = \
latent_space_streamlines[remaining_indices][label_indices]
label_data = self._resample_max_subset_size(label_data)
self.bundles[label] = label_data

_remaining_indices = _remaining_indices[~label_indices]
remaining_indices = remaining_indices[~label_indices]

if len(_remaining_indices) > 0:
if len(remaining_indices) > 0:
LOGGER.warning(
"Some streamlines were not considered in the bundles,"
"some labels are missing.\n"
"Added them to the {} bundle.".format(DEFAULT_BUNDLE_NAME))
self.bundles[DEFAULT_BUNDLE_NAME] = latent_space_streamlines[_remaining_indices]
"Added them to the {} bundle."
.format(DEFAULT_BUNDLE_NAME))
self.bundles[DEFAULT_BUNDLE_NAME] = \
latent_space_streamlines[remaining_indices]

def add_bundle_to_plot(self, data: np.ndarray, label: str = '_'):
"""
Expand Down Expand Up @@ -244,7 +258,8 @@ def plot(self, epoch: int, figure_name_prefix: str = 'lt_space'):
# So that the warning above is only displayed once.
self.should_call_reset_before_plot = True

# Start by making sure the number of streamlines doesn't exceed the threshold.
# Start by making sure the number of streamlines doesn't
# exceed the threshold.
for (bname, bdata) in self.bundles.items():
if bdata.shape[0] > self.max_subset_size:
self.bundles[bname] = self._resample_max_subset_size(bdata)
Expand Down Expand Up @@ -347,8 +362,8 @@ def _resample_max_subset_size(self, data: np.ndarray):
"A max_subset_size of an integer value greater"
"than 0 is required.")

# Only sample if we need to reduce the number of latent streamlines
# to show on the plot.
# Only sample if we need to reduce the number of latent
# streamlines to show on the plot.
if (len(data) > self.max_subset_size):
sample_indices = np.random.choice(
len(data),
Expand Down

0 comments on commit 7f3d931

Please sign in to comment.