Skip to content

Commit e613790

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
InputTransfrom list broadcasted over batch shapes (#2558)
Summary: Pull Request resolved: #2558 This commit adds `BatchBroadcastedInputTransform`, which broadcasts a list of input transforms across the first batch dimension of the input X, thereby enabling batch models in cases where only the input transforms are structurally different for each batch. Reviewed By: Balandat Differential Revision: D63660807 fbshipit-source-id: 7980f1a03732e83275a38c76693f1591e3c2d891
1 parent 98c1504 commit e613790

File tree

2 files changed

+252
-1
lines changed

2 files changed

+252
-1
lines changed

botorch/models/transforms/input.py

Lines changed: 125 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from abc import ABC, abstractmethod
1919
from collections import OrderedDict
20-
from typing import Any, Callable, Optional, Union
20+
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union
2121
from warnings import warn
2222

2323
import numpy as np
@@ -155,6 +155,130 @@ def preprocess_transform(self, X: Tensor) -> Tensor:
155155
return X
156156

157157

158+
class BatchBroadcastedInputTransform(InputTransform, ModuleDict):
159+
r"""An input transform representing a list of transforms to be broadcasted."""
160+
161+
def __init__(
162+
self,
163+
transforms: List[InputTransform],
164+
broadcast_index: int = -3,
165+
) -> None:
166+
r"""A transform list that is broadcasted across a batch dimension specified by
167+
`broadcast_index`. This is allows using a batched Gaussian process model when
168+
the input transforms are different for different batch dimensions.
169+
170+
Args:
171+
transforms: The transforms to broadcast across the first batch dimension.
172+
The transform at position i in the list will be applied to `X[i]` for
173+
a given input tensor `X` in the forward pass.
174+
broadcast_index: The tensor index at which the transforms are broadcasted.
175+
176+
Example:
177+
>>> tf1 = Normalize(d=2)
178+
>>> tf2 = InputStandardize(d=2)
179+
>>> tf = BatchBroadcastedTransformList(transforms=[tf1, tf2])
180+
"""
181+
super().__init__()
182+
self.transform_on_train = False
183+
self.transform_on_eval = False
184+
self.transform_on_fantasize = False
185+
self.transforms = transforms
186+
if broadcast_index in (-2, -1):
187+
raise ValueError(
188+
"The broadcast index cannot be -2 and -1, as these indices are reserved"
189+
" for non-batch, data and input dimensions."
190+
)
191+
self.broadcast_index = broadcast_index
192+
self.is_one_to_many = self.transforms[0].is_one_to_many
193+
if not all(tf.is_one_to_many == self.is_one_to_many for tf in self.transforms):
194+
raise ValueError( # output shapes of transforms must be the same
195+
"All transforms must have the same is_one_to_many property."
196+
)
197+
for tf in self.transforms:
198+
self.transform_on_train |= tf.transform_on_train
199+
self.transform_on_eval |= tf.transform_on_eval
200+
self.transform_on_fantasize |= tf.transform_on_fantasize
201+
202+
def transform(self, X: Tensor) -> Tensor:
203+
r"""Transform the inputs to a model.
204+
205+
Individual transforms are applied in sequence and results are returned as
206+
a batched tensor.
207+
208+
Args:
209+
X: A `batch_shape x n x d`-dim tensor of inputs.
210+
211+
Returns:
212+
A `batch_shape x n x d`-dim tensor of transformed inputs.
213+
"""
214+
return torch.stack(
215+
[t.forward(Xi) for Xi, t in self._Xs_and_transforms(X)],
216+
dim=self.broadcast_index,
217+
)
218+
219+
def untransform(self, X: Tensor) -> Tensor:
220+
r"""Un-transform the inputs to a model.
221+
222+
Un-transforms of the individual transforms are applied in reverse sequence.
223+
224+
Args:
225+
X: A `batch_shape x n x d`-dim tensor of transformed inputs.
226+
227+
Returns:
228+
A `batch_shape x n x d`-dim tensor of un-transformed inputs.
229+
"""
230+
return torch.stack(
231+
[t.untransform(Xi) for Xi, t in self._Xs_and_transforms(X)],
232+
dim=self.broadcast_index,
233+
)
234+
235+
def equals(self, other: InputTransform) -> bool:
236+
r"""Check if another input transform is equivalent.
237+
238+
Args:
239+
other: Another input transform.
240+
241+
Returns:
242+
A boolean indicating if the other transform is equivalent.
243+
"""
244+
return (
245+
super().equals(other=other)
246+
and all(t1.equals(t2) for t1, t2 in zip(self.transforms, other.transforms))
247+
and (self.broadcast_index == other.broadcast_index)
248+
)
249+
250+
def preprocess_transform(self, X: Tensor) -> Tensor:
251+
r"""Apply transforms for preprocessing inputs.
252+
253+
The main use cases for this method are 1) to preprocess training data
254+
before calling `set_train_data` and 2) preprocess `X_baseline` for noisy
255+
acquisition functions so that `X_baseline` is "preprocessed" with the
256+
same transformations as the cached training inputs.
257+
258+
Args:
259+
X: A `batch_shape x n x d`-dim tensor of inputs.
260+
261+
Returns:
262+
A `batch_shape x n x d`-dim tensor of (transformed) inputs.
263+
"""
264+
return torch.stack(
265+
[t.preprocess_transform(Xi) for Xi, t in self._Xs_and_transforms(X)],
266+
dim=self.broadcast_index,
267+
)
268+
269+
def _Xs_and_transforms(self, X: Tensor) -> Iterable[Tuple[Tensor, InputTransform]]:
270+
r"""Returns an iterable of sub-tensors of X and their associated transforms.
271+
272+
Args:
273+
X: A `batch_shape x n x d`-dim tensor of inputs.
274+
275+
Returns:
276+
An iterable containing tuples of sub-tensors of X and their transforms.
277+
"""
278+
Xs = X.unbind(dim=self.broadcast_index)
279+
return zip(Xs, self.transforms)
280+
281+
158282
class ChainedInputTransform(InputTransform, ModuleDict):
159283
r"""An input transform representing the chaining of individual transforms."""
160284

test/models/transforms/test_input.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from botorch.models.transforms.input import (
1515
AffineInputTransform,
1616
AppendFeatures,
17+
BatchBroadcastedInputTransform,
1718
ChainedInputTransform,
1819
FilterFeatures,
1920
InputPerturbation,
@@ -652,6 +653,132 @@ def test_chained_input_transform(self) -> None:
652653
tf = ChainedInputTransform(stz=tf1, pert=tf2)
653654
self.assertTrue(tf.is_one_to_many)
654655

656+
def test_batch_broadcasted_input_transform(self) -> None:
657+
ds = (1, 2)
658+
batch_args = [
659+
(torch.Size([2]), {}),
660+
(torch.Size([3, 2]), {}),
661+
(torch.Size([2, 3]), {"broadcast_index": 0}),
662+
(torch.Size([5, 2, 3]), {"broadcast_index": 1}),
663+
]
664+
dtypes = (torch.float, torch.double)
665+
# set seed to range where this is known to not be flaky
666+
torch.manual_seed(randint(0, 1000))
667+
668+
for d, (batch_shape, kwargs), dtype in itertools.product(
669+
ds, batch_args, dtypes
670+
):
671+
bounds = torch.tensor(
672+
[[-2.0] * d, [2.0] * d], device=self.device, dtype=dtype
673+
)
674+
# when the batch_shape is (2, 3), the transform list is broadcasted across
675+
# the first dimension, whereas each individual transform gets broadcasted
676+
# over the remaining batch dimensions.
677+
if "broadcast_index" not in kwargs:
678+
broadcast_index = -3
679+
tf_batch_shape = batch_shape[:-1]
680+
else:
681+
broadcast_index = kwargs["broadcast_index"]
682+
# if the broadcast index is negative, we need to adjust the index
683+
# when indexing into the batch shape tuple
684+
i = broadcast_index + 2 if broadcast_index < 0 else broadcast_index
685+
tf_batch_shape = list(batch_shape[:i])
686+
tf_batch_shape.extend(list(batch_shape[i + 1 :]))
687+
tf_batch_shape = torch.Size(tf_batch_shape)
688+
689+
tf1 = Normalize(d=d, bounds=bounds, batch_shape=tf_batch_shape)
690+
tf2 = InputStandardize(d=d, batch_shape=tf_batch_shape)
691+
transforms = [tf1, tf2]
692+
tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs)
693+
# make copies for validation below
694+
transforms_ = [deepcopy(tf_i) for tf_i in transforms]
695+
self.assertTrue(tf.training)
696+
# self.assertEqual(sorted(tf.keys()), ["stz_fixed", "stz_learned"])
697+
self.assertEqual(tf.transforms[0], tf1)
698+
self.assertEqual(tf.transforms[1], tf2)
699+
self.assertFalse(tf.is_one_to_many)
700+
701+
X = torch.rand(*batch_shape, 4, d, device=self.device, dtype=dtype)
702+
X_tf = tf(X)
703+
Xs = X.unbind(dim=broadcast_index)
704+
705+
X_tf_ = torch.stack(
706+
[tf_i_(Xi) for tf_i_, Xi in zip(transforms_, Xs)], dim=broadcast_index
707+
)
708+
self.assertTrue(tf1.training)
709+
self.assertTrue(tf2.training)
710+
self.assertTrue(torch.equal(X_tf, X_tf_))
711+
X_utf = tf.untransform(X_tf)
712+
self.assertAllClose(X_utf, X, atol=1e-4, rtol=1e-4)
713+
714+
# test not transformed on eval
715+
for tf_i in transforms:
716+
tf_i.transform_on_eval = False
717+
718+
tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs)
719+
tf.eval()
720+
self.assertTrue(torch.equal(tf(X), X))
721+
722+
# test transformed on eval
723+
for tf_i in transforms:
724+
tf_i.transform_on_eval = True
725+
726+
tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs)
727+
tf.eval()
728+
self.assertTrue(torch.equal(tf(X), X_tf))
729+
730+
# test not transformed on train
731+
for tf_i in transforms:
732+
tf_i.transform_on_train = False
733+
734+
tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs)
735+
tf.train()
736+
self.assertTrue(torch.equal(tf(X), X))
737+
738+
# test __eq__
739+
other_tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs)
740+
self.assertTrue(tf.equals(other_tf))
741+
# change order
742+
other_tf = BatchBroadcastedInputTransform(
743+
transforms=list(reversed(transforms))
744+
)
745+
self.assertFalse(tf.equals(other_tf))
746+
# Identical transforms but different objects.
747+
other_tf = BatchBroadcastedInputTransform(
748+
transforms=deepcopy(transforms), **kwargs
749+
)
750+
self.assertTrue(tf.equals(other_tf))
751+
752+
# test preprocess_transform
753+
transforms[-1].transform_on_train = False
754+
transforms[0].transform_on_train = True
755+
tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs)
756+
self.assertTrue(
757+
torch.equal(
758+
tf.preprocess_transform(X).unbind(dim=broadcast_index)[0],
759+
transforms[0].transform(Xs[0]),
760+
)
761+
)
762+
763+
# test one-to-many
764+
tf2 = InputPerturbation(perturbation_set=2 * bounds)
765+
with self.assertRaisesRegex(ValueError, r".*one_to_many.*"):
766+
tf = BatchBroadcastedInputTransform(transforms=[tf1, tf2], **kwargs)
767+
768+
# these could technically be batched internally, but we're testing the generic
769+
# batch broadcasted transform list here. Could change test to use AppendFeatures
770+
tf1 = InputPerturbation(perturbation_set=bounds)
771+
tf2 = InputPerturbation(perturbation_set=2 * bounds)
772+
tf = BatchBroadcastedInputTransform(transforms=[tf1, tf2], **kwargs)
773+
self.assertTrue(tf.is_one_to_many)
774+
775+
with self.assertRaisesRegex(
776+
ValueError, r"The broadcast index cannot be -2 and -1"
777+
):
778+
tf = BatchBroadcastedInputTransform(
779+
transforms=[tf1, tf2], broadcast_index=-2
780+
)
781+
655782
def test_round_transform_init(self) -> None:
656783
# basic init
657784
int_idcs = [0, 4]

0 commit comments

Comments
 (0)