diff --git a/src/pendulum/__init__.py b/src/pendulum/__init__.py index 3863b760..21d1dfe7 100644 --- a/src/pendulum/__init__.py +++ b/src/pendulum/__init__.py @@ -331,7 +331,7 @@ def duration( ) -def interval(start: DateTime, end: DateTime, absolute: bool = False) -> Interval: +def interval(start: DateTime, end: DateTime, absolute: bool = False) -> Interval[DateTime]: """ Create an Interval instance. """ diff --git a/src/pendulum/date.py b/src/pendulum/date.py index e7b862c5..633bb6da 100644 --- a/src/pendulum/date.py +++ b/src/pendulum/date.py @@ -265,10 +265,10 @@ def __sub__(self, __dt: datetime) -> NoReturn: ... @overload - def __sub__(self, __dt: Self) -> Interval: + def __sub__(self, __dt: Self) -> Interval[Date]: ... - def __sub__(self, other: timedelta | date) -> Self | Interval: + def __sub__(self, other: timedelta | date) -> Self | Interval[Date]: if isinstance(other, timedelta): return self._subtract_timedelta(other) @@ -281,7 +281,7 @@ def __sub__(self, other: timedelta | date) -> Self | Interval: # DIFFERENCES - def diff(self, dt: date | None = None, abs: bool = True) -> Interval: + def diff(self, dt: date | None = None, abs: bool = True) -> Interval[Date]: """ Returns the difference between two Date objects as an Interval. diff --git a/src/pendulum/datetime.py b/src/pendulum/datetime.py index 752a9ac8..c1babbf8 100644 --- a/src/pendulum/datetime.py +++ b/src/pendulum/datetime.py @@ -704,7 +704,7 @@ def _subtract_timedelta(self, delta: datetime.timedelta) -> Self: def diff( # type: ignore[override] self, dt: datetime.datetime | None = None, abs: bool = True - ) -> Interval: + ) -> Interval[datetime.datetime]: """ Returns the difference between two DateTime objects represented as an Interval. """ @@ -1190,10 +1190,12 @@ def __sub__(self, other: datetime.timedelta) -> Self: ... @overload - def __sub__(self, other: DateTime) -> Interval: + def __sub__(self, other: DateTime) -> Interval[datetime.datetime]: ... - def __sub__(self, other: datetime.datetime | datetime.timedelta) -> Self | Interval: + def __sub__( + self, other: datetime.datetime | datetime.timedelta + ) -> Self | Interval[datetime.datetime]: if isinstance(other, datetime.timedelta): return self._subtract_timedelta(other) @@ -1216,7 +1218,7 @@ def __sub__(self, other: datetime.datetime | datetime.timedelta) -> Self | Inter return other.diff(self, False) - def __rsub__(self, other: datetime.datetime) -> Interval: + def __rsub__(self, other: datetime.datetime) -> Interval[datetime.datetime]: if not isinstance(other, datetime.datetime): return NotImplemented diff --git a/src/pendulum/interval.py b/src/pendulum/interval.py index 19c91a6a..f233312d 100644 --- a/src/pendulum/interval.py +++ b/src/pendulum/interval.py @@ -6,8 +6,9 @@ from datetime import datetime from datetime import timedelta from typing import TYPE_CHECKING +from typing import Generic from typing import Iterator -from typing import Union +from typing import TypeVar from typing import cast from typing import overload @@ -26,35 +27,15 @@ from pendulum.locales.locale import Locale -class Interval(Duration): +_T = TypeVar("_T", bound=date) + + +class Interval(Duration, Generic[_T]): """ An interval of time between two datetimes. """ - @overload - def __new__( - cls, - start: pendulum.DateTime | datetime, - end: pendulum.DateTime | datetime, - absolute: bool = False, - ) -> Self: - ... - - @overload - def __new__( - cls, - start: pendulum.Date | date, - end: pendulum.Date | date, - absolute: bool = False, - ) -> Self: - ... - - def __new__( - cls, - start: pendulum.DateTime | pendulum.Date | datetime | date, - end: pendulum.DateTime | pendulum.Date | datetime | date, - absolute: bool = False, - ) -> Self: + def __new__(cls, start: _T, end: _T, absolute: bool = False) -> Self: if ( isinstance(start, datetime) and not isinstance(end, datetime) @@ -83,34 +64,40 @@ def __new__( _start = start _end = end if isinstance(start, pendulum.DateTime): - _start = datetime( - start.year, - start.month, - start.day, - start.hour, - start.minute, - start.second, - start.microsecond, - tzinfo=start.tzinfo, - fold=start.fold, + _start = cast( + _T, + datetime( + start.year, + start.month, + start.day, + start.hour, + start.minute, + start.second, + start.microsecond, + tzinfo=start.tzinfo, + fold=start.fold, + ), ) elif isinstance(start, pendulum.Date): - _start = date(start.year, start.month, start.day) + _start = cast(_T, date(start.year, start.month, start.day)) if isinstance(end, pendulum.DateTime): - _end = datetime( - end.year, - end.month, - end.day, - end.hour, - end.minute, - end.second, - end.microsecond, - tzinfo=end.tzinfo, - fold=end.fold, + _end = cast( + _T, + datetime( + end.year, + end.month, + end.day, + end.hour, + end.minute, + end.second, + end.microsecond, + tzinfo=end.tzinfo, + fold=end.fold, + ), ) elif isinstance(end, pendulum.Date): - _end = date(end.year, end.month, end.day) + _end = cast(_T, date(end.year, end.month, end.day)) # Fixing issues with datetime.__sub__() # not handling offsets if the tzinfo is the same @@ -121,69 +108,70 @@ def __new__( ): if _start.tzinfo is not None: offset = cast(timedelta, cast(datetime, start).utcoffset()) - _start = (_start - offset).replace(tzinfo=None) + _start = cast(_T, (_start - offset).replace(tzinfo=None)) if isinstance(end, datetime) and _end.tzinfo is not None: offset = cast(timedelta, end.utcoffset()) - _end = (_end - offset).replace(tzinfo=None) + _end = cast(_T, (_end - offset).replace(tzinfo=None)) - delta: timedelta = _end - _start # type: ignore[operator] + delta: timedelta = _end - _start return super().__new__(cls, seconds=delta.total_seconds()) - def __init__( - self, - start: pendulum.DateTime | pendulum.Date | datetime | date, - end: pendulum.DateTime | pendulum.Date | datetime | date, - absolute: bool = False, - ) -> None: + def __init__(self, start: _T, end: _T, absolute: bool = False) -> None: super().__init__() - _start: pendulum.DateTime | pendulum.Date | datetime | date + _start: _T if not isinstance(start, pendulum.Date): if isinstance(start, datetime): - start = pendulum.instance(start) + start = cast(_T, pendulum.instance(start)) else: - start = pendulum.date(start.year, start.month, start.day) + start = cast(_T, pendulum.date(start.year, start.month, start.day)) _start = start else: if isinstance(start, pendulum.DateTime): - _start = datetime( - start.year, - start.month, - start.day, - start.hour, - start.minute, - start.second, - start.microsecond, - tzinfo=start.tzinfo, + _start = cast( + _T, + datetime( + start.year, + start.month, + start.day, + start.hour, + start.minute, + start.second, + start.microsecond, + tzinfo=start.tzinfo, + ), ) else: - _start = date(start.year, start.month, start.day) + _start = cast(_T, date(start.year, start.month, start.day)) - _end: pendulum.DateTime | pendulum.Date | datetime | date + _end: _T if not isinstance(end, pendulum.Date): if isinstance(end, datetime): - end = pendulum.instance(end) + end = cast(_T, pendulum.instance(end)) else: - end = pendulum.date(end.year, end.month, end.day) + end = cast(_T, pendulum.date(end.year, end.month, end.day)) _end = end else: if isinstance(end, pendulum.DateTime): - _end = datetime( - end.year, - end.month, - end.day, - end.hour, - end.minute, - end.second, - end.microsecond, - tzinfo=end.tzinfo, + _end = cast( + _T, + datetime( + end.year, + end.month, + end.day, + end.hour, + end.minute, + end.second, + end.microsecond, + tzinfo=end.tzinfo, + ), ) else: - _end = date(end.year, end.month, end.day) + _end = cast(_T, date(end.year, end.month, end.day)) self._invert = False if start > end: @@ -194,8 +182,8 @@ def __init__( _end, _start = _start, _end self._absolute = absolute - self._start: pendulum.DateTime | pendulum.Date = start - self._end: pendulum.DateTime | pendulum.Date = end + self._start: _T = start + self._end: _T = end self._delta: PreciseDiff = precise_diff(_start, _end) @property @@ -227,11 +215,11 @@ def minutes(self) -> int: return self._delta.minutes @property - def start(self) -> pendulum.DateTime | pendulum.Date | datetime | date: + def start(self) -> _T: return self._start @property - def end(self) -> pendulum.DateTime | pendulum.Date | datetime | date: + def end(self) -> _T: return self._end def in_years(self) -> int: @@ -301,9 +289,7 @@ def in_words(self, locale: str | None = None, separator: str = " ") -> str: return separator.join(parts) - def range( - self, unit: str, amount: int = 1 - ) -> Iterator[pendulum.DateTime | pendulum.Date]: + def range(self, unit: str, amount: int = 1) -> Iterator[_T]: method = "add" op = operator.le if not self._absolute and self.invert: @@ -314,7 +300,7 @@ def range( i = amount while op(start, end): - yield cast(Union[pendulum.DateTime, pendulum.Date], start) + yield start start = getattr(self.start, method)(**{unit: i}) @@ -326,12 +312,10 @@ def as_duration(self) -> Duration: """ return Duration(seconds=self.total_seconds()) - def __iter__(self) -> Iterator[pendulum.DateTime | pendulum.Date]: + def __iter__(self) -> Iterator[_T]: return self.range("days") - def __contains__( - self, item: datetime | date | pendulum.DateTime | pendulum.Date - ) -> bool: + def __contains__(self, item: _T) -> bool: return self.start <= item <= self.end def __add__(self, other: timedelta) -> Duration: # type: ignore[override] @@ -400,13 +384,7 @@ def _cmp(self, other: timedelta) -> int: return 0 if td == other else 1 if td > other else -1 - def _getstate( - self, protocol: SupportsIndex = 3 - ) -> tuple[ - pendulum.DateTime | pendulum.Date | datetime | date, - pendulum.DateTime | pendulum.Date | datetime | date, - bool, - ]: + def _getstate(self, protocol: SupportsIndex = 3) -> tuple[_T, _T, bool]: start, end = self.start, self.end if self._invert and self._absolute: @@ -416,26 +394,12 @@ def _getstate( def __reduce__( self, - ) -> tuple[ - type[Self], - tuple[ - pendulum.DateTime | pendulum.Date | datetime | date, - pendulum.DateTime | pendulum.Date | datetime | date, - bool, - ], - ]: + ) -> tuple[type[Self], tuple[_T, _T, bool]]: return self.__reduce_ex__(2) def __reduce_ex__( self, protocol: SupportsIndex - ) -> tuple[ - type[Self], - tuple[ - pendulum.DateTime | pendulum.Date | datetime | date, - pendulum.DateTime | pendulum.Date | datetime | date, - bool, - ], - ]: + ) -> tuple[type[Self], tuple[_T, _T, bool]]: return self.__class__, self._getstate(protocol) def __hash__(self) -> int: diff --git a/src/pendulum/parser.py b/src/pendulum/parser.py index 5f9a0f78..a329a321 100644 --- a/src/pendulum/parser.py +++ b/src/pendulum/parser.py @@ -30,7 +30,9 @@ def parse(text: str, **options: t.Any) -> Date | Time | DateTime | Duration: return _parse(text, **options) -def _parse(text: str, **options: t.Any) -> Date | DateTime | Time | Duration | Interval: +def _parse( + text: str, **options: t.Any +) -> Date | DateTime | Time | Duration | Interval[DateTime]: """ Parses a string with the given options.