From c78c796529b38198751044d05330ed2456c1c51d Mon Sep 17 00:00:00 2001 From: yejashi Date: Sun, 6 Oct 2024 20:20:09 -0700 Subject: [PATCH] Confidence interval: basic implementation --- thicket/stats/confidence_interval.py | 80 ++++++++++++++++++++++------ thicket/tests/test_stats.py | 27 ++++++++++ 2 files changed, 91 insertions(+), 16 deletions(-) diff --git a/thicket/stats/confidence_interval.py b/thicket/stats/confidence_interval.py index 2c014644..12741cfd 100644 --- a/thicket/stats/confidence_interval.py +++ b/thicket/stats/confidence_interval.py @@ -10,41 +10,89 @@ import thicket as th from ..utils import verify_thicket_structures from .stats_utils import cache_stats_op -from thicket.stats import mean @cache_stats_op -def confidence_interval(thicket, columns=None, confidence_value=0.95): +def confidence_interval(thicket, columns=None, confidence_level=0.95): + r"""Calculate the confidence interval for each node in the performance data table. + + Designed to take in a thicket, and append one or more columns to the aggregated + statistics table for the confidence interval calculation for each node. + + A confidence interval is a range of values, derived from sample data, that is + likely to contain the true population parameter with a specified level of confidence. + It provides an estimate of uncertainty around a sample statistic, indicating how much + variability is expected if the sampling process were repeated multiple times. + + Arguments: + thicket (thicket): Thicket object + columns (list): List of hardware/timing metrics to perform confidence interval + calculation on. Note, if using a columnar_joined thicket a list of tuples + must be passed in with the format (column index, column name). + confidence_level (int,float): The confidence level (often 0.90, 0.95, or 0.99) + indicates the degree of confidence that the true parameter lies within the interval. + + Returns: + (list): returns a list of output statsframe column names + + Equation: + .. math:: + + \text{CI} = \bar{x} \pm z \left( \frac{\sigma}{\sqrt{n}} \right) + """ + if columns is None or isinstance(columns, list) is False: + raise ValueError("Value passed to 'columns' must be of type list.") + + if isinstance(confidence_level, (int, float)) is False: + raise ValueError( + r"Value passed to 'confidence_level' must be of type float or int." + ) + + if confidence_level >= 1 or confidence_level <= 0: + raise ValueError( + r"Value passed to 'confidence_level' must be in the range of (0, 1)." + ) + + verify_thicket_structures(thicket.dataframe, columns=columns) + output_column_names = [] - + + sample_sizes = [] + + # Calculate mean and standard deviation mean_cols = th.stats.mean(thicket, columns=columns) std_cols = th.stats.std(thicket, columns=columns) - sample_sizes = [] - z = stats.norm.ppf((1 + confidence_value) / 2) + # Convert confidence level to Z score + z = stats.norm.ppf((1 + confidence_level) / 2) + + # Get number of profiles per node idx = pd.IndexSlice for node in thicket.graph.traverse(): node_df = thicket.dataframe.loc[idx[node, :]] sample_sizes.append(len(node_df)) + # Calculate confidence interval for every column for i in range(0, len(columns)): x = thicket.statsframe.dataframe[mean_cols[i]] s = thicket.statsframe.dataframe[std_cols[i]] - n = sample_sizes - c_p = x + (z * (s / np.sqrt(n))) - c_m = x - (z * (s / np.sqrt(n))) - - out = list(zip(c_m, c_p)) - out = pd.Series(out, index=thicket.statsframe.dataframe.index) + c_p = x + (z * (s / np.sqrt(sample_sizes))) + c_m = x - (z * (s / np.sqrt(sample_sizes))) + + out = pd.Series(list(zip(c_m, c_p)), index=thicket.statsframe.dataframe.index) + + if thicket.dataframe.columns.nlevels == 1: + out_col = f"confidence_interval_{confidence_level}_{columns[i]}" + else: + out_col = ( + columns[i][0], + f"confidence_interval_{confidence_level}_{columns[i][1]}", + ) - # If multi index, place below first level - out_col = f"confidence_interval_{confidence_value}_{columns[i]}" output_column_names.append(out_col) thicket.statsframe.dataframe[out_col] = out - break thicket.statsframe.dataframe = thicket.statsframe.dataframe.sort_index(axis=1) - return output_column_names - + return output_column_names diff --git a/thicket/tests/test_stats.py b/thicket/tests/test_stats.py index 868cf6ef..e315567c 100644 --- a/thicket/tests/test_stats.py +++ b/thicket/tests/test_stats.py @@ -6,6 +6,7 @@ import math import numpy as np +import pytest import thicket as th @@ -1217,3 +1218,29 @@ def test_cache_decorator(rajaperf_seq_O3_1M_cali): assert ( len(th_1.statsframe_ops_cache[list(th_1.statsframe_ops_cache.keys())[0]]) == 1 ) + + +def test_confidence_interval(thicket_axis_columns): + thicket_list, thicket_list_cp, combined_th = thicket_axis_columns + + idx = list(combined_th.dataframe.columns.levels[0][0:2]) + columns = [(idx[0], "Min time/rank"), (idx[1], "Min time/rank")] + + with pytest.raises( + ValueError, match="Value passed to 'columns' must be of type list." + ): + th.stats.confidence_interval(combined_th, columns="columns") + + with pytest.raises( + ValueError, + match="Value passed to 'confidence_level' must be of type float or int.", + ): + th.stats.confidence_interval( + combined_th, columns=columns, confidence_level="0.95" + ) + + with pytest.raises( + ValueError, + match=r"Value passed to 'confidence_level' must be in the range of \(0, 1\).", + ): + th.stats.confidence_interval(combined_th, columns=columns, confidence_level=95)