diff --git a/inference/model.py b/inference/model.py index 88684997..5bd7b724 100644 --- a/inference/model.py +++ b/inference/model.py @@ -188,7 +188,7 @@ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtyp else: self.register_parameter("scale", None) if bias: - self.bias = nn.Parameter(torch.empty(out_features)) + self.bias = nn.Parameter(torch.empty(out_features, dtype=torch.float32)) else: self.register_parameter("bias", None)