Compare commits

..

2 Commits

Author SHA1 Message Date
DogmaDragon
7c97a140f9 Merge branch 'develop' of https://github.com/stashapp/stash into docs-patchable-components 2026-01-22 13:01:43 +02:00
DogmaDragon
70ad014ac4 docs: add missing patchable components and library 2026-01-22 12:53:00 +02:00
177 changed files with 3100 additions and 6634 deletions

View File

@@ -5,39 +5,20 @@ import (
"fmt"
"os"
"os/exec"
"path/filepath"
flag "github.com/spf13/pflag"
"github.com/stashapp/stash/pkg/ffmpeg"
"github.com/stashapp/stash/pkg/hash/imagephash"
"github.com/stashapp/stash/pkg/hash/videophash"
"github.com/stashapp/stash/pkg/models"
)
func customUsage() {
fmt.Fprintf(os.Stderr, "Usage:\n")
fmt.Fprintf(os.Stderr, "%s [OPTIONS] FILE...\n\nOptions:\n", os.Args[0])
fmt.Fprintf(os.Stderr, "%s [OPTIONS] VIDEOFILE...\n\nOptions:\n", os.Args[0])
flag.PrintDefaults()
}
func printPhash(ff *ffmpeg.FFMpeg, ffp *ffmpeg.FFProbe, inputfile string, quiet *bool) error {
// Determine if this is a video or image file based on extension
ext := filepath.Ext(inputfile)
ext = ext[1:] // remove the leading dot
// Common image extensions
imageExts := map[string]bool{
"jpg": true, "jpeg": true, "png": true, "gif": true, "webp": true, "bmp": true,
}
if imageExts[ext] {
return printImagePhash(inputfile, quiet)
}
return printVideoPhash(ff, ffp, inputfile, quiet)
}
func printVideoPhash(ff *ffmpeg.FFMpeg, ffp *ffmpeg.FFProbe, inputfile string, quiet *bool) error {
ffvideoFile, err := ffp.NewVideoFile(inputfile)
if err != nil {
return err
@@ -65,24 +46,6 @@ func printVideoPhash(ff *ffmpeg.FFMpeg, ffp *ffmpeg.FFProbe, inputfile string, q
return nil
}
func printImagePhash(inputfile string, quiet *bool) error {
imgFile := &models.ImageFile{
BaseFile: &models.BaseFile{Path: inputfile},
}
phash, err := imagephash.Generate(imgFile)
if err != nil {
return err
}
if *quiet {
fmt.Printf("%x\n", *phash)
} else {
fmt.Printf("%x %v\n", *phash, imgFile.Path)
}
return nil
}
func getPaths() (string, string) {
ffmpegPath, _ := exec.LookPath("ffmpeg")
ffprobePath, _ := exec.LookPath("ffprobe")
@@ -104,7 +67,7 @@ func main() {
args := flag.Args()
if len(args) < 1 {
fmt.Fprintf(os.Stderr, "Missing FILE argument.\n")
fmt.Fprintf(os.Stderr, "Missing VIDEOFILE argument.\n")
flag.Usage()
os.Exit(2)
}
@@ -124,5 +87,4 @@ func main() {
fmt.Fprintln(os.Stderr, err)
}
}
}

View File

@@ -422,8 +422,6 @@ type Mutation {
"""
moveFiles(input: MoveFilesInput!): Boolean!
deleteFiles(ids: [ID!]!): Boolean!
"Deletes file entries from the database without deleting the files from the filesystem"
destroyFiles(ids: [ID!]!): Boolean!
fileSetFingerprints(input: FileSetFingerprintsInput!): Boolean!

View File

@@ -395,9 +395,6 @@ input ConfigInterfaceInput {
customLocales: String
customLocalesEnabled: Boolean
"When true, disables all customizations (plugins, CSS, JavaScript, locales) for troubleshooting"
disableCustomizations: Boolean
"Interface language"
language: String
@@ -472,9 +469,6 @@ type ConfigInterfaceResult {
customLocales: String
customLocalesEnabled: Boolean
"When true, disables all customizations (plugins, CSS, JavaScript, locales) for troubleshooting"
disableCustomizations: Boolean
"Interface language"
language: String

View File

@@ -308,8 +308,6 @@ input SceneFilterType {
@deprecated(reason: "use stash_ids_endpoint instead")
"Filter by StashIDs"
stash_ids_endpoint: StashIDsCriterionInput
"Filter by StashID count"
stash_id_count: IntCriterionInput
"Filter by url"
url: StringCriterionInput
"Filter by interactive"
@@ -487,8 +485,6 @@ input StudioFilterType {
created_at: TimestampCriterionInput
"Filter by last update time"
updated_at: TimestampCriterionInput
custom_fields: [CustomFieldCriterionInput!]
}
input GalleryFilterType {
@@ -662,8 +658,6 @@ input ImageFilterType {
id: IntCriterionInput
"Filter by file checksum"
checksum: StringCriterionInput
"Filter by file phash distance"
phash_distance: PhashDistanceCriterionInput
"Filter by path"
path: StringCriterionInput
"Filter by file count"

View File

@@ -100,8 +100,6 @@ input GalleryDestroyInput {
"""
delete_file: Boolean
delete_generated: Boolean
"If true, delete the file entry from the database if the file is not assigned to any other objects"
destroy_file_entry: Boolean
}
type FindGalleriesResultType {

View File

@@ -82,16 +82,12 @@ input ImageDestroyInput {
id: ID!
delete_file: Boolean
delete_generated: Boolean
"If true, delete the file entry from the database if the file is not assigned to any other objects"
destroy_file_entry: Boolean
}
input ImagesDestroyInput {
ids: [ID!]!
delete_file: Boolean
delete_generated: Boolean
"If true, delete the file entry from the database if the file is not assigned to any other objects"
destroy_file_entry: Boolean
}
type FindImagesResultType {

View File

@@ -10,11 +10,8 @@ input GenerateMetadataInput {
transcodes: Boolean
"Generate transcodes even if not required"
forceTranscodes: Boolean
"Generate video phashes during scan"
phashes: Boolean
interactiveHeatmapsSpeeds: Boolean
"Generate image phashes during scan"
imagePhashes: Boolean
imageThumbnails: Boolean
clipPreviews: Boolean
@@ -22,10 +19,6 @@ input GenerateMetadataInput {
sceneIDs: [ID!]
"marker ids to generate for"
markerIDs: [ID!]
"image ids to generate for"
imageIDs: [ID!]
"gallery ids to generate for"
galleryIDs: [ID!]
"overwrite existing media"
overwrite: Boolean
@@ -92,10 +85,8 @@ input ScanMetadataInput {
scanGenerateImagePreviews: Boolean
"Generate sprites during scan"
scanGenerateSprites: Boolean
"Generate video phashes during scan"
"Generate phashes during scan"
scanGeneratePhashes: Boolean
"Generate image phashes during scan"
scanGenerateImagePhashes: Boolean
"Generate image thumbnails during scan"
scanGenerateThumbnails: Boolean
"Generate image clip previews during scan"
@@ -116,10 +107,8 @@ type ScanMetadataOptions {
scanGenerateImagePreviews: Boolean!
"Generate sprites during scan"
scanGenerateSprites: Boolean!
"Generate video phashes during scan"
"Generate phashes during scan"
scanGeneratePhashes: Boolean!
"Generate image phashes during scan"
scanGenerateImagePhashes: Boolean
"Generate image thumbnails during scan"
scanGenerateThumbnails: Boolean!
"Generate image clip previews during scan"

View File

@@ -80,7 +80,6 @@ input PerformerCreateInput {
career_length: String
tattoos: String
piercings: String
"Duplicate aliases and those equal to name will be ignored (case-insensitive)"
alias_list: [String!]
twitter: String @deprecated(reason: "Use urls")
instagram: String @deprecated(reason: "Use urls")
@@ -119,7 +118,6 @@ input PerformerUpdateInput {
career_length: String
tattoos: String
piercings: String
"Duplicate aliases and those equal to name will be ignored (case-insensitive)"
alias_list: [String!]
twitter: String @deprecated(reason: "Use urls")
instagram: String @deprecated(reason: "Use urls")
@@ -163,7 +161,6 @@ input BulkPerformerUpdateInput {
career_length: String
tattoos: String
piercings: String
"Duplicate aliases and those equal to name will result in an error (case-insensitive)"
alias_list: BulkUpdateStrings
twitter: String @deprecated(reason: "Use urls")
instagram: String @deprecated(reason: "Use urls")

View File

@@ -196,16 +196,12 @@ input SceneDestroyInput {
id: ID!
delete_file: Boolean
delete_generated: Boolean
"If true, delete the file entry from the database if the file is not assigned to any other objects"
destroy_file_entry: Boolean
}
input ScenesDestroyInput {
ids: [ID!]!
delete_file: Boolean
delete_generated: Boolean
"If true, delete the file entry from the database if the file is not assigned to any other objects"
destroy_file_entry: Boolean
}
type FindScenesResultType {

View File

@@ -26,8 +26,6 @@ type Studio {
groups: [Group!]!
movies: [Movie!]! @deprecated(reason: "use groups instead")
o_counter: Int
custom_fields: Map!
}
input StudioCreateInput {
@@ -42,12 +40,9 @@ input StudioCreateInput {
rating100: Int
favorite: Boolean
details: String
"Duplicate aliases and those equal to name will be ignored (case-insensitive)"
aliases: [String!]
tag_ids: [ID!]
ignore_auto_tag: Boolean
custom_fields: Map
}
input StudioUpdateInput {
@@ -63,12 +58,9 @@ input StudioUpdateInput {
rating100: Int
favorite: Boolean
details: String
"Duplicate aliases and those equal to name will be ignored (case-insensitive)"
aliases: [String!]
tag_ids: [ID!]
ignore_auto_tag: Boolean
custom_fields: CustomFieldsInput
}
input BulkStudioUpdateInput {

View File

@@ -31,7 +31,6 @@ input TagCreateInput {
"Value that does not appear in the UI but overrides name for sorting"
sort_name: String
description: String
"Duplicate aliases and those equal to name will be ignored (case-insensitive)"
aliases: [String!]
ignore_auto_tag: Boolean
favorite: Boolean
@@ -49,7 +48,6 @@ input TagUpdateInput {
"Value that does not appear in the UI but overrides name for sorting"
sort_name: String
description: String
"Duplicate aliases and those equal to name will be ignored (case-insensitive)"
aliases: [String!]
ignore_auto_tag: Boolean
favorite: Boolean
@@ -78,7 +76,6 @@ input TagsMergeInput {
input BulkTagUpdateInput {
ids: [ID!]
description: String
"Duplicate aliases and those equal to name will result in an error (case-insensitive)"
aliases: BulkUpdateStrings
ignore_auto_tag: Boolean
favorite: Boolean

View File

@@ -59,9 +59,7 @@ type Loaders struct {
PerformerByID *PerformerLoader
PerformerCustomFields *CustomFieldsLoader
StudioByID *StudioLoader
StudioCustomFields *CustomFieldsLoader
StudioByID *StudioLoader
TagByID *TagLoader
GroupByID *GroupLoader
FileByID *FileLoader
@@ -101,11 +99,6 @@ func (m Middleware) Middleware(next http.Handler) http.Handler {
maxBatch: maxBatch,
fetch: m.fetchPerformerCustomFields(ctx),
},
StudioCustomFields: &CustomFieldsLoader{
wait: wait,
maxBatch: maxBatch,
fetch: m.fetchStudioCustomFields(ctx),
},
StudioByID: &StudioLoader{
wait: wait,
maxBatch: maxBatch,
@@ -260,18 +253,6 @@ func (m Middleware) fetchStudios(ctx context.Context) func(keys []int) ([]*model
}
}
func (m Middleware) fetchStudioCustomFields(ctx context.Context) func(keys []int) ([]models.CustomFieldMap, []error) {
return func(keys []int) (ret []models.CustomFieldMap, errs []error) {
err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
var err error
ret, err = m.Repository.Studio.GetCustomFieldsBulk(ctx, keys)
return err
})
return ret, toErrorSlice(err)
}
}
func (m Middleware) fetchTags(ctx context.Context) func(keys []int) ([]*models.Tag, []error) {
return func(keys []int) (ret []*models.Tag, errs []error) {
err := m.Repository.WithDB(ctx, func(ctx context.Context) error {

View File

@@ -207,19 +207,6 @@ func (r *studioResolver) Groups(ctx context.Context, obj *models.Studio) (ret []
return ret, nil
}
func (r *studioResolver) CustomFields(ctx context.Context, obj *models.Studio) (map[string]interface{}, error) {
m, err := loaders.From(ctx).StudioCustomFields.Load(obj.ID)
if err != nil {
return nil, err
}
if m == nil {
return make(map[string]interface{}), nil
}
return m, nil
}
// deprecated
func (r *studioResolver) Movies(ctx context.Context, obj *models.Studio) (ret []*models.Group, err error) {
return r.Groups(ctx, obj)

View File

@@ -515,8 +515,6 @@ func (r *mutationResolver) ConfigureInterface(ctx context.Context, input ConfigI
r.setConfigBool(config.CustomLocalesEnabled, input.CustomLocalesEnabled)
r.setConfigBool(config.DisableCustomizations, input.DisableCustomizations)
if input.DisableDropdownCreate != nil {
ddc := input.DisableDropdownCreate
r.setConfigBool(config.DisableDropdownCreatePerformer, ddc.Performer)

View File

@@ -210,58 +210,6 @@ func (r *mutationResolver) DeleteFiles(ctx context.Context, ids []string) (ret b
return true, nil
}
func (r *mutationResolver) DestroyFiles(ctx context.Context, ids []string) (ret bool, err error) {
fileIDs, err := stringslice.StringSliceToIntSlice(ids)
if err != nil {
return false, fmt.Errorf("converting ids: %w", err)
}
destroyer := &file.ZipDestroyer{
FileDestroyer: r.repository.File,
FolderDestroyer: r.repository.Folder,
}
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.File
for _, fileIDInt := range fileIDs {
fileID := models.FileID(fileIDInt)
f, err := qb.Find(ctx, fileID)
if err != nil {
return err
}
if len(f) == 0 {
return fmt.Errorf("file with id %d not found", fileID)
}
path := f[0].Base().Path
// ensure not a primary file
isPrimary, err := qb.IsPrimary(ctx, fileID)
if err != nil {
return fmt.Errorf("checking if file %s is primary: %w", path, err)
}
if isPrimary {
return fmt.Errorf("cannot destroy primary file entry %s", path)
}
// destroy DB entries only (no filesystem deletion)
const deleteFile = false
if err := destroyer.DestroyZip(ctx, f[0], nil, deleteFile); err != nil {
return fmt.Errorf("destroying file entry %s: %w", path, err)
}
}
return nil
}); err != nil {
return false, err
}
return true, nil
}
func (r *mutationResolver) FileSetFingerprints(ctx context.Context, input FileSetFingerprintsInput) (bool, error) {
fileIDInt, err := strconv.Atoi(input.ID)
if err != nil {

View File

@@ -346,7 +346,6 @@ func (r *mutationResolver) GalleryDestroy(ctx context.Context, input models.Gall
deleteGenerated := utils.IsTrue(input.DeleteGenerated)
deleteFile := utils.IsTrue(input.DeleteFile)
destroyFileEntry := utils.IsTrue(input.DestroyFileEntry)
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Gallery
@@ -367,7 +366,7 @@ func (r *mutationResolver) GalleryDestroy(ctx context.Context, input models.Gall
galleries = append(galleries, gallery)
imgsDestroyed, err = r.galleryService.Destroy(ctx, gallery, fileDeleter, deleteGenerated, deleteFile, destroyFileEntry)
imgsDestroyed, err = r.galleryService.Destroy(ctx, gallery, fileDeleter, deleteGenerated, deleteFile)
if err != nil {
return err
}

View File

@@ -325,7 +325,7 @@ func (r *mutationResolver) ImageDestroy(ctx context.Context, input models.ImageD
return fmt.Errorf("image with id %d not found", imageID)
}
return r.imageService.Destroy(ctx, i, fileDeleter, utils.IsTrue(input.DeleteGenerated), utils.IsTrue(input.DeleteFile), utils.IsTrue(input.DestroyFileEntry))
return r.imageService.Destroy(ctx, i, fileDeleter, utils.IsTrue(input.DeleteGenerated), utils.IsTrue(input.DeleteFile))
}); err != nil {
fileDeleter.Rollback()
return false, err
@@ -372,7 +372,7 @@ func (r *mutationResolver) ImagesDestroy(ctx context.Context, input models.Image
images = append(images, i)
if err := r.imageService.Destroy(ctx, i, fileDeleter, utils.IsTrue(input.DeleteGenerated), utils.IsTrue(input.DeleteFile), utils.IsTrue(input.DestroyFileEntry)); err != nil {
if err := r.imageService.Destroy(ctx, i, fileDeleter, utils.IsTrue(input.DeleteGenerated), utils.IsTrue(input.DeleteFile)); err != nil {
return err
}
}

View File

@@ -43,7 +43,7 @@ func (r *mutationResolver) PerformerCreate(ctx context.Context, input models.Per
newPerformer.Name = strings.TrimSpace(input.Name)
newPerformer.Disambiguation = translator.string(input.Disambiguation)
newPerformer.Aliases = models.NewRelatedStrings(stringslice.UniqueExcludeFold(stringslice.TrimSpace(input.AliasList), newPerformer.Name))
newPerformer.Aliases = models.NewRelatedStrings(stringslice.TrimSpace(input.AliasList))
newPerformer.Gender = input.Gender
newPerformer.Ethnicity = translator.string(input.Ethnicity)
newPerformer.Country = translator.string(input.Country)
@@ -348,27 +348,6 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input models.Per
}
}
if updatedPerformer.Aliases != nil {
p, err := qb.Find(ctx, performerID)
if err != nil {
return err
}
if p != nil {
if err := p.LoadAliases(ctx, qb); err != nil {
return err
}
effectiveAliases := updatedPerformer.Aliases.Apply(p.Aliases.List())
name := p.Name
if updatedPerformer.Name.Set {
name = updatedPerformer.Name.Value
}
sanitized := stringslice.UniqueExcludeFold(effectiveAliases, name)
updatedPerformer.Aliases.Values = sanitized
updatedPerformer.Aliases.Mode = models.RelationshipUpdateModeSet
}
}
if err := performer.ValidateUpdate(ctx, performerID, *updatedPerformer, qb); err != nil {
return err
}

View File

@@ -441,7 +441,6 @@ func (r *mutationResolver) SceneDestroy(ctx context.Context, input models.SceneD
deleteGenerated := utils.IsTrue(input.DeleteGenerated)
deleteFile := utils.IsTrue(input.DeleteFile)
destroyFileEntry := utils.IsTrue(input.DestroyFileEntry)
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Scene
@@ -458,7 +457,7 @@ func (r *mutationResolver) SceneDestroy(ctx context.Context, input models.SceneD
// kill any running encoders
manager.KillRunningStreams(s, fileNamingAlgo)
return r.sceneService.Destroy(ctx, s, fileDeleter, deleteGenerated, deleteFile, destroyFileEntry)
return r.sceneService.Destroy(ctx, s, fileDeleter, deleteGenerated, deleteFile)
}); err != nil {
fileDeleter.Rollback()
return false, err
@@ -496,7 +495,6 @@ func (r *mutationResolver) ScenesDestroy(ctx context.Context, input models.Scene
deleteGenerated := utils.IsTrue(input.DeleteGenerated)
deleteFile := utils.IsTrue(input.DeleteFile)
destroyFileEntry := utils.IsTrue(input.DestroyFileEntry)
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Scene
@@ -515,7 +513,7 @@ func (r *mutationResolver) ScenesDestroy(ctx context.Context, input models.Scene
// kill any running encoders
manager.KillRunningStreams(scene, fileNamingAlgo)
if err := r.sceneService.Destroy(ctx, scene, fileDeleter, deleteGenerated, deleteFile, destroyFileEntry); err != nil {
if err := r.sceneService.Destroy(ctx, scene, fileDeleter, deleteGenerated, deleteFile); err != nil {
return err
}
}
@@ -624,12 +622,7 @@ func (r *mutationResolver) SceneMerge(ctx context.Context, input SceneMergeInput
return fmt.Errorf("scene with id %d not found", destID)
}
// only update cover image if one was provided
if len(coverImageData) > 0 {
return r.sceneUpdateCoverImage(ctx, ret, coverImageData)
}
return nil
return r.sceneUpdateCoverImage(ctx, ret, coverImageData)
}); err != nil {
return nil, err
}

View File

@@ -31,14 +31,14 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input models.Studio
}
// Populate a new studio from the input
newStudio := models.NewCreateStudioInput()
newStudio := models.NewStudio()
newStudio.Name = strings.TrimSpace(input.Name)
newStudio.Rating = input.Rating100
newStudio.Favorite = translator.bool(input.Favorite)
newStudio.Details = translator.string(input.Details)
newStudio.IgnoreAutoTag = translator.bool(input.IgnoreAutoTag)
newStudio.Aliases = models.NewRelatedStrings(stringslice.UniqueExcludeFold(stringslice.TrimSpace(input.Aliases), newStudio.Name))
newStudio.Aliases = models.NewRelatedStrings(stringslice.TrimSpace(input.Aliases))
newStudio.StashIDs = models.NewRelatedStashIDs(models.StashIDInputs(input.StashIds).ToStashIDs())
var err error
@@ -61,7 +61,6 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input models.Studio
if err != nil {
return nil, fmt.Errorf("converting tag ids: %w", err)
}
newStudio.CustomFields = convertMapJSONNumbers(input.CustomFields)
// Process the base 64 encoded image string
var imageData []byte
@@ -153,11 +152,6 @@ func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.Studio
}
}
updatedStudio.CustomFields = input.CustomFields
// convert json.Numbers to int/float
updatedStudio.CustomFields.Full = convertMapJSONNumbers(updatedStudio.CustomFields.Full)
updatedStudio.CustomFields.Partial = convertMapJSONNumbers(updatedStudio.CustomFields.Partial)
// Process the base 64 encoded image string
var imageData []byte
imageIncluded := translator.hasField("image")
@@ -173,28 +167,6 @@ func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.Studio
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Studio
if updatedStudio.Aliases != nil {
s, err := qb.Find(ctx, studioID)
if err != nil {
return err
}
if s != nil {
if err := s.LoadAliases(ctx, qb); err != nil {
return err
}
effectiveAliases := updatedStudio.Aliases.Apply(s.Aliases.List())
name := s.Name
if updatedStudio.Name.Set {
name = updatedStudio.Name.Value
}
sanitized := stringslice.UniqueExcludeFold(effectiveAliases, name)
updatedStudio.Aliases.Values = sanitized
updatedStudio.Aliases.Mode = models.RelationshipUpdateModeSet
}
}
if err := studio.ValidateModify(ctx, updatedStudio, qb); err != nil {
return err
}

View File

@@ -35,7 +35,7 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input TagCreateInput)
newTag.Name = strings.TrimSpace(input.Name)
newTag.SortName = translator.string(input.SortName)
newTag.Aliases = models.NewRelatedStrings(stringslice.UniqueExcludeFold(stringslice.TrimSpace(input.Aliases), newTag.Name))
newTag.Aliases = models.NewRelatedStrings(stringslice.TrimSpace(input.Aliases))
newTag.Favorite = translator.bool(input.Favorite)
newTag.Description = translator.string(input.Description)
newTag.IgnoreAutoTag = translator.bool(input.IgnoreAutoTag)
@@ -151,28 +151,6 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input TagUpdateInput)
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Tag
if updatedTag.Aliases != nil {
t, err := qb.Find(ctx, tagID)
if err != nil {
return err
}
if t != nil {
if err := t.LoadAliases(ctx, qb); err != nil {
return err
}
newAliases := updatedTag.Aliases.Apply(t.Aliases.List())
name := t.Name
if updatedTag.Name.Set {
name = updatedTag.Name.Value
}
sanitized := stringslice.UniqueExcludeFold(newAliases, name)
updatedTag.Aliases.Values = sanitized
updatedTag.Aliases.Mode = models.RelationshipUpdateModeSet
}
}
if err := tag.ValidateUpdate(ctx, tagID, updatedTag, qb); err != nil {
return err
}

View File

@@ -156,7 +156,6 @@ func makeConfigInterfaceResult() *ConfigInterfaceResult {
javascriptEnabled := config.GetJavascriptEnabled()
customLocales := config.GetCustomLocales()
customLocalesEnabled := config.GetCustomLocalesEnabled()
disableCustomizations := config.GetDisableCustomizations()
language := config.GetLanguage()
handyKey := config.GetHandyKey()
scriptOffset := config.GetFunscriptOffset()
@@ -184,7 +183,6 @@ func makeConfigInterfaceResult() *ConfigInterfaceResult {
JavascriptEnabled: &javascriptEnabled,
CustomLocales: &customLocales,
CustomLocalesEnabled: &customLocalesEnabled,
DisableCustomizations: &disableCustomizations,
Language: &language,
ImageLightbox: &imageLightboxOptions,

View File

@@ -450,7 +450,7 @@ func cssHandler(c *config.Config) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var paths []string
if c.GetCSSEnabled() && !c.GetDisableCustomizations() {
if c.GetCSSEnabled() {
// search for custom.css in current directory, then $HOME/.stash
fn := c.GetCSSPath()
exists, _ := fsutil.FileExists(fn)
@@ -468,7 +468,7 @@ func javascriptHandler(c *config.Config) func(w http.ResponseWriter, r *http.Req
return func(w http.ResponseWriter, r *http.Request) {
var paths []string
if c.GetJavascriptEnabled() && !c.GetDisableCustomizations() {
if c.GetJavascriptEnabled() {
// search for custom.js in current directory, then $HOME/.stash
fn := c.GetJavascriptPath()
exists, _ := fsutil.FileExists(fn)
@@ -486,7 +486,7 @@ func customLocalesHandler(c *config.Config) func(w http.ResponseWriter, r *http.
return func(w http.ResponseWriter, r *http.Request) {
buffer := bytes.Buffer{}
if c.GetCustomLocalesEnabled() && !c.GetDisableCustomizations() {
if c.GetCustomLocalesEnabled() {
// search for custom-locales.json in current directory, then $HOME/.stash
path := c.GetCustomLocalesPath()
exists, _ := fsutil.FileExists(path)

View File

@@ -101,15 +101,16 @@ func createPerformer(ctx context.Context, pqb models.PerformerWriter) error {
func createStudio(ctx context.Context, qb models.StudioWriter, name string) (*models.Studio, error) {
// create the studio
studio := models.NewCreateStudioInput()
studio.Name = name
studio := models.Studio{
Name: name,
}
err := qb.Create(ctx, &studio)
if err != nil {
return nil, err
}
return studio.Studio, nil
return &studio, nil
}
func createTag(ctx context.Context, qb models.TagWriter) error {

View File

@@ -27,7 +27,7 @@ func Test_sceneRelationships_studio(t *testing.T) {
db := mocks.NewDatabase()
db.Studio.On("Create", testCtx, mock.Anything).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.CreateStudioInput)
s := args.Get(1).(*models.Studio)
s.ID = validStoredIDInt
}).Return(nil)

View File

@@ -21,13 +21,13 @@ func Test_createMissingStudio(t *testing.T) {
db := mocks.NewDatabase()
db.Studio.On("Create", testCtx, mock.MatchedBy(func(p *models.CreateStudioInput) bool {
db.Studio.On("Create", testCtx, mock.MatchedBy(func(p *models.Studio) bool {
return p.Name == validName
})).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.CreateStudioInput)
s := args.Get(1).(*models.Studio)
s.ID = createdID
}).Return(nil)
db.Studio.On("Create", testCtx, mock.MatchedBy(func(p *models.CreateStudioInput) bool {
db.Studio.On("Create", testCtx, mock.MatchedBy(func(p *models.Studio) bool {
return p.Name == invalidName
})).Return(errors.New("error creating studio"))

View File

@@ -194,7 +194,6 @@ const (
CSSEnabled = "cssenabled"
JavascriptEnabled = "javascriptenabled"
CustomLocalesEnabled = "customlocalesenabled"
DisableCustomizations = "disable_customizations"
ShowScrubber = "show_scrubber"
showScrubberDefault = true
@@ -1480,13 +1479,6 @@ func (i *Config) GetCustomLocalesEnabled() bool {
return i.getBool(CustomLocalesEnabled)
}
// GetDisableCustomizations returns true if all customizations (plugins, custom CSS,
// custom JavaScript, and custom locales) should be disabled. This is useful for
// troubleshooting issues without permanently disabling individual customizations.
func (i *Config) GetDisableCustomizations() bool {
return i.getBool(DisableCustomizations)
}
func (i *Config) GetHandyKey() string {
return i.getString(HandyKey)
}

View File

@@ -11,10 +11,8 @@ type ScanMetadataOptions struct {
ScanGenerateImagePreviews bool `json:"scanGenerateImagePreviews"`
// Generate sprites during scan
ScanGenerateSprites bool `json:"scanGenerateSprites"`
// Generate video phashes during scan
// Generate phashes during scan
ScanGeneratePhashes bool `json:"scanGeneratePhashes"`
// Generate image phashes during scan
ScanGenerateImagePhashes bool `json:"scanGenerateImagePhashes"`
// Generate image thumbnails during scan
ScanGenerateThumbnails bool `json:"scanGenerateThumbnails"`
// Generate image thumbnails during scan

View File

@@ -100,8 +100,6 @@ func (s *Manager) Scan(ctx context.Context, input ScanMetadataInput) (int, error
return 0, err
}
cfg := config.GetInstance()
scanner := &file.Scanner{
Repository: file.NewRepository(s.Repository),
FileDecorators: []file.Decorator{
@@ -120,10 +118,6 @@ func (s *Manager) Scan(ctx context.Context, input ScanMetadataInput) (int, error
},
FingerprintCalculator: &fingerprintCalculator{s.Config},
FS: &file.OsFS{},
ZipFileExtensions: cfg.GetGalleryExtensions(),
// ScanFilters is set in ScanJob.Execute
// HandlerRequiredFilters is set in ScanJob.Execute
Rescan: input.Rescan,
}
scanJob := ScanJob{

View File

@@ -13,14 +13,14 @@ type SceneService interface {
Create(ctx context.Context, input *models.Scene, fileIDs []models.FileID, coverImage []byte) (*models.Scene, error)
AssignFile(ctx context.Context, sceneID int, fileID models.FileID) error
Merge(ctx context.Context, sourceIDs []int, destinationID int, fileDeleter *scene.FileDeleter, options scene.MergeOptions) error
Destroy(ctx context.Context, scene *models.Scene, fileDeleter *scene.FileDeleter, deleteGenerated, deleteFile, destroyFileEntry bool) error
Destroy(ctx context.Context, scene *models.Scene, fileDeleter *scene.FileDeleter, deleteGenerated, deleteFile bool) error
FindByIDs(ctx context.Context, ids []int, load ...scene.LoadRelationshipOption) ([]*models.Scene, error)
sceneFingerprintGetter
}
type ImageService interface {
Destroy(ctx context.Context, image *models.Image, fileDeleter *image.FileDeleter, deleteGenerated, deleteFile, destroyFileEntry bool) error
Destroy(ctx context.Context, image *models.Image, fileDeleter *image.FileDeleter, deleteGenerated, deleteFile bool) error
DestroyZipImages(ctx context.Context, zipFile models.File, fileDeleter *image.FileDeleter, deleteGenerated bool) ([]*models.Image, error)
}
@@ -31,7 +31,7 @@ type GalleryService interface {
SetCover(ctx context.Context, g *models.Gallery, coverImageId int) error
ResetCover(ctx context.Context, g *models.Gallery) error
Destroy(ctx context.Context, i *models.Gallery, fileDeleter *image.FileDeleter, deleteGenerated, deleteFile, destroyFileEntry bool) ([]*models.Image, error)
Destroy(ctx context.Context, i *models.Gallery, fileDeleter *image.FileDeleter, deleteGenerated, deleteFile bool) ([]*models.Image, error)
ValidateImageGalleryChange(ctx context.Context, i *models.Image, updateIDs models.UpdateIDs) error

View File

@@ -300,10 +300,7 @@ func (h *cleanHandler) handleRelatedScenes(ctx context.Context, fileDeleter *fil
// only delete if the scene has no other files
if len(scene.Files.List()) <= 1 {
logger.Infof("Deleting scene %q since it has no other related files", scene.DisplayName())
const deleteGenerated = true
const deleteFile = false
const destroyFileEntry = false
if err := mgr.SceneService.Destroy(ctx, scene, sceneFileDeleter, deleteGenerated, deleteFile, destroyFileEntry); err != nil {
if err := mgr.SceneService.Destroy(ctx, scene, sceneFileDeleter, true, false); err != nil {
return err
}
@@ -424,10 +421,7 @@ func (h *cleanHandler) handleRelatedImages(ctx context.Context, fileDeleter *fil
if len(i.Files.List()) <= 1 {
logger.Infof("Deleting image %q since it has no other related files", i.DisplayName())
const deleteGenerated = true
const deleteFile = false
const destroyFileEntry = false
if err := mgr.ImageService.Destroy(ctx, i, imageFileDeleter, deleteGenerated, deleteFile, destroyFileEntry); err != nil {
if err := mgr.ImageService.Destroy(ctx, i, imageFileDeleter, true, false); err != nil {
return err
}

View File

@@ -29,7 +29,6 @@ type GenerateMetadataInput struct {
// Generate transcodes even if not required
ForceTranscodes bool `json:"forceTranscodes"`
Phashes bool `json:"phashes"`
ImagePhashes bool `json:"imagePhashes"`
InteractiveHeatmapsSpeeds bool `json:"interactiveHeatmapsSpeeds"`
ClipPreviews bool `json:"clipPreviews"`
ImageThumbnails bool `json:"imageThumbnails"`
@@ -37,10 +36,6 @@ type GenerateMetadataInput struct {
SceneIDs []string `json:"sceneIDs"`
// marker ids to generate for
MarkerIDs []string `json:"markerIDs"`
// image ids to generate for
ImageIDs []string `json:"imageIDs"`
// gallery ids to generate for
GalleryIDs []string `json:"galleryIDs"`
// overwrite existing media
Overwrite bool `json:"overwrite"`
}
@@ -78,7 +73,6 @@ type totalsGenerate struct {
markers int64
transcodes int64
phashes int64
imagePhashes int64
interactiveHeatmapSpeeds int64
clipPreviews int64
imageThumbnails int64
@@ -88,9 +82,8 @@ type totalsGenerate struct {
func (j *GenerateJob) Execute(ctx context.Context, progress *job.Progress) error {
var scenes []*models.Scene
var markers []*models.SceneMarker
var images []*models.Image
var err error
var markers []*models.SceneMarker
j.overwrite = j.input.Overwrite
j.fileNamingAlgo = config.GetInstance().GetVideoFileNamingAlgorithm()
@@ -112,14 +105,6 @@ func (j *GenerateJob) Execute(ctx context.Context, progress *job.Progress) error
if err != nil {
logger.Error(err.Error())
}
imageIDs, err := stringslice.StringSliceToIntSlice(j.input.ImageIDs)
if err != nil {
logger.Error(err.Error())
}
galleryIDs, err := stringslice.StringSliceToIntSlice(j.input.GalleryIDs)
if err != nil {
logger.Error(err.Error())
}
g := &generate.Generator{
Encoder: instance.FFMpeg,
@@ -133,7 +118,7 @@ func (j *GenerateJob) Execute(ctx context.Context, progress *job.Progress) error
r := j.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
qb := r.Scene
if len(j.input.SceneIDs) == 0 && len(j.input.MarkerIDs) == 0 && len(j.input.ImageIDs) == 0 && len(j.input.GalleryIDs) == 0 {
if len(j.input.SceneIDs) == 0 && len(j.input.MarkerIDs) == 0 {
j.queueTasks(ctx, g, queue)
} else {
if len(j.input.SceneIDs) > 0 {
@@ -156,33 +141,6 @@ func (j *GenerateJob) Execute(ctx context.Context, progress *job.Progress) error
j.queueMarkerJob(g, m, queue)
}
}
if len(j.input.ImageIDs) > 0 {
images, err = r.Image.FindMany(ctx, imageIDs)
for _, i := range images {
if err := i.LoadFiles(ctx, r.Image); err != nil {
return err
}
j.queueImageJob(g, i, queue)
}
}
if len(j.input.GalleryIDs) > 0 {
for _, galleryID := range galleryIDs {
imgs, err := r.Image.FindByGalleryID(ctx, galleryID)
if err != nil {
return err
}
for _, img := range imgs {
if err := img.LoadFiles(ctx, r.Image); err != nil {
return err
}
j.queueImageJob(g, img, queue)
}
}
}
}
return nil
@@ -214,17 +172,14 @@ func (j *GenerateJob) Execute(ctx context.Context, progress *job.Progress) error
if j.input.Phashes {
logMsg += fmt.Sprintf(" %d phashes", totals.phashes)
}
if j.input.ImagePhashes {
logMsg += fmt.Sprintf(" %d image phashes", totals.imagePhashes)
}
if j.input.InteractiveHeatmapsSpeeds {
logMsg += fmt.Sprintf(" %d heatmaps & speeds", totals.interactiveHeatmapSpeeds)
}
if j.input.ClipPreviews {
logMsg += fmt.Sprintf(" %d image clip previews", totals.clipPreviews)
logMsg += fmt.Sprintf(" %d Image Clip Previews", totals.clipPreviews)
}
if j.input.ImageThumbnails {
logMsg += fmt.Sprintf(" %d image thumbnails", totals.imageThumbnails)
logMsg += fmt.Sprintf(" %d Image Thumbnails", totals.imageThumbnails)
}
if logMsg == "Generating" {
logMsg = "Nothing selected to generate"
@@ -329,7 +284,7 @@ func (j *GenerateJob) queueImagesTasks(ctx context.Context, g *generate.Generato
r := j.repository
for more := j.input.ClipPreviews || j.input.ImageThumbnails || j.input.ImagePhashes; more; {
for more := j.input.ClipPreviews || j.input.ImageThumbnails; more; {
if job.IsCancelled(ctx) {
return
}
@@ -570,23 +525,4 @@ func (j *GenerateJob) queueImageJob(g *generate.Generator, image *models.Image,
queue <- task
}
}
if j.input.ImagePhashes {
// generate for all files in image
for _, f := range image.Files.List() {
if imageFile, ok := f.(*models.ImageFile); ok {
task := &GenerateImagePhashTask{
repository: j.repository,
File: imageFile,
Overwrite: j.overwrite,
}
if task.required() {
j.totals.imagePhashes++
j.totals.tasks++
queue <- task
}
}
}
}
}

View File

@@ -1,103 +0,0 @@
package manager
import (
"context"
"fmt"
"github.com/stashapp/stash/pkg/hash/imagephash"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
)
type GenerateImagePhashTask struct {
repository models.Repository
File *models.ImageFile
Overwrite bool
}
func (t *GenerateImagePhashTask) GetDescription() string {
return fmt.Sprintf("Generating phash for %s", t.File.Path)
}
func (t *GenerateImagePhashTask) Start(ctx context.Context) {
if !t.required() {
return
}
var hash int64
set := false
// #4393 - if there is a file with the same md5, we can use the same phash
// only use this if we're not overwriting
if !t.Overwrite {
existing, err := t.findExistingPhash(ctx)
if err != nil {
logger.Warnf("Error finding existing phash: %v", err)
} else if existing != nil {
logger.Infof("Using existing phash for %s", t.File.Path)
hash = existing.(int64)
set = true
}
}
if !set {
generated, err := imagephash.Generate(t.File)
if err != nil {
logger.Errorf("Error generating phash for %q: %v", t.File.Path, err)
logErrorOutput(err)
return
}
hash = int64(*generated)
}
r := t.repository
if err := r.WithTxn(ctx, func(ctx context.Context) error {
t.File.Fingerprints = t.File.Fingerprints.AppendUnique(models.Fingerprint{
Type: models.FingerprintTypePhash,
Fingerprint: hash,
})
return r.File.Update(ctx, t.File)
}); err != nil && ctx.Err() == nil {
logger.Errorf("Error setting phash: %v", err)
}
}
func (t *GenerateImagePhashTask) findExistingPhash(ctx context.Context) (interface{}, error) {
r := t.repository
var ret interface{}
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
md5 := t.File.Fingerprints.Get(models.FingerprintTypeMD5)
// find other files with the same md5
files, err := r.File.FindByFingerprint(ctx, models.Fingerprint{
Type: models.FingerprintTypeMD5,
Fingerprint: md5,
})
if err != nil {
return fmt.Errorf("finding files by md5: %w", err)
}
// find the first file with a phash
for _, file := range files {
if phash := file.Base().Fingerprints.Get(models.FingerprintTypePhash); phash != nil {
ret = phash
return nil
}
}
return nil
}); err != nil {
return nil, err
}
return ret, nil
}
func (t *GenerateImagePhashTask) required() bool {
if t.Overwrite {
return true
}
return t.File.Fingerprints.Get(models.FingerprintTypePhash) == nil
}

View File

@@ -44,7 +44,7 @@ func (t *GeneratePhashTask) Start(ctx context.Context) {
if !set {
generated, err := videophash.Generate(instance.FFMpeg, t.File)
if err != nil {
logger.Errorf("Error generating phash for %q: %v", t.File.Path, err)
logger.Errorf("Error generating phash: %v", err)
logErrorOutput(err)
return
}

View File

@@ -2,17 +2,13 @@ package manager
import (
"context"
"errors"
"fmt"
"io/fs"
"path/filepath"
"regexp"
"runtime/debug"
"sync"
"time"
"github.com/99designs/gqlgen/graphql/handler/lru"
"github.com/remeh/sizedwaitgroup"
"github.com/stashapp/stash/internal/manager/config"
"github.com/stashapp/stash/pkg/file"
"github.com/stashapp/stash/pkg/file/video"
@@ -28,13 +24,14 @@ import (
"github.com/stashapp/stash/pkg/txn"
)
type scanner interface {
Scan(ctx context.Context, handlers []file.Handler, options file.ScanOptions, progressReporter file.ProgressReporter)
}
type ScanJob struct {
scanner *file.Scanner
scanner scanner
input ScanMetadataInput
subscriptions *subscriptionManager
fileQueue chan file.ScannedFile
count int
}
func (j *ScanJob) Execute(ctx context.Context, progress *job.Progress) error {
@@ -58,22 +55,22 @@ func (j *ScanJob) Execute(ctx context.Context, progress *job.Progress) error {
start := time.Now()
nTasks := cfg.GetParallelTasksWithAutoDetection()
const taskQueueSize = 200000
taskQueue := job.NewTaskQueue(ctx, progress, taskQueueSize, nTasks)
taskQueue := job.NewTaskQueue(ctx, progress, taskQueueSize, cfg.GetParallelTasksWithAutoDetection())
var minModTime time.Time
if j.input.Filter != nil && j.input.Filter.MinModTime != nil {
minModTime = *j.input.Filter.MinModTime
}
// HACK - these should really be set in the scanner initialization
j.scanner.FileHandlers = getScanHandlers(j.input, taskQueue, progress)
j.scanner.ScanFilters = []file.PathFilter{newScanFilter(c, repo, minModTime)}
j.scanner.HandlerRequiredFilters = []file.Filter{newHandlerRequiredFilter(cfg, repo)}
j.runJob(ctx, paths, nTasks, progress)
j.scanner.Scan(ctx, getScanHandlers(j.input, taskQueue, progress), file.ScanOptions{
Paths: paths,
ScanFilters: []file.PathFilter{newScanFilter(c, repo, minModTime)},
ZipFileExtensions: cfg.GetGalleryExtensions(),
ParallelTasks: cfg.GetParallelTasksWithAutoDetection(),
HandlerRequiredFilters: []file.Filter{newHandlerRequiredFilter(cfg, repo)},
Rescan: j.input.Rescan,
}, progress)
taskQueue.Close()
@@ -89,264 +86,6 @@ func (j *ScanJob) Execute(ctx context.Context, progress *job.Progress) error {
return nil
}
func (j *ScanJob) runJob(ctx context.Context, paths []string, nTasks int, progress *job.Progress) {
var wg sync.WaitGroup
wg.Add(1)
j.fileQueue = make(chan file.ScannedFile, scanQueueSize)
go func() {
defer func() {
wg.Done()
// handle panics in goroutine
if p := recover(); p != nil {
logger.Errorf("panic while queuing files for scan: %v", p)
logger.Errorf(string(debug.Stack()))
}
}()
if err := j.queueFiles(ctx, paths, progress); err != nil {
if errors.Is(err, context.Canceled) {
return
}
logger.Errorf("error queuing files for scan: %v", err)
return
}
logger.Infof("Finished adding files to queue. %d files queued", j.count)
}()
defer wg.Wait()
j.processQueue(ctx, nTasks, progress)
}
const scanQueueSize = 200000
func (j *ScanJob) queueFiles(ctx context.Context, paths []string, progress *job.Progress) error {
fs := &file.OsFS{}
defer func() {
close(j.fileQueue)
progress.AddTotal(j.count)
progress.Definite()
}()
var err error
progress.ExecuteTask("Walking directory tree", func() {
for _, p := range paths {
err = file.SymWalk(fs, p, j.queueFileFunc(ctx, fs, nil, progress))
if err != nil {
return
}
}
})
return err
}
func (j *ScanJob) queueFileFunc(ctx context.Context, f models.FS, zipFile *file.ScannedFile, progress *job.Progress) fs.WalkDirFunc {
return func(path string, d fs.DirEntry, err error) error {
if err != nil {
// don't let errors prevent scanning
logger.Errorf("error scanning %s: %v", path, err)
return nil
}
if err = ctx.Err(); err != nil {
return err
}
info, err := d.Info()
if err != nil {
logger.Errorf("reading info for %q: %v", path, err)
return nil
}
if !j.scanner.AcceptEntry(ctx, path, info) {
if info.IsDir() {
logger.Debugf("Skipping directory %s", path)
return fs.SkipDir
}
logger.Debugf("Skipping file %s", path)
return nil
}
size, err := file.GetFileSize(f, path, info)
if err != nil {
return err
}
ff := file.ScannedFile{
BaseFile: &models.BaseFile{
DirEntry: models.DirEntry{
ModTime: file.ModTime(info),
},
Path: path,
Basename: filepath.Base(path),
Size: size,
},
FS: f,
Info: info,
}
if zipFile != nil {
ff.ZipFileID = &zipFile.ID
ff.ZipFile = zipFile
}
if info.IsDir() {
// handle folders immediately
if err := j.handleFolder(ctx, ff, progress); err != nil {
if !errors.Is(err, context.Canceled) {
logger.Errorf("error processing %q: %v", path, err)
}
// skip the directory since we won't be able to process the files anyway
return fs.SkipDir
}
return nil
}
// if zip file is present, we handle immediately
if zipFile != nil {
progress.ExecuteTask("Scanning "+path, func() {
// don't increment progress in zip files
if err := j.handleFile(ctx, ff, nil); err != nil {
if !errors.Is(err, context.Canceled) {
logger.Errorf("error processing %q: %v", path, err)
}
// don't return an error, just skip the file
}
})
return nil
}
logger.Tracef("Queueing file %s for scanning", path)
j.fileQueue <- ff
j.count++
return nil
}
}
func (j *ScanJob) processQueue(ctx context.Context, parallelTasks int, progress *job.Progress) {
if parallelTasks < 1 {
parallelTasks = 1
}
wg := sizedwaitgroup.New(parallelTasks)
func() {
defer func() {
wg.Wait()
// handle panics in goroutine
if p := recover(); p != nil {
logger.Errorf("panic while scanning files: %v", p)
logger.Errorf(string(debug.Stack()))
}
}()
for f := range j.fileQueue {
logger.Tracef("Processing queued file %s", f.Path)
if err := ctx.Err(); err != nil {
return
}
wg.Add()
ff := f
go func() {
defer wg.Done()
j.processQueueItem(ctx, ff, progress)
}()
}
}()
}
func (j *ScanJob) processQueueItem(ctx context.Context, f file.ScannedFile, progress *job.Progress) {
progress.ExecuteTask("Scanning "+f.Path, func() {
var err error
if f.Info.IsDir() {
err = j.handleFolder(ctx, f, progress)
} else {
err = j.handleFile(ctx, f, progress)
}
if err != nil && !errors.Is(err, context.Canceled) {
logger.Errorf("error processing %q: %v", f.Path, err)
}
})
}
func (j *ScanJob) handleFolder(ctx context.Context, f file.ScannedFile, progress *job.Progress) error {
if progress != nil {
defer progress.Increment()
}
_, err := j.scanner.ScanFolder(ctx, f)
if err != nil {
return err
}
return nil
}
func (j *ScanJob) handleFile(ctx context.Context, f file.ScannedFile, progress *job.Progress) error {
if progress != nil {
defer progress.Increment()
}
r, err := j.scanner.ScanFile(ctx, f)
if err != nil {
return err
}
// handle rename should have already handled the contents of the zip file
// so shouldn't need to scan it again
if (r.New || r.Updated) && j.scanner.IsZipFile(f.Info.Name()) {
ff := r.File
f.BaseFile = ff.Base()
// scan zip files with a different context that is not cancellable
// cancelling while scanning zip file contents results in the scan
// contents being partially completed
zipCtx := context.WithoutCancel(ctx)
if err := j.scanZipFile(zipCtx, f, progress); err != nil {
logger.Errorf("Error scanning zip file %q: %v", f.Path, err)
}
}
return nil
}
func (j *ScanJob) scanZipFile(ctx context.Context, f file.ScannedFile, progress *job.Progress) error {
zipFS, err := f.FS.OpenZip(f.Path, f.Size)
if err != nil {
if errors.Is(err, file.ErrNotReaderAt) {
// can't walk the zip file
// just return
logger.Debugf("Skipping zip file %q as it cannot be opened for walking", f.Path)
return nil
}
return err
}
defer zipFS.Close()
return file.SymWalk(zipFS, f.Path, j.queueFileFunc(ctx, zipFS, &f, progress))
}
type extensionConfig struct {
vidExt []string
imgExt []string
@@ -724,29 +463,6 @@ func (g *imageGenerators) Generate(ctx context.Context, i *models.Image, f model
}
}
if t.ScanGenerateImagePhashes {
progress.AddTotal(1)
phashFn := func(ctx context.Context) {
mgr := GetInstance()
// Only generate phash for image files, not video files
if imageFile, ok := f.(*models.ImageFile); ok {
taskPhash := GenerateImagePhashTask{
repository: mgr.Repository,
File: imageFile,
Overwrite: overwrite,
}
taskPhash.Start(ctx)
}
progress.Increment()
}
if g.sequentialScanning {
phashFn(ctx)
} else {
g.taskQueue.Add(fmt.Sprintf("Generating phash for %s", path), phashFn)
}
}
return nil
}

View File

@@ -3,10 +3,6 @@ package file
import (
"context"
"fmt"
"io/fs"
"os"
"time"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
@@ -39,23 +35,3 @@ func (r *Repository) WithReadTxn(ctx context.Context, fn txn.TxnFunc) error {
func (r *Repository) WithDB(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithDatabase(ctx, r.TxnManager, fn)
}
// ModTime returns the modification time truncated to seconds.
func ModTime(info fs.FileInfo) time.Time {
// truncate to seconds, since we don't store beyond that in the database
return info.ModTime().Truncate(time.Second)
}
// GetFileSize gets the size of the file, taking into account symlinks.
func GetFileSize(f models.FS, path string, info fs.FileInfo) (int64, error) {
// #2196/#3042 - replace size with target size if file is a symlink
if info.Mode()&os.ModeSymlink == os.ModeSymlink {
targetInfo, err := f.Stat(path)
if err != nil {
return 0, fmt.Errorf("reading info for symlink %q: %w", path, err)
}
return targetInfo.Size(), nil
}
return info.Size(), nil
}

View File

@@ -75,7 +75,7 @@ func (d *folderRenameDetector) bestCandidate() *models.Folder {
return best.folder
}
func (s *Scanner) detectFolderMove(ctx context.Context, file ScannedFile) (*models.Folder, error) {
func (s *scanJob) detectFolderMove(ctx context.Context, file scanFile) (*models.Folder, error) {
// in order for a folder to be considered moved, the existing folder must be
// missing, and the majority of the old folder's files must be present, unchanged,
// in the new folder.
@@ -88,7 +88,7 @@ func (s *Scanner) detectFolderMove(ctx context.Context, file ScannedFile) (*mode
r := s.Repository
if err := SymWalk(file.FS, file.Path, func(path string, d fs.DirEntry, err error) error {
if err := symWalk(file.fs, file.Path, func(path string, d fs.DirEntry, err error) error {
if err != nil {
// don't let errors prevent scanning
logger.Errorf("error scanning %s: %v", path, err)
@@ -111,11 +111,11 @@ func (s *Scanner) detectFolderMove(ctx context.Context, file ScannedFile) (*mode
return nil
}
if !s.AcceptEntry(ctx, path, info) {
if !s.acceptEntry(ctx, path, info) {
return nil
}
size, err := GetFileSize(file.FS, path, info)
size, err := getFileSize(file.fs, path, info)
if err != nil {
return fmt.Errorf("getting file size for %q: %w", path, err)
}
@@ -154,7 +154,7 @@ func (s *Scanner) detectFolderMove(ctx context.Context, file ScannedFile) (*mode
}
// parent folder must be missing
_, err = file.FS.Lstat(pf.Path)
_, err = file.fs.Lstat(pf.Path)
if err == nil {
// parent folder exists, not a candidate
detector.reject(parentFolderID)

View File

@@ -2,18 +2,29 @@ package file
import (
"context"
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"runtime/debug"
"strings"
"sync"
"time"
"github.com/remeh/sizedwaitgroup"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
)
const (
scanQueueSize = 200000
// maximum number of times to retry in the event of a locked database
// use -1 to retry forever
maxRetries = -1
)
// Scanner scans files into the database.
//
// The scan process works using two goroutines. The first walks through the provided paths
@@ -44,26 +55,8 @@ type Scanner struct {
Repository Repository
FingerprintCalculator FingerprintCalculator
// ZipFileExtensions is a list of file extensions that are considered zip files.
// Extension does not include the . character.
ZipFileExtensions []string
// ScanFilters are used to determine if a file should be scanned.
ScanFilters []PathFilter
// HandlerRequiredFilters are used to determine if an unchanged file needs to be handled
HandlerRequiredFilters []Filter
// FileDecorators are applied to files as they are scanned.
FileDecorators []Decorator
// handlers are called after a file has been scanned.
FileHandlers []Handler
// Rescan indicates whether files should be rescanned even if they haven't changed.
Rescan bool
folderPathToID sync.Map
}
// FingerprintCalculator calculates a fingerprint for the provided file.
@@ -98,18 +91,257 @@ func (d *FilteredDecorator) IsMissingMetadata(ctx context.Context, fs models.FS,
return false
}
// ScannedFile represents a file being scanned.
type ScannedFile struct {
*models.BaseFile
FS models.FS
Info fs.FileInfo
// ProgressReporter is used to report progress of the scan.
type ProgressReporter interface {
AddTotal(total int)
Increment()
Definite()
ExecuteTask(description string, fn func())
}
// AcceptEntry determines if the file entry should be accepted for scanning
func (s *Scanner) AcceptEntry(ctx context.Context, path string, info fs.FileInfo) bool {
type scanJob struct {
*Scanner
// handlers are called after a file has been scanned.
handlers []Handler
ProgressReports ProgressReporter
options ScanOptions
startTime time.Time
fileQueue chan scanFile
retryList []scanFile
retrying bool
folderPathToID sync.Map
zipPathToID sync.Map
count int
txnRetryer txn.Retryer
}
// ScanOptions provides options for scanning files.
type ScanOptions struct {
Paths []string
// ZipFileExtensions is a list of file extensions that are considered zip files.
// Extension does not include the . character.
ZipFileExtensions []string
// ScanFilters are used to determine if a file should be scanned.
ScanFilters []PathFilter
// HandlerRequiredFilters are used to determine if an unchanged file needs to be handled
HandlerRequiredFilters []Filter
ParallelTasks int
// When true files in path will be rescanned even if they haven't changed
Rescan bool
}
// Scan starts the scanning process.
func (s *Scanner) Scan(ctx context.Context, handlers []Handler, options ScanOptions, progressReporter ProgressReporter) {
job := &scanJob{
Scanner: s,
handlers: handlers,
ProgressReports: progressReporter,
options: options,
txnRetryer: txn.Retryer{
Manager: s.Repository.TxnManager,
Retries: maxRetries,
},
}
job.execute(ctx)
}
type scanFile struct {
*models.BaseFile
fs models.FS
info fs.FileInfo
}
func (s *scanJob) withTxn(ctx context.Context, fn func(ctx context.Context) error) error {
return s.txnRetryer.WithTxn(ctx, fn)
}
func (s *scanJob) withDB(ctx context.Context, fn func(ctx context.Context) error) error {
return s.Repository.WithDB(ctx, fn)
}
func (s *scanJob) execute(ctx context.Context) {
paths := s.options.Paths
logger.Infof("scanning %d paths", len(paths))
s.startTime = time.Now()
s.fileQueue = make(chan scanFile, scanQueueSize)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer func() {
wg.Done()
// handle panics in goroutine
if p := recover(); p != nil {
logger.Errorf("panic while queuing files for scan: %v", p)
logger.Errorf(string(debug.Stack()))
}
}()
if err := s.queueFiles(ctx, paths); err != nil {
if errors.Is(err, context.Canceled) {
return
}
logger.Errorf("error queuing files for scan: %v", err)
return
}
logger.Infof("Finished adding files to queue. %d files queued", s.count)
}()
defer wg.Wait()
if err := s.processQueue(ctx); err != nil {
if errors.Is(err, context.Canceled) {
return
}
logger.Errorf("error scanning files: %v", err)
return
}
}
func (s *scanJob) queueFiles(ctx context.Context, paths []string) error {
defer func() {
close(s.fileQueue)
if s.ProgressReports != nil {
s.ProgressReports.AddTotal(s.count)
s.ProgressReports.Definite()
}
}()
var err error
s.ProgressReports.ExecuteTask("Walking directory tree", func() {
for _, p := range paths {
err = symWalk(s.FS, p, s.queueFileFunc(ctx, s.FS, nil))
if err != nil {
return
}
}
})
return err
}
func (s *scanJob) queueFileFunc(ctx context.Context, f models.FS, zipFile *scanFile) fs.WalkDirFunc {
return func(path string, d fs.DirEntry, err error) error {
if err != nil {
// don't let errors prevent scanning
logger.Errorf("error scanning %s: %v", path, err)
return nil
}
if err = ctx.Err(); err != nil {
return err
}
info, err := d.Info()
if err != nil {
logger.Errorf("reading info for %q: %v", path, err)
return nil
}
if !s.acceptEntry(ctx, path, info) {
if info.IsDir() {
return fs.SkipDir
}
return nil
}
size, err := getFileSize(f, path, info)
if err != nil {
return err
}
ff := scanFile{
BaseFile: &models.BaseFile{
DirEntry: models.DirEntry{
ModTime: modTime(info),
},
Path: path,
Basename: filepath.Base(path),
Size: size,
},
fs: f,
info: info,
}
if zipFile != nil {
zipFileID, err := s.getZipFileID(ctx, zipFile)
if err != nil {
return err
}
ff.ZipFileID = zipFileID
ff.ZipFile = zipFile
}
if info.IsDir() {
// handle folders immediately
if err := s.handleFolder(ctx, ff); err != nil {
if !errors.Is(err, context.Canceled) {
logger.Errorf("error processing %q: %v", path, err)
}
// skip the directory since we won't be able to process the files anyway
return fs.SkipDir
}
return nil
}
// if zip file is present, we handle immediately
if zipFile != nil {
s.ProgressReports.ExecuteTask("Scanning "+path, func() {
if err := s.handleFile(ctx, ff); err != nil {
if !errors.Is(err, context.Canceled) {
logger.Errorf("error processing %q: %v", path, err)
}
// don't return an error, just skip the file
}
})
return nil
}
s.fileQueue <- ff
s.count++
return nil
}
}
func getFileSize(f models.FS, path string, info fs.FileInfo) (int64, error) {
// #2196/#3042 - replace size with target size if file is a symlink
if info.Mode()&os.ModeSymlink == os.ModeSymlink {
targetInfo, err := f.Stat(path)
if err != nil {
return 0, fmt.Errorf("reading info for symlink %q: %w", path, err)
}
return targetInfo.Size(), nil
}
return info.Size(), nil
}
func (s *scanJob) acceptEntry(ctx context.Context, path string, info fs.FileInfo) bool {
// always accept if there's no filters
accept := len(s.ScanFilters) == 0
for _, filter := range s.ScanFilters {
accept := len(s.options.ScanFilters) == 0
for _, filter := range s.options.ScanFilters {
// accept if any filter accepts the file
if filter.Accept(ctx, path, info) {
accept = true
@@ -120,7 +352,102 @@ func (s *Scanner) AcceptEntry(ctx context.Context, path string, info fs.FileInfo
return accept
}
func (s *Scanner) getFolderID(ctx context.Context, path string) (*models.FolderID, error) {
func (s *scanJob) scanZipFile(ctx context.Context, f scanFile) error {
zipFS, err := f.fs.OpenZip(f.Path, f.Size)
if err != nil {
if errors.Is(err, errNotReaderAt) {
// can't walk the zip file
// just return
return nil
}
return err
}
defer zipFS.Close()
return symWalk(zipFS, f.Path, s.queueFileFunc(ctx, zipFS, &f))
}
func (s *scanJob) processQueue(ctx context.Context) error {
parallelTasks := s.options.ParallelTasks
if parallelTasks < 1 {
parallelTasks = 1
}
wg := sizedwaitgroup.New(parallelTasks)
if err := func() error {
defer wg.Wait()
for f := range s.fileQueue {
if err := ctx.Err(); err != nil {
return err
}
wg.Add()
ff := f
go func() {
defer wg.Done()
s.processQueueItem(ctx, ff)
}()
}
return nil
}(); err != nil {
return err
}
s.retrying = true
if err := func() error {
defer wg.Wait()
for _, f := range s.retryList {
if err := ctx.Err(); err != nil {
return err
}
wg.Add()
ff := f
go func() {
defer wg.Done()
s.processQueueItem(ctx, ff)
}()
}
return nil
}(); err != nil {
return err
}
return nil
}
func (s *scanJob) incrementProgress(f scanFile) {
// don't increment for files inside zip files since these aren't
// counted during the initial walking
if s.ProgressReports != nil && f.ZipFile == nil {
s.ProgressReports.Increment()
}
}
func (s *scanJob) processQueueItem(ctx context.Context, f scanFile) {
s.ProgressReports.ExecuteTask("Scanning "+f.Path, func() {
var err error
if f.info.IsDir() {
err = s.handleFolder(ctx, f)
} else {
err = s.handleFile(ctx, f)
}
if err != nil && !errors.Is(err, context.Canceled) {
logger.Errorf("error processing %q: %v", f.Path, err)
}
})
}
func (s *scanJob) getFolderID(ctx context.Context, path string) (*models.FolderID, error) {
// check the folder cache first
if f, ok := s.folderPathToID.Load(path); ok {
v := f.(models.FolderID)
@@ -143,17 +470,48 @@ func (s *Scanner) getFolderID(ctx context.Context, path string) (*models.FolderI
return &ret.ID, nil
}
// ScanFolder scans the provided folder into the database, returning the folder entry.
// If the folder already exists, it is updated if necessary.
func (s *Scanner) ScanFolder(ctx context.Context, file ScannedFile) (*models.Folder, error) {
var f *models.Folder
var err error
func (s *scanJob) getZipFileID(ctx context.Context, zipFile *scanFile) (*models.FileID, error) {
if zipFile == nil {
return nil, nil
}
if zipFile.ID != 0 {
return &zipFile.ID, nil
}
path := zipFile.Path
// check the folder cache first
if f, ok := s.zipPathToID.Load(path); ok {
v := f.(models.FileID)
return &v, nil
}
// assume case sensitive when searching for the zip file
const caseSensitive = true
ret, err := s.Repository.File.FindByPath(ctx, path, caseSensitive)
if err != nil {
return nil, fmt.Errorf("getting zip file ID for %q: %w", path, err)
}
if ret == nil {
return nil, fmt.Errorf("zip file %q doesn't exist in database", zipFile.Path)
}
s.zipPathToID.Store(path, ret.Base().ID)
return &ret.Base().ID, nil
}
func (s *scanJob) handleFolder(ctx context.Context, file scanFile) error {
path := file.Path
err = s.Repository.WithTxn(ctx, func(ctx context.Context) error {
return s.withTxn(ctx, func(ctx context.Context) error {
defer s.incrementProgress(file)
// determine if folder already exists in data store (by path)
// assume case sensitive by default
f, err = s.Repository.Folder.FindByPath(ctx, path, true)
f, err := s.Repository.Folder.FindByPath(ctx, path, true)
if err != nil {
return fmt.Errorf("checking for existing folder %q: %w", path, err)
}
@@ -162,7 +520,7 @@ func (s *Scanner) ScanFolder(ctx context.Context, file ScannedFile) (*models.Fol
// case insensitive searching
// assume case sensitive if in zip
if f == nil && file.ZipFileID == nil {
caseSensitive, _ := file.FS.IsPathCaseSensitive(file.Path)
caseSensitive, _ := file.fs.IsPathCaseSensitive(file.Path)
if !caseSensitive {
f, err = s.Repository.Folder.FindByPath(ctx, path, false)
@@ -189,11 +547,9 @@ func (s *Scanner) ScanFolder(ctx context.Context, file ScannedFile) (*models.Fol
return nil
})
return f, err
}
func (s *Scanner) onNewFolder(ctx context.Context, file ScannedFile) (*models.Folder, error) {
func (s *scanJob) onNewFolder(ctx context.Context, file scanFile) (*models.Folder, error) {
renamed, err := s.handleFolderRename(ctx, file)
if err != nil {
return nil, err
@@ -240,7 +596,7 @@ func (s *Scanner) onNewFolder(ctx context.Context, file ScannedFile) (*models.Fo
return toCreate, nil
}
func (s *Scanner) handleFolderRename(ctx context.Context, file ScannedFile) (*models.Folder, error) {
func (s *scanJob) handleFolderRename(ctx context.Context, file scanFile) (*models.Folder, error) {
// ignore folders in zip files
if file.ZipFileID != nil {
return nil, nil
@@ -281,7 +637,7 @@ func (s *Scanner) handleFolderRename(ctx context.Context, file ScannedFile) (*mo
return renamedFrom, nil
}
func (s *Scanner) onExistingFolder(ctx context.Context, f ScannedFile, existing *models.Folder) (*models.Folder, error) {
func (s *scanJob) onExistingFolder(ctx context.Context, f scanFile, existing *models.Folder) (*models.Folder, error) {
update := false
// update if mod time is changed
@@ -322,22 +678,22 @@ func (s *Scanner) onExistingFolder(ctx context.Context, f ScannedFile, existing
return existing, nil
}
type ScanFileResult struct {
File models.File
New bool
Renamed bool
Updated bool
func modTime(info fs.FileInfo) time.Time {
// truncate to seconds, since we don't store beyond that in the database
return info.ModTime().Truncate(time.Second)
}
// ScanFile scans the provided file into the database, returning the scan result.
func (s *Scanner) ScanFile(ctx context.Context, f ScannedFile) (*ScanFileResult, error) {
var r *ScanFileResult
func (s *scanJob) handleFile(ctx context.Context, f scanFile) error {
defer s.incrementProgress(f)
var ff models.File
// don't use a transaction to check if new or existing
if err := s.Repository.WithDB(ctx, func(ctx context.Context) error {
if err := s.withDB(ctx, func(ctx context.Context) error {
// determine if file already exists in data store
// assume case sensitive when searching for the file to begin with
ff, err := s.Repository.File.FindByPath(ctx, f.Path, true)
var err error
ff, err = s.Repository.File.FindByPath(ctx, f.Path, true)
if err != nil {
return fmt.Errorf("checking for existing file %q: %w", f.Path, err)
}
@@ -346,7 +702,7 @@ func (s *Scanner) ScanFile(ctx context.Context, f ScannedFile) (*ScanFileResult,
// case insensitive search
// assume case sensitive if in zip
if ff == nil && f.ZipFileID != nil {
caseSensitive, _ := f.FS.IsPathCaseSensitive(f.Path)
caseSensitive, _ := f.fs.IsPathCaseSensitive(f.Path)
if !caseSensitive {
ff, err = s.Repository.File.FindByPath(ctx, f.Path, false)
@@ -358,23 +714,35 @@ func (s *Scanner) ScanFile(ctx context.Context, f ScannedFile) (*ScanFileResult,
if ff == nil {
// returns a file only if it is actually new
r, err = s.onNewFile(ctx, f)
ff, err = s.onNewFile(ctx, f)
return err
}
r, err = s.onExistingFile(ctx, f, ff)
ff, err = s.onExistingFile(ctx, f, ff)
return err
}); err != nil {
return nil, err
return err
}
return r, nil
if ff != nil && s.isZipFile(f.info.Name()) {
f.BaseFile = ff.Base()
// scan zip files with a different context that is not cancellable
// cancelling while scanning zip file contents results in the scan
// contents being partially completed
zipCtx := context.WithoutCancel(ctx)
if err := s.scanZipFile(zipCtx, f); err != nil {
logger.Errorf("Error scanning zip file %q: %v", f.Path, err)
}
}
return nil
}
// IsZipFile determines if the provided path is a zip file based on its extension.
func (s *Scanner) IsZipFile(path string) bool {
func (s *scanJob) isZipFile(path string) bool {
fExt := filepath.Ext(path)
for _, ext := range s.ZipFileExtensions {
for _, ext := range s.options.ZipFileExtensions {
if strings.EqualFold(fExt, "."+ext) {
return true
}
@@ -383,7 +751,7 @@ func (s *Scanner) IsZipFile(path string) bool {
return false
}
func (s *Scanner) onNewFile(ctx context.Context, f ScannedFile) (*ScanFileResult, error) {
func (s *scanJob) onNewFile(ctx context.Context, f scanFile) (models.File, error) {
now := time.Now()
baseFile := f.BaseFile
@@ -399,20 +767,28 @@ func (s *Scanner) onNewFile(ctx context.Context, f ScannedFile) (*ScanFileResult
}
if parentFolderID == nil {
return nil, fmt.Errorf("parent folder for %q doesn't exist", path)
// if parent folder doesn't exist, assume it's not yet created
// add this file to the queue to be created later
if s.retrying {
// if we're retrying and the folder still doesn't exist, then it's a problem
return nil, fmt.Errorf("parent folder for %q doesn't exist", path)
}
s.retryList = append(s.retryList, f)
return nil, nil
}
baseFile.ParentFolderID = *parentFolderID
const useExisting = false
fp, err := s.calculateFingerprints(f.FS, baseFile, path, useExisting)
fp, err := s.calculateFingerprints(f.fs, baseFile, path, useExisting)
if err != nil {
return nil, err
}
baseFile.SetFingerprints(fp)
file, err := s.fireDecorators(ctx, f.FS, baseFile)
file, err := s.fireDecorators(ctx, f.fs, baseFile)
if err != nil {
return nil, err
}
@@ -425,17 +801,14 @@ func (s *Scanner) onNewFile(ctx context.Context, f ScannedFile) (*ScanFileResult
}
if renamed != nil {
return &ScanFileResult{
File: renamed,
Renamed: true,
}, nil
// handle rename should have already handled the contents of the zip file
// so shouldn't need to scan it again
// return nil so it doesn't
return nil, nil
}
// if not renamed, queue file for creation
if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error {
if err := s.withTxn(ctx, func(ctx context.Context) error {
if err := s.Repository.File.Create(ctx, file); err != nil {
return fmt.Errorf("creating file %q: %w", path, err)
}
@@ -449,13 +822,10 @@ func (s *Scanner) onNewFile(ctx context.Context, f ScannedFile) (*ScanFileResult
return nil, err
}
return &ScanFileResult{
File: file,
New: true,
}, nil
return file, nil
}
func (s *Scanner) fireDecorators(ctx context.Context, fs models.FS, f models.File) (models.File, error) {
func (s *scanJob) fireDecorators(ctx context.Context, fs models.FS, f models.File) (models.File, error) {
for _, h := range s.FileDecorators {
var err error
f, err = h.Decorate(ctx, fs, f)
@@ -467,8 +837,8 @@ func (s *Scanner) fireDecorators(ctx context.Context, fs models.FS, f models.Fil
return f, nil
}
func (s *Scanner) fireHandlers(ctx context.Context, f models.File, oldFile models.File) error {
for _, h := range s.FileHandlers {
func (s *scanJob) fireHandlers(ctx context.Context, f models.File, oldFile models.File) error {
for _, h := range s.handlers {
if err := h.Handle(ctx, f, oldFile); err != nil {
return err
}
@@ -477,7 +847,7 @@ func (s *Scanner) fireHandlers(ctx context.Context, f models.File, oldFile model
return nil
}
func (s *Scanner) calculateFingerprints(fs models.FS, f *models.BaseFile, path string, useExisting bool) (models.Fingerprints, error) {
func (s *scanJob) calculateFingerprints(fs models.FS, f *models.BaseFile, path string, useExisting bool) (models.Fingerprints, error) {
// only log if we're (re)calculating fingerprints
if !useExisting {
logger.Infof("Calculating fingerprints for %s ...", path)
@@ -514,7 +884,7 @@ func appendFileUnique(v []models.File, toAdd []models.File) []models.File {
return v
}
func (s *Scanner) getFileFS(f *models.BaseFile) (models.FS, error) {
func (s *scanJob) getFileFS(f *models.BaseFile) (models.FS, error) {
if f.ZipFile == nil {
return s.FS, nil
}
@@ -529,7 +899,7 @@ func (s *Scanner) getFileFS(f *models.BaseFile) (models.FS, error) {
return fs.OpenZip(zipPath, zipSize)
}
func (s *Scanner) handleRename(ctx context.Context, f models.File, fp []models.Fingerprint) (models.File, error) {
func (s *scanJob) handleRename(ctx context.Context, f models.File, fp []models.Fingerprint) (models.File, error) {
var others []models.File
for _, tfp := range fp {
@@ -571,7 +941,7 @@ func (s *Scanner) handleRename(ctx context.Context, f models.File, fp []models.F
// treat as a move
missing = append(missing, other)
}
case !s.AcceptEntry(ctx, other.Base().Path, info):
case !s.acceptEntry(ctx, other.Base().Path, info):
// #4393 - if the file is no longer in the configured library paths, treat it as a move
logger.Debugf("File %q no longer in library paths. Treating as a move.", other.Base().Path)
missing = append(missing, other)
@@ -604,12 +974,12 @@ func (s *Scanner) handleRename(ctx context.Context, f models.File, fp []models.F
fBaseCopy.Fingerprints = updatedBase.Fingerprints
*updatedBase = fBaseCopy
if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error {
if err := s.withTxn(ctx, func(ctx context.Context) error {
if err := s.Repository.File.Update(ctx, updated); err != nil {
return fmt.Errorf("updating file for rename %q: %w", newPath, err)
}
if s.IsZipFile(updatedBase.Basename) {
if s.isZipFile(updatedBase.Basename) {
if err := transferZipHierarchy(ctx, s.Repository.Folder, s.Repository.File, updatedBase.ID, oldPath, newPath); err != nil {
return fmt.Errorf("moving zip hierarchy for renamed zip file %q: %w", newPath, err)
}
@@ -627,9 +997,9 @@ func (s *Scanner) handleRename(ctx context.Context, f models.File, fp []models.F
return updated, nil
}
func (s *Scanner) isHandlerRequired(ctx context.Context, f models.File) bool {
accept := len(s.HandlerRequiredFilters) == 0
for _, filter := range s.HandlerRequiredFilters {
func (s *scanJob) isHandlerRequired(ctx context.Context, f models.File) bool {
accept := len(s.options.HandlerRequiredFilters) == 0
for _, filter := range s.options.HandlerRequiredFilters {
// accept if any filter accepts the file
if filter.Accept(ctx, f) {
accept = true
@@ -648,9 +1018,9 @@ func (s *Scanner) isHandlerRequired(ctx context.Context, f models.File) bool {
// - file size
// - image format, width or height
// - video codec, audio codec, format, width, height, framerate or bitrate
func (s *Scanner) isMissingMetadata(ctx context.Context, f ScannedFile, existing models.File) bool {
func (s *scanJob) isMissingMetadata(ctx context.Context, f scanFile, existing models.File) bool {
for _, h := range s.FileDecorators {
if h.IsMissingMetadata(ctx, f.FS, existing) {
if h.IsMissingMetadata(ctx, f.fs, existing) {
return true
}
}
@@ -658,20 +1028,20 @@ func (s *Scanner) isMissingMetadata(ctx context.Context, f ScannedFile, existing
return false
}
func (s *Scanner) setMissingMetadata(ctx context.Context, f ScannedFile, existing models.File) (models.File, error) {
func (s *scanJob) setMissingMetadata(ctx context.Context, f scanFile, existing models.File) (models.File, error) {
path := existing.Base().Path
logger.Infof("Updating metadata for %s", path)
existing.Base().Size = f.Size
var err error
existing, err = s.fireDecorators(ctx, f.FS, existing)
existing, err = s.fireDecorators(ctx, f.fs, existing)
if err != nil {
return nil, err
}
// queue file for update
if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error {
if err := s.withTxn(ctx, func(ctx context.Context) error {
if err := s.Repository.File.Update(ctx, existing); err != nil {
return fmt.Errorf("updating file %q: %w", path, err)
}
@@ -684,9 +1054,9 @@ func (s *Scanner) setMissingMetadata(ctx context.Context, f ScannedFile, existin
return existing, nil
}
func (s *Scanner) setMissingFingerprints(ctx context.Context, f ScannedFile, existing models.File) (models.File, error) {
func (s *scanJob) setMissingFingerprints(ctx context.Context, f scanFile, existing models.File) (models.File, error) {
const useExisting = true
fp, err := s.calculateFingerprints(f.FS, existing.Base(), f.Path, useExisting)
fp, err := s.calculateFingerprints(f.fs, existing.Base(), f.Path, useExisting)
if err != nil {
return nil, err
}
@@ -694,7 +1064,7 @@ func (s *Scanner) setMissingFingerprints(ctx context.Context, f ScannedFile, exi
if fp.ContentsChanged(existing.Base().Fingerprints) {
existing.SetFingerprints(fp)
if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error {
if err := s.withTxn(ctx, func(ctx context.Context) error {
if err := s.Repository.File.Update(ctx, existing); err != nil {
return fmt.Errorf("updating file %q: %w", f.Path, err)
}
@@ -709,14 +1079,14 @@ func (s *Scanner) setMissingFingerprints(ctx context.Context, f ScannedFile, exi
}
// returns a file only if it was updated
func (s *Scanner) onExistingFile(ctx context.Context, f ScannedFile, existing models.File) (*ScanFileResult, error) {
func (s *scanJob) onExistingFile(ctx context.Context, f scanFile, existing models.File) (models.File, error) {
base := existing.Base()
path := base.Path
fileModTime := f.ModTime
// #6326 - also force a rescan if the basename changed
updated := !fileModTime.Equal(base.ModTime) || base.Basename != f.Basename
forceRescan := s.Rescan
forceRescan := s.options.Rescan
if !updated && !forceRescan {
return s.onUnchangedFile(ctx, f, existing)
@@ -738,7 +1108,7 @@ func (s *Scanner) onExistingFile(ctx context.Context, f ScannedFile, existing mo
// calculate and update fingerprints for the file
const useExisting = false
fp, err := s.calculateFingerprints(f.FS, base, path, useExisting)
fp, err := s.calculateFingerprints(f.fs, base, path, useExisting)
if err != nil {
return nil, err
}
@@ -746,13 +1116,13 @@ func (s *Scanner) onExistingFile(ctx context.Context, f ScannedFile, existing mo
s.removeOutdatedFingerprints(existing, fp)
existing.SetFingerprints(fp)
existing, err = s.fireDecorators(ctx, f.FS, existing)
existing, err = s.fireDecorators(ctx, f.fs, existing)
if err != nil {
return nil, err
}
// queue file for update
if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error {
if err := s.withTxn(ctx, func(ctx context.Context) error {
if err := s.Repository.File.Update(ctx, existing); err != nil {
return fmt.Errorf("updating file %q: %w", path, err)
}
@@ -765,13 +1135,11 @@ func (s *Scanner) onExistingFile(ctx context.Context, f ScannedFile, existing mo
}); err != nil {
return nil, err
}
return &ScanFileResult{
File: existing,
Updated: true,
}, nil
return existing, nil
}
func (s *Scanner) removeOutdatedFingerprints(existing models.File, fp models.Fingerprints) {
func (s *scanJob) removeOutdatedFingerprints(existing models.File, fp models.Fingerprints) {
// HACK - if no MD5 fingerprint was returned, and the oshash is changed
// then remove the MD5 fingerprint
oshash := fp.For(models.FingerprintTypeOshash)
@@ -799,7 +1167,7 @@ func (s *Scanner) removeOutdatedFingerprints(existing models.File, fp models.Fin
}
// returns a file only if it was updated
func (s *Scanner) onUnchangedFile(ctx context.Context, f ScannedFile, existing models.File) (*ScanFileResult, error) {
func (s *scanJob) onUnchangedFile(ctx context.Context, f scanFile, existing models.File) (models.File, error) {
var err error
isMissingMetdata := s.isMissingMetadata(ctx, f, existing)
@@ -818,7 +1186,7 @@ func (s *Scanner) onUnchangedFile(ctx context.Context, f ScannedFile, existing m
}
handlerRequired := false
if err := s.Repository.WithDB(ctx, func(ctx context.Context) error {
if err := s.withDB(ctx, func(ctx context.Context) error {
// check if the handler needs to be run
handlerRequired = s.isHandlerRequired(ctx, existing)
return nil
@@ -828,20 +1196,15 @@ func (s *Scanner) onUnchangedFile(ctx context.Context, f ScannedFile, existing m
if !handlerRequired {
// if this file is a zip file, then we need to rescan the contents
// as well. We do this by indicating that the file is updated.
// as well. We do this by returning the file, instead of nil.
if isMissingMetdata {
return &ScanFileResult{
File: existing,
Updated: true,
}, nil
return existing, nil
}
return &ScanFileResult{
File: existing,
}, nil
return nil, nil
}
if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error {
if err := s.withTxn(ctx, func(ctx context.Context) error {
if err := s.fireHandlers(ctx, existing, nil); err != nil {
return err
}
@@ -852,9 +1215,6 @@ func (s *Scanner) onUnchangedFile(ctx context.Context, f ScannedFile, existing m
}
// if this file is a zip file, then we need to rescan the contents
// as well. We do this by indicating that the file is updated.
return &ScanFileResult{
File: existing,
Updated: true,
}, nil
// as well. We do this by returning the file, instead of nil.
return existing, nil
}

View File

@@ -81,8 +81,8 @@ func walkSym(f models.FS, filename string, linkDirname string, walkFn fs.WalkDir
return fsWalk(f, filename, symWalkFunc)
}
// SymWalk extends filepath.Walk to also follow symlinks
func SymWalk(fs models.FS, path string, walkFn fs.WalkDirFunc) error {
// symWalk extends filepath.Walk to also follow symlinks
func symWalk(fs models.FS, path string, walkFn fs.WalkDirFunc) error {
return walkSym(fs, path, path, walkFn)
}

View File

@@ -18,7 +18,7 @@ import (
)
var (
ErrNotReaderAt = errors.New("invalid reader: does not implement io.ReaderAt")
errNotReaderAt = errors.New("not a ReaderAt")
errZipFSOpenZip = errors.New("cannot open zip file inside zip file")
)
@@ -38,7 +38,7 @@ func newZipFS(fs models.FS, path string, size int64) (*zipFS, error) {
asReaderAt, _ := reader.(io.ReaderAt)
if asReaderAt == nil {
reader.Close()
return nil, ErrNotReaderAt
return nil, errNotReaderAt
}
zipReader, err := zip.NewReader(asReaderAt, size)

View File

@@ -8,13 +8,13 @@ import (
"github.com/stashapp/stash/pkg/models"
)
func (s *Service) Destroy(ctx context.Context, i *models.Gallery, fileDeleter *image.FileDeleter, deleteGenerated, deleteFile, destroyFileEntry bool) ([]*models.Image, error) {
func (s *Service) Destroy(ctx context.Context, i *models.Gallery, fileDeleter *image.FileDeleter, deleteGenerated, deleteFile bool) ([]*models.Image, error) {
var imgsDestroyed []*models.Image
// chapter deletion is done via delete cascade, so we don't need to do anything here
// if this is a zip-based gallery, delete the images as well first
zipImgsDestroyed, err := s.destroyZipFileImages(ctx, i, fileDeleter, deleteGenerated, deleteFile, destroyFileEntry)
zipImgsDestroyed, err := s.destroyZipFileImages(ctx, i, fileDeleter, deleteGenerated, deleteFile)
if err != nil {
return nil, err
}
@@ -45,7 +45,7 @@ func DestroyChapter(ctx context.Context, galleryChapter *models.GalleryChapter,
return qb.Destroy(ctx, galleryChapter.ID)
}
func (s *Service) destroyZipFileImages(ctx context.Context, i *models.Gallery, fileDeleter *image.FileDeleter, deleteGenerated, deleteFile, destroyFileEntry bool) ([]*models.Image, error) {
func (s *Service) destroyZipFileImages(ctx context.Context, i *models.Gallery, fileDeleter *image.FileDeleter, deleteGenerated, deleteFile bool) ([]*models.Image, error) {
if err := i.LoadFiles(ctx, s.Repository); err != nil {
return nil, err
}
@@ -81,12 +81,6 @@ func (s *Service) destroyZipFileImages(ctx context.Context, i *models.Gallery, f
if err := destroyer.DestroyZip(ctx, f, fileDeleter.Deleter, deleteFile); err != nil {
return nil, err
}
} else if destroyFileEntry {
// destroy file DB entry without deleting filesystem file
const deleteFileFromFS = false
if err := destroyer.DestroyZip(ctx, f, nil, deleteFileFromFS); err != nil {
return nil, err
}
}
}

View File

@@ -126,7 +126,7 @@ func (i *Importer) populateStudio(ctx context.Context) error {
}
func (i *Importer) createStudio(ctx context.Context, name string) (int, error) {
newStudio := models.NewCreateStudioInput()
newStudio := models.NewStudio()
newStudio.Name = name
err := i.StudioWriter.Create(ctx, &newStudio)

View File

@@ -115,9 +115,9 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
}
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.CreateStudioInput)
s.Studio.ID = existingStudioID
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio)
s.ID = existingStudioID
}).Return(nil)
err := i.PreImport(testCtx)
@@ -147,7 +147,7 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
}
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Return(errors.New("Create error"))
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)

View File

@@ -16,7 +16,7 @@ type ImageFinder interface {
}
type ImageService interface {
Destroy(ctx context.Context, i *models.Image, fileDeleter *image.FileDeleter, deleteGenerated, deleteFile, destroyFileEntry bool) error
Destroy(ctx context.Context, i *models.Image, fileDeleter *image.FileDeleter, deleteGenerated, deleteFile bool) error
DestroyZipImages(ctx context.Context, zipFile models.File, fileDeleter *image.FileDeleter, deleteGenerated bool) ([]*models.Image, error)
DestroyFolderImages(ctx context.Context, folderID models.FolderID, fileDeleter *image.FileDeleter, deleteGenerated, deleteFile bool) ([]*models.Image, error)
}

View File

@@ -203,7 +203,7 @@ func (i *Importer) populateStudio(ctx context.Context) error {
}
func (i *Importer) createStudio(ctx context.Context, name string) (int, error) {
newStudio := models.NewCreateStudioInput()
newStudio := models.NewStudio()
newStudio.Name = name
err := i.StudioWriter.Create(ctx, &newStudio)

View File

@@ -121,9 +121,9 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
}
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.CreateStudioInput)
s.Studio.ID = existingStudioID
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio)
s.ID = existingStudioID
}).Return(nil)
err := i.PreImport(testCtx)
@@ -156,7 +156,7 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
}
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Return(errors.New("Create error"))
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)

View File

@@ -1,48 +0,0 @@
package imagephash
import (
"bytes"
"fmt"
"image"
"github.com/corona10/goimagehash"
"github.com/stashapp/stash/pkg/file"
"github.com/stashapp/stash/pkg/models"
)
// Generate computes a perceptual hash for an image file.
func Generate(imageFile *models.ImageFile) (*uint64, error) {
img, err := loadImage(imageFile)
if err != nil {
return nil, fmt.Errorf("loading image: %w", err)
}
hash, err := goimagehash.PerceptionHash(img)
if err != nil {
return nil, fmt.Errorf("computing phash from image: %w", err)
}
hashValue := hash.GetHash()
return &hashValue, nil
}
// loadImage loads an image from disk and decodes it.
func loadImage(imageFile *models.ImageFile) (image.Image, error) {
reader, err := imageFile.Open(&file.OsFS{})
if err != nil {
return nil, err
}
defer reader.Close()
buf := new(bytes.Buffer)
if _, err := buf.ReadFrom(reader); err != nil {
return nil, err
}
img, _, err := image.Decode(buf)
if err != nil {
return nil, fmt.Errorf("decoding image: %w", err)
}
return img, nil
}

View File

@@ -37,8 +37,8 @@ func (d *FileDeleter) MarkGeneratedFiles(image *models.Image) error {
}
// Destroy destroys an image, optionally marking the file and generated files for deletion.
func (s *Service) Destroy(ctx context.Context, i *models.Image, fileDeleter *FileDeleter, deleteGenerated, deleteFile, destroyFileEntry bool) error {
return s.destroyImage(ctx, i, fileDeleter, deleteGenerated, deleteFile, destroyFileEntry)
func (s *Service) Destroy(ctx context.Context, i *models.Image, fileDeleter *FileDeleter, deleteGenerated, deleteFile bool) error {
return s.destroyImage(ctx, i, fileDeleter, deleteGenerated, deleteFile)
}
// DestroyZipImages destroys all images in zip, optionally marking the files and generated files for deletion.
@@ -75,8 +75,7 @@ func (s *Service) DestroyZipImages(ctx context.Context, zipFile models.File, fil
}
const deleteFileInZip = false
const destroyFileEntry = false
if err := s.destroyImage(ctx, img, fileDeleter, deleteGenerated, deleteFileInZip, destroyFileEntry); err != nil {
if err := s.destroyImage(ctx, img, fileDeleter, deleteGenerated, deleteFileInZip); err != nil {
return nil, err
}
@@ -136,8 +135,7 @@ func (s *Service) DestroyFolderImages(ctx context.Context, folderID models.Folde
continue
}
const destroyFileEntry = false
if err := s.Destroy(ctx, img, fileDeleter, deleteGenerated, deleteFile, destroyFileEntry); err != nil {
if err := s.Destroy(ctx, img, fileDeleter, deleteGenerated, deleteFile); err != nil {
return nil, err
}
@@ -148,15 +146,11 @@ func (s *Service) DestroyFolderImages(ctx context.Context, folderID models.Folde
}
// Destroy destroys an image, optionally marking the file and generated files for deletion.
func (s *Service) destroyImage(ctx context.Context, i *models.Image, fileDeleter *FileDeleter, deleteGenerated, deleteFile, destroyFileEntry bool) error {
func (s *Service) destroyImage(ctx context.Context, i *models.Image, fileDeleter *FileDeleter, deleteGenerated, deleteFile bool) error {
if deleteFile {
if err := s.deleteFiles(ctx, i, fileDeleter); err != nil {
return err
}
} else if destroyFileEntry {
if err := s.destroyFileEntries(ctx, i); err != nil {
return err
}
}
if deleteGenerated {
@@ -198,35 +192,3 @@ func (s *Service) deleteFiles(ctx context.Context, i *models.Image, fileDeleter
return nil
}
// destroyFileEntries destroys file entries from the database without deleting
// the files from the filesystem
func (s *Service) destroyFileEntries(ctx context.Context, i *models.Image) error {
if err := i.LoadFiles(ctx, s.Repository); err != nil {
return err
}
for _, f := range i.Files.List() {
// only destroy file entries where there is no other associated image
otherImages, err := s.Repository.FindByFileID(ctx, f.Base().ID)
if err != nil {
return err
}
if len(otherImages) > 1 {
// other image associated, don't remove
continue
}
// don't destroy files in zip archives
if f.Base().ZipFileID == nil {
const deleteFile = false
logger.Info("Destroying image file entry: ", f.Base().Path)
if err := file.Destroy(ctx, s.File, f, nil, deleteFile); err != nil {
return err
}
}
}
return nil
}

View File

@@ -159,7 +159,7 @@ func (i *Importer) populateStudio(ctx context.Context) error {
}
func (i *Importer) createStudio(ctx context.Context, name string) (int, error) {
newStudio := models.NewCreateStudioInput()
newStudio := models.NewStudio()
newStudio.Name = name
err := i.StudioWriter.Create(ctx, &newStudio)

View File

@@ -77,9 +77,9 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
}
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.CreateStudioInput)
s.Studio.ID = existingStudioID
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio)
s.ID = existingStudioID
}).Return(nil)
err := i.PreImport(testCtx)
@@ -109,7 +109,7 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
}
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Return(errors.New("Create error"))
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)

View File

@@ -95,7 +95,6 @@ type GalleryDestroyInput struct {
// If true, then the zip file will be deleted if the gallery is zip-file-based.
// If gallery is folder-based, then any files not associated with other
// galleries will be deleted, along with the folder, if it is not empty.
DeleteFile *bool `json:"delete_file"`
DeleteGenerated *bool `json:"delete_generated"`
DestroyFileEntry *bool `json:"destroy_file_entry"`
DeleteFile *bool `json:"delete_file"`
DeleteGenerated *bool `json:"delete_generated"`
}

View File

@@ -11,8 +11,6 @@ type ImageFilterType struct {
Photographer *StringCriterionInput `json:"photographer"`
// Filter by file checksum
Checksum *StringCriterionInput `json:"checksum"`
// Filter by phash distance
PhashDistance *PhashDistanceCriterionInput `json:"phash_distance"`
// Filter by path
Path *StringCriterionInput `json:"path"`
// Filter by file count
@@ -90,17 +88,15 @@ type ImageUpdateInput struct {
}
type ImageDestroyInput struct {
ID string `json:"id"`
DeleteFile *bool `json:"delete_file"`
DeleteGenerated *bool `json:"delete_generated"`
DestroyFileEntry *bool `json:"destroy_file_entry"`
ID string `json:"id"`
DeleteFile *bool `json:"delete_file"`
DeleteGenerated *bool `json:"delete_generated"`
}
type ImagesDestroyInput struct {
Ids []string `json:"ids"`
DeleteFile *bool `json:"delete_file"`
DeleteGenerated *bool `json:"delete_generated"`
DestroyFileEntry *bool `json:"destroy_file_entry"`
Ids []string `json:"ids"`
DeleteFile *bool `json:"delete_file"`
DeleteGenerated *bool `json:"delete_generated"`
}
type ImageQueryOptions struct {

View File

@@ -80,11 +80,11 @@ func (_m *StudioReaderWriter) CountByTagID(ctx context.Context, tagID int) (int,
}
// Create provides a mock function with given fields: ctx, newStudio
func (_m *StudioReaderWriter) Create(ctx context.Context, newStudio *models.CreateStudioInput) error {
func (_m *StudioReaderWriter) Create(ctx context.Context, newStudio *models.Studio) error {
ret := _m.Called(ctx, newStudio)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *models.CreateStudioInput) error); ok {
if rf, ok := ret.Get(0).(func(context.Context, *models.Studio) error); ok {
r0 = rf(ctx, newStudio)
} else {
r0 = ret.Error(0)
@@ -291,52 +291,6 @@ func (_m *StudioReaderWriter) GetAliases(ctx context.Context, relatedID int) ([]
return r0, r1
}
// GetCustomFields provides a mock function with given fields: ctx, id
func (_m *StudioReaderWriter) GetCustomFields(ctx context.Context, id int) (map[string]interface{}, error) {
ret := _m.Called(ctx, id)
var r0 map[string]interface{}
if rf, ok := ret.Get(0).(func(context.Context, int) map[string]interface{}); ok {
r0 = rf(ctx, id)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[string]interface{})
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
r1 = rf(ctx, id)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetCustomFieldsBulk provides a mock function with given fields: ctx, ids
func (_m *StudioReaderWriter) GetCustomFieldsBulk(ctx context.Context, ids []int) ([]models.CustomFieldMap, error) {
ret := _m.Called(ctx, ids)
var r0 []models.CustomFieldMap
if rf, ok := ret.Get(0).(func(context.Context, []int) []models.CustomFieldMap); ok {
r0 = rf(ctx, ids)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]models.CustomFieldMap)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, []int) error); ok {
r1 = rf(ctx, ids)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetImage provides a mock function with given fields: ctx, studioID
func (_m *StudioReaderWriter) GetImage(ctx context.Context, studioID int) ([]byte, error) {
ret := _m.Called(ctx, studioID)
@@ -525,11 +479,11 @@ func (_m *StudioReaderWriter) QueryForAutoTag(ctx context.Context, words []strin
}
// Update provides a mock function with given fields: ctx, updatedStudio
func (_m *StudioReaderWriter) Update(ctx context.Context, updatedStudio *models.UpdateStudioInput) error {
func (_m *StudioReaderWriter) Update(ctx context.Context, updatedStudio *models.Studio) error {
ret := _m.Called(ctx, updatedStudio)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *models.UpdateStudioInput) error); ok {
if rf, ok := ret.Get(0).(func(context.Context, *models.Studio) error); ok {
r0 = rf(ctx, updatedStudio)
} else {
r0 = ret.Error(0)

View File

@@ -27,9 +27,9 @@ type ScrapedStudio struct {
func (ScrapedStudio) IsScrapedContent() {}
func (s *ScrapedStudio) ToStudio(endpoint string, excluded map[string]bool) *CreateStudioInput {
func (s *ScrapedStudio) ToStudio(endpoint string, excluded map[string]bool) *Studio {
// Populate a new studio from the input
ret := NewCreateStudioInput()
ret := NewStudio()
ret.Name = strings.TrimSpace(s.Name)
if s.RemoteSiteID != nil && endpoint != "" && *s.RemoteSiteID != "" {

View File

@@ -113,7 +113,7 @@ func Test_scrapedToStudioInput(t *testing.T) {
got.StashIDs.List()[stid].UpdatedAt = time.Time{}
}
}
assert.Equal(t, tt.want, got.Studio)
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -23,18 +23,6 @@ type Studio struct {
StashIDs RelatedStashIDs `json:"stash_ids"`
}
type CreateStudioInput struct {
*Studio
CustomFields map[string]interface{} `json:"custom_fields"`
}
type UpdateStudioInput struct {
*Studio
CustomFields CustomFieldsInput `json:"custom_fields"`
}
func NewStudio() Studio {
currentTime := time.Now()
return Studio{
@@ -43,13 +31,6 @@ func NewStudio() Studio {
}
}
func NewCreateStudioInput() CreateStudioInput {
s := NewStudio()
return CreateStudioInput{
Studio: &s,
}
}
// StudioPartial represents part of a Studio object. It is used to update the database entry.
type StudioPartial struct {
ID int
@@ -67,8 +48,6 @@ type StudioPartial struct {
URLs *UpdateStrings
TagIDs *UpdateIDs
StashIDs *UpdateStashIDs
CustomFields CustomFieldsInput
}
func NewStudioPartial() StudioPartial {

View File

@@ -42,12 +42,12 @@ type StudioCounter interface {
// StudioCreator provides methods to create studios.
type StudioCreator interface {
Create(ctx context.Context, newStudio *CreateStudioInput) error
Create(ctx context.Context, newStudio *Studio) error
}
// StudioUpdater provides methods to update studios.
type StudioUpdater interface {
Update(ctx context.Context, updatedStudio *UpdateStudioInput) error
Update(ctx context.Context, updatedStudio *Studio) error
UpdatePartial(ctx context.Context, updatedStudio StudioPartial) (*Studio, error)
UpdateImage(ctx context.Context, studioID int, image []byte) error
}
@@ -79,8 +79,6 @@ type StudioReader interface {
TagIDLoader
URLLoader
CustomFieldsReader
All(ctx context.Context) ([]*Studio, error)
GetImage(ctx context.Context, studioID int) ([]byte, error)
HasImage(ctx context.Context, studioID int) (bool, error)

View File

@@ -81,8 +81,6 @@ type SceneFilterType struct {
StashIDEndpoint *StashIDCriterionInput `json:"stash_id_endpoint"`
// Filter by StashIDs Endpoint
StashIDsEndpoint *StashIDsCriterionInput `json:"stash_ids_endpoint"`
// Filter by StashID count
StashIDCount *IntCriterionInput `json:"stash_id_count"`
// Filter by url
URL *StringCriterionInput `json:"url"`
// Filter by interactive
@@ -206,17 +204,15 @@ type SceneUpdateInput struct {
}
type SceneDestroyInput struct {
ID string `json:"id"`
DeleteFile *bool `json:"delete_file"`
DeleteGenerated *bool `json:"delete_generated"`
DestroyFileEntry *bool `json:"destroy_file_entry"`
ID string `json:"id"`
DeleteFile *bool `json:"delete_file"`
DeleteGenerated *bool `json:"delete_generated"`
}
type ScenesDestroyInput struct {
Ids []string `json:"ids"`
DeleteFile *bool `json:"delete_file"`
DeleteGenerated *bool `json:"delete_generated"`
DestroyFileEntry *bool `json:"destroy_file_entry"`
Ids []string `json:"ids"`
DeleteFile *bool `json:"delete_file"`
DeleteGenerated *bool `json:"delete_generated"`
}
func NewSceneQueryResult(getter SceneGetter) *SceneQueryResult {

View File

@@ -46,9 +46,6 @@ type StudioFilterType struct {
CreatedAt *TimestampCriterionInput `json:"created_at"`
// Filter by updated at
UpdatedAt *TimestampCriterionInput `json:"updated_at"`
// Filter by custom fields
CustomFields []CustomFieldCriterionInput `json:"custom_fields"`
}
type StudioCreateInput struct {
@@ -65,8 +62,6 @@ type StudioCreateInput struct {
Aliases []string `json:"aliases"`
TagIds []string `json:"tag_ids"`
IgnoreAutoTag *bool `json:"ignore_auto_tag"`
CustomFields map[string]interface{} `json:"custom_fields"`
}
type StudioUpdateInput struct {
@@ -84,6 +79,4 @@ type StudioUpdateInput struct {
Aliases []string `json:"aliases"`
TagIds []string `json:"tag_ids"`
IgnoreAutoTag *bool `json:"ignore_auto_tag"`
CustomFields CustomFieldsInput `json:"custom_fields"`
}

View File

@@ -225,11 +225,6 @@ func ValidateUpdateAliases(existing models.Performer, name models.OptionalString
newName = name.Value
}
// If aliases is nil, we're only changing the name - check existing aliases against new name
if aliases == nil {
return ValidateAliases(newName, existing.Aliases)
}
newAliases := aliases.Apply(existing.Aliases.List())
return ValidateAliases(newName, models.NewRelatedStrings(newAliases))

View File

@@ -213,12 +213,12 @@ func TestValidateUpdateAliases(t *testing.T) {
want error
}{
{"both unset", osUnset, nil, nil},
{"name conflicts with alias", os2, nil, &DuplicateAliasError{name2}},
{"invalid name set", os2, nil, &DuplicateAliasError{name2}},
{"valid name set", os3, nil, nil},
{"valid aliases empty", os1, []string{}, nil},
{"alias matches name", osUnset, []string{name1U}, &DuplicateAliasError{name1U}},
{"invalid aliases set", osUnset, []string{name1U}, &DuplicateAliasError{name1U}},
{"valid aliases set", osUnset, []string{name3, name2}, nil},
{"alias matches new name", os4, []string{name4}, &DuplicateAliasError{name4}},
{"invalid both set", os4, []string{name4}, &DuplicateAliasError{name4}},
{"valid both set", os2, []string{name1}, nil},
}

View File

@@ -109,7 +109,7 @@ func (d *FileDeleter) MarkMarkerFiles(scene *models.Scene, seconds int) error {
// Destroy deletes a scene and its associated relationships from the
// database.
func (s *Service) Destroy(ctx context.Context, scene *models.Scene, fileDeleter *FileDeleter, deleteGenerated, deleteFile, destroyFileEntry bool) error {
func (s *Service) Destroy(ctx context.Context, scene *models.Scene, fileDeleter *FileDeleter, deleteGenerated, deleteFile bool) error {
mqb := s.MarkerRepository
markers, err := mqb.FindBySceneID(ctx, scene.ID)
if err != nil {
@@ -126,10 +126,6 @@ func (s *Service) Destroy(ctx context.Context, scene *models.Scene, fileDeleter
if err := s.deleteFiles(ctx, scene, fileDeleter); err != nil {
return err
}
} else if destroyFileEntry {
if err := s.destroyFileEntries(ctx, scene); err != nil {
return err
}
}
if deleteGenerated {
@@ -184,35 +180,6 @@ func (s *Service) deleteFiles(ctx context.Context, scene *models.Scene, fileDele
return nil
}
// destroyFileEntries destroys file entries from the database without deleting
// the files from the filesystem
func (s *Service) destroyFileEntries(ctx context.Context, scene *models.Scene) error {
if err := scene.LoadFiles(ctx, s.Repository); err != nil {
return err
}
for _, f := range scene.Files.List() {
// only destroy file entries where there is no other associated scene
otherScenes, err := s.Repository.FindByFileID(ctx, f.ID)
if err != nil {
return err
}
if len(otherScenes) > 1 {
// other scenes associated, don't remove
continue
}
const deleteFile = false
logger.Info("Destroying scene file entry: ", f.Path)
if err := file.Destroy(ctx, s.File, f, nil, deleteFile); err != nil {
return err
}
}
return nil
}
// DestroyMarker deletes the scene marker from the database and returns a
// function that removes the generated files, to be executed after the
// transaction is successfully committed.

View File

@@ -213,7 +213,7 @@ func (i *Importer) populateStudio(ctx context.Context) error {
}
func (i *Importer) createStudio(ctx context.Context, name string) (int, error) {
newStudio := models.NewCreateStudioInput()
newStudio := models.NewStudio()
newStudio.Name = name
err := i.StudioWriter.Create(ctx, &newStudio)

View File

@@ -241,9 +241,9 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
}
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.CreateStudioInput)
s.Studio.ID = existingStudioID
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio)
s.ID = existingStudioID
}).Return(nil)
err := i.PreImport(testCtx)
@@ -273,7 +273,7 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
}
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Return(errors.New("Create error"))
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)

View File

@@ -120,8 +120,7 @@ func (s *Service) Merge(ctx context.Context, sourceIDs []int, destinationID int,
for _, src := range sources {
const deleteGenerated = true
const deleteFile = false
const destroyFileEntry = false
if err := s.Destroy(ctx, src, fileDeleter, deleteGenerated, deleteFile, destroyFileEntry); err != nil {
if err := s.Destroy(ctx, src, fileDeleter, deleteGenerated, deleteFile); err != nil {
return fmt.Errorf("deleting scene %d: %w", src.ID, err)
}
}

View File

@@ -24,85 +24,9 @@ func (e scraperAction) IsValid() bool {
return false
}
type urlScraperActionImpl interface {
type scraperActionImpl interface {
scrapeByURL(ctx context.Context, url string, ty ScrapeContentType) (ScrapedContent, error)
}
func (c Definition) getURLScraper(def ByURLDefinition, client *http.Client, globalConfig GlobalConfig) urlScraperActionImpl {
switch def.Action {
case scraperActionScript:
return &scriptURLScraper{
scriptScraper: scriptScraper{
definition: c,
globalConfig: globalConfig,
},
definition: def,
}
case scraperActionStash:
return newStashScraper(client, c, globalConfig)
case scraperActionXPath:
return &xpathURLScraper{
xpathScraper: xpathScraper{
definition: c,
globalConfig: globalConfig,
client: client,
},
definition: def,
}
case scraperActionJson:
return &jsonURLScraper{
jsonScraper: jsonScraper{
definition: c,
globalConfig: globalConfig,
client: client,
},
definition: def,
}
}
panic("unknown scraper action: " + def.Action)
}
type nameScraperActionImpl interface {
scrapeByName(ctx context.Context, name string, ty ScrapeContentType) ([]ScrapedContent, error)
}
func (c Definition) getNameScraper(def ByNameDefinition, client *http.Client, globalConfig GlobalConfig) nameScraperActionImpl {
switch def.Action {
case scraperActionScript:
return &scriptNameScraper{
scriptScraper: scriptScraper{
definition: c,
globalConfig: globalConfig,
},
definition: def,
}
case scraperActionStash:
return newStashScraper(client, c, globalConfig)
case scraperActionXPath:
return &xpathNameScraper{
xpathScraper: xpathScraper{
definition: c,
globalConfig: globalConfig,
client: client,
},
definition: def,
}
case scraperActionJson:
return &jsonNameScraper{
jsonScraper: jsonScraper{
definition: c,
globalConfig: globalConfig,
client: client,
},
definition: def,
}
}
panic("unknown scraper action: " + def.Action)
}
type fragmentScraperActionImpl interface {
scrapeByFragment(ctx context.Context, input Input) (ScrapedContent, error)
scrapeSceneByScene(ctx context.Context, scene *models.Scene) (*models.ScrapedScene, error)
@@ -110,37 +34,17 @@ type fragmentScraperActionImpl interface {
scrapeImageByImage(ctx context.Context, image *models.Image) (*models.ScrapedImage, error)
}
func (c Definition) getFragmentScraper(actionDef ByFragmentDefinition, client *http.Client, globalConfig GlobalConfig) fragmentScraperActionImpl {
switch actionDef.Action {
func (c config) getScraper(scraper scraperTypeConfig, client *http.Client, globalConfig GlobalConfig) scraperActionImpl {
switch scraper.Action {
case scraperActionScript:
return &scriptFragmentScraper{
scriptScraper: scriptScraper{
definition: c,
globalConfig: globalConfig,
},
definition: actionDef,
}
return newScriptScraper(scraper, c, globalConfig)
case scraperActionStash:
return newStashScraper(client, c, globalConfig)
return newStashScraper(scraper, client, c, globalConfig)
case scraperActionXPath:
return &xpathFragmentScraper{
xpathScraper: xpathScraper{
definition: c,
globalConfig: globalConfig,
client: client,
},
definition: actionDef,
}
return newXpathScraper(scraper, client, c, globalConfig)
case scraperActionJson:
return &jsonFragmentScraper{
jsonScraper: jsonScraper{
definition: c,
globalConfig: globalConfig,
client: client,
},
definition: actionDef,
}
return newJsonScraper(scraper, client, c, globalConfig)
}
panic("unknown scraper action: " + actionDef.Action)
panic("unknown scraper action: " + scraper.Action)
}

View File

@@ -182,7 +182,7 @@ func (c *Cache) ReloadScrapers() {
if err != nil {
logger.Errorf("Error loading scraper %s: %v", fp, err)
} else {
scraper := scraperFromDefinition(*conf, c.globalConfig)
scraper := newGroupScraper(*conf, c.globalConfig)
scrapers[scraper.spec().ID] = scraper
}
}

View File

@@ -11,8 +11,7 @@ import (
"gopkg.in/yaml.v2"
)
// Definition represents a scraper definition (typically) loaded from a YAML configuration file.
type Definition struct {
type config struct {
ID string
path string
@@ -20,43 +19,43 @@ type Definition struct {
Name string `yaml:"name"`
// Configuration for querying performers by name
PerformerByName *ByNameDefinition `yaml:"performerByName"`
PerformerByName *scraperTypeConfig `yaml:"performerByName"`
// Configuration for querying performers by a Performer fragment
PerformerByFragment *ByFragmentDefinition `yaml:"performerByFragment"`
PerformerByFragment *scraperTypeConfig `yaml:"performerByFragment"`
// Configuration for querying a performer by a URL
PerformerByURL []*ByURLDefinition `yaml:"performerByURL"`
PerformerByURL []*scrapeByURLConfig `yaml:"performerByURL"`
// Configuration for querying scenes by a Scene fragment
SceneByFragment *ByFragmentDefinition `yaml:"sceneByFragment"`
SceneByFragment *scraperTypeConfig `yaml:"sceneByFragment"`
// Configuration for querying gallery by a Gallery fragment
GalleryByFragment *ByFragmentDefinition `yaml:"galleryByFragment"`
GalleryByFragment *scraperTypeConfig `yaml:"galleryByFragment"`
// Configuration for querying scenes by name
SceneByName *ByNameDefinition `yaml:"sceneByName"`
SceneByName *scraperTypeConfig `yaml:"sceneByName"`
// Configuration for querying scenes by query fragment
SceneByQueryFragment *ByFragmentDefinition `yaml:"sceneByQueryFragment"`
SceneByQueryFragment *scraperTypeConfig `yaml:"sceneByQueryFragment"`
// Configuration for querying a scene by a URL
SceneByURL []*ByURLDefinition `yaml:"sceneByURL"`
SceneByURL []*scrapeByURLConfig `yaml:"sceneByURL"`
// Configuration for querying a gallery by a URL
GalleryByURL []*ByURLDefinition `yaml:"galleryByURL"`
GalleryByURL []*scrapeByURLConfig `yaml:"galleryByURL"`
// Configuration for querying an image by a URL
ImageByURL []*ByURLDefinition `yaml:"imageByURL"`
ImageByURL []*scrapeByURLConfig `yaml:"imageByURL"`
// Configuration for querying image by an Image fragment
ImageByFragment *ByFragmentDefinition `yaml:"imageByFragment"`
ImageByFragment *scraperTypeConfig `yaml:"imageByFragment"`
// Configuration for querying a movie by a URL - deprecated, use GroupByURL
MovieByURL []*ByURLDefinition `yaml:"movieByURL"`
MovieByURL []*scrapeByURLConfig `yaml:"movieByURL"`
// Configuration for querying a group by a URL
GroupByURL []*ByURLDefinition `yaml:"groupByURL"`
GroupByURL []*scrapeByURLConfig `yaml:"groupByURL"`
// Scraper debugging options
DebugOptions *scraperDebugOptions `yaml:"debug"`
@@ -74,7 +73,7 @@ type Definition struct {
DriverOptions *scraperDriverOptions `yaml:"driver"`
}
func (c Definition) validate() error {
func (c config) validate() error {
if strings.TrimSpace(c.Name) == "" {
return errors.New("name must not be empty")
}
@@ -127,13 +126,17 @@ type stashServer struct {
ApiKey string `yaml:"apiKey"`
}
type ActionDefinition struct {
type scraperTypeConfig struct {
Action scraperAction `yaml:"action"`
Script []string `yaml:"script,flow"`
Scraper string `yaml:"scraper"`
// for xpath name scraper only
QueryURL string `yaml:"queryURL"`
QueryURLReplacements queryURLReplacements `yaml:"queryURLReplace"`
}
func (c ActionDefinition) validate() error {
func (c scraperTypeConfig) validate() error {
if !c.Action.IsValid() {
return fmt.Errorf("%s is not a valid scraper action", c.Action)
}
@@ -145,22 +148,20 @@ func (c ActionDefinition) validate() error {
return nil
}
type ByURLDefinition struct {
ActionDefinition `yaml:",inline"`
URL []string `yaml:"url,flow"`
QueryURL string `yaml:"queryURL"`
QueryURLReplacements queryURLReplacements `yaml:"queryURLReplace"`
type scrapeByURLConfig struct {
scraperTypeConfig `yaml:",inline"`
URL []string `yaml:"url,flow"`
}
func (c ByURLDefinition) validate() error {
func (c scrapeByURLConfig) validate() error {
if len(c.URL) == 0 {
return errors.New("url is mandatory for scrape by url scrapers")
}
return c.ActionDefinition.validate()
return c.scraperTypeConfig.validate()
}
func (c ByURLDefinition) matchesURL(url string) bool {
func (c scrapeByURLConfig) matchesURL(url string) bool {
for _, thisURL := range c.URL {
if strings.Contains(url, thisURL) {
return true
@@ -170,18 +171,6 @@ func (c ByURLDefinition) matchesURL(url string) bool {
return false
}
type ByFragmentDefinition struct {
ActionDefinition `yaml:",inline"`
QueryURL string `yaml:"queryURL"`
QueryURLReplacements queryURLReplacements `yaml:"queryURLReplace"`
}
type ByNameDefinition struct {
ActionDefinition `yaml:",inline"`
QueryURL string `yaml:"queryURL"`
}
type scraperDebugOptions struct {
PrintHTML bool `yaml:"printHTML"`
}
@@ -217,8 +206,8 @@ type scraperDriverOptions struct {
Headers []*header `yaml:"headers"`
}
func loadConfigFromYAML(id string, reader io.Reader) (*Definition, error) {
ret := &Definition{}
func loadConfigFromYAML(id string, reader io.Reader) (*config, error) {
ret := &config{}
parser := yaml.NewDecoder(reader)
parser.SetStrict(true)
@@ -236,7 +225,7 @@ func loadConfigFromYAML(id string, reader io.Reader) (*Definition, error) {
return ret, nil
}
func loadConfigFromYAMLFile(path string) (*Definition, error) {
func loadConfigFromYAMLFile(path string) (*config, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
@@ -257,7 +246,7 @@ func loadConfigFromYAMLFile(path string) (*Definition, error) {
return ret, nil
}
func (c Definition) spec() Scraper {
func (c config) spec() Scraper {
ret := Scraper{
ID: c.ID,
Name: c.Name,
@@ -345,7 +334,7 @@ func (c Definition) spec() Scraper {
return ret
}
func (c Definition) supports(ty ScrapeContentType) bool {
func (c config) supports(ty ScrapeContentType) bool {
switch ty {
case ScrapeContentTypePerformer:
return c.PerformerByName != nil || c.PerformerByFragment != nil || len(c.PerformerByURL) > 0
@@ -362,7 +351,7 @@ func (c Definition) supports(ty ScrapeContentType) bool {
panic("Unhandled ScrapeContentType")
}
func (c Definition) matchesURL(url string, ty ScrapeContentType) bool {
func (c config) matchesURL(url string, ty ScrapeContentType) bool {
switch ty {
case ScrapeContentTypePerformer:
for _, scraper := range c.PerformerByURL {

View File

@@ -18,7 +18,7 @@ import (
)
// jar constructs a cookie jar from a configuration
func (c Definition) jar() (*cookiejar.Jar, error) {
func (c config) jar() (*cookiejar.Jar, error) {
opts := c.DriverOptions
jar, err := cookiejar.New(&cookiejar.Options{
PublicSuffixList: publicsuffix.List,
@@ -77,7 +77,7 @@ func randomSequence(n int) string {
}
// printCookies prints all cookies from the given cookie jar
func printCookies(jar *cookiejar.Jar, scraperConfig Definition, msg string) {
func printCookies(jar *cookiejar.Jar, scraperConfig config, msg string) {
driverOptions := scraperConfig.DriverOptions
if driverOptions != nil && !driverOptions.UseCDP {
var foundURLs []*url.URL

View File

@@ -139,5 +139,5 @@ func getFreeonesScraper(globalConfig GlobalConfig) scraper {
logger.Fatalf("Error loading builtin freeones scraper: %s", err.Error())
}
return scraperFromDefinition(*c, globalConfig)
return newGroupScraper(*c, globalConfig)
}

View File

@@ -8,26 +8,25 @@ import (
"github.com/stashapp/stash/pkg/models"
)
// definedScraper implements the scraper interface using a Definition object.
type definedScraper struct {
config Definition
type group struct {
config config
globalConf GlobalConfig
}
func scraperFromDefinition(c Definition, globalConfig GlobalConfig) definedScraper {
return definedScraper{
func newGroupScraper(c config, globalConfig GlobalConfig) scraper {
return group{
config: c,
globalConf: globalConfig,
}
}
func (g definedScraper) spec() Scraper {
func (g group) spec() Scraper {
return g.config.spec()
}
// fragmentScraper finds an appropriate fragment scraper based on input.
func (g definedScraper) fragmentScraper(input Input) *ByFragmentDefinition {
func (g group) fragmentScraper(input Input) *scraperTypeConfig {
switch {
case input.Performer != nil:
return g.config.PerformerByFragment
@@ -44,7 +43,7 @@ func (g definedScraper) fragmentScraper(input Input) *ByFragmentDefinition {
return nil
}
func (g definedScraper) viaFragment(ctx context.Context, client *http.Client, input Input) (ScrapedContent, error) {
func (g group) viaFragment(ctx context.Context, client *http.Client, input Input) (ScrapedContent, error) {
stc := g.fragmentScraper(input)
if stc == nil {
// If there's no performer fragment scraper in the group, we try to use
@@ -57,38 +56,38 @@ func (g definedScraper) viaFragment(ctx context.Context, client *http.Client, in
return nil, ErrNotSupported
}
s := g.config.getFragmentScraper(*stc, client, g.globalConf)
s := g.config.getScraper(*stc, client, g.globalConf)
return s.scrapeByFragment(ctx, input)
}
func (g definedScraper) viaScene(ctx context.Context, client *http.Client, scene *models.Scene) (*models.ScrapedScene, error) {
func (g group) viaScene(ctx context.Context, client *http.Client, scene *models.Scene) (*models.ScrapedScene, error) {
if g.config.SceneByFragment == nil {
return nil, ErrNotSupported
}
s := g.config.getFragmentScraper(*g.config.SceneByFragment, client, g.globalConf)
s := g.config.getScraper(*g.config.SceneByFragment, client, g.globalConf)
return s.scrapeSceneByScene(ctx, scene)
}
func (g definedScraper) viaGallery(ctx context.Context, client *http.Client, gallery *models.Gallery) (*models.ScrapedGallery, error) {
func (g group) viaGallery(ctx context.Context, client *http.Client, gallery *models.Gallery) (*models.ScrapedGallery, error) {
if g.config.GalleryByFragment == nil {
return nil, ErrNotSupported
}
s := g.config.getFragmentScraper(*g.config.GalleryByFragment, client, g.globalConf)
s := g.config.getScraper(*g.config.GalleryByFragment, client, g.globalConf)
return s.scrapeGalleryByGallery(ctx, gallery)
}
func (g definedScraper) viaImage(ctx context.Context, client *http.Client, gallery *models.Image) (*models.ScrapedImage, error) {
func (g group) viaImage(ctx context.Context, client *http.Client, gallery *models.Image) (*models.ScrapedImage, error) {
if g.config.ImageByFragment == nil {
return nil, ErrNotSupported
}
s := g.config.getFragmentScraper(*g.config.ImageByFragment, client, g.globalConf)
s := g.config.getScraper(*g.config.ImageByFragment, client, g.globalConf)
return s.scrapeImageByImage(ctx, gallery)
}
func loadUrlCandidates(c Definition, ty ScrapeContentType) []*ByURLDefinition {
func loadUrlCandidates(c config, ty ScrapeContentType) []*scrapeByURLConfig {
switch ty {
case ScrapeContentTypePerformer:
return c.PerformerByURL
@@ -105,13 +104,12 @@ func loadUrlCandidates(c Definition, ty ScrapeContentType) []*ByURLDefinition {
panic("loadUrlCandidates: unreachable")
}
func (g definedScraper) viaURL(ctx context.Context, client *http.Client, url string, ty ScrapeContentType) (ScrapedContent, error) {
func (g group) viaURL(ctx context.Context, client *http.Client, url string, ty ScrapeContentType) (ScrapedContent, error) {
candidates := loadUrlCandidates(g.config, ty)
for _, scraper := range candidates {
if scraper.matchesURL(url) {
u := replaceURL(url, *scraper) // allow a URL Replace for url-queries
s := g.config.getURLScraper(*scraper, client, g.globalConf)
ret, err := s.scrapeByURL(ctx, u, ty)
s := g.config.getScraper(scraper.scraperTypeConfig, client, g.globalConf)
ret, err := s.scrapeByURL(ctx, url, ty)
if err != nil {
return nil, err
}
@@ -125,31 +123,31 @@ func (g definedScraper) viaURL(ctx context.Context, client *http.Client, url str
return nil, nil
}
func (g definedScraper) viaName(ctx context.Context, client *http.Client, name string, ty ScrapeContentType) ([]ScrapedContent, error) {
func (g group) viaName(ctx context.Context, client *http.Client, name string, ty ScrapeContentType) ([]ScrapedContent, error) {
switch ty {
case ScrapeContentTypePerformer:
if g.config.PerformerByName == nil {
break
}
s := g.config.getNameScraper(*g.config.PerformerByName, client, g.globalConf)
s := g.config.getScraper(*g.config.PerformerByName, client, g.globalConf)
return s.scrapeByName(ctx, name, ty)
case ScrapeContentTypeScene:
if g.config.SceneByName == nil {
break
}
s := g.config.getNameScraper(*g.config.SceneByName, client, g.globalConf)
s := g.config.getScraper(*g.config.SceneByName, client, g.globalConf)
return s.scrapeByName(ctx, name, ty)
}
return nil, fmt.Errorf("%w: cannot load %v by name", ErrNotSupported, ty)
}
func (g definedScraper) supports(ty ScrapeContentType) bool {
func (g group) supports(ty ScrapeContentType) bool {
return g.config.supports(ty)
}
func (g definedScraper) supportsURL(url string, ty ScrapeContentType) bool {
func (g group) supportsURL(url string, ty ScrapeContentType) bool {
return g.config.matchesURL(url, ty)
}

View File

@@ -15,22 +15,43 @@ import (
)
type jsonScraper struct {
definition Definition
scraper scraperTypeConfig
config config
globalConfig GlobalConfig
client *http.Client
}
func (s *jsonScraper) getJsonScraper(name string) (*mappedScraper, error) {
ret, ok := s.definition.JsonScrapers[name]
if !ok {
return nil, fmt.Errorf("json scraper with name %s not found in config", name)
func newJsonScraper(scraper scraperTypeConfig, client *http.Client, config config, globalConfig GlobalConfig) *jsonScraper {
return &jsonScraper{
scraper: scraper,
config: config,
client: client,
globalConfig: globalConfig,
}
}
func (s *jsonScraper) getJsonScraper() *mappedScraper {
return s.config.JsonScrapers[s.scraper.Scraper]
}
func (s *jsonScraper) scrapeURL(ctx context.Context, url string) (string, *mappedScraper, error) {
scraper := s.getJsonScraper()
if scraper == nil {
return "", nil, errors.New("json scraper with name " + s.scraper.Scraper + " not found in config")
}
return &ret, nil
doc, err := s.loadURL(ctx, url)
if err != nil {
return "", nil, err
}
return doc, scraper, nil
}
func (s *jsonScraper) loadURL(ctx context.Context, url string) (string, error) {
r, err := loadURL(ctx, url, s.client, s.definition, s.globalConfig)
r, err := loadURL(ctx, url, s.client, s.config, s.globalConfig)
if err != nil {
return "", err
}
@@ -45,30 +66,21 @@ func (s *jsonScraper) loadURL(ctx context.Context, url string) (string, error) {
return "", errors.New("not valid json")
}
if s.definition.DebugOptions != nil && s.definition.DebugOptions.PrintHTML {
if s.config.DebugOptions != nil && s.config.DebugOptions.PrintHTML {
logger.Infof("loadURL (%s) response: \n%s", url, docStr)
}
return docStr, err
}
type jsonURLScraper struct {
jsonScraper
definition ByURLDefinition
}
func (s *jsonURLScraper) scrapeByURL(ctx context.Context, url string, ty ScrapeContentType) (ScrapedContent, error) {
scraper, err := s.getJsonScraper(s.definition.Scraper)
func (s *jsonScraper) scrapeByURL(ctx context.Context, url string, ty ScrapeContentType) (ScrapedContent, error) {
u := replaceURL(url, s.scraper) // allow a URL Replace for url-queries
doc, scraper, err := s.scrapeURL(ctx, u)
if err != nil {
return nil, err
}
doc, err := s.loadURL(ctx, url)
if err != nil {
return nil, err
}
q := s.getJsonQuery(doc, url)
q := s.getJsonQuery(doc, u)
// if these just return the return values from scraper.scrape* functions then
// it ends up returning ScrapedContent(nil) rather than nil
switch ty {
@@ -107,15 +119,11 @@ func (s *jsonURLScraper) scrapeByURL(ctx context.Context, url string, ty ScrapeC
return nil, ErrNotSupported
}
type jsonNameScraper struct {
jsonScraper
definition ByNameDefinition
}
func (s *jsonScraper) scrapeByName(ctx context.Context, name string, ty ScrapeContentType) ([]ScrapedContent, error) {
scraper := s.getJsonScraper()
func (s *jsonNameScraper) scrapeByName(ctx context.Context, name string, ty ScrapeContentType) ([]ScrapedContent, error) {
scraper, err := s.getJsonScraper(s.definition.Scraper)
if err != nil {
return nil, err
if scraper == nil {
return nil, fmt.Errorf("%w: name %v", ErrNotFound, s.scraper.Scraper)
}
const placeholder = "{}"
@@ -123,7 +131,7 @@ func (s *jsonNameScraper) scrapeByName(ctx context.Context, name string, ty Scra
// replace the placeholder string with the URL-escaped name
escapedName := url.QueryEscape(name)
url := s.definition.QueryURL
url := s.scraper.QueryURL
url = strings.ReplaceAll(url, placeholder, escapedName)
doc, err := s.loadURL(ctx, url)
@@ -164,22 +172,18 @@ func (s *jsonNameScraper) scrapeByName(ctx context.Context, name string, ty Scra
return nil, ErrNotSupported
}
type jsonFragmentScraper struct {
jsonScraper
definition ByFragmentDefinition
}
func (s *jsonFragmentScraper) scrapeSceneByScene(ctx context.Context, scene *models.Scene) (*models.ScrapedScene, error) {
func (s *jsonScraper) scrapeSceneByScene(ctx context.Context, scene *models.Scene) (*models.ScrapedScene, error) {
// construct the URL
queryURL := queryURLParametersFromScene(scene)
if s.definition.QueryURLReplacements != nil {
queryURL.applyReplacements(s.definition.QueryURLReplacements)
if s.scraper.QueryURLReplacements != nil {
queryURL.applyReplacements(s.scraper.QueryURLReplacements)
}
url := queryURL.constructURL(s.definition.QueryURL)
url := queryURL.constructURL(s.scraper.QueryURL)
scraper, err := s.getJsonScraper(s.definition.Scraper)
if err != nil {
return nil, err
scraper := s.getJsonScraper()
if scraper == nil {
return nil, errors.New("json scraper with name " + s.scraper.Scraper + " not found in config")
}
doc, err := s.loadURL(ctx, url)
@@ -192,7 +196,7 @@ func (s *jsonFragmentScraper) scrapeSceneByScene(ctx context.Context, scene *mod
return scraper.scrapeScene(ctx, q)
}
func (s *jsonFragmentScraper) scrapeByFragment(ctx context.Context, input Input) (ScrapedContent, error) {
func (s *jsonScraper) scrapeByFragment(ctx context.Context, input Input) (ScrapedContent, error) {
switch {
case input.Gallery != nil:
return nil, fmt.Errorf("%w: cannot use a json scraper as a gallery fragment scraper", ErrNotSupported)
@@ -206,14 +210,15 @@ func (s *jsonFragmentScraper) scrapeByFragment(ctx context.Context, input Input)
// construct the URL
queryURL := queryURLParametersFromScrapedScene(scene)
if s.definition.QueryURLReplacements != nil {
queryURL.applyReplacements(s.definition.QueryURLReplacements)
if s.scraper.QueryURLReplacements != nil {
queryURL.applyReplacements(s.scraper.QueryURLReplacements)
}
url := queryURL.constructURL(s.definition.QueryURL)
url := queryURL.constructURL(s.scraper.QueryURL)
scraper, err := s.getJsonScraper(s.definition.Scraper)
if err != nil {
return nil, err
scraper := s.getJsonScraper()
if scraper == nil {
return nil, errors.New("xpath scraper with name " + s.scraper.Scraper + " not found in config")
}
doc, err := s.loadURL(ctx, url)
@@ -226,17 +231,18 @@ func (s *jsonFragmentScraper) scrapeByFragment(ctx context.Context, input Input)
return scraper.scrapeScene(ctx, q)
}
func (s *jsonFragmentScraper) scrapeImageByImage(ctx context.Context, image *models.Image) (*models.ScrapedImage, error) {
func (s *jsonScraper) scrapeImageByImage(ctx context.Context, image *models.Image) (*models.ScrapedImage, error) {
// construct the URL
queryURL := queryURLParametersFromImage(image)
if s.definition.QueryURLReplacements != nil {
queryURL.applyReplacements(s.definition.QueryURLReplacements)
if s.scraper.QueryURLReplacements != nil {
queryURL.applyReplacements(s.scraper.QueryURLReplacements)
}
url := queryURL.constructURL(s.definition.QueryURL)
url := queryURL.constructURL(s.scraper.QueryURL)
scraper, err := s.getJsonScraper(s.definition.Scraper)
if err != nil {
return nil, err
scraper := s.getJsonScraper()
if scraper == nil {
return nil, errors.New("json scraper with name " + s.scraper.Scraper + " not found in config")
}
doc, err := s.loadURL(ctx, url)
@@ -249,17 +255,18 @@ func (s *jsonFragmentScraper) scrapeImageByImage(ctx context.Context, image *mod
return scraper.scrapeImage(ctx, q)
}
func (s *jsonFragmentScraper) scrapeGalleryByGallery(ctx context.Context, gallery *models.Gallery) (*models.ScrapedGallery, error) {
func (s *jsonScraper) scrapeGalleryByGallery(ctx context.Context, gallery *models.Gallery) (*models.ScrapedGallery, error) {
// construct the URL
queryURL := queryURLParametersFromGallery(gallery)
if s.definition.QueryURLReplacements != nil {
queryURL.applyReplacements(s.definition.QueryURLReplacements)
if s.scraper.QueryURLReplacements != nil {
queryURL.applyReplacements(s.scraper.QueryURLReplacements)
}
url := queryURL.constructURL(s.definition.QueryURL)
url := queryURL.constructURL(s.scraper.QueryURL)
scraper, err := s.getJsonScraper(s.definition.Scraper)
if err != nil {
return nil, err
scraper := s.getJsonScraper()
if scraper == nil {
return nil, errors.New("json scraper with name " + s.scraper.Scraper + " not found in config")
}
doc, err := s.loadURL(ctx, url)

View File

@@ -68,7 +68,7 @@ jsonScrapers:
}
`
c := &Definition{}
c := &config{}
err := yaml.Unmarshal([]byte(yamlStr), &c)
if err != nil {

File diff suppressed because it is too large Load Diff

View File

@@ -1,537 +0,0 @@
package scraper
import (
"context"
"errors"
"net/url"
"strings"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/sliceutil"
"gopkg.in/yaml.v2"
)
type commonMappedConfig map[string]string
type mappedConfig map[string]mappedScraperAttrConfig
func (s mappedConfig) applyCommon(c commonMappedConfig, src string) string {
if c == nil {
return src
}
ret := src
for commonKey, commonVal := range c {
ret = strings.ReplaceAll(ret, commonKey, commonVal)
}
return ret
}
// extractHostname parses a URL string and returns the hostname.
// Returns empty string if the URL cannot be parsed.
func extractHostname(urlStr string) string {
if urlStr == "" {
return ""
}
u, err := url.Parse(urlStr)
if err != nil {
logger.Warnf("Error parsing URL '%s': %s", urlStr, err.Error())
return ""
}
return u.Hostname()
}
type isMultiFunc func(key string) bool
func (s mappedConfig) process(ctx context.Context, q mappedQuery, common commonMappedConfig, isMulti isMultiFunc) mappedResults {
var ret mappedResults
for k, attrConfig := range s {
if attrConfig.Fixed != "" {
// TODO - not sure if this needs to set _all_ indexes for the key
const i = 0
// Support {inputURL} and {inputHostname} placeholders in fixed values
value := strings.ReplaceAll(attrConfig.Fixed, "{inputURL}", q.getURL())
value = strings.ReplaceAll(value, "{inputHostname}", extractHostname(q.getURL()))
ret = ret.setSingleValue(i, k, value)
} else {
selector := attrConfig.Selector
selector = s.applyCommon(common, selector)
// Support {inputURL} and {inputHostname} placeholders in selectors
selector = strings.ReplaceAll(selector, "{inputURL}", q.getURL())
selector = strings.ReplaceAll(selector, "{inputHostname}", extractHostname(q.getURL()))
found, err := q.runQuery(selector)
if err != nil {
logger.Warnf("key '%v': %v", k, err)
}
if len(found) > 0 {
result := s.postProcess(ctx, q, attrConfig, found)
// HACK - if the key is URLs, then we need to set the value as a multi-value
isMulti := isMulti != nil && isMulti(k)
if isMulti {
ret = ret.setMultiValue(0, k, result)
} else {
for i, text := range result {
ret = ret.setSingleValue(i, k, text)
}
}
}
}
}
return ret
}
func (s mappedConfig) postProcess(ctx context.Context, q mappedQuery, attrConfig mappedScraperAttrConfig, found []string) []string {
// check if we're concatenating the results into a single result
var ret []string
if attrConfig.hasConcat() {
result := attrConfig.concatenateResults(found)
result = attrConfig.postProcess(ctx, result, q)
if attrConfig.hasSplit() {
results := attrConfig.splitString(result)
// skip cleaning when the query is used for searching
if q.getType() == SearchQuery {
return results
}
results = attrConfig.cleanResults(results)
return results
}
ret = []string{result}
} else {
for _, text := range found {
text = attrConfig.postProcess(ctx, text, q)
if attrConfig.hasSplit() {
return attrConfig.splitString(text)
}
ret = append(ret, text)
}
// skip cleaning when the query is used for searching
if q.getType() == SearchQuery {
return ret
}
ret = attrConfig.cleanResults(ret)
}
return ret
}
type mappedSceneScraperConfig struct {
mappedConfig
Tags mappedConfig `yaml:"Tags"`
Performers mappedPerformerScraperConfig `yaml:"Performers"`
Studio mappedConfig `yaml:"Studio"`
Movies mappedConfig `yaml:"Movies"`
Groups mappedConfig `yaml:"Groups"`
}
type _mappedSceneScraperConfig mappedSceneScraperConfig
const (
mappedScraperConfigSceneTags = "Tags"
mappedScraperConfigScenePerformers = "Performers"
mappedScraperConfigSceneStudio = "Studio"
mappedScraperConfigSceneMovies = "Movies"
mappedScraperConfigSceneGroups = "Groups"
)
func (s *mappedSceneScraperConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
// HACK - unmarshal to map first, then remove known scene sub-fields, then
// remarshal to yaml and pass that down to the base map
parentMap := make(map[string]interface{})
if err := unmarshal(parentMap); err != nil {
return err
}
// move the known sub-fields to a separate map
thisMap := make(map[string]interface{})
thisMap[mappedScraperConfigSceneTags] = parentMap[mappedScraperConfigSceneTags]
thisMap[mappedScraperConfigScenePerformers] = parentMap[mappedScraperConfigScenePerformers]
thisMap[mappedScraperConfigSceneStudio] = parentMap[mappedScraperConfigSceneStudio]
thisMap[mappedScraperConfigSceneMovies] = parentMap[mappedScraperConfigSceneMovies]
thisMap[mappedScraperConfigSceneGroups] = parentMap[mappedScraperConfigSceneGroups]
delete(parentMap, mappedScraperConfigSceneTags)
delete(parentMap, mappedScraperConfigScenePerformers)
delete(parentMap, mappedScraperConfigSceneStudio)
delete(parentMap, mappedScraperConfigSceneMovies)
delete(parentMap, mappedScraperConfigSceneGroups)
// re-unmarshal the sub-fields
yml, err := yaml.Marshal(thisMap)
if err != nil {
return err
}
// needs to be a different type to prevent infinite recursion
c := _mappedSceneScraperConfig{}
if err := yaml.Unmarshal(yml, &c); err != nil {
return err
}
*s = mappedSceneScraperConfig(c)
yml, err = yaml.Marshal(parentMap)
if err != nil {
return err
}
if err := yaml.Unmarshal(yml, &s.mappedConfig); err != nil {
return err
}
return nil
}
type mappedGalleryScraperConfig struct {
mappedConfig
Tags mappedConfig `yaml:"Tags"`
Performers mappedConfig `yaml:"Performers"`
Studio mappedConfig `yaml:"Studio"`
}
type _mappedGalleryScraperConfig mappedGalleryScraperConfig
func (s *mappedGalleryScraperConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
// HACK - unmarshal to map first, then remove known scene sub-fields, then
// remarshal to yaml and pass that down to the base map
parentMap := make(map[string]interface{})
if err := unmarshal(parentMap); err != nil {
return err
}
// move the known sub-fields to a separate map
thisMap := make(map[string]interface{})
thisMap[mappedScraperConfigSceneTags] = parentMap[mappedScraperConfigSceneTags]
thisMap[mappedScraperConfigScenePerformers] = parentMap[mappedScraperConfigScenePerformers]
thisMap[mappedScraperConfigSceneStudio] = parentMap[mappedScraperConfigSceneStudio]
delete(parentMap, mappedScraperConfigSceneTags)
delete(parentMap, mappedScraperConfigScenePerformers)
delete(parentMap, mappedScraperConfigSceneStudio)
// re-unmarshal the sub-fields
yml, err := yaml.Marshal(thisMap)
if err != nil {
return err
}
// needs to be a different type to prevent infinite recursion
c := _mappedGalleryScraperConfig{}
if err := yaml.Unmarshal(yml, &c); err != nil {
return err
}
*s = mappedGalleryScraperConfig(c)
yml, err = yaml.Marshal(parentMap)
if err != nil {
return err
}
if err := yaml.Unmarshal(yml, &s.mappedConfig); err != nil {
return err
}
return nil
}
type mappedImageScraperConfig struct {
mappedConfig
Tags mappedConfig `yaml:"Tags"`
Performers mappedConfig `yaml:"Performers"`
Studio mappedConfig `yaml:"Studio"`
}
type _mappedImageScraperConfig mappedImageScraperConfig
func (s *mappedImageScraperConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
// HACK - unmarshal to map first, then remove known scene sub-fields, then
// remarshal to yaml and pass that down to the base map
parentMap := make(map[string]interface{})
if err := unmarshal(parentMap); err != nil {
return err
}
// move the known sub-fields to a separate map
thisMap := make(map[string]interface{})
thisMap[mappedScraperConfigSceneTags] = parentMap[mappedScraperConfigSceneTags]
thisMap[mappedScraperConfigScenePerformers] = parentMap[mappedScraperConfigScenePerformers]
thisMap[mappedScraperConfigSceneStudio] = parentMap[mappedScraperConfigSceneStudio]
delete(parentMap, mappedScraperConfigSceneTags)
delete(parentMap, mappedScraperConfigScenePerformers)
delete(parentMap, mappedScraperConfigSceneStudio)
// re-unmarshal the sub-fields
yml, err := yaml.Marshal(thisMap)
if err != nil {
return err
}
// needs to be a different type to prevent infinite recursion
c := _mappedImageScraperConfig{}
if err := yaml.Unmarshal(yml, &c); err != nil {
return err
}
*s = mappedImageScraperConfig(c)
yml, err = yaml.Marshal(parentMap)
if err != nil {
return err
}
if err := yaml.Unmarshal(yml, &s.mappedConfig); err != nil {
return err
}
return nil
}
type mappedPerformerScraperConfig struct {
mappedConfig
Tags mappedConfig `yaml:"Tags"`
}
type _mappedPerformerScraperConfig mappedPerformerScraperConfig
const (
mappedScraperConfigPerformerTags = "Tags"
)
func (s *mappedPerformerScraperConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
// HACK - unmarshal to map first, then remove known scene sub-fields, then
// remarshal to yaml and pass that down to the base map
parentMap := make(map[string]interface{})
if err := unmarshal(parentMap); err != nil {
return err
}
// move the known sub-fields to a separate map
thisMap := make(map[string]interface{})
thisMap[mappedScraperConfigPerformerTags] = parentMap[mappedScraperConfigPerformerTags]
delete(parentMap, mappedScraperConfigPerformerTags)
// re-unmarshal the sub-fields
yml, err := yaml.Marshal(thisMap)
if err != nil {
return err
}
// needs to be a different type to prevent infinite recursion
c := _mappedPerformerScraperConfig{}
if err := yaml.Unmarshal(yml, &c); err != nil {
return err
}
*s = mappedPerformerScraperConfig(c)
yml, err = yaml.Marshal(parentMap)
if err != nil {
return err
}
if err := yaml.Unmarshal(yml, &s.mappedConfig); err != nil {
return err
}
return nil
}
type mappedMovieScraperConfig struct {
mappedConfig
Studio mappedConfig `yaml:"Studio"`
Tags mappedConfig `yaml:"Tags"`
}
type _mappedMovieScraperConfig mappedMovieScraperConfig
const (
mappedScraperConfigMovieStudio = "Studio"
mappedScraperConfigMovieTags = "Tags"
)
func (s *mappedMovieScraperConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
// HACK - unmarshal to map first, then remove known movie sub-fields, then
// remarshal to yaml and pass that down to the base map
parentMap := make(map[string]interface{})
if err := unmarshal(parentMap); err != nil {
return err
}
// move the known sub-fields to a separate map
thisMap := make(map[string]interface{})
thisMap[mappedScraperConfigMovieStudio] = parentMap[mappedScraperConfigMovieStudio]
delete(parentMap, mappedScraperConfigMovieStudio)
thisMap[mappedScraperConfigMovieTags] = parentMap[mappedScraperConfigMovieTags]
delete(parentMap, mappedScraperConfigMovieTags)
// re-unmarshal the sub-fields
yml, err := yaml.Marshal(thisMap)
if err != nil {
return err
}
// needs to be a different type to prevent infinite recursion
c := _mappedMovieScraperConfig{}
if err := yaml.Unmarshal(yml, &c); err != nil {
return err
}
*s = mappedMovieScraperConfig(c)
yml, err = yaml.Marshal(parentMap)
if err != nil {
return err
}
if err := yaml.Unmarshal(yml, &s.mappedConfig); err != nil {
return err
}
return nil
}
type mappedScraperAttrConfig struct {
Selector string `yaml:"selector"`
Fixed string `yaml:"fixed"`
PostProcess []mappedPostProcessAction `yaml:"postProcess"`
Concat string `yaml:"concat"`
Split string `yaml:"split"`
postProcessActions []postProcessAction
// Deprecated: use PostProcess instead
ParseDate string `yaml:"parseDate"`
Replace mappedRegexConfigs `yaml:"replace"`
SubScraper *mappedScraperAttrConfig `yaml:"subScraper"`
}
type _mappedScraperAttrConfig mappedScraperAttrConfig
func (c *mappedScraperAttrConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
// try unmarshalling into a string first
if err := unmarshal(&c.Selector); err != nil {
// if it's a type error then we try to unmarshall to the full object
var typeErr *yaml.TypeError
if !errors.As(err, &typeErr) {
return err
}
// unmarshall to full object
// need it as a separate object
t := _mappedScraperAttrConfig{}
if err = unmarshal(&t); err != nil {
return err
}
*c = mappedScraperAttrConfig(t)
}
return c.convertPostProcessActions()
}
func (c *mappedScraperAttrConfig) convertPostProcessActions() error {
// ensure we don't have the old deprecated fields and the new post process field
if len(c.PostProcess) > 0 {
if c.ParseDate != "" || len(c.Replace) > 0 || c.SubScraper != nil {
return errors.New("cannot include postProcess and (parseDate, replace, subScraper) deprecated fields")
}
// convert xpathPostProcessAction actions to postProcessActions
for _, a := range c.PostProcess {
action, err := a.ToPostProcessAction()
if err != nil {
return err
}
c.postProcessActions = append(c.postProcessActions, action)
}
c.PostProcess = nil
} else {
// convert old deprecated fields if present
// in same order as they used to be executed
if len(c.Replace) > 0 {
action := postProcessReplace(c.Replace)
c.postProcessActions = append(c.postProcessActions, &action)
c.Replace = nil
}
if c.SubScraper != nil {
action := postProcessSubScraper(*c.SubScraper)
c.postProcessActions = append(c.postProcessActions, &action)
c.SubScraper = nil
}
if c.ParseDate != "" {
action := postProcessParseDate(c.ParseDate)
c.postProcessActions = append(c.postProcessActions, &action)
c.ParseDate = ""
}
}
return nil
}
func (c mappedScraperAttrConfig) hasConcat() bool {
return c.Concat != ""
}
func (c mappedScraperAttrConfig) hasSplit() bool {
return c.Split != ""
}
func (c mappedScraperAttrConfig) concatenateResults(nodes []string) string {
separator := c.Concat
return strings.Join(nodes, separator)
}
func (c mappedScraperAttrConfig) cleanResults(nodes []string) []string {
cleaned := sliceutil.Unique(nodes) // remove duplicate values
cleaned = sliceutil.Delete(cleaned, "") // remove empty values
return cleaned
}
func (c mappedScraperAttrConfig) splitString(value string) []string {
separator := c.Split
var res []string
if separator == "" {
return []string{value}
}
for _, str := range strings.Split(value, separator) {
if str != "" {
res = append(res, str)
}
}
return res
}
func (c mappedScraperAttrConfig) postProcess(ctx context.Context, value string, q mappedQuery) string {
for _, action := range c.postProcessActions {
value = action.Apply(ctx, value, q)
}
return value
}

View File

@@ -1,333 +0,0 @@
package scraper
import (
"context"
"errors"
"fmt"
"math"
"regexp"
"strconv"
"strings"
"time"
"github.com/stashapp/stash/pkg/javascript"
"github.com/stashapp/stash/pkg/logger"
)
type mappedRegexConfig struct {
Regex string `yaml:"regex"`
With string `yaml:"with"`
}
type mappedRegexConfigs []mappedRegexConfig
func (c mappedRegexConfig) apply(value string) string {
if c.Regex != "" {
re, err := regexp.Compile(c.Regex)
if err != nil {
logger.Warnf("Error compiling regex '%s': %s", c.Regex, err.Error())
return value
}
ret := re.ReplaceAllString(value, c.With)
// trim leading and trailing whitespace
// this is done to maintain backwards compatibility with existing
// scrapers
ret = strings.TrimSpace(ret)
logger.Debugf(`Replace: '%s' with '%s'`, c.Regex, c.With)
logger.Debugf("Before: %s", value)
logger.Debugf("After: %s", ret)
return ret
}
return value
}
func (c mappedRegexConfigs) apply(value string) string {
// apply regex in order
for _, config := range c {
value = config.apply(value)
}
return value
}
type postProcessAction interface {
Apply(ctx context.Context, value string, q mappedQuery) string
}
type postProcessParseDate string
func (p *postProcessParseDate) Apply(ctx context.Context, value string, q mappedQuery) string {
parseDate := string(*p)
const internalDateFormat = "2006-01-02"
valueLower := strings.ToLower(value)
if valueLower == "today" || valueLower == "yesterday" { // handle today, yesterday
dt := time.Now()
if valueLower == "yesterday" { // subtract 1 day from now
dt = dt.AddDate(0, 0, -1)
}
return dt.Format(internalDateFormat)
}
if parseDate == "" {
return value
}
if parseDate == "unix" {
// try to parse the date using unix timestamp format
// if it fails, then just fall back to the original value
timeAsInt, err := strconv.ParseInt(value, 10, 64)
if err != nil {
logger.Warnf("Error parsing date string '%s' using unix timestamp format : %s", value, err.Error())
return value
}
parsedValue := time.Unix(timeAsInt, 0)
return parsedValue.Format(internalDateFormat)
}
// try to parse the date using the pattern
// if it fails, then just fall back to the original value
parsedValue, err := time.Parse(parseDate, value)
if err != nil {
logger.Warnf("Error parsing date string '%s' using format '%s': %s", value, parseDate, err.Error())
return value
}
// convert it into our date format
return parsedValue.Format(internalDateFormat)
}
type postProcessSubtractDays bool
func (p *postProcessSubtractDays) Apply(ctx context.Context, value string, q mappedQuery) string {
const internalDateFormat = "2006-01-02"
i, err := strconv.Atoi(value)
if err != nil {
logger.Warnf("Error parsing day string %s: %s", value, err)
return value
}
dt := time.Now()
dt = dt.AddDate(0, 0, -i)
return dt.Format(internalDateFormat)
}
type postProcessReplace mappedRegexConfigs
func (c *postProcessReplace) Apply(ctx context.Context, value string, q mappedQuery) string {
replace := mappedRegexConfigs(*c)
return replace.apply(value)
}
type postProcessSubScraper mappedScraperAttrConfig
func (p *postProcessSubScraper) Apply(ctx context.Context, value string, q mappedQuery) string {
subScrapeConfig := mappedScraperAttrConfig(*p)
logger.Debugf("Sub-scraping for: %s", value)
ss := q.subScrape(ctx, value)
if ss != nil {
found, err := ss.runQuery(subScrapeConfig.Selector)
if err != nil {
logger.Warnf("subscrape for '%v': %v", value, err)
}
if len(found) > 0 {
// check if we're concatenating the results into a single result
var result string
if subScrapeConfig.hasConcat() {
result = subScrapeConfig.concatenateResults(found)
} else {
result = found[0]
}
result = subScrapeConfig.postProcess(ctx, result, ss)
return result
}
}
return ""
}
type postProcessMap map[string]string
func (p *postProcessMap) Apply(ctx context.Context, value string, q mappedQuery) string {
// return the mapped value if present
m := *p
mapped, ok := m[value]
if ok {
return mapped
}
return value
}
type postProcessFeetToCm bool
func (p *postProcessFeetToCm) Apply(ctx context.Context, value string, q mappedQuery) string {
const foot_in_cm = 30.48
const inch_in_cm = 2.54
reg := regexp.MustCompile("[0-9]+")
filtered := reg.FindAllString(value, -1)
var feet float64
var inches float64
if len(filtered) > 0 {
feet, _ = strconv.ParseFloat(filtered[0], 64)
}
if len(filtered) > 1 {
inches, _ = strconv.ParseFloat(filtered[1], 64)
}
var centimeters = feet*foot_in_cm + inches*inch_in_cm
// Return rounded integer string
return strconv.Itoa(int(math.Round(centimeters)))
}
type postProcessLbToKg bool
func (p *postProcessLbToKg) Apply(ctx context.Context, value string, q mappedQuery) string {
const lb_in_kg = 0.45359237
w, err := strconv.ParseFloat(value, 64)
if err == nil {
w *= lb_in_kg
value = strconv.Itoa(int(math.Round(w)))
}
return value
}
type postProcessJavascript string
func (p *postProcessJavascript) Apply(ctx context.Context, value string, q mappedQuery) string {
vm := javascript.NewVM()
if err := vm.Set("value", value); err != nil {
logger.Warnf("javascript failed to set value: %v", err)
return value
}
log := &javascript.Log{
Logger: logger.Logger,
Prefix: "",
ProgressChan: make(chan float64),
}
if err := log.AddToVM("log", vm); err != nil {
logger.Logger.Errorf("error adding log API: %w", err)
}
util := &javascript.Util{}
if err := util.AddToVM("util", vm); err != nil {
logger.Logger.Errorf("error adding util API: %w", err)
}
script, err := javascript.CompileScript("", "(function() { "+string(*p)+"})()")
if err != nil {
logger.Warnf("javascript failed to compile: %v", err)
return value
}
output, err := vm.RunProgram(script)
if err != nil {
logger.Warnf("javascript failed to run: %v", err)
return value
}
// assume output is string
return output.String()
}
type mappedPostProcessAction struct {
ParseDate string `yaml:"parseDate"`
SubtractDays bool `yaml:"subtractDays"`
Replace mappedRegexConfigs `yaml:"replace"`
SubScraper *mappedScraperAttrConfig `yaml:"subScraper"`
Map map[string]string `yaml:"map"`
FeetToCm bool `yaml:"feetToCm"`
LbToKg bool `yaml:"lbToKg"`
Javascript string `yaml:"javascript"`
}
func (a mappedPostProcessAction) ToPostProcessAction() (postProcessAction, error) {
var found string
var ret postProcessAction
ensureOnly := func(field string) error {
if found != "" {
return fmt.Errorf("post-process actions must have a single field, found %s and %s", found, field)
}
found = field
return nil
}
if a.ParseDate != "" {
found = "parseDate"
action := postProcessParseDate(a.ParseDate)
ret = &action
}
if len(a.Replace) > 0 {
if err := ensureOnly("replace"); err != nil {
return nil, err
}
action := postProcessReplace(a.Replace)
ret = &action
}
if a.SubScraper != nil {
if err := ensureOnly("subScraper"); err != nil {
return nil, err
}
action := postProcessSubScraper(*a.SubScraper)
ret = &action
}
if a.Map != nil {
if err := ensureOnly("map"); err != nil {
return nil, err
}
action := postProcessMap(a.Map)
ret = &action
}
if a.FeetToCm {
if err := ensureOnly("feetToCm"); err != nil {
return nil, err
}
action := postProcessFeetToCm(a.FeetToCm)
ret = &action
}
if a.LbToKg {
if err := ensureOnly("lbToKg"); err != nil {
return nil, err
}
action := postProcessLbToKg(a.LbToKg)
ret = &action
}
if a.SubtractDays {
if err := ensureOnly("subtractDays"); err != nil {
return nil, err
}
action := postProcessSubtractDays(a.SubtractDays)
ret = &action
}
if a.Javascript != "" {
if err := ensureOnly("javascript"); err != nil {
return nil, err
}
action := postProcessJavascript(a.Javascript)
ret = &action
}
if ret == nil {
return nil, errors.New("invalid post-process action")
}
return ret, nil
}

View File

@@ -1,276 +0,0 @@
package scraper
import (
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
)
type mappedResult map[string]interface{}
type mappedResults []mappedResult
func (r mappedResult) string(key string) (string, bool) {
v, ok := r[key]
if !ok {
return "", false
}
val, ok := v.(string)
if !ok {
logger.Errorf("String field %s is %T in mappedResult", key, r[key])
}
return val, true
}
func (r mappedResult) mustString(key string) string {
v, ok := r[key]
if !ok {
logger.Errorf("Missing required string field %s in mappedResult", key)
return ""
}
val, ok := v.(string)
if !ok {
logger.Errorf("String field %s is %T in mappedResult", key, r[key])
}
return val
}
func (r mappedResult) stringPtr(key string) *string {
val, ok := r.string(key)
if !ok {
return nil
}
return &val
}
func (r mappedResult) stringSlice(key string) []string {
v, ok := r[key]
if !ok {
return nil
}
// need to try both []string and string
val, ok := v.([]string)
if ok {
return val
}
// try single string
singleVal, ok := v.(string)
if !ok {
logger.Errorf("String slice field %s is %T in mappedResult", key, r[key])
return nil
}
return []string{singleVal}
}
func (r mappedResult) IntPtr(key string) *int {
v, ok := r[key]
if !ok {
return nil
}
val, ok := v.(int)
if !ok {
logger.Errorf("Int field %s is %T in mappedResult", key, r[key])
return nil
}
return &val
}
func (r mappedResults) setSingleValue(index int, key string, value string) mappedResults {
if index >= len(r) {
r = append(r, make(mappedResult))
}
logger.Debugf(`[%d][%s] = %s`, index, key, value)
r[index][key] = value
return r
}
func (r mappedResults) setMultiValue(index int, key string, value []string) mappedResults {
if index >= len(r) {
r = append(r, make(mappedResult))
}
logger.Debugf(`[%d][%s] = %s`, index, key, value)
r[index][key] = value
return r
}
func (r mappedResults) scrapedTags() []*models.ScrapedTag {
if len(r) == 0 {
return nil
}
ret := make([]*models.ScrapedTag, len(r))
for i, result := range r {
ret[i] = result.scrapedTag()
}
return ret
}
func (r mappedResult) scrapedTag() *models.ScrapedTag {
return &models.ScrapedTag{
Name: r.mustString("Name"),
}
}
func (r mappedResult) scrapedPerformer() *models.ScrapedPerformer {
ret := &models.ScrapedPerformer{
Name: r.stringPtr("Name"),
Disambiguation: r.stringPtr("Disambiguation"),
Gender: r.stringPtr("Gender"),
URL: r.stringPtr("URL"),
URLs: r.stringSlice("URLs"),
Twitter: r.stringPtr("Twitter"),
Birthdate: r.stringPtr("Birthdate"),
Ethnicity: r.stringPtr("Ethnicity"),
Country: r.stringPtr("Country"),
EyeColor: r.stringPtr("EyeColor"),
Height: r.stringPtr("Height"),
Measurements: r.stringPtr("Measurements"),
FakeTits: r.stringPtr("FakeTits"),
PenisLength: r.stringPtr("PenisLength"),
Circumcised: r.stringPtr("Circumcised"),
CareerLength: r.stringPtr("CareerLength"),
Tattoos: r.stringPtr("Tattoos"),
Piercings: r.stringPtr("Piercings"),
Aliases: r.stringPtr("Aliases"),
Image: r.stringPtr("Image"),
Images: r.stringSlice("Images"),
Details: r.stringPtr("Details"),
DeathDate: r.stringPtr("DeathDate"),
HairColor: r.stringPtr("HairColor"),
Weight: r.stringPtr("Weight"),
}
return ret
}
func (r mappedResults) scrapedPerformers() []*models.ScrapedPerformer {
if len(r) == 0 {
return nil
}
ret := make([]*models.ScrapedPerformer, len(r))
for i, result := range r {
ret[i] = result.scrapedPerformer()
}
return ret
}
func (r mappedResult) scrapedScene() *models.ScrapedScene {
ret := &models.ScrapedScene{
Title: r.stringPtr("Title"),
Code: r.stringPtr("Code"),
Details: r.stringPtr("Details"),
Director: r.stringPtr("Director"),
URL: r.stringPtr("URL"),
URLs: r.stringSlice("URLs"),
Date: r.stringPtr("Date"),
Image: r.stringPtr("Image"),
Duration: r.IntPtr("Duration"),
}
return ret
}
func (r mappedResult) scrapedImage() *models.ScrapedImage {
ret := &models.ScrapedImage{
Title: r.stringPtr("Title"),
Code: r.stringPtr("Code"),
Details: r.stringPtr("Details"),
Photographer: r.stringPtr("Photographer"),
URLs: r.stringSlice("URLs"),
Date: r.stringPtr("Date"),
}
return ret
}
func (r mappedResult) scrapedGallery() *models.ScrapedGallery {
ret := &models.ScrapedGallery{
Title: r.stringPtr("Title"),
Code: r.stringPtr("Code"),
Details: r.stringPtr("Details"),
Photographer: r.stringPtr("Photographer"),
URL: r.stringPtr("URL"),
URLs: r.stringSlice("URLs"),
Date: r.stringPtr("Date"),
}
return ret
}
func (r mappedResult) scrapedStudio() *models.ScrapedStudio {
ret := &models.ScrapedStudio{
Name: r.mustString("Name"),
URL: r.stringPtr("URL"),
URLs: r.stringSlice("URLs"),
Image: r.stringPtr("Image"),
Details: r.stringPtr("Details"),
Aliases: r.stringPtr("Aliases"),
}
return ret
}
func (r mappedResult) scrapedMovie() *models.ScrapedMovie {
ret := &models.ScrapedMovie{
Name: r.stringPtr("Name"),
Aliases: r.stringPtr("Aliases"),
URLs: r.stringSlice("URLs"),
Duration: r.stringPtr("Duration"),
Date: r.stringPtr("Date"),
Director: r.stringPtr("Director"),
Synopsis: r.stringPtr("Synopsis"),
FrontImage: r.stringPtr("FrontImage"),
BackImage: r.stringPtr("BackImage"),
}
return ret
}
func (r mappedResult) scrapedGroup() *models.ScrapedGroup {
ret := &models.ScrapedGroup{
Name: r.stringPtr("Name"),
Aliases: r.stringPtr("Aliases"),
URL: r.stringPtr("URL"),
URLs: r.stringSlice("URLs"),
Duration: r.stringPtr("Duration"),
Date: r.stringPtr("Date"),
Director: r.stringPtr("Director"),
Synopsis: r.stringPtr("Synopsis"),
FrontImage: r.stringPtr("FrontImage"),
BackImage: r.stringPtr("BackImage"),
}
return ret
}
func (r mappedResults) scrapedMovies() []*models.ScrapedMovie {
if len(r) == 0 {
return nil
}
ret := make([]*models.ScrapedMovie, len(r))
for i, result := range r {
ret[i] = result.scrapedMovie()
}
return ret
}
func (r mappedResults) scrapedGroups() []*models.ScrapedGroup {
if len(r) == 0 {
return nil
}
ret := make([]*models.ScrapedGroup, len(r))
for i, result := range r {
ret[i] = result.scrapedGroup()
}
return ret
}

View File

@@ -1,908 +0,0 @@
package scraper
import (
"testing"
"github.com/stashapp/stash/pkg/models"
"github.com/stretchr/testify/assert"
)
// Test string method
func TestMappedResultString(t *testing.T) {
tests := []struct {
name string
data mappedResult
key string
expectedValue string
expectedOk bool
}{
{
name: "valid string",
data: mappedResult{"name": "test"},
key: "name",
expectedValue: "test",
expectedOk: true,
},
{
name: "missing key",
data: mappedResult{},
key: "missing",
expectedValue: "",
expectedOk: false,
},
{
name: "wrong type still returns ok true but empty value",
data: mappedResult{"num": 123},
key: "num",
expectedValue: "",
expectedOk: true, // logs error but returns ok=true
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
val, ok := test.data.string(test.key)
assert.Equal(t, test.expectedValue, val)
assert.Equal(t, test.expectedOk, ok)
})
}
}
// Test mustString method
func TestMappedResultMustString(t *testing.T) {
tests := []struct {
name string
data mappedResult
key string
expectedValue string
}{
{
name: "valid string",
data: mappedResult{"name": "test"},
key: "name",
expectedValue: "test",
},
{
name: "missing key returns empty string",
data: mappedResult{},
key: "missing",
expectedValue: "",
},
{
name: "wrong type returns empty string",
data: mappedResult{"num": 123},
key: "num",
expectedValue: "",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
val := test.data.mustString(test.key)
assert.Equal(t, test.expectedValue, val)
})
}
}
// Test stringPtr method
func TestMappedResultStringPtr(t *testing.T) {
tests := []struct {
name string
data mappedResult
key string
expectedValue *string
}{
{
name: "valid string",
data: mappedResult{"name": "test"},
key: "name",
expectedValue: strPtr("test"),
},
{
name: "missing key returns nil",
data: mappedResult{},
key: "missing",
expectedValue: nil,
},
{
name: "wrong type returns non-nil pointer to empty string",
data: mappedResult{"num": 123},
key: "num",
expectedValue: strPtr(""), // string() returns empty string but ok=true
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
val := test.data.stringPtr(test.key)
if test.expectedValue == nil {
assert.Nil(t, val)
} else {
assert.NotNil(t, val)
assert.Equal(t, *test.expectedValue, *val)
}
})
}
}
// Test stringSlice method
func TestMappedResultStringSlice(t *testing.T) {
tests := []struct {
name string
data mappedResult
key string
expectedValue []string
}{
{
name: "valid slice",
data: mappedResult{"tags": []string{"a", "b", "c"}},
key: "tags",
expectedValue: []string{"a", "b", "c"},
},
{
name: "missing key returns nil",
data: mappedResult{},
key: "missing",
expectedValue: nil,
},
{
name: "single value converted to slice",
data: mappedResult{"tags": "not a slice"},
key: "tags",
expectedValue: []string{"not a slice"},
},
{
name: "wrong type returns nil",
data: mappedResult{"tags": 123},
key: "tags",
expectedValue: nil,
},
{
name: "empty slice",
data: mappedResult{"tags": []string{}},
key: "tags",
expectedValue: []string{},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
val := test.data.stringSlice(test.key)
assert.Equal(t, test.expectedValue, val)
})
}
}
// Test IntPtr method
func TestMappedResultIntPtr(t *testing.T) {
tests := []struct {
name string
data mappedResult
key string
expectedValue *int
}{
{
name: "valid int",
data: mappedResult{"duration": 120},
key: "duration",
expectedValue: intPtr(120),
},
{
name: "missing key returns nil",
data: mappedResult{},
key: "missing",
expectedValue: nil,
},
{
name: "wrong type returns nil",
data: mappedResult{"duration": "120"},
key: "duration",
expectedValue: nil,
},
{
name: "zero value",
data: mappedResult{"duration": 0},
key: "duration",
expectedValue: intPtr(0),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
val := test.data.IntPtr(test.key)
assert.Equal(t, test.expectedValue, val)
})
}
}
// Test setSingleValue method
func TestMappedResultsSetSingleValue(t *testing.T) {
tests := []struct {
name string
initialResults mappedResults
index int
key string
value string
expectedLen int
shouldPanic bool
}{
{
name: "append to empty",
initialResults: mappedResults{},
index: 0,
key: "name",
value: "test",
expectedLen: 1,
shouldPanic: false,
},
{
name: "set in existing",
initialResults: mappedResults{mappedResult{}},
index: 0,
key: "name",
value: "test",
expectedLen: 1,
shouldPanic: false,
},
{
name: "append to existing",
initialResults: mappedResults{mappedResult{}},
index: 1,
key: "name",
value: "test",
expectedLen: 2,
shouldPanic: false,
},
{
name: "sparse index causes panic",
initialResults: mappedResults{mappedResult{}},
index: 5,
key: "name",
value: "test",
expectedLen: 6,
shouldPanic: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if test.shouldPanic {
assert.Panics(t, func() {
test.initialResults.setSingleValue(test.index, test.key, test.value)
})
} else {
results := test.initialResults.setSingleValue(test.index, test.key, test.value)
assert.Equal(t, test.expectedLen, len(results))
assert.Equal(t, test.value, results[test.index][test.key])
}
})
}
}
// Test setMultiValue method
func TestMappedResultsSetMultiValue(t *testing.T) {
tests := []struct {
name string
initialResults mappedResults
index int
key string
value []string
expectedLen int
}{
{
name: "append to empty",
initialResults: mappedResults{},
index: 0,
key: "tags",
value: []string{"a", "b"},
expectedLen: 1,
},
{
name: "set in existing",
initialResults: mappedResults{mappedResult{}},
index: 0,
key: "tags",
value: []string{"a", "b"},
expectedLen: 1,
},
{
name: "append to existing",
initialResults: mappedResults{mappedResult{}},
index: 1,
key: "tags",
value: []string{"x", "y"},
expectedLen: 2,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
results := test.initialResults.setMultiValue(test.index, test.key, test.value)
assert.Equal(t, test.expectedLen, len(results))
assert.Equal(t, test.value, results[test.index][test.key])
})
}
}
// Test scrapedTag method
func TestMappedResultScrapedTag(t *testing.T) {
tests := []struct {
name string
data mappedResult
expectedName string
}{
{
name: "valid tag",
data: mappedResult{"Name": "Action"},
expectedName: "Action",
},
{
name: "missing name",
data: mappedResult{},
expectedName: "",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
tag := test.data.scrapedTag()
assert.NotNil(t, tag)
assert.Equal(t, test.expectedName, tag.Name)
})
}
}
// Test scrapedTags method
func TestMappedResultsScrapedTags(t *testing.T) {
tests := []struct {
name string
data mappedResults
expectedCount int
expectedNames []string
}{
{
name: "empty results",
data: mappedResults{},
expectedCount: 0,
},
{
name: "single tag",
data: mappedResults{
mappedResult{"Name": "Action"},
},
expectedCount: 1,
expectedNames: []string{"Action"},
},
{
name: "multiple tags",
data: mappedResults{
mappedResult{"Name": "Action"},
mappedResult{"Name": "Drama"},
mappedResult{"Name": "Comedy"},
},
expectedCount: 3,
expectedNames: []string{"Action", "Drama", "Comedy"},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
tags := test.data.scrapedTags()
if test.expectedCount == 0 {
assert.Nil(t, tags)
} else {
assert.NotNil(t, tags)
assert.Equal(t, test.expectedCount, len(tags))
for i, expectedName := range test.expectedNames {
assert.Equal(t, expectedName, tags[i].Name)
}
}
})
}
}
// Test scrapedPerformer method
func TestMappedResultScrapedPerformer(t *testing.T) {
tests := []struct {
name string
data mappedResult
validate func(t *testing.T, p *models.ScrapedPerformer)
}{
{
name: "full performer",
data: mappedResult{
"Name": "Jane Doe",
"Disambiguation": "Actress",
"Gender": "Female",
"URL": "https://example.com/jane",
"URLs": []string{"url1", "url2"},
"Twitter": "@jane",
"Birthdate": "1990-01-01",
"Ethnicity": "Caucasian",
"Country": "USA",
"EyeColor": "Blue",
"Height": "5'6\"",
"Measurements": "36-24-36",
"FakeTits": "No",
"PenisLength": "N/A",
"Circumcised": "N/A",
"CareerLength": "10 years",
"Tattoos": "Yes",
"Piercings": "Yes",
"Aliases": "Jane Smith",
"Image": "image.jpg",
"Images": []string{"img1", "img2"},
"Details": "Some details",
"DeathDate": "N/A",
"HairColor": "Blonde",
"Weight": "130 lbs",
},
validate: func(t *testing.T, p *models.ScrapedPerformer) {
assert.NotNil(t, p)
assert.Equal(t, "Jane Doe", *p.Name)
assert.Equal(t, "Actress", *p.Disambiguation)
assert.Equal(t, "Female", *p.Gender)
assert.Equal(t, "https://example.com/jane", *p.URL)
assert.Equal(t, []string{"url1", "url2"}, p.URLs)
assert.Equal(t, "@jane", *p.Twitter)
assert.Equal(t, "Blonde", *p.HairColor)
assert.Equal(t, "130 lbs", *p.Weight)
},
},
{
name: "minimal performer",
data: mappedResult{},
validate: func(t *testing.T, p *models.ScrapedPerformer) {
assert.NotNil(t, p)
assert.Nil(t, p.Name)
assert.Nil(t, p.Gender)
assert.Empty(t, p.URLs)
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
performer := test.data.scrapedPerformer()
test.validate(t, performer)
})
}
}
// Test scrapedPerformers method
func TestMappedResultsScrapedPerformers(t *testing.T) {
tests := []struct {
name string
data mappedResults
expectedCount int
}{
{
name: "empty results",
data: mappedResults{},
expectedCount: 0,
},
{
name: "single performer",
data: mappedResults{
mappedResult{"Name": "Jane Doe"},
},
expectedCount: 1,
},
{
name: "multiple performers",
data: mappedResults{
mappedResult{"Name": "Jane Doe"},
mappedResult{"Name": "John Doe"},
mappedResult{"Name": "Alice"},
},
expectedCount: 3,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
performers := test.data.scrapedPerformers()
if test.expectedCount == 0 {
assert.Nil(t, performers)
} else {
assert.NotNil(t, performers)
assert.Equal(t, test.expectedCount, len(performers))
}
})
}
}
// Test scrapedScene method
func TestMappedResultScrapedScene(t *testing.T) {
tests := []struct {
name string
data mappedResult
validate func(t *testing.T, s *models.ScrapedScene)
}{
{
name: "full scene",
data: mappedResult{
"Title": "Scene Title",
"Code": "CODE123",
"Details": "Scene details",
"Director": "John Smith",
"URL": "https://example.com/scene",
"URLs": []string{"url1", "url2"},
"Date": "2020-01-01",
"Image": "scene.jpg",
"Duration": 3600,
},
validate: func(t *testing.T, s *models.ScrapedScene) {
assert.NotNil(t, s)
assert.Equal(t, "Scene Title", *s.Title)
assert.Equal(t, "CODE123", *s.Code)
assert.Equal(t, "Scene details", *s.Details)
assert.Equal(t, "John Smith", *s.Director)
assert.Equal(t, "https://example.com/scene", *s.URL)
assert.Equal(t, []string{"url1", "url2"}, s.URLs)
assert.Equal(t, "2020-01-01", *s.Date)
assert.Equal(t, "scene.jpg", *s.Image)
assert.Equal(t, 3600, *s.Duration)
},
},
{
name: "minimal scene",
data: mappedResult{},
validate: func(t *testing.T, s *models.ScrapedScene) {
assert.NotNil(t, s)
assert.Nil(t, s.Title)
assert.Nil(t, s.Duration)
assert.Empty(t, s.URLs)
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
scene := test.data.scrapedScene()
test.validate(t, scene)
})
}
}
// Test scrapedImage method
func TestMappedResultScrapedImage(t *testing.T) {
tests := []struct {
name string
data mappedResult
validate func(t *testing.T, i *models.ScrapedImage)
}{
{
name: "full image",
data: mappedResult{
"Title": "Image Title",
"Code": "IMG123",
"Details": "Image details",
"Photographer": "Jane Photographer",
"URLs": []string{"url1", "url2"},
"Date": "2020-06-15",
},
validate: func(t *testing.T, i *models.ScrapedImage) {
assert.NotNil(t, i)
assert.Equal(t, "Image Title", *i.Title)
assert.Equal(t, "IMG123", *i.Code)
assert.Equal(t, "Image details", *i.Details)
assert.Equal(t, "Jane Photographer", *i.Photographer)
assert.Equal(t, []string{"url1", "url2"}, i.URLs)
assert.Equal(t, "2020-06-15", *i.Date)
},
},
{
name: "minimal image",
data: mappedResult{},
validate: func(t *testing.T, i *models.ScrapedImage) {
assert.NotNil(t, i)
assert.Nil(t, i.Title)
assert.Empty(t, i.URLs)
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
image := test.data.scrapedImage()
test.validate(t, image)
})
}
}
// Test scrapedGallery method
func TestMappedResultScrapedGallery(t *testing.T) {
tests := []struct {
name string
data mappedResult
validate func(t *testing.T, g *models.ScrapedGallery)
}{
{
name: "full gallery",
data: mappedResult{
"Title": "Gallery Title",
"Code": "GAL123",
"Details": "Gallery details",
"Photographer": "Jane Photographer",
"URL": "https://example.com/gallery",
"URLs": []string{"url1", "url2"},
"Date": "2020-07-20",
},
validate: func(t *testing.T, g *models.ScrapedGallery) {
assert.NotNil(t, g)
assert.Equal(t, "Gallery Title", *g.Title)
assert.Equal(t, "GAL123", *g.Code)
assert.Equal(t, "Gallery details", *g.Details)
assert.Equal(t, "Jane Photographer", *g.Photographer)
assert.Equal(t, "https://example.com/gallery", *g.URL)
assert.Equal(t, []string{"url1", "url2"}, g.URLs)
assert.Equal(t, "2020-07-20", *g.Date)
},
},
{
name: "minimal gallery",
data: mappedResult{},
validate: func(t *testing.T, g *models.ScrapedGallery) {
assert.NotNil(t, g)
assert.Nil(t, g.Title)
assert.Empty(t, g.URLs)
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
gallery := test.data.scrapedGallery()
test.validate(t, gallery)
})
}
}
// Test scrapedStudio method
func TestMappedResultScrapedStudio(t *testing.T) {
tests := []struct {
name string
data mappedResult
validate func(t *testing.T, st *models.ScrapedStudio)
}{
{
name: "full studio",
data: mappedResult{
"Name": "Studio Name",
"URL": "https://example.com/studio",
"URLs": []string{"url1", "url2"},
"Image": "studio.jpg",
"Details": "Studio details",
"Aliases": "Studio Alias",
},
validate: func(t *testing.T, st *models.ScrapedStudio) {
assert.NotNil(t, st)
assert.Equal(t, "Studio Name", st.Name)
assert.Equal(t, "https://example.com/studio", *st.URL)
assert.Equal(t, []string{"url1", "url2"}, st.URLs)
assert.Equal(t, "studio.jpg", *st.Image)
assert.Equal(t, "Studio details", *st.Details)
assert.Equal(t, "Studio Alias", *st.Aliases)
},
},
{
name: "minimal studio",
data: mappedResult{},
validate: func(t *testing.T, st *models.ScrapedStudio) {
assert.NotNil(t, st)
assert.Equal(t, "", st.Name) // mustString returns empty string
assert.Nil(t, st.URL)
assert.Empty(t, st.URLs)
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
studio := test.data.scrapedStudio()
test.validate(t, studio)
})
}
}
// Test scrapedMovie method
func TestMappedResultScrapedMovie(t *testing.T) {
tests := []struct {
name string
data mappedResult
validate func(t *testing.T, m *models.ScrapedMovie)
}{
{
name: "full movie",
data: mappedResult{
"Name": "Movie Title",
"Aliases": "Movie Alias",
"URLs": []string{"url1", "url2"},
"Duration": "120 minutes",
"Date": "2020-05-10",
"Director": "John Director",
"Synopsis": "Movie synopsis",
"FrontImage": "front.jpg",
"BackImage": "back.jpg",
},
validate: func(t *testing.T, m *models.ScrapedMovie) {
assert.NotNil(t, m)
assert.Equal(t, "Movie Title", *m.Name)
assert.Equal(t, "Movie Alias", *m.Aliases)
assert.Equal(t, []string{"url1", "url2"}, m.URLs)
assert.Equal(t, "120 minutes", *m.Duration)
assert.Equal(t, "2020-05-10", *m.Date)
assert.Equal(t, "John Director", *m.Director)
assert.Equal(t, "Movie synopsis", *m.Synopsis)
assert.Equal(t, "front.jpg", *m.FrontImage)
assert.Equal(t, "back.jpg", *m.BackImage)
},
},
{
name: "minimal movie",
data: mappedResult{},
validate: func(t *testing.T, m *models.ScrapedMovie) {
assert.NotNil(t, m)
assert.Nil(t, m.Name)
assert.Empty(t, m.URLs)
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
movie := test.data.scrapedMovie()
test.validate(t, movie)
})
}
}
// Test scrapedMovies method
func TestMappedResultsScrapedMovies(t *testing.T) {
tests := []struct {
name string
data mappedResults
expectedCount int
}{
{
name: "empty results",
data: mappedResults{},
expectedCount: 0,
},
{
name: "single movie",
data: mappedResults{
mappedResult{"Name": "Movie 1"},
},
expectedCount: 1,
},
{
name: "multiple movies",
data: mappedResults{
mappedResult{"Name": "Movie 1"},
mappedResult{"Name": "Movie 2"},
mappedResult{"Name": "Movie 3"},
},
expectedCount: 3,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
movies := test.data.scrapedMovies()
if test.expectedCount == 0 {
assert.Nil(t, movies)
} else {
assert.NotNil(t, movies)
assert.Equal(t, test.expectedCount, len(movies))
}
})
}
}
// Test scrapedGroup method
func TestMappedResultScrapedGroup(t *testing.T) {
tests := []struct {
name string
data mappedResult
validate func(t *testing.T, g *models.ScrapedGroup)
}{
{
name: "full group",
data: mappedResult{
"Name": "Group Title",
"Aliases": "Group Alias",
"URL": "https://example.com/group",
"URLs": []string{"url1", "url2"},
"Duration": "240 minutes",
"Date": "2020-08-15",
"Director": "Jane Director",
"Synopsis": "Group synopsis",
"FrontImage": "front.jpg",
"BackImage": "back.jpg",
},
validate: func(t *testing.T, g *models.ScrapedGroup) {
assert.NotNil(t, g)
assert.Equal(t, "Group Title", *g.Name)
assert.Equal(t, "Group Alias", *g.Aliases)
assert.Equal(t, "https://example.com/group", *g.URL)
assert.Equal(t, []string{"url1", "url2"}, g.URLs)
assert.Equal(t, "240 minutes", *g.Duration)
assert.Equal(t, "2020-08-15", *g.Date)
assert.Equal(t, "Jane Director", *g.Director)
assert.Equal(t, "Group synopsis", *g.Synopsis)
assert.Equal(t, "front.jpg", *g.FrontImage)
assert.Equal(t, "back.jpg", *g.BackImage)
},
},
{
name: "minimal group",
data: mappedResult{},
validate: func(t *testing.T, g *models.ScrapedGroup) {
assert.NotNil(t, g)
assert.Nil(t, g.Name)
assert.Empty(t, g.URLs)
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
group := test.data.scrapedGroup()
test.validate(t, group)
})
}
}
// Test scrapedGroups method
func TestMappedResultsScrapedGroups(t *testing.T) {
tests := []struct {
name string
data mappedResults
expectedCount int
}{
{
name: "empty results",
data: mappedResults{},
expectedCount: 0,
},
{
name: "single group",
data: mappedResults{
mappedResult{"Name": "Group 1"},
},
expectedCount: 1,
},
{
name: "multiple groups",
data: mappedResults{
mappedResult{"Name": "Group 1"},
mappedResult{"Name": "Group 2"},
mappedResult{"Name": "Group 3"},
},
expectedCount: 3,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
groups := test.data.scrapedGroups()
if test.expectedCount == 0 {
assert.Nil(t, groups)
} else {
assert.NotNil(t, groups)
assert.Equal(t, test.expectedCount, len(groups))
}
})
}
}
// Helper functions
func strPtr(s string) *string {
return &s
}
func intPtr(i int) *int {
return &i
}

View File

@@ -25,7 +25,7 @@ xPathScrapers:
- anything
`
c := &Definition{}
c := &config{}
err := yaml.Unmarshal([]byte(yamlStr), &c)
if err == nil {

View File

@@ -110,7 +110,7 @@ func (p queryURLParameters) constructURL(url string) string {
}
// replaceURL does a partial URL Replace ( only url parameter is used)
func replaceURL(url string, scraperConfig ByURLDefinition) string {
func replaceURL(url string, scraperConfig scraperTypeConfig) string {
u := url
queryURL := queryURLParameterFromURL(u)
if scraperConfig.QueryURLReplacements != nil {

View File

@@ -208,11 +208,22 @@ func galleryInputFromGallery(gallery *models.Gallery) galleryInput {
var ErrScraperScript = errors.New("scraper script error")
type scriptScraper struct {
definition Definition
scraper scraperTypeConfig
config config
globalConfig GlobalConfig
}
func (s *scriptScraper) runScraperScript(ctx context.Context, command []string, inString string, out interface{}) error {
func newScriptScraper(scraper scraperTypeConfig, config config, globalConfig GlobalConfig) *scriptScraper {
return &scriptScraper{
scraper: scraper,
config: config,
globalConfig: globalConfig,
}
}
func (s *scriptScraper) runScraperScript(ctx context.Context, inString string, out interface{}) error {
command := s.scraper.Script
var cmd *exec.Cmd
if python.IsPythonCommand(command[0]) {
pythonPath := s.globalConfig.GetPythonPath()
@@ -222,7 +233,7 @@ func (s *scriptScraper) runScraperScript(ctx context.Context, command []string,
logger.Warnf("%s", err)
} else {
cmd = p.Command(ctx, command[1:])
envVariable, _ := filepath.Abs(filepath.Dir(filepath.Dir(s.definition.path)))
envVariable, _ := filepath.Abs(filepath.Dir(filepath.Dir(s.config.path)))
python.AppendPythonPath(cmd, envVariable)
}
}
@@ -232,7 +243,7 @@ func (s *scriptScraper) runScraperScript(ctx context.Context, command []string,
cmd = stashExec.CommandContext(ctx, command[0], command[1:]...)
}
cmd.Dir = filepath.Dir(s.definition.path)
cmd.Dir = filepath.Dir(s.config.path)
stdin, err := cmd.StdinPipe()
if err != nil {
@@ -262,7 +273,7 @@ func (s *scriptScraper) runScraperScript(ctx context.Context, command []string,
return errors.New("error running scraper script")
}
go handleScraperStderr(s.definition.Name, stderr)
go handleScraperStderr(s.config.Name, stderr)
logger.Debugf("Scraper script <%s> started", strings.Join(cmd.Args, " "))
@@ -301,39 +312,7 @@ func (s *scriptScraper) runScraperScript(ctx context.Context, command []string,
return nil
}
func (s *scriptScraper) scrape(ctx context.Context, command []string, input string, ty ScrapeContentType) (ScrapedContent, error) {
switch ty {
case ScrapeContentTypePerformer:
var performer *models.ScrapedPerformer
err := s.runScraperScript(ctx, command, input, &performer)
return performer, err
case ScrapeContentTypeGallery:
var gallery *models.ScrapedGallery
err := s.runScraperScript(ctx, command, input, &gallery)
return gallery, err
case ScrapeContentTypeScene:
var scene *models.ScrapedScene
err := s.runScraperScript(ctx, command, input, &scene)
return scene, err
case ScrapeContentTypeMovie, ScrapeContentTypeGroup:
var movie *models.ScrapedMovie
err := s.runScraperScript(ctx, command, input, &movie)
return movie, err
case ScrapeContentTypeImage:
var image *models.ScrapedImage
err := s.runScraperScript(ctx, command, input, &image)
return image, err
}
return nil, ErrNotSupported
}
type scriptNameScraper struct {
scriptScraper
definition ByNameDefinition
}
func (s *scriptNameScraper) scrapeByName(ctx context.Context, name string, ty ScrapeContentType) ([]ScrapedContent, error) {
func (s *scriptScraper) scrapeByName(ctx context.Context, name string, ty ScrapeContentType) ([]ScrapedContent, error) {
input := `{"name": "` + name + `"}`
var ret []ScrapedContent
@@ -341,7 +320,7 @@ func (s *scriptNameScraper) scrapeByName(ctx context.Context, name string, ty Sc
switch ty {
case ScrapeContentTypePerformer:
var performers []models.ScrapedPerformer
err = s.runScraperScript(ctx, s.definition.Script, input, &performers)
err = s.runScraperScript(ctx, input, &performers)
if err == nil {
for _, p := range performers {
v := p
@@ -350,7 +329,7 @@ func (s *scriptNameScraper) scrapeByName(ctx context.Context, name string, ty Sc
}
case ScrapeContentTypeScene:
var scenes []models.ScrapedScene
err = s.runScraperScript(ctx, s.definition.Script, input, &scenes)
err = s.runScraperScript(ctx, input, &scenes)
if err == nil {
for _, s := range scenes {
v := s
@@ -364,21 +343,7 @@ func (s *scriptNameScraper) scrapeByName(ctx context.Context, name string, ty Sc
return ret, err
}
type scriptURLScraper struct {
scriptScraper
definition ByURLDefinition
}
func (s *scriptURLScraper) scrapeByURL(ctx context.Context, url string, ty ScrapeContentType) (ScrapedContent, error) {
return s.scrape(ctx, s.definition.Script, `{"url": "`+url+`"}`, ty)
}
type scriptFragmentScraper struct {
scriptScraper
definition ByFragmentDefinition
}
func (s *scriptFragmentScraper) scrapeByFragment(ctx context.Context, input Input) (ScrapedContent, error) {
func (s *scriptScraper) scrapeByFragment(ctx context.Context, input Input) (ScrapedContent, error) {
var inString []byte
var err error
var ty ScrapeContentType
@@ -398,10 +363,41 @@ func (s *scriptFragmentScraper) scrapeByFragment(ctx context.Context, input Inpu
return nil, err
}
return s.scrape(ctx, s.definition.Script, string(inString), ty)
return s.scrape(ctx, string(inString), ty)
}
func (s *scriptFragmentScraper) scrapeSceneByScene(ctx context.Context, scene *models.Scene) (*models.ScrapedScene, error) {
func (s *scriptScraper) scrapeByURL(ctx context.Context, url string, ty ScrapeContentType) (ScrapedContent, error) {
return s.scrape(ctx, `{"url": "`+url+`"}`, ty)
}
func (s *scriptScraper) scrape(ctx context.Context, input string, ty ScrapeContentType) (ScrapedContent, error) {
switch ty {
case ScrapeContentTypePerformer:
var performer *models.ScrapedPerformer
err := s.runScraperScript(ctx, input, &performer)
return performer, err
case ScrapeContentTypeGallery:
var gallery *models.ScrapedGallery
err := s.runScraperScript(ctx, input, &gallery)
return gallery, err
case ScrapeContentTypeScene:
var scene *models.ScrapedScene
err := s.runScraperScript(ctx, input, &scene)
return scene, err
case ScrapeContentTypeMovie, ScrapeContentTypeGroup:
var movie *models.ScrapedMovie
err := s.runScraperScript(ctx, input, &movie)
return movie, err
case ScrapeContentTypeImage:
var image *models.ScrapedImage
err := s.runScraperScript(ctx, input, &image)
return image, err
}
return nil, ErrNotSupported
}
func (s *scriptScraper) scrapeSceneByScene(ctx context.Context, scene *models.Scene) (*models.ScrapedScene, error) {
inString, err := json.Marshal(sceneInputFromScene(scene))
if err != nil {
@@ -410,12 +406,12 @@ func (s *scriptFragmentScraper) scrapeSceneByScene(ctx context.Context, scene *m
var ret *models.ScrapedScene
err = s.runScraperScript(ctx, s.definition.Script, string(inString), &ret)
err = s.runScraperScript(ctx, string(inString), &ret)
return ret, err
}
func (s *scriptFragmentScraper) scrapeGalleryByGallery(ctx context.Context, gallery *models.Gallery) (*models.ScrapedGallery, error) {
func (s *scriptScraper) scrapeGalleryByGallery(ctx context.Context, gallery *models.Gallery) (*models.ScrapedGallery, error) {
inString, err := json.Marshal(galleryInputFromGallery(gallery))
if err != nil {
@@ -424,12 +420,12 @@ func (s *scriptFragmentScraper) scrapeGalleryByGallery(ctx context.Context, gall
var ret *models.ScrapedGallery
err = s.runScraperScript(ctx, s.definition.Script, string(inString), &ret)
err = s.runScraperScript(ctx, string(inString), &ret)
return ret, err
}
func (s *scriptFragmentScraper) scrapeImageByImage(ctx context.Context, image *models.Image) (*models.ScrapedImage, error) {
func (s *scriptScraper) scrapeImageByImage(ctx context.Context, image *models.Image) (*models.ScrapedImage, error) {
inString, err := json.Marshal(imageToUpdateInput(image))
if err != nil {
@@ -438,7 +434,7 @@ func (s *scriptFragmentScraper) scrapeImageByImage(ctx context.Context, image *m
var ret *models.ScrapedImage
err = s.runScraperScript(ctx, s.definition.Script, string(inString), &ret)
err = s.runScraperScript(ctx, string(inString), &ret)
return ret, err
}

View File

@@ -14,13 +14,15 @@ import (
)
type stashScraper struct {
config Definition
scraper scraperTypeConfig
config config
globalConfig GlobalConfig
client *http.Client
}
func newStashScraper(client *http.Client, config Definition, globalConfig GlobalConfig) *stashScraper {
func newStashScraper(scraper scraperTypeConfig, client *http.Client, config config, globalConfig GlobalConfig) *stashScraper {
return &stashScraper{
scraper: scraper,
config: config,
client: client,
globalConfig: globalConfig,

View File

@@ -25,8 +25,8 @@ import (
const scrapeDefaultSleep = time.Second * 2
func loadURL(ctx context.Context, loadURL string, client *http.Client, def Definition, globalConfig GlobalConfig) (io.Reader, error) {
driverOptions := def.DriverOptions
func loadURL(ctx context.Context, loadURL string, client *http.Client, scraperConfig config, globalConfig GlobalConfig) (io.Reader, error) {
driverOptions := scraperConfig.DriverOptions
if driverOptions != nil && driverOptions.UseCDP {
// get the page using chrome dp
return urlFromCDP(ctx, loadURL, *driverOptions, globalConfig)
@@ -37,7 +37,7 @@ func loadURL(ctx context.Context, loadURL string, client *http.Client, def Defin
return nil, err
}
jar, err := def.jar()
jar, err := scraperConfig.jar()
if err != nil {
return nil, fmt.Errorf("error creating cookie jar: %w", err)
}
@@ -83,7 +83,7 @@ func loadURL(ctx context.Context, loadURL string, client *http.Client, def Defin
}
bodyReader := bytes.NewReader(body)
printCookies(jar, def, "Jar cookies found for scraper urls")
printCookies(jar, scraperConfig, "Jar cookies found for scraper urls")
return charset.NewReader(bodyReader, resp.Header.Get("Content-Type"))
}

View File

@@ -3,6 +3,7 @@ package scraper
import (
"bytes"
"context"
"errors"
"fmt"
"net/http"
"net/url"
@@ -18,36 +19,49 @@ import (
)
type xpathScraper struct {
definition Definition
scraper scraperTypeConfig
config config
globalConfig GlobalConfig
client *http.Client
}
func (s *xpathScraper) getXpathScraper(name string) (*mappedScraper, error) {
ret, ok := s.definition.XPathScrapers[name]
if !ok {
return nil, fmt.Errorf("xpath scraper with name %s not found in config", name)
func newXpathScraper(scraper scraperTypeConfig, client *http.Client, config config, globalConfig GlobalConfig) *xpathScraper {
return &xpathScraper{
scraper: scraper,
config: config,
globalConfig: globalConfig,
client: client,
}
return &ret, nil
}
type xpathURLScraper struct {
xpathScraper
definition ByURLDefinition
func (s *xpathScraper) getXpathScraper() *mappedScraper {
return s.config.XPathScrapers[s.scraper.Scraper]
}
func (s *xpathURLScraper) scrapeByURL(ctx context.Context, url string, ty ScrapeContentType) (ScrapedContent, error) {
scraper, err := s.getXpathScraper(s.definition.Scraper)
if err != nil {
return nil, err
func (s *xpathScraper) scrapeURL(ctx context.Context, url string) (*html.Node, *mappedScraper, error) {
scraper := s.getXpathScraper()
if scraper == nil {
return nil, nil, errors.New("xpath scraper with name " + s.scraper.Scraper + " not found in config")
}
doc, err := s.loadURL(ctx, url)
if err != nil {
return nil, nil, err
}
return doc, scraper, nil
}
func (s *xpathScraper) scrapeByURL(ctx context.Context, url string, ty ScrapeContentType) (ScrapedContent, error) {
u := replaceURL(url, s.scraper) // allow a URL Replace for performer by URL queries
doc, scraper, err := s.scrapeURL(ctx, u)
if err != nil {
return nil, err
}
q := s.getXPathQuery(doc, url)
q := s.getXPathQuery(doc, u)
// if these just return the return values from scraper.scrape* functions then
// it ends up returning ScrapedContent(nil) rather than nil
switch ty {
@@ -86,15 +100,11 @@ func (s *xpathURLScraper) scrapeByURL(ctx context.Context, url string, ty Scrape
return nil, ErrNotSupported
}
type xpathNameScraper struct {
xpathScraper
definition ByNameDefinition
}
func (s *xpathScraper) scrapeByName(ctx context.Context, name string, ty ScrapeContentType) ([]ScrapedContent, error) {
scraper := s.getXpathScraper()
func (s *xpathNameScraper) scrapeByName(ctx context.Context, name string, ty ScrapeContentType) ([]ScrapedContent, error) {
scraper, err := s.getXpathScraper(s.definition.Scraper)
if err != nil {
return nil, err
if scraper == nil {
return nil, fmt.Errorf("%w: name %v", ErrNotFound, s.scraper.Scraper)
}
const placeholder = "{}"
@@ -102,7 +112,7 @@ func (s *xpathNameScraper) scrapeByName(ctx context.Context, name string, ty Scr
// replace the placeholder string with the URL-escaped name
escapedName := url.QueryEscape(name)
url := s.definition.QueryURL
url := s.scraper.QueryURL
url = strings.ReplaceAll(url, placeholder, escapedName)
doc, err := s.loadURL(ctx, url)
@@ -141,22 +151,18 @@ func (s *xpathNameScraper) scrapeByName(ctx context.Context, name string, ty Scr
return nil, ErrNotSupported
}
type xpathFragmentScraper struct {
xpathScraper
definition ByFragmentDefinition
}
func (s *xpathFragmentScraper) scrapeSceneByScene(ctx context.Context, scene *models.Scene) (*models.ScrapedScene, error) {
func (s *xpathScraper) scrapeSceneByScene(ctx context.Context, scene *models.Scene) (*models.ScrapedScene, error) {
// construct the URL
queryURL := queryURLParametersFromScene(scene)
if s.definition.QueryURLReplacements != nil {
queryURL.applyReplacements(s.definition.QueryURLReplacements)
if s.scraper.QueryURLReplacements != nil {
queryURL.applyReplacements(s.scraper.QueryURLReplacements)
}
url := queryURL.constructURL(s.definition.QueryURL)
url := queryURL.constructURL(s.scraper.QueryURL)
scraper, err := s.getXpathScraper(s.definition.Scraper)
if err != nil {
return nil, err
scraper := s.getXpathScraper()
if scraper == nil {
return nil, errors.New("xpath scraper with name " + s.scraper.Scraper + " not found in config")
}
doc, err := s.loadURL(ctx, url)
@@ -169,7 +175,7 @@ func (s *xpathFragmentScraper) scrapeSceneByScene(ctx context.Context, scene *mo
return scraper.scrapeScene(ctx, q)
}
func (s *xpathFragmentScraper) scrapeByFragment(ctx context.Context, input Input) (ScrapedContent, error) {
func (s *xpathScraper) scrapeByFragment(ctx context.Context, input Input) (ScrapedContent, error) {
switch {
case input.Gallery != nil:
return nil, fmt.Errorf("%w: cannot use an xpath scraper as a gallery fragment scraper", ErrNotSupported)
@@ -183,14 +189,15 @@ func (s *xpathFragmentScraper) scrapeByFragment(ctx context.Context, input Input
// construct the URL
queryURL := queryURLParametersFromScrapedScene(scene)
if s.definition.QueryURLReplacements != nil {
queryURL.applyReplacements(s.definition.QueryURLReplacements)
if s.scraper.QueryURLReplacements != nil {
queryURL.applyReplacements(s.scraper.QueryURLReplacements)
}
url := queryURL.constructURL(s.definition.QueryURL)
url := queryURL.constructURL(s.scraper.QueryURL)
scraper, err := s.getXpathScraper(s.definition.Scraper)
if err != nil {
return nil, err
scraper := s.getXpathScraper()
if scraper == nil {
return nil, errors.New("xpath scraper with name " + s.scraper.Scraper + " not found in config")
}
doc, err := s.loadURL(ctx, url)
@@ -203,17 +210,18 @@ func (s *xpathFragmentScraper) scrapeByFragment(ctx context.Context, input Input
return scraper.scrapeScene(ctx, q)
}
func (s *xpathFragmentScraper) scrapeGalleryByGallery(ctx context.Context, gallery *models.Gallery) (*models.ScrapedGallery, error) {
func (s *xpathScraper) scrapeGalleryByGallery(ctx context.Context, gallery *models.Gallery) (*models.ScrapedGallery, error) {
// construct the URL
queryURL := queryURLParametersFromGallery(gallery)
if s.definition.QueryURLReplacements != nil {
queryURL.applyReplacements(s.definition.QueryURLReplacements)
if s.scraper.QueryURLReplacements != nil {
queryURL.applyReplacements(s.scraper.QueryURLReplacements)
}
url := queryURL.constructURL(s.definition.QueryURL)
url := queryURL.constructURL(s.scraper.QueryURL)
scraper, err := s.getXpathScraper(s.definition.Scraper)
if err != nil {
return nil, err
scraper := s.getXpathScraper()
if scraper == nil {
return nil, errors.New("xpath scraper with name " + s.scraper.Scraper + " not found in config")
}
doc, err := s.loadURL(ctx, url)
@@ -226,17 +234,18 @@ func (s *xpathFragmentScraper) scrapeGalleryByGallery(ctx context.Context, galle
return scraper.scrapeGallery(ctx, q)
}
func (s *xpathFragmentScraper) scrapeImageByImage(ctx context.Context, image *models.Image) (*models.ScrapedImage, error) {
func (s *xpathScraper) scrapeImageByImage(ctx context.Context, image *models.Image) (*models.ScrapedImage, error) {
// construct the URL
queryURL := queryURLParametersFromImage(image)
if s.definition.QueryURLReplacements != nil {
queryURL.applyReplacements(s.definition.QueryURLReplacements)
if s.scraper.QueryURLReplacements != nil {
queryURL.applyReplacements(s.scraper.QueryURLReplacements)
}
url := queryURL.constructURL(s.definition.QueryURL)
url := queryURL.constructURL(s.scraper.QueryURL)
scraper, err := s.getXpathScraper(s.definition.Scraper)
if err != nil {
return nil, err
scraper := s.getXpathScraper()
if scraper == nil {
return nil, errors.New("xpath scraper with name " + s.scraper.Scraper + " not found in config")
}
doc, err := s.loadURL(ctx, url)
@@ -250,14 +259,14 @@ func (s *xpathFragmentScraper) scrapeImageByImage(ctx context.Context, image *mo
}
func (s *xpathScraper) loadURL(ctx context.Context, url string) (*html.Node, error) {
r, err := loadURL(ctx, url, s.client, s.definition, s.globalConfig)
r, err := loadURL(ctx, url, s.client, s.config, s.globalConfig)
if err != nil {
return nil, fmt.Errorf("failed to load URL %q: %w", url, err)
}
ret, err := html.Parse(r)
if err == nil && s.definition.DebugOptions != nil && s.definition.DebugOptions.PrintHTML {
if err == nil && s.config.DebugOptions != nil && s.config.DebugOptions.PrintHTML {
var b bytes.Buffer
if err := html.Render(&b, ret); err != nil {
logger.Warnf("could not render HTML: %v", err)

View File

@@ -674,10 +674,10 @@ func verifyPerformers(t *testing.T, expectedNames []string, expectedURLs []strin
}
if expectedName != actualName {
t.Errorf("Expected performer name %q, got %q", expectedName, actualName)
t.Errorf("Expected performer name %s, got %s", expectedName, actualName)
}
if expectedURL != actualURL {
t.Errorf("Expected performer URL %q, got %q", expectedURL, actualURL)
t.Errorf("Expected performer URL %s, got %s", expectedName, actualName)
}
i++
}
@@ -780,7 +780,7 @@ xPathScrapers:
Name: //studio
`
c := &Definition{}
c := &config{}
err := yaml.Unmarshal([]byte(yamlStr), &c)
if err != nil {
@@ -892,7 +892,7 @@ xPathScrapers:
selector: //span
`
c := &Definition{}
c := &config{}
err := yaml.Unmarshal([]byte(yamlStr), &c)
if err != nil {
@@ -904,8 +904,12 @@ xPathScrapers:
client := &http.Client{}
ctx := context.Background()
s := scraperFromDefinition(*c, globalConfig)
content, err := s.viaURL(ctx, client, ts.URL, ScrapeContentTypePerformer)
s := newGroupScraper(*c, globalConfig)
us, ok := s.(urlScraper)
if !ok {
t.Error("couldn't convert scraper into url scraper")
}
content, err := us.viaURL(ctx, client, ts.URL, ScrapeContentTypePerformer)
if err != nil {
t.Errorf("Error scraping performer: %s", err.Error())

View File

@@ -45,23 +45,6 @@ func UniqueFold(s []string) []string {
return ret
}
// UniqueExcludeFold returns a deduplicated slice of strings with the excluded string removed.
// The comparison is case-insensitive.
func UniqueExcludeFold(values []string, exclude string) []string {
seen := make(map[string]struct{}, len(values))
seen[strings.ToLower(exclude)] = struct{}{}
ret := make([]string, 0, len(values))
for _, v := range values {
vLower := strings.ToLower(v)
if _, exists := seen[vLower]; exists {
continue
}
seen[vLower] = struct{}{}
ret = append(ret, v)
}
return ret
}
// TrimSpace trims whitespace from each string in a slice.
func TrimSpace(s []string) []string {
for i, v := range s {

View File

@@ -1126,40 +1126,3 @@ func (h *relatedFilterHandler) handle(ctx context.Context, f *filterBuilder) {
f.addWhere(fmt.Sprintf("%s IN ("+subQuery.toSQL(false)+")", h.relatedIDCol), subQuery.args...)
}
type phashDistanceCriterionHandler struct {
// assumes that applicable fingerprints table is joined as fingerprints_phash
joinFn func(f *filterBuilder)
criterion *models.PhashDistanceCriterionInput
}
func (h *phashDistanceCriterionHandler) handle(ctx context.Context, f *filterBuilder) {
phashDistance := h.criterion
if phashDistance == nil {
return
}
h.joinFn(f)
value, _ := utils.StringToPhash(phashDistance.Value)
distance := 0
if phashDistance.Distance != nil {
distance = *phashDistance.Distance
}
switch {
case phashDistance.Modifier == models.CriterionModifierEquals && distance > 0:
// needed to avoid a type mismatch
f.addWhere("typeof(fingerprints_phash.fingerprint) = 'integer'")
f.addWhere("phash_distance(fingerprints_phash.fingerprint, ?) < ?", value, distance)
case phashDistance.Modifier == models.CriterionModifierNotEquals && distance > 0:
// needed to avoid a type mismatch
f.addWhere("typeof(fingerprints_phash.fingerprint) = 'integer'")
f.addWhere("phash_distance(fingerprints_phash.fingerprint, ?) > ?", value, distance)
default:
intCriterionHandler(&models.IntCriterionInput{
Value: int(value),
Modifier: phashDistance.Modifier,
}, "fingerprints_phash.fingerprint", nil)(ctx, f)
}
}

View File

@@ -34,7 +34,7 @@ const (
cacheSizeEnv = "STASH_SQLITE_CACHE_SIZE"
)
var appSchemaVersion uint = 76
var appSchemaVersion uint = 75
//go:embed migrations/*.sql
var migrationsBox embed.FS

View File

@@ -62,15 +62,6 @@ func (qb *imageFilterHandler) criterionHandler() criterionHandler {
stringCriterionHandler(imageFilter.Checksum, "fingerprints_md5.fingerprint")(ctx, f)
}),
&phashDistanceCriterionHandler{
joinFn: func(f *filterBuilder) {
imageRepository.addImagesFilesTable(f)
f.addLeftJoin(fingerprintTable, "fingerprints_phash", "images_files.file_id = fingerprints_phash.file_id AND fingerprints_phash.type = 'phash'")
},
criterion: imageFilter.PhashDistance,
},
stringCriterionHandler(imageFilter.Title, "images.title"),
stringCriterionHandler(imageFilter.Code, "images.code"),
stringCriterionHandler(imageFilter.Details, "images.details"),

View File

@@ -1,9 +0,0 @@
CREATE TABLE `studio_custom_fields` (
`studio_id` integer NOT NULL,
`field` varchar(64) NOT NULL,
`value` BLOB NOT NULL,
PRIMARY KEY (`studio_id`, `field`),
foreign key(`studio_id`) references `studios`(`id`) on delete CASCADE
);
CREATE INDEX `index_studio_custom_fields_field_value` ON `studio_custom_fields` (`field`, `value`);

View File

@@ -706,28 +706,6 @@ func (qb *PerformerStore) sortByLastOAt(direction string) string {
return " ORDER BY (" + selectPerformerLastOAtSQL + ") " + direction
}
// used for sorting on performer latest scene
var selectPerformerLatestSceneSQL = utils.StrFormat(
"SELECT MAX(date) FROM ("+
"SELECT {date} FROM {performers_scenes} s "+
"LEFT JOIN {scenes} ON {scenes}.id = s.{scene_id} "+
"WHERE s.{performer_id} = {performers}.id"+
")",
map[string]interface{}{
"performer_id": performerIDColumn,
"performers": performerTable,
"performers_scenes": performersScenesTable,
"scenes": sceneTable,
"scene_id": sceneIDColumn,
"date": sceneDateColumn,
},
)
func (qb *PerformerStore) sortByLatestScene(direction string) string {
// need to get the latest date from scenes
return " ORDER BY (" + selectPerformerLatestSceneSQL + ") " + direction
}
// used for sorting on performer last view_date
var selectPerformerLastPlayedAtSQL = utils.StrFormat(
"SELECT MAX(view_date) FROM ("+
@@ -784,7 +762,6 @@ var performerSortOptions = sortOptions{
"images_count",
"last_o_at",
"last_played_at",
"latest_scene",
"measurements",
"name",
"o_counter",
@@ -835,8 +812,6 @@ func (qb *PerformerStore) getPerformerSort(findFilter *models.FindFilterType) (s
sortQuery += qb.sortByLastPlayedAt(direction)
case "last_o_at":
sortQuery += qb.sortByLastOAt(direction)
case "latest_scene":
sortQuery += qb.sortByLatestScene(direction)
default:
sortQuery += getSort(sort, direction, "performers")
}

View File

@@ -26,7 +26,6 @@ const (
sceneTable = "scenes"
scenesFilesTable = "scenes_files"
sceneIDColumn = "scene_id"
sceneDateColumn = "date"
performersScenesTable = "performers_scenes"
scenesTagsTable = "scenes_tags"
scenesGalleriesTable = "scenes_galleries"

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/utils"
)
type sceneFilterHandler struct {
@@ -82,27 +83,14 @@ func (qb *sceneFilterHandler) criterionHandler() criterionHandler {
criterionHandlerFunc(func(ctx context.Context, f *filterBuilder) {
if sceneFilter.Phash != nil {
// backwards compatibility
h := phashDistanceCriterionHandler{
joinFn: func(f *filterBuilder) {
qb.addSceneFilesTable(f)
f.addLeftJoin(fingerprintTable, "fingerprints_phash", "scenes_files.file_id = fingerprints_phash.file_id AND fingerprints_phash.type = 'phash'")
},
criterion: &models.PhashDistanceCriterionInput{
Value: sceneFilter.Phash.Value,
Modifier: sceneFilter.Phash.Modifier,
},
}
h.handle(ctx, f)
qb.phashDistanceCriterionHandler(&models.PhashDistanceCriterionInput{
Value: sceneFilter.Phash.Value,
Modifier: sceneFilter.Phash.Modifier,
})(ctx, f)
}
}),
&phashDistanceCriterionHandler{
joinFn: func(f *filterBuilder) {
qb.addSceneFilesTable(f)
f.addLeftJoin(fingerprintTable, "fingerprints_phash", "scenes_files.file_id = fingerprints_phash.file_id AND fingerprints_phash.type = 'phash'")
},
criterion: sceneFilter.PhashDistance,
},
qb.phashDistanceCriterionHandler(sceneFilter.PhashDistance),
intCriterionHandler(sceneFilter.Rating100, "scenes.rating", nil),
qb.oCountCriterionHandler(sceneFilter.OCounter),
@@ -139,8 +127,6 @@ func (qb *sceneFilterHandler) criterionHandler() criterionHandler {
parentIDCol: "scenes.id",
},
qb.stashIDCountCriterionHandler(sceneFilter.StashIDCount),
boolCriterionHandler(sceneFilter.Interactive, "video_files.interactive", qb.addVideoFilesTable),
intCriterionHandler(sceneFilter.InteractiveSpeed, "video_files.interactive_speed", qb.addVideoFilesTable),
@@ -455,16 +441,6 @@ func (qb *sceneFilterHandler) tagCountCriterionHandler(tagCount *models.IntCrite
return h.handler(tagCount)
}
func (qb *sceneFilterHandler) stashIDCountCriterionHandler(stashIDCount *models.IntCriterionInput) criterionHandlerFunc {
h := countCriterionHandlerBuilder{
primaryTable: sceneTable,
joinTable: "scene_stash_ids",
primaryFK: sceneIDColumn,
}
return h.handler(stashIDCount)
}
func (qb *sceneFilterHandler) performersCriterionHandler(performers *models.MultiCriterionInput) criterionHandlerFunc {
h := joinedMultiCriterionHandlerBuilder{
primaryTable: sceneTable,
@@ -571,3 +547,42 @@ func (qb *sceneFilterHandler) performerTagsCriterionHandler(tags *models.Hierarc
joinPrimaryKey: sceneIDColumn,
}
}
func (qb *sceneFilterHandler) phashDistanceCriterionHandler(phashDistance *models.PhashDistanceCriterionInput) criterionHandlerFunc {
return func(ctx context.Context, f *filterBuilder) {
if phashDistance != nil {
qb.addSceneFilesTable(f)
f.addLeftJoin(fingerprintTable, "fingerprints_phash", "scenes_files.file_id = fingerprints_phash.file_id AND fingerprints_phash.type = 'phash'")
value, _ := utils.StringToPhash(phashDistance.Value)
distance := 0
if phashDistance.Distance != nil {
distance = *phashDistance.Distance
}
if distance == 0 {
// use the default handler
intCriterionHandler(&models.IntCriterionInput{
Value: int(value),
Modifier: phashDistance.Modifier,
}, "fingerprints_phash.fingerprint", nil)(ctx, f)
}
switch {
case phashDistance.Modifier == models.CriterionModifierEquals && distance > 0:
// needed to avoid a type mismatch
f.addWhere("typeof(fingerprints_phash.fingerprint) = 'integer'")
f.addWhere("phash_distance(fingerprints_phash.fingerprint, ?) < ?", value, distance)
case phashDistance.Modifier == models.CriterionModifierNotEquals && distance > 0:
// needed to avoid a type mismatch
f.addWhere("typeof(fingerprints_phash.fingerprint) = 'integer'")
f.addWhere("phash_distance(fingerprints_phash.fingerprint, ?) > ?", value, distance)
default:
intCriterionHandler(&models.IntCriterionInput{
Value: int(value),
Modifier: phashDistance.Modifier,
}, "fingerprints_phash.fingerprint", nil)(ctx, f)
}
}
}
}

View File

@@ -2273,32 +2273,6 @@ func TestSceneQuery(t *testing.T) {
nil,
false,
},
{
"single stash id",
nil,
&models.SceneFilterType{
StashIDCount: &models.IntCriterionInput{
Modifier: models.CriterionModifierEquals,
Value: 1,
},
},
[]int{sceneIdxWithGallery, sceneIdxWithPerformer},
[]int{sceneIdxWithGroup},
false,
},
{
"less than one stash id",
nil,
&models.SceneFilterType{
StashIDCount: &models.IntCriterionInput{
Modifier: models.CriterionModifierLessThan,
Value: 1,
},
},
[]int{sceneIdxWithGroup},
[]int{sceneIdxWithGallery, sceneIdxWithPerformer},
false,
},
}
for _, tt := range tests {

View File

@@ -1076,13 +1076,6 @@ func getObjectDate(index int) *models.Date {
return &ret
}
func sceneStashIDs(i int) []models.StashID {
if i%5 == 0 {
return nil
}
return []models.StashID{sceneStashID(i)}
}
func sceneStashID(i int) models.StashID {
return models.StashID{
StashID: getSceneStringValue(i, "stashid"),
@@ -1181,7 +1174,9 @@ func makeScene(i int) *models.Scene {
PerformerIDs: models.NewRelatedIDs(pids),
TagIDs: models.NewRelatedIDs(tids),
Groups: models.NewRelatedGroups(groups),
StashIDs: models.NewRelatedStashIDs(sceneStashIDs(i)),
StashIDs: models.NewRelatedStashIDs([]models.StashID{
sceneStashID(i),
}),
PlayDuration: getScenePlayDuration(i),
ResumeTime: getSceneResumeTime(i),
}
@@ -1765,19 +1760,7 @@ func getStudioNullStringValue(index int, field string) string {
return ret.String
}
func getStudioCustomFields(index int) map[string]interface{} {
if index%5 == 0 {
return nil
}
return map[string]interface{}{
"string": getStudioStringValue(index, "custom"),
"int": int64(index % 5),
"real": float64(index) / 10,
}
}
func createStudio(ctx context.Context, sqb *sqlite.StudioStore, name string, parentID *int, customFields map[string]interface{}) (*models.Studio, error) {
func createStudio(ctx context.Context, sqb *sqlite.StudioStore, name string, parentID *int) (*models.Studio, error) {
studio := models.Studio{
Name: name,
}
@@ -1786,7 +1769,7 @@ func createStudio(ctx context.Context, sqb *sqlite.StudioStore, name string, par
studio.ParentID = parentID
}
err := createStudioFromModel(ctx, sqb, &studio, customFields)
err := createStudioFromModel(ctx, sqb, &studio)
if err != nil {
return nil, err
}
@@ -1794,11 +1777,8 @@ func createStudio(ctx context.Context, sqb *sqlite.StudioStore, name string, par
return &studio, nil
}
func createStudioFromModel(ctx context.Context, sqb *sqlite.StudioStore, studio *models.Studio, customFields map[string]interface{}) error {
err := sqb.Create(ctx, &models.CreateStudioInput{
Studio: studio,
CustomFields: customFields,
})
func createStudioFromModel(ctx context.Context, sqb *sqlite.StudioStore, studio *models.Studio) error {
err := sqb.Create(ctx, studio)
if err != nil {
return fmt.Errorf("Error creating studio %v+: %s", studio, err.Error())
@@ -1860,7 +1840,7 @@ func createStudios(ctx context.Context, n int, o int) error {
alias := getStudioStringValue(i, "Alias")
studio.Aliases = models.NewRelatedStrings([]string{alias})
}
err := createStudioFromModel(ctx, sqb, &studio, getStudioCustomFields(i))
err := createStudioFromModel(ctx, sqb, &studio)
if err != nil {
return err

View File

@@ -15,7 +15,6 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/studio"
"github.com/stashapp/stash/pkg/utils"
)
const (
@@ -141,7 +140,6 @@ var (
type StudioStore struct {
blobJoinQueryBuilder
customFieldsStore
tagRelationshipStore
tableMgr *table
@@ -153,10 +151,6 @@ func NewStudioStore(blobStore *BlobStore) *StudioStore {
blobStore: blobStore,
joinTable: studioTable,
},
customFieldsStore: customFieldsStore{
table: studiosCustomFieldsTable,
fk: studiosCustomFieldsTable.Col(studioIDColumn),
},
tagRelationshipStore: tagRelationshipStore{
idRelationshipStore: idRelationshipStore{
joinTable: studiosTagsTableMgr,
@@ -175,11 +169,11 @@ func (qb *StudioStore) selectDataset() *goqu.SelectDataset {
return dialect.From(qb.table()).Select(qb.table().All())
}
func (qb *StudioStore) Create(ctx context.Context, newObject *models.CreateStudioInput) error {
func (qb *StudioStore) Create(ctx context.Context, newObject *models.Studio) error {
var err error
var r studioRow
r.fromStudio(*newObject.Studio)
r.fromStudio(*newObject)
id, err := qb.tableMgr.insertID(ctx, r)
if err != nil {
@@ -213,17 +207,12 @@ func (qb *StudioStore) Create(ctx context.Context, newObject *models.CreateStudi
}
}
const partial = false
if err := qb.setCustomFields(ctx, id, newObject.CustomFields, partial); err != nil {
return err
}
updated, err := qb.find(ctx, id)
if err != nil {
return fmt.Errorf("finding after create: %w", err)
}
*newObject.Studio = *updated
*newObject = *updated
return nil
}
@@ -264,17 +253,13 @@ func (qb *StudioStore) UpdatePartial(ctx context.Context, input models.StudioPar
}
}
if err := qb.SetCustomFields(ctx, input.ID, input.CustomFields); err != nil {
return nil, err
}
return qb.find(ctx, input.ID)
return qb.Find(ctx, input.ID)
}
// This is only used by the Import/Export functionality
func (qb *StudioStore) Update(ctx context.Context, updatedObject *models.UpdateStudioInput) error {
func (qb *StudioStore) Update(ctx context.Context, updatedObject *models.Studio) error {
var r studioRow
r.fromStudio(*updatedObject.Studio)
r.fromStudio(*updatedObject)
if err := qb.tableMgr.updateByID(ctx, updatedObject.ID, r); err != nil {
return err
@@ -302,10 +287,6 @@ func (qb *StudioStore) Update(ctx context.Context, updatedObject *models.UpdateS
}
}
if err := qb.SetCustomFields(ctx, updatedObject.ID, updatedObject.CustomFields); err != nil {
return err
}
return nil
}
@@ -620,32 +601,12 @@ func (qb *StudioStore) sortByScenesDuration(direction string) string {
) %s`, sceneTable, scenesFilesTable, scenesFilesTable, sceneIDColumn, sceneTable, scenesFilesTable, sceneTable, studioIDColumn, studioTable, getSortDirection(direction))
}
// used for sorting on performer latest scene
var selectStudioLatestSceneSQL = utils.StrFormat(
"SELECT MAX(date) FROM ("+
"SELECT {date} FROM {scenes} s "+
"WHERE s.{studio_id} = {studios}.id"+
")",
map[string]interface{}{
"scenes": sceneTable,
"studios": studioTable,
"studio_id": studioIDColumn,
"date": sceneDateColumn,
},
)
func (qb *StudioStore) sortByLatestScene(direction string) string {
// need to get the latest date from scenes
return " ORDER BY (" + selectStudioLatestSceneSQL + ") " + direction
}
var studioSortOptions = sortOptions{
"child_count",
"created_at",
"galleries_count",
"id",
"images_count",
"latest_scene",
"name",
"scenes_count",
"scenes_duration",
@@ -685,8 +646,6 @@ func (qb *StudioStore) getStudioSort(findFilter *models.FindFilterType) (string,
sortQuery += getCountSort(studioTable, galleryTable, studioIDColumn, direction)
case "child_count":
sortQuery += getCountSort(studioTable, studioTable, studioParentIDColumn, direction)
case "latest_scene":
sortQuery += qb.sortByLatestScene(direction)
default:
sortQuery += getSort(sort, direction, "studios")
}

View File

@@ -117,13 +117,6 @@ func (qb *studioFilterHandler) criterionHandler() criterionHandler {
studioRepository.galleries.innerJoin(f, "", "studios.id")
},
},
&customFieldsFilterHandler{
table: studiosCustomFieldsTable.GetTable(),
fkCol: studioIDColumn,
c: studioFilter.CustomFields,
idCol: "studios.id",
},
}
}

View File

@@ -11,7 +11,6 @@ import (
"strconv"
"strings"
"testing"
"time"
"github.com/stashapp/stash/pkg/models"
"github.com/stretchr/testify/assert"
@@ -48,559 +47,6 @@ func TestStudioFindByName(t *testing.T) {
})
}
func loadStudioRelationships(ctx context.Context, expected models.Studio, actual *models.Studio) error {
if expected.Aliases.Loaded() {
if err := actual.LoadAliases(ctx, db.Studio); err != nil {
return err
}
}
if expected.URLs.Loaded() {
if err := actual.LoadURLs(ctx, db.Studio); err != nil {
return err
}
}
if expected.TagIDs.Loaded() {
if err := actual.LoadTagIDs(ctx, db.Studio); err != nil {
return err
}
}
if expected.StashIDs.Loaded() {
if err := actual.LoadStashIDs(ctx, db.Studio); err != nil {
return err
}
}
return nil
}
func Test_StudioStore_Create(t *testing.T) {
var (
name = "name"
details = "details"
url = "url"
rating = 3
aliases = []string{"alias1", "alias2"}
ignoreAutoTag = true
favorite = true
endpoint1 = "endpoint1"
endpoint2 = "endpoint2"
stashID1 = "stashid1"
stashID2 = "stashid2"
createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)
updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)
)
tests := []struct {
name string
newObject models.CreateStudioInput
wantErr bool
}{
{
"full",
models.CreateStudioInput{
Studio: &models.Studio{
Name: name,
URLs: models.NewRelatedStrings([]string{url}),
Favorite: favorite,
Rating: &rating,
Details: details,
IgnoreAutoTag: ignoreAutoTag,
TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithStudio], tagIDs[tagIdx1WithDupName]}),
Aliases: models.NewRelatedStrings(aliases),
StashIDs: models.NewRelatedStashIDs([]models.StashID{
{
StashID: stashID1,
Endpoint: endpoint1,
UpdatedAt: epochTime,
},
{
StashID: stashID2,
Endpoint: endpoint2,
UpdatedAt: epochTime,
},
}),
CreatedAt: createdAt,
UpdatedAt: updatedAt,
},
CustomFields: testCustomFields,
},
false,
},
{
"invalid tag id",
models.CreateStudioInput{
Studio: &models.Studio{
Name: name,
TagIDs: models.NewRelatedIDs([]int{invalidID}),
},
},
true,
},
}
qb := db.Studio
for _, tt := range tests {
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
assert := assert.New(t)
p := tt.newObject
if err := qb.Create(ctx, &p); (err != nil) != tt.wantErr {
t.Errorf("StudioStore.Create() error = %v, wantErr = %v", err, tt.wantErr)
}
if tt.wantErr {
assert.Zero(p.ID)
return
}
assert.NotZero(p.ID)
copy := *tt.newObject.Studio
copy.ID = p.ID
// load relationships
if err := loadStudioRelationships(ctx, copy, p.Studio); err != nil {
t.Errorf("loadStudioRelationships() error = %v", err)
return
}
assert.Equal(copy, *p.Studio)
// ensure can find the Studio
found, err := qb.Find(ctx, p.ID)
if err != nil {
t.Errorf("StudioStore.Find() error = %v", err)
}
if !assert.NotNil(found) {
return
}
// load relationships
if err := loadStudioRelationships(ctx, copy, found); err != nil {
t.Errorf("loadStudioRelationships() error = %v", err)
return
}
assert.Equal(copy, *found)
// ensure custom fields are set
cf, err := qb.GetCustomFields(ctx, p.ID)
if err != nil {
t.Errorf("StudioStore.GetCustomFields() error = %v", err)
return
}
assert.Equal(tt.newObject.CustomFields, cf)
return
})
}
}
func Test_StudioStore_Update(t *testing.T) {
var (
name = "name"
details = "details"
url = "url"
rating = 3
aliases = []string{"aliasX", "aliasY"}
ignoreAutoTag = true
favorite = true
endpoint1 = "endpoint1"
endpoint2 = "endpoint2"
stashID1 = "stashid1"
stashID2 = "stashid2"
createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)
updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)
)
tests := []struct {
name string
updatedObject models.UpdateStudioInput
wantErr bool
}{
{
"full",
models.UpdateStudioInput{
Studio: &models.Studio{
ID: studioIDs[studioIdxWithGallery],
Name: name,
URLs: models.NewRelatedStrings([]string{url}),
Favorite: favorite,
Rating: &rating,
Details: details,
IgnoreAutoTag: ignoreAutoTag,
Aliases: models.NewRelatedStrings(aliases),
TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithDupName], tagIDs[tagIdx1WithStudio]}),
StashIDs: models.NewRelatedStashIDs([]models.StashID{
{
StashID: stashID1,
Endpoint: endpoint1,
UpdatedAt: epochTime,
},
{
StashID: stashID2,
Endpoint: endpoint2,
UpdatedAt: epochTime,
},
}),
CreatedAt: createdAt,
UpdatedAt: updatedAt,
},
},
false,
},
{
"clear nullables",
models.UpdateStudioInput{
Studio: &models.Studio{
ID: studioIDs[studioIdxWithGallery],
Name: name, // name is mandatory
URLs: models.NewRelatedStrings([]string{}),
Aliases: models.NewRelatedStrings([]string{}),
TagIDs: models.NewRelatedIDs([]int{}),
StashIDs: models.NewRelatedStashIDs([]models.StashID{}),
},
},
false,
},
{
"clear tag ids",
models.UpdateStudioInput{
Studio: &models.Studio{
ID: studioIDs[sceneIdxWithTag],
Name: name, // name is mandatory
TagIDs: models.NewRelatedIDs([]int{}),
},
},
false,
},
{
"set custom fields",
models.UpdateStudioInput{
Studio: &models.Studio{
ID: studioIDs[studioIdxWithGallery],
Name: name, // name is mandatory
},
CustomFields: models.CustomFieldsInput{
Full: testCustomFields,
},
},
false,
},
{
"clear custom fields",
models.UpdateStudioInput{
Studio: &models.Studio{
ID: studioIDs[studioIdxWithGallery],
Name: name, // name is mandatory
},
CustomFields: models.CustomFieldsInput{
Full: map[string]interface{}{},
},
},
false,
},
{
"invalid tag id",
models.UpdateStudioInput{
Studio: &models.Studio{
ID: studioIDs[sceneIdxWithGallery],
Name: name, // name is mandatory
TagIDs: models.NewRelatedIDs([]int{invalidID}),
},
},
true,
},
}
qb := db.Studio
for _, tt := range tests {
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
assert := assert.New(t)
copy := *tt.updatedObject.Studio
if err := qb.Update(ctx, &tt.updatedObject); (err != nil) != tt.wantErr {
t.Errorf("StudioStore.Update() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr {
return
}
s, err := qb.Find(ctx, tt.updatedObject.ID)
if err != nil {
t.Errorf("StudioStore.Find() error = %v", err)
}
// load relationships
if err := loadStudioRelationships(ctx, copy, s); err != nil {
t.Errorf("loadStudioRelationships() error = %v", err)
return
}
assert.Equal(copy, *s)
// ensure custom fields are correct
if tt.updatedObject.CustomFields.Full != nil {
cf, err := qb.GetCustomFields(ctx, tt.updatedObject.ID)
if err != nil {
t.Errorf("StudioStore.GetCustomFields() error = %v", err)
return
}
assert.Equal(tt.updatedObject.CustomFields.Full, cf)
}
})
}
}
func clearStudioPartial() models.StudioPartial {
nullString := models.OptionalString{Set: true, Null: true}
nullInt := models.OptionalInt{Set: true, Null: true}
// leave mandatory fields
return models.StudioPartial{
URLs: &models.UpdateStrings{Mode: models.RelationshipUpdateModeSet},
Aliases: &models.UpdateStrings{Mode: models.RelationshipUpdateModeSet},
Rating: nullInt,
Details: nullString,
TagIDs: &models.UpdateIDs{Mode: models.RelationshipUpdateModeSet},
StashIDs: &models.UpdateStashIDs{Mode: models.RelationshipUpdateModeSet},
}
}
func Test_StudioStore_UpdatePartial(t *testing.T) {
var (
name = "name"
details = "details"
url = "url"
aliases = []string{"aliasX", "aliasY"}
rating = 3
ignoreAutoTag = true
favorite = true
endpoint1 = "endpoint1"
endpoint2 = "endpoint2"
stashID1 = "stashid1"
stashID2 = "stashid2"
createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)
updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)
)
tests := []struct {
name string
id int
partial models.StudioPartial
want models.Studio
wantErr bool
}{
{
"full",
studioIDs[studioIdxWithDupName],
models.StudioPartial{
Name: models.NewOptionalString(name),
URLs: &models.UpdateStrings{
Values: []string{url},
Mode: models.RelationshipUpdateModeSet,
},
Aliases: &models.UpdateStrings{
Values: aliases,
Mode: models.RelationshipUpdateModeSet,
},
Favorite: models.NewOptionalBool(favorite),
Rating: models.NewOptionalInt(rating),
Details: models.NewOptionalString(details),
IgnoreAutoTag: models.NewOptionalBool(ignoreAutoTag),
TagIDs: &models.UpdateIDs{
IDs: []int{tagIDs[tagIdx1WithStudio], tagIDs[tagIdx1WithDupName]},
Mode: models.RelationshipUpdateModeSet,
},
StashIDs: &models.UpdateStashIDs{
StashIDs: []models.StashID{
{
StashID: stashID1,
Endpoint: endpoint1,
UpdatedAt: epochTime,
},
{
StashID: stashID2,
Endpoint: endpoint2,
UpdatedAt: epochTime,
},
},
Mode: models.RelationshipUpdateModeSet,
},
CreatedAt: models.NewOptionalTime(createdAt),
UpdatedAt: models.NewOptionalTime(updatedAt),
},
models.Studio{
ID: studioIDs[studioIdxWithDupName],
Name: name,
URLs: models.NewRelatedStrings([]string{url}),
Aliases: models.NewRelatedStrings(aliases),
Favorite: favorite,
Rating: &rating,
Details: details,
IgnoreAutoTag: ignoreAutoTag,
TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithDupName], tagIDs[tagIdx1WithStudio]}),
StashIDs: models.NewRelatedStashIDs([]models.StashID{
{
StashID: stashID1,
Endpoint: endpoint1,
UpdatedAt: epochTime,
},
{
StashID: stashID2,
Endpoint: endpoint2,
UpdatedAt: epochTime,
},
}),
CreatedAt: createdAt,
UpdatedAt: updatedAt,
},
false,
},
{
"clear all",
studioIDs[studioIdxWithTwoTags],
clearStudioPartial(),
models.Studio{
ID: studioIDs[studioIdxWithTwoTags],
Name: getStudioStringValue(studioIdxWithTwoTags, "Name"),
Favorite: getStudioBoolValue(studioIdxWithTwoTags),
Aliases: models.NewRelatedStrings([]string{}),
TagIDs: models.NewRelatedIDs([]int{}),
StashIDs: models.NewRelatedStashIDs([]models.StashID{}),
IgnoreAutoTag: getIgnoreAutoTag(studioIdxWithTwoTags),
},
false,
},
{
"invalid id",
invalidID,
models.StudioPartial{Name: models.NewOptionalString(name)},
models.Studio{},
true,
},
}
for _, tt := range tests {
qb := db.Studio
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
assert := assert.New(t)
tt.partial.ID = tt.id
got, err := qb.UpdatePartial(ctx, tt.partial)
if (err != nil) != tt.wantErr {
t.Errorf("StudioStore.UpdatePartial() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
if err := loadStudioRelationships(ctx, tt.want, got); err != nil {
t.Errorf("loadStudioRelationships() error = %v", err)
return
}
assert.Equal(tt.want, *got)
s, err := qb.Find(ctx, tt.id)
if err != nil {
t.Errorf("StudioStore.Find() error = %v", err)
}
// load relationships
if err := loadStudioRelationships(ctx, tt.want, s); err != nil {
t.Errorf("loadStudioRelationships() error = %v", err)
return
}
assert.Equal(tt.want, *s)
})
}
}
func Test_StudioStore_UpdatePartialCustomFields(t *testing.T) {
tests := []struct {
name string
id int
partial models.StudioPartial
expected map[string]interface{} // nil to use the partial
}{
{
"set custom fields",
studioIDs[studioIdxWithGallery],
models.StudioPartial{
CustomFields: models.CustomFieldsInput{
Full: testCustomFields,
},
},
nil,
},
{
"clear custom fields",
studioIDs[studioIdxWithGallery],
models.StudioPartial{
CustomFields: models.CustomFieldsInput{
Full: map[string]interface{}{},
},
},
nil,
},
{
"partial custom fields",
studioIDs[studioIdxWithGallery],
models.StudioPartial{
CustomFields: models.CustomFieldsInput{
Partial: map[string]interface{}{
"string": "bbb",
"new_field": "new",
},
},
},
map[string]interface{}{
"int": int64(2),
"real": 0.7,
"string": "bbb",
"new_field": "new",
},
},
}
for _, tt := range tests {
qb := db.Studio
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
assert := assert.New(t)
tt.partial.ID = tt.id
_, err := qb.UpdatePartial(ctx, tt.partial)
if err != nil {
t.Errorf("StudioStore.UpdatePartial() error = %v", err)
return
}
// ensure custom fields are correct
cf, err := qb.GetCustomFields(ctx, tt.id)
if err != nil {
t.Errorf("StudioStore.GetCustomFields() error = %v", err)
return
}
if tt.expected == nil {
assert.Equal(tt.partial.CustomFields.Full, cf)
} else {
assert.Equal(tt.expected, cf)
}
})
}
}
func TestStudioQueryNameOr(t *testing.T) {
const studio1Idx = 1
const studio2Idx = 2
@@ -636,6 +82,14 @@ func TestStudioQueryNameOr(t *testing.T) {
})
}
func loadStudioRelationships(ctx context.Context, t *testing.T, s *models.Studio) error {
if err := s.LoadURLs(ctx, db.Studio); err != nil {
return err
}
return nil
}
func TestStudioQueryNameAndUrl(t *testing.T) {
const studioIdx = 1
studioName := getStudioStringValue(studioIdx, "Name")
@@ -857,13 +311,13 @@ func TestStudioDestroyParent(t *testing.T) {
// create parent and child studios
if err := withTxn(func(ctx context.Context) error {
createdParent, err := createStudio(ctx, db.Studio, parentName, nil, nil)
createdParent, err := createStudio(ctx, db.Studio, parentName, nil)
if err != nil {
return fmt.Errorf("Error creating parent studio: %s", err.Error())
}
parentID := createdParent.ID
createdChild, err := createStudio(ctx, db.Studio, childName, &parentID, nil)
createdChild, err := createStudio(ctx, db.Studio, childName, &parentID)
if err != nil {
return fmt.Errorf("Error creating child studio: %s", err.Error())
}
@@ -919,13 +373,13 @@ func TestStudioUpdateClearParent(t *testing.T) {
// create parent and child studios
if err := withTxn(func(ctx context.Context) error {
createdParent, err := createStudio(ctx, db.Studio, parentName, nil, nil)
createdParent, err := createStudio(ctx, db.Studio, parentName, nil)
if err != nil {
return fmt.Errorf("Error creating parent studio: %s", err.Error())
}
parentID := createdParent.ID
createdChild, err := createStudio(ctx, db.Studio, childName, &parentID, nil)
createdChild, err := createStudio(ctx, db.Studio, childName, &parentID)
if err != nil {
return fmt.Errorf("Error creating child studio: %s", err.Error())
}
@@ -960,7 +414,7 @@ func TestStudioUpdateStudioImage(t *testing.T) {
// create studio to test against
const name = "TestStudioUpdateStudioImage"
created, err := createStudio(ctx, db.Studio, name, nil, nil)
created, err := createStudio(ctx, db.Studio, name, nil)
if err != nil {
return fmt.Errorf("Error creating studio: %s", err.Error())
}
@@ -1124,7 +578,7 @@ func TestStudioStashIDs(t *testing.T) {
// create studio to test against
const name = "TestStudioStashIDs"
created, err := createStudio(ctx, db.Studio, name, nil, nil)
created, err := createStudio(ctx, db.Studio, name, nil)
if err != nil {
return fmt.Errorf("Error creating studio: %s", err.Error())
}
@@ -1536,7 +990,7 @@ func TestStudioAlias(t *testing.T) {
// create studio to test against
const name = "TestStudioAlias"
created, err := createStudio(ctx, db.Studio, name, nil, nil)
created, err := createStudio(ctx, db.Studio, name, nil)
if err != nil {
return fmt.Errorf("Error creating studio: %s", err.Error())
}

View File

@@ -40,7 +40,6 @@ var (
studiosURLsJoinTable = goqu.T(studioURLsTable)
studiosTagsJoinTable = goqu.T(studiosTagsTable)
studiosStashIDsJoinTable = goqu.T("studio_stash_ids")
studiosCustomFieldsTable = goqu.T("studio_custom_fields")
groupsURLsJoinTable = goqu.T(groupURLsTable)
groupsTagsJoinTable = goqu.T(groupsTagsTable)

View File

@@ -153,7 +153,7 @@ func (i *Importer) populateParentStudio(ctx context.Context) error {
}
func (i *Importer) createParentStudio(ctx context.Context, name string) (int, error) {
newStudio := models.NewCreateStudioInput()
newStudio := models.NewStudio()
newStudio.Name = name
err := i.ReaderWriter.Create(ctx, &newStudio)
@@ -194,7 +194,7 @@ func (i *Importer) FindExistingID(ctx context.Context) (*int, error) {
}
func (i *Importer) Create(ctx context.Context) (*int, error) {
err := i.ReaderWriter.Create(ctx, &models.CreateStudioInput{Studio: &i.studio})
err := i.ReaderWriter.Create(ctx, &i.studio)
if err != nil {
return nil, fmt.Errorf("error creating studio: %v", err)
}
@@ -206,7 +206,7 @@ func (i *Importer) Create(ctx context.Context) (*int, error) {
func (i *Importer) Update(ctx context.Context, id int) error {
studio := i.studio
studio.ID = id
err := i.ReaderWriter.Update(ctx, &models.UpdateStudioInput{Studio: &studio})
err := i.ReaderWriter.Update(ctx, &studio)
if err != nil {
return fmt.Errorf("error updating existing studio: %v", err)
}

Some files were not shown because too many files have changed in this diff Show More