Skip to content

Commit daf43a3

Browse files
committed
WIP
1 parent dc61fb1 commit daf43a3

17 files changed

+226
-35
lines changed

gpytorch/kernels/cosine_kernel.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,6 @@ class CosineKernel(Kernel):
5656
>>> covar = covar_module(x) # Output: LazyVariable of size (2 x 10 x 10)
5757
"""
5858

59-
is_stationary = True
60-
6159
def __init__(
6260
self,
6361
period_length_prior: Optional[Prior] = None,
@@ -85,6 +83,10 @@ def __init__(
8583

8684
self.register_constraint("raw_period_length", period_length_constraint)
8785

86+
@property
87+
def is_stationary(self):
88+
return True
89+
8890
@property
8991
def period_length(self):
9092
return self.raw_period_length_constraint.transform(self.raw_period_length)

gpytorch/kernels/cylindrical_kernel.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import torch
66

7-
from .. import settings
87
from ..constraints import Interval, Positive
98
from ..priors import Prior
109
from .kernel import Kernel
@@ -152,8 +151,7 @@ def forward(self, x1: torch.Tensor, x2: torch.Tensor, diag: Optional[bool] = Fal
152151
else:
153152
angular_kernel = angular_kernel + self.angular_weights[..., p, None].mul(gram_mat.pow(p))
154153

155-
with settings.lazily_evaluate_kernels(False):
156-
radial_kernel = self.radial_base_kernel(self.kuma(r1), self.kuma(r2), diag=diag, **params)
154+
radial_kernel = self.radial_base_kernel.forward(self.kuma(r1), self.kuma(r2), diag=diag, **params)
157155
return radial_kernel.mul(angular_kernel)
158156

159157
def kuma(self, x: torch.Tensor) -> torch.Tensor:

gpytorch/kernels/grid_interpolation_kernel.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,13 @@ def __init__(
115115
)
116116
self.register_buffer("has_initialized_grid", torch.tensor(has_initialized_grid, dtype=torch.bool))
117117

118+
@property
119+
def _lazily_evaluate(self) -> bool:
120+
# GridInterpolationKernels should not lazily evaluate; there are few gains (the inducing point kernel
121+
# matrix always needs to be evaluated; regardless of the size of x1 and x2), and the
122+
# InterpolatedLinearOperator structure is needed for fast predictions.
123+
return False
124+
118125
@property
119126
def _tight_grid_bounds(self):
120127
grid_spacings = tuple((bound[1] - bound[0]) / self.grid_sizes[i] for i, bound in enumerate(self.grid_bounds))

gpytorch/kernels/grid_kernel.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,6 @@ class GridKernel(Kernel):
4141
http://www.cs.cmu.edu/~andrewgw/manet.pdf
4242
"""
4343

44-
# TODO: update doc
45-
46-
is_stationary = True
47-
4844
def __init__(
4945
self,
5046
base_kernel: Kernel,
@@ -76,6 +72,15 @@ def __init__(
7672
# Also create the full_grid buffer
7773
self.update_grid(grid)
7874

75+
@property
76+
def _lazily_evaluate(self) -> bool:
77+
# Toeplitz structure is very efficient; no need to lazily evaluate
78+
return False
79+
80+
@property
81+
def is_stationary(self) -> bool:
82+
return True
83+
7984
def _clear_cache(self):
8085
if hasattr(self, "_cached_kernel_mat"):
8186
del self._cached_kernel_mat

gpytorch/kernels/index_kernel.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,12 @@ def __init__(
7676

7777
self.register_constraint("raw_var", var_constraint)
7878

79+
@property
80+
def _lazily_evaluate(self) -> bool:
81+
# IndexKernel does not need lazy evaluation, since the complete BB^T + D_v` is always
82+
# computed regardless of x1 and x2
83+
return False
84+
7985
@property
8086
def var(self):
8187
return self.raw_var_constraint.transform(self.raw_var)

gpytorch/kernels/inducing_point_kernel.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ def _clear_cache(self):
4747
if hasattr(self, "_cached_kernel_inv_root"):
4848
del self._cached_kernel_inv_root
4949

50+
@property
51+
def _lazily_evaluate(self) -> bool:
52+
# InducingPointKernels kernels should not lazily evaluate; to use the Woodbury formula,
53+
# we want the Kernel to return a LowRankLinearOperator, not a KernelLinaerOperator.
54+
return False
55+
5056
@property
5157
def _inducing_mat(self):
5258
if not self.training and hasattr(self, "_cached_kernel_mat"):

gpytorch/kernels/keops/rbf_kernel.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#!/usr/bin/env python3
22

3-
# from linear_operator.operators import KeOpsLinearOperator
43
from linear_operator.operators import KernelLinearOperator
54

65
from .keops_kernel import _lazify_and_expand_inputs, KeOpsKernel

gpytorch/kernels/kernel.py

Lines changed: 151 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44

55
import warnings
66
from abc import abstractmethod
7+
from collections import defaultdict, OrderedDict
78
from copy import deepcopy
8-
from typing import Callable, Dict, Iterable, Optional, Tuple, Union
9+
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
910

1011
import torch
1112
from linear_operator import to_dense, to_linear_operator
12-
from linear_operator.operators import LinearOperator, ZeroLinearOperator
13+
from linear_operator.operators import KernelLinearOperator, LinearOperator, ZeroLinearOperator
1314
from torch import Tensor
1415
from torch.nn import ModuleList
1516

@@ -81,6 +82,44 @@ def _dist(self, x1, x2, x1_eq_x2=False, postprocess=False):
8182
return self._postprocess(res) if postprocess else res
8283

8384

85+
class _autograd_kernel_hack:
86+
"""
87+
Helper class.
88+
89+
When using KernelLinearOperator, the `covar_func` cannot close over any Tensors that require gradients.
90+
(Any Tensor that `covar_func` closes over will not backpropagate gradients.)
91+
Unfortunately, for most kernels, `covar_func=self.forward`, which closes over all of the kernel's parameters.
92+
93+
This context manager temporarily replaces a kernel (and its submodules') parameter assignments with an
94+
external set of references to these parameters.
95+
The external set of references will be passed in by KernelLinearOperator.
96+
97+
This way, when calling self.forward, no parameter references are closed over, and so all parameters
98+
will receive the appropriate gradients.
99+
"""
100+
101+
def __init__(
102+
self,
103+
kernel: Kernel,
104+
params: Dict[str, torch.nn.Parameters],
105+
module_params: Dict[torch.nn.Module, Iterable[str]],
106+
):
107+
self.temp_module_param_dicts = defaultdict(OrderedDict)
108+
for module, param_names in module_params.items():
109+
self.temp_module_param_dicts[module] = OrderedDict(
110+
(param_name.rsplit(".", 1)[-1], params[param_name]) for param_name in param_names
111+
)
112+
self.orig_model_param_dicts = dict((module, module._parameters) for module in self.temp_module_param_dicts)
113+
114+
def __enter__(self):
115+
for module, temp_param_dict in self.temp_module_param_dicts.items():
116+
object.__setattr__(module, "_parameters", temp_param_dict)
117+
118+
def __exit__(self, type, value, traceback):
119+
for module, orig_param_dict in self.orig_model_param_dicts.items():
120+
object.__setattr__(module, "_parameters", orig_param_dict)
121+
122+
84123
class Kernel(Module):
85124
r"""
86125
Kernels in GPyTorch are implemented as a :class:`gpytorch.Module` that, when called on two :class:`torch.Tensor`
@@ -212,6 +251,45 @@ def __init__(
212251
# TODO: Remove this on next official PyTorch release.
213252
self.__pdist_supports_batch = True
214253

254+
@property
255+
def _lazily_evaluate(self) -> bool:
256+
r"""
257+
Determines whether or not the kernel is lazily evaluated.
258+
259+
If False, kernel(x1, x2) produces a Tensor/LinearOperator where the covariance function has been evaluated
260+
over x1 and x2.
261+
262+
If True, kernel(x1, x2) produces a KernelLinearOperator that delays evaluation of the kernel function.
263+
The kernel function will only be evaluated when either
264+
- An mathematical operation is performed on the kernel matrix (e.g. solves, logdets, etc.), or
265+
- An indexing operation is performed on the kernel matrix to select specific covariance entries.
266+
267+
In general, _lazily_evaluate should return True (this option is more efficient), unless lazy evaluation
268+
offers no gains and there is specific structure that will be lost with lazy evaluation
269+
(e.g. low-rank/Nystrom approximations).
270+
"""
271+
return True
272+
273+
def _kernel_linear_operator_covar_func(
274+
self,
275+
x1: Tensor,
276+
x2: Tensor,
277+
non_param_kwargs: Dict[str, Any],
278+
module_params: Dict[torch.nn.Module, Iterable[str]],
279+
**params: torch.nn.Parameter,
280+
) -> Union[Tensor, LinearOperator]:
281+
# This is the `covar_function` that is passed into KernelLinearOperator
282+
# This function calls self.forward, but does so in a way so that no parameters are closed over
283+
# (by using the _autograd_kernel_hack context manager)
284+
try:
285+
if any(param.requires_grad for param in params.values()):
286+
with _autograd_kernel_hack(self, params, module_params):
287+
return self.forward(x1, x2, **non_param_kwargs)
288+
else:
289+
return self.forward(x1, x2, **non_param_kwargs)
290+
except Exception as e:
291+
raise e
292+
215293
def _lengthscale_param(self, m: Kernel) -> Tensor:
216294
# Used by the lengthscale_prior
217295
return m.lengthscale
@@ -501,8 +579,63 @@ def __call__(
501579
return res
502580

503581
else:
504-
if settings.lazily_evaluate_kernels.on():
505-
res = LazyEvaluatedKernelTensor(x1_, x2_, kernel=self, **params)
582+
if settings.lazily_evaluate_kernels.on() and self._lazily_evaluate:
583+
num_outputs_per_input = self.num_outputs_per_input(x1_, x2_)
584+
if isinstance(num_outputs_per_input, int):
585+
num_outputs_per_input = (num_outputs_per_input, num_outputs_per_input)
586+
587+
def _get_parameter_parent_module_and_batch_shape(module):
588+
num_module_batch_dimension = len(module.batch_shape) if isinstance(module, Kernel) else 0
589+
for name, param in module._parameters.items():
590+
yield name, (param, module, param.dim() - num_module_batch_dimension)
591+
592+
# The following returns a list of tuples for each parameter + parameters of sub-modules:
593+
# (param_name, (param_val, param_parent_module, param_batch_shape))
594+
named_parameters_parent_modules_and_batch_dimensions = tuple(
595+
self._named_members(
596+
_get_parameter_parent_module_and_batch_shape,
597+
prefix="",
598+
recurse=True,
599+
)
600+
)
601+
602+
if len(named_parameters_parent_modules_and_batch_dimensions):
603+
# Information we need for the KernelLinearOperator, as well as the autograd hack:
604+
# - the names/values of all parameters
605+
# - the parent module associated with each parameter
606+
# - the number of non-batch dimensions associated with each parameter
607+
# WE get this information from the list constructed in the previous step
608+
params = dict()
609+
module_params = defaultdict(list)
610+
num_nonbatch_dimensions = dict()
611+
for name, (
612+
param,
613+
parent_module,
614+
num_nonbatch_dimension,
615+
) in named_parameters_parent_modules_and_batch_dimensions:
616+
params[name] = param
617+
module_params[parent_module].append(name)
618+
num_nonbatch_dimensions[name] = num_nonbatch_dimension
619+
620+
# Construct the KernelLinearOperator
621+
res = KernelLinearOperator(
622+
x1_,
623+
x2_,
624+
covar_func=self._kernel_linear_operator_covar_func,
625+
num_outputs_per_input=num_outputs_per_input,
626+
num_nonbatch_dimensions=num_nonbatch_dimensions,
627+
module_params=module_params, # params for _kernel_linear_operator_covar_func
628+
non_param_kwargs=dict(**params), # params for forward
629+
**params,
630+
)
631+
else:
632+
res = KernelLinearOperator(
633+
x1_,
634+
x2_,
635+
covar_func=self.forward,
636+
num_outputs_per_input=num_outputs_per_input,
637+
non_param_kwargs=dict(**params), # params for forward
638+
)
506639
else:
507640
res = to_linear_operator(super(Kernel, self).__call__(x1_, x2_, **params))
508641
return res
@@ -575,13 +708,17 @@ class AdditiveKernel(Kernel):
575708
:param kernels: Kernels to add together.
576709
"""
577710

711+
def __init__(self, *kernels: Iterable[Kernel]):
712+
super(AdditiveKernel, self).__init__()
713+
self.kernels = ModuleList(kernels)
714+
578715
@property
579716
def is_stationary(self) -> bool:
580717
return all(k.is_stationary for k in self.kernels)
581718

582-
def __init__(self, *kernels: Iterable[Kernel]):
583-
super(AdditiveKernel, self).__init__()
584-
self.kernels = ModuleList(kernels)
719+
@property
720+
def _lazily_evaluate(self) -> bool:
721+
return all(k._lazily_evaluate for k in self.kernels)
585722

586723
def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **params) -> Union[Tensor, LinearOperator]:
587724
res = ZeroLinearOperator() if not diag else 0
@@ -617,13 +754,17 @@ class ProductKernel(Kernel):
617754
:param kernels: Kernels to multiply together.
618755
"""
619756

757+
def __init__(self, *kernels: Iterable[Kernel]):
758+
super(ProductKernel, self).__init__()
759+
self.kernels = ModuleList(kernels)
760+
620761
@property
621762
def is_stationary(self) -> bool:
622763
return all(k.is_stationary for k in self.kernels)
623764

624-
def __init__(self, *kernels: Iterable[Kernel]):
625-
super(ProductKernel, self).__init__()
626-
self.kernels = ModuleList(kernels)
765+
@property
766+
def _lazily_evaluate(self) -> bool:
767+
return False
627768

628769
def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **params) -> Union[Tensor, LinearOperator]:
629770
x1_eq_x2 = torch.equal(x1, x2)

gpytorch/kernels/linear_kernel.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ def __init__(
7272

7373
self.register_constraint("raw_variance", variance_constraint)
7474

75+
@property
76+
def _lazily_evaluate(self) -> bool:
77+
# LinearKernel should not lazily evaluate; to use the Woodbury formula,
78+
# we want the Kernel to return a LowRankLinearOperator, not a KernelLinaerOperator.
79+
return False
80+
7581
@property
7682
def variance(self) -> Tensor:
7783
return self.raw_variance_constraint.transform(self.raw_variance)

gpytorch/kernels/multi_device_kernel.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,18 @@ def __init__(
4242
self.__cached_x1 = torch.empty(1)
4343
self.__cached_x2 = torch.empty(1)
4444

45+
@property
46+
def _lazily_evaluate(self) -> bool:
47+
return self.base_kernel._lazily_evaluate
48+
4549
@property
4650
def base_kernel(self):
4751
return self.module
4852

53+
@property
54+
def is_stationary(self):
55+
return self.base_kernel.is_stationary
56+
4957
def forward(self, x1, x2, diag=False, **kwargs):
5058
if diag:
5159
return self.module.forward(x1, x2, diag=True, **kwargs).to(self.output_device)

0 commit comments

Comments
 (0)