Skip to content

Commit

Permalink
proper type annotation (#409)
Browse files Browse the repository at this point in the history
* remove cast if not neccesary

* remove _from_dict because it is not neccesary, using pydantic methods instead

* Q_co

* extra ignore

* version

* using cast

* version

* version

* coderabbit comments
  • Loading branch information
felipao-mx authored Jan 10, 2025
1 parent b602497 commit b308816
Show file tree
Hide file tree
Showing 31 changed files with 134 additions and 154 deletions.
4 changes: 1 addition & 3 deletions cuenca/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@
'get_balance',
]

from typing import cast

from . import http
from .resources import (
Account,
Expand Down Expand Up @@ -96,5 +94,5 @@


def get_balance(session: http.Session = session) -> int:
balance_entry = cast('BalanceEntry', BalanceEntry.first(session=session))
balance_entry = BalanceEntry.first(session=session)
return balance_entry.rolling_balance if balance_entry else 0
9 changes: 4 additions & 5 deletions cuenca/resources/api_keys.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datetime as dt
from typing import ClassVar, Optional, cast
from typing import ClassVar, Optional

from cuenca_validations.types import ApiKeyQuery, ApiKeyUpdateRequest

Expand Down Expand Up @@ -36,7 +36,7 @@ def active(self) -> bool:

@classmethod
def create(cls, *, session: Session = global_session) -> 'ApiKey':
return cast('ApiKey', cls._create(session=session))
return cls._create(session=session)

@classmethod
def deactivate(
Expand All @@ -55,7 +55,7 @@ def deactivate(
"""
url = cls._resource + f'/{api_key_id}'
resp = session.delete(url, dict(minutes=minutes))
return cast('ApiKey', cls._from_dict(resp))
return cls(**resp)

@classmethod
def update(
Expand All @@ -74,5 +74,4 @@ def update(
req = ApiKeyUpdateRequest(
metadata=metadata, user_id=user_id, platform_id=platform_id
)
resp = cls._update(api_key_id, **req.dict(), session=session)
return cast('ApiKey', resp)
return cls._update(api_key_id, **req.dict(), session=session)
4 changes: 2 additions & 2 deletions cuenca/resources/arpc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datetime as dt
from typing import ClassVar, Optional, cast
from typing import ClassVar, Optional

from cuenca_validations.types.requests import ARPCRequest

Expand Down Expand Up @@ -52,4 +52,4 @@ def create(
unique_number=unique_number,
track_data_method=track_data_method,
)
return cast('Arpc', cls._create(session=session, **req.dict()))
return cls._create(session=session, **req.dict())
117 changes: 66 additions & 51 deletions cuenca/resources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import datetime as dt
import json
from io import BytesIO
from typing import ClassVar, Dict, Generator, Optional, Union
from typing import Any, ClassVar, Generator, Optional, Type, TypeVar, cast
from urllib.parse import urlencode

from cuenca_validations.types import (
Expand All @@ -12,34 +12,21 @@
TransactionQuery,
TransactionStatus,
)
from pydantic import BaseModel
from pydantic import BaseModel, Extra

from ..exc import MultipleResultsFound, NoResultFound
from ..http import Session, session as global_session

R_co = TypeVar('R_co', bound='Resource', covariant=True)


class Resource(BaseModel):
_resource: ClassVar[str]

id: str

@classmethod
def _from_dict(cls, obj_dict: Dict[str, Union[str, int]]) -> 'Resource':
cls._filter_excess_fields(obj_dict)
return cls(**obj_dict)

@classmethod
def _filter_excess_fields(cls, obj_dict):
"""
dataclasses don't allow __init__ to be called with excess fields. This
method allows the API to add fields in the response body without
breaking the client
"""
excess = set(obj_dict.keys()) - set(
cls.schema().get("properties").keys()
)
for f in excess:
del obj_dict[f]
class Config:
extra = Extra.ignore

def to_dict(self):
return SantizedDict(self.dict())
Expand All @@ -48,22 +35,30 @@ def to_dict(self):
class Retrievable(Resource):
@classmethod
def retrieve(
cls, id: str, *, session: Session = global_session
) -> Resource:
cls: Type[R_co],
id: str,
*,
session: Session = global_session,
) -> R_co:
resp = session.get(f'/{cls._resource}/{id}')
return cls._from_dict(resp)
return cls(**resp)

def refresh(self, *, session: Session = global_session):
def refresh(self, *, session: Session = global_session) -> None:
new = self.retrieve(self.id, session=session)
for attr, value in new.__dict__.items():
setattr(self, attr, value)


class Creatable(Resource):
@classmethod
def _create(cls, *, session: Session = global_session, **data) -> Resource:
def _create(
cls: Type[R_co],
*,
session: Session = global_session,
**data: Any,
) -> R_co:
resp = session.post(cls._resource, data)
return cls._from_dict(resp)
return cls(**resp)


class Updateable(Resource):
Expand All @@ -72,31 +67,39 @@ class Updateable(Resource):

@classmethod
def _update(
cls, id: str, *, session: Session = global_session, **data
) -> Resource:
cls: Type[R_co],
id: str,
*,
session: Session = global_session,
**data: Any,
) -> R_co:
resp = session.patch(f'/{cls._resource}/{id}', data)
return cls._from_dict(resp)
return cls(**resp)


class Deactivable(Resource):
deactivated_at: Optional[dt.datetime]

@classmethod
def deactivate(
cls, id: str, *, session: Session = global_session, **data
) -> Resource:
cls: Type[R_co],
id: str,
*,
session: Session = global_session,
**data: Any,
) -> R_co:
resp = session.delete(f'/{cls._resource}/{id}', data)
return cls._from_dict(resp)
return cls(**resp)

@property
def is_active(self):
def is_active(self) -> bool:
return not self.deactivated_at


class Downloadable(Resource):
@classmethod
def download(
cls,
cls: Type[R_co],
id: str,
file_format: FileFormat = FileFormat.any,
*,
Expand All @@ -121,13 +124,13 @@ def xml(self) -> bytes:
class Uploadable(Resource):
@classmethod
def _upload(
cls,
cls: Type[R_co],
file: bytes,
user_id: str,
*,
session: Session = global_session,
**data,
) -> Resource:
**data: Any,
) -> R_co:
encoded_file = base64.b64encode(file)
resp = session.request(
'post',
Expand All @@ -138,7 +141,7 @@ def _upload(
**{k: (None, v) for k, v in data.items()},
),
)
return cls._from_dict(json.loads(resp))
return cls(**json.loads(resp))


class Queryable(Resource):
Expand All @@ -148,50 +151,62 @@ class Queryable(Resource):

@classmethod
def one(
cls, *, session: Session = global_session, **query_params
) -> Resource:
q = cls._query_params(limit=2, **query_params)
cls: Type[R_co],
*,
session: Session = global_session,
**query_params: Any,
) -> R_co:
q = cast(Queryable, cls)._query_params(limit=2, **query_params)
resp = session.get(cls._resource, q.dict())
items = resp['items']
len_items = len(items)
if not len_items:
raise NoResultFound
if len_items > 1:
raise MultipleResultsFound
return cls._from_dict(items[0])
return cls(**items[0])

@classmethod
def first(
cls, *, session: Session = global_session, **query_params
) -> Optional[Resource]:
q = cls._query_params(limit=1, **query_params)
cls: Type[R_co],
*,
session: Session = global_session,
**query_params: Any,
) -> Optional[R_co]:
q = cast(Queryable, cls)._query_params(limit=1, **query_params)
resp = session.get(cls._resource, q.dict())
try:
item = resp['items'][0]
except IndexError:
rv = None
else:
rv = cls._from_dict(item)
rv = cls(**item)
return rv

@classmethod
def count(
cls, *, session: Session = global_session, **query_params
cls: Type[R_co],
*,
session: Session = global_session,
**query_params: Any,
) -> int:
q = cls._query_params(count=True, **query_params)
q = cast(Queryable, cls)._query_params(count=True, **query_params)
resp = session.get(cls._resource, q.dict())
return resp['count']

@classmethod
def all(
cls, *, session: Session = global_session, **query_params
) -> Generator[Resource, None, None]:
cls: Type[R_co],
*,
session: Session = global_session,
**query_params: Any,
) -> Generator[R_co, None, None]:
session = session or global_session
q = cls._query_params(**query_params)
q = cast(Queryable, cls)._query_params(**query_params)
next_page_uri = f'{cls._resource}?{urlencode(q.dict())}'
while next_page_uri:
page = session.get(next_page_uri)
yield from (cls._from_dict(item) for item in page['items'])
yield from (cls(**item) for item in page['items'])
next_page_uri = page['next_page_uri']


Expand Down
4 changes: 1 addition & 3 deletions cuenca/resources/card_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ def create(
exp_year=exp_year,
cvv2=cvv2,
)
return cast(
'CardActivation', cls._create(session=session, **req.dict())
)
return cls._create(session=session, **req.dict())

@property
def card(self) -> Optional[Card]:
Expand Down
4 changes: 1 addition & 3 deletions cuenca/resources/card_validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ def create(
pin_block=pin_block,
pin_attempts_exceeded=pin_attempts_exceeded,
)
return cast(
'CardValidation', cls._create(session=session, **req.dict())
)
return cls._create(session=session, **req.dict())

@property
def card(self) -> Card:
Expand Down
9 changes: 4 additions & 5 deletions cuenca/resources/cards.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datetime as dt
from typing import ClassVar, Optional, cast
from typing import ClassVar, Optional

from cuenca_validations.types import (
CardFundingType,
Expand Down Expand Up @@ -81,7 +81,7 @@ def create(
card_holder_user_id=card_holder_user_id,
is_dynamic_cvv=is_dynamic_cvv,
)
return cast('Card', cls._create(session=session, **req.dict()))
return cls._create(session=session, **req.dict())

@classmethod
def update(
Expand All @@ -106,8 +106,7 @@ def update(
req = CardUpdateRequest(
status=status, pin_block=pin_block, is_dynamic_cvv=is_dynamic_cvv
)
resp = cls._update(card_id, session=session, **req.dict())
return cast('Card', resp)
return cls._update(card_id, session=session, **req.dict())

@classmethod
def deactivate(
Expand All @@ -118,4 +117,4 @@ def deactivate(
"""
url = f'{cls._resource}/{card_id}'
resp = session.delete(url)
return cast('Card', cls._from_dict(resp))
return cls(**resp)
4 changes: 2 additions & 2 deletions cuenca/resources/clabes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import ClassVar, cast
from typing import ClassVar

from ..http import Session, session as global_session
from .base import Creatable, Queryable, Retrievable
Expand All @@ -11,4 +11,4 @@ class Clabe(Creatable, Queryable, Retrievable):

@classmethod
def create(cls, session: Session = global_session):
return cast('Clabe', cls._create(session=session))
return cls._create(session=session)
7 changes: 2 additions & 5 deletions cuenca/resources/curp_validations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datetime as dt
from typing import ClassVar, Optional, cast
from typing import ClassVar, Optional

from cuenca_validations.types import (
Country,
Expand Down Expand Up @@ -98,7 +98,4 @@ def create(
gender=gender,
manual_curp=manual_curp,
)
return cast(
'CurpValidation',
cls._create(session=session, **req.dict()),
)
return cls._create(session=session, **req.dict())
Loading

0 comments on commit b308816

Please sign in to comment.