Skip to content

Commit

Permalink
Add Instance Definitions for log1p, expm1, log1pexp, and `log1m…
Browse files Browse the repository at this point in the history
…exp` (#111)

* add Numeric imports

* add instance definitions for log1p, expm1, log1pexp, and log1mexp

* add regression tests for extra floating point functions

* add extra edge case test for log1mexp

* exclude test case for old versions of base with faulty log1mexp
  • Loading branch information
julmb authored Mar 12, 2024
1 parent 2e37c7d commit 56a2ef5
Show file tree
Hide file tree
Showing 14 changed files with 65 additions and 1 deletion.
5 changes: 5 additions & 0 deletions include/instances.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ instance BODY1(Floating a) Floating HEAD where
acosh = lift1 acosh $ \x -> recip (sqrt (join (*) x - 1))
atanh = lift1 atanh $ \x -> recip (1 - join (*) x)

log1p = lift1 log1p $ recip . (+) 1
expm1 = lift1 expm1 exp
log1pexp = lift1 log1pexp $ recip . (+) 1 . exp . negate
log1mexp = lift1 log1mexp $ recip . negate . expm1 . negate

instance BODY2(Num a, Enum a) Enum HEAD where
succ = lift1 succ (const 1)
pred = lift1 pred (const 1)
Expand Down
1 change: 1 addition & 0 deletions include/internal_kahn.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ import System.IO.Unsafe (unsafePerformIO)
import Data.Data (Data)
import Data.Typeable (Typeable)
import qualified GHC.Exts as Exts
import Numeric
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Identity
import Numeric.AD.Jacobian
Expand Down
1 change: 1 addition & 0 deletions src/Numeric/AD/Internal/Dense.hs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import Data.Typeable ()
import Data.Traversable (mapAccumL)
import Data.Data ()
import Data.Number.Erf
import Numeric
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Identity
import Numeric.AD.Jacobian
Expand Down
1 change: 1 addition & 0 deletions src/Numeric/AD/Internal/Dense/Representable.hs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import Data.Functor.Rep
import Data.Typeable ()
import Data.Data ()
import Data.Number.Erf
import Numeric
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Identity
import Numeric.AD.Jacobian
Expand Down
1 change: 1 addition & 0 deletions src/Numeric/AD/Internal/Forward.hs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import Data.Foldable (toList)
import Data.Traversable (mapAccumL)
import Data.Data
import Data.Number.Erf
import Numeric
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Identity
import Numeric.AD.Jacobian
Expand Down
1 change: 1 addition & 0 deletions src/Numeric/AD/Internal/Forward/Double.hs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import Data.Foldable (toList)
import Data.Traversable (mapAccumL)
import Control.Monad (join)
import Data.Number.Erf
import Numeric
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Identity
import Numeric.AD.Jacobian
Expand Down
1 change: 1 addition & 0 deletions src/Numeric/AD/Internal/Kahn.hs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ import qualified Data.Reify.Graph as Reified
import System.IO.Unsafe (unsafePerformIO)
import Data.Data (Data)
import Data.Typeable (Typeable)
import Numeric
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Identity
import Numeric.AD.Jacobian
Expand Down
1 change: 1 addition & 0 deletions src/Numeric/AD/Internal/Reverse.hs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ import Data.Proxy
import Data.Reflection
import Data.Traversable (mapM)
import Data.Typeable
import Numeric
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Identity
import Numeric.AD.Jacobian
Expand Down
1 change: 1 addition & 0 deletions src/Numeric/AD/Internal/Reverse/Double.hs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ import Data.Proxy
import Data.Reflection
import Data.Traversable (mapM)
import Data.Typeable
import Numeric
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Identity
import Numeric.AD.Jacobian
Expand Down
1 change: 1 addition & 0 deletions src/Numeric/AD/Internal/Sparse.hs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import qualified Data.IntMap as IntMap
import Data.Number.Erf
import Data.Traversable
import Data.Typeable ()
import Numeric
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Sparse.Common
import Numeric.AD.Jacobian
Expand Down
1 change: 1 addition & 0 deletions src/Numeric/AD/Internal/Sparse/Double.hs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import qualified Data.IntMap as IntMap
import Data.Number.Erf
import Data.Traversable
import Data.Typeable ()
import Numeric
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Sparse.Common
import Numeric.AD.Jacobian
Expand Down
1 change: 1 addition & 0 deletions src/Numeric/AD/Internal/Tower.hs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import Data.Foldable
import Data.Data (Data)
import Data.Number.Erf
import Data.Typeable (Typeable)
import Numeric
import Numeric.AD.Internal.Combinators
import Numeric.AD.Jacobian
import Numeric.AD.Mode
Expand Down
1 change: 1 addition & 0 deletions src/Numeric/AD/Internal/Tower/Double.hs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import Data.Foldable
import Data.Data (Data)
import Data.Number.Erf
import Data.Typeable (Typeable)
import Numeric
import Numeric.AD.Internal.Combinators
import Numeric.AD.Jacobian
import Numeric.AD.Mode
Expand Down
49 changes: 48 additions & 1 deletion tests/Regression.hs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE RankNTypes #-}

module Main (main) where

import Numeric
import qualified Numeric.AD.Mode.Forward as F
import qualified Numeric.AD.Mode.Forward.Double as FD
import qualified Numeric.AD.Mode.Reverse as R
Expand Down Expand Up @@ -30,7 +32,11 @@ tests = testGroup "tests" [
mode "reverse-double" (\ f -> RD.diff' f) (\ f -> RD.grad f) (\ f -> RD.jacobian f) (\ f -> RD.hessian f)]

mode :: String -> Diff' -> Grad -> Jacobian -> Hessian -> TestTree
mode name diff grad jacobian hessian = testGroup name [basic diff grad jacobian hessian, issue97 diff, issue104 diff grad]
mode name diff grad jacobian hessian = testGroup name [
basic diff grad jacobian hessian,
issue97 diff,
issue104 diff grad,
issue108 diff]

basic :: Diff' -> Grad -> Jacobian -> Hessian -> TestTree
basic diff grad jacobian hessian = testGroup "basic" [tdiff, tgrad, tjacobian, thessian] where
Expand Down Expand Up @@ -92,6 +98,47 @@ issue104 diff grad = testGroup "issue-104" [inside, outside] where
f x y = sqrt x * sqrt y -- grad f [x, y] = [sqrt y / 2 sqrt x, sqrt x / 2 sqrt y]
binary f [x, y] = f x y

issue108 :: Diff' -> TestTree
issue108 diff = testGroup "issue-108" [tlog1p, texpm1, tlog1pexp, tlog1mexp] where
tlog1p = testCase "log1p" $ do
equal (-inf, inf) $ diff log1p (-1)
equal (-1.0000000000000007e-15, 1.000000000000001) $ diff log1p (-1e-15)
equal (-1e-20, 1) $ diff log1p (-1e-20)
equal (0, 1) $ diff log1p 0
equal (1e-20, 1) $ diff log1p 1e-20
equal (9.999999999999995e-16, 0.9999999999999989) $ diff log1p 1e-15
equal (0.6931471805599453, 0.5) $ diff log1p 1
texpm1 = testCase "expm1" $ do
equal (-0.6321205588285577, 0.36787944117144233) $ diff expm1 (-1)
equal (-9.999999999999995e-16, 0.999999999999999) $ diff expm1 (-1e-15)
equal (-1e-20, 1) $ diff expm1 (-1e-20)
equal (0, 1) $ diff expm1 0
equal (1e-20, 1) $ diff expm1 1e-20
equal (1.0000000000000007e-15, 1.000000000000001) $ diff expm1 1e-15
equal (1.718281828459045, 2.718281828459045) $ diff expm1 1
tlog1pexp = testCase "log1pexp" $ do
equal (0, 0) $ diff log1pexp (-1000)
equal (3.720075976020836e-44, 3.7200759760208356e-44) $ diff log1pexp (-100)
equal (0.31326168751822286, 0.2689414213699951) $ diff log1pexp (-1)
equal (0.6931471805599453, 0.5) $ diff log1pexp 0
equal (1.3132616875182228, 0.7310585786300049) $ diff log1pexp 1
equal (100, 1) $ diff log1pexp 100
equal (1000, 1) $ diff log1pexp 1000
tlog1mexp = testCase "log1mexp" $ do
equal (-0, -0) $ diff log1mexp (-1000)
-- old versions of base have a faulty implementation of log1mexp, causing this case to fail
-- see also https://gitlab.haskell.org/ghc/ghc/-/issues/17125
#if MIN_VERSION_base(4, 13, 0)
equal (-3.720075976020836e-44, -3.7200759760208356e-44) $ diff log1mexp (-100)
#endif
equal (-0.45867514538708193, -0.5819767068693265) $ diff log1mexp (-1)
equal (-0.9327521295671886, -1.5414940825367982) $ diff log1mexp (-0.5)
equal (-2.3521684610440907, -9.50833194477505) $ diff log1mexp (-0.1)
equal (-34.538776394910684, -9.999999999999994e14) $ diff log1mexp (-1e-15)
equal (-46.051701859880914, -1e20) $ diff log1mexp (-1e-20)
equal (-inf, -inf) $ diff log1mexp (-0)
equal = expect $ \ (a, b) (c, d) -> eq a c && eq b d

-- TODO: ideally, we would consider `0` and `-0` to be different
-- however, zero signedness is currently not reliably propagated through some modes
-- see also https://github.com/ekmett/ad/issues/109 and https://github.com/ekmett/ad/pull/110
Expand Down

0 comments on commit 56a2ef5

Please sign in to comment.