Skip to content

Commit 87ffe03

Browse files
committed
typing: some small typing improvements to API/context modules
1 parent 192a1f4 commit 87ffe03

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

pyinfra/api/inventory.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class Inventory:
3636
"""
3737

3838
state: "State"
39+
groups: dict[str, list[Host]]
3940

4041
@staticmethod
4142
def empty():
@@ -181,7 +182,7 @@ def len_activated_hosts(self) -> int:
181182
"""
182183
return len(self.state.activated_hosts)
183184

184-
def get_host(self, name: str, default=NoHostError):
185+
def get_host(self, name: str, default=NoHostError) -> Host:
185186
"""
186187
Get a single host by name.
187188
"""
@@ -192,9 +193,10 @@ def get_host(self, name: str, default=NoHostError):
192193
if default is NoHostError:
193194
raise NoHostError("No such host: {0}".format(name))
194195

196+
# TODO: remove default here?
195197
return default
196198

197-
def get_group(self, name: str, default=NoGroupError):
199+
def get_group(self, name: str, default=NoGroupError) -> list[Host]:
198200
"""
199201
Get a list of hosts belonging to a group.
200202
"""
@@ -205,6 +207,7 @@ def get_group(self, name: str, default=NoGroupError):
205207
if default is NoGroupError:
206208
raise NoGroupError("No such group: {0}".format(name))
207209

210+
# TODO: remove default here?
208211
return default
209212

210213
def get_data(self):

pyinfra/context.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,22 @@ def __setattr__(self, key, value):
5959
if key in ("_container", "_base_cls"):
6060
return super().__setattr__(key, value)
6161

62-
if self._get_module() is None:
62+
mod = self._get_module()
63+
if mod is None:
6364
raise TypeError("Cannot assign to context base module")
64-
65-
return setattr(self._get_module(), key, value)
65+
return setattr(mod, key, value)
6666

6767
def __iter__(self):
68-
return iter(self._get_module())
68+
mod = self._get_module()
69+
if mod is None:
70+
raise ValueError("Context not set")
71+
return iter(mod)
6972

7073
def __len__(self):
71-
return len(self._get_module())
74+
mod = self._get_module()
75+
if mod is None:
76+
raise ValueError("Context not set")
77+
return len(mod)
7278

7379
@override
7480
def __eq__(self, other):
@@ -105,6 +111,9 @@ def isset(self):
105111
@contextmanager
106112
def use(self, module):
107113
old_module = self.get()
114+
if old_module is module:
115+
yield # if we're double-setting, nothing to do
116+
return
108117
self.set(module)
109118
yield
110119
self.set(old_module)

0 commit comments

Comments
 (0)