Skip to content

Conversation

@zijianshen
Copy link

@zijianshen zijianshen commented Sep 3, 2025

…o overwrite

Extract the model forward logic so that we could apply plug-in into it for NE debugging

Purpose

Refactor to make the NE debugging plug-in could be easily added during running the model forward.

Test Plan

Tested locally


Essential Elements of an Effective PR Description Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the execute_model method in GPUModelRunner by extracting the core model forward pass into a new _forward method. This is a good change for modularity and extensibility, as it allows plugins to more easily hook into the model execution logic, which aligns with the stated purpose of enabling NE debugging. The implementation is straightforward. I've identified one issue with an incorrect type hint in the new _forward method that should be addressed to ensure type safety and clarity for future extensions.

intermediate_tensors: IntermediateTensors,
inputs_embeds: list[torch.Tensor],
model_kwargs: dict[str, Any],
) -> Tuple[torch.Tensor, Optional[KVConnectorOutput]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The return type hint for the model output is torch.Tensor, but the model can also return a tuple of tensors (e.g., when use_aux_hidden_state_outputs is true). This incorrect type hint can cause issues with static type checkers and mislead developers who might extend this class, especially since this method is designed to be overridden by plugins. Using Any will make the type hint correct for all possible return types from the model.

Suggested change
) -> Tuple[torch.Tensor, Optional[KVConnectorOutput]]:
) -> Tuple[Any, Optional[KVConnectorOutput]]:

@zijianshen zijianshen closed this Sep 3, 2025
@zijianshen zijianshen deleted the zijiansh branch September 3, 2025 23:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant