This commit is contained in:
SengokuCola
2025-05-29 10:16:36 +08:00
7 changed files with 29 additions and 126 deletions

View File

@@ -339,7 +339,9 @@ class HeartFChatting:
for pname, ptime in processor_time_costs.items():
formatted_ptime = f"{ptime * 1000:.2f}毫秒" if ptime < 1 else f"{ptime:.2f}"
processor_time_strings.append(f"{pname}: {formatted_ptime}")
processor_time_log = ("\n各处理器耗时: " + "; ".join(processor_time_strings)) if processor_time_strings else ""
processor_time_log = (
("\n各处理器耗时: " + "; ".join(processor_time_strings)) if processor_time_strings else ""
)
logger.info(
f"{self.log_prefix}{self._current_cycle.cycle_id}次思考,"
@@ -481,7 +483,9 @@ class HeartFChatting:
observations.append(self.action_observation)
with Timer("执行 信息处理器", cycle_timers):
all_plan_info, processor_time_costs = await self._process_processors(observations, running_memorys, cycle_timers)
all_plan_info, processor_time_costs = await self._process_processors(
observations, running_memorys, cycle_timers
)
loop_processor_info = {
"all_plan_info": all_plan_info,

View File

@@ -457,7 +457,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
return "SELF"
try:
person_id = person_info_manager.get_person_id(platform, user_id)
except Exception as e:
except Exception as _e:
person_id = None
if not person_id:
return "?"
@@ -472,7 +472,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
# user_info = msg.get("user_info", {})
platform = msg.get("chat_info_platform")
user_id = msg.get("user_id")
timestamp = msg.get("time")
_timestamp = msg.get("time")
# print(f"msg:{msg}")
# print(f"platform:{platform}")
# print(f"user_id:{user_id}")

View File

@@ -22,7 +22,9 @@ class PicAction(PluginAction):
"""根据描述使用火山引擎HTTP API生成图片的动作处理类"""
action_name = "pic_action"
action_description = "可以根据特定的描述,生成并发送一张图片,如果没提供描述,就根据聊天内容生成,你可以立刻画好,不用等待"
action_description = (
"可以根据特定的描述,生成并发送一张图片,如果没提供描述,就根据聊天内容生成,你可以立刻画好,不用等待"
)
action_parameters = {
"description": "图片描述,输入你想要生成并发送的图片的描述,必填",
"size": "图片尺寸,例如 '1024x1024' (可选, 默认从配置或 '1024x1024')",

View File

@@ -38,125 +38,22 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
# threshold = function_args.get("threshold", 0.4)
# 调用知识库搜索
embedding = await get_embedding(query, request_type="info_retrieval")
if embedding:
knowledge_info = qa_manager.get_knowledge(query)
logger.debug(f"知识库查询结果: {knowledge_info}")
if knowledge_info:
content = f"你知道这些知识: {knowledge_info}"
else:
content = f"你不太了解有关{query}的知识"
return {"type": "lpmm_knowledge", "id": query, "content": content}
# 如果获取嵌入失败
return {"type": "info", "id": query, "content": f"无法获取关于'{query}'的嵌入向量你lpmm知识库炸了"}
except Exception as e:
# 捕获异常并记录错误
logger.error(f"知识库搜索工具执行失败: {str(e)}")
# 在其他异常情况下,确保 id 仍然是 query (如果它被定义了)
query_id = query if "query" in locals() else "unknown_query"
return {"type": "info", "id": query_id, "content": f"lpmm知识库搜索失败炸了: {str(e)}"}
# def get_info_from_db(
# self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
# ) -> Union[str, list]:
# """从数据库中获取相关信息
# Args:
# query_embedding: 查询的嵌入向量
# limit: 最大返回结果数
# threshold: 相似度阈值
# return_raw: 是否返回原始结果
# Returns:
# Union[str, list]: 格式化的信息字符串或原始结果列表
# """
# if not query_embedding:
# return "" if not return_raw else []
# # 使用余弦相似度计算
# pipeline = [
# {
# "$addFields": {
# "dotProduct": {
# "$reduce": {
# "input": {"$range": [0, {"$size": "$embedding"}]},
# "initialValue": 0,
# "in": {
# "$add": [
# "$$value",
# {
# "$multiply": [
# {"$arrayElemAt": ["$embedding", "$$this"]},
# {"$arrayElemAt": [query_embedding, "$$this"]},
# ]
# },
# ]
# },
# }
# },
# "magnitude1": {
# "$sqrt": {
# "$reduce": {
# "input": "$embedding",
# "initialValue": 0,
# "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
# }
# }
# },
# "magnitude2": {
# "$sqrt": {
# "$reduce": {
# "input": query_embedding,
# "initialValue": 0,
# "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
# }
# }
# },
# }
# },
# {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
# {
# "$match": {
# "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
# }
# },
# {"$sort": {"similarity": -1}},
# {"$limit": limit},
# {"$project": {"content": 1, "similarity": 1}},
# ]
# results = list(db.knowledges.aggregate(pipeline))
# logger.debug(f"知识库查询结果数量: {len(results)}")
# if not results:
# return "" if not return_raw else []
# if return_raw:
# return results
# else:
# # 返回所有找到的内容,用换行分隔
# return "\n".join(str(result["content"]) for result in results)
def _format_results(self, results: list) -> str:
"""格式化结果"""
if not results:
return "未找到相关知识。"
formatted_string = "我找到了一些相关知识:\n"
for i, result in enumerate(results):
# chunk_id = result.get("chunk_id")
text = result.get("text", "")
source = result.get("source", "未知来源")
source_type = result.get("source_type", "未知类型")
similarity = result.get("similarity", 0.0)
formatted_string += (
f"{i + 1}. (相似度: {similarity:.2f}) 类型: {source_type}, 来源: {source} \n内容片段: {text}\n\n"
)
# 暂时去掉chunk_id
# formatted_string += f"{i + 1}. (相似度: {similarity:.2f}) 类型: {source_type}, 来源: {source}, Chunk ID: {chunk_id} \n内容片段: {text}\n\n"
return formatted_string
# 注册工具
# register_tool(SearchKnowledgeTool)