Skip to content

Commit

Permalink
merge master
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaudbore committed Sep 13, 2024
2 parents 346c383 + 93cec1e commit 5411a21
Show file tree
Hide file tree
Showing 105 changed files with 6,211 additions and 3,204 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,5 @@ target/
*.swo

# dwi_ml stuff
.ipynb_config/
.ipynb_config/
.ipynb_checkpoints/
6 changes: 3 additions & 3 deletions bash_utilities/scil_score_ismrm_Renauld2023.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ fi


echo '------------- SEGMENTATION ------------'
scil_score_tractogram.py $tractogram $config_file_segmentation $out_dir --no_empty \
scil_tractogram_segment_and_score.py $tractogram $config_file_segmentation $out_dir --no_empty \
--gt_dir $scoring_data --reference $ref --json_prefix tmp_ --no_bbox_check;

echo '------------- Merging CC sub-bundles ------------'
Expand All @@ -54,7 +54,7 @@ then
fi

echo '------------- FINAL SCORING ------------'
scil_score_bundles.py -v $config_file_tractometry $out_dir \
--gt_dir $scoring_data --reference $ref --no_bbox_check
scil_bundle_score_many_bundles_one_tractogram.py $config_file_tractometry $out_dir \
--gt_dir $scoring_data --reference $ref --no_bbox_check -v

cat $out_dir/results.json
14 changes: 12 additions & 2 deletions dwi_ml/data/dataset/multi_subject_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,16 @@ def get_mri_data(self, subj_idx: int, group_idx: int,
Contrary to get_volume_verify_cache, this does not send data to
cache for later use.
Parameters
----------
subj_idx: int
The subject id.
group_idx: int
The volume group idx.
load_it: bool
If data is lazy, get the volume as a LazyMRIData (False) or load it
as non-lazy (if True).
"""
if self.subjs_data_list.is_lazy:
if load_it:
Expand Down Expand Up @@ -271,7 +281,6 @@ def load(self, hdf_handle: h5py.File, subj_id=None):
hdf_handle, subj_id, ref_group_info)

# Add subject to the list
logger.debug(" Adding it to the list of subjects.")
subj_idx = self.subjs_data_list.add_subject(subj_data)

# Arrange streamlines
Expand All @@ -280,7 +289,6 @@ def load(self, hdf_handle: h5py.File, subj_id=None):
if subj_data.is_lazy:
subj_data.add_handle(hdf_handle)

logger.debug(" Counting streamlines")
for group in range(len(self.streamline_groups)):
subj_sft_data = subj_data.sft_data_list[group]
n_streamlines = len(subj_sft_data)
Expand All @@ -292,6 +300,7 @@ def load(self, hdf_handle: h5py.File, subj_id=None):
subj_data.hdf_handle = None

# Arrange final data properties: Concatenate all subjects
logging.debug("All subjects added. Final verifications.")
self.streamline_lengths_mm = \
[np.concatenate(lengths_mm[group], axis=0)
for group in range(len(self.streamline_groups))]
Expand Down Expand Up @@ -484,6 +493,7 @@ def load_data(self, load_training=True, load_validation=True,
self.streamline_groups = poss_strea_groups
self.streamlines_contain_connectivity = contains_connectivity

self.streamline_groups = list(self.streamline_groups)
group_info = (self.volume_groups, self.nb_features,
self.streamline_groups,
self.streamlines_contain_connectivity)
Expand Down
3 changes: 2 additions & 1 deletion dwi_ml/data/dataset/single_subject_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def init_single_subject_from_hdf(
subject_mri_data_list.append(subject_mri_group_data)

for group in streamline_groups:
logger.debug(" Loading subject's streamlines")
logger.debug(" Loading streamlines group '{}'"
.format(group))
sft_data = SFTData.init_sft_data_from_hdf_info(
hdf_file[subject_id][group])
subject_sft_data_list.append(sft_data)
Expand Down
84 changes: 58 additions & 26 deletions dwi_ml/data/dataset/streamline_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,27 @@ def _load_all_streamlines_from_hdf(hdf_group: h5py.Group):
return streamlines


def _load_connectivity_info(hdf_group: h5py.Group):
connectivity_nb_blocs = None
connectivity_labels = None
if 'connectivity_matrix' in hdf_group:
contains_connectivity = True
if 'connectivity_nb_blocs' in hdf_group.attrs:
connectivity_nb_blocs = hdf_group.attrs['connectivity_nb_blocs']
elif 'connectivity_label_volume' in hdf_group:
connectivity_labels = np.asarray(
hdf_group['connectivity_label_volume'], dtype=int)
else:
raise ValueError(
"Information stored in the hdf5 is that it contains a "
"connectivity matrix, but we don't know how it was "
"created. Either 'connectivity_nb_blocs' or "
"'connectivity_labels' should be set.")
else:
contains_connectivity = False
return contains_connectivity, connectivity_nb_blocs, connectivity_labels


class _LazyStreamlinesGetter(object):
def __init__(self, hdf_group):
self.hdf_group = hdf_group
Expand Down Expand Up @@ -141,27 +162,38 @@ class SFTDataAbstract(object):
"""
def __init__(self, space_attributes: Tuple, space: Space, origin: Origin,
contains_connectivity: bool,
connectivity_nb_blocs: List):
connectivity_nb_blocs: List = None,
connectivity_labels: np.ndarray = None):
"""
Params
------
group: str
The current streamlines group id, as loaded in the hdf5 file (it
had type "streamlines"). Probabaly 'streamlines'.
The lazy/non-lazy versions will have more parameters, such as the
streamlines, the connectivity_matrix. In the case of the lazy version,
through the LazyStreamlinesGetter.
Parameters
----------
space_attributes: Tuple
The space attributes consist of a tuple:
(affine, dimensions, voxel_sizes, voxel_order)
space: Space
The space from dipy's Space format.
subject_id: str:
The subject's name
origin: Origin
The origin from dipy's Origin format.
contains_connectivity: bool
If true, will search for either the connectivity_nb_blocs or the
connectivity_from_labels information.
connectivity_nb_blocs: List
The information how to recreate the connectivity matrix.
connectivity_labels: np.ndarray
The 3D volume stating how to recreate the labels.
(toDo: Could be managed to be lazy)
"""
self.space_attributes = space_attributes
self.space = space
self.origin = origin
self.is_lazy = None
self.contains_connectivity = contains_connectivity
self.connectivity_nb_blocs = connectivity_nb_blocs
self.connectivity_labels = connectivity_labels

def __len__(self):
raise NotImplementedError
Expand Down Expand Up @@ -195,7 +227,7 @@ def get_connectivity_matrix_and_info(self, ind=None):
(_, ref_volume_shape, _, _) = self.space_attributes

return (self._access_connectivity_matrix(ind), ref_volume_shape,
self.connectivity_nb_blocs)
self.connectivity_nb_blocs, self.connectivity_labels)

def _access_connectivity_matrix(self, ind):
raise NotImplementedError
Expand Down Expand Up @@ -277,15 +309,14 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group):
streamlines = _load_all_streamlines_from_hdf(hdf_group)
# Adding non-hidden parameters for nicer later access
lengths_mm = hdf_group['euclidean_lengths']
if 'connectivity_matrix' in hdf_group:
contains_connectivity = True
connectivity_matrix = np.asarray(hdf_group['connectivity_matrix'],
dtype=int)
connectivity_nb_blocs = hdf_group.attrs['connectivity_nb_blocs']

contains_connectivity, connectivity_nb_blocs, connectivity_labels = \
_load_connectivity_info(hdf_group)
if contains_connectivity:
connectivity_matrix = np.asarray(
hdf_group['connectivity_matrix'], dtype=int) # int or bool?
else:
contains_connectivity = False
connectivity_matrix = None
connectivity_nb_blocs = None

space_attributes, space, origin = _load_space_attributes_from_hdf(hdf_group)

Expand All @@ -296,7 +327,8 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group):
space_attributes=space_attributes,
space=space, origin=origin,
contains_connectivity=contains_connectivity,
connectivity_nb_blocs=connectivity_nb_blocs)
connectivity_nb_blocs=connectivity_nb_blocs,
connectivity_labels=connectivity_labels)

def _get_streamlines_as_list(self, streamline_ids):
if streamline_ids is not None:
Expand Down Expand Up @@ -336,22 +368,22 @@ def _access_connectivity_matrix(self, indxyz: Tuple = None):

@classmethod
def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group):
space_attributes, space, origin = _load_space_attributes_from_hdf(hdf_group)
if 'connectivity_matrix' in hdf_group:
contains_connectivity = True
connectivity_nb_blocs = hdf_group.attrs['connectivity_nb_blocs']
else:
contains_connectivity = False
connectivity_nb_blocs = None
space_attributes, space, origin = _load_space_attributes_from_hdf(
hdf_group)

contains_connectivity, connectivity_nb_blocs, connectivity_labels = \
_load_connectivity_info(hdf_group)

streamlines = _LazyStreamlinesGetter(hdf_group)

return cls(streamlines_getter=streamlines,
space_attributes=space_attributes,
space=space, origin=origin,
contains_connectivity=contains_connectivity,
connectivity_nb_blocs=connectivity_nb_blocs)
connectivity_nb_blocs=connectivity_nb_blocs,
connectivity_labels=connectivity_labels)

def _get_streamlines_as_list(self, streamline_ids):
streamlines = self.streamlines_getter.get_array_sequence(streamline_ids)
streamlines = self.streamlines_getter.get_array_sequence(
streamline_ids)
return streamlines
Loading

0 comments on commit 5411a21

Please sign in to comment.