Skip to content

Commit 456bde8

Browse files
authored
(torchx/specs) component fn python-3.10 compatibility. Support BinOp for optional types and builtin container types.
Differential Revision: D81826386 Pull Request resolved: #1110
1 parent 3bb8b40 commit 456bde8

File tree

4 files changed

+262
-126
lines changed

4 files changed

+262
-126
lines changed

docs/source/specs.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,17 +90,17 @@ Component Linter
9090
.. autoclass:: LinterMessage
9191
:members:
9292

93-
.. autoclass:: TorchFunctionVisitor
93+
.. autoclass:: ComponentFnVisitor
9494
:members:
9595

9696
.. autoclass:: TorchXArgumentHelpFormatter
9797
:members:
9898

99-
.. autoclass:: TorchxFunctionArgsValidator
99+
.. autoclass:: ArgTypeValidator
100100
:members:
101101

102-
.. autoclass:: TorchxFunctionValidator
102+
.. autoclass:: ComponentFunctionValidator
103103
:members:
104104

105-
.. autoclass:: TorchxReturnValidator
105+
.. autoclass:: ReturnTypeValidator
106106
:members:

torchx/specs/file_linter.py

Lines changed: 116 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
import argparse
1212
import ast
1313
import inspect
14+
import sys
1415
from dataclasses import dataclass
15-
from typing import Callable, cast, Dict, List, Optional, Tuple
16+
from typing import Callable, Dict, List, Optional, Tuple
1617

1718
from docstring_parser import parse
1819
from torchx.util.io import read_conf_file
@@ -98,7 +99,7 @@ class LinterMessage:
9899
severity: str = "error"
99100

100101

101-
class TorchxFunctionValidator(abc.ABC):
102+
class ComponentFunctionValidator(abc.ABC):
102103
@abc.abstractmethod
103104
def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
104105
"""
@@ -116,7 +117,55 @@ def _gen_linter_message(self, description: str, lineno: int) -> LinterMessage:
116117
)
117118

118119

119-
class TorchxFunctionArgsValidator(TorchxFunctionValidator):
120+
def OK() -> list[LinterMessage]:
121+
return [] # empty linter error means validation passes
122+
123+
124+
def is_primitive(arg: ast.expr) -> bool:
125+
# whether the arg is a primitive type (e.g. int, float, str, bool)
126+
return isinstance(arg, ast.Name)
127+
128+
129+
def get_generic_type(arg: ast.expr) -> ast.expr:
130+
# returns the slice expr of a subscripted type
131+
# `arg` must be an instance of ast.Subscript (caller checks)
132+
# in this validator's context, this is the generic type of a container type
133+
# e.g. for Optional[str] returns the expr for str
134+
135+
assert isinstance(arg, ast.Subscript) # e.g. arg = C[T]
136+
137+
if isinstance(arg.slice, ast.Index): # python>=3.10
138+
return arg.slice.value
139+
else: # python-3.9
140+
return arg.slice
141+
142+
143+
def get_optional_type(arg: ast.expr) -> Optional[ast.expr]:
144+
"""
145+
Returns the type parameter ``T`` of ``Optional[T]`` or ``None`` if `arg``
146+
is not an ``Optional``. Handles both:
147+
1. ``typing.Optional[T]`` (python<3.10)
148+
2. ``T | None`` or ``None | T`` (python>=3.10 - PEP 604)
149+
"""
150+
# case 1: 'a: Optional[T]'
151+
if isinstance(arg, ast.Subscript) and arg.value.id == "Optional":
152+
return get_generic_type(arg)
153+
154+
# case 2: 'a: T | None' or 'a: None | T'
155+
if sys.version_info >= (3, 10): # PEP 604 introduced in python-3.10
156+
if isinstance(arg, ast.BinOp) and isinstance(arg.op, ast.BitOr):
157+
if isinstance(arg.right, ast.Constant) and arg.right.value is None:
158+
return arg.left
159+
if isinstance(arg.left, ast.Constant) and arg.left.value is None:
160+
return arg.right
161+
162+
# case 3: is not optional
163+
return None
164+
165+
166+
class ArgTypeValidator(ComponentFunctionValidator):
167+
"""Validates component function's argument types."""
168+
120169
def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
121170
linter_errors = []
122171
for arg_def in app_specs_func_def.args.args:
@@ -133,53 +182,68 @@ def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
133182
return linter_errors
134183

135184
def _validate_arg_def(
136-
self, function_name: str, arg_def: ast.arg
185+
self, function_name: str, arg: ast.arg
137186
) -> List[LinterMessage]:
138-
if not arg_def.annotation:
139-
return [
140-
self._gen_linter_message(
141-
f"Arg {arg_def.arg} missing type annotation", arg_def.lineno
142-
)
143-
]
144-
if isinstance(arg_def.annotation, ast.Name):
187+
arg_type = arg.annotation # type hint
188+
189+
def ok() -> list[LinterMessage]:
190+
# return value when validation passes (e.g. no linter errors)
145191
return []
146-
complex_type_def = cast(ast.Subscript, none_throws(arg_def.annotation))
147-
if complex_type_def.value.id == "Optional":
148-
# ast module in python3.9 does not have ast.Index wrapper
149-
if isinstance(complex_type_def.slice, ast.Index):
150-
complex_type_def = complex_type_def.slice.value
151-
else:
152-
complex_type_def = complex_type_def.slice
153-
# Check if type is Optional[primitive_type]
154-
if isinstance(complex_type_def, ast.Name):
155-
return []
156-
# Check if type is Union[Dict,List]
157-
type_name = complex_type_def.value.id
158-
if type_name != "Dict" and type_name != "List":
159-
desc = (
160-
f"`{function_name}` allows only Dict, List as complex types."
161-
f"Argument `{arg_def.arg}` has: {type_name}"
162-
)
163-
return [self._gen_linter_message(desc, arg_def.lineno)]
164-
linter_errors = []
165-
# ast module in python3.9 does not have objects wrapped in ast.Index
166-
if isinstance(complex_type_def.slice, ast.Index):
167-
sub_type = complex_type_def.slice.value
192+
193+
def err(reason: str) -> list[LinterMessage]:
194+
msg = f"{reason} for argument {ast.unparse(arg)!r} in function {function_name!r}"
195+
return [self._gen_linter_message(msg, arg.lineno)]
196+
197+
if not arg_type:
198+
return err("Missing type annotation")
199+
200+
# Case 1: optional
201+
if T := get_optional_type(arg_type):
202+
# NOTE: optional types can be primitives or any of the allowed container types
203+
# so check if arg is an optional, and if so, run the rest of the validation with the unpacked type
204+
arg_type = T
205+
206+
# Case 2: int, float, str, bool
207+
if is_primitive(arg_type):
208+
return ok()
209+
# Case 3: Containers (Dict, List, Tuple)
210+
elif isinstance(arg_type, ast.Subscript):
211+
container_type = arg_type.value.id
212+
213+
if container_type in ["Dict", "dict"]:
214+
KV = get_generic_type(arg_type)
215+
216+
assert isinstance(KV, ast.Tuple) # dict[K,V] has ast.Tuple slice
217+
218+
K, V = KV.elts
219+
if not is_primitive(K):
220+
return err(f"Non-primitive key type {ast.unparse(K)!r}")
221+
if not is_primitive(V):
222+
return err(f"Non-primitive value type {ast.unparse(V)!r}")
223+
return ok()
224+
elif container_type in ["List", "list"]:
225+
T = get_generic_type(arg_type)
226+
if is_primitive(T):
227+
return ok()
228+
else:
229+
return err(f"Non-primitive element type {ast.unparse(T)!r}")
230+
elif container_type in ["Tuple", "tuple"]:
231+
E_N = get_generic_type(arg_type)
232+
assert isinstance(E_N, ast.Tuple) # tuple[...] has ast.Tuple slice
233+
234+
for e in E_N.elts:
235+
if not is_primitive(e):
236+
return err(f"Non-primitive element type '{ast.unparse(e)!r}'")
237+
238+
return ok()
239+
240+
return err(f"Unsupported container type {container_type!r}")
168241
else:
169-
sub_type = complex_type_def.slice
170-
if type_name == "Dict":
171-
sub_type_tuple = cast(ast.Tuple, sub_type)
172-
for el in sub_type_tuple.elts:
173-
if not isinstance(el, ast.Name):
174-
desc = "Dict can only have primitive types"
175-
linter_errors.append(self._gen_linter_message(desc, arg_def.lineno))
176-
elif not isinstance(sub_type, ast.Name):
177-
desc = "List can only have primitive types"
178-
linter_errors.append(self._gen_linter_message(desc, arg_def.lineno))
179-
return linter_errors
242+
return err(f"Unsupported argument type {ast.unparse(arg_type)!r}")
180243

181244

182-
class TorchxReturnValidator(TorchxFunctionValidator):
245+
class ReturnTypeValidator(ComponentFunctionValidator):
246+
"""Validates that component functions always return AppDef type"""
183247

184248
def __init__(self, supported_return_type: str) -> None:
185249
super().__init__()
@@ -231,7 +295,7 @@ def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
231295
return linter_errors
232296

233297

234-
class TorchFunctionVisitor(ast.NodeVisitor):
298+
class ComponentFnVisitor(ast.NodeVisitor):
235299
"""
236300
Visitor that finds the component_function and runs registered validators on it.
237301
Current registered validators:
@@ -252,12 +316,12 @@ class TorchFunctionVisitor(ast.NodeVisitor):
252316
def __init__(
253317
self,
254318
component_function_name: str,
255-
validators: Optional[List[TorchxFunctionValidator]],
319+
validators: Optional[List[ComponentFunctionValidator]],
256320
) -> None:
257321
if validators is None:
258-
self.validators: List[TorchxFunctionValidator] = [
259-
TorchxFunctionArgsValidator(),
260-
TorchxReturnValidator("AppDef"),
322+
self.validators: List[ComponentFunctionValidator] = [
323+
ArgTypeValidator(),
324+
ReturnTypeValidator("AppDef"),
261325
]
262326
else:
263327
self.validators = validators
@@ -279,7 +343,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
279343
def validate(
280344
path: str,
281345
component_function: str,
282-
validators: Optional[List[TorchxFunctionValidator]],
346+
validators: Optional[List[ComponentFunctionValidator]] = None,
283347
) -> List[LinterMessage]:
284348
"""
285349
Validates the function to make sure it complies the component standard.
@@ -309,7 +373,7 @@ def validate(
309373
severity="error",
310374
)
311375
return [linter_message]
312-
visitor = TorchFunctionVisitor(component_function, validators)
376+
visitor = ComponentFnVisitor(component_function, validators)
313377
visitor.visit(module)
314378
linter_errors = visitor.linter_errors
315379
if not visitor.visited_function:

torchx/specs/finder.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
from types import ModuleType
2020
from typing import Any, Callable, Dict, Generator, List, Optional, Union
2121

22-
from torchx.specs.file_linter import get_fn_docstring, TorchxFunctionValidator, validate
22+
from torchx.specs.file_linter import (
23+
ComponentFunctionValidator,
24+
get_fn_docstring,
25+
validate,
26+
)
2327
from torchx.util import entrypoints
2428
from torchx.util.io import read_conf_file
2529
from torchx.util.types import none_throws
@@ -64,7 +68,7 @@ class _Component:
6468
class ComponentsFinder(abc.ABC):
6569
@abc.abstractmethod
6670
def find(
67-
self, validators: Optional[List[TorchxFunctionValidator]]
71+
self, validators: Optional[List[ComponentFunctionValidator]]
6872
) -> List[_Component]:
6973
"""
7074
Retrieves a set of components. A component is defined as a python
@@ -210,7 +214,7 @@ def _iter_modules_recursive(
210214
yield self._try_import(module_info.name)
211215

212216
def find(
213-
self, validators: Optional[List[TorchxFunctionValidator]]
217+
self, validators: Optional[List[ComponentFunctionValidator]]
214218
) -> List[_Component]:
215219
components = []
216220
for m in self._iter_modules_recursive(self.base_module):
@@ -230,7 +234,7 @@ def _try_import(self, module: Union[str, ModuleType]) -> ModuleType:
230234
return module
231235

232236
def _get_components_from_module(
233-
self, module: ModuleType, validators: Optional[List[TorchxFunctionValidator]]
237+
self, module: ModuleType, validators: Optional[List[ComponentFunctionValidator]]
234238
) -> List[_Component]:
235239
functions = getmembers(module, isfunction)
236240
component_defs = []
@@ -269,7 +273,7 @@ def _get_validation_errors(
269273
self,
270274
path: str,
271275
function_name: str,
272-
validators: Optional[List[TorchxFunctionValidator]],
276+
validators: Optional[List[ComponentFunctionValidator]],
273277
) -> List[str]:
274278
linter_errors = validate(path, function_name, validators)
275279
return [linter_error.description for linter_error in linter_errors]
@@ -289,7 +293,7 @@ def _get_path_to_function_decl(
289293
return path_to_function_decl
290294

291295
def find(
292-
self, validators: Optional[List[TorchxFunctionValidator]]
296+
self, validators: Optional[List[ComponentFunctionValidator]]
293297
) -> List[_Component]:
294298

295299
file_source = read_conf_file(self._filepath)
@@ -321,7 +325,7 @@ def find(
321325

322326

323327
def _load_custom_components(
324-
validators: Optional[List[TorchxFunctionValidator]],
328+
validators: Optional[List[ComponentFunctionValidator]],
325329
) -> List[_Component]:
326330
component_modules = {
327331
name: load_fn()
@@ -346,7 +350,7 @@ def _load_custom_components(
346350

347351

348352
def _load_components(
349-
validators: Optional[List[TorchxFunctionValidator]],
353+
validators: Optional[List[ComponentFunctionValidator]],
350354
) -> Dict[str, _Component]:
351355
"""
352356
Loads either the custom component defs from the entrypoint ``[torchx.components]``
@@ -368,7 +372,7 @@ def _load_components(
368372

369373

370374
def _find_components(
371-
validators: Optional[List[TorchxFunctionValidator]],
375+
validators: Optional[List[ComponentFunctionValidator]],
372376
) -> Dict[str, _Component]:
373377
global _components
374378
if not _components:
@@ -381,7 +385,7 @@ def _is_custom_component(component_name: str) -> bool:
381385

382386

383387
def _find_custom_components(
384-
name: str, validators: Optional[List[TorchxFunctionValidator]]
388+
name: str, validators: Optional[List[ComponentFunctionValidator]]
385389
) -> Dict[str, _Component]:
386390
if ":" not in name:
387391
raise ValueError(
@@ -393,7 +397,7 @@ def _find_custom_components(
393397

394398

395399
def get_components(
396-
validators: Optional[List[TorchxFunctionValidator]] = None,
400+
validators: Optional[List[ComponentFunctionValidator]] = None,
397401
) -> Dict[str, _Component]:
398402
"""
399403
Returns all custom components registered via ``[torchx.components]`` entrypoints
@@ -448,7 +452,7 @@ def get_components(
448452

449453

450454
def get_component(
451-
name: str, validators: Optional[List[TorchxFunctionValidator]] = None
455+
name: str, validators: Optional[List[ComponentFunctionValidator]] = None
452456
) -> _Component:
453457
"""
454458
Retrieves components by the provided name.
@@ -477,7 +481,7 @@ def get_component(
477481

478482

479483
def get_builtin_source(
480-
name: str, validators: Optional[List[TorchxFunctionValidator]] = None
484+
name: str, validators: Optional[List[ComponentFunctionValidator]] = None
481485
) -> str:
482486
"""
483487
Returns a string of the the builtin component's function source code

0 commit comments

Comments
 (0)