add marker support to ai tagger plugin

This commit is contained in:
skier233
2024-06-13 20:56:36 -04:00
parent 39332d354b
commit 759b90630a
7 changed files with 624 additions and 276 deletions

View File

@@ -0,0 +1,90 @@
from typing import Any, Dict, List
import aiohttp
import pydantic
import config
import stashapi.log as log
# ----------------- AI Server Calling Functions -----------------
async def post_api_async(session, endpoint, payload):
url = f'{config.API_BASE_URL}/{endpoint}'
try:
async with session.post(url, json=payload) as response:
if response.status == 200:
return await response.json()
else:
log.error(f"Failed to process {endpoint} status_code: {response.status}")
return None
except aiohttp.ClientConnectionError as e:
log.error(f"Failed to connect to AI server. Is the AI server running at {config.API_BASE_URL}? {e}")
raise e
async def get_api_async(session, endpoint, params=None):
url = f'{config.API_BASE_URL}/{endpoint}'
try:
async with session.get(url, params=params) as response:
if response.status == 200:
return await response.json()
else:
log.error(f"Failed to process {endpoint} status_code: {response.status}")
return None
except aiohttp.ClientConnectionError as e:
log.error(f"Failed to connect to AI server. Is the AI server running at {config.API_BASE_URL}? {e}")
raise e
async def process_images_async(image_paths, threshold=config.IMAGE_THRESHOLD, return_confidence=False):
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=config.SERVER_TIMEOUT)) as session:
return await post_api_async(session, 'process_images/', {"paths": image_paths, "threshold": threshold, "return_confidence": return_confidence})
async def process_video_async(video_path, frame_interval=config.FRAME_INTERVAL,threshold=config.AI_VIDEO_THRESHOLD, return_confidence=True ,vr_video=False):
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=config.SERVER_TIMEOUT)) as session:
return await post_api_async(session, 'process_video/', {"path": video_path, "frame_interval": frame_interval, "threshold": threshold, "return_confidence": return_confidence, "vr_video": vr_video})
async def get_image_config_async(threshold=config.IMAGE_THRESHOLD):
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=config.SERVER_TIMEOUT)) as session:
return await get_api_async(session, f'image_pipeline_info/?threshold={threshold}')
async def get_video_config_async(frame_interval=config.FRAME_INTERVAL, threshold=config.AI_VIDEO_THRESHOLD):
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=config.SERVER_TIMEOUT)) as session:
return await get_api_async(session, f'video_pipeline_info/?frame_interval={frame_interval}&threshold={threshold}&return_confidence=True')
class VideoResult(pydantic.BaseModel):
result: List[Dict[str, Any]] = pydantic.Field(..., min_items=1)
pipeline_short_name: str
pipeline_version: float
threshold: float
frame_interval: float
return_confidence: bool
class ImageResult(pydantic.BaseModel):
result: List[Dict[str, Any]] = pydantic.Field(..., min_items=1)
pipeline_short_name: str
pipeline_version: float
threshold: float
return_confidence: bool
class ImagePipelineInfo(pydantic.BaseModel):
pipeline_short_name: str
pipeline_version: float
threshold: float
return_confidence: bool
class VideoPipelineInfo(pydantic.BaseModel):
pipeline_short_name: str
pipeline_version: float
threshold: float
frame_interval: float
return_confidence: bool
async def get_current_video_pipeline():
global current_videopipeline
if current_videopipeline is not None:
return current_videopipeline
try:
current_videopipeline = VideoPipelineInfo(**await get_video_config_async())
except aiohttp.ClientConnectionError as e:
log.error(f"Failed to connect to AI server. Is the AI server running at {config.API_BASE_URL}? {e}")
except Exception as e:
log.error(f"Failed to get pipeline info: {e}. Ensure the AI server is running with at least version 1.3.1!")
raise
return current_videopipeline

View File

@@ -2,11 +2,8 @@ import os
import sys
import json
import subprocess
import csv
import zipfile
import shutil
from typing import Any
import traceback
# ----------------- Setup -----------------
def install(package):
@@ -52,7 +49,9 @@ try:
except ModuleNotFoundError:
log.error("Please provide a config.py file with the required variables.")
raise Exception("Please provide a config.py file with the required variables.")
from ai_video_result import AIVideoResult
import media_handler
import ai_server
except:
log.error("Attempted to install required packages, please retry the task.")
sys.exit(1)
@@ -60,14 +59,10 @@ except:
# ----------------- Variable Definitions -----------------
tagid_mappings = {}
tagname_mappings = {}
max_gaps = {}
min_durations = {}
required_durations = {}
semaphore = asyncio.Semaphore(config.CONCURRENT_TASK_LIMIT)
progress = 0
increment = 0.0
current_videopipeline = None
# ----------------- Main Execution -----------------
@@ -84,31 +79,13 @@ def read_json_input():
async def run(json_input, output):
PLUGIN_ARGS = False
HOOKCONTEXT = False
global stash
global aierroed_tag_id
global tagme_tag_id
global ai_base_tag_id
global ai_tagged_tag_id
global updateme_tag_id
try:
log.debug(json_input["server_connection"])
os.chdir(json_input["server_connection"]["PluginDir"])
stash = StashInterface(json_input["server_connection"])
aierroed_tag_id = stash.find_tag(config.aierrored_tag_name, create=True)["id"]
tagme_tag_id = stash.find_tag(config.tagme_tag_name, create=True)["id"]
ai_base_tag_id = stash.find_tag(config.ai_base_tag_name, create=True)["id"]
ai_tagged_tag_id = stash.find_tag(config.aitagged_tag_name, create=True)["id"]
updateme_tag_id = stash.find_tag(config.updateme_tag_name, create=True)["id"]
media_handler.initialize(json_input["server_connection"])
except Exception:
raise
try:
parse_csv("tag_mappings.csv")
except Exception as e:
log.error("Failed to parse tag_mappings.csv: {e}")
try:
PLUGIN_ARGS = json_input['args']["mode"]
except:
@@ -124,12 +101,11 @@ async def run(json_input, output):
output["output"] = "ok"
return
# ----------------- High Level Calls -----------------
async def tag_images():
global increment
images = stash.find_images(f={"tags": {"value":tagme_tag_id, "modifier":"INCLUDES"}}, fragment="id files {path}")
images = media_handler.get_tagme_images()
if images:
image_batches = [images[i:i + config.IMAGE_REQUEST_BATCH_SIZE] for i in range(0, len(images), config.IMAGE_REQUEST_BATCH_SIZE)]
increment = 1.0 / len(image_batches)
@@ -141,9 +117,9 @@ async def tag_images():
async def tag_scenes():
global increment
scenes = stash.find_scenes(f={"tags": {"value":tagme_tag_id, "modifier":"INCLUDES"}}, fragment="id files {path}")
increment = 1.0 / len(scenes)
scenes = media_handler.get_tagme_scenes()
if scenes:
increment = 1.0 / len(scenes)
tasks = [__tag_scene(scene) for scene in scenes]
await asyncio.gather(*tasks)
else:
@@ -151,33 +127,42 @@ async def tag_scenes():
# ----------------- Image Processing -----------------
def add_error_images(image_ids):
stash.update_images({"ids": image_ids, "tag_ids": {"ids": [aierroed_tag_id], "mode": "ADD"}})
async def __tag_images(images):
async with semaphore:
imagePaths = [image['files'][0]['path'] for image in images]
imageIds = [image['id'] for image in images]
temp_files = []
for i, path in enumerate(imagePaths):
if '.zip' in path:
zip_index = path.index('.zip') + 4
zip_path, img_path = path[:zip_index], path[zip_index+1:].replace('\\', '/')
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
temp_path = os.path.join(config.temp_image_dir, img_path)
os.makedirs(os.path.dirname(temp_path), exist_ok=True)
zip_ref.extract(img_path, config.temp_image_dir)
imagePaths[i] = os.path.abspath(os.path.normpath(temp_path))
temp_files.append(imagePaths[i])
imagePaths, imageIds, temp_files = media_handler.get_image_paths_and_ids(images)
try:
server_results = ImageResult(**await process_images_async(imagePaths))
process_server_image_results(server_results, imageIds)
server_result = await ai_server.process_images_async(imagePaths)
if server_result is None:
log.error("Server returned no results")
media_handler.add_error_images(imageIds)
media_handler.remove_tagme_tags_from_images(imageIds)
return
server_results = ai_server.ImageResult(**server_result)
results = server_results.result
if len(results) != len(imageIds):
log.error("Server returned incorrect number of results")
media_handler.add_error_images(imageIds)
else:
for id, result in zip(imageIds, results):
if 'error' in result:
log.error(f"Error processing image: {result['error']}")
media_handler.add_error_images([id])
else:
actions = result['actions']
action_stashtag_ids = media_handler.get_tag_ids(actions)
action_stashtag_ids.append(media_handler.ai_tagged_tag_id)
media_handler.add_tags_to_image(id, action_stashtag_ids)
log.info(f"Tagged {len(imageIds)} images")
media_handler.remove_tagme_tags_from_images(imageIds)
except aiohttp.ClientConnectionError as e:
log.error(f"Failed to connect to AI server. Is the AI server running at {config.API_BASE_URL}? {e}")
except asyncio.TimeoutError as a:
log.error(f"Timeout processing images: {a}")
except Exception as e:
log.error(f"Failed to process images: {e}")
add_error_images(imageIds)
stash.update_images({"ids": imageIds, "tag_ids": {"ids": [tagme_tag_id], "mode": "REMOVE"}})
media_handler.add_error_images(imageIds)
media_handler.remove_tagme_tags_from_images(imageIds)
finally:
increment_progress()
for temp_file in temp_files:
@@ -185,203 +170,79 @@ async def __tag_images(images):
shutil.rmtree(temp_file)
else:
os.remove(temp_file)
def process_server_image_results(server_results, imageIds):
results = server_results.result
if results is None:
add_error_images(imageIds)
elif len(results) == 0:
log.error("Server returned no results")
add_error_images(imageIds)
elif len(results) != len(imageIds):
log.error("Server returned incorrect number of results")
add_error_images(imageIds)
else:
for id, result in zip(imageIds, results):
if 'error' in result:
log.error(f"Error processing image: {result['error']}")
stash.update_images({"ids": [id], "tag_ids": {"ids": [aierroed_tag_id], "mode": "ADD"}})
else:
actions = result['actions']
action_stashtag_ids = [tagid_mappings[action] for action in actions if action in tagid_mappings]
action_stashtag_ids.append(ai_tagged_tag_id)
stash.update_images({"ids": [id], "tag_ids": {"ids": action_stashtag_ids, "mode": "ADD"}})
log.info(f"Tagged {len(imageIds)} images")
stash.update_images({"ids": imageIds, "tag_ids": {"ids": [tagme_tag_id], "mode": "REMOVE"}})
# ----------------- Scene Processing -----------------
def add_error_scene(scene_id):
stash.update_scenes({"ids": [scene_id], "tag_ids": {"ids": [aierroed_tag_id], "mode": "ADD"}})
async def __tag_scene(scene):
async with semaphore:
scenePath = scene['files'][0]['path']
sceneId = scene['id']
log.info("files result:" + str(scene['files'][0]))
phash = scene['files'][0].get('fingerprint', None)
duration = scene['files'][0].get('duration', None)
if duration is None:
log.error(f"Scene {sceneId} has no duration")
return
try:
server_result = VideoResult(**await process_video_async(scenePath))
process_server_video_result(server_result, sceneId, scenePath)
already_ai_tagged = media_handler.is_scene_tagged(scene.get('tags'))
ai_file_path = scenePath + ".AI.json"
ai_video_result = None
if already_ai_tagged:
if os.path.exists(ai_file_path):
try:
ai_video_result = AIVideoResult.from_json_file(ai_file_path)
current_pipeline_video = await ai_server.get_current_video_pipeline()
if ai_video_result.already_contains_model(current_pipeline_video):
log.info(f"Skipping running AI for scene {scenePath} as it has already been processed with the same pipeline version and configuration. Updating tags and markers instead.")
ai_video_result.update_stash_tags()
ai_video_result.update_stash_markers()
return
except Exception as e:
log.error(f"Failed to load AI results from file: {e}")
elif os.path.exists(os.path.join(os.path.dirname(scenePath), os.path.splitext(os.path.basename(scenePath))[0] + f"__vid_giddy__1.0.csv")):
ai_video_result = AIVideoResult.from_csv_file(os.path.join(os.path.dirname(scenePath), os.path.splitext(os.path.basename(scenePath))[0] + f"__vid_giddy__1.0.csv"), scene_id=sceneId, phash=phash, duration=duration)
log.info(f"Loading AI results from CSV file for scene {scenePath}: {ai_video_result}")
current_pipeline_video = await ai_server.get_current_video_pipeline()
if ai_video_result.already_contains_model(current_pipeline_video):
log.info(f"Skipping running AI for scene {scenePath} as it has already been processed with the same pipeline version and configuration. Updating tags and markers instead.")
ai_video_result.to_json_file(ai_file_path)
ai_video_result.update_stash_tags()
ai_video_result.update_stash_markers()
return
else:
log.warning(f"Scene {scenePath} is already tagged but has no AI results file. Running AI again.")
vr_video = media_handler.is_vr_scene(scene.get('tags'))
if vr_video:
log.info(f"Processing VR video {scenePath}")
server_result = await ai_server.process_video_async(scenePath, vr_video)
if server_result is None:
log.error("Server returned no results")
media_handler.add_error_scene(sceneId)
media_handler.remove_tagme_tag_from_scene(sceneId)
return
server_result = ai_server.VideoResult(**server_result)
if ai_video_result:
ai_video_result.add_server_response(server_result)
else:
ai_video_result = AIVideoResult.from_server_response(server_result, sceneId, phash, duration)
ai_video_result.to_json_file(ai_file_path)
ai_video_result.update_stash_tags()
ai_video_result.update_stash_markers()
log.info(f"Processed video with {len(server_result.result)} AI tagged frames")
except aiohttp.ClientConnectionError as e:
log.error(f"Failed to connect to AI server. Is the AI server running at {config.API_BASE_URL}? {e}")
except asyncio.TimeoutError as a:
log.error(f"Timeout processing scene: {a}")
except Exception as e:
log.error(f"Failed to process video: {e}")
add_error_scene(sceneId)
stash.update_scenes({"ids": [sceneId], "tag_ids": {"ids": [tagme_tag_id], "mode": "REMOVE"}})
log.error(f"Failed to process video: {e}\n{traceback.format_exc()}")
media_handler.add_error_scene(sceneId)
media_handler.remove_tagme_tag_from_scene(sceneId)
return
finally:
increment_progress()
def process_server_video_result(server_result, sceneId, scenePath):
results = server_result.result
if results is None:
add_error_scene(sceneId)
elif len(results) == 0:
log.error("Server returned no results")
add_error_scene(sceneId)
else:
# Get the directory of the scene file
directory = os.path.dirname(scenePath)
# Get the base name of the scene file (without the extension)
base_name = os.path.splitext(os.path.basename(scenePath))[0]
# Create the CSV file path
csv_path = os.path.join(directory, base_name + f"__{server_result.pipeline_short_name}__{server_result.pipeline_version}.csv")
save_to_csv(csv_path, results)
# Step 1: Group results by tag
timespan = results[1]['frame_index'] - results[0]['frame_index']
log.debug(f"Server returned results every {timespan}s")
tag_timestamps = {}
for result in results:
for action in result['actions']:
if action not in tag_timestamps:
tag_timestamps[action] = []
tag_timestamps[action].append(result['frame_index'])
# Step 2: Process each tag
tag_durations = {}
for tag, timestamps in tag_timestamps.items():
start = timestamps[0]
total_duration = 0
for i in range(1, len(timestamps)):
if timestamps[i] - timestamps[i - 1] > timespan + max_gaps.get(tag, 0):
# End of current marker, start of new one
duration = timestamps[i - 1] - start
min_duration_temp = min_durations.get(tag, 0)
min_duration_temp = min_duration_temp if min_duration_temp > timespan else timespan
if duration >= min_duration_temp:
# The marker is long enough, add its duration
total_duration += duration
# README: This code works for generating markers but stash markers don't have a way to be deleted in batch and are missing a lot of other
# needed features so this code will remain disabled until stash adds the needed features.
# log.debug(f"Creating marker for {tagname_mappings[tag]} with range {start} - {timestamps[i - 1]}")
# stash.create_scene_marker({"scene_id": sceneId, "primary_tag_id":tagid_mappings[tag], "tag_ids": [tagid_mappings[tag]], "seconds": start, "title":tagname_mappings[tag]})
start = timestamps[i]
# Check the last marker
duration = timestamps[-1] - start
if duration >= min_durations.get(tag, 0):
total_duration += duration
# README: This code works for generating markers but stash markers don't have a way to be deleted in batch and are missing a lot of other
# needed features so this code will remain disabled until stash adds the needed features.
# log.debug(f"Creating marker for {tagname_mappings[tag]} with range {start} - {timestamps[-1]}")
# stash.create_scene_marker({"scene_id": sceneId, "primary_tag_id":tagid_mappings[tag], "tag_ids": [tagid_mappings[tag]], "seconds": start, "title":tagname_mappings[tag]})
tag_durations[tag] = total_duration
scene_duration = results[-1]['frame_index']
# Step 3: Check if each tag meets the required duration
tags_to_add = [ai_tagged_tag_id]
for tag, duration in tag_durations.items():
required_duration = required_durations.get(tag, "0s")
if required_duration.endswith("s"):
required_duration = float(required_duration[:-1])
elif required_duration.endswith("%"):
required_duration = float(required_duration[:-1]) / 100 * scene_duration
if duration < required_duration:
log.debug(f"Tag {tagname_mappings[tag]} does not meet the required duration of {required_duration}s. It only has a duration of {duration}s.")
else:
log.debug(f"Tag {tagname_mappings[tag]} meets the required duration of {required_duration}s. It has a duration of {duration}s.")
tags_to_add.append(tagid_mappings[tag])
log.info(f"Processed video with {len(results)} AI tagged frames")
stash.update_scenes({"ids": [sceneId], "tag_ids": {"ids": [tagme_tag_id], "mode": "REMOVE"}})
stash.update_scenes({"ids": [sceneId], "tag_ids": {"ids": tags_to_add, "mode": "ADD"}})
# ----------------- AI Server Calling Functions -----------------
async def call_api_async(session, endpoint, payload):
url = f'{config.API_BASE_URL}/{endpoint}'
try:
async with session.post(url, json=payload) as response:
if response.status == 200:
return await response.json()
else:
log.error(f"Failed to process {endpoint} status_code: {response.status}")
return None
except aiohttp.ClientConnectionError as e:
log.error(f"Failed to connect to AI server. Is the AI server running at {config.API_BASE_URL}? {e}")
raise Exception(f"Failed to connect to AI server. Is the AI server running at {config.API_BASE_URL}?")
async def process_images_async(image_paths):
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800)) as session:
return await call_api_async(session, 'process_images/', {"paths": image_paths})
async def process_video_async(video_path):
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800)) as session:
return await call_api_async(session, 'process_video/', {"path": video_path})
class VideoResult(pydantic.BaseModel):
result: Any
pipeline_short_name: str
pipeline_version: float
class ImageResult(pydantic.BaseModel):
result: Any
pipeline_short_name: str
pipeline_version: float
# ----------------- Utility Functions -----------------
def save_to_csv(file_path, server_result):
with open(file_path, mode='w', newline='') as outfile:
writer = csv.writer(outfile)
for result in server_result:
timestamp = result["frame_index"]
row = [timestamp] + result['actions']
writer.writerow(row)
def parse_csv(file_path):
global tagid_mappings
global tagname_mappings
global max_gaps
global min_durations
global required_durations
with open(file_path, mode='r') as infile:
reader = csv.DictReader(infile)
for row in reader:
server_tag = row['ServerTag']
stash_tag = row['StashTag']
min_duration = float(row['MinDuration'])
max_gap = float(row['MaxGap'])
required_duration = row['RequiredDuration']
tag = stash.find_tag(stash_tag)
if not tag:
tag = stash.create_tag({"name":stash_tag, "ignore_auto_tag": True, "parent_ids":[ai_base_tag_id]})
tagid_mappings[server_tag] = tag["id"]
tagname_mappings[server_tag] = stash_tag
min_durations[server_tag] = min_duration
max_gaps[server_tag] = max_gap
required_durations[server_tag] = required_duration
def increment_progress():
global progress
global increment

View File

@@ -1,6 +1,6 @@
name: AI Tagger
description: Tag videos and Images with Locally hosted AI using Skier's Patreon AI models
version: 1.2
version: 1.3
url: https://github.com/stashapp/CommunityScripts/tree/main/plugins/AITagger
exec:
- python

View File

@@ -0,0 +1,192 @@
from copy import deepcopy
import csv
from typing import Dict, List, Optional
from pydantic import BaseModel
import stashapi.log as log
import config
import media_handler
class ModelConfig(BaseModel):
frame_interval: float
threshold: float
def __str__(self):
return f"ModelConfig(frame_interval={self.frame_interval}, threshold={self.threshold})"
class ModelInfo(BaseModel):
version: float
ai_model_config: ModelConfig
def __str__(self):
return f"ModelInfo(version={self.version}, ai_model_config={self.ai_model_config})"
class VideoMetadata(BaseModel):
video_id: int
duration: float
phash: Optional[str]
models: Dict[str, ModelInfo]
def __str__(self):
return f"VideoMetadata(video_id={self.video_id}, duration={self.duration}, phash={self.phash}, models={self.models})"
class TagTimeFrame(BaseModel):
start: float
end: Optional[float] = None
confidence: float
def __str__(self):
return f"TagTimeFrame(start={self.start}, end={self.end}, confidence={self.confidence})"
class TagData(BaseModel):
ai_model_name: str
time_frames: List[TagTimeFrame]
def __str__(self):
return f"TagData(model_name={self.ai_model_name}, time_frames={self.time_frames})"
class AIVideoResult(BaseModel):
video_metadata: VideoMetadata
tags: Dict[str, TagData]
def add_server_response(self, response):
frame_interval = response.frame_interval
model_name = response.pipeline_short_name
if model_name in self.video_metadata.models:
self.tags = {tag_name: tag_data for tag_name, tag_data in self.tags.items() if tag_data.ai_model_name != model_name}
model_info = ModelInfo(version=response.pipeline_version, ai_model_config=ModelConfig(frame_interval=frame_interval, threshold=response.threshold))
self.video_metadata.models[model_name] = model_info
tagsToAdd = AIVideoResult.__mutate_server_result_tags(response.result, model_name, frame_interval)
for tag_name, tag_data in tagsToAdd.items():
self.tags[tag_name] = tag_data
def update_stash_tags(self):
tagsToAdd = []
for tag_name, tag_data in self.tags.items():
if media_handler.is_ai_tag(tag_name):
requiredDuration = media_handler.get_required_duration(tag_name, self.video_metadata.duration)
# To disable making tags for certain tags, set RequiredDuration to above 100%
if requiredDuration <= self.video_metadata.duration:
totalDuration = 0.0
frame_interval = self.video_metadata.models[tag_data.ai_model_name].ai_model_config.frame_interval
tag_threshold = media_handler.get_tag_threshold(tag_name)
tag_id = media_handler.get_tag_id(tag_name)
for time_frame in tag_data.time_frames:
if time_frame.end is None and time_frame.confidence >= tag_threshold:
totalDuration += frame_interval
elif time_frame.confidence >= tag_threshold:
totalDuration += time_frame.end - time_frame.start + frame_interval
if totalDuration >= requiredDuration:
tagsToAdd.append(tag_id)
media_handler.remove_ai_tags_from_video(self.video_metadata.video_id, True)
media_handler.add_tags_to_video(self.video_metadata.video_id, tagsToAdd, True)
def update_stash_markers(self):
if not config.CREATE_MARKERS:
log.debug("Not creating markers since marker creation is disabled")
return
media_handler.remove_ai_markers_from_video(self.video_metadata.video_id)
for tag_name, tag_data in self.tags.items():
if media_handler.is_ai_marker_supported(tag_name):
tag_threshold = media_handler.get_tag_threshold(tag_name)
frame_interval = self.video_metadata.models[tag_data.ai_model_name].ai_model_config.frame_interval
tag_id = media_handler.get_tag_id(tag_name)
max_gap = media_handler.get_max_gap(tag_name)
min_duration = media_handler.get_min_duration(tag_name)
merged_time_frames = []
for time_frame in tag_data.time_frames:
if time_frame.confidence < tag_threshold:
continue
if not merged_time_frames:
merged_time_frames.append(deepcopy(time_frame))
continue
else:
last_time_frame = merged_time_frames[-1]
if last_time_frame.end is None:
if (time_frame.start - last_time_frame.start - frame_interval) <= max_gap:
last_time_frame.end = time_frame.end or time_frame.start
else:
merged_time_frames.append(deepcopy(time_frame))
else:
if (time_frame.start - last_time_frame.end - frame_interval) <= max_gap:
last_time_frame.end = time_frame.end or time_frame.start
else:
merged_time_frames.append(deepcopy(time_frame))
merged_time_frames = [tf for tf in merged_time_frames if (tf.end or tf.start) - tf.start + frame_interval >= min_duration]
media_handler.add_markers_to_video(self.video_metadata.video_id, tag_id, tag_name, merged_time_frames)
def already_contains_model(self, model_config):
correspondingModelInfo = self.video_metadata.models.get(model_config.pipeline_short_name)
toReturn = (correspondingModelInfo is not None and correspondingModelInfo.version == model_config.pipeline_version and
correspondingModelInfo.ai_model_config.frame_interval == model_config.frame_interval and
correspondingModelInfo.ai_model_config.threshold == model_config.threshold)
log.info(f"Already contains model: {toReturn}, {correspondingModelInfo is not None}, {correspondingModelInfo.version == model_config.pipeline_version}, {correspondingModelInfo.ai_model_config.frame_interval == model_config.frame_interval}, {correspondingModelInfo.ai_model_config.threshold == model_config.threshold}")
return toReturn
def __str__(self):
return f"AIVideoResult(video_metadata={self.video_metadata}, tags={self.tags})"
def to_json_file(self, json_file):
with open(json_file, 'w') as f:
f.write(self.model_dump_json(exclude_none=True))
@classmethod
def from_server_response(cls, response, sceneId, phash, duration):
frame_interval = response.frame_interval
model_name = response.pipeline_short_name
model_info = ModelInfo(version=response.pipeline_version, ai_model_config=ModelConfig(frame_interval=frame_interval, threshold=response.threshold))
video_metadata = VideoMetadata(video_id=sceneId, phash=phash, models={model_name : model_info}, duration=duration)
tags = AIVideoResult.__mutate_server_result_tags(response.result, model_name, frame_interval)
return cls(video_metadata=video_metadata, tags=tags)
@classmethod
def __mutate_server_result_tags(cls, server_result, model_name, frame_interval):
tags = {}
for result in server_result:
frame_index = result["frame_index"]
actions = result["actions"]
for action in actions:
tag_name, confidence = action
if tag_name not in tags:
tags[tag_name] = TagData(ai_model_name=model_name, time_frames=[TagTimeFrame(start=frame_index, end=None, confidence=confidence)])
else:
last_time_frame = tags[tag_name].time_frames[-1]
if last_time_frame.end is None:
if frame_index - last_time_frame.start == frame_interval and last_time_frame.confidence == confidence:
last_time_frame.end = frame_index
else:
tags[tag_name].time_frames.append(TagTimeFrame(start=frame_index, end=None, confidence=confidence))
elif last_time_frame.confidence == confidence and frame_index - last_time_frame.end == frame_interval:
last_time_frame.end = frame_index
else:
tags[tag_name].time_frames.append(TagTimeFrame(start=frame_index, end=None, confidence=confidence))
return tags
@classmethod
def from_json_file(cls, json_file):
with open(json_file, 'r') as f:
return cls.model_validate_json(f.read())
@classmethod
def from_csv_file(cls, csv_file, scene_id, phash, duration):
server_results = []
frame_interval = None
last_frame_index = None
with open(csv_file, 'r') as f:
reader = csv.reader(f)
for row in reader:
frame_index = float(row[0])
if last_frame_index is not None and frame_interval is None:
log.info(f"Calculating frame interval: {frame_index} - {last_frame_index}")
frame_interval = frame_index - last_frame_index
for tag_name in row[1:]: # Skip the first column (frame_indexes)
if tag_name: # If the cell is not empty
server_results.append({"frame_index": frame_index, "actions": [(tag_name, 1.0)]})
last_frame_index = frame_index
tags = cls.__mutate_server_result_tags(server_results, "actiondetection", frame_interval)
model_info = ModelInfo(version=1.0, ai_model_config=ModelConfig(frame_interval=frame_interval, threshold=0.3))
video_metadata = VideoMetadata(video_id=scene_id, phash=phash, models={"actiondetection" : model_info}, duration=duration)
return cls(video_metadata=video_metadata, tags=tags)

View File

@@ -1,8 +1,14 @@
CREATE_MARKERS = True
FRAME_INTERVAL = 0.5
IMAGE_THRESHOLD = 0.5
API_BASE_URL = 'http://localhost:8000'
IMAGE_REQUEST_BATCH_SIZE = 320
CONCURRENT_TASK_LIMIT = 10
temp_image_dir = "./temp_images"
SERVER_TIMEOUT = 2700
AI_VIDEO_THRESHOLD = 0.3
temp_image_dir = "./temp_images"
ai_base_tag_name = "AI"
tagme_tag_name = "AI_TagMe"
updateme_tag_name = "AI_UpdateMe"

View File

@@ -0,0 +1,199 @@
import csv
import os
import zipfile
from stashapi.stashapp import StashInterface
import stashapi.log as log
import config
tagid_mappings = {}
tagname_mappings = {}
max_gaps = {}
min_durations = {}
required_durations = {}
tag_thresholds = {}
def initialize(connection):
global stash
global aierroed_tag_id
global tagme_tag_id
global ai_base_tag_id
global ai_tagged_tag_id
global vr_tag_id
# Initialize the Stash API
stash = StashInterface(connection)
# Initialize "metadata" tags
aierroed_tag_id = stash.find_tag(config.aierrored_tag_name, create=True)["id"]
tagme_tag_id = stash.find_tag(config.tagme_tag_name, create=True)["id"]
ai_base_tag_id = stash.find_tag(config.ai_base_tag_name, create=True)["id"]
ai_tagged_tag_id = stash.find_tag(config.aitagged_tag_name, create=True)["id"]
vr_tag_id = stash.find_tag(stash.get_configuration()["ui"]["vrTag"])["id"]
try:
parse_csv("tag_mappings.csv")
except Exception as e:
log.error(f"Failed to parse tag_mappings.csv: {e}")
# ----------------- Tag Methods -----------------
def get_tag_ids(tag_names):
return [get_tag_id(tag_name) for tag_name in tag_names]
def get_tag_id(tag_name):
if tag_name not in tagid_mappings:
return stash.find_tag(tag_name)["id"]
return tagid_mappings.get(tag_name)
def get_tag_threshold(tag_name):
return tag_thresholds.get(tag_name, 0.5)
def is_ai_tag(tag_name):
return tag_name in tagname_mappings
def is_scene_tagged(tags):
for tag in tags:
if tag['id'] == ai_tagged_tag_id:
return True
return False
def is_vr_scene(tags):
for tag in tags:
if tag['id'] == vr_tag_id:
return True
return False
# ----------------- Image Methods -----------------
def get_tagme_images():
return stash.find_images(f={"tags": {"value":tagme_tag_id, "modifier":"INCLUDES"}}, fragment="id files {path}")
def add_error_images(image_ids):
stash.update_images({"ids": image_ids, "tag_ids": {"ids": [aierroed_tag_id], "mode": "ADD"}})
def remove_tagme_tags_from_images(image_ids):
stash.update_images({"ids": image_ids, "tag_ids": {"ids": [tagme_tag_id], "mode": "REMOVE"}})
def add_tags_to_image(image_id, tag_ids):
stash.update_images({"ids": [image_id], "tag_ids": {"ids": tag_ids, "mode": "ADD"}})
def get_image_paths_and_ids(images):
imagePaths = []
imageIds = []
temp_files = []
for image in images:
try:
imagePath = image['files'][0]['path']
imageId = image['id']
if '.zip' in imagePath:
zip_index = imagePath.index('.zip') + 4
zip_path, img_path = imagePath[:zip_index], imagePath[zip_index+1:].replace('\\', '/')
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
temp_path = os.path.join(config.temp_image_dir, img_path)
os.makedirs(os.path.dirname(temp_path), exist_ok=True)
zip_ref.extract(img_path, config.temp_image_dir)
imagePath = os.path.abspath(os.path.normpath(temp_path))
temp_files.append(imagePath)
imagePaths.append(imagePath)
imageIds.append(imageId)
except IndexError:
log.error(f"Failed to process image: {image}")
continue
return imagePaths, imageIds, temp_files
# ----------------- Scene Methods -----------------
def add_tags_to_video(video_id, tag_ids, add_tagged=True):
if add_tagged:
tag_ids.append(ai_tagged_tag_id)
stash.update_scenes({"ids": [video_id], "tag_ids": {"ids": tag_ids, "mode": "ADD"}})
def remove_ai_tags_from_video(video_id, remove_tagme=True):
ai_tags = list(tagid_mappings.values())
if remove_tagme:
ai_tags.append(tagme_tag_id)
stash.update_scenes({"ids": [video_id], "tag_ids": {"ids": ai_tags, "mode": "REMOVE"}})
def get_tagme_scenes():
return stash.find_scenes(f={"tags": {"value":tagme_tag_id, "modifier":"INCLUDES"}}, fragment="id tags {id} files {path duration fingerprint(type: \"phash\")}")
def add_error_scene(scene_id):
stash.update_scenes({"ids": [scene_id], "tag_ids": {"ids": [aierroed_tag_id], "mode": "ADD"}})
def remove_tagme_tag_from_scene(scene_id):
stash.update_scenes({"ids": [scene_id], "tag_ids": {"ids": [tagme_tag_id], "mode": "REMOVE"}})
def get_required_duration(tag_name, scene_duration):
if not required_durations:
log.error("Tag mappings not initialized")
required_duration_value = str(required_durations.get(tag_name))
required_duration_value = required_duration_value.replace(" ", "").lower()
if required_duration_value.endswith("s"):
# If the value ends with 's', remove 's' and convert to float
return float(required_duration_value[:-1])
elif required_duration_value.endswith("%"):
# If the value ends with '%', remove '%' and calculate the percentage of scene_duration
percentage = float(required_duration_value[:-1])
return (percentage / 100) * scene_duration
elif "." in required_duration_value and 0 <= float(required_duration_value) <= 1:
# If the value is a proportion, calculate the proportion of scene_duration
proportion = float(required_duration_value)
return proportion * scene_duration
else:
# If the value is a straight number, convert to float
return float(required_duration_value)
# ----------------- Marker Methods -----------------
def is_ai_marker_supported(tag_name):
return tag_name in min_durations
def get_min_duration(tag_name):
return min_durations.get(tag_name)
def get_max_gap(tag_name):
return max_gaps.get(tag_name, 0)
def add_markers_to_video(video_id, tag_id, tag_name, time_frames):
for time_frame in time_frames:
stash.create_scene_marker({"scene_id": video_id, "primary_tag_id":tag_id, "tag_ids": [tag_id], "seconds": time_frame.start, "title":tagname_mappings[tag_name]})
def remove_ai_markers_from_video(video_id):
ai_tags = set(tagid_mappings.values())
scene_markers = stash.get_scene_markers(video_id, fragment="id primary_tag {id}")
for scene_marker in scene_markers:
if scene_marker.get("primary_tag").get("id") in ai_tags:
stash.destroy_scene_marker(scene_marker.get("id"))
# ----------------- Helpers -----------------
def parse_csv(file_path):
global tagid_mappings
global tagname_mappings
global max_gaps
global min_durations
global required_durations
global tag_thresholds
with open(file_path, mode='r') as infile:
reader = csv.DictReader(infile)
for row in reader:
server_tag = row.get('ServerTag')
stash_tag = row.get('StashTag')
min_duration = float(row.get('MinMarkerDuration', -1)) #float(row['MinDuration'])
max_gap = float(row.get('MaxGap', 0)) #float(row['MaxGap'])
required_duration = row.get('RequiredDuration', "200%")
tag_threshold = float(row.get('TagThreshold', 0.5))
tag = stash.find_tag(stash_tag)
if not tag:
tag = stash.create_tag({"name":stash_tag, "ignore_auto_tag": True, "parent_ids":[ai_base_tag_id]})
tagid_mappings[server_tag] = tag["id"]
tagname_mappings[server_tag] = stash_tag
if min_duration != -1:
min_durations[server_tag] = min_duration
max_gaps[server_tag] = max_gap
required_durations[server_tag] = required_duration
tag_thresholds[server_tag] = tag_threshold

View File

@@ -1,37 +1,37 @@
ServerTag,StashTag,MinDuration,MaxGap,RequiredDuration
69,69_AI,5,2,20s
Anal Fucking,Anal Fucking_AI,5,2,20s
Ass Licking,Ass Licking_AI,5,2,20s
Ass Penetration,Ass Penetration_AI,5,2,20s
Ball Licking/Sucking,Ball Licking/Sucking_AI,2,2,20s
Blowjob,Blowjob_AI,5,2,20s
Cum on Person,Cum on Person_AI,3,2,15s
Cum Swapping,Cum Swapping_AI,2,2,15s
Cumshot,Cumshot_AI,1,2,10s
Deepthroat,Deepthroat_AI,1,2,20s
Double Penetration,Double Penetration_AI,5,2,20s
Fingering,Fingering_AI,5,2,20s
Fisting,Fisting_AI,3,2,20s
Footjob,Footjob_AI,3,2,20s
Gangbang,Gangbang_AI,5,2,20s
Gloryhole,Gloryhole_AI,5,2,20s
Grabbing Ass,Grabbing Ass_AI,5,2,20s
Grabbing Boobs,Grabbing Boobs_AI,5,2,20s
Grabbing Hair/Head,Grabbing Hair/Head_AI,5,2,20s
Handjob,Handjob_AI,5,2,20s
Kissing,Kissing_AI,5,2,20s
Licking Penis,Licking Penis_AI,2,2,20s
Masturbation,Masturbation_AI,5,2,20s
Pissing,Pissing_AI,2,2,20s
Pussy Licking (Clearly Visible),Pussy Licking (Clearly Visible)_AI,5,2,20s
Pussy Licking,Pussy Licking_AI,3,2,20s
Pussy Rubbing,Pussy Rubbing_AI,5,2,20s
Sucking Fingers,Sucking Fingers_AI,1,2,20s
Sucking Toy/Dildo,Sucking Toy/Dildo_AI,1,2,20s
Wet (Genitals),Wet (Genitals)_AI,3,2,20s
Titjob,Titjob_AI,5,2,20s
Tribbing/Scissoring,Tribbing/Scissoring_AI,3,2,20s
Undressing,Undressing_AI,3,2,20s
Vaginal Penetration,Vaginal Penetration_AI,5,2,20s
Vaginal Fucking,Vaginal Fucking_AI,5,2,20s
Vibrating,Vibrating_AI,5,2,20s
ServerTag,StashTag,MinMarkerDuration,MaxGap,RequiredDuration,TagThreshold
69,69_AI,5,2,20s,0.5
Anal Fucking,Anal Fucking_AI,5,2,20s,0.5
Ass Licking,Ass Licking_AI,5,2,20s,0.5
Ass Penetration,Ass Penetration_AI,5,2,20s,0.5
Ball Licking/Sucking,Ball Licking/Sucking_AI,2,2,20s,0.5
Blowjob,Blowjob_AI,5,2,20s,0.5
Cum on Person,Cum on Person_AI,3,2,15s,0.5
Cum Swapping,Cum Swapping_AI,2,2,15s,0.5
Cumshot,Cumshot_AI,1,2,10s,0.5
Deepthroat,Deepthroat_AI,1,2,20s,0.5
Double Penetration,Double Penetration_AI,5,2,20s,0.5
Fingering,Fingering_AI,5,2,20s,0.5
Fisting,Fisting_AI,3,2,20s,0.5
Footjob,Footjob_AI,3,2,20s,0.5
Gangbang,Gangbang_AI,5,2,20s,0.5
Gloryhole,Gloryhole_AI,5,2,20s,0.5
Grabbing Ass,Grabbing Ass_AI,5,2,20s,0.5
Grabbing Boobs,Grabbing Boobs_AI,5,2,20s,0.5
Grabbing Hair/Head,Grabbing Hair/Head_AI,5,2,20s,0.5
Handjob,Handjob_AI,5,2,20s,0.5
Kissing,Kissing_AI,5,2,20s,0.5
Licking Penis,Licking Penis_AI,2,2,20s,0.5
Masturbation,Masturbation_AI,5,2,20s,0.5
Pissing,Pissing_AI,2,2,20s,0.5
Pussy Licking (Clearly Visible),Pussy Licking (Clearly Visible)_AI,5,2,20s,0.5
Pussy Licking,Pussy Licking_AI,3,2,20s,0.5
Pussy Rubbing,Pussy Rubbing_AI,5,2,20s,0.5
Sucking Fingers,Sucking Fingers_AI,1,2,20s,0.5
Sucking Toy/Dildo,Sucking Toy/Dildo_AI,1,2,20s,0.5
Wet (Genitals),Wet (Genitals)_AI,3,2,20s,0.5
Titjob,Titjob_AI,5,2,20s,0.5
Tribbing/Scissoring,Tribbing/Scissoring_AI,3,2,20s,0.5
Undressing,Undressing_AI,3,2,20s,0.5
Vaginal Penetration,Vaginal Penetration_AI,5,2,20s,0.5
Vaginal Fucking,Vaginal Fucking_AI,5,2,20s,0.5
Vibrating,Vibrating_AI,5,2,20s,0.5
1 ServerTag StashTag MinDuration MinMarkerDuration MaxGap RequiredDuration TagThreshold
2 69 69_AI 5 2 20s 0.5
3 Anal Fucking Anal Fucking_AI 5 2 20s 0.5
4 Ass Licking Ass Licking_AI 5 2 20s 0.5
5 Ass Penetration Ass Penetration_AI 5 2 20s 0.5
6 Ball Licking/Sucking Ball Licking/Sucking_AI 2 2 20s 0.5
7 Blowjob Blowjob_AI 5 2 20s 0.5
8 Cum on Person Cum on Person_AI 3 2 15s 0.5
9 Cum Swapping Cum Swapping_AI 2 2 15s 0.5
10 Cumshot Cumshot_AI 1 2 10s 0.5
11 Deepthroat Deepthroat_AI 1 2 20s 0.5
12 Double Penetration Double Penetration_AI 5 2 20s 0.5
13 Fingering Fingering_AI 5 2 20s 0.5
14 Fisting Fisting_AI 3 2 20s 0.5
15 Footjob Footjob_AI 3 2 20s 0.5
16 Gangbang Gangbang_AI 5 2 20s 0.5
17 Gloryhole Gloryhole_AI 5 2 20s 0.5
18 Grabbing Ass Grabbing Ass_AI 5 2 20s 0.5
19 Grabbing Boobs Grabbing Boobs_AI 5 2 20s 0.5
20 Grabbing Hair/Head Grabbing Hair/Head_AI 5 2 20s 0.5
21 Handjob Handjob_AI 5 2 20s 0.5
22 Kissing Kissing_AI 5 2 20s 0.5
23 Licking Penis Licking Penis_AI 2 2 20s 0.5
24 Masturbation Masturbation_AI 5 2 20s 0.5
25 Pissing Pissing_AI 2 2 20s 0.5
26 Pussy Licking (Clearly Visible) Pussy Licking (Clearly Visible)_AI 5 2 20s 0.5
27 Pussy Licking Pussy Licking_AI 3 2 20s 0.5
28 Pussy Rubbing Pussy Rubbing_AI 5 2 20s 0.5
29 Sucking Fingers Sucking Fingers_AI 1 2 20s 0.5
30 Sucking Toy/Dildo Sucking Toy/Dildo_AI 1 2 20s 0.5
31 Wet (Genitals) Wet (Genitals)_AI 3 2 20s 0.5
32 Titjob Titjob_AI 5 2 20s 0.5
33 Tribbing/Scissoring Tribbing/Scissoring_AI 3 2 20s 0.5
34 Undressing Undressing_AI 3 2 20s 0.5
35 Vaginal Penetration Vaginal Penetration_AI 5 2 20s 0.5
36 Vaginal Fucking Vaginal Fucking_AI 5 2 20s 0.5
37 Vibrating Vibrating_AI 5 2 20s 0.5