refactor(database): 将同步数据库操作迁移为异步操作
将整个项目的数据库操作从同步模式迁移为异步模式,主要涉及以下修改: - 将 `with get_db_session()` 改为 `async with get_db_session()` - 将同步的 SQLAlchemy 查询方法改为异步执行 - 更新相关的方法签名,添加 async/await 关键字 - 修复由于异步化导致的并发问题和性能问题 这些修改提高了数据库操作的并发性能,避免了阻塞主线程,提升了系统的整体响应能力。涉及修改的模块包括表情包管理、反提示注入统计、用户封禁管理、记忆系统、消息存储等多个核心组件。 BREAKING CHANGE: 所有涉及数据库操作的方法现在都需要使用异步调用,同步调用将不再工作
This commit is contained in:
13
bot.py
13
bot.py
@@ -185,12 +185,12 @@ class MaiBotMain(BaseMain):
|
||||
check_eula()
|
||||
logger.info("检查EULA和隐私条款完成")
|
||||
|
||||
def initialize_database(self):
|
||||
async def initialize_database(self):
|
||||
"""初始化数据库"""
|
||||
|
||||
logger.info("正在初始化数据库连接...")
|
||||
try:
|
||||
initialize_sql_database(global_config.database)
|
||||
await initialize_sql_database(global_config.database)
|
||||
logger.info(f"数据库连接初始化成功,使用 {global_config.database.database_type} 数据库")
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接初始化失败: {e}")
|
||||
@@ -211,11 +211,11 @@ class MaiBotMain(BaseMain):
|
||||
self.main_system = MainSystem()
|
||||
return self.main_system
|
||||
|
||||
def run(self):
|
||||
async def run(self):
|
||||
"""运行主程序"""
|
||||
self.setup_timezone()
|
||||
self.check_and_confirm_eula()
|
||||
self.initialize_database()
|
||||
await self.initialize_database()
|
||||
|
||||
return self.create_main_system()
|
||||
|
||||
@@ -225,14 +225,14 @@ if __name__ == "__main__":
|
||||
try:
|
||||
# 创建MaiBotMain实例并获取MainSystem
|
||||
maibot = MaiBotMain()
|
||||
main_system = maibot.run()
|
||||
|
||||
# 创建事件循环
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# 异步初始化数据库表结构
|
||||
# 异步初始化数据库和表结构
|
||||
main_system = loop.run_until_complete(maibot.run())
|
||||
loop.run_until_complete(maibot.initialize_database_async())
|
||||
# 执行初始化和任务调度
|
||||
loop.run_until_complete(main_system.initialize())
|
||||
@@ -269,3 +269,4 @@ if __name__ == "__main__":
|
||||
# 在程序退出前暂停,让你有机会看到输出
|
||||
# input("按 Enter 键退出...") # <--- 添加这行
|
||||
sys.exit(exit_code) # <--- 使用记录的退出码
|
||||
|
||||
@@ -8,6 +8,8 @@
|
||||
import datetime
|
||||
from typing import Dict, Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import AntiInjectionStats, get_db_session
|
||||
from src.config.config import global_config
|
||||
@@ -27,9 +29,11 @@ class AntiInjectionStatistics:
|
||||
async def get_or_create_stats():
|
||||
"""获取或创建统计记录"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 获取最新的统计记录,如果没有则创建
|
||||
stats = session.query(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()).first()
|
||||
stats = (await session.execute(
|
||||
select(AntiInjectionStats).order_by(AntiInjectionStats.id.desc())
|
||||
)).scalars().first()
|
||||
if not stats:
|
||||
stats = AntiInjectionStats()
|
||||
session.add(stats)
|
||||
@@ -44,8 +48,10 @@ class AntiInjectionStatistics:
|
||||
async def update_stats(**kwargs):
|
||||
"""更新统计数据"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
stats = session.query(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()).first()
|
||||
async with get_db_session() as session:
|
||||
stats = (await session.execute(
|
||||
select(AntiInjectionStats).order_by(AntiInjectionStats.id.desc())
|
||||
)).scalars().first()
|
||||
if not stats:
|
||||
stats = AntiInjectionStats()
|
||||
session.add(stats)
|
||||
@@ -53,7 +59,7 @@ class AntiInjectionStatistics:
|
||||
# 更新统计字段
|
||||
for key, value in kwargs.items():
|
||||
if key == "processing_time_delta":
|
||||
# 处理时间累加 - 确保不为None
|
||||
# 处理 时间累加 - 确保不为None
|
||||
if stats.processing_time_total is None:
|
||||
stats.processing_time_total = 0.0
|
||||
stats.processing_time_total += value
|
||||
@@ -138,9 +144,9 @@ class AntiInjectionStatistics:
|
||||
async def reset_stats():
|
||||
"""重置统计信息"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 删除现有统计记录
|
||||
session.query(AntiInjectionStats).delete()
|
||||
await session.execute(select(AntiInjectionStats).delete())
|
||||
await session.commit()
|
||||
logger.info("统计信息已重置")
|
||||
except Exception as e:
|
||||
|
||||
@@ -8,6 +8,8 @@
|
||||
import datetime
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import BanUser, get_db_session
|
||||
from ..types import DetectionResult
|
||||
@@ -37,8 +39,9 @@ class UserBanManager:
|
||||
如果用户被封禁则返回拒绝结果,否则返回None
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
ban_record = session.query(BanUser).filter_by(user_id=user_id, platform=platform).first()
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(select(BanUser).filter_by(user_id=user_id, platform=platform))
|
||||
ban_record = result.scalar_one_or_none()
|
||||
|
||||
if ban_record:
|
||||
# 只有违规次数达到阈值时才算被封禁
|
||||
@@ -70,9 +73,10 @@ class UserBanManager:
|
||||
detection_result: 检测结果
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 查找或创建违规记录
|
||||
ban_record = session.query(BanUser).filter_by(user_id=user_id, platform=platform).first()
|
||||
result = await session.execute(select(BanUser).filter_by(user_id=user_id, platform=platform))
|
||||
ban_record = result.scalar_one_or_none()
|
||||
|
||||
if ban_record:
|
||||
ban_record.violation_num += 1
|
||||
|
||||
@@ -149,7 +149,7 @@ class MaiEmoji:
|
||||
# --- 数据库操作 ---
|
||||
try:
|
||||
# 准备数据库记录 for emoji collection
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
emotion_str = ",".join(self.emotion) if self.emotion else ""
|
||||
|
||||
emoji = Emoji(
|
||||
@@ -167,7 +167,7 @@ class MaiEmoji:
|
||||
last_used_time=self.last_used_time,
|
||||
)
|
||||
session.add(emoji)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")
|
||||
|
||||
@@ -203,17 +203,18 @@ class MaiEmoji:
|
||||
|
||||
# 2. 删除数据库记录
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
will_delete_emoji = session.execute(
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(
|
||||
select(Emoji).where(Emoji.emoji_hash == self.hash)
|
||||
).scalar_one_or_none()
|
||||
)
|
||||
will_delete_emoji = result.scalar_one_or_none()
|
||||
if will_delete_emoji is None:
|
||||
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
|
||||
result = 0 # Indicate no DB record was deleted
|
||||
else:
|
||||
session.delete(will_delete_emoji)
|
||||
await session.delete(will_delete_emoji)
|
||||
result = 1 # Successfully deleted one record
|
||||
session.commit()
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除数据库记录时出错: {str(e)}")
|
||||
result = 0
|
||||
@@ -424,17 +425,19 @@ class EmojiManager:
|
||||
# if not self._initialized:
|
||||
# raise RuntimeError("EmojiManager not initialized")
|
||||
|
||||
def record_usage(self, emoji_hash: str) -> None:
|
||||
async def record_usage(self, emoji_hash: str) -> None:
|
||||
"""记录表情使用次数"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
emoji_update = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none()
|
||||
async with get_db_session() as session:
|
||||
stmt = select(Emoji).where(Emoji.emoji_hash == emoji_hash)
|
||||
result = await session.execute(stmt)
|
||||
emoji_update = result.scalar_one_or_none()
|
||||
if emoji_update is None:
|
||||
logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包")
|
||||
else:
|
||||
emoji_update.usage_count += 1
|
||||
emoji_update.last_used_time = time.time() # Update last used time
|
||||
session.commit()
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"记录表情使用失败: {str(e)}")
|
||||
|
||||
@@ -521,7 +524,7 @@ class EmojiManager:
|
||||
|
||||
# 7. 获取选中的表情包并更新使用记录
|
||||
selected_emoji = candidate_emojis[selected_index]
|
||||
self.record_usage(selected_emoji.hash)
|
||||
await self.record_usage(selected_emoji.hash)
|
||||
_time_end = time.time()
|
||||
|
||||
logger.info(f"找到匹配描述的表情包: {selected_emoji.description}, 耗时: {(_time_end - _time_start):.2f}s")
|
||||
@@ -657,10 +660,11 @@ class EmojiManager:
|
||||
async def get_all_emoji_from_db(self) -> None:
|
||||
"""获取所有表情包并初始化为MaiEmoji类对象,更新 self.emoji_objects"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
logger.debug("[数据库] 开始加载所有表情包记录 ...")
|
||||
|
||||
emoji_instances = session.execute(select(Emoji)).scalars().all()
|
||||
result = await session.execute(select(Emoji))
|
||||
emoji_instances = result.scalars().all()
|
||||
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
|
||||
|
||||
# 更新内存中的列表和数量
|
||||
@@ -686,14 +690,16 @@ class EmojiManager:
|
||||
list[MaiEmoji]: 表情包对象列表
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
if emoji_hash:
|
||||
query = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalars().all()
|
||||
result = await session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash))
|
||||
query = result.scalars().all()
|
||||
else:
|
||||
logger.warning(
|
||||
"[查询] 未提供 hash,将尝试加载所有表情包,建议使用 get_all_emoji_from_db 更新管理器状态。"
|
||||
)
|
||||
query = session.execute(select(Emoji)).scalars().all()
|
||||
result = await session.execute(select(Emoji))
|
||||
query = result.scalars().all()
|
||||
|
||||
emoji_instances = query
|
||||
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
|
||||
@@ -770,10 +776,10 @@ class EmojiManager:
|
||||
|
||||
# 如果内存中没有,从数据库查找
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
emoji_record = session.execute(
|
||||
select(Emoji).where(Emoji.emoji_hash == emoji_hash)
|
||||
).scalar_one_or_none()
|
||||
async with get_db_session() as session:
|
||||
stmt = select(Emoji).where(Emoji.emoji_hash == emoji_hash)
|
||||
result = await session.execute(stmt)
|
||||
emoji_record = result.scalar_one_or_none()
|
||||
if emoji_record and emoji_record.description:
|
||||
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...")
|
||||
return emoji_record.description
|
||||
@@ -939,12 +945,13 @@ class EmojiManager:
|
||||
# 2. 检查数据库中是否已存在该表情包的描述,实现复用
|
||||
existing_description = None
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
existing_image = (
|
||||
session.query(Images)
|
||||
.filter((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
|
||||
.one_or_none()
|
||||
async with get_db_session() as session:
|
||||
stmt = select(Images).where(
|
||||
Images.emoji_hash == image_hash,
|
||||
Images.type == "emoji"
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
existing_image = result.scalar_one_or_none()
|
||||
if existing_image and existing_image.description:
|
||||
existing_description = existing_image.description
|
||||
logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...")
|
||||
|
||||
@@ -198,7 +198,7 @@ class RecencyEnergyCalculator(EnergyCalculator):
|
||||
class RelationshipEnergyCalculator(EnergyCalculator):
|
||||
"""关系能量计算器"""
|
||||
|
||||
def calculate(self, context: Dict[str, Any]) -> float:
|
||||
async def calculate(self, context: Dict[str, Any]) -> float:
|
||||
"""基于关系计算能量"""
|
||||
user_id = context.get("user_id")
|
||||
if not user_id:
|
||||
@@ -208,7 +208,7 @@ class RelationshipEnergyCalculator(EnergyCalculator):
|
||||
try:
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||
|
||||
relationship_score = chatter_interest_scoring_system._calculate_relationship_score(user_id)
|
||||
relationship_score = await chatter_interest_scoring_system._calculate_relationship_score(user_id)
|
||||
logger.debug(f"使用插件内部系统计算关系分: {relationship_score:.3f}")
|
||||
return max(0.0, min(1.0, relationship_score))
|
||||
|
||||
@@ -273,7 +273,7 @@ class EnergyManager:
|
||||
except Exception as e:
|
||||
logger.warning(f"加载AFC阈值失败,使用默认值: {e}")
|
||||
|
||||
def calculate_focus_energy(self, stream_id: str, messages: List[Any], user_id: Optional[str] = None) -> float:
|
||||
async def calculate_focus_energy(self, stream_id: str, messages: List[Any], user_id: Optional[str] = None) -> float:
|
||||
"""计算聊天流的focus_energy"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -303,7 +303,16 @@ class EnergyManager:
|
||||
|
||||
for calculator in self.calculators:
|
||||
try:
|
||||
score = calculator.calculate(context)
|
||||
# 支持同步和异步计算器
|
||||
if callable(calculator.calculate):
|
||||
import inspect
|
||||
if inspect.iscoroutinefunction(calculator.calculate):
|
||||
score = await calculator.calculate(context)
|
||||
else:
|
||||
score = calculator.calculate(context)
|
||||
else:
|
||||
score = calculator.calculate(context)
|
||||
|
||||
weight = calculator.get_weight()
|
||||
|
||||
component_scores[calculator.__class__.__name__] = score
|
||||
|
||||
@@ -8,6 +8,7 @@ import traceback
|
||||
from typing import List, Dict, Optional, Any
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
@@ -610,14 +611,13 @@ class BotInterestManager:
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
import orjson
|
||||
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 查询最新的兴趣标签配置
|
||||
db_interests = (
|
||||
session.query(DBBotPersonalityInterests)
|
||||
.filter(DBBotPersonalityInterests.personality_id == personality_id)
|
||||
db_interests = (await session.execute(
|
||||
select(DBBotPersonalityInterests)
|
||||
.where(DBBotPersonalityInterests.personality_id == personality_id)
|
||||
.order_by(DBBotPersonalityInterests.version.desc(), DBBotPersonalityInterests.last_updated.desc())
|
||||
.first()
|
||||
)
|
||||
)).scalars().first()
|
||||
|
||||
if db_interests:
|
||||
logger.debug(f"在数据库中找到兴趣标签配置, 版本: {db_interests.version}")
|
||||
@@ -700,13 +700,12 @@ class BotInterestManager:
|
||||
# 序列化为JSON
|
||||
json_data = orjson.dumps(tags_data)
|
||||
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 检查是否已存在相同personality_id的记录
|
||||
existing_record = (
|
||||
session.query(DBBotPersonalityInterests)
|
||||
.filter(DBBotPersonalityInterests.personality_id == interests.personality_id)
|
||||
.first()
|
||||
)
|
||||
existing_record = (await session.execute(
|
||||
select(DBBotPersonalityInterests)
|
||||
.where(DBBotPersonalityInterests.personality_id == interests.personality_id)
|
||||
)).scalars().first()
|
||||
|
||||
if existing_record:
|
||||
# 更新现有记录
|
||||
@@ -731,19 +730,17 @@ class BotInterestManager:
|
||||
last_updated=interests.last_updated,
|
||||
)
|
||||
session.add(new_record)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
logger.info(f"✅ 成功创建兴趣标签配置,版本: {interests.version}")
|
||||
|
||||
logger.info("✅ 兴趣标签已成功保存到数据库")
|
||||
|
||||
# 验证保存是否成功
|
||||
with get_db_session() as session:
|
||||
saved_record = (
|
||||
session.query(DBBotPersonalityInterests)
|
||||
.filter(DBBotPersonalityInterests.personality_id == interests.personality_id)
|
||||
.first()
|
||||
)
|
||||
session.commit()
|
||||
async with get_db_session() as session:
|
||||
saved_record = (await session.execute(
|
||||
select(DBBotPersonalityInterests)
|
||||
.where(DBBotPersonalityInterests.personality_id == interests.personality_id)
|
||||
)).scalars().first()
|
||||
if saved_record:
|
||||
logger.info(f"✅ 验证成功:数据库中存在personality_id为 {interests.personality_id} 的记录")
|
||||
logger.info(f" 版本: {saved_record.version}")
|
||||
|
||||
@@ -882,7 +882,8 @@ class EntorhinalCortex:
|
||||
|
||||
# 获取数据库中所有节点和内存中所有节点
|
||||
async with get_db_session() as session:
|
||||
db_nodes = {node.concept: node for node in (await session.execute(select(GraphNodes))).scalars()}
|
||||
result = await session.execute(select(GraphNodes))
|
||||
db_nodes = {node.concept: node for node in result.scalars()}
|
||||
memory_nodes = list(self.memory_graph.G.nodes(data=True))
|
||||
|
||||
# 批量准备节点数据
|
||||
@@ -978,7 +979,8 @@ class EntorhinalCortex:
|
||||
await session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete)))
|
||||
|
||||
# 处理边的信息
|
||||
db_edges = list((await session.execute(select(GraphEdges))).scalars())
|
||||
result = await session.execute(select(GraphEdges))
|
||||
db_edges = list(result.scalars())
|
||||
memory_edges = list(self.memory_graph.G.edges(data=True))
|
||||
|
||||
# 创建边的哈希值字典
|
||||
@@ -1157,7 +1159,8 @@ class EntorhinalCortex:
|
||||
|
||||
# 从数据库加载所有节点
|
||||
async with get_db_session() as session:
|
||||
nodes = list((await session.execute(select(GraphNodes))).scalars())
|
||||
result = await session.execute(select(GraphNodes))
|
||||
nodes = list(result.scalars())
|
||||
for node in nodes:
|
||||
concept = node.concept
|
||||
try:
|
||||
@@ -1192,7 +1195,8 @@ class EntorhinalCortex:
|
||||
continue
|
||||
|
||||
# 从数据库加载所有边
|
||||
edges = list((await session.execute(select(GraphEdges))).scalars())
|
||||
result = await session.execute(select(GraphEdges))
|
||||
edges = list(result.scalars())
|
||||
for edge in edges:
|
||||
source = edge.source
|
||||
target = edge.target
|
||||
|
||||
@@ -184,6 +184,11 @@ class AsyncMemoryQueue:
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
|
||||
if hippocampus_manager._initialized:
|
||||
# 确保海马体对象已正确初始化
|
||||
if not hippocampus_manager._hippocampus.parahippocampal_gyrus:
|
||||
logger.warning("海马体对象未完全初始化,进行同步初始化")
|
||||
hippocampus_manager._hippocampus.initialize()
|
||||
|
||||
await hippocampus_manager.build_memory()
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -108,7 +108,7 @@ class InstantMemory:
|
||||
|
||||
@staticmethod
|
||||
async def store_memory(memory_item: MemoryItem):
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
memory = Memory(
|
||||
memory_id=memory_item.memory_id,
|
||||
chat_id=memory_item.chat_id,
|
||||
@@ -161,20 +161,21 @@ class InstantMemory:
|
||||
logger.info(f"start_time: {start_time}, end_time: {end_time}")
|
||||
# 检索包含关键词的记忆
|
||||
memories_set = set()
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
if start_time and end_time:
|
||||
start_ts = start_time.timestamp()
|
||||
end_ts = end_time.timestamp()
|
||||
|
||||
query = session.execute(
|
||||
query = (await session.execute(
|
||||
select(Memory).where(
|
||||
(Memory.chat_id == self.chat_id)
|
||||
& (Memory.create_time >= start_ts)
|
||||
& (Memory.create_time < end_ts)
|
||||
)
|
||||
).scalars()
|
||||
)).scalars()
|
||||
else:
|
||||
query = session.execute(select(Memory).where(Memory.chat_id == self.chat_id)).scalars()
|
||||
query = result = await session.execute(select(Memory).where(Memory.chat_id == self.chat_id))
|
||||
result.scalars()
|
||||
for mem in query:
|
||||
# 对每条记忆
|
||||
mem_keywords_str = mem.keywords or "[]"
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
"""
|
||||
|
||||
from .message_manager import MessageManager, message_manager
|
||||
from .context_manager import StreamContextManager, context_manager
|
||||
from .context_manager import SingleStreamContextManager
|
||||
from .distribution_manager import (
|
||||
DistributionManager,
|
||||
DistributionPriority,
|
||||
@@ -16,8 +16,7 @@ from .distribution_manager import (
|
||||
__all__ = [
|
||||
"MessageManager",
|
||||
"message_manager",
|
||||
"StreamContextManager",
|
||||
"context_manager",
|
||||
"SingleStreamContextManager",
|
||||
"DistributionManager",
|
||||
"DistributionPriority",
|
||||
"DistributionTask",
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
"""
|
||||
重构后的聊天上下文管理器
|
||||
提供统一、稳定的聊天上下文管理功能
|
||||
每个 context_manager 实例只管理一个 stream 的上下文
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any, Union, Tuple
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.common.logger import get_logger
|
||||
@@ -17,241 +17,112 @@ from .distribution_manager import distribution_manager
|
||||
|
||||
logger = get_logger("context_manager")
|
||||
|
||||
class StreamContextManager:
|
||||
"""流上下文管理器 - 统一管理所有聊天流上下文"""
|
||||
|
||||
def __init__(self, max_context_size: Optional[int] = None, context_ttl: Optional[int] = None):
|
||||
# 上下文存储
|
||||
self.stream_contexts: Dict[str, Any] = {}
|
||||
self.context_metadata: Dict[str, Dict[str, Any]] = {}
|
||||
class SingleStreamContextManager:
|
||||
"""单流上下文管理器 - 每个实例只管理一个 stream 的上下文"""
|
||||
|
||||
# 统计信息
|
||||
self.stats: Dict[str, Union[int, float, str, Dict]] = {
|
||||
"total_messages": 0,
|
||||
"total_streams": 0,
|
||||
"active_streams": 0,
|
||||
"inactive_streams": 0,
|
||||
"last_activity": time.time(),
|
||||
"creation_time": time.time(),
|
||||
}
|
||||
def __init__(self, stream_id: str, context: StreamContext, max_context_size: Optional[int] = None):
|
||||
self.stream_id = stream_id
|
||||
self.context = context
|
||||
|
||||
# 配置参数
|
||||
self.max_context_size = max_context_size or getattr(global_config.chat, "max_context_size", 100)
|
||||
self.context_ttl = context_ttl or getattr(global_config.chat, "context_ttl", 24 * 3600) # 24小时
|
||||
self.cleanup_interval = getattr(global_config.chat, "context_cleanup_interval", 3600) # 1小时
|
||||
self.auto_cleanup = getattr(global_config.chat, "auto_cleanup_contexts", True)
|
||||
self.enable_validation = getattr(global_config.chat, "enable_context_validation", True)
|
||||
self.context_ttl = getattr(global_config.chat, "context_ttl", 24 * 3600) # 24小时
|
||||
|
||||
# 清理任务
|
||||
self.cleanup_task: Optional[Any] = None
|
||||
self.is_running = False
|
||||
# 元数据
|
||||
self.created_time = time.time()
|
||||
self.last_access_time = time.time()
|
||||
self.access_count = 0
|
||||
self.total_messages = 0
|
||||
|
||||
logger.info(f"上下文管理器初始化完成 (最大上下文: {self.max_context_size}, TTL: {self.context_ttl}s)")
|
||||
logger.debug(f"单流上下文管理器初始化: {stream_id}")
|
||||
|
||||
def add_stream_context(self, stream_id: str, context: Any, metadata: Optional[Dict[str, Any]] = None) -> bool:
|
||||
"""添加流上下文
|
||||
def get_context(self) -> StreamContext:
|
||||
"""获取流上下文"""
|
||||
self._update_access_stats()
|
||||
return self.context
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
context: 上下文对象
|
||||
metadata: 上下文元数据
|
||||
|
||||
Returns:
|
||||
bool: 是否成功添加
|
||||
"""
|
||||
if stream_id in self.stream_contexts:
|
||||
logger.warning(f"流上下文已存在: {stream_id}")
|
||||
return False
|
||||
|
||||
# 添加上下文
|
||||
self.stream_contexts[stream_id] = context
|
||||
|
||||
# 初始化元数据
|
||||
self.context_metadata[stream_id] = {
|
||||
"created_time": time.time(),
|
||||
"last_access_time": time.time(),
|
||||
"access_count": 0,
|
||||
"last_validation_time": 0.0,
|
||||
"custom_metadata": metadata or {},
|
||||
}
|
||||
|
||||
# 更新统计
|
||||
self.stats["total_streams"] += 1
|
||||
self.stats["active_streams"] += 1
|
||||
self.stats["last_activity"] = time.time()
|
||||
|
||||
logger.debug(f"添加流上下文: {stream_id} (类型: {type(context).__name__})")
|
||||
return True
|
||||
|
||||
def remove_stream_context(self, stream_id: str) -> bool:
|
||||
"""移除流上下文
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
|
||||
Returns:
|
||||
bool: 是否成功移除
|
||||
"""
|
||||
if stream_id in self.stream_contexts:
|
||||
context = self.stream_contexts[stream_id]
|
||||
metadata = self.context_metadata.get(stream_id, {})
|
||||
|
||||
del self.stream_contexts[stream_id]
|
||||
if stream_id in self.context_metadata:
|
||||
del self.context_metadata[stream_id]
|
||||
|
||||
self.stats["active_streams"] = max(0, self.stats["active_streams"] - 1)
|
||||
self.stats["inactive_streams"] += 1
|
||||
self.stats["last_activity"] = time.time()
|
||||
|
||||
logger.debug(f"移除流上下文: {stream_id} (类型: {type(context).__name__})")
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_stream_context(self, stream_id: str, update_access: bool = True) -> Optional[StreamContext]:
|
||||
"""获取流上下文
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
update_access: 是否更新访问统计
|
||||
|
||||
Returns:
|
||||
Optional[Any]: 上下文对象
|
||||
"""
|
||||
context = self.stream_contexts.get(stream_id)
|
||||
if context and update_access:
|
||||
# 更新访问统计
|
||||
if stream_id in self.context_metadata:
|
||||
metadata = self.context_metadata[stream_id]
|
||||
metadata["last_access_time"] = time.time()
|
||||
metadata["access_count"] = metadata.get("access_count", 0) + 1
|
||||
return context
|
||||
|
||||
def get_context_metadata(self, stream_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取上下文元数据
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 元数据
|
||||
"""
|
||||
return self.context_metadata.get(stream_id)
|
||||
|
||||
def update_context_metadata(self, stream_id: str, updates: Dict[str, Any]) -> bool:
|
||||
"""更新上下文元数据
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
updates: 更新的元数据
|
||||
|
||||
Returns:
|
||||
bool: 是否成功更新
|
||||
"""
|
||||
if stream_id not in self.context_metadata:
|
||||
return False
|
||||
|
||||
self.context_metadata[stream_id].update(updates)
|
||||
return True
|
||||
|
||||
def add_message_to_context(self, stream_id: str, message: DatabaseMessages, skip_energy_update: bool = False) -> bool:
|
||||
def add_message(self, message: DatabaseMessages, skip_energy_update: bool = False) -> bool:
|
||||
"""添加消息到上下文
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
message: 消息对象
|
||||
skip_energy_update: 是否跳过能量更新
|
||||
|
||||
Returns:
|
||||
bool: 是否成功添加
|
||||
"""
|
||||
context = self.get_stream_context(stream_id)
|
||||
if not context:
|
||||
logger.warning(f"流上下文不存在: {stream_id}")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 添加消息到上下文
|
||||
context.add_message(message)
|
||||
self.context.add_message(message)
|
||||
|
||||
# 计算消息兴趣度
|
||||
interest_value = self._calculate_message_interest(message)
|
||||
message.interest_value = interest_value
|
||||
|
||||
# 更新统计
|
||||
self.stats["total_messages"] += 1
|
||||
self.stats["last_activity"] = time.time()
|
||||
self.total_messages += 1
|
||||
self.last_access_time = time.time()
|
||||
|
||||
# 更新能量和分发
|
||||
if not skip_energy_update:
|
||||
self._update_stream_energy(stream_id)
|
||||
distribution_manager.add_stream_message(stream_id, 1)
|
||||
self._update_stream_energy()
|
||||
distribution_manager.add_stream_message(self.stream_id, 1)
|
||||
|
||||
logger.debug(f"添加消息到上下文: {stream_id} (兴趣度: {interest_value:.3f})")
|
||||
logger.debug(f"添加消息到单流上下文: {self.stream_id} (兴趣度: {interest_value:.3f})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"添加消息到上下文失败 {stream_id}: {e}", exc_info=True)
|
||||
logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def update_message_in_context(self, stream_id: str, message_id: str, updates: Dict[str, Any]) -> bool:
|
||||
def update_message(self, message_id: str, updates: Dict[str, Any]) -> bool:
|
||||
"""更新上下文中的消息
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
message_id: 消息ID
|
||||
updates: 更新的属性
|
||||
|
||||
Returns:
|
||||
bool: 是否成功更新
|
||||
"""
|
||||
context = self.get_stream_context(stream_id)
|
||||
if not context:
|
||||
logger.warning(f"流上下文不存在: {stream_id}")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 更新消息信息
|
||||
context.update_message_info(message_id, **updates)
|
||||
self.context.update_message_info(message_id, **updates)
|
||||
|
||||
# 如果更新了兴趣度,重新计算能量
|
||||
if "interest_value" in updates:
|
||||
self._update_stream_energy(stream_id)
|
||||
self._update_stream_energy()
|
||||
|
||||
logger.debug(f"更新上下文消息: {stream_id}/{message_id}")
|
||||
logger.debug(f"更新单流上下文消息: {self.stream_id}/{message_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新上下文消息失败 {stream_id}/{message_id}: {e}", exc_info=True)
|
||||
logger.error(f"更新单流上下文消息失败 {self.stream_id}/{message_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def get_context_messages(self, stream_id: str, limit: Optional[int] = None, include_unread: bool = True) -> List[DatabaseMessages]:
|
||||
def get_messages(self, limit: Optional[int] = None, include_unread: bool = True) -> List[DatabaseMessages]:
|
||||
"""获取上下文消息
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
limit: 消息数量限制
|
||||
include_unread: 是否包含未读消息
|
||||
|
||||
Returns:
|
||||
List[Any]: 消息列表
|
||||
List[DatabaseMessages]: 消息列表
|
||||
"""
|
||||
context = self.get_stream_context(stream_id)
|
||||
if not context:
|
||||
return []
|
||||
|
||||
try:
|
||||
messages = []
|
||||
if include_unread:
|
||||
messages.extend(context.get_unread_messages())
|
||||
messages.extend(self.context.get_unread_messages())
|
||||
|
||||
if limit:
|
||||
messages.extend(context.get_history_messages(limit=limit))
|
||||
messages.extend(self.context.get_history_messages(limit=limit))
|
||||
else:
|
||||
messages.extend(context.get_history_messages())
|
||||
messages.extend(self.context.get_history_messages())
|
||||
|
||||
# 按时间排序
|
||||
messages.sort(key=lambda msg: getattr(msg, 'time', 0))
|
||||
messages.sort(key=lambda msg: getattr(msg, "time", 0))
|
||||
|
||||
# 应用限制
|
||||
if limit and len(messages) > limit:
|
||||
@@ -260,103 +131,124 @@ class StreamContextManager:
|
||||
return messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取上下文消息失败 {stream_id}: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def get_unread_messages(self, stream_id: str) -> List[DatabaseMessages]:
|
||||
"""获取未读消息
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
|
||||
Returns:
|
||||
List[Any]: 未读消息列表
|
||||
"""
|
||||
context = self.get_stream_context(stream_id)
|
||||
if not context:
|
||||
logger.error(f"获取单流上下文消息失败 {self.stream_id}: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def get_unread_messages(self) -> List[DatabaseMessages]:
|
||||
"""获取未读消息"""
|
||||
try:
|
||||
return context.get_unread_messages()
|
||||
return self.context.get_unread_messages()
|
||||
except Exception as e:
|
||||
logger.error(f"获取未读消息失败 {stream_id}: {e}", exc_info=True)
|
||||
logger.error(f"获取单流未读消息失败 {self.stream_id}: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def mark_messages_as_read(self, stream_id: str, message_ids: List[str]) -> bool:
|
||||
"""标记消息为已读
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
message_ids: 消息ID列表
|
||||
|
||||
Returns:
|
||||
bool: 是否成功标记
|
||||
"""
|
||||
context = self.get_stream_context(stream_id)
|
||||
if not context:
|
||||
logger.warning(f"流上下文不存在: {stream_id}")
|
||||
return False
|
||||
|
||||
def mark_messages_as_read(self, message_ids: List[str]) -> bool:
|
||||
"""标记消息为已读"""
|
||||
try:
|
||||
if not hasattr(context, 'mark_message_as_read'):
|
||||
logger.error(f"上下文对象缺少 mark_message_as_read 方法: {stream_id}")
|
||||
if not hasattr(self.context, "mark_message_as_read"):
|
||||
logger.error(f"上下文对象缺少 mark_message_as_read 方法: {self.stream_id}")
|
||||
return False
|
||||
|
||||
marked_count = 0
|
||||
for message_id in message_ids:
|
||||
try:
|
||||
context.mark_message_as_read(message_id)
|
||||
self.context.mark_message_as_read(message_id)
|
||||
marked_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"标记消息已读失败 {message_id}: {e}")
|
||||
|
||||
logger.debug(f"标记消息为已读: {stream_id} ({marked_count}/{len(message_ids)}条)")
|
||||
logger.debug(f"标记消息为已读: {self.stream_id} ({marked_count}/{len(message_ids)}条)")
|
||||
return marked_count > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"标记消息已读失败 {stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def clear_context(self, stream_id: str) -> bool:
|
||||
"""清空上下文
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
|
||||
Returns:
|
||||
bool: 是否成功清空
|
||||
"""
|
||||
context = self.get_stream_context(stream_id)
|
||||
if not context:
|
||||
logger.warning(f"流上下文不存在: {stream_id}")
|
||||
logger.error(f"标记消息已读失败 {self.stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def clear_context(self) -> bool:
|
||||
"""清空上下文"""
|
||||
try:
|
||||
# 清空消息
|
||||
if hasattr(context, 'unread_messages'):
|
||||
context.unread_messages.clear()
|
||||
if hasattr(context, 'history_messages'):
|
||||
context.history_messages.clear()
|
||||
if hasattr(self.context, "unread_messages"):
|
||||
self.context.unread_messages.clear()
|
||||
if hasattr(self.context, "history_messages"):
|
||||
self.context.history_messages.clear()
|
||||
|
||||
# 重置状态
|
||||
reset_attrs = ['interruption_count', 'afc_threshold_adjustment', 'last_check_time']
|
||||
reset_attrs = ["interruption_count", "afc_threshold_adjustment", "last_check_time"]
|
||||
for attr in reset_attrs:
|
||||
if hasattr(context, attr):
|
||||
if attr in ['interruption_count', 'afc_threshold_adjustment']:
|
||||
setattr(context, attr, 0)
|
||||
if hasattr(self.context, attr):
|
||||
if attr in ["interruption_count", "afc_threshold_adjustment"]:
|
||||
setattr(self.context, attr, 0)
|
||||
else:
|
||||
setattr(context, attr, time.time())
|
||||
setattr(self.context, attr, time.time())
|
||||
|
||||
# 重新计算能量
|
||||
self._update_stream_energy(stream_id)
|
||||
self._update_stream_energy()
|
||||
|
||||
logger.info(f"清空上下文: {stream_id}")
|
||||
logger.info(f"清空单流上下文: {self.stream_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清空上下文失败 {stream_id}: {e}", exc_info=True)
|
||||
logger.error(f"清空单流上下文失败 {self.stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""获取流统计信息"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
uptime = current_time - self.created_time
|
||||
|
||||
unread_messages = getattr(self.context, "unread_messages", [])
|
||||
history_messages = getattr(self.context, "history_messages", [])
|
||||
|
||||
return {
|
||||
"stream_id": self.stream_id,
|
||||
"context_type": type(self.context).__name__,
|
||||
"total_messages": len(history_messages) + len(unread_messages),
|
||||
"unread_messages": len(unread_messages),
|
||||
"history_messages": len(history_messages),
|
||||
"is_active": getattr(self.context, "is_active", True),
|
||||
"last_check_time": getattr(self.context, "last_check_time", current_time),
|
||||
"interruption_count": getattr(self.context, "interruption_count", 0),
|
||||
"afc_threshold_adjustment": getattr(self.context, "afc_threshold_adjustment", 0.0),
|
||||
"created_time": self.created_time,
|
||||
"last_access_time": self.last_access_time,
|
||||
"access_count": self.access_count,
|
||||
"uptime_seconds": uptime,
|
||||
"idle_seconds": current_time - self.last_access_time,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取单流统计失败 {self.stream_id}: {e}", exc_info=True)
|
||||
return {}
|
||||
|
||||
def validate_integrity(self) -> bool:
|
||||
"""验证上下文完整性"""
|
||||
try:
|
||||
# 检查基本属性
|
||||
required_attrs = ["stream_id", "unread_messages", "history_messages"]
|
||||
for attr in required_attrs:
|
||||
if not hasattr(self.context, attr):
|
||||
logger.warning(f"上下文缺少必要属性: {attr}")
|
||||
return False
|
||||
|
||||
# 检查消息ID唯一性
|
||||
all_messages = getattr(self.context, "unread_messages", []) + getattr(self.context, "history_messages", [])
|
||||
message_ids = [msg.message_id for msg in all_messages if hasattr(msg, "message_id")]
|
||||
if len(message_ids) != len(set(message_ids)):
|
||||
logger.warning(f"上下文中存在重复消息ID: {self.stream_id}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"验证单流上下文完整性失败 {self.stream_id}: {e}")
|
||||
return False
|
||||
|
||||
def _update_access_stats(self):
|
||||
"""更新访问统计"""
|
||||
self.last_access_time = time.time()
|
||||
self.access_count += 1
|
||||
|
||||
def _calculate_message_interest(self, message: DatabaseMessages) -> float:
|
||||
"""计算消息兴趣度"""
|
||||
try:
|
||||
@@ -373,8 +265,7 @@ class StreamContextManager:
|
||||
|
||||
interest_score = loop.run_until_complete(
|
||||
chatter_interest_scoring_system._calculate_single_message_score(
|
||||
message=message,
|
||||
bot_nickname=global_config.bot.nickname
|
||||
message=message, bot_nickname=global_config.bot.nickname
|
||||
)
|
||||
)
|
||||
interest_value = interest_score.total_score
|
||||
@@ -391,12 +282,12 @@ class StreamContextManager:
|
||||
logger.error(f"计算消息兴趣度失败: {e}")
|
||||
return 0.5
|
||||
|
||||
def _update_stream_energy(self, stream_id: str):
|
||||
async def _update_stream_energy(self):
|
||||
"""更新流能量"""
|
||||
try:
|
||||
# 获取所有消息
|
||||
all_messages = self.get_context_messages(stream_id, self.max_context_size)
|
||||
unread_messages = self.get_unread_messages(stream_id)
|
||||
all_messages = self.get_messages(self.max_context_size)
|
||||
unread_messages = self.get_unread_messages()
|
||||
combined_messages = all_messages + unread_messages
|
||||
|
||||
# 获取用户ID
|
||||
@@ -406,248 +297,12 @@ class StreamContextManager:
|
||||
user_id = last_message.user_info.user_id
|
||||
|
||||
# 计算能量
|
||||
energy = energy_manager.calculate_focus_energy(
|
||||
stream_id=stream_id,
|
||||
messages=combined_messages,
|
||||
user_id=user_id
|
||||
energy = await energy_manager.calculate_focus_energy(
|
||||
stream_id=self.stream_id, messages=combined_messages, user_id=user_id
|
||||
)
|
||||
|
||||
# 更新分发管理器
|
||||
distribution_manager.update_stream_energy(stream_id, energy)
|
||||
distribution_manager.update_stream_energy(self.stream_id, energy)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新流能量失败 {stream_id}: {e}")
|
||||
|
||||
def get_stream_statistics(self, stream_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取流统计信息
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 统计信息
|
||||
"""
|
||||
context = self.get_stream_context(stream_id, update_access=False)
|
||||
if not context:
|
||||
return None
|
||||
|
||||
try:
|
||||
metadata = self.context_metadata.get(stream_id, {})
|
||||
current_time = time.time()
|
||||
created_time = metadata.get("created_time", current_time)
|
||||
last_access_time = metadata.get("last_access_time", current_time)
|
||||
access_count = metadata.get("access_count", 0)
|
||||
|
||||
unread_messages = getattr(context, "unread_messages", [])
|
||||
history_messages = getattr(context, "history_messages", [])
|
||||
|
||||
return {
|
||||
"stream_id": stream_id,
|
||||
"context_type": type(context).__name__,
|
||||
"total_messages": len(history_messages) + len(unread_messages),
|
||||
"unread_messages": len(unread_messages),
|
||||
"history_messages": len(history_messages),
|
||||
"is_active": getattr(context, "is_active", True),
|
||||
"last_check_time": getattr(context, "last_check_time", current_time),
|
||||
"interruption_count": getattr(context, "interruption_count", 0),
|
||||
"afc_threshold_adjustment": getattr(context, "afc_threshold_adjustment", 0.0),
|
||||
"created_time": created_time,
|
||||
"last_access_time": last_access_time,
|
||||
"access_count": access_count,
|
||||
"uptime_seconds": current_time - created_time,
|
||||
"idle_seconds": current_time - last_access_time,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取流统计失败 {stream_id}: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def get_manager_statistics(self) -> Dict[str, Any]:
|
||||
"""获取管理器统计信息
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 管理器统计信息
|
||||
"""
|
||||
current_time = time.time()
|
||||
uptime = current_time - self.stats.get("creation_time", current_time)
|
||||
|
||||
return {
|
||||
**self.stats,
|
||||
"uptime_hours": uptime / 3600,
|
||||
"stream_count": len(self.stream_contexts),
|
||||
"metadata_count": len(self.context_metadata),
|
||||
"auto_cleanup_enabled": self.auto_cleanup,
|
||||
"cleanup_interval": self.cleanup_interval,
|
||||
}
|
||||
|
||||
def cleanup_inactive_contexts(self, max_inactive_hours: int = 24) -> int:
|
||||
"""清理不活跃的上下文
|
||||
|
||||
Args:
|
||||
max_inactive_hours: 最大不活跃小时数
|
||||
|
||||
Returns:
|
||||
int: 清理的上下文数量
|
||||
"""
|
||||
current_time = time.time()
|
||||
max_inactive_seconds = max_inactive_hours * 3600
|
||||
|
||||
inactive_streams = []
|
||||
for stream_id, context in self.stream_contexts.items():
|
||||
try:
|
||||
# 获取最后活动时间
|
||||
metadata = self.context_metadata.get(stream_id, {})
|
||||
last_activity = metadata.get("last_access_time", metadata.get("created_time", 0))
|
||||
context_last_activity = getattr(context, "last_check_time", 0)
|
||||
actual_last_activity = max(last_activity, context_last_activity)
|
||||
|
||||
# 检查是否不活跃
|
||||
unread_count = len(getattr(context, "unread_messages", []))
|
||||
history_count = len(getattr(context, "history_messages", []))
|
||||
total_messages = unread_count + history_count
|
||||
|
||||
if (current_time - actual_last_activity > max_inactive_seconds and
|
||||
total_messages == 0):
|
||||
inactive_streams.append(stream_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"检查上下文活跃状态失败 {stream_id}: {e}")
|
||||
continue
|
||||
|
||||
# 清理不活跃上下文
|
||||
cleaned_count = 0
|
||||
for stream_id in inactive_streams:
|
||||
if self.remove_stream_context(stream_id):
|
||||
cleaned_count += 1
|
||||
|
||||
if cleaned_count > 0:
|
||||
logger.info(f"清理了 {cleaned_count} 个不活跃上下文")
|
||||
|
||||
return cleaned_count
|
||||
|
||||
def validate_context_integrity(self, stream_id: str) -> bool:
|
||||
"""验证上下文完整性
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
|
||||
Returns:
|
||||
bool: 是否完整
|
||||
"""
|
||||
context = self.get_stream_context(stream_id)
|
||||
if not context:
|
||||
return False
|
||||
|
||||
try:
|
||||
# 检查基本属性
|
||||
required_attrs = ["stream_id", "unread_messages", "history_messages"]
|
||||
for attr in required_attrs:
|
||||
if not hasattr(context, attr):
|
||||
logger.warning(f"上下文缺少必要属性: {attr}")
|
||||
return False
|
||||
|
||||
# 检查消息ID唯一性
|
||||
all_messages = getattr(context, "unread_messages", []) + getattr(context, "history_messages", [])
|
||||
message_ids = [msg.message_id for msg in all_messages if hasattr(msg, "message_id")]
|
||||
if len(message_ids) != len(set(message_ids)):
|
||||
logger.warning(f"上下文中存在重复消息ID: {stream_id}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"验证上下文完整性失败 {stream_id}: {e}")
|
||||
return False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动上下文管理器"""
|
||||
if self.is_running:
|
||||
logger.warning("上下文管理器已经在运行")
|
||||
return
|
||||
|
||||
await self.start_auto_cleanup()
|
||||
logger.info("上下文管理器已启动")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止上下文管理器"""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
await self.stop_auto_cleanup()
|
||||
logger.info("上下文管理器已停止")
|
||||
|
||||
async def start_auto_cleanup(self, interval: Optional[float] = None) -> None:
|
||||
"""启动自动清理
|
||||
|
||||
Args:
|
||||
interval: 清理间隔(秒)
|
||||
"""
|
||||
if not self.auto_cleanup:
|
||||
logger.info("自动清理已禁用")
|
||||
return
|
||||
|
||||
if self.is_running:
|
||||
logger.warning("自动清理已在运行")
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
cleanup_interval = interval or self.cleanup_interval
|
||||
logger.info(f"启动自动清理(间隔: {cleanup_interval}s)")
|
||||
|
||||
import asyncio
|
||||
self.cleanup_task = asyncio.create_task(self._cleanup_loop(cleanup_interval))
|
||||
|
||||
async def stop_auto_cleanup(self) -> None:
|
||||
"""停止自动清理"""
|
||||
self.is_running = False
|
||||
if self.cleanup_task and not self.cleanup_task.done():
|
||||
self.cleanup_task.cancel()
|
||||
try:
|
||||
await self.cleanup_task
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("自动清理已停止")
|
||||
|
||||
async def _cleanup_loop(self, interval: float) -> None:
|
||||
"""清理循环
|
||||
|
||||
Args:
|
||||
interval: 清理间隔
|
||||
"""
|
||||
while self.is_running:
|
||||
try:
|
||||
await asyncio.sleep(interval)
|
||||
self.cleanup_inactive_contexts()
|
||||
self._cleanup_expired_contexts()
|
||||
logger.debug("自动清理完成")
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"清理循环出错: {e}", exc_info=True)
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
def _cleanup_expired_contexts(self) -> None:
|
||||
"""清理过期上下文"""
|
||||
current_time = time.time()
|
||||
expired_contexts = []
|
||||
|
||||
for stream_id, metadata in self.context_metadata.items():
|
||||
created_time = metadata.get("created_time", current_time)
|
||||
if current_time - created_time > self.context_ttl:
|
||||
expired_contexts.append(stream_id)
|
||||
|
||||
for stream_id in expired_contexts:
|
||||
self.remove_stream_context(stream_id)
|
||||
|
||||
if expired_contexts:
|
||||
logger.info(f"清理了 {len(expired_contexts)} 个过期上下文")
|
||||
|
||||
def get_active_streams(self) -> List[str]:
|
||||
"""获取活跃流列表
|
||||
|
||||
Returns:
|
||||
List[str]: 活跃流ID列表
|
||||
"""
|
||||
return list(self.stream_contexts.keys())
|
||||
|
||||
|
||||
# 全局上下文管理器实例
|
||||
context_manager = StreamContextManager()
|
||||
logger.error(f"更新单流能量失败 {self.stream_id}: {e}")
|
||||
@@ -14,11 +14,10 @@ from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.message_manager_data_model import StreamContext, MessageManagerStats, StreamStats
|
||||
from src.chat.chatter_manager import ChatterManager
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.plugin_system.base.component_types import ChatMode
|
||||
from .sleep_manager.sleep_manager import SleepManager
|
||||
from .sleep_manager.wakeup_manager import WakeUpManager
|
||||
from src.config.config import global_config
|
||||
from .context_manager import context_manager
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
@@ -45,8 +44,7 @@ class MessageManager:
|
||||
self.sleep_manager = SleepManager()
|
||||
self.wakeup_manager = WakeUpManager(self.sleep_manager)
|
||||
|
||||
# 初始化上下文管理器
|
||||
self.context_manager = context_manager
|
||||
# 不再需要全局上下文管理器,直接通过 ChatManager 访问各个 ChatStream 的 context_manager
|
||||
|
||||
async def start(self):
|
||||
"""启动消息管理器"""
|
||||
@@ -57,7 +55,7 @@ class MessageManager:
|
||||
self.is_running = True
|
||||
self.manager_task = asyncio.create_task(self._manager_loop())
|
||||
await self.wakeup_manager.start()
|
||||
await self.context_manager.start()
|
||||
# await self.context_manager.start() # 已删除,需要重构
|
||||
logger.info("消息管理器已启动")
|
||||
|
||||
async def stop(self):
|
||||
@@ -73,28 +71,31 @@ class MessageManager:
|
||||
self.manager_task.cancel()
|
||||
|
||||
await self.wakeup_manager.stop()
|
||||
await self.context_manager.stop()
|
||||
# await self.context_manager.stop() # 已删除,需要重构
|
||||
|
||||
logger.info("消息管理器已停止")
|
||||
|
||||
def add_message(self, stream_id: str, message: DatabaseMessages):
|
||||
"""添加消息到指定聊天流"""
|
||||
# 检查流上下文是否存在,不存在则创建
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if not context:
|
||||
# 创建新的流上下文
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
context = StreamContext(stream_id=stream_id)
|
||||
# 将创建的上下文添加到管理器
|
||||
self.context_manager.add_stream_context(stream_id, context)
|
||||
try:
|
||||
# 通过 ChatManager 获取 ChatStream
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
|
||||
# 使用 context_manager 添加消息
|
||||
success = self.context_manager.add_message_to_context(stream_id, message)
|
||||
if not chat_stream:
|
||||
logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在")
|
||||
return
|
||||
|
||||
if success:
|
||||
logger.debug(f"添加消息到聊天流 {stream_id}: {message.message_id}")
|
||||
else:
|
||||
logger.warning(f"添加消息到聊天流 {stream_id} 失败")
|
||||
# 使用 ChatStream 的 context_manager 添加消息
|
||||
success = chat_stream.context_manager.add_message(message)
|
||||
|
||||
if success:
|
||||
logger.debug(f"添加消息到聊天流 {stream_id}: {message.message_id}")
|
||||
else:
|
||||
logger.warning(f"添加消息到聊天流 {stream_id} 失败")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"添加消息到聊天流 {stream_id} 时发生错误: {e}")
|
||||
|
||||
def update_message(
|
||||
self,
|
||||
@@ -105,17 +106,60 @@ class MessageManager:
|
||||
should_reply: bool = None,
|
||||
):
|
||||
"""更新消息信息"""
|
||||
# 使用 context_manager 更新消息信息
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if context:
|
||||
context.update_message_info(message_id, interest_value, actions, should_reply)
|
||||
try:
|
||||
# 通过 ChatManager 获取 ChatStream
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
|
||||
if not chat_stream:
|
||||
logger.warning(f"MessageManager.update_message: 聊天流 {stream_id} 不存在")
|
||||
return
|
||||
|
||||
# 构建更新字典
|
||||
updates = {}
|
||||
if interest_value is not None:
|
||||
updates["interest_value"] = interest_value
|
||||
if actions is not None:
|
||||
updates["actions"] = actions
|
||||
if should_reply is not None:
|
||||
updates["should_reply"] = should_reply
|
||||
|
||||
# 使用 ChatStream 的 context_manager 更新消息
|
||||
if updates:
|
||||
success = chat_stream.context_manager.update_message(message_id, updates)
|
||||
if success:
|
||||
logger.debug(f"更新消息 {message_id} 成功")
|
||||
else:
|
||||
logger.warning(f"更新消息 {message_id} 失败")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新消息 {message_id} 时发生错误: {e}")
|
||||
|
||||
def add_action(self, stream_id: str, message_id: str, action: str):
|
||||
"""添加动作到消息"""
|
||||
# 使用 context_manager 添加动作到消息
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if context:
|
||||
context.add_action_to_message(message_id, action)
|
||||
try:
|
||||
# 通过 ChatManager 获取 ChatStream
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
|
||||
if not chat_stream:
|
||||
logger.warning(f"MessageManager.add_action: 聊天流 {stream_id} 不存在")
|
||||
return
|
||||
|
||||
# 使用 ChatStream 的 context_manager 添加动作
|
||||
# 注意:这里需要根据实际的 API 调整
|
||||
# 假设我们可以通过 update_message 来添加动作
|
||||
success = chat_stream.context_manager.update_message(
|
||||
message_id, {"actions": [action]}
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.debug(f"为消息 {message_id} 添加动作 {action} 成功")
|
||||
else:
|
||||
logger.warning(f"为消息 {message_id} 添加动作 {action} 失败")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"为消息 {message_id} 添加动作时发生错误: {e}")
|
||||
|
||||
async def _manager_loop(self):
|
||||
"""管理器主循环 - 独立聊天流分发周期版本"""
|
||||
@@ -145,38 +189,53 @@ class MessageManager:
|
||||
active_streams = 0
|
||||
total_unread = 0
|
||||
|
||||
# 使用 context_manager 获取活跃的流
|
||||
active_stream_ids = self.context_manager.get_active_streams()
|
||||
# 通过 ChatManager 获取所有活跃的流
|
||||
try:
|
||||
chat_manager = get_chat_manager()
|
||||
active_stream_ids = list(chat_manager.streams.keys())
|
||||
|
||||
for stream_id in active_stream_ids:
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if not context:
|
||||
continue
|
||||
for stream_id in active_stream_ids:
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
if not chat_stream:
|
||||
continue
|
||||
|
||||
active_streams += 1
|
||||
# 检查流是否活跃
|
||||
context = chat_stream.stream_context
|
||||
if not context.is_active:
|
||||
continue
|
||||
|
||||
# 检查是否有未读消息
|
||||
unread_messages = self.context_manager.get_unread_messages(stream_id)
|
||||
if unread_messages:
|
||||
total_unread += len(unread_messages)
|
||||
active_streams += 1
|
||||
|
||||
# 如果没有处理任务,创建一个
|
||||
if not hasattr(context, 'processing_task') or not context.processing_task or context.processing_task.done():
|
||||
context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
|
||||
# 检查是否有未读消息
|
||||
unread_messages = chat_stream.context_manager.get_unread_messages()
|
||||
if unread_messages:
|
||||
total_unread += len(unread_messages)
|
||||
|
||||
# 更新统计
|
||||
self.stats.active_streams = active_streams
|
||||
self.stats.total_unread_messages = total_unread
|
||||
# 如果没有处理任务,创建一个
|
||||
if not hasattr(context, 'processing_task') or not context.processing_task or context.processing_task.done():
|
||||
context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
|
||||
|
||||
# 更新统计
|
||||
self.stats.active_streams = active_streams
|
||||
self.stats.total_unread_messages = total_unread
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检查所有聊天流时发生错误: {e}")
|
||||
|
||||
async def _process_stream_messages(self, stream_id: str):
|
||||
"""处理指定聊天流的消息"""
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if not context:
|
||||
return
|
||||
|
||||
try:
|
||||
# 通过 ChatManager 获取 ChatStream
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
if not chat_stream:
|
||||
logger.warning(f"处理消息失败: 聊天流 {stream_id} 不存在")
|
||||
return
|
||||
|
||||
context = chat_stream.stream_context
|
||||
|
||||
# 获取未读消息
|
||||
unread_messages = self.context_manager.get_unread_messages(stream_id)
|
||||
unread_messages = chat_stream.context_manager.get_unread_messages()
|
||||
if not unread_messages:
|
||||
return
|
||||
|
||||
@@ -250,8 +309,15 @@ class MessageManager:
|
||||
|
||||
def deactivate_stream(self, stream_id: str):
|
||||
"""停用聊天流"""
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if context:
|
||||
try:
|
||||
# 通过 ChatManager 获取 ChatStream
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
if not chat_stream:
|
||||
logger.warning(f"停用流失败: 聊天流 {stream_id} 不存在")
|
||||
return
|
||||
|
||||
context = chat_stream.stream_context
|
||||
context.is_active = False
|
||||
|
||||
# 取消处理任务
|
||||
@@ -260,27 +326,50 @@ class MessageManager:
|
||||
|
||||
logger.info(f"停用聊天流: {stream_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"停用聊天流 {stream_id} 时发生错误: {e}")
|
||||
|
||||
def activate_stream(self, stream_id: str):
|
||||
"""激活聊天流"""
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if context:
|
||||
try:
|
||||
# 通过 ChatManager 获取 ChatStream
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
if not chat_stream:
|
||||
logger.warning(f"激活流失败: 聊天流 {stream_id} 不存在")
|
||||
return
|
||||
|
||||
context = chat_stream.stream_context
|
||||
context.is_active = True
|
||||
logger.info(f"激活聊天流: {stream_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"激活聊天流 {stream_id} 时发生错误: {e}")
|
||||
|
||||
def get_stream_stats(self, stream_id: str) -> Optional[StreamStats]:
|
||||
"""获取聊天流统计"""
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if not context:
|
||||
return None
|
||||
try:
|
||||
# 通过 ChatManager 获取 ChatStream
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
if not chat_stream:
|
||||
return None
|
||||
|
||||
return StreamStats(
|
||||
stream_id=stream_id,
|
||||
is_active=context.is_active,
|
||||
unread_count=len(self.context_manager.get_unread_messages(stream_id)),
|
||||
history_count=len(context.history_messages),
|
||||
last_check_time=context.last_check_time,
|
||||
has_active_task=bool(hasattr(context, 'processing_task') and context.processing_task and not context.processing_task.done()),
|
||||
)
|
||||
context = chat_stream.stream_context
|
||||
unread_count = len(chat_stream.context_manager.get_unread_messages())
|
||||
|
||||
return StreamStats(
|
||||
stream_id=stream_id,
|
||||
is_active=context.is_active,
|
||||
unread_count=unread_count,
|
||||
history_count=len(context.history_messages),
|
||||
last_check_time=context.last_check_time,
|
||||
has_active_task=bool(hasattr(context, 'processing_task') and context.processing_task and not context.processing_task.done()),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取聊天流 {stream_id} 统计时发生错误: {e}")
|
||||
return None
|
||||
|
||||
def get_manager_stats(self) -> Dict[str, Any]:
|
||||
"""获取管理器统计"""
|
||||
@@ -295,9 +384,36 @@ class MessageManager:
|
||||
|
||||
def cleanup_inactive_streams(self, max_inactive_hours: int = 24):
|
||||
"""清理不活跃的聊天流"""
|
||||
# 使用 context_manager 的自动清理功能
|
||||
self.context_manager.cleanup_inactive_contexts(max_inactive_hours * 3600)
|
||||
logger.info("已启动不活跃聊天流清理")
|
||||
try:
|
||||
# 通过 ChatManager 清理不活跃的流
|
||||
chat_manager = get_chat_manager()
|
||||
current_time = time.time()
|
||||
max_inactive_seconds = max_inactive_hours * 3600
|
||||
|
||||
inactive_streams = []
|
||||
for stream_id, chat_stream in chat_manager.streams.items():
|
||||
# 检查最后活跃时间
|
||||
if current_time - chat_stream.last_active_time > max_inactive_seconds:
|
||||
inactive_streams.append(stream_id)
|
||||
|
||||
# 清理不活跃的流
|
||||
for stream_id in inactive_streams:
|
||||
try:
|
||||
# 清理流的内容
|
||||
chat_stream.context_manager.clear_context()
|
||||
# 从 ChatManager 中移除
|
||||
del chat_manager.streams[stream_id]
|
||||
logger.info(f"清理不活跃聊天流: {stream_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"清理聊天流 {stream_id} 失败: {e}")
|
||||
|
||||
if inactive_streams:
|
||||
logger.info(f"已清理 {len(inactive_streams)} 个不活跃聊天流")
|
||||
else:
|
||||
logger.debug("没有需要清理的不活跃聊天流")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理不活跃聊天流时发生错误: {e}")
|
||||
|
||||
async def _check_and_handle_interruption(self, context: StreamContext, stream_id: str):
|
||||
"""检查并处理消息打断"""
|
||||
@@ -376,115 +492,123 @@ class MessageManager:
|
||||
min_delay = float("inf")
|
||||
|
||||
# 找到最近需要检查的流
|
||||
active_stream_ids = self.context_manager.get_active_streams()
|
||||
for stream_id in active_stream_ids:
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if not context or not context.is_active:
|
||||
continue
|
||||
try:
|
||||
chat_manager = get_chat_manager()
|
||||
for _stream_id, chat_stream in chat_manager.streams.items():
|
||||
context = chat_stream.stream_context
|
||||
if not context or not context.is_active:
|
||||
continue
|
||||
|
||||
time_until_check = context.next_check_time - current_time
|
||||
if time_until_check > 0:
|
||||
min_delay = min(min_delay, time_until_check)
|
||||
else:
|
||||
min_delay = 0.1 # 立即检查
|
||||
break
|
||||
time_until_check = context.next_check_time - current_time
|
||||
if time_until_check > 0:
|
||||
min_delay = min(min_delay, time_until_check)
|
||||
else:
|
||||
min_delay = 0.1 # 立即检查
|
||||
break
|
||||
|
||||
# 如果没有活跃流,使用默认间隔
|
||||
if min_delay == float("inf"):
|
||||
# 如果没有活跃流,使用默认间隔
|
||||
if min_delay == float("inf"):
|
||||
return self.check_interval
|
||||
|
||||
# 确保最小延迟
|
||||
return max(0.1, min(min_delay, self.check_interval))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算下次检查延迟时发生错误: {e}")
|
||||
return self.check_interval
|
||||
|
||||
# 确保最小延迟
|
||||
return max(0.1, min(min_delay, self.check_interval))
|
||||
|
||||
async def _check_streams_with_individual_intervals(self):
|
||||
"""检查所有达到检查时间的聊天流"""
|
||||
current_time = time.time()
|
||||
processed_streams = 0
|
||||
|
||||
# 使用 context_manager 获取活跃的流
|
||||
active_stream_ids = self.context_manager.get_active_streams()
|
||||
# 通过 ChatManager 获取活跃的流
|
||||
try:
|
||||
chat_manager = get_chat_manager()
|
||||
for stream_id, chat_stream in chat_manager.streams.items():
|
||||
context = chat_stream.stream_context
|
||||
if not context or not context.is_active:
|
||||
continue
|
||||
|
||||
for stream_id in active_stream_ids:
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if not context or not context.is_active:
|
||||
continue
|
||||
# 检查是否达到检查时间
|
||||
if current_time >= context.next_check_time:
|
||||
# 更新检查时间
|
||||
context.last_check_time = current_time
|
||||
|
||||
# 检查是否达到检查时间
|
||||
if current_time >= context.next_check_time:
|
||||
# 更新检查时间
|
||||
context.last_check_time = current_time
|
||||
# 计算下次检查时间和分发周期
|
||||
if global_config.chat.dynamic_distribution_enabled:
|
||||
context.distribution_interval = self._calculate_stream_distribution_interval(context)
|
||||
else:
|
||||
context.distribution_interval = self.check_interval
|
||||
|
||||
# 计算下次检查时间和分发周期
|
||||
if global_config.chat.dynamic_distribution_enabled:
|
||||
context.distribution_interval = self._calculate_stream_distribution_interval(context)
|
||||
else:
|
||||
context.distribution_interval = self.check_interval
|
||||
# 设置下次检查时间
|
||||
context.next_check_time = current_time + context.distribution_interval
|
||||
|
||||
# 设置下次检查时间
|
||||
context.next_check_time = current_time + context.distribution_interval
|
||||
# 检查未读消息
|
||||
unread_messages = chat_stream.context_manager.get_unread_messages()
|
||||
if unread_messages:
|
||||
processed_streams += 1
|
||||
self.stats.total_unread_messages = len(unread_messages)
|
||||
|
||||
# 检查未读消息
|
||||
unread_messages = self.context_manager.get_unread_messages(stream_id)
|
||||
if unread_messages:
|
||||
processed_streams += 1
|
||||
self.stats.total_unread_messages = len(unread_messages)
|
||||
# 如果没有处理任务,创建一个
|
||||
if not context.processing_task or context.processing_task.done():
|
||||
focus_energy = chat_stream.focus_energy
|
||||
|
||||
# 如果没有处理任务,创建一个
|
||||
if not context.processing_task or context.processing_task.done():
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
# 根据优先级记录日志
|
||||
if focus_energy >= 0.7:
|
||||
logger.info(
|
||||
f"高优先级流 {stream_id} 开始处理 | "
|
||||
f"focus_energy: {focus_energy:.3f} | "
|
||||
f"分发周期: {context.distribution_interval:.2f}s | "
|
||||
f"未读消息: {len(unread_messages)}"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"流 {stream_id} 开始处理 | "
|
||||
f"focus_energy: {focus_energy:.3f} | "
|
||||
f"分发周期: {context.distribution_interval:.2f}s"
|
||||
)
|
||||
|
||||
chat_stream = get_chat_manager().get_stream(context.stream_id)
|
||||
focus_energy = chat_stream.focus_energy if chat_stream else 0.5
|
||||
context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
|
||||
|
||||
# 根据优先级记录日志
|
||||
if focus_energy >= 0.7:
|
||||
logger.info(
|
||||
f"高优先级流 {stream_id} 开始处理 | "
|
||||
f"focus_energy: {focus_energy:.3f} | "
|
||||
f"分发周期: {context.distribution_interval:.2f}s | "
|
||||
f"未读消息: {len(unread_messages)}"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"流 {stream_id} 开始处理 | "
|
||||
f"focus_energy: {focus_energy:.3f} | "
|
||||
f"分发周期: {context.distribution_interval:.2f}s"
|
||||
)
|
||||
|
||||
context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
|
||||
except Exception as e:
|
||||
logger.error(f"检查独立分发周期的聊天流时发生错误: {e}")
|
||||
|
||||
# 更新活跃流计数
|
||||
active_count = len(self.context_manager.get_active_streams())
|
||||
self.stats.active_streams = active_count
|
||||
try:
|
||||
chat_manager = get_chat_manager()
|
||||
active_count = len([s for s in chat_manager.streams.values() if s.stream_context.is_active])
|
||||
self.stats.active_streams = active_count
|
||||
|
||||
if processed_streams > 0:
|
||||
logger.debug(f"本次循环处理了 {processed_streams} 个流 | 活跃流总数: {active_count}")
|
||||
if processed_streams > 0:
|
||||
logger.debug(f"本次循环处理了 {processed_streams} 个流 | 活跃流总数: {active_count}")
|
||||
except Exception as e:
|
||||
logger.error(f"更新活跃流计数时发生错误: {e}")
|
||||
|
||||
async def _check_all_streams_with_priority(self):
|
||||
"""按优先级检查所有聊天流,高focus_energy的流优先处理"""
|
||||
if not self.context_manager.get_active_streams():
|
||||
return
|
||||
try:
|
||||
chat_manager = get_chat_manager()
|
||||
if not chat_manager.streams:
|
||||
return
|
||||
|
||||
# 获取活跃的聊天流并按focus_energy排序
|
||||
active_streams = []
|
||||
active_stream_ids = self.context_manager.get_active_streams()
|
||||
# 获取活跃的聊天流并按focus_energy排序
|
||||
active_streams = []
|
||||
for stream_id, chat_stream in chat_manager.streams.items():
|
||||
context = chat_stream.stream_context
|
||||
if not context or not context.is_active:
|
||||
continue
|
||||
|
||||
for stream_id in active_stream_ids:
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if not context or not context.is_active:
|
||||
continue
|
||||
|
||||
# 获取focus_energy,如果不存在则使用默认值
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
|
||||
chat_stream = get_chat_manager().get_stream(context.stream_id)
|
||||
focus_energy = 0.5
|
||||
if chat_stream:
|
||||
# 获取focus_energy
|
||||
focus_energy = chat_stream.focus_energy
|
||||
|
||||
# 计算流优先级分数
|
||||
priority_score = self._calculate_stream_priority(context, focus_energy)
|
||||
active_streams.append((priority_score, stream_id, context))
|
||||
# 计算流优先级分数
|
||||
priority_score = self._calculate_stream_priority(context, focus_energy)
|
||||
active_streams.append((priority_score, stream_id, context))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取活跃流列表时发生错误: {e}")
|
||||
return
|
||||
|
||||
# 按优先级降序排序
|
||||
active_streams.sort(reverse=True, key=lambda x: x[0])
|
||||
@@ -497,21 +621,29 @@ class MessageManager:
|
||||
active_stream_count += 1
|
||||
|
||||
# 检查是否有未读消息
|
||||
unread_messages = self.context_manager.get_unread_messages(stream_id)
|
||||
if unread_messages:
|
||||
total_unread += len(unread_messages)
|
||||
try:
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
if not chat_stream:
|
||||
continue
|
||||
|
||||
# 如果没有处理任务,创建一个
|
||||
if not hasattr(context, 'processing_task') or not context.processing_task or context.processing_task.done():
|
||||
context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
|
||||
unread_messages = chat_stream.context_manager.get_unread_messages()
|
||||
if unread_messages:
|
||||
total_unread += len(unread_messages)
|
||||
|
||||
# 高优先级流的额外日志
|
||||
if priority_score > 0.7:
|
||||
logger.info(
|
||||
f"高优先级流 {stream_id} 开始处理 | "
|
||||
f"优先级: {priority_score:.3f} | "
|
||||
f"未读消息: {len(unread_messages)}"
|
||||
)
|
||||
# 如果没有处理任务,创建一个
|
||||
if not hasattr(context, 'processing_task') or not context.processing_task or context.processing_task.done():
|
||||
context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
|
||||
|
||||
# 高优先级流的额外日志
|
||||
if priority_score > 0.7:
|
||||
logger.info(
|
||||
f"高优先级流 {stream_id} 开始处理 | "
|
||||
f"优先级: {priority_score:.3f} | "
|
||||
f"未读消息: {len(unread_messages)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"处理流 {stream_id} 的未读消息时发生错误: {e}")
|
||||
continue
|
||||
|
||||
# 更新统计
|
||||
self.stats.active_streams = active_stream_count
|
||||
@@ -536,22 +668,33 @@ class MessageManager:
|
||||
|
||||
def _clear_all_unread_messages(self, stream_id: str):
|
||||
"""清除指定上下文中的所有未读消息,防止意外情况导致消息一直未读"""
|
||||
unread_messages = self.context_manager.get_unread_messages(stream_id)
|
||||
if not unread_messages:
|
||||
return
|
||||
try:
|
||||
# 通过 ChatManager 获取 ChatStream
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
if not chat_stream:
|
||||
logger.warning(f"清除消息失败: 聊天流 {stream_id} 不存在")
|
||||
return
|
||||
|
||||
logger.warning(f"正在清除 {len(unread_messages)} 条未读消息")
|
||||
# 获取未读消息
|
||||
unread_messages = chat_stream.context_manager.get_unread_messages()
|
||||
if not unread_messages:
|
||||
return
|
||||
|
||||
# 将所有未读消息标记为已读
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if context:
|
||||
for msg in unread_messages[:]: # 使用切片复制避免迭代时修改列表
|
||||
try:
|
||||
context.mark_message_as_read(msg.message_id)
|
||||
self.stats.total_processed_messages += 1
|
||||
logger.debug(f"强制清除消息 {msg.message_id},标记为已读")
|
||||
except Exception as e:
|
||||
logger.error(f"清除消息 {msg.message_id} 时出错: {e}")
|
||||
logger.warning(f"正在清除 {len(unread_messages)} 条未读消息")
|
||||
|
||||
# 将所有未读消息标记为已读
|
||||
message_ids = [msg.message_id for msg in unread_messages]
|
||||
success = chat_stream.context_manager.mark_messages_as_read(message_ids)
|
||||
|
||||
if success:
|
||||
self.stats.total_processed_messages += len(unread_messages)
|
||||
logger.debug(f"强制清除 {len(unread_messages)} 条消息,标记为已读")
|
||||
else:
|
||||
logger.error("标记未读消息为已读失败")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清除未读消息时发生错误: {e}")
|
||||
|
||||
|
||||
# 创建全局消息管理器实例
|
||||
|
||||
@@ -49,10 +49,18 @@ class ChatStream:
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.plugin_system.base.component_types import ChatType, ChatMode
|
||||
|
||||
# 创建StreamContext
|
||||
self.stream_context: StreamContext = StreamContext(
|
||||
stream_id=stream_id, chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE, chat_mode=ChatMode.NORMAL
|
||||
)
|
||||
|
||||
# 创建单流上下文管理器
|
||||
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
||||
|
||||
self.context_manager: SingleStreamContextManager = SingleStreamContextManager(
|
||||
stream_id=stream_id, context=self.stream_context
|
||||
)
|
||||
|
||||
# 基础参数
|
||||
self.base_interest_energy = 0.5 # 默认基础兴趣度
|
||||
self._focus_energy = 0.5 # 内部存储的focus_energy值
|
||||
@@ -61,6 +69,37 @@ class ChatStream:
|
||||
# 自动加载历史消息
|
||||
self._load_history_messages()
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
"""自定义深拷贝方法,避免复制不可序列化的 asyncio.Task 对象"""
|
||||
import copy
|
||||
|
||||
# 创建新的实例
|
||||
new_stream = ChatStream(
|
||||
stream_id=self.stream_id,
|
||||
platform=self.platform,
|
||||
user_info=copy.deepcopy(self.user_info, memo),
|
||||
group_info=copy.deepcopy(self.group_info, memo),
|
||||
)
|
||||
|
||||
# 复制基本属性
|
||||
new_stream.create_time = self.create_time
|
||||
new_stream.last_active_time = self.last_active_time
|
||||
new_stream.sleep_pressure = self.sleep_pressure
|
||||
new_stream.saved = self.saved
|
||||
new_stream.base_interest_energy = self.base_interest_energy
|
||||
new_stream._focus_energy = self._focus_energy
|
||||
new_stream.no_reply_consecutive = self.no_reply_consecutive
|
||||
|
||||
# 复制 stream_context,但跳过 processing_task
|
||||
new_stream.stream_context = copy.deepcopy(self.stream_context, memo)
|
||||
if hasattr(new_stream.stream_context, 'processing_task'):
|
||||
new_stream.stream_context.processing_task = None
|
||||
|
||||
# 复制 context_manager
|
||||
new_stream.context_manager = copy.deepcopy(self.context_manager, memo)
|
||||
|
||||
return new_stream
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
@@ -74,10 +113,10 @@ class ChatStream:
|
||||
"focus_energy": self.focus_energy,
|
||||
# 基础兴趣度
|
||||
"base_interest_energy": self.base_interest_energy,
|
||||
# 新增stream_context信息
|
||||
# stream_context基本信息
|
||||
"stream_context_chat_type": self.stream_context.chat_type.value,
|
||||
"stream_context_chat_mode": self.stream_context.chat_mode.value,
|
||||
# 新增interruption_count信息
|
||||
# 统计信息
|
||||
"interruption_count": self.stream_context.interruption_count,
|
||||
}
|
||||
|
||||
@@ -109,6 +148,14 @@ class ChatStream:
|
||||
if "interruption_count" in data:
|
||||
instance.stream_context.interruption_count = data["interruption_count"]
|
||||
|
||||
# 确保 context_manager 已初始化
|
||||
if not hasattr(instance, "context_manager"):
|
||||
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
||||
|
||||
instance.context_manager = SingleStreamContextManager(
|
||||
stream_id=instance.stream_id, context=instance.stream_context
|
||||
)
|
||||
|
||||
return instance
|
||||
|
||||
def update_active_time(self):
|
||||
@@ -195,12 +242,14 @@ class ChatStream:
|
||||
self.stream_context.priority_info = getattr(message, "priority_info", None)
|
||||
|
||||
# 调试日志:记录数据转移情况
|
||||
logger.debug(f"消息数据转移完成 - message_id: {db_message.message_id}, "
|
||||
f"chat_id: {db_message.chat_id}, "
|
||||
f"is_mentioned: {db_message.is_mentioned}, "
|
||||
f"is_emoji: {db_message.is_emoji}, "
|
||||
f"is_picid: {db_message.is_picid}, "
|
||||
f"interest_value: {db_message.interest_value}")
|
||||
logger.debug(
|
||||
f"消息数据转移完成 - message_id: {db_message.message_id}, "
|
||||
f"chat_id: {db_message.chat_id}, "
|
||||
f"is_mentioned: {db_message.is_mentioned}, "
|
||||
f"is_emoji: {db_message.is_emoji}, "
|
||||
f"is_picid: {db_message.is_picid}, "
|
||||
f"interest_value: {db_message.interest_value}"
|
||||
)
|
||||
|
||||
def _safe_get_actions(self, message: "MessageRecv") -> Optional[list]:
|
||||
"""安全获取消息的actions字段"""
|
||||
@@ -213,6 +262,7 @@ class ChatStream:
|
||||
if isinstance(actions, str):
|
||||
try:
|
||||
import json
|
||||
|
||||
actions = json.loads(actions)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"无法解析actions JSON字符串: {actions}")
|
||||
@@ -269,14 +319,17 @@ class ChatStream:
|
||||
|
||||
@property
|
||||
def focus_energy(self) -> float:
|
||||
"""使用重构后的能量管理器计算focus_energy"""
|
||||
try:
|
||||
from src.chat.energy_system import energy_manager
|
||||
"""获取缓存的focus_energy值"""
|
||||
if hasattr(self, "_focus_energy"):
|
||||
return self._focus_energy
|
||||
else:
|
||||
return 0.5
|
||||
|
||||
# 获取所有消息
|
||||
history_messages = self.stream_context.get_history_messages(limit=global_config.chat.max_context_size)
|
||||
unread_messages = self.stream_context.get_unread_messages()
|
||||
all_messages = history_messages + unread_messages
|
||||
async def calculate_focus_energy(self) -> float:
|
||||
"""异步计算focus_energy"""
|
||||
try:
|
||||
# 使用单流上下文管理器获取消息
|
||||
all_messages = self.context_manager.get_messages(limit=global_config.chat.max_context_size)
|
||||
|
||||
# 获取用户ID
|
||||
user_id = None
|
||||
@@ -284,10 +337,10 @@ class ChatStream:
|
||||
user_id = str(self.user_info.user_id)
|
||||
|
||||
# 使用能量管理器计算
|
||||
energy = energy_manager.calculate_focus_energy(
|
||||
stream_id=self.stream_id,
|
||||
messages=all_messages,
|
||||
user_id=user_id
|
||||
from src.chat.energy_system import energy_manager
|
||||
|
||||
energy = await energy_manager.calculate_focus_energy(
|
||||
stream_id=self.stream_id, messages=all_messages, user_id=user_id
|
||||
)
|
||||
|
||||
# 更新内部存储
|
||||
@@ -299,7 +352,7 @@ class ChatStream:
|
||||
except Exception as e:
|
||||
logger.error(f"获取focus_energy失败: {e}", exc_info=True)
|
||||
# 返回缓存的值或默认值
|
||||
if hasattr(self, '_focus_energy'):
|
||||
if hasattr(self, "_focus_energy"):
|
||||
return self._focus_energy
|
||||
else:
|
||||
return 0.5
|
||||
@@ -309,7 +362,7 @@ class ChatStream:
|
||||
"""设置focus_energy值(主要用于初始化或特殊场景)"""
|
||||
self._focus_energy = max(0.0, min(1.0, value))
|
||||
|
||||
def _get_user_relationship_score(self) -> float:
|
||||
async def _get_user_relationship_score(self) -> float:
|
||||
"""获取用户关系分"""
|
||||
# 使用插件内部的兴趣度评分系统
|
||||
try:
|
||||
@@ -317,7 +370,7 @@ class ChatStream:
|
||||
|
||||
if self.user_info and hasattr(self.user_info, "user_id"):
|
||||
user_id = str(self.user_info.user_id)
|
||||
relationship_score = chatter_interest_scoring_system._calculate_relationship_score(user_id)
|
||||
relationship_score = await chatter_interest_scoring_system._calculate_relationship_score(user_id)
|
||||
logger.debug(f"ChatStream {self.stream_id}: 用户关系分 = {relationship_score:.3f}")
|
||||
return max(0.0, min(1.0, relationship_score))
|
||||
|
||||
@@ -346,7 +399,8 @@ class ChatStream:
|
||||
.order_by(desc(Messages.time))
|
||||
.limit(global_config.chat.max_context_size)
|
||||
)
|
||||
results = session.execute(stmt).scalars().all()
|
||||
result = session.execute(stmt)
|
||||
results = result.scalars().all()
|
||||
return results
|
||||
|
||||
# 在线程中执行数据库查询
|
||||
@@ -404,7 +458,9 @@ class ChatStream:
|
||||
)
|
||||
|
||||
# 添加调试日志:检查从数据库加载的interest_value
|
||||
logger.debug(f"加载历史消息 {db_message.message_id} - interest_value: {db_message.interest_value}")
|
||||
logger.debug(
|
||||
f"加载历史消息 {db_message.message_id} - interest_value: {db_message.interest_value}"
|
||||
)
|
||||
|
||||
# 标记为已读并添加到历史消息
|
||||
db_message.is_read = True
|
||||
@@ -548,7 +604,11 @@ class ChatManager:
|
||||
# 检查数据库中是否存在
|
||||
async def _db_find_stream_async(s_id: str):
|
||||
async with get_db_session() as session:
|
||||
return (await session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id))).scalars().first()
|
||||
return (
|
||||
(await session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id)))
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
|
||||
model_instance = await _db_find_stream_async(stream_id)
|
||||
|
||||
@@ -603,6 +663,15 @@ class ChatManager:
|
||||
stream.set_context(self.last_messages[stream_id])
|
||||
else:
|
||||
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
|
||||
|
||||
# 确保 ChatStream 有自己的 context_manager
|
||||
if not hasattr(stream, "context_manager"):
|
||||
# 创建新的单流上下文管理器
|
||||
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
||||
stream.context_manager = SingleStreamContextManager(
|
||||
stream_id=stream_id, context=stream.stream_context
|
||||
)
|
||||
|
||||
# 保存到内存和数据库
|
||||
self.streams[stream_id] = stream
|
||||
await self._save_stream(stream)
|
||||
@@ -704,7 +773,8 @@ class ChatManager:
|
||||
async def _db_load_all_streams_async():
|
||||
loaded_streams_data = []
|
||||
async with get_db_session() as session:
|
||||
for model_instance in (await session.execute(select(ChatStreams))).scalars().all():
|
||||
result = await session.execute(select(ChatStreams))
|
||||
for model_instance in result.scalars().all():
|
||||
user_info_data = {
|
||||
"platform": model_instance.user_platform,
|
||||
"user_id": model_instance.user_id,
|
||||
@@ -752,6 +822,13 @@ class ChatManager:
|
||||
self.streams[stream.stream_id] = stream
|
||||
if stream.stream_id in self.last_messages:
|
||||
stream.set_context(self.last_messages[stream.stream_id])
|
||||
|
||||
# 确保 ChatStream 有自己的 context_manager
|
||||
if not hasattr(stream, "context_manager"):
|
||||
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
||||
stream.context_manager = SingleStreamContextManager(
|
||||
stream_id=stream.stream_id, context=stream.stream_context
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}", exc_info=True)
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ class MessageStorage:
|
||||
processed_plain_text = message.processed_plain_text
|
||||
|
||||
if processed_plain_text:
|
||||
processed_plain_text = MessageStorage.replace_image_descriptions(processed_plain_text)
|
||||
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
|
||||
filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL)
|
||||
else:
|
||||
filtered_processed_plain_text = ""
|
||||
@@ -129,9 +129,9 @@ class MessageStorage:
|
||||
key_words=key_words,
|
||||
key_words_lite=key_words_lite,
|
||||
)
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
session.add(new_message)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
except Exception:
|
||||
logger.exception("存储消息失败")
|
||||
@@ -174,13 +174,13 @@ class MessageStorage:
|
||||
# 使用上下文管理器确保session正确管理
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
|
||||
with get_db_session() as session:
|
||||
matched_message = session.execute(
|
||||
async with get_db_session() as session:
|
||||
matched_message = (await session.execute(
|
||||
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
|
||||
).scalar()
|
||||
)).scalar()
|
||||
|
||||
if matched_message:
|
||||
session.execute(
|
||||
await session.execute(
|
||||
update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id)
|
||||
)
|
||||
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
|
||||
@@ -195,7 +195,7 @@ class MessageStorage:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def replace_image_descriptions(text: str) -> str:
|
||||
async def replace_image_descriptions(text: str) -> str:
|
||||
"""将[图片:描述]替换为[picid:image_id]"""
|
||||
# 先检查文本中是否有图片标记
|
||||
pattern = r"\[图片:([^\]]+)\]"
|
||||
@@ -205,15 +205,15 @@ class MessageStorage:
|
||||
logger.debug("文本中没有图片标记,直接返回原文本")
|
||||
return text
|
||||
|
||||
def replace_match(match):
|
||||
async def replace_match(match):
|
||||
description = match.group(1).strip()
|
||||
try:
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
|
||||
with get_db_session() as session:
|
||||
image_record = session.execute(
|
||||
async with get_db_session() as session:
|
||||
image_record = (await session.execute(
|
||||
select(Images).where(Images.description == description).order_by(desc(Images.timestamp))
|
||||
).scalar()
|
||||
)).scalar()
|
||||
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
|
||||
except Exception:
|
||||
return match.group(0)
|
||||
@@ -271,7 +271,8 @@ class MessageStorage:
|
||||
)
|
||||
).limit(50) # 限制每次修复的数量,避免性能问题
|
||||
|
||||
messages_to_fix = session.execute(query).scalars().all()
|
||||
result = session.execute(query)
|
||||
messages_to_fix = result.scalars().all()
|
||||
fixed_count = 0
|
||||
|
||||
for msg in messages_to_fix:
|
||||
|
||||
@@ -824,7 +824,8 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
description = "[图片内容未知]" # 默认描述
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
image = session.execute(select(Images).where(Images.image_id == pic_id)).scalar_one_or_none()
|
||||
result = session.execute(select(Images).where(Images.image_id == pic_id))
|
||||
image = result.scalar_one_or_none()
|
||||
if image and image.description: # type: ignore
|
||||
description = image.description
|
||||
except Exception:
|
||||
|
||||
@@ -308,7 +308,8 @@ class ImageManager:
|
||||
|
||||
async with get_db_session() as session:
|
||||
# 优先检查Images表中是否已有完整的描述
|
||||
existing_image = (await session.execute(select(Images).where(Images.emoji_hash == image_hash))).scalar()
|
||||
existing_image = result = await session.execute(select(Images).where(Images.emoji_hash == image_hash))
|
||||
result.scalar()
|
||||
if existing_image:
|
||||
# 更新计数
|
||||
if hasattr(existing_image, "count") and existing_image.count is not None:
|
||||
@@ -527,7 +528,8 @@ class ImageManager:
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
async with get_db_session() as session:
|
||||
existing_image = (await session.execute(select(Images).where(Images.emoji_hash == image_hash))).scalar()
|
||||
existing_image = result = await session.execute(select(Images).where(Images.emoji_hash == image_hash))
|
||||
result.scalar()
|
||||
if existing_image:
|
||||
# 检查是否缺少必要字段,如果缺少则创建新记录
|
||||
if (
|
||||
|
||||
@@ -22,13 +22,14 @@ from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import get_db_session, Videos
|
||||
from sqlalchemy import select
|
||||
|
||||
logger = get_logger("utils_video")
|
||||
|
||||
# Rust模块可用性检测
|
||||
RUST_VIDEO_AVAILABLE = False
|
||||
try:
|
||||
import rust_video
|
||||
import rust_video # pyright: ignore[reportMissingImports]
|
||||
|
||||
RUST_VIDEO_AVAILABLE = True
|
||||
logger.info("✅ Rust 视频处理模块加载成功")
|
||||
@@ -202,19 +203,21 @@ class VideoAnalyzer:
|
||||
hash_obj.update(video_data)
|
||||
return hash_obj.hexdigest()
|
||||
|
||||
def _check_video_exists(self, video_hash: str) -> Optional[Videos]:
|
||||
async def _check_video_exists(self, video_hash: str) -> Optional[Videos]:
|
||||
"""检查视频是否已经分析过"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 明确刷新会话以确保看到其他事务的最新提交
|
||||
session.expire_all()
|
||||
return session.query(Videos).filter(Videos.video_hash == video_hash).first()
|
||||
await session.expire_all()
|
||||
stmt = select(Videos).where(Videos.video_hash == video_hash)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.warning(f"检查视频是否存在时出错: {e}")
|
||||
return None
|
||||
|
||||
def _store_video_result(
|
||||
self, video_hash: str, description: str, metadata: Optional[Dict] = None
|
||||
async def _store_video_result(
|
||||
self, video_hash: str, description: str, metadata: Optional[Dict] = None
|
||||
) -> Optional[Videos]:
|
||||
"""存储视频分析结果到数据库"""
|
||||
# 检查描述是否为错误信息,如果是则不保存
|
||||
@@ -223,9 +226,11 @@ class VideoAnalyzer:
|
||||
return None
|
||||
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 只根据video_hash查找
|
||||
existing_video = session.query(Videos).filter(Videos.video_hash == video_hash).first()
|
||||
stmt = select(Videos).where(Videos.video_hash == video_hash)
|
||||
result = await session.execute(stmt)
|
||||
existing_video = result.scalar_one_or_none()
|
||||
|
||||
if existing_video:
|
||||
# 如果已存在,更新描述和计数
|
||||
@@ -238,8 +243,8 @@ class VideoAnalyzer:
|
||||
existing_video.fps = metadata.get("fps")
|
||||
existing_video.resolution = metadata.get("resolution")
|
||||
existing_video.file_size = metadata.get("file_size")
|
||||
session.commit()
|
||||
session.refresh(existing_video)
|
||||
await session.commit()
|
||||
await session.refresh(existing_video)
|
||||
logger.info(f"✅ 更新已存在的视频记录,hash: {video_hash[:16]}..., count: {existing_video.count}")
|
||||
return existing_video
|
||||
else:
|
||||
@@ -254,8 +259,8 @@ class VideoAnalyzer:
|
||||
video_record.file_size = metadata.get("file_size")
|
||||
|
||||
session.add(video_record)
|
||||
session.commit()
|
||||
session.refresh(video_record)
|
||||
await session.commit()
|
||||
await session.refresh(video_record)
|
||||
logger.info(f"✅ 新视频分析结果已保存到数据库,hash: {video_hash[:16]}...")
|
||||
return video_record
|
||||
except Exception as e:
|
||||
@@ -704,7 +709,7 @@ class VideoAnalyzer:
|
||||
logger.info("✅ 等待结束,检查是否有处理结果")
|
||||
|
||||
# 检查是否有结果了
|
||||
existing_video = self._check_video_exists(video_hash)
|
||||
existing_video = await self._check_video_exists(video_hash)
|
||||
if existing_video:
|
||||
logger.info(f"✅ 找到了处理结果,直接返回 (id: {existing_video.id})")
|
||||
return {"summary": existing_video.description}
|
||||
@@ -718,7 +723,7 @@ class VideoAnalyzer:
|
||||
logger.info(f"🔒 获得视频处理锁,开始处理 (hash: {video_hash[:16]}...)")
|
||||
|
||||
# 再次检查数据库(可能在等待期间已经有结果了)
|
||||
existing_video = self._check_video_exists(video_hash)
|
||||
existing_video = await self._check_video_exists(video_hash)
|
||||
if existing_video:
|
||||
logger.info(f"✅ 获得锁后发现已有结果,直接返回 (id: {existing_video.id})")
|
||||
video_event.set() # 通知其他等待者
|
||||
@@ -749,7 +754,7 @@ class VideoAnalyzer:
|
||||
# 保存分析结果到数据库(仅保存成功的结果)
|
||||
if success and not result.startswith("❌"):
|
||||
metadata = {"filename": filename, "file_size": len(video_bytes), "analysis_timestamp": time.time()}
|
||||
self._store_video_result(video_hash=video_hash, description=result, metadata=metadata)
|
||||
await self._store_video_result(video_hash=video_hash, description=result, metadata=metadata)
|
||||
logger.info("✅ 分析结果已保存到数据库")
|
||||
else:
|
||||
logger.warning("⚠️ 分析失败,不保存到数据库以便后续重试")
|
||||
|
||||
@@ -22,9 +22,9 @@ class DatabaseProxy:
|
||||
self._session = None
|
||||
|
||||
@staticmethod
|
||||
def initialize(*args, **kwargs):
|
||||
async def initialize(*args, **kwargs):
|
||||
"""初始化数据库连接"""
|
||||
return initialize_database_compat()
|
||||
return await initialize_database_compat()
|
||||
|
||||
|
||||
class SQLAlchemyTransaction:
|
||||
@@ -88,7 +88,7 @@ async def initialize_sql_database(database_config):
|
||||
logger.info(f" 数据库文件: {db_path}")
|
||||
|
||||
# 使用SQLAlchemy初始化
|
||||
success = initialize_database_compat()
|
||||
success = await initialize_database_compat()
|
||||
if success:
|
||||
_sql_engine = await get_engine()
|
||||
logger.info("SQLAlchemy数据库初始化成功")
|
||||
|
||||
@@ -706,7 +706,8 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
raise RuntimeError("Database session not initialized")
|
||||
session = SessionLocal()
|
||||
yield session
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.error(f"数据库会话错误: {e}")
|
||||
if session:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
@@ -101,7 +101,8 @@ def find_messages(
|
||||
# 获取时间最早的 limit 条记录,已经是正序
|
||||
query = query.order_by(Messages.time.asc()).limit(limit)
|
||||
try:
|
||||
results = session.execute(query).scalars().all()
|
||||
results = result = session.execute(query)
|
||||
result.scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"执行earliest查询失败: {e}")
|
||||
results = []
|
||||
@@ -109,7 +110,8 @@ def find_messages(
|
||||
# 获取时间最晚的 limit 条记录
|
||||
query = query.order_by(Messages.time.desc()).limit(limit)
|
||||
try:
|
||||
latest_results = session.execute(query).scalars().all()
|
||||
latest_results = result = session.execute(query)
|
||||
result.scalars().all()
|
||||
# 将结果按时间正序排列
|
||||
results = sorted(latest_results, key=lambda msg: msg.time)
|
||||
except Exception as e:
|
||||
@@ -133,7 +135,8 @@ def find_messages(
|
||||
if sort_terms:
|
||||
query = query.order_by(*sort_terms)
|
||||
try:
|
||||
results = session.execute(query).scalars().all()
|
||||
results = result = session.execute(query)
|
||||
result.scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"执行无限制查询失败: {e}")
|
||||
results = []
|
||||
@@ -207,5 +210,5 @@ def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
|
||||
|
||||
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
|
||||
# 注意:对于 SQLAlchemy,插入操作通常是使用 session.add() 和 session.commit()。
|
||||
# 注意:对于 SQLAlchemy,插入操作通常是使用 await session.add() 和 await session.commit()。
|
||||
# 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。
|
||||
|
||||
@@ -161,7 +161,7 @@ class LLMUsageRecorder:
|
||||
session = None
|
||||
try:
|
||||
# 使用 SQLAlchemy 会话创建记录
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
usage_record = LLMUsage(
|
||||
model_name=model_info.model_identifier,
|
||||
model_assign_name=model_info.name,
|
||||
@@ -172,14 +172,14 @@ class LLMUsageRecorder:
|
||||
prompt_tokens=model_usage.prompt_tokens or 0,
|
||||
completion_tokens=model_usage.completion_tokens or 0,
|
||||
total_tokens=model_usage.total_tokens or 0,
|
||||
cost=total_cost or 0.0,
|
||||
cost=1.0,
|
||||
time_cost=round(time_cost or 0.0, 3),
|
||||
status="success",
|
||||
timestamp=datetime.now(), # SQLAlchemy 会处理 DateTime 字段
|
||||
)
|
||||
|
||||
session.add(usage_record)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
logger.debug(
|
||||
f"Token使用情况 - 模型: {model_usage.model_name}, "
|
||||
|
||||
@@ -163,7 +163,8 @@ class PersonInfoManager:
|
||||
try:
|
||||
# 在需要时获取会话
|
||||
async with get_db_session() as session:
|
||||
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name))).scalar()
|
||||
record = result = await session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name))
|
||||
result.scalar()
|
||||
return record.person_id if record else ""
|
||||
except Exception as e:
|
||||
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (SQLAlchemy): {e}")
|
||||
@@ -339,7 +340,8 @@ class PersonInfoManager:
|
||||
start_time = time.time()
|
||||
async with get_db_session() as session:
|
||||
try:
|
||||
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
|
||||
result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))
|
||||
record = result.scalar()
|
||||
query_time = time.time()
|
||||
if record:
|
||||
setattr(record, f_name, val_to_set)
|
||||
@@ -401,7 +403,8 @@ class PersonInfoManager:
|
||||
|
||||
async def _db_has_field_async(p_id: str, f_name: str):
|
||||
async with get_db_session() as session:
|
||||
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
|
||||
result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))
|
||||
record = result.scalar()
|
||||
return bool(record)
|
||||
|
||||
try:
|
||||
@@ -512,10 +515,9 @@ class PersonInfoManager:
|
||||
|
||||
async def _db_check_name_exists_async(name_to_check):
|
||||
async with get_db_session() as session:
|
||||
return (
|
||||
(await session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check))).scalar()
|
||||
is not None
|
||||
)
|
||||
result = await session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check))
|
||||
record = result.scalar()
|
||||
return record is not None
|
||||
|
||||
if await _db_check_name_exists_async(generated_nickname):
|
||||
is_duplicate = True
|
||||
@@ -556,7 +558,8 @@ class PersonInfoManager:
|
||||
async def _db_delete_async(p_id: str):
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
|
||||
result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))
|
||||
record = result.scalar()
|
||||
if record:
|
||||
await session.delete(record)
|
||||
await session.commit()
|
||||
@@ -585,7 +588,9 @@ class PersonInfoManager:
|
||||
|
||||
async def _get_record_sync():
|
||||
async with get_db_session() as session:
|
||||
return (await session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id))).scalar()
|
||||
result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id))
|
||||
record = result.scalar()
|
||||
return record
|
||||
|
||||
try:
|
||||
record = asyncio.run(_get_record_sync())
|
||||
@@ -624,7 +629,9 @@ class PersonInfoManager:
|
||||
|
||||
async def _db_get_record_async(p_id: str):
|
||||
async with get_db_session() as session:
|
||||
return (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
|
||||
result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))
|
||||
record = result.scalar()
|
||||
return record
|
||||
|
||||
record = await _db_get_record_async(person_id)
|
||||
|
||||
@@ -700,7 +707,8 @@ class PersonInfoManager:
|
||||
"""原子性的获取或创建操作"""
|
||||
async with get_db_session() as session:
|
||||
# 首先尝试获取现有记录
|
||||
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
|
||||
result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))
|
||||
record = result.scalar()
|
||||
if record:
|
||||
return record, False # 记录存在,未创建
|
||||
|
||||
@@ -715,9 +723,10 @@ class PersonInfoManager:
|
||||
# 如果创建失败(可能是因为竞态条件),再次尝试获取
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
logger.debug(f"检测到并发创建用户 {p_id},获取现有记录")
|
||||
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
|
||||
if record:
|
||||
return record, False # 其他协程已创建,返回现有记录
|
||||
result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))
|
||||
record = result.scalar()
|
||||
if record:
|
||||
return record, False # 其他协程已创建,返回现有记录
|
||||
# 如果仍然失败,重新抛出异常
|
||||
raise e
|
||||
|
||||
|
||||
@@ -122,7 +122,7 @@ async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
|
||||
matched_emotion = random.choice(selected_emoji.emotion) if selected_emoji.emotion else "随机表情"
|
||||
|
||||
# 记录使用次数
|
||||
emoji_manager.record_usage(selected_emoji.hash)
|
||||
await emoji_manager.record_usage(selected_emoji.hash)
|
||||
results.append((emoji_base64, selected_emoji.description, matched_emotion))
|
||||
|
||||
if not results and count > 0:
|
||||
@@ -180,7 +180,7 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
|
||||
return None
|
||||
|
||||
# 记录使用次数
|
||||
emoji_manager.record_usage(selected_emoji.hash)
|
||||
await emoji_manager.record_usage(selected_emoji.hash)
|
||||
|
||||
logger.info(f"[EmojiAPI] 成功获取情感表情包: {selected_emoji.description}")
|
||||
return emoji_base64, selected_emoji.description, emotion
|
||||
|
||||
@@ -65,7 +65,7 @@ class AffinityChatter(BaseChatter):
|
||||
"""
|
||||
try:
|
||||
# 触发表达学习
|
||||
learner = expression_learner_manager.get_expression_learner(self.stream_id)
|
||||
learner = await expression_learner_manager.get_expression_learner(self.stream_id)
|
||||
asyncio.create_task(learner.trigger_learning_for_chat())
|
||||
|
||||
unread_messages = context.get_unread_messages()
|
||||
|
||||
@@ -69,7 +69,7 @@ class ChatterInterestScoringSystem:
|
||||
|
||||
keywords = self._extract_keywords_from_database(message)
|
||||
interest_match_score = await self._calculate_interest_match_score(message.processed_plain_text, keywords)
|
||||
relationship_score = self._calculate_relationship_score(message.user_info.user_id)
|
||||
relationship_score = await self._calculate_relationship_score(message.user_info.user_id)
|
||||
mentioned_score = self._calculate_mentioned_score(message, bot_nickname)
|
||||
|
||||
total_score = (
|
||||
@@ -189,7 +189,7 @@ class ChatterInterestScoringSystem:
|
||||
unique_keywords = list(set(keywords))
|
||||
return unique_keywords[:10] # 返回前10个唯一关键词
|
||||
|
||||
def _calculate_relationship_score(self, user_id: str) -> float:
|
||||
async def _calculate_relationship_score(self, user_id: str) -> float:
|
||||
"""计算关系分 - 从数据库获取关系分"""
|
||||
# 优先使用内存中的关系分
|
||||
if user_id in self.user_relationships:
|
||||
@@ -212,7 +212,7 @@ class ChatterInterestScoringSystem:
|
||||
|
||||
global_tracker = ChatterRelationshipTracker()
|
||||
if global_tracker:
|
||||
relationship_score = global_tracker.get_user_relationship_score(user_id)
|
||||
relationship_score = await global_tracker.get_user_relationship_score(user_id)
|
||||
# 同时更新内存缓存
|
||||
self.user_relationships[user_id] = relationship_score
|
||||
return relationship_score
|
||||
|
||||
@@ -287,7 +287,7 @@ class ChatterRelationshipTracker:
|
||||
|
||||
# ===== 数据库支持方法 =====
|
||||
|
||||
def get_user_relationship_score(self, user_id: str) -> float:
|
||||
async def get_user_relationship_score(self, user_id: str) -> float:
|
||||
"""获取用户关系分"""
|
||||
# 先检查缓存
|
||||
if user_id in self.user_relationship_cache:
|
||||
@@ -298,7 +298,7 @@ class ChatterRelationshipTracker:
|
||||
return cache_data.get("relationship_score", global_config.affinity_flow.base_relationship_score)
|
||||
|
||||
# 缓存过期或不存在,从数据库获取
|
||||
relationship_data = self._get_user_relationship_from_db(user_id)
|
||||
relationship_data = await self._get_user_relationship_from_db(user_id)
|
||||
if relationship_data:
|
||||
# 更新缓存
|
||||
self.user_relationship_cache[user_id] = {
|
||||
@@ -313,37 +313,38 @@ class ChatterRelationshipTracker:
|
||||
# 数据库中也没有,返回默认值
|
||||
return global_config.affinity_flow.base_relationship_score
|
||||
|
||||
def _get_user_relationship_from_db(self, user_id: str) -> Optional[Dict]:
|
||||
async def _get_user_relationship_from_db(self, user_id: str) -> Optional[Dict]:
|
||||
"""从数据库获取用户关系数据"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 查询用户关系表
|
||||
stmt = select(UserRelationships).where(UserRelationships.user_id == user_id)
|
||||
result = session.execute(stmt).scalar_one_or_none()
|
||||
result = await session.execute(stmt)
|
||||
relationship = result.scalar_one_or_none()
|
||||
|
||||
if result:
|
||||
if relationship:
|
||||
return {
|
||||
"relationship_text": result.relationship_text or "",
|
||||
"relationship_score": float(result.relationship_score)
|
||||
if result.relationship_score is not None
|
||||
"relationship_text": relationship.relationship_text or "",
|
||||
"relationship_score": float(relationship.relationship_score)
|
||||
if relationship.relationship_score is not None
|
||||
else 0.3,
|
||||
"last_updated": result.last_updated,
|
||||
"last_updated": relationship.last_updated,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库获取用户关系失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _update_user_relationship_in_db(self, user_id: str, relationship_text: str, relationship_score: float):
|
||||
async def _update_user_relationship_in_db(self, user_id: str, relationship_text: str, relationship_score: float):
|
||||
"""更新数据库中的用户关系"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 检查是否已存在关系记录
|
||||
existing = session.execute(
|
||||
select(UserRelationships).where(UserRelationships.user_id == user_id)
|
||||
).scalar_one_or_none()
|
||||
stmt = select(UserRelationships).where(UserRelationships.user_id == user_id)
|
||||
result = await session.execute(stmt)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
# 更新现有记录
|
||||
@@ -362,7 +363,7 @@ class ChatterRelationshipTracker:
|
||||
)
|
||||
session.add(new_relationship)
|
||||
|
||||
session.commit()
|
||||
await session.commit()
|
||||
logger.info(f"已更新数据库中用户关系: {user_id} -> 分数: {relationship_score:.3f}")
|
||||
|
||||
except Exception as e:
|
||||
@@ -399,7 +400,7 @@ class ChatterRelationshipTracker:
|
||||
logger.debug(f"💬 [RelationshipTracker] 找到用户 {user_id} 在上次回复后的 {len(user_reactions)} 条反应消息")
|
||||
|
||||
# 获取当前关系数据
|
||||
current_relationship = self._get_user_relationship_from_db(user_id)
|
||||
current_relationship = await self._get_user_relationship_from_db(user_id)
|
||||
current_score = (
|
||||
current_relationship.get("relationship_score", global_config.affinity_flow.base_relationship_score)
|
||||
if current_relationship
|
||||
@@ -417,14 +418,14 @@ class ChatterRelationshipTracker:
|
||||
logger.error(f"回复后关系追踪失败: {e}")
|
||||
logger.debug("错误详情:", exc_info=True)
|
||||
|
||||
def _get_last_tracked_time(self, user_id: str) -> float:
|
||||
async def _get_last_tracked_time(self, user_id: str) -> float:
|
||||
"""获取上次追踪时间"""
|
||||
# 先检查缓存
|
||||
if user_id in self.user_relationship_cache:
|
||||
return self.user_relationship_cache[user_id].get("last_tracked", 0)
|
||||
|
||||
# 从数据库获取
|
||||
relationship_data = self._get_user_relationship_from_db(user_id)
|
||||
relationship_data = await self._get_user_relationship_from_db(user_id)
|
||||
if relationship_data:
|
||||
return relationship_data.get("last_updated", 0)
|
||||
|
||||
@@ -433,7 +434,7 @@ class ChatterRelationshipTracker:
|
||||
async def _get_last_bot_reply_to_user(self, user_id: str) -> Optional[DatabaseMessages]:
|
||||
"""获取上次bot回复该用户的消息"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 查询bot回复给该用户的最新消息
|
||||
stmt = (
|
||||
select(Messages)
|
||||
@@ -443,10 +444,11 @@ class ChatterRelationshipTracker:
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
result = session.execute(stmt).scalar_one_or_none()
|
||||
if result:
|
||||
result = await session.execute(stmt)
|
||||
message = result.scalar_one_or_none()
|
||||
if message:
|
||||
# 将SQLAlchemy模型转换为DatabaseMessages对象
|
||||
return self._sqlalchemy_to_database_messages(result)
|
||||
return self._sqlalchemy_to_database_messages(message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取上次回复消息失败: {e}")
|
||||
@@ -456,7 +458,7 @@ class ChatterRelationshipTracker:
|
||||
async def _get_user_reactions_after_reply(self, user_id: str, reply_time: float) -> List[DatabaseMessages]:
|
||||
"""获取用户在bot回复后的反应消息"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 查询用户在回复时间之后的5分钟内的消息
|
||||
end_time = reply_time + 5 * 60 # 5分钟
|
||||
|
||||
@@ -468,9 +470,10 @@ class ChatterRelationshipTracker:
|
||||
.order_by(Messages.time)
|
||||
)
|
||||
|
||||
results = session.execute(stmt).scalars().all()
|
||||
if results:
|
||||
return [self._sqlalchemy_to_database_messages(result) for result in results]
|
||||
result = await session.execute(stmt)
|
||||
messages = result.scalars().all()
|
||||
if messages:
|
||||
return [self._sqlalchemy_to_database_messages(message) for message in messages]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户反应消息失败: {e}")
|
||||
@@ -593,7 +596,7 @@ class ChatterRelationshipTracker:
|
||||
quality = response_data.get("interaction_quality", "medium")
|
||||
|
||||
# 更新数据库
|
||||
self._update_user_relationship_in_db(user_id, new_text, new_score)
|
||||
await self._update_user_relationship_in_db(user_id, new_text, new_score)
|
||||
|
||||
# 更新缓存
|
||||
self.user_relationship_cache[user_id] = {
|
||||
@@ -696,7 +699,7 @@ class ChatterRelationshipTracker:
|
||||
)
|
||||
|
||||
# 更新数据库和缓存
|
||||
self._update_user_relationship_in_db(user_id, new_text, new_score)
|
||||
await self._update_user_relationship_in_db(user_id, new_text, new_score)
|
||||
self.user_relationship_cache[user_id] = {
|
||||
"relationship_text": new_text,
|
||||
"relationship_score": new_score,
|
||||
|
||||
@@ -13,6 +13,7 @@ from typing import Callable
|
||||
from src.common.logger import get_logger
|
||||
from src.schedule.schedule_manager import schedule_manager
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from sqlalchemy import select
|
||||
from src.common.database.sqlalchemy_models import MaiZoneScheduleStatus
|
||||
|
||||
from .qzone_service import QZoneService
|
||||
@@ -138,15 +139,13 @@ class SchedulerService:
|
||||
:return: 如果已处理过,返回 True,否则返回 False。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
record = (
|
||||
session.query(MaiZoneScheduleStatus)
|
||||
.filter(
|
||||
MaiZoneScheduleStatus.datetime_hour == hour_str,
|
||||
MaiZoneScheduleStatus.is_processed == True, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
async with get_db_session() as session:
|
||||
stmt = select(MaiZoneScheduleStatus).where(
|
||||
MaiZoneScheduleStatus.datetime_hour == hour_str,
|
||||
MaiZoneScheduleStatus.is_processed == True, # noqa: E712
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
record = result.scalar_one_or_none()
|
||||
return record is not None
|
||||
except Exception as e:
|
||||
logger.error(f"检查日程处理状态时发生数据库错误: {e}")
|
||||
@@ -162,11 +161,11 @@ class SchedulerService:
|
||||
:param content: 最终发送的说说内容或错误信息。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 查找是否已存在该记录
|
||||
record = (
|
||||
session.query(MaiZoneScheduleStatus).filter(MaiZoneScheduleStatus.datetime_hour == hour_str).first()
|
||||
)
|
||||
stmt = select(MaiZoneScheduleStatus).where(MaiZoneScheduleStatus.datetime_hour == hour_str)
|
||||
result = await session.execute(stmt)
|
||||
record = result.scalar_one_or_none()
|
||||
|
||||
if record:
|
||||
# 如果存在,则更新状态
|
||||
@@ -185,7 +184,7 @@ class SchedulerService:
|
||||
send_success=success,
|
||||
)
|
||||
session.add(new_record)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
logger.info(f"已更新日程处理状态: {hour_str} - {activity} - 成功: {success}")
|
||||
except Exception as e:
|
||||
logger.error(f"更新日程处理状态时发生数据库错误: {e}")
|
||||
|
||||
@@ -64,15 +64,9 @@ async def message_recv(server_connection: Server.ServerConnection):
|
||||
|
||||
# 处理完整消息(可能是重组后的,也可能是原本就完整的)
|
||||
post_type = decoded_raw_message.get("post_type")
|
||||
|
||||
# 兼容没有 post_type 的普通消息
|
||||
if not post_type and "message_type" in decoded_raw_message:
|
||||
decoded_raw_message["post_type"] = "message"
|
||||
post_type = "message"
|
||||
|
||||
if post_type in ["meta_event", "message", "notice"]:
|
||||
await message_queue.put(decoded_raw_message)
|
||||
else:
|
||||
elif post_type is None:
|
||||
await put_response(decoded_raw_message)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
@@ -428,8 +422,9 @@ class NapcatAdapterPlugin(BasePlugin):
|
||||
def get_plugin_components(self):
|
||||
self.register_events()
|
||||
|
||||
components = [(LauchNapcatAdapterHandler.get_handler_info(), LauchNapcatAdapterHandler),
|
||||
(StopNapcatAdapterHandler.get_handler_info(), StopNapcatAdapterHandler)]
|
||||
components = []
|
||||
components.append((LauchNapcatAdapterHandler.get_handler_info(), LauchNapcatAdapterHandler))
|
||||
components.append((StopNapcatAdapterHandler.get_handler_info(), StopNapcatAdapterHandler))
|
||||
for handler in get_classes_in_module(event_handlers):
|
||||
if issubclass(handler, BaseEventHandler):
|
||||
components.append((handler.get_handler_info(), handler))
|
||||
|
||||
@@ -1,156 +1,162 @@
|
||||
"""Napcat Adapter 插件数据库层 (基于主程序异步SQLAlchemy API)
|
||||
|
||||
本模块替换原先的 sqlmodel + 同步Session 实现:
|
||||
1. 复用主项目的异步数据库连接与迁移体系
|
||||
2. 提供与旧接口名兼容的方法(update_ban_record/create_ban_record/delete_ban_record)
|
||||
3. 新增首选异步方法: update_ban_records / create_or_update / delete_record / get_ban_records
|
||||
|
||||
数据语义:
|
||||
user_id == 0 表示群全体禁言
|
||||
|
||||
注意: 所有方法均为异步, 需要在 async 上下文中调用。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Optional, List
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List, Sequence
|
||||
from sqlmodel import Field, Session, SQLModel, create_engine, select
|
||||
|
||||
from sqlalchemy import Column, Integer, BigInteger, UniqueConstraint, select, Index
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.common.database.sqlalchemy_models import Base, get_db_session
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
"""
|
||||
表记录的方式:
|
||||
| group_id | user_id | lift_time |
|
||||
|----------|---------|-----------|
|
||||
|
||||
class NapcatBanRecord(Base):
|
||||
__tablename__ = "napcat_ban_records"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
group_id = Column(BigInteger, nullable=False, index=True)
|
||||
user_id = Column(BigInteger, nullable=False, index=True) # 0 == 全体禁言
|
||||
lift_time = Column(BigInteger, nullable=True) # -1 / None 表示未知/永久
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("group_id", "user_id", name="uq_napcat_group_user"),
|
||||
Index("idx_napcat_ban_group", "group_id"),
|
||||
Index("idx_napcat_ban_user", "user_id"),
|
||||
)
|
||||
其中使用 user_id == 0 表示群全体禁言
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BanUser:
|
||||
"""
|
||||
程序处理使用的实例
|
||||
"""
|
||||
|
||||
user_id: int
|
||||
group_id: int
|
||||
lift_time: Optional[int] = -1
|
||||
|
||||
def identity(self) -> tuple[int, int]:
|
||||
return self.group_id, self.user_id
|
||||
lift_time: Optional[int] = Field(default=-1)
|
||||
|
||||
|
||||
class NapcatDatabase:
|
||||
async def _fetch_all(self, session: AsyncSession) -> Sequence[NapcatBanRecord]:
|
||||
result = await session.execute(select(NapcatBanRecord))
|
||||
return result.scalars().all()
|
||||
class DB_BanUser(SQLModel, table=True):
|
||||
"""
|
||||
表示数据库中的用户禁言记录。
|
||||
使用双重主键
|
||||
"""
|
||||
|
||||
async def get_ban_records(self) -> List[BanUser]:
|
||||
async with get_db_session() as session:
|
||||
rows = await self._fetch_all(session)
|
||||
return [BanUser(group_id=r.group_id, user_id=r.user_id, lift_time=r.lift_time) for r in rows]
|
||||
user_id: int = Field(index=True, primary_key=True) # 被禁言用户的用户 ID
|
||||
group_id: int = Field(index=True, primary_key=True) # 用户被禁言的群组 ID
|
||||
lift_time: Optional[int] # 禁言解除的时间(时间戳)
|
||||
|
||||
async def update_ban_records(self, ban_list: List[BanUser]) -> None:
|
||||
target_map = {b.identity(): b for b in ban_list}
|
||||
async with get_db_session() as session:
|
||||
rows = await self._fetch_all(session)
|
||||
existing_map = {(r.group_id, r.user_id): r for r in rows}
|
||||
|
||||
changed = 0
|
||||
for ident, ban in target_map.items():
|
||||
if ident in existing_map:
|
||||
row = existing_map[ident]
|
||||
if row.lift_time != ban.lift_time:
|
||||
row.lift_time = ban.lift_time
|
||||
changed += 1
|
||||
def is_identical(obj1: BanUser, obj2: BanUser) -> bool:
|
||||
"""
|
||||
检查两个 BanUser 对象是否相同。
|
||||
"""
|
||||
return obj1.user_id == obj2.user_id and obj1.group_id == obj2.group_id
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""
|
||||
数据库管理类,负责与数据库交互。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
os.makedirs(os.path.join(os.path.dirname(__file__), "..", "data"), exist_ok=True) # 确保数据目录存在
|
||||
DATABASE_FILE = os.path.join(os.path.dirname(__file__), "..", "data", "NapcatAdapter.db")
|
||||
self.sqlite_url = f"sqlite:///{DATABASE_FILE}" # SQLite 数据库 URL
|
||||
self.engine = create_engine(self.sqlite_url, echo=False) # 创建数据库引擎
|
||||
self._ensure_database() # 确保数据库和表已创建
|
||||
|
||||
def _ensure_database(self) -> None:
|
||||
"""
|
||||
确保数据库和表已创建。
|
||||
"""
|
||||
logger.info("确保数据库文件和表已创建...")
|
||||
SQLModel.metadata.create_all(self.engine)
|
||||
logger.info("数据库和表已创建或已存在")
|
||||
|
||||
def update_ban_record(self, ban_list: List[BanUser]) -> None:
|
||||
# sourcery skip: class-extract-method
|
||||
"""
|
||||
更新禁言列表到数据库。
|
||||
支持在不存在时创建新记录,对于多余的项目自动删除。
|
||||
"""
|
||||
with Session(self.engine) as session:
|
||||
all_records = session.exec(select(DB_BanUser)).all()
|
||||
for ban_user in ban_list:
|
||||
statement = select(DB_BanUser).where(
|
||||
DB_BanUser.user_id == ban_user.user_id, DB_BanUser.group_id == ban_user.group_id
|
||||
)
|
||||
if existing_record := session.exec(statement).first():
|
||||
if existing_record.lift_time == ban_user.lift_time:
|
||||
logger.debug(f"禁言记录未变更: {existing_record}")
|
||||
continue
|
||||
# 更新现有记录的 lift_time
|
||||
existing_record.lift_time = ban_user.lift_time
|
||||
session.add(existing_record)
|
||||
logger.debug(f"更新禁言记录: {existing_record}")
|
||||
else:
|
||||
session.add(
|
||||
NapcatBanRecord(group_id=ban.group_id, user_id=ban.user_id, lift_time=ban.lift_time)
|
||||
# 创建新记录
|
||||
db_record = DB_BanUser(
|
||||
user_id=ban_user.user_id, group_id=ban_user.group_id, lift_time=ban_user.lift_time
|
||||
)
|
||||
changed += 1
|
||||
|
||||
removed = 0
|
||||
for ident, row in existing_map.items():
|
||||
if ident not in target_map:
|
||||
await session.delete(row)
|
||||
removed += 1
|
||||
|
||||
logger.debug(
|
||||
f"Napcat ban list sync => total_incoming={len(ban_list)} created_or_updated={changed} removed={removed}"
|
||||
)
|
||||
|
||||
async def create_or_update(self, ban_record: BanUser) -> None:
|
||||
async with get_db_session() as session:
|
||||
stmt = select(NapcatBanRecord).where(
|
||||
NapcatBanRecord.group_id == ban_record.group_id,
|
||||
NapcatBanRecord.user_id == ban_record.user_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
row = result.scalars().first()
|
||||
if row:
|
||||
if row.lift_time != ban_record.lift_time:
|
||||
row.lift_time = ban_record.lift_time
|
||||
logger.debug(
|
||||
f"更新禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={ban_record.lift_time}"
|
||||
session.add(db_record)
|
||||
logger.debug(f"创建新禁言记录: {ban_user}")
|
||||
# 删除不在 ban_list 中的记录
|
||||
for db_record in all_records:
|
||||
record = BanUser(user_id=db_record.user_id, group_id=db_record.group_id, lift_time=db_record.lift_time)
|
||||
if not any(is_identical(record, ban_user) for ban_user in ban_list):
|
||||
statement = select(DB_BanUser).where(
|
||||
DB_BanUser.user_id == record.user_id, DB_BanUser.group_id == record.group_id
|
||||
)
|
||||
if ban_record := session.exec(statement).first():
|
||||
session.delete(ban_record)
|
||||
|
||||
logger.debug(f"删除禁言记录: {ban_record}")
|
||||
else:
|
||||
logger.info(f"未找到禁言记录: {ban_record}")
|
||||
|
||||
logger.info("禁言记录已更新")
|
||||
|
||||
def get_ban_records(self) -> List[BanUser]:
|
||||
"""
|
||||
读取所有禁言记录。
|
||||
"""
|
||||
with Session(self.engine) as session:
|
||||
statement = select(DB_BanUser)
|
||||
records = session.exec(statement).all()
|
||||
return [BanUser(user_id=item.user_id, group_id=item.group_id, lift_time=item.lift_time) for item in records]
|
||||
|
||||
def create_ban_record(self, ban_record: BanUser) -> None:
|
||||
"""
|
||||
为特定群组中的用户创建禁言记录。
|
||||
一个简化版本的添加方式,防止 update_ban_record 方法的复杂性。
|
||||
其同时还是简化版的更新方式。
|
||||
"""
|
||||
with Session(self.engine) as session:
|
||||
# 检查记录是否已存在
|
||||
statement = select(DB_BanUser).where(
|
||||
DB_BanUser.user_id == ban_record.user_id, DB_BanUser.group_id == ban_record.group_id
|
||||
)
|
||||
existing_record = session.exec(statement).first()
|
||||
if existing_record:
|
||||
# 如果记录已存在,更新 lift_time
|
||||
existing_record.lift_time = ban_record.lift_time
|
||||
session.add(existing_record)
|
||||
logger.debug(f"更新禁言记录: {ban_record}")
|
||||
else:
|
||||
session.add(
|
||||
NapcatBanRecord(
|
||||
group_id=ban_record.group_id, user_id=ban_record.user_id, lift_time=ban_record.lift_time
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
f"创建禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={ban_record.lift_time}"
|
||||
# 如果记录不存在,创建新记录
|
||||
db_record = DB_BanUser(
|
||||
user_id=ban_record.user_id, group_id=ban_record.group_id, lift_time=ban_record.lift_time
|
||||
)
|
||||
session.add(db_record)
|
||||
logger.debug(f"创建新禁言记录: {ban_record}")
|
||||
|
||||
async def delete_record(self, ban_record: BanUser) -> None:
|
||||
async with get_db_session() as session:
|
||||
stmt = select(NapcatBanRecord).where(
|
||||
NapcatBanRecord.group_id == ban_record.group_id,
|
||||
NapcatBanRecord.user_id == ban_record.user_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
row = result.scalars().first()
|
||||
if row:
|
||||
await session.delete(row)
|
||||
logger.debug(
|
||||
f"删除禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={row.lift_time}"
|
||||
)
|
||||
def delete_ban_record(self, ban_record: BanUser):
|
||||
"""
|
||||
删除特定用户在特定群组中的禁言记录。
|
||||
一个简化版本的删除方式,防止 update_ban_record 方法的复杂性。
|
||||
"""
|
||||
user_id = ban_record.user_id
|
||||
group_id = ban_record.group_id
|
||||
with Session(self.engine) as session:
|
||||
statement = select(DB_BanUser).where(DB_BanUser.user_id == user_id, DB_BanUser.group_id == group_id)
|
||||
if ban_record := session.exec(statement).first():
|
||||
session.delete(ban_record)
|
||||
|
||||
logger.debug(f"删除禁言记录: {ban_record}")
|
||||
else:
|
||||
logger.info(
|
||||
f"未找到禁言记录 group={ban_record.group_id} user={ban_record.user_id}"
|
||||
)
|
||||
|
||||
# 兼容旧命名
|
||||
async def update_ban_record(self, ban_list: List[BanUser]) -> None: # old name
|
||||
await self.update_ban_records(ban_list)
|
||||
|
||||
async def create_ban_record(self, ban_record: BanUser) -> None: # old name
|
||||
await self.create_or_update(ban_record)
|
||||
|
||||
async def delete_ban_record(self, ban_record: BanUser) -> None: # old name
|
||||
await self.delete_record(ban_record)
|
||||
logger.info(f"未找到禁言记录: user_id: {user_id}, group_id: {group_id}")
|
||||
|
||||
|
||||
napcat_db = NapcatDatabase()
|
||||
|
||||
|
||||
def is_identical(a: BanUser, b: BanUser) -> bool:
|
||||
return a.group_id == b.group_id and a.user_id == b.user_id
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BanUser",
|
||||
"NapcatBanRecord",
|
||||
"napcat_db",
|
||||
"is_identical",
|
||||
]
|
||||
db_manager = DatabaseManager()
|
||||
|
||||
@@ -112,8 +112,7 @@ class MessageChunker:
|
||||
else:
|
||||
return [{"_original_message": message}]
|
||||
|
||||
@staticmethod
|
||||
def is_chunk_message(message: Union[str, Dict[str, Any]]) -> bool:
|
||||
def is_chunk_message(self, message: Union[str, Dict[str, Any]]) -> bool:
|
||||
"""判断是否是切片消息"""
|
||||
try:
|
||||
if isinstance(message, str):
|
||||
|
||||
@@ -14,7 +14,6 @@ class MetaEventHandler:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.last_heart_beat = time.time()
|
||||
self.interval = 5.0 # 默认值,稍后通过set_plugin_config设置
|
||||
self._interval_checking = False
|
||||
self.plugin_config = None
|
||||
@@ -40,6 +39,7 @@ class MetaEventHandler:
|
||||
if message["status"].get("online") and message["status"].get("good"):
|
||||
if not self._interval_checking:
|
||||
asyncio.create_task(self.check_heartbeat())
|
||||
self.last_heart_beat = time.time()
|
||||
self.interval = message.get("interval") / 1000
|
||||
else:
|
||||
self_id = message.get("self_id")
|
||||
|
||||
@@ -76,7 +76,7 @@ class SendHandler:
|
||||
processed_message = await self.handle_seg_recursive(message_segment, user_info)
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息时发生错误: {e}")
|
||||
return None
|
||||
return
|
||||
|
||||
if not processed_message:
|
||||
logger.critical("现在暂时不支持解析此回复!")
|
||||
@@ -94,7 +94,7 @@ class SendHandler:
|
||||
id_name = "user_id"
|
||||
else:
|
||||
logger.error("无法识别的消息类型")
|
||||
return None
|
||||
return
|
||||
logger.info("尝试发送到napcat")
|
||||
logger.debug(f"准备发送到napcat的消息体: action='{action}', {id_name}='{target_id}', message='{processed_message}'")
|
||||
response = await self.send_message_to_napcat(
|
||||
@@ -108,10 +108,8 @@ class SendHandler:
|
||||
logger.info("消息发送成功")
|
||||
qq_message_id = response.get("data", {}).get("message_id")
|
||||
await self.message_sent_back(raw_message_base, qq_message_id)
|
||||
return None
|
||||
else:
|
||||
logger.warning(f"消息发送失败,napcat返回:{str(response)}")
|
||||
return None
|
||||
|
||||
async def send_command(self, raw_message_base: MessageBase) -> None:
|
||||
"""
|
||||
@@ -149,7 +147,7 @@ class SendHandler:
|
||||
command, args_dict = self.handle_send_like_command(args)
|
||||
case _:
|
||||
logger.error(f"未知命令: {command_name}")
|
||||
return None
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"处理命令时发生错误: {e}")
|
||||
return None
|
||||
@@ -161,10 +159,8 @@ class SendHandler:
|
||||
response = await self.send_message_to_napcat(command, args_dict)
|
||||
if response.get("status") == "ok":
|
||||
logger.info(f"命令 {command_name} 执行成功")
|
||||
return None
|
||||
else:
|
||||
logger.warning(f"命令 {command_name} 执行失败,napcat返回:{str(response)}")
|
||||
return None
|
||||
|
||||
async def handle_adapter_command(self, raw_message_base: MessageBase) -> None:
|
||||
"""
|
||||
@@ -272,8 +268,7 @@ class SendHandler:
|
||||
new_payload = self.build_payload(payload, self.handle_file_message(file_path), False)
|
||||
return new_payload
|
||||
|
||||
@staticmethod
|
||||
def build_payload(payload: list, addon: dict | list, is_reply: bool = False) -> list:
|
||||
def build_payload(self, payload: list, addon: dict | list, is_reply: bool = False) -> list:
|
||||
# sourcery skip: for-append-to-extend, merge-list-append, simplify-generator
|
||||
"""构建发送的消息体"""
|
||||
if is_reply:
|
||||
@@ -339,13 +334,11 @@ class SendHandler:
|
||||
logger.info(f"最终返回的回复段: {reply_seg}")
|
||||
return reply_seg
|
||||
|
||||
@staticmethod
|
||||
def handle_text_message(message: str) -> dict:
|
||||
def handle_text_message(self, message: str) -> dict:
|
||||
"""处理文本消息"""
|
||||
return {"type": "text", "data": {"text": message}}
|
||||
|
||||
@staticmethod
|
||||
def handle_image_message(encoded_image: str) -> dict:
|
||||
def handle_image_message(self, encoded_image: str) -> dict:
|
||||
"""处理图片消息"""
|
||||
return {
|
||||
"type": "image",
|
||||
@@ -355,8 +348,7 @@ class SendHandler:
|
||||
},
|
||||
} # base64 编码的图片
|
||||
|
||||
@staticmethod
|
||||
def handle_emoji_message(encoded_emoji: str) -> dict:
|
||||
def handle_emoji_message(self, encoded_emoji: str) -> dict:
|
||||
"""处理表情消息"""
|
||||
encoded_image = encoded_emoji
|
||||
image_format = get_image_format(encoded_emoji)
|
||||
@@ -387,45 +379,39 @@ class SendHandler:
|
||||
"data": {"file": f"base64://{encoded_voice}"},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def handle_voiceurl_message(voice_url: str) -> dict:
|
||||
def handle_voiceurl_message(self, voice_url: str) -> dict:
|
||||
"""处理语音链接消息"""
|
||||
return {
|
||||
"type": "record",
|
||||
"data": {"file": voice_url},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def handle_music_message(song_id: str) -> dict:
|
||||
def handle_music_message(self, song_id: str) -> dict:
|
||||
"""处理音乐消息"""
|
||||
return {
|
||||
"type": "music",
|
||||
"data": {"type": "163", "id": song_id},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def handle_videourl_message(video_url: str) -> dict:
|
||||
def handle_videourl_message(self, video_url: str) -> dict:
|
||||
"""处理视频链接消息"""
|
||||
return {
|
||||
"type": "video",
|
||||
"data": {"file": video_url},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def handle_file_message(file_path: str) -> dict:
|
||||
def handle_file_message(self, file_path: str) -> dict:
|
||||
"""处理文件消息"""
|
||||
return {
|
||||
"type": "file",
|
||||
"data": {"file": f"file://{file_path}"},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def delete_msg_command(args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
def delete_msg_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理删除消息命令"""
|
||||
return "delete_msg", {"message_id": args["message_id"]}
|
||||
|
||||
@staticmethod
|
||||
def handle_ban_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
def handle_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理封禁命令
|
||||
|
||||
Args:
|
||||
@@ -453,8 +439,7 @@ class SendHandler:
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def handle_whole_ban_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
def handle_whole_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理全体禁言命令
|
||||
|
||||
Args:
|
||||
@@ -477,8 +462,7 @@ class SendHandler:
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def handle_kick_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
def handle_kick_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理群成员踢出命令
|
||||
|
||||
Args:
|
||||
@@ -503,8 +487,7 @@ class SendHandler:
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def handle_poke_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
def handle_poke_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理戳一戳命令
|
||||
|
||||
Args:
|
||||
@@ -531,8 +514,7 @@ class SendHandler:
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def handle_set_emoji_like_command(args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
def handle_set_emoji_like_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理设置表情回应命令
|
||||
|
||||
Args:
|
||||
@@ -554,8 +536,7 @@ class SendHandler:
|
||||
{"message_id": message_id, "emoji_id": emoji_id, "set": set_like},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def handle_send_like_command(args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
def handle_send_like_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
处理发送点赞命令的逻辑。
|
||||
|
||||
@@ -576,8 +557,7 @@ class SendHandler:
|
||||
{"user_id": user_id, "times": times},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def handle_ai_voice_send_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
def handle_ai_voice_send_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
处理AI语音发送命令的逻辑。
|
||||
并返回 NapCat 兼容的 (action, params) 元组。
|
||||
@@ -624,8 +604,7 @@ class SendHandler:
|
||||
return {"status": "error", "message": str(e)}
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
async def message_sent_back(message_base: MessageBase, qq_message_id: str) -> None:
|
||||
async def message_sent_back(self, message_base: MessageBase, qq_message_id: str) -> None:
|
||||
# 修改 additional_config,添加 echo 字段
|
||||
if message_base.message_info.additional_config is None:
|
||||
message_base.message_info.additional_config = {}
|
||||
@@ -643,9 +622,8 @@ class SendHandler:
|
||||
logger.debug("已回送消息ID")
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
async def send_adapter_command_response(
|
||||
original_message: MessageBase, response_data: dict, request_id: str
|
||||
self, original_message: MessageBase, response_data: dict, request_id: str
|
||||
) -> None:
|
||||
"""
|
||||
发送适配器命令响应回MaiBot
|
||||
@@ -674,8 +652,7 @@ class SendHandler:
|
||||
except Exception as e:
|
||||
logger.error(f"发送适配器命令响应时出错: {e}")
|
||||
|
||||
@staticmethod
|
||||
def handle_at_message_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
def handle_at_message_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理艾特并发送消息命令
|
||||
|
||||
Args:
|
||||
|
||||
@@ -6,7 +6,7 @@ import urllib3
|
||||
import ssl
|
||||
import io
|
||||
|
||||
from .database import BanUser, napcat_db
|
||||
from .database import BanUser, db_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
@@ -270,11 +270,10 @@ async def read_ban_list(
|
||||
]
|
||||
"""
|
||||
try:
|
||||
ban_list = await napcat_db.get_ban_records()
|
||||
ban_list = db_manager.get_ban_records()
|
||||
lifted_list: List[BanUser] = []
|
||||
logger.info("已经读取禁言列表")
|
||||
# 复制列表以避免迭代中修改原列表问题
|
||||
for ban_record in list(ban_list):
|
||||
for ban_record in ban_list:
|
||||
if ban_record.user_id == 0:
|
||||
fetched_group_info = await get_group_info(websocket, ban_record.group_id)
|
||||
if fetched_group_info is None:
|
||||
@@ -302,12 +301,12 @@ async def read_ban_list(
|
||||
ban_list.remove(ban_record)
|
||||
else:
|
||||
ban_record.lift_time = lift_ban_time
|
||||
await napcat_db.update_ban_record(ban_list)
|
||||
db_manager.update_ban_record(ban_list)
|
||||
return ban_list, lifted_list
|
||||
except Exception as e:
|
||||
logger.error(f"读取禁言列表失败: {e}")
|
||||
return [], []
|
||||
|
||||
|
||||
async def save_ban_record(list: List[BanUser]):
|
||||
return await napcat_db.update_ban_record(list)
|
||||
def save_ban_record(list: List[BanUser]):
|
||||
return db_manager.update_ban_record(list)
|
||||
|
||||
110
test_deepcopy_fix.py
Normal file
110
test_deepcopy_fix.py
Normal file
@@ -0,0 +1,110 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
测试 ChatStream 的 deepcopy 功能
|
||||
验证 asyncio.Task 序列化问题是否已解决
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
import copy
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from maim_message import UserInfo, GroupInfo
|
||||
|
||||
|
||||
async def test_chat_stream_deepcopy():
|
||||
"""测试 ChatStream 的 deepcopy 功能"""
|
||||
print("[TEST] 开始测试 ChatStream deepcopy 功能...")
|
||||
|
||||
try:
|
||||
# 创建测试用的用户和群组信息
|
||||
user_info = UserInfo(
|
||||
platform="test_platform",
|
||||
user_id="test_user_123",
|
||||
user_nickname="测试用户",
|
||||
user_cardname="测试卡片名"
|
||||
)
|
||||
|
||||
group_info = GroupInfo(
|
||||
platform="test_platform",
|
||||
group_id="test_group_456",
|
||||
group_name="测试群组"
|
||||
)
|
||||
|
||||
# 创建 ChatStream 实例
|
||||
print("📝 创建 ChatStream 实例...")
|
||||
stream_id = "test_stream_789"
|
||||
platform = "test_platform"
|
||||
|
||||
chat_stream = ChatStream(
|
||||
stream_id=stream_id,
|
||||
platform=platform,
|
||||
user_info=user_info,
|
||||
group_info=group_info
|
||||
)
|
||||
|
||||
print(f"[SUCCESS] ChatStream 创建成功: {chat_stream.stream_id}")
|
||||
|
||||
# 等待一下,让异步任务有机会创建
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# 尝试进行 deepcopy
|
||||
print("[INFO] 尝试进行 deepcopy...")
|
||||
copied_stream = copy.deepcopy(chat_stream)
|
||||
|
||||
print("[SUCCESS] deepcopy 成功!")
|
||||
|
||||
# 验证复制后的对象属性
|
||||
print("\n[CHECK] 验证复制后的对象属性:")
|
||||
print(f" - stream_id: {copied_stream.stream_id}")
|
||||
print(f" - platform: {copied_stream.platform}")
|
||||
print(f" - user_info: {copied_stream.user_info.user_nickname}")
|
||||
print(f" - group_info: {copied_stream.group_info.group_name}")
|
||||
|
||||
# 检查 processing_task 是否被正确处理
|
||||
if hasattr(copied_stream.stream_context, 'processing_task'):
|
||||
print(f" - processing_task: {copied_stream.stream_context.processing_task}")
|
||||
if copied_stream.stream_context.processing_task is None:
|
||||
print(" [SUCCESS] processing_task 已被正确设置为 None")
|
||||
else:
|
||||
print(" [WARNING] processing_task 不为 None")
|
||||
else:
|
||||
print(" [SUCCESS] stream_context 没有 processing_task 属性")
|
||||
|
||||
# 验证原始对象和复制对象是不同的实例
|
||||
if id(chat_stream) != id(copied_stream):
|
||||
print("[SUCCESS] 原始对象和复制对象是不同的实例")
|
||||
else:
|
||||
print("[ERROR] 原始对象和复制对象是同一个实例")
|
||||
|
||||
# 验证基本属性是否正确复制
|
||||
if (chat_stream.stream_id == copied_stream.stream_id and
|
||||
chat_stream.platform == copied_stream.platform):
|
||||
print("[SUCCESS] 基本属性正确复制")
|
||||
else:
|
||||
print("[ERROR] 基本属性复制失败")
|
||||
|
||||
print("\n[COMPLETE] 测试完成!deepcopy 功能修复成功!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行测试
|
||||
result = asyncio.run(test_chat_stream_deepcopy())
|
||||
|
||||
if result:
|
||||
print("\n[SUCCESS] 所有测试通过!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("\n[ERROR] 测试失败!")
|
||||
sys.exit(1)
|
||||
109
test_simple_deepcopy.py
Normal file
109
test_simple_deepcopy.py
Normal file
@@ -0,0 +1,109 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
简单的 ChatStream deepcopy 测试
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
import copy
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from maim_message import UserInfo, GroupInfo
|
||||
|
||||
|
||||
async def test_deepcopy():
|
||||
"""测试 deepcopy 功能"""
|
||||
print("开始测试 ChatStream deepcopy 功能...")
|
||||
|
||||
try:
|
||||
# 创建测试用的用户和群组信息
|
||||
user_info = UserInfo(
|
||||
platform="test_platform",
|
||||
user_id="test_user_123",
|
||||
user_nickname="测试用户",
|
||||
user_cardname="测试卡片名"
|
||||
)
|
||||
|
||||
group_info = GroupInfo(
|
||||
platform="test_platform",
|
||||
group_id="test_group_456",
|
||||
group_name="测试群组"
|
||||
)
|
||||
|
||||
# 创建 ChatStream 实例
|
||||
print("创建 ChatStream 实例...")
|
||||
stream_id = "test_stream_789"
|
||||
platform = "test_platform"
|
||||
|
||||
chat_stream = ChatStream(
|
||||
stream_id=stream_id,
|
||||
platform=platform,
|
||||
user_info=user_info,
|
||||
group_info=group_info
|
||||
)
|
||||
|
||||
print(f"ChatStream 创建成功: {chat_stream.stream_id}")
|
||||
|
||||
# 等待一下,让异步任务有机会创建
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# 尝试进行 deepcopy
|
||||
print("尝试进行 deepcopy...")
|
||||
copied_stream = copy.deepcopy(chat_stream)
|
||||
|
||||
print("deepcopy 成功!")
|
||||
|
||||
# 验证复制后的对象属性
|
||||
print("\n验证复制后的对象属性:")
|
||||
print(f" - stream_id: {copied_stream.stream_id}")
|
||||
print(f" - platform: {copied_stream.platform}")
|
||||
print(f" - user_info: {copied_stream.user_info.user_nickname}")
|
||||
print(f" - group_info: {copied_stream.group_info.group_name}")
|
||||
|
||||
# 检查 processing_task 是否被正确处理
|
||||
if hasattr(copied_stream.stream_context, 'processing_task'):
|
||||
print(f" - processing_task: {copied_stream.stream_context.processing_task}")
|
||||
if copied_stream.stream_context.processing_task is None:
|
||||
print(" SUCCESS: processing_task 已被正确设置为 None")
|
||||
else:
|
||||
print(" WARNING: processing_task 不为 None")
|
||||
else:
|
||||
print(" SUCCESS: stream_context 没有 processing_task 属性")
|
||||
|
||||
# 验证原始对象和复制对象是不同的实例
|
||||
if id(chat_stream) != id(copied_stream):
|
||||
print("SUCCESS: 原始对象和复制对象是不同的实例")
|
||||
else:
|
||||
print("ERROR: 原始对象和复制对象是同一个实例")
|
||||
|
||||
# 验证基本属性是否正确复制
|
||||
if (chat_stream.stream_id == copied_stream.stream_id and
|
||||
chat_stream.platform == copied_stream.platform):
|
||||
print("SUCCESS: 基本属性正确复制")
|
||||
else:
|
||||
print("ERROR: 基本属性复制失败")
|
||||
|
||||
print("\n测试完成!deepcopy 功能修复成功!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: 测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行测试
|
||||
result = asyncio.run(test_deepcopy())
|
||||
|
||||
if result:
|
||||
print("\n所有测试通过!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("\n测试失败!")
|
||||
sys.exit(1)
|
||||
Reference in New Issue
Block a user