Skip to content

Commit d97bd20

Browse files
support types in the typing module (#37)
This will be very useful for all downstream projects to set types like `list[str]`, `list[list[str]]`, etc. ```py >>> ca = Argument("key1", List[float]) >>> ca.check({"key1": [1, 2.0, 3]}) pass >>> ca.check({"key1": [1, 2.0, "3"]}) throw ArgumentTypeError ``` --------- Signed-off-by: Jinzhe Zeng <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 900cf68 commit d97bd20

File tree

5 files changed

+45
-14
lines changed

5 files changed

+45
-14
lines changed

dargs/dargs.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323
import re
2424
from copy import deepcopy
2525
from enum import Enum
26-
from numbers import Real
2726
from textwrap import indent
28-
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
27+
from typing import Any, Callable, Dict, Iterable, List, Optional, Union, get_origin
2928

29+
import typeguard
3030

3131
INDENT = " " # doc is indented by four spaces
3232
RAW_ANCHOR = False # whether to use raw html anchors or RST ones
@@ -176,7 +176,7 @@ def __eq__(self, other: "Argument") -> bool:
176176
)
177177

178178
def __repr__(self) -> str:
179-
return f"<Argument {self.name}: {' | '.join(dd.__name__ for dd in self.dtype)}>"
179+
return f"<Argument {self.name}: {' | '.join(self._get_type_name(dd) for dd in self.dtype)}>"
180180

181181
def __getitem__(self, key: str) -> "Argument":
182182
key = key.lstrip("/")
@@ -205,10 +205,17 @@ def I(self):
205205
return Argument("_", dict, [self])
206206

207207
def _reorg_dtype(self):
208-
if isinstance(self.dtype, type) or self.dtype is None:
208+
if (
209+
isinstance(self.dtype, type)
210+
or isinstance(get_origin(self.dtype), type)
211+
or self.dtype is None
212+
):
209213
self.dtype = [self.dtype]
210214
# remove duplicate
211-
self.dtype = {dt if type(dt) is type else type(dt) for dt in self.dtype}
215+
self.dtype = {
216+
dt if type(dt) is type or type(get_origin(dt)) is type else type(dt)
217+
for dt in self.dtype
218+
}
212219
# check conner cases
213220
if self.sub_fields or self.sub_variants:
214221
self.dtype.add(list if self.repeat else dict)
@@ -414,16 +421,19 @@ def _check_exist(self, argdict: dict, path=None):
414421
)
415422

416423
def _check_data(self, value: Any, path=None):
417-
if not (
418-
isinstance(value, self.dtype)
419-
or (float in self.dtype and isinstance(value, Real))
420-
):
424+
try:
425+
typeguard.check_type(
426+
value,
427+
self.dtype,
428+
collection_check_strategy=typeguard.CollectionCheckStrategy.ALL_ITEMS,
429+
)
430+
except typeguard.TypeCheckError as e:
421431
raise ArgumentTypeError(
422432
path,
423433
f"key `{self.name}` gets wrong value type, "
424-
f"requires <{'|'.join(dd.__name__ for dd in self.dtype)}> "
425-
f"but gets <{type(value).__name__}>",
426-
)
434+
f"requires <{'|'.join(self._get_type_name(dd) for dd in self.dtype)}> "
435+
f"but " + str(e),
436+
) from e
427437
if self.extra_check is not None and not self.extra_check(value):
428438
raise ArgumentValueError(
429439
path,
@@ -586,7 +596,9 @@ def gen_doc(self, path: Optional[List[str]] = None, **kwargs) -> str:
586596
return "\n".join(filter(None, doc_list))
587597

588598
def gen_doc_head(self, path: Optional[List[str]] = None, **kwargs) -> str:
589-
typesig = "| type: " + " | ".join([f"``{dt.__name__}``" for dt in self.dtype])
599+
typesig = "| type: " + " | ".join(
600+
[f"``{self._get_type_name(dt)}``" for dt in self.dtype]
601+
)
590602
if self.optional:
591603
typesig += ", optional"
592604
if self.default == "":
@@ -632,6 +644,10 @@ def gen_doc_body(self, path: Optional[List[str]] = None, **kwargs) -> str:
632644
body = "\n".join(body_list)
633645
return body
634646

647+
def _get_type_name(self, dd) -> str:
648+
"""Get type name for doc/message generation."""
649+
return str(dd) if isinstance(get_origin(dd), type) else dd.__name__
650+
635651

636652
class Variant:
637653
"""Define multiple choices of possible argument sets.
@@ -993,6 +1009,8 @@ def default(self, obj) -> Dict[str, Union[str, bool, List]]:
9931009
"choice_alias": obj.choice_alias,
9941010
"doc": obj.doc,
9951011
}
1012+
elif isinstance(get_origin(obj), type):
1013+
return get_origin(obj).__name__
9961014
elif isinstance(obj, type):
9971015
return obj.__name__
9981016
return json.JSONEncoder.default(self, obj)

dargs/sphinx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,5 +192,5 @@ def _test_arguments() -> List[Argument]:
192192
return [
193193
Argument(name="test1", dtype=int, doc="Argument 1"),
194194
Argument(name="test2", dtype=[float, None], doc="Argument 2"),
195-
Argument(name="test3", dtype=list, doc="Argument 3"),
195+
Argument(name="test3", dtype=List[str], doc="Argument 3"),
196196
]

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ classifiers = [
1616
"License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)",
1717
]
1818
dependencies = [
19+
"typeguard>=3",
1920
]
2021
requires-python = ">=3.7"
2122
readme = "README.md"

tests/test_checker.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import List
12
from .context import dargs
23
import unittest
34
from dargs import Argument, Variant
@@ -27,6 +28,11 @@ def test_name_type(self):
2728
# special handel of int and float
2829
ca = Argument("key1", float)
2930
ca.check({"key1": 1})
31+
# list[int]
32+
ca = Argument("key1", List[float])
33+
ca.check({"key1": [1, 2.0, 3]})
34+
with self.assertRaises(ArgumentTypeError):
35+
ca.check({"key1": [1, 2.0, "3"]})
3036
# optional case
3137
ca = Argument("key1", int, optional=True)
3238
ca.check({})

tests/test_docgen.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .context import dargs
22
import unittest
33
import json
4+
from typing import List
45
from dargs import Argument, Variant, ArgumentEncoder
56

67

@@ -22,6 +23,11 @@ def test_sub_fields(self):
2223
[Argument("subsubsub1", int, doc="subsubsub doc." * 5)],
2324
doc="subsub doc." * 5,
2425
),
26+
Argument(
27+
"list_of_float",
28+
List[float],
29+
doc="Check if List[float] works.",
30+
),
2531
],
2632
doc="sub doc." * 5,
2733
),

0 commit comments

Comments
 (0)