Skip to content

Commit 47c01b1

Browse files
nghuiqinfacebook-github-bot
authored andcommitted
Improve _get_path_to_function_decl to handle function wrapper with class (#1116)
Summary: We added file decorator support in #1111 **Problem:** This will fail when the function wrapper with dataclass object **Fix:** Determine if decorators found in function before unwrap. Add two test cases to cover: * comp_f using dataclass in g.py => should return __init__.py * comp_g using decorator in h.py => should return g.py Differential Revision: D82346696
1 parent dca99de commit 47c01b1

File tree

5 files changed

+123
-3
lines changed

5 files changed

+123
-3
lines changed

torchx/specs/finder.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# pyre-strict
88

99
import abc
10+
import ast
1011
import copy
1112
import importlib
1213
import inspect
@@ -278,6 +279,22 @@ def _get_validation_errors(
278279
linter_errors = validate(path, function_name, validators)
279280
return [linter_error.description for linter_error in linter_errors]
280281

282+
def _get_function_decorators_count(
283+
self, function: Callable[..., Any] # pyre-ignore[2]
284+
) -> int:
285+
"""
286+
Returns the count of decorators for the given function.
287+
"""
288+
try:
289+
source = inspect.getsource(function)
290+
tree = ast.parse(source)
291+
for node in ast.walk(tree):
292+
if isinstance(node, ast.FunctionDef):
293+
return len(node.decorator_list)
294+
except (OSError, TypeError):
295+
return 0
296+
return 0
297+
281298
def _get_path_to_function_decl(
282299
self, function: Callable[..., Any] # pyre-ignore[2]
283300
) -> str:
@@ -287,9 +304,10 @@ def _get_path_to_function_decl(
287304
my_component defined in some_file.py, imported in other_file.py
288305
and the component is invoked as other_file.py:my_component
289306
"""
290-
# Unwrap decorated functions to get the original function
291-
unwrapped_function = inspect.unwrap(function)
292-
path_to_function_decl = inspect.getabsfile(unwrapped_function)
307+
# unwrap the function if it has decorators
308+
if self._get_function_decorators_count(function) > 0:
309+
function = inspect.unwrap(function)
310+
path_to_function_decl = inspect.getabsfile(function)
293311
if path_to_function_decl is None or not os.path.isfile(path_to_function_decl):
294312
return self._filepath
295313
return path_to_function_decl
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import functools
11+
12+
from torchx import specs
13+
14+
from .g import Cls
15+
16+
17+
@functools.wraps(Cls)
18+
def comp_f(**kwargs) -> specs.AppDef: # pyre-ignore[2]
19+
return Cls(**kwargs).build()
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
from dataclasses import dataclass
10+
11+
import torchx
12+
from torchx import specs
13+
14+
from .h import fake_decorator
15+
16+
17+
@dataclass
18+
class Args:
19+
name: str
20+
21+
22+
@dataclass
23+
class Cls(Args):
24+
def build(self) -> specs.AppDef:
25+
return specs.AppDef(
26+
name=self.name,
27+
roles=[
28+
specs.Role(
29+
name=self.name,
30+
image=torchx.IMAGE,
31+
entrypoint="echo",
32+
args=["hello world"],
33+
)
34+
],
35+
)
36+
37+
38+
@fake_decorator
39+
def comp_g() -> specs.AppDef:
40+
return specs.AppDef(
41+
name="g",
42+
roles=[
43+
specs.Role(
44+
name="g",
45+
image=torchx.IMAGE,
46+
entrypoint="echo",
47+
args=["hello world"],
48+
)
49+
],
50+
)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
11+
import functools
12+
from typing import Any, Callable
13+
14+
15+
def fake_decorator(
16+
func: Callable[..., Any],
17+
) -> Callable[..., Any]:
18+
@functools.wraps(func)
19+
def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]:
20+
# Fake decorator: just calls the original function
21+
return func(*args, **kwargs)
22+
23+
return wrapper

torchx/specs/test/finder_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
ModuleComponentsFinder,
3131
)
3232
from torchx.specs.test.components.a import comp_a
33+
from torchx.specs.test.components.f import comp_f
34+
from torchx.specs.test.components.f.g import comp_g
3335
from torchx.util.test.entrypoints_test import EntryPoint_from_text
3436
from torchx.util.types import none_throws
3537

@@ -243,6 +245,14 @@ def test_get_component_imported_from_other_file(self) -> None:
243245
component = get_component(f"{current_file_path()}:comp_a")
244246
self.assertListEqual([], component.validation_errors)
245247

248+
def test_get_component_from_dataclass(self) -> None:
249+
component = get_component(f"{current_file_path()}:comp_f")
250+
self.assertListEqual([], component.validation_errors)
251+
252+
def test_get_component_from_decorator(self) -> None:
253+
component = get_component(f"{current_file_path()}:comp_g")
254+
self.assertListEqual([], component.validation_errors)
255+
246256

247257
class GetBuiltinSourceTest(unittest.TestCase):
248258
def setUp(self) -> None:

0 commit comments

Comments
 (0)