Skip to content

Commit 8433c0b

Browse files
authored
Merge pull request #2664 from sdaulton/hadamard
Supporting passing full train inputs, rather than only supporting task indices to in `HadamardGaussianLikelihood`
2 parents 7807ab2 + 25b35b4 commit 8433c0b

File tree

2 files changed

+27
-12
lines changed

2 files changed

+27
-12
lines changed

gpytorch/likelihoods/hadamard_gaussian_likelihood.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class HadamardGaussianLikelihood(_GaussianLikelihoodBase):
4545
type of prior.
4646
noise_constraint: Constraint on the noise value.
4747
batch_shape: The batch shape of the learned noise parameter (default: []).
48+
task_feature_index: The index of the task feature in the input data (default: None).
4849
"""
4950

5051
def __init__(
@@ -53,6 +54,7 @@ def __init__(
5354
noise_prior: Optional[Prior] = None,
5455
noise_constraint: Optional[Interval] = None,
5556
batch_shape: torch.Size = torch.Size(),
57+
task_feature_index: Optional[int] = None,
5658
**kwargs,
5759
):
5860
noise_covar = MultitaskHomoskedasticNoise(
@@ -62,6 +64,7 @@ def __init__(
6264
batch_shape=batch_shape,
6365
)
6466
self.num_tasks = num_tasks
67+
self.task_feature_index = task_feature_index
6568
super().__init__(noise_covar=noise_covar, **kwargs)
6669

6770
@property
@@ -81,17 +84,26 @@ def raw_noise(self, value: torch.Tensor) -> None:
8184
self.noise_covar.initialize(raw_noise=value)
8285

8386
def _shaped_noise_covar(self, base_shape: torch.Size, *params: Any, **kwargs: Any) -> LinearOperator:
84-
# params contains task indexes, shape (*task_batch_shape, num_data, 1)
85-
if not params or not params[0]:
87+
# params contains input data, shape (*task_batch_shape, num_data, d)
88+
if len(params) == 0 or len(params[0]) == 0:
8689
raise ValueError("Task indices must be provided.")
87-
task_idcs = params[0][-1]
90+
task_idcs = params[0] if torch.is_tensor(params[0]) else params[0][0]
91+
if task_idcs.shape[-1] > 1:
92+
if self.task_feature_index is None:
93+
raise ValueError("Task indices must be a single dimension if task_feature_index is not provided.")
94+
# handle case where full inputs are passed in,
95+
# rather than just the task indices
96+
task_idcs = task_idcs[..., self.task_feature_index].unsqueeze(-1)
97+
task_idcs = task_idcs.long()
8898
_check_task_indices(task_idcs, base_shape)
89-
90-
# squeeze to remove the `1` dimension returned by `MultitaskHomoskedasticNoise`
91-
noise_base_covar_matrix = self.noise_covar(*params, shape=base_shape, **kwargs).squeeze(
92-
-4
93-
) # (*batch_shape, num_tasks, num_data, num_data)
94-
all_tasks = torch.arange(self.num_tasks).unsqueeze(-1) # (num_tasks, 1)
99+
noise_base_covar_matrix = self.noise_covar(task_idcs, *params[1:], shape=base_shape, **kwargs)
100+
if self.num_tasks > 1:
101+
# squeeze to remove the `1` dimension returned by `MultitaskHomoskedasticNoise`
102+
# when there is more than 1 task
103+
noise_base_covar_matrix = noise_base_covar_matrix.squeeze(
104+
-4
105+
) # (*batch_shape, num_tasks, num_data, num_data)
106+
all_tasks = torch.arange(self.num_tasks, device=task_idcs.device).unsqueeze(-1) # (num_tasks, 1)
95107
diag = torch.eq(all_tasks, task_idcs.mT) # (num_tasks, num_data)
96108
mask = DiagLinearOperator(diag) # (num_tasks, num_data, num_data)
97109
return (noise_base_covar_matrix @ mask).sum(dim=-3)

test/likelihoods/test_hadamard_multitask_gaussian_likelihood.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,11 @@ def test_marginal_variance(self):
111111
with self.assertRaises(ValueError):
112112
likelihood(input, [])
113113

114-
with self.assertRaises(ValueError):
115-
likelihood(input, [task_idcs.float()])
116-
117114
with self.assertRaises(ValueError):
118115
likelihood(input, [task_idcs.squeeze(-1)])
116+
117+
# test with task feature and full input
118+
likelihood = HadamardGaussianLikelihood(num_tasks=4, task_feature_index=1)
119+
likelihood.noise = torch.tensor([[0.1, 0.2, 0.3, 0.4]])
120+
X = torch.cat([torch.zeros_like(task_idcs), task_idcs], dim=-1)
121+
self.assertEqual(variance, likelihood(input, [X]).variance)

0 commit comments

Comments
 (0)