Skip to content

Commit

Permalink
microssim now handles an array of gt,pred pairs where individual gt c…
Browse files Browse the repository at this point in the history
…ould have different shape
  • Loading branch information
ashesh-0 committed Aug 22, 2024
1 parent 85b8623 commit 8246b4e
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 2,018 deletions.
23 changes: 22 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,23 @@
# Objective
This is the official implementation of Range invariant SSIM metric.
This is the official implementation of [MicroSSIM](https://arxiv.org/abs/2408.08747), accepted at BIC, ECCV 2024.
## Installation
We will soon release the package on PyPI. For now, you can install the package by cloning the repository and running the following command:
```bash
git clone [email protected]:juglab/MicroSSIM.git
cd MicroSSIM
pip install -e .
```
## Usage
```python
from microssim import MicroSSIM, MicroMS3IM
gt: N x H x W
pred: N x H x W

ssim = MicroSSIM() # or MicroMS3IM()
ssim.fit(gt, pred)

for i in range(N):
score = ssim.score(gt[i], pred[i])
print('SSIM score for', i, 'th image:', score)

```
43 changes: 32 additions & 11 deletions microssim/_micro_ssim_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,38 @@ def get_transformation_params(gt, pred, **ssim_kwargs):
ssim_dict["C1"],
ssim_dict["C2"],
)
ux_arr.append(ux)
uy_arr.append(uy)
vx_arr.append(vx)
vy_arr.append(vy)
vxy_arr.append(vxy)

ux = np.concatenate(ux_arr, axis=0)
uy = np.concatenate(uy_arr, axis=0)
vx = np.concatenate(vx_arr, axis=0)
vy = np.concatenate(vy_arr, axis=0)
vxy = np.concatenate(vxy_arr, axis=0)
# reshape allows handling differently sized images.
ux_arr.append(
ux.reshape(
-1,
)
)
uy_arr.append(
uy.reshape(
-1,
)
)
vx_arr.append(
vx.reshape(
-1,
)
)
vy_arr.append(
vy.reshape(
-1,
)
)
vxy_arr.append(
vxy.reshape(
-1,
)
)

ux = np.concatenate(ux_arr)
uy = np.concatenate(uy_arr)
vx = np.concatenate(vx_arr)
vy = np.concatenate(vy_arr)
vxy = np.concatenate(vxy_arr)

other_args = (
ux,
Expand Down
2 changes: 1 addition & 1 deletion microssim/micro_ms3im.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def score(
if ms_ssim_kwargs is None:
ms_ssim_kwargs = {}

if not self._fit_called:
if not self._initialized:
raise ValueError(
"fit method was not called before score method. Expected behaviour is to call fit \
with ALL DATA and then call score(), with individual images.\
Expand Down
64 changes: 56 additions & 8 deletions microssim/micro_ssim.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List, Union

import numpy as np

from microssim._micro_ssim_internal import get_transformation_params, micro_SSIM
Expand Down Expand Up @@ -30,16 +32,29 @@ def __init__(
self._offset_gt = offset_gt
self._max_val = max_val
self._ri_factor = ri_factor
self._fit_called = self._ri_factor is not None
if self._fit_called:
self._initialized = self._ri_factor is not None
if self._initialized:
assert (
self._offset_gt is not None
and self._offset_pred is not None
and self._max_val is not None
), "If ri_factor is provided, offset_pred, offset_gt and max_val must be provided as well."

def fit(self, gt: np.ndarray, pred: np.ndarray):
assert self._fit_called is False, "fit method can be called only once."
def get_init_params_dict(self):
"""
Returns the initialization parameters of the measure. This can be used to save the model and
reload it later or to initialize other SSIM variants with the same parameters.
"""
assert self._initialized is True, "model is not initialized."
return {
"bkg_percentile": self._bkg_percentile,
"offset_pred": self._offset_pred,
"offset_gt": self._offset_gt,
"max_val": self._max_val,
"ri_factor": self._ri_factor,
}

def _set_hparams(self, gt: np.ndarray, pred: np.ndarray):
if self._offset_gt is None:
self._offset_gt = np.percentile(gt, self._bkg_percentile, keepdims=False)

Expand All @@ -51,13 +66,46 @@ def fit(self, gt: np.ndarray, pred: np.ndarray):
if self._max_val is None:
self._max_val = (gt - self._offset_gt).max()

def fit(self, gt: np.ndarray, pred: np.ndarray):
assert self._initialized is False, "fit method can be called only once."

if isinstance(gt, np.ndarray):
self._set_hparams(gt, pred)

elif isinstance(gt, list):
gt_squished = np.concatenate(
[
x.reshape(
-1,
)
for x in gt
]
)
pred_squished = np.concatenate(
[
x.reshape(
-1,
)
for x in pred
]
)
self._set_hparams(gt_squished, pred_squished)

self._fit(gt, pred)
self._fit_called = True
self._initialized = True

def normalize_prediction(self, pred: Union[List[np.ndarray], np.ndarray]):
if isinstance(pred, list):
assert isinstance(pred[0], np.ndarray), "List must contain numpy arrays."
return [self.normalize_prediction(x) for x in pred]

def normalize_prediction(self, pred: np.ndarray):
return (pred - self._offset_pred) / self._max_val

def normalize_gt(self, gt: np.ndarray):
def normalize_gt(self, gt: Union[List[np.ndarray], np.ndarray]):
if isinstance(gt, list):
assert isinstance(gt[0], np.ndarray), "List must contain numpy arrays."
return [self.normalize_gt(x) for x in gt]

return (gt - self._offset_gt) / self._max_val

def _fit(self, gt: np.ndarray, pred: np.ndarray):
Expand All @@ -71,7 +119,7 @@ def score(
pred: np.ndarray,
return_individual_components: bool = False,
):
if not self._fit_called:
if not self._initialized:
raise ValueError(
"fit method was not called before score method. Expected behaviour is to call fit \
with ALL DATA and then call score(), with individual images.\
Expand Down
15 changes: 8 additions & 7 deletions notebooks/Evaluate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,15 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 25/25 [00:16<00:00, 1.47it/s]\n"
"100%|██████████| 25/25 [00:17<00:00, 1.40it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 30.2 s, sys: 12.2 s, total: 42.4 s\n",
"Wall time: 42.6 s\n"
"CPU times: user 32.6 s, sys: 24.8 s, total: 57.4 s\n",
"Wall time: 59 s\n"
]
},
{
Expand Down Expand Up @@ -191,22 +191,23 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 25/25 [00:17<00:00, 1.43it/s]\n"
"100%|██████████| 25/25 [00:17<00:00, 1.41it/s]\n"
]
}
],
"source": [
"from microssim.micro_ms3im import MicroMS3IM\n",
"msssim = MicroSSIM(offset_pred=BACKGROUND_OFFSET_PREDICTION, offset_gt=BACKGROUND_OFFSET_TARGET, max_val=all_max)\n",
"msssim.fit(gt[::4], pred[::4])\n",
"ms3im = MicroMS3IM(offset_pred=BACKGROUND_OFFSET_PREDICTION, offset_gt=BACKGROUND_OFFSET_TARGET, max_val=all_max, ri_factor=msssim._ri_factor)\n",
"# msssim.fit(gt[::4], pred[::4])\n",
"msssim.fit([x for x in gt[::4]], [x for x in pred[::4]])\n",
"m3im = MicroMS3IM(**msssim.get_init_params_dict())\n",
"\n"
]
},
Expand Down
Loading

0 comments on commit 8246b4e

Please sign in to comment.