-
Notifications
You must be signed in to change notification settings - Fork 265
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove AdapterSpec from metrics #2244
base: main
Are you sure you want to change the base?
Conversation
ba53f57
to
e23220b
Compare
Converting to draft because this requires some manual testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is also removing ScenarioState
from metrics. Consider updating the name/description to mention that.
reference_stats: Dict[ReferenceKey, ReferenceStat] = {} | ||
for request_state in reference_request_states: | ||
assert request_state.reference_index is not None and request_state.request_mode is not None | ||
reference_key = ReferenceKey(request_state.reference_index, request_state.request_mode) | ||
reference_stats[reference_key] = compute_logprob_and_length(request_state, window_service) | ||
|
||
if adapter_spec.method in [ADAPT_MULTIPLE_CHOICE_SEPARATE_ORIGINAL, ADAPT_RANKING_BINARY]: | ||
is_calibrated = any([request_state.request_mode == "calibration" for request_state in reference_request_states]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why "any" here but using reference_request_states[0] to decide model_deployment?
If we are asserting in both cases that they are universal values, maybe we should write a helper to do that assertion?
@@ -294,20 +280,14 @@ def compute_request_state_metrics( | |||
stats: List[Stat] = [] | |||
|
|||
stats.append(Stat(MetricName("num_references")).add(len(request_state.instance.references))) | |||
|
|||
# Copy from adapter spec | |||
stats.append(Stat(MetricName("num_train_trials")).add(adapter_spec.num_train_trials)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this Stat not needed?
for context, request_states in grouped_request_states.items(): | ||
for stat in self.evaluate_instances(request_states): | ||
for request_state in trial_request_states: | ||
grouped_request_states[MetricContext.from_instance(request_state.instance)].append(request_state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has potential behavior change since it can include request_states that have non-None reference_index.
if request_state.reference_index is None: | ||
instance_to_request_state_set[instance].generation_states.append(request_state) | ||
else: | ||
instance_to_request_state_set[instance].references_states.append(request_state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously the reference_states were ordered by reference_index. Is that still guaranteed? Does it matter if the order changes?
@@ -166,7 +149,7 @@ def evaluate( | |||
|
|||
# Compute per-instance stats | |||
per_instance_stats: List[PerInstanceStats] = [] | |||
for instance, stats in zip(scenario_state.instances, results): | |||
for instance, stats in zip(instances, results): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think switching this to zip(request_state_sets, results)
would make it less fragile and more clear that we are putting the input and output of the parallel map back together.
@@ -352,3 +333,19 @@ def add_context(stat: Stat, context: MetricContext) -> Stat: | |||
return Stat( | |||
replace(stat.name, split=context.split, sub_split=context.sub_split, perturbation=context.perturbation) | |||
).merge(stat) | |||
|
|||
|
|||
def get_num_train_trials(request_states: List[RequestState]) -> int: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appears to be no method calling this? Is it left over from a previous iteration?
instance_to_request_state_set[instance].generation_states.append(request_state) | ||
else: | ||
instance_to_request_state_set[instance].references_states.append(request_state) | ||
request_state_sets: List[RequestStateSet] = list(instance_to_request_state_set.values()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Order here can also change. Maybe we want an OrderedDict?
This removes the coupling between the adapter and the metrics, allowing the metrics to be computed only using the requests and results from the model clients.