fix(tasks): replace fire-and-forget asyncio.create_task with create_tracked_task
Add utils/task_tracker.py with create_tracked_task() that wraps background tasks with error logging, cancellation handling, and reference tracking. Replace all 17 fire-and-forget asyncio.create_task() calls across 7 files: - bot/bot.py (5 interjection checks) - bot/utils/autonomous.py (2 check-and-act/react tasks) - bot/utils/bipolar_mode.py (3 argument tasks) - bot/commands/uno.py (1 game loop task) - bot/utils/voice_receiver.py (3 STT/interruption callbacks) - bot/utils/persona_dialogue.py (4 dialogue turn/interjection tasks) Previously-tracked tasks (voice_audio.py, voice_manager.py) were left as-is since they already store task references for cancellation. Closes #1
This commit is contained in:
27
bot/bot.py
27
bot/bot.py
@@ -10,6 +10,9 @@ import signal
|
|||||||
import atexit
|
import atexit
|
||||||
from api import app
|
from api import app
|
||||||
|
|
||||||
|
# Import new configuration system
|
||||||
|
from config import CONFIG, SECRETS, validate_config, print_config_summary
|
||||||
|
|
||||||
from server_manager import server_manager
|
from server_manager import server_manager
|
||||||
from utils.scheduled import (
|
from utils.scheduled import (
|
||||||
send_monday_video
|
send_monday_video
|
||||||
@@ -47,12 +50,26 @@ from utils.autonomous import (
|
|||||||
from utils.dm_logger import dm_logger
|
from utils.dm_logger import dm_logger
|
||||||
from utils.dm_interaction_analyzer import init_dm_analyzer
|
from utils.dm_interaction_analyzer import init_dm_analyzer
|
||||||
from utils.logger import get_logger
|
from utils.logger import get_logger
|
||||||
|
from utils.task_tracker import create_tracked_task
|
||||||
|
|
||||||
import globals
|
import globals
|
||||||
|
|
||||||
# Initialize bot logger
|
# Initialize bot logger
|
||||||
logger = get_logger('bot')
|
logger = get_logger('bot')
|
||||||
|
|
||||||
|
# Validate configuration on startup
|
||||||
|
is_valid, validation_errors = validate_config()
|
||||||
|
if not is_valid:
|
||||||
|
logger.error("❌ Configuration validation failed!")
|
||||||
|
for error in validation_errors:
|
||||||
|
logger.error(f" - {error}")
|
||||||
|
logger.error("Please check your .env file and restart.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Print configuration summary for debugging
|
||||||
|
if CONFIG.autonomous.debug_mode:
|
||||||
|
print_config_summary()
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
format="%(asctime)s %(levelname)s: %(message)s",
|
format="%(asctime)s %(levelname)s: %(message)s",
|
||||||
@@ -281,7 +298,7 @@ async def on_message(message):
|
|||||||
try:
|
try:
|
||||||
from utils.persona_dialogue import check_for_interjection
|
from utils.persona_dialogue import check_for_interjection
|
||||||
current_persona = "evil" if globals.EVIL_MODE else "miku"
|
current_persona = "evil" if globals.EVIL_MODE else "miku"
|
||||||
asyncio.create_task(check_for_interjection(response_message, current_persona))
|
create_tracked_task(check_for_interjection(response_message, current_persona), task_name="interjection_check")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error checking for persona interjection: {e}")
|
logger.error(f"Error checking for persona interjection: {e}")
|
||||||
|
|
||||||
@@ -353,7 +370,7 @@ async def on_message(message):
|
|||||||
try:
|
try:
|
||||||
from utils.persona_dialogue import check_for_interjection
|
from utils.persona_dialogue import check_for_interjection
|
||||||
current_persona = "evil" if globals.EVIL_MODE else "miku"
|
current_persona = "evil" if globals.EVIL_MODE else "miku"
|
||||||
asyncio.create_task(check_for_interjection(response_message, current_persona))
|
create_tracked_task(check_for_interjection(response_message, current_persona), task_name="interjection_check")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error checking for persona interjection: {e}")
|
logger.error(f"Error checking for persona interjection: {e}")
|
||||||
|
|
||||||
@@ -435,7 +452,7 @@ async def on_message(message):
|
|||||||
try:
|
try:
|
||||||
from utils.persona_dialogue import check_for_interjection
|
from utils.persona_dialogue import check_for_interjection
|
||||||
current_persona = "evil" if globals.EVIL_MODE else "miku"
|
current_persona = "evil" if globals.EVIL_MODE else "miku"
|
||||||
asyncio.create_task(check_for_interjection(response_message, current_persona))
|
create_tracked_task(check_for_interjection(response_message, current_persona), task_name="interjection_check")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error checking for persona interjection: {e}")
|
logger.error(f"Error checking for persona interjection: {e}")
|
||||||
|
|
||||||
@@ -557,7 +574,7 @@ async def on_message(message):
|
|||||||
try:
|
try:
|
||||||
from utils.persona_dialogue import check_for_interjection
|
from utils.persona_dialogue import check_for_interjection
|
||||||
current_persona = "evil" if globals.EVIL_MODE else "miku"
|
current_persona = "evil" if globals.EVIL_MODE else "miku"
|
||||||
asyncio.create_task(check_for_interjection(response_message, current_persona))
|
create_tracked_task(check_for_interjection(response_message, current_persona), task_name="interjection_check")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error checking for persona interjection: {e}")
|
logger.error(f"Error checking for persona interjection: {e}")
|
||||||
|
|
||||||
@@ -650,7 +667,7 @@ async def on_message(message):
|
|||||||
current_persona = "evil" if globals.EVIL_MODE else "miku"
|
current_persona = "evil" if globals.EVIL_MODE else "miku"
|
||||||
logger.debug(f"Creating interjection check task for persona: {current_persona}")
|
logger.debug(f"Creating interjection check task for persona: {current_persona}")
|
||||||
# Pass the bot's response message for analysis
|
# Pass the bot's response message for analysis
|
||||||
asyncio.create_task(check_for_interjection(response_message, current_persona))
|
create_tracked_task(check_for_interjection(response_message, current_persona), task_name="interjection_check")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error checking for persona interjection: {e}")
|
logger.error(f"Error checking for persona interjection: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
from utils.logger import get_logger
|
from utils.logger import get_logger
|
||||||
|
from utils.task_tracker import create_tracked_task
|
||||||
|
|
||||||
logger = get_logger('uno')
|
logger = get_logger('uno')
|
||||||
|
|
||||||
@@ -64,7 +65,7 @@ async def join_uno_game(message: discord.Message, room_code: str):
|
|||||||
await message.channel.send(f"✅ Joined room **{room_code}**! Waiting for Player 1 to start the game... 🎮")
|
await message.channel.send(f"✅ Joined room **{room_code}**! Waiting for Player 1 to start the game... 🎮")
|
||||||
|
|
||||||
# Start the game loop
|
# Start the game loop
|
||||||
asyncio.create_task(player.play_game())
|
create_tracked_task(player.play_game(), task_name=f"uno_game_{room_code}")
|
||||||
else:
|
else:
|
||||||
await message.channel.send(f"❌ Couldn't join room **{room_code}**. Make sure the room exists and has space!")
|
await message.channel.send(f"❌ Couldn't join room **{room_code}**. Make sure the room exists and has space!")
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from utils.autonomous_engine import autonomous_engine
|
|||||||
from server_manager import server_manager
|
from server_manager import server_manager
|
||||||
import globals
|
import globals
|
||||||
from utils.logger import get_logger
|
from utils.logger import get_logger
|
||||||
|
from utils.task_tracker import create_tracked_task
|
||||||
|
|
||||||
logger = get_logger('autonomous')
|
logger = get_logger('autonomous')
|
||||||
|
|
||||||
@@ -166,10 +167,10 @@ def on_message_event(message):
|
|||||||
|
|
||||||
# Check if we should act (async, non-blocking)
|
# Check if we should act (async, non-blocking)
|
||||||
if not message.author.bot: # Only check for human messages
|
if not message.author.bot: # Only check for human messages
|
||||||
asyncio.create_task(_check_and_act(guild_id))
|
create_tracked_task(_check_and_act(guild_id), task_name="autonomous_check_act")
|
||||||
|
|
||||||
# Also check if we should react to this specific message
|
# Also check if we should react to this specific message
|
||||||
asyncio.create_task(_check_and_react(guild_id, message))
|
create_tracked_task(_check_and_react(guild_id, message), task_name="autonomous_check_react")
|
||||||
|
|
||||||
|
|
||||||
async def _check_and_react(guild_id: int, message):
|
async def _check_and_react(guild_id: int, message):
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import asyncio
|
|||||||
import discord
|
import discord
|
||||||
import globals
|
import globals
|
||||||
from utils.logger import get_logger
|
from utils.logger import get_logger
|
||||||
|
from utils.task_tracker import create_tracked_task
|
||||||
|
|
||||||
logger = get_logger('persona')
|
logger = get_logger('persona')
|
||||||
|
|
||||||
@@ -1113,7 +1114,7 @@ async def maybe_trigger_argument(channel: discord.TextChannel, client, context:
|
|||||||
|
|
||||||
if should_trigger_argument():
|
if should_trigger_argument():
|
||||||
# Run argument in background
|
# Run argument in background
|
||||||
asyncio.create_task(run_argument(channel, client, context))
|
create_tracked_task(run_argument(channel, client, context), task_name="bipolar_argument")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
@@ -1136,7 +1137,7 @@ async def force_trigger_argument(channel: discord.TextChannel, client, context:
|
|||||||
logger.warning("Argument already in progress in this channel")
|
logger.warning("Argument already in progress in this channel")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
asyncio.create_task(run_argument(channel, client, context, starting_message))
|
create_tracked_task(run_argument(channel, client, context, starting_message), task_name="bipolar_argument_forced")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@@ -1174,5 +1175,5 @@ async def force_trigger_argument_from_message_id(channel_id: int, message_id: in
|
|||||||
return False, f"Failed to fetch message: {str(e)}"
|
return False, f"Failed to fetch message: {str(e)}"
|
||||||
|
|
||||||
# Trigger the argument with this message as starting point
|
# Trigger the argument with this message as starting point
|
||||||
asyncio.create_task(run_argument(channel, client, context, message))
|
create_tracked_task(run_argument(channel, client, context, message), task_name="bipolar_argument_from_msg")
|
||||||
return True, None
|
return True, None
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import asyncio
|
|||||||
import time
|
import time
|
||||||
import globals
|
import globals
|
||||||
from utils.logger import get_logger
|
from utils.logger import get_logger
|
||||||
|
from utils.task_tracker import create_tracked_task
|
||||||
|
|
||||||
logger = get_logger('persona')
|
logger = get_logger('persona')
|
||||||
|
|
||||||
@@ -668,15 +669,16 @@ You can use emojis naturally! ✨💙"""
|
|||||||
opposite = "evil" if responding_persona == "miku" else "miku"
|
opposite = "evil" if responding_persona == "miku" else "miku"
|
||||||
|
|
||||||
if should_continue and confidence in ["HIGH", "MEDIUM"]:
|
if should_continue and confidence in ["HIGH", "MEDIUM"]:
|
||||||
asyncio.create_task(self._next_turn(channel, opposite))
|
create_tracked_task(self._next_turn(channel, opposite), task_name="persona_next_turn")
|
||||||
|
|
||||||
elif should_continue and confidence == "LOW":
|
elif should_continue and confidence == "LOW":
|
||||||
asyncio.create_task(self._next_turn(channel, opposite))
|
create_tracked_task(self._next_turn(channel, opposite), task_name="persona_next_turn")
|
||||||
|
|
||||||
elif not should_continue and confidence == "LOW":
|
elif not should_continue and confidence == "LOW":
|
||||||
# Offer opposite persona the last word
|
# Offer opposite persona the last word
|
||||||
asyncio.create_task(
|
create_tracked_task(
|
||||||
self._offer_last_word(channel, opposite, context + f"\n{responding_persona}: {response_text}")
|
self._offer_last_word(channel, opposite, context + f"\n{responding_persona}: {response_text}"),
|
||||||
|
task_name="persona_last_word"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Clear signal to end
|
# Clear signal to end
|
||||||
@@ -788,7 +790,7 @@ Don't force a response if you have nothing meaningful to contribute."""
|
|||||||
logger.info(f"Dialogue ended after last word, {state['turn_count']} turns total")
|
logger.info(f"Dialogue ended after last word, {state['turn_count']} turns total")
|
||||||
self.end_dialogue(channel.id)
|
self.end_dialogue(channel.id)
|
||||||
else:
|
else:
|
||||||
asyncio.create_task(self._next_turn(channel, opposite))
|
create_tracked_task(self._next_turn(channel, opposite), task_name="persona_next_turn")
|
||||||
|
|
||||||
# ========================================================================
|
# ========================================================================
|
||||||
# ARGUMENT ESCALATION
|
# ARGUMENT ESCALATION
|
||||||
@@ -953,8 +955,9 @@ async def check_for_interjection(message: discord.Message, current_persona: str)
|
|||||||
|
|
||||||
# Start dialogue with the opposite persona responding first
|
# Start dialogue with the opposite persona responding first
|
||||||
dialogue_manager.start_dialogue(message.channel.id)
|
dialogue_manager.start_dialogue(message.channel.id)
|
||||||
asyncio.create_task(
|
create_tracked_task(
|
||||||
dialogue_manager.handle_dialogue_turn(message.channel, opposite_persona, trigger_reason=reason)
|
dialogue_manager.handle_dialogue_turn(message.channel, opposite_persona, trigger_reason=reason),
|
||||||
|
task_name="persona_dialogue_turn"
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
54
bot/utils/task_tracker.py
Normal file
54
bot/utils/task_tracker.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
# utils/task_tracker.py
|
||||||
|
"""
|
||||||
|
Tracked asyncio task creation utility.
|
||||||
|
|
||||||
|
Replaces fire-and-forget asyncio.create_task() calls with error-logging wrappers
|
||||||
|
so that exceptions in background tasks are never silently swallowed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Optional, Coroutine, Set
|
||||||
|
from utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("task_tracker")
|
||||||
|
|
||||||
|
# Keep references to running tasks so they aren't garbage-collected
|
||||||
|
_active_tasks: Set[asyncio.Task] = set()
|
||||||
|
|
||||||
|
|
||||||
|
def create_tracked_task(
|
||||||
|
coro: Coroutine,
|
||||||
|
task_name: Optional[str] = None,
|
||||||
|
) -> asyncio.Task:
|
||||||
|
"""
|
||||||
|
Create an asyncio task with automatic error logging.
|
||||||
|
|
||||||
|
Unlike bare asyncio.create_task(), this wrapper:
|
||||||
|
- Names the task for easier debugging
|
||||||
|
- Logs any unhandled exception (with full traceback) instead of swallowing it
|
||||||
|
- Keeps a strong reference so the task isn't garbage-collected mid-flight
|
||||||
|
- Auto-cleans the reference set when the task finishes
|
||||||
|
|
||||||
|
Args:
|
||||||
|
coro: The coroutine to schedule.
|
||||||
|
task_name: Human-readable name for log messages.
|
||||||
|
Defaults to the coroutine's __qualname__.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created asyncio.Task (tracked internally).
|
||||||
|
"""
|
||||||
|
name = task_name or getattr(coro, "__qualname__", str(coro))
|
||||||
|
|
||||||
|
async def _wrapped():
|
||||||
|
try:
|
||||||
|
await coro
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.debug(f"Task '{name}' was cancelled")
|
||||||
|
raise # re-raise so Task.cancelled() works correctly
|
||||||
|
except Exception:
|
||||||
|
logger.error(f"Background task '{name}' failed", exc_info=True)
|
||||||
|
|
||||||
|
task = asyncio.create_task(_wrapped(), name=name)
|
||||||
|
_active_tasks.add(task)
|
||||||
|
task.add_done_callback(_active_tasks.discard)
|
||||||
|
return task
|
||||||
@@ -17,6 +17,7 @@ import discord
|
|||||||
from discord.ext import voice_recv
|
from discord.ext import voice_recv
|
||||||
|
|
||||||
from utils.stt_client import STTClient
|
from utils.stt_client import STTClient
|
||||||
|
from utils.task_tracker import create_tracked_task
|
||||||
|
|
||||||
logger = logging.getLogger('voice_receiver')
|
logger = logging.getLogger('voice_receiver')
|
||||||
|
|
||||||
@@ -256,11 +257,11 @@ class VoiceReceiverSink(voice_recv.AudioSink):
|
|||||||
stt_client = STTClient(
|
stt_client = STTClient(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
stt_url=self.stt_url,
|
stt_url=self.stt_url,
|
||||||
on_partial_transcript=lambda text, timestamp: asyncio.create_task(
|
on_partial_transcript=lambda text, timestamp: create_tracked_task(
|
||||||
self._on_partial_transcript(user_id, text)
|
self._on_partial_transcript(user_id, text), task_name="stt_partial_transcript"
|
||||||
),
|
),
|
||||||
on_final_transcript=lambda text, timestamp: asyncio.create_task(
|
on_final_transcript=lambda text, timestamp: create_tracked_task(
|
||||||
self._on_final_transcript(user_id, text, user)
|
self._on_final_transcript(user_id, text, user), task_name="stt_final_transcript"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -421,8 +422,9 @@ class VoiceReceiverSink(voice_recv.AudioSink):
|
|||||||
self.interruption_audio_count.pop(user_id, None)
|
self.interruption_audio_count.pop(user_id, None)
|
||||||
|
|
||||||
# Call interruption handler (this sets miku_speaking=False)
|
# Call interruption handler (this sets miku_speaking=False)
|
||||||
asyncio.create_task(
|
create_tracked_task(
|
||||||
self.voice_manager.on_user_interruption(user_id)
|
self.voice_manager.on_user_interruption(user_id),
|
||||||
|
task_name="voice_user_interruption"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Audio below RMS threshold (silence) - reset interruption tracking
|
# Audio below RMS threshold (silence) - reset interruption tracking
|
||||||
|
|||||||
Reference in New Issue
Block a user