Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 83 additions & 66 deletions databases/Databases.cs
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using Bob.Commands.Helpers;
using Bob.Database.Types;
using Bob.Database.Types.DataTransferObjects;
using BobTheBot.Chat.MemoryHandling;
using DotNetEnv;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Design;
using Npgsql;
using Npgsql.EntityFrameworkCore.PostgreSQL.Metadata;
using Pgvector;
using Pgvector.EntityFrameworkCore;


namespace Bob.Database
Expand Down Expand Up @@ -40,6 +37,7 @@ public class BobEntities(DbContextOptions<BobEntities> options) : DbContext(opti
public virtual DbSet<ScheduledAnnouncement> ScheduledAnnouncement { get; set; }
public virtual DbSet<ReactBoardMessage> ReactBoardMessage { get; set; }
public virtual DbSet<Memory> Memory { get; set; }
public DbSet<MemoryDTO> MemoryDTOs { get; set; } // Only for projection queries
public virtual DbSet<Tag> Tag { get; set; }

protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder)
Expand Down Expand Up @@ -678,31 +676,6 @@ public async Task StoreMemoryAsync(string userId, string userMessage, string bot
await SaveChangesAsync();
}

/// <summary>
/// Retrieves the most relevant memories for a user based on vector similarity.
/// </summary>
/// <param name="userId">The user's unique identifier.</param>
/// <param name="queryEmbedding">The embedding to compare against stored memories.</param>
/// <param name="limit">The maximum number of memories to return.</param>
/// <returns>A list of relevant <see cref="Memory"/> objects.</returns>
public async Task<List<Memory>> GetRelevantMemoriesAsync(string userId, Vector queryEmbedding, int limit = 5)
{
var sql = @"SELECT * FROM ""Memory""
WHERE ""UserId"" = @userId
AND (""Ephemeral"" = FALSE OR ""CreatedAt"" > NOW() - INTERVAL '2 days')
ORDER BY ""Embedding"" <-> @embedding
LIMIT @limit;";

var memories = await Memory
.FromSqlRaw(sql,
new NpgsqlParameter("userId", userId),
new NpgsqlParameter("embedding", queryEmbedding),
new NpgsqlParameter("limit", limit))
.ToListAsync();

return memories;
}

public async Task<HybridMemoryResult> GetHybridMemoriesAsync(
string userId,
Vector queryEmbedding,
Expand All @@ -719,91 +692,135 @@ public async Task<HybridMemoryResult> GetHybridMemoriesAsync(
var f = DateTime.SpecifyKind(from.Value, DateTimeKind.Utc);
var t = DateTime.SpecifyKind(to.Value, DateTimeKind.Utc);

// semantic search inside timeframe
var sql = @"SELECT * FROM ""Memory""
WHERE ""UserId"" = @userId
AND ""CreatedAt"" >= @from
AND ""CreatedAt"" <= @to
AND (""Ephemeral"" = FALSE OR ""CreatedAt"" > NOW() - INTERVAL '2 days')
ORDER BY ""Embedding"" <-> @embedding
LIMIT @limit;";

semanticMemories = await Memory
const string sql = @"
SELECT ""Id"", ""UserId"", ""CreatedAt"",
""UserMessage"", ""BotResponse"", ""Ephemeral""
FROM ""Memory""
WHERE ""UserId"" = @userId
AND ""CreatedAt"" BETWEEN @from AND @to
AND (""Ephemeral"" = FALSE OR ""CreatedAt"" > NOW() - INTERVAL '2 days')
ORDER BY ""Embedding"" <-> @embedding
LIMIT @limit;";

var semanticDto = await MemoryDTOs
.FromSqlRaw(sql,
new NpgsqlParameter("userId", userId),
new NpgsqlParameter("embedding", queryEmbedding),
new NpgsqlParameter("from", f),
new NpgsqlParameter("to", t),
new NpgsqlParameter("limit", semanticLimit))
.AsNoTracking()
.ToListAsync();

// temporal (chronological) inside timeframe
semanticMemories = [.. semanticDto.Select(d => new Memory
{
Id = d.Id,
UserId = d.UserId,
CreatedAt = d.CreatedAt,
UserMessage = d.UserMessage,
BotResponse = d.BotResponse,
Ephemeral = d.Ephemeral
})];

temporalMemories = await Memory
.Where(m => m.UserId == userId
&& m.CreatedAt >= f
&& m.CreatedAt <= t
&& (!m.Ephemeral || m.CreatedAt > DateTime.UtcNow.AddDays(-2)))
.Where(m => m.UserId == userId &&
m.CreatedAt >= f &&
m.CreatedAt <= t &&
(!m.Ephemeral || m.CreatedAt > DateTime.UtcNow.AddDays(-2)))
.OrderBy(m => m.CreatedAt)
.Take(temporalLimit)
.Select(m => new Memory
{
Id = m.Id,
UserId = m.UserId,
CreatedAt = m.CreatedAt,
UserMessage = m.UserMessage,
BotResponse = m.BotResponse,
Ephemeral = m.Ephemeral
})
.AsNoTracking()
.ToListAsync();
}
else
{
// global semantic search
var sql = @"SELECT * FROM ""Memory""
WHERE ""UserId"" = @userId
AND (""Ephemeral"" = FALSE OR ""CreatedAt"" > NOW() - INTERVAL '2 days')
ORDER BY ""Embedding"" <-> @embedding
LIMIT @limit;";

semanticMemories = await Memory
const string sql = @"
SELECT ""Id"", ""UserId"", ""CreatedAt"",
""UserMessage"", ""BotResponse"", ""Ephemeral""
FROM ""Memory""
WHERE ""UserId"" = @userId
AND (""Ephemeral"" = FALSE OR ""CreatedAt"" > NOW() - INTERVAL '2 days')
ORDER BY ""Embedding"" <-> @embedding
LIMIT @limit;";

var semanticDto = await MemoryDTOs
.FromSqlRaw(sql,
new NpgsqlParameter("userId", userId),
new NpgsqlParameter("embedding", queryEmbedding),
new NpgsqlParameter("limit", semanticLimit))
.AsNoTracking()
.ToListAsync();

semanticMemories = [.. semanticDto.Select(d => new Memory
{
Id = d.Id,
UserId = d.UserId,
CreatedAt = d.CreatedAt,
UserMessage = d.UserMessage,
BotResponse = d.BotResponse,
Ephemeral = d.Ephemeral
})];

temporalMemories = [];
}

// merge + deduplicate
var merged = semanticMemories
.Concat(temporalMemories)
.GroupBy(m => m.Id)
.Select(g => g.First())
.ToList();
// merge and deduplicate with dictionary
var merged = new Dictionary<int, Memory>();
foreach (var m in semanticMemories)
merged[m.Id] = m;
foreach (var m in temporalMemories)
merged.TryAdd(m.Id, m);

return new HybridMemoryResult(
merged,
[.. merged.Values],
SemanticCount: semanticMemories.Count,
TemporalCount: temporalMemories.Count
);
TemporalCount: temporalMemories.Count);
}

public async Task<List<Memory>> GetRecentConversationAsync(
string userId,
int limit = 5,
TimeSpan? maxGap = null)
{
maxGap ??= TimeSpan.FromMinutes(30); // defines what “continuous” means
maxGap ??= TimeSpan.FromMinutes(30);

var ordered = await Memory
.Where(m => m.UserId == userId && !m.Ephemeral)
.OrderByDescending(m => m.CreatedAt)
.Take(limit * 3) // grab a bit extra to detect spacing
.Take(limit * 3)
.Select(m => new Memory
{
Id = m.Id,
UserId = m.UserId,
CreatedAt = m.CreatedAt,
UserMessage = m.UserMessage,
BotResponse = m.BotResponse,
Ephemeral = m.Ephemeral
})
.AsNoTracking()
.ToListAsync();

if (ordered.Count == 0)
return [];

ordered.Reverse(); // chronological order
ordered.Reverse();

// Filter messages that are reasonably close in time
var cluster = new List<Memory> { ordered.Last() };
for (int i = ordered.Count - 2; i >= 0; i--)
{
var newer = cluster.First();
var older = ordered[i];
if (newer.CreatedAt - older.CreatedAt > maxGap)
break; // stop if this older message is too far apart in time
break;
cluster.Insert(0, older);
if (cluster.Count >= limit)
break;
Expand Down
13 changes: 13 additions & 0 deletions databases/types/DTOs/memoryDTO.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using System;

namespace Bob.Database.Types.DataTransferObjects;

public class MemoryDTO
{
public int Id { get; set; }
public string UserId { get; set; }
public DateTime CreatedAt { get; set; }
public string UserMessage { get; set; }
public string BotResponse { get; set; }
public bool Ephemeral { get; set; }
}
2 changes: 2 additions & 0 deletions databases/types/memory.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using System;
using Pgvector;

namespace Bob.Database.Types;

public class Memory
{
public int Id { get; set; }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Collections.Generic;
using Bob.Database.Types;

namespace BobTheBot.Chat.MemoryHandling;

Expand Down
Loading