Skip to content

Commit ffb86c6

Browse files
committed
⚡ fix experts torch
1 parent de082f1 commit ffb86c6

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

ktransformers/operators/experts.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -459,9 +459,9 @@ def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None
459459
self.up[i] = w["up"][i, ...].to(device=device, dtype=self.dtype)
460460
self.down[i] = w["down"][i, ...].to(device=device, dtype=self.dtype)
461461

462-
self.up = torch.cat(self.up, dim=0)
463-
self.gate = torch.cat(self.gate, dim=0)
464-
self.down = torch.cat(self.down, dim=0)
462+
self.up = torch.stack(self.up, dim=0)
463+
self.gate = torch.stack(self.gate, dim=0)
464+
self.down = torch.stack(self.down, dim=0)
465465
return
466466

467467
def unload(self):

0 commit comments

Comments
 (0)