Skip to content

Commit

Permalink
[core] Add vector field logarithm
Browse files Browse the repository at this point in the history
  • Loading branch information
aschuh-hf committed Nov 13, 2023
1 parent 3980252 commit 7fed696
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 0 deletions.
55 changes: 55 additions & 0 deletions src/deepali/core/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def expv(
sampling: Union[Sampling, str] = Sampling.LINEAR,
padding: Union[PaddingMode, str] = PaddingMode.BORDER,
align_corners: bool = ALIGN_CORNERS,
inverse: bool = False,
) -> Tensor:
r"""Group exponential maps of flow fields computed using scaling and squaring.
Expand All @@ -336,13 +337,17 @@ def expv(
padding: Flow field extrapolation mode.
align_corners: Whether ``flow`` vectors are defined with respect to
``Axes.CUBE`` (False) or ``Axes.CUBE_CORNERS`` (True).
inverse: Whether to negate scaled velocity field. Setting this to ``True``
is equivalent to negating the ``scale`` (e.g., ``scale=-1``).
Returns:
Exponential map of input flow field. If ``steps=0``, a reference to ``flow`` is returned.
"""
if scale is None:
scale = 1
if inverse:
scale = -scale
if steps is None:
steps = 5
if not isinstance(steps, int):
Expand Down Expand Up @@ -699,6 +704,56 @@ def normalize_flow(
return data


def logv(
flow: Tensor,
num_iters: int = 5,
bch_terms: int = 1,
sigma: Optional[float] = 1.0,
spacing: Optional[Union[Scalar, Array]] = None,
exp_steps: Optional[int] = None,
sampling: Union[Sampling, str] = Sampling.LINEAR,
padding: Union[PaddingMode, str] = PaddingMode.BORDER,
align_corners: bool = ALIGN_CORNERS,
) -> Tensor:
r"""Group logarithmic maps of flow fields computed using algorithm by Bossa & Olsom (2008).
References:
- Bossa & Olmos, 2008. A new algorithm for the computation of the group logarithm of diffeomorphisms.
https://inria.hal.science/inria-00629873
Args:
num_iters: Number of iterations.
bch_terms: Number of Lie bracket terms of the Baker-Campbell-Hausdorff (BCH) formula to use
when computing the composite of current velocity field with the correction field.
sigma: Standard deviation of Gaussian kernel used as low-pass filter when computing spatial
derivatives required for evaluation of Lie brackets during application of BCH formula.
spacing: Physical size of image voxels used to compute spatial derivatives.
exp_steps: Number of exponentiation steps to evaluate current inverse displacement field.
sampling: Flow field interpolation mode when computing inverse displacement field.
padding: Flow field extrapolation mode when computing inverse displacement field.
align_corners: Whether ``flow`` vectors are defined with respect to
``Axes.CUBE`` (False) or ``Axes.CUBE_CORNERS`` (True).
Returns:
Approximate stationary velocity field which when exponentiated (cf. :func:`expv`) results
in the given input ``flow`` field.
"""
v = flow
for _ in range(num_iters):
u = expv(
v,
steps=exp_steps,
sampling=sampling,
padding=padding,
align_corners=align_corners,
inverse=True,
)
u = compose_flows(flow, u)
v = compose_svfs(u, v, bch_terms=bch_terms, sigma=sigma, spacing=spacing)
return v


def denormalize_flow(
data: Tensor,
size: Optional[Union[Tensor, torch.Size]] = None,
Expand Down
2 changes: 2 additions & 0 deletions src/deepali/core/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
from .flow import jacobian_dict
from .flow import jacobian_matrix
from .flow import lie_bracket
from .flow import logv
from .flow import normalize_flow
from .flow import sample_flow
from .flow import warp_grid
Expand Down Expand Up @@ -220,6 +221,7 @@
"jacobian_dict",
"jacobian_matrix",
"lie_bracket",
"logv",
"max_pool",
"min_pool",
"normalize_flow",
Expand Down
79 changes: 79 additions & 0 deletions tests/_test_core_flow_logv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# %%
# Imports
from typing import Optional, Sequence

import matplotlib.pyplot as plt

import torch
from torch import Tensor
from torch.random import Generator

from deepali.core import Grid
import deepali.core.bspline as B
import deepali.core.functional as U


# %%
# Auxiliary functions
def random_svf(
size: Sequence[int],
stride: int = 1,
generator: Optional[Generator] = None,
) -> Tensor:
cp_grid_size = B.cubic_bspline_control_point_grid_size(size, stride=stride)
data = torch.randn((1, 3) + cp_grid_size, generator=generator)
data = U.fill_border(data, margin=3, value=0, inplace=True)
return B.evaluate_cubic_bspline(data, size=size, stride=stride)


def visualize_flow(ax, flow: Tensor, label: Optional[str] = None) -> None:
grid = Grid(shape=flow.shape[2:], align_corners=True)
x = grid.coords(channels_last=False, dtype=u.dtype, device=u.device)
x = U.move_dim(x.unsqueeze(0).add_(flow), 1, -1)
target_grid = U.grid_image(shape=flow.shape[2:], inverted=True, stride=(5, 5))
warped_grid = U.warp_image(target_grid, x)
ax.imshow(warped_grid[0, 0, flow.shape[2] // 2], cmap="gray")
if label:
ax.set_title(label, fontsize=24)


# %%
# Random velocity fields
size = (128, 128, 128)
generator = torch.Generator().manual_seed(42)
v = random_svf(size, stride=8, generator=generator).mul_(0.1)


# %%
# Compute logarithm of exponentiated velocity field
bch_terms = 3
exp_steps = 5
log_steps = 5

u = U.expv(v, steps=exp_steps)
w = U.logv(u, num_iters=log_steps, bch_terms=bch_terms, exp_steps=exp_steps, sigma=1.0)

fig, axes = plt.subplots(1, 4, figsize=(40, 10))

ax = axes[0]
ax.set_title("v", fontsize=32, pad=20)
visualize_flow(ax, v)

ax = axes[1]
ax.set_title("u = exp(v)", fontsize=32, pad=20)
visualize_flow(ax, u)

ax = axes[2]
ax.set_title("log(u)", fontsize=32, pad=20)
visualize_flow(ax, w)

error = w.sub(v).norm(dim=1, keepdim=True)

ax = axes[3]
ax.set_title("|log(u) - v|", fontsize=32, pad=20)
_ = ax.imshow(error[0, 0, error.shape[2] // 2], cmap="jet", vmin=0, vmax=0.1)

print(f"Mean error: {error.mean():.5f}")
print(f"Maximium error: {error.max():.5f}")

# %%
11 changes: 11 additions & 0 deletions tests/test_core_flow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,17 @@ def test_flow_lie_bracket() -> None:
assert error.max().lt(0.134)


def test_flow_logv() -> None:
size = (128, 128, 128)
generator = torch.Generator().manual_seed(42)
v = random_svf(size, stride=8, generator=generator).mul_(0.1)
u = U.expv(v)
w = U.logv(u)
error = w.sub(v).norm(dim=1, keepdim=True)
assert error.mean().lt(0.001)
assert error.max().lt(0.02)


def test_flow_compose_svfs() -> None:
# 3D flow fields
p = U.move_dim(Grid(size=(64, 32, 16)).coords().unsqueeze_(0), -1, 1)
Expand Down

0 comments on commit 7fed696

Please sign in to comment.