Skip to content

Commit d0e35e4

Browse files
Casting Tensor to LinearOperator in constructor of MultivariateNormal to ensure PSD-safe factorization
1 parent 442ad98 commit d0e35e4

File tree

3 files changed

+62
-42
lines changed

3 files changed

+62
-42
lines changed

gpytorch/distributions/multitask_multivariate_normal.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22

33
import torch
44
from linear_operator import LinearOperator, to_linear_operator
5-
from linear_operator.operators import BlockDiagLinearOperator, BlockInterleavedLinearOperator, CatLinearOperator
5+
from linear_operator.operators import (BlockDiagLinearOperator,
6+
BlockInterleavedLinearOperator,
7+
CatLinearOperator)
8+
from torch import Tensor
69

710
from .multivariate_normal import MultivariateNormal
811

@@ -24,7 +27,9 @@ class MultitaskMultivariateNormal(MultivariateNormal):
2427
w.r.t. inter-observation covariance for each task.
2528
"""
2629

27-
def __init__(self, mean, covariance_matrix, validate_args=False, interleaved=True):
30+
def __init__(
31+
self, mean: Tensor, covariance_matrix: LinearOperator, validate_args: bool = False, interleaved: bool = True
32+
):
2833
if not torch.is_tensor(mean) and not isinstance(mean, LinearOperator):
2934
raise RuntimeError("The mean of a MultitaskMultivariateNormal must be a Tensor or LinearOperator")
3035

gpytorch/distributions/multivariate_normal.py

Lines changed: 41 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
import torch
1111
from linear_operator import to_dense, to_linear_operator
12-
from linear_operator.operators import DiagLinearOperator, LinearOperator, RootLinearOperator
12+
from linear_operator.operators import (DiagLinearOperator, LinearOperator,
13+
RootLinearOperator)
1314
from torch import Tensor
1415
from torch.distributions import MultivariateNormal as TMultivariateNormal
1516
from torch.distributions.kl import register_kl
@@ -42,27 +43,36 @@ class MultivariateNormal(TMultivariateNormal, Distribution):
4243
:ivar torch.Tensor variance: The variance.
4344
"""
4445

45-
def __init__(self, mean: Tensor, covariance_matrix: Union[Tensor, LinearOperator], validate_args: bool = False):
46-
self._islazy = isinstance(mean, LinearOperator) or isinstance(covariance_matrix, LinearOperator)
47-
if self._islazy:
48-
if validate_args:
49-
ms = mean.size(-1)
50-
cs1 = covariance_matrix.size(-1)
51-
cs2 = covariance_matrix.size(-2)
52-
if not (ms == cs1 and ms == cs2):
53-
raise ValueError(f"Wrong shapes in {self._repr_sizes(mean, covariance_matrix)}")
54-
self.loc = mean
55-
self._covar = covariance_matrix
56-
self.__unbroadcasted_scale_tril = None
57-
self._validate_args = validate_args
58-
batch_shape = torch.broadcast_shapes(self.loc.shape[:-1], covariance_matrix.shape[:-2])
59-
60-
event_shape = self.loc.shape[-1:]
61-
62-
# TODO: Integrate argument validation for LinearOperators into torch.distribution validation logic
63-
super(TMultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=False)
64-
else:
65-
super().__init__(loc=mean, covariance_matrix=covariance_matrix, validate_args=validate_args)
46+
def __init__(
47+
self,
48+
mean: Union[Tensor, LinearOperator],
49+
covariance_matrix: Union[Tensor, LinearOperator],
50+
validate_args: bool = False,
51+
):
52+
self._islazy = True
53+
# casting Tensor to DenseLinearOperator because the super constructor calls cholesky, which
54+
# will fail if the covariance matrix is semi-definite, whereas DenseLinearOperator ends up
55+
# calling _psd_safe_cholesky, which factorizes semi-definite matrices by adding to the diagonal.
56+
if isinstance(covariance_matrix, Tensor):
57+
self._islazy = False # to allow _unbroadcasted_scale_tril setter
58+
covariance_matrix = to_linear_operator(covariance_matrix)
59+
60+
if validate_args:
61+
ms = mean.size(-1)
62+
cs1 = covariance_matrix.size(-1)
63+
cs2 = covariance_matrix.size(-2)
64+
if not (ms == cs1 and ms == cs2):
65+
raise ValueError(f"Wrong shapes in {self._repr_sizes(mean, covariance_matrix)}")
66+
self.loc = mean
67+
self._covar = covariance_matrix
68+
self.__unbroadcasted_scale_tril = None
69+
self._validate_args = validate_args
70+
batch_shape = torch.broadcast_shapes(self.loc.shape[:-1], covariance_matrix.shape[:-2])
71+
72+
event_shape = self.loc.shape[-1:]
73+
74+
# TODO: Integrate argument validation for LinearOperators into torch.distribution validation logic
75+
super(TMultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=False)
6676

6777
def _extended_shape(self, sample_shape: torch.Size = torch.Size()) -> torch.Size:
6878
"""
@@ -81,16 +91,16 @@ def _extended_shape(self, sample_shape: torch.Size = torch.Size()) -> torch.Size
8191
def _repr_sizes(mean: Tensor, covariance_matrix: Union[Tensor, LinearOperator]) -> str:
8292
return f"MultivariateNormal(loc: {mean.size()}, scale: {covariance_matrix.size()})"
8393

84-
@property
94+
@property # not using lazy_property here, because it does not allow for setter below
8595
def _unbroadcasted_scale_tril(self) -> Tensor:
86-
if self.islazy and self.__unbroadcasted_scale_tril is None:
96+
if self.__unbroadcasted_scale_tril is None:
8797
# cache root decoposition
8898
ust = to_dense(self.lazy_covariance_matrix.cholesky())
8999
self.__unbroadcasted_scale_tril = ust
90100
return self.__unbroadcasted_scale_tril
91101

92102
@_unbroadcasted_scale_tril.setter
93-
def _unbroadcasted_scale_tril(self, ust: Tensor):
103+
def _unbroadcasted_scale_tril(self, ust: Tensor) -> None:
94104
if self.islazy:
95105
raise NotImplementedError("Cannot set _unbroadcasted_scale_tril for lazy MVN distributions")
96106
else:
@@ -114,10 +124,7 @@ def base_sample_shape(self) -> torch.Size:
114124

115125
@lazy_property
116126
def covariance_matrix(self) -> Tensor:
117-
if self.islazy:
118-
return self._covar.to_dense()
119-
else:
120-
return super().covariance_matrix
127+
return self._covar.to_dense()
121128

122129
def confidence_region(self) -> Tuple[Tensor, Tensor]:
123130
"""
@@ -157,10 +164,7 @@ def get_base_samples(self, sample_shape: torch.Size = torch.Size()) -> Tensor:
157164

158165
@lazy_property
159166
def lazy_covariance_matrix(self) -> LinearOperator:
160-
if self.islazy:
161-
return self._covar
162-
else:
163-
return to_linear_operator(super().covariance_matrix)
167+
return self._covar
164168

165169
def log_prob(self, value: Tensor) -> Tensor:
166170
r"""
@@ -304,13 +308,10 @@ def to_data_independent_dist(self) -> torch.distributions.Normal:
304308

305309
@property
306310
def variance(self) -> Tensor:
307-
if self.islazy:
308-
# overwrite this since torch MVN uses unbroadcasted_scale_tril for this
309-
diag = self.lazy_covariance_matrix.diagonal(dim1=-1, dim2=-2)
310-
diag = diag.view(diag.shape[:-1] + self._event_shape)
311-
variance = diag.expand(self._batch_shape + self._event_shape)
312-
else:
313-
variance = super().variance
311+
# overwrite this since torch MVN uses unbroadcasted_scale_tril for this
312+
diag = self.lazy_covariance_matrix.diagonal(dim1=-1, dim2=-2)
313+
diag = diag.view(diag.shape[:-1] + self._event_shape)
314+
variance = diag.expand(self._batch_shape + self._event_shape)
314315

315316
# Check to make sure that variance isn't lower than minimum allowed value (default 1e-6).
316317
# This ensures that all variances are positive

test/distributions/test_multivariate_normal.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,20 @@ def test_multivariate_normal_non_lazy(self, cuda=False):
4747
self.assertTrue(mvn.sample(torch.Size([2])).shape == torch.Size([2, 3]))
4848
self.assertTrue(mvn.sample(torch.Size([2, 4])).shape == torch.Size([2, 4, 3]))
4949

50+
# testing with semi-definite input
51+
A = torch.randn(len(mean), 1)
52+
covmat = A @ A.T
53+
handles_psd = False
54+
try:
55+
# the regular call fails:
56+
# mvn = TMultivariateNormal(loc=mean, covariance_matrix=covmat, validate_args=True)
57+
mvn = MultivariateNormal(mean=mean, covariance_matrix=covmat, validate_args=True)
58+
mvn.sample()
59+
handles_psd = True
60+
except ValueError:
61+
handles_psd = False
62+
self.assertTrue(handles_psd)
63+
5064
def test_multivariate_normal_non_lazy_cuda(self):
5165
if torch.cuda.is_available():
5266
with least_used_cuda_device():

0 commit comments

Comments
 (0)