mirror of
https://github.com/stashapp/CommunityScripts.git
synced 2026-02-05 04:45:09 -06:00
116 lines
4.6 KiB
Python
116 lines
4.6 KiB
Python
import os
|
|
import sys
|
|
import json
|
|
from urllib.parse import urlparse
|
|
import numpy as np
|
|
from typing import Dict, Any, Optional, List
|
|
|
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../py_dependencies/numpy_1.26.4")))
|
|
|
|
server_connection = json.loads(os.environ.get("SERVER_CONNECTION"))
|
|
from stashapi.stashapp import StashInterface
|
|
|
|
class DataManager:
|
|
def __init__(self, voy_root_folder):
|
|
"""
|
|
Initialize the data manager using folders of .voy files for each model.
|
|
Parameters:
|
|
voy_root_folder: Path to the root folder containing 'facenet' and 'arc' subfolders.
|
|
"""
|
|
self.voy_root_folder = voy_root_folder
|
|
self.embeddings = {
|
|
"facenet": {}, # Dict[str, Dict[str, Any]]
|
|
"arc": {}
|
|
}
|
|
self._load_voy_files()
|
|
self.stash = StashInterface(server_connection)
|
|
|
|
def _load_voy_files(self):
|
|
"""Load all .voy files for each model into memory."""
|
|
for model in ["facenet", "arc"]:
|
|
folder = os.path.join(self.voy_root_folder, model)
|
|
self.embeddings[model] = {}
|
|
if not os.path.isdir(folder):
|
|
continue
|
|
for fname in os.listdir(folder):
|
|
if fname.endswith(".voy.npy") or fname.endswith(".voy"):
|
|
try:
|
|
# Remove .voy or .voy.npy
|
|
if fname.endswith(".voy.npy"):
|
|
id_name = fname[:-8]
|
|
else:
|
|
id_name = fname[:-4]
|
|
stash_id, name = id_name.split("-", 1)
|
|
path = os.path.join(folder, fname)
|
|
embedding = np.load(path)
|
|
self.embeddings[model][stash_id] = {
|
|
"name": name,
|
|
"embedding": embedding
|
|
}
|
|
except Exception as e:
|
|
print(f"Error loading {fname} for {model}: {e}")
|
|
|
|
def get_all_ids(self, model: str = "facenet") -> List[str]:
|
|
"""Return all performer IDs for a given model."""
|
|
return list(self.embeddings.get(model, {}).keys())
|
|
|
|
def get_performer_info(self, stash_id: str, confidence: float) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get performer information from the loaded embeddings.
|
|
Parameters:
|
|
stash_id: Stash ID of the performer
|
|
confidence: Confidence score (0-1)
|
|
Returns:
|
|
Dictionary with performer information or None if not found
|
|
"""
|
|
performer = self.stash.find_performer(stash_id)
|
|
if not performer:
|
|
# Fallback to embedding name if performer not found
|
|
for model in self.embeddings:
|
|
if stash_id in self.embeddings[model]:
|
|
name = self.embeddings[model][stash_id].get("name", "Unknown")
|
|
break
|
|
else:
|
|
name = "Unknown"
|
|
return {
|
|
'id': stash_id,
|
|
"name": name,
|
|
"image": None,
|
|
"confidence": int(confidence * 100),
|
|
}
|
|
return {
|
|
'id': stash_id,
|
|
"name": performer['name'],
|
|
"image": urlparse(performer['image_path']).path if performer.get('image_path') else None,
|
|
"confidence": int(confidence * 100),
|
|
'country': performer.get('country'),
|
|
'distance': int(confidence * 100),
|
|
'performer_url': f"/performers/{stash_id}"
|
|
}
|
|
|
|
def query_index(self, model: str, embedding: np.ndarray, limit: int = 5):
|
|
"""
|
|
Query the loaded embeddings for the closest matches using cosine similarity for a given model.
|
|
Parameters:
|
|
model: 'facenet' or 'arc'
|
|
embedding: The embedding to compare
|
|
limit: Number of top matches to return
|
|
Returns:
|
|
List of (stash_id, distance) tuples, sorted by distance ascending
|
|
"""
|
|
results = []
|
|
for stash_id, data in self.embeddings.get(model, {}).items():
|
|
db_embedding = data["embedding"]
|
|
sim = np.dot(embedding, db_embedding) / (np.linalg.norm(embedding) * np.linalg.norm(db_embedding))
|
|
distance = 1 - sim
|
|
results.append((stash_id, distance))
|
|
results.sort(key=lambda x: x[1])
|
|
return results[:limit]
|
|
|
|
def query_facenet_index(self, embedding: np.ndarray, limit: int = 5):
|
|
"""Query the Facenet index."""
|
|
return self.query_index("facenet", embedding, limit)
|
|
|
|
def query_arc_index(self, embedding: np.ndarray, limit: int = 5):
|
|
"""Query the ArcFace index."""
|
|
return self.query_index("arc", embedding, limit) |