11
11
import argparse
12
12
import ast
13
13
import inspect
14
+ import sys
14
15
from dataclasses import dataclass
15
- from typing import Callable , cast , Dict , List , Optional , Tuple
16
+ from typing import Callable , Dict , List , Optional , Tuple
16
17
17
18
from docstring_parser import parse
18
19
from torchx .util .io import read_conf_file
@@ -98,7 +99,7 @@ class LinterMessage:
98
99
severity : str = "error"
99
100
100
101
101
- class TorchxFunctionValidator (abc .ABC ):
102
+ class ComponentFunctionValidator (abc .ABC ):
102
103
@abc .abstractmethod
103
104
def validate (self , app_specs_func_def : ast .FunctionDef ) -> List [LinterMessage ]:
104
105
"""
@@ -116,7 +117,55 @@ def _gen_linter_message(self, description: str, lineno: int) -> LinterMessage:
116
117
)
117
118
118
119
119
- class TorchxFunctionArgsValidator (TorchxFunctionValidator ):
120
+ def OK () -> list [LinterMessage ]:
121
+ return [] # empty linter error means validation passes
122
+
123
+
124
+ def is_primitive (arg : ast .expr ) -> bool :
125
+ # whether the arg is a primitive type (e.g. int, float, str, bool)
126
+ return isinstance (arg , ast .Name )
127
+
128
+
129
+ def get_generic_type (arg : ast .expr ) -> ast .expr :
130
+ # returns the slice expr of a subscripted type
131
+ # `arg` must be an instance of ast.Subscript (caller checks)
132
+ # in this validator's context, this is the generic type of a container type
133
+ # e.g. for Optional[str] returns the expr for str
134
+
135
+ assert isinstance (arg , ast .Subscript ) # e.g. arg = C[T]
136
+
137
+ if isinstance (arg .slice , ast .Index ): # python>=3.10
138
+ return arg .slice .value
139
+ else : # python-3.9
140
+ return arg .slice
141
+
142
+
143
+ def get_optional_type (arg : ast .expr ) -> Optional [ast .expr ]:
144
+ """
145
+ Returns the type parameter ``T`` of ``Optional[T]`` or ``None`` if `arg``
146
+ is not an ``Optional``. Handles both:
147
+ 1. ``typing.Optional[T]`` (python<3.10)
148
+ 2. ``T | None`` or ``None | T`` (python>=3.10 - PEP 604)
149
+ """
150
+ # case 1: 'a: Optional[T]'
151
+ if isinstance (arg , ast .Subscript ) and arg .value .id == "Optional" :
152
+ return get_generic_type (arg )
153
+
154
+ # case 2: 'a: T | None' or 'a: None | T'
155
+ if sys .version_info >= (3 , 10 ): # PEP 604 introduced in python-3.10
156
+ if isinstance (arg , ast .BinOp ) and isinstance (arg .op , ast .BitOr ):
157
+ if isinstance (arg .right , ast .Constant ) and arg .right .value is None :
158
+ return arg .left
159
+ if isinstance (arg .left , ast .Constant ) and arg .left .value is None :
160
+ return arg .right
161
+
162
+ # case 3: is not optional
163
+ return None
164
+
165
+
166
+ class ArgTypeValidator (ComponentFunctionValidator ):
167
+ """Validates component function's argument types."""
168
+
120
169
def validate (self , app_specs_func_def : ast .FunctionDef ) -> List [LinterMessage ]:
121
170
linter_errors = []
122
171
for arg_def in app_specs_func_def .args .args :
@@ -133,53 +182,68 @@ def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
133
182
return linter_errors
134
183
135
184
def _validate_arg_def (
136
- self , function_name : str , arg_def : ast .arg
185
+ self , function_name : str , arg : ast .arg
137
186
) -> List [LinterMessage ]:
138
- if not arg_def .annotation :
139
- return [
140
- self ._gen_linter_message (
141
- f"Arg { arg_def .arg } missing type annotation" , arg_def .lineno
142
- )
143
- ]
144
- if isinstance (arg_def .annotation , ast .Name ):
187
+ arg_type = arg .annotation # type hint
188
+
189
+ def ok () -> list [LinterMessage ]:
190
+ # return value when validation passes (e.g. no linter errors)
145
191
return []
146
- complex_type_def = cast (ast .Subscript , none_throws (arg_def .annotation ))
147
- if complex_type_def .value .id == "Optional" :
148
- # ast module in python3.9 does not have ast.Index wrapper
149
- if isinstance (complex_type_def .slice , ast .Index ):
150
- complex_type_def = complex_type_def .slice .value
151
- else :
152
- complex_type_def = complex_type_def .slice
153
- # Check if type is Optional[primitive_type]
154
- if isinstance (complex_type_def , ast .Name ):
155
- return []
156
- # Check if type is Union[Dict,List]
157
- type_name = complex_type_def .value .id
158
- if type_name != "Dict" and type_name != "List" :
159
- desc = (
160
- f"`{ function_name } ` allows only Dict, List as complex types."
161
- f"Argument `{ arg_def .arg } ` has: { type_name } "
162
- )
163
- return [self ._gen_linter_message (desc , arg_def .lineno )]
164
- linter_errors = []
165
- # ast module in python3.9 does not have objects wrapped in ast.Index
166
- if isinstance (complex_type_def .slice , ast .Index ):
167
- sub_type = complex_type_def .slice .value
192
+
193
+ def err (reason : str ) -> list [LinterMessage ]:
194
+ msg = f"{ reason } for argument { ast .unparse (arg )!r} in function { function_name !r} "
195
+ return [self ._gen_linter_message (msg , arg .lineno )]
196
+
197
+ if not arg_type :
198
+ return err ("Missing type annotation" )
199
+
200
+ # Case 1: optional
201
+ if T := get_optional_type (arg_type ):
202
+ # NOTE: optional types can be primitives or any of the allowed container types
203
+ # so check if arg is an optional, and if so, run the rest of the validation with the unpacked type
204
+ arg_type = T
205
+
206
+ # Case 2: int, float, str, bool
207
+ if is_primitive (arg_type ):
208
+ return ok ()
209
+ # Case 3: Containers (Dict, List, Tuple)
210
+ elif isinstance (arg_type , ast .Subscript ):
211
+ container_type = arg_type .value .id
212
+
213
+ if container_type in ["Dict" , "dict" ]:
214
+ KV = get_generic_type (arg_type )
215
+
216
+ assert isinstance (KV , ast .Tuple ) # dict[K,V] has ast.Tuple slice
217
+
218
+ K , V = KV .elts
219
+ if not is_primitive (K ):
220
+ return err (f"Non-primitive key type { ast .unparse (K )!r} " )
221
+ if not is_primitive (V ):
222
+ return err (f"Non-primitive value type { ast .unparse (V )!r} " )
223
+ return ok ()
224
+ elif container_type in ["List" , "list" ]:
225
+ T = get_generic_type (arg_type )
226
+ if is_primitive (T ):
227
+ return ok ()
228
+ else :
229
+ return err (f"Non-primitive element type { ast .unparse (T )!r} " )
230
+ elif container_type in ["Tuple" , "tuple" ]:
231
+ E_N = get_generic_type (arg_type )
232
+ assert isinstance (E_N , ast .Tuple ) # tuple[...] has ast.Tuple slice
233
+
234
+ for e in E_N .elts :
235
+ if not is_primitive (e ):
236
+ return err (f"Non-primitive element type '{ ast .unparse (e )!r} '" )
237
+
238
+ return ok ()
239
+
240
+ return err (f"Unsupported container type { container_type !r} " )
168
241
else :
169
- sub_type = complex_type_def .slice
170
- if type_name == "Dict" :
171
- sub_type_tuple = cast (ast .Tuple , sub_type )
172
- for el in sub_type_tuple .elts :
173
- if not isinstance (el , ast .Name ):
174
- desc = "Dict can only have primitive types"
175
- linter_errors .append (self ._gen_linter_message (desc , arg_def .lineno ))
176
- elif not isinstance (sub_type , ast .Name ):
177
- desc = "List can only have primitive types"
178
- linter_errors .append (self ._gen_linter_message (desc , arg_def .lineno ))
179
- return linter_errors
242
+ return err (f"Unsupported argument type { ast .unparse (arg_type )!r} " )
180
243
181
244
182
- class TorchxReturnValidator (TorchxFunctionValidator ):
245
+ class ReturnTypeValidator (ComponentFunctionValidator ):
246
+ """Validates that component functions always return AppDef type"""
183
247
184
248
def __init__ (self , supported_return_type : str ) -> None :
185
249
super ().__init__ ()
@@ -231,7 +295,7 @@ def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
231
295
return linter_errors
232
296
233
297
234
- class TorchFunctionVisitor (ast .NodeVisitor ):
298
+ class ComponentFnVisitor (ast .NodeVisitor ):
235
299
"""
236
300
Visitor that finds the component_function and runs registered validators on it.
237
301
Current registered validators:
@@ -252,12 +316,12 @@ class TorchFunctionVisitor(ast.NodeVisitor):
252
316
def __init__ (
253
317
self ,
254
318
component_function_name : str ,
255
- validators : Optional [List [TorchxFunctionValidator ]],
319
+ validators : Optional [List [ComponentFunctionValidator ]],
256
320
) -> None :
257
321
if validators is None :
258
- self .validators : List [TorchxFunctionValidator ] = [
259
- TorchxFunctionArgsValidator (),
260
- TorchxReturnValidator ("AppDef" ),
322
+ self .validators : List [ComponentFunctionValidator ] = [
323
+ ArgTypeValidator (),
324
+ ReturnTypeValidator ("AppDef" ),
261
325
]
262
326
else :
263
327
self .validators = validators
@@ -279,7 +343,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
279
343
def validate (
280
344
path : str ,
281
345
component_function : str ,
282
- validators : Optional [List [TorchxFunctionValidator ]] ,
346
+ validators : Optional [List [ComponentFunctionValidator ]] = None ,
283
347
) -> List [LinterMessage ]:
284
348
"""
285
349
Validates the function to make sure it complies the component standard.
@@ -309,7 +373,7 @@ def validate(
309
373
severity = "error" ,
310
374
)
311
375
return [linter_message ]
312
- visitor = TorchFunctionVisitor (component_function , validators )
376
+ visitor = ComponentFnVisitor (component_function , validators )
313
377
visitor .visit (module )
314
378
linter_errors = visitor .linter_errors
315
379
if not visitor .visited_function :
0 commit comments