diff --git a/internal/autotag/gallery.go b/internal/autotag/gallery.go index 031079e494..3d027ffe24 100644 --- a/internal/autotag/gallery.go +++ b/internal/autotag/gallery.go @@ -7,6 +7,7 @@ import ( "github.com/stashapp/stash/pkg/gallery" "github.com/stashapp/stash/pkg/match" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/txn" ) type GalleryFinderUpdater interface { @@ -43,9 +44,11 @@ func getGalleryFileTagger(s *models.Gallery, cache *match.Cache) tagger { } } -// GalleryPerformers tags the provided gallery with performers whose name matches the gallery's path. -func GalleryPerformers(ctx context.Context, s *models.Gallery, rw GalleryPerformerUpdater, performerReader models.PerformerAutoTagQueryer, cache *match.Cache) error { - t := getGalleryFileTagger(s, cache) +// GalleryPerformersAtPath tags the provided gallery with performers whose +// name matches the gallery's path. A fresh write txn is opened only when a +// match is applied. +func (tagger *Tagger) GalleryPerformersAtPath(ctx context.Context, s *models.Gallery, rw GalleryPerformerUpdater, performerReader models.PerformerAutoTagQueryer) error { + t := getGalleryFileTagger(s, tagger.Cache) return t.tagPerformers(ctx, performerReader, func(subjectID, otherID int) (bool, error) { if err := s.LoadPerformerIDs(ctx, rw); err != nil { @@ -57,7 +60,9 @@ func GalleryPerformers(ctx context.Context, s *models.Gallery, rw GalleryPerform return false, nil } - if err := gallery.AddPerformer(ctx, rw, s, otherID); err != nil { + if err := txn.WithTxn(ctx, tagger.TxnManager, func(ctx context.Context) error { + return gallery.AddPerformer(ctx, rw, s, otherID) + }); err != nil { return false, err } @@ -65,25 +70,35 @@ func GalleryPerformers(ctx context.Context, s *models.Gallery, rw GalleryPerform }) } -// GalleryStudios tags the provided gallery with the first studio whose name matches the gallery's path. +// GalleryStudiosAtPath tags the provided gallery with the first studio whose +// name matches the gallery's path. // -// Gallerys will not be tagged if studio is already set. -func GalleryStudios(ctx context.Context, s *models.Gallery, rw GalleryFinderUpdater, studioReader models.StudioAutoTagQueryer, cache *match.Cache) error { +// Galleries will not be tagged if studio is already set. +func (tagger *Tagger) GalleryStudiosAtPath(ctx context.Context, s *models.Gallery, rw GalleryFinderUpdater, studioReader models.StudioAutoTagQueryer) error { if s.StudioID != nil { // don't modify return nil } - t := getGalleryFileTagger(s, cache) + t := getGalleryFileTagger(s, tagger.Cache) return t.tagStudios(ctx, studioReader, func(subjectID, otherID int) (bool, error) { - return addGalleryStudio(ctx, rw, s, otherID) + var added bool + if err := txn.WithTxn(ctx, tagger.TxnManager, func(ctx context.Context) error { + var err error + added, err = addGalleryStudio(ctx, rw, s, otherID) + return err + }); err != nil { + return false, err + } + return added, nil }) } -// GalleryTags tags the provided gallery with tags whose name matches the gallery's path. -func GalleryTags(ctx context.Context, s *models.Gallery, rw GalleryTagUpdater, tagReader models.TagAutoTagQueryer, cache *match.Cache) error { - t := getGalleryFileTagger(s, cache) +// GalleryTagsAtPath tags the provided gallery with tags whose name matches +// the gallery's path. +func (tagger *Tagger) GalleryTagsAtPath(ctx context.Context, s *models.Gallery, rw GalleryTagUpdater, tagReader models.TagAutoTagQueryer) error { + t := getGalleryFileTagger(s, tagger.Cache) return t.tagTags(ctx, tagReader, func(subjectID, otherID int) (bool, error) { if err := s.LoadTagIDs(ctx, rw); err != nil { @@ -95,7 +110,9 @@ func GalleryTags(ctx context.Context, s *models.Gallery, rw GalleryTagUpdater, t return false, nil } - if err := gallery.AddTag(ctx, rw, s, otherID); err != nil { + if err := txn.WithTxn(ctx, tagger.TxnManager, func(ctx context.Context) error { + return gallery.AddTag(ctx, rw, s, otherID) + }); err != nil { return false, err } diff --git a/internal/autotag/gallery_test.go b/internal/autotag/gallery_test.go index 6333a6c172..d061cec05d 100644 --- a/internal/autotag/gallery_test.go +++ b/internal/autotag/gallery_test.go @@ -68,7 +68,7 @@ func TestGalleryPerformers(t *testing.T) { return galleryPartialsEqual(got, expected) }) - db.Gallery.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once() + db.Gallery.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once() } gallery := models.Gallery{ @@ -76,7 +76,8 @@ func TestGalleryPerformers(t *testing.T) { Path: test.Path, PerformerIDs: models.NewRelatedIDs([]int{}), } - err := GalleryPerformers(testCtx, &gallery, db.Gallery, db.Performer, nil) + tagger := &Tagger{TxnManager: db, Cache: nil} + err := tagger.GalleryPerformersAtPath(testCtx, &gallery, db.Gallery, db.Performer) assert.Nil(err) db.AssertExpectations(t) @@ -114,14 +115,15 @@ func TestGalleryStudios(t *testing.T) { return galleryPartialsEqual(got, expected) }) - db.Gallery.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once() + db.Gallery.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once() } gallery := models.Gallery{ ID: galleryID, Path: test.Path, } - err := GalleryStudios(testCtx, &gallery, db.Gallery, db.Studio, nil) + tagger := &Tagger{TxnManager: db, Cache: nil} + err := tagger.GalleryStudiosAtPath(testCtx, &gallery, db.Gallery, db.Studio) assert.Nil(err) db.AssertExpectations(t) @@ -189,7 +191,7 @@ func TestGalleryTags(t *testing.T) { return galleryPartialsEqual(got, expected) }) - db.Gallery.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once() + db.Gallery.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once() } gallery := models.Gallery{ @@ -197,7 +199,8 @@ func TestGalleryTags(t *testing.T) { Path: test.Path, TagIDs: models.NewRelatedIDs([]int{}), } - err := GalleryTags(testCtx, &gallery, db.Gallery, db.Tag, nil) + tagger := &Tagger{TxnManager: db, Cache: nil} + err := tagger.GalleryTagsAtPath(testCtx, &gallery, db.Gallery, db.Tag) assert.Nil(err) db.AssertExpectations(t) diff --git a/internal/autotag/image.go b/internal/autotag/image.go index e4acbcd3af..4202798f3d 100644 --- a/internal/autotag/image.go +++ b/internal/autotag/image.go @@ -7,6 +7,7 @@ import ( "github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/match" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/txn" ) type ImageFinderUpdater interface { @@ -34,9 +35,11 @@ func getImageFileTagger(s *models.Image, cache *match.Cache) tagger { } } -// ImagePerformers tags the provided image with performers whose name matches the image's path. -func ImagePerformers(ctx context.Context, s *models.Image, rw ImagePerformerUpdater, performerReader models.PerformerAutoTagQueryer, cache *match.Cache) error { - t := getImageFileTagger(s, cache) +// ImagePerformersAtPath tags the provided image with performers whose name +// matches the image's path. A fresh write txn is opened only when a match is +// applied. +func (tagger *Tagger) ImagePerformersAtPath(ctx context.Context, s *models.Image, rw ImagePerformerUpdater, performerReader models.PerformerAutoTagQueryer) error { + t := getImageFileTagger(s, tagger.Cache) return t.tagPerformers(ctx, performerReader, func(subjectID, otherID int) (bool, error) { if err := s.LoadPerformerIDs(ctx, rw); err != nil { @@ -48,7 +51,9 @@ func ImagePerformers(ctx context.Context, s *models.Image, rw ImagePerformerUpda return false, nil } - if err := image.AddPerformer(ctx, rw, s, otherID); err != nil { + if err := txn.WithTxn(ctx, tagger.TxnManager, func(ctx context.Context) error { + return image.AddPerformer(ctx, rw, s, otherID) + }); err != nil { return false, err } @@ -56,25 +61,35 @@ func ImagePerformers(ctx context.Context, s *models.Image, rw ImagePerformerUpda }) } -// ImageStudios tags the provided image with the first studio whose name matches the image's path. +// ImageStudiosAtPath tags the provided image with the first studio whose +// name matches the image's path. // // Images will not be tagged if studio is already set. -func ImageStudios(ctx context.Context, s *models.Image, rw ImageFinderUpdater, studioReader models.StudioAutoTagQueryer, cache *match.Cache) error { +func (tagger *Tagger) ImageStudiosAtPath(ctx context.Context, s *models.Image, rw ImageFinderUpdater, studioReader models.StudioAutoTagQueryer) error { if s.StudioID != nil { // don't modify return nil } - t := getImageFileTagger(s, cache) + t := getImageFileTagger(s, tagger.Cache) return t.tagStudios(ctx, studioReader, func(subjectID, otherID int) (bool, error) { - return addImageStudio(ctx, rw, s, otherID) + var added bool + if err := txn.WithTxn(ctx, tagger.TxnManager, func(ctx context.Context) error { + var err error + added, err = addImageStudio(ctx, rw, s, otherID) + return err + }); err != nil { + return false, err + } + return added, nil }) } -// ImageTags tags the provided image with tags whose name matches the image's path. -func ImageTags(ctx context.Context, s *models.Image, rw ImageTagUpdater, tagReader models.TagAutoTagQueryer, cache *match.Cache) error { - t := getImageFileTagger(s, cache) +// ImageTagsAtPath tags the provided image with tags whose name matches the +// image's path. +func (tagger *Tagger) ImageTagsAtPath(ctx context.Context, s *models.Image, rw ImageTagUpdater, tagReader models.TagAutoTagQueryer) error { + t := getImageFileTagger(s, tagger.Cache) return t.tagTags(ctx, tagReader, func(subjectID, otherID int) (bool, error) { if err := s.LoadTagIDs(ctx, rw); err != nil { @@ -86,7 +101,9 @@ func ImageTags(ctx context.Context, s *models.Image, rw ImageTagUpdater, tagRead return false, nil } - if err := image.AddTag(ctx, rw, s, otherID); err != nil { + if err := txn.WithTxn(ctx, tagger.TxnManager, func(ctx context.Context) error { + return image.AddTag(ctx, rw, s, otherID) + }); err != nil { return false, err } diff --git a/internal/autotag/image_test.go b/internal/autotag/image_test.go index 88e42532d0..ed8c467ed8 100644 --- a/internal/autotag/image_test.go +++ b/internal/autotag/image_test.go @@ -65,7 +65,7 @@ func TestImagePerformers(t *testing.T) { return imagePartialsEqual(got, expected) }) - db.Image.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once() + db.Image.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once() } image := models.Image{ @@ -73,7 +73,8 @@ func TestImagePerformers(t *testing.T) { Path: test.Path, PerformerIDs: models.NewRelatedIDs([]int{}), } - err := ImagePerformers(testCtx, &image, db.Image, db.Performer, nil) + tagger := &Tagger{TxnManager: db, Cache: nil} + err := tagger.ImagePerformersAtPath(testCtx, &image, db.Image, db.Performer) assert.Nil(err) db.AssertExpectations(t) @@ -111,14 +112,15 @@ func TestImageStudios(t *testing.T) { return imagePartialsEqual(got, expected) }) - db.Image.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once() + db.Image.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once() } image := models.Image{ ID: imageID, Path: test.Path, } - err := ImageStudios(testCtx, &image, db.Image, db.Studio, nil) + tagger := &Tagger{TxnManager: db, Cache: nil} + err := tagger.ImageStudiosAtPath(testCtx, &image, db.Image, db.Studio) assert.Nil(err) db.AssertExpectations(t) @@ -186,7 +188,7 @@ func TestImageTags(t *testing.T) { return imagePartialsEqual(got, expected) }) - db.Image.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once() + db.Image.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once() } image := models.Image{ @@ -194,7 +196,8 @@ func TestImageTags(t *testing.T) { Path: test.Path, TagIDs: models.NewRelatedIDs([]int{}), } - err := ImageTags(testCtx, &image, db.Image, db.Tag, nil) + tagger := &Tagger{TxnManager: db, Cache: nil} + err := tagger.ImageTagsAtPath(testCtx, &image, db.Image, db.Tag) assert.Nil(err) db.AssertExpectations(t) diff --git a/internal/autotag/scene.go b/internal/autotag/scene.go index 273378b9bc..d788d817f9 100644 --- a/internal/autotag/scene.go +++ b/internal/autotag/scene.go @@ -7,6 +7,7 @@ import ( "github.com/stashapp/stash/pkg/match" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/scene" + "github.com/stashapp/stash/pkg/txn" ) type SceneFinderUpdater interface { @@ -34,9 +35,12 @@ func getSceneFileTagger(s *models.Scene, cache *match.Cache) tagger { } } -// ScenePerformers tags the provided scene with performers whose name matches the scene's path. -func ScenePerformers(ctx context.Context, s *models.Scene, rw ScenePerformerUpdater, performerReader models.PerformerAutoTagQueryer, cache *match.Cache) error { - t := getSceneFileTagger(s, cache) +// ScenePerformersAtPath tags the provided scene with performers whose name +// matches the scene's path. The match phase runs using the current context +// (no outer write txn needed); a fresh write txn is opened only when a match +// is applied. +func (tagger *Tagger) ScenePerformersAtPath(ctx context.Context, s *models.Scene, rw ScenePerformerUpdater, performerReader models.PerformerAutoTagQueryer) error { + t := getSceneFileTagger(s, tagger.Cache) return t.tagPerformers(ctx, performerReader, func(subjectID, otherID int) (bool, error) { if err := s.LoadPerformerIDs(ctx, rw); err != nil { @@ -48,7 +52,9 @@ func ScenePerformers(ctx context.Context, s *models.Scene, rw ScenePerformerUpda return false, nil } - if err := scene.AddPerformer(ctx, rw, s, otherID); err != nil { + if err := txn.WithTxn(ctx, tagger.TxnManager, func(ctx context.Context) error { + return scene.AddPerformer(ctx, rw, s, otherID) + }); err != nil { return false, err } @@ -56,25 +62,35 @@ func ScenePerformers(ctx context.Context, s *models.Scene, rw ScenePerformerUpda }) } -// SceneStudios tags the provided scene with the first studio whose name matches the scene's path. +// SceneStudiosAtPath tags the provided scene with the first studio whose name +// matches the scene's path. // // Scenes will not be tagged if studio is already set. -func SceneStudios(ctx context.Context, s *models.Scene, rw SceneFinderUpdater, studioReader models.StudioAutoTagQueryer, cache *match.Cache) error { +func (tagger *Tagger) SceneStudiosAtPath(ctx context.Context, s *models.Scene, rw SceneFinderUpdater, studioReader models.StudioAutoTagQueryer) error { if s.StudioID != nil { // don't modify return nil } - t := getSceneFileTagger(s, cache) + t := getSceneFileTagger(s, tagger.Cache) return t.tagStudios(ctx, studioReader, func(subjectID, otherID int) (bool, error) { - return addSceneStudio(ctx, rw, s, otherID) + var added bool + if err := txn.WithTxn(ctx, tagger.TxnManager, func(ctx context.Context) error { + var err error + added, err = addSceneStudio(ctx, rw, s, otherID) + return err + }); err != nil { + return false, err + } + return added, nil }) } -// SceneTags tags the provided scene with tags whose name matches the scene's path. -func SceneTags(ctx context.Context, s *models.Scene, rw SceneTagUpdater, tagReader models.TagAutoTagQueryer, cache *match.Cache) error { - t := getSceneFileTagger(s, cache) +// SceneTagsAtPath tags the provided scene with tags whose name matches the +// scene's path. +func (tagger *Tagger) SceneTagsAtPath(ctx context.Context, s *models.Scene, rw SceneTagUpdater, tagReader models.TagAutoTagQueryer) error { + t := getSceneFileTagger(s, tagger.Cache) return t.tagTags(ctx, tagReader, func(subjectID, otherID int) (bool, error) { if err := s.LoadTagIDs(ctx, rw); err != nil { @@ -86,7 +102,9 @@ func SceneTags(ctx context.Context, s *models.Scene, rw SceneTagUpdater, tagRead return false, nil } - if err := scene.AddTag(ctx, rw, s, otherID); err != nil { + if err := txn.WithTxn(ctx, tagger.TxnManager, func(ctx context.Context) error { + return scene.AddTag(ctx, rw, s, otherID) + }); err != nil { return false, err } diff --git a/internal/autotag/scene_test.go b/internal/autotag/scene_test.go index aaf015c8ff..356e72ec73 100644 --- a/internal/autotag/scene_test.go +++ b/internal/autotag/scene_test.go @@ -204,10 +204,11 @@ func TestScenePerformers(t *testing.T) { return scenePartialsEqual(got, expected) }) - db.Scene.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once() + db.Scene.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once() } - err := ScenePerformers(testCtx, &scene, db.Scene, db.Performer, nil) + tagger := &Tagger{TxnManager: db, Cache: nil} + err := tagger.ScenePerformersAtPath(testCtx, &scene, db.Scene, db.Performer) assert.Nil(err) db.AssertExpectations(t) @@ -247,14 +248,15 @@ func TestSceneStudios(t *testing.T) { return scenePartialsEqual(got, expected) }) - db.Scene.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once() + db.Scene.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once() } scene := models.Scene{ ID: sceneID, Path: test.Path, } - err := SceneStudios(testCtx, &scene, db.Scene, db.Studio, nil) + tagger := &Tagger{TxnManager: db, Cache: nil} + err := tagger.SceneStudiosAtPath(testCtx, &scene, db.Scene, db.Studio) assert.Nil(err) db.AssertExpectations(t) @@ -322,7 +324,7 @@ func TestSceneTags(t *testing.T) { return scenePartialsEqual(got, expected) }) - db.Scene.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once() + db.Scene.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once() } scene := models.Scene{ @@ -330,7 +332,8 @@ func TestSceneTags(t *testing.T) { Path: test.Path, TagIDs: models.NewRelatedIDs([]int{}), } - err := SceneTags(testCtx, &scene, db.Scene, db.Tag, nil) + tagger := &Tagger{TxnManager: db, Cache: nil} + err := tagger.SceneTagsAtPath(testCtx, &scene, db.Scene, db.Tag) assert.Nil(err) db.AssertExpectations(t) diff --git a/internal/manager/task_autotag.go b/internal/manager/task_autotag.go index e280e79f6f..007415bbc4 100644 --- a/internal/manager/task_autotag.go +++ b/internal/manager/task_autotag.go @@ -6,10 +6,10 @@ import ( "path/filepath" "strconv" "strings" - "sync" "time" "github.com/stashapp/stash/internal/autotag" + "github.com/stashapp/stash/internal/manager/config" "github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/job" "github.com/stashapp/stash/pkg/logger" @@ -51,6 +51,35 @@ func (j *autoTagJob) isFileBasedAutoTag(input AutoTagMetadataInput) bool { } func (j *autoTagJob) autoTagFiles(ctx context.Context, progress *job.Progress, paths []string, performers, studios, tags bool) { + // Preload entity sets once. Each worker then matches against the + // in-memory set instead of paying a QueryForAutoTag roundtrip per file. + r := j.repository + preloadBegin := time.Now() + if err := r.WithReadTxn(ctx, func(ctx context.Context) error { + if performers { + if err := j.cache.PreloadPerformers(ctx, r.Performer); err != nil { + return fmt.Errorf("preloading performers: %w", err) + } + } + if studios { + if err := j.cache.PreloadStudios(ctx, r.Studio); err != nil { + return fmt.Errorf("preloading studios: %w", err) + } + } + if tags { + if err := j.cache.PreloadTags(ctx, r.Tag); err != nil { + return fmt.Errorf("preloading tags: %w", err) + } + } + return nil + }); err != nil { + if !job.IsCancelled(ctx) { + logger.Errorf("auto-tag preload error: %v", err) + } + return + } + logger.Infof("Preloaded auto-tag entities in %s", time.Since(preloadBegin)) + t := autoTagFilesTask{ paths: paths, performers: performers, @@ -545,25 +574,40 @@ func (t *autoTagFilesTask) processScenes(ctx context.Context) { return } - logger.Info("Auto-tagging scenes...") + workers := config.GetInstance().GetParallelTasksWithAutoDetection() + logger.Infof("Auto-tagging scenes (workers=%d)...", workers) - batchSize := 1000 + const batchSize = 1000 + const queueSize = batchSize * 4 - findFilter := models.BatchFindFilter(batchSize) + findFilter := models.KeysetFindFilter(batchSize) sceneFilter := t.makeSceneFilter() - r := t.repository - more := true - for more { + taskQueue := job.NewTaskQueue(ctx, t.progress, queueSize, workers) + defer taskQueue.Close() + + var lastID, processed int + for { + filter := sceneFilter + if lastID != 0 { + filter = &models.SceneFilterType{ + ID: &models.IntCriterionInput{ + Value: lastID, + Modifier: models.CriterionModifierGreaterThan, + }, + } + filter.And = sceneFilter + } + var scenes []*models.Scene if err := r.WithReadTxn(ctx, func(ctx context.Context) error { var err error - scenes, err = scene.Query(ctx, r.Scene, sceneFilter, findFilter) + scenes, err = scene.Query(ctx, r.Scene, filter, findFilter) return err }); err != nil { if !job.IsCancelled(ctx) { - logger.Errorf("error querying scenes for auto-tag: %w", err) + logger.Errorf("error querying scenes for auto-tag: %v", err) } return } @@ -573,32 +617,28 @@ func (t *autoTagFilesTask) processScenes(ctx context.Context) { logger.Info("Stopping auto-tag due to user request") return } - - tt := autoTagSceneTask{ - repository: r, - scene: ss, - performers: t.performers, - studios: t.studios, - tags: t.tags, - cache: t.cache, - } - - var wg sync.WaitGroup - wg.Add(1) - go tt.Start(ctx, &wg) - wg.Wait() - - t.progress.Increment() + taskQueue.Add(fmt.Sprintf("Auto-tagging %s", ss.DisplayName()), func(ctx context.Context) { + tt := autoTagSceneTask{ + repository: r, + scene: ss, + performers: t.performers, + studios: t.studios, + tags: t.tags, + cache: t.cache, + } + tt.Start(ctx) + t.progress.Increment() + }) } - if len(scenes) != batchSize { - more = false - } else { - *findFilter.Page++ + if len(scenes) < batchSize { + return + } - if *findFilter.Page%10 == 1 { - logger.Infof("Processed %d scenes...", (*findFilter.Page-1)*batchSize) - } + lastID = scenes[len(scenes)-1].ID + processed += len(scenes) + if processed%(batchSize*10) == 0 { + logger.Infof("Processed %d scenes...", processed) } } } @@ -608,25 +648,40 @@ func (t *autoTagFilesTask) processImages(ctx context.Context) { return } - logger.Info("Auto-tagging images...") + workers := config.GetInstance().GetParallelTasksWithAutoDetection() + logger.Infof("Auto-tagging images (workers=%d)...", workers) - batchSize := 1000 + const batchSize = 1000 + const queueSize = batchSize * 4 - findFilter := models.BatchFindFilter(batchSize) + findFilter := models.KeysetFindFilter(batchSize) imageFilter := t.makeImageFilter() - r := t.repository - more := true - for more { + taskQueue := job.NewTaskQueue(ctx, t.progress, queueSize, workers) + defer taskQueue.Close() + + var lastID, processed int + for { + filter := imageFilter + if lastID != 0 { + filter = &models.ImageFilterType{ + ID: &models.IntCriterionInput{ + Value: lastID, + Modifier: models.CriterionModifierGreaterThan, + }, + } + filter.And = imageFilter + } + var images []*models.Image if err := r.WithReadTxn(ctx, func(ctx context.Context) error { var err error - images, err = image.Query(ctx, r.Image, imageFilter, findFilter) + images, err = image.Query(ctx, r.Image, filter, findFilter) return err }); err != nil { if !job.IsCancelled(ctx) { - logger.Errorf("error querying images for auto-tag: %w", err) + logger.Errorf("error querying images for auto-tag: %v", err) } return } @@ -636,32 +691,28 @@ func (t *autoTagFilesTask) processImages(ctx context.Context) { logger.Info("Stopping auto-tag due to user request") return } - - tt := autoTagImageTask{ - repository: t.repository, - image: ss, - performers: t.performers, - studios: t.studios, - tags: t.tags, - cache: t.cache, - } - - var wg sync.WaitGroup - wg.Add(1) - go tt.Start(ctx, &wg) - wg.Wait() - - t.progress.Increment() + taskQueue.Add(fmt.Sprintf("Auto-tagging %s", ss.DisplayName()), func(ctx context.Context) { + tt := autoTagImageTask{ + repository: r, + image: ss, + performers: t.performers, + studios: t.studios, + tags: t.tags, + cache: t.cache, + } + tt.Start(ctx) + t.progress.Increment() + }) } - if len(images) != batchSize { - more = false - } else { - *findFilter.Page++ + if len(images) < batchSize { + return + } - if *findFilter.Page%10 == 1 { - logger.Infof("Processed %d images...", (*findFilter.Page-1)*batchSize) - } + lastID = images[len(images)-1].ID + processed += len(images) + if processed%(batchSize*10) == 0 { + logger.Infof("Processed %d images...", processed) } } } @@ -671,25 +722,40 @@ func (t *autoTagFilesTask) processGalleries(ctx context.Context) { return } - logger.Info("Auto-tagging galleries...") + workers := config.GetInstance().GetParallelTasksWithAutoDetection() + logger.Infof("Auto-tagging galleries (workers=%d)...", workers) - batchSize := 1000 + const batchSize = 1000 + const queueSize = batchSize * 4 - findFilter := models.BatchFindFilter(batchSize) + findFilter := models.KeysetFindFilter(batchSize) galleryFilter := t.makeGalleryFilter() - r := t.repository - more := true - for more { + taskQueue := job.NewTaskQueue(ctx, t.progress, queueSize, workers) + defer taskQueue.Close() + + var lastID, processed int + for { + filter := galleryFilter + if lastID != 0 { + filter = &models.GalleryFilterType{ + ID: &models.IntCriterionInput{ + Value: lastID, + Modifier: models.CriterionModifierGreaterThan, + }, + } + filter.And = galleryFilter + } + var galleries []*models.Gallery if err := r.WithReadTxn(ctx, func(ctx context.Context) error { var err error - galleries, _, err = r.Gallery.Query(ctx, galleryFilter, findFilter) + galleries, _, err = r.Gallery.Query(ctx, filter, findFilter) return err }); err != nil { if !job.IsCancelled(ctx) { - logger.Errorf("error querying galleries for auto-tag: %w", err) + logger.Errorf("error querying galleries for auto-tag: %v", err) } return } @@ -699,32 +765,28 @@ func (t *autoTagFilesTask) processGalleries(ctx context.Context) { logger.Info("Stopping auto-tag due to user request") return } - - tt := autoTagGalleryTask{ - repository: t.repository, - gallery: ss, - performers: t.performers, - studios: t.studios, - tags: t.tags, - cache: t.cache, - } - - var wg sync.WaitGroup - wg.Add(1) - go tt.Start(ctx, &wg) - wg.Wait() - - t.progress.Increment() + taskQueue.Add(fmt.Sprintf("Auto-tagging %s", ss.DisplayName()), func(ctx context.Context) { + tt := autoTagGalleryTask{ + repository: r, + gallery: ss, + performers: t.performers, + studios: t.studios, + tags: t.tags, + cache: t.cache, + } + tt.Start(ctx) + t.progress.Increment() + }) } - if len(galleries) != batchSize { - more = false - } else { - *findFilter.Page++ + if len(galleries) < batchSize { + return + } - if *findFilter.Page%10 == 1 { - logger.Infof("Processed %d galleries...", (*findFilter.Page-1)*batchSize) - } + lastID = galleries[len(galleries)-1].ID + processed += len(galleries) + if processed%(batchSize*10) == 0 { + logger.Infof("Processed %d galleries...", processed) } } } @@ -763,27 +825,28 @@ type autoTagSceneTask struct { cache *match.Cache } -func (t *autoTagSceneTask) Start(ctx context.Context, wg *sync.WaitGroup) { - defer wg.Done() +func (t *autoTagSceneTask) Start(ctx context.Context) { r := t.repository - if err := r.WithTxn(ctx, func(ctx context.Context) error { + tagger := &autotag.Tagger{TxnManager: r.TxnManager, Cache: t.cache} + + if err := r.WithDB(ctx, func(ctx context.Context) error { if t.scene.Path == "" { // nothing to do return nil } if t.performers { - if err := autotag.ScenePerformers(ctx, t.scene, r.Scene, r.Performer, t.cache); err != nil { + if err := tagger.ScenePerformersAtPath(ctx, t.scene, r.Scene, r.Performer); err != nil { return fmt.Errorf("tagging scene performers for %s: %v", t.scene.DisplayName(), err) } } if t.studios { - if err := autotag.SceneStudios(ctx, t.scene, r.Scene, r.Studio, t.cache); err != nil { + if err := tagger.SceneStudiosAtPath(ctx, t.scene, r.Scene, r.Studio); err != nil { return fmt.Errorf("tagging scene studio for %s: %v", t.scene.DisplayName(), err) } } if t.tags { - if err := autotag.SceneTags(ctx, t.scene, r.Scene, r.Tag, t.cache); err != nil { + if err := tagger.SceneTagsAtPath(ctx, t.scene, r.Scene, r.Tag); err != nil { return fmt.Errorf("tagging scene tags for %s: %v", t.scene.DisplayName(), err) } } @@ -807,22 +870,23 @@ type autoTagImageTask struct { cache *match.Cache } -func (t *autoTagImageTask) Start(ctx context.Context, wg *sync.WaitGroup) { - defer wg.Done() +func (t *autoTagImageTask) Start(ctx context.Context) { r := t.repository - if err := r.WithTxn(ctx, func(ctx context.Context) error { + tagger := &autotag.Tagger{TxnManager: r.TxnManager, Cache: t.cache} + + if err := r.WithDB(ctx, func(ctx context.Context) error { if t.performers { - if err := autotag.ImagePerformers(ctx, t.image, r.Image, r.Performer, t.cache); err != nil { + if err := tagger.ImagePerformersAtPath(ctx, t.image, r.Image, r.Performer); err != nil { return fmt.Errorf("tagging image performers for %s: %v", t.image.DisplayName(), err) } } if t.studios { - if err := autotag.ImageStudios(ctx, t.image, r.Image, r.Studio, t.cache); err != nil { + if err := tagger.ImageStudiosAtPath(ctx, t.image, r.Image, r.Studio); err != nil { return fmt.Errorf("tagging image studio for %s: %v", t.image.DisplayName(), err) } } if t.tags { - if err := autotag.ImageTags(ctx, t.image, r.Image, r.Tag, t.cache); err != nil { + if err := tagger.ImageTagsAtPath(ctx, t.image, r.Image, r.Tag); err != nil { return fmt.Errorf("tagging image tags for %s: %v", t.image.DisplayName(), err) } } @@ -846,22 +910,23 @@ type autoTagGalleryTask struct { cache *match.Cache } -func (t *autoTagGalleryTask) Start(ctx context.Context, wg *sync.WaitGroup) { - defer wg.Done() +func (t *autoTagGalleryTask) Start(ctx context.Context) { r := t.repository - if err := r.WithTxn(ctx, func(ctx context.Context) error { + tagger := &autotag.Tagger{TxnManager: r.TxnManager, Cache: t.cache} + + if err := r.WithDB(ctx, func(ctx context.Context) error { if t.performers { - if err := autotag.GalleryPerformers(ctx, t.gallery, r.Gallery, r.Performer, t.cache); err != nil { + if err := tagger.GalleryPerformersAtPath(ctx, t.gallery, r.Gallery, r.Performer); err != nil { return fmt.Errorf("tagging gallery performers for %s: %v", t.gallery.DisplayName(), err) } } if t.studios { - if err := autotag.GalleryStudios(ctx, t.gallery, r.Gallery, r.Studio, t.cache); err != nil { + if err := tagger.GalleryStudiosAtPath(ctx, t.gallery, r.Gallery, r.Studio); err != nil { return fmt.Errorf("tagging gallery studio for %s: %v", t.gallery.DisplayName(), err) } } if t.tags { - if err := autotag.GalleryTags(ctx, t.gallery, r.Gallery, r.Tag, t.cache); err != nil { + if err := tagger.GalleryTagsAtPath(ctx, t.gallery, r.Gallery, r.Tag); err != nil { return fmt.Errorf("tagging gallery tags for %s: %v", t.gallery.DisplayName(), err) } } diff --git a/pkg/match/cache.go b/pkg/match/cache.go index 002d67116c..2988520e0e 100644 --- a/pkg/match/cache.go +++ b/pkg/match/cache.go @@ -2,17 +2,354 @@ package match import ( "context" + "regexp" + "strings" + "sync" "github.com/stashapp/stash/pkg/models" ) const singleFirstCharacterRegex = `^[\p{L}][.\-_ ]` -// Cache is used to cache queries that should not change across an autotag process. +var singleFirstCharacterRE = regexp.MustCompile(singleFirstCharacterRegex) + +// firstTwoRunesLower returns the first two runes of s, lowercased. Returns +// "" if s has fewer than two runes. Mirrors what getPathWords produces for +// path words, so the two can be compared as index keys. +func firstTwoRunesLower(s string) string { + lower := strings.ToLower(s) + runes := []rune(lower) + if len(runes) < 2 { + return "" + } + return string(runes[0:2]) +} + +// performerCandidates returns the set of preloaded performers that should +// be regex-checked for the given path words. Mirrors the SQL +// `name LIKE 'xx%' OR name LIKE 'yy%'` prefilter, plus always-check +// performers whose name begins with a single-letter word (which the 2-rune +// prefix lookup can't reach). +func (c *Cache) performerCandidates(pathWords []string) []*models.Performer { + if len(c.performerByPrefix) == 0 && len(c.performerAlwaysCheck) == 0 { + return nil + } + seen := make(map[int]bool, len(pathWords)*2) + out := make([]*models.Performer, 0, len(pathWords)*2) + for _, w := range pathWords { + key := strings.ToLower(w) + for _, p := range c.performerByPrefix[key] { + if !seen[p.ID] { + seen[p.ID] = true + out = append(out, p) + } + } + } + for _, p := range c.performerAlwaysCheck { + if !seen[p.ID] { + seen[p.ID] = true + out = append(out, p) + } + } + return out +} + +func (c *Cache) studioCandidates(pathWords []string) []cachedStudio { + if len(c.studioByPrefix) == 0 && len(c.studioAlwaysCheck) == 0 { + return nil + } + seen := make(map[int]bool, len(pathWords)*2) + out := make([]cachedStudio, 0, len(pathWords)*2) + for _, w := range pathWords { + key := strings.ToLower(w) + for _, s := range c.studioByPrefix[key] { + if !seen[s.Studio.ID] { + seen[s.Studio.ID] = true + out = append(out, s) + } + } + } + for _, s := range c.studioAlwaysCheck { + if !seen[s.Studio.ID] { + seen[s.Studio.ID] = true + out = append(out, s) + } + } + return out +} + +func (c *Cache) tagCandidates(pathWords []string) []cachedTag { + if len(c.tagByPrefix) == 0 && len(c.tagAlwaysCheck) == 0 { + return nil + } + seen := make(map[int]bool, len(pathWords)*2) + out := make([]cachedTag, 0, len(pathWords)*2) + for _, w := range pathWords { + key := strings.ToLower(w) + for _, t := range c.tagByPrefix[key] { + if !seen[t.Tag.ID] { + seen[t.Tag.ID] = true + out = append(out, t) + } + } + } + for _, t := range c.tagAlwaysCheck { + if !seen[t.Tag.ID] { + seen[t.Tag.ID] = true + out = append(out, t) + } + } + return out +} + +// Cache is used to cache queries that should not change across an autotag +// process. Safe for concurrent use by multiple goroutines. type Cache struct { + performersOnce sync.Once + performersErr error + studiosOnce sync.Once + studiosErr error + tagsOnce sync.Once + tagsErr error + singleCharPerformers []*models.Performer singleCharStudios []*models.Studio singleCharTags []*models.Tag + + // Preloaded candidate sets. When populated (via PreloadX), the + // PathTo* functions skip the per-path QueryForAutoTag DB roundtrip + // and consult the in-memory prefix index instead. Nil means + // "not preloaded, fall back to the old SQL-prefilter path". + allPerformers []*models.Performer + allStudios []cachedStudio + allTags []cachedTag + + // Prefix indexes built at preload time. Map key is the first two + // lowercased runes of name (or alias, for studios/tags). The + // alwaysCandidate slice holds entries whose first "word" is a + // single letter — they wouldn't be reached by 2-rune path word + // lookup, so they must always be checked (mirroring the existing + // single-letter regex query). + performerByPrefix map[string][]*models.Performer + performerAlwaysCheck []*models.Performer + studioByPrefix map[string][]cachedStudio + studioAlwaysCheck []cachedStudio + tagByPrefix map[string][]cachedTag + tagAlwaysCheck []cachedTag + + // regexpCache maps regexpCacheKey → *regexp.Regexp. sync.Map rather + // than the hashicorp LRU used in pkg/sqlite/regex.go: this cache is + // job-scoped (so LRU's eviction buys us nothing) and is hit by every + // worker on every candidate (so a single-mutex Get becomes the + // bottleneck). sync.Map's read-optimised path sidesteps that. + regexpCache sync.Map +} + +// cachedStudio bundles a studio with its aliases so PathToStudio can match +// against both without an N+1 GetAliases query. +type cachedStudio struct { + Studio *models.Studio + Aliases []string +} + +// cachedTag bundles a tag with its aliases so PathToTags can match against +// both without an N+1 GetAliases query. +type cachedTag struct { + Tag *models.Tag + Aliases []string +} + +// PreloadPerformers loads all non-ignored performers into the cache and +// builds a 2-rune prefix index so subsequent PathToPerformers calls can +// skip both the per-path QueryForAutoTag and the per-candidate regex +// when no prefix matches. +func (c *Cache) PreloadPerformers(ctx context.Context, reader models.PerformerAutoTagQueryer) error { + if c.allPerformers != nil { + return nil + } + ignoreAutoTag := false + perPage := -1 + perfs, _, err := reader.Query(ctx, &models.PerformerFilterType{ + IgnoreAutoTag: &ignoreAutoTag, + }, &models.FindFilterType{PerPage: &perPage}) + if err != nil { + return err + } + if perfs == nil { + perfs = []*models.Performer{} + } + c.allPerformers = perfs + + c.performerByPrefix = make(map[string][]*models.Performer, len(perfs)) + for _, p := range perfs { + if prefix := firstTwoRunesLower(p.Name); prefix != "" { + c.performerByPrefix[prefix] = append(c.performerByPrefix[prefix], p) + } + if singleFirstCharacterRE.MatchString(p.Name) { + c.performerAlwaysCheck = append(c.performerAlwaysCheck, p) + } + } + return nil +} + +// loadAllAliases loads aliases for the given ids. Uses the reader's bulk +// GetAllAliases method when available (avoiding the N+1 per-id roundtrip); +// otherwise falls back to per-id GetAliases. +func loadAllAliases(ctx context.Context, reader models.AliasLoader, ids []int) (map[int][]string, error) { + if bulk, ok := reader.(models.AllAliasLoader); ok { + return bulk.GetAllAliases(ctx) + } + ret := make(map[int][]string, len(ids)) + for _, id := range ids { + a, err := reader.GetAliases(ctx, id) + if err != nil { + return nil, err + } + if len(a) > 0 { + ret[id] = a + } + } + return ret, nil +} + +// PreloadStudios loads all non-ignored studios plus their aliases into the +// cache and builds a 2-rune prefix index (over names AND aliases, mirroring +// the SQL LEFT JOIN on studio_aliases). +func (c *Cache) PreloadStudios(ctx context.Context, reader models.StudioAutoTagQueryer) error { + if c.allStudios != nil { + return nil + } + ignoreAutoTag := false + perPage := -1 + studios, _, err := reader.Query(ctx, &models.StudioFilterType{ + IgnoreAutoTag: &ignoreAutoTag, + }, &models.FindFilterType{PerPage: &perPage}) + if err != nil { + return err + } + ids := make([]int, len(studios)) + for i, s := range studios { + ids[i] = s.ID + } + aliasesByID, err := loadAllAliases(ctx, reader, ids) + if err != nil { + return err + } + out := make([]cachedStudio, len(studios)) + c.studioByPrefix = make(map[string][]cachedStudio, len(studios)) + seenPerPrefix := make(map[string]map[int]bool) + for i, s := range studios { + aliases := aliasesByID[s.ID] + cs := cachedStudio{Studio: s, Aliases: aliases} + out[i] = cs + + c.indexByPrefix(s.ID, s.Name, aliases, seenPerPrefix, func(prefix string) { + c.studioByPrefix[prefix] = append(c.studioByPrefix[prefix], cs) + }) + if hasSingleFirstChar(s.Name, aliases) { + c.studioAlwaysCheck = append(c.studioAlwaysCheck, cs) + } + } + c.allStudios = out + return nil +} + +// PreloadTags loads all non-ignored tags plus their aliases into the cache +// and builds a 2-rune prefix index (over names AND aliases). +func (c *Cache) PreloadTags(ctx context.Context, reader models.TagAutoTagQueryer) error { + if c.allTags != nil { + return nil + } + ignoreAutoTag := false + perPage := -1 + tags, _, err := reader.Query(ctx, &models.TagFilterType{ + IgnoreAutoTag: &ignoreAutoTag, + }, &models.FindFilterType{PerPage: &perPage}) + if err != nil { + return err + } + ids := make([]int, len(tags)) + for i, t := range tags { + ids[i] = t.ID + } + aliasesByID, err := loadAllAliases(ctx, reader, ids) + if err != nil { + return err + } + out := make([]cachedTag, len(tags)) + c.tagByPrefix = make(map[string][]cachedTag, len(tags)) + seenPerPrefix := make(map[string]map[int]bool) + for i, t := range tags { + aliases := aliasesByID[t.ID] + ct := cachedTag{Tag: t, Aliases: aliases} + out[i] = ct + + c.indexByPrefix(t.ID, t.Name, aliases, seenPerPrefix, func(prefix string) { + c.tagByPrefix[prefix] = append(c.tagByPrefix[prefix], ct) + }) + if hasSingleFirstChar(t.Name, aliases) { + c.tagAlwaysCheck = append(c.tagAlwaysCheck, ct) + } + } + c.allTags = out + return nil +} + +// indexByPrefix records the entity under every distinct 2-rune prefix of +// its name/aliases (deduping so a name+alias that share a prefix bucket +// only add the entity once). +func (c *Cache) indexByPrefix(id int, name string, aliases []string, seen map[string]map[int]bool, add func(prefix string)) { + emit := func(s string) { + prefix := firstTwoRunesLower(s) + if prefix == "" { + return + } + if seen[prefix] == nil { + seen[prefix] = make(map[int]bool) + } + if !seen[prefix][id] { + seen[prefix][id] = true + add(prefix) + } + } + emit(name) + for _, a := range aliases { + emit(a) + } +} + +func hasSingleFirstChar(name string, aliases []string) bool { + if singleFirstCharacterRE.MatchString(name) { + return true + } + for _, a := range aliases { + if singleFirstCharacterRE.MatchString(a) { + return true + } + } + return false +} + +type regexpCacheKey struct { + name string + useUnicode bool +} + +// nameRegexp returns a compiled regexp for the given name, caching the +// result so repeated autotag calls across many files don't pay the +// compile cost each time. +func (c *Cache) nameRegexp(name string, useUnicode bool) *regexp.Regexp { + if c == nil { + return nameToRegexp(name, useUnicode) + } + + key := regexpCacheKey{name: name, useUnicode: useUnicode} + if r, ok := c.regexpCache.Load(key); ok { + return r.(*regexp.Regexp) + } + r := nameToRegexp(name, useUnicode) + actual, _ := c.regexpCache.LoadOrStore(key, r) + return actual.(*regexp.Regexp) } // getSingleLetterPerformers returns all performers with names that start with single character words. @@ -25,7 +362,7 @@ func getSingleLetterPerformers(ctx context.Context, c *Cache, reader models.Perf c = &Cache{} } - if c.singleCharPerformers == nil { + c.performersOnce.Do(func() { pp := -1 performers, _, err := reader.Query(ctx, &models.PerformerFilterType{ Name: &models.StringCriterionInput{ @@ -37,18 +374,18 @@ func getSingleLetterPerformers(ctx context.Context, c *Cache, reader models.Perf }) if err != nil { - return nil, err + c.performersErr = err + return } if len(performers) == 0 { - // make singleWordPerformers not nil c.singleCharPerformers = make([]*models.Performer, 0) } else { c.singleCharPerformers = performers } - } + }) - return c.singleCharPerformers, nil + return c.singleCharPerformers, c.performersErr } // getSingleLetterStudios returns all studios with names that start with single character words. @@ -58,7 +395,7 @@ func getSingleLetterStudios(ctx context.Context, c *Cache, reader models.StudioA c = &Cache{} } - if c.singleCharStudios == nil { + c.studiosOnce.Do(func() { pp := -1 studios, _, err := reader.Query(ctx, &models.StudioFilterType{ Name: &models.StringCriterionInput{ @@ -70,18 +407,18 @@ func getSingleLetterStudios(ctx context.Context, c *Cache, reader models.StudioA }) if err != nil { - return nil, err + c.studiosErr = err + return } if len(studios) == 0 { - // make singleWordStudios not nil c.singleCharStudios = make([]*models.Studio, 0) } else { c.singleCharStudios = studios } - } + }) - return c.singleCharStudios, nil + return c.singleCharStudios, c.studiosErr } // getSingleLetterTags returns all tags with names that start with single character words. @@ -91,7 +428,7 @@ func getSingleLetterTags(ctx context.Context, c *Cache, reader models.TagAutoTag c = &Cache{} } - if c.singleCharTags == nil { + c.tagsOnce.Do(func() { pp := -1 tags, _, err := reader.Query(ctx, &models.TagFilterType{ Name: &models.StringCriterionInput{ @@ -111,16 +448,16 @@ func getSingleLetterTags(ctx context.Context, c *Cache, reader models.TagAutoTag }) if err != nil { - return nil, err + c.tagsErr = err + return } if len(tags) == 0 { - // make singleWordTags not nil c.singleCharTags = make([]*models.Tag, 0) } else { c.singleCharTags = tags } - } + }) - return c.singleCharTags, nil + return c.singleCharTags, c.tagsErr } diff --git a/pkg/match/cache_test.go b/pkg/match/cache_test.go new file mode 100644 index 0000000000..9641e923db --- /dev/null +++ b/pkg/match/cache_test.go @@ -0,0 +1,204 @@ +package match + +import ( + "context" + "slices" + "testing" + + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/models/mocks" +) + +func TestFirstTwoRunesLower(t *testing.T) { + t.Parallel() + + tests := []struct { + in string + want string + }{ + {"alice smith", "al"}, + {"ALICE", "al"}, + {"Àbc", "àb"}, + {"伏字 name", "伏字"}, + {"ab", "ab"}, + {"a", ""}, // single rune -> no prefix + {"", ""}, // empty -> no prefix + {"X Man", "x "}, // space is preserved in 2-rune prefix + } + + for _, tt := range tests { + t.Run(tt.in, func(t *testing.T) { + t.Parallel() + if got := firstTwoRunesLower(tt.in); got != tt.want { + t.Errorf("firstTwoRunesLower(%q) = %q, want %q", tt.in, got, tt.want) + } + }) + } +} + +func TestCacheNameRegexpCaches(t *testing.T) { + t.Parallel() + + c := &Cache{} + r1 := c.nameRegexp("alice smith", true) + r2 := c.nameRegexp("alice smith", true) + if r1 != r2 { + t.Error("expected cached regexp to be reused across calls") + } + + // Different useUnicode flag -> different cached regexp. + r3 := c.nameRegexp("alice smith", false) + if r3 == r1 { + t.Error("expected ASCII and unicode variants to be distinct cached entries") + } + + // Nil cache must still return a valid regexp, just uncached. + var nilCache *Cache + if got := nilCache.nameRegexp("alice smith", true); got == nil { + t.Error("nil cache should still return a regexp") + } +} + +func TestPreloadPerformersBuildsIndex(t *testing.T) { + t.Parallel() + + alice := &models.Performer{ID: 1, Name: "Alice Smith"} + bob := &models.Performer{ID: 2, Name: "bob jones"} + xman := &models.Performer{ID: 3, Name: "X Man"} + ignored := &models.Performer{ID: 4, Name: "ignored", IgnoreAutoTag: true} + + performers := []*models.Performer{alice, bob, xman, ignored} + db := mocks.NewDatabase() + primePerformerMock(db.Performer, performers) + + c := &Cache{} + if err := c.PreloadPerformers(context.Background(), db.Performer); err != nil { + t.Fatalf("PreloadPerformers: %v", err) + } + + // allPerformers excludes IgnoreAutoTag=true. + if got := len(c.allPerformers); got != 3 { + t.Errorf("allPerformers len = %d, want 3 (ignored must be excluded)", got) + } + + // Prefix "al" -> alice, "bo" -> bob, "x " -> xman. + assertBucket := func(prefix string, wantIDs []int) { + t.Helper() + var gotIDs []int + for _, p := range c.performerByPrefix[prefix] { + gotIDs = append(gotIDs, p.ID) + } + slices.Sort(gotIDs) + if !slices.Equal(gotIDs, wantIDs) { + t.Errorf("bucket %q = %v, want %v", prefix, gotIDs, wantIDs) + } + } + assertBucket("al", []int{1}) + assertBucket("bo", []int{2}) + assertBucket("x ", []int{3}) + + // Single-letter-first-word performer must also be in alwaysCheck. + var alwaysIDs []int + for _, p := range c.performerAlwaysCheck { + alwaysIDs = append(alwaysIDs, p.ID) + } + if !slices.Equal(alwaysIDs, []int{3}) { + t.Errorf("alwaysCheck IDs = %v, want [3]", alwaysIDs) + } + + // Idempotent: second call is a no-op. + if err := c.PreloadPerformers(context.Background(), db.Performer); err != nil { + t.Fatalf("second PreloadPerformers: %v", err) + } + if got := len(c.allPerformers); got != 3 { + t.Errorf("after idempotent call allPerformers len = %d, want 3", got) + } +} + +func TestPreloadStudiosIndexesAliasPrefixes(t *testing.T) { + t.Parallel() + + // Name "Acme" shares no prefix with alias "Widgets" — both must be + // reachable by their own 2-rune prefix. + s := &models.Studio{ID: 1, Name: "Acme Corp"} + ignored := &models.Studio{ID: 2, Name: "ignored", IgnoreAutoTag: true} + + db := mocks.NewDatabase() + primeStudioMock(db.Studio, []*models.Studio{s, ignored}, map[int][]string{1: {"Widgets Inc"}}) + + c := &Cache{} + if err := c.PreloadStudios(context.Background(), db.Studio); err != nil { + t.Fatalf("PreloadStudios: %v", err) + } + + if got := len(c.allStudios); got != 1 { + t.Errorf("allStudios len = %d, want 1 (ignored must be excluded)", got) + } + + // "ac" bucket has the studio (via name), "wi" bucket has it (via alias). + if len(c.studioByPrefix["ac"]) != 1 || c.studioByPrefix["ac"][0].Studio.ID != 1 { + t.Errorf("bucket 'ac' should hold studio 1, got %+v", c.studioByPrefix["ac"]) + } + if len(c.studioByPrefix["wi"]) != 1 || c.studioByPrefix["wi"][0].Studio.ID != 1 { + t.Errorf("bucket 'wi' should hold studio 1, got %+v", c.studioByPrefix["wi"]) + } +} + +func TestPreloadStudiosDedupsSharedPrefix(t *testing.T) { + t.Parallel() + + // Name and two aliases all share prefix "pr"; the bucket must contain + // the studio exactly once. + s := &models.Studio{ID: 1, Name: "Primary"} + db := mocks.NewDatabase() + primeStudioMock(db.Studio, []*models.Studio{s}, map[int][]string{1: {"Primary Nick", "Primary Alt"}}) + + c := &Cache{} + if err := c.PreloadStudios(context.Background(), db.Studio); err != nil { + t.Fatal(err) + } + + if got := len(c.studioByPrefix["pr"]); got != 1 { + t.Errorf("bucket 'pr' should have 1 entry, got %d", got) + } +} + +func TestPreloadTagsIndexesAliasPrefixes(t *testing.T) { + t.Parallel() + + db := mocks.NewDatabase() + primeTagMock(db.Tag, []*models.Tag{{ID: 1, Name: "documentary"}}, map[int][]string{1: {"film"}}) + + c := &Cache{} + if err := c.PreloadTags(context.Background(), db.Tag); err != nil { + t.Fatal(err) + } + + if len(c.tagByPrefix["do"]) != 1 || c.tagByPrefix["do"][0].Tag.ID != 1 { + t.Errorf("bucket 'do' should hold tag 1") + } + if len(c.tagByPrefix["fi"]) != 1 || c.tagByPrefix["fi"][0].Tag.ID != 1 { + t.Errorf("bucket 'fi' should hold tag 1 (via alias)") + } +} + +func TestCandidateLookupDedupesAcrossPathWords(t *testing.T) { + t.Parallel() + + // A performer with name "alabama" falls in bucket "al". If a path has + // two words that both map to bucket "al" (e.g., from separate tokens), + // the candidate must appear exactly once. + p := &models.Performer{ID: 1, Name: "alabama"} + db := mocks.NewDatabase() + primePerformerMock(db.Performer, []*models.Performer{p}) + + c := &Cache{} + if err := c.PreloadPerformers(context.Background(), db.Performer); err != nil { + t.Fatal(err) + } + + got := c.performerCandidates([]string{"al", "AL", "al"}) // same bucket three times + if len(got) != 1 { + t.Errorf("expected 1 candidate after dedup, got %d: %v", len(got), got) + } +} diff --git a/pkg/match/path.go b/pkg/match/path.go index 1755e70126..85fe097303 100644 --- a/pkg/match/path.go +++ b/pkg/match/path.go @@ -94,6 +94,36 @@ func nameMatchesPath(name, path string) int { return regexpMatchesPath(re, path) } +// pathMatcher holds per-path precomputed values so they aren't recomputed +// for every candidate name. `allASCII` and `strings.ToLower(path)` were +// running once per (candidate, file) pair before; under a worker pool with +// thousands of candidates per file that was the dominant allocation. +type pathMatcher struct { + loweredPath string + useUnicode bool + cache *Cache +} + +func newPathMatcher(path string, cache *Cache) pathMatcher { + return pathMatcher{ + loweredPath: strings.ToLower(path), + useUnicode: !allASCII(path), + cache: cache, + } +} + +// match returns the right-most index where name matches the path, or -1. +// Uses the cache's compiled-regexp table so each name is compiled once per +// autotag run instead of once per file. +func (m *pathMatcher) match(name string) int { + re := m.cache.nameRegexp(name, m.useUnicode) + found := re.FindAllStringIndex(m.loweredPath, -1) + if found == nil { + return -1 + } + return found[len(found)-1][0] +} + // nameToRegexp compiles a regexp pattern to match paths from the given name. // Set useUnicode to true if this regexp is to be used on any strings with unicode characters. func nameToRegexp(name string, useUnicode bool) *regexp.Regexp { @@ -141,30 +171,47 @@ func getPerformers(ctx context.Context, words []string, performerReader models.P return append(performers, swPerformers...), nil } +// PathToPerformers returns performers whose name matches the given path. +// +// When the cache has been preloaded via Cache.PreloadPerformers, the full +// non-ignored performer set is already in memory and a 2-rune prefix index +// narrows candidates before regex-matching — this is the path the bulk +// file-based auto-tag job takes. Otherwise (e.g., the built-in scraper, +// which runs on a single scene per request) falls back to a per-call SQL +// prefilter via reader.QueryForAutoTag. func PathToPerformers(ctx context.Context, path string, reader models.PerformerAutoTagQueryer, cache *Cache, trimExt bool) ([]*models.Performer, error) { - words := getPathWords(path, trimExt) - - performers, err := getPerformers(ctx, words, reader, cache) - if err != nil { - return nil, err + var performers []*models.Performer + if cache != nil && cache.allPerformers != nil { + performers = cache.performerCandidates(getPathWords(path, trimExt)) + } else { + words := getPathWords(path, trimExt) + var err error + performers, err = getPerformers(ctx, words, reader, cache) + if err != nil { + return nil, err + } } + pm := newPathMatcher(path, cache) var ret []*models.Performer for _, p := range performers { matches := false - if nameMatchesPath(p.Name, path) != -1 { + if pm.match(p.Name) != -1 { matches = true } // TODO - disabled alias matching until we can get finer - // control over the matching + // control over the matching. To re-enable: + // - uncomment this block (fallback path) + // - have Cache.PreloadPerformers load aliases (e.g. via + // loadAllAliases, as PreloadStudios/PreloadTags do) and + // iterate them here in the preloaded path too // if !matches { // if err := p.LoadAliases(ctx, reader); err != nil { // return nil, err // } - // for _, alias := range p.Aliases.List() { - // if nameMatchesPath(alias, path) != -1 { + // if pm.match(alias) != -1 { // matches = true // break // } @@ -193,13 +240,34 @@ func getStudios(ctx context.Context, words []string, reader models.StudioAutoTag return append(studios, swStudios...), nil } -// PathToStudio returns the Studio that matches the given path. -// Where multiple matching studios are found, the one that matches the latest -// position in the path is returned. +// PathToStudio returns the studio whose name or alias matches the given +// path. Where multiple match, the one matching the latest position wins. +// +// See PathToPerformers for the preloaded-vs-fallback behavior. func PathToStudio(ctx context.Context, path string, reader models.StudioAutoTagQueryer, cache *Cache, trimExt bool) (*models.Studio, error) { + pm := newPathMatcher(path, cache) + + if cache != nil && cache.allStudios != nil { + candidates := cache.studioCandidates(getPathWords(path, trimExt)) + var ret *models.Studio + index := -1 + for _, c := range candidates { + if matchIndex := pm.match(c.Studio.Name); matchIndex != -1 && matchIndex > index { + ret = c.Studio + index = matchIndex + } + for _, alias := range c.Aliases { + if matchIndex := pm.match(alias); matchIndex != -1 && matchIndex > index { + ret = c.Studio + index = matchIndex + } + } + } + return ret, nil + } + words := getPathWords(path, trimExt) candidates, err := getStudios(ctx, words, reader, cache) - if err != nil { return nil, err } @@ -207,8 +275,7 @@ func PathToStudio(ctx context.Context, path string, reader models.StudioAutoTagQ var ret *models.Studio index := -1 for _, c := range candidates { - matchIndex := nameMatchesPath(c.Name, path) - if matchIndex != -1 && matchIndex > index { + if matchIndex := pm.match(c.Name); matchIndex != -1 && matchIndex > index { ret = c index = matchIndex } @@ -217,10 +284,8 @@ func PathToStudio(ctx context.Context, path string, reader models.StudioAutoTagQ if err != nil { return nil, err } - for _, alias := range aliases { - matchIndex = nameMatchesPath(alias, path) - if matchIndex != -1 && matchIndex > index { + if matchIndex := pm.match(alias); matchIndex != -1 && matchIndex > index { ret = c index = matchIndex } @@ -244,10 +309,32 @@ func getTags(ctx context.Context, words []string, reader models.TagAutoTagQuerye return append(tags, swTags...), nil } +// PathToTags returns tags whose name or alias matches the given path. +// +// See PathToPerformers for the preloaded-vs-fallback behavior. func PathToTags(ctx context.Context, path string, reader models.TagAutoTagQueryer, cache *Cache, trimExt bool) ([]*models.Tag, error) { + pm := newPathMatcher(path, cache) + + if cache != nil && cache.allTags != nil { + candidates := cache.tagCandidates(getPathWords(path, trimExt)) + var ret []*models.Tag + for _, c := range candidates { + if pm.match(c.Tag.Name) != -1 { + ret = append(ret, c.Tag) + continue + } + for _, alias := range c.Aliases { + if pm.match(alias) != -1 { + ret = append(ret, c.Tag) + break + } + } + } + return ret, nil + } + words := getPathWords(path, trimExt) tags, err := getTags(ctx, words, reader, cache) - if err != nil { return nil, err } @@ -255,23 +342,21 @@ func PathToTags(ctx context.Context, path string, reader models.TagAutoTagQuerye var ret []*models.Tag for _, t := range tags { matches := false - if nameMatchesPath(t.Name, path) != -1 { + if pm.match(t.Name) != -1 { matches = true } - if !matches { aliases, err := reader.GetAliases(ctx, t.ID) if err != nil { return nil, err } for _, alias := range aliases { - if nameMatchesPath(alias, path) != -1 { + if pm.match(alias) != -1 { matches = true break } } } - if matches { ret = append(ret, t) } diff --git a/pkg/match/path_semantic_test.go b/pkg/match/path_semantic_test.go new file mode 100644 index 0000000000..0c734f668c --- /dev/null +++ b/pkg/match/path_semantic_test.go @@ -0,0 +1,426 @@ +package match + +import ( + "context" + "slices" + "testing" + + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/models/mocks" + "github.com/stretchr/testify/mock" +) + +// Path-matching semantic tests that lock in the behavior of +// PathTo{Performers,Studio,Tags} via the generated testify mocks in +// pkg/models/mocks. These are the regression guard when the candidate- +// lookup strategy changes (e.g., replacing the SQL prefilter with an +// in-memory matcher): each case runs against both cache=nil and a +// preloaded cache, asserting identical output. + +// --- mock setup helpers --- + +// preloadFilter matches the filter PreloadX passes: IgnoreAutoTag=false. +// singleLetterFilter matches the filter the single-letter-cache path +// passes: a regex in Name. Keeping them disjoint means testify will +// route each Query call to the right stub regardless of declaration +// order. +func performerPreloadFilter() interface{} { + return mock.MatchedBy(func(f *models.PerformerFilterType) bool { + return f != nil && f.IgnoreAutoTag != nil && !*f.IgnoreAutoTag + }) +} +func performerSingleLetterFilter() interface{} { + return mock.MatchedBy(func(f *models.PerformerFilterType) bool { + return f != nil && f.Name != nil + }) +} +func studioPreloadFilter() interface{} { + return mock.MatchedBy(func(f *models.StudioFilterType) bool { + return f != nil && f.IgnoreAutoTag != nil && !*f.IgnoreAutoTag + }) +} +func studioSingleLetterFilter() interface{} { + return mock.MatchedBy(func(f *models.StudioFilterType) bool { + return f != nil && f.Name != nil + }) +} +func tagPreloadFilter() interface{} { + return mock.MatchedBy(func(f *models.TagFilterType) bool { + return f != nil && f.IgnoreAutoTag != nil && !*f.IgnoreAutoTag + }) +} +func tagSingleLetterFilter() interface{} { + return mock.MatchedBy(func(f *models.TagFilterType) bool { + return f != nil && f.Name != nil + }) +} + +// primePerformerMock sets up a PerformerReaderWriter to serve both the +// no-preload path (QueryForAutoTag returns all non-ignored; single-letter +// Query returns nothing) and the preload path (Query with IgnoreAutoTag +// filter returns all non-ignored). All expectations are .Maybe() because +// which ones fire depends on whether the test passes a cache. +func primePerformerMock(m *mocks.PerformerReaderWriter, performers []*models.Performer) { + var nonIgnored []*models.Performer + for _, p := range performers { + if !p.IgnoreAutoTag { + nonIgnored = append(nonIgnored, p) + } + } + m.On("QueryForAutoTag", mock.Anything, mock.Anything).Return(nonIgnored, nil).Maybe() + m.On("Query", mock.Anything, performerPreloadFilter(), mock.Anything).Return(nonIgnored, len(nonIgnored), nil).Maybe() + m.On("Query", mock.Anything, performerSingleLetterFilter(), mock.Anything).Return(nil, 0, nil).Maybe() +} + +func primeStudioMock(m *mocks.StudioReaderWriter, studios []*models.Studio, aliases map[int][]string) { + var nonIgnored []*models.Studio + for _, s := range studios { + if !s.IgnoreAutoTag { + nonIgnored = append(nonIgnored, s) + } + } + m.On("QueryForAutoTag", mock.Anything, mock.Anything).Return(nonIgnored, nil).Maybe() + m.On("Query", mock.Anything, studioPreloadFilter(), mock.Anything).Return(nonIgnored, len(nonIgnored), nil).Maybe() + m.On("Query", mock.Anything, studioSingleLetterFilter(), mock.Anything).Return(nil, 0, nil).Maybe() + for _, s := range studios { + m.On("GetAliases", mock.Anything, s.ID).Return(aliases[s.ID], nil).Maybe() + } +} + +func primeTagMock(m *mocks.TagReaderWriter, tags []*models.Tag, aliases map[int][]string) { + var nonIgnored []*models.Tag + for _, t := range tags { + if !t.IgnoreAutoTag { + nonIgnored = append(nonIgnored, t) + } + } + m.On("QueryForAutoTag", mock.Anything, mock.Anything).Return(nonIgnored, nil).Maybe() + m.On("Query", mock.Anything, tagPreloadFilter(), mock.Anything).Return(nonIgnored, len(nonIgnored), nil).Maybe() + m.On("Query", mock.Anything, tagSingleLetterFilter(), mock.Anything).Return(nil, 0, nil).Maybe() + for _, t := range tags { + m.On("GetAliases", mock.Anything, t.ID).Return(aliases[t.ID], nil).Maybe() + } +} + +// --- helpers --- + +func perfIDs(ps []*models.Performer) []int { + ids := make([]int, 0, len(ps)) + for _, p := range ps { + ids = append(ids, p.ID) + } + slices.Sort(ids) + return ids +} + +func tagIDs(ts []*models.Tag) []int { + ids := make([]int, 0, len(ts)) + for _, t := range ts { + ids = append(ids, t.ID) + } + slices.Sort(ids) + return ids +} + +// --- tests --- + +func TestPathToPerformers_Semantics(t *testing.T) { + t.Parallel() + ctx := context.Background() + + alice := &models.Performer{ID: 1, Name: "alice smith"} + bob := &models.Performer{ID: 2, Name: "bob jones"} + unicodeP := &models.Performer{ID: 3, Name: "伏字"} + ignored := &models.Performer{ID: 4, Name: "ignored person", IgnoreAutoTag: true} + substr := &models.Performer{ID: 5, Name: "ali"} // substring of "alice" - should NOT match "alice smith.jpg" + + performers := []*models.Performer{alice, bob, unicodeP, ignored, substr} + db := mocks.NewDatabase() + primePerformerMock(db.Performer, performers) + + tests := []struct { + name string + path string + wantIDs []int + }{ + {"plain name match", "/media/alice smith.jpg", []int{1}}, + {"separator variants", "/media/alice.smith.jpg", []int{1}}, + {"separator variants 2", "/media/alice_smith.jpg", []int{1}}, + {"multiple matches", "/media/alice smith and bob jones.jpg", []int{1, 2}}, + {"case insensitive", "/media/ALICE SMITH.jpg", []int{1}}, + {"unicode", "/media/伏字.jpg", []int{3}}, + {"ignore_auto_tag skipped", "/media/ignored person.jpg", nil}, + {"no substring match", "/media/alicent.jpg", nil}, + {"short name does NOT match inside longer", "/media/alice smith.jpg", []int{1}}, // 'ali' should not match + {"short name matches exact", "/media/ali.jpg", []int{5}}, + {"no match", "/media/nobody here.jpg", nil}, + } + + for _, tt := range tests { + t.Run(tt.name+"/no-preload", func(t *testing.T) { + got, err := PathToPerformers(ctx, tt.path, db.Performer, nil, false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if gotIDs := perfIDs(got); !slices.Equal(gotIDs, tt.wantIDs) { + t.Errorf("got %v, want %v", gotIDs, tt.wantIDs) + } + }) + t.Run(tt.name+"/preloaded", func(t *testing.T) { + cache := &Cache{} + if err := cache.PreloadPerformers(ctx, db.Performer); err != nil { + t.Fatalf("preload: %v", err) + } + got, err := PathToPerformers(ctx, tt.path, db.Performer, cache, false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if gotIDs := perfIDs(got); !slices.Equal(gotIDs, tt.wantIDs) { + t.Errorf("got %v, want %v", gotIDs, tt.wantIDs) + } + }) + } +} + +func TestPathToStudio_Semantics(t *testing.T) { + t.Parallel() + ctx := context.Background() + + s1 := &models.Studio{ID: 1, Name: "first studio"} + s2 := &models.Studio{ID: 2, Name: "second"} + s3 := &models.Studio{ID: 3, Name: "third", IgnoreAutoTag: true} + + studios := []*models.Studio{s1, s2, s3} + aliases := map[int][]string{2: {"second alias"}} + db := mocks.NewDatabase() + primeStudioMock(db.Studio, studios, aliases) + + tests := []struct { + name string + path string + wantID int // 0 == no match + }{ + {"primary name", "/first studio/scene.mp4", 1}, + {"alias matches", "/second alias/scene.mp4", 2}, + {"ignore_auto_tag studio skipped", "/third/scene.mp4", 0}, + {"multiple matches - rightmost wins", "/first studio/second/scene.mp4", 2}, + {"no match", "/unknown/scene.mp4", 0}, + } + + runCase := func(t *testing.T, path string, wantID int, cache *Cache) { + got, err := PathToStudio(ctx, path, db.Studio, cache, false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + var gotID int + if got != nil { + gotID = got.ID + } + if gotID != wantID { + t.Errorf("got %d, want %d", gotID, wantID) + } + } + + for _, tt := range tests { + t.Run(tt.name+"/no-preload", func(t *testing.T) { + runCase(t, tt.path, tt.wantID, nil) + }) + t.Run(tt.name+"/preloaded", func(t *testing.T) { + cache := &Cache{} + if err := cache.PreloadStudios(ctx, db.Studio); err != nil { + t.Fatalf("preload: %v", err) + } + runCase(t, tt.path, tt.wantID, cache) + }) + } +} + +func TestPathToTags_Semantics(t *testing.T) { + t.Parallel() + ctx := context.Background() + + t1 := &models.Tag{ID: 1, Name: "anime"} + t2 := &models.Tag{ID: 2, Name: "docs"} + t3 := &models.Tag{ID: 3, Name: "skip me", IgnoreAutoTag: true} + + tags := []*models.Tag{t1, t2, t3} + aliases := map[int][]string{2: {"documentary"}} + db := mocks.NewDatabase() + primeTagMock(db.Tag, tags, aliases) + + tests := []struct { + name string + path string + wantIDs []int + }{ + {"name match", "/media/anime/x.mp4", []int{1}}, + {"alias match", "/media/documentary/x.mp4", []int{2}}, + {"multiple matches", "/media/anime-documentary/x.mp4", []int{1, 2}}, + {"ignore_auto_tag skipped", "/media/skip me/x.mp4", nil}, + {"no match", "/media/comedy/x.mp4", nil}, + } + + runCase := func(t *testing.T, path string, wantIDs []int, cache *Cache) { + got, err := PathToTags(ctx, path, db.Tag, cache, false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if gotIDs := tagIDs(got); !slices.Equal(gotIDs, wantIDs) { + t.Errorf("got %v, want %v", gotIDs, wantIDs) + } + } + + for _, tt := range tests { + t.Run(tt.name+"/no-preload", func(t *testing.T) { + runCase(t, tt.path, tt.wantIDs, nil) + }) + t.Run(tt.name+"/preloaded", func(t *testing.T) { + cache := &Cache{} + if err := cache.PreloadTags(ctx, db.Tag); err != nil { + t.Fatalf("preload: %v", err) + } + runCase(t, tt.path, tt.wantIDs, cache) + }) + } +} + +// Performer whose name starts with a single-letter word (e.g., "X Man") +// can't be reached via 2-rune prefix lookup (getPathWords drops 1-char +// words). The preload must put them in the alwaysCheck list so they're +// still regex-tested. +func TestPathToPerformers_SingleLetterFirstWord(t *testing.T) { + t.Parallel() + ctx := context.Background() + xman := &models.Performer{ID: 1, Name: "X Man"} + other := &models.Performer{ID: 2, Name: "alice smith"} + + db := mocks.NewDatabase() + primePerformerMock(db.Performer, []*models.Performer{xman, other}) + + cache := &Cache{} + if err := cache.PreloadPerformers(ctx, db.Performer); err != nil { + t.Fatal(err) + } + + got, err := PathToPerformers(ctx, "/media/X Man.mp4", db.Performer, cache, false) + if err != nil { + t.Fatal(err) + } + if ids := perfIDs(got); !slices.Equal(ids, []int{1}) { + t.Errorf("expected [1], got %v", ids) + } +} + +// A studio whose name shares no prefix with its aliases must be reachable +// by alias prefix. "Acme Corp" with alias "Widgets Inc" must match a path +// containing "widgets inc". +func TestPathToStudio_AliasPrefixDistinctFromName(t *testing.T) { + t.Parallel() + ctx := context.Background() + s := &models.Studio{ID: 1, Name: "Acme Corp"} + + db := mocks.NewDatabase() + primeStudioMock(db.Studio, []*models.Studio{s}, map[int][]string{1: {"Widgets Inc"}}) + + cache := &Cache{} + if err := cache.PreloadStudios(ctx, db.Studio); err != nil { + t.Fatal(err) + } + + got, err := PathToStudio(ctx, "/media/Widgets Inc/scene.mp4", db.Studio, cache, false) + if err != nil { + t.Fatal(err) + } + if got == nil || got.ID != 1 { + t.Errorf("expected studio 1, got %v", got) + } +} + +// Same for tags. +func TestPathToTags_AliasPrefixDistinctFromName(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := mocks.NewDatabase() + primeTagMock(db.Tag, []*models.Tag{{ID: 1, Name: "documentary"}}, map[int][]string{1: {"film"}}) + + cache := &Cache{} + if err := cache.PreloadTags(ctx, db.Tag); err != nil { + t.Fatal(err) + } + + got, err := PathToTags(ctx, "/media/film/x.mp4", db.Tag, cache, false) + if err != nil { + t.Fatal(err) + } + if ids := tagIDs(got); !slices.Equal(ids, []int{1}) { + t.Errorf("expected [1], got %v", ids) + } +} + +// Two aliases on the same studio with different prefixes should each +// reach the studio. Index bucket must dedupe inside the bucket. +func TestPathToStudio_MultipleAliasesDedup(t *testing.T) { + t.Parallel() + ctx := context.Background() + s := &models.Studio{ID: 1, Name: "Primary Name"} + + db := mocks.NewDatabase() + primeStudioMock(db.Studio, []*models.Studio{s}, map[int][]string{1: {"Primary Nickname", "Primary Alt"}}) + + cache := &Cache{} + if err := cache.PreloadStudios(ctx, db.Studio); err != nil { + t.Fatal(err) + } + // Studio "Primary Name" and both aliases all share prefix "pr". + // The bucket should contain it exactly once. + if got := len(cache.studioByPrefix["pr"]); got != 1 { + t.Errorf("bucket 'pr' should have 1 entry, got %d", got) + } +} + +// Equivalence test: the function must return the same result regardless of +// whether a match.Cache is passed in. This is the invariant that any +// caching-based optimization must preserve. +func TestPathToPerformers_CachedVsUncached(t *testing.T) { + t.Parallel() + ctx := context.Background() + + perfs := []*models.Performer{ + {ID: 1, Name: "alice smith"}, + {ID: 2, Name: "bob jones"}, + {ID: 3, Name: "charlie"}, + {ID: 4, Name: "david wong"}, + } + db := mocks.NewDatabase() + primePerformerMock(db.Performer, perfs) + + paths := []string{ + "/media/alice smith.jpg", + "/media/bob_jones.jpg", + "/media/alice smith and charlie.jpg", + "/media/nobody.jpg", + "/media/alice smith.jpg", // repeat: cached regex should not change outcome + } + + var noCache, withCache [][]int + cache := &Cache{} + for _, p := range paths { + uc, err := PathToPerformers(ctx, p, db.Performer, nil, false) + if err != nil { + t.Fatal(err) + } + wc, err := PathToPerformers(ctx, p, db.Performer, cache, false) + if err != nil { + t.Fatal(err) + } + noCache = append(noCache, perfIDs(uc)) + withCache = append(withCache, perfIDs(wc)) + } + + for i := range paths { + if !slices.Equal(noCache[i], withCache[i]) { + t.Errorf("path %q: no-cache %v vs cached %v", paths[i], noCache[i], withCache[i]) + } + } +} diff --git a/pkg/models/find_filter.go b/pkg/models/find_filter.go index 9934a9ea9c..49e27daea7 100644 --- a/pkg/models/find_filter.go +++ b/pkg/models/find_filter.go @@ -127,3 +127,17 @@ func BatchFindFilter(batchSize int) *FindFilterType { Page: &page, } } + +// KeysetFindFilter returns a FindFilterType suitable for id-ordered keyset +// pagination. Callers pair it with a WHERE id > lastID clause to iterate +// large tables without paying the O(offset) scan that LIMIT/OFFSET pays +// on later pages. +func KeysetFindFilter(batchSize int) *FindFilterType { + sort := "id" + sortDir := SortDirectionEnumAsc + return &FindFilterType{ + PerPage: &batchSize, + Sort: &sort, + Direction: &sortDir, + } +} diff --git a/pkg/models/relationships.go b/pkg/models/relationships.go index 5495f858b1..494c664fa8 100644 --- a/pkg/models/relationships.go +++ b/pkg/models/relationships.go @@ -63,6 +63,14 @@ type AliasLoader interface { GetAliases(ctx context.Context, relatedID int) ([]string, error) } +// AllAliasLoader is an optional bulk variant of AliasLoader: it returns +// aliases for every id in one query, letting callers that need aliases for +// many entities skip the N+1 per-id lookups. Implementations are free to +// add this alongside AliasLoader; callers use it via a type assertion. +type AllAliasLoader interface { + GetAllAliases(ctx context.Context) (map[int][]string, error) +} + type URLLoader interface { GetURLs(ctx context.Context, relatedID int) ([]string, error) } diff --git a/pkg/sqlite/repository.go b/pkg/sqlite/repository.go index 18d501e3aa..4eb053e2a1 100644 --- a/pkg/sqlite/repository.go +++ b/pkg/sqlite/repository.go @@ -378,6 +378,23 @@ func (r *stringRepository) get(ctx context.Context, id int) ([]string, error) { return ret, err } +// getAll returns every (id, value) pair in the join table, grouped by id. +// Used to avoid N+1 lookups when callers need values for many ids at once. +func (r *stringRepository) getAll(ctx context.Context) (map[int][]string, error) { + query := fmt.Sprintf("SELECT %s, %s from %s", r.idColumn, r.stringColumn, r.tableName) + ret := make(map[int][]string) + err := r.queryFunc(ctx, query, nil, false, func(rows *sqlx.Rows) error { + var id int + var out string + if err := rows.Scan(&id, &out); err != nil { + return err + } + ret[id] = append(ret[id], out) + return nil + }) + return ret, err +} + func (r *stringRepository) insert(ctx context.Context, id int, s string) (sql.Result, error) { stmt := fmt.Sprintf("INSERT INTO %s (%s, %s) VALUES (?, ?)", r.tableName, r.idColumn, r.stringColumn) return dbWrapper.Exec(ctx, stmt, id, s) diff --git a/pkg/sqlite/studio.go b/pkg/sqlite/studio.go index 87f9059359..a39ddd016b 100644 --- a/pkg/sqlite/studio.go +++ b/pkg/sqlite/studio.go @@ -742,6 +742,12 @@ func (qb *StudioStore) GetAliases(ctx context.Context, studioID int) ([]string, return studiosAliasesTableMgr.get(ctx, studioID) } +// GetAllAliases returns a map of studio id to its aliases. Lets callers that +// need aliases for many studios avoid N+1 per-id lookups. +func (qb *StudioStore) GetAllAliases(ctx context.Context) (map[int][]string, error) { + return studiosAliasesTableMgr.getAll(ctx) +} + func (qb *StudioStore) GetURLs(ctx context.Context, studioID int) ([]string, error) { return studiosURLsTableMgr.get(ctx, studioID) } diff --git a/pkg/sqlite/table.go b/pkg/sqlite/table.go index 3f8dfb70f8..e785230a4c 100644 --- a/pkg/sqlite/table.go +++ b/pkg/sqlite/table.go @@ -423,6 +423,28 @@ func (t *stringTable) get(ctx context.Context, id int) ([]string, error) { return ret, nil } +// getAll returns every (id, value) pair in the join table, grouped by id. +// Used to avoid N+1 lookups when callers need values for many ids at once. +func (t *stringTable) getAll(ctx context.Context) (map[int][]string, error) { + q := dialect.Select(t.idColumn, t.stringColumn).From(t.table.table) + + const single = false + ret := make(map[int][]string) + if err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error { + var id int + var v string + if err := rows.Scan(&id, &v); err != nil { + return err + } + ret[id] = append(ret[id], v) + return nil + }); err != nil { + return nil, fmt.Errorf("getting all values from %s: %w", t.table.table.GetTable(), err) + } + + return ret, nil +} + func (t *stringTable) insertJoin(ctx context.Context, id int, v string) (sql.Result, error) { q := dialect.Insert(t.table.table).Cols(t.idColumn.GetCol(), t.stringColumn.GetCol()).Vals( goqu.Vals{id, v}, diff --git a/pkg/sqlite/tag.go b/pkg/sqlite/tag.go index 4ee69cc460..76030403bb 100644 --- a/pkg/sqlite/tag.go +++ b/pkg/sqlite/tag.go @@ -940,6 +940,12 @@ func (qb *TagStore) GetAliases(ctx context.Context, tagID int) ([]string, error) return tagRepository.aliases.get(ctx, tagID) } +// GetAllAliases returns a map of tag id to its aliases. Lets callers that +// need aliases for many tags avoid N+1 per-id lookups. +func (qb *TagStore) GetAllAliases(ctx context.Context) (map[int][]string, error) { + return tagRepository.aliases.getAll(ctx) +} + func (qb *TagStore) UpdateAliases(ctx context.Context, tagID int, aliases []string) error { return tagRepository.aliases.replace(ctx, tagID, aliases) }