Skip to content

Commit

Permalink
Merge pull request #44 from scil-vital/atheb/fix-docker
Browse files Browse the repository at this point in the history
FIX: docker install
  • Loading branch information
AntoineTheb authored Jun 19, 2024
2 parents ddbaa2a + f42e50b commit 7f3e4bc
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 33 deletions.
5 changes: 3 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-runtime
FROM pytorch/pytorch:2.2.0-cuda11.8-cudnn8-runtime

WORKDIR /

Expand All @@ -11,5 +11,6 @@ RUN git clone https://github.com/scil-vital/TrackToLearn.git

WORKDIR /TrackToLearn

RUN pip install Cython numpy packaging
RUN pip install Cython==0.29.* numpy==1.25.* packaging --quiet
RUN pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --extra-index-url https://download.pytorch.org/whl/cu118 --quiet
RUN pip install -e .
8 changes: 7 additions & 1 deletion TrackToLearn/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ def set_sh_order_basis(
sh_order, full_basis = get_sh_order_and_fullness(n_coefs)
sh_order = int(sh_order)

is_legacy = False
if '_legacy' in sh_basis:
sh_basis = sh_basis[:-7]
is_legacy = True

# If SH in full basis, convert them
if full_basis is True:
print('SH coefficients are in "full" basis, only even coefficients '
Expand All @@ -174,6 +179,7 @@ def set_sh_order_basis(
print('SH coefficients are in the {} basis, '
'converting them to {}.'.format(sh_basis, target_basis))
sh = convert_sh_basis(
sh, sphere, input_basis=sh_basis, nbr_processes=1)
sh, sphere, input_basis=sh_basis, nbr_processes=1,
is_input_legacy=is_legacy)

return sh
33 changes: 4 additions & 29 deletions TrackToLearn/environments/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,12 @@
import nibabel as nib
import numpy as np
import torch
from dipy.core.sphere import HemiSphere
from dipy.data import get_sphere
from dipy.direction.peaks import reshape_peaks_for_visualization
from dipy.tracking import utils as track_utils
from dwi_ml.data.processing.volume.interpolation import \
interpolate_volume_in_neighborhood
from dwi_ml.data.processing.space.neighborhood import \
get_neighborhood_vectors_axes
from scilpy.reconst.utils import (find_order_from_nb_coeff,
get_maximas)
from dipy.reconst.shm import sh_to_sf_matrix
from torch.utils.data import DataLoader

from TrackToLearn.datasets.SubjectDataset import SubjectDataset
Expand All @@ -34,9 +29,11 @@

# from dipy.io.utils import get_reference_info


def collate_fn(data):
return data


class BaseEnv(object):
"""
Abstract tracking environment. This class should not be used directly.
Expand Down Expand Up @@ -158,7 +155,7 @@ def load_subject(
self.loader_iter = iter(self.loader)
(sub_id, input_volume, tracking_mask, seeding_mask,
peaks, reference) = next(self.loader_iter)[0]

self.subject_id = sub_id
# Affines
self.reference = reference
Expand Down Expand Up @@ -407,28 +404,6 @@ def _load_files(
npeaks = 5
odf_shape_3d = data.shape[:-1]
peak_dirs = np.zeros((odf_shape_3d + (npeaks, 3)))
peak_values = np.zeros((odf_shape_3d + (npeaks, )))

sphere = HemiSphere.from_sphere(get_sphere("repulsion724")
).subdivide(0)

b_matrix, _ = sh_to_sf_matrix(sphere, find_order_from_nb_coeff(data), "descoteaux07")

for idx in np.argwhere(np.sum(data, axis=-1)):
idx = tuple(idx)
directions, values, indices = get_maximas(data[idx],
sphere, b_matrix,
0.1, 0)
if values.shape[0] != 0:
n = min(npeaks, values.shape[0])
peak_dirs[idx][:n] = directions[:n]
peak_values[idx][:n] = values[:n]

X, Y, Z, N, P = peak_dirs.shape
peak_values = np.divide(peak_values, peak_values[..., 0, None],
out=np.zeros_like(peak_values),
where=peak_values[..., 0, None] != 0)
peak_dirs[...] *= peak_values[..., :, None]
peak_dirs = reshape_peaks_for_visualization(peak_dirs)

# Load rest of volumes
Expand Down Expand Up @@ -467,7 +442,7 @@ def get_action_size(self):
"""

return 3

def get_target_sh_order(self):
""" Returns the target SH order. For tracking, this is based on the hyperparameters.json if it's specified.
Otherwise, it's extracted from the data directly.
Expand Down
2 changes: 1 addition & 1 deletion TrackToLearn/runners/ttl_track.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
self.max_length = track_dto['max_length']

self.compress = track_dto['compress'] or 0.0
self.sh_basis = track_dto['sh_basis']
self.sh_basis = track_dto['sh_basis'][0]
self.save_seeds = track_dto['save_seeds']

# Tractometer parameters
Expand Down
35 changes: 35 additions & 0 deletions examples/ismrm2015_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{
"training": {
"ismrm2015": {
"inputs": [
"ismrm2015/fodfs/ismrm2015_fodf.nii.gz"
],
"peaks": "ismrm2015/fodfs/ismrm2015_peaks.nii.gz",
"tracking": "ismrm2015/masks/ismrm2015_wm.nii.gz",
"seeding": "ismrm2015/maps/interface.nii.gz",
"anat": "ismrm2015/anat/ismrm2015_T1.nii.gz"
}
},
"validation": {
"ismrm2015": {
"inputs": [
"ismrm2015/fodfs/ismrm2015_fodf.nii.gz"
],
"peaks": "ismrm2015/fodfs/ismrm2015_peaks.nii.gz",
"tracking": "ismrm2015/masks/ismrm2015_wm.nii.gz",
"seeding": "ismrm2015/maps/interface.nii.gz",
"anat": "ismrm2015/anat/ismrm2015_T1.nii.gz"
}
},
"testing": {
"ismrm2015": {
"inputs": [
"ismrm2015/fodfs/ismrm2015_fodf.nii.gz"
],
"peaks": "ismrm2015/fodfs/ismrm2015_peaks.nii.gz",
"tracking": "ismrm2015/masks/ismrm2015_wm.nii.gz",
"seeding": "ismrm2015/maps/interface.nii.gz",
"anat": "ismrm2015/anat/ismrm2015_T1.nii.gz"
}
}
}

0 comments on commit 7f3e4bc

Please sign in to comment.