22
33import torch
44
5- from ..distributions import MultitaskMultivariateNormal
6- from ..lazy import KroneckerProductLazyTensor , MatmulLazyTensor
5+ from .. import settings
6+ from ..distributions import MultitaskMultivariateNormal , MultivariateNormal
7+ from ..lazy import KroneckerProductLazyTensor , RootLazyTensor
78from ..module import Module
9+ from ..utils .broadcasting import _mul_broadcast_shape
10+ from ..utils .interpolation import left_interp
811from ._variational_strategy import _VariationalStrategy
912
1013
14+ def _select_lmc_coefficients (lmc_coefficients : torch .Tensor , indices : torch .LongTensor ) -> torch .Tensor :
15+ """
16+ Given a list of indices for ... x N datapoints,
17+ select the row from lmc_coefficient that corresponds to each datapoint
18+
19+ lmc_coefficients: torch.Tensor ... x num_latents x ... x num_tasks
20+ indices: torch.Tesnor ... x N
21+ """
22+ batch_shape = _mul_broadcast_shape (lmc_coefficients .shape [:- 1 ], indices .shape [:- 1 ])
23+
24+ # We will use the left_interp helper to do the indexing
25+ lmc_coefficients = lmc_coefficients .expand (* batch_shape , lmc_coefficients .shape [- 1 ])[..., None ]
26+ indices = indices .expand (* batch_shape , indices .shape [- 1 ])[..., None ]
27+ res = left_interp (
28+ indices , torch .ones (indices .shape , dtype = torch .long , device = indices .device ), lmc_coefficients ,
29+ ).squeeze (- 1 )
30+ return res
31+
32+
1133class LMCVariationalStrategy (_VariationalStrategy ):
1234 r"""
1335 LMCVariationalStrategy is an implementation of the "Linear Model of Coregionalization"
@@ -20,8 +42,11 @@ class LMCVariationalStrategy(_VariationalStrategy):
2042
2143 f_{\text{task } i}( \mathbf x) = \sum_{q=1}^Q a_i^{(q)} g^{(q)} ( \mathbf x )
2244
23- LMCVariationalStrategy wraps an existing :obj:`~gpytorch.variational.VariationalStrategy`
24- to produce a :obj:`~gpytorch.variational.MultitaskMultivariateNormal` distribution.
45+ LMCVariationalStrategy wraps an existing :obj:`~gpytorch.variational.VariationalStrategy`.
46+ The output will either be a :obj:`~gpytorch.distributions.MultitaskMultivariateNormal` distribution
47+ (if we wish to evaluate all tasks for each input) or a :obj:`~gpytorch.distributions.MultivariateNormal`
48+ (if we wish to evaluate a single task for each input).
49+
2550 The base variational strategy is assumed to operate on a multi-batch of GPs, where one
2651 of the batch dimensions corresponds to the latent function dimension.
2752
@@ -35,13 +60,6 @@ class LMCVariationalStrategy(_VariationalStrategy):
3560 batch shape. This would correspond to each of the latent functions having different kernels
3661 or the same kernel, respectivly.
3762
38- :param ~gpytorch.variational.VariationalStrategy base_variational_strategy: Base variational strategy
39- :param int num_tasks: The total number of tasks (output functions)
40- :param int num_latents: The total number of latent functions in each group
41- :param latent_dim: (Default: -1) Which batch dimension corresponds to the latent function batch.
42- **Must be negative indexed**
43- :type latent_dim: `int` < 0
44-
4563 Example:
4664 >>> class LMCMultitaskGP(gpytorch.models.ApproximateGP):
4765 >>> '''
@@ -74,7 +92,13 @@ class LMCVariationalStrategy(_VariationalStrategy):
7492 >>> batch_shape=torch.Size([3]),
7593 >>> )
7694 >>>
77- >>> # Model output: n x 5
95+
96+ :param ~gpytorch.variational.VariationalStrategy base_variational_strategy: Base variational strategy
97+ :param int num_tasks: The total number of tasks (output functions)
98+ :param int num_latents: The total number of latent functions in each group
99+ :param latent_dim: (Default: -1) Which batch dimension corresponds to the latent function batch.
100+ **Must be negative indexed**
101+ :type latent_dim: `int` < 0
78102 """
79103
80104 def __init__ (
@@ -120,28 +144,84 @@ def variational_params_initialized(self):
120144 def kl_divergence (self ):
121145 return super ().kl_divergence ().sum (dim = self .latent_dim )
122146
123- def __call__ (self , x , prior = False , ** kwargs ):
124- function_dist = self .base_variational_strategy (x , prior = prior , ** kwargs )
125- lmc_coefficients = self .lmc_coefficients .expand (* function_dist .batch_shape , self .lmc_coefficients .size (- 1 ))
126- num_batch = len (function_dist .batch_shape )
127- num_dim = num_batch + len (function_dist .event_shape )
128- latent_dim = num_batch + self .latent_dim if self .latent_dim is not None else None
129-
130- # Mean
131- mean = function_dist .mean .permute (* range (0 , latent_dim ), * range (latent_dim + 1 , num_dim ), latent_dim )
132- mean = mean @ lmc_coefficients .permute (
133- * range (0 , latent_dim ), * range (latent_dim + 1 , num_dim - 1 ), latent_dim , - 1
134- )
135-
136- # Covar
137- covar = function_dist .lazy_covariance_matrix
138- lmc_factor = MatmulLazyTensor (lmc_coefficients .unsqueeze (- 1 ), lmc_coefficients .unsqueeze (- 2 ))
139- covar = KroneckerProductLazyTensor (covar , lmc_factor )
140- covar = covar .sum (latent_dim )
141-
142- # Add a bit of jitter to make the covar PD
143- covar = covar .add_jitter (1e-6 )
144-
145- # Done!
146- function_dist = MultitaskMultivariateNormal (mean , covar )
147+ def __call__ (self , x , task_indices = None , prior = False , ** kwargs ):
148+ r"""
149+ Computes the variational (or prior) distribution
150+ :math:`q( \mathbf f \mid \mathbf X)` (or :math:`p( \mathbf f \mid \mathbf X)`).
151+ There are two modes:
152+
153+ 1. Compute **all tasks** for all inputs.
154+ If this is the case, the :attr:`task_indices` attribute should be None.
155+ The return type will be a (... x N x num_tasks)
156+ :class:`~gpytorch.distributions.MultitaskMultivariateNormal`.
157+ 2. Compute **one task** per inputs.
158+ If this is the case, the (... x N) :attr:`task_indices` tensor should contain
159+ the indices of each input's assigned task.
160+ The return type will be a (... x N)
161+ :class:`~gpytorch.distributions.MultivariateNormal`.
162+
163+ :param x: Input locations to evaluate variational strategy
164+ :type x: torch.Tensor (... x N x D)
165+ :param task_indices: (Default: None) Task index associated with each input.
166+ If this **is not** provided, then the returned distribution evaluates every input on every task
167+ (returns :class:`~gpytorch.distributions.MultitaskMultivariateNormal`).
168+ If this **is** provided, then the returned distribution evaluates each input only on its assigned task.
169+ (returns :class:`~gpytorch.distributions.MultivariateNormal`).
170+ :type task_indices: torch.Tensor (... x N), optional
171+ :param prior: (Default: False) If False, returns the variational distribution
172+ :math:`q( \mathbf f \mid \mathbf X)`.
173+ If True, returns the prior distribution
174+ :math:`p( \mathbf f \mid \mathbf X)`.
175+ :type prior: bool
176+ :return: :math:`q( \mathbf f \mid \mathbf X)` (or the prior),
177+ either for all tasks (if `task_indices == None`)
178+ or for a specific task (if `task_indices != None`).
179+ :rtype: ~gpytorch.distributions.MultitaskMultivariateNormal (... x N x num_tasks)
180+ or ~gpytorch.distributions.MultivariateNormal (... x N)
181+ """
182+ latent_dist = self .base_variational_strategy (x , prior = prior , ** kwargs )
183+ num_batch = len (latent_dist .batch_shape )
184+ latent_dim = num_batch + self .latent_dim
185+
186+ if task_indices is None :
187+ num_dim = num_batch + len (latent_dist .event_shape )
188+
189+ # Every data point will get an output for each task
190+ # Therefore, we will set up the lmc_coefficients shape for a matmul
191+ lmc_coefficients = self .lmc_coefficients .expand (* latent_dist .batch_shape , self .lmc_coefficients .size (- 1 ))
192+
193+ # Mean: ... x N x num_tasks
194+ latent_mean = latent_dist .mean .permute (* range (0 , latent_dim ), * range (latent_dim + 1 , num_dim ), latent_dim )
195+ mean = latent_mean @ lmc_coefficients .permute (
196+ * range (0 , latent_dim ), * range (latent_dim + 1 , num_dim - 1 ), latent_dim , - 1
197+ )
198+
199+ # Covar: ... x (N x num_tasks) x (N x num_tasks)
200+ latent_covar = latent_dist .lazy_covariance_matrix
201+ lmc_factor = RootLazyTensor (lmc_coefficients .unsqueeze (- 1 ))
202+ covar = KroneckerProductLazyTensor (latent_covar , lmc_factor ).sum (latent_dim )
203+ # Add a bit of jitter to make the covar PD
204+ covar = covar .add_jitter (settings .cholesky_jitter .value (dtype = mean .dtype ))
205+
206+ # Done!
207+ function_dist = MultitaskMultivariateNormal (mean , covar )
208+
209+ else :
210+ # Each data point will get a single output corresponding to a single task
211+ # Therefore, we will select the appropriate lmc coefficients for each task
212+ lmc_coefficients = _select_lmc_coefficients (self .lmc_coefficients , task_indices )
213+
214+ # Mean: ... x N
215+ mean = (latent_dist .mean * lmc_coefficients ).sum (latent_dim )
216+
217+ # Covar: ... x N x N
218+ latent_covar = latent_dist .lazy_covariance_matrix
219+ lmc_factor = RootLazyTensor (lmc_coefficients .unsqueeze (- 1 ))
220+ covar = (latent_covar * lmc_factor ).sum (latent_dim )
221+ # Add a bit of jitter to make the covar PD
222+ covar = covar .add_jitter (settings .cholesky_jitter .value (dtype = mean .dtype ))
223+
224+ # Done!
225+ function_dist = MultivariateNormal (mean , covar )
226+
147227 return function_dist
0 commit comments