Skip to content

Commit ba187db

Browse files
committed
add combine_terms option to exact MLL
1 parent 5a0ff6b commit ba187db

File tree

3 files changed

+22
-7
lines changed

3 files changed

+22
-7
lines changed

gpytorch/distributions/multivariate_normal.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def lazy_covariance_matrix(self):
142142
else:
143143
return lazify(super().covariance_matrix)
144144

145-
def log_prob(self, value):
145+
def log_prob(self, value, combine_terms=True):
146146
if settings.fast_computations.log_prob.off():
147147
return super().log_prob(value)
148148

@@ -167,9 +167,13 @@ def log_prob(self, value):
167167
# Get log determininant and first part of quadratic form
168168
covar = covar.evaluate_kernel()
169169
inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=diff.unsqueeze(-1), logdet=True)
170+
norm_const = diff.size(-1) * math.log(2 * math.pi)
171+
split_terms = [inv_quad, logdet, norm_const]
170172

171-
res = -0.5 * sum([inv_quad, logdet, diff.size(-1) * math.log(2 * math.pi)])
172-
return res
173+
if combine_terms:
174+
return -0.5 * sum(split_terms)
175+
else:
176+
return [-0.5 * term for term in split_terms]
173177

174178
def rsample(self, sample_shape=torch.Size(), base_samples=None):
175179
covar = self.lazy_covariance_matrix

gpytorch/mlls/exact_marginal_log_likelihood.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#!/usr/bin/env python3
22

3+
import torch
4+
35
from ..distributions import MultivariateNormal
46
from ..likelihoods import _GaussianLikelihoodBase
57
from .marginal_log_likelihood import MarginalLogLikelihood
@@ -59,9 +61,17 @@ def forward(self, function_dist, target, *params):
5961

6062
# Get the log prob of the marginal distribution
6163
output = self.likelihood(function_dist, *params)
62-
res = output.log_prob(target)
63-
res = self._add_other_terms(res, params)
64+
res = output.log_prob(target, combine_terms=self.combine_terms)
6465

6566
# Scale by the amount of data we have
6667
num_data = function_dist.event_shape.numel()
67-
return res.div_(num_data)
68+
69+
if self.combine_terms:
70+
res = self._add_other_terms(res, params)
71+
return res.div(num_data)
72+
else:
73+
norm_const = res[-1]
74+
other_terms = torch.zeros_like(norm_const)
75+
other_terms = self._add_other_terms(other_terms, params)
76+
res.append(other_terms)
77+
return [term.div(num_data) for term in res]

gpytorch/mlls/marginal_log_likelihood.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class MarginalLogLikelihood(Module):
2525
these functions must be negated for optimization).
2626
"""
2727

28-
def __init__(self, likelihood, model):
28+
def __init__(self, likelihood, model, combine_terms=True):
2929
super(MarginalLogLikelihood, self).__init__()
3030
if not isinstance(model, GP):
3131
raise RuntimeError(
@@ -35,6 +35,7 @@ def __init__(self, likelihood, model):
3535
)
3636
self.likelihood = likelihood
3737
self.model = model
38+
self.combine_terms = combine_terms
3839

3940
def forward(self, output, target, **kwargs):
4041
r"""

0 commit comments

Comments
 (0)