Skip to content

Commit 3cba0d2

Browse files
committed
fix diffusers test/models on 910B
1 parent 684ed17 commit 3cba0d2

File tree

9 files changed

+81
-53
lines changed

9 files changed

+81
-53
lines changed

mindnlp/patch/safetensors/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def safe_save_file(tensor_dict, filename, metadata=None):
254254
return safetensors.numpy.save_file(tensor_dict, filename, metadata)
255255

256256

257-
def safe_load_file(filename, device):
257+
def safe_load_file(filename, device = 'cpu'):
258258
"""
259259
Loads a safetensors file into torch format.
260260

mindtorch/_apis/npu_910a.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,6 +1192,8 @@ def avg_pool2d(input, kernel_size, stride, padding=0, ceil_mode=False, count_inc
11921192
return pyboost.avg_pool2d_op(input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
11931193

11941194
def avg_pool3d(input, kernel_size, stride, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None):
1195+
if divisor_override is None:
1196+
divisor_override = 0
11951197
return legacy.avg_pool3_d(
11961198
input,
11971199
kernel_size,

mindtorch/_apis/npu_910b.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import math
22
import numbers
3+
import warnings
34
import mindspore
45
import mindtorch
56
import numpy as np
@@ -1380,6 +1381,8 @@ def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_paddi
13801381
return out
13811382

13821383
def conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
1384+
warnings.warn('conv_transposed3d only support float16 on MindSpore, mindtorch will do autocast + nan_to_num to void inf/nan, please check the precision if the result is not good.')
1385+
13831386
in_channel, out_channel = weight.shape[0], weight.shape[1]
13841387
kernel_size = weight.shape[2:]
13851388
# conv_transpose3d_op = ops.Conv3DTranspose(
@@ -1438,6 +1441,7 @@ def conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_paddi
14381441
out = cast(out, input_dtype)
14391442
if bias is not None:
14401443
out = add(out, bias)
1444+
out = nan_to_num(out, 0., 0., 0.)
14411445
return out
14421446

14431447
def relu(input):
@@ -1880,6 +1884,8 @@ def dynamic_rnn(x, w, b, seq_length, init_h, init_c):
18801884
'LSTM', 'UNIDIRECTIONAL', 1, False, 1.0, -1.0, 0, True, 'tanh', 0.0, True)
18811885

18821886
def nan_to_num(input, nan=0.0, posinf=None, neginf=None):
1887+
if ENABLE_PYBOOST:
1888+
return pyboost.nan_to_num_impl(input, nan, posinf, neginf)
18831889
return legacy.nan_to_num(input, nan, posinf, neginf)
18841890

18851891
def round(input, decimals):
@@ -2131,14 +2137,14 @@ def sdpa_manual(query, key, value, attn_mask=None, dropout_p=0.0,
21312137

21322138
def sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
21332139
is_causal=False, scale=None, enable_gqa=False):
2134-
if ENABLE_FLASH_ATTENTION:
2140+
if not ENABLE_FLASH_ATTENTION:
21352141
return sdpa_manual(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa)
21362142

21372143
scale_factor = 1 / math.sqrt(query.shape[-1]) if scale is None else scale
21382144

21392145
if attn_mask is not None and not is_causal:
21402146
if FLASH_ATTN_MASK_VALID == 1:
2141-
attn_mask = bitwise_not(attn_mask)
2147+
attn_mask = attn_mask == 0.0
21422148
else:
21432149
attn_mask = cast(attn_mask, mindspore.bool_)
21442150

@@ -2500,4 +2506,7 @@ def raw_adam(param, exp_avg, exp_avg_sq, beta1_power, beta2_power, lr, beta1, be
25002506
return legacy.adam(param, exp_avg, exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, False, False)
25012507

25022508
def inplace_sub(input, other):
2503-
return pyboost.inplace_sub_ext_op(input, other)
2509+
return pyboost.inplace_sub_ext_op(input, other)
2510+
2511+
def isfinite(input):
2512+
return pyboost.isfinite_op(input)

mindtorch/_tensor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -715,7 +715,9 @@ def byte(self):
715715
return self.to(mindspore.uint8)
716716

717717
# Tensor.broadcast_to
718-
def broadcast_to(self, shape):
718+
def broadcast_to(self, *shape):
719+
if isinstance(shape[0], (tuple, list)):
720+
shape = shape[0]
719721
return ops.broadcast_to(self, shape)
720722

721723
# Tensor.cauchy_
@@ -1309,7 +1311,7 @@ def is_complex(self):
13091311

13101312
# Tensor.is_floating_point
13111313
def is_floating_point(self):
1312-
return isinstance(self.dtype, typing.Float)
1314+
return isinstance(self.dtype, (typing.Float, typing.BFloat))
13131315

13141316
# Tensor.is_inference
13151317

mindtorch/configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,4 @@ def parse_flag_from_env(key, default=False):
4747
ENABLE_PYBOOST = parse_flag_from_env('ENABLE_PYBOOST', True)
4848
CPU_USE_NUMPY_OP = parse_flag_from_env('CPU_USE_NUMPY', False)
4949
ENABLE_FLASH_ATTENTION = parse_flag_from_env('ENABLE_FLASH_ATTENTION', False)
50+
CAPTURE_INF_NAN = parse_flag_from_env('CAPTURE_INF_NAN', False)

mindtorch/executor.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
12
from ._apis import cpu, gpu, meta, numpy, npu_910a, npu_910b, npu_310b, npu_310p
2-
from .configs import CPU_USE_NUMPY_OP, SOC, ENABLE_DISPATCH, DEVICE_TARGET
3+
from .configs import CPU_USE_NUMPY_OP, SOC, ENABLE_DISPATCH, DEVICE_TARGET, CAPTURE_INF_NAN
34

45
if SOC == 'ascend910':
56
npu = npu_910a
@@ -24,6 +25,7 @@
2425
}
2526

2627
DISPATCH_WHITE_LIST = ['inplace_zero', 'inplace_fill_scalar']
28+
SKIP_NAN_CHECK = ['empty', 'empty_like']
2729

2830
if ENABLE_DISPATCH:
2931
def execute(func_name, *args, **kwargs):
@@ -65,4 +67,17 @@ def execute(func_name, *args, **kwargs):
6567
raise RuntimeError(
6668
f"No implementation for function: {func_name} on {device_type}."
6769
)
70+
if CAPTURE_INF_NAN:
71+
outs = func(*args, **kwargs)
72+
if func_name in SKIP_NAN_CHECK:
73+
return outs
74+
75+
isfinite_op = getattr(api_map[device_type], 'isfinite')
76+
if isinstance(outs, tuple):
77+
for out in outs:
78+
assert isfinite_op(out).asnumpy().all()
79+
else:
80+
assert isfinite_op(outs).asnumpy().all()
81+
return outs
82+
6883
return func(*args, **kwargs)

mindtorch/nn/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def avg_pool1d(input, kernel_size, stride, padding=0, ceil_mode=False, count_inc
8282
def avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None):
8383
return execute('avg_pool2d', input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
8484

85-
def avg_pool3d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=0):
85+
def avg_pool3d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None):
8686
return execute('avg_pool3d', input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
8787

8888
def adaptive_avg_pool1d(input, output_size):

mindtorch/nn/utils/weight_norm.py

Lines changed: 39 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
# ============================================================================
1515
r"""Weight Normalization from https://arxiv.org/abs/1602.07868."""
1616
from typing import Any, TypeVar
17-
from ..parameter import Parameter
17+
from typing_extensions import deprecated
18+
from ..parameter import Parameter, UninitializedParameter
1819
from ..modules import Module
1920
from ... import ops
2021

@@ -43,12 +44,6 @@ def _weight_norm(weight_v, weight_g, dim):
4344

4445

4546
class WeightNorm:
46-
47-
r"""
48-
The 'WeightNorm' class implements weight normalization for neural network modules. It provides methods to compute normalized weights, apply weight normalization to a cell, wrap a function, and remove
49-
weight bias from a cell. The class also contains an initializer for the name and dimension of the weight parameters, as well as a method to compute the weight using the normalized parameters. Additionally, it
50-
includes a method to remove the weight bias and a wrapper function for transposing cell_id to cell.
51-
"""
5247
name: str
5348
dim: int
5449

@@ -60,63 +55,64 @@ def __init__(self, name: str, dim: int) -> None:
6055

6156
# TODO Make return type more specific
6257
def compute_weight(self, module: Module) -> Any:
63-
g = getattr(module, self.name + '_g')
64-
v = getattr(module, self.name + '_v')
65-
return Parameter(_weight_norm(v, g, self.dim))
58+
g = getattr(module, self.name + "_g")
59+
v = getattr(module, self.name + "_v")
60+
return _weight_norm(v, g, self.dim)
6661

6762
@staticmethod
68-
def apply(module, name: str, dim: int) -> 'WeightNorm':
69-
for k, hook in module._forward_pre_hooks.items():
63+
@deprecated(
64+
"`torch.nn.utils.weight_norm` is deprecated "
65+
"in favor of `torch.nn.utils.parametrizations.weight_norm`.",
66+
category=FutureWarning,
67+
)
68+
def apply(module, name: str, dim: int) -> "WeightNorm":
69+
for hook in module._forward_pre_hooks.values():
7070
if isinstance(hook, WeightNorm) and hook.name == name:
71-
raise RuntimeError("Cannot register two weight_norm hooks on "
72-
"the same parameter {}".format(name))
71+
raise RuntimeError(
72+
f"Cannot register two weight_norm hooks on the same parameter {name}"
73+
)
7374

7475
if dim is None:
7576
dim = -1
7677

7778
fn = WeightNorm(name, dim)
7879

7980
weight = getattr(module, name)
80-
# if isinstance(weight, UninitializedParameter):
81-
# raise ValueError(
82-
# 'The module passed to `WeightNorm` can\'t have uninitialized parameters. '
83-
# 'Make sure to run the dummy forward before applying weight normalization')
81+
if isinstance(weight, UninitializedParameter):
82+
raise ValueError(
83+
"The module passed to `WeightNorm` can't have uninitialized parameters. "
84+
"Make sure to run the dummy forward before applying weight normalization"
85+
)
8486
# remove w from parameter list
8587
del module._parameters[name]
8688

8789
# add g and v as new parameters and express w as g/||v|| * v
88-
module.register_parameter(name + '_g', Parameter(norm_except_dim(weight, 2, dim)))
89-
module.register_parameter(name + '_v', Parameter(weight))
90+
module.register_parameter(
91+
name + "_g", Parameter(norm_except_dim(weight, 2, dim).data)
92+
)
93+
module.register_parameter(name + "_v", Parameter(weight.data))
9094
setattr(module, name, fn.compute_weight(module))
9195

9296
# recompute weight before every forward()
9397
module.register_forward_pre_hook(fn)
9498

9599
return fn
96100

97-
def wrapper_func(self, cell, func):
98-
r"""
99-
wrapper_func where used to transpose cell_id to cell
100-
"""
101-
def new_func(_, inputs):
102-
nonlocal cell
103-
return func(cell, inputs)
104-
return new_func
105-
106101
def remove(self, module: Module) -> None:
107102
weight = self.compute_weight(module)
108103
delattr(module, self.name)
109-
del module._parameters[self.name + '_g']
110-
del module._parameters[self.name + '_v']
111-
setattr(module, self.name, weight)
104+
del module._parameters[self.name + "_g"]
105+
del module._parameters[self.name + "_v"]
106+
setattr(module, self.name, Parameter(weight.data))
112107

113108
def __call__(self, module: Module, inputs: Any) -> None:
114109
setattr(module, self.name, self.compute_weight(module))
115110

116111

117-
T_module = TypeVar('T_module', bound=Module)
112+
T_module = TypeVar("T_module", bound=Module)
113+
118114

119-
def weight_norm(module: T_module, name: str = 'weight', dim: int = 0) -> T_module:
115+
def weight_norm(module: T_module, name: str = "weight", dim: int = 0) -> T_module:
120116
r"""Apply weight normalization to a parameter in the given module.
121117
122118
.. math::
@@ -138,7 +134,7 @@ def weight_norm(module: T_module, name: str = 'weight', dim: int = 0) -> T_modul
138134
139135
.. warning::
140136
141-
This function is deprecated. Use :func:`mindtorch.nn.utils.parametrizations.weight_norm`
137+
This function is deprecated. Use :func:`torch.nn.utils.parametrizations.weight_norm`
142138
which uses the modern parametrization API. The new ``weight_norm`` is compatible
143139
with ``state_dict`` generated from old ``weight_norm``.
144140
@@ -150,11 +146,11 @@ def weight_norm(module: T_module, name: str = 'weight', dim: int = 0) -> T_modul
150146
https://github.com/pytorch/pytorch/issues/102999
151147
152148
* To remove the weight normalization reparametrization, use
153-
:func:`mindtorch.nn.utils.parametrize.remove_parametrizations`.
149+
:func:`torch.nn.utils.parametrize.remove_parametrizations`.
154150
155151
* The weight is no longer recomputed once at module forward; instead, it will
156152
be recomputed on every access. To restore the old behavior, use
157-
:func:`mindtorch.nn.utils.parametrize.cached` before invoking the module
153+
:func:`torch.nn.utils.parametrize.cached` before invoking the module
158154
in question.
159155
160156
Args:
@@ -171,16 +167,17 @@ def weight_norm(module: T_module, name: str = 'weight', dim: int = 0) -> T_modul
171167
>>> m
172168
Linear(in_features=20, out_features=40, bias=True)
173169
>>> m.weight_g.size()
174-
mindtorch.Size([40, 1])
170+
torch.Size([40, 1])
175171
>>> m.weight_v.size()
176-
mindtorch.Size([40, 20])
172+
torch.Size([40, 20])
177173
178174
"""
179175
WeightNorm.apply(module, name, dim)
180176
return module
181177

182-
def remove_weight_norm(module: T_module, name: str = 'weight') -> T_module:
183-
r"""Removes the weight normalization reparameterization from a module.
178+
179+
def remove_weight_norm(module: T_module, name: str = "weight") -> T_module:
180+
r"""Remove the weight normalization reparameterization from a module.
184181
185182
Args:
186183
module (Module): containing module
@@ -196,5 +193,4 @@ def remove_weight_norm(module: T_module, name: str = 'weight') -> T_module:
196193
del module._forward_pre_hooks[k]
197194
return module
198195

199-
raise ValueError("weight_norm of '{}' not found in {}"
200-
.format(name, module))
196+
raise ValueError(f"weight_norm of '{name}' not found in {module}")

mindtorch/ops/creation.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ def arange(start=0, end=None, step=1, *, out=None, dtype=None, layout=None, devi
9696
start, end = 0, int(start)
9797
if dtype is None:
9898
dtype = mindtorch.py2dtype[type(start)]
99+
100+
if dtype == mindtorch.float64:
101+
dtype = mindtorch.float32
99102

100103
device = check_device(device)
101104

@@ -181,8 +184,8 @@ def empty_like(input, *, dtype=None, layout=None, device=None, requires_grad=Fal
181184

182185
# full
183186
def full(size, fill_value, *, out=None, dtype=None, layout=None, device=None, requires_grad=False):
184-
# if dtype is None:
185-
# dtype = get_default_dtype()
187+
if dtype is None:
188+
dtype = get_default_dtype()
186189
device = check_device(device)
187190
if not isinstance(device, str):
188191
device = device.type

0 commit comments

Comments
 (0)