Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Minor change of bch_terms param of compose_svfs() #130

Merged
merged 1 commit into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 41 additions & 19 deletions src/deepali/core/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,27 +71,44 @@ def compose_svfs(
sigma: Optional[float] = None,
spacing: Optional[Union[Scalar, Array]] = None,
stride: Optional[ScalarOrTuple[int]] = None,
bch_terms: int = 4,
bch_terms: int = 3,
) -> Tensor:
r"""Approximate stationary velocity field (SVF) of composite deformation.

The output velocity field is ``w = log(exp(v) o exp(u))``, where ``exp`` is the exponential map
of a stationary velocity field, and ``log`` its inverse. The velocity field ``w`` is given by the
`Baker-Campbell-Hausdorff (BCH) formula <https://en.wikipedia.org/wiki/Baker%E2%80%93Campbell%E2%80%93Hausdorff_formula>`_.

The BCH formula with 5 Lie bracket terms (cf. ``bch_terms`` parameter) is

.. math::

w = v + u + \frac{1}{2} [v, u]
+ \frac{1}{12} ([v, [v, u]] - [u, [v, u]])
+ \frac{1}{48} ([[v, [v, u]], u] - [v, [u, [v, u]]])

where

.. math::

[[v, [v, u]], u] - [v, [u, [v, u]]] = -2 [u, [v, [v, u]]]

References:
- Vercauteren, 2008. Symmetric Log-Domain Diffeomorphic Registration: A Demons-based Approach.
doi:10.1007/978-3-540-85988-8_90
- Bossa & Olmos, 2008. A new algorithm for the computation of the group logarithm of diffeomorphisms.
https://inria.hal.science/inria-00629873
- Vercauteren et al., 2008. Symmetric log-domain diffeomorphic registration: A Demons-based approach.
https://doi.org/10.1007/978-3-540-85988-8_90

Args:
u: First applied stationary velocity field as tensor of shape ``(N, D, ..., X)``.
v: Second applied stationary velocity field as tensor of shape ``(N, D, ..., X)``.
bch_terms: Number of terms of the BCH formula to consider. Must be at least 2.
When 2, the returned velocity field is the sum of ``u`` and ``v``.
bch_terms: Number of Lie bracket terms of the BCH formula to consider.
When 0, the returned velocity field is the sum of ``u`` and ``v``.
This approximation is accurate if the input velocity fields commute, i.e.,
the Lie bracket [v, u] = 0. When ``bch_terms=3``, the approximation is given by
``w = v + u + 1/2 [v, u]`` (note that deformation ``exp(u)`` is applied first),
and when ``bch_terms=4``, it is ``w = v + u + 1/2 [v, u] + 1/12 [v, [v, u]]``.
the Lie bracket [v, u] = 0. When ``bch_terms=1``, the approximation is given by
``w = v + u + 1/2 [v, u]`` (note ``exp(u)`` is applied before ``exp(v)``). Formula
``w = v + u + \frac{1}{2} [v, u] + \frac{1}{12} ([v, [v, u]] - [u, [v, u]])`` is
used by default, i.e., ``bch_terms=3``.
mode: Mode of :func:`flow_derivatives()` approximation.
sigma: Standard deviation of Gaussian used for computing spatial derivatives.
spacing: Physical size of image voxels used to compute spatial derivatives.
Expand All @@ -114,25 +131,30 @@ def lb(a: Tensor, b: Tensor) -> Tensor:
raise ValueError(f"compose_svfs() '{name}' must have shape (N, D, ..., X)")
if u.shape != v.shape:
raise ValueError("compose_svfs() 'u' and 'v' must have the same shape")
if bch_terms < 2:
raise ValueError("compose_svfs() 'bch_terms' must be at least 2")
elif bch_terms > 6:
if bch_terms < 0:
raise ValueError("compose_svfs() 'bch_terms' must not be negative")
elif bch_terms > 5:
raise NotImplementedError("compose_svfs() 'bch_terms' of more than 6 not implemented")

# w = v + u
w = v.add(u)
if bch_terms >= 3:
if bch_terms >= 1:
# + 1/2 [v, u]
vu = lb(v, u)
w = w.add(vu.mul(0.5))
if bch_terms >= 4:
if bch_terms >= 2:
# + 1/12 [v, [v, u]]
vvu = lb(v, vu)
w = w.add(vvu.mul(1 / 12))
if bch_terms >= 5:
uv = lb(u, v)
uuv = lb(u, uv)
w = w.add(uuv.mul(1 / 12))
if bch_terms >= 6:
if bch_terms >= 3:
# - 1/12 [u, [v, u]]
uvu = lb(u, vu)
w = w.sub(uvu.mul(1 / 12))
if bch_terms >= 4:
# + 1/48 [[v, [v, u]], u] = - 1/48 [u, [v, [v, u]]]
# - 1/48 [v, [u, [v, u]]] = - 1/48 [u, [v, [v, u]]]
uvvu = lb(u, vvu)
w = w.sub(uvvu.mul(1 / 24))
w = w.sub(uvvu.mul((1 if bch_terms == 4 else 2) / 48))

return w

Expand Down
2 changes: 1 addition & 1 deletion tests/_test_compose_svfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def visualize_flow(ax, flow: Tensor) -> None:

# %%
# Approximate velocity field of composite displacement field
flow_w = U.expv(U.compose_svfs(u, v, bch_terms=6))
flow_w = U.expv(U.compose_svfs(u, v, bch_terms=3))


# %%
Expand Down
20 changes: 9 additions & 11 deletions tests/test_core_flow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,34 +438,32 @@ def test_flow_compose_svfs() -> None:

with pytest.raises(ValueError):
U.compose_svfs(p, p, bch_terms=-1)
with pytest.raises(ValueError):
U.compose_svfs(p, p, bch_terms=0)
with pytest.raises(ValueError):
U.compose_svfs(p, p, bch_terms=1)
with pytest.raises(NotImplementedError):
U.compose_svfs(p, p, bch_terms=7)
U.compose_svfs(p, p, bch_terms=6)

# u = [yz, xz, xy] and v = u
u = v = torch.cat([y.mul(z), x.mul(z), x.mul(y)], dim=1)

w = U.compose_svfs(u, v, bch_terms=0)
assert torch.allclose(w, u.add(v))
w = U.compose_svfs(u, v, bch_terms=1)
assert torch.allclose(w, u.add(v))
w = U.compose_svfs(u, v, bch_terms=2)
assert torch.allclose(w, u.add(v))
w = U.compose_svfs(u, v, bch_terms=3)
assert torch.allclose(w, u.add(v))
w = U.compose_svfs(u, v, bch_terms=4)
assert torch.allclose(w, u.add(v))
assert torch.allclose(w, u.add(v), atol=1e-5)
w = U.compose_svfs(u, v, bch_terms=5)
assert torch.allclose(w, u.add(v))
w = U.compose_svfs(u, v, bch_terms=6)
assert torch.allclose(w, u.add(v), atol=1e-5)

# u = [yz, xz, xy] and v = [x, y, z]
u = torch.cat([y.mul(z), x.mul(z), x.mul(y)], dim=1)
v = torch.cat([x, y, z], dim=1)

w = U.compose_svfs(u, v, bch_terms=2)
w = U.compose_svfs(u, v, bch_terms=0)
assert torch.allclose(w, u.add(v))
w = U.compose_svfs(u, v, bch_terms=3)
w = U.compose_svfs(u, v, bch_terms=1)
assert torch.allclose(w, u.mul(0.5).add(v), atol=1e-6)

# u = random_svf(), u -> 0 at boundary
Expand All @@ -474,7 +472,7 @@ def test_flow_compose_svfs() -> None:
generator = torch.Generator().manual_seed(42)
u = random_svf(size, stride=4, generator=generator).mul_(0.1)
v = random_svf(size, stride=4, generator=generator).mul_(0.05)
w = U.compose_svfs(u, v, bch_terms=6)
w = U.compose_svfs(u, v, bch_terms=5)

flow_u = U.expv(u)
flow_v = U.expv(v)
Expand Down