diff --git a/.gitignore b/.gitignore index 518d386d146..718142e387e 100644 --- a/.gitignore +++ b/.gitignore @@ -29,6 +29,7 @@ mobile/native/android/apps/ensu/crypto-auth-core/src/main/jniLibs/ mobile/native/android/apps/ensu/crypto-auth-core/src/main/java/io/ente/ensu/crypto/core_uniffi.kt mobile/native/android/apps/ensu/crypto-auth-core/src/main/java/io/ente/ensu/crypto/core.kt mobile/native/android/packages/rust/.gradle/ +mobile/packages/**/android/.gradle/ mobile/native/android/packages/rust/src/main/jniLibs/ mobile/native/android/packages/rust/src/main/kotlin/io/ente/labs/inference_rs/inference_rs_uniffi.kt mobile/native/android/packages/rust/src/main/kotlin/io/ente/labs/inference_rs/inference.kt diff --git a/mobile/apps/photos/lib/db/ml/base.dart b/mobile/apps/photos/lib/db/ml/base.dart index 8c17601bbbd..50e4db2ecbc 100644 --- a/mobile/apps/photos/lib/db/ml/base.dart +++ b/mobile/apps/photos/lib/db/ml/base.dart @@ -10,7 +10,6 @@ import "package:photos/services/machine_learning/face_ml/face_clustering/face_db abstract class IMLDataDB { Future bulkInsertFaces(List faces); Future bulkInsertPetFaces(List petFaces); - Future bulkInsertPetBodies(List petBodies); Future updateFaceIdToClusterId(Map faceIDToClusterID); Future> faceIndexedFileIds({int minimumMlVersion}); Future getFaceIndexedFileCount({int minimumMlVersion}); @@ -36,7 +35,6 @@ abstract class IMLDataDB { }); Future?> getFacesForGivenFileID(T fileUploadID); Future?> getPetFacesForFileID(T fileUploadID); - Future?> getPetBodiesForFileID(T fileUploadID); Future>> getFileIDsToFacesWithoutEmbedding(); Future>> getClusterToFaceIDs( @@ -117,7 +115,6 @@ abstract class IMLDataDB { }); Future updatePetFaceVectorIds(Map petFaceIdToVectorId); - Future updatePetBodyVectorIds(Map petBodyIdToVectorId); Future> getAllClipVectors(); Future> clipIndexedFileWithVersion(); diff --git a/mobile/apps/photos/lib/db/ml/db.dart b/mobile/apps/photos/lib/db/ml/db.dart index 5791868cbc2..cfce1b9bf16 100644 --- a/mobile/apps/photos/lib/db/ml/db.dart +++ b/mobile/apps/photos/lib/db/ml/db.dart @@ -106,6 +106,11 @@ class MLDataDB with SqlDbBase implements IMLDataDB { createPetBodiesTable, createPetFaceVectorIdMappingTable, createPetBodyVectorIdMappingTable, + createPetFaceClustersTable, + petFcClusterIDIndex, + createPetClusterPetTable, + createNotPetFeedbackTable, + petFacesSpeciesIndex, ]; static const List _offlineMigrationScripts = [ ..._defaultMigrationScripts, @@ -218,41 +223,6 @@ class MLDataDB with SqlDbBase implements IMLDataDB { } } - @override - Future bulkInsertPetBodies(List petBodies) async { - final db = await asyncDB; - const batchSize = 500; - final numBatches = (petBodies.length / batchSize).ceil(); - for (int i = 0; i < numBatches; i++) { - final start = i * batchSize; - final end = min((i + 1) * batchSize, petBodies.length); - final batch = petBodies.sublist(start, end); - - const String sql = ''' - INSERT INTO $petBodiesTable ( - $fileIDColumn, $petBodyIDColumn, $detectionColumn, $bodyVectorIdColumn, $speciesColumn, $bodyScore, $imageHeight, $imageWidth, $mlVersionColumn - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - ON CONFLICT($fileIDColumn, $petBodyIDColumn) DO UPDATE SET $detectionColumn = excluded.$detectionColumn, $bodyVectorIdColumn = excluded.$bodyVectorIdColumn, $speciesColumn = excluded.$speciesColumn, $bodyScore = excluded.$bodyScore, $imageHeight = excluded.$imageHeight, $imageWidth = excluded.$imageWidth, $mlVersionColumn = excluded.$mlVersionColumn - '''; - final parameterSets = batch.map((obj) { - final map = obj.toMap(); - return [ - map[fileIDColumn], - map[petBodyIDColumn], - map[detectionColumn], - map[bodyVectorIdColumn], - map[speciesColumn], - map[bodyScore], - map[imageHeight], - map[imageWidth], - map[mlVersionColumn], - ]; - }).toList(); - - await db.executeBatch(sql, parameterSets); - } - } - @override Future updatePetFaceVectorIds( Map petFaceIdToVectorId, @@ -277,30 +247,6 @@ class MLDataDB with SqlDbBase implements IMLDataDB { } } - @override - Future updatePetBodyVectorIds( - Map petBodyIdToVectorId, - ) async { - if (petBodyIdToVectorId.isEmpty) return; - final db = await asyncDB; - const batchSize = 500; - final entries = petBodyIdToVectorId.entries.toList(); - final numBatches = (entries.length / batchSize).ceil(); - for (int i = 0; i < numBatches; i++) { - final start = i * batchSize; - final end = min((i + 1) * batchSize, entries.length); - final batch = entries.sublist(start, end); - - const String sql = ''' - UPDATE $petBodiesTable - SET $bodyVectorIdColumn = ? - WHERE $petBodyIDColumn = ? - '''; - final parameterSets = batch.map((e) => [e.value, e.key]).toList(); - await db.executeBatch(sql, parameterSets); - } - } - /// Store pet face embeddings into PetVectorDB after SQLite insert. Future storePetFaceEmbeddings( List dbPetFaces, @@ -367,76 +313,6 @@ class MLDataDB with SqlDbBase implements IMLDataDB { } } - /// Store pet body embeddings into PetVectorDB after SQLite insert. - Future storePetBodyEmbeddings( - List dbPetBodies, - List petBodies, - ) async { - if (dbPetBodies.length != petBodies.length) { - throw StateError( - 'dbPetBodies.length (${dbPetBodies.length}) != petBodies.length (${petBodies.length})', - ); - } - try { - final db = await asyncDB; - // Group by species (0 = dog, 1 = cat) - final bySpecies = >{}; - for (int i = 0; i < dbPetBodies.length; i++) { - final species = dbPetBodies[i].species; - bySpecies.putIfAbsent(species, () => []); - bySpecies[species]!.add((dbPetBodies[i], petBodies[i])); - } - for (final entry in bySpecies.entries) { - final vdb = PetVectorDB.forModel( - species: entry.key, - isFace: false, - offline: _isOffline, - ); - final bodyIds = entry.value.map((e) => e.$1.petBodyId).toList(); - final idMap = await vdb.getObjectVectorIdMap( - bodyIds, - db: db, - createIfMissing: true, - ); - final vectorIds = []; - final embeddings = []; - final insertedBodyIds = []; - for (final (dbBody, bodyResult) in entry.value) { - final vid = idMap[dbBody.petBodyId]; - if (vid == null) continue; - final emb = Float32List.fromList(bodyResult.embedding); - if (emb.length != PetVectorDB.bodyDimension) { - _logger.warning( - "Skipping pet body embedding with wrong dimension ${emb.length}", - ); - continue; - } - vectorIds.add(vid); - embeddings.add(emb); - insertedBodyIds.add(dbBody.petBodyId); - } - if (vectorIds.isNotEmpty) { - await vdb.bulkInsertEmbeddings( - vectorIds: vectorIds, - embeddings: embeddings, - ); - final updateMap = Map.fromIterables( - insertedBodyIds, - vectorIds, - ); - await updatePetBodyVectorIds(updateMap); - } - } - } catch (e, s) { - _logger.severe( - "Failed to store pet body embeddings in vector DB", - e, - s, - ); - rethrow; - } - } - @override Future updateFaceIdToClusterId( Map faceIDToClusterID, @@ -576,6 +452,11 @@ class MLDataDB with SqlDbBase implements IMLDataDB { await db.execute(deletePetBodiesTable); await db.execute(deletePetFaceVectorIdMappingTable); await db.execute(deletePetBodyVectorIdMappingTable); + await db.execute(deletePetFaceClustersTable); + await db.execute(deletePetClusterPetTable); + if (!_isOffline) { + await db.execute(deleteNotPetFeedbackTable); + } final petVdbs = _isOffline ? PetVectorDB.allOfflineInstances : PetVectorDB.allInstances; for (final vdb in petVdbs) { @@ -767,25 +648,6 @@ class MLDataDB with SqlDbBase implements IMLDataDB { return maps.map((e) => DBPetFace.fromMap(e)).toList(); } - @override - Future?> getPetBodiesForFileID( - int fileUploadID, - ) async { - final db = await asyncDB; - const String query = ''' - SELECT * FROM $petBodiesTable - WHERE $fileIDColumn = ? AND $speciesColumn != -1 - '''; - final List> maps = await db.getAll( - query, - [fileUploadID], - ); - if (maps.isEmpty) { - return null; - } - return maps.map((e) => DBPetBody.fromMap(e)).toList(); - } - @override Future>> getFileIDsToFacesWithoutEmbedding() async { @@ -1718,6 +1580,31 @@ class MLDataDB with SqlDbBase implements IMLDataDB { } } + /// WARNING: Deletes all pet clustering data (clusters, summaries, feedback, + /// centroid mappings). Preserves indexed pet faces/bodies. Debug only! + Future dropPetClusteringData() async { + try { + final db = await asyncDB; + + // Clear pet clustering tables + await db.execute(deletePetFaceClustersTable); + await db.execute(deletePetClusterPetTable); + if (!_isOffline) { + await db.execute(deleteNotPetFeedbackTable); + } + + // Recreate the tables + await db.execute(createPetFaceClustersTable); + await db.execute(petFcClusterIDIndex); + await db.execute(createPetClusterPetTable); + if (!_isOffline) { + await db.execute(createNotPetFeedbackTable); + } + } catch (e, s) { + _logger.severe('Error dropping pet clustering data', e, s); + } + } + @override Future> getFileIDsOfPersonID(String personID) async { final db = await asyncDB; @@ -2446,40 +2333,20 @@ class MLDataDB with SqlDbBase implements IMLDataDB { 'FROM $petFacesTable WHERE $fileIDColumn IN ($placeholders)', fileIDs, ); - final bodyRows = await db.getAll( - 'SELECT $petBodyIDColumn, $bodyVectorIdColumn, $speciesColumn ' - 'FROM $petBodiesTable WHERE $fileIDColumn IN ($placeholders)', - fileIDs, - ); - // Group face vector IDs by species for targeted vector DB deletion. final faceVidsBySpecies = >{}; final faceIdsToRemove = []; for (final row in faceRows) { final vid = row[faceVectorIdColumn] as int?; final petFaceId = row[petFaceIDColumn] as String; + final species = row[speciesColumn] as int; faceIdsToRemove.add(petFaceId); if (vid != null) { - final species = row[speciesColumn] as int; faceVidsBySpecies.putIfAbsent(species, () => []); faceVidsBySpecies[species]!.add(vid); } } - // Group body vector IDs by species. - final bodyVidsBySpecies = >{}; - final bodyIdsToRemove = []; - for (final row in bodyRows) { - final vid = row[bodyVectorIdColumn] as int?; - final petBodyId = row[petBodyIDColumn] as String; - bodyIdsToRemove.add(petBodyId); - if (vid != null) { - final species = row[speciesColumn] as int; - bodyVidsBySpecies.putIfAbsent(species, () => []); - bodyVidsBySpecies[species]!.add(vid); - } - } - // Delete from usearch vector indexes. for (final entry in faceVidsBySpecies.entries) { try { @@ -2493,35 +2360,21 @@ class MLDataDB with SqlDbBase implements IMLDataDB { _logger.warning("Failed to delete pet face vectors", e, s); } } - for (final entry in bodyVidsBySpecies.entries) { - try { - final vdb = PetVectorDB.forModel( - species: entry.key, - isFace: false, - offline: _isOffline, - ); - await vdb.deleteEmbeddings(entry.value); - } catch (e, s) { - _logger.warning("Failed to delete pet body vectors", e, s); - } - } - - // Delete mapping table entries and detection rows atomically. + // Delete mapping table entries, cluster assignments, and detection rows + // atomically. await db.writeTransaction((tx) async { if (faceIdsToRemove.isNotEmpty) { - final placeholders = List.filled(faceIdsToRemove.length, '?').join(','); + final fpH = List.filled(faceIdsToRemove.length, '?').join(','); await tx.execute( 'DELETE FROM $petFaceVectorIdMappingTable ' - 'WHERE $petFaceIDColumn IN ($placeholders)', + 'WHERE $petFaceIDColumn IN ($fpH)', faceIdsToRemove, ); - } - if (bodyIdsToRemove.isNotEmpty) { - final placeholders = List.filled(bodyIdsToRemove.length, '?').join(','); + // Remove cluster assignments for deleted faces await tx.execute( - 'DELETE FROM $petBodyVectorIdMappingTable ' - 'WHERE $petBodyIDColumn IN ($placeholders)', - bodyIdsToRemove, + 'DELETE FROM $petFaceClustersTable ' + 'WHERE $petFaceIDColumn IN ($fpH)', + faceIdsToRemove, ); } await tx.execute( diff --git a/mobile/apps/photos/lib/db/ml/db_pet_model_mappers.dart b/mobile/apps/photos/lib/db/ml/db_pet_model_mappers.dart index 7d969231772..b27abd44200 100644 --- a/mobile/apps/photos/lib/db/ml/db_pet_model_mappers.dart +++ b/mobile/apps/photos/lib/db/ml/db_pet_model_mappers.dart @@ -73,59 +73,3 @@ class DBPetFace { ); } } - -// ── Detected Object DB Mapper ── - -/// Represents a row in the [petBodiesTable]. -class DBPetBody { - final int fileId; - final String petBodyId; - final String detection; - final int? bodyVectorId; - final int species; - final double score; - final int imageHeight; - final int imageWidth; - final int mlVersion; - - DBPetBody({ - required this.fileId, - required this.petBodyId, - required this.detection, - required this.bodyVectorId, - required this.species, - required this.score, - required this.imageHeight, - required this.imageWidth, - required this.mlVersion, - }); - - Map toMap() { - return { - fileIDColumn: fileId, - petBodyIDColumn: petBodyId, - detectionColumn: detection, - bodyVectorIdColumn: bodyVectorId, - speciesColumn: species, - // Use string literals for keys that collide with instance member names. - 'score': score, - 'height': imageHeight, - 'width': imageWidth, - mlVersionColumn: mlVersion, - }; - } - - factory DBPetBody.fromMap(Map map) { - return DBPetBody( - fileId: map[fileIDColumn] as int, - petBodyId: map[petBodyIDColumn] as String, - detection: map[detectionColumn] as String, - bodyVectorId: map[bodyVectorIdColumn] as int?, - species: map[speciesColumn] as int, - score: parseIntOrDoubleAsDouble(map['score']) ?? 0.0, - imageHeight: map['height'] as int, - imageWidth: map['width'] as int, - mlVersion: map[mlVersionColumn] as int, - ); - } -} diff --git a/mobile/apps/photos/lib/db/ml/pet_vector_db.dart b/mobile/apps/photos/lib/db/ml/pet_vector_db.dart index c48aa54571f..9be970751a2 100644 --- a/mobile/apps/photos/lib/db/ml/pet_vector_db.dart +++ b/mobile/apps/photos/lib/db/ml/pet_vector_db.dart @@ -25,7 +25,6 @@ class PetVectorDB { static final Logger _logger = Logger("PetVectorDB"); static const int faceDimension = 128; - static const int bodyDimension = 192; final String _databaseName; final BigInt _embeddingDimension; @@ -50,16 +49,6 @@ class PetVectorDB { BigInt.from(faceDimension), ); - static final dogBody = PetVectorDB._named( - "ente.ml.vectordb.pet.dog_body.usearch", - BigInt.from(bodyDimension), - ); - - static final catBody = PetVectorDB._named( - "ente.ml.vectordb.pet.cat_body.usearch", - BigInt.from(bodyDimension), - ); - // ── Offline vector spaces ── static final offlineDogFace = PetVectorDB._named( @@ -72,38 +61,23 @@ class PetVectorDB { BigInt.from(faceDimension), ); - static final offlineDogBody = PetVectorDB._named( - "ente.ml.offline.vectordb.pet.dog_body.usearch", - BigInt.from(bodyDimension), - ); - - static final offlineCatBody = PetVectorDB._named( - "ente.ml.offline.vectordb.pet.cat_body.usearch", - BigInt.from(bodyDimension), - ); - /// All online vector DB instances for iteration. static final List allInstances = [ dogFace, catFace, - dogBody, - catBody, ]; /// All offline vector DB instances for iteration. static final List allOfflineInstances = [ offlineDogFace, offlineCatFace, - offlineDogBody, - offlineCatBody, ]; - /// Get the correct vector DB for a species + embedding type. + /// Get the correct face vector DB for a species. /// [species]: 0 = dog, 1 = cat - /// [isFace]: true = face embedding, false = body embedding static PetVectorDB forModel({ required int species, - required bool isFace, + bool isFace = true, bool offline = false, }) { assert( @@ -111,22 +85,20 @@ class PetVectorDB { 'Invalid pet species: $species (expected 0=dog or 1=cat)', ); if (offline) { - if (species == 0) { - return isFace ? offlineDogFace : offlineDogBody; - } else { - return isFace ? offlineCatFace : offlineCatBody; - } - } - if (species == 0) { - return isFace ? dogFace : dogBody; - } else { - return isFace ? catFace : catBody; + return species == 0 ? offlineDogFace : offlineCatFace; } + return species == 0 ? dogFace : catFace; } Future? _vectorDbFuture; final Lock _writeLock = Lock(); + /// Get the on-disk file path for this vector index without opening it. + Future getIndexPath() async { + final documentsDirectory = await getApplicationDocumentsDirectory(); + return join(documentsDirectory.path, _databaseName); + } + Future get _vectorDB async { _vectorDbFuture ??= _initVectorDB(); return _vectorDbFuture!; @@ -205,48 +177,6 @@ class PetVectorDB { return result; } - /// Get or create integer vector IDs for body/object embeddings. - /// Uses [petBodyVectorIdMappingTable] — separate ID space from face embeddings. - Future> getObjectVectorIdMap( - Iterable objectIds, { - required SqliteDatabase db, - bool createIfMissing = false, - }) async { - final uniqueIds = objectIds.toSet().toList(growable: false); - if (uniqueIds.isEmpty) return {}; - - if (createIfMissing) { - const insertSql = ''' - INSERT OR IGNORE INTO $petBodyVectorIdMappingTable ($petBodyIDColumn) - VALUES (?) - '''; - final insertParams = >[]; - for (final id in uniqueIds) { - insertParams.add([id]); - } - await db.executeBatch(insertSql, insertParams); - } - - final result = {}; - const chunkSize = 800; - for (int i = 0; i < uniqueIds.length; i += chunkSize) { - final chunk = uniqueIds.sublist(i, min(i + chunkSize, uniqueIds.length)); - final rows = await db.getAll( - ''' - SELECT $petBodyIDColumn, $petBodyVectorIdColumn - FROM $petBodyVectorIdMappingTable - WHERE $petBodyIDColumn IN (${List.filled(chunk.length, '?').join(',')}) - ''', - chunk, - ); - for (final row in rows) { - result[row[petBodyIDColumn] as String] = - row[petBodyVectorIdColumn] as int; - } - } - return result; - } - // ── Vector Operations ── Future _runWriteOperation( diff --git a/mobile/apps/photos/lib/db/ml/schema.dart b/mobile/apps/photos/lib/db/ml/schema.dart index 1c3b0753687..28f87993f7a 100644 --- a/mobile/apps/photos/lib/db/ml/schema.dart +++ b/mobile/apps/photos/lib/db/ml/schema.dart @@ -213,6 +213,9 @@ const createPetFacesTable = '''CREATE TABLE IF NOT EXISTS $petFacesTable ( const deletePetFacesTable = 'DELETE FROM $petFacesTable'; +const petFacesSpeciesIndex = + 'CREATE INDEX IF NOT EXISTS idx_pet_faces_species ON $petFacesTable ($speciesColumn)'; + // ── Pet Bodies Table ── const petBodiesTable = 'pet_bodies'; @@ -265,3 +268,46 @@ CREATE TABLE IF NOT EXISTS $petBodyVectorIdMappingTable ( const deletePetBodyVectorIdMappingTable = 'DELETE FROM $petBodyVectorIdMappingTable'; + +// ── Pet Face Clusters Table ── + +const petFaceClustersTable = 'pet_face_clusters'; + +const createPetFaceClustersTable = ''' +CREATE TABLE IF NOT EXISTS $petFaceClustersTable ( + $petFaceIDColumn TEXT NOT NULL PRIMARY KEY, + $clusterIDColumn TEXT NOT NULL +); +'''; + +const petFcClusterIDIndex = + 'CREATE INDEX IF NOT EXISTS idx_pet_fc_cluster ON $petFaceClustersTable ($clusterIDColumn)'; + +// ── Pet Cluster → Pet Mapping Table ── + +const petClusterPetTable = 'pet_cluster_pet'; +const petIdColumn = 'pet_id'; + +const createPetClusterPetTable = ''' +CREATE TABLE IF NOT EXISTS $petClusterPetTable ( + $clusterIDColumn TEXT NOT NULL PRIMARY KEY, + $petIdColumn TEXT NOT NULL +); +'''; + +const deletePetFaceClustersTable = 'DELETE FROM $petFaceClustersTable'; +const deletePetClusterPetTable = 'DELETE FROM $petClusterPetTable'; + +// ── Not-Pet Feedback Table ── + +const notPetFeedbackTable = 'not_pet_feedback'; + +const createNotPetFeedbackTable = ''' +CREATE TABLE IF NOT EXISTS $notPetFeedbackTable ( + $clusterIDColumn TEXT NOT NULL, + $petFaceIDColumn TEXT NOT NULL, + PRIMARY KEY($clusterIDColumn, $petFaceIDColumn) +); +'''; + +const deleteNotPetFeedbackTable = 'DELETE FROM $notPetFeedbackTable'; diff --git a/mobile/apps/photos/lib/events/pets_changed_event.dart b/mobile/apps/photos/lib/events/pets_changed_event.dart new file mode 100644 index 00000000000..3bd6473a44b --- /dev/null +++ b/mobile/apps/photos/lib/events/pets_changed_event.dart @@ -0,0 +1,10 @@ +import "package:photos/events/event.dart"; + +class PetsChangedEvent extends Event { + final String source; + + PetsChangedEvent({this.source = ""}); + + @override + String get reason => '$runtimeType{"via": $source}'; +} diff --git a/mobile/apps/photos/lib/gateways/entity/models/type.dart b/mobile/apps/photos/lib/gateways/entity/models/type.dart index 0fbb8aef89f..0acfc03033f 100644 --- a/mobile/apps/photos/lib/gateways/entity/models/type.dart +++ b/mobile/apps/photos/lib/gateways/entity/models/type.dart @@ -4,12 +4,14 @@ enum EntityType { cgroup, unknown, smartAlbum, - memory; + memory, + pet; bool get isZipped { switch (this) { case EntityType.location: case EntityType.person: + case EntityType.pet: return false; default: return true; @@ -28,6 +30,8 @@ enum EntityType { return "smart_album"; case EntityType.memory: return "memory"; + case EntityType.pet: + return "pet"; case EntityType.unknown: return "unknown"; } diff --git a/mobile/apps/photos/lib/l10n/intl_en.arb b/mobile/apps/photos/lib/l10n/intl_en.arb index 3d51b998c88..dd9e6e415f6 100644 --- a/mobile/apps/photos/lib/l10n/intl_en.arb +++ b/mobile/apps/photos/lib/l10n/intl_en.arb @@ -1677,6 +1677,7 @@ }, "faces": "Faces", "people": "People", + "peopleAndPets": "People & Pets", "contents": "Contents", "addNew": "Add new", "@addNew": { @@ -1734,6 +1735,7 @@ "viewPersonToUnlink": "View {name} to unlink", "enterName": "Enter name", "savePerson": "Save person", + "savePet": "Save pet", "editPerson": "Edit person", "mergedPhotos": "Merged photos", "orMergeWithExistingPerson": "Or merge with existing", @@ -3075,5 +3077,14 @@ "offlineNameFaceBannerSubtitle": "Sign up to tag them and easily find their photos later.", "signUp": "Sign up", "dog": "Dog", - "cat": "Cat" + "cat": "Cat", + "pet": "Pet", + "moveTo": "Move to", + "notThisPet": "Not this pet", + "noClusters": "No clusters", + "noNamedPetsToMerge": "No named pets to merge with.\nName a pet first, then merge.", + "viewClusters": "View clusters", + "ignorePet": "Ignore pet", + "areYouSureYouWantToIgnoreThisPet": "Are you sure you want to ignore this pet?", + "thePetGroupsWillNotBeDisplayed": "The pet groups will not be displayed in the pets section anymore. Photos will remain untouched." } diff --git a/mobile/apps/photos/lib/main.dart b/mobile/apps/photos/lib/main.dart index f8c35d8198d..12d64a988d1 100644 --- a/mobile/apps/photos/lib/main.dart +++ b/mobile/apps/photos/lib/main.dart @@ -35,6 +35,7 @@ import 'package:photos/services/home_widget_service.dart'; import 'package:photos/services/local_file_update_service.dart'; import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; import 'package:photos/services/machine_learning/ml_service.dart'; +import "package:photos/services/machine_learning/pet_ml/pet_service.dart"; import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart'; import 'package:photos/services/memory_lane/memory_lane_service.dart'; import 'package:photos/services/memory_share_service.dart'; @@ -383,6 +384,7 @@ Future _init(bool isBackground, {String via = ''}) async { unawaited(MLService.instance.init()); PersonService.init(entityService, MLDataDB.instance, preferences); await PersonService.instance.refreshPersonCache(); + await PetService.init(entityService, MLDataDB.instance); EnteWakeLockService.instance.init(preferences); wrappedService.scheduleInitialLoad(); logLocalSettings(); diff --git a/mobile/apps/photos/lib/models/gallery_type.dart b/mobile/apps/photos/lib/models/gallery_type.dart index 88e8187131e..a2cdb8ca028 100644 --- a/mobile/apps/photos/lib/models/gallery_type.dart +++ b/mobile/apps/photos/lib/models/gallery_type.dart @@ -21,6 +21,7 @@ enum GalleryType { quickLink, peopleTag, cluster, + petCluster, sharedPublicCollection, magic, cleanupHiddenFromDevice, @@ -47,6 +48,7 @@ extension GalleyTypeExtension on GalleryType { case GalleryType.hiddenOwnedCollection: case GalleryType.trash: case GalleryType.cluster: + case GalleryType.petCluster: case GalleryType.sharedPublicCollection: case GalleryType.deleteSuggestions: case GalleryType.cleanupHiddenFromDevice: @@ -73,6 +75,7 @@ extension GalleyTypeExtension on GalleryType { case GalleryType.sharedCollection: case GalleryType.locationTag: case GalleryType.cluster: + case GalleryType.petCluster: case GalleryType.sharedPublicCollection: case GalleryType.magic: case GalleryType.deleteSuggestions: @@ -95,6 +98,7 @@ extension GalleyTypeExtension on GalleryType { case GalleryType.quickLink: case GalleryType.peopleTag: case GalleryType.cluster: + case GalleryType.petCluster: case GalleryType.magic: return true; case GalleryType.trash: @@ -125,6 +129,7 @@ extension GalleyTypeExtension on GalleryType { case GalleryType.quickLink: case GalleryType.peopleTag: case GalleryType.cluster: + case GalleryType.petCluster: case GalleryType.magic: return true; case GalleryType.trash: @@ -146,6 +151,7 @@ extension GalleyTypeExtension on GalleryType { case GalleryType.locationTag: case GalleryType.peopleTag: case GalleryType.cluster: + case GalleryType.petCluster: case GalleryType.magic: return true; case GalleryType.hiddenSection: @@ -177,6 +183,7 @@ extension GalleyTypeExtension on GalleryType { case GalleryType.archive: case GalleryType.localFolder: case GalleryType.cluster: + case GalleryType.petCluster: case GalleryType.trash: case GalleryType.locationTag: case GalleryType.sharedPublicCollection: @@ -207,6 +214,7 @@ extension GalleyTypeExtension on GalleryType { case GalleryType.trash: case GalleryType.sharedCollection: case GalleryType.cluster: + case GalleryType.petCluster: case GalleryType.sharedPublicCollection: case GalleryType.deleteSuggestions: case GalleryType.cleanupHiddenFromDevice: @@ -230,6 +238,7 @@ extension GalleyTypeExtension on GalleryType { case GalleryType.magic: case GalleryType.peopleTag: case GalleryType.cluster: + case GalleryType.petCluster: return true; case GalleryType.hiddenSection: @@ -266,6 +275,7 @@ extension GalleyTypeExtension on GalleryType { case GalleryType.quickLink: case GalleryType.favorite: case GalleryType.cluster: + case GalleryType.petCluster: case GalleryType.archive: case GalleryType.localFolder: case GalleryType.trash: @@ -292,6 +302,7 @@ extension GalleyTypeExtension on GalleryType { case GalleryType.locationTag: case GalleryType.peopleTag: case GalleryType.cluster: + case GalleryType.petCluster: case GalleryType.magic: return true; @@ -333,12 +344,14 @@ extension GalleyTypeExtension on GalleryType { bool showEditLocation() { return this != GalleryType.sharedCollection && this != GalleryType.cluster && + this != GalleryType.petCluster && this != GalleryType.cleanupHiddenFromDevice; } bool showBulkEditTime() { return this != GalleryType.sharedCollection && this != GalleryType.cluster && + this != GalleryType.petCluster && this != GalleryType.cleanupHiddenFromDevice; } } @@ -442,6 +455,7 @@ extension GalleryAppBarExtn on GalleryType { return false; case GalleryType.uncategorized: case GalleryType.cluster: + case GalleryType.petCluster: case GalleryType.peopleTag: case GalleryType.ownedCollection: case GalleryType.sharedCollection: diff --git a/mobile/apps/photos/lib/models/ml/pet/pet_entity.dart b/mobile/apps/photos/lib/models/ml/pet/pet_entity.dart new file mode 100644 index 00000000000..316f4ebe0eb --- /dev/null +++ b/mobile/apps/photos/lib/models/ml/pet/pet_entity.dart @@ -0,0 +1,158 @@ +import "package:flutter/foundation.dart"; +import "package:photos/models/ml/face/person.dart"; + +const Object _petDataUnchanged = Object(); + +class PetEntity { + final String remoteID; + final PetData data; + PetEntity( + this.remoteID, + this.data, + ); + + PetEntity copyWith({ + String? remoteID, + PetData? data, + }) { + return PetEntity( + remoteID ?? this.remoteID, + data ?? this.data, + ); + } +} + +class PetData { + final String name; + final int species; + + final bool isHidden; + final bool isPinned; + final bool hideFromMemories; + + String? avatarFaceID; + List assigned = List.empty(); + List rejectedFaceIDs = List.empty(); + List manuallyAssigned = List.empty(); + + /// string formatted in `yyyy-MM-dd` + final String? birthDate; + + bool hasAvatar() => avatarFaceID != null; + + bool get isIgnored => + (isHidden || name.isEmpty || name == '(hidden)' || name == '(ignored)'); + + PetData({ + required this.name, + required this.species, + this.assigned = const [], + this.rejectedFaceIDs = const [], + this.manuallyAssigned = const [], + this.avatarFaceID, + this.isHidden = false, + this.isPinned = false, + this.hideFromMemories = false, + this.birthDate, + }); + + PetData copyWith({ + String? name, + int? species, + List? assigned, + String? avatarFaceId, + bool? isHidden, + bool? isPinned, + bool? hideFromMemories, + Object? birthDate = _petDataUnchanged, + List? rejectedFaceIDs, + List? manuallyAssigned, + }) { + return PetData( + name: name ?? this.name, + species: species ?? this.species, + assigned: assigned ?? this.assigned, + avatarFaceID: avatarFaceId ?? avatarFaceID, + isHidden: isHidden ?? this.isHidden, + isPinned: isPinned ?? this.isPinned, + hideFromMemories: hideFromMemories ?? this.hideFromMemories, + birthDate: identical(birthDate, _petDataUnchanged) + ? this.birthDate + : birthDate as String?, + rejectedFaceIDs: + rejectedFaceIDs ?? List.from(this.rejectedFaceIDs), + manuallyAssigned: + manuallyAssigned ?? List.from(this.manuallyAssigned), + ); + } + + void logStats() { + if (kDebugMode == false) return; + final StringBuffer sb = StringBuffer(); + sb.writeln('Pet: $name (species: $species)'); + int assignedCount = 0; + for (final a in assigned) { + assignedCount += a.faces.length; + } + sb.writeln('Assigned: ${assigned.length} withFaces $assignedCount'); + sb.writeln('Rejected faceIDs: ${rejectedFaceIDs.length}'); + sb.writeln('Manual fileIDs: ${manuallyAssigned.length}'); + for (var cluster in assigned) { + sb.writeln('Cluster: ${cluster.id} - ${cluster.faces.length}'); + } + debugPrint(sb.toString()); + } + + Map toJson() => { + 'name': name, + 'species': species, + 'assigned': assigned.map((e) => e.toJson()).toList(), + 'rejectedFaceIDs': rejectedFaceIDs, + 'avatarFaceID': avatarFaceID, + 'isHidden': isHidden, + 'isPinned': isPinned, + 'hideFromMemories': hideFromMemories, + 'birthDate': birthDate, + 'manuallyAssigned': manuallyAssigned, + }; + + factory PetData.fromJson(Map json) { + final assigned = (json['assigned'] == null || + json['assigned'].length == 0 || + json['assigned'] is! Iterable) + ? [] + : List.from( + json['assigned'] + .where((x) => x is Map) + .map((x) => ClusterInfo.fromJson(x as Map)), + ); + + final List rejectedFaceIDs = + (json['rejectedFaceIDs'] == null || json['rejectedFaceIDs'].length == 0) + ? [] + : List.from( + json['rejectedFaceIDs'], + ); + final manualAssignmentData = json['manuallyAssigned']; + final manuallyAssigned = manualAssignmentData is Iterable + ? List.from( + manualAssignmentData.map((value) { + if (value is num) return value.toInt(); + return int.tryParse(value.toString()); + }).whereType(), + ) + : []; + return PetData( + name: json['name'] as String? ?? '', + species: json['species'] as int? ?? -1, + assigned: assigned, + rejectedFaceIDs: rejectedFaceIDs, + manuallyAssigned: manuallyAssigned, + avatarFaceID: json['avatarFaceID'] as String?, + isHidden: json['isHidden'] as bool? ?? false, + isPinned: json['isPinned'] as bool? ?? false, + hideFromMemories: json['hideFromMemories'] as bool? ?? false, + birthDate: json['birthDate'] as String?, + ); + } +} diff --git a/mobile/apps/photos/lib/models/search/search_constants.dart b/mobile/apps/photos/lib/models/search/search_constants.dart index c5d64a2645f..45e2d14bcaf 100644 --- a/mobile/apps/photos/lib/models/search/search_constants.dart +++ b/mobile/apps/photos/lib/models/search/search_constants.dart @@ -3,5 +3,7 @@ const kPersonWidgetKey = 'person_widget_key'; const kPersonPinned = 'person_pinned'; const kClusterParamId = 'cluster_id'; const kFileID = 'file_id'; +const kPetClusterParamId = 'pet_cluster_id'; +const kPetId = 'pet_id'; const kContactEmail = 'contact_email'; const kContactCollections = 'contact_collections'; diff --git a/mobile/apps/photos/lib/models/search/search_types.dart b/mobile/apps/photos/lib/models/search/search_types.dart index 4f48bbc1efe..359983d98ed 100644 --- a/mobile/apps/photos/lib/models/search/search_types.dart +++ b/mobile/apps/photos/lib/models/search/search_types.dart @@ -9,13 +9,13 @@ import "package:photos/events/event.dart"; import "package:photos/events/location_tag_updated_event.dart"; import "package:photos/events/magic_cache_updated_event.dart"; import "package:photos/events/people_changed_event.dart"; +import "package:photos/events/pets_changed_event.dart"; import "package:photos/generated/l10n.dart"; import "package:photos/models/collection/collection.dart"; import "package:photos/models/collection/collection_items.dart"; import "package:photos/models/search/search_result.dart"; import "package:photos/models/typedefs.dart"; import "package:photos/services/collections_service.dart"; -import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart"; import "package:photos/services/search_service.dart"; import "package:photos/ui/viewer/gallery/collection_page.dart"; import "package:photos/ui/viewer/location/add_location_sheet.dart"; @@ -252,12 +252,8 @@ extension SectionTypeExtensions on SectionType { }) { switch (this) { case SectionType.face: - return SearchService.instance.getAllFace( - limit, - minClusterSize: limit == null - ? kMinimumClusterSizeAllFaces - : kMinimumClusterSizeSearchResult, - ); + return SearchService.instance + .getAllPeopleAndPets(limit, context: context); case SectionType.magic: return SearchService.instance.getMagicSectionResults(context!); case SectionType.wrapped: @@ -290,7 +286,10 @@ extension SectionTypeExtensions on SectionType { case SectionType.album: return [Bus.instance.on()]; case SectionType.face: - return [Bus.instance.on()]; + return [ + Bus.instance.on(), + Bus.instance.on(), + ]; case SectionType.contacts: return [Bus.instance.on()]; default: @@ -302,6 +301,8 @@ extension SectionTypeExtensions on SectionType { ///events listened to in AllSectionsExampleState. List> sectionUpdateEvents() { switch (this) { + case SectionType.face: + return [Bus.instance.on()]; case SectionType.location: return [Bus.instance.on()]; case SectionType.magic: diff --git a/mobile/apps/photos/lib/services/entity_service.dart b/mobile/apps/photos/lib/services/entity_service.dart index 948ea1e4fb3..f42f29a8c4f 100644 --- a/mobile/apps/photos/lib/services/entity_service.dart +++ b/mobile/apps/photos/lib/services/entity_service.dart @@ -105,6 +105,8 @@ class EntityService { try { await _remoteToLocalSync(EntityType.location); await _remoteToLocalSync(EntityType.cgroup); + await _remoteToLocalSync(EntityType.person); + await _remoteToLocalSync(EntityType.pet); await _remoteToLocalSync(EntityType.smartAlbum); } catch (e) { _logger.severe("Failed to sync entities", e); diff --git a/mobile/apps/photos/lib/services/machine_learning/compute_controller.dart b/mobile/apps/photos/lib/services/machine_learning/compute_controller.dart index 0b6f8438a79..2d56ece0f89 100644 --- a/mobile/apps/photos/lib/services/machine_learning/compute_controller.dart +++ b/mobile/apps/photos/lib/services/machine_learning/compute_controller.dart @@ -275,8 +275,7 @@ class ComputeController { ); return; } - final shouldRunCompute = - this.shouldRunCompute; + final shouldRunCompute = this.shouldRunCompute; if (shouldRunCompute != _canRunCompute) { _canRunCompute = shouldRunCompute; _logger.info( diff --git a/mobile/apps/photos/lib/services/machine_learning/ml_indexing_isolate.dart b/mobile/apps/photos/lib/services/machine_learning/ml_indexing_isolate.dart index 3f6c142d815..953f411d2ff 100644 --- a/mobile/apps/photos/lib/services/machine_learning/ml_indexing_isolate.dart +++ b/mobile/apps/photos/lib/services/machine_learning/ml_indexing_isolate.dart @@ -207,9 +207,6 @@ class MLIndexingIsolate extends SuperIsolate { PetFaceDetectionService.instance.downloadModel(forceRefresh), PetFaceEmbeddingDogService.instance.downloadModel(forceRefresh), PetFaceEmbeddingCatService.instance.downloadModel(forceRefresh), - PetBodyDetectionService.instance.downloadModel(forceRefresh), - PetBodyEmbeddingDogService.instance.downloadModel(forceRefresh), - PetBodyEmbeddingCatService.instance.downloadModel(forceRefresh), ]); } @@ -386,9 +383,6 @@ class MLIndexingIsolate extends SuperIsolate { String petFaceDetectionPath = ""; String petFaceEmbeddingDogPath = ""; String petFaceEmbeddingCatPath = ""; - String petBodyDetectionPath = ""; - String petBodyEmbeddingDogPath = ""; - String petBodyEmbeddingCatPath = ""; if (flagService.petEnabled && localSettings.petRecognitionEnabled) { petFaceDetectionPath = @@ -397,12 +391,6 @@ class MLIndexingIsolate extends SuperIsolate { (await PetFaceEmbeddingDogService.instance.getModelNameAndPath()).$2; petFaceEmbeddingCatPath = (await PetFaceEmbeddingCatService.instance.getModelNameAndPath()).$2; - petBodyDetectionPath = - (await PetBodyDetectionService.instance.getModelNameAndPath()).$2; - petBodyEmbeddingDogPath = - (await PetBodyEmbeddingDogService.instance.getModelNameAndPath()).$2; - petBodyEmbeddingCatPath = - (await PetBodyEmbeddingCatService.instance.getModelNameAndPath()).$2; } return { @@ -412,9 +400,9 @@ class MLIndexingIsolate extends SuperIsolate { "petFaceDetectionModelPath": petFaceDetectionPath, "petFaceEmbeddingDogModelPath": petFaceEmbeddingDogPath, "petFaceEmbeddingCatModelPath": petFaceEmbeddingCatPath, - "petBodyDetectionModelPath": petBodyDetectionPath, - "petBodyEmbeddingDogModelPath": petBodyEmbeddingDogPath, - "petBodyEmbeddingCatModelPath": petBodyEmbeddingCatPath, + "petBodyDetectionModelPath": "", + "petBodyEmbeddingDogModelPath": "", + "petBodyEmbeddingCatModelPath": "", "preferCoreml": Platform.isIOS, "preferNnapi": Platform.isAndroid, "preferXnnpack": Platform.isAndroid, diff --git a/mobile/apps/photos/lib/services/machine_learning/ml_models_overview.dart b/mobile/apps/photos/lib/services/machine_learning/ml_models_overview.dart index f1ca929a84c..ae8525275eb 100644 --- a/mobile/apps/photos/lib/services/machine_learning/ml_models_overview.dart +++ b/mobile/apps/photos/lib/services/machine_learning/ml_models_overview.dart @@ -13,9 +13,6 @@ enum MLModels { petFaceDetection, petFaceEmbeddingDog, petFaceEmbeddingCat, - petBodyDetection, - petBodyEmbeddingDog, - petBodyEmbeddingCat, } extension MLModelsExtension on MLModels { @@ -35,12 +32,6 @@ extension MLModelsExtension on MLModels { return PetFaceEmbeddingDogService.instance; case MLModels.petFaceEmbeddingCat: return PetFaceEmbeddingCatService.instance; - case MLModels.petBodyDetection: - return PetBodyDetectionService.instance; - case MLModels.petBodyEmbeddingDog: - return PetBodyEmbeddingDogService.instance; - case MLModels.petBodyEmbeddingCat: - return PetBodyEmbeddingCatService.instance; } } @@ -52,9 +43,6 @@ extension MLModelsExtension on MLModels { case MLModels.petFaceDetection: case MLModels.petFaceEmbeddingDog: case MLModels.petFaceEmbeddingCat: - case MLModels.petBodyDetection: - case MLModels.petBodyEmbeddingDog: - case MLModels.petBodyEmbeddingCat: return true; case MLModels.clipTextEncoder: return false; diff --git a/mobile/apps/photos/lib/services/machine_learning/ml_result.dart b/mobile/apps/photos/lib/services/machine_learning/ml_result.dart index 38e121022a6..dfec57df5ed 100644 --- a/mobile/apps/photos/lib/services/machine_learning/ml_result.dart +++ b/mobile/apps/photos/lib/services/machine_learning/ml_result.dart @@ -13,21 +13,19 @@ class MLResult { List? faces = []; ClipResult? clip; List? petFaces; - List? petBodies; Dimensions decodedImageSize; bool get ranML => facesRan || clipRan || petsRan; bool get facesRan => faces != null; bool get clipRan => clip != null; - bool get petsRan => petFaces != null || petBodies != null; + bool get petsRan => petFaces != null; MLResult({ this.fileId = -1, this.faces, this.clip, this.petFaces, - this.petBodies, this.decodedImageSize = const Dimensions(width: -1, height: -1), }); @@ -41,7 +39,6 @@ class MLResult { 'faces': faces?.map((face) => face.toJson()).toList(), 'clip': clip?.toJson(), 'petFaces': petFaces?.map((pf) => pf.toJson()).toList(), - 'petBodies': petBodies?.map((obj) => obj.toJson()).toList(), 'decodedImageSize': { 'width': decodedImageSize.width, 'height': decodedImageSize.height, @@ -68,13 +65,6 @@ class MLResult { ) .toList() : null, - petBodies: json['petBodies'] != null - ? (json['petBodies'] as List) - .map( - (item) => PetBodyResult.fromJson(item as Map), - ) - .toList() - : null, decodedImageSize: json['decodedImageSize'] != null ? Dimensions( width: json['decodedImageSize']['width'], @@ -235,38 +225,3 @@ class PetFaceResult { ); } } - -class PetBodyResult { - final List boxXyxy; - final double score; - final int cocoClass; - final String petBodyId; - final Embedding embedding; - - PetBodyResult({ - required this.boxXyxy, - required this.score, - required this.cocoClass, - required this.petBodyId, - required this.embedding, - }); - - Map toJson() => { - 'boxXyxy': boxXyxy, - 'score': score, - 'cocoClass': cocoClass, - 'petBodyId': petBodyId, - 'embedding': embedding, - }; - - static PetBodyResult fromJson(Map json) { - return PetBodyResult( - boxXyxy: - (json['boxXyxy'] as List).map((e) => (e as num).toDouble()).toList(), - score: (json['score'] as num).toDouble(), - cocoClass: json['cocoClass'], - petBodyId: json['petBodyId'], - embedding: Embedding.from(json['embedding']), - ); - } -} diff --git a/mobile/apps/photos/lib/services/machine_learning/ml_service.dart b/mobile/apps/photos/lib/services/machine_learning/ml_service.dart index 6c551c5c132..12c2bde8cba 100644 --- a/mobile/apps/photos/lib/services/machine_learning/ml_service.dart +++ b/mobile/apps/photos/lib/services/machine_learning/ml_service.dart @@ -13,6 +13,7 @@ import "package:photos/db/ml/db_pet_model_mappers.dart"; import "package:photos/db/offline_files_db.dart"; import "package:photos/events/compute_control_event.dart"; import "package:photos/events/people_changed_event.dart"; +import "package:photos/events/pets_changed_event.dart"; import "package:photos/main.dart"; import "package:photos/models/ml/clip.dart"; import "package:photos/models/ml/face/face.dart"; @@ -21,10 +22,11 @@ import "package:photos/service_locator.dart"; import "package:photos/services/filedata/model/file_data.dart"; import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart"; import "package:photos/services/machine_learning/face_ml/face_clustering/face_db_info_for_clustering.dart"; -import "package:photos/services/machine_learning/face_ml/face_detection/detection.dart"; import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; import "package:photos/services/machine_learning/ml_indexing_isolate.dart"; import "package:photos/services/machine_learning/ml_result.dart"; +import "package:photos/services/machine_learning/pet_ml/pet_clustering_service.dart"; +import "package:photos/services/machine_learning/pet_ml/pet_service.dart"; import "package:photos/services/machine_learning/semantic_search/semantic_search_service.dart"; import "package:photos/services/search_service.dart"; import "package:photos/services/video_preview_service.dart"; @@ -393,6 +395,14 @@ class MLService { if ((await mlDataDB.getUnclusteredFaceCount()) > 0) { await clusterAllImages(); } + if (_hasModeChanged(mode)) { + _logger.info("App mode changed during ML run, stopping"); + return; + } + // Pet clustering (internal users only) + if (flagService.petEnabled && localSettings.petRecognitionEnabled) { + await _clusterPets(mlDataDB, mode); + } if (_mlControllerStatus == true) { if (_hasModeChanged(mode)) { _logger.info("App mode changed during ML run, stopping"); @@ -678,7 +688,6 @@ class MLService { } } else { final clusterStartTime = DateTime.now(); - // Cluster the embeddings using the linear clustering algorithm, returning a map from faceID to clusterID final clusteringResult = await FaceClusteringService.instance.predictLinearIsolate( allFaceInfoForClustering.toSet(), @@ -717,6 +726,25 @@ class MLService { } } + Future _clusterPets(MLDataDB mlDataDB, MLMode mode) async { + if (_shouldPauseIndexingAndClustering) return; + try { + final changed = await PetClusteringService.instance.clusterPets( + mlDataDB: mlDataDB, + isOffline: mode == MLMode.offline, + ); + if (changed) { + Bus.instance.fire(PetsChangedEvent(source: "clustering")); + } + // Reconcile local cluster mappings with synced PetEntity data + if (PetService.isInitialized) { + await PetService.instance.reconcileClusters(); + } + } catch (e, s) { + _logger.severe("Pet clustering failed", e, s); + } + } + Future processImage(FileMLInstruction instruction) async { bool actuallyRanML = false; @@ -817,7 +845,7 @@ class MLService { // Pet results locally — delete stale rows before writing so // re-indexing with fewer detections doesn't leave old data behind. - final rustPets = result.petFaces != null || result.petBodies != null; + final rustPets = result.petFaces != null; if (rustPets) { await mlDataDB.deletePetDataForFiles([result.fileId]); if (result.petFaces != null && result.petFaces!.isNotEmpty) { @@ -844,37 +872,6 @@ class MLService { // considered pet-indexed (mirrors Face.empty for human faces). await mlDataDB.bulkInsertPetFaces([DBPetFace.empty(result.fileId)]); } - - if (result.petBodies != null && result.petBodies!.isNotEmpty) { - final dbPetBodies = result.petBodies!.map((obj) { - final detectionObj = FaceDetectionRelative( - score: obj.score, - box: [ - obj.boxXyxy[0], - obj.boxXyxy[1], - obj.boxXyxy[2], - obj.boxXyxy[3], - ], - allKeypoints: const [], - ); - return DBPetBody( - fileId: result.fileId, - petBodyId: obj.petBodyId, - detection: jsonEncode(detectionObj.toJson()), - bodyVectorId: null, - species: obj.cocoClass == 15 ? 1 : 0, - score: obj.score, - imageHeight: result.decodedImageSize.height, - imageWidth: result.decodedImageSize.width, - mlVersion: petMlVersion, - ); - }).toList(); - await mlDataDB.bulkInsertPetBodies(dbPetBodies); - await mlDataDB.storePetBodyEmbeddings( - dbPetBodies, - result.petBodies!, - ); - } } _logger.info("ML result for fileID ${result.fileId} stored remote+local"); return actuallyRanML; diff --git a/mobile/apps/photos/lib/services/machine_learning/pet_ml/pet_clustering_service.dart b/mobile/apps/photos/lib/services/machine_learning/pet_ml/pet_clustering_service.dart new file mode 100644 index 00000000000..b3a2470a592 --- /dev/null +++ b/mobile/apps/photos/lib/services/machine_learning/pet_ml/pet_clustering_service.dart @@ -0,0 +1,693 @@ +import "dart:math" show Random, min; +import "dart:typed_data" show Float64List; + +import "package:logging/logging.dart"; +import "package:photos/db/ml/db.dart"; +import "package:photos/db/ml/db_pet_model_mappers.dart"; +import "package:photos/db/ml/pet_vector_db.dart"; +import "package:photos/db/ml/schema.dart"; +import "package:photos/services/machine_learning/pet_ml/pet_service.dart"; +import "package:photos/src/rust/api/ml_indexing_api.dart"; +import "package:synchronized/synchronized.dart"; +import "package:uuid/uuid.dart"; + +final _logger = Logger("PetClusteringService"); +const _incrementalClusterSampleSize = 5; + +/// Orchestrates pet clustering by reading indexed pet data from the DB, +/// fetching embeddings from the vector DB, calling the Rust 3-phase +/// clustering engine, and storing the results. +class PetClusteringService { + PetClusteringService._(); + static final instance = PetClusteringService._(); + + final Lock _clusterLock = Lock(); + + /// Run pet clustering on all unclustered pet faces. + /// + /// Groups by species (dog=0, cat=1) and clusters each independently. + /// Supports both batch (first run) and incremental (subsequent runs). + /// Returns `true` if any assignments or summaries were changed. + Future clusterPets({ + required MLDataDB mlDataDB, + bool isOffline = false, + }) async { + return _clusterLock.synchronized(() async { + try { + final unclusteredCount = + await mlDataDB.getUnclusteredPetFaceCount(isOffline: isOffline); + if (unclusteredCount == 0) { + _logger.info("No unclustered pet faces, skipping"); + return false; + } + _logger.info("Starting pet clustering: $unclusteredCount unclustered"); + + bool changed = false; + for (final species in [0, 1]) { + final speciesChanged = await _clusterSpecies( + species: species, + mlDataDB: mlDataDB, + isOffline: isOffline, + ); + changed = changed || speciesChanged; + } + return changed; + } catch (e, s) { + _logger.severe("Pet clustering failed", e, s); + return false; + } + }); + } + + Future _clusterSpecies({ + required int species, + required MLDataDB mlDataDB, + required bool isOffline, + }) async { + final speciesName = species == 0 ? "dog" : "cat"; + _logger.info("Clustering $speciesName faces..."); + + // 1. Read all pet faces for this species (lightweight metadata only) + final faces = + await mlDataDB.getPetFacesForClustering(species, isOffline: isOffline); + if (faces.isEmpty) { + _logger.info("No $speciesName faces to cluster"); + return false; + } + + // 2. Build lightweight metadata — no embeddings cross FFI + final allMetas = []; + final unclusteredMetas = []; + for (final face in faces) { + if (face.faceVectorId == null) continue; + final meta = RustPetFaceMeta( + petFaceId: face.petFaceId, + vectorId: face.faceVectorId!, + species: species, + fileId: face.fileId, + clusterId: face.clusterId ?? "", + ); + allMetas.add(meta); + if (face.clusterId == null) { + unclusteredMetas.add(meta); + } + } + + // 3. Get usearch index path — Rust reads embeddings directly + final faceVdb = PetVectorDB.forModel( + species: species, + isFace: true, + offline: isOffline, + ); + final faceIndexPath = await faceVdb.getIndexPath(); + + // 4. Check for existing clusters (incremental mode) + final existingSummaries = + await mlDataDB.getAllPetClusterSummary(species: species); + + late RustPetClusterResult result; + + if (existingSummaries.isEmpty) { + // Batch mode — first run, cluster all faces + if (allMetas.length < 2) { + _logger.info("Not enough $speciesName faces to cluster " + "(${allMetas.length})"); + return false; + } + _logger.info( + "Batch clustering ${allMetas.length} $speciesName faces", + ); + result = await runPetClusteringFromIndex( + faces: allMetas, + faceIndexPath: faceIndexPath, + species: species, + ); + } else { + // Incremental mode — use random cluster samples built on demand. + if (unclusteredMetas.isEmpty) { + _logger.info("No unclustered $speciesName faces, skipping"); + return false; + } + _logger.info( + "Sample-incremental clustering ${unclusteredMetas.length} new " + "$speciesName faces (${existingSummaries.length} existing clusters)", + ); + + final clusterExemplars = await _buildRandomClusterSamples( + existingClusterIds: existingSummaries.keys, + faces: faces, + faceVdb: faceVdb, + ); + + if (clusterExemplars.isEmpty) { + _logger.info("No cluster samples found, falling back to batch"); + result = await runPetClusteringFromIndex( + faces: allMetas, + faceIndexPath: faceIndexPath, + species: species, + ); + } else { + result = await runPetClusteringIncrementalExemplarsFromIndex( + newFaces: unclusteredMetas, + faceIndexPath: faceIndexPath, + clusterExemplars: clusterExemplars, + species: species, + ); + } + } + + // 5. Store results — respect user feedback + final faceToCluster = {}; + for (final assignment in result.assignments) { + faceToCluster[assignment.petFaceId] = assignment.clusterId; + } + + // Online mode preserves user corrections by honoring rejected assignments. + // Offline mode is view-only raw clustering output. + if (faceToCluster.isNotEmpty) { + if (!isOffline) { + final allRejected = await mlDataDB.getBulkRejectedPetFaceIds( + faceToCluster.values.toSet(), + ); + if (allRejected.isNotEmpty) { + for (final entry in faceToCluster.entries.toList()) { + final rejected = allRejected[entry.value]; + if (rejected != null && rejected.contains(entry.key)) { + faceToCluster[entry.key] = const Uuid().v4(); + } + } + } + } + await mlDataDB.updatePetFaceIdToClusterId(faceToCluster); + } + + _logger.info( + "$speciesName clustering done: ${faceToCluster.length} assigned, " + "${result.nUnclustered} unclustered, " + "${result.summaries.length} clusters", + ); + return faceToCluster.isNotEmpty; + } + + static Future> _buildRandomClusterSamples({ + required Iterable existingClusterIds, + required List faces, + required PetVectorDB faceVdb, + }) async { + final clusterToVectorIds = >{}; + for (final face in faces) { + final clusterId = face.clusterId; + final vectorId = face.faceVectorId; + if (clusterId == null || clusterId.isEmpty || vectorId == null) { + continue; + } + clusterToVectorIds.putIfAbsent(clusterId, () => []).add(vectorId); + } + + final random = Random(); + final clusterSamples = []; + for (final clusterId in existingClusterIds) { + final vectorIds = clusterToVectorIds[clusterId]; + if (vectorIds == null || vectorIds.isEmpty) { + continue; + } + final sampledVectorIds = _sampleRandomVectorIds( + vectorIds, + _incrementalClusterSampleSize, + random, + ); + final embeddings = await faceVdb.getEmbeddings(sampledVectorIds); + if (embeddings.isEmpty) { + continue; + } + clusterSamples.add( + RustClusterExemplars( + clusterId: clusterId, + exemplars: embeddings + .map( + (embedding) => Float64List.fromList( + embedding.map((value) => value.toDouble()).toList(), + ), + ) + .toList(), + ), + ); + } + return clusterSamples; + } + + static List _sampleRandomVectorIds( + List vectorIds, + int maxSamples, + Random random, + ) { + if (vectorIds.length <= maxSamples) { + return List.from(vectorIds); + } + final shuffled = List.from(vectorIds); + for (int i = shuffled.length - 1; i > 0; i--) { + final j = random.nextInt(i + 1); + final tmp = shuffled[i]; + shuffled[i] = shuffled[j]; + shuffled[j] = tmp; + } + return shuffled.sublist(0, maxSamples); + } +} + +/// Lightweight holder for pet face data needed for clustering. +class PetFaceClusterInfo { + final int fileId; + final String petFaceId; + final int? faceVectorId; + final int species; + final String? clusterId; + + PetFaceClusterInfo({ + required this.fileId, + required this.petFaceId, + required this.faceVectorId, + required this.species, + this.clusterId, + }); +} + +// ── DB helper methods (extension on MLDataDB) ─────────────────────────── + +extension PetClusteringDB on MLDataDB { + /// Count pet faces that don't have a cluster assignment yet. + Future getUnclusteredPetFaceCount({bool isOffline = false}) async { + final db = await asyncDB; + const String query = ''' + SELECT COUNT(*) as count + FROM $petFacesTable f + LEFT JOIN $petFaceClustersTable fc ON f.$petFaceIDColumn = fc.$petFaceIDColumn + WHERE f.$speciesColumn >= 0 + AND f.$faceVectorIdColumn IS NOT NULL + AND fc.$petFaceIDColumn IS NULL + '''; + final rows = await db.getAll(query); + return rows.first['count'] as int; + } + + /// Get all pet faces for a given species, with their existing cluster IDs. + Future> getPetFacesForClustering( + int species, { + bool isOffline = false, + }) async { + final db = await asyncDB; + const String query = ''' + SELECT f.$fileIDColumn, f.$petFaceIDColumn, f.$faceVectorIdColumn, + f.$speciesColumn, fc.$clusterIDColumn + FROM $petFacesTable f + LEFT JOIN $petFaceClustersTable fc ON f.$petFaceIDColumn = fc.$petFaceIDColumn + WHERE f.$speciesColumn = ? + AND f.$faceVectorIdColumn IS NOT NULL + ORDER BY f.$fileIDColumn + '''; + final rows = await db.getAll(query, [species]); + return rows + .map( + (r) => PetFaceClusterInfo( + fileId: r[fileIDColumn] as int, + petFaceId: r[petFaceIDColumn] as String, + faceVectorId: r[faceVectorIdColumn] as int?, + species: r[speciesColumn] as int, + clusterId: r[clusterIDColumn] as String?, + ), + ) + .toList(); + } + + /// Store pet face → cluster assignments. + Future updatePetFaceIdToClusterId( + Map faceIdToClusterId, + ) async { + if (faceIdToClusterId.isEmpty) return; + final db = await asyncDB; + const batchSize = 500; + final entries = faceIdToClusterId.entries.toList(); + for (int i = 0; i < entries.length; i += batchSize) { + final batch = entries.sublist(i, min(i + batchSize, entries.length)); + const String sql = ''' + INSERT INTO $petFaceClustersTable ($petFaceIDColumn, $clusterIDColumn) + VALUES (?, ?) + ON CONFLICT($petFaceIDColumn) DO UPDATE SET $clusterIDColumn = excluded.$clusterIDColumn + '''; + final params = batch.map((e) => [e.key, e.value]).toList(); + await db.executeBatch(sql, params); + } + } + + /// Get all existing pet cluster summaries, optionally filtered by species. + Future> getAllPetClusterSummary({ + int? species, + }) async { + final db = await asyncDB; + final where = species != null ? ' AND f.$speciesColumn = ?' : ''; + final params = species != null ? [species] : []; + final rows = await db.getAll( + 'SELECT fc.$clusterIDColumn, COUNT(*) as $countColumn, ' + 'MIN(f.$speciesColumn) as $speciesColumn ' + 'FROM $petFaceClustersTable fc ' + 'INNER JOIN $petFacesTable f ON fc.$petFaceIDColumn = f.$petFaceIDColumn ' + 'WHERE f.$speciesColumn >= 0$where ' + 'GROUP BY fc.$clusterIDColumn', + params, + ); + final result = {}; + for (final r in rows) { + result[r[clusterIDColumn] as String] = ( + r[countColumn] as int, + r[speciesColumn] as int, + ); + } + return result; + } + + // ── Cover face for cluster (face crops) ── + + /// Get the highest-scoring pet face in a cluster. + Future getCoverPetFaceForCluster(String clusterId) async { + final db = await asyncDB; + const String query = ''' + SELECT f.* + FROM $petFaceClustersTable fc + INNER JOIN $petFacesTable f ON fc.$petFaceIDColumn = f.$petFaceIDColumn + WHERE fc.$clusterIDColumn = ? + AND f.$speciesColumn >= 0 + ORDER BY f.$faceScore DESC + LIMIT 1 + '''; + final rows = await db.getAll(query, [clusterId]); + if (rows.isEmpty) return null; + return DBPetFace.fromMap(rows.first); + } + + // ── Cluster → Pet mapping ── + + /// Get all cluster-to-pet-ID mappings. + Future> getClusterToPetId() async { + final db = await asyncDB; + final rows = await db.getAll( + "SELECT $clusterIDColumn, $petIdColumn FROM $petClusterPetTable", + ); + final result = {}; + for (final r in rows) { + result[r[clusterIDColumn] as String] = r[petIdColumn] as String; + } + return result; + } + + /// Map a cluster to a pet ID. + Future setClusterPetId(String clusterId, String petId) async { + final db = await asyncDB; + await db.execute( + '''INSERT INTO $petClusterPetTable ($clusterIDColumn, $petIdColumn) + VALUES (?, ?) + ON CONFLICT($clusterIDColumn) DO UPDATE SET + $petIdColumn = excluded.$petIdColumn''', + [clusterId, petId], + ); + } + + /// Remove a cluster's pet mapping (unmerge). + Future removeClusterPetId(String clusterId) async { + final db = await asyncDB; + await db.execute( + 'DELETE FROM $petClusterPetTable WHERE $clusterIDColumn = ?', + [clusterId], + ); + } + + // ── Manual reassignment helpers ── + + /// Get petFaceIds for given fileIds within a specific cluster. + Future> getPetFaceIdsForFilesInCluster( + List fileIds, + String clusterId, + ) async { + if (fileIds.isEmpty) return []; + final db = await asyncDB; + final placeholders = List.filled(fileIds.length, '?').join(','); + final query = ''' + SELECT fc.$petFaceIDColumn + FROM $petFaceClustersTable fc + INNER JOIN $petFacesTable f ON fc.$petFaceIDColumn = f.$petFaceIDColumn + WHERE fc.$clusterIDColumn = ? + AND f.$fileIDColumn IN ($placeholders) + '''; + final rows = await db.getAll(query, [clusterId, ...fileIds]); + return rows.map((r) => r[petFaceIDColumn] as String).toList(); + } + + /// Get cluster → list of petFaceIds for a given species. + Future>> getPetClusterToFaceIds( + int species, + ) async { + final db = await asyncDB; + final rows = await db.getAll( + 'SELECT fc.$clusterIDColumn, fc.$petFaceIDColumn ' + 'FROM $petFaceClustersTable fc ' + 'INNER JOIN $petFacesTable f ON fc.$petFaceIDColumn = f.$petFaceIDColumn ' + 'WHERE f.$speciesColumn = ?', + [species], + ); + final result = >{}; + for (final r in rows) { + result + .putIfAbsent(r[clusterIDColumn] as String, () => []) + .add(r[petFaceIDColumn] as String); + } + return result; + } + + /// Get petId to {clusterId to Set of faceIds} for reconciliation. + Future>>> + getPetToClusterIdToFaceIds() async { + final db = await asyncDB; + final rows = await db.getAll( + 'SELECT pcp.$petIdColumn, fc.$clusterIDColumn, fc.$petFaceIDColumn ' + 'FROM $petClusterPetTable pcp ' + 'INNER JOIN $petFaceClustersTable fc ' + 'ON pcp.$clusterIDColumn = fc.$clusterIDColumn', + ); + final result = >>{}; + for (final r in rows) { + final petId = r[petIdColumn] as String; + final clusterId = r[clusterIDColumn] as String; + final faceId = r[petFaceIDColumn] as String; + result + .putIfAbsent(petId, () => {}) + .putIfAbsent(clusterId, () => {}) + .add(faceId); + } + return result; + } + + /// Get species and faceVectorId for given petFaceIds. + Future> getPetFaceDetails( + List petFaceIds, + ) async { + if (petFaceIds.isEmpty) return {}; + final db = await asyncDB; + final placeholders = List.filled(petFaceIds.length, '?').join(','); + final rows = await db.getAll( + 'SELECT $petFaceIDColumn, $speciesColumn, $faceVectorIdColumn ' + 'FROM $petFacesTable WHERE $petFaceIDColumn IN ($placeholders)', + petFaceIds, + ); + final result = {}; + for (final r in rows) { + result[r[petFaceIDColumn] as String] = ( + r[speciesColumn] as int, + r[faceVectorIdColumn] as int?, + ); + } + return result; + } + + /// Get all petFaceIds assigned to a given cluster. + Future> getPetFaceIdsForCluster(String clusterId) async { + final db = await asyncDB; + final rows = await db.getAll( + 'SELECT $petFaceIDColumn FROM $petFaceClustersTable ' + 'WHERE $clusterIDColumn = ?', + [clusterId], + ); + return rows.map((r) => r[petFaceIDColumn] as String).toList(); + } + + /// Force-update pet face cluster assignments. + Future forceUpdatePetFaceClusterIds( + Map petFaceIdToClusterId, + ) async { + if (petFaceIdToClusterId.isEmpty) return; + final db = await asyncDB; + const String sql = ''' + INSERT INTO $petFaceClustersTable ($petFaceIDColumn, $clusterIDColumn) + VALUES (?, ?) + ON CONFLICT($petFaceIDColumn) DO UPDATE SET + $clusterIDColumn = excluded.$clusterIDColumn + '''; + const batchSize = 500; + final entries = petFaceIdToClusterId.entries.toList(); + for (int i = 0; i < entries.length; i += batchSize) { + final batch = entries.sublist(i, min(i + batchSize, entries.length)); + final params = batch.map((e) => [e.key, e.value]).toList(); + await db.executeBatch(sql, params); + } + } + + /// Record "not this pet" feedback. + Future bulkInsertNotPetFeedback( + List<(String clusterId, String petFaceId)> feedback, + ) async { + if (feedback.isEmpty) return; + final db = await asyncDB; + const String sql = ''' + INSERT OR IGNORE INTO $notPetFeedbackTable + ($clusterIDColumn, $petFaceIDColumn) + VALUES (?, ?) + '''; + final params = feedback.map((e) => [e.$1, e.$2]).toList(); + await db.executeBatch(sql, params); + } + + /// Remove "not this pet" feedback for faces being moved to a cluster. + /// This clears prior rejections so re-clustering won't eject them. + Future clearNotPetFeedback( + String clusterId, + List petFaceIds, + ) async { + if (petFaceIds.isEmpty) return; + final db = await asyncDB; + final placeholders = List.filled(petFaceIds.length, '?').join(','); + await db.execute( + 'DELETE FROM $notPetFeedbackTable ' + 'WHERE $clusterIDColumn = ? AND $petFaceIDColumn IN ($placeholders)', + [clusterId, ...petFaceIds], + ); + } + + /// Get all rejected petFaceIds for a cluster. + Future> getRejectedPetFaceIds(String clusterId) async { + final db = await asyncDB; + final rows = await db.getAll( + 'SELECT $petFaceIDColumn FROM $notPetFeedbackTable ' + 'WHERE $clusterIDColumn = ?', + [clusterId], + ); + return rows.map((r) => r[petFaceIDColumn] as String).toSet(); + } + + /// Get all rejected petFaceIds for multiple clusters in one query. + Future>> getBulkRejectedPetFaceIds( + Set clusterIds, + ) async { + if (clusterIds.isEmpty) return {}; + final db = await asyncDB; + final result = >{}; + const chunkSize = 500; + final idList = clusterIds.toList(); + for (int i = 0; i < idList.length; i += chunkSize) { + final chunk = idList.sublist(i, min(i + chunkSize, idList.length)); + final placeholders = List.filled(chunk.length, '?').join(','); + final rows = await db.getAll( + 'SELECT $clusterIDColumn, $petFaceIDColumn FROM $notPetFeedbackTable ' + 'WHERE $clusterIDColumn IN ($placeholders)', + chunk, + ); + for (final r in rows) { + result + .putIfAbsent(r[clusterIDColumn] as String, () => {}) + .add(r[petFaceIDColumn] as String); + } + } + return result; + } + + /// Reassign all pet faces in one cluster to another. + Future reassignAllPetFacesInCluster( + String sourceClusterId, + String targetClusterId, + ) async { + final db = await asyncDB; + await db.execute( + 'UPDATE $petFaceClustersTable SET $clusterIDColumn = ? ' + 'WHERE $clusterIDColumn = ?', + [targetClusterId, sourceClusterId], + ); + } + + /// Get file IDs for a single pet cluster. + Future> getPetFileIdsForCluster(String clusterId) async { + final db = await asyncDB; + final rows = await db.getAll( + 'SELECT DISTINCT f.$fileIDColumn ' + 'FROM $petFaceClustersTable fc ' + 'INNER JOIN $petFacesTable f ON fc.$petFaceIDColumn = f.$petFaceIDColumn ' + 'WHERE fc.$clusterIDColumn = ? AND f.$speciesColumn >= 0', + [clusterId], + ); + return rows.map((r) => r[fileIDColumn] as int).toList(); + } + + /// Get a mapping from cluster ID to the list of file IDs in that cluster. + Future>> getPetClusterFileIds() async { + final db = await asyncDB; + const String query = ''' + SELECT fc.$clusterIDColumn, f.$fileIDColumn + FROM $petFaceClustersTable fc + INNER JOIN $petFacesTable f ON fc.$petFaceIDColumn = f.$petFaceIDColumn + WHERE f.$speciesColumn >= 0 + ORDER BY fc.$clusterIDColumn + '''; + final rows = await db.getAll(query); + final result = >{}; + for (final r in rows) { + final cid = r[clusterIDColumn] as String; + final fid = r[fileIDColumn] as int; + result.putIfAbsent(cid, () => {}).add(fid); + } + return result.map((k, v) => MapEntry(k, v.toList())); + } + + /// Get all pet clusters with their file counts. + /// Returns list of (clusterId, species, fileCount, name?). + Future> getAllPetClustersWithInfo() async { + final db = await asyncDB; + const String query = ''' + SELECT fc.$clusterIDColumn, + f.$speciesColumn, + COUNT(DISTINCT f.$fileIDColumn) as file_count + FROM $petFaceClustersTable fc + INNER JOIN $petFacesTable f ON fc.$petFaceIDColumn = f.$petFaceIDColumn + WHERE f.$speciesColumn >= 0 + GROUP BY fc.$clusterIDColumn + ORDER BY file_count DESC + '''; + final rows = await db.getAll(query); + + // Resolve names via cluster → pet mapping + PetService + final clusterToPetId = await getClusterToPetId(); + final petEntities = await PetService.instance.getPetsMap(); + final names = {}; + for (final entry in clusterToPetId.entries) { + final pet = petEntities[entry.value]; + if (pet != null && pet.data.name.isNotEmpty) { + names[entry.key] = pet.data.name; + } + } + + return rows.map((r) { + final cid = r[clusterIDColumn] as String; + return ( + cid, + r[speciesColumn] as int, + r['file_count'] as int, + names[cid], + ); + }).toList(); + } +} diff --git a/mobile/apps/photos/lib/services/machine_learning/pet_ml/pet_model_services.dart b/mobile/apps/photos/lib/services/machine_learning/pet_ml/pet_model_services.dart index c5eb114266e..0c98dd4a251 100644 --- a/mobile/apps/photos/lib/services/machine_learning/pet_ml/pet_model_services.dart +++ b/mobile/apps/photos/lib/services/machine_learning/pet_ml/pet_model_services.dart @@ -61,63 +61,3 @@ class PetFaceEmbeddingCatService extends MlModel { static final instance = PetFaceEmbeddingCatService._privateConstructor(); factory PetFaceEmbeddingCatService() => instance; } - -/// Pet body detection model (YOLOv5s — COCO classes 15=cat, 16=dog). -class PetBodyDetectionService extends MlModel { - static const kRemoteBucketModelPath = "yolov5s_object_fp16.onnx"; - static const _modelName = "YOLOv5sPetBody"; - - @override - String get modelRemotePath => kModelBucketEndpoint + kRemoteBucketModelPath; - - @override - Logger get logger => _logger; - static final _logger = Logger('PetBodyDetectionService'); - - @override - String get modelName => _modelName; - - PetBodyDetectionService._privateConstructor(); - static final instance = PetBodyDetectionService._privateConstructor(); - factory PetBodyDetectionService() => instance; -} - -/// Dog body embedding model. -class PetBodyEmbeddingDogService extends MlModel { - static const kRemoteBucketModelPath = "dog_body_embedding192.onnx"; - static const _modelName = "DogBody"; - - @override - String get modelRemotePath => kModelBucketEndpoint + kRemoteBucketModelPath; - - @override - Logger get logger => _logger; - static final _logger = Logger('PetBodyEmbeddingDogService'); - - @override - String get modelName => _modelName; - - PetBodyEmbeddingDogService._privateConstructor(); - static final instance = PetBodyEmbeddingDogService._privateConstructor(); - factory PetBodyEmbeddingDogService() => instance; -} - -/// Cat body embedding model. -class PetBodyEmbeddingCatService extends MlModel { - static const kRemoteBucketModelPath = "cat_body_embedding192.onnx"; - static const _modelName = "CatBody"; - - @override - String get modelRemotePath => kModelBucketEndpoint + kRemoteBucketModelPath; - - @override - Logger get logger => _logger; - static final _logger = Logger('PetBodyEmbeddingCatService'); - - @override - String get modelName => _modelName; - - PetBodyEmbeddingCatService._privateConstructor(); - static final instance = PetBodyEmbeddingCatService._privateConstructor(); - factory PetBodyEmbeddingCatService() => instance; -} diff --git a/mobile/apps/photos/lib/services/machine_learning/pet_ml/pet_service.dart b/mobile/apps/photos/lib/services/machine_learning/pet_ml/pet_service.dart new file mode 100644 index 00000000000..668fa240dc2 --- /dev/null +++ b/mobile/apps/photos/lib/services/machine_learning/pet_ml/pet_service.dart @@ -0,0 +1,381 @@ +import "dart:convert"; + +import "package:computer/computer.dart"; +import "package:flutter/foundation.dart"; +import "package:logging/logging.dart"; +import "package:photos/core/event_bus.dart"; +import "package:photos/db/ml/db.dart"; +import "package:photos/events/pets_changed_event.dart"; +import "package:photos/gateways/entity/models/type.dart"; +import "package:photos/models/local_entity_data.dart"; +import "package:photos/models/ml/face/person.dart" show ClusterInfo; +import "package:photos/models/ml/pet/pet_entity.dart"; +import "package:photos/service_locator.dart"; +import "package:photos/services/entity_service.dart"; +import "package:photos/services/machine_learning/pet_ml/pet_clustering_service.dart"; +import "package:uuid/uuid.dart"; + +/// Manages pet entities synced via the entity sync service. +class PetService { + final EntityService entityService; + final MLDataDB mlDataDB; + final _logger = Logger("PetService"); + + PetService(this.entityService, this.mlDataDB); + + static PetService? _instance; + + static PetService get instance { + if (_instance == null) { + throw Exception("PetService not initialized"); + } + return _instance!; + } + + static bool get isInitialized => _instance != null; + + Future>? _cachedPetsFuture; + int _lastCacheRefreshTime = 0; + + static Future init( + EntityService entityService, + MLDataDB mlDataDB, + ) async { + _instance = PetService(entityService, mlDataDB); + await _instance!._refreshCache(); + } + + void clearCache() { + _cachedPetsFuture = null; + _lastCacheRefreshTime = 0; + } + + Future _refreshCache() async { + _lastCacheRefreshTime = 0; + final _ = await getPets(); + } + + int _lastRemoteSyncTime() { + return entityService.lastSyncTime(EntityType.pet); + } + + Future> getPets() async { + if (_lastCacheRefreshTime != _lastRemoteSyncTime()) { + _lastCacheRefreshTime = _lastRemoteSyncTime(); + _cachedPetsFuture = null; + } + _cachedPetsFuture ??= _fetchAndCachePets(); + return _cachedPetsFuture!; + } + + Future> _fetchAndCachePets() async { + _logger.finest("reading all pets from local db"); + final entities = await entityService.getEntities(EntityType.pet); + final pets = await Computer.shared().compute( + _decodePetEntities, + param: {"entity": entities}, + taskName: "decode_pet_entities", + ); + return pets; + } + + static List _decodePetEntities(Map param) { + final entities = param["entity"] as List; + return entities + .map( + (e) => PetEntity( + e.id, + PetData.fromJson(json.decode(e.data)), + ), + ) + .toList(); + } + + Future getPet(String id) async { + final e = await entityService.getEntity(EntityType.pet, id); + if (e == null) return null; + return PetEntity(e.id, PetData.fromJson(json.decode(e.data))); + } + + Future> getPetsMap() async { + final pets = await getPets(); + return {for (final p in pets) p.remoteID: p}; + } + + Future addPet(PetData data) async { + final result = await _addOrUpdateEntity(data.toJson()); + Bus.instance.fire(PetsChangedEvent(source: "PetService.addPet")); + return PetEntity(result.id, data); + } + + Future updatePet(String petID, PetData data) async { + await _addOrUpdateEntity(data.toJson(), id: petID); + Bus.instance.fire(PetsChangedEvent(source: "PetService.updatePet")); + return PetEntity(petID, data); + } + + Future deletePet(String petID) async { + await entityService.deleteEntry(petID); + await _removeLocalClusterMappingsForPet(petID); + _invalidateCache(); + Bus.instance.fire(PetsChangedEvent(source: "PetService.deletePet")); + } + + /// Delete all pet entities. Used for debug reset. + Future deleteAllPets() async { + final pets = await getPets(); + for (final pet in pets) { + await entityService.deleteEntry(pet.remoteID); + await _removeLocalClusterMappingsForPet(pet.remoteID); + } + _invalidateCache(); + } + + /// Sync pet entities from remote. Returns true if data changed. + Future syncPets() async { + if (isOfflineMode) { + _logger.finest("Skip syncing pets in offline mode"); + return false; + } + final int changedEntities = await entityService.syncEntity(EntityType.pet); + return changedEntities > 0; + } + + /// Sync local and remote pet-to-cluster mappings in both directions. + /// + /// Direction 1 (local → remote): Push local pet-to-cluster mappings to + /// the server first, so fresh assignments (e.g. user just named or merged + /// a pet) are persisted before the stale-cleanup pass runs. + /// + /// Direction 2 (remote → local): Fetch pet entities from server, update + /// local `pet_cluster_pet` mappings, and remove stale entries whose + /// clusters no longer appear in any remote PetData.assigned. + Future reconcileClusters() async { + await _pushLocalClustersToRemote(); + await fetchRemoteClusterFeedback(skipIfNoChange: false); + } + + /// Fetch remote pet entities and update local ML DB mappings. + /// Returns true if remote data changed. + Future fetchRemoteClusterFeedback({ + bool skipIfNoChange = true, + }) async { + if (isOfflineMode) { + _logger.finest("Skip fetching remote pet clusters in offline mode"); + return false; + } + final int changedEntities = await entityService.syncEntity(EntityType.pet); + final bool changed = changedEntities > 0; + if (!changed && skipIfNoChange) { + return false; + } + + final entities = await entityService.getEntities(EntityType.pet); + final remotePetIDs = entities.map((e) => e.id).toSet(); + + // Remove local mappings for pets that no longer exist remotely + final localMappings = await mlDataDB.getClusterToPetId(); + final localPetIds = localMappings.values.toSet(); + int removedOrphans = 0; + for (final localPetId in localPetIds) { + if (!remotePetIDs.contains(localPetId)) { + // Remove all cluster mappings for this orphaned pet + for (final entry in localMappings.entries) { + if (entry.value == localPetId) { + await mlDataDB.removeClusterPetId(entry.key); + removedOrphans++; + } + } + } + } + if (removedOrphans > 0) { + _logger.info( + "Removed $removedOrphans orphaned local pet cluster mappings", + ); + } + + // Apply remote assignments to local DB + entities.sort((a, b) => a.updatedAt.compareTo(b.updatedAt)); + final Map clusterToPetId = {}; + for (final e in entities) { + final petData = PetData.fromJson(json.decode(e.data)); + for (final cluster in petData.assigned) { + clusterToPetId[cluster.id] = e.id; + } + if (kDebugMode) { + _logger.info( + "Pet ${e.id} ${petData.name} has ${petData.assigned.length} clusters", + ); + } + } + + // Remove stale local mappings: clusters locally assigned to a pet that + // still exists remotely but whose remote assignments no longer include + // that cluster (i.e. unmerge/unassign done on another device). + final remoteClusterIds = clusterToPetId.keys.toSet(); + int removedStale = 0; + for (final entry in localMappings.entries) { + if (remotePetIDs.contains(entry.value) && + !remoteClusterIds.contains(entry.key)) { + await mlDataDB.removeClusterPetId(entry.key); + removedStale++; + } + } + if (removedStale > 0) { + _logger.info( + "Removed $removedStale stale local pet cluster mappings", + ); + } + + // Write all cluster-to-pet mappings + for (final entry in clusterToPetId.entries) { + await mlDataDB.setClusterPetId(entry.key, entry.value); + } + + return changed; + } + + /// Push local cluster assignments to remote PetData.assigned. + Future _pushLocalClustersToRemote() async { + final dbPetClusterInfo = await mlDataDB.getPetToClusterIdToFaceIds(); + final pets = await getPetsMap(); + + for (final petID in dbPetClusterInfo.keys) { + final pet = pets[petID]; + if (pet == null) { + _logger.warning("Pet $petID not found in entities, skipping"); + continue; + } + final dbClusters = dbPetClusterInfo[petID]!; + final petData = pet.data; + + if (!_shouldUpdateAssigned(petData, dbClusters)) { + continue; + } + + petData.assigned = dbClusters.entries + .map( + (e) => ClusterInfo(id: e.key, faces: e.value), + ) + .toList(); + + await _addOrUpdateEntity(petData.toJson(), id: petID); + petData.logStats(); + } + + // Clear remote assigned list for pets that lost all local clusters + // (e.g. after unmerging the last cluster). Without this, stale + // assignments stay on the server and get re-imported on next sync. + for (final pet in pets.values) { + if (dbPetClusterInfo.containsKey(pet.remoteID)) continue; + if (pet.data.assigned.isEmpty) continue; + + pet.data.assigned = []; + await _addOrUpdateEntity(pet.data.toJson(), id: pet.remoteID); + _logger.info( + "Cleared remote assignments for pet ${pet.remoteID} (no local clusters)", + ); + } + } + + bool _shouldUpdateAssigned( + PetData petData, + Map> dbClusters, + ) { + if (petData.assigned.length != dbClusters.length) return true; + for (final info in petData.assigned) { + final dbCluster = dbClusters[info.id]; + if (dbCluster == null) return true; + if (info.faces.length != dbCluster.length) return true; + for (final faceId in info.faces) { + if (!dbCluster.contains(faceId)) return true; + } + } + return false; + } + + Future _addOrUpdateEntity( + Map jsonMap, { + String? id, + }) async { + final result = + await entityService.addOrUpdate(EntityType.pet, jsonMap, id: id); + _invalidateCache(); + return result; + } + + void _invalidateCache() { + _lastCacheRefreshTime = 0; + _cachedPetsFuture = null; + } + + /// Hide a pet cluster so it no longer appears in the pets section. + Future ignorePetCluster(String clusterId, int species) async { + final clusterToPetId = await mlDataDB.getClusterToPetId(); + final existingPetId = clusterToPetId[clusterId]; + + if (existingPetId != null) { + final pet = await getPet(existingPetId); + if (pet != null) { + await updatePet(existingPetId, pet.data.copyWith(isHidden: true)); + } + } else { + final pet = await addPet( + PetData(name: "", species: species, isHidden: true), + ); + await mlDataDB.setClusterPetId(clusterId, pet.remoteID); + } + Bus.instance.fire(PetsChangedEvent(source: "ignorePetCluster")); + } + + /// Remove a secondary cluster from a pet. + Future removeClusterFromPet({ + required String petID, + required String clusterID, + }) async { + await mlDataDB.removeClusterPetId(clusterID); + final pet = await getPet(petID); + if (pet != null) { + pet.data.assigned.removeWhere((e) => e.id == clusterID); + await updatePet(petID, pet.data); + } + Bus.instance.fire(PetsChangedEvent(source: "removeClusterFromPet")); + } + + /// Map a cluster to an existing pet entity. + Future addClusterToExistingPet({ + required String petId, + required String clusterID, + }) async { + await mlDataDB.setClusterPetId(clusterID, petId); + Bus.instance.fire(PetsChangedEvent(source: "addClusterToExistingPet")); + } + + /// Remove selected files' pet faces from a cluster ("Not this pet"). + Future removeFilesFromPetCluster( + List fileIds, + String clusterId, + ) async { + final faceIds = + await mlDataDB.getPetFaceIdsForFilesInCluster(fileIds, clusterId); + if (faceIds.isEmpty) return; + await mlDataDB.bulkInsertNotPetFeedback( + faceIds.map((faceId) => (clusterId, faceId)).toList(), + ); + final newAssignments = {}; + for (final faceId in faceIds) { + newAssignments[faceId] = const Uuid().v4(); + } + await mlDataDB.forceUpdatePetFaceClusterIds(newAssignments); + Bus.instance.fire(PetsChangedEvent(source: "removeFilesFromPetCluster")); + } + + Future _removeLocalClusterMappingsForPet(String petID) async { + final clusterToPetId = await mlDataDB.getClusterToPetId(); + for (final entry in clusterToPetId.entries) { + if (entry.value == petID) { + await mlDataDB.removeClusterPetId(entry.key); + } + } + } +} diff --git a/mobile/apps/photos/lib/services/search_service.dart b/mobile/apps/photos/lib/services/search_service.dart index 4c73e23ce49..3b785c2b37c 100644 --- a/mobile/apps/photos/lib/services/search_service.dart +++ b/mobile/apps/photos/lib/services/search_service.dart @@ -53,7 +53,10 @@ import 'package:photos/services/collections_service.dart'; import "package:photos/services/date_parse_service.dart"; import "package:photos/services/filter/db_filters.dart"; import "package:photos/services/location_service.dart"; +import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart"; import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; +import "package:photos/services/machine_learning/pet_ml/pet_clustering_service.dart"; +import "package:photos/services/machine_learning/pet_ml/pet_service.dart"; import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart'; import "package:photos/services/memories_cache_service.dart"; import "package:photos/states/location_screen_state.dart"; @@ -61,6 +64,7 @@ import "package:photos/ui/viewer/location/add_location_sheet.dart"; import "package:photos/ui/viewer/location/location_screen.dart"; import "package:photos/ui/viewer/people/cluster_page.dart"; import "package:photos/ui/viewer/people/people_page.dart"; +import "package:photos/ui/viewer/people/pet_cluster_page.dart"; import "package:photos/ui/viewer/search/result/magic_result_screen.dart"; import "package:photos/utils/cache_util.dart"; import "package:photos/utils/file_util.dart"; @@ -949,6 +953,28 @@ class SearchService { return files; } + /// Returns combined people + pets results for the unified section. + Future> getAllPeopleAndPets( + int? limit, { + BuildContext? context, + }) async { + final people = await getAllFace( + null, + minClusterSize: limit == null + ? kMinimumClusterSizeAllFaces + : kMinimumClusterSizeSearchResult, + ); + final pets = flagService.petEnabled + ? await getAllPets(null) + : []; + final combined = [...people, ...pets] + ..sort((a, b) => b.fileCount().compareTo(a.fileCount())); + if (limit != null && combined.length > limit) { + return combined.sublist(0, limit); + } + return combined; + } + Future> getAllFace( int? limit, { required int minClusterSize, @@ -1874,4 +1900,168 @@ class SearchService { : DateTime(year, month + 1, 1).microsecondsSinceEpoch, ]; } + + Future> _getFileIdToFileMap(Set fileIds) async { + final allFiles = await getAllFilesForSearch(); + final map = {}; + for (final file in allFiles) { + final id = file.uploadedFileID; + if (id != null && fileIds.contains(id)) { + map[id] = file; + } + } + return map; + } + + Future> getAllPets( + int? limit, + ) async { + try { + final mlDataDB = + isOfflineMode ? MLDataDB.offlineInstance : MLDataDB.instance; + + // Get all pet cluster assignments: cluster_id -> list of file_ids + final clusterToFileIds = await mlDataDB.getPetClusterFileIds(); + if (clusterToFileIds.isEmpty) return []; + + // Get cluster summaries for species info + final clusterSummaries = await mlDataDB.getAllPetClusterSummary(); + + // Get pet names via cluster → pet mapping + PetService + final clusterToPetId = await mlDataDB.getClusterToPetId(); + final petEntities = await PetService.instance.getPetsMap(); + final petNames = {}; + for (final entry in clusterToPetId.entries) { + final pet = petEntities[entry.value]; + if (pet != null && pet.data.name.isNotEmpty) { + petNames[entry.key] = pet.data.name; + } + } + + // Get file ID -> EnteFile mapping. In offline mode the ML DB stores + // local integer IDs, so we must convert via OfflineFilesDB first. + final allFileIds = clusterToFileIds.values.expand((ids) => ids).toSet(); + final Map fileIdToFile; + if (isOfflineMode) { + final localIdMap = + await OfflineFilesDB.instance.getLocalIdsForIntIds(allFileIds); + final allFiles = await getAllFilesForSearch(); + final localIdToFile = {}; + for (final file in allFiles) { + final localId = file.localID; + if (localId != null && localId.isNotEmpty) { + localIdToFile[localId] = file; + } + } + final map = {}; + for (final entry in localIdMap.entries) { + final file = localIdToFile[entry.value]; + if (file != null) { + map[entry.key] = file; + } + } + fileIdToFile = map; + } else { + fileIdToFile = await _getFileIdToFileMap(allFileIds); + } + + // Group clusters by petId. Clusters sharing the same petId (merged) + // are combined into a single search result, mirroring how persons + // aggregate multiple clusters. + final petIdToClusters = >{}; + final standaloneClusters = []; + for (final clusterId in clusterToFileIds.keys) { + final petId = clusterToPetId[clusterId]; + if (petId != null) { + petIdToClusters.putIfAbsent(petId, () => []).add(clusterId); + } else { + standaloneClusters.add(clusterId); + } + } + + // Build groups: each group is (primaryClusterId, allFiles, species, name) + final groups = <(String, List, int, String?)>[]; + + for (final entry in petIdToClusters.entries) { + final clusterIds = entry.value; + final seenFids = {}; + final allFiles = []; + int? species; + for (final cid in clusterIds) { + final fids = clusterToFileIds[cid] ?? []; + for (final fid in fids) { + if (!seenFids.add(fid)) continue; + final file = fileIdToFile[fid]; + if (file != null) allFiles.add(file); + } + species ??= clusterSummaries[cid]?.$2; + } + if (allFiles.isEmpty) continue; + final pet = petEntities[entry.key]; + final name = + pet != null && pet.data.name.isNotEmpty ? pet.data.name : null; + groups.add((clusterIds.first, allFiles, species ?? -1, name)); + } + + for (final clusterId in standaloneClusters) { + final fids = clusterToFileIds[clusterId] ?? []; + final files = []; + for (final fid in fids) { + final file = fileIdToFile[fid]; + if (file != null) files.add(file); + } + if (files.isEmpty) continue; + final species = clusterSummaries[clusterId]?.$2 ?? -1; + groups.add((clusterId, files, species, null)); + } + + // Sort by file count descending + groups.sort((a, b) => b.$2.length.compareTo(a.$2.length)); + + final List results = []; + for (final (primaryClusterId, files, species, customName) in groups) { + final label = + (customName != null && customName.isNotEmpty) ? customName : ""; + + results.add( + GenericSearchResult( + ResultType.faces, + label, + files, + params: { + kPetClusterParamId: primaryClusterId, + kFileID: files.first.uploadedFileID, + if (clusterToPetId.containsKey(primaryClusterId)) + kPetId: clusterToPetId[primaryClusterId], + }, + onResultTap: (ctx) { + routeToPage( + ctx, + PetClusterPage( + clusterId: primaryClusterId, + clusterLabel: label, + files: files, + species: species, + ), + ); + }, + hierarchicalSearchFilter: TopLevelGenericFilter( + filterName: label, + occurrence: kMostRelevantFilter, + filterResultType: ResultType.faces, + matchedUploadedIDs: filesToUploadedFileIDs(files), + ), + ), + ); + } + + if (limit != null && results.length > limit) { + return results.sublist(0, limit); + } + return results; + } catch (e, s) { + _logger.severe("Error in getAllPets", e, s); + return []; + } + } } diff --git a/mobile/apps/photos/lib/services/sync/local_sync_service.dart b/mobile/apps/photos/lib/services/sync/local_sync_service.dart index 89cbb9a4104..4d7c3d2eba2 100644 --- a/mobile/apps/photos/lib/services/sync/local_sync_service.dart +++ b/mobile/apps/photos/lib/services/sync/local_sync_service.dart @@ -361,7 +361,8 @@ class LocalSyncService { conflictAlgorithm: SqliteAsyncConflictAlgorithm.ignore, ); _logger.info('Inserted ${files.length} out of ${allFiles.length} files'); - if (flagService.syncRecoveryDiagnostics && allFiles.length != files.length) { + if (flagService.syncRecoveryDiagnostics && + allFiles.length != files.length) { final sampleLocalIDs = allFiles.take(3).map((file) => file.localID).toList(); _logger.info( diff --git a/mobile/apps/photos/lib/services/sync/remote_sync_service.dart b/mobile/apps/photos/lib/services/sync/remote_sync_service.dart index f0759d3e5c8..2a12d7031f4 100644 --- a/mobile/apps/photos/lib/services/sync/remote_sync_service.dart +++ b/mobile/apps/photos/lib/services/sync/remote_sync_service.dart @@ -1111,9 +1111,8 @@ class RemoteSyncService { return await _photoManagerPlugin.isLocallyAvailable( file.localID!, isOrigin: true, - subtype: file.fileType == FileType.livePhoto - ? (file.fileSubType ?? 0) - : 0, + subtype: + file.fileType == FileType.livePhoto ? (file.fileSubType ?? 0) : 0, ); } catch (e, s) { _logger.warning( diff --git a/mobile/apps/photos/lib/states/all_sections_examples_state.dart b/mobile/apps/photos/lib/states/all_sections_examples_state.dart index b81f50ff67f..8e1f8aa0851 100644 --- a/mobile/apps/photos/lib/states/all_sections_examples_state.dart +++ b/mobile/apps/photos/lib/states/all_sections_examples_state.dart @@ -9,6 +9,7 @@ import "package:photos/core/event_bus.dart"; import "package:photos/events/files_updated_event.dart"; import "package:photos/events/people_changed_event.dart"; import "package:photos/events/people_sort_order_change_event.dart"; +import "package:photos/events/pets_changed_event.dart"; import "package:photos/events/tab_changed_event.dart"; import "package:photos/models/search/search_result.dart"; import "package:photos/models/search/search_types.dart"; @@ -35,6 +36,7 @@ class _AllSectionsExamplesProviderState late StreamSubscription _filesUpdatedEvent; late StreamSubscription _onPeopleChangedEvent; late StreamSubscription _peopleSortChangedEvent; + late StreamSubscription _onPetsChangedEvent; late StreamSubscription _tabChangeEvent; bool hasPendingUpdate = false; bool isOnSearchTab = false; @@ -64,6 +66,9 @@ class _AllSectionsExamplesProviderState Bus.instance.on().listen((event) { onDataUpdate(); }); + _onPetsChangedEvent = Bus.instance.on().listen((event) { + onDataUpdate(); + }); _tabChangeEvent = Bus.instance.on().listen((event) { if (event.source == TabChangedEventSource.pageView && event.selectedIndex == 3) { @@ -138,6 +143,7 @@ class _AllSectionsExamplesProviderState _onPeopleChangedEvent.cancel(); _filesUpdatedEvent.cancel(); _peopleSortChangedEvent.cancel(); + _onPetsChangedEvent.cancel(); _tabChangeEvent.cancel(); _cancelInitialLoadTimer(); _debouncer.cancelDebounceTimer(); diff --git a/mobile/apps/photos/lib/ui/settings/debug/ml_debug_settings_page.dart b/mobile/apps/photos/lib/ui/settings/debug/ml_debug_settings_page.dart index 13618116a08..e5b17b2243a 100644 --- a/mobile/apps/photos/lib/ui/settings/debug/ml_debug_settings_page.dart +++ b/mobile/apps/photos/lib/ui/settings/debug/ml_debug_settings_page.dart @@ -8,12 +8,14 @@ import "package:photos/db/ml/clip_vector_db.dart"; import "package:photos/db/ml/cluster_centroid_vector_db.dart"; import "package:photos/db/ml/db.dart"; import "package:photos/events/people_changed_event.dart"; +import "package:photos/events/pets_changed_event.dart"; import "package:photos/models/ml/face/person.dart"; import "package:photos/service_locator.dart"; import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart"; import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; import "package:photos/services/machine_learning/ml_indexing_isolate.dart"; import "package:photos/services/machine_learning/ml_service.dart"; +import "package:photos/services/machine_learning/pet_ml/pet_service.dart"; import "package:photos/services/machine_learning/semantic_search/semantic_search_service.dart"; import "package:photos/theme/ente_theme.dart"; import "package:photos/ui/components/menu_item_widget/menu_item_widget_new.dart"; @@ -601,6 +603,26 @@ class _MLDebugSettingsPageState extends State { trailingIconIsMuted: true, onTap: () async => _onResetFacesAndClustering(context), ), + MenuItemWidgetNew( + title: "Re-index pet faces", + leadingIconWidget: _buildIconWidget( + context, + HugeIcons.strokeRoundedAiImage, + ), + trailingIcon: Icons.chevron_right_outlined, + trailingIconIsMuted: true, + onTap: () async => _onReindexPetFaces(context), + ), + MenuItemWidgetNew( + title: "Reset pet clustering", + leadingIconWidget: _buildIconWidget( + context, + HugeIcons.strokeRoundedAiImage, + ), + trailingIcon: Icons.chevron_right_outlined, + trailingIconIsMuted: true, + onTap: () async => _onResetPetClustering(context), + ), MenuItemWidgetNew( title: "Reset all local faces", leadingIconWidget: _buildIconWidget( @@ -916,6 +938,56 @@ class _MLDebugSettingsPageState extends State { ); } + Future _onReindexPetFaces(BuildContext context) async { + await showChoiceDialog( + context, + title: "Are you sure?", + body: "This will delete ALL pet data (indexed faces, bodies, embeddings, " + "clusters, and PetEntity data) and re-run detection, embedding, " + "and clustering from scratch.", + firstButtonLabel: "Yes, confirm", + firstButtonOnTap: () async { + try { + // Drop clustering data first + await mlDataDB.dropPetClusteringData(); + await PetService.instance.deleteAllPets(); + // Drop indexed faces/bodies and vector DBs + final allFileIds = (await mlDataDB.petIndexedFileIds()).keys.toList(); + if (allFileIds.isNotEmpty) { + await mlDataDB.deletePetDataForFiles(allFileIds); + } + Bus.instance.fire(PetsChangedEvent(source: "reindexPetFaces")); + showShortToast(context, "Done — pet re-indexing will start shortly"); + } catch (e, s) { + logger.warning('re-index pet faces failed', e, s); + await showGenericErrorDialog(context: context, error: e); + } + }, + ); + } + + Future _onResetPetClustering(BuildContext context) async { + await showChoiceDialog( + context, + title: "Are you sure?", + body: "This will delete all pet clusters, PetEntity data, and feedback. " + "Indexed pet faces/bodies are preserved. " + "Pet clustering will re-run automatically.", + firstButtonLabel: "Yes, confirm", + firstButtonOnTap: () async { + try { + await mlDataDB.dropPetClusteringData(); + await PetService.instance.deleteAllPets(); + Bus.instance.fire(PetsChangedEvent(source: "resetPetClusters")); + showShortToast(context, "Done"); + } catch (e, s) { + logger.warning('reset pet clustering failed ', e, s); + await showGenericErrorDialog(context: context, error: e); + } + }, + ); + } + Future _onResetAllLocalFaces(BuildContext context) async { await showChoiceDialog( context, diff --git a/mobile/apps/photos/lib/ui/settings/ml/ml_user_dev_screen.dart b/mobile/apps/photos/lib/ui/settings/ml/ml_user_dev_screen.dart index caa23b83a76..e192d31f4cb 100644 --- a/mobile/apps/photos/lib/ui/settings/ml/ml_user_dev_screen.dart +++ b/mobile/apps/photos/lib/ui/settings/ml/ml_user_dev_screen.dart @@ -166,6 +166,42 @@ class _MLUserDeveloperOptionsState extends State { isGestureDetectorDisabled: true, ) : const SizedBox(), + widget.mlIsEnabled + ? MenuItemWidget( + captionedTextWidget: const CaptionedTextWidget( + title: "Pet recognition", + ), + menuItemColor: colorScheme.fillFaint, + trailingWidget: ToggleSwitchWidget( + value: () => localSettings.petRecognitionEnabled, + onChanged: () async { + try { + await localSettings.togglePetRecognition(); + _logger.info( + 'Pet recognition is turned ${localSettings.petRecognitionEnabled ? 'on' : 'off'}', + ); + if (mounted) { + setState(() {}); + } + } catch (e, s) { + _logger.warning( + 'Pet recognition toggle failed ', + e, + s, + ); + await showGenericErrorDialog( + context: context, + error: e, + ); + } + }, + ), + singleBorderRadius: 8, + alignCaptionedTextToLeft: true, + isBottomBorderRadiusRemoved: true, + isGestureDetectorDisabled: true, + ) + : const SizedBox(), widget.mlIsEnabled ? const SizedBox(height: 24) : const SizedBox.shrink(), diff --git a/mobile/apps/photos/lib/ui/viewer/actions/file_selection_actions_widget.dart b/mobile/apps/photos/lib/ui/viewer/actions/file_selection_actions_widget.dart index 692a5842e3b..29b1fdb279e 100644 --- a/mobile/apps/photos/lib/ui/viewer/actions/file_selection_actions_widget.dart +++ b/mobile/apps/photos/lib/ui/viewer/actions/file_selection_actions_widget.dart @@ -27,6 +27,7 @@ import 'package:photos/services/collections_service.dart'; import 'package:photos/services/hidden_service.dart'; import 'package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart'; import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; +import "package:photos/services/machine_learning/pet_ml/pet_service.dart"; import "package:photos/theme/colors.dart"; import "package:photos/theme/ente_theme.dart"; import 'package:photos/ui/actions/collection/collection_file_actions.dart'; @@ -256,6 +257,16 @@ class _FileSelectionActionsWidgetState ); } + if (widget.type == GalleryType.petCluster && widget.clusterID != null) { + items.add( + SelectionActionButton( + labelText: AppLocalizations.of(context).notThisPet, + icon: Icons.remove_circle_outline, + onTap: _onRemoveFromPetClusterClicked, + ), + ); + } + final showUploadIcon = widget.type == GalleryType.localFolder && split.ownedByCurrentUser.isEmpty; if (widget.type.showAddToAlbum() && !isOfflineMode) { @@ -1019,6 +1030,47 @@ class _FileSelectionActionsWidgetState } } + Future _onRemoveFromPetClusterClicked() async { + if (widget.clusterID == null) return; + final actionResult = await showActionSheet( + context: context, + buttons: [ + ButtonWidget( + labelText: AppLocalizations.of(context).yesRemove, + buttonType: ButtonType.neutral, + buttonSize: ButtonSize.large, + shouldStickToDarkTheme: true, + buttonAction: ButtonAction.first, + isInAlert: true, + ), + ButtonWidget( + labelText: AppLocalizations.of(context).cancel, + buttonType: ButtonType.secondary, + buttonSize: ButtonSize.large, + buttonAction: ButtonAction.second, + shouldStickToDarkTheme: true, + isInAlert: true, + ), + ], + body: AppLocalizations.of(context).notThisPet, + actionSheetType: ActionSheetType.defaultActionSheet, + ); + if (actionResult?.action == ButtonAction.first) { + final fileIds = widget.selectedFiles.files + .map((f) => f.uploadedFileID ?? 0) + .where((id) => id > 0) + .toList(); + await PetService.instance.removeFilesFromPetCluster( + fileIds, + widget.clusterID!, + ); + } + widget.selectedFiles.clearAll(); + if (mounted) { + setState(() => {}); + } + } + Future _sendLink() async { if (_cachedCollectionForSharedLink != null) { final String url = CollectionsService.instance.getPublicUrl( diff --git a/mobile/apps/photos/lib/ui/viewer/file/qr_code_highlight_overlay.dart b/mobile/apps/photos/lib/ui/viewer/file/qr_code_highlight_overlay.dart index 2e1298312c4..0b4fab2503b 100644 --- a/mobile/apps/photos/lib/ui/viewer/file/qr_code_highlight_overlay.dart +++ b/mobile/apps/photos/lib/ui/viewer/file/qr_code_highlight_overlay.dart @@ -21,46 +21,46 @@ class QrCodeHighlightOverlay extends StatelessWidget { } return LayoutBuilder( - builder: (context, constraints) { - final screenWidth = constraints.maxWidth; - final screenHeight = constraints.maxHeight; + builder: (context, constraints) { + final screenWidth = constraints.maxWidth; + final screenHeight = constraints.maxHeight; - double displayWidth; - double displayHeight; - if (file.hasDimensions) { - final imageAspect = file.width / file.height; - final screenAspect = screenWidth / screenHeight; - if (imageAspect > screenAspect) { - displayWidth = screenWidth; - displayHeight = screenWidth / imageAspect; - } else { - displayHeight = screenHeight; - displayWidth = screenHeight * imageAspect; - } - } else { - displayWidth = screenWidth; - displayHeight = screenHeight; - } + double displayWidth; + double displayHeight; + if (file.hasDimensions) { + final imageAspect = file.width / file.height; + final screenAspect = screenWidth / screenHeight; + if (imageAspect > screenAspect) { + displayWidth = screenWidth; + displayHeight = screenWidth / imageAspect; + } else { + displayHeight = screenHeight; + displayWidth = screenHeight * imageAspect; + } + } else { + displayWidth = screenWidth; + displayHeight = screenHeight; + } - final offsetX = (screenWidth - displayWidth) / 2; - final offsetY = (screenHeight - displayHeight) / 2; + final offsetX = (screenWidth - displayWidth) / 2; + final offsetY = (screenHeight - displayHeight) / 2; - return SizedBox.expand( - child: Stack( - children: [ - for (final detection in detections) - _QrTapRegion( - detection: detection, - offsetX: offsetX, - offsetY: offsetY, - displayWidth: displayWidth, - displayHeight: displayHeight, - ), - ], - ), - ); - }, + return SizedBox.expand( + child: Stack( + children: [ + for (final detection in detections) + _QrTapRegion( + detection: detection, + offsetX: offsetX, + offsetY: offsetY, + displayWidth: displayWidth, + displayHeight: displayHeight, + ), + ], + ), ); + }, + ); } } diff --git a/mobile/apps/photos/lib/ui/viewer/file/zoomable_image.dart b/mobile/apps/photos/lib/ui/viewer/file/zoomable_image.dart index 156b6527a15..f452afe55bc 100644 --- a/mobile/apps/photos/lib/ui/viewer/file/zoomable_image.dart +++ b/mobile/apps/photos/lib/ui/viewer/file/zoomable_image.dart @@ -85,9 +85,8 @@ class _ZoomableImageState extends State { widget.shouldDisableScroll!(value != PhotoViewScaleState.initial); } _isZooming = value != PhotoViewScaleState.initial; - InheritedDetailPageState.maybeOf(context) - ?.isZoomedNotifier - .value = _isZooming; + InheritedDetailPageState.maybeOf(context)?.isZoomedNotifier.value = + _isZooming; debugPrint("isZooming = $_isZooming, currentState $value"); // _logger.info('is reakky zooming $_isZooming with state $value'); }; diff --git a/mobile/apps/photos/lib/ui/viewer/file_details/file_info_pets_item_widget.dart b/mobile/apps/photos/lib/ui/viewer/file_details/file_info_pets_item_widget.dart index 8d037bc4b99..0f0185be3fa 100644 --- a/mobile/apps/photos/lib/ui/viewer/file_details/file_info_pets_item_widget.dart +++ b/mobile/apps/photos/lib/ui/viewer/file_details/file_info_pets_item_widget.dart @@ -204,7 +204,11 @@ class _PetsItemWidgetState extends State { EnteTextTheme textTheme, ) { final l10n = AppLocalizations.of(context); - final speciesLabel = info.species == 0 ? l10n.dog : l10n.cat; + final speciesLabel = info.species == 0 + ? l10n.dog + : info.species == 1 + ? l10n.cat + : l10n.pet; return SizedBox( width: thumbnailWidth, diff --git a/mobile/apps/photos/lib/ui/viewer/people/merge_pet_sheet.dart b/mobile/apps/photos/lib/ui/viewer/people/merge_pet_sheet.dart new file mode 100644 index 00000000000..a4ff9e26cd3 --- /dev/null +++ b/mobile/apps/photos/lib/ui/viewer/people/merge_pet_sheet.dart @@ -0,0 +1,173 @@ +import "package:flutter/material.dart"; +import "package:photos/db/ml/db.dart"; +import "package:photos/l10n/l10n.dart"; +import "package:photos/services/machine_learning/pet_ml/pet_clustering_service.dart"; +import "package:photos/services/machine_learning/pet_ml/pet_service.dart"; +import "package:photos/theme/ente_theme.dart"; +import "package:photos/ui/viewer/people/face_thumbnail_squircle.dart"; +import "package:photos/ui/viewer/people/pet_face_widget.dart"; + +class MergePetResult { + final String petId; + final String petName; + MergePetResult(this.petId, this.petName); +} + +/// Show a page to select an existing named pet to merge into. +Future showMergePetPage( + BuildContext context, { + String? currentClusterId, +}) async { + return Navigator.of(context).push( + MaterialPageRoute( + builder: (_) => _MergePetPage(currentClusterId: currentClusterId), + ), + ); +} + +class _MergePetPage extends StatefulWidget { + final String? currentClusterId; + const _MergePetPage({this.currentClusterId}); + + @override + State<_MergePetPage> createState() => _MergePetPageState(); +} + +class _MergePetPageState extends State<_MergePetPage> { + List<_PetGridItem>? _pets; + bool _loading = true; + + @override + void initState() { + super.initState(); + _loadPets(); + } + + Future _loadPets() async { + final pets = await PetService.instance.getPets(); + final clusterToPetId = await MLDataDB.instance.getClusterToPetId(); + + // Find the current pet ID to exclude from the list + final currentPetId = widget.currentClusterId != null + ? clusterToPetId[widget.currentClusterId] + : null; + + // Build a map of petId -> first clusterId for the thumbnail + final petToCluster = {}; + for (final entry in clusterToPetId.entries) { + petToCluster.putIfAbsent(entry.value, () => entry.key); + } + + final items = <_PetGridItem>[]; + for (final pet in pets) { + if (pet.data.isIgnored) continue; + if (pet.data.name.isEmpty) continue; + if (pet.remoteID == currentPetId) continue; + final clusterId = petToCluster[pet.remoteID]; + if (clusterId == null) continue; + items.add( + _PetGridItem( + petId: pet.remoteID, + name: pet.data.name, + clusterId: clusterId, + ), + ); + } + items.sort((a, b) => a.name.compareTo(b.name)); + + if (mounted) { + setState(() { + _pets = items; + _loading = false; + }); + } + } + + @override + Widget build(BuildContext context) { + final textTheme = getEnteTextTheme(context); + final colorScheme = getEnteColorScheme(context); + + return Scaffold( + appBar: AppBar(title: Text(context.l10n.merge)), + body: _loading + ? const Center(child: CircularProgressIndicator()) + : _pets == null || _pets!.isEmpty + ? Center( + child: Padding( + padding: const EdgeInsets.all(32), + child: Text( + context.l10n.noNamedPetsToMerge, + style: + textTheme.body.copyWith(color: colorScheme.textMuted), + textAlign: TextAlign.center, + ), + ), + ) + : GridView.builder( + padding: const EdgeInsets.all(12), + gridDelegate: const SliverGridDelegateWithFixedCrossAxisCount( + crossAxisCount: 3, + crossAxisSpacing: 8, + mainAxisSpacing: 8, + childAspectRatio: 0.75, + ), + itemCount: _pets!.length, + itemBuilder: (context, index) { + final item = _pets![index]; + return _PetGridTile( + item: item, + onTap: () => Navigator.pop( + context, + MergePetResult(item.petId, item.name), + ), + ); + }, + ), + ); + } +} + +class _PetGridItem { + final String petId; + final String name; + final String clusterId; + _PetGridItem({ + required this.petId, + required this.name, + required this.clusterId, + }); +} + +class _PetGridTile extends StatelessWidget { + final _PetGridItem item; + final VoidCallback onTap; + const _PetGridTile({required this.item, required this.onTap}); + + @override + Widget build(BuildContext context) { + final textTheme = getEnteTextTheme(context); + return GestureDetector( + onTap: onTap, + child: Column( + children: [ + Expanded( + child: AspectRatio( + aspectRatio: 1, + child: FaceThumbnailSquircleClip( + child: PetFaceWidget(petClusterId: item.clusterId), + ), + ), + ), + const SizedBox(height: 4), + Text( + item.name, + style: textTheme.mini, + maxLines: 1, + overflow: TextOverflow.ellipsis, + ), + ], + ), + ); + } +} diff --git a/mobile/apps/photos/lib/ui/viewer/people/pet_cluster_page.dart b/mobile/apps/photos/lib/ui/viewer/people/pet_cluster_page.dart new file mode 100644 index 00000000000..2167854286b --- /dev/null +++ b/mobile/apps/photos/lib/ui/viewer/people/pet_cluster_page.dart @@ -0,0 +1,352 @@ +import "dart:async"; + +import "package:ente_pure_utils/ente_pure_utils.dart"; +import "package:flutter/material.dart"; +import "package:photos/core/event_bus.dart"; +import "package:photos/db/files_db.dart"; +import "package:photos/db/ml/db.dart"; +import "package:photos/db/offline_files_db.dart"; +import "package:photos/events/files_updated_event.dart"; +import "package:photos/events/local_photos_updated_event.dart"; +import "package:photos/events/pets_changed_event.dart"; +import "package:photos/generated/intl/app_localizations.dart"; +import "package:photos/models/file/file.dart"; +import "package:photos/models/file_load_result.dart"; +import "package:photos/models/gallery_type.dart"; +import "package:photos/models/selected_files.dart"; +import "package:photos/service_locator.dart" show isOfflineMode; +import "package:photos/services/machine_learning/pet_ml/pet_clustering_service.dart"; +import "package:photos/services/machine_learning/pet_ml/pet_service.dart"; +import "package:photos/theme/ente_theme.dart"; +import "package:photos/ui/common/popup_item.dart"; +import "package:photos/ui/components/buttons/button_widget.dart" + show ButtonAction; +import "package:photos/ui/viewer/actions/file_selection_overlay_bar.dart"; +import "package:photos/ui/viewer/gallery/gallery.dart"; +import "package:photos/ui/viewer/gallery/state/gallery_boundaries_provider.dart"; +import "package:photos/ui/viewer/gallery/state/gallery_files_inherited_widget.dart"; +import "package:photos/ui/viewer/gallery/state/selection_state.dart"; +import "package:photos/ui/viewer/people/face_thumbnail_squircle.dart"; +import "package:photos/ui/viewer/people/merge_pet_sheet.dart"; +import "package:photos/ui/viewer/people/pet_clusters_page.dart"; +import "package:photos/ui/viewer/people/pet_face_widget.dart"; +import "package:photos/ui/viewer/people/save_or_edit_pet.dart"; +import "package:photos/ui/viewer/people/save_person_banner.dart"; +import "package:photos/utils/dialog_util.dart"; + +/// Detail page for a pet cluster with gallery, name editing, and reassignment. +class PetClusterPage extends StatefulWidget { + final String clusterId; + final String clusterLabel; + final List files; + final int species; + + const PetClusterPage({ + required this.clusterId, + required this.clusterLabel, + required this.files, + required this.species, + super.key, + }); + + @override + State createState() => _PetClusterPageState(); +} + +class _PetClusterPageState extends State { + final _selectedFiles = SelectedFiles(); + late List _files; + late String _label; + bool _isBannerDismissed = false; + late final StreamSubscription _filesUpdated; + late final StreamSubscription _petsChanged; + + @override + void initState() { + super.initState(); + _files = List.from(widget.files) + ..sort((a, b) => (b.creationTime ?? 0).compareTo(a.creationTime ?? 0)); + _label = widget.clusterLabel; + _filesUpdated = Bus.instance.on().listen((event) { + if (event.type == EventType.deletedFromEverywhere || + event.type == EventType.deletedFromRemote || + event.type == EventType.hide) { + for (final f in event.updatedFiles) { + _files.remove(f); + } + setState(() {}); + } + }); + _petsChanged = Bus.instance.on().listen((_) { + WidgetsBinding.instance.addPostFrameCallback((_) { + if (mounted) _reloadClusterFiles(); + }); + }); + } + + Future _reloadClusterFiles() async { + if (!mounted) return; + final mlDataDB = + isOfflineMode ? MLDataDB.offlineInstance : MLDataDB.instance; + + // Load files from all clusters sharing the same pet (after merge). + final clusterToPetId = await mlDataDB.getClusterToPetId(); + final petId = clusterToPetId[widget.clusterId]; + final allFileIds = []; + if (petId != null) { + final siblingClusters = clusterToPetId.entries + .where((e) => e.value == petId) + .map((e) => e.key); + for (final cid in siblingClusters) { + allFileIds.addAll(await mlDataDB.getPetFileIdsForCluster(cid)); + } + } else { + allFileIds + .addAll(await mlDataDB.getPetFileIdsForCluster(widget.clusterId)); + } + final fileIds = allFileIds.toSet().toList(); + if (fileIds.isEmpty) { + if (mounted) Navigator.pop(context); + return; + } + + final List files; + if (isOfflineMode) { + final localIdMap = + await OfflineFilesDB.instance.getLocalIdsForIntIds(fileIds.toSet()); + final localIds = localIdMap.values.toList(); + files = localIds.isEmpty + ? [] + : await FilesDB.instance.getLocalFiles(localIds); + } else { + files = await FilesDB.instance + .getFilesFromIDs(fileIds, dedupeByUploadId: true); + } + + if (!mounted) return; + // Deduplicate by generatedID to avoid duplicate-key errors in Gallery. + final seen = {}; + final dedupedFiles = []; + for (final f in files) { + if (seen.add(f.generatedID ?? 0)) { + dedupedFiles.add(f); + } + } + setState(() { + _files = dedupedFiles + ..sort( + (a, b) => (b.creationTime ?? 0).compareTo(a.creationTime ?? 0), + ); + }); + } + + @override + void dispose() { + _filesUpdated.cancel(); + _petsChanged.cancel(); + super.dispose(); + } + + @override + Widget build(BuildContext context) { + final l10n = AppLocalizations.of(context); + final textTheme = getEnteTextTheme(context); + final colorScheme = getEnteColorScheme(context); + + final bool showBanner = !_isBannerDismissed && + !isOfflineMode && + _files.isNotEmpty && + (widget.clusterLabel.isEmpty || + widget.clusterLabel.startsWith("Dog") || + widget.clusterLabel.startsWith("Cat") || + widget.clusterLabel.startsWith("Pet")); + + final gallery = Gallery( + asyncLoader: (creationStartTime, creationEndTime, {limit, asc}) { + final result = _files + .where( + (file) => + (file.creationTime ?? 0) >= creationStartTime && + (file.creationTime ?? 0) <= creationEndTime, + ) + .toList(); + return Future.value( + FileLoadResult(result, result.length < _files.length), + ); + }, + reloadEvent: Bus.instance.on(), + forceReloadEvents: [Bus.instance.on()], + removalEventTypes: const { + EventType.deletedFromRemote, + EventType.deletedFromEverywhere, + EventType.hide, + }, + selectedFiles: _selectedFiles, + tagPrefix: "pet_cluster_${widget.clusterId}", + enableFileGrouping: true, + initialFiles: _files, + header: showBanner + ? SavePersonBanner( + faceWidget: PetFaceWidget(petClusterId: widget.clusterId), + text: l10n.savePet, + subText: l10n.findThemQuickly, + primaryActionLabel: l10n.save, + secondaryActionLabel: l10n.merge, + onPrimaryTap: () => _editName(), + onSecondaryTap: () => _handleMergePet(), + onDismissed: () => setState(() => _isBannerDismissed = true), + dismissibleKey: ValueKey("pet_banner_${widget.clusterId}"), + ) + : null, + ); + + return GalleryBoundariesProvider( + child: GalleryFilesState( + child: Scaffold( + appBar: AppBar( + title: GestureDetector( + onTap: isOfflineMode ? null : _editName, + child: Row( + mainAxisSize: MainAxisSize.min, + children: [ + SizedBox( + width: 32, + height: 32, + child: FaceThumbnailSquircleClip( + child: PetFaceWidget(petClusterId: widget.clusterId), + ), + ), + const SizedBox(width: 8), + Flexible( + child: Text( + _label, + overflow: TextOverflow.ellipsis, + ), + ), + if (!isOfflineMode) ...[ + const SizedBox(width: 4), + Icon( + Icons.edit, + size: 16, + color: colorScheme.strokeMuted, + ), + ], + ], + ), + ), + actions: [ + if (!isOfflineMode) + PopupMenuButton<_PetClusterAction>( + icon: const Icon(Icons.more_horiz), + onSelected: (action) { + switch (action) { + case _PetClusterAction.viewClusters: + _viewClusters(); + case _PetClusterAction.ignore: + _ignorePet(); + } + }, + itemBuilder: (_) => [ + EntePopupMenuItem( + l10n.viewClusters, + value: _PetClusterAction.viewClusters, + icon: Icons.account_tree_outlined, + ), + EntePopupMenuItem( + l10n.ignorePet, + value: _PetClusterAction.ignore, + icon: Icons.hide_image_outlined, + ), + ], + ), + Text( + "${_files.length}", + style: textTheme.body.copyWith(color: colorScheme.textMuted), + ), + const SizedBox(width: 16), + ], + ), + body: SelectionState( + selectedFiles: _selectedFiles, + child: Stack( + alignment: Alignment.bottomCenter, + children: [ + gallery, + FileSelectionOverlayBar( + GalleryType.petCluster, + _selectedFiles, + clusterID: widget.clusterId, + ), + ], + ), + ), + ), + ), + ); + } + + Future _editName() async { + if (isOfflineMode) return; + final result = await routeToPage( + context, + SaveOrEditPet( + clusterId: widget.clusterId, + species: widget.species, + currentName: _label, + ), + ); + if (result is String && result.isNotEmpty && mounted) { + setState(() { + _label = result; + _isBannerDismissed = true; + }); + } + } + + Future _handleMergePet() async { + if (isOfflineMode) return; + final selection = await showMergePetPage( + context, + currentClusterId: widget.clusterId, + ); + if (selection == null || !mounted) return; + await PetService.instance.addClusterToExistingPet( + petId: selection.petId, + clusterID: widget.clusterId, + ); + Navigator.of(context).pop(); + } + + Future _ignorePet() async { + if (isOfflineMode) return; + final l10n = AppLocalizations.of(context); + final result = await showChoiceDialog( + context, + title: l10n.areYouSureYouWantToIgnoreThisPet, + body: l10n.thePetGroupsWillNotBeDisplayed, + firstButtonLabel: l10n.confirm, + firstButtonOnTap: () async { + await PetService.instance + .ignorePetCluster(widget.clusterId, widget.species); + }, + ); + if (!mounted || result?.action != ButtonAction.first) return; + Navigator.of(context).pop(); + } + + Future _viewClusters() async { + if (isOfflineMode) return; + final mlDataDB = + isOfflineMode ? MLDataDB.offlineInstance : MLDataDB.instance; + final clusterToPetId = await mlDataDB.getClusterToPetId(); + final petId = clusterToPetId[widget.clusterId]; + if (petId == null || !mounted) return; + await routeToPage( + context, + PetClustersPage(petId: petId, petName: _label), + ); + await _reloadClusterFiles(); + } +} + +enum _PetClusterAction { viewClusters, ignore } diff --git a/mobile/apps/photos/lib/ui/viewer/people/pet_clusters_page.dart b/mobile/apps/photos/lib/ui/viewer/people/pet_clusters_page.dart new file mode 100644 index 00000000000..39eb3b1d58b --- /dev/null +++ b/mobile/apps/photos/lib/ui/viewer/people/pet_clusters_page.dart @@ -0,0 +1,124 @@ +import "package:flutter/material.dart"; +import "package:photos/db/ml/db.dart"; +import "package:photos/l10n/l10n.dart"; +import "package:photos/service_locator.dart" show isOfflineMode; +import "package:photos/services/machine_learning/pet_ml/pet_clustering_service.dart"; +import "package:photos/services/machine_learning/pet_ml/pet_service.dart"; +import "package:photos/theme/ente_theme.dart"; +import "package:photos/ui/viewer/people/face_thumbnail_squircle.dart"; +import "package:photos/ui/viewer/people/pet_face_widget.dart"; + +/// Shows all clusters belonging to a pet. Mirrors [PersonClustersPage]. +class PetClustersPage extends StatefulWidget { + final String petId; + final String petName; + + const PetClustersPage({ + required this.petId, + required this.petName, + super.key, + }); + + @override + State createState() => _PetClustersPageState(); +} + +class _PetClustersPageState extends State { + List<(String clusterId, int fileCount, int species)>? _clusters; + bool _loading = true; + + @override + void initState() { + super.initState(); + _loadClusters(); + } + + Future _loadClusters() async { + final mlDataDB = + isOfflineMode ? MLDataDB.offlineInstance : MLDataDB.instance; + final clusterToPetId = await mlDataDB.getClusterToPetId(); + final clusterIds = clusterToPetId.entries + .where((e) => e.value == widget.petId) + .map((e) => e.key) + .toList(); + + final allInfo = await mlDataDB.getAllPetClustersWithInfo(); + final infoMap = {for (final c in allInfo) c.$1: c}; + + final result = <(String, int, int)>[]; + for (final cid in clusterIds) { + final info = infoMap[cid]; + if (info != null) { + result.add((cid, info.$3, info.$2)); + } + } + // Largest cluster first (primary) + result.sort((a, b) => b.$2.compareTo(a.$2)); + if (mounted) { + setState(() { + _clusters = result; + _loading = false; + }); + } + } + + Future _removeCluster(String clusterId) async { + await PetService.instance.removeClusterFromPet( + petID: widget.petId, + clusterID: clusterId, + ); + await _loadClusters(); + } + + @override + Widget build(BuildContext context) { + final textTheme = getEnteTextTheme(context); + final colorScheme = getEnteColorScheme(context); + + return Scaffold( + appBar: AppBar( + title: Text(widget.petName), + ), + body: _loading + ? const Center(child: CircularProgressIndicator()) + : _clusters == null || _clusters!.isEmpty + ? Center( + child: Text( + context.l10n.noClusters, + style: + textTheme.body.copyWith(color: colorScheme.textMuted), + ), + ) + : ListView.builder( + itemCount: _clusters!.length, + itemBuilder: (context, index) { + final (clusterId, fileCount, _) = _clusters![index]; + final canRemove = + !isOfflineMode && _clusters!.length > 1 && index != 0; + return ListTile( + leading: SizedBox( + width: 56, + height: 56, + child: FaceThumbnailSquircleClip( + child: PetFaceWidget(petClusterId: clusterId), + ), + ), + title: Text( + context.l10n.photosCount(count: fileCount), + style: textTheme.body, + ), + trailing: canRemove + ? IconButton( + icon: Icon( + Icons.remove_circle_outline, + color: colorScheme.warning500, + ), + onPressed: () => _removeCluster(clusterId), + ) + : null, + ); + }, + ), + ); + } +} diff --git a/mobile/apps/photos/lib/ui/viewer/people/pet_face_widget.dart b/mobile/apps/photos/lib/ui/viewer/people/pet_face_widget.dart new file mode 100644 index 00000000000..132a0186480 --- /dev/null +++ b/mobile/apps/photos/lib/ui/viewer/people/pet_face_widget.dart @@ -0,0 +1,126 @@ +import "dart:async"; +import "dart:convert"; +import "dart:typed_data"; + +import "package:flutter/material.dart"; +import "package:logging/logging.dart"; +import "package:photos/core/event_bus.dart"; +import "package:photos/db/files_db.dart"; +import "package:photos/db/ml/db.dart"; +import "package:photos/db/offline_files_db.dart"; +import "package:photos/events/pets_changed_event.dart"; +import "package:photos/models/file/file.dart"; +import "package:photos/models/ml/face/box.dart"; +import "package:photos/models/ml/face/detection.dart"; +import "package:photos/models/ml/face/face.dart"; +import "package:photos/service_locator.dart" show isOfflineMode; +import "package:photos/services/machine_learning/pet_ml/pet_clustering_service.dart"; +import "package:photos/utils/face/face_thumbnail_cache.dart"; + +final _logger = Logger("PetFaceWidget"); + +class PetFaceWidget extends StatefulWidget { + final String petClusterId; + + const PetFaceWidget({required this.petClusterId, super.key}); + + @override + State createState() => _PetFaceWidgetState(); +} + +class _PetFaceWidgetState extends State { + Future? _faceCropFuture; + late final StreamSubscription _petsChangedSub; + + @override + void initState() { + super.initState(); + _faceCropFuture = _loadFaceCrop(); + _petsChangedSub = + Bus.instance.on().listen((_) => _reload()); + } + + @override + void dispose() { + _petsChangedSub.cancel(); + super.dispose(); + } + + void _reload() { + if (mounted) { + setState(() { + _faceCropFuture = _loadFaceCrop(); + }); + } + } + + Future _loadFaceCrop() async { + try { + final mlDataDB = + isOfflineMode ? MLDataDB.offlineInstance : MLDataDB.instance; + final dbPetFace = + await mlDataDB.getCoverPetFaceForCluster(widget.petClusterId); + if (dbPetFace == null) return null; + + EnteFile? enteFile; + if (isOfflineMode) { + final localId = + await OfflineFilesDB.instance.getLocalIdForIntId(dbPetFace.fileId); + if (localId != null) { + final files = await FilesDB.instance.getLocalFiles([localId]); + enteFile = files.firstOrNull; + } + } else { + enteFile = await FilesDB.instance.getAnyUploadedFile(dbPetFace.fileId); + } + if (enteFile == null) return null; + + final json = jsonDecode(dbPetFace.detection) as Map; + final boxList = + (json['box'] as List).map((e) => (e as num).toDouble()).toList(); + final detection = Detection( + box: FaceBox( + x: boxList[0], + y: boxList[1], + width: boxList[2] - boxList[0], + height: boxList[3] - boxList[1], + ), + landmarks: const [], + ); + final face = Face( + dbPetFace.petFaceId, + dbPetFace.fileId, + const [], + dbPetFace.faceScore, + detection, + 0.0, + ); + + final crops = await getCachedFaceCrops( + enteFile, + [face], + useTempCache: true, + ); + return crops?[face.faceID]; + } catch (e, s) { + _logger.warning("Failed to load pet face crop", e, s); + return null; + } + } + + @override + Widget build(BuildContext context) { + return FutureBuilder( + future: _faceCropFuture, + builder: (context, snapshot) { + if (snapshot.hasData && snapshot.data != null) { + return Image.memory( + snapshot.data!, + fit: BoxFit.cover, + ); + } + return const Center(child: Icon(Icons.pets, size: 32)); + }, + ); + } +} diff --git a/mobile/apps/photos/lib/ui/viewer/people/save_or_edit_pet.dart b/mobile/apps/photos/lib/ui/viewer/people/save_or_edit_pet.dart new file mode 100644 index 00000000000..88a74a97862 --- /dev/null +++ b/mobile/apps/photos/lib/ui/viewer/people/save_or_edit_pet.dart @@ -0,0 +1,250 @@ +import "dart:async"; + +import "package:flutter/material.dart"; +import "package:logging/logging.dart"; +import "package:photos/db/ml/db.dart"; +import "package:photos/l10n/l10n.dart"; +import "package:photos/models/ml/pet/pet_entity.dart"; +import "package:photos/service_locator.dart" show isOfflineMode; +import "package:photos/services/machine_learning/pet_ml/pet_clustering_service.dart"; +import "package:photos/services/machine_learning/pet_ml/pet_service.dart"; +import "package:photos/theme/ente_theme.dart"; +import "package:photos/ui/common/date_input.dart"; +import "package:photos/ui/components/buttons/button_widget.dart"; +import "package:photos/ui/components/models/button_type.dart"; +import "package:photos/ui/viewer/people/face_thumbnail_squircle.dart"; +import "package:photos/ui/viewer/people/pet_face_widget.dart"; + +/// Full-page screen for saving or editing a pet, mirroring the person +/// save/edit screen layout (minus email/suggestions). +class SaveOrEditPet extends StatefulWidget { + final String clusterId; + final int? species; + final String? currentName; + final String? petId; + final bool isEditing; + + const SaveOrEditPet({ + required this.clusterId, + this.species, + this.currentName, + this.petId, + this.isEditing = false, + super.key, + }); + + @override + State createState() => _SaveOrEditPetState(); +} + +class _SaveOrEditPetState extends State { + final _logger = Logger("_SaveOrEditPetState"); + String _inputName = ""; + String? _selectedDate; + PetData? _existingData; + int? _species; + Timer? _debounce; + + @override + void initState() { + super.initState(); + _inputName = widget.currentName ?? ""; + _species = widget.species; + _loadExistingData(); + } + + @override + void dispose() { + _debounce?.cancel(); + super.dispose(); + } + + Future _loadExistingData() async { + final mlDataDB = + isOfflineMode ? MLDataDB.offlineInstance : MLDataDB.instance; + final resolvedPetId = + widget.petId ?? (await mlDataDB.getClusterToPetId())[widget.clusterId]; + if (resolvedPetId != null) { + final pet = await PetService.instance.getPet(resolvedPetId); + if (pet != null && mounted) { + setState(() { + _existingData = pet.data; + _selectedDate = pet.data.birthDate; + if (_inputName.isEmpty) _inputName = pet.data.name; + }); + } + } + _species ??= await _resolveSpecies(mlDataDB); + } + + Future _resolveSpecies(MLDataDB mlDataDB) async { + final clusters = await mlDataDB.getAllPetClustersWithInfo(); + for (final cluster in clusters) { + if (cluster.$1 == widget.clusterId) { + return cluster.$2; + } + } + return null; + } + + bool get _hasChanges { + if (_existingData == null) return _inputName.isNotEmpty; + return _inputName != _existingData!.name || + _selectedDate != _existingData!.birthDate; + } + + @override + Widget build(BuildContext context) { + final textTheme = getEnteTextTheme(context); + final colorScheme = getEnteColorScheme(context); + + return Scaffold( + resizeToAvoidBottomInset: true, + appBar: AppBar( + title: Align( + alignment: Alignment.centerLeft, + child: Text( + widget.isEditing ? context.l10n.editPerson : context.l10n.savePet, + ), + ), + ), + body: GestureDetector( + onTap: () => FocusScope.of(context).unfocus(), + child: SafeArea( + child: Column( + children: [ + Expanded( + child: SingleChildScrollView( + padding: const EdgeInsets.only( + bottom: 32.0, + left: 16.0, + right: 16.0, + ), + child: Column( + children: [ + const SizedBox(height: 48), + SizedBox( + height: 110, + width: 110, + child: FaceThumbnailSquircleClip( + child: PetFaceWidget( + petClusterId: widget.clusterId, + ), + ), + ), + const SizedBox(height: 36), + TextFormField( + keyboardType: TextInputType.name, + textCapitalization: TextCapitalization.words, + autocorrect: false, + initialValue: _inputName, + onChanged: (value) { + if (_debounce?.isActive ?? false) { + _debounce?.cancel(); + } + _debounce = + Timer(const Duration(milliseconds: 300), () { + setState(() => _inputName = value); + }); + }, + decoration: InputDecoration( + focusedBorder: OutlineInputBorder( + borderRadius: const BorderRadius.all( + Radius.circular(8.0), + ), + borderSide: BorderSide( + color: colorScheme.strokeMuted, + ), + ), + fillColor: colorScheme.fillFaint, + filled: true, + hintText: context.l10n.enterName, + hintStyle: textTheme.bodyFaint, + contentPadding: const EdgeInsets.symmetric( + horizontal: 16, + vertical: 14, + ), + border: UnderlineInputBorder( + borderSide: BorderSide.none, + borderRadius: BorderRadius.circular(8), + ), + ), + ), + const SizedBox(height: 16), + DatePickerField( + hintText: context.l10n.enterDateOfBirth, + firstDate: DateTime(100), + lastDate: DateTime.now(), + initialValue: _selectedDate, + isRequired: false, + onChanged: (date) { + setState(() { + _selectedDate = + date?.toIso8601String().split("T").first; + }); + }, + ), + const SizedBox(height: 24), + ButtonWidget( + buttonType: ButtonType.primary, + labelText: context.l10n.save, + isDisabled: !_hasChanges || _inputName.trim().isEmpty, + onTap: () async => _save(), + ), + ], + ), + ), + ), + ], + ), + ), + ), + ); + } + + Future _save() async { + final name = _inputName.trim(); + if (name.isEmpty) return; + + final mlDataDB = + isOfflineMode ? MLDataDB.offlineInstance : MLDataDB.instance; + final petService = PetService.instance; + + try { + String petId; + final resolvedPetId = widget.petId ?? + (await mlDataDB.getClusterToPetId())[widget.clusterId]; + + if (resolvedPetId != null) { + petId = resolvedPetId; + final existingPet = await petService.getPet(petId); + final updatedData = existingPet != null + ? existingPet.data.copyWith(name: name, birthDate: _selectedDate) + : PetData( + name: name, + species: _species ?? -1, + birthDate: _selectedDate, + ); + await petService.updatePet(petId, updatedData); + } else { + _species ??= await _resolveSpecies(mlDataDB); + final pet = await petService.addPet( + PetData( + name: name, + species: _species ?? -1, + birthDate: _selectedDate, + ), + ); + petId = pet.remoteID; + } + + await mlDataDB.setClusterPetId(widget.clusterId, petId); + + if (mounted) { + Navigator.pop(context, name); + } + } catch (e) { + _logger.severe("Error saving pet", e); + } + } +} diff --git a/mobile/apps/photos/lib/ui/viewer/people/save_person_banner.dart b/mobile/apps/photos/lib/ui/viewer/people/save_person_banner.dart index e21d8da956d..3f3dafcc234 100644 --- a/mobile/apps/photos/lib/ui/viewer/people/save_person_banner.dart +++ b/mobile/apps/photos/lib/ui/viewer/people/save_person_banner.dart @@ -1,10 +1,9 @@ import "package:flutter/material.dart"; import "package:photos/theme/ente_theme.dart"; import "package:photos/ui/viewer/people/face_thumbnail_squircle.dart"; -import "package:photos/ui/viewer/people/person_face_widget.dart"; class SavePersonBanner extends StatelessWidget { - final PersonFaceWidget? faceWidget; + final Widget? faceWidget; final String text; final String? subText; final String? primaryActionLabel; diff --git a/mobile/apps/photos/lib/ui/viewer/search/result/people_section_all_page.dart b/mobile/apps/photos/lib/ui/viewer/search/result/people_section_all_page.dart index 14e9513a0ee..7008da30808 100644 --- a/mobile/apps/photos/lib/ui/viewer/search/result/people_section_all_page.dart +++ b/mobile/apps/photos/lib/ui/viewer/search/result/people_section_all_page.dart @@ -31,6 +31,7 @@ import "package:photos/ui/viewer/file/thumbnail_widget.dart"; import "package:photos/ui/viewer/people/face_thumbnail_squircle.dart"; import "package:photos/ui/viewer/people/person_face_widget.dart"; import "package:photos/ui/viewer/people/person_gallery_suggestion.dart"; +import "package:photos/ui/viewer/people/pet_face_widget.dart"; import "package:photos/ui/viewer/people/pinned_person_badge.dart"; import "package:photos/ui/viewer/search/result/search_result_page.dart"; import "package:photos/ui/viewer/search_tab/people_section.dart"; @@ -301,6 +302,13 @@ class FaceSearchResult extends StatelessWidget { @override Widget build(BuildContext context) { final params = (searchResult as GenericSearchResult).params; + final petClusterId = params[kPetClusterParamId] as String?; + if (petClusterId != null) { + return PetFaceWidget( + petClusterId: petClusterId, + key: ValueKey(petClusterId), + ); + } final int cachedPixelWidth = (displaySize * MediaQuery.devicePixelRatioOf(context)).toInt(); return PersonFaceWidget( @@ -375,7 +383,8 @@ class _PeopleSectionAllWidgetState extends State { } return faces.where((face) { final personId = face.params[kPersonParamID] as String?; - if (personId == null || personId.isEmpty) { + final petClusterId = face.params[kPetClusterParamId] as String?; + if ((personId == null || personId.isEmpty) && petClusterId == null) { return false; } return face.name().toLowerCase().contains(query); @@ -467,6 +476,11 @@ class _PeopleSectionAllWidgetState extends State { minClusterSize: kMinimumClusterSizeAllFaces, showIgnoredOnly: _showingIgnoredPeople, ); + if (!_showingIgnoredPeople && flagService.petEnabled) { + allFaces.addAll( + await SearchService.instance.getAllPets(null), + ); + } normalFaces.clear(); extraFaces.clear(); if (_showingIgnoredPeople) { @@ -629,7 +643,11 @@ class _PeopleSectionAllWidgetState extends State { final slivers = [ if (widget.showSearchBar) SearchableAppBar( - title: Text(SectionType.face.sectionTitle(context)), + title: Text( + flagService.petEnabled + ? AppLocalizations.of(context).peopleAndPets + : SectionType.face.sectionTitle(context), + ), autoActivateSearch: widget.startInSearchMode, onSearch: _updateSearchQuery, onSearchClosed: _clearSearchQuery, diff --git a/mobile/apps/photos/lib/ui/viewer/search_tab/people_section.dart b/mobile/apps/photos/lib/ui/viewer/search_tab/people_section.dart index 942fb477133..236f1dbe17f 100644 --- a/mobile/apps/photos/lib/ui/viewer/search_tab/people_section.dart +++ b/mobile/apps/photos/lib/ui/viewer/search_tab/people_section.dart @@ -21,7 +21,9 @@ import "package:photos/ui/viewer/file/thumbnail_widget.dart"; import "package:photos/ui/viewer/people/add_person_action_sheet.dart"; import "package:photos/ui/viewer/people/face_thumbnail_squircle.dart"; import "package:photos/ui/viewer/people/people_page.dart"; -import 'package:photos/ui/viewer/people/person_face_widget.dart'; +import "package:photos/ui/viewer/people/person_face_widget.dart"; +import "package:photos/ui/viewer/people/pet_face_widget.dart"; +import "package:photos/ui/viewer/people/save_or_edit_pet.dart"; import "package:photos/ui/viewer/search/result/people_section_all_page.dart"; import "package:photos/ui/viewer/search/result/search_result_page.dart"; import "package:photos/ui/viewer/search/search_section_cta.dart"; @@ -84,6 +86,12 @@ class _PeopleSectionState extends State { final shouldShowMore = _examples.length >= widget.limit - 1; final textTheme = getEnteTextTheme(context); final colorScheme = getEnteColorScheme(context); + final hasPets = _examples.any( + (e) => e.params.containsKey(kPetClusterParamId), + ); + final sectionTitle = hasPets + ? AppLocalizations.of(context).peopleAndPets + : widget.sectionType.sectionTitle(context); return _examples.isNotEmpty ? GestureDetector( behavior: HitTestBehavior.opaque, @@ -104,7 +112,7 @@ class _PeopleSectionState extends State { Padding( padding: const EdgeInsets.all(12), child: Text( - widget.sectionType.sectionTitle(context), + sectionTitle, style: textTheme.largeBold, ), ), @@ -202,6 +210,8 @@ class PersonSearchExample extends StatelessWidget { this.size = 102, }); + bool get _isPet => searchResult.params.containsKey(kPetClusterParamId); + void toggleSelection() { selectedPeople ?.toggleSelection(searchResult.params[kPersonParamID]! as String); @@ -209,6 +219,84 @@ class PersonSearchExample extends StatelessWidget { @override Widget build(BuildContext context) { + if (_isPet) return _buildPetItem(context); + return _buildPersonItem(context); + } + + Widget _buildPetItem(BuildContext context) { + final textTheme = getEnteTextTheme(context); + final clusterId = searchResult.params[kPetClusterParamId] as String?; + final hasCustomName = searchResult.name().isNotEmpty; + + return GestureDetector( + onTap: () { + RecentSearches().add(searchResult.name()); + if (searchResult.onResultTap != null) { + searchResult.onResultTap!(context); + } else { + routeToPage(context, SearchResultPage(searchResult)); + } + }, + child: Column( + mainAxisSize: MainAxisSize.min, + children: [ + SizedBox( + width: size, + height: size, + child: clusterId != null + ? FaceThumbnailSquircleClip( + child: PetFaceWidget(petClusterId: clusterId), + ) + : FaceThumbnailSquircleClip( + child: Container( + color: getEnteColorScheme(context).strokeFaint, + child: const Icon(Icons.pets, size: 40), + ), + ), + ), + hasCustomName + ? Padding( + padding: const EdgeInsets.only(top: 6), + child: SizedBox( + width: size, + child: Text( + searchResult.name(), + maxLines: 1, + textAlign: TextAlign.center, + overflow: TextOverflow.ellipsis, + style: textTheme.small, + ), + ), + ) + : GestureDetector( + behavior: HitTestBehavior.translucent, + onTap: () { + if (clusterId != null) { + routeToPage( + context, + SaveOrEditPet( + clusterId: clusterId, + ), + ); + } + }, + child: Padding( + padding: const EdgeInsets.only(top: 6), + child: Text( + AppLocalizations.of(context).addName, + maxLines: 1, + textAlign: TextAlign.center, + overflow: TextOverflow.ellipsis, + style: textTheme.small, + ), + ), + ), + ], + ), + ); + } + + Widget _buildPersonItem(BuildContext context) { final bool isCluster = searchResult.type() == ResultType.faces && searchResult.params.containsKey(kClusterParamId); diff --git a/mobile/apps/photos/lib/utils/face_crop_util.dart b/mobile/apps/photos/lib/utils/face_crop_util.dart index 2c3cf010d15..e697f89a823 100644 --- a/mobile/apps/photos/lib/utils/face_crop_util.dart +++ b/mobile/apps/photos/lib/utils/face_crop_util.dart @@ -54,14 +54,12 @@ FaceBox computePaddedFaceCropBox(FaceBox faceBox) { final xCrop = faceBox.x - faceBox.width * _regularPadding; final xOvershoot = min(0.0, xCrop).abs() / faceBox.width; - final widthCrop = - faceBox.width * (1 + 2 * _regularPadding) - + final widthCrop = faceBox.width * (1 + 2 * _regularPadding) - 2 * min(xOvershoot, _regularPadding - _minimumPadding) * faceBox.width; final yCrop = faceBox.y - faceBox.height * _regularPadding; final yOvershoot = min(0.0, yCrop).abs() / faceBox.height; - final heightCrop = - faceBox.height * (1 + 2 * _regularPadding) - + final heightCrop = faceBox.height * (1 + 2 * _regularPadding) - 2 * min(yOvershoot, _regularPadding - _minimumPadding) * faceBox.height; final xCropSafe = xCrop.clamp(0.0, 1.0); diff --git a/mobile/apps/photos/lib/utils/isolate/isolate_operations.dart b/mobile/apps/photos/lib/utils/isolate/isolate_operations.dart index 69ae162163c..7b65b217142 100644 --- a/mobile/apps/photos/lib/utils/isolate/isolate_operations.dart +++ b/mobile/apps/photos/lib/utils/isolate/isolate_operations.dart @@ -453,11 +453,9 @@ Future _ensureRustRuntimePrepared(Map args) async { (args["petFaceEmbeddingDogModelPath"] as String?) ?? "", petFaceEmbeddingCat: (args["petFaceEmbeddingCatModelPath"] as String?) ?? "", - petBodyDetection: (args["petBodyDetectionModelPath"] as String?) ?? "", - petBodyEmbeddingDog: - (args["petBodyEmbeddingDogModelPath"] as String?) ?? "", - petBodyEmbeddingCat: - (args["petBodyEmbeddingCatModelPath"] as String?) ?? "", + petBodyDetection: "", + petBodyEmbeddingDog: "", + petBodyEmbeddingCat: "", ); final providerPolicy = rust_ml.RustExecutionProviderPolicy( preferCoreml: args["preferCoreml"] as bool? ?? true, @@ -522,9 +520,6 @@ String _runtimeConfigCacheKey( modelPaths.petFaceDetection, modelPaths.petFaceEmbeddingDog, modelPaths.petFaceEmbeddingCat, - modelPaths.petBodyDetection, - modelPaths.petBodyEmbeddingDog, - modelPaths.petBodyEmbeddingCat, providerPolicy.preferCoreml, providerPolicy.preferNnapi, providerPolicy.preferXnnpack, diff --git a/mobile/apps/photos/lib/utils/ml_util.dart b/mobile/apps/photos/lib/utils/ml_util.dart index f49dd00f526..d479da01a52 100644 --- a/mobile/apps/photos/lib/utils/ml_util.dart +++ b/mobile/apps/photos/lib/utils/ml_util.dart @@ -1,5 +1,6 @@ import "dart:io" show Directory, File, Platform; import "dart:math" as math show sqrt, min, max; +import "dart:typed_data" show Float32List; import "package:ente_pure_utils/ente_pure_utils.dart"; import "package:flutter/services.dart" show PlatformException; @@ -762,12 +763,6 @@ Future analyzeImageRust(Map args) async { args["petFaceEmbeddingDogModelPath"] as String?; final String? petFaceEmbeddingCatModelPath = args["petFaceEmbeddingCatModelPath"] as String?; - final String? petBodyDetectionModelPath = - args["petBodyDetectionModelPath"] as String?; - final String? petBodyEmbeddingDogModelPath = - args["petBodyEmbeddingDogModelPath"] as String?; - final String? petBodyEmbeddingCatModelPath = - args["petBodyEmbeddingCatModelPath"] as String?; final bool preferCoreml = args["preferCoreml"] as bool? ?? true; final bool preferNnapi = args["preferNnapi"] as bool? ?? true; final bool preferXnnpack = args["preferXnnpack"] as bool? ?? false; @@ -797,15 +792,6 @@ Future analyzeImageRust(Map args) async { if (isMissingModelPath(petFaceEmbeddingCatModelPath)) { missingModelPaths.add("petFaceEmbeddingCatModelPath"); } - if (isMissingModelPath(petBodyDetectionModelPath)) { - missingModelPaths.add("petBodyDetectionModelPath"); - } - if (isMissingModelPath(petBodyEmbeddingDogModelPath)) { - missingModelPaths.add("petBodyEmbeddingDogModelPath"); - } - if (isMissingModelPath(petBodyEmbeddingCatModelPath)) { - missingModelPaths.add("petBodyEmbeddingCatModelPath"); - } } if (missingModelPaths.isNotEmpty) { throw Exception( @@ -821,9 +807,9 @@ Future analyzeImageRust(Map args) async { petFaceDetection: petFaceDetectionModelPath ?? "", petFaceEmbeddingDog: petFaceEmbeddingDogModelPath ?? "", petFaceEmbeddingCat: petFaceEmbeddingCatModelPath ?? "", - petBodyDetection: petBodyDetectionModelPath ?? "", - petBodyEmbeddingDog: petBodyEmbeddingDogModelPath ?? "", - petBodyEmbeddingCat: petBodyEmbeddingCatModelPath ?? "", + petBodyDetection: "", + petBodyEmbeddingDog: "", + petBodyEmbeddingCat: "", ); final providerPolicy = rust_ml.RustExecutionProviderPolicy( preferCoreml: preferCoreml, @@ -1023,18 +1009,6 @@ Future analyzeImageRust(Map args) async { ); }).toList(growable: false); } - - if (rustResult.petBodies != null) { - result.petBodies = rustResult.petBodies!.map((body) { - return PetBodyResult( - boxXyxy: body.boxXyxy.toList(growable: false), - score: body.score, - cocoClass: body.cocoClass, - petBodyId: body.petBodyId, - embedding: Embedding.from(body.bodyEmbedding), - ); - }).toList(growable: false); - } } return result; @@ -1219,3 +1193,27 @@ Future _cleanupDecodeFallback(_DecodeFallbackFile fallback) async { _logger.warning("Could not cleanup decode fallback file", e, s); } } + +/// Compute L2-normalized mean centroid from a list of Float32 embeddings. +Float32List computeL2MeanCentroid(List embeddings) { + final dim = embeddings.first.length; + final centroid = Float32List(dim); + for (final emb in embeddings) { + for (int i = 0; i < dim; i++) { + centroid[i] += emb[i]; + } + } + final n = embeddings.length.toDouble(); + double norm = 0; + for (int i = 0; i < dim; i++) { + centroid[i] /= n; + norm += centroid[i] * centroid[i]; + } + norm = math.sqrt(norm); + if (norm > 0) { + for (int i = 0; i < dim; i++) { + centroid[i] /= norm; + } + } + return centroid; +} diff --git a/mobile/apps/photos/plugins/ente_crypto/lib/src/crypto.dart b/mobile/apps/photos/plugins/ente_crypto/lib/src/crypto.dart index d7eddd4f748..86a2b9712e0 100644 --- a/mobile/apps/photos/plugins/ente_crypto/lib/src/crypto.dart +++ b/mobile/apps/photos/plugins/ente_crypto/lib/src/crypto.dart @@ -943,8 +943,8 @@ class CryptoUtil { // account KDF memory limit (128 MiB), not libsodium's absolute minimum. const int minMemLimit = 128 * 1024 * 1024; Uint8List key; - while (memLimit >= minMemLimit && - opsLimit <= Sodium.cryptoPwhashOpslimitMax) { + while ( + memLimit >= minMemLimit && opsLimit <= Sodium.cryptoPwhashOpslimitMax) { try { key = await deriveKey(password, salt, memLimit, opsLimit); return DerivedKeyResult(key, memLimit, opsLimit); diff --git a/mobile/apps/photos/plugins/ente_feature_flag/lib/src/service.dart b/mobile/apps/photos/plugins/ente_feature_flag/lib/src/service.dart index 5993fa2dc39..7c65b0d7827 100644 --- a/mobile/apps/photos/plugins/ente_feature_flag/lib/src/service.dart +++ b/mobile/apps/photos/plugins/ente_feature_flag/lib/src/service.dart @@ -104,20 +104,22 @@ class FlagService { bool get enableMemoryShareLink => true; - bool get useRustForML => internalUser; + bool get useRustForML => + internalUser || (_prefs.getBool("ls.pet_recognition_enabled") ?? false); bool get enableMLInBackground => internalUser; bool get useRustForFaceThumbnails => internalUser; - bool get petEnabled => internalUser; + bool get petEnabled => + internalUser || (_prefs.getBool("ls.pet_recognition_enabled") ?? false); bool get qrFeatureEnabled => true; - bool get ocrOverlayEnabled => true; - bool get enableBgLocalUploadPriority => internalUser; + bool get ocrOverlayEnabled => true; + bool get syncRecoveryDiagnostics => internalUser; Future tryRefreshFlags() async { diff --git a/mobile/apps/photos/rust/Cargo.lock b/mobile/apps/photos/rust/Cargo.lock index e514df2105d..fc8d7b5a43b 100644 --- a/mobile/apps/photos/rust/Cargo.lock +++ b/mobile/apps/photos/rust/Cargo.lock @@ -2873,9 +2873,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" +checksum = "3e61e67053d25a4e82c844e8424039d9745781b3fc4f32b8d55ed50f5f667ef3" dependencies = [ "tinyvec_macros", ] diff --git a/mobile/apps/photos/rust/src/api/ml_indexing_api.rs b/mobile/apps/photos/rust/src/api/ml_indexing_api.rs index 07626371905..709c1dda80a 100644 --- a/mobile/apps/photos/rust/src/api/ml_indexing_api.rs +++ b/mobile/apps/photos/rust/src/api/ml_indexing_api.rs @@ -1,8 +1,12 @@ +use std::collections::HashMap; + use ente_media_inspector::ml::{ indexing as shared_indexing, + pet::cluster::{self, ClusterConfig, PetClusterInput, Species}, runtime::{ExecutionProviderPolicy, MlRuntimeConfig, ModelPaths}, types as shared_types, }; +use ente_media_inspector::vector_db::VectorDB; #[derive(Clone, Debug)] pub struct RustExecutionProviderPolicy { @@ -333,3 +337,438 @@ fn to_api_pet_body_result(result: shared_types::PetBodyResult) -> RustPetBodyRes .collect(), } } + +// -- Pet Clustering API -- + +/// A single pet face/body entry for clustering, passed from Dart. +#[derive(Clone, Debug)] +pub struct RustPetClusterInput { + pub pet_face_id: String, + /// L2-normalized face embedding. Empty if no face detected. + pub face_embedding: Vec, + /// L2-normalized body embedding. Empty if no body detected. + pub body_embedding: Vec, + /// 0 = dog, 1 = cat. + pub species: u8, + pub file_id: i64, +} + +/// Result entry: one pet_face_id mapped to a cluster. +#[derive(Clone, Debug)] +pub struct RustPetClusterEntry { + pub pet_face_id: String, + pub cluster_id: String, +} + +/// Cluster summary: centroid + count. +#[derive(Clone, Debug)] +pub struct RustPetClusterSummary { + pub cluster_id: String, + pub centroid: Vec, + pub count: i32, +} + +/// Cluster exemplars: multiple diverse real embeddings per cluster. +#[derive(Clone, Debug)] +pub struct RustPetClusterExemplarSummary { + pub cluster_id: String, + /// Multiple L2-normalized exemplar embeddings (real faces, not averaged). + pub exemplars: Vec>, + pub count: i32, +} + +/// Existing cluster exemplars passed from Dart for incremental matching. +#[derive(Clone, Debug)] +pub struct RustPetClusterExemplarInput { + pub cluster_id: String, + pub exemplars: Vec>, +} + +/// Full clustering result returned to Dart. +#[derive(Clone, Debug)] +pub struct RustPetClusterResult { + pub assignments: Vec, + pub summaries: Vec, + /// Exemplar summaries for multi-exemplar incremental matching. + pub exemplar_summaries: Vec, + pub n_unclustered: i32, +} + +/// Run batch pet clustering on all provided inputs. +pub fn run_pet_clustering_rust( + inputs: Vec, + species: u8, +) -> Result { + let config = ClusterConfig::for_species(Species::from_u8(species)); + + let cluster_inputs: Vec = inputs + .into_iter() + .map(|i| PetClusterInput { + pet_face_id: i.pet_face_id, + face_embedding: i.face_embedding.into_iter().map(|v| v as f32).collect(), + species: i.species, + file_id: i.file_id, + }) + .collect(); + + let result = cluster::run_pet_clustering(&cluster_inputs, &config); + + Ok(to_api_cluster_result(result)) +} + +/// Run incremental pet clustering: assign new inputs to existing clusters, +/// then cluster remainder among themselves. +pub fn run_pet_clustering_incremental_rust( + new_inputs: Vec, + existing_face_centroids: Vec, + existing_body_centroids: Vec, + species: u8, +) -> Result { + let config = ClusterConfig::for_species(Species::from_u8(species)); + + let cluster_inputs: Vec = new_inputs + .into_iter() + .map(|i| PetClusterInput { + pet_face_id: i.pet_face_id, + face_embedding: i.face_embedding.into_iter().map(|v| v as f32).collect(), + species: i.species, + file_id: i.file_id, + }) + .collect(); + + let face_centroids: HashMap> = existing_face_centroids + .into_iter() + .map(|s| { + ( + s.cluster_id, + s.centroid.into_iter().map(|v| v as f32).collect(), + ) + }) + .collect(); + + let _body_centroids: HashMap> = existing_body_centroids + .into_iter() + .map(|s| { + ( + s.cluster_id, + s.centroid.into_iter().map(|v| v as f32).collect(), + ) + }) + .collect(); + + let result = cluster::run_pet_clustering_incremental( + &cluster_inputs, + &face_centroids, + &config, + ); + + Ok(to_api_cluster_result(result)) +} + +/// Run incremental pet clustering using multi-exemplar matching. +/// +/// Instead of comparing against a single centroid, compares new faces against +/// multiple real exemplar embeddings per cluster for better accuracy. +pub fn run_pet_clustering_incremental_exemplars_rust( + new_inputs: Vec, + existing_exemplars: Vec, + species: u8, +) -> Result { + let config = ClusterConfig::for_species(Species::from_u8(species)); + + let cluster_inputs: Vec = new_inputs + .into_iter() + .map(|i| PetClusterInput { + pet_face_id: i.pet_face_id, + face_embedding: i.face_embedding.into_iter().map(|v| v as f32).collect(), + species: i.species, + file_id: i.file_id, + }) + .collect(); + + let exemplars: HashMap>> = existing_exemplars + .into_iter() + .map(|e| { + ( + e.cluster_id, + e.exemplars + .into_iter() + .map(|ex| ex.into_iter().map(|v| v as f32).collect()) + .collect(), + ) + }) + .collect(); + + let result = cluster::run_pet_clustering_incremental_with_exemplars( + &cluster_inputs, + &exemplars, + &config, + ); + + Ok(to_api_cluster_result(result)) +} + +fn to_api_cluster_result(result: cluster::PetClusterResult) -> RustPetClusterResult { + let assignments: Vec = result + .face_to_cluster + .into_iter() + .map(|(face_id, cluster_id)| RustPetClusterEntry { + pet_face_id: face_id, + cluster_id, + }) + .collect(); + + let summaries: Vec = result + .cluster_centroids + .into_iter() + .map(|(cluster_id, centroid)| { + let count = result + .cluster_counts + .get(&cluster_id) + .copied() + .unwrap_or(0) as i32; + RustPetClusterSummary { + cluster_id, + centroid: centroid.into_iter().map(|v| v as f64).collect(), + count, + } + }) + .collect(); + + let exemplar_summaries: Vec = result + .cluster_exemplars + .into_iter() + .map(|(cluster_id, exemplars)| { + let count = result + .cluster_counts + .get(&cluster_id) + .copied() + .unwrap_or(0) as i32; + RustPetClusterExemplarSummary { + cluster_id, + exemplars: exemplars + .into_iter() + .map(|ex| ex.into_iter().map(|v| v as f64).collect()) + .collect(), + count, + } + }) + .collect(); + + RustPetClusterResult { + assignments, + summaries, + exemplar_summaries, + n_unclustered: result.n_unclustered as i32, + } +} + +// -- Pet Clustering with direct usearch access -- + +/// Lightweight face metadata passed from Dart (no embeddings). +#[derive(Clone, Debug)] +pub struct RustPetFaceMeta { + pub pet_face_id: String, + /// Integer key in the usearch index. + pub vector_id: i64, + pub species: u8, + pub file_id: i64, + /// Existing cluster ID, or empty string if unclustered. + pub cluster_id: String, +} + +/// Run batch pet clustering by reading embeddings directly from usearch. +/// +/// Dart passes only lightweight metadata + the path to the usearch index file. +/// Rust opens the index, bulk-reads embeddings, clusters, and returns +/// assignments — no embedding round-trip through FFI. +pub fn run_pet_clustering_from_index( + faces: Vec, + face_index_path: String, + species: u8, +) -> Result { + let config = ClusterConfig::for_species(Species::from_u8(species)); + let dim = 128; // pet face embedding dimension + + let vdb = VectorDB::new(&face_index_path, dim) + .map_err(|e| format!("Failed to open face index: {e}"))?; + + let mut inputs = Vec::with_capacity(faces.len()); + for face in &faces { + let emb = match vdb.get_vector(face.vector_id as u64) { + Ok(v) => v, + Err(_) => continue, + }; + inputs.push(PetClusterInput { + pet_face_id: face.pet_face_id.clone(), + face_embedding: emb, + species: face.species, + file_id: face.file_id, + }); + } + + if inputs.len() < 2 { + return Ok(RustPetClusterResult { + assignments: Vec::new(), + summaries: Vec::new(), + exemplar_summaries: Vec::new(), + n_unclustered: inputs.len() as i32, + }); + } + + let result = cluster::run_pet_clustering(&inputs, &config); + Ok(to_api_cluster_result(result)) +} + +/// Run incremental pet clustering by reading embeddings directly from usearch. +/// +/// Only unclustered faces are clustered against existing centroids. +/// Centroids are read from a separate usearch index. +pub fn run_pet_clustering_incremental_from_index( + new_faces: Vec, + face_index_path: String, + centroid_index_path: String, + // cluster_id -> vector_id in the centroid index + centroid_mappings: Vec, + // cluster_id -> face count (unused for now, reserved for weighted merge) + _centroid_counts: Vec, + species: u8, +) -> Result { + let config = ClusterConfig::for_species(Species::from_u8(species)); + let dim = 128; + + let vdb = VectorDB::new(&face_index_path, dim) + .map_err(|e| format!("Failed to open face index: {e}"))?; + + let mut inputs = Vec::with_capacity(new_faces.len()); + for face in &new_faces { + let emb = match vdb.get_vector(face.vector_id as u64) { + Ok(v) => v, + Err(_) => continue, + }; + inputs.push(PetClusterInput { + pet_face_id: face.pet_face_id.clone(), + face_embedding: emb, + species: face.species, + file_id: face.file_id, + }); + } + + if inputs.is_empty() { + return Ok(RustPetClusterResult { + assignments: Vec::new(), + summaries: Vec::new(), + exemplar_summaries: Vec::new(), + n_unclustered: 0, + }); + } + + // Load existing centroids from centroid index + let face_centroids: HashMap> = if !centroid_mappings.is_empty() { + let centroid_vdb = VectorDB::new(¢roid_index_path, dim) + .map_err(|e| format!("Failed to open centroid index: {e}"))?; + + let mut centroids = HashMap::new(); + for mapping in ¢roid_mappings { + if let Ok(emb) = centroid_vdb.get_vector(mapping.vector_id as u64) { + centroids.insert(mapping.cluster_id.clone(), emb); + } + } + centroids + } else { + HashMap::new() + }; + + let result = cluster::run_pet_clustering_incremental( + &inputs, + &face_centroids, + &config, + ); + + Ok(to_api_cluster_result(result)) +} + +/// Mapping from cluster ID to its vector ID in the centroid usearch index. +#[derive(Clone, Debug)] +pub struct RustCentroidMapping { + pub cluster_id: String, + pub vector_id: i64, +} + +/// Cluster ID with its face count (for incremental clustering). +#[derive(Clone, Debug)] +pub struct RustCentroidCount { + pub cluster_id: String, + pub count: i32, +} + +/// Exemplar embeddings for a cluster, used for incremental matching. +#[derive(Clone, Debug)] +pub struct RustClusterExemplars { + pub cluster_id: String, + /// Multiple real face embeddings (not averaged), f64 for Dart compatibility. + pub exemplars: Vec>, +} + +/// Run incremental pet clustering using multi-exemplar matching. +/// +/// Instead of comparing new faces against a single centroid per cluster, +/// compares against multiple diverse real face embeddings (exemplars). +/// Gives F1=0.96 vs centroid's F1=0.86. +pub fn run_pet_clustering_incremental_exemplars_from_index( + new_faces: Vec, + face_index_path: String, + cluster_exemplars: Vec, + species: u8, +) -> Result { + let config = ClusterConfig::for_species(Species::from_u8(species)); + let dim = 128; + + let vdb = VectorDB::new(&face_index_path, dim) + .map_err(|e| format!("Failed to open face index: {e}"))?; + + let mut inputs = Vec::with_capacity(new_faces.len()); + for face in &new_faces { + let emb = match vdb.get_vector(face.vector_id as u64) { + Ok(v) => v, + Err(_) => continue, + }; + inputs.push(PetClusterInput { + pet_face_id: face.pet_face_id.clone(), + face_embedding: emb, + species: face.species, + file_id: face.file_id, + }); + } + + if inputs.is_empty() { + return Ok(RustPetClusterResult { + assignments: Vec::new(), + summaries: Vec::new(), + exemplar_summaries: Vec::new(), + n_unclustered: 0, + }); + } + + // Convert exemplars from f64 to f32 + let existing_exemplars: HashMap>> = cluster_exemplars + .into_iter() + .map(|ce| { + let exs: Vec> = ce + .exemplars + .into_iter() + .map(|e| e.into_iter().map(|v| v as f32).collect()) + .collect(); + (ce.cluster_id, exs) + }) + .collect(); + + let result = cluster::run_pet_clustering_incremental_with_exemplars( + &inputs, + &existing_exemplars, + &config, + ); + + Ok(to_api_cluster_result(result)) +} + diff --git a/mobile/apps/photos/test/db/ml/pet_entity_db_test.dart b/mobile/apps/photos/test/db/ml/pet_entity_db_test.dart new file mode 100644 index 00000000000..64fcfd434b0 --- /dev/null +++ b/mobile/apps/photos/test/db/ml/pet_entity_db_test.dart @@ -0,0 +1,199 @@ +import 'dart:io'; + +import 'package:photos/db/ml/schema.dart'; +import 'package:photos/models/ml/pet/pet_entity.dart'; +import 'package:sqlite_async/sqlite_async.dart'; +import 'package:test/test.dart'; + +void main() { + // ── PetData model tests ── + + group('PetData', () { + test('toJson/fromJson roundtrip', () { + final data = PetData(name: 'Buddy', species: 0); + final restored = PetData.fromJson(data.toJson()); + + expect(restored.name, 'Buddy'); + expect(restored.species, 0); + }); + + test('copyWith updates name only', () { + final data = PetData(name: 'Buddy', species: 0); + final renamed = data.copyWith(name: 'Max'); + + expect(renamed.name, 'Max'); + expect(renamed.species, 0); + }); + + test('copyWith updates species only', () { + final data = PetData(name: 'Buddy', species: 0); + final changed = data.copyWith(species: 1); + + expect(changed.name, 'Buddy'); + expect(changed.species, 1); + }); + + test('toJson contains expected keys', () { + final data = PetData(name: 'Luna', species: 1); + final map = data.toJson(); + + expect(map.keys, containsAll(['name', 'species'])); + expect(map.length, 10); + }); + + test('fromJson handles missing fields with defaults', () { + final data = PetData.fromJson({}); + + expect(data.name, ''); + expect(data.species, -1); + }); + }); + + // ── PetEntity tests ── + + group('PetEntity', () { + test('copyWith replaces data', () { + final pet = PetEntity('pet-1', PetData(name: 'Buddy', species: 0)); + final updated = pet.copyWith( + data: PetData(name: 'Max', species: 0), + ); + + expect(updated.remoteID, 'pet-1'); + expect(updated.data.name, 'Max'); + }); + + test('remoteID is preserved on copyWith', () { + final pet = PetEntity('pet-1', PetData(name: 'Buddy', species: 0)); + final copy = pet.copyWith(); + + expect(copy.remoteID, 'pet-1'); + expect(copy.data.name, 'Buddy'); + }); + }); + + // ── pet_cluster_pet mapping table tests ── + + group('pet_cluster_pet mapping', () { + late SqliteDatabase db; + late Directory tempDir; + + setUp(() async { + tempDir = Directory.systemTemp.createTempSync('pet_mapping_test_'); + final dbPath = '${tempDir.path}/test_mapping.db'; + db = SqliteDatabase(path: dbPath); + await db.writeTransaction((tx) async { + await tx.execute(createPetClusterPetTable); + }); + }); + + tearDown(() async { + await db.close(); + }); + + Future setClusterPetId(String clusterId, String petId) async { + await db.execute( + '''INSERT INTO $petClusterPetTable ($clusterIDColumn, $petIdColumn) + VALUES (?, ?) + ON CONFLICT($clusterIDColumn) DO UPDATE SET + $petIdColumn = excluded.$petIdColumn''', + [clusterId, petId], + ); + } + + Future> getClusterToPetId() async { + final rows = await db.getAll( + 'SELECT $clusterIDColumn, $petIdColumn FROM $petClusterPetTable', + ); + return { + for (final r in rows) + r[clusterIDColumn] as String: r[petIdColumn] as String, + }; + } + + test('map a cluster to a pet', () async { + await setClusterPetId('cluster-1', 'pet-1'); + + final mappings = await getClusterToPetId(); + expect(mappings['cluster-1'], 'pet-1'); + }); + + test('update mapping for existing cluster', () async { + await setClusterPetId('cluster-1', 'pet-1'); + await setClusterPetId('cluster-1', 'pet-2'); + + final mappings = await getClusterToPetId(); + expect(mappings['cluster-1'], 'pet-2'); + expect(mappings.length, 1); + }); + + test('multiple clusters map to same pet (merge)', () async { + await setClusterPetId('cluster-1', 'pet-1'); + await setClusterPetId('cluster-2', 'pet-1'); + await setClusterPetId('cluster-3', 'pet-1'); + + final mappings = await getClusterToPetId(); + expect(mappings.length, 3); + expect(mappings.values.toSet(), {'pet-1'}); + }); + + test('different clusters map to different pets', () async { + await setClusterPetId('cluster-1', 'pet-1'); + await setClusterPetId('cluster-2', 'pet-2'); + + final mappings = await getClusterToPetId(); + expect(mappings['cluster-1'], 'pet-1'); + expect(mappings['cluster-2'], 'pet-2'); + }); + + test('unmerge by deleting a mapping', () async { + await setClusterPetId('cluster-1', 'pet-1'); + await setClusterPetId('cluster-2', 'pet-1'); + + await db.execute( + 'DELETE FROM $petClusterPetTable WHERE $clusterIDColumn = ?', + ['cluster-2'], + ); + + final mappings = await getClusterToPetId(); + expect(mappings.length, 1); + expect(mappings['cluster-1'], 'pet-1'); + expect(mappings.containsKey('cluster-2'), isFalse); + }); + + test('empty table returns empty map', () async { + final mappings = await getClusterToPetId(); + expect(mappings, isEmpty); + }); + + test('stale mapping removed when remote unassigns cluster', () async { + // Simulate initial state: pet-1 has cluster-1, cluster-2, cluster-3 + await setClusterPetId('cluster-1', 'pet-1'); + await setClusterPetId('cluster-2', 'pet-1'); + await setClusterPetId('cluster-3', 'pet-1'); + + // Simulate remote unassign of cluster-2: remote now has only + // cluster-1 and cluster-3 assigned to pet-1 + final remoteClusterIds = {'cluster-1', 'cluster-3'}; + final remotePetIDs = {'pet-1'}; + final localMappings = await getClusterToPetId(); + + // Remove stale: local mapping where pet exists remotely but + // cluster is no longer in remote assignments + for (final entry in localMappings.entries) { + if (remotePetIDs.contains(entry.value) && + !remoteClusterIds.contains(entry.key)) { + await db.execute( + 'DELETE FROM $petClusterPetTable WHERE $clusterIDColumn = ?', + [entry.key], + ); + } + } + + final mappings = await getClusterToPetId(); + expect(mappings.length, 2); + expect(mappings['cluster-1'], 'pet-1'); + expect(mappings['cluster-3'], 'pet-1'); + expect(mappings.containsKey('cluster-2'), isFalse); + }); + }); +} diff --git a/mobile/apps/photos/test/db/ml/pet_indexing_pipeline_test.dart b/mobile/apps/photos/test/db/ml/pet_indexing_pipeline_test.dart new file mode 100644 index 00000000000..e2fc567ebc4 --- /dev/null +++ b/mobile/apps/photos/test/db/ml/pet_indexing_pipeline_test.dart @@ -0,0 +1,425 @@ +import 'dart:io'; + +import 'package:photos/db/ml/db_pet_model_mappers.dart'; +import 'package:photos/db/ml/schema.dart'; +import 'package:photos/models/ml/ml_versions.dart'; +import 'package:sqlite_async/sqlite_async.dart'; +import 'package:test/test.dart'; + +/// Test the pet indexing pipeline after migrating from a separate +/// `pet_indexed_files` table to dummy [DBPetFace.empty] rows in +/// `pet_faces` (matching the human face detection pattern). +void main() { + // ── DBPetFace.empty() factory tests ── + + group('DBPetFace.empty', () { + test('creates a dummy entry with correct default values', () { + final dummy = DBPetFace.empty(42); + + expect(dummy.fileId, 42); + expect(dummy.petFaceId, '42_pet_0_0_0_0'); + expect(dummy.detection, '{}'); + expect(dummy.faceVectorId, isNull); + expect(dummy.species, -1); + expect(dummy.faceScore, 0.0); + expect(dummy.imageHeight, 0); + expect(dummy.imageWidth, 0); + expect(dummy.mlVersion, petMlVersion); + }); + + test('creates an error dummy with score -1.0', () { + final dummy = DBPetFace.empty(7, error: true); + + expect(dummy.fileId, 7); + expect(dummy.petFaceId, '7_pet_0_0_0_0'); + expect(dummy.faceScore, -1.0); + expect(dummy.species, -1); + expect(dummy.mlVersion, petMlVersion); + }); + + test('dummy petFaceId does not collide with real detection IDs', () { + // Real IDs use detection coordinates like "42_pet_0.12_0.34_0.56_0.78" + final dummy = DBPetFace.empty(42); + const realId = '42_pet_0.12_0.34_0.56_0.78'; + + expect(dummy.petFaceId, isNot(equals(realId))); + }); + }); + + // ── toMap / fromMap roundtrip tests ── + + group('DBPetFace serialisation roundtrip', () { + test('empty entry survives toMap/fromMap roundtrip', () { + final original = DBPetFace.empty(99); + final map = original.toMap(); + final restored = DBPetFace.fromMap(map); + + expect(restored.fileId, original.fileId); + expect(restored.petFaceId, original.petFaceId); + expect(restored.detection, original.detection); + expect(restored.faceVectorId, original.faceVectorId); + expect(restored.species, original.species); + expect(restored.faceScore, original.faceScore); + expect(restored.imageHeight, original.imageHeight); + expect(restored.imageWidth, original.imageWidth); + expect(restored.mlVersion, original.mlVersion); + }); + + test('real entry survives toMap/fromMap roundtrip', () { + final original = DBPetFace( + fileId: 10, + petFaceId: '10_pet_0.1_0.2_0.3_0.4', + detection: '{"box":[0.1,0.2,0.3,0.4]}', + faceVectorId: 5, + species: 0, + faceScore: 0.95, + imageHeight: 1080, + imageWidth: 1920, + mlVersion: petMlVersion, + ); + final restored = DBPetFace.fromMap(original.toMap()); + + expect(restored.fileId, 10); + expect(restored.species, 0); + expect(restored.faceScore, closeTo(0.95, 0.001)); + expect(restored.faceVectorId, 5); + }); + + test('toMap does not contain embedding column', () { + final face = DBPetFace.empty(1); + final map = face.toMap(); + + expect(map.containsKey('pet_face_embedding'), isFalse); + }); + }); + + // ── SQLite query tests (in-memory database) ── + + group('Pet faces DB queries', () { + late SqliteDatabase db; + + setUp(() async { + // Create a temp file for the database since sqlite_async needs a path + final tempDir = Directory.systemTemp.createTempSync('pet_test_'); + final dbPath = '${tempDir.path}/test_pet.db'; + db = SqliteDatabase(path: dbPath); + + // Run pet-related schema migrations + await db.writeTransaction((tx) async { + await tx.execute(createPetFacesTable); + }); + }); + + tearDown(() async { + await db.close(); + }); + + Future insertPetFace(DBPetFace face) async { + final map = face.toMap(); + await db.execute( + '''INSERT INTO $petFacesTable ( + $fileIDColumn, $petFaceIDColumn, $faceDetectionColumn, + $faceVectorIdColumn, $speciesColumn, $faceScore, + $imageHeight, $imageWidth, $mlVersionColumn + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT($fileIDColumn, $petFaceIDColumn) DO UPDATE SET + $faceDetectionColumn = excluded.$faceDetectionColumn, + $faceVectorIdColumn = excluded.$faceVectorIdColumn, + $speciesColumn = excluded.$speciesColumn, + $faceScore = excluded.$faceScore, + $imageHeight = excluded.$imageHeight, + $imageWidth = excluded.$imageWidth, + $mlVersionColumn = excluded.$mlVersionColumn + ''', + [ + map[fileIDColumn], + map[petFaceIDColumn], + map[faceDetectionColumn], + map[faceVectorIdColumn], + map[speciesColumn], + map['score'], + map['height'], + map['width'], + map[mlVersionColumn], + ], + ); + } + + // ── petIndexedFileIds (should include dummies) ── + + test('petIndexedFileIds includes dummy entries', () async { + await insertPetFace(DBPetFace.empty(1)); + await insertPetFace(DBPetFace.empty(2, error: true)); + + final rows = await db.getAll( + 'SELECT $fileIDColumn, $mlVersionColumn FROM $petFacesTable ' + 'WHERE $mlVersionColumn >= $petMlVersion', + ); + + final indexed = {for (final r in rows) r[fileIDColumn] as int}; + expect(indexed, contains(1)); + expect(indexed, contains(2)); + }); + + // ── getPetFacesForFileID (should exclude dummies) ── + + test('getPetFacesForFileID excludes dummy entries', () async { + // Insert a dummy and a real face for the same file + await insertPetFace(DBPetFace.empty(10)); + await insertPetFace( + DBPetFace( + fileId: 10, + petFaceId: '10_pet_0.1_0.2_0.3_0.4', + detection: '{"box":[0.1,0.2,0.3,0.4]}', + faceVectorId: 5, + species: 0, + faceScore: 0.92, + imageHeight: 1080, + imageWidth: 1920, + mlVersion: petMlVersion, + ), + ); + + // Query with species filter (matches production query) + final rows = await db.getAll( + 'SELECT * FROM $petFacesTable ' + 'WHERE $fileIDColumn = ? AND $speciesColumn != -1', + [10], + ); + + expect(rows.length, 1); + final face = DBPetFace.fromMap(rows.first); + expect(face.petFaceId, '10_pet_0.1_0.2_0.3_0.4'); + expect(face.species, 0); + }); + + test('getPetFacesForFileID returns null-equivalent for dummy-only files', + () async { + await insertPetFace(DBPetFace.empty(20)); + + final rows = await db.getAll( + 'SELECT * FROM $petFacesTable ' + 'WHERE $fileIDColumn = ? AND $speciesColumn != -1', + [20], + ); + + expect(rows, isEmpty); + }); + + // ── getPetIndexedFileCount (DISTINCT, includes dummies) ── + + test('getPetIndexedFileCount counts distinct files including dummies', + () async { + // File 1: dummy only + await insertPetFace(DBPetFace.empty(1)); + // File 2: two real faces + await insertPetFace( + DBPetFace( + fileId: 2, + petFaceId: '2_pet_a', + detection: '{"box":[0.1,0.2,0.3,0.4]}', + faceVectorId: 1, + species: 0, + faceScore: 0.9, + imageHeight: 100, + imageWidth: 100, + mlVersion: petMlVersion, + ), + ); + await insertPetFace( + DBPetFace( + fileId: 2, + petFaceId: '2_pet_b', + detection: '{"box":[0.5,0.6,0.7,0.8]}', + faceVectorId: 2, + species: 1, + faceScore: 0.85, + imageHeight: 100, + imageWidth: 100, + mlVersion: petMlVersion, + ), + ); + + final countRows = await db.getAll( + 'SELECT COUNT(DISTINCT $fileIDColumn) as count FROM $petFacesTable ' + 'WHERE $mlVersionColumn >= $petMlVersion', + ); + + expect(countRows.first['count'], 2); + }); + + // ── ON CONFLICT: dummy replaced when same ID re-inserted ── + + test('re-indexing same file replaces dummy via ON CONFLICT', () async { + // Initial: no pets found + await insertPetFace(DBPetFace.empty(30)); + + var rows = await db.getAll( + 'SELECT * FROM $petFacesTable WHERE $fileIDColumn = ?', + [30], + ); + expect(rows.length, 1); + expect(DBPetFace.fromMap(rows.first).species, -1); + + // Re-index with same dummy ID (e.g. mlVersion bump, still no pets) + await insertPetFace(DBPetFace.empty(30)); + + rows = await db.getAll( + 'SELECT * FROM $petFacesTable WHERE $fileIDColumn = ?', + [30], + ); + // ON CONFLICT should update, not duplicate + expect(rows.length, 1); + }); + + // ── Coexistence: dummy + real faces for same file ── + + test('dummy and real faces coexist for the same file', () async { + await insertPetFace(DBPetFace.empty(40)); + await insertPetFace( + DBPetFace( + fileId: 40, + petFaceId: '40_pet_0.2_0.3_0.6_0.7', + detection: '{"box":[0.2,0.3,0.6,0.7]}', + faceVectorId: 3, + species: 1, + faceScore: 0.88, + imageHeight: 720, + imageWidth: 1280, + mlVersion: petMlVersion, + ), + ); + + // All rows (for indexing tracking) + final allRows = await db.getAll( + 'SELECT * FROM $petFacesTable WHERE $fileIDColumn = ?', + [40], + ); + expect(allRows.length, 2); + + // Filtered rows (for UI display) + final realRows = await db.getAll( + 'SELECT * FROM $petFacesTable ' + 'WHERE $fileIDColumn = ? AND $speciesColumn != -1', + [40], + ); + expect(realRows.length, 1); + expect(DBPetFace.fromMap(realRows.first).species, 1); + }); + + // ── Delete cleans up both dummy and real entries ── + + test('deletePetDataForFiles removes dummy and real entries', () async { + await insertPetFace(DBPetFace.empty(50)); + await insertPetFace( + DBPetFace( + fileId: 50, + petFaceId: '50_pet_real', + detection: '{"box":[0.1,0.2,0.3,0.4]}', + faceVectorId: 1, + species: 0, + faceScore: 0.9, + imageHeight: 100, + imageWidth: 100, + mlVersion: petMlVersion, + ), + ); + + await db.execute( + 'DELETE FROM $petFacesTable WHERE $fileIDColumn IN (50)', + ); + + final rows = await db.getAll( + 'SELECT * FROM $petFacesTable WHERE $fileIDColumn = ?', + [50], + ); + expect(rows, isEmpty); + }); + + // ── Error dummy (-1 score) is still tracked as indexed ── + + test('error dummy entries are tracked as indexed', () async { + await insertPetFace(DBPetFace.empty(60, error: true)); + + final rows = await db.getAll( + 'SELECT $fileIDColumn FROM $petFacesTable ' + 'WHERE $mlVersionColumn >= $petMlVersion', + ); + + expect(rows.length, 1); + expect(rows.first[fileIDColumn], 60); + }); + + // ── Error dummies are excluded from UI queries ── + + test('error dummy entries are excluded from UI queries', () async { + await insertPetFace(DBPetFace.empty(60, error: true)); + + final rows = await db.getAll( + 'SELECT * FROM $petFacesTable ' + 'WHERE $fileIDColumn = ? AND $speciesColumn != -1', + [60], + ); + + expect(rows, isEmpty); + }); + + // ── Vector ID mapping: faceVectorId tracks vector DB entry ── + + test('faceVectorId is stored and retrievable', () async { + await insertPetFace( + DBPetFace( + fileId: 70, + petFaceId: '70_pet_real', + detection: '{"box":[0.1,0.2,0.3,0.4]}', + faceVectorId: 42, + species: 0, + faceScore: 0.9, + imageHeight: 100, + imageWidth: 100, + mlVersion: petMlVersion, + ), + ); + + final rows = await db.getAll( + 'SELECT $faceVectorIdColumn FROM $petFacesTable ' + 'WHERE $petFaceIDColumn = ?', + ['70_pet_real'], + ); + + expect(rows.first[faceVectorIdColumn], 42); + }); + + test('faceVectorId can be updated after initial insert', () async { + // Insert with placeholder vectorId = -1 + await insertPetFace( + DBPetFace( + fileId: 80, + petFaceId: '80_pet_real', + detection: '{"box":[0.1,0.2,0.3,0.4]}', + faceVectorId: null, + species: 1, + faceScore: 0.85, + imageHeight: 100, + imageWidth: 100, + mlVersion: petMlVersion, + ), + ); + + // Simulate updatePetFaceVectorIds + await db.execute( + 'UPDATE $petFacesTable SET $faceVectorIdColumn = ? ' + 'WHERE $petFaceIDColumn = ?', + [99, '80_pet_real'], + ); + + final rows = await db.getAll( + 'SELECT $faceVectorIdColumn FROM $petFacesTable ' + 'WHERE $petFaceIDColumn = ?', + ['80_pet_real'], + ); + + expect(rows.first[faceVectorIdColumn], 99); + }); + }); +} diff --git a/rust/photos/src/ml/cluster.rs b/rust/photos/src/ml/cluster.rs new file mode 100644 index 00000000000..515ecb233d8 --- /dev/null +++ b/rust/photos/src/ml/cluster.rs @@ -0,0 +1,489 @@ +//! Shared clustering algorithms used by the pet pipeline. +//! +//! Contains agglomerative clustering (average linkage) and helper functions. + +use std::cmp::Ordering; +use std::collections::BinaryHeap; +use std::collections::HashMap; + +// ── Agglomerative clustering (average linkage, precomputed) ───────────── + +/// Hierarchical agglomerative clustering with average linkage on a +/// precomputed distance matrix. Cuts the dendrogram at `threshold`. +/// +/// Mirrors Python's `sklearn.cluster.AgglomerativeClustering( +/// metric="precomputed", linkage="average", distance_threshold=threshold)`. +pub fn agglomerative_precomputed(dist: &[f32], n: usize, threshold: f32) -> Vec { + agglomerative_precomputed_min_size(dist, n, threshold, 1) +} + +/// Same as [agglomerative_precomputed] but clusters smaller than +/// `min_cluster_size` are marked as noise (-1). +pub fn agglomerative_precomputed_min_size( + dist: &[f32], + n: usize, + threshold: f32, + min_cluster_size: usize, +) -> Vec { + agglomerative_precomputed_min_size_heap(dist, n, threshold, min_cluster_size) +} + +/// Naive exact average-linkage agglomerative clustering. +/// +/// This retains the original O(n^3)-like merge selection logic for +/// verification and benchmarking. Production code should use +/// [agglomerative_precomputed_min_size], which now routes to the optimized +/// heap-based exact implementation. +pub fn agglomerative_precomputed_min_size_naive( + dist: &[f32], + n: usize, + threshold: f32, + min_cluster_size: usize, +) -> Vec { + if n == 0 { + return vec![]; + } + if n == 1 { + return if min_cluster_size <= 1 { + vec![0] + } else { + vec![-1] + }; + } + + let mut clusters: Vec>> = (0..n).map(|i| Some(vec![i])).collect(); + let mut active: Vec = (0..n).collect(); + + let mut cdist: Vec = dist.to_vec(); + let cap = n; + + loop { + if active.len() < 2 { + break; + } + + let mut best_d = f32::INFINITY; + let mut best_i = 0; + let mut best_j = 0; + for ai in 0..active.len() { + for aj in (ai + 1)..active.len() { + let ci = active[ai]; + let cj = active[aj]; + let d = cdist[ci * cap + cj]; + if d < best_d { + best_d = d; + best_i = ai; + best_j = aj; + } + } + } + + if best_d > threshold { + break; + } + + let ci = active[best_i]; + let cj = active[best_j]; + + let size_i = clusters[ci].as_ref().unwrap().len(); + let size_j = clusters[cj].as_ref().unwrap().len(); + let merged_size = size_i + size_j; + + // Average-linkage update: d(i∪j, k) = (size_i*d(i,k) + size_j*d(j,k)) / merged + for &ck in &active { + if ck == ci || ck == cj { + continue; + } + let d_ik = cdist[ci.min(ck) * cap + ci.max(ck)]; + let d_jk = cdist[cj.min(ck) * cap + cj.max(ck)]; + let new_d = (size_i as f32 * d_ik + size_j as f32 * d_jk) / merged_size as f32; + let (lo, hi) = (ci.min(ck), ci.max(ck)); + cdist[lo * cap + hi] = new_d; + cdist[hi * cap + lo] = new_d; + } + + let cj_members = clusters[cj].take().unwrap(); + clusters[ci].as_mut().unwrap().extend(cj_members); + + active.remove(best_j); + } + + let mut labels = vec![-1i32; n]; + let mut next_label = 0i32; + for &ci in &active { + if let Some(members) = &clusters[ci] + && members.len() >= min_cluster_size + { + for &m in members { + labels[m] = next_label; + } + next_label += 1; + } + } + + labels +} + +/// Optimized exact average-linkage agglomerative clustering using +/// lazy nearest-neighbor tracking plus a heap of best merge candidates. +/// +/// This keeps the same distance-matrix representation as +/// [agglomerative_precomputed_min_size], but avoids a full O(n^2) scan +/// for the best merge on every iteration. +pub fn agglomerative_precomputed_min_size_heap( + dist: &[f32], + n: usize, + threshold: f32, + min_cluster_size: usize, +) -> Vec { + if n == 0 { + return vec![]; + } + if n == 1 { + return if min_cluster_size <= 1 { + vec![0] + } else { + vec![-1] + }; + } + + let mut clusters: Vec>> = (0..n).map(|i| Some(vec![i])).collect(); + let mut active = vec![true; n]; + let mut sizes = vec![1usize; n]; + let mut cdist: Vec = dist.to_vec(); + let cap = n; + + let mut nearest: Vec> = vec![None; n]; + let mut heap = BinaryHeap::new(); + + for (i, nearest_i) in nearest.iter_mut().enumerate() { + if let Some((j, d)) = recompute_nearest(i, &active, &cdist, cap) { + *nearest_i = Some((j, d)); + heap.push(Candidate { + dist: d, + from: i, + to: j, + }); + } + } + + let mut active_count = n; + while active_count >= 2 { + let Some(Candidate { + dist: best_d, + from: ci, + to: cj, + }) = heap.pop() + else { + break; + }; + + if !active[ci] || !active[cj] { + continue; + } + if let Some((n_to, n_dist)) = nearest[ci] { + if n_to != cj || !same_distance(n_dist, best_d) { + continue; + } + } else { + continue; + } + + if best_d > threshold { + break; + } + + let size_i = sizes[ci]; + let size_j = sizes[cj]; + let merged_size = size_i + size_j; + + for ck in 0..n { + if !active[ck] || ck == ci || ck == cj { + continue; + } + let d_ik = cdist[ci * cap + ck]; + let d_jk = cdist[cj * cap + ck]; + let new_d = (size_i as f32 * d_ik + size_j as f32 * d_jk) / merged_size as f32; + cdist[ci * cap + ck] = new_d; + cdist[ck * cap + ci] = new_d; + } + + let cj_members = clusters[cj].take().unwrap(); + clusters[ci].as_mut().unwrap().extend(cj_members); + active[cj] = false; + nearest[cj] = None; + sizes[ci] = merged_size; + active_count -= 1; + + if let Some((to, d)) = recompute_nearest(ci, &active, &cdist, cap) { + nearest[ci] = Some((to, d)); + heap.push(Candidate { + dist: d, + from: ci, + to, + }); + } else { + nearest[ci] = None; + } + + for ck in 0..n { + if !active[ck] || ck == ci { + continue; + } + + let should_recompute = match nearest[ck] { + None => true, + Some((to, d)) => to == ci || to == cj || cdist[ck * cap + ci] < d, + }; + + if should_recompute { + if let Some((to, d)) = recompute_nearest(ck, &active, &cdist, cap) { + nearest[ck] = Some((to, d)); + heap.push(Candidate { + dist: d, + from: ck, + to, + }); + } else { + nearest[ck] = None; + } + } + } + } + + let mut labels = vec![-1i32; n]; + let mut next_label = 0i32; + for ci in 0..n { + if active[ci] + && let Some(members) = &clusters[ci] + && members.len() >= min_cluster_size + { + for &m in members { + labels[m] = next_label; + } + next_label += 1; + } + } + + labels +} + +#[derive(Clone, Copy, Debug)] +struct Candidate { + dist: f32, + from: usize, + to: usize, +} + +impl PartialEq for Candidate { + fn eq(&self, other: &Self) -> bool { + self.from == other.from && self.to == other.to && same_distance(self.dist, other.dist) + } +} + +impl Eq for Candidate {} + +impl PartialOrd for Candidate { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Candidate { + fn cmp(&self, other: &Self) -> Ordering { + other + .dist + .total_cmp(&self.dist) + .then_with(|| other.from.cmp(&self.from)) + .then_with(|| other.to.cmp(&self.to)) + } +} + +fn recompute_nearest(i: usize, active: &[bool], cdist: &[f32], cap: usize) -> Option<(usize, f32)> { + if !active[i] { + return None; + } + + let mut best: Option<(usize, f32)> = None; + for j in 0..active.len() { + if !active[j] || i == j { + continue; + } + let d = cdist[i * cap + j]; + match best { + None => best = Some((j, d)), + Some((best_j, best_d)) => { + if d < best_d || (same_distance(d, best_d) && j < best_j) { + best = Some((j, d)); + } + } + } + } + best +} + +fn same_distance(a: f32, b: f32) -> bool { + (a - b).abs() <= 1e-7 +} + +// ── Helper functions ──────────────────────────────────────────────────── + +pub fn dot(a: &[f32], b: &[f32]) -> f32 { + a.iter().zip(b.iter()).map(|(x, y)| x * y).sum() +} + +pub fn l2_norm(v: &[f32]) -> f32 { + dot(v, v).sqrt() +} + +pub fn normalize(v: &mut [f32]) { + let n = l2_norm(v); + if n > 1e-8 { + for x in v.iter_mut() { + *x /= n; + } + } +} + +/// Compute L2-normalized median centroid from a list of embeddings. +pub fn median_centroid(embs: &[&Vec], dim: usize) -> Vec { + if embs.is_empty() { + return vec![0.0; dim]; + } + let mut centroid = vec![0.0f32; dim]; + for d in 0..dim { + let mut vals: Vec = embs.iter().map(|e| e[d]).collect(); + vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + let mid = vals.len() / 2; + centroid[d] = if vals.len().is_multiple_of(2) { + (vals[mid - 1] + vals[mid]) / 2.0 + } else { + vals[mid] + }; + } + normalize(&mut centroid); + centroid +} + +/// Select up to `k` diverse exemplars from a set of embeddings. +/// +/// Uses greedy farthest-first traversal: +/// 1. Start with the embedding closest to the mean centroid (most typical). +/// 2. Repeatedly add the embedding farthest from all already-selected exemplars. +/// +/// This gives good coverage of the cluster's shape without storing every member. +pub fn select_exemplars(embs: &[&Vec], k: usize, dim: usize) -> Vec> { + let n = embs.len(); + if n == 0 { + return Vec::new(); + } + if n <= k { + return embs.iter().map(|e| (*e).clone()).collect(); + } + + let centroid = mean_centroid(embs, dim); + + let mut selected: Vec = Vec::with_capacity(k); + + // First exemplar: closest to the centroid (most representative) + let first = (0..n) + .max_by(|&a, &b| { + dot(embs[a], ¢roid) + .partial_cmp(&dot(embs[b], ¢roid)) + .unwrap_or(std::cmp::Ordering::Equal) + }) + .unwrap(); + selected.push(first); + + // Greedy farthest-first: pick the point whose nearest selected exemplar + // is as far away as possible (maximises diversity). + while selected.len() < k { + let best = (0..n) + .filter(|i| !selected.contains(i)) + .min_by(|&a, &b| { + // max similarity to any selected exemplar + let max_sim_a = selected + .iter() + .map(|&s| dot(embs[a], embs[s])) + .fold(f32::NEG_INFINITY, f32::max); + let max_sim_b = selected + .iter() + .map(|&s| dot(embs[b], embs[s])) + .fold(f32::NEG_INFINITY, f32::max); + // lower max-similarity = farther = more diverse + max_sim_a + .partial_cmp(&max_sim_b) + .unwrap_or(std::cmp::Ordering::Equal) + }) + .unwrap(); + selected.push(best); + } + + selected.iter().map(|&i| embs[i].clone()).collect() +} + +/// Compute L2-normalized mean centroid from a list of embeddings. +pub fn mean_centroid(embs: &[&Vec], dim: usize) -> Vec { + if embs.is_empty() { + return vec![0.0; dim]; + } + let mut centroid = vec![0.0f32; dim]; + for emb in embs { + for (i, &v) in emb.iter().enumerate() { + if i < dim { + centroid[i] += v; + } + } + } + let count = embs.len() as f32; + for v in centroid.iter_mut() { + *v /= count; + } + normalize(&mut centroid); + centroid +} + +pub fn unique_cluster_ids(labels: &[i32]) -> Vec { + let mut ids: Vec = labels + .iter() + .copied() + .filter(|&l| l >= 0) + .collect::>() + .into_iter() + .collect(); + ids.sort(); + ids +} + +pub fn renumber_labels(labels: &mut [i32]) { + let unique = unique_cluster_ids(labels); + let mapping: HashMap = unique + .iter() + .enumerate() + .map(|(new, &old)| (old, new as i32)) + .collect(); + for label in labels.iter_mut() { + if *label >= 0 + && let Some(&new) = mapping.get(label) + { + *label = new; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_agglomerative_empty() { + let labels = agglomerative_precomputed(&[], 0, 0.5); + assert!(labels.is_empty()); + } + + #[test] + fn test_agglomerative_single() { + let labels = agglomerative_precomputed(&[0.0], 1, 0.5); + assert_eq!(labels, vec![0]); + } +} diff --git a/rust/photos/src/ml/indexing.rs b/rust/photos/src/ml/indexing.rs index 86689f80c9b..4bb7f9031f8 100644 --- a/rust/photos/src/ml/indexing.rs +++ b/rust/photos/src/ml/indexing.rs @@ -8,12 +8,11 @@ use crate::{ error::{MlError, MlResult}, face::{align::run_face_alignment, detect::run_face_detection, embed::run_face_embedding}, pet::{ - align::run_pet_face_alignment, - detect::{run_pet_body_detection, run_pet_face_detection}, - embed::{run_pet_body_embedding, run_pet_face_embedding}, + align::run_pet_face_alignment, detect::run_pet_face_detection, + embed::run_pet_face_embedding, }, runtime::{self, ExecutionProviderPolicy, MlRuntimeConfig, ModelPaths}, - types::{self, ClipResult, Dimensions, FaceResult, PetBodyResult, PetFaceResult}, + types::{ClipResult, Dimensions, FaceResult, PetBodyResult, PetFaceResult}, }, }; @@ -74,13 +73,18 @@ pub fn analyze_image(req: AnalyzeImageRequest) -> MlResult { let decoded = decode_image_from_path(&image_path)?; let dims = decoded.dimensions.clone(); + let face_detections = if run_faces { + run_face_detection(runtime, &decoded)? + } else { + Vec::new() + }; + let faces = if run_faces { - let detections = run_face_detection(runtime, &decoded)?; - if detections.is_empty() { + if face_detections.is_empty() { Some(Vec::new()) } else { let (aligned, mut face_results) = - run_face_alignment(file_id, &decoded, detections)?; + run_face_alignment(file_id, &decoded, face_detections)?; run_face_embedding(runtime, &aligned, &mut face_results)?; Some(face_results) } @@ -94,39 +98,19 @@ pub fn analyze_image(req: AnalyzeImageRequest) -> MlResult { None }; - let (pet_faces, pet_bodies) = if run_pets { + let pet_faces = if run_pets { let pet_face_detections = run_pet_face_detection(runtime, &decoded)?; - let body_detections = run_pet_body_detection(runtime, &decoded)?; - let pet_face_results = if !pet_face_detections.is_empty() { + if !pet_face_detections.is_empty() { let (aligned, mut pet_results) = run_pet_face_alignment(file_id, &decoded, &pet_face_detections)?; run_pet_face_embedding(runtime, &aligned, &mut pet_results)?; - pet_results + Some(pet_results) } else { - Vec::new() - }; - - let mut body_results: Vec = body_detections - .into_iter() - .map(|det| { - let base_id = types::to_face_id(file_id, det.box_xyxy); - let pet_body_id = format!("{base_id}_c{}", det.coco_class); - PetBodyResult { - pet_body_id, - detection: det, - body_embedding: Vec::new(), - } - }) - .collect(); - - if !body_results.is_empty() { - run_pet_body_embedding(runtime, &decoded, &mut body_results)?; + Some(Vec::new()) } - - (Some(pet_face_results), Some(body_results)) } else { - (None, None) + None }; Ok(AnalyzeImageResult { @@ -135,7 +119,7 @@ pub fn analyze_image(req: AnalyzeImageRequest) -> MlResult { faces, clip, pet_faces, - pet_bodies, + pet_bodies: None, }) }) } @@ -211,21 +195,12 @@ fn validate_request_model_paths(req: &AnalyzeImageRequest) -> MlResult<()> { if model_paths.pet_face_detection.trim().is_empty() { missing.push("petFaceDetectionModelPath"); } - if model_paths.pet_body_detection.trim().is_empty() { - missing.push("petBodyDetectionModelPath"); - } if model_paths.pet_face_embedding_dog.trim().is_empty() { missing.push("petFaceEmbeddingDogModelPath"); } if model_paths.pet_face_embedding_cat.trim().is_empty() { missing.push("petFaceEmbeddingCatModelPath"); } - if model_paths.pet_body_embedding_dog.trim().is_empty() { - missing.push("petBodyEmbeddingDogModelPath"); - } - if model_paths.pet_body_embedding_cat.trim().is_empty() { - missing.push("petBodyEmbeddingCatModelPath"); - } } if missing.is_empty() { return Ok(()); diff --git a/rust/photos/src/ml/mod.rs b/rust/photos/src/ml/mod.rs index a492cbe1ad3..c13151989d1 100644 --- a/rust/photos/src/ml/mod.rs +++ b/rust/photos/src/ml/mod.rs @@ -1,4 +1,5 @@ pub mod clip; +pub mod cluster; pub mod error; pub mod face; pub mod indexing; diff --git a/rust/photos/src/ml/onnx.rs b/rust/photos/src/ml/onnx.rs index d0d7e6f6f46..656caad749f 100644 --- a/rust/photos/src/ml/onnx.rs +++ b/rust/photos/src/ml/onnx.rs @@ -196,8 +196,19 @@ pub fn run_i32_f32( return Err(MlError::Ort("missing first output tensor".to_string())); } let output = &outputs[0]; - let tensor = output.try_extract_tensor::()?; - let shape = tensor.shape().iter().map(|d| *d as i64).collect::>(); - let data = tensor.iter().copied().collect::>(); - Ok((shape, data)) + + // Extract output: try f32 first, fall back to f16 with conversion. + if let Ok(tensor) = output.try_extract_tensor::() { + let shape = tensor.shape().iter().map(|d| *d as i64).collect::>(); + let data = tensor.iter().copied().collect::>(); + Ok((shape, data)) + } else { + let tensor = output.try_extract_tensor::()?; + let shape = tensor.shape().iter().map(|d| *d as i64).collect::>(); + let data = tensor + .iter() + .map(|v: &half::f16| v.to_f32()) + .collect::>(); + Ok((shape, data)) + } } diff --git a/rust/photos/src/ml/pet/align.rs b/rust/photos/src/ml/pet/align.rs index bd96dcc3b8f..046a8cd70ac 100644 --- a/rust/photos/src/ml/pet/align.rs +++ b/rust/photos/src/ml/pet/align.rs @@ -20,9 +20,9 @@ const CROP_EXPAND: f32 = 0.1; /// /// Mirrors `pet_pipeline/detection.py` `_align_face()`: /// 1. Skip if eye distance < 5 px -/// 2. If angle < 1°, crop bounding box directly (no rotation) +/// 2. If angle < 1 deg, crop bounding box directly (no rotation) /// 3. Otherwise rotate around face center, then crop with 10% expand -/// 4. Resize to 224×224 +/// 4. Resize to 224x224 /// 5. Apply ImageNet normalization (CHW) pub fn run_pet_face_alignment( file_id: i64, @@ -69,7 +69,7 @@ pub fn run_pet_face_alignment( let angle_rad = dy.atan2(dx); let aligned_rgb = if angle_deg.abs() < ANGLE_SKIP_DEG { - // No rotation needed — just crop the bounding box directly + // No rotation needed -- just crop the bounding box directly let cx1 = box_x1.max(0) as u32; let cy1 = box_y1.max(0) as u32; let cx2 = (box_x2 as u32).min(img_w); @@ -204,7 +204,7 @@ fn rotate_around_center(source: &RgbImage, angle_rad: f64, cx: f64, cy: f64) -> output } -/// Crop a region from an RGB image and resize to 224×224 using bilinear interpolation. +/// Crop a region from an RGB image and resize to 224x224 using bilinear interpolation. fn crop_and_resize_rgb(source: &RgbImage, x: u32, y: u32, w: u32, h: u32) -> MlResult { // Extract crop bytes let src_w = source.width(); @@ -266,7 +266,7 @@ fn extract_rgb_region( .ok_or_else(|| MlError::Preprocess("failed to build region image".to_string())) } -/// Crop directly from decoded image bytes and resize — avoids building a +/// Crop directly from decoded image bytes and resize -- avoids building a /// full-size RgbImage when no rotation is needed. fn crop_and_resize_decoded( decoded: &DecodedImage, diff --git a/rust/photos/src/ml/pet/cluster.rs b/rust/photos/src/ml/pet/cluster.rs new file mode 100644 index 00000000000..1792eaf41e8 --- /dev/null +++ b/rust/photos/src/ml/pet/cluster.rs @@ -0,0 +1,1104 @@ +//! Pet clustering engine — face-based agglomerative clustering. +//! +//! Clusters pet face embeddings using average-linkage agglomerative +//! clustering with species-specific distance thresholds. +//! +//! All embeddings are assumed L2-normalized (cosine distance = 1 − dot). + +use std::collections::HashMap; + +use crate::ml::cluster::{ + agglomerative_precomputed_min_size, dot, mean_centroid, renumber_labels, select_exemplars, + unique_cluster_ids, +}; + +// ── Species-specific configuration ────────────────────────────────────── + +/// Species identifier for threshold lookup. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Species { + Dog = 0, + Cat = 1, +} + +impl Species { + pub fn from_u8(v: u8) -> Self { + match v { + 0 => Species::Dog, + _ => Species::Cat, + } + } +} + +/// Clustering thresholds per species. +#[derive(Clone, Debug)] +pub struct ClusterConfig { + pub species: Species, + /// Minimum number of faces to form a cluster. + pub min_cluster_size: usize, + /// Distance threshold for agglomerative clustering (average linkage). + /// Lower = tighter clusters, higher = more permissive merging. + pub agglomerative_threshold: f32, + /// Maximum number of exemplar embeddings to store per cluster. + /// Used for multi-exemplar incremental matching. + pub max_exemplars: usize, +} + +impl ClusterConfig { + pub fn dog() -> Self { + Self { + species: Species::Dog, + min_cluster_size: 2, + agglomerative_threshold: 0.625, + max_exemplars: 5, + } + } + + pub fn cat() -> Self { + Self { + species: Species::Cat, + min_cluster_size: 2, + agglomerative_threshold: 0.75, + max_exemplars: 5, + } + } + + pub fn for_species(species: Species) -> Self { + match species { + Species::Dog => Self::dog(), + Species::Cat => Self::cat(), + } + } +} + +// ── Input / Output types ──────────────────────────────────────────────── + +/// One image's data for clustering. Index-aligned across the batch. +#[derive(Clone, Debug)] +pub struct PetClusterInput { + /// Unique ID for this pet face (from indexing). Used as the key in results. + pub pet_face_id: String, + /// L2-normalized face embedding (128-d). Empty vec if no face. + pub face_embedding: Vec, + /// 0 = dog, 1 = cat. + pub species: u8, + /// File ID that this detection belongs to. + pub file_id: i64, +} + +impl PetClusterInput { + pub fn has_face(&self) -> bool { + !self.face_embedding.is_empty() + } +} + +/// Result of clustering: maps each pet_face_id to a cluster_id string. +#[derive(Clone, Debug, Default)] +pub struct PetClusterResult { + /// pet_face_id → cluster_id (UUID-style string). + pub face_to_cluster: HashMap, + /// cluster_id → centroid face embedding (L2-normalized). + pub cluster_centroids: HashMap>, + /// cluster_id → diverse exemplar embeddings (real faces, not averaged). + /// Used for multi-exemplar incremental matching. + pub cluster_exemplars: HashMap>>, + /// cluster_id → member count. + pub cluster_counts: HashMap, + /// Number of inputs that remained unclustered. + pub n_unclustered: usize, +} + +// ── Core clustering engine ────────────────────────────────────────────── + +/// Run face-only pet clustering (agglomerative average linkage). +/// +/// This is the main entry point called from the API layer. +pub fn run_pet_clustering(inputs: &[PetClusterInput], config: &ClusterConfig) -> PetClusterResult { + let n = inputs.len(); + if n == 0 { + return PetClusterResult::default(); + } + + let has_face: Vec = inputs.iter().map(|i| i.has_face()).collect(); + + // Face-based agglomerative clustering + let mut labels = phase1_face_cluster(inputs, &has_face, config); + renumber_labels(&mut labels); + + build_result(inputs, &labels, &has_face, config) +} + +/// Run incremental face-only clustering: assign new inputs to existing +/// face clusters by centroid similarity, then cluster the remainder. +/// +/// Uses a relaxed threshold for centroid matching (centroid is an average +/// that doesn't represent any individual face perfectly, so it needs more +/// slack than pairwise comparisons in batch mode). +pub fn run_pet_clustering_incremental( + new_inputs: &[PetClusterInput], + existing_centroids_face: &HashMap>, + config: &ClusterConfig, +) -> PetClusterResult { + let n = new_inputs.len(); + if n == 0 { + return PetClusterResult::default(); + } + + let mut labels = vec![-1i32; n]; + let mut cluster_name_map: HashMap = HashMap::new(); + + // Step 1: Try to assign each new face to the closest existing cluster. + // Use a relaxed threshold: centroids are averages that drift from + // individual members, so we allow 15% more distance than batch mode. + let centroid_threshold = config.agglomerative_threshold * 1.15; + let min_sim = 1.0 - centroid_threshold; + + for (i, inp) in new_inputs.iter().enumerate() { + if !inp.has_face() { + continue; + } + + let mut best_sim = -1.0f32; + let mut second_best_sim = -1.0f32; + let mut best_id: Option<&String> = None; + + for (cluster_id, centroid) in existing_centroids_face { + let sim = dot(&inp.face_embedding, centroid); + if sim > best_sim { + second_best_sim = best_sim; + best_sim = sim; + best_id = Some(cluster_id); + } else if sim > second_best_sim { + second_best_sim = sim; + } + } + + // Assign if: + // 1. Distance to best centroid is below relaxed threshold + // 2. Best is clearly better than second-best (margin > 0.05) + // to avoid ambiguous assignments + let has_clear_winner = + existing_centroids_face.len() <= 1 || (best_sim - second_best_sim) > 0.05; + + if best_sim > min_sim + && has_clear_winner + && let Some(cid) = best_id + { + let numeric = cluster_name_map + .iter() + .find(|(_, v)| *v == cid) + .map(|(k, _)| *k) + .unwrap_or_else(|| { + let new_label = cluster_name_map.len() as i32; + cluster_name_map.insert(new_label, cid.clone()); + new_label + }); + labels[i] = numeric; + } + } + + // Step 2: Cluster unassigned among themselves + let unassigned: Vec = labels + .iter() + .enumerate() + .filter(|(_, l)| **l == -1) + .map(|(i, _)| i) + .collect(); + + if unassigned.len() >= config.min_cluster_size { + let sub_inputs: Vec = + unassigned.iter().map(|&i| new_inputs[i].clone()).collect(); + let sub_result = run_pet_clustering(&sub_inputs, config); + + let mut next_label = cluster_name_map.keys().copied().max().unwrap_or(-1) + 1; + for (sub_idx, &global_idx) in unassigned.iter().enumerate() { + let pet_face_id = &sub_inputs[sub_idx].pet_face_id; + if let Some(cluster_id) = sub_result.face_to_cluster.get(pet_face_id) { + let numeric = cluster_name_map + .iter() + .find(|(_, v)| *v == cluster_id) + .map(|(k, _)| *k) + .unwrap_or_else(|| { + let label = next_label; + next_label += 1; + cluster_name_map.insert(label, cluster_id.clone()); + label + }); + labels[global_idx] = numeric; + } + } + } + + // Build final result + let mut result = PetClusterResult::default(); + for (i, inp) in new_inputs.iter().enumerate() { + if labels[i] >= 0 { + if let Some(cluster_id) = cluster_name_map.get(&labels[i]) { + result + .face_to_cluster + .insert(inp.pet_face_id.clone(), cluster_id.clone()); + *result.cluster_counts.entry(cluster_id.clone()).or_insert(0) += 1; + } + } else { + result.n_unclustered += 1; + } + } + + // Recompute face centroids + for cluster_id in result.cluster_counts.keys() { + let face_embs: Vec<&Vec> = new_inputs + .iter() + .filter(|inp| { + result + .face_to_cluster + .get(&inp.pet_face_id) + .map(|c| c == cluster_id) + .unwrap_or(false) + && inp.has_face() + }) + .map(|inp| &inp.face_embedding) + .collect(); + + if !face_embs.is_empty() { + let centroid = mean_centroid(&face_embs, face_embs[0].len()); + result + .cluster_centroids + .insert(cluster_id.clone(), centroid); + } + } + + result +} + +/// Run incremental clustering using multi-exemplar matching. +/// +/// Instead of comparing new faces against a single mean centroid per cluster, +/// this compares against multiple real exemplar embeddings. A face matches a +/// cluster if it is similar enough to ANY exemplar in that cluster. +/// +/// Benefits over centroid matching: +/// - No "centroid drift" — exemplars are real embeddings that don't degrade. +/// - Captures cluster shape (e.g., front vs. profile views of the same pet). +/// - No need for the 1.15× threshold relaxation hack. +pub fn run_pet_clustering_incremental_with_exemplars( + new_inputs: &[PetClusterInput], + existing_exemplars: &HashMap>>, + config: &ClusterConfig, +) -> PetClusterResult { + let n = new_inputs.len(); + if n == 0 { + return PetClusterResult::default(); + } + + let mut labels = vec![-1i32; n]; + let mut cluster_name_map: HashMap = HashMap::new(); + + // Step 1: Match each new face against all exemplars of each cluster. + // No threshold relaxation needed — we're comparing against real faces. + let min_sim = 1.0 - config.agglomerative_threshold; + + for (i, inp) in new_inputs.iter().enumerate() { + if !inp.has_face() { + continue; + } + + let mut best_sim = f32::NEG_INFINITY; + let mut second_best_sim = f32::NEG_INFINITY; + let mut best_id: Option<&String> = None; + + for (cluster_id, exemplars) in existing_exemplars { + // Best similarity to ANY exemplar in this cluster + let cluster_sim = exemplars + .iter() + .map(|ex| dot(&inp.face_embedding, ex)) + .fold(f32::NEG_INFINITY, f32::max); + + if cluster_sim > best_sim { + second_best_sim = best_sim; + best_sim = cluster_sim; + best_id = Some(cluster_id); + } else if cluster_sim > second_best_sim { + second_best_sim = cluster_sim; + } + } + + // Assign if: + // 1. Similarity to best exemplar exceeds threshold + // 2. Clear winner (margin > 0.05) to avoid ambiguous assignments + let has_clear_winner = existing_exemplars.len() <= 1 || (best_sim - second_best_sim) > 0.05; + + if best_sim > min_sim + && has_clear_winner + && let Some(cid) = best_id + { + let numeric = cluster_name_map + .iter() + .find(|(_, v)| *v == cid) + .map(|(k, _)| *k) + .unwrap_or_else(|| { + let new_label = cluster_name_map.len() as i32; + cluster_name_map.insert(new_label, cid.clone()); + new_label + }); + labels[i] = numeric; + } + } + + // Step 2: Cluster unassigned among themselves + let unassigned: Vec = labels + .iter() + .enumerate() + .filter(|(_, l)| **l == -1) + .map(|(i, _)| i) + .collect(); + + if unassigned.len() >= config.min_cluster_size { + let sub_inputs: Vec = + unassigned.iter().map(|&i| new_inputs[i].clone()).collect(); + let sub_result = run_pet_clustering(&sub_inputs, config); + + let mut next_label = cluster_name_map.keys().copied().max().unwrap_or(-1) + 1; + for (sub_idx, &global_idx) in unassigned.iter().enumerate() { + let pet_face_id = &sub_inputs[sub_idx].pet_face_id; + if let Some(cluster_id) = sub_result.face_to_cluster.get(pet_face_id) { + let numeric = cluster_name_map + .iter() + .find(|(_, v)| *v == cluster_id) + .map(|(k, _)| *k) + .unwrap_or_else(|| { + let label = next_label; + next_label += 1; + cluster_name_map.insert(label, cluster_id.clone()); + label + }); + labels[global_idx] = numeric; + } + } + } + + // Build final result + let face_dim = new_inputs + .iter() + .find(|i| i.has_face()) + .map(|i| i.face_embedding.len()) + .unwrap_or(128); + + let mut result = PetClusterResult::default(); + for (i, inp) in new_inputs.iter().enumerate() { + if labels[i] >= 0 { + if let Some(cluster_id) = cluster_name_map.get(&labels[i]) { + result + .face_to_cluster + .insert(inp.pet_face_id.clone(), cluster_id.clone()); + *result.cluster_counts.entry(cluster_id.clone()).or_insert(0) += 1; + } + } else { + result.n_unclustered += 1; + } + } + + // Compute centroids and exemplars from the new faces in each cluster + for cluster_id in result.cluster_counts.keys() { + let face_embs: Vec<&Vec> = new_inputs + .iter() + .filter(|inp| { + result + .face_to_cluster + .get(&inp.pet_face_id) + .map(|c| c == cluster_id) + .unwrap_or(false) + && inp.has_face() + }) + .map(|inp| &inp.face_embedding) + .collect(); + + if !face_embs.is_empty() { + result + .cluster_centroids + .insert(cluster_id.clone(), mean_centroid(&face_embs, face_dim)); + result.cluster_exemplars.insert( + cluster_id.clone(), + select_exemplars(&face_embs, config.max_exemplars, face_dim), + ); + } + } + + result +} + +// ── Phase 1: Face-based agglomerative clustering ──────────────────────── + +/// Agglomerative clustering (average linkage) on face embeddings. +fn phase1_face_cluster( + inputs: &[PetClusterInput], + has_face: &[bool], + config: &ClusterConfig, +) -> Vec { + let n = inputs.len(); + let mut labels = vec![-1i32; n]; + + let face_indices: Vec = has_face + .iter() + .enumerate() + .filter(|(_, h)| **h) + .map(|(i, _)| i) + .collect(); + + if face_indices.len() < 2 { + return labels; + } + + let nf = face_indices.len(); + + // Guard against excessive memory usage: n^2 * 4 bytes. + // 5000^2 * 4 = ~100MB, which is the upper bound for mobile devices. + if nf > 5000 { + return labels; + } + + // Compute pairwise cosine distance matrix: dist = 1 - dot(a, b) + let mut dist = vec![0.0f32; nf * nf]; + for i in 0..nf { + for j in (i + 1)..nf { + let sim = dot( + &inputs[face_indices[i]].face_embedding, + &inputs[face_indices[j]].face_embedding, + ); + let d = (1.0 - sim).clamp(0.0, 2.0); + dist[i * nf + j] = d; + dist[j * nf + i] = d; + } + } + + // Run configured clustering algorithm on the distance matrix + let face_labels = agglomerative_precomputed_min_size( + &dist, + nf, + config.agglomerative_threshold, + config.min_cluster_size, + ); + + // Map back to global indices + for (local, &global) in face_indices.iter().enumerate() { + labels[global] = face_labels[local]; + } + + labels +} + +// Helper functions (dot, l2_norm, normalize, mean_centroid) imported +// from crate::ml::cluster + +// unique_cluster_ids and renumber_labels imported from crate::ml::cluster + +fn build_result( + inputs: &[PetClusterInput], + labels: &[i32], + has_face: &[bool], + config: &ClusterConfig, +) -> PetClusterResult { + let mut result = PetClusterResult::default(); + + // Generate cluster ID strings (UUID-like from hash for determinism) + let unique = unique_cluster_ids(labels); + let cluster_id_map: HashMap = unique + .iter() + .map(|&c| { + // Create a deterministic cluster ID from the first member's pet_face_id + let first_member = labels + .iter() + .enumerate() + .find(|(_, l)| **l == c) + .map(|(i, _)| i) + .unwrap(); + let id = format!("pet_cluster_{}", inputs[first_member].pet_face_id); + (c, id) + }) + .collect(); + + // Map each input to its cluster + for (i, inp) in inputs.iter().enumerate() { + if labels[i] >= 0 { + if let Some(cluster_id) = cluster_id_map.get(&labels[i]) { + result + .face_to_cluster + .insert(inp.pet_face_id.clone(), cluster_id.clone()); + *result.cluster_counts.entry(cluster_id.clone()).or_insert(0) += 1; + } + } else { + result.n_unclustered += 1; + } + } + + // Compute centroids and exemplars for each cluster + let face_dim = inputs + .iter() + .find(|i| i.has_face()) + .map(|i| i.face_embedding.len()) + .unwrap_or(128); + + for (&numeric, cluster_id) in &cluster_id_map { + let face_embs: Vec<&Vec> = labels + .iter() + .enumerate() + .filter(|(i, l)| **l == numeric && has_face[*i]) + .map(|(i, _)| &inputs[i].face_embedding) + .collect(); + if !face_embs.is_empty() { + result + .cluster_centroids + .insert(cluster_id.clone(), mean_centroid(&face_embs, face_dim)); + result.cluster_exemplars.insert( + cluster_id.clone(), + select_exemplars(&face_embs, config.max_exemplars, face_dim), + ); + } + } + + result +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_input(face_id: &str, face: Vec, species: u8) -> PetClusterInput { + PetClusterInput { + pet_face_id: face_id.to_string(), + face_embedding: face, + species, + file_id: 0, + } + } + + /// Make a face embedding clustered around a base direction. + fn make_face(base_dim: usize, noise_seed: u32) -> Vec { + let mut v = vec![0.0f32; 128]; + v[base_dim] = 1.0; + for k in 0..10u32 { + let dim = (noise_seed.wrapping_mul(7).wrapping_add(k * 13 + 3) % 128) as usize; + let t = (noise_seed as f32 * 0.618 + k as f32 * 1.377).sin(); + v[dim] += t * 0.35; + } + normalized(v) + } + + #[allow(dead_code)] + fn perturb(base: &[f32], dim: usize, amount: f32) -> Vec { + let mut v = base.to_vec(); + let idx = dim % v.len(); + v[idx] += amount; + normalized(v) + } + + fn normalized(mut v: Vec) -> Vec { + let n: f32 = v.iter().map(|x| x * x).sum::().sqrt(); + if n > 0.0 { + for x in v.iter_mut() { + *x /= n; + } + } + v + } + + #[test] + fn test_empty_input() { + let config = ClusterConfig::dog(); + let result = run_pet_clustering(&[], &config); + assert_eq!(result.n_unclustered, 0); + assert!(result.face_to_cluster.is_empty()); + } + + #[test] + fn test_single_input() { + let config = ClusterConfig::dog(); + let inputs = vec![make_input("face_1", normalized(vec![1.0; 128]), 0)]; + let result = run_pet_clustering(&inputs, &config); + // Single input can't form a cluster + assert_eq!(result.n_unclustered, 1); + } + + #[test] + fn test_two_similar_faces_cluster() { + let config = ClusterConfig::dog(); + // HDBSCAN needs density contrast — provide 2 groups + let mut inputs: Vec<_> = (0..5) + .map(|i| make_input(&format!("a{}", i), make_face(0, i), 0)) + .collect(); + // Second group so HDBSCAN sees density contrast + for i in 0..5 { + inputs.push(make_input(&format!("b{}", i), make_face(64, i + 100), 0)); + } + + let result = run_pet_clustering(&inputs, &config); + let c0 = result.face_to_cluster.get("a0"); + assert!(c0.is_some(), "a0 should be clustered"); + for i in 1..5 { + assert_eq!( + c0, + result.face_to_cluster.get(&format!("a{}", i)), + "a{} should be in same cluster as a0", + i + ); + } + // Groups should be separate + assert_ne!( + result.face_to_cluster.get("a0"), + result.face_to_cluster.get("b0"), + "Group A and B should be separate" + ); + } + + #[test] + fn test_two_groups_separate() { + let config = ClusterConfig::dog(); + + // Group A: 3 nearly-identical faces along dimension 0 + let base_a = { + let mut v = vec![0.0f32; 128]; + v[0] = 1.0; + v + }; + // Group B: 3 nearly-identical faces along dimension 64 + let base_b = { + let mut v = vec![0.0f32; 128]; + v[64] = 1.0; + v + }; + + let perturb = |base: &Vec, dim: usize, amt: f32| { + let mut v = base.clone(); + v[dim] += amt; + normalized(v) + }; + + let inputs = vec![ + make_input("a1", base_a.clone(), 0), + make_input("a2", perturb(&base_a, 1, 0.02), 0), + make_input("a3", perturb(&base_a, 2, 0.02), 0), + make_input("b1", base_b.clone(), 0), + make_input("b2", perturb(&base_b, 65, 0.02), 0), + make_input("b3", perturb(&base_b, 66, 0.02), 0), + ]; + let result = run_pet_clustering(&inputs, &config); + + let ca1 = result.face_to_cluster.get("a1"); + let ca2 = result.face_to_cluster.get("a2"); + let ca3 = result.face_to_cluster.get("a3"); + let cb1 = result.face_to_cluster.get("b1"); + let cb2 = result.face_to_cluster.get("b2"); + let cb3 = result.face_to_cluster.get("b3"); + assert_eq!(ca1, ca2, "a1 and a2 should be in same cluster"); + assert_eq!(ca1, ca3, "a1 and a3 should be in same cluster"); + assert_eq!(cb1, cb2, "b1 and b2 should be in same cluster"); + assert_eq!(cb1, cb3, "b1 and b3 should be in same cluster"); + assert_ne!( + ca1, cb1, + "group a and group b should be in different clusters" + ); + } + + /// Simulate the real app flow: + /// ONE: Batch cluster initial faces → creates clusters + centroids + /// TWO: New faces arrive → incremental assigns to existing clusters + /// THREE: More faces arrive → incremental assigns again + #[test] + fn test_incremental_one_two_three() { + let config = ClusterConfig::dog(); + + // ── ONE: Initial batch of 10 faces (2 dogs, 5 each) ── + let mut batch_inputs = Vec::new(); + for i in 0..5u32 { + batch_inputs.push(make_input(&format!("dogA_{}", i), make_face(0, i), 0)); + } + for i in 0..5u32 { + batch_inputs.push(make_input( + &format!("dogB_{}", i), + make_face(64, i + 100), + 0, + )); + } + + let batch_result = run_pet_clustering(&batch_inputs, &config); + + // Verify batch: 2 clusters, 0 unclustered + let ca = batch_result.face_to_cluster.get("dogA_0"); + let cb = batch_result.face_to_cluster.get("dogB_0"); + assert!(ca.is_some(), "ONE: dogA_0 should be clustered"); + assert!(cb.is_some(), "ONE: dogB_0 should be clustered"); + assert_ne!(ca, cb, "ONE: dog A and dog B should be separate clusters"); + + // All A's in same cluster + for i in 1..5 { + assert_eq!( + ca, + batch_result.face_to_cluster.get(&format!("dogA_{}", i)), + "ONE: dogA_{} should be in same cluster as dogA_0", + i + ); + } + // All B's in same cluster + for i in 1..5 { + assert_eq!( + cb, + batch_result.face_to_cluster.get(&format!("dogB_{}", i)), + "ONE: dogB_{} should be in same cluster as dogB_0", + i + ); + } + + let cluster_a_id = ca.unwrap().clone(); + let cluster_b_id = cb.unwrap().clone(); + + println!( + "ONE: {} clusters, {} unclustered, A={}, B={}", + batch_result.cluster_counts.len(), + batch_result.n_unclustered, + cluster_a_id, + cluster_b_id + ); + + // Extract centroids (simulating what Dart stores in VDB) + let centroids = batch_result.cluster_centroids.clone(); + assert!( + centroids.contains_key(&cluster_a_id), + "ONE: centroid for cluster A should exist" + ); + assert!( + centroids.contains_key(&cluster_b_id), + "ONE: centroid for cluster B should exist" + ); + + // ── TWO: 4 new faces arrive (2 for each dog) ── + let new_faces_two = vec![ + make_input("dogA_new1", make_face(0, 50), 0), + make_input("dogA_new2", make_face(0, 51), 0), + make_input("dogB_new1", make_face(64, 150), 0), + make_input("dogB_new2", make_face(64, 151), 0), + ]; + + let incr_result_two = run_pet_clustering_incremental(&new_faces_two, ¢roids, &config); + + println!( + "TWO: {} assigned, {} unclustered, clusters={:?}", + incr_result_two.face_to_cluster.len(), + incr_result_two.n_unclustered, + incr_result_two.cluster_counts + ); + + // dogA new faces should go to cluster A + let ca_new1 = incr_result_two.face_to_cluster.get("dogA_new1"); + let ca_new2 = incr_result_two.face_to_cluster.get("dogA_new2"); + assert_eq!( + ca_new1, + Some(&cluster_a_id), + "TWO: dogA_new1 should be assigned to cluster A" + ); + assert_eq!( + ca_new2, + Some(&cluster_a_id), + "TWO: dogA_new2 should be assigned to cluster A" + ); + + // dogB new faces should go to cluster B + let cb_new1 = incr_result_two.face_to_cluster.get("dogB_new1"); + let cb_new2 = incr_result_two.face_to_cluster.get("dogB_new2"); + assert_eq!( + cb_new1, + Some(&cluster_b_id), + "TWO: dogB_new1 should be assigned to cluster B" + ); + assert_eq!( + cb_new2, + Some(&cluster_b_id), + "TWO: dogB_new2 should be assigned to cluster B" + ); + + assert_eq!( + incr_result_two.n_unclustered, 0, + "TWO: all new faces should be assigned" + ); + + // ── THREE: 3 more faces (2 dog A, 1 completely new dog C) ── + let new_faces_three = vec![ + make_input("dogA_new3", make_face(0, 52), 0), + make_input("dogA_new4", make_face(0, 53), 0), + // Dog C: new identity, dim 32 (not matching A or B) + make_input("dogC_0", make_face(32, 200), 0), + ]; + + let incr_result_three = + run_pet_clustering_incremental(&new_faces_three, ¢roids, &config); + + println!( + "THREE: {} assigned, {} unclustered, clusters={:?}", + incr_result_three.face_to_cluster.len(), + incr_result_three.n_unclustered, + incr_result_three.cluster_counts + ); + + // Dog A faces should still go to cluster A + assert_eq!( + incr_result_three.face_to_cluster.get("dogA_new3"), + Some(&cluster_a_id), + "THREE: dogA_new3 should be assigned to cluster A" + ); + assert_eq!( + incr_result_three.face_to_cluster.get("dogA_new4"), + Some(&cluster_a_id), + "THREE: dogA_new4 should be assigned to cluster A" + ); + + // Dog C should NOT be assigned to A or B (it's a new identity) + let cc = incr_result_three.face_to_cluster.get("dogC_0"); + if let Some(cid) = cc { + assert_ne!(cid, &cluster_a_id, "THREE: dogC should not be in cluster A"); + assert_ne!(cid, &cluster_b_id, "THREE: dogC should not be in cluster B"); + println!("THREE: dogC_0 created new cluster {}", cid); + } else { + println!("THREE: dogC_0 stayed unclustered (only 1 face, expected)"); + } + } + + /// Test incremental with a face that's borderline between two clusters. + /// The margin check should prevent ambiguous assignment. + #[test] + fn test_incremental_ambiguous_face_stays_unassigned() { + let config = ClusterConfig::dog(); + + // Batch: 2 clusters + let mut batch_inputs = Vec::new(); + for i in 0..5u32 { + batch_inputs.push(make_input(&format!("a{}", i), make_face(0, i), 0)); + } + for i in 0..5u32 { + batch_inputs.push(make_input(&format!("b{}", i), make_face(64, i + 100), 0)); + } + + let batch_result = run_pet_clustering(&batch_inputs, &config); + let centroids = batch_result.cluster_centroids.clone(); + + // Create a face that's equidistant from both centroids + // (halfway between dim 0 and dim 64) + let mut ambiguous = vec![0.0f32; 128]; + ambiguous[0] = 1.0; + ambiguous[64] = 1.0; + let ambiguous = normalized(ambiguous); + + let new_faces = vec![make_input("ambiguous", ambiguous, 0)]; + + let result = run_pet_clustering_incremental(&new_faces, ¢roids, &config); + + println!( + "AMBIGUOUS: assigned={:?}, unclustered={}", + result.face_to_cluster.get("ambiguous"), + result.n_unclustered + ); + + // Should either be unassigned or in its own new cluster — NOT in A or B + let cluster_a = batch_result.face_to_cluster.get("a0").unwrap(); + let cluster_b = batch_result.face_to_cluster.get("b0").unwrap(); + + if let Some(assigned) = result.face_to_cluster.get("ambiguous") { + assert_ne!( + assigned, cluster_a, + "Ambiguous face should not be forced into cluster A" + ); + assert_ne!( + assigned, cluster_b, + "Ambiguous face should not be forced into cluster B" + ); + } + } + + /// Test incremental handles faceless (body-only) inputs. + #[test] + fn test_incremental_faceless_stays_unassigned() { + let config = ClusterConfig::dog(); + + // Batch: 2 clusters + let mut batch_inputs = Vec::new(); + for i in 0..5u32 { + batch_inputs.push(make_input(&format!("a{}", i), make_face(0, i), 0)); + } + for i in 0..5u32 { + batch_inputs.push(make_input(&format!("b{}", i), make_face(64, i + 100), 0)); + } + + let batch_result = run_pet_clustering(&batch_inputs, &config); + let centroids = batch_result.cluster_centroids.clone(); + + // Faceless input — can't match against face centroids + let new_faces = vec![make_input("no_face", vec![], 0)]; + + let result = run_pet_clustering_incremental(&new_faces, ¢roids, &config); + + assert_eq!( + result.n_unclustered, 1, + "Faceless input should be unclustered in incremental mode" + ); + } + + /// Exemplar-based incremental: same ONE→TWO→THREE flow but using + /// multi-exemplar matching instead of centroid matching. + #[test] + fn test_incremental_exemplars_one_two_three() { + let config = ClusterConfig::dog(); + + // ── ONE: Batch cluster 10 faces (2 dogs, 5 each) ── + let mut batch_inputs = Vec::new(); + for i in 0..5u32 { + batch_inputs.push(make_input(&format!("dogA_{}", i), make_face(0, i), 0)); + } + for i in 0..5u32 { + batch_inputs.push(make_input( + &format!("dogB_{}", i), + make_face(64, i + 100), + 0, + )); + } + + let batch_result = run_pet_clustering(&batch_inputs, &config); + + let ca = batch_result.face_to_cluster.get("dogA_0"); + let cb = batch_result.face_to_cluster.get("dogB_0"); + assert!(ca.is_some()); + assert!(cb.is_some()); + assert_ne!(ca, cb); + + let cluster_a_id = ca.unwrap().clone(); + let cluster_b_id = cb.unwrap().clone(); + + // Verify exemplars were computed + let exemplars = batch_result.cluster_exemplars.clone(); + assert!( + exemplars.contains_key(&cluster_a_id), + "ONE: exemplars for cluster A should exist" + ); + assert!( + exemplars.contains_key(&cluster_b_id), + "ONE: exemplars for cluster B should exist" + ); + assert!( + exemplars[&cluster_a_id].len() <= config.max_exemplars, + "ONE: exemplar count should not exceed max_exemplars" + ); + + println!( + "ONE (exemplars): A has {} exemplars, B has {} exemplars", + exemplars[&cluster_a_id].len(), + exemplars[&cluster_b_id].len() + ); + + // ── TWO: 4 new faces, using exemplar matching ── + let new_faces_two = vec![ + make_input("dogA_new1", make_face(0, 50), 0), + make_input("dogA_new2", make_face(0, 51), 0), + make_input("dogB_new1", make_face(64, 150), 0), + make_input("dogB_new2", make_face(64, 151), 0), + ]; + + let incr_result_two = + run_pet_clustering_incremental_with_exemplars(&new_faces_two, &exemplars, &config); + + println!( + "TWO (exemplars): {} assigned, {} unclustered", + incr_result_two.face_to_cluster.len(), + incr_result_two.n_unclustered, + ); + + assert_eq!( + incr_result_two.face_to_cluster.get("dogA_new1"), + Some(&cluster_a_id), + "TWO: dogA_new1 should be assigned to cluster A" + ); + assert_eq!( + incr_result_two.face_to_cluster.get("dogA_new2"), + Some(&cluster_a_id), + "TWO: dogA_new2 should be assigned to cluster A" + ); + assert_eq!( + incr_result_two.face_to_cluster.get("dogB_new1"), + Some(&cluster_b_id), + "TWO: dogB_new1 should be assigned to cluster B" + ); + assert_eq!( + incr_result_two.face_to_cluster.get("dogB_new2"), + Some(&cluster_b_id), + "TWO: dogB_new2 should be assigned to cluster B" + ); + assert_eq!(incr_result_two.n_unclustered, 0); + + // ── THREE: 3 more faces (2 dog A, 1 new dog C) ── + let new_faces_three = vec![ + make_input("dogA_new3", make_face(0, 52), 0), + make_input("dogA_new4", make_face(0, 53), 0), + make_input("dogC_0", make_face(32, 200), 0), + ]; + + let incr_result_three = + run_pet_clustering_incremental_with_exemplars(&new_faces_three, &exemplars, &config); + + println!( + "THREE (exemplars): {} assigned, {} unclustered", + incr_result_three.face_to_cluster.len(), + incr_result_three.n_unclustered, + ); + + assert_eq!( + incr_result_three.face_to_cluster.get("dogA_new3"), + Some(&cluster_a_id), + "THREE: dogA_new3 should be assigned to cluster A" + ); + assert_eq!( + incr_result_three.face_to_cluster.get("dogA_new4"), + Some(&cluster_a_id), + "THREE: dogA_new4 should be assigned to cluster A" + ); + + // Dog C should NOT be in cluster A or B + let cc = incr_result_three.face_to_cluster.get("dogC_0"); + if let Some(cid) = cc { + assert_ne!(cid, &cluster_a_id); + assert_ne!(cid, &cluster_b_id); + println!("THREE (exemplars): dogC_0 created new cluster {}", cid); + } else { + println!("THREE (exemplars): dogC_0 stayed unclustered (only 1 face)"); + } + } + + /// Verify that batch clustering produces exemplars for all clusters. + #[test] + fn test_batch_produces_exemplars() { + let config = ClusterConfig::dog(); + + let mut inputs = Vec::new(); + for i in 0..8u32 { + inputs.push(make_input(&format!("a{}", i), make_face(0, i), 0)); + } + for i in 0..8u32 { + inputs.push(make_input(&format!("b{}", i), make_face(64, i + 100), 0)); + } + + let result = run_pet_clustering(&inputs, &config); + + // Every cluster with a centroid should also have exemplars + for cid in result.cluster_centroids.keys() { + assert!( + result.cluster_exemplars.contains_key(cid), + "Cluster {} should have exemplars", + cid + ); + let exs = &result.cluster_exemplars[cid]; + assert!(!exs.is_empty(), "Exemplars should not be empty"); + assert!( + exs.len() <= config.max_exemplars, + "Should not exceed max_exemplars" + ); + // Each exemplar should be 128-d + for ex in exs { + assert_eq!(ex.len(), 128); + } + } + } +} diff --git a/rust/photos/src/ml/pet/detect.rs b/rust/photos/src/ml/pet/detect.rs index 817ecbd9c6c..9671b32eb07 100644 --- a/rust/photos/src/ml/pet/detect.rs +++ b/rust/photos/src/ml/pet/detect.rs @@ -10,7 +10,7 @@ const INPUT_HEIGHT: f32 = 640.0; // Pet face detection thresholds (from Python config) const PET_FACE_IOU_THRESHOLD: f32 = 0.5; -const PET_FACE_MIN_SCORE: f32 = 0.3; +const PET_FACE_MIN_SCORE: f32 = 0.5; // Body detection thresholds const BODY_IOU_THRESHOLD: f32 = 0.5; @@ -34,13 +34,21 @@ const COCO_DOG: u8 = 16; pub fn run_pet_face_detection( runtime: &MlRuntimeView<'_>, decoded: &DecodedImage, +) -> MlResult> { + let session = runtime.pet_face_detection_session()?; + run_pet_face_detection_with_session(&session, decoded) +} + +/// Same as [run_pet_face_detection] but accepts a pre-built session directly. +pub fn run_pet_face_detection_with_session( + session: &ort::Session, + decoded: &DecodedImage, ) -> MlResult> { let (input, scaled_width, scaled_height, pad_left, pad_top) = preprocess::preprocess_yolo(decoded)?; - let pet_face_detection = runtime.pet_face_detection_session()?; let (output_shape, output_data) = onnx::run_f32( - &pet_face_detection, + session, input, [1, 3, INPUT_HEIGHT as i64, INPUT_WIDTH as i64], )?; @@ -66,15 +74,8 @@ pub fn run_pet_face_detection( total, output_shape ))); }; - // Warn if the total is ambiguously divisible by multiple candidates. let candidates = [11usize, 12, 13]; - let valid_count = candidates.iter().filter(|&&c| total % c == 0).count(); - if valid_count > 1 { - eprintln!( - "[ml][pet] WARNING: flat output len={total} is divisible by {valid_count} row-length candidates; using {inferred}. \ - Prefer a model with 2D output shape for reliability." - ); - } + let _valid_count = candidates.iter().filter(|&&c| total % c == 0).count(); inferred } else { return Err(MlError::Postprocess( @@ -130,14 +131,23 @@ pub fn run_pet_face_detection( // For a 2-class model (row_len >= 13): row[11] = cat score, // row[12] = dog score. Pick argmax and map to 0=dog, 1=cat. + // Skip detections where species confidence is too low (< 60% of + // the winning class) to avoid misclassifying cats as dogs or vice versa. // For a 1-class model (row_len == 12): row[11] is the single class // score; class is always 0 (dog). let class_id: u8 = if row_len >= 13 { - if row[12] > row[11] { 0 } else { 1 } + let cat_score = row[11]; + let dog_score = row[12]; + let max_score = cat_score.max(dog_score); + let min_score = cat_score.min(dog_score); + // Skip if species scores are too close (ambiguous) + if max_score > 0.0 && min_score / max_score > 0.7 { + continue; + } + if dog_score > cat_score { 0 } else { 1 } } else { 0 }; - detections.push(PetFaceDetection { score, box_xyxy, @@ -187,7 +197,7 @@ pub fn run_pet_body_detection( // Find the winning class across all 80 COCO classes and only // keep detections whose predicted class is cat (15) or dog (16). let class_logits = &row[5..85]; - let (best_cls, best_logit) = class_logits + let (best_cls, _) = class_logits .iter() .enumerate() .max_by(|a, b| a.1.total_cmp(b.1)) @@ -196,11 +206,11 @@ pub fn run_pet_body_detection( if best_cls != COCO_CAT && best_cls != COCO_DOG { continue; } - let class_score = best_logit * obj_conf; + let class_score = row[5 + best_cls as usize] * obj_conf; + let class_id = best_cls; if class_score < BODY_MIN_SCORE { continue; } - let class_id = best_cls; let x_min_abs = row[0] - row[2] / 2.0; let y_min_abs = row[1] - row[3] / 2.0; @@ -239,28 +249,30 @@ fn correct_box_for_aspect_ratio( pad_left: usize, pad_top: usize, ) { - if scaled_width == INPUT_WIDTH as usize - && scaled_height == INPUT_HEIGHT as usize - && pad_left == 0 - && pad_top == 0 + if scaled_width != INPUT_WIDTH as usize + || scaled_height != INPUT_HEIGHT as usize + || pad_left != 0 + || pad_top != 0 { - return; + let scaled_width = scaled_width as f32; + let scaled_height = scaled_height as f32; + let pad_left = pad_left as f32; + let pad_top = pad_top as f32; + + let transform_x = |x: f32| -> f32 { (x * INPUT_WIDTH - pad_left) / scaled_width }; + let transform_y = |y: f32| -> f32 { (y * INPUT_HEIGHT - pad_top) / scaled_height }; + + box_xyxy[0] = transform_x(box_xyxy[0]); + box_xyxy[1] = transform_y(box_xyxy[1]); + box_xyxy[2] = transform_x(box_xyxy[2]); + box_xyxy[3] = transform_y(box_xyxy[3]); } - let scaled_width = scaled_width as f32; - let scaled_height = scaled_height as f32; - let pad_left = pad_left as f32; - let pad_top = pad_top as f32; - - let transform_x = - |x: f32| -> f32 { ((x * INPUT_WIDTH - pad_left) / scaled_width).clamp(0.0, 1.0) }; - let transform_y = - |y: f32| -> f32 { ((y * INPUT_HEIGHT - pad_top) / scaled_height).clamp(0.0, 1.0) }; - - box_xyxy[0] = transform_x(box_xyxy[0]); - box_xyxy[1] = transform_y(box_xyxy[1]); - box_xyxy[2] = transform_x(box_xyxy[2]); - box_xyxy[3] = transform_y(box_xyxy[3]); + // Always clamp to [0, 1] -- YOLO can predict boxes that extend beyond + // the image boundary, matching the Python pipeline's post-detection clamp. + for v in box_xyxy.iter_mut() { + *v = v.clamp(0.0, 1.0); + } } fn correct_for_maintained_aspect_ratio_3kp( @@ -273,24 +285,26 @@ fn correct_for_maintained_aspect_ratio_3kp( ) { correct_box_for_aspect_ratio(box_xyxy, scaled_width, scaled_height, pad_left, pad_top); - if scaled_width == INPUT_WIDTH as usize - && scaled_height == INPUT_HEIGHT as usize - && pad_left == 0 - && pad_top == 0 + if scaled_width != INPUT_WIDTH as usize + || scaled_height != INPUT_HEIGHT as usize + || pad_left != 0 + || pad_top != 0 { - return; + let transform_x = + |x: f32| -> f32 { (x * INPUT_WIDTH - pad_left as f32) / scaled_width as f32 }; + let transform_y = + |y: f32| -> f32 { (y * INPUT_HEIGHT - pad_top as f32) / scaled_height as f32 }; + + for point in keypoints.iter_mut() { + point[0] = transform_x(point[0]); + point[1] = transform_y(point[1]); + } } - let transform_x = |x: f32| -> f32 { - ((x * INPUT_WIDTH - pad_left as f32) / scaled_width as f32).clamp(0.0, 1.0) - }; - let transform_y = |y: f32| -> f32 { - ((y * INPUT_HEIGHT - pad_top as f32) / scaled_height as f32).clamp(0.0, 1.0) - }; - + // Always clamp keypoints to [0, 1], matching box clamping above. for point in keypoints.iter_mut() { - point[0] = transform_x(point[0]); - point[1] = transform_y(point[1]); + point[0] = point[0].clamp(0.0, 1.0); + point[1] = point[1].clamp(0.0, 1.0); } } @@ -377,3 +391,31 @@ fn naive_nms_pet_body( .filter_map(|(d, s)| if s { None } else { Some(d) }) .collect() } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn iou_identical_boxes() { + let b = [0.1, 0.1, 0.5, 0.5]; + assert!((calculate_iou_4(&b, &b) - 1.0).abs() < 1e-6); + } + + #[test] + fn iou_disjoint_boxes() { + let a = [0.0, 0.0, 0.1, 0.1]; + let b = [0.5, 0.5, 0.6, 0.6]; + assert_eq!(calculate_iou_4(&a, &b), 0.0); + } + + #[test] + fn iou_partial_overlap() { + let a = [0.0, 0.0, 0.4, 0.4]; + let b = [0.2, 0.2, 0.6, 0.6]; + // Intersection: [0.2,0.2]->[0.4,0.4] = 0.2*0.2 = 0.04 + // Union: 0.16 + 0.16 - 0.04 = 0.28 + let iou = calculate_iou_4(&a, &b); + assert!((iou - 0.04 / 0.28).abs() < 1e-5); + } +} diff --git a/rust/photos/src/ml/pet/embed.rs b/rust/photos/src/ml/pet/embed.rs index f640f56f5ca..22646e3384d 100644 --- a/rust/photos/src/ml/pet/embed.rs +++ b/rust/photos/src/ml/pet/embed.rs @@ -12,22 +12,53 @@ const BODY_EMBED_INPUT_SIZE: i64 = 224; const FACE_EMBED_CHANNELS: i64 = 3; const BODY_EMBED_CHANNELS: i64 = 3; -/// Run pet face embedding on aligned face inputs. +/// Run pet face embedding using each face's own `class_id` to select the model. /// -/// The species parameter (0=dog, 1=cat) selects the model to use. +/// Faces are grouped by species and batched per model to avoid running the +/// wrong embedding model on any detection. /// /// Input per face: CHW float32 of shape [1, 3, 224, 224], ImageNet-normalized. /// Output: L2-normalized embedding vector (128-d for BYOL). /// /// This mirrors `pet_pipeline/embedding.py` `Embedder.embed_face()`. -/// Run pet face embedding using each face's own `class_id` to select the model. -/// -/// Faces are grouped by species and batched per model to avoid running the -/// wrong embedding model on any detection. pub fn run_pet_face_embedding( runtime: &MlRuntimeView<'_>, aligned_faces: &[Vec], face_results: &mut [PetFaceResult], +) -> MlResult<()> { + let get_session = |species: u8| -> MlResult<_> { + if species == 0 { + runtime.pet_face_embedding_dog_session() + } else { + runtime.pet_face_embedding_cat_session() + } + }; + run_pet_face_embedding_inner(aligned_faces, face_results, &get_session) +} + +/// Same as [run_pet_face_embedding] but accepts pre-built sessions directly. +pub fn run_pet_face_embedding_with_sessions( + aligned_faces: &[Vec], + face_results: &mut [PetFaceResult], + dog_session: &ort::Session, + cat_session: &ort::Session, +) -> MlResult<()> { + let get_session = |species: u8| -> MlResult<_> { + // Return a reference that satisfies the closure's lifetime. + // We wrap in Ok() since the caller already built the sessions. + if species == 0 { + Ok(dog_session) + } else { + Ok(cat_session) + } + }; + run_pet_face_embedding_inner(aligned_faces, face_results, &get_session) +} + +fn run_pet_face_embedding_inner>( + aligned_faces: &[Vec], + face_results: &mut [PetFaceResult], + get_session: &dyn Fn(u8) -> MlResult, ) -> MlResult<()> { if aligned_faces.is_empty() { return Ok(()); @@ -73,11 +104,7 @@ pub fn run_pet_face_embedding( input.extend_from_slice(aligned); } - let session = if species == 0 { - runtime.pet_face_embedding_dog_session()? - } else { - runtime.pet_face_embedding_cat_session()? - }; + let session = get_session(species)?; let (shape, output) = onnx::run_f32( &session, diff --git a/rust/photos/src/ml/pet/mod.rs b/rust/photos/src/ml/pet/mod.rs index de8d91a2406..3ff62374712 100644 --- a/rust/photos/src/ml/pet/mod.rs +++ b/rust/photos/src/ml/pet/mod.rs @@ -1,4 +1,5 @@ pub mod align; +pub mod cluster; pub mod detect; pub mod embed; pub mod preprocess;