1+ import onnxruntime as ort
2+ import numpy as np
3+ from collections import defaultdict
4+ import traceback
5+ import sys
6+ sys .path .append ('.' )
7+ from pathlib import Path as P
8+
9+ class FrenchG2p :
10+ graphemes = ["" , "" , "" , "" , "-" , "a" , "b" , "c" , "d" , "e" , "f" , "g" ,
11+ "h" , "i" , "j" , "k" , "l" , "m" , "n" , "o" , "p" , "q" , "r" ,
12+ "s" , "t" , "u" , "v" , "w" , "x" , "y" , "z" , "à" , "á" , "â" , "ä" ,
13+ "æ" , "ç" , "è" , "é" , "ê" , "ë" , "î" , "ï" , "ñ" , "ô" , "ö" ,
14+ "ù" , "ú" , "û" , "ü" , "ÿ" ]
15+
16+ phonemes = ["" , "" , "" , "" , "aa" , "ai" , "an" , "au" , "bb" , "ch" , "dd" ,
17+ "ee" , "ei" , "eu" , "ff" , "gg" , "gn" , "ii" , "in" , "jj" , "kk" ,
18+ "ll" , "mm" , "nn" , "oe" , "on" , "oo" , "ou" , "pp" , "rr" , "ss" ,
19+ "tt" , "un" , "uu" , "uy" , "vv" , "ww" , "yy" , "zz" ]
20+
21+ def __init__ (self ):
22+ self .lock = None # Placeholder for thread safety if needed
23+ self .dict = {}
24+ self .grapheme_indexes = {}
25+ self .pred_cache = defaultdict (list )
26+ self .session = None
27+ self .phonemes = self .phonemes [4 :]
28+ self .load_pack ()
29+
30+ def load_pack (self ):
31+ dict_path = P ('./Assets/G2p/g2p-fr/dict.txt' )
32+ with open (dict_path , 'r' , encoding = 'utf-8' ) as f :
33+ for line in f :
34+ parts = line .strip ().split (' ' )
35+ if len (parts ) >= 2 :
36+ grapheme = parts [0 ].lower ()
37+ phoneme_parts = parts [1 :]
38+ phonemes = '' .join (phoneme_parts )
39+ self .dict [grapheme ] = phonemes .split ()
40+ else :
41+ print (f"Ignoring line: { line .strip ()} " )
42+
43+ # Create grapheme indexes (skip the first four graphemes)
44+ self .grapheme_indexes = {g : i + 4 for i , g in enumerate (self .graphemes [4 :])}
45+
46+ onnx_path = P ('./Assets/G2p/g2p-fr/g2p.onnx' )
47+ self .session = ort .InferenceSession (onnx_path )
48+
49+ def predict (self , input_text ):
50+ words = input_text .strip ().split ()
51+ predicted_phonemes = []
52+ for word in words :
53+ word_lower = word .lower ()
54+ if word_lower in self .dict :
55+ predicted_phonemes .append (' ' .join (self .dict [word_lower ]))
56+ else :
57+ cached_phoneme = self .pred_cache .get (word_lower )
58+ if cached_phoneme :
59+ predicted_phonemes .append (' ' .join (cached_phoneme ))
60+ else :
61+ predicted_phoneme = self .predict_with_model (word )
62+ self .pred_cache [word_lower ] = predicted_phoneme .split ()
63+ predicted_phonemes .append (predicted_phoneme )
64+ return ' ' .join (predicted_phonemes )
65+
66+ def predict_with_model (self , word ):
67+ # Encode input word as indices of graphemes
68+ word_with_dash = "-" + word # funny workaround for that first skipped phoneme
69+
70+ input_ids = np .array ([self .grapheme_indexes .get (c , 0 ) for c in word_with_dash ], dtype = np .int32 ) # equvilant to `Tensor<int> src = EncodeWord(grapheme);`
71+ input_length = len (input_ids )
72+
73+ if len (input_ids .shape ) == 1 :
74+ input_ids = np .expand_dims (input_ids , axis = 0 )
75+
76+ t = np .ones ((1 ,), dtype = np .int32 )
77+
78+ src = input_ids
79+ tgt = np .array ([2 , ], dtype = np .int32 )
80+ if len (tgt .shape ) == 1 :
81+ tgt = np .expand_dims (tgt , axis = 0 )
82+ print (tgt )
83+
84+ try :
85+ while t [0 ] < input_length and len (tgt ) < 48 :
86+ input_feed = {'src' : src , 'tgt' : tgt , 't' : t }
87+
88+ outputs = self .session .run (['pred' ], input_feed )
89+ pred = outputs [0 ].flatten ().astype (int )
90+ if pred != 2 :
91+ new_tgt_shape = (tgt .shape [0 ], tgt .shape [1 ] + 1 )
92+
93+ new_tgt = np .zeros (new_tgt_shape , dtype = np .int32 )
94+
95+ for i in range (tgt .shape [1 ]):
96+ new_tgt [:, i ] = tgt [:, i ]
97+
98+ new_tgt [:, tgt .shape [1 ]] = pred
99+ print (pred - 4 )
100+
101+ tgt = new_tgt
102+ else :
103+ t [0 ] += 1
104+
105+ # these lines are equivalent to `var phonemes = DecodePhonemes(tgt.Skip(1).ToArray());`
106+ predicted_phonemes = []
107+ for id in tgt .flatten ().astype (int ):
108+ if id != 2 : # skip the first phone (workaround) cuz of the np.array initial phoneme
109+ predicted_phonemes .append (self .phonemes [id - 4 ])
110+
111+ print (predicted_phonemes )
112+
113+ predicted_phonemes_str = ' ' .join (predicted_phonemes )
114+ return predicted_phonemes_str
115+ except Exception as e :
116+
117+ print ("Error in prediction" , traceback .format_exc ())
0 commit comments