Skip to content

Commit 3912213

Browse files
committed
Skip functions with undefined types.
1 parent 6caa86e commit 3912213

File tree

3 files changed

+128
-6
lines changed

3 files changed

+128
-6
lines changed

taskiq_dependencies/dependency.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,15 @@ def __eq__(self, rhs: object) -> bool:
156156
if not isinstance(rhs, Dependency):
157157
return False
158158
return self._id == rhs._id
159+
160+
def __repr__(self) -> str:
161+
func_name = str(self.dependency)
162+
if self.dependency is not None and hasattr(self.dependency, "__name__"):
163+
func_name = self.dependency.__name__
164+
return (
165+
f"Dependency({func_name}, "
166+
f"use_cache={self.use_cache}, "
167+
f"kwargs={self.kwargs}, "
168+
f"parent={self.parent}"
169+
")"
170+
)

taskiq_dependencies/graph.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import inspect
2+
import os
23
import sys
4+
import warnings
35
from collections import defaultdict, deque
6+
from pathlib import Path
47
from typing import Any, Callable, Dict, List, Optional, TypeVar, get_type_hints
58

69
from graphlib import TopologicalSorter
@@ -171,19 +174,64 @@ def _build_graph(self) -> None: # noqa: C901
171174
if inspect.isclass(origin):
172175
# If this is a class, we need to get signature of
173176
# an __init__ method.
174-
hints = get_type_hints(origin.__init__)
177+
try:
178+
hints = get_type_hints(origin.__init__)
179+
except NameError:
180+
_, src_lineno = inspect.getsourcelines(dep.dependency)
181+
src_file = Path(inspect.getfile(dep.dependency)).relative_to(
182+
Path.cwd(),
183+
)
184+
warnings.warn(
185+
"Cannot resolve type hints for "
186+
f"a class {dep.dependency.__name__} defined "
187+
f"at {src_file}:{src_lineno}.",
188+
RuntimeWarning,
189+
stacklevel=2,
190+
)
191+
continue
175192
sign = inspect.signature(
176193
origin.__init__,
177194
**signature_kwargs,
178195
)
179196
elif inspect.isfunction(dep.dependency):
180197
# If this is function or an instance of a class, we get it's type hints.
181-
hints = get_type_hints(dep.dependency)
198+
try:
199+
hints = get_type_hints(dep.dependency)
200+
except NameError:
201+
_, src_lineno = inspect.getsourcelines(dep.dependency)
202+
src_file = Path(inspect.getfile(dep.dependency)).relative_to(
203+
Path.cwd(),
204+
)
205+
warnings.warn(
206+
"Cannot resolve type hints for "
207+
f"a function {dep.dependency.__name__} defined "
208+
f"at {src_file}:{src_lineno}.",
209+
RuntimeWarning,
210+
stacklevel=2,
211+
)
212+
continue
182213
sign = inspect.signature(origin, **signature_kwargs) # type: ignore
183214
else:
184-
hints = get_type_hints(
185-
dep.dependency.__call__, # type: ignore
186-
)
215+
try:
216+
hints = get_type_hints(
217+
dep.dependency.__call__, # type: ignore
218+
)
219+
except NameError:
220+
_, src_lineno = inspect.getsourcelines(dep.dependency.__class__)
221+
src_file = Path(
222+
inspect.getfile(dep.dependency.__class__),
223+
).relative_to(
224+
Path.cwd(),
225+
)
226+
cls_name = dep.dependency.__class__.__name__
227+
warnings.warn(
228+
"Cannot resolve type hints for "
229+
f"an object of class {cls_name} defined "
230+
f"at {src_file}:{src_lineno}.",
231+
RuntimeWarning,
232+
stacklevel=2,
233+
)
234+
continue
187235
sign = inspect.signature(origin, **signature_kwargs) # type: ignore
188236

189237
# Now we need to iterate over parameters, to

tests/test_graph.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
1+
from collections import UserString
12
import re
23
import uuid
34
from contextlib import asynccontextmanager, contextmanager
4-
from typing import Any, AsyncGenerator, Generator, Generic, Tuple, TypeVar
5+
from typing import (
6+
TYPE_CHECKING,
7+
Any,
8+
AsyncGenerator,
9+
Generator,
10+
Generic,
11+
Tuple,
12+
TypeVar,
13+
)
514

615
import pytest
716

@@ -891,3 +900,56 @@ def target(info: ParamInfo = Depends(inner_dep, use_cache=False)) -> None:
891900
assert info.name == ""
892901
assert info.definition is None
893902
assert info.graph == graph
903+
904+
905+
def test_skip_type_checking_function() -> None:
906+
"""Test if we can skip type only for type checking for the function."""
907+
if TYPE_CHECKING:
908+
909+
class A:
910+
pass
911+
912+
def target(unknown: "A") -> None:
913+
pass
914+
915+
with pytest.warns(RuntimeWarning, match=r"Cannot resolve.*function target.*"):
916+
graph = DependencyGraph(target=target)
917+
with graph.sync_ctx() as ctx:
918+
assert "unknown" not in ctx.resolve_kwargs()
919+
920+
921+
def test_skip_type_checking_class() -> None:
922+
"""Test if we can skip type only for type checking for the function."""
923+
if TYPE_CHECKING:
924+
925+
class A:
926+
pass
927+
928+
class Target:
929+
def __init__(self, unknown: "A") -> None:
930+
pass
931+
932+
with pytest.warns(RuntimeWarning, match=r"Cannot resolve.*class Target.*"):
933+
graph = DependencyGraph(target=Target)
934+
with graph.sync_ctx() as ctx:
935+
assert "unknown" not in ctx.resolve_kwargs()
936+
937+
938+
def test_skip_type_checking_object() -> None:
939+
"""Test if we can skip type only for type checking for the function."""
940+
if TYPE_CHECKING:
941+
942+
class A:
943+
pass
944+
945+
class Target:
946+
def __call__(self, unknown: "A") -> None:
947+
pass
948+
949+
with pytest.warns(
950+
RuntimeWarning,
951+
match=r"Cannot resolve.*object of class Target.*",
952+
):
953+
graph = DependencyGraph(target=Target())
954+
with graph.sync_ctx() as ctx:
955+
assert "unknown" not in ctx.resolve_kwargs()

0 commit comments

Comments
 (0)