Skip to content

Commit

Permalink
Confidence interval: basic implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Yejashi committed Oct 16, 2024
1 parent 1fb1c57 commit c78c796
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 16 deletions.
80 changes: 64 additions & 16 deletions thicket/stats/confidence_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,41 +10,89 @@
import thicket as th
from ..utils import verify_thicket_structures
from .stats_utils import cache_stats_op
from thicket.stats import mean


@cache_stats_op
def confidence_interval(thicket, columns=None, confidence_value=0.95):
def confidence_interval(thicket, columns=None, confidence_level=0.95):
r"""Calculate the confidence interval for each node in the performance data table.
Designed to take in a thicket, and append one or more columns to the aggregated
statistics table for the confidence interval calculation for each node.
A confidence interval is a range of values, derived from sample data, that is
likely to contain the true population parameter with a specified level of confidence.
It provides an estimate of uncertainty around a sample statistic, indicating how much
variability is expected if the sampling process were repeated multiple times.
Arguments:
thicket (thicket): Thicket object
columns (list): List of hardware/timing metrics to perform confidence interval
calculation on. Note, if using a columnar_joined thicket a list of tuples
must be passed in with the format (column index, column name).
confidence_level (int,float): The confidence level (often 0.90, 0.95, or 0.99)
indicates the degree of confidence that the true parameter lies within the interval.
Returns:
(list): returns a list of output statsframe column names
Equation:
.. math::
\text{CI} = \bar{x} \pm z \left( \frac{\sigma}{\sqrt{n}} \right)
"""
if columns is None or isinstance(columns, list) is False:
raise ValueError("Value passed to 'columns' must be of type list.")

if isinstance(confidence_level, (int, float)) is False:
raise ValueError(
r"Value passed to 'confidence_level' must be of type float or int."
)

if confidence_level >= 1 or confidence_level <= 0:
raise ValueError(
r"Value passed to 'confidence_level' must be in the range of (0, 1)."
)

verify_thicket_structures(thicket.dataframe, columns=columns)

output_column_names = []


sample_sizes = []

# Calculate mean and standard deviation
mean_cols = th.stats.mean(thicket, columns=columns)
std_cols = th.stats.std(thicket, columns=columns)
sample_sizes = []
z = stats.norm.ppf((1 + confidence_value) / 2)

# Convert confidence level to Z score
z = stats.norm.ppf((1 + confidence_level) / 2)

# Get number of profiles per node
idx = pd.IndexSlice
for node in thicket.graph.traverse():
node_df = thicket.dataframe.loc[idx[node, :]]
sample_sizes.append(len(node_df))

# Calculate confidence interval for every column
for i in range(0, len(columns)):
x = thicket.statsframe.dataframe[mean_cols[i]]
s = thicket.statsframe.dataframe[std_cols[i]]
n = sample_sizes

c_p = x + (z * (s / np.sqrt(n)))
c_m = x - (z * (s / np.sqrt(n)))

out = list(zip(c_m, c_p))
out = pd.Series(out, index=thicket.statsframe.dataframe.index)
c_p = x + (z * (s / np.sqrt(sample_sizes)))
c_m = x - (z * (s / np.sqrt(sample_sizes)))

out = pd.Series(list(zip(c_m, c_p)), index=thicket.statsframe.dataframe.index)

if thicket.dataframe.columns.nlevels == 1:
out_col = f"confidence_interval_{confidence_level}_{columns[i]}"
else:
out_col = (
columns[i][0],
f"confidence_interval_{confidence_level}_{columns[i][1]}",
)

# If multi index, place below first level
out_col = f"confidence_interval_{confidence_value}_{columns[i]}"
output_column_names.append(out_col)
thicket.statsframe.dataframe[out_col] = out
break

thicket.statsframe.dataframe = thicket.statsframe.dataframe.sort_index(axis=1)
return output_column_names


return output_column_names
27 changes: 27 additions & 0 deletions thicket/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import math

import numpy as np
import pytest

import thicket as th

Expand Down Expand Up @@ -1217,3 +1218,29 @@ def test_cache_decorator(rajaperf_seq_O3_1M_cali):
assert (
len(th_1.statsframe_ops_cache[list(th_1.statsframe_ops_cache.keys())[0]]) == 1
)


def test_confidence_interval(thicket_axis_columns):
thicket_list, thicket_list_cp, combined_th = thicket_axis_columns

idx = list(combined_th.dataframe.columns.levels[0][0:2])
columns = [(idx[0], "Min time/rank"), (idx[1], "Min time/rank")]

with pytest.raises(
ValueError, match="Value passed to 'columns' must be of type list."
):
th.stats.confidence_interval(combined_th, columns="columns")

with pytest.raises(
ValueError,
match="Value passed to 'confidence_level' must be of type float or int.",
):
th.stats.confidence_interval(
combined_th, columns=columns, confidence_level="0.95"
)

with pytest.raises(
ValueError,
match=r"Value passed to 'confidence_level' must be in the range of \(0, 1\).",
):
th.stats.confidence_interval(combined_th, columns=columns, confidence_level=95)

0 comments on commit c78c796

Please sign in to comment.