From 14a197acf406212df9a3d31e39b794c8dbba8ddb Mon Sep 17 00:00:00 2001 From: Zach Goodson Date: Sun, 26 Oct 2025 17:27:39 -0500 Subject: [PATCH] Make Memory DB Queries Faster & Lighter on Memory --- databases/Databases.cs | 149 ++++++++++-------- databases/types/DTOs/memoryDTO.cs | 13 ++ databases/types/memory.cs | 2 + .../temporal-handling/HybridMemoryResult.cs | 1 + 4 files changed, 99 insertions(+), 66 deletions(-) create mode 100644 databases/types/DTOs/memoryDTO.cs diff --git a/databases/Databases.cs b/databases/Databases.cs index 3707c208..b9a8a67a 100644 --- a/databases/Databases.cs +++ b/databases/Databases.cs @@ -1,18 +1,15 @@ using System; using System.Collections.Generic; -using System.IO; using System.Linq; using System.Threading.Tasks; using Bob.Commands.Helpers; using Bob.Database.Types; +using Bob.Database.Types.DataTransferObjects; using BobTheBot.Chat.MemoryHandling; -using DotNetEnv; using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.Design; using Npgsql; -using Npgsql.EntityFrameworkCore.PostgreSQL.Metadata; using Pgvector; -using Pgvector.EntityFrameworkCore; namespace Bob.Database @@ -40,6 +37,7 @@ public class BobEntities(DbContextOptions options) : DbContext(opti public virtual DbSet ScheduledAnnouncement { get; set; } public virtual DbSet ReactBoardMessage { get; set; } public virtual DbSet Memory { get; set; } + public DbSet MemoryDTOs { get; set; } // Only for projection queries public virtual DbSet Tag { get; set; } protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) @@ -678,31 +676,6 @@ public async Task StoreMemoryAsync(string userId, string userMessage, string bot await SaveChangesAsync(); } - /// - /// Retrieves the most relevant memories for a user based on vector similarity. - /// - /// The user's unique identifier. - /// The embedding to compare against stored memories. - /// The maximum number of memories to return. - /// A list of relevant objects. - public async Task> GetRelevantMemoriesAsync(string userId, Vector queryEmbedding, int limit = 5) - { - var sql = @"SELECT * FROM ""Memory"" - WHERE ""UserId"" = @userId - AND (""Ephemeral"" = FALSE OR ""CreatedAt"" > NOW() - INTERVAL '2 days') - ORDER BY ""Embedding"" <-> @embedding - LIMIT @limit;"; - - var memories = await Memory - .FromSqlRaw(sql, - new NpgsqlParameter("userId", userId), - new NpgsqlParameter("embedding", queryEmbedding), - new NpgsqlParameter("limit", limit)) - .ToListAsync(); - - return memories; - } - public async Task GetHybridMemoriesAsync( string userId, Vector queryEmbedding, @@ -719,63 +692,98 @@ public async Task GetHybridMemoriesAsync( var f = DateTime.SpecifyKind(from.Value, DateTimeKind.Utc); var t = DateTime.SpecifyKind(to.Value, DateTimeKind.Utc); - // semantic search inside timeframe - var sql = @"SELECT * FROM ""Memory"" - WHERE ""UserId"" = @userId - AND ""CreatedAt"" >= @from - AND ""CreatedAt"" <= @to - AND (""Ephemeral"" = FALSE OR ""CreatedAt"" > NOW() - INTERVAL '2 days') - ORDER BY ""Embedding"" <-> @embedding - LIMIT @limit;"; - - semanticMemories = await Memory + const string sql = @" + SELECT ""Id"", ""UserId"", ""CreatedAt"", + ""UserMessage"", ""BotResponse"", ""Ephemeral"" + FROM ""Memory"" + WHERE ""UserId"" = @userId + AND ""CreatedAt"" BETWEEN @from AND @to + AND (""Ephemeral"" = FALSE OR ""CreatedAt"" > NOW() - INTERVAL '2 days') + ORDER BY ""Embedding"" <-> @embedding + LIMIT @limit;"; + + var semanticDto = await MemoryDTOs .FromSqlRaw(sql, new NpgsqlParameter("userId", userId), new NpgsqlParameter("embedding", queryEmbedding), new NpgsqlParameter("from", f), new NpgsqlParameter("to", t), new NpgsqlParameter("limit", semanticLimit)) + .AsNoTracking() .ToListAsync(); - // temporal (chronological) inside timeframe + semanticMemories = [.. semanticDto.Select(d => new Memory + { + Id = d.Id, + UserId = d.UserId, + CreatedAt = d.CreatedAt, + UserMessage = d.UserMessage, + BotResponse = d.BotResponse, + Ephemeral = d.Ephemeral + })]; + temporalMemories = await Memory - .Where(m => m.UserId == userId - && m.CreatedAt >= f - && m.CreatedAt <= t - && (!m.Ephemeral || m.CreatedAt > DateTime.UtcNow.AddDays(-2))) + .Where(m => m.UserId == userId && + m.CreatedAt >= f && + m.CreatedAt <= t && + (!m.Ephemeral || m.CreatedAt > DateTime.UtcNow.AddDays(-2))) .OrderBy(m => m.CreatedAt) .Take(temporalLimit) + .Select(m => new Memory + { + Id = m.Id, + UserId = m.UserId, + CreatedAt = m.CreatedAt, + UserMessage = m.UserMessage, + BotResponse = m.BotResponse, + Ephemeral = m.Ephemeral + }) + .AsNoTracking() .ToListAsync(); } else { - // global semantic search - var sql = @"SELECT * FROM ""Memory"" - WHERE ""UserId"" = @userId - AND (""Ephemeral"" = FALSE OR ""CreatedAt"" > NOW() - INTERVAL '2 days') - ORDER BY ""Embedding"" <-> @embedding - LIMIT @limit;"; - - semanticMemories = await Memory + const string sql = @" + SELECT ""Id"", ""UserId"", ""CreatedAt"", + ""UserMessage"", ""BotResponse"", ""Ephemeral"" + FROM ""Memory"" + WHERE ""UserId"" = @userId + AND (""Ephemeral"" = FALSE OR ""CreatedAt"" > NOW() - INTERVAL '2 days') + ORDER BY ""Embedding"" <-> @embedding + LIMIT @limit;"; + + var semanticDto = await MemoryDTOs .FromSqlRaw(sql, new NpgsqlParameter("userId", userId), new NpgsqlParameter("embedding", queryEmbedding), new NpgsqlParameter("limit", semanticLimit)) + .AsNoTracking() .ToListAsync(); + + semanticMemories = [.. semanticDto.Select(d => new Memory + { + Id = d.Id, + UserId = d.UserId, + CreatedAt = d.CreatedAt, + UserMessage = d.UserMessage, + BotResponse = d.BotResponse, + Ephemeral = d.Ephemeral + })]; + + temporalMemories = []; } - // merge + deduplicate - var merged = semanticMemories - .Concat(temporalMemories) - .GroupBy(m => m.Id) - .Select(g => g.First()) - .ToList(); + // merge and deduplicate with dictionary + var merged = new Dictionary(); + foreach (var m in semanticMemories) + merged[m.Id] = m; + foreach (var m in temporalMemories) + merged.TryAdd(m.Id, m); return new HybridMemoryResult( - merged, + [.. merged.Values], SemanticCount: semanticMemories.Count, - TemporalCount: temporalMemories.Count - ); + TemporalCount: temporalMemories.Count); } public async Task> GetRecentConversationAsync( @@ -783,27 +791,36 @@ public async Task> GetRecentConversationAsync( int limit = 5, TimeSpan? maxGap = null) { - maxGap ??= TimeSpan.FromMinutes(30); // defines what “continuous” means + maxGap ??= TimeSpan.FromMinutes(30); var ordered = await Memory .Where(m => m.UserId == userId && !m.Ephemeral) .OrderByDescending(m => m.CreatedAt) - .Take(limit * 3) // grab a bit extra to detect spacing + .Take(limit * 3) + .Select(m => new Memory + { + Id = m.Id, + UserId = m.UserId, + CreatedAt = m.CreatedAt, + UserMessage = m.UserMessage, + BotResponse = m.BotResponse, + Ephemeral = m.Ephemeral + }) + .AsNoTracking() .ToListAsync(); if (ordered.Count == 0) return []; - ordered.Reverse(); // chronological order + ordered.Reverse(); - // Filter messages that are reasonably close in time var cluster = new List { ordered.Last() }; for (int i = ordered.Count - 2; i >= 0; i--) { var newer = cluster.First(); var older = ordered[i]; if (newer.CreatedAt - older.CreatedAt > maxGap) - break; // stop if this older message is too far apart in time + break; cluster.Insert(0, older); if (cluster.Count >= limit) break; diff --git a/databases/types/DTOs/memoryDTO.cs b/databases/types/DTOs/memoryDTO.cs new file mode 100644 index 00000000..4f44eb89 --- /dev/null +++ b/databases/types/DTOs/memoryDTO.cs @@ -0,0 +1,13 @@ +using System; + +namespace Bob.Database.Types.DataTransferObjects; + +public class MemoryDTO +{ + public int Id { get; set; } + public string UserId { get; set; } + public DateTime CreatedAt { get; set; } + public string UserMessage { get; set; } + public string BotResponse { get; set; } + public bool Ephemeral { get; set; } +} \ No newline at end of file diff --git a/databases/types/memory.cs b/databases/types/memory.cs index 52918620..4997cf98 100644 --- a/databases/types/memory.cs +++ b/databases/types/memory.cs @@ -1,6 +1,8 @@ using System; using Pgvector; +namespace Bob.Database.Types; + public class Memory { public int Id { get; set; } diff --git a/general-helpers/chat/temporal-handling/HybridMemoryResult.cs b/general-helpers/chat/temporal-handling/HybridMemoryResult.cs index 01918092..acc3ca52 100644 --- a/general-helpers/chat/temporal-handling/HybridMemoryResult.cs +++ b/general-helpers/chat/temporal-handling/HybridMemoryResult.cs @@ -1,4 +1,5 @@ using System.Collections.Generic; +using Bob.Database.Types; namespace BobTheBot.Chat.MemoryHandling;