diff --git a/examples/mAP_demo.ipynb b/examples/mAP_demo.ipynb
index 23a8478..5f25e41 100644
--- a/examples/mAP_demo.ipynb
+++ b/examples/mAP_demo.ipynb
@@ -623,7 +623,7 @@
"source": [
"Here, we treat different doses of each compound as replicates and assess how well we can retrieve them by similarity against the group of negative controls (DMSO).\n",
"\n",
- "To ensure correct grouping of profiles, we can add a dummy column that is equal to row index for all compound replicates and to -1 for all DMSO replicates. "
+ "To ensure correct grouping of profiles, we can add a dummy column that is equal to row index for all DMSO replicates and to -1 for all compound replicates. "
]
},
{
@@ -652,7 +652,7 @@
" \n",
" \n",
" \n",
- " Metadata_treatment_index \n",
+ " Metadata_reference_index \n",
" Metadata_broad_sample \n",
" Metadata_mg_per_ml \n",
" Metadata_mmoles_per_liter \n",
@@ -678,7 +678,7 @@
"
\n",
" \n",
" 0 \n",
- " -1 \n",
+ " 0 \n",
" DMSO \n",
" 0.000000 \n",
" 0.000000 \n",
@@ -702,7 +702,7 @@
" \n",
" \n",
" 1 \n",
- " -1 \n",
+ " 1 \n",
" DMSO \n",
" 0.000000 \n",
" 0.000000 \n",
@@ -726,7 +726,7 @@
" \n",
" \n",
" 2 \n",
- " -1 \n",
+ " 2 \n",
" DMSO \n",
" 0.000000 \n",
" 0.000000 \n",
@@ -750,7 +750,7 @@
" \n",
" \n",
" 3 \n",
- " -1 \n",
+ " 3 \n",
" DMSO \n",
" 0.000000 \n",
" 0.000000 \n",
@@ -774,7 +774,7 @@
" \n",
" \n",
" 4 \n",
- " -1 \n",
+ " 4 \n",
" DMSO \n",
" 0.000000 \n",
" 0.000000 \n",
@@ -822,7 +822,7 @@
" \n",
" \n",
" 379 \n",
- " 379 \n",
+ " -1 \n",
" BRD-K82746043-001-15-1 \n",
" 3.248700 \n",
" 3.333300 \n",
@@ -846,7 +846,7 @@
" \n",
" \n",
" 380 \n",
- " 380 \n",
+ " -1 \n",
" BRD-K82746043-001-15-1 \n",
" 1.082900 \n",
" 1.111100 \n",
@@ -870,7 +870,7 @@
" \n",
" \n",
" 381 \n",
- " 381 \n",
+ " -1 \n",
" BRD-K82746043-001-15-1 \n",
" 0.360970 \n",
" 0.370370 \n",
@@ -894,7 +894,7 @@
" \n",
" \n",
" 382 \n",
- " 382 \n",
+ " -1 \n",
" BRD-K82746043-001-15-1 \n",
" 0.120320 \n",
" 0.123460 \n",
@@ -918,7 +918,7 @@
" \n",
" \n",
" 383 \n",
- " 383 \n",
+ " -1 \n",
" BRD-K82746043-001-15-1 \n",
" 0.040108 \n",
" 0.041152 \n",
@@ -946,18 +946,18 @@
""
],
"text/plain": [
- " Metadata_treatment_index Metadata_broad_sample Metadata_mg_per_ml \\\n",
- "0 -1 DMSO 0.000000 \n",
- "1 -1 DMSO 0.000000 \n",
- "2 -1 DMSO 0.000000 \n",
- "3 -1 DMSO 0.000000 \n",
- "4 -1 DMSO 0.000000 \n",
+ " Metadata_reference_index Metadata_broad_sample Metadata_mg_per_ml \\\n",
+ "0 0 DMSO 0.000000 \n",
+ "1 1 DMSO 0.000000 \n",
+ "2 2 DMSO 0.000000 \n",
+ "3 3 DMSO 0.000000 \n",
+ "4 4 DMSO 0.000000 \n",
".. ... ... ... \n",
- "379 379 BRD-K82746043-001-15-1 3.248700 \n",
- "380 380 BRD-K82746043-001-15-1 1.082900 \n",
- "381 381 BRD-K82746043-001-15-1 0.360970 \n",
- "382 382 BRD-K82746043-001-15-1 0.120320 \n",
- "383 383 BRD-K82746043-001-15-1 0.040108 \n",
+ "379 -1 BRD-K82746043-001-15-1 3.248700 \n",
+ "380 -1 BRD-K82746043-001-15-1 1.082900 \n",
+ "381 -1 BRD-K82746043-001-15-1 0.360970 \n",
+ "382 -1 BRD-K82746043-001-15-1 0.120320 \n",
+ "383 -1 BRD-K82746043-001-15-1 0.040108 \n",
"\n",
" Metadata_mmoles_per_liter Metadata_pert_id Metadata_pert_mfc_id \\\n",
"0 0.000000 NaN NaN \n",
@@ -1100,12 +1100,12 @@
"source": [
"df_activity = df.copy()\n",
"# make deafult value equal to row index\n",
- "df_activity[\"Metadata_treatment_index\"] = df_activity.index\n",
- "# make index equal to -1 for all DMSO treatment replicates\n",
- "df_activity.loc[df[\"Metadata_broad_sample\"] == \"DMSO\", \"Metadata_treatment_index\"] = -1\n",
- "# now all treatment replicates differ in the index column, except for DMSO replicates\n",
+ "df_activity[\"Metadata_reference_index\"] = df_activity.index\n",
+ "# make index equal to -1 for all treatment replicates (non-DMSO)\n",
+ "df_activity.loc[df[\"Metadata_broad_sample\"] != \"DMSO\", \"Metadata_reference_index\"] = -1\n",
+ "# now all treatment replicates equal -1 in the index column, except for DMSO replicates\n",
"df_activity.insert(\n",
- " 0, \"Metadata_treatment_index\", df_activity.pop(\"Metadata_treatment_index\")\n",
+ " 0, \"Metadata_reference_index\", df_activity.pop(\"Metadata_reference_index\")\n",
")\n",
"df_activity"
]
@@ -1120,7 +1120,7 @@
"\n",
"* In this case, profiles that form a positive pair do not need to be different in any of the metatada columns, so we keep `pos_diffby` empty. Although one could define them as being from different batches, for instance, to account for batch effects.\n",
"\n",
- "* Two profiles are a negative pair when one of them belongs to a group of compound replicates and another to a group of DMSO controls. That means they should be different both in the metadata column that identifies the specific compound and the treatment index columns that we created. The latter is needed to ensure that replicates of compounds are retrieved against only DMSO controls at this stage (and not against replicates of other compounds). We list these columns in `neg_diffby`.\n",
+ "* Two profiles are a negative pair when one of them belongs to a group of compound replicates and another to a group of DMSO controls. That means they should be different both in the metadata column that identifies the specific compound and the reference index columns that we created. The latter is needed to ensure that replicates of compounds are retrieved against only DMSO controls at this stage (and not against replicates of other compounds). We list these columns in `neg_diffby`.\n",
"\n",
"* Profiles that form a negative pair do not need to be same in any of the metatada columns, so we keep `neg_sameby` empty."
]
@@ -1137,7 +1137,7 @@
"\n",
"neg_sameby = []\n",
"# negative pairs are replicates of different treatments\n",
- "neg_diffby = [\"Metadata_broad_sample\", \"Metadata_treatment_index\"]"
+ "neg_diffby = [\"Metadata_broad_sample\", \"Metadata_reference_index\"]"
]
},
{
@@ -1157,7 +1157,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "5d583875de81417fa8f66f87e2bfb80b",
+ "model_id": "51509158c2e84267b94e8d0cf5952604",
"version_major": 2,
"version_minor": 0
},
@@ -1171,12 +1171,12 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "e7687333b8544cc6a2e377ac848e3457",
+ "model_id": "5458f7a2a5904b7685560ac4e20d3dd8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
- " 0%| | 0/4 [00:00, ?it/s]"
+ " 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
@@ -1203,7 +1203,7 @@
" \n",
" \n",
" \n",
- " Metadata_treatment_index \n",
+ " Metadata_reference_index \n",
" Metadata_broad_sample \n",
" Metadata_mg_per_ml \n",
" Metadata_mmoles_per_liter \n",
@@ -1226,7 +1226,7 @@
" \n",
" \n",
" 6 \n",
- " 6 \n",
+ " -1 \n",
" BRD-K74363950-004-01-0 \n",
" 5.655600 \n",
" 10.000000 \n",
@@ -1242,12 +1242,12 @@
" broad_id_20170327 \n",
" A07 \n",
" 5 \n",
- " 383 \n",
- " 0.050922 \n",
+ " 29 \n",
+ " 0.325013 \n",
" \n",
" \n",
" 7 \n",
- " 7 \n",
+ " -1 \n",
" BRD-K74363950-004-01-0 \n",
" 1.885200 \n",
" 3.333300 \n",
@@ -1263,12 +1263,12 @@
" broad_id_20170327 \n",
" A08 \n",
" 5 \n",
- " 383 \n",
- " 0.308904 \n",
+ " 29 \n",
+ " 0.513889 \n",
" \n",
" \n",
" 8 \n",
- " 8 \n",
+ " -1 \n",
" BRD-K74363950-004-01-0 \n",
" 0.628400 \n",
" 1.111100 \n",
@@ -1284,12 +1284,12 @@
" broad_id_20170327 \n",
" A09 \n",
" 5 \n",
- " 383 \n",
- " 0.412513 \n",
+ " 29 \n",
+ " 0.727778 \n",
" \n",
" \n",
" 9 \n",
- " 9 \n",
+ " -1 \n",
" BRD-K74363950-004-01-0 \n",
" 0.209470 \n",
" 0.370370 \n",
@@ -1305,12 +1305,12 @@
" broad_id_20170327 \n",
" A10 \n",
" 5 \n",
- " 383 \n",
- " 0.377730 \n",
+ " 29 \n",
+ " 0.783333 \n",
" \n",
" \n",
" 10 \n",
- " 10 \n",
+ " -1 \n",
" BRD-K74363950-004-01-0 \n",
" 0.069823 \n",
" 0.123460 \n",
@@ -1326,8 +1326,8 @@
" broad_id_20170327 \n",
" A11 \n",
" 5 \n",
- " 383 \n",
- " 0.715591 \n",
+ " 29 \n",
+ " 0.900000 \n",
" \n",
" \n",
" ... \n",
@@ -1352,7 +1352,7 @@
" \n",
" \n",
" 379 \n",
- " 379 \n",
+ " -1 \n",
" BRD-K82746043-001-15-1 \n",
" 3.248700 \n",
" 3.333300 \n",
@@ -1368,12 +1368,12 @@
" broad_id_20170327 \n",
" P20 \n",
" 5 \n",
- " 383 \n",
- " 0.726786 \n",
+ " 29 \n",
+ " 1.000000 \n",
" \n",
" \n",
" 380 \n",
- " 380 \n",
+ " -1 \n",
" BRD-K82746043-001-15-1 \n",
" 1.082900 \n",
" 1.111100 \n",
@@ -1389,12 +1389,12 @@
" broad_id_20170327 \n",
" P21 \n",
" 5 \n",
- " 383 \n",
- " 0.658824 \n",
+ " 29 \n",
+ " 0.966667 \n",
" \n",
" \n",
" 381 \n",
- " 381 \n",
+ " -1 \n",
" BRD-K82746043-001-15-1 \n",
" 0.360970 \n",
" 0.370370 \n",
@@ -1410,12 +1410,12 @@
" broad_id_20170327 \n",
" P22 \n",
" 5 \n",
- " 383 \n",
- " 0.517619 \n",
+ " 29 \n",
+ " 0.942857 \n",
" \n",
" \n",
" 382 \n",
- " 382 \n",
+ " -1 \n",
" BRD-K82746043-001-15-1 \n",
" 0.120320 \n",
" 0.123460 \n",
@@ -1431,12 +1431,12 @@
" broad_id_20170327 \n",
" P23 \n",
" 5 \n",
- " 383 \n",
- " 0.543290 \n",
+ " 29 \n",
+ " 1.000000 \n",
" \n",
" \n",
" 383 \n",
- " 383 \n",
+ " -1 \n",
" BRD-K82746043-001-15-1 \n",
" 0.040108 \n",
" 0.041152 \n",
@@ -1452,8 +1452,8 @@
" broad_id_20170327 \n",
" P24 \n",
" 5 \n",
- " 383 \n",
- " 0.535714 \n",
+ " 29 \n",
+ " 1.000000 \n",
" \n",
" \n",
"\n",
@@ -1461,18 +1461,18 @@
""
],
"text/plain": [
- " Metadata_treatment_index Metadata_broad_sample Metadata_mg_per_ml \\\n",
- "6 6 BRD-K74363950-004-01-0 5.655600 \n",
- "7 7 BRD-K74363950-004-01-0 1.885200 \n",
- "8 8 BRD-K74363950-004-01-0 0.628400 \n",
- "9 9 BRD-K74363950-004-01-0 0.209470 \n",
- "10 10 BRD-K74363950-004-01-0 0.069823 \n",
+ " Metadata_reference_index Metadata_broad_sample Metadata_mg_per_ml \\\n",
+ "6 -1 BRD-K74363950-004-01-0 5.655600 \n",
+ "7 -1 BRD-K74363950-004-01-0 1.885200 \n",
+ "8 -1 BRD-K74363950-004-01-0 0.628400 \n",
+ "9 -1 BRD-K74363950-004-01-0 0.209470 \n",
+ "10 -1 BRD-K74363950-004-01-0 0.069823 \n",
".. ... ... ... \n",
- "379 379 BRD-K82746043-001-15-1 3.248700 \n",
- "380 380 BRD-K82746043-001-15-1 1.082900 \n",
- "381 381 BRD-K82746043-001-15-1 0.360970 \n",
- "382 382 BRD-K82746043-001-15-1 0.120320 \n",
- "383 383 BRD-K82746043-001-15-1 0.040108 \n",
+ "379 -1 BRD-K82746043-001-15-1 3.248700 \n",
+ "380 -1 BRD-K82746043-001-15-1 1.082900 \n",
+ "381 -1 BRD-K82746043-001-15-1 0.360970 \n",
+ "382 -1 BRD-K82746043-001-15-1 0.120320 \n",
+ "383 -1 BRD-K82746043-001-15-1 0.040108 \n",
"\n",
" Metadata_mmoles_per_liter Metadata_pert_id Metadata_pert_mfc_id \\\n",
"6 10.000000 BRD-K74363950 BRD-K74363950-004-01-0 \n",
@@ -1527,17 +1527,17 @@
"383 BCL2|BCL2L1|BCL2L2 broad_id_20170327 P24 \n",
"\n",
" n_pos_pairs n_total_pairs average_precision \n",
- "6 5 383 0.050922 \n",
- "7 5 383 0.308904 \n",
- "8 5 383 0.412513 \n",
- "9 5 383 0.377730 \n",
- "10 5 383 0.715591 \n",
+ "6 5 29 0.325013 \n",
+ "7 5 29 0.513889 \n",
+ "8 5 29 0.727778 \n",
+ "9 5 29 0.783333 \n",
+ "10 5 29 0.900000 \n",
".. ... ... ... \n",
- "379 5 383 0.726786 \n",
- "380 5 383 0.658824 \n",
- "381 5 383 0.517619 \n",
- "382 5 383 0.543290 \n",
- "383 5 383 0.535714 \n",
+ "379 5 29 1.000000 \n",
+ "380 5 29 0.966667 \n",
+ "381 5 29 0.942857 \n",
+ "382 5 29 1.000000 \n",
+ "383 5 29 1.000000 \n",
"\n",
"[360 rows x 18 columns]"
]
@@ -1575,7 +1575,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "7b70f4682ff947ca875777958c499f94",
+ "model_id": "b55cf11c765b4af98dca44f808372955",
"version_major": 2,
"version_minor": 0
},
@@ -1589,7 +1589,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "eed952af896e48ffaa790e3c997714fc",
+ "model_id": "f470ba4162d4425f8983d7f7c52ff982",
"version_major": 2,
"version_minor": 0
},
@@ -1634,102 +1634,102 @@
" \n",
" 0 \n",
" BRD-A69275535-001-01-5 \n",
- " 0.203576 \n",
- " 0.012899 \n",
- " 0.016390 \n",
+ " 0.575629 \n",
+ " 0.017698 \n",
+ " 0.023857 \n",
" True \n",
" True \n",
- " 1.785430 \n",
+ " 1.622390 \n",
" \n",
" \n",
" 1 \n",
" BRD-A69636825-003-04-7 \n",
- " 0.269093 \n",
- " 0.000800 \n",
- " 0.001365 \n",
+ " 0.693806 \n",
+ " 0.003700 \n",
+ " 0.006922 \n",
" True \n",
" True \n",
- " 2.865004 \n",
+ " 2.159775 \n",
" \n",
" \n",
" 2 \n",
" BRD-A69815203-001-07-6 \n",
- " 0.862226 \n",
+ " 1.000000 \n",
" 0.000100 \n",
- " 0.000276 \n",
+ " 0.000341 \n",
" True \n",
" True \n",
- " 3.558835 \n",
+ " 3.467064 \n",
" \n",
" \n",
" 3 \n",
" BRD-A70858459-001-01-7 \n",
- " 0.351816 \n",
- " 0.000200 \n",
- " 0.000400 \n",
+ " 0.777173 \n",
+ " 0.000600 \n",
+ " 0.001289 \n",
" True \n",
" True \n",
- " 3.397983 \n",
+ " 2.889828 \n",
" \n",
" \n",
" 4 \n",
" BRD-A72309220-001-04-1 \n",
- " 0.263986 \n",
- " 0.000900 \n",
- " 0.001491 \n",
+ " 0.716927 \n",
+ " 0.002200 \n",
+ " 0.004253 \n",
" True \n",
" True \n",
- " 2.826441 \n",
+ " 2.371314 \n",
" \n",
" \n",
" 5 \n",
" BRD-A72390365-001-15-2 \n",
- " 0.554667 \n",
+ " 0.934444 \n",
" 0.000100 \n",
- " 0.000276 \n",
+ " 0.000341 \n",
" True \n",
" True \n",
- " 3.558835 \n",
+ " 3.467064 \n",
" \n",
" \n",
" 6 \n",
" BRD-A73368467-003-17-6 \n",
- " 0.788666 \n",
+ " 0.926032 \n",
" 0.000100 \n",
- " 0.000276 \n",
+ " 0.000341 \n",
" True \n",
" True \n",
- " 3.558835 \n",
+ " 3.467064 \n",
" \n",
" \n",
" 7 \n",
" BRD-A74980173-001-11-9 \n",
- " 0.500600 \n",
- " 0.000100 \n",
- " 0.000276 \n",
+ " 0.765931 \n",
+ " 0.000600 \n",
+ " 0.001289 \n",
" True \n",
" True \n",
- " 3.558835 \n",
+ " 2.889828 \n",
" \n",
" \n",
" 8 \n",
" BRD-A81233518-004-16-1 \n",
- " 0.140208 \n",
- " 0.015598 \n",
- " 0.018700 \n",
+ " 0.621183 \n",
+ " 0.009399 \n",
+ " 0.013978 \n",
" True \n",
" True \n",
- " 1.728154 \n",
+ " 1.854552 \n",
" \n",
" \n",
" 9 \n",
" BRD-A82035391-001-02-7 \n",
- " 0.052362 \n",
- " 0.077692 \n",
- " 0.078692 \n",
+ " 0.318066 \n",
+ " 0.260374 \n",
+ " 0.264942 \n",
" False \n",
" False \n",
- " 1.104069 \n",
+ " 0.576849 \n",
" \n",
" \n",
"\n",
@@ -1737,28 +1737,28 @@
],
"text/plain": [
" Metadata_broad_sample mean_average_precision p_value \\\n",
- "0 BRD-A69275535-001-01-5 0.203576 0.012899 \n",
- "1 BRD-A69636825-003-04-7 0.269093 0.000800 \n",
- "2 BRD-A69815203-001-07-6 0.862226 0.000100 \n",
- "3 BRD-A70858459-001-01-7 0.351816 0.000200 \n",
- "4 BRD-A72309220-001-04-1 0.263986 0.000900 \n",
- "5 BRD-A72390365-001-15-2 0.554667 0.000100 \n",
- "6 BRD-A73368467-003-17-6 0.788666 0.000100 \n",
- "7 BRD-A74980173-001-11-9 0.500600 0.000100 \n",
- "8 BRD-A81233518-004-16-1 0.140208 0.015598 \n",
- "9 BRD-A82035391-001-02-7 0.052362 0.077692 \n",
+ "0 BRD-A69275535-001-01-5 0.575629 0.017698 \n",
+ "1 BRD-A69636825-003-04-7 0.693806 0.003700 \n",
+ "2 BRD-A69815203-001-07-6 1.000000 0.000100 \n",
+ "3 BRD-A70858459-001-01-7 0.777173 0.000600 \n",
+ "4 BRD-A72309220-001-04-1 0.716927 0.002200 \n",
+ "5 BRD-A72390365-001-15-2 0.934444 0.000100 \n",
+ "6 BRD-A73368467-003-17-6 0.926032 0.000100 \n",
+ "7 BRD-A74980173-001-11-9 0.765931 0.000600 \n",
+ "8 BRD-A81233518-004-16-1 0.621183 0.009399 \n",
+ "9 BRD-A82035391-001-02-7 0.318066 0.260374 \n",
"\n",
" corrected_p_value below_p below_corrected_p -log10(p-value) \n",
- "0 0.016390 True True 1.785430 \n",
- "1 0.001365 True True 2.865004 \n",
- "2 0.000276 True True 3.558835 \n",
- "3 0.000400 True True 3.397983 \n",
- "4 0.001491 True True 2.826441 \n",
- "5 0.000276 True True 3.558835 \n",
- "6 0.000276 True True 3.558835 \n",
- "7 0.000276 True True 3.558835 \n",
- "8 0.018700 True True 1.728154 \n",
- "9 0.078692 False False 1.104069 "
+ "0 0.023857 True True 1.622390 \n",
+ "1 0.006922 True True 2.159775 \n",
+ "2 0.000341 True True 3.467064 \n",
+ "3 0.001289 True True 2.889828 \n",
+ "4 0.004253 True True 2.371314 \n",
+ "5 0.000341 True True 3.467064 \n",
+ "6 0.000341 True True 3.467064 \n",
+ "7 0.001289 True True 2.889828 \n",
+ "8 0.013978 True True 1.854552 \n",
+ "9 0.264942 False False 0.576849 "
]
},
"execution_count": 9,
@@ -1788,7 +1788,7 @@
"outputs": [
{
"data": {
- "image/png": "",
+ "image/png": "",
"text/plain": [
""
]
@@ -2479,7 +2479,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "1dfaedfbdea44ec4ba89eb7d49a4d9ea",
+ "model_id": "d6f90c7f26924332b4a4e23ba90dd98e",
"version_major": 2,
"version_minor": 0
},
@@ -2493,7 +2493,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "b7e2c7e2ab214da09395dab1d71b9221",
+ "model_id": "e6891153024b485a92a556545e04a8eb",
"version_major": 2,
"version_minor": 0
},
@@ -2534,43 +2534,43 @@
" \n",
" \n",
" \n",
- " 54 \n",
+ " 52 \n",
" BRD-A69636825-003-04-7 \n",
" 0.500000 \n",
" 1 \n",
- " 46 \n",
+ " 42 \n",
" HTR3A \n",
" \n",
" \n",
- " 34 \n",
+ " 32 \n",
" BRD-A72309220-001-04-1 \n",
- " 0.396412 \n",
+ " 0.406071 \n",
" 4 \n",
- " 46 \n",
+ " 42 \n",
" HTR1A \n",
" \n",
" \n",
- " 39 \n",
+ " 37 \n",
" BRD-A72309220-001-04-1 \n",
" 0.142857 \n",
" 1 \n",
- " 43 \n",
+ " 39 \n",
" HTR1B \n",
" \n",
" \n",
- " 41 \n",
+ " 39 \n",
" BRD-A72309220-001-04-1 \n",
" 0.142857 \n",
" 1 \n",
- " 43 \n",
+ " 39 \n",
" HTR1D \n",
" \n",
" \n",
- " 43 \n",
+ " 41 \n",
" BRD-A72309220-001-04-1 \n",
" 0.142857 \n",
" 1 \n",
- " 43 \n",
+ " 39 \n",
" HTR1E \n",
" \n",
" \n",
@@ -2584,25 +2584,25 @@
" \n",
" 16 \n",
" BRD-K74363950-004-01-0 \n",
- " 0.094538 \n",
+ " 0.105128 \n",
" 2 \n",
- " 46 \n",
+ " 42 \n",
" CHRM3 \n",
" \n",
" \n",
" 19 \n",
" BRD-K74363950-004-01-0 \n",
- " 0.094538 \n",
+ " 0.105128 \n",
" 2 \n",
- " 46 \n",
+ " 42 \n",
" CHRM4 \n",
" \n",
" \n",
" 22 \n",
" BRD-K74363950-004-01-0 \n",
- " 0.094538 \n",
+ " 0.105128 \n",
" 2 \n",
- " 46 \n",
+ " 42 \n",
" CHRM5 \n",
" \n",
" \n",
@@ -2610,50 +2610,50 @@
" BRD-K76908866-001-07-6 \n",
" 0.500000 \n",
" 1 \n",
- " 46 \n",
+ " 42 \n",
" ERBB2 \n",
" \n",
" \n",
- " 63 \n",
+ " 61 \n",
" BRD-K81258678-001-01-0 \n",
" 0.100000 \n",
" 1 \n",
- " 46 \n",
+ " 42 \n",
" RELA \n",
" \n",
" \n",
"\n",
- "66 rows × 5 columns
\n",
+ "64 rows × 5 columns
\n",
""
],
"text/plain": [
" Metadata_broad_sample average_precision n_pos_pairs n_total_pairs \\\n",
- "54 BRD-A69636825-003-04-7 0.500000 1 46 \n",
- "34 BRD-A72309220-001-04-1 0.396412 4 46 \n",
- "39 BRD-A72309220-001-04-1 0.142857 1 43 \n",
- "41 BRD-A72309220-001-04-1 0.142857 1 43 \n",
- "43 BRD-A72309220-001-04-1 0.142857 1 43 \n",
+ "52 BRD-A69636825-003-04-7 0.500000 1 42 \n",
+ "32 BRD-A72309220-001-04-1 0.406071 4 42 \n",
+ "37 BRD-A72309220-001-04-1 0.142857 1 39 \n",
+ "39 BRD-A72309220-001-04-1 0.142857 1 39 \n",
+ "41 BRD-A72309220-001-04-1 0.142857 1 39 \n",
".. ... ... ... ... \n",
- "16 BRD-K74363950-004-01-0 0.094538 2 46 \n",
- "19 BRD-K74363950-004-01-0 0.094538 2 46 \n",
- "22 BRD-K74363950-004-01-0 0.094538 2 46 \n",
- "28 BRD-K76908866-001-07-6 0.500000 1 46 \n",
- "63 BRD-K81258678-001-01-0 0.100000 1 46 \n",
+ "16 BRD-K74363950-004-01-0 0.105128 2 42 \n",
+ "19 BRD-K74363950-004-01-0 0.105128 2 42 \n",
+ "22 BRD-K74363950-004-01-0 0.105128 2 42 \n",
+ "28 BRD-K76908866-001-07-6 0.500000 1 42 \n",
+ "61 BRD-K81258678-001-01-0 0.100000 1 42 \n",
"\n",
" Metadata_target \n",
- "54 HTR3A \n",
- "34 HTR1A \n",
- "39 HTR1B \n",
- "41 HTR1D \n",
- "43 HTR1E \n",
+ "52 HTR3A \n",
+ "32 HTR1A \n",
+ "37 HTR1B \n",
+ "39 HTR1D \n",
+ "41 HTR1E \n",
".. ... \n",
"16 CHRM3 \n",
"19 CHRM4 \n",
"22 CHRM5 \n",
"28 ERBB2 \n",
- "63 RELA \n",
+ "61 RELA \n",
"\n",
- "[66 rows x 5 columns]"
+ "[64 rows x 5 columns]"
]
},
"execution_count": 13,
@@ -2700,7 +2700,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "e4cf611dd48f421a92039c551475d6e8",
+ "model_id": "587b411cad734aa9ab356ee6ba537fd5",
"version_major": 2,
"version_minor": 0
},
@@ -2714,12 +2714,12 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "36db5ef9e7e24a55899d35ab9b7d746a",
+ "model_id": "c051523be08049089850e2cda87e7b5c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
- " 0%| | 0/27 [00:00, ?it/s]"
+ " 0%| | 0/26 [00:00, ?it/s]"
]
},
"metadata": {},
@@ -2759,102 +2759,102 @@
" \n",
" 0 \n",
" ADRA1A \n",
- " 0.238095 \n",
- " 0.104890 \n",
- " 0.167542 \n",
+ " 0.250000 \n",
+ " 0.113389 \n",
+ " 0.192056 \n",
" False \n",
" False \n",
- " 0.775876 \n",
+ " 0.716573 \n",
" \n",
" \n",
" 1 \n",
" ADRA2A \n",
- " 0.238095 \n",
- " 0.104890 \n",
- " 0.167542 \n",
+ " 0.250000 \n",
+ " 0.113389 \n",
+ " 0.192056 \n",
" False \n",
" False \n",
- " 0.775876 \n",
+ " 0.716573 \n",
" \n",
" \n",
" 2 \n",
" AURKA \n",
" 0.625000 \n",
- " 0.022298 \n",
- " 0.100340 \n",
+ " 0.023398 \n",
+ " 0.101390 \n",
" True \n",
" False \n",
- " 0.998526 \n",
+ " 0.994005 \n",
" \n",
" \n",
" 3 \n",
" BIRC2 \n",
- " 0.051316 \n",
- " 0.413459 \n",
- " 0.483152 \n",
+ " 0.060662 \n",
+ " 0.379062 \n",
+ " 0.469315 \n",
" False \n",
" False \n",
- " 0.315917 \n",
+ " 0.328536 \n",
" \n",
" \n",
" 4 \n",
" CHRM1 \n",
- " 0.091024 \n",
- " 0.483152 \n",
- " 0.483152 \n",
+ " 0.098420 \n",
+ " 0.484752 \n",
+ " 0.484752 \n",
" False \n",
" False \n",
- " 0.315917 \n",
+ " 0.314481 \n",
" \n",
" \n",
" 5 \n",
" CHRM2 \n",
- " 0.091024 \n",
- " 0.483152 \n",
- " 0.483152 \n",
+ " 0.098420 \n",
+ " 0.484752 \n",
+ " 0.484752 \n",
" False \n",
" False \n",
- " 0.315917 \n",
+ " 0.314481 \n",
" \n",
" \n",
" 6 \n",
" CHRM3 \n",
- " 0.091024 \n",
- " 0.483152 \n",
- " 0.483152 \n",
+ " 0.098420 \n",
+ " 0.484752 \n",
+ " 0.484752 \n",
" False \n",
" False \n",
- " 0.315917 \n",
+ " 0.314481 \n",
" \n",
" \n",
" 7 \n",
" CHRM4 \n",
- " 0.091024 \n",
- " 0.483152 \n",
- " 0.483152 \n",
+ " 0.098420 \n",
+ " 0.484752 \n",
+ " 0.484752 \n",
" False \n",
" False \n",
- " 0.315917 \n",
+ " 0.314481 \n",
" \n",
" \n",
" 8 \n",
" CHRM5 \n",
- " 0.091024 \n",
- " 0.483152 \n",
- " 0.483152 \n",
+ " 0.098420 \n",
+ " 0.484752 \n",
+ " 0.484752 \n",
" False \n",
" False \n",
- " 0.315917 \n",
+ " 0.314481 \n",
" \n",
" \n",
" 9 \n",
" DRD2 \n",
" 0.750000 \n",
" 0.000900 \n",
- " 0.006074 \n",
+ " 0.005849 \n",
" True \n",
" True \n",
- " 2.216497 \n",
+ " 2.232888 \n",
" \n",
" \n",
"\n",
@@ -2862,28 +2862,28 @@
],
"text/plain": [
" Metadata_target mean_average_precision p_value corrected_p_value \\\n",
- "0 ADRA1A 0.238095 0.104890 0.167542 \n",
- "1 ADRA2A 0.238095 0.104890 0.167542 \n",
- "2 AURKA 0.625000 0.022298 0.100340 \n",
- "3 BIRC2 0.051316 0.413459 0.483152 \n",
- "4 CHRM1 0.091024 0.483152 0.483152 \n",
- "5 CHRM2 0.091024 0.483152 0.483152 \n",
- "6 CHRM3 0.091024 0.483152 0.483152 \n",
- "7 CHRM4 0.091024 0.483152 0.483152 \n",
- "8 CHRM5 0.091024 0.483152 0.483152 \n",
- "9 DRD2 0.750000 0.000900 0.006074 \n",
+ "0 ADRA1A 0.250000 0.113389 0.192056 \n",
+ "1 ADRA2A 0.250000 0.113389 0.192056 \n",
+ "2 AURKA 0.625000 0.023398 0.101390 \n",
+ "3 BIRC2 0.060662 0.379062 0.469315 \n",
+ "4 CHRM1 0.098420 0.484752 0.484752 \n",
+ "5 CHRM2 0.098420 0.484752 0.484752 \n",
+ "6 CHRM3 0.098420 0.484752 0.484752 \n",
+ "7 CHRM4 0.098420 0.484752 0.484752 \n",
+ "8 CHRM5 0.098420 0.484752 0.484752 \n",
+ "9 DRD2 0.750000 0.000900 0.005849 \n",
"\n",
" below_p below_corrected_p -log10(p-value) \n",
- "0 False False 0.775876 \n",
- "1 False False 0.775876 \n",
- "2 True False 0.998526 \n",
- "3 False False 0.315917 \n",
- "4 False False 0.315917 \n",
- "5 False False 0.315917 \n",
- "6 False False 0.315917 \n",
- "7 False False 0.315917 \n",
- "8 False False 0.315917 \n",
- "9 True True 2.216497 "
+ "0 False False 0.716573 \n",
+ "1 False False 0.716573 \n",
+ "2 True False 0.994005 \n",
+ "3 False False 0.328536 \n",
+ "4 False False 0.314481 \n",
+ "5 False False 0.314481 \n",
+ "6 False False 0.314481 \n",
+ "7 False False 0.314481 \n",
+ "8 False False 0.314481 \n",
+ "9 True True 2.232888 "
]
},
"execution_count": 14,
@@ -2913,7 +2913,7 @@
"outputs": [
{
"data": {
- "image/png": "",
+ "image/png": "",
"text/plain": [
""
]
@@ -2990,7 +2990,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "copairs",
+ "display_name": "map_benchmark",
"language": "python",
"name": "python3"
},
@@ -3004,7 +3004,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.9.19"
+ "version": "3.10.13"
}
},
"nbformat": 4,
diff --git a/pyproject.toml b/pyproject.toml
index 2133055..1d0645c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -17,6 +17,7 @@ dependencies = [
]
[project.optional-dependencies]
+dev = ["ruff"]
plot = ["plotly"]
test = ["scikit-learn", "pytest"]
demo = ["notebook", "matplotlib"]
@@ -31,4 +32,4 @@ requires = ["setuptools"]
build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
-where = ["src"]
+where = ["src"]
diff --git a/tests/README.md b/tests/README.md
new file mode 100644
index 0000000..42ebdad
--- /dev/null
+++ b/tests/README.md
@@ -0,0 +1,25 @@
+ # Unit tests
+
+We use `pytest` package to implement and run unit tests for copairs.
+
+## Getting started
+
+### Installation
+
+To install copairs with test dependencies, check out code locally and install as:
+```bash
+pip install -e .[test]
+```
+
+### Running tests
+To execute all tests, run:
+```bash
+pytest
+```
+
+Each individual `test_filename.py` file implements tests for particular features in the corresponding `copairs/filename.py`.
+
+To run tests for a particular source file, specify its test file:
+```bash
+pytest tests/test_map.py
+```
diff --git a/tests/__init__.py b/tests/__init__.py
index e69de29..b185083 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -0,0 +1 @@
+"""Unit tests for the copairs package."""
diff --git a/tests/helpers.py b/tests/helpers.py
index a7a4c25..57207a5 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -1,3 +1,5 @@
+"""Helper functions for testing."""
+
from itertools import product
from typing import Dict
@@ -10,7 +12,7 @@
def simulate_plates(n_compounds, n_replicates, plate_size):
- """Round robin creation of platemaps"""
+ """Round robin creation of platemaps."""
total = n_compounds * n_replicates
compounds = []
@@ -35,6 +37,7 @@ def simulate_random_plates(
sameby=ColumnList,
diffby=ColumnList,
):
+ """Simulate random platemaps."""
rng = np.random.default_rng(SEED)
dframe = simulate_plates(n_compounds, n_replicates, plate_size)
# Shuffle values
@@ -52,6 +55,7 @@ def simulate_random_dframe(
diffby: ColumnList,
rng: np.random.Generator,
):
+ """Simulate random dataframe."""
dframe = pd.DataFrame(columns=list(vocab_size.keys()), index=range(length))
for col, size in vocab_size.items():
dframe[col] = rng.integers(1, size + 1, size=length)
@@ -64,9 +68,7 @@ def simulate_random_dframe(
def create_dframe(n_options, n_rows):
- """
- Random permutation of a fix number of elements per column
- """
+ """Create a dataframe with predefined number of plates, wells, and compounds."""
if isinstance(n_options, int):
n_options = [n_options] * 3
colc = list(f"c{i}" for i in range(n_options[0]))
diff --git a/tests/test_build_rank_multilabel.py b/tests/test_build_rank_multilabel.py
index 49b6f08..210a918 100644
--- a/tests/test_build_rank_multilabel.py
+++ b/tests/test_build_rank_multilabel.py
@@ -1,9 +1,12 @@
+"""Test the concatenation of ranges."""
+
import numpy as np
from copairs.compute import concat_ranges
def naive_concat_ranges(start: np.ndarray, end: np.ndarray):
+ """Concatenate ranges into a mask."""
mask = []
for s, e in zip(start, end):
mask.extend(range(s, e))
@@ -11,6 +14,7 @@ def naive_concat_ranges(start: np.ndarray, end: np.ndarray):
def test_concat_ranges():
+ """Test the concatenation of ranges."""
rng = np.random.default_rng()
num_range = 5, 10
start_range = 2, 10
diff --git a/tests/test_compute.py b/tests/test_compute.py
index 63444c7..dfad755 100644
--- a/tests/test_compute.py
+++ b/tests/test_compute.py
@@ -1,3 +1,5 @@
+"""Test pairwise distance calculation functions."""
+
import numpy as np
from copairs import compute
@@ -7,6 +9,7 @@
def corrcoef_naive(feats, pairs):
+ """Compute correlation coefficient between pairs of features."""
corr = np.empty((len(pairs),))
for pos, (i, j) in enumerate(pairs):
corr[pos] = np.corrcoef(feats[i], feats[j])[0, 1]
@@ -14,6 +17,7 @@ def corrcoef_naive(feats, pairs):
def cosine_naive(feats, pairs):
+ """Compute cosine similarity between pairs of features."""
cosine = np.empty((len(pairs),))
for pos, (i, j) in enumerate(pairs):
a, b = feats[i], feats[j]
@@ -24,6 +28,7 @@ def cosine_naive(feats, pairs):
def euclidean_naive(feats, pairs):
+ """Compute euclidean similarity between pairs of features."""
euclidean_sim = np.empty((len(pairs),))
for pos, (i, j) in enumerate(pairs):
dist = np.linalg.norm(feats[i] - feats[j])
@@ -32,10 +37,12 @@ def euclidean_naive(feats, pairs):
def abs_cosine_naive(feats, pairs):
+ """Compute absolute cosine similarity between pairs of features."""
return np.abs(cosine_naive(feats, pairs))
def test_corrcoef():
+ """Test correlation coefficient computation."""
n_samples = 10
n_pairs = 20
n_feats = 5
@@ -50,6 +57,7 @@ def test_corrcoef():
def test_cosine():
+ """Test cosine similarity computation."""
n_samples = 10
n_pairs = 20
n_feats = 5
@@ -64,6 +72,7 @@ def test_cosine():
def test_euclidean():
+ """Test euclidean similarity computation."""
n_samples = 10
n_pairs = 20
n_feats = 5
@@ -78,6 +87,7 @@ def test_euclidean():
def test_abs_cosine():
+ """Test absolute cosine similarity computation."""
n_samples = 10
n_pairs = 20
n_feats = 5
diff --git a/tests/test_map.py b/tests/test_map.py
index 816d7d8..b18ca9e 100644
--- a/tests/test_map.py
+++ b/tests/test_map.py
@@ -1,3 +1,5 @@
+"""Tests for (mean) Average Precision calculation."""
+
import numpy as np
import pandas as pd
import pytest
@@ -13,6 +15,7 @@
def test_random_binary_matrix():
+ """Test the random binary matrix generation."""
rng = np.random.default_rng(SEED)
# Test with n=3, m=4, k=2
A = compute.random_binary_matrix(3, 4, 2, rng)
@@ -28,6 +31,7 @@ def test_random_binary_matrix():
def test_compute_ap():
+ """Test the average precision computation."""
num_pos, num_neg, num_perm = 5, 6, 100
total = num_pos + num_neg
@@ -56,6 +60,7 @@ def test_compute_ap():
def test_compute_ap_contiguous():
+ """Test the contiguous average precision computation."""
num_pos_range = [2, 9]
num_neg_range = [10, 20]
num_samples_range = [5, 30]
@@ -88,6 +93,7 @@ def test_compute_ap_contiguous():
def test_pipeline():
+ """Check the implementation with for mAP calculation."""
length = 10
vocab_size = {"p": 5, "w": 3, "l": 4}
n_feats = 5
@@ -103,7 +109,7 @@ def test_pipeline():
def test_pipeline_multilabel():
- """Check the multilabel implementation with for mAP calculation"""
+ """Check the multilabel implementation with for mAP calculation."""
length = 10
vocab_size = {"p": 3, "w": 5, "l": 4}
n_feats = 8
@@ -124,6 +130,7 @@ def test_pipeline_multilabel():
def test_raise_no_pairs():
+ """Test the exception raised when no pairs are found."""
length = 10
vocab_size = {"p": 3, "w": 3, "l": 10}
n_feats = 5
@@ -143,6 +150,7 @@ def test_raise_no_pairs():
def test_raise_nan_error():
+ """Test the exception raised when there are null values."""
length = 10
vocab_size = {"p": 5, "w": 3, "l": 4}
n_feats = 8
diff --git a/tests/test_map_filter.py b/tests/test_map_filter.py
index 9b1b311..c49e6e2 100644
--- a/tests/test_map_filter.py
+++ b/tests/test_map_filter.py
@@ -1,3 +1,5 @@
+"""Tests data filtering by query."""
+
import numpy as np
import pytest
@@ -9,6 +11,7 @@
@pytest.fixture
def mock_dataframe():
+ """Create a mock dataframe."""
length = 10
vocab_size = {"p": 3, "w": 3, "l": 10}
pos_sameby = ["l"]
@@ -20,6 +23,7 @@ def mock_dataframe():
def test_correct(mock_dataframe):
+ """Test correct query."""
df, parsed_cols = evaluate_and_filter(mock_dataframe, ["p == 'p1'", "w > 'w2'"])
assert not df.empty
assert "p" in parsed_cols and "w" in parsed_cols
@@ -27,6 +31,7 @@ def test_correct(mock_dataframe):
def test_invalid_query(mock_dataframe):
+ """Test invalid query."""
with pytest.raises(ValueError) as excinfo:
evaluate_and_filter(mock_dataframe, ['l == "lHello"'])
assert "Invalid combined query expression" in str(excinfo.value)
@@ -34,12 +39,14 @@ def test_invalid_query(mock_dataframe):
def test_empty_result(mock_dataframe):
+ """Test empty result."""
with pytest.raises(ValueError) as excinfo:
evaluate_and_filter(mock_dataframe, ['p == "p1"', 'p == "p2"'])
assert "Duplicate queries for column" in str(excinfo.value)
def test_empty_result_from_valid_query(mock_dataframe):
+ """Test empty result from valid query."""
with pytest.raises(ValueError) as excinfo:
evaluate_and_filter(mock_dataframe, ['p == "p4"'])
assert "No data matched the query" in str(excinfo.value)
diff --git a/tests/test_matching.py b/tests/test_matching.py
index 5bc1132..91c494f 100644
--- a/tests/test_matching.py
+++ b/tests/test_matching.py
@@ -1,4 +1,4 @@
-"""Test functions for Matcher"""
+"""Test functions for Matcher."""
from string import ascii_letters
@@ -13,7 +13,7 @@
def run_stress_sample_null(dframe, num_pairs):
- """Assert every generated null pair does not match any column"""
+ """Assert every generated null pair does not match any column."""
matcher = Matcher(dframe, dframe.columns, seed=SEED)
for _ in range(num_pairs):
id1, id2 = matcher.sample_null_pair(dframe.columns)
@@ -23,19 +23,19 @@ def run_stress_sample_null(dframe, num_pairs):
def test_null_sample_large():
- """Test Matcher guarantees elements with different values"""
+ """Test Matcher guarantees elements with different values."""
dframe = create_dframe(32, 10000)
run_stress_sample_null(dframe, 5000)
def test_null_sample_small():
- """Test Sample with small set"""
+ """Test Sample with small set."""
dframe = create_dframe(3, 10)
run_stress_sample_null(dframe, 100)
def test_null_sample_nan_vals():
- """Test NaN values are ignored"""
+ """Test NaN values are ignored."""
dframe = create_dframe(4, 15)
rng = np.random.default_rng(SEED)
nan_mask = rng.random(dframe.shape) < 0.5
@@ -44,7 +44,7 @@ def test_null_sample_nan_vals():
def get_naive_pairs(dframe: pd.DataFrame, sameby, diffby):
- """Compute valid pairs using cross product from pandas"""
+ """Compute valid pairs using cross product from pandas."""
cross = dframe.reset_index().merge(
dframe.reset_index(), how="cross", suffixes=("_x", "_y")
)
@@ -62,7 +62,7 @@ def get_naive_pairs(dframe: pd.DataFrame, sameby, diffby):
def check_naive(dframe, matcher: Matcher, sameby, diffby):
- """Check Matcher and naive generate same pairs"""
+ """Check Matcher and naive generate same pairs."""
gt_pairs = get_naive_pairs(dframe, sameby, diffby)
vals = matcher.get_all_pairs(sameby, diffby)
vals = sum(vals.values(), [])
@@ -74,14 +74,14 @@ def check_naive(dframe, matcher: Matcher, sameby, diffby):
def check_simulated_data(length, vocab_size, sameby, diffby, rng):
- """Test sample of valid pairs from a simulated dataset"""
+ """Test sample of valid pairs from a simulated dataset."""
dframe = simulate_random_dframe(length, vocab_size, sameby, diffby, rng)
matcher = Matcher(dframe, dframe.columns, seed=SEED)
check_naive(dframe, matcher, sameby, diffby)
def test_stress_simulated_data():
- """Run multiple tests using simulated data"""
+ """Run multiple tests using simulated data."""
rng = np.random.default_rng(SEED)
num_cols_range = [2, 6]
vocab_size_range = [5, 10]
@@ -99,7 +99,7 @@ def test_stress_simulated_data():
def test_empty_sameby():
- """Test query without sameby"""
+ """Test query without sameby."""
dframe = create_dframe(3, 10)
matcher = Matcher(dframe, dframe.columns, seed=SEED)
check_naive(dframe, matcher, sameby=[], diffby=["w", "c"])
@@ -107,7 +107,7 @@ def test_empty_sameby():
def test_empty_diffby():
- """Test query without diffby"""
+ """Test query without diffby."""
dframe = create_dframe(3, 10)
matcher = Matcher(dframe, dframe.columns, seed=SEED)
matcher.get_all_pairs(["c"], [])
@@ -116,7 +116,7 @@ def test_empty_diffby():
def test_raise_distjoint():
- """Test check for disjoint sameby and diffby"""
+ """Test check for disjoint sameby and diffby."""
dframe = create_dframe(3, 10)
matcher = Matcher(dframe, dframe.columns, seed=SEED)
with pytest.raises(ValueError, match="must be disjoint lists"):
@@ -124,7 +124,7 @@ def test_raise_distjoint():
def test_raise_no_params():
- """Test check for at least one of sameby and diffby"""
+ """Test check for at least one of sameby and diffby."""
dframe = create_dframe(3, 10)
matcher = Matcher(dframe, dframe.columns, seed=SEED)
with pytest.raises(ValueError, match="at least one should be provided"):
@@ -132,7 +132,7 @@ def test_raise_no_params():
def assert_sameby_diffby(dframe: pd.DataFrame, pairs_dict: dict, sameby, diffby):
- """Assert the pairs are valid"""
+ """Assert the pairs are valid."""
for _, pairs in pairs_dict.items():
for id1, id2 in pairs:
for col in sameby:
diff --git a/tests/test_matching_any.py b/tests/test_matching_any.py
index 25ccc02..b949613 100644
--- a/tests/test_matching_any.py
+++ b/tests/test_matching_any.py
@@ -1,3 +1,5 @@
+"""Test matching with `any` conditions using simulated data."""
+
from string import ascii_letters
import numpy as np
@@ -10,7 +12,7 @@
def get_naive_pairs(dframe: pd.DataFrame, sameby, diffby):
- """Compute valid pairs using cross product from pandas"""
+ """Compute valid pairs using cross product from pandas."""
cross = dframe.reset_index().merge(
dframe.reset_index(), how="cross", suffixes=("_x", "_y")
)
@@ -39,7 +41,7 @@ def get_naive_pairs(dframe: pd.DataFrame, sameby, diffby):
def check_naive(dframe, matcher: Matcher, sameby, diffby):
- """Check Matcher and naive generate same pairs"""
+ """Check Matcher and naive generate same pairs."""
gt_pairs = get_naive_pairs(dframe, sameby, diffby)
vals = matcher.get_all_pairs(sameby, diffby)
vals = sum(vals.values(), [])
@@ -51,7 +53,7 @@ def check_naive(dframe, matcher: Matcher, sameby, diffby):
def check_simulated_data(length, vocab_size, sameby, diffby, rng):
- """Test sample of valid pairs from a simulated dataset"""
+ """Test sample of valid pairs from a simulated dataset."""
sameby_cols = sameby["all"] + sameby["any"]
diffby_cols = diffby["all"] + diffby["any"]
dframe = simulate_random_dframe(length, vocab_size, sameby_cols, diffby_cols, rng)
@@ -60,7 +62,7 @@ def check_simulated_data(length, vocab_size, sameby, diffby, rng):
def test_stress_simulated_data_any_all():
- """Run multiple tests using simulated data"""
+ """Run multiple tests using simulated data."""
rng = np.random.default_rng(SEED)
num_cols_range = [2, 6]
vocab_size_range = [5, 10]
@@ -78,7 +80,7 @@ def test_stress_simulated_data_any_all():
def test_stress_simulated_data_all_all():
- """Run multiple tests using simulated data"""
+ """Run multiple tests using simulated data."""
rng = np.random.default_rng(SEED)
num_cols_range = [2, 6]
vocab_size_range = [5, 10]
@@ -96,7 +98,7 @@ def test_stress_simulated_data_all_all():
def test_stress_simulated_data_all_any():
- """Run multiple tests using simulated data"""
+ """Run multiple tests using simulated data."""
rng = np.random.default_rng(SEED)
num_cols_range = [2, 6]
vocab_size_range = [5, 10]
@@ -114,7 +116,7 @@ def test_stress_simulated_data_all_any():
def test_stress_simulated_data_any_any():
- """Run multiple tests using simulated data"""
+ """Run multiple tests using simulated data."""
rng = np.random.default_rng(SEED)
num_cols_range = [4, 6]
vocab_size_range = [5, 10]
diff --git a/tests/test_matching_multilabel.py b/tests/test_matching_multilabel.py
index 50f978e..0ee6538 100644
--- a/tests/test_matching_multilabel.py
+++ b/tests/test_matching_multilabel.py
@@ -1,3 +1,5 @@
+"""Tests for the multilabel matching implementation."""
+
import pandas as pd
from copairs.matching import MatcherMultilabel
@@ -7,6 +9,7 @@
def get_naive_pairs(dframe: pd.DataFrame, sameby, diffby, multilabel_col: str):
+ """Get pairs using a naive implementation."""
dframe = dframe.copy()
dframe[multilabel_col] = dframe[multilabel_col].apply(set)
@@ -45,7 +48,7 @@ def any_equal(row):
def check_naive(dframe, matcher: MatcherMultilabel, sameby, diffby, multilabel_col):
- """Check Matcher and naive generate same pairs"""
+ """Check Matcher and naive generate same pairs."""
gt_pairs = get_naive_pairs(dframe, sameby, diffby, multilabel_col)
vals = matcher.get_all_pairs(sameby, diffby)
vals = sum(vals.values(), [])
@@ -57,7 +60,7 @@ def check_naive(dframe, matcher: MatcherMultilabel, sameby, diffby, multilabel_c
def test_sameby():
- """Check the multilabel implementation with sameby"""
+ """Check the multilabel implementation with sameby."""
multilabel_col = "c"
sameby = ["c"]
diffby = ["p", "w"]
@@ -70,7 +73,7 @@ def test_sameby():
def test_diffby():
- """Check the multilabel implementation with sameby"""
+ """Check the multilabel implementation with sameby."""
multilabel_col = "c"
sameby = ["p"]
diffby = ["c", "w"]
@@ -84,7 +87,7 @@ def test_diffby():
def test_only_diffby():
- """Check the multilabel implementation with only diffby being equal to c"""
+ """Check the multilabel implementation with only diffby being equal to c."""
multilabel_col = "c"
sameby = []
diffby = ["c"]
@@ -97,7 +100,7 @@ def test_only_diffby():
def test_only_diffby_many_cols():
- """Check the multilabel implementation with only diffby being equal to c"""
+ """Check the multilabel implementation with only diffby being equal to c."""
multilabel_col = "c"
sameby = []
diffby = ["c", "w"]
@@ -110,7 +113,7 @@ def test_only_diffby_many_cols():
def test_only_sameby_many_cols():
- """Check the multilabel implementation with only diffby being equal to c"""
+ """Check the multilabel implementation with only diffby being equal to c."""
multilabel_col = "c"
sameby = ["c", "w"]
diffby = []
diff --git a/tests/test_replicating.py b/tests/test_replicating.py
index a273bbe..db9f157 100644
--- a/tests/test_replicating.py
+++ b/tests/test_replicating.py
@@ -1,3 +1,5 @@
+"""Tests for the replicating module."""
+
from numpy.random import default_rng
from copairs import Matcher
@@ -12,6 +14,7 @@
def test_corr_between_replicates():
+ """Test calculating correlation between replicates."""
rng = default_rng(SEED)
num_samples = 10
X = rng.normal(size=[num_samples, 6])
@@ -20,6 +23,7 @@ def test_corr_between_replicates():
def test_correlation_test():
+ """Test correlation test."""
rng = default_rng(SEED)
num_samples = 10
X = rng.normal(size=[num_samples, 6])
@@ -31,6 +35,7 @@ def test_correlation_test():
def test_corr_from_pairs():
+ """Test calculating correlation from a list of named pairs."""
num_samples = 10
sameby = ["c"]
diffby = ["p", "w"]