Skip to content

Commit c76e884

Browse files
yiyixuxua-r-r-o-w
andauthored
update get_parameter_dtype (#9526)
* up * Update src/diffusers/models/modeling_utils.py Co-authored-by: Aryan <[email protected]> --------- Co-authored-by: Aryan <[email protected]>
1 parent d9c9691 commit c76e884

File tree

1 file changed

+13
-17
lines changed

1 file changed

+13
-17
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -93,24 +93,20 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
9393

9494
def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
9595
try:
96-
params = tuple(parameter.parameters())
97-
if len(params) > 0:
98-
return params[0].dtype
99-
100-
buffers = tuple(parameter.buffers())
101-
if len(buffers) > 0:
102-
return buffers[0].dtype
103-
96+
return next(parameter.parameters()).dtype
10497
except StopIteration:
105-
# For torch.nn.DataParallel compatibility in PyTorch 1.5
106-
107-
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
108-
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
109-
return tuples
110-
111-
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
112-
first_tuple = next(gen)
113-
return first_tuple[1].dtype
98+
try:
99+
return next(parameter.buffers()).dtype
100+
except StopIteration:
101+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
102+
103+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
104+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
105+
return tuples
106+
107+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
108+
first_tuple = next(gen)
109+
return first_tuple[1].dtype
114110

115111

116112
class ModelMixin(torch.nn.Module, PushToHubMixin):

0 commit comments

Comments
 (0)