From 652f6f9a55ae32ecee2bd00368f307a54bc92b9b Mon Sep 17 00:00:00 2001 From: Lezcano Date: Tue, 1 Mar 2022 18:59:20 -0300 Subject: [PATCH] Implement SL(n) (#31) --- README.rst | 6 +++ docs/source/invertibility/glp.rst | 4 +- docs/source/invertibility/sl.rst | 45 ++++++++++++++++ geotorch/__init__.py | 4 ++ geotorch/constraints.py | 42 ++++++++++++++- geotorch/fixedrank.py | 14 +++-- geotorch/lowrank.py | 5 +- geotorch/sl.py | 89 +++++++++++++++++++++++++++++++ test/test_integration.py | 3 ++ test/test_sl.py | 13 +++++ 10 files changed, 216 insertions(+), 9 deletions(-) create mode 100644 docs/source/invertibility/sl.rst create mode 100644 geotorch/sl.py create mode 100644 test/test_sl.py diff --git a/README.rst b/README.rst index f01e46f3..024ceb87 100644 --- a/README.rst +++ b/README.rst @@ -54,6 +54,7 @@ The following constraints are implemented and may be used as in the example abov - |grassmannian|_. Skew-symmetric matrices - |almost_orthogonal|_. Matrices with singular values in the interval ``[1-λ, 1+λ]`` - |invertible|_. Invertible matrices with positive determinant +- |sln|_. Matrices of determinant equal to ``1`` - |low_rank|_. Matrices of rank at most ``r`` - |fixed_rank|_. Matrices of rank ``r`` - |positive_definite|_. Positive definite matrices @@ -75,6 +76,8 @@ The following constraints are implemented and may be used as in the example abov .. _almost_orthogonal: https://geotorch.readthedocs.io/en/latest/constraints.html#geotorch.almost_orthogonal .. |invertible| replace:: ``geotorch.invertible`` .. _invertible: https://geotorch.readthedocs.io/en/latest/constraints.html#geotorch.invertible +.. |sln| replace:: ``geotorch.sln`` +.. _sln: https://geotorch.readthedocs.io/en/latest/constraints.html#geotorch.sln .. |low_rank| replace:: ``geotorch.low_rank(r)`` .. _low_rank: https://geotorch.readthedocs.io/en/latest/constraints.html#geotorch.low_rank .. |fixed_rank| replace:: ``geotorch.fixed_rank(r)`` @@ -111,6 +114,7 @@ GeoTorch currently supports the following spaces: - |almost|_: Manifold of ``n×k`` matrices with singular values in the interval ``[1-λ, 1+λ]`` - |grass|_: Manifold of ``k``-dimensional subspaces in ``Rⁿ`` - |glp|_: Manifold of invertible ``n×n`` matrices with positive determinant +- |sl|_: Manifold of ``n×n`` matrices with determinant equal to `1` - |low|_: Variety of ``n×k`` matrices of rank ``r`` or less - |fixed|_: Manifold of ``n×k`` matrices of rank ``r`` - |psd|_: Cone of ``n×n`` symmetric positive definite matrices @@ -170,6 +174,8 @@ As every space in GeoTorch is, at its core, a map from a flat space into a manif .. _grass: https://geotorch.readthedocs.io/en/latest/orthogonal/grassmannian.html .. |glp| replace:: ``GLp(n)`` .. _glp: https://geotorch.readthedocs.io/en/latest/invertibility/glp.html +.. |sl| replace:: ``SL(n)`` +.. _sl: https://geotorch.readthedocs.io/en/latest/invertibility/sl.html .. |low| replace:: ``LowRank(n,k,r)`` .. _low: https://geotorch.readthedocs.io/en/latest/lowrank/lowrank.html .. |fixed| replace:: ``FixedRank(n,k,r)`` diff --git a/docs/source/invertibility/glp.rst b/docs/source/invertibility/glp.rst index db892776..171137dc 100644 --- a/docs/source/invertibility/glp.rst +++ b/docs/source/invertibility/glp.rst @@ -19,12 +19,12 @@ It is realized via an SVD-like factorization: (U, \Sigma, V) &\mapsto Uf(\Sigma)V^\intercal \end{align*} -where we have identified the vector :math:`\Sigma` with a diagonal matrix in :math:`\mathbb{R}^{n \times n}`. The function :math:`f\colon \mathbb{R} \to (0, \infty)` is applied element-wise to the diagonal. By default, the `softmax` function is used +where we have identified the vector :math:`\Sigma` with a diagonal matrix in :math:`\mathbb{R}^{n \times n}`. The function :math:`f\colon \mathbb{R} \to (0, \infty)` is applied element-wise to the diagonal. By default, the `softplus` function is used .. math:: \begin{align*} - \operatorname{softmax} \colon \mathbb{R} &\to (0, \infty) \\ + \operatorname{softplus} \colon \mathbb{R} &\to (0, \infty) \\ x &\mapsto \log(1+\exp(x)) + \varepsilon \end{align*} diff --git a/docs/source/invertibility/sl.rst b/docs/source/invertibility/sl.rst new file mode 100644 index 00000000..bd912fdb --- /dev/null +++ b/docs/source/invertibility/sl.rst @@ -0,0 +1,45 @@ +General Linear Group +==================== + +.. currentmodule:: geotorch + +:math:`\operatorname{SL}(n)` is the manifold of matrices of determinant equal to 1. + +.. math:: + + \operatorname{SL}(n) = \{X \in \mathbb{R}^{n\times n}\:\mid\:\det(X) = 1\} + +It is realized via an SVD-like factorization: + +.. math:: + + \begin{align*} + \pi \colon \operatorname{SO}(n) \times \mathbb{R}^n \times \operatorname{SO}(n) + &\to \operatorname{SL}(n) \\ + (U, \Sigma, V) &\mapsto Uf(\Sigma)V^\intercal + \end{align*} + +where we have identified the vector :math:`\Sigma` with a diagonal matrix in :math:`\mathbb{R}^{n \times n}`. The function :math:`f\colon \mathbb{R} \to (\varepsilon, \infty)` is applied element-wise to the diagonal for a small :math:`\varepsilon > 0`. By default, a combination of the `softplus` function + +.. math:: + + \begin{align*} + \operatorname{softplus} \colon \mathbb{R} &\to (\varepsilon, \infty) \\ + x &\mapsto \log(1+\exp(x)) + \varepsilon + \end{align*} + +composed with the normalization function + +.. math:: + + \begin{align*} + \operatorname{g} \colon \mathbb{R}^n &\to (\varepsilon, \infty)^n \\ + (x_1, \dots, x_n) &\mapsto \left(\frac{x_i}{\sqrt[\leftroot{-2}\uproot{2}n]{\prod_i x_i}}\right)_i + \end{align*} + +to ensure that the product of all the singular values is equal to 1. + +.. autoclass:: SL + + .. automethod:: sample + .. automethod:: in_manifold diff --git a/geotorch/__init__.py b/geotorch/__init__.py index 81685898..02943273 100644 --- a/geotorch/__init__.py +++ b/geotorch/__init__.py @@ -8,6 +8,7 @@ low_rank, fixed_rank, invertible, + sln, positive_definite, positive_semidefinite, positive_semidefinite_low_rank, @@ -25,6 +26,7 @@ from .lowrank import LowRank from .fixedrank import FixedRank from .glp import GLp +from .sl import SL from .psd import PSD from .pssd import PSSD from .pssdfixedrank import PSSDFixedRank @@ -48,6 +50,7 @@ "Stiefel", "AlmostOrthogonal", "GLp", + "SL", "FixedRank", "PSD", "PSSD", @@ -62,6 +65,7 @@ "fixed_rank", "almost_orthogonal", "invertible", + "sln", "positive_definite", "positive_semidefinite", "positive_semidefinite_low_rank", diff --git a/geotorch/constraints.py b/geotorch/constraints.py index 16efe44c..4535a97d 100644 --- a/geotorch/constraints.py +++ b/geotorch/constraints.py @@ -10,6 +10,7 @@ from .lowrank import LowRank from .fixedrank import FixedRank from .glp import GLp +from .sl import SL from .psd import PSD from .pssd import PSSD from .pssdlowrank import PSSDLowRank @@ -327,7 +328,7 @@ def invertible(module, tensor_name="weight", f="softplus", triv="expm"): Examples:: >>> layer = nn.Linear(20, 20) - >>> geotorch.invertible(layer, "weight", 5) + >>> geotorch.invertible(layer, "weight") >>> torch.det(layer.weight) > 0.0 True @@ -354,6 +355,45 @@ def invertible(module, tensor_name="weight", f="softplus", triv="expm"): return _register_manifold(module, tensor_name, GLp, f, triv) +def sln(module, tensor_name="weight", f="softplus", triv="expm"): + r"""Adds a constraint of having determinant one to the tensor ``module.tensor_name``. + + When accessing ``module.tensor_name``, the module will return the + parametrized version :math:`X` which will have determinant equal to 1. + + If the tensor has more than two dimensions, the parametrization will be + applied to the last two dimensions. + + Examples:: + + >>> layer = nn.Linear(20, 20) + >>> geotorch.sln(layer, "weight") + >>> torch.det(layer.weight) + tensor(1.0000) + + Args: + module (nn.Module): module on which to register the parametrization + tensor_name (string): name of the parameter, buffer, or parametrization + on which the parametrization will be applied. Default: ``"weight"`` + f (str or callable or pair of callables): Optional. Either: + + - ``"softplus"`` + + - A callable that maps real numbers to the interval :math:`(0, \infty)` + + - A pair of callables such that the first maps the real numbers to + :math:`(0, \infty)` and the second is a (right) inverse of the first + + Default: ``"softplus"`` + triv (str or callable): Optional. + A map that maps skew-symmetric matrices onto the orthogonal matrices + surjectively. This is used to optimize the :math:`U` and :math:`V` in the + SVD. It can be one of ``["expm", "cayley"]`` or a custom + callable. Default: ``"expm"`` + """ + return _register_manifold(module, tensor_name, SL, f, triv) + + def positive_definite(module, tensor_name="weight", f="softplus", triv="expm"): r"""Adds a positive definiteness constraint to the tensor ``module.tensor_name``. diff --git a/geotorch/fixedrank.py b/geotorch/fixedrank.py index 51fc3e17..36547f1a 100644 --- a/geotorch/fixedrank.py +++ b/geotorch/fixedrank.py @@ -89,7 +89,7 @@ def in_manifold_singular_values(self, S, eps=1e-5): infty_norm = D.abs().max(dim=-1).values return (infty_norm > eps).all().item() - def sample(self, init_=torch.nn.init.xavier_normal_, eps=5e-6): + def sample(self, init_=torch.nn.init.xavier_normal_, eps=5e-6, factorized=False): r""" Returns a randomly sampled matrix on the manifold by sampling a matrix according to ``init_`` and projecting it onto the manifold. @@ -117,7 +117,11 @@ def sample(self, init_=torch.nn.init.xavier_normal_, eps=5e-6): with torch.no_grad(): # S >= 0, as given by torch.linalg.eigvalsh() S[S < eps] = eps - X = (U * S.unsqueeze(-2)) @ V.transpose(-2, -1) - if self.transposed: - X = X.transpose(-2, -1) - return X + if factorized: + return U, S, V + else: + # Compute U S V^T efficiently + if self.transposed: + return (U * S.unsqueeze(-2)) @ V.transpose(-2, -1) + else: + return U @ (S.unsqueeze(-1) * V.transpose(-2, -1)) diff --git a/geotorch/lowrank.py b/geotorch/lowrank.py index fde12af6..6664cb16 100644 --- a/geotorch/lowrank.py +++ b/geotorch/lowrank.py @@ -156,7 +156,10 @@ def sample(self, init_=torch.nn.init.xavier_normal_, factorized=False): U, S, Vt = torch.linalg.svd(X, full_matrices=False) U, S, Vt = U[..., : self.rank], S[..., : self.rank], Vt[..., : self.rank, :] if factorized: - return U, S, Vt.transpose(-2, -1) + if self.transposed: + return Vt.transpose(-2, -1), S, U + else: + return U, S, Vt.transpose(-2, -1) else: X = (U * S.unsqueeze(-2)) @ Vt if self.transposed: diff --git a/geotorch/sl.py b/geotorch/sl.py new file mode 100644 index 00000000..b9532a8a --- /dev/null +++ b/geotorch/sl.py @@ -0,0 +1,89 @@ +import torch +from .glp import GLp +from .fixedrank import FixedRank + + +class SL(GLp): + def __init__(self, size, f="softplus", triv="expm"): + r""" + Manifold of special linear matrices + + Args: + size (torch.size): Size of the tensor to be parametrized + f (str or callable or pair of callables): Optional. Either: + + - ``"softplus"`` + + - A callable that maps real numbers to the interval :math:`(0, \infty)` + + - A pair of callables such that the first maps the real numbers to + :math:`(0, \infty)` and the second is a (right) inverse of the first + + Default: ``"softplus"`` + triv (str or callable): Optional. + A map that maps skew-symmetric matrices onto the orthogonal matrices + surjectively. This is used to optimize the :math:`U` and :math:`V` in the + SVD. It can be one of ``["expm", "cayley"]`` or a custom + callable. Default: ``"expm"`` + """ + super().__init__(size, SL.parse_f(f), triv) + + @staticmethod + def parse_f(f_name): + if f_name in FixedRank.fs.keys(): + f, inv = FixedRank.parse_f(f_name) + + def f_sl(x): + y = f(x) + return y / y.prod(dim=-1, keepdim=True).pow(1.0 / y.shape[-1]) + + return (f_sl, inv) + else: + return f_name + + def in_manifold_singular_values(self, S, eps=5e-3): + if not super().in_manifold_singular_values(S, eps): + return False + # We compute the \infty-norm of the determinant minus 1 and should be about zero + infty_norm = (S.prod(dim=-1) - 1).abs().max(dim=-1).values + return (infty_norm < eps).all().item() + + def in_manifold(self, X, eps=5e-3): + r""" + Checks that a given matrix is in the manifold. + + Args: + X (torch.Tensor or tuple): The input matrix or matrices of shape ``(*, n, k)``. + eps (float): Optional. Threshold at which the singular values are + considered to be zero + Default: ``5e-3`` + """ + # The purpose of this function is just to have a more lax default eps value + return super().in_manifold(X, eps) + + def sample(self, init_=torch.nn.init.xavier_normal_, eps=5e-6, factorized=False): + r""" + Returns a randomly sampled matrix on the manifold by sampling a matrix according + to ``init_`` and projecting it onto the manifold. + + The output of this method can be used to initialize a parametrized tensor + that has been parametrized with this or any other manifold as:: + + >>> layer = nn.Linear(20, 20) + >>> M = SL(layer.weight.size(), rank=6) + >>> geotorch.register_parametrization(layer, "weight", M) + >>> layer.weight = M.sample() + + Args: + init\_ (callable): Optional. A function that takes a tensor and fills it + in place according to some distribution. See + `torch.init `_. + Default: ``torch.nn.init.xavier_normal_`` + eps (float): Optional. Minimum singular value of the sampled matrix. + Default: ``5e-6`` + """ + U, S, V = super().sample(factorized=True, init_=init_) + with torch.no_grad(): + # S >= 0, as given by torch.linalg.eigvalsh() + S = S / S.prod(dim=-1, keepdim=True).pow(1.0 / S.shape[-1]) + return (U * S.unsqueeze(-2)) @ V.transpose(-2, -1) diff --git a/test/test_integration.py b/test/test_integration.py index 6af0291a..1943289a 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -20,6 +20,7 @@ from geotorch.pssdlowrank import PSSDLowRank from geotorch.pssdfixedrank import PSSDFixedRank from geotorch.glp import GLp +from geotorch.sl import SL from geotorch.almostorthogonal import AlmostOrthogonal from geotorch.sphere import Sphere, SphereEmbedded @@ -112,9 +113,11 @@ def test_psd_and_glp(self): PSD, PSSD, GLp, + SL, geotorch.positive_definite, geotorch.positive_semidefinite, geotorch.invertible, + geotorch.sln, ], [{}], [{}], diff --git a/test/test_sl.py b/test/test_sl.py new file mode 100644 index 00000000..58f0d0ba --- /dev/null +++ b/test/test_sl.py @@ -0,0 +1,13 @@ +from unittest import TestCase + +from geotorch.sl import SL + + +class TestSL(TestCase): + def test_SL_errors(self): + # Non square + with self.assertRaises(ValueError): + SL(size=(4, 3)) + # Try to instantiate it in a vector rather than a matrix + with self.assertRaises(ValueError): + SL(size=(5,))