|
14 | 14 | from torch import Tensor |
15 | 15 |
|
16 | 16 | import helion |
| 17 | +from helion._testing import DEVICE |
17 | 18 | import helion.language as hl |
18 | 19 |
|
19 | 20 |
|
@@ -137,17 +138,17 @@ def check(m: int, k: int, n: int) -> None: |
137 | 138 | k (int): Shared dimension. |
138 | 139 | n (int): Number of cols. |
139 | 140 | """ |
140 | | - x = torch.randn([m, k], device="cuda", dtype=torch.bfloat16) |
141 | | - w = torch.randint(-(2**15), 2**15 - 1, (k, n), device="cuda", dtype=torch.int16) |
| 141 | + x = torch.randn([m, k], device=DEVICE, dtype=torch.bfloat16) |
| 142 | + w = torch.randint(-(2**15), 2**15 - 1, (k, n), device=DEVICE, dtype=torch.int16) |
142 | 143 |
|
143 | 144 | result = bf16xint16_gemm(x, w, transpose=False) |
144 | 145 | expected = reference_bf16xint16_pytorch(x, w, transpose=False) |
145 | 146 | torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2) |
146 | 147 |
|
147 | 148 | x_int16 = torch.randint( |
148 | | - -(2**15), 2**15 - 1, (m, k), device="cuda", dtype=torch.int16 |
| 149 | + -(2**15), 2**15 - 1, (m, k), device=DEVICE, dtype=torch.int16 |
149 | 150 | ) |
150 | | - w_bf16 = torch.randn([k, n], device="cuda", dtype=torch.bfloat16) |
| 151 | + w_bf16 = torch.randn([k, n], device=DEVICE, dtype=torch.bfloat16) |
151 | 152 |
|
152 | 153 | result = bf16xint16_gemm(x_int16, w_bf16, transpose=True) |
153 | 154 | expected = reference_bf16xint16_pytorch(x_int16, w_bf16, transpose=True) |
|
0 commit comments