fix:加入工具调用能力

This commit is contained in:
SengokuCola
2025-04-10 22:13:17 +08:00
parent de061024c1
commit 110f94353f
6 changed files with 627 additions and 265 deletions

View File

@@ -16,6 +16,8 @@ import random
from src.plugins.chat.chat_stream import ChatStream
from src.plugins.person_info.relationship_manager import relationship_manager
from src.plugins.chat.utils import get_recent_group_speaker
import json
from src.heart_flow.tool_use import ToolUser
subheartflow_config = LogConfig(
# 使用海马体专用样式
@@ -47,6 +49,7 @@ class SubHeartflow:
self.llm_model = LLM_request(
model=global_config.llm_sub_heartflow, temperature=0.2, max_tokens=600, request_type="sub_heart_flow"
)
self.main_heartflow_info = ""
@@ -63,6 +66,8 @@ class SubHeartflow:
self.running_knowledges = []
self.bot_name = global_config.BOT_NICKNAME
self.tool_user = ToolUser()
def add_observation(self, observation: Observation):
"""添加一个新的observation对象到列表中如果已存在相同id的observation则不添加"""
@@ -115,6 +120,7 @@ class SubHeartflow:
observation = self.observations[0]
await observation.observe()
async def do_thinking_before_reply(self, message_txt:str, sender_name:str, chat_stream:ChatStream):
current_thinking_info = self.current_mind
mood_info = self.current_state.mood
@@ -123,6 +129,19 @@ class SubHeartflow:
chat_observe_info = observation.observe_info
# print(f"chat_observe_info{chat_observe_info}")
# 首先尝试使用工具获取更多信息
tool_result = await self.tool_user.use_tool(message_txt, sender_name, chat_stream)
# 如果工具被使用且获得了结果,将收集到的信息合并到思考中
if tool_result.get("used_tools", False):
logger.info("使用工具收集了信息")
# 如果有收集到的信息,将其添加到当前思考中
if "collected_info" in tool_result:
collected_info = tool_result["collected_info"]
# 开始构建prompt
prompt_personality = f"你的名字是{self.bot_name},你"
# person
@@ -158,38 +177,11 @@ class SubHeartflow:
f"根据你和说话者{sender_name}的关系和态度进行回复,明确你的立场和情感。"
)
# 调取记忆
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
text=chat_observe_info, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
)
if related_memory:
related_memory_info = ""
for memory in related_memory:
related_memory_info += memory[1]
else:
related_memory_info = ""
related_info, grouped_results = await self.get_prompt_info(chat_observe_info + message_txt, 0.4)
# print(related_info)
for _topic, results in grouped_results.items():
for result in results:
# print(result)
self.running_knowledges.append(result)
# print(f"相关记忆:{related_memory_info}")
schedule_info = bot_schedule.get_current_num_task(num=1, time_info=False)
prompt = ""
# prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n"
if tool_result.get("used_tools", False):
prompt += f"{collected_info}\n"
prompt += f"{relation_prompt_all}\n"
prompt += f"{prompt_personality}\n"
# prompt += f"你刚刚在做的事情是:{schedule_info}\n"
# if related_memory_info:
# prompt += f"你想起来你之前见过的回忆:{related_memory_info}。\n以上是你的回忆不一定是目前聊天里的人说的也不一定是现在发生的事情请记住。\n"
# if related_info:
# prompt += f"你想起你知道:{related_info}\n"
prompt += f"刚刚你的想法是{current_thinking_info}。如果有新的内容,记得转换话题\n"
prompt += "-----------------------------------\n"
prompt += f"现在你正在上网和qq群里的网友们聊天群里正在聊的话题是{chat_observe_info}\n"
@@ -211,7 +203,7 @@ class SubHeartflow:
logger.info(f"prompt:\n{prompt}\n")
logger.info(f"麦麦的思考前脑内状态:{self.current_mind}")
return self.current_mind ,self.past_mind
return self.current_mind, self.past_mind
async def do_thinking_after_reply(self, reply_content, chat_talking_prompt):
# print("麦麦回复之后脑袋转起来了")
@@ -310,224 +302,5 @@ class SubHeartflow:
self.past_mind.append(self.current_mind)
self.current_mind = response
async def get_prompt_info(self, message: str, threshold: float):
start_time = time.time()
related_info = ""
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
# 1. 先从LLM获取主题类似于记忆系统的做法
topics = []
# try:
# # 先尝试使用记忆系统的方法获取主题
# hippocampus = HippocampusManager.get_instance()._hippocampus
# topic_num = min(5, max(1, int(len(message) * 0.1)))
# topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
# # 提取关键词
# topics = re.findall(r"<([^>]+)>", topics_response[0])
# if not topics:
# topics = []
# else:
# topics = [
# topic.strip()
# for topic in ",".join(topics).replace("", ",").replace("、", ",").replace(" ", ",").split(",")
# if topic.strip()
# ]
# logger.info(f"从LLM提取的主题: {', '.join(topics)}")
# except Exception as e:
# logger.error(f"从LLM提取主题失败: {str(e)}")
# # 如果LLM提取失败使用jieba分词提取关键词作为备选
# words = jieba.cut(message)
# topics = [word for word in words if len(word) > 1][:5]
# logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
# 如果无法提取到主题,直接使用整个消息
if not topics:
logger.debug("未能提取到任何主题,使用整个消息进行查询")
embedding = await get_embedding(message, request_type="info_retrieval")
if not embedding:
logger.error("获取消息嵌入向量失败")
return ""
related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}")
return related_info, {}
# 2. 对每个主题进行知识库查询
logger.info(f"开始处理{len(topics)}个主题的知识库查询")
# 优化批量获取嵌入向量减少API调用
embeddings = {}
topics_batch = [topic for topic in topics if len(topic) > 0]
if message: # 确保消息非空
topics_batch.append(message)
# 批量获取嵌入向量
embed_start_time = time.time()
for text in topics_batch:
if not text or len(text.strip()) == 0:
continue
try:
embedding = await get_embedding(text, request_type="info_retrieval")
if embedding:
embeddings[text] = embedding
else:
logger.warning(f"获取'{text}'的嵌入向量失败")
except Exception as e:
logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}")
logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}")
if not embeddings:
logger.error("所有嵌入向量获取失败")
return ""
# 3. 对每个主题进行知识库查询
all_results = []
query_start_time = time.time()
# 首先添加原始消息的查询结果
if message in embeddings:
original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True)
if original_results:
for result in original_results:
result["topic"] = "原始消息"
all_results.extend(original_results)
logger.info(f"原始消息查询到{len(original_results)}条结果")
# 然后添加每个主题的查询结果
for topic in topics:
if not topic or topic not in embeddings:
continue
try:
topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True)
if topic_results:
# 添加主题标记
for result in topic_results:
result["topic"] = topic
all_results.extend(topic_results)
logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果")
except Exception as e:
logger.error(f"查询主题'{topic}'时发生错误: {str(e)}")
logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果")
# 4. 去重和过滤
process_start_time = time.time()
unique_contents = set()
filtered_results = []
for result in all_results:
content = result["content"]
if content not in unique_contents:
unique_contents.add(content)
filtered_results.append(result)
# 5. 按相似度排序
filtered_results.sort(key=lambda x: x["similarity"], reverse=True)
# 6. 限制总数量最多10条
filtered_results = filtered_results[:10]
logger.info(
f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果"
)
# 7. 格式化输出
if filtered_results:
format_start_time = time.time()
grouped_results = {}
for result in filtered_results:
topic = result["topic"]
if topic not in grouped_results:
grouped_results[topic] = []
grouped_results[topic].append(result)
# 按主题组织输出
for topic, results in grouped_results.items():
related_info += f"【主题: {topic}\n"
for _i, result in enumerate(results, 1):
_similarity = result["similarity"]
content = result["content"].strip()
# 调试:为内容添加序号和相似度信息
# related_info += f"{i}. [{similarity:.2f}] {content}\n"
related_info += f"{content}\n"
related_info += "\n"
logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}")
logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}")
return related_info, grouped_results
def get_info_from_db(
self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
) -> Union[str, list]:
if not query_embedding:
return "" if not return_raw else []
# 使用余弦相似度计算
pipeline = [
{
"$addFields": {
"dotProduct": {
"$reduce": {
"input": {"$range": [0, {"$size": "$embedding"}]},
"initialValue": 0,
"in": {
"$add": [
"$$value",
{
"$multiply": [
{"$arrayElemAt": ["$embedding", "$$this"]},
{"$arrayElemAt": [query_embedding, "$$this"]},
]
},
]
},
}
},
"magnitude1": {
"$sqrt": {
"$reduce": {
"input": "$embedding",
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
"magnitude2": {
"$sqrt": {
"$reduce": {
"input": query_embedding,
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
}
},
{"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
{
"$match": {
"similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
}
},
{"$sort": {"similarity": -1}},
{"$limit": limit},
{"$project": {"content": 1, "similarity": 1}},
]
results = list(db.knowledges.aggregate(pipeline))
logger.debug(f"知识库查询结果数量: {len(results)}")
if not results:
return "" if not return_raw else []
if return_raw:
return results
else:
# 返回所有找到的内容,用换行分隔
return "\n".join(str(result["content"]) for result in results)
# subheartflow = SubHeartflow()