Skip to content

Commit fef6dd9

Browse files
authored
add accuracy and performance test (#1643)
1 parent e7d1c1d commit fef6dd9

9 files changed

+2110
-1
lines changed

.github/workflows/kt-kernel-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ jobs:
6666
bash install.sh build
6767
6868
- name: Run KT-Kernel CPU tests
69-
timeout-minutes: 30
69+
timeout-minutes: 60
7070
run: |
7171
cd kt-kernel/test
7272
python3 run_suite.py --hw cpu --suite default
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
#!/usr/bin/env python
2+
# coding=utf-8
3+
"""AMX MOE INT4 accuracy tests for KT-Kernel.
4+
5+
Tests accuracy of AMX-accelerated INT4 MOE operations against torch reference.
6+
"""
7+
8+
import os
9+
import sys
10+
import pytest
11+
12+
# Add parent directory to path for CI registration
13+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
14+
from ci.ci_register import register_cpu_ci
15+
16+
# Register this test for CPU CI with estimated runtime of 120 seconds
17+
register_cpu_ci(est_time=120, suite="default")
18+
19+
# Check if dependencies are available
20+
try:
21+
import torch
22+
import kt_kernel_ext
23+
HAS_DEPS = True
24+
except ImportError as e:
25+
HAS_DEPS = False
26+
import_error = str(e)
27+
28+
# Test parameters (from original test_moe_amx.py)
29+
expert_num = 256
30+
hidden_size = 7168
31+
intermediate_size = 2048
32+
max_len = 25600
33+
num_experts_per_tok = 8
34+
qlen = 1
35+
layer_num = 1
36+
validation_iter = 2
37+
physical_to_logical_map = None
38+
39+
40+
def act_fn(x):
41+
"""Activation function for MoE."""
42+
return x / (1.0 + torch.exp(-x))
43+
44+
45+
def mlp_torch(input, gate_proj, up_proj, down_proj):
46+
"""PyTorch reference implementation of MLP."""
47+
gate_buf = torch.mm(input, gate_proj.t())
48+
up_buf = torch.mm(input, up_proj.t())
49+
intermediate = act_fn(gate_buf) * up_buf
50+
ret = torch.mm(intermediate, down_proj.t())
51+
return ret
52+
53+
54+
def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):
55+
"""PyTorch reference implementation of MoE."""
56+
cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))
57+
cnts.scatter_(1, expert_ids, 1)
58+
tokens_per_expert = cnts.sum(dim=0)
59+
idxs = expert_ids.view(-1).argsort()
60+
sorted_tokens = input[idxs // expert_ids.shape[1]]
61+
62+
outputs = []
63+
start_idx = 0
64+
65+
for i, num_tokens in enumerate(tokens_per_expert):
66+
end_idx = start_idx + num_tokens
67+
if num_tokens == 0:
68+
continue
69+
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
70+
expert_out = mlp_torch(
71+
tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i]
72+
)
73+
outputs.append(expert_out)
74+
start_idx = end_idx
75+
76+
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
77+
78+
new_x = torch.empty_like(outs)
79+
new_x[idxs] = outs
80+
t_output = (
81+
new_x.view(*expert_ids.shape, -1)
82+
.type(weights.dtype)
83+
.mul_(weights.unsqueeze(dim=-1))
84+
.sum(dim=1)
85+
.type(new_x.dtype)
86+
)
87+
88+
return t_output
89+
90+
91+
@pytest.mark.cpu
92+
def test_moe_amx_int4_accuracy():
93+
"""Test AMX INT4 MOE accuracy against PyTorch reference implementation."""
94+
if not HAS_DEPS:
95+
pytest.skip(f"Dependencies not available: {import_error}")
96+
97+
global physical_to_logical_map
98+
physical_to_logical_map = torch.tensor(
99+
data=range(expert_num), device="cpu", dtype=torch.int64
100+
).contiguous()
101+
102+
CPUInfer = kt_kernel_ext.CPUInfer(90)
103+
104+
with torch.inference_mode(mode=True):
105+
# Initialize MoE layers
106+
gate_proj = (
107+
torch.randn(
108+
(expert_num, intermediate_size, hidden_size),
109+
dtype=torch.bfloat16,
110+
device="cuda",
111+
)
112+
.to("cpu")
113+
.contiguous()
114+
)
115+
up_proj = (
116+
torch.randn(
117+
(expert_num, intermediate_size, hidden_size),
118+
dtype=torch.bfloat16,
119+
device="cuda",
120+
)
121+
.to("cpu")
122+
.contiguous()
123+
)
124+
down_proj = (
125+
torch.randn(
126+
(expert_num, hidden_size, intermediate_size),
127+
dtype=torch.bfloat16,
128+
device="cuda",
129+
)
130+
.to("cpu")
131+
.contiguous()
132+
)
133+
134+
# Create MOE config
135+
config = kt_kernel_ext.moe.MOEConfig(
136+
expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0
137+
)
138+
config.max_len = max_len
139+
config.gate_proj = gate_proj.data_ptr()
140+
config.up_proj = up_proj.data_ptr()
141+
config.down_proj = down_proj.data_ptr()
142+
config.gate_scale = 0
143+
config.pool = CPUInfer.backend_
144+
145+
# Initialize INT4 MOE
146+
moe = kt_kernel_ext.moe.AMXInt4_MOE(config)
147+
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
148+
CPUInfer.sync()
149+
CPUInfer.submit(moe.warm_up_task())
150+
CPUInfer.sync()
151+
152+
# Run validation iterations
153+
for i in range(validation_iter):
154+
bsz_tensor = torch.tensor([qlen], device="cpu")
155+
expert_ids = torch.stack(
156+
[torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]
157+
).contiguous()
158+
weights = torch.rand((qlen, num_experts_per_tok), dtype=torch.float32).contiguous()
159+
input_data = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous()
160+
output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()
161+
input_data = input_data / 100
162+
163+
# Run AMX MOE
164+
CPUInfer.submit(
165+
moe.forward_task(
166+
bsz_tensor.data_ptr(),
167+
num_experts_per_tok,
168+
expert_ids.data_ptr(),
169+
weights.data_ptr(),
170+
input_data.data_ptr(),
171+
output.data_ptr(),
172+
False,
173+
)
174+
)
175+
CPUInfer.sync()
176+
177+
# Run torch reference
178+
t_output = moe_torch(
179+
input_data, expert_ids, weights, gate_proj, up_proj, down_proj
180+
)
181+
182+
# Calculate relative difference
183+
diff = torch.mean(torch.abs(output - t_output)) / torch.mean(
184+
torch.abs(t_output)
185+
)
186+
print(f"Iteration {i}, diff = {diff:.6f}")
187+
188+
# INT4 should have diff < 0.35
189+
assert diff < 0.35, f"INT4 accuracy test failed: diff={diff:.6f} >= 0.35"
190+
191+
192+
def run_all_tests():
193+
"""Run all tests in this file (for standalone execution)."""
194+
if not HAS_DEPS:
195+
print(f"⚠ Dependencies not available: {import_error}")
196+
print("Skipping AMX MOE INT4 accuracy tests")
197+
return
198+
199+
try:
200+
print("Running AMX MOE INT4 accuracy test...")
201+
test_moe_amx_int4_accuracy()
202+
print("✓ AMX MOE INT4 accuracy test passed")
203+
print("\n✓ All tests passed!")
204+
except Exception as e:
205+
print(f"\n✗ Test failed: {e}")
206+
import traceback
207+
traceback.print_exc()
208+
sys.exit(1)
209+
210+
211+
if __name__ == "__main__":
212+
run_all_tests()

0 commit comments

Comments
 (0)