|
4 | 4 |
|
5 | 5 | import warnings |
6 | 6 | from abc import abstractmethod |
| 7 | +from collections import defaultdict, OrderedDict |
7 | 8 | from copy import deepcopy |
8 | | -from typing import Callable, Dict, Iterable, Optional, Tuple, Union |
| 9 | +from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union |
9 | 10 |
|
10 | 11 | import torch |
11 | 12 | from linear_operator import to_dense, to_linear_operator |
12 | | -from linear_operator.operators import LinearOperator, ZeroLinearOperator |
| 13 | +from linear_operator.operators import KernelLinearOperator, LinearOperator, ZeroLinearOperator |
13 | 14 | from torch import Tensor |
14 | 15 | from torch.nn import ModuleList |
15 | 16 |
|
@@ -81,6 +82,44 @@ def _dist(self, x1, x2, x1_eq_x2=False, postprocess=False): |
81 | 82 | return self._postprocess(res) if postprocess else res |
82 | 83 |
|
83 | 84 |
|
| 85 | +class _autograd_kernel_hack: |
| 86 | + """ |
| 87 | + Helper class. |
| 88 | +
|
| 89 | + When using KernelLinearOperator, the `covar_func` cannot close over any Tensors that require gradients. |
| 90 | + (Any Tensor that `covar_func` closes over will not backpropagate gradients.) |
| 91 | + Unfortunately, for most kernels, `covar_func=self.forward`, which closes over all of the kernel's parameters. |
| 92 | +
|
| 93 | + This context manager temporarily replaces a kernel (and its submodules') parameter assignments with an |
| 94 | + external set of references to these parameters. |
| 95 | + The external set of references will be passed in by KernelLinearOperator. |
| 96 | +
|
| 97 | + This way, when calling self.forward, no parameter references are closed over, and so all parameters |
| 98 | + will receive the appropriate gradients. |
| 99 | + """ |
| 100 | + |
| 101 | + def __init__( |
| 102 | + self, |
| 103 | + kernel: Kernel, |
| 104 | + params: Dict[str, torch.nn.Parameters], |
| 105 | + module_params: Dict[torch.nn.Module, Iterable[str]], |
| 106 | + ): |
| 107 | + self.temp_module_param_dicts = defaultdict(OrderedDict) |
| 108 | + for module, param_names in module_params.items(): |
| 109 | + self.temp_module_param_dicts[module] = OrderedDict( |
| 110 | + (param_name.rsplit(".", 1)[-1], params[param_name]) for param_name in param_names |
| 111 | + ) |
| 112 | + self.orig_model_param_dicts = dict((module, module._parameters) for module in self.temp_module_param_dicts) |
| 113 | + |
| 114 | + def __enter__(self): |
| 115 | + for module, temp_param_dict in self.temp_module_param_dicts.items(): |
| 116 | + object.__setattr__(module, "_parameters", temp_param_dict) |
| 117 | + |
| 118 | + def __exit__(self, type, value, traceback): |
| 119 | + for module, orig_param_dict in self.orig_model_param_dicts.items(): |
| 120 | + object.__setattr__(module, "_parameters", orig_param_dict) |
| 121 | + |
| 122 | + |
84 | 123 | class Kernel(Module): |
85 | 124 | r""" |
86 | 125 | Kernels in GPyTorch are implemented as a :class:`gpytorch.Module` that, when called on two :class:`torch.Tensor` |
@@ -212,6 +251,45 @@ def __init__( |
212 | 251 | # TODO: Remove this on next official PyTorch release. |
213 | 252 | self.__pdist_supports_batch = True |
214 | 253 |
|
| 254 | + @property |
| 255 | + def _lazily_evaluate(self) -> bool: |
| 256 | + r""" |
| 257 | + Determines whether or not the kernel is lazily evaluated. |
| 258 | +
|
| 259 | + If False, kernel(x1, x2) produces a Tensor/LinearOperator where the covariance function has been evaluated |
| 260 | + over x1 and x2. |
| 261 | +
|
| 262 | + If True, kernel(x1, x2) produces a KernelLinearOperator that delays evaluation of the kernel function. |
| 263 | + The kernel function will only be evaluated when either |
| 264 | + - An mathematical operation is performed on the kernel matrix (e.g. solves, logdets, etc.), or |
| 265 | + - An indexing operation is performed on the kernel matrix to select specific covariance entries. |
| 266 | +
|
| 267 | + In general, _lazily_evaluate should return True (this option is more efficient), unless lazy evaluation |
| 268 | + offers no gains and there is specific structure that will be lost with lazy evaluation |
| 269 | + (e.g. low-rank/Nystrom approximations). |
| 270 | + """ |
| 271 | + return True |
| 272 | + |
| 273 | + def _kernel_linear_operator_covar_func( |
| 274 | + self, |
| 275 | + x1: Tensor, |
| 276 | + x2: Tensor, |
| 277 | + non_param_kwargs: Dict[str, Any], |
| 278 | + module_params: Dict[torch.nn.Module, Iterable[str]], |
| 279 | + **params: torch.nn.Parameter, |
| 280 | + ) -> Union[Tensor, LinearOperator]: |
| 281 | + # This is the `covar_function` that is passed into KernelLinearOperator |
| 282 | + # This function calls self.forward, but does so in a way so that no parameters are closed over |
| 283 | + # (by using the _autograd_kernel_hack context manager) |
| 284 | + try: |
| 285 | + if any(param.requires_grad for param in params.values()): |
| 286 | + with _autograd_kernel_hack(self, params, module_params): |
| 287 | + return self.forward(x1, x2, **non_param_kwargs) |
| 288 | + else: |
| 289 | + return self.forward(x1, x2, **non_param_kwargs) |
| 290 | + except Exception as e: |
| 291 | + raise e |
| 292 | + |
215 | 293 | def _lengthscale_param(self, m: Kernel) -> Tensor: |
216 | 294 | # Used by the lengthscale_prior |
217 | 295 | return m.lengthscale |
@@ -501,8 +579,63 @@ def __call__( |
501 | 579 | return res |
502 | 580 |
|
503 | 581 | else: |
504 | | - if settings.lazily_evaluate_kernels.on(): |
505 | | - res = LazyEvaluatedKernelTensor(x1_, x2_, kernel=self, **params) |
| 582 | + if settings.lazily_evaluate_kernels.on() and self._lazily_evaluate: |
| 583 | + num_outputs_per_input = self.num_outputs_per_input(x1_, x2_) |
| 584 | + if isinstance(num_outputs_per_input, int): |
| 585 | + num_outputs_per_input = (num_outputs_per_input, num_outputs_per_input) |
| 586 | + |
| 587 | + def _get_parameter_parent_module_and_batch_shape(module): |
| 588 | + num_module_batch_dimension = len(module.batch_shape) if isinstance(module, Kernel) else 0 |
| 589 | + for name, param in module._parameters.items(): |
| 590 | + yield name, (param, module, param.dim() - num_module_batch_dimension) |
| 591 | + |
| 592 | + # The following returns a list of tuples for each parameter + parameters of sub-modules: |
| 593 | + # (param_name, (param_val, param_parent_module, param_batch_shape)) |
| 594 | + named_parameters_parent_modules_and_batch_dimensions = tuple( |
| 595 | + self._named_members( |
| 596 | + _get_parameter_parent_module_and_batch_shape, |
| 597 | + prefix="", |
| 598 | + recurse=True, |
| 599 | + ) |
| 600 | + ) |
| 601 | + |
| 602 | + if len(named_parameters_parent_modules_and_batch_dimensions): |
| 603 | + # Information we need for the KernelLinearOperator, as well as the autograd hack: |
| 604 | + # - the names/values of all parameters |
| 605 | + # - the parent module associated with each parameter |
| 606 | + # - the number of non-batch dimensions associated with each parameter |
| 607 | + # WE get this information from the list constructed in the previous step |
| 608 | + params = dict() |
| 609 | + module_params = defaultdict(list) |
| 610 | + num_nonbatch_dimensions = dict() |
| 611 | + for name, ( |
| 612 | + param, |
| 613 | + parent_module, |
| 614 | + num_nonbatch_dimension, |
| 615 | + ) in named_parameters_parent_modules_and_batch_dimensions: |
| 616 | + params[name] = param |
| 617 | + module_params[parent_module].append(name) |
| 618 | + num_nonbatch_dimensions[name] = num_nonbatch_dimension |
| 619 | + |
| 620 | + # Construct the KernelLinearOperator |
| 621 | + res = KernelLinearOperator( |
| 622 | + x1_, |
| 623 | + x2_, |
| 624 | + covar_func=self._kernel_linear_operator_covar_func, |
| 625 | + num_outputs_per_input=num_outputs_per_input, |
| 626 | + num_nonbatch_dimensions=num_nonbatch_dimensions, |
| 627 | + module_params=module_params, # params for _kernel_linear_operator_covar_func |
| 628 | + non_param_kwargs=dict(**params), # params for forward |
| 629 | + **params, |
| 630 | + ) |
| 631 | + else: |
| 632 | + res = KernelLinearOperator( |
| 633 | + x1_, |
| 634 | + x2_, |
| 635 | + covar_func=self.forward, |
| 636 | + num_outputs_per_input=num_outputs_per_input, |
| 637 | + non_param_kwargs=dict(**params), # params for forward |
| 638 | + ) |
506 | 639 | else: |
507 | 640 | res = to_linear_operator(super(Kernel, self).__call__(x1_, x2_, **params)) |
508 | 641 | return res |
@@ -575,13 +708,17 @@ class AdditiveKernel(Kernel): |
575 | 708 | :param kernels: Kernels to add together. |
576 | 709 | """ |
577 | 710 |
|
| 711 | + def __init__(self, *kernels: Iterable[Kernel]): |
| 712 | + super(AdditiveKernel, self).__init__() |
| 713 | + self.kernels = ModuleList(kernels) |
| 714 | + |
578 | 715 | @property |
579 | 716 | def is_stationary(self) -> bool: |
580 | 717 | return all(k.is_stationary for k in self.kernels) |
581 | 718 |
|
582 | | - def __init__(self, *kernels: Iterable[Kernel]): |
583 | | - super(AdditiveKernel, self).__init__() |
584 | | - self.kernels = ModuleList(kernels) |
| 719 | + @property |
| 720 | + def _lazily_evaluate(self) -> bool: |
| 721 | + return all(k._lazily_evaluate for k in self.kernels) |
585 | 722 |
|
586 | 723 | def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **params) -> Union[Tensor, LinearOperator]: |
587 | 724 | res = ZeroLinearOperator() if not diag else 0 |
@@ -617,13 +754,17 @@ class ProductKernel(Kernel): |
617 | 754 | :param kernels: Kernels to multiply together. |
618 | 755 | """ |
619 | 756 |
|
| 757 | + def __init__(self, *kernels: Iterable[Kernel]): |
| 758 | + super(ProductKernel, self).__init__() |
| 759 | + self.kernels = ModuleList(kernels) |
| 760 | + |
620 | 761 | @property |
621 | 762 | def is_stationary(self) -> bool: |
622 | 763 | return all(k.is_stationary for k in self.kernels) |
623 | 764 |
|
624 | | - def __init__(self, *kernels: Iterable[Kernel]): |
625 | | - super(ProductKernel, self).__init__() |
626 | | - self.kernels = ModuleList(kernels) |
| 765 | + @property |
| 766 | + def _lazily_evaluate(self) -> bool: |
| 767 | + return False |
627 | 768 |
|
628 | 769 | def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **params) -> Union[Tensor, LinearOperator]: |
629 | 770 | x1_eq_x2 = torch.equal(x1, x2) |
|
0 commit comments