Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NF: Visualize the latent space #245

Open
wants to merge 32 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
5149706
Latent space visualization integration
levje Sep 20, 2024
c665cf2
Viz latent space each n epochs
levje Sep 27, 2024
8df0c0f
autopep8 pass
levje Sep 27, 2024
7d6de7c
Merge branch 'master' into levje/viz-latent-space
levje Sep 27, 2024
8968c7f
Merge branch 'master' of github.com:scil-vital/dwi_ml into levje/viz-…
levje Oct 2, 2024
cb64063
Fix to cpu
levje Oct 2, 2024
f0973ff
Use bundle index within HDF5 for coloring latent space
levje Oct 2, 2024
7f41de6
Subplots with best epoch: part I
levje Oct 2, 2024
327ca86
Color matching between epochs and save the plot of the best epoch
levje Oct 3, 2024
701737a
Fix best epoch legend and colors
levje Oct 3, 2024
28f3f3d
Cleaup data_per_streamline retrieval in the HDF5
levje Oct 4, 2024
a228c86
Cleanup: part 1
levje Oct 4, 2024
1d82356
Cleanup: part 2
levje Oct 5, 2024
115a7dc
Cleanup: part 3
levje Oct 5, 2024
1b1499c
Cleanup: part 4
levje Oct 6, 2024
39ecd7b
Cleanup: part 5
levje Oct 6, 2024
2e4a191
Fix dps unpacking
levje Oct 7, 2024
2de0a43
Fix tests
levje Oct 7, 2024
7f3d931
Fix pep8
levje Oct 7, 2024
e051f19
Move color generation func into separate file
levje Oct 7, 2024
a6a9a67
Doc update
levje Oct 7, 2024
26ded25
Fix missing import
levje Oct 7, 2024
c51c7a3
input francois code
arnaudbore Oct 7, 2024
f164a50
fix import
arnaudbore Oct 7, 2024
3545c88
pep8
arnaudbore Oct 8, 2024
31c7b71
remove comment - jeremi review
arnaudbore Oct 8, 2024
5f0fae5
set higher s range and v range
arnaudbore Oct 8, 2024
64efc09
Merge pull request #1 from arnaudbore/update_color_class
levje Oct 29, 2024
c9deae2
Merge branch 'master' into levje/viz-latent-space
levje Nov 8, 2024
3ea27cb
incomplete: stick to master
levje Nov 8, 2024
3669453
incomplete: stick to master part 2
levje Nov 8, 2024
6bbc096
fix: dps batch loading and stick to master 3
levje Nov 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 73 additions & 15 deletions dwi_ml/data/dataset/streamline_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import h5py
from nibabel.streamlines import ArraySequence
import numpy as np
from collections import defaultdict


def _load_space_attributes_from_hdf(hdf_group: h5py.Group):
Expand Down Expand Up @@ -42,8 +43,9 @@ def _load_all_streamlines_from_hdf(hdf_group: h5py.Group):
streamlines._data = np.array(hdf_group['data'])
streamlines._offsets = np.array(hdf_group['offsets'])
streamlines._lengths = np.array(hdf_group['lengths'])
dps_dict = _load_data_per_streamline(hdf_group)

return streamlines
return streamlines, dps_dict


def _load_connectivity_info(hdf_group: h5py.Group):
Expand All @@ -67,6 +69,30 @@ def _load_connectivity_info(hdf_group: h5py.Group):
return contains_connectivity, connectivity_nb_blocs, connectivity_labels


def _load_data_per_streamline(hdf_group,
dps_key: str = None) -> Union[np.ndarray, None]:
dps_dict = defaultdict(list)
# Load only related data key if specified
if 'data_per_streamline' not in hdf_group.keys():
return dps_dict

dps_group = hdf_group['data_per_streamline']
if dps_key is not None:
# Make sure the related data key is in the hdf5 group
if not (dps_key in dps_group.keys()):
raise KeyError("The key '{}' is not in the hdf5 group. Found: {}"
.format(dps_key, dps_group.keys()))

# Load the related data per streamline
dps_dict[dps_key] = dps_group[dps_key][:]
# Otherwise, load every dps.
else:
for dps_key in dps_group.keys():
dps_dict[dps_key] = dps_group[dps_key][:]

return dps_dict


class _LazyStreamlinesGetter(object):
def __init__(self, hdf_group):
self.hdf_group = hdf_group
Expand All @@ -81,12 +107,25 @@ def _get_one_streamline(self, idx: int):

def get_array_sequence(self, item=None):
if item is None:
streamlines = _load_all_streamlines_from_hdf(self.hdf_group)
streamlines, data_per_streamline = _load_all_streamlines_from_hdf(
self.hdf_group)
else:
streamlines = ArraySequence()
data_per_streamline = defaultdict(list)

# If data_per_streamline is not in the hdf5, use an empty dict
# so that we don't add anything to the data_per_streamline in the
# following steps.
hdf_dps_group = self.hdf_group['data_per_streamline'] if \
'data_per_streamline' in self.hdf_group.keys() else {}

if isinstance(item, int):
streamlines.append(self._get_one_streamline(item))
data = self._get_one_streamline(item)
streamlines.append(data)

for dps_key in hdf_dps_group.keys():
data_per_streamline[dps_key].append(
hdf_dps_group[dps_key][item])

elif isinstance(item, list) or isinstance(item, np.ndarray):
# Getting a list of value from a hdf5: slow. Uses fancy indexing.
Expand All @@ -96,8 +135,13 @@ def get_array_sequence(self, item=None):
# Good also load the whole data and access the indexes after.
# toDo Test speed for the three options.
for i in item:
streamlines.append(self._get_one_streamline(i),
cache_build=True)
data = self._get_one_streamline(i)
streamlines.append(data, cache_build=True)

for dps_key in hdf_dps_group.keys():
data_per_streamline[dps_key].append(
hdf_dps_group[dps_key][item])

streamlines.finalize_append()

elif isinstance(item, slice):
Expand All @@ -106,13 +150,17 @@ def get_array_sequence(self, item=None):
for offset, length in zip(offsets, lengths):
streamline = self.hdf_group['data'][offset:offset + length]
streamlines.append(streamline, cache_build=True)

for dps_key in hdf_dps_group.keys():
data_per_streamline[dps_key].append(
hdf_dps_group[dps_key][offset:offset + length])
streamlines.finalize_append()

else:
raise ValueError('Item should be either a int, list, '
'np.ndarray or slice but we received {}'
.format(type(item)))
return streamlines
return streamlines, data_per_streamline

@property
def lengths(self):
Expand Down Expand Up @@ -160,6 +208,7 @@ class SFTDataAbstract(object):
all information necessary to treat with streamlines: the data itself and
_offset, _lengths, space attributes, etc.
"""

def __init__(self, space_attributes: Tuple, space: Space, origin: Origin,
contains_connectivity: bool,
connectivity_nb_blocs: List = None,
Expand Down Expand Up @@ -256,17 +305,19 @@ def as_sft(self,
streamline_ids: Union[List[int], int, slice, None]
List of chosen ids. If None, use all streamlines.
"""
streamlines = self._get_streamlines_as_list(streamline_ids)
streamlines, dps = self._get_streamlines_as_list(streamline_ids)

sft = StatefulTractogram(streamlines, self.space_attributes,
self.space, self.origin)
self.space, self.origin,
data_per_streamline=dps)

return sft


class SFTData(SFTDataAbstract):
def __init__(self, streamlines: ArraySequence,
lengths_mm: List, connectivity_matrix: np.ndarray,
data_per_streamline: np.ndarray = None,
**kwargs):
"""
streamlines: ArraySequence or LazyStreamlinesGetter
Expand All @@ -279,6 +330,7 @@ def __init__(self, streamlines: ArraySequence,
self._lengths_mm = lengths_mm
self._connectivity_matrix = connectivity_matrix
self.is_lazy = False
self.data_per_streamline = data_per_streamline

def __len__(self):
return len(self.streamlines)
Expand Down Expand Up @@ -306,7 +358,7 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group):
Creating class instance from the hdf in cases where data is not
loaded yet. Non-lazy = loading the data here.
"""
streamlines = _load_all_streamlines_from_hdf(hdf_group)
streamlines, dps_dict = _load_all_streamlines_from_hdf(hdf_group)
# Adding non-hidden parameters for nicer later access
lengths_mm = hdf_group['euclidean_lengths']

Expand All @@ -318,7 +370,8 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group):
else:
connectivity_matrix = None

space_attributes, space, origin = _load_space_attributes_from_hdf(hdf_group)
space_attributes, space, origin = _load_space_attributes_from_hdf(
hdf_group)

# Return an instance of SubjectMRIData instantiated through __init__
# with this loaded data:
Expand All @@ -328,13 +381,18 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group):
space=space, origin=origin,
contains_connectivity=contains_connectivity,
connectivity_nb_blocs=connectivity_nb_blocs,
connectivity_labels=connectivity_labels)
connectivity_labels=connectivity_labels,
data_per_streamline=dps_dict)

def _get_streamlines_as_list(self, streamline_ids):
if streamline_ids is not None:
return self.streamlines.__getitem__(streamline_ids)
dps_indexed = {}
for key, value in self.data_per_streamline.items():
dps_indexed[key] = value[streamline_ids]

return self.streamlines.__getitem__(streamline_ids), dps_indexed
else:
return self.streamlines
return self.streamlines, self.data_per_streamline


class LazySFTData(SFTDataAbstract):
Expand Down Expand Up @@ -384,6 +442,6 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group):
connectivity_labels=connectivity_labels)

def _get_streamlines_as_list(self, streamline_ids):
streamlines = self.streamlines_getter.get_array_sequence(
streamlines, dps = self.streamlines_getter.get_array_sequence(
streamline_ids)
return streamlines
return streamlines, dps
11 changes: 8 additions & 3 deletions dwi_ml/data/hdf5/hdf5_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class HDF5Creator:
See the doc for an example of config file.
https://dwi-ml.readthedocs.io/en/latest/config_file.html
"""

def __init__(self, root_folder: Path, out_hdf_filename: Path,
training_subjs: List[str], validation_subjs: List[str],
testing_subjs: List[str], groups_config: dict,
Expand Down Expand Up @@ -629,15 +630,19 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id,
if len(sft.data_per_point) > 0:
logging.debug('sft contained data_per_point. Data not kept.')

dps_group = streamlines_group.create_group('data_per_streamline')

for dps_key in self.dps_keys:
if dps_key not in sft.data_per_streamline:
raise ValueError(
"The data_per_streamline key '{}' was not found in "
"the sft. Check your tractogram file.".format(dps_key))

logging.debug(" Include dps \"{}\" in the HDF5.".format(dps_key))
streamlines_group.create_dataset('dps_' + dps_key,
data=sft.data_per_streamline[dps_key])
logging.debug(
" Include dps \"{}\" in the HDF5.".format(dps_key))

dps_group.create_dataset(
dps_key, data=sft.data_per_streamline[dps_key])

# Accessing private Dipy values, but necessary.
# We need to deconstruct the streamlines into arrays with
Expand Down
20 changes: 17 additions & 3 deletions dwi_ml/models/projects/ae_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class ModelAE(MainModelAbstract):
deterministic (3D vectors) or probabilistic (based on probability
distribution parameters).
"""

def __init__(self,
experiment_name: str,
step_size: float = None,
Expand All @@ -35,6 +36,7 @@ def __init__(self,
self.latent_space_dims = 32

self.pad = torch.nn.ReflectionPad1d(1)
self.post_encoding_hooks = []

def pre_pad(m):
return torch.nn.Sequential(self.pad, m)
Expand Down Expand Up @@ -104,6 +106,7 @@ def pre_pad(m):

def forward(self,
input_streamlines: List[torch.tensor],
data_per_streamline: dict = None
):
"""Run the model on a batch of sequences.

Expand All @@ -121,12 +124,20 @@ def forward(self,
`get_tracking_directions()`.
"""

x = self.decode(self.encode(input_streamlines))
encoded = self.encode(input_streamlines)

for hook in self.post_encoding_hooks:
hook(encoded, data_per_streamline)

x = self.decode(encoded)
return x

def encode(self, x):
# x: list of tensors
x = torch.stack(x)
# X input shape is (batch_size, nb_points, 3)
if isinstance(x, list):
x = torch.stack(x)

# input of the network should be (N, 3, nb_points)
x = torch.swapaxes(x, 1, 2)

h1 = F.relu(self.encod_conv1(x))
Expand Down Expand Up @@ -171,3 +182,6 @@ def compute_loss(self, model_outputs, targets, average_results=True):
reconstruction_loss = torch.nn.MSELoss(reduction="sum")
mse = reconstruction_loss(model_outputs, targets)
return mse, 1

def register_hook_post_encoding(self, hook):
self.post_encoding_hooks.append(hook)
4 changes: 3 additions & 1 deletion dwi_ml/models/projects/learn2track_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def computed_params_for_display(self):

def forward(self, x: List[torch.tensor],
input_streamlines: List[torch.tensor] = None,
data_per_streamline: List[torch.tensor] = {},
hidden_recurrent_states: List = None, return_hidden=False,
point_idx: int = None):
"""Run the model on a batch of sequences.
Expand Down Expand Up @@ -284,7 +285,8 @@ def forward(self, x: List[torch.tensor],
unsorted_indices = invert_permutation(sorted_indices)
x = [x[i] for i in sorted_indices]
if input_streamlines is not None:
input_streamlines = [input_streamlines[i] for i in sorted_indices]
input_streamlines = [input_streamlines[i]
for i in sorted_indices]

# ==== 0. Previous dirs.
n_prev_dirs = None
Expand Down
7 changes: 6 additions & 1 deletion dwi_ml/models/projects/transformer_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class AbstractTransformerModel(ModelWithNeighborhood, ModelWithDirectionGetter,
https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
the embedding probably adapts to leave place for the positional encoding.
"""

def __init__(self,
experiment_name: str,
# Target preprocessing params for the batch loader + tracker
Expand Down Expand Up @@ -358,7 +359,9 @@ def _prepare_masks(self, unpadded_lengths, use_padding, batch_max_len):
return mask_future, mask_padding

def forward(self, inputs: List[torch.tensor],
input_streamlines: List[torch.tensor] = None):
input_streamlines: List[torch.tensor] = None,
data_per_streamline: Union[List[torch.tensor],
List[np.ndarray]] = None):
"""
Params
------
Expand Down Expand Up @@ -823,6 +826,7 @@ class OriginalTransformerModel(AbstractTransformerModelWithTarget):
emb_choice_x

"""

def __init__(self, input_embedded_size, n_layers_d: int, **kw):
"""
d_model = input_embedded_size = target_embedded_size.
Expand Down Expand Up @@ -964,6 +968,7 @@ class TransformerSrcAndTgtModel(AbstractTransformerModelWithTarget):
[ emb_choice_x ; emb_choice_y ]

"""

def __init__(self, **kw):
"""
No additional params. d_model = input size + target size.
Expand Down
19 changes: 15 additions & 4 deletions dwi_ml/training/batch_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@
logger = logging.getLogger('batch_loader_logger')


def _dps_to_tensors(dps: dict, device='cpu'):
"""
Convert a list of DPS to a list of tensors.
"""
dps_tensors = {}
for key, value in dps.items():
dps_tensors[key] = torch.tensor(value, device=device)
return dps_tensors


class DWIMLStreamlinesBatchLoader:
def __init__(self, dataset: MultiSubjectDataset, model: MainModelAbstract,
streamline_group_name: str, rng: int,
Expand Down Expand Up @@ -157,7 +167,7 @@ def params_for_checkpoint(self):
'noise_gaussian_size_forward': self.noise_gaussian_size_forward,
'noise_gaussian_size_loss': self.noise_gaussian_size_loss,
'reverse_ratio': self.reverse_ratio,
'split_ratio': self.split_ratio,
'split_ratio': self.split_ratio
}
return params

Expand Down Expand Up @@ -323,8 +333,9 @@ def load_batch_streamlines(
batch_streamlines.extend(sft.streamlines)

batch_streamlines = [torch.as_tensor(s) for s in batch_streamlines]
data_per_streamline = _dps_to_tensors(sft.data_per_streamline)

return batch_streamlines, final_s_ids_per_subj
return batch_streamlines, final_s_ids_per_subj, data_per_streamline

def load_batch_connectivity_matrices(
self, streamline_ids_per_subj: Dict[int, slice]):
Expand Down Expand Up @@ -447,8 +458,8 @@ def load_batch_inputs(self, batch_streamlines: List[torch.tensor],
# because in load_batch, we use sft.to_vox and sft.to_corner
# before adding streamline to batch.
subbatch_x_data = self.model.prepare_batch_one_input(
streamlines, self.context_subset, subj,
self.input_group_idx)
streamlines, self.context_subset, subj,
self.input_group_idx)

batch_x_data.extend(subbatch_x_data)

Expand Down
Loading
Loading