|
14 | 14 | from botorch.models.transforms.input import (
|
15 | 15 | AffineInputTransform,
|
16 | 16 | AppendFeatures,
|
| 17 | + BatchBroadcastedInputTransform, |
17 | 18 | ChainedInputTransform,
|
18 | 19 | FilterFeatures,
|
19 | 20 | InputPerturbation,
|
@@ -652,6 +653,132 @@ def test_chained_input_transform(self) -> None:
|
652 | 653 | tf = ChainedInputTransform(stz=tf1, pert=tf2)
|
653 | 654 | self.assertTrue(tf.is_one_to_many)
|
654 | 655 |
|
| 656 | + def test_batch_broadcasted_input_transform(self) -> None: |
| 657 | + ds = (1, 2) |
| 658 | + batch_args = [ |
| 659 | + (torch.Size([2]), {}), |
| 660 | + (torch.Size([3, 2]), {}), |
| 661 | + (torch.Size([2, 3]), {"broadcast_index": 0}), |
| 662 | + (torch.Size([5, 2, 3]), {"broadcast_index": 1}), |
| 663 | + ] |
| 664 | + dtypes = (torch.float, torch.double) |
| 665 | + # set seed to range where this is known to not be flaky |
| 666 | + torch.manual_seed(randint(0, 1000)) |
| 667 | + |
| 668 | + for d, (batch_shape, kwargs), dtype in itertools.product( |
| 669 | + ds, batch_args, dtypes |
| 670 | + ): |
| 671 | + bounds = torch.tensor( |
| 672 | + [[-2.0] * d, [2.0] * d], device=self.device, dtype=dtype |
| 673 | + ) |
| 674 | + # when the batch_shape is (2, 3), the transform list is broadcasted across |
| 675 | + # the first dimension, whereas each individual transform gets broadcasted |
| 676 | + # over the remaining batch dimensions. |
| 677 | + if "broadcast_index" not in kwargs: |
| 678 | + broadcast_index = -3 |
| 679 | + tf_batch_shape = batch_shape[:-1] |
| 680 | + else: |
| 681 | + broadcast_index = kwargs["broadcast_index"] |
| 682 | + # if the broadcast index is negative, we need to adjust the index |
| 683 | + # when indexing into the batch shape tuple |
| 684 | + i = broadcast_index + 2 if broadcast_index < 0 else broadcast_index |
| 685 | + tf_batch_shape = list(batch_shape[:i]) |
| 686 | + tf_batch_shape.extend(list(batch_shape[i + 1 :])) |
| 687 | + tf_batch_shape = torch.Size(tf_batch_shape) |
| 688 | + |
| 689 | + tf1 = Normalize(d=d, bounds=bounds, batch_shape=tf_batch_shape) |
| 690 | + tf2 = InputStandardize(d=d, batch_shape=tf_batch_shape) |
| 691 | + transforms = [tf1, tf2] |
| 692 | + tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs) |
| 693 | + # make copies for validation below |
| 694 | + transforms_ = [deepcopy(tf_i) for tf_i in transforms] |
| 695 | + self.assertTrue(tf.training) |
| 696 | + # self.assertEqual(sorted(tf.keys()), ["stz_fixed", "stz_learned"]) |
| 697 | + self.assertEqual(tf.transforms[0], tf1) |
| 698 | + self.assertEqual(tf.transforms[1], tf2) |
| 699 | + self.assertFalse(tf.is_one_to_many) |
| 700 | + |
| 701 | + X = torch.rand(*batch_shape, 4, d, device=self.device, dtype=dtype) |
| 702 | + X_tf = tf(X) |
| 703 | + Xs = X.unbind(dim=broadcast_index) |
| 704 | + |
| 705 | + X_tf_ = torch.stack( |
| 706 | + [tf_i_(Xi) for tf_i_, Xi in zip(transforms_, Xs)], dim=broadcast_index |
| 707 | + ) |
| 708 | + self.assertTrue(tf1.training) |
| 709 | + self.assertTrue(tf2.training) |
| 710 | + self.assertTrue(torch.equal(X_tf, X_tf_)) |
| 711 | + X_utf = tf.untransform(X_tf) |
| 712 | + self.assertAllClose(X_utf, X, atol=1e-4, rtol=1e-4) |
| 713 | + |
| 714 | + # test not transformed on eval |
| 715 | + for tf_i in transforms: |
| 716 | + tf_i.transform_on_eval = False |
| 717 | + |
| 718 | + tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs) |
| 719 | + tf.eval() |
| 720 | + self.assertTrue(torch.equal(tf(X), X)) |
| 721 | + |
| 722 | + # test transformed on eval |
| 723 | + for tf_i in transforms: |
| 724 | + tf_i.transform_on_eval = True |
| 725 | + |
| 726 | + tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs) |
| 727 | + tf.eval() |
| 728 | + self.assertTrue(torch.equal(tf(X), X_tf)) |
| 729 | + |
| 730 | + # test not transformed on train |
| 731 | + for tf_i in transforms: |
| 732 | + tf_i.transform_on_train = False |
| 733 | + |
| 734 | + tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs) |
| 735 | + tf.train() |
| 736 | + self.assertTrue(torch.equal(tf(X), X)) |
| 737 | + |
| 738 | + # test __eq__ |
| 739 | + other_tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs) |
| 740 | + self.assertTrue(tf.equals(other_tf)) |
| 741 | + # change order |
| 742 | + other_tf = BatchBroadcastedInputTransform( |
| 743 | + transforms=list(reversed(transforms)) |
| 744 | + ) |
| 745 | + self.assertFalse(tf.equals(other_tf)) |
| 746 | + # Identical transforms but different objects. |
| 747 | + other_tf = BatchBroadcastedInputTransform( |
| 748 | + transforms=deepcopy(transforms), **kwargs |
| 749 | + ) |
| 750 | + self.assertTrue(tf.equals(other_tf)) |
| 751 | + |
| 752 | + # test preprocess_transform |
| 753 | + transforms[-1].transform_on_train = False |
| 754 | + transforms[0].transform_on_train = True |
| 755 | + tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs) |
| 756 | + self.assertTrue( |
| 757 | + torch.equal( |
| 758 | + tf.preprocess_transform(X).unbind(dim=broadcast_index)[0], |
| 759 | + transforms[0].transform(Xs[0]), |
| 760 | + ) |
| 761 | + ) |
| 762 | + |
| 763 | + # test one-to-many |
| 764 | + tf2 = InputPerturbation(perturbation_set=2 * bounds) |
| 765 | + with self.assertRaisesRegex(ValueError, r".*one_to_many.*"): |
| 766 | + tf = BatchBroadcastedInputTransform(transforms=[tf1, tf2], **kwargs) |
| 767 | + |
| 768 | + # these could technically be batched internally, but we're testing the generic |
| 769 | + # batch broadcasted transform list here. Could change test to use AppendFeatures |
| 770 | + tf1 = InputPerturbation(perturbation_set=bounds) |
| 771 | + tf2 = InputPerturbation(perturbation_set=2 * bounds) |
| 772 | + tf = BatchBroadcastedInputTransform(transforms=[tf1, tf2], **kwargs) |
| 773 | + self.assertTrue(tf.is_one_to_many) |
| 774 | + |
| 775 | + with self.assertRaisesRegex( |
| 776 | + ValueError, r"The broadcast index cannot be -2 and -1" |
| 777 | + ): |
| 778 | + tf = BatchBroadcastedInputTransform( |
| 779 | + transforms=[tf1, tf2], broadcast_index=-2 |
| 780 | + ) |
| 781 | + |
655 | 782 | def test_round_transform_init(self) -> None:
|
656 | 783 | # basic init
|
657 | 784 | int_idcs = [0, 4]
|
|
0 commit comments