Skip to content

deprecate factory methods that have an implicit Float/.float32 type#391

Open
davidkoski wants to merge 2 commits intomainfrom
fix-dtype
Open

deprecate factory methods that have an implicit Float/.float32 type#391
davidkoski wants to merge 2 commits intomainfrom
fix-dtype

Conversation

@davidkoski
Copy link
Copy Markdown
Collaborator

@davidkoski davidkoski commented Apr 6, 2026

wait just a sec... not sure we want this

specifically this logically deprecates functions like this:

static public func ones(
    _ shape: some Collection<Int>, type: (some HasDType).Type = Float.self,
    stream: StreamOrDevice = .default
) -> MLXArray {
    MLX.ones(shape, type: type, stream: stream)
}

calling with no type will get a deprecation warning

Proposed changes

Please include a description of the problem or feature this PR is addressing. If there is a corresponding issue, include the issue #.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

- fix #390

specifically this logically deprecates functions like this:

    static public func ones(
        _ shape: some Collection<Int>, type: (some HasDType).Type = Float.self,
        stream: StreamOrDevice = .default
    ) -> MLXArray {
        MLX.ones(shape, type: type, stream: stream)
    }

calling with no type will get a deprecation warning
@davidkoski
Copy link
Copy Markdown
Collaborator Author

@DePasqualeOrg -- nice, your doc checker already caught one!

/// - ``ones(_:type:stream:)``
static public func zeros(
_ shape: some Collection<Int>, type: (some HasDType).Type = Float.self,
_ shape: some Collection<Int>, type: (some HasDType).Type,
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the pattern: remove the default value. Now type: is required, which is a breaking change.

static public func zeros(
_ shape: some Collection<Int>,
stream: StreamOrDevice = .default
) -> MLXArray {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a no type: variant that is deprecated. Code that would have used the zeros() with no type now resolves to this and gets a deprecation warning.

_ stop: T, stream: StreamOrDevice = .default
) -> MLXArray {
MLX.arange(stop, stream: stream)
}
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding variants that will infer the integer dtype, e.g. Int16

/// ### See Also
/// - <doc:initialization>
/// - ``arange(_:dtype:stream:)-(Double,_,_)``
static public func arange<T: HasDType & BinaryFloatingPoint>(
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And float types

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that Double already has an implementation and produces .float32

stream: StreamOrDevice = .default
) -> MLXArray {
zeros(shape, type: Float.self, stream: stream)
}
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Free function variants.

let a = MLXArray.ones([2, 6, 6, 6])
let b = MLXArray.zeros([3, 4, 4, 4])
let a = MLXArray.ones([2, 6, 6, 6], type: Float.self)
let b = MLXArray.zeros([3, 4, 4, 4], type: Float.self)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update to use the non deprecated forms

XCTAssertEqual(c.shape, [0])
}

func testArangeDTypeInference() {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verify the type inference works as expected

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] many mlx-swift functions incorrectly promote to float32, e.g. zeros, ones, etc.

2 participants