Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add resample nb points + dps #244

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:

- name: Install dependencies
run: |
export SETUPTOOLS_USE_DISTUTILS=stdlib
pip install --upgrade pip
pip install pytest
pip install -e .
Expand Down
77 changes: 54 additions & 23 deletions dwi_ml/data/hdf5/hdf5_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,17 @@ class HDF5Creator:
def __init__(self, root_folder: Path, out_hdf_filename: Path,
training_subjs: List[str], validation_subjs: List[str],
testing_subjs: List[str], groups_config: dict,
step_size: float = None, compress: float = None,
dps_keys: List[str] = [],
step_size: float = None,
nb_points: int = None,
compress: float = None,
remove_invalid: bool = False,
enforce_files_presence: bool = True,
save_intermediate: bool = False,
intermediate_folder: Path = None):
"""
Params step_size, nb_points and compress are mutually exclusive.

Params
------
root_folder: Path
Expand All @@ -154,10 +160,16 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path,
List of subject names for each data set.
groups_config: dict
Information from json file loaded as a dict.
dps_keys: List[str]
List of keys to keep in data_per_streamline. Default: [].
step_size: float
Step size to resample streamlines. Default: None.
nb_points: int
Number of points per streamline. Default: None.
EmmaRenauld marked this conversation as resolved.
Show resolved Hide resolved
compress: float
Compress streamlines. Default: None.
remove_invalid: bool
Remove invalid streamline. Default: False
enforce_files_presence: bool
If true, will stop if some files are not available for a subject.
Default: True.
Expand All @@ -174,8 +186,11 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path,
self.validation_subjs = validation_subjs
self.testing_subjs = testing_subjs
self.groups_config = groups_config
self.dps_keys = dps_keys
self.step_size = step_size
self.nb_points = nb_points
self.compress = compress
self.remove_invalid = remove_invalid

# Optional
self.save_intermediate = save_intermediate
Expand All @@ -188,7 +203,7 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path,
self._analyse_config_file()

# -------- Performing checks

self._check_streamlines_operations()
# Check that all subjects exist.
logging.debug("Preparing hdf5 creator for \n"
" training subjs {}, \n"
Expand Down Expand Up @@ -340,6 +355,19 @@ def flatten_list(a_list):
self.enforce_files_presence,
folder=subj_input_dir)

def _check_streamlines_operations(self):
valid = True
if self.step_size and self.nb_points:
valid = False
elif self.step_size and self.compress:
valid = False
elif self.nb_points and self.compress:
valid = False
if not valid:
raise ValueError(
"Only one option can be chosen: either resampling to "
"step_size, nb_points or compressing, not both.")

def create_database(self):
"""
Generates a hdf5 dataset from a group of subjects. Hdf5 dataset will
Expand All @@ -359,6 +387,8 @@ def create_database(self):
hdf_handle.attrs['testing_subjs'] = self.testing_subjs
hdf_handle.attrs['step_size'] = self.step_size if \
self.step_size is not None else 'Not defined by user'
hdf_handle.attrs['nb_points'] = self.nb_points if \
self.nb_points is not None else 'Not defined by user'
hdf_handle.attrs['compress'] = self.compress if \
self.compress is not None else 'Not defined by user'

Expand Down Expand Up @@ -598,9 +628,16 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id,

if len(sft.data_per_point) > 0:
logging.debug('sft contained data_per_point. Data not kept.')
if len(sft.data_per_streamline) > 0:
logging.debug('sft contained data_per_streamlines. Data not '
'kept.')

for dps_key in self.dps_keys:
if dps_key not in sft.data_per_streamline:
raise ValueError(
"The data_per_streamline key '{}' was not found in "
"the sft. Check your tractogram file.".format(dps_key))

logging.debug(" Include dps \"{}\" in the HDF5.".format(dps_key))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a check that the dps_key exists in the tractogram? Else, error, or else, warning?

streamlines_group.create_dataset('dps_' + dps_key,
data=sft.data_per_streamline[dps_key])

# Accessing private Dipy values, but necessary.
# We need to deconstruct the streamlines into arrays with
Expand Down Expand Up @@ -632,6 +669,8 @@ def _process_one_streamline_group(
Reference used to load and send the streamlines in voxel space and
to create final merged SFT. If the file is a .trk, 'same' is used
instead.
remove_invalid : bool
If True, invalid streamlines will be removed

Returns
-------
Expand All @@ -642,11 +681,6 @@ def _process_one_streamline_group(
"""
tractograms = self.groups_config[group]['files']

if self.step_size and self.compress:
raise ValueError(
"Only one option can be chosen: either resampling to "
"step_size or compressing, not both.")

# Silencing SFT's logger if our logging is in DEBUG mode, because it
# typically produces a lot of outputs!
set_sft_logger_level('WARNING')
Expand Down Expand Up @@ -679,19 +713,12 @@ def _process_one_streamline_group(
if self.save_intermediate:
output_fname = self.intermediate_folder.joinpath(
subj_id + '_' + group + '.trk')
logging.debug(' *Saving intermediate streamline group {} '
'into {}.'.format(group, output_fname))
logging.debug(" *Saving intermediate streamline group {} "
"into {}.".format(group, output_fname))
# Note. Do not remove the str below. Does not work well
# with Path.
save_tractogram(final_sft, str(output_fname))

# Removing invalid streamlines
logging.debug(' *Total: {:,.0f} streamlines. Now removing '
'invalid streamlines.'.format(len(final_sft)))
final_sft.remove_invalid_streamlines()
logging.info(" Final number of streamlines: {:,.0f}."
.format(len(final_sft)))

conn_matrix = None
conn_info = None
if 'connectivity_matrix' in self.groups_config[group]:
Expand Down Expand Up @@ -735,9 +762,10 @@ def _load_and_process_sft(self, tractogram_file, header):
"We do not support file's type: {}. We only support .trk "
"and .tck files.".format(tractogram_file))
if file_extension == '.trk':
if not is_header_compatible(str(tractogram_file), header):
raise ValueError("Streamlines group is not compatible with "
"volume groups\n ({})"
if header and not is_header_compatible(str(tractogram_file),
header):
raise ValueError("Streamlines group is not compatible "
"with volume groups\n ({})"
.format(tractogram_file))
# overriding given header.
header = 'same'
Expand All @@ -748,6 +776,9 @@ def _load_and_process_sft(self, tractogram_file, header):
sft = load_tractogram(str(tractogram_file), header)

# Resample or compress streamlines
sft = resample_or_compress(sft, self.step_size, self.compress)
sft = resample_or_compress(sft, self.step_size,
self.nb_points,
self.compress,
self.remove_invalid)

return sft
16 changes: 10 additions & 6 deletions dwi_ml/data/hdf5/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ def add_hdf5_creation_args(p: ArgumentParser):
help="A txt file containing the list of subjects ids to "
"use for training. \n(Can be an empty file.)")
p.add_argument('validation_subjs',
help="A txt file containing the list of subjects ids to use "
"for validation. \n(Can be an empty file.)")
help="A txt file containing the list of subjects ids"
" to use for validation. \n(Can be an empty file.)")
p.add_argument('testing_subjs',
help="A txt file containing the list of subjects ids to use "
"for testing. \n(Can be an empty file.)")
help="A txt file containing the list of subjects ids"
" to use for testing. \n(Can be an empty file.)")

# Optional arguments
p.add_argument('--enforce_files_presence', type=bool, default=True,
Expand All @@ -76,9 +76,13 @@ def add_hdf5_creation_args(p: ArgumentParser):
"each subject inside the \nhdf5 folder, in sub-"
"folders named subjid_intermediate.\n"
"(Final concatenated standardized volumes and \n"
"final concatenated resampled/compressed streamlines.)")
"final concatenated resampled/compressed "
"streamlines.)")
p.add_argument('--dps_keys', type=str, nargs='+', default=[],
help="List of keys to keep in data_per_streamline. "
"Default: Empty.")


def add_streamline_processing_args(p: ArgumentParser):
g = p.add_argument_group('Streamlines processing options:')
g = p.add_argument_group('Streamlines processing options')
add_resample_or_compress_arg(g)
25 changes: 20 additions & 5 deletions dwi_ml/data/processing/streamlines/data_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,33 @@
import numpy as np

from scilpy.tractograms.streamline_operations import \
resample_streamlines_step_size, compress_sft
resample_streamlines_num_points, resample_streamlines_step_size, \
compress_sft


def resample_or_compress(sft, step_size_mm: float = None,
compress: float = None):
nb_points: int = None,
compress: float = None,
remove_invalid: bool = False):
if step_size_mm is not None:
# Note. No matter the chosen space, resampling is done in mm.
logging.debug(" Resampling: {}".format(step_size_mm))
logging.debug(" Resampling (step size): {}mm".format(step_size_mm))
sft = resample_streamlines_step_size(sft, step_size=step_size_mm)
if compress is not None:
logging.debug(" Compressing: {}".format(compress))
elif nb_points is not None:
logging.debug(" Resampling: " +
"{} points per streamline".format(nb_points))
sft = resample_streamlines_num_points(sft, nb_points)
elif compress is not None:
logging.debug(" Compressing: {}".format(compress))
sft = compress_sft(sft, compress)

if remove_invalid:
logging.debug(" Total: {:,.0f} streamlines. Now removing "
"invalid streamlines.".format(len(sft)))
sft.remove_invalid_streamlines()
logging.info(" Final number of streamlines: {:,.0f}."
.format(len(sft)))

return sft


Expand Down
11 changes: 7 additions & 4 deletions dwi_ml/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,23 @@
import os
from argparse import ArgumentParser

from scilpy.io.utils import add_processes_arg


def add_resample_or_compress_arg(p: ArgumentParser):
p.add_argument("--remove_invalid", action='store_true',
help="If set, remove invalid streamlines.")
g = p.add_mutually_exclusive_group()
g.add_argument(
'--step_size', type=float, metavar='s',
help="Step size to resample the data (in mm). Default: None")
g.add_argument('--nb_points', type=int, metavar='n',
help='Number of points per streamline in the output.'
'Default: None')
g.add_argument(
'--compress', type=float, metavar='r', const=0.01, nargs='?',
dest='compress_th',
help="Compression ratio. Default: None. Default if set: 0.01.\n"
"If neither step_size nor compress are chosen, streamlines "
"will be kept \nas they are.")
"If neither step_size, nb_points nor compress "
"are chosen, \nstreamlines will be kept as they are.")


def add_arg_existing_experiment_path(p: ArgumentParser):
Expand Down
13 changes: 8 additions & 5 deletions dwi_ml/models/main_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
prepare_neighborhood_vectors
from dwi_ml.experiment_utils.prints import format_dict_to_str
from dwi_ml.io_utils import add_resample_or_compress_arg
from dwi_ml.models.direction_getter_models import keys_to_direction_getters, \
AbstractDirectionGetterModel
from dwi_ml.models.direction_getter_models import keys_to_direction_getters
from dwi_ml.models.embeddings import (keys_to_embeddings, NNEmbedding,
NoEmbedding)
from dwi_ml.models.utils.direction_getters import add_direction_getter_args
Expand All @@ -37,6 +36,7 @@ class MainModelAbstract(torch.nn.Module):
def __init__(self, experiment_name: str,
# Target preprocessing params for the batch loader + tracker
step_size: float = None,
nb_points: int = None,
compress_lines: float = False,
# Other
log_level=logging.root.level):
Expand Down Expand Up @@ -74,17 +74,20 @@ def __init__(self, experiment_name: str,

# To tell our batch loader how to resample streamlines during training
# (should also be the step size during tractography).
if step_size and compress_lines:
raise ValueError("You may choose either resampling or compressing,"
"but not both.")
if (step_size and compress_lines) or (step_size and nb_points) or (nb_points and compress_lines):
raise ValueError("You may choose either resampling (step_size or nb_points)"
" or compressing, but not two of them or more.")
elif step_size and step_size <= 0:
raise ValueError("Step size can't be 0 or less!")
elif nb_points and nb_points <= 0:
raise ValueError("Number of points can't be 0 or less!")
# Note. When using
# scilpy.tracking.tools.resample_streamlines_step_size, a warning
# is shown if step_size < 0.1 or > np.max(sft.voxel_sizes), saying
# that the value is suspicious. Not raising the same warnings here
# as you may be wanting to test weird things to understand better
# your model.
self.nb_points = nb_points
self.step_size = step_size
self.compress_lines = compress_lines

Expand Down
1 change: 1 addition & 0 deletions dwi_ml/testing/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def run_model_on_sft(self, sft, compute_loss=False):
The mean eos error per line.
"""
sft = resample_or_compress(sft, self.model.step_size,
self.model.nb_points,
self.model.compress_lines)
sft.to_vox()
sft.to_corner()
Expand Down
1 change: 1 addition & 0 deletions dwi_ml/training/batch_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def _data_augmentation_sft(self, sft):
"the hdf5 dataset. Not compressing again.")
else:
sft = resample_or_compress(sft, self.model.step_size,
self.model.nb_points,
self.model.compress_lines)

# Splitting streamlines
Expand Down
7 changes: 6 additions & 1 deletion scripts_python/dwiml_create_hdf5_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,12 @@ def prepare_hdf5_creator(args):
# Instantiate a creator and perform checks
creator = HDF5Creator(Path(args.dwi_ml_ready_folder), args.out_hdf5_file,
training_subjs, validation_subjs, testing_subjs,
groups_config, args.step_size, args.compress_th,
groups_config,
args.dps_keys,
args.step_size,
args.nb_points,
args.compress_th,
args.remove_invalid,
args.enforce_files_presence,
args.save_intermediate, intermediate_subdir)

Expand Down
4 changes: 3 additions & 1 deletion scripts_python/dwiml_visualize_noise_on_streamlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def main():
subj_sft_data = subj_data.sft_data_list[streamline_group_idx]
sft = subj_sft_data.as_sft()

sft = resample_or_compress(sft, args.step_size, args.compress_th)
sft = resample_or_compress(sft, args.step_size,
args.nb_points,
args.compress_th)
sft.to_vox()
sft.to_corner()

Expand Down
Loading