ruff reformatted

This commit is contained in:
春河晴
2025-04-08 15:31:13 +09:00
parent 0d7068acab
commit 7840a6080d
40 changed files with 1227 additions and 1336 deletions

View File

@@ -1,6 +1,6 @@
import time
import asyncio
from typing import Optional, Dict, Any, List, Tuple
from typing import Optional, Dict, Any, List, Tuple
from src.common.logger import get_module_logger
from src.common.database import db
from ..message.message_base import UserInfo
@@ -8,99 +8,97 @@ from ..config.config import global_config
logger = get_module_logger("chat_observer")
class ChatObserver:
"""聊天状态观察器"""
# 类级别的实例管理
_instances: Dict[str, 'ChatObserver'] = {}
_instances: Dict[str, "ChatObserver"] = {}
@classmethod
def get_instance(cls, stream_id: str) -> 'ChatObserver':
def get_instance(cls, stream_id: str) -> "ChatObserver":
"""获取或创建观察器实例
Args:
stream_id: 聊天流ID
Returns:
ChatObserver: 观察器实例
"""
if stream_id not in cls._instances:
cls._instances[stream_id] = cls(stream_id)
return cls._instances[stream_id]
def __init__(self, stream_id: str):
"""初始化观察器
Args:
stream_id: 聊天流ID
"""
if stream_id in self._instances:
raise RuntimeError(f"ChatObserver for {stream_id} already exists. Use get_instance() instead.")
self.stream_id = stream_id
self.last_user_speak_time: Optional[float] = None # 对方上次发言时间
self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间
self.last_check_time: float = time.time() # 上次查看聊天记录时间
self.last_message_read: Optional[str] = None # 最后读取的消息ID
self.last_message_time: Optional[float] = None # 最后一条消息的时间戳
self.waiting_start_time: Optional[float] = None # 等待开始时间
self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间
self.last_check_time: float = time.time() # 上次查看聊天记录时间
self.last_message_read: Optional[str] = None # 最后读取的消息ID
self.last_message_time: Optional[float] = None # 最后一条消息的时间戳
self.waiting_start_time: Optional[float] = None # 等待开始时间
# 消息历史记录
self.message_history: List[Dict[str, Any]] = [] # 所有消息历史
self.last_message_id: Optional[str] = None # 最后一条消息的ID
self.message_count: int = 0 # 消息计数
self.last_message_id: Optional[str] = None # 最后一条消息的ID
self.message_count: int = 0 # 消息计数
# 运行状态
self._running: bool = False
self._task: Optional[asyncio.Task] = None
self._update_event = asyncio.Event() # 触发更新的事件
self._update_complete = asyncio.Event() # 更新完成的事件
def check(self) -> bool:
"""检查距离上一次观察之后是否有了新消息
Returns:
bool: 是否有新消息
"""
logger.debug(f"检查距离上一次观察之后是否有了新消息: {self.last_check_time}")
query = {
"chat_id": self.stream_id,
"time": {"$gt": self.last_check_time}
}
query = {"chat_id": self.stream_id, "time": {"$gt": self.last_check_time}}
# 只需要查询是否存在,不需要获取具体消息
new_message_exists = db.messages.find_one(query) is not None
if new_message_exists:
logger.debug("发现新消息")
self.last_check_time = time.time()
return new_message_exists
def get_new_message(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""获取上一次观察的时间点后的新消息,插入到历史记录中,并返回新消息和历史记录两个对象"""
messages = self.get_message_history(self.last_check_time)
for message in messages:
self._add_message_to_history(message)
return messages, self.message_history
def new_message_after(self, time_point: float) -> bool:
"""判断是否在指定时间点后有新消息
Args:
time_point: 时间戳
Returns:
bool: 是否有新消息
"""
logger.debug(f"判断是否在指定时间点后有新消息: {self.last_message_time} > {time_point}")
return self.last_message_time is None or self.last_message_time > time_point
def _add_message_to_history(self, message: Dict[str, Any]):
"""添加消息到历史记录
Args:
message: 消息数据
"""
@@ -108,54 +106,53 @@ class ChatObserver:
self.last_message_id = message["message_id"]
self.last_message_time = message["time"] # 更新最后消息时间
self.message_count += 1
# 更新说话时间
user_info = UserInfo.from_dict(message.get("user_info", {}))
if user_info.user_id == global_config.BOT_QQ:
self.last_bot_speak_time = message["time"]
else:
self.last_user_speak_time = message["time"]
def get_message_history(
self,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
limit: Optional[int] = None,
user_id: Optional[str] = None
user_id: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""获取消息历史
Args:
start_time: 开始时间戳
end_time: 结束时间戳
limit: 限制返回消息数量
user_id: 指定用户ID
Returns:
List[Dict[str, Any]]: 消息列表
"""
filtered_messages = self.message_history
if start_time is not None:
filtered_messages = [m for m in filtered_messages if m["time"] >= start_time]
if end_time is not None:
filtered_messages = [m for m in filtered_messages if m["time"] <= end_time]
if user_id is not None:
filtered_messages = [
m for m in filtered_messages
if UserInfo.from_dict(m.get("user_info", {})).user_id == user_id
m for m in filtered_messages if UserInfo.from_dict(m.get("user_info", {})).user_id == user_id
]
if limit is not None:
filtered_messages = filtered_messages[-limit:]
return filtered_messages
async def _fetch_new_messages(self) -> List[Dict[str, Any]]:
"""获取新消息
Returns:
List[Dict[str, Any]]: 新消息列表
"""
@@ -165,42 +162,37 @@ class ChatObserver:
last_message = db.messages.find_one({"message_id": self.last_message_read})
if last_message:
query["time"] = {"$gt": last_message["time"]}
new_messages = list(
db.messages.find(query).sort("time", 1)
)
new_messages = list(db.messages.find(query).sort("time", 1))
if new_messages:
self.last_message_read = new_messages[-1]["message_id"]
return new_messages
async def _fetch_new_messages_before(self, time_point: float) -> List[Dict[str, Any]]:
"""获取指定时间点之前的消息
Args:
time_point: 时间戳
Returns:
List[Dict[str, Any]]: 最多5条消息
"""
query = {
"chat_id": self.stream_id,
"time": {"$lt": time_point}
}
query = {"chat_id": self.stream_id, "time": {"$lt": time_point}}
new_messages = list(
db.messages.find(query).sort("time", -1).limit(5) # 倒序获取5条
)
# 将消息按时间正序排列
new_messages.reverse()
if new_messages:
self.last_message_read = new_messages[-1]["message_id"]
return new_messages
async def _update_loop(self):
"""更新循环"""
try:
@@ -210,7 +202,7 @@ class ChatObserver:
self._add_message_to_history(message)
except Exception as e:
logger.error(f"缓冲消息出错: {e}")
while self._running:
try:
# 等待事件或超时1秒
@@ -218,35 +210,35 @@ class ChatObserver:
await asyncio.wait_for(self._update_event.wait(), timeout=1)
except asyncio.TimeoutError:
pass # 超时后也执行一次检查
self._update_event.clear() # 重置触发事件
self._update_complete.clear() # 重置完成事件
# 获取新消息
new_messages = await self._fetch_new_messages()
if new_messages:
# 处理新消息
for message in new_messages:
self._add_message_to_history(message)
# 设置完成事件
self._update_complete.set()
except Exception as e:
logger.error(f"更新循环出错: {e}")
self._update_complete.set() # 即使出错也要设置完成事件
def trigger_update(self):
"""触发一次立即更新"""
self._update_event.set()
async def wait_for_update(self, timeout: float = 5.0) -> bool:
"""等待更新完成
Args:
timeout: 超时时间(秒)
Returns:
bool: 是否成功完成更新False表示超时
"""
@@ -256,16 +248,16 @@ class ChatObserver:
except asyncio.TimeoutError:
logger.warning(f"等待更新完成超时({timeout}秒)")
return False
def start(self):
"""启动观察器"""
if self._running:
return
self._running = True
self._task = asyncio.create_task(self._update_loop())
logger.info(f"ChatObserver for {self.stream_id} started")
def stop(self):
"""停止观察器"""
self._running = False
@@ -274,15 +266,15 @@ class ChatObserver:
if self._task:
self._task.cancel()
logger.info(f"ChatObserver for {self.stream_id} stopped")
async def process_chat_history(self, messages: list):
"""处理聊天历史
Args:
messages: 消息列表
"""
self.update_check_time()
for msg in messages:
try:
user_info = UserInfo.from_dict(msg.get("user_info", {}))
@@ -292,31 +284,31 @@ class ChatObserver:
self.update_user_speak_time(msg["time"])
except Exception as e:
logger.warning(f"处理消息时间时出错: {e}")
continue
continue
def update_check_time(self):
"""更新查看时间"""
self.last_check_time = time.time()
def update_bot_speak_time(self, speak_time: Optional[float] = None):
"""更新机器人说话时间"""
self.last_bot_speak_time = speak_time or time.time()
def update_user_speak_time(self, speak_time: Optional[float] = None):
"""更新用户说话时间"""
self.last_user_speak_time = speak_time or time.time()
def get_time_info(self) -> str:
"""获取时间信息文本"""
current_time = time.time()
time_info = ""
if self.last_bot_speak_time:
bot_speak_ago = current_time - self.last_bot_speak_time
time_info += f"\n距离你上次发言已经过去了{int(bot_speak_ago)}"
if self.last_user_speak_time:
user_speak_ago = current_time - self.last_user_speak_time
time_info += f"\n距离对方上次发言已经过去了{int(user_speak_ago)}"
return time_info

View File

@@ -1,5 +1,5 @@
#Programmable Friendly Conversationalist
#Prefrontal cortex
# Programmable Friendly Conversationalist
# Prefrontal cortex
import datetime
import asyncio
from typing import List, Optional, Dict, Any, Tuple, Literal
@@ -26,6 +26,7 @@ logger = get_module_logger("pfc")
class ConversationState(Enum):
"""对话状态"""
INIT = "初始化"
RETHINKING = "重新思考"
ANALYZING = "分析历史"
@@ -44,40 +45,37 @@ ActionType = Literal["direct_reply", "fetch_knowledge", "wait"]
class ActionPlanner:
"""行动规划器"""
def __init__(self, stream_id: str):
self.llm = LLM_request(
model=global_config.llm_normal,
temperature=0.7,
max_tokens=1000,
request_type="action_planning"
model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="action_planning"
)
self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2)
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
self.name = global_config.BOT_NICKNAME
self.chat_observer = ChatObserver.get_instance(stream_id)
async def plan(
self,
goal: str,
method: str,
self,
goal: str,
method: str,
reasoning: str,
action_history: List[Dict[str, str]] = None,
chat_observer: Optional[ChatObserver] = None, # 添加chat_observer参数
) -> Tuple[str, str]:
"""规划下一步行动
Args:
goal: 对话目标
reasoning: 目标原因
action_history: 行动历史记录
Returns:
Tuple[str, str]: (行动类型, 行动原因)
"""
# 构建提示词
# 获取最近20条消息
self.chat_observer.waiting_start_time = time.time()
messages = self.chat_observer.get_message_history(limit=20)
chat_history_text = ""
for msg in messages:
@@ -87,13 +85,13 @@ class ActionPlanner:
if sender == self.name:
sender = "你说"
chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n"
personality_text = f"你的名字是{self.name}{self.personality_info}"
# 构建action历史文本
action_history_text = ""
if action_history:
if action_history[-1]['action'] == "direct_reply":
if action_history[-1]["action"] == "direct_reply":
action_history_text = "你刚刚发言回复了对方"
# 获取时间信息
@@ -127,29 +125,34 @@ judge_conversation: 判断对话是否结束,当发现对话目标已经达到
try:
content, _ = await self.llm.generate_response_async(prompt)
logger.debug(f"LLM原始返回内容: {content}")
# 使用简化函数提取JSON内容
success, result = get_items_from_json(
content,
"action", "reason",
default_values={"action": "direct_reply", "reason": "默认原因"}
content, "action", "reason", default_values={"action": "direct_reply", "reason": "默认原因"}
)
if not success:
return "direct_reply", "JSON解析失败选择直接回复"
action = result["action"]
reason = result["reason"]
# 验证action类型
if action not in ["direct_reply", "fetch_knowledge", "wait", "listening", "rethink_goal", "judge_conversation"]:
if action not in [
"direct_reply",
"fetch_knowledge",
"wait",
"listening",
"rethink_goal",
"judge_conversation",
]:
logger.warning(f"未知的行动类型: {action}默认使用listening")
action = "listening"
logger.info(f"规划的行动: {action}")
logger.info(f"行动原因: {reason}")
return action, reason
except Exception as e:
logger.error(f"规划行动时出错: {str(e)}")
return "direct_reply", "发生错误,选择直接回复"
@@ -157,20 +160,17 @@ judge_conversation: 判断对话是否结束,当发现对话目标已经达到
class GoalAnalyzer:
"""对话目标分析器"""
def __init__(self, stream_id: str):
self.llm = LLM_request(
model=global_config.llm_normal,
temperature=0.7,
max_tokens=1000,
request_type="conversation_goal"
model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="conversation_goal"
)
self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2)
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
self.name = global_config.BOT_NICKNAME
self.nick_name = global_config.BOT_ALIAS_NAMES
self.chat_observer = ChatObserver.get_instance(stream_id)
# 多目标存储结构
self.goals = [] # 存储多个目标
self.max_goals = 3 # 同时保持的最大目标数量
@@ -178,10 +178,10 @@ class GoalAnalyzer:
async def analyze_goal(self) -> Tuple[str, str, str]:
"""分析对话历史并设定目标
Args:
chat_history: 聊天历史记录列表
Returns:
Tuple[str, str, str]: (目标, 方法, 原因)
"""
@@ -198,16 +198,16 @@ class GoalAnalyzer:
if sender == self.name:
sender = "你说"
chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n"
personality_text = f"你的名字是{self.name}{self.personality_info}"
# 构建当前已有目标的文本
existing_goals_text = ""
if self.goals:
existing_goals_text = "当前已有的对话目标:\n"
for i, (goal, _, reason) in enumerate(self.goals):
existing_goals_text += f"{i+1}. 目标: {goal}, 原因: {reason}\n"
existing_goals_text += f"{i + 1}. 目标: {goal}, 原因: {reason}\n"
prompt = f"""{personality_text}。现在你在参与一场QQ聊天请分析以下聊天记录并根据你的性格特征确定多个明确的对话目标。
这些目标应该反映出对话的不同方面和意图。
@@ -235,46 +235,44 @@ class GoalAnalyzer:
logger.debug(f"发送到LLM的提示词: {prompt}")
content, _ = await self.llm.generate_response_async(prompt)
logger.debug(f"LLM原始返回内容: {content}")
# 使用简化函数提取JSON内容
success, result = get_items_from_json(
content,
"goal", "reasoning",
required_types={"goal": str, "reasoning": str}
content, "goal", "reasoning", required_types={"goal": str, "reasoning": str}
)
if not success:
logger.error(f"无法解析JSON重试第{retry + 1}")
continue
goal = result["goal"]
reasoning = result["reasoning"]
# 使用默认的方法
method = "以友好的态度回应"
# 更新目标列表
await self._update_goals(goal, method, reasoning)
# 返回当前最主要的目标
if self.goals:
current_goal, current_method, current_reasoning = self.goals[0]
return current_goal, current_method, current_reasoning
else:
return goal, method, reasoning
except Exception as e:
logger.error(f"分析对话目标时出错: {str(e)},重试第{retry + 1}")
if retry == max_retries - 1:
return "保持友好的对话", "以友好的态度回应", "确保对话顺利进行"
continue
# 所有重试都失败后的默认返回
return "保持友好的对话", "以友好的态度回应", "确保对话顺利进行"
async def _update_goals(self, new_goal: str, method: str, reasoning: str):
"""更新目标列表
Args:
new_goal: 新的目标
method: 实现目标的方法
@@ -288,23 +286,23 @@ class GoalAnalyzer:
# 将此目标移到列表前面(最主要的位置)
self.goals.insert(0, self.goals.pop(i))
return
# 添加新目标到列表前面
self.goals.insert(0, (new_goal, method, reasoning))
# 限制目标数量
if len(self.goals) > self.max_goals:
self.goals.pop() # 移除最老的目标
def _calculate_similarity(self, goal1: str, goal2: str) -> float:
"""简单计算两个目标之间的相似度
这里使用一个简单的实现,实际可以使用更复杂的文本相似度算法
Args:
goal1: 第一个目标
goal2: 第二个目标
Returns:
float: 相似度得分 (0-1)
"""
@@ -314,18 +312,18 @@ class GoalAnalyzer:
overlap = len(words1.intersection(words2))
total = len(words1.union(words2))
return overlap / total if total > 0 else 0
async def get_all_goals(self) -> List[Tuple[str, str, str]]:
"""获取所有当前目标
Returns:
List[Tuple[str, str, str]]: 目标列表,每项为(目标, 方法, 原因)
"""
return self.goals.copy()
async def get_alternative_goals(self) -> List[Tuple[str, str, str]]:
"""获取除了当前主要目标外的其他备选目标
Returns:
List[Tuple[str, str, str]]: 备选目标列表
"""
@@ -343,9 +341,9 @@ class GoalAnalyzer:
if sender == self.name:
sender = "你说"
chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n"
personality_text = f"你的名字是{self.name}{self.personality_info}"
prompt = f"""{personality_text}。现在你在参与一场QQ聊天
当前对话目标:{goal}
产生该对话目标的原因:{reasoning}
@@ -368,21 +366,19 @@ class GoalAnalyzer:
try:
content, _ = await self.llm.generate_response_async(prompt)
logger.debug(f"LLM原始返回内容: {content}")
# 使用简化函数提取JSON内容
success, result = get_items_from_json(
content,
"goal_achieved", "stop_conversation", "reason",
required_types={
"goal_achieved": bool,
"stop_conversation": bool,
"reason": str
}
"goal_achieved",
"stop_conversation",
"reason",
required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str},
)
if not success:
return False, False, "确保对话顺利进行"
# 如果当前目标达成,从目标列表中移除
if result["goal_achieved"] and not result["stop_conversation"]:
for i, (g, _, _) in enumerate(self.goals):
@@ -392,9 +388,9 @@ class GoalAnalyzer:
if self.goals:
result["stop_conversation"] = False
break
return result["goal_achieved"], result["stop_conversation"], result["reason"]
except Exception as e:
logger.error(f"分析对话目标时出错: {str(e)}")
return False, False, "确保对话顺利进行"
@@ -402,14 +398,15 @@ class GoalAnalyzer:
class Waiter:
"""快 速 等 待"""
def __init__(self, stream_id: str):
self.chat_observer = ChatObserver.get_instance(stream_id)
self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2)
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
self.name = global_config.BOT_NICKNAME
async def wait(self) -> bool:
"""等待
Returns:
bool: 是否超时True表示超时
"""
@@ -424,39 +421,36 @@ class Waiter:
logger.info("等待结束")
return False
class ReplyGenerator:
"""回复生成器"""
def __init__(self, stream_id: str):
self.llm = LLM_request(
model=global_config.llm_normal,
temperature=0.7,
max_tokens=300,
request_type="reply_generation"
model=global_config.llm_normal, temperature=0.7, max_tokens=300, request_type="reply_generation"
)
self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2)
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
self.name = global_config.BOT_NICKNAME
self.chat_observer = ChatObserver.get_instance(stream_id)
self.reply_checker = ReplyChecker(stream_id)
async def generate(
self,
goal: str,
chat_history: List[Message],
knowledge_cache: Dict[str, str],
previous_reply: Optional[str] = None,
retry_count: int = 0
retry_count: int = 0,
) -> str:
"""生成回复
Args:
goal: 对话目标
chat_history: 聊天历史
knowledge_cache: 知识缓存
previous_reply: 上一次生成的回复(如果有)
retry_count: 当前重试次数
Returns:
str: 生成的回复
"""
@@ -465,7 +459,7 @@ class ReplyGenerator:
self.chat_observer.trigger_update() # 触发立即更新
if not await self.chat_observer.wait_for_update():
logger.warning("等待消息更新超时")
messages = self.chat_observer.get_message_history(limit=20)
chat_history_text = ""
for msg in messages:
@@ -475,7 +469,7 @@ class ReplyGenerator:
if sender == self.name:
sender = "你说"
chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n"
# 整理知识缓存
knowledge_text = ""
if knowledge_cache:
@@ -486,14 +480,14 @@ class ReplyGenerator:
elif isinstance(knowledge_cache, list):
for item in knowledge_cache:
knowledge_text += f"\n{item}"
# 添加上一次生成的回复信息
previous_reply_text = ""
if previous_reply:
previous_reply_text = f"\n上一次生成的回复(需要改进):\n{previous_reply}"
personality_text = f"你的名字是{self.name}{self.personality_info}"
prompt = f"""{personality_text}。现在你在参与一场QQ聊天请根据以下信息生成回复
当前对话目标:{goal}
@@ -507,7 +501,7 @@ class ReplyGenerator:
2. 体现你的性格特征
3. 自然流畅,像正常聊天一样,简短
4. 适当利用相关知识,但不要生硬引用
{'5. 改进上一次回复中的问题' if previous_reply else ''}
{"5. 改进上一次回复中的问题" if previous_reply else ""}
请注意把握聊天内容,不要回复的太有条理,可以有个性。请分清""和对方说的话,不要把""说的话当做对方说的话,这是你自己说的话。
请你回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,尽量不要说你说过的话
@@ -521,34 +515,26 @@ class ReplyGenerator:
logger.info(f"生成的回复: {content}")
is_new = self.chat_observer.check()
logger.debug(f"再看一眼聊天记录,{'' if is_new else '没有'}新消息")
# 如果有新消息,重新生成回复
if is_new:
logger.info("检测到新消息,重新生成回复")
return await self.generate(
goal, chat_history, knowledge_cache,
None, retry_count
)
return await self.generate(goal, chat_history, knowledge_cache, None, retry_count)
return content
except Exception as e:
logger.error(f"生成回复时出错: {e}")
return "抱歉,我现在有点混乱,让我重新思考一下..."
async def check_reply(
self,
reply: str,
goal: str,
retry_count: int = 0
) -> Tuple[bool, str, bool]:
async def check_reply(self, reply: str, goal: str, retry_count: int = 0) -> Tuple[bool, str, bool]:
"""检查回复是否合适
Args:
reply: 生成的回复
goal: 对话目标
retry_count: 当前重试次数
Returns:
Tuple[bool, str, bool]: (是否合适, 原因, 是否需要重新规划)
"""
@@ -557,18 +543,18 @@ class ReplyGenerator:
class Conversation:
# 类级别的实例管理
_instances: Dict[str, 'Conversation'] = {}
_instances: Dict[str, "Conversation"] = {}
_instance_lock = asyncio.Lock() # 类级别的全局锁
_init_events: Dict[str, asyncio.Event] = {} # 初始化完成事件
_initializing: Dict[str, bool] = {} # 标记是否正在初始化
@classmethod
async def get_instance(cls, stream_id: str) -> Optional['Conversation']:
async def get_instance(cls, stream_id: str) -> Optional["Conversation"]:
"""获取或创建对话实例
Args:
stream_id: 聊天流ID
Returns:
Optional[Conversation]: 对话实例如果创建或等待失败则返回None
"""
@@ -586,23 +572,23 @@ class Conversation:
return None
finally:
await cls._instance_lock.acquire()
# 如果实例不存在,创建新实例
if stream_id not in cls._instances:
cls._instances[stream_id] = cls(stream_id)
cls._init_events[stream_id] = asyncio.Event()
cls._initializing[stream_id] = True
logger.info(f"创建新的对话实例: {stream_id}")
return cls._instances[stream_id]
except Exception as e:
logger.error(f"获取对话实例失败: {e}")
return None
@classmethod
async def remove_instance(cls, stream_id: str):
"""删除对话实例
Args:
stream_id: 聊天流ID
"""
@@ -628,16 +614,16 @@ class Conversation:
self.goal_reasoning: Optional[str] = None
self.generated_reply: Optional[str] = None
self.should_continue = True
# 初始化聊天观察器
self.chat_observer = ChatObserver.get_instance(stream_id)
# 添加action历史记录
self.action_history: List[Dict[str, str]] = []
# 知识缓存
self.knowledge_cache: Dict[str, str] = {} # 确保初始化为字典
# 初始化各个组件
self.goal_analyzer = GoalAnalyzer(self.stream_id)
self.action_planner = ActionPlanner(self.stream_id)
@@ -645,14 +631,14 @@ class Conversation:
self.knowledge_fetcher = KnowledgeFetcher()
self.direct_sender = DirectMessageSender()
self.waiter = Waiter(self.stream_id)
# 创建聊天流
self.chat_stream = chat_manager.get_stream(self.stream_id)
def _clear_knowledge_cache(self):
"""清空知识缓存"""
self.knowledge_cache.clear() # 使用clear方法清空字典
async def start(self):
"""开始对话流程"""
try:
@@ -674,38 +660,38 @@ class Conversation:
"""对话循环"""
# 获取最近的消息历史
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
while self.should_continue:
# 执行行动
self.chat_observer.trigger_update() # 触发立即更新
if not await self.chat_observer.wait_for_update():
logger.warning("等待消息更新超时")
action, reason = await self.action_planner.plan(
self.current_goal,
self.current_method,
self.goal_reasoning,
self.action_history, # 传入action历史
self.chat_observer # 传入chat_observer
self.chat_observer, # 传入chat_observer
)
# 执行行动
await self._handle_action(action, reason)
def _convert_to_message(self, msg_dict: Dict[str, Any]) -> Message:
"""将消息字典转换为Message对象"""
try:
chat_info = msg_dict.get("chat_info", {})
chat_stream = ChatStream.from_dict(chat_info)
user_info = UserInfo.from_dict(msg_dict.get("user_info", {}))
return Message(
message_id=msg_dict["message_id"],
chat_stream=chat_stream,
time=msg_dict["time"],
user_info=user_info,
processed_plain_text=msg_dict.get("processed_plain_text", ""),
detailed_plain_text=msg_dict.get("detailed_plain_text", "")
detailed_plain_text=msg_dict.get("detailed_plain_text", ""),
)
except Exception as e:
logger.warning(f"转换消息时出错: {e}")
@@ -714,18 +700,16 @@ class Conversation:
async def _handle_action(self, action: str, reason: str):
"""处理规划的行动"""
logger.info(f"执行行动: {action}, 原因: {reason}")
# 记录action历史
self.action_history.append({
"action": action,
"reason": reason,
"time": datetime.datetime.now().strftime("%H:%M:%S")
})
self.action_history.append(
{"action": action, "reason": reason, "time": datetime.datetime.now().strftime("%H:%M:%S")}
)
# 只保留最近的10条记录
if len(self.action_history) > 10:
self.action_history = self.action_history[-10:]
if action == "direct_reply":
self.state = ConversationState.GENERATING
messages = self.chat_observer.get_message_history(limit=30)
@@ -733,15 +717,14 @@ class Conversation:
self.current_goal,
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache
self.knowledge_cache,
)
# 检查回复是否合适
is_suitable, reason, need_replan = await self.reply_generator.check_reply(
self.generated_reply,
self.current_goal
self.generated_reply, self.current_goal
)
if not is_suitable:
logger.warning(f"生成的回复不合适,原因: {reason}")
if need_replan:
@@ -756,29 +739,34 @@ class Conversation:
self.current_goal,
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache
self.knowledge_cache,
)
# 检查使用新目标生成的回复是否合适
is_suitable, reason, _ = await self.reply_generator.check_reply(
self.generated_reply,
self.current_goal
self.generated_reply, self.current_goal
)
if is_suitable:
# 如果新目标的回复合适,调整目标优先级
await self.goal_analyzer._update_goals(
self.current_goal,
self.current_method,
self.goal_reasoning
self.current_goal, self.current_method, self.goal_reasoning
)
else:
# 如果新目标还是不合适,重新思考目标
self.state = ConversationState.RETHINKING
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
(
self.current_goal,
self.current_method,
self.goal_reasoning,
) = await self.goal_analyzer.analyze_goal()
return
else:
# 没有备选目标,重新分析
self.state = ConversationState.RETHINKING
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
(
self.current_goal,
self.current_method,
self.goal_reasoning,
) = await self.goal_analyzer.analyze_goal()
return
else:
# 重新生成回复
@@ -787,9 +775,9 @@ class Conversation:
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache,
self.generated_reply # 将不合适的回复作为previous_reply传入
self.generated_reply, # 将不合适的回复作为previous_reply传入
)
while self.chat_observer.check():
if not is_suitable:
logger.warning(f"生成的回复不合适,原因: {reason}")
@@ -805,13 +793,17 @@ class Conversation:
self.current_goal,
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache
self.knowledge_cache,
)
is_suitable = True # 假设使用新目标后回复是合适的
else:
# 没有备选目标,重新分析
self.state = ConversationState.RETHINKING
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
(
self.current_goal,
self.current_method,
self.goal_reasoning,
) = await self.goal_analyzer.analyze_goal()
return
else:
# 重新生成回复
@@ -820,36 +812,34 @@ class Conversation:
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache,
self.generated_reply # 将不合适的回复作为previous_reply传入
self.generated_reply, # 将不合适的回复作为previous_reply传入
)
await self._send_reply()
elif action == "fetch_knowledge":
self.state = ConversationState.GENERATING
messages = self.chat_observer.get_message_history(limit=30)
knowledge, sources = await self.knowledge_fetcher.fetch(
self.current_goal,
[self._convert_to_message(msg) for msg in messages]
self.current_goal, [self._convert_to_message(msg) for msg in messages]
)
logger.info(f"获取到知识,来源: {sources}")
if knowledge != "未找到相关知识":
self.knowledge_cache[sources] = knowledge
self.generated_reply = await self.reply_generator.generate(
self.current_goal,
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache
self.knowledge_cache,
)
# 检查回复是否合适
is_suitable, reason, need_replan = await self.reply_generator.check_reply(
self.generated_reply,
self.current_goal
self.generated_reply, self.current_goal
)
if not is_suitable:
logger.warning(f"生成的回复不合适,原因: {reason}")
if need_replan:
@@ -861,22 +851,25 @@ class Conversation:
logger.info(f"切换到备选目标: {self.current_goal}")
# 使用新目标获取知识并生成回复
knowledge, sources = await self.knowledge_fetcher.fetch(
self.current_goal,
[self._convert_to_message(msg) for msg in messages]
self.current_goal, [self._convert_to_message(msg) for msg in messages]
)
if knowledge != "未找到相关知识":
self.knowledge_cache[sources] = knowledge
self.generated_reply = await self.reply_generator.generate(
self.current_goal,
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache
self.knowledge_cache,
)
else:
# 没有备选目标,重新分析
self.state = ConversationState.RETHINKING
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
(
self.current_goal,
self.current_method,
self.goal_reasoning,
) = await self.goal_analyzer.analyze_goal()
return
else:
# 重新生成回复
@@ -885,19 +878,21 @@ class Conversation:
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache,
self.generated_reply # 将不合适的回复作为previous_reply传入
self.generated_reply, # 将不合适的回复作为previous_reply传入
)
await self._send_reply()
elif action == "rethink_goal":
self.state = ConversationState.RETHINKING
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
elif action == "judge_conversation":
self.state = ConversationState.JUDGING
self.goal_achieved, self.stop_conversation, self.reason = await self.goal_analyzer.analyze_conversation(self.current_goal, self.goal_reasoning)
self.goal_achieved, self.stop_conversation, self.reason = await self.goal_analyzer.analyze_conversation(
self.current_goal, self.goal_reasoning
)
# 如果当前目标达成但还有其他目标
if self.goal_achieved and not self.stop_conversation:
alternative_goals = await self.goal_analyzer.get_alternative_goals()
@@ -906,17 +901,17 @@ class Conversation:
self.current_goal, self.current_method, self.goal_reasoning = alternative_goals[0]
logger.info(f"当前目标已达成,切换到新目标: {self.current_goal}")
return
if self.stop_conversation:
await self._stop_conversation()
elif action == "listening":
self.state = ConversationState.LISTENING
logger.info("倾听对方发言...")
if await self.waiter.wait(): # 如果返回True表示超时
await self._send_timeout_message()
await self._stop_conversation()
else: # wait
self.state = ConversationState.WAITING
logger.info("等待更多信息...")
@@ -938,12 +933,12 @@ class Conversation:
messages = self.chat_observer.get_message_history(limit=1)
if not messages:
return
latest_message = self._convert_to_message(messages[0])
await self.direct_sender.send_message(
chat_stream=self.chat_stream,
content="抱歉,由于等待时间过长,我需要先去忙别的了。下次再聊吧~",
reply_to_message=latest_message
reply_to_message=latest_message,
)
except Exception as e:
logger.error(f"发送超时消息失败: {str(e)}")
@@ -953,23 +948,21 @@ class Conversation:
if not self.generated_reply:
logger.warning("没有生成回复")
return
messages = self.chat_observer.get_message_history(limit=1)
if not messages:
logger.warning("没有最近的消息可以回复")
return
latest_message = self._convert_to_message(messages[0])
try:
await self.direct_sender.send_message(
chat_stream=self.chat_stream,
content=self.generated_reply,
reply_to_message=latest_message
chat_stream=self.chat_stream, content=self.generated_reply, reply_to_message=latest_message
)
self.chat_observer.trigger_update() # 触发立即更新
if not await self.chat_observer.wait_for_update():
logger.warning("等待消息更新超时")
self.state = ConversationState.ANALYZING
except Exception as e:
logger.error(f"发送消息失败: {str(e)}")
@@ -978,7 +971,7 @@ class Conversation:
class DirectMessageSender:
"""直接发送消息到平台的发送器"""
def __init__(self):
self.logger = get_module_logger("direct_sender")
self.storage = MessageStorage()
@@ -990,7 +983,7 @@ class DirectMessageSender:
reply_to_message: Optional[Message] = None,
) -> None:
"""直接发送消息到平台
Args:
chat_stream: 聊天流
content: 消息内容
@@ -1003,7 +996,7 @@ class DirectMessageSender:
user_nickname=global_config.BOT_NICKNAME,
platform=chat_stream.platform,
)
message = MessageSending(
message_id=f"dm{round(time.time(), 2)}",
chat_stream=chat_stream,
@@ -1023,18 +1016,17 @@ class DirectMessageSender:
try:
message_json = message.to_dict()
end_point = global_config.api_urls.get(chat_stream.platform, None)
if not end_point:
raise ValueError(f"未找到平台:{chat_stream.platform} 的url配置")
await global_api.send_message_REST(end_point, message_json)
# 存储消息
await self.storage.store_message(message, message.chat_stream)
self.logger.info(f"直接发送消息成功: {content[:30]}...")
except Exception as e:
self.logger.error(f"直接发送消息失败: {str(e)}")
raise

View File

@@ -7,24 +7,22 @@ from ..chat.message import Message
logger = get_module_logger("knowledge_fetcher")
class KnowledgeFetcher:
"""知识调取器"""
def __init__(self):
self.llm = LLM_request(
model=global_config.llm_normal,
temperature=0.7,
max_tokens=1000,
request_type="knowledge_fetch"
model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="knowledge_fetch"
)
async def fetch(self, query: str, chat_history: List[Message]) -> Tuple[str, str]:
"""获取相关知识
Args:
query: 查询内容
chat_history: 聊天历史
Returns:
Tuple[str, str]: (获取的知识, 知识来源)
"""
@@ -33,16 +31,16 @@ class KnowledgeFetcher:
for msg in chat_history:
# sender = msg.message_info.user_info.user_nickname or f"用户{msg.message_info.user_info.user_id}"
chat_history_text += f"{msg.detailed_plain_text}\n"
# 从记忆中获取相关知识
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
text=f"{query}\n{chat_history_text}",
max_memory_num=3,
max_memory_length=2,
max_depth=3,
fast_retrieval=False
fast_retrieval=False,
)
if related_memory:
knowledge = ""
sources = []
@@ -50,5 +48,5 @@ class KnowledgeFetcher:
knowledge += memory[1] + "\n"
sources.append(f"记忆片段{memory[0]}")
return knowledge.strip(), "".join(sources)
return "未找到相关知识", "无记忆匹配"
return "未找到相关知识", "无记忆匹配"

View File

@@ -5,36 +5,37 @@ from src.common.logger import get_module_logger
logger = get_module_logger("pfc_utils")
def get_items_from_json(
content: str,
*items: str,
default_values: Optional[Dict[str, Any]] = None,
required_types: Optional[Dict[str, type]] = None
required_types: Optional[Dict[str, type]] = None,
) -> Tuple[bool, Dict[str, Any]]:
"""从文本中提取JSON内容并获取指定字段
Args:
content: 包含JSON的文本
*items: 要提取的字段名
default_values: 字段的默认值,格式为 {字段名: 默认值}
required_types: 字段的必需类型,格式为 {字段名: 类型}
Returns:
Tuple[bool, Dict[str, Any]]: (是否成功, 提取的字段字典)
"""
content = content.strip()
result = {}
# 设置默认值
if default_values:
result.update(default_values)
# 尝试解析JSON
try:
json_data = json.loads(content)
except json.JSONDecodeError:
# 如果直接解析失败尝试查找和提取JSON部分
json_pattern = r'\{[^{}]*\}'
json_pattern = r"\{[^{}]*\}"
json_match = re.search(json_pattern, content)
if json_match:
try:
@@ -45,28 +46,28 @@ def get_items_from_json(
else:
logger.error("无法在返回内容中找到有效的JSON")
return False, result
# 提取字段
for item in items:
if item in json_data:
result[item] = json_data[item]
# 验证必需字段
if not all(item in result for item in items):
logger.error(f"JSON缺少必要字段实际内容: {json_data}")
return False, result
# 验证字段类型
if required_types:
for field, expected_type in required_types.items():
if field in result and not isinstance(result[field], expected_type):
logger.error(f"{field} 必须是 {expected_type.__name__} 类型")
return False, result
# 验证字符串字段不为空
for field in items:
if isinstance(result[field], str) and not result[field].strip():
logger.error(f"{field} 不能为空")
return False, result
return True, result
return True, result

View File

@@ -9,33 +9,26 @@ from ..message.message_base import UserInfo
logger = get_module_logger("reply_checker")
class ReplyChecker:
"""回复检查器"""
def __init__(self, stream_id: str):
self.llm = LLM_request(
model=global_config.llm_normal,
temperature=0.7,
max_tokens=1000,
request_type="reply_check"
model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="reply_check"
)
self.name = global_config.BOT_NICKNAME
self.chat_observer = ChatObserver.get_instance(stream_id)
self.max_retries = 2 # 最大重试次数
async def check(
self,
reply: str,
goal: str,
retry_count: int = 0
) -> Tuple[bool, str, bool]:
async def check(self, reply: str, goal: str, retry_count: int = 0) -> Tuple[bool, str, bool]:
"""检查生成的回复是否合适
Args:
reply: 生成的回复
goal: 对话目标
retry_count: 当前重试次数
Returns:
Tuple[bool, str, bool]: (是否合适, 原因, 是否需要重新规划)
"""
@@ -49,7 +42,7 @@ class ReplyChecker:
if sender == self.name:
sender = "你说"
chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n"
prompt = f"""请检查以下回复是否合适:
当前对话目标:{goal}
@@ -83,7 +76,7 @@ class ReplyChecker:
try:
content, _ = await self.llm.generate_response_async(prompt)
logger.debug(f"检查回复的原始返回: {content}")
# 清理内容尝试提取JSON部分
content = content.strip()
try:
@@ -92,7 +85,8 @@ class ReplyChecker:
except json.JSONDecodeError:
# 如果直接解析失败尝试查找和提取JSON部分
import re
json_pattern = r'\{[^{}]*\}'
json_pattern = r"\{[^{}]*\}"
json_match = re.search(json_pattern, content)
if json_match:
try:
@@ -109,33 +103,33 @@ class ReplyChecker:
reason = content[:100] if content else "无法解析响应"
need_replan = "重新规划" in content.lower() or "目标不适合" in content.lower()
return is_suitable, reason, need_replan
# 验证JSON字段
suitable = result.get("suitable", None)
reason = result.get("reason", "未提供原因")
need_replan = result.get("need_replan", False)
# 如果suitable字段是字符串转换为布尔值
if isinstance(suitable, str):
suitable = suitable.lower() == "true"
# 如果suitable字段不存在或不是布尔值从reason中判断
if suitable is None:
suitable = "不合适" not in reason.lower() and "违规" not in reason.lower()
# 如果不合适且未达到最大重试次数,返回需要重试
if not suitable and retry_count < self.max_retries:
return False, reason, False
# 如果不合适且已达到最大重试次数,返回需要重新规划
if not suitable and retry_count >= self.max_retries:
return False, f"多次重试后仍不合适: {reason}", True
return suitable, reason, need_replan
except Exception as e:
logger.error(f"检查回复时出错: {e}")
# 如果出错且已达到最大重试次数,建议重新规划
if retry_count >= self.max_retries:
return False, "多次检查失败,建议重新规划", True
return False, f"检查过程出错,建议重试: {str(e)}", False
return False, f"检查过程出错,建议重试: {str(e)}", False