From 128ca79f68b2c863736470f9758b5b9b47f087bc Mon Sep 17 00:00:00 2001 From: jessica-ewald <1jess.ewald@gmail.com> Date: Tue, 6 Feb 2024 11:51:51 -0500 Subject: [PATCH 1/2] add option to randomly sample negative pairs --- src/copairs/map/average_precision.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/copairs/map/average_precision.py b/src/copairs/map/average_precision.py index c335bfa..3c462f4 100644 --- a/src/copairs/map/average_precision.py +++ b/src/copairs/map/average_precision.py @@ -26,9 +26,8 @@ def build_rank_lists(pos_pairs, neg_pairs, pos_sims, neg_sims): paired_ix, counts = np.unique(ix, return_counts=True) return paired_ix, rel_k_list, counts - def average_precision( - meta, feats, pos_sameby, pos_diffby, neg_sameby, neg_diffby, batch_size=20000 + meta, feats, pos_sameby, pos_diffby, neg_sameby, neg_diffby, batch_size=20000, sample_neg: float = 1 ) -> pd.DataFrame: columns = flatten_str_list(pos_sameby, pos_diffby, neg_sameby, neg_diffby) validate_pipeline_input(meta, feats, columns) @@ -60,6 +59,20 @@ def average_precision( count=neg_total, ) + # if sample_neg not equal to 1, randomly sample negative pairs + if (sample_neg > 1) & (sample_neg < neg_pairs.shape[0]): + sampled_rows = np.random.choice(neg_pairs.shape[0], size=sample_neg, replace=False) + neg_pairs = neg_pairs[sampled_rows] + elif (sample_neg > 1) & (sample_neg >= neg_pairs.shape[0]): + raise UnpairedException("'sample_neg' must be less than the number of negative pairs. There are " + str(neg_pairs.shape[0]) + " negative pairs in this analysis.") + elif (sample_neg > 0) & (sample_neg < 1): + sample_size = round(sample_neg*neg_pairs.shape[0]) + sampled_rows = np.random.choice(neg_pairs.shape[0], size=sample_size, replace=False) + neg_pairs = neg_pairs[sampled_rows] + elif sample_neg <= 0: + raise UnpairedException("'sample_neg' must be greater than 0.") + + logger.info("Computing positive similarities...") pos_sims = compute.pairwise_cosine(feats, pos_pairs, batch_size) From 0518045a8f06c92f475af0ef403408ed7e08322d Mon Sep 17 00:00:00 2001 From: jessica-ewald <1jess.ewald@gmail.com> Date: Fri, 9 Feb 2024 08:52:05 -0500 Subject: [PATCH 2/2] change negative sampling to factor of pos pairs --- src/copairs/map/average_precision.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/src/copairs/map/average_precision.py b/src/copairs/map/average_precision.py index 3c462f4..dc0d453 100644 --- a/src/copairs/map/average_precision.py +++ b/src/copairs/map/average_precision.py @@ -27,7 +27,7 @@ def build_rank_lists(pos_pairs, neg_pairs, pos_sims, neg_sims): return paired_ix, rel_k_list, counts def average_precision( - meta, feats, pos_sameby, pos_diffby, neg_sameby, neg_diffby, batch_size=20000, sample_neg: float = 1 + meta, feats, pos_sameby, pos_diffby, neg_sameby, neg_diffby, batch_size=20000, sample_neg: bool = False, sample_factor: float = 10 ) -> pd.DataFrame: columns = flatten_str_list(pos_sameby, pos_diffby, neg_sameby, neg_diffby) validate_pipeline_input(meta, feats, columns) @@ -60,18 +60,11 @@ def average_precision( ) # if sample_neg not equal to 1, randomly sample negative pairs - if (sample_neg > 1) & (sample_neg < neg_pairs.shape[0]): - sampled_rows = np.random.choice(neg_pairs.shape[0], size=sample_neg, replace=False) - neg_pairs = neg_pairs[sampled_rows] - elif (sample_neg > 1) & (sample_neg >= neg_pairs.shape[0]): - raise UnpairedException("'sample_neg' must be less than the number of negative pairs. There are " + str(neg_pairs.shape[0]) + " negative pairs in this analysis.") - elif (sample_neg > 0) & (sample_neg < 1): - sample_size = round(sample_neg*neg_pairs.shape[0]) - sampled_rows = np.random.choice(neg_pairs.shape[0], size=sample_size, replace=False) - neg_pairs = neg_pairs[sampled_rows] - elif sample_neg <= 0: - raise UnpairedException("'sample_neg' must be greater than 0.") - + if sample_neg: + sample_size = pos_pairs.shape[0]*sample_factor + if sample_size < neg_pairs.shape[0]: + sampled_rows = np.random.choice(neg_pairs.shape[0], size=sample_size, replace=False) + neg_pairs = neg_pairs[sampled_rows] logger.info("Computing positive similarities...") pos_sims = compute.pairwise_cosine(feats, pos_pairs, batch_size)