Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,6 @@
<PackageVersion Include="System.Numerics.Tensors" Version="10.0.3" />
<PackageVersion Include="Microsoft.Extensions.Hosting" Version="10.0.3" />
<PackageVersion Include="Microsoft.Extensions.Hosting.WindowsServices" Version="10.0.3" />
<PackageVersion Include="diskann-garnet" Version="1.0.23" />
<PackageVersion Include="diskann-garnet" Version="1.0.25" />
</ItemGroup>
</Project>
12 changes: 9 additions & 3 deletions libs/server/Resp/Vector/DiskANNService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using Garnet.common;
using Tsavorite.core;

namespace Garnet.server
Expand Down Expand Up @@ -34,11 +35,15 @@ public nint CreateIndex(
delegate* unmanaged[Cdecl]<ulong, nint, nuint, nuint, nint, nint, byte> readModifyWriteCallback
)
{
// TODO: actually pass distance metric

unsafe
{
return NativeDiskANNMethods.create_index(context, dimensions, reduceDims, quantType, buildExplorationFactor, numLinks, (nint)readCallback, (nint)writeCallback, (nint)deleteCallback, (nint)readModifyWriteCallback);
var index = NativeDiskANNMethods.create_index(context, dimensions, reduceDims, quantType, (int)distanceMetric, buildExplorationFactor, numLinks, (nint)readCallback, (nint)writeCallback, (nint)deleteCallback, (nint)readModifyWriteCallback);
if (index == nint.Zero)
{
throw new GarnetException("Failed to create DiskANN index, native create_index returned null");
}

return index;
}
}

Expand Down Expand Up @@ -308,6 +313,7 @@ public static partial nint create_index(
uint dimensions,
uint reduceDims,
VectorQuantType quantType,
int metricType,
uint buildExplorationFactor,
uint numLinks,
nint readCallback,
Expand Down
6 changes: 3 additions & 3 deletions test/Garnet.test/DiskANN/DiskANNServiceTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ unsafe byte ReadModifyWriteCallback(ulong context, nint keyData, nuint keyLength
var deleteFuncPtr = Marshal.GetFunctionPointerForDelegate(deleteDel);
var rmwFuncPtr = Marshal.GetFunctionPointerForDelegate(rmwDel);

var rawIndex = NativeDiskANNMethods.create_index(Context, 75, 0, VectorQuantType.XPreQ8, 10, 10, readFuncPtr, writeFuncPtr, deleteFuncPtr, rmwFuncPtr);
var rawIndex = NativeDiskANNMethods.create_index(Context, 75, 0, VectorQuantType.XPreQ8, (int)VectorDistanceMetricType.Cosine, 10, 10, readFuncPtr, writeFuncPtr, deleteFuncPtr, rmwFuncPtr);

Span<byte> id = [0, 1, 2, 3];
Span<byte> elem = Enumerable.Range(0, 75).Select(static x => (byte)x).ToArray();
Expand Down Expand Up @@ -365,7 +365,7 @@ unsafe byte ReadModifyWriteCallback(ulong context, nint keyData, nuint keyLength
var deleteFuncPtr = Marshal.GetFunctionPointerForDelegate(deleteDel);
var rmwFuncPtr = Marshal.GetFunctionPointerForDelegate(rmwDel);

var rawIndex = NativeDiskANNMethods.create_index(Context, 75, 0, VectorQuantType.XPreQ8, 10, 10, readFuncPtr, writeFuncPtr, deleteFuncPtr, rmwFuncPtr);
var rawIndex = NativeDiskANNMethods.create_index(Context, 75, 0, VectorQuantType.XPreQ8, (int)VectorDistanceMetricType.Cosine, 10, 10, readFuncPtr, writeFuncPtr, deleteFuncPtr, rmwFuncPtr);

Span<byte> id = [0, 1, 2, 3];
Span<byte> elem = Enumerable.Range(0, 75).Select(static x => (byte)x).ToArray();
Expand Down Expand Up @@ -410,7 +410,7 @@ unsafe byte ReadModifyWriteCallback(ulong context, nint keyData, nuint keyLength
{
NativeDiskANNMethods.drop_index(Context, rawIndex);

rawIndex = NativeDiskANNMethods.create_index(Context, 75, 0, VectorQuantType.XPreQ8, 10, 10, readFuncPtr, writeFuncPtr, deleteFuncPtr, rmwFuncPtr);
rawIndex = NativeDiskANNMethods.create_index(Context, 75, 0, VectorQuantType.XPreQ8, (int)VectorDistanceMetricType.Cosine, 10, 10, readFuncPtr, writeFuncPtr, deleteFuncPtr, rmwFuncPtr);
}

// Search value
Expand Down
Loading