Skip to content

Commit

Permalink
Improve Interval type hint
Browse files Browse the repository at this point in the history
  • Loading branch information
rahuliyer95 committed May 8, 2024
1 parent 3e3fec6 commit 16a6295
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 127 deletions.
2 changes: 1 addition & 1 deletion src/pendulum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def duration(
)


def interval(start: DateTime, end: DateTime, absolute: bool = False) -> Interval:
def interval(start: DateTime, end: DateTime, absolute: bool = False) -> Interval[DateTime]:
"""
Create an Interval instance.
"""
Expand Down
6 changes: 3 additions & 3 deletions src/pendulum/date.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,10 @@ def __sub__(self, __dt: datetime) -> NoReturn:
...

@overload
def __sub__(self, __dt: Self) -> Interval:
def __sub__(self, __dt: Self) -> Interval[Date]:
...

def __sub__(self, other: timedelta | date) -> Self | Interval:
def __sub__(self, other: timedelta | date) -> Self | Interval[Date]:
if isinstance(other, timedelta):
return self._subtract_timedelta(other)

Expand All @@ -281,7 +281,7 @@ def __sub__(self, other: timedelta | date) -> Self | Interval:

# DIFFERENCES

def diff(self, dt: date | None = None, abs: bool = True) -> Interval:
def diff(self, dt: date | None = None, abs: bool = True) -> Interval[Date]:
"""
Returns the difference between two Date objects as an Interval.
Expand Down
10 changes: 6 additions & 4 deletions src/pendulum/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ def _subtract_timedelta(self, delta: datetime.timedelta) -> Self:

def diff( # type: ignore[override]
self, dt: datetime.datetime | None = None, abs: bool = True
) -> Interval:
) -> Interval[datetime.datetime]:
"""
Returns the difference between two DateTime objects represented as an Interval.
"""
Expand Down Expand Up @@ -1190,10 +1190,12 @@ def __sub__(self, other: datetime.timedelta) -> Self:
...

@overload
def __sub__(self, other: DateTime) -> Interval:
def __sub__(self, other: DateTime) -> Interval[datetime.datetime]:
...

def __sub__(self, other: datetime.datetime | datetime.timedelta) -> Self | Interval:
def __sub__(
self, other: datetime.datetime | datetime.timedelta
) -> Self | Interval[datetime.datetime]:
if isinstance(other, datetime.timedelta):
return self._subtract_timedelta(other)

Expand All @@ -1216,7 +1218,7 @@ def __sub__(self, other: datetime.datetime | datetime.timedelta) -> Self | Inter

return other.diff(self, False)

def __rsub__(self, other: datetime.datetime) -> Interval:
def __rsub__(self, other: datetime.datetime) -> Interval[datetime.datetime]:
if not isinstance(other, datetime.datetime):
return NotImplemented

Expand Down
200 changes: 82 additions & 118 deletions src/pendulum/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from datetime import datetime
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import Generic
from typing import Iterator
from typing import Union
from typing import TypeVar
from typing import cast
from typing import overload

Expand All @@ -26,35 +27,15 @@
from pendulum.locales.locale import Locale


class Interval(Duration):
_T = TypeVar("_T", bound=date)


class Interval(Duration, Generic[_T]):
"""
An interval of time between two datetimes.
"""

@overload
def __new__(
cls,
start: pendulum.DateTime | datetime,
end: pendulum.DateTime | datetime,
absolute: bool = False,
) -> Self:
...

@overload
def __new__(
cls,
start: pendulum.Date | date,
end: pendulum.Date | date,
absolute: bool = False,
) -> Self:
...

def __new__(
cls,
start: pendulum.DateTime | pendulum.Date | datetime | date,
end: pendulum.DateTime | pendulum.Date | datetime | date,
absolute: bool = False,
) -> Self:
def __new__(cls, start: _T, end: _T, absolute: bool = False) -> Self:
if (
isinstance(start, datetime)
and not isinstance(end, datetime)
Expand Down Expand Up @@ -83,34 +64,40 @@ def __new__(
_start = start
_end = end
if isinstance(start, pendulum.DateTime):
_start = datetime(
start.year,
start.month,
start.day,
start.hour,
start.minute,
start.second,
start.microsecond,
tzinfo=start.tzinfo,
fold=start.fold,
_start = cast(
_T,
datetime(
start.year,
start.month,
start.day,
start.hour,
start.minute,
start.second,
start.microsecond,
tzinfo=start.tzinfo,
fold=start.fold,
),
)
elif isinstance(start, pendulum.Date):
_start = date(start.year, start.month, start.day)
_start = cast(_T, date(start.year, start.month, start.day))

if isinstance(end, pendulum.DateTime):
_end = datetime(
end.year,
end.month,
end.day,
end.hour,
end.minute,
end.second,
end.microsecond,
tzinfo=end.tzinfo,
fold=end.fold,
_end = cast(
_T,
datetime(
end.year,
end.month,
end.day,
end.hour,
end.minute,
end.second,
end.microsecond,
tzinfo=end.tzinfo,
fold=end.fold,
),
)
elif isinstance(end, pendulum.Date):
_end = date(end.year, end.month, end.day)
_end = cast(_T, date(end.year, end.month, end.day))

# Fixing issues with datetime.__sub__()
# not handling offsets if the tzinfo is the same
Expand All @@ -121,69 +108,70 @@ def __new__(
):
if _start.tzinfo is not None:
offset = cast(timedelta, cast(datetime, start).utcoffset())
_start = (_start - offset).replace(tzinfo=None)
_start = cast(_T, (_start - offset).replace(tzinfo=None))

if isinstance(end, datetime) and _end.tzinfo is not None:
offset = cast(timedelta, end.utcoffset())
_end = (_end - offset).replace(tzinfo=None)
_end = cast(_T, (_end - offset).replace(tzinfo=None))

delta: timedelta = _end - _start # type: ignore[operator]
delta: timedelta = _end - _start

return super().__new__(cls, seconds=delta.total_seconds())

def __init__(
self,
start: pendulum.DateTime | pendulum.Date | datetime | date,
end: pendulum.DateTime | pendulum.Date | datetime | date,
absolute: bool = False,
) -> None:
def __init__(self, start: _T, end: _T, absolute: bool = False) -> None:
super().__init__()

_start: pendulum.DateTime | pendulum.Date | datetime | date
_start: _T
if not isinstance(start, pendulum.Date):
if isinstance(start, datetime):
start = pendulum.instance(start)
start = cast(_T, pendulum.instance(start))
else:
start = pendulum.date(start.year, start.month, start.day)
start = cast(_T, pendulum.date(start.year, start.month, start.day))

_start = start
else:
if isinstance(start, pendulum.DateTime):
_start = datetime(
start.year,
start.month,
start.day,
start.hour,
start.minute,
start.second,
start.microsecond,
tzinfo=start.tzinfo,
_start = cast(
_T,
datetime(
start.year,
start.month,
start.day,
start.hour,
start.minute,
start.second,
start.microsecond,
tzinfo=start.tzinfo,
),
)
else:
_start = date(start.year, start.month, start.day)
_start = cast(_T, date(start.year, start.month, start.day))

_end: pendulum.DateTime | pendulum.Date | datetime | date
_end: _T
if not isinstance(end, pendulum.Date):
if isinstance(end, datetime):
end = pendulum.instance(end)
end = cast(_T, pendulum.instance(end))
else:
end = pendulum.date(end.year, end.month, end.day)
end = cast(_T, pendulum.date(end.year, end.month, end.day))

_end = end
else:
if isinstance(end, pendulum.DateTime):
_end = datetime(
end.year,
end.month,
end.day,
end.hour,
end.minute,
end.second,
end.microsecond,
tzinfo=end.tzinfo,
_end = cast(
_T,
datetime(
end.year,
end.month,
end.day,
end.hour,
end.minute,
end.second,
end.microsecond,
tzinfo=end.tzinfo,
),
)
else:
_end = date(end.year, end.month, end.day)
_end = cast(_T, date(end.year, end.month, end.day))

self._invert = False
if start > end:
Expand All @@ -194,8 +182,8 @@ def __init__(
_end, _start = _start, _end

self._absolute = absolute
self._start: pendulum.DateTime | pendulum.Date = start
self._end: pendulum.DateTime | pendulum.Date = end
self._start: _T = start
self._end: _T = end
self._delta: PreciseDiff = precise_diff(_start, _end)

@property
Expand Down Expand Up @@ -227,11 +215,11 @@ def minutes(self) -> int:
return self._delta.minutes

@property
def start(self) -> pendulum.DateTime | pendulum.Date | datetime | date:
def start(self) -> _T:
return self._start

@property
def end(self) -> pendulum.DateTime | pendulum.Date | datetime | date:
def end(self) -> _T:
return self._end

def in_years(self) -> int:
Expand Down Expand Up @@ -301,9 +289,7 @@ def in_words(self, locale: str | None = None, separator: str = " ") -> str:

return separator.join(parts)

def range(
self, unit: str, amount: int = 1
) -> Iterator[pendulum.DateTime | pendulum.Date]:
def range(self, unit: str, amount: int = 1) -> Iterator[_T]:
method = "add"
op = operator.le
if not self._absolute and self.invert:
Expand All @@ -314,7 +300,7 @@ def range(

i = amount
while op(start, end):
yield cast(Union[pendulum.DateTime, pendulum.Date], start)
yield start

start = getattr(self.start, method)(**{unit: i})

Expand All @@ -326,12 +312,10 @@ def as_duration(self) -> Duration:
"""
return Duration(seconds=self.total_seconds())

def __iter__(self) -> Iterator[pendulum.DateTime | pendulum.Date]:
def __iter__(self) -> Iterator[_T]:
return self.range("days")

def __contains__(
self, item: datetime | date | pendulum.DateTime | pendulum.Date
) -> bool:
def __contains__(self, item: _T) -> bool:
return self.start <= item <= self.end

def __add__(self, other: timedelta) -> Duration: # type: ignore[override]
Expand Down Expand Up @@ -400,13 +384,7 @@ def _cmp(self, other: timedelta) -> int:

return 0 if td == other else 1 if td > other else -1

def _getstate(
self, protocol: SupportsIndex = 3
) -> tuple[
pendulum.DateTime | pendulum.Date | datetime | date,
pendulum.DateTime | pendulum.Date | datetime | date,
bool,
]:
def _getstate(self, protocol: SupportsIndex = 3) -> tuple[_T, _T, bool]:
start, end = self.start, self.end

if self._invert and self._absolute:
Expand All @@ -416,26 +394,12 @@ def _getstate(

def __reduce__(
self,
) -> tuple[
type[Self],
tuple[
pendulum.DateTime | pendulum.Date | datetime | date,
pendulum.DateTime | pendulum.Date | datetime | date,
bool,
],
]:
) -> tuple[type[Self], tuple[_T, _T, bool]]:
return self.__reduce_ex__(2)

def __reduce_ex__(
self, protocol: SupportsIndex
) -> tuple[
type[Self],
tuple[
pendulum.DateTime | pendulum.Date | datetime | date,
pendulum.DateTime | pendulum.Date | datetime | date,
bool,
],
]:
) -> tuple[type[Self], tuple[_T, _T, bool]]:
return self.__class__, self._getstate(protocol)

def __hash__(self) -> int:
Expand Down
4 changes: 3 additions & 1 deletion src/pendulum/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def parse(text: str, **options: t.Any) -> Date | Time | DateTime | Duration:
return _parse(text, **options)


def _parse(text: str, **options: t.Any) -> Date | DateTime | Time | Duration | Interval:
def _parse(
text: str, **options: t.Any
) -> Date | DateTime | Time | Duration | Interval[DateTime]:
"""
Parses a string with the given options.
Expand Down

0 comments on commit 16a6295

Please sign in to comment.