Skip to content

Commit ce4e097

Browse files
SEUNGHWAN.JUNGralla0405
authored andcommitted
Fix TokenTextSplitter to prevent mini chunks at end
The previous implementation advanced position based on optimized chunk size, which could create a series of mini chunks when boundary optimization aggressively shrunk chunks near the end of text. Changes: - Implement minimum advance guarantee to prevent consecutive mini chunks - Add tests to verify no mini chunks are generated at end - Clean up test code by removing redundant comments Signed-off-by: logan-mac <[email protected]>
1 parent e0ccc13 commit ce4e097

File tree

2 files changed

+324
-63
lines changed

2 files changed

+324
-63
lines changed

spring-ai-commons/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java

Lines changed: 88 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,14 @@
3333
* @author Raphael Yu
3434
* @author Christian Tzolov
3535
* @author Ricken Bazolo
36+
* @author Seunghwan Jung
3637
*/
3738
public class TokenTextSplitter extends TextSplitter {
3839

3940
private static final int DEFAULT_CHUNK_SIZE = 800;
4041

42+
private static final int DEFAULT_CHUNK_OVERLAP = 50;
43+
4144
private static final int MIN_CHUNK_SIZE_CHARS = 350;
4245

4346
private static final int MIN_CHUNK_LENGTH_TO_EMBED = 5;
@@ -53,6 +56,9 @@ public class TokenTextSplitter extends TextSplitter {
5356
// The target size of each text chunk in tokens
5457
private final int chunkSize;
5558

59+
// The overlap size of each text chunk in tokens
60+
private final int chunkOverlap;
61+
5662
// The minimum size of each text chunk in characters
5763
private final int minChunkSizeChars;
5864

@@ -65,16 +71,20 @@ public class TokenTextSplitter extends TextSplitter {
6571
private final boolean keepSeparator;
6672

6773
public TokenTextSplitter() {
68-
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, KEEP_SEPARATOR);
74+
this(DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS,
75+
KEEP_SEPARATOR);
6976
}
7077

7178
public TokenTextSplitter(boolean keepSeparator) {
72-
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, keepSeparator);
79+
this(DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS,
80+
keepSeparator);
7381
}
7482

75-
public TokenTextSplitter(int chunkSize, int minChunkSizeChars, int minChunkLengthToEmbed, int maxNumChunks,
76-
boolean keepSeparator) {
83+
public TokenTextSplitter(int chunkSize, int chunkOverlap, int minChunkSizeChars, int minChunkLengthToEmbed,
84+
int maxNumChunks, boolean keepSeparator) {
85+
Assert.isTrue(chunkOverlap < chunkSize, "chunk overlap must be less than chunk size");
7786
this.chunkSize = chunkSize;
87+
this.chunkOverlap = chunkOverlap;
7888
this.minChunkSizeChars = minChunkSizeChars;
7989
this.minChunkLengthToEmbed = minChunkLengthToEmbed;
8090
this.maxNumChunks = maxNumChunks;
@@ -87,57 +97,89 @@ public static Builder builder() {
8797

8898
@Override
8999
protected List<String> splitText(String text) {
90-
return doSplit(text, this.chunkSize);
100+
return doSplit(text, this.chunkSize, this.chunkOverlap);
91101
}
92102

93-
protected List<String> doSplit(String text, int chunkSize) {
103+
protected List<String> doSplit(String text, int chunkSize, int chunkOverlap) {
94104
if (text == null || text.trim().isEmpty()) {
95105
return new ArrayList<>();
96106
}
97107

98108
List<Integer> tokens = getEncodedTokens(text);
99-
List<String> chunks = new ArrayList<>();
100-
int num_chunks = 0;
101-
while (!tokens.isEmpty() && num_chunks < this.maxNumChunks) {
102-
List<Integer> chunk = tokens.subList(0, Math.min(chunkSize, tokens.size()));
103-
String chunkText = decodeTokens(chunk);
104-
105-
// Skip the chunk if it is empty or whitespace
106-
if (chunkText.trim().isEmpty()) {
107-
tokens = tokens.subList(chunk.size(), tokens.size());
108-
continue;
109-
}
110-
111-
// Find the last period or punctuation mark in the chunk
112-
int lastPunctuation = Math.max(chunkText.lastIndexOf('.'), Math.max(chunkText.lastIndexOf('?'),
113-
Math.max(chunkText.lastIndexOf('!'), chunkText.lastIndexOf('\n'))));
109+
// If text is smaller than chunk size, return as a single chunk
110+
if (tokens.size() <= chunkSize) {
111+
String processedText = this.keepSeparator ? text.trim() : text.replace(System.lineSeparator(), " ").trim();
114112

115-
if (lastPunctuation != -1 && lastPunctuation > this.minChunkSizeChars) {
116-
// Truncate the chunk text at the punctuation mark
117-
chunkText = chunkText.substring(0, lastPunctuation + 1);
113+
if (processedText.length() > this.minChunkLengthToEmbed) {
114+
return List.of(processedText);
118115
}
116+
return new ArrayList<>();
117+
}
118+
List<String> chunks = new ArrayList<>();
119119

120-
String chunkTextToAppend = (this.keepSeparator) ? chunkText.trim()
121-
: chunkText.replace(System.lineSeparator(), " ").trim();
122-
if (chunkTextToAppend.length() > this.minChunkLengthToEmbed) {
123-
chunks.add(chunkTextToAppend);
120+
int position = 0;
121+
int num_chunks = 0;
122+
while (position < tokens.size() && num_chunks < this.maxNumChunks) {
123+
int chunkEnd = Math.min(position + chunkSize, tokens.size());
124+
125+
// Extract tokens for this chunk
126+
List<Integer> chunkTokens = tokens.subList(position, chunkEnd);
127+
String chunkText = decodeTokens(chunkTokens);
128+
129+
// Apply sentence boundary optimization
130+
String optimizedText = optimizeChunkBoundary(chunkText);
131+
int optimizedTokenCount = getEncodedTokens(optimizedText).size();
132+
133+
// Use optimized chunk
134+
String finalChunkText = optimizedText;
135+
int finalChunkTokenCount = optimizedTokenCount;
136+
137+
// Advance position with minimum advance guarantee
138+
// This prevents creating a series of mini chunks when boundary optimization
139+
// aggressively shrinks chunks
140+
int naturalAdvance = finalChunkTokenCount - chunkOverlap;
141+
int minAdvance = Math.max(1, (chunkSize - chunkOverlap) / 2);
142+
int advance = Math.max(naturalAdvance, minAdvance);
143+
position += advance;
144+
145+
// Format according to keepSeparator setting
146+
String formattedChunk = this.keepSeparator ? finalChunkText.trim()
147+
: finalChunkText.replace(System.lineSeparator(), " ").trim();
148+
149+
// Add chunk if it meets minimum length
150+
if (formattedChunk.length() > this.minChunkLengthToEmbed) {
151+
chunks.add(formattedChunk);
152+
num_chunks++;
124153
}
154+
}
125155

126-
// Remove the tokens corresponding to the chunk text from the remaining tokens
127-
tokens = tokens.subList(getEncodedTokens(chunkText).size(), tokens.size());
156+
return chunks;
157+
}
128158

129-
num_chunks++;
159+
private String optimizeChunkBoundary(String chunkText) {
160+
if (chunkText.length() <= this.minChunkSizeChars) {
161+
return chunkText;
130162
}
131163

132-
// Handle the remaining tokens
133-
if (!tokens.isEmpty()) {
134-
String remaining_text = decodeTokens(tokens).replace(System.lineSeparator(), " ").trim();
135-
if (remaining_text.length() > this.minChunkLengthToEmbed) {
136-
chunks.add(remaining_text);
164+
// Look for sentence endings: . ! ? \n
165+
int bestCutPoint = -1;
166+
167+
// Check in reverse order to find the last sentence ending
168+
for (int i = chunkText.length() - 1; i >= this.minChunkSizeChars; i--) {
169+
char c = chunkText.charAt(i);
170+
if (c == '.' || c == '!' || c == '?' || c == '\n') {
171+
bestCutPoint = i + 1; // Include the punctuation
172+
break;
137173
}
138174
}
139175

140-
return chunks;
176+
// If we found a good cut point, use it
177+
if (bestCutPoint > 0) {
178+
return chunkText.substring(0, bestCutPoint);
179+
}
180+
181+
// Otherwise return the original chunk
182+
return chunkText;
141183
}
142184

143185
private List<Integer> getEncodedTokens(String text) {
@@ -156,6 +198,8 @@ public static final class Builder {
156198

157199
private int chunkSize = DEFAULT_CHUNK_SIZE;
158200

201+
private int chunkOverlap = DEFAULT_CHUNK_OVERLAP;
202+
159203
private int minChunkSizeChars = MIN_CHUNK_SIZE_CHARS;
160204

161205
private int minChunkLengthToEmbed = MIN_CHUNK_LENGTH_TO_EMBED;
@@ -172,6 +216,11 @@ public Builder withChunkSize(int chunkSize) {
172216
return this;
173217
}
174218

219+
public Builder withChunkOverlap(int chunkOverlap) {
220+
this.chunkOverlap = chunkOverlap;
221+
return this;
222+
}
223+
175224
public Builder withMinChunkSizeChars(int minChunkSizeChars) {
176225
this.minChunkSizeChars = minChunkSizeChars;
177226
return this;
@@ -193,8 +242,8 @@ public Builder withKeepSeparator(boolean keepSeparator) {
193242
}
194243

195244
public TokenTextSplitter build() {
196-
return new TokenTextSplitter(this.chunkSize, this.minChunkSizeChars, this.minChunkLengthToEmbed,
197-
this.maxNumChunks, this.keepSeparator);
245+
return new TokenTextSplitter(this.chunkSize, this.chunkOverlap, this.minChunkSizeChars,
246+
this.minChunkLengthToEmbed, this.maxNumChunks, this.keepSeparator);
198247
}
199248

200249
}

0 commit comments

Comments
 (0)