Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NF: Visualize the latent space #245

Open
wants to merge 32 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
5149706
Latent space visualization integration
levje Sep 20, 2024
c665cf2
Viz latent space each n epochs
levje Sep 27, 2024
8df0c0f
autopep8 pass
levje Sep 27, 2024
7d6de7c
Merge branch 'master' into levje/viz-latent-space
levje Sep 27, 2024
8968c7f
Merge branch 'master' of github.com:scil-vital/dwi_ml into levje/viz-…
levje Oct 2, 2024
cb64063
Fix to cpu
levje Oct 2, 2024
f0973ff
Use bundle index within HDF5 for coloring latent space
levje Oct 2, 2024
7f41de6
Subplots with best epoch: part I
levje Oct 2, 2024
327ca86
Color matching between epochs and save the plot of the best epoch
levje Oct 3, 2024
701737a
Fix best epoch legend and colors
levje Oct 3, 2024
28f3f3d
Cleaup data_per_streamline retrieval in the HDF5
levje Oct 4, 2024
a228c86
Cleanup: part 1
levje Oct 4, 2024
1d82356
Cleanup: part 2
levje Oct 5, 2024
115a7dc
Cleanup: part 3
levje Oct 5, 2024
1b1499c
Cleanup: part 4
levje Oct 6, 2024
39ecd7b
Cleanup: part 5
levje Oct 6, 2024
2e4a191
Fix dps unpacking
levje Oct 7, 2024
2de0a43
Fix tests
levje Oct 7, 2024
7f3d931
Fix pep8
levje Oct 7, 2024
e051f19
Move color generation func into separate file
levje Oct 7, 2024
a6a9a67
Doc update
levje Oct 7, 2024
26ded25
Fix missing import
levje Oct 7, 2024
c51c7a3
input francois code
arnaudbore Oct 7, 2024
f164a50
fix import
arnaudbore Oct 7, 2024
3545c88
pep8
arnaudbore Oct 8, 2024
31c7b71
remove comment - jeremi review
arnaudbore Oct 8, 2024
5f0fae5
set higher s range and v range
arnaudbore Oct 8, 2024
64efc09
Merge pull request #1 from arnaudbore/update_color_class
levje Oct 29, 2024
c9deae2
Merge branch 'master' into levje/viz-latent-space
levje Nov 8, 2024
3ea27cb
incomplete: stick to master
levje Nov 8, 2024
3669453
incomplete: stick to master part 2
levje Nov 8, 2024
6bbc096
fix: dps batch loading and stick to master 3
levje Nov 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions dwi_ml/models/projects/ae_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self,
self.latent_space_dims = 32

self.pad = torch.nn.ReflectionPad1d(1)
self.post_encoding_hooks = []

def pre_pad(m):
return torch.nn.Sequential(self.pad, m)
Expand Down Expand Up @@ -104,6 +105,7 @@ def pre_pad(m):

def forward(self,
input_streamlines: List[torch.tensor],
data_per_streamline: dict = None
):
"""Run the model on a batch of sequences.

Expand All @@ -113,6 +115,10 @@ def forward(self,
Batch of streamlines. Only used if previous directions are added to
the model. Used to compute directions; its last point will not be
used.
data_per_streamline: dict of lists, optional
Dictionary containing additional data for each streamline. Each
key is the name of a data type, and each value is a list of length
`len(input_streamlines)` containing the data for each streamline.

Returns
-------
Expand All @@ -121,12 +127,20 @@ def forward(self,
`get_tracking_directions()`.
"""

x = self.decode(self.encode(input_streamlines))
encoded = self.encode(input_streamlines)

for hook in self.post_encoding_hooks:
hook(encoded, data_per_streamline)

x = self.decode(encoded)
return x

def encode(self, x):
# x: list of tensors
x = torch.stack(x)
# X input shape is (batch_size, nb_points, 3)
if isinstance(x, list):
x = torch.stack(x)

# input of the network should be (N, 3, nb_points)
x = torch.swapaxes(x, 1, 2)

h1 = F.relu(self.encod_conv1(x))
Expand Down Expand Up @@ -171,3 +185,6 @@ def compute_loss(self, model_outputs, targets, average_results=True):
reconstruction_loss = torch.nn.MSELoss(reduction="sum")
mse = reconstruction_loss(model_outputs, targets)
return mse, 1

def register_hook_post_encoding(self, hook):
self.post_encoding_hooks.append(hook)
5 changes: 5 additions & 0 deletions dwi_ml/models/projects/learn2track_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def computed_params_for_display(self):

def forward(self, x: List[torch.tensor],
input_streamlines: List[torch.tensor] = None,
data_per_streamline: List[torch.tensor] = {},
hidden_recurrent_states: List = None, return_hidden=False,
point_idx: int = None):
"""Run the model on a batch of sequences.
Expand All @@ -243,6 +244,10 @@ def forward(self, x: List[torch.tensor],
Batch of streamlines. Only used if previous directions are added to
the model. Used to compute directions; its last point will not be
used.
data_per_streamline: dict of lists, optional
Dictionary containing additional data for each streamline. Each
key is the name of a data type, and each value is a list of length
`len(input_streamlines)` containing the data for each streamline.
hidden_recurrent_states : list[states]
The current hidden states of the (stacked) RNN model.
return_hidden: bool
Expand Down
8 changes: 6 additions & 2 deletions dwi_ml/models/projects/transformer_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,8 @@ def _prepare_masks(self, unpadded_lengths, use_padding, batch_max_len):
return mask_future, mask_padding

def forward(self, inputs: List[torch.tensor],
input_streamlines: List[torch.tensor] = None):
input_streamlines: List[torch.tensor] = None,
data_per_streamline: dict = None):
"""
Params
------
Expand All @@ -376,7 +377,10 @@ def forward(self, inputs: List[torch.tensor],
adequately masked to hide future positions. The last direction is
not used.
- As target during training. The whole sequence is used.

data_per_streamline: dict of lists, optional
Dictionary containing additional data for each streamline. Each
key is the name of a data type, and each value is a list of length
`len(input_streamlines)` containing the data for each streamline.
Returns
-------
output: Tensor,
Expand Down
20 changes: 18 additions & 2 deletions dwi_ml/training/batch_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@
logger = logging.getLogger('batch_loader_logger')


def _dps_to_tensors(dps: dict, device='cpu'):
"""
Convert a list of DPS to a list of tensors.
"""
dps_tensors = {}
for key, value in dps.items():
dps_tensors[key] = torch.tensor(value, device=device)
return dps_tensors


class DWIMLStreamlinesBatchLoader:
def __init__(self, dataset: MultiSubjectDataset, model: MainModelAbstract,
streamline_group_name: str, rng: int,
Expand Down Expand Up @@ -157,7 +167,7 @@ def params_for_checkpoint(self):
'noise_gaussian_size_forward': self.noise_gaussian_size_forward,
'noise_gaussian_size_loss': self.noise_gaussian_size_loss,
'reverse_ratio': self.reverse_ratio,
'split_ratio': self.split_ratio,
'split_ratio': self.split_ratio
}
return params

Expand Down Expand Up @@ -292,6 +302,7 @@ def load_batch_streamlines(
# the loaded, processed streamlines, not to the ids in the hdf5 file.
final_s_ids_per_subj = defaultdict(slice)
batch_streamlines = []
batch_dps = defaultdict(list)
for subj, s_ids in streamline_ids_per_subj:
logger.debug(
" Data loader: Processing data preparation for "
Expand Down Expand Up @@ -322,9 +333,14 @@ def load_batch_streamlines(
sft.to_corner()
batch_streamlines.extend(sft.streamlines)

# Add data per streamline for the batch elements
for key, value in sft.data_per_streamline.items():
batch_dps[key].extend(value)

batch_streamlines = [torch.as_tensor(s) for s in batch_streamlines]
data_per_streamline = _dps_to_tensors(sft.data_per_streamline)

return batch_streamlines, final_s_ids_per_subj
return batch_streamlines, final_s_ids_per_subj, data_per_streamline

def load_batch_connectivity_matrices(
self, streamline_ids_per_subj: Dict[int, slice]):
Expand Down
110 changes: 110 additions & 0 deletions dwi_ml/training/projects/ae_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# -*- coding: utf-8 -*-
import logging
import os
from typing import Union, List

from dwi_ml.models.main_models import MainModelAbstract
from dwi_ml.training.batch_loaders import DWIMLStreamlinesBatchLoader
from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler
from dwi_ml.training.trainers import DWIMLAbstractTrainer
from dwi_ml.viz.latent_streamlines import BundlesLatentSpaceVisualizer

LOGGER = logging.getLogger(__name__)


def parse_bundle_mapping(bundles_mapping_file: str = None):
if bundles_mapping_file is None:
return None

with open(bundles_mapping_file, 'r') as f:
bundle_mapping = {}
for line in f:
bundle_name, bundle_number = line.strip().split()
bundle_mapping[int(bundle_number)] = bundle_name
return bundle_mapping


class TrainerWithBundleDPS(DWIMLAbstractTrainer):

def __init__(self,
model: MainModelAbstract, experiments_path: str,
experiment_name: str, batch_sampler: DWIMLBatchIDSampler,
batch_loader: DWIMLStreamlinesBatchLoader,
learning_rates: Union[List, float] = None,
weight_decay: float = 0.01,
optimizer: str = 'Adam', max_epochs: int = 10,
max_batches_per_epoch_training: int = 1000,
max_batches_per_epoch_validation: Union[int, None] = 1000,
patience: int = None, patience_delta: float = 1e-6,
nb_cpu_processes: int = 0, use_gpu: bool = False,
clip_grad: float = None,
comet_workspace: str = None, comet_project: str = None,
from_checkpoint: bool = False, log_level=logging.root.level,
viz_latent_space: bool = False, color_by: str = None,
bundles_mapping_file: str = None,
max_viz_subset_size: int = 1000):

super().__init__(model, experiments_path, experiment_name,
batch_sampler, batch_loader, learning_rates,
weight_decay, optimizer, max_epochs,
max_batches_per_epoch_training,
max_batches_per_epoch_validation, patience,
patience_delta, nb_cpu_processes, use_gpu,
clip_grad, comet_workspace, comet_project,
from_checkpoint, log_level)

self.color_by = color_by
self.viz_latent_space = viz_latent_space
if self.viz_latent_space:
# Setup to visualize latent space
save_dir = os.path.join(
experiments_path, experiment_name, 'latent_space_plots')
os.makedirs(save_dir, exist_ok=True)

bundle_mapping = parse_bundle_mapping(bundles_mapping_file)
self.ls_viz = BundlesLatentSpaceVisualizer(
save_dir,
prefix_numbering=True,
max_subset_size=max_viz_subset_size,
bundle_mapping=bundle_mapping)
self.warning_printed = False

# Register what to do post encoding.
def handle_latent_encodings(encoding, data_per_streamline):
# Only accumulate data during training
if not self.model.context == 'training':
return

if self.color_by is None:
bundle_index = None
elif self.color_by not in data_per_streamline.keys():
if not self.warning_printed:
LOGGER.warning(
f"Coloring by {self.color_by} not found in "
"data_per_streamline.")
self.warning_printed = True
bundle_index = None
else:
bundle_index = \
data_per_streamline[self.color_by].squeeze(1)

self.ls_viz.add_data_to_plot(encoding, labels=bundle_index)
# Execute the above function within the model's forward().
model.register_hook_post_encoding(handle_latent_encodings)

# Plot the latent space after each best epoch.
# Called after running training & validation epochs.
self.best_epoch_monitor.register_on_best_epoch_hook(
self.ls_viz.plot)

def train_one_epoch(self, epoch):
if self.viz_latent_space:
# Before starting another training epoch, make sure the data
# is cleared. This is important to avoid accumulating data.
# We have to do it here. Since the on_new_best_epoch is called
# after the validation epoch, we can't do it there.
# Also, we won't always have the best epoch, if not, we still need
# to clear the data.
self.ls_viz.reset_data()

super().train_one_epoch(epoch)
9 changes: 5 additions & 4 deletions dwi_ml/training/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,7 +1015,7 @@ def run_one_batch(self, data):
"""
# Data interpolation has not been done yet. GPU computations are done
# here in the main thread.
targets, ids_per_subj = data
targets, ids_per_subj, data_per_streamline = data

# Dataloader always works on CPU. Sending to right device.
# (model is already moved).
Expand All @@ -1037,7 +1037,7 @@ def run_one_batch(self, data):
# but ok, shouldn't be too heavy. Easier to deal with multiple
# projects' requirements by sending whole streamlines rather
# than only directions.
model_outputs = self.model(streamlines_f)
model_outputs = self.model(streamlines_f, data_per_streamline)
del streamlines_f

logger.debug('*** Computing loss')
Expand Down Expand Up @@ -1150,7 +1150,7 @@ def run_one_batch(self, data):
"""
# Data interpolation has not been done yet. GPU computations are done
# here in the main thread.
targets, ids_per_subj = data
targets, ids_per_subj, data_per_streamline = data

# Dataloader always works on CPU. Sending to right device.
# (model is already moved).
Expand Down Expand Up @@ -1178,7 +1178,8 @@ def run_one_batch(self, data):
# (batch loader will do it depending on training / valid)
streamlines_f = self.batch_loader.add_noise_streamlines_forward(
streamlines_f, self.device)
model_outputs = self.model(batch_inputs, streamlines_f)
model_outputs = self.model(
batch_inputs, streamlines_f, data_per_streamline)
del streamlines_f

logger.debug('*** Computing loss')
Expand Down
2 changes: 1 addition & 1 deletion dwi_ml/training/trainers_withGV.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def gv_phase_one_batch(self, data, compute_all_scores=False):
seeds and first few segments. Expected results are the batch's
validation streamlines.
"""
real_lines, ids_per_subj = data
real_lines, ids_per_subj, data_per_streamline = data

# Possibly sending again to GPU even if done in the local loss
# computation, but easier with current implementation.
Expand Down
10 changes: 10 additions & 0 deletions dwi_ml/training/utils/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,14 @@ def __init__(self, name, patience: int, patience_delta: float = 1e-6):
self.best_value = None
self.best_epoch = None
self.n_bad_epochs = None
self.on_best_epoch_hooks = []

def register_on_best_epoch_hook(self, hook):
self.on_best_epoch_hooks.append(hook)

def _call_on_best_epoch_hooks(self, new_best_epoch):
for hook in self.on_best_epoch_hooks:
hook(new_best_epoch)

def update(self, loss, epoch):
"""
Expand All @@ -178,12 +186,14 @@ def update(self, loss, epoch):
self.best_value = loss
self.best_epoch = epoch
self.n_bad_epochs = 0
self._call_on_best_epoch_hooks(epoch)
return False
elif loss < self.best_value - self.min_eps:
# Improving from at least eps.
self.best_value = loss
self.best_epoch = epoch
self.n_bad_epochs = 0
self._call_on_best_epoch_hooks(epoch)
return False
else:
# Not improving enough
Expand Down
5 changes: 3 additions & 2 deletions dwi_ml/unit_tests/utils/data_and_models_for_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def compute_loss(self, model_outputs, target_streamlines=None,
else:
return torch.zeros(n, device=self.device), 1

def forward(self, inputs: list, streamlines):
def forward(self, inputs: list, streamlines, data_per_streamline):
# Not using streamlines. Pretending to use inputs.
_ = self.fake_parameter
regressed_dir = torch.as_tensor([1., 1., 1.])
Expand Down Expand Up @@ -143,7 +143,8 @@ def get_tracking_directions(self, regressed_dirs, algo,
raise ValueError("'algo' should be 'det' or 'prob'.")

def forward(self, inputs: List[torch.tensor],
target_streamlines: List[torch.tensor]):
target_streamlines: List[torch.tensor],
data_per_streamline: List[torch.tensor]):
# Previous dirs
if self.nb_previous_dirs > 0:
target_dirs = compute_directions(target_streamlines)
Expand Down
Loading
Loading