refactor(database): 将同步数据库操作迁移为异步操作

将整个项目的数据库操作从同步模式迁移为异步模式,主要涉及以下修改:

- 将 `with get_db_session()` 改为 `async with get_db_session()`
- 将同步的 SQLAlchemy 查询方法改为异步执行
- 更新相关的方法签名,添加 async/await 关键字
- 修复由于异步化导致的并发问题和性能问题

这些修改提高了数据库操作的并发性能,避免了阻塞主线程,提升了系统的整体响应能力。涉及修改的模块包括表情包管理、反提示注入统计、用户封禁管理、记忆系统、消息存储等多个核心组件。

BREAKING CHANGE: 所有涉及数据库操作的方法现在都需要使用异步调用,同步调用将不再工作
This commit is contained in:
Windpicker-owo
2025-09-28 15:42:30 +08:00
parent ff24bd8148
commit 08ef960947
35 changed files with 1180 additions and 1053 deletions

13
bot.py
View File

@@ -185,12 +185,12 @@ class MaiBotMain(BaseMain):
check_eula() check_eula()
logger.info("检查EULA和隐私条款完成") logger.info("检查EULA和隐私条款完成")
def initialize_database(self): async def initialize_database(self):
"""初始化数据库""" """初始化数据库"""
logger.info("正在初始化数据库连接...") logger.info("正在初始化数据库连接...")
try: try:
initialize_sql_database(global_config.database) await initialize_sql_database(global_config.database)
logger.info(f"数据库连接初始化成功,使用 {global_config.database.database_type} 数据库") logger.info(f"数据库连接初始化成功,使用 {global_config.database.database_type} 数据库")
except Exception as e: except Exception as e:
logger.error(f"数据库连接初始化失败: {e}") logger.error(f"数据库连接初始化失败: {e}")
@@ -211,11 +211,11 @@ class MaiBotMain(BaseMain):
self.main_system = MainSystem() self.main_system = MainSystem()
return self.main_system return self.main_system
def run(self): async def run(self):
"""运行主程序""" """运行主程序"""
self.setup_timezone() self.setup_timezone()
self.check_and_confirm_eula() self.check_and_confirm_eula()
self.initialize_database() await self.initialize_database()
return self.create_main_system() return self.create_main_system()
@@ -225,14 +225,14 @@ if __name__ == "__main__":
try: try:
# 创建MaiBotMain实例并获取MainSystem # 创建MaiBotMain实例并获取MainSystem
maibot = MaiBotMain() maibot = MaiBotMain()
main_system = maibot.run()
# 创建事件循环 # 创建事件循环
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
try: try:
# 异步初始化数据库表结构 # 异步初始化数据库表结构
main_system = loop.run_until_complete(maibot.run())
loop.run_until_complete(maibot.initialize_database_async()) loop.run_until_complete(maibot.initialize_database_async())
# 执行初始化和任务调度 # 执行初始化和任务调度
loop.run_until_complete(main_system.initialize()) loop.run_until_complete(main_system.initialize())
@@ -269,3 +269,4 @@ if __name__ == "__main__":
# 在程序退出前暂停,让你有机会看到输出 # 在程序退出前暂停,让你有机会看到输出
# input("按 Enter 键退出...") # <--- 添加这行 # input("按 Enter 键退出...") # <--- 添加这行
sys.exit(exit_code) # <--- 使用记录的退出码 sys.exit(exit_code) # <--- 使用记录的退出码

View File

@@ -8,6 +8,8 @@
import datetime import datetime
from typing import Dict, Any from typing import Dict, Any
from sqlalchemy import select
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import AntiInjectionStats, get_db_session from src.common.database.sqlalchemy_models import AntiInjectionStats, get_db_session
from src.config.config import global_config from src.config.config import global_config
@@ -27,9 +29,11 @@ class AntiInjectionStatistics:
async def get_or_create_stats(): async def get_or_create_stats():
"""获取或创建统计记录""" """获取或创建统计记录"""
try: try:
with get_db_session() as session: async with get_db_session() as session:
# 获取最新的统计记录,如果没有则创建 # 获取最新的统计记录,如果没有则创建
stats = session.query(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()).first() stats = (await session.execute(
select(AntiInjectionStats).order_by(AntiInjectionStats.id.desc())
)).scalars().first()
if not stats: if not stats:
stats = AntiInjectionStats() stats = AntiInjectionStats()
session.add(stats) session.add(stats)
@@ -44,8 +48,10 @@ class AntiInjectionStatistics:
async def update_stats(**kwargs): async def update_stats(**kwargs):
"""更新统计数据""" """更新统计数据"""
try: try:
with get_db_session() as session: async with get_db_session() as session:
stats = session.query(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()).first() stats = (await session.execute(
select(AntiInjectionStats).order_by(AntiInjectionStats.id.desc())
)).scalars().first()
if not stats: if not stats:
stats = AntiInjectionStats() stats = AntiInjectionStats()
session.add(stats) session.add(stats)
@@ -138,9 +144,9 @@ class AntiInjectionStatistics:
async def reset_stats(): async def reset_stats():
"""重置统计信息""" """重置统计信息"""
try: try:
with get_db_session() as session: async with get_db_session() as session:
# 删除现有统计记录 # 删除现有统计记录
session.query(AntiInjectionStats).delete() await session.execute(select(AntiInjectionStats).delete())
await session.commit() await session.commit()
logger.info("统计信息已重置") logger.info("统计信息已重置")
except Exception as e: except Exception as e:

View File

@@ -8,6 +8,8 @@
import datetime import datetime
from typing import Optional, Tuple from typing import Optional, Tuple
from sqlalchemy import select
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import BanUser, get_db_session from src.common.database.sqlalchemy_models import BanUser, get_db_session
from ..types import DetectionResult from ..types import DetectionResult
@@ -37,8 +39,9 @@ class UserBanManager:
如果用户被封禁则返回拒绝结果否则返回None 如果用户被封禁则返回拒绝结果否则返回None
""" """
try: try:
with get_db_session() as session: async with get_db_session() as session:
ban_record = session.query(BanUser).filter_by(user_id=user_id, platform=platform).first() result = await session.execute(select(BanUser).filter_by(user_id=user_id, platform=platform))
ban_record = result.scalar_one_or_none()
if ban_record: if ban_record:
# 只有违规次数达到阈值时才算被封禁 # 只有违规次数达到阈值时才算被封禁
@@ -70,9 +73,10 @@ class UserBanManager:
detection_result: 检测结果 detection_result: 检测结果
""" """
try: try:
with get_db_session() as session: async with get_db_session() as session:
# 查找或创建违规记录 # 查找或创建违规记录
ban_record = session.query(BanUser).filter_by(user_id=user_id, platform=platform).first() result = await session.execute(select(BanUser).filter_by(user_id=user_id, platform=platform))
ban_record = result.scalar_one_or_none()
if ban_record: if ban_record:
ban_record.violation_num += 1 ban_record.violation_num += 1

View File

@@ -149,7 +149,7 @@ class MaiEmoji:
# --- 数据库操作 --- # --- 数据库操作 ---
try: try:
# 准备数据库记录 for emoji collection # 准备数据库记录 for emoji collection
with get_db_session() as session: async with get_db_session() as session:
emotion_str = ",".join(self.emotion) if self.emotion else "" emotion_str = ",".join(self.emotion) if self.emotion else ""
emoji = Emoji( emoji = Emoji(
@@ -167,7 +167,7 @@ class MaiEmoji:
last_used_time=self.last_used_time, last_used_time=self.last_used_time,
) )
session.add(emoji) session.add(emoji)
session.commit() await session.commit()
logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})") logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")
@@ -203,17 +203,18 @@ class MaiEmoji:
# 2. 删除数据库记录 # 2. 删除数据库记录
try: try:
with get_db_session() as session: async with get_db_session() as session:
will_delete_emoji = session.execute( result = await session.execute(
select(Emoji).where(Emoji.emoji_hash == self.hash) select(Emoji).where(Emoji.emoji_hash == self.hash)
).scalar_one_or_none() )
will_delete_emoji = result.scalar_one_or_none()
if will_delete_emoji is None: if will_delete_emoji is None:
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。") logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
result = 0 # Indicate no DB record was deleted result = 0 # Indicate no DB record was deleted
else: else:
session.delete(will_delete_emoji) await session.delete(will_delete_emoji)
result = 1 # Successfully deleted one record result = 1 # Successfully deleted one record
session.commit() await session.commit()
except Exception as e: except Exception as e:
logger.error(f"[错误] 删除数据库记录时出错: {str(e)}") logger.error(f"[错误] 删除数据库记录时出错: {str(e)}")
result = 0 result = 0
@@ -424,17 +425,19 @@ class EmojiManager:
# if not self._initialized: # if not self._initialized:
# raise RuntimeError("EmojiManager not initialized") # raise RuntimeError("EmojiManager not initialized")
def record_usage(self, emoji_hash: str) -> None: async def record_usage(self, emoji_hash: str) -> None:
"""记录表情使用次数""" """记录表情使用次数"""
try: try:
with get_db_session() as session: async with get_db_session() as session:
emoji_update = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none() stmt = select(Emoji).where(Emoji.emoji_hash == emoji_hash)
result = await session.execute(stmt)
emoji_update = result.scalar_one_or_none()
if emoji_update is None: if emoji_update is None:
logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包") logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包")
else: else:
emoji_update.usage_count += 1 emoji_update.usage_count += 1
emoji_update.last_used_time = time.time() # Update last used time emoji_update.last_used_time = time.time() # Update last used time
session.commit() await session.commit()
except Exception as e: except Exception as e:
logger.error(f"记录表情使用失败: {str(e)}") logger.error(f"记录表情使用失败: {str(e)}")
@@ -521,7 +524,7 @@ class EmojiManager:
# 7. 获取选中的表情包并更新使用记录 # 7. 获取选中的表情包并更新使用记录
selected_emoji = candidate_emojis[selected_index] selected_emoji = candidate_emojis[selected_index]
self.record_usage(selected_emoji.hash) await self.record_usage(selected_emoji.hash)
_time_end = time.time() _time_end = time.time()
logger.info(f"找到匹配描述的表情包: {selected_emoji.description}, 耗时: {(_time_end - _time_start):.2f}s") logger.info(f"找到匹配描述的表情包: {selected_emoji.description}, 耗时: {(_time_end - _time_start):.2f}s")
@@ -657,10 +660,11 @@ class EmojiManager:
async def get_all_emoji_from_db(self) -> None: async def get_all_emoji_from_db(self) -> None:
"""获取所有表情包并初始化为MaiEmoji类对象更新 self.emoji_objects""" """获取所有表情包并初始化为MaiEmoji类对象更新 self.emoji_objects"""
try: try:
with get_db_session() as session: async with get_db_session() as session:
logger.debug("[数据库] 开始加载所有表情包记录 ...") logger.debug("[数据库] 开始加载所有表情包记录 ...")
emoji_instances = session.execute(select(Emoji)).scalars().all() result = await session.execute(select(Emoji))
emoji_instances = result.scalars().all()
emoji_objects, load_errors = _to_emoji_objects(emoji_instances) emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
# 更新内存中的列表和数量 # 更新内存中的列表和数量
@@ -686,14 +690,16 @@ class EmojiManager:
list[MaiEmoji]: 表情包对象列表 list[MaiEmoji]: 表情包对象列表
""" """
try: try:
with get_db_session() as session: async with get_db_session() as session:
if emoji_hash: if emoji_hash:
query = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalars().all() result = await session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash))
query = result.scalars().all()
else: else:
logger.warning( logger.warning(
"[查询] 未提供 hash将尝试加载所有表情包建议使用 get_all_emoji_from_db 更新管理器状态。" "[查询] 未提供 hash将尝试加载所有表情包建议使用 get_all_emoji_from_db 更新管理器状态。"
) )
query = session.execute(select(Emoji)).scalars().all() result = await session.execute(select(Emoji))
query = result.scalars().all()
emoji_instances = query emoji_instances = query
emoji_objects, load_errors = _to_emoji_objects(emoji_instances) emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
@@ -770,10 +776,10 @@ class EmojiManager:
# 如果内存中没有,从数据库查找 # 如果内存中没有,从数据库查找
try: try:
with get_db_session() as session: async with get_db_session() as session:
emoji_record = session.execute( stmt = select(Emoji).where(Emoji.emoji_hash == emoji_hash)
select(Emoji).where(Emoji.emoji_hash == emoji_hash) result = await session.execute(stmt)
).scalar_one_or_none() emoji_record = result.scalar_one_or_none()
if emoji_record and emoji_record.description: if emoji_record and emoji_record.description:
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...") logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...")
return emoji_record.description return emoji_record.description
@@ -939,12 +945,13 @@ class EmojiManager:
# 2. 检查数据库中是否已存在该表情包的描述,实现复用 # 2. 检查数据库中是否已存在该表情包的描述,实现复用
existing_description = None existing_description = None
try: try:
with get_db_session() as session: async with get_db_session() as session:
existing_image = ( stmt = select(Images).where(
session.query(Images) Images.emoji_hash == image_hash,
.filter((Images.emoji_hash == image_hash) & (Images.type == "emoji")) Images.type == "emoji"
.one_or_none()
) )
result = await session.execute(stmt)
existing_image = result.scalar_one_or_none()
if existing_image and existing_image.description: if existing_image and existing_image.description:
existing_description = existing_image.description existing_description = existing_image.description
logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...") logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...")

View File

@@ -198,7 +198,7 @@ class RecencyEnergyCalculator(EnergyCalculator):
class RelationshipEnergyCalculator(EnergyCalculator): class RelationshipEnergyCalculator(EnergyCalculator):
"""关系能量计算器""" """关系能量计算器"""
def calculate(self, context: Dict[str, Any]) -> float: async def calculate(self, context: Dict[str, Any]) -> float:
"""基于关系计算能量""" """基于关系计算能量"""
user_id = context.get("user_id") user_id = context.get("user_id")
if not user_id: if not user_id:
@@ -208,7 +208,7 @@ class RelationshipEnergyCalculator(EnergyCalculator):
try: try:
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
relationship_score = chatter_interest_scoring_system._calculate_relationship_score(user_id) relationship_score = await chatter_interest_scoring_system._calculate_relationship_score(user_id)
logger.debug(f"使用插件内部系统计算关系分: {relationship_score:.3f}") logger.debug(f"使用插件内部系统计算关系分: {relationship_score:.3f}")
return max(0.0, min(1.0, relationship_score)) return max(0.0, min(1.0, relationship_score))
@@ -273,7 +273,7 @@ class EnergyManager:
except Exception as e: except Exception as e:
logger.warning(f"加载AFC阈值失败使用默认值: {e}") logger.warning(f"加载AFC阈值失败使用默认值: {e}")
def calculate_focus_energy(self, stream_id: str, messages: List[Any], user_id: Optional[str] = None) -> float: async def calculate_focus_energy(self, stream_id: str, messages: List[Any], user_id: Optional[str] = None) -> float:
"""计算聊天流的focus_energy""" """计算聊天流的focus_energy"""
start_time = time.time() start_time = time.time()
@@ -303,7 +303,16 @@ class EnergyManager:
for calculator in self.calculators: for calculator in self.calculators:
try: try:
# 支持同步和异步计算器
if callable(calculator.calculate):
import inspect
if inspect.iscoroutinefunction(calculator.calculate):
score = await calculator.calculate(context)
else:
score = calculator.calculate(context) score = calculator.calculate(context)
else:
score = calculator.calculate(context)
weight = calculator.get_weight() weight = calculator.get_weight()
component_scores[calculator.__class__.__name__] = score component_scores[calculator.__class__.__name__] = score

View File

@@ -8,6 +8,7 @@ import traceback
from typing import List, Dict, Optional, Any from typing import List, Dict, Optional, Any
from datetime import datetime from datetime import datetime
import numpy as np import numpy as np
from sqlalchemy import select
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
@@ -610,14 +611,13 @@ class BotInterestManager:
from src.common.database.sqlalchemy_database_api import get_db_session from src.common.database.sqlalchemy_database_api import get_db_session
import orjson import orjson
with get_db_session() as session: async with get_db_session() as session:
# 查询最新的兴趣标签配置 # 查询最新的兴趣标签配置
db_interests = ( db_interests = (await session.execute(
session.query(DBBotPersonalityInterests) select(DBBotPersonalityInterests)
.filter(DBBotPersonalityInterests.personality_id == personality_id) .where(DBBotPersonalityInterests.personality_id == personality_id)
.order_by(DBBotPersonalityInterests.version.desc(), DBBotPersonalityInterests.last_updated.desc()) .order_by(DBBotPersonalityInterests.version.desc(), DBBotPersonalityInterests.last_updated.desc())
.first() )).scalars().first()
)
if db_interests: if db_interests:
logger.debug(f"在数据库中找到兴趣标签配置, 版本: {db_interests.version}") logger.debug(f"在数据库中找到兴趣标签配置, 版本: {db_interests.version}")
@@ -700,13 +700,12 @@ class BotInterestManager:
# 序列化为JSON # 序列化为JSON
json_data = orjson.dumps(tags_data) json_data = orjson.dumps(tags_data)
with get_db_session() as session: async with get_db_session() as session:
# 检查是否已存在相同personality_id的记录 # 检查是否已存在相同personality_id的记录
existing_record = ( existing_record = (await session.execute(
session.query(DBBotPersonalityInterests) select(DBBotPersonalityInterests)
.filter(DBBotPersonalityInterests.personality_id == interests.personality_id) .where(DBBotPersonalityInterests.personality_id == interests.personality_id)
.first() )).scalars().first()
)
if existing_record: if existing_record:
# 更新现有记录 # 更新现有记录
@@ -731,19 +730,17 @@ class BotInterestManager:
last_updated=interests.last_updated, last_updated=interests.last_updated,
) )
session.add(new_record) session.add(new_record)
session.commit() await session.commit()
logger.info(f"✅ 成功创建兴趣标签配置,版本: {interests.version}") logger.info(f"✅ 成功创建兴趣标签配置,版本: {interests.version}")
logger.info("✅ 兴趣标签已成功保存到数据库") logger.info("✅ 兴趣标签已成功保存到数据库")
# 验证保存是否成功 # 验证保存是否成功
with get_db_session() as session: async with get_db_session() as session:
saved_record = ( saved_record = (await session.execute(
session.query(DBBotPersonalityInterests) select(DBBotPersonalityInterests)
.filter(DBBotPersonalityInterests.personality_id == interests.personality_id) .where(DBBotPersonalityInterests.personality_id == interests.personality_id)
.first() )).scalars().first()
)
session.commit()
if saved_record: if saved_record:
logger.info(f"✅ 验证成功数据库中存在personality_id为 {interests.personality_id} 的记录") logger.info(f"✅ 验证成功数据库中存在personality_id为 {interests.personality_id} 的记录")
logger.info(f" 版本: {saved_record.version}") logger.info(f" 版本: {saved_record.version}")

View File

@@ -882,7 +882,8 @@ class EntorhinalCortex:
# 获取数据库中所有节点和内存中所有节点 # 获取数据库中所有节点和内存中所有节点
async with get_db_session() as session: async with get_db_session() as session:
db_nodes = {node.concept: node for node in (await session.execute(select(GraphNodes))).scalars()} result = await session.execute(select(GraphNodes))
db_nodes = {node.concept: node for node in result.scalars()}
memory_nodes = list(self.memory_graph.G.nodes(data=True)) memory_nodes = list(self.memory_graph.G.nodes(data=True))
# 批量准备节点数据 # 批量准备节点数据
@@ -978,7 +979,8 @@ class EntorhinalCortex:
await session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete))) await session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete)))
# 处理边的信息 # 处理边的信息
db_edges = list((await session.execute(select(GraphEdges))).scalars()) result = await session.execute(select(GraphEdges))
db_edges = list(result.scalars())
memory_edges = list(self.memory_graph.G.edges(data=True)) memory_edges = list(self.memory_graph.G.edges(data=True))
# 创建边的哈希值字典 # 创建边的哈希值字典
@@ -1157,7 +1159,8 @@ class EntorhinalCortex:
# 从数据库加载所有节点 # 从数据库加载所有节点
async with get_db_session() as session: async with get_db_session() as session:
nodes = list((await session.execute(select(GraphNodes))).scalars()) result = await session.execute(select(GraphNodes))
nodes = list(result.scalars())
for node in nodes: for node in nodes:
concept = node.concept concept = node.concept
try: try:
@@ -1192,7 +1195,8 @@ class EntorhinalCortex:
continue continue
# 从数据库加载所有边 # 从数据库加载所有边
edges = list((await session.execute(select(GraphEdges))).scalars()) result = await session.execute(select(GraphEdges))
edges = list(result.scalars())
for edge in edges: for edge in edges:
source = edge.source source = edge.source
target = edge.target target = edge.target

View File

@@ -184,6 +184,11 @@ class AsyncMemoryQueue:
from src.chat.memory_system.Hippocampus import hippocampus_manager from src.chat.memory_system.Hippocampus import hippocampus_manager
if hippocampus_manager._initialized: if hippocampus_manager._initialized:
# 确保海马体对象已正确初始化
if not hippocampus_manager._hippocampus.parahippocampal_gyrus:
logger.warning("海马体对象未完全初始化,进行同步初始化")
hippocampus_manager._hippocampus.initialize()
await hippocampus_manager.build_memory() await hippocampus_manager.build_memory()
return True return True
return False return False

View File

@@ -108,7 +108,7 @@ class InstantMemory:
@staticmethod @staticmethod
async def store_memory(memory_item: MemoryItem): async def store_memory(memory_item: MemoryItem):
with get_db_session() as session: async with get_db_session() as session:
memory = Memory( memory = Memory(
memory_id=memory_item.memory_id, memory_id=memory_item.memory_id,
chat_id=memory_item.chat_id, chat_id=memory_item.chat_id,
@@ -161,20 +161,21 @@ class InstantMemory:
logger.info(f"start_time: {start_time}, end_time: {end_time}") logger.info(f"start_time: {start_time}, end_time: {end_time}")
# 检索包含关键词的记忆 # 检索包含关键词的记忆
memories_set = set() memories_set = set()
with get_db_session() as session: async with get_db_session() as session:
if start_time and end_time: if start_time and end_time:
start_ts = start_time.timestamp() start_ts = start_time.timestamp()
end_ts = end_time.timestamp() end_ts = end_time.timestamp()
query = session.execute( query = (await session.execute(
select(Memory).where( select(Memory).where(
(Memory.chat_id == self.chat_id) (Memory.chat_id == self.chat_id)
& (Memory.create_time >= start_ts) & (Memory.create_time >= start_ts)
& (Memory.create_time < end_ts) & (Memory.create_time < end_ts)
) )
).scalars() )).scalars()
else: else:
query = session.execute(select(Memory).where(Memory.chat_id == self.chat_id)).scalars() query = result = await session.execute(select(Memory).where(Memory.chat_id == self.chat_id))
result.scalars()
for mem in query: for mem in query:
# 对每条记忆 # 对每条记忆
mem_keywords_str = mem.keywords or "[]" mem_keywords_str = mem.keywords or "[]"

View File

@@ -4,7 +4,7 @@
""" """
from .message_manager import MessageManager, message_manager from .message_manager import MessageManager, message_manager
from .context_manager import StreamContextManager, context_manager from .context_manager import SingleStreamContextManager
from .distribution_manager import ( from .distribution_manager import (
DistributionManager, DistributionManager,
DistributionPriority, DistributionPriority,
@@ -16,8 +16,7 @@ from .distribution_manager import (
__all__ = [ __all__ = [
"MessageManager", "MessageManager",
"message_manager", "message_manager",
"StreamContextManager", "SingleStreamContextManager",
"context_manager",
"DistributionManager", "DistributionManager",
"DistributionPriority", "DistributionPriority",
"DistributionTask", "DistributionTask",

View File

@@ -1,12 +1,12 @@
""" """
重构后的聊天上下文管理器 重构后的聊天上下文管理器
提供统一、稳定的聊天上下文管理功能 提供统一、稳定的聊天上下文管理功能
每个 context_manager 实例只管理一个 stream 的上下文
""" """
import asyncio import asyncio
import time import time
from typing import Dict, List, Optional, Any, Union, Tuple from typing import Dict, List, Optional, Any
from abc import ABC, abstractmethod
from src.common.data_models.message_manager_data_model import StreamContext from src.common.data_models.message_manager_data_model import StreamContext
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -17,241 +17,112 @@ from .distribution_manager import distribution_manager
logger = get_logger("context_manager") logger = get_logger("context_manager")
class StreamContextManager:
"""流上下文管理器 - 统一管理所有聊天流上下文"""
def __init__(self, max_context_size: Optional[int] = None, context_ttl: Optional[int] = None): class SingleStreamContextManager:
# 上下文存储 """单流上下文管理器 - 每个实例只管理一个 stream 的上下文"""
self.stream_contexts: Dict[str, Any] = {}
self.context_metadata: Dict[str, Dict[str, Any]] = {}
# 统计信息 def __init__(self, stream_id: str, context: StreamContext, max_context_size: Optional[int] = None):
self.stats: Dict[str, Union[int, float, str, Dict]] = { self.stream_id = stream_id
"total_messages": 0, self.context = context
"total_streams": 0,
"active_streams": 0,
"inactive_streams": 0,
"last_activity": time.time(),
"creation_time": time.time(),
}
# 配置参数 # 配置参数
self.max_context_size = max_context_size or getattr(global_config.chat, "max_context_size", 100) self.max_context_size = max_context_size or getattr(global_config.chat, "max_context_size", 100)
self.context_ttl = context_ttl or getattr(global_config.chat, "context_ttl", 24 * 3600) # 24小时 self.context_ttl = getattr(global_config.chat, "context_ttl", 24 * 3600) # 24小时
self.cleanup_interval = getattr(global_config.chat, "context_cleanup_interval", 3600) # 1小时
self.auto_cleanup = getattr(global_config.chat, "auto_cleanup_contexts", True)
self.enable_validation = getattr(global_config.chat, "enable_context_validation", True)
# 清理任务 # 元数据
self.cleanup_task: Optional[Any] = None self.created_time = time.time()
self.is_running = False self.last_access_time = time.time()
self.access_count = 0
self.total_messages = 0
logger.info(f"上下文管理器初始化完成 (最大上下文: {self.max_context_size}, TTL: {self.context_ttl}s)") logger.debug(f"单流上下文管理器初始化: {stream_id}")
def add_stream_context(self, stream_id: str, context: Any, metadata: Optional[Dict[str, Any]] = None) -> bool: def get_context(self) -> StreamContext:
"""添加流上下文 """获取流上下文"""
self._update_access_stats()
return self.context
Args: def add_message(self, message: DatabaseMessages, skip_energy_update: bool = False) -> bool:
stream_id: 流ID
context: 上下文对象
metadata: 上下文元数据
Returns:
bool: 是否成功添加
"""
if stream_id in self.stream_contexts:
logger.warning(f"流上下文已存在: {stream_id}")
return False
# 添加上下文
self.stream_contexts[stream_id] = context
# 初始化元数据
self.context_metadata[stream_id] = {
"created_time": time.time(),
"last_access_time": time.time(),
"access_count": 0,
"last_validation_time": 0.0,
"custom_metadata": metadata or {},
}
# 更新统计
self.stats["total_streams"] += 1
self.stats["active_streams"] += 1
self.stats["last_activity"] = time.time()
logger.debug(f"添加流上下文: {stream_id} (类型: {type(context).__name__})")
return True
def remove_stream_context(self, stream_id: str) -> bool:
"""移除流上下文
Args:
stream_id: 流ID
Returns:
bool: 是否成功移除
"""
if stream_id in self.stream_contexts:
context = self.stream_contexts[stream_id]
metadata = self.context_metadata.get(stream_id, {})
del self.stream_contexts[stream_id]
if stream_id in self.context_metadata:
del self.context_metadata[stream_id]
self.stats["active_streams"] = max(0, self.stats["active_streams"] - 1)
self.stats["inactive_streams"] += 1
self.stats["last_activity"] = time.time()
logger.debug(f"移除流上下文: {stream_id} (类型: {type(context).__name__})")
return True
return False
def get_stream_context(self, stream_id: str, update_access: bool = True) -> Optional[StreamContext]:
"""获取流上下文
Args:
stream_id: 流ID
update_access: 是否更新访问统计
Returns:
Optional[Any]: 上下文对象
"""
context = self.stream_contexts.get(stream_id)
if context and update_access:
# 更新访问统计
if stream_id in self.context_metadata:
metadata = self.context_metadata[stream_id]
metadata["last_access_time"] = time.time()
metadata["access_count"] = metadata.get("access_count", 0) + 1
return context
def get_context_metadata(self, stream_id: str) -> Optional[Dict[str, Any]]:
"""获取上下文元数据
Args:
stream_id: 流ID
Returns:
Optional[Dict[str, Any]]: 元数据
"""
return self.context_metadata.get(stream_id)
def update_context_metadata(self, stream_id: str, updates: Dict[str, Any]) -> bool:
"""更新上下文元数据
Args:
stream_id: 流ID
updates: 更新的元数据
Returns:
bool: 是否成功更新
"""
if stream_id not in self.context_metadata:
return False
self.context_metadata[stream_id].update(updates)
return True
def add_message_to_context(self, stream_id: str, message: DatabaseMessages, skip_energy_update: bool = False) -> bool:
"""添加消息到上下文 """添加消息到上下文
Args: Args:
stream_id: 流ID
message: 消息对象 message: 消息对象
skip_energy_update: 是否跳过能量更新 skip_energy_update: 是否跳过能量更新
Returns: Returns:
bool: 是否成功添加 bool: 是否成功添加
""" """
context = self.get_stream_context(stream_id)
if not context:
logger.warning(f"流上下文不存在: {stream_id}")
return False
try: try:
# 添加消息到上下文 # 添加消息到上下文
context.add_message(message) self.context.add_message(message)
# 计算消息兴趣度 # 计算消息兴趣度
interest_value = self._calculate_message_interest(message) interest_value = self._calculate_message_interest(message)
message.interest_value = interest_value message.interest_value = interest_value
# 更新统计 # 更新统计
self.stats["total_messages"] += 1 self.total_messages += 1
self.stats["last_activity"] = time.time() self.last_access_time = time.time()
# 更新能量和分发 # 更新能量和分发
if not skip_energy_update: if not skip_energy_update:
self._update_stream_energy(stream_id) self._update_stream_energy()
distribution_manager.add_stream_message(stream_id, 1) distribution_manager.add_stream_message(self.stream_id, 1)
logger.debug(f"添加消息到上下文: {stream_id} (兴趣度: {interest_value:.3f})") logger.debug(f"添加消息到单流上下文: {self.stream_id} (兴趣度: {interest_value:.3f})")
return True return True
except Exception as e: except Exception as e:
logger.error(f"添加消息到上下文失败 {stream_id}: {e}", exc_info=True) logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True)
return False return False
def update_message_in_context(self, stream_id: str, message_id: str, updates: Dict[str, Any]) -> bool: def update_message(self, message_id: str, updates: Dict[str, Any]) -> bool:
"""更新上下文中的消息 """更新上下文中的消息
Args: Args:
stream_id: 流ID
message_id: 消息ID message_id: 消息ID
updates: 更新的属性 updates: 更新的属性
Returns: Returns:
bool: 是否成功更新 bool: 是否成功更新
""" """
context = self.get_stream_context(stream_id)
if not context:
logger.warning(f"流上下文不存在: {stream_id}")
return False
try: try:
# 更新消息信息 # 更新消息信息
context.update_message_info(message_id, **updates) self.context.update_message_info(message_id, **updates)
# 如果更新了兴趣度,重新计算能量 # 如果更新了兴趣度,重新计算能量
if "interest_value" in updates: if "interest_value" in updates:
self._update_stream_energy(stream_id) self._update_stream_energy()
logger.debug(f"更新上下文消息: {stream_id}/{message_id}") logger.debug(f"更新单流上下文消息: {self.stream_id}/{message_id}")
return True return True
except Exception as e: except Exception as e:
logger.error(f"更新上下文消息失败 {stream_id}/{message_id}: {e}", exc_info=True) logger.error(f"更新单流上下文消息失败 {self.stream_id}/{message_id}: {e}", exc_info=True)
return False return False
def get_context_messages(self, stream_id: str, limit: Optional[int] = None, include_unread: bool = True) -> List[DatabaseMessages]: def get_messages(self, limit: Optional[int] = None, include_unread: bool = True) -> List[DatabaseMessages]:
"""获取上下文消息 """获取上下文消息
Args: Args:
stream_id: 流ID
limit: 消息数量限制 limit: 消息数量限制
include_unread: 是否包含未读消息 include_unread: 是否包含未读消息
Returns: Returns:
List[Any]: 消息列表 List[DatabaseMessages]: 消息列表
""" """
context = self.get_stream_context(stream_id)
if not context:
return []
try: try:
messages = [] messages = []
if include_unread: if include_unread:
messages.extend(context.get_unread_messages()) messages.extend(self.context.get_unread_messages())
if limit: if limit:
messages.extend(context.get_history_messages(limit=limit)) messages.extend(self.context.get_history_messages(limit=limit))
else: else:
messages.extend(context.get_history_messages()) messages.extend(self.context.get_history_messages())
# 按时间排序 # 按时间排序
messages.sort(key=lambda msg: getattr(msg, 'time', 0)) messages.sort(key=lambda msg: getattr(msg, "time", 0))
# 应用限制 # 应用限制
if limit and len(messages) > limit: if limit and len(messages) > limit:
@@ -260,103 +131,124 @@ class StreamContextManager:
return messages return messages
except Exception as e: except Exception as e:
logger.error(f"获取上下文消息失败 {stream_id}: {e}", exc_info=True) logger.error(f"获取单流上下文消息失败 {self.stream_id}: {e}", exc_info=True)
return []
def get_unread_messages(self, stream_id: str) -> List[DatabaseMessages]:
"""获取未读消息
Args:
stream_id: 流ID
Returns:
List[Any]: 未读消息列表
"""
context = self.get_stream_context(stream_id)
if not context:
return [] return []
def get_unread_messages(self) -> List[DatabaseMessages]:
"""获取未读消息"""
try: try:
return context.get_unread_messages() return self.context.get_unread_messages()
except Exception as e: except Exception as e:
logger.error(f"获取未读消息失败 {stream_id}: {e}", exc_info=True) logger.error(f"获取单流未读消息失败 {self.stream_id}: {e}", exc_info=True)
return [] return []
def mark_messages_as_read(self, stream_id: str, message_ids: List[str]) -> bool: def mark_messages_as_read(self, message_ids: List[str]) -> bool:
"""标记消息为已读 """标记消息为已读"""
Args:
stream_id: 流ID
message_ids: 消息ID列表
Returns:
bool: 是否成功标记
"""
context = self.get_stream_context(stream_id)
if not context:
logger.warning(f"流上下文不存在: {stream_id}")
return False
try: try:
if not hasattr(context, 'mark_message_as_read'): if not hasattr(self.context, "mark_message_as_read"):
logger.error(f"上下文对象缺少 mark_message_as_read 方法: {stream_id}") logger.error(f"上下文对象缺少 mark_message_as_read 方法: {self.stream_id}")
return False return False
marked_count = 0 marked_count = 0
for message_id in message_ids: for message_id in message_ids:
try: try:
context.mark_message_as_read(message_id) self.context.mark_message_as_read(message_id)
marked_count += 1 marked_count += 1
except Exception as e: except Exception as e:
logger.warning(f"标记消息已读失败 {message_id}: {e}") logger.warning(f"标记消息已读失败 {message_id}: {e}")
logger.debug(f"标记消息为已读: {stream_id} ({marked_count}/{len(message_ids)}条)") logger.debug(f"标记消息为已读: {self.stream_id} ({marked_count}/{len(message_ids)}条)")
return marked_count > 0 return marked_count > 0
except Exception as e: except Exception as e:
logger.error(f"标记消息已读失败 {stream_id}: {e}", exc_info=True) logger.error(f"标记消息已读失败 {self.stream_id}: {e}", exc_info=True)
return False
def clear_context(self, stream_id: str) -> bool:
"""清空上下文
Args:
stream_id: 流ID
Returns:
bool: 是否成功清空
"""
context = self.get_stream_context(stream_id)
if not context:
logger.warning(f"流上下文不存在: {stream_id}")
return False return False
def clear_context(self) -> bool:
"""清空上下文"""
try: try:
# 清空消息 # 清空消息
if hasattr(context, 'unread_messages'): if hasattr(self.context, "unread_messages"):
context.unread_messages.clear() self.context.unread_messages.clear()
if hasattr(context, 'history_messages'): if hasattr(self.context, "history_messages"):
context.history_messages.clear() self.context.history_messages.clear()
# 重置状态 # 重置状态
reset_attrs = ['interruption_count', 'afc_threshold_adjustment', 'last_check_time'] reset_attrs = ["interruption_count", "afc_threshold_adjustment", "last_check_time"]
for attr in reset_attrs: for attr in reset_attrs:
if hasattr(context, attr): if hasattr(self.context, attr):
if attr in ['interruption_count', 'afc_threshold_adjustment']: if attr in ["interruption_count", "afc_threshold_adjustment"]:
setattr(context, attr, 0) setattr(self.context, attr, 0)
else: else:
setattr(context, attr, time.time()) setattr(self.context, attr, time.time())
# 重新计算能量 # 重新计算能量
self._update_stream_energy(stream_id) self._update_stream_energy()
logger.info(f"清空上下文: {stream_id}") logger.info(f"清空单流上下文: {self.stream_id}")
return True return True
except Exception as e: except Exception as e:
logger.error(f"清空上下文失败 {stream_id}: {e}", exc_info=True) logger.error(f"清空单流上下文失败 {self.stream_id}: {e}", exc_info=True)
return False return False
def get_statistics(self) -> Dict[str, Any]:
"""获取流统计信息"""
try:
current_time = time.time()
uptime = current_time - self.created_time
unread_messages = getattr(self.context, "unread_messages", [])
history_messages = getattr(self.context, "history_messages", [])
return {
"stream_id": self.stream_id,
"context_type": type(self.context).__name__,
"total_messages": len(history_messages) + len(unread_messages),
"unread_messages": len(unread_messages),
"history_messages": len(history_messages),
"is_active": getattr(self.context, "is_active", True),
"last_check_time": getattr(self.context, "last_check_time", current_time),
"interruption_count": getattr(self.context, "interruption_count", 0),
"afc_threshold_adjustment": getattr(self.context, "afc_threshold_adjustment", 0.0),
"created_time": self.created_time,
"last_access_time": self.last_access_time,
"access_count": self.access_count,
"uptime_seconds": uptime,
"idle_seconds": current_time - self.last_access_time,
}
except Exception as e:
logger.error(f"获取单流统计失败 {self.stream_id}: {e}", exc_info=True)
return {}
def validate_integrity(self) -> bool:
"""验证上下文完整性"""
try:
# 检查基本属性
required_attrs = ["stream_id", "unread_messages", "history_messages"]
for attr in required_attrs:
if not hasattr(self.context, attr):
logger.warning(f"上下文缺少必要属性: {attr}")
return False
# 检查消息ID唯一性
all_messages = getattr(self.context, "unread_messages", []) + getattr(self.context, "history_messages", [])
message_ids = [msg.message_id for msg in all_messages if hasattr(msg, "message_id")]
if len(message_ids) != len(set(message_ids)):
logger.warning(f"上下文中存在重复消息ID: {self.stream_id}")
return False
return True
except Exception as e:
logger.error(f"验证单流上下文完整性失败 {self.stream_id}: {e}")
return False
def _update_access_stats(self):
"""更新访问统计"""
self.last_access_time = time.time()
self.access_count += 1
def _calculate_message_interest(self, message: DatabaseMessages) -> float: def _calculate_message_interest(self, message: DatabaseMessages) -> float:
"""计算消息兴趣度""" """计算消息兴趣度"""
try: try:
@@ -373,8 +265,7 @@ class StreamContextManager:
interest_score = loop.run_until_complete( interest_score = loop.run_until_complete(
chatter_interest_scoring_system._calculate_single_message_score( chatter_interest_scoring_system._calculate_single_message_score(
message=message, message=message, bot_nickname=global_config.bot.nickname
bot_nickname=global_config.bot.nickname
) )
) )
interest_value = interest_score.total_score interest_value = interest_score.total_score
@@ -391,12 +282,12 @@ class StreamContextManager:
logger.error(f"计算消息兴趣度失败: {e}") logger.error(f"计算消息兴趣度失败: {e}")
return 0.5 return 0.5
def _update_stream_energy(self, stream_id: str): async def _update_stream_energy(self):
"""更新流能量""" """更新流能量"""
try: try:
# 获取所有消息 # 获取所有消息
all_messages = self.get_context_messages(stream_id, self.max_context_size) all_messages = self.get_messages(self.max_context_size)
unread_messages = self.get_unread_messages(stream_id) unread_messages = self.get_unread_messages()
combined_messages = all_messages + unread_messages combined_messages = all_messages + unread_messages
# 获取用户ID # 获取用户ID
@@ -406,248 +297,12 @@ class StreamContextManager:
user_id = last_message.user_info.user_id user_id = last_message.user_info.user_id
# 计算能量 # 计算能量
energy = energy_manager.calculate_focus_energy( energy = await energy_manager.calculate_focus_energy(
stream_id=stream_id, stream_id=self.stream_id, messages=combined_messages, user_id=user_id
messages=combined_messages,
user_id=user_id
) )
# 更新分发管理器 # 更新分发管理器
distribution_manager.update_stream_energy(stream_id, energy) distribution_manager.update_stream_energy(self.stream_id, energy)
except Exception as e: except Exception as e:
logger.error(f"更新流能量失败 {stream_id}: {e}") logger.error(f"更新流能量失败 {self.stream_id}: {e}")
def get_stream_statistics(self, stream_id: str) -> Optional[Dict[str, Any]]:
"""获取流统计信息
Args:
stream_id: 流ID
Returns:
Optional[Dict[str, Any]]: 统计信息
"""
context = self.get_stream_context(stream_id, update_access=False)
if not context:
return None
try:
metadata = self.context_metadata.get(stream_id, {})
current_time = time.time()
created_time = metadata.get("created_time", current_time)
last_access_time = metadata.get("last_access_time", current_time)
access_count = metadata.get("access_count", 0)
unread_messages = getattr(context, "unread_messages", [])
history_messages = getattr(context, "history_messages", [])
return {
"stream_id": stream_id,
"context_type": type(context).__name__,
"total_messages": len(history_messages) + len(unread_messages),
"unread_messages": len(unread_messages),
"history_messages": len(history_messages),
"is_active": getattr(context, "is_active", True),
"last_check_time": getattr(context, "last_check_time", current_time),
"interruption_count": getattr(context, "interruption_count", 0),
"afc_threshold_adjustment": getattr(context, "afc_threshold_adjustment", 0.0),
"created_time": created_time,
"last_access_time": last_access_time,
"access_count": access_count,
"uptime_seconds": current_time - created_time,
"idle_seconds": current_time - last_access_time,
}
except Exception as e:
logger.error(f"获取流统计失败 {stream_id}: {e}", exc_info=True)
return None
def get_manager_statistics(self) -> Dict[str, Any]:
"""获取管理器统计信息
Returns:
Dict[str, Any]: 管理器统计信息
"""
current_time = time.time()
uptime = current_time - self.stats.get("creation_time", current_time)
return {
**self.stats,
"uptime_hours": uptime / 3600,
"stream_count": len(self.stream_contexts),
"metadata_count": len(self.context_metadata),
"auto_cleanup_enabled": self.auto_cleanup,
"cleanup_interval": self.cleanup_interval,
}
def cleanup_inactive_contexts(self, max_inactive_hours: int = 24) -> int:
"""清理不活跃的上下文
Args:
max_inactive_hours: 最大不活跃小时数
Returns:
int: 清理的上下文数量
"""
current_time = time.time()
max_inactive_seconds = max_inactive_hours * 3600
inactive_streams = []
for stream_id, context in self.stream_contexts.items():
try:
# 获取最后活动时间
metadata = self.context_metadata.get(stream_id, {})
last_activity = metadata.get("last_access_time", metadata.get("created_time", 0))
context_last_activity = getattr(context, "last_check_time", 0)
actual_last_activity = max(last_activity, context_last_activity)
# 检查是否不活跃
unread_count = len(getattr(context, "unread_messages", []))
history_count = len(getattr(context, "history_messages", []))
total_messages = unread_count + history_count
if (current_time - actual_last_activity > max_inactive_seconds and
total_messages == 0):
inactive_streams.append(stream_id)
except Exception as e:
logger.warning(f"检查上下文活跃状态失败 {stream_id}: {e}")
continue
# 清理不活跃上下文
cleaned_count = 0
for stream_id in inactive_streams:
if self.remove_stream_context(stream_id):
cleaned_count += 1
if cleaned_count > 0:
logger.info(f"清理了 {cleaned_count} 个不活跃上下文")
return cleaned_count
def validate_context_integrity(self, stream_id: str) -> bool:
"""验证上下文完整性
Args:
stream_id: 流ID
Returns:
bool: 是否完整
"""
context = self.get_stream_context(stream_id)
if not context:
return False
try:
# 检查基本属性
required_attrs = ["stream_id", "unread_messages", "history_messages"]
for attr in required_attrs:
if not hasattr(context, attr):
logger.warning(f"上下文缺少必要属性: {attr}")
return False
# 检查消息ID唯一性
all_messages = getattr(context, "unread_messages", []) + getattr(context, "history_messages", [])
message_ids = [msg.message_id for msg in all_messages if hasattr(msg, "message_id")]
if len(message_ids) != len(set(message_ids)):
logger.warning(f"上下文中存在重复消息ID: {stream_id}")
return False
return True
except Exception as e:
logger.error(f"验证上下文完整性失败 {stream_id}: {e}")
return False
async def start(self) -> None:
"""启动上下文管理器"""
if self.is_running:
logger.warning("上下文管理器已经在运行")
return
await self.start_auto_cleanup()
logger.info("上下文管理器已启动")
async def stop(self) -> None:
"""停止上下文管理器"""
if not self.is_running:
return
await self.stop_auto_cleanup()
logger.info("上下文管理器已停止")
async def start_auto_cleanup(self, interval: Optional[float] = None) -> None:
"""启动自动清理
Args:
interval: 清理间隔(秒)
"""
if not self.auto_cleanup:
logger.info("自动清理已禁用")
return
if self.is_running:
logger.warning("自动清理已在运行")
return
self.is_running = True
cleanup_interval = interval or self.cleanup_interval
logger.info(f"启动自动清理(间隔: {cleanup_interval}s")
import asyncio
self.cleanup_task = asyncio.create_task(self._cleanup_loop(cleanup_interval))
async def stop_auto_cleanup(self) -> None:
"""停止自动清理"""
self.is_running = False
if self.cleanup_task and not self.cleanup_task.done():
self.cleanup_task.cancel()
try:
await self.cleanup_task
except Exception:
pass
logger.info("自动清理已停止")
async def _cleanup_loop(self, interval: float) -> None:
"""清理循环
Args:
interval: 清理间隔
"""
while self.is_running:
try:
await asyncio.sleep(interval)
self.cleanup_inactive_contexts()
self._cleanup_expired_contexts()
logger.debug("自动清理完成")
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"清理循环出错: {e}", exc_info=True)
await asyncio.sleep(interval)
def _cleanup_expired_contexts(self) -> None:
"""清理过期上下文"""
current_time = time.time()
expired_contexts = []
for stream_id, metadata in self.context_metadata.items():
created_time = metadata.get("created_time", current_time)
if current_time - created_time > self.context_ttl:
expired_contexts.append(stream_id)
for stream_id in expired_contexts:
self.remove_stream_context(stream_id)
if expired_contexts:
logger.info(f"清理了 {len(expired_contexts)} 个过期上下文")
def get_active_streams(self) -> List[str]:
"""获取活跃流列表
Returns:
List[str]: 活跃流ID列表
"""
return list(self.stream_contexts.keys())
# 全局上下文管理器实例
context_manager = StreamContextManager()

View File

@@ -14,11 +14,10 @@ from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.message_manager_data_model import StreamContext, MessageManagerStats, StreamStats from src.common.data_models.message_manager_data_model import StreamContext, MessageManagerStats, StreamStats
from src.chat.chatter_manager import ChatterManager from src.chat.chatter_manager import ChatterManager
from src.chat.planner_actions.action_manager import ChatterActionManager from src.chat.planner_actions.action_manager import ChatterActionManager
from src.plugin_system.base.component_types import ChatMode
from .sleep_manager.sleep_manager import SleepManager from .sleep_manager.sleep_manager import SleepManager
from .sleep_manager.wakeup_manager import WakeUpManager from .sleep_manager.wakeup_manager import WakeUpManager
from src.config.config import global_config from src.config.config import global_config
from .context_manager import context_manager from src.plugin_system.apis.chat_api import get_chat_manager
if TYPE_CHECKING: if TYPE_CHECKING:
from src.common.data_models.message_manager_data_model import StreamContext from src.common.data_models.message_manager_data_model import StreamContext
@@ -45,8 +44,7 @@ class MessageManager:
self.sleep_manager = SleepManager() self.sleep_manager = SleepManager()
self.wakeup_manager = WakeUpManager(self.sleep_manager) self.wakeup_manager = WakeUpManager(self.sleep_manager)
# 初始化上下文管理器 # 不再需要全局上下文管理器,直接通过 ChatManager 访问各个 ChatStream 的 context_manager
self.context_manager = context_manager
async def start(self): async def start(self):
"""启动消息管理器""" """启动消息管理器"""
@@ -57,7 +55,7 @@ class MessageManager:
self.is_running = True self.is_running = True
self.manager_task = asyncio.create_task(self._manager_loop()) self.manager_task = asyncio.create_task(self._manager_loop())
await self.wakeup_manager.start() await self.wakeup_manager.start()
await self.context_manager.start() # await self.context_manager.start() # 已删除,需要重构
logger.info("消息管理器已启动") logger.info("消息管理器已启动")
async def stop(self): async def stop(self):
@@ -73,29 +71,32 @@ class MessageManager:
self.manager_task.cancel() self.manager_task.cancel()
await self.wakeup_manager.stop() await self.wakeup_manager.stop()
await self.context_manager.stop() # await self.context_manager.stop() # 已删除,需要重构
logger.info("消息管理器已停止") logger.info("消息管理器已停止")
def add_message(self, stream_id: str, message: DatabaseMessages): def add_message(self, stream_id: str, message: DatabaseMessages):
"""添加消息到指定聊天流""" """添加消息到指定聊天流"""
# 检查流上下文是否存在,不存在则创建 try:
context = self.context_manager.get_stream_context(stream_id) # 通过 ChatManager 获取 ChatStream
if not context: chat_manager = get_chat_manager()
# 创建新的流上下文 chat_stream = chat_manager.get_stream(stream_id)
from src.common.data_models.message_manager_data_model import StreamContext
context = StreamContext(stream_id=stream_id)
# 将创建的上下文添加到管理器
self.context_manager.add_stream_context(stream_id, context)
# 使用 context_manager 添加消息 if not chat_stream:
success = self.context_manager.add_message_to_context(stream_id, message) logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在")
return
# 使用 ChatStream 的 context_manager 添加消息
success = chat_stream.context_manager.add_message(message)
if success: if success:
logger.debug(f"添加消息到聊天流 {stream_id}: {message.message_id}") logger.debug(f"添加消息到聊天流 {stream_id}: {message.message_id}")
else: else:
logger.warning(f"添加消息到聊天流 {stream_id} 失败") logger.warning(f"添加消息到聊天流 {stream_id} 失败")
except Exception as e:
logger.error(f"添加消息到聊天流 {stream_id} 时发生错误: {e}")
def update_message( def update_message(
self, self,
stream_id: str, stream_id: str,
@@ -105,17 +106,60 @@ class MessageManager:
should_reply: bool = None, should_reply: bool = None,
): ):
"""更新消息信息""" """更新消息信息"""
# 使用 context_manager 更新消息信息 try:
context = self.context_manager.get_stream_context(stream_id) # 通过 ChatManager 获取 ChatStream
if context: chat_manager = get_chat_manager()
context.update_message_info(message_id, interest_value, actions, should_reply) chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"MessageManager.update_message: 聊天流 {stream_id} 不存在")
return
# 构建更新字典
updates = {}
if interest_value is not None:
updates["interest_value"] = interest_value
if actions is not None:
updates["actions"] = actions
if should_reply is not None:
updates["should_reply"] = should_reply
# 使用 ChatStream 的 context_manager 更新消息
if updates:
success = chat_stream.context_manager.update_message(message_id, updates)
if success:
logger.debug(f"更新消息 {message_id} 成功")
else:
logger.warning(f"更新消息 {message_id} 失败")
except Exception as e:
logger.error(f"更新消息 {message_id} 时发生错误: {e}")
def add_action(self, stream_id: str, message_id: str, action: str): def add_action(self, stream_id: str, message_id: str, action: str):
"""添加动作到消息""" """添加动作到消息"""
# 使用 context_manager 添加动作到消息 try:
context = self.context_manager.get_stream_context(stream_id) # 通过 ChatManager 获取 ChatStream
if context: chat_manager = get_chat_manager()
context.add_action_to_message(message_id, action) chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"MessageManager.add_action: 聊天流 {stream_id} 不存在")
return
# 使用 ChatStream 的 context_manager 添加动作
# 注意:这里需要根据实际的 API 调整
# 假设我们可以通过 update_message 来添加动作
success = chat_stream.context_manager.update_message(
message_id, {"actions": [action]}
)
if success:
logger.debug(f"为消息 {message_id} 添加动作 {action} 成功")
else:
logger.warning(f"为消息 {message_id} 添加动作 {action} 失败")
except Exception as e:
logger.error(f"为消息 {message_id} 添加动作时发生错误: {e}")
async def _manager_loop(self): async def _manager_loop(self):
"""管理器主循环 - 独立聊天流分发周期版本""" """管理器主循环 - 独立聊天流分发周期版本"""
@@ -145,18 +189,25 @@ class MessageManager:
active_streams = 0 active_streams = 0
total_unread = 0 total_unread = 0
# 使用 context_manager 获取活跃的流 # 通过 ChatManager 获取所有活跃的流
active_stream_ids = self.context_manager.get_active_streams() try:
chat_manager = get_chat_manager()
active_stream_ids = list(chat_manager.streams.keys())
for stream_id in active_stream_ids: for stream_id in active_stream_ids:
context = self.context_manager.get_stream_context(stream_id) chat_stream = chat_manager.get_stream(stream_id)
if not context: if not chat_stream:
continue
# 检查流是否活跃
context = chat_stream.stream_context
if not context.is_active:
continue continue
active_streams += 1 active_streams += 1
# 检查是否有未读消息 # 检查是否有未读消息
unread_messages = self.context_manager.get_unread_messages(stream_id) unread_messages = chat_stream.context_manager.get_unread_messages()
if unread_messages: if unread_messages:
total_unread += len(unread_messages) total_unread += len(unread_messages)
@@ -168,15 +219,23 @@ class MessageManager:
self.stats.active_streams = active_streams self.stats.active_streams = active_streams
self.stats.total_unread_messages = total_unread self.stats.total_unread_messages = total_unread
except Exception as e:
logger.error(f"检查所有聊天流时发生错误: {e}")
async def _process_stream_messages(self, stream_id: str): async def _process_stream_messages(self, stream_id: str):
"""处理指定聊天流的消息""" """处理指定聊天流的消息"""
context = self.context_manager.get_stream_context(stream_id) try:
if not context: # 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"处理消息失败: 聊天流 {stream_id} 不存在")
return return
try: context = chat_stream.stream_context
# 获取未读消息 # 获取未读消息
unread_messages = self.context_manager.get_unread_messages(stream_id) unread_messages = chat_stream.context_manager.get_unread_messages()
if not unread_messages: if not unread_messages:
return return
@@ -250,8 +309,15 @@ class MessageManager:
def deactivate_stream(self, stream_id: str): def deactivate_stream(self, stream_id: str):
"""停用聊天流""" """停用聊天流"""
context = self.context_manager.get_stream_context(stream_id) try:
if context: # 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"停用流失败: 聊天流 {stream_id} 不存在")
return
context = chat_stream.stream_context
context.is_active = False context.is_active = False
# 取消处理任务 # 取消处理任务
@@ -260,28 +326,51 @@ class MessageManager:
logger.info(f"停用聊天流: {stream_id}") logger.info(f"停用聊天流: {stream_id}")
except Exception as e:
logger.error(f"停用聊天流 {stream_id} 时发生错误: {e}")
def activate_stream(self, stream_id: str): def activate_stream(self, stream_id: str):
"""激活聊天流""" """激活聊天流"""
context = self.context_manager.get_stream_context(stream_id) try:
if context: # 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"激活流失败: 聊天流 {stream_id} 不存在")
return
context = chat_stream.stream_context
context.is_active = True context.is_active = True
logger.info(f"激活聊天流: {stream_id}") logger.info(f"激活聊天流: {stream_id}")
except Exception as e:
logger.error(f"激活聊天流 {stream_id} 时发生错误: {e}")
def get_stream_stats(self, stream_id: str) -> Optional[StreamStats]: def get_stream_stats(self, stream_id: str) -> Optional[StreamStats]:
"""获取聊天流统计""" """获取聊天流统计"""
context = self.context_manager.get_stream_context(stream_id) try:
if not context: # 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
return None return None
context = chat_stream.stream_context
unread_count = len(chat_stream.context_manager.get_unread_messages())
return StreamStats( return StreamStats(
stream_id=stream_id, stream_id=stream_id,
is_active=context.is_active, is_active=context.is_active,
unread_count=len(self.context_manager.get_unread_messages(stream_id)), unread_count=unread_count,
history_count=len(context.history_messages), history_count=len(context.history_messages),
last_check_time=context.last_check_time, last_check_time=context.last_check_time,
has_active_task=bool(hasattr(context, 'processing_task') and context.processing_task and not context.processing_task.done()), has_active_task=bool(hasattr(context, 'processing_task') and context.processing_task and not context.processing_task.done()),
) )
except Exception as e:
logger.error(f"获取聊天流 {stream_id} 统计时发生错误: {e}")
return None
def get_manager_stats(self) -> Dict[str, Any]: def get_manager_stats(self) -> Dict[str, Any]:
"""获取管理器统计""" """获取管理器统计"""
return { return {
@@ -295,9 +384,36 @@ class MessageManager:
def cleanup_inactive_streams(self, max_inactive_hours: int = 24): def cleanup_inactive_streams(self, max_inactive_hours: int = 24):
"""清理不活跃的聊天流""" """清理不活跃的聊天流"""
# 使用 context_manager 的自动清理功能 try:
self.context_manager.cleanup_inactive_contexts(max_inactive_hours * 3600) # 通过 ChatManager 清理不活跃的流
logger.info("已启动不活跃聊天流清理") chat_manager = get_chat_manager()
current_time = time.time()
max_inactive_seconds = max_inactive_hours * 3600
inactive_streams = []
for stream_id, chat_stream in chat_manager.streams.items():
# 检查最后活跃时间
if current_time - chat_stream.last_active_time > max_inactive_seconds:
inactive_streams.append(stream_id)
# 清理不活跃的流
for stream_id in inactive_streams:
try:
# 清理流的内容
chat_stream.context_manager.clear_context()
# 从 ChatManager 中移除
del chat_manager.streams[stream_id]
logger.info(f"清理不活跃聊天流: {stream_id}")
except Exception as e:
logger.error(f"清理聊天流 {stream_id} 失败: {e}")
if inactive_streams:
logger.info(f"已清理 {len(inactive_streams)} 个不活跃聊天流")
else:
logger.debug("没有需要清理的不活跃聊天流")
except Exception as e:
logger.error(f"清理不活跃聊天流时发生错误: {e}")
async def _check_and_handle_interruption(self, context: StreamContext, stream_id: str): async def _check_and_handle_interruption(self, context: StreamContext, stream_id: str):
"""检查并处理消息打断""" """检查并处理消息打断"""
@@ -376,9 +492,10 @@ class MessageManager:
min_delay = float("inf") min_delay = float("inf")
# 找到最近需要检查的流 # 找到最近需要检查的流
active_stream_ids = self.context_manager.get_active_streams() try:
for stream_id in active_stream_ids: chat_manager = get_chat_manager()
context = self.context_manager.get_stream_context(stream_id) for _stream_id, chat_stream in chat_manager.streams.items():
context = chat_stream.stream_context
if not context or not context.is_active: if not context or not context.is_active:
continue continue
@@ -396,16 +513,20 @@ class MessageManager:
# 确保最小延迟 # 确保最小延迟
return max(0.1, min(min_delay, self.check_interval)) return max(0.1, min(min_delay, self.check_interval))
except Exception as e:
logger.error(f"计算下次检查延迟时发生错误: {e}")
return self.check_interval
async def _check_streams_with_individual_intervals(self): async def _check_streams_with_individual_intervals(self):
"""检查所有达到检查时间的聊天流""" """检查所有达到检查时间的聊天流"""
current_time = time.time() current_time = time.time()
processed_streams = 0 processed_streams = 0
# 使用 context_manager 获取活跃的流 # 通过 ChatManager 获取活跃的流
active_stream_ids = self.context_manager.get_active_streams() try:
chat_manager = get_chat_manager()
for stream_id in active_stream_ids: for stream_id, chat_stream in chat_manager.streams.items():
context = self.context_manager.get_stream_context(stream_id) context = chat_stream.stream_context
if not context or not context.is_active: if not context or not context.is_active:
continue continue
@@ -424,17 +545,14 @@ class MessageManager:
context.next_check_time = current_time + context.distribution_interval context.next_check_time = current_time + context.distribution_interval
# 检查未读消息 # 检查未读消息
unread_messages = self.context_manager.get_unread_messages(stream_id) unread_messages = chat_stream.context_manager.get_unread_messages()
if unread_messages: if unread_messages:
processed_streams += 1 processed_streams += 1
self.stats.total_unread_messages = len(unread_messages) self.stats.total_unread_messages = len(unread_messages)
# 如果没有处理任务,创建一个 # 如果没有处理任务,创建一个
if not context.processing_task or context.processing_task.done(): if not context.processing_task or context.processing_task.done():
from src.plugin_system.apis.chat_api import get_chat_manager focus_energy = chat_stream.focus_energy
chat_stream = get_chat_manager().get_stream(context.stream_id)
focus_energy = chat_stream.focus_energy if chat_stream else 0.5
# 根据优先级记录日志 # 根据优先级记录日志
if focus_energy >= 0.7: if focus_energy >= 0.7:
@@ -453,39 +571,45 @@ class MessageManager:
context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id)) context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
except Exception as e:
logger.error(f"检查独立分发周期的聊天流时发生错误: {e}")
# 更新活跃流计数 # 更新活跃流计数
active_count = len(self.context_manager.get_active_streams()) try:
chat_manager = get_chat_manager()
active_count = len([s for s in chat_manager.streams.values() if s.stream_context.is_active])
self.stats.active_streams = active_count self.stats.active_streams = active_count
if processed_streams > 0: if processed_streams > 0:
logger.debug(f"本次循环处理了 {processed_streams} 个流 | 活跃流总数: {active_count}") logger.debug(f"本次循环处理了 {processed_streams} 个流 | 活跃流总数: {active_count}")
except Exception as e:
logger.error(f"更新活跃流计数时发生错误: {e}")
async def _check_all_streams_with_priority(self): async def _check_all_streams_with_priority(self):
"""按优先级检查所有聊天流高focus_energy的流优先处理""" """按优先级检查所有聊天流高focus_energy的流优先处理"""
if not self.context_manager.get_active_streams(): try:
chat_manager = get_chat_manager()
if not chat_manager.streams:
return return
# 获取活跃的聊天流并按focus_energy排序 # 获取活跃的聊天流并按focus_energy排序
active_streams = [] active_streams = []
active_stream_ids = self.context_manager.get_active_streams() for stream_id, chat_stream in chat_manager.streams.items():
context = chat_stream.stream_context
for stream_id in active_stream_ids:
context = self.context_manager.get_stream_context(stream_id)
if not context or not context.is_active: if not context or not context.is_active:
continue continue
# 获取focus_energy,如果不存在则使用默认值 # 获取focus_energy
from src.plugin_system.apis.chat_api import get_chat_manager
chat_stream = get_chat_manager().get_stream(context.stream_id)
focus_energy = 0.5
if chat_stream:
focus_energy = chat_stream.focus_energy focus_energy = chat_stream.focus_energy
# 计算流优先级分数 # 计算流优先级分数
priority_score = self._calculate_stream_priority(context, focus_energy) priority_score = self._calculate_stream_priority(context, focus_energy)
active_streams.append((priority_score, stream_id, context)) active_streams.append((priority_score, stream_id, context))
except Exception as e:
logger.error(f"获取活跃流列表时发生错误: {e}")
return
# 按优先级降序排序 # 按优先级降序排序
active_streams.sort(reverse=True, key=lambda x: x[0]) active_streams.sort(reverse=True, key=lambda x: x[0])
@@ -497,7 +621,12 @@ class MessageManager:
active_stream_count += 1 active_stream_count += 1
# 检查是否有未读消息 # 检查是否有未读消息
unread_messages = self.context_manager.get_unread_messages(stream_id) try:
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
continue
unread_messages = chat_stream.context_manager.get_unread_messages()
if unread_messages: if unread_messages:
total_unread += len(unread_messages) total_unread += len(unread_messages)
@@ -512,6 +641,9 @@ class MessageManager:
f"优先级: {priority_score:.3f} | " f"优先级: {priority_score:.3f} | "
f"未读消息: {len(unread_messages)}" f"未读消息: {len(unread_messages)}"
) )
except Exception as e:
logger.error(f"处理流 {stream_id} 的未读消息时发生错误: {e}")
continue
# 更新统计 # 更新统计
self.stats.active_streams = active_stream_count self.stats.active_streams = active_stream_count
@@ -536,22 +668,33 @@ class MessageManager:
def _clear_all_unread_messages(self, stream_id: str): def _clear_all_unread_messages(self, stream_id: str):
"""清除指定上下文中的所有未读消息,防止意外情况导致消息一直未读""" """清除指定上下文中的所有未读消息,防止意外情况导致消息一直未读"""
unread_messages = self.context_manager.get_unread_messages(stream_id) try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"清除消息失败: 聊天流 {stream_id} 不存在")
return
# 获取未读消息
unread_messages = chat_stream.context_manager.get_unread_messages()
if not unread_messages: if not unread_messages:
return return
logger.warning(f"正在清除 {len(unread_messages)} 条未读消息") logger.warning(f"正在清除 {len(unread_messages)} 条未读消息")
# 将所有未读消息标记为已读 # 将所有未读消息标记为已读
context = self.context_manager.get_stream_context(stream_id) message_ids = [msg.message_id for msg in unread_messages]
if context: success = chat_stream.context_manager.mark_messages_as_read(message_ids)
for msg in unread_messages[:]: # 使用切片复制避免迭代时修改列表
try: if success:
context.mark_message_as_read(msg.message_id) self.stats.total_processed_messages += len(unread_messages)
self.stats.total_processed_messages += 1 logger.debug(f"强制清除 {len(unread_messages)} 条消息,标记为已读")
logger.debug(f"强制清除消息 {msg.message_id},标记为已读") else:
logger.error("标记未读消息为已读失败")
except Exception as e: except Exception as e:
logger.error(f"清除消息 {msg.message_id} 时出错: {e}") logger.error(f"清除未读消息时发生错误: {e}")
# 创建全局消息管理器实例 # 创建全局消息管理器实例

View File

@@ -49,10 +49,18 @@ class ChatStream:
from src.common.data_models.message_manager_data_model import StreamContext from src.common.data_models.message_manager_data_model import StreamContext
from src.plugin_system.base.component_types import ChatType, ChatMode from src.plugin_system.base.component_types import ChatType, ChatMode
# 创建StreamContext
self.stream_context: StreamContext = StreamContext( self.stream_context: StreamContext = StreamContext(
stream_id=stream_id, chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE, chat_mode=ChatMode.NORMAL stream_id=stream_id, chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE, chat_mode=ChatMode.NORMAL
) )
# 创建单流上下文管理器
from src.chat.message_manager.context_manager import SingleStreamContextManager
self.context_manager: SingleStreamContextManager = SingleStreamContextManager(
stream_id=stream_id, context=self.stream_context
)
# 基础参数 # 基础参数
self.base_interest_energy = 0.5 # 默认基础兴趣度 self.base_interest_energy = 0.5 # 默认基础兴趣度
self._focus_energy = 0.5 # 内部存储的focus_energy值 self._focus_energy = 0.5 # 内部存储的focus_energy值
@@ -61,6 +69,37 @@ class ChatStream:
# 自动加载历史消息 # 自动加载历史消息
self._load_history_messages() self._load_history_messages()
def __deepcopy__(self, memo):
"""自定义深拷贝方法,避免复制不可序列化的 asyncio.Task 对象"""
import copy
# 创建新的实例
new_stream = ChatStream(
stream_id=self.stream_id,
platform=self.platform,
user_info=copy.deepcopy(self.user_info, memo),
group_info=copy.deepcopy(self.group_info, memo),
)
# 复制基本属性
new_stream.create_time = self.create_time
new_stream.last_active_time = self.last_active_time
new_stream.sleep_pressure = self.sleep_pressure
new_stream.saved = self.saved
new_stream.base_interest_energy = self.base_interest_energy
new_stream._focus_energy = self._focus_energy
new_stream.no_reply_consecutive = self.no_reply_consecutive
# 复制 stream_context但跳过 processing_task
new_stream.stream_context = copy.deepcopy(self.stream_context, memo)
if hasattr(new_stream.stream_context, 'processing_task'):
new_stream.stream_context.processing_task = None
# 复制 context_manager
new_stream.context_manager = copy.deepcopy(self.context_manager, memo)
return new_stream
def to_dict(self) -> dict: def to_dict(self) -> dict:
"""转换为字典格式""" """转换为字典格式"""
return { return {
@@ -74,10 +113,10 @@ class ChatStream:
"focus_energy": self.focus_energy, "focus_energy": self.focus_energy,
# 基础兴趣度 # 基础兴趣度
"base_interest_energy": self.base_interest_energy, "base_interest_energy": self.base_interest_energy,
# 新增stream_context信息 # stream_context基本信息
"stream_context_chat_type": self.stream_context.chat_type.value, "stream_context_chat_type": self.stream_context.chat_type.value,
"stream_context_chat_mode": self.stream_context.chat_mode.value, "stream_context_chat_mode": self.stream_context.chat_mode.value,
# 新增interruption_count信息 # 统计信息
"interruption_count": self.stream_context.interruption_count, "interruption_count": self.stream_context.interruption_count,
} }
@@ -109,6 +148,14 @@ class ChatStream:
if "interruption_count" in data: if "interruption_count" in data:
instance.stream_context.interruption_count = data["interruption_count"] instance.stream_context.interruption_count = data["interruption_count"]
# 确保 context_manager 已初始化
if not hasattr(instance, "context_manager"):
from src.chat.message_manager.context_manager import SingleStreamContextManager
instance.context_manager = SingleStreamContextManager(
stream_id=instance.stream_id, context=instance.stream_context
)
return instance return instance
def update_active_time(self): def update_active_time(self):
@@ -195,12 +242,14 @@ class ChatStream:
self.stream_context.priority_info = getattr(message, "priority_info", None) self.stream_context.priority_info = getattr(message, "priority_info", None)
# 调试日志:记录数据转移情况 # 调试日志:记录数据转移情况
logger.debug(f"消息数据转移完成 - message_id: {db_message.message_id}, " logger.debug(
f"消息数据转移完成 - message_id: {db_message.message_id}, "
f"chat_id: {db_message.chat_id}, " f"chat_id: {db_message.chat_id}, "
f"is_mentioned: {db_message.is_mentioned}, " f"is_mentioned: {db_message.is_mentioned}, "
f"is_emoji: {db_message.is_emoji}, " f"is_emoji: {db_message.is_emoji}, "
f"is_picid: {db_message.is_picid}, " f"is_picid: {db_message.is_picid}, "
f"interest_value: {db_message.interest_value}") f"interest_value: {db_message.interest_value}"
)
def _safe_get_actions(self, message: "MessageRecv") -> Optional[list]: def _safe_get_actions(self, message: "MessageRecv") -> Optional[list]:
"""安全获取消息的actions字段""" """安全获取消息的actions字段"""
@@ -213,6 +262,7 @@ class ChatStream:
if isinstance(actions, str): if isinstance(actions, str):
try: try:
import json import json
actions = json.loads(actions) actions = json.loads(actions)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning(f"无法解析actions JSON字符串: {actions}") logger.warning(f"无法解析actions JSON字符串: {actions}")
@@ -269,14 +319,17 @@ class ChatStream:
@property @property
def focus_energy(self) -> float: def focus_energy(self) -> float:
"""使用重构后的能量管理器计算focus_energy""" """获取缓存的focus_energy"""
try: if hasattr(self, "_focus_energy"):
from src.chat.energy_system import energy_manager return self._focus_energy
else:
return 0.5
# 获取所有消息 async def calculate_focus_energy(self) -> float:
history_messages = self.stream_context.get_history_messages(limit=global_config.chat.max_context_size) """异步计算focus_energy"""
unread_messages = self.stream_context.get_unread_messages() try:
all_messages = history_messages + unread_messages # 使用单流上下文管理器获取消息
all_messages = self.context_manager.get_messages(limit=global_config.chat.max_context_size)
# 获取用户ID # 获取用户ID
user_id = None user_id = None
@@ -284,10 +337,10 @@ class ChatStream:
user_id = str(self.user_info.user_id) user_id = str(self.user_info.user_id)
# 使用能量管理器计算 # 使用能量管理器计算
energy = energy_manager.calculate_focus_energy( from src.chat.energy_system import energy_manager
stream_id=self.stream_id,
messages=all_messages, energy = await energy_manager.calculate_focus_energy(
user_id=user_id stream_id=self.stream_id, messages=all_messages, user_id=user_id
) )
# 更新内部存储 # 更新内部存储
@@ -299,7 +352,7 @@ class ChatStream:
except Exception as e: except Exception as e:
logger.error(f"获取focus_energy失败: {e}", exc_info=True) logger.error(f"获取focus_energy失败: {e}", exc_info=True)
# 返回缓存的值或默认值 # 返回缓存的值或默认值
if hasattr(self, '_focus_energy'): if hasattr(self, "_focus_energy"):
return self._focus_energy return self._focus_energy
else: else:
return 0.5 return 0.5
@@ -309,7 +362,7 @@ class ChatStream:
"""设置focus_energy值主要用于初始化或特殊场景""" """设置focus_energy值主要用于初始化或特殊场景"""
self._focus_energy = max(0.0, min(1.0, value)) self._focus_energy = max(0.0, min(1.0, value))
def _get_user_relationship_score(self) -> float: async def _get_user_relationship_score(self) -> float:
"""获取用户关系分""" """获取用户关系分"""
# 使用插件内部的兴趣度评分系统 # 使用插件内部的兴趣度评分系统
try: try:
@@ -317,7 +370,7 @@ class ChatStream:
if self.user_info and hasattr(self.user_info, "user_id"): if self.user_info and hasattr(self.user_info, "user_id"):
user_id = str(self.user_info.user_id) user_id = str(self.user_info.user_id)
relationship_score = chatter_interest_scoring_system._calculate_relationship_score(user_id) relationship_score = await chatter_interest_scoring_system._calculate_relationship_score(user_id)
logger.debug(f"ChatStream {self.stream_id}: 用户关系分 = {relationship_score:.3f}") logger.debug(f"ChatStream {self.stream_id}: 用户关系分 = {relationship_score:.3f}")
return max(0.0, min(1.0, relationship_score)) return max(0.0, min(1.0, relationship_score))
@@ -346,7 +399,8 @@ class ChatStream:
.order_by(desc(Messages.time)) .order_by(desc(Messages.time))
.limit(global_config.chat.max_context_size) .limit(global_config.chat.max_context_size)
) )
results = session.execute(stmt).scalars().all() result = session.execute(stmt)
results = result.scalars().all()
return results return results
# 在线程中执行数据库查询 # 在线程中执行数据库查询
@@ -404,7 +458,9 @@ class ChatStream:
) )
# 添加调试日志检查从数据库加载的interest_value # 添加调试日志检查从数据库加载的interest_value
logger.debug(f"加载历史消息 {db_message.message_id} - interest_value: {db_message.interest_value}") logger.debug(
f"加载历史消息 {db_message.message_id} - interest_value: {db_message.interest_value}"
)
# 标记为已读并添加到历史消息 # 标记为已读并添加到历史消息
db_message.is_read = True db_message.is_read = True
@@ -548,7 +604,11 @@ class ChatManager:
# 检查数据库中是否存在 # 检查数据库中是否存在
async def _db_find_stream_async(s_id: str): async def _db_find_stream_async(s_id: str):
async with get_db_session() as session: async with get_db_session() as session:
return (await session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id))).scalars().first() return (
(await session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id)))
.scalars()
.first()
)
model_instance = await _db_find_stream_async(stream_id) model_instance = await _db_find_stream_async(stream_id)
@@ -603,6 +663,15 @@ class ChatManager:
stream.set_context(self.last_messages[stream_id]) stream.set_context(self.last_messages[stream_id])
else: else:
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的") logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
# 确保 ChatStream 有自己的 context_manager
if not hasattr(stream, "context_manager"):
# 创建新的单流上下文管理器
from src.chat.message_manager.context_manager import SingleStreamContextManager
stream.context_manager = SingleStreamContextManager(
stream_id=stream_id, context=stream.stream_context
)
# 保存到内存和数据库 # 保存到内存和数据库
self.streams[stream_id] = stream self.streams[stream_id] = stream
await self._save_stream(stream) await self._save_stream(stream)
@@ -704,7 +773,8 @@ class ChatManager:
async def _db_load_all_streams_async(): async def _db_load_all_streams_async():
loaded_streams_data = [] loaded_streams_data = []
async with get_db_session() as session: async with get_db_session() as session:
for model_instance in (await session.execute(select(ChatStreams))).scalars().all(): result = await session.execute(select(ChatStreams))
for model_instance in result.scalars().all():
user_info_data = { user_info_data = {
"platform": model_instance.user_platform, "platform": model_instance.user_platform,
"user_id": model_instance.user_id, "user_id": model_instance.user_id,
@@ -752,6 +822,13 @@ class ChatManager:
self.streams[stream.stream_id] = stream self.streams[stream.stream_id] = stream
if stream.stream_id in self.last_messages: if stream.stream_id in self.last_messages:
stream.set_context(self.last_messages[stream.stream_id]) stream.set_context(self.last_messages[stream.stream_id])
# 确保 ChatStream 有自己的 context_manager
if not hasattr(stream, "context_manager"):
from src.chat.message_manager.context_manager import SingleStreamContextManager
stream.context_manager = SingleStreamContextManager(
stream_id=stream.stream_id, context=stream.stream_context
)
except Exception as e: except Exception as e:
logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}", exc_info=True) logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}", exc_info=True)

View File

@@ -41,7 +41,7 @@ class MessageStorage:
processed_plain_text = message.processed_plain_text processed_plain_text = message.processed_plain_text
if processed_plain_text: if processed_plain_text:
processed_plain_text = MessageStorage.replace_image_descriptions(processed_plain_text) processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL) filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL)
else: else:
filtered_processed_plain_text = "" filtered_processed_plain_text = ""
@@ -129,9 +129,9 @@ class MessageStorage:
key_words=key_words, key_words=key_words,
key_words_lite=key_words_lite, key_words_lite=key_words_lite,
) )
with get_db_session() as session: async with get_db_session() as session:
session.add(new_message) session.add(new_message)
session.commit() await session.commit()
except Exception: except Exception:
logger.exception("存储消息失败") logger.exception("存储消息失败")
@@ -174,13 +174,13 @@ class MessageStorage:
# 使用上下文管理器确保session正确管理 # 使用上下文管理器确保session正确管理
from src.common.database.sqlalchemy_models import get_db_session from src.common.database.sqlalchemy_models import get_db_session
with get_db_session() as session: async with get_db_session() as session:
matched_message = session.execute( matched_message = (await session.execute(
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time)) select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
).scalar() )).scalar()
if matched_message: if matched_message:
session.execute( await session.execute(
update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id) update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id)
) )
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}") logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
@@ -195,7 +195,7 @@ class MessageStorage:
) )
@staticmethod @staticmethod
def replace_image_descriptions(text: str) -> str: async def replace_image_descriptions(text: str) -> str:
"""将[图片:描述]替换为[picid:image_id]""" """将[图片:描述]替换为[picid:image_id]"""
# 先检查文本中是否有图片标记 # 先检查文本中是否有图片标记
pattern = r"\[图片:([^\]]+)\]" pattern = r"\[图片:([^\]]+)\]"
@@ -205,15 +205,15 @@ class MessageStorage:
logger.debug("文本中没有图片标记,直接返回原文本") logger.debug("文本中没有图片标记,直接返回原文本")
return text return text
def replace_match(match): async def replace_match(match):
description = match.group(1).strip() description = match.group(1).strip()
try: try:
from src.common.database.sqlalchemy_models import get_db_session from src.common.database.sqlalchemy_models import get_db_session
with get_db_session() as session: async with get_db_session() as session:
image_record = session.execute( image_record = (await session.execute(
select(Images).where(Images.description == description).order_by(desc(Images.timestamp)) select(Images).where(Images.description == description).order_by(desc(Images.timestamp))
).scalar() )).scalar()
return f"[picid:{image_record.image_id}]" if image_record else match.group(0) return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
except Exception: except Exception:
return match.group(0) return match.group(0)
@@ -271,7 +271,8 @@ class MessageStorage:
) )
).limit(50) # 限制每次修复的数量,避免性能问题 ).limit(50) # 限制每次修复的数量,避免性能问题
messages_to_fix = session.execute(query).scalars().all() result = session.execute(query)
messages_to_fix = result.scalars().all()
fixed_count = 0 fixed_count = 0
for msg in messages_to_fix: for msg in messages_to_fix:

View File

@@ -824,7 +824,8 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
description = "[图片内容未知]" # 默认描述 description = "[图片内容未知]" # 默认描述
try: try:
with get_db_session() as session: with get_db_session() as session:
image = session.execute(select(Images).where(Images.image_id == pic_id)).scalar_one_or_none() result = session.execute(select(Images).where(Images.image_id == pic_id))
image = result.scalar_one_or_none()
if image and image.description: # type: ignore if image and image.description: # type: ignore
description = image.description description = image.description
except Exception: except Exception:

View File

@@ -308,7 +308,8 @@ class ImageManager:
async with get_db_session() as session: async with get_db_session() as session:
# 优先检查Images表中是否已有完整的描述 # 优先检查Images表中是否已有完整的描述
existing_image = (await session.execute(select(Images).where(Images.emoji_hash == image_hash))).scalar() existing_image = result = await session.execute(select(Images).where(Images.emoji_hash == image_hash))
result.scalar()
if existing_image: if existing_image:
# 更新计数 # 更新计数
if hasattr(existing_image, "count") and existing_image.count is not None: if hasattr(existing_image, "count") and existing_image.count is not None:
@@ -527,7 +528,8 @@ class ImageManager:
image_bytes = base64.b64decode(image_base64) image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest() image_hash = hashlib.md5(image_bytes).hexdigest()
async with get_db_session() as session: async with get_db_session() as session:
existing_image = (await session.execute(select(Images).where(Images.emoji_hash == image_hash))).scalar() existing_image = result = await session.execute(select(Images).where(Images.emoji_hash == image_hash))
result.scalar()
if existing_image: if existing_image:
# 检查是否缺少必要字段,如果缺少则创建新记录 # 检查是否缺少必要字段,如果缺少则创建新记录
if ( if (

View File

@@ -22,13 +22,14 @@ from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import get_db_session, Videos from src.common.database.sqlalchemy_models import get_db_session, Videos
from sqlalchemy import select
logger = get_logger("utils_video") logger = get_logger("utils_video")
# Rust模块可用性检测 # Rust模块可用性检测
RUST_VIDEO_AVAILABLE = False RUST_VIDEO_AVAILABLE = False
try: try:
import rust_video import rust_video # pyright: ignore[reportMissingImports]
RUST_VIDEO_AVAILABLE = True RUST_VIDEO_AVAILABLE = True
logger.info("✅ Rust 视频处理模块加载成功") logger.info("✅ Rust 视频处理模块加载成功")
@@ -202,18 +203,20 @@ class VideoAnalyzer:
hash_obj.update(video_data) hash_obj.update(video_data)
return hash_obj.hexdigest() return hash_obj.hexdigest()
def _check_video_exists(self, video_hash: str) -> Optional[Videos]: async def _check_video_exists(self, video_hash: str) -> Optional[Videos]:
"""检查视频是否已经分析过""" """检查视频是否已经分析过"""
try: try:
with get_db_session() as session: async with get_db_session() as session:
# 明确刷新会话以确保看到其他事务的最新提交 # 明确刷新会话以确保看到其他事务的最新提交
session.expire_all() await session.expire_all()
return session.query(Videos).filter(Videos.video_hash == video_hash).first() stmt = select(Videos).where(Videos.video_hash == video_hash)
result = await session.execute(stmt)
return result.scalar_one_or_none()
except Exception as e: except Exception as e:
logger.warning(f"检查视频是否存在时出错: {e}") logger.warning(f"检查视频是否存在时出错: {e}")
return None return None
def _store_video_result( async def _store_video_result(
self, video_hash: str, description: str, metadata: Optional[Dict] = None self, video_hash: str, description: str, metadata: Optional[Dict] = None
) -> Optional[Videos]: ) -> Optional[Videos]:
"""存储视频分析结果到数据库""" """存储视频分析结果到数据库"""
@@ -223,9 +226,11 @@ class VideoAnalyzer:
return None return None
try: try:
with get_db_session() as session: async with get_db_session() as session:
# 只根据video_hash查找 # 只根据video_hash查找
existing_video = session.query(Videos).filter(Videos.video_hash == video_hash).first() stmt = select(Videos).where(Videos.video_hash == video_hash)
result = await session.execute(stmt)
existing_video = result.scalar_one_or_none()
if existing_video: if existing_video:
# 如果已存在,更新描述和计数 # 如果已存在,更新描述和计数
@@ -238,8 +243,8 @@ class VideoAnalyzer:
existing_video.fps = metadata.get("fps") existing_video.fps = metadata.get("fps")
existing_video.resolution = metadata.get("resolution") existing_video.resolution = metadata.get("resolution")
existing_video.file_size = metadata.get("file_size") existing_video.file_size = metadata.get("file_size")
session.commit() await session.commit()
session.refresh(existing_video) await session.refresh(existing_video)
logger.info(f"✅ 更新已存在的视频记录hash: {video_hash[:16]}..., count: {existing_video.count}") logger.info(f"✅ 更新已存在的视频记录hash: {video_hash[:16]}..., count: {existing_video.count}")
return existing_video return existing_video
else: else:
@@ -254,8 +259,8 @@ class VideoAnalyzer:
video_record.file_size = metadata.get("file_size") video_record.file_size = metadata.get("file_size")
session.add(video_record) session.add(video_record)
session.commit() await session.commit()
session.refresh(video_record) await session.refresh(video_record)
logger.info(f"✅ 新视频分析结果已保存到数据库hash: {video_hash[:16]}...") logger.info(f"✅ 新视频分析结果已保存到数据库hash: {video_hash[:16]}...")
return video_record return video_record
except Exception as e: except Exception as e:
@@ -704,7 +709,7 @@ class VideoAnalyzer:
logger.info("✅ 等待结束,检查是否有处理结果") logger.info("✅ 等待结束,检查是否有处理结果")
# 检查是否有结果了 # 检查是否有结果了
existing_video = self._check_video_exists(video_hash) existing_video = await self._check_video_exists(video_hash)
if existing_video: if existing_video:
logger.info(f"✅ 找到了处理结果,直接返回 (id: {existing_video.id})") logger.info(f"✅ 找到了处理结果,直接返回 (id: {existing_video.id})")
return {"summary": existing_video.description} return {"summary": existing_video.description}
@@ -718,7 +723,7 @@ class VideoAnalyzer:
logger.info(f"🔒 获得视频处理锁,开始处理 (hash: {video_hash[:16]}...)") logger.info(f"🔒 获得视频处理锁,开始处理 (hash: {video_hash[:16]}...)")
# 再次检查数据库(可能在等待期间已经有结果了) # 再次检查数据库(可能在等待期间已经有结果了)
existing_video = self._check_video_exists(video_hash) existing_video = await self._check_video_exists(video_hash)
if existing_video: if existing_video:
logger.info(f"✅ 获得锁后发现已有结果,直接返回 (id: {existing_video.id})") logger.info(f"✅ 获得锁后发现已有结果,直接返回 (id: {existing_video.id})")
video_event.set() # 通知其他等待者 video_event.set() # 通知其他等待者
@@ -749,7 +754,7 @@ class VideoAnalyzer:
# 保存分析结果到数据库(仅保存成功的结果) # 保存分析结果到数据库(仅保存成功的结果)
if success and not result.startswith(""): if success and not result.startswith(""):
metadata = {"filename": filename, "file_size": len(video_bytes), "analysis_timestamp": time.time()} metadata = {"filename": filename, "file_size": len(video_bytes), "analysis_timestamp": time.time()}
self._store_video_result(video_hash=video_hash, description=result, metadata=metadata) await self._store_video_result(video_hash=video_hash, description=result, metadata=metadata)
logger.info("✅ 分析结果已保存到数据库") logger.info("✅ 分析结果已保存到数据库")
else: else:
logger.warning("⚠️ 分析失败,不保存到数据库以便后续重试") logger.warning("⚠️ 分析失败,不保存到数据库以便后续重试")

View File

@@ -22,9 +22,9 @@ class DatabaseProxy:
self._session = None self._session = None
@staticmethod @staticmethod
def initialize(*args, **kwargs): async def initialize(*args, **kwargs):
"""初始化数据库连接""" """初始化数据库连接"""
return initialize_database_compat() return await initialize_database_compat()
class SQLAlchemyTransaction: class SQLAlchemyTransaction:
@@ -88,7 +88,7 @@ async def initialize_sql_database(database_config):
logger.info(f" 数据库文件: {db_path}") logger.info(f" 数据库文件: {db_path}")
# 使用SQLAlchemy初始化 # 使用SQLAlchemy初始化
success = initialize_database_compat() success = await initialize_database_compat()
if success: if success:
_sql_engine = await get_engine() _sql_engine = await get_engine()
logger.info("SQLAlchemy数据库初始化成功") logger.info("SQLAlchemy数据库初始化成功")

View File

@@ -706,7 +706,8 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
raise RuntimeError("Database session not initialized") raise RuntimeError("Database session not initialized")
session = SessionLocal() session = SessionLocal()
yield session yield session
except Exception: except Exception as e:
logger.error(f"数据库会话错误: {e}")
if session: if session:
await session.rollback() await session.rollback()
raise raise

View File

@@ -101,7 +101,8 @@ def find_messages(
# 获取时间最早的 limit 条记录,已经是正序 # 获取时间最早的 limit 条记录,已经是正序
query = query.order_by(Messages.time.asc()).limit(limit) query = query.order_by(Messages.time.asc()).limit(limit)
try: try:
results = session.execute(query).scalars().all() results = result = session.execute(query)
result.scalars().all()
except Exception as e: except Exception as e:
logger.error(f"执行earliest查询失败: {e}") logger.error(f"执行earliest查询失败: {e}")
results = [] results = []
@@ -109,7 +110,8 @@ def find_messages(
# 获取时间最晚的 limit 条记录 # 获取时间最晚的 limit 条记录
query = query.order_by(Messages.time.desc()).limit(limit) query = query.order_by(Messages.time.desc()).limit(limit)
try: try:
latest_results = session.execute(query).scalars().all() latest_results = result = session.execute(query)
result.scalars().all()
# 将结果按时间正序排列 # 将结果按时间正序排列
results = sorted(latest_results, key=lambda msg: msg.time) results = sorted(latest_results, key=lambda msg: msg.time)
except Exception as e: except Exception as e:
@@ -133,7 +135,8 @@ def find_messages(
if sort_terms: if sort_terms:
query = query.order_by(*sort_terms) query = query.order_by(*sort_terms)
try: try:
results = session.execute(query).scalars().all() results = result = session.execute(query)
result.scalars().all()
except Exception as e: except Exception as e:
logger.error(f"执行无限制查询失败: {e}") logger.error(f"执行无限制查询失败: {e}")
results = [] results = []
@@ -207,5 +210,5 @@ def count_messages(message_filter: dict[str, Any]) -> int:
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。 # 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
# 注意:对于 SQLAlchemy插入操作通常是使用 session.add() 和 session.commit()。 # 注意:对于 SQLAlchemy插入操作通常是使用 await session.add() 和 await session.commit()。
# 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。 # 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。

View File

@@ -161,7 +161,7 @@ class LLMUsageRecorder:
session = None session = None
try: try:
# 使用 SQLAlchemy 会话创建记录 # 使用 SQLAlchemy 会话创建记录
with get_db_session() as session: async with get_db_session() as session:
usage_record = LLMUsage( usage_record = LLMUsage(
model_name=model_info.model_identifier, model_name=model_info.model_identifier,
model_assign_name=model_info.name, model_assign_name=model_info.name,
@@ -172,14 +172,14 @@ class LLMUsageRecorder:
prompt_tokens=model_usage.prompt_tokens or 0, prompt_tokens=model_usage.prompt_tokens or 0,
completion_tokens=model_usage.completion_tokens or 0, completion_tokens=model_usage.completion_tokens or 0,
total_tokens=model_usage.total_tokens or 0, total_tokens=model_usage.total_tokens or 0,
cost=total_cost or 0.0, cost=1.0,
time_cost=round(time_cost or 0.0, 3), time_cost=round(time_cost or 0.0, 3),
status="success", status="success",
timestamp=datetime.now(), # SQLAlchemy 会处理 DateTime 字段 timestamp=datetime.now(), # SQLAlchemy 会处理 DateTime 字段
) )
session.add(usage_record) session.add(usage_record)
session.commit() await session.commit()
logger.debug( logger.debug(
f"Token使用情况 - 模型: {model_usage.model_name}, " f"Token使用情况 - 模型: {model_usage.model_name}, "

View File

@@ -163,7 +163,8 @@ class PersonInfoManager:
try: try:
# 在需要时获取会话 # 在需要时获取会话
async with get_db_session() as session: async with get_db_session() as session:
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name))).scalar() record = result = await session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name))
result.scalar()
return record.person_id if record else "" return record.person_id if record else ""
except Exception as e: except Exception as e:
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (SQLAlchemy): {e}") logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (SQLAlchemy): {e}")
@@ -339,7 +340,8 @@ class PersonInfoManager:
start_time = time.time() start_time = time.time()
async with get_db_session() as session: async with get_db_session() as session:
try: try:
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar() result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))
record = result.scalar()
query_time = time.time() query_time = time.time()
if record: if record:
setattr(record, f_name, val_to_set) setattr(record, f_name, val_to_set)
@@ -401,7 +403,8 @@ class PersonInfoManager:
async def _db_has_field_async(p_id: str, f_name: str): async def _db_has_field_async(p_id: str, f_name: str):
async with get_db_session() as session: async with get_db_session() as session:
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar() result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))
record = result.scalar()
return bool(record) return bool(record)
try: try:
@@ -512,10 +515,9 @@ class PersonInfoManager:
async def _db_check_name_exists_async(name_to_check): async def _db_check_name_exists_async(name_to_check):
async with get_db_session() as session: async with get_db_session() as session:
return ( result = await session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check))
(await session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check))).scalar() record = result.scalar()
is not None return record is not None
)
if await _db_check_name_exists_async(generated_nickname): if await _db_check_name_exists_async(generated_nickname):
is_duplicate = True is_duplicate = True
@@ -556,7 +558,8 @@ class PersonInfoManager:
async def _db_delete_async(p_id: str): async def _db_delete_async(p_id: str):
try: try:
async with get_db_session() as session: async with get_db_session() as session:
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar() result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))
record = result.scalar()
if record: if record:
await session.delete(record) await session.delete(record)
await session.commit() await session.commit()
@@ -585,7 +588,9 @@ class PersonInfoManager:
async def _get_record_sync(): async def _get_record_sync():
async with get_db_session() as session: async with get_db_session() as session:
return (await session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id))).scalar() result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id))
record = result.scalar()
return record
try: try:
record = asyncio.run(_get_record_sync()) record = asyncio.run(_get_record_sync())
@@ -624,7 +629,9 @@ class PersonInfoManager:
async def _db_get_record_async(p_id: str): async def _db_get_record_async(p_id: str):
async with get_db_session() as session: async with get_db_session() as session:
return (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar() result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))
record = result.scalar()
return record
record = await _db_get_record_async(person_id) record = await _db_get_record_async(person_id)
@@ -700,7 +707,8 @@ class PersonInfoManager:
"""原子性的获取或创建操作""" """原子性的获取或创建操作"""
async with get_db_session() as session: async with get_db_session() as session:
# 首先尝试获取现有记录 # 首先尝试获取现有记录
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar() result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))
record = result.scalar()
if record: if record:
return record, False # 记录存在,未创建 return record, False # 记录存在,未创建
@@ -715,7 +723,8 @@ class PersonInfoManager:
# 如果创建失败(可能是因为竞态条件),再次尝试获取 # 如果创建失败(可能是因为竞态条件),再次尝试获取
if "UNIQUE constraint failed" in str(e): if "UNIQUE constraint failed" in str(e):
logger.debug(f"检测到并发创建用户 {p_id},获取现有记录") logger.debug(f"检测到并发创建用户 {p_id},获取现有记录")
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar() result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))
record = result.scalar()
if record: if record:
return record, False # 其他协程已创建,返回现有记录 return record, False # 其他协程已创建,返回现有记录
# 如果仍然失败,重新抛出异常 # 如果仍然失败,重新抛出异常

View File

@@ -122,7 +122,7 @@ async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
matched_emotion = random.choice(selected_emoji.emotion) if selected_emoji.emotion else "随机表情" matched_emotion = random.choice(selected_emoji.emotion) if selected_emoji.emotion else "随机表情"
# 记录使用次数 # 记录使用次数
emoji_manager.record_usage(selected_emoji.hash) await emoji_manager.record_usage(selected_emoji.hash)
results.append((emoji_base64, selected_emoji.description, matched_emotion)) results.append((emoji_base64, selected_emoji.description, matched_emotion))
if not results and count > 0: if not results and count > 0:
@@ -180,7 +180,7 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
return None return None
# 记录使用次数 # 记录使用次数
emoji_manager.record_usage(selected_emoji.hash) await emoji_manager.record_usage(selected_emoji.hash)
logger.info(f"[EmojiAPI] 成功获取情感表情包: {selected_emoji.description}") logger.info(f"[EmojiAPI] 成功获取情感表情包: {selected_emoji.description}")
return emoji_base64, selected_emoji.description, emotion return emoji_base64, selected_emoji.description, emotion

View File

@@ -65,7 +65,7 @@ class AffinityChatter(BaseChatter):
""" """
try: try:
# 触发表达学习 # 触发表达学习
learner = expression_learner_manager.get_expression_learner(self.stream_id) learner = await expression_learner_manager.get_expression_learner(self.stream_id)
asyncio.create_task(learner.trigger_learning_for_chat()) asyncio.create_task(learner.trigger_learning_for_chat())
unread_messages = context.get_unread_messages() unread_messages = context.get_unread_messages()

View File

@@ -69,7 +69,7 @@ class ChatterInterestScoringSystem:
keywords = self._extract_keywords_from_database(message) keywords = self._extract_keywords_from_database(message)
interest_match_score = await self._calculate_interest_match_score(message.processed_plain_text, keywords) interest_match_score = await self._calculate_interest_match_score(message.processed_plain_text, keywords)
relationship_score = self._calculate_relationship_score(message.user_info.user_id) relationship_score = await self._calculate_relationship_score(message.user_info.user_id)
mentioned_score = self._calculate_mentioned_score(message, bot_nickname) mentioned_score = self._calculate_mentioned_score(message, bot_nickname)
total_score = ( total_score = (
@@ -189,7 +189,7 @@ class ChatterInterestScoringSystem:
unique_keywords = list(set(keywords)) unique_keywords = list(set(keywords))
return unique_keywords[:10] # 返回前10个唯一关键词 return unique_keywords[:10] # 返回前10个唯一关键词
def _calculate_relationship_score(self, user_id: str) -> float: async def _calculate_relationship_score(self, user_id: str) -> float:
"""计算关系分 - 从数据库获取关系分""" """计算关系分 - 从数据库获取关系分"""
# 优先使用内存中的关系分 # 优先使用内存中的关系分
if user_id in self.user_relationships: if user_id in self.user_relationships:
@@ -212,7 +212,7 @@ class ChatterInterestScoringSystem:
global_tracker = ChatterRelationshipTracker() global_tracker = ChatterRelationshipTracker()
if global_tracker: if global_tracker:
relationship_score = global_tracker.get_user_relationship_score(user_id) relationship_score = await global_tracker.get_user_relationship_score(user_id)
# 同时更新内存缓存 # 同时更新内存缓存
self.user_relationships[user_id] = relationship_score self.user_relationships[user_id] = relationship_score
return relationship_score return relationship_score

View File

@@ -287,7 +287,7 @@ class ChatterRelationshipTracker:
# ===== 数据库支持方法 ===== # ===== 数据库支持方法 =====
def get_user_relationship_score(self, user_id: str) -> float: async def get_user_relationship_score(self, user_id: str) -> float:
"""获取用户关系分""" """获取用户关系分"""
# 先检查缓存 # 先检查缓存
if user_id in self.user_relationship_cache: if user_id in self.user_relationship_cache:
@@ -298,7 +298,7 @@ class ChatterRelationshipTracker:
return cache_data.get("relationship_score", global_config.affinity_flow.base_relationship_score) return cache_data.get("relationship_score", global_config.affinity_flow.base_relationship_score)
# 缓存过期或不存在,从数据库获取 # 缓存过期或不存在,从数据库获取
relationship_data = self._get_user_relationship_from_db(user_id) relationship_data = await self._get_user_relationship_from_db(user_id)
if relationship_data: if relationship_data:
# 更新缓存 # 更新缓存
self.user_relationship_cache[user_id] = { self.user_relationship_cache[user_id] = {
@@ -313,37 +313,38 @@ class ChatterRelationshipTracker:
# 数据库中也没有,返回默认值 # 数据库中也没有,返回默认值
return global_config.affinity_flow.base_relationship_score return global_config.affinity_flow.base_relationship_score
def _get_user_relationship_from_db(self, user_id: str) -> Optional[Dict]: async def _get_user_relationship_from_db(self, user_id: str) -> Optional[Dict]:
"""从数据库获取用户关系数据""" """从数据库获取用户关系数据"""
try: try:
with get_db_session() as session: async with get_db_session() as session:
# 查询用户关系表 # 查询用户关系表
stmt = select(UserRelationships).where(UserRelationships.user_id == user_id) stmt = select(UserRelationships).where(UserRelationships.user_id == user_id)
result = session.execute(stmt).scalar_one_or_none() result = await session.execute(stmt)
relationship = result.scalar_one_or_none()
if result: if relationship:
return { return {
"relationship_text": result.relationship_text or "", "relationship_text": relationship.relationship_text or "",
"relationship_score": float(result.relationship_score) "relationship_score": float(relationship.relationship_score)
if result.relationship_score is not None if relationship.relationship_score is not None
else 0.3, else 0.3,
"last_updated": result.last_updated, "last_updated": relationship.last_updated,
} }
except Exception as e: except Exception as e:
logger.error(f"从数据库获取用户关系失败: {e}") logger.error(f"从数据库获取用户关系失败: {e}")
return None return None
def _update_user_relationship_in_db(self, user_id: str, relationship_text: str, relationship_score: float): async def _update_user_relationship_in_db(self, user_id: str, relationship_text: str, relationship_score: float):
"""更新数据库中的用户关系""" """更新数据库中的用户关系"""
try: try:
current_time = time.time() current_time = time.time()
with get_db_session() as session: async with get_db_session() as session:
# 检查是否已存在关系记录 # 检查是否已存在关系记录
existing = session.execute( stmt = select(UserRelationships).where(UserRelationships.user_id == user_id)
select(UserRelationships).where(UserRelationships.user_id == user_id) result = await session.execute(stmt)
).scalar_one_or_none() existing = result.scalar_one_or_none()
if existing: if existing:
# 更新现有记录 # 更新现有记录
@@ -362,7 +363,7 @@ class ChatterRelationshipTracker:
) )
session.add(new_relationship) session.add(new_relationship)
session.commit() await session.commit()
logger.info(f"已更新数据库中用户关系: {user_id} -> 分数: {relationship_score:.3f}") logger.info(f"已更新数据库中用户关系: {user_id} -> 分数: {relationship_score:.3f}")
except Exception as e: except Exception as e:
@@ -399,7 +400,7 @@ class ChatterRelationshipTracker:
logger.debug(f"💬 [RelationshipTracker] 找到用户 {user_id} 在上次回复后的 {len(user_reactions)} 条反应消息") logger.debug(f"💬 [RelationshipTracker] 找到用户 {user_id} 在上次回复后的 {len(user_reactions)} 条反应消息")
# 获取当前关系数据 # 获取当前关系数据
current_relationship = self._get_user_relationship_from_db(user_id) current_relationship = await self._get_user_relationship_from_db(user_id)
current_score = ( current_score = (
current_relationship.get("relationship_score", global_config.affinity_flow.base_relationship_score) current_relationship.get("relationship_score", global_config.affinity_flow.base_relationship_score)
if current_relationship if current_relationship
@@ -417,14 +418,14 @@ class ChatterRelationshipTracker:
logger.error(f"回复后关系追踪失败: {e}") logger.error(f"回复后关系追踪失败: {e}")
logger.debug("错误详情:", exc_info=True) logger.debug("错误详情:", exc_info=True)
def _get_last_tracked_time(self, user_id: str) -> float: async def _get_last_tracked_time(self, user_id: str) -> float:
"""获取上次追踪时间""" """获取上次追踪时间"""
# 先检查缓存 # 先检查缓存
if user_id in self.user_relationship_cache: if user_id in self.user_relationship_cache:
return self.user_relationship_cache[user_id].get("last_tracked", 0) return self.user_relationship_cache[user_id].get("last_tracked", 0)
# 从数据库获取 # 从数据库获取
relationship_data = self._get_user_relationship_from_db(user_id) relationship_data = await self._get_user_relationship_from_db(user_id)
if relationship_data: if relationship_data:
return relationship_data.get("last_updated", 0) return relationship_data.get("last_updated", 0)
@@ -433,7 +434,7 @@ class ChatterRelationshipTracker:
async def _get_last_bot_reply_to_user(self, user_id: str) -> Optional[DatabaseMessages]: async def _get_last_bot_reply_to_user(self, user_id: str) -> Optional[DatabaseMessages]:
"""获取上次bot回复该用户的消息""" """获取上次bot回复该用户的消息"""
try: try:
with get_db_session() as session: async with get_db_session() as session:
# 查询bot回复给该用户的最新消息 # 查询bot回复给该用户的最新消息
stmt = ( stmt = (
select(Messages) select(Messages)
@@ -443,10 +444,11 @@ class ChatterRelationshipTracker:
.limit(1) .limit(1)
) )
result = session.execute(stmt).scalar_one_or_none() result = await session.execute(stmt)
if result: message = result.scalar_one_or_none()
if message:
# 将SQLAlchemy模型转换为DatabaseMessages对象 # 将SQLAlchemy模型转换为DatabaseMessages对象
return self._sqlalchemy_to_database_messages(result) return self._sqlalchemy_to_database_messages(message)
except Exception as e: except Exception as e:
logger.error(f"获取上次回复消息失败: {e}") logger.error(f"获取上次回复消息失败: {e}")
@@ -456,7 +458,7 @@ class ChatterRelationshipTracker:
async def _get_user_reactions_after_reply(self, user_id: str, reply_time: float) -> List[DatabaseMessages]: async def _get_user_reactions_after_reply(self, user_id: str, reply_time: float) -> List[DatabaseMessages]:
"""获取用户在bot回复后的反应消息""" """获取用户在bot回复后的反应消息"""
try: try:
with get_db_session() as session: async with get_db_session() as session:
# 查询用户在回复时间之后的5分钟内的消息 # 查询用户在回复时间之后的5分钟内的消息
end_time = reply_time + 5 * 60 # 5分钟 end_time = reply_time + 5 * 60 # 5分钟
@@ -468,9 +470,10 @@ class ChatterRelationshipTracker:
.order_by(Messages.time) .order_by(Messages.time)
) )
results = session.execute(stmt).scalars().all() result = await session.execute(stmt)
if results: messages = result.scalars().all()
return [self._sqlalchemy_to_database_messages(result) for result in results] if messages:
return [self._sqlalchemy_to_database_messages(message) for message in messages]
except Exception as e: except Exception as e:
logger.error(f"获取用户反应消息失败: {e}") logger.error(f"获取用户反应消息失败: {e}")
@@ -593,7 +596,7 @@ class ChatterRelationshipTracker:
quality = response_data.get("interaction_quality", "medium") quality = response_data.get("interaction_quality", "medium")
# 更新数据库 # 更新数据库
self._update_user_relationship_in_db(user_id, new_text, new_score) await self._update_user_relationship_in_db(user_id, new_text, new_score)
# 更新缓存 # 更新缓存
self.user_relationship_cache[user_id] = { self.user_relationship_cache[user_id] = {
@@ -696,7 +699,7 @@ class ChatterRelationshipTracker:
) )
# 更新数据库和缓存 # 更新数据库和缓存
self._update_user_relationship_in_db(user_id, new_text, new_score) await self._update_user_relationship_in_db(user_id, new_text, new_score)
self.user_relationship_cache[user_id] = { self.user_relationship_cache[user_id] = {
"relationship_text": new_text, "relationship_text": new_text,
"relationship_score": new_score, "relationship_score": new_score,

View File

@@ -13,6 +13,7 @@ from typing import Callable
from src.common.logger import get_logger from src.common.logger import get_logger
from src.schedule.schedule_manager import schedule_manager from src.schedule.schedule_manager import schedule_manager
from src.common.database.sqlalchemy_database_api import get_db_session from src.common.database.sqlalchemy_database_api import get_db_session
from sqlalchemy import select
from src.common.database.sqlalchemy_models import MaiZoneScheduleStatus from src.common.database.sqlalchemy_models import MaiZoneScheduleStatus
from .qzone_service import QZoneService from .qzone_service import QZoneService
@@ -138,15 +139,13 @@ class SchedulerService:
:return: 如果已处理过,返回 True否则返回 False。 :return: 如果已处理过,返回 True否则返回 False。
""" """
try: try:
with get_db_session() as session: async with get_db_session() as session:
record = ( stmt = select(MaiZoneScheduleStatus).where(
session.query(MaiZoneScheduleStatus)
.filter(
MaiZoneScheduleStatus.datetime_hour == hour_str, MaiZoneScheduleStatus.datetime_hour == hour_str,
MaiZoneScheduleStatus.is_processed == True, # noqa: E712 MaiZoneScheduleStatus.is_processed == True, # noqa: E712
) )
.first() result = await session.execute(stmt)
) record = result.scalar_one_or_none()
return record is not None return record is not None
except Exception as e: except Exception as e:
logger.error(f"检查日程处理状态时发生数据库错误: {e}") logger.error(f"检查日程处理状态时发生数据库错误: {e}")
@@ -162,11 +161,11 @@ class SchedulerService:
:param content: 最终发送的说说内容或错误信息。 :param content: 最终发送的说说内容或错误信息。
""" """
try: try:
with get_db_session() as session: async with get_db_session() as session:
# 查找是否已存在该记录 # 查找是否已存在该记录
record = ( stmt = select(MaiZoneScheduleStatus).where(MaiZoneScheduleStatus.datetime_hour == hour_str)
session.query(MaiZoneScheduleStatus).filter(MaiZoneScheduleStatus.datetime_hour == hour_str).first() result = await session.execute(stmt)
) record = result.scalar_one_or_none()
if record: if record:
# 如果存在,则更新状态 # 如果存在,则更新状态
@@ -185,7 +184,7 @@ class SchedulerService:
send_success=success, send_success=success,
) )
session.add(new_record) session.add(new_record)
session.commit() await session.commit()
logger.info(f"已更新日程处理状态: {hour_str} - {activity} - 成功: {success}") logger.info(f"已更新日程处理状态: {hour_str} - {activity} - 成功: {success}")
except Exception as e: except Exception as e:
logger.error(f"更新日程处理状态时发生数据库错误: {e}") logger.error(f"更新日程处理状态时发生数据库错误: {e}")

View File

@@ -64,15 +64,9 @@ async def message_recv(server_connection: Server.ServerConnection):
# 处理完整消息(可能是重组后的,也可能是原本就完整的) # 处理完整消息(可能是重组后的,也可能是原本就完整的)
post_type = decoded_raw_message.get("post_type") post_type = decoded_raw_message.get("post_type")
# 兼容没有 post_type 的普通消息
if not post_type and "message_type" in decoded_raw_message:
decoded_raw_message["post_type"] = "message"
post_type = "message"
if post_type in ["meta_event", "message", "notice"]: if post_type in ["meta_event", "message", "notice"]:
await message_queue.put(decoded_raw_message) await message_queue.put(decoded_raw_message)
else: elif post_type is None:
await put_response(decoded_raw_message) await put_response(decoded_raw_message)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
@@ -428,8 +422,9 @@ class NapcatAdapterPlugin(BasePlugin):
def get_plugin_components(self): def get_plugin_components(self):
self.register_events() self.register_events()
components = [(LauchNapcatAdapterHandler.get_handler_info(), LauchNapcatAdapterHandler), components = []
(StopNapcatAdapterHandler.get_handler_info(), StopNapcatAdapterHandler)] components.append((LauchNapcatAdapterHandler.get_handler_info(), LauchNapcatAdapterHandler))
components.append((StopNapcatAdapterHandler.get_handler_info(), StopNapcatAdapterHandler))
for handler in get_classes_in_module(event_handlers): for handler in get_classes_in_module(event_handlers):
if issubclass(handler, BaseEventHandler): if issubclass(handler, BaseEventHandler):
components.append((handler.get_handler_info(), handler)) components.append((handler.get_handler_info(), handler))

View File

@@ -1,156 +1,162 @@
"""Napcat Adapter 插件数据库层 (基于主程序异步SQLAlchemy API) import os
from typing import Optional, List
本模块替换原先的 sqlmodel + 同步Session 实现:
1. 复用主项目的异步数据库连接与迁移体系
2. 提供与旧接口名兼容的方法(update_ban_record/create_ban_record/delete_ban_record)
3. 新增首选异步方法: update_ban_records / create_or_update / delete_record / get_ban_records
数据语义:
user_id == 0 表示群全体禁言
注意: 所有方法均为异步, 需要在 async 上下文中调用。
"""
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, List, Sequence from sqlmodel import Field, Session, SQLModel, create_engine, select
from sqlalchemy import Column, Integer, BigInteger, UniqueConstraint, select, Index
from sqlalchemy.ext.asyncio import AsyncSession
from src.common.database.sqlalchemy_models import Base, get_db_session
from src.common.logger import get_logger from src.common.logger import get_logger
logger = get_logger("napcat_adapter") logger = get_logger("napcat_adapter")
"""
表记录的方式:
| group_id | user_id | lift_time |
|----------|---------|-----------|
class NapcatBanRecord(Base): 其中使用 user_id == 0 表示群全体禁言
__tablename__ = "napcat_ban_records" """
id = Column(Integer, primary_key=True, autoincrement=True)
group_id = Column(BigInteger, nullable=False, index=True)
user_id = Column(BigInteger, nullable=False, index=True) # 0 == 全体禁言
lift_time = Column(BigInteger, nullable=True) # -1 / None 表示未知/永久
__table_args__ = (
UniqueConstraint("group_id", "user_id", name="uq_napcat_group_user"),
Index("idx_napcat_ban_group", "group_id"),
Index("idx_napcat_ban_user", "user_id"),
)
@dataclass @dataclass
class BanUser: class BanUser:
"""
程序处理使用的实例
"""
user_id: int user_id: int
group_id: int group_id: int
lift_time: Optional[int] = -1 lift_time: Optional[int] = Field(default=-1)
def identity(self) -> tuple[int, int]:
return self.group_id, self.user_id
class NapcatDatabase: class DB_BanUser(SQLModel, table=True):
async def _fetch_all(self, session: AsyncSession) -> Sequence[NapcatBanRecord]: """
result = await session.execute(select(NapcatBanRecord)) 表示数据库中的用户禁言记录。
return result.scalars().all() 使用双重主键
"""
async def get_ban_records(self) -> List[BanUser]: user_id: int = Field(index=True, primary_key=True) # 被禁言用户的用户 ID
async with get_db_session() as session: group_id: int = Field(index=True, primary_key=True) # 用户被禁言的群组 ID
rows = await self._fetch_all(session) lift_time: Optional[int] # 禁言解除的时间(时间戳)
return [BanUser(group_id=r.group_id, user_id=r.user_id, lift_time=r.lift_time) for r in rows]
async def update_ban_records(self, ban_list: List[BanUser]) -> None:
target_map = {b.identity(): b for b in ban_list}
async with get_db_session() as session:
rows = await self._fetch_all(session)
existing_map = {(r.group_id, r.user_id): r for r in rows}
changed = 0 def is_identical(obj1: BanUser, obj2: BanUser) -> bool:
for ident, ban in target_map.items(): """
if ident in existing_map: 检查两个 BanUser 对象是否相同。
row = existing_map[ident] """
if row.lift_time != ban.lift_time: return obj1.user_id == obj2.user_id and obj1.group_id == obj2.group_id
row.lift_time = ban.lift_time
changed += 1
class DatabaseManager:
"""
数据库管理类,负责与数据库交互。
"""
def __init__(self):
os.makedirs(os.path.join(os.path.dirname(__file__), "..", "data"), exist_ok=True) # 确保数据目录存在
DATABASE_FILE = os.path.join(os.path.dirname(__file__), "..", "data", "NapcatAdapter.db")
self.sqlite_url = f"sqlite:///{DATABASE_FILE}" # SQLite 数据库 URL
self.engine = create_engine(self.sqlite_url, echo=False) # 创建数据库引擎
self._ensure_database() # 确保数据库和表已创建
def _ensure_database(self) -> None:
"""
确保数据库和表已创建。
"""
logger.info("确保数据库文件和表已创建...")
SQLModel.metadata.create_all(self.engine)
logger.info("数据库和表已创建或已存在")
def update_ban_record(self, ban_list: List[BanUser]) -> None:
# sourcery skip: class-extract-method
"""
更新禁言列表到数据库。
支持在不存在时创建新记录,对于多余的项目自动删除。
"""
with Session(self.engine) as session:
all_records = session.exec(select(DB_BanUser)).all()
for ban_user in ban_list:
statement = select(DB_BanUser).where(
DB_BanUser.user_id == ban_user.user_id, DB_BanUser.group_id == ban_user.group_id
)
if existing_record := session.exec(statement).first():
if existing_record.lift_time == ban_user.lift_time:
logger.debug(f"禁言记录未变更: {existing_record}")
continue
# 更新现有记录的 lift_time
existing_record.lift_time = ban_user.lift_time
session.add(existing_record)
logger.debug(f"更新禁言记录: {existing_record}")
else: else:
session.add( # 创建新记录
NapcatBanRecord(group_id=ban.group_id, user_id=ban.user_id, lift_time=ban.lift_time) db_record = DB_BanUser(
user_id=ban_user.user_id, group_id=ban_user.group_id, lift_time=ban_user.lift_time
) )
changed += 1 session.add(db_record)
logger.debug(f"创建新禁言记录: {ban_user}")
# 删除不在 ban_list 中的记录
for db_record in all_records:
record = BanUser(user_id=db_record.user_id, group_id=db_record.group_id, lift_time=db_record.lift_time)
if not any(is_identical(record, ban_user) for ban_user in ban_list):
statement = select(DB_BanUser).where(
DB_BanUser.user_id == record.user_id, DB_BanUser.group_id == record.group_id
)
if ban_record := session.exec(statement).first():
session.delete(ban_record)
removed = 0 logger.debug(f"删除禁言记录: {ban_record}")
for ident, row in existing_map.items():
if ident not in target_map:
await session.delete(row)
removed += 1
logger.debug(
f"Napcat ban list sync => total_incoming={len(ban_list)} created_or_updated={changed} removed={removed}"
)
async def create_or_update(self, ban_record: BanUser) -> None:
async with get_db_session() as session:
stmt = select(NapcatBanRecord).where(
NapcatBanRecord.group_id == ban_record.group_id,
NapcatBanRecord.user_id == ban_record.user_id,
)
result = await session.execute(stmt)
row = result.scalars().first()
if row:
if row.lift_time != ban_record.lift_time:
row.lift_time = ban_record.lift_time
logger.debug(
f"更新禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={ban_record.lift_time}"
)
else: else:
session.add( logger.info(f"未找到禁言记录: {ban_record}")
NapcatBanRecord(
group_id=ban_record.group_id, user_id=ban_record.user_id, lift_time=ban_record.lift_time
)
)
logger.debug(
f"创建禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={ban_record.lift_time}"
)
async def delete_record(self, ban_record: BanUser) -> None: logger.info("禁言记录已更新")
async with get_db_session() as session:
stmt = select(NapcatBanRecord).where( def get_ban_records(self) -> List[BanUser]:
NapcatBanRecord.group_id == ban_record.group_id, """
NapcatBanRecord.user_id == ban_record.user_id, 读取所有禁言记录。
) """
result = await session.execute(stmt) with Session(self.engine) as session:
row = result.scalars().first() statement = select(DB_BanUser)
if row: records = session.exec(statement).all()
await session.delete(row) return [BanUser(user_id=item.user_id, group_id=item.group_id, lift_time=item.lift_time) for item in records]
logger.debug(
f"删除禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={row.lift_time}" def create_ban_record(self, ban_record: BanUser) -> None:
"""
为特定群组中的用户创建禁言记录。
一个简化版本的添加方式,防止 update_ban_record 方法的复杂性。
其同时还是简化版的更新方式。
"""
with Session(self.engine) as session:
# 检查记录是否已存在
statement = select(DB_BanUser).where(
DB_BanUser.user_id == ban_record.user_id, DB_BanUser.group_id == ban_record.group_id
) )
existing_record = session.exec(statement).first()
if existing_record:
# 如果记录已存在,更新 lift_time
existing_record.lift_time = ban_record.lift_time
session.add(existing_record)
logger.debug(f"更新禁言记录: {ban_record}")
else: else:
logger.info( # 如果记录不存在,创建新记录
f"未找到禁言记录 group={ban_record.group_id} user={ban_record.user_id}" db_record = DB_BanUser(
user_id=ban_record.user_id, group_id=ban_record.group_id, lift_time=ban_record.lift_time
) )
session.add(db_record)
logger.debug(f"创建新禁言记录: {ban_record}")
# 兼容旧命名 def delete_ban_record(self, ban_record: BanUser):
async def update_ban_record(self, ban_list: List[BanUser]) -> None: # old name """
await self.update_ban_records(ban_list) 删除特定用户在特定群组中的禁言记录。
一个简化版本的删除方式,防止 update_ban_record 方法的复杂性。
"""
user_id = ban_record.user_id
group_id = ban_record.group_id
with Session(self.engine) as session:
statement = select(DB_BanUser).where(DB_BanUser.user_id == user_id, DB_BanUser.group_id == group_id)
if ban_record := session.exec(statement).first():
session.delete(ban_record)
async def create_ban_record(self, ban_record: BanUser) -> None: # old name logger.debug(f"删除禁言记录: {ban_record}")
await self.create_or_update(ban_record) else:
logger.info(f"未找到禁言记录: user_id: {user_id}, group_id: {group_id}")
async def delete_ban_record(self, ban_record: BanUser) -> None: # old name
await self.delete_record(ban_record)
napcat_db = NapcatDatabase() db_manager = DatabaseManager()
def is_identical(a: BanUser, b: BanUser) -> bool:
return a.group_id == b.group_id and a.user_id == b.user_id
__all__ = [
"BanUser",
"NapcatBanRecord",
"napcat_db",
"is_identical",
]

View File

@@ -112,8 +112,7 @@ class MessageChunker:
else: else:
return [{"_original_message": message}] return [{"_original_message": message}]
@staticmethod def is_chunk_message(self, message: Union[str, Dict[str, Any]]) -> bool:
def is_chunk_message(message: Union[str, Dict[str, Any]]) -> bool:
"""判断是否是切片消息""" """判断是否是切片消息"""
try: try:
if isinstance(message, str): if isinstance(message, str):

View File

@@ -14,7 +14,6 @@ class MetaEventHandler:
""" """
def __init__(self): def __init__(self):
self.last_heart_beat = time.time()
self.interval = 5.0 # 默认值稍后通过set_plugin_config设置 self.interval = 5.0 # 默认值稍后通过set_plugin_config设置
self._interval_checking = False self._interval_checking = False
self.plugin_config = None self.plugin_config = None
@@ -40,6 +39,7 @@ class MetaEventHandler:
if message["status"].get("online") and message["status"].get("good"): if message["status"].get("online") and message["status"].get("good"):
if not self._interval_checking: if not self._interval_checking:
asyncio.create_task(self.check_heartbeat()) asyncio.create_task(self.check_heartbeat())
self.last_heart_beat = time.time()
self.interval = message.get("interval") / 1000 self.interval = message.get("interval") / 1000
else: else:
self_id = message.get("self_id") self_id = message.get("self_id")

View File

@@ -76,7 +76,7 @@ class SendHandler:
processed_message = await self.handle_seg_recursive(message_segment, user_info) processed_message = await self.handle_seg_recursive(message_segment, user_info)
except Exception as e: except Exception as e:
logger.error(f"处理消息时发生错误: {e}") logger.error(f"处理消息时发生错误: {e}")
return None return
if not processed_message: if not processed_message:
logger.critical("现在暂时不支持解析此回复!") logger.critical("现在暂时不支持解析此回复!")
@@ -94,7 +94,7 @@ class SendHandler:
id_name = "user_id" id_name = "user_id"
else: else:
logger.error("无法识别的消息类型") logger.error("无法识别的消息类型")
return None return
logger.info("尝试发送到napcat") logger.info("尝试发送到napcat")
logger.debug(f"准备发送到napcat的消息体: action='{action}', {id_name}='{target_id}', message='{processed_message}'") logger.debug(f"准备发送到napcat的消息体: action='{action}', {id_name}='{target_id}', message='{processed_message}'")
response = await self.send_message_to_napcat( response = await self.send_message_to_napcat(
@@ -108,10 +108,8 @@ class SendHandler:
logger.info("消息发送成功") logger.info("消息发送成功")
qq_message_id = response.get("data", {}).get("message_id") qq_message_id = response.get("data", {}).get("message_id")
await self.message_sent_back(raw_message_base, qq_message_id) await self.message_sent_back(raw_message_base, qq_message_id)
return None
else: else:
logger.warning(f"消息发送失败napcat返回{str(response)}") logger.warning(f"消息发送失败napcat返回{str(response)}")
return None
async def send_command(self, raw_message_base: MessageBase) -> None: async def send_command(self, raw_message_base: MessageBase) -> None:
""" """
@@ -149,7 +147,7 @@ class SendHandler:
command, args_dict = self.handle_send_like_command(args) command, args_dict = self.handle_send_like_command(args)
case _: case _:
logger.error(f"未知命令: {command_name}") logger.error(f"未知命令: {command_name}")
return None return
except Exception as e: except Exception as e:
logger.error(f"处理命令时发生错误: {e}") logger.error(f"处理命令时发生错误: {e}")
return None return None
@@ -161,10 +159,8 @@ class SendHandler:
response = await self.send_message_to_napcat(command, args_dict) response = await self.send_message_to_napcat(command, args_dict)
if response.get("status") == "ok": if response.get("status") == "ok":
logger.info(f"命令 {command_name} 执行成功") logger.info(f"命令 {command_name} 执行成功")
return None
else: else:
logger.warning(f"命令 {command_name} 执行失败napcat返回{str(response)}") logger.warning(f"命令 {command_name} 执行失败napcat返回{str(response)}")
return None
async def handle_adapter_command(self, raw_message_base: MessageBase) -> None: async def handle_adapter_command(self, raw_message_base: MessageBase) -> None:
""" """
@@ -272,8 +268,7 @@ class SendHandler:
new_payload = self.build_payload(payload, self.handle_file_message(file_path), False) new_payload = self.build_payload(payload, self.handle_file_message(file_path), False)
return new_payload return new_payload
@staticmethod def build_payload(self, payload: list, addon: dict | list, is_reply: bool = False) -> list:
def build_payload(payload: list, addon: dict | list, is_reply: bool = False) -> list:
# sourcery skip: for-append-to-extend, merge-list-append, simplify-generator # sourcery skip: for-append-to-extend, merge-list-append, simplify-generator
"""构建发送的消息体""" """构建发送的消息体"""
if is_reply: if is_reply:
@@ -339,13 +334,11 @@ class SendHandler:
logger.info(f"最终返回的回复段: {reply_seg}") logger.info(f"最终返回的回复段: {reply_seg}")
return reply_seg return reply_seg
@staticmethod def handle_text_message(self, message: str) -> dict:
def handle_text_message(message: str) -> dict:
"""处理文本消息""" """处理文本消息"""
return {"type": "text", "data": {"text": message}} return {"type": "text", "data": {"text": message}}
@staticmethod def handle_image_message(self, encoded_image: str) -> dict:
def handle_image_message(encoded_image: str) -> dict:
"""处理图片消息""" """处理图片消息"""
return { return {
"type": "image", "type": "image",
@@ -355,8 +348,7 @@ class SendHandler:
}, },
} # base64 编码的图片 } # base64 编码的图片
@staticmethod def handle_emoji_message(self, encoded_emoji: str) -> dict:
def handle_emoji_message(encoded_emoji: str) -> dict:
"""处理表情消息""" """处理表情消息"""
encoded_image = encoded_emoji encoded_image = encoded_emoji
image_format = get_image_format(encoded_emoji) image_format = get_image_format(encoded_emoji)
@@ -387,45 +379,39 @@ class SendHandler:
"data": {"file": f"base64://{encoded_voice}"}, "data": {"file": f"base64://{encoded_voice}"},
} }
@staticmethod def handle_voiceurl_message(self, voice_url: str) -> dict:
def handle_voiceurl_message(voice_url: str) -> dict:
"""处理语音链接消息""" """处理语音链接消息"""
return { return {
"type": "record", "type": "record",
"data": {"file": voice_url}, "data": {"file": voice_url},
} }
@staticmethod def handle_music_message(self, song_id: str) -> dict:
def handle_music_message(song_id: str) -> dict:
"""处理音乐消息""" """处理音乐消息"""
return { return {
"type": "music", "type": "music",
"data": {"type": "163", "id": song_id}, "data": {"type": "163", "id": song_id},
} }
@staticmethod def handle_videourl_message(self, video_url: str) -> dict:
def handle_videourl_message(video_url: str) -> dict:
"""处理视频链接消息""" """处理视频链接消息"""
return { return {
"type": "video", "type": "video",
"data": {"file": video_url}, "data": {"file": video_url},
} }
@staticmethod def handle_file_message(self, file_path: str) -> dict:
def handle_file_message(file_path: str) -> dict:
"""处理文件消息""" """处理文件消息"""
return { return {
"type": "file", "type": "file",
"data": {"file": f"file://{file_path}"}, "data": {"file": f"file://{file_path}"},
} }
@staticmethod def delete_msg_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
def delete_msg_command(args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
"""处理删除消息命令""" """处理删除消息命令"""
return "delete_msg", {"message_id": args["message_id"]} return "delete_msg", {"message_id": args["message_id"]}
@staticmethod def handle_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
def handle_ban_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""处理封禁命令 """处理封禁命令
Args: Args:
@@ -453,8 +439,7 @@ class SendHandler:
}, },
) )
@staticmethod def handle_whole_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
def handle_whole_ban_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""处理全体禁言命令 """处理全体禁言命令
Args: Args:
@@ -477,8 +462,7 @@ class SendHandler:
}, },
) )
@staticmethod def handle_kick_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
def handle_kick_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""处理群成员踢出命令 """处理群成员踢出命令
Args: Args:
@@ -503,8 +487,7 @@ class SendHandler:
}, },
) )
@staticmethod def handle_poke_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
def handle_poke_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""处理戳一戳命令 """处理戳一戳命令
Args: Args:
@@ -531,8 +514,7 @@ class SendHandler:
}, },
) )
@staticmethod def handle_set_emoji_like_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
def handle_set_emoji_like_command(args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
"""处理设置表情回应命令 """处理设置表情回应命令
Args: Args:
@@ -554,8 +536,7 @@ class SendHandler:
{"message_id": message_id, "emoji_id": emoji_id, "set": set_like}, {"message_id": message_id, "emoji_id": emoji_id, "set": set_like},
) )
@staticmethod def handle_send_like_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
def handle_send_like_command(args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
""" """
处理发送点赞命令的逻辑。 处理发送点赞命令的逻辑。
@@ -576,8 +557,7 @@ class SendHandler:
{"user_id": user_id, "times": times}, {"user_id": user_id, "times": times},
) )
@staticmethod def handle_ai_voice_send_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
def handle_ai_voice_send_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
""" """
处理AI语音发送命令的逻辑。 处理AI语音发送命令的逻辑。
并返回 NapCat 兼容的 (action, params) 元组。 并返回 NapCat 兼容的 (action, params) 元组。
@@ -624,8 +604,7 @@ class SendHandler:
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
return response return response
@staticmethod async def message_sent_back(self, message_base: MessageBase, qq_message_id: str) -> None:
async def message_sent_back(message_base: MessageBase, qq_message_id: str) -> None:
# 修改 additional_config添加 echo 字段 # 修改 additional_config添加 echo 字段
if message_base.message_info.additional_config is None: if message_base.message_info.additional_config is None:
message_base.message_info.additional_config = {} message_base.message_info.additional_config = {}
@@ -643,9 +622,8 @@ class SendHandler:
logger.debug("已回送消息ID") logger.debug("已回送消息ID")
return return
@staticmethod
async def send_adapter_command_response( async def send_adapter_command_response(
original_message: MessageBase, response_data: dict, request_id: str self, original_message: MessageBase, response_data: dict, request_id: str
) -> None: ) -> None:
""" """
发送适配器命令响应回MaiBot 发送适配器命令响应回MaiBot
@@ -674,8 +652,7 @@ class SendHandler:
except Exception as e: except Exception as e:
logger.error(f"发送适配器命令响应时出错: {e}") logger.error(f"发送适配器命令响应时出错: {e}")
@staticmethod def handle_at_message_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
def handle_at_message_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""处理艾特并发送消息命令 """处理艾特并发送消息命令
Args: Args:

View File

@@ -6,7 +6,7 @@ import urllib3
import ssl import ssl
import io import io
from .database import BanUser, napcat_db from .database import BanUser, db_manager
from src.common.logger import get_logger from src.common.logger import get_logger
logger = get_logger("napcat_adapter") logger = get_logger("napcat_adapter")
@@ -270,11 +270,10 @@ async def read_ban_list(
] ]
""" """
try: try:
ban_list = await napcat_db.get_ban_records() ban_list = db_manager.get_ban_records()
lifted_list: List[BanUser] = [] lifted_list: List[BanUser] = []
logger.info("已经读取禁言列表") logger.info("已经读取禁言列表")
# 复制列表以避免迭代中修改原列表问题 for ban_record in ban_list:
for ban_record in list(ban_list):
if ban_record.user_id == 0: if ban_record.user_id == 0:
fetched_group_info = await get_group_info(websocket, ban_record.group_id) fetched_group_info = await get_group_info(websocket, ban_record.group_id)
if fetched_group_info is None: if fetched_group_info is None:
@@ -302,12 +301,12 @@ async def read_ban_list(
ban_list.remove(ban_record) ban_list.remove(ban_record)
else: else:
ban_record.lift_time = lift_ban_time ban_record.lift_time = lift_ban_time
await napcat_db.update_ban_record(ban_list) db_manager.update_ban_record(ban_list)
return ban_list, lifted_list return ban_list, lifted_list
except Exception as e: except Exception as e:
logger.error(f"读取禁言列表失败: {e}") logger.error(f"读取禁言列表失败: {e}")
return [], [] return [], []
async def save_ban_record(list: List[BanUser]): def save_ban_record(list: List[BanUser]):
return await napcat_db.update_ban_record(list) return db_manager.update_ban_record(list)

110
test_deepcopy_fix.py Normal file
View File

@@ -0,0 +1,110 @@
#!/usr/bin/env python3
"""
测试 ChatStream 的 deepcopy 功能
验证 asyncio.Task 序列化问题是否已解决
"""
import asyncio
import sys
import os
import copy
# 添加项目根目录到 Python 路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from src.chat.message_receive.chat_stream import ChatStream
from maim_message import UserInfo, GroupInfo
async def test_chat_stream_deepcopy():
"""测试 ChatStream 的 deepcopy 功能"""
print("[TEST] 开始测试 ChatStream deepcopy 功能...")
try:
# 创建测试用的用户和群组信息
user_info = UserInfo(
platform="test_platform",
user_id="test_user_123",
user_nickname="测试用户",
user_cardname="测试卡片名"
)
group_info = GroupInfo(
platform="test_platform",
group_id="test_group_456",
group_name="测试群组"
)
# 创建 ChatStream 实例
print("📝 创建 ChatStream 实例...")
stream_id = "test_stream_789"
platform = "test_platform"
chat_stream = ChatStream(
stream_id=stream_id,
platform=platform,
user_info=user_info,
group_info=group_info
)
print(f"[SUCCESS] ChatStream 创建成功: {chat_stream.stream_id}")
# 等待一下,让异步任务有机会创建
await asyncio.sleep(0.1)
# 尝试进行 deepcopy
print("[INFO] 尝试进行 deepcopy...")
copied_stream = copy.deepcopy(chat_stream)
print("[SUCCESS] deepcopy 成功!")
# 验证复制后的对象属性
print("\n[CHECK] 验证复制后的对象属性:")
print(f" - stream_id: {copied_stream.stream_id}")
print(f" - platform: {copied_stream.platform}")
print(f" - user_info: {copied_stream.user_info.user_nickname}")
print(f" - group_info: {copied_stream.group_info.group_name}")
# 检查 processing_task 是否被正确处理
if hasattr(copied_stream.stream_context, 'processing_task'):
print(f" - processing_task: {copied_stream.stream_context.processing_task}")
if copied_stream.stream_context.processing_task is None:
print(" [SUCCESS] processing_task 已被正确设置为 None")
else:
print(" [WARNING] processing_task 不为 None")
else:
print(" [SUCCESS] stream_context 没有 processing_task 属性")
# 验证原始对象和复制对象是不同的实例
if id(chat_stream) != id(copied_stream):
print("[SUCCESS] 原始对象和复制对象是不同的实例")
else:
print("[ERROR] 原始对象和复制对象是同一个实例")
# 验证基本属性是否正确复制
if (chat_stream.stream_id == copied_stream.stream_id and
chat_stream.platform == copied_stream.platform):
print("[SUCCESS] 基本属性正确复制")
else:
print("[ERROR] 基本属性复制失败")
print("\n[COMPLETE] 测试完成deepcopy 功能修复成功!")
return True
except Exception as e:
print(f"[ERROR] 测试失败: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
# 运行测试
result = asyncio.run(test_chat_stream_deepcopy())
if result:
print("\n[SUCCESS] 所有测试通过!")
sys.exit(0)
else:
print("\n[ERROR] 测试失败!")
sys.exit(1)

109
test_simple_deepcopy.py Normal file
View File

@@ -0,0 +1,109 @@
#!/usr/bin/env python3
"""
简单的 ChatStream deepcopy 测试
"""
import asyncio
import sys
import os
import copy
# 添加项目根目录到 Python 路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from src.chat.message_receive.chat_stream import ChatStream
from maim_message import UserInfo, GroupInfo
async def test_deepcopy():
"""测试 deepcopy 功能"""
print("开始测试 ChatStream deepcopy 功能...")
try:
# 创建测试用的用户和群组信息
user_info = UserInfo(
platform="test_platform",
user_id="test_user_123",
user_nickname="测试用户",
user_cardname="测试卡片名"
)
group_info = GroupInfo(
platform="test_platform",
group_id="test_group_456",
group_name="测试群组"
)
# 创建 ChatStream 实例
print("创建 ChatStream 实例...")
stream_id = "test_stream_789"
platform = "test_platform"
chat_stream = ChatStream(
stream_id=stream_id,
platform=platform,
user_info=user_info,
group_info=group_info
)
print(f"ChatStream 创建成功: {chat_stream.stream_id}")
# 等待一下,让异步任务有机会创建
await asyncio.sleep(0.1)
# 尝试进行 deepcopy
print("尝试进行 deepcopy...")
copied_stream = copy.deepcopy(chat_stream)
print("deepcopy 成功!")
# 验证复制后的对象属性
print("\n验证复制后的对象属性:")
print(f" - stream_id: {copied_stream.stream_id}")
print(f" - platform: {copied_stream.platform}")
print(f" - user_info: {copied_stream.user_info.user_nickname}")
print(f" - group_info: {copied_stream.group_info.group_name}")
# 检查 processing_task 是否被正确处理
if hasattr(copied_stream.stream_context, 'processing_task'):
print(f" - processing_task: {copied_stream.stream_context.processing_task}")
if copied_stream.stream_context.processing_task is None:
print(" SUCCESS: processing_task 已被正确设置为 None")
else:
print(" WARNING: processing_task 不为 None")
else:
print(" SUCCESS: stream_context 没有 processing_task 属性")
# 验证原始对象和复制对象是不同的实例
if id(chat_stream) != id(copied_stream):
print("SUCCESS: 原始对象和复制对象是不同的实例")
else:
print("ERROR: 原始对象和复制对象是同一个实例")
# 验证基本属性是否正确复制
if (chat_stream.stream_id == copied_stream.stream_id and
chat_stream.platform == copied_stream.platform):
print("SUCCESS: 基本属性正确复制")
else:
print("ERROR: 基本属性复制失败")
print("\n测试完成deepcopy 功能修复成功!")
return True
except Exception as e:
print(f"ERROR: 测试失败: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
# 运行测试
result = asyncio.run(test_deepcopy())
if result:
print("\n所有测试通过!")
sys.exit(0)
else:
print("\n测试失败!")
sys.exit(1)