Skip to content

Commit

Permalink
refined error messages
Browse files Browse the repository at this point in the history
  • Loading branch information
odunbar committed Jan 24, 2025
1 parent 5954be2 commit 18196d1
Showing 1 changed file with 5 additions and 16 deletions.
21 changes: 5 additions & 16 deletions src/MarkovChainMonteCarlo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,20 +219,14 @@ end
## -------- Autodifferentiation procedures ------ ##

function autodiff_gradient(model::AdvancedMH.DensityModel, params, autodiff_protocol)
if autodiff_protocol == GradFreeProtocol
throw(
ArgumentError(
"Calling gradient on a Sampling with `GradFreeProtocol`. \n Please construct the `XYZSampling{<:AutodiffProtocol}()` object with `AutodiffProtocol` from the implemented options that is compatible with the chosen Emulator (e.g., `AbstractGP` emulators are compatible with `ForwardDiffProtocol`).",
),
)
elseif autodiff_protocol == ForwardDiffProtocol
if autodiff_protocol == ForwardDiffProtocol
return ForwardDiff.gradient(x -> AdvancedMH.logdensity(model, x), params)
elseif autodiff_protocol == ReverseDiffProtocol
return ReverseDiff.gradient(x -> AdvancedMH.logdensity(model, x), params)
else
throw(
ArgumentError(
"autodifferentiation protocol $(autodiff_protocol) has no `autodiff_gradient` method implemented.",
"Calling `autodiff_gradient(...)` on a sampler with protocol $(autodiff_protocol) that has *no* gradient implementation.\n Please select from a protocol with a gradient implementation (e.g., `ForwardDiffProtocol`).",
),
)
end
Expand All @@ -243,22 +237,17 @@ autodiff_gradient(model::AdvancedMH.DensityModel, params, sampler::MH) where {MH
autodiff_gradient(model::AdvancedMH.DensityModel, params, typeof(sampler).parameters[2]) # hacky way of getting the "AutodiffProtocol"

function autodiff_hessian(model::AdvancedMH.DensityModel, params, autodiff_protocol)
if autodiff_protocol == GradFreeProtocol
throw(
ArgumentError(
"Calling hessian on a Sampling with `GradFreeProtocol`. \n Please construct the `XYZSampling{<:AutodiffProtocol}()` object with `AutodiffProtocol` from the implemented options that is compatible with the chosen Emulator (e.g., `AbstractGP` emulators are compatible with `ForwardDiffProtocol`).",
),
)
elseif autodiff_protocol == ForwardDiffProtocol
if autodiff_protocol == ForwardDiffProtocol
return Symmetric(ForwardDiff.hessian(x -> AdvancedMH.logdensity(model, x), params))
elseif autodiff_protocol == ReverseDiffProtocol
return Symmetric(ReverseDiff.hessian(x -> AdvancedMH.logdensity(model, x), params))

Check warning on line 243 in src/MarkovChainMonteCarlo.jl

View check run for this annotation

Codecov / codecov/patch

src/MarkovChainMonteCarlo.jl#L239-L243

Added lines #L239 - L243 were not covered by tests
else
throw(

Check warning on line 245 in src/MarkovChainMonteCarlo.jl

View check run for this annotation

Codecov / codecov/patch

src/MarkovChainMonteCarlo.jl#L245

Added line #L245 was not covered by tests
ArgumentError(
"autodifferentiation protocol $(autodiff_protocol) has no `autodiff_hessian` method implemented.",
"Calling `autodiff_hessian(...)` on a sampler with protocol $(autodiff_protocol) that has *no* hessian implementation.\n Please select from a protocol with a hessian implementation (e.g., `ForwardDiffProtocol`).",
),
)

end
end

Expand Down

0 comments on commit 18196d1

Please sign in to comment.