9
9
10
10
import torch
11
11
from linear_operator import to_dense , to_linear_operator
12
- from linear_operator .operators import DiagLinearOperator , LinearOperator , RootLinearOperator
12
+ from linear_operator .operators import (DiagLinearOperator , LinearOperator ,
13
+ RootLinearOperator )
13
14
from torch import Tensor
14
15
from torch .distributions import MultivariateNormal as TMultivariateNormal
15
16
from torch .distributions .kl import register_kl
@@ -42,27 +43,36 @@ class MultivariateNormal(TMultivariateNormal, Distribution):
42
43
:ivar torch.Tensor variance: The variance.
43
44
"""
44
45
45
- def __init__ (self , mean : Tensor , covariance_matrix : Union [Tensor , LinearOperator ], validate_args : bool = False ):
46
- self ._islazy = isinstance (mean , LinearOperator ) or isinstance (covariance_matrix , LinearOperator )
47
- if self ._islazy :
48
- if validate_args :
49
- ms = mean .size (- 1 )
50
- cs1 = covariance_matrix .size (- 1 )
51
- cs2 = covariance_matrix .size (- 2 )
52
- if not (ms == cs1 and ms == cs2 ):
53
- raise ValueError (f"Wrong shapes in { self ._repr_sizes (mean , covariance_matrix )} " )
54
- self .loc = mean
55
- self ._covar = covariance_matrix
56
- self .__unbroadcasted_scale_tril = None
57
- self ._validate_args = validate_args
58
- batch_shape = torch .broadcast_shapes (self .loc .shape [:- 1 ], covariance_matrix .shape [:- 2 ])
59
-
60
- event_shape = self .loc .shape [- 1 :]
61
-
62
- # TODO: Integrate argument validation for LinearOperators into torch.distribution validation logic
63
- super (TMultivariateNormal , self ).__init__ (batch_shape , event_shape , validate_args = False )
64
- else :
65
- super ().__init__ (loc = mean , covariance_matrix = covariance_matrix , validate_args = validate_args )
46
+ def __init__ (
47
+ self ,
48
+ mean : Union [Tensor , LinearOperator ],
49
+ covariance_matrix : Union [Tensor , LinearOperator ],
50
+ validate_args : bool = False ,
51
+ ):
52
+ self ._islazy = True
53
+ # casting Tensor to DenseLinearOperator because the super constructor calls cholesky, which
54
+ # will fail if the covariance matrix is semi-definite, whereas DenseLinearOperator ends up
55
+ # calling _psd_safe_cholesky, which factorizes semi-definite matrices by adding to the diagonal.
56
+ if isinstance (covariance_matrix , Tensor ):
57
+ self ._islazy = False # to allow _unbroadcasted_scale_tril setter
58
+ covariance_matrix = to_linear_operator (covariance_matrix )
59
+
60
+ if validate_args :
61
+ ms = mean .size (- 1 )
62
+ cs1 = covariance_matrix .size (- 1 )
63
+ cs2 = covariance_matrix .size (- 2 )
64
+ if not (ms == cs1 and ms == cs2 ):
65
+ raise ValueError (f"Wrong shapes in { self ._repr_sizes (mean , covariance_matrix )} " )
66
+ self .loc = mean
67
+ self ._covar = covariance_matrix
68
+ self .__unbroadcasted_scale_tril = None
69
+ self ._validate_args = validate_args
70
+ batch_shape = torch .broadcast_shapes (self .loc .shape [:- 1 ], covariance_matrix .shape [:- 2 ])
71
+
72
+ event_shape = self .loc .shape [- 1 :]
73
+
74
+ # TODO: Integrate argument validation for LinearOperators into torch.distribution validation logic
75
+ super (TMultivariateNormal , self ).__init__ (batch_shape , event_shape , validate_args = False )
66
76
67
77
def _extended_shape (self , sample_shape : torch .Size = torch .Size ()) -> torch .Size :
68
78
"""
@@ -81,16 +91,16 @@ def _extended_shape(self, sample_shape: torch.Size = torch.Size()) -> torch.Size
81
91
def _repr_sizes (mean : Tensor , covariance_matrix : Union [Tensor , LinearOperator ]) -> str :
82
92
return f"MultivariateNormal(loc: { mean .size ()} , scale: { covariance_matrix .size ()} )"
83
93
84
- @property
94
+ @property # not using lazy_property here, because it does not allow for setter below
85
95
def _unbroadcasted_scale_tril (self ) -> Tensor :
86
- if self .islazy and self . __unbroadcasted_scale_tril is None :
96
+ if self .__unbroadcasted_scale_tril is None :
87
97
# cache root decoposition
88
98
ust = to_dense (self .lazy_covariance_matrix .cholesky ())
89
99
self .__unbroadcasted_scale_tril = ust
90
100
return self .__unbroadcasted_scale_tril
91
101
92
102
@_unbroadcasted_scale_tril .setter
93
- def _unbroadcasted_scale_tril (self , ust : Tensor ):
103
+ def _unbroadcasted_scale_tril (self , ust : Tensor ) -> None :
94
104
if self .islazy :
95
105
raise NotImplementedError ("Cannot set _unbroadcasted_scale_tril for lazy MVN distributions" )
96
106
else :
@@ -114,10 +124,7 @@ def base_sample_shape(self) -> torch.Size:
114
124
115
125
@lazy_property
116
126
def covariance_matrix (self ) -> Tensor :
117
- if self .islazy :
118
- return self ._covar .to_dense ()
119
- else :
120
- return super ().covariance_matrix
127
+ return self ._covar .to_dense ()
121
128
122
129
def confidence_region (self ) -> Tuple [Tensor , Tensor ]:
123
130
"""
@@ -157,10 +164,7 @@ def get_base_samples(self, sample_shape: torch.Size = torch.Size()) -> Tensor:
157
164
158
165
@lazy_property
159
166
def lazy_covariance_matrix (self ) -> LinearOperator :
160
- if self .islazy :
161
- return self ._covar
162
- else :
163
- return to_linear_operator (super ().covariance_matrix )
167
+ return self ._covar
164
168
165
169
def log_prob (self , value : Tensor ) -> Tensor :
166
170
r"""
@@ -304,13 +308,10 @@ def to_data_independent_dist(self) -> torch.distributions.Normal:
304
308
305
309
@property
306
310
def variance (self ) -> Tensor :
307
- if self .islazy :
308
- # overwrite this since torch MVN uses unbroadcasted_scale_tril for this
309
- diag = self .lazy_covariance_matrix .diagonal (dim1 = - 1 , dim2 = - 2 )
310
- diag = diag .view (diag .shape [:- 1 ] + self ._event_shape )
311
- variance = diag .expand (self ._batch_shape + self ._event_shape )
312
- else :
313
- variance = super ().variance
311
+ # overwrite this since torch MVN uses unbroadcasted_scale_tril for this
312
+ diag = self .lazy_covariance_matrix .diagonal (dim1 = - 1 , dim2 = - 2 )
313
+ diag = diag .view (diag .shape [:- 1 ] + self ._event_shape )
314
+ variance = diag .expand (self ._batch_shape + self ._event_shape )
314
315
315
316
# Check to make sure that variance isn't lower than minimum allowed value (default 1e-6).
316
317
# This ensures that all variances are positive
0 commit comments