From ml-explore/mlx-swift-lm#124
@arkavo-com found ml-explore/mlx-swift-lm#124 (comment)
Swift's MLXArray.ones(), .zeros(), and Swift Float literals default to float32. Python's MLX inherits dtype from context. When a float32 value enters a bfloat16 computation graph, MLX's C++ engine inserts AsType cast nodes. Each AsType forces a new Metal buffer allocation (different bit-width = buffer donation is physically impossible).
Oh that is very interesting! Amazing find! The MLXArray.ones([dim]) resolves to:
static public func ones(
_ shape: some Collection<Int>, type: (some HasDType).Type = Float.self,
stream: StreamOrDevice = .default
) -> MLXArray {
MLX.ones(shape, type: type, stream: stream)
}
how is this handled on the python side (llm side)? There is no "safe" dtype to use, though there are potentially cheaper ones.
I wonder if we should:
- add an overload with no type/dtype and deprecate it
- remove the default on the type here
The dtype variant has no default.
It looks like these functions all have this flaw (search for type: (some HasDType).Type = Float.self):
- zeros
- ones
- eye
- full
- identity
- tri
It seems that python does not have functions like this with a default dtype.
These all have a default dtype and should not (search for dtype: DType =):
- arange
- uniform
- normal
- multivariateNormal
- truncatedNormal
- gumbel
- laplace
In python these infer types based on the inputs -- mlx-swift should do the same.
From ml-explore/mlx-swift-lm#124
@arkavo-com found ml-explore/mlx-swift-lm#124 (comment)
Oh that is very interesting! Amazing find! The
MLXArray.ones([dim])resolves to:how is this handled on the python side (llm side)? There is no "safe" dtype to use, though there are potentially cheaper ones.
I wonder if we should:
The dtype variant has no default.
It looks like these functions all have this flaw (search for
type: (some HasDType).Type = Float.self):It seems that python does not have functions like this with a default dtype.
These all have a default dtype and should not (search for
dtype: DType =):In python these infer types based on the inputs -- mlx-swift should do the same.