From 928946b517ffdb51cd357cd6fda7e3aef23c2877 Mon Sep 17 00:00:00 2001 From: Joel Ye Date: Mon, 6 May 2024 22:36:49 -0400 Subject: [PATCH] continual without enforcement --- decoder_demos/ndt2_sample.py | 8 +++++++- falcon_challenge/evaluator.py | 6 ++++-- setup.py | 2 +- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/decoder_demos/ndt2_sample.py b/decoder_demos/ndt2_sample.py index 6498cbf..bb23b69 100644 --- a/decoder_demos/ndt2_sample.py +++ b/decoder_demos/ndt2_sample.py @@ -37,11 +37,17 @@ def main(): parser.add_argument( '--batch-size', type=int, default=1 ) + parser.add_argument( + '--continual', action='store_true', default=False + ) + args = parser.parse_args() evaluator = FalconEvaluator( eval_remote=args.evaluation == "remote", - split=args.split) + split=args.split, + continual=args.continual + ) task = getattr(FalconTask, args.split) config = FalconConfig(task=task) diff --git a/falcon_challenge/evaluator.py b/falcon_challenge/evaluator.py index 30ad5a0..fa5e8f4 100644 --- a/falcon_challenge/evaluator.py +++ b/falcon_challenge/evaluator.py @@ -277,9 +277,10 @@ def simple_collater(batch, task): class FalconEvaluator: - def __init__(self, eval_remote=False, split='h1'): + def __init__(self, eval_remote=False, split='h1', continual=False): self.eval_remote = eval_remote assert split in ['h1', 'h2', 'm1', 'm2'], "Split must be h1, h2, m1, or m2." + self.continual = continual self.dataset: FalconTask = getattr(FalconTask, split) self.cfg = FalconConfig(self.dataset) @@ -374,7 +375,8 @@ def predict_files(self, decoder: BCIDecoder, eval_files: List): if trial_delta_obs[0]: trial_preds.append(decoder.on_done(trial_delta_obs)) else: - decoder.on_done(trial_delta_obs) + if not self.continual: + decoder.on_done(trial_delta_obs) step_prediction = decoder.predict(neural_observations) assert step_prediction.shape[1] == self.cfg.out_dim, f"Prediction shape mismatch: {step_prediction.shape[1]} vs {self.cfg.out_dim}." trial_preds.append(step_prediction) diff --git a/setup.py b/setup.py index 2f82318..51d2b2a 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='falcon_challenge', - version='0.3.3', + version='0.3.4', url='https://github.com/snel-repo/stability-benchmark', author='Joel Ye',