Merge pull request #682 from Kohaku-hupo/main

优化了现有的知识库系统(基于2025/4/5的MMC版本)
This commit is contained in:
SengokuCola
2025-04-05 21:02:37 +08:00
committed by GitHub
4 changed files with 511 additions and 156 deletions

View File

@@ -8,6 +8,9 @@ import time
from src.plugins.schedule.schedule_generator import bot_schedule from src.plugins.schedule.schedule_generator import bot_schedule
from src.plugins.memory_system.Hippocampus import HippocampusManager from src.plugins.memory_system.Hippocampus import HippocampusManager
from src.common.logger import get_module_logger, LogConfig, SUB_HEARTFLOW_STYLE_CONFIG # noqa: E402 from src.common.logger import get_module_logger, LogConfig, SUB_HEARTFLOW_STYLE_CONFIG # noqa: E402
from src.plugins.chat.utils import get_embedding
from src.common.database import db
from typing import Union
subheartflow_config = LogConfig( subheartflow_config = LogConfig(
# 使用海马体专用样式 # 使用海马体专用样式
@@ -54,6 +57,8 @@ class SubHeartflow:
self.observations: list[Observation] = [] self.observations: list[Observation] = []
self.running_knowledges = []
def add_observation(self, observation: Observation): def add_observation(self, observation: Observation):
"""添加一个新的observation对象到列表中如果已存在相同id的observation则不添加""" """添加一个新的observation对象到列表中如果已存在相同id的observation则不添加"""
# 查找是否存在相同id的observation # 查找是否存在相同id的observation
@@ -98,49 +103,49 @@ class SubHeartflow:
logger.info(f"子心流 {self.subheartflow_id} 已经5分钟没有激活正在销毁...") logger.info(f"子心流 {self.subheartflow_id} 已经5分钟没有激活正在销毁...")
break # 退出循环,销毁自己 break # 退出循环,销毁自己
async def do_a_thinking(self): # async def do_a_thinking(self):
current_thinking_info = self.current_mind # current_thinking_info = self.current_mind
mood_info = self.current_state.mood # mood_info = self.current_state.mood
observation = self.observations[0] # observation = self.observations[0]
chat_observe_info = observation.observe_info # chat_observe_info = observation.observe_info
# print(f"chat_observe_info{chat_observe_info}") # # print(f"chat_observe_info{chat_observe_info}")
# 调取记忆 # # 调取记忆
related_memory = await HippocampusManager.get_instance().get_memory_from_text( # related_memory = await HippocampusManager.get_instance().get_memory_from_text(
text=chat_observe_info, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False # text=chat_observe_info, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
) # )
if related_memory: # if related_memory:
related_memory_info = "" # related_memory_info = ""
for memory in related_memory: # for memory in related_memory:
related_memory_info += memory[1] # related_memory_info += memory[1]
else: # else:
related_memory_info = "" # related_memory_info = ""
# print(f"相关记忆:{related_memory_info}") # # print(f"相关记忆:{related_memory_info}")
schedule_info = bot_schedule.get_current_num_task(num=1, time_info=False) # schedule_info = bot_schedule.get_current_num_task(num=1, time_info=False)
prompt = "" # prompt = ""
prompt += f"你刚刚在做的事情是:{schedule_info}\n" # prompt += f"你刚刚在做的事情是:{schedule_info}\n"
# prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n" # # prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n"
prompt += f"{self.personality_info}\n" # prompt += f"你{self.personality_info}\n"
if related_memory_info: # if related_memory_info:
prompt += f"你想起来你之前见过的回忆:{related_memory_info}\n以上是你的回忆,不一定是目前聊天里的人说的,也不一定是现在发生的事情,请记住。\n" # prompt += f"你想起来你之前见过的回忆:{related_memory_info}。\n以上是你的回忆,不一定是目前聊天里的人说的,也不一定是现在发生的事情,请记住。\n"
prompt += f"刚刚你的想法是{current_thinking_info}\n" # prompt += f"刚刚你的想法是{current_thinking_info}。\n"
prompt += "-----------------------------------\n" # prompt += "-----------------------------------\n"
prompt += f"现在你正在上网和qq群里的网友们聊天群里正在聊的话题是{chat_observe_info}\n" # prompt += f"现在你正在上网和qq群里的网友们聊天群里正在聊的话题是{chat_observe_info}\n"
prompt += f"你现在{mood_info}\n" # prompt += f"你现在{mood_info}\n"
prompt += "现在你接下去继续思考,产生新的想法,不要分点输出,输出连贯的内心独白,不要太长," # prompt += "现在你接下去继续思考,产生新的想法,不要分点输出,输出连贯的内心独白,不要太长,"
prompt += "但是记得结合上述的消息,要记得维持住你的人设,关注聊天和新内容,不要思考太多:" # prompt += "但是记得结合上述的消息,要记得维持住你的人设,关注聊天和新内容,不要思考太多:"
reponse, reasoning_content = await self.llm_model.generate_response_async(prompt) # reponse, reasoning_content = await self.llm_model.generate_response_async(prompt)
self.update_current_mind(reponse) # self.update_current_mind(reponse)
self.current_mind = reponse # self.current_mind = reponse
logger.debug(f"prompt:\n{prompt}\n") # logger.debug(f"prompt:\n{prompt}\n")
logger.info(f"麦麦的脑内状态:{self.current_mind}") # logger.info(f"麦麦的脑内状态:{self.current_mind}")
async def do_observe(self): async def do_observe(self):
observation = self.observations[0] observation = self.observations[0]
@@ -166,6 +171,13 @@ class SubHeartflow:
else: else:
related_memory_info = "" related_memory_info = ""
related_info,grouped_results = await self.get_prompt_info(chat_observe_info + message_txt, 0.4)
print(related_info)
for topic, results in grouped_results.items():
for result in results:
print(result)
self.running_knowledges.append(result)
# print(f"相关记忆:{related_memory_info}") # print(f"相关记忆:{related_memory_info}")
schedule_info = bot_schedule.get_current_num_task(num=1, time_info=False) schedule_info = bot_schedule.get_current_num_task(num=1, time_info=False)
@@ -176,6 +188,8 @@ class SubHeartflow:
prompt += f"你刚刚在做的事情是:{schedule_info}\n" prompt += f"你刚刚在做的事情是:{schedule_info}\n"
if related_memory_info: if related_memory_info:
prompt += f"你想起来你之前见过的回忆:{related_memory_info}\n以上是你的回忆,不一定是目前聊天里的人说的,也不一定是现在发生的事情,请记住。\n" prompt += f"你想起来你之前见过的回忆:{related_memory_info}\n以上是你的回忆,不一定是目前聊天里的人说的,也不一定是现在发生的事情,请记住。\n"
if related_info:
prompt += f"你想起你知道:{related_info}\n"
prompt += f"刚刚你的想法是{current_thinking_info}\n" prompt += f"刚刚你的想法是{current_thinking_info}\n"
prompt += "-----------------------------------\n" prompt += "-----------------------------------\n"
prompt += f"现在你正在上网和qq群里的网友们聊天群里正在聊的话题是{chat_observe_info}\n" prompt += f"现在你正在上网和qq群里的网友们聊天群里正在聊的话题是{chat_observe_info}\n"
@@ -251,4 +265,220 @@ class SubHeartflow:
self.current_mind = reponse self.current_mind = reponse
async def get_prompt_info(self, message: str, threshold: float):
start_time = time.time()
related_info = ""
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
# 1. 先从LLM获取主题类似于记忆系统的做法
topics = []
# try:
# # 先尝试使用记忆系统的方法获取主题
# hippocampus = HippocampusManager.get_instance()._hippocampus
# topic_num = min(5, max(1, int(len(message) * 0.1)))
# topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
# # 提取关键词
# topics = re.findall(r"<([^>]+)>", topics_response[0])
# if not topics:
# topics = []
# else:
# topics = [
# topic.strip()
# for topic in ",".join(topics).replace("", ",").replace("、", ",").replace(" ", ",").split(",")
# if topic.strip()
# ]
# logger.info(f"从LLM提取的主题: {', '.join(topics)}")
# except Exception as e:
# logger.error(f"从LLM提取主题失败: {str(e)}")
# # 如果LLM提取失败使用jieba分词提取关键词作为备选
# words = jieba.cut(message)
# topics = [word for word in words if len(word) > 1][:5]
# logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
# 如果无法提取到主题,直接使用整个消息
if not topics:
logger.info("未能提取到任何主题,使用整个消息进行查询")
embedding = await get_embedding(message, request_type="info_retrieval")
if not embedding:
logger.error("获取消息嵌入向量失败")
return ""
related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}")
return related_info, {}
# 2. 对每个主题进行知识库查询
logger.info(f"开始处理{len(topics)}个主题的知识库查询")
# 优化批量获取嵌入向量减少API调用
embeddings = {}
topics_batch = [topic for topic in topics if len(topic) > 0]
if message: # 确保消息非空
topics_batch.append(message)
# 批量获取嵌入向量
embed_start_time = time.time()
for text in topics_batch:
if not text or len(text.strip()) == 0:
continue
try:
embedding = await get_embedding(text, request_type="info_retrieval")
if embedding:
embeddings[text] = embedding
else:
logger.warning(f"获取'{text}'的嵌入向量失败")
except Exception as e:
logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}")
logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}")
if not embeddings:
logger.error("所有嵌入向量获取失败")
return ""
# 3. 对每个主题进行知识库查询
all_results = []
query_start_time = time.time()
# 首先添加原始消息的查询结果
if message in embeddings:
original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True)
if original_results:
for result in original_results:
result["topic"] = "原始消息"
all_results.extend(original_results)
logger.info(f"原始消息查询到{len(original_results)}条结果")
# 然后添加每个主题的查询结果
for topic in topics:
if not topic or topic not in embeddings:
continue
try:
topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True)
if topic_results:
# 添加主题标记
for result in topic_results:
result["topic"] = topic
all_results.extend(topic_results)
logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果")
except Exception as e:
logger.error(f"查询主题'{topic}'时发生错误: {str(e)}")
logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果")
# 4. 去重和过滤
process_start_time = time.time()
unique_contents = set()
filtered_results = []
for result in all_results:
content = result["content"]
if content not in unique_contents:
unique_contents.add(content)
filtered_results.append(result)
# 5. 按相似度排序
filtered_results.sort(key=lambda x: x["similarity"], reverse=True)
# 6. 限制总数量最多10条
filtered_results = filtered_results[:10]
logger.info(f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果")
# 7. 格式化输出
if filtered_results:
format_start_time = time.time()
grouped_results = {}
for result in filtered_results:
topic = result["topic"]
if topic not in grouped_results:
grouped_results[topic] = []
grouped_results[topic].append(result)
# 按主题组织输出
for topic, results in grouped_results.items():
related_info += f"【主题: {topic}\n"
for i, result in enumerate(results, 1):
similarity = result["similarity"]
content = result["content"].strip()
# 调试:为内容添加序号和相似度信息
# related_info += f"{i}. [{similarity:.2f}] {content}\n"
related_info += f"{content}\n"
related_info += "\n"
logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}")
logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}")
return related_info,grouped_results
def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False) -> Union[str, list]:
if not query_embedding:
return "" if not return_raw else []
# 使用余弦相似度计算
pipeline = [
{
"$addFields": {
"dotProduct": {
"$reduce": {
"input": {"$range": [0, {"$size": "$embedding"}]},
"initialValue": 0,
"in": {
"$add": [
"$$value",
{
"$multiply": [
{"$arrayElemAt": ["$embedding", "$$this"]},
{"$arrayElemAt": [query_embedding, "$$this"]},
]
},
]
},
}
},
"magnitude1": {
"$sqrt": {
"$reduce": {
"input": "$embedding",
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
"magnitude2": {
"$sqrt": {
"$reduce": {
"input": query_embedding,
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
}
},
{"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
{
"$match": {
"similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
}
},
{"$sort": {"similarity": -1}},
{"$limit": limit},
{"$project": {"content": 1, "similarity": 1}},
]
results = list(db.knowledges.aggregate(pipeline))
logger.debug(f"知识库查询结果数量: {len(results)}")
if not results:
return "" if not return_raw else []
if return_raw:
return results
else:
# 返回所有找到的内容,用换行分隔
return "\n".join(str(result["content"]) for result in results)
# subheartflow = SubHeartflow() # subheartflow = SubHeartflow()

View File

@@ -1,16 +1,19 @@
import random import random
import time import time
from typing import Optional from typing import Optional, Union
import re
import jieba
import numpy as np
from ....common.database import db from ....common.database import db
from ...memory_system.Hippocampus import HippocampusManager
from ...moods.moods import MoodManager
from ...schedule.schedule_generator import bot_schedule
from ...config.config import global_config
from ...chat.utils import get_embedding, get_recent_group_detailed_plain_text, get_recent_group_speaker from ...chat.utils import get_embedding, get_recent_group_detailed_plain_text, get_recent_group_speaker
from ...chat.chat_stream import chat_manager from ...chat.chat_stream import chat_manager
from src.common.logger import get_module_logger from ...moods.moods import MoodManager
from ...memory_system.Hippocampus import HippocampusManager
from ...schedule.schedule_generator import bot_schedule
from ...config.config import global_config
from ...person_info.relationship_manager import relationship_manager from ...person_info.relationship_manager import relationship_manager
from src.common.logger import get_module_logger
logger = get_module_logger("prompt") logger = get_module_logger("prompt")
@@ -128,7 +131,7 @@ class PromptBuilder:
# 知识构建 # 知识构建
start_time = time.time() start_time = time.time()
prompt_info = "" prompt_info = ""
prompt_info = await self.get_prompt_info(message_txt, threshold=0.5) prompt_info = await self.get_prompt_info(message_txt, threshold=0.38)
if prompt_info: if prompt_info:
prompt_info = f"""\n你有以下这些**知识**\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n""" prompt_info = f"""\n你有以下这些**知识**\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n"""
@@ -159,16 +162,156 @@ class PromptBuilder:
return prompt return prompt
async def get_prompt_info(self, message: str, threshold: float): async def get_prompt_info(self, message: str, threshold: float):
start_time = time.time()
related_info = "" related_info = ""
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
embedding = await get_embedding(message, request_type="prompt_build")
related_info += self.get_info_from_db(embedding, limit=1, threshold=threshold)
# 1. 先从LLM获取主题类似于记忆系统的做法
topics = []
# try:
# # 先尝试使用记忆系统的方法获取主题
# hippocampus = HippocampusManager.get_instance()._hippocampus
# topic_num = min(5, max(1, int(len(message) * 0.1)))
# topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
# # 提取关键词
# topics = re.findall(r"<([^>]+)>", topics_response[0])
# if not topics:
# topics = []
# else:
# topics = [
# topic.strip()
# for topic in ",".join(topics).replace("", ",").replace("、", ",").replace(" ", ",").split(",")
# if topic.strip()
# ]
# logger.info(f"从LLM提取的主题: {', '.join(topics)}")
# except Exception as e:
# logger.error(f"从LLM提取主题失败: {str(e)}")
# # 如果LLM提取失败使用jieba分词提取关键词作为备选
# words = jieba.cut(message)
# topics = [word for word in words if len(word) > 1][:5]
# logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
# 如果无法提取到主题,直接使用整个消息
if not topics:
logger.info("未能提取到任何主题,使用整个消息进行查询")
embedding = await get_embedding(message, request_type="prompt_build")
if not embedding:
logger.error("获取消息嵌入向量失败")
return ""
related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}")
return related_info return related_info
def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5) -> str: # 2. 对每个主题进行知识库查询
if not query_embedding: logger.info(f"开始处理{len(topics)}个主题的知识库查询")
# 优化批量获取嵌入向量减少API调用
embeddings = {}
topics_batch = [topic for topic in topics if len(topic) > 0]
if message: # 确保消息非空
topics_batch.append(message)
# 批量获取嵌入向量
embed_start_time = time.time()
for text in topics_batch:
if not text or len(text.strip()) == 0:
continue
try:
embedding = await get_embedding(text, request_type="prompt_build")
if embedding:
embeddings[text] = embedding
else:
logger.warning(f"获取'{text}'的嵌入向量失败")
except Exception as e:
logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}")
logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}")
if not embeddings:
logger.error("所有嵌入向量获取失败")
return "" return ""
# 3. 对每个主题进行知识库查询
all_results = []
query_start_time = time.time()
# 首先添加原始消息的查询结果
if message in embeddings:
original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True)
if original_results:
for result in original_results:
result["topic"] = "原始消息"
all_results.extend(original_results)
logger.info(f"原始消息查询到{len(original_results)}条结果")
# 然后添加每个主题的查询结果
for topic in topics:
if not topic or topic not in embeddings:
continue
try:
topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True)
if topic_results:
# 添加主题标记
for result in topic_results:
result["topic"] = topic
all_results.extend(topic_results)
logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果")
except Exception as e:
logger.error(f"查询主题'{topic}'时发生错误: {str(e)}")
logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果")
# 4. 去重和过滤
process_start_time = time.time()
unique_contents = set()
filtered_results = []
for result in all_results:
content = result["content"]
if content not in unique_contents:
unique_contents.add(content)
filtered_results.append(result)
# 5. 按相似度排序
filtered_results.sort(key=lambda x: x["similarity"], reverse=True)
# 6. 限制总数量最多10条
filtered_results = filtered_results[:10]
logger.info(f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果")
# 7. 格式化输出
if filtered_results:
format_start_time = time.time()
grouped_results = {}
for result in filtered_results:
topic = result["topic"]
if topic not in grouped_results:
grouped_results[topic] = []
grouped_results[topic].append(result)
# 按主题组织输出
for topic, results in grouped_results.items():
related_info += f"【主题: {topic}\n"
for i, result in enumerate(results, 1):
similarity = result["similarity"]
content = result["content"].strip()
# 调试:为内容添加序号和相似度信息
# related_info += f"{i}. [{similarity:.2f}] {content}\n"
related_info += f"{content}\n"
related_info += "\n"
logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}")
logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}")
return related_info
def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False) -> Union[str, list]:
if not query_embedding:
return "" if not return_raw else []
# 使用余弦相似度计算 # 使用余弦相似度计算
pipeline = [ pipeline = [
{ {
@@ -222,11 +365,14 @@ class PromptBuilder:
] ]
results = list(db.knowledges.aggregate(pipeline)) results = list(db.knowledges.aggregate(pipeline))
# print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}") logger.debug(f"知识库查询结果数量: {len(results)}")
if not results: if not results:
return "" return "" if not return_raw else []
if return_raw:
return results
else:
# 返回所有找到的内容,用换行分隔 # 返回所有找到的内容,用换行分隔
return "\n".join(str(result["content"]) for result in results) return "\n".join(str(result["content"]) for result in results)

View File

@@ -238,25 +238,35 @@ class ThinkFlowChat:
do_reply = False do_reply = False
if random() < reply_probability: if random() < reply_probability:
try:
do_reply = True do_reply = True
# 创建思考消息 # 创建思考消息
try:
timer1 = time.time() timer1 = time.time()
thinking_id = await self._create_thinking_message(message, chat, userinfo, messageinfo) thinking_id = await self._create_thinking_message(message, chat, userinfo, messageinfo)
timer2 = time.time() timer2 = time.time()
timing_results["创建思考消息"] = timer2 - timer1 timing_results["创建思考消息"] = timer2 - timer1
except Exception as e:
logger.error(f"心流创建思考消息失败: {e}")
try:
# 观察 # 观察
timer1 = time.time() timer1 = time.time()
await heartflow.get_subheartflow(chat.stream_id).do_observe() await heartflow.get_subheartflow(chat.stream_id).do_observe()
timer2 = time.time() timer2 = time.time()
timing_results["观察"] = timer2 - timer1 timing_results["观察"] = timer2 - timer1
except Exception as e:
logger.error(f"心流观察失败: {e}")
# 思考前脑内状态 # 思考前脑内状态
try:
timer1 = time.time() timer1 = time.time()
await heartflow.get_subheartflow(chat.stream_id).do_thinking_before_reply(message.processed_plain_text) await heartflow.get_subheartflow(chat.stream_id).do_thinking_before_reply(message.processed_plain_text)
timer2 = time.time() timer2 = time.time()
timing_results["思考前脑内状态"] = timer2 - timer1 timing_results["思考前脑内状态"] = timer2 - timer1
except Exception as e:
logger.error(f"心流思考前脑内状态失败: {e}")
# 生成回复 # 生成回复
timer1 = time.time() timer1 = time.time()
@@ -269,28 +279,43 @@ class ThinkFlowChat:
return return
# 发送消息 # 发送消息
try:
timer1 = time.time() timer1 = time.time()
await self._send_response_messages(message, chat, response_set, thinking_id) await self._send_response_messages(message, chat, response_set, thinking_id)
timer2 = time.time() timer2 = time.time()
timing_results["发送消息"] = timer2 - timer1 timing_results["发送消息"] = timer2 - timer1
except Exception as e:
logger.error(f"心流发送消息失败: {e}")
# 发送表情包 # 处理表情包
try:
timer1 = time.time() timer1 = time.time()
await self._handle_emoji(message, chat, response_set) await self._handle_emoji(message, chat, response_set)
timer2 = time.time() timer2 = time.time()
timing_results["发送表情包"] = timer2 - timer1 timing_results["处理表情包"] = timer2 - timer1
except Exception as e:
logger.error(f"心流处理表情包失败: {e}")
# 更新心流 # 更新心流
try:
timer1 = time.time() timer1 = time.time()
await self._update_using_response(message, response_set) await self._update_using_response(message, response_set)
timer2 = time.time() timer2 = time.time()
timing_results["更新心流"] = timer2 - timer1 timing_results["更新心流"] = timer2 - timer1
except Exception as e:
logger.error(f"心流更新失败: {e}")
# 更新关系情绪 # 更新关系情绪
try:
timer1 = time.time() timer1 = time.time()
await self._update_relationship(message, response_set) await self._update_relationship(message, response_set)
timer2 = time.time() timer2 = time.time()
timing_results["更新关系情绪"] = timer2 - timer1 timing_results["更新关系情绪"] = timer2 - timer1
except Exception as e:
logger.error(f"心流更新关系情绪失败: {e}")
except Exception as e:
logger.error(f"心流处理消息失败: {e}")
# 输出性能计时结果 # 输出性能计时结果
if do_reply: if do_reply:

View File

@@ -41,7 +41,7 @@ class KnowledgeLibrary:
return f.read() return f.read()
def split_content(self, content: str, max_length: int = 512) -> list: def split_content(self, content: str, max_length: int = 512) -> list:
"""将内容分割成适当大小的块,保持段落完整性 """将内容分割成适当大小的块,按空行分割
Args: Args:
content: 要分割的文本内容 content: 要分割的文本内容
@@ -50,66 +50,20 @@ class KnowledgeLibrary:
Returns: Returns:
list: 分割后的文本块列表 list: 分割后的文本块列表
""" """
# 首先按段落分割 # 按空行分割内容
paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()] paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()]
chunks = [] chunks = []
current_chunk = []
current_length = 0
for para in paragraphs: for para in paragraphs:
para_length = len(para) para_length = len(para)
# 如果单个段落就超过最大长度 # 如果段落长度小于等于最大长度,直接添加
if para_length > max_length: if para_length <= max_length:
# 如果当前chunk不为空先保存 chunks.append(para)
if current_chunk:
chunks.append("\n".join(current_chunk))
current_chunk = []
current_length = 0
# 将长段落按句子分割
sentences = [
s.strip()
for s in para.replace("", "\n").replace("", "\n").replace("", "\n").split("\n")
if s.strip()
]
temp_chunk = []
temp_length = 0
for sentence in sentences:
sentence_length = len(sentence)
if sentence_length > max_length:
# 如果单个句子超长,强制按长度分割
if temp_chunk:
chunks.append("\n".join(temp_chunk))
temp_chunk = []
temp_length = 0
for i in range(0, len(sentence), max_length):
chunks.append(sentence[i : i + max_length])
elif temp_length + sentence_length + 1 <= max_length:
temp_chunk.append(sentence)
temp_length += sentence_length + 1
else: else:
chunks.append("\n".join(temp_chunk)) # 如果段落超过最大长度,则按最大长度切分
temp_chunk = [sentence] for i in range(0, para_length, max_length):
temp_length = sentence_length chunks.append(para[i:i + max_length])
if temp_chunk:
chunks.append("\n".join(temp_chunk))
# 如果当前段落加上现有chunk不超过最大长度
elif current_length + para_length + 1 <= max_length:
current_chunk.append(para)
current_length += para_length + 1
else:
# 保存当前chunk并开始新的chunk
chunks.append("\n".join(current_chunk))
current_chunk = [para]
current_length = para_length
# 添加最后一个chunk
if current_chunk:
chunks.append("\n".join(current_chunk))
return chunks return chunks