-
Notifications
You must be signed in to change notification settings - Fork 442
Description
How would you describe the priority of this documentation request
Low (would be nice)
Is this for new documentation, or an update to existing docs?
Update
Describe the incorrect/future/missing documentation
Currently, _StaticCapture.__call__
will decorate a user-defined function that is currently typed with a generic Callable
.
The issue is, however, that the interface expected of the user function changes depending on whether you are using StaticCaptureTraining
versus StaticCaptureEvaluateNoGrad
: the former does actually expect the function to return something with a backward()
method, whereas the latter does not and it really is just any Callable
.
The suggestion would be to add some typing.Protocol
to signify this difference, so that it's both better documented and so that it should ideally get picked up by static analysis tools/LSPs. Users are then able to clearly see what expectations there are for using _StaticCapture
workflows. Something like:
from typing import Any, Callable, Protocol
class HasBackward(Protocol):
# nominally *could* be torch.Tensor, but this is more general
def backward(self, *args: Any, **kwds: Any) -> None:
...
class LossEmitter(Protocol):
# this replaces `Callable` for training
def __call__(self, *args: Any, **kwds: Any) -> HasBackward:
...
# decorated signature for training vs. for EvaluateNoGrad
StaticCaptureTraining.__call__(fn: LossEmitter) -> LossEmitter: ...
StaticCaptureEvaluateNoGrad.__call__(fn: Callable) -> Callable: ...