Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Protocol for API objects #426

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 71 additions & 9 deletions pytket/extensions/quantinuum/backends/api_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import time
from http import HTTPStatus
from typing import Optional, Dict, Tuple, List
from typing import Optional, Dict, Protocol, Tuple, List
import asyncio
import json
import getpass
Expand Down Expand Up @@ -46,7 +46,7 @@ class QuantinuumAPIError(Exception):
class _OverrideManager:
def __init__(
self,
api_handler: "QuantinuumAPI",
api_handler: "QuantinuumAPIProtocol",
timeout: Optional[int] = None,
retry_timeout: Optional[int] = None,
):
Expand All @@ -67,7 +67,54 @@ def __exit__(self, exc_type, exc_value, traceback) -> None: # type: ignore
self.api_handler.retry_timeout = self._orig_retry


class QuantinuumAPI:
class QuantinuumAPIProtocol(Protocol):
"""
Protocol for classes that can interact with the Quantinuum API.
"""

online: bool
provider: str | None
timeout: Optional[int]
retry_timeout: Optional[int]

def override_timeouts(
self, timeout: Optional[int] = None, retry_timeout: Optional[int] = None
) -> _OverrideManager:
...

def full_login(self) -> None:
...

def login(self) -> str:
...

def delete_authentication(self) -> None:
...

def submit_job(self, body: Dict) -> Response:
...

def retrieve_job_status(
self, job_id: str, use_websocket: Optional[bool] = None
) -> Optional[Dict]:
...

def retrieve_job(
self, job_id: str, use_websocket: Optional[bool] = None
) -> Optional[Dict]:
...

def status(self, machine: str) -> str:
...

def cancel(self, job_id: str) -> dict:
...

def get_calendar(self, start_date: str, end_date: str) -> List[Dict[str, str]]:
...


class QuantinuumAPI(QuantinuumAPIProtocol):
"""
Interface to the Quantinuum online remote API.
"""
Expand Down Expand Up @@ -302,7 +349,7 @@ def delete_authentication(self) -> None:
"""Remove stored credentials and tokens"""
self._cred_store.delete_credential()

def _submit_job(self, body: Dict) -> Response:
def submit_job(self, body: Dict) -> Response:
id_token = self.login()
# send job request
return self.session.post(
Expand Down Expand Up @@ -585,7 +632,7 @@ def get_calendar(self, start_date: str, end_date: str) -> List[Dict[str, str]]:
]


class QuantinuumAPIOffline:
class QuantinuumAPIOffline(QuantinuumAPIProtocol):
"""
Offline copy of the interface to the Quantinuum remote API.
"""
Expand Down Expand Up @@ -618,6 +665,8 @@ def __init__(self, machine_list: Optional[list] = None):
self.machine_list = machine_list
self._cred_store = None
self.submitted: list = []
self.timeout = None
self.retry_timeout = None

def _get_machine_list(self) -> Optional[list]:
"""returns the given list of the avilable machines
Expand All @@ -626,6 +675,11 @@ def _get_machine_list(self) -> Optional[list]:

return self.machine_list

def override_timeouts(
self, timeout: Optional[int] = None, retry_timeout: Optional[int] = None
) -> _OverrideManager:
return _OverrideManager(self, timeout=timeout, retry_timeout=retry_timeout)

def full_login(self) -> None:
"""No login offline with the offline API"""

Expand All @@ -636,15 +690,18 @@ def login(self) -> str:
return an empty api token"""
return ""

def _submit_job(self, body: Dict) -> None:
def delete_authentication(self) -> None:
return None

def submit_job(self, body: Dict) -> Response:
"""The function will take the submitted job and store it for later

:param body: submitted job

:return: None
"""
self.submitted.append(body)
return None
return Response()

def get_jobs(self) -> Optional[list]:
"""The function will return all the jobs that have been submitted
Expand All @@ -666,7 +723,7 @@ def _response_check(self, res: Response, description: str) -> None:

def retrieve_job_status(
self, job_id: str, use_websocket: Optional[bool] = None
) -> None:
) -> Optional[Dict]:
"""No retrieve_job_status offline"""
raise QuantinuumAPIError(
(
Expand All @@ -675,7 +732,9 @@ def retrieve_job_status(
)
)

def retrieve_job(self, job_id: str, use_websocket: Optional[bool] = None) -> None:
def retrieve_job(
self, job_id: str, use_websocket: Optional[bool] = None
) -> Optional[Dict]:
"""No retrieve_job_status offline"""
raise QuantinuumAPIError(
(
Expand All @@ -692,3 +751,6 @@ def status(self, machine: str) -> str:
def cancel(self, job_id: str) -> dict:
"""No cancel offline"""
raise QuantinuumAPIError((f"Can't cancel job offline: job_id {job_id}."))

def get_calendar(self, start_date: str, end_date: str) -> List[Dict[str, str]]:
return []
3 changes: 0 additions & 3 deletions pytket/extensions/quantinuum/backends/credential_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@


class CredentialStorage(ABC):

"""Storage for Quantinuum username and tokens"""

def __init__(
Expand Down Expand Up @@ -65,7 +64,6 @@ def user_name(self) -> Optional[str]:


class MemoryCredentialStorage(CredentialStorage):

"""In memory credential storage. Intended use is only to store id tokens,
refresh tokens and user_name. Password storage is only included for debug
purposes."""
Expand Down Expand Up @@ -138,7 +136,6 @@ def delete_credential(self) -> None:


class QuantinuumConfigCredentialStorage(CredentialStorage):

"""Store tokens in the default pytket configuration file."""

def __init__(
Expand Down
14 changes: 8 additions & 6 deletions pytket/extensions/quantinuum/backends/quantinuum.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
)

from pytket.extensions.quantinuum.backends.leakage_gadget import get_detection_circuit
from .api_wrappers import QuantinuumAPIError, QuantinuumAPI
from .api_wrappers import QuantinuumAPIError, QuantinuumAPI, QuantinuumAPIProtocol


try:
Expand Down Expand Up @@ -281,7 +281,7 @@ def __init__(
group: Optional[str] = None,
provider: Optional[str] = None,
machine_debug: bool = False,
api_handler: QuantinuumAPI = DEFAULT_API_HANDLER,
api_handler: QuantinuumAPIProtocol = DEFAULT_API_HANDLER,
compilation_config: Optional[QuantinuumBackendCompilationConfig] = None,
**kwargs: QuumKwargTypes,
):
Expand Down Expand Up @@ -878,7 +878,7 @@ def submit_program(
body.update(request_options or {})

try:
res = self.api_handler._submit_job(body)
res = self.api_handler.submit_job(body)
if self.api_handler.online:
jobdict = res.json()
if res.status_code != HTTPStatus.OK:
Expand Down Expand Up @@ -1026,9 +1026,11 @@ def process_circuits(
quantinuum_circ = circuit_to_qasm_str(
c0,
header="hqslib1",
maxwidth=self.backend_info.misc["cl_reg_width"]
if self.backend_info
else 32,
maxwidth=(
self.backend_info.misc["cl_reg_width"]
if self.backend_info
else 32
),
)
used_scratch_regs = _used_scratch_registers(quantinuum_circ)
for name, count in Counter(
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/api1_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def match_normal_request(request: requests.PreparedRequest) -> bool:
assert normal_login_route.called_once
# Check that the mfa login has been invoked
assert mfa_login_route.called_once
assert isinstance(backend.api_handler, QuantinuumAPI)
assert backend.api_handler._cred_store.id_token is not None
assert backend.api_handler._cred_store.refresh_token is not None

Expand Down Expand Up @@ -331,6 +332,7 @@ def test_federated_login(

mock_microsoft_login.assert_called_once()
assert login_route.called_once
assert isinstance(backend.api_handler, QuantinuumAPI)
assert backend.api_handler._cred_store.id_token is not None
assert backend.api_handler._cred_store.refresh_token is not None

Expand Down
6 changes: 4 additions & 2 deletions tests/unit/offline_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
def test_quantinuum_offline(language: Language) -> None:
qapioffline = QuantinuumAPIOffline()
backend = QuantinuumBackend(
device_name="H1-1", machine_debug=False, api_handler=qapioffline # type: ignore
device_name="H1-1", machine_debug=False, api_handler=qapioffline
)
c = Circuit(4, 4, "test 1")
c.H(0)
Expand Down Expand Up @@ -72,7 +72,7 @@ def test_quantinuum_offline(language: Language) -> None:
def test_max_classical_register_ii() -> None:
qapioffline = QuantinuumAPIOffline()
backend = QuantinuumBackend(
device_name="H1-1", machine_debug=False, api_handler=qapioffline # type: ignore
device_name="H1-1", machine_debug=False, api_handler=qapioffline
)

c = Circuit(4, 4, "test 1")
Expand Down Expand Up @@ -167,4 +167,6 @@ def test_custom_api_handler(device_name: str) -> None:
backend_2 = QuantinuumBackend(device_name, api_handler=handler_2)

assert backend_1.api_handler is not backend_2.api_handler
assert isinstance(backend_1.api_handler, QuantinuumAPI)
assert isinstance(backend_2.api_handler, QuantinuumAPI)
assert backend_1.api_handler._cred_store is not backend_2.api_handler._cred_store