initial commit of AHavenVLMConnector (#657)

Co-authored-by: DogmaDragon <103123951+DogmaDragon@users.noreply.github.com>
This commit is contained in:
HavenCTO 2026-01-26 22:43:05 -05:00 committed by GitHub
parent 69e44b2099
commit 2a0719091c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 4596 additions and 0 deletions

View File

@ -0,0 +1,8 @@
# Changelog
All notable changes to the A Haven VLM Connector project will be documented in this file.
## [1.0.0] - 2025-06-29
### Added
- **Initial release**

View File

@ -0,0 +1,143 @@
# A Haven VLM Connector
A StashApp plugin for Vision-Language Model (VLM) based content tagging and analysis. This plugin is designed with a **local-first philosophy**, empowering users to run analysis on their own hardware (using CPU or GPU) and their local network. It also supports cloud-based VLM endpoints for additional flexibility. The Haven VLM Engine provides advanced automatic content detection and tagging, delivering superior accuracy compared to traditional image classification methods.
## Features
- **Local Network Empowerment**: Distribute processing across home/office computers without cloud dependencies
- **Context-Aware Detection**: Leverages Vision-Language Models' understanding of visual relationships
- **Advanced Dependency Management**: Uses PythonDepManager for automatic dependency installation
- **Enjoying Funscript Haven?** Check out more tools and projects at https://github.com/Haven-hvn
## Requirements
- Python 3.8+
- StashApp
- PythonDepManager plugin (automatically handles dependencies)
- OpenAI-compatible VLM endpoints (local or cloud-based)
## Installation
1. Clone or download this plugin to your StashApp plugins directory
2. Ensure PythonDepManager is installed in your StashApp plugins
3. Configure your VLM endpoints in `haven_vlm_config.py` (local network endpoints recommended)
4. Restart StashApp
The plugin automatically manages all dependencies.
## Why Local-First?
- **Complete Control**: Process sensitive content on your own hardware
- **Cost Effective**: Avoid cloud processing fees by using existing resources
- **Flexible Scaling**: Add more computers to your local network for increased capacity
- **Privacy Focused**: Keep your media completely private
- **Hybrid Options**: Combine local and cloud endpoints for optimal flexibility
```mermaid
graph LR
A[User's Computer] --> B[Local GPU Machine]
A --> C[Local CPU Machine 1]
A --> D[Local CPU Machine 2]
A --> E[Cloud Endpoint]
```
## Configuration
### Easy Setup with LM Studio
[LM Studio](https://lmstudio.ai/) provides the easiest way to configure local endpoints:
1. Download and install [LM Studio](https://lmstudio.ai/)
2. [Search for or download](https://huggingface.co/models) a vision-capable model; tested with : (in order of high to low accuracy) zai-org/glm-4.6v-flash, huihui-mistral-small-3.2-24b-instruct-2506-abliterated-v2, qwen/qwen3-vl-8b, lfm2.5-vl
3. Load your desired Model
4. On the developer tab start the local server using the start toggle
5. Optionally click the Settings gear then toggle *Serve on local network*
5. Optionally configure `haven_vlm_config.py`:
By default locahost is included in the config, **remove cloud endpoint if you don't want automatic failover**
```python
{
"base_url": "http://localhost:1234/v1", # LM Studio default
"api_key": "", # API key not required
"name": "lm-studio-local",
"weight": 5,
"is_fallback": False
}
```
### Tag Configuration
```python
"tag_list": [
"Basketball point", "Foul", "Break-away", "Turnover"
]
```
### Processing Settings
```python
VIDEO_FRAME_INTERVAL = 2.0 # Process every 2 seconds
CONCURRENT_TASK_LIMIT = 8 # Adjust based on local hardware
```
## Usage
### Tag Videos
1. Tag scenes with `VLM_TagMe`
2. Run "Tag Videos" task
3. Plugin processes content using local/network resources
### Performance Tips
- Start with 2-3 local machines for load balancing
- Assign higher weights to GPU-enabled machines
- Adjust `CONCURRENT_TASK_LIMIT` based on total system resources
- Use SSD storage for better I/O performance
## File Structure
```
AHavenVLMConnector/
├── ahavenvlmconnector.yml
├── haven_vlm_connector.py
├── haven_vlm_config.py
├── haven_vlm_engine.py
├── haven_media_handler.py
├── haven_vlm_utility.py
├── requirements.txt
└── README.md
```
## Troubleshooting
### Local Network Setup
- Ensure firewalls allow communication between machines
- Verify all local endpoints are running VLM services
- Use static IPs for local machines
- Check `http://local-machine-ip:port/v1` responds correctly
### Performance Optimization
- **Distribute Load**: Use multiple mid-range machines instead of one high-end
- **GPU Prioritization**: Assign highest weight to GPU machines
- **Network Speed**: Use wired Ethernet connections for faster transfer
- **Resource Monitoring**: Watch system resources during processing
## Development
### Adding Local Endpoints
1. Install VLM service on network machines
2. Add endpoint configuration with local IPs
3. Set appropriate weights based on hardware capability
### Custom Models
Use any OpenAI-compatible models that support:
- POST requests to `/v1/chat/completions`
- Vision capabilities with image input
- Local deployment options
### Log Messages
Check StashApp logs for detailed processing information and error messages.
## License
This project is part of the StashApp Community Scripts collection.

View File

@ -0,0 +1,22 @@
name: Haven VLM Connector
# requires: PythonDepManager
description: Tag videos with Vision-Language Models using any OpenAI-compatible VLM endpoint
version: 1.0.0
url: https://github.com/stashapp/CommunityScripts/tree/main/plugins/AHavenVLMConnector
exec:
- python
- "{pluginDir}/haven_vlm_connector.py"
interface: raw
tasks:
- name: Tag Videos
description: Run VLM analysis on videos with VLM_TagMe tag
defaultArgs:
mode: tag_videos
- name: Collect Incorrect Markers and Images
description: Collects data from markers and images that were VLM tagged but were manually marked with VLM_Incorrect due to the VLM making a mistake. This will collect the data and output as a file which can be used to improve the VLM models.
defaultArgs:
mode: collect_incorrect_markers
- name: Find Marker Settings
description: Find Optimal Marker Settings based on a video that has manually tuned markers and has been processed by the VLM previously. Only 1 video should have VLM_TagMe before running.
defaultArgs:
mode: find_marker_settings

View File

@ -0,0 +1,98 @@
"""
Comprehensive sys.exit tracking module
Instruments all sys.exit() calls with full call stack and context
"""
import sys
import traceback
from typing import Optional
# Store original sys.exit
original_exit = sys.exit
# Track if we've already patched
_exit_tracker_patched = False
def install_exit_tracker(logger=None) -> None:
"""
Install the exit tracker by monkey-patching sys.exit
Args:
logger: Optional logger instance (will use fallback print if None)
"""
global _exit_tracker_patched, original_exit
if _exit_tracker_patched:
return
# Store original if not already stored
if hasattr(sys, 'exit') and sys.exit is not original_exit:
original_exit = sys.exit
def tracked_exit(code: int = 0) -> None:
"""Track sys.exit() calls with full call stack"""
# Get current stack trace (not from exception, but current call stack)
stack = traceback.extract_stack()
# Format the stack trace, excluding this tracking function
stack_lines = []
for frame in stack:
# Skip internal Python frames and this tracker
if ('tracked_exit' not in frame.filename and
'/usr/lib' not in frame.filename and
'/System/Library' not in frame.filename and
'exit_tracker.py' not in frame.filename):
stack_lines.append(
f" File \"{frame.filename}\", line {frame.lineno}, in {frame.name}\n {frame.line}"
)
# Take last 15 frames to see the full call chain
stack_str = '\n'.join(stack_lines[-15:])
# Get current exception info if available
exc_info = sys.exc_info()
exc_str = ""
if exc_info[0] is not None:
exc_str = f"\n Active Exception: {exc_info[0].__name__}: {exc_info[1]}"
# Build the error message
error_msg = f"""[DEBUG_EXIT_CODE] ==========================================
[DEBUG_EXIT_CODE] sys.exit() called with code: {code}
[DEBUG_EXIT_CODE] Call stack (last 15 frames):
{stack_str}
{exc_str}
[DEBUG_EXIT_CODE] =========================================="""
# Log using provided logger or fallback to print
if logger:
try:
logger.error(error_msg)
except Exception as log_error:
print(f"[EXIT_TRACKER_LOGGER_ERROR] Failed to log: {log_error}")
print(error_msg)
else:
print(error_msg)
# Call original exit
original_exit(code)
# Install the tracker
sys.exit = tracked_exit
_exit_tracker_patched = True
if logger:
logger.debug("[DEBUG_EXIT_CODE] Exit tracker installed successfully")
else:
print("[DEBUG_EXIT_CODE] Exit tracker installed successfully")
def uninstall_exit_tracker() -> None:
"""Uninstall the exit tracker and restore original sys.exit"""
global _exit_tracker_patched, original_exit
if _exit_tracker_patched:
sys.exit = original_exit
_exit_tracker_patched = False
# Auto-install on import (can be disabled by calling uninstall_exit_tracker())
if not _exit_tracker_patched:
install_exit_tracker()

View File

@ -0,0 +1,333 @@
"""
Haven Media Handler Module
Handles StashApp media operations and tag management
"""
import os
import zipfile
import shutil
from typing import List, Dict, Any, Optional, Tuple, Set
from datetime import datetime
import json
# Use PythonDepManager for dependency management
try:
from PythonDepManager import ensure_import
ensure_import("stashapi:stashapp-tools==0.2.58")
from stashapi.stashapp import StashInterface, StashVersion
import stashapi.log as log
except ImportError as e:
print(f"stashapp-tools not found: {e}")
print("Please ensure PythonDepManager is available and stashapp-tools is accessible")
raise
import haven_vlm_config as config
# Global variables
tag_id_cache: Dict[str, int] = {}
vlm_tag_ids_cache: Set[int] = set()
stash_version: Optional[StashVersion] = None
end_seconds_support: bool = False
# Tag IDs
stash: Optional[StashInterface] = None
vlm_errored_tag_id: Optional[int] = None
vlm_tagme_tag_id: Optional[int] = None
vlm_base_tag_id: Optional[int] = None
vlm_tagged_tag_id: Optional[int] = None
vr_tag_id: Optional[int] = None
vlm_incorrect_tag_id: Optional[int] = None
def initialize(connection: Dict[str, Any]) -> None:
"""Initialize the media handler with StashApp connection"""
global stash, vlm_errored_tag_id, vlm_tagme_tag_id, vlm_base_tag_id
global vlm_tagged_tag_id, vr_tag_id, end_seconds_support, stash_version
global vlm_incorrect_tag_id
# Initialize the Stash API
stash = StashInterface(connection)
# Initialize "metadata" tags
vlm_errored_tag_id = stash.find_tag(config.config.vlm_errored_tag_name, create=True)["id"]
vlm_tagme_tag_id = stash.find_tag(config.config.vlm_tagme_tag_name, create=True)["id"]
vlm_base_tag_id = stash.find_tag(config.config.vlm_base_tag_name, create=True)["id"]
vlm_tagged_tag_id = stash.find_tag(config.config.vlm_tagged_tag_name, create=True)["id"]
vlm_incorrect_tag_id = stash.find_tag(config.config.vlm_incorrect_tag_name, create=True)["id"]
# Get VR tag from configuration
vr_tag_name = stash.get_configuration()["ui"].get("vrTag", None)
if not vr_tag_name:
log.warning("No VR tag found in configuration")
vr_tag_id = None
else:
vr_tag_id = stash.find_tag(vr_tag_name)["id"]
stash_version = get_stash_version()
end_second_support_beyond = StashVersion("v0.27.2-76648")
end_seconds_support = stash_version > end_second_support_beyond
def get_stash_version() -> StashVersion:
"""Get the current StashApp version"""
if not stash:
raise RuntimeError("Stash interface not initialized")
return stash.stash_version()
# ----------------- Tag Management Methods -----------------
def get_tag_ids(tag_names: List[str], create: bool = False) -> List[int]:
"""Get tag IDs for multiple tag names"""
return [get_tag_id(tag_name, create) for tag_name in tag_names]
def get_tag_id(tag_name: str, create: bool = False) -> Optional[int]:
"""Get tag ID for a single tag name"""
if tag_name not in tag_id_cache:
stashtag = stash.find_tag(tag_name)
if stashtag:
tag_id_cache[tag_name] = stashtag["id"]
return stashtag["id"]
else:
if not create:
return None
tag = stash.create_tag({
"name": tag_name,
"ignore_auto_tag": True,
"parent_ids": [vlm_base_tag_id]
})['id']
tag_id_cache[tag_name] = tag
vlm_tag_ids_cache.add(tag)
return tag
return tag_id_cache.get(tag_name)
def get_vlm_tags() -> List[int]:
"""Get all VLM-generated tags"""
if len(vlm_tag_ids_cache) == 0:
vlm_tags = [
item['id'] for item in stash.find_tags(
f={"parents": {"value": vlm_base_tag_id, "modifier": "INCLUDES"}},
fragment="id"
)
]
vlm_tag_ids_cache.update(vlm_tags)
else:
vlm_tags = list(vlm_tag_ids_cache)
return vlm_tags
def is_scene_tagged(tags: List[Dict[str, Any]]) -> bool:
"""Check if a scene has been tagged by VLM"""
for tag in tags:
if tag['id'] == vlm_tagged_tag_id:
return True
return False
def is_vr_scene(tags: List[Dict[str, Any]]) -> bool:
"""Check if a scene is VR content"""
for tag in tags:
if tag['id'] == vr_tag_id:
return True
return False
# ----------------- Scene Management Methods -----------------
def add_tags_to_video(video_id: int, tag_ids: List[int], add_tagged: bool = True) -> None:
"""Add tags to a video scene"""
if add_tagged:
tag_ids.append(vlm_tagged_tag_id)
stash.update_scenes({
"ids": [video_id],
"tag_ids": {"ids": tag_ids, "mode": "ADD"}
})
def clear_all_tags_from_video(scene: Dict[str, Any]) -> None:
"""Clear all tags from a video scene using existing scene data"""
scene_id = scene.get('id')
if scene_id is None:
log.error("Scene missing 'id' field")
return
current_tag_ids = [tag['id'] for tag in scene.get('tags', [])]
if current_tag_ids:
stash.update_scenes({
"ids": [scene_id],
"tag_ids": {"ids": current_tag_ids, "mode": "REMOVE"}
})
log.info(f"Cleared {len(current_tag_ids)} tags from scene {scene_id}")
def clear_all_markers_from_video(video_id: int) -> None:
"""Clear all markers from a video scene"""
markers = get_scene_markers(video_id)
if markers:
delete_markers(markers)
log.info(f"Cleared all {len(markers)} markers from scene {video_id}")
def remove_vlm_tags_from_video(
video_id: int,
remove_tagme: bool = True,
remove_errored: bool = True
) -> None:
"""Remove all VLM tags from a video scene"""
vlm_tags = get_vlm_tags()
if remove_tagme:
vlm_tags.append(vlm_tagme_tag_id)
if remove_errored:
vlm_tags.append(vlm_errored_tag_id)
stash.update_scenes({
"ids": [video_id],
"tag_ids": {"ids": vlm_tags, "mode": "REMOVE"}
})
def get_tagme_scenes() -> List[Dict[str, Any]]:
"""Get scenes tagged with VLM_TagMe"""
return stash.find_scenes(
f={"tags": {"value": vlm_tagme_tag_id, "modifier": "INCLUDES"}},
fragment="id tags {id} files {path duration fingerprint(type: \"phash\")}"
)
def add_error_scene(scene_id: int) -> None:
"""Add error tag to a scene"""
stash.update_scenes({
"ids": [scene_id],
"tag_ids": {"ids": [vlm_errored_tag_id], "mode": "ADD"}
})
def remove_tagme_tag_from_scene(scene_id: int) -> None:
"""Remove VLM_TagMe tag from a scene"""
stash.update_scenes({
"ids": [scene_id],
"tag_ids": {"ids": [vlm_tagme_tag_id], "mode": "REMOVE"}
})
# ----------------- Marker Management Methods -----------------
def add_markers_to_video_from_dict(
video_id: int,
tag_timespans_dict: Dict[str, Dict[str, List[Any]]]
) -> None:
"""Add markers to video from timespan dictionary"""
for _, tag_timespan_dict in tag_timespans_dict.items():
for tag_name, time_frames in tag_timespan_dict.items():
tag_id = get_tag_id(tag_name, create=True)
if tag_id:
add_markers_to_video(video_id, tag_id, tag_name, time_frames)
def get_incorrect_markers() -> List[Dict[str, Any]]:
"""Get markers tagged with VLM_Incorrect"""
if end_seconds_support:
return stash.find_scene_markers(
{"tags": {"value": vlm_incorrect_tag_id, "modifier": "INCLUDES"}},
fragment="id scene {id files{path}} primary_tag {id, name} seconds end_seconds"
)
else:
return stash.find_scene_markers(
{"tags": {"value": vlm_incorrect_tag_id, "modifier": "INCLUDES"}},
fragment="id scene {id files{path}} primary_tag {id, name} seconds"
)
def add_markers_to_video(
video_id: int,
tag_id: int,
tag_name: str,
time_frames: List[Any]
) -> None:
"""Add markers to video for specific time frames"""
for time_frame in time_frames:
if end_seconds_support:
stash.create_scene_marker({
"scene_id": video_id,
"primary_tag_id": tag_id,
"tag_ids": [tag_id],
"seconds": time_frame.start,
"end_seconds": time_frame.end,
"title": tag_name
})
else:
stash.create_scene_marker({
"scene_id": video_id,
"primary_tag_id": tag_id,
"tag_ids": [tag_id],
"seconds": time_frame.start,
"title": tag_name
})
def get_scene_markers(video_id: int) -> List[Dict[str, Any]]:
"""Get all markers for a scene"""
return stash.get_scene_markers(video_id)
def write_scene_marker_to_file(
marker: Dict[str, Any],
scene_file: str,
output_folder: str
) -> None:
"""Write scene marker data to file for analysis"""
try:
marker_id = marker['id']
scene_id = marker['scene']['id']
tag_name = marker['primary_tag']['name']
# Create output filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"marker_{marker_id}_scene_{scene_id}_{tag_name}_{timestamp}.json"
output_path = os.path.join(output_folder, filename)
# Prepare marker data
marker_data = {
"marker_id": marker_id,
"scene_id": scene_id,
"tag_name": tag_name,
"seconds": marker.get("seconds"),
"end_seconds": marker.get("end_seconds"),
"scene_file": scene_file,
"timestamp": timestamp
}
# Write to file
with open(output_path, 'w') as f:
json.dump(marker_data, f, indent=2)
except Exception as e:
log.error(f"Failed to write marker data: {e}")
def delete_markers(markers: List[Dict[str, Any]]) -> None:
"""Delete markers from StashApp"""
for marker in markers:
try:
stash.destroy_scene_marker(marker['id'])
except Exception as e:
log.error(f"Failed to delete marker {marker['id']}: {e}")
def get_scene_markers_by_tag(
video_id: int,
error_if_no_end_seconds: bool = True
) -> List[Dict[str, Any]]:
"""Get scene markers by tag with end_seconds support check"""
if end_seconds_support:
return stash.get_scene_markers(video_id)
else:
if error_if_no_end_seconds:
log.error("End seconds not supported in this StashApp version")
raise RuntimeError("End seconds not supported")
return stash.get_scene_markers(video_id)
def remove_incorrect_tag_from_markers(markers: List[Dict[str, Any]]) -> None:
"""Remove VLM_Incorrect tag from markers"""
marker_ids = [marker['id'] for marker in markers]
for marker_id in marker_ids:
try:
stash.update_scene_marker({
"id": marker_id,
"tag_ids": {"ids": [vlm_incorrect_tag_id], "mode": "REMOVE"}
})
except Exception as e:
log.error(f"Failed to remove incorrect tag from marker {marker_id}: {e}")
def remove_vlm_markers_from_video(video_id: int) -> None:
"""Remove all VLM markers from a video"""
markers = get_scene_markers(video_id)
vlm_tag_ids = get_vlm_tags()
for marker in markers:
if marker['primary_tag']['id'] in vlm_tag_ids:
try:
stash.destroy_scene_marker(marker['id'])
except Exception as e:
log.error(f"Failed to delete VLM marker {marker['id']}: {e}")

View File

@ -0,0 +1,445 @@
"""
Configuration for A Haven VLM Connector
A StashApp plugin for Vision-Language Model based content tagging
"""
from typing import Dict, List, Optional
from dataclasses import dataclass
import os
import yaml
# ----------------- Core Settings -----------------
# VLM Engine Configuration
VLM_ENGINE_CONFIG = {
"active_ai_models": ["vlm_multiplexer_model"],
"pipelines": {
"video_pipeline_dynamic": {
"inputs": [
"video_path",
"return_timestamps",
"time_interval",
"threshold",
"return_confidence",
"vr_video",
"existing_video_data",
"skipped_categories",
],
"output": "results",
"short_name": "dynamic_video",
"version": 1.0,
"models": [
{
"name": "dynamic_video_ai",
"inputs": [
"video_path", "return_timestamps", "time_interval",
"threshold", "return_confidence", "vr_video",
"existing_video_data", "skipped_categories"
],
"outputs": "results",
},
],
}
},
"models": {
"binary_search_processor_dynamic": {
"type": "binary_search_processor",
"model_file_name": "binary_search_processor_dynamic"
},
"vlm_multiplexer_model": {
"type": "vlm_model",
"model_file_name": "vlm_multiplexer_model",
"model_category": "actiondetection",
"model_id": "zai-org/glm-4.6v-flash",
"model_identifier": 93848,
"model_version": "1.0",
"use_multiplexer": True,
"max_concurrent_requests": 13,
"instance_count": 10,
"max_batch_size": 4,
"multiplexer_endpoints": [
{
"base_url": "http://localhost:1234/v1",
"api_key": "",
"name": "lm-studio-primary",
"weight": 9,
"is_fallback": False,
"max_concurrent": 10
},
{
"base_url": "https://cloudagnostic.com:443/v1",
"api_key": "",
"name": "cloud-fallback",
"weight": 1,
"is_fallback": True,
"max_concurrent": 2
}
],
"tag_list": [
"Anal Fucking", "Ass Licking", "Ass Penetration", "Ball Licking/Sucking", "Blowjob", "Cum on Person",
"Cum Swapping", "Cumshot", "Deepthroat", "Double Penetration", "Fingering", "Fisting", "Footjob",
"Gangbang", "Gloryhole", "Grabbing Ass", "Grabbing Boobs", "Grabbing Hair/Head", "Handjob", "Kissing",
"Licking Penis", "Masturbation", "Pissing", "Pussy Licking (Clearly Visible)", "Pussy Licking",
"Pussy Rubbing", "Sucking Fingers", "Sucking Toy/Dildo", "Wet (Genitals)", "Titjob", "Tribbing/Scissoring",
"Undressing", "Vaginal Penetration", "Vaginal Fucking", "Vibrating"
]
},
"result_coalescer": {
"type": "python",
"model_file_name": "result_coalescer"
},
"result_finisher": {
"type": "python",
"model_file_name": "result_finisher"
},
"batch_awaiter": {
"type": "python",
"model_file_name": "batch_awaiter"
},
"video_result_postprocessor": {
"type": "python",
"model_file_name": "video_result_postprocessor"
},
},
"category_config": {
"actiondetection": {
"69": {
"RenamedTag": "69",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Anal Fucking": {
"RenamedTag": "Anal Fucking",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Ass Licking": {
"RenamedTag": "Ass Licking",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Ass Penetration": {
"RenamedTag": "Ass Penetration",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Ball Licking/Sucking": {
"RenamedTag": "Ball Licking/Sucking",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Blowjob": {
"RenamedTag": "Blowjob",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Cum on Person": {
"RenamedTag": "Cum on Person",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Cum Swapping": {
"RenamedTag": "Cum Swapping",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Cumshot": {
"RenamedTag": "Cumshot",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Deepthroat": {
"RenamedTag": "Deepthroat",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Double Penetration": {
"RenamedTag": "Double Penetration",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Fingering": {
"RenamedTag": "Fingering",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Fisting": {
"RenamedTag": "Fisting",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Footjob": {
"RenamedTag": "Footjob",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Gangbang": {
"RenamedTag": "Gangbang",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Gloryhole": {
"RenamedTag": "Gloryhole",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Grabbing Ass": {
"RenamedTag": "Grabbing Ass",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Grabbing Boobs": {
"RenamedTag": "Grabbing Boobs",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Grabbing Hair/Head": {
"RenamedTag": "Grabbing Hair/Head",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Handjob": {
"RenamedTag": "Handjob",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Kissing": {
"RenamedTag": "Kissing",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Licking Penis": {
"RenamedTag": "Licking Penis",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Masturbation": {
"RenamedTag": "Masturbation",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Pissing": {
"RenamedTag": "Pissing",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Pussy Licking (Clearly Visible)": {
"RenamedTag": "Pussy Licking (Clearly Visible)",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Pussy Licking": {
"RenamedTag": "Pussy Licking",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Pussy Rubbing": {
"RenamedTag": "Pussy Rubbing",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Sucking Fingers": {
"RenamedTag": "Sucking Fingers",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Sucking Toy/Dildo": {
"RenamedTag": "Sucking Toy/Dildo",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Wet (Genitals)": {
"RenamedTag": "Wet (Genitals)",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Titjob": {
"RenamedTag": "Titjob",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Tribbing/Scissoring": {
"RenamedTag": "Tribbing/Scissoring",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Undressing": {
"RenamedTag": "Undressing",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Vaginal Penetration": {
"RenamedTag": "Vaginal Penetration",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Vaginal Fucking": {
"RenamedTag": "Vaginal Fucking",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
},
"Vibrating": {
"RenamedTag": "Vibrating",
"MinMarkerDuration": "1s",
"MaxGap": "30s",
"RequiredDuration": "1s",
"TagThreshold": 0.5,
}
}
}
}
# ----------------- Processing Settings -----------------
# Video processing settings
VIDEO_FRAME_INTERVAL = 80 # Process every 80 seconds
VIDEO_THRESHOLD = 0.3
VIDEO_CONFIDENCE_RETURN = True
# Concurrency settings
CONCURRENT_TASK_LIMIT = 20 # Increased for better parallel video processing
SERVER_TIMEOUT = 3700
# ----------------- Tag Configuration -----------------
# Tag names for StashApp integration
VLM_BASE_TAG_NAME = "VLM"
VLM_TAGME_TAG_NAME = "VLM_TagMe"
VLM_UPDATEME_TAG_NAME = "VLM_UpdateMe"
VLM_TAGGED_TAG_NAME = "VLM_Tagged"
VLM_ERRORED_TAG_NAME = "VLM_Errored"
VLM_INCORRECT_TAG_NAME = "VLM_Incorrect"
# ----------------- File System Settings -----------------
# Directory paths
OUTPUT_DATA_DIR = "./output_data"
# File management
DELETE_INCORRECT_MARKERS = True
CREATE_MARKERS = True
# Path mutations for different environments
PATH_MUTATION = {}
# ----------------- Configuration Loading -----------------
@dataclass
class VLMConnectorConfig:
"""Configuration class for the VLM Connector"""
vlm_engine_config: Dict
video_frame_interval: float
video_threshold: float
video_confidence_return: bool
concurrent_task_limit: int
server_timeout: int
vlm_base_tag_name: str
vlm_tagme_tag_name: str
vlm_updateme_tag_name: str
vlm_tagged_tag_name: str
vlm_errored_tag_name: str
vlm_incorrect_tag_name: str
output_data_dir: str
delete_incorrect_markers: bool
create_markers: bool
path_mutation: Dict
def load_config_from_yaml(config_path: Optional[str] = None) -> VLMConnectorConfig:
"""Load configuration from YAML file or use defaults"""
if config_path and os.path.exists(config_path):
with open(config_path, 'r') as f:
yaml_config = yaml.safe_load(f)
return VLMConnectorConfig(**yaml_config)
# Return default configuration
return VLMConnectorConfig(
vlm_engine_config=VLM_ENGINE_CONFIG,
video_frame_interval=VIDEO_FRAME_INTERVAL,
video_threshold=VIDEO_THRESHOLD,
video_confidence_return=VIDEO_CONFIDENCE_RETURN,
concurrent_task_limit=CONCURRENT_TASK_LIMIT,
server_timeout=SERVER_TIMEOUT,
vlm_base_tag_name=VLM_BASE_TAG_NAME,
vlm_tagme_tag_name=VLM_TAGME_TAG_NAME,
vlm_updateme_tag_name=VLM_UPDATEME_TAG_NAME,
vlm_tagged_tag_name=VLM_TAGGED_TAG_NAME,
vlm_errored_tag_name=VLM_ERRORED_TAG_NAME,
vlm_incorrect_tag_name=VLM_INCORRECT_TAG_NAME,
output_data_dir=OUTPUT_DATA_DIR,
delete_incorrect_markers=DELETE_INCORRECT_MARKERS,
create_markers=CREATE_MARKERS,
path_mutation=PATH_MUTATION
)
# Global configuration instance
config = load_config_from_yaml()

View File

@ -0,0 +1,444 @@
"""
A Haven VLM Connector
A StashApp plugin for Vision-Language Model based content tagging
"""
import os
import sys
import json
import shutil
import traceback
import asyncio
import logging
import time
from typing import Dict, Any, List, Optional
from datetime import datetime
# Import and install sys.exit tracking FIRST (before any other imports that might call sys.exit)
try:
from exit_tracker import install_exit_tracker
import stashapi.log as log
install_exit_tracker(log)
except ImportError as e:
print(f"Warning: exit_tracker not available: {e}")
print("sys.exit tracking will not be available")
# ----------------- Setup and Dependencies -----------------
# Use PythonDepManager for dependency management
try:
from PythonDepManager import ensure_import
# Install and ensure all required dependencies with specific versions
ensure_import(
"stashapi:stashapp-tools==0.2.58",
"aiohttp==3.12.13",
"pydantic==2.11.7",
"vlm-engine==0.9.1",
"pyyaml==6.0.2"
)
# Import the dependencies after ensuring they're available
import stashapi.log as log
from stashapi.stashapp import StashInterface
import aiohttp
import pydantic
import yaml
except ImportError as e:
print(f"Failed to import PythonDepManager or required dependencies: {e}")
print("Please ensure PythonDepManager is installed and available.")
sys.exit(1)
except Exception as e:
print(f"Error during dependency management: {e}")
print(f"Stack trace: {traceback.format_exc()}")
sys.exit(1)
# Import local modules
try:
import haven_vlm_config as config
except ModuleNotFoundError:
log.error("Please provide a haven_vlm_config.py file with the required variables.")
raise Exception("Please provide a haven_vlm_config.py file with the required variables.")
import haven_media_handler as media_handler
import haven_vlm_engine as vlm_engine
from haven_vlm_engine import TimeFrame
log.debug("Python instance is running at: " + sys.executable)
# ----------------- Global Variables -----------------
semaphore: Optional[asyncio.Semaphore] = None
progress: float = 0.0
increment: float = 0.0
completed_tasks: int = 0
total_tasks: int = 0
video_progress: Dict[str, float] = {}
# ----------------- Main Execution -----------------
async def main() -> None:
"""Main entry point for the plugin"""
global semaphore
# Semaphore initialization logging for hypothesis A
log.debug(f"[DEBUG_HYPOTHESIS_A] Initializing semaphore with limit {config.config.concurrent_task_limit}")
semaphore = asyncio.Semaphore(config.config.concurrent_task_limit)
# Post-semaphore creation logging
log.debug(f"[DEBUG_HYPOTHESIS_A] Semaphore created successfully (limit: {config.config.concurrent_task_limit})")
json_input = read_json_input()
output = {}
await run(json_input, output)
out = json.dumps(output)
print(out + "\n")
def read_json_input() -> Dict[str, Any]:
"""Read JSON input from stdin"""
json_input = sys.stdin.read()
return json.loads(json_input)
async def run(json_input: Dict[str, Any], output: Dict[str, Any]) -> None:
"""Main execution logic"""
plugin_args = None
try:
log.debug(json_input["server_connection"])
os.chdir(json_input["server_connection"]["PluginDir"])
media_handler.initialize(json_input["server_connection"])
except Exception as e:
log.error(f"Failed to initialize media handler: {e}")
raise
try:
plugin_args = json_input['args']["mode"]
except KeyError:
pass
if plugin_args == "tag_videos":
await tag_videos()
output["output"] = "ok"
return
elif plugin_args == "find_marker_settings":
await find_marker_settings()
output["output"] = "ok"
return
elif plugin_args == "collect_incorrect_markers":
collect_incorrect_markers_and_images()
output["output"] = "ok"
return
output["output"] = "ok"
return
# ----------------- High Level Processing Functions -----------------
async def tag_videos() -> None:
"""Tag videos with VLM analysis using improved async orchestration"""
global completed_tasks, total_tasks
scenes = media_handler.get_tagme_scenes()
if not scenes:
log.info("No videos to tag. Have you tagged any scenes with the VLM_TagMe tag to get processed?")
return
total_tasks = len(scenes)
completed_tasks = 0
video_progress.clear()
for scene in scenes:
video_progress[scene.get('id', 'unknown')] = 0.0
log.progress(0.0)
log.info(f"🚀 Starting video processing for {total_tasks} scenes with semaphore limit of {config.config.concurrent_task_limit}")
# Create tasks with proper indexing for debugging
tasks = []
for i, scene in enumerate(scenes):
# Pre-task creation logging for hypothesis A (semaphore deadlock) and E (signal termination)
scene_id = scene.get('id')
log.debug(f"[DEBUG_HYPOTHESIS_A] Creating task {i+1}/{total_tasks} for scene {scene_id}, semaphore limit: {config.config.concurrent_task_limit}")
task = asyncio.create_task(__tag_video_with_timing(scene, i))
tasks.append(task)
# Use asyncio.as_completed to process results as they finish (proves concurrency)
completed_task_futures = asyncio.as_completed(tasks)
batch_start_time = asyncio.get_event_loop().time()
for completed_task in completed_task_futures:
try:
await completed_task
completed_tasks += 1
except Exception as e:
completed_tasks += 1
# Exception logging for hypothesis E (signal termination)
error_type = type(e).__name__
log.debug(f"[DEBUG_HYPOTHESIS_E] Task failed with exception: {error_type}: {str(e)} (Task {completed_tasks}/{total_tasks})")
log.error(f"❌ Task failed: {e}")
total_time = asyncio.get_event_loop().time() - batch_start_time
log.info(f"🎉 All {total_tasks} videos completed in {total_time:.2f}s (avg: {total_time/total_tasks:.2f}s/video)")
log.progress(1.0)
async def find_marker_settings() -> None:
"""Find optimal marker settings based on a single tagged video"""
scenes = media_handler.get_tagme_scenes()
if len(scenes) != 1:
log.error("Please tag exactly one scene with the VLM_TagMe tag to get processed.")
return
scene = scenes[0]
await __find_marker_settings(scene)
def collect_incorrect_markers_and_images() -> None:
"""Collect data from incorrectly tagged markers and images"""
incorrect_images = media_handler.get_incorrect_images()
image_paths, image_ids, temp_files = media_handler.get_image_paths_and_ids(incorrect_images)
incorrect_markers = media_handler.get_incorrect_markers()
if not (len(incorrect_images) > 0 or len(incorrect_markers) > 0):
log.info("No incorrect images or markers to collect.")
return
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
try:
# Process images
image_folder = os.path.join(config.config.output_data_dir, "images")
os.makedirs(image_folder, exist_ok=True)
for image_path in image_paths:
try:
shutil.copy(image_path, image_folder)
except Exception as e:
log.error(f"Failed to copy image {image_path} to {image_folder}: {e}")
except Exception as e:
log.error(f"Failed to process images: {e}")
raise e
finally:
# Clean up temp files
for temp_file in temp_files:
try:
if os.path.isdir(temp_file):
shutil.rmtree(temp_file)
else:
os.remove(temp_file)
except Exception as e:
log.debug(f"Failed to remove temp file {temp_file}: {e}")
# Process markers
scene_folder = os.path.join(config.config.output_data_dir, "scenes")
os.makedirs(scene_folder, exist_ok=True)
tag_folders = {}
for marker in incorrect_markers:
scene_path = marker['scene']['files'][0]['path']
if not scene_path:
log.error(f"Marker {marker['id']} has no scene path")
continue
try:
tag_name = marker['primary_tag']['name']
if tag_name not in tag_folders:
tag_folders[tag_name] = os.path.join(scene_folder, tag_name)
os.makedirs(tag_folders[tag_name], exist_ok=True)
media_handler.write_scene_marker_to_file(marker, scene_path, tag_folders[tag_name])
except Exception as e:
log.error(f"Failed to collect scene: {e}")
# Remove incorrect tags from images
image_ids = [image['id'] for image in incorrect_images]
media_handler.remove_incorrect_tag_from_images(image_ids)
# ----------------- Low Level Processing Functions -----------------
async def __tag_video_with_timing(scene: Dict[str, Any], scene_index: int) -> None:
"""Tag a single video scene with timing diagnostics"""
start_time = asyncio.get_event_loop().time()
scene_id = scene.get('id', 'unknown')
log.info(f"🎬 Starting video {scene_index + 1}: Scene {scene_id}")
try:
await __tag_video(scene)
end_time = asyncio.get_event_loop().time()
duration = end_time - start_time
log.info(f"✅ Completed video {scene_index + 1} (Scene {scene_id}) in {duration:.2f}s")
except Exception as e:
end_time = asyncio.get_event_loop().time()
duration = end_time - start_time
log.error(f"❌ Failed video {scene_index + 1} (Scene {scene_id}) after {duration:.2f}s: {e}")
raise
async def __tag_video(scene: Dict[str, Any]) -> None:
"""Tag a single video scene with semaphore timing instrumentation"""
scene_id = scene.get('id')
# Pre-semaphore acquisition logging for hypothesis A (semaphore deadlock)
task_start_time = asyncio.get_event_loop().time()
acquisition_start_time = task_start_time
log.debug(f"[DEBUG_HYPOTHESIS_A] Task starting for scene {scene_id} at {task_start_time:.3f}s")
async with semaphore:
try:
# Semaphore acquisition successful logging
acquisition_end_time = asyncio.get_event_loop().time()
acquisition_time = acquisition_end_time - acquisition_start_time
log.debug(f"[DEBUG_HYPOTHESIS_A] Semaphore acquired for scene {scene_id} after {acquisition_time:.3f}s")
if scene_id is None:
log.error("Scene missing 'id' field")
return
files = scene.get('files', [])
if not files:
log.error(f"Scene {scene_id} has no files")
return
scene_file = files[0].get('path')
if scene_file is None:
log.error(f"Scene {scene_id} file has no path")
return
# Check if scene is VR
is_vr = media_handler.is_vr_scene(scene.get('tags', []))
def progress_cb(p: int) -> None:
global video_progress, total_tasks
video_progress[scene_id] = p / 100.0
total_prog = sum(video_progress.values()) / total_tasks
log.progress(total_prog)
# Process video through VLM Engine with HTTP timing for hypothesis B
processing_start_time = asyncio.get_event_loop().time()
# HTTP request lifecycle tracking start
log.debug(f"[DEBUG_HYPOTHESIS_B] Starting VLM processing for scene {scene_id}: {scene_file}")
video_result = await vlm_engine.process_video_async(
scene_file,
vr_video=is_vr,
frame_interval=config.config.video_frame_interval,
threshold=config.config.video_threshold,
return_confidence=config.config.video_confidence_return,
progress_callback=progress_cb
)
# Extract detected tags
detected_tags = set()
for category_tags in video_result.video_tags.values():
detected_tags.update(category_tags)
# Post-VLM processing logging
processing_end_time = asyncio.get_event_loop().time()
processing_duration = processing_end_time - processing_start_time
log.debug(f"[DEBUG_HYPOTHESIS_B] VLM processing completed for scene {scene_id} in {processing_duration:.2f}s ({len(detected_tags)} detected tags)")
if detected_tags:
# Clear all existing tags and markers before adding new ones
media_handler.clear_all_tags_from_video(scene)
media_handler.clear_all_markers_from_video(scene_id)
# Add tags to scene
tag_ids = media_handler.get_tag_ids(list(detected_tags), create=True)
media_handler.add_tags_to_video(scene_id, tag_ids)
log.info(f"Added tags {list(detected_tags)} to scene {scene_id}")
# Add markers if enabled
if config.config.create_markers:
media_handler.add_markers_to_video_from_dict(scene_id, video_result.tag_timespans)
log.info(f"Added markers to scene {scene_id}")
# Remove VLM_TagMe tag from processed scene
media_handler.remove_tagme_tag_from_scene(scene_id)
# Task completion logging
task_end_time = asyncio.get_event_loop().time()
total_task_time = task_end_time - task_start_time
log.debug(f"[DEBUG_HYPOTHESIS_A] Task completed for scene {scene_id} in {total_task_time:.2f}s")
except Exception as e:
# Exception handling with detailed logging for hypothesis E
exception_time = asyncio.get_event_loop().time()
error_type = type(e).__name__
log.debug(f"[DEBUG_HYPOTHESIS_E] Task exception for scene {scene_id}: {error_type}: {str(e)} at {exception_time:.3f}s")
scene_id = scene.get('id', 'unknown')
log.error(f"Error processing video scene {scene_id}: {e}")
# Add error tag to failed scene if we have a valid ID
if scene_id != 'unknown':
media_handler.add_error_scene(scene_id)
async def __find_marker_settings(scene: Dict[str, Any]) -> None:
"""Find optimal marker settings for a scene"""
try:
scene_id = scene.get('id')
if scene_id is None:
log.error("Scene missing 'id' field")
return
files = scene.get('files', [])
if not files:
log.error(f"Scene {scene_id} has no files")
return
scene_file = files[0].get('path')
if scene_file is None:
log.error(f"Scene {scene_id} file has no path")
return
# Get existing markers for the scene
existing_markers = media_handler.get_scene_markers(scene_id)
# Convert markers to desired timespan format
desired_timespan_data = {}
for marker in existing_markers:
tag_name = marker['primary_tag']['name']
desired_timespan_data[tag_name] = TimeFrame(
start=marker['seconds'],
end=marker.get('end_seconds', marker['seconds'] + 1),
total_confidence=1.0
)
# Find optimal settings
optimal_settings = await vlm_engine.find_optimal_marker_settings_async(
existing_json={}, # No existing JSON data
desired_timespan_data=desired_timespan_data
)
# Output results
log.info(f"Optimal marker settings found for scene {scene_id}:")
log.info(json.dumps(optimal_settings, indent=2))
except Exception as e:
scene_id = scene.get('id', 'unknown')
log.error(f"Error finding marker settings for scene {scene_id}: {e}")
# ----------------- Cleanup -----------------
async def cleanup() -> None:
"""Cleanup resources"""
if vlm_engine.vlm_engine:
await vlm_engine.vlm_engine.shutdown()
# Run main function if script is executed directly
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
log.info("Plugin interrupted by user")
sys.exit(0)
except SystemExit as e:
# Re-raise system exit with the exit code
log.debug(f"[DEBUG_EXIT_CODE] Caught SystemExit with code: {e.code}")
raise
except Exception as e:
log.error(f"Plugin failed: {e}")
sys.exit(1)
finally:
asyncio.run(cleanup())

View File

@ -0,0 +1,299 @@
"""
Haven VLM Engine Integration Module
Provides integration with the Haven VLM Engine for video and image processing
"""
import asyncio
import logging
from typing import Any, Dict, List, Optional, Set, Union, Callable
from dataclasses import dataclass
from datetime import datetime
import json
# Use PythonDepManager for dependency management
from vlm_engine import VLMEngine
from vlm_engine.config_models import (
EngineConfig,
PipelineConfig,
ModelConfig,
PipelineModelConfig
)
import haven_vlm_config as config
# Configure logging
logging.basicConfig(level=logging.CRITICAL)
logger = logging.getLogger(__name__)
@dataclass
class TimeFrame:
"""Represents a time frame with start and end times"""
start: float
end: float
total_confidence: Optional[float] = None
def to_json(self) -> str:
"""Convert to JSON string"""
return json.dumps({
"start": self.start,
"end": self.end,
"total_confidence": self.total_confidence
})
def __str__(self) -> str:
return f"TimeFrame(start={self.start}, end={self.end}, confidence={self.total_confidence})"
@dataclass
class VideoTagInfo:
"""Represents video tagging information"""
video_duration: float
video_tags: Dict[str, Set[str]]
tag_totals: Dict[str, Dict[str, float]]
tag_timespans: Dict[str, Dict[str, List[TimeFrame]]]
@classmethod
def from_json(cls, json_data: Dict[str, Any]) -> 'VideoTagInfo':
"""Create VideoTagInfo from JSON data"""
logger.debug(f"Creating VideoTagInfo from JSON: {json_data}")
# Convert tag_timespans to TimeFrame objects
tag_timespans = {}
for category, tags in json_data.get("tag_timespans", {}).items():
tag_timespans[category] = {}
for tag_name, timeframes in tags.items():
tag_timespans[category][tag_name] = [
TimeFrame(
start=tf["start"],
end=tf["end"],
total_confidence=tf.get("total_confidence")
) for tf in timeframes
]
return cls(
video_duration=json_data.get("video_duration", 0.0),
video_tags=json_data.get("video_tags", {}),
tag_totals=json_data.get("tag_totals", {}),
tag_timespans=tag_timespans
)
def __str__(self) -> str:
return f"VideoTagInfo(duration={self.video_duration}, tags={len(self.video_tags)}, timespans={len(self.tag_timespans)})"
class HavenVLMEngine:
"""Main VLM Engine integration class"""
def __init__(self):
self.engine: Optional[VLMEngine] = None
self.engine_config: Optional[EngineConfig] = None
self._initialized = False
async def initialize(self) -> None:
"""Initialize the VLM Engine with configuration"""
if self._initialized:
return
try:
logger.info("Initializing Haven VLM Engine...")
# Convert config dict to EngineConfig objects
self.engine_config = self._create_engine_config()
# Create and initialize the engine
self.engine = VLMEngine(config=self.engine_config)
await self.engine.initialize()
self._initialized = True
logger.info("Haven VLM Engine initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize VLM Engine: {e}")
raise
def _create_engine_config(self) -> EngineConfig:
"""Create EngineConfig from the configuration"""
vlm_config = config.config.vlm_engine_config
# Create pipeline configs
pipelines = {}
for pipeline_name, pipeline_data in vlm_config["pipelines"].items():
models = [
PipelineModelConfig(
name=model["name"],
inputs=model["inputs"],
outputs=model["outputs"]
) for model in pipeline_data["models"]
]
pipelines[pipeline_name] = PipelineConfig(
inputs=pipeline_data["inputs"],
output=pipeline_data["output"],
short_name=pipeline_data["short_name"],
version=pipeline_data["version"],
models=models
)
# Create model configs with new architectural changes
models = {}
for model_name, model_data in vlm_config["models"].items():
if model_data["type"] == "vlm_model":
# Process multiplexer_endpoints and validate max_concurrent
multiplexer_endpoints = []
for endpoint in model_data.get("multiplexer_endpoints", []):
# Validate that max_concurrent is present
if "max_concurrent" not in endpoint:
raise ValueError(f"Endpoint '{endpoint.get('name', 'unnamed')}' is missing required 'max_concurrent' parameter")
multiplexer_endpoints.append({
"base_url": endpoint["base_url"],
"api_key": endpoint.get("api_key", ""),
"name": endpoint["name"],
"weight": endpoint.get("weight", 5),
"is_fallback": endpoint.get("is_fallback", False),
"max_concurrent": endpoint["max_concurrent"]
})
models[model_name] = ModelConfig(
type=model_data["type"],
model_file_name=model_data["model_file_name"],
model_category=model_data["model_category"],
model_id=model_data["model_id"],
model_identifier=model_data["model_identifier"],
model_version=model_data["model_version"],
use_multiplexer=model_data.get("use_multiplexer", False),
max_concurrent_requests=model_data.get("max_concurrent_requests", 10),
instance_count=model_data.get("instance_count",1),
max_batch_size=model_data.get("max_batch_size",1),
multiplexer_endpoints=multiplexer_endpoints,
tag_list=model_data.get("tag_list", [])
)
else:
models[model_name] = ModelConfig(
type=model_data["type"],
model_file_name=model_data["model_file_name"]
)
return EngineConfig(
active_ai_models=vlm_config["active_ai_models"],
pipelines=pipelines,
models=models,
category_config=vlm_config["category_config"]
)
async def process_video(
self,
video_path: str,
vr_video: bool = False,
frame_interval: Optional[float] = None,
threshold: Optional[float] = None,
return_confidence: Optional[bool] = None,
existing_json: Optional[Dict[str, Any]] = None,
progress_callback: Optional[Callable[[int], None]] = None
) -> VideoTagInfo:
"""Process a video using the VLM Engine"""
if not self._initialized:
await self.initialize()
try:
logger.info(f"Processing video: {video_path}")
# Use config defaults if not provided
frame_interval = frame_interval or config.config.video_frame_interval
threshold = threshold or config.config.video_threshold
return_confidence = return_confidence if return_confidence is not None else config.config.video_confidence_return
# Process video through the engine
results = await self.engine.process_video(
video_path,
frame_interval=frame_interval,
progress_callback=progress_callback
)
logger.info(f"Video processing completed for: {video_path}")
logger.debug(f"Raw results structure: {type(results)}")
# Extract video_tag_info from the nested structure
if isinstance(results, dict) and 'video_tag_info' in results:
video_tag_data = results['video_tag_info']
logger.debug(f"Using video_tag_info from results: {video_tag_data.keys()}")
else:
# Fallback: assume results is already in the correct format
video_tag_data = results
logger.debug(f"Using results directly: {video_tag_data.keys() if isinstance(video_tag_data, dict) else type(video_tag_data)}")
return VideoTagInfo.from_json(video_tag_data)
except Exception as e:
logger.error(f"Error processing video {video_path}: {e}")
raise
async def find_optimal_marker_settings(
self,
existing_json: Dict[str, Any],
desired_timespan_data: Dict[str, TimeFrame]
) -> Dict[str, Any]:
"""Find optimal marker settings based on existing data"""
if not self._initialized:
await self.initialize()
try:
logger.info("Finding optimal marker settings...")
# Convert TimeFrame objects to dict format
desired_data = {}
for key, timeframe in desired_timespan_data.items():
desired_data[key] = {
"start": timeframe.start,
"end": timeframe.end,
"total_confidence": timeframe.total_confidence
}
# Call the engine's optimization method
results = await self.engine.optimize_timeframe_settings(
existing_json_data=existing_json,
desired_timespan_data=desired_data
)
logger.info("Optimal marker settings found")
return results
except Exception as e:
logger.error(f"Error finding optimal marker settings: {e}")
raise
async def shutdown(self) -> None:
"""Shutdown the VLM Engine"""
if self.engine and self._initialized:
try:
# VLMEngine doesn't have a shutdown method, just perform basic cleanup
logger.info("VLM Engine cleanup completed")
self._initialized = False
except Exception as e:
logger.error(f"Error during VLM Engine cleanup: {e}")
self._initialized = False
# Global VLM Engine instance
vlm_engine = HavenVLMEngine()
# Convenience functions for backward compatibility
async def process_video_async(
video_path: str,
vr_video: bool = False,
frame_interval: Optional[float] = None,
threshold: Optional[float] = None,
return_confidence: Optional[bool] = None,
existing_json: Optional[Dict[str, Any]] = None,
progress_callback: Optional[Callable[[int], None]] = None
) -> VideoTagInfo:
"""Process video asynchronously"""
return await vlm_engine.process_video(
video_path, vr_video, frame_interval, threshold, return_confidence, existing_json,
progress_callback=progress_callback
)
async def find_optimal_marker_settings_async(
existing_json: Dict[str, Any],
desired_timespan_data: Dict[str, TimeFrame]
) -> Dict[str, Any]:
"""Find optimal marker settings asynchronously"""
return await vlm_engine.find_optimal_marker_settings(existing_json, desired_timespan_data)

View File

@ -0,0 +1,316 @@
"""
Haven VLM Utility Module
Utility functions for the A Haven VLM Connector plugin
"""
import os
import json
import logging
from typing import Dict, Any, List, Optional, Union
from pathlib import Path
import yaml
logger = logging.getLogger(__name__)
def apply_path_mutations(path: str, mutations: Dict[str, str]) -> str:
"""
Apply path mutations for different environments
Args:
path: Original file path
mutations: Dictionary of path mutations (e.g., {"E:": "F:", "G:": "D:"})
Returns:
Mutated path string
"""
if not mutations:
return path
mutated_path = path
for old_path, new_path in mutations.items():
if mutated_path.startswith(old_path):
mutated_path = mutated_path.replace(old_path, new_path, 1)
break
return mutated_path
def ensure_directory_exists(directory_path: str) -> None:
"""
Ensure a directory exists, creating it if necessary
Args:
directory_path: Path to the directory
"""
Path(directory_path).mkdir(parents=True, exist_ok=True)
def safe_file_operation(operation_func, *args, **kwargs) -> Optional[Any]:
"""
Safely execute a file operation with error handling
Args:
operation_func: Function to execute
*args: Arguments for the function
**kwargs: Keyword arguments for the function
Returns:
Result of the operation or None if failed
"""
try:
return operation_func(*args, **kwargs)
except (OSError, IOError) as e:
logger.error(f"File operation failed: {e}")
return None
except Exception as e:
logger.error(f"Unexpected error in file operation: {e}")
return None
def load_yaml_config(config_path: str) -> Optional[Dict[str, Any]]:
"""
Load configuration from YAML file
Args:
config_path: Path to the YAML configuration file
Returns:
Configuration dictionary or None if failed
"""
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
logger.info(f"Configuration loaded from {config_path}")
return config
except FileNotFoundError:
logger.warning(f"Configuration file not found: {config_path}")
return None
except yaml.YAMLError as e:
logger.error(f"Error parsing YAML configuration: {e}")
return None
except Exception as e:
logger.error(f"Unexpected error loading configuration: {e}")
return None
def save_yaml_config(config: Dict[str, Any], config_path: str) -> bool:
"""
Save configuration to YAML file
Args:
config: Configuration dictionary
config_path: Path to save the configuration file
Returns:
True if successful, False otherwise
"""
try:
ensure_directory_exists(os.path.dirname(config_path))
with open(config_path, 'w', encoding='utf-8') as f:
yaml.dump(config, f, default_flow_style=False, indent=2)
logger.info(f"Configuration saved to {config_path}")
return True
except Exception as e:
logger.error(f"Error saving configuration: {e}")
return False
def validate_file_path(file_path: str) -> bool:
"""
Validate if a file path exists and is accessible
Args:
file_path: Path to validate
Returns:
True if file exists and is accessible, False otherwise
"""
try:
return os.path.isfile(file_path) and os.access(file_path, os.R_OK)
except Exception:
return False
def get_file_extension(file_path: str) -> str:
"""
Get the file extension from a file path
Args:
file_path: Path to the file
Returns:
File extension (including the dot)
"""
return Path(file_path).suffix.lower()
def is_video_file(file_path: str) -> bool:
"""
Check if a file is a video file based on its extension
Args:
file_path: Path to the file
Returns:
True if it's a video file, False otherwise
"""
video_extensions = {'.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.m4v'}
return get_file_extension(file_path) in video_extensions
def is_image_file(file_path: str) -> bool:
"""
Check if a file is an image file based on its extension
Args:
file_path: Path to the file
Returns:
True if it's an image file, False otherwise
"""
image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'}
return get_file_extension(file_path) in image_extensions
def format_duration(seconds: float) -> str:
"""
Format duration in seconds to human-readable string
Args:
seconds: Duration in seconds
Returns:
Formatted duration string (e.g., "1h 23m 45s")
"""
if seconds < 60:
return f"{seconds:.1f}s"
elif seconds < 3600:
minutes = int(seconds // 60)
remaining_seconds = seconds % 60
return f"{minutes}m {remaining_seconds:.1f}s"
else:
hours = int(seconds // 3600)
remaining_minutes = int((seconds % 3600) // 60)
remaining_seconds = seconds % 60
return f"{hours}h {remaining_minutes}m {remaining_seconds:.1f}s"
def format_file_size(bytes_size: int) -> str:
"""
Format file size in bytes to human-readable string
Args:
bytes_size: Size in bytes
Returns:
Formatted size string (e.g., "1.5 MB")
"""
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
if bytes_size < 1024.0:
return f"{bytes_size:.1f} {unit}"
bytes_size /= 1024.0
return f"{bytes_size:.1f} PB"
def sanitize_filename(filename: str) -> str:
"""
Sanitize a filename by removing or replacing invalid characters
Args:
filename: Original filename
Returns:
Sanitized filename
"""
# Replace invalid characters with underscores
invalid_chars = '<>:"/\\|?*'
for char in invalid_chars:
filename = filename.replace(char, '_')
# Remove leading/trailing spaces and dots
filename = filename.strip(' .')
# Ensure filename is not empty
if not filename:
filename = "unnamed"
return filename
def create_backup_file(file_path: str, backup_suffix: str = ".backup") -> Optional[str]:
"""
Create a backup of a file
Args:
file_path: Path to the file to backup
backup_suffix: Suffix for the backup file
Returns:
Path to the backup file or None if failed
"""
try:
if not os.path.exists(file_path):
logger.warning(f"File does not exist: {file_path}")
return None
backup_path = file_path + backup_suffix
import shutil
shutil.copy2(file_path, backup_path)
logger.info(f"Backup created: {backup_path}")
return backup_path
except Exception as e:
logger.error(f"Failed to create backup: {e}")
return None
def merge_dictionaries(dict1: Dict[str, Any], dict2: Dict[str, Any], overwrite: bool = True) -> Dict[str, Any]:
"""
Merge two dictionaries, with option to overwrite existing keys
Args:
dict1: First dictionary
dict2: Second dictionary
overwrite: Whether to overwrite existing keys in dict1
Returns:
Merged dictionary
"""
result = dict1.copy()
for key, value in dict2.items():
if key not in result or overwrite:
result[key] = value
elif isinstance(result[key], dict) and isinstance(value, dict):
result[key] = merge_dictionaries(result[key], value, overwrite)
return result
def chunk_list(lst: List[Any], chunk_size: int) -> List[List[Any]]:
"""
Split a list into chunks of specified size
Args:
lst: List to chunk
chunk_size: Size of each chunk
Returns:
List of chunks
"""
return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
def retry_operation(operation_func, max_retries: int = 3, delay: float = 1.0, *args, **kwargs) -> Optional[Any]:
"""
Retry an operation with exponential backoff
Args:
operation_func: Function to retry
max_retries: Maximum number of retries
delay: Initial delay between retries
*args: Arguments for the function
**kwargs: Keyword arguments for the function
Returns:
Result of the operation or None if all retries failed
"""
import time
for attempt in range(max_retries + 1):
try:
return operation_func(*args, **kwargs)
except Exception as e:
if attempt == max_retries:
logger.error(f"Operation failed after {max_retries} retries: {e}")
return None
wait_time = delay * (2 ** attempt)
logger.warning(f"Operation failed (attempt {attempt + 1}/{max_retries + 1}), retrying in {wait_time}s: {e}")
time.sleep(wait_time)
return None

View File

@ -0,0 +1,8 @@
# Core dependencies managed by PythonDepManager
# These are automatically handled by the plugin's dependency management system
# PythonDepManager will ensure the correct versions are installed
# Development and testing dependencies
coverage>=7.0.0
pytest>=7.0.0
pytest-cov>=4.0.0

View File

@ -0,0 +1,110 @@
#!/usr/bin/env python3
"""
Test runner for A Haven VLM Connector
Runs all unit tests with coverage reporting
"""
import sys
import os
import subprocess
import unittest
from pathlib import Path
def install_test_dependencies():
"""Install test dependencies if not already installed"""
test_deps = [
'coverage',
'pytest',
'pytest-cov'
]
for dep in test_deps:
try:
__import__(dep.replace('-', '_'))
except ImportError:
print(f"Installing {dep}...")
subprocess.check_call([sys.executable, "-m", "pip", "install", dep])
def run_tests_with_coverage():
"""Run tests with coverage reporting"""
# Install test dependencies
install_test_dependencies()
# Get the directory containing this script
script_dir = Path(__file__).parent
# Discover and run tests
loader = unittest.TestLoader()
start_dir = script_dir
suite = loader.discover(start_dir, pattern='test_*.py')
# Run tests with coverage
import coverage
# Start coverage measurement
cov = coverage.Coverage(
source=['haven_vlm_config.py', 'haven_vlm_engine.py', 'haven_media_handler.py',
'haven_vlm_connector.py', 'haven_vlm_utility.py'],
omit=['*/test_*.py', '*/__pycache__/*', '*/venv/*', '*/env/*']
)
cov.start()
# Run the tests
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(suite)
# Stop coverage measurement
cov.stop()
cov.save()
# Generate coverage report
print("\n" + "="*60)
print("COVERAGE REPORT")
print("="*60)
cov.report()
# Generate HTML coverage report
cov.html_report(directory='htmlcov')
print(f"\nHTML coverage report generated in: {script_dir}/htmlcov/index.html")
return result.wasSuccessful()
def run_specific_test(test_file):
"""Run a specific test file"""
if not test_file.endswith('.py'):
test_file += '.py'
test_path = Path(__file__).parent / test_file
if not test_path.exists():
print(f"Test file not found: {test_path}")
return False
# Run the specific test
loader = unittest.TestLoader()
suite = loader.loadTestsFromName(test_file[:-3])
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(suite)
return result.wasSuccessful()
def main():
"""Main entry point"""
if len(sys.argv) > 1:
# Run specific test file
test_file = sys.argv[1]
success = run_specific_test(test_file)
else:
# Run all tests with coverage
success = run_tests_with_coverage()
if success:
print("\n✅ All tests passed!")
sys.exit(0)
else:
print("\n❌ Some tests failed!")
sys.exit(1)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,98 @@
"""
Unit tests for dependency management functionality using PythonDepManager
"""
import unittest
import sys
from unittest.mock import patch, MagicMock, mock_open
import tempfile
import os
class TestPythonDepManagerIntegration(unittest.TestCase):
"""Test cases for PythonDepManager integration"""
def setUp(self):
"""Set up test fixtures"""
# Mock PythonDepManager module
self.mock_python_dep_manager = MagicMock()
sys.modules['PythonDepManager'] = self.mock_python_dep_manager
def tearDown(self):
"""Clean up after tests"""
if 'PythonDepManager' in sys.modules:
del sys.modules['PythonDepManager']
@patch('builtins.print')
def test_dependency_import_failure(self, mock_print):
"""Test dependency import failure handling"""
# Mock ensure_import to raise ImportError
self.mock_python_dep_manager.ensure_import = MagicMock(side_effect=ImportError("Package not found"))
# Test that the error is handled gracefully
with self.assertRaises(SystemExit):
import haven_vlm_connector
def test_error_messages(self):
"""Test that appropriate error messages are displayed"""
# Mock ensure_import to raise ImportError
self.mock_python_dep_manager.ensure_import = MagicMock(side_effect=ImportError("Package not found"))
with patch('builtins.print') as mock_print:
with self.assertRaises(SystemExit):
import haven_vlm_connector
# Check that appropriate error messages were printed
print_calls = [call[0][0] for call in mock_print.call_args_list]
self.assertTrue(any("Failed to import PythonDepManager" in msg for msg in print_calls if isinstance(msg, str)))
self.assertTrue(any("Please ensure PythonDepManager is installed" in msg for msg in print_calls if isinstance(msg, str)))
class TestDependencyManagementEdgeCases(unittest.TestCase):
"""Test edge cases in dependency management"""
def setUp(self):
"""Set up test fixtures"""
self.mock_python_dep_manager = MagicMock()
sys.modules['PythonDepManager'] = self.mock_python_dep_manager
def tearDown(self):
"""Clean up after tests"""
if 'PythonDepManager' in sys.modules:
del sys.modules['PythonDepManager']
def test_missing_python_dep_manager(self):
"""Test behavior when PythonDepManager is not available"""
# Remove PythonDepManager from sys.modules
if 'PythonDepManager' in sys.modules:
del sys.modules['PythonDepManager']
with patch('builtins.print') as mock_print:
with self.assertRaises(SystemExit):
import haven_vlm_connector
# Check that appropriate error message was printed
print_calls = [call[0][0] for call in mock_print.call_args_list]
self.assertTrue(any("Failed to import PythonDepManager" in msg for msg in print_calls if isinstance(msg, str)))
def test_partial_dependency_failure(self):
"""Test behavior when some dependencies fail to import"""
# Mock ensure_import to succeed but some imports to fail
self.mock_python_dep_manager.ensure_import = MagicMock()
# Mock some successful imports but not all
mock_stashapi = MagicMock()
sys.modules['stashapi.log'] = mock_stashapi
sys.modules['stashapi.stashapp'] = mock_stashapi
# Don't mock aiohttp, so it should fail
with patch('builtins.print') as mock_print:
with self.assertRaises(SystemExit):
import haven_vlm_connector
# Check that appropriate error message was printed
print_calls = [call[0][0] for call in mock_print.call_args_list]
self.assertTrue(any("Error during dependency management" in msg for msg in print_calls if isinstance(msg, str)))
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,387 @@
"""
Unit tests for Haven Media Handler Module
Tests StashApp media operations and tag management
"""
import unittest
from unittest.mock import Mock, patch, MagicMock
from typing import List, Dict, Any, Optional
import sys
import os
# Add the current directory to the path to import the module
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
# Mock the dependencies before importing the module
sys.modules['PythonDepManager'] = Mock()
sys.modules['stashapi.stashapp'] = Mock()
sys.modules['stashapi.log'] = Mock()
sys.modules['haven_vlm_config'] = Mock()
# Import the module after mocking dependencies
import haven_media_handler
class TestHavenMediaHandler(unittest.TestCase):
"""Test cases for Haven Media Handler"""
def setUp(self) -> None:
"""Set up test fixtures"""
# Mock the stash interface
self.mock_stash = Mock()
self.mock_stash.find_tag.return_value = {"id": 1}
self.mock_stash.get_configuration.return_value = {"ui": {"vrTag": "VR"}}
self.mock_stash.stash_version.return_value = Mock()
# Mock the log module
self.mock_log = Mock()
# Patch the global variables
haven_media_handler.stash = self.mock_stash
haven_media_handler.log = self.mock_log
# Mock tag IDs
haven_media_handler.vlm_errored_tag_id = 1
haven_media_handler.vlm_tagme_tag_id = 2
haven_media_handler.vlm_base_tag_id = 3
haven_media_handler.vlm_tagged_tag_id = 4
haven_media_handler.vr_tag_id = 5
haven_media_handler.vlm_incorrect_tag_id = 6
def tearDown(self) -> None:
"""Clean up after tests"""
# Clear any cached data
haven_media_handler.tag_id_cache.clear()
haven_media_handler.vlm_tag_ids_cache.clear()
def test_clear_all_tags_from_video_with_tags(self) -> None:
"""Test clearing all tags from a video that has tags"""
# Mock scene with tags
mock_scene = {
"id": 123,
"tags": [
{"id": 10, "name": "Tag1"},
{"id": 20, "name": "Tag2"},
{"id": 30, "name": "Tag3"}
]
}
# Call the function
haven_media_handler.clear_all_tags_from_video(mock_scene)
# Verify tags were removed
self.mock_stash.update_scenes.assert_called_once_with({
"ids": [123],
"tag_ids": {"ids": [10, 20, 30], "mode": "REMOVE"}
})
# Verify log message
self.mock_log.info.assert_called_once_with("Cleared 3 tags from scene 123")
def test_clear_all_tags_from_video_no_tags(self) -> None:
"""Test clearing all tags from a video that has no tags"""
# Mock scene without tags
mock_scene = {"id": 123, "tags": []}
# Call the function
haven_media_handler.clear_all_tags_from_video(mock_scene)
# Verify no update was called since there are no tags
self.mock_stash.update_scenes.assert_not_called()
# Verify no log message
self.mock_log.info.assert_not_called()
def test_clear_all_tags_from_video_scene_without_tags_key(self) -> None:
"""Test clearing all tags from a scene that doesn't have a tags key"""
# Mock scene without tags key
mock_scene = {"id": 123}
# Call the function
haven_media_handler.clear_all_tags_from_video(mock_scene)
# Verify no update was called
self.mock_stash.update_scenes.assert_not_called()
@patch('haven_media_handler.get_scene_markers')
@patch('haven_media_handler.delete_markers')
def test_clear_all_markers_from_video_with_markers(self, mock_delete_markers: Mock, mock_get_markers: Mock) -> None:
"""Test clearing all markers from a video that has markers"""
# Mock markers
mock_markers = [
{"id": 1, "title": "Marker1"},
{"id": 2, "title": "Marker2"}
]
mock_get_markers.return_value = mock_markers
# Call the function
haven_media_handler.clear_all_markers_from_video(123)
# Verify markers were retrieved
mock_get_markers.assert_called_once_with(123)
# Verify markers were deleted
mock_delete_markers.assert_called_once_with(mock_markers)
# Verify log message
self.mock_log.info.assert_called_once_with("Cleared all 2 markers from scene 123")
@patch('haven_media_handler.get_scene_markers')
@patch('haven_media_handler.delete_markers')
def test_clear_all_markers_from_video_no_markers(self, mock_delete_markers: Mock, mock_get_markers: Mock) -> None:
"""Test clearing all markers from a video that has no markers"""
# Mock no markers
mock_get_markers.return_value = []
# Call the function
haven_media_handler.clear_all_markers_from_video(123)
# Verify markers were retrieved
mock_get_markers.assert_called_once_with(123)
# Verify no deletion was called
mock_delete_markers.assert_not_called()
# Verify no log message
self.mock_log.info.assert_not_called()
def test_add_tags_to_video_with_tagged(self) -> None:
"""Test adding tags to video with tagged flag enabled"""
# Call the function
haven_media_handler.add_tags_to_video(123, [10, 20, 30], add_tagged=True)
# Verify tags were added (including tagged tag)
self.mock_stash.update_scenes.assert_called_once_with({
"ids": [123],
"tag_ids": {"ids": [10, 20, 30, 4], "mode": "ADD"}
})
def test_add_tags_to_video_without_tagged(self) -> None:
"""Test adding tags to video with tagged flag disabled"""
# Call the function
haven_media_handler.add_tags_to_video(123, [10, 20, 30], add_tagged=False)
# Verify tags were added (without tagged tag)
self.mock_stash.update_scenes.assert_called_once_with({
"ids": [123],
"tag_ids": {"ids": [10, 20, 30], "mode": "ADD"}
})
@patch('haven_media_handler.get_vlm_tags')
def test_remove_vlm_tags_from_video(self, mock_get_vlm_tags: Mock) -> None:
"""Test removing VLM tags from video"""
# Mock VLM tags
mock_get_vlm_tags.return_value = [100, 200, 300]
# Call the function
haven_media_handler.remove_vlm_tags_from_video(123, remove_tagme=True, remove_errored=True)
# Verify VLM tags were retrieved
mock_get_vlm_tags.assert_called_once()
# Verify tags were removed (including tagme and errored tags)
self.mock_stash.update_scenes.assert_called_once_with({
"ids": [123],
"tag_ids": {"ids": [100, 200, 300, 2, 1], "mode": "REMOVE"}
})
def test_get_tagme_scenes(self) -> None:
"""Test getting scenes tagged with VLM_TagMe"""
# Mock scenes
mock_scenes = [{"id": 1}, {"id": 2}]
self.mock_stash.find_scenes.return_value = mock_scenes
# Call the function
result = haven_media_handler.get_tagme_scenes()
# Verify scenes were found
self.mock_stash.find_scenes.assert_called_once_with(
f={"tags": {"value": 2, "modifier": "INCLUDES"}},
fragment="id tags {id} files {path duration fingerprint(type: \"phash\")}"
)
# Verify result
self.assertEqual(result, mock_scenes)
def test_add_error_scene(self) -> None:
"""Test adding error tag to a scene"""
# Call the function
haven_media_handler.add_error_scene(123)
# Verify error tag was added
self.mock_stash.update_scenes.assert_called_once_with({
"ids": [123],
"tag_ids": {"ids": [1], "mode": "ADD"}
})
def test_remove_tagme_tag_from_scene(self) -> None:
"""Test removing VLM_TagMe tag from a scene"""
# Call the function
haven_media_handler.remove_tagme_tag_from_scene(123)
# Verify tagme tag was removed
self.mock_stash.update_scenes.assert_called_once_with({
"ids": [123],
"tag_ids": {"ids": [2], "mode": "REMOVE"}
})
def test_is_scene_tagged_true(self) -> None:
"""Test checking if a scene is tagged (true case)"""
# Mock tags including tagged tag
tags = [
{"id": 10, "name": "Tag1"},
{"id": 4, "name": "VLM_Tagged"}, # This is the tagged tag
{"id": 20, "name": "Tag2"}
]
# Call the function
result = haven_media_handler.is_scene_tagged(tags)
# Verify result
self.assertTrue(result)
def test_is_scene_tagged_false(self) -> None:
"""Test checking if a scene is tagged (false case)"""
# Mock tags without tagged tag
tags = [
{"id": 10, "name": "Tag1"},
{"id": 20, "name": "Tag2"}
]
# Call the function
result = haven_media_handler.is_scene_tagged(tags)
# Verify result
self.assertFalse(result)
def test_is_vr_scene_true(self) -> None:
"""Test checking if a scene is VR (true case)"""
# Mock tags including VR tag
tags = [
{"id": 10, "name": "Tag1"},
{"id": 5, "name": "VR"}, # This is the VR tag
{"id": 20, "name": "Tag2"}
]
# Call the function
result = haven_media_handler.is_vr_scene(tags)
# Verify result
self.assertTrue(result)
def test_is_vr_scene_false(self) -> None:
"""Test checking if a scene is VR (false case)"""
# Mock tags without VR tag
tags = [
{"id": 10, "name": "Tag1"},
{"id": 20, "name": "Tag2"}
]
# Call the function
result = haven_media_handler.is_vr_scene(tags)
# Verify result
self.assertFalse(result)
def test_get_tag_id_existing(self) -> None:
"""Test getting tag ID for existing tag"""
# Mock existing tag
self.mock_stash.find_tag.return_value = {"id": 123, "name": "TestTag"}
# Call the function
result = haven_media_handler.get_tag_id("TestTag", create=False)
# Verify tag was found
self.mock_stash.find_tag.assert_called_once_with("TestTag")
# Verify result
self.assertEqual(result, 123)
def test_get_tag_id_not_existing_no_create(self) -> None:
"""Test getting tag ID for non-existing tag without create"""
# Mock non-existing tag
self.mock_stash.find_tag.return_value = None
# Call the function
result = haven_media_handler.get_tag_id("TestTag", create=False)
# Verify tag was searched
self.mock_stash.find_tag.assert_called_once_with("TestTag")
# Verify result is None
self.assertIsNone(result)
def test_get_tag_id_create_new(self) -> None:
"""Test getting tag ID for non-existing tag with create"""
# Mock non-existing tag
self.mock_stash.find_tag.return_value = None
# Mock created tag
self.mock_stash.create_tag.return_value = {"id": 456, "name": "TestTag"}
# Call the function
result = haven_media_handler.get_tag_id("TestTag", create=True)
# Verify tag was searched
self.mock_stash.find_tag.assert_called_once_with("TestTag")
# Verify tag was created
self.mock_stash.create_tag.assert_called_once_with({
"name": "TestTag",
"ignore_auto_tag": True,
"parent_ids": [3]
})
# Verify result
self.assertEqual(result, 456)
def test_get_tag_ids(self) -> None:
"""Test getting multiple tag IDs"""
# Mock tag IDs
with patch('haven_media_handler.get_tag_id') as mock_get_tag_id:
mock_get_tag_id.side_effect = [10, 20, 30]
# Call the function
result = haven_media_handler.get_tag_ids(["Tag1", "Tag2", "Tag3"], create=True)
# Verify individual tag IDs were retrieved
self.assertEqual(mock_get_tag_id.call_count, 3)
mock_get_tag_id.assert_any_call("Tag1", True)
mock_get_tag_id.assert_any_call("Tag2", True)
mock_get_tag_id.assert_any_call("Tag3", True)
# Verify result
self.assertEqual(result, [10, 20, 30])
@patch('haven_media_handler.vlm_tag_ids_cache')
def test_get_vlm_tags_from_cache(self, mock_cache: Mock) -> None:
"""Test getting VLM tags from cache"""
# Mock cached tags
mock_cache.__len__.return_value = 3
mock_cache.__iter__.return_value = iter([100, 200, 300])
# Call the function
result = haven_media_handler.get_vlm_tags()
# Verify result from cache
self.assertEqual(result, [100, 200, 300])
def test_get_vlm_tags_from_stash(self) -> None:
"""Test getting VLM tags from stash when cache is empty"""
# Mock empty cache
haven_media_handler.vlm_tag_ids_cache.clear()
# Mock stash tags
mock_tags = [
{"id": 100, "name": "VLM_Tag1"},
{"id": 200, "name": "VLM_Tag2"}
]
self.mock_stash.find_tags.return_value = mock_tags
# Call the function
result = haven_media_handler.get_vlm_tags()
# Verify tags were found
self.mock_stash.find_tags.assert_called_once_with(
f={"parents": {"value": 3, "modifier": "INCLUDES"}},
fragment="id"
)
# Verify result
self.assertEqual(result, [100, 200])
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,286 @@
"""
Unit tests for haven_vlm_config module
"""
import unittest
import tempfile
import os
import yaml
from unittest.mock import patch, mock_open
from dataclasses import dataclass
import haven_vlm_config
class TestVLMConnectorConfig(unittest.TestCase):
"""Test cases for VLMConnectorConfig dataclass"""
def test_vlm_connector_config_creation(self):
"""Test creating VLMConnectorConfig with all required fields"""
config = haven_vlm_config.VLMConnectorConfig(
vlm_engine_config={"test": "config"},
video_frame_interval=2.0,
video_threshold=0.3,
video_confidence_return=True,
image_threshold=0.5,
image_batch_size=320,
image_confidence_return=False,
concurrent_task_limit=10,
server_timeout=3700,
vlm_base_tag_name="VLM",
vlm_tagme_tag_name="VLM_TagMe",
vlm_updateme_tag_name="VLM_UpdateMe",
vlm_tagged_tag_name="VLM_Tagged",
vlm_errored_tag_name="VLM_Errored",
vlm_incorrect_tag_name="VLM_Incorrect",
temp_image_dir="./temp_images",
output_data_dir="./output_data",
delete_incorrect_markers=True,
create_markers=True,
path_mutation={}
)
self.assertEqual(config.video_frame_interval, 2.0)
self.assertEqual(config.video_threshold, 0.3)
self.assertEqual(config.image_threshold, 0.5)
self.assertEqual(config.concurrent_task_limit, 10)
self.assertEqual(config.vlm_base_tag_name, "VLM")
self.assertEqual(config.temp_image_dir, "./temp_images")
def test_vlm_connector_config_defaults(self):
"""Test VLMConnectorConfig with minimal required fields"""
config = haven_vlm_config.VLMConnectorConfig(
vlm_engine_config={},
video_frame_interval=1.0,
video_threshold=0.1,
video_confidence_return=False,
image_threshold=0.1,
image_batch_size=100,
image_confidence_return=False,
concurrent_task_limit=5,
server_timeout=1000,
vlm_base_tag_name="TEST",
vlm_tagme_tag_name="TEST_TagMe",
vlm_updateme_tag_name="TEST_UpdateMe",
vlm_tagged_tag_name="TEST_Tagged",
vlm_errored_tag_name="TEST_Errored",
vlm_incorrect_tag_name="TEST_Incorrect",
temp_image_dir="./test_temp",
output_data_dir="./test_output",
delete_incorrect_markers=False,
create_markers=False,
path_mutation={"test": "mutation"}
)
self.assertEqual(config.video_frame_interval, 1.0)
self.assertEqual(config.video_threshold, 0.1)
self.assertEqual(config.path_mutation, {"test": "mutation"})
class TestLoadConfigFromYaml(unittest.TestCase):
"""Test cases for load_config_from_yaml function"""
def setUp(self):
"""Set up test fixtures"""
self.test_config = {
"vlm_engine_config": {
"active_ai_models": ["test_model"],
"pipelines": {},
"models": {},
"category_config": {}
},
"video_frame_interval": 3.0,
"video_threshold": 0.4,
"video_confidence_return": True,
"image_threshold": 0.6,
"image_batch_size": 500,
"image_confidence_return": True,
"concurrent_task_limit": 15,
"server_timeout": 5000,
"vlm_base_tag_name": "TEST_VLM",
"vlm_tagme_tag_name": "TEST_VLM_TagMe",
"vlm_updateme_tag_name": "TEST_VLM_UpdateMe",
"vlm_tagged_tag_name": "TEST_VLM_Tagged",
"vlm_errored_tag_name": "TEST_VLM_Errored",
"vlm_incorrect_tag_name": "TEST_VLM_Incorrect",
"temp_image_dir": "./test_temp_images",
"output_data_dir": "./test_output_data",
"delete_incorrect_markers": False,
"create_markers": False,
"path_mutation": {"E:": "F:"}
}
def test_load_config_from_yaml_with_valid_file(self):
"""Test loading configuration from a valid YAML file"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.yml', delete=False) as f:
yaml.dump(self.test_config, f)
config_path = f.name
try:
config = haven_vlm_config.load_config_from_yaml(config_path)
self.assertIsInstance(config, haven_vlm_config.VLMConnectorConfig)
self.assertEqual(config.video_frame_interval, 3.0)
self.assertEqual(config.video_threshold, 0.4)
self.assertEqual(config.image_threshold, 0.6)
self.assertEqual(config.concurrent_task_limit, 15)
self.assertEqual(config.vlm_base_tag_name, "TEST_VLM")
self.assertEqual(config.path_mutation, {"E:": "F:"})
finally:
os.unlink(config_path)
def test_load_config_from_yaml_with_nonexistent_file(self):
"""Test loading configuration with nonexistent file path"""
config = haven_vlm_config.load_config_from_yaml("nonexistent_file.yml")
# Should return default configuration
self.assertIsInstance(config, haven_vlm_config.VLMConnectorConfig)
self.assertEqual(config.video_frame_interval, haven_vlm_config.VIDEO_FRAME_INTERVAL)
self.assertEqual(config.video_threshold, haven_vlm_config.VIDEO_THRESHOLD)
def test_load_config_from_yaml_with_none_path(self):
"""Test loading configuration with None path"""
config = haven_vlm_config.load_config_from_yaml(None)
# Should return default configuration
self.assertIsInstance(config, haven_vlm_config.VLMConnectorConfig)
self.assertEqual(config.video_frame_interval, haven_vlm_config.VIDEO_FRAME_INTERVAL)
def test_load_config_from_yaml_with_invalid_yaml(self):
"""Test loading configuration with invalid YAML content"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.yml', delete=False) as f:
f.write("invalid: yaml: content: [")
config_path = f.name
try:
config = haven_vlm_config.load_config_from_yaml(config_path)
# Should return default configuration on YAML error
self.assertIsInstance(config, haven_vlm_config.VLMConnectorConfig)
self.assertEqual(config.video_frame_interval, haven_vlm_config.VIDEO_FRAME_INTERVAL)
finally:
os.unlink(config_path)
def test_load_config_from_yaml_with_file_permission_error(self):
"""Test loading configuration with file permission error"""
with patch('builtins.open', side_effect=PermissionError("Permission denied")):
config = haven_vlm_config.load_config_from_yaml("test.yml")
# Should return default configuration on file error
self.assertIsInstance(config, haven_vlm_config.VLMConnectorConfig)
self.assertEqual(config.video_frame_interval, haven_vlm_config.VIDEO_FRAME_INTERVAL)
class TestConfigurationConstants(unittest.TestCase):
"""Test cases for configuration constants"""
def test_vlm_engine_config_structure(self):
"""Test that VLM_ENGINE_CONFIG has the expected structure"""
config = haven_vlm_config.VLM_ENGINE_CONFIG
# Check required top-level keys
self.assertIn("active_ai_models", config)
self.assertIn("pipelines", config)
self.assertIn("models", config)
self.assertIn("category_config", config)
# Check active_ai_models is a list
self.assertIsInstance(config["active_ai_models"], list)
self.assertIn("vlm_multiplexer_model", config["active_ai_models"])
# Check pipelines structure
self.assertIn("video_pipeline_dynamic", config["pipelines"])
pipeline = config["pipelines"]["video_pipeline_dynamic"]
self.assertIn("inputs", pipeline)
self.assertIn("output", pipeline)
self.assertIn("models", pipeline)
# Check models structure
self.assertIn("vlm_multiplexer_model", config["models"])
model = config["models"]["vlm_multiplexer_model"]
self.assertIn("type", model)
self.assertIn("multiplexer_endpoints", model)
self.assertIn("tag_list", model)
def test_processing_settings(self):
"""Test that processing settings have valid values"""
self.assertGreater(haven_vlm_config.VIDEO_FRAME_INTERVAL, 0)
self.assertGreaterEqual(haven_vlm_config.VIDEO_THRESHOLD, 0)
self.assertLessEqual(haven_vlm_config.VIDEO_THRESHOLD, 1)
self.assertGreaterEqual(haven_vlm_config.IMAGE_THRESHOLD, 0)
self.assertLessEqual(haven_vlm_config.IMAGE_THRESHOLD, 1)
self.assertGreater(haven_vlm_config.IMAGE_BATCH_SIZE, 0)
self.assertGreater(haven_vlm_config.CONCURRENT_TASK_LIMIT, 0)
self.assertGreater(haven_vlm_config.SERVER_TIMEOUT, 0)
def test_tag_names(self):
"""Test that tag names are valid strings"""
tag_names = [
haven_vlm_config.VLM_BASE_TAG_NAME,
haven_vlm_config.VLM_TAGME_TAG_NAME,
haven_vlm_config.VLM_UPDATEME_TAG_NAME,
haven_vlm_config.VLM_TAGGED_TAG_NAME,
haven_vlm_config.VLM_ERRORED_TAG_NAME,
haven_vlm_config.VLM_INCORRECT_TAG_NAME
]
for tag_name in tag_names:
self.assertIsInstance(tag_name, str)
self.assertGreater(len(tag_name), 0)
def test_directory_paths(self):
"""Test that directory paths are valid strings"""
self.assertIsInstance(haven_vlm_config.TEMP_IMAGE_DIR, str)
self.assertIsInstance(haven_vlm_config.OUTPUT_DATA_DIR, str)
self.assertGreater(len(haven_vlm_config.TEMP_IMAGE_DIR), 0)
self.assertGreater(len(haven_vlm_config.OUTPUT_DATA_DIR), 0)
def test_boolean_settings(self):
"""Test that boolean settings are valid"""
self.assertIsInstance(haven_vlm_config.DELETE_INCORRECT_MARKERS, bool)
self.assertIsInstance(haven_vlm_config.CREATE_MARKERS, bool)
def test_path_mutation(self):
"""Test that path mutation is a dictionary"""
self.assertIsInstance(haven_vlm_config.PATH_MUTATION, dict)
class TestGlobalConfigInstance(unittest.TestCase):
"""Test cases for the global config instance"""
def test_global_config_exists(self):
"""Test that the global config instance exists and is valid"""
self.assertIsInstance(haven_vlm_config.config, haven_vlm_config.VLMConnectorConfig)
def test_global_config_has_required_attributes(self):
"""Test that the global config has all required attributes"""
config = haven_vlm_config.config
# Check that all required attributes exist
required_attrs = [
'vlm_engine_config', 'video_frame_interval', 'video_threshold',
'video_confidence_return', 'image_threshold', 'image_batch_size',
'image_confidence_return', 'concurrent_task_limit', 'server_timeout',
'vlm_base_tag_name', 'vlm_tagme_tag_name', 'vlm_updateme_tag_name',
'vlm_tagged_tag_name', 'vlm_errored_tag_name', 'vlm_incorrect_tag_name',
'temp_image_dir', 'output_data_dir', 'delete_incorrect_markers',
'create_markers', 'path_mutation'
]
for attr in required_attrs:
self.assertTrue(hasattr(config, attr), f"Missing attribute: {attr}")
def test_global_config_values(self):
"""Test that the global config has expected default values"""
config = haven_vlm_config.config
self.assertEqual(config.video_frame_interval, haven_vlm_config.VIDEO_FRAME_INTERVAL)
self.assertEqual(config.video_threshold, haven_vlm_config.VIDEO_THRESHOLD)
self.assertEqual(config.image_threshold, haven_vlm_config.IMAGE_THRESHOLD)
self.assertEqual(config.concurrent_task_limit, haven_vlm_config.CONCURRENT_TASK_LIMIT)
self.assertEqual(config.vlm_base_tag_name, haven_vlm_config.VLM_BASE_TAG_NAME)
self.assertEqual(config.temp_image_dir, haven_vlm_config.TEMP_IMAGE_DIR)
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,451 @@
"""
Unit tests for haven_vlm_connector module
"""
import unittest
import asyncio
import json
import tempfile
import os
from unittest.mock import patch, MagicMock, AsyncMock, mock_open
import sys
# Mock the stashapi imports
sys.modules['stashapi.log'] = MagicMock()
sys.modules['stashapi.stashapp'] = MagicMock()
# Mock the vlm_engine imports
sys.modules['vlm_engine'] = MagicMock()
sys.modules['vlm_engine.config_models'] = MagicMock()
import haven_vlm_connector
class TestMainExecution(unittest.TestCase):
"""Test cases for main execution functions"""
def setUp(self):
"""Set up test fixtures"""
self.sample_json_input = {
"server_connection": {
"PluginDir": "/tmp/plugin"
},
"args": {
"mode": "tag_videos"
}
}
@patch('haven_vlm_connector.media_handler')
@patch('haven_vlm_connector.tag_videos')
@patch('haven_vlm_connector.os.chdir')
def test_run_tag_videos(self, mock_chdir, mock_tag_videos, mock_media_handler):
"""Test running tag_videos mode"""
output = {}
with patch('haven_vlm_connector.read_json_input', return_value=self.sample_json_input):
asyncio.run(haven_vlm_connector.run(self.sample_json_input, output))
mock_chdir.assert_called_once_with("/tmp/plugin")
mock_media_handler.initialize.assert_called_once_with(self.sample_json_input["server_connection"])
mock_tag_videos.assert_called_once()
self.assertEqual(output["output"], "ok")
@patch('haven_vlm_connector.media_handler')
@patch('haven_vlm_connector.tag_images')
@patch('haven_vlm_connector.os.chdir')
def test_run_tag_images(self, mock_chdir, mock_tag_images, mock_media_handler):
"""Test running tag_images mode"""
json_input = self.sample_json_input.copy()
json_input["args"]["mode"] = "tag_images"
output = {}
asyncio.run(haven_vlm_connector.run(json_input, output))
mock_tag_images.assert_called_once()
self.assertEqual(output["output"], "ok")
@patch('haven_vlm_connector.media_handler')
@patch('haven_vlm_connector.find_marker_settings')
@patch('haven_vlm_connector.os.chdir')
def test_run_find_marker_settings(self, mock_chdir, mock_find_marker_settings, mock_media_handler):
"""Test running find_marker_settings mode"""
json_input = self.sample_json_input.copy()
json_input["args"]["mode"] = "find_marker_settings"
output = {}
asyncio.run(haven_vlm_connector.run(json_input, output))
mock_find_marker_settings.assert_called_once()
self.assertEqual(output["output"], "ok")
@patch('haven_vlm_connector.media_handler')
@patch('haven_vlm_connector.collect_incorrect_markers_and_images')
@patch('haven_vlm_connector.os.chdir')
def test_run_collect_incorrect_markers(self, mock_chdir, mock_collect, mock_media_handler):
"""Test running collect_incorrect_markers mode"""
json_input = self.sample_json_input.copy()
json_input["args"]["mode"] = "collect_incorrect_markers"
output = {}
asyncio.run(haven_vlm_connector.run(json_input, output))
mock_collect.assert_called_once()
self.assertEqual(output["output"], "ok")
@patch('haven_vlm_connector.media_handler')
@patch('haven_vlm_connector.os.chdir')
def test_run_no_mode(self, mock_chdir, mock_media_handler):
"""Test running with no mode specified"""
json_input = self.sample_json_input.copy()
del json_input["args"]["mode"]
output = {}
asyncio.run(haven_vlm_connector.run(json_input, output))
self.assertEqual(output["output"], "ok")
@patch('haven_vlm_connector.media_handler')
def test_run_media_handler_initialization_error(self, mock_media_handler):
"""Test handling media handler initialization error"""
mock_media_handler.initialize.side_effect = Exception("Initialization failed")
output = {}
with self.assertRaises(Exception):
asyncio.run(haven_vlm_connector.run(self.sample_json_input, output))
def test_read_json_input(self):
"""Test reading JSON input from stdin"""
test_input = '{"test": "data"}'
with patch('sys.stdin.read', return_value=test_input):
result = haven_vlm_connector.read_json_input()
self.assertEqual(result, {"test": "data"})
class TestHighLevelProcessingFunctions(unittest.TestCase):
"""Test cases for high-level processing functions"""
@patch('haven_vlm_connector.media_handler')
@patch('haven_vlm_connector.__tag_images')
@patch('haven_vlm_connector.asyncio.gather')
def test_tag_images_with_images(self, mock_gather, mock_tag_images, mock_media_handler):
"""Test tagging images when images are available"""
mock_images = [{"id": 1}, {"id": 2}, {"id": 3}]
mock_media_handler.get_tagme_images.return_value = mock_images
asyncio.run(haven_vlm_connector.tag_images())
mock_media_handler.get_tagme_images.assert_called_once()
mock_gather.assert_called_once()
@patch('haven_vlm_connector.media_handler')
def test_tag_images_no_images(self, mock_media_handler):
"""Test tagging images when no images are available"""
mock_media_handler.get_tagme_images.return_value = []
asyncio.run(haven_vlm_connector.tag_images())
mock_media_handler.get_tagme_images.assert_called_once()
@patch('haven_vlm_connector.media_handler')
@patch('haven_vlm_connector.__tag_video')
@patch('haven_vlm_connector.asyncio.gather')
def test_tag_videos_with_scenes(self, mock_gather, mock_tag_video, mock_media_handler):
"""Test tagging videos when scenes are available"""
mock_scenes = [{"id": 1}, {"id": 2}]
mock_media_handler.get_tagme_scenes.return_value = mock_scenes
asyncio.run(haven_vlm_connector.tag_videos())
mock_media_handler.get_tagme_scenes.assert_called_once()
mock_gather.assert_called_once()
@patch('haven_vlm_connector.media_handler')
def test_tag_videos_no_scenes(self, mock_media_handler):
"""Test tagging videos when no scenes are available"""
mock_media_handler.get_tagme_scenes.return_value = []
asyncio.run(haven_vlm_connector.tag_videos())
mock_media_handler.get_tagme_scenes.assert_called_once()
@patch('haven_vlm_connector.media_handler')
@patch('haven_vlm_connector.__find_marker_settings')
def test_find_marker_settings_single_scene(self, mock_find_settings, mock_media_handler):
"""Test finding marker settings with single scene"""
mock_scenes = [{"id": 1}]
mock_media_handler.get_tagme_scenes.return_value = mock_scenes
asyncio.run(haven_vlm_connector.find_marker_settings())
mock_media_handler.get_tagme_scenes.assert_called_once()
mock_find_settings.assert_called_once_with(mock_scenes[0])
@patch('haven_vlm_connector.media_handler')
def test_find_marker_settings_no_scenes(self, mock_media_handler):
"""Test finding marker settings with no scenes"""
mock_media_handler.get_tagme_scenes.return_value = []
asyncio.run(haven_vlm_connector.find_marker_settings())
mock_media_handler.get_tagme_scenes.assert_called_once()
@patch('haven_vlm_connector.media_handler')
def test_find_marker_settings_multiple_scenes(self, mock_media_handler):
"""Test finding marker settings with multiple scenes"""
mock_scenes = [{"id": 1}, {"id": 2}]
mock_media_handler.get_tagme_scenes.return_value = mock_scenes
asyncio.run(haven_vlm_connector.find_marker_settings())
mock_media_handler.get_tagme_scenes.assert_called_once()
@patch('haven_vlm_connector.media_handler')
@patch('haven_vlm_connector.os.makedirs')
@patch('haven_vlm_connector.shutil.copy')
def test_collect_incorrect_markers_and_images_with_data(self, mock_copy, mock_makedirs, mock_media_handler):
"""Test collecting incorrect markers and images with data"""
mock_images = [{"id": 1, "files": [{"path": "/path/to/image.jpg"}]}]
mock_markers = [{"id": 1, "scene": {"files": [{"path": "/path/to/video.mp4"}]}, "primary_tag": {"name": "test"}}]
mock_media_handler.get_incorrect_images.return_value = mock_images
mock_media_handler.get_incorrect_markers.return_value = mock_markers
mock_media_handler.get_image_paths_and_ids.return_value = (["/path/to/image.jpg"], [1], [])
haven_vlm_connector.collect_incorrect_markers_and_images()
mock_media_handler.get_incorrect_images.assert_called_once()
mock_media_handler.get_incorrect_markers.assert_called_once()
mock_media_handler.remove_incorrect_tag_from_images.assert_called_once()
@patch('haven_vlm_connector.media_handler')
def test_collect_incorrect_markers_and_images_no_data(self, mock_media_handler):
"""Test collecting incorrect markers and images with no data"""
mock_media_handler.get_incorrect_images.return_value = []
mock_media_handler.get_incorrect_markers.return_value = []
haven_vlm_connector.collect_incorrect_markers_and_images()
mock_media_handler.get_incorrect_images.assert_called_once()
mock_media_handler.get_incorrect_markers.assert_called_once()
class TestLowLevelProcessingFunctions(unittest.TestCase):
"""Test cases for low-level processing functions"""
@patch('haven_vlm_connector.vlm_engine')
@patch('haven_vlm_connector.media_handler')
@patch('haven_vlm_connector.semaphore')
def test_tag_images_success(self, mock_semaphore, mock_media_handler, mock_vlm_engine):
"""Test successful image tagging"""
mock_images = [{"id": 1}, {"id": 2}]
mock_media_handler.get_image_paths_and_ids.return_value = (["/path1.jpg", "/path2.jpg"], [1, 2], [])
mock_vlm_engine.process_images_async.return_value = MagicMock(result=[{"tags": ["tag1"]}, {"tags": ["tag2"]}])
mock_media_handler.get_tag_ids.return_value = [100, 200]
# Mock semaphore context manager
mock_semaphore.__aenter__ = AsyncMock()
mock_semaphore.__aexit__ = AsyncMock()
asyncio.run(haven_vlm_connector.__tag_images(mock_images))
mock_media_handler.get_image_paths_and_ids.assert_called_once_with(mock_images)
mock_vlm_engine.process_images_async.assert_called_once()
mock_media_handler.remove_tagme_tags_from_images.assert_called_once()
@patch('haven_vlm_connector.vlm_engine')
@patch('haven_vlm_connector.media_handler')
@patch('haven_vlm_connector.semaphore')
def test_tag_images_error(self, mock_semaphore, mock_media_handler, mock_vlm_engine):
"""Test image tagging with error"""
mock_images = [{"id": 1}]
mock_vlm_engine.process_images_async.side_effect = Exception("Processing error")
# Mock semaphore context manager
mock_semaphore.__aenter__ = AsyncMock()
mock_semaphore.__aexit__ = AsyncMock()
asyncio.run(haven_vlm_connector.__tag_images(mock_images))
mock_media_handler.add_error_images.assert_called_once()
@patch('haven_vlm_connector.vlm_engine')
@patch('haven_vlm_connector.media_handler')
@patch('haven_vlm_connector.semaphore')
def test_tag_video_success(self, mock_semaphore, mock_media_handler, mock_vlm_engine):
"""Test successful video tagging"""
mock_scene = {
"id": 1,
"files": [{"path": "/path/to/video.mp4"}],
"tags": []
}
mock_vlm_engine.process_video_async.return_value = MagicMock(
video_tags={"category": ["tag1", "tag2"]},
tag_timespans={}
)
mock_media_handler.is_vr_scene.return_value = False
mock_media_handler.get_tag_ids.return_value = [100, 200]
# Mock semaphore context manager
mock_semaphore.__aenter__ = AsyncMock()
mock_semaphore.__aexit__ = AsyncMock()
asyncio.run(haven_vlm_connector.__tag_video(mock_scene))
mock_vlm_engine.process_video_async.assert_called_once()
# Verify tags and markers were cleared before adding new ones
mock_media_handler.clear_all_tags_from_video.assert_called_once_with(1)
mock_media_handler.clear_all_markers_from_video.assert_called_once_with(1)
mock_media_handler.add_tags_to_video.assert_called_once()
mock_media_handler.remove_tagme_tag_from_scene.assert_called_once()
@patch('haven_vlm_connector.vlm_engine')
@patch('haven_vlm_connector.media_handler')
@patch('haven_vlm_connector.semaphore')
def test_tag_video_error(self, mock_semaphore, mock_media_handler, mock_vlm_engine):
"""Test video tagging with error"""
mock_scene = {
"id": 1,
"files": [{"path": "/path/to/video.mp4"}],
"tags": []
}
mock_vlm_engine.process_video_async.side_effect = Exception("Processing error")
# Mock semaphore context manager
mock_semaphore.__aenter__ = AsyncMock()
mock_semaphore.__aexit__ = AsyncMock()
asyncio.run(haven_vlm_connector.__tag_video(mock_scene))
mock_media_handler.add_error_scene.assert_called_once()
@patch('haven_vlm_connector.vlm_engine')
@patch('haven_vlm_connector.media_handler')
def test_find_marker_settings_success(self, mock_media_handler, mock_vlm_engine):
"""Test successful marker settings finding"""
mock_scene = {
"id": 1,
"files": [{"path": "/path/to/video.mp4"}]
}
mock_markers = [
{
"primary_tag": {"name": "tag1"},
"seconds": 10.0,
"end_seconds": 15.0
}
]
mock_media_handler.get_scene_markers.return_value = mock_markers
mock_vlm_engine.find_optimal_marker_settings_async.return_value = {"optimal": "settings"}
asyncio.run(haven_vlm_connector.__find_marker_settings(mock_scene))
mock_media_handler.get_scene_markers.assert_called_once_with(1)
mock_vlm_engine.find_optimal_marker_settings_async.assert_called_once()
@patch('haven_vlm_connector.media_handler')
def test_find_marker_settings_error(self, mock_media_handler):
"""Test marker settings finding with error"""
mock_scene = {
"id": 1,
"files": [{"path": "/path/to/video.mp4"}]
}
mock_media_handler.get_scene_markers.side_effect = Exception("Marker error")
asyncio.run(haven_vlm_connector.__find_marker_settings(mock_scene))
mock_media_handler.get_scene_markers.assert_called_once()
class TestUtilityFunctions(unittest.TestCase):
"""Test cases for utility functions"""
def test_increment_progress(self):
"""Test progress increment"""
haven_vlm_connector.progress = 0.0
haven_vlm_connector.increment = 0.1
haven_vlm_connector.increment_progress()
self.assertEqual(haven_vlm_connector.progress, 0.1)
@patch('haven_vlm_connector.vlm_engine')
async def test_cleanup(self, mock_vlm_engine):
"""Test cleanup function"""
mock_vlm_engine.vlm_engine = MagicMock()
await haven_vlm_connector.cleanup()
mock_vlm_engine.vlm_engine.shutdown.assert_called_once()
class TestMainFunction(unittest.TestCase):
"""Test cases for main function"""
@patch('haven_vlm_connector.run')
@patch('haven_vlm_connector.read_json_input')
@patch('haven_vlm_connector.json.dumps')
@patch('builtins.print')
def test_main_success(self, mock_print, mock_json_dumps, mock_read_input, mock_run):
"""Test successful main execution"""
mock_read_input.return_value = {"test": "data"}
mock_json_dumps.return_value = '{"output": "ok"}'
asyncio.run(haven_vlm_connector.main())
mock_read_input.assert_called_once()
mock_run.assert_called_once()
mock_json_dumps.assert_called_once()
mock_print.assert_called()
class TestErrorHandling(unittest.TestCase):
"""Test cases for error handling"""
@patch('haven_vlm_connector.media_handler')
def test_tag_images_empty_paths(self, mock_media_handler):
"""Test image tagging with empty paths"""
mock_images = [{"id": 1}]
mock_media_handler.get_image_paths_and_ids.return_value = ([], [1], [])
# Mock semaphore context manager
with patch('haven_vlm_connector.semaphore') as mock_semaphore:
mock_semaphore.__aenter__ = AsyncMock()
mock_semaphore.__aexit__ = AsyncMock()
asyncio.run(haven_vlm_connector.__tag_images(mock_images))
mock_media_handler.get_image_paths_and_ids.assert_called_once()
@patch('haven_vlm_connector.vlm_engine')
@patch('haven_vlm_connector.media_handler')
def test_tag_video_no_detected_tags(self, mock_media_handler, mock_vlm_engine):
"""Test video tagging with no detected tags"""
mock_scene = {
"id": 1,
"files": [{"path": "/path/to/video.mp4"}],
"tags": []
}
mock_vlm_engine.process_video_async.return_value = MagicMock(
video_tags={},
tag_timespans={}
)
mock_media_handler.is_vr_scene.return_value = False
# Mock semaphore context manager
with patch('haven_vlm_connector.semaphore') as mock_semaphore:
mock_semaphore.__aenter__ = AsyncMock()
mock_semaphore.__aexit__ = AsyncMock()
asyncio.run(haven_vlm_connector.__tag_video(mock_scene))
# Verify clearing functions are NOT called when no tags are detected
mock_media_handler.clear_all_tags_from_video.assert_not_called()
mock_media_handler.clear_all_markers_from_video.assert_not_called()
mock_media_handler.remove_tagme_tag_from_scene.assert_called_once()
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,544 @@
"""
Unit tests for haven_vlm_engine module
"""
import unittest
import asyncio
import json
import tempfile
import os
from unittest.mock import patch, MagicMock, AsyncMock, mock_open
import sys
# Mock the vlm_engine imports
sys.modules['vlm_engine'] = MagicMock()
sys.modules['vlm_engine.config_models'] = MagicMock()
import haven_vlm_engine
class TestTimeFrame(unittest.TestCase):
"""Test cases for TimeFrame dataclass"""
def test_timeframe_creation(self):
"""Test creating TimeFrame with all parameters"""
timeframe = haven_vlm_engine.TimeFrame(
start=10.0,
end=15.0,
total_confidence=0.85
)
self.assertEqual(timeframe.start, 10.0)
self.assertEqual(timeframe.end, 15.0)
self.assertEqual(timeframe.total_confidence, 0.85)
def test_timeframe_creation_without_confidence(self):
"""Test creating TimeFrame without confidence"""
timeframe = haven_vlm_engine.TimeFrame(
start=10.0,
end=15.0
)
self.assertEqual(timeframe.start, 10.0)
self.assertEqual(timeframe.end, 15.0)
self.assertIsNone(timeframe.total_confidence)
def test_timeframe_to_json(self):
"""Test TimeFrame to_json method"""
timeframe = haven_vlm_engine.TimeFrame(
start=10.0,
end=15.0,
total_confidence=0.85
)
json_str = timeframe.to_json()
json_data = json.loads(json_str)
self.assertEqual(json_data["start"], 10.0)
self.assertEqual(json_data["end"], 15.0)
self.assertEqual(json_data["total_confidence"], 0.85)
def test_timeframe_to_json_without_confidence(self):
"""Test TimeFrame to_json method without confidence"""
timeframe = haven_vlm_engine.TimeFrame(
start=10.0,
end=15.0
)
json_str = timeframe.to_json()
json_data = json.loads(json_str)
self.assertEqual(json_data["start"], 10.0)
self.assertEqual(json_data["end"], 15.0)
self.assertIsNone(json_data["total_confidence"])
def test_timeframe_str(self):
"""Test TimeFrame string representation"""
timeframe = haven_vlm_engine.TimeFrame(
start=10.0,
end=15.0,
total_confidence=0.85
)
str_repr = str(timeframe)
self.assertIn("10.0", str_repr)
self.assertIn("15.0", str_repr)
self.assertIn("0.85", str_repr)
class TestVideoTagInfo(unittest.TestCase):
"""Test cases for VideoTagInfo dataclass"""
def test_videotaginfo_creation(self):
"""Test creating VideoTagInfo with all parameters"""
video_tags = {"category1": {"tag1", "tag2"}}
tag_totals = {"tag1": {"total": 0.8}}
tag_timespans = {"category1": {"tag1": [haven_vlm_engine.TimeFrame(10.0, 15.0)]}}
video_info = haven_vlm_engine.VideoTagInfo(
video_duration=120.0,
video_tags=video_tags,
tag_totals=tag_totals,
tag_timespans=tag_timespans
)
self.assertEqual(video_info.video_duration, 120.0)
self.assertEqual(video_info.video_tags, video_tags)
self.assertEqual(video_info.tag_totals, tag_totals)
self.assertEqual(video_info.tag_timespans, tag_timespans)
def test_videotaginfo_from_json(self):
"""Test creating VideoTagInfo from JSON data"""
json_data = {
"video_duration": 120.0,
"video_tags": {"category1": ["tag1", "tag2"]},
"tag_totals": {"tag1": {"total": 0.8}},
"tag_timespans": {
"category1": {
"tag1": [
{"start": 10.0, "end": 15.0, "total_confidence": 0.85}
]
}
}
}
video_info = haven_vlm_engine.VideoTagInfo.from_json(json_data)
self.assertEqual(video_info.video_duration, 120.0)
self.assertEqual(video_info.video_tags, {"category1": ["tag1", "tag2"]})
self.assertEqual(video_info.tag_totals, {"tag1": {"total": 0.8}})
# Check that tag_timespans contains TimeFrame objects
self.assertIn("category1", video_info.tag_timespans)
self.assertIn("tag1", video_info.tag_timespans["category1"])
self.assertIsInstance(video_info.tag_timespans["category1"]["tag1"][0], haven_vlm_engine.TimeFrame)
def test_videotaginfo_from_json_without_confidence(self):
"""Test creating VideoTagInfo from JSON data without confidence"""
json_data = {
"video_duration": 120.0,
"video_tags": {"category1": ["tag1"]},
"tag_totals": {"tag1": {"total": 0.8}},
"tag_timespans": {
"category1": {
"tag1": [
{"start": 10.0, "end": 15.0}
]
}
}
}
video_info = haven_vlm_engine.VideoTagInfo.from_json(json_data)
timeframe = video_info.tag_timespans["category1"]["tag1"][0]
self.assertEqual(timeframe.start, 10.0)
self.assertEqual(timeframe.end, 15.0)
self.assertIsNone(timeframe.total_confidence)
def test_videotaginfo_from_json_empty_timespans(self):
"""Test creating VideoTagInfo from JSON data with empty timespans"""
json_data = {
"video_duration": 120.0,
"video_tags": {"category1": ["tag1"]},
"tag_totals": {"tag1": {"total": 0.8}},
"tag_timespans": {}
}
video_info = haven_vlm_engine.VideoTagInfo.from_json(json_data)
self.assertEqual(video_info.video_duration, 120.0)
self.assertEqual(video_info.tag_timespans, {})
def test_videotaginfo_str(self):
"""Test VideoTagInfo string representation"""
video_info = haven_vlm_engine.VideoTagInfo(
video_duration=120.0,
video_tags={"category1": {"tag1"}},
tag_totals={"tag1": {"total": 0.8}},
tag_timespans={"category1": {"tag1": []}}
)
str_repr = str(video_info)
self.assertIn("120.0", str_repr)
self.assertIn("1", str_repr) # number of tags
self.assertIn("1", str_repr) # number of timespans
class TestImageResult(unittest.TestCase):
"""Test cases for ImageResult dataclass"""
def test_imageresult_creation(self):
"""Test creating ImageResult with valid data"""
result_data = [{"tags": ["tag1"], "confidence": 0.8}]
image_result = haven_vlm_engine.ImageResult(result=result_data)
self.assertEqual(image_result.result, result_data)
def test_imageresult_creation_empty_list(self):
"""Test creating ImageResult with empty list"""
with self.assertRaises(ValueError):
haven_vlm_engine.ImageResult(result=[])
def test_imageresult_creation_none_result(self):
"""Test creating ImageResult with None result"""
with self.assertRaises(ValueError):
haven_vlm_engine.ImageResult(result=None)
class TestHavenVLMEngine(unittest.TestCase):
"""Test cases for HavenVLMEngine class"""
def setUp(self):
"""Set up test fixtures"""
self.engine = haven_vlm_engine.HavenVLMEngine()
def test_engine_initialization(self):
"""Test engine initialization"""
self.assertIsNone(self.engine.engine)
self.assertIsNone(self.engine.engine_config)
self.assertFalse(self.engine._initialized)
@patch('haven_vlm_engine.config')
@patch('haven_vlm_engine.VLMEngine')
async def test_initialize_success(self, mock_vlm_engine_class, mock_config):
"""Test successful engine initialization"""
mock_engine_instance = MagicMock()
mock_vlm_engine_class.return_value = mock_engine_instance
mock_config.config.vlm_engine_config = {"test": "config"}
await self.engine.initialize()
self.assertTrue(self.engine._initialized)
mock_vlm_engine_class.assert_called_once()
mock_engine_instance.initialize.assert_called_once()
@patch('haven_vlm_engine.config')
@patch('haven_vlm_engine.VLMEngine')
async def test_initialize_already_initialized(self, mock_vlm_engine_class, mock_config):
"""Test initialization when already initialized"""
self.engine._initialized = True
await self.engine.initialize()
# Should not call VLMEngine constructor again
mock_vlm_engine_class.assert_not_called()
@patch('haven_vlm_engine.config')
@patch('haven_vlm_engine.VLMEngine')
async def test_initialize_error(self, mock_vlm_engine_class, mock_config):
"""Test initialization with error"""
mock_vlm_engine_class.side_effect = Exception("Initialization failed")
mock_config.config.vlm_engine_config = {"test": "config"}
with self.assertRaises(Exception):
await self.engine.initialize()
self.assertFalse(self.engine._initialized)
@patch('haven_vlm_engine.config')
def test_create_engine_config(self, mock_config):
"""Test creating engine configuration"""
mock_config.config.vlm_engine_config = {
"active_ai_models": ["model1"],
"pipelines": {
"pipeline1": {
"inputs": ["input1"],
"output": "output1",
"short_name": "short1",
"version": 1.0,
"models": [
{
"name": "model1",
"inputs": ["input1"],
"outputs": "output1"
}
]
}
},
"models": {
"model1": {
"type": "vlm_model",
"model_file_name": "model1.py",
"model_category": "test",
"model_id": "test_model",
"model_identifier": 123,
"model_version": "1.0",
"use_multiplexer": True,
"max_concurrent_requests": 10,
"connection_pool_size": 20,
"multiplexer_endpoints": [],
"tag_list": ["tag1"]
}
},
"category_config": {"test": {}}
}
config = self.engine._create_engine_config()
self.assertIsNotNone(config)
# Note: We can't easily test the exact structure without the actual VLM Engine classes
# but we can verify the method doesn't raise exceptions
@patch('haven_vlm_engine.config')
@patch('haven_vlm_engine.VLMEngine')
async def test_process_video_success(self, mock_vlm_engine_class, mock_config):
"""Test successful video processing"""
# Setup mocks
mock_engine_instance = MagicMock()
mock_vlm_engine_class.return_value = mock_engine_instance
mock_config.config.video_frame_interval = 2.0
mock_config.config.video_threshold = 0.3
mock_config.config.video_confidence_return = True
# Mock the engine's process_video method
mock_engine_instance.process_video.return_value = {
"video_duration": 120.0,
"video_tags": {"category1": ["tag1"]},
"tag_totals": {"tag1": {"total": 0.8}},
"tag_timespans": {}
}
# Initialize engine
await self.engine.initialize()
# Process video
result = await self.engine.process_video("/path/to/video.mp4")
self.assertIsInstance(result, haven_vlm_engine.VideoTagInfo)
mock_engine_instance.process_video.assert_called_once()
@patch('haven_vlm_engine.config')
@patch('haven_vlm_engine.VLMEngine')
async def test_process_video_not_initialized(self, mock_vlm_engine_class, mock_config):
"""Test video processing when not initialized"""
mock_engine_instance = MagicMock()
mock_vlm_engine_class.return_value = mock_engine_instance
mock_config.config.video_frame_interval = 2.0
mock_config.config.video_threshold = 0.3
mock_config.config.video_confidence_return = True
mock_engine_instance.process_video.return_value = {
"video_duration": 120.0,
"video_tags": {"category1": ["tag1"]},
"tag_totals": {"tag1": {"total": 0.8}},
"tag_timespans": {}
}
# Process video without explicit initialization
result = await self.engine.process_video("/path/to/video.mp4")
self.assertIsInstance(result, haven_vlm_engine.VideoTagInfo)
mock_engine_instance.initialize.assert_called_once()
@patch('haven_vlm_engine.config')
@patch('haven_vlm_engine.VLMEngine')
async def test_process_video_error(self, mock_vlm_engine_class, mock_config):
"""Test video processing with error"""
mock_engine_instance = MagicMock()
mock_vlm_engine_class.return_value = mock_engine_instance
mock_config.config.video_frame_interval = 2.0
mock_config.config.video_threshold = 0.3
mock_config.config.video_confidence_return = True
mock_engine_instance.process_video.side_effect = Exception("Processing failed")
await self.engine.initialize()
with self.assertRaises(Exception):
await self.engine.process_video("/path/to/video.mp4")
@patch('haven_vlm_engine.config')
@patch('haven_vlm_engine.VLMEngine')
async def test_process_images_success(self, mock_vlm_engine_class, mock_config):
"""Test successful image processing"""
# Setup mocks
mock_engine_instance = MagicMock()
mock_vlm_engine_class.return_value = mock_engine_instance
mock_config.config.image_threshold = 0.5
mock_config.config.image_confidence_return = False
# Mock the engine's process_images method
mock_engine_instance.process_images.return_value = [
{"tags": ["tag1"], "confidence": 0.8}
]
# Initialize engine
await self.engine.initialize()
# Process images
result = await self.engine.process_images(["/path/to/image1.jpg"])
self.assertIsInstance(result, haven_vlm_engine.ImageResult)
mock_engine_instance.process_images.assert_called_once()
@patch('haven_vlm_engine.config')
@patch('haven_vlm_engine.VLMEngine')
async def test_process_images_error(self, mock_vlm_engine_class, mock_config):
"""Test image processing with error"""
mock_engine_instance = MagicMock()
mock_vlm_engine_class.return_value = mock_engine_instance
mock_config.config.image_threshold = 0.5
mock_config.config.image_confidence_return = False
mock_engine_instance.process_images.side_effect = Exception("Processing failed")
await self.engine.initialize()
with self.assertRaises(Exception):
await self.engine.process_images(["/path/to/image1.jpg"])
@patch('haven_vlm_engine.config')
@patch('haven_vlm_engine.VLMEngine')
async def test_find_optimal_marker_settings_success(self, mock_vlm_engine_class, mock_config):
"""Test successful marker settings optimization"""
# Setup mocks
mock_engine_instance = MagicMock()
mock_vlm_engine_class.return_value = mock_engine_instance
mock_engine_instance.optimize_timeframe_settings.return_value = {"optimal": "settings"}
# Initialize engine
await self.engine.initialize()
# Test data
existing_json = {"existing": "data"}
desired_timespan_data = {
"tag1": haven_vlm_engine.TimeFrame(10.0, 15.0, 0.8)
}
# Find optimal settings
result = await self.engine.find_optimal_marker_settings(existing_json, desired_timespan_data)
self.assertEqual(result, {"optimal": "settings"})
mock_engine_instance.optimize_timeframe_settings.assert_called_once()
@patch('haven_vlm_engine.config')
@patch('haven_vlm_engine.VLMEngine')
async def test_find_optimal_marker_settings_error(self, mock_vlm_engine_class, mock_config):
"""Test marker settings optimization with error"""
mock_engine_instance = MagicMock()
mock_vlm_engine_class.return_value = mock_engine_instance
mock_engine_instance.optimize_timeframe_settings.side_effect = Exception("Optimization failed")
await self.engine.initialize()
existing_json = {"existing": "data"}
desired_timespan_data = {
"tag1": haven_vlm_engine.TimeFrame(10.0, 15.0, 0.8)
}
with self.assertRaises(Exception):
await self.engine.find_optimal_marker_settings(existing_json, desired_timespan_data)
@patch('haven_vlm_engine.config')
@patch('haven_vlm_engine.VLMEngine')
async def test_shutdown_success(self, mock_vlm_engine_class, mock_config):
"""Test successful engine shutdown"""
mock_engine_instance = MagicMock()
mock_vlm_engine_class.return_value = mock_engine_instance
await self.engine.initialize()
await self.engine.shutdown()
mock_engine_instance.shutdown.assert_called_once()
self.assertFalse(self.engine._initialized)
@patch('haven_vlm_engine.config')
@patch('haven_vlm_engine.VLMEngine')
async def test_shutdown_not_initialized(self, mock_vlm_engine_class, mock_config):
"""Test shutdown when not initialized"""
await self.engine.shutdown()
# Should not raise any exceptions
self.assertFalse(self.engine._initialized)
@patch('haven_vlm_engine.config')
@patch('haven_vlm_engine.VLMEngine')
async def test_shutdown_error(self, mock_vlm_engine_class, mock_config):
"""Test shutdown with error"""
mock_engine_instance = MagicMock()
mock_vlm_engine_class.return_value = mock_engine_instance
mock_engine_instance.shutdown.side_effect = Exception("Shutdown failed")
await self.engine.initialize()
await self.engine.shutdown()
# Should handle the error gracefully
self.assertFalse(self.engine._initialized)
class TestConvenienceFunctions(unittest.TestCase):
"""Test cases for convenience functions"""
@patch('haven_vlm_engine.vlm_engine')
async def test_process_video_async(self, mock_vlm_engine):
"""Test process_video_async convenience function"""
mock_vlm_engine.process_video.return_value = MagicMock()
result = await haven_vlm_engine.process_video_async("/path/to/video.mp4")
mock_vlm_engine.process_video.assert_called_once()
self.assertEqual(result, mock_vlm_engine.process_video.return_value)
@patch('haven_vlm_engine.vlm_engine')
async def test_process_images_async(self, mock_vlm_engine):
"""Test process_images_async convenience function"""
mock_vlm_engine.process_images.return_value = MagicMock()
result = await haven_vlm_engine.process_images_async(["/path/to/image.jpg"])
mock_vlm_engine.process_images.assert_called_once()
self.assertEqual(result, mock_vlm_engine.process_images.return_value)
@patch('haven_vlm_engine.vlm_engine')
async def test_find_optimal_marker_settings_async(self, mock_vlm_engine):
"""Test find_optimal_marker_settings_async convenience function"""
mock_vlm_engine.find_optimal_marker_settings.return_value = {"optimal": "settings"}
existing_json = {"existing": "data"}
desired_timespan_data = {
"tag1": haven_vlm_engine.TimeFrame(10.0, 15.0, 0.8)
}
result = await haven_vlm_engine.find_optimal_marker_settings_async(existing_json, desired_timespan_data)
mock_vlm_engine.find_optimal_marker_settings.assert_called_once()
self.assertEqual(result, {"optimal": "settings"})
class TestGlobalVLMEngineInstance(unittest.TestCase):
"""Test cases for global VLM engine instance"""
def test_global_vlm_engine_exists(self):
"""Test that global VLM engine instance exists"""
self.assertIsInstance(haven_vlm_engine.vlm_engine, haven_vlm_engine.HavenVLMEngine)
def test_global_vlm_engine_is_singleton(self):
"""Test that global VLM engine is a singleton"""
engine1 = haven_vlm_engine.vlm_engine
engine2 = haven_vlm_engine.vlm_engine
self.assertIs(engine1, engine2)
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,604 @@
"""
Unit tests for haven_vlm_utility module
"""
import unittest
import tempfile
import os
import shutil
import time
from unittest.mock import patch, mock_open, MagicMock
import yaml
import haven_vlm_utility
class TestPathMutations(unittest.TestCase):
"""Test cases for path mutation functions"""
def test_apply_path_mutations_with_mutations(self):
"""Test applying path mutations with valid mutations"""
mutations = {"E:": "F:", "G:": "D:"}
path = "E:\\videos\\test.mp4"
result = haven_vlm_utility.apply_path_mutations(path, mutations)
self.assertEqual(result, "F:\\videos\\test.mp4")
def test_apply_path_mutations_without_mutations(self):
"""Test applying path mutations with empty mutations"""
mutations = {}
path = "E:\\videos\\test.mp4"
result = haven_vlm_utility.apply_path_mutations(path, mutations)
self.assertEqual(result, path)
def test_apply_path_mutations_with_none_mutations(self):
"""Test applying path mutations with None mutations"""
mutations = None
path = "E:\\videos\\test.mp4"
result = haven_vlm_utility.apply_path_mutations(path, mutations)
self.assertEqual(result, path)
def test_apply_path_mutations_no_match(self):
"""Test applying path mutations when no mutation matches"""
mutations = {"E:": "F:", "G:": "D:"}
path = "C:\\videos\\test.mp4"
result = haven_vlm_utility.apply_path_mutations(path, mutations)
self.assertEqual(result, path)
def test_apply_path_mutations_multiple_matches(self):
"""Test applying path mutations with multiple possible matches"""
mutations = {"E:": "F:", "E:\\videos": "F:\\movies"}
path = "E:\\videos\\test.mp4"
result = haven_vlm_utility.apply_path_mutations(path, mutations)
# Should use the first match
self.assertEqual(result, "F:\\videos\\test.mp4")
class TestDirectoryOperations(unittest.TestCase):
"""Test cases for directory operations"""
def test_ensure_directory_exists_new_directory(self):
"""Test creating a new directory"""
with tempfile.TemporaryDirectory() as temp_dir:
new_dir = os.path.join(temp_dir, "test_subdir")
haven_vlm_utility.ensure_directory_exists(new_dir)
self.assertTrue(os.path.exists(new_dir))
self.assertTrue(os.path.isdir(new_dir))
def test_ensure_directory_exists_existing_directory(self):
"""Test ensuring directory exists when it already exists"""
with tempfile.TemporaryDirectory() as temp_dir:
haven_vlm_utility.ensure_directory_exists(temp_dir)
self.assertTrue(os.path.exists(temp_dir))
self.assertTrue(os.path.isdir(temp_dir))
def test_ensure_directory_exists_nested_directories(self):
"""Test creating nested directories"""
with tempfile.TemporaryDirectory() as temp_dir:
nested_dir = os.path.join(temp_dir, "level1", "level2", "level3")
haven_vlm_utility.ensure_directory_exists(nested_dir)
self.assertTrue(os.path.exists(nested_dir))
self.assertTrue(os.path.isdir(nested_dir))
class TestSafeFileOperations(unittest.TestCase):
"""Test cases for safe file operations"""
def test_safe_file_operation_success(self):
"""Test successful file operation"""
def test_func(a, b, c=10):
return a + b + c
result = haven_vlm_utility.safe_file_operation(test_func, 1, 2, c=5)
self.assertEqual(result, 8)
def test_safe_file_operation_os_error(self):
"""Test file operation with OSError"""
def test_func():
raise OSError("File not found")
result = haven_vlm_utility.safe_file_operation(test_func)
self.assertIsNone(result)
def test_safe_file_operation_io_error(self):
"""Test file operation with IOError"""
def test_func():
raise IOError("Permission denied")
result = haven_vlm_utility.safe_file_operation(test_func)
self.assertIsNone(result)
def test_safe_file_operation_unexpected_error(self):
"""Test file operation with unexpected error"""
def test_func():
raise ValueError("Unexpected error")
result = haven_vlm_utility.safe_file_operation(test_func)
self.assertIsNone(result)
class TestYamlConfigOperations(unittest.TestCase):
"""Test cases for YAML configuration operations"""
def setUp(self):
"""Set up test fixtures"""
self.test_config = {
"video_frame_interval": 2.0,
"video_threshold": 0.3,
"image_threshold": 0.5,
"endpoints": [
{"url": "http://localhost:1234", "weight": 5},
{"url": "https://cloud.example.com", "weight": 1}
]
}
def test_load_yaml_config_success(self):
"""Test successfully loading YAML configuration"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.yml', delete=False) as f:
yaml.dump(self.test_config, f)
config_path = f.name
try:
result = haven_vlm_utility.load_yaml_config(config_path)
self.assertEqual(result, self.test_config)
finally:
os.unlink(config_path)
def test_load_yaml_config_file_not_found(self):
"""Test loading YAML configuration from nonexistent file"""
result = haven_vlm_utility.load_yaml_config("nonexistent_file.yml")
self.assertIsNone(result)
def test_load_yaml_config_invalid_yaml(self):
"""Test loading YAML configuration with invalid YAML"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.yml', delete=False) as f:
f.write("invalid: yaml: content: [")
config_path = f.name
try:
result = haven_vlm_utility.load_yaml_config(config_path)
self.assertIsNone(result)
finally:
os.unlink(config_path)
def test_load_yaml_config_permission_error(self):
"""Test loading YAML configuration with permission error"""
with patch('builtins.open', side_effect=PermissionError("Permission denied")):
result = haven_vlm_utility.load_yaml_config("test.yml")
self.assertIsNone(result)
def test_save_yaml_config_success(self):
"""Test successfully saving YAML configuration"""
with tempfile.TemporaryDirectory() as temp_dir:
config_path = os.path.join(temp_dir, "test_config.yml")
result = haven_vlm_utility.save_yaml_config(self.test_config, config_path)
self.assertTrue(result)
self.assertTrue(os.path.exists(config_path))
# Verify the saved content
with open(config_path, 'r') as f:
loaded_config = yaml.safe_load(f)
self.assertEqual(loaded_config, self.test_config)
def test_save_yaml_config_with_nested_directories(self):
"""Test saving YAML configuration to nested directory"""
with tempfile.TemporaryDirectory() as temp_dir:
config_path = os.path.join(temp_dir, "nested", "dir", "test_config.yml")
result = haven_vlm_utility.save_yaml_config(self.test_config, config_path)
self.assertTrue(result)
self.assertTrue(os.path.exists(config_path))
def test_save_yaml_config_permission_error(self):
"""Test saving YAML configuration with permission error"""
with patch('builtins.open', side_effect=PermissionError("Permission denied")):
result = haven_vlm_utility.save_yaml_config(self.test_config, "test.yml")
self.assertFalse(result)
class TestFileValidation(unittest.TestCase):
"""Test cases for file validation functions"""
def test_validate_file_path_existing_file(self):
"""Test validating an existing file path"""
with tempfile.NamedTemporaryFile(delete=False) as f:
file_path = f.name
try:
result = haven_vlm_utility.validate_file_path(file_path)
self.assertTrue(result)
finally:
os.unlink(file_path)
def test_validate_file_path_nonexistent_file(self):
"""Test validating a nonexistent file path"""
result = haven_vlm_utility.validate_file_path("nonexistent_file.txt")
self.assertFalse(result)
def test_validate_file_path_directory(self):
"""Test validating a directory path"""
with tempfile.TemporaryDirectory() as temp_dir:
result = haven_vlm_utility.validate_file_path(temp_dir)
self.assertFalse(result)
def test_validate_file_path_permission_error(self):
"""Test validating file path with permission error"""
with patch('os.path.isfile', side_effect=OSError("Permission denied")):
result = haven_vlm_utility.validate_file_path("test.txt")
self.assertFalse(result)
class TestFileExtensionFunctions(unittest.TestCase):
"""Test cases for file extension functions"""
def test_get_file_extension_with_extension(self):
"""Test getting file extension from file with extension"""
result = haven_vlm_utility.get_file_extension("test.mp4")
self.assertEqual(result, ".mp4")
def test_get_file_extension_without_extension(self):
"""Test getting file extension from file without extension"""
result = haven_vlm_utility.get_file_extension("test")
self.assertEqual(result, "")
def test_get_file_extension_multiple_dots(self):
"""Test getting file extension from file with multiple dots"""
result = haven_vlm_utility.get_file_extension("test.backup.mp4")
self.assertEqual(result, ".mp4")
def test_get_file_extension_uppercase(self):
"""Test getting file extension from file with uppercase extension"""
result = haven_vlm_utility.get_file_extension("test.MP4")
self.assertEqual(result, ".mp4")
def test_is_video_file_valid_extensions(self):
"""Test video file detection with valid extensions"""
video_files = ["test.mp4", "test.avi", "test.mkv", "test.mov", "test.wmv", "test.flv", "test.webm", "test.m4v"]
for video_file in video_files:
result = haven_vlm_utility.is_video_file(video_file)
self.assertTrue(result, f"Failed for {video_file}")
def test_is_video_file_invalid_extensions(self):
"""Test video file detection with invalid extensions"""
non_video_files = ["test.jpg", "test.txt", "test.pdf", "test.exe"]
for non_video_file in non_video_files:
result = haven_vlm_utility.is_video_file(non_video_file)
self.assertFalse(result, f"Failed for {non_video_file}")
def test_is_image_file_valid_extensions(self):
"""Test image file detection with valid extensions"""
image_files = ["test.jpg", "test.jpeg", "test.png", "test.gif", "test.bmp", "test.tiff", "test.webp"]
for image_file in image_files:
result = haven_vlm_utility.is_image_file(image_file)
self.assertTrue(result, f"Failed for {image_file}")
def test_is_image_file_invalid_extensions(self):
"""Test image file detection with invalid extensions"""
non_image_files = ["test.mp4", "test.txt", "test.pdf", "test.exe"]
for non_image_file in non_image_files:
result = haven_vlm_utility.is_image_file(non_image_file)
self.assertFalse(result, f"Failed for {non_image_file}")
class TestFormattingFunctions(unittest.TestCase):
"""Test cases for formatting functions"""
def test_format_duration_seconds(self):
"""Test formatting duration in seconds"""
result = haven_vlm_utility.format_duration(45.5)
self.assertEqual(result, "45.5s")
def test_format_duration_minutes(self):
"""Test formatting duration in minutes"""
result = haven_vlm_utility.format_duration(125.3)
self.assertEqual(result, "2m 5.3s")
def test_format_duration_hours(self):
"""Test formatting duration in hours"""
result = haven_vlm_utility.format_duration(7325.7)
self.assertEqual(result, "2h 2m 5.7s")
def test_format_duration_zero(self):
"""Test formatting zero duration"""
result = haven_vlm_utility.format_duration(0)
self.assertEqual(result, "0.0s")
def test_format_file_size_bytes(self):
"""Test formatting file size in bytes"""
result = haven_vlm_utility.format_file_size(512)
self.assertEqual(result, "512.0 B")
def test_format_file_size_kilobytes(self):
"""Test formatting file size in kilobytes"""
result = haven_vlm_utility.format_file_size(1536)
self.assertEqual(result, "1.5 KB")
def test_format_file_size_megabytes(self):
"""Test formatting file size in megabytes"""
result = haven_vlm_utility.format_file_size(1572864)
self.assertEqual(result, "1.5 MB")
def test_format_file_size_gigabytes(self):
"""Test formatting file size in gigabytes"""
result = haven_vlm_utility.format_file_size(1610612736)
self.assertEqual(result, "1.5 GB")
def test_format_file_size_zero(self):
"""Test formatting zero file size"""
result = haven_vlm_utility.format_file_size(0)
self.assertEqual(result, "0.0 B")
class TestSanitizationFunctions(unittest.TestCase):
"""Test cases for sanitization functions"""
def test_sanitize_filename_valid(self):
"""Test sanitizing a valid filename"""
result = haven_vlm_utility.sanitize_filename("valid_filename.txt")
self.assertEqual(result, "valid_filename.txt")
def test_sanitize_filename_invalid_chars(self):
"""Test sanitizing filename with invalid characters"""
result = haven_vlm_utility.sanitize_filename("file<name>:with/invalid\\chars|?*")
self.assertEqual(result, "file_name__with_invalid_chars___")
def test_sanitize_filename_leading_trailing_spaces(self):
"""Test sanitizing filename with leading/trailing spaces"""
result = haven_vlm_utility.sanitize_filename(" filename.txt ")
self.assertEqual(result, "filename.txt")
def test_sanitize_filename_leading_trailing_dots(self):
"""Test sanitizing filename with leading/trailing dots"""
result = haven_vlm_utility.sanitize_filename("...filename.txt...")
self.assertEqual(result, "filename.txt")
def test_sanitize_filename_empty(self):
"""Test sanitizing empty filename"""
result = haven_vlm_utility.sanitize_filename("")
self.assertEqual(result, "unnamed")
def test_sanitize_filename_only_spaces(self):
"""Test sanitizing filename with only spaces"""
result = haven_vlm_utility.sanitize_filename(" ")
self.assertEqual(result, "unnamed")
class TestBackupFunctions(unittest.TestCase):
"""Test cases for backup functions"""
def test_create_backup_file_success(self):
"""Test successfully creating a backup file"""
with tempfile.NamedTemporaryFile(delete=False) as f:
original_file = f.name
f.write(b"test content")
try:
result = haven_vlm_utility.create_backup_file(original_file)
self.assertIsNotNone(result)
self.assertTrue(os.path.exists(result))
self.assertTrue(result.endswith(".backup"))
# Verify backup content
with open(result, 'rb') as f:
content = f.read()
self.assertEqual(content, b"test content")
# Clean up backup
os.unlink(result)
finally:
os.unlink(original_file)
def test_create_backup_file_custom_suffix(self):
"""Test creating backup file with custom suffix"""
with tempfile.NamedTemporaryFile(delete=False) as f:
original_file = f.name
f.write(b"test content")
try:
result = haven_vlm_utility.create_backup_file(original_file, ".custom")
self.assertIsNotNone(result)
self.assertTrue(result.endswith(".custom"))
# Clean up backup
os.unlink(result)
finally:
os.unlink(original_file)
def test_create_backup_file_nonexistent(self):
"""Test creating backup of nonexistent file"""
result = haven_vlm_utility.create_backup_file("nonexistent_file.txt")
self.assertIsNone(result)
def test_create_backup_file_permission_error(self):
"""Test creating backup file with permission error"""
with patch('shutil.copy2', side_effect=PermissionError("Permission denied")):
with tempfile.NamedTemporaryFile(delete=False) as f:
original_file = f.name
try:
result = haven_vlm_utility.create_backup_file(original_file)
self.assertIsNone(result)
finally:
os.unlink(original_file)
class TestDictionaryOperations(unittest.TestCase):
"""Test cases for dictionary operations"""
def test_merge_dictionaries_simple(self):
"""Test simple dictionary merging"""
dict1 = {"a": 1, "b": 2}
dict2 = {"c": 3, "d": 4}
result = haven_vlm_utility.merge_dictionaries(dict1, dict2)
expected = {"a": 1, "b": 2, "c": 3, "d": 4}
self.assertEqual(result, expected)
def test_merge_dictionaries_overwrite(self):
"""Test dictionary merging with overwrite"""
dict1 = {"a": 1, "b": 2}
dict2 = {"b": 3, "c": 4}
result = haven_vlm_utility.merge_dictionaries(dict1, dict2, overwrite=True)
expected = {"a": 1, "b": 3, "c": 4}
self.assertEqual(result, expected)
def test_merge_dictionaries_no_overwrite(self):
"""Test dictionary merging without overwrite"""
dict1 = {"a": 1, "b": 2}
dict2 = {"b": 3, "c": 4}
result = haven_vlm_utility.merge_dictionaries(dict1, dict2, overwrite=False)
expected = {"a": 1, "b": 2, "c": 4}
self.assertEqual(result, expected)
def test_merge_dictionaries_nested(self):
"""Test merging nested dictionaries"""
dict1 = {"a": 1, "b": {"x": 10, "y": 20}}
dict2 = {"c": 3, "b": {"y": 25, "z": 30}}
result = haven_vlm_utility.merge_dictionaries(dict1, dict2, overwrite=True)
expected = {"a": 1, "b": {"x": 10, "y": 25, "z": 30}, "c": 3}
self.assertEqual(result, expected)
def test_merge_dictionaries_empty(self):
"""Test merging with empty dictionaries"""
dict1 = {}
dict2 = {"a": 1, "b": 2}
result = haven_vlm_utility.merge_dictionaries(dict1, dict2)
self.assertEqual(result, dict2)
class TestListOperations(unittest.TestCase):
"""Test cases for list operations"""
def test_chunk_list_even_chunks(self):
"""Test chunking list into even chunks"""
lst = [1, 2, 3, 4, 5, 6]
result = haven_vlm_utility.chunk_list(lst, 2)
expected = [[1, 2], [3, 4], [5, 6]]
self.assertEqual(result, expected)
def test_chunk_list_uneven_chunks(self):
"""Test chunking list into uneven chunks"""
lst = [1, 2, 3, 4, 5]
result = haven_vlm_utility.chunk_list(lst, 2)
expected = [[1, 2], [3, 4], [5]]
self.assertEqual(result, expected)
def test_chunk_list_empty_list(self):
"""Test chunking empty list"""
lst = []
result = haven_vlm_utility.chunk_list(lst, 3)
expected = []
self.assertEqual(result, expected)
def test_chunk_list_chunk_size_larger_than_list(self):
"""Test chunking when chunk size is larger than list"""
lst = [1, 2, 3]
result = haven_vlm_utility.chunk_list(lst, 5)
expected = [[1, 2, 3]]
self.assertEqual(result, expected)
class TestRetryOperations(unittest.TestCase):
"""Test cases for retry operations"""
def test_retry_operation_success_first_try(self):
"""Test retry operation that succeeds on first try"""
def test_func():
return "success"
result = haven_vlm_utility.retry_operation(test_func)
self.assertEqual(result, "success")
def test_retry_operation_success_after_retries(self):
"""Test retry operation that succeeds after some retries"""
call_count = 0
def test_func():
nonlocal call_count
call_count += 1
if call_count < 3:
raise ValueError("Temporary error")
return "success"
result = haven_vlm_utility.retry_operation(test_func, max_retries=3, delay=0.1)
self.assertEqual(result, "success")
self.assertEqual(call_count, 3)
def test_retry_operation_all_retries_fail(self):
"""Test retry operation that fails all retries"""
def test_func():
raise ValueError("Persistent error")
result = haven_vlm_utility.retry_operation(test_func, max_retries=2, delay=0.1)
self.assertIsNone(result)
def test_retry_operation_with_arguments(self):
"""Test retry operation with function arguments"""
def test_func(a, b, c=10):
return a + b + c
result = haven_vlm_utility.retry_operation(test_func, max_retries=1, delay=0.1, 1, 2, c=5)
self.assertEqual(result, 8)
def test_retry_operation_with_keyword_arguments(self):
"""Test retry operation with keyword arguments"""
def test_func(**kwargs):
return kwargs.get('value', 0)
result = haven_vlm_utility.retry_operation(test_func, max_retries=1, delay=0.1, value=42)
self.assertEqual(result, 42)
if __name__ == '__main__':
unittest.main()