99
1010import torch
1111from linear_operator import to_dense , to_linear_operator
12- from linear_operator .operators import DiagLinearOperator , LinearOperator , RootLinearOperator
12+ from linear_operator .operators import (DiagLinearOperator , LinearOperator ,
13+ RootLinearOperator )
1314from torch import Tensor
1415from torch .distributions import MultivariateNormal as TMultivariateNormal
1516from torch .distributions .kl import register_kl
1617from torch .distributions .utils import _standard_normal , lazy_property
1718
1819from .. import settings
19- from ..utils .warnings import NumericalWarning
2020from .distribution import Distribution
2121
2222
@@ -42,27 +42,36 @@ class MultivariateNormal(TMultivariateNormal, Distribution):
4242 :ivar torch.Tensor variance: The variance.
4343 """
4444
45- def __init__ (self , mean : Tensor , covariance_matrix : Union [Tensor , LinearOperator ], validate_args : bool = False ):
46- self ._islazy = isinstance (mean , LinearOperator ) or isinstance (covariance_matrix , LinearOperator )
47- if self ._islazy :
48- if validate_args :
49- ms = mean .size (- 1 )
50- cs1 = covariance_matrix .size (- 1 )
51- cs2 = covariance_matrix .size (- 2 )
52- if not (ms == cs1 and ms == cs2 ):
53- raise ValueError (f"Wrong shapes in { self ._repr_sizes (mean , covariance_matrix )} " )
54- self .loc = mean
55- self ._covar = covariance_matrix
56- self .__unbroadcasted_scale_tril = None
57- self ._validate_args = validate_args
58- batch_shape = torch .broadcast_shapes (self .loc .shape [:- 1 ], covariance_matrix .shape [:- 2 ])
59-
60- event_shape = self .loc .shape [- 1 :]
61-
62- # TODO: Integrate argument validation for LinearOperators into torch.distribution validation logic
63- super (TMultivariateNormal , self ).__init__ (batch_shape , event_shape , validate_args = False )
64- else :
65- super ().__init__ (loc = mean , covariance_matrix = covariance_matrix , validate_args = validate_args )
45+ def __init__ (
46+ self ,
47+ mean : Union [Tensor , LinearOperator ],
48+ covariance_matrix : Union [Tensor , LinearOperator ],
49+ validate_args : bool = False ,
50+ ):
51+ self ._islazy = True
52+ # casting Tensor to DenseLinearOperator because the super constructor calls cholesky, which
53+ # will fail if the covariance matrix is semi-definite, whereas DenseLinearOperator ends up
54+ # calling _psd_safe_cholesky, which factorizes semi-definite matrices by adding to the diagonal.
55+ if isinstance (covariance_matrix , Tensor ):
56+ self ._islazy = False # to allow _unbroadcasted_scale_tril setter
57+ covariance_matrix = to_linear_operator (covariance_matrix )
58+
59+ if validate_args :
60+ ms = mean .size (- 1 )
61+ cs1 = covariance_matrix .size (- 1 )
62+ cs2 = covariance_matrix .size (- 2 )
63+ if not (ms == cs1 and ms == cs2 ):
64+ raise ValueError (f"Wrong shapes in { self ._repr_sizes (mean , covariance_matrix )} " )
65+ self .loc = mean
66+ self ._covar = covariance_matrix
67+ self .__unbroadcasted_scale_tril = None
68+ self ._validate_args = validate_args
69+ batch_shape = torch .broadcast_shapes (self .loc .shape [:- 1 ], covariance_matrix .shape [:- 2 ])
70+
71+ event_shape = self .loc .shape [- 1 :]
72+
73+ # TODO: Integrate argument validation for LinearOperators into torch.distribution validation logic
74+ super (TMultivariateNormal , self ).__init__ (batch_shape , event_shape , validate_args = False )
6675
6776 def _extended_shape (self , sample_shape : torch .Size = torch .Size ()) -> torch .Size :
6877 """
@@ -81,16 +90,16 @@ def _extended_shape(self, sample_shape: torch.Size = torch.Size()) -> torch.Size
8190 def _repr_sizes (mean : Tensor , covariance_matrix : Union [Tensor , LinearOperator ]) -> str :
8291 return f"MultivariateNormal(loc: { mean .size ()} , scale: { covariance_matrix .size ()} )"
8392
84- @property
93+ @property # not using lazy_property here, because it does not allow for setter below
8594 def _unbroadcasted_scale_tril (self ) -> Tensor :
86- if self .islazy and self . __unbroadcasted_scale_tril is None :
95+ if self .__unbroadcasted_scale_tril is None :
8796 # cache root decoposition
8897 ust = to_dense (self .lazy_covariance_matrix .cholesky ())
8998 self .__unbroadcasted_scale_tril = ust
9099 return self .__unbroadcasted_scale_tril
91100
92101 @_unbroadcasted_scale_tril .setter
93- def _unbroadcasted_scale_tril (self , ust : Tensor ):
102+ def _unbroadcasted_scale_tril (self , ust : Tensor ) -> None :
94103 if self .islazy :
95104 raise NotImplementedError ("Cannot set _unbroadcasted_scale_tril for lazy MVN distributions" )
96105 else :
@@ -114,10 +123,7 @@ def base_sample_shape(self) -> torch.Size:
114123
115124 @lazy_property
116125 def covariance_matrix (self ) -> Tensor :
117- if self .islazy :
118- return self ._covar .to_dense ()
119- else :
120- return super ().covariance_matrix
126+ return self ._covar .to_dense ()
121127
122128 def confidence_region (self ) -> Tuple [Tensor , Tensor ]:
123129 """
@@ -157,10 +163,7 @@ def get_base_samples(self, sample_shape: torch.Size = torch.Size()) -> Tensor:
157163
158164 @lazy_property
159165 def lazy_covariance_matrix (self ) -> LinearOperator :
160- if self .islazy :
161- return self ._covar
162- else :
163- return to_linear_operator (super ().covariance_matrix )
166+ return self ._covar
164167
165168 def log_prob (self , value : Tensor ) -> Tensor :
166169 r"""
@@ -304,13 +307,10 @@ def to_data_independent_dist(self) -> torch.distributions.Normal:
304307
305308 @property
306309 def variance (self ) -> Tensor :
307- if self .islazy :
308- # overwrite this since torch MVN uses unbroadcasted_scale_tril for this
309- diag = self .lazy_covariance_matrix .diagonal (dim1 = - 1 , dim2 = - 2 )
310- diag = diag .view (diag .shape [:- 1 ] + self ._event_shape )
311- variance = diag .expand (self ._batch_shape + self ._event_shape )
312- else :
313- variance = super ().variance
310+ # overwrite this since torch MVN uses unbroadcasted_scale_tril for this
311+ diag = self .lazy_covariance_matrix .diagonal (dim1 = - 1 , dim2 = - 2 )
312+ diag = diag .view (diag .shape [:- 1 ] + self ._event_shape )
313+ variance = diag .expand (self ._batch_shape + self ._event_shape )
314314
315315 # Check to make sure that variance isn't lower than minimum allowed value (default 1e-6).
316316 # This ensures that all variances are positive
0 commit comments