Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Optim.jl directly for Optim solvers #129

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.7.1"
[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Folds = "41a02a25-b8f0-4f67-bc48-60067656b558"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
Expand Down Expand Up @@ -38,6 +39,7 @@ TuringExt = ["DynamicPPL", "MCMCChains", "Turing"]
[compat]
Accessors = "0.1"
Distributions = "0.25"
DocStringExtensions = "0.9"
DynamicPPL = "0.20, 0.21"
Folds = "0.2"
ForwardDiff = "0.10"
Expand Down
3 changes: 3 additions & 0 deletions src/Pathfinder.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module Pathfinder

using Distributions: Distributions
using DocStringExtensions
using Folds: Folds
# ensure that ForwardDiff is conditionally loaded by Optimization
using ForwardDiff: ForwardDiff
Expand Down Expand Up @@ -36,6 +37,8 @@ end

include("transducers.jl")
include("woodbury.jl")
include("trace.jl")
include("callbacks.jl")
include("optimize.jl")
include("inverse_hessian.jl")
include("mvnormal.jl")
Expand Down
175 changes: 175 additions & 0 deletions src/callbacks.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# utility macro to annotate a field as const only if supported
@eval macro $(Symbol("const"))(x)
if VERSION ≥ v"1.8"
Expr(:const, esc(x))
else
esc(x)
end
end

# callbacks for Optimization.jl

"""
$(TYPEDEF)

Abstract type for Optimization.jl callbacks.

A callback should be a callable with the signature

(x, f(x), args...) -> Bool

where `x` is the parameter being optimized, and `f` is the objective function. A return
value of `true` signals that optimization should stop.

See the [Optimization.jl docs](https://docs.sciml.ai/Optimization/stable/API/solve/) for
more information.
"""
abstract type AbstractOptimizationCallback end

"""
$(TYPEDEF)

A sequence of Optimization.jl callbacks to be executed in order.
"""
struct CallbackSequence{C<:Tuple} <: AbstractOptimizationCallback
"Tuple of Optimization.jl callbacks to be called in order"
callbacks::C
end

"""
$(SIGNATURES)

Wrap the sequence `callbacks`.
"""
CallbackSequence(callbacks...) = CallbackSequence(callbacks)

function (callback::CallbackSequence)(args...)
return mapfoldl(|, callback.callbacks; init=false) do cb
cb === nothing && return false
return cb(args...)
end
end

"""
$(TYPEDEF)

A callback that signals termination if the objective value is non-finite and `fail=true`.
"""
struct CheckFiniteValueCallback <: AbstractOptimizationCallback
"Whether to raise an error if the objective function is non-finite"
fail::Bool
end

function (cb::CheckFiniteValueCallback)(x, fx, args...)
return cb.fail && (isnan(fx) || fx == -Inf)
end

"""
$(TYPEDEF)

A callback to fill an [`OptimizationTrace`](@ref)

!!! note
Optimization doesn't have an interface for accessing the gradient trace, so this
callback recomputes the gradient.

# Fields

$(FIELDS)
"""
struct FillTraceCallback{G,T<:OptimizationTrace} <: AbstractOptimizationCallback
"A function to compute the gradient of the objective function"
grad::G
"An `OptimizationTrace` with empty vectors to be filled."
trace::T
end

function (cb::FillTraceCallback)(x, fx, args...)
# NOTE: Optimization doesn't have an interface for accessing the gradient trace,
# so we need to recompute it ourselves
# see https://github.com/SciML/Optimization.jl/issues/149
∇fx = cb.grad(x)
rmul!(∇fx, -1)

trace = cb.trace
# some backends mutate x, so we must copy it
push!(trace.points, copy(x))
push!(trace.log_densities, -fx)
push!(trace.gradients, ∇fx)
return false
end

"""
$(TYPEDEF)

A callback to log progress with a `reporter`

# Fields

$(FIELDS)
"""
Base.@kwdef mutable struct ProgressCallback{R,U} <: AbstractOptimizationCallback
"Reporter function, called with signature `report(progress_id, maxiters, try_id, iter_id)`"
@const reporter::R
"An identifier for the progress bar."
@const progress_id::U
"Maximum number of iterations"
@const maxiters::Int
"Index of the current try"
try_id::Int
"Index of the current iteration"
iter_id::Int
end

function (cb::ProgressCallback)(args...)
reporter = cb.reporter
reporter === nothing || reporter(cb.progress_id, cb.maxiters, cb.try_id, cb.iter_id)
cb.iter_id += 1
return false
end

"""
$(SIGNATURES)

Report progress using ProgressLogging.jl.
"""
function report_progress(progress_id, maxiters::Int, try_id::Int, iter_id::Int)
Base.@logmsg ProgressLogging.ProgressLevel "Optimizing (try $(try_id))" progress =
iter_id / maxiters _id = progress_id
return nothing
end

# callbacks for Optim.jl

"""
$(TYPEDEF)

Abstract type for Optim.jl callbacks.

A callback should be a callable with the signature

(states::Vector{<:AbstractOptimizerState}) -> Bool

where `x` is the parameter being optimized, and `f` is the objective function. A return
value of `true` signals that optimization should stop.
"""
abstract type AbstractOptimJLCallback end

"""
$(TYPEDEF)

Adaptor for an Optimization.jl callback to be an Optim.jl callback.
"""
struct OptimJLCallbackAdaptor{C} <: AbstractOptimJLCallback
"An Optimization.jl callback to be called."
callback::C
end

function (cb::OptimJLCallbackAdaptor)(states)
state = states[end]
md = state.metadata
x = md["x"]
fx = state.value
return cb.callback(x, fx, md["g(x)"])
end
(::OptimJLCallbackAdaptor{Nothing})(states) = false
22 changes: 15 additions & 7 deletions src/multipath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,23 @@ for approximating expectations with respect to ``p``.
"""
function multipathfinder end

function multipathfinder(ℓ, ndraws::Int; input=ℓ, kwargs...)
function multipathfinder(
ℓ,
ndraws::Int;
input=ℓ,
history_length::Int=DEFAULT_HISTORY_LENGTH,
optimizer=default_optimizer(history_length),
kwargs...,
)
_check_log_density_problem(ℓ)
dim = LogDensityProblems.dimension(ℓ)
optim_fun = build_optim_function(ℓ)
return multipathfinder(optim_fun, ndraws; input, dim, kwargs...)
optim_fun = build_optim_function(ℓ, optimizer)
return multipathfinder(
optim_fun, ndraws; input, dim, history_length, optimizer, kwargs...
)
end
function multipathfinder(
optim_fun::SciMLBase.OptimizationFunction,
optim_fun::Union{SciMLBase.OptimizationFunction,OptimJLFunction},
ndraws::Int;
init=nothing,
input=optim_fun,
Expand All @@ -145,9 +154,8 @@ function multipathfinder(
importance::Bool=true,
kwargs...,
)
if optim_fun.grad === nothing || optim_fun.grad isa Bool
_defines_gradient(optim_fun) ||
throw(ArgumentError("optimization function must define a gradient function."))
end
if init === nothing
nruns > 0 || throw(
ArgumentError("A positive `nruns` must be set or `init` must be provided.")
Expand All @@ -159,7 +167,7 @@ function multipathfinder(
if ndraws > ndraws_per_run * nruns
@warn "More draws requested than total number of draws across replicas. Draws will not be unique."
end
logp(x) = -optim_fun.f(x, nothing)
logp = get_logp(optim_fun)

# run pathfinder independently from each starting point
trans = Transducers.Map() do (init_i)
Expand Down
Loading