[jax-inference-offloading] consolidate definitions for default tensor dtype#1816
[jax-inference-offloading] consolidate definitions for default tensor dtype#1816
Conversation
There was a problem hiding this comment.
Pull request overview
This PR refactors the default dtype handling by moving the default value from Python code to the protobuf definition. Instead of using fallback logic (param.vllm_param.dtype or 'bfloat16') in the Python code, the default is now specified directly in the proto file, simplifying the code and making the default more explicit.
Key changes:
- Added default value
'bfloat16'to thedtypefield in theVllmParammessage definition - Removed all fallback
or 'bfloat16'logic from the Python code in four locations
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| jax-inference-offloading/jax_inference_offloading/api/param_mapping.proto | Added default value for the dtype field in VllmParam message |
| jax-inference-offloading/jax_inference_offloading/vllm/extension.py | Removed fallback logic for dtype in update_weights and update_weights_grouped methods |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
jax-inference-offloading/jax_inference_offloading/api/param_mapping.proto
Outdated
Show resolved
Hide resolved
|
Why? |
Would it be better to have a single source of truth for model's default dtype? |
|
Default values in protos are a bit of an anti pattern (they even removed the feature completely in proto3). Once you put them in, you can never remove or change them again. I think they're acceptable when there's an obviously meaningful default, but I don't think that's the case here. I'd keep it in the application logic. |
That makes sense. I have removed the default value for |
|
|
||
| def make_mapping( | ||
| jax_name, vllm_name, vllm_shape, *, transform=None, jax_prefix="model", vllm_prefix="model" | ||
| jax_name, vllm_name, vllm_shape, *, transform=None, jax_prefix="model", vllm_prefix="model", dtype="bfloat16" |
There was a problem hiding this comment.
At the moment we don’t support any dtype conversion between the JAX and vLLM sides, so only vllm_param carries a dtype field, and the dtypes are expected to match between JAX and vLLM. Once we add conversion support, it may even make sense to stop specifying dtype in make_mapping altogether and instead rely on the handshake to discover the dtype at runtime.
Consolidates the definition for the default tensor dtype in refitting specs, via the
dtype="bfloat16"keyword argument ofmake_mapping(...)inmodels/__init__.py. Since all current refitting specs are defined viamake_mapping, this gives us a single source of truth for the default tensor dtype. This change should not introduce any visiblel functional changes for existing models.