-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathshape_space.py
245 lines (212 loc) · 12.2 KB
/
shape_space.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import torch
import numpy as np
from torch.linalg import norm, eigh
from scipy.linalg import solve_sylvester
from manifold import Manifold, Point, Scalar, Vector
from hypersphere import slerp
class PreShapeSpace(Manifold):
"""The Pre Shape Metric is like the Shape Metric but without removing rotations (
no alignment step).
"""
def __init__(self, num_keypoints, keypoints_dim):
self.m, self.p = num_keypoints, keypoints_dim
self.shape = torch.Size((self.m, self.p))
self.dim = self.p*(self.m-1) - (self.p*(self.p-1))//2
self.ambient = self.m * self.p
def project(self, pt: Point) -> Point:
# Assume that pt.shape == (m, p). The first operation is to center it:
pt = pt - torch.mean(pt, dim=0)
# Remove scale by normalize the point
pt = pt / norm(pt, ord="fro")
return pt
def contains(self, pt: Point, atol: float = 1e-6) -> bool:
# Test shape
if pt.shape != self.shape:
return False
# Test centered
if not torch.allclose(torch.mean(pt, dim=0), pt.new_zeros(pt.shape[1:]), atol=atol):
return False
# Test unit norm
if not torch.isclose(norm(pt, ord="fro"), pt.new_ones((1,))):
return False
# All tests passed - return True
return True
def _distance_impl(self, pt_x: Point, pt_y: Point) -> Scalar:
cos_ab = torch.sum(pt_x * pt_y) / torch.sqrt(torch.sum(pt_x * pt_x) * torch.sum(pt_y * pt_y))
return torch.arccos(torch.clip(cos_ab, -1.0, +1.0))
def geodesic(self, pt_x: Point, pt_y: Point, t: float) -> Point:
return slerp(pt_x, pt_y, t)
def to_tangent(self, pt_x: Point, vec_w: Vector):
"""Project to tangent in the pre-shape space. Pre-shapes are equivalent
translation and scale but not rotation.
:param pt_x: base point for the tangent vector
:param vec_w: ambient space vector
:return: tangent vector with mean-shifts removed, as well as scaling removed
"""
# Points must be 'centered', so subtract off component of vec that would
# affect the mean
vec_w = vec_w - torch.mean(vec_w, dim=0)
# Subtract off component that would uniformly scale all points (component of
# the tangent in the direction of pt_x)
vec_w = vec_w - pt_x * torch.sum(vec_w * pt_x) / torch.sum(pt_x * pt_x)
return vec_w
def inner_product(self, pt_x: Point, vec_w: Vector, vec_v: Vector):
return torch.sum(vec_w * vec_v)
def exp_map(self, pt_x: Point, vec_w: Vector) -> Point:
# Identical to Hypersphere.exp_map
# See https://math.stackexchange.com/a/1930880
norm = self.norm(pt_x, vec_w)
c1 = torch.cos(norm)
c2 = torch.sinc(norm / np.pi)
return c1 * pt_x + c2 * vec_w
def log_map(self, pt_x: Point, pt_y: Point) -> Vector:
# Identical to Hypersphere.log_map
unscaled_w = self.to_tangent(pt_x, pt_y)
norm_w = unscaled_w / torch.clip(self.norm(pt_x, unscaled_w), 1e-7)
return norm_w * self.distance(pt_x, pt_y)
def levi_civita(self, pt_x: Point, pt_y: Point, vec_w: Vector) -> Vector:
# Identical to Hypersphere.levi_civita
vec_v = self.log_map(pt_x, pt_y)
angle = self.distance(pt_x, pt_y)
unit_v = vec_v / torch.clip(angle, 1e-7) # the length of tangent vector v *is* the length from x to y
w_along_v = torch.sum(unit_v * vec_w)
orth_part = vec_w - w_along_v * unit_v
return orth_part + w_along_v * (torch.cos(angle) * unit_v - torch.sin(angle) * pt_x)
class ShapeSpace(PreShapeSpace):
"""Practical differences between PreShape and Shape:
- Shape space decomposes the PreShape tangent space into vertical (within equivalence class) and horizontal (across
equivalence class) parts.
- Shape.to_tangent is not overridden, so Shape.to_tangent(pt, vec) will in general contain both horz and vert parts
- Shape.exp_map and Shape.levi_civita both respect the vertical part
- Shape.log_map returns *only* the horizontal part
- Shape.inner_product only takes the horizontal part
This means that exp_map and log_map are not exact inverses up to _equality_. However, they are inverses up to
_equivalence_.
"""
def __init__(self, *args, **kwargs):
super(ShapeSpace, self).__init__(*args, **kwargs)
self.dim = self.dim - (self.p * (self.p - 1)) // 2
@staticmethod
def orthogocnal_procrustes_rotation(x, y, anchor="middle"):
"""Provided x and y, each matrix of size (m, p) that are already centered and
scaled, solve the orthogonal procrustest problem (rotate x and y into a common
frame that minimizes distances).
If anchor="middle" (default) then both are both rotated to meet in the middle
If anchor="x", then x is left unchanged and y is rotated towards it
If anchor="y", then y is left unchanged and x is rotated towards it
See https://en.wikipedia.org/wiki/Orthogonal_Procrustes_problem
:return: r_x and r_y, which, when right-multiplied with x and y, gives the
aligned coordinates, or None for each if no transform is required
"""
with torch.no_grad():
u, _, v = torch.linalg.svd(x.T @ y)
# Helpful trick to see how these are related: u is the inverse of u.T,
# and likewise v is inverse of v.T. We get to the anchor=x and anchor=y
# solutions by right-multiplying both return values by u.T or
# right-multiplying both return values by v, respectively (if both return
# values are rotated in the same way, it preserves the shape).
if anchor == "middle":
return u, v.T
elif anchor == "x":
return None, v.T @ u.T
elif anchor == "y":
return u @ v, None
else:
raise ValueError(f"Invalid 'anchor' argument: {anchor}: "
f"(must be 'middle', 'x', or 'y')")
@staticmethod
def align(x, y, anchor="middle"):
"""Provided x and y, each matrix of size (m, p) that are already centered and
scaled, solve the orthogonal procrustes problem (rotate x and y into a common
frame that minimizes distances).
:return: new_a, new_b the rotated versions of x and y, minimizing element-wise
squared differences
"""
r_x, r_y = ShapeSpace.orthogocnal_procrustes_rotation(x, y, anchor)
return x @ r_x if r_x is not None else x, y @ r_y if r_y is not None else y
def _distance_impl(self, pt_x: Point, pt_y: Point) -> Scalar:
# Distance in shape space = distance in pre shape space after aligning points
# to each other
pt_x, pt_y = self.align(pt_x, pt_y)
return super(ShapeSpace, self)._distance_impl(pt_x, pt_y)
def geodesic(self, pt_x: Point, pt_y: Point, t: float) -> Point:
# Choice of anchor here is largely arbitrary, but for local consistency with
# log_map we set it to 'x'
pt_x, pt_y = self.align(pt_x, pt_y, anchor="x")
return super(ShapeSpace, self).geodesic(pt_x, pt_y, t)
def _horizontal_tangent(self, pt_x: Point, vec_w: Vector, *, vert_part: Vector = None) -> Vector:
"""The 'horizontal' part of the tangent space is the part that is actually
movement in the quotient space, i.e. across equivalence classes. For example,
east/west movement where equivalence = lines of longitude.
"""
# Start by ensuring vec_w is a tangent vector in the pre-shape space
vec_w = super(ShapeSpace, self).to_tangent(pt_x, vec_w)
if vert_part is None:
# Calculate vertical part
vert_part = self._vertical_tangent(pt_x, vec_w)
# The horizontal part is what is left after projecting away the vertical part
square_vert_norm = torch.clip(torch.sum(vert_part * vert_part), 1e-7)
horz_part = vec_w - vert_part * torch.sum(vec_w * vert_part) / square_vert_norm
return horz_part
def _solve_skew_symmetric_vertical_tangent(self, pt_x: Point, vec_w: Vector):
"""Find A such that x@A is the vertical part of vec_w at pt_x
"""
# Start by ensuring vec_w is a tangent vector in the pre-shape space
vec_w = super(ShapeSpace, self).to_tangent(pt_x, vec_w)
# See equation (2) in Nava-Yazdani et al (2020), but note that all of our
# equations are transposed from theirs
xxT = pt_x.T @ pt_x
wxT = vec_w.T @ pt_x
return _solve_sylvester(xxT, xxT, wxT - wxT.T)
def _vertical_tangent(self, pt_x: Point, vec_w: Vector) -> Vector:
"""The 'vertical' part of the tangent space is the part that doesn't count as
movement in the quotient space, i.e. within equivalence classes. For example,
north/south movement where equivalence = lines of longitude.
The space of 'vertical' tangents, after accounting for shifts and scales with
_aux_to_tangent, is the set of rotations. We get these by looking at the span
of all 2D rotations – one per pair of axes in our space.
"""
return pt_x @ self._solve_skew_symmetric_vertical_tangent(pt_x, vec_w)
def inner_product(self, pt_x: Point, vec_w: Vector, vec_v: Vector):
# Ensure that we're only measuring the 'horizontal' part of each tangent
# vector. (We expect distance between two points to be equal to square root
# norm of the logarithmic map between them).
h_vec_w, h_vec_v = self._horizontal_tangent(pt_x, vec_w), self._horizontal_tangent(pt_x, vec_v)
return super(ShapeSpace, self).inner_product(pt_x, h_vec_w, h_vec_v)
def exp_map(self, pt_x: Point, vec_w: Vector) -> Point:
# Decompose into horizontal and vertical parts. The vertical part specifies a
# rotation in the sense that Skew-Symmetric matrices are the tangent space of
# SO(p), and the vertical part equals Ax for some skew-symmetric matrix A. We
# get from skew-symmetry to rotation using the matrix exponential,
# i.e. rotation_matrix = matrix_exp(skew_symmetric_matrix)
mat_a = self._solve_skew_symmetric_vertical_tangent(pt_x, vec_w)
rotation = torch.matrix_exp(mat_a)
horz_part = self._horizontal_tangent(pt_x, vec_w, vert_part=pt_x @ mat_a)
# Apply vertical part, and note that rotation is equivariant with respect to
# horizontal vectors, or horz_Rx(Rw)=Rhorz_x(w). This means that we (1)
# rotate pt_x to pt_x', and (2) the new horizontal vector at pt_x' is equal
# to the rotation applied to the original horizontal vector
pt_x, horz_part = pt_x @ rotation, horz_part @ rotation
# After applying the vertical part, delegate to the ambient PreShapeSpace for
# the remaining horizontal part
return super(ShapeSpace, self).exp_map(pt_x, horz_part)
def log_map(self, pt_x: Point, pt_y: Point) -> Vector:
# Only returns *horizontal* part of the tangent. Note that this means log_map
# and exp_map are not inverses from the perspective of the PreShapeSpace,
# but they are in the ShapeSpace. In other words, if c=exp_map(x,log_map(x,
# y)), then we'll have length(y,c)=0 but not y==c Method: align y to x and
# get x-->y' horizontal part from the PreShapeSpace's log_map
_, new_b = ShapeSpace.align(pt_x, pt_y, anchor="x")
return super(ShapeSpace, self).log_map(pt_x, new_b)
def levi_civita(self, pt_x: Point, pt_y: Point, vec_w: Vector) -> Vector:
# Both the horizontal and vertical parts of tangent vectors are equivariant
# after rotation (Lemma 1b of Nava-Yazdani et al (2020)). This means we can
# start by aligning x to y as follows to take care of the vertical part,
# then all that's left is to transport the horizontal part:
r_x, _ = ShapeSpace.orthogocnal_procrustes_rotation(pt_x, pt_y, anchor="y")
new_pt_x, new_vec_w = pt_x @ r_x, vec_w @ r_x
return super(ShapeSpace, self).levi_civita(new_pt_x, pt_y, new_vec_w)
def _solve_sylvester(x, y, q):
# TODO - implement natively in pytorch so we don't have to convert to numpy on CPU and back again
a_np, b_np, q_np = x.detach().cpu().numpy(), y.detach().cpu().numpy(), q.detach().cpu().numpy()
return torch.tensor(solve_sylvester(a_np, b_np, q_np), dtype=x.dtype, device=x.device)