|
9 | 9 | from pytensor.compile.mode import Mode
|
10 | 10 | from pytensor.configdefaults import config
|
11 | 11 | from pytensor.graph import rewrite_graph
|
12 |
| -from pytensor.graph.basic import Constant, applys_between, equal_computations |
| 12 | +from pytensor.graph.basic import Constant, Variable, applys_between, equal_computations |
13 | 13 | from pytensor.npy_2_compat import old_np_unique
|
14 | 14 | from pytensor.raise_op import Assert
|
15 | 15 | from pytensor.tensor import alloc
|
|
37 | 37 | diff,
|
38 | 38 | fill_diagonal,
|
39 | 39 | fill_diagonal_offset,
|
| 40 | + pack, |
40 | 41 | ravel_multi_index,
|
41 | 42 | repeat,
|
42 | 43 | searchsorted,
|
43 | 44 | squeeze,
|
44 | 45 | to_one_hot,
|
| 46 | + unpack, |
45 | 47 | unravel_index,
|
46 | 48 | )
|
47 | 49 | from pytensor.tensor.type import (
|
@@ -1378,3 +1380,69 @@ def test_concat_with_broadcast():
|
1378 | 1380 | a = pt.tensor("a", shape=(1, 3, 5))
|
1379 | 1381 | b = pt.tensor("b", shape=(3, 5))
|
1380 | 1382 | pt.concat_with_broadcast([a, b], axis=1)
|
| 1383 | + |
| 1384 | + |
| 1385 | +@pytest.mark.parametrize( |
| 1386 | + "shapes, expected_flat_shape", |
| 1387 | + [([(), (5,), (3, 3)], 15), ([(), (None,), (None, None)], None)], |
| 1388 | + ids=["static", "symbolic"], |
| 1389 | +) |
| 1390 | +def test_pack_all_shapes_known(shapes, expected_flat_shape): |
| 1391 | + rng = np.random.default_rng() |
| 1392 | + |
| 1393 | + x = pt.tensor("x", shape=shapes[0]) |
| 1394 | + y = pt.tensor("y", shape=shapes[1]) |
| 1395 | + z = pt.tensor("z", shape=shapes[2]) |
| 1396 | + |
| 1397 | + has_static_shape = [not any(s is None for s in shape) for shape in shapes] |
| 1398 | + |
| 1399 | + flat_packed, packed_shapes = pack(x, y, z) |
| 1400 | + |
| 1401 | + assert flat_packed.type.shape[0] == expected_flat_shape |
| 1402 | + |
| 1403 | + for i, (packed_shape, has_static) in enumerate( |
| 1404 | + zip(packed_shapes, has_static_shape) |
| 1405 | + ): |
| 1406 | + if has_static: |
| 1407 | + assert packed_shape == shapes[i] |
| 1408 | + else: |
| 1409 | + assert isinstance(packed_shape, Variable) |
| 1410 | + |
| 1411 | + new_outputs = unpack(flat_packed, packed_shapes) |
| 1412 | + |
| 1413 | + assert len(new_outputs) == 3 |
| 1414 | + assert all( |
| 1415 | + out.type.shape == var.type.shape for out, var in zip(new_outputs, [x, y, z]) |
| 1416 | + ) |
| 1417 | + |
| 1418 | + fn = function([x, y, z], new_outputs, mode="FAST_COMPILE") |
| 1419 | + |
| 1420 | + input_vals = [ |
| 1421 | + rng.normal(size=shape) for var, shape in zip([x, y, z], [(), (5,), (3, 3)]) |
| 1422 | + ] |
| 1423 | + new_output_vals = fn(*input_vals) |
| 1424 | + for input, output in zip(input_vals, new_output_vals): |
| 1425 | + np.testing.assert_allclose(input, output) |
| 1426 | + |
| 1427 | + |
| 1428 | +def test_make_replacements_with_pack_unpack(): |
| 1429 | + rng = np.random.default_rng() |
| 1430 | + |
| 1431 | + x = pt.tensor("x", shape=()) |
| 1432 | + y = pt.tensor("y", shape=(5,)) |
| 1433 | + z = pt.tensor("z", shape=(3, 3)) |
| 1434 | + |
| 1435 | + loss = (x + y.sum() + z.sum()) ** 2 |
| 1436 | + |
| 1437 | + flat_packed, packed_shapes = pack(x, y, z) |
| 1438 | + new_input = flat_packed.type() |
| 1439 | + new_outputs = unpack(new_input, packed_shapes) |
| 1440 | + |
| 1441 | + loss = pytensor.graph.graph_replace(loss, dict(zip([x, y, z], new_outputs))) |
| 1442 | + fn = pytensor.function([new_input], loss, mode="FAST_COMPILE") |
| 1443 | + |
| 1444 | + input_vals = [rng.normal(size=(var.type.shape)) for var in [x, y, z]] |
| 1445 | + flat_inputs = np.r_[*[input.ravel() for input in input_vals]] |
| 1446 | + output_val = fn(flat_inputs) |
| 1447 | + |
| 1448 | + assert np.allclose(output_val, sum([input.sum() for input in input_vals]) ** 2) |
0 commit comments