Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Use functional models inside usual nn.Module #1111

@subho406

Description

@subho406

Hi, Thanks for the adding functional features to Pytorch. I want to use a nn.Module converted into a functional form inside a usual stateful nn.Module. However, the code below does not correctly register the parameters for the functional module. Is there a way to do this currently?

import torch
import optree
import torch.nn as nn
from functorch import make_functional

x = torch.randn(4, 10)
class TinyModel(torch.nn.Module):

    def __init__(self):
        super(TinyModel, self).__init__()
        self.func_l,self.params_l=make_functional(nn.Linear(10,10))
        for i,ele in enumerate(self.params_l):
            self.register_parameter(str(i),ele)
    def forward(self,inputs):
        return self.func_l(self.params_l,inputs)
        
model = TinyModel()
func, params = make_functional(model)

This is useful for me as I want to use functional operations over an inner nn.Module (such as vmap, jvp, vip) inside the forward pass of an outer nn.Module. The idea is to be able to have a lifted version of vjp, jvp, etc, similar to Flax (https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.vjp.html).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions