Skip to content

Commit

Permalink
feat: Add Protocol for API objects
Browse files Browse the repository at this point in the history
Fixes some type errors that have been seen by users and in other
internal code. Also makes some methods public that were being used by
QuantinuumBackend and fixes some mismatched types.
  • Loading branch information
johnchildren committed Apr 26, 2024
1 parent 057610f commit 3f13d43
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 20 deletions.
80 changes: 71 additions & 9 deletions pytket/extensions/quantinuum/backends/api_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import time
from http import HTTPStatus
from typing import Optional, Dict, Tuple, List
from typing import Optional, Dict, Protocol, Tuple, List
import asyncio
import json
import getpass
Expand Down Expand Up @@ -46,7 +46,7 @@ class QuantinuumAPIError(Exception):
class _OverrideManager:
def __init__(
self,
api_handler: "QuantinuumAPI",
api_handler: "QuantinuumAPIProtocol",
timeout: Optional[int] = None,
retry_timeout: Optional[int] = None,
):
Expand All @@ -67,7 +67,54 @@ def __exit__(self, exc_type, exc_value, traceback) -> None: # type: ignore
self.api_handler.retry_timeout = self._orig_retry


class QuantinuumAPI:
class QuantinuumAPIProtocol(Protocol):
"""
Protocol for classes that can interact with the Quantinuum API.
"""

online: bool
provider: str | None
timeout: Optional[int]
retry_timeout: Optional[int]

def override_timeouts(
self, timeout: Optional[int] = None, retry_timeout: Optional[int] = None
) -> _OverrideManager:
...

def full_login(self) -> None:
...

def login(self) -> str:
...

def delete_authentication(self) -> None:
...

def submit_job(self, body: Dict) -> Response:
...

def retrieve_job_status(
self, job_id: str, use_websocket: Optional[bool] = None
) -> Optional[Dict]:
...

def retrieve_job(
self, job_id: str, use_websocket: Optional[bool] = None
) -> Optional[Dict]:
...

def status(self, machine: str) -> str:
...

def cancel(self, job_id: str) -> dict:
...

def get_calendar(self, start_date: str, end_date: str) -> List[Dict[str, str]]:
...


class QuantinuumAPI(QuantinuumAPIProtocol):
"""
Interface to the Quantinuum online remote API.
"""
Expand Down Expand Up @@ -302,7 +349,7 @@ def delete_authentication(self) -> None:
"""Remove stored credentials and tokens"""
self._cred_store.delete_credential()

def _submit_job(self, body: Dict) -> Response:
def submit_job(self, body: Dict) -> Response:
id_token = self.login()
# send job request
return self.session.post(
Expand Down Expand Up @@ -545,7 +592,7 @@ def get_calendar(self, start_date: str, end_date: str) -> List[Dict[str, str]]:
]


class QuantinuumAPIOffline:
class QuantinuumAPIOffline(QuantinuumAPIProtocol):
"""
Offline copy of the interface to the Quantinuum remote API.
"""
Expand Down Expand Up @@ -578,6 +625,8 @@ def __init__(self, machine_list: Optional[list] = None):
self.machine_list = machine_list
self._cred_store = None
self.submitted: list = []
self.timeout = None
self.retry_timeout = None

def _get_machine_list(self) -> Optional[list]:
"""returns the given list of the avilable machines
Expand All @@ -586,6 +635,11 @@ def _get_machine_list(self) -> Optional[list]:

return self.machine_list

def override_timeouts(
self, timeout: Optional[int] = None, retry_timeout: Optional[int] = None
) -> _OverrideManager:
return _OverrideManager(self, timeout=timeout, retry_timeout=retry_timeout)

def full_login(self) -> None:
"""No login offline with the offline API"""

Expand All @@ -596,15 +650,18 @@ def login(self) -> str:
return an empty api token"""
return ""

def _submit_job(self, body: Dict) -> None:
def delete_authentication(self) -> None:
return None

def submit_job(self, body: Dict) -> Response:
"""The function will take the submitted job and store it for later
:param body: submitted job
:return: None
"""
self.submitted.append(body)
return None
return Response()

def get_jobs(self) -> Optional[list]:
"""The function will return all the jobs that have been submitted
Expand All @@ -626,7 +683,7 @@ def _response_check(self, res: Response, description: str) -> None:

def retrieve_job_status(
self, job_id: str, use_websocket: Optional[bool] = None
) -> None:
) -> Optional[Dict]:
"""No retrieve_job_status offline"""
raise QuantinuumAPIError(
(
Expand All @@ -635,7 +692,9 @@ def retrieve_job_status(
)
)

def retrieve_job(self, job_id: str, use_websocket: Optional[bool] = None) -> None:
def retrieve_job(
self, job_id: str, use_websocket: Optional[bool] = None
) -> Optional[Dict]:
"""No retrieve_job_status offline"""
raise QuantinuumAPIError(
(
Expand All @@ -652,3 +711,6 @@ def status(self, machine: str) -> str:
def cancel(self, job_id: str) -> dict:
"""No cancel offline"""
raise QuantinuumAPIError((f"Can't cancel job offline: job_id {job_id}."))

def get_calendar(self, start_date: str, end_date: str) -> List[Dict[str, str]]:
return []
3 changes: 0 additions & 3 deletions pytket/extensions/quantinuum/backends/credential_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@


class CredentialStorage(ABC):

"""Storage for Quantinuum username and tokens"""

def __init__(
Expand Down Expand Up @@ -65,7 +64,6 @@ def user_name(self) -> Optional[str]:


class MemoryCredentialStorage(CredentialStorage):

"""In memory credential storage. Intended use is only to store id tokens,
refresh tokens and user_name. Password storage is only included for debug
purposes."""
Expand Down Expand Up @@ -138,7 +136,6 @@ def delete_credential(self) -> None:


class QuantinuumConfigCredentialStorage(CredentialStorage):

"""Store tokens in the default pytket configuration file."""

def __init__(
Expand Down
14 changes: 8 additions & 6 deletions pytket/extensions/quantinuum/backends/quantinuum.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
)

from pytket.extensions.quantinuum.backends.leakage_gadget import get_detection_circuit
from .api_wrappers import QuantinuumAPIError, QuantinuumAPI
from .api_wrappers import QuantinuumAPIError, QuantinuumAPI, QuantinuumAPIProtocol


try:
Expand Down Expand Up @@ -269,7 +269,7 @@ def __init__(
group: Optional[str] = None,
provider: Optional[str] = None,
machine_debug: bool = False,
api_handler: QuantinuumAPI = DEFAULT_API_HANDLER,
api_handler: QuantinuumAPIProtocol = DEFAULT_API_HANDLER,
compilation_config: Optional[QuantinuumBackendCompilationConfig] = None,
**kwargs: QuumKwargTypes,
):
Expand Down Expand Up @@ -867,7 +867,7 @@ def submit_program(
body.update(request_options or {})

try:
res = self.api_handler._submit_job(body)
res = self.api_handler.submit_job(body)
if self.api_handler.online:
jobdict = res.json()
if res.status_code != HTTPStatus.OK:
Expand Down Expand Up @@ -1014,9 +1014,11 @@ def process_circuits(
quantinuum_circ = circuit_to_qasm_str(
c0,
header="hqslib1",
maxwidth=self.backend_info.misc["cl_reg_width"]
if self.backend_info
else 32,
maxwidth=(
self.backend_info.misc["cl_reg_width"]
if self.backend_info
else 32
),
)
used_scratch_regs = _used_scratch_registers(quantinuum_circ)
for name, count in Counter(
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/api1_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def match_normal_request(request: requests.PreparedRequest) -> bool:
assert normal_login_route.called_once
# Check that the mfa login has been invoked
assert mfa_login_route.called_once
assert isinstance(backend.api_handler, QuantinuumAPI)
assert backend.api_handler._cred_store.id_token is not None
assert backend.api_handler._cred_store.refresh_token is not None

Expand Down Expand Up @@ -331,6 +332,7 @@ def test_federated_login(

mock_microsoft_login.assert_called_once()
assert login_route.called_once
assert isinstance(backend.api_handler, QuantinuumAPI)
assert backend.api_handler._cred_store.id_token is not None
assert backend.api_handler._cred_store.refresh_token is not None

Expand Down
6 changes: 4 additions & 2 deletions tests/unit/offline_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
def test_quantinuum_offline(language: Language) -> None:
qapioffline = QuantinuumAPIOffline()
backend = QuantinuumBackend(
device_name="H1-1", machine_debug=False, api_handler=qapioffline # type: ignore
device_name="H1-1", machine_debug=False, api_handler=qapioffline
)
c = Circuit(4, 4, "test 1")
c.H(0)
Expand Down Expand Up @@ -72,7 +72,7 @@ def test_quantinuum_offline(language: Language) -> None:
def test_max_classical_register_ii() -> None:
qapioffline = QuantinuumAPIOffline()
backend = QuantinuumBackend(
device_name="H1-1", machine_debug=False, api_handler=qapioffline # type: ignore
device_name="H1-1", machine_debug=False, api_handler=qapioffline
)

c = Circuit(4, 4, "test 1")
Expand Down Expand Up @@ -167,4 +167,6 @@ def test_custom_api_handler(device_name: str) -> None:
backend_2 = QuantinuumBackend(device_name, api_handler=handler_2)

assert backend_1.api_handler is not backend_2.api_handler
assert isinstance(backend_1.api_handler, QuantinuumAPI)
assert isinstance(backend_2.api_handler, QuantinuumAPI)
assert backend_1.api_handler._cred_store is not backend_2.api_handler._cred_store

0 comments on commit 3f13d43

Please sign in to comment.