Skip to content

Commit 51fb4c3

Browse files
authored
Improve Interval type hint (#828)
* Improve Interval type hint * Format files with black
1 parent b84b976 commit 51fb4c3

File tree

5 files changed

+97
-127
lines changed

5 files changed

+97
-127
lines changed

src/pendulum/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,9 @@ def duration(
331331
)
332332

333333

334-
def interval(start: DateTime, end: DateTime, absolute: bool = False) -> Interval:
334+
def interval(
335+
start: DateTime, end: DateTime, absolute: bool = False
336+
) -> Interval[DateTime]:
335337
"""
336338
Create an Interval instance.
337339
"""

src/pendulum/date.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -265,10 +265,10 @@ def __sub__(self, __dt: datetime) -> NoReturn:
265265
...
266266

267267
@overload
268-
def __sub__(self, __dt: Self) -> Interval:
268+
def __sub__(self, __dt: Self) -> Interval[Date]:
269269
...
270270

271-
def __sub__(self, other: timedelta | date) -> Self | Interval:
271+
def __sub__(self, other: timedelta | date) -> Self | Interval[Date]:
272272
if isinstance(other, timedelta):
273273
return self._subtract_timedelta(other)
274274

@@ -281,7 +281,7 @@ def __sub__(self, other: timedelta | date) -> Self | Interval:
281281

282282
# DIFFERENCES
283283

284-
def diff(self, dt: date | None = None, abs: bool = True) -> Interval:
284+
def diff(self, dt: date | None = None, abs: bool = True) -> Interval[Date]:
285285
"""
286286
Returns the difference between two Date objects as an Interval.
287287

src/pendulum/datetime.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ def _subtract_timedelta(self, delta: datetime.timedelta) -> Self:
704704

705705
def diff( # type: ignore[override]
706706
self, dt: datetime.datetime | None = None, abs: bool = True
707-
) -> Interval:
707+
) -> Interval[datetime.datetime]:
708708
"""
709709
Returns the difference between two DateTime objects represented as an Interval.
710710
"""
@@ -1190,10 +1190,12 @@ def __sub__(self, other: datetime.timedelta) -> Self:
11901190
...
11911191

11921192
@overload
1193-
def __sub__(self, other: DateTime) -> Interval:
1193+
def __sub__(self, other: DateTime) -> Interval[datetime.datetime]:
11941194
...
11951195

1196-
def __sub__(self, other: datetime.datetime | datetime.timedelta) -> Self | Interval:
1196+
def __sub__(
1197+
self, other: datetime.datetime | datetime.timedelta
1198+
) -> Self | Interval[datetime.datetime]:
11971199
if isinstance(other, datetime.timedelta):
11981200
return self._subtract_timedelta(other)
11991201

@@ -1216,7 +1218,7 @@ def __sub__(self, other: datetime.datetime | datetime.timedelta) -> Self | Inter
12161218

12171219
return other.diff(self, False)
12181220

1219-
def __rsub__(self, other: datetime.datetime) -> Interval:
1221+
def __rsub__(self, other: datetime.datetime) -> Interval[datetime.datetime]:
12201222
if not isinstance(other, datetime.datetime):
12211223
return NotImplemented
12221224

src/pendulum/interval.py

Lines changed: 82 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
from datetime import datetime
77
from datetime import timedelta
88
from typing import TYPE_CHECKING
9+
from typing import Generic
910
from typing import Iterator
10-
from typing import Union
11+
from typing import TypeVar
1112
from typing import cast
1213
from typing import overload
1314

@@ -26,35 +27,15 @@
2627
from pendulum.locales.locale import Locale
2728

2829

29-
class Interval(Duration):
30+
_T = TypeVar("_T", bound=date)
31+
32+
33+
class Interval(Duration, Generic[_T]):
3034
"""
3135
An interval of time between two datetimes.
3236
"""
3337

34-
@overload
35-
def __new__(
36-
cls,
37-
start: pendulum.DateTime | datetime,
38-
end: pendulum.DateTime | datetime,
39-
absolute: bool = False,
40-
) -> Self:
41-
...
42-
43-
@overload
44-
def __new__(
45-
cls,
46-
start: pendulum.Date | date,
47-
end: pendulum.Date | date,
48-
absolute: bool = False,
49-
) -> Self:
50-
...
51-
52-
def __new__(
53-
cls,
54-
start: pendulum.DateTime | pendulum.Date | datetime | date,
55-
end: pendulum.DateTime | pendulum.Date | datetime | date,
56-
absolute: bool = False,
57-
) -> Self:
38+
def __new__(cls, start: _T, end: _T, absolute: bool = False) -> Self:
5839
if (
5940
isinstance(start, datetime)
6041
and not isinstance(end, datetime)
@@ -83,34 +64,40 @@ def __new__(
8364
_start = start
8465
_end = end
8566
if isinstance(start, pendulum.DateTime):
86-
_start = datetime(
87-
start.year,
88-
start.month,
89-
start.day,
90-
start.hour,
91-
start.minute,
92-
start.second,
93-
start.microsecond,
94-
tzinfo=start.tzinfo,
95-
fold=start.fold,
67+
_start = cast(
68+
_T,
69+
datetime(
70+
start.year,
71+
start.month,
72+
start.day,
73+
start.hour,
74+
start.minute,
75+
start.second,
76+
start.microsecond,
77+
tzinfo=start.tzinfo,
78+
fold=start.fold,
79+
),
9680
)
9781
elif isinstance(start, pendulum.Date):
98-
_start = date(start.year, start.month, start.day)
82+
_start = cast(_T, date(start.year, start.month, start.day))
9983

10084
if isinstance(end, pendulum.DateTime):
101-
_end = datetime(
102-
end.year,
103-
end.month,
104-
end.day,
105-
end.hour,
106-
end.minute,
107-
end.second,
108-
end.microsecond,
109-
tzinfo=end.tzinfo,
110-
fold=end.fold,
85+
_end = cast(
86+
_T,
87+
datetime(
88+
end.year,
89+
end.month,
90+
end.day,
91+
end.hour,
92+
end.minute,
93+
end.second,
94+
end.microsecond,
95+
tzinfo=end.tzinfo,
96+
fold=end.fold,
97+
),
11198
)
11299
elif isinstance(end, pendulum.Date):
113-
_end = date(end.year, end.month, end.day)
100+
_end = cast(_T, date(end.year, end.month, end.day))
114101

115102
# Fixing issues with datetime.__sub__()
116103
# not handling offsets if the tzinfo is the same
@@ -121,69 +108,70 @@ def __new__(
121108
):
122109
if _start.tzinfo is not None:
123110
offset = cast(timedelta, cast(datetime, start).utcoffset())
124-
_start = (_start - offset).replace(tzinfo=None)
111+
_start = cast(_T, (_start - offset).replace(tzinfo=None))
125112

126113
if isinstance(end, datetime) and _end.tzinfo is not None:
127114
offset = cast(timedelta, end.utcoffset())
128-
_end = (_end - offset).replace(tzinfo=None)
115+
_end = cast(_T, (_end - offset).replace(tzinfo=None))
129116

130-
delta: timedelta = _end - _start # type: ignore[operator]
117+
delta: timedelta = _end - _start
131118

132119
return super().__new__(cls, seconds=delta.total_seconds())
133120

134-
def __init__(
135-
self,
136-
start: pendulum.DateTime | pendulum.Date | datetime | date,
137-
end: pendulum.DateTime | pendulum.Date | datetime | date,
138-
absolute: bool = False,
139-
) -> None:
121+
def __init__(self, start: _T, end: _T, absolute: bool = False) -> None:
140122
super().__init__()
141123

142-
_start: pendulum.DateTime | pendulum.Date | datetime | date
124+
_start: _T
143125
if not isinstance(start, pendulum.Date):
144126
if isinstance(start, datetime):
145-
start = pendulum.instance(start)
127+
start = cast(_T, pendulum.instance(start))
146128
else:
147-
start = pendulum.date(start.year, start.month, start.day)
129+
start = cast(_T, pendulum.date(start.year, start.month, start.day))
148130

149131
_start = start
150132
else:
151133
if isinstance(start, pendulum.DateTime):
152-
_start = datetime(
153-
start.year,
154-
start.month,
155-
start.day,
156-
start.hour,
157-
start.minute,
158-
start.second,
159-
start.microsecond,
160-
tzinfo=start.tzinfo,
134+
_start = cast(
135+
_T,
136+
datetime(
137+
start.year,
138+
start.month,
139+
start.day,
140+
start.hour,
141+
start.minute,
142+
start.second,
143+
start.microsecond,
144+
tzinfo=start.tzinfo,
145+
),
161146
)
162147
else:
163-
_start = date(start.year, start.month, start.day)
148+
_start = cast(_T, date(start.year, start.month, start.day))
164149

165-
_end: pendulum.DateTime | pendulum.Date | datetime | date
150+
_end: _T
166151
if not isinstance(end, pendulum.Date):
167152
if isinstance(end, datetime):
168-
end = pendulum.instance(end)
153+
end = cast(_T, pendulum.instance(end))
169154
else:
170-
end = pendulum.date(end.year, end.month, end.day)
155+
end = cast(_T, pendulum.date(end.year, end.month, end.day))
171156

172157
_end = end
173158
else:
174159
if isinstance(end, pendulum.DateTime):
175-
_end = datetime(
176-
end.year,
177-
end.month,
178-
end.day,
179-
end.hour,
180-
end.minute,
181-
end.second,
182-
end.microsecond,
183-
tzinfo=end.tzinfo,
160+
_end = cast(
161+
_T,
162+
datetime(
163+
end.year,
164+
end.month,
165+
end.day,
166+
end.hour,
167+
end.minute,
168+
end.second,
169+
end.microsecond,
170+
tzinfo=end.tzinfo,
171+
),
184172
)
185173
else:
186-
_end = date(end.year, end.month, end.day)
174+
_end = cast(_T, date(end.year, end.month, end.day))
187175

188176
self._invert = False
189177
if start > end:
@@ -194,8 +182,8 @@ def __init__(
194182
_end, _start = _start, _end
195183

196184
self._absolute = absolute
197-
self._start: pendulum.DateTime | pendulum.Date = start
198-
self._end: pendulum.DateTime | pendulum.Date = end
185+
self._start: _T = start
186+
self._end: _T = end
199187
self._delta: PreciseDiff = precise_diff(_start, _end)
200188

201189
@property
@@ -227,11 +215,11 @@ def minutes(self) -> int:
227215
return self._delta.minutes
228216

229217
@property
230-
def start(self) -> pendulum.DateTime | pendulum.Date | datetime | date:
218+
def start(self) -> _T:
231219
return self._start
232220

233221
@property
234-
def end(self) -> pendulum.DateTime | pendulum.Date | datetime | date:
222+
def end(self) -> _T:
235223
return self._end
236224

237225
def in_years(self) -> int:
@@ -301,9 +289,7 @@ def in_words(self, locale: str | None = None, separator: str = " ") -> str:
301289

302290
return separator.join(parts)
303291

304-
def range(
305-
self, unit: str, amount: int = 1
306-
) -> Iterator[pendulum.DateTime | pendulum.Date]:
292+
def range(self, unit: str, amount: int = 1) -> Iterator[_T]:
307293
method = "add"
308294
op = operator.le
309295
if not self._absolute and self.invert:
@@ -314,7 +300,7 @@ def range(
314300

315301
i = amount
316302
while op(start, end):
317-
yield cast(Union[pendulum.DateTime, pendulum.Date], start)
303+
yield start
318304

319305
start = getattr(self.start, method)(**{unit: i})
320306

@@ -326,12 +312,10 @@ def as_duration(self) -> Duration:
326312
"""
327313
return Duration(seconds=self.total_seconds())
328314

329-
def __iter__(self) -> Iterator[pendulum.DateTime | pendulum.Date]:
315+
def __iter__(self) -> Iterator[_T]:
330316
return self.range("days")
331317

332-
def __contains__(
333-
self, item: datetime | date | pendulum.DateTime | pendulum.Date
334-
) -> bool:
318+
def __contains__(self, item: _T) -> bool:
335319
return self.start <= item <= self.end
336320

337321
def __add__(self, other: timedelta) -> Duration: # type: ignore[override]
@@ -400,13 +384,7 @@ def _cmp(self, other: timedelta) -> int:
400384

401385
return 0 if td == other else 1 if td > other else -1
402386

403-
def _getstate(
404-
self, protocol: SupportsIndex = 3
405-
) -> tuple[
406-
pendulum.DateTime | pendulum.Date | datetime | date,
407-
pendulum.DateTime | pendulum.Date | datetime | date,
408-
bool,
409-
]:
387+
def _getstate(self, protocol: SupportsIndex = 3) -> tuple[_T, _T, bool]:
410388
start, end = self.start, self.end
411389

412390
if self._invert and self._absolute:
@@ -416,26 +394,12 @@ def _getstate(
416394

417395
def __reduce__(
418396
self,
419-
) -> tuple[
420-
type[Self],
421-
tuple[
422-
pendulum.DateTime | pendulum.Date | datetime | date,
423-
pendulum.DateTime | pendulum.Date | datetime | date,
424-
bool,
425-
],
426-
]:
397+
) -> tuple[type[Self], tuple[_T, _T, bool]]:
427398
return self.__reduce_ex__(2)
428399

429400
def __reduce_ex__(
430401
self, protocol: SupportsIndex
431-
) -> tuple[
432-
type[Self],
433-
tuple[
434-
pendulum.DateTime | pendulum.Date | datetime | date,
435-
pendulum.DateTime | pendulum.Date | datetime | date,
436-
bool,
437-
],
438-
]:
402+
) -> tuple[type[Self], tuple[_T, _T, bool]]:
439403
return self.__class__, self._getstate(protocol)
440404

441405
def __hash__(self) -> int:

0 commit comments

Comments
 (0)