refactor(database): 将同步数据库操作迁移为异步操作

将整个项目的数据库操作从同步模式迁移为异步模式,主要涉及以下修改:

- 将 `with get_db_session()` 改为 `async with get_db_session()`
- 将同步的 SQLAlchemy 查询方法改为异步执行
- 更新相关的方法签名,添加 async/await 关键字
- 修复由于异步化导致的并发问题和性能问题

这些修改提高了数据库操作的并发性能,避免了阻塞主线程,提升了系统的整体响应能力。涉及修改的模块包括表情包管理、反提示注入统计、用户封禁管理、记忆系统、消息存储等多个核心组件。

BREAKING CHANGE: 所有涉及数据库操作的方法现在都需要使用异步调用,同步调用将不再工作
This commit is contained in:
Windpicker-owo
2025-09-28 15:42:30 +08:00
parent ff24bd8148
commit 08ef960947
35 changed files with 1180 additions and 1053 deletions

13
bot.py
View File

@@ -185,12 +185,12 @@ class MaiBotMain(BaseMain):
check_eula()
logger.info("检查EULA和隐私条款完成")
def initialize_database(self):
async def initialize_database(self):
"""初始化数据库"""
logger.info("正在初始化数据库连接...")
try:
initialize_sql_database(global_config.database)
await initialize_sql_database(global_config.database)
logger.info(f"数据库连接初始化成功,使用 {global_config.database.database_type} 数据库")
except Exception as e:
logger.error(f"数据库连接初始化失败: {e}")
@@ -211,11 +211,11 @@ class MaiBotMain(BaseMain):
self.main_system = MainSystem()
return self.main_system
def run(self):
async def run(self):
"""运行主程序"""
self.setup_timezone()
self.check_and_confirm_eula()
self.initialize_database()
await self.initialize_database()
return self.create_main_system()
@@ -225,14 +225,14 @@ if __name__ == "__main__":
try:
# 创建MaiBotMain实例并获取MainSystem
maibot = MaiBotMain()
main_system = maibot.run()
# 创建事件循环
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# 异步初始化数据库表结构
# 异步初始化数据库表结构
main_system = loop.run_until_complete(maibot.run())
loop.run_until_complete(maibot.initialize_database_async())
# 执行初始化和任务调度
loop.run_until_complete(main_system.initialize())
@@ -269,3 +269,4 @@ if __name__ == "__main__":
# 在程序退出前暂停,让你有机会看到输出
# input("按 Enter 键退出...") # <--- 添加这行
sys.exit(exit_code) # <--- 使用记录的退出码

View File

@@ -8,6 +8,8 @@
import datetime
from typing import Dict, Any
from sqlalchemy import select
from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import AntiInjectionStats, get_db_session
from src.config.config import global_config
@@ -27,9 +29,11 @@ class AntiInjectionStatistics:
async def get_or_create_stats():
"""获取或创建统计记录"""
try:
with get_db_session() as session:
async with get_db_session() as session:
# 获取最新的统计记录,如果没有则创建
stats = session.query(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()).first()
stats = (await session.execute(
select(AntiInjectionStats).order_by(AntiInjectionStats.id.desc())
)).scalars().first()
if not stats:
stats = AntiInjectionStats()
session.add(stats)
@@ -44,8 +48,10 @@ class AntiInjectionStatistics:
async def update_stats(**kwargs):
"""更新统计数据"""
try:
with get_db_session() as session:
stats = session.query(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()).first()
async with get_db_session() as session:
stats = (await session.execute(
select(AntiInjectionStats).order_by(AntiInjectionStats.id.desc())
)).scalars().first()
if not stats:
stats = AntiInjectionStats()
session.add(stats)
@@ -138,9 +144,9 @@ class AntiInjectionStatistics:
async def reset_stats():
"""重置统计信息"""
try:
with get_db_session() as session:
async with get_db_session() as session:
# 删除现有统计记录
session.query(AntiInjectionStats).delete()
await session.execute(select(AntiInjectionStats).delete())
await session.commit()
logger.info("统计信息已重置")
except Exception as e:

View File

@@ -8,6 +8,8 @@
import datetime
from typing import Optional, Tuple
from sqlalchemy import select
from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import BanUser, get_db_session
from ..types import DetectionResult
@@ -37,8 +39,9 @@ class UserBanManager:
如果用户被封禁则返回拒绝结果否则返回None
"""
try:
with get_db_session() as session:
ban_record = session.query(BanUser).filter_by(user_id=user_id, platform=platform).first()
async with get_db_session() as session:
result = await session.execute(select(BanUser).filter_by(user_id=user_id, platform=platform))
ban_record = result.scalar_one_or_none()
if ban_record:
# 只有违规次数达到阈值时才算被封禁
@@ -70,9 +73,10 @@ class UserBanManager:
detection_result: 检测结果
"""
try:
with get_db_session() as session:
async with get_db_session() as session:
# 查找或创建违规记录
ban_record = session.query(BanUser).filter_by(user_id=user_id, platform=platform).first()
result = await session.execute(select(BanUser).filter_by(user_id=user_id, platform=platform))
ban_record = result.scalar_one_or_none()
if ban_record:
ban_record.violation_num += 1

View File

@@ -149,7 +149,7 @@ class MaiEmoji:
# --- 数据库操作 ---
try:
# 准备数据库记录 for emoji collection
with get_db_session() as session:
async with get_db_session() as session:
emotion_str = ",".join(self.emotion) if self.emotion else ""
emoji = Emoji(
@@ -167,7 +167,7 @@ class MaiEmoji:
last_used_time=self.last_used_time,
)
session.add(emoji)
session.commit()
await session.commit()
logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")
@@ -203,17 +203,18 @@ class MaiEmoji:
# 2. 删除数据库记录
try:
with get_db_session() as session:
will_delete_emoji = session.execute(
async with get_db_session() as session:
result = await session.execute(
select(Emoji).where(Emoji.emoji_hash == self.hash)
).scalar_one_or_none()
)
will_delete_emoji = result.scalar_one_or_none()
if will_delete_emoji is None:
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
result = 0 # Indicate no DB record was deleted
else:
session.delete(will_delete_emoji)
await session.delete(will_delete_emoji)
result = 1 # Successfully deleted one record
session.commit()
await session.commit()
except Exception as e:
logger.error(f"[错误] 删除数据库记录时出错: {str(e)}")
result = 0
@@ -424,17 +425,19 @@ class EmojiManager:
# if not self._initialized:
# raise RuntimeError("EmojiManager not initialized")
def record_usage(self, emoji_hash: str) -> None:
async def record_usage(self, emoji_hash: str) -> None:
"""记录表情使用次数"""
try:
with get_db_session() as session:
emoji_update = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none()
async with get_db_session() as session:
stmt = select(Emoji).where(Emoji.emoji_hash == emoji_hash)
result = await session.execute(stmt)
emoji_update = result.scalar_one_or_none()
if emoji_update is None:
logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包")
else:
emoji_update.usage_count += 1
emoji_update.last_used_time = time.time() # Update last used time
session.commit()
await session.commit()
except Exception as e:
logger.error(f"记录表情使用失败: {str(e)}")
@@ -521,7 +524,7 @@ class EmojiManager:
# 7. 获取选中的表情包并更新使用记录
selected_emoji = candidate_emojis[selected_index]
self.record_usage(selected_emoji.hash)
await self.record_usage(selected_emoji.hash)
_time_end = time.time()
logger.info(f"找到匹配描述的表情包: {selected_emoji.description}, 耗时: {(_time_end - _time_start):.2f}s")
@@ -657,10 +660,11 @@ class EmojiManager:
async def get_all_emoji_from_db(self) -> None:
"""获取所有表情包并初始化为MaiEmoji类对象更新 self.emoji_objects"""
try:
with get_db_session() as session:
async with get_db_session() as session:
logger.debug("[数据库] 开始加载所有表情包记录 ...")
emoji_instances = session.execute(select(Emoji)).scalars().all()
result = await session.execute(select(Emoji))
emoji_instances = result.scalars().all()
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
# 更新内存中的列表和数量
@@ -686,14 +690,16 @@ class EmojiManager:
list[MaiEmoji]: 表情包对象列表
"""
try:
with get_db_session() as session:
async with get_db_session() as session:
if emoji_hash:
query = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalars().all()
result = await session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash))
query = result.scalars().all()
else:
logger.warning(
"[查询] 未提供 hash将尝试加载所有表情包建议使用 get_all_emoji_from_db 更新管理器状态。"
)
query = session.execute(select(Emoji)).scalars().all()
result = await session.execute(select(Emoji))
query = result.scalars().all()
emoji_instances = query
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
@@ -770,10 +776,10 @@ class EmojiManager:
# 如果内存中没有,从数据库查找
try:
with get_db_session() as session:
emoji_record = session.execute(
select(Emoji).where(Emoji.emoji_hash == emoji_hash)
).scalar_one_or_none()
async with get_db_session() as session:
stmt = select(Emoji).where(Emoji.emoji_hash == emoji_hash)
result = await session.execute(stmt)
emoji_record = result.scalar_one_or_none()
if emoji_record and emoji_record.description:
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...")
return emoji_record.description
@@ -939,12 +945,13 @@ class EmojiManager:
# 2. 检查数据库中是否已存在该表情包的描述,实现复用
existing_description = None
try:
with get_db_session() as session:
existing_image = (
session.query(Images)
.filter((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
.one_or_none()
async with get_db_session() as session:
stmt = select(Images).where(
Images.emoji_hash == image_hash,
Images.type == "emoji"
)
result = await session.execute(stmt)
existing_image = result.scalar_one_or_none()
if existing_image and existing_image.description:
existing_description = existing_image.description
logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...")

View File

@@ -198,7 +198,7 @@ class RecencyEnergyCalculator(EnergyCalculator):
class RelationshipEnergyCalculator(EnergyCalculator):
"""关系能量计算器"""
def calculate(self, context: Dict[str, Any]) -> float:
async def calculate(self, context: Dict[str, Any]) -> float:
"""基于关系计算能量"""
user_id = context.get("user_id")
if not user_id:
@@ -208,7 +208,7 @@ class RelationshipEnergyCalculator(EnergyCalculator):
try:
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
relationship_score = chatter_interest_scoring_system._calculate_relationship_score(user_id)
relationship_score = await chatter_interest_scoring_system._calculate_relationship_score(user_id)
logger.debug(f"使用插件内部系统计算关系分: {relationship_score:.3f}")
return max(0.0, min(1.0, relationship_score))
@@ -273,7 +273,7 @@ class EnergyManager:
except Exception as e:
logger.warning(f"加载AFC阈值失败使用默认值: {e}")
def calculate_focus_energy(self, stream_id: str, messages: List[Any], user_id: Optional[str] = None) -> float:
async def calculate_focus_energy(self, stream_id: str, messages: List[Any], user_id: Optional[str] = None) -> float:
"""计算聊天流的focus_energy"""
start_time = time.time()
@@ -303,7 +303,16 @@ class EnergyManager:
for calculator in self.calculators:
try:
# 支持同步和异步计算器
if callable(calculator.calculate):
import inspect
if inspect.iscoroutinefunction(calculator.calculate):
score = await calculator.calculate(context)
else:
score = calculator.calculate(context)
else:
score = calculator.calculate(context)
weight = calculator.get_weight()
component_scores[calculator.__class__.__name__] = score

View File

@@ -8,6 +8,7 @@ import traceback
from typing import List, Dict, Optional, Any
from datetime import datetime
import numpy as np
from sqlalchemy import select
from src.common.logger import get_logger
from src.config.config import global_config
@@ -610,14 +611,13 @@ class BotInterestManager:
from src.common.database.sqlalchemy_database_api import get_db_session
import orjson
with get_db_session() as session:
async with get_db_session() as session:
# 查询最新的兴趣标签配置
db_interests = (
session.query(DBBotPersonalityInterests)
.filter(DBBotPersonalityInterests.personality_id == personality_id)
db_interests = (await session.execute(
select(DBBotPersonalityInterests)
.where(DBBotPersonalityInterests.personality_id == personality_id)
.order_by(DBBotPersonalityInterests.version.desc(), DBBotPersonalityInterests.last_updated.desc())
.first()
)
)).scalars().first()
if db_interests:
logger.debug(f"在数据库中找到兴趣标签配置, 版本: {db_interests.version}")
@@ -700,13 +700,12 @@ class BotInterestManager:
# 序列化为JSON
json_data = orjson.dumps(tags_data)
with get_db_session() as session:
async with get_db_session() as session:
# 检查是否已存在相同personality_id的记录
existing_record = (
session.query(DBBotPersonalityInterests)
.filter(DBBotPersonalityInterests.personality_id == interests.personality_id)
.first()
)
existing_record = (await session.execute(
select(DBBotPersonalityInterests)
.where(DBBotPersonalityInterests.personality_id == interests.personality_id)
)).scalars().first()
if existing_record:
# 更新现有记录
@@ -731,19 +730,17 @@ class BotInterestManager:
last_updated=interests.last_updated,
)
session.add(new_record)
session.commit()
await session.commit()
logger.info(f"✅ 成功创建兴趣标签配置,版本: {interests.version}")
logger.info("✅ 兴趣标签已成功保存到数据库")
# 验证保存是否成功
with get_db_session() as session:
saved_record = (
session.query(DBBotPersonalityInterests)
.filter(DBBotPersonalityInterests.personality_id == interests.personality_id)
.first()
)
session.commit()
async with get_db_session() as session:
saved_record = (await session.execute(
select(DBBotPersonalityInterests)
.where(DBBotPersonalityInterests.personality_id == interests.personality_id)
)).scalars().first()
if saved_record:
logger.info(f"✅ 验证成功数据库中存在personality_id为 {interests.personality_id} 的记录")
logger.info(f" 版本: {saved_record.version}")

View File

@@ -882,7 +882,8 @@ class EntorhinalCortex:
# 获取数据库中所有节点和内存中所有节点
async with get_db_session() as session:
db_nodes = {node.concept: node for node in (await session.execute(select(GraphNodes))).scalars()}
result = await session.execute(select(GraphNodes))
db_nodes = {node.concept: node for node in result.scalars()}
memory_nodes = list(self.memory_graph.G.nodes(data=True))
# 批量准备节点数据
@@ -978,7 +979,8 @@ class EntorhinalCortex:
await session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete)))
# 处理边的信息
db_edges = list((await session.execute(select(GraphEdges))).scalars())
result = await session.execute(select(GraphEdges))
db_edges = list(result.scalars())
memory_edges = list(self.memory_graph.G.edges(data=True))
# 创建边的哈希值字典
@@ -1157,7 +1159,8 @@ class EntorhinalCortex:
# 从数据库加载所有节点
async with get_db_session() as session:
nodes = list((await session.execute(select(GraphNodes))).scalars())
result = await session.execute(select(GraphNodes))
nodes = list(result.scalars())
for node in nodes:
concept = node.concept
try:
@@ -1192,7 +1195,8 @@ class EntorhinalCortex:
continue
# 从数据库加载所有边
edges = list((await session.execute(select(GraphEdges))).scalars())
result = await session.execute(select(GraphEdges))
edges = list(result.scalars())
for edge in edges:
source = edge.source
target = edge.target

View File

@@ -184,6 +184,11 @@ class AsyncMemoryQueue:
from src.chat.memory_system.Hippocampus import hippocampus_manager
if hippocampus_manager._initialized:
# 确保海马体对象已正确初始化
if not hippocampus_manager._hippocampus.parahippocampal_gyrus:
logger.warning("海马体对象未完全初始化,进行同步初始化")
hippocampus_manager._hippocampus.initialize()
await hippocampus_manager.build_memory()
return True
return False

View File

@@ -108,7 +108,7 @@ class InstantMemory:
@staticmethod
async def store_memory(memory_item: MemoryItem):
with get_db_session() as session:
async with get_db_session() as session:
memory = Memory(
memory_id=memory_item.memory_id,
chat_id=memory_item.chat_id,
@@ -161,20 +161,21 @@ class InstantMemory:
logger.info(f"start_time: {start_time}, end_time: {end_time}")
# 检索包含关键词的记忆
memories_set = set()
with get_db_session() as session:
async with get_db_session() as session:
if start_time and end_time:
start_ts = start_time.timestamp()
end_ts = end_time.timestamp()
query = session.execute(
query = (await session.execute(
select(Memory).where(
(Memory.chat_id == self.chat_id)
& (Memory.create_time >= start_ts)
& (Memory.create_time < end_ts)
)
).scalars()
)).scalars()
else:
query = session.execute(select(Memory).where(Memory.chat_id == self.chat_id)).scalars()
query = result = await session.execute(select(Memory).where(Memory.chat_id == self.chat_id))
result.scalars()
for mem in query:
# 对每条记忆
mem_keywords_str = mem.keywords or "[]"

View File

@@ -4,7 +4,7 @@
"""
from .message_manager import MessageManager, message_manager
from .context_manager import StreamContextManager, context_manager
from .context_manager import SingleStreamContextManager
from .distribution_manager import (
DistributionManager,
DistributionPriority,
@@ -16,8 +16,7 @@ from .distribution_manager import (
__all__ = [
"MessageManager",
"message_manager",
"StreamContextManager",
"context_manager",
"SingleStreamContextManager",
"DistributionManager",
"DistributionPriority",
"DistributionTask",

View File

@@ -1,12 +1,12 @@
"""
重构后的聊天上下文管理器
提供统一、稳定的聊天上下文管理功能
每个 context_manager 实例只管理一个 stream 的上下文
"""
import asyncio
import time
from typing import Dict, List, Optional, Any, Union, Tuple
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Any
from src.common.data_models.message_manager_data_model import StreamContext
from src.common.logger import get_logger
@@ -17,241 +17,112 @@ from .distribution_manager import distribution_manager
logger = get_logger("context_manager")
class StreamContextManager:
"""流上下文管理器 - 统一管理所有聊天流上下文"""
def __init__(self, max_context_size: Optional[int] = None, context_ttl: Optional[int] = None):
# 上下文存储
self.stream_contexts: Dict[str, Any] = {}
self.context_metadata: Dict[str, Dict[str, Any]] = {}
class SingleStreamContextManager:
"""单流上下文管理器 - 每个实例只管理一个 stream 的上下文"""
# 统计信息
self.stats: Dict[str, Union[int, float, str, Dict]] = {
"total_messages": 0,
"total_streams": 0,
"active_streams": 0,
"inactive_streams": 0,
"last_activity": time.time(),
"creation_time": time.time(),
}
def __init__(self, stream_id: str, context: StreamContext, max_context_size: Optional[int] = None):
self.stream_id = stream_id
self.context = context
# 配置参数
self.max_context_size = max_context_size or getattr(global_config.chat, "max_context_size", 100)
self.context_ttl = context_ttl or getattr(global_config.chat, "context_ttl", 24 * 3600) # 24小时
self.cleanup_interval = getattr(global_config.chat, "context_cleanup_interval", 3600) # 1小时
self.auto_cleanup = getattr(global_config.chat, "auto_cleanup_contexts", True)
self.enable_validation = getattr(global_config.chat, "enable_context_validation", True)
self.context_ttl = getattr(global_config.chat, "context_ttl", 24 * 3600) # 24小时
# 清理任务
self.cleanup_task: Optional[Any] = None
self.is_running = False
# 元数据
self.created_time = time.time()
self.last_access_time = time.time()
self.access_count = 0
self.total_messages = 0
logger.info(f"上下文管理器初始化完成 (最大上下文: {self.max_context_size}, TTL: {self.context_ttl}s)")
logger.debug(f"单流上下文管理器初始化: {stream_id}")
def add_stream_context(self, stream_id: str, context: Any, metadata: Optional[Dict[str, Any]] = None) -> bool:
"""添加流上下文
def get_context(self) -> StreamContext:
"""获取流上下文"""
self._update_access_stats()
return self.context
Args:
stream_id: 流ID
context: 上下文对象
metadata: 上下文元数据
Returns:
bool: 是否成功添加
"""
if stream_id in self.stream_contexts:
logger.warning(f"流上下文已存在: {stream_id}")
return False
# 添加上下文
self.stream_contexts[stream_id] = context
# 初始化元数据
self.context_metadata[stream_id] = {
"created_time": time.time(),
"last_access_time": time.time(),
"access_count": 0,
"last_validation_time": 0.0,
"custom_metadata": metadata or {},
}
# 更新统计
self.stats["total_streams"] += 1
self.stats["active_streams"] += 1
self.stats["last_activity"] = time.time()
logger.debug(f"添加流上下文: {stream_id} (类型: {type(context).__name__})")
return True
def remove_stream_context(self, stream_id: str) -> bool:
"""移除流上下文
Args:
stream_id: 流ID
Returns:
bool: 是否成功移除
"""
if stream_id in self.stream_contexts:
context = self.stream_contexts[stream_id]
metadata = self.context_metadata.get(stream_id, {})
del self.stream_contexts[stream_id]
if stream_id in self.context_metadata:
del self.context_metadata[stream_id]
self.stats["active_streams"] = max(0, self.stats["active_streams"] - 1)
self.stats["inactive_streams"] += 1
self.stats["last_activity"] = time.time()
logger.debug(f"移除流上下文: {stream_id} (类型: {type(context).__name__})")
return True
return False
def get_stream_context(self, stream_id: str, update_access: bool = True) -> Optional[StreamContext]:
"""获取流上下文
Args:
stream_id: 流ID
update_access: 是否更新访问统计
Returns:
Optional[Any]: 上下文对象
"""
context = self.stream_contexts.get(stream_id)
if context and update_access:
# 更新访问统计
if stream_id in self.context_metadata:
metadata = self.context_metadata[stream_id]
metadata["last_access_time"] = time.time()
metadata["access_count"] = metadata.get("access_count", 0) + 1
return context
def get_context_metadata(self, stream_id: str) -> Optional[Dict[str, Any]]:
"""获取上下文元数据
Args:
stream_id: 流ID
Returns:
Optional[Dict[str, Any]]: 元数据
"""
return self.context_metadata.get(stream_id)
def update_context_metadata(self, stream_id: str, updates: Dict[str, Any]) -> bool:
"""更新上下文元数据
Args:
stream_id: 流ID
updates: 更新的元数据
Returns:
bool: 是否成功更新
"""
if stream_id not in self.context_metadata:
return False
self.context_metadata[stream_id].update(updates)
return True
def add_message_to_context(self, stream_id: str, message: DatabaseMessages, skip_energy_update: bool = False) -> bool:
def add_message(self, message: DatabaseMessages, skip_energy_update: bool = False) -> bool:
"""添加消息到上下文
Args:
stream_id: 流ID
message: 消息对象
skip_energy_update: 是否跳过能量更新
Returns:
bool: 是否成功添加
"""
context = self.get_stream_context(stream_id)
if not context:
logger.warning(f"流上下文不存在: {stream_id}")
return False
try:
# 添加消息到上下文
context.add_message(message)
self.context.add_message(message)
# 计算消息兴趣度
interest_value = self._calculate_message_interest(message)
message.interest_value = interest_value
# 更新统计
self.stats["total_messages"] += 1
self.stats["last_activity"] = time.time()
self.total_messages += 1
self.last_access_time = time.time()
# 更新能量和分发
if not skip_energy_update:
self._update_stream_energy(stream_id)
distribution_manager.add_stream_message(stream_id, 1)
self._update_stream_energy()
distribution_manager.add_stream_message(self.stream_id, 1)
logger.debug(f"添加消息到上下文: {stream_id} (兴趣度: {interest_value:.3f})")
logger.debug(f"添加消息到单流上下文: {self.stream_id} (兴趣度: {interest_value:.3f})")
return True
except Exception as e:
logger.error(f"添加消息到上下文失败 {stream_id}: {e}", exc_info=True)
logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True)
return False
def update_message_in_context(self, stream_id: str, message_id: str, updates: Dict[str, Any]) -> bool:
def update_message(self, message_id: str, updates: Dict[str, Any]) -> bool:
"""更新上下文中的消息
Args:
stream_id: 流ID
message_id: 消息ID
updates: 更新的属性
Returns:
bool: 是否成功更新
"""
context = self.get_stream_context(stream_id)
if not context:
logger.warning(f"流上下文不存在: {stream_id}")
return False
try:
# 更新消息信息
context.update_message_info(message_id, **updates)
self.context.update_message_info(message_id, **updates)
# 如果更新了兴趣度,重新计算能量
if "interest_value" in updates:
self._update_stream_energy(stream_id)
self._update_stream_energy()
logger.debug(f"更新上下文消息: {stream_id}/{message_id}")
logger.debug(f"更新单流上下文消息: {self.stream_id}/{message_id}")
return True
except Exception as e:
logger.error(f"更新上下文消息失败 {stream_id}/{message_id}: {e}", exc_info=True)
logger.error(f"更新单流上下文消息失败 {self.stream_id}/{message_id}: {e}", exc_info=True)
return False
def get_context_messages(self, stream_id: str, limit: Optional[int] = None, include_unread: bool = True) -> List[DatabaseMessages]:
def get_messages(self, limit: Optional[int] = None, include_unread: bool = True) -> List[DatabaseMessages]:
"""获取上下文消息
Args:
stream_id: 流ID
limit: 消息数量限制
include_unread: 是否包含未读消息
Returns:
List[Any]: 消息列表
List[DatabaseMessages]: 消息列表
"""
context = self.get_stream_context(stream_id)
if not context:
return []
try:
messages = []
if include_unread:
messages.extend(context.get_unread_messages())
messages.extend(self.context.get_unread_messages())
if limit:
messages.extend(context.get_history_messages(limit=limit))
messages.extend(self.context.get_history_messages(limit=limit))
else:
messages.extend(context.get_history_messages())
messages.extend(self.context.get_history_messages())
# 按时间排序
messages.sort(key=lambda msg: getattr(msg, 'time', 0))
messages.sort(key=lambda msg: getattr(msg, "time", 0))
# 应用限制
if limit and len(messages) > limit:
@@ -260,103 +131,124 @@ class StreamContextManager:
return messages
except Exception as e:
logger.error(f"获取上下文消息失败 {stream_id}: {e}", exc_info=True)
return []
def get_unread_messages(self, stream_id: str) -> List[DatabaseMessages]:
"""获取未读消息
Args:
stream_id: 流ID
Returns:
List[Any]: 未读消息列表
"""
context = self.get_stream_context(stream_id)
if not context:
logger.error(f"获取单流上下文消息失败 {self.stream_id}: {e}", exc_info=True)
return []
def get_unread_messages(self) -> List[DatabaseMessages]:
"""获取未读消息"""
try:
return context.get_unread_messages()
return self.context.get_unread_messages()
except Exception as e:
logger.error(f"获取未读消息失败 {stream_id}: {e}", exc_info=True)
logger.error(f"获取单流未读消息失败 {self.stream_id}: {e}", exc_info=True)
return []
def mark_messages_as_read(self, stream_id: str, message_ids: List[str]) -> bool:
"""标记消息为已读
Args:
stream_id: 流ID
message_ids: 消息ID列表
Returns:
bool: 是否成功标记
"""
context = self.get_stream_context(stream_id)
if not context:
logger.warning(f"流上下文不存在: {stream_id}")
return False
def mark_messages_as_read(self, message_ids: List[str]) -> bool:
"""标记消息为已读"""
try:
if not hasattr(context, 'mark_message_as_read'):
logger.error(f"上下文对象缺少 mark_message_as_read 方法: {stream_id}")
if not hasattr(self.context, "mark_message_as_read"):
logger.error(f"上下文对象缺少 mark_message_as_read 方法: {self.stream_id}")
return False
marked_count = 0
for message_id in message_ids:
try:
context.mark_message_as_read(message_id)
self.context.mark_message_as_read(message_id)
marked_count += 1
except Exception as e:
logger.warning(f"标记消息已读失败 {message_id}: {e}")
logger.debug(f"标记消息为已读: {stream_id} ({marked_count}/{len(message_ids)}条)")
logger.debug(f"标记消息为已读: {self.stream_id} ({marked_count}/{len(message_ids)}条)")
return marked_count > 0
except Exception as e:
logger.error(f"标记消息已读失败 {stream_id}: {e}", exc_info=True)
return False
def clear_context(self, stream_id: str) -> bool:
"""清空上下文
Args:
stream_id: 流ID
Returns:
bool: 是否成功清空
"""
context = self.get_stream_context(stream_id)
if not context:
logger.warning(f"流上下文不存在: {stream_id}")
logger.error(f"标记消息已读失败 {self.stream_id}: {e}", exc_info=True)
return False
def clear_context(self) -> bool:
"""清空上下文"""
try:
# 清空消息
if hasattr(context, 'unread_messages'):
context.unread_messages.clear()
if hasattr(context, 'history_messages'):
context.history_messages.clear()
if hasattr(self.context, "unread_messages"):
self.context.unread_messages.clear()
if hasattr(self.context, "history_messages"):
self.context.history_messages.clear()
# 重置状态
reset_attrs = ['interruption_count', 'afc_threshold_adjustment', 'last_check_time']
reset_attrs = ["interruption_count", "afc_threshold_adjustment", "last_check_time"]
for attr in reset_attrs:
if hasattr(context, attr):
if attr in ['interruption_count', 'afc_threshold_adjustment']:
setattr(context, attr, 0)
if hasattr(self.context, attr):
if attr in ["interruption_count", "afc_threshold_adjustment"]:
setattr(self.context, attr, 0)
else:
setattr(context, attr, time.time())
setattr(self.context, attr, time.time())
# 重新计算能量
self._update_stream_energy(stream_id)
self._update_stream_energy()
logger.info(f"清空上下文: {stream_id}")
logger.info(f"清空单流上下文: {self.stream_id}")
return True
except Exception as e:
logger.error(f"清空上下文失败 {stream_id}: {e}", exc_info=True)
logger.error(f"清空单流上下文失败 {self.stream_id}: {e}", exc_info=True)
return False
def get_statistics(self) -> Dict[str, Any]:
"""获取流统计信息"""
try:
current_time = time.time()
uptime = current_time - self.created_time
unread_messages = getattr(self.context, "unread_messages", [])
history_messages = getattr(self.context, "history_messages", [])
return {
"stream_id": self.stream_id,
"context_type": type(self.context).__name__,
"total_messages": len(history_messages) + len(unread_messages),
"unread_messages": len(unread_messages),
"history_messages": len(history_messages),
"is_active": getattr(self.context, "is_active", True),
"last_check_time": getattr(self.context, "last_check_time", current_time),
"interruption_count": getattr(self.context, "interruption_count", 0),
"afc_threshold_adjustment": getattr(self.context, "afc_threshold_adjustment", 0.0),
"created_time": self.created_time,
"last_access_time": self.last_access_time,
"access_count": self.access_count,
"uptime_seconds": uptime,
"idle_seconds": current_time - self.last_access_time,
}
except Exception as e:
logger.error(f"获取单流统计失败 {self.stream_id}: {e}", exc_info=True)
return {}
def validate_integrity(self) -> bool:
"""验证上下文完整性"""
try:
# 检查基本属性
required_attrs = ["stream_id", "unread_messages", "history_messages"]
for attr in required_attrs:
if not hasattr(self.context, attr):
logger.warning(f"上下文缺少必要属性: {attr}")
return False
# 检查消息ID唯一性
all_messages = getattr(self.context, "unread_messages", []) + getattr(self.context, "history_messages", [])
message_ids = [msg.message_id for msg in all_messages if hasattr(msg, "message_id")]
if len(message_ids) != len(set(message_ids)):
logger.warning(f"上下文中存在重复消息ID: {self.stream_id}")
return False
return True
except Exception as e:
logger.error(f"验证单流上下文完整性失败 {self.stream_id}: {e}")
return False
def _update_access_stats(self):
"""更新访问统计"""
self.last_access_time = time.time()
self.access_count += 1
def _calculate_message_interest(self, message: DatabaseMessages) -> float:
"""计算消息兴趣度"""
try:
@@ -373,8 +265,7 @@ class StreamContextManager:
interest_score = loop.run_until_complete(
chatter_interest_scoring_system._calculate_single_message_score(
message=message,
bot_nickname=global_config.bot.nickname
message=message, bot_nickname=global_config.bot.nickname
)
)
interest_value = interest_score.total_score
@@ -391,12 +282,12 @@ class StreamContextManager:
logger.error(f"计算消息兴趣度失败: {e}")
return 0.5
def _update_stream_energy(self, stream_id: str):
async def _update_stream_energy(self):
"""更新流能量"""
try:
# 获取所有消息
all_messages = self.get_context_messages(stream_id, self.max_context_size)
unread_messages = self.get_unread_messages(stream_id)
all_messages = self.get_messages(self.max_context_size)
unread_messages = self.get_unread_messages()
combined_messages = all_messages + unread_messages
# 获取用户ID
@@ -406,248 +297,12 @@ class StreamContextManager:
user_id = last_message.user_info.user_id
# 计算能量
energy = energy_manager.calculate_focus_energy(
stream_id=stream_id,
messages=combined_messages,
user_id=user_id
energy = await energy_manager.calculate_focus_energy(
stream_id=self.stream_id, messages=combined_messages, user_id=user_id
)
# 更新分发管理器
distribution_manager.update_stream_energy(stream_id, energy)
distribution_manager.update_stream_energy(self.stream_id, energy)
except Exception as e:
logger.error(f"更新流能量失败 {stream_id}: {e}")
def get_stream_statistics(self, stream_id: str) -> Optional[Dict[str, Any]]:
"""获取流统计信息
Args:
stream_id: 流ID
Returns:
Optional[Dict[str, Any]]: 统计信息
"""
context = self.get_stream_context(stream_id, update_access=False)
if not context:
return None
try:
metadata = self.context_metadata.get(stream_id, {})
current_time = time.time()
created_time = metadata.get("created_time", current_time)
last_access_time = metadata.get("last_access_time", current_time)
access_count = metadata.get("access_count", 0)
unread_messages = getattr(context, "unread_messages", [])
history_messages = getattr(context, "history_messages", [])
return {
"stream_id": stream_id,
"context_type": type(context).__name__,
"total_messages": len(history_messages) + len(unread_messages),
"unread_messages": len(unread_messages),
"history_messages": len(history_messages),
"is_active": getattr(context, "is_active", True),
"last_check_time": getattr(context, "last_check_time", current_time),
"interruption_count": getattr(context, "interruption_count", 0),
"afc_threshold_adjustment": getattr(context, "afc_threshold_adjustment", 0.0),
"created_time": created_time,
"last_access_time": last_access_time,
"access_count": access_count,
"uptime_seconds": current_time - created_time,
"idle_seconds": current_time - last_access_time,
}
except Exception as e:
logger.error(f"获取流统计失败 {stream_id}: {e}", exc_info=True)
return None
def get_manager_statistics(self) -> Dict[str, Any]:
"""获取管理器统计信息
Returns:
Dict[str, Any]: 管理器统计信息
"""
current_time = time.time()
uptime = current_time - self.stats.get("creation_time", current_time)
return {
**self.stats,
"uptime_hours": uptime / 3600,
"stream_count": len(self.stream_contexts),
"metadata_count": len(self.context_metadata),
"auto_cleanup_enabled": self.auto_cleanup,
"cleanup_interval": self.cleanup_interval,
}
def cleanup_inactive_contexts(self, max_inactive_hours: int = 24) -> int:
"""清理不活跃的上下文
Args:
max_inactive_hours: 最大不活跃小时数
Returns:
int: 清理的上下文数量
"""
current_time = time.time()
max_inactive_seconds = max_inactive_hours * 3600
inactive_streams = []
for stream_id, context in self.stream_contexts.items():
try:
# 获取最后活动时间
metadata = self.context_metadata.get(stream_id, {})
last_activity = metadata.get("last_access_time", metadata.get("created_time", 0))
context_last_activity = getattr(context, "last_check_time", 0)
actual_last_activity = max(last_activity, context_last_activity)
# 检查是否不活跃
unread_count = len(getattr(context, "unread_messages", []))
history_count = len(getattr(context, "history_messages", []))
total_messages = unread_count + history_count
if (current_time - actual_last_activity > max_inactive_seconds and
total_messages == 0):
inactive_streams.append(stream_id)
except Exception as e:
logger.warning(f"检查上下文活跃状态失败 {stream_id}: {e}")
continue
# 清理不活跃上下文
cleaned_count = 0
for stream_id in inactive_streams:
if self.remove_stream_context(stream_id):
cleaned_count += 1
if cleaned_count > 0:
logger.info(f"清理了 {cleaned_count} 个不活跃上下文")
return cleaned_count
def validate_context_integrity(self, stream_id: str) -> bool:
"""验证上下文完整性
Args:
stream_id: 流ID
Returns:
bool: 是否完整
"""
context = self.get_stream_context(stream_id)
if not context:
return False
try:
# 检查基本属性
required_attrs = ["stream_id", "unread_messages", "history_messages"]
for attr in required_attrs:
if not hasattr(context, attr):
logger.warning(f"上下文缺少必要属性: {attr}")
return False
# 检查消息ID唯一性
all_messages = getattr(context, "unread_messages", []) + getattr(context, "history_messages", [])
message_ids = [msg.message_id for msg in all_messages if hasattr(msg, "message_id")]
if len(message_ids) != len(set(message_ids)):
logger.warning(f"上下文中存在重复消息ID: {stream_id}")
return False
return True
except Exception as e:
logger.error(f"验证上下文完整性失败 {stream_id}: {e}")
return False
async def start(self) -> None:
"""启动上下文管理器"""
if self.is_running:
logger.warning("上下文管理器已经在运行")
return
await self.start_auto_cleanup()
logger.info("上下文管理器已启动")
async def stop(self) -> None:
"""停止上下文管理器"""
if not self.is_running:
return
await self.stop_auto_cleanup()
logger.info("上下文管理器已停止")
async def start_auto_cleanup(self, interval: Optional[float] = None) -> None:
"""启动自动清理
Args:
interval: 清理间隔(秒)
"""
if not self.auto_cleanup:
logger.info("自动清理已禁用")
return
if self.is_running:
logger.warning("自动清理已在运行")
return
self.is_running = True
cleanup_interval = interval or self.cleanup_interval
logger.info(f"启动自动清理(间隔: {cleanup_interval}s")
import asyncio
self.cleanup_task = asyncio.create_task(self._cleanup_loop(cleanup_interval))
async def stop_auto_cleanup(self) -> None:
"""停止自动清理"""
self.is_running = False
if self.cleanup_task and not self.cleanup_task.done():
self.cleanup_task.cancel()
try:
await self.cleanup_task
except Exception:
pass
logger.info("自动清理已停止")
async def _cleanup_loop(self, interval: float) -> None:
"""清理循环
Args:
interval: 清理间隔
"""
while self.is_running:
try:
await asyncio.sleep(interval)
self.cleanup_inactive_contexts()
self._cleanup_expired_contexts()
logger.debug("自动清理完成")
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"清理循环出错: {e}", exc_info=True)
await asyncio.sleep(interval)
def _cleanup_expired_contexts(self) -> None:
"""清理过期上下文"""
current_time = time.time()
expired_contexts = []
for stream_id, metadata in self.context_metadata.items():
created_time = metadata.get("created_time", current_time)
if current_time - created_time > self.context_ttl:
expired_contexts.append(stream_id)
for stream_id in expired_contexts:
self.remove_stream_context(stream_id)
if expired_contexts:
logger.info(f"清理了 {len(expired_contexts)} 个过期上下文")
def get_active_streams(self) -> List[str]:
"""获取活跃流列表
Returns:
List[str]: 活跃流ID列表
"""
return list(self.stream_contexts.keys())
# 全局上下文管理器实例
context_manager = StreamContextManager()
logger.error(f"更新流能量失败 {self.stream_id}: {e}")

View File

@@ -14,11 +14,10 @@ from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.message_manager_data_model import StreamContext, MessageManagerStats, StreamStats
from src.chat.chatter_manager import ChatterManager
from src.chat.planner_actions.action_manager import ChatterActionManager
from src.plugin_system.base.component_types import ChatMode
from .sleep_manager.sleep_manager import SleepManager
from .sleep_manager.wakeup_manager import WakeUpManager
from src.config.config import global_config
from .context_manager import context_manager
from src.plugin_system.apis.chat_api import get_chat_manager
if TYPE_CHECKING:
from src.common.data_models.message_manager_data_model import StreamContext
@@ -45,8 +44,7 @@ class MessageManager:
self.sleep_manager = SleepManager()
self.wakeup_manager = WakeUpManager(self.sleep_manager)
# 初始化上下文管理器
self.context_manager = context_manager
# 不再需要全局上下文管理器,直接通过 ChatManager 访问各个 ChatStream 的 context_manager
async def start(self):
"""启动消息管理器"""
@@ -57,7 +55,7 @@ class MessageManager:
self.is_running = True
self.manager_task = asyncio.create_task(self._manager_loop())
await self.wakeup_manager.start()
await self.context_manager.start()
# await self.context_manager.start() # 已删除,需要重构
logger.info("消息管理器已启动")
async def stop(self):
@@ -73,29 +71,32 @@ class MessageManager:
self.manager_task.cancel()
await self.wakeup_manager.stop()
await self.context_manager.stop()
# await self.context_manager.stop() # 已删除,需要重构
logger.info("消息管理器已停止")
def add_message(self, stream_id: str, message: DatabaseMessages):
"""添加消息到指定聊天流"""
# 检查流上下文是否存在,不存在则创建
context = self.context_manager.get_stream_context(stream_id)
if not context:
# 创建新的流上下文
from src.common.data_models.message_manager_data_model import StreamContext
context = StreamContext(stream_id=stream_id)
# 将创建的上下文添加到管理器
self.context_manager.add_stream_context(stream_id, context)
try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
# 使用 context_manager 添加消息
success = self.context_manager.add_message_to_context(stream_id, message)
if not chat_stream:
logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在")
return
# 使用 ChatStream 的 context_manager 添加消息
success = chat_stream.context_manager.add_message(message)
if success:
logger.debug(f"添加消息到聊天流 {stream_id}: {message.message_id}")
else:
logger.warning(f"添加消息到聊天流 {stream_id} 失败")
except Exception as e:
logger.error(f"添加消息到聊天流 {stream_id} 时发生错误: {e}")
def update_message(
self,
stream_id: str,
@@ -105,17 +106,60 @@ class MessageManager:
should_reply: bool = None,
):
"""更新消息信息"""
# 使用 context_manager 更新消息信息
context = self.context_manager.get_stream_context(stream_id)
if context:
context.update_message_info(message_id, interest_value, actions, should_reply)
try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"MessageManager.update_message: 聊天流 {stream_id} 不存在")
return
# 构建更新字典
updates = {}
if interest_value is not None:
updates["interest_value"] = interest_value
if actions is not None:
updates["actions"] = actions
if should_reply is not None:
updates["should_reply"] = should_reply
# 使用 ChatStream 的 context_manager 更新消息
if updates:
success = chat_stream.context_manager.update_message(message_id, updates)
if success:
logger.debug(f"更新消息 {message_id} 成功")
else:
logger.warning(f"更新消息 {message_id} 失败")
except Exception as e:
logger.error(f"更新消息 {message_id} 时发生错误: {e}")
def add_action(self, stream_id: str, message_id: str, action: str):
"""添加动作到消息"""
# 使用 context_manager 添加动作到消息
context = self.context_manager.get_stream_context(stream_id)
if context:
context.add_action_to_message(message_id, action)
try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"MessageManager.add_action: 聊天流 {stream_id} 不存在")
return
# 使用 ChatStream 的 context_manager 添加动作
# 注意:这里需要根据实际的 API 调整
# 假设我们可以通过 update_message 来添加动作
success = chat_stream.context_manager.update_message(
message_id, {"actions": [action]}
)
if success:
logger.debug(f"为消息 {message_id} 添加动作 {action} 成功")
else:
logger.warning(f"为消息 {message_id} 添加动作 {action} 失败")
except Exception as e:
logger.error(f"为消息 {message_id} 添加动作时发生错误: {e}")
async def _manager_loop(self):
"""管理器主循环 - 独立聊天流分发周期版本"""
@@ -145,18 +189,25 @@ class MessageManager:
active_streams = 0
total_unread = 0
# 使用 context_manager 获取活跃的流
active_stream_ids = self.context_manager.get_active_streams()
# 通过 ChatManager 获取所有活跃的流
try:
chat_manager = get_chat_manager()
active_stream_ids = list(chat_manager.streams.keys())
for stream_id in active_stream_ids:
context = self.context_manager.get_stream_context(stream_id)
if not context:
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
continue
# 检查流是否活跃
context = chat_stream.stream_context
if not context.is_active:
continue
active_streams += 1
# 检查是否有未读消息
unread_messages = self.context_manager.get_unread_messages(stream_id)
unread_messages = chat_stream.context_manager.get_unread_messages()
if unread_messages:
total_unread += len(unread_messages)
@@ -168,15 +219,23 @@ class MessageManager:
self.stats.active_streams = active_streams
self.stats.total_unread_messages = total_unread
except Exception as e:
logger.error(f"检查所有聊天流时发生错误: {e}")
async def _process_stream_messages(self, stream_id: str):
"""处理指定聊天流的消息"""
context = self.context_manager.get_stream_context(stream_id)
if not context:
try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"处理消息失败: 聊天流 {stream_id} 不存在")
return
try:
context = chat_stream.stream_context
# 获取未读消息
unread_messages = self.context_manager.get_unread_messages(stream_id)
unread_messages = chat_stream.context_manager.get_unread_messages()
if not unread_messages:
return
@@ -250,8 +309,15 @@ class MessageManager:
def deactivate_stream(self, stream_id: str):
"""停用聊天流"""
context = self.context_manager.get_stream_context(stream_id)
if context:
try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"停用流失败: 聊天流 {stream_id} 不存在")
return
context = chat_stream.stream_context
context.is_active = False
# 取消处理任务
@@ -260,28 +326,51 @@ class MessageManager:
logger.info(f"停用聊天流: {stream_id}")
except Exception as e:
logger.error(f"停用聊天流 {stream_id} 时发生错误: {e}")
def activate_stream(self, stream_id: str):
"""激活聊天流"""
context = self.context_manager.get_stream_context(stream_id)
if context:
try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"激活流失败: 聊天流 {stream_id} 不存在")
return
context = chat_stream.stream_context
context.is_active = True
logger.info(f"激活聊天流: {stream_id}")
except Exception as e:
logger.error(f"激活聊天流 {stream_id} 时发生错误: {e}")
def get_stream_stats(self, stream_id: str) -> Optional[StreamStats]:
"""获取聊天流统计"""
context = self.context_manager.get_stream_context(stream_id)
if not context:
try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
return None
context = chat_stream.stream_context
unread_count = len(chat_stream.context_manager.get_unread_messages())
return StreamStats(
stream_id=stream_id,
is_active=context.is_active,
unread_count=len(self.context_manager.get_unread_messages(stream_id)),
unread_count=unread_count,
history_count=len(context.history_messages),
last_check_time=context.last_check_time,
has_active_task=bool(hasattr(context, 'processing_task') and context.processing_task and not context.processing_task.done()),
)
except Exception as e:
logger.error(f"获取聊天流 {stream_id} 统计时发生错误: {e}")
return None
def get_manager_stats(self) -> Dict[str, Any]:
"""获取管理器统计"""
return {
@@ -295,9 +384,36 @@ class MessageManager:
def cleanup_inactive_streams(self, max_inactive_hours: int = 24):
"""清理不活跃的聊天流"""
# 使用 context_manager 的自动清理功能
self.context_manager.cleanup_inactive_contexts(max_inactive_hours * 3600)
logger.info("已启动不活跃聊天流清理")
try:
# 通过 ChatManager 清理不活跃的流
chat_manager = get_chat_manager()
current_time = time.time()
max_inactive_seconds = max_inactive_hours * 3600
inactive_streams = []
for stream_id, chat_stream in chat_manager.streams.items():
# 检查最后活跃时间
if current_time - chat_stream.last_active_time > max_inactive_seconds:
inactive_streams.append(stream_id)
# 清理不活跃的流
for stream_id in inactive_streams:
try:
# 清理流的内容
chat_stream.context_manager.clear_context()
# 从 ChatManager 中移除
del chat_manager.streams[stream_id]
logger.info(f"清理不活跃聊天流: {stream_id}")
except Exception as e:
logger.error(f"清理聊天流 {stream_id} 失败: {e}")
if inactive_streams:
logger.info(f"已清理 {len(inactive_streams)} 个不活跃聊天流")
else:
logger.debug("没有需要清理的不活跃聊天流")
except Exception as e:
logger.error(f"清理不活跃聊天流时发生错误: {e}")
async def _check_and_handle_interruption(self, context: StreamContext, stream_id: str):
"""检查并处理消息打断"""
@@ -376,9 +492,10 @@ class MessageManager:
min_delay = float("inf")
# 找到最近需要检查的流
active_stream_ids = self.context_manager.get_active_streams()
for stream_id in active_stream_ids:
context = self.context_manager.get_stream_context(stream_id)
try:
chat_manager = get_chat_manager()
for _stream_id, chat_stream in chat_manager.streams.items():
context = chat_stream.stream_context
if not context or not context.is_active:
continue
@@ -396,16 +513,20 @@ class MessageManager:
# 确保最小延迟
return max(0.1, min(min_delay, self.check_interval))
except Exception as e:
logger.error(f"计算下次检查延迟时发生错误: {e}")
return self.check_interval
async def _check_streams_with_individual_intervals(self):
"""检查所有达到检查时间的聊天流"""
current_time = time.time()
processed_streams = 0
# 使用 context_manager 获取活跃的流
active_stream_ids = self.context_manager.get_active_streams()
for stream_id in active_stream_ids:
context = self.context_manager.get_stream_context(stream_id)
# 通过 ChatManager 获取活跃的流
try:
chat_manager = get_chat_manager()
for stream_id, chat_stream in chat_manager.streams.items():
context = chat_stream.stream_context
if not context or not context.is_active:
continue
@@ -424,17 +545,14 @@ class MessageManager:
context.next_check_time = current_time + context.distribution_interval
# 检查未读消息
unread_messages = self.context_manager.get_unread_messages(stream_id)
unread_messages = chat_stream.context_manager.get_unread_messages()
if unread_messages:
processed_streams += 1
self.stats.total_unread_messages = len(unread_messages)
# 如果没有处理任务,创建一个
if not context.processing_task or context.processing_task.done():
from src.plugin_system.apis.chat_api import get_chat_manager
chat_stream = get_chat_manager().get_stream(context.stream_id)
focus_energy = chat_stream.focus_energy if chat_stream else 0.5
focus_energy = chat_stream.focus_energy
# 根据优先级记录日志
if focus_energy >= 0.7:
@@ -453,39 +571,45 @@ class MessageManager:
context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
except Exception as e:
logger.error(f"检查独立分发周期的聊天流时发生错误: {e}")
# 更新活跃流计数
active_count = len(self.context_manager.get_active_streams())
try:
chat_manager = get_chat_manager()
active_count = len([s for s in chat_manager.streams.values() if s.stream_context.is_active])
self.stats.active_streams = active_count
if processed_streams > 0:
logger.debug(f"本次循环处理了 {processed_streams} 个流 | 活跃流总数: {active_count}")
except Exception as e:
logger.error(f"更新活跃流计数时发生错误: {e}")
async def _check_all_streams_with_priority(self):
"""按优先级检查所有聊天流高focus_energy的流优先处理"""
if not self.context_manager.get_active_streams():
try:
chat_manager = get_chat_manager()
if not chat_manager.streams:
return
# 获取活跃的聊天流并按focus_energy排序
active_streams = []
active_stream_ids = self.context_manager.get_active_streams()
for stream_id in active_stream_ids:
context = self.context_manager.get_stream_context(stream_id)
for stream_id, chat_stream in chat_manager.streams.items():
context = chat_stream.stream_context
if not context or not context.is_active:
continue
# 获取focus_energy,如果不存在则使用默认值
from src.plugin_system.apis.chat_api import get_chat_manager
chat_stream = get_chat_manager().get_stream(context.stream_id)
focus_energy = 0.5
if chat_stream:
# 获取focus_energy
focus_energy = chat_stream.focus_energy
# 计算流优先级分数
priority_score = self._calculate_stream_priority(context, focus_energy)
active_streams.append((priority_score, stream_id, context))
except Exception as e:
logger.error(f"获取活跃流列表时发生错误: {e}")
return
# 按优先级降序排序
active_streams.sort(reverse=True, key=lambda x: x[0])
@@ -497,7 +621,12 @@ class MessageManager:
active_stream_count += 1
# 检查是否有未读消息
unread_messages = self.context_manager.get_unread_messages(stream_id)
try:
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
continue
unread_messages = chat_stream.context_manager.get_unread_messages()
if unread_messages:
total_unread += len(unread_messages)
@@ -512,6 +641,9 @@ class MessageManager:
f"优先级: {priority_score:.3f} | "
f"未读消息: {len(unread_messages)}"
)
except Exception as e:
logger.error(f"处理流 {stream_id} 的未读消息时发生错误: {e}")
continue
# 更新统计
self.stats.active_streams = active_stream_count
@@ -536,22 +668,33 @@ class MessageManager:
def _clear_all_unread_messages(self, stream_id: str):
"""清除指定上下文中的所有未读消息,防止意外情况导致消息一直未读"""
unread_messages = self.context_manager.get_unread_messages(stream_id)
try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"清除消息失败: 聊天流 {stream_id} 不存在")
return
# 获取未读消息
unread_messages = chat_stream.context_manager.get_unread_messages()
if not unread_messages:
return
logger.warning(f"正在清除 {len(unread_messages)} 条未读消息")
# 将所有未读消息标记为已读
context = self.context_manager.get_stream_context(stream_id)
if context:
for msg in unread_messages[:]: # 使用切片复制避免迭代时修改列表
try:
context.mark_message_as_read(msg.message_id)
self.stats.total_processed_messages += 1
logger.debug(f"强制清除消息 {msg.message_id},标记为已读")
message_ids = [msg.message_id for msg in unread_messages]
success = chat_stream.context_manager.mark_messages_as_read(message_ids)
if success:
self.stats.total_processed_messages += len(unread_messages)
logger.debug(f"强制清除 {len(unread_messages)} 条消息,标记为已读")
else:
logger.error("标记未读消息为已读失败")
except Exception as e:
logger.error(f"清除消息 {msg.message_id} 时出错: {e}")
logger.error(f"清除未读消息时发生错误: {e}")
# 创建全局消息管理器实例

View File

@@ -49,10 +49,18 @@ class ChatStream:
from src.common.data_models.message_manager_data_model import StreamContext
from src.plugin_system.base.component_types import ChatType, ChatMode
# 创建StreamContext
self.stream_context: StreamContext = StreamContext(
stream_id=stream_id, chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE, chat_mode=ChatMode.NORMAL
)
# 创建单流上下文管理器
from src.chat.message_manager.context_manager import SingleStreamContextManager
self.context_manager: SingleStreamContextManager = SingleStreamContextManager(
stream_id=stream_id, context=self.stream_context
)
# 基础参数
self.base_interest_energy = 0.5 # 默认基础兴趣度
self._focus_energy = 0.5 # 内部存储的focus_energy值
@@ -61,6 +69,37 @@ class ChatStream:
# 自动加载历史消息
self._load_history_messages()
def __deepcopy__(self, memo):
"""自定义深拷贝方法,避免复制不可序列化的 asyncio.Task 对象"""
import copy
# 创建新的实例
new_stream = ChatStream(
stream_id=self.stream_id,
platform=self.platform,
user_info=copy.deepcopy(self.user_info, memo),
group_info=copy.deepcopy(self.group_info, memo),
)
# 复制基本属性
new_stream.create_time = self.create_time
new_stream.last_active_time = self.last_active_time
new_stream.sleep_pressure = self.sleep_pressure
new_stream.saved = self.saved
new_stream.base_interest_energy = self.base_interest_energy
new_stream._focus_energy = self._focus_energy
new_stream.no_reply_consecutive = self.no_reply_consecutive
# 复制 stream_context但跳过 processing_task
new_stream.stream_context = copy.deepcopy(self.stream_context, memo)
if hasattr(new_stream.stream_context, 'processing_task'):
new_stream.stream_context.processing_task = None
# 复制 context_manager
new_stream.context_manager = copy.deepcopy(self.context_manager, memo)
return new_stream
def to_dict(self) -> dict:
"""转换为字典格式"""
return {
@@ -74,10 +113,10 @@ class ChatStream:
"focus_energy": self.focus_energy,
# 基础兴趣度
"base_interest_energy": self.base_interest_energy,
# 新增stream_context信息
# stream_context基本信息
"stream_context_chat_type": self.stream_context.chat_type.value,
"stream_context_chat_mode": self.stream_context.chat_mode.value,
# 新增interruption_count信息
# 统计信息
"interruption_count": self.stream_context.interruption_count,
}
@@ -109,6 +148,14 @@ class ChatStream:
if "interruption_count" in data:
instance.stream_context.interruption_count = data["interruption_count"]
# 确保 context_manager 已初始化
if not hasattr(instance, "context_manager"):
from src.chat.message_manager.context_manager import SingleStreamContextManager
instance.context_manager = SingleStreamContextManager(
stream_id=instance.stream_id, context=instance.stream_context
)
return instance
def update_active_time(self):
@@ -195,12 +242,14 @@ class ChatStream:
self.stream_context.priority_info = getattr(message, "priority_info", None)
# 调试日志:记录数据转移情况
logger.debug(f"消息数据转移完成 - message_id: {db_message.message_id}, "
logger.debug(
f"消息数据转移完成 - message_id: {db_message.message_id}, "
f"chat_id: {db_message.chat_id}, "
f"is_mentioned: {db_message.is_mentioned}, "
f"is_emoji: {db_message.is_emoji}, "
f"is_picid: {db_message.is_picid}, "
f"interest_value: {db_message.interest_value}")
f"interest_value: {db_message.interest_value}"
)
def _safe_get_actions(self, message: "MessageRecv") -> Optional[list]:
"""安全获取消息的actions字段"""
@@ -213,6 +262,7 @@ class ChatStream:
if isinstance(actions, str):
try:
import json
actions = json.loads(actions)
except json.JSONDecodeError:
logger.warning(f"无法解析actions JSON字符串: {actions}")
@@ -269,14 +319,17 @@ class ChatStream:
@property
def focus_energy(self) -> float:
"""使用重构后的能量管理器计算focus_energy"""
try:
from src.chat.energy_system import energy_manager
"""获取缓存的focus_energy"""
if hasattr(self, "_focus_energy"):
return self._focus_energy
else:
return 0.5
# 获取所有消息
history_messages = self.stream_context.get_history_messages(limit=global_config.chat.max_context_size)
unread_messages = self.stream_context.get_unread_messages()
all_messages = history_messages + unread_messages
async def calculate_focus_energy(self) -> float:
"""异步计算focus_energy"""
try:
# 使用单流上下文管理器获取消息
all_messages = self.context_manager.get_messages(limit=global_config.chat.max_context_size)
# 获取用户ID
user_id = None
@@ -284,10 +337,10 @@ class ChatStream:
user_id = str(self.user_info.user_id)
# 使用能量管理器计算
energy = energy_manager.calculate_focus_energy(
stream_id=self.stream_id,
messages=all_messages,
user_id=user_id
from src.chat.energy_system import energy_manager
energy = await energy_manager.calculate_focus_energy(
stream_id=self.stream_id, messages=all_messages, user_id=user_id
)
# 更新内部存储
@@ -299,7 +352,7 @@ class ChatStream:
except Exception as e:
logger.error(f"获取focus_energy失败: {e}", exc_info=True)
# 返回缓存的值或默认值
if hasattr(self, '_focus_energy'):
if hasattr(self, "_focus_energy"):
return self._focus_energy
else:
return 0.5
@@ -309,7 +362,7 @@ class ChatStream:
"""设置focus_energy值主要用于初始化或特殊场景"""
self._focus_energy = max(0.0, min(1.0, value))
def _get_user_relationship_score(self) -> float:
async def _get_user_relationship_score(self) -> float:
"""获取用户关系分"""
# 使用插件内部的兴趣度评分系统
try:
@@ -317,7 +370,7 @@ class ChatStream:
if self.user_info and hasattr(self.user_info, "user_id"):
user_id = str(self.user_info.user_id)
relationship_score = chatter_interest_scoring_system._calculate_relationship_score(user_id)
relationship_score = await chatter_interest_scoring_system._calculate_relationship_score(user_id)
logger.debug(f"ChatStream {self.stream_id}: 用户关系分 = {relationship_score:.3f}")
return max(0.0, min(1.0, relationship_score))
@@ -346,7 +399,8 @@ class ChatStream:
.order_by(desc(Messages.time))
.limit(global_config.chat.max_context_size)
)
results = session.execute(stmt).scalars().all()
result = session.execute(stmt)
results = result.scalars().all()
return results
# 在线程中执行数据库查询
@@ -404,7 +458,9 @@ class ChatStream:
)
# 添加调试日志检查从数据库加载的interest_value
logger.debug(f"加载历史消息 {db_message.message_id} - interest_value: {db_message.interest_value}")
logger.debug(
f"加载历史消息 {db_message.message_id} - interest_value: {db_message.interest_value}"
)
# 标记为已读并添加到历史消息
db_message.is_read = True
@@ -548,7 +604,11 @@ class ChatManager:
# 检查数据库中是否存在
async def _db_find_stream_async(s_id: str):
async with get_db_session() as session:
return (await session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id))).scalars().first()
return (
(await session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id)))
.scalars()
.first()
)
model_instance = await _db_find_stream_async(stream_id)
@@ -603,6 +663,15 @@ class ChatManager:
stream.set_context(self.last_messages[stream_id])
else:
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
# 确保 ChatStream 有自己的 context_manager
if not hasattr(stream, "context_manager"):
# 创建新的单流上下文管理器
from src.chat.message_manager.context_manager import SingleStreamContextManager
stream.context_manager = SingleStreamContextManager(
stream_id=stream_id, context=stream.stream_context
)
# 保存到内存和数据库
self.streams[stream_id] = stream
await self._save_stream(stream)
@@ -704,7 +773,8 @@ class ChatManager:
async def _db_load_all_streams_async():
loaded_streams_data = []
async with get_db_session() as session:
for model_instance in (await session.execute(select(ChatStreams))).scalars().all():
result = await session.execute(select(ChatStreams))
for model_instance in result.scalars().all():
user_info_data = {
"platform": model_instance.user_platform,
"user_id": model_instance.user_id,
@@ -752,6 +822,13 @@ class ChatManager:
self.streams[stream.stream_id] = stream
if stream.stream_id in self.last_messages:
stream.set_context(self.last_messages[stream.stream_id])
# 确保 ChatStream 有自己的 context_manager
if not hasattr(stream, "context_manager"):
from src.chat.message_manager.context_manager import SingleStreamContextManager
stream.context_manager = SingleStreamContextManager(
stream_id=stream.stream_id, context=stream.stream_context
)
except Exception as e:
logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}", exc_info=True)

View File

@@ -41,7 +41,7 @@ class MessageStorage:
processed_plain_text = message.processed_plain_text
if processed_plain_text:
processed_plain_text = MessageStorage.replace_image_descriptions(processed_plain_text)
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL)
else:
filtered_processed_plain_text = ""
@@ -129,9 +129,9 @@ class MessageStorage:
key_words=key_words,
key_words_lite=key_words_lite,
)
with get_db_session() as session:
async with get_db_session() as session:
session.add(new_message)
session.commit()
await session.commit()
except Exception:
logger.exception("存储消息失败")
@@ -174,13 +174,13 @@ class MessageStorage:
# 使用上下文管理器确保session正确管理
from src.common.database.sqlalchemy_models import get_db_session
with get_db_session() as session:
matched_message = session.execute(
async with get_db_session() as session:
matched_message = (await session.execute(
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
).scalar()
)).scalar()
if matched_message:
session.execute(
await session.execute(
update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id)
)
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
@@ -195,7 +195,7 @@ class MessageStorage:
)
@staticmethod
def replace_image_descriptions(text: str) -> str:
async def replace_image_descriptions(text: str) -> str:
"""将[图片:描述]替换为[picid:image_id]"""
# 先检查文本中是否有图片标记
pattern = r"\[图片:([^\]]+)\]"
@@ -205,15 +205,15 @@ class MessageStorage:
logger.debug("文本中没有图片标记,直接返回原文本")
return text
def replace_match(match):
async def replace_match(match):
description = match.group(1).strip()
try:
from src.common.database.sqlalchemy_models import get_db_session
with get_db_session() as session:
image_record = session.execute(
async with get_db_session() as session:
image_record = (await session.execute(
select(Images).where(Images.description == description).order_by(desc(Images.timestamp))
).scalar()
)).scalar()
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
except Exception:
return match.group(0)
@@ -271,7 +271,8 @@ class MessageStorage:
)
).limit(50) # 限制每次修复的数量,避免性能问题
messages_to_fix = session.execute(query).scalars().all()
result = session.execute(query)
messages_to_fix = result.scalars().all()
fixed_count = 0
for msg in messages_to_fix:

View File

@@ -824,7 +824,8 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
description = "[图片内容未知]" # 默认描述
try:
with get_db_session() as session:
image = session.execute(select(Images).where(Images.image_id == pic_id)).scalar_one_or_none()
result = session.execute(select(Images).where(Images.image_id == pic_id))
image = result.scalar_one_or_none()
if image and image.description: # type: ignore
description = image.description
except Exception:

View File

@@ -308,7 +308,8 @@ class ImageManager:
async with get_db_session() as session:
# 优先检查Images表中是否已有完整的描述
existing_image = (await session.execute(select(Images).where(Images.emoji_hash == image_hash))).scalar()
existing_image = result = await session.execute(select(Images).where(Images.emoji_hash == image_hash))
result.scalar()
if existing_image:
# 更新计数
if hasattr(existing_image, "count") and existing_image.count is not None:
@@ -527,7 +528,8 @@ class ImageManager:
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
async with get_db_session() as session:
existing_image = (await session.execute(select(Images).where(Images.emoji_hash == image_hash))).scalar()
existing_image = result = await session.execute(select(Images).where(Images.emoji_hash == image_hash))
result.scalar()
if existing_image:
# 检查是否缺少必要字段,如果缺少则创建新记录
if (

View File

@@ -22,13 +22,14 @@ from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import get_db_session, Videos
from sqlalchemy import select
logger = get_logger("utils_video")
# Rust模块可用性检测
RUST_VIDEO_AVAILABLE = False
try:
import rust_video
import rust_video # pyright: ignore[reportMissingImports]
RUST_VIDEO_AVAILABLE = True
logger.info("✅ Rust 视频处理模块加载成功")
@@ -202,18 +203,20 @@ class VideoAnalyzer:
hash_obj.update(video_data)
return hash_obj.hexdigest()
def _check_video_exists(self, video_hash: str) -> Optional[Videos]:
async def _check_video_exists(self, video_hash: str) -> Optional[Videos]:
"""检查视频是否已经分析过"""
try:
with get_db_session() as session:
async with get_db_session() as session:
# 明确刷新会话以确保看到其他事务的最新提交
session.expire_all()
return session.query(Videos).filter(Videos.video_hash == video_hash).first()
await session.expire_all()
stmt = select(Videos).where(Videos.video_hash == video_hash)
result = await session.execute(stmt)
return result.scalar_one_or_none()
except Exception as e:
logger.warning(f"检查视频是否存在时出错: {e}")
return None
def _store_video_result(
async def _store_video_result(
self, video_hash: str, description: str, metadata: Optional[Dict] = None
) -> Optional[Videos]:
"""存储视频分析结果到数据库"""
@@ -223,9 +226,11 @@ class VideoAnalyzer:
return None
try:
with get_db_session() as session:
async with get_db_session() as session:
# 只根据video_hash查找
existing_video = session.query(Videos).filter(Videos.video_hash == video_hash).first()
stmt = select(Videos).where(Videos.video_hash == video_hash)
result = await session.execute(stmt)
existing_video = result.scalar_one_or_none()
if existing_video:
# 如果已存在,更新描述和计数
@@ -238,8 +243,8 @@ class VideoAnalyzer:
existing_video.fps = metadata.get("fps")
existing_video.resolution = metadata.get("resolution")
existing_video.file_size = metadata.get("file_size")
session.commit()
session.refresh(existing_video)
await session.commit()
await session.refresh(existing_video)
logger.info(f"✅ 更新已存在的视频记录hash: {video_hash[:16]}..., count: {existing_video.count}")
return existing_video
else:
@@ -254,8 +259,8 @@ class VideoAnalyzer:
video_record.file_size = metadata.get("file_size")
session.add(video_record)
session.commit()
session.refresh(video_record)
await session.commit()
await session.refresh(video_record)
logger.info(f"✅ 新视频分析结果已保存到数据库hash: {video_hash[:16]}...")
return video_record
except Exception as e:
@@ -704,7 +709,7 @@ class VideoAnalyzer:
logger.info("✅ 等待结束,检查是否有处理结果")
# 检查是否有结果了
existing_video = self._check_video_exists(video_hash)
existing_video = await self._check_video_exists(video_hash)
if existing_video:
logger.info(f"✅ 找到了处理结果,直接返回 (id: {existing_video.id})")
return {"summary": existing_video.description}
@@ -718,7 +723,7 @@ class VideoAnalyzer:
logger.info(f"🔒 获得视频处理锁,开始处理 (hash: {video_hash[:16]}...)")
# 再次检查数据库(可能在等待期间已经有结果了)
existing_video = self._check_video_exists(video_hash)
existing_video = await self._check_video_exists(video_hash)
if existing_video:
logger.info(f"✅ 获得锁后发现已有结果,直接返回 (id: {existing_video.id})")
video_event.set() # 通知其他等待者
@@ -749,7 +754,7 @@ class VideoAnalyzer:
# 保存分析结果到数据库(仅保存成功的结果)
if success and not result.startswith(""):
metadata = {"filename": filename, "file_size": len(video_bytes), "analysis_timestamp": time.time()}
self._store_video_result(video_hash=video_hash, description=result, metadata=metadata)
await self._store_video_result(video_hash=video_hash, description=result, metadata=metadata)
logger.info("✅ 分析结果已保存到数据库")
else:
logger.warning("⚠️ 分析失败,不保存到数据库以便后续重试")

View File

@@ -22,9 +22,9 @@ class DatabaseProxy:
self._session = None
@staticmethod
def initialize(*args, **kwargs):
async def initialize(*args, **kwargs):
"""初始化数据库连接"""
return initialize_database_compat()
return await initialize_database_compat()
class SQLAlchemyTransaction:
@@ -88,7 +88,7 @@ async def initialize_sql_database(database_config):
logger.info(f" 数据库文件: {db_path}")
# 使用SQLAlchemy初始化
success = initialize_database_compat()
success = await initialize_database_compat()
if success:
_sql_engine = await get_engine()
logger.info("SQLAlchemy数据库初始化成功")

View File

@@ -706,7 +706,8 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
raise RuntimeError("Database session not initialized")
session = SessionLocal()
yield session
except Exception:
except Exception as e:
logger.error(f"数据库会话错误: {e}")
if session:
await session.rollback()
raise

View File

@@ -101,7 +101,8 @@ def find_messages(
# 获取时间最早的 limit 条记录,已经是正序
query = query.order_by(Messages.time.asc()).limit(limit)
try:
results = session.execute(query).scalars().all()
results = result = session.execute(query)
result.scalars().all()
except Exception as e:
logger.error(f"执行earliest查询失败: {e}")
results = []
@@ -109,7 +110,8 @@ def find_messages(
# 获取时间最晚的 limit 条记录
query = query.order_by(Messages.time.desc()).limit(limit)
try:
latest_results = session.execute(query).scalars().all()
latest_results = result = session.execute(query)
result.scalars().all()
# 将结果按时间正序排列
results = sorted(latest_results, key=lambda msg: msg.time)
except Exception as e:
@@ -133,7 +135,8 @@ def find_messages(
if sort_terms:
query = query.order_by(*sort_terms)
try:
results = session.execute(query).scalars().all()
results = result = session.execute(query)
result.scalars().all()
except Exception as e:
logger.error(f"执行无限制查询失败: {e}")
results = []
@@ -207,5 +210,5 @@ def count_messages(message_filter: dict[str, Any]) -> int:
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
# 注意:对于 SQLAlchemy插入操作通常是使用 session.add() 和 session.commit()。
# 注意:对于 SQLAlchemy插入操作通常是使用 await session.add() 和 await session.commit()。
# 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。

View File

@@ -161,7 +161,7 @@ class LLMUsageRecorder:
session = None
try:
# 使用 SQLAlchemy 会话创建记录
with get_db_session() as session:
async with get_db_session() as session:
usage_record = LLMUsage(
model_name=model_info.model_identifier,
model_assign_name=model_info.name,
@@ -172,14 +172,14 @@ class LLMUsageRecorder:
prompt_tokens=model_usage.prompt_tokens or 0,
completion_tokens=model_usage.completion_tokens or 0,
total_tokens=model_usage.total_tokens or 0,
cost=total_cost or 0.0,
cost=1.0,
time_cost=round(time_cost or 0.0, 3),
status="success",
timestamp=datetime.now(), # SQLAlchemy 会处理 DateTime 字段
)
session.add(usage_record)
session.commit()
await session.commit()
logger.debug(
f"Token使用情况 - 模型: {model_usage.model_name}, "

View File

@@ -163,7 +163,8 @@ class PersonInfoManager:
try:
# 在需要时获取会话
async with get_db_session() as session:
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name))).scalar()
record = result = await session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name))
result.scalar()
return record.person_id if record else ""
except Exception as e:
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (SQLAlchemy): {e}")
@@ -339,7 +340,8 @@ class PersonInfoManager:
start_time = time.time()
async with get_db_session() as session:
try:
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))
record = result.scalar()
query_time = time.time()
if record:
setattr(record, f_name, val_to_set)
@@ -401,7 +403,8 @@ class PersonInfoManager:
async def _db_has_field_async(p_id: str, f_name: str):
async with get_db_session() as session:
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))
record = result.scalar()
return bool(record)
try:
@@ -512,10 +515,9 @@ class PersonInfoManager:
async def _db_check_name_exists_async(name_to_check):
async with get_db_session() as session:
return (
(await session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check))).scalar()
is not None
)
result = await session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check))
record = result.scalar()
return record is not None
if await _db_check_name_exists_async(generated_nickname):
is_duplicate = True
@@ -556,7 +558,8 @@ class PersonInfoManager:
async def _db_delete_async(p_id: str):
try:
async with get_db_session() as session:
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))
record = result.scalar()
if record:
await session.delete(record)
await session.commit()
@@ -585,7 +588,9 @@ class PersonInfoManager:
async def _get_record_sync():
async with get_db_session() as session:
return (await session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id))).scalar()
result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id))
record = result.scalar()
return record
try:
record = asyncio.run(_get_record_sync())
@@ -624,7 +629,9 @@ class PersonInfoManager:
async def _db_get_record_async(p_id: str):
async with get_db_session() as session:
return (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))
record = result.scalar()
return record
record = await _db_get_record_async(person_id)
@@ -700,7 +707,8 @@ class PersonInfoManager:
"""原子性的获取或创建操作"""
async with get_db_session() as session:
# 首先尝试获取现有记录
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))
record = result.scalar()
if record:
return record, False # 记录存在,未创建
@@ -715,7 +723,8 @@ class PersonInfoManager:
# 如果创建失败(可能是因为竞态条件),再次尝试获取
if "UNIQUE constraint failed" in str(e):
logger.debug(f"检测到并发创建用户 {p_id},获取现有记录")
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))
record = result.scalar()
if record:
return record, False # 其他协程已创建,返回现有记录
# 如果仍然失败,重新抛出异常

View File

@@ -122,7 +122,7 @@ async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
matched_emotion = random.choice(selected_emoji.emotion) if selected_emoji.emotion else "随机表情"
# 记录使用次数
emoji_manager.record_usage(selected_emoji.hash)
await emoji_manager.record_usage(selected_emoji.hash)
results.append((emoji_base64, selected_emoji.description, matched_emotion))
if not results and count > 0:
@@ -180,7 +180,7 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
return None
# 记录使用次数
emoji_manager.record_usage(selected_emoji.hash)
await emoji_manager.record_usage(selected_emoji.hash)
logger.info(f"[EmojiAPI] 成功获取情感表情包: {selected_emoji.description}")
return emoji_base64, selected_emoji.description, emotion

View File

@@ -65,7 +65,7 @@ class AffinityChatter(BaseChatter):
"""
try:
# 触发表达学习
learner = expression_learner_manager.get_expression_learner(self.stream_id)
learner = await expression_learner_manager.get_expression_learner(self.stream_id)
asyncio.create_task(learner.trigger_learning_for_chat())
unread_messages = context.get_unread_messages()

View File

@@ -69,7 +69,7 @@ class ChatterInterestScoringSystem:
keywords = self._extract_keywords_from_database(message)
interest_match_score = await self._calculate_interest_match_score(message.processed_plain_text, keywords)
relationship_score = self._calculate_relationship_score(message.user_info.user_id)
relationship_score = await self._calculate_relationship_score(message.user_info.user_id)
mentioned_score = self._calculate_mentioned_score(message, bot_nickname)
total_score = (
@@ -189,7 +189,7 @@ class ChatterInterestScoringSystem:
unique_keywords = list(set(keywords))
return unique_keywords[:10] # 返回前10个唯一关键词
def _calculate_relationship_score(self, user_id: str) -> float:
async def _calculate_relationship_score(self, user_id: str) -> float:
"""计算关系分 - 从数据库获取关系分"""
# 优先使用内存中的关系分
if user_id in self.user_relationships:
@@ -212,7 +212,7 @@ class ChatterInterestScoringSystem:
global_tracker = ChatterRelationshipTracker()
if global_tracker:
relationship_score = global_tracker.get_user_relationship_score(user_id)
relationship_score = await global_tracker.get_user_relationship_score(user_id)
# 同时更新内存缓存
self.user_relationships[user_id] = relationship_score
return relationship_score

View File

@@ -287,7 +287,7 @@ class ChatterRelationshipTracker:
# ===== 数据库支持方法 =====
def get_user_relationship_score(self, user_id: str) -> float:
async def get_user_relationship_score(self, user_id: str) -> float:
"""获取用户关系分"""
# 先检查缓存
if user_id in self.user_relationship_cache:
@@ -298,7 +298,7 @@ class ChatterRelationshipTracker:
return cache_data.get("relationship_score", global_config.affinity_flow.base_relationship_score)
# 缓存过期或不存在,从数据库获取
relationship_data = self._get_user_relationship_from_db(user_id)
relationship_data = await self._get_user_relationship_from_db(user_id)
if relationship_data:
# 更新缓存
self.user_relationship_cache[user_id] = {
@@ -313,37 +313,38 @@ class ChatterRelationshipTracker:
# 数据库中也没有,返回默认值
return global_config.affinity_flow.base_relationship_score
def _get_user_relationship_from_db(self, user_id: str) -> Optional[Dict]:
async def _get_user_relationship_from_db(self, user_id: str) -> Optional[Dict]:
"""从数据库获取用户关系数据"""
try:
with get_db_session() as session:
async with get_db_session() as session:
# 查询用户关系表
stmt = select(UserRelationships).where(UserRelationships.user_id == user_id)
result = session.execute(stmt).scalar_one_or_none()
result = await session.execute(stmt)
relationship = result.scalar_one_or_none()
if result:
if relationship:
return {
"relationship_text": result.relationship_text or "",
"relationship_score": float(result.relationship_score)
if result.relationship_score is not None
"relationship_text": relationship.relationship_text or "",
"relationship_score": float(relationship.relationship_score)
if relationship.relationship_score is not None
else 0.3,
"last_updated": result.last_updated,
"last_updated": relationship.last_updated,
}
except Exception as e:
logger.error(f"从数据库获取用户关系失败: {e}")
return None
def _update_user_relationship_in_db(self, user_id: str, relationship_text: str, relationship_score: float):
async def _update_user_relationship_in_db(self, user_id: str, relationship_text: str, relationship_score: float):
"""更新数据库中的用户关系"""
try:
current_time = time.time()
with get_db_session() as session:
async with get_db_session() as session:
# 检查是否已存在关系记录
existing = session.execute(
select(UserRelationships).where(UserRelationships.user_id == user_id)
).scalar_one_or_none()
stmt = select(UserRelationships).where(UserRelationships.user_id == user_id)
result = await session.execute(stmt)
existing = result.scalar_one_or_none()
if existing:
# 更新现有记录
@@ -362,7 +363,7 @@ class ChatterRelationshipTracker:
)
session.add(new_relationship)
session.commit()
await session.commit()
logger.info(f"已更新数据库中用户关系: {user_id} -> 分数: {relationship_score:.3f}")
except Exception as e:
@@ -399,7 +400,7 @@ class ChatterRelationshipTracker:
logger.debug(f"💬 [RelationshipTracker] 找到用户 {user_id} 在上次回复后的 {len(user_reactions)} 条反应消息")
# 获取当前关系数据
current_relationship = self._get_user_relationship_from_db(user_id)
current_relationship = await self._get_user_relationship_from_db(user_id)
current_score = (
current_relationship.get("relationship_score", global_config.affinity_flow.base_relationship_score)
if current_relationship
@@ -417,14 +418,14 @@ class ChatterRelationshipTracker:
logger.error(f"回复后关系追踪失败: {e}")
logger.debug("错误详情:", exc_info=True)
def _get_last_tracked_time(self, user_id: str) -> float:
async def _get_last_tracked_time(self, user_id: str) -> float:
"""获取上次追踪时间"""
# 先检查缓存
if user_id in self.user_relationship_cache:
return self.user_relationship_cache[user_id].get("last_tracked", 0)
# 从数据库获取
relationship_data = self._get_user_relationship_from_db(user_id)
relationship_data = await self._get_user_relationship_from_db(user_id)
if relationship_data:
return relationship_data.get("last_updated", 0)
@@ -433,7 +434,7 @@ class ChatterRelationshipTracker:
async def _get_last_bot_reply_to_user(self, user_id: str) -> Optional[DatabaseMessages]:
"""获取上次bot回复该用户的消息"""
try:
with get_db_session() as session:
async with get_db_session() as session:
# 查询bot回复给该用户的最新消息
stmt = (
select(Messages)
@@ -443,10 +444,11 @@ class ChatterRelationshipTracker:
.limit(1)
)
result = session.execute(stmt).scalar_one_or_none()
if result:
result = await session.execute(stmt)
message = result.scalar_one_or_none()
if message:
# 将SQLAlchemy模型转换为DatabaseMessages对象
return self._sqlalchemy_to_database_messages(result)
return self._sqlalchemy_to_database_messages(message)
except Exception as e:
logger.error(f"获取上次回复消息失败: {e}")
@@ -456,7 +458,7 @@ class ChatterRelationshipTracker:
async def _get_user_reactions_after_reply(self, user_id: str, reply_time: float) -> List[DatabaseMessages]:
"""获取用户在bot回复后的反应消息"""
try:
with get_db_session() as session:
async with get_db_session() as session:
# 查询用户在回复时间之后的5分钟内的消息
end_time = reply_time + 5 * 60 # 5分钟
@@ -468,9 +470,10 @@ class ChatterRelationshipTracker:
.order_by(Messages.time)
)
results = session.execute(stmt).scalars().all()
if results:
return [self._sqlalchemy_to_database_messages(result) for result in results]
result = await session.execute(stmt)
messages = result.scalars().all()
if messages:
return [self._sqlalchemy_to_database_messages(message) for message in messages]
except Exception as e:
logger.error(f"获取用户反应消息失败: {e}")
@@ -593,7 +596,7 @@ class ChatterRelationshipTracker:
quality = response_data.get("interaction_quality", "medium")
# 更新数据库
self._update_user_relationship_in_db(user_id, new_text, new_score)
await self._update_user_relationship_in_db(user_id, new_text, new_score)
# 更新缓存
self.user_relationship_cache[user_id] = {
@@ -696,7 +699,7 @@ class ChatterRelationshipTracker:
)
# 更新数据库和缓存
self._update_user_relationship_in_db(user_id, new_text, new_score)
await self._update_user_relationship_in_db(user_id, new_text, new_score)
self.user_relationship_cache[user_id] = {
"relationship_text": new_text,
"relationship_score": new_score,

View File

@@ -13,6 +13,7 @@ from typing import Callable
from src.common.logger import get_logger
from src.schedule.schedule_manager import schedule_manager
from src.common.database.sqlalchemy_database_api import get_db_session
from sqlalchemy import select
from src.common.database.sqlalchemy_models import MaiZoneScheduleStatus
from .qzone_service import QZoneService
@@ -138,15 +139,13 @@ class SchedulerService:
:return: 如果已处理过,返回 True否则返回 False。
"""
try:
with get_db_session() as session:
record = (
session.query(MaiZoneScheduleStatus)
.filter(
async with get_db_session() as session:
stmt = select(MaiZoneScheduleStatus).where(
MaiZoneScheduleStatus.datetime_hour == hour_str,
MaiZoneScheduleStatus.is_processed == True, # noqa: E712
)
.first()
)
result = await session.execute(stmt)
record = result.scalar_one_or_none()
return record is not None
except Exception as e:
logger.error(f"检查日程处理状态时发生数据库错误: {e}")
@@ -162,11 +161,11 @@ class SchedulerService:
:param content: 最终发送的说说内容或错误信息。
"""
try:
with get_db_session() as session:
async with get_db_session() as session:
# 查找是否已存在该记录
record = (
session.query(MaiZoneScheduleStatus).filter(MaiZoneScheduleStatus.datetime_hour == hour_str).first()
)
stmt = select(MaiZoneScheduleStatus).where(MaiZoneScheduleStatus.datetime_hour == hour_str)
result = await session.execute(stmt)
record = result.scalar_one_or_none()
if record:
# 如果存在,则更新状态
@@ -185,7 +184,7 @@ class SchedulerService:
send_success=success,
)
session.add(new_record)
session.commit()
await session.commit()
logger.info(f"已更新日程处理状态: {hour_str} - {activity} - 成功: {success}")
except Exception as e:
logger.error(f"更新日程处理状态时发生数据库错误: {e}")

View File

@@ -64,15 +64,9 @@ async def message_recv(server_connection: Server.ServerConnection):
# 处理完整消息(可能是重组后的,也可能是原本就完整的)
post_type = decoded_raw_message.get("post_type")
# 兼容没有 post_type 的普通消息
if not post_type and "message_type" in decoded_raw_message:
decoded_raw_message["post_type"] = "message"
post_type = "message"
if post_type in ["meta_event", "message", "notice"]:
await message_queue.put(decoded_raw_message)
else:
elif post_type is None:
await put_response(decoded_raw_message)
except json.JSONDecodeError as e:
@@ -428,8 +422,9 @@ class NapcatAdapterPlugin(BasePlugin):
def get_plugin_components(self):
self.register_events()
components = [(LauchNapcatAdapterHandler.get_handler_info(), LauchNapcatAdapterHandler),
(StopNapcatAdapterHandler.get_handler_info(), StopNapcatAdapterHandler)]
components = []
components.append((LauchNapcatAdapterHandler.get_handler_info(), LauchNapcatAdapterHandler))
components.append((StopNapcatAdapterHandler.get_handler_info(), StopNapcatAdapterHandler))
for handler in get_classes_in_module(event_handlers):
if issubclass(handler, BaseEventHandler):
components.append((handler.get_handler_info(), handler))

View File

@@ -1,156 +1,162 @@
"""Napcat Adapter 插件数据库层 (基于主程序异步SQLAlchemy API)
本模块替换原先的 sqlmodel + 同步Session 实现:
1. 复用主项目的异步数据库连接与迁移体系
2. 提供与旧接口名兼容的方法(update_ban_record/create_ban_record/delete_ban_record)
3. 新增首选异步方法: update_ban_records / create_or_update / delete_record / get_ban_records
数据语义:
user_id == 0 表示群全体禁言
注意: 所有方法均为异步, 需要在 async 上下文中调用。
"""
from __future__ import annotations
import os
from typing import Optional, List
from dataclasses import dataclass
from typing import Optional, List, Sequence
from sqlmodel import Field, Session, SQLModel, create_engine, select
from sqlalchemy import Column, Integer, BigInteger, UniqueConstraint, select, Index
from sqlalchemy.ext.asyncio import AsyncSession
from src.common.database.sqlalchemy_models import Base, get_db_session
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
"""
表记录的方式:
| group_id | user_id | lift_time |
|----------|---------|-----------|
class NapcatBanRecord(Base):
__tablename__ = "napcat_ban_records"
id = Column(Integer, primary_key=True, autoincrement=True)
group_id = Column(BigInteger, nullable=False, index=True)
user_id = Column(BigInteger, nullable=False, index=True) # 0 == 全体禁言
lift_time = Column(BigInteger, nullable=True) # -1 / None 表示未知/永久
__table_args__ = (
UniqueConstraint("group_id", "user_id", name="uq_napcat_group_user"),
Index("idx_napcat_ban_group", "group_id"),
Index("idx_napcat_ban_user", "user_id"),
)
其中使用 user_id == 0 表示群全体禁言
"""
@dataclass
class BanUser:
"""
程序处理使用的实例
"""
user_id: int
group_id: int
lift_time: Optional[int] = -1
def identity(self) -> tuple[int, int]:
return self.group_id, self.user_id
lift_time: Optional[int] = Field(default=-1)
class NapcatDatabase:
async def _fetch_all(self, session: AsyncSession) -> Sequence[NapcatBanRecord]:
result = await session.execute(select(NapcatBanRecord))
return result.scalars().all()
class DB_BanUser(SQLModel, table=True):
"""
表示数据库中的用户禁言记录。
使用双重主键
"""
async def get_ban_records(self) -> List[BanUser]:
async with get_db_session() as session:
rows = await self._fetch_all(session)
return [BanUser(group_id=r.group_id, user_id=r.user_id, lift_time=r.lift_time) for r in rows]
user_id: int = Field(index=True, primary_key=True) # 被禁言用户的用户 ID
group_id: int = Field(index=True, primary_key=True) # 用户被禁言的群组 ID
lift_time: Optional[int] # 禁言解除的时间(时间戳)
async def update_ban_records(self, ban_list: List[BanUser]) -> None:
target_map = {b.identity(): b for b in ban_list}
async with get_db_session() as session:
rows = await self._fetch_all(session)
existing_map = {(r.group_id, r.user_id): r for r in rows}
changed = 0
for ident, ban in target_map.items():
if ident in existing_map:
row = existing_map[ident]
if row.lift_time != ban.lift_time:
row.lift_time = ban.lift_time
changed += 1
def is_identical(obj1: BanUser, obj2: BanUser) -> bool:
"""
检查两个 BanUser 对象是否相同。
"""
return obj1.user_id == obj2.user_id and obj1.group_id == obj2.group_id
class DatabaseManager:
"""
数据库管理类,负责与数据库交互。
"""
def __init__(self):
os.makedirs(os.path.join(os.path.dirname(__file__), "..", "data"), exist_ok=True) # 确保数据目录存在
DATABASE_FILE = os.path.join(os.path.dirname(__file__), "..", "data", "NapcatAdapter.db")
self.sqlite_url = f"sqlite:///{DATABASE_FILE}" # SQLite 数据库 URL
self.engine = create_engine(self.sqlite_url, echo=False) # 创建数据库引擎
self._ensure_database() # 确保数据库和表已创建
def _ensure_database(self) -> None:
"""
确保数据库和表已创建。
"""
logger.info("确保数据库文件和表已创建...")
SQLModel.metadata.create_all(self.engine)
logger.info("数据库和表已创建或已存在")
def update_ban_record(self, ban_list: List[BanUser]) -> None:
# sourcery skip: class-extract-method
"""
更新禁言列表到数据库。
支持在不存在时创建新记录,对于多余的项目自动删除。
"""
with Session(self.engine) as session:
all_records = session.exec(select(DB_BanUser)).all()
for ban_user in ban_list:
statement = select(DB_BanUser).where(
DB_BanUser.user_id == ban_user.user_id, DB_BanUser.group_id == ban_user.group_id
)
if existing_record := session.exec(statement).first():
if existing_record.lift_time == ban_user.lift_time:
logger.debug(f"禁言记录未变更: {existing_record}")
continue
# 更新现有记录的 lift_time
existing_record.lift_time = ban_user.lift_time
session.add(existing_record)
logger.debug(f"更新禁言记录: {existing_record}")
else:
session.add(
NapcatBanRecord(group_id=ban.group_id, user_id=ban.user_id, lift_time=ban.lift_time)
# 创建新记录
db_record = DB_BanUser(
user_id=ban_user.user_id, group_id=ban_user.group_id, lift_time=ban_user.lift_time
)
changed += 1
session.add(db_record)
logger.debug(f"创建新禁言记录: {ban_user}")
# 删除不在 ban_list 中的记录
for db_record in all_records:
record = BanUser(user_id=db_record.user_id, group_id=db_record.group_id, lift_time=db_record.lift_time)
if not any(is_identical(record, ban_user) for ban_user in ban_list):
statement = select(DB_BanUser).where(
DB_BanUser.user_id == record.user_id, DB_BanUser.group_id == record.group_id
)
if ban_record := session.exec(statement).first():
session.delete(ban_record)
removed = 0
for ident, row in existing_map.items():
if ident not in target_map:
await session.delete(row)
removed += 1
logger.debug(
f"Napcat ban list sync => total_incoming={len(ban_list)} created_or_updated={changed} removed={removed}"
)
async def create_or_update(self, ban_record: BanUser) -> None:
async with get_db_session() as session:
stmt = select(NapcatBanRecord).where(
NapcatBanRecord.group_id == ban_record.group_id,
NapcatBanRecord.user_id == ban_record.user_id,
)
result = await session.execute(stmt)
row = result.scalars().first()
if row:
if row.lift_time != ban_record.lift_time:
row.lift_time = ban_record.lift_time
logger.debug(
f"更新禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={ban_record.lift_time}"
)
logger.debug(f"删除禁言记录: {ban_record}")
else:
session.add(
NapcatBanRecord(
group_id=ban_record.group_id, user_id=ban_record.user_id, lift_time=ban_record.lift_time
)
)
logger.debug(
f"创建禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={ban_record.lift_time}"
)
logger.info(f"未找到禁言记录: {ban_record}")
async def delete_record(self, ban_record: BanUser) -> None:
async with get_db_session() as session:
stmt = select(NapcatBanRecord).where(
NapcatBanRecord.group_id == ban_record.group_id,
NapcatBanRecord.user_id == ban_record.user_id,
)
result = await session.execute(stmt)
row = result.scalars().first()
if row:
await session.delete(row)
logger.debug(
f"删除禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={row.lift_time}"
logger.info("禁言记录已更新")
def get_ban_records(self) -> List[BanUser]:
"""
读取所有禁言记录。
"""
with Session(self.engine) as session:
statement = select(DB_BanUser)
records = session.exec(statement).all()
return [BanUser(user_id=item.user_id, group_id=item.group_id, lift_time=item.lift_time) for item in records]
def create_ban_record(self, ban_record: BanUser) -> None:
"""
为特定群组中的用户创建禁言记录。
一个简化版本的添加方式,防止 update_ban_record 方法的复杂性。
其同时还是简化版的更新方式。
"""
with Session(self.engine) as session:
# 检查记录是否已存在
statement = select(DB_BanUser).where(
DB_BanUser.user_id == ban_record.user_id, DB_BanUser.group_id == ban_record.group_id
)
existing_record = session.exec(statement).first()
if existing_record:
# 如果记录已存在,更新 lift_time
existing_record.lift_time = ban_record.lift_time
session.add(existing_record)
logger.debug(f"更新禁言记录: {ban_record}")
else:
logger.info(
f"未找到禁言记录 group={ban_record.group_id} user={ban_record.user_id}"
# 如果记录不存在,创建新记录
db_record = DB_BanUser(
user_id=ban_record.user_id, group_id=ban_record.group_id, lift_time=ban_record.lift_time
)
session.add(db_record)
logger.debug(f"创建新禁言记录: {ban_record}")
# 兼容旧命名
async def update_ban_record(self, ban_list: List[BanUser]) -> None: # old name
await self.update_ban_records(ban_list)
def delete_ban_record(self, ban_record: BanUser):
"""
删除特定用户在特定群组中的禁言记录。
一个简化版本的删除方式,防止 update_ban_record 方法的复杂性。
"""
user_id = ban_record.user_id
group_id = ban_record.group_id
with Session(self.engine) as session:
statement = select(DB_BanUser).where(DB_BanUser.user_id == user_id, DB_BanUser.group_id == group_id)
if ban_record := session.exec(statement).first():
session.delete(ban_record)
async def create_ban_record(self, ban_record: BanUser) -> None: # old name
await self.create_or_update(ban_record)
async def delete_ban_record(self, ban_record: BanUser) -> None: # old name
await self.delete_record(ban_record)
logger.debug(f"删除禁言记录: {ban_record}")
else:
logger.info(f"未找到禁言记录: user_id: {user_id}, group_id: {group_id}")
napcat_db = NapcatDatabase()
def is_identical(a: BanUser, b: BanUser) -> bool:
return a.group_id == b.group_id and a.user_id == b.user_id
__all__ = [
"BanUser",
"NapcatBanRecord",
"napcat_db",
"is_identical",
]
db_manager = DatabaseManager()

View File

@@ -112,8 +112,7 @@ class MessageChunker:
else:
return [{"_original_message": message}]
@staticmethod
def is_chunk_message(message: Union[str, Dict[str, Any]]) -> bool:
def is_chunk_message(self, message: Union[str, Dict[str, Any]]) -> bool:
"""判断是否是切片消息"""
try:
if isinstance(message, str):

View File

@@ -14,7 +14,6 @@ class MetaEventHandler:
"""
def __init__(self):
self.last_heart_beat = time.time()
self.interval = 5.0 # 默认值稍后通过set_plugin_config设置
self._interval_checking = False
self.plugin_config = None
@@ -40,6 +39,7 @@ class MetaEventHandler:
if message["status"].get("online") and message["status"].get("good"):
if not self._interval_checking:
asyncio.create_task(self.check_heartbeat())
self.last_heart_beat = time.time()
self.interval = message.get("interval") / 1000
else:
self_id = message.get("self_id")

View File

@@ -76,7 +76,7 @@ class SendHandler:
processed_message = await self.handle_seg_recursive(message_segment, user_info)
except Exception as e:
logger.error(f"处理消息时发生错误: {e}")
return None
return
if not processed_message:
logger.critical("现在暂时不支持解析此回复!")
@@ -94,7 +94,7 @@ class SendHandler:
id_name = "user_id"
else:
logger.error("无法识别的消息类型")
return None
return
logger.info("尝试发送到napcat")
logger.debug(f"准备发送到napcat的消息体: action='{action}', {id_name}='{target_id}', message='{processed_message}'")
response = await self.send_message_to_napcat(
@@ -108,10 +108,8 @@ class SendHandler:
logger.info("消息发送成功")
qq_message_id = response.get("data", {}).get("message_id")
await self.message_sent_back(raw_message_base, qq_message_id)
return None
else:
logger.warning(f"消息发送失败napcat返回{str(response)}")
return None
async def send_command(self, raw_message_base: MessageBase) -> None:
"""
@@ -149,7 +147,7 @@ class SendHandler:
command, args_dict = self.handle_send_like_command(args)
case _:
logger.error(f"未知命令: {command_name}")
return None
return
except Exception as e:
logger.error(f"处理命令时发生错误: {e}")
return None
@@ -161,10 +159,8 @@ class SendHandler:
response = await self.send_message_to_napcat(command, args_dict)
if response.get("status") == "ok":
logger.info(f"命令 {command_name} 执行成功")
return None
else:
logger.warning(f"命令 {command_name} 执行失败napcat返回{str(response)}")
return None
async def handle_adapter_command(self, raw_message_base: MessageBase) -> None:
"""
@@ -272,8 +268,7 @@ class SendHandler:
new_payload = self.build_payload(payload, self.handle_file_message(file_path), False)
return new_payload
@staticmethod
def build_payload(payload: list, addon: dict | list, is_reply: bool = False) -> list:
def build_payload(self, payload: list, addon: dict | list, is_reply: bool = False) -> list:
# sourcery skip: for-append-to-extend, merge-list-append, simplify-generator
"""构建发送的消息体"""
if is_reply:
@@ -339,13 +334,11 @@ class SendHandler:
logger.info(f"最终返回的回复段: {reply_seg}")
return reply_seg
@staticmethod
def handle_text_message(message: str) -> dict:
def handle_text_message(self, message: str) -> dict:
"""处理文本消息"""
return {"type": "text", "data": {"text": message}}
@staticmethod
def handle_image_message(encoded_image: str) -> dict:
def handle_image_message(self, encoded_image: str) -> dict:
"""处理图片消息"""
return {
"type": "image",
@@ -355,8 +348,7 @@ class SendHandler:
},
} # base64 编码的图片
@staticmethod
def handle_emoji_message(encoded_emoji: str) -> dict:
def handle_emoji_message(self, encoded_emoji: str) -> dict:
"""处理表情消息"""
encoded_image = encoded_emoji
image_format = get_image_format(encoded_emoji)
@@ -387,45 +379,39 @@ class SendHandler:
"data": {"file": f"base64://{encoded_voice}"},
}
@staticmethod
def handle_voiceurl_message(voice_url: str) -> dict:
def handle_voiceurl_message(self, voice_url: str) -> dict:
"""处理语音链接消息"""
return {
"type": "record",
"data": {"file": voice_url},
}
@staticmethod
def handle_music_message(song_id: str) -> dict:
def handle_music_message(self, song_id: str) -> dict:
"""处理音乐消息"""
return {
"type": "music",
"data": {"type": "163", "id": song_id},
}
@staticmethod
def handle_videourl_message(video_url: str) -> dict:
def handle_videourl_message(self, video_url: str) -> dict:
"""处理视频链接消息"""
return {
"type": "video",
"data": {"file": video_url},
}
@staticmethod
def handle_file_message(file_path: str) -> dict:
def handle_file_message(self, file_path: str) -> dict:
"""处理文件消息"""
return {
"type": "file",
"data": {"file": f"file://{file_path}"},
}
@staticmethod
def delete_msg_command(args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
def delete_msg_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
"""处理删除消息命令"""
return "delete_msg", {"message_id": args["message_id"]}
@staticmethod
def handle_ban_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
def handle_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""处理封禁命令
Args:
@@ -453,8 +439,7 @@ class SendHandler:
},
)
@staticmethod
def handle_whole_ban_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
def handle_whole_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""处理全体禁言命令
Args:
@@ -477,8 +462,7 @@ class SendHandler:
},
)
@staticmethod
def handle_kick_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
def handle_kick_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""处理群成员踢出命令
Args:
@@ -503,8 +487,7 @@ class SendHandler:
},
)
@staticmethod
def handle_poke_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
def handle_poke_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""处理戳一戳命令
Args:
@@ -531,8 +514,7 @@ class SendHandler:
},
)
@staticmethod
def handle_set_emoji_like_command(args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
def handle_set_emoji_like_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
"""处理设置表情回应命令
Args:
@@ -554,8 +536,7 @@ class SendHandler:
{"message_id": message_id, "emoji_id": emoji_id, "set": set_like},
)
@staticmethod
def handle_send_like_command(args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
def handle_send_like_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
"""
处理发送点赞命令的逻辑。
@@ -576,8 +557,7 @@ class SendHandler:
{"user_id": user_id, "times": times},
)
@staticmethod
def handle_ai_voice_send_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
def handle_ai_voice_send_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""
处理AI语音发送命令的逻辑。
并返回 NapCat 兼容的 (action, params) 元组。
@@ -624,8 +604,7 @@ class SendHandler:
return {"status": "error", "message": str(e)}
return response
@staticmethod
async def message_sent_back(message_base: MessageBase, qq_message_id: str) -> None:
async def message_sent_back(self, message_base: MessageBase, qq_message_id: str) -> None:
# 修改 additional_config添加 echo 字段
if message_base.message_info.additional_config is None:
message_base.message_info.additional_config = {}
@@ -643,9 +622,8 @@ class SendHandler:
logger.debug("已回送消息ID")
return
@staticmethod
async def send_adapter_command_response(
original_message: MessageBase, response_data: dict, request_id: str
self, original_message: MessageBase, response_data: dict, request_id: str
) -> None:
"""
发送适配器命令响应回MaiBot
@@ -674,8 +652,7 @@ class SendHandler:
except Exception as e:
logger.error(f"发送适配器命令响应时出错: {e}")
@staticmethod
def handle_at_message_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
def handle_at_message_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""处理艾特并发送消息命令
Args:

View File

@@ -6,7 +6,7 @@ import urllib3
import ssl
import io
from .database import BanUser, napcat_db
from .database import BanUser, db_manager
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
@@ -270,11 +270,10 @@ async def read_ban_list(
]
"""
try:
ban_list = await napcat_db.get_ban_records()
ban_list = db_manager.get_ban_records()
lifted_list: List[BanUser] = []
logger.info("已经读取禁言列表")
# 复制列表以避免迭代中修改原列表问题
for ban_record in list(ban_list):
for ban_record in ban_list:
if ban_record.user_id == 0:
fetched_group_info = await get_group_info(websocket, ban_record.group_id)
if fetched_group_info is None:
@@ -302,12 +301,12 @@ async def read_ban_list(
ban_list.remove(ban_record)
else:
ban_record.lift_time = lift_ban_time
await napcat_db.update_ban_record(ban_list)
db_manager.update_ban_record(ban_list)
return ban_list, lifted_list
except Exception as e:
logger.error(f"读取禁言列表失败: {e}")
return [], []
async def save_ban_record(list: List[BanUser]):
return await napcat_db.update_ban_record(list)
def save_ban_record(list: List[BanUser]):
return db_manager.update_ban_record(list)

110
test_deepcopy_fix.py Normal file
View File

@@ -0,0 +1,110 @@
#!/usr/bin/env python3
"""
测试 ChatStream 的 deepcopy 功能
验证 asyncio.Task 序列化问题是否已解决
"""
import asyncio
import sys
import os
import copy
# 添加项目根目录到 Python 路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from src.chat.message_receive.chat_stream import ChatStream
from maim_message import UserInfo, GroupInfo
async def test_chat_stream_deepcopy():
"""测试 ChatStream 的 deepcopy 功能"""
print("[TEST] 开始测试 ChatStream deepcopy 功能...")
try:
# 创建测试用的用户和群组信息
user_info = UserInfo(
platform="test_platform",
user_id="test_user_123",
user_nickname="测试用户",
user_cardname="测试卡片名"
)
group_info = GroupInfo(
platform="test_platform",
group_id="test_group_456",
group_name="测试群组"
)
# 创建 ChatStream 实例
print("📝 创建 ChatStream 实例...")
stream_id = "test_stream_789"
platform = "test_platform"
chat_stream = ChatStream(
stream_id=stream_id,
platform=platform,
user_info=user_info,
group_info=group_info
)
print(f"[SUCCESS] ChatStream 创建成功: {chat_stream.stream_id}")
# 等待一下,让异步任务有机会创建
await asyncio.sleep(0.1)
# 尝试进行 deepcopy
print("[INFO] 尝试进行 deepcopy...")
copied_stream = copy.deepcopy(chat_stream)
print("[SUCCESS] deepcopy 成功!")
# 验证复制后的对象属性
print("\n[CHECK] 验证复制后的对象属性:")
print(f" - stream_id: {copied_stream.stream_id}")
print(f" - platform: {copied_stream.platform}")
print(f" - user_info: {copied_stream.user_info.user_nickname}")
print(f" - group_info: {copied_stream.group_info.group_name}")
# 检查 processing_task 是否被正确处理
if hasattr(copied_stream.stream_context, 'processing_task'):
print(f" - processing_task: {copied_stream.stream_context.processing_task}")
if copied_stream.stream_context.processing_task is None:
print(" [SUCCESS] processing_task 已被正确设置为 None")
else:
print(" [WARNING] processing_task 不为 None")
else:
print(" [SUCCESS] stream_context 没有 processing_task 属性")
# 验证原始对象和复制对象是不同的实例
if id(chat_stream) != id(copied_stream):
print("[SUCCESS] 原始对象和复制对象是不同的实例")
else:
print("[ERROR] 原始对象和复制对象是同一个实例")
# 验证基本属性是否正确复制
if (chat_stream.stream_id == copied_stream.stream_id and
chat_stream.platform == copied_stream.platform):
print("[SUCCESS] 基本属性正确复制")
else:
print("[ERROR] 基本属性复制失败")
print("\n[COMPLETE] 测试完成deepcopy 功能修复成功!")
return True
except Exception as e:
print(f"[ERROR] 测试失败: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
# 运行测试
result = asyncio.run(test_chat_stream_deepcopy())
if result:
print("\n[SUCCESS] 所有测试通过!")
sys.exit(0)
else:
print("\n[ERROR] 测试失败!")
sys.exit(1)

109
test_simple_deepcopy.py Normal file
View File

@@ -0,0 +1,109 @@
#!/usr/bin/env python3
"""
简单的 ChatStream deepcopy 测试
"""
import asyncio
import sys
import os
import copy
# 添加项目根目录到 Python 路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from src.chat.message_receive.chat_stream import ChatStream
from maim_message import UserInfo, GroupInfo
async def test_deepcopy():
"""测试 deepcopy 功能"""
print("开始测试 ChatStream deepcopy 功能...")
try:
# 创建测试用的用户和群组信息
user_info = UserInfo(
platform="test_platform",
user_id="test_user_123",
user_nickname="测试用户",
user_cardname="测试卡片名"
)
group_info = GroupInfo(
platform="test_platform",
group_id="test_group_456",
group_name="测试群组"
)
# 创建 ChatStream 实例
print("创建 ChatStream 实例...")
stream_id = "test_stream_789"
platform = "test_platform"
chat_stream = ChatStream(
stream_id=stream_id,
platform=platform,
user_info=user_info,
group_info=group_info
)
print(f"ChatStream 创建成功: {chat_stream.stream_id}")
# 等待一下,让异步任务有机会创建
await asyncio.sleep(0.1)
# 尝试进行 deepcopy
print("尝试进行 deepcopy...")
copied_stream = copy.deepcopy(chat_stream)
print("deepcopy 成功!")
# 验证复制后的对象属性
print("\n验证复制后的对象属性:")
print(f" - stream_id: {copied_stream.stream_id}")
print(f" - platform: {copied_stream.platform}")
print(f" - user_info: {copied_stream.user_info.user_nickname}")
print(f" - group_info: {copied_stream.group_info.group_name}")
# 检查 processing_task 是否被正确处理
if hasattr(copied_stream.stream_context, 'processing_task'):
print(f" - processing_task: {copied_stream.stream_context.processing_task}")
if copied_stream.stream_context.processing_task is None:
print(" SUCCESS: processing_task 已被正确设置为 None")
else:
print(" WARNING: processing_task 不为 None")
else:
print(" SUCCESS: stream_context 没有 processing_task 属性")
# 验证原始对象和复制对象是不同的实例
if id(chat_stream) != id(copied_stream):
print("SUCCESS: 原始对象和复制对象是不同的实例")
else:
print("ERROR: 原始对象和复制对象是同一个实例")
# 验证基本属性是否正确复制
if (chat_stream.stream_id == copied_stream.stream_id and
chat_stream.platform == copied_stream.platform):
print("SUCCESS: 基本属性正确复制")
else:
print("ERROR: 基本属性复制失败")
print("\n测试完成deepcopy 功能修复成功!")
return True
except Exception as e:
print(f"ERROR: 测试失败: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
# 运行测试
result = asyncio.run(test_deepcopy())
if result:
print("\n所有测试通过!")
sys.exit(0)
else:
print("\n测试失败!")
sys.exit(1)