Skip to content

Commit 192782e

Browse files
authored
Add Semantic Similarity chunker (#6994)
1 parent 0eb69ee commit 192782e

File tree

5 files changed

+432
-1
lines changed

5 files changed

+432
-1
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.Collections.Generic;
6+
using System.Numerics.Tensors;
7+
using System.Runtime.CompilerServices;
8+
using System.Threading;
9+
using System.Threading.Tasks;
10+
using Microsoft.Extensions.AI;
11+
using Microsoft.Shared.Diagnostics;
12+
13+
namespace Microsoft.Extensions.DataIngestion.Chunkers;
14+
15+
/// <summary>
16+
/// Splits a <see cref="IngestionDocument"/> into chunks based on semantic similarity between its elements based on cosine distance of their embeddings.
17+
/// </summary>
18+
public sealed class SemanticSimilarityChunker : IngestionChunker<string>
19+
{
20+
private readonly ElementsChunker _elementsChunker;
21+
private readonly IEmbeddingGenerator<string, Embedding<float>> _embeddingGenerator;
22+
private readonly float _thresholdPercentile;
23+
24+
/// <summary>
25+
/// Initializes a new instance of the <see cref="SemanticSimilarityChunker"/> class.
26+
/// </summary>
27+
/// <param name="embeddingGenerator">Embedding generator.</param>
28+
/// <param name="options">The options for the chunker.</param>
29+
/// <param name="thresholdPercentile">Threshold percentile to consider the chunks to be sufficiently similar. 95th percentile will be used if not specified.</param>
30+
public SemanticSimilarityChunker(
31+
IEmbeddingGenerator<string, Embedding<float>> embeddingGenerator,
32+
IngestionChunkerOptions options,
33+
float? thresholdPercentile = null)
34+
{
35+
_embeddingGenerator = embeddingGenerator;
36+
_elementsChunker = new(options);
37+
38+
if (thresholdPercentile < 0f || thresholdPercentile > 100f)
39+
{
40+
Throw.ArgumentOutOfRangeException(nameof(thresholdPercentile), "Threshold percentile must be between 0 and 100.");
41+
}
42+
43+
_thresholdPercentile = thresholdPercentile ?? 95.0f;
44+
}
45+
46+
/// <inheritdoc/>
47+
public override async IAsyncEnumerable<IngestionChunk<string>> ProcessAsync(IngestionDocument document,
48+
[EnumeratorCancellation] CancellationToken cancellationToken = default)
49+
{
50+
_ = Throw.IfNull(document);
51+
52+
List<(IngestionDocumentElement, float)> distances = await CalculateDistancesAsync(document, cancellationToken).ConfigureAwait(false);
53+
foreach (var chunk in MakeChunks(document, distances))
54+
{
55+
yield return chunk;
56+
}
57+
}
58+
59+
private async Task<List<(IngestionDocumentElement element, float distance)>> CalculateDistancesAsync(IngestionDocument documents, CancellationToken cancellationToken)
60+
{
61+
List<(IngestionDocumentElement element, float distance)> elementDistances = [];
62+
List<string> semanticContents = [];
63+
64+
foreach (IngestionDocumentElement element in documents.EnumerateContent())
65+
{
66+
string? semanticContent = element is IngestionDocumentImage img
67+
? img.AlternativeText ?? img.Text
68+
: element.GetMarkdown();
69+
70+
if (!string.IsNullOrEmpty(semanticContent))
71+
{
72+
elementDistances.Add((element, default));
73+
semanticContents.Add(semanticContent!);
74+
}
75+
}
76+
77+
if (elementDistances.Count > 0)
78+
{
79+
var embeddings = await _embeddingGenerator.GenerateAsync(semanticContents, cancellationToken: cancellationToken).ConfigureAwait(false);
80+
81+
if (embeddings.Count != elementDistances.Count)
82+
{
83+
Throw.InvalidOperationException("The number of embeddings returned does not match the number of document elements.");
84+
}
85+
86+
for (int i = 0; i < elementDistances.Count - 1; i++)
87+
{
88+
float distance = 1 - TensorPrimitives.CosineSimilarity(embeddings[i].Vector.Span, embeddings[i + 1].Vector.Span);
89+
elementDistances[i] = (elementDistances[i].element, distance);
90+
}
91+
}
92+
93+
return elementDistances;
94+
}
95+
96+
private IEnumerable<IngestionChunk<string>> MakeChunks(IngestionDocument document, List<(IngestionDocumentElement element, float distance)> elementDistances)
97+
{
98+
float distanceThreshold = Percentile(elementDistances);
99+
100+
List<IngestionDocumentElement> elementAccumulator = [];
101+
string context = string.Empty;
102+
for (int i = 0; i < elementDistances.Count; i++)
103+
{
104+
var (element, distance) = elementDistances[i];
105+
106+
elementAccumulator.Add(element);
107+
if (distance > distanceThreshold || i == elementDistances.Count - 1)
108+
{
109+
foreach (var chunk in _elementsChunker.Process(document, context, elementAccumulator))
110+
{
111+
yield return chunk;
112+
}
113+
elementAccumulator.Clear();
114+
}
115+
}
116+
}
117+
118+
private float Percentile(List<(IngestionDocumentElement element, float distance)> elementDistances)
119+
{
120+
if (elementDistances.Count == 0)
121+
{
122+
return 0f;
123+
}
124+
else if (elementDistances.Count == 1)
125+
{
126+
return elementDistances[0].distance;
127+
}
128+
129+
float[] sorted = new float[elementDistances.Count];
130+
for (int elementIndex = 0; elementIndex < elementDistances.Count; elementIndex++)
131+
{
132+
sorted[elementIndex] = elementDistances[elementIndex].distance;
133+
}
134+
Array.Sort(sorted);
135+
136+
float i = (_thresholdPercentile / 100f) * (sorted.Length - 1);
137+
int i0 = (int)i;
138+
int i1 = Math.Min(i0 + 1, sorted.Length - 1);
139+
return sorted[i0] + ((i - i0) * (sorted[i1] - sorted[i0]));
140+
}
141+
}

src/Libraries/Microsoft.Extensions.DataIngestion/Microsoft.Extensions.DataIngestion.csproj

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
<!-- we are not ready to publish yet -->
1111
<IsPackable>false</IsPackable>
1212
<Stage>preview</Stage>
13-
<EnablePackageValidation>false</EnablePackageValidation>
13+
<EnablePackageValidation>false</EnablePackageValidation>
1414
</PropertyGroup>
1515

1616
<ItemGroup>
@@ -21,6 +21,7 @@
2121
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" />
2222
<PackageReference Include="Microsoft.Extensions.VectorData.Abstractions" />
2323
<PackageReference Include="Microsoft.ML.Tokenizers" />
24+
<PackageReference Include="System.Numerics.Tensors" />
2425
</ItemGroup>
2526

2627
<ItemGroup Condition="'$(TargetFrameworkIdentifier)' != '.NETCoreApp'">
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.Collections.Generic;
6+
using System.Threading.Tasks;
7+
using Xunit;
8+
9+
namespace Microsoft.Extensions.DataIngestion.Chunkers.Tests
10+
{
11+
public abstract class DocumentChunkerTests
12+
{
13+
protected abstract IngestionChunker<string> CreateDocumentChunker(int maxTokensPerChunk = 2_000, int overlapTokens = 500);
14+
15+
[Fact]
16+
public async Task ProcessAsync_ThrowsArgumentNullException_WhenDocumentIsNull()
17+
{
18+
var chunker = CreateDocumentChunker();
19+
await Assert.ThrowsAsync<ArgumentNullException>("document", async () => await chunker.ProcessAsync(null!).ToListAsync());
20+
}
21+
22+
[Fact]
23+
public async Task EmptyDocument()
24+
{
25+
IngestionDocument emptyDoc = new("emptyDoc");
26+
IngestionChunker<string> chunker = CreateDocumentChunker();
27+
28+
IReadOnlyList<IngestionChunk<string>> chunks = await chunker.ProcessAsync(emptyDoc).ToListAsync();
29+
Assert.Empty(chunks);
30+
}
31+
}
32+
}

0 commit comments

Comments
 (0)