Official Pytorch implementation utilised on the paper: Disagreement attention: Let us agree to disagree on computed tomography segmentation.
The schematics of the proposed Mixed-Embedded Disagreement Attention (MEDA) 1.
The proposed Alternating Deep Supervision (ADSV) framework 1.
In general terms the application contains:
-
The Standard attention gate (AttentionBlock) and novel Disagreement-based attention modules (PureDABlock, EmbeddedDABlock and MixedEmbeddedDABlock).
-
XAttentionUNet: A modified Attention UNet that receives the attention class to be employed as an argument.
-
XAttentionUNet_ADSV: A modified XAttentionUNet that employs the proposed Alternating Deep Supervision (ADSV).
-
UNet2D, UNet3D, Attention Unet, UNet with grid attention.
-
The dataset processors and model managers for the CT-82 and LiTS17 datasets have been moved to github.com/giussepi/gtorch_utils/tree/main/gtorch_utils/datasets/segmentation/datasets/ct82 and github.com/giussepi/gtorch_utils/tree/main/gtorch_utils/datasets/segmentation/datasets/lits17, respectively.
-
Clone this repository
-
[OPTIONAL] Create your virtual enviroment and activate it
-
Install Pytorch 1.10.0 following the instructions provided on the page pytorch.org/get-started/previous-versions/#v1100.
-
Install OpenSlide
-
Install the requirements
pip install -r requirements.txt
-
Make a copy of the configuration file, review it thoroughly and update it properly (especially
PROJECT_PATH
,CT82_SAVING_PATH
,LITS17_SAVING_PATH
andLITS17_CONFIG
)cp settings.py.template settings.py
The main rule is running everything from the main.py
. Thus, code for processing the datasets and training the models is provided in the main.py
. You should carefully review it, follow the instructions and only uncomment the code you need to execute.
-
Make
get_test_datasets.sh
executable and download the testing datasetschmod +x get_test_datasets.sh ./get_test_datasets.sh
-
Make
run_tests.sh
executable and run it:chmod +x run_tests.sh ./run_tests.sh
-
Make
run_tensorboard.sh
executable and run it:chmod +x run_tensorboard.sh ./run_tensorboard.sh
Just open your settings.py
and set DEBUG = True
. This will set the log level to debug and your dataloader will not use workers so you can use pdb.set_trace()
without any problem.
Note: Always see the class or function definition to pass the correct parameters and see all available options.
The instructions to get and process the dataset are available at github.com/giussepi/gtorch_utils/blob/main/gtorch_utils/datasets/segmentation/datasets/ct82/README.md.
Remember: All the code must be executed always from the main.py
.
Before training: Do not forget to configurate the ModelMGR to employ the CT-82:
model = ModelMGR(
...
labels_data=CT82Labels,
...
dataset=CT82Dataset,
...
dataset_kwargs={
'train_path': settings.CT82_TRAIN_PATH,
'val_path': settings.CT82_VAL_PATH,
'test_path': settings.CT82_TEST_PATH,
'cotraining': settings.COTRAINING,
'cache': settings.DB_CACHE,
},
...
)
LiTS17 dataset 5
The instructions to get and process the dataset are available at github.com/giussepi/gtorch_utils/blob/main/gtorch_utils/datasets/segmentation/datasets/lits17/README.md.
IMPORTANT: If the LiTS17 lesion and liver datasets are or will be at different locations, update the LITS17_SAVING_PATH
appropriately.
Remember: All the code must be executed always from the main.py
.
Before training: Do not forget to configurate the ModelMGR to employ the LiTS17 liver or Lesion datasets:
- Using LiTS17 Liver:
model = ModelMGR(
...
labels_data=LiTS17OnlyLiverLabels,
...
dataset=LiTS17Dataset,
...
dataset_kwargs={
'train_path': settings.LITS17_TRAIN_PATH,
'val_path': settings.LITS17_VAL_PATH,
'test_path': settings.LITS17_TEST_PATH,
'cotraining': settings.COTRAINING,
'cache': settings.DB_CACHE,
},
...
)
- Using LiTS17 Lesion crops:
model = ModelMGR(
...
labels_data=LiTS17OnlyLesionLabels,
...
dataset=LiTS17CropDataset,
...
dataset_kwargs={
'train_path': settings.LITS17_TRAIN_PATH,
'val_path': settings.LITS17_VAL_PATH,
'test_path': settings.LITS17_TEST_PATH,
'cotraining': settings.COTRAINING,
'cache': settings.DB_CACHE,
},
...
)
- Modify your model manager to support these feature using the
CT3DNIfTIMixin
. E.g.:
class CTModelMGR(CT3DNIfTIMixin, ModelMGR):
pass
- Replace your old ModelMGR by the new one and provide and initial weights
mymodel = CTModelMGR(
...
ini_checkpoint='<path to your best checkpoint>',
...
)
- Make the 3D mask prediction
mymodel.predict('<path to your CT folder>/CT_119.nii.gz')
- Visualize all the 2D masks
id_ = '119'
mymodel.plot_2D_ct_gt_preds(
ct_path=f'<path to your CT folder>/CT_{id_}.nii.gz',
gt_path=f'<path to your CT folder>/label_{id_}.nii.gz',
pred_path=f'pred_CT_{id_}.nii.gz'
)
Use the ModelMGR
to train models and make predictions.
# NOTE: for XAttentionUNet_ADSV employ ADSVModelMGR instead of ModelMGR
class CTModelMGR(CT3DNIfTIMixin, ModelMGR):
pass
model7 = CTModelMGR(
# UNet3D ##############################################################
# model=UNet3D,
# model_kwargs=dict(feature_scale=1, n_channels=1, n_classes=1, is_batchnorm=True),
# XAttentionUNet & XAttentionUNet_ADSV ###############################
model=XAttentionUNet,
model_kwargs=dict(
n_channels=1, n_classes=1, bilinear=False, batchnorm_cls=get_batchnormxd_class(),
init_type=UNet3InitMethod.KAIMING, data_dimensions=settings.DATA_DIMENSIONS,
da_block_cls=intra_model.MixedEmbeddedDABlock, # EmbeddedDABlock, PureDABlock, AttentionBlock
dsv=True,
),
# UNet_Att_DSV ########################################################
# model=UNet_Att_DSV,
# model_kwargs=dict(
# feature_scale=1, n_classes=1, n_channels=1, is_batchnorm=True,
# attention_block_cls=SingleAttentionBlock, data_dimensions=settings.DATA_DIMENSIONS
# ),
# UNet_Grid_Attention #################################################
# model=UNet_Grid_Attention,
# model_kwargs=dict(
# feature_scale=1, n_classes=1, n_channels=1, is_batchnorm=True,
# data_dimensions=settings.DATA_DIMENSIONS
# ),
# remaining configuration #############################################
cuda=settings.CUDA,
multigpus=settings.MULTIGPUS,
patch_replication_callback=settings.PATCH_REPLICATION_CALLBACK,
epochs=settings.EPOCHS,
intrain_val=2,
optimizer=torch.optim.Adam,
optimizer_kwargs=dict(lr=1e-4, betas=(0.9, 0.999), weight_decay=1e-6),
sanity_checks=False,
labels_data=LiTS17OnlyLiverLabels, # LiTS17OnlyLesionLabels, # CT82Labels
data_dimensions=settings.DATA_DIMENSIONS,
dataset=LiTS17Dataset, # LiTS17CropDataset, # CT82Dataset
dataset_kwargs={
'train_path': settings.LITS17_TRAIN_PATH, # settings.CT82_TRAIN_PATH
'val_path': settings.LITS17_VAL_PATH, # settings.CT82_VAL_PATH
'test_path': settings.LITS17_TEST_PATH, # settings.CT82_TEST_PATH
'cotraining': settings.COTRAINING,
'cache': settings.DB_CACHE,
},
train_dataloader_kwargs={
'batch_size': settings.TOTAL_BATCH_SIZE, 'shuffle': True, 'num_workers': settings.NUM_WORKERS,
'pin_memory': False
},
testval_dataloader_kwargs={
'batch_size': settings.TOTAL_BATCH_SIZE, 'shuffle': False, 'num_workers': settings.NUM_WORKERS,
'pin_memory': False, 'drop_last': True
},
lr_scheduler=torch.optim.lr_scheduler.StepLR,
lr_scheduler_kwargs={'step_size': 250, 'gamma': 0.5},
lr_scheduler_track=LrShedulerTrack.NO_ARGS,
criterions=[
loss_functions.BceDiceLoss(with_logits=True),
],
mask_threshold=0.5,
metrics=settings.get_metrics(),
metric_mode=MetricEvaluatorMode.MAX,
earlystopping_kwargs=dict(min_delta=1e-3, patience=np.inf, metric=True), # patience=10
checkpoint_interval=0,
train_eval_chkpt=False,
last_checkpoint=True,
ini_checkpoint='',
dir_checkpoints=os.path.join(settings.DIR_CHECKPOINTS, 'exp1'),
tensorboard=False,
# TODO: there a bug that appeared once when plotting to disk after a long training
# anyway I can always plot from the checkpoints :)
plot_to_disk=False,
plot_dir=settings.PLOT_DIRECTORY,
memory_print=dict(epochs=settings.EPOCHS//2),
)
model7()
Use the ModelMGR.print_data_logger_summary
method to do it.
model = ModelMGR(<your settings>, ini_checkpoint='chkpt_X.pth.tar', dir_checkpoints=settings.DIR_CHECKPOINTS)
model.print_data_logger_summary()
The summary will be a table like this one
key | Validation | corresponding training value |
---|---|---|
Best metric | 0.7495 | 0.7863 |
Min loss | 0.2170 | 0.1691 |
Max LR | 0.001 | |
Min LR | 1e-07 |
This application employs logzero. Thus, some functionalities can print extra data. To enable this just open your settings.py
and set DEBUG = True
. By default, the log level is set to logging.INFO.
You are free to utilise this program or any of its components. If so, please reference the following paper: Disagreement attention: Let us agree to disagree on computed tomography segmentation.
Footnotes
-
Lopez Molina, E. G., Huang, X., & Zhang, Q. (2023). Disagreement attention: Let us agree to disagree on computed tomography segmentation. Biomedical Signal Processing and Control, 84, 104769. https://doi.org/10.1016/j.bspc.2023.104769 ↩ ↩2
-
Holger R. Roth, Amal Farag, Evrim B. Turkbey, Le Lu, Jiamin Liu, and Ronald M. Summers. (2016). Data From Pancreas-CT. The Cancer Imaging Archive. https://doi.org/10.7937/K9/TCIA.2016.tNB1kqBU ↩
-
Roth HR, Lu L, Farag A, Shin H-C, Liu J, Turkbey EB, Summers RM. DeepOrgan: Multi-level Deep Convolutional Networks for Automated Pancreas Segmentation. N. Navab et al. (Eds.): MICCAI 2015, Part I, LNCS 9349, pp. 556–564, 2015. (paper) ↩
-
Clark K, Vendt B, Smith K, Freymann J, Kirby J, Koppel P, Moore S, Phillips S, Maffitt D, Pringle M, Tarbox L, Prior F. The Cancer Imaging Archive (TCIA): Maintaining and Operating a Public Information Repository, Journal of Digital Imaging, Volume 26, Number 6, December, 2013, pp 1045-1057. DOI: https://doi.org/10.1007/s10278-013-9622-7 ↩
-
P. Bilic et al., “The liver tumor segmentation benchmark (LiTS),” arXiv e-prints, p. arXiv:1901.04056, Jan. 2019. [Online]. Available: https://arxiv.org/abs/1901.04056 ↩