-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
106 additions
and
85 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,73 +1,82 @@ | ||
{-# LANGUAGE CPP #-} | ||
{-# HLINT ignore "Avoid restricted function" #-} | ||
{-# LANGUAGE NumDecimals #-} | ||
{-# LANGUAGE RecordWildCards #-} | ||
{-# LANGUAGE ScopedTypeVariables #-} | ||
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} | ||
|
||
-- | Find a tipping point between two algorithms. | ||
module Test.Tasty.Bench.Equalize ( | ||
equalize, | ||
mkEqualizeConfig, | ||
EqualizeConfig (..), | ||
) where | ||
|
||
import Control.DeepSeq (NFData) | ||
import Data.List.NonEmpty (NonEmpty) | ||
import qualified Data.List.NonEmpty as NE | ||
import System.IO.Unsafe (unsafeInterleaveIO) | ||
import Test.Tasty (Timeout (..)) | ||
import Test.Tasty.Bench (Benchmarkable, RelStDev (..), measureCpuTimeAndStDev, nf) | ||
import Test.Tasty (Timeout (..), mkTimeout) | ||
import Test.Tasty.Bench (Benchmarkable, RelStDev (..), nf) | ||
import Test.Tasty.Bench.Utils (Measurement (..), getRelStDev, measRelStDev, measure, traceShowM') | ||
|
||
#ifdef DEBUG | ||
import Debug.Trace | ||
#endif | ||
|
||
-- | Configuration for 'fit'. | ||
-- | Configuration for 'equalize'. | ||
data EqualizeConfig = EqualizeConfig | ||
{ eqlFasterOnLow :: Word -> Benchmarkable | ||
-- ^ A benchmark which is faster at 'eqlLow', typically 'nf' @f@. | ||
, eqlFasterOnHigh :: Word -> Benchmarkable | ||
-- ^ A benchmark which is faster at 'eqlHigh', typically 'nf' @g@. | ||
, eqlLow :: Word | ||
-- ^ An argument at which 'eqlFasterOnLow' is faster than 'eqlFasterOnHigh'. | ||
, eqlHigh :: Word | ||
-- ^ An argument at which 'eqlFasterOnHigh' is faster than 'eqlFasterOnLow'. | ||
, eqlTimeout :: Timeout | ||
-- ^ Timeout of individual measurements. | ||
} | ||
|
||
-- | Generate a default 'equalize' configuration. | ||
mkEqualizeConfig | ||
:: (NFData a) | ||
=> (Word -> a) | ||
-- ^ An algorithm which is faster for small arguments, without 'nf'. | ||
-> (Word -> a) | ||
-- ^ An algorithm which is faster for large arguments, without 'nf'. | ||
-> (Word, Word) | ||
-- ^ The smallest and the largest sizes of the input. | ||
-- ^ Small and large arguments. | ||
-> EqualizeConfig | ||
mkEqualizeConfig fLow fHigh (low, high) = | ||
EqualizeConfig | ||
{ eqlFasterOnLow = nf fLow | ||
, eqlFasterOnHigh = nf fHigh | ||
, eqlLow = low | ||
, eqlHigh = high | ||
, eqlTimeout = NoTimeout | ||
, eqlTimeout = mkTimeout 1e8 | ||
} | ||
|
||
equalize :: EqualizeConfig -> IO [(Word, Word)] | ||
equalize EqualizeConfig {..} = go (RelStDev 1.0) eqlLow eqlHigh | ||
equalize :: EqualizeConfig -> IO (NonEmpty (Word, Word)) | ||
equalize EqualizeConfig {..} = NE.fromList <$> go (RelStDev (1 / 3)) eqlLow eqlHigh | ||
where | ||
go std@(RelStDev std') lo hi | ||
| lo + 1 >= hi = pure [(lo, hi)] | ||
| otherwise = | ||
unsafeInterleaveIO $ | ||
((lo, hi) :) <$> do | ||
let mid = (lo + hi) `quot` 2 | ||
traceShowM' $ "mid = " ++ show mid | ||
(mean1, stdev1) <- measureCpuTimeAndStDev eqlTimeout std $ eqlFasterOnLow mid | ||
traceShowM' $ "(mean1, stdev1) = " ++ show (mean1, stdev1) | ||
(mean2, stdev2) <- measureCpuTimeAndStDev eqlTimeout std $ eqlFasterOnHigh mid | ||
traceShowM' $ "(mean2, stdev2) = " ++ show (mean2, stdev2) | ||
if mean1 + 2 * stdev1 < mean2 - 2 * stdev2 | ||
then go std mid hi | ||
else | ||
if mean2 + 2 * stdev2 < mean1 - 2 * stdev1 | ||
then go std lo mid | ||
else go (RelStDev $ std' / 2) lo hi | ||
go targetRelStdDev lo hi = fmap ((lo, hi) :) $ | ||
unsafeInterleaveIO $ do | ||
let mid = (lo + hi) `quot` 2 | ||
measureIt alg k = do | ||
meas <- measure eqlTimeout targetRelStdDev $ alg mid | ||
traceShowM' meas | ||
if getRelStDev (measRelStDev meas) > getRelStDev targetRelStdDev | ||
then pure [] | ||
else k meas | ||
|
||
traceShowM' :: (Applicative m, Show a) => a -> m () | ||
#ifdef DEBUG | ||
traceShowM' = traceShowM | ||
#else | ||
traceShowM' = const (pure ()) | ||
#endif | ||
traceShowM' targetRelStdDev | ||
if mid == lo | ||
then pure [] | ||
else measureIt eqlFasterOnLow $ \(Measurement mean1 stdev1) -> | ||
measureIt eqlFasterOnHigh $ \(Measurement mean2 stdev2) -> do | ||
let (lo', hi') | ||
| mean1 + 2 * stdev1 < mean2 - 2 * stdev2 = (mid, hi) | ||
| mean2 + 2 * stdev2 < mean1 - 2 * stdev1 = (lo, mid) | ||
| otherwise = (lo, hi) | ||
let targetStdDev' = | ||
RelStDev $ | ||
max | ||
(abs (mean1 - mean2) / max mean1 mean2 / 4) | ||
(getRelStDev targetRelStdDev * (sqrt 5 - 1) / 2) | ||
go targetStdDev' lo' hi' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
{-# LANGUAGE CPP #-} | ||
{-# LANGUAGE DeriveGeneric #-} | ||
|
||
module Test.Tasty.Bench.Utils ( | ||
Measurement (..), | ||
measRelStDev, | ||
measure, | ||
getRelStDev, | ||
traceShowM', | ||
) where | ||
|
||
import Control.DeepSeq (NFData) | ||
import GHC.Generics (Generic) | ||
import Test.Tasty (Timeout) | ||
import Test.Tasty.Bench (Benchmarkable, RelStDev (..), measureCpuTimeAndStDev) | ||
import Text.Printf (printf) | ||
|
||
#ifdef DEBUG | ||
import Debug.Trace | ||
#endif | ||
|
||
-- | Represents a time measurement for a given problem's size. | ||
data Measurement = Measurement | ||
{ measTime :: !Double | ||
, measStDev :: !Double | ||
} | ||
deriving (Eq, Ord, Generic) | ||
|
||
instance Show Measurement where | ||
show (Measurement t err) = printf "%.3g ± %.3g" t err | ||
|
||
instance NFData Measurement | ||
|
||
measure :: Timeout -> RelStDev -> Benchmarkable -> IO Measurement | ||
measure x y z = uncurry Measurement <$> measureCpuTimeAndStDev x y z | ||
|
||
measRelStDev :: Measurement -> RelStDev | ||
measRelStDev (Measurement mean stDev) = RelStDev (stDev / mean) | ||
|
||
getRelStDev :: RelStDev -> Double | ||
getRelStDev (RelStDev x) = x | ||
|
||
traceShowM' :: (Applicative m, Show a) => a -> m () | ||
#ifdef DEBUG | ||
traceShowM' = traceShowM | ||
#else | ||
traceShowM' = const (pure ()) | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters