@@ -62,13 +62,16 @@ class ActionPlanner:
|
|||||||
async def plan(
|
async def plan(
|
||||||
self,
|
self,
|
||||||
goal: str,
|
goal: str,
|
||||||
|
method: str,
|
||||||
reasoning: str,
|
reasoning: str,
|
||||||
action_history: List[Dict[str, str]] = None
|
action_history: List[Dict[str, str]] = None,
|
||||||
|
chat_observer: Optional[ChatObserver] = None, # 添加chat_observer参数
|
||||||
) -> Tuple[str, str]:
|
) -> Tuple[str, str]:
|
||||||
"""规划下一步行动
|
"""规划下一步行动
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
goal: 对话目标
|
goal: 对话目标
|
||||||
|
method: 实现方式
|
||||||
reasoning: 目标原因
|
reasoning: 目标原因
|
||||||
action_history: 行动历史记录
|
action_history: 行动历史记录
|
||||||
|
|
||||||
@@ -103,6 +106,7 @@ class ActionPlanner:
|
|||||||
prompt = f"""现在你在参与一场QQ聊天,请分析以下内容,根据信息决定下一步行动:
|
prompt = f"""现在你在参与一场QQ聊天,请分析以下内容,根据信息决定下一步行动:
|
||||||
{personality_text}
|
{personality_text}
|
||||||
当前对话目标:{goal}
|
当前对话目标:{goal}
|
||||||
|
实现该对话目标的方式:{method}
|
||||||
产生该对话目标的原因:{reasoning}
|
产生该对话目标的原因:{reasoning}
|
||||||
{time_info}
|
{time_info}
|
||||||
最近的对话记录:
|
最近的对话记录:
|
||||||
@@ -281,7 +285,9 @@ class GoalAnalyzer:
|
|||||||
logger.error(f"JSON字段内容为空,重试第{retry + 1}次")
|
logger.error(f"JSON字段内容为空,重试第{retry + 1}次")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return goal, reasoning
|
# 使用默认的方法
|
||||||
|
method = "以友好的态度回应"
|
||||||
|
return goal, method, reasoning
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"分析对话目标时出错: {str(e)},重试第{retry + 1}次")
|
logger.error(f"分析对话目标时出错: {str(e)},重试第{retry + 1}次")
|
||||||
@@ -438,6 +444,7 @@ class ReplyGenerator:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
goal: 对话目标
|
goal: 对话目标
|
||||||
|
method: 实现方式
|
||||||
chat_history: 聊天历史
|
chat_history: 聊天历史
|
||||||
knowledge_cache: 知识缓存
|
knowledge_cache: 知识缓存
|
||||||
previous_reply: 上一次生成的回复(如果有)
|
previous_reply: 上一次生成的回复(如果有)
|
||||||
@@ -558,6 +565,7 @@ class Conversation:
|
|||||||
self.stream_id = stream_id
|
self.stream_id = stream_id
|
||||||
self.state = ConversationState.INIT
|
self.state = ConversationState.INIT
|
||||||
self.current_goal: Optional[str] = None
|
self.current_goal: Optional[str] = None
|
||||||
|
self.current_method: Optional[str] = None
|
||||||
self.goal_reasoning: Optional[str] = None
|
self.goal_reasoning: Optional[str] = None
|
||||||
self.generated_reply: Optional[str] = None
|
self.generated_reply: Optional[str] = None
|
||||||
self.should_continue = True
|
self.should_continue = True
|
||||||
@@ -598,7 +606,7 @@ class Conversation:
|
|||||||
async def _conversation_loop(self):
|
async def _conversation_loop(self):
|
||||||
"""对话循环"""
|
"""对话循环"""
|
||||||
# 获取最近的消息历史
|
# 获取最近的消息历史
|
||||||
self.current_goal, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
|
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
|
||||||
|
|
||||||
while self.should_continue:
|
while self.should_continue:
|
||||||
# 执行行动
|
# 执行行动
|
||||||
@@ -606,15 +614,12 @@ class Conversation:
|
|||||||
if not await self.chat_observer.wait_for_update():
|
if not await self.chat_observer.wait_for_update():
|
||||||
logger.warning("等待消息更新超时")
|
logger.warning("等待消息更新超时")
|
||||||
|
|
||||||
# 如果用户最后发言时间比当前时间晚2秒,说明消息还没到数据库,跳过这次循环
|
|
||||||
if self.chat_observer.last_user_speak_time - time.time() < 1.5:
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
action, reason = await self.action_planner.plan(
|
action, reason = await self.action_planner.plan(
|
||||||
self.current_goal,
|
self.current_goal,
|
||||||
|
self.current_method,
|
||||||
self.goal_reasoning,
|
self.goal_reasoning,
|
||||||
self.action_history, # 传入action历史
|
self.action_history, # 传入action历史
|
||||||
|
self.chat_observer # 传入chat_observer
|
||||||
)
|
)
|
||||||
|
|
||||||
# 执行行动
|
# 执行行动
|
||||||
@@ -659,12 +664,13 @@ class Conversation:
|
|||||||
messages = self.chat_observer.get_message_history(limit=30)
|
messages = self.chat_observer.get_message_history(limit=30)
|
||||||
self.generated_reply, need_replan = await self.reply_generator.generate(
|
self.generated_reply, need_replan = await self.reply_generator.generate(
|
||||||
self.current_goal,
|
self.current_goal,
|
||||||
|
self.current_method,
|
||||||
[self._convert_to_message(msg) for msg in messages],
|
[self._convert_to_message(msg) for msg in messages],
|
||||||
self.knowledge_cache
|
self.knowledge_cache
|
||||||
)
|
)
|
||||||
if need_replan:
|
if need_replan:
|
||||||
self.state = ConversationState.RETHINKING
|
self.state = ConversationState.RETHINKING
|
||||||
self.current_goal, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
|
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
|
||||||
else:
|
else:
|
||||||
await self._send_reply()
|
await self._send_reply()
|
||||||
|
|
||||||
@@ -682,18 +688,19 @@ class Conversation:
|
|||||||
|
|
||||||
self.generated_reply, need_replan = await self.reply_generator.generate(
|
self.generated_reply, need_replan = await self.reply_generator.generate(
|
||||||
self.current_goal,
|
self.current_goal,
|
||||||
|
self.current_method,
|
||||||
[self._convert_to_message(msg) for msg in messages],
|
[self._convert_to_message(msg) for msg in messages],
|
||||||
self.knowledge_cache
|
self.knowledge_cache
|
||||||
)
|
)
|
||||||
if need_replan:
|
if need_replan:
|
||||||
self.state = ConversationState.RETHINKING
|
self.state = ConversationState.RETHINKING
|
||||||
self.current_goal, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
|
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
|
||||||
else:
|
else:
|
||||||
await self._send_reply()
|
await self._send_reply()
|
||||||
|
|
||||||
elif action == "rethink_goal":
|
elif action == "rethink_goal":
|
||||||
self.state = ConversationState.RETHINKING
|
self.state = ConversationState.RETHINKING
|
||||||
self.current_goal, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
|
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
|
||||||
|
|
||||||
elif action == "judge_conversation":
|
elif action == "judge_conversation":
|
||||||
self.state = ConversationState.JUDGING
|
self.state = ConversationState.JUDGING
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
from typing import Union, Optional
|
from typing import Union
|
||||||
|
|
||||||
from ...common.database import db
|
from ...common.database import db
|
||||||
from ..chat.message import MessageSending, MessageRecv
|
from ..chat.message import MessageSending, MessageRecv
|
||||||
from ..chat.chat_stream import ChatStream
|
from ..chat.chat_stream import ChatStream
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
from ..message.message_base import BaseMessageInfo, Seg, UserInfo
|
|
||||||
|
|
||||||
logger = get_module_logger("message_storage")
|
logger = get_module_logger("message_storage")
|
||||||
|
|
||||||
@@ -27,57 +26,6 @@ class MessageStorage:
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("存储消息失败")
|
logger.exception("存储消息失败")
|
||||||
|
|
||||||
async def get_last_message(self, chat_id: str, user_id: str) -> Optional[MessageRecv]:
|
|
||||||
"""获取指定聊天流和用户的最后一条消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: 聊天流ID
|
|
||||||
user_id: 用户ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[MessageRecv]: 最后一条消息,如果没有找到则返回None
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 查找最后一条消息
|
|
||||||
message_data = db.messages.find_one(
|
|
||||||
{
|
|
||||||
"chat_id": chat_id,
|
|
||||||
"user_info.user_id": user_id
|
|
||||||
},
|
|
||||||
sort=[("time", -1)] # 按时间降序排序
|
|
||||||
)
|
|
||||||
|
|
||||||
if not message_data:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 构建消息字典
|
|
||||||
message_dict = {
|
|
||||||
"message_info": {
|
|
||||||
"platform": message_data["chat_info"]["platform"],
|
|
||||||
"message_id": message_data["message_id"],
|
|
||||||
"time": message_data["time"],
|
|
||||||
"group_info": message_data["chat_info"].get("group_info"),
|
|
||||||
"user_info": message_data["user_info"]
|
|
||||||
},
|
|
||||||
"message_segment": {
|
|
||||||
"type": "text",
|
|
||||||
"data": message_data["processed_plain_text"]
|
|
||||||
},
|
|
||||||
"raw_message": message_data["processed_plain_text"]
|
|
||||||
}
|
|
||||||
|
|
||||||
# 创建并返回消息对象
|
|
||||||
message = MessageRecv(message_dict)
|
|
||||||
message.processed_plain_text = message_data["processed_plain_text"]
|
|
||||||
message.detailed_plain_text = message_data["detailed_plain_text"]
|
|
||||||
message.update_chat_stream(ChatStream.from_dict(message_data["chat_info"]))
|
|
||||||
|
|
||||||
return message
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
logger.exception("获取最后一条消息失败")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def store_recalled_message(self, message_id: str, time: str, chat_stream: ChatStream) -> None:
|
async def store_recalled_message(self, message_id: str, time: str, chat_stream: ChatStream) -> None:
|
||||||
"""存储撤回消息到数据库"""
|
"""存储撤回消息到数据库"""
|
||||||
if "recalled_messages" not in db.list_collection_names():
|
if "recalled_messages" not in db.list_collection_names():
|
||||||
|
|||||||
Reference in New Issue
Block a user