diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 00000000..f1c588bd
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1 @@
+*.apk filter=lfs diff=lfs merge=lfs -text
diff --git a/.idea/.gitignore b/.idea/.gitignore
index 2aba50ec..69055eae 100644
--- a/.idea/.gitignore
+++ b/.idea/.gitignore
@@ -1,13 +1,11 @@
-# Default ignored files
/shelf/
-/usage.statistics.xml
-/workspace.xml
-#extra ignored files
/caches
/caches/*
+/dictionaries
+/libraries
+/usage.statistics.xml
+/workspace.xml
/gradle.xml
/dataSources.ids
/datasources.xml
-/modules.xml
-/dictionaries
-/libraries
\ No newline at end of file
+/modules.xml
\ No newline at end of file
diff --git a/.idea/caches/deviceStreaming.xml b/.idea/caches/deviceStreaming.xml
deleted file mode 100644
index 6f03899d..00000000
--- a/.idea/caches/deviceStreaming.xml
+++ /dev/null
@@ -1,1041 +0,0 @@
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/gradle.xml b/.idea/gradle.xml
deleted file mode 100644
index 2fa2eefc..00000000
--- a/.idea/gradle.xml
+++ /dev/null
@@ -1,19 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/apks/README.md b/apks/README.md
new file mode 100644
index 00000000..a3efc28e
--- /dev/null
+++ b/apks/README.md
@@ -0,0 +1,14 @@
+If you want to early test this version of the app, first of all consider that this is a development preview, so it could be pretty unstable, and it will require initial setup, so to try it you need to:
+
+Download from [here](https://github.com/niedev/OnnxModelsEnhancer/releases/tag/v1.0.0-beta) the Mozilla.zip,
+Madlad.zip, HY-MT.zip. After the download extract these folders (their name should remain "Mozilla", "Madlad" and "HY-MT", with the content of these .zip directly inside the corresponding extracted folders, if you change the structure of these folder the app will not work).
+
+After that create a folder named "models", inside it create a folder named "Translation", and inside it paste all the extracted folders.
+
+Now download one of the apk in this folder, install it, open the app, enable the requested file access and start the download.
+
+After the download has finished, exit the app and enable all its permissions from the Android Settings (Settings -> Applications -> RTranslator).
+
+Then re open the app and everything should work.
+
+By default, the models used for translation are the Mozilla ones, to select the other supported model you can select one from RTranslator's settings, at the bottom (note that probably, to execute Madlad and HY-MT, you will need a phone with at least 12GB of RAM, if the RAM won't be enough the app will crash until you wipe its data).
diff --git a/apks/RTranslator_3.0.0_alpha1.apk b/apks/RTranslator_3.0.0_alpha1.apk
new file mode 100644
index 00000000..e5f6d48f
--- /dev/null
+++ b/apks/RTranslator_3.0.0_alpha1.apk
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:59f7525063e66da7efbd9c32dac2a4810fdf0df1d15137af3efaa66d87604553
+size 192110576
diff --git a/app/build.gradle b/app/build.gradle
index 2f0c80ca..4fb1eee8 100644
--- a/app/build.gradle
+++ b/app/build.gradle
@@ -42,7 +42,7 @@ android {
//bergamot flags
'-DANDROID_ABI=arm64-v8a',
'-DANDROID_PLATFORM=android-28',
- '-DANDROID_STL=c++_static',
+ '-DANDROID_STL=c++_shared',
"-DANDROID_PIE=ON",
"-DANDROID_CPP_FEATURES=exceptions",
'-DCMAKE_BUILD_TYPE=Release',
@@ -121,6 +121,11 @@ dependencies {
implementation "androidx.lifecycle:lifecycle-extensions:2.2.0"
//Download library
implementation 'com.github.amitshekhariitbhu:PRDownloader:1.0.2'
+ // DJL HuggingFace tokenizers wrapper
+ implementation "ai.djl.android:core:0.33.0"
+ implementation "ai.djl.huggingface:tokenizers:0.33.0"
+ implementation "ai.djl.android:tokenizer-native:0.33.0"
+
//implementation 'androidx.core:core-ktx:1.10.0'
implementation 'androidx.work:work-runtime:2.7.1'
implementation 'androidx.exifinterface:exifinterface:1.3.7'
diff --git a/app/src/main/java/nie/translator/rtranslator/Global.java b/app/src/main/java/nie/translator/rtranslator/Global.java
index 7b6c6f0f..a95615ba 100644
--- a/app/src/main/java/nie/translator/rtranslator/Global.java
+++ b/app/src/main/java/nie/translator/rtranslator/Global.java
@@ -55,7 +55,7 @@
public class Global extends Application implements DefaultLifecycleObserver {
- public static final boolean ONLY_TEXT_TRANSLATION_MODE = true;
+ public static final boolean ONLY_TEXT_TRANSLATION_MODE = false;
public enum RTranslatorMode {
TEXT_TRANSLATION_MODE,
WALKIE_TALKIE_MODE,
@@ -128,7 +128,18 @@ public void restartTranslator(Translator.GeneralListener listener){
getLanguages(false);
SharedPreferences sharedPreferences = getSharedPreferences("default", Context.MODE_PRIVATE);
int mode = sharedPreferences.getInt("selectedTranslationModel", Translator.MOZILLA);
- translator.restart(mode, listener);
+ translator.restart(mode, new Translator.GeneralListener() {
+ @Override
+ public void onSuccess() {
+ getTranslatorLanguages(false); //refresh languages
+ listener.onSuccess();
+ }
+
+ @Override
+ public void onFailure(int[] reasons, long value) {
+ listener.onFailure(reasons, value);
+ }
+ });
}
@Nullable
diff --git a/app/src/main/java/nie/translator/rtranslator/LoadingActivity.java b/app/src/main/java/nie/translator/rtranslator/LoadingActivity.java
index 0d915336..9a5338e7 100644
--- a/app/src/main/java/nie/translator/rtranslator/LoadingActivity.java
+++ b/app/src/main/java/nie/translator/rtranslator/LoadingActivity.java
@@ -20,9 +20,13 @@
import android.content.DialogInterface;
import android.content.Intent;
import android.content.SharedPreferences;
+import android.net.Uri;
+import android.os.Build;
import android.os.Bundle;
+import android.os.Environment;
import android.os.Handler;
import android.os.Looper;
+import android.provider.Settings;
import androidx.appcompat.app.AlertDialog;
import java.util.ArrayList;
diff --git a/app/src/main/java/nie/translator/rtranslator/access/AccessActivity.java b/app/src/main/java/nie/translator/rtranslator/access/AccessActivity.java
index 3d5ac624..fc460d71 100644
--- a/app/src/main/java/nie/translator/rtranslator/access/AccessActivity.java
+++ b/app/src/main/java/nie/translator/rtranslator/access/AccessActivity.java
@@ -18,9 +18,13 @@
import android.Manifest;
import android.content.Context;
+import android.content.Intent;
import android.content.SharedPreferences;
+import android.net.Uri;
import android.os.Build;
import android.os.Bundle;
+import android.os.Environment;
+import android.provider.Settings;
import android.view.View;
import androidx.fragment.app.Fragment;
@@ -78,6 +82,12 @@ protected void onStart() {
super.onStart(); //called here because otherwise the onStart of the DownloadFragment is called before this onStart, and this could cause problems.
}
+ @Override
+ protected void onResume() {
+ super.onResume();
+ checkAllFilesPermission(); //todo: remove before the final release
+ }
+
@Override
protected void onStop() {
super.onStop();
@@ -151,6 +161,27 @@ public void onBackPressed() {
}
super.onBackPressed();
}
+
+ //todo: remove before the final release
+ private void checkAllFilesPermission() {
+ if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.R) {
+ if (!Environment.isExternalStorageManager()) {
+ requestAllFilesPermission();
+ }
+ }
+ }
+
+ //todo: remove before the final release
+ private void requestAllFilesPermission() {
+ try {
+ Intent intent = new Intent(Settings.ACTION_MANAGE_APP_ALL_FILES_ACCESS_PERMISSION);
+ intent.setData(Uri.parse("package:" + getPackageName()));
+ startActivityForResult(intent, 100);
+ } catch (Exception e) {
+ Intent intent = new Intent(Settings.ACTION_MANAGE_ALL_FILES_ACCESS_PERMISSION);
+ startActivityForResult(intent, 100);
+ }
+ }
}
diff --git a/app/src/main/java/nie/translator/rtranslator/settings/LanguagePreference.java b/app/src/main/java/nie/translator/rtranslator/settings/LanguagePreference.java
index 6871c196..c78516c5 100644
--- a/app/src/main/java/nie/translator/rtranslator/settings/LanguagePreference.java
+++ b/app/src/main/java/nie/translator/rtranslator/settings/LanguagePreference.java
@@ -127,6 +127,7 @@ public void onClick(View v) {
}
private void showList() {
+ progressBar.setVisibility(View.GONE);
reloadButton.setVisibility(View.GONE);
final ArrayList languages = global.getLanguages(true);
diff --git a/app/src/main/java/nie/translator/rtranslator/settings/ModelManagerFragment.java b/app/src/main/java/nie/translator/rtranslator/settings/ModelManagerFragment.java
index 98c0f0b2..ce2ed4c0 100644
--- a/app/src/main/java/nie/translator/rtranslator/settings/ModelManagerFragment.java
+++ b/app/src/main/java/nie/translator/rtranslator/settings/ModelManagerFragment.java
@@ -13,13 +13,9 @@
import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import androidx.fragment.app.Fragment;
-import androidx.recyclerview.widget.DefaultItemAnimator;
-import androidx.recyclerview.widget.LinearLayoutManager;
-import androidx.recyclerview.widget.RecyclerView;
import nie.translator.rtranslator.Global;
import nie.translator.rtranslator.R;
-import nie.translator.rtranslator.voice_translation.VoiceTranslationActivity;
import nie.translator.rtranslator.voice_translation.neural_networks.translation.Translator;
public class ModelManagerFragment extends Fragment {
@@ -71,8 +67,8 @@ public void onActivityCreated(@Nullable Bundle savedInstanceState) {
case Translator.MADLAD_CACHE:
radioGroup.check(R.id.madlad_radio);
break;
- case Translator.GEMMA:
- radioGroup.check(R.id.gemma_radio);
+ case Translator.HY_MT:
+ radioGroup.check(R.id.hy_radio);
break;
}
@@ -87,8 +83,8 @@ public void onActivityCreated(@Nullable Bundle savedInstanceState) {
case R.id.madlad_radio:
changeModel(Translator.MADLAD_CACHE);
break;
- case R.id.gemma_radio:
- changeModel(Translator.GEMMA);
+ case R.id.hy_radio:
+ changeModel(Translator.HY_MT);
break;
}
});
diff --git a/app/src/main/java/nie/translator/rtranslator/tools/nn/CacheContainerNative.java b/app/src/main/java/nie/translator/rtranslator/tools/nn/CacheContainerNative.java
index 336b85fe..49676f99 100644
--- a/app/src/main/java/nie/translator/rtranslator/tools/nn/CacheContainerNative.java
+++ b/app/src/main/java/nie/translator/rtranslator/tools/nn/CacheContainerNative.java
@@ -29,24 +29,26 @@
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
+import nie.translator.rtranslator.voice_translation.neural_networks.translation.Translator;
public class CacheContainerNative {
private int[] shape;
private OnnxTensor[] cacheTensors;
private long cacheContainerNativePointer;
+ private int mode = Translator.NLLB;
//Used to load CacheContainerNative.cpp on application startup
static {
System.loadLibrary("cache_container_native");
}
- public CacheContainerNative(OrtEnvironment env, OrtSession.Result cache, int nLevels, int batchSize, int nHeads, int sequenceLength, int hiddenSize){
+ public CacheContainerNative(OrtEnvironment env, OrtSession.Result cache, int nLevels, int batchSize, int nHeads, int sequenceLength, int hiddenSize, int mode){
try {
cacheTensors = new OnnxTensor[nLevels*2];
cacheContainerNativePointer = initialize(nLevels*2, batchSize, nHeads, sequenceLength, hiddenSize);
int count=0;
for (int i = 0; i < nLevels; i++) {
- cacheTensors[count] = (OnnxTensor) cache.get("present." + i + ".decoder.key").get();
+ cacheTensors[count] = (OnnxTensor) cache.get("present." + i + (mode!= Translator.HY_MT ? ".decoder" : "") + ".key").get();
//we use OnnxTensor's private getBuffer method, which returns the data reference without making a copy of it and we pass this reference to the native cache container
Method method = cacheTensors[count].getClass().getDeclaredMethod("getBuffer");
method.setAccessible(true);
@@ -54,7 +56,7 @@ public CacheContainerNative(OrtEnvironment env, OrtSession.Result cache, int nLe
insertValues(cacheContainerNativePointer, count, buffer);
count++;
- cacheTensors[count] = (OnnxTensor) cache.get("present." + i + ".decoder.value").get();
+ cacheTensors[count] = (OnnxTensor) cache.get("present." + i + (mode!= Translator.HY_MT ? ".decoder" : "") + ".value").get();
//we use OnnxTensor's private getBuffer method, which returns the data reference without making a copy of it and we pass this reference to the native cache container
method = cacheTensors[count].getClass().getDeclaredMethod("getBuffer");
method.setAccessible(true);
@@ -63,6 +65,7 @@ public CacheContainerNative(OrtEnvironment env, OrtSession.Result cache, int nLe
count++;
}
shape = new int[]{nLevels*2, batchSize, nHeads, sequenceLength, hiddenSize};
+ this.mode = mode;
} catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException e) {
e.printStackTrace();
}
@@ -86,7 +89,7 @@ public OrtSession.Result getCacheResult(OrtEnvironment env){
int count = 0;
for (int i = 0; i < shape[0]/2; i++) {
for (String suffix: suffixes) {
- names[count] = "present." + i + ".decoder."+suffix;
+ names[count] = "present." + i + (mode!= Translator.HY_MT ? ".decoder." : ".") + suffix;
ByteBuffer buffer = getBuffer(cacheContainerNativePointer, count);
values[count] = OnnxTensor.createTensor(env, buffer, new long[]{shape[1], shape[2], shape[3], shape[4]}, OnnxJavaType.FLOAT);
count++;
diff --git a/app/src/main/java/nie/translator/rtranslator/voice_translation/neural_networks/translation/Tokenizer.java b/app/src/main/java/nie/translator/rtranslator/voice_translation/neural_networks/translation/Tokenizer.java
index 7da2d209..fba89125 100644
--- a/app/src/main/java/nie/translator/rtranslator/voice_translation/neural_networks/translation/Tokenizer.java
+++ b/app/src/main/java/nie/translator/rtranslator/voice_translation/neural_networks/translation/Tokenizer.java
@@ -16,121 +16,171 @@
package nie.translator.rtranslator.voice_translation.neural_networks.translation;
+import androidx.annotation.Nullable;
+
+import java.io.IOException;
+import java.nio.file.Paths;
import java.util.Arrays;
+import ai.djl.huggingface.tokenizers.Encoding;
+import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
+
public class Tokenizer {
public static final int NLLB = 0;
public static final int NLLB_FIXED = 1;
public static final int SEAMLESS = 2;
public static final int MADLAD = 3;
public static final int MADLAD_FIXED = 4;
+ public static final int HY_MT = 5;
private SentencePieceProcessorJava spProcessor;
+ private HuggingFaceTokenizer hfTokenizer;
private final int mode;
//languagesNLLB and languagesSeamless contain the list of all supported languages, sorted by their ID order (their IDs are consecutive)
private final String[] languagesNLLB = {"ace_Arab", "ace_Latn", "acm_Arab", "acq_Arab", "aeb_Arab", "afr_Latn", "ajp_Arab", "aka_Latn", "amh_Ethi", "apc_Arab", "arb_Arab", "ars_Arab", "ary_Arab", "arz_Arab", "asm_Beng", "ast_Latn", "awa_Deva", "ayr_Latn", "azb_Arab", "azj_Latn", "bak_Cyrl", "bam_Latn", "ban_Latn", "bel_Cyrl", "bem_Latn", "ben_Beng", "bho_Deva", "bjn_Arab", "bjn_Latn", "bod_Tibt", "bos_Latn", "bug_Latn", "bul_Cyrl", "cat_Latn", "ceb_Latn", "ces_Latn", "cjk_Latn", "ckb_Arab", "crh_Latn", "cym_Latn", "dan_Latn", "deu_Latn", "dik_Latn", "dyu_Latn", "dzo_Tibt", "ell_Grek", "eng_Latn", "epo_Latn", "est_Latn", "eus_Latn", "ewe_Latn", "fao_Latn", "pes_Arab", "fij_Latn", "fin_Latn", "fon_Latn", "fra_Latn", "fur_Latn", "fuv_Latn", "gla_Latn", "gle_Latn", "glg_Latn", "grn_Latn", "guj_Gujr", "hat_Latn", "hau_Latn", "heb_Hebr", "hin_Deva", "hne_Deva", "hrv_Latn", "hun_Latn", "hye_Armn", "ibo_Latn", "ilo_Latn", "ind_Latn", "isl_Latn", "ita_Latn", "jav_Latn", "jpn_Jpan", "kab_Latn", "kac_Latn", "kam_Latn", "kan_Knda", "kas_Arab", "kas_Deva", "kat_Geor", "knc_Arab", "knc_Latn", "kaz_Cyrl", "kbp_Latn", "kea_Latn", "khm_Khmr", "kik_Latn", "kin_Latn", "kir_Cyrl", "kmb_Latn", "kon_Latn", "kor_Hang", "kmr_Latn", "lao_Laoo", "lvs_Latn", "lij_Latn", "lim_Latn", "lin_Latn", "lit_Latn", "lmo_Latn", "ltg_Latn", "ltz_Latn", "lua_Latn", "lug_Latn", "luo_Latn", "lus_Latn", "mag_Deva", "mai_Deva", "mal_Mlym", "mar_Deva", "min_Latn", "mkd_Cyrl", "plt_Latn", "mlt_Latn", "mni_Beng", "khk_Cyrl", "mos_Latn", "mri_Latn", "zsm_Latn", "mya_Mymr", "nld_Latn", "nno_Latn", "nob_Latn", "npi_Deva", "nso_Latn", "nus_Latn", "nya_Latn", "oci_Latn", "gaz_Latn", "ory_Orya", "pag_Latn", "pan_Guru", "pap_Latn", "pol_Latn", "por_Latn", "prs_Arab", "pbt_Arab", "quy_Latn", "ron_Latn", "run_Latn", "rus_Cyrl", "sag_Latn", "san_Deva", "sat_Beng", "scn_Latn", "shn_Mymr", "sin_Sinh", "slk_Latn", "slv_Latn", "smo_Latn", "sna_Latn", "snd_Arab", "som_Latn", "sot_Latn", "spa_Latn", "als_Latn", "srd_Latn", "srp_Cyrl", "ssw_Latn", "sun_Latn", "swe_Latn", "swh_Latn", "szl_Latn", "tam_Taml", "tat_Cyrl", "tel_Telu", "tgk_Cyrl", "tgl_Latn", "tha_Thai", "tir_Ethi", "taq_Latn", "taq_Tfng", "tpi_Latn", "tsn_Latn", "tso_Latn", "tuk_Latn", "tum_Latn", "tur_Latn", "twi_Latn", "tzm_Tfng", "uig_Arab", "ukr_Cyrl", "umb_Latn", "urd_Arab", "uzn_Latn", "vec_Latn", "vie_Latn", "war_Latn", "wol_Latn", "xho_Latn", "ydd_Hebr", "yor_Latn", "yue_Hant", "zho_Hans", "zho_Hant", "zul_Latn"};
private final String[] languagesSeamless = {"ace", "ace_Latn", "acm", "acq", "aeb", "afr", "ajp", "aka", "amh", "apc", "arb", "ars", "ary", "arz", "asm", "ast", "awa", "ayr", "azb", "azj", "bak", "bam", "ban", "bel", "bem", "ben", "bho", "bjn", "bjn_Latn", "bod", "bos", "bug", "bul", "cat", "ceb", "ces", "cjk", "ckb", "crh", "cym", "dan", "deu", "dik", "dyu", "dzo", "ell", "eng", "epo", "est", "eus", "ewe", "fao", "pes", "fij", "fin", "fon", "fra", "fur", "fuv", "gla", "gle", "glg", "grn", "guj", "hat", "hau", "heb", "hin", "hne", "hrv", "hun", "hye", "ibo", "ilo", "ind", "isl", "ita", "jav", "jpn", "kab", "kac", "kam", "kan", "kas", "kas_Deva", "kat", "knc", "knc_Latn", "kaz", "kbp", "kea", "khm", "kik", "kin", "kir", "kmb", "kon", "kor", "kmr", "lao", "lvs", "lij", "lim", "lin", "lit", "lmo", "ltg", "ltz", "lua", "lug", "luo", "lus", "mag", "mai", "mal", "mar", "min", "mkd", "plt", "mlt", "mni", "khk", "mos", "mri", "zsm", "mya", "nld", "nno", "nob", "npi", "nso", "nus", "nya", "oci", "gaz", "ory", "pag", "pan", "pap", "pol", "por", "prs", "pbt", "quy", "ron", "run", "rus", "sag", "san", "sat", "scn", "shn", "sin", "slk", "slv", "smo", "sna", "snd", "som", "sot", "spa", "als", "srd", "srp", "ssw", "sun", "swe", "swh", "szl", "tam", "tat", "tel", "tgk", "tgl", "tha", "tir", "taq", "taq_Tfng", "tpi", "tsn", "tso", "tuk", "tum", "tur", "twi", "tzm", "uig", "ukr", "umb", "urd", "uzn", "vec", "vie", "war", "wol", "xho", "ydd", "yor", "yue", "cmn", "cmn_Hant", "zul"};
private final int DICTIONARY_LENGTH = 256000;
- public Tokenizer(String vocab_file, int mode) {
- spProcessor = new SentencePieceProcessorJava();
- spProcessor.Load(vocab_file);
+ public Tokenizer(String vocab_path, int mode) throws IOException {
+ if(mode != HY_MT) {
+ spProcessor = new SentencePieceProcessorJava();
+ spProcessor.Load(vocab_path);
+ }else{
+ hfTokenizer = HuggingFaceTokenizer.newInstance(Paths.get(vocab_path));
+ }
this.mode = mode;
}
- public TokenizerResult tokenize(String srcLanguage, String tgtLanguage, String text) {
- //for madlad we add <2tgtLanguage> at the beginning of the text (srcLanguage is not specified)
- if (mode == MADLAD || mode==MADLAD_FIXED){
- text = "<2"+tgtLanguage+"> "+text;
- }
- //we translate text into ids via sentencepiece
- int[] ids = spProcessor.encode(text);
+ public TokenizerResult tokenize(String text, String srcLanguage, String tgtLanguage){
+ return tokenize(text, srcLanguage, tgtLanguage, null, null, false);
+ }
+
+ public TokenizerResult tokenize(String text, String srcLanguage, String tgtLanguage, @Nullable Translator.HyLanguageInfo srcLangInfo, @Nullable Translator.HyLanguageInfo tgtLangInfo, boolean excludePrompt) {
+ if(mode != HY_MT) {
+ //for madlad we add <2tgtLanguage> at the beginning of the text (srcLanguage is not specified)
+ if (mode == MADLAD || mode == MADLAD_FIXED) {
+ text = "<2" + tgtLanguage + "> " + text;
+ }
+ //we translate text into ids via sentencepiece
+ int[] ids = spProcessor.encode(text);
/* The NLLBTokenizer's dictionary has a different mapping for tokens identified by the first 4 IDs values (from 0 to 3),
also for the other IDs it has a value equal to those of the dictionary we passed to sentencepiece but with an addition of 1
(idNLLB = idSentencePiece + 1), so now we make the necessary adjustments */
- /* For the SeamlessTokenizer's dictionary the same thing is true but the + 1 is also valid for the first 4 values of the IDs (the value 0, in addition, represents the padding) */
- if(mode != MADLAD && mode != MADLAD_FIXED) { //MADLAD has a one-to-one match to the sentencepiece dictionary.
- for (int i = 0; i < ids.length; i++) {
- //we add 1 to each element of ids
- ids[i] = ids[i] + 1;
- //we replace the values from 0 to 3 with the correct ones for NLBTokenizer (for Seamless the values are already all correct)
- if (mode == NLLB || mode == NLLB_FIXED) {
- switch (ids[i]) {
- case 1: {
- ids[i] = 3;
- break;
- }
- case 2: {
- ids[i] = 0;
- break;
- }
- case 3: {
- ids[i] = 2;
- break;
+ /* For the SeamlessTokenizer's dictionary the same thing is true but the + 1 is also valid for the first 4 values of the IDs (the value 0, in addition, represents the padding) */
+ if (mode != MADLAD && mode != MADLAD_FIXED) { //MADLAD has a one-to-one match to the sentencepiece dictionary.
+ for (int i = 0; i < ids.length; i++) {
+ //we add 1 to each element of ids
+ ids[i] = ids[i] + 1;
+ //we replace the values from 0 to 3 with the correct ones for NLBTokenizer (for Seamless the values are already all correct)
+ if (mode == NLLB || mode == NLLB_FIXED) {
+ switch (ids[i]) {
+ case 1: {
+ ids[i] = 3;
+ break;
+ }
+ case 2: {
+ ids[i] = 0;
+ break;
+ }
+ case 3: {
+ ids[i] = 2;
+ break;
+ }
}
}
}
}
- }
- //add at the end and srcLanguage at the beginning (srcLanguage is not added for Madlad)
- int eos = PieceToID("");
- int srcLanguageID = getLanguageID(srcLanguage);
- int[] idsExtended;
- if(mode != MADLAD && mode != MADLAD_FIXED) {
- idsExtended = new int[ids.length + 2];
- System.arraycopy(ids, 0, idsExtended, 1, ids.length);
- idsExtended[idsExtended.length - 1] = eos;
- idsExtended[0] = srcLanguageID;
- }else{
- idsExtended = new int[ids.length + 1];
- System.arraycopy(ids, 0, idsExtended, 0, ids.length);
- idsExtended[idsExtended.length - 1] = eos;
- }
+ //add at the end and srcLanguage at the beginning (srcLanguage is not added for Madlad)
+ int eos = PieceToID("");
+ int srcLanguageID = getLanguageID(srcLanguage);
+ int[] idsExtended;
+ if (mode != MADLAD && mode != MADLAD_FIXED) {
+ idsExtended = new int[ids.length + 2];
+ System.arraycopy(ids, 0, idsExtended, 1, ids.length);
+ idsExtended[idsExtended.length - 1] = eos;
+ idsExtended[0] = srcLanguageID;
+ } else {
+ idsExtended = new int[ids.length + 1];
+ System.arraycopy(ids, 0, idsExtended, 0, ids.length);
+ idsExtended[idsExtended.length - 1] = eos;
+ }
- //we create the attention mask
- int[] attentionMask = new int[idsExtended.length];
- Arrays.fill(attentionMask, 1);
-
- if(mode == NLLB || mode == MADLAD) {
- return new TokenizerResult(idsExtended, attentionMask);
- }else if(mode == SEAMLESS){
- //for seamless the ids must always be 512, we fill the empty ids with padding (0)
- int[] idsPadded = new int[512];
- Arrays.fill(idsPadded, 0);
- System.arraycopy(idsExtended, 0, idsPadded, 0, idsExtended.length);
- //for seamless also the attention mask must always have length 512, we fill the rest with 0
- int[] attentionMaskPadded = new int[512];
- Arrays.fill(attentionMaskPadded, 0);
- System.arraycopy(attentionMask, 0, attentionMaskPadded, 0, attentionMask.length);
- return new TokenizerResult(idsPadded, attentionMaskPadded);
- }else if(mode == NLLB_FIXED){
- //for NLLB Fixed the ids must always be 256, we fill the empty ids with padding (0)
- int[] idsPadded = new int[256];
- Arrays.fill(idsPadded, 0);
- System.arraycopy(idsExtended, 0, idsPadded, 0, idsExtended.length);
- //also the attention mask must always have length 256, we fill the rest with 0
- int[] attentionMaskPadded = new int[256];
- Arrays.fill(attentionMaskPadded, 0);
- System.arraycopy(attentionMask, 0, attentionMaskPadded, 0, attentionMask.length);
- return new TokenizerResult(idsPadded, attentionMaskPadded);
- }else{
- //for Madlad Fixed the ids must always be 128, we fill the empty ids with padding (1)
- int[] idsPadded = new int[128];
- Arrays.fill(idsPadded, 1);
- System.arraycopy(idsExtended, 0, idsPadded, 0, idsExtended.length);
- //also the attention mask must always have length 128, we fill the rest with 0
- int[] attentionMaskPadded = new int[128];
- Arrays.fill(attentionMaskPadded, 0);
- System.arraycopy(attentionMask, 0, attentionMaskPadded, 0, attentionMask.length);
- return new TokenizerResult(idsPadded, attentionMaskPadded);
+ //we create the attention mask
+ int[] attentionMask = new int[idsExtended.length];
+ Arrays.fill(attentionMask, 1);
+
+ if (mode == NLLB || mode == MADLAD) {
+ return new TokenizerResult(idsExtended, attentionMask);
+ } else if (mode == SEAMLESS) {
+ //for seamless the ids must always be 512, we fill the empty ids with padding (0)
+ int[] idsPadded = new int[512];
+ Arrays.fill(idsPadded, 0);
+ System.arraycopy(idsExtended, 0, idsPadded, 0, idsExtended.length);
+ //for seamless also the attention mask must always have length 512, we fill the rest with 0
+ int[] attentionMaskPadded = new int[512];
+ Arrays.fill(attentionMaskPadded, 0);
+ System.arraycopy(attentionMask, 0, attentionMaskPadded, 0, attentionMask.length);
+ return new TokenizerResult(idsPadded, attentionMaskPadded);
+ } else if (mode == NLLB_FIXED) {
+ //for NLLB Fixed the ids must always be 256, we fill the empty ids with padding (0)
+ int[] idsPadded = new int[256];
+ Arrays.fill(idsPadded, 0);
+ System.arraycopy(idsExtended, 0, idsPadded, 0, idsExtended.length);
+ //also the attention mask must always have length 256, we fill the rest with 0
+ int[] attentionMaskPadded = new int[256];
+ Arrays.fill(attentionMaskPadded, 0);
+ System.arraycopy(attentionMask, 0, attentionMaskPadded, 0, attentionMask.length);
+ return new TokenizerResult(idsPadded, attentionMaskPadded);
+ } else {
+ //for Madlad Fixed the ids must always be 128, we fill the empty ids with padding (1)
+ int[] idsPadded = new int[128];
+ Arrays.fill(idsPadded, 1);
+ System.arraycopy(idsExtended, 0, idsPadded, 0, idsExtended.length);
+ //also the attention mask must always have length 128, we fill the rest with 0
+ int[] attentionMaskPadded = new int[128];
+ Arrays.fill(attentionMaskPadded, 0);
+ System.arraycopy(attentionMask, 0, attentionMaskPadded, 0, attentionMask.length);
+ return new TokenizerResult(idsPadded, attentionMaskPadded);
+ }
+ } else {
+ //tokenization for hy-mt
+ String tgtLangEnName = tgtLanguage;
+ String tgtLangZhName = tgtLanguage;
+ if(tgtLangInfo != null){
+ tgtLangEnName = tgtLangInfo.enName;
+ tgtLangZhName = tgtLangInfo.zhName;
+ }
+ String prompt;
+ if(srcLanguage.equals("zh")) {
+ prompt = "<|hy_begin▁of▁sentence|><|hy_User|>将以下文本翻译为"+tgtLangZhName+",注意只需要输出翻译后的结果,不要额外解释:\n\n"+text+"<|hy_Assistant|>";
+ } else {
+ prompt = "<|hy_begin▁of▁sentence|><|hy_User|>Translate the following segment into "+tgtLangEnName+", without additional explanation.\n\n"+text+"<|hy_Assistant|>";
+ }
+ Encoding result = hfTokenizer.encode(new String[]{excludePrompt ? text : prompt}, false, false);
+
+ //we convert inputIds and attention mask from long[] to int[]
+ long[] inputIdsLong = result.getIds();
+ int[] inputIds = new int[inputIdsLong.length];
+ for (int i = 0; i < inputIds.length; i++) {
+ inputIds[i] = (int) inputIdsLong[i];
+ }
+ long[] attentionMaskLong = result.getAttentionMask();
+ int[] attentionMask = new int[attentionMaskLong.length];
+ for (int i = 0; i < attentionMask.length; i++) {
+ attentionMask[i] = (int) attentionMaskLong[i];
+ }
+
+ return new TokenizerResult(inputIds, attentionMask);
}
}
public int PieceToID(String token){
if(mode == NLLB || mode == NLLB_FIXED || mode==MADLAD || mode==MADLAD_FIXED) {
return spProcessor.PieceToID(token);
+ }else if(mode == HY_MT) {
+ return (int) hfTokenizer.encode(token).getIds()[0];
}else{
return spProcessor.PieceToID(token)+1;
}
@@ -155,13 +205,20 @@ public int getLanguageID(String language){
public String decode(int[] ids) {
String output = "";
- if(mode != MADLAD && mode != MADLAD_FIXED) {
+ if(mode == NLLB || mode == NLLB_FIXED) {
for (int i = 0; i < ids.length; i++) {
if (ids[i] < DICTIONARY_LENGTH && ids[i] > 3) { //This check skips special tokens and those that sentencepiece does not have in the dictionary (such as languages)
output = output.concat(spProcessor.IDToPiece(ids[i] - 1));
}
}
- }else{
+ }else if(mode == HY_MT){
+ long[] inputIdsLong = new long[ids.length];
+ for (int i = 0; i < ids.length; i++) {
+ inputIdsLong[i] = (long) ids[i];
+ }
+ output = hfTokenizer.decode(inputIdsLong);
+ output = output.replace("<|hy_place▁holder▁no▁2|>", ""); //we remove the eos token
+ }else{ //madlad and seamless
for (int i = 0; i < ids.length; i++) {
if (ids[i] < DICTIONARY_LENGTH && ids[i] > 3) { //This check skips special tokens and those that sentencepiece does not have in the dictionary (such as languages)
output = output.concat(spProcessor.IDToPiece(ids[i]));
diff --git a/app/src/main/java/nie/translator/rtranslator/voice_translation/neural_networks/translation/Translator.java b/app/src/main/java/nie/translator/rtranslator/voice_translation/neural_networks/translation/Translator.java
index 2d7ae195..92b20cd3 100644
--- a/app/src/main/java/nie/translator/rtranslator/voice_translation/neural_networks/translation/Translator.java
+++ b/app/src/main/java/nie/translator/rtranslator/voice_translation/neural_networks/translation/Translator.java
@@ -64,7 +64,6 @@
import nie.translator.rtranslator.tools.CustomLocale;
import nie.translator.rtranslator.tools.ErrorCodes;
import nie.translator.rtranslator.tools.FileTools;
-import nie.translator.rtranslator.tools.gui.animations.CustomAnimator;
import nie.translator.rtranslator.tools.gui.messages.GuiMessage;
import nie.translator.rtranslator.tools.nn.CacheContainerNative;
import nie.translator.rtranslator.tools.nn.TensorUtils;
@@ -80,7 +79,7 @@ public class Translator extends NeuralNetworkApi {
public static final int MADLAD = 3;
public static final int MADLAD_CACHE = 5;
public static final int MOZILLA = 7;
- public static final int GEMMA = 8;
+ public static final int HY_MT = 8;
private int mode;
private final BergamotModelsIndicator bergamotModelsIndicator = new BergamotModelsIndicator();
private Tokenizer tokenizer;
@@ -90,7 +89,8 @@ public class Translator extends NeuralNetworkApi {
private OrtSession cacheInitSession;
private OrtSession embedAndLmHeadSession;
private OrtSession embedSession;
- private Map nllbLanguagesCodes = new HashMap();
+ private final Map nllbLanguagesCodes = new HashMap<>();
+ private final Map hyLanguagesInfo = new HashMap<>();
private static final double EOS_PENALTY = 0.9;
@Nullable
private GuiMessage lastInputText;
@@ -121,35 +121,46 @@ public Translator(@NonNull Global global, int mode, GeneralListener initListener
this.mode = mode;
mainHandler = new android.os.Handler(Looper.getMainLooper());
initializeNllbLanguagesCodes(global);
+ initializeHyLanguagesInfo(global);
initialize(global, mode, initListener);
}
private void initialize(@NonNull Global global, int mode, GeneralListener initListener){
- String encoderPath;
- String decoderPath;
- String vocabPath;
- String embedAndLmHeadPath;
- String cacheInitializerPath;
+ String encoderPath = "";
+ String decoderPath = "";
+ String vocabPath = "";
+ String embedAndLmHeadPath = "";
+ String cacheInitializerPath = "";
if(mode == NLLB || mode == NLLB_CACHE) {
//8 bit
- /*encoderPath = global.getFilesDir().getPath() + "/NLLB_encoder.onnx";
+ encoderPath = global.getFilesDir().getPath() + "/NLLB_encoder.onnx";
decoderPath = global.getFilesDir().getPath() + "/NLLB_decoder.onnx";
vocabPath = global.getFilesDir().getPath() + "/sentencepiece_bpe.model";
embedAndLmHeadPath = global.getFilesDir().getPath() + "/NLLB_embed_and_lm_head.onnx";
- cacheInitializerPath = global.getFilesDir().getPath() + "/NLLB_cache_initializer.onnx";*/
+ cacheInitializerPath = global.getFilesDir().getPath() + "/NLLB_cache_initializer.onnx";
//4 bit
- encoderPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/NLLB" + "/nllb_encoder_4bit.onnx";
+ /*encoderPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/NLLB" + "/nllb_encoder_4bit.onnx";
decoderPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/NLLB" + "/nllb_decoder_4bit.onnx";
vocabPath = global.getFilesDir().getPath() + "/sentencepiece_bpe.model";
embedAndLmHeadPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/NLLB" + "/nllb_embed_and_lm_head_4bit.onnx";
- cacheInitializerPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/NLLB" + "/nllb_cache_initializer_4bit.onnx";
- }else { //madlad
- encoderPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/Madlad" + "/Int4Acc4/madlad_encoder_4bit.onnx";
- decoderPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/Madlad" + "/Int4Acc4/madlad_decoder_4bit.onnx";
+ cacheInitializerPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/NLLB" + "/nllb_cache_initializer_4bit.onnx";*/
+ }else if(mode == MADLAD || mode == MADLAD_CACHE){ //madlad
+ //8bit
+ encoderPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/Madlad" + "/Int8WO/madlad_encoder_8bit.onnx";
+ decoderPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/Madlad" + "/Int8WO/madlad_decoder_8bit.onnx";
vocabPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/Madlad" + "/spiece.model";
embedAndLmHeadPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/Madlad" + "/madlad_embed_8bit.onnx";
- cacheInitializerPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/Madlad" + "/Int4Acc4/madlad_cache_initializer_4bit.onnx";
+ cacheInitializerPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/Madlad" + "/Int8WO/madlad_cache_initializer_8bit.onnx";
+ //4bit
+ /*encoderPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/Madlad" + "/Int4_16/madlad_encoder_4bit.onnx";
+ decoderPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/Madlad" + "/Int4_16/madlad_decoder_4bit.onnx";
+ vocabPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/Madlad" + "/spiece.model";
+ embedAndLmHeadPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/Madlad" + "/madlad_embed_8bit.onnx";
+ cacheInitializerPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/Madlad" + "/Int4_16/madlad_cache_initializer_4bit.onnx";*/
+ }else { //hy-mt
+ decoderPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/HY-MT" + "/model_int8_final.onnx";
+ vocabPath = Environment.getExternalStorageDirectory().getPath() + "/models/Translation/HY-MT" + "/tokenizer.json";
}
String finalDecoderPath = decoderPath;
@@ -197,7 +208,7 @@ public void onFailure(int[] reasons, long value) {
}
});
} else {
- final OrtSession.SessionOptions.OptLevel optDefaultLevel = OrtSession.SessionOptions.OptLevel.BASIC_OPT;
+ final OrtSession.SessionOptions.OptLevel optDefaultLevel = OrtSession.SessionOptions.OptLevel.EXTENDED_OPT;
boolean arena = true;
OrtSession.SessionOptions decoderOptions = new OrtSession.SessionOptions();
@@ -210,13 +221,13 @@ public void onFailure(int[] reasons, long value) {
encoderOptions.setMemoryPatternOptimization(arena);
encoderOptions.setCPUArenaAllocator(arena);
encoderOptions.setOptimizationLevel(optDefaultLevel);
- encoderSession = onnxEnv.createSession(finalEncoderPath, encoderOptions);
+ if(mode != HY_MT) encoderSession = onnxEnv.createSession(finalEncoderPath, encoderOptions);
OrtSession.SessionOptions cacheInitOptions = new OrtSession.SessionOptions();
cacheInitOptions.setMemoryPatternOptimization(arena);
cacheInitOptions.setCPUArenaAllocator(arena);
cacheInitOptions.setOptimizationLevel(optDefaultLevel);
- cacheInitSession = onnxEnv.createSession(finalCacheInitializerPath, cacheInitOptions);
+ if(mode != HY_MT) cacheInitSession = onnxEnv.createSession(finalCacheInitializerPath, cacheInitOptions);
OrtSession.SessionOptions embedAndLmHeadOptions = new OrtSession.SessionOptions();
embedAndLmHeadOptions.setMemoryPatternOptimization(arena);
@@ -224,7 +235,7 @@ public void onFailure(int[] reasons, long value) {
embedAndLmHeadOptions.setOptimizationLevel(optDefaultLevel);
if (mode == MADLAD_CACHE) {
embedSession = onnxEnv.createSession(finalEmbedAndLmHeadPath, embedAndLmHeadOptions);
- } else {
+ } else if(mode != HY_MT){
embedAndLmHeadSession = onnxEnv.createSession(finalEmbedAndLmHeadPath, embedAndLmHeadOptions);
}
@@ -234,18 +245,21 @@ public void onFailure(int[] reasons, long value) {
embedAndLmHeadOptions.close();
//mainHandler.post(() -> initListener.onInitializationFinished());
- initListener.onSuccess();
+ mainHandler.post(() -> initListener.onSuccess());
}
- } catch (OrtException e) {
+ if(mode == MADLAD || mode == MADLAD_CACHE) {
+ tokenizer = new Tokenizer(finalVocabPath, Tokenizer.MADLAD);
+ }else if(mode == NLLB || mode == NLLB_CACHE) {
+ tokenizer = new Tokenizer(finalVocabPath, Tokenizer.NLLB);
+ }else if(mode == HY_MT) {
+ tokenizer = new Tokenizer(finalVocabPath, Tokenizer.HY_MT);
+ }
+
+ } catch (OrtException | IOException e) {
e.printStackTrace();
mainHandler.post(() -> initListener.onFailure(new int[]{ErrorCodes.ERROR_LOADING_MODEL},0));
}
- if(mode == MADLAD_CACHE) {
- tokenizer = new Tokenizer(finalVocabPath, Tokenizer.MADLAD);
- }else{
- tokenizer = new Tokenizer(finalVocabPath, Tokenizer.NLLB);
- }
}
};
t.start();
@@ -640,6 +654,14 @@ private void performTextTranslation(final String textToTranslate, final CustomLo
}
if(mode != MOZILLA){
+ int maxLength = 200;
+ if(mode == NLLB || mode == NLLB_CACHE){
+ maxLength = 200;
+ }else if(mode == MADLAD || mode == MADLAD_CACHE){
+ maxLength = 200; //todo: research the best value for madlad
+ }else if(mode == HY_MT){
+ maxLength = 5000; //todo: research the best value for hy-mt
+ }
//we split the input text in sentences
ArrayList textSplit = new ArrayList<>();
BreakIterator iterator = BreakIterator.getSentenceInstance(inputLanguage.getLocale());
@@ -653,9 +675,9 @@ private void performTextTranslation(final String textToTranslate, final CustomLo
while (joined) {
joined = false;
for (int i = 1; i < textSplit.size(); i++) {
- int numTokens = tokenizer.tokenize(getNllbLanguageCode(inputLanguage.getCode()), getNllbLanguageCode(outputLanguage.getCode()), textSplit.get(i - 1)).getInputIDs().length;
- int numTokens2 = tokenizer.tokenize(getNllbLanguageCode(inputLanguage.getCode()), getNllbLanguageCode(outputLanguage.getCode()), textSplit.get(i)).getInputIDs().length;
- if ((numTokens + numTokens2 < 200) || (numTokens2 < 5)) {
+ int numTokens = tokenize(textSplit.get(i - 1), inputLanguage, outputLanguage, true).getInputIDs().length;
+ int numTokens2 = tokenize(textSplit.get(i), inputLanguage, outputLanguage, true).getInputIDs().length;
+ if ((numTokens + numTokens2 < maxLength) || (numTokens2 < 5)) {
textSplit.set(i - 1, textSplit.get(i - 1) + textSplit.get(i));
textSplit.remove(i);
i = i - 1;
@@ -679,29 +701,25 @@ private void performTextTranslation(final String textToTranslate, final CustomLo
//tokenization
long time = System.currentTimeMillis();
TokenizerResult input = null;
- String correctedSubText = correctText(textSplit.get(i), inputLanguage.getLocale());
- if (mode == MADLAD_CACHE) {
- input = tokenizer.tokenize(inputLanguage.getCode(), outputLanguage.getCode(), correctedSubText);
- } else { //if mode == NLLB_CACHE
- input = tokenizer.tokenize(getNllbLanguageCode(inputLanguage.getCode()), getNllbLanguageCode(outputLanguage.getCode()), correctedSubText);
- }
+ String correctedSubText = correctText(textSplit.get(i), inputLanguage.getLocale()); //input text pre process
+ input = tokenize(correctedSubText, inputLanguage, outputLanguage);
android.util.Log.i("performance", "Tokenization done in: " + (System.currentTimeMillis() - time) + "ms");
//encoder execution
time = System.currentTimeMillis();
- OnnxTensor encoderResult = executeEncoder(input.getInputIDs(), input.getAttentionMask());
- android.util.Log.i("performance", "Encoder done in: " + (System.currentTimeMillis() - time) + "ms");
- if (encoderResult == null) {
- if (responseListener != null) {
- mainHandler.post(() -> responseListener.onFailure(new int[]{ErrorCodes.ERROR_EXECUTING_MODEL}, 0));
- } else {
- mainHandler.post(() -> notifyError(new int[]{ErrorCodes.ERROR_EXECUTING_MODEL}, 0));
+ OnnxTensor encoderResult = null;
+ if(mode != HY_MT) {
+ encoderResult = executeEncoder(input.getInputIDs(), input.getAttentionMask());
+ android.util.Log.i("performance", "Encoder done in: " + (System.currentTimeMillis() - time) + "ms");
+ if (encoderResult == null) {
+ if (responseListener != null) {
+ mainHandler.post(() -> responseListener.onFailure(new int[]{ErrorCodes.ERROR_EXECUTING_MODEL}, 0));
+ } else {
+ mainHandler.post(() -> notifyError(new int[]{ErrorCodes.ERROR_EXECUTING_MODEL}, 0));
+ }
+ return;
}
- return;
}
//decoder execution
- final int eos = tokenizer.PieceToID("");
- ArrayList completeOutput = new ArrayList();
- completeOutput.add(0); //tokenizer.PieceToID("")
TranslateListener translateListener = new TranslateListener() {
@Override
public void onTranslatedText(String textToTranslate, String text, long resultID, boolean isFinal, CustomLocale languageOfText) {
@@ -734,12 +752,12 @@ public void onFailure(int[] reasons, long value) {
}
};
if (beamSize > 1) { //beam search
- executeCacheDecoderBeam(textToTranslate, input, encoderResult, completeBeamOutput, beamsOutputsProbabilities, outputLanguage, beamSize, translateListener);
+ executeCacheDecoder(textToTranslate, input, encoderResult, completeBeamOutput, beamsOutputsProbabilities, inputLanguage, outputLanguage, beamSize, translateListener);
} else if (beamSize == 1) { //greedy search (with kv cache)
- executeCacheDecoderGreedy(textToTranslate, input, encoderResult, completeOutput, outputLanguage, translateListener);
+ executeCacheDecoder(textToTranslate, input, encoderResult, completeBeamOutput, null, inputLanguage, outputLanguage, 1, translateListener);
}
- //we convert the ids of completeOutputs into a string and return it
- encoderResult.close();
+ //we convert the ids of completeBeamOutputs into a string and return it
+ if(encoderResult != null) encoderResult.close();
int[] completeOutputArray;
if ((mode == MADLAD_CACHE || mode == NLLB_CACHE) && beamSize > 1) {
int indexMax = 0;
@@ -748,7 +766,7 @@ public void onFailure(int[] reasons, long value) {
}
completeOutputArray = completeBeamOutput[indexMax].stream().mapToInt(k -> k).toArray();
} else {
- completeOutputArray = completeOutput.stream().mapToInt(k -> k).toArray(); //converte completeOutput in un array di int
+ completeOutputArray = completeBeamOutput[0].stream().mapToInt(k -> k).toArray(); //converte completeBeamOutput in un array di int
}
String finalSplitResult = tokenizer.decode(completeOutputArray);
if (joinedStringOutput[0].equals("")) {
@@ -828,6 +846,7 @@ private OnnxTensor executeEncoder(int[] inputIDs, int[] attentionMask){
}
}
+ //todo: remove this method in the future
public void executeCacheDecoderGreedy(String textToTranslate, TokenizerResult input, OnnxTensor encoderResult, ArrayList completeOutput, final CustomLocale outputLanguage, @Nullable final TranslateListener responseListener){
try {
long time = System.currentTimeMillis();
@@ -899,7 +918,6 @@ public void executeCacheDecoderGreedy(String textToTranslate, TokenizerResult in
embedResult = embedSession.run(embedInput, requestedOutputs);
decoderInput.put("embed_matrix", (OnnxTensor) embedResult.get(0));
- //decoderInput.put("encoder_hidden_states", encoderResult);
}
if(j == 1){
long[] shape = {1, 16, 0, hiddenSize};
@@ -1008,17 +1026,34 @@ public void executeCacheDecoderGreedy(String textToTranslate, TokenizerResult in
}
}
- public void executeCacheDecoderBeam(String textToTranslate, TokenizerResult input, OnnxTensor encoderResult, ArrayList[] completeBeamOutput, double[] beamsOutputsProbabilities, final CustomLocale outputLanguage, int beamSize, @Nullable final TranslateListener responseListener) {
- final int eos = tokenizer.PieceToID("");
+ public void executeCacheDecoder(String textToTranslate, TokenizerResult input, @Nullable OnnxTensor encoderResult, ArrayList[] completeBeamOutput, @Nullable double[] beamsOutputsProbabilities, final CustomLocale inputLanguage, final CustomLocale outputLanguage, int beamSize, @Nullable final TranslateListener responseListener) {
+ int eos;
+ if(mode == HY_MT){
+ eos = tokenizer.PieceToID("<|hy_place▁holder▁no▁2|>");
+ }else{
+ eos = tokenizer.PieceToID("");
+ }
int nLayers;
int hiddenSize;
+ int hiddenSizeAttention;
+ int nHeads;
if(mode == MADLAD_CACHE){
nLayers = 32;
- hiddenSize = 128;
- }else{ //if mode == NLLB_CACHE
+ hiddenSize = 1024;
+ hiddenSizeAttention = 128;
+ nHeads = 16;
+ }else if(mode == NLLB_CACHE){
nLayers = 12;
- hiddenSize = 64;
+ hiddenSize = 1024;
+ hiddenSizeAttention = 64;
+ nHeads = 16;
+ }else{ //if mode == HY_MT
+ nLayers = 32;
+ hiddenSize = 2048;
+ hiddenSizeAttention = 128;
+ nHeads = 4; //nHeads in this case refers only to the number of heads used in kvCache, the real number of heads are 16, but this model uses group query attention, with a group size of 4
}
+ int initialPromptLength = tokenize(" ", inputLanguage, outputLanguage, false).getInputIDs().length;
try {
long initialTime;
@@ -1027,78 +1062,35 @@ public void executeCacheDecoderBeam(String textToTranslate, TokenizerResult inpu
OnnxTensor inputIDsTensor;
if(mode == MADLAD_CACHE){
inputIDsTensor = TensorUtils.convertIntArrayToTensor(onnxEnv, new int[]{0}); //for the first iteration we use input_ids = 0, with batch_size = 1
- }else{ //if mode == NLLB_CACHE
+ }else if(mode == NLLB_CACHE){
inputIDsTensor = TensorUtils.convertIntArrayToTensor(onnxEnv, new int[]{2}); //for the first iteration we use input_ids = 2, with batch_size = 1
+ }else{ // if mode == HY_MT
+ inputIDsTensor = TensorUtils.convertIntArrayToTensor(onnxEnv, input.getInputIDs()); //for the first iteration we use the input_ids generated by the tokenizer (with the prompt), with batch_size = 1
}
- OnnxTensor encoderAttentionMaskTensor = TensorUtils.convertIntArrayToTensor(onnxEnv, input.getAttentionMask());
- int encoderInputIdsLength = input.getInputIDs().length;
+ int[] attentionMask = input.getAttentionMask();
+ OnnxTensor attentionMaskTensor = TensorUtils.convertIntArrayToTensor(onnxEnv, attentionMask);
CacheContainerNative cacheContainer = null;
OnnxTensor decoderOutput = null;
Map decoderInput = new HashMap();
- float [][][] outputValues = null;
+ float [][][] logits = null;
time = System.currentTimeMillis();
//preparing cache initializer input
- Map initInput = new HashMap();
- initInput.put("encoder_hidden_states", encoderResult);
- //execution of the cache initializer
- OrtSession.Result initResult = cacheInitSession.run(initInput);
- android.util.Log.i("performance", "Cache initialization done in: " + (System.currentTimeMillis()-time) + "ms");
-
- time = System.currentTimeMillis();
- //we convert the fixed decoder inputs to have batch_size==beamSize
- OnnxTensor encoderResultBatched = null;
- if(mode == MADLAD_CACHE) {
- float[][] encoderValue = ((float[][][]) encoderResult.getValue())[0];
- float[] encoderValueFlatBatched = TensorUtils.flattenFloatArrayBatched(encoderValue, beamSize);
- encoderResultBatched = TensorUtils.createFloatTensor(onnxEnv, encoderValueFlatBatched, new long[]{beamSize, encoderValue.length, encoderValue[0].length});
- encoderValue = null; //free the memory
- encoderValueFlatBatched = null; //free the memory
- //System.gc();
- android.util.Log.i("performance", "Encoder batch initialization done in: " + (System.currentTimeMillis()-time) + "ms");
- }
- time = System.currentTimeMillis();
- OnnxTensor encoderAttentionMaskTensorBatched;
- int[] encoderMaskFlatBatched = TensorUtils.flattenIntArrayBatched(input.getAttentionMask(), beamSize);
- encoderAttentionMaskTensorBatched = TensorUtils.createIntTensor(onnxEnv, encoderMaskFlatBatched, new long[]{beamSize, input.getAttentionMask().length});
- encoderMaskFlatBatched = null; //free the memory
- //System.gc();
- android.util.Log.i("performance", "Mask batch initialization done in: " + (System.currentTimeMillis()-time) + "ms");
- time = System.currentTimeMillis();
- OrtSession.Result initResultBatched;
- String[] names = new String[2*nLayers];
- OnnxValue[] values = new OnnxValue[2*nLayers];
- boolean[] ownedByResult = new boolean[2*nLayers];
- Arrays.fill(ownedByResult, true);
- String[] suffixes = {"key", "value"};
- long timeExtract = 0;
- long timeBatch = 0;
- long timeCreate = 0;
- int count = 0;
- for (int i = 0; i < nLayers; i++) {
- for (String suffix: suffixes) {
- //System.gc();
- names[count] = "present." + i + ".encoder."+suffix;
- long timeInner = System.currentTimeMillis();
- float[][][] keyValue = ((float[][][][]) TensorUtils.extractValue(initResult, "present." + i + ".encoder."+suffix))[0];
- timeExtract += System.currentTimeMillis() - timeInner;
- timeInner = System.currentTimeMillis();
- float[][][][] keyValueFlatBatched = TensorUtils.batchTensor(keyValue, beamSize);
- timeBatch += System.currentTimeMillis() - timeInner;
- timeInner = System.currentTimeMillis();
- values[count] = TensorUtils.createFloatTensorOptimized(onnxEnv, keyValueFlatBatched, new long[]{beamSize, keyValue.length, keyValue[0].length, keyValue[0][0].length});;
- timeCreate += System.currentTimeMillis() - timeInner;
- count++;
- }
+ OrtSession.Result initResult = null;
+ OnnxTensor attentionMaskTensorBatched = null;
+ OrtSession.Result initResultBatched = null;
+ if(mode != HY_MT) {
+ Map initInput = new HashMap();
+ initInput.put("encoder_hidden_states", encoderResult);
+ //execution of the cache initializer
+ initResult = cacheInitSession.run(initInput);
+ android.util.Log.i("performance", "Cache initialization done in: " + (System.currentTimeMillis() - time) + "ms");
+ if (encoderResult != null)
+ encoderResult.close(); //we close it because from now on we only need initResult
}
- //the Result constructor is private but this way we can use it anyway
- Constructor constructor = OrtSession.Result.class.getDeclaredConstructor(names.getClass(), values.getClass(), ownedByResult.getClass());
- constructor.setAccessible(true);
- initResultBatched = constructor.newInstance(names, values, ownedByResult);
- android.util.Log.i("performance", "InitResult extract done in: " + timeExtract + "ms");
- android.util.Log.i("performance", "InitResult batch done in: " + timeBatch + "ms");
- android.util.Log.i("performance", "InitResult create done in: " + timeCreate + "ms");
- android.util.Log.i("performance", "InitResult batch initialization done in: " + (System.currentTimeMillis()-time) + "ms");
+ //we convert the fixed decoder inputs to have batch_size==beamSize
+ attentionMaskTensorBatched = beamSize > 1 ? batchEncoderAttentionMask(attentionMask, beamSize, true) : null;
+ initResultBatched = initResult != null && beamSize > 1 ? batchEncoderKvCache(initResult, nLayers, beamSize, true) : null; //this is not executed for HY_MT
//we begin the iterative execution of the decoder
String[] partialResults = new String[beamSize]; //used for log
@@ -1107,19 +1099,20 @@ public void executeCacheDecoderBeam(String textToTranslate, TokenizerResult inpu
int[] max = new int[beamSize];
int[][] beamMax = new int[beamSize][beamSize];
int j = 1;
- OnnxTensor emptyPreLogits = TensorUtils.createFloatTensorWithSingleValue(onnxEnv, 0, new long[]{EMPTY_BATCH_SIZE, 1, 1024});
- OnnxTensor emptyPreLogitsBatch = TensorUtils.createFloatTensorWithSingleValue(onnxEnv, 0, new long[]{beamSize, 1, 1024});
+ OnnxTensor emptyPreLogits = TensorUtils.createFloatTensorWithSingleValue(onnxEnv, 0, new long[]{EMPTY_BATCH_SIZE, 1, hiddenSize});
+ OnnxTensor emptyPreLogitsBatch = TensorUtils.createFloatTensorWithSingleValue(onnxEnv, 0, new long[]{beamSize, 1, hiddenSize});
OnnxTensor emptyInputIds = TensorUtils.createInt64TensorWithSingleValue(onnxEnv, 0, new long[]{EMPTY_BATCH_SIZE, 2});
OnnxTensor emptyInputIdsBatch = TensorUtils.createInt64TensorWithSingleValue(onnxEnv, 0, new long[]{beamSize, 2});
- while(input_ids[0] != eos){ //input_ids[0] should always contain the ultimate value generated from the text with highest probability (to be verified)
+ while(max[0] != eos){ //max[0] should always contain the ultimate value generated from the text with highest probability (to be verified)
initialTime = System.currentTimeMillis();
time = System.currentTimeMillis();
+
//we prepare the decoder input
decoderInput = new HashMap();
OrtSession.Result embedResult = null;
if(mode == NLLB_CACHE){
- //we do the embedding separately and then we pass the result to the encoder
+ //we do the embedding separately and then we pass the result to the decoder
Map embedInput = new HashMap();
embedInput.put("input_ids", inputIDsTensor);
embedInput.put("pre_logits", j == 1 ? emptyPreLogits : emptyPreLogitsBatch);
@@ -1131,6 +1124,7 @@ public void executeCacheDecoderBeam(String textToTranslate, TokenizerResult inpu
decoderInput.put("embed_matrix", (OnnxTensor) embedResult.get(0));
}
if(mode == MADLAD_CACHE) {
+ //we do the embedding separately and then we pass the result to the decoder
Map embedInput = new HashMap();
embedInput.put("input_ids", inputIDsTensor);
ArraySet requestedOutputs = new ArraySet<>();
@@ -1138,44 +1132,49 @@ public void executeCacheDecoderBeam(String textToTranslate, TokenizerResult inpu
embedResult = embedSession.run(embedInput, requestedOutputs);
decoderInput.put("embed_matrix", (OnnxTensor) embedResult.get(0));
- //decoderInput.put("encoder_hidden_states", encoderResult);
}
decoderInput.put("input_ids", inputIDsTensor);
- if(j == 1){ //se è la prima iterazione
+ if(j == 1){ //if it is the first iteration
//we run the decoder with a batch_size = 1
- decoderInput.put("encoder_attention_mask", encoderAttentionMaskTensor);
- if(mode == MADLAD_CACHE) {
- //decoderInput.put("encoder_hidden_states", encoderResult);
+ if(mode != HY_MT) {
+ decoderInput.put("encoder_attention_mask", attentionMaskTensor);
+ }else{
+ decoderInput.put("attention_mask", attentionMaskTensor);
}
- long[] shape = {1, 16, 0, hiddenSize};
+ long[] shape = {1, nHeads, 0, hiddenSizeAttention};
OnnxTensor decoderPastTensor = TensorUtils.createFloatTensorWithSingleValue(onnxEnv,0, shape);
for (int i = 0; i < nLayers; i++) {
- decoderInput.put("past_key_values." + i + ".decoder.key", decoderPastTensor);
- decoderInput.put("past_key_values." + i + ".decoder.value", decoderPastTensor);
- decoderInput.put("past_key_values." + i + ".encoder.key", (OnnxTensor) initResult.get("present." + i + ".encoder.key").get());
- decoderInput.put("past_key_values." + i + ".encoder.value", (OnnxTensor) initResult.get("present." + i + ".encoder.value").get());
+ decoderInput.put("past_key_values." + i + (mode != HY_MT ? ".decoder" : "") + ".key", decoderPastTensor);
+ decoderInput.put("past_key_values." + i + (mode != HY_MT ? ".decoder" : "") + ".value", decoderPastTensor);
+ if(mode != HY_MT){
+ decoderInput.put("past_key_values." + i + ".encoder.key", (OnnxTensor) initResult.get("present." + i + ".encoder.key").get());
+ decoderInput.put("past_key_values." + i + ".encoder.value", (OnnxTensor) initResult.get("present." + i + ".encoder.value").get());
+ }
}
}else {
- if(j == 2) {
- encoderAttentionMaskTensor.close(); //we close it because from now on we only need encoderAttentionMaskTensorBatched
- encoderResult.close(); //we close it because from now on we only need encoderResultBatched
- initResult.close(); //we close it because from now on we only need initResultBatched
+ if(j == 2 && beamSize > 1) {
+ attentionMaskTensor.close(); //we close it because from now on we only need attentionMaskTensorBatched
+ if(initResult != null) initResult.close(); //we close it because from now on we only need initResultBatched
}
//we run the decoder with batch_size = beamSize
- decoderInput.put("encoder_attention_mask", encoderAttentionMaskTensorBatched);
- if(mode == MADLAD_CACHE) {
- //decoderInput.put("encoder_hidden_states", encoderResultBatched);
+ if(mode != HY_MT) {
+ decoderInput.put("encoder_attention_mask", beamSize > 1 ? attentionMaskTensorBatched : attentionMaskTensor);
+ }else{
+ decoderInput.put("attention_mask", beamSize > 1 ? attentionMaskTensorBatched : attentionMaskTensor);
}
for (int i = 0; i < nLayers; i++) {
- decoderInput.put("past_key_values." + i + ".decoder.key", (OnnxTensor) result.get("present." + i + ".decoder.key").get());
- decoderInput.put("past_key_values." + i + ".decoder.value", (OnnxTensor) result.get("present." + i + ".decoder.value").get());
- decoderInput.put("past_key_values." + i + ".encoder.key", (OnnxTensor) initResultBatched.get("present." + i + ".encoder.key").get());
- decoderInput.put("past_key_values." + i + ".encoder.value", (OnnxTensor) initResultBatched.get("present." + i + ".encoder.value").get());
+ decoderInput.put("past_key_values." + i + (mode != HY_MT ? ".decoder" : "") + ".key", (OnnxTensor) result.get("present." + i + (mode != HY_MT ? ".decoder" : "") + ".key").get());
+ decoderInput.put("past_key_values." + i + (mode != HY_MT ? ".decoder" : "") + ".value", (OnnxTensor) result.get("present." + i + (mode != HY_MT ? ".decoder" : "") + ".value").get());
+ if(mode != HY_MT) {
+ decoderInput.put("past_key_values." + i + ".encoder.key", (OnnxTensor) (beamSize > 1 ? initResultBatched : initResult).get("present." + i + ".encoder.key").get());
+ decoderInput.put("past_key_values." + i + ".encoder.value", (OnnxTensor) (beamSize > 1 ? initResultBatched : initResult).get("present." + i + ".encoder.value").get());
+ }
}
}
oldResult = result;
android.util.Log.i("performance", "pre-execution of"+j+"th word done in: " + (System.currentTimeMillis()-time) + "ms");
time = System.currentTimeMillis();
+
//decoder execution (with cache)
result = decoderSession.run(decoderInput);
android.util.Log.i("performance", "execution of"+j+"th word done in: " + (System.currentTimeMillis()-time) + "ms");
@@ -1188,7 +1187,7 @@ public void executeCacheDecoderBeam(String textToTranslate, TokenizerResult inpu
embedResult.close();
}
android.util.Log.i("performance", "release RAM of"+j+"th word done in: " + (System.currentTimeMillis()-time) + "ms");
- //we take the logits and the max value
+ //we take the logits
OrtSession.Result lmHeadResult = null;
if(mode == NLLB_CACHE) {
//we execute the lmHead separately to get the logits
@@ -1204,24 +1203,17 @@ public void executeCacheDecoderBeam(String textToTranslate, TokenizerResult inpu
decoderOutput = (OnnxTensor) result.get("logits").get();
}
//we take the logits and the larger "beamSize" values
+ logits = (float[][][]) decoderOutput.getValue();
if(j == 1) { //if we are at the first iteration
- //decoderOutput = (OnnxTensor) result.get("logits").get();
- outputValues = (float[][][]) decoderOutput.getValue();
- //the "beamSize" words with highest probability are inserted into max and added to completeBeamOutput
- ArrayList indexesToAvoid = new ArrayList<>();
- for (int i = 0; i < beamSize; i++) {
- max[i] = Utils.getIndexOfLargest(outputValues[0][0], indexesToAvoid);
- indexesToAvoid.add(max[i]);
- completeBeamOutput[i].add(max[i]);
- }
- //we insert the initial probabilities of the "beamSize" output strings into beamsOutputsProbabilities
- for (int i = 0; i < beamSize; i++) {
- float maxLogit = outputValues[0][0][max[i]];
- //old version of probability calculation (softmax)
- //beamsOutputsProbabilities[i] = Math.log(Utils.softmax(maxLogit, outputValues[0][0]));
- //new version of probability calculation (logSumExp)
- beamsOutputsProbabilities[i] = maxLogit - Utils.logSumExpFast(outputValues[0][0]);
+ if(beamSize > 1) {
+ //based on the logits, we initialize max, completeBeamOutput and beamsOutputsProbabilities
+ initBeamSearchData(logits, beamSize, max, completeBeamOutput, beamsOutputsProbabilities);
+ }else{
+ int seqLen = logits[0].length;
+ max[0] = Utils.getIndexOfLargest(logits[0][seqLen-1]);
+ completeBeamOutput[0].add(max[0]);
}
+
//we prepare the inputs of the next iteration
if(mode == NLLB_CACHE){
for(int i=0; i indexesToAvoid = new ArrayList<>();
- for (int i = 0; i < beamSize; i++) {
- beamMax[k][i] = Utils.getIndexOfLargest(outputValues[k][0], indexesToAvoid);
- indexesToAvoid.add(beamMax[k][i]);
- }
- }
- //Now beamMax will contain for each decoder output ("beamSize" outputs) the "beamSize" words with highest probability,
- // so for each output we calculate its overall probability for each of its "beamSize" words with highest probability
- long timeSoftmax = System.currentTimeMillis();
- double[] beamsOutputsProbabilitiesTemp = new double[beamSize*beamSize];
- for(int k=0; k < beamSize; k++) {
- //old version of probability calculation (softmax)
- /*for (int i = 0; i < beamSize; i++) {
- beamsOutputsProbabilitiesTemp[(k*beamSize)+i] = beamsOutputsProbabilities[k] + Math.log(Utils.softmax(outputValues[k][0][beamMax[k][i]], outputValues[k][0]));
- if(beamMax[k][i] == eos){
- beamsOutputsProbabilitiesTemp[(k*beamSize)+i] = beamsOutputsProbabilitiesTemp[(k*beamSize)+i]/EOS_PENALTY;
- }
- }*/
- //new version of probability calculation (logSumExp)
- double logSumExp = Utils.logSumExpFast(outputValues[k][0]);
- for (int i = 0; i < beamSize; i++) {
- float maxLogit = outputValues[k][0][beamMax[k][i]];
- beamsOutputsProbabilitiesTemp[(k*beamSize)+i] = beamsOutputsProbabilities[k] + maxLogit - logSumExp;
- if(beamMax[k][i] == eos){
- beamsOutputsProbabilitiesTemp[(k*beamSize)+i] = beamsOutputsProbabilitiesTemp[(k*beamSize)+i]/EOS_PENALTY;
- }
- }
- }
- android.util.Log.i("performance", "softmax done in: " + (System.currentTimeMillis()-timeSoftmax) + "ms");
- // Now we save in maxProbabilities the indices of the "beamSize" words generated by the decoder that have the
- // highest overall probability with their respective output sentences and then we will use them as the next inputs
- ArrayList indexesToAvoid = new ArrayList<>();
- int[] maxProbabilities = new int[beamSize];
- for(int i=0; i[] oldCompleteBeamOutput = completeBeamOutput.clone();
- for (int i = 0; i < beamSize; i++) {
- beamsOutputsProbabilities[i] = beamsOutputsProbabilitiesTemp[maxProbabilities[i]];
- completeBeamOutput[i] = (ArrayList) oldCompleteBeamOutput[maxProbabilities[i]/beamSize].clone();
- completeBeamOutput[i].add(beamMax[maxProbabilities[i]/beamSize][maxProbabilities[i]%beamSize]);
+ if(beamSize > 1) {
+ //based on the logits we update beam search data
+ int[] maxProbabilities = new int[beamSize];
+ int sequenceLength = mode == HY_MT ? attentionMask.length : j;
+ cacheContainer = updateBeamSearchData(logits, beamSize, eos, result, sequenceLength, nLayers, nHeads, hiddenSizeAttention, cacheContainer, maxProbabilities, beamMax, max, completeBeamOutput, beamsOutputsProbabilities);
+ }else{
+ int seqLen = logits[0].length;
+ max[0] = Utils.getIndexOfLargest(logits[0][seqLen-1]);
+ completeBeamOutput[0].add(max[0]);
}
+
//we prepare the inputs of the next iteration
- for (int i = 0; i < beamSize; i++) {
- input_ids[i] = beamMax[maxProbabilities[i]/beamSize][maxProbabilities[i]%beamSize];
- }
+ input_ids = max;
inputIDsTensor = TensorUtils.createIntTensor(onnxEnv, input_ids, new long[]{beamSize,1});
- long timeCache = System.currentTimeMillis();
- CacheContainerNative oldCache = cacheContainer;
- cacheContainer = new CacheContainerNative(onnxEnv, result, nLayers, beamSize, 16, j, hiddenSize);
- if(oldCache != null){
- oldCache.close();
- }
- android.util.Log.i("performance", "cache creation done in: " + (System.currentTimeMillis()-timeCache) + "ms");
- int[] indexes = new int[beamSize];
- for(int i=0; i 1) {
+ attentionMaskTensorBatched = batchEncoderAttentionMask(attentionMask, beamSize, true);
+ }else{
+ attentionMaskTensor = TensorUtils.convertIntArrayToTensor(onnxEnv, attentionMask);
}
- timeCache = System.currentTimeMillis();
- cacheContainer.reorder(indexes);
- android.util.Log.i("performance", "cache reorder done in: " + (System.currentTimeMillis()-timeCache) + "ms");
}
android.util.Log.i("performance", "post-execution of" + j + "th word done in: " + (System.currentTimeMillis() - time) + "ms");
android.util.Log.i("performance", "Generation of" + j + "th word done in: " + (System.currentTimeMillis() - initialTime) + "ms");
// we return the partial result with the highest probability
int indexMax = 0;
- for(int i=0; i 1) {
+ for (int i = 0; i < beamSize; i++) {
+ indexMax = Utils.getIndexOfLargest(beamsOutputsProbabilities);
+ }
}
- int [] outputIDs = completeBeamOutput[indexMax].stream().mapToInt(k -> k).toArray();
+ int[] outputIDs = completeBeamOutput[indexMax].stream().mapToInt(k -> k).toArray();
String partialResult = tokenizer.decode(outputIDs);
if(responseListener != null) {
responseListener.onTranslatedText(textToTranslate, partialResult, currentResultID, false, outputLanguage);
@@ -1344,22 +1273,32 @@ public void executeCacheDecoderBeam(String textToTranslate, TokenizerResult inpu
partialResults[i] = tokenizer.decode(completeBeamOutput[i].stream().mapToInt(k -> k).toArray());
android.util.Log.i("result "+i, partialResults[i]);
}
- }
- if(result != null) {
- result.close();
- }
- initResult.close();
- if(cacheContainer != null) {
- cacheContainer.close();
- }
- if (encoderAttentionMaskTensorBatched != null) {
- encoderAttentionMaskTensorBatched.close();
- }
- if(encoderResultBatched != null) {
- encoderResultBatched.close();
+ //early stop if the decoder is generating in loop
+ if(input.getInputIDs().length - initialPromptLength > 30){ //if the input is long
+ if(j > 3*input.getInputIDs().length) {
+ break;
+ }
+ }else if(input.getInputIDs().length - initialPromptLength > 20){ //if the input is medium length
+ if(j > 4*input.getInputIDs().length){
+ break;
+ }
+ }else if(input.getInputIDs().length - initialPromptLength > 10){ //if the input is short
+ if(j > 5*input.getInputIDs().length){
+ break;
+ }
+ }else if(input.getInputIDs().length - initialPromptLength > 5){ //if the input is very short
+ if(j > 8*input.getInputIDs().length){
+ break;
+ }
+ }
}
- initResultBatched.close();
+
+ if(result != null) result.close();
+ if(initResult != null) initResult.close();
+ if(cacheContainer != null) cacheContainer.close();
+ if(attentionMaskTensorBatched != null) attentionMaskTensorBatched.close();
+ if(initResultBatched != null) initResultBatched.close();
} catch (OrtException | InvocationTargetException | NoSuchMethodException |
IllegalAccessException | InstantiationException e) {
@@ -1388,13 +1327,186 @@ private String correctText(String text, Locale locale){
if(!language.equals("th")) {
correctedText = correctedText.trim(); //we remove eventual white space from both ends of the text
if(correctedText.length() >= 2) {
- if (!Character.isLetterOrDigit(correctedText.charAt(correctedText.length() - 1))) {
- return correctedText;
+ if (Character.isLetterOrDigit(correctedText.charAt(correctedText.length() - 1))) {
+ correctedText = correctedText + getSentenceTerminator(locale);
+ }
+ }
+ }
+ //for Madlad only, we remove all the control characters (like \n), because those will make the model hallucinate
+ if(mode == MADLAD || mode == MADLAD_CACHE){
+ correctedText = text.replaceAll("\\R", " ") // remove all newlines
+ .replaceAll("\\p{Cntrl}", "") // remove other control chars
+ .trim();
+ }
+ // collapse whitespace
+ correctedText = correctedText.replaceAll("\\s+", " ");
+ return correctedText;
+ }
+
+ private OnnxTensor batchEncoderAttentionMask(int[] attentionMask, int batchSize, boolean log) throws OrtException {
+ long time = System.currentTimeMillis();
+ int[] encoderMaskFlatBatched = TensorUtils.flattenIntArrayBatched(attentionMask, batchSize);
+ OnnxTensor encoderAttentionMaskTensorBatched = TensorUtils.createIntTensor(onnxEnv, encoderMaskFlatBatched, new long[]{batchSize, attentionMask.length});
+ encoderMaskFlatBatched = null; //free the memory
+ //System.gc();
+ if(log) {
+ android.util.Log.i("performance", "Mask batch initialization done in: " + (System.currentTimeMillis() - time) + "ms");
+ }
+ return encoderAttentionMaskTensorBatched;
+ }
+
+ @NonNull
+ private OrtSession.Result batchEncoderKvCache(OrtSession.Result result, int nLayers, int batchSize, boolean log) throws InvocationTargetException, IllegalAccessException, InstantiationException, NoSuchMethodException, OrtException {
+ long time = System.currentTimeMillis();
+ String[] names = new String[2*nLayers];
+ OnnxValue[] values = new OnnxValue[2*nLayers];
+ boolean[] ownedByResult = new boolean[2*nLayers];
+ Arrays.fill(ownedByResult, true);
+ String[] suffixes = {"key", "value"};
+ long timeExtract = 0;
+ long timeBatch = 0;
+ long timeCreate = 0;
+ int count = 0;
+ for (int i = 0; i < nLayers; i++) {
+ for (String suffix: suffixes) {
+ //System.gc();
+ names[count] = "present." + i + ".encoder."+suffix;
+ long timeInner = System.currentTimeMillis();
+ float[][][] keyValue = ((float[][][][]) TensorUtils.extractValue(result, "present." + i + ".encoder."+suffix))[0];
+ timeExtract += System.currentTimeMillis() - timeInner;
+ timeInner = System.currentTimeMillis();
+ float[][][][] keyValueFlatBatched = TensorUtils.batchTensor(keyValue, batchSize);
+ timeBatch += System.currentTimeMillis() - timeInner;
+ timeInner = System.currentTimeMillis();
+ values[count] = TensorUtils.createFloatTensorOptimized(onnxEnv, keyValueFlatBatched, new long[]{batchSize, keyValue.length, keyValue[0].length, keyValue[0][0].length});;
+ timeCreate += System.currentTimeMillis() - timeInner;
+ count++;
+ }
+ }
+ //the Result constructor is private but this way we can use it anyway
+ Constructor constructor = OrtSession.Result.class.getDeclaredConstructor(names.getClass(), values.getClass(), ownedByResult.getClass());
+ constructor.setAccessible(true);
+ OrtSession.Result initResultBatched = constructor.newInstance(names, values, ownedByResult);
+ if(log) {
+ android.util.Log.i("performance", "InitResult extract done in: " + timeExtract + "ms");
+ android.util.Log.i("performance", "InitResult batch done in: " + timeBatch + "ms");
+ android.util.Log.i("performance", "InitResult create done in: " + timeCreate + "ms");
+ android.util.Log.i("performance", "InitResult batch initialization done in: " + (System.currentTimeMillis() - time) + "ms");
+ }
+
+ return initResultBatched;
+ }
+
+ private OrtSession.Result batchDecoderKvCache(OrtSession.Result result, OnnxTensor decoderOutput, int nLayers, int batchSize, boolean log) throws InvocationTargetException, IllegalAccessException, InstantiationException, NoSuchMethodException, OrtException {
+ long time = System.currentTimeMillis();
+ String[] names = new String[2*nLayers+1];
+ OnnxValue[] values = new OnnxValue[2*nLayers+1];
+ boolean[] ownedByResult = new boolean[2*nLayers+1];
+ Arrays.fill(ownedByResult, true);
+ names[0] = "logits";
+ values[0] = decoderOutput; //result.get("logits").get();
+ String[] suffixes = new String[]{"key", "value"};
+ int count = 1;
+ for (int i = 0; i < nLayers; i++) {
+ for (String suffix: suffixes) {
+ names[count] = "present." + i + (mode!=HY_MT ? ".decoder." : ".") + suffix;
+ float[][][] keyValue = ((float[][][][]) TensorUtils.extractValue(result, "present." + i + (mode!=HY_MT ? ".decoder." : ".") + suffix))[0];
+ float[] keyValueFlatBatched = TensorUtils.flattenFloatArrayBatched(keyValue, batchSize);
+ values[count] = TensorUtils.createFloatTensor(onnxEnv, keyValueFlatBatched, new long[]{batchSize, keyValue.length, keyValue[0].length, keyValue[0][0].length}); //todo: evaluate the use of createFloatTensorOptimized
+ count++;
+ }
+ }
+ if(log) {
+ android.util.Log.i("performance", "Decoder kvCache batch initialization done in: " + (System.currentTimeMillis() - time) + "ms");
+ }
+ result.close();
+ //the Result constructor is private but this way we can use it anyway
+ Constructor constructor = OrtSession.Result.class.getDeclaredConstructor(names.getClass(), values.getClass(), ownedByResult.getClass());
+ constructor.setAccessible(true);
+ return constructor.newInstance(names, values, ownedByResult);
+ }
+
+ private void initBeamSearchData(float [][][] logits, int beamSize, int[] max, ArrayList[] completeBeamOutput, double[] beamsOutputsProbabilities){
+ //the "beamSize" words with highest probability are inserted into max and added to completeBeamOutput
+ int seqLen = logits[0].length;
+ ArrayList indexesToAvoid = new ArrayList<>();
+ for (int i = 0; i < beamSize; i++) {
+ max[i] = Utils.getIndexOfLargest(logits[0][seqLen-1], indexesToAvoid);
+ indexesToAvoid.add(max[i]);
+ completeBeamOutput[i].add(max[i]);
+ }
+ //we insert the initial probabilities of the "beamSize" output strings into beamsOutputsProbabilities
+ for (int i = 0; i < beamSize; i++) {
+ float maxLogit = logits[0][seqLen-1][max[i]];
+ beamsOutputsProbabilities[i] = maxLogit - Utils.logSumExpFast(logits[0][seqLen-1]);
+ }
+ }
+
+ private CacheContainerNative updateBeamSearchData(
+ float [][][] logits, int beamSize, int eos, OrtSession.Result decoderResult, int sequenceLength, int nLayers, int nHeads, int hiddenSize,
+ CacheContainerNative cacheContainer, int[] maxProbabilities, int[][] beamMax, int[] max, ArrayList[] completeBeamOutput, double[] beamsOutputsProbabilities
+ ){
+ //for each of the "beamSize" decoder outputs, the "beamSize" words with the highest probability are inserted into beamMax
+ for(int k=0; k < beamSize; k++) {
+ ArrayList indexesToAvoid = new ArrayList<>();
+ for (int i = 0; i < beamSize; i++) {
+ int seqLen = logits[k].length;
+ beamMax[k][i] = Utils.getIndexOfLargest(logits[k][seqLen-1], indexesToAvoid);
+ indexesToAvoid.add(beamMax[k][i]);
+ }
+ }
+ //Now beamMax will contain for each decoder output ("beamSize" outputs) the "beamSize" words with highest probability,
+ // so for each output we calculate its overall probability for each of its "beamSize" words with highest probability
+ long timeSoftmax = System.currentTimeMillis();
+ double[] beamsOutputsProbabilitiesTemp = new double[beamSize*beamSize];
+ for(int k=0; k < beamSize; k++) {
+ //new version of probability calculation (logSumExp)
+ int seqLen = logits[k].length;
+ double logSumExp = Utils.logSumExpFast(logits[k][seqLen-1]);
+ for (int i = 0; i < beamSize; i++) {
+ float maxLogit = logits[k][seqLen-1][beamMax[k][i]];
+ beamsOutputsProbabilitiesTemp[(k*beamSize)+i] = beamsOutputsProbabilities[k] + maxLogit - logSumExp;
+ if(beamMax[k][i] == eos){
+ beamsOutputsProbabilitiesTemp[(k*beamSize)+i] = beamsOutputsProbabilitiesTemp[(k*beamSize)+i]/EOS_PENALTY;
}
- return correctedText + getSentenceTerminator(locale);
}
}
- return text;
+ android.util.Log.i("performance", "softmax done in: " + (System.currentTimeMillis()-timeSoftmax) + "ms");
+ // Now we save in maxProbabilities the indices of the "beamSize" words generated by the decoder that have the
+ // highest overall probability with their respective output sentences and then we will use them as the next inputs
+ ArrayList indexesToAvoid = new ArrayList<>();
+ for(int i=0; i[] oldCompleteBeamOutput = completeBeamOutput.clone();
+ for (int i = 0; i < beamSize; i++) {
+ beamsOutputsProbabilities[i] = beamsOutputsProbabilitiesTemp[maxProbabilities[i]];
+ completeBeamOutput[i] = (ArrayList) oldCompleteBeamOutput[maxProbabilities[i]/beamSize].clone();
+ completeBeamOutput[i].add(beamMax[maxProbabilities[i]/beamSize][maxProbabilities[i]%beamSize]);
+ }
+ // reorder of the kvCache to match the new selected inputs for the next iteration
+ long timeCache = System.currentTimeMillis();
+ CacheContainerNative oldCache = cacheContainer;
+ cacheContainer = new CacheContainerNative(onnxEnv, decoderResult, nLayers, beamSize, nHeads, sequenceLength, hiddenSize, mode);
+ if(oldCache != null){
+ oldCache.close();
+ }
+ android.util.Log.i("performance", "cache creation done in: " + (System.currentTimeMillis()-timeCache) + "ms");
+ int[] indexes = new int[beamSize];
+ for(int i=0; i getSupportedLanguages(Context context, int
DocumentBuilder documentBuilder = documentBuilderFactory.newDocumentBuilder();
Document document = null;
if (mode == MADLAD || mode == MADLAD_CACHE) {
- document = documentBuilder.parse(context.getResources().openRawResource(R.raw.madlad_supported_launguages));
- } else if (mode == NLLB || mode == NLLB_CACHE) { //if mode == NLLB
+ if (!qualityLow) {
+ document = documentBuilder.parse(context.getResources().openRawResource(R.raw.madlad_supported_launguages));
+ }else{
+ document = documentBuilder.parse(context.getResources().openRawResource(R.raw.madlad_supported_launguages_all));
+ }
+ } else if (mode == NLLB || mode == NLLB_CACHE) {
if (!qualityLow) {
document = documentBuilder.parse(context.getResources().openRawResource(R.raw.nllb_supported_languages));
} else {
document = documentBuilder.parse(context.getResources().openRawResource(R.raw.nllb_supported_languages_all));
}
+ }else if (mode == HY_MT) {
+ document = documentBuilder.parse(context.getResources().openRawResource(R.raw.hy_mt_supported_languages));
}
NodeList list = document.getElementsByTagName("code");
for (int i = 0; i < list.getLength(); i++) {
@@ -1481,6 +1640,15 @@ public static ArrayList getSupportedLanguages(Context context, int
return languages;
}
+ public static class HyLanguageInfo{
+ public String enName;
+ public String zhName;
+
+ public HyLanguageInfo(String enName, String zhName) {
+ this.enName = enName;
+ this.zhName = zhName;
+ }
+ }
private interface TranslatorListener {
void onFailure(int[] reasons, long value);
diff --git a/app/src/main/java/nie/translator/rtranslator/voice_translation/neural_networks/voice/Recognizer.java b/app/src/main/java/nie/translator/rtranslator/voice_translation/neural_networks/voice/Recognizer.java
index 666cd49e..c1bdfd70 100644
--- a/app/src/main/java/nie/translator/rtranslator/voice_translation/neural_networks/voice/Recognizer.java
+++ b/app/src/main/java/nie/translator/rtranslator/voice_translation/neural_networks/voice/Recognizer.java
@@ -291,7 +291,7 @@ private void recognize() {
DataContainer data = dataToRecognize.pollFirst();
if (initSession != null && encoderSession != null && cacheInitSession != null && decoderSession != null && detokenizerSession != null) {
if (data != null) {
- //we convert data in un audioTensor and start the transcription
+ //we convert data in an audioTensor and start the transcription
try {
FloatBuffer floatAudioDataBuffer = FloatBuffer.wrap(data.data);
OnnxTensor audioTensor = OnnxTensor.createTensor(onnxEnv, floatAudioDataBuffer, TensorUtils.tensorShape(1L, (long) data.data.length));
diff --git a/app/src/main/res/layout/fragment_model_manager.xml b/app/src/main/res/layout/fragment_model_manager.xml
index 8c63700a..1eca07e4 100644
--- a/app/src/main/res/layout/fragment_model_manager.xml
+++ b/app/src/main/res/layout/fragment_model_manager.xml
@@ -40,10 +40,10 @@
android:layout_height="wrap_content"
android:text="Madlad 400" />
+ android:text="HY-MT 1.5" />