Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,27 @@
*/
public class WordpieceTokenizer implements Tokenizer {

private static final Pattern PUNCTUATION_PATTERN = Pattern.compile("\\p{Punct}+");
private static final String CLASSIFICATION_TOKEN = "[CLS]";
private static final String SEPARATOR_TOKEN = "[SEP]";
private static final String UNKNOWN_TOKEN = "[UNK]";
/** BERT classification token: {@code [CLS]}. */
public static final String BERT_CLS_TOKEN = "[CLS]";
/** BERT separator token: {@code [SEP]}. */
public static final String BERT_SEP_TOKEN = "[SEP]";
/** BERT unknown token: {@code [UNK]}. */
public static final String BERT_UNK_TOKEN = "[UNK]";

/** RoBERTa classification token: {@code <s>}. */
public static final String ROBERTA_CLS_TOKEN = "<s>";
/** RoBERTa separator token. */
public static final String ROBERTA_SEP_TOKEN = "</s>";
/** RoBERTa unknown token. */
public static final String ROBERTA_UNK_TOKEN = "<unk>";

private static final Pattern PUNCTUATION_PATTERN =
Pattern.compile("\\p{Punct}+");

private final Set<String> vocabulary;
private final String classificationToken;
private final String separatorToken;
private final String unknownToken;
private int maxTokenLength = 50;

/**
Expand All @@ -60,7 +75,7 @@ public class WordpieceTokenizer implements Tokenizer {
* @param vocabulary A set of tokens considered the vocabulary.
*/
public WordpieceTokenizer(Set<String> vocabulary) {
this.vocabulary = vocabulary;
this(vocabulary, BERT_CLS_TOKEN, BERT_SEP_TOKEN, BERT_UNK_TOKEN);
}

/**
Expand All @@ -75,6 +90,29 @@ public WordpieceTokenizer(Set<String> vocabulary, int maxTokenLength) {
this.maxTokenLength = maxTokenLength;
}

/**
* Initializes a {@link WordpieceTokenizer} with a
* {@code vocabulary} and custom special tokens.
* This allows support for models like RoBERTa that
* use different special tokens instead of the BERT
* defaults.
*
* @param vocabulary The vocabulary.
* @param classificationToken The CLS token.
* @param separatorToken The SEP token.
* @param unknownToken The UNK token.
*/
public WordpieceTokenizer(
final Set<String> vocabulary,
final String classificationToken,
final String separatorToken,
final String unknownToken) {
this.vocabulary = vocabulary;
this.classificationToken = classificationToken;
this.separatorToken = separatorToken;
this.unknownToken = unknownToken;
}

@Override
public Span[] tokenizePos(final String text) {
// TODO: Implement this.
Expand All @@ -85,7 +123,7 @@ public Span[] tokenizePos(final String text) {
public String[] tokenize(final String text) {

final List<String> tokens = new LinkedList<>();
tokens.add(CLASSIFICATION_TOKEN);
tokens.add(classificationToken);

// Put spaces around punctuation.
final String spacedPunctuation = PUNCTUATION_PATTERN.matcher(text).replaceAll(" $0 ");
Expand Down Expand Up @@ -146,7 +184,7 @@ public String[] tokenize(final String text) {
// If the word can't be represented by vocabulary pieces replace
// it with a specified "unknown" token.
if (!found) {
tokens.add(UNKNOWN_TOKEN);
tokens.add(unknownToken);
break;
}

Expand All @@ -157,14 +195,14 @@ public String[] tokenize(final String text) {

} else {

// If the token's length is greater than the max length just add [UNK] instead.
tokens.add(UNKNOWN_TOKEN);
// If the token's length is greater than the max length just add unknown token instead.
tokens.add(unknownToken);

}

}

tokens.add(SEPARATOR_TOKEN);
tokens.add(separatorToken);

return tokens.toArray(new String[0]);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,22 @@

import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Stream;

import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;

import opennlp.tools.tokenize.Tokenizer;
import opennlp.tools.tokenize.WordpieceTokenizer;

/**
* Base class for OpenNLP deep-learning classes using ONNX Runtime.
Expand All @@ -46,21 +50,92 @@ public abstract class AbstractDL implements AutoCloseable {
protected Tokenizer tokenizer;
protected Map<String, Integer> vocab;

private static final Pattern JSON_ENTRY_PATTERN =
Pattern.compile("\"((?:[^\"\\\\]|\\\\.)*)\"\\s*:\\s*(\\d+)");

/**
* Loads a vocabulary {@link File} from disk.
* Supports both plain text files (one token per
* line) and JSON files mapping tokens to integer
* IDs.
*
* @param vocabFile The vocabulary file.
* @return A map of vocabulary words to integer IDs.
* @throws IOException Thrown if the vocabulary file cannot be opened or read.
* @return A map of vocabulary words to IDs.
* @throws IOException Thrown if the vocabulary
* file cannot be opened or read.
*/
public Map<String, Integer> loadVocab(final File vocabFile) throws IOException {
public Map<String, Integer> loadVocab(
final File vocabFile) throws IOException {

final Map<String, Integer> vocab = new HashMap<>();
final AtomicInteger counter = new AtomicInteger(0);
final Path vocabPath =
Path.of(vocabFile.getPath());
final String content = Files.readString(
vocabPath, StandardCharsets.UTF_8);
final String trimmed = content.trim();

// Detect JSON format by leading brace
if (trimmed.startsWith("{")) {
return loadJsonVocab(trimmed);
}

final Map<String, Integer> vocab =
new HashMap<>();
final AtomicInteger counter =
new AtomicInteger(0);

try (Stream<String> lines = Files.lines(
vocabPath, StandardCharsets.UTF_8)) {
lines.forEach(line ->
vocab.put(line, counter.getAndIncrement())
);
}

return vocab;
}

try (Stream<String> lines = Files.lines(Path.of(vocabFile.getPath()))) {
/**
* Creates a {@link WordpieceTokenizer} that uses the
* appropriate special tokens based on the vocabulary.
* If the vocabulary contains RoBERTa-style tokens,
* those are used. Otherwise, the BERT defaults are
* used.
*
* @param vocab The vocabulary map.
* @return A configured {@link WordpieceTokenizer}.
*/
protected WordpieceTokenizer createTokenizer(
final Map<String, Integer> vocab) {
if (vocab.containsKey(
WordpieceTokenizer.ROBERTA_CLS_TOKEN)
&& vocab.containsKey(
WordpieceTokenizer.ROBERTA_SEP_TOKEN)) {
final String unk = vocab.containsKey(
WordpieceTokenizer.ROBERTA_UNK_TOKEN)
? WordpieceTokenizer.ROBERTA_UNK_TOKEN
: WordpieceTokenizer.BERT_UNK_TOKEN;
return new WordpieceTokenizer(
vocab.keySet(),
WordpieceTokenizer.ROBERTA_CLS_TOKEN,
WordpieceTokenizer.ROBERTA_SEP_TOKEN,
unk);
}
return new WordpieceTokenizer(vocab.keySet());
}

private Map<String, Integer> loadJsonVocab(final String json) {

final Map<String, Integer> vocab = new HashMap<>();
final Matcher matcher = JSON_ENTRY_PATTERN.matcher(json);

lines.forEach(line -> vocab.put(line, counter.getAndIncrement()));
while (matcher.find()) {
final String token = matcher.group(1)
.replace("\\\"", "\"")
.replace("\\\\", "\\")
.replace("\\/", "/")
.replace("\\n", "\n")
.replace("\\t", "\t");
final int id = Integer.parseInt(matcher.group(2));
vocab.put(token, id);
}

return vocab;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
import opennlp.dl.Tokens;
import opennlp.dl.doccat.scoring.ClassificationScoringStrategy;
import opennlp.tools.doccat.DocumentCategorizer;
import opennlp.tools.tokenize.WordpieceTokenizer;


/**
* An implementation of {@link DocumentCategorizer} that performs document classification
Expand Down Expand Up @@ -90,7 +90,7 @@ public DocumentCategorizerDL(File model, File vocabulary, Map<Integer, String> c

this.session = env.createSession(model.getPath(), sessionOptions);
this.vocab = loadVocab(vocabulary);
this.tokenizer = new WordpieceTokenizer(vocab.keySet());
this.tokenizer = createTokenizer(vocab);
this.categories = categories;
this.classificationScoringStrategy = classificationScoringStrategy;
this.inferenceOptions = inferenceOptions;
Expand Down Expand Up @@ -125,7 +125,7 @@ public DocumentCategorizerDL(File model, File vocabulary, File config,

this.session = env.createSession(model.getPath(), sessionOptions);
this.vocab = loadVocab(vocabulary);
this.tokenizer = new WordpieceTokenizer(vocab.keySet());
this.tokenizer = createTokenizer(vocab);
this.categories = readCategoriesFromFile(config);
this.classificationScoringStrategy = classificationScoringStrategy;
this.inferenceOptions = inferenceOptions;
Expand Down Expand Up @@ -158,11 +158,22 @@ public double[] categorize(String[] strings) {
LongBuffer.wrap(t.types()), new long[] {1, t.types().length}));
}

// The outputs from the model.
final float[][] v = (float[][]) session.run(inputs).get(0).getValue();
// The outputs from the model. Some models return a 2D array (e.g. BERT),
// while others return a 1D array (e.g. RoBERTa).
final Object output = session.run(inputs).get(0).getValue();

final float[] rawScores;
if (output instanceof float[][] v) {
rawScores = v[0];
} else if (output instanceof float[] v) {
rawScores = v;
} else {
throw new IllegalStateException(
"Unexpected model output type: " + output.getClass().getName());
}

// Keep track of all scores.
final double[] categoryScoresForTokens = softmax(v[0]);
final double[] categoryScoresForTokens = softmax(rawScores);
scores.add(categoryScoresForTokens);

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import opennlp.dl.Tokens;
import opennlp.tools.namefind.TokenNameFinder;
import opennlp.tools.sentdetect.SentenceDetector;
import opennlp.tools.tokenize.WordpieceTokenizer;
import opennlp.tools.util.Span;

/**
Expand Down Expand Up @@ -104,7 +103,7 @@ public NameFinderDL(File model, File vocabulary, Map<Integer, String> ids2Labels
this.session = env.createSession(model.getPath(), sessionOptions);
this.ids2Labels = ids2Labels;
this.vocab = loadVocab(vocabulary);
this.tokenizer = new WordpieceTokenizer(vocab.keySet());
this.tokenizer = createTokenizer(vocab);
this.inferenceOptions = inferenceOptions;
this.sentenceDetector = sentenceDetector;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import opennlp.dl.AbstractDL;
import opennlp.dl.Tokens;
import opennlp.tools.tokenize.Tokenizer;
import opennlp.tools.tokenize.WordpieceTokenizer;


/**
* Facilitates the generation of sentence vectors using
Expand All @@ -55,7 +55,7 @@ public SentenceVectorsDL(final File model, final File vocabulary)
env = OrtEnvironment.getEnvironment();
session = env.createSession(model.getPath(), new OrtSession.SessionOptions());
vocab = loadVocab(new File(vocabulary.getPath()));
tokenizer = new WordpieceTokenizer(vocab.keySet());
tokenizer = createTokenizer(vocab);

}

Expand Down
Loading