Hadamard Multitask GP revisited #1979
Unanswered
prockenschaub
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I am currently trying to implement a Hadamard Multitask GP model. The tutorial in the docs was a great help in getting something up and running quickly, and the modular nature of gpytorch is a joy to work with!
That being said, what I ultimately ended up with feels a bit hacky. Unlike the tutorial, I would like my GP to support different means and noise variables per task. I also wanted to wrap the kernel into a subclass of
gpytorch.kernels.Kernelto make it more readable and to easily allow for multiple kernels withgpytorch.kernels.AdditiveKernel. Similar to #654, I ran into difficulties communicatingtask_indicesto all model parts that needed them.I thought I'd share what I ended up with here
Mean
As in #654, the main problem is that
gpytorch.means.Mean.__call__()only takesxas the input. The solution for me was to reservex[..., 0]for the task indices.It would also have been possible to override the
__call__()function, as in #654. However, this would not been as straightforward for other objects like likelihoods, as we'll see below.Kernel
Since
gpytorch.kernels.Kernel.__call__()takes a**paramsargument, the Kernel could in theory be passed separatetask_indices1andtask_indices2parameters. However, to be compatible with the rest, I again used a part ofxto carry the task indices.Likelihood
Here is where things get hacky. In order to get a one noise parameter per task, I think I need to create my own
Noiseclass, which will need to know the task indicesi.In order to make sure that
iis always passed to the noise, I also need to subclass the likelihood, relying on the data to be passed as the first element in*args. The reason for this is that although we have**kwargsin the call to the likelihood, this was the only quick way that I could find to pass indices during a call to thegpytorch.mlls.ExactMarginalLogLikelihood(which does not pass on**kwargsto the likelihood) while simultaneously allowing me to usegpytorch.models.exact_prediction_strategies.DefaultPredictionStrategy(which I did not feel comfortable amending without breaking anything) and which otherwise would throw an error here:gpytorch/gpytorch/models/exact_prediction_strategies.py
Line 60 in e64cb6f
My implementation of the likelihood looks like this
Putting it all together
Putting all of the above together in a subclass of
ExactGP, I can fit a slightly extended tutorial. The full code for this can be found below (Python 3.7,pytorch==1.11.0,gpytorch==1.6.0)Click to see full example
I am sure this can be done better, so any pointers are greatly appreciated.
Beta Was this translation helpful? Give feedback.
All reactions