Skip to content

Commit

Permalink
Fix get_model_signature
Browse files Browse the repository at this point in the history
The get_model_signature method needs the actual name of the model - which _resolve_model was not returning (only the dict). _resolve_model now returns the new model_name as well. This should have the added benefit of making the logging of the new name easier in the future.
  • Loading branch information
Rastislav Turanyi committed Jan 24, 2025
1 parent 00fa259 commit 400d6f8
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions src/resolution_functions/instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def _get_model_data(self, model_name: Optional[str] = None, **kwargs) -> tuple[M
if model_name is None:
model_name = self.default_model

model = self._resolve_model(model_name)
model, model_name = self._resolve_model(model_name)

available_configurations = model['configurations']

Expand All @@ -448,7 +448,7 @@ def _get_model_data(self, model_name: Optional[str] = None, **kwargs) -> tuple[M
**ChainMap(*configurations, model['parameters']))
return model, model_name

def _resolve_model(self, model_name: str) -> dict:
def _resolve_model(self, model_name: str) -> tuple[dict, str]:
"""
Returns the data for the `model_name` model, taking into account aliases.
Expand All @@ -468,12 +468,13 @@ def _resolve_model(self, model_name: str) -> dict:
raise InvalidModelError(model_name, self)

if isinstance(model, str):
model_name = model
try:
model = self._models[model]
except KeyError:
raise InvalidModelError(model, self)

return model
return model, model_name

def get_resolution_function(self, model_name: Optional[str] = None, **kwargs) -> InstrumentModel:
"""
Expand Down Expand Up @@ -813,7 +814,7 @@ def possible_configurations_for_model(self, model_name: str) -> list[str]:
InvalidModelError
If the provided `model_name` is not supported for this version of this instrument.
"""
return list(self._resolve_model(model_name)['configurations'])
return list(self._resolve_model(model_name)[0]['configurations'])

def possible_options_for_model(self, model_name: str) -> dict[str, list[str]]:
"""
Expand All @@ -836,7 +837,7 @@ def possible_options_for_model(self, model_name: str) -> dict[str, list[str]]:
InvalidModelError
If the provided `model_name` is not supported for this version of this instrument.
"""
model = self._resolve_model(model_name)
model, _ = self._resolve_model(model_name)

return {config: self._get_options(value)
for config, value in model['configurations'].items()}
Expand Down Expand Up @@ -868,7 +869,7 @@ def possible_options_for_model_and_configuration(self,
If the provided `configuration` is not supported for the `model_name` model of this
instrument.
"""
configurations = self._resolve_model(model_name)['configurations']
configurations = self._resolve_model(model_name)[0]['configurations']

try:
configurations = configurations[configuration]
Expand Down Expand Up @@ -924,7 +925,7 @@ def default_option_for_configuration(self, model_name: str, configuration: str)
If the provided `configuration` is not supported for the `model_name` model of this
instrument.
"""
configurations = self._resolve_model(model_name)['configurations']
configurations = self._resolve_model(model_name)[0]['configurations']

try:
configurations = configurations[configuration]
Expand Down

0 comments on commit 400d6f8

Please sign in to comment.