Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion jax/_src/pallas/mosaic/sc_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def _lax_cumsum_lowering_rule(ctx: sc_lowering.LoweringRuleContext, x, axis,
raise NotImplementedError("SC cumsum: reverse=True is not yet supported")
i1t = ir.IntegerType.get_signless(1)
c1 = arith.constant(i1t, ir.IntegerAttr.get(i1t, 1))
c1v = vector.splat(ir.VectorType.get(x.type.shape, c1.type), c1)
c1v = vector.broadcast(ir.VectorType.get(x.type.shape, c1.type), c1)
return tpu.scan(
x.type, x, ir.Attribute.parse("#tpu.reduction_kind<sum>"), mask=c1v)

Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/mosaic/gpu/dialect_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,9 +592,9 @@ def store_untiled(optimized: bool):
return []


@_register_lowering(vector.SplatOp)
@_register_lowering(vector.BroadcastOp)
def _vector_splat_op_lowering_rule(
_: LoweringContext, vector_splat_op: vector.SplatOp
_: LoweringContext, vector_splat_op: vector.BroadcastOp
) -> Sequence[ir.Value]:

out_vec_ty = ir.VectorType(vector_splat_op.aggregate.type)
Expand Down
12 changes: 7 additions & 5 deletions jax/experimental/mosaic/gpu/fragmented_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,9 @@ def splat(
case WGSplatFragLayout():
pass
case WGStridedFragLayout() | TiledLayout():
value = vector.splat(layout.registers_element_type(value.type), value)
value = vector.broadcast(
layout.registers_element_type(value.type), value
)
case _:
raise NotImplementedError(layout)

Expand Down Expand Up @@ -1847,7 +1849,7 @@ def packed_registers() -> Iterable[tuple[Sequence[int], ir.Value]]:
for part in range(max(group_size // 4, 1))
]
out_vec_int = utils.vector_concat([
vector.splat(ir.VectorType.get((1,), i32), out_i32_reg)
vector.broadcast(ir.VectorType.get((1,), i32), out_i32_reg)
for out_i32_reg in out_i32_regs
])
out_vector_len = len(out_i32_regs) * 4
Expand Down Expand Up @@ -1933,7 +1935,7 @@ def upcast_i4_to_bf16(reg: ir.Value, reg_shr: ir.Value, part: int):
offset += group_size
assert offset == vector_len
out_vec_int = utils.vector_concat([
vector.splat(ir.VectorType.get((1,), i32), reg)
vector.broadcast(ir.VectorType.get((1,), i32), reg)
for reg in out_int_regs
])
new_registers[idx] = utils.bitcast(out_vec_int, out_vec_ty)
Expand Down Expand Up @@ -2262,7 +2264,7 @@ def reduce(
scalar_out_reg = (
scalar if scalar_out_reg is None else op(scalar_out_reg, scalar)
)
out_reg = vector.splat(
out_reg = vector.broadcast(
ir.VectorType.get((1,), out_reg.type.element_type), scalar_out_reg
)
# Reduce across warp lanes, if necessary (using warp shuffles).
Expand Down Expand Up @@ -2701,7 +2703,7 @@ def load_tiled(
tiling = Tiling((tiled_shape[len(tiled_shape) // 2 :],))
shape = tiling.untile_shape(tiled_shape)
reg_ty = ir.VectorType.get((layout.vector_length,), dtype)
zero = vector.splat(reg_ty, c(0, dtype))
zero = vector.broadcast(reg_ty, c(0, dtype))
registers = np.full(layout.registers_shape(shape), zero, dtype=object)
is_f8 = ir.FloatType.isinstance(dtype) and utils.bitwidth(dtype) == 8
i8 = ir.IntegerType.get_signless(8)
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/mosaic/gpu/layout_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ def _optimization_barrier_equation_system(
return eqns.EquationSystem(), operand_or_results_for_variable, []


@_add_equation_system_derivation_rule(vector.SplatOp)
@_add_equation_system_derivation_rule(vector.BroadcastOp)
def _vector_splat_equation_system(
ctx: DerivationContext,
op: ir.OpView,
Expand Down
6 changes: 4 additions & 2 deletions jax/experimental/mosaic/gpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def c(val: int | float, ty):
elif ir.FloatType.isinstance(ty):
attr = ir.FloatAttr.get(ty, val)
elif ir.VectorType.isinstance(ty):
return vector.splat(ty, c(val, ir.VectorType(ty).element_type))
return vector.broadcast(ty, c(val, ir.VectorType(ty).element_type))
else:
raise NotImplementedError(ty)
return arith.constant(ty, attr)
Expand Down Expand Up @@ -1690,7 +1690,9 @@ def bitcast(x: ir.Value, new_type: ir.Type):
new_type = ir.VectorType(new_type)
x_ty = ir.IntegerType(x.type)
assert x_ty.width == bitwidth(new_type.element_type) * math.prod(new_type.shape)
return vector.bitcast(new_type, vector.splat(ir.VectorType.get((1,), x_ty), x))
return vector.bitcast(
new_type, vector.broadcast(ir.VectorType.get((1,), x_ty), x)
)
if ir.VectorType.isinstance(x.type) and ir.VectorType.isinstance(new_type):
x_ty = ir.VectorType(x.type)
new_ty = ir.VectorType(new_type)
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/mosaic/gpu/wgmma.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,5 +492,5 @@ def _llvm_add(x, y):
def _unpack_i32(vec_ty, r):
i32 = ir.IntegerType.get_signless(32)
return vector.bitcast(
vec_ty, vector.splat(ir.VectorType.get((1,), i32), r)
vec_ty, vector.broadcast(ir.VectorType.get((1,), i32), r)
)
4 changes: 2 additions & 2 deletions jaxlib/mosaic/dialect/tpu/vreg_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ TEST_F(VregUtilTest, GetFullVector) {

TEST_F(VregUtilTest, GetFullLikeVector) {
VectorType vty = VectorType::get({2, 4}, Builder().getF32Type());
TypedValue<VectorType> in_vec = Builder().create<vector::SplatOp>(
TypedValue<VectorType> in_vec = Builder().create<vector::BroadcastOp>(
vty, Builder().create<arith::ConstantOp>(
vty.getElementType(), Builder().getF32FloatAttr(1.0f)));
TypedValue<VectorType> vec =
Expand All @@ -193,7 +193,7 @@ TEST_F(VregUtilTest, GetZerosVector) {

TEST_F(VregUtilTest, GetZerosLikeVector) {
VectorType vty = VectorType::get({2, 4}, Builder().getF32Type());
TypedValue<VectorType> in_vec = Builder().create<vector::SplatOp>(
TypedValue<VectorType> in_vec = Builder().create<vector::BroadcastOp>(
vty, Builder().create<arith::ConstantOp>(
vty.getElementType(), Builder().getF32FloatAttr(1.0f)));
TypedValue<VectorType> vec = getZerosLikeVector(Builder(), in_vec);
Expand Down
5 changes: 3 additions & 2 deletions tests/mosaic/gpu_dialect_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,8 +1281,9 @@ def test_lowering_for(self):
with ir.InsertionPoint(self.module.body):
i1 = arith.constant(ir.IndexType.get(), 1)
c1 = arith.constant(i32, 1)
splat = vector.SplatOp(
ir.VectorType.get(shape, i32), arith.constant(i32, 1234),
splat = vector.BroadcastOp(
ir.VectorType.get(shape, i32),
arith.constant(i32, 1234),
)
splat.attributes["out_layouts"] = ir.ArrayAttr.get([
splat_layout_attr
Expand Down
2 changes: 1 addition & 1 deletion tests/mosaic/gpu_layout_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def test_infer_splat_layout_for_vector_splat(self):
ty = ir.VectorType.get(shape, bf16)
lhs, rhs = undefs(bf16, ty)
rhs = layout_cast(rhs, splat_layout)
splat = vector.SplatOp(rhs.type, lhs)
splat = vector.BroadcastOp(rhs.type, lhs)
add = arith.AddFOp(splat.result, rhs)

mgpu.infer_layout(self.module)
Expand Down
14 changes: 8 additions & 6 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4290,7 +4290,7 @@ def body(ctx, result_gmem_ref, scratch):
f32 = ir.F32Type.get()
x_type = ir.VectorType.get(input_shape, f32)
c = arith.constant(f32, element_value)
x = vector.splat(x_type, c)
x = vector.broadcast(x_type, c)

# Computation
out_type = ir.VectorType.get(output_shape, f32)
Expand Down Expand Up @@ -4321,7 +4321,9 @@ def test_bad_layout_cast_raises_in_inference(self):
def body(ctx, out, _):
del ctx, out
f32 = ir.F32Type.get()
x = vector.splat(ir.VectorType.get(shape, f32), arith.constant(f32, 0.0))
x = vector.broadcast(
ir.VectorType.get(shape, f32), arith.constant(f32, 0.0)
)
wgmma_layout = layouts.to_layout_attr(fa.WGMMA_LAYOUT)
wgmma_row_layout = layouts.to_layout_attr(fa.WGMMA_ROW_LAYOUT)
lc1 = mgpu_dialect.layout_cast(x, wgmma_layout)
Expand Down Expand Up @@ -4372,12 +4374,12 @@ def body(ctx, result_gmem_ref, scratch):
# Create source in registers
source_type = ir.VectorType.get(input_shape, el_type)
c = arith.constant(el_type, input_value)
source = vector.splat(source_type, c)
source = vector.broadcast(source_type, c)

# Create accumulator in registers
acc_type = ir.VectorType.get(output_shape, el_type)
c = arith.constant(el_type, init_value)
acc = vector.splat(acc_type, c)
acc = vector.broadcast(acc_type, c)

# Cast inputs
source = mgpu_dialect.layout_cast(
Expand Down Expand Up @@ -4425,7 +4427,7 @@ def body(ctx, result_gmem_ref, smem):
f32 = ir.F32Type.get()
x_type = ir.VectorType.get(shape, f32)
c = arith.constant(f32, element_value)
x = vector.splat(x_type, c)
x = vector.broadcast(x_type, c)
cast = mgpu_dialect.layout_cast(x, layouts.to_layout_attr(in_layout))

# Registers -> SMEM
Expand Down Expand Up @@ -4736,7 +4738,7 @@ def matmul(
zero_acc = arith.constant(
result_elt_type, ir.FloatAttr.get(acc_elt_type, 0.0)
)
accumulator = vector.splat(acc_type, zero_acc)
accumulator = vector.broadcast(acc_type, zero_acc)

if transpose_lhs:
lhs_smem_ref = utils.memref_transpose(lhs_smem_ref, (1, 0))
Expand Down
2 changes: 1 addition & 1 deletion tests/mosaic/gpu_test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def kernel(ctx, sem, out, _):
other_sem = mgpu.SemaphoreRef(mgpu.utils.memref_ptr(other_dst))
with mgpu.when(arith.cmpi(arith.CmpIPredicate.eq, my_device, arith.constant(i32, 0))):
c = arith.constant(i32, 1)
vc = vector.splat(ir.VectorType.get((vector_length,), i32), c)
vc = vector.broadcast(ir.VectorType.get((vector_length,), i32), c)
multicast_ref = ctx.to_remote_multicast(out)
multicast_ref.store(vc, [arith.constant(index, 0)])
other_sem.signal(arith.constant(i32, 1))
Expand Down
Loading