Skip to content

Commit

Permalink
Check feature input for NaNs (#60)
Browse files Browse the repository at this point in the history
* add check for nan in feature array
* add test for NaNs in feats and meta
  • Loading branch information
jessica-ewald authored Mar 5, 2024
1 parent ab8b688 commit 200d24d
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/copairs/map/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
from typing import List, Tuple

import pandas as pd
import numpy as np


def validate_pipeline_input(meta, feats, columns):
if meta[columns].isna().any(axis=None):
raise ValueError("metadata columns should not have null values.")
if len(meta) != len(feats):
raise ValueError("meta and feats have different number of rows")
if np.isnan(feats).any():
raise ValueError("features should not have null values.")


def flatten_str_list(*args):
Expand Down
27 changes: 26 additions & 1 deletion tests/test_map.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pandas as pd
import pytest
from sklearn.metrics import average_precision_score
import numpy as np

from copairs import compute
from copairs.map import average_precision
Expand Down Expand Up @@ -140,3 +140,28 @@ def test_raise_no_pairs():
average_precision(meta, feats, pos_sameby, pos_diffby, neg_sameby, neg_diffby)
with pytest.raises(UnpairedException, match="Unable to find negative pairs."):
average_precision(meta, feats, pos_diffby, [], pos_sameby, [])


def test_raise_nan_error():
length = 10
vocab_size = {"p": 5, "w": 3, "l": 4}
n_feats = 8
pos_sameby = ["l"]
pos_diffby = ["p"]
neg_sameby = []
neg_diffby = ["l"]
rng = np.random.default_rng(SEED)
meta = simulate_random_dframe(length, vocab_size, pos_sameby, pos_diffby, rng)
length = len(meta)
feats = rng.uniform(size=(length, n_feats))

# add null values
feats_nan = feats.copy()
feats_nan[2,2] = None
meta_nan = meta.copy()
meta_nan.loc[1,"p"] = None

with pytest.raises(ValueError, match="features should not have null values."):
average_precision(meta, feats_nan, pos_sameby, pos_diffby, neg_sameby, neg_diffby)
with pytest.raises(ValueError, match="metadata columns should not have null values."):
average_precision(meta_nan, feats, pos_sameby, pos_diffby, neg_sameby, neg_diffby)

0 comments on commit 200d24d

Please sign in to comment.