Skip to content

[BUG] many mlx-swift functions incorrectly promote to float32, e.g. zeros, ones, etc. #390

@davidkoski

Description

@davidkoski

From ml-explore/mlx-swift-lm#124

@arkavo-com found ml-explore/mlx-swift-lm#124 (comment)

Swift's MLXArray.ones(), .zeros(), and Swift Float literals default to float32. Python's MLX inherits dtype from context. When a float32 value enters a bfloat16 computation graph, MLX's C++ engine inserts AsType cast nodes. Each AsType forces a new Metal buffer allocation (different bit-width = buffer donation is physically impossible).

Oh that is very interesting! Amazing find! The MLXArray.ones([dim]) resolves to:

    static public func ones(
        _ shape: some Collection<Int>, type: (some HasDType).Type = Float.self,
        stream: StreamOrDevice = .default
    ) -> MLXArray {
        MLX.ones(shape, type: type, stream: stream)
    }

how is this handled on the python side (llm side)? There is no "safe" dtype to use, though there are potentially cheaper ones.

I wonder if we should:

  • add an overload with no type/dtype and deprecate it
  • remove the default on the type here

The dtype variant has no default.

It looks like these functions all have this flaw (search for type: (some HasDType).Type = Float.self):

  • zeros
  • ones
  • eye
  • full
  • identity
  • tri

It seems that python does not have functions like this with a default dtype.

These all have a default dtype and should not (search for dtype: DType =):

  • arange
  • uniform
  • normal
  • multivariateNormal
  • truncatedNormal
  • gumbel
  • laplace

In python these infer types based on the inputs -- mlx-swift should do the same.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions