293 lines
11 KiB
Python
293 lines
11 KiB
Python
# face_detector_manager.py
|
|
"""
|
|
Manages on-demand starting/stopping of anime-face-detector container
|
|
to free up VRAM when not needed.
|
|
"""
|
|
|
|
import asyncio
|
|
import aiohttp
|
|
import subprocess
|
|
import time
|
|
from typing import Optional, Dict
|
|
from utils.logger import get_logger
|
|
|
|
logger = get_logger('gpu')
|
|
|
|
|
|
class FaceDetectorManager:
|
|
"""Manages the anime-face-detector container lifecycle"""
|
|
|
|
FACE_DETECTOR_API = "http://anime-face-detector:6078/detect"
|
|
HEALTH_ENDPOINT = "http://anime-face-detector:6078/health"
|
|
CONTAINER_NAME = "anime-face-detector"
|
|
STARTUP_TIMEOUT = 60 # seconds - increased to allow for model loading
|
|
|
|
def __init__(self):
|
|
self.is_running = False
|
|
|
|
def _container_exists(self) -> bool:
|
|
"""Check if the anime-face-detector container exists (created but may not be running)"""
|
|
try:
|
|
result = subprocess.run(
|
|
["docker", "ps", "-a", "--filter", f"name=^/{self.CONTAINER_NAME}$", "--format", "{{.Names}}"],
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=5
|
|
)
|
|
return self.CONTAINER_NAME in result.stdout
|
|
except Exception as e:
|
|
logger.error(f"Error checking if container exists: {e}")
|
|
return False
|
|
|
|
def _create_container(self, debug: bool = False) -> bool:
|
|
"""Create the anime-face-detector container using docker run"""
|
|
try:
|
|
if debug:
|
|
logger.info("Creating anime-face-detector container...")
|
|
|
|
# Run docker run command to create the container (without starting it)
|
|
# This replicates the docker-compose configuration for anime-face-detector
|
|
cmd = [
|
|
"docker", "create",
|
|
"--name", self.CONTAINER_NAME,
|
|
"--network", "miku-discord_default", # Use the same network as miku-bot
|
|
"--runtime", "nvidia",
|
|
"-e", "NVIDIA_VISIBLE_DEVICES=all",
|
|
"-e", "NVIDIA_DRIVER_CAPABILITIES=compute,utility",
|
|
"-p", "7860:7860",
|
|
"-p", "6078:6078",
|
|
"--restart", "no",
|
|
"--gpus", "all",
|
|
"miku-discord-anime-face-detector:latest"
|
|
]
|
|
|
|
result = subprocess.run(
|
|
cmd,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=30
|
|
)
|
|
|
|
if result.returncode != 0:
|
|
if debug:
|
|
logger.error(f"Failed to create container: {result.stderr}")
|
|
return False
|
|
|
|
if debug:
|
|
logger.info("Container created successfully")
|
|
return True
|
|
|
|
except Exception as e:
|
|
if debug:
|
|
logger.error(f"Error creating container: {e}")
|
|
return False
|
|
|
|
async def start_container(self, debug: bool = False) -> bool:
|
|
"""
|
|
Start the anime-face-detector container.
|
|
Creates the container if it doesn't exist, then starts it.
|
|
|
|
Returns:
|
|
True if started successfully, False otherwise
|
|
"""
|
|
try:
|
|
if debug:
|
|
logger.debug("Starting anime-face-detector container...")
|
|
|
|
# Step 1: Check if container exists, create if it doesn't
|
|
if not self._container_exists():
|
|
if not self._create_container(debug=debug):
|
|
return False
|
|
|
|
# Step 2: Start the container
|
|
result = subprocess.run(
|
|
["docker", "start", self.CONTAINER_NAME],
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=30
|
|
)
|
|
|
|
if result.returncode != 0:
|
|
if debug:
|
|
logger.error(f"Failed to start container: {result.stderr}")
|
|
return False
|
|
|
|
# Wait for API to be ready
|
|
start_time = time.time()
|
|
while time.time() - start_time < self.STARTUP_TIMEOUT:
|
|
if await self._check_health():
|
|
self.is_running = True
|
|
if debug:
|
|
logger.info(f"Face detector container started and ready")
|
|
return True
|
|
await asyncio.sleep(1)
|
|
|
|
if debug:
|
|
logger.warning(f"Container started but API not ready after {self.STARTUP_TIMEOUT}s")
|
|
return False
|
|
|
|
except Exception as e:
|
|
if debug:
|
|
logger.error(f"Error starting face detector container: {e}")
|
|
return False
|
|
|
|
async def stop_container(self, debug: bool = False) -> bool:
|
|
"""
|
|
Stop the anime-face-detector container to free VRAM.
|
|
|
|
Returns:
|
|
True if stopped successfully, False otherwise
|
|
"""
|
|
try:
|
|
if debug:
|
|
logger.debug("Stopping anime-face-detector container...")
|
|
|
|
result = subprocess.run(
|
|
["docker", "stop", self.CONTAINER_NAME],
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=15
|
|
)
|
|
|
|
if result.returncode == 0:
|
|
self.is_running = False
|
|
if debug:
|
|
logger.info("Face detector container stopped")
|
|
return True
|
|
else:
|
|
if debug:
|
|
logger.error(f"Failed to stop container: {result.stderr}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
if debug:
|
|
logger.error(f"Error stopping face detector container: {e}")
|
|
return False
|
|
|
|
async def _check_health(self) -> bool:
|
|
"""Check if the face detector API is responding"""
|
|
try:
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(
|
|
self.HEALTH_ENDPOINT,
|
|
timeout=aiohttp.ClientTimeout(total=2)
|
|
) as response:
|
|
return response.status == 200
|
|
except:
|
|
return False
|
|
|
|
async def detect_face_with_management(
|
|
self,
|
|
image_bytes: bytes,
|
|
unload_vision_model: callable = None,
|
|
reload_vision_model: callable = None,
|
|
debug: bool = False
|
|
) -> Optional[Dict]:
|
|
"""
|
|
Detect face with automatic container lifecycle management.
|
|
|
|
Args:
|
|
image_bytes: Image data as bytes
|
|
unload_vision_model: Optional callback to unload vision model first
|
|
reload_vision_model: Optional callback to reload vision model after
|
|
debug: Enable debug output
|
|
|
|
Returns:
|
|
Detection dict or None
|
|
"""
|
|
container_was_started = False
|
|
|
|
try:
|
|
# Step 1: Unload vision model if callback provided
|
|
if unload_vision_model:
|
|
if debug:
|
|
logger.debug("Unloading vision model to free VRAM...")
|
|
await unload_vision_model()
|
|
await asyncio.sleep(2) # Give time for VRAM to clear
|
|
|
|
# Step 2: Start face detector if not running
|
|
if not self.is_running:
|
|
if not await self.start_container(debug=debug):
|
|
if debug:
|
|
logger.error("Could not start face detector container")
|
|
return None
|
|
container_was_started = True
|
|
|
|
# Step 3: Detect face
|
|
result = await self._detect_face_api(image_bytes, debug=debug)
|
|
|
|
return result
|
|
|
|
finally:
|
|
# Step 4: Stop container and reload vision model
|
|
if container_was_started:
|
|
await self.stop_container(debug=debug)
|
|
|
|
if reload_vision_model:
|
|
if debug:
|
|
logger.debug("Reloading vision model...")
|
|
await reload_vision_model()
|
|
|
|
async def _detect_face_api(self, image_bytes: bytes, debug: bool = False) -> Optional[Dict]:
|
|
"""Call the face detection API"""
|
|
try:
|
|
async with aiohttp.ClientSession() as session:
|
|
form = aiohttp.FormData()
|
|
form.add_field('file', image_bytes, filename='image.jpg', content_type='image/jpeg')
|
|
|
|
async with session.post(
|
|
self.FACE_DETECTOR_API,
|
|
data=form,
|
|
timeout=aiohttp.ClientTimeout(total=30)
|
|
) as response:
|
|
if response.status != 200:
|
|
if debug:
|
|
logger.warning(f"Face detection API returned status {response.status}")
|
|
return None
|
|
|
|
result = await response.json()
|
|
|
|
if result.get('count', 0) == 0:
|
|
if debug:
|
|
logger.debug("No faces detected by API")
|
|
return None
|
|
|
|
detections = result.get('detections', [])
|
|
if not detections:
|
|
return None
|
|
|
|
best_detection = max(detections, key=lambda d: d.get('confidence', 0))
|
|
bbox = best_detection.get('bbox', [])
|
|
confidence = best_detection.get('confidence', 0)
|
|
keypoints = best_detection.get('keypoints', [])
|
|
|
|
if len(bbox) >= 4:
|
|
x1, y1, x2, y2 = bbox[:4]
|
|
center_x = int((x1 + x2) / 2)
|
|
center_y = int((y1 + y2) / 2)
|
|
|
|
if debug:
|
|
width = int(x2 - x1)
|
|
height = int(y2 - y1)
|
|
logger.debug(f"Detected {len(detections)} face(s) via API, using best at ({center_x}, {center_y}) [confidence: {confidence:.2%}]")
|
|
logger.debug(f" Bounding box: x={int(x1)}, y={int(y1)}, w={width}, h={height}")
|
|
logger.debug(f" Keypoints: {len(keypoints)} facial landmarks detected")
|
|
|
|
return {
|
|
'center': (center_x, center_y),
|
|
'bbox': bbox,
|
|
'confidence': confidence,
|
|
'keypoints': keypoints,
|
|
'count': len(detections)
|
|
}
|
|
|
|
except Exception as e:
|
|
if debug:
|
|
logger.error(f"Error calling face detection API: {e}")
|
|
|
|
return None
|
|
|
|
|
|
# Global instance
|
|
face_detector_manager = FaceDetectorManager()
|