Skip to content

Commit 6a17c25

Browse files
committed
Moved test and updated import
Signed-off-by: jorisSchaller <[email protected]>
1 parent 8a14732 commit 6a17c25

File tree

2 files changed

+107
-96
lines changed

2 files changed

+107
-96
lines changed

tests/linen/linen_test.py

Lines changed: 0 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -40,102 +40,6 @@ def check_eq(xs, ys):
4040
)
4141

4242

43-
class PoolTest(parameterized.TestCase):
44-
def test_pool_custom_reduce(self):
45-
x = jnp.full((1, 3, 3, 1), 2.0)
46-
mul_reduce = lambda x, y: x * y
47-
y = nn.pooling.pool(x, 1.0, mul_reduce, (2, 2), (1, 1), 'VALID')
48-
np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.0**4))
49-
50-
@parameterized.parameters(
51-
{'count_include_pad': True}, {'count_include_pad': False}
52-
)
53-
def test_avg_pool(self, count_include_pad):
54-
x = jnp.full((1, 3, 3, 1), 2.0)
55-
pool = lambda x: nn.avg_pool(x, (2, 2), count_include_pad=count_include_pad)
56-
y = pool(x)
57-
np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.0))
58-
y_grad = jax.grad(lambda x: pool(x).sum())(x)
59-
expected_grad = jnp.array(
60-
[
61-
[0.25, 0.5, 0.25],
62-
[0.5, 1.0, 0.5],
63-
[0.25, 0.5, 0.25],
64-
]
65-
).reshape((1, 3, 3, 1))
66-
np.testing.assert_allclose(y_grad, expected_grad)
67-
68-
@parameterized.parameters(
69-
{'count_include_pad': True}, {'count_include_pad': False}
70-
)
71-
def test_avg_pool_no_batch(self, count_include_pad):
72-
x = jnp.full((3, 3, 1), 2.0)
73-
pool = lambda x: nn.avg_pool(x, (2, 2), count_include_pad=count_include_pad)
74-
y = pool(x)
75-
np.testing.assert_allclose(y, np.full((2, 2, 1), 2.0))
76-
y_grad = jax.grad(lambda x: pool(x).sum())(x)
77-
expected_grad = jnp.array(
78-
[
79-
[0.25, 0.5, 0.25],
80-
[0.5, 1.0, 0.5],
81-
[0.25, 0.5, 0.25],
82-
]
83-
).reshape((3, 3, 1))
84-
np.testing.assert_allclose(y_grad, expected_grad)
85-
86-
def test_max_pool(self):
87-
x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32)
88-
pool = lambda x: nn.max_pool(x, (2, 2))
89-
expected_y = jnp.array(
90-
[
91-
[4.0, 5.0],
92-
[7.0, 8.0],
93-
]
94-
).reshape((1, 2, 2, 1))
95-
y = pool(x)
96-
np.testing.assert_allclose(y, expected_y)
97-
y_grad = jax.grad(lambda x: pool(x).sum())(x)
98-
expected_grad = jnp.array(
99-
[
100-
[0.0, 0.0, 0.0],
101-
[0.0, 1.0, 1.0],
102-
[0.0, 1.0, 1.0],
103-
]
104-
).reshape((1, 3, 3, 1))
105-
np.testing.assert_allclose(y_grad, expected_grad)
106-
107-
@parameterized.parameters(
108-
{'count_include_pad': True}, {'count_include_pad': False}
109-
)
110-
def test_avg_pool_padding_same(self, count_include_pad):
111-
x = jnp.array([1.0, 2.0, 3.0, 4.0]).reshape((1, 2, 2, 1))
112-
pool = lambda x: nn.avg_pool(
113-
x, (2, 2), padding='SAME', count_include_pad=count_include_pad
114-
)
115-
y = pool(x)
116-
if count_include_pad:
117-
expected_y = jnp.array([10.0 / 4, 6.0 / 4, 7.0 / 4, 4.0 / 4]).reshape(
118-
(1, 2, 2, 1)
119-
)
120-
else:
121-
expected_y = jnp.array([10.0 / 4, 6.0 / 2, 7.0 / 2, 4.0 / 1]).reshape(
122-
(1, 2, 2, 1)
123-
)
124-
np.testing.assert_allclose(y, expected_y)
125-
126-
def test_pooling_variable_batch_dims(self):
127-
x = jnp.zeros((1, 8, 32, 32, 3), dtype=jnp.float32)
128-
y = nn.max_pool(x, (2, 2), (2, 2))
129-
130-
assert y.shape == (1, 8, 16, 16, 3)
131-
132-
def test_pooling_no_batch_dims(self):
133-
x = jnp.zeros((32, 32, 3), dtype=jnp.float32)
134-
y = nn.max_pool(x, (2, 2), (2, 2))
135-
136-
assert y.shape == (16, 16, 3)
137-
138-
13943
class NormalizationTest(parameterized.TestCase):
14044
def test_layer_norm_mask(self):
14145
key = random.key(0)

tests/pooling_test.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import unittest
2+
from flax.pooling import pool,avg_pool, max_pool, min_pool
3+
import numpy as np
4+
import jax.numpy as jnp
5+
from absl.testing import absltest, parameterized
6+
import jax
7+
8+
jax.config.parse_flags_with_absl()
9+
10+
11+
class PoolTest(parameterized.TestCase):
12+
def test_pool_custom_reduce(self):
13+
x = jnp.full((1, 3, 3, 1), 2.0)
14+
mul_reduce = lambda x, y: x * y
15+
y = pool(x, 1.0, mul_reduce, (2, 2), (1, 1), 'VALID')
16+
np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.0**4))
17+
18+
@parameterized.parameters(
19+
{'count_include_pad': True}, {'count_include_pad': False}
20+
)
21+
def test_avg_pool(self, count_include_pad):
22+
x = jnp.full((1, 3, 3, 1), 2.0)
23+
pool = lambda x: avg_pool(x, (2, 2), count_include_pad=count_include_pad)
24+
y = pool(x)
25+
np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.0))
26+
y_grad = jax.grad(lambda x: pool(x).sum())(x)
27+
expected_grad = jnp.array(
28+
[
29+
[0.25, 0.5, 0.25],
30+
[0.5, 1.0, 0.5],
31+
[0.25, 0.5, 0.25],
32+
]
33+
).reshape((1, 3, 3, 1))
34+
np.testing.assert_allclose(y_grad, expected_grad)
35+
36+
@parameterized.parameters(
37+
{'count_include_pad': True}, {'count_include_pad': False}
38+
)
39+
def test_avg_pool_no_batch(self, count_include_pad):
40+
x = jnp.full((3, 3, 1), 2.0)
41+
pool = lambda x: avg_pool(x, (2, 2), count_include_pad=count_include_pad)
42+
y = pool(x)
43+
np.testing.assert_allclose(y, np.full((2, 2, 1), 2.0))
44+
y_grad = jax.grad(lambda x: pool(x).sum())(x)
45+
expected_grad = jnp.array(
46+
[
47+
[0.25, 0.5, 0.25],
48+
[0.5, 1.0, 0.5],
49+
[0.25, 0.5, 0.25],
50+
]
51+
).reshape((3, 3, 1))
52+
np.testing.assert_allclose(y_grad, expected_grad)
53+
54+
def test_max_pool(self):
55+
x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32)
56+
pool = lambda x: max_pool(x, (2, 2))
57+
expected_y = jnp.array(
58+
[
59+
[4.0, 5.0],
60+
[7.0, 8.0],
61+
]
62+
).reshape((1, 2, 2, 1))
63+
y = pool(x)
64+
np.testing.assert_allclose(y, expected_y)
65+
y_grad = jax.grad(lambda x: pool(x).sum())(x)
66+
expected_grad = jnp.array(
67+
[
68+
[0.0, 0.0, 0.0],
69+
[0.0, 1.0, 1.0],
70+
[0.0, 1.0, 1.0],
71+
]
72+
).reshape((1, 3, 3, 1))
73+
np.testing.assert_allclose(y_grad, expected_grad)
74+
75+
@parameterized.parameters(
76+
{'count_include_pad': True}, {'count_include_pad': False}
77+
)
78+
def test_avg_pool_padding_same(self, count_include_pad):
79+
x = jnp.array([1.0, 2.0, 3.0, 4.0]).reshape((1, 2, 2, 1))
80+
pool = lambda x: avg_pool(
81+
x, (2, 2), padding='SAME', count_include_pad=count_include_pad
82+
)
83+
y = pool(x)
84+
if count_include_pad:
85+
expected_y = jnp.array([10.0 / 4, 6.0 / 4, 7.0 / 4, 4.0 / 4]).reshape(
86+
(1, 2, 2, 1)
87+
)
88+
else:
89+
expected_y = jnp.array([10.0 / 4, 6.0 / 2, 7.0 / 2, 4.0 / 1]).reshape(
90+
(1, 2, 2, 1)
91+
)
92+
np.testing.assert_allclose(y, expected_y)
93+
94+
def test_pooling_variable_batch_dims(self):
95+
x = jnp.zeros((1, 8, 32, 32, 3), dtype=jnp.float32)
96+
y = max_pool(x, (2, 2), (2, 2))
97+
98+
assert y.shape == (1, 8, 16, 16, 3)
99+
100+
def test_pooling_no_batch_dims(self):
101+
x = jnp.zeros((32, 32, 3), dtype=jnp.float32)
102+
y = max_pool(x, (2, 2), (2, 2))
103+
104+
assert y.shape == (16, 16, 3)
105+
106+
if __name__ == '__main__':
107+
unittest.main()

0 commit comments

Comments
 (0)