@@ -45,6 +45,7 @@ class HadamardGaussianLikelihood(_GaussianLikelihoodBase):
4545 type of prior.
4646 noise_constraint: Constraint on the noise value.
4747 batch_shape: The batch shape of the learned noise parameter (default: []).
48+ task_feature_index: The index of the task feature in the input data (default: None).
4849 """
4950
5051 def __init__ (
@@ -53,6 +54,7 @@ def __init__(
5354 noise_prior : Optional [Prior ] = None ,
5455 noise_constraint : Optional [Interval ] = None ,
5556 batch_shape : torch .Size = torch .Size (),
57+ task_feature_index : Optional [int ] = None ,
5658 ** kwargs ,
5759 ):
5860 noise_covar = MultitaskHomoskedasticNoise (
@@ -62,6 +64,7 @@ def __init__(
6264 batch_shape = batch_shape ,
6365 )
6466 self .num_tasks = num_tasks
67+ self .task_feature_index = task_feature_index
6568 super ().__init__ (noise_covar = noise_covar , ** kwargs )
6669
6770 @property
@@ -81,17 +84,26 @@ def raw_noise(self, value: torch.Tensor) -> None:
8184 self .noise_covar .initialize (raw_noise = value )
8285
8386 def _shaped_noise_covar (self , base_shape : torch .Size , * params : Any , ** kwargs : Any ) -> LinearOperator :
84- # params contains task indexes , shape (*task_batch_shape, num_data, 1 )
85- if not params or not params [0 ]:
87+ # params contains input data , shape (*task_batch_shape, num_data, d )
88+ if len ( params ) == 0 or len ( params [0 ]) == 0 :
8689 raise ValueError ("Task indices must be provided." )
87- task_idcs = params [0 ][- 1 ]
90+ task_idcs = params [0 ] if torch .is_tensor (params [0 ]) else params [0 ][0 ]
91+ if task_idcs .shape [- 1 ] > 1 :
92+ if self .task_feature_index is None :
93+ raise ValueError ("Task indices must be a single dimension if task_feature_index is not provided." )
94+ # handle case where full inputs are passed in,
95+ # rather than just the task indices
96+ task_idcs = task_idcs [..., self .task_feature_index ].unsqueeze (- 1 )
97+ task_idcs = task_idcs .long ()
8898 _check_task_indices (task_idcs , base_shape )
89-
90- # squeeze to remove the `1` dimension returned by `MultitaskHomoskedasticNoise`
91- noise_base_covar_matrix = self .noise_covar (* params , shape = base_shape , ** kwargs ).squeeze (
92- - 4
93- ) # (*batch_shape, num_tasks, num_data, num_data)
94- all_tasks = torch .arange (self .num_tasks ).unsqueeze (- 1 ) # (num_tasks, 1)
99+ noise_base_covar_matrix = self .noise_covar (task_idcs , * params [1 :], shape = base_shape , ** kwargs )
100+ if self .num_tasks > 1 :
101+ # squeeze to remove the `1` dimension returned by `MultitaskHomoskedasticNoise`
102+ # when there is more than 1 task
103+ noise_base_covar_matrix = noise_base_covar_matrix .squeeze (
104+ - 4
105+ ) # (*batch_shape, num_tasks, num_data, num_data)
106+ all_tasks = torch .arange (self .num_tasks , device = task_idcs .device ).unsqueeze (- 1 ) # (num_tasks, 1)
95107 diag = torch .eq (all_tasks , task_idcs .mT ) # (num_tasks, num_data)
96108 mask = DiagLinearOperator (diag ) # (num_tasks, num_data, num_data)
97109 return (noise_base_covar_matrix @ mask ).sum (dim = - 3 )
0 commit comments