mirror of
https://github.com/stashapp/CommunityScripts.git
synced 2026-02-04 01:52:30 -06:00
initial commit of AHavenVLMConnector (#657)
Co-authored-by: DogmaDragon <103123951+DogmaDragon@users.noreply.github.com>
This commit is contained in:
parent
69e44b2099
commit
2a0719091c
8
plugins/AHavenVLMConnector/CHANGELOG.md
Normal file
8
plugins/AHavenVLMConnector/CHANGELOG.md
Normal file
@ -0,0 +1,8 @@
|
||||
# Changelog
|
||||
|
||||
All notable changes to the A Haven VLM Connector project will be documented in this file.
|
||||
|
||||
## [1.0.0] - 2025-06-29
|
||||
|
||||
### Added
|
||||
- **Initial release**
|
||||
143
plugins/AHavenVLMConnector/README.md
Normal file
143
plugins/AHavenVLMConnector/README.md
Normal file
@ -0,0 +1,143 @@
|
||||
# A Haven VLM Connector
|
||||
|
||||
A StashApp plugin for Vision-Language Model (VLM) based content tagging and analysis. This plugin is designed with a **local-first philosophy**, empowering users to run analysis on their own hardware (using CPU or GPU) and their local network. It also supports cloud-based VLM endpoints for additional flexibility. The Haven VLM Engine provides advanced automatic content detection and tagging, delivering superior accuracy compared to traditional image classification methods.
|
||||
|
||||
## Features
|
||||
|
||||
- **Local Network Empowerment**: Distribute processing across home/office computers without cloud dependencies
|
||||
- **Context-Aware Detection**: Leverages Vision-Language Models' understanding of visual relationships
|
||||
- **Advanced Dependency Management**: Uses PythonDepManager for automatic dependency installation
|
||||
- **Enjoying Funscript Haven?** Check out more tools and projects at https://github.com/Haven-hvn
|
||||
|
||||
## Requirements
|
||||
|
||||
- Python 3.8+
|
||||
- StashApp
|
||||
- PythonDepManager plugin (automatically handles dependencies)
|
||||
- OpenAI-compatible VLM endpoints (local or cloud-based)
|
||||
|
||||
## Installation
|
||||
|
||||
1. Clone or download this plugin to your StashApp plugins directory
|
||||
2. Ensure PythonDepManager is installed in your StashApp plugins
|
||||
3. Configure your VLM endpoints in `haven_vlm_config.py` (local network endpoints recommended)
|
||||
4. Restart StashApp
|
||||
|
||||
The plugin automatically manages all dependencies.
|
||||
|
||||
## Why Local-First?
|
||||
|
||||
- **Complete Control**: Process sensitive content on your own hardware
|
||||
- **Cost Effective**: Avoid cloud processing fees by using existing resources
|
||||
- **Flexible Scaling**: Add more computers to your local network for increased capacity
|
||||
- **Privacy Focused**: Keep your media completely private
|
||||
- **Hybrid Options**: Combine local and cloud endpoints for optimal flexibility
|
||||
|
||||
```mermaid
|
||||
graph LR
|
||||
A[User's Computer] --> B[Local GPU Machine]
|
||||
A --> C[Local CPU Machine 1]
|
||||
A --> D[Local CPU Machine 2]
|
||||
A --> E[Cloud Endpoint]
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### Easy Setup with LM Studio
|
||||
|
||||
[LM Studio](https://lmstudio.ai/) provides the easiest way to configure local endpoints:
|
||||
|
||||
1. Download and install [LM Studio](https://lmstudio.ai/)
|
||||
2. [Search for or download](https://huggingface.co/models) a vision-capable model; tested with : (in order of high to low accuracy) zai-org/glm-4.6v-flash, huihui-mistral-small-3.2-24b-instruct-2506-abliterated-v2, qwen/qwen3-vl-8b, lfm2.5-vl
|
||||
3. Load your desired Model
|
||||
4. On the developer tab start the local server using the start toggle
|
||||
5. Optionally click the Settings gear then toggle *Serve on local network*
|
||||
5. Optionally configure `haven_vlm_config.py`:
|
||||
|
||||
By default locahost is included in the config, **remove cloud endpoint if you don't want automatic failover**
|
||||
```python
|
||||
{
|
||||
"base_url": "http://localhost:1234/v1", # LM Studio default
|
||||
"api_key": "", # API key not required
|
||||
"name": "lm-studio-local",
|
||||
"weight": 5,
|
||||
"is_fallback": False
|
||||
}
|
||||
```
|
||||
|
||||
### Tag Configuration
|
||||
|
||||
```python
|
||||
"tag_list": [
|
||||
"Basketball point", "Foul", "Break-away", "Turnover"
|
||||
]
|
||||
```
|
||||
|
||||
### Processing Settings
|
||||
|
||||
```python
|
||||
VIDEO_FRAME_INTERVAL = 2.0 # Process every 2 seconds
|
||||
CONCURRENT_TASK_LIMIT = 8 # Adjust based on local hardware
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Tag Videos
|
||||
1. Tag scenes with `VLM_TagMe`
|
||||
2. Run "Tag Videos" task
|
||||
3. Plugin processes content using local/network resources
|
||||
|
||||
### Performance Tips
|
||||
- Start with 2-3 local machines for load balancing
|
||||
- Assign higher weights to GPU-enabled machines
|
||||
- Adjust `CONCURRENT_TASK_LIMIT` based on total system resources
|
||||
- Use SSD storage for better I/O performance
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
AHavenVLMConnector/
|
||||
├── ahavenvlmconnector.yml
|
||||
├── haven_vlm_connector.py
|
||||
├── haven_vlm_config.py
|
||||
├── haven_vlm_engine.py
|
||||
├── haven_media_handler.py
|
||||
├── haven_vlm_utility.py
|
||||
├── requirements.txt
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Local Network Setup
|
||||
- Ensure firewalls allow communication between machines
|
||||
- Verify all local endpoints are running VLM services
|
||||
- Use static IPs for local machines
|
||||
- Check `http://local-machine-ip:port/v1` responds correctly
|
||||
|
||||
### Performance Optimization
|
||||
- **Distribute Load**: Use multiple mid-range machines instead of one high-end
|
||||
- **GPU Prioritization**: Assign highest weight to GPU machines
|
||||
- **Network Speed**: Use wired Ethernet connections for faster transfer
|
||||
- **Resource Monitoring**: Watch system resources during processing
|
||||
|
||||
## Development
|
||||
|
||||
### Adding Local Endpoints
|
||||
1. Install VLM service on network machines
|
||||
2. Add endpoint configuration with local IPs
|
||||
3. Set appropriate weights based on hardware capability
|
||||
|
||||
### Custom Models
|
||||
Use any OpenAI-compatible models that support:
|
||||
- POST requests to `/v1/chat/completions`
|
||||
- Vision capabilities with image input
|
||||
- Local deployment options
|
||||
|
||||
### Log Messages
|
||||
|
||||
Check StashApp logs for detailed processing information and error messages.
|
||||
|
||||
## License
|
||||
|
||||
This project is part of the StashApp Community Scripts collection.
|
||||
22
plugins/AHavenVLMConnector/ahavenvlmconnector.yml
Normal file
22
plugins/AHavenVLMConnector/ahavenvlmconnector.yml
Normal file
@ -0,0 +1,22 @@
|
||||
name: Haven VLM Connector
|
||||
# requires: PythonDepManager
|
||||
description: Tag videos with Vision-Language Models using any OpenAI-compatible VLM endpoint
|
||||
version: 1.0.0
|
||||
url: https://github.com/stashapp/CommunityScripts/tree/main/plugins/AHavenVLMConnector
|
||||
exec:
|
||||
- python
|
||||
- "{pluginDir}/haven_vlm_connector.py"
|
||||
interface: raw
|
||||
tasks:
|
||||
- name: Tag Videos
|
||||
description: Run VLM analysis on videos with VLM_TagMe tag
|
||||
defaultArgs:
|
||||
mode: tag_videos
|
||||
- name: Collect Incorrect Markers and Images
|
||||
description: Collects data from markers and images that were VLM tagged but were manually marked with VLM_Incorrect due to the VLM making a mistake. This will collect the data and output as a file which can be used to improve the VLM models.
|
||||
defaultArgs:
|
||||
mode: collect_incorrect_markers
|
||||
- name: Find Marker Settings
|
||||
description: Find Optimal Marker Settings based on a video that has manually tuned markers and has been processed by the VLM previously. Only 1 video should have VLM_TagMe before running.
|
||||
defaultArgs:
|
||||
mode: find_marker_settings
|
||||
98
plugins/AHavenVLMConnector/exit_tracker.py
Normal file
98
plugins/AHavenVLMConnector/exit_tracker.py
Normal file
@ -0,0 +1,98 @@
|
||||
"""
|
||||
Comprehensive sys.exit tracking module
|
||||
Instruments all sys.exit() calls with full call stack and context
|
||||
"""
|
||||
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Optional
|
||||
|
||||
# Store original sys.exit
|
||||
original_exit = sys.exit
|
||||
|
||||
# Track if we've already patched
|
||||
_exit_tracker_patched = False
|
||||
|
||||
def install_exit_tracker(logger=None) -> None:
|
||||
"""
|
||||
Install the exit tracker by monkey-patching sys.exit
|
||||
|
||||
Args:
|
||||
logger: Optional logger instance (will use fallback print if None)
|
||||
"""
|
||||
global _exit_tracker_patched, original_exit
|
||||
|
||||
if _exit_tracker_patched:
|
||||
return
|
||||
|
||||
# Store original if not already stored
|
||||
if hasattr(sys, 'exit') and sys.exit is not original_exit:
|
||||
original_exit = sys.exit
|
||||
|
||||
def tracked_exit(code: int = 0) -> None:
|
||||
"""Track sys.exit() calls with full call stack"""
|
||||
# Get current stack trace (not from exception, but current call stack)
|
||||
stack = traceback.extract_stack()
|
||||
|
||||
# Format the stack trace, excluding this tracking function
|
||||
stack_lines = []
|
||||
for frame in stack:
|
||||
# Skip internal Python frames and this tracker
|
||||
if ('tracked_exit' not in frame.filename and
|
||||
'/usr/lib' not in frame.filename and
|
||||
'/System/Library' not in frame.filename and
|
||||
'exit_tracker.py' not in frame.filename):
|
||||
stack_lines.append(
|
||||
f" File \"{frame.filename}\", line {frame.lineno}, in {frame.name}\n {frame.line}"
|
||||
)
|
||||
|
||||
# Take last 15 frames to see the full call chain
|
||||
stack_str = '\n'.join(stack_lines[-15:])
|
||||
|
||||
# Get current exception info if available
|
||||
exc_info = sys.exc_info()
|
||||
exc_str = ""
|
||||
if exc_info[0] is not None:
|
||||
exc_str = f"\n Active Exception: {exc_info[0].__name__}: {exc_info[1]}"
|
||||
|
||||
# Build the error message
|
||||
error_msg = f"""[DEBUG_EXIT_CODE] ==========================================
|
||||
[DEBUG_EXIT_CODE] sys.exit() called with code: {code}
|
||||
[DEBUG_EXIT_CODE] Call stack (last 15 frames):
|
||||
{stack_str}
|
||||
{exc_str}
|
||||
[DEBUG_EXIT_CODE] =========================================="""
|
||||
|
||||
# Log using provided logger or fallback to print
|
||||
if logger:
|
||||
try:
|
||||
logger.error(error_msg)
|
||||
except Exception as log_error:
|
||||
print(f"[EXIT_TRACKER_LOGGER_ERROR] Failed to log: {log_error}")
|
||||
print(error_msg)
|
||||
else:
|
||||
print(error_msg)
|
||||
|
||||
# Call original exit
|
||||
original_exit(code)
|
||||
|
||||
# Install the tracker
|
||||
sys.exit = tracked_exit
|
||||
_exit_tracker_patched = True
|
||||
|
||||
if logger:
|
||||
logger.debug("[DEBUG_EXIT_CODE] Exit tracker installed successfully")
|
||||
else:
|
||||
print("[DEBUG_EXIT_CODE] Exit tracker installed successfully")
|
||||
|
||||
def uninstall_exit_tracker() -> None:
|
||||
"""Uninstall the exit tracker and restore original sys.exit"""
|
||||
global _exit_tracker_patched, original_exit
|
||||
|
||||
if _exit_tracker_patched:
|
||||
sys.exit = original_exit
|
||||
_exit_tracker_patched = False
|
||||
|
||||
# Auto-install on import (can be disabled by calling uninstall_exit_tracker())
|
||||
if not _exit_tracker_patched:
|
||||
install_exit_tracker()
|
||||
333
plugins/AHavenVLMConnector/haven_media_handler.py
Normal file
333
plugins/AHavenVLMConnector/haven_media_handler.py
Normal file
@ -0,0 +1,333 @@
|
||||
"""
|
||||
Haven Media Handler Module
|
||||
Handles StashApp media operations and tag management
|
||||
"""
|
||||
|
||||
import os
|
||||
import zipfile
|
||||
import shutil
|
||||
from typing import List, Dict, Any, Optional, Tuple, Set
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
# Use PythonDepManager for dependency management
|
||||
try:
|
||||
from PythonDepManager import ensure_import
|
||||
ensure_import("stashapi:stashapp-tools==0.2.58")
|
||||
|
||||
from stashapi.stashapp import StashInterface, StashVersion
|
||||
import stashapi.log as log
|
||||
except ImportError as e:
|
||||
print(f"stashapp-tools not found: {e}")
|
||||
print("Please ensure PythonDepManager is available and stashapp-tools is accessible")
|
||||
raise
|
||||
|
||||
import haven_vlm_config as config
|
||||
|
||||
# Global variables
|
||||
tag_id_cache: Dict[str, int] = {}
|
||||
vlm_tag_ids_cache: Set[int] = set()
|
||||
stash_version: Optional[StashVersion] = None
|
||||
end_seconds_support: bool = False
|
||||
|
||||
# Tag IDs
|
||||
stash: Optional[StashInterface] = None
|
||||
vlm_errored_tag_id: Optional[int] = None
|
||||
vlm_tagme_tag_id: Optional[int] = None
|
||||
vlm_base_tag_id: Optional[int] = None
|
||||
vlm_tagged_tag_id: Optional[int] = None
|
||||
vr_tag_id: Optional[int] = None
|
||||
vlm_incorrect_tag_id: Optional[int] = None
|
||||
|
||||
def initialize(connection: Dict[str, Any]) -> None:
|
||||
"""Initialize the media handler with StashApp connection"""
|
||||
global stash, vlm_errored_tag_id, vlm_tagme_tag_id, vlm_base_tag_id
|
||||
global vlm_tagged_tag_id, vr_tag_id, end_seconds_support, stash_version
|
||||
global vlm_incorrect_tag_id
|
||||
|
||||
# Initialize the Stash API
|
||||
stash = StashInterface(connection)
|
||||
|
||||
# Initialize "metadata" tags
|
||||
vlm_errored_tag_id = stash.find_tag(config.config.vlm_errored_tag_name, create=True)["id"]
|
||||
vlm_tagme_tag_id = stash.find_tag(config.config.vlm_tagme_tag_name, create=True)["id"]
|
||||
vlm_base_tag_id = stash.find_tag(config.config.vlm_base_tag_name, create=True)["id"]
|
||||
vlm_tagged_tag_id = stash.find_tag(config.config.vlm_tagged_tag_name, create=True)["id"]
|
||||
vlm_incorrect_tag_id = stash.find_tag(config.config.vlm_incorrect_tag_name, create=True)["id"]
|
||||
|
||||
# Get VR tag from configuration
|
||||
vr_tag_name = stash.get_configuration()["ui"].get("vrTag", None)
|
||||
if not vr_tag_name:
|
||||
log.warning("No VR tag found in configuration")
|
||||
vr_tag_id = None
|
||||
else:
|
||||
vr_tag_id = stash.find_tag(vr_tag_name)["id"]
|
||||
|
||||
stash_version = get_stash_version()
|
||||
end_second_support_beyond = StashVersion("v0.27.2-76648")
|
||||
end_seconds_support = stash_version > end_second_support_beyond
|
||||
|
||||
def get_stash_version() -> StashVersion:
|
||||
"""Get the current StashApp version"""
|
||||
if not stash:
|
||||
raise RuntimeError("Stash interface not initialized")
|
||||
return stash.stash_version()
|
||||
|
||||
# ----------------- Tag Management Methods -----------------
|
||||
|
||||
def get_tag_ids(tag_names: List[str], create: bool = False) -> List[int]:
|
||||
"""Get tag IDs for multiple tag names"""
|
||||
return [get_tag_id(tag_name, create) for tag_name in tag_names]
|
||||
|
||||
def get_tag_id(tag_name: str, create: bool = False) -> Optional[int]:
|
||||
"""Get tag ID for a single tag name"""
|
||||
if tag_name not in tag_id_cache:
|
||||
stashtag = stash.find_tag(tag_name)
|
||||
if stashtag:
|
||||
tag_id_cache[tag_name] = stashtag["id"]
|
||||
return stashtag["id"]
|
||||
else:
|
||||
if not create:
|
||||
return None
|
||||
tag = stash.create_tag({
|
||||
"name": tag_name,
|
||||
"ignore_auto_tag": True,
|
||||
"parent_ids": [vlm_base_tag_id]
|
||||
})['id']
|
||||
tag_id_cache[tag_name] = tag
|
||||
vlm_tag_ids_cache.add(tag)
|
||||
return tag
|
||||
return tag_id_cache.get(tag_name)
|
||||
|
||||
def get_vlm_tags() -> List[int]:
|
||||
"""Get all VLM-generated tags"""
|
||||
if len(vlm_tag_ids_cache) == 0:
|
||||
vlm_tags = [
|
||||
item['id'] for item in stash.find_tags(
|
||||
f={"parents": {"value": vlm_base_tag_id, "modifier": "INCLUDES"}},
|
||||
fragment="id"
|
||||
)
|
||||
]
|
||||
vlm_tag_ids_cache.update(vlm_tags)
|
||||
else:
|
||||
vlm_tags = list(vlm_tag_ids_cache)
|
||||
return vlm_tags
|
||||
|
||||
def is_scene_tagged(tags: List[Dict[str, Any]]) -> bool:
|
||||
"""Check if a scene has been tagged by VLM"""
|
||||
for tag in tags:
|
||||
if tag['id'] == vlm_tagged_tag_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_vr_scene(tags: List[Dict[str, Any]]) -> bool:
|
||||
"""Check if a scene is VR content"""
|
||||
for tag in tags:
|
||||
if tag['id'] == vr_tag_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
# ----------------- Scene Management Methods -----------------
|
||||
|
||||
def add_tags_to_video(video_id: int, tag_ids: List[int], add_tagged: bool = True) -> None:
|
||||
"""Add tags to a video scene"""
|
||||
if add_tagged:
|
||||
tag_ids.append(vlm_tagged_tag_id)
|
||||
stash.update_scenes({
|
||||
"ids": [video_id],
|
||||
"tag_ids": {"ids": tag_ids, "mode": "ADD"}
|
||||
})
|
||||
|
||||
def clear_all_tags_from_video(scene: Dict[str, Any]) -> None:
|
||||
"""Clear all tags from a video scene using existing scene data"""
|
||||
scene_id = scene.get('id')
|
||||
if scene_id is None:
|
||||
log.error("Scene missing 'id' field")
|
||||
return
|
||||
|
||||
current_tag_ids = [tag['id'] for tag in scene.get('tags', [])]
|
||||
if current_tag_ids:
|
||||
stash.update_scenes({
|
||||
"ids": [scene_id],
|
||||
"tag_ids": {"ids": current_tag_ids, "mode": "REMOVE"}
|
||||
})
|
||||
log.info(f"Cleared {len(current_tag_ids)} tags from scene {scene_id}")
|
||||
|
||||
def clear_all_markers_from_video(video_id: int) -> None:
|
||||
"""Clear all markers from a video scene"""
|
||||
markers = get_scene_markers(video_id)
|
||||
if markers:
|
||||
delete_markers(markers)
|
||||
log.info(f"Cleared all {len(markers)} markers from scene {video_id}")
|
||||
|
||||
def remove_vlm_tags_from_video(
|
||||
video_id: int,
|
||||
remove_tagme: bool = True,
|
||||
remove_errored: bool = True
|
||||
) -> None:
|
||||
"""Remove all VLM tags from a video scene"""
|
||||
vlm_tags = get_vlm_tags()
|
||||
if remove_tagme:
|
||||
vlm_tags.append(vlm_tagme_tag_id)
|
||||
if remove_errored:
|
||||
vlm_tags.append(vlm_errored_tag_id)
|
||||
stash.update_scenes({
|
||||
"ids": [video_id],
|
||||
"tag_ids": {"ids": vlm_tags, "mode": "REMOVE"}
|
||||
})
|
||||
|
||||
def get_tagme_scenes() -> List[Dict[str, Any]]:
|
||||
"""Get scenes tagged with VLM_TagMe"""
|
||||
return stash.find_scenes(
|
||||
f={"tags": {"value": vlm_tagme_tag_id, "modifier": "INCLUDES"}},
|
||||
fragment="id tags {id} files {path duration fingerprint(type: \"phash\")}"
|
||||
)
|
||||
|
||||
def add_error_scene(scene_id: int) -> None:
|
||||
"""Add error tag to a scene"""
|
||||
stash.update_scenes({
|
||||
"ids": [scene_id],
|
||||
"tag_ids": {"ids": [vlm_errored_tag_id], "mode": "ADD"}
|
||||
})
|
||||
|
||||
def remove_tagme_tag_from_scene(scene_id: int) -> None:
|
||||
"""Remove VLM_TagMe tag from a scene"""
|
||||
stash.update_scenes({
|
||||
"ids": [scene_id],
|
||||
"tag_ids": {"ids": [vlm_tagme_tag_id], "mode": "REMOVE"}
|
||||
})
|
||||
|
||||
# ----------------- Marker Management Methods -----------------
|
||||
|
||||
def add_markers_to_video_from_dict(
|
||||
video_id: int,
|
||||
tag_timespans_dict: Dict[str, Dict[str, List[Any]]]
|
||||
) -> None:
|
||||
"""Add markers to video from timespan dictionary"""
|
||||
for _, tag_timespan_dict in tag_timespans_dict.items():
|
||||
for tag_name, time_frames in tag_timespan_dict.items():
|
||||
tag_id = get_tag_id(tag_name, create=True)
|
||||
if tag_id:
|
||||
add_markers_to_video(video_id, tag_id, tag_name, time_frames)
|
||||
|
||||
def get_incorrect_markers() -> List[Dict[str, Any]]:
|
||||
"""Get markers tagged with VLM_Incorrect"""
|
||||
if end_seconds_support:
|
||||
return stash.find_scene_markers(
|
||||
{"tags": {"value": vlm_incorrect_tag_id, "modifier": "INCLUDES"}},
|
||||
fragment="id scene {id files{path}} primary_tag {id, name} seconds end_seconds"
|
||||
)
|
||||
else:
|
||||
return stash.find_scene_markers(
|
||||
{"tags": {"value": vlm_incorrect_tag_id, "modifier": "INCLUDES"}},
|
||||
fragment="id scene {id files{path}} primary_tag {id, name} seconds"
|
||||
)
|
||||
|
||||
def add_markers_to_video(
|
||||
video_id: int,
|
||||
tag_id: int,
|
||||
tag_name: str,
|
||||
time_frames: List[Any]
|
||||
) -> None:
|
||||
"""Add markers to video for specific time frames"""
|
||||
for time_frame in time_frames:
|
||||
if end_seconds_support:
|
||||
stash.create_scene_marker({
|
||||
"scene_id": video_id,
|
||||
"primary_tag_id": tag_id,
|
||||
"tag_ids": [tag_id],
|
||||
"seconds": time_frame.start,
|
||||
"end_seconds": time_frame.end,
|
||||
"title": tag_name
|
||||
})
|
||||
else:
|
||||
stash.create_scene_marker({
|
||||
"scene_id": video_id,
|
||||
"primary_tag_id": tag_id,
|
||||
"tag_ids": [tag_id],
|
||||
"seconds": time_frame.start,
|
||||
"title": tag_name
|
||||
})
|
||||
|
||||
def get_scene_markers(video_id: int) -> List[Dict[str, Any]]:
|
||||
"""Get all markers for a scene"""
|
||||
return stash.get_scene_markers(video_id)
|
||||
|
||||
def write_scene_marker_to_file(
|
||||
marker: Dict[str, Any],
|
||||
scene_file: str,
|
||||
output_folder: str
|
||||
) -> None:
|
||||
"""Write scene marker data to file for analysis"""
|
||||
try:
|
||||
marker_id = marker['id']
|
||||
scene_id = marker['scene']['id']
|
||||
tag_name = marker['primary_tag']['name']
|
||||
|
||||
# Create output filename
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"marker_{marker_id}_scene_{scene_id}_{tag_name}_{timestamp}.json"
|
||||
output_path = os.path.join(output_folder, filename)
|
||||
|
||||
# Prepare marker data
|
||||
marker_data = {
|
||||
"marker_id": marker_id,
|
||||
"scene_id": scene_id,
|
||||
"tag_name": tag_name,
|
||||
"seconds": marker.get("seconds"),
|
||||
"end_seconds": marker.get("end_seconds"),
|
||||
"scene_file": scene_file,
|
||||
"timestamp": timestamp
|
||||
}
|
||||
|
||||
# Write to file
|
||||
with open(output_path, 'w') as f:
|
||||
json.dump(marker_data, f, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Failed to write marker data: {e}")
|
||||
|
||||
def delete_markers(markers: List[Dict[str, Any]]) -> None:
|
||||
"""Delete markers from StashApp"""
|
||||
for marker in markers:
|
||||
try:
|
||||
stash.destroy_scene_marker(marker['id'])
|
||||
except Exception as e:
|
||||
log.error(f"Failed to delete marker {marker['id']}: {e}")
|
||||
|
||||
def get_scene_markers_by_tag(
|
||||
video_id: int,
|
||||
error_if_no_end_seconds: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get scene markers by tag with end_seconds support check"""
|
||||
if end_seconds_support:
|
||||
return stash.get_scene_markers(video_id)
|
||||
else:
|
||||
if error_if_no_end_seconds:
|
||||
log.error("End seconds not supported in this StashApp version")
|
||||
raise RuntimeError("End seconds not supported")
|
||||
return stash.get_scene_markers(video_id)
|
||||
|
||||
def remove_incorrect_tag_from_markers(markers: List[Dict[str, Any]]) -> None:
|
||||
"""Remove VLM_Incorrect tag from markers"""
|
||||
marker_ids = [marker['id'] for marker in markers]
|
||||
for marker_id in marker_ids:
|
||||
try:
|
||||
stash.update_scene_marker({
|
||||
"id": marker_id,
|
||||
"tag_ids": {"ids": [vlm_incorrect_tag_id], "mode": "REMOVE"}
|
||||
})
|
||||
except Exception as e:
|
||||
log.error(f"Failed to remove incorrect tag from marker {marker_id}: {e}")
|
||||
|
||||
def remove_vlm_markers_from_video(video_id: int) -> None:
|
||||
"""Remove all VLM markers from a video"""
|
||||
markers = get_scene_markers(video_id)
|
||||
vlm_tag_ids = get_vlm_tags()
|
||||
|
||||
for marker in markers:
|
||||
if marker['primary_tag']['id'] in vlm_tag_ids:
|
||||
try:
|
||||
stash.destroy_scene_marker(marker['id'])
|
||||
except Exception as e:
|
||||
log.error(f"Failed to delete VLM marker {marker['id']}: {e}")
|
||||
445
plugins/AHavenVLMConnector/haven_vlm_config.py
Normal file
445
plugins/AHavenVLMConnector/haven_vlm_config.py
Normal file
@ -0,0 +1,445 @@
|
||||
"""
|
||||
Configuration for A Haven VLM Connector
|
||||
A StashApp plugin for Vision-Language Model based content tagging
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
import yaml
|
||||
|
||||
# ----------------- Core Settings -----------------
|
||||
|
||||
# VLM Engine Configuration
|
||||
VLM_ENGINE_CONFIG = {
|
||||
"active_ai_models": ["vlm_multiplexer_model"],
|
||||
"pipelines": {
|
||||
"video_pipeline_dynamic": {
|
||||
"inputs": [
|
||||
"video_path",
|
||||
"return_timestamps",
|
||||
"time_interval",
|
||||
"threshold",
|
||||
"return_confidence",
|
||||
"vr_video",
|
||||
"existing_video_data",
|
||||
"skipped_categories",
|
||||
],
|
||||
"output": "results",
|
||||
"short_name": "dynamic_video",
|
||||
"version": 1.0,
|
||||
"models": [
|
||||
{
|
||||
"name": "dynamic_video_ai",
|
||||
"inputs": [
|
||||
"video_path", "return_timestamps", "time_interval",
|
||||
"threshold", "return_confidence", "vr_video",
|
||||
"existing_video_data", "skipped_categories"
|
||||
],
|
||||
"outputs": "results",
|
||||
},
|
||||
],
|
||||
}
|
||||
},
|
||||
"models": {
|
||||
"binary_search_processor_dynamic": {
|
||||
"type": "binary_search_processor",
|
||||
"model_file_name": "binary_search_processor_dynamic"
|
||||
},
|
||||
"vlm_multiplexer_model": {
|
||||
"type": "vlm_model",
|
||||
"model_file_name": "vlm_multiplexer_model",
|
||||
"model_category": "actiondetection",
|
||||
"model_id": "zai-org/glm-4.6v-flash",
|
||||
"model_identifier": 93848,
|
||||
"model_version": "1.0",
|
||||
"use_multiplexer": True,
|
||||
"max_concurrent_requests": 13,
|
||||
"instance_count": 10,
|
||||
"max_batch_size": 4,
|
||||
"multiplexer_endpoints": [
|
||||
{
|
||||
"base_url": "http://localhost:1234/v1",
|
||||
"api_key": "",
|
||||
"name": "lm-studio-primary",
|
||||
"weight": 9,
|
||||
"is_fallback": False,
|
||||
"max_concurrent": 10
|
||||
},
|
||||
{
|
||||
"base_url": "https://cloudagnostic.com:443/v1",
|
||||
"api_key": "",
|
||||
"name": "cloud-fallback",
|
||||
"weight": 1,
|
||||
"is_fallback": True,
|
||||
"max_concurrent": 2
|
||||
}
|
||||
],
|
||||
"tag_list": [
|
||||
"Anal Fucking", "Ass Licking", "Ass Penetration", "Ball Licking/Sucking", "Blowjob", "Cum on Person",
|
||||
"Cum Swapping", "Cumshot", "Deepthroat", "Double Penetration", "Fingering", "Fisting", "Footjob",
|
||||
"Gangbang", "Gloryhole", "Grabbing Ass", "Grabbing Boobs", "Grabbing Hair/Head", "Handjob", "Kissing",
|
||||
"Licking Penis", "Masturbation", "Pissing", "Pussy Licking (Clearly Visible)", "Pussy Licking",
|
||||
"Pussy Rubbing", "Sucking Fingers", "Sucking Toy/Dildo", "Wet (Genitals)", "Titjob", "Tribbing/Scissoring",
|
||||
"Undressing", "Vaginal Penetration", "Vaginal Fucking", "Vibrating"
|
||||
]
|
||||
},
|
||||
"result_coalescer": {
|
||||
"type": "python",
|
||||
"model_file_name": "result_coalescer"
|
||||
},
|
||||
"result_finisher": {
|
||||
"type": "python",
|
||||
"model_file_name": "result_finisher"
|
||||
},
|
||||
"batch_awaiter": {
|
||||
"type": "python",
|
||||
"model_file_name": "batch_awaiter"
|
||||
},
|
||||
"video_result_postprocessor": {
|
||||
"type": "python",
|
||||
"model_file_name": "video_result_postprocessor"
|
||||
},
|
||||
},
|
||||
"category_config": {
|
||||
"actiondetection": {
|
||||
"69": {
|
||||
"RenamedTag": "69",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Anal Fucking": {
|
||||
"RenamedTag": "Anal Fucking",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Ass Licking": {
|
||||
"RenamedTag": "Ass Licking",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Ass Penetration": {
|
||||
"RenamedTag": "Ass Penetration",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Ball Licking/Sucking": {
|
||||
"RenamedTag": "Ball Licking/Sucking",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Blowjob": {
|
||||
"RenamedTag": "Blowjob",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Cum on Person": {
|
||||
"RenamedTag": "Cum on Person",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Cum Swapping": {
|
||||
"RenamedTag": "Cum Swapping",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Cumshot": {
|
||||
"RenamedTag": "Cumshot",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Deepthroat": {
|
||||
"RenamedTag": "Deepthroat",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Double Penetration": {
|
||||
"RenamedTag": "Double Penetration",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Fingering": {
|
||||
"RenamedTag": "Fingering",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Fisting": {
|
||||
"RenamedTag": "Fisting",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Footjob": {
|
||||
"RenamedTag": "Footjob",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Gangbang": {
|
||||
"RenamedTag": "Gangbang",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Gloryhole": {
|
||||
"RenamedTag": "Gloryhole",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Grabbing Ass": {
|
||||
"RenamedTag": "Grabbing Ass",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Grabbing Boobs": {
|
||||
"RenamedTag": "Grabbing Boobs",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Grabbing Hair/Head": {
|
||||
"RenamedTag": "Grabbing Hair/Head",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Handjob": {
|
||||
"RenamedTag": "Handjob",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Kissing": {
|
||||
"RenamedTag": "Kissing",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Licking Penis": {
|
||||
"RenamedTag": "Licking Penis",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Masturbation": {
|
||||
"RenamedTag": "Masturbation",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Pissing": {
|
||||
"RenamedTag": "Pissing",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Pussy Licking (Clearly Visible)": {
|
||||
"RenamedTag": "Pussy Licking (Clearly Visible)",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Pussy Licking": {
|
||||
"RenamedTag": "Pussy Licking",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Pussy Rubbing": {
|
||||
"RenamedTag": "Pussy Rubbing",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Sucking Fingers": {
|
||||
"RenamedTag": "Sucking Fingers",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Sucking Toy/Dildo": {
|
||||
"RenamedTag": "Sucking Toy/Dildo",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Wet (Genitals)": {
|
||||
"RenamedTag": "Wet (Genitals)",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Titjob": {
|
||||
"RenamedTag": "Titjob",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Tribbing/Scissoring": {
|
||||
"RenamedTag": "Tribbing/Scissoring",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Undressing": {
|
||||
"RenamedTag": "Undressing",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Vaginal Penetration": {
|
||||
"RenamedTag": "Vaginal Penetration",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Vaginal Fucking": {
|
||||
"RenamedTag": "Vaginal Fucking",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
},
|
||||
"Vibrating": {
|
||||
"RenamedTag": "Vibrating",
|
||||
"MinMarkerDuration": "1s",
|
||||
"MaxGap": "30s",
|
||||
"RequiredDuration": "1s",
|
||||
"TagThreshold": 0.5,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# ----------------- Processing Settings -----------------
|
||||
|
||||
# Video processing settings
|
||||
VIDEO_FRAME_INTERVAL = 80 # Process every 80 seconds
|
||||
VIDEO_THRESHOLD = 0.3
|
||||
VIDEO_CONFIDENCE_RETURN = True
|
||||
|
||||
# Concurrency settings
|
||||
CONCURRENT_TASK_LIMIT = 20 # Increased for better parallel video processing
|
||||
SERVER_TIMEOUT = 3700
|
||||
|
||||
# ----------------- Tag Configuration -----------------
|
||||
|
||||
# Tag names for StashApp integration
|
||||
VLM_BASE_TAG_NAME = "VLM"
|
||||
VLM_TAGME_TAG_NAME = "VLM_TagMe"
|
||||
VLM_UPDATEME_TAG_NAME = "VLM_UpdateMe"
|
||||
VLM_TAGGED_TAG_NAME = "VLM_Tagged"
|
||||
VLM_ERRORED_TAG_NAME = "VLM_Errored"
|
||||
VLM_INCORRECT_TAG_NAME = "VLM_Incorrect"
|
||||
|
||||
# ----------------- File System Settings -----------------
|
||||
|
||||
# Directory paths
|
||||
OUTPUT_DATA_DIR = "./output_data"
|
||||
|
||||
# File management
|
||||
DELETE_INCORRECT_MARKERS = True
|
||||
CREATE_MARKERS = True
|
||||
|
||||
# Path mutations for different environments
|
||||
PATH_MUTATION = {}
|
||||
|
||||
# ----------------- Configuration Loading -----------------
|
||||
|
||||
@dataclass
|
||||
class VLMConnectorConfig:
|
||||
"""Configuration class for the VLM Connector"""
|
||||
vlm_engine_config: Dict
|
||||
video_frame_interval: float
|
||||
video_threshold: float
|
||||
video_confidence_return: bool
|
||||
concurrent_task_limit: int
|
||||
server_timeout: int
|
||||
vlm_base_tag_name: str
|
||||
vlm_tagme_tag_name: str
|
||||
vlm_updateme_tag_name: str
|
||||
vlm_tagged_tag_name: str
|
||||
vlm_errored_tag_name: str
|
||||
vlm_incorrect_tag_name: str
|
||||
output_data_dir: str
|
||||
delete_incorrect_markers: bool
|
||||
create_markers: bool
|
||||
path_mutation: Dict
|
||||
|
||||
def load_config_from_yaml(config_path: Optional[str] = None) -> VLMConnectorConfig:
|
||||
"""Load configuration from YAML file or use defaults"""
|
||||
if config_path and os.path.exists(config_path):
|
||||
with open(config_path, 'r') as f:
|
||||
yaml_config = yaml.safe_load(f)
|
||||
return VLMConnectorConfig(**yaml_config)
|
||||
|
||||
# Return default configuration
|
||||
return VLMConnectorConfig(
|
||||
vlm_engine_config=VLM_ENGINE_CONFIG,
|
||||
video_frame_interval=VIDEO_FRAME_INTERVAL,
|
||||
video_threshold=VIDEO_THRESHOLD,
|
||||
video_confidence_return=VIDEO_CONFIDENCE_RETURN,
|
||||
concurrent_task_limit=CONCURRENT_TASK_LIMIT,
|
||||
server_timeout=SERVER_TIMEOUT,
|
||||
vlm_base_tag_name=VLM_BASE_TAG_NAME,
|
||||
vlm_tagme_tag_name=VLM_TAGME_TAG_NAME,
|
||||
vlm_updateme_tag_name=VLM_UPDATEME_TAG_NAME,
|
||||
vlm_tagged_tag_name=VLM_TAGGED_TAG_NAME,
|
||||
vlm_errored_tag_name=VLM_ERRORED_TAG_NAME,
|
||||
vlm_incorrect_tag_name=VLM_INCORRECT_TAG_NAME,
|
||||
output_data_dir=OUTPUT_DATA_DIR,
|
||||
delete_incorrect_markers=DELETE_INCORRECT_MARKERS,
|
||||
create_markers=CREATE_MARKERS,
|
||||
path_mutation=PATH_MUTATION
|
||||
)
|
||||
|
||||
# Global configuration instance
|
||||
config = load_config_from_yaml()
|
||||
444
plugins/AHavenVLMConnector/haven_vlm_connector.py
Normal file
444
plugins/AHavenVLMConnector/haven_vlm_connector.py
Normal file
@ -0,0 +1,444 @@
|
||||
"""
|
||||
A Haven VLM Connector
|
||||
A StashApp plugin for Vision-Language Model based content tagging
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import shutil
|
||||
import traceback
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
# Import and install sys.exit tracking FIRST (before any other imports that might call sys.exit)
|
||||
try:
|
||||
from exit_tracker import install_exit_tracker
|
||||
import stashapi.log as log
|
||||
install_exit_tracker(log)
|
||||
except ImportError as e:
|
||||
print(f"Warning: exit_tracker not available: {e}")
|
||||
print("sys.exit tracking will not be available")
|
||||
|
||||
# ----------------- Setup and Dependencies -----------------
|
||||
|
||||
# Use PythonDepManager for dependency management
|
||||
try:
|
||||
from PythonDepManager import ensure_import
|
||||
|
||||
# Install and ensure all required dependencies with specific versions
|
||||
ensure_import(
|
||||
"stashapi:stashapp-tools==0.2.58",
|
||||
"aiohttp==3.12.13",
|
||||
"pydantic==2.11.7",
|
||||
"vlm-engine==0.9.1",
|
||||
"pyyaml==6.0.2"
|
||||
)
|
||||
|
||||
# Import the dependencies after ensuring they're available
|
||||
import stashapi.log as log
|
||||
from stashapi.stashapp import StashInterface
|
||||
import aiohttp
|
||||
import pydantic
|
||||
import yaml
|
||||
|
||||
except ImportError as e:
|
||||
print(f"Failed to import PythonDepManager or required dependencies: {e}")
|
||||
print("Please ensure PythonDepManager is installed and available.")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"Error during dependency management: {e}")
|
||||
print(f"Stack trace: {traceback.format_exc()}")
|
||||
sys.exit(1)
|
||||
|
||||
# Import local modules
|
||||
try:
|
||||
import haven_vlm_config as config
|
||||
except ModuleNotFoundError:
|
||||
log.error("Please provide a haven_vlm_config.py file with the required variables.")
|
||||
raise Exception("Please provide a haven_vlm_config.py file with the required variables.")
|
||||
|
||||
import haven_media_handler as media_handler
|
||||
import haven_vlm_engine as vlm_engine
|
||||
from haven_vlm_engine import TimeFrame
|
||||
|
||||
log.debug("Python instance is running at: " + sys.executable)
|
||||
|
||||
# ----------------- Global Variables -----------------
|
||||
|
||||
semaphore: Optional[asyncio.Semaphore] = None
|
||||
progress: float = 0.0
|
||||
increment: float = 0.0
|
||||
completed_tasks: int = 0
|
||||
total_tasks: int = 0
|
||||
video_progress: Dict[str, float] = {}
|
||||
|
||||
# ----------------- Main Execution -----------------
|
||||
|
||||
async def main() -> None:
|
||||
"""Main entry point for the plugin"""
|
||||
global semaphore
|
||||
|
||||
# Semaphore initialization logging for hypothesis A
|
||||
log.debug(f"[DEBUG_HYPOTHESIS_A] Initializing semaphore with limit {config.config.concurrent_task_limit}")
|
||||
|
||||
semaphore = asyncio.Semaphore(config.config.concurrent_task_limit)
|
||||
|
||||
# Post-semaphore creation logging
|
||||
log.debug(f"[DEBUG_HYPOTHESIS_A] Semaphore created successfully (limit: {config.config.concurrent_task_limit})")
|
||||
|
||||
json_input = read_json_input()
|
||||
output = {}
|
||||
await run(json_input, output)
|
||||
out = json.dumps(output)
|
||||
print(out + "\n")
|
||||
|
||||
def read_json_input() -> Dict[str, Any]:
|
||||
"""Read JSON input from stdin"""
|
||||
json_input = sys.stdin.read()
|
||||
return json.loads(json_input)
|
||||
|
||||
async def run(json_input: Dict[str, Any], output: Dict[str, Any]) -> None:
|
||||
"""Main execution logic"""
|
||||
plugin_args = None
|
||||
try:
|
||||
log.debug(json_input["server_connection"])
|
||||
os.chdir(json_input["server_connection"]["PluginDir"])
|
||||
media_handler.initialize(json_input["server_connection"])
|
||||
except Exception as e:
|
||||
log.error(f"Failed to initialize media handler: {e}")
|
||||
raise
|
||||
|
||||
try:
|
||||
plugin_args = json_input['args']["mode"]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if plugin_args == "tag_videos":
|
||||
await tag_videos()
|
||||
output["output"] = "ok"
|
||||
return
|
||||
elif plugin_args == "find_marker_settings":
|
||||
await find_marker_settings()
|
||||
output["output"] = "ok"
|
||||
return
|
||||
elif plugin_args == "collect_incorrect_markers":
|
||||
collect_incorrect_markers_and_images()
|
||||
output["output"] = "ok"
|
||||
return
|
||||
|
||||
output["output"] = "ok"
|
||||
return
|
||||
|
||||
# ----------------- High Level Processing Functions -----------------
|
||||
|
||||
async def tag_videos() -> None:
|
||||
"""Tag videos with VLM analysis using improved async orchestration"""
|
||||
global completed_tasks, total_tasks
|
||||
|
||||
scenes = media_handler.get_tagme_scenes()
|
||||
if not scenes:
|
||||
log.info("No videos to tag. Have you tagged any scenes with the VLM_TagMe tag to get processed?")
|
||||
return
|
||||
|
||||
total_tasks = len(scenes)
|
||||
completed_tasks = 0
|
||||
|
||||
video_progress.clear()
|
||||
for scene in scenes:
|
||||
video_progress[scene.get('id', 'unknown')] = 0.0
|
||||
log.progress(0.0)
|
||||
|
||||
log.info(f"🚀 Starting video processing for {total_tasks} scenes with semaphore limit of {config.config.concurrent_task_limit}")
|
||||
|
||||
# Create tasks with proper indexing for debugging
|
||||
tasks = []
|
||||
for i, scene in enumerate(scenes):
|
||||
# Pre-task creation logging for hypothesis A (semaphore deadlock) and E (signal termination)
|
||||
scene_id = scene.get('id')
|
||||
log.debug(f"[DEBUG_HYPOTHESIS_A] Creating task {i+1}/{total_tasks} for scene {scene_id}, semaphore limit: {config.config.concurrent_task_limit}")
|
||||
|
||||
task = asyncio.create_task(__tag_video_with_timing(scene, i))
|
||||
tasks.append(task)
|
||||
|
||||
# Use asyncio.as_completed to process results as they finish (proves concurrency)
|
||||
completed_task_futures = asyncio.as_completed(tasks)
|
||||
|
||||
batch_start_time = asyncio.get_event_loop().time()
|
||||
|
||||
for completed_task in completed_task_futures:
|
||||
try:
|
||||
await completed_task
|
||||
completed_tasks += 1
|
||||
|
||||
except Exception as e:
|
||||
completed_tasks += 1
|
||||
# Exception logging for hypothesis E (signal termination)
|
||||
error_type = type(e).__name__
|
||||
log.debug(f"[DEBUG_HYPOTHESIS_E] Task failed with exception: {error_type}: {str(e)} (Task {completed_tasks}/{total_tasks})")
|
||||
|
||||
log.error(f"❌ Task failed: {e}")
|
||||
|
||||
total_time = asyncio.get_event_loop().time() - batch_start_time
|
||||
|
||||
log.info(f"🎉 All {total_tasks} videos completed in {total_time:.2f}s (avg: {total_time/total_tasks:.2f}s/video)")
|
||||
log.progress(1.0)
|
||||
|
||||
async def find_marker_settings() -> None:
|
||||
"""Find optimal marker settings based on a single tagged video"""
|
||||
scenes = media_handler.get_tagme_scenes()
|
||||
if len(scenes) != 1:
|
||||
log.error("Please tag exactly one scene with the VLM_TagMe tag to get processed.")
|
||||
return
|
||||
scene = scenes[0]
|
||||
await __find_marker_settings(scene)
|
||||
|
||||
def collect_incorrect_markers_and_images() -> None:
|
||||
"""Collect data from incorrectly tagged markers and images"""
|
||||
incorrect_images = media_handler.get_incorrect_images()
|
||||
image_paths, image_ids, temp_files = media_handler.get_image_paths_and_ids(incorrect_images)
|
||||
incorrect_markers = media_handler.get_incorrect_markers()
|
||||
|
||||
if not (len(incorrect_images) > 0 or len(incorrect_markers) > 0):
|
||||
log.info("No incorrect images or markers to collect.")
|
||||
return
|
||||
|
||||
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
try:
|
||||
# Process images
|
||||
image_folder = os.path.join(config.config.output_data_dir, "images")
|
||||
os.makedirs(image_folder, exist_ok=True)
|
||||
for image_path in image_paths:
|
||||
try:
|
||||
shutil.copy(image_path, image_folder)
|
||||
except Exception as e:
|
||||
log.error(f"Failed to copy image {image_path} to {image_folder}: {e}")
|
||||
except Exception as e:
|
||||
log.error(f"Failed to process images: {e}")
|
||||
raise e
|
||||
finally:
|
||||
# Clean up temp files
|
||||
for temp_file in temp_files:
|
||||
try:
|
||||
if os.path.isdir(temp_file):
|
||||
shutil.rmtree(temp_file)
|
||||
else:
|
||||
os.remove(temp_file)
|
||||
except Exception as e:
|
||||
log.debug(f"Failed to remove temp file {temp_file}: {e}")
|
||||
|
||||
# Process markers
|
||||
scene_folder = os.path.join(config.config.output_data_dir, "scenes")
|
||||
os.makedirs(scene_folder, exist_ok=True)
|
||||
tag_folders = {}
|
||||
|
||||
for marker in incorrect_markers:
|
||||
scene_path = marker['scene']['files'][0]['path']
|
||||
if not scene_path:
|
||||
log.error(f"Marker {marker['id']} has no scene path")
|
||||
continue
|
||||
try:
|
||||
tag_name = marker['primary_tag']['name']
|
||||
if tag_name not in tag_folders:
|
||||
tag_folders[tag_name] = os.path.join(scene_folder, tag_name)
|
||||
os.makedirs(tag_folders[tag_name], exist_ok=True)
|
||||
media_handler.write_scene_marker_to_file(marker, scene_path, tag_folders[tag_name])
|
||||
except Exception as e:
|
||||
log.error(f"Failed to collect scene: {e}")
|
||||
|
||||
# Remove incorrect tags from images
|
||||
image_ids = [image['id'] for image in incorrect_images]
|
||||
media_handler.remove_incorrect_tag_from_images(image_ids)
|
||||
|
||||
# ----------------- Low Level Processing Functions -----------------
|
||||
|
||||
async def __tag_video_with_timing(scene: Dict[str, Any], scene_index: int) -> None:
|
||||
"""Tag a single video scene with timing diagnostics"""
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
scene_id = scene.get('id', 'unknown')
|
||||
|
||||
log.info(f"🎬 Starting video {scene_index + 1}: Scene {scene_id}")
|
||||
|
||||
try:
|
||||
await __tag_video(scene)
|
||||
end_time = asyncio.get_event_loop().time()
|
||||
duration = end_time - start_time
|
||||
log.info(f"✅ Completed video {scene_index + 1} (Scene {scene_id}) in {duration:.2f}s")
|
||||
|
||||
except Exception as e:
|
||||
end_time = asyncio.get_event_loop().time()
|
||||
duration = end_time - start_time
|
||||
log.error(f"❌ Failed video {scene_index + 1} (Scene {scene_id}) after {duration:.2f}s: {e}")
|
||||
raise
|
||||
|
||||
async def __tag_video(scene: Dict[str, Any]) -> None:
|
||||
"""Tag a single video scene with semaphore timing instrumentation"""
|
||||
scene_id = scene.get('id')
|
||||
|
||||
# Pre-semaphore acquisition logging for hypothesis A (semaphore deadlock)
|
||||
task_start_time = asyncio.get_event_loop().time()
|
||||
acquisition_start_time = task_start_time
|
||||
log.debug(f"[DEBUG_HYPOTHESIS_A] Task starting for scene {scene_id} at {task_start_time:.3f}s")
|
||||
|
||||
async with semaphore:
|
||||
try:
|
||||
# Semaphore acquisition successful logging
|
||||
acquisition_end_time = asyncio.get_event_loop().time()
|
||||
acquisition_time = acquisition_end_time - acquisition_start_time
|
||||
log.debug(f"[DEBUG_HYPOTHESIS_A] Semaphore acquired for scene {scene_id} after {acquisition_time:.3f}s")
|
||||
|
||||
if scene_id is None:
|
||||
log.error("Scene missing 'id' field")
|
||||
return
|
||||
|
||||
files = scene.get('files', [])
|
||||
if not files:
|
||||
log.error(f"Scene {scene_id} has no files")
|
||||
return
|
||||
|
||||
scene_file = files[0].get('path')
|
||||
if scene_file is None:
|
||||
log.error(f"Scene {scene_id} file has no path")
|
||||
return
|
||||
|
||||
# Check if scene is VR
|
||||
is_vr = media_handler.is_vr_scene(scene.get('tags', []))
|
||||
|
||||
def progress_cb(p: int) -> None:
|
||||
global video_progress, total_tasks
|
||||
video_progress[scene_id] = p / 100.0
|
||||
total_prog = sum(video_progress.values()) / total_tasks
|
||||
log.progress(total_prog)
|
||||
|
||||
# Process video through VLM Engine with HTTP timing for hypothesis B
|
||||
processing_start_time = asyncio.get_event_loop().time()
|
||||
|
||||
# HTTP request lifecycle tracking start
|
||||
log.debug(f"[DEBUG_HYPOTHESIS_B] Starting VLM processing for scene {scene_id}: {scene_file}")
|
||||
|
||||
video_result = await vlm_engine.process_video_async(
|
||||
scene_file,
|
||||
vr_video=is_vr,
|
||||
frame_interval=config.config.video_frame_interval,
|
||||
threshold=config.config.video_threshold,
|
||||
return_confidence=config.config.video_confidence_return,
|
||||
progress_callback=progress_cb
|
||||
)
|
||||
|
||||
# Extract detected tags
|
||||
detected_tags = set()
|
||||
for category_tags in video_result.video_tags.values():
|
||||
detected_tags.update(category_tags)
|
||||
|
||||
# Post-VLM processing logging
|
||||
processing_end_time = asyncio.get_event_loop().time()
|
||||
processing_duration = processing_end_time - processing_start_time
|
||||
log.debug(f"[DEBUG_HYPOTHESIS_B] VLM processing completed for scene {scene_id} in {processing_duration:.2f}s ({len(detected_tags)} detected tags)")
|
||||
|
||||
if detected_tags:
|
||||
# Clear all existing tags and markers before adding new ones
|
||||
media_handler.clear_all_tags_from_video(scene)
|
||||
media_handler.clear_all_markers_from_video(scene_id)
|
||||
|
||||
# Add tags to scene
|
||||
tag_ids = media_handler.get_tag_ids(list(detected_tags), create=True)
|
||||
media_handler.add_tags_to_video(scene_id, tag_ids)
|
||||
log.info(f"Added tags {list(detected_tags)} to scene {scene_id}")
|
||||
|
||||
# Add markers if enabled
|
||||
if config.config.create_markers:
|
||||
media_handler.add_markers_to_video_from_dict(scene_id, video_result.tag_timespans)
|
||||
log.info(f"Added markers to scene {scene_id}")
|
||||
|
||||
# Remove VLM_TagMe tag from processed scene
|
||||
media_handler.remove_tagme_tag_from_scene(scene_id)
|
||||
|
||||
# Task completion logging
|
||||
task_end_time = asyncio.get_event_loop().time()
|
||||
total_task_time = task_end_time - task_start_time
|
||||
log.debug(f"[DEBUG_HYPOTHESIS_A] Task completed for scene {scene_id} in {total_task_time:.2f}s")
|
||||
|
||||
except Exception as e:
|
||||
# Exception handling with detailed logging for hypothesis E
|
||||
exception_time = asyncio.get_event_loop().time()
|
||||
error_type = type(e).__name__
|
||||
log.debug(f"[DEBUG_HYPOTHESIS_E] Task exception for scene {scene_id}: {error_type}: {str(e)} at {exception_time:.3f}s")
|
||||
|
||||
scene_id = scene.get('id', 'unknown')
|
||||
log.error(f"Error processing video scene {scene_id}: {e}")
|
||||
# Add error tag to failed scene if we have a valid ID
|
||||
if scene_id != 'unknown':
|
||||
media_handler.add_error_scene(scene_id)
|
||||
|
||||
async def __find_marker_settings(scene: Dict[str, Any]) -> None:
|
||||
"""Find optimal marker settings for a scene"""
|
||||
try:
|
||||
scene_id = scene.get('id')
|
||||
if scene_id is None:
|
||||
log.error("Scene missing 'id' field")
|
||||
return
|
||||
|
||||
files = scene.get('files', [])
|
||||
if not files:
|
||||
log.error(f"Scene {scene_id} has no files")
|
||||
return
|
||||
|
||||
scene_file = files[0].get('path')
|
||||
if scene_file is None:
|
||||
log.error(f"Scene {scene_id} file has no path")
|
||||
return
|
||||
|
||||
# Get existing markers for the scene
|
||||
existing_markers = media_handler.get_scene_markers(scene_id)
|
||||
|
||||
# Convert markers to desired timespan format
|
||||
desired_timespan_data = {}
|
||||
for marker in existing_markers:
|
||||
tag_name = marker['primary_tag']['name']
|
||||
desired_timespan_data[tag_name] = TimeFrame(
|
||||
start=marker['seconds'],
|
||||
end=marker.get('end_seconds', marker['seconds'] + 1),
|
||||
total_confidence=1.0
|
||||
)
|
||||
|
||||
# Find optimal settings
|
||||
optimal_settings = await vlm_engine.find_optimal_marker_settings_async(
|
||||
existing_json={}, # No existing JSON data
|
||||
desired_timespan_data=desired_timespan_data
|
||||
)
|
||||
|
||||
# Output results
|
||||
log.info(f"Optimal marker settings found for scene {scene_id}:")
|
||||
log.info(json.dumps(optimal_settings, indent=2))
|
||||
|
||||
except Exception as e:
|
||||
scene_id = scene.get('id', 'unknown')
|
||||
log.error(f"Error finding marker settings for scene {scene_id}: {e}")
|
||||
|
||||
# ----------------- Cleanup -----------------
|
||||
|
||||
async def cleanup() -> None:
|
||||
"""Cleanup resources"""
|
||||
if vlm_engine.vlm_engine:
|
||||
await vlm_engine.vlm_engine.shutdown()
|
||||
|
||||
# Run main function if script is executed directly
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
log.info("Plugin interrupted by user")
|
||||
sys.exit(0)
|
||||
except SystemExit as e:
|
||||
# Re-raise system exit with the exit code
|
||||
log.debug(f"[DEBUG_EXIT_CODE] Caught SystemExit with code: {e.code}")
|
||||
raise
|
||||
except Exception as e:
|
||||
log.error(f"Plugin failed: {e}")
|
||||
sys.exit(1)
|
||||
finally:
|
||||
asyncio.run(cleanup())
|
||||
299
plugins/AHavenVLMConnector/haven_vlm_engine.py
Normal file
299
plugins/AHavenVLMConnector/haven_vlm_engine.py
Normal file
@ -0,0 +1,299 @@
|
||||
"""
|
||||
Haven VLM Engine Integration Module
|
||||
Provides integration with the Haven VLM Engine for video and image processing
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Set, Union, Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
# Use PythonDepManager for dependency management
|
||||
from vlm_engine import VLMEngine
|
||||
from vlm_engine.config_models import (
|
||||
EngineConfig,
|
||||
PipelineConfig,
|
||||
ModelConfig,
|
||||
PipelineModelConfig
|
||||
)
|
||||
|
||||
import haven_vlm_config as config
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.CRITICAL)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class TimeFrame:
|
||||
"""Represents a time frame with start and end times"""
|
||||
start: float
|
||||
end: float
|
||||
total_confidence: Optional[float] = None
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""Convert to JSON string"""
|
||||
return json.dumps({
|
||||
"start": self.start,
|
||||
"end": self.end,
|
||||
"total_confidence": self.total_confidence
|
||||
})
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"TimeFrame(start={self.start}, end={self.end}, confidence={self.total_confidence})"
|
||||
|
||||
@dataclass
|
||||
class VideoTagInfo:
|
||||
"""Represents video tagging information"""
|
||||
video_duration: float
|
||||
video_tags: Dict[str, Set[str]]
|
||||
tag_totals: Dict[str, Dict[str, float]]
|
||||
tag_timespans: Dict[str, Dict[str, List[TimeFrame]]]
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_data: Dict[str, Any]) -> 'VideoTagInfo':
|
||||
"""Create VideoTagInfo from JSON data"""
|
||||
logger.debug(f"Creating VideoTagInfo from JSON: {json_data}")
|
||||
|
||||
# Convert tag_timespans to TimeFrame objects
|
||||
tag_timespans = {}
|
||||
for category, tags in json_data.get("tag_timespans", {}).items():
|
||||
tag_timespans[category] = {}
|
||||
for tag_name, timeframes in tags.items():
|
||||
tag_timespans[category][tag_name] = [
|
||||
TimeFrame(
|
||||
start=tf["start"],
|
||||
end=tf["end"],
|
||||
total_confidence=tf.get("total_confidence")
|
||||
) for tf in timeframes
|
||||
]
|
||||
|
||||
return cls(
|
||||
video_duration=json_data.get("video_duration", 0.0),
|
||||
video_tags=json_data.get("video_tags", {}),
|
||||
tag_totals=json_data.get("tag_totals", {}),
|
||||
tag_timespans=tag_timespans
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"VideoTagInfo(duration={self.video_duration}, tags={len(self.video_tags)}, timespans={len(self.tag_timespans)})"
|
||||
|
||||
class HavenVLMEngine:
|
||||
"""Main VLM Engine integration class"""
|
||||
|
||||
def __init__(self):
|
||||
self.engine: Optional[VLMEngine] = None
|
||||
self.engine_config: Optional[EngineConfig] = None
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the VLM Engine with configuration"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("Initializing Haven VLM Engine...")
|
||||
|
||||
# Convert config dict to EngineConfig objects
|
||||
self.engine_config = self._create_engine_config()
|
||||
|
||||
# Create and initialize the engine
|
||||
self.engine = VLMEngine(config=self.engine_config)
|
||||
await self.engine.initialize()
|
||||
|
||||
self._initialized = True
|
||||
logger.info("Haven VLM Engine initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize VLM Engine: {e}")
|
||||
raise
|
||||
|
||||
def _create_engine_config(self) -> EngineConfig:
|
||||
"""Create EngineConfig from the configuration"""
|
||||
vlm_config = config.config.vlm_engine_config
|
||||
|
||||
# Create pipeline configs
|
||||
pipelines = {}
|
||||
for pipeline_name, pipeline_data in vlm_config["pipelines"].items():
|
||||
models = [
|
||||
PipelineModelConfig(
|
||||
name=model["name"],
|
||||
inputs=model["inputs"],
|
||||
outputs=model["outputs"]
|
||||
) for model in pipeline_data["models"]
|
||||
]
|
||||
|
||||
pipelines[pipeline_name] = PipelineConfig(
|
||||
inputs=pipeline_data["inputs"],
|
||||
output=pipeline_data["output"],
|
||||
short_name=pipeline_data["short_name"],
|
||||
version=pipeline_data["version"],
|
||||
models=models
|
||||
)
|
||||
|
||||
# Create model configs with new architectural changes
|
||||
models = {}
|
||||
for model_name, model_data in vlm_config["models"].items():
|
||||
if model_data["type"] == "vlm_model":
|
||||
# Process multiplexer_endpoints and validate max_concurrent
|
||||
multiplexer_endpoints = []
|
||||
for endpoint in model_data.get("multiplexer_endpoints", []):
|
||||
# Validate that max_concurrent is present
|
||||
if "max_concurrent" not in endpoint:
|
||||
raise ValueError(f"Endpoint '{endpoint.get('name', 'unnamed')}' is missing required 'max_concurrent' parameter")
|
||||
|
||||
multiplexer_endpoints.append({
|
||||
"base_url": endpoint["base_url"],
|
||||
"api_key": endpoint.get("api_key", ""),
|
||||
"name": endpoint["name"],
|
||||
"weight": endpoint.get("weight", 5),
|
||||
"is_fallback": endpoint.get("is_fallback", False),
|
||||
"max_concurrent": endpoint["max_concurrent"]
|
||||
})
|
||||
|
||||
models[model_name] = ModelConfig(
|
||||
type=model_data["type"],
|
||||
model_file_name=model_data["model_file_name"],
|
||||
model_category=model_data["model_category"],
|
||||
model_id=model_data["model_id"],
|
||||
model_identifier=model_data["model_identifier"],
|
||||
model_version=model_data["model_version"],
|
||||
use_multiplexer=model_data.get("use_multiplexer", False),
|
||||
max_concurrent_requests=model_data.get("max_concurrent_requests", 10),
|
||||
instance_count=model_data.get("instance_count",1),
|
||||
max_batch_size=model_data.get("max_batch_size",1),
|
||||
multiplexer_endpoints=multiplexer_endpoints,
|
||||
tag_list=model_data.get("tag_list", [])
|
||||
)
|
||||
else:
|
||||
models[model_name] = ModelConfig(
|
||||
type=model_data["type"],
|
||||
model_file_name=model_data["model_file_name"]
|
||||
)
|
||||
|
||||
return EngineConfig(
|
||||
active_ai_models=vlm_config["active_ai_models"],
|
||||
pipelines=pipelines,
|
||||
models=models,
|
||||
category_config=vlm_config["category_config"]
|
||||
)
|
||||
|
||||
async def process_video(
|
||||
self,
|
||||
video_path: str,
|
||||
vr_video: bool = False,
|
||||
frame_interval: Optional[float] = None,
|
||||
threshold: Optional[float] = None,
|
||||
return_confidence: Optional[bool] = None,
|
||||
existing_json: Optional[Dict[str, Any]] = None,
|
||||
progress_callback: Optional[Callable[[int], None]] = None
|
||||
) -> VideoTagInfo:
|
||||
"""Process a video using the VLM Engine"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
try:
|
||||
logger.info(f"Processing video: {video_path}")
|
||||
|
||||
# Use config defaults if not provided
|
||||
frame_interval = frame_interval or config.config.video_frame_interval
|
||||
threshold = threshold or config.config.video_threshold
|
||||
return_confidence = return_confidence if return_confidence is not None else config.config.video_confidence_return
|
||||
|
||||
# Process video through the engine
|
||||
results = await self.engine.process_video(
|
||||
video_path,
|
||||
frame_interval=frame_interval,
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
|
||||
logger.info(f"Video processing completed for: {video_path}")
|
||||
logger.debug(f"Raw results structure: {type(results)}")
|
||||
|
||||
# Extract video_tag_info from the nested structure
|
||||
if isinstance(results, dict) and 'video_tag_info' in results:
|
||||
video_tag_data = results['video_tag_info']
|
||||
logger.debug(f"Using video_tag_info from results: {video_tag_data.keys()}")
|
||||
else:
|
||||
# Fallback: assume results is already in the correct format
|
||||
video_tag_data = results
|
||||
logger.debug(f"Using results directly: {video_tag_data.keys() if isinstance(video_tag_data, dict) else type(video_tag_data)}")
|
||||
|
||||
return VideoTagInfo.from_json(video_tag_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing video {video_path}: {e}")
|
||||
raise
|
||||
|
||||
async def find_optimal_marker_settings(
|
||||
self,
|
||||
existing_json: Dict[str, Any],
|
||||
desired_timespan_data: Dict[str, TimeFrame]
|
||||
) -> Dict[str, Any]:
|
||||
"""Find optimal marker settings based on existing data"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
try:
|
||||
logger.info("Finding optimal marker settings...")
|
||||
|
||||
# Convert TimeFrame objects to dict format
|
||||
desired_data = {}
|
||||
for key, timeframe in desired_timespan_data.items():
|
||||
desired_data[key] = {
|
||||
"start": timeframe.start,
|
||||
"end": timeframe.end,
|
||||
"total_confidence": timeframe.total_confidence
|
||||
}
|
||||
|
||||
# Call the engine's optimization method
|
||||
results = await self.engine.optimize_timeframe_settings(
|
||||
existing_json_data=existing_json,
|
||||
desired_timespan_data=desired_data
|
||||
)
|
||||
|
||||
logger.info("Optimal marker settings found")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding optimal marker settings: {e}")
|
||||
raise
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shutdown the VLM Engine"""
|
||||
if self.engine and self._initialized:
|
||||
try:
|
||||
# VLMEngine doesn't have a shutdown method, just perform basic cleanup
|
||||
logger.info("VLM Engine cleanup completed")
|
||||
self._initialized = False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during VLM Engine cleanup: {e}")
|
||||
self._initialized = False
|
||||
|
||||
# Global VLM Engine instance
|
||||
vlm_engine = HavenVLMEngine()
|
||||
|
||||
# Convenience functions for backward compatibility
|
||||
async def process_video_async(
|
||||
video_path: str,
|
||||
vr_video: bool = False,
|
||||
frame_interval: Optional[float] = None,
|
||||
threshold: Optional[float] = None,
|
||||
return_confidence: Optional[bool] = None,
|
||||
existing_json: Optional[Dict[str, Any]] = None,
|
||||
progress_callback: Optional[Callable[[int], None]] = None
|
||||
) -> VideoTagInfo:
|
||||
"""Process video asynchronously"""
|
||||
return await vlm_engine.process_video(
|
||||
video_path, vr_video, frame_interval, threshold, return_confidence, existing_json,
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
|
||||
async def find_optimal_marker_settings_async(
|
||||
existing_json: Dict[str, Any],
|
||||
desired_timespan_data: Dict[str, TimeFrame]
|
||||
) -> Dict[str, Any]:
|
||||
"""Find optimal marker settings asynchronously"""
|
||||
return await vlm_engine.find_optimal_marker_settings(existing_json, desired_timespan_data)
|
||||
316
plugins/AHavenVLMConnector/haven_vlm_utility.py
Normal file
316
plugins/AHavenVLMConnector/haven_vlm_utility.py
Normal file
@ -0,0 +1,316 @@
|
||||
"""
|
||||
Haven VLM Utility Module
|
||||
Utility functions for the A Haven VLM Connector plugin
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def apply_path_mutations(path: str, mutations: Dict[str, str]) -> str:
|
||||
"""
|
||||
Apply path mutations for different environments
|
||||
|
||||
Args:
|
||||
path: Original file path
|
||||
mutations: Dictionary of path mutations (e.g., {"E:": "F:", "G:": "D:"})
|
||||
|
||||
Returns:
|
||||
Mutated path string
|
||||
"""
|
||||
if not mutations:
|
||||
return path
|
||||
|
||||
mutated_path = path
|
||||
for old_path, new_path in mutations.items():
|
||||
if mutated_path.startswith(old_path):
|
||||
mutated_path = mutated_path.replace(old_path, new_path, 1)
|
||||
break
|
||||
|
||||
return mutated_path
|
||||
|
||||
def ensure_directory_exists(directory_path: str) -> None:
|
||||
"""
|
||||
Ensure a directory exists, creating it if necessary
|
||||
|
||||
Args:
|
||||
directory_path: Path to the directory
|
||||
"""
|
||||
Path(directory_path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def safe_file_operation(operation_func, *args, **kwargs) -> Optional[Any]:
|
||||
"""
|
||||
Safely execute a file operation with error handling
|
||||
|
||||
Args:
|
||||
operation_func: Function to execute
|
||||
*args: Arguments for the function
|
||||
**kwargs: Keyword arguments for the function
|
||||
|
||||
Returns:
|
||||
Result of the operation or None if failed
|
||||
"""
|
||||
try:
|
||||
return operation_func(*args, **kwargs)
|
||||
except (OSError, IOError) as e:
|
||||
logger.error(f"File operation failed: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in file operation: {e}")
|
||||
return None
|
||||
|
||||
def load_yaml_config(config_path: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Load configuration from YAML file
|
||||
|
||||
Args:
|
||||
config_path: Path to the YAML configuration file
|
||||
|
||||
Returns:
|
||||
Configuration dictionary or None if failed
|
||||
"""
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
logger.info(f"Configuration loaded from {config_path}")
|
||||
return config
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"Configuration file not found: {config_path}")
|
||||
return None
|
||||
except yaml.YAMLError as e:
|
||||
logger.error(f"Error parsing YAML configuration: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error loading configuration: {e}")
|
||||
return None
|
||||
|
||||
def save_yaml_config(config: Dict[str, Any], config_path: str) -> bool:
|
||||
"""
|
||||
Save configuration to YAML file
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
config_path: Path to save the configuration file
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
ensure_directory_exists(os.path.dirname(config_path))
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(config, f, default_flow_style=False, indent=2)
|
||||
logger.info(f"Configuration saved to {config_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving configuration: {e}")
|
||||
return False
|
||||
|
||||
def validate_file_path(file_path: str) -> bool:
|
||||
"""
|
||||
Validate if a file path exists and is accessible
|
||||
|
||||
Args:
|
||||
file_path: Path to validate
|
||||
|
||||
Returns:
|
||||
True if file exists and is accessible, False otherwise
|
||||
"""
|
||||
try:
|
||||
return os.path.isfile(file_path) and os.access(file_path, os.R_OK)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_file_extension(file_path: str) -> str:
|
||||
"""
|
||||
Get the file extension from a file path
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
File extension (including the dot)
|
||||
"""
|
||||
return Path(file_path).suffix.lower()
|
||||
|
||||
def is_video_file(file_path: str) -> bool:
|
||||
"""
|
||||
Check if a file is a video file based on its extension
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
True if it's a video file, False otherwise
|
||||
"""
|
||||
video_extensions = {'.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.m4v'}
|
||||
return get_file_extension(file_path) in video_extensions
|
||||
|
||||
def is_image_file(file_path: str) -> bool:
|
||||
"""
|
||||
Check if a file is an image file based on its extension
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
True if it's an image file, False otherwise
|
||||
"""
|
||||
image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'}
|
||||
return get_file_extension(file_path) in image_extensions
|
||||
|
||||
def format_duration(seconds: float) -> str:
|
||||
"""
|
||||
Format duration in seconds to human-readable string
|
||||
|
||||
Args:
|
||||
seconds: Duration in seconds
|
||||
|
||||
Returns:
|
||||
Formatted duration string (e.g., "1h 23m 45s")
|
||||
"""
|
||||
if seconds < 60:
|
||||
return f"{seconds:.1f}s"
|
||||
elif seconds < 3600:
|
||||
minutes = int(seconds // 60)
|
||||
remaining_seconds = seconds % 60
|
||||
return f"{minutes}m {remaining_seconds:.1f}s"
|
||||
else:
|
||||
hours = int(seconds // 3600)
|
||||
remaining_minutes = int((seconds % 3600) // 60)
|
||||
remaining_seconds = seconds % 60
|
||||
return f"{hours}h {remaining_minutes}m {remaining_seconds:.1f}s"
|
||||
|
||||
def format_file_size(bytes_size: int) -> str:
|
||||
"""
|
||||
Format file size in bytes to human-readable string
|
||||
|
||||
Args:
|
||||
bytes_size: Size in bytes
|
||||
|
||||
Returns:
|
||||
Formatted size string (e.g., "1.5 MB")
|
||||
"""
|
||||
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
|
||||
if bytes_size < 1024.0:
|
||||
return f"{bytes_size:.1f} {unit}"
|
||||
bytes_size /= 1024.0
|
||||
return f"{bytes_size:.1f} PB"
|
||||
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
"""
|
||||
Sanitize a filename by removing or replacing invalid characters
|
||||
|
||||
Args:
|
||||
filename: Original filename
|
||||
|
||||
Returns:
|
||||
Sanitized filename
|
||||
"""
|
||||
# Replace invalid characters with underscores
|
||||
invalid_chars = '<>:"/\\|?*'
|
||||
for char in invalid_chars:
|
||||
filename = filename.replace(char, '_')
|
||||
|
||||
# Remove leading/trailing spaces and dots
|
||||
filename = filename.strip(' .')
|
||||
|
||||
# Ensure filename is not empty
|
||||
if not filename:
|
||||
filename = "unnamed"
|
||||
|
||||
return filename
|
||||
|
||||
def create_backup_file(file_path: str, backup_suffix: str = ".backup") -> Optional[str]:
|
||||
"""
|
||||
Create a backup of a file
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to backup
|
||||
backup_suffix: Suffix for the backup file
|
||||
|
||||
Returns:
|
||||
Path to the backup file or None if failed
|
||||
"""
|
||||
try:
|
||||
if not os.path.exists(file_path):
|
||||
logger.warning(f"File does not exist: {file_path}")
|
||||
return None
|
||||
|
||||
backup_path = file_path + backup_suffix
|
||||
import shutil
|
||||
shutil.copy2(file_path, backup_path)
|
||||
logger.info(f"Backup created: {backup_path}")
|
||||
return backup_path
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create backup: {e}")
|
||||
return None
|
||||
|
||||
def merge_dictionaries(dict1: Dict[str, Any], dict2: Dict[str, Any], overwrite: bool = True) -> Dict[str, Any]:
|
||||
"""
|
||||
Merge two dictionaries, with option to overwrite existing keys
|
||||
|
||||
Args:
|
||||
dict1: First dictionary
|
||||
dict2: Second dictionary
|
||||
overwrite: Whether to overwrite existing keys in dict1
|
||||
|
||||
Returns:
|
||||
Merged dictionary
|
||||
"""
|
||||
result = dict1.copy()
|
||||
|
||||
for key, value in dict2.items():
|
||||
if key not in result or overwrite:
|
||||
result[key] = value
|
||||
elif isinstance(result[key], dict) and isinstance(value, dict):
|
||||
result[key] = merge_dictionaries(result[key], value, overwrite)
|
||||
|
||||
return result
|
||||
|
||||
def chunk_list(lst: List[Any], chunk_size: int) -> List[List[Any]]:
|
||||
"""
|
||||
Split a list into chunks of specified size
|
||||
|
||||
Args:
|
||||
lst: List to chunk
|
||||
chunk_size: Size of each chunk
|
||||
|
||||
Returns:
|
||||
List of chunks
|
||||
"""
|
||||
return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
|
||||
|
||||
def retry_operation(operation_func, max_retries: int = 3, delay: float = 1.0, *args, **kwargs) -> Optional[Any]:
|
||||
"""
|
||||
Retry an operation with exponential backoff
|
||||
|
||||
Args:
|
||||
operation_func: Function to retry
|
||||
max_retries: Maximum number of retries
|
||||
delay: Initial delay between retries
|
||||
*args: Arguments for the function
|
||||
**kwargs: Keyword arguments for the function
|
||||
|
||||
Returns:
|
||||
Result of the operation or None if all retries failed
|
||||
"""
|
||||
import time
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
return operation_func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
if attempt == max_retries:
|
||||
logger.error(f"Operation failed after {max_retries} retries: {e}")
|
||||
return None
|
||||
|
||||
wait_time = delay * (2 ** attempt)
|
||||
logger.warning(f"Operation failed (attempt {attempt + 1}/{max_retries + 1}), retrying in {wait_time}s: {e}")
|
||||
time.sleep(wait_time)
|
||||
|
||||
return None
|
||||
8
plugins/AHavenVLMConnector/requirements.txt
Normal file
8
plugins/AHavenVLMConnector/requirements.txt
Normal file
@ -0,0 +1,8 @@
|
||||
# Core dependencies managed by PythonDepManager
|
||||
# These are automatically handled by the plugin's dependency management system
|
||||
# PythonDepManager will ensure the correct versions are installed
|
||||
|
||||
# Development and testing dependencies
|
||||
coverage>=7.0.0
|
||||
pytest>=7.0.0
|
||||
pytest-cov>=4.0.0
|
||||
110
plugins/AHavenVLMConnector/run_tests.py
Normal file
110
plugins/AHavenVLMConnector/run_tests.py
Normal file
@ -0,0 +1,110 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test runner for A Haven VLM Connector
|
||||
Runs all unit tests with coverage reporting
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import subprocess
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
def install_test_dependencies():
|
||||
"""Install test dependencies if not already installed"""
|
||||
test_deps = [
|
||||
'coverage',
|
||||
'pytest',
|
||||
'pytest-cov'
|
||||
]
|
||||
|
||||
for dep in test_deps:
|
||||
try:
|
||||
__import__(dep.replace('-', '_'))
|
||||
except ImportError:
|
||||
print(f"Installing {dep}...")
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", dep])
|
||||
|
||||
def run_tests_with_coverage():
|
||||
"""Run tests with coverage reporting"""
|
||||
# Install test dependencies
|
||||
install_test_dependencies()
|
||||
|
||||
# Get the directory containing this script
|
||||
script_dir = Path(__file__).parent
|
||||
|
||||
# Discover and run tests
|
||||
loader = unittest.TestLoader()
|
||||
start_dir = script_dir
|
||||
suite = loader.discover(start_dir, pattern='test_*.py')
|
||||
|
||||
# Run tests with coverage
|
||||
import coverage
|
||||
|
||||
# Start coverage measurement
|
||||
cov = coverage.Coverage(
|
||||
source=['haven_vlm_config.py', 'haven_vlm_engine.py', 'haven_media_handler.py',
|
||||
'haven_vlm_connector.py', 'haven_vlm_utility.py'],
|
||||
omit=['*/test_*.py', '*/__pycache__/*', '*/venv/*', '*/env/*']
|
||||
)
|
||||
cov.start()
|
||||
|
||||
# Run the tests
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(suite)
|
||||
|
||||
# Stop coverage measurement
|
||||
cov.stop()
|
||||
cov.save()
|
||||
|
||||
# Generate coverage report
|
||||
print("\n" + "="*60)
|
||||
print("COVERAGE REPORT")
|
||||
print("="*60)
|
||||
cov.report()
|
||||
|
||||
# Generate HTML coverage report
|
||||
cov.html_report(directory='htmlcov')
|
||||
print(f"\nHTML coverage report generated in: {script_dir}/htmlcov/index.html")
|
||||
|
||||
return result.wasSuccessful()
|
||||
|
||||
def run_specific_test(test_file):
|
||||
"""Run a specific test file"""
|
||||
if not test_file.endswith('.py'):
|
||||
test_file += '.py'
|
||||
|
||||
test_path = Path(__file__).parent / test_file
|
||||
|
||||
if not test_path.exists():
|
||||
print(f"Test file not found: {test_path}")
|
||||
return False
|
||||
|
||||
# Run the specific test
|
||||
loader = unittest.TestLoader()
|
||||
suite = loader.loadTestsFromName(test_file[:-3])
|
||||
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(suite)
|
||||
|
||||
return result.wasSuccessful()
|
||||
|
||||
def main():
|
||||
"""Main entry point"""
|
||||
if len(sys.argv) > 1:
|
||||
# Run specific test file
|
||||
test_file = sys.argv[1]
|
||||
success = run_specific_test(test_file)
|
||||
else:
|
||||
# Run all tests with coverage
|
||||
success = run_tests_with_coverage()
|
||||
|
||||
if success:
|
||||
print("\n✅ All tests passed!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("\n❌ Some tests failed!")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
98
plugins/AHavenVLMConnector/test_dependency_management.py
Normal file
98
plugins/AHavenVLMConnector/test_dependency_management.py
Normal file
@ -0,0 +1,98 @@
|
||||
"""
|
||||
Unit tests for dependency management functionality using PythonDepManager
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import sys
|
||||
from unittest.mock import patch, MagicMock, mock_open
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
class TestPythonDepManagerIntegration(unittest.TestCase):
|
||||
"""Test cases for PythonDepManager integration"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
# Mock PythonDepManager module
|
||||
self.mock_python_dep_manager = MagicMock()
|
||||
sys.modules['PythonDepManager'] = self.mock_python_dep_manager
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up after tests"""
|
||||
if 'PythonDepManager' in sys.modules:
|
||||
del sys.modules['PythonDepManager']
|
||||
|
||||
@patch('builtins.print')
|
||||
def test_dependency_import_failure(self, mock_print):
|
||||
"""Test dependency import failure handling"""
|
||||
# Mock ensure_import to raise ImportError
|
||||
self.mock_python_dep_manager.ensure_import = MagicMock(side_effect=ImportError("Package not found"))
|
||||
|
||||
# Test that the error is handled gracefully
|
||||
with self.assertRaises(SystemExit):
|
||||
import haven_vlm_connector
|
||||
|
||||
def test_error_messages(self):
|
||||
"""Test that appropriate error messages are displayed"""
|
||||
# Mock ensure_import to raise ImportError
|
||||
self.mock_python_dep_manager.ensure_import = MagicMock(side_effect=ImportError("Package not found"))
|
||||
|
||||
with patch('builtins.print') as mock_print:
|
||||
with self.assertRaises(SystemExit):
|
||||
import haven_vlm_connector
|
||||
|
||||
# Check that appropriate error messages were printed
|
||||
print_calls = [call[0][0] for call in mock_print.call_args_list]
|
||||
self.assertTrue(any("Failed to import PythonDepManager" in msg for msg in print_calls if isinstance(msg, str)))
|
||||
self.assertTrue(any("Please ensure PythonDepManager is installed" in msg for msg in print_calls if isinstance(msg, str)))
|
||||
|
||||
|
||||
class TestDependencyManagementEdgeCases(unittest.TestCase):
|
||||
"""Test edge cases in dependency management"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
self.mock_python_dep_manager = MagicMock()
|
||||
sys.modules['PythonDepManager'] = self.mock_python_dep_manager
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up after tests"""
|
||||
if 'PythonDepManager' in sys.modules:
|
||||
del sys.modules['PythonDepManager']
|
||||
|
||||
def test_missing_python_dep_manager(self):
|
||||
"""Test behavior when PythonDepManager is not available"""
|
||||
# Remove PythonDepManager from sys.modules
|
||||
if 'PythonDepManager' in sys.modules:
|
||||
del sys.modules['PythonDepManager']
|
||||
|
||||
with patch('builtins.print') as mock_print:
|
||||
with self.assertRaises(SystemExit):
|
||||
import haven_vlm_connector
|
||||
|
||||
# Check that appropriate error message was printed
|
||||
print_calls = [call[0][0] for call in mock_print.call_args_list]
|
||||
self.assertTrue(any("Failed to import PythonDepManager" in msg for msg in print_calls if isinstance(msg, str)))
|
||||
|
||||
def test_partial_dependency_failure(self):
|
||||
"""Test behavior when some dependencies fail to import"""
|
||||
# Mock ensure_import to succeed but some imports to fail
|
||||
self.mock_python_dep_manager.ensure_import = MagicMock()
|
||||
|
||||
# Mock some successful imports but not all
|
||||
mock_stashapi = MagicMock()
|
||||
sys.modules['stashapi.log'] = mock_stashapi
|
||||
sys.modules['stashapi.stashapp'] = mock_stashapi
|
||||
|
||||
# Don't mock aiohttp, so it should fail
|
||||
with patch('builtins.print') as mock_print:
|
||||
with self.assertRaises(SystemExit):
|
||||
import haven_vlm_connector
|
||||
|
||||
# Check that appropriate error message was printed
|
||||
print_calls = [call[0][0] for call in mock_print.call_args_list]
|
||||
self.assertTrue(any("Error during dependency management" in msg for msg in print_calls if isinstance(msg, str)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
387
plugins/AHavenVLMConnector/test_haven_media_handler.py
Normal file
387
plugins/AHavenVLMConnector/test_haven_media_handler.py
Normal file
@ -0,0 +1,387 @@
|
||||
"""
|
||||
Unit tests for Haven Media Handler Module
|
||||
Tests StashApp media operations and tag management
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from typing import List, Dict, Any, Optional
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add the current directory to the path to import the module
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
# Mock the dependencies before importing the module
|
||||
sys.modules['PythonDepManager'] = Mock()
|
||||
sys.modules['stashapi.stashapp'] = Mock()
|
||||
sys.modules['stashapi.log'] = Mock()
|
||||
sys.modules['haven_vlm_config'] = Mock()
|
||||
|
||||
# Import the module after mocking dependencies
|
||||
import haven_media_handler
|
||||
|
||||
|
||||
class TestHavenMediaHandler(unittest.TestCase):
|
||||
"""Test cases for Haven Media Handler"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
"""Set up test fixtures"""
|
||||
# Mock the stash interface
|
||||
self.mock_stash = Mock()
|
||||
self.mock_stash.find_tag.return_value = {"id": 1}
|
||||
self.mock_stash.get_configuration.return_value = {"ui": {"vrTag": "VR"}}
|
||||
self.mock_stash.stash_version.return_value = Mock()
|
||||
|
||||
# Mock the log module
|
||||
self.mock_log = Mock()
|
||||
|
||||
# Patch the global variables
|
||||
haven_media_handler.stash = self.mock_stash
|
||||
haven_media_handler.log = self.mock_log
|
||||
|
||||
# Mock tag IDs
|
||||
haven_media_handler.vlm_errored_tag_id = 1
|
||||
haven_media_handler.vlm_tagme_tag_id = 2
|
||||
haven_media_handler.vlm_base_tag_id = 3
|
||||
haven_media_handler.vlm_tagged_tag_id = 4
|
||||
haven_media_handler.vr_tag_id = 5
|
||||
haven_media_handler.vlm_incorrect_tag_id = 6
|
||||
|
||||
def tearDown(self) -> None:
|
||||
"""Clean up after tests"""
|
||||
# Clear any cached data
|
||||
haven_media_handler.tag_id_cache.clear()
|
||||
haven_media_handler.vlm_tag_ids_cache.clear()
|
||||
|
||||
def test_clear_all_tags_from_video_with_tags(self) -> None:
|
||||
"""Test clearing all tags from a video that has tags"""
|
||||
# Mock scene with tags
|
||||
mock_scene = {
|
||||
"id": 123,
|
||||
"tags": [
|
||||
{"id": 10, "name": "Tag1"},
|
||||
{"id": 20, "name": "Tag2"},
|
||||
{"id": 30, "name": "Tag3"}
|
||||
]
|
||||
}
|
||||
# Call the function
|
||||
haven_media_handler.clear_all_tags_from_video(mock_scene)
|
||||
# Verify tags were removed
|
||||
self.mock_stash.update_scenes.assert_called_once_with({
|
||||
"ids": [123],
|
||||
"tag_ids": {"ids": [10, 20, 30], "mode": "REMOVE"}
|
||||
})
|
||||
# Verify log message
|
||||
self.mock_log.info.assert_called_once_with("Cleared 3 tags from scene 123")
|
||||
|
||||
def test_clear_all_tags_from_video_no_tags(self) -> None:
|
||||
"""Test clearing all tags from a video that has no tags"""
|
||||
# Mock scene without tags
|
||||
mock_scene = {"id": 123, "tags": []}
|
||||
# Call the function
|
||||
haven_media_handler.clear_all_tags_from_video(mock_scene)
|
||||
# Verify no update was called since there are no tags
|
||||
self.mock_stash.update_scenes.assert_not_called()
|
||||
# Verify no log message
|
||||
self.mock_log.info.assert_not_called()
|
||||
|
||||
def test_clear_all_tags_from_video_scene_without_tags_key(self) -> None:
|
||||
"""Test clearing all tags from a scene that doesn't have a tags key"""
|
||||
# Mock scene without tags key
|
||||
mock_scene = {"id": 123}
|
||||
# Call the function
|
||||
haven_media_handler.clear_all_tags_from_video(mock_scene)
|
||||
# Verify no update was called
|
||||
self.mock_stash.update_scenes.assert_not_called()
|
||||
|
||||
@patch('haven_media_handler.get_scene_markers')
|
||||
@patch('haven_media_handler.delete_markers')
|
||||
def test_clear_all_markers_from_video_with_markers(self, mock_delete_markers: Mock, mock_get_markers: Mock) -> None:
|
||||
"""Test clearing all markers from a video that has markers"""
|
||||
# Mock markers
|
||||
mock_markers = [
|
||||
{"id": 1, "title": "Marker1"},
|
||||
{"id": 2, "title": "Marker2"}
|
||||
]
|
||||
mock_get_markers.return_value = mock_markers
|
||||
|
||||
# Call the function
|
||||
haven_media_handler.clear_all_markers_from_video(123)
|
||||
|
||||
# Verify markers were retrieved
|
||||
mock_get_markers.assert_called_once_with(123)
|
||||
|
||||
# Verify markers were deleted
|
||||
mock_delete_markers.assert_called_once_with(mock_markers)
|
||||
|
||||
# Verify log message
|
||||
self.mock_log.info.assert_called_once_with("Cleared all 2 markers from scene 123")
|
||||
|
||||
@patch('haven_media_handler.get_scene_markers')
|
||||
@patch('haven_media_handler.delete_markers')
|
||||
def test_clear_all_markers_from_video_no_markers(self, mock_delete_markers: Mock, mock_get_markers: Mock) -> None:
|
||||
"""Test clearing all markers from a video that has no markers"""
|
||||
# Mock no markers
|
||||
mock_get_markers.return_value = []
|
||||
|
||||
# Call the function
|
||||
haven_media_handler.clear_all_markers_from_video(123)
|
||||
|
||||
# Verify markers were retrieved
|
||||
mock_get_markers.assert_called_once_with(123)
|
||||
|
||||
# Verify no deletion was called
|
||||
mock_delete_markers.assert_not_called()
|
||||
|
||||
# Verify no log message
|
||||
self.mock_log.info.assert_not_called()
|
||||
|
||||
def test_add_tags_to_video_with_tagged(self) -> None:
|
||||
"""Test adding tags to video with tagged flag enabled"""
|
||||
# Call the function
|
||||
haven_media_handler.add_tags_to_video(123, [10, 20, 30], add_tagged=True)
|
||||
|
||||
# Verify tags were added (including tagged tag)
|
||||
self.mock_stash.update_scenes.assert_called_once_with({
|
||||
"ids": [123],
|
||||
"tag_ids": {"ids": [10, 20, 30, 4], "mode": "ADD"}
|
||||
})
|
||||
|
||||
def test_add_tags_to_video_without_tagged(self) -> None:
|
||||
"""Test adding tags to video with tagged flag disabled"""
|
||||
# Call the function
|
||||
haven_media_handler.add_tags_to_video(123, [10, 20, 30], add_tagged=False)
|
||||
|
||||
# Verify tags were added (without tagged tag)
|
||||
self.mock_stash.update_scenes.assert_called_once_with({
|
||||
"ids": [123],
|
||||
"tag_ids": {"ids": [10, 20, 30], "mode": "ADD"}
|
||||
})
|
||||
|
||||
@patch('haven_media_handler.get_vlm_tags')
|
||||
def test_remove_vlm_tags_from_video(self, mock_get_vlm_tags: Mock) -> None:
|
||||
"""Test removing VLM tags from video"""
|
||||
# Mock VLM tags
|
||||
mock_get_vlm_tags.return_value = [100, 200, 300]
|
||||
|
||||
# Call the function
|
||||
haven_media_handler.remove_vlm_tags_from_video(123, remove_tagme=True, remove_errored=True)
|
||||
|
||||
# Verify VLM tags were retrieved
|
||||
mock_get_vlm_tags.assert_called_once()
|
||||
|
||||
# Verify tags were removed (including tagme and errored tags)
|
||||
self.mock_stash.update_scenes.assert_called_once_with({
|
||||
"ids": [123],
|
||||
"tag_ids": {"ids": [100, 200, 300, 2, 1], "mode": "REMOVE"}
|
||||
})
|
||||
|
||||
def test_get_tagme_scenes(self) -> None:
|
||||
"""Test getting scenes tagged with VLM_TagMe"""
|
||||
# Mock scenes
|
||||
mock_scenes = [{"id": 1}, {"id": 2}]
|
||||
self.mock_stash.find_scenes.return_value = mock_scenes
|
||||
|
||||
# Call the function
|
||||
result = haven_media_handler.get_tagme_scenes()
|
||||
|
||||
# Verify scenes were found
|
||||
self.mock_stash.find_scenes.assert_called_once_with(
|
||||
f={"tags": {"value": 2, "modifier": "INCLUDES"}},
|
||||
fragment="id tags {id} files {path duration fingerprint(type: \"phash\")}"
|
||||
)
|
||||
|
||||
# Verify result
|
||||
self.assertEqual(result, mock_scenes)
|
||||
|
||||
def test_add_error_scene(self) -> None:
|
||||
"""Test adding error tag to a scene"""
|
||||
# Call the function
|
||||
haven_media_handler.add_error_scene(123)
|
||||
|
||||
# Verify error tag was added
|
||||
self.mock_stash.update_scenes.assert_called_once_with({
|
||||
"ids": [123],
|
||||
"tag_ids": {"ids": [1], "mode": "ADD"}
|
||||
})
|
||||
|
||||
def test_remove_tagme_tag_from_scene(self) -> None:
|
||||
"""Test removing VLM_TagMe tag from a scene"""
|
||||
# Call the function
|
||||
haven_media_handler.remove_tagme_tag_from_scene(123)
|
||||
|
||||
# Verify tagme tag was removed
|
||||
self.mock_stash.update_scenes.assert_called_once_with({
|
||||
"ids": [123],
|
||||
"tag_ids": {"ids": [2], "mode": "REMOVE"}
|
||||
})
|
||||
|
||||
def test_is_scene_tagged_true(self) -> None:
|
||||
"""Test checking if a scene is tagged (true case)"""
|
||||
# Mock tags including tagged tag
|
||||
tags = [
|
||||
{"id": 10, "name": "Tag1"},
|
||||
{"id": 4, "name": "VLM_Tagged"}, # This is the tagged tag
|
||||
{"id": 20, "name": "Tag2"}
|
||||
]
|
||||
|
||||
# Call the function
|
||||
result = haven_media_handler.is_scene_tagged(tags)
|
||||
|
||||
# Verify result
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_is_scene_tagged_false(self) -> None:
|
||||
"""Test checking if a scene is tagged (false case)"""
|
||||
# Mock tags without tagged tag
|
||||
tags = [
|
||||
{"id": 10, "name": "Tag1"},
|
||||
{"id": 20, "name": "Tag2"}
|
||||
]
|
||||
|
||||
# Call the function
|
||||
result = haven_media_handler.is_scene_tagged(tags)
|
||||
|
||||
# Verify result
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_is_vr_scene_true(self) -> None:
|
||||
"""Test checking if a scene is VR (true case)"""
|
||||
# Mock tags including VR tag
|
||||
tags = [
|
||||
{"id": 10, "name": "Tag1"},
|
||||
{"id": 5, "name": "VR"}, # This is the VR tag
|
||||
{"id": 20, "name": "Tag2"}
|
||||
]
|
||||
|
||||
# Call the function
|
||||
result = haven_media_handler.is_vr_scene(tags)
|
||||
|
||||
# Verify result
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_is_vr_scene_false(self) -> None:
|
||||
"""Test checking if a scene is VR (false case)"""
|
||||
# Mock tags without VR tag
|
||||
tags = [
|
||||
{"id": 10, "name": "Tag1"},
|
||||
{"id": 20, "name": "Tag2"}
|
||||
]
|
||||
|
||||
# Call the function
|
||||
result = haven_media_handler.is_vr_scene(tags)
|
||||
|
||||
# Verify result
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_get_tag_id_existing(self) -> None:
|
||||
"""Test getting tag ID for existing tag"""
|
||||
# Mock existing tag
|
||||
self.mock_stash.find_tag.return_value = {"id": 123, "name": "TestTag"}
|
||||
|
||||
# Call the function
|
||||
result = haven_media_handler.get_tag_id("TestTag", create=False)
|
||||
|
||||
# Verify tag was found
|
||||
self.mock_stash.find_tag.assert_called_once_with("TestTag")
|
||||
|
||||
# Verify result
|
||||
self.assertEqual(result, 123)
|
||||
|
||||
def test_get_tag_id_not_existing_no_create(self) -> None:
|
||||
"""Test getting tag ID for non-existing tag without create"""
|
||||
# Mock non-existing tag
|
||||
self.mock_stash.find_tag.return_value = None
|
||||
|
||||
# Call the function
|
||||
result = haven_media_handler.get_tag_id("TestTag", create=False)
|
||||
|
||||
# Verify tag was searched
|
||||
self.mock_stash.find_tag.assert_called_once_with("TestTag")
|
||||
|
||||
# Verify result is None
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_get_tag_id_create_new(self) -> None:
|
||||
"""Test getting tag ID for non-existing tag with create"""
|
||||
# Mock non-existing tag
|
||||
self.mock_stash.find_tag.return_value = None
|
||||
|
||||
# Mock created tag
|
||||
self.mock_stash.create_tag.return_value = {"id": 456, "name": "TestTag"}
|
||||
|
||||
# Call the function
|
||||
result = haven_media_handler.get_tag_id("TestTag", create=True)
|
||||
|
||||
# Verify tag was searched
|
||||
self.mock_stash.find_tag.assert_called_once_with("TestTag")
|
||||
|
||||
# Verify tag was created
|
||||
self.mock_stash.create_tag.assert_called_once_with({
|
||||
"name": "TestTag",
|
||||
"ignore_auto_tag": True,
|
||||
"parent_ids": [3]
|
||||
})
|
||||
|
||||
# Verify result
|
||||
self.assertEqual(result, 456)
|
||||
|
||||
def test_get_tag_ids(self) -> None:
|
||||
"""Test getting multiple tag IDs"""
|
||||
# Mock tag IDs
|
||||
with patch('haven_media_handler.get_tag_id') as mock_get_tag_id:
|
||||
mock_get_tag_id.side_effect = [10, 20, 30]
|
||||
|
||||
# Call the function
|
||||
result = haven_media_handler.get_tag_ids(["Tag1", "Tag2", "Tag3"], create=True)
|
||||
|
||||
# Verify individual tag IDs were retrieved
|
||||
self.assertEqual(mock_get_tag_id.call_count, 3)
|
||||
mock_get_tag_id.assert_any_call("Tag1", True)
|
||||
mock_get_tag_id.assert_any_call("Tag2", True)
|
||||
mock_get_tag_id.assert_any_call("Tag3", True)
|
||||
|
||||
# Verify result
|
||||
self.assertEqual(result, [10, 20, 30])
|
||||
|
||||
@patch('haven_media_handler.vlm_tag_ids_cache')
|
||||
def test_get_vlm_tags_from_cache(self, mock_cache: Mock) -> None:
|
||||
"""Test getting VLM tags from cache"""
|
||||
# Mock cached tags
|
||||
mock_cache.__len__.return_value = 3
|
||||
mock_cache.__iter__.return_value = iter([100, 200, 300])
|
||||
|
||||
# Call the function
|
||||
result = haven_media_handler.get_vlm_tags()
|
||||
|
||||
# Verify result from cache
|
||||
self.assertEqual(result, [100, 200, 300])
|
||||
|
||||
def test_get_vlm_tags_from_stash(self) -> None:
|
||||
"""Test getting VLM tags from stash when cache is empty"""
|
||||
# Mock empty cache
|
||||
haven_media_handler.vlm_tag_ids_cache.clear()
|
||||
|
||||
# Mock stash tags
|
||||
mock_tags = [
|
||||
{"id": 100, "name": "VLM_Tag1"},
|
||||
{"id": 200, "name": "VLM_Tag2"}
|
||||
]
|
||||
self.mock_stash.find_tags.return_value = mock_tags
|
||||
|
||||
# Call the function
|
||||
result = haven_media_handler.get_vlm_tags()
|
||||
|
||||
# Verify tags were found
|
||||
self.mock_stash.find_tags.assert_called_once_with(
|
||||
f={"parents": {"value": 3, "modifier": "INCLUDES"}},
|
||||
fragment="id"
|
||||
)
|
||||
|
||||
# Verify result
|
||||
self.assertEqual(result, [100, 200])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
286
plugins/AHavenVLMConnector/test_haven_vlm_config.py
Normal file
286
plugins/AHavenVLMConnector/test_haven_vlm_config.py
Normal file
@ -0,0 +1,286 @@
|
||||
"""
|
||||
Unit tests for haven_vlm_config module
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import tempfile
|
||||
import os
|
||||
import yaml
|
||||
from unittest.mock import patch, mock_open
|
||||
from dataclasses import dataclass
|
||||
|
||||
import haven_vlm_config
|
||||
|
||||
|
||||
class TestVLMConnectorConfig(unittest.TestCase):
|
||||
"""Test cases for VLMConnectorConfig dataclass"""
|
||||
|
||||
def test_vlm_connector_config_creation(self):
|
||||
"""Test creating VLMConnectorConfig with all required fields"""
|
||||
config = haven_vlm_config.VLMConnectorConfig(
|
||||
vlm_engine_config={"test": "config"},
|
||||
video_frame_interval=2.0,
|
||||
video_threshold=0.3,
|
||||
video_confidence_return=True,
|
||||
image_threshold=0.5,
|
||||
image_batch_size=320,
|
||||
image_confidence_return=False,
|
||||
concurrent_task_limit=10,
|
||||
server_timeout=3700,
|
||||
vlm_base_tag_name="VLM",
|
||||
vlm_tagme_tag_name="VLM_TagMe",
|
||||
vlm_updateme_tag_name="VLM_UpdateMe",
|
||||
vlm_tagged_tag_name="VLM_Tagged",
|
||||
vlm_errored_tag_name="VLM_Errored",
|
||||
vlm_incorrect_tag_name="VLM_Incorrect",
|
||||
temp_image_dir="./temp_images",
|
||||
output_data_dir="./output_data",
|
||||
delete_incorrect_markers=True,
|
||||
create_markers=True,
|
||||
path_mutation={}
|
||||
)
|
||||
|
||||
self.assertEqual(config.video_frame_interval, 2.0)
|
||||
self.assertEqual(config.video_threshold, 0.3)
|
||||
self.assertEqual(config.image_threshold, 0.5)
|
||||
self.assertEqual(config.concurrent_task_limit, 10)
|
||||
self.assertEqual(config.vlm_base_tag_name, "VLM")
|
||||
self.assertEqual(config.temp_image_dir, "./temp_images")
|
||||
|
||||
def test_vlm_connector_config_defaults(self):
|
||||
"""Test VLMConnectorConfig with minimal required fields"""
|
||||
config = haven_vlm_config.VLMConnectorConfig(
|
||||
vlm_engine_config={},
|
||||
video_frame_interval=1.0,
|
||||
video_threshold=0.1,
|
||||
video_confidence_return=False,
|
||||
image_threshold=0.1,
|
||||
image_batch_size=100,
|
||||
image_confidence_return=False,
|
||||
concurrent_task_limit=5,
|
||||
server_timeout=1000,
|
||||
vlm_base_tag_name="TEST",
|
||||
vlm_tagme_tag_name="TEST_TagMe",
|
||||
vlm_updateme_tag_name="TEST_UpdateMe",
|
||||
vlm_tagged_tag_name="TEST_Tagged",
|
||||
vlm_errored_tag_name="TEST_Errored",
|
||||
vlm_incorrect_tag_name="TEST_Incorrect",
|
||||
temp_image_dir="./test_temp",
|
||||
output_data_dir="./test_output",
|
||||
delete_incorrect_markers=False,
|
||||
create_markers=False,
|
||||
path_mutation={"test": "mutation"}
|
||||
)
|
||||
|
||||
self.assertEqual(config.video_frame_interval, 1.0)
|
||||
self.assertEqual(config.video_threshold, 0.1)
|
||||
self.assertEqual(config.path_mutation, {"test": "mutation"})
|
||||
|
||||
|
||||
class TestLoadConfigFromYaml(unittest.TestCase):
|
||||
"""Test cases for load_config_from_yaml function"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
self.test_config = {
|
||||
"vlm_engine_config": {
|
||||
"active_ai_models": ["test_model"],
|
||||
"pipelines": {},
|
||||
"models": {},
|
||||
"category_config": {}
|
||||
},
|
||||
"video_frame_interval": 3.0,
|
||||
"video_threshold": 0.4,
|
||||
"video_confidence_return": True,
|
||||
"image_threshold": 0.6,
|
||||
"image_batch_size": 500,
|
||||
"image_confidence_return": True,
|
||||
"concurrent_task_limit": 15,
|
||||
"server_timeout": 5000,
|
||||
"vlm_base_tag_name": "TEST_VLM",
|
||||
"vlm_tagme_tag_name": "TEST_VLM_TagMe",
|
||||
"vlm_updateme_tag_name": "TEST_VLM_UpdateMe",
|
||||
"vlm_tagged_tag_name": "TEST_VLM_Tagged",
|
||||
"vlm_errored_tag_name": "TEST_VLM_Errored",
|
||||
"vlm_incorrect_tag_name": "TEST_VLM_Incorrect",
|
||||
"temp_image_dir": "./test_temp_images",
|
||||
"output_data_dir": "./test_output_data",
|
||||
"delete_incorrect_markers": False,
|
||||
"create_markers": False,
|
||||
"path_mutation": {"E:": "F:"}
|
||||
}
|
||||
|
||||
def test_load_config_from_yaml_with_valid_file(self):
|
||||
"""Test loading configuration from a valid YAML file"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.yml', delete=False) as f:
|
||||
yaml.dump(self.test_config, f)
|
||||
config_path = f.name
|
||||
|
||||
try:
|
||||
config = haven_vlm_config.load_config_from_yaml(config_path)
|
||||
|
||||
self.assertIsInstance(config, haven_vlm_config.VLMConnectorConfig)
|
||||
self.assertEqual(config.video_frame_interval, 3.0)
|
||||
self.assertEqual(config.video_threshold, 0.4)
|
||||
self.assertEqual(config.image_threshold, 0.6)
|
||||
self.assertEqual(config.concurrent_task_limit, 15)
|
||||
self.assertEqual(config.vlm_base_tag_name, "TEST_VLM")
|
||||
self.assertEqual(config.path_mutation, {"E:": "F:"})
|
||||
finally:
|
||||
os.unlink(config_path)
|
||||
|
||||
def test_load_config_from_yaml_with_nonexistent_file(self):
|
||||
"""Test loading configuration with nonexistent file path"""
|
||||
config = haven_vlm_config.load_config_from_yaml("nonexistent_file.yml")
|
||||
|
||||
# Should return default configuration
|
||||
self.assertIsInstance(config, haven_vlm_config.VLMConnectorConfig)
|
||||
self.assertEqual(config.video_frame_interval, haven_vlm_config.VIDEO_FRAME_INTERVAL)
|
||||
self.assertEqual(config.video_threshold, haven_vlm_config.VIDEO_THRESHOLD)
|
||||
|
||||
def test_load_config_from_yaml_with_none_path(self):
|
||||
"""Test loading configuration with None path"""
|
||||
config = haven_vlm_config.load_config_from_yaml(None)
|
||||
|
||||
# Should return default configuration
|
||||
self.assertIsInstance(config, haven_vlm_config.VLMConnectorConfig)
|
||||
self.assertEqual(config.video_frame_interval, haven_vlm_config.VIDEO_FRAME_INTERVAL)
|
||||
|
||||
def test_load_config_from_yaml_with_invalid_yaml(self):
|
||||
"""Test loading configuration with invalid YAML content"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.yml', delete=False) as f:
|
||||
f.write("invalid: yaml: content: [")
|
||||
config_path = f.name
|
||||
|
||||
try:
|
||||
config = haven_vlm_config.load_config_from_yaml(config_path)
|
||||
|
||||
# Should return default configuration on YAML error
|
||||
self.assertIsInstance(config, haven_vlm_config.VLMConnectorConfig)
|
||||
self.assertEqual(config.video_frame_interval, haven_vlm_config.VIDEO_FRAME_INTERVAL)
|
||||
finally:
|
||||
os.unlink(config_path)
|
||||
|
||||
def test_load_config_from_yaml_with_file_permission_error(self):
|
||||
"""Test loading configuration with file permission error"""
|
||||
with patch('builtins.open', side_effect=PermissionError("Permission denied")):
|
||||
config = haven_vlm_config.load_config_from_yaml("test.yml")
|
||||
|
||||
# Should return default configuration on file error
|
||||
self.assertIsInstance(config, haven_vlm_config.VLMConnectorConfig)
|
||||
self.assertEqual(config.video_frame_interval, haven_vlm_config.VIDEO_FRAME_INTERVAL)
|
||||
|
||||
|
||||
class TestConfigurationConstants(unittest.TestCase):
|
||||
"""Test cases for configuration constants"""
|
||||
|
||||
def test_vlm_engine_config_structure(self):
|
||||
"""Test that VLM_ENGINE_CONFIG has the expected structure"""
|
||||
config = haven_vlm_config.VLM_ENGINE_CONFIG
|
||||
|
||||
# Check required top-level keys
|
||||
self.assertIn("active_ai_models", config)
|
||||
self.assertIn("pipelines", config)
|
||||
self.assertIn("models", config)
|
||||
self.assertIn("category_config", config)
|
||||
|
||||
# Check active_ai_models is a list
|
||||
self.assertIsInstance(config["active_ai_models"], list)
|
||||
self.assertIn("vlm_multiplexer_model", config["active_ai_models"])
|
||||
|
||||
# Check pipelines structure
|
||||
self.assertIn("video_pipeline_dynamic", config["pipelines"])
|
||||
pipeline = config["pipelines"]["video_pipeline_dynamic"]
|
||||
self.assertIn("inputs", pipeline)
|
||||
self.assertIn("output", pipeline)
|
||||
self.assertIn("models", pipeline)
|
||||
|
||||
# Check models structure
|
||||
self.assertIn("vlm_multiplexer_model", config["models"])
|
||||
model = config["models"]["vlm_multiplexer_model"]
|
||||
self.assertIn("type", model)
|
||||
self.assertIn("multiplexer_endpoints", model)
|
||||
self.assertIn("tag_list", model)
|
||||
|
||||
def test_processing_settings(self):
|
||||
"""Test that processing settings have valid values"""
|
||||
self.assertGreater(haven_vlm_config.VIDEO_FRAME_INTERVAL, 0)
|
||||
self.assertGreaterEqual(haven_vlm_config.VIDEO_THRESHOLD, 0)
|
||||
self.assertLessEqual(haven_vlm_config.VIDEO_THRESHOLD, 1)
|
||||
self.assertGreaterEqual(haven_vlm_config.IMAGE_THRESHOLD, 0)
|
||||
self.assertLessEqual(haven_vlm_config.IMAGE_THRESHOLD, 1)
|
||||
self.assertGreater(haven_vlm_config.IMAGE_BATCH_SIZE, 0)
|
||||
self.assertGreater(haven_vlm_config.CONCURRENT_TASK_LIMIT, 0)
|
||||
self.assertGreater(haven_vlm_config.SERVER_TIMEOUT, 0)
|
||||
|
||||
def test_tag_names(self):
|
||||
"""Test that tag names are valid strings"""
|
||||
tag_names = [
|
||||
haven_vlm_config.VLM_BASE_TAG_NAME,
|
||||
haven_vlm_config.VLM_TAGME_TAG_NAME,
|
||||
haven_vlm_config.VLM_UPDATEME_TAG_NAME,
|
||||
haven_vlm_config.VLM_TAGGED_TAG_NAME,
|
||||
haven_vlm_config.VLM_ERRORED_TAG_NAME,
|
||||
haven_vlm_config.VLM_INCORRECT_TAG_NAME
|
||||
]
|
||||
|
||||
for tag_name in tag_names:
|
||||
self.assertIsInstance(tag_name, str)
|
||||
self.assertGreater(len(tag_name), 0)
|
||||
|
||||
def test_directory_paths(self):
|
||||
"""Test that directory paths are valid strings"""
|
||||
self.assertIsInstance(haven_vlm_config.TEMP_IMAGE_DIR, str)
|
||||
self.assertIsInstance(haven_vlm_config.OUTPUT_DATA_DIR, str)
|
||||
self.assertGreater(len(haven_vlm_config.TEMP_IMAGE_DIR), 0)
|
||||
self.assertGreater(len(haven_vlm_config.OUTPUT_DATA_DIR), 0)
|
||||
|
||||
def test_boolean_settings(self):
|
||||
"""Test that boolean settings are valid"""
|
||||
self.assertIsInstance(haven_vlm_config.DELETE_INCORRECT_MARKERS, bool)
|
||||
self.assertIsInstance(haven_vlm_config.CREATE_MARKERS, bool)
|
||||
|
||||
def test_path_mutation(self):
|
||||
"""Test that path mutation is a dictionary"""
|
||||
self.assertIsInstance(haven_vlm_config.PATH_MUTATION, dict)
|
||||
|
||||
|
||||
class TestGlobalConfigInstance(unittest.TestCase):
|
||||
"""Test cases for the global config instance"""
|
||||
|
||||
def test_global_config_exists(self):
|
||||
"""Test that the global config instance exists and is valid"""
|
||||
self.assertIsInstance(haven_vlm_config.config, haven_vlm_config.VLMConnectorConfig)
|
||||
|
||||
def test_global_config_has_required_attributes(self):
|
||||
"""Test that the global config has all required attributes"""
|
||||
config = haven_vlm_config.config
|
||||
|
||||
# Check that all required attributes exist
|
||||
required_attrs = [
|
||||
'vlm_engine_config', 'video_frame_interval', 'video_threshold',
|
||||
'video_confidence_return', 'image_threshold', 'image_batch_size',
|
||||
'image_confidence_return', 'concurrent_task_limit', 'server_timeout',
|
||||
'vlm_base_tag_name', 'vlm_tagme_tag_name', 'vlm_updateme_tag_name',
|
||||
'vlm_tagged_tag_name', 'vlm_errored_tag_name', 'vlm_incorrect_tag_name',
|
||||
'temp_image_dir', 'output_data_dir', 'delete_incorrect_markers',
|
||||
'create_markers', 'path_mutation'
|
||||
]
|
||||
|
||||
for attr in required_attrs:
|
||||
self.assertTrue(hasattr(config, attr), f"Missing attribute: {attr}")
|
||||
|
||||
def test_global_config_values(self):
|
||||
"""Test that the global config has expected default values"""
|
||||
config = haven_vlm_config.config
|
||||
|
||||
self.assertEqual(config.video_frame_interval, haven_vlm_config.VIDEO_FRAME_INTERVAL)
|
||||
self.assertEqual(config.video_threshold, haven_vlm_config.VIDEO_THRESHOLD)
|
||||
self.assertEqual(config.image_threshold, haven_vlm_config.IMAGE_THRESHOLD)
|
||||
self.assertEqual(config.concurrent_task_limit, haven_vlm_config.CONCURRENT_TASK_LIMIT)
|
||||
self.assertEqual(config.vlm_base_tag_name, haven_vlm_config.VLM_BASE_TAG_NAME)
|
||||
self.assertEqual(config.temp_image_dir, haven_vlm_config.TEMP_IMAGE_DIR)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
451
plugins/AHavenVLMConnector/test_haven_vlm_connector.py
Normal file
451
plugins/AHavenVLMConnector/test_haven_vlm_connector.py
Normal file
@ -0,0 +1,451 @@
|
||||
"""
|
||||
Unit tests for haven_vlm_connector module
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import asyncio
|
||||
import json
|
||||
import tempfile
|
||||
import os
|
||||
from unittest.mock import patch, MagicMock, AsyncMock, mock_open
|
||||
import sys
|
||||
|
||||
# Mock the stashapi imports
|
||||
sys.modules['stashapi.log'] = MagicMock()
|
||||
sys.modules['stashapi.stashapp'] = MagicMock()
|
||||
|
||||
# Mock the vlm_engine imports
|
||||
sys.modules['vlm_engine'] = MagicMock()
|
||||
sys.modules['vlm_engine.config_models'] = MagicMock()
|
||||
|
||||
import haven_vlm_connector
|
||||
|
||||
|
||||
class TestMainExecution(unittest.TestCase):
|
||||
"""Test cases for main execution functions"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
self.sample_json_input = {
|
||||
"server_connection": {
|
||||
"PluginDir": "/tmp/plugin"
|
||||
},
|
||||
"args": {
|
||||
"mode": "tag_videos"
|
||||
}
|
||||
}
|
||||
|
||||
@patch('haven_vlm_connector.media_handler')
|
||||
@patch('haven_vlm_connector.tag_videos')
|
||||
@patch('haven_vlm_connector.os.chdir')
|
||||
def test_run_tag_videos(self, mock_chdir, mock_tag_videos, mock_media_handler):
|
||||
"""Test running tag_videos mode"""
|
||||
output = {}
|
||||
|
||||
with patch('haven_vlm_connector.read_json_input', return_value=self.sample_json_input):
|
||||
asyncio.run(haven_vlm_connector.run(self.sample_json_input, output))
|
||||
|
||||
mock_chdir.assert_called_once_with("/tmp/plugin")
|
||||
mock_media_handler.initialize.assert_called_once_with(self.sample_json_input["server_connection"])
|
||||
mock_tag_videos.assert_called_once()
|
||||
self.assertEqual(output["output"], "ok")
|
||||
|
||||
@patch('haven_vlm_connector.media_handler')
|
||||
@patch('haven_vlm_connector.tag_images')
|
||||
@patch('haven_vlm_connector.os.chdir')
|
||||
def test_run_tag_images(self, mock_chdir, mock_tag_images, mock_media_handler):
|
||||
"""Test running tag_images mode"""
|
||||
json_input = self.sample_json_input.copy()
|
||||
json_input["args"]["mode"] = "tag_images"
|
||||
output = {}
|
||||
|
||||
asyncio.run(haven_vlm_connector.run(json_input, output))
|
||||
|
||||
mock_tag_images.assert_called_once()
|
||||
self.assertEqual(output["output"], "ok")
|
||||
|
||||
@patch('haven_vlm_connector.media_handler')
|
||||
@patch('haven_vlm_connector.find_marker_settings')
|
||||
@patch('haven_vlm_connector.os.chdir')
|
||||
def test_run_find_marker_settings(self, mock_chdir, mock_find_marker_settings, mock_media_handler):
|
||||
"""Test running find_marker_settings mode"""
|
||||
json_input = self.sample_json_input.copy()
|
||||
json_input["args"]["mode"] = "find_marker_settings"
|
||||
output = {}
|
||||
|
||||
asyncio.run(haven_vlm_connector.run(json_input, output))
|
||||
|
||||
mock_find_marker_settings.assert_called_once()
|
||||
self.assertEqual(output["output"], "ok")
|
||||
|
||||
@patch('haven_vlm_connector.media_handler')
|
||||
@patch('haven_vlm_connector.collect_incorrect_markers_and_images')
|
||||
@patch('haven_vlm_connector.os.chdir')
|
||||
def test_run_collect_incorrect_markers(self, mock_chdir, mock_collect, mock_media_handler):
|
||||
"""Test running collect_incorrect_markers mode"""
|
||||
json_input = self.sample_json_input.copy()
|
||||
json_input["args"]["mode"] = "collect_incorrect_markers"
|
||||
output = {}
|
||||
|
||||
asyncio.run(haven_vlm_connector.run(json_input, output))
|
||||
|
||||
mock_collect.assert_called_once()
|
||||
self.assertEqual(output["output"], "ok")
|
||||
|
||||
@patch('haven_vlm_connector.media_handler')
|
||||
@patch('haven_vlm_connector.os.chdir')
|
||||
def test_run_no_mode(self, mock_chdir, mock_media_handler):
|
||||
"""Test running with no mode specified"""
|
||||
json_input = self.sample_json_input.copy()
|
||||
del json_input["args"]["mode"]
|
||||
output = {}
|
||||
|
||||
asyncio.run(haven_vlm_connector.run(json_input, output))
|
||||
|
||||
self.assertEqual(output["output"], "ok")
|
||||
|
||||
@patch('haven_vlm_connector.media_handler')
|
||||
def test_run_media_handler_initialization_error(self, mock_media_handler):
|
||||
"""Test handling media handler initialization error"""
|
||||
mock_media_handler.initialize.side_effect = Exception("Initialization failed")
|
||||
output = {}
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
asyncio.run(haven_vlm_connector.run(self.sample_json_input, output))
|
||||
|
||||
def test_read_json_input(self):
|
||||
"""Test reading JSON input from stdin"""
|
||||
test_input = '{"test": "data"}'
|
||||
|
||||
with patch('sys.stdin.read', return_value=test_input):
|
||||
result = haven_vlm_connector.read_json_input()
|
||||
|
||||
self.assertEqual(result, {"test": "data"})
|
||||
|
||||
|
||||
class TestHighLevelProcessingFunctions(unittest.TestCase):
|
||||
"""Test cases for high-level processing functions"""
|
||||
|
||||
@patch('haven_vlm_connector.media_handler')
|
||||
@patch('haven_vlm_connector.__tag_images')
|
||||
@patch('haven_vlm_connector.asyncio.gather')
|
||||
def test_tag_images_with_images(self, mock_gather, mock_tag_images, mock_media_handler):
|
||||
"""Test tagging images when images are available"""
|
||||
mock_images = [{"id": 1}, {"id": 2}, {"id": 3}]
|
||||
mock_media_handler.get_tagme_images.return_value = mock_images
|
||||
|
||||
asyncio.run(haven_vlm_connector.tag_images())
|
||||
|
||||
mock_media_handler.get_tagme_images.assert_called_once()
|
||||
mock_gather.assert_called_once()
|
||||
|
||||
@patch('haven_vlm_connector.media_handler')
|
||||
def test_tag_images_no_images(self, mock_media_handler):
|
||||
"""Test tagging images when no images are available"""
|
||||
mock_media_handler.get_tagme_images.return_value = []
|
||||
|
||||
asyncio.run(haven_vlm_connector.tag_images())
|
||||
|
||||
mock_media_handler.get_tagme_images.assert_called_once()
|
||||
|
||||
@patch('haven_vlm_connector.media_handler')
|
||||
@patch('haven_vlm_connector.__tag_video')
|
||||
@patch('haven_vlm_connector.asyncio.gather')
|
||||
def test_tag_videos_with_scenes(self, mock_gather, mock_tag_video, mock_media_handler):
|
||||
"""Test tagging videos when scenes are available"""
|
||||
mock_scenes = [{"id": 1}, {"id": 2}]
|
||||
mock_media_handler.get_tagme_scenes.return_value = mock_scenes
|
||||
|
||||
asyncio.run(haven_vlm_connector.tag_videos())
|
||||
|
||||
mock_media_handler.get_tagme_scenes.assert_called_once()
|
||||
mock_gather.assert_called_once()
|
||||
|
||||
@patch('haven_vlm_connector.media_handler')
|
||||
def test_tag_videos_no_scenes(self, mock_media_handler):
|
||||
"""Test tagging videos when no scenes are available"""
|
||||
mock_media_handler.get_tagme_scenes.return_value = []
|
||||
|
||||
asyncio.run(haven_vlm_connector.tag_videos())
|
||||
|
||||
mock_media_handler.get_tagme_scenes.assert_called_once()
|
||||
|
||||
@patch('haven_vlm_connector.media_handler')
|
||||
@patch('haven_vlm_connector.__find_marker_settings')
|
||||
def test_find_marker_settings_single_scene(self, mock_find_settings, mock_media_handler):
|
||||
"""Test finding marker settings with single scene"""
|
||||
mock_scenes = [{"id": 1}]
|
||||
mock_media_handler.get_tagme_scenes.return_value = mock_scenes
|
||||
|
||||
asyncio.run(haven_vlm_connector.find_marker_settings())
|
||||
|
||||
mock_media_handler.get_tagme_scenes.assert_called_once()
|
||||
mock_find_settings.assert_called_once_with(mock_scenes[0])
|
||||
|
||||
@patch('haven_vlm_connector.media_handler')
|
||||
def test_find_marker_settings_no_scenes(self, mock_media_handler):
|
||||
"""Test finding marker settings with no scenes"""
|
||||
mock_media_handler.get_tagme_scenes.return_value = []
|
||||
|
||||
asyncio.run(haven_vlm_connector.find_marker_settings())
|
||||
|
||||
mock_media_handler.get_tagme_scenes.assert_called_once()
|
||||
|
||||
@patch('haven_vlm_connector.media_handler')
|
||||
def test_find_marker_settings_multiple_scenes(self, mock_media_handler):
|
||||
"""Test finding marker settings with multiple scenes"""
|
||||
mock_scenes = [{"id": 1}, {"id": 2}]
|
||||
mock_media_handler.get_tagme_scenes.return_value = mock_scenes
|
||||
|
||||
asyncio.run(haven_vlm_connector.find_marker_settings())
|
||||
|
||||
mock_media_handler.get_tagme_scenes.assert_called_once()
|
||||
|
||||
@patch('haven_vlm_connector.media_handler')
|
||||
@patch('haven_vlm_connector.os.makedirs')
|
||||
@patch('haven_vlm_connector.shutil.copy')
|
||||
def test_collect_incorrect_markers_and_images_with_data(self, mock_copy, mock_makedirs, mock_media_handler):
|
||||
"""Test collecting incorrect markers and images with data"""
|
||||
mock_images = [{"id": 1, "files": [{"path": "/path/to/image.jpg"}]}]
|
||||
mock_markers = [{"id": 1, "scene": {"files": [{"path": "/path/to/video.mp4"}]}, "primary_tag": {"name": "test"}}]
|
||||
mock_media_handler.get_incorrect_images.return_value = mock_images
|
||||
mock_media_handler.get_incorrect_markers.return_value = mock_markers
|
||||
mock_media_handler.get_image_paths_and_ids.return_value = (["/path/to/image.jpg"], [1], [])
|
||||
|
||||
haven_vlm_connector.collect_incorrect_markers_and_images()
|
||||
|
||||
mock_media_handler.get_incorrect_images.assert_called_once()
|
||||
mock_media_handler.get_incorrect_markers.assert_called_once()
|
||||
mock_media_handler.remove_incorrect_tag_from_images.assert_called_once()
|
||||
|
||||
@patch('haven_vlm_connector.media_handler')
|
||||
def test_collect_incorrect_markers_and_images_no_data(self, mock_media_handler):
|
||||
"""Test collecting incorrect markers and images with no data"""
|
||||
mock_media_handler.get_incorrect_images.return_value = []
|
||||
mock_media_handler.get_incorrect_markers.return_value = []
|
||||
|
||||
haven_vlm_connector.collect_incorrect_markers_and_images()
|
||||
|
||||
mock_media_handler.get_incorrect_images.assert_called_once()
|
||||
mock_media_handler.get_incorrect_markers.assert_called_once()
|
||||
|
||||
|
||||
class TestLowLevelProcessingFunctions(unittest.TestCase):
|
||||
"""Test cases for low-level processing functions"""
|
||||
|
||||
@patch('haven_vlm_connector.vlm_engine')
|
||||
@patch('haven_vlm_connector.media_handler')
|
||||
@patch('haven_vlm_connector.semaphore')
|
||||
def test_tag_images_success(self, mock_semaphore, mock_media_handler, mock_vlm_engine):
|
||||
"""Test successful image tagging"""
|
||||
mock_images = [{"id": 1}, {"id": 2}]
|
||||
mock_media_handler.get_image_paths_and_ids.return_value = (["/path1.jpg", "/path2.jpg"], [1, 2], [])
|
||||
mock_vlm_engine.process_images_async.return_value = MagicMock(result=[{"tags": ["tag1"]}, {"tags": ["tag2"]}])
|
||||
mock_media_handler.get_tag_ids.return_value = [100, 200]
|
||||
|
||||
# Mock semaphore context manager
|
||||
mock_semaphore.__aenter__ = AsyncMock()
|
||||
mock_semaphore.__aexit__ = AsyncMock()
|
||||
|
||||
asyncio.run(haven_vlm_connector.__tag_images(mock_images))
|
||||
|
||||
mock_media_handler.get_image_paths_and_ids.assert_called_once_with(mock_images)
|
||||
mock_vlm_engine.process_images_async.assert_called_once()
|
||||
mock_media_handler.remove_tagme_tags_from_images.assert_called_once()
|
||||
|
||||
@patch('haven_vlm_connector.vlm_engine')
|
||||
@patch('haven_vlm_connector.media_handler')
|
||||
@patch('haven_vlm_connector.semaphore')
|
||||
def test_tag_images_error(self, mock_semaphore, mock_media_handler, mock_vlm_engine):
|
||||
"""Test image tagging with error"""
|
||||
mock_images = [{"id": 1}]
|
||||
mock_vlm_engine.process_images_async.side_effect = Exception("Processing error")
|
||||
|
||||
# Mock semaphore context manager
|
||||
mock_semaphore.__aenter__ = AsyncMock()
|
||||
mock_semaphore.__aexit__ = AsyncMock()
|
||||
|
||||
asyncio.run(haven_vlm_connector.__tag_images(mock_images))
|
||||
|
||||
mock_media_handler.add_error_images.assert_called_once()
|
||||
|
||||
@patch('haven_vlm_connector.vlm_engine')
|
||||
@patch('haven_vlm_connector.media_handler')
|
||||
@patch('haven_vlm_connector.semaphore')
|
||||
def test_tag_video_success(self, mock_semaphore, mock_media_handler, mock_vlm_engine):
|
||||
"""Test successful video tagging"""
|
||||
mock_scene = {
|
||||
"id": 1,
|
||||
"files": [{"path": "/path/to/video.mp4"}],
|
||||
"tags": []
|
||||
}
|
||||
mock_vlm_engine.process_video_async.return_value = MagicMock(
|
||||
video_tags={"category": ["tag1", "tag2"]},
|
||||
tag_timespans={}
|
||||
)
|
||||
mock_media_handler.is_vr_scene.return_value = False
|
||||
mock_media_handler.get_tag_ids.return_value = [100, 200]
|
||||
|
||||
# Mock semaphore context manager
|
||||
mock_semaphore.__aenter__ = AsyncMock()
|
||||
mock_semaphore.__aexit__ = AsyncMock()
|
||||
|
||||
asyncio.run(haven_vlm_connector.__tag_video(mock_scene))
|
||||
|
||||
mock_vlm_engine.process_video_async.assert_called_once()
|
||||
|
||||
# Verify tags and markers were cleared before adding new ones
|
||||
mock_media_handler.clear_all_tags_from_video.assert_called_once_with(1)
|
||||
mock_media_handler.clear_all_markers_from_video.assert_called_once_with(1)
|
||||
|
||||
mock_media_handler.add_tags_to_video.assert_called_once()
|
||||
mock_media_handler.remove_tagme_tag_from_scene.assert_called_once()
|
||||
|
||||
@patch('haven_vlm_connector.vlm_engine')
|
||||
@patch('haven_vlm_connector.media_handler')
|
||||
@patch('haven_vlm_connector.semaphore')
|
||||
def test_tag_video_error(self, mock_semaphore, mock_media_handler, mock_vlm_engine):
|
||||
"""Test video tagging with error"""
|
||||
mock_scene = {
|
||||
"id": 1,
|
||||
"files": [{"path": "/path/to/video.mp4"}],
|
||||
"tags": []
|
||||
}
|
||||
mock_vlm_engine.process_video_async.side_effect = Exception("Processing error")
|
||||
|
||||
# Mock semaphore context manager
|
||||
mock_semaphore.__aenter__ = AsyncMock()
|
||||
mock_semaphore.__aexit__ = AsyncMock()
|
||||
|
||||
asyncio.run(haven_vlm_connector.__tag_video(mock_scene))
|
||||
|
||||
mock_media_handler.add_error_scene.assert_called_once()
|
||||
|
||||
@patch('haven_vlm_connector.vlm_engine')
|
||||
@patch('haven_vlm_connector.media_handler')
|
||||
def test_find_marker_settings_success(self, mock_media_handler, mock_vlm_engine):
|
||||
"""Test successful marker settings finding"""
|
||||
mock_scene = {
|
||||
"id": 1,
|
||||
"files": [{"path": "/path/to/video.mp4"}]
|
||||
}
|
||||
mock_markers = [
|
||||
{
|
||||
"primary_tag": {"name": "tag1"},
|
||||
"seconds": 10.0,
|
||||
"end_seconds": 15.0
|
||||
}
|
||||
]
|
||||
mock_media_handler.get_scene_markers.return_value = mock_markers
|
||||
mock_vlm_engine.find_optimal_marker_settings_async.return_value = {"optimal": "settings"}
|
||||
|
||||
asyncio.run(haven_vlm_connector.__find_marker_settings(mock_scene))
|
||||
|
||||
mock_media_handler.get_scene_markers.assert_called_once_with(1)
|
||||
mock_vlm_engine.find_optimal_marker_settings_async.assert_called_once()
|
||||
|
||||
@patch('haven_vlm_connector.media_handler')
|
||||
def test_find_marker_settings_error(self, mock_media_handler):
|
||||
"""Test marker settings finding with error"""
|
||||
mock_scene = {
|
||||
"id": 1,
|
||||
"files": [{"path": "/path/to/video.mp4"}]
|
||||
}
|
||||
mock_media_handler.get_scene_markers.side_effect = Exception("Marker error")
|
||||
|
||||
asyncio.run(haven_vlm_connector.__find_marker_settings(mock_scene))
|
||||
|
||||
mock_media_handler.get_scene_markers.assert_called_once()
|
||||
|
||||
|
||||
class TestUtilityFunctions(unittest.TestCase):
|
||||
"""Test cases for utility functions"""
|
||||
|
||||
def test_increment_progress(self):
|
||||
"""Test progress increment"""
|
||||
haven_vlm_connector.progress = 0.0
|
||||
haven_vlm_connector.increment = 0.1
|
||||
|
||||
haven_vlm_connector.increment_progress()
|
||||
|
||||
self.assertEqual(haven_vlm_connector.progress, 0.1)
|
||||
|
||||
@patch('haven_vlm_connector.vlm_engine')
|
||||
async def test_cleanup(self, mock_vlm_engine):
|
||||
"""Test cleanup function"""
|
||||
mock_vlm_engine.vlm_engine = MagicMock()
|
||||
|
||||
await haven_vlm_connector.cleanup()
|
||||
|
||||
mock_vlm_engine.vlm_engine.shutdown.assert_called_once()
|
||||
|
||||
|
||||
class TestMainFunction(unittest.TestCase):
|
||||
"""Test cases for main function"""
|
||||
|
||||
@patch('haven_vlm_connector.run')
|
||||
@patch('haven_vlm_connector.read_json_input')
|
||||
@patch('haven_vlm_connector.json.dumps')
|
||||
@patch('builtins.print')
|
||||
def test_main_success(self, mock_print, mock_json_dumps, mock_read_input, mock_run):
|
||||
"""Test successful main execution"""
|
||||
mock_read_input.return_value = {"test": "data"}
|
||||
mock_json_dumps.return_value = '{"output": "ok"}'
|
||||
|
||||
asyncio.run(haven_vlm_connector.main())
|
||||
|
||||
mock_read_input.assert_called_once()
|
||||
mock_run.assert_called_once()
|
||||
mock_json_dumps.assert_called_once()
|
||||
mock_print.assert_called()
|
||||
|
||||
|
||||
class TestErrorHandling(unittest.TestCase):
|
||||
"""Test cases for error handling"""
|
||||
|
||||
@patch('haven_vlm_connector.media_handler')
|
||||
def test_tag_images_empty_paths(self, mock_media_handler):
|
||||
"""Test image tagging with empty paths"""
|
||||
mock_images = [{"id": 1}]
|
||||
mock_media_handler.get_image_paths_and_ids.return_value = ([], [1], [])
|
||||
|
||||
# Mock semaphore context manager
|
||||
with patch('haven_vlm_connector.semaphore') as mock_semaphore:
|
||||
mock_semaphore.__aenter__ = AsyncMock()
|
||||
mock_semaphore.__aexit__ = AsyncMock()
|
||||
|
||||
asyncio.run(haven_vlm_connector.__tag_images(mock_images))
|
||||
|
||||
mock_media_handler.get_image_paths_and_ids.assert_called_once()
|
||||
|
||||
@patch('haven_vlm_connector.vlm_engine')
|
||||
@patch('haven_vlm_connector.media_handler')
|
||||
def test_tag_video_no_detected_tags(self, mock_media_handler, mock_vlm_engine):
|
||||
"""Test video tagging with no detected tags"""
|
||||
mock_scene = {
|
||||
"id": 1,
|
||||
"files": [{"path": "/path/to/video.mp4"}],
|
||||
"tags": []
|
||||
}
|
||||
mock_vlm_engine.process_video_async.return_value = MagicMock(
|
||||
video_tags={},
|
||||
tag_timespans={}
|
||||
)
|
||||
mock_media_handler.is_vr_scene.return_value = False
|
||||
|
||||
# Mock semaphore context manager
|
||||
with patch('haven_vlm_connector.semaphore') as mock_semaphore:
|
||||
mock_semaphore.__aenter__ = AsyncMock()
|
||||
mock_semaphore.__aexit__ = AsyncMock()
|
||||
|
||||
asyncio.run(haven_vlm_connector.__tag_video(mock_scene))
|
||||
|
||||
# Verify clearing functions are NOT called when no tags are detected
|
||||
mock_media_handler.clear_all_tags_from_video.assert_not_called()
|
||||
mock_media_handler.clear_all_markers_from_video.assert_not_called()
|
||||
|
||||
mock_media_handler.remove_tagme_tag_from_scene.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
544
plugins/AHavenVLMConnector/test_haven_vlm_engine.py
Normal file
544
plugins/AHavenVLMConnector/test_haven_vlm_engine.py
Normal file
@ -0,0 +1,544 @@
|
||||
"""
|
||||
Unit tests for haven_vlm_engine module
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import asyncio
|
||||
import json
|
||||
import tempfile
|
||||
import os
|
||||
from unittest.mock import patch, MagicMock, AsyncMock, mock_open
|
||||
import sys
|
||||
|
||||
# Mock the vlm_engine imports
|
||||
sys.modules['vlm_engine'] = MagicMock()
|
||||
sys.modules['vlm_engine.config_models'] = MagicMock()
|
||||
|
||||
import haven_vlm_engine
|
||||
|
||||
|
||||
class TestTimeFrame(unittest.TestCase):
|
||||
"""Test cases for TimeFrame dataclass"""
|
||||
|
||||
def test_timeframe_creation(self):
|
||||
"""Test creating TimeFrame with all parameters"""
|
||||
timeframe = haven_vlm_engine.TimeFrame(
|
||||
start=10.0,
|
||||
end=15.0,
|
||||
total_confidence=0.85
|
||||
)
|
||||
|
||||
self.assertEqual(timeframe.start, 10.0)
|
||||
self.assertEqual(timeframe.end, 15.0)
|
||||
self.assertEqual(timeframe.total_confidence, 0.85)
|
||||
|
||||
def test_timeframe_creation_without_confidence(self):
|
||||
"""Test creating TimeFrame without confidence"""
|
||||
timeframe = haven_vlm_engine.TimeFrame(
|
||||
start=10.0,
|
||||
end=15.0
|
||||
)
|
||||
|
||||
self.assertEqual(timeframe.start, 10.0)
|
||||
self.assertEqual(timeframe.end, 15.0)
|
||||
self.assertIsNone(timeframe.total_confidence)
|
||||
|
||||
def test_timeframe_to_json(self):
|
||||
"""Test TimeFrame to_json method"""
|
||||
timeframe = haven_vlm_engine.TimeFrame(
|
||||
start=10.0,
|
||||
end=15.0,
|
||||
total_confidence=0.85
|
||||
)
|
||||
|
||||
json_str = timeframe.to_json()
|
||||
json_data = json.loads(json_str)
|
||||
|
||||
self.assertEqual(json_data["start"], 10.0)
|
||||
self.assertEqual(json_data["end"], 15.0)
|
||||
self.assertEqual(json_data["total_confidence"], 0.85)
|
||||
|
||||
def test_timeframe_to_json_without_confidence(self):
|
||||
"""Test TimeFrame to_json method without confidence"""
|
||||
timeframe = haven_vlm_engine.TimeFrame(
|
||||
start=10.0,
|
||||
end=15.0
|
||||
)
|
||||
|
||||
json_str = timeframe.to_json()
|
||||
json_data = json.loads(json_str)
|
||||
|
||||
self.assertEqual(json_data["start"], 10.0)
|
||||
self.assertEqual(json_data["end"], 15.0)
|
||||
self.assertIsNone(json_data["total_confidence"])
|
||||
|
||||
def test_timeframe_str(self):
|
||||
"""Test TimeFrame string representation"""
|
||||
timeframe = haven_vlm_engine.TimeFrame(
|
||||
start=10.0,
|
||||
end=15.0,
|
||||
total_confidence=0.85
|
||||
)
|
||||
|
||||
str_repr = str(timeframe)
|
||||
self.assertIn("10.0", str_repr)
|
||||
self.assertIn("15.0", str_repr)
|
||||
self.assertIn("0.85", str_repr)
|
||||
|
||||
|
||||
class TestVideoTagInfo(unittest.TestCase):
|
||||
"""Test cases for VideoTagInfo dataclass"""
|
||||
|
||||
def test_videotaginfo_creation(self):
|
||||
"""Test creating VideoTagInfo with all parameters"""
|
||||
video_tags = {"category1": {"tag1", "tag2"}}
|
||||
tag_totals = {"tag1": {"total": 0.8}}
|
||||
tag_timespans = {"category1": {"tag1": [haven_vlm_engine.TimeFrame(10.0, 15.0)]}}
|
||||
|
||||
video_info = haven_vlm_engine.VideoTagInfo(
|
||||
video_duration=120.0,
|
||||
video_tags=video_tags,
|
||||
tag_totals=tag_totals,
|
||||
tag_timespans=tag_timespans
|
||||
)
|
||||
|
||||
self.assertEqual(video_info.video_duration, 120.0)
|
||||
self.assertEqual(video_info.video_tags, video_tags)
|
||||
self.assertEqual(video_info.tag_totals, tag_totals)
|
||||
self.assertEqual(video_info.tag_timespans, tag_timespans)
|
||||
|
||||
def test_videotaginfo_from_json(self):
|
||||
"""Test creating VideoTagInfo from JSON data"""
|
||||
json_data = {
|
||||
"video_duration": 120.0,
|
||||
"video_tags": {"category1": ["tag1", "tag2"]},
|
||||
"tag_totals": {"tag1": {"total": 0.8}},
|
||||
"tag_timespans": {
|
||||
"category1": {
|
||||
"tag1": [
|
||||
{"start": 10.0, "end": 15.0, "total_confidence": 0.85}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
video_info = haven_vlm_engine.VideoTagInfo.from_json(json_data)
|
||||
|
||||
self.assertEqual(video_info.video_duration, 120.0)
|
||||
self.assertEqual(video_info.video_tags, {"category1": ["tag1", "tag2"]})
|
||||
self.assertEqual(video_info.tag_totals, {"tag1": {"total": 0.8}})
|
||||
|
||||
# Check that tag_timespans contains TimeFrame objects
|
||||
self.assertIn("category1", video_info.tag_timespans)
|
||||
self.assertIn("tag1", video_info.tag_timespans["category1"])
|
||||
self.assertIsInstance(video_info.tag_timespans["category1"]["tag1"][0], haven_vlm_engine.TimeFrame)
|
||||
|
||||
def test_videotaginfo_from_json_without_confidence(self):
|
||||
"""Test creating VideoTagInfo from JSON data without confidence"""
|
||||
json_data = {
|
||||
"video_duration": 120.0,
|
||||
"video_tags": {"category1": ["tag1"]},
|
||||
"tag_totals": {"tag1": {"total": 0.8}},
|
||||
"tag_timespans": {
|
||||
"category1": {
|
||||
"tag1": [
|
||||
{"start": 10.0, "end": 15.0}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
video_info = haven_vlm_engine.VideoTagInfo.from_json(json_data)
|
||||
|
||||
timeframe = video_info.tag_timespans["category1"]["tag1"][0]
|
||||
self.assertEqual(timeframe.start, 10.0)
|
||||
self.assertEqual(timeframe.end, 15.0)
|
||||
self.assertIsNone(timeframe.total_confidence)
|
||||
|
||||
def test_videotaginfo_from_json_empty_timespans(self):
|
||||
"""Test creating VideoTagInfo from JSON data with empty timespans"""
|
||||
json_data = {
|
||||
"video_duration": 120.0,
|
||||
"video_tags": {"category1": ["tag1"]},
|
||||
"tag_totals": {"tag1": {"total": 0.8}},
|
||||
"tag_timespans": {}
|
||||
}
|
||||
|
||||
video_info = haven_vlm_engine.VideoTagInfo.from_json(json_data)
|
||||
|
||||
self.assertEqual(video_info.video_duration, 120.0)
|
||||
self.assertEqual(video_info.tag_timespans, {})
|
||||
|
||||
def test_videotaginfo_str(self):
|
||||
"""Test VideoTagInfo string representation"""
|
||||
video_info = haven_vlm_engine.VideoTagInfo(
|
||||
video_duration=120.0,
|
||||
video_tags={"category1": {"tag1"}},
|
||||
tag_totals={"tag1": {"total": 0.8}},
|
||||
tag_timespans={"category1": {"tag1": []}}
|
||||
)
|
||||
|
||||
str_repr = str(video_info)
|
||||
self.assertIn("120.0", str_repr)
|
||||
self.assertIn("1", str_repr) # number of tags
|
||||
self.assertIn("1", str_repr) # number of timespans
|
||||
|
||||
|
||||
class TestImageResult(unittest.TestCase):
|
||||
"""Test cases for ImageResult dataclass"""
|
||||
|
||||
def test_imageresult_creation(self):
|
||||
"""Test creating ImageResult with valid data"""
|
||||
result_data = [{"tags": ["tag1"], "confidence": 0.8}]
|
||||
image_result = haven_vlm_engine.ImageResult(result=result_data)
|
||||
|
||||
self.assertEqual(image_result.result, result_data)
|
||||
|
||||
def test_imageresult_creation_empty_list(self):
|
||||
"""Test creating ImageResult with empty list"""
|
||||
with self.assertRaises(ValueError):
|
||||
haven_vlm_engine.ImageResult(result=[])
|
||||
|
||||
def test_imageresult_creation_none_result(self):
|
||||
"""Test creating ImageResult with None result"""
|
||||
with self.assertRaises(ValueError):
|
||||
haven_vlm_engine.ImageResult(result=None)
|
||||
|
||||
|
||||
class TestHavenVLMEngine(unittest.TestCase):
|
||||
"""Test cases for HavenVLMEngine class"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
self.engine = haven_vlm_engine.HavenVLMEngine()
|
||||
|
||||
def test_engine_initialization(self):
|
||||
"""Test engine initialization"""
|
||||
self.assertIsNone(self.engine.engine)
|
||||
self.assertIsNone(self.engine.engine_config)
|
||||
self.assertFalse(self.engine._initialized)
|
||||
|
||||
@patch('haven_vlm_engine.config')
|
||||
@patch('haven_vlm_engine.VLMEngine')
|
||||
async def test_initialize_success(self, mock_vlm_engine_class, mock_config):
|
||||
"""Test successful engine initialization"""
|
||||
mock_engine_instance = MagicMock()
|
||||
mock_vlm_engine_class.return_value = mock_engine_instance
|
||||
mock_config.config.vlm_engine_config = {"test": "config"}
|
||||
|
||||
await self.engine.initialize()
|
||||
|
||||
self.assertTrue(self.engine._initialized)
|
||||
mock_vlm_engine_class.assert_called_once()
|
||||
mock_engine_instance.initialize.assert_called_once()
|
||||
|
||||
@patch('haven_vlm_engine.config')
|
||||
@patch('haven_vlm_engine.VLMEngine')
|
||||
async def test_initialize_already_initialized(self, mock_vlm_engine_class, mock_config):
|
||||
"""Test initialization when already initialized"""
|
||||
self.engine._initialized = True
|
||||
|
||||
await self.engine.initialize()
|
||||
|
||||
# Should not call VLMEngine constructor again
|
||||
mock_vlm_engine_class.assert_not_called()
|
||||
|
||||
@patch('haven_vlm_engine.config')
|
||||
@patch('haven_vlm_engine.VLMEngine')
|
||||
async def test_initialize_error(self, mock_vlm_engine_class, mock_config):
|
||||
"""Test initialization with error"""
|
||||
mock_vlm_engine_class.side_effect = Exception("Initialization failed")
|
||||
mock_config.config.vlm_engine_config = {"test": "config"}
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
await self.engine.initialize()
|
||||
|
||||
self.assertFalse(self.engine._initialized)
|
||||
|
||||
@patch('haven_vlm_engine.config')
|
||||
def test_create_engine_config(self, mock_config):
|
||||
"""Test creating engine configuration"""
|
||||
mock_config.config.vlm_engine_config = {
|
||||
"active_ai_models": ["model1"],
|
||||
"pipelines": {
|
||||
"pipeline1": {
|
||||
"inputs": ["input1"],
|
||||
"output": "output1",
|
||||
"short_name": "short1",
|
||||
"version": 1.0,
|
||||
"models": [
|
||||
{
|
||||
"name": "model1",
|
||||
"inputs": ["input1"],
|
||||
"outputs": "output1"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"models": {
|
||||
"model1": {
|
||||
"type": "vlm_model",
|
||||
"model_file_name": "model1.py",
|
||||
"model_category": "test",
|
||||
"model_id": "test_model",
|
||||
"model_identifier": 123,
|
||||
"model_version": "1.0",
|
||||
"use_multiplexer": True,
|
||||
"max_concurrent_requests": 10,
|
||||
"connection_pool_size": 20,
|
||||
"multiplexer_endpoints": [],
|
||||
"tag_list": ["tag1"]
|
||||
}
|
||||
},
|
||||
"category_config": {"test": {}}
|
||||
}
|
||||
|
||||
config = self.engine._create_engine_config()
|
||||
|
||||
self.assertIsNotNone(config)
|
||||
# Note: We can't easily test the exact structure without the actual VLM Engine classes
|
||||
# but we can verify the method doesn't raise exceptions
|
||||
|
||||
@patch('haven_vlm_engine.config')
|
||||
@patch('haven_vlm_engine.VLMEngine')
|
||||
async def test_process_video_success(self, mock_vlm_engine_class, mock_config):
|
||||
"""Test successful video processing"""
|
||||
# Setup mocks
|
||||
mock_engine_instance = MagicMock()
|
||||
mock_vlm_engine_class.return_value = mock_engine_instance
|
||||
mock_config.config.video_frame_interval = 2.0
|
||||
mock_config.config.video_threshold = 0.3
|
||||
mock_config.config.video_confidence_return = True
|
||||
|
||||
# Mock the engine's process_video method
|
||||
mock_engine_instance.process_video.return_value = {
|
||||
"video_duration": 120.0,
|
||||
"video_tags": {"category1": ["tag1"]},
|
||||
"tag_totals": {"tag1": {"total": 0.8}},
|
||||
"tag_timespans": {}
|
||||
}
|
||||
|
||||
# Initialize engine
|
||||
await self.engine.initialize()
|
||||
|
||||
# Process video
|
||||
result = await self.engine.process_video("/path/to/video.mp4")
|
||||
|
||||
self.assertIsInstance(result, haven_vlm_engine.VideoTagInfo)
|
||||
mock_engine_instance.process_video.assert_called_once()
|
||||
|
||||
@patch('haven_vlm_engine.config')
|
||||
@patch('haven_vlm_engine.VLMEngine')
|
||||
async def test_process_video_not_initialized(self, mock_vlm_engine_class, mock_config):
|
||||
"""Test video processing when not initialized"""
|
||||
mock_engine_instance = MagicMock()
|
||||
mock_vlm_engine_class.return_value = mock_engine_instance
|
||||
mock_config.config.video_frame_interval = 2.0
|
||||
mock_config.config.video_threshold = 0.3
|
||||
mock_config.config.video_confidence_return = True
|
||||
|
||||
mock_engine_instance.process_video.return_value = {
|
||||
"video_duration": 120.0,
|
||||
"video_tags": {"category1": ["tag1"]},
|
||||
"tag_totals": {"tag1": {"total": 0.8}},
|
||||
"tag_timespans": {}
|
||||
}
|
||||
|
||||
# Process video without explicit initialization
|
||||
result = await self.engine.process_video("/path/to/video.mp4")
|
||||
|
||||
self.assertIsInstance(result, haven_vlm_engine.VideoTagInfo)
|
||||
mock_engine_instance.initialize.assert_called_once()
|
||||
|
||||
@patch('haven_vlm_engine.config')
|
||||
@patch('haven_vlm_engine.VLMEngine')
|
||||
async def test_process_video_error(self, mock_vlm_engine_class, mock_config):
|
||||
"""Test video processing with error"""
|
||||
mock_engine_instance = MagicMock()
|
||||
mock_vlm_engine_class.return_value = mock_engine_instance
|
||||
mock_config.config.video_frame_interval = 2.0
|
||||
mock_config.config.video_threshold = 0.3
|
||||
mock_config.config.video_confidence_return = True
|
||||
|
||||
mock_engine_instance.process_video.side_effect = Exception("Processing failed")
|
||||
|
||||
await self.engine.initialize()
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
await self.engine.process_video("/path/to/video.mp4")
|
||||
|
||||
@patch('haven_vlm_engine.config')
|
||||
@patch('haven_vlm_engine.VLMEngine')
|
||||
async def test_process_images_success(self, mock_vlm_engine_class, mock_config):
|
||||
"""Test successful image processing"""
|
||||
# Setup mocks
|
||||
mock_engine_instance = MagicMock()
|
||||
mock_vlm_engine_class.return_value = mock_engine_instance
|
||||
mock_config.config.image_threshold = 0.5
|
||||
mock_config.config.image_confidence_return = False
|
||||
|
||||
# Mock the engine's process_images method
|
||||
mock_engine_instance.process_images.return_value = [
|
||||
{"tags": ["tag1"], "confidence": 0.8}
|
||||
]
|
||||
|
||||
# Initialize engine
|
||||
await self.engine.initialize()
|
||||
|
||||
# Process images
|
||||
result = await self.engine.process_images(["/path/to/image1.jpg"])
|
||||
|
||||
self.assertIsInstance(result, haven_vlm_engine.ImageResult)
|
||||
mock_engine_instance.process_images.assert_called_once()
|
||||
|
||||
@patch('haven_vlm_engine.config')
|
||||
@patch('haven_vlm_engine.VLMEngine')
|
||||
async def test_process_images_error(self, mock_vlm_engine_class, mock_config):
|
||||
"""Test image processing with error"""
|
||||
mock_engine_instance = MagicMock()
|
||||
mock_vlm_engine_class.return_value = mock_engine_instance
|
||||
mock_config.config.image_threshold = 0.5
|
||||
mock_config.config.image_confidence_return = False
|
||||
|
||||
mock_engine_instance.process_images.side_effect = Exception("Processing failed")
|
||||
|
||||
await self.engine.initialize()
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
await self.engine.process_images(["/path/to/image1.jpg"])
|
||||
|
||||
@patch('haven_vlm_engine.config')
|
||||
@patch('haven_vlm_engine.VLMEngine')
|
||||
async def test_find_optimal_marker_settings_success(self, mock_vlm_engine_class, mock_config):
|
||||
"""Test successful marker settings optimization"""
|
||||
# Setup mocks
|
||||
mock_engine_instance = MagicMock()
|
||||
mock_vlm_engine_class.return_value = mock_engine_instance
|
||||
mock_engine_instance.optimize_timeframe_settings.return_value = {"optimal": "settings"}
|
||||
|
||||
# Initialize engine
|
||||
await self.engine.initialize()
|
||||
|
||||
# Test data
|
||||
existing_json = {"existing": "data"}
|
||||
desired_timespan_data = {
|
||||
"tag1": haven_vlm_engine.TimeFrame(10.0, 15.0, 0.8)
|
||||
}
|
||||
|
||||
# Find optimal settings
|
||||
result = await self.engine.find_optimal_marker_settings(existing_json, desired_timespan_data)
|
||||
|
||||
self.assertEqual(result, {"optimal": "settings"})
|
||||
mock_engine_instance.optimize_timeframe_settings.assert_called_once()
|
||||
|
||||
@patch('haven_vlm_engine.config')
|
||||
@patch('haven_vlm_engine.VLMEngine')
|
||||
async def test_find_optimal_marker_settings_error(self, mock_vlm_engine_class, mock_config):
|
||||
"""Test marker settings optimization with error"""
|
||||
mock_engine_instance = MagicMock()
|
||||
mock_vlm_engine_class.return_value = mock_engine_instance
|
||||
mock_engine_instance.optimize_timeframe_settings.side_effect = Exception("Optimization failed")
|
||||
|
||||
await self.engine.initialize()
|
||||
|
||||
existing_json = {"existing": "data"}
|
||||
desired_timespan_data = {
|
||||
"tag1": haven_vlm_engine.TimeFrame(10.0, 15.0, 0.8)
|
||||
}
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
await self.engine.find_optimal_marker_settings(existing_json, desired_timespan_data)
|
||||
|
||||
@patch('haven_vlm_engine.config')
|
||||
@patch('haven_vlm_engine.VLMEngine')
|
||||
async def test_shutdown_success(self, mock_vlm_engine_class, mock_config):
|
||||
"""Test successful engine shutdown"""
|
||||
mock_engine_instance = MagicMock()
|
||||
mock_vlm_engine_class.return_value = mock_engine_instance
|
||||
|
||||
await self.engine.initialize()
|
||||
await self.engine.shutdown()
|
||||
|
||||
mock_engine_instance.shutdown.assert_called_once()
|
||||
self.assertFalse(self.engine._initialized)
|
||||
|
||||
@patch('haven_vlm_engine.config')
|
||||
@patch('haven_vlm_engine.VLMEngine')
|
||||
async def test_shutdown_not_initialized(self, mock_vlm_engine_class, mock_config):
|
||||
"""Test shutdown when not initialized"""
|
||||
await self.engine.shutdown()
|
||||
|
||||
# Should not raise any exceptions
|
||||
self.assertFalse(self.engine._initialized)
|
||||
|
||||
@patch('haven_vlm_engine.config')
|
||||
@patch('haven_vlm_engine.VLMEngine')
|
||||
async def test_shutdown_error(self, mock_vlm_engine_class, mock_config):
|
||||
"""Test shutdown with error"""
|
||||
mock_engine_instance = MagicMock()
|
||||
mock_vlm_engine_class.return_value = mock_engine_instance
|
||||
mock_engine_instance.shutdown.side_effect = Exception("Shutdown failed")
|
||||
|
||||
await self.engine.initialize()
|
||||
await self.engine.shutdown()
|
||||
|
||||
# Should handle the error gracefully
|
||||
self.assertFalse(self.engine._initialized)
|
||||
|
||||
|
||||
class TestConvenienceFunctions(unittest.TestCase):
|
||||
"""Test cases for convenience functions"""
|
||||
|
||||
@patch('haven_vlm_engine.vlm_engine')
|
||||
async def test_process_video_async(self, mock_vlm_engine):
|
||||
"""Test process_video_async convenience function"""
|
||||
mock_vlm_engine.process_video.return_value = MagicMock()
|
||||
|
||||
result = await haven_vlm_engine.process_video_async("/path/to/video.mp4")
|
||||
|
||||
mock_vlm_engine.process_video.assert_called_once()
|
||||
self.assertEqual(result, mock_vlm_engine.process_video.return_value)
|
||||
|
||||
@patch('haven_vlm_engine.vlm_engine')
|
||||
async def test_process_images_async(self, mock_vlm_engine):
|
||||
"""Test process_images_async convenience function"""
|
||||
mock_vlm_engine.process_images.return_value = MagicMock()
|
||||
|
||||
result = await haven_vlm_engine.process_images_async(["/path/to/image.jpg"])
|
||||
|
||||
mock_vlm_engine.process_images.assert_called_once()
|
||||
self.assertEqual(result, mock_vlm_engine.process_images.return_value)
|
||||
|
||||
@patch('haven_vlm_engine.vlm_engine')
|
||||
async def test_find_optimal_marker_settings_async(self, mock_vlm_engine):
|
||||
"""Test find_optimal_marker_settings_async convenience function"""
|
||||
mock_vlm_engine.find_optimal_marker_settings.return_value = {"optimal": "settings"}
|
||||
|
||||
existing_json = {"existing": "data"}
|
||||
desired_timespan_data = {
|
||||
"tag1": haven_vlm_engine.TimeFrame(10.0, 15.0, 0.8)
|
||||
}
|
||||
|
||||
result = await haven_vlm_engine.find_optimal_marker_settings_async(existing_json, desired_timespan_data)
|
||||
|
||||
mock_vlm_engine.find_optimal_marker_settings.assert_called_once()
|
||||
self.assertEqual(result, {"optimal": "settings"})
|
||||
|
||||
|
||||
class TestGlobalVLMEngineInstance(unittest.TestCase):
|
||||
"""Test cases for global VLM engine instance"""
|
||||
|
||||
def test_global_vlm_engine_exists(self):
|
||||
"""Test that global VLM engine instance exists"""
|
||||
self.assertIsInstance(haven_vlm_engine.vlm_engine, haven_vlm_engine.HavenVLMEngine)
|
||||
|
||||
def test_global_vlm_engine_is_singleton(self):
|
||||
"""Test that global VLM engine is a singleton"""
|
||||
engine1 = haven_vlm_engine.vlm_engine
|
||||
engine2 = haven_vlm_engine.vlm_engine
|
||||
|
||||
self.assertIs(engine1, engine2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
604
plugins/AHavenVLMConnector/test_haven_vlm_utility.py
Normal file
604
plugins/AHavenVLMConnector/test_haven_vlm_utility.py
Normal file
@ -0,0 +1,604 @@
|
||||
"""
|
||||
Unit tests for haven_vlm_utility module
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import tempfile
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from unittest.mock import patch, mock_open, MagicMock
|
||||
import yaml
|
||||
|
||||
import haven_vlm_utility
|
||||
|
||||
|
||||
class TestPathMutations(unittest.TestCase):
|
||||
"""Test cases for path mutation functions"""
|
||||
|
||||
def test_apply_path_mutations_with_mutations(self):
|
||||
"""Test applying path mutations with valid mutations"""
|
||||
mutations = {"E:": "F:", "G:": "D:"}
|
||||
path = "E:\\videos\\test.mp4"
|
||||
|
||||
result = haven_vlm_utility.apply_path_mutations(path, mutations)
|
||||
|
||||
self.assertEqual(result, "F:\\videos\\test.mp4")
|
||||
|
||||
def test_apply_path_mutations_without_mutations(self):
|
||||
"""Test applying path mutations with empty mutations"""
|
||||
mutations = {}
|
||||
path = "E:\\videos\\test.mp4"
|
||||
|
||||
result = haven_vlm_utility.apply_path_mutations(path, mutations)
|
||||
|
||||
self.assertEqual(result, path)
|
||||
|
||||
def test_apply_path_mutations_with_none_mutations(self):
|
||||
"""Test applying path mutations with None mutations"""
|
||||
mutations = None
|
||||
path = "E:\\videos\\test.mp4"
|
||||
|
||||
result = haven_vlm_utility.apply_path_mutations(path, mutations)
|
||||
|
||||
self.assertEqual(result, path)
|
||||
|
||||
def test_apply_path_mutations_no_match(self):
|
||||
"""Test applying path mutations when no mutation matches"""
|
||||
mutations = {"E:": "F:", "G:": "D:"}
|
||||
path = "C:\\videos\\test.mp4"
|
||||
|
||||
result = haven_vlm_utility.apply_path_mutations(path, mutations)
|
||||
|
||||
self.assertEqual(result, path)
|
||||
|
||||
def test_apply_path_mutations_multiple_matches(self):
|
||||
"""Test applying path mutations with multiple possible matches"""
|
||||
mutations = {"E:": "F:", "E:\\videos": "F:\\movies"}
|
||||
path = "E:\\videos\\test.mp4"
|
||||
|
||||
result = haven_vlm_utility.apply_path_mutations(path, mutations)
|
||||
|
||||
# Should use the first match
|
||||
self.assertEqual(result, "F:\\videos\\test.mp4")
|
||||
|
||||
|
||||
class TestDirectoryOperations(unittest.TestCase):
|
||||
"""Test cases for directory operations"""
|
||||
|
||||
def test_ensure_directory_exists_new_directory(self):
|
||||
"""Test creating a new directory"""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
new_dir = os.path.join(temp_dir, "test_subdir")
|
||||
|
||||
haven_vlm_utility.ensure_directory_exists(new_dir)
|
||||
|
||||
self.assertTrue(os.path.exists(new_dir))
|
||||
self.assertTrue(os.path.isdir(new_dir))
|
||||
|
||||
def test_ensure_directory_exists_existing_directory(self):
|
||||
"""Test ensuring directory exists when it already exists"""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
haven_vlm_utility.ensure_directory_exists(temp_dir)
|
||||
|
||||
self.assertTrue(os.path.exists(temp_dir))
|
||||
self.assertTrue(os.path.isdir(temp_dir))
|
||||
|
||||
def test_ensure_directory_exists_nested_directories(self):
|
||||
"""Test creating nested directories"""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
nested_dir = os.path.join(temp_dir, "level1", "level2", "level3")
|
||||
|
||||
haven_vlm_utility.ensure_directory_exists(nested_dir)
|
||||
|
||||
self.assertTrue(os.path.exists(nested_dir))
|
||||
self.assertTrue(os.path.isdir(nested_dir))
|
||||
|
||||
|
||||
class TestSafeFileOperations(unittest.TestCase):
|
||||
"""Test cases for safe file operations"""
|
||||
|
||||
def test_safe_file_operation_success(self):
|
||||
"""Test successful file operation"""
|
||||
def test_func(a, b, c=10):
|
||||
return a + b + c
|
||||
|
||||
result = haven_vlm_utility.safe_file_operation(test_func, 1, 2, c=5)
|
||||
|
||||
self.assertEqual(result, 8)
|
||||
|
||||
def test_safe_file_operation_os_error(self):
|
||||
"""Test file operation with OSError"""
|
||||
def test_func():
|
||||
raise OSError("File not found")
|
||||
|
||||
result = haven_vlm_utility.safe_file_operation(test_func)
|
||||
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_safe_file_operation_io_error(self):
|
||||
"""Test file operation with IOError"""
|
||||
def test_func():
|
||||
raise IOError("Permission denied")
|
||||
|
||||
result = haven_vlm_utility.safe_file_operation(test_func)
|
||||
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_safe_file_operation_unexpected_error(self):
|
||||
"""Test file operation with unexpected error"""
|
||||
def test_func():
|
||||
raise ValueError("Unexpected error")
|
||||
|
||||
result = haven_vlm_utility.safe_file_operation(test_func)
|
||||
|
||||
self.assertIsNone(result)
|
||||
|
||||
|
||||
class TestYamlConfigOperations(unittest.TestCase):
|
||||
"""Test cases for YAML configuration operations"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
self.test_config = {
|
||||
"video_frame_interval": 2.0,
|
||||
"video_threshold": 0.3,
|
||||
"image_threshold": 0.5,
|
||||
"endpoints": [
|
||||
{"url": "http://localhost:1234", "weight": 5},
|
||||
{"url": "https://cloud.example.com", "weight": 1}
|
||||
]
|
||||
}
|
||||
|
||||
def test_load_yaml_config_success(self):
|
||||
"""Test successfully loading YAML configuration"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.yml', delete=False) as f:
|
||||
yaml.dump(self.test_config, f)
|
||||
config_path = f.name
|
||||
|
||||
try:
|
||||
result = haven_vlm_utility.load_yaml_config(config_path)
|
||||
|
||||
self.assertEqual(result, self.test_config)
|
||||
finally:
|
||||
os.unlink(config_path)
|
||||
|
||||
def test_load_yaml_config_file_not_found(self):
|
||||
"""Test loading YAML configuration from nonexistent file"""
|
||||
result = haven_vlm_utility.load_yaml_config("nonexistent_file.yml")
|
||||
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_load_yaml_config_invalid_yaml(self):
|
||||
"""Test loading YAML configuration with invalid YAML"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.yml', delete=False) as f:
|
||||
f.write("invalid: yaml: content: [")
|
||||
config_path = f.name
|
||||
|
||||
try:
|
||||
result = haven_vlm_utility.load_yaml_config(config_path)
|
||||
|
||||
self.assertIsNone(result)
|
||||
finally:
|
||||
os.unlink(config_path)
|
||||
|
||||
def test_load_yaml_config_permission_error(self):
|
||||
"""Test loading YAML configuration with permission error"""
|
||||
with patch('builtins.open', side_effect=PermissionError("Permission denied")):
|
||||
result = haven_vlm_utility.load_yaml_config("test.yml")
|
||||
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_save_yaml_config_success(self):
|
||||
"""Test successfully saving YAML configuration"""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
config_path = os.path.join(temp_dir, "test_config.yml")
|
||||
|
||||
result = haven_vlm_utility.save_yaml_config(self.test_config, config_path)
|
||||
|
||||
self.assertTrue(result)
|
||||
self.assertTrue(os.path.exists(config_path))
|
||||
|
||||
# Verify the saved content
|
||||
with open(config_path, 'r') as f:
|
||||
loaded_config = yaml.safe_load(f)
|
||||
|
||||
self.assertEqual(loaded_config, self.test_config)
|
||||
|
||||
def test_save_yaml_config_with_nested_directories(self):
|
||||
"""Test saving YAML configuration to nested directory"""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
config_path = os.path.join(temp_dir, "nested", "dir", "test_config.yml")
|
||||
|
||||
result = haven_vlm_utility.save_yaml_config(self.test_config, config_path)
|
||||
|
||||
self.assertTrue(result)
|
||||
self.assertTrue(os.path.exists(config_path))
|
||||
|
||||
def test_save_yaml_config_permission_error(self):
|
||||
"""Test saving YAML configuration with permission error"""
|
||||
with patch('builtins.open', side_effect=PermissionError("Permission denied")):
|
||||
result = haven_vlm_utility.save_yaml_config(self.test_config, "test.yml")
|
||||
|
||||
self.assertFalse(result)
|
||||
|
||||
|
||||
class TestFileValidation(unittest.TestCase):
|
||||
"""Test cases for file validation functions"""
|
||||
|
||||
def test_validate_file_path_existing_file(self):
|
||||
"""Test validating an existing file path"""
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
file_path = f.name
|
||||
|
||||
try:
|
||||
result = haven_vlm_utility.validate_file_path(file_path)
|
||||
self.assertTrue(result)
|
||||
finally:
|
||||
os.unlink(file_path)
|
||||
|
||||
def test_validate_file_path_nonexistent_file(self):
|
||||
"""Test validating a nonexistent file path"""
|
||||
result = haven_vlm_utility.validate_file_path("nonexistent_file.txt")
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_validate_file_path_directory(self):
|
||||
"""Test validating a directory path"""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
result = haven_vlm_utility.validate_file_path(temp_dir)
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_validate_file_path_permission_error(self):
|
||||
"""Test validating file path with permission error"""
|
||||
with patch('os.path.isfile', side_effect=OSError("Permission denied")):
|
||||
result = haven_vlm_utility.validate_file_path("test.txt")
|
||||
self.assertFalse(result)
|
||||
|
||||
|
||||
class TestFileExtensionFunctions(unittest.TestCase):
|
||||
"""Test cases for file extension functions"""
|
||||
|
||||
def test_get_file_extension_with_extension(self):
|
||||
"""Test getting file extension from file with extension"""
|
||||
result = haven_vlm_utility.get_file_extension("test.mp4")
|
||||
self.assertEqual(result, ".mp4")
|
||||
|
||||
def test_get_file_extension_without_extension(self):
|
||||
"""Test getting file extension from file without extension"""
|
||||
result = haven_vlm_utility.get_file_extension("test")
|
||||
self.assertEqual(result, "")
|
||||
|
||||
def test_get_file_extension_multiple_dots(self):
|
||||
"""Test getting file extension from file with multiple dots"""
|
||||
result = haven_vlm_utility.get_file_extension("test.backup.mp4")
|
||||
self.assertEqual(result, ".mp4")
|
||||
|
||||
def test_get_file_extension_uppercase(self):
|
||||
"""Test getting file extension from file with uppercase extension"""
|
||||
result = haven_vlm_utility.get_file_extension("test.MP4")
|
||||
self.assertEqual(result, ".mp4")
|
||||
|
||||
def test_is_video_file_valid_extensions(self):
|
||||
"""Test video file detection with valid extensions"""
|
||||
video_files = ["test.mp4", "test.avi", "test.mkv", "test.mov", "test.wmv", "test.flv", "test.webm", "test.m4v"]
|
||||
|
||||
for video_file in video_files:
|
||||
result = haven_vlm_utility.is_video_file(video_file)
|
||||
self.assertTrue(result, f"Failed for {video_file}")
|
||||
|
||||
def test_is_video_file_invalid_extensions(self):
|
||||
"""Test video file detection with invalid extensions"""
|
||||
non_video_files = ["test.jpg", "test.txt", "test.pdf", "test.exe"]
|
||||
|
||||
for non_video_file in non_video_files:
|
||||
result = haven_vlm_utility.is_video_file(non_video_file)
|
||||
self.assertFalse(result, f"Failed for {non_video_file}")
|
||||
|
||||
def test_is_image_file_valid_extensions(self):
|
||||
"""Test image file detection with valid extensions"""
|
||||
image_files = ["test.jpg", "test.jpeg", "test.png", "test.gif", "test.bmp", "test.tiff", "test.webp"]
|
||||
|
||||
for image_file in image_files:
|
||||
result = haven_vlm_utility.is_image_file(image_file)
|
||||
self.assertTrue(result, f"Failed for {image_file}")
|
||||
|
||||
def test_is_image_file_invalid_extensions(self):
|
||||
"""Test image file detection with invalid extensions"""
|
||||
non_image_files = ["test.mp4", "test.txt", "test.pdf", "test.exe"]
|
||||
|
||||
for non_image_file in non_image_files:
|
||||
result = haven_vlm_utility.is_image_file(non_image_file)
|
||||
self.assertFalse(result, f"Failed for {non_image_file}")
|
||||
|
||||
|
||||
class TestFormattingFunctions(unittest.TestCase):
|
||||
"""Test cases for formatting functions"""
|
||||
|
||||
def test_format_duration_seconds(self):
|
||||
"""Test formatting duration in seconds"""
|
||||
result = haven_vlm_utility.format_duration(45.5)
|
||||
self.assertEqual(result, "45.5s")
|
||||
|
||||
def test_format_duration_minutes(self):
|
||||
"""Test formatting duration in minutes"""
|
||||
result = haven_vlm_utility.format_duration(125.3)
|
||||
self.assertEqual(result, "2m 5.3s")
|
||||
|
||||
def test_format_duration_hours(self):
|
||||
"""Test formatting duration in hours"""
|
||||
result = haven_vlm_utility.format_duration(7325.7)
|
||||
self.assertEqual(result, "2h 2m 5.7s")
|
||||
|
||||
def test_format_duration_zero(self):
|
||||
"""Test formatting zero duration"""
|
||||
result = haven_vlm_utility.format_duration(0)
|
||||
self.assertEqual(result, "0.0s")
|
||||
|
||||
def test_format_file_size_bytes(self):
|
||||
"""Test formatting file size in bytes"""
|
||||
result = haven_vlm_utility.format_file_size(512)
|
||||
self.assertEqual(result, "512.0 B")
|
||||
|
||||
def test_format_file_size_kilobytes(self):
|
||||
"""Test formatting file size in kilobytes"""
|
||||
result = haven_vlm_utility.format_file_size(1536)
|
||||
self.assertEqual(result, "1.5 KB")
|
||||
|
||||
def test_format_file_size_megabytes(self):
|
||||
"""Test formatting file size in megabytes"""
|
||||
result = haven_vlm_utility.format_file_size(1572864)
|
||||
self.assertEqual(result, "1.5 MB")
|
||||
|
||||
def test_format_file_size_gigabytes(self):
|
||||
"""Test formatting file size in gigabytes"""
|
||||
result = haven_vlm_utility.format_file_size(1610612736)
|
||||
self.assertEqual(result, "1.5 GB")
|
||||
|
||||
def test_format_file_size_zero(self):
|
||||
"""Test formatting zero file size"""
|
||||
result = haven_vlm_utility.format_file_size(0)
|
||||
self.assertEqual(result, "0.0 B")
|
||||
|
||||
|
||||
class TestSanitizationFunctions(unittest.TestCase):
|
||||
"""Test cases for sanitization functions"""
|
||||
|
||||
def test_sanitize_filename_valid(self):
|
||||
"""Test sanitizing a valid filename"""
|
||||
result = haven_vlm_utility.sanitize_filename("valid_filename.txt")
|
||||
self.assertEqual(result, "valid_filename.txt")
|
||||
|
||||
def test_sanitize_filename_invalid_chars(self):
|
||||
"""Test sanitizing filename with invalid characters"""
|
||||
result = haven_vlm_utility.sanitize_filename("file<name>:with/invalid\\chars|?*")
|
||||
self.assertEqual(result, "file_name__with_invalid_chars___")
|
||||
|
||||
def test_sanitize_filename_leading_trailing_spaces(self):
|
||||
"""Test sanitizing filename with leading/trailing spaces"""
|
||||
result = haven_vlm_utility.sanitize_filename(" filename.txt ")
|
||||
self.assertEqual(result, "filename.txt")
|
||||
|
||||
def test_sanitize_filename_leading_trailing_dots(self):
|
||||
"""Test sanitizing filename with leading/trailing dots"""
|
||||
result = haven_vlm_utility.sanitize_filename("...filename.txt...")
|
||||
self.assertEqual(result, "filename.txt")
|
||||
|
||||
def test_sanitize_filename_empty(self):
|
||||
"""Test sanitizing empty filename"""
|
||||
result = haven_vlm_utility.sanitize_filename("")
|
||||
self.assertEqual(result, "unnamed")
|
||||
|
||||
def test_sanitize_filename_only_spaces(self):
|
||||
"""Test sanitizing filename with only spaces"""
|
||||
result = haven_vlm_utility.sanitize_filename(" ")
|
||||
self.assertEqual(result, "unnamed")
|
||||
|
||||
|
||||
class TestBackupFunctions(unittest.TestCase):
|
||||
"""Test cases for backup functions"""
|
||||
|
||||
def test_create_backup_file_success(self):
|
||||
"""Test successfully creating a backup file"""
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
original_file = f.name
|
||||
f.write(b"test content")
|
||||
|
||||
try:
|
||||
result = haven_vlm_utility.create_backup_file(original_file)
|
||||
|
||||
self.assertIsNotNone(result)
|
||||
self.assertTrue(os.path.exists(result))
|
||||
self.assertTrue(result.endswith(".backup"))
|
||||
|
||||
# Verify backup content
|
||||
with open(result, 'rb') as f:
|
||||
content = f.read()
|
||||
self.assertEqual(content, b"test content")
|
||||
|
||||
# Clean up backup
|
||||
os.unlink(result)
|
||||
finally:
|
||||
os.unlink(original_file)
|
||||
|
||||
def test_create_backup_file_custom_suffix(self):
|
||||
"""Test creating backup file with custom suffix"""
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
original_file = f.name
|
||||
f.write(b"test content")
|
||||
|
||||
try:
|
||||
result = haven_vlm_utility.create_backup_file(original_file, ".custom")
|
||||
|
||||
self.assertIsNotNone(result)
|
||||
self.assertTrue(result.endswith(".custom"))
|
||||
|
||||
# Clean up backup
|
||||
os.unlink(result)
|
||||
finally:
|
||||
os.unlink(original_file)
|
||||
|
||||
def test_create_backup_file_nonexistent(self):
|
||||
"""Test creating backup of nonexistent file"""
|
||||
result = haven_vlm_utility.create_backup_file("nonexistent_file.txt")
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_create_backup_file_permission_error(self):
|
||||
"""Test creating backup file with permission error"""
|
||||
with patch('shutil.copy2', side_effect=PermissionError("Permission denied")):
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
original_file = f.name
|
||||
|
||||
try:
|
||||
result = haven_vlm_utility.create_backup_file(original_file)
|
||||
self.assertIsNone(result)
|
||||
finally:
|
||||
os.unlink(original_file)
|
||||
|
||||
|
||||
class TestDictionaryOperations(unittest.TestCase):
|
||||
"""Test cases for dictionary operations"""
|
||||
|
||||
def test_merge_dictionaries_simple(self):
|
||||
"""Test simple dictionary merging"""
|
||||
dict1 = {"a": 1, "b": 2}
|
||||
dict2 = {"c": 3, "d": 4}
|
||||
|
||||
result = haven_vlm_utility.merge_dictionaries(dict1, dict2)
|
||||
|
||||
expected = {"a": 1, "b": 2, "c": 3, "d": 4}
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_merge_dictionaries_overwrite(self):
|
||||
"""Test dictionary merging with overwrite"""
|
||||
dict1 = {"a": 1, "b": 2}
|
||||
dict2 = {"b": 3, "c": 4}
|
||||
|
||||
result = haven_vlm_utility.merge_dictionaries(dict1, dict2, overwrite=True)
|
||||
|
||||
expected = {"a": 1, "b": 3, "c": 4}
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_merge_dictionaries_no_overwrite(self):
|
||||
"""Test dictionary merging without overwrite"""
|
||||
dict1 = {"a": 1, "b": 2}
|
||||
dict2 = {"b": 3, "c": 4}
|
||||
|
||||
result = haven_vlm_utility.merge_dictionaries(dict1, dict2, overwrite=False)
|
||||
|
||||
expected = {"a": 1, "b": 2, "c": 4}
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_merge_dictionaries_nested(self):
|
||||
"""Test merging nested dictionaries"""
|
||||
dict1 = {"a": 1, "b": {"x": 10, "y": 20}}
|
||||
dict2 = {"c": 3, "b": {"y": 25, "z": 30}}
|
||||
|
||||
result = haven_vlm_utility.merge_dictionaries(dict1, dict2, overwrite=True)
|
||||
|
||||
expected = {"a": 1, "b": {"x": 10, "y": 25, "z": 30}, "c": 3}
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_merge_dictionaries_empty(self):
|
||||
"""Test merging with empty dictionaries"""
|
||||
dict1 = {}
|
||||
dict2 = {"a": 1, "b": 2}
|
||||
|
||||
result = haven_vlm_utility.merge_dictionaries(dict1, dict2)
|
||||
|
||||
self.assertEqual(result, dict2)
|
||||
|
||||
|
||||
class TestListOperations(unittest.TestCase):
|
||||
"""Test cases for list operations"""
|
||||
|
||||
def test_chunk_list_even_chunks(self):
|
||||
"""Test chunking list into even chunks"""
|
||||
lst = [1, 2, 3, 4, 5, 6]
|
||||
result = haven_vlm_utility.chunk_list(lst, 2)
|
||||
|
||||
expected = [[1, 2], [3, 4], [5, 6]]
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_chunk_list_uneven_chunks(self):
|
||||
"""Test chunking list into uneven chunks"""
|
||||
lst = [1, 2, 3, 4, 5]
|
||||
result = haven_vlm_utility.chunk_list(lst, 2)
|
||||
|
||||
expected = [[1, 2], [3, 4], [5]]
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_chunk_list_empty_list(self):
|
||||
"""Test chunking empty list"""
|
||||
lst = []
|
||||
result = haven_vlm_utility.chunk_list(lst, 3)
|
||||
|
||||
expected = []
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_chunk_list_chunk_size_larger_than_list(self):
|
||||
"""Test chunking when chunk size is larger than list"""
|
||||
lst = [1, 2, 3]
|
||||
result = haven_vlm_utility.chunk_list(lst, 5)
|
||||
|
||||
expected = [[1, 2, 3]]
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
|
||||
class TestRetryOperations(unittest.TestCase):
|
||||
"""Test cases for retry operations"""
|
||||
|
||||
def test_retry_operation_success_first_try(self):
|
||||
"""Test retry operation that succeeds on first try"""
|
||||
def test_func():
|
||||
return "success"
|
||||
|
||||
result = haven_vlm_utility.retry_operation(test_func)
|
||||
|
||||
self.assertEqual(result, "success")
|
||||
|
||||
def test_retry_operation_success_after_retries(self):
|
||||
"""Test retry operation that succeeds after some retries"""
|
||||
call_count = 0
|
||||
|
||||
def test_func():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
raise ValueError("Temporary error")
|
||||
return "success"
|
||||
|
||||
result = haven_vlm_utility.retry_operation(test_func, max_retries=3, delay=0.1)
|
||||
|
||||
self.assertEqual(result, "success")
|
||||
self.assertEqual(call_count, 3)
|
||||
|
||||
def test_retry_operation_all_retries_fail(self):
|
||||
"""Test retry operation that fails all retries"""
|
||||
def test_func():
|
||||
raise ValueError("Persistent error")
|
||||
|
||||
result = haven_vlm_utility.retry_operation(test_func, max_retries=2, delay=0.1)
|
||||
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_retry_operation_with_arguments(self):
|
||||
"""Test retry operation with function arguments"""
|
||||
def test_func(a, b, c=10):
|
||||
return a + b + c
|
||||
|
||||
result = haven_vlm_utility.retry_operation(test_func, max_retries=1, delay=0.1, 1, 2, c=5)
|
||||
|
||||
self.assertEqual(result, 8)
|
||||
|
||||
def test_retry_operation_with_keyword_arguments(self):
|
||||
"""Test retry operation with keyword arguments"""
|
||||
def test_func(**kwargs):
|
||||
return kwargs.get('value', 0)
|
||||
|
||||
result = haven_vlm_utility.retry_operation(test_func, max_retries=1, delay=0.1, value=42)
|
||||
|
||||
self.assertEqual(result, 42)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Loading…
x
Reference in New Issue
Block a user