This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Description
Hi, Thanks for the adding functional features to Pytorch. I want to use a nn.Module converted into a functional form inside a usual stateful nn.Module. However, the code below does not correctly register the parameters for the functional module. Is there a way to do this currently?
import torch
import optree
import torch.nn as nn
from functorch import make_functional
x = torch.randn(4, 10)
class TinyModel(torch.nn.Module):
def __init__(self):
super(TinyModel, self).__init__()
self.func_l,self.params_l=make_functional(nn.Linear(10,10))
for i,ele in enumerate(self.params_l):
self.register_parameter(str(i),ele)
def forward(self,inputs):
return self.func_l(self.params_l,inputs)
model = TinyModel()
func, params = make_functional(model)
This is useful for me as I want to use functional operations over an inner nn.Module (such as vmap, jvp, vip) inside the forward pass of an outer nn.Module. The idea is to be able to have a lifted version of vjp, jvp, etc, similar to Flax (https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.vjp.html).