diff --git a/opennlp-api/src/main/java/opennlp/tools/tokenize/WordpieceTokenizer.java b/opennlp-api/src/main/java/opennlp/tools/tokenize/WordpieceTokenizer.java index 1cf9aa0c5..4d92c86bd 100644 --- a/opennlp-api/src/main/java/opennlp/tools/tokenize/WordpieceTokenizer.java +++ b/opennlp-api/src/main/java/opennlp/tools/tokenize/WordpieceTokenizer.java @@ -45,12 +45,27 @@ */ public class WordpieceTokenizer implements Tokenizer { - private static final Pattern PUNCTUATION_PATTERN = Pattern.compile("\\p{Punct}+"); - private static final String CLASSIFICATION_TOKEN = "[CLS]"; - private static final String SEPARATOR_TOKEN = "[SEP]"; - private static final String UNKNOWN_TOKEN = "[UNK]"; + /** BERT classification token: {@code [CLS]}. */ + public static final String BERT_CLS_TOKEN = "[CLS]"; + /** BERT separator token: {@code [SEP]}. */ + public static final String BERT_SEP_TOKEN = "[SEP]"; + /** BERT unknown token: {@code [UNK]}. */ + public static final String BERT_UNK_TOKEN = "[UNK]"; + + /** RoBERTa classification token: {@code }. */ + public static final String ROBERTA_CLS_TOKEN = ""; + /** RoBERTa separator token. */ + public static final String ROBERTA_SEP_TOKEN = ""; + /** RoBERTa unknown token. */ + public static final String ROBERTA_UNK_TOKEN = ""; + + private static final Pattern PUNCTUATION_PATTERN = + Pattern.compile("\\p{Punct}+"); private final Set vocabulary; + private final String classificationToken; + private final String separatorToken; + private final String unknownToken; private int maxTokenLength = 50; /** @@ -60,7 +75,7 @@ public class WordpieceTokenizer implements Tokenizer { * @param vocabulary A set of tokens considered the vocabulary. */ public WordpieceTokenizer(Set vocabulary) { - this.vocabulary = vocabulary; + this(vocabulary, BERT_CLS_TOKEN, BERT_SEP_TOKEN, BERT_UNK_TOKEN); } /** @@ -75,6 +90,29 @@ public WordpieceTokenizer(Set vocabulary, int maxTokenLength) { this.maxTokenLength = maxTokenLength; } + /** + * Initializes a {@link WordpieceTokenizer} with a + * {@code vocabulary} and custom special tokens. + * This allows support for models like RoBERTa that + * use different special tokens instead of the BERT + * defaults. + * + * @param vocabulary The vocabulary. + * @param classificationToken The CLS token. + * @param separatorToken The SEP token. + * @param unknownToken The UNK token. + */ + public WordpieceTokenizer( + final Set vocabulary, + final String classificationToken, + final String separatorToken, + final String unknownToken) { + this.vocabulary = vocabulary; + this.classificationToken = classificationToken; + this.separatorToken = separatorToken; + this.unknownToken = unknownToken; + } + @Override public Span[] tokenizePos(final String text) { // TODO: Implement this. @@ -85,7 +123,7 @@ public Span[] tokenizePos(final String text) { public String[] tokenize(final String text) { final List tokens = new LinkedList<>(); - tokens.add(CLASSIFICATION_TOKEN); + tokens.add(classificationToken); // Put spaces around punctuation. final String spacedPunctuation = PUNCTUATION_PATTERN.matcher(text).replaceAll(" $0 "); @@ -146,7 +184,7 @@ public String[] tokenize(final String text) { // If the word can't be represented by vocabulary pieces replace // it with a specified "unknown" token. if (!found) { - tokens.add(UNKNOWN_TOKEN); + tokens.add(unknownToken); break; } @@ -157,14 +195,14 @@ public String[] tokenize(final String text) { } else { - // If the token's length is greater than the max length just add [UNK] instead. - tokens.add(UNKNOWN_TOKEN); + // If the token's length is greater than the max length just add unknown token instead. + tokens.add(unknownToken); } } - tokens.add(SEPARATOR_TOKEN); + tokens.add(separatorToken); return tokens.toArray(new String[0]); diff --git a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java index eb8c41cef..d46d68a6f 100644 --- a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java +++ b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java @@ -19,11 +19,14 @@ import java.io.File; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.util.HashMap; import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import java.util.stream.Stream; import ai.onnxruntime.OrtEnvironment; @@ -31,6 +34,7 @@ import ai.onnxruntime.OrtSession; import opennlp.tools.tokenize.Tokenizer; +import opennlp.tools.tokenize.WordpieceTokenizer; /** * Base class for OpenNLP deep-learning classes using ONNX Runtime. @@ -46,21 +50,92 @@ public abstract class AbstractDL implements AutoCloseable { protected Tokenizer tokenizer; protected Map vocab; + private static final Pattern JSON_ENTRY_PATTERN = + Pattern.compile("\"((?:[^\"\\\\]|\\\\.)*)\"\\s*:\\s*(\\d+)"); + /** * Loads a vocabulary {@link File} from disk. + * Supports both plain text files (one token per + * line) and JSON files mapping tokens to integer + * IDs. * * @param vocabFile The vocabulary file. - * @return A map of vocabulary words to integer IDs. - * @throws IOException Thrown if the vocabulary file cannot be opened or read. + * @return A map of vocabulary words to IDs. + * @throws IOException Thrown if the vocabulary + * file cannot be opened or read. */ - public Map loadVocab(final File vocabFile) throws IOException { + public Map loadVocab( + final File vocabFile) throws IOException { - final Map vocab = new HashMap<>(); - final AtomicInteger counter = new AtomicInteger(0); + final Path vocabPath = + Path.of(vocabFile.getPath()); + final String content = Files.readString( + vocabPath, StandardCharsets.UTF_8); + final String trimmed = content.trim(); + + // Detect JSON format by leading brace + if (trimmed.startsWith("{")) { + return loadJsonVocab(trimmed); + } + + final Map vocab = + new HashMap<>(); + final AtomicInteger counter = + new AtomicInteger(0); + + try (Stream lines = Files.lines( + vocabPath, StandardCharsets.UTF_8)) { + lines.forEach(line -> + vocab.put(line, counter.getAndIncrement()) + ); + } + + return vocab; + } - try (Stream lines = Files.lines(Path.of(vocabFile.getPath()))) { + /** + * Creates a {@link WordpieceTokenizer} that uses the + * appropriate special tokens based on the vocabulary. + * If the vocabulary contains RoBERTa-style tokens, + * those are used. Otherwise, the BERT defaults are + * used. + * + * @param vocab The vocabulary map. + * @return A configured {@link WordpieceTokenizer}. + */ + protected WordpieceTokenizer createTokenizer( + final Map vocab) { + if (vocab.containsKey( + WordpieceTokenizer.ROBERTA_CLS_TOKEN) + && vocab.containsKey( + WordpieceTokenizer.ROBERTA_SEP_TOKEN)) { + final String unk = vocab.containsKey( + WordpieceTokenizer.ROBERTA_UNK_TOKEN) + ? WordpieceTokenizer.ROBERTA_UNK_TOKEN + : WordpieceTokenizer.BERT_UNK_TOKEN; + return new WordpieceTokenizer( + vocab.keySet(), + WordpieceTokenizer.ROBERTA_CLS_TOKEN, + WordpieceTokenizer.ROBERTA_SEP_TOKEN, + unk); + } + return new WordpieceTokenizer(vocab.keySet()); + } + + private Map loadJsonVocab(final String json) { + + final Map vocab = new HashMap<>(); + final Matcher matcher = JSON_ENTRY_PATTERN.matcher(json); - lines.forEach(line -> vocab.put(line, counter.getAndIncrement())); + while (matcher.find()) { + final String token = matcher.group(1) + .replace("\\\"", "\"") + .replace("\\\\", "\\") + .replace("\\/", "/") + .replace("\\n", "\n") + .replace("\\t", "\t"); + final int id = Integer.parseInt(matcher.group(2)); + vocab.put(token, id); } return vocab; diff --git a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java index 9173f30e4..a0c9ede77 100644 --- a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java +++ b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java @@ -45,7 +45,7 @@ import opennlp.dl.Tokens; import opennlp.dl.doccat.scoring.ClassificationScoringStrategy; import opennlp.tools.doccat.DocumentCategorizer; -import opennlp.tools.tokenize.WordpieceTokenizer; + /** * An implementation of {@link DocumentCategorizer} that performs document classification @@ -90,7 +90,7 @@ public DocumentCategorizerDL(File model, File vocabulary, Map c this.session = env.createSession(model.getPath(), sessionOptions); this.vocab = loadVocab(vocabulary); - this.tokenizer = new WordpieceTokenizer(vocab.keySet()); + this.tokenizer = createTokenizer(vocab); this.categories = categories; this.classificationScoringStrategy = classificationScoringStrategy; this.inferenceOptions = inferenceOptions; @@ -125,7 +125,7 @@ public DocumentCategorizerDL(File model, File vocabulary, File config, this.session = env.createSession(model.getPath(), sessionOptions); this.vocab = loadVocab(vocabulary); - this.tokenizer = new WordpieceTokenizer(vocab.keySet()); + this.tokenizer = createTokenizer(vocab); this.categories = readCategoriesFromFile(config); this.classificationScoringStrategy = classificationScoringStrategy; this.inferenceOptions = inferenceOptions; @@ -158,11 +158,22 @@ public double[] categorize(String[] strings) { LongBuffer.wrap(t.types()), new long[] {1, t.types().length})); } - // The outputs from the model. - final float[][] v = (float[][]) session.run(inputs).get(0).getValue(); + // The outputs from the model. Some models return a 2D array (e.g. BERT), + // while others return a 1D array (e.g. RoBERTa). + final Object output = session.run(inputs).get(0).getValue(); + + final float[] rawScores; + if (output instanceof float[][] v) { + rawScores = v[0]; + } else if (output instanceof float[] v) { + rawScores = v; + } else { + throw new IllegalStateException( + "Unexpected model output type: " + output.getClass().getName()); + } // Keep track of all scores. - final double[] categoryScoresForTokens = softmax(v[0]); + final double[] categoryScoresForTokens = softmax(rawScores); scores.add(categoryScoresForTokens); } diff --git a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java index 3cbf0e2a0..74e5a1aac 100644 --- a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java +++ b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java @@ -39,7 +39,6 @@ import opennlp.dl.Tokens; import opennlp.tools.namefind.TokenNameFinder; import opennlp.tools.sentdetect.SentenceDetector; -import opennlp.tools.tokenize.WordpieceTokenizer; import opennlp.tools.util.Span; /** @@ -104,7 +103,7 @@ public NameFinderDL(File model, File vocabulary, Map ids2Labels this.session = env.createSession(model.getPath(), sessionOptions); this.ids2Labels = ids2Labels; this.vocab = loadVocab(vocabulary); - this.tokenizer = new WordpieceTokenizer(vocab.keySet()); + this.tokenizer = createTokenizer(vocab); this.inferenceOptions = inferenceOptions; this.sentenceDetector = sentenceDetector; diff --git a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java index 805b41188..85abfbe4f 100644 --- a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java +++ b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java @@ -32,7 +32,7 @@ import opennlp.dl.AbstractDL; import opennlp.dl.Tokens; import opennlp.tools.tokenize.Tokenizer; -import opennlp.tools.tokenize.WordpieceTokenizer; + /** * Facilitates the generation of sentence vectors using @@ -55,7 +55,7 @@ public SentenceVectorsDL(final File model, final File vocabulary) env = OrtEnvironment.getEnvironment(); session = env.createSession(model.getPath(), new OrtSession.SessionOptions()); vocab = loadVocab(new File(vocabulary.getPath())); - tokenizer = new WordpieceTokenizer(vocab.keySet()); + tokenizer = createTokenizer(vocab); } diff --git a/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/LoadVocabTest.java b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/LoadVocabTest.java new file mode 100644 index 000000000..d8554c3fb --- /dev/null +++ b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/LoadVocabTest.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package opennlp.dl; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.StandardCopyOption; +import java.util.Map; +import java.util.Objects; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +public class LoadVocabTest { + + private final AbstractDL dl = new AbstractDL() { + @Override + public void close() { + } + }; + + private File getResource(String name) throws IOException { + try (InputStream is = Objects.requireNonNull( + getClass().getResourceAsStream("/opennlp/dl/" + name))) { + final File tempFile = File.createTempFile("vocab-test-", "-" + name); + tempFile.deleteOnExit(); + Files.copy(is, tempFile.toPath(), StandardCopyOption.REPLACE_EXISTING); + return tempFile; + } + } + + @Test + void testLoadPlainTextVocab() throws IOException { + final Map vocab = dl.loadVocab(getResource("vocab-plain.txt")); + + assertNotNull(vocab); + assertEquals(6, vocab.size()); + assertEquals(0, vocab.get("[CLS]")); + assertEquals(1, vocab.get("[SEP]")); + assertEquals(2, vocab.get("[UNK]")); + assertEquals(3, vocab.get("hello")); + assertEquals(4, vocab.get("world")); + assertEquals(5, vocab.get("##ing")); + } + + @Test + void testLoadJsonVocab() throws IOException { + final Map vocab = dl.loadVocab(getResource("vocab.json")); + + assertNotNull(vocab); + assertEquals(6, vocab.size()); + assertEquals(0, vocab.get("[CLS]")); + assertEquals(1, vocab.get("[SEP]")); + assertEquals(2, vocab.get("[UNK]")); + assertEquals(3, vocab.get("hello")); + assertEquals(4, vocab.get("world")); + assertEquals(5, vocab.get("##ing")); + } + + @Test + void testJsonVocabWithEscapedCharacters() throws IOException { + final File tempFile = File.createTempFile("vocab-escaped", ".json"); + tempFile.deleteOnExit(); + + Files.writeString(tempFile.toPath(), + "{\"hello\\\"world\": 0, \"back\\\\slash\": 1}"); + + final Map vocab = dl.loadVocab(tempFile); + + assertNotNull(vocab); + assertEquals(2, vocab.size()); + assertEquals(0, vocab.get("hello\"world")); + assertEquals(1, vocab.get("back\\slash")); + } + + @Test + void testJsonAndPlainTextVocabProduceSameResult() throws IOException { + final Map plainVocab = dl.loadVocab(getResource("vocab-plain.txt")); + final Map jsonVocab = dl.loadVocab(getResource("vocab.json")); + + assertEquals(plainVocab, jsonVocab); + } +} diff --git a/opennlp-core/opennlp-ml/opennlp-dl/src/test/resources/opennlp/dl/vocab-plain.txt b/opennlp-core/opennlp-ml/opennlp-dl/src/test/resources/opennlp/dl/vocab-plain.txt new file mode 100644 index 000000000..2f458c144 --- /dev/null +++ b/opennlp-core/opennlp-ml/opennlp-dl/src/test/resources/opennlp/dl/vocab-plain.txt @@ -0,0 +1,6 @@ +[CLS] +[SEP] +[UNK] +hello +world +##ing \ No newline at end of file diff --git a/opennlp-core/opennlp-ml/opennlp-dl/src/test/resources/opennlp/dl/vocab.json b/opennlp-core/opennlp-ml/opennlp-dl/src/test/resources/opennlp/dl/vocab.json new file mode 100644 index 000000000..2e62cee38 --- /dev/null +++ b/opennlp-core/opennlp-ml/opennlp-dl/src/test/resources/opennlp/dl/vocab.json @@ -0,0 +1,8 @@ +{ + "[CLS]": 0, + "[SEP]": 1, + "[UNK]": 2, + "hello": 3, + "world": 4, + "##ing": 5 +} \ No newline at end of file