-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
216 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
General Linear Group | ||
==================== | ||
|
||
.. currentmodule:: geotorch | ||
|
||
:math:`\operatorname{SL}(n)` is the manifold of matrices of determinant equal to 1. | ||
|
||
.. math:: | ||
\operatorname{SL}(n) = \{X \in \mathbb{R}^{n\times n}\:\mid\:\det(X) = 1\} | ||
It is realized via an SVD-like factorization: | ||
|
||
.. math:: | ||
\begin{align*} | ||
\pi \colon \operatorname{SO}(n) \times \mathbb{R}^n \times \operatorname{SO}(n) | ||
&\to \operatorname{SL}(n) \\ | ||
(U, \Sigma, V) &\mapsto Uf(\Sigma)V^\intercal | ||
\end{align*} | ||
where we have identified the vector :math:`\Sigma` with a diagonal matrix in :math:`\mathbb{R}^{n \times n}`. The function :math:`f\colon \mathbb{R} \to (\varepsilon, \infty)` is applied element-wise to the diagonal for a small :math:`\varepsilon > 0`. By default, a combination of the `softplus` function | ||
|
||
.. math:: | ||
\begin{align*} | ||
\operatorname{softplus} \colon \mathbb{R} &\to (\varepsilon, \infty) \\ | ||
x &\mapsto \log(1+\exp(x)) + \varepsilon | ||
\end{align*} | ||
composed with the normalization function | ||
|
||
.. math:: | ||
\begin{align*} | ||
\operatorname{g} \colon \mathbb{R}^n &\to (\varepsilon, \infty)^n \\ | ||
(x_1, \dots, x_n) &\mapsto \left(\frac{x_i}{\sqrt[\leftroot{-2}\uproot{2}n]{\prod_i x_i}}\right)_i | ||
\end{align*} | ||
to ensure that the product of all the singular values is equal to 1. | ||
|
||
.. autoclass:: SL | ||
|
||
.. automethod:: sample | ||
.. automethod:: in_manifold |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import torch | ||
from .glp import GLp | ||
from .fixedrank import FixedRank | ||
|
||
|
||
class SL(GLp): | ||
def __init__(self, size, f="softplus", triv="expm"): | ||
r""" | ||
Manifold of special linear matrices | ||
Args: | ||
size (torch.size): Size of the tensor to be parametrized | ||
f (str or callable or pair of callables): Optional. Either: | ||
- ``"softplus"`` | ||
- A callable that maps real numbers to the interval :math:`(0, \infty)` | ||
- A pair of callables such that the first maps the real numbers to | ||
:math:`(0, \infty)` and the second is a (right) inverse of the first | ||
Default: ``"softplus"`` | ||
triv (str or callable): Optional. | ||
A map that maps skew-symmetric matrices onto the orthogonal matrices | ||
surjectively. This is used to optimize the :math:`U` and :math:`V` in the | ||
SVD. It can be one of ``["expm", "cayley"]`` or a custom | ||
callable. Default: ``"expm"`` | ||
""" | ||
super().__init__(size, SL.parse_f(f), triv) | ||
|
||
@staticmethod | ||
def parse_f(f_name): | ||
if f_name in FixedRank.fs.keys(): | ||
f, inv = FixedRank.parse_f(f_name) | ||
|
||
def f_sl(x): | ||
y = f(x) | ||
return y / y.prod(dim=-1, keepdim=True).pow(1.0 / y.shape[-1]) | ||
|
||
return (f_sl, inv) | ||
else: | ||
return f_name | ||
|
||
def in_manifold_singular_values(self, S, eps=5e-3): | ||
if not super().in_manifold_singular_values(S, eps): | ||
return False | ||
# We compute the \infty-norm of the determinant minus 1 and should be about zero | ||
infty_norm = (S.prod(dim=-1) - 1).abs().max(dim=-1).values | ||
return (infty_norm < eps).all().item() | ||
|
||
def in_manifold(self, X, eps=5e-3): | ||
r""" | ||
Checks that a given matrix is in the manifold. | ||
Args: | ||
X (torch.Tensor or tuple): The input matrix or matrices of shape ``(*, n, k)``. | ||
eps (float): Optional. Threshold at which the singular values are | ||
considered to be zero | ||
Default: ``5e-3`` | ||
""" | ||
# The purpose of this function is just to have a more lax default eps value | ||
return super().in_manifold(X, eps) | ||
|
||
def sample(self, init_=torch.nn.init.xavier_normal_, eps=5e-6, factorized=False): | ||
r""" | ||
Returns a randomly sampled matrix on the manifold by sampling a matrix according | ||
to ``init_`` and projecting it onto the manifold. | ||
The output of this method can be used to initialize a parametrized tensor | ||
that has been parametrized with this or any other manifold as:: | ||
>>> layer = nn.Linear(20, 20) | ||
>>> M = SL(layer.weight.size(), rank=6) | ||
>>> geotorch.register_parametrization(layer, "weight", M) | ||
>>> layer.weight = M.sample() | ||
Args: | ||
init\_ (callable): Optional. A function that takes a tensor and fills it | ||
in place according to some distribution. See | ||
`torch.init <https://pytorch.org/docs/stable/nn.init.html>`_. | ||
Default: ``torch.nn.init.xavier_normal_`` | ||
eps (float): Optional. Minimum singular value of the sampled matrix. | ||
Default: ``5e-6`` | ||
""" | ||
U, S, V = super().sample(factorized=True, init_=init_) | ||
with torch.no_grad(): | ||
# S >= 0, as given by torch.linalg.eigvalsh() | ||
S = S / S.prod(dim=-1, keepdim=True).pow(1.0 / S.shape[-1]) | ||
return (U * S.unsqueeze(-2)) @ V.transpose(-2, -1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from unittest import TestCase | ||
|
||
from geotorch.sl import SL | ||
|
||
|
||
class TestSL(TestCase): | ||
def test_SL_errors(self): | ||
# Non square | ||
with self.assertRaises(ValueError): | ||
SL(size=(4, 3)) | ||
# Try to instantiate it in a vector rather than a matrix | ||
with self.assertRaises(ValueError): | ||
SL(size=(5,)) |