Skip to content

📚[DOC]: _StaticCapture decorated function signature changes based on behavior #1115

@laserkelvin

Description

@laserkelvin

How would you describe the priority of this documentation request

Low (would be nice)

Is this for new documentation, or an update to existing docs?

Update

Describe the incorrect/future/missing documentation

Currently, _StaticCapture.__call__ will decorate a user-defined function that is currently typed with a generic Callable.

The issue is, however, that the interface expected of the user function changes depending on whether you are using StaticCaptureTraining versus StaticCaptureEvaluateNoGrad: the former does actually expect the function to return something with a backward() method, whereas the latter does not and it really is just any Callable.

The suggestion would be to add some typing.Protocol to signify this difference, so that it's both better documented and so that it should ideally get picked up by static analysis tools/LSPs. Users are then able to clearly see what expectations there are for using _StaticCapture workflows. Something like:

from typing import Any, Callable, Protocol


class HasBackward(Protocol):
    # nominally *could* be torch.Tensor, but this is more general
    def backward(self, *args: Any, **kwds: Any) -> None:
         ...


class LossEmitter(Protocol):
    # this replaces `Callable` for training
    def __call__(self, *args: Any, **kwds: Any) -> HasBackward:
         ...


# decorated signature for training vs. for EvaluateNoGrad
StaticCaptureTraining.__call__(fn: LossEmitter) -> LossEmitter: ...
StaticCaptureEvaluateNoGrad.__call__(fn: Callable) -> Callable: ...

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions