diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 5c2606fe..03ede05f 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -10,7 +10,6 @@ apt update -y && apt install -y --no-install-recommends \ g++ \ cmake && rm -rf /var/lib/apt/lists/* && -cd ktransformers && pip install ninja pyproject numpy cpufeature && pip install flash-attn && cp /usr/lib/x86_64-linux-gnu/libstdc++.so.6 /opt/conda/lib/ diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index 10e3a668..88960c70 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -459,9 +459,9 @@ def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None self.up[i] = w["up"][i, ...].to(device=device, dtype=self.dtype) self.down[i] = w["down"][i, ...].to(device=device, dtype=self.dtype) - self.up = torch.cat(self.up, dim=0) - self.gate = torch.cat(self.gate, dim=0) - self.down = torch.cat(self.down, dim=0) + self.up = torch.stack(self.up, dim=0) + self.gate = torch.stack(self.gate, dim=0) + self.down = torch.stack(self.down, dim=0) return def unload(self):