Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions internal/autotag/gallery.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/stashapp/stash/pkg/gallery"
"github.com/stashapp/stash/pkg/match"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
)

type GalleryFinderUpdater interface {
Expand Down Expand Up @@ -43,9 +44,11 @@ func getGalleryFileTagger(s *models.Gallery, cache *match.Cache) tagger {
}
}

// GalleryPerformers tags the provided gallery with performers whose name matches the gallery's path.
func GalleryPerformers(ctx context.Context, s *models.Gallery, rw GalleryPerformerUpdater, performerReader models.PerformerAutoTagQueryer, cache *match.Cache) error {
t := getGalleryFileTagger(s, cache)
// GalleryPerformersAtPath tags the provided gallery with performers whose
// name matches the gallery's path. A fresh write txn is opened only when a
// match is applied.
func (tagger *Tagger) GalleryPerformersAtPath(ctx context.Context, s *models.Gallery, rw GalleryPerformerUpdater, performerReader models.PerformerAutoTagQueryer) error {
t := getGalleryFileTagger(s, tagger.Cache)

return t.tagPerformers(ctx, performerReader, func(subjectID, otherID int) (bool, error) {
if err := s.LoadPerformerIDs(ctx, rw); err != nil {
Expand All @@ -57,33 +60,45 @@ func GalleryPerformers(ctx context.Context, s *models.Gallery, rw GalleryPerform
return false, nil
}

if err := gallery.AddPerformer(ctx, rw, s, otherID); err != nil {
if err := txn.WithTxn(ctx, tagger.TxnManager, func(ctx context.Context) error {
return gallery.AddPerformer(ctx, rw, s, otherID)
}); err != nil {
return false, err
}

return true, nil
})
}

// GalleryStudios tags the provided gallery with the first studio whose name matches the gallery's path.
// GalleryStudiosAtPath tags the provided gallery with the first studio whose
// name matches the gallery's path.
//
// Gallerys will not be tagged if studio is already set.
func GalleryStudios(ctx context.Context, s *models.Gallery, rw GalleryFinderUpdater, studioReader models.StudioAutoTagQueryer, cache *match.Cache) error {
// Galleries will not be tagged if studio is already set.
func (tagger *Tagger) GalleryStudiosAtPath(ctx context.Context, s *models.Gallery, rw GalleryFinderUpdater, studioReader models.StudioAutoTagQueryer) error {
if s.StudioID != nil {
// don't modify
return nil
}

t := getGalleryFileTagger(s, cache)
t := getGalleryFileTagger(s, tagger.Cache)

return t.tagStudios(ctx, studioReader, func(subjectID, otherID int) (bool, error) {
return addGalleryStudio(ctx, rw, s, otherID)
var added bool
if err := txn.WithTxn(ctx, tagger.TxnManager, func(ctx context.Context) error {
var err error
added, err = addGalleryStudio(ctx, rw, s, otherID)
return err
}); err != nil {
return false, err
}
return added, nil
})
}

// GalleryTags tags the provided gallery with tags whose name matches the gallery's path.
func GalleryTags(ctx context.Context, s *models.Gallery, rw GalleryTagUpdater, tagReader models.TagAutoTagQueryer, cache *match.Cache) error {
t := getGalleryFileTagger(s, cache)
// GalleryTagsAtPath tags the provided gallery with tags whose name matches
// the gallery's path.
func (tagger *Tagger) GalleryTagsAtPath(ctx context.Context, s *models.Gallery, rw GalleryTagUpdater, tagReader models.TagAutoTagQueryer) error {
t := getGalleryFileTagger(s, tagger.Cache)

return t.tagTags(ctx, tagReader, func(subjectID, otherID int) (bool, error) {
if err := s.LoadTagIDs(ctx, rw); err != nil {
Expand All @@ -95,7 +110,9 @@ func GalleryTags(ctx context.Context, s *models.Gallery, rw GalleryTagUpdater, t
return false, nil
}

if err := gallery.AddTag(ctx, rw, s, otherID); err != nil {
if err := txn.WithTxn(ctx, tagger.TxnManager, func(ctx context.Context) error {
return gallery.AddTag(ctx, rw, s, otherID)
}); err != nil {
return false, err
}

Expand Down
15 changes: 9 additions & 6 deletions internal/autotag/gallery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,16 @@ func TestGalleryPerformers(t *testing.T) {

return galleryPartialsEqual(got, expected)
})
db.Gallery.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once()
db.Gallery.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once()
}

gallery := models.Gallery{
ID: galleryID,
Path: test.Path,
PerformerIDs: models.NewRelatedIDs([]int{}),
}
err := GalleryPerformers(testCtx, &gallery, db.Gallery, db.Performer, nil)
tagger := &Tagger{TxnManager: db, Cache: nil}
err := tagger.GalleryPerformersAtPath(testCtx, &gallery, db.Gallery, db.Performer)

assert.Nil(err)
db.AssertExpectations(t)
Expand Down Expand Up @@ -114,14 +115,15 @@ func TestGalleryStudios(t *testing.T) {

return galleryPartialsEqual(got, expected)
})
db.Gallery.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once()
db.Gallery.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once()
}

gallery := models.Gallery{
ID: galleryID,
Path: test.Path,
}
err := GalleryStudios(testCtx, &gallery, db.Gallery, db.Studio, nil)
tagger := &Tagger{TxnManager: db, Cache: nil}
err := tagger.GalleryStudiosAtPath(testCtx, &gallery, db.Gallery, db.Studio)

assert.Nil(err)
db.AssertExpectations(t)
Expand Down Expand Up @@ -189,15 +191,16 @@ func TestGalleryTags(t *testing.T) {

return galleryPartialsEqual(got, expected)
})
db.Gallery.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once()
db.Gallery.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once()
}

gallery := models.Gallery{
ID: galleryID,
Path: test.Path,
TagIDs: models.NewRelatedIDs([]int{}),
}
err := GalleryTags(testCtx, &gallery, db.Gallery, db.Tag, nil)
tagger := &Tagger{TxnManager: db, Cache: nil}
err := tagger.GalleryTagsAtPath(testCtx, &gallery, db.Gallery, db.Tag)

assert.Nil(err)
db.AssertExpectations(t)
Expand Down
41 changes: 29 additions & 12 deletions internal/autotag/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/match"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
)

type ImageFinderUpdater interface {
Expand Down Expand Up @@ -34,9 +35,11 @@ func getImageFileTagger(s *models.Image, cache *match.Cache) tagger {
}
}

// ImagePerformers tags the provided image with performers whose name matches the image's path.
func ImagePerformers(ctx context.Context, s *models.Image, rw ImagePerformerUpdater, performerReader models.PerformerAutoTagQueryer, cache *match.Cache) error {
t := getImageFileTagger(s, cache)
// ImagePerformersAtPath tags the provided image with performers whose name
// matches the image's path. A fresh write txn is opened only when a match is
// applied.
func (tagger *Tagger) ImagePerformersAtPath(ctx context.Context, s *models.Image, rw ImagePerformerUpdater, performerReader models.PerformerAutoTagQueryer) error {
t := getImageFileTagger(s, tagger.Cache)

return t.tagPerformers(ctx, performerReader, func(subjectID, otherID int) (bool, error) {
if err := s.LoadPerformerIDs(ctx, rw); err != nil {
Expand All @@ -48,33 +51,45 @@ func ImagePerformers(ctx context.Context, s *models.Image, rw ImagePerformerUpda
return false, nil
}

if err := image.AddPerformer(ctx, rw, s, otherID); err != nil {
if err := txn.WithTxn(ctx, tagger.TxnManager, func(ctx context.Context) error {
return image.AddPerformer(ctx, rw, s, otherID)
}); err != nil {
return false, err
}

return true, nil
})
}

// ImageStudios tags the provided image with the first studio whose name matches the image's path.
// ImageStudiosAtPath tags the provided image with the first studio whose
// name matches the image's path.
//
// Images will not be tagged if studio is already set.
func ImageStudios(ctx context.Context, s *models.Image, rw ImageFinderUpdater, studioReader models.StudioAutoTagQueryer, cache *match.Cache) error {
func (tagger *Tagger) ImageStudiosAtPath(ctx context.Context, s *models.Image, rw ImageFinderUpdater, studioReader models.StudioAutoTagQueryer) error {
if s.StudioID != nil {
// don't modify
return nil
}

t := getImageFileTagger(s, cache)
t := getImageFileTagger(s, tagger.Cache)

return t.tagStudios(ctx, studioReader, func(subjectID, otherID int) (bool, error) {
return addImageStudio(ctx, rw, s, otherID)
var added bool
if err := txn.WithTxn(ctx, tagger.TxnManager, func(ctx context.Context) error {
var err error
added, err = addImageStudio(ctx, rw, s, otherID)
return err
}); err != nil {
return false, err
}
return added, nil
})
}

// ImageTags tags the provided image with tags whose name matches the image's path.
func ImageTags(ctx context.Context, s *models.Image, rw ImageTagUpdater, tagReader models.TagAutoTagQueryer, cache *match.Cache) error {
t := getImageFileTagger(s, cache)
// ImageTagsAtPath tags the provided image with tags whose name matches the
// image's path.
func (tagger *Tagger) ImageTagsAtPath(ctx context.Context, s *models.Image, rw ImageTagUpdater, tagReader models.TagAutoTagQueryer) error {
t := getImageFileTagger(s, tagger.Cache)

return t.tagTags(ctx, tagReader, func(subjectID, otherID int) (bool, error) {
if err := s.LoadTagIDs(ctx, rw); err != nil {
Expand All @@ -86,7 +101,9 @@ func ImageTags(ctx context.Context, s *models.Image, rw ImageTagUpdater, tagRead
return false, nil
}

if err := image.AddTag(ctx, rw, s, otherID); err != nil {
if err := txn.WithTxn(ctx, tagger.TxnManager, func(ctx context.Context) error {
return image.AddTag(ctx, rw, s, otherID)
}); err != nil {
return false, err
}

Expand Down
15 changes: 9 additions & 6 deletions internal/autotag/image_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,16 @@ func TestImagePerformers(t *testing.T) {

return imagePartialsEqual(got, expected)
})
db.Image.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once()
db.Image.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once()
}

image := models.Image{
ID: imageID,
Path: test.Path,
PerformerIDs: models.NewRelatedIDs([]int{}),
}
err := ImagePerformers(testCtx, &image, db.Image, db.Performer, nil)
tagger := &Tagger{TxnManager: db, Cache: nil}
err := tagger.ImagePerformersAtPath(testCtx, &image, db.Image, db.Performer)

assert.Nil(err)
db.AssertExpectations(t)
Expand Down Expand Up @@ -111,14 +112,15 @@ func TestImageStudios(t *testing.T) {

return imagePartialsEqual(got, expected)
})
db.Image.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once()
db.Image.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once()
}

image := models.Image{
ID: imageID,
Path: test.Path,
}
err := ImageStudios(testCtx, &image, db.Image, db.Studio, nil)
tagger := &Tagger{TxnManager: db, Cache: nil}
err := tagger.ImageStudiosAtPath(testCtx, &image, db.Image, db.Studio)

assert.Nil(err)
db.AssertExpectations(t)
Expand Down Expand Up @@ -186,15 +188,16 @@ func TestImageTags(t *testing.T) {

return imagePartialsEqual(got, expected)
})
db.Image.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once()
db.Image.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once()
}

image := models.Image{
ID: imageID,
Path: test.Path,
TagIDs: models.NewRelatedIDs([]int{}),
}
err := ImageTags(testCtx, &image, db.Image, db.Tag, nil)
tagger := &Tagger{TxnManager: db, Cache: nil}
err := tagger.ImageTagsAtPath(testCtx, &image, db.Image, db.Tag)

assert.Nil(err)
db.AssertExpectations(t)
Expand Down
42 changes: 30 additions & 12 deletions internal/autotag/scene.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/stashapp/stash/pkg/match"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
"github.com/stashapp/stash/pkg/txn"
)

type SceneFinderUpdater interface {
Expand Down Expand Up @@ -34,9 +35,12 @@ func getSceneFileTagger(s *models.Scene, cache *match.Cache) tagger {
}
}

// ScenePerformers tags the provided scene with performers whose name matches the scene's path.
func ScenePerformers(ctx context.Context, s *models.Scene, rw ScenePerformerUpdater, performerReader models.PerformerAutoTagQueryer, cache *match.Cache) error {
t := getSceneFileTagger(s, cache)
// ScenePerformersAtPath tags the provided scene with performers whose name
// matches the scene's path. The match phase runs using the current context
// (no outer write txn needed); a fresh write txn is opened only when a match
// is applied.
func (tagger *Tagger) ScenePerformersAtPath(ctx context.Context, s *models.Scene, rw ScenePerformerUpdater, performerReader models.PerformerAutoTagQueryer) error {
t := getSceneFileTagger(s, tagger.Cache)

return t.tagPerformers(ctx, performerReader, func(subjectID, otherID int) (bool, error) {
if err := s.LoadPerformerIDs(ctx, rw); err != nil {
Expand All @@ -48,33 +52,45 @@ func ScenePerformers(ctx context.Context, s *models.Scene, rw ScenePerformerUpda
return false, nil
}

if err := scene.AddPerformer(ctx, rw, s, otherID); err != nil {
if err := txn.WithTxn(ctx, tagger.TxnManager, func(ctx context.Context) error {
return scene.AddPerformer(ctx, rw, s, otherID)
}); err != nil {
return false, err
}

return true, nil
})
}

// SceneStudios tags the provided scene with the first studio whose name matches the scene's path.
// SceneStudiosAtPath tags the provided scene with the first studio whose name
// matches the scene's path.
//
// Scenes will not be tagged if studio is already set.
func SceneStudios(ctx context.Context, s *models.Scene, rw SceneFinderUpdater, studioReader models.StudioAutoTagQueryer, cache *match.Cache) error {
func (tagger *Tagger) SceneStudiosAtPath(ctx context.Context, s *models.Scene, rw SceneFinderUpdater, studioReader models.StudioAutoTagQueryer) error {
if s.StudioID != nil {
// don't modify
return nil
}

t := getSceneFileTagger(s, cache)
t := getSceneFileTagger(s, tagger.Cache)

return t.tagStudios(ctx, studioReader, func(subjectID, otherID int) (bool, error) {
return addSceneStudio(ctx, rw, s, otherID)
var added bool
if err := txn.WithTxn(ctx, tagger.TxnManager, func(ctx context.Context) error {
var err error
added, err = addSceneStudio(ctx, rw, s, otherID)
return err
}); err != nil {
return false, err
}
return added, nil
})
}

// SceneTags tags the provided scene with tags whose name matches the scene's path.
func SceneTags(ctx context.Context, s *models.Scene, rw SceneTagUpdater, tagReader models.TagAutoTagQueryer, cache *match.Cache) error {
t := getSceneFileTagger(s, cache)
// SceneTagsAtPath tags the provided scene with tags whose name matches the
// scene's path.
func (tagger *Tagger) SceneTagsAtPath(ctx context.Context, s *models.Scene, rw SceneTagUpdater, tagReader models.TagAutoTagQueryer) error {
t := getSceneFileTagger(s, tagger.Cache)

return t.tagTags(ctx, tagReader, func(subjectID, otherID int) (bool, error) {
if err := s.LoadTagIDs(ctx, rw); err != nil {
Expand All @@ -86,7 +102,9 @@ func SceneTags(ctx context.Context, s *models.Scene, rw SceneTagUpdater, tagRead
return false, nil
}

if err := scene.AddTag(ctx, rw, s, otherID); err != nil {
if err := txn.WithTxn(ctx, tagger.TxnManager, func(ctx context.Context) error {
return scene.AddTag(ctx, rw, s, otherID)
}); err != nil {
return false, err
}

Expand Down
Loading
Loading