From 00a7522a03641925a7dbffa1d721251ea90078cf Mon Sep 17 00:00:00 2001 From: Roman Garnett Date: Mon, 19 Sep 2022 18:23:53 -0500 Subject: [PATCH 1/2] slight improvements to IndexKernel for the rank = 0 (diagonal) case --- gpytorch/kernels/index_kernel.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/gpytorch/kernels/index_kernel.py b/gpytorch/kernels/index_kernel.py index 7fa5e01f3..85c14461f 100644 --- a/gpytorch/kernels/index_kernel.py +++ b/gpytorch/kernels/index_kernel.py @@ -65,9 +65,11 @@ def __init__( if var_constraint is None: var_constraint = Positive() - self.register_parameter( - name="covar_factor", parameter=torch.nn.Parameter(torch.randn(*self.batch_shape, num_tasks, rank)) - ) + self.rank = rank + if self.rank > 0: + self.register_parameter( + name="covar_factor", parameter=torch.nn.Parameter(torch.randn(*self.batch_shape, num_tasks, self.rank)) + ) self.register_parameter(name="raw_var", parameter=torch.nn.Parameter(torch.randn(*self.batch_shape, num_tasks))) if prior is not None: if not isinstance(prior, Prior): @@ -94,7 +96,10 @@ def _eval_covar_matrix(self): @property def covar_matrix(self): var = self.var - res = PsdSumLinearOperator(RootLinearOperator(self.covar_factor), DiagLinearOperator(var)) + if self.rank > 0: + res = PsdSumLinearOperator(RootLinearOperator(self.covar_factor), DiagLinearOperator(var)) + else: + res = DiagLinearOperator(var) return res def forward(self, i1, i2, **params): From a4a622b1ebce18fa52e653514c1271d688a885cf Mon Sep 17 00:00:00 2001 From: Roman Garnett Date: Mon, 19 Sep 2022 22:17:56 -0500 Subject: [PATCH 2/2] adding rank = 0 (diagonal) branch to _eval_covar_matrix() in IndexKernel --- gpytorch/kernels/index_kernel.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/gpytorch/kernels/index_kernel.py b/gpytorch/kernels/index_kernel.py index 85c14461f..7fde6497e 100644 --- a/gpytorch/kernels/index_kernel.py +++ b/gpytorch/kernels/index_kernel.py @@ -90,8 +90,11 @@ def _set_var(self, value): self.initialize(raw_var=self.raw_var_constraint.inverse_transform(value)) def _eval_covar_matrix(self): - cf = self.covar_factor - return cf @ cf.transpose(-1, -2) + torch.diag_embed(self.var) + if self.rank > 0: + cf = self.covar_factor + return cf @ cf.transpose(-1, -2) + torch.diag_embed(self.var) + else: + return torch.diag_embed(self.var) @property def covar_matrix(self):