Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 63 additions & 24 deletions tools/extract_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,58 +13,97 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from concurrent.futures import ThreadPoolExecutor

import onnxruntime
import torch
import torchaudio
from tqdm import tqdm
import onnxruntime
import torchaudio.compliance.kaldi as kaldi
from tqdm import tqdm


def extract_embedding(input_list):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

入参可以应该直接改为utt, wav

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

utt, wav_file, ort_session = input_list

audio, sample_rate = torchaudio.load(wav_file)
if sample_rate != 16000:
audio = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=16000
)(audio)
feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000)
feat = feat - feat.mean(dim=0, keepdim=True)
embedding = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不需要这么多换行,看我们workflow/lint.py,里面最大允许150长度,不然换行太多了也影响可读性

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

ort_session.run(
None,
{
ort_session.get_inputs()[0]
.name: feat.unsqueeze(dim=0)
.cpu()
.numpy()
},
)[0]
.flatten()
.tolist()
)
return (utt, embedding)


def main(args):
utt2wav, utt2spk = {}, {}
with open('{}/wav.scp'.format(args.dir)) as f:
with open("{}/wav.scp".format(args.dir)) as f:
for l in f:
l = l.replace('\n', '').split()
l = l.replace("\n", "").split()
utt2wav[l[0]] = l[1]
with open('{}/utt2spk'.format(args.dir)) as f:
with open("{}/utt2spk".format(args.dir)) as f:
for l in f:
l = l.replace('\n', '').split()
l = l.replace("\n", "").split()
utt2spk[l[0]] = l[1]

assert os.path.exists(args.onnx_path), "onnx_path not exists"

option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.graph_optimization_level = (
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
)
option.intra_op_num_threads = 1
providers = ["CPUExecutionProvider"]
ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
ort_session = onnxruntime.InferenceSession(
args.onnx_path, sess_options=option, providers=providers
)

inputs = [
(utt, utt2wav[utt], ort_session)
for utt in tqdm(utt2wav.keys(), desc="Load data")
]
with ThreadPoolExecutor(max_workers=args.num_thread) as executor:
results = list(
tqdm(
executor.map(extract_embedding, inputs),
total=len(inputs),
desc="Process data: ",
)
)

utt2embedding, spk2embedding = {}, {}
for utt in tqdm(utt2wav.keys()):
audio, sample_rate = torchaudio.load(utt2wav[utt])
if sample_rate != 16000:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
feat = kaldi.fbank(audio,
num_mel_bins=80,
dither=0,
sample_frequency=16000)
feat = feat - feat.mean(dim=0, keepdim=True)
embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
for utt, embedding in results:
utt2embedding[utt] = embedding
spk = utt2spk[utt]
if spk not in spk2embedding:
spk2embedding[spk] = []
spk2embedding[spk].append(embedding)

for k, v in spk2embedding.items():
spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()

torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir))
torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))
torch.save(utt2embedding, "{}/utt2embedding.pt".format(args.dir))
torch.save(spk2embedding, "{}/spk2embedding.pt".format(args.dir))


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--dir',
type=str)
parser.add_argument('--onnx_path',
type=str)
parser.add_argument("--dir", type=str)
parser.add_argument("--onnx_path", type=str)
parser.add_argument("--num_thread", type=int, default=8)
args = parser.parse_args()
main(args)