-
Notifications
You must be signed in to change notification settings - Fork 574
add combine_terms option to exact MLL #1863
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
d8bd497
328ebd0
b85418e
d6ca1bf
8069b7e
71ba3bf
2160a7f
e13a318
877f271
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,7 +48,11 @@ def __init__(self, mean, covariance_matrix, validate_args=False): | |
# TODO: Integrate argument validation for LazyTensors into torch.distribution validation logic | ||
super(TMultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=False) | ||
else: | ||
super().__init__(loc=mean, covariance_matrix=covariance_matrix, validate_args=validate_args) | ||
super().__init__( | ||
loc=mean, | ||
covariance_matrix=covariance_matrix, | ||
validate_args=validate_args, | ||
) | ||
|
||
@property | ||
def _unbroadcasted_scale_tril(self): | ||
|
@@ -142,7 +146,7 @@ def lazy_covariance_matrix(self): | |
else: | ||
return lazify(super().covariance_matrix) | ||
|
||
def log_prob(self, value): | ||
def log_prob(self, value, combine_terms=True): | ||
|
||
if settings.fast_computations.log_prob.off(): | ||
return super().log_prob(value) | ||
|
||
|
@@ -157,7 +161,10 @@ def log_prob(self, value): | |
if len(diff.shape[:-1]) < len(covar.batch_shape): | ||
diff = diff.expand(covar.shape[:-1]) | ||
else: | ||
padded_batch_shape = (*(1 for _ in range(diff.dim() + 1 - covar.dim())), *covar.batch_shape) | ||
padded_batch_shape = ( | ||
*(1 for _ in range(diff.dim() + 1 - covar.dim())), | ||
*covar.batch_shape, | ||
) | ||
covar = covar.repeat( | ||
*(diff_size // covar_size for diff_size, covar_size in zip(diff.shape[:-1], padded_batch_shape)), | ||
1, | ||
|
@@ -167,9 +174,13 @@ def log_prob(self, value): | |
# Get log determininant and first part of quadratic form | ||
covar = covar.evaluate_kernel() | ||
inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=diff.unsqueeze(-1), logdet=True) | ||
norm_const = torch.tensor(diff.size(-1) * math.log(2 * math.pi)).to(inv_quad) | ||
split_terms = [inv_quad, logdet, norm_const] | ||
|
||
res = -0.5 * sum([inv_quad, logdet, diff.size(-1) * math.log(2 * math.pi)]) | ||
return res | ||
if combine_terms: | ||
return -0.5 * sum(split_terms) | ||
else: | ||
return [-0.5 * term for term in split_terms] | ||
|
||
def rsample(self, sample_shape=torch.Size(), base_samples=None): | ||
covar = self.lazy_covariance_matrix | ||
|
@@ -286,7 +297,10 @@ def __mul__(self, other): | |
raise RuntimeError("Can only multiply by scalars") | ||
if other == 1: | ||
return self | ||
return self.__class__(mean=self.mean * other, covariance_matrix=self.lazy_covariance_matrix * (other ** 2)) | ||
return self.__class__( | ||
mean=self.mean * other, | ||
covariance_matrix=self.lazy_covariance_matrix * (other ** 2), | ||
) | ||
|
||
def __truediv__(self, other): | ||
return self.__mul__(1.0 / other) | ||
|
@@ -341,5 +355,12 @@ def kl_mvn_mvn(p_dist, q_dist): | |
trace_plus_inv_quad_form, logdet_q_covar = q_covar.inv_quad_logdet(inv_quad_rhs=inv_quad_rhs, logdet=True) | ||
|
||
# Compute the KL Divergence. | ||
res = 0.5 * sum([logdet_q_covar, logdet_p_covar.mul(-1), trace_plus_inv_quad_form, -float(mean_diffs.size(-1))]) | ||
res = 0.5 * sum( | ||
[ | ||
logdet_q_covar, | ||
logdet_p_covar.mul(-1), | ||
trace_plus_inv_quad_form, | ||
-float(mean_diffs.size(-1)), | ||
] | ||
) | ||
return res |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import torch | ||
|
||
from ..distributions import MultivariateNormal | ||
from ..likelihoods import _GaussianLikelihoodBase | ||
from .marginal_log_likelihood import MarginalLogLikelihood | ||
|
@@ -17,6 +19,7 @@ class ExactMarginalLogLikelihood(MarginalLogLikelihood): | |
|
||
:param ~gpytorch.likelihoods.GaussianLikelihood likelihood: The Gaussian likelihood for the model | ||
:param ~gpytorch.models.ExactGP model: The exact GP model | ||
:param ~bool combine_terms (optional): If `False`, the MLL call returns each MLL term separately | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should probably also describe what happens if there are "other terms" (i.e. that they are added to the return elements) |
||
|
||
Example: | ||
>>> # model is a gpytorch.models.ExactGP | ||
|
@@ -28,10 +31,10 @@ class ExactMarginalLogLikelihood(MarginalLogLikelihood): | |
>>> loss.backward() | ||
""" | ||
|
||
def __init__(self, likelihood, model): | ||
def __init__(self, likelihood, model, combine_terms=True): | ||
if not isinstance(likelihood, _GaussianLikelihoodBase): | ||
raise RuntimeError("Likelihood must be Gaussian for exact inference") | ||
super(ExactMarginalLogLikelihood, self).__init__(likelihood, model) | ||
super(ExactMarginalLogLikelihood, self).__init__(likelihood, model, combine_terms) | ||
|
||
def _add_other_terms(self, res, params): | ||
# Add additional terms (SGPR / learned inducing points, heteroskedastic likelihood models) | ||
|
@@ -59,9 +62,17 @@ def forward(self, function_dist, target, *params): | |
|
||
# Get the log prob of the marginal distribution | ||
output = self.likelihood(function_dist, *params) | ||
res = output.log_prob(target) | ||
res = self._add_other_terms(res, params) | ||
res = output.log_prob(target, combine_terms=self.combine_terms) | ||
|
||
|
||
# Scale by the amount of data we have | ||
num_data = function_dist.event_shape.numel() | ||
return res.div_(num_data) | ||
|
||
if self.combine_terms: | ||
res = self._add_other_terms(res, params) | ||
return res.div(num_data) | ||
else: | ||
norm_const = res[-1] | ||
other_terms = torch.zeros_like(norm_const) | ||
other_terms = self._add_other_terms(other_terms, params) | ||
res.append(other_terms) | ||
return [term.div(num_data) for term in res] |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -219,20 +219,32 @@ def test_log_prob(self, cuda=False): | |||||
var = torch.randn(4, device=device, dtype=dtype).abs_() | ||||||
values = torch.randn(4, device=device, dtype=dtype) | ||||||
|
||||||
res = MultivariateNormal(mean, DiagLazyTensor(var)).log_prob(values) | ||||||
mvn = MultivariateNormal(mean, DiagLazyTensor(var)) | ||||||
res = mvn.log_prob(values) | ||||||
actual = TMultivariateNormal(mean, torch.eye(4, device=device, dtype=dtype) * var).log_prob(values) | ||||||
self.assertLess((res - actual).div(res).abs().item(), 1e-2) | ||||||
|
||||||
res2 = mvn.log_prob(values, combine_terms=False) | ||||||
assert len(res2) == 3 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also in other places in the tests below |
||||||
res2 = sum(res2) | ||||||
self.assertLess((res2 - actual).div(res).abs().item(), 1e-2) | ||||||
|
||||||
mean = torch.randn(3, 4, device=device, dtype=dtype) | ||||||
var = torch.randn(3, 4, device=device, dtype=dtype).abs_() | ||||||
values = torch.randn(3, 4, device=device, dtype=dtype) | ||||||
|
||||||
res = MultivariateNormal(mean, DiagLazyTensor(var)).log_prob(values) | ||||||
mvn = MultivariateNormal(mean, DiagLazyTensor(var)) | ||||||
res = mvn.log_prob(values) | ||||||
actual = TMultivariateNormal( | ||||||
mean, var.unsqueeze(-1) * torch.eye(4, device=device, dtype=dtype).repeat(3, 1, 1) | ||||||
).log_prob(values) | ||||||
self.assertLess((res - actual).div(res).abs().norm(), 1e-2) | ||||||
|
||||||
res2 = mvn.log_prob(values, combine_terms=False) | ||||||
assert len(res2) == 3 | ||||||
res2 = sum(res2) | ||||||
self.assertLess((res2 - actual).div(res).abs().norm(), 1e-2) | ||||||
|
||||||
def test_log_prob_cuda(self): | ||||||
if torch.cuda.is_available(): | ||||||
with least_used_cuda_device(): | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import unittest | ||
|
||
import torch | ||
|
||
import gpytorch | ||
|
||
from .test_leave_one_out_pseudo_likelihood import ExactGPModel | ||
|
||
|
||
class TestExactMarginalLogLikelihood(unittest.TestCase): | ||
def get_data(self, shapes, combine_terms, dtype=None, device=None): | ||
train_x = torch.rand(*shapes, dtype=dtype, device=device, requires_grad=True) | ||
train_y = torch.sin(train_x[..., 0]) + torch.cos(train_x[..., 1]) | ||
likelihood = gpytorch.likelihoods.GaussianLikelihood().to(dtype=dtype, device=device) | ||
model = ExactGPModel(train_x, train_y, likelihood).to(dtype=dtype, device=device) | ||
exact_mll = gpytorch.mlls.ExactMarginalLogLikelihood( | ||
likelihood=likelihood, model=model, combine_terms=combine_terms | ||
) | ||
return train_x, train_y, exact_mll | ||
|
||
def test_smoke(self): | ||
"""Make sure the exact_mll works without batching.""" | ||
train_x, train_y, exact_mll = self.get_data([5, 2], combine_terms=True) | ||
output = exact_mll.model(train_x) | ||
loss = -exact_mll(output, train_y) | ||
loss.backward() | ||
self.assertTrue(train_x.grad is not None) | ||
|
||
train_x, train_y, exact_mll = self.get_data([5, 2], combine_terms=False) | ||
output = exact_mll.model(train_x) | ||
mll_out = exact_mll(output, train_y) | ||
loss = -1 * sum(mll_out) | ||
loss.backward() | ||
assert len(mll_out) == 4 | ||
self.assertTrue(train_x.grad is not None) | ||
|
||
def test_smoke_batch(self): | ||
"""Make sure the exact_mll works without batching.""" | ||
train_x, train_y, exact_mll = self.get_data([3, 3, 3, 5, 2], combine_terms=True) | ||
output = exact_mll.model(train_x) | ||
loss = -exact_mll(output, train_y) | ||
assert loss.shape == (3, 3, 3) | ||
loss.sum().backward() | ||
self.assertTrue(train_x.grad is not None) | ||
|
||
train_x, train_y, exact_mll = self.get_data([3, 3, 3, 5, 2], combine_terms=False) | ||
output = exact_mll.model(train_x) | ||
mll_out = exact_mll(output, train_y) | ||
loss = -1 * sum(mll_out) | ||
assert len(mll_out) == 4 | ||
assert loss.shape == (3, 3, 3) | ||
loss.sum().backward() | ||
self.assertTrue(train_x.grad is not None) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, I don't think we want to be adding flags to the standard log_prob call here to maintain compatibility with the MVN api in pytorch. let's have this be a
_log_prob
method with thelog_prob
just calling_log_prob(value=value, combine_terms=True)
?