-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Better treatment for numerical errors #330
Conversation
I think I actually prefer the approach in #327 (but throwing a proper error instead of just printing) because it doesn't split
I'm not a big fan of using |
Thanks for taking care of this btw :) |
I did the benchmarks using BenchmarkTools
function scalar_with_try_catch(n::Int, F::Float64, v::Float64)
llk = 0.0
HALF_LOG_2_PI = 0.5 * log(2*pi)
for i in 1:n
try
llk -= (
HALF_LOG_2_PI + 0.5 * (log(F) + v^2 / F)
)
catch
error("Numerical error, F is negative: $F")
end
end
return llk
end
function scalar_with_check(n::Int, F::Float64, v::Float64)
llk = 0.0
HALF_LOG_2_PI = 0.5 * log(2*pi)
for i in 1:n
if F < 0
error("Numerical error, F is negative: $F")
end
llk -= (
HALF_LOG_2_PI + 0.5 * (log(F) + v^2 / F)
)
end
return llk
end
n = 1000
F = 1.0
v = 1.0
julia> @btime scalar_with_try_catch($n, $F, $v)
23.400 μs (0 allocations: 0 bytes)
-1418.9385332046932
julia> @btime scalar_with_check($n, $F, $v)
5.714 μs (0 allocations: 0 bytes)
-1418.9385332046932
function vetorial_with_try_catch(n::Int, F::Matrix{Float64}, v::Vector{Float64})
llk = 0.0
HALF_LOG_2_PI = 0.5 * log(2*pi)
for i in 1:n
try
llk -=
HALF_LOG_2_PI + 0.5 * (logdet(F) +
v' * inv(F) * v)
catch
error("Numerical error, F is negative: $(F[i])")
end
end
return llk
end
function vetorial_with_check(n::Int, F::Matrix{Float64}, v::Vector{Float64})
llk = 0.0
HALF_LOG_2_PI = 0.5 * log(2*pi)
for i in 1:n
detF = det(F)
if detF < 0
error("Numerical error, F is negative: $(F[i])")
end
llk -=
HALF_LOG_2_PI + 0.5 * (log(detF) +
v' * inv(F) * v)
end
return llk
end
function vetorial_with_check_and_logdet(n::Int, F::Matrix{Float64}, v::Vector{Float64})
llk = 0.0
HALF_LOG_2_PI = 0.5 * log(2*pi)
for i in 1:n
detF = det(F)
if detF < 0
error("Numerical error, F is negative: $(F[i])")
end
llk -=
HALF_LOG_2_PI + 0.5 * (logdet(F) +
v' * inv(F) * v)
end
return llk
end
n = 1000
F = Matrix{Float64}(I, 3, 3)
v = ones(3)
julia> GC.gc()
julia> GC.gc()
julia> @btime vetorial_with_try_catch($n, $F, $v)
3.747 ms (4000 allocations: 218.75 KiB)
-2418.9385332047
julia> GC.gc()
julia> GC.gc()
julia> @btime vetorial_with_check($n, $F, $v) # might have been a little unlucky in this one
8.012 ms (6000 allocations: 296.88 KiB)
-2418.9385332047
julia> GC.gc()
julia> GC.gc()
julia> @btime vetorial_with_check_and_logdet($n, $F, $v)
6.133 ms (6000 allocations: 296.88 KiB)
-2418.9385332047 In the scalar case it makes more difference, not using the try catch, in the vetorial case my approach actually hurts performance. I found this post https://discourse.julialang.org/t/try-catch-statement-really-slow-even-if-catch-is-never-executed/91744/6 revealing that each try means a few nanoseconds. I am going to change it bakc to try catch statements |
What about moving the try block outside the for loop? It doesn't matter in which step it fails, it only matters that it fails:
|
This was a test to mimic the package behavior, realistically we would have to put the try-catch blocks in the filter recursions. function filter_recursions!(
kalman_state::MultivariateKalmanState{Fl},
sys::LinearMultivariateTimeInvariant,
steadystate_tol::Fl,
skip_llk_instants::Int,
) where Fl
RQR = sys.R * sys.Q * sys.R'
@inbounds for t in 1:size(sys.y, 1)
update_kalman_state!(
kalman_state,
sys.y[t, :],
sys.Z,
sys.T,
sys.H,
RQR,
sys.d,
sys.c,
skip_llk_instants,
steadystate_tol,
t,
)
end
return kalman_state.llk
end It is also a viable option. |
catch | ||
@error("Numerical error in the log-likelihood calculation. F = $(kalman_state.F), v = $(kalman_state.v). F can only be positive.") | ||
rethrow() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wondering if this is any different than just
catch e
@error("Numerical error... :, $e)
But looks good to me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this one does not stop the program, only logs to screen
julia> try; log(-10); catch; @error("oi"); rethrow() end
┌ Error: oi
└ @ Main REPL[102]:1
ERROR: DomainError with -10.0:
log was called with a negative real argument but will only return a complex result if called with a complex argument. Try log(Complex(x)).
Stacktrace:
[1] throw_complex_domainerror(f::Symbol, x::Float64)
@ Base.Math .\math.jl:33
[2] _log
@ .\special\log.jl:295 [inlined]
[3] log(x::Float64)
@ Base.Math .\special\log.jl:261
[4] log(x::Int64)
@ Base.Math .\math.jl:1531
[5] top-level scope
@ REPL[102]:1
julia> try; log(-10); catch ex; @error("oi", ex) end
┌ Error: oi
│ ex =
│ DomainError with -10.0:
│ log was called with a negative real argument but will only return a complex result if called with a complex argument. Try log(Complex(x)).
└ @ Main REPL[103]:1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, it's been a while since I programmed in julia... forgot @error
just logs! Thanks
@raphaelsaavedra I moved the try-catch blocks to filter recursions. I think this is the best way to go. If there is any error, it will report the entire Kalman state and tell where the error happened. |
replace #327
close #313