From 8ca8c7252e0f0dd04ddc9f0fdca6c47f32d3fc8b Mon Sep 17 00:00:00 2001 From: niedev Date: Mon, 9 Feb 2026 14:19:48 +0100 Subject: [PATCH 01/15] Refactor of the decoder execution in Translator.java --- .idea/.gitignore | 2 - .../translation/Translator.java | 419 +++++++++--------- .../neural_networks/voice/Recognizer.java | 2 +- 3 files changed, 216 insertions(+), 207 deletions(-) diff --git a/.idea/.gitignore b/.idea/.gitignore index 2aba50e..046eaf9 100644 --- a/.idea/.gitignore +++ b/.idea/.gitignore @@ -1,8 +1,6 @@ -# Default ignored files /shelf/ /usage.statistics.xml /workspace.xml -#extra ignored files /caches /caches/* /gradle.xml diff --git a/app/src/main/java/nie/translator/rtranslator/voice_translation/neural_networks/translation/Translator.java b/app/src/main/java/nie/translator/rtranslator/voice_translation/neural_networks/translation/Translator.java index 2d7ae19..642c15b 100644 --- a/app/src/main/java/nie/translator/rtranslator/voice_translation/neural_networks/translation/Translator.java +++ b/app/src/main/java/nie/translator/rtranslator/voice_translation/neural_networks/translation/Translator.java @@ -64,7 +64,6 @@ import nie.translator.rtranslator.tools.CustomLocale; import nie.translator.rtranslator.tools.ErrorCodes; import nie.translator.rtranslator.tools.FileTools; -import nie.translator.rtranslator.tools.gui.animations.CustomAnimator; import nie.translator.rtranslator.tools.gui.messages.GuiMessage; import nie.translator.rtranslator.tools.nn.CacheContainerNative; import nie.translator.rtranslator.tools.nn.TensorUtils; @@ -133,17 +132,17 @@ private void initialize(@NonNull Global global, int mode, GeneralListener initLi String cacheInitializerPath; if(mode == NLLB || mode == NLLB_CACHE) { //8 bit - /*encoderPath = global.getFilesDir().getPath() + "/NLLB_encoder.onnx"; + encoderPath = global.getFilesDir().getPath() + "/NLLB_encoder.onnx"; decoderPath = global.getFilesDir().getPath() + "/NLLB_decoder.onnx"; vocabPath = global.getFilesDir().getPath() + "/sentencepiece_bpe.model"; embedAndLmHeadPath = global.getFilesDir().getPath() + "/NLLB_embed_and_lm_head.onnx"; - cacheInitializerPath = global.getFilesDir().getPath() + "/NLLB_cache_initializer.onnx";*/ + cacheInitializerPath = global.getFilesDir().getPath() + "/NLLB_cache_initializer.onnx"; //4 bit - encoderPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/NLLB" + "/nllb_encoder_4bit.onnx"; + /*encoderPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/NLLB" + "/nllb_encoder_4bit.onnx"; decoderPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/NLLB" + "/nllb_decoder_4bit.onnx"; vocabPath = global.getFilesDir().getPath() + "/sentencepiece_bpe.model"; embedAndLmHeadPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/NLLB" + "/nllb_embed_and_lm_head_4bit.onnx"; - cacheInitializerPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/NLLB" + "/nllb_cache_initializer_4bit.onnx"; + cacheInitializerPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/NLLB" + "/nllb_cache_initializer_4bit.onnx";*/ }else { //madlad encoderPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/Madlad" + "/Int4Acc4/madlad_encoder_4bit.onnx"; decoderPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/Madlad" + "/Int4Acc4/madlad_decoder_4bit.onnx"; @@ -699,9 +698,6 @@ private void performTextTranslation(final String textToTranslate, final CustomLo return; } //decoder execution - final int eos = tokenizer.PieceToID(""); - ArrayList completeOutput = new ArrayList(); - completeOutput.add(0); //tokenizer.PieceToID("") TranslateListener translateListener = new TranslateListener() { @Override public void onTranslatedText(String textToTranslate, String text, long resultID, boolean isFinal, CustomLocale languageOfText) { @@ -734,11 +730,11 @@ public void onFailure(int[] reasons, long value) { } }; if (beamSize > 1) { //beam search - executeCacheDecoderBeam(textToTranslate, input, encoderResult, completeBeamOutput, beamsOutputsProbabilities, outputLanguage, beamSize, translateListener); + executeCacheDecoder(textToTranslate, input, encoderResult, completeBeamOutput, beamsOutputsProbabilities, outputLanguage, beamSize, translateListener); } else if (beamSize == 1) { //greedy search (with kv cache) - executeCacheDecoderGreedy(textToTranslate, input, encoderResult, completeOutput, outputLanguage, translateListener); + executeCacheDecoder(textToTranslate, input, encoderResult, completeBeamOutput, null, outputLanguage, 1, translateListener); } - //we convert the ids of completeOutputs into a string and return it + //we convert the ids of completeBeamOutputs into a string and return it encoderResult.close(); int[] completeOutputArray; if ((mode == MADLAD_CACHE || mode == NLLB_CACHE) && beamSize > 1) { @@ -748,7 +744,7 @@ public void onFailure(int[] reasons, long value) { } completeOutputArray = completeBeamOutput[indexMax].stream().mapToInt(k -> k).toArray(); } else { - completeOutputArray = completeOutput.stream().mapToInt(k -> k).toArray(); //converte completeOutput in un array di int + completeOutputArray = completeBeamOutput[0].stream().mapToInt(k -> k).toArray(); //converte completeBeamOutput in un array di int } String finalSplitResult = tokenizer.decode(completeOutputArray); if (joinedStringOutput[0].equals("")) { @@ -828,6 +824,7 @@ private OnnxTensor executeEncoder(int[] inputIDs, int[] attentionMask){ } } + //todo: remove this method in the future public void executeCacheDecoderGreedy(String textToTranslate, TokenizerResult input, OnnxTensor encoderResult, ArrayList completeOutput, final CustomLocale outputLanguage, @Nullable final TranslateListener responseListener){ try { long time = System.currentTimeMillis(); @@ -899,7 +896,6 @@ public void executeCacheDecoderGreedy(String textToTranslate, TokenizerResult in embedResult = embedSession.run(embedInput, requestedOutputs); decoderInput.put("embed_matrix", (OnnxTensor) embedResult.get(0)); - //decoderInput.put("encoder_hidden_states", encoderResult); } if(j == 1){ long[] shape = {1, 16, 0, hiddenSize}; @@ -1008,7 +1004,7 @@ public void executeCacheDecoderGreedy(String textToTranslate, TokenizerResult in } } - public void executeCacheDecoderBeam(String textToTranslate, TokenizerResult input, OnnxTensor encoderResult, ArrayList[] completeBeamOutput, double[] beamsOutputsProbabilities, final CustomLocale outputLanguage, int beamSize, @Nullable final TranslateListener responseListener) { + public void executeCacheDecoder(String textToTranslate, TokenizerResult input, OnnxTensor encoderResult, ArrayList[] completeBeamOutput, @Nullable double[] beamsOutputsProbabilities, final CustomLocale outputLanguage, int beamSize, @Nullable final TranslateListener responseListener) { final int eos = tokenizer.PieceToID(""); int nLayers; int hiddenSize; @@ -1031,11 +1027,10 @@ public void executeCacheDecoderBeam(String textToTranslate, TokenizerResult inpu inputIDsTensor = TensorUtils.convertIntArrayToTensor(onnxEnv, new int[]{2}); //for the first iteration we use input_ids = 2, with batch_size = 1 } OnnxTensor encoderAttentionMaskTensor = TensorUtils.convertIntArrayToTensor(onnxEnv, input.getAttentionMask()); - int encoderInputIdsLength = input.getInputIDs().length; CacheContainerNative cacheContainer = null; OnnxTensor decoderOutput = null; Map decoderInput = new HashMap(); - float [][][] outputValues = null; + float [][][] logits = null; time = System.currentTimeMillis(); //preparing cache initializer input @@ -1044,61 +1039,11 @@ public void executeCacheDecoderBeam(String textToTranslate, TokenizerResult inpu //execution of the cache initializer OrtSession.Result initResult = cacheInitSession.run(initInput); android.util.Log.i("performance", "Cache initialization done in: " + (System.currentTimeMillis()-time) + "ms"); + encoderResult.close(); //we close it because from now on we only need initResult - time = System.currentTimeMillis(); //we convert the fixed decoder inputs to have batch_size==beamSize - OnnxTensor encoderResultBatched = null; - if(mode == MADLAD_CACHE) { - float[][] encoderValue = ((float[][][]) encoderResult.getValue())[0]; - float[] encoderValueFlatBatched = TensorUtils.flattenFloatArrayBatched(encoderValue, beamSize); - encoderResultBatched = TensorUtils.createFloatTensor(onnxEnv, encoderValueFlatBatched, new long[]{beamSize, encoderValue.length, encoderValue[0].length}); - encoderValue = null; //free the memory - encoderValueFlatBatched = null; //free the memory - //System.gc(); - android.util.Log.i("performance", "Encoder batch initialization done in: " + (System.currentTimeMillis()-time) + "ms"); - } - time = System.currentTimeMillis(); - OnnxTensor encoderAttentionMaskTensorBatched; - int[] encoderMaskFlatBatched = TensorUtils.flattenIntArrayBatched(input.getAttentionMask(), beamSize); - encoderAttentionMaskTensorBatched = TensorUtils.createIntTensor(onnxEnv, encoderMaskFlatBatched, new long[]{beamSize, input.getAttentionMask().length}); - encoderMaskFlatBatched = null; //free the memory - //System.gc(); - android.util.Log.i("performance", "Mask batch initialization done in: " + (System.currentTimeMillis()-time) + "ms"); - time = System.currentTimeMillis(); - OrtSession.Result initResultBatched; - String[] names = new String[2*nLayers]; - OnnxValue[] values = new OnnxValue[2*nLayers]; - boolean[] ownedByResult = new boolean[2*nLayers]; - Arrays.fill(ownedByResult, true); - String[] suffixes = {"key", "value"}; - long timeExtract = 0; - long timeBatch = 0; - long timeCreate = 0; - int count = 0; - for (int i = 0; i < nLayers; i++) { - for (String suffix: suffixes) { - //System.gc(); - names[count] = "present." + i + ".encoder."+suffix; - long timeInner = System.currentTimeMillis(); - float[][][] keyValue = ((float[][][][]) TensorUtils.extractValue(initResult, "present." + i + ".encoder."+suffix))[0]; - timeExtract += System.currentTimeMillis() - timeInner; - timeInner = System.currentTimeMillis(); - float[][][][] keyValueFlatBatched = TensorUtils.batchTensor(keyValue, beamSize); - timeBatch += System.currentTimeMillis() - timeInner; - timeInner = System.currentTimeMillis(); - values[count] = TensorUtils.createFloatTensorOptimized(onnxEnv, keyValueFlatBatched, new long[]{beamSize, keyValue.length, keyValue[0].length, keyValue[0][0].length});; - timeCreate += System.currentTimeMillis() - timeInner; - count++; - } - } - //the Result constructor is private but this way we can use it anyway - Constructor constructor = OrtSession.Result.class.getDeclaredConstructor(names.getClass(), values.getClass(), ownedByResult.getClass()); - constructor.setAccessible(true); - initResultBatched = constructor.newInstance(names, values, ownedByResult); - android.util.Log.i("performance", "InitResult extract done in: " + timeExtract + "ms"); - android.util.Log.i("performance", "InitResult batch done in: " + timeBatch + "ms"); - android.util.Log.i("performance", "InitResult create done in: " + timeCreate + "ms"); - android.util.Log.i("performance", "InitResult batch initialization done in: " + (System.currentTimeMillis()-time) + "ms"); + OnnxTensor encoderAttentionMaskTensorBatched = beamSize > 1 ? batchEncoderAttentionMask(input.getAttentionMask(), beamSize, true) : null; + OrtSession.Result initResultBatched = beamSize > 1 ? batchEncoderKvCache(initResult, nLayers, beamSize, true) : null; //we begin the iterative execution of the decoder String[] partialResults = new String[beamSize]; //used for log @@ -1112,14 +1057,15 @@ public void executeCacheDecoderBeam(String textToTranslate, TokenizerResult inpu OnnxTensor emptyInputIds = TensorUtils.createInt64TensorWithSingleValue(onnxEnv, 0, new long[]{EMPTY_BATCH_SIZE, 2}); OnnxTensor emptyInputIdsBatch = TensorUtils.createInt64TensorWithSingleValue(onnxEnv, 0, new long[]{beamSize, 2}); - while(input_ids[0] != eos){ //input_ids[0] should always contain the ultimate value generated from the text with highest probability (to be verified) + while(max[0] != eos){ //max[0] should always contain the ultimate value generated from the text with highest probability (to be verified) initialTime = System.currentTimeMillis(); time = System.currentTimeMillis(); + //we prepare the decoder input decoderInput = new HashMap(); OrtSession.Result embedResult = null; if(mode == NLLB_CACHE){ - //we do the embedding separately and then we pass the result to the encoder + //we do the embedding separately and then we pass the result to the decoder Map embedInput = new HashMap(); embedInput.put("input_ids", inputIDsTensor); embedInput.put("pre_logits", j == 1 ? emptyPreLogits : emptyPreLogitsBatch); @@ -1131,6 +1077,7 @@ public void executeCacheDecoderBeam(String textToTranslate, TokenizerResult inpu decoderInput.put("embed_matrix", (OnnxTensor) embedResult.get(0)); } if(mode == MADLAD_CACHE) { + //we do the embedding separately and then we pass the result to the decoder Map embedInput = new HashMap(); embedInput.put("input_ids", inputIDsTensor); ArraySet requestedOutputs = new ArraySet<>(); @@ -1138,15 +1085,11 @@ public void executeCacheDecoderBeam(String textToTranslate, TokenizerResult inpu embedResult = embedSession.run(embedInput, requestedOutputs); decoderInput.put("embed_matrix", (OnnxTensor) embedResult.get(0)); - //decoderInput.put("encoder_hidden_states", encoderResult); } decoderInput.put("input_ids", inputIDsTensor); - if(j == 1){ //se è la prima iterazione + if(j == 1){ //if it is the first iteration //we run the decoder with a batch_size = 1 decoderInput.put("encoder_attention_mask", encoderAttentionMaskTensor); - if(mode == MADLAD_CACHE) { - //decoderInput.put("encoder_hidden_states", encoderResult); - } long[] shape = {1, 16, 0, hiddenSize}; OnnxTensor decoderPastTensor = TensorUtils.createFloatTensorWithSingleValue(onnxEnv,0, shape); for (int i = 0; i < nLayers; i++) { @@ -1156,26 +1099,23 @@ public void executeCacheDecoderBeam(String textToTranslate, TokenizerResult inpu decoderInput.put("past_key_values." + i + ".encoder.value", (OnnxTensor) initResult.get("present." + i + ".encoder.value").get()); } }else { - if(j == 2) { + if(j == 2 && beamSize > 1) { encoderAttentionMaskTensor.close(); //we close it because from now on we only need encoderAttentionMaskTensorBatched - encoderResult.close(); //we close it because from now on we only need encoderResultBatched initResult.close(); //we close it because from now on we only need initResultBatched } //we run the decoder with batch_size = beamSize - decoderInput.put("encoder_attention_mask", encoderAttentionMaskTensorBatched); - if(mode == MADLAD_CACHE) { - //decoderInput.put("encoder_hidden_states", encoderResultBatched); - } + decoderInput.put("encoder_attention_mask", beamSize > 1 ? encoderAttentionMaskTensorBatched : encoderAttentionMaskTensor); for (int i = 0; i < nLayers; i++) { decoderInput.put("past_key_values." + i + ".decoder.key", (OnnxTensor) result.get("present." + i + ".decoder.key").get()); decoderInput.put("past_key_values." + i + ".decoder.value", (OnnxTensor) result.get("present." + i + ".decoder.value").get()); - decoderInput.put("past_key_values." + i + ".encoder.key", (OnnxTensor) initResultBatched.get("present." + i + ".encoder.key").get()); - decoderInput.put("past_key_values." + i + ".encoder.value", (OnnxTensor) initResultBatched.get("present." + i + ".encoder.value").get()); + decoderInput.put("past_key_values." + i + ".encoder.key", (OnnxTensor) (beamSize > 1 ? initResultBatched : initResult).get("present." + i + ".encoder.key").get()); + decoderInput.put("past_key_values." + i + ".encoder.value", (OnnxTensor) (beamSize > 1 ? initResultBatched : initResult).get("present." + i + ".encoder.value").get()); } } oldResult = result; android.util.Log.i("performance", "pre-execution of"+j+"th word done in: " + (System.currentTimeMillis()-time) + "ms"); time = System.currentTimeMillis(); + //decoder execution (with cache) result = decoderSession.run(decoderInput); android.util.Log.i("performance", "execution of"+j+"th word done in: " + (System.currentTimeMillis()-time) + "ms"); @@ -1188,7 +1128,7 @@ public void executeCacheDecoderBeam(String textToTranslate, TokenizerResult inpu embedResult.close(); } android.util.Log.i("performance", "release RAM of"+j+"th word done in: " + (System.currentTimeMillis()-time) + "ms"); - //we take the logits and the max value + //we take the logits OrtSession.Result lmHeadResult = null; if(mode == NLLB_CACHE) { //we execute the lmHead separately to get the logits @@ -1204,24 +1144,16 @@ public void executeCacheDecoderBeam(String textToTranslate, TokenizerResult inpu decoderOutput = (OnnxTensor) result.get("logits").get(); } //we take the logits and the larger "beamSize" values + logits = (float[][][]) decoderOutput.getValue(); if(j == 1) { //if we are at the first iteration - //decoderOutput = (OnnxTensor) result.get("logits").get(); - outputValues = (float[][][]) decoderOutput.getValue(); - //the "beamSize" words with highest probability are inserted into max and added to completeBeamOutput - ArrayList indexesToAvoid = new ArrayList<>(); - for (int i = 0; i < beamSize; i++) { - max[i] = Utils.getIndexOfLargest(outputValues[0][0], indexesToAvoid); - indexesToAvoid.add(max[i]); - completeBeamOutput[i].add(max[i]); - } - //we insert the initial probabilities of the "beamSize" output strings into beamsOutputsProbabilities - for (int i = 0; i < beamSize; i++) { - float maxLogit = outputValues[0][0][max[i]]; - //old version of probability calculation (softmax) - //beamsOutputsProbabilities[i] = Math.log(Utils.softmax(maxLogit, outputValues[0][0])); - //new version of probability calculation (logSumExp) - beamsOutputsProbabilities[i] = maxLogit - Utils.logSumExpFast(outputValues[0][0]); + if(beamSize > 1) { + //based on the logits, we initialize max, completeBeamOutput and beamsOutputsProbabilities + initBeamSearchData(logits, beamSize, max, completeBeamOutput, beamsOutputsProbabilities); + }else{ + max[0] = Utils.getIndexOfLargest(logits[0][0]); + completeBeamOutput[0].add(max[0]); } + //we prepare the inputs of the next iteration if(mode == NLLB_CACHE){ for(int i=0; i indexesToAvoid = new ArrayList<>(); - for (int i = 0; i < beamSize; i++) { - beamMax[k][i] = Utils.getIndexOfLargest(outputValues[k][0], indexesToAvoid); - indexesToAvoid.add(beamMax[k][i]); - } - } - //Now beamMax will contain for each decoder output ("beamSize" outputs) the "beamSize" words with highest probability, - // so for each output we calculate its overall probability for each of its "beamSize" words with highest probability - long timeSoftmax = System.currentTimeMillis(); - double[] beamsOutputsProbabilitiesTemp = new double[beamSize*beamSize]; - for(int k=0; k < beamSize; k++) { - //old version of probability calculation (softmax) - /*for (int i = 0; i < beamSize; i++) { - beamsOutputsProbabilitiesTemp[(k*beamSize)+i] = beamsOutputsProbabilities[k] + Math.log(Utils.softmax(outputValues[k][0][beamMax[k][i]], outputValues[k][0])); - if(beamMax[k][i] == eos){ - beamsOutputsProbabilitiesTemp[(k*beamSize)+i] = beamsOutputsProbabilitiesTemp[(k*beamSize)+i]/EOS_PENALTY; - } - }*/ - //new version of probability calculation (logSumExp) - double logSumExp = Utils.logSumExpFast(outputValues[k][0]); - for (int i = 0; i < beamSize; i++) { - float maxLogit = outputValues[k][0][beamMax[k][i]]; - beamsOutputsProbabilitiesTemp[(k*beamSize)+i] = beamsOutputsProbabilities[k] + maxLogit - logSumExp; - if(beamMax[k][i] == eos){ - beamsOutputsProbabilitiesTemp[(k*beamSize)+i] = beamsOutputsProbabilitiesTemp[(k*beamSize)+i]/EOS_PENALTY; - } - } - } - android.util.Log.i("performance", "softmax done in: " + (System.currentTimeMillis()-timeSoftmax) + "ms"); - // Now we save in maxProbabilities the indices of the "beamSize" words generated by the decoder that have the - // highest overall probability with their respective output sentences and then we will use them as the next inputs - ArrayList indexesToAvoid = new ArrayList<>(); - int[] maxProbabilities = new int[beamSize]; - for(int i=0; i[] oldCompleteBeamOutput = completeBeamOutput.clone(); - for (int i = 0; i < beamSize; i++) { - beamsOutputsProbabilities[i] = beamsOutputsProbabilitiesTemp[maxProbabilities[i]]; - completeBeamOutput[i] = (ArrayList) oldCompleteBeamOutput[maxProbabilities[i]/beamSize].clone(); - completeBeamOutput[i].add(beamMax[maxProbabilities[i]/beamSize][maxProbabilities[i]%beamSize]); + if(beamSize > 1) { + //based on the logits we update beam search data + int[] maxProbabilities = new int[beamSize]; + cacheContainer = updateBeamSearchData(logits, beamSize, eos, result, j, nLayers, 16, hiddenSize, cacheContainer, maxProbabilities, beamMax, max, completeBeamOutput, beamsOutputsProbabilities); + }else{ + max[0] = Utils.getIndexOfLargest(logits[0][0]); + completeBeamOutput[0].add(max[0]); } + //we prepare the inputs of the next iteration - for (int i = 0; i < beamSize; i++) { - input_ids[i] = beamMax[maxProbabilities[i]/beamSize][maxProbabilities[i]%beamSize]; - } + input_ids = max; inputIDsTensor = TensorUtils.createIntTensor(onnxEnv, input_ids, new long[]{beamSize,1}); - long timeCache = System.currentTimeMillis(); - CacheContainerNative oldCache = cacheContainer; - cacheContainer = new CacheContainerNative(onnxEnv, result, nLayers, beamSize, 16, j, hiddenSize); - if(oldCache != null){ - oldCache.close(); - } - android.util.Log.i("performance", "cache creation done in: " + (System.currentTimeMillis()-timeCache) + "ms"); - int[] indexes = new int[beamSize]; - for(int i=0; i 1) { + for (int i = 0; i < beamSize; i++) { + indexMax = Utils.getIndexOfLargest(beamsOutputsProbabilities); + } } - int [] outputIDs = completeBeamOutput[indexMax].stream().mapToInt(k -> k).toArray(); + int[] outputIDs = completeBeamOutput[indexMax].stream().mapToInt(k -> k).toArray(); String partialResult = tokenizer.decode(outputIDs); if(responseListener != null) { responseListener.onTranslatedText(textToTranslate, partialResult, currentResultID, false, outputLanguage); @@ -1346,20 +1203,11 @@ public void executeCacheDecoderBeam(String textToTranslate, TokenizerResult inpu } } - if(result != null) { - result.close(); - } + if(result != null) result.close(); initResult.close(); - if(cacheContainer != null) { - cacheContainer.close(); - } - if (encoderAttentionMaskTensorBatched != null) { - encoderAttentionMaskTensorBatched.close(); - } - if(encoderResultBatched != null) { - encoderResultBatched.close(); - } - initResultBatched.close(); + if(cacheContainer != null) cacheContainer.close(); + if(encoderAttentionMaskTensorBatched != null) encoderAttentionMaskTensorBatched.close(); + if(initResultBatched != null) initResultBatched.close(); } catch (OrtException | InvocationTargetException | NoSuchMethodException | IllegalAccessException | InstantiationException e) { @@ -1397,6 +1245,169 @@ private String correctText(String text, Locale locale){ return text; } + private OnnxTensor batchEncoderAttentionMask(int[] attentionMask, int batchSize, boolean log) throws OrtException { + long time = System.currentTimeMillis(); + int[] encoderMaskFlatBatched = TensorUtils.flattenIntArrayBatched(attentionMask, batchSize); + OnnxTensor encoderAttentionMaskTensorBatched = TensorUtils.createIntTensor(onnxEnv, encoderMaskFlatBatched, new long[]{batchSize, attentionMask.length}); + encoderMaskFlatBatched = null; //free the memory + //System.gc(); + if(log) { + android.util.Log.i("performance", "Mask batch initialization done in: " + (System.currentTimeMillis() - time) + "ms"); + } + return encoderAttentionMaskTensorBatched; + } + + @NonNull + private OrtSession.Result batchEncoderKvCache(OrtSession.Result result, int nLayers, int batchSize, boolean log) throws InvocationTargetException, IllegalAccessException, InstantiationException, NoSuchMethodException, OrtException { + long time = System.currentTimeMillis(); + String[] names = new String[2*nLayers]; + OnnxValue[] values = new OnnxValue[2*nLayers]; + boolean[] ownedByResult = new boolean[2*nLayers]; + Arrays.fill(ownedByResult, true); + String[] suffixes = {"key", "value"}; + long timeExtract = 0; + long timeBatch = 0; + long timeCreate = 0; + int count = 0; + for (int i = 0; i < nLayers; i++) { + for (String suffix: suffixes) { + //System.gc(); + names[count] = "present." + i + ".encoder."+suffix; + long timeInner = System.currentTimeMillis(); + float[][][] keyValue = ((float[][][][]) TensorUtils.extractValue(result, "present." + i + ".encoder."+suffix))[0]; + timeExtract += System.currentTimeMillis() - timeInner; + timeInner = System.currentTimeMillis(); + float[][][][] keyValueFlatBatched = TensorUtils.batchTensor(keyValue, batchSize); + timeBatch += System.currentTimeMillis() - timeInner; + timeInner = System.currentTimeMillis(); + values[count] = TensorUtils.createFloatTensorOptimized(onnxEnv, keyValueFlatBatched, new long[]{batchSize, keyValue.length, keyValue[0].length, keyValue[0][0].length});; + timeCreate += System.currentTimeMillis() - timeInner; + count++; + } + } + //the Result constructor is private but this way we can use it anyway + Constructor constructor = OrtSession.Result.class.getDeclaredConstructor(names.getClass(), values.getClass(), ownedByResult.getClass()); + constructor.setAccessible(true); + OrtSession.Result initResultBatched = constructor.newInstance(names, values, ownedByResult); + if(log) { + android.util.Log.i("performance", "InitResult extract done in: " + timeExtract + "ms"); + android.util.Log.i("performance", "InitResult batch done in: " + timeBatch + "ms"); + android.util.Log.i("performance", "InitResult create done in: " + timeCreate + "ms"); + android.util.Log.i("performance", "InitResult batch initialization done in: " + (System.currentTimeMillis() - time) + "ms"); + } + + return initResultBatched; + } + + private OrtSession.Result batchDecoderKvCache(OrtSession.Result result, OnnxTensor decoderOutput, int nLayers, int batchSize, boolean log) throws InvocationTargetException, IllegalAccessException, InstantiationException, NoSuchMethodException, OrtException { + long time = System.currentTimeMillis(); + String[] names = new String[2*nLayers+1]; + OnnxValue[] values = new OnnxValue[2*nLayers+1]; + boolean[] ownedByResult = new boolean[2*nLayers+1]; + Arrays.fill(ownedByResult, true); + names[0] = "logits"; + values[0] = decoderOutput; //result.get("logits").get(); + String[] suffixes = new String[]{"key", "value"}; + int count = 1; + for (int i = 0; i < nLayers; i++) { + for (String suffix: suffixes) { + names[count] = "present." + i + ".decoder."+suffix; + float[][][] keyValue = ((float[][][][]) TensorUtils.extractValue(result, "present." + i + ".decoder."+suffix))[0]; + float[] keyValueFlatBatched = TensorUtils.flattenFloatArrayBatched(keyValue, batchSize); + values[count] = TensorUtils.createFloatTensor(onnxEnv, keyValueFlatBatched, new long[]{batchSize, keyValue.length, keyValue[0].length, keyValue[0][0].length}); //todo: evaluate the use of createFloatTensorOptimized + count++; + } + } + if(log) { + android.util.Log.i("performance", "Decoder kvCache batch initialization done in: " + (System.currentTimeMillis() - time) + "ms"); + } + result.close(); + //the Result constructor is private but this way we can use it anyway + Constructor constructor = OrtSession.Result.class.getDeclaredConstructor(names.getClass(), values.getClass(), ownedByResult.getClass()); + constructor.setAccessible(true); + return constructor.newInstance(names, values, ownedByResult); + } + + private void initBeamSearchData(float [][][] logits, int beamSize, int[] max, ArrayList[] completeBeamOutput, double[] beamsOutputsProbabilities){ + //the "beamSize" words with highest probability are inserted into max and added to completeBeamOutput + ArrayList indexesToAvoid = new ArrayList<>(); + for (int i = 0; i < beamSize; i++) { + max[i] = Utils.getIndexOfLargest(logits[0][0], indexesToAvoid); + indexesToAvoid.add(max[i]); + completeBeamOutput[i].add(max[i]); + } + //we insert the initial probabilities of the "beamSize" output strings into beamsOutputsProbabilities + for (int i = 0; i < beamSize; i++) { + float maxLogit = logits[0][0][max[i]]; + beamsOutputsProbabilities[i] = maxLogit - Utils.logSumExpFast(logits[0][0]); + } + } + + private CacheContainerNative updateBeamSearchData( + float [][][] logits, int beamSize, int eos, OrtSession.Result decoderResult, int sequenceLength, int nLayers, int nHeads, int hiddenSize, + CacheContainerNative cacheContainer, int[] maxProbabilities, int[][] beamMax, int[] max, ArrayList[] completeBeamOutput, double[] beamsOutputsProbabilities + ){ + //for each of the "beamSize" decoder outputs, the "beamSize" words with the highest probability are inserted into beamMax + for(int k=0; k < beamSize; k++) { + ArrayList indexesToAvoid = new ArrayList<>(); + for (int i = 0; i < beamSize; i++) { + beamMax[k][i] = Utils.getIndexOfLargest(logits[k][0], indexesToAvoid); + indexesToAvoid.add(beamMax[k][i]); + } + } + //Now beamMax will contain for each decoder output ("beamSize" outputs) the "beamSize" words with highest probability, + // so for each output we calculate its overall probability for each of its "beamSize" words with highest probability + long timeSoftmax = System.currentTimeMillis(); + double[] beamsOutputsProbabilitiesTemp = new double[beamSize*beamSize]; + for(int k=0; k < beamSize; k++) { + //new version of probability calculation (logSumExp) + double logSumExp = Utils.logSumExpFast(logits[k][0]); + for (int i = 0; i < beamSize; i++) { + float maxLogit = logits[k][0][beamMax[k][i]]; + beamsOutputsProbabilitiesTemp[(k*beamSize)+i] = beamsOutputsProbabilities[k] + maxLogit - logSumExp; + if(beamMax[k][i] == eos){ + beamsOutputsProbabilitiesTemp[(k*beamSize)+i] = beamsOutputsProbabilitiesTemp[(k*beamSize)+i]/EOS_PENALTY; + } + } + } + android.util.Log.i("performance", "softmax done in: " + (System.currentTimeMillis()-timeSoftmax) + "ms"); + // Now we save in maxProbabilities the indices of the "beamSize" words generated by the decoder that have the + // highest overall probability with their respective output sentences and then we will use them as the next inputs + ArrayList indexesToAvoid = new ArrayList<>(); + for(int i=0; i[] oldCompleteBeamOutput = completeBeamOutput.clone(); + for (int i = 0; i < beamSize; i++) { + beamsOutputsProbabilities[i] = beamsOutputsProbabilitiesTemp[maxProbabilities[i]]; + completeBeamOutput[i] = (ArrayList) oldCompleteBeamOutput[maxProbabilities[i]/beamSize].clone(); + completeBeamOutput[i].add(beamMax[maxProbabilities[i]/beamSize][maxProbabilities[i]%beamSize]); + } + // reorder of the kvCache to match the new selected inputs for the next iteration + long timeCache = System.currentTimeMillis(); + CacheContainerNative oldCache = cacheContainer; + cacheContainer = new CacheContainerNative(onnxEnv, decoderResult, nLayers, beamSize, nHeads, sequenceLength, hiddenSize); + if(oldCache != null){ + oldCache.close(); + } + android.util.Log.i("performance", "cache creation done in: " + (System.currentTimeMillis()-timeCache) + "ms"); + int[] indexes = new int[beamSize]; + for(int i=0; i Date: Tue, 10 Feb 2026 14:23:38 +0100 Subject: [PATCH 02/15] Added initial support for HY-MT --- app/build.gradle | 7 +- .../settings/ModelManagerFragment.java | 12 +- .../tools/nn/CacheContainerNative.java | 11 +- .../translation/Tokenizer.java | 221 +++++++++------ .../translation/Translator.java | 267 ++++++++++++------ .../res/layout/fragment_model_manager.xml | 4 +- .../res/raw/hy_mt_supported_languages.xml | 209 ++++++++++++++ 7 files changed, 552 insertions(+), 179 deletions(-) create mode 100644 app/src/main/res/raw/hy_mt_supported_languages.xml diff --git a/app/build.gradle b/app/build.gradle index 2f0c80c..4fb1eee 100644 --- a/app/build.gradle +++ b/app/build.gradle @@ -42,7 +42,7 @@ android { //bergamot flags '-DANDROID_ABI=arm64-v8a', '-DANDROID_PLATFORM=android-28', - '-DANDROID_STL=c++_static', + '-DANDROID_STL=c++_shared', "-DANDROID_PIE=ON", "-DANDROID_CPP_FEATURES=exceptions", '-DCMAKE_BUILD_TYPE=Release', @@ -121,6 +121,11 @@ dependencies { implementation "androidx.lifecycle:lifecycle-extensions:2.2.0" //Download library implementation 'com.github.amitshekhariitbhu:PRDownloader:1.0.2' + // DJL HuggingFace tokenizers wrapper + implementation "ai.djl.android:core:0.33.0" + implementation "ai.djl.huggingface:tokenizers:0.33.0" + implementation "ai.djl.android:tokenizer-native:0.33.0" + //implementation 'androidx.core:core-ktx:1.10.0' implementation 'androidx.work:work-runtime:2.7.1' implementation 'androidx.exifinterface:exifinterface:1.3.7' diff --git a/app/src/main/java/nie/translator/rtranslator/settings/ModelManagerFragment.java b/app/src/main/java/nie/translator/rtranslator/settings/ModelManagerFragment.java index 98c0f0b..ce2ed4c 100644 --- a/app/src/main/java/nie/translator/rtranslator/settings/ModelManagerFragment.java +++ b/app/src/main/java/nie/translator/rtranslator/settings/ModelManagerFragment.java @@ -13,13 +13,9 @@ import androidx.annotation.NonNull; import androidx.annotation.Nullable; import androidx.fragment.app.Fragment; -import androidx.recyclerview.widget.DefaultItemAnimator; -import androidx.recyclerview.widget.LinearLayoutManager; -import androidx.recyclerview.widget.RecyclerView; import nie.translator.rtranslator.Global; import nie.translator.rtranslator.R; -import nie.translator.rtranslator.voice_translation.VoiceTranslationActivity; import nie.translator.rtranslator.voice_translation.neural_networks.translation.Translator; public class ModelManagerFragment extends Fragment { @@ -71,8 +67,8 @@ public void onActivityCreated(@Nullable Bundle savedInstanceState) { case Translator.MADLAD_CACHE: radioGroup.check(R.id.madlad_radio); break; - case Translator.GEMMA: - radioGroup.check(R.id.gemma_radio); + case Translator.HY_MT: + radioGroup.check(R.id.hy_radio); break; } @@ -87,8 +83,8 @@ public void onActivityCreated(@Nullable Bundle savedInstanceState) { case R.id.madlad_radio: changeModel(Translator.MADLAD_CACHE); break; - case R.id.gemma_radio: - changeModel(Translator.GEMMA); + case R.id.hy_radio: + changeModel(Translator.HY_MT); break; } }); diff --git a/app/src/main/java/nie/translator/rtranslator/tools/nn/CacheContainerNative.java b/app/src/main/java/nie/translator/rtranslator/tools/nn/CacheContainerNative.java index 336b85f..49676f9 100644 --- a/app/src/main/java/nie/translator/rtranslator/tools/nn/CacheContainerNative.java +++ b/app/src/main/java/nie/translator/rtranslator/tools/nn/CacheContainerNative.java @@ -29,24 +29,26 @@ import ai.onnxruntime.OrtEnvironment; import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtSession; +import nie.translator.rtranslator.voice_translation.neural_networks.translation.Translator; public class CacheContainerNative { private int[] shape; private OnnxTensor[] cacheTensors; private long cacheContainerNativePointer; + private int mode = Translator.NLLB; //Used to load CacheContainerNative.cpp on application startup static { System.loadLibrary("cache_container_native"); } - public CacheContainerNative(OrtEnvironment env, OrtSession.Result cache, int nLevels, int batchSize, int nHeads, int sequenceLength, int hiddenSize){ + public CacheContainerNative(OrtEnvironment env, OrtSession.Result cache, int nLevels, int batchSize, int nHeads, int sequenceLength, int hiddenSize, int mode){ try { cacheTensors = new OnnxTensor[nLevels*2]; cacheContainerNativePointer = initialize(nLevels*2, batchSize, nHeads, sequenceLength, hiddenSize); int count=0; for (int i = 0; i < nLevels; i++) { - cacheTensors[count] = (OnnxTensor) cache.get("present." + i + ".decoder.key").get(); + cacheTensors[count] = (OnnxTensor) cache.get("present." + i + (mode!= Translator.HY_MT ? ".decoder" : "") + ".key").get(); //we use OnnxTensor's private getBuffer method, which returns the data reference without making a copy of it and we pass this reference to the native cache container Method method = cacheTensors[count].getClass().getDeclaredMethod("getBuffer"); method.setAccessible(true); @@ -54,7 +56,7 @@ public CacheContainerNative(OrtEnvironment env, OrtSession.Result cache, int nLe insertValues(cacheContainerNativePointer, count, buffer); count++; - cacheTensors[count] = (OnnxTensor) cache.get("present." + i + ".decoder.value").get(); + cacheTensors[count] = (OnnxTensor) cache.get("present." + i + (mode!= Translator.HY_MT ? ".decoder" : "") + ".value").get(); //we use OnnxTensor's private getBuffer method, which returns the data reference without making a copy of it and we pass this reference to the native cache container method = cacheTensors[count].getClass().getDeclaredMethod("getBuffer"); method.setAccessible(true); @@ -63,6 +65,7 @@ public CacheContainerNative(OrtEnvironment env, OrtSession.Result cache, int nLe count++; } shape = new int[]{nLevels*2, batchSize, nHeads, sequenceLength, hiddenSize}; + this.mode = mode; } catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException e) { e.printStackTrace(); } @@ -86,7 +89,7 @@ public OrtSession.Result getCacheResult(OrtEnvironment env){ int count = 0; for (int i = 0; i < shape[0]/2; i++) { for (String suffix: suffixes) { - names[count] = "present." + i + ".decoder."+suffix; + names[count] = "present." + i + (mode!= Translator.HY_MT ? ".decoder." : ".") + suffix; ByteBuffer buffer = getBuffer(cacheContainerNativePointer, count); values[count] = OnnxTensor.createTensor(env, buffer, new long[]{shape[1], shape[2], shape[3], shape[4]}, OnnxJavaType.FLOAT); count++; diff --git a/app/src/main/java/nie/translator/rtranslator/voice_translation/neural_networks/translation/Tokenizer.java b/app/src/main/java/nie/translator/rtranslator/voice_translation/neural_networks/translation/Tokenizer.java index 7da2d20..3843f81 100644 --- a/app/src/main/java/nie/translator/rtranslator/voice_translation/neural_networks/translation/Tokenizer.java +++ b/app/src/main/java/nie/translator/rtranslator/voice_translation/neural_networks/translation/Tokenizer.java @@ -16,121 +16,171 @@ package nie.translator.rtranslator.voice_translation.neural_networks.translation; +import androidx.annotation.Nullable; + +import java.io.IOException; +import java.nio.file.Paths; import java.util.Arrays; +import ai.djl.huggingface.tokenizers.Encoding; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; + public class Tokenizer { public static final int NLLB = 0; public static final int NLLB_FIXED = 1; public static final int SEAMLESS = 2; public static final int MADLAD = 3; public static final int MADLAD_FIXED = 4; + public static final int HY_MT = 5; private SentencePieceProcessorJava spProcessor; + private HuggingFaceTokenizer hfTokenizer; private final int mode; //languagesNLLB and languagesSeamless contain the list of all supported languages, sorted by their ID order (their IDs are consecutive) private final String[] languagesNLLB = {"ace_Arab", "ace_Latn", "acm_Arab", "acq_Arab", "aeb_Arab", "afr_Latn", "ajp_Arab", "aka_Latn", "amh_Ethi", "apc_Arab", "arb_Arab", "ars_Arab", "ary_Arab", "arz_Arab", "asm_Beng", "ast_Latn", "awa_Deva", "ayr_Latn", "azb_Arab", "azj_Latn", "bak_Cyrl", "bam_Latn", "ban_Latn", "bel_Cyrl", "bem_Latn", "ben_Beng", "bho_Deva", "bjn_Arab", "bjn_Latn", "bod_Tibt", "bos_Latn", "bug_Latn", "bul_Cyrl", "cat_Latn", "ceb_Latn", "ces_Latn", "cjk_Latn", "ckb_Arab", "crh_Latn", "cym_Latn", "dan_Latn", "deu_Latn", "dik_Latn", "dyu_Latn", "dzo_Tibt", "ell_Grek", "eng_Latn", "epo_Latn", "est_Latn", "eus_Latn", "ewe_Latn", "fao_Latn", "pes_Arab", "fij_Latn", "fin_Latn", "fon_Latn", "fra_Latn", "fur_Latn", "fuv_Latn", "gla_Latn", "gle_Latn", "glg_Latn", "grn_Latn", "guj_Gujr", "hat_Latn", "hau_Latn", "heb_Hebr", "hin_Deva", "hne_Deva", "hrv_Latn", "hun_Latn", "hye_Armn", "ibo_Latn", "ilo_Latn", "ind_Latn", "isl_Latn", "ita_Latn", "jav_Latn", "jpn_Jpan", "kab_Latn", "kac_Latn", "kam_Latn", "kan_Knda", "kas_Arab", "kas_Deva", "kat_Geor", "knc_Arab", "knc_Latn", "kaz_Cyrl", "kbp_Latn", "kea_Latn", "khm_Khmr", "kik_Latn", "kin_Latn", "kir_Cyrl", "kmb_Latn", "kon_Latn", "kor_Hang", "kmr_Latn", "lao_Laoo", "lvs_Latn", "lij_Latn", "lim_Latn", "lin_Latn", "lit_Latn", "lmo_Latn", "ltg_Latn", "ltz_Latn", "lua_Latn", "lug_Latn", "luo_Latn", "lus_Latn", "mag_Deva", "mai_Deva", "mal_Mlym", "mar_Deva", "min_Latn", "mkd_Cyrl", "plt_Latn", "mlt_Latn", "mni_Beng", "khk_Cyrl", "mos_Latn", "mri_Latn", "zsm_Latn", "mya_Mymr", "nld_Latn", "nno_Latn", "nob_Latn", "npi_Deva", "nso_Latn", "nus_Latn", "nya_Latn", "oci_Latn", "gaz_Latn", "ory_Orya", "pag_Latn", "pan_Guru", "pap_Latn", "pol_Latn", "por_Latn", "prs_Arab", "pbt_Arab", "quy_Latn", "ron_Latn", "run_Latn", "rus_Cyrl", "sag_Latn", "san_Deva", "sat_Beng", "scn_Latn", "shn_Mymr", "sin_Sinh", "slk_Latn", "slv_Latn", "smo_Latn", "sna_Latn", "snd_Arab", "som_Latn", "sot_Latn", "spa_Latn", "als_Latn", "srd_Latn", "srp_Cyrl", "ssw_Latn", "sun_Latn", "swe_Latn", "swh_Latn", "szl_Latn", "tam_Taml", "tat_Cyrl", "tel_Telu", "tgk_Cyrl", "tgl_Latn", "tha_Thai", "tir_Ethi", "taq_Latn", "taq_Tfng", "tpi_Latn", "tsn_Latn", "tso_Latn", "tuk_Latn", "tum_Latn", "tur_Latn", "twi_Latn", "tzm_Tfng", "uig_Arab", "ukr_Cyrl", "umb_Latn", "urd_Arab", "uzn_Latn", "vec_Latn", "vie_Latn", "war_Latn", "wol_Latn", "xho_Latn", "ydd_Hebr", "yor_Latn", "yue_Hant", "zho_Hans", "zho_Hant", "zul_Latn"}; private final String[] languagesSeamless = {"ace", "ace_Latn", "acm", "acq", "aeb", "afr", "ajp", "aka", "amh", "apc", "arb", "ars", "ary", "arz", "asm", "ast", "awa", "ayr", "azb", "azj", "bak", "bam", "ban", "bel", "bem", "ben", "bho", "bjn", "bjn_Latn", "bod", "bos", "bug", "bul", "cat", "ceb", "ces", "cjk", "ckb", "crh", "cym", "dan", "deu", "dik", "dyu", "dzo", "ell", "eng", "epo", "est", "eus", "ewe", "fao", "pes", "fij", "fin", "fon", "fra", "fur", "fuv", "gla", "gle", "glg", "grn", "guj", "hat", "hau", "heb", "hin", "hne", "hrv", "hun", "hye", "ibo", "ilo", "ind", "isl", "ita", "jav", "jpn", "kab", "kac", "kam", "kan", "kas", "kas_Deva", "kat", "knc", "knc_Latn", "kaz", "kbp", "kea", "khm", "kik", "kin", "kir", "kmb", "kon", "kor", "kmr", "lao", "lvs", "lij", "lim", "lin", "lit", "lmo", "ltg", "ltz", "lua", "lug", "luo", "lus", "mag", "mai", "mal", "mar", "min", "mkd", "plt", "mlt", "mni", "khk", "mos", "mri", "zsm", "mya", "nld", "nno", "nob", "npi", "nso", "nus", "nya", "oci", "gaz", "ory", "pag", "pan", "pap", "pol", "por", "prs", "pbt", "quy", "ron", "run", "rus", "sag", "san", "sat", "scn", "shn", "sin", "slk", "slv", "smo", "sna", "snd", "som", "sot", "spa", "als", "srd", "srp", "ssw", "sun", "swe", "swh", "szl", "tam", "tat", "tel", "tgk", "tgl", "tha", "tir", "taq", "taq_Tfng", "tpi", "tsn", "tso", "tuk", "tum", "tur", "twi", "tzm", "uig", "ukr", "umb", "urd", "uzn", "vec", "vie", "war", "wol", "xho", "ydd", "yor", "yue", "cmn", "cmn_Hant", "zul"}; private final int DICTIONARY_LENGTH = 256000; - public Tokenizer(String vocab_file, int mode) { - spProcessor = new SentencePieceProcessorJava(); - spProcessor.Load(vocab_file); + public Tokenizer(String vocab_path, int mode) throws IOException { + if(mode != HY_MT) { + spProcessor = new SentencePieceProcessorJava(); + spProcessor.Load(vocab_path); + }else{ + hfTokenizer = HuggingFaceTokenizer.newInstance(Paths.get(vocab_path)); + } this.mode = mode; } - public TokenizerResult tokenize(String srcLanguage, String tgtLanguage, String text) { - //for madlad we add <2tgtLanguage> at the beginning of the text (srcLanguage is not specified) - if (mode == MADLAD || mode==MADLAD_FIXED){ - text = "<2"+tgtLanguage+"> "+text; - } - //we translate text into ids via sentencepiece - int[] ids = spProcessor.encode(text); + public TokenizerResult tokenize(String text, String srcLanguage, String tgtLanguage){ + return tokenize(text, srcLanguage, tgtLanguage, null, null); + } + + public TokenizerResult tokenize(String text, String srcLanguage, String tgtLanguage, @Nullable Translator.HyLanguageInfo srcLangInfo, @Nullable Translator.HyLanguageInfo tgtLangInfo) { + if(mode != HY_MT) { + //for madlad we add <2tgtLanguage> at the beginning of the text (srcLanguage is not specified) + if (mode == MADLAD || mode == MADLAD_FIXED) { + text = "<2" + tgtLanguage + "> " + text; + } + //we translate text into ids via sentencepiece + int[] ids = spProcessor.encode(text); /* The NLLBTokenizer's dictionary has a different mapping for tokens identified by the first 4 IDs values (from 0 to 3), also for the other IDs it has a value equal to those of the dictionary we passed to sentencepiece but with an addition of 1 (idNLLB = idSentencePiece + 1), so now we make the necessary adjustments */ - /* For the SeamlessTokenizer's dictionary the same thing is true but the + 1 is also valid for the first 4 values of the IDs (the value 0, in addition, represents the padding) */ - if(mode != MADLAD && mode != MADLAD_FIXED) { //MADLAD has a one-to-one match to the sentencepiece dictionary. - for (int i = 0; i < ids.length; i++) { - //we add 1 to each element of ids - ids[i] = ids[i] + 1; - //we replace the values from 0 to 3 with the correct ones for NLBTokenizer (for Seamless the values are already all correct) - if (mode == NLLB || mode == NLLB_FIXED) { - switch (ids[i]) { - case 1: { - ids[i] = 3; - break; - } - case 2: { - ids[i] = 0; - break; - } - case 3: { - ids[i] = 2; - break; + /* For the SeamlessTokenizer's dictionary the same thing is true but the + 1 is also valid for the first 4 values of the IDs (the value 0, in addition, represents the padding) */ + if (mode != MADLAD && mode != MADLAD_FIXED) { //MADLAD has a one-to-one match to the sentencepiece dictionary. + for (int i = 0; i < ids.length; i++) { + //we add 1 to each element of ids + ids[i] = ids[i] + 1; + //we replace the values from 0 to 3 with the correct ones for NLBTokenizer (for Seamless the values are already all correct) + if (mode == NLLB || mode == NLLB_FIXED) { + switch (ids[i]) { + case 1: { + ids[i] = 3; + break; + } + case 2: { + ids[i] = 0; + break; + } + case 3: { + ids[i] = 2; + break; + } } } } } - } - //add at the end and srcLanguage at the beginning (srcLanguage is not added for Madlad) - int eos = PieceToID(""); - int srcLanguageID = getLanguageID(srcLanguage); - int[] idsExtended; - if(mode != MADLAD && mode != MADLAD_FIXED) { - idsExtended = new int[ids.length + 2]; - System.arraycopy(ids, 0, idsExtended, 1, ids.length); - idsExtended[idsExtended.length - 1] = eos; - idsExtended[0] = srcLanguageID; - }else{ - idsExtended = new int[ids.length + 1]; - System.arraycopy(ids, 0, idsExtended, 0, ids.length); - idsExtended[idsExtended.length - 1] = eos; - } + //add at the end and srcLanguage at the beginning (srcLanguage is not added for Madlad) + int eos = PieceToID(""); + int srcLanguageID = getLanguageID(srcLanguage); + int[] idsExtended; + if (mode != MADLAD && mode != MADLAD_FIXED) { + idsExtended = new int[ids.length + 2]; + System.arraycopy(ids, 0, idsExtended, 1, ids.length); + idsExtended[idsExtended.length - 1] = eos; + idsExtended[0] = srcLanguageID; + } else { + idsExtended = new int[ids.length + 1]; + System.arraycopy(ids, 0, idsExtended, 0, ids.length); + idsExtended[idsExtended.length - 1] = eos; + } - //we create the attention mask - int[] attentionMask = new int[idsExtended.length]; - Arrays.fill(attentionMask, 1); - - if(mode == NLLB || mode == MADLAD) { - return new TokenizerResult(idsExtended, attentionMask); - }else if(mode == SEAMLESS){ - //for seamless the ids must always be 512, we fill the empty ids with padding (0) - int[] idsPadded = new int[512]; - Arrays.fill(idsPadded, 0); - System.arraycopy(idsExtended, 0, idsPadded, 0, idsExtended.length); - //for seamless also the attention mask must always have length 512, we fill the rest with 0 - int[] attentionMaskPadded = new int[512]; - Arrays.fill(attentionMaskPadded, 0); - System.arraycopy(attentionMask, 0, attentionMaskPadded, 0, attentionMask.length); - return new TokenizerResult(idsPadded, attentionMaskPadded); - }else if(mode == NLLB_FIXED){ - //for NLLB Fixed the ids must always be 256, we fill the empty ids with padding (0) - int[] idsPadded = new int[256]; - Arrays.fill(idsPadded, 0); - System.arraycopy(idsExtended, 0, idsPadded, 0, idsExtended.length); - //also the attention mask must always have length 256, we fill the rest with 0 - int[] attentionMaskPadded = new int[256]; - Arrays.fill(attentionMaskPadded, 0); - System.arraycopy(attentionMask, 0, attentionMaskPadded, 0, attentionMask.length); - return new TokenizerResult(idsPadded, attentionMaskPadded); - }else{ - //for Madlad Fixed the ids must always be 128, we fill the empty ids with padding (1) - int[] idsPadded = new int[128]; - Arrays.fill(idsPadded, 1); - System.arraycopy(idsExtended, 0, idsPadded, 0, idsExtended.length); - //also the attention mask must always have length 128, we fill the rest with 0 - int[] attentionMaskPadded = new int[128]; - Arrays.fill(attentionMaskPadded, 0); - System.arraycopy(attentionMask, 0, attentionMaskPadded, 0, attentionMask.length); - return new TokenizerResult(idsPadded, attentionMaskPadded); + //we create the attention mask + int[] attentionMask = new int[idsExtended.length]; + Arrays.fill(attentionMask, 1); + + if (mode == NLLB || mode == MADLAD) { + return new TokenizerResult(idsExtended, attentionMask); + } else if (mode == SEAMLESS) { + //for seamless the ids must always be 512, we fill the empty ids with padding (0) + int[] idsPadded = new int[512]; + Arrays.fill(idsPadded, 0); + System.arraycopy(idsExtended, 0, idsPadded, 0, idsExtended.length); + //for seamless also the attention mask must always have length 512, we fill the rest with 0 + int[] attentionMaskPadded = new int[512]; + Arrays.fill(attentionMaskPadded, 0); + System.arraycopy(attentionMask, 0, attentionMaskPadded, 0, attentionMask.length); + return new TokenizerResult(idsPadded, attentionMaskPadded); + } else if (mode == NLLB_FIXED) { + //for NLLB Fixed the ids must always be 256, we fill the empty ids with padding (0) + int[] idsPadded = new int[256]; + Arrays.fill(idsPadded, 0); + System.arraycopy(idsExtended, 0, idsPadded, 0, idsExtended.length); + //also the attention mask must always have length 256, we fill the rest with 0 + int[] attentionMaskPadded = new int[256]; + Arrays.fill(attentionMaskPadded, 0); + System.arraycopy(attentionMask, 0, attentionMaskPadded, 0, attentionMask.length); + return new TokenizerResult(idsPadded, attentionMaskPadded); + } else { + //for Madlad Fixed the ids must always be 128, we fill the empty ids with padding (1) + int[] idsPadded = new int[128]; + Arrays.fill(idsPadded, 1); + System.arraycopy(idsExtended, 0, idsPadded, 0, idsExtended.length); + //also the attention mask must always have length 128, we fill the rest with 0 + int[] attentionMaskPadded = new int[128]; + Arrays.fill(attentionMaskPadded, 0); + System.arraycopy(attentionMask, 0, attentionMaskPadded, 0, attentionMask.length); + return new TokenizerResult(idsPadded, attentionMaskPadded); + } + } else { + //tokenization for hy-mt + String tgtLangEnName = tgtLanguage; + String tgtLangZhName = tgtLanguage; + if(tgtLangInfo != null){ + tgtLangEnName = tgtLangInfo.enName; + tgtLangZhName = tgtLangInfo.zhName; + } + String prompt; + if(srcLanguage.equals("zh")) { + prompt = "<|hy_begin▁of▁sentence|><|hy_User|>将以下文本翻译为"+tgtLangZhName+",注意只需要输出翻译后的结果,不要额外解释:\n\n"+text+"<|hy_Assistant|>"; + } else { + prompt = "<|hy_begin▁of▁sentence|><|hy_User|>Translate the following segment into "+tgtLangEnName+", without additional explanation.\n\n"+text+"<|hy_Assistant|>"; + } + Encoding result = hfTokenizer.encode(new String[]{prompt}, false, false); + + //we convert inputIds and attention mask from long[] to int[] + long[] inputIdsLong = result.getIds(); + int[] inputIds = new int[inputIdsLong.length]; + for (int i = 0; i < inputIds.length; i++) { + inputIds[i] = (int) inputIdsLong[i]; + } + long[] attentionMaskLong = result.getIds(); + int[] attentionMask = new int[attentionMaskLong.length]; + for (int i = 0; i < attentionMask.length; i++) { + attentionMask[i] = (int) attentionMaskLong[i]; + } + + return new TokenizerResult(inputIds, attentionMask); } } public int PieceToID(String token){ if(mode == NLLB || mode == NLLB_FIXED || mode==MADLAD || mode==MADLAD_FIXED) { return spProcessor.PieceToID(token); + }else if(mode == HY_MT) { + return (int) hfTokenizer.encode(token).getIds()[0]; }else{ return spProcessor.PieceToID(token)+1; } @@ -155,13 +205,20 @@ public int getLanguageID(String language){ public String decode(int[] ids) { String output = ""; - if(mode != MADLAD && mode != MADLAD_FIXED) { + if(mode == NLLB || mode == NLLB_FIXED) { for (int i = 0; i < ids.length; i++) { if (ids[i] < DICTIONARY_LENGTH && ids[i] > 3) { //This check skips special tokens and those that sentencepiece does not have in the dictionary (such as languages) output = output.concat(spProcessor.IDToPiece(ids[i] - 1)); } } - }else{ + }else if(mode == HY_MT){ + long[] inputIdsLong = new long[ids.length]; + for (int i = 0; i < ids.length; i++) { + inputIdsLong[i] = (long) ids[i]; + } + output = hfTokenizer.decode(inputIdsLong); + output = output.replace("<|hy_place▁holder▁no▁2|>", ""); //we remove the eos token + }else{ //madlad and seamless for (int i = 0; i < ids.length; i++) { if (ids[i] < DICTIONARY_LENGTH && ids[i] > 3) { //This check skips special tokens and those that sentencepiece does not have in the dictionary (such as languages) output = output.concat(spProcessor.IDToPiece(ids[i])); diff --git a/app/src/main/java/nie/translator/rtranslator/voice_translation/neural_networks/translation/Translator.java b/app/src/main/java/nie/translator/rtranslator/voice_translation/neural_networks/translation/Translator.java index 642c15b..ca65fbc 100644 --- a/app/src/main/java/nie/translator/rtranslator/voice_translation/neural_networks/translation/Translator.java +++ b/app/src/main/java/nie/translator/rtranslator/voice_translation/neural_networks/translation/Translator.java @@ -79,7 +79,7 @@ public class Translator extends NeuralNetworkApi { public static final int MADLAD = 3; public static final int MADLAD_CACHE = 5; public static final int MOZILLA = 7; - public static final int GEMMA = 8; + public static final int HY_MT = 8; private int mode; private final BergamotModelsIndicator bergamotModelsIndicator = new BergamotModelsIndicator(); private Tokenizer tokenizer; @@ -89,7 +89,8 @@ public class Translator extends NeuralNetworkApi { private OrtSession cacheInitSession; private OrtSession embedAndLmHeadSession; private OrtSession embedSession; - private Map nllbLanguagesCodes = new HashMap(); + private final Map nllbLanguagesCodes = new HashMap<>(); + private final Map hyLanguagesInfo = new HashMap<>(); private static final double EOS_PENALTY = 0.9; @Nullable private GuiMessage lastInputText; @@ -120,16 +121,17 @@ public Translator(@NonNull Global global, int mode, GeneralListener initListener this.mode = mode; mainHandler = new android.os.Handler(Looper.getMainLooper()); initializeNllbLanguagesCodes(global); + initializeHyLanguagesInfo(global); initialize(global, mode, initListener); } private void initialize(@NonNull Global global, int mode, GeneralListener initListener){ - String encoderPath; - String decoderPath; - String vocabPath; - String embedAndLmHeadPath; - String cacheInitializerPath; + String encoderPath = ""; + String decoderPath = ""; + String vocabPath = ""; + String embedAndLmHeadPath = ""; + String cacheInitializerPath = ""; if(mode == NLLB || mode == NLLB_CACHE) { //8 bit encoderPath = global.getFilesDir().getPath() + "/NLLB_encoder.onnx"; @@ -143,12 +145,15 @@ private void initialize(@NonNull Global global, int mode, GeneralListener initLi vocabPath = global.getFilesDir().getPath() + "/sentencepiece_bpe.model"; embedAndLmHeadPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/NLLB" + "/nllb_embed_and_lm_head_4bit.onnx"; cacheInitializerPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/NLLB" + "/nllb_cache_initializer_4bit.onnx";*/ - }else { //madlad + }else if(mode == MADLAD || mode == MADLAD_CACHE){ //madlad encoderPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/Madlad" + "/Int4Acc4/madlad_encoder_4bit.onnx"; decoderPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/Madlad" + "/Int4Acc4/madlad_decoder_4bit.onnx"; vocabPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/Madlad" + "/spiece.model"; embedAndLmHeadPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/Madlad" + "/madlad_embed_8bit.onnx"; cacheInitializerPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/Madlad" + "/Int4Acc4/madlad_cache_initializer_4bit.onnx"; + }else { //hy-mt + decoderPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/HY-MT" + "/model_int8_final.onnx"; + vocabPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/HY-MT" + "/tokenizer.json"; } String finalDecoderPath = decoderPath; @@ -209,13 +214,13 @@ public void onFailure(int[] reasons, long value) { encoderOptions.setMemoryPatternOptimization(arena); encoderOptions.setCPUArenaAllocator(arena); encoderOptions.setOptimizationLevel(optDefaultLevel); - encoderSession = onnxEnv.createSession(finalEncoderPath, encoderOptions); + if(mode != HY_MT) encoderSession = onnxEnv.createSession(finalEncoderPath, encoderOptions); OrtSession.SessionOptions cacheInitOptions = new OrtSession.SessionOptions(); cacheInitOptions.setMemoryPatternOptimization(arena); cacheInitOptions.setCPUArenaAllocator(arena); cacheInitOptions.setOptimizationLevel(optDefaultLevel); - cacheInitSession = onnxEnv.createSession(finalCacheInitializerPath, cacheInitOptions); + if(mode != HY_MT) cacheInitSession = onnxEnv.createSession(finalCacheInitializerPath, cacheInitOptions); OrtSession.SessionOptions embedAndLmHeadOptions = new OrtSession.SessionOptions(); embedAndLmHeadOptions.setMemoryPatternOptimization(arena); @@ -223,7 +228,7 @@ public void onFailure(int[] reasons, long value) { embedAndLmHeadOptions.setOptimizationLevel(optDefaultLevel); if (mode == MADLAD_CACHE) { embedSession = onnxEnv.createSession(finalEmbedAndLmHeadPath, embedAndLmHeadOptions); - } else { + } else if(mode != HY_MT){ embedAndLmHeadSession = onnxEnv.createSession(finalEmbedAndLmHeadPath, embedAndLmHeadOptions); } @@ -236,15 +241,18 @@ public void onFailure(int[] reasons, long value) { initListener.onSuccess(); } - } catch (OrtException e) { + if(mode == MADLAD || mode == MADLAD_CACHE) { + tokenizer = new Tokenizer(finalVocabPath, Tokenizer.MADLAD); + }else if(mode == NLLB || mode == NLLB_CACHE) { + tokenizer = new Tokenizer(finalVocabPath, Tokenizer.NLLB); + }else if(mode == HY_MT) { + tokenizer = new Tokenizer(finalVocabPath, Tokenizer.HY_MT); + } + + } catch (OrtException | IOException e) { e.printStackTrace(); mainHandler.post(() -> initListener.onFailure(new int[]{ErrorCodes.ERROR_LOADING_MODEL},0)); } - if(mode == MADLAD_CACHE) { - tokenizer = new Tokenizer(finalVocabPath, Tokenizer.MADLAD); - }else{ - tokenizer = new Tokenizer(finalVocabPath, Tokenizer.NLLB); - } } }; t.start(); @@ -639,6 +647,14 @@ private void performTextTranslation(final String textToTranslate, final CustomLo } if(mode != MOZILLA){ + int maxLength = 200; + if(mode == NLLB || mode == NLLB_CACHE){ + maxLength = 200; + }else if(mode == MADLAD || mode == MADLAD_CACHE){ + maxLength = 200; //todo: research the best value for madlad + }else if(mode == HY_MT){ + maxLength = 5000; //todo: research the best value for hy-mt + } //we split the input text in sentences ArrayList textSplit = new ArrayList<>(); BreakIterator iterator = BreakIterator.getSentenceInstance(inputLanguage.getLocale()); @@ -652,9 +668,9 @@ private void performTextTranslation(final String textToTranslate, final CustomLo while (joined) { joined = false; for (int i = 1; i < textSplit.size(); i++) { - int numTokens = tokenizer.tokenize(getNllbLanguageCode(inputLanguage.getCode()), getNllbLanguageCode(outputLanguage.getCode()), textSplit.get(i - 1)).getInputIDs().length; - int numTokens2 = tokenizer.tokenize(getNllbLanguageCode(inputLanguage.getCode()), getNllbLanguageCode(outputLanguage.getCode()), textSplit.get(i)).getInputIDs().length; - if ((numTokens + numTokens2 < 200) || (numTokens2 < 5)) { + int numTokens = tokenize(textSplit.get(i - 1), inputLanguage, outputLanguage).getInputIDs().length; + int numTokens2 = tokenize(textSplit.get(i), inputLanguage, outputLanguage).getInputIDs().length; + if ((numTokens + numTokens2 < maxLength) || (numTokens2 < 5)) { textSplit.set(i - 1, textSplit.get(i - 1) + textSplit.get(i)); textSplit.remove(i); i = i - 1; @@ -679,23 +695,22 @@ private void performTextTranslation(final String textToTranslate, final CustomLo long time = System.currentTimeMillis(); TokenizerResult input = null; String correctedSubText = correctText(textSplit.get(i), inputLanguage.getLocale()); - if (mode == MADLAD_CACHE) { - input = tokenizer.tokenize(inputLanguage.getCode(), outputLanguage.getCode(), correctedSubText); - } else { //if mode == NLLB_CACHE - input = tokenizer.tokenize(getNllbLanguageCode(inputLanguage.getCode()), getNllbLanguageCode(outputLanguage.getCode()), correctedSubText); - } + input = tokenize(correctedSubText, inputLanguage, outputLanguage); android.util.Log.i("performance", "Tokenization done in: " + (System.currentTimeMillis() - time) + "ms"); //encoder execution time = System.currentTimeMillis(); - OnnxTensor encoderResult = executeEncoder(input.getInputIDs(), input.getAttentionMask()); - android.util.Log.i("performance", "Encoder done in: " + (System.currentTimeMillis() - time) + "ms"); - if (encoderResult == null) { - if (responseListener != null) { - mainHandler.post(() -> responseListener.onFailure(new int[]{ErrorCodes.ERROR_EXECUTING_MODEL}, 0)); - } else { - mainHandler.post(() -> notifyError(new int[]{ErrorCodes.ERROR_EXECUTING_MODEL}, 0)); + OnnxTensor encoderResult = null; + if(mode != HY_MT) { + encoderResult = executeEncoder(input.getInputIDs(), input.getAttentionMask()); + android.util.Log.i("performance", "Encoder done in: " + (System.currentTimeMillis() - time) + "ms"); + if (encoderResult == null) { + if (responseListener != null) { + mainHandler.post(() -> responseListener.onFailure(new int[]{ErrorCodes.ERROR_EXECUTING_MODEL}, 0)); + } else { + mainHandler.post(() -> notifyError(new int[]{ErrorCodes.ERROR_EXECUTING_MODEL}, 0)); + } + return; } - return; } //decoder execution TranslateListener translateListener = new TranslateListener() { @@ -735,7 +750,7 @@ public void onFailure(int[] reasons, long value) { executeCacheDecoder(textToTranslate, input, encoderResult, completeBeamOutput, null, outputLanguage, 1, translateListener); } //we convert the ids of completeBeamOutputs into a string and return it - encoderResult.close(); + if(encoderResult != null) encoderResult.close(); int[] completeOutputArray; if ((mode == MADLAD_CACHE || mode == NLLB_CACHE) && beamSize > 1) { int indexMax = 0; @@ -1004,16 +1019,32 @@ public void executeCacheDecoderGreedy(String textToTranslate, TokenizerResult in } } - public void executeCacheDecoder(String textToTranslate, TokenizerResult input, OnnxTensor encoderResult, ArrayList[] completeBeamOutput, @Nullable double[] beamsOutputsProbabilities, final CustomLocale outputLanguage, int beamSize, @Nullable final TranslateListener responseListener) { - final int eos = tokenizer.PieceToID(""); + public void executeCacheDecoder(String textToTranslate, TokenizerResult input, @Nullable OnnxTensor encoderResult, ArrayList[] completeBeamOutput, @Nullable double[] beamsOutputsProbabilities, final CustomLocale outputLanguage, int beamSize, @Nullable final TranslateListener responseListener) { + int eos; + if(mode == HY_MT){ + eos = tokenizer.PieceToID("<|hy_place▁holder▁no▁2|>"); + }else{ + eos = tokenizer.PieceToID(""); + } int nLayers; int hiddenSize; + int hiddenSizeAttention; + int nHeads; if(mode == MADLAD_CACHE){ nLayers = 32; - hiddenSize = 128; - }else{ //if mode == NLLB_CACHE + hiddenSize = 1024; + hiddenSizeAttention = 128; + nHeads = 16; + }else if(mode == NLLB_CACHE){ nLayers = 12; - hiddenSize = 64; + hiddenSize = 1024; + hiddenSizeAttention = 64; + nHeads = 16; + }else{ //if mode == HY_MT + nLayers = 32; + hiddenSize = 2048; + hiddenSizeAttention = 128; + nHeads = 4; //nHeads in this case refers only to the number of heads used in kvCache, the real number of heads are 16, but this model uses group query attention, with a group size of 4 } try { @@ -1023,10 +1054,13 @@ public void executeCacheDecoder(String textToTranslate, TokenizerResult input, O OnnxTensor inputIDsTensor; if(mode == MADLAD_CACHE){ inputIDsTensor = TensorUtils.convertIntArrayToTensor(onnxEnv, new int[]{0}); //for the first iteration we use input_ids = 0, with batch_size = 1 - }else{ //if mode == NLLB_CACHE + }else if(mode == NLLB_CACHE){ inputIDsTensor = TensorUtils.convertIntArrayToTensor(onnxEnv, new int[]{2}); //for the first iteration we use input_ids = 2, with batch_size = 1 + }else{ // if mode == HY_MT + inputIDsTensor = TensorUtils.convertIntArrayToTensor(onnxEnv, input.getInputIDs()); //for the first iteration we use the input_ids generated by the tokenizer (with the prompt), with batch_size = 1 } - OnnxTensor encoderAttentionMaskTensor = TensorUtils.convertIntArrayToTensor(onnxEnv, input.getAttentionMask()); + int[] attentionMask = input.getAttentionMask(); + OnnxTensor attentionMaskTensor = TensorUtils.convertIntArrayToTensor(onnxEnv, attentionMask); CacheContainerNative cacheContainer = null; OnnxTensor decoderOutput = null; Map decoderInput = new HashMap(); @@ -1034,16 +1068,21 @@ public void executeCacheDecoder(String textToTranslate, TokenizerResult input, O time = System.currentTimeMillis(); //preparing cache initializer input - Map initInput = new HashMap(); - initInput.put("encoder_hidden_states", encoderResult); - //execution of the cache initializer - OrtSession.Result initResult = cacheInitSession.run(initInput); - android.util.Log.i("performance", "Cache initialization done in: " + (System.currentTimeMillis()-time) + "ms"); - encoderResult.close(); //we close it because from now on we only need initResult - + OrtSession.Result initResult = null; + OnnxTensor attentionMaskTensorBatched = null; + OrtSession.Result initResultBatched = null; + if(mode != HY_MT) { + Map initInput = new HashMap(); + initInput.put("encoder_hidden_states", encoderResult); + //execution of the cache initializer + initResult = cacheInitSession.run(initInput); + android.util.Log.i("performance", "Cache initialization done in: " + (System.currentTimeMillis() - time) + "ms"); + if (encoderResult != null) + encoderResult.close(); //we close it because from now on we only need initResult + } //we convert the fixed decoder inputs to have batch_size==beamSize - OnnxTensor encoderAttentionMaskTensorBatched = beamSize > 1 ? batchEncoderAttentionMask(input.getAttentionMask(), beamSize, true) : null; - OrtSession.Result initResultBatched = beamSize > 1 ? batchEncoderKvCache(initResult, nLayers, beamSize, true) : null; + attentionMaskTensorBatched = beamSize > 1 ? batchEncoderAttentionMask(attentionMask, beamSize, true) : null; + initResultBatched = initResult != null && beamSize > 1 ? batchEncoderKvCache(initResult, nLayers, beamSize, true) : null; //this is not executed for HY_MT //we begin the iterative execution of the decoder String[] partialResults = new String[beamSize]; //used for log @@ -1052,8 +1091,8 @@ public void executeCacheDecoder(String textToTranslate, TokenizerResult input, O int[] max = new int[beamSize]; int[][] beamMax = new int[beamSize][beamSize]; int j = 1; - OnnxTensor emptyPreLogits = TensorUtils.createFloatTensorWithSingleValue(onnxEnv, 0, new long[]{EMPTY_BATCH_SIZE, 1, 1024}); - OnnxTensor emptyPreLogitsBatch = TensorUtils.createFloatTensorWithSingleValue(onnxEnv, 0, new long[]{beamSize, 1, 1024}); + OnnxTensor emptyPreLogits = TensorUtils.createFloatTensorWithSingleValue(onnxEnv, 0, new long[]{EMPTY_BATCH_SIZE, 1, hiddenSize}); + OnnxTensor emptyPreLogitsBatch = TensorUtils.createFloatTensorWithSingleValue(onnxEnv, 0, new long[]{beamSize, 1, hiddenSize}); OnnxTensor emptyInputIds = TensorUtils.createInt64TensorWithSingleValue(onnxEnv, 0, new long[]{EMPTY_BATCH_SIZE, 2}); OnnxTensor emptyInputIdsBatch = TensorUtils.createInt64TensorWithSingleValue(onnxEnv, 0, new long[]{beamSize, 2}); @@ -1089,27 +1128,39 @@ public void executeCacheDecoder(String textToTranslate, TokenizerResult input, O decoderInput.put("input_ids", inputIDsTensor); if(j == 1){ //if it is the first iteration //we run the decoder with a batch_size = 1 - decoderInput.put("encoder_attention_mask", encoderAttentionMaskTensor); - long[] shape = {1, 16, 0, hiddenSize}; + if(mode != HY_MT) { + decoderInput.put("encoder_attention_mask", attentionMaskTensor); + }else{ + decoderInput.put("attention_mask", attentionMaskTensor); + } + long[] shape = {1, nHeads, 0, hiddenSizeAttention}; OnnxTensor decoderPastTensor = TensorUtils.createFloatTensorWithSingleValue(onnxEnv,0, shape); for (int i = 0; i < nLayers; i++) { - decoderInput.put("past_key_values." + i + ".decoder.key", decoderPastTensor); - decoderInput.put("past_key_values." + i + ".decoder.value", decoderPastTensor); - decoderInput.put("past_key_values." + i + ".encoder.key", (OnnxTensor) initResult.get("present." + i + ".encoder.key").get()); - decoderInput.put("past_key_values." + i + ".encoder.value", (OnnxTensor) initResult.get("present." + i + ".encoder.value").get()); + decoderInput.put("past_key_values." + i + (mode != HY_MT ? ".decoder" : "") + ".key", decoderPastTensor); + decoderInput.put("past_key_values." + i + (mode != HY_MT ? ".decoder" : "") + ".value", decoderPastTensor); + if(mode != HY_MT){ + decoderInput.put("past_key_values." + i + ".encoder.key", (OnnxTensor) initResult.get("present." + i + ".encoder.key").get()); + decoderInput.put("past_key_values." + i + ".encoder.value", (OnnxTensor) initResult.get("present." + i + ".encoder.value").get()); + } } }else { if(j == 2 && beamSize > 1) { - encoderAttentionMaskTensor.close(); //we close it because from now on we only need encoderAttentionMaskTensorBatched - initResult.close(); //we close it because from now on we only need initResultBatched + attentionMaskTensor.close(); //we close it because from now on we only need attentionMaskTensorBatched + if(initResult != null) initResult.close(); //we close it because from now on we only need initResultBatched } //we run the decoder with batch_size = beamSize - decoderInput.put("encoder_attention_mask", beamSize > 1 ? encoderAttentionMaskTensorBatched : encoderAttentionMaskTensor); + if(mode != HY_MT) { + decoderInput.put("encoder_attention_mask", beamSize > 1 ? attentionMaskTensorBatched : attentionMaskTensor); + }else{ + decoderInput.put("attention_mask", beamSize > 1 ? attentionMaskTensorBatched : attentionMaskTensor); + } for (int i = 0; i < nLayers; i++) { - decoderInput.put("past_key_values." + i + ".decoder.key", (OnnxTensor) result.get("present." + i + ".decoder.key").get()); - decoderInput.put("past_key_values." + i + ".decoder.value", (OnnxTensor) result.get("present." + i + ".decoder.value").get()); - decoderInput.put("past_key_values." + i + ".encoder.key", (OnnxTensor) (beamSize > 1 ? initResultBatched : initResult).get("present." + i + ".encoder.key").get()); - decoderInput.put("past_key_values." + i + ".encoder.value", (OnnxTensor) (beamSize > 1 ? initResultBatched : initResult).get("present." + i + ".encoder.value").get()); + decoderInput.put("past_key_values." + i + (mode != HY_MT ? ".decoder" : "") + ".key", (OnnxTensor) result.get("present." + i + (mode != HY_MT ? ".decoder" : "") + ".key").get()); + decoderInput.put("past_key_values." + i + (mode != HY_MT ? ".decoder" : "") + ".value", (OnnxTensor) result.get("present." + i + (mode != HY_MT ? ".decoder" : "") + ".value").get()); + if(mode != HY_MT) { + decoderInput.put("past_key_values." + i + ".encoder.key", (OnnxTensor) (beamSize > 1 ? initResultBatched : initResult).get("present." + i + ".encoder.key").get()); + decoderInput.put("past_key_values." + i + ".encoder.value", (OnnxTensor) (beamSize > 1 ? initResultBatched : initResult).get("present." + i + ".encoder.value").get()); + } } } oldResult = result; @@ -1170,7 +1221,7 @@ public void executeCacheDecoder(String textToTranslate, TokenizerResult input, O if(beamSize > 1) { //based on the logits we update beam search data int[] maxProbabilities = new int[beamSize]; - cacheContainer = updateBeamSearchData(logits, beamSize, eos, result, j, nLayers, 16, hiddenSize, cacheContainer, maxProbabilities, beamMax, max, completeBeamOutput, beamsOutputsProbabilities); + cacheContainer = updateBeamSearchData(logits, beamSize, eos, result, j, nLayers, nHeads, hiddenSizeAttention, cacheContainer, maxProbabilities, beamMax, max, completeBeamOutput, beamsOutputsProbabilities); }else{ max[0] = Utils.getIndexOfLargest(logits[0][0]); completeBeamOutput[0].add(max[0]); @@ -1180,6 +1231,16 @@ public void executeCacheDecoder(String textToTranslate, TokenizerResult input, O input_ids = max; inputIDsTensor = TensorUtils.createIntTensor(onnxEnv, input_ids, new long[]{beamSize,1}); } + if(mode == HY_MT) { //todo: make this part more optimized (measure and increase the speed and efficiency) + //we increment the attentionMask size of 1 for the nex iteration + attentionMask = new int[attentionMask.length + 1]; + Arrays.fill(attentionMask, 1); + if(beamSize > 1) { + attentionMaskTensorBatched = batchEncoderAttentionMask(attentionMask, beamSize, true); + }else{ + attentionMaskTensor = TensorUtils.convertIntArrayToTensor(onnxEnv, attentionMask); + } + } android.util.Log.i("performance", "post-execution of" + j + "th word done in: " + (System.currentTimeMillis() - time) + "ms"); android.util.Log.i("performance", "Generation of" + j + "th word done in: " + (System.currentTimeMillis() - initialTime) + "ms"); // we return the partial result with the highest probability @@ -1204,9 +1265,9 @@ public void executeCacheDecoder(String textToTranslate, TokenizerResult input, O } if(result != null) result.close(); - initResult.close(); + if(initResult != null) initResult.close(); if(cacheContainer != null) cacheContainer.close(); - if(encoderAttentionMaskTensorBatched != null) encoderAttentionMaskTensorBatched.close(); + if(attentionMaskTensorBatched != null) attentionMaskTensorBatched.close(); if(initResultBatched != null) initResultBatched.close(); } catch (OrtException | InvocationTargetException | NoSuchMethodException | @@ -1311,8 +1372,8 @@ private OrtSession.Result batchDecoderKvCache(OrtSession.Result result, OnnxTens int count = 1; for (int i = 0; i < nLayers; i++) { for (String suffix: suffixes) { - names[count] = "present." + i + ".decoder."+suffix; - float[][][] keyValue = ((float[][][][]) TensorUtils.extractValue(result, "present." + i + ".decoder."+suffix))[0]; + names[count] = "present." + i + (mode!=HY_MT ? ".decoder." : ".") + suffix; + float[][][] keyValue = ((float[][][][]) TensorUtils.extractValue(result, "present." + i + (mode!=HY_MT ? ".decoder." : ".") + suffix))[0]; float[] keyValueFlatBatched = TensorUtils.flattenFloatArrayBatched(keyValue, batchSize); values[count] = TensorUtils.createFloatTensor(onnxEnv, keyValueFlatBatched, new long[]{batchSize, keyValue.length, keyValue[0].length, keyValue[0][0].length}); //todo: evaluate the use of createFloatTensorOptimized count++; @@ -1389,7 +1450,7 @@ private CacheContainerNative updateBeamSearchData( // reorder of the kvCache to match the new selected inputs for the next iteration long timeCache = System.currentTimeMillis(); CacheContainerNative oldCache = cacheContainer; - cacheContainer = new CacheContainerNative(onnxEnv, decoderResult, nLayers, beamSize, nHeads, sequenceLength, hiddenSize); + cacheContainer = new CacheContainerNative(onnxEnv, decoderResult, nLayers, beamSize, nHeads, sequenceLength, hiddenSize, mode); if(oldCache != null){ oldCache.close(); } @@ -1427,6 +1488,16 @@ private static String getSentenceTerminator(Locale locale) { } } + private TokenizerResult tokenize(String text, final CustomLocale inputLanguage, final CustomLocale outputLanguage){ + if (mode == MADLAD_CACHE || mode == MADLAD) { + return tokenizer.tokenize(text, inputLanguage.getCode(), outputLanguage.getCode()); + } else if(mode == NLLB_CACHE || mode == NLLB){ + return tokenizer.tokenize(text, getNllbLanguageCode(inputLanguage.getCode()), getNllbLanguageCode(outputLanguage.getCode())); + }else{ //if mode == HY_MT + return tokenizer.tokenize(text, inputLanguage.getCode(), outputLanguage.getCode(), getHyLanguageInfo(inputLanguage.getCode()), getHyLanguageInfo(outputLanguage.getCode())); + } + } + private void initializeNllbLanguagesCodes(Context context){ DocumentBuilderFactory documentBuilderFactory = DocumentBuilderFactory.newInstance(); @@ -1443,18 +1514,39 @@ private void initializeNllbLanguagesCodes(Context context){ } } - private String getNllbLanguageCode(String languageCode){ - if(nllbLanguagesCodes != null) { - String nllbCode = nllbLanguagesCodes.get(languageCode); - if (nllbCode == null) { - Log.e("error", "Error Converting Language code " + languageCode + " to NLLB code"); - return languageCode; - } else { - return nllbCode; + private void initializeHyLanguagesInfo(Context context){ + DocumentBuilderFactory documentBuilderFactory = DocumentBuilderFactory.newInstance(); + try { + DocumentBuilder documentBuilder = documentBuilderFactory.newDocumentBuilder(); + Document document = documentBuilder.parse(context.getResources().openRawResource(R.raw.hy_mt_supported_languages)); + NodeList listCode = document.getElementsByTagName("code"); + NodeList listEnNames = document.getElementsByTagName("en_name"); + NodeList listZhNames = document.getElementsByTagName("zh_name"); + for (int i = 0; i < listCode.getLength(); i++) { + hyLanguagesInfo.put(listCode.item(i).getTextContent(), new HyLanguageInfo(listEnNames.item(i).getTextContent(), listZhNames.item(i).getTextContent())); } - }else{ - Log.e("error", "Error Converting Language code " + languageCode + " to NLLB code, the NllbLanguagesCodes are not initialized"); + } catch (IOException | SAXException | ParserConfigurationException e) { + e.printStackTrace(); + } + } + + private String getNllbLanguageCode(String languageCode){ + String nllbCode = nllbLanguagesCodes.get(languageCode); + if (nllbCode == null) { + Log.e("error", "Error Converting Language code " + languageCode + " to NLLB code"); return languageCode; + } else { + return nllbCode; + } + } + + private HyLanguageInfo getHyLanguageInfo(String languageCode){ + HyLanguageInfo hyLanguageInfo = hyLanguagesInfo.get(languageCode); + if (hyLanguageInfo == null) { + Log.e("error", "Error Converting Language code " + languageCode + " to HY language info"); + return new HyLanguageInfo(languageCode, languageCode); + } else { + return hyLanguageInfo; } } @@ -1470,12 +1562,14 @@ public static ArrayList getSupportedLanguages(Context context, int Document document = null; if (mode == MADLAD || mode == MADLAD_CACHE) { document = documentBuilder.parse(context.getResources().openRawResource(R.raw.madlad_supported_launguages)); - } else if (mode == NLLB || mode == NLLB_CACHE) { //if mode == NLLB + } else if (mode == NLLB || mode == NLLB_CACHE) { if (!qualityLow) { document = documentBuilder.parse(context.getResources().openRawResource(R.raw.nllb_supported_languages)); } else { document = documentBuilder.parse(context.getResources().openRawResource(R.raw.nllb_supported_languages_all)); } + }else if (mode == HY_MT) { + document = documentBuilder.parse(context.getResources().openRawResource(R.raw.hy_mt_supported_languages)); } NodeList list = document.getElementsByTagName("code"); for (int i = 0; i < list.getLength(); i++) { @@ -1492,6 +1586,15 @@ public static ArrayList getSupportedLanguages(Context context, int return languages; } + public static class HyLanguageInfo{ + public String enName; + public String zhName; + + public HyLanguageInfo(String enName, String zhName) { + this.enName = enName; + this.zhName = zhName; + } + } private interface TranslatorListener { void onFailure(int[] reasons, long value); diff --git a/app/src/main/res/layout/fragment_model_manager.xml b/app/src/main/res/layout/fragment_model_manager.xml index 8c63700..1eca07e 100644 --- a/app/src/main/res/layout/fragment_model_manager.xml +++ b/app/src/main/res/layout/fragment_model_manager.xml @@ -40,10 +40,10 @@ android:layout_height="wrap_content" android:text="Madlad 400" /> + android:text="HY-MT 1.5" />