From 4a47e4d53d23ef49b873be344ee2215ad405c409 Mon Sep 17 00:00:00 2001 From: skier233 <39396856+skier233@users.noreply.github.com> Date: Wed, 1 Jan 2025 15:02:31 -0500 Subject: [PATCH] [AI Tagger] V2 (#478) --- plugins/AITagger/ai_server.py | 84 +++++------ plugins/AITagger/ai_tagger.py | 229 ++++++++++++++++++++--------- plugins/AITagger/ai_tagger.yml | 10 +- plugins/AITagger/config.py | 2 + plugins/AITagger/media_handler.py | 234 ++++++++++++++++++------------ plugins/AITagger/tag_mappings.csv | 125 ---------------- plugins/AITagger/utility.py | 19 +++ 7 files changed, 367 insertions(+), 336 deletions(-) delete mode 100644 plugins/AITagger/tag_mappings.csv create mode 100644 plugins/AITagger/utility.py diff --git a/plugins/AITagger/ai_server.py b/plugins/AITagger/ai_server.py index 7c93c7c..f738896 100644 --- a/plugins/AITagger/ai_server.py +++ b/plugins/AITagger/ai_server.py @@ -1,11 +1,9 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Set import aiohttp import pydantic import config import stashapi.log as log -current_videopipeline = None - # ----------------- AI Server Calling Functions ----------------- async def post_api_async(session, endpoint, payload): @@ -38,55 +36,47 @@ async def process_images_async(image_paths, threshold=config.IMAGE_THRESHOLD, re async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=config.SERVER_TIMEOUT)) as session: return await post_api_async(session, 'process_images/', {"paths": image_paths, "threshold": threshold, "return_confidence": return_confidence}) -async def process_video_async(video_path, vr_video=False, frame_interval=config.FRAME_INTERVAL,threshold=config.AI_VIDEO_THRESHOLD, return_confidence=True): +async def process_video_async(video_path, vr_video=False, frame_interval=config.FRAME_INTERVAL,threshold=config.AI_VIDEO_THRESHOLD, return_confidence=True, existing_json=None): async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=config.SERVER_TIMEOUT)) as session: - return await post_api_async(session, 'process_video/', {"path": video_path, "frame_interval": frame_interval, "threshold": threshold, "return_confidence": return_confidence, "vr_video": vr_video}) - -async def get_image_config_async(threshold=config.IMAGE_THRESHOLD): - async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=config.SERVER_TIMEOUT)) as session: - return await get_api_async(session, f'image_pipeline_info/?threshold={threshold}') + return await post_api_async(session, 'process_video/', {"path": video_path, "frame_interval": frame_interval, "threshold": threshold, "return_confidence": return_confidence, "vr_video": vr_video, "existing_json_data": existing_json}) -async def get_video_config_async(frame_interval=config.FRAME_INTERVAL, threshold=config.AI_VIDEO_THRESHOLD): +async def find_optimal_marker_settings(existing_json, desired_timespan_data): async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=config.SERVER_TIMEOUT)) as session: - return await get_api_async(session, f'video_pipeline_info/?frame_interval={frame_interval}&threshold={threshold}&return_confidence=True') - + return await post_api_async(session, 'optimize_timeframe_settings/', {"existing_json_data": existing_json, "desired_timespan_data": desired_timespan_data}) + + class VideoResult(pydantic.BaseModel): - result: List[Dict[str, Any]] = pydantic.Field(..., min_items=1) - pipeline_short_name: str - pipeline_version: float - threshold: float - frame_interval: float - return_confidence: bool + result: Dict[str, Any] +class TimeFrame(pydantic.BaseModel): + start: float + end: float + totalConfidence: Optional[float] + + def to_json(self): + return self.model_dump_json(exclude_none=True) + + def __str__(self): + return f"TimeFrame(start={self.start}, end={self.end})" + +class VideoTagInfo(pydantic.BaseModel): + video_duration: float + video_tags: Dict[str, Set[str]] + tag_totals: Dict[str, Dict[str, float]] + tag_timespans: Dict[str, Dict[str, List[TimeFrame]]] + + @classmethod + def from_json(cls, json_str: str): + log.info(f"json_str: {json_str}") + log.info(f"video_duration: {json_str['video_duration']}, video_tags: {json_str['video_tags']}, tag_totals: {json_str['tag_totals']}, tag_timespans: {json_str['tag_timespans']}") + return cls(video_duration=json_str["video_duration"], video_tags=json_str["video_tags"], tag_totals=json_str["tag_totals"], tag_timespans=json_str["tag_timespans"]) + + def __str__(self): + return f"VideoTagInfo(video_duration={self.video_duration}, video_tags={self.video_tags}, tag_totals={self.tag_totals}, tag_timespans={self.tag_timespans})" + class ImageResult(pydantic.BaseModel): result: List[Dict[str, Any]] = pydantic.Field(..., min_items=1) - pipeline_short_name: str - pipeline_version: float - threshold: float - return_confidence: bool -class ImagePipelineInfo(pydantic.BaseModel): - pipeline_short_name: str - pipeline_version: float - threshold: float - return_confidence: bool - -class VideoPipelineInfo(pydantic.BaseModel): - pipeline_short_name: str - pipeline_version: float - threshold: float - frame_interval: float - return_confidence: bool - -async def get_current_video_pipeline(): - global current_videopipeline - if current_videopipeline is not None: - return current_videopipeline - try: - current_videopipeline = VideoPipelineInfo(**await get_video_config_async()) - except aiohttp.ClientConnectionError as e: - log.error(f"Failed to connect to AI server. Is the AI server running at {config.API_BASE_URL}? {e}") - except Exception as e: - log.error(f"Failed to get pipeline info: {e}. Ensure the AI server is running with at least version 1.3.1!") - raise - return current_videopipeline \ No newline at end of file +class OptimizeMarkerSettings(pydantic.BaseModel): + existing_json_data: Any = None + desired_timespan_data: Dict[str, TimeFrame] \ No newline at end of file diff --git a/plugins/AITagger/ai_tagger.py b/plugins/AITagger/ai_tagger.py index c3b23d2..4ca195c 100644 --- a/plugins/AITagger/ai_tagger.py +++ b/plugins/AITagger/ai_tagger.py @@ -49,23 +49,32 @@ try: except ModuleNotFoundError: log.error("Please provide a config.py file with the required variables.") raise Exception("Please provide a config.py file with the required variables.") - from ai_video_result import AIVideoResult import media_handler import ai_server + import utility + from datetime import datetime + try: + import cv2 + except ModuleNotFoundError: + install('opencv-python') + toRaise = True except: log.error("Attempted to install required packages, please retry the task.") + log.error(f"Stack trace {traceback.format_exc()}") sys.exit(1) raise # ----------------- Variable Definitions ----------------- -semaphore = asyncio.Semaphore(config.CONCURRENT_TASK_LIMIT) +semaphore = None progress = 0 increment = 0.0 # ----------------- Main Execution ----------------- async def main(): + global semaphore + semaphore = asyncio.Semaphore(config.CONCURRENT_TASK_LIMIT) json_input = read_json_input() output = {} await run(json_input, output) @@ -97,6 +106,10 @@ async def run(json_input, output): await tag_scenes() output["output"] = "ok" return + elif PLUGIN_ARGS == "find_marker_settings": + await find_marker_settings() + elif PLUGIN_ARGS == "collect_incorrect_markers": + collect_incorrect_markers_and_images() output["output"] = "ok" return @@ -124,18 +137,85 @@ async def tag_scenes(): else: log.info("No scenes to tag. Have you tagged any scenes with the AI_TagMe tag to get processed?") +async def find_marker_settings(): + scenes = media_handler.get_tagme_scenes() + if len(scenes) != 1: + log.error("Please tag exactly one scene with the AI_TagMe tag to get processed.") + return + scene = scenes[0] + await __find_marker_settings(scene) + +def collect_incorrect_markers_and_images(): + incorrect_images = media_handler.get_incorrect_images() + imagePaths, imageIds, temp_files = media_handler.get_image_paths_and_ids(incorrect_images) + incorrect_markers = media_handler.get_incorrect_markers() + if not (len(incorrect_images) > 0 or len(incorrect_markers) > 0): + log.info("No incorrect images or markers to collect.") + return + current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + try: + image_folder = os.path.join(config.output_data_dir, "images") + os.makedirs(image_folder, exist_ok=True) + for imagePath in imagePaths: + try: + shutil.copy(imagePath, image_folder) + except Exception as e: + log.error(f"Failed to copy image {imagePath} to {image_folder}: {e}") + except Exception as e: + log.error(f"Failed to process images: {e}") + raise e + finally: + for temp_file in temp_files: + try: + if os.path.isdir(temp_file): + shutil.rmtree(temp_file) + else: + os.remove(temp_file) + except Exception as e: + log.debug(f"Failed to remove temp file {temp_file}: {e}") + + + scene_folder = os.path.join(config.output_data_dir, "scenes") + os.makedirs(scene_folder, exist_ok=True) + tag_folders = {} + for marker in incorrect_markers: + scene_path = marker['scene']['files'][0]['path'] + if not scene_path: + log.error(f"Marker {marker['id']} has no scene path") + continue + try: + tag_name = marker['primary_tag']['name'] + if tag_name not in tag_folders: + tag_folders[tag_name] = os.path.join(scene_folder, tag_name) + os.makedirs(tag_folders[tag_name], exist_ok=True) + media_handler.write_scene_marker_to_file(marker, scene_path, tag_folders[tag_name]) + + except Exception as e: + log.error(f"Failed to collect scene: {e}") + image_ids = [image['id'] for image in incorrect_images] + media_handler.remove_incorrect_tag_from_images(image_ids) + + if config.delete_incorrect_markers: + media_handler.delete_markers(incorrect_markers) + else: + media_handler.remove_incorrect_tag_from_markers(incorrect_markers) + + destination_folder = "./send_to_Skier" + os.makedirs(destination_folder, exist_ok=True) + # Zip the entire output data directory + output_zip_path = os.path.join(destination_folder, f"{current_time}.zip") + + shutil.make_archive(output_zip_path.replace('.zip', ''), 'zip', config.output_data_dir) + shutil.rmtree(config.output_data_dir) + log.info(f"Please send the following file to Skier to help improve the AI: {os.path.abspath(output_zip_path)}") + + # ----------------- Image Processing ----------------- async def __tag_images(images): async with semaphore: imagePaths, imageIds, temp_files = media_handler.get_image_paths_and_ids(images) - mutated_image_paths = [] - for path in imagePaths: - mutated_path = path - for key, value in config.path_mutation.items(): - mutated_path = mutated_path.replace(key, value) - mutated_image_paths.append(mutated_path) - imagePaths = mutated_image_paths + imagePaths = [utility.mutate_path(path) for path in imagePaths] try: server_result = await ai_server.process_images_async(imagePaths) if server_result is None: @@ -149,15 +229,19 @@ async def __tag_images(images): log.error("Server returned incorrect number of results") media_handler.add_error_images(imageIds) else: + media_handler.remove_ai_tags_from_images(imageIds, remove_tagme=False) + for id, result in zip(imageIds, results): if 'error' in result: log.error(f"Error processing image: {result['error']}") media_handler.add_error_images([id]) else: - tags = media_handler.get_all_tags_from_server_result(result) - stashtag_ids = media_handler.get_tag_ids(tags) - stashtag_ids.append(media_handler.ai_tagged_tag_id) - media_handler.add_tags_to_image(id, stashtag_ids) + tags_list = [] + for _, tags in result.items(): + stashtag_ids = media_handler.get_tag_ids(tags) + stashtag_ids.append(media_handler.ai_tagged_tag_id) + tags_list.extend(stashtag_ids) + media_handler.add_tags_to_image(id, tags_list) log.info(f"Tagged {len(imageIds)} images") media_handler.remove_tagme_tags_from_images(imageIds) @@ -180,86 +264,64 @@ async def __tag_images(images): except Exception as e: log.debug(f"Failed to remove temp file {temp_file}: {e}") -# ----------------- Scene Processing ----------------- +# ----------------- Scene Processing ----------------- async def __tag_scene(scene): async with semaphore: scenePath = scene['files'][0]['path'] - mutated_path = scenePath - for key, value in config.path_mutation.items(): - mutated_path = mutated_path.replace(key, value) sceneId = scene['id'] - log.debug("files result:" + str(scene['files'][0])) - phash = scene['files'][0].get('fingerprint', None) duration = scene['files'][0].get('duration', None) + log.debug("files result:" + str(scene['files'][0])) if duration is None: log.error(f"Scene {sceneId} has no duration") return + + mutated_path = utility.mutate_path(scenePath) + try: already_ai_tagged = media_handler.is_scene_tagged(scene.get('tags')) ai_file_path = scenePath + ".AI.json" - ai_video_result = None + saved_json = None if already_ai_tagged: if os.path.exists(ai_file_path): try: - ai_video_result = AIVideoResult.from_json_file(ai_file_path) - current_pipeline_video = await ai_server.get_current_video_pipeline() - if ai_video_result.already_contains_model(current_pipeline_video): - log.info(f"Skipping running AI for scene {scenePath} as it has already been processed with the same pipeline version and configuration. Updating tags and markers instead.") - ai_video_result.update_stash_tags() - ai_video_result.update_stash_markers() - return + saved_json = utility.read_json_from_file(ai_file_path) except Exception as e: log.error(f"Failed to load AI results from file: {e}") - elif os.path.exists(os.path.join(os.path.dirname(scenePath), os.path.splitext(os.path.basename(scenePath))[0] + f"__vid_giddy__1.0.csv")): - ai_video_result = AIVideoResult.from_csv_file(os.path.join(os.path.dirname(scenePath), os.path.splitext(os.path.basename(scenePath))[0] + f"__vid_giddy__1.0.csv"), scene_id=sceneId, phash=phash, duration=duration) - log.info(f"Loading AI results from CSV file for scene {scenePath}") - current_pipeline_video = await ai_server.get_current_video_pipeline() - if ai_video_result.already_contains_model(current_pipeline_video): - log.info(f"Skipping running AI for scene {scenePath} as it has already been processed with the same pipeline version and configuration. Updating tags and markers instead.") - ai_video_result.to_json_file(ai_file_path) - ai_video_result.update_stash_tags() - ai_video_result.update_stash_markers() - return - elif os.path.exists(os.path.join(os.path.dirname(scenePath), os.path.splitext(os.path.basename(scenePath))[0] + f"__actiondetection__1.0.csv")): - ai_video_result = AIVideoResult.from_csv_file(os.path.join(os.path.dirname(scenePath), os.path.splitext(os.path.basename(scenePath))[0] + f"__actiondetection__1.0.csv"), scene_id=sceneId, phash=phash, duration=duration, version=1.0) - log.info(f"Loading AI results from CSV file for scene {scenePath}") - current_pipeline_video = await ai_server.get_current_video_pipeline() - if ai_video_result.already_contains_model(current_pipeline_video): - log.info(f"Skipping running AI for scene {scenePath} as it has already been processed with the same pipeline version and configuration. Updating tags and markers instead.") - ai_video_result.to_json_file(ai_file_path) - ai_video_result.update_stash_tags() - ai_video_result.update_stash_markers() - return - elif os.path.exists(os.path.join(os.path.dirname(scenePath), os.path.splitext(os.path.basename(scenePath))[0] + f"__actiondetection__2.0.csv")): - ai_video_result = AIVideoResult.from_csv_file(os.path.join(os.path.dirname(scenePath), os.path.splitext(os.path.basename(scenePath))[0] + f"__actiondetection__2.0.csv"), scene_id=sceneId, phash=phash, duration=duration, version=2.0) - log.info(f"Loading AI results from CSV file for scene {scenePath}") - current_pipeline_video = await ai_server.get_current_video_pipeline() - if ai_video_result.already_contains_model(current_pipeline_video): - log.info(f"Skipping running AI for scene {scenePath} as it has already been processed with the same pipeline version and configuration. Updating tags and markers instead.") - ai_video_result.to_json_file(ai_file_path) - ai_video_result.update_stash_tags() - ai_video_result.update_stash_markers() - return else: log.warning(f"Scene {scenePath} is already tagged but has no AI results file. Running AI again.") vr_video = media_handler.is_vr_scene(scene.get('tags')) if vr_video: log.info(f"Processing VR video {scenePath}") - server_result = await ai_server.process_video_async(video_path=mutated_path, vr_video=vr_video) + server_result = await ai_server.process_video_async(video_path=mutated_path, vr_video=vr_video, existing_json=saved_json) + if server_result is None: log.error("Server returned no results") media_handler.add_error_scene(sceneId) media_handler.remove_tagme_tag_from_scene(sceneId) return server_result = ai_server.VideoResult(**server_result) - if ai_video_result: - ai_video_result.add_server_response(server_result) - else: - ai_video_result = AIVideoResult.from_server_response(server_result, sceneId, phash, duration) - ai_video_result.to_json_file(ai_file_path) - ai_video_result.update_stash_tags() - ai_video_result.update_stash_markers() + + result = server_result.result + json_to_write = result['json_result'] + if json_to_write: + utility.write_json_to_file(ai_file_path, json_to_write) + video_tag_info = ai_server.VideoTagInfo(**result['video_tag_info']) + + media_handler.remove_ai_tags_from_video(sceneId, remove_tagme=True) + allTags = [] + for _, tag_set in video_tag_info.video_tags.items(): + allTags.extend(tag_set) + tagIdsToAdd = media_handler.get_tag_ids(allTags, create=True) + media_handler.add_tags_to_video(sceneId, tagIdsToAdd) + + #TODO: find a good place to store total durations of tags in a video and ideally be able to query them and see them in stash's UI (via custom plugin db fields?) + #todo = video_tag_info.tag_totals + + if config.CREATE_MARKERS: + media_handler.remove_ai_markers_from_video(sceneId) + media_handler.add_markers_to_video_from_dict(sceneId, video_tag_info.tag_timespans) + log.info(f"Server Result: {server_result}") log.info(f"Processed video with {len(server_result.result)} AI tagged frames") except aiohttp.ClientConnectionError as e: log.error(f"Failed to connect to AI server. Is the AI server running at {config.API_BASE_URL}? {e}") @@ -272,12 +334,45 @@ async def __tag_scene(scene): return finally: increment_progress() + + +# ----------------- Find Marker Settings ------------- + +async def __find_marker_settings(scene): + scenePath = scene['files'][0]['path'] + + already_ai_tagged = media_handler.is_scene_tagged(scene.get('tags')) + ai_file_path = scenePath + ".AI.json" + saved_json = None + if already_ai_tagged: + if os.path.exists(ai_file_path): + try: + saved_json = utility.read_json_from_file(ai_file_path) + except Exception as e: + log.error(f"Failed to load AI results from file: {e}") + else: + log.warning(f"Scene {scenePath} is already tagged but has no AI results file. Running AI again.") + if saved_json is None: + log.error(f"Scene {scenePath} has no AI results to optimize. Run the AI on this scene first and tune the markers manually.") + return + sorted_markers = media_handler.get_scene_markers_by_tag(scene['id']) + + for tag in sorted_markers: + sorted_markers[tag].sort(key=lambda x: x['seconds']) + + tag_timespans = {} + for tag, markers in sorted_markers.items(): + timeframes = [(ai_server.TimeFrame(start=marker['seconds'], end=marker['end_seconds'], totalConfidence=None)).to_json() for marker in markers] + tag_timespans[tag] = timeframes + log.info(f"Sending {tag_timespans} to AI server to optimize marker settings") + await ai_server.find_optimal_marker_settings(saved_json, tag_timespans) -# ----------------- Utility Functions ----------------- + +# ----------------- Utility Functions ---------------- def increment_progress(): global progress global increment progress += increment log.progress(progress) -asyncio.run(main()) +asyncio.run(main()) \ No newline at end of file diff --git a/plugins/AITagger/ai_tagger.yml b/plugins/AITagger/ai_tagger.yml index 3d990c1..1aa1fd1 100644 --- a/plugins/AITagger/ai_tagger.yml +++ b/plugins/AITagger/ai_tagger.yml @@ -1,6 +1,6 @@ name: AI Tagger description: Tag videos and Images with Locally hosted AI using Skier's Free and Patreon AI models -version: 1.8 +version: 2.0 url: https://github.com/stashapp/CommunityScripts/tree/main/plugins/AITagger exec: - python @@ -15,3 +15,11 @@ tasks: description: Run AI Tagger on scenes with AI_TagMe tag defaultArgs: mode: tag_scenes + - name: Collect Incorrect Markers and Images + description: Collects data from markers and images that were AI Tagged but were manually marked with AI_Incorrect due to the AI making a mistake. This will collect the data and output as a file which can be sent to Skier to improve the AI. + defaultArgs: + mode: collect_incorrect_markers + - name: Find Marker Settings + description: Find Optimal Marker Settings based on a video that has manually tuned markers and has been processed by the AI previously. Only 1 video should have AI_TagMe before running. + defaultArgs: + mode: find_marker_settings diff --git a/plugins/AITagger/config.py b/plugins/AITagger/config.py index f3df3ef..e287219 100644 --- a/plugins/AITagger/config.py +++ b/plugins/AITagger/config.py @@ -9,6 +9,8 @@ SERVER_TIMEOUT = 3700 AI_VIDEO_THRESHOLD = 0.3 temp_image_dir = "./temp_images" +output_data_dir = "./output_data" +delete_incorrect_markers = True ai_base_tag_name = "AI" tagme_tag_name = "AI_TagMe" updateme_tag_name = "AI_UpdateMe" diff --git a/plugins/AITagger/media_handler.py b/plugins/AITagger/media_handler.py index b75fcb0..b0ec43b 100644 --- a/plugins/AITagger/media_handler.py +++ b/plugins/AITagger/media_handler.py @@ -1,16 +1,15 @@ -import csv import os import zipfile -from stashapi.stashapp import StashInterface +from stashapi.stashapp import StashInterface, StashVersion import stashapi.log as log import config +import cv2 -tagid_mappings = {} -tagname_mappings = {} -max_gaps = {} -min_durations = {} -required_durations = {} -tag_thresholds = {} +tagid_cache = {} + +ai_tag_ids_cache = set() +stash_version = None +end_seconds_support = False def initialize(connection): global stash @@ -19,6 +18,9 @@ def initialize(connection): global ai_base_tag_id global ai_tagged_tag_id global vr_tag_id + global end_seconds_support + global stash_version + global ai_incorrect_tag_id # Initialize the Stash API stash = StashInterface(connection) @@ -28,6 +30,7 @@ def initialize(connection): tagme_tag_id = stash.find_tag(config.tagme_tag_name, create=True)["id"] ai_base_tag_id = stash.find_tag(config.ai_base_tag_name, create=True)["id"] ai_tagged_tag_id = stash.find_tag(config.aitagged_tag_name, create=True)["id"] + ai_incorrect_tag_id = stash.find_tag("AI_Incorrect", create=True)["id"] vr_tag_name = stash.get_configuration()["ui"].get("vrTag", None) if not vr_tag_name: log.warning("No VR tag found in configuration") @@ -35,35 +38,41 @@ def initialize(connection): else: vr_tag_id = stash.find_tag(vr_tag_name)["id"] - try: - parse_csv("tag_mappings.csv") - except Exception as e: - log.error(f"Failed to parse tag_mappings.csv: {e}") + stash_version = get_stash_version() + end_second_support_beyond = StashVersion("v0.27.2") + end_seconds_support = stash_version > end_second_support_beyond + +def get_stash_version(): + return stash.stash_version() + # ----------------- Tag Methods ----------------- +def get_tag_ids(tag_names, create=False): + return [get_tag_id(tag_name, create) for tag_name in tag_names] -tag_categories = ["actions", "bodyparts", "bdsm", "clothing", "describingperson", "environment", "describingbody", "describingimage", "describingscene", "sextoys"] +def get_tag_id(tag_name, create=False): + if tag_name not in tagid_cache: + stashtag = stash.find_tag(tag_name) + if stashtag: + tagid_cache[tag_name] = stashtag["id"] + return stashtag["id"] + else: + if not create: + return None + tag = stash.create_tag({"name":tag_name, "ignore_auto_tag": True, "parent_ids":[ai_base_tag_id]})['id'] + tagid_cache[tag_name] = tag + ai_tag_ids_cache.add(tag) + return tag + return tagid_cache.get(tag_name) -def get_all_tags_from_server_result(result): - alltags = [] - for category in tag_categories: - alltags.extend(result.get(category, [])) - return alltags - -def get_tag_ids(tag_names): - return [get_tag_id(tag_name) for tag_name in tag_names] - -def get_tag_id(tag_name): - if tag_name not in tagid_mappings: - return stash.find_tag(tag_name)["id"] - return tagid_mappings.get(tag_name) - -def get_tag_threshold(tag_name): - return tag_thresholds.get(tag_name, 0.5) - -def is_ai_tag(tag_name): - return tag_name in tagname_mappings +def get_ai_tags(): + if len(ai_tag_ids_cache) == 0: + ai_tags = [item['id'] for item in stash.find_tags(f={"parents": {"value":1410, "modifier":"INCLUDES"}}, fragment="id")] + ai_tag_ids_cache.update(ai_tags) + else : + ai_tags = list(ai_tag_ids_cache) + return ai_tags def is_scene_tagged(tags): for tag in tags: @@ -85,9 +94,23 @@ def get_tagme_images(): def add_error_images(image_ids): stash.update_images({"ids": image_ids, "tag_ids": {"ids": [aierroed_tag_id], "mode": "ADD"}}) +def get_incorrect_images(): + return stash.find_images(f={"tags": {"value":ai_incorrect_tag_id, "modifier":"INCLUDES"}}, fragment="id files {path}") + def remove_tagme_tags_from_images(image_ids): stash.update_images({"ids": image_ids, "tag_ids": {"ids": [tagme_tag_id], "mode": "REMOVE"}}) +def remove_incorrect_tag_from_images(image_ids): + stash.update_images({"ids": image_ids, "tag_ids": {"ids": [ai_incorrect_tag_id], "mode": "REMOVE"}}) + +def remove_ai_tags_from_images(image_ids, remove_tagme=True, remove_errored=True): + ai_tags = get_ai_tags() + if remove_tagme: + ai_tags.append(tagme_tag_id) + if remove_errored: + ai_tags.append(aierroed_tag_id) + stash.update_images({"ids": image_ids, "tag_ids": {"ids": ai_tags, "mode": "REMOVE"}}) + def add_tags_to_image(image_id, tag_ids): stash.update_images({"ids": [image_id], "tag_ids": {"ids": tag_ids, "mode": "ADD"}}) @@ -137,10 +160,12 @@ def add_tags_to_video(video_id, tag_ids, add_tagged=True): tag_ids.append(ai_tagged_tag_id) stash.update_scenes({"ids": [video_id], "tag_ids": {"ids": tag_ids, "mode": "ADD"}}) -def remove_ai_tags_from_video(video_id, remove_tagme=True): - ai_tags = list(tagid_mappings.values()) +def remove_ai_tags_from_video(video_id, remove_tagme=True, remove_errored=True): + ai_tags = get_ai_tags() if remove_tagme: ai_tags.append(tagme_tag_id) + if remove_errored: + ai_tags.append(aierroed_tag_id) stash.update_scenes({"ids": [video_id], "tag_ids": {"ids": ai_tags, "mode": "REMOVE"}}) def get_tagme_scenes(): @@ -152,77 +177,94 @@ def add_error_scene(scene_id): def remove_tagme_tag_from_scene(scene_id): stash.update_scenes({"ids": [scene_id], "tag_ids": {"ids": [tagme_tag_id], "mode": "REMOVE"}}) -def get_required_duration(tag_name, scene_duration): - if not required_durations: - log.error("Tag mappings not initialized") - required_duration_value = str(required_durations.get(tag_name)) - required_duration_value = required_duration_value.replace(" ", "").lower() - - if required_duration_value.endswith("s"): - # If the value ends with 's', remove 's' and convert to float - return float(required_duration_value[:-1]) - elif required_duration_value.endswith("%"): - # If the value ends with '%', remove '%' and calculate the percentage of scene_duration - percentage = float(required_duration_value[:-1]) - return (percentage / 100) * scene_duration - elif "." in required_duration_value and 0 <= float(required_duration_value) <= 1: - # If the value is a proportion, calculate the proportion of scene_duration - proportion = float(required_duration_value) - return proportion * scene_duration - else: - # If the value is a straight number, convert to float - return float(required_duration_value) - # ----------------- Marker Methods ----------------- -def is_ai_marker_supported(tag_name): - return tag_name in min_durations +def add_markers_to_video_from_dict(video_id, tag_timespans_dict): + for _, tag_timespan_dict in tag_timespans_dict.items(): + for tag_name, time_frames in tag_timespan_dict.items(): + tag_id = get_tag_id(tag_name, create=True) + add_markers_to_video(video_id, tag_id, tag_name, time_frames) -def get_min_duration(tag_name): - return min_durations.get(tag_name) -def get_max_gap(tag_name): - return max_gaps.get(tag_name, 0) +def get_incorrect_markers(): + if end_seconds_support: + return stash.find_scene_markers({"tags": {"value":ai_incorrect_tag_id, "modifier":"INCLUDES"}}, fragment="id scene {id files{path}} primary_tag {id, name} seconds end_seconds") + else: + return stash.find_scene_markers({"tags": {"value":ai_incorrect_tag_id, "modifier":"INCLUDES"}}, fragment="id scene {id files{path}} primary_tag {id, name} seconds") def add_markers_to_video(video_id, tag_id, tag_name, time_frames): for time_frame in time_frames: - stash.create_scene_marker({"scene_id": video_id, "primary_tag_id":tag_id, "tag_ids": [tag_id], "seconds": time_frame.start, "title":tagname_mappings[tag_name]}) + if end_seconds_support: + stash.create_scene_marker({"scene_id": video_id, "primary_tag_id":tag_id, "tag_ids": [tag_id], "seconds": time_frame.start, "end_seconds": time_frame.end, "title":tag_name}) + else: + stash.create_scene_marker({"scene_id": video_id, "primary_tag_id":tag_id, "tag_ids": [tag_id], "seconds": time_frame.start, "title":tag_name}) + +def get_scene_markers(video_id): + return stash.get_scene_markers(video_id, fragment="id primary_tag {id} seconds end_seconds") + +def write_scene_marker_to_file(marker, scene_file, output_folder): + start = marker.get("seconds", None) + end = marker.get("end_seconds", None) + try: + cap = cv2.VideoCapture(scene_file) + if not cap.isOpened(): + log.error(f"Failed to open video {scene_file}") + return + + timestamps = [] + if end is None: + timestamps.append(start) + else: + duration = end - start + if duration > 4 and duration < 30: + timestamps.append(start + 4) + elif duration >= 30 and duration < 60: + timestamps.append(start + 4) + timestamps.append(start + 20) + elif duration >= 60 and duration < 120: + timestamps.append(start + 4) + timestamps.append(start + 20) + timestamps.append(start + 50) + elif duration >= 120: + timestamps.append(start + 4) + timestamps.append(start + 20) + timestamps.append(start + 50) + timestamps.append(start + 100) + + for timestamp in timestamps: + cap.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000) + ret, frame = cap.read() + if not ret: + log.error(f"Failed to read frame at {timestamp} seconds from {scene_file}") + return + output_path = os.path.join(output_folder, f"{marker.get('id')}_{timestamp}.jpg") + cv2.imwrite(output_path, frame) + except Exception as e: + log.error(f"Failed to write scene marker to file: {e}") + +def delete_markers(markers): + for scene_marker in markers: + stash.destroy_scene_marker(scene_marker["id"]) + +def get_scene_markers_by_tag(video_id, error_if_no_end_seconds=True): + scene_markers = stash.get_scene_markers(video_id, fragment="id primary_tag {name} seconds end_seconds") + scene_markers_by_tag = {} + for scene_marker in scene_markers: + tag_name = scene_marker.get("primary_tag").get("name") + if tag_name not in scene_markers_by_tag: + scene_markers_by_tag[tag_name] = [] + if error_if_no_end_seconds and scene_marker.get("end_seconds", None) is None: + raise ValueError(f"Scene marker {scene_marker.get('id')} has no end_seconds") + scene_markers_by_tag[tag_name].append(scene_marker) + return scene_markers_by_tag + +def remove_incorrect_tag_from_markers(markers): + for marker in markers: + stash.update_scene_marker({"id": marker["id"], "tag_ids": []}) def remove_ai_markers_from_video(video_id): - ai_tags = set(tagid_mappings.values()) + ai_tags = set(get_ai_tags()) scene_markers = stash.get_scene_markers(video_id, fragment="id primary_tag {id}") for scene_marker in scene_markers: if scene_marker.get("primary_tag").get("id") in ai_tags: - stash.destroy_scene_marker(scene_marker.get("id")) - -# ----------------- Helpers ----------------- - -def parse_csv(file_path): - global tagid_mappings - global tagname_mappings - global max_gaps - global min_durations - global required_durations - global tag_thresholds - - with open(file_path, mode='r') as infile: - reader = csv.DictReader(infile) - for row in reader: - server_tag = row.get('ServerTag') - stash_tag = row.get('StashTag') - min_duration = float(row.get('MinMarkerDuration', -1)) #float(row['MinDuration']) - max_gap = float(row.get('MaxGap', 0)) #float(row['MaxGap']) - required_duration = row.get('RequiredDuration', "200%") - tag_threshold = float(row.get('TagThreshold', 0.5)) - - tag = stash.find_tag(stash_tag) - if not tag: - tag = stash.create_tag({"name":stash_tag, "ignore_auto_tag": True, "parent_ids":[ai_base_tag_id]}) - - tagid_mappings[server_tag] = tag["id"] - tagname_mappings[server_tag] = stash_tag - if min_duration != -1: - min_durations[server_tag] = min_duration - max_gaps[server_tag] = max_gap - required_durations[server_tag] = required_duration - tag_thresholds[server_tag] = tag_threshold + stash.destroy_scene_marker(scene_marker.get("id")) \ No newline at end of file diff --git a/plugins/AITagger/tag_mappings.csv b/plugins/AITagger/tag_mappings.csv deleted file mode 100644 index f82f7b0..0000000 --- a/plugins/AITagger/tag_mappings.csv +++ /dev/null @@ -1,125 +0,0 @@ -ServerTag,StashTag,MinMarkerDuration,MaxGap,RequiredDuration,TagThreshold -69,69_AI,15,6,20s,0.5 -Anal Fucking,Anal Fucking_AI,15,6,20s,0.5 -Ass Licking,Ass Licking_AI,15,6,20s,0.5 -Ass Penetration,Ass Penetration_AI,15,6,20s,0.5 -Ball Licking/Sucking,Ball Licking/Sucking_AI,5,4,20s,0.5 -Blowjob,Blowjob_AI,15,6,20s,0.5 -Cum on Person,Cum on Person_AI,5,4,15s,0.5 -Cum Swapping,Cum Swapping_AI,5,4,15s,0.5 -Cumshot,Cumshot_AI,4,4,10s,0.5 -Deepthroat,Deepthroat_AI,5,4,20s,0.5 -Double Penetration,Double Penetration_AI,10,4,20s,0.5 -Fingering,Fingering_AI,15,6,20s,0.5 -Fisting,Fisting_AI,15,6,20s,0.5 -Footjob,Footjob_AI,15,6,20s,0.5 -Gangbang,Gangbang_AI,15,6,20s,0.5 -Gloryhole,Gloryhole_AI,15,8,20s,0.5 -Grabbing Ass,Grabbing Ass_AI,10,8,20s,0.5 -Grabbing Boobs,Grabbing Boobs_AI,6,6,20s,0.5 -Grabbing Hair/Head,Grabbing Hair/Head_AI,6,6,20s,0.5 -Handjob,Handjob_AI,15,6,20s,0.5 -Kissing,Kissing_AI,10,4,20s,0.5 -Licking Penis,Licking Penis_AI,6,4,20s,0.5 -Masturbation,Masturbation_AI,15,10,20s,0.5 -Pissing,Pissing_AI,5,4,20s,0.5 -Pussy Licking (Clearly Visible),Pussy Licking (Clearly Visible)_AI,10,4,20s,0.5 -Pussy Licking,Pussy Licking_AI,15,6,20s,0.5 -Pussy Rubbing,Pussy Rubbing_AI,15,6,20s,0.5 -Sucking Fingers,Sucking Fingers_AI,5,4,20s,0.5 -Sucking Toy/Dildo,Sucking Toy/Dildo_AI,5,4,20s,0.5 -Wet (Genitals),Wet (Genitals)_AI,15,6,20s,0.5 -Titjob,Titjob_AI,10,4,20s,0.5 -Tribbing/Scissoring,Tribbing/Scissoring_AI,15,6,20s,0.5 -Undressing,Undressing_AI,15,6,20s,0.5 -Vaginal Penetration,Vaginal Penetration_AI,15,6,20s,0.5 -Vaginal Fucking,Vaginal Fucking_AI,15,6,20s,0.5 -Vibrating,Vibrating_AI,10,6,20s,0.5 -Ass,Ass_AI,-1,6,20s,0.5 -Asshole,Asshole_AI,-1,6,20s,0.5 -Anal Gape,Anal Gape_AI,10,6,20s,0.5 -Balls,Balls_AI,-1,6,20s,0.5 -Boobs,Boobs_AI,-1,6,20s,0.5 -Cum,Cum_AI,10,6,20s,0.5 -Dick,Dick_AI,-1,6,20s,0.5 -Face,Face_AI,-1,6,20s,0.5 -Feet,Feet_AI,-1,6,20s,0.5 -Fingers,Fingers_AI,-1,6,20s,0.5 -Belly Button,Belly Button_AI,-1,6,20s,0.5 -Nipples,Nipples_AI,-1,6,20s,0.5 -Thighs,Thighs_AI,-1,6,20s,0.5 -Lower Legs,Lower Legs_AI,-1,6,20s,0.5 -Tongue,Tongue_AI,10,6,20s,0.5 -Pussy,Pussy_AI,-1,6,20s,0.5 -Pussy Gape,Pussy Gape_AI,10,6,20s,0.5 -Spit,Spit_AI,10,6,20s,0.5 -Oiled,Oiled_AI,10,6,20s,0.5 -Wet (Water),Wet (Water)_AI,10,6,20s,0.5 -Chastity,Chastity_AI,10,6,20s,0.5 -Bondage,Bondage_AI,10,6,20s,0.5 -Female Bondage,Female Bondage_AI,-1,6,20s,0.5 -Male Bondage,Male Bondage_AI,-1,6,20s,0.5 -Choking,Choking_AI,5,4,20s,0.5 -Pegging,Pegging_AI,10,6,20s,0.5 -Nipple Clamps,Nipple Clamps_AI,10,6,20s,0.5 -Gag,Gag_AI,10,6,20s,0.5 -Pain,Pain_AI,10,6,20s,0.5 -Anal Hook,Anal Hook_AI,10,6,20s,0.5 -Chastity (Male),Chastity (Male)_AI,10,6,20s,0.5 -Chastity (Female),Chastity (Female)_AI,10,6,20s,0.5 -Metal Chastity,Metal Chastity_AI,10,6,20s,0.5 -Plastic Chastity,Plastic Chastity_AI,10,6,20s,0.5 -Cum in Chastity,Cum in Chastity_AI,10,6,20s,0.5 -Crotch Roped,Crotch Roped_AI,10,6,20s,0.5 -Bondaged Boobs,Bondaged Boobs_AI,10,6,20s,0.5 -Tied Penis,Tied Penis_AI,10,6,20s,0.5 -Tied Balls,Tied Balls_AI,10,6,20s,0.5 -Clover Clamps,Clover Clamps_AI,10,6,20s,0.5 -Clothes Pin,Clothes Pin_AI,10,6,20s,0.5 -Weights,Weights_AI,10,6,20s,0.5 -Alligator Clamp,Alligator Clamp_AI,10,6,20s,0.5 -Ball Gag,Ball Gag_AI,10,6,20s,0.5 -Ring Gag,Ring Gag_AI,10,6,20s,0.5 -Harness Gag,Harness Gag_AI,10,6,20s,0.5 -Bit Gag,Bit Gag_AI,10,6,20s,0.5 -Muzzle Gag,Muzzle Gag_AI,10,6,20s,0.5 -Dildo Gag,Dildo Gag_AI,10,6,20s,0.5 -Inflatable Gag,Inflatable Gag_AI,10,6,20s,0.5 -Tape Gag,Tape Gag_AI,10,6,20s,0.5 -Rope Bondage,Rope Bondage_AI,10,6,20s,0.5 -Metal Bondage,Metal Bondage_AI,10,6,20s,0.5 -Leather Bondage,Leather Bondage_AI,10,6,20s,0.5 -Latex Bondage,Latex Bondage_AI,10,6,20s,0.5 -Collared,Collared_AI,10,6,20s,0.5 -Blindfolded,Blindfolded_AI,10,6,20s,0.5 -Chair Tied,Chair Tied_AI,10,6,20s,0.5 -Straight Jacket,Straight Jacket_AI,10,6,20s,0.5 -Yoke,Yoke_AI,10,6,20s,0.5 -Whip,Whip_AI,10,6,20s,0.5 -Flogger,Flogger_AI,10,6,20s,0.5 -Electric Torture,Electric Torture_AI,10,6,20s,0.5 -Crush Torture,Crush Torture_AI,10,6,20s,0.5 -Arm Binder,Arm Binder_AI,10,6,20s,0.5 -Rope Collar,Rope Collar_AI,10,6,20s,0.5 -Leather Collar,Leather Collar_AI,10,6,20s,0.5 -Metal Collar,Metal Collar_AI,10,6,20s,0.5 -Leash,Leash_AI,10,6,20s,0.5 -Pussy Fully Visible,Pussy Fully Visible_AI,15,6,20s,0.5 -Pussy Closeup,Pussy Closeup_AI,15,6,20s,0.5 -Pussy Very Closeup,Pussy Very Closeup_AI,15,6,20s,0.5 -Wet Pussy,Wet Pussy_AI,10,6,20s,0.5 -Very Wet Pussy,Very Wet Pussy_AI,5,6,20s,0.5 -Cum on Pussy,Cum on Pussy_AI,5,6,20s,0.5 -Small Labia,Small Labia_AI,20,6,20s,0.5 -Big Labia,Big Labia_AI,20,6,20s,0.5 -Pierced Pussy,Pierced Pussy_AI,20,6,20s,0.5 -Pussy Hair,Pussy Hair_AI,20,6,20s,0.5 -Very Hairy Pussy,Very Hairy Pussy_AI,20,6,20s,0.5 -Innie,Innie_AI,20,6,20s,0.5 -Medium Labia,Medium Labia_AI,20,6,20s,0.5 -Spread Labia,Spread Labia_AI,10,6,20s,0.5 -Dark Pink Pussy,Dark Pink Pussy_AI,20,6,20s,0.5 -Bright Pink Pussy,Bright Pink Pussy_AI,20,6,20s,0.5 -Brown Pussy,Brown Pussy_AI,20,6,20s,0.5 -Light Brown Pussy,Light Brown Pussy_AI,20,6,20s,0.5 -Shaved Pussy,Shaved Pussy_AI,20,6,20s,0.5 diff --git a/plugins/AITagger/utility.py b/plugins/AITagger/utility.py new file mode 100644 index 0000000..f44e143 --- /dev/null +++ b/plugins/AITagger/utility.py @@ -0,0 +1,19 @@ +import json +import config + +def mutate_path(to_mutate): + if isinstance(to_mutate, str): + for key, value in config.path_mutation.items(): + to_mutate = to_mutate.replace(key, value) + elif isinstance(to_mutate, list): + for i in range(len(to_mutate)): + to_mutate[i] = mutate_path(to_mutate[i]) + return to_mutate + +def read_json_from_file(file_path): + with open(file_path, 'r') as f: + return json.load(f) + +def write_json_to_file(file_path, json_data): + with open(file_path, 'w') as f: + f.write(json_data) \ No newline at end of file