From cc117bcc72251c93838d3c8b3b5c0770d4fb12e1 Mon Sep 17 00:00:00 2001 From: John Children Date: Fri, 26 Apr 2024 10:55:11 +0100 Subject: [PATCH] feat: Add Protocol for API objects Fixes some type errors that have been seen by users and in other internal code. Also makes some methods public that were being used by QuantinuumBackend and fixes some mismatched types. --- .../quantinuum/backends/api_wrappers.py | 80 ++++++++++++++++--- .../quantinuum/backends/credential_storage.py | 3 - .../quantinuum/backends/quantinuum.py | 14 ++-- tests/unit/api1_test.py | 2 + tests/unit/offline_backend_test.py | 6 +- 5 files changed, 85 insertions(+), 20 deletions(-) diff --git a/pytket/extensions/quantinuum/backends/api_wrappers.py b/pytket/extensions/quantinuum/backends/api_wrappers.py index b2c00f5e..fa41700a 100644 --- a/pytket/extensions/quantinuum/backends/api_wrappers.py +++ b/pytket/extensions/quantinuum/backends/api_wrappers.py @@ -18,7 +18,7 @@ import time from http import HTTPStatus -from typing import Optional, Dict, Tuple, List +from typing import Optional, Dict, Protocol, Tuple, List import asyncio import json import getpass @@ -46,7 +46,7 @@ class QuantinuumAPIError(Exception): class _OverrideManager: def __init__( self, - api_handler: "QuantinuumAPI", + api_handler: "QuantinuumAPIProtocol", timeout: Optional[int] = None, retry_timeout: Optional[int] = None, ): @@ -67,7 +67,54 @@ def __exit__(self, exc_type, exc_value, traceback) -> None: # type: ignore self.api_handler.retry_timeout = self._orig_retry -class QuantinuumAPI: +class QuantinuumAPIProtocol(Protocol): + """ + Protocol for classes that can interact with the Quantinuum API. + """ + + online: bool + provider: str | None + timeout: Optional[int] + retry_timeout: Optional[int] + + def override_timeouts( + self, timeout: Optional[int] = None, retry_timeout: Optional[int] = None + ) -> _OverrideManager: + ... + + def full_login(self) -> None: + ... + + def login(self) -> str: + ... + + def delete_authentication(self) -> None: + ... + + def submit_job(self, body: Dict) -> Response: + ... + + def retrieve_job_status( + self, job_id: str, use_websocket: Optional[bool] = None + ) -> Optional[Dict]: + ... + + def retrieve_job( + self, job_id: str, use_websocket: Optional[bool] = None + ) -> Optional[Dict]: + ... + + def status(self, machine: str) -> str: + ... + + def cancel(self, job_id: str) -> dict: + ... + + def get_calendar(self, start_date: str, end_date: str) -> List[Dict[str, str]]: + ... + + +class QuantinuumAPI(QuantinuumAPIProtocol): """ Interface to the Quantinuum online remote API. """ @@ -302,7 +349,7 @@ def delete_authentication(self) -> None: """Remove stored credentials and tokens""" self._cred_store.delete_credential() - def _submit_job(self, body: Dict) -> Response: + def submit_job(self, body: Dict) -> Response: id_token = self.login() # send job request return self.session.post( @@ -585,7 +632,7 @@ def get_calendar(self, start_date: str, end_date: str) -> List[Dict[str, str]]: ] -class QuantinuumAPIOffline: +class QuantinuumAPIOffline(QuantinuumAPIProtocol): """ Offline copy of the interface to the Quantinuum remote API. """ @@ -618,6 +665,8 @@ def __init__(self, machine_list: Optional[list] = None): self.machine_list = machine_list self._cred_store = None self.submitted: list = [] + self.timeout = None + self.retry_timeout = None def _get_machine_list(self) -> Optional[list]: """returns the given list of the avilable machines @@ -626,6 +675,11 @@ def _get_machine_list(self) -> Optional[list]: return self.machine_list + def override_timeouts( + self, timeout: Optional[int] = None, retry_timeout: Optional[int] = None + ) -> _OverrideManager: + return _OverrideManager(self, timeout=timeout, retry_timeout=retry_timeout) + def full_login(self) -> None: """No login offline with the offline API""" @@ -636,7 +690,10 @@ def login(self) -> str: return an empty api token""" return "" - def _submit_job(self, body: Dict) -> None: + def delete_authentication(self) -> None: + return None + + def submit_job(self, body: Dict) -> Response: """The function will take the submitted job and store it for later :param body: submitted job @@ -644,7 +701,7 @@ def _submit_job(self, body: Dict) -> None: :return: None """ self.submitted.append(body) - return None + return Response() def get_jobs(self) -> Optional[list]: """The function will return all the jobs that have been submitted @@ -666,7 +723,7 @@ def _response_check(self, res: Response, description: str) -> None: def retrieve_job_status( self, job_id: str, use_websocket: Optional[bool] = None - ) -> None: + ) -> Optional[Dict]: """No retrieve_job_status offline""" raise QuantinuumAPIError( ( @@ -675,7 +732,9 @@ def retrieve_job_status( ) ) - def retrieve_job(self, job_id: str, use_websocket: Optional[bool] = None) -> None: + def retrieve_job( + self, job_id: str, use_websocket: Optional[bool] = None + ) -> Optional[Dict]: """No retrieve_job_status offline""" raise QuantinuumAPIError( ( @@ -692,3 +751,6 @@ def status(self, machine: str) -> str: def cancel(self, job_id: str) -> dict: """No cancel offline""" raise QuantinuumAPIError((f"Can't cancel job offline: job_id {job_id}.")) + + def get_calendar(self, start_date: str, end_date: str) -> List[Dict[str, str]]: + return [] diff --git a/pytket/extensions/quantinuum/backends/credential_storage.py b/pytket/extensions/quantinuum/backends/credential_storage.py index 93064da1..5cdb292f 100644 --- a/pytket/extensions/quantinuum/backends/credential_storage.py +++ b/pytket/extensions/quantinuum/backends/credential_storage.py @@ -20,7 +20,6 @@ class CredentialStorage(ABC): - """Storage for Quantinuum username and tokens""" def __init__( @@ -65,7 +64,6 @@ def user_name(self) -> Optional[str]: class MemoryCredentialStorage(CredentialStorage): - """In memory credential storage. Intended use is only to store id tokens, refresh tokens and user_name. Password storage is only included for debug purposes.""" @@ -138,7 +136,6 @@ def delete_credential(self) -> None: class QuantinuumConfigCredentialStorage(CredentialStorage): - """Store tokens in the default pytket configuration file.""" def __init__( diff --git a/pytket/extensions/quantinuum/backends/quantinuum.py b/pytket/extensions/quantinuum/backends/quantinuum.py index f9c5fe17..650c090a 100644 --- a/pytket/extensions/quantinuum/backends/quantinuum.py +++ b/pytket/extensions/quantinuum/backends/quantinuum.py @@ -75,7 +75,7 @@ ) from pytket.extensions.quantinuum.backends.leakage_gadget import get_detection_circuit -from .api_wrappers import QuantinuumAPIError, QuantinuumAPI +from .api_wrappers import QuantinuumAPIError, QuantinuumAPI, QuantinuumAPIProtocol try: @@ -281,7 +281,7 @@ def __init__( group: Optional[str] = None, provider: Optional[str] = None, machine_debug: bool = False, - api_handler: QuantinuumAPI = DEFAULT_API_HANDLER, + api_handler: QuantinuumAPIProtocol = DEFAULT_API_HANDLER, compilation_config: Optional[QuantinuumBackendCompilationConfig] = None, **kwargs: QuumKwargTypes, ): @@ -878,7 +878,7 @@ def submit_program( body.update(request_options or {}) try: - res = self.api_handler._submit_job(body) + res = self.api_handler.submit_job(body) if self.api_handler.online: jobdict = res.json() if res.status_code != HTTPStatus.OK: @@ -1026,9 +1026,11 @@ def process_circuits( quantinuum_circ = circuit_to_qasm_str( c0, header="hqslib1", - maxwidth=self.backend_info.misc["cl_reg_width"] - if self.backend_info - else 32, + maxwidth=( + self.backend_info.misc["cl_reg_width"] + if self.backend_info + else 32 + ), ) used_scratch_regs = _used_scratch_registers(quantinuum_circ) for name, count in Counter( diff --git a/tests/unit/api1_test.py b/tests/unit/api1_test.py index 1a09fa50..1aae5ae0 100644 --- a/tests/unit/api1_test.py +++ b/tests/unit/api1_test.py @@ -295,6 +295,7 @@ def match_normal_request(request: requests.PreparedRequest) -> bool: assert normal_login_route.called_once # Check that the mfa login has been invoked assert mfa_login_route.called_once + assert isinstance(backend.api_handler, QuantinuumAPI) assert backend.api_handler._cred_store.id_token is not None assert backend.api_handler._cred_store.refresh_token is not None @@ -331,6 +332,7 @@ def test_federated_login( mock_microsoft_login.assert_called_once() assert login_route.called_once + assert isinstance(backend.api_handler, QuantinuumAPI) assert backend.api_handler._cred_store.id_token is not None assert backend.api_handler._cred_store.refresh_token is not None diff --git a/tests/unit/offline_backend_test.py b/tests/unit/offline_backend_test.py index 7811616a..865b310c 100644 --- a/tests/unit/offline_backend_test.py +++ b/tests/unit/offline_backend_test.py @@ -35,7 +35,7 @@ def test_quantinuum_offline(language: Language) -> None: qapioffline = QuantinuumAPIOffline() backend = QuantinuumBackend( - device_name="H1-1", machine_debug=False, api_handler=qapioffline # type: ignore + device_name="H1-1", machine_debug=False, api_handler=qapioffline ) c = Circuit(4, 4, "test 1") c.H(0) @@ -72,7 +72,7 @@ def test_quantinuum_offline(language: Language) -> None: def test_max_classical_register_ii() -> None: qapioffline = QuantinuumAPIOffline() backend = QuantinuumBackend( - device_name="H1-1", machine_debug=False, api_handler=qapioffline # type: ignore + device_name="H1-1", machine_debug=False, api_handler=qapioffline ) c = Circuit(4, 4, "test 1") @@ -167,4 +167,6 @@ def test_custom_api_handler(device_name: str) -> None: backend_2 = QuantinuumBackend(device_name, api_handler=handler_2) assert backend_1.api_handler is not backend_2.api_handler + assert isinstance(backend_1.api_handler, QuantinuumAPI) + assert isinstance(backend_2.api_handler, QuantinuumAPI) assert backend_1.api_handler._cred_store is not backend_2.api_handler._cred_store