|
23 | 23 | import re
|
24 | 24 | from copy import deepcopy
|
25 | 25 | from enum import Enum
|
26 |
| -from numbers import Real |
27 | 26 | from textwrap import indent
|
28 |
| -from typing import Any, Callable, Dict, Iterable, List, Optional, Union |
| 27 | +from typing import Any, Callable, Dict, Iterable, List, Optional, Union, get_origin |
29 | 28 |
|
| 29 | +import typeguard |
30 | 30 |
|
31 | 31 | INDENT = " " # doc is indented by four spaces
|
32 | 32 | RAW_ANCHOR = False # whether to use raw html anchors or RST ones
|
@@ -176,7 +176,7 @@ def __eq__(self, other: "Argument") -> bool:
|
176 | 176 | )
|
177 | 177 |
|
178 | 178 | def __repr__(self) -> str:
|
179 |
| - return f"<Argument {self.name}: {' | '.join(dd.__name__ for dd in self.dtype)}>" |
| 179 | + return f"<Argument {self.name}: {' | '.join(self._get_type_name(dd) for dd in self.dtype)}>" |
180 | 180 |
|
181 | 181 | def __getitem__(self, key: str) -> "Argument":
|
182 | 182 | key = key.lstrip("/")
|
@@ -205,10 +205,17 @@ def I(self):
|
205 | 205 | return Argument("_", dict, [self])
|
206 | 206 |
|
207 | 207 | def _reorg_dtype(self):
|
208 |
| - if isinstance(self.dtype, type) or self.dtype is None: |
| 208 | + if ( |
| 209 | + isinstance(self.dtype, type) |
| 210 | + or isinstance(get_origin(self.dtype), type) |
| 211 | + or self.dtype is None |
| 212 | + ): |
209 | 213 | self.dtype = [self.dtype]
|
210 | 214 | # remove duplicate
|
211 |
| - self.dtype = {dt if type(dt) is type else type(dt) for dt in self.dtype} |
| 215 | + self.dtype = { |
| 216 | + dt if type(dt) is type or type(get_origin(dt)) is type else type(dt) |
| 217 | + for dt in self.dtype |
| 218 | + } |
212 | 219 | # check conner cases
|
213 | 220 | if self.sub_fields or self.sub_variants:
|
214 | 221 | self.dtype.add(list if self.repeat else dict)
|
@@ -414,16 +421,19 @@ def _check_exist(self, argdict: dict, path=None):
|
414 | 421 | )
|
415 | 422 |
|
416 | 423 | def _check_data(self, value: Any, path=None):
|
417 |
| - if not ( |
418 |
| - isinstance(value, self.dtype) |
419 |
| - or (float in self.dtype and isinstance(value, Real)) |
420 |
| - ): |
| 424 | + try: |
| 425 | + typeguard.check_type( |
| 426 | + value, |
| 427 | + self.dtype, |
| 428 | + collection_check_strategy=typeguard.CollectionCheckStrategy.ALL_ITEMS, |
| 429 | + ) |
| 430 | + except typeguard.TypeCheckError as e: |
421 | 431 | raise ArgumentTypeError(
|
422 | 432 | path,
|
423 | 433 | f"key `{self.name}` gets wrong value type, "
|
424 |
| - f"requires <{'|'.join(dd.__name__ for dd in self.dtype)}> " |
425 |
| - f"but gets <{type(value).__name__}>", |
426 |
| - ) |
| 434 | + f"requires <{'|'.join(self._get_type_name(dd) for dd in self.dtype)}> " |
| 435 | + f"but " + str(e), |
| 436 | + ) from e |
427 | 437 | if self.extra_check is not None and not self.extra_check(value):
|
428 | 438 | raise ArgumentValueError(
|
429 | 439 | path,
|
@@ -586,7 +596,9 @@ def gen_doc(self, path: Optional[List[str]] = None, **kwargs) -> str:
|
586 | 596 | return "\n".join(filter(None, doc_list))
|
587 | 597 |
|
588 | 598 | def gen_doc_head(self, path: Optional[List[str]] = None, **kwargs) -> str:
|
589 |
| - typesig = "| type: " + " | ".join([f"``{dt.__name__}``" for dt in self.dtype]) |
| 599 | + typesig = "| type: " + " | ".join( |
| 600 | + [f"``{self._get_type_name(dt)}``" for dt in self.dtype] |
| 601 | + ) |
590 | 602 | if self.optional:
|
591 | 603 | typesig += ", optional"
|
592 | 604 | if self.default == "":
|
@@ -632,6 +644,10 @@ def gen_doc_body(self, path: Optional[List[str]] = None, **kwargs) -> str:
|
632 | 644 | body = "\n".join(body_list)
|
633 | 645 | return body
|
634 | 646 |
|
| 647 | + def _get_type_name(self, dd) -> str: |
| 648 | + """Get type name for doc/message generation.""" |
| 649 | + return str(dd) if isinstance(get_origin(dd), type) else dd.__name__ |
| 650 | + |
635 | 651 |
|
636 | 652 | class Variant:
|
637 | 653 | """Define multiple choices of possible argument sets.
|
@@ -993,6 +1009,8 @@ def default(self, obj) -> Dict[str, Union[str, bool, List]]:
|
993 | 1009 | "choice_alias": obj.choice_alias,
|
994 | 1010 | "doc": obj.doc,
|
995 | 1011 | }
|
| 1012 | + elif isinstance(get_origin(obj), type): |
| 1013 | + return get_origin(obj).__name__ |
996 | 1014 | elif isinstance(obj, type):
|
997 | 1015 | return obj.__name__
|
998 | 1016 | return json.JSONEncoder.default(self, obj)
|
0 commit comments