Skip to content

Commit

Permalink
continual without enforcement
Browse files Browse the repository at this point in the history
  • Loading branch information
joel99 committed May 7, 2024
1 parent 585152d commit 928946b
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
8 changes: 7 additions & 1 deletion decoder_demos/ndt2_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,17 @@ def main():
parser.add_argument(
'--batch-size', type=int, default=1
)
parser.add_argument(
'--continual', action='store_true', default=False
)

args = parser.parse_args()

evaluator = FalconEvaluator(
eval_remote=args.evaluation == "remote",
split=args.split)
split=args.split,
continual=args.continual
)

task = getattr(FalconTask, args.split)
config = FalconConfig(task=task)
Expand Down
6 changes: 4 additions & 2 deletions falcon_challenge/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,10 @@ def simple_collater(batch, task):

class FalconEvaluator:

def __init__(self, eval_remote=False, split='h1'):
def __init__(self, eval_remote=False, split='h1', continual=False):
self.eval_remote = eval_remote
assert split in ['h1', 'h2', 'm1', 'm2'], "Split must be h1, h2, m1, or m2."
self.continual = continual
self.dataset: FalconTask = getattr(FalconTask, split)
self.cfg = FalconConfig(self.dataset)

Expand Down Expand Up @@ -374,7 +375,8 @@ def predict_files(self, decoder: BCIDecoder, eval_files: List):
if trial_delta_obs[0]:
trial_preds.append(decoder.on_done(trial_delta_obs))
else:
decoder.on_done(trial_delta_obs)
if not self.continual:
decoder.on_done(trial_delta_obs)
step_prediction = decoder.predict(neural_observations)
assert step_prediction.shape[1] == self.cfg.out_dim, f"Prediction shape mismatch: {step_prediction.shape[1]} vs {self.cfg.out_dim}."
trial_preds.append(step_prediction)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='falcon_challenge',
version='0.3.3',
version='0.3.4',

url='https://github.com/snel-repo/stability-benchmark',
author='Joel Ye',
Expand Down

0 comments on commit 928946b

Please sign in to comment.