Skip to content

Commit

Permalink
Rearchitect reg_measure to handle forward and backward similarity mea…
Browse files Browse the repository at this point in the history
…sure values #92
  • Loading branch information
onurulgen committed Jul 31, 2023
1 parent 76efc9f commit 4a98c08
Show file tree
Hide file tree
Showing 20 changed files with 1,071 additions and 1,301 deletions.
2 changes: 1 addition & 1 deletion niftyreg_build_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
295
296
4 changes: 2 additions & 2 deletions reg-apps/reg_tools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,7 @@ int main(int argc, char **argv)
outputImage->data = malloc(outputImage->nvox * outputImage->nbyper);
// Compute the MIND descriptor
int *mask = (int *)calloc(image->nvox, sizeof(int));
GetMINDImageDescriptor(image, outputImage, mask, 1, 0);
GetMindImageDescriptor(image, outputImage, mask, 1, 0);
free(mask);
// Save the MIND descriptor image
if(flag->outputImageFlag)
Expand All @@ -1064,7 +1064,7 @@ int main(int argc, char **argv)
outputImage->data = malloc(outputImage->nvox * outputImage->nbyper);
// Compute the MIND-SSC descriptor
int *mask = (int *)calloc(image->nvox, sizeof(int));
GetMINDSSCImageDescriptor(image, outputImage, mask, 1, 0);
GetMindSscImageDescriptor(image, outputImage, mask, 1, 0);
free(mask);
// Save the MIND descriptor image
if(flag->outputImageFlag)
Expand Down
582 changes: 252 additions & 330 deletions reg-lib/cpu/_reg_dti.cpp

Large diffs are not rendered by default.

14 changes: 8 additions & 6 deletions reg-lib/cpu/_reg_dti.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ class reg_dti: public reg_measure {
nifti_image *warpedImgBw = nullptr,
nifti_image *warpedGradBw = nullptr,
nifti_image *voxelBasedGradBw = nullptr) override;
/// @brief Returns the value
virtual double GetSimilarityMeasureValue() override;
/// @brief Returns the dti value forwards
virtual double GetSimilarityMeasureValueFw() override;
/// @brief Returns the dti value backwards
virtual double GetSimilarityMeasureValueBw() override;
/// @brief Compute the voxel based gradient for DTI images
virtual void GetVoxelBasedSimilarityMeasureGradient(int currentTimepoint) override;

Expand All @@ -57,10 +59,10 @@ class reg_dti: public reg_measure {
* @return Returns an L2 measure of the distance between the anisotropic components of the diffusion tensors
*/
extern "C++" template <class DataType>
double reg_getDTIMeasureValue(nifti_image *referenceImage,
nifti_image *warpedImage,
int *mask,
unsigned *dtIndicies);
double reg_getDTIMeasureValue(const nifti_image *referenceImage,
const nifti_image *warpedImage,
const int *mask,
const unsigned *dtIndicies);
/* *************************************************************** */
/**
* @brief Compute a voxel based gradient of the sum squared difference.
Expand Down
154 changes: 52 additions & 102 deletions reg-lib/cpu/_reg_kld.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,13 @@

#include "_reg_kld.h"

/* *************************************************************** */
/* *************************************************************** */
reg_kld::reg_kld(): reg_measure() {
#ifndef NDEBUG
reg_print_msg_debug("reg_kld constructor called");
#endif
}
/* *************************************************************** */
/* *************************************************************** */
void reg_kld::InitialiseMeasure(nifti_image *refImg,
nifti_image *floImg,
int *refMask,
Expand Down Expand Up @@ -55,33 +53,33 @@ void reg_kld::InitialiseMeasure(nifti_image *refImg,
// are meant to be probabilities
for (int t = 0; t < this->referenceImage->nt; ++t) {
if (this->timePointWeight[t] > 0) {
float min_ref = reg_tools_getMinValue(this->referenceImage, t);
float max_ref = reg_tools_getMaxValue(this->referenceImage, t);
float min_flo = reg_tools_getMinValue(this->floatingImage, t);
float max_flo = reg_tools_getMaxValue(this->floatingImage, t);
if (min_ref < 0.f || min_flo < 0.f || max_ref>1.f || max_flo>1.f) {
const float minRef = reg_tools_getMinValue(this->referenceImage, t);
const float maxRef = reg_tools_getMaxValue(this->referenceImage, t);
const float minFlo = reg_tools_getMinValue(this->floatingImage, t);
const float maxFlo = reg_tools_getMaxValue(this->floatingImage, t);
if (minRef < 0.f || minFlo < 0.f || maxRef > 1.f || maxFlo > 1.f) {
reg_print_fct_error("reg_kld::InitialiseMeasure");
reg_print_msg_error("The input images are expected to be probabilities to use the kld measure");
reg_exit();
}
}
}
#ifndef NDEBUG
char text[255];
reg_print_msg_debug("reg_kld::InitialiseMeasure().");
reg_print_msg_debug("reg_kld::InitialiseMeasure()");
for (int i = 0; i < this->referenceImage->nt; ++i) {
sprintf(text, "Weight for timepoint %i: %f", i, this->timePointWeight[i]);
reg_print_msg_debug(text);
}
#endif
}
/* *************************************************************** */
/* *************************************************************** */
template <class DataType>
double reg_getKLDivergence(nifti_image *referenceImage,
nifti_image *warpedImage,
double *timePointWeight,
nifti_image *jacobianDetImg,
int *mask) {
double reg_getKLDivergence(const nifti_image *referenceImage,
const nifti_image *warpedImage,
const double *timePointWeight,
const nifti_image *jacobianDetImg,
const int *mask) {
#ifdef _WIN32
long voxel;
const long voxelNumber = (long)NiftiImage::calcVoxelNumber(referenceImage, 3);
Expand All @@ -90,119 +88,77 @@ double reg_getKLDivergence(nifti_image *referenceImage,
const size_t voxelNumber = NiftiImage::calcVoxelNumber(referenceImage, 3);
#endif

DataType *refPtr = static_cast<DataType*>(referenceImage->data);
DataType *warPtr = static_cast<DataType*>(warpedImage->data);
int *maskPtr = nullptr;
bool MrClean = false;
if (mask == nullptr) {
maskPtr = (int*)calloc(voxelNumber, sizeof(int));
MrClean = true;
} else maskPtr = &mask[0];

DataType *jacPtr = nullptr;
const DataType *refPtr = static_cast<DataType*>(referenceImage->data);
const DataType *warPtr = static_cast<DataType*>(warpedImage->data);
const DataType *jacPtr = nullptr;
if (jacobianDetImg != nullptr)
jacPtr = static_cast<DataType*>(jacobianDetImg->data);
double measure = 0, measure_tp = 0, num = 0, tempRefValue, tempWarValue, tempValue;

double measure = 0, measureTp = 0, num = 0, tempRefValue, tempWarValue, tempValue;

for (int time = 0; time < referenceImage->nt; ++time) {
if (timePointWeight[time] > 0) {
DataType *currentRefPtr = &refPtr[time * voxelNumber];
DataType *currentWarPtr = &warPtr[time * voxelNumber];
const DataType *currentRefPtr = &refPtr[time * voxelNumber];
const DataType *currentWarPtr = &warPtr[time * voxelNumber];
#ifdef _OPENMP
#pragma omp parallel for default(none) \
shared(voxelNumber,currentRefPtr, currentWarPtr, \
maskPtr, jacobianDetImg, jacPtr) \
shared(voxelNumber,currentRefPtr, currentWarPtr, mask, jacobianDetImg, jacPtr) \
private(tempRefValue, tempWarValue, tempValue) \
reduction(+:measure_tp, num)
reduction(+:measureTp, num)
#endif
for (voxel = 0; voxel < voxelNumber; ++voxel) {
if (maskPtr[voxel] > -1) {
if (mask[voxel] > -1) {
tempRefValue = currentRefPtr[voxel] + 1e-16;
tempWarValue = currentWarPtr[voxel] + 1e-16;
tempValue = tempRefValue * log(tempRefValue / tempWarValue);
if (tempValue == tempValue &&
tempValue != std::numeric_limits<double>::infinity()) {
if (jacobianDetImg == nullptr) {
measure_tp -= tempValue;
measureTp -= tempValue;
num++;
} else {
measure_tp -= tempValue * jacPtr[voxel];
measureTp -= tempValue * jacPtr[voxel];
num += jacPtr[voxel];
}
}
}
}
measure += measure_tp * timePointWeight[time] / num;
measure += measureTp * timePointWeight[time] / num;
}
}
if (MrClean) free(maskPtr);
return measure;
}
template double reg_getKLDivergence<float>(nifti_image*, nifti_image*, double*, nifti_image*, int*);
template double reg_getKLDivergence<double>(nifti_image*, nifti_image*, double*, nifti_image*, int*);
/* *************************************************************** */
double GetSimilarityMeasureValue(const nifti_image *referenceImage,
const nifti_image *warpedImage,
const double *timePointWeight,
const nifti_image *jacobianDetImg,
const int *mask) {
return std::visit([&](auto&& refImgDataType) {
using RefImgDataType = std::decay_t<decltype(refImgDataType)>;
return reg_getKLDivergence<RefImgDataType>(referenceImage,
warpedImage,
timePointWeight,
jacobianDetImg,
mask);
}, NiftiImage::getFloatingDataType(referenceImage));
}
/* *************************************************************** */
double reg_kld::GetSimilarityMeasureValue() {
// Check that all the specified image are of the same datatype
if (this->warpedImage->datatype != this->referenceImage->datatype) {
reg_print_fct_error("reg_kld::GetSimilarityMeasureValue");
reg_print_msg_error("Both input images are expected to have the same type");
reg_exit();
}
double KLDValue;
switch (this->referenceImage->datatype) {
case NIFTI_TYPE_FLOAT32:
KLDValue = reg_getKLDivergence<float>(this->referenceImage,
this->warpedImage,
this->timePointWeight,
nullptr, // TODO this->forwardJacDetImagePointer,
this->referenceMask);
break;
case NIFTI_TYPE_FLOAT64:
KLDValue = reg_getKLDivergence<double>(this->referenceImage,
this->warpedImage,
this->timePointWeight,
nullptr, // TODO this->forwardJacDetImagePointer,
this->referenceMask);
break;
default:
reg_print_fct_error("reg_kld::GetSimilarityMeasureValue");
reg_print_msg_error("Warped pixel type unsupported");
reg_exit();
}

// Backward computation
if (this->isSymmetric) {
// Check that all the specified image are of the same datatype
if (this->warpedImageBw->datatype != this->floatingImage->datatype) {
reg_print_fct_error("reg_kld::GetSimilarityMeasureValue");
reg_print_msg_error("Both input images are expected to have the same type");
reg_exit();
}
switch (this->floatingImage->datatype) {
case NIFTI_TYPE_FLOAT32:
KLDValue += reg_getKLDivergence<float>(this->floatingImage,
this->warpedImageBw,
this->timePointWeight,
nullptr, // TODO this->backwardJacDetImagePointer,
this->floatingMask);
break;
case NIFTI_TYPE_FLOAT64:
KLDValue += reg_getKLDivergence<double>(this->floatingImage,
this->warpedImageBw,
this->timePointWeight,
nullptr, // TODO this->backwardJacDetImagePointer,
this->floatingMask);
break;
default:
reg_print_fct_error("reg_kld::GetSimilarityMeasureValue");
reg_print_msg_error("Warped pixel type unsupported");
reg_exit();
}
}
return KLDValue;
double reg_kld::GetSimilarityMeasureValueFw() {
return ::GetSimilarityMeasureValue(this->referenceImage,
this->warpedImage,
this->timePointWeight,
nullptr, // TODO this->forwardJacDetImagePointer,
this->referenceMask);
}
/* *************************************************************** */
double reg_kld::GetSimilarityMeasureValueBw() {
return ::GetSimilarityMeasureValue(this->floatingImage,
this->warpedImageBw,
this->timePointWeight,
nullptr, // TODO this->backwardJacDetImagePointer,
this->floatingMask);
}
/* *************************************************************** */
template <class DataType>
void reg_getKLDivergenceVoxelBasedGradient(nifti_image *referenceImage,
Expand Down Expand Up @@ -313,11 +269,6 @@ void reg_getKLDivergenceVoxelBasedGradient(nifti_image *referenceImage,
}
if (MrClean) free(maskPtr);
}
template void reg_getKLDivergenceVoxelBasedGradient<float>
(nifti_image*, nifti_image*, nifti_image*, nifti_image*, nifti_image*, int*, int, double);
template void reg_getKLDivergenceVoxelBasedGradient<double>
(nifti_image*, nifti_image*, nifti_image*, nifti_image*, nifti_image*, int*, int, double);
/* *************************************************************** */
/* *************************************************************** */
void reg_kld::GetVoxelBasedSimilarityMeasureGradient(int currentTimepoint) {
// Check if the specified time point exists and is active
Expand Down Expand Up @@ -401,4 +352,3 @@ void reg_kld::GetVoxelBasedSimilarityMeasureGradient(int currentTimepoint) {
}
}
/* *************************************************************** */
/* *************************************************************** */
18 changes: 10 additions & 8 deletions reg-lib/cpu/_reg_kld.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ class reg_kld: public reg_measure {
nifti_image *warpedImgBw = nullptr,
nifti_image *warpedGradBw = nullptr,
nifti_image *voxelBasedGradBw = nullptr) override;
/// @brief Returns the kld value
virtual double GetSimilarityMeasureValue() override;
/// @brief Returns the kld value forwards
virtual double GetSimilarityMeasureValueFw() override;
/// @brief Returns the kld value backwards
virtual double GetSimilarityMeasureValueBw() override;
/// @brief Compute the voxel based kld gradient
virtual void GetVoxelBasedSimilarityMeasureGradient(int currentTimepoint) override;
};
Expand All @@ -50,15 +52,15 @@ class reg_kld: public reg_measure {
* image is used to modulate the KLD. The argument is ignored if the
* pointer is set to nullptr
* @param mask Array that contains a mask to specify which voxel
* should be considered. If set to nullptr, all voxels are considered
* should be considered
* @return Returns the computed sum squared difference
*/
extern "C++" template <class DataType>
double reg_getKLDivergence(nifti_image *reference,
nifti_image *warped,
double *timePointWeight,
nifti_image *jacobianDeterminantImage,
int *mask);
double reg_getKLDivergence(const nifti_image *reference,
const nifti_image *warped,
const double *timePointWeight,
const nifti_image *jacobianDeterminantImage,
const int *mask);
/* *************************************************************** */

/** @brief Compute a voxel based gradient of the sum squared difference.
Expand Down
Loading

0 comments on commit 4a98c08

Please sign in to comment.