Skip to content

Commit

Permalink
clean evaluator code
Browse files Browse the repository at this point in the history
  • Loading branch information
joel99 committed Apr 5, 2024
1 parent 4adb60d commit e75917e
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 78 deletions.
19 changes: 3 additions & 16 deletions decoder_demos/ndt2_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
from context_general_bci.config import RootConfig, propagate_config, DataKey, MetaKey
from context_general_bci.dataset import DataAttrs, ContextAttrs
from context_general_bci.subjects import SubjectName
from context_general_bci.contexts.context_registry import context_registry
from context_general_bci.contexts.context_info import FalconContextInfo, ExperimentalTask
from context_general_bci.contexts.context_info import ExperimentalTask
from context_general_bci.model import load_from_checkpoint
from context_general_bci.model_slim import transfer_model

Expand All @@ -38,24 +37,14 @@ def __init__(
model_ckpt_path: str,
model_cfg_stem: str,
zscore_path: str,
dataset_handles: List[str] = []
):
r"""
Loading NDT2 requires both weights and model config. Weight loading through a checkpoint is standard.
Model config is typically stored on wandb, but this is not portable enough. Instead, directly reference the model config file.
"""
self._task_config = task_config
self.exp_task = getattr(ExperimentalTask, f'falcon_{task_config.task.name}')

context_registry.register([
*FalconContextInfo.build_from_dir(
f'./data/{task_config.task.name}/eval',
task=self.exp_task,
suffix='eval'),
*FalconContextInfo.build_from_dir(
f'./data/{task_config.task.name}/minival',
task=self.exp_task,
suffix='minival')])

try:
initialize_config_module(
config_module="context_general_bci.config",
Expand All @@ -76,9 +65,7 @@ def __init__(
context_idx = {
MetaKey.array.name: [format_array_name(self.subject)],
MetaKey.subject.name: [self.subject],
MetaKey.session.name: sorted([
self._task_config.hash_dataset(handle) for handle in task_config.dataset_handles
]),
MetaKey.session.name: sorted([self._task_config.hash_dataset(handle) for handle in dataset_handles]),
MetaKey.task.name: [self.exp_task],
}
data_attrs = DataAttrs.from_config(cfg.dataset, context=ContextAttrs(**context_idx))
Expand Down
3 changes: 2 additions & 1 deletion decoder_demos/ndt2_sample.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ ENV EVALUATION_LOC remote

# Add ckpt
# Note that Docker cannot easily import across symlinks, make sure data is not symlinked
ADD ./local_data/ndt2_h1_sample.pth data/decoder.pth
ADD ./local_data/ndt2_h1_sample_nokey.pth data/decoder.pth
ADD ./local_data/ndt2_zscore_h1.pt data/zscore.pt

# Add runfile
Expand All @@ -38,6 +38,7 @@ ENV PHASE "test"

# Make sure this matches the mounted data volume path. Generally leave as is.
ENV EVAL_DATA_PATH "/dataset/evaluation_data"
ADD ./falcon_challenge falcon_challenge

# CMD specifies a default command to run when the container is launched.
# It can be overridden with any cmd e.g. sudo docker run -it my_image /bin/bash
Expand Down
8 changes: 3 additions & 5 deletions decoder_demos/ndt2_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,14 @@ def main():
split=args.split)

task = getattr(FalconTask, args.split)
config = FalconConfig(
task=task,
dataset_handles=[x.stem for x in evaluator.get_eval_files(phase=args.phase)]
)
config = FalconConfig(task=task)

decoder = NDT2Decoder(
task_config=config,
model_ckpt_path=args.model_path,
model_cfg_stem=args.config_stem,
zscore_path=args.zscore_path
zscore_path=args.zscore_path,
dataset_handles=[x.stem for x in evaluator.get_eval_files(phase=args.phase)]
)


Expand Down
4 changes: 2 additions & 2 deletions decoder_demos/sklearn_sample.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ ADD ./decoder_demos/sklearn_sample.py decode.py
ADD ./preproc/filtering.py filtering.py

ENV SPLIT "h1"
ENV PHASE "minival"
# ENV PHASE "test"
# ENV PHASE "minival"
ENV PHASE "test"

# Make sure this matches the mounted data volume path. Generally leave as is.
ENV EVAL_DATA_PATH "/dataset/evaluation_data"
Expand Down
3 changes: 2 additions & 1 deletion falcon_challenge/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import enum
from typing import Union
from pathlib import Path
from dataclasses import dataclass, field

Expand Down Expand Up @@ -52,7 +53,7 @@ def out_dim(self):
return 2
raise NotImplementedError(f"Task {self.task} not implemented.")

def hash_dataset(self, handle: str | Path):
def hash_dataset(self, handle: Union[str, Path]):
r"""
handle - path.stem of a datafile.
Convenience function to help identify what "session" a datafile belongs to.. If multiple files per session in real-world time, this may _not_ uniquely identify runfile.
Expand Down
54 changes: 2 additions & 52 deletions falcon_challenge/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,51 +80,6 @@
'held_out': "Held Out",
}

# def evaluate(
# test_annotation_file: str, # The annotation file for the phase - but our labels are pulled from eval data.
# user_submission_file: str, # * JY: This appears to always be /submission/submission.csv on EvalAI. No matter - load it as a pickle.
# phase_codename: str, # e.g. minival or test
# **kwargs
# ):
# r"""
# Evaluate payloads with potentially multiple splits worth of data
# - Low pri: can I provide all results or just one split's worth entry? Currently providing 1, examples just provide 1, but in general would be nice to provide all. User shouldn't be able to submit more than 1, though.
# """
# # ! Want: Locally, test_annotation should be somewhere safe (tmp)
# # ! Remotely, it shoudl be /submission/submission.csv exactly.
# # Ignore explicit annotations provided and directly search for concatenated answers
# logger.info(f"Evaluation: Docker side")
# logger.info(f"Loading GT from {test_annotation_file}")
# logger.info(f"Loading submission from {user_submission_file}")
# logger.info(f"Phase: {phase_codename}")

# result = []
# # Load pickles
# with open(test_annotation_file, 'rb') as test_annotation_file, open(user_submission_file, 'rb') as user_submission_file:
# test_annotations = pickle.load(test_annotation_file)
# user_submission = pickle.load(user_submission_file)
# for datasplit in user_submission: # datasplit e.g. h1, m1
# if datasplit not in test_annotations:
# raise ValueError(f"Missing {datasplit} in GT labels.")
# split_annotations = test_annotations[datasplit]
# split_result = {}
# split_result["Normalized Latency"] = user_submission[datasplit]["normalized_latency"]
# for in_or_out in split_annotations.keys():
# if f'{in_or_out}_pred' in user_submission[datasplit]:
# pred = user_submission[datasplit][f'{in_or_out}_pred']
# mask = user_submission[datasplit][f'{in_or_out}_eval_mask']
# # User submission should be in an expected format because we force predictions through our codepack interface... right? They could hypothetically spoof. But we see dockerfile.
# eval_fn = FalconEvaluator.compute_metrics_classification if 'h2' in datasplit else FalconEvaluator.compute_metrics_regression
# metrics_held_in = eval_fn(pred, split_annotations[in_or_out], mask)
# for k in metrics_held_in:
# split_result[f'{HELDIN_OR_OUT_MAP[in_or_out]} {k}'] = metrics_held_in[k]
# result.append({datasplit: split_result})

# print(f"Returning result from phase: {phase_codename}: {result}")
# # Out struct according to https://evalai.readthedocs.io/en/latest/evaluation_scripts.html
# return {"result": result, 'submission_result': result[0]}


def evaluate(
test_annotation_file: str, # The annotation file for the phase
user_submission_file: str, # * JY: This appears to always be /submission/submission.csv on EvalAI. No matter - load it as a pickle.
Expand Down Expand Up @@ -341,17 +296,12 @@ def evaluate(self, decoder: BCIDecoder, phase: str, held_out_only: bool = False,
truth_payload = {self.dataset.name: inner_tgt_spoof}
else:
pass
# TODO restore
# metrics_held_in = self.compute_metrics(all_preds_held_in, all_targets_held_in, all_eval_mask_held_in)
# metrics_held_out = self.compute_metrics(all_preds_held_out, all_targets_held_out, all_eval_mask_held_out)
# for k, v in metrics_held_in.items():
# metrics[f'{HELDIN_OR_OUT_MAP["held_in"]} {k}'] = v
# for k, v in metrics_held_out.items():
# metrics[f'{HELDIN_OR_OUT_MAP["held_out"]} {k}'] = v

if USE_PKLS:
Path(prediction_path).parent.mkdir(parents=True, exist_ok=True)
with open(prediction_path, 'wb') as f:
pickle.dump(pred_payload, f)
Path(gt_path).parent.mkdir(parents=True, exist_ok=True)
with open(gt_path, 'wb') as f:
pickle.dump(truth_payload, f)
import time
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='falcon_challenge',
version='0.2.6',
version='0.2.7',

url='https://github.com/snel-repo/stability-benchmark',
author='Joel Ye',
Expand Down

0 comments on commit e75917e

Please sign in to comment.