Skip to content

Commit 6c90998

Browse files
committed
Support injecting annotated types with Inject[]
This was missing in #263.
1 parent 47e9f97 commit 6c90998

File tree

2 files changed

+138
-1
lines changed

2 files changed

+138
-1
lines changed

injector/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1213,6 +1213,16 @@ def _is_injection_annotation(annotation: Any) -> bool:
12131213
_inject_marker in annotation.__metadata__ or _noinject_marker in annotation.__metadata__
12141214
)
12151215

1216+
def _recreate_annotated_origin(annotated_type: Any) -> Any:
1217+
# Creates `Annotated[type, annotation]` from `Inject[Annotated[type, annotation]]`,
1218+
# to support the injection of annotated types with the `Inject[]` annotation.
1219+
origin = annotated_type.__origin__
1220+
for metadata in annotated_type.__metadata__:
1221+
if metadata in (_inject_marker, _noinject_marker):
1222+
break
1223+
origin = Annotated[origin, metadata]
1224+
return origin
1225+
12161226
spec = inspect.getfullargspec(callable)
12171227

12181228
try:
@@ -1245,7 +1255,7 @@ def _is_injection_annotation(annotation: Any) -> bool:
12451255
for k, v in list(bindings.items()):
12461256
# extract metadata only from Inject and NonInject
12471257
if _is_injection_annotation(v):
1248-
v, metadata = v.__origin__, v.__metadata__
1258+
v, metadata = _recreate_annotated_origin(v), v.__metadata__
12491259
bindings[k] = v
12501260
else:
12511261
metadata = tuple()

injector_test.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1700,6 +1700,67 @@ def function(a: int) -> 'InvalidForwardReference':
17001700
assert get_bindings(function) == {'a': int}
17011701

17021702

1703+
def test_gets_bindings_for_annotated_type_with_inject_decorator() -> None:
1704+
UserID = Annotated[int, 'user_id']
1705+
1706+
@inject
1707+
def function(a: UserID, b: str) -> None:
1708+
pass
1709+
1710+
assert get_bindings(function) == {'a': UserID, 'b': str}
1711+
1712+
1713+
def test_gets_bindings_of_annotated_type_with_inject_annotation() -> None:
1714+
UserID = Annotated[int, 'user_id']
1715+
1716+
def function(a: Inject[UserID], b: Inject[str]) -> None:
1717+
pass
1718+
1719+
assert get_bindings(function) == {'a': UserID, 'b': str}
1720+
1721+
1722+
def test_gets_bindings_of_new_type_with_inject_annotation() -> None:
1723+
Name = NewType('Name', str)
1724+
1725+
@inject
1726+
def function(a: Name, b: str) -> None:
1727+
pass
1728+
1729+
assert get_bindings(function) == {'a': Name, 'b': str}
1730+
1731+
1732+
def test_gets_bindings_of_inject_annotation_with_new_type() -> None:
1733+
def function(a: Inject[Name], b: str) -> None:
1734+
pass
1735+
1736+
assert get_bindings(function) == {'a': Name}
1737+
1738+
1739+
def test_get_bindings_of_nested_noinject_inject_annotation() -> None:
1740+
# This is not how this is intended to be used
1741+
def function(a: Inject[NoInject[int]], b: NoInject[Inject[str]]) -> None:
1742+
pass
1743+
1744+
assert get_bindings(function) == {}
1745+
1746+
1747+
def test_get_bindings_of_nested_noinject_inject_annotation_and_inject_decorator() -> None:
1748+
# This is not how this is intended to be used
1749+
@inject
1750+
def function(a: Inject[NoInject[int]], b: NoInject[Inject[str]]) -> None:
1751+
pass
1752+
1753+
assert get_bindings(function) == {}
1754+
1755+
1756+
def test_get_bindings_of_nested_inject_annotations() -> None:
1757+
# This is not how this is intended to be used
1758+
def function(a: Inject[Inject[int]]) -> None:
1759+
pass
1760+
1761+
assert get_bindings(function) == {'a': int}
1762+
1763+
17031764
# Tests https://github.com/alecthomas/injector/issues/202
17041765
@pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10+")
17051766
def test_get_bindings_for_pep_604():
@@ -1800,6 +1861,56 @@ def configure(binder):
18001861
assert test_class.user_id == 123
18011862

18021863

1864+
def test_inject_annotation_with_annotated_type():
1865+
UserID = Annotated[int, 'user_id']
1866+
1867+
class TestClass:
1868+
def __init__(self, user_id: Inject[UserID]):
1869+
self.user_id = user_id
1870+
1871+
def configure(binder):
1872+
binder.bind(UserID, to=123)
1873+
1874+
injector = Injector([configure])
1875+
1876+
test_class = injector.get(TestClass)
1877+
assert test_class.user_id == 123
1878+
1879+
1880+
def test_inject_annotation_with_nested_annotated_type():
1881+
UserID = Annotated[int, 'user_id']
1882+
SpecialUserID = Annotated[UserID, 'special_user_id']
1883+
1884+
class TestClass:
1885+
def __init__(self, user_id: Inject[SpecialUserID]):
1886+
self.user_id = user_id
1887+
1888+
def configure(binder):
1889+
binder.bind(SpecialUserID, to=123)
1890+
1891+
injector = Injector([configure])
1892+
1893+
test_class = injector.get(TestClass)
1894+
assert test_class.user_id == 123
1895+
1896+
1897+
def test_noinject_annotation_with_annotated_type():
1898+
UserID = Annotated[int, 'user_id']
1899+
1900+
@inject
1901+
class TestClass:
1902+
def __init__(self, user_id: NoInject[UserID] = None):
1903+
self.user_id = user_id
1904+
1905+
def configure(binder):
1906+
binder.bind(UserID, to=123)
1907+
1908+
injector = Injector([configure])
1909+
1910+
test_class = injector.get(TestClass)
1911+
assert test_class.user_id is None
1912+
1913+
18031914
def test_newtype_integration_with_annotated():
18041915
UserID = NewType('UserID', int)
18051916

@@ -1817,6 +1928,22 @@ def configure(binder):
18171928
assert test_class.user_id == 123
18181929

18191930

1931+
def test_newtype_with_injection_annotation():
1932+
UserID = NewType('UserID', int)
1933+
1934+
class TestClass:
1935+
def __init__(self, user_id: Inject[UserID]):
1936+
self.user_id = user_id
1937+
1938+
def configure(binder):
1939+
binder.bind(UserID, to=123)
1940+
1941+
injector = Injector([configure])
1942+
1943+
test_class = injector.get(TestClass)
1944+
assert test_class.user_id == 123
1945+
1946+
18201947
def test_dataclass_annotated_parameter():
18211948
Foo = Annotated[int, object()]
18221949

0 commit comments

Comments
 (0)