Skip to content

Commit

Permalink
Some more consistency changes to subgrad_distance (#32)
Browse files Browse the repository at this point in the history
* Fix doc typo and make functions type-proof
* Make statement more precise
* Export subgrad_distance
* Run formatter
* Bump version
* Change eltype to number_eltype
  • Loading branch information
hajg-ijk authored Sep 22, 2023
1 parent cf4a488 commit ef69a8e
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ManifoldDiff"
uuid = "af67fdf4-a580-4b9f-bbec-742ef357defd"
authors = ["Seth Axen <[email protected]>", "Mateusz Baran <[email protected]>", "Ronny Bergmann <[email protected]>"]
version = "0.3.6"
version = "0.3.7"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Providing a derivative, differential or gradient for a given function, this pack
For example

* `grad_f` for a gradient ``\operatorname{grad} f``
* `subgrad_f` for a subgradient ``\partial f``
* `subgrad_f` for a subgradient from the subdifferential``\partial f``
* `differential_f` for ``Df`` (also called pushforward)
* `differential_f_variable` if `f` has multiple variables / parameters, since a usual writing in math is ``f_x`` in this case
* `adjoint_differential_f` for pullbacks
Expand Down
7 changes: 6 additions & 1 deletion src/ManifoldDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,11 @@ include("forward_diff.jl")
include("reverse_diff.jl")
include("zygote.jl")

export riemannian_gradient, riemannian_gradient!, riemannian_Hessian, riemannian_Hessian!
export riemannian_gradient,
riemannian_gradient!,
riemannian_Hessian,
riemannian_Hessian!,
subgrad_distance,
subgrad_distance!
export set_default_differential_backend!, default_differential_backend
end # module
16 changes: 8 additions & 8 deletions src/subgradients.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

@doc raw"""
subgrad_distance(M, q, p[, c = 1; atol = 0])
subgrad_distance!(M, X, q, p[, c = 1; atol = 0])
subgrad_distance(M, q, p[, c = 2; atol = 0])
subgrad_distance!(M, X, q, p[, c = 2; atol = 0])
compute the subgradient of the distance (in place of `X`)
Expand All @@ -21,10 +21,10 @@ for ``c\neq 1`` or ``p\neq q``. Note that for the remaining case ``c=1``,
# Optional
* `c` – (`1`) the exponent of the distance, i.e. the default is the distance
* `c` – (`2`) the exponent of the distance, i.e. the default is the distance
* `atol` – (`0`) the tolerance to use when evaluating the distance between `p` and `q`.
"""
function subgrad_distance(M, q, p, c::Int = 2; atol = 0)
function subgrad_distance(M, q, p, c::Int = 2; atol = zero(number_eltype(p)))
if c == 2
return -log(M, p, q)
elseif c == 1 && distance(M, q, p) atol
Expand All @@ -33,10 +33,10 @@ function subgrad_distance(M, q, p, c::Int = 2; atol = 0)
return -distance(M, p, q)^(c - 2) * log(M, p, q)
end
end
function subgrad_distance!(M, X, q, p, c::Int = 2; atol = 0)
function subgrad_distance!(M, X, q, p, c::Int = 2; atol = zero(number_eltype(p)))
log!(M, X, p, q)
if c == 2
X .*= -one(eltype(X))
X .*= -one(number_eltype(X))
elseif c == 1 && distance(M, q, p) atol
normal_cone_vector!(M, X, p)
else
Expand All @@ -46,15 +46,15 @@ function subgrad_distance!(M, X, q, p, c::Int = 2; atol = 0)
end
function normal_cone_vector(M, p)
Y = rand(M; vector_at = p)
if norm(M, p, Y) > 1.0
if norm(M, p, Y) > one(number_eltype(Y))
Y ./= norm(M, p, Y)
Y .*= rand()
end
return Y
end
function normal_cone_vector!(M, Y, p)
ManifoldsBase.rand!(M, Y; vector_at = p)
if norm(M, p, Y) > 1.0
if norm(M, p, Y) > one(number_eltype(Y))
Y ./= norm(M, p, Y)
Y .*= rand()
end
Expand Down

0 comments on commit ef69a8e

Please sign in to comment.