重构ChatStream和StreamContext:移除context_manager引用
- 在ChatStream及相关类中,将所有context_manager的实例替换为直接上下文访问。 - 更新方法,利用新的上下文结构来管理聊天状态和消息。 - 增强的StreamContext,增加了用于消息处理、统计和历史管理的方法。 - 在重构过程中改进了错误处理和日志记录。
This commit is contained in:
@@ -8,6 +8,8 @@ from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
from src.common.database.api.crud import CRUDBase
|
||||
from src.common.database.api.specialized import get_or_create_chat_stream
|
||||
from src.common.database.compatibility import get_db_session
|
||||
@@ -41,18 +43,10 @@ class ChatStream:
|
||||
self.sleep_pressure = data.get("sleep_pressure", 0.0) if data else 0.0
|
||||
self.saved = False
|
||||
|
||||
# 创建单流上下文管理器(包含StreamContext)
|
||||
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
self.context_manager: SingleStreamContextManager = SingleStreamContextManager(
|
||||
self.context: StreamContext = StreamContext(
|
||||
stream_id=stream_id,
|
||||
context=StreamContext(
|
||||
stream_id=stream_id,
|
||||
chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE,
|
||||
chat_mode=ChatMode.FOCUS,
|
||||
),
|
||||
chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE,
|
||||
chat_mode=ChatMode.FOCUS,
|
||||
)
|
||||
|
||||
# 基础参数
|
||||
@@ -73,11 +67,11 @@ class ChatStream:
|
||||
"focus_energy": self.focus_energy,
|
||||
# 基础兴趣度
|
||||
"base_interest_energy": self.base_interest_energy,
|
||||
# stream_context基本信息(通过context_manager访问)
|
||||
"stream_context_chat_type": self.context_manager.context.chat_type.value,
|
||||
"stream_context_chat_mode": self.context_manager.context.chat_mode.value,
|
||||
# stream_context基本信息
|
||||
"stream_context_chat_type": self.context.chat_type.value,
|
||||
"stream_context_chat_mode": self.context.chat_mode.value,
|
||||
# 统计信息
|
||||
"interruption_count": self.context_manager.context.interruption_count,
|
||||
"interruption_count": self.context.interruption_count,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -94,19 +88,19 @@ class ChatStream:
|
||||
data=data,
|
||||
)
|
||||
|
||||
# 恢复stream_context信息(通过context_manager访问)
|
||||
# 恢复stream_context信息
|
||||
if "stream_context_chat_type" in data:
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
instance.context_manager.context.chat_type = ChatType(data["stream_context_chat_type"])
|
||||
instance.context.chat_type = ChatType(data["stream_context_chat_type"])
|
||||
if "stream_context_chat_mode" in data:
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
instance.context_manager.context.chat_mode = ChatMode(data["stream_context_chat_mode"])
|
||||
instance.context.chat_mode = ChatMode(data["stream_context_chat_mode"])
|
||||
|
||||
# 恢复interruption_count信息
|
||||
if "interruption_count" in data:
|
||||
instance.context_manager.context.interruption_count = data["interruption_count"]
|
||||
instance.context.interruption_count = data["interruption_count"]
|
||||
|
||||
return instance
|
||||
|
||||
@@ -131,15 +125,7 @@ class ChatStream:
|
||||
message: DatabaseMessages 对象,直接使用不需要转换
|
||||
"""
|
||||
# 直接使用传入的 DatabaseMessages,设置到上下文中
|
||||
self.context_manager.context.set_current_message(message)
|
||||
|
||||
# 设置优先级信息(如果存在)
|
||||
priority_mode = getattr(message, "priority_mode", None)
|
||||
priority_info = getattr(message, "priority_info", None)
|
||||
if priority_mode:
|
||||
self.context_manager.context.priority_mode = priority_mode
|
||||
if priority_info:
|
||||
self.context_manager.context.priority_info = priority_info
|
||||
self.context.set_current_message(message)
|
||||
|
||||
# 调试日志
|
||||
logger.debug(
|
||||
@@ -253,7 +239,7 @@ class ChatStream:
|
||||
"""异步计算focus_energy"""
|
||||
try:
|
||||
# 使用单流上下文管理器获取消息
|
||||
all_messages = self.context_manager.get_messages(limit=global_config.chat.max_context_size)
|
||||
all_messages = self.context.get_messages(limit=global_config.chat.max_context_size)
|
||||
|
||||
# 获取用户ID
|
||||
user_id = None
|
||||
@@ -318,7 +304,6 @@ class ChatManager:
|
||||
|
||||
def __init__(self):
|
||||
if not self._initialized:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
self.streams: dict[str, ChatStream] = {} # stream_id -> ChatStream
|
||||
self.last_messages: dict[str, DatabaseMessages] = {} # stream_id -> last_message
|
||||
@@ -409,135 +394,87 @@ class ChatManager:
|
||||
async def get_or_create_stream(
|
||||
self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None
|
||||
) -> ChatStream:
|
||||
"""获取或创建聊天流 - 优化版本使用缓存管理器
|
||||
|
||||
Args:
|
||||
platform: 平台标识
|
||||
user_info: 用户信息
|
||||
group_info: 群组信息(可选)
|
||||
|
||||
Returns:
|
||||
ChatStream: 聊天流对象
|
||||
"""
|
||||
# 生成stream_id
|
||||
"""获取或创建聊天流 - 优化版本使用缓存机制"""
|
||||
try:
|
||||
stream_id = self._generate_stream_id(platform, user_info, group_info)
|
||||
|
||||
# 检查内存中是否存在
|
||||
if stream_id in self.streams:
|
||||
stream = self.streams[stream_id]
|
||||
|
||||
# 更新用户信息和群组信息
|
||||
stream.update_active_time()
|
||||
if user_info.platform and user_info.user_id:
|
||||
stream.user_info = user_info
|
||||
if group_info:
|
||||
stream.group_info = group_info
|
||||
|
||||
# 检查是否有最后一条消息(现在使用 DatabaseMessages)
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
|
||||
await stream.set_context(self.last_messages[stream_id])
|
||||
else:
|
||||
logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的或还没有消息")
|
||||
return stream
|
||||
|
||||
# 使用优化后的API查询(带缓存)
|
||||
current_time = time.time()
|
||||
model_instance, _ = await get_or_create_chat_stream(
|
||||
stream_id=stream_id,
|
||||
platform=platform,
|
||||
defaults={
|
||||
"create_time": current_time,
|
||||
"last_active_time": current_time,
|
||||
"user_platform": user_info.platform if user_info else platform,
|
||||
"user_id": user_info.user_id if user_info else "",
|
||||
"user_nickname": user_info.user_nickname if user_info else "",
|
||||
"user_cardname": user_info.user_cardname if user_info else "",
|
||||
"group_platform": group_info.platform if group_info else None,
|
||||
"group_id": group_info.group_id if group_info else None,
|
||||
"group_name": group_info.group_name if group_info else None,
|
||||
}
|
||||
)
|
||||
|
||||
if model_instance:
|
||||
# 从 SQLAlchemy 模型转换回 ChatStream.from_dict 期望的格式
|
||||
user_info_data = {
|
||||
"platform": model_instance.user_platform,
|
||||
"user_id": model_instance.user_id,
|
||||
"user_nickname": model_instance.user_nickname,
|
||||
"user_cardname": model_instance.user_cardname or "",
|
||||
}
|
||||
group_info_data = None
|
||||
if model_instance and getattr(model_instance, "group_id", None):
|
||||
group_info_data = {
|
||||
"platform": model_instance.group_platform,
|
||||
"group_id": model_instance.group_id,
|
||||
"group_name": model_instance.group_name,
|
||||
}
|
||||
|
||||
data_for_from_dict = {
|
||||
"stream_id": model_instance.stream_id,
|
||||
"platform": model_instance.platform,
|
||||
"user_info": user_info_data,
|
||||
"group_info": group_info_data,
|
||||
"create_time": model_instance.create_time,
|
||||
"last_active_time": model_instance.last_active_time,
|
||||
"energy_value": model_instance.energy_value,
|
||||
"sleep_pressure": model_instance.sleep_pressure,
|
||||
}
|
||||
stream = ChatStream.from_dict(data_for_from_dict)
|
||||
# 更新用户信息和群组信息
|
||||
stream.user_info = user_info
|
||||
if group_info:
|
||||
stream.group_info = group_info
|
||||
stream.update_active_time()
|
||||
else:
|
||||
# 创建新的聊天流
|
||||
stream = ChatStream(
|
||||
current_time = time.time()
|
||||
model_instance, _ = await get_or_create_chat_stream(
|
||||
stream_id=stream_id,
|
||||
platform=platform,
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
defaults={
|
||||
"create_time": current_time,
|
||||
"last_active_time": current_time,
|
||||
"user_platform": user_info.platform if user_info else platform,
|
||||
"user_id": user_info.user_id if user_info else "",
|
||||
"user_nickname": user_info.user_nickname if user_info else "",
|
||||
"user_cardname": user_info.user_cardname if user_info else "",
|
||||
"group_platform": group_info.platform if group_info else None,
|
||||
"group_id": group_info.group_id if group_info else None,
|
||||
"group_name": group_info.group_name if group_info else None,
|
||||
},
|
||||
)
|
||||
|
||||
if model_instance:
|
||||
user_info_data = {
|
||||
"platform": model_instance.user_platform,
|
||||
"user_id": model_instance.user_id,
|
||||
"user_nickname": model_instance.user_nickname,
|
||||
"user_cardname": model_instance.user_cardname or "",
|
||||
}
|
||||
group_info_data = None
|
||||
if getattr(model_instance, "group_id", None):
|
||||
group_info_data = {
|
||||
"platform": model_instance.group_platform,
|
||||
"group_id": model_instance.group_id,
|
||||
"group_name": model_instance.group_name,
|
||||
}
|
||||
|
||||
data_for_from_dict = {
|
||||
"stream_id": model_instance.stream_id,
|
||||
"platform": model_instance.platform,
|
||||
"user_info": user_info_data,
|
||||
"group_info": group_info_data,
|
||||
"create_time": model_instance.create_time,
|
||||
"last_active_time": model_instance.last_active_time,
|
||||
"energy_value": model_instance.energy_value,
|
||||
"sleep_pressure": model_instance.sleep_pressure,
|
||||
}
|
||||
stream = ChatStream.from_dict(data_for_from_dict)
|
||||
stream.user_info = user_info
|
||||
if group_info:
|
||||
stream.group_info = group_info
|
||||
stream.update_active_time()
|
||||
else:
|
||||
stream = ChatStream(
|
||||
stream_id=stream_id,
|
||||
platform=platform,
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取或创建聊天流失败: {e}", exc_info=True)
|
||||
raise e
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
|
||||
await stream.set_context(self.last_messages[stream_id])
|
||||
else:
|
||||
logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
|
||||
|
||||
# 确保 ChatStream 有自己的 context_manager
|
||||
if not hasattr(stream, "context_manager") or stream.context_manager is None:
|
||||
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
logger.info(f"为 stream {stream_id} 创建新的 context_manager")
|
||||
stream.context_manager = SingleStreamContextManager(
|
||||
stream_id=stream_id,
|
||||
context=StreamContext(
|
||||
stream_id=stream_id,
|
||||
chat_type=ChatType.GROUP if stream.group_info else ChatType.PRIVATE,
|
||||
chat_mode=ChatMode.FOCUS,
|
||||
),
|
||||
)
|
||||
else:
|
||||
logger.info(f"stream {stream_id} 已有 context_manager,跳过创建")
|
||||
|
||||
# 保存到内存和数据库
|
||||
self.streams[stream_id] = stream
|
||||
await self._save_stream(stream)
|
||||
return stream
|
||||
|
||||
async def get_stream(self, stream_id: str) -> ChatStream | None:
|
||||
"""通过stream_id获取聊天流"""
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
stream = self.streams.get(stream_id)
|
||||
if not stream:
|
||||
return None
|
||||
@@ -765,23 +702,6 @@ class ChatManager:
|
||||
# if stream.stream_id in self.last_messages:
|
||||
# await stream.set_context(self.last_messages[stream.stream_id])
|
||||
|
||||
# 确保 ChatStream 有自己的 context_manager
|
||||
if not hasattr(stream, "context_manager") or stream.context_manager is None:
|
||||
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
logger.debug(f"为加载的 stream {stream.stream_id} 创建新的 context_manager")
|
||||
stream.context_manager = SingleStreamContextManager(
|
||||
stream_id=stream.stream_id,
|
||||
context=StreamContext(
|
||||
stream_id=stream.stream_id,
|
||||
chat_type=ChatType.GROUP if stream.group_info else ChatType.PRIVATE,
|
||||
chat_mode=ChatMode.FOCUS,
|
||||
),
|
||||
)
|
||||
else:
|
||||
logger.debug(f"加载的 stream {stream.stream_id} 已有 context_manager")
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}", exc_info=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user