Skip to content

Commit 4cf4dd7

Browse files
committed
Draft
1 parent fdf794e commit 4cf4dd7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+4382
-753
lines changed

.github/build_windows_packages.ps1

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,17 @@ Remove-Item $ffDir.FullName -Recurse -Force
115115
Write-Host "[INFO] Installing PyTorch..."
116116
& ".\runtime\python.exe" -m ensurepip
117117
& ".\runtime\python.exe" -m pip install --upgrade pip --no-warn-script-location
118+
118119
switch ($cuda) {
119120
"cu124" {
120-
& ".\runtime\python.exe" -m pip install torch==2.6 torchaudio --index-url https://download.pytorch.org/whl/cu124 --no-warn-script-location
121+
& ".\runtime\python.exe" -m pip install psutil ninja packaging wheel "setuptools>=42" --no-warn-script-location
122+
& ".\runtime\python.exe" -m pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu124 --no-warn-script-location
123+
& ".\runtime\python.exe" -m pip install flash-attn -i https://xxxxrt666.github.io/PIP-Index/ --no-build-isolation
121124
}
122125
"cu128" {
126+
& ".\runtime\python.exe" -m pip install psutil ninja packaging wheel "setuptools>=42" --no-warn-script-location
123127
& ".\runtime\python.exe" -m pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu128 --no-warn-script-location
128+
& ".\runtime\python.exe" -m pip install flash-attn -i https://xxxxrt666.github.io/PIP-Index/ --no-build-isolation
124129
}
125130
default {
126131
Write-Error "Unsupported CUDA version: $cuda"

.github/workflows/build_windows_packages.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,15 @@ jobs:
3131
- name: Checkout
3232
uses: actions/checkout@v4
3333

34+
- name: Install Windows CUDA 12.9
35+
if: ${{ runner.os == 'Windows' && matrix.torch_cuda == '12.8' }}
36+
uses: Jimver/cuda-toolkit
37+
id: cuda-toolkit-win-129
38+
with:
39+
cuda: 12.9.1
40+
method: "network"
41+
sub-packages: '["nvcc", "cudart", "visual_studio_integration"]'
42+
3443
- name: Run Build and Upload Script
3544
shell: pwsh
3645
run: |

Docker/miniconda_install.sh

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@ fi
2323

2424
if [ "$TARGETPLATFORM" = "linux/amd64" ]; then
2525
"${WGET_CMD[@]}" -O miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-py311_25.3.1-1-Linux-x86_64.sh
26+
SYSROOT_PKG="sysroot_linux-64>=2.28"
2627
elif [ "$TARGETPLATFORM" = "linux/arm64" ]; then
2728
"${WGET_CMD[@]}" -O miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-py311_25.3.1-1-Linux-aarch64.sh
29+
SYSROOT_PKG="sysroot_linux-aarch64>=2.28"
2830
else
2931
exit 1
3032
fi
@@ -45,20 +47,36 @@ rm miniconda.sh
4547

4648
source "$HOME/miniconda3/etc/profile.d/conda.sh"
4749

50+
"$HOME/miniconda3/bin/conda" init bash
51+
52+
source "$HOME/.bashrc"
53+
4854
"$HOME/miniconda3/bin/conda" config --add channels conda-forge
4955

5056
"$HOME/miniconda3/bin/conda" update -q --all -y 1>/dev/null
5157

5258
"$HOME/miniconda3/bin/conda" install python=3.11 -q -y
5359

54-
"$HOME/miniconda3/bin/conda" install gcc=14 gxx ffmpeg cmake make unzip -q -y
60+
"$HOME/miniconda3/bin/conda" install gcc=11 gxx ffmpeg cmake make unzip $SYSROOT_PKG "libstdcxx-ng>=11" -q -y
5561

5662
if [ "$CUDA_VERSION" = "12.8" ]; then
5763
"$HOME/miniconda3/bin/pip" install torch torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/cu128
64+
"$HOME/miniconda3/bin/conda" install cuda-nvcc=12.8 -c nvidia
5865
elif [ "$CUDA_VERSION" = "12.6" ]; then
59-
"$HOME/miniconda3/bin/pip" install torch==2.6 torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/cu126
66+
"$HOME/miniconda3/bin/pip" install torch torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/cu126
67+
"$HOME/miniconda3/bin/conda" install cuda-nvcc=12.6 -c nvidia
6068
fi
6169

70+
CUDA_PATH=$(echo "$HOME/miniconda3/targets/"*-linux | awk '{print $1}')
71+
72+
export CUDA_HOME=$CUDA_PATH
73+
export PATH="$HOME/miniconda3/bin:$PATH"
74+
export PATH="$CUDA_HOME/bin:$PATH"
75+
export PATH="$CUDA_HOME/nvvm/bin:$PATH"
76+
77+
"$HOME/miniconda3/bin/pip" install psutil ninja packaging wheel "setuptools>=42"
78+
"$HOME/miniconda3/bin/pip" install flash-attn -i https://xxxxrt666.github.io/PIP-Index/ --no-build-isolation
79+
6280
"$HOME/miniconda3/bin/pip" cache purge
6381

6482
rm $LOG_PATH
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import importlib.util
2+
3+
if importlib.util.find_spec("mlx") is not None:
4+
from .sample_funcs_mlx import sample_naive as sample_naive_mlx
5+
from .t2s_engine_mlx import T2SEngine as T2SEngineMLX
6+
7+
backends = ["MLX"]
8+
else:
9+
backends = []
10+
11+
__all__ = ["T2SEngineMLX", "sample_naive_mlx", "backends"]
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from functools import partial
2+
from typing import Protocol, cast
3+
4+
import mlx.core as mx
5+
6+
Array = mx.array
7+
8+
9+
class SampleProtocolMLX(Protocol):
10+
@staticmethod
11+
def __call__(
12+
logits: Array,
13+
previous_tokens: Array,
14+
temperature: float,
15+
top_k: int,
16+
top_p: float,
17+
repetition_penalty: float,
18+
) -> Array: ...
19+
20+
21+
class sample_naive(SampleProtocolMLX):
22+
@partial(mx.compile, shapeless=True)
23+
@staticmethod
24+
def __call__(
25+
logits,
26+
previous_tokens,
27+
temperature,
28+
top_k,
29+
top_p,
30+
repetition_penalty,
31+
):
32+
if temperature <= 1e-5:
33+
probs = mx.softmax(logits, axis=-1)
34+
return mx.argmax(probs, axis=-1, keepdims=True)
35+
36+
if repetition_penalty != 1.0:
37+
batch_idx = mx.arange(cast(tuple[int, ...], previous_tokens.shape)[0])
38+
previous_tokens = previous_tokens.astype(mx.int64)
39+
selected_logists = logits[batch_idx, previous_tokens]
40+
selected_logists = mx.where(
41+
selected_logists < 0, selected_logists * repetition_penalty, selected_logists / repetition_penalty
42+
)
43+
logits[batch_idx, previous_tokens] = selected_logists
44+
45+
sorted_indices = mx.argsort(-logits, axis=-1)
46+
sorted_logits = mx.take_along_axis(logits, sorted_indices, axis=-1)
47+
cum_probs = mx.cumsum(mx.softmax(sorted_logits, axis=-1), axis=-1)
48+
sorted_indices_to_remove = cum_probs > top_p
49+
sorted_indices_to_remove[:, 0] = False
50+
indices_to_remove = mx.zeros_like(logits).astype(mx.bool_)
51+
batch_indices = mx.arange(cast(tuple[int, ...], logits.shape)[0])[:, None]
52+
indices_to_remove[batch_indices, sorted_indices] = sorted_indices_to_remove
53+
logits = mx.where(indices_to_remove, -mx.inf, logits)
54+
55+
logits = logits / temperature
56+
57+
v = mx.topk(logits, top_k)
58+
pivot = mx.expand_dims(v[:, -1], -1)
59+
logits = mx.where(logits < pivot, -mx.inf, logits)
60+
61+
gumbel_noise = mx.random.gumbel(shape=cast(tuple[int, ...], logits.shape), dtype=logits.dtype)
62+
idx_next = mx.argmax(logits + gumbel_noise, axis=-1, keepdims=True).astype(mx.int32)
63+
64+
return idx_next
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
"""
2+
Modified From https://github.com/XXXXRT666/GPT-SoVITS
3+
"""
4+
5+
from __future__ import annotations
6+
7+
import os
8+
from dataclasses import dataclass
9+
from typing import Callable, List, MutableSequence, Protocol, Type, cast
10+
11+
import mlx.core as mx
12+
import torch
13+
14+
from ..PyTorch.structs import T2SRequest, T2SResult
15+
from .sample_funcs_mlx import SampleProtocolMLX, sample_naive
16+
17+
Tensor = torch.Tensor
18+
Array = mx.array
19+
20+
21+
@dataclass(slots=True)
22+
class T2SRequestMLX:
23+
x: List[Array]
24+
x_lens: Array
25+
prompts: Array
26+
bert_feature: List[Array]
27+
valid_length: int
28+
top_k: int = 5
29+
top_p: float = 1
30+
early_stop_num: int = -1
31+
temperature: float = 1.0
32+
repetition_penalty: float = 1.35
33+
34+
@classmethod
35+
def from_torch(cls, request: T2SRequest) -> T2SRequestMLX:
36+
x = list(map(lambda tensor: mx.array(tensor.cpu()), request.x))
37+
x_lens = mx.array(request.x_lens.cpu())
38+
prompts = mx.array(request.prompts.cpu())
39+
bert_feature = list(map(lambda tensor: mx.array(tensor.cpu()), request.bert_feature))
40+
41+
return cls(
42+
x,
43+
x_lens,
44+
prompts,
45+
bert_feature,
46+
request.valid_length,
47+
request.top_k,
48+
request.top_p,
49+
request.early_stop_num,
50+
request.temperature,
51+
request.repetition_penalty,
52+
)
53+
54+
55+
class KVCacheProtocol(Protocol):
56+
k_cache: Array
57+
v_cache: Array
58+
59+
def empty(self) -> None: ...
60+
61+
def update_cache(self, input_pos: Array, k_val: Array, v_val: Array, *args, **kwds) -> tuple[Array, Array]: ...
62+
63+
def prefill_kv(self, k_val: Array, v_val: Array) -> None: ...
64+
65+
def sync_cache(self, kv_cache: KVCacheProtocol) -> None: ...
66+
67+
68+
class T2SDecoderProtocol(Protocol):
69+
max_seq_length: int
70+
EOS: int
71+
n_head: int
72+
73+
def embed(self, x: list[Array], y: Array, bert_features: list[Array]) -> Array: ...
74+
75+
76+
class T2SEngineProtocol(Protocol):
77+
def _handle_request(self, request: T2SRequest) -> tuple[list[Array], float]: ...
78+
79+
def generate(self, request: T2SRequest) -> T2SResult: ...
80+
81+
@staticmethod
82+
def load_decoder(
83+
weights_path: os.PathLike, max_batch_size: int = 1, implement: str = "MLX"
84+
) -> T2SDecoderProtocol: ...
85+
86+
87+
class T2SSessionMLX:
88+
def __init__(
89+
self,
90+
decoder: T2SDecoderProtocol,
91+
request_torch: T2SRequest,
92+
sample_func: Type[SampleProtocolMLX] = sample_naive,
93+
device: mx.Device = mx.Device(mx.cpu),
94+
dtype: mx.Dtype = mx.float32,
95+
):
96+
with mx.stream(device):
97+
request = T2SRequestMLX.from_torch(request_torch)
98+
99+
self.decoder = decoder
100+
self.request = request
101+
self.device = device
102+
self.dtype = dtype
103+
104+
bsz = len(request.x)
105+
y_len: int = cast(tuple[int, ...], request.prompts.shape)[-1]
106+
self.bsz = bsz
107+
self.y_len = y_len
108+
109+
# Cache
110+
self.kv_cache: MutableSequence[KVCacheProtocol]
111+
self.sample = sample_func()
112+
113+
# Forward args
114+
self.x = [i.astype(mx.int32) for i in request.x]
115+
self.x_lens = request.x_lens.astype(mx.int32)
116+
self.y = mx.zeros((bsz, decoder.max_seq_length)).astype(mx.int32)
117+
self.y[:, : cast(tuple[int, ...], request.prompts.shape)[-1]] = request.prompts.astype(mx.int32)
118+
self.bert_feature = [i.astype(dtype) for i in request.bert_feature]
119+
120+
self.prefill_len = self.x_lens + cast(tuple[int, ...], request.prompts.shape)[1]
121+
122+
self.input_pos = mx.zeros_like(self.prefill_len)
123+
self.input_pos += self.prefill_len
124+
125+
# EOS
126+
self.completed = mx.array([False] * len(self.x)).astype(mx.bool_)
127+
self.y_results: List[Array] = [None] * len(self.x) # type: ignore
128+
129+
self.xy_pos = decoder.embed(self.x, request.prompts, self.bert_feature)
130+
131+
max_len = int(self.prefill_len.max(-1))
132+
attn_mask = mx.zeros(shape=(bsz, max_len, max_len), dtype=mx.bool_)
133+
134+
for bs in range(bsz):
135+
pos = int(self.x_lens[bs])
136+
seq_len = pos + y_len
137+
138+
attn_mask[bs, :seq_len, :pos] = True
139+
140+
ar_mask = ~mx.triu(
141+
x=mx.ones(
142+
shape=(
143+
y_len,
144+
y_len,
145+
),
146+
dtype=mx.bool_,
147+
),
148+
k=1,
149+
)
150+
attn_mask[bs, pos:seq_len, pos:seq_len] = ar_mask
151+
152+
attn_mask = mx.repeat(mx.expand_dims(attn_mask, 1), decoder.n_head, 1)
153+
self.attn_mask = attn_mask
154+
155+
mx.eval(self.attn_mask)

0 commit comments

Comments
 (0)