Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Optional, Dict, Any, List, Tuple
|
||||
from typing import Optional, Dict, Any, List, Tuple
|
||||
from src.common.logger import get_module_logger
|
||||
from ..message.message_base import UserInfo
|
||||
from ..config.config import global_config
|
||||
@@ -9,16 +9,17 @@ from .message_storage import MessageStorage, MongoDBMessageStorage
|
||||
|
||||
logger = get_module_logger("chat_observer")
|
||||
|
||||
|
||||
class ChatObserver:
|
||||
"""聊天状态观察器"""
|
||||
|
||||
|
||||
# 类级别的实例管理
|
||||
_instances: Dict[str, 'ChatObserver'] = {}
|
||||
|
||||
_instances: Dict[str, "ChatObserver"] = {}
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, stream_id: str, message_storage: Optional[MessageStorage] = None) -> 'ChatObserver':
|
||||
"""获取或创建观察器实例
|
||||
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
message_storage: 消息存储实现,如果为None则使用MongoDB实现
|
||||
@@ -32,14 +33,14 @@ class ChatObserver:
|
||||
|
||||
def __init__(self, stream_id: str, message_storage: Optional[MessageStorage] = None):
|
||||
"""初始化观察器
|
||||
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
message_storage: 消息存储实现,如果为None则使用MongoDB实现
|
||||
"""
|
||||
if stream_id in self._instances:
|
||||
raise RuntimeError(f"ChatObserver for {stream_id} already exists. Use get_instance() instead.")
|
||||
|
||||
|
||||
self.stream_id = stream_id
|
||||
self.message_storage = message_storage or MongoDBMessageStorage()
|
||||
|
||||
@@ -53,9 +54,9 @@ class ChatObserver:
|
||||
|
||||
# 消息历史记录
|
||||
self.message_history: List[Dict[str, Any]] = [] # 所有消息历史
|
||||
self.last_message_id: Optional[str] = None # 最后一条消息的ID
|
||||
self.message_count: int = 0 # 消息计数
|
||||
|
||||
self.last_message_id: Optional[str] = None # 最后一条消息的ID
|
||||
self.message_count: int = 0 # 消息计数
|
||||
|
||||
# 运行状态
|
||||
self._running: bool = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
@@ -77,7 +78,7 @@ class ChatObserver:
|
||||
|
||||
async def check(self) -> bool:
|
||||
"""检查距离上一次观察之后是否有了新消息
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否有新消息
|
||||
"""
|
||||
@@ -91,7 +92,7 @@ class ChatObserver:
|
||||
if new_message_exists:
|
||||
logger.debug("发现新消息")
|
||||
self.last_check_time = time.time()
|
||||
|
||||
|
||||
return new_message_exists
|
||||
|
||||
async def _add_message_to_history(self, message: Dict[str, Any]):
|
||||
@@ -104,7 +105,7 @@ class ChatObserver:
|
||||
self.last_message_id = message["message_id"]
|
||||
self.last_message_time = message["time"] # 更新最后消息时间
|
||||
self.message_count += 1
|
||||
|
||||
|
||||
# 更新说话时间
|
||||
user_info = UserInfo.from_dict(message.get("user_info", {}))
|
||||
if user_info.user_id == global_config.BOT_QQ:
|
||||
@@ -186,41 +187,40 @@ class ChatObserver:
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None,
|
||||
limit: Optional[int] = None,
|
||||
user_id: Optional[str] = None
|
||||
user_id: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取消息历史
|
||||
|
||||
|
||||
Args:
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳
|
||||
limit: 限制返回消息数量
|
||||
user_id: 指定用户ID
|
||||
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
"""
|
||||
filtered_messages = self.message_history
|
||||
|
||||
|
||||
if start_time is not None:
|
||||
filtered_messages = [m for m in filtered_messages if m["time"] >= start_time]
|
||||
|
||||
|
||||
if end_time is not None:
|
||||
filtered_messages = [m for m in filtered_messages if m["time"] <= end_time]
|
||||
|
||||
|
||||
if user_id is not None:
|
||||
filtered_messages = [
|
||||
m for m in filtered_messages
|
||||
if UserInfo.from_dict(m.get("user_info", {})).user_id == user_id
|
||||
m for m in filtered_messages if UserInfo.from_dict(m.get("user_info", {})).user_id == user_id
|
||||
]
|
||||
|
||||
|
||||
if limit is not None:
|
||||
filtered_messages = filtered_messages[-limit:]
|
||||
|
||||
|
||||
return filtered_messages
|
||||
|
||||
|
||||
async def _fetch_new_messages(self) -> List[Dict[str, Any]]:
|
||||
"""获取新消息
|
||||
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 新消息列表
|
||||
"""
|
||||
@@ -231,15 +231,15 @@ class ChatObserver:
|
||||
|
||||
if new_messages:
|
||||
self.last_message_read = new_messages[-1]["message_id"]
|
||||
|
||||
|
||||
return new_messages
|
||||
|
||||
|
||||
async def _fetch_new_messages_before(self, time_point: float) -> List[Dict[str, Any]]:
|
||||
"""获取指定时间点之前的消息
|
||||
|
||||
|
||||
Args:
|
||||
time_point: 时间戳
|
||||
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 最多5条消息
|
||||
"""
|
||||
@@ -250,7 +250,7 @@ class ChatObserver:
|
||||
|
||||
if new_messages:
|
||||
self.last_message_read = new_messages[-1]["message_id"]
|
||||
|
||||
|
||||
return new_messages
|
||||
|
||||
'''主要观察循环'''
|
||||
@@ -263,7 +263,7 @@ class ChatObserver:
|
||||
await self._add_message_to_history(message)
|
||||
except Exception as e:
|
||||
logger.error(f"缓冲消息出错: {e}")
|
||||
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
# 等待事件或超时(1秒)
|
||||
@@ -271,13 +271,13 @@ class ChatObserver:
|
||||
await asyncio.wait_for(self._update_event.wait(), timeout=1)
|
||||
except asyncio.TimeoutError:
|
||||
pass # 超时后也执行一次检查
|
||||
|
||||
|
||||
self._update_event.clear() # 重置触发事件
|
||||
self._update_complete.clear() # 重置完成事件
|
||||
|
||||
|
||||
# 获取新消息
|
||||
new_messages = await self._fetch_new_messages()
|
||||
|
||||
|
||||
if new_messages:
|
||||
# 处理新消息
|
||||
for message in new_messages:
|
||||
@@ -285,21 +285,21 @@ class ChatObserver:
|
||||
|
||||
# 设置完成事件
|
||||
self._update_complete.set()
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新循环出错: {e}")
|
||||
self._update_complete.set() # 即使出错也要设置完成事件
|
||||
|
||||
|
||||
def trigger_update(self):
|
||||
"""触发一次立即更新"""
|
||||
self._update_event.set()
|
||||
|
||||
|
||||
async def wait_for_update(self, timeout: float = 5.0) -> bool:
|
||||
"""等待更新完成
|
||||
|
||||
|
||||
Args:
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否成功完成更新(False表示超时)
|
||||
"""
|
||||
@@ -309,16 +309,16 @@ class ChatObserver:
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"等待更新完成超时({timeout}秒)")
|
||||
return False
|
||||
|
||||
|
||||
def start(self):
|
||||
"""启动观察器"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
|
||||
self._running = True
|
||||
self._task = asyncio.create_task(self._update_loop())
|
||||
logger.info(f"ChatObserver for {self.stream_id} started")
|
||||
|
||||
|
||||
def stop(self):
|
||||
"""停止观察器"""
|
||||
self._running = False
|
||||
@@ -327,15 +327,15 @@ class ChatObserver:
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
logger.info(f"ChatObserver for {self.stream_id} stopped")
|
||||
|
||||
|
||||
async def process_chat_history(self, messages: list):
|
||||
"""处理聊天历史
|
||||
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
"""
|
||||
self.update_check_time()
|
||||
|
||||
|
||||
for msg in messages:
|
||||
try:
|
||||
user_info = UserInfo.from_dict(msg.get("user_info", {}))
|
||||
@@ -345,33 +345,33 @@ class ChatObserver:
|
||||
self.update_user_speak_time(msg["time"])
|
||||
except Exception as e:
|
||||
logger.warning(f"处理消息时间时出错: {e}")
|
||||
continue
|
||||
|
||||
continue
|
||||
|
||||
def update_check_time(self):
|
||||
"""更新查看时间"""
|
||||
self.last_check_time = time.time()
|
||||
|
||||
|
||||
def update_bot_speak_time(self, speak_time: Optional[float] = None):
|
||||
"""更新机器人说话时间"""
|
||||
self.last_bot_speak_time = speak_time or time.time()
|
||||
|
||||
|
||||
def update_user_speak_time(self, speak_time: Optional[float] = None):
|
||||
"""更新用户说话时间"""
|
||||
self.last_user_speak_time = speak_time or time.time()
|
||||
|
||||
|
||||
def get_time_info(self) -> str:
|
||||
"""获取时间信息文本"""
|
||||
current_time = time.time()
|
||||
time_info = ""
|
||||
|
||||
|
||||
if self.last_bot_speak_time:
|
||||
bot_speak_ago = current_time - self.last_bot_speak_time
|
||||
time_info += f"\n距离你上次发言已经过去了{int(bot_speak_ago)}秒"
|
||||
|
||||
|
||||
if self.last_user_speak_time:
|
||||
user_speak_ago = current_time - self.last_user_speak_time
|
||||
time_info += f"\n距离对方上次发言已经过去了{int(user_speak_ago)}秒"
|
||||
|
||||
|
||||
return time_info
|
||||
|
||||
def start_periodic_update(self):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#Programmable Friendly Conversationalist
|
||||
#Prefrontal cortex
|
||||
# Programmable Friendly Conversationalist
|
||||
# Prefrontal cortex
|
||||
import datetime
|
||||
import asyncio
|
||||
from typing import List, Optional, Tuple, TYPE_CHECKING
|
||||
@@ -29,20 +29,17 @@ logger = get_module_logger("pfc")
|
||||
|
||||
class GoalAnalyzer:
|
||||
"""对话目标分析器"""
|
||||
|
||||
|
||||
def __init__(self, stream_id: str):
|
||||
self.llm = LLM_request(
|
||||
model=global_config.llm_normal,
|
||||
temperature=0.7,
|
||||
max_tokens=1000,
|
||||
request_type="conversation_goal"
|
||||
model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="conversation_goal"
|
||||
)
|
||||
|
||||
self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2)
|
||||
|
||||
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
|
||||
self.name = global_config.BOT_NICKNAME
|
||||
self.nick_name = global_config.BOT_ALIAS_NAMES
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id)
|
||||
|
||||
|
||||
# 多目标存储结构
|
||||
self.goals = [] # 存储多个目标
|
||||
self.max_goals = 3 # 同时保持的最大目标数量
|
||||
@@ -50,10 +47,10 @@ class GoalAnalyzer:
|
||||
|
||||
async def analyze_goal(self) -> Tuple[str, str, str]:
|
||||
"""分析对话历史并设定目标
|
||||
|
||||
|
||||
Args:
|
||||
chat_history: 聊天历史记录列表
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[str, str, str]: (目标, 方法, 原因)
|
||||
"""
|
||||
@@ -70,16 +67,16 @@ class GoalAnalyzer:
|
||||
if sender == self.name:
|
||||
sender = "你说"
|
||||
chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n"
|
||||
|
||||
|
||||
personality_text = f"你的名字是{self.name},{self.personality_info}"
|
||||
|
||||
|
||||
# 构建当前已有目标的文本
|
||||
existing_goals_text = ""
|
||||
if self.goals:
|
||||
existing_goals_text = "当前已有的对话目标:\n"
|
||||
for i, (goal, _, reason) in enumerate(self.goals):
|
||||
existing_goals_text += f"{i+1}. 目标: {goal}, 原因: {reason}\n"
|
||||
|
||||
existing_goals_text += f"{i + 1}. 目标: {goal}, 原因: {reason}\n"
|
||||
|
||||
prompt = f"""{personality_text}。现在你在参与一场QQ聊天,请分析以下聊天记录,并根据你的性格特征确定多个明确的对话目标。
|
||||
这些目标应该反映出对话的不同方面和意图。
|
||||
|
||||
@@ -107,46 +104,44 @@ class GoalAnalyzer:
|
||||
logger.debug(f"发送到LLM的提示词: {prompt}")
|
||||
content, _ = await self.llm.generate_response_async(prompt)
|
||||
logger.debug(f"LLM原始返回内容: {content}")
|
||||
|
||||
|
||||
# 使用简化函数提取JSON内容
|
||||
success, result = get_items_from_json(
|
||||
content,
|
||||
"goal", "reasoning",
|
||||
required_types={"goal": str, "reasoning": str}
|
||||
content, "goal", "reasoning", required_types={"goal": str, "reasoning": str}
|
||||
)
|
||||
|
||||
|
||||
if not success:
|
||||
logger.error(f"无法解析JSON,重试第{retry + 1}次")
|
||||
continue
|
||||
|
||||
|
||||
goal = result["goal"]
|
||||
reasoning = result["reasoning"]
|
||||
|
||||
|
||||
# 使用默认的方法
|
||||
method = "以友好的态度回应"
|
||||
|
||||
|
||||
# 更新目标列表
|
||||
await self._update_goals(goal, method, reasoning)
|
||||
|
||||
|
||||
# 返回当前最主要的目标
|
||||
if self.goals:
|
||||
current_goal, current_method, current_reasoning = self.goals[0]
|
||||
return current_goal, current_method, current_reasoning
|
||||
else:
|
||||
return goal, method, reasoning
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"分析对话目标时出错: {str(e)},重试第{retry + 1}次")
|
||||
if retry == max_retries - 1:
|
||||
return "保持友好的对话", "以友好的态度回应", "确保对话顺利进行"
|
||||
continue
|
||||
|
||||
|
||||
# 所有重试都失败后的默认返回
|
||||
return "保持友好的对话", "以友好的态度回应", "确保对话顺利进行"
|
||||
|
||||
|
||||
async def _update_goals(self, new_goal: str, method: str, reasoning: str):
|
||||
"""更新目标列表
|
||||
|
||||
|
||||
Args:
|
||||
new_goal: 新的目标
|
||||
method: 实现目标的方法
|
||||
@@ -160,23 +155,23 @@ class GoalAnalyzer:
|
||||
# 将此目标移到列表前面(最主要的位置)
|
||||
self.goals.insert(0, self.goals.pop(i))
|
||||
return
|
||||
|
||||
|
||||
# 添加新目标到列表前面
|
||||
self.goals.insert(0, (new_goal, method, reasoning))
|
||||
|
||||
|
||||
# 限制目标数量
|
||||
if len(self.goals) > self.max_goals:
|
||||
self.goals.pop() # 移除最老的目标
|
||||
|
||||
|
||||
def _calculate_similarity(self, goal1: str, goal2: str) -> float:
|
||||
"""简单计算两个目标之间的相似度
|
||||
|
||||
|
||||
这里使用一个简单的实现,实际可以使用更复杂的文本相似度算法
|
||||
|
||||
|
||||
Args:
|
||||
goal1: 第一个目标
|
||||
goal2: 第二个目标
|
||||
|
||||
|
||||
Returns:
|
||||
float: 相似度得分 (0-1)
|
||||
"""
|
||||
@@ -186,18 +181,18 @@ class GoalAnalyzer:
|
||||
overlap = len(words1.intersection(words2))
|
||||
total = len(words1.union(words2))
|
||||
return overlap / total if total > 0 else 0
|
||||
|
||||
|
||||
async def get_all_goals(self) -> List[Tuple[str, str, str]]:
|
||||
"""获取所有当前目标
|
||||
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, str]]: 目标列表,每项为(目标, 方法, 原因)
|
||||
"""
|
||||
return self.goals.copy()
|
||||
|
||||
|
||||
async def get_alternative_goals(self) -> List[Tuple[str, str, str]]:
|
||||
"""获取除了当前主要目标外的其他备选目标
|
||||
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, str]]: 备选目标列表
|
||||
"""
|
||||
@@ -215,9 +210,9 @@ class GoalAnalyzer:
|
||||
if sender == self.name:
|
||||
sender = "你说"
|
||||
chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n"
|
||||
|
||||
|
||||
personality_text = f"你的名字是{self.name},{self.personality_info}"
|
||||
|
||||
|
||||
prompt = f"""{personality_text}。现在你在参与一场QQ聊天,
|
||||
当前对话目标:{goal}
|
||||
产生该对话目标的原因:{reasoning}
|
||||
@@ -247,7 +242,7 @@ class GoalAnalyzer:
|
||||
"goal_achieved", "stop_conversation", "reason",
|
||||
required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str}
|
||||
)
|
||||
|
||||
|
||||
if not success:
|
||||
logger.error("无法解析对话分析结果JSON")
|
||||
return False, False, "解析结果失败"
|
||||
@@ -265,14 +260,15 @@ class GoalAnalyzer:
|
||||
|
||||
class Waiter:
|
||||
"""快 速 等 待"""
|
||||
|
||||
def __init__(self, stream_id: str):
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id)
|
||||
self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2)
|
||||
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
|
||||
self.name = global_config.BOT_NICKNAME
|
||||
|
||||
|
||||
async def wait(self) -> bool:
|
||||
"""等待
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否超时(True表示超时)
|
||||
"""
|
||||
@@ -298,7 +294,7 @@ class Waiter:
|
||||
|
||||
class DirectMessageSender:
|
||||
"""直接发送消息到平台的发送器"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.logger = get_module_logger("direct_sender")
|
||||
self.storage = MessageStorage()
|
||||
@@ -310,7 +306,7 @@ class DirectMessageSender:
|
||||
reply_to_message: Optional[Message] = None,
|
||||
) -> None:
|
||||
"""直接发送消息到平台
|
||||
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流
|
||||
content: 消息内容
|
||||
@@ -323,7 +319,7 @@ class DirectMessageSender:
|
||||
user_nickname=global_config.BOT_NICKNAME,
|
||||
platform=chat_stream.platform,
|
||||
)
|
||||
|
||||
|
||||
message = MessageSending(
|
||||
message_id=f"dm{round(time.time(), 2)}",
|
||||
chat_stream=chat_stream,
|
||||
@@ -343,18 +339,17 @@ class DirectMessageSender:
|
||||
try:
|
||||
message_json = message.to_dict()
|
||||
end_point = global_config.api_urls.get(chat_stream.platform, None)
|
||||
|
||||
|
||||
if not end_point:
|
||||
raise ValueError(f"未找到平台:{chat_stream.platform} 的url配置")
|
||||
|
||||
|
||||
await global_api.send_message_REST(end_point, message_json)
|
||||
|
||||
|
||||
# 存储消息
|
||||
await self.storage.store_message(message, message.chat_stream)
|
||||
|
||||
|
||||
self.logger.info(f"直接发送消息成功: {content[:30]}...")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"直接发送消息失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -7,24 +7,22 @@ from ..chat.message import Message
|
||||
|
||||
logger = get_module_logger("knowledge_fetcher")
|
||||
|
||||
|
||||
class KnowledgeFetcher:
|
||||
"""知识调取器"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.llm = LLM_request(
|
||||
model=global_config.llm_normal,
|
||||
temperature=0.7,
|
||||
max_tokens=1000,
|
||||
request_type="knowledge_fetch"
|
||||
model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="knowledge_fetch"
|
||||
)
|
||||
|
||||
|
||||
async def fetch(self, query: str, chat_history: List[Message]) -> Tuple[str, str]:
|
||||
"""获取相关知识
|
||||
|
||||
|
||||
Args:
|
||||
query: 查询内容
|
||||
chat_history: 聊天历史
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (获取的知识, 知识来源)
|
||||
"""
|
||||
@@ -33,16 +31,16 @@ class KnowledgeFetcher:
|
||||
for msg in chat_history:
|
||||
# sender = msg.message_info.user_info.user_nickname or f"用户{msg.message_info.user_info.user_id}"
|
||||
chat_history_text += f"{msg.detailed_plain_text}\n"
|
||||
|
||||
|
||||
# 从记忆中获取相关知识
|
||||
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
||||
text=f"{query}\n{chat_history_text}",
|
||||
max_memory_num=3,
|
||||
max_memory_length=2,
|
||||
max_depth=3,
|
||||
fast_retrieval=False
|
||||
fast_retrieval=False,
|
||||
)
|
||||
|
||||
|
||||
if related_memory:
|
||||
knowledge = ""
|
||||
sources = []
|
||||
@@ -50,5 +48,5 @@ class KnowledgeFetcher:
|
||||
knowledge += memory[1] + "\n"
|
||||
sources.append(f"记忆片段{memory[0]}")
|
||||
return knowledge.strip(), ",".join(sources)
|
||||
|
||||
return "未找到相关知识", "无记忆匹配"
|
||||
|
||||
return "未找到相关知识", "无记忆匹配"
|
||||
|
||||
@@ -5,36 +5,37 @@ from src.common.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger("pfc_utils")
|
||||
|
||||
|
||||
def get_items_from_json(
|
||||
content: str,
|
||||
*items: str,
|
||||
default_values: Optional[Dict[str, Any]] = None,
|
||||
required_types: Optional[Dict[str, type]] = None
|
||||
required_types: Optional[Dict[str, type]] = None,
|
||||
) -> Tuple[bool, Dict[str, Any]]:
|
||||
"""从文本中提取JSON内容并获取指定字段
|
||||
|
||||
|
||||
Args:
|
||||
content: 包含JSON的文本
|
||||
*items: 要提取的字段名
|
||||
default_values: 字段的默认值,格式为 {字段名: 默认值}
|
||||
required_types: 字段的必需类型,格式为 {字段名: 类型}
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Dict[str, Any]]: (是否成功, 提取的字段字典)
|
||||
"""
|
||||
content = content.strip()
|
||||
result = {}
|
||||
|
||||
|
||||
# 设置默认值
|
||||
if default_values:
|
||||
result.update(default_values)
|
||||
|
||||
|
||||
# 尝试解析JSON
|
||||
try:
|
||||
json_data = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
# 如果直接解析失败,尝试查找和提取JSON部分
|
||||
json_pattern = r'\{[^{}]*\}'
|
||||
json_pattern = r"\{[^{}]*\}"
|
||||
json_match = re.search(json_pattern, content)
|
||||
if json_match:
|
||||
try:
|
||||
@@ -45,28 +46,28 @@ def get_items_from_json(
|
||||
else:
|
||||
logger.error("无法在返回内容中找到有效的JSON")
|
||||
return False, result
|
||||
|
||||
|
||||
# 提取字段
|
||||
for item in items:
|
||||
if item in json_data:
|
||||
result[item] = json_data[item]
|
||||
|
||||
|
||||
# 验证必需字段
|
||||
if not all(item in result for item in items):
|
||||
logger.error(f"JSON缺少必要字段,实际内容: {json_data}")
|
||||
return False, result
|
||||
|
||||
|
||||
# 验证字段类型
|
||||
if required_types:
|
||||
for field, expected_type in required_types.items():
|
||||
if field in result and not isinstance(result[field], expected_type):
|
||||
logger.error(f"{field} 必须是 {expected_type.__name__} 类型")
|
||||
return False, result
|
||||
|
||||
|
||||
# 验证字符串字段不为空
|
||||
for field in items:
|
||||
if isinstance(result[field], str) and not result[field].strip():
|
||||
logger.error(f"{field} 不能为空")
|
||||
return False, result
|
||||
|
||||
return True, result
|
||||
|
||||
return True, result
|
||||
|
||||
@@ -9,33 +9,26 @@ from ..message.message_base import UserInfo
|
||||
|
||||
logger = get_module_logger("reply_checker")
|
||||
|
||||
|
||||
class ReplyChecker:
|
||||
"""回复检查器"""
|
||||
|
||||
|
||||
def __init__(self, stream_id: str):
|
||||
self.llm = LLM_request(
|
||||
model=global_config.llm_normal,
|
||||
temperature=0.7,
|
||||
max_tokens=1000,
|
||||
request_type="reply_check"
|
||||
model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="reply_check"
|
||||
)
|
||||
self.name = global_config.BOT_NICKNAME
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id)
|
||||
self.max_retries = 2 # 最大重试次数
|
||||
|
||||
async def check(
|
||||
self,
|
||||
reply: str,
|
||||
goal: str,
|
||||
retry_count: int = 0
|
||||
) -> Tuple[bool, str, bool]:
|
||||
|
||||
async def check(self, reply: str, goal: str, retry_count: int = 0) -> Tuple[bool, str, bool]:
|
||||
"""检查生成的回复是否合适
|
||||
|
||||
|
||||
Args:
|
||||
reply: 生成的回复
|
||||
goal: 对话目标
|
||||
retry_count: 当前重试次数
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, bool]: (是否合适, 原因, 是否需要重新规划)
|
||||
"""
|
||||
@@ -49,7 +42,7 @@ class ReplyChecker:
|
||||
if sender == self.name:
|
||||
sender = "你说"
|
||||
chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n"
|
||||
|
||||
|
||||
prompt = f"""请检查以下回复是否合适:
|
||||
|
||||
当前对话目标:{goal}
|
||||
@@ -83,7 +76,7 @@ class ReplyChecker:
|
||||
try:
|
||||
content, _ = await self.llm.generate_response_async(prompt)
|
||||
logger.debug(f"检查回复的原始返回: {content}")
|
||||
|
||||
|
||||
# 清理内容,尝试提取JSON部分
|
||||
content = content.strip()
|
||||
try:
|
||||
@@ -92,7 +85,8 @@ class ReplyChecker:
|
||||
except json.JSONDecodeError:
|
||||
# 如果直接解析失败,尝试查找和提取JSON部分
|
||||
import re
|
||||
json_pattern = r'\{[^{}]*\}'
|
||||
|
||||
json_pattern = r"\{[^{}]*\}"
|
||||
json_match = re.search(json_pattern, content)
|
||||
if json_match:
|
||||
try:
|
||||
@@ -109,33 +103,33 @@ class ReplyChecker:
|
||||
reason = content[:100] if content else "无法解析响应"
|
||||
need_replan = "重新规划" in content.lower() or "目标不适合" in content.lower()
|
||||
return is_suitable, reason, need_replan
|
||||
|
||||
|
||||
# 验证JSON字段
|
||||
suitable = result.get("suitable", None)
|
||||
reason = result.get("reason", "未提供原因")
|
||||
need_replan = result.get("need_replan", False)
|
||||
|
||||
|
||||
# 如果suitable字段是字符串,转换为布尔值
|
||||
if isinstance(suitable, str):
|
||||
suitable = suitable.lower() == "true"
|
||||
|
||||
|
||||
# 如果suitable字段不存在或不是布尔值,从reason中判断
|
||||
if suitable is None:
|
||||
suitable = "不合适" not in reason.lower() and "违规" not in reason.lower()
|
||||
|
||||
|
||||
# 如果不合适且未达到最大重试次数,返回需要重试
|
||||
if not suitable and retry_count < self.max_retries:
|
||||
return False, reason, False
|
||||
|
||||
|
||||
# 如果不合适且已达到最大重试次数,返回需要重新规划
|
||||
if not suitable and retry_count >= self.max_retries:
|
||||
return False, f"多次重试后仍不合适: {reason}", True
|
||||
|
||||
|
||||
return suitable, reason, need_replan
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检查回复时出错: {e}")
|
||||
# 如果出错且已达到最大重试次数,建议重新规划
|
||||
if retry_count >= self.max_retries:
|
||||
return False, "多次检查失败,建议重新规划", True
|
||||
return False, f"检查过程出错,建议重试: {str(e)}", False
|
||||
return False, f"检查过程出错,建议重试: {str(e)}", False
|
||||
|
||||
Reference in New Issue
Block a user