Skip to content

Commit

Permalink
nit on masking eval data to save filesize
Browse files Browse the repository at this point in the history
  • Loading branch information
joel99 committed Apr 3, 2024
1 parent f3db585 commit ed3b58d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
14 changes: 9 additions & 5 deletions falcon_challenge/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def evaluate(self, decoder: BCIDecoder, phase: str, held_out_only: bool = False,
inner_pred = {**all_preds}
inner_tgt_spoof = { # spoof for local mirror of eval ai path, in reality targets are already compiled on eval ai side.
k: {
'data': all_targets[k],
'data': all_targets[k][all_eval_mask[k]],
'mask': all_eval_mask[k],
} for k in all_targets
}
Expand Down Expand Up @@ -341,21 +341,25 @@ def evaluate(self, decoder: BCIDecoder, phase: str, held_out_only: bool = False,
for k, v in metrics.items():
logger.info("{}: {}".format(k, v))


@staticmethod
def compute_metrics_regression(preds, targets, eval_mask):
targets = targets[eval_mask]
# assumes targets are already masked
preds = preds[eval_mask]
if not targets.shape[0] == preds.shape[0]:
raise ValueError(f"Targets and predictions have different lengths: {targets.shape[0]} vs {preds.shape[0]}.")
return {
"R2": r2_score(targets, preds, multioutput='variance_weighted'),
"R2 Std.": 0, # TODO Clay
}

@staticmethod
def compute_metrics_classification(preds, targets, eval_mask):
preds = preds[eval_mask]
if not targets.shape[0] == preds.shape[0]:
raise ValueError(f"Targets and predictions have different lengths: {targets.shape[0]} vs {preds.shape[0]}.")
return {
"CER": 1-(preds == targets)[eval_mask].mean(),
"CER Std.": 0, # TODO Clay
"WER": 1-(preds == targets).mean(),
"WER Std.": 0, # TODO Clay
}

def compute_metrics(self, all_preds, all_targets, all_eval_mask=None):
Expand Down
2 changes: 1 addition & 1 deletion preproc/merge_answers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def assemble_phase_answer_key(phase='minival', answer_key_dir='./data/answer_key
print(d.stem)
neural_data, decoding_targets, trial_change, eval_mask = load_nwb(d, dataset=task)
annotations[dataset][config.hash_dataset(d.stem)] = {
'data': decoding_targets,
'data': decoding_targets[eval_mask],
'mask': eval_mask
}
return annotations
Expand Down

0 comments on commit ed3b58d

Please sign in to comment.