diff --git a/src/Libraries/Microsoft.Extensions.DataIngestion.Abstractions/IngestionChunkWriter.cs b/src/Libraries/Microsoft.Extensions.DataIngestion.Abstractions/IngestionChunkWriter.cs index 119265caf6e..7eb4292a582 100644 --- a/src/Libraries/Microsoft.Extensions.DataIngestion.Abstractions/IngestionChunkWriter.cs +++ b/src/Libraries/Microsoft.Extensions.DataIngestion.Abstractions/IngestionChunkWriter.cs @@ -15,12 +15,13 @@ namespace Microsoft.Extensions.DataIngestion; public abstract class IngestionChunkWriter : IDisposable { /// - /// Writes chunks asynchronously. + /// Writes the chunks of a single document asynchronously. /// + /// The document from which the chunks were extracted. /// The chunks to write. /// The token to monitor for cancellation requests. /// A task representing the asynchronous write operation. - public abstract Task WriteAsync(IAsyncEnumerable> chunks, CancellationToken cancellationToken = default); + public abstract Task WriteAsync(IngestionDocument document, IAsyncEnumerable> chunks, CancellationToken cancellationToken = default); /// /// Disposes the writer and releases all associated resources. diff --git a/src/Libraries/Microsoft.Extensions.DataIngestion/IngestionPipeline.cs b/src/Libraries/Microsoft.Extensions.DataIngestion/IngestionPipeline.cs index 1eeb94058ee..35ba3d38823 100644 --- a/src/Libraries/Microsoft.Extensions.DataIngestion/IngestionPipeline.cs +++ b/src/Libraries/Microsoft.Extensions.DataIngestion/IngestionPipeline.cs @@ -187,7 +187,7 @@ private async Task IngestAsync(IngestionDocument document, Ac } _logger?.WritingChunks(GetShortName(_writer)); - await _writer.WriteAsync(chunks, cancellationToken).ConfigureAwait(false); + await _writer.WriteAsync(document, chunks, cancellationToken).ConfigureAwait(false); _logger?.WroteChunks(document.Identifier); return document; diff --git a/src/Libraries/Microsoft.Extensions.DataIngestion/Writers/VectorStoreWriter.cs b/src/Libraries/Microsoft.Extensions.DataIngestion/Writers/VectorStoreWriter.cs index 124c33ab644..967e2e91929 100644 --- a/src/Libraries/Microsoft.Extensions.DataIngestion/Writers/VectorStoreWriter.cs +++ b/src/Libraries/Microsoft.Extensions.DataIngestion/Writers/VectorStoreWriter.cs @@ -43,8 +43,9 @@ public VectorStoreWriter(VectorStoreCollection collection, Vector public VectorStoreCollection VectorStoreCollection { get; } /// - public override async Task WriteAsync(IAsyncEnumerable> chunks, CancellationToken cancellationToken = default) + public override async Task WriteAsync(IngestionDocument document, IAsyncEnumerable> chunks, CancellationToken cancellationToken = default) { + _ = Throw.IfNull(document); _ = Throw.IfNull(chunks); IReadOnlyList? preExistingKeys = null; @@ -62,13 +63,13 @@ public override async Task WriteAsync(IAsyncEnumerable> c // We obtain the IDs of the pre-existing chunks for given document, // and delete them after we finish inserting the new chunks, // to avoid a situation where we delete the chunks and then fail to insert the new ones. - preExistingKeys ??= await GetPreExistingChunksIdsAsync(chunk.Document, cancellationToken).ConfigureAwait(false); + preExistingKeys ??= await GetPreExistingChunksIdsAsync(document, cancellationToken).ConfigureAwait(false); TRecord record = new() { Content = chunk.Content, Context = chunk.Context, - DocumentId = chunk.Document.Identifier, + DocumentId = document.Identifier, }; if (chunk.HasMetadata) diff --git a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Writers/VectorStoreWriterTests.cs b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Writers/VectorStoreWriterTests.cs index e8ab6ab6d73..8cb7b0d9768 100644 --- a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Writers/VectorStoreWriterTests.cs +++ b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Writers/VectorStoreWriterTests.cs @@ -51,7 +51,7 @@ public async Task CanWriteChunksWithCustomDefinition() List> chunks = [chunk]; - await writer.WriteAsync(chunks.ToAsyncEnumerable()); + await writer.WriteAsync(document, chunks.ToAsyncEnumerable()); IngestionChunkVectorRecord record = await writer.VectorStoreCollection .GetAsync(filter: record => record.DocumentId == documentId, top: 1) @@ -82,7 +82,7 @@ public async Task CanWriteChunks() List> chunks = [chunk]; Assert.False(testEmbeddingGenerator.WasCalled); - await writer.WriteAsync(chunks.ToAsyncEnumerable()); + await writer.WriteAsync(document, chunks.ToAsyncEnumerable()); IngestionChunkVectorRecord record = await writer.VectorStoreCollection .GetAsync(filter: record => record.DocumentId == documentId, top: 1) @@ -112,7 +112,7 @@ public async Task CanWriteChunksWithMetadata() List> chunks = [chunk]; - await writer.WriteAsync(chunks.ToAsyncEnumerable()); + await writer.WriteAsync(document, chunks.ToAsyncEnumerable()); TestChunkRecordWithMetadata record = await writer.VectorStoreCollection .GetAsync(filter: record => record.DocumentId == documentId, top: 1) @@ -148,7 +148,7 @@ public async Task DoesSupportIncrementalIngestion() List> chunks = [chunk1, chunk2]; - await writer.WriteAsync(chunks.ToAsyncEnumerable()); + await writer.WriteAsync(document, chunks.ToAsyncEnumerable()); int recordCount = await writer.VectorStoreCollection .GetAsync(filter: record => record.DocumentId == documentId, top: 100) @@ -160,7 +160,7 @@ public async Task DoesSupportIncrementalIngestion() List> updatedChunks = [updatedChunk]; - await writer.WriteAsync(updatedChunks.ToAsyncEnumerable()); + await writer.WriteAsync(document, updatedChunks.ToAsyncEnumerable()); // We ask for 100 records, but we expect only 1 as the previous 2 should have been deleted. IngestionChunkVectorRecord record = await writer.VectorStoreCollection @@ -216,7 +216,7 @@ public async Task BatchesChunks(int? batchTokenCount, int[] chunkTokenCounts) chunks.Add(new($"chunk {i + 1}", document, context: null, tokenCount: chunkTokenCounts[i])); } - await writer.WriteAsync(chunks.ToAsyncEnumerable()); + await writer.WriteAsync(document, chunks.ToAsyncEnumerable()); int recordCount = await writer.VectorStoreCollection .GetAsync(filter: record => record.DocumentId == documentId, top: 100) @@ -252,7 +252,7 @@ public async Task IncrementalIngestion_WithManyRecords_DeletesAllPreExistingChun chunks.Add(TestChunkFactory.CreateChunk($"chunk {i}", document)); } - await writer.WriteAsync(chunks.ToAsyncEnumerable()); + await writer.WriteAsync(document, chunks.ToAsyncEnumerable()); int recordCount = await writer.VectorStoreCollection .GetAsync(filter: record => record.DocumentId == documentId, top: 10000) @@ -266,7 +266,7 @@ public async Task IncrementalIngestion_WithManyRecords_DeletesAllPreExistingChun TestChunkFactory.CreateChunk("updated chunk 2", document) ]; - await writer.WriteAsync(updatedChunks.ToAsyncEnumerable()); + await writer.WriteAsync(document, updatedChunks.ToAsyncEnumerable()); // Verify that all old records were deleted and only the new ones remain List> records = await writer.VectorStoreCollection