Skip to content

Commit 50ea961

Browse files
Merge pull request #67 from graphcore-research/conv1d
Conv1d
2 parents a85d806 + 013e6c9 commit 50ea961

File tree

5 files changed

+227
-2
lines changed

5 files changed

+227
-2
lines changed

unit_scaling/_modules.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import einops
1010
import torch
11+
import torch.nn.functional as F
1112
from torch import Tensor, nn
1213

1314
from . import functional as U
@@ -137,7 +138,7 @@ def __init__(
137138
self.constraint = constraint
138139
self.weight = Parameter(self.weight.data, mup_type=weight_mup_type)
139140
if self.bias is not None:
140-
self.bias = Parameter(self.bias, mup_type="bias")
141+
self.bias = Parameter(self.bias.data, mup_type="bias")
141142

142143
def reset_parameters(self) -> None:
143144
nn.init.normal_(self.weight)
@@ -181,6 +182,75 @@ def forward(self, input: Tensor) -> Tensor:
181182
return U.linear_readout(input, self.weight, self.bias, self.constraint)
182183

183184

185+
@inherit_docstring(
186+
short_description=(
187+
"Applies a **unit-scaled** 1D convolution to the incoming data."
188+
"\nNote that this layer sets :code:`bias=False` by default."
189+
"We also require padding to be supplied as an integer, not a string."
190+
),
191+
add_args=[binary_constraint_docstring],
192+
)
193+
class Conv1d(nn.Conv1d):
194+
def __init__(
195+
self,
196+
in_channels: int,
197+
out_channels: int,
198+
kernel_size: int,
199+
stride: int = 1,
200+
padding: int = 0,
201+
dilation: int = 1,
202+
groups: int = 1,
203+
bias: bool = False,
204+
padding_mode: str = "zeros",
205+
device: Any = None,
206+
dtype: Any = None,
207+
constraint: Optional[str] = "to_output_scale",
208+
weight_mup_type: MupType = "weight",
209+
) -> None:
210+
super().__init__(
211+
in_channels,
212+
out_channels,
213+
kernel_size,
214+
stride,
215+
padding,
216+
dilation,
217+
groups,
218+
bias,
219+
padding_mode,
220+
device,
221+
dtype,
222+
)
223+
assert isinstance(padding, int), "only `int` is supported for padding type"
224+
self.kernel_size = kernel_size # type:ignore[assignment]
225+
self.stride = stride # type:ignore[assignment]
226+
self.padding = padding # type:ignore[assignment]
227+
self.dilation = dilation # type:ignore[assignment]
228+
self.constraint = constraint
229+
self.weight = Parameter(self.weight.data, mup_type=weight_mup_type)
230+
if self.bias is not None:
231+
self.bias = Parameter(self.bias.data, mup_type="bias")
232+
233+
def reset_parameters(self) -> None:
234+
nn.init.normal_(self.weight)
235+
if self.bias is not None:
236+
self.bias.data.zero_()
237+
238+
def forward(self, input: Tensor) -> Tensor:
239+
if self.padding_mode != "zeros":
240+
input = F.pad(
241+
input, self._reversed_padding_repeated_twice, mode=self.padding_mode
242+
)
243+
return U.conv1d(
244+
input,
245+
self.weight,
246+
self.bias,
247+
self.stride,
248+
self.padding,
249+
self.dilation,
250+
self.groups,
251+
)
252+
253+
184254
@inherit_docstring(
185255
short_description=(
186256
"Applies a **unit-scaled** Layer Normalization over a mini-batch of inputs."

unit_scaling/functional.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,49 @@ def linear_readout(
262262
)
263263

264264

265+
@docstring_from(
266+
F.conv1d,
267+
short_description="Applies a **unit-scaled** 1D convolution.",
268+
add_args=[
269+
binary_constraint_docstring,
270+
"scale_power ((float, float, float), optional): scaling power"
271+
" for each of (output, grad(input), grad(weight|bias))",
272+
],
273+
)
274+
def conv1d(
275+
input: Tensor,
276+
weight: Tensor,
277+
bias: Optional[Tensor] = None,
278+
stride: int = 1,
279+
padding: int = 0,
280+
dilation: int = 1,
281+
groups: int = 1,
282+
constraint: Optional[str] = "to_output_scale",
283+
scale_power: Tuple[float, float, float] = (0.5, 0.5, 0.5),
284+
) -> Tensor:
285+
fan_out, fan_in, kernel_size = weight.shape
286+
seq_len = input.shape[-1]
287+
out_size = (seq_len + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1
288+
batch_size = out_size
289+
if len(input.shape) > 2:
290+
batch_size *= input.shape[:-2].numel()
291+
292+
output_scale = 1 / (fan_in * kernel_size) ** scale_power[0]
293+
grad_input_scale = (stride * groups / (fan_out * kernel_size)) ** scale_power[1]
294+
grad_weight_scale = grad_bias_scale = 1 / batch_size ** scale_power[2]
295+
296+
output_scale, grad_input_scale = apply_constraint(
297+
constraint, output_scale, grad_input_scale
298+
)
299+
300+
input = scale_bwd(input, grad_input_scale)
301+
weight = scale_bwd(weight, grad_weight_scale)
302+
bias = scale_bwd(bias, grad_bias_scale) if bias is not None else None
303+
output = F.conv1d(input, weight, bias, stride, padding, dilation, groups)
304+
assert out_size == output.shape[-1]
305+
return scale_fwd(output, output_scale)
306+
307+
265308
@docstring_from(
266309
F.layer_norm,
267310
short_description=(

unit_scaling/optim.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,10 @@ def _get_fan_in(param: ParameterData) -> int:
3636
return param.shape[0]
3737
if len(param.shape) == 2:
3838
return param.shape[1]
39+
if len(param.shape) == 3:
40+
return param.shape[1] * param.shape[2]
3941
raise ValueError(
40-
f"Cannot get fan_in of `ndim >= 3` param, shape={tuple(param.shape)}"
42+
f"Cannot get fan_in of `ndim >= 4` param, shape={tuple(param.shape)}"
4143
)
4244

4345

unit_scaling/tests/test_functional.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from ..functional import (
88
add,
9+
conv1d,
910
cross_entropy,
1011
dropout,
1112
embedding,
@@ -277,6 +278,92 @@ def test_linear_readout() -> None:
277278
assert_scale(output, target=2**-5) # 1/sqrt(fan_in)
278279

279280

281+
# --- test conv1d() ---
282+
283+
284+
def test_conv1d() -> None:
285+
batch_size = 2**6
286+
d_in = 2**6 * 3
287+
d_out = 2**6 * 5
288+
kernel_size = 11
289+
seq_len = 2**6 * 7
290+
input = randn(batch_size, d_in, seq_len, requires_grad=True)
291+
weight = randn(d_out, d_in, kernel_size, requires_grad=True)
292+
bias = zeros(d_out).requires_grad_()
293+
output = conv1d(input, weight, bias, constraint=None)
294+
unit_backward(output)
295+
296+
assert_unit_scaled(output, input.grad, weight.grad, bias.grad)
297+
298+
299+
def test_conv1d_stride() -> None:
300+
batch_size = 2**6
301+
d_in = 2**6 * 3
302+
d_out = 2**6 * 5
303+
kernel_size = 11
304+
seq_len = 2**6 * 7
305+
stride = 3
306+
307+
input = randn(batch_size, d_in, seq_len, requires_grad=True)
308+
weight = randn(d_out, d_in, kernel_size, requires_grad=True)
309+
bias = zeros(d_out).requires_grad_()
310+
output = conv1d(input, weight, bias, stride=stride, constraint=None)
311+
unit_backward(output)
312+
313+
assert_unit_scaled(output, input.grad, weight.grad, bias.grad)
314+
315+
316+
def test_conv1d_padding() -> None:
317+
batch_size = 2**6
318+
d_in = 2**6 * 3
319+
d_out = 2**6 * 5
320+
kernel_size = 11
321+
seq_len = 2**6 * 7
322+
padding = 23 # If this is large enough wrt seq_len, test fails
323+
324+
input = randn(batch_size, d_in, seq_len, requires_grad=True)
325+
weight = randn(d_out, d_in, kernel_size, requires_grad=True)
326+
bias = zeros(d_out).requires_grad_()
327+
output = conv1d(input, weight, bias, padding=padding, constraint=None)
328+
unit_backward(output)
329+
330+
assert_unit_scaled(output, input.grad, weight.grad, bias.grad)
331+
332+
333+
def test_conv1d_dilation() -> None:
334+
batch_size = 2**6
335+
d_in = 2**6 * 3
336+
d_out = 2**6 * 5
337+
kernel_size = 11
338+
seq_len = 2**6 * 7
339+
dilation = 8
340+
341+
input = randn(batch_size, d_in, seq_len, requires_grad=True)
342+
weight = randn(d_out, d_in, kernel_size, requires_grad=True)
343+
bias = zeros(d_out).requires_grad_()
344+
output = conv1d(input, weight, bias, dilation=dilation, constraint=None)
345+
unit_backward(output)
346+
347+
assert_unit_scaled(output, input.grad, weight.grad, bias.grad)
348+
349+
350+
def test_conv1d_groups() -> None:
351+
batch_size = 2**6
352+
d_in = 2**6 * 3
353+
d_out = 2**6 * 5
354+
kernel_size = 11
355+
seq_len = 2**6 * 7
356+
groups = 32
357+
358+
input = randn(batch_size, d_in, seq_len, requires_grad=True)
359+
weight = randn(d_out, d_in // groups, kernel_size, requires_grad=True)
360+
bias = zeros(d_out).requires_grad_()
361+
output = conv1d(input, weight, bias, groups=groups, constraint=None)
362+
unit_backward(output)
363+
364+
assert_unit_scaled(output, input.grad, weight.grad, bias.grad)
365+
366+
280367
# --- test layer_norm() ---
281368

282369

unit_scaling/tests/test_modules.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
GELU,
99
MHSA,
1010
MLP,
11+
Conv1d,
1112
CrossEntropyLoss,
1213
DepthModuleList,
1314
DepthSequential,
@@ -87,6 +88,28 @@ def test_linear() -> None:
8788
assert_non_zeros(model.bias)
8889

8990

91+
def test_conv1d() -> None:
92+
batch_size = 2**6
93+
d_in = 2**6 * 3
94+
d_out = 2**6 * 5
95+
kernel_size = 11
96+
seq_len = 2**6 * 7
97+
input = randn(batch_size, d_in, seq_len, requires_grad=True)
98+
model = Conv1d(d_in, d_out, kernel_size, bias=True)
99+
output = model(input)
100+
101+
assert_unit_scaled(model.weight)
102+
assert_zeros(model.bias)
103+
104+
unit_backward(output)
105+
SGD(model.parameters(), lr=1).step()
106+
107+
assert float(output.std()) == pytest.approx(1, abs=0.1)
108+
109+
assert_not_unit_scaled(model.weight)
110+
assert_non_zeros(model.bias)
111+
112+
90113
def test_linear_readout() -> None:
91114
input = randn(2**8, 2**10, requires_grad=True)
92115
model = LinearReadout(2**10, 2**12)

0 commit comments

Comments
 (0)