Skip to content

Commit

Permalink
Implement SL(n) (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
lezcano authored Mar 1, 2022
1 parent 0d69425 commit 652f6f9
Show file tree
Hide file tree
Showing 10 changed files with 216 additions and 9 deletions.
6 changes: 6 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ The following constraints are implemented and may be used as in the example abov
- |grassmannian|_. Skew-symmetric matrices
- |almost_orthogonal|_. Matrices with singular values in the interval ``[1-λ, 1+λ]``
- |invertible|_. Invertible matrices with positive determinant
- |sln|_. Matrices of determinant equal to ``1``
- |low_rank|_. Matrices of rank at most ``r``
- |fixed_rank|_. Matrices of rank ``r``
- |positive_definite|_. Positive definite matrices
Expand All @@ -75,6 +76,8 @@ The following constraints are implemented and may be used as in the example abov
.. _almost_orthogonal: https://geotorch.readthedocs.io/en/latest/constraints.html#geotorch.almost_orthogonal
.. |invertible| replace:: ``geotorch.invertible``
.. _invertible: https://geotorch.readthedocs.io/en/latest/constraints.html#geotorch.invertible
.. |sln| replace:: ``geotorch.sln``
.. _sln: https://geotorch.readthedocs.io/en/latest/constraints.html#geotorch.sln
.. |low_rank| replace:: ``geotorch.low_rank(r)``
.. _low_rank: https://geotorch.readthedocs.io/en/latest/constraints.html#geotorch.low_rank
.. |fixed_rank| replace:: ``geotorch.fixed_rank(r)``
Expand Down Expand Up @@ -111,6 +114,7 @@ GeoTorch currently supports the following spaces:
- |almost|_: Manifold of ``n×k`` matrices with singular values in the interval ``[1-λ, 1+λ]``
- |grass|_: Manifold of ``k``-dimensional subspaces in ``Rⁿ``
- |glp|_: Manifold of invertible ``n×n`` matrices with positive determinant
- |sl|_: Manifold of ``n×n`` matrices with determinant equal to `1`
- |low|_: Variety of ``n×k`` matrices of rank ``r`` or less
- |fixed|_: Manifold of ``n×k`` matrices of rank ``r``
- |psd|_: Cone of ``n×n`` symmetric positive definite matrices
Expand Down Expand Up @@ -170,6 +174,8 @@ As every space in GeoTorch is, at its core, a map from a flat space into a manif
.. _grass: https://geotorch.readthedocs.io/en/latest/orthogonal/grassmannian.html
.. |glp| replace:: ``GLp(n)``
.. _glp: https://geotorch.readthedocs.io/en/latest/invertibility/glp.html
.. |sl| replace:: ``SL(n)``
.. _sl: https://geotorch.readthedocs.io/en/latest/invertibility/sl.html
.. |low| replace:: ``LowRank(n,k,r)``
.. _low: https://geotorch.readthedocs.io/en/latest/lowrank/lowrank.html
.. |fixed| replace:: ``FixedRank(n,k,r)``
Expand Down
4 changes: 2 additions & 2 deletions docs/source/invertibility/glp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ It is realized via an SVD-like factorization:
(U, \Sigma, V) &\mapsto Uf(\Sigma)V^\intercal
\end{align*}
where we have identified the vector :math:`\Sigma` with a diagonal matrix in :math:`\mathbb{R}^{n \times n}`. The function :math:`f\colon \mathbb{R} \to (0, \infty)` is applied element-wise to the diagonal. By default, the `softmax` function is used
where we have identified the vector :math:`\Sigma` with a diagonal matrix in :math:`\mathbb{R}^{n \times n}`. The function :math:`f\colon \mathbb{R} \to (0, \infty)` is applied element-wise to the diagonal. By default, the `softplus` function is used

.. math::
\begin{align*}
\operatorname{softmax} \colon \mathbb{R} &\to (0, \infty) \\
\operatorname{softplus} \colon \mathbb{R} &\to (0, \infty) \\
x &\mapsto \log(1+\exp(x)) + \varepsilon
\end{align*}
Expand Down
45 changes: 45 additions & 0 deletions docs/source/invertibility/sl.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
General Linear Group
====================

.. currentmodule:: geotorch

:math:`\operatorname{SL}(n)` is the manifold of matrices of determinant equal to 1.

.. math::
\operatorname{SL}(n) = \{X \in \mathbb{R}^{n\times n}\:\mid\:\det(X) = 1\}
It is realized via an SVD-like factorization:

.. math::
\begin{align*}
\pi \colon \operatorname{SO}(n) \times \mathbb{R}^n \times \operatorname{SO}(n)
&\to \operatorname{SL}(n) \\
(U, \Sigma, V) &\mapsto Uf(\Sigma)V^\intercal
\end{align*}
where we have identified the vector :math:`\Sigma` with a diagonal matrix in :math:`\mathbb{R}^{n \times n}`. The function :math:`f\colon \mathbb{R} \to (\varepsilon, \infty)` is applied element-wise to the diagonal for a small :math:`\varepsilon > 0`. By default, a combination of the `softplus` function

.. math::
\begin{align*}
\operatorname{softplus} \colon \mathbb{R} &\to (\varepsilon, \infty) \\
x &\mapsto \log(1+\exp(x)) + \varepsilon
\end{align*}
composed with the normalization function

.. math::
\begin{align*}
\operatorname{g} \colon \mathbb{R}^n &\to (\varepsilon, \infty)^n \\
(x_1, \dots, x_n) &\mapsto \left(\frac{x_i}{\sqrt[\leftroot{-2}\uproot{2}n]{\prod_i x_i}}\right)_i
\end{align*}
to ensure that the product of all the singular values is equal to 1.

.. autoclass:: SL

.. automethod:: sample
.. automethod:: in_manifold
4 changes: 4 additions & 0 deletions geotorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
low_rank,
fixed_rank,
invertible,
sln,
positive_definite,
positive_semidefinite,
positive_semidefinite_low_rank,
Expand All @@ -25,6 +26,7 @@
from .lowrank import LowRank
from .fixedrank import FixedRank
from .glp import GLp
from .sl import SL
from .psd import PSD
from .pssd import PSSD
from .pssdfixedrank import PSSDFixedRank
Expand All @@ -48,6 +50,7 @@
"Stiefel",
"AlmostOrthogonal",
"GLp",
"SL",
"FixedRank",
"PSD",
"PSSD",
Expand All @@ -62,6 +65,7 @@
"fixed_rank",
"almost_orthogonal",
"invertible",
"sln",
"positive_definite",
"positive_semidefinite",
"positive_semidefinite_low_rank",
Expand Down
42 changes: 41 additions & 1 deletion geotorch/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .lowrank import LowRank
from .fixedrank import FixedRank
from .glp import GLp
from .sl import SL
from .psd import PSD
from .pssd import PSSD
from .pssdlowrank import PSSDLowRank
Expand Down Expand Up @@ -327,7 +328,7 @@ def invertible(module, tensor_name="weight", f="softplus", triv="expm"):
Examples::
>>> layer = nn.Linear(20, 20)
>>> geotorch.invertible(layer, "weight", 5)
>>> geotorch.invertible(layer, "weight")
>>> torch.det(layer.weight) > 0.0
True
Expand All @@ -354,6 +355,45 @@ def invertible(module, tensor_name="weight", f="softplus", triv="expm"):
return _register_manifold(module, tensor_name, GLp, f, triv)


def sln(module, tensor_name="weight", f="softplus", triv="expm"):
r"""Adds a constraint of having determinant one to the tensor ``module.tensor_name``.
When accessing ``module.tensor_name``, the module will return the
parametrized version :math:`X` which will have determinant equal to 1.
If the tensor has more than two dimensions, the parametrization will be
applied to the last two dimensions.
Examples::
>>> layer = nn.Linear(20, 20)
>>> geotorch.sln(layer, "weight")
>>> torch.det(layer.weight)
tensor(1.0000)
Args:
module (nn.Module): module on which to register the parametrization
tensor_name (string): name of the parameter, buffer, or parametrization
on which the parametrization will be applied. Default: ``"weight"``
f (str or callable or pair of callables): Optional. Either:
- ``"softplus"``
- A callable that maps real numbers to the interval :math:`(0, \infty)`
- A pair of callables such that the first maps the real numbers to
:math:`(0, \infty)` and the second is a (right) inverse of the first
Default: ``"softplus"``
triv (str or callable): Optional.
A map that maps skew-symmetric matrices onto the orthogonal matrices
surjectively. This is used to optimize the :math:`U` and :math:`V` in the
SVD. It can be one of ``["expm", "cayley"]`` or a custom
callable. Default: ``"expm"``
"""
return _register_manifold(module, tensor_name, SL, f, triv)


def positive_definite(module, tensor_name="weight", f="softplus", triv="expm"):
r"""Adds a positive definiteness constraint to the tensor ``module.tensor_name``.
Expand Down
14 changes: 9 additions & 5 deletions geotorch/fixedrank.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def in_manifold_singular_values(self, S, eps=1e-5):
infty_norm = D.abs().max(dim=-1).values
return (infty_norm > eps).all().item()

def sample(self, init_=torch.nn.init.xavier_normal_, eps=5e-6):
def sample(self, init_=torch.nn.init.xavier_normal_, eps=5e-6, factorized=False):
r"""
Returns a randomly sampled matrix on the manifold by sampling a matrix according
to ``init_`` and projecting it onto the manifold.
Expand Down Expand Up @@ -117,7 +117,11 @@ def sample(self, init_=torch.nn.init.xavier_normal_, eps=5e-6):
with torch.no_grad():
# S >= 0, as given by torch.linalg.eigvalsh()
S[S < eps] = eps
X = (U * S.unsqueeze(-2)) @ V.transpose(-2, -1)
if self.transposed:
X = X.transpose(-2, -1)
return X
if factorized:
return U, S, V
else:
# Compute U S V^T efficiently
if self.transposed:
return (U * S.unsqueeze(-2)) @ V.transpose(-2, -1)
else:
return U @ (S.unsqueeze(-1) * V.transpose(-2, -1))
5 changes: 4 additions & 1 deletion geotorch/lowrank.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,10 @@ def sample(self, init_=torch.nn.init.xavier_normal_, factorized=False):
U, S, Vt = torch.linalg.svd(X, full_matrices=False)
U, S, Vt = U[..., : self.rank], S[..., : self.rank], Vt[..., : self.rank, :]
if factorized:
return U, S, Vt.transpose(-2, -1)
if self.transposed:
return Vt.transpose(-2, -1), S, U
else:
return U, S, Vt.transpose(-2, -1)
else:
X = (U * S.unsqueeze(-2)) @ Vt
if self.transposed:
Expand Down
89 changes: 89 additions & 0 deletions geotorch/sl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import torch
from .glp import GLp
from .fixedrank import FixedRank


class SL(GLp):
def __init__(self, size, f="softplus", triv="expm"):
r"""
Manifold of special linear matrices
Args:
size (torch.size): Size of the tensor to be parametrized
f (str or callable or pair of callables): Optional. Either:
- ``"softplus"``
- A callable that maps real numbers to the interval :math:`(0, \infty)`
- A pair of callables such that the first maps the real numbers to
:math:`(0, \infty)` and the second is a (right) inverse of the first
Default: ``"softplus"``
triv (str or callable): Optional.
A map that maps skew-symmetric matrices onto the orthogonal matrices
surjectively. This is used to optimize the :math:`U` and :math:`V` in the
SVD. It can be one of ``["expm", "cayley"]`` or a custom
callable. Default: ``"expm"``
"""
super().__init__(size, SL.parse_f(f), triv)

@staticmethod
def parse_f(f_name):
if f_name in FixedRank.fs.keys():
f, inv = FixedRank.parse_f(f_name)

def f_sl(x):
y = f(x)
return y / y.prod(dim=-1, keepdim=True).pow(1.0 / y.shape[-1])

return (f_sl, inv)
else:
return f_name

def in_manifold_singular_values(self, S, eps=5e-3):
if not super().in_manifold_singular_values(S, eps):
return False
# We compute the \infty-norm of the determinant minus 1 and should be about zero
infty_norm = (S.prod(dim=-1) - 1).abs().max(dim=-1).values
return (infty_norm < eps).all().item()

def in_manifold(self, X, eps=5e-3):
r"""
Checks that a given matrix is in the manifold.
Args:
X (torch.Tensor or tuple): The input matrix or matrices of shape ``(*, n, k)``.
eps (float): Optional. Threshold at which the singular values are
considered to be zero
Default: ``5e-3``
"""
# The purpose of this function is just to have a more lax default eps value
return super().in_manifold(X, eps)

def sample(self, init_=torch.nn.init.xavier_normal_, eps=5e-6, factorized=False):
r"""
Returns a randomly sampled matrix on the manifold by sampling a matrix according
to ``init_`` and projecting it onto the manifold.
The output of this method can be used to initialize a parametrized tensor
that has been parametrized with this or any other manifold as::
>>> layer = nn.Linear(20, 20)
>>> M = SL(layer.weight.size(), rank=6)
>>> geotorch.register_parametrization(layer, "weight", M)
>>> layer.weight = M.sample()
Args:
init\_ (callable): Optional. A function that takes a tensor and fills it
in place according to some distribution. See
`torch.init <https://pytorch.org/docs/stable/nn.init.html>`_.
Default: ``torch.nn.init.xavier_normal_``
eps (float): Optional. Minimum singular value of the sampled matrix.
Default: ``5e-6``
"""
U, S, V = super().sample(factorized=True, init_=init_)
with torch.no_grad():
# S >= 0, as given by torch.linalg.eigvalsh()
S = S / S.prod(dim=-1, keepdim=True).pow(1.0 / S.shape[-1])
return (U * S.unsqueeze(-2)) @ V.transpose(-2, -1)
3 changes: 3 additions & 0 deletions test/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from geotorch.pssdlowrank import PSSDLowRank
from geotorch.pssdfixedrank import PSSDFixedRank
from geotorch.glp import GLp
from geotorch.sl import SL
from geotorch.almostorthogonal import AlmostOrthogonal
from geotorch.sphere import Sphere, SphereEmbedded

Expand Down Expand Up @@ -112,9 +113,11 @@ def test_psd_and_glp(self):
PSD,
PSSD,
GLp,
SL,
geotorch.positive_definite,
geotorch.positive_semidefinite,
geotorch.invertible,
geotorch.sln,
],
[{}],
[{}],
Expand Down
13 changes: 13 additions & 0 deletions test/test_sl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from unittest import TestCase

from geotorch.sl import SL


class TestSL(TestCase):
def test_SL_errors(self):
# Non square
with self.assertRaises(ValueError):
SL(size=(4, 3))
# Try to instantiate it in a vector rather than a matrix
with self.assertRaises(ValueError):
SL(size=(5,))

0 comments on commit 652f6f9

Please sign in to comment.