From 124915fce203925b69fa1a295a2ab3025cbe3f3c Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 7 May 2024 23:00:50 +0200 Subject: [PATCH] Add `erfinv` and `erfcinv` for `Float16` and generalize `logerfc` and `logerfcx` (#372) * Add `erfinv` and `erfcinv` for `Float16` * Generalize `logerfc` and `logerfcx` * Add tests * Update version number * Update test/erf.jl * Fix test * Simplify branch for `abs(x) >= 1` in `_erfinv` * Fix incomplete function * Simplify implementation * Apply suggested alternative --------- Co-authored-by: Viral B. Shah --- Project.toml | 2 +- src/erf.jl | 89 +++++++++++++++++++++++++++++++++++-------------- test/erf.jl | 93 +++++++++++++++++++++------------------------------- 3 files changed, 104 insertions(+), 80 deletions(-) diff --git a/Project.toml b/Project.toml index 6b128971..245e52ce 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "SpecialFunctions" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.3.1" +version = "2.4.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/erf.jl b/src/erf.jl index 97dafcb7..dd04e904 100644 --- a/src/erf.jl +++ b/src/erf.jl @@ -253,13 +253,10 @@ erfinv(x::Real) = _erfinv(float(x)) function _erfinv(x::Float64) a = abs(x) - if a >= 1.0 - if x == 1.0 - return Inf - elseif x == -1.0 - return -Inf - end + if a > 1.0 throw(DomainError(a, "`abs(x)` cannot be greater than 1.")) + elseif a == 1.0 + return copysign(Inf, x) elseif a <= 0.75 # Table 17 in Blair et al. t = x*x - 0.5625 return x * @horner(t, 0.16030_49558_44066_229311e2, @@ -321,13 +318,10 @@ end function _erfinv(x::Float32) a = abs(x) - if a >= 1.0f0 - if x == 1.0f0 - return Inf32 - elseif x == -1.0f0 - return -Inf32 - end + if a > 1f0 throw(DomainError(a, "`abs(x)` cannot be greater than 1.")) + elseif a == 1f0 + return copysign(Inf32, x) elseif a <= 0.75f0 # Table 10 in Blair et al. t = x*x - 0.5625f0 return x * @horner(t, -0.13095_99674_22f2, @@ -362,6 +356,42 @@ function _erfinv(x::Float32) end end +function _erfinv(x::Float16) + a = abs(x) + if a > Float16(1) + throw(DomainError(a, "`abs(x)` cannot be greater than 1.")) + elseif a == Float16(1) + return copysign(Inf16, x) + else + # Perform calculations with `Float32` + x32 = Float32(x) + a32 = Float32(a) + if a32 <= 0.75f0 + # Simpler and more accurate alternative to Table 7 in Blair et al. + # Ref: https://github.com/JuliaMath/SpecialFunctions.jl/pull/372#discussion_r1592832735 + t = muladd(-6.73815f1, x32, 1f0) / muladd(-4.18798f0, x32, 4.54263f0) + y = copysign(muladd(0.88622695f0, x32, t), x32) + elseif a32 <= 0.9375f0 # Table 26 in Blair et al. + t = x32^2 - 0.87890625f0 + y = x32 * @horner(t, 0.10178_950f1, + -0.32827_601f1) / + @horner(t, 0.72455_99f0, + -0.33871_553f1, + 0.1f1) + else + # Simpler alternative to Table 47 in Blair et al. + # because of the reduced accuracy requirement + # (it turns out that this branch only covers 128 values). + # Note that the use of log(1-x) rather than log1p is intentional since it will be + # slightly faster and 1-x is exact. + # Ref: https://github.com/JuliaMath/SpecialFunctions.jl/pull/372#discussion_r1592710586 + t = sqrt(-log(1-a32)) + y = copysign(@horner(t, -0.429159f0, 1.04868f0), x32) + end + return Float16(y) + end +end + function _erfinv(y::BigFloat) xfloat = erfinv(Float64(y)) if isfinite(xfloat) @@ -482,6 +512,25 @@ function _erfcinv(y::Float32) end end +function _erfcinv(y::Float16) + if y > Float16(0.0625) + return erfinv(Float16(1) - y) + elseif y <= Float16(0) + if y == Float16(0) + return Inf16 + end + throw(DomainError(y, "`y` must be nonnegative.")) + else # Table 47 in Blair et al. + t = 1.0f0 / sqrt(-log(Float32(y))) + x = @horner(t, 0.98650_088f0, + 0.92601_777f0) / + (t * @horner(t, 0.98424_719f0, + 0.10074_7432f0, + 0.1f0)) + return Float16(x) + end +end + function _erfcinv(y::BigFloat) yfloat = Float64(y) xfloat = erfcinv(yfloat) @@ -526,13 +575,9 @@ See also: [`erfcx(x)`](@ref erfcx). # Implementation Based on the [`erfc(x)`](@ref erfc) and [`erfcx(x)`](@ref erfcx) functions. -Currently only implemented for `Float32`, `Float64`, and `BigFloat`. """ -logerfc(x::Real) = _logerfc(float(x)) - -function _logerfc(x::Union{Float32, Float64, BigFloat}) - # Don't include Float16 in the Union, otherwise logerfc would currently work for x <= 0.0, but not x > 0.0 - if x > 0.0 +function logerfc(x::Real) + if x > zero(x) return log(erfcx(x)) - x^2 else return log(erfc(x)) @@ -557,13 +602,9 @@ See also: [`erfcx(x)`](@ref erfcx). # Implementation Based on the [`erfc(x)`](@ref erfc) and [`erfcx(x)`](@ref erfcx) functions. -Currently only implemented for `Float32`, `Float64`, and `BigFloat`. """ -logerfcx(x::Real) = _logerfcx(float(x)) - -function _logerfcx(x::Union{Float32, Float64, BigFloat}) - # Don't include Float16 in the Union, otherwise logerfc would currently work for x <= 0.0, but not x > 0.0 - if x < 0.0 +function logerfcx(x::Real) + if x < zero(x) return log(erfc(x)) + x^2 else return log(erfcx(x)) diff --git a/test/erf.jl b/test/erf.jl index b0f57f11..29127343 100644 --- a/test/erf.jl +++ b/test/erf.jl @@ -1,65 +1,48 @@ @testset "error functions" begin @testset "real argument" begin - @test erf(Float16(1)) ≈ 0.84270079294971486934 rtol=2*eps(Float16) - @test erf(Float32(1)) ≈ 0.84270079294971486934 rtol=2*eps(Float32) - @test erf(Float64(1)) ≈ 0.84270079294971486934 rtol=2*eps(Float64) - - @test erfc(Float16(1)) ≈ 0.15729920705028513066 rtol=2*eps(Float16) - @test erfc(Float32(1)) ≈ 0.15729920705028513066 rtol=2*eps(Float32) - @test erfc(Float64(1)) ≈ 0.15729920705028513066 rtol=2*eps(Float64) - - @test erfcx(Float16(1)) ≈ 0.42758357615580700442 rtol=2*eps(Float16) - @test erfcx(Float32(1)) ≈ 0.42758357615580700442 rtol=2*eps(Float32) - @test erfcx(Float64(1)) ≈ 0.42758357615580700442 rtol=2*eps(Float64) - - @test_throws MethodError logerfc(Float16(1)) - @test_throws MethodError logerfc(Float16(-1)) - @test logerfc(Float32(-100)) ≈ 0.6931471805599453 rtol=2*eps(Float32) - @test logerfc(Float64(-100)) ≈ 0.6931471805599453 rtol=2*eps(Float64) - @test logerfc(Float32(1000)) ≈ -1.0000074801207219e6 rtol=2*eps(Float32) - @test logerfc(Float64(1000)) ≈ -1.0000074801207219e6 rtol=2*eps(Float64) - @test logerfc(1000) ≈ -1.0000074801207219e6 rtol=2*eps(Float32) - @test logerfc(Float32(10000)) ≈ log(erfc(BigFloat(10000, precision=100))) rtol=2*eps(Float32) - @test logerfc(Float64(10000)) ≈ log(erfc(BigFloat(10000, precision=100))) rtol=2*eps(Float64) - - @test_throws MethodError logerfcx(Float16(1)) - @test_throws MethodError logerfcx(Float16(-1)) - @test iszero(logerfcx(0)) - @test logerfcx(Float32(1)) ≈ -0.849605509933248248576017509499 rtol=2eps(Float32) - @test logerfcx(Float64(1)) ≈ -0.849605509933248248576017509499 rtol=2eps(Float32) - @test logerfcx(Float32(-1)) ≈ 1.61123231767807049464268192445 rtol=2eps(Float32) - @test logerfcx(Float64(-1)) ≈ 1.61123231767807049464268192445 rtol=2eps(Float32) - @test logerfcx(Float32(-100)) ≈ 10000.6931471805599453094172321 rtol=2eps(Float32) - @test logerfcx(Float64(-100)) ≈ 10000.6931471805599453094172321 rtol=2eps(Float64) - @test logerfcx(Float32(100)) ≈ -5.17758512266433257046678208395 rtol=2eps(Float32) - @test logerfcx(Float64(100)) ≈ -5.17758512266433257046678208395 rtol=2eps(Float64) - @test logerfcx(Float32(-1000)) ≈ 1.00000069314718055994530941723e6 rtol=2eps(Float32) - @test logerfcx(Float64(-1000)) ≈ 1.00000069314718055994530941723e6 rtol=2eps(Float64) - @test logerfcx(Float32(1000)) ≈ -7.48012072190621214066734919080 rtol=2eps(Float32) - @test logerfcx(Float64(1000)) ≈ -7.48012072190621214066734919080 rtol=2eps(Float64) - - @test erfi(Float16(1)) ≈ 1.6504257587975428760 rtol=2*eps(Float16) - @test erfi(Float32(1)) ≈ 1.6504257587975428760 rtol=2*eps(Float32) - @test erfi(Float64(1)) ≈ 1.6504257587975428760 rtol=2*eps(Float64) + for T in (Float16, Float32, Float64) + @test @inferred(erf(T(1))) isa T + @test erf(T(1)) ≈ T(0.84270079294971486934) rtol=2*eps(T) - @test erfinv(Integer(0)) == 0 == erfinv(0//1) - @test_throws MethodError erfinv(Float16(1)) - @test erfinv(Float32(0.84270079294971486934)) ≈ 1 rtol=2*eps(Float32) - @test erfinv(Float64(0.84270079294971486934)) ≈ 1 rtol=2*eps(Float64) + @test @inferred(erfc(T(1))) isa T + @test erfc(T(1)) ≈ T(0.15729920705028513066) rtol=2*eps(T) - @test erfcinv(Integer(1)) == 0 == erfcinv(1//1) - @test_throws MethodError erfcinv(Float16(1)) - @test erfcinv(Float32(0.15729920705028513066)) ≈ 1 rtol=2*eps(Float32) - @test erfcinv(Float64(0.15729920705028513066)) ≈ 1 rtol=2*eps(Float64) + @test @inferred(erfcx(T(1))) isa T + @test erfcx(T(1)) ≈ T(0.42758357615580700442) rtol=2*eps(T) + + @test @inferred(logerfc(T(1))) isa T + @test logerfc(T(-100)) ≈ T(0.6931471805599453) rtol=2*eps(T) + @test logerfc(T(1000)) ≈ T(-1.0000074801207219e6) rtol=2*eps(T) + @test logerfc(T(10000)) ≈ T(log(erfc(BigFloat(10000, precision=100)))) rtol=2*eps(T) + + @test @inferred(logerfcx(T(1))) isa T + @test logerfcx(T(1)) ≈ T(-0.849605509933248248576017509499) rtol=2eps(T) + @test logerfcx(T(-1)) ≈ T(1.61123231767807049464268192445) rtol=2eps(T) + @test logerfcx(T(-100)) ≈ T(10000.6931471805599453094172321) rtol=2eps(T) + @test logerfcx(T(100)) ≈ T(-5.17758512266433257046678208395) rtol=2eps(T) + @test logerfcx(T(-1000)) ≈ T(1.00000069314718055994530941723e6) rtol=2eps(T) + @test logerfcx(T(1000)) ≈ T(-7.48012072190621214066734919080) rtol=2eps(T) + + @test @inferred(erfi(T(1))) isa T + @test erfi(T(1)) ≈ T(1.6504257587975428760) rtol=2*eps(T) + + @test @inferred(erfinv(T(1))) isa T + @test erfinv(T(0.84270079294971486934)) ≈ 1 rtol=2*eps(T) - @test dawson(Float16(1)) ≈ 0.53807950691276841914 rtol=2*eps(Float16) - @test dawson(Float32(1)) ≈ 0.53807950691276841914 rtol=2*eps(Float32) - @test dawson(Float64(1)) ≈ 0.53807950691276841914 rtol=2*eps(Float64) + @test @inferred(erfcinv(T(1))) isa T + @test erfcinv(T(0.15729920705028513066)) ≈ 1 rtol=2*eps(T) + @test @inferred(dawson(T(1))) isa T + @test dawson(T(1)) ≈ T(0.53807950691276841914) rtol=2*eps(T) + + @test @inferred(faddeeva(T(1))) isa Complex{T} + @test faddeeva(T(1)) ≈ 0.36787944117144233402+0.60715770584139372446im rtol=2*eps(T) + end + + @test logerfc(1000) ≈ -1.0000074801207219e6 rtol=2*eps(Float32) + @test erfinv(Integer(0)) == 0 == erfinv(0//1) + @test erfcinv(Integer(1)) == 0 == erfcinv(1//1) @test faddeeva(0) == faddeeva(0//1) == 1 - @test faddeeva(Float16(1)) ≈ 0.36787944117144233402+0.60715770584139372446im rtol=2*eps(Float16) - @test faddeeva(Float32(1)) ≈ 0.36787944117144233402+0.60715770584139372446im rtol=2*eps(Float32) - @test faddeeva(Float64(1)) ≈ 0.36787944117144233402+0.60715770584139372446im rtol=2*eps(Float64) end @testset "complex arguments" begin