Skip to content

Commit e06f8da

Browse files
Casting Tensor to LinearOperator in constructor of MultivariateNormal to ensure PSD-safe factorization
1 parent 442ad98 commit e06f8da

File tree

3 files changed

+63
-43
lines changed

3 files changed

+63
-43
lines changed

gpytorch/distributions/multitask_multivariate_normal.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22

33
import torch
44
from linear_operator import LinearOperator, to_linear_operator
5-
from linear_operator.operators import BlockDiagLinearOperator, BlockInterleavedLinearOperator, CatLinearOperator
5+
from linear_operator.operators import (BlockDiagLinearOperator,
6+
BlockInterleavedLinearOperator,
7+
CatLinearOperator)
8+
from torch import Tensor
69

710
from .multivariate_normal import MultivariateNormal
811

@@ -24,7 +27,9 @@ class MultitaskMultivariateNormal(MultivariateNormal):
2427
w.r.t. inter-observation covariance for each task.
2528
"""
2629

27-
def __init__(self, mean, covariance_matrix, validate_args=False, interleaved=True):
30+
def __init__(
31+
self, mean: Tensor, covariance_matrix: LinearOperator, validate_args: bool = False, interleaved: bool = True
32+
):
2833
if not torch.is_tensor(mean) and not isinstance(mean, LinearOperator):
2934
raise RuntimeError("The mean of a MultitaskMultivariateNormal must be a Tensor or LinearOperator")
3035

gpytorch/distributions/multivariate_normal.py

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99

1010
import torch
1111
from linear_operator import to_dense, to_linear_operator
12-
from linear_operator.operators import DiagLinearOperator, LinearOperator, RootLinearOperator
12+
from linear_operator.operators import (DiagLinearOperator, LinearOperator,
13+
RootLinearOperator)
1314
from torch import Tensor
1415
from torch.distributions import MultivariateNormal as TMultivariateNormal
1516
from torch.distributions.kl import register_kl
1617
from torch.distributions.utils import _standard_normal, lazy_property
1718

1819
from .. import settings
19-
from ..utils.warnings import NumericalWarning
2020
from .distribution import Distribution
2121

2222

@@ -42,27 +42,36 @@ class MultivariateNormal(TMultivariateNormal, Distribution):
4242
:ivar torch.Tensor variance: The variance.
4343
"""
4444

45-
def __init__(self, mean: Tensor, covariance_matrix: Union[Tensor, LinearOperator], validate_args: bool = False):
46-
self._islazy = isinstance(mean, LinearOperator) or isinstance(covariance_matrix, LinearOperator)
47-
if self._islazy:
48-
if validate_args:
49-
ms = mean.size(-1)
50-
cs1 = covariance_matrix.size(-1)
51-
cs2 = covariance_matrix.size(-2)
52-
if not (ms == cs1 and ms == cs2):
53-
raise ValueError(f"Wrong shapes in {self._repr_sizes(mean, covariance_matrix)}")
54-
self.loc = mean
55-
self._covar = covariance_matrix
56-
self.__unbroadcasted_scale_tril = None
57-
self._validate_args = validate_args
58-
batch_shape = torch.broadcast_shapes(self.loc.shape[:-1], covariance_matrix.shape[:-2])
59-
60-
event_shape = self.loc.shape[-1:]
61-
62-
# TODO: Integrate argument validation for LinearOperators into torch.distribution validation logic
63-
super(TMultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=False)
64-
else:
65-
super().__init__(loc=mean, covariance_matrix=covariance_matrix, validate_args=validate_args)
45+
def __init__(
46+
self,
47+
mean: Union[Tensor, LinearOperator],
48+
covariance_matrix: Union[Tensor, LinearOperator],
49+
validate_args: bool = False,
50+
):
51+
self._islazy = True
52+
# casting Tensor to DenseLinearOperator because the super constructor calls cholesky, which
53+
# will fail if the covariance matrix is semi-definite, whereas DenseLinearOperator ends up
54+
# calling _psd_safe_cholesky, which factorizes semi-definite matrices by adding to the diagonal.
55+
if isinstance(covariance_matrix, Tensor):
56+
self._islazy = False # to allow _unbroadcasted_scale_tril setter
57+
covariance_matrix = to_linear_operator(covariance_matrix)
58+
59+
if validate_args:
60+
ms = mean.size(-1)
61+
cs1 = covariance_matrix.size(-1)
62+
cs2 = covariance_matrix.size(-2)
63+
if not (ms == cs1 and ms == cs2):
64+
raise ValueError(f"Wrong shapes in {self._repr_sizes(mean, covariance_matrix)}")
65+
self.loc = mean
66+
self._covar = covariance_matrix
67+
self.__unbroadcasted_scale_tril = None
68+
self._validate_args = validate_args
69+
batch_shape = torch.broadcast_shapes(self.loc.shape[:-1], covariance_matrix.shape[:-2])
70+
71+
event_shape = self.loc.shape[-1:]
72+
73+
# TODO: Integrate argument validation for LinearOperators into torch.distribution validation logic
74+
super(TMultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=False)
6675

6776
def _extended_shape(self, sample_shape: torch.Size = torch.Size()) -> torch.Size:
6877
"""
@@ -81,16 +90,16 @@ def _extended_shape(self, sample_shape: torch.Size = torch.Size()) -> torch.Size
8190
def _repr_sizes(mean: Tensor, covariance_matrix: Union[Tensor, LinearOperator]) -> str:
8291
return f"MultivariateNormal(loc: {mean.size()}, scale: {covariance_matrix.size()})"
8392

84-
@property
93+
@property # not using lazy_property here, because it does not allow for setter below
8594
def _unbroadcasted_scale_tril(self) -> Tensor:
86-
if self.islazy and self.__unbroadcasted_scale_tril is None:
95+
if self.__unbroadcasted_scale_tril is None:
8796
# cache root decoposition
8897
ust = to_dense(self.lazy_covariance_matrix.cholesky())
8998
self.__unbroadcasted_scale_tril = ust
9099
return self.__unbroadcasted_scale_tril
91100

92101
@_unbroadcasted_scale_tril.setter
93-
def _unbroadcasted_scale_tril(self, ust: Tensor):
102+
def _unbroadcasted_scale_tril(self, ust: Tensor) -> None:
94103
if self.islazy:
95104
raise NotImplementedError("Cannot set _unbroadcasted_scale_tril for lazy MVN distributions")
96105
else:
@@ -114,10 +123,7 @@ def base_sample_shape(self) -> torch.Size:
114123

115124
@lazy_property
116125
def covariance_matrix(self) -> Tensor:
117-
if self.islazy:
118-
return self._covar.to_dense()
119-
else:
120-
return super().covariance_matrix
126+
return self._covar.to_dense()
121127

122128
def confidence_region(self) -> Tuple[Tensor, Tensor]:
123129
"""
@@ -157,10 +163,7 @@ def get_base_samples(self, sample_shape: torch.Size = torch.Size()) -> Tensor:
157163

158164
@lazy_property
159165
def lazy_covariance_matrix(self) -> LinearOperator:
160-
if self.islazy:
161-
return self._covar
162-
else:
163-
return to_linear_operator(super().covariance_matrix)
166+
return self._covar
164167

165168
def log_prob(self, value: Tensor) -> Tensor:
166169
r"""
@@ -304,13 +307,10 @@ def to_data_independent_dist(self) -> torch.distributions.Normal:
304307

305308
@property
306309
def variance(self) -> Tensor:
307-
if self.islazy:
308-
# overwrite this since torch MVN uses unbroadcasted_scale_tril for this
309-
diag = self.lazy_covariance_matrix.diagonal(dim1=-1, dim2=-2)
310-
diag = diag.view(diag.shape[:-1] + self._event_shape)
311-
variance = diag.expand(self._batch_shape + self._event_shape)
312-
else:
313-
variance = super().variance
310+
# overwrite this since torch MVN uses unbroadcasted_scale_tril for this
311+
diag = self.lazy_covariance_matrix.diagonal(dim1=-1, dim2=-2)
312+
diag = diag.view(diag.shape[:-1] + self._event_shape)
313+
variance = diag.expand(self._batch_shape + self._event_shape)
314314

315315
# Check to make sure that variance isn't lower than minimum allowed value (default 1e-6).
316316
# This ensures that all variances are positive

test/distributions/test_multivariate_normal.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from gpytorch.distributions import MultivariateNormal
1212
from gpytorch.test.base_test_case import BaseTestCase
1313
from gpytorch.test.utils import least_used_cuda_device
14+
from gpytorch.utils.warnings import NumericalWarning
1415

1516

1617
class TestMultivariateNormal(BaseTestCase, unittest.TestCase):
@@ -47,6 +48,20 @@ def test_multivariate_normal_non_lazy(self, cuda=False):
4748
self.assertTrue(mvn.sample(torch.Size([2])).shape == torch.Size([2, 3]))
4849
self.assertTrue(mvn.sample(torch.Size([2, 4])).shape == torch.Size([2, 4, 3]))
4950

51+
# testing with semi-definite input
52+
A = torch.randn(len(mean), 1)
53+
covmat = A @ A.T
54+
handles_psd = False
55+
try:
56+
# the regular call fails:
57+
# mvn = TMultivariateNormal(loc=mean, covariance_matrix=covmat, validate_args=True)
58+
mvn = MultivariateNormal(mean=mean, covariance_matrix=covmat, validate_args=True)
59+
mvn.sample()
60+
handles_psd = True
61+
except ValueError:
62+
handles_psd = False
63+
self.assertTrue(handles_psd)
64+
5065
def test_multivariate_normal_non_lazy_cuda(self):
5166
if torch.cuda.is_available():
5267
with least_used_cuda_device():

0 commit comments

Comments
 (0)