From 54f7b73ec495a075d3dcd3d35ce57a48dbde02fd Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Thu, 10 Apr 2025 23:32:28 +0800 Subject: [PATCH] =?UTF-8?q?better=EF=BC=9A=E5=BF=83=E6=B5=81=E5=8D=87?= =?UTF-8?q?=E7=BA=A7=EF=BC=8C=E5=A4=A7=E5=A4=A7=E5=87=8F=E5=B0=91=E4=BA=86?= =?UTF-8?q?=E5=A4=8D=E8=AF=BB=E6=83=85=E5=86=B5=EF=BC=8C=E5=B9=B6=E4=B8=94?= =?UTF-8?q?=E7=81=B5=E6=B4=BB=E8=B0=83=E7=94=A8=E5=B7=A5=E5=85=B7=E6=9D=A5?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E7=9F=A5=E8=AF=86=E5=92=8C=E8=AE=B0=E5=BF=86?= =?UTF-8?q?=E6=A3=80=E7=B4=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/do_tool/tool_can_use/README.md | 102 ++++ src/do_tool/tool_can_use/__init__.py | 11 + src/do_tool/tool_can_use/base_tool.py | 119 ++++ src/do_tool/tool_can_use/get_current_task.py | 63 +++ src/do_tool/tool_can_use/get_knowledge.py | 147 +++++ src/do_tool/tool_can_use/get_memory.py | 72 +++ src/do_tool/tool_use.py | 172 ++++++ src/heart_flow/observation.py | 65 ++- src/heart_flow/sub_heartflow.py | 14 +- src/heart_flow/tool_use.py | 561 ------------------- 10 files changed, 729 insertions(+), 597 deletions(-) create mode 100644 src/do_tool/tool_can_use/README.md create mode 100644 src/do_tool/tool_can_use/__init__.py create mode 100644 src/do_tool/tool_can_use/base_tool.py create mode 100644 src/do_tool/tool_can_use/get_current_task.py create mode 100644 src/do_tool/tool_can_use/get_knowledge.py create mode 100644 src/do_tool/tool_can_use/get_memory.py create mode 100644 src/do_tool/tool_use.py delete mode 100644 src/heart_flow/tool_use.py diff --git a/src/do_tool/tool_can_use/README.md b/src/do_tool/tool_can_use/README.md new file mode 100644 index 000000000..15c771887 --- /dev/null +++ b/src/do_tool/tool_can_use/README.md @@ -0,0 +1,102 @@ +# 工具系统使用指南 + +## 概述 + +`tool_can_use` 是一个插件式工具系统,允许轻松扩展和注册新工具。每个工具作为独立的文件存在于该目录下,系统会自动发现和注册这些工具。 + +## 工具结构 + +每个工具应该继承 `BaseTool` 基类并实现必要的属性和方法: + +```python +from src.do_tool.tool_can_use.base_tool import BaseTool, register_tool + +class MyNewTool(BaseTool): + # 工具名称,必须唯一 + name = "my_new_tool" + + # 工具描述,告诉LLM这个工具的用途 + description = "这是一个新工具,用于..." + + # 工具参数定义,遵循JSONSchema格式 + parameters = { + "type": "object", + "properties": { + "param1": { + "type": "string", + "description": "参数1的描述" + }, + "param2": { + "type": "integer", + "description": "参数2的描述" + } + }, + "required": ["param1"] # 必需的参数列表 + } + + async def execute(self, function_args, message_txt=""): + """执行工具逻辑 + + Args: + function_args: 工具调用参数 + message_txt: 原始消息文本 + + Returns: + Dict: 包含执行结果的字典,必须包含name和content字段 + """ + # 实现工具逻辑 + result = f"工具执行结果: {function_args.get('param1')}" + + return { + "name": self.name, + "content": result + } + +# 注册工具 +register_tool(MyNewTool) +``` + +## 自动注册机制 + +工具系统通过以下步骤自动注册工具: + +1. 在`__init__.py`中,`discover_tools()`函数会自动遍历当前目录中的所有Python文件 +2. 对于每个文件,系统会寻找继承自`BaseTool`的类 +3. 这些类会被自动注册到工具注册表中 + +只要确保在每个工具文件的末尾调用`register_tool(YourToolClass)`,工具就会被自动注册。 + +## 添加新工具步骤 + +1. 在`tool_can_use`目录下创建新的Python文件(如`my_new_tool.py`) +2. 导入`BaseTool`和`register_tool` +3. 创建继承自`BaseTool`的工具类 +4. 实现必要的属性(`name`, `description`, `parameters`) +5. 实现`execute`方法 +6. 使用`register_tool`注册工具 + +## 与ToolUser整合 + +`ToolUser`类已经更新为使用这个新的工具系统,它会: + +1. 自动获取所有已注册工具的定义 +2. 基于工具名称找到对应的工具实例 +3. 调用工具的`execute`方法 + +## 使用示例 + +```python +from src.do_tool.tool_use import ToolUser + +# 创建工具用户 +tool_user = ToolUser() + +# 使用工具 +result = await tool_user.use_tool(message_txt="查询关于Python的知识", sender_name="用户", chat_stream=chat_stream) + +# 处理结果 +if result["used_tools"]: + print("工具使用结果:", result["collected_info"]) +else: + print("未使用工具") +``` \ No newline at end of file diff --git a/src/do_tool/tool_can_use/__init__.py b/src/do_tool/tool_can_use/__init__.py new file mode 100644 index 000000000..cc196d07a --- /dev/null +++ b/src/do_tool/tool_can_use/__init__.py @@ -0,0 +1,11 @@ +from src.do_tool.tool_can_use.base_tool import ( + BaseTool, + register_tool, + discover_tools, + get_all_tool_definitions, + get_tool_instance, + TOOL_REGISTRY +) + +# 自动发现并注册工具 +discover_tools() \ No newline at end of file diff --git a/src/do_tool/tool_can_use/base_tool.py b/src/do_tool/tool_can_use/base_tool.py new file mode 100644 index 000000000..03aac5e4c --- /dev/null +++ b/src/do_tool/tool_can_use/base_tool.py @@ -0,0 +1,119 @@ +from typing import Dict, List, Any, Optional, Union, Type +import inspect +import importlib +import pkgutil +import os +from src.common.logger import get_module_logger + +logger = get_module_logger("base_tool") + +# 工具注册表 +TOOL_REGISTRY = {} + +class BaseTool: + """所有工具的基类""" + # 工具名称,子类必须重写 + name = None + # 工具描述,子类必须重写 + description = None + # 工具参数定义,子类必须重写 + parameters = None + + @classmethod + def get_tool_definition(cls) -> Dict[str, Any]: + """获取工具定义,用于LLM工具调用 + + Returns: + Dict: 工具定义字典 + """ + if not cls.name or not cls.description or not cls.parameters: + raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性") + + return { + "type": "function", + "function": { + "name": cls.name, + "description": cls.description, + "parameters": cls.parameters + } + } + + async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: + """执行工具函数 + + Args: + function_args: 工具调用参数 + message_txt: 原始消息文本 + + Returns: + Dict: 工具执行结果 + """ + raise NotImplementedError("子类必须实现execute方法") + + +def register_tool(tool_class: Type[BaseTool]): + """注册工具到全局注册表 + + Args: + tool_class: 工具类 + """ + if not issubclass(tool_class, BaseTool): + raise TypeError(f"{tool_class.__name__} 不是 BaseTool 的子类") + + tool_name = tool_class.name + if not tool_name: + raise ValueError(f"工具类 {tool_class.__name__} 没有定义 name 属性") + + TOOL_REGISTRY[tool_name] = tool_class + logger.info(f"已注册工具: {tool_name}") + + +def discover_tools(): + """自动发现并注册tool_can_use目录下的所有工具""" + # 获取当前目录路径 + current_dir = os.path.dirname(os.path.abspath(__file__)) + package_name = os.path.basename(current_dir) + parent_dir = os.path.dirname(current_dir) + + # 导入当前包 + package = importlib.import_module(f"src.do_tool.{package_name}") + + # 遍历包中的所有模块 + for _, module_name, is_pkg in pkgutil.iter_modules([current_dir]): + # 跳过当前模块和__pycache__ + if module_name == "base_tool" or module_name.startswith("__"): + continue + + # 导入模块 + module = importlib.import_module(f"src.do_tool.{package_name}.{module_name}") + + # 查找模块中的工具类 + for name, obj in inspect.getmembers(module): + if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool: + register_tool(obj) + + logger.info(f"工具发现完成,共注册 {len(TOOL_REGISTRY)} 个工具") + + +def get_all_tool_definitions() -> List[Dict[str, Any]]: + """获取所有已注册工具的定义 + + Returns: + List[Dict]: 工具定义列表 + """ + return [tool_class().get_tool_definition() for tool_class in TOOL_REGISTRY.values()] + + +def get_tool_instance(tool_name: str) -> Optional[BaseTool]: + """获取指定名称的工具实例 + + Args: + tool_name: 工具名称 + + Returns: + Optional[BaseTool]: 工具实例,如果找不到则返回None + """ + tool_class = TOOL_REGISTRY.get(tool_name) + if not tool_class: + return None + return tool_class() \ No newline at end of file diff --git a/src/do_tool/tool_can_use/get_current_task.py b/src/do_tool/tool_can_use/get_current_task.py new file mode 100644 index 000000000..dd3402357 --- /dev/null +++ b/src/do_tool/tool_can_use/get_current_task.py @@ -0,0 +1,63 @@ +from src.do_tool.tool_can_use.base_tool import BaseTool, register_tool +from src.plugins.schedule.schedule_generator import bot_schedule +from src.common.logger import get_module_logger +from typing import Dict, Any + +logger = get_module_logger("get_current_task_tool") + +class GetCurrentTaskTool(BaseTool): + """获取当前正在做的事情/最近的任务工具""" + name = "get_current_task" + description = "获取当前正在做的事情/最近的任务" + parameters = { + "type": "object", + "properties": { + "num": { + "type": "integer", + "description": "要获取的任务数量" + }, + "time_info": { + "type": "boolean", + "description": "是否包含时间信息" + } + }, + "required": [] + } + + async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: + """执行获取当前任务 + + Args: + function_args: 工具参数 + message_txt: 原始消息文本,此工具不使用 + + Returns: + Dict: 工具执行结果 + """ + try: + # 获取参数,如果没有提供则使用默认值 + num = function_args.get("num", 1) + time_info = function_args.get("time_info", False) + + # 调用日程系统获取当前任务 + current_task = bot_schedule.get_current_num_task(num=num, time_info=time_info) + + # 格式化返回结果 + if current_task: + task_info = current_task + else: + task_info = "当前没有正在进行的任务" + + return { + "name": "get_current_task", + "content": f"当前任务信息: {task_info}" + } + except Exception as e: + logger.error(f"获取当前任务工具执行失败: {str(e)}") + return { + "name": "get_current_task", + "content": f"获取当前任务失败: {str(e)}" + } + +# 注册工具 +register_tool(GetCurrentTaskTool) \ No newline at end of file diff --git a/src/do_tool/tool_can_use/get_knowledge.py b/src/do_tool/tool_can_use/get_knowledge.py new file mode 100644 index 000000000..06ea7a91b --- /dev/null +++ b/src/do_tool/tool_can_use/get_knowledge.py @@ -0,0 +1,147 @@ +from src.do_tool.tool_can_use.base_tool import BaseTool, register_tool +from src.plugins.chat.utils import get_embedding +from src.common.database import db +from src.common.logger import get_module_logger +from typing import Dict, Any, Union, List + +logger = get_module_logger("get_knowledge_tool") + +class SearchKnowledgeTool(BaseTool): + """从知识库中搜索相关信息的工具""" + name = "search_knowledge" + description = "从知识库中搜索相关信息" + parameters = { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "搜索查询关键词" + }, + "threshold": { + "type": "number", + "description": "相似度阈值,0.0到1.0之间" + } + }, + "required": ["query"] + } + + async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: + """执行知识库搜索 + + Args: + function_args: 工具参数 + message_txt: 原始消息文本 + + Returns: + Dict: 工具执行结果 + """ + try: + query = function_args.get("query", message_txt) + threshold = function_args.get("threshold", 0.4) + + # 调用知识库搜索 + embedding = await get_embedding(query, request_type="info_retrieval") + if embedding: + knowledge_info = self.get_info_from_db(embedding, limit=3, threshold=threshold) + if knowledge_info: + content = f"你知道这些知识: {knowledge_info}" + else: + content = f"你不太了解有关{query}的知识" + return { + "name": "search_knowledge", + "content": content + } + return { + "name": "search_knowledge", + "content": f"无法获取关于'{query}'的嵌入向量" + } + except Exception as e: + logger.error(f"知识库搜索工具执行失败: {str(e)}") + return { + "name": "search_knowledge", + "content": f"知识库搜索失败: {str(e)}" + } + + def get_info_from_db( + self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False + ) -> Union[str, list]: + """从数据库中获取相关信息 + + Args: + query_embedding: 查询的嵌入向量 + limit: 最大返回结果数 + threshold: 相似度阈值 + return_raw: 是否返回原始结果 + + Returns: + Union[str, list]: 格式化的信息字符串或原始结果列表 + """ + if not query_embedding: + return "" if not return_raw else [] + + # 使用余弦相似度计算 + pipeline = [ + { + "$addFields": { + "dotProduct": { + "$reduce": { + "input": {"$range": [0, {"$size": "$embedding"}]}, + "initialValue": 0, + "in": { + "$add": [ + "$$value", + { + "$multiply": [ + {"$arrayElemAt": ["$embedding", "$$this"]}, + {"$arrayElemAt": [query_embedding, "$$this"]}, + ] + }, + ] + }, + } + }, + "magnitude1": { + "$sqrt": { + "$reduce": { + "input": "$embedding", + "initialValue": 0, + "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, + } + } + }, + "magnitude2": { + "$sqrt": { + "$reduce": { + "input": query_embedding, + "initialValue": 0, + "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, + } + } + }, + } + }, + {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}}, + { + "$match": { + "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果 + } + }, + {"$sort": {"similarity": -1}}, + {"$limit": limit}, + {"$project": {"content": 1, "similarity": 1}}, + ] + + results = list(db.knowledges.aggregate(pipeline)) + logger.debug(f"知识库查询结果数量: {len(results)}") + + if not results: + return "" if not return_raw else [] + + if return_raw: + return results + else: + # 返回所有找到的内容,用换行分隔 + return "\n".join(str(result["content"]) for result in results) + +# 注册工具 +register_tool(SearchKnowledgeTool) diff --git a/src/do_tool/tool_can_use/get_memory.py b/src/do_tool/tool_can_use/get_memory.py new file mode 100644 index 000000000..171e8486a --- /dev/null +++ b/src/do_tool/tool_can_use/get_memory.py @@ -0,0 +1,72 @@ +from src.do_tool.tool_can_use.base_tool import BaseTool, register_tool +from src.plugins.memory_system.Hippocampus import HippocampusManager +from src.common.logger import get_module_logger +from typing import Dict, Any + +logger = get_module_logger("get_memory_tool") + +class GetMemoryTool(BaseTool): + """从记忆系统中获取相关记忆的工具""" + name = "get_memory" + description = "从记忆系统中获取相关记忆" + parameters = { + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "要查询的相关文本" + }, + "max_memory_num": { + "type": "integer", + "description": "最大返回记忆数量" + } + }, + "required": ["text"] + } + + async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: + """执行记忆获取 + + Args: + function_args: 工具参数 + message_txt: 原始消息文本 + + Returns: + Dict: 工具执行结果 + """ + try: + text = function_args.get("text", message_txt) + max_memory_num = function_args.get("max_memory_num", 2) + + # 调用记忆系统 + related_memory = await HippocampusManager.get_instance().get_memory_from_text( + text=text, + max_memory_num=max_memory_num, + max_memory_length=2, + max_depth=3, + fast_retrieval=False + ) + + memory_info = "" + if related_memory: + for memory in related_memory: + memory_info += memory[1] + "\n" + + if memory_info: + content = f"你记得这些事情: {memory_info}" + else: + content = f"你不太记得有关{text}的记忆,你对此不太了解" + + return { + "name": "get_memory", + "content": content + } + except Exception as e: + logger.error(f"记忆获取工具执行失败: {str(e)}") + return { + "name": "get_memory", + "content": f"记忆获取失败: {str(e)}" + } + +# 注册工具 +register_tool(GetMemoryTool) \ No newline at end of file diff --git a/src/do_tool/tool_use.py b/src/do_tool/tool_use.py new file mode 100644 index 000000000..a2e23ab21 --- /dev/null +++ b/src/do_tool/tool_use.py @@ -0,0 +1,172 @@ +from src.plugins.models.utils_model import LLM_request +from src.plugins.config.config import global_config +from src.plugins.chat.chat_stream import ChatStream +from src.common.database import db +import time +import json +from src.common.logger import get_module_logger +from typing import Union +from src.do_tool.tool_can_use import get_all_tool_definitions, get_tool_instance + +logger = get_module_logger("tool_use") + + +class ToolUser: + def __init__(self): + self.llm_model_tool = LLM_request( + model=global_config.llm_heartflow, temperature=0.2, max_tokens=1000, request_type="tool_use" + ) + + async def _build_tool_prompt(self, message_txt:str, sender_name:str, chat_stream:ChatStream): + """构建工具使用的提示词 + + Args: + message_txt: 用户消息文本 + sender_name: 发送者名称 + chat_stream: 聊天流对象 + + Returns: + str: 构建好的提示词 + """ + new_messages = list( + db.messages.find({"chat_id": chat_stream.stream_id, "time": {"$gt": time.time()}}) + .sort("time", 1) + .limit(15) + ) + new_messages_str = "" + for msg in new_messages: + if "detailed_plain_text" in msg: + new_messages_str += f"{msg['detailed_plain_text']}" + + # 这些信息应该从调用者传入,而不是从self获取 + bot_name = global_config.BOT_NICKNAME + prompt = "" + prompt += "你正在思考如何回复群里的消息。\n" + prompt += f"你注意到{sender_name}刚刚说:{message_txt}\n" + prompt += f"注意你就是{bot_name},{bot_name}指的就是你。" + prompt += "你现在需要对群里的聊天内容进行回复,现在请你思考,你是否需要额外的信息,或者一些工具来帮你回复,比如回忆或者搜寻已有的知识,或者了解你现在正在做什么,请输出你需要的工具,或者你需要的额外信息。" + + return prompt + + def _define_tools(self): + """获取所有已注册工具的定义 + + Returns: + list: 工具定义列表 + """ + return get_all_tool_definitions() + + async def _execute_tool_call(self, tool_call, message_txt:str): + """执行特定的工具调用 + + Args: + tool_call: 工具调用对象 + message_txt: 原始消息文本 + + Returns: + dict: 工具调用结果 + """ + try: + function_name = tool_call["function"]["name"] + function_args = json.loads(tool_call["function"]["arguments"]) + + # 获取对应工具实例 + tool_instance = get_tool_instance(function_name) + if not tool_instance: + logger.warning(f"未知工具名称: {function_name}") + return None + + # 执行工具 + result = await tool_instance.execute(function_args, message_txt) + if result: + return { + "tool_call_id": tool_call["id"], + "role": "tool", + "name": function_name, + "content": result["content"] + } + return None + except Exception as e: + logger.error(f"执行工具调用时发生错误: {str(e)}") + return None + + async def use_tool(self, message_txt:str, sender_name:str, chat_stream:ChatStream): + """使用工具辅助思考,判断是否需要额外信息 + + Args: + message_txt: 用户消息文本 + sender_name: 发送者名称 + chat_stream: 聊天流对象 + + Returns: + dict: 工具使用结果 + """ + try: + # 构建提示词 + prompt = await self._build_tool_prompt(message_txt, sender_name, chat_stream) + + # 定义可用工具 + tools = self._define_tools() + + # 使用llm_model_tool发送带工具定义的请求 + payload = { + "model": self.llm_model_tool.model_name, + "messages": [{"role": "user", "content": prompt}], + "max_tokens": global_config.max_response_length, + "tools": tools, + "temperature": 0.2 + } + + logger.debug(f"发送工具调用请求,模型: {self.llm_model_tool.model_name}") + # 发送请求获取模型是否需要调用工具 + response = await self.llm_model_tool._execute_request( + endpoint="/chat/completions", + payload=payload, + prompt=prompt + ) + + # 根据返回值数量判断是否有工具调用 + if len(response) == 3: + content, reasoning_content, tool_calls = response + logger.info(f"工具思考: {tool_calls}") + + # 检查响应中工具调用是否有效 + if not tool_calls: + logger.info("模型返回了空的tool_calls列表") + return {"used_tools": False} + + logger.info(f"模型请求调用{len(tool_calls)}个工具") + tool_results = [] + collected_info = "" + + # 执行所有工具调用 + for tool_call in tool_calls: + result = await self._execute_tool_call(tool_call, message_txt) + if result: + tool_results.append(result) + # 将工具结果添加到收集的信息中 + collected_info += f"\n{result['name']}返回结果: {result['content']}\n" + + # 如果有工具结果,直接返回收集的信息 + if collected_info: + logger.info(f"工具调用收集到信息: {collected_info}") + return { + "used_tools": True, + "collected_info": collected_info, + } + else: + # 没有工具调用 + content, reasoning_content = response + logger.info("模型没有请求调用任何工具") + + # 如果没有工具调用或处理失败,直接返回原始思考 + return { + "used_tools": False, + } + + except Exception as e: + logger.error(f"工具调用过程中出错: {str(e)}") + return { + "used_tools": False, + "error": str(e), + } \ No newline at end of file diff --git a/src/heart_flow/observation.py b/src/heart_flow/observation.py index 818f1775d..f507374cb 100644 --- a/src/heart_flow/observation.py +++ b/src/heart_flow/observation.py @@ -68,20 +68,24 @@ class ChattingObservation(Observation): self.translate_message_list_to_str() # 更新观察次数 - self.observe_times += 1 + # self.observe_times += 1 self.last_observe_time = new_messages[-1]["time"] # 检查是否需要更新summary - current_time = int(datetime.now().timestamp()) - if current_time - self.last_summary_time >= 30: # 如果超过30秒,重置计数 - self.summary_count = 0 - self.last_summary_time = current_time + # current_time = int(datetime.now().timestamp()) + # if current_time - self.last_summary_time >= 30: # 如果超过30秒,重置计数 + # self.summary_count = 0 + # self.last_summary_time = current_time - if self.summary_count < self.max_update_in_30s: # 如果30秒内更新次数小于2次 - await self.update_talking_summary(new_messages_str) - self.summary_count += 1 + # if self.summary_count < self.max_update_in_30s: # 如果30秒内更新次数小于2次 + # await self.update_talking_summary(new_messages_str) + # print(f"更新聊天总结:{self.observe_info}11111111111111") + # self.summary_count += 1 + updated_observe_info = await self.update_talking_summary(new_messages_str) + print(f"更新聊天总结:{updated_observe_info}11111111111111") + self.observe_info = updated_observe_info - return self.observe_info + return updated_observe_info async def carefully_observe(self): # 查找新消息,限制最多40条 @@ -110,43 +114,46 @@ class ChattingObservation(Observation): self.observe_times += 1 self.last_observe_time = new_messages[-1]["time"] - await self.update_talking_summary(new_messages_str) - return self.observe_info + updated_observe_info = await self.update_talking_summary(new_messages_str) + self.observe_info = updated_observe_info + return updated_observe_info async def update_talking_summary(self, new_messages_str): # 基于已经有的talking_summary,和新的talking_message,生成一个summary # print(f"更新聊天总结:{self.talking_summary}") # 开始构建prompt - prompt_personality = "你" - # person - individuality = Individuality.get_instance() + # prompt_personality = "你" + # # person + # individuality = Individuality.get_instance() - personality_core = individuality.personality.personality_core - prompt_personality += personality_core + # personality_core = individuality.personality.personality_core + # prompt_personality += personality_core - personality_sides = individuality.personality.personality_sides - random.shuffle(personality_sides) - prompt_personality += f",{personality_sides[0]}" + # personality_sides = individuality.personality.personality_sides + # random.shuffle(personality_sides) + # prompt_personality += f",{personality_sides[0]}" - identity_detail = individuality.identity.identity_detail - random.shuffle(identity_detail) - prompt_personality += f",{identity_detail[0]}" + # identity_detail = individuality.identity.identity_detail + # random.shuffle(identity_detail) + # prompt_personality += f",{identity_detail[0]}" - personality_info = prompt_personality + # personality_info = prompt_personality prompt = "" - prompt += f"{personality_info},请注意识别你自己的聊天发言" - prompt += f"你的名字叫:{self.name},你的昵称是:{self.nick_name}\n" + # prompt += f"{personality_info}" + prompt += f"你的名字叫:{self.name}\n,标识'{self.name}'的都是你自己说的话" prompt += f"你正在参与一个qq群聊的讨论,你记得这个群之前在聊的内容是:{self.observe_info}\n" prompt += f"现在群里的群友们产生了新的讨论,有了新的发言,具体内容如下:{new_messages_str}\n" - prompt += """以上是群里在进行的聊天,请你对这个聊天内容进行总结,总结内容要包含聊天的大致内容, - 以及聊天中的一些重要信息,注意识别你自己的发言,记得不要分点,不要太长,精简的概括成一段文本\n""" + prompt += """以上是群里在进行的聊天,请你对这个聊天内容进行总结,总结内容要包含聊天的大致内容,目前最新讨论的话题 + 以及聊天中的一些重要信息,记得不要分点,精简的概括成一段文本\n""" prompt += "总结概括:" try: - self.observe_info, reasoning_content = await self.llm_summary.generate_response_async(prompt) + updated_observe_info, reasoning_content = await self.llm_summary.generate_response_async(prompt) except Exception as e: print(f"获取总结失败: {e}") - self.observe_info = "" + updated_observe_info = "" + + return updated_observe_info # print(f"prompt:{prompt}") # print(f"self.observe_info:{self.observe_info}") diff --git a/src/heart_flow/sub_heartflow.py b/src/heart_flow/sub_heartflow.py index 6d0138987..80d7efb61 100644 --- a/src/heart_flow/sub_heartflow.py +++ b/src/heart_flow/sub_heartflow.py @@ -17,7 +17,7 @@ from src.plugins.chat.chat_stream import ChatStream from src.plugins.person_info.relationship_manager import relationship_manager from src.plugins.chat.utils import get_recent_group_speaker import json -from src.heart_flow.tool_use import ToolUser +from src.do_tool.tool_use import ToolUser subheartflow_config = LogConfig( # 使用海马体专用样式 @@ -133,14 +133,13 @@ class SubHeartflow: tool_result = await self.tool_user.use_tool(message_txt, sender_name, chat_stream) # 如果工具被使用且获得了结果,将收集到的信息合并到思考中 + collected_info = "" if tool_result.get("used_tools", False): logger.info("使用工具收集了信息") # 如果有收集到的信息,将其添加到当前思考中 if "collected_info" in tool_result: collected_info = tool_result["collected_info"] - - # 开始构建prompt prompt_personality = f"你的名字是{self.bot_name},你" @@ -178,6 +177,7 @@ class SubHeartflow: ) prompt = "" + # prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n" if tool_result.get("used_tools", False): prompt += f"{collected_info}\n" prompt += f"{relation_prompt_all}\n" @@ -187,10 +187,10 @@ class SubHeartflow: prompt += f"现在你正在上网,和qq群里的网友们聊天,群里正在聊的话题是:{chat_observe_info}\n" prompt += f"你现在{mood_info}\n" prompt += f"你注意到{sender_name}刚刚说:{message_txt}\n" - prompt += "现在你接下去继续思考,产生新的想法,不要分点输出,输出连贯的内心独白,不要太长," - prompt += "思考时可以想想如何对群聊内容进行回复。回复的要求是:平淡一些,简短一些,说中文,不要刻意突出自身学科背景,尽量不要说你说过的话 ,注意只输出回复内容。" - prompt += "请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等)," - prompt += f"记得结合上述的消息,生成符合内心想法的内心独白,文字不要浮夸,注意你就是{self.bot_name},{self.bot_name}指的就是你。" + prompt += "现在你接下去继续思考,产生新的想法,不要分点输出,输出连贯的内心独白" + prompt += "思考时可以想想如何对群聊内容进行回复。回复的要求是:平淡一些,简短一些,说中文,尽量不要说你说过的话\n" + prompt += "请注意不要输出多余内容(包括前后缀,冒号和引号,括号, 表情,其他描述等)," + prompt += f"记得结合上述的消息,生成内心想法,文字不要浮夸,注意你就是{self.bot_name},{self.bot_name}指的就是你。" try: response, reasoning_content = await self.llm_model.generate_response_async(prompt) diff --git a/src/heart_flow/tool_use.py b/src/heart_flow/tool_use.py deleted file mode 100644 index 7471e7512..000000000 --- a/src/heart_flow/tool_use.py +++ /dev/null @@ -1,561 +0,0 @@ -from src.plugins.models.utils_model import LLM_request -from src.plugins.config.config import global_config -from src.plugins.chat.chat_stream import ChatStream -from src.plugins.memory_system.Hippocampus import HippocampusManager -from src.common.database import db -import time -import json -from src.common.logger import get_module_logger -from src.plugins.chat.utils import get_embedding -from typing import Union -logger = get_module_logger("tool_use") - - -class ToolUser: - def __init__(self): - self.llm_model_tool = LLM_request( - model=global_config.llm_heartflow, temperature=0.2, max_tokens=1000, request_type="tool_use" - ) - - async def _build_tool_prompt(self, message_txt:str, sender_name:str, chat_stream:ChatStream): - """构建工具使用的提示词 - - Args: - message_txt: 用户消息文本 - sender_name: 发送者名称 - chat_stream: 聊天流对象 - - Returns: - str: 构建好的提示词 - """ - from src.plugins.config.config import global_config - - new_messages = list( - db.messages.find({"chat_id": chat_stream.stream_id, "time": {"$gt": time.time()}}) - .sort("time", 1) - .limit(15) - ) - new_messages_str = "" - for msg in new_messages: - if "detailed_plain_text" in msg: - new_messages_str += f"{msg['detailed_plain_text']}" - - # 这些信息应该从调用者传入,而不是从self获取 - bot_name = global_config.BOT_NICKNAME - prompt = "" - prompt += "你正在思考如何回复群里的消息。\n" - prompt += f"你注意到{sender_name}刚刚说:{message_txt}\n" - prompt += f"注意你就是{bot_name},{bot_name}指的就是你。" - prompt += "你现在需要对群里的聊天内容进行回复,现在请你思考,你是否需要额外的信息,或者一些工具来帮你回复,比如回忆或者搜寻已有的知识,或者了解你现在正在做什么,请输出你需要的工具,或者你需要的额外信息。" - - return prompt - - def _define_tools(self): - """定义可用的工具列表 - - Returns: - list: 工具定义列表 - """ - tools = [ - { - "type": "function", - "function": { - "name": "search_knowledge", - "description": "从知识库中搜索相关信息", - "parameters": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "搜索查询关键词" - }, - "threshold": { - "type": "number", - "description": "相似度阈值,0.0到1.0之间" - } - }, - "required": ["query"] - } - } - }, - { - "type": "function", - "function": { - "name": "get_memory", - "description": "从记忆系统中获取相关记忆", - "parameters": { - "type": "object", - "properties": { - "text": { - "type": "string", - "description": "要查询的相关文本" - }, - "max_memory_num": { - "type": "integer", - "description": "最大返回记忆数量" - } - }, - "required": ["text"] - } - } - }, - { - "type": "function", - "function": { - "name": "get_current_task", - "description": "获取当前正在做的事情/最近的任务", - "parameters": { - "type": "object", - "properties": { - "num": { - "type": "integer", - "description": "要获取的任务数量" - }, - "time_info": { - "type": "boolean", - "description": "是否包含时间信息" - } - }, - "required": [] - } - } - } - ] - return tools - - async def _execute_tool_call(self, tool_call, message_txt:str): - """执行特定的工具调用 - - Args: - tool_call: 工具调用对象 - message_txt: 原始消息文本 - - Returns: - dict: 工具调用结果 - """ - try: - function_name = tool_call["function"]["name"] - function_args = json.loads(tool_call["function"]["arguments"]) - - if function_name == "search_knowledge": - return await self._execute_search_knowledge(tool_call, function_args, message_txt) - elif function_name == "get_memory": - return await self._execute_get_memory(tool_call, function_args, message_txt) - elif function_name == "get_current_task": - return await self._execute_get_current_task(tool_call, function_args) - - logger.warning(f"未知工具名称: {function_name}") - return None - except Exception as e: - logger.error(f"执行工具调用时发生错误: {str(e)}") - return None - - async def _execute_search_knowledge(self, tool_call, function_args, message_txt:str): - """执行知识库搜索工具 - - Args: - tool_call: 工具调用对象 - function_args: 工具参数 - message_txt: 原始消息文本 - - Returns: - dict: 工具调用结果 - """ - try: - query = function_args.get("query", message_txt) - threshold = function_args.get("threshold", 0.4) - - # 调用知识库搜索 - embedding = await get_embedding(query, request_type="info_retrieval") - if embedding: - knowledge_info = self.get_info_from_db(embedding, limit=3, threshold=threshold) - return { - "tool_call_id": tool_call["id"], - "role": "tool", - "name": "search_knowledge", - "content": f"知识库搜索结果: {knowledge_info}" - } - return None - except Exception as e: - logger.error(f"知识库搜索工具执行失败: {str(e)}") - return None - - async def _execute_get_memory(self, tool_call, function_args, message_txt:str): - """执行记忆获取工具 - - Args: - tool_call: 工具调用对象 - function_args: 工具参数 - message_txt: 原始消息文本 - - Returns: - dict: 工具调用结果 - """ - try: - text = function_args.get("text", message_txt) - max_memory_num = function_args.get("max_memory_num", 2) - - # 调用记忆系统 - related_memory = await HippocampusManager.get_instance().get_memory_from_text( - text=text, - max_memory_num=max_memory_num, - max_memory_length=2, - max_depth=3, - fast_retrieval=False - ) - - memory_info = "" - if related_memory: - for memory in related_memory: - memory_info += memory[1] + "\n" - - return { - "tool_call_id": tool_call["id"], - "role": "tool", - "name": "get_memory", - "content": f"记忆系统结果: {memory_info if memory_info else '没有找到相关记忆'}" - } - except Exception as e: - logger.error(f"记忆获取工具执行失败: {str(e)}") - return None - - async def _execute_get_current_task(self, tool_call, function_args): - """执行获取当前任务工具 - - Args: - tool_call: 工具调用对象 - function_args: 工具参数 - - Returns: - dict: 工具调用结果 - """ - try: - from src.plugins.schedule.schedule_generator import bot_schedule - - # 获取参数,如果没有提供则使用默认值 - num = function_args.get("num", 1) - time_info = function_args.get("time_info", False) - - # 调用日程系统获取当前任务 - current_task = bot_schedule.get_current_num_task(num=num, time_info=time_info) - - # 格式化返回结果 - if current_task: - task_info = current_task - else: - task_info = "当前没有正在进行的任务" - - return { - "tool_call_id": tool_call["id"], - "role": "tool", - "name": "get_current_task", - "content": f"当前任务信息: {task_info}" - } - except Exception as e: - logger.error(f"获取当前任务工具执行失败: {str(e)}") - return None - - async def use_tool(self, message_txt:str, sender_name:str, chat_stream:ChatStream): - """使用工具辅助思考,判断是否需要额外信息 - - Args: - message_txt: 用户消息文本 - sender_name: 发送者名称 - chat_stream: 聊天流对象 - - Returns: - dict: 工具使用结果 - """ - try: - # 构建提示词 - prompt = await self._build_tool_prompt(message_txt, sender_name, chat_stream) - - # 定义可用工具 - tools = self._define_tools() - - # 使用llm_model_tool发送带工具定义的请求 - payload = { - "model": self.llm_model_tool.model_name, - "messages": [{"role": "user", "content": prompt}], - "max_tokens": global_config.max_response_length, - "tools": tools, - "temperature": 0.2 - } - - logger.debug(f"发送工具调用请求,模型: {self.llm_model_tool.model_name}") - # 发送请求获取模型是否需要调用工具 - response = await self.llm_model_tool._execute_request( - endpoint="/chat/completions", - payload=payload, - prompt=prompt - ) - - # 根据返回值数量判断是否有工具调用 - if len(response) == 3: - content, reasoning_content, tool_calls = response - logger.info(f"工具思考: {tool_calls}") - - # 检查响应中工具调用是否有效 - if not tool_calls: - logger.info("模型返回了空的tool_calls列表") - return {"used_tools": False, "thinking": self.current_mind} - - logger.info(f"模型请求调用{len(tool_calls)}个工具") - tool_results = [] - collected_info = "" - - # 执行所有工具调用 - for tool_call in tool_calls: - result = await self._execute_tool_call(tool_call, message_txt) - if result: - tool_results.append(result) - # 将工具结果添加到收集的信息中 - collected_info += f"\n{result['name']}返回结果: {result['content']}\n" - - # 如果有工具结果,直接返回收集的信息 - if collected_info: - logger.info(f"工具调用收集到信息: {collected_info}") - return { - "used_tools": True, - "collected_info": collected_info, - "thinking": self.current_mind # 保持原始思考不变 - } - else: - # 没有工具调用 - content, reasoning_content = response - logger.info("模型没有请求调用任何工具") - - # 如果没有工具调用或处理失败,直接返回原始思考 - return { - "used_tools": False, - "thinking": self.current_mind - } - - except Exception as e: - logger.error(f"工具调用过程中出错: {str(e)}") - return { - "used_tools": False, - "error": str(e), - "thinking": self.current_mind - } - - - - async def get_prompt_info(self, message: str, threshold: float): - start_time = time.time() - related_info = "" - logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") - - # 1. 先从LLM获取主题,类似于记忆系统的做法 - topics = [] - # try: - # # 先尝试使用记忆系统的方法获取主题 - # hippocampus = HippocampusManager.get_instance()._hippocampus - # topic_num = min(5, max(1, int(len(message) * 0.1))) - # topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num)) - - # # 提取关键词 - # topics = re.findall(r"<([^>]+)>", topics_response[0]) - # if not topics: - # topics = [] - # else: - # topics = [ - # topic.strip() - # for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",") - # if topic.strip() - # ] - - # logger.info(f"从LLM提取的主题: {', '.join(topics)}") - # except Exception as e: - # logger.error(f"从LLM提取主题失败: {str(e)}") - # # 如果LLM提取失败,使用jieba分词提取关键词作为备选 - # words = jieba.cut(message) - # topics = [word for word in words if len(word) > 1][:5] - # logger.info(f"使用jieba提取的主题: {', '.join(topics)}") - - # 如果无法提取到主题,直接使用整个消息 - if not topics: - logger.debug("未能提取到任何主题,使用整个消息进行查询") - embedding = await get_embedding(message, request_type="info_retrieval") - if not embedding: - logger.error("获取消息嵌入向量失败") - return "" - - related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold) - logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}秒") - return related_info, {} - - # 2. 对每个主题进行知识库查询 - logger.info(f"开始处理{len(topics)}个主题的知识库查询") - - # 优化:批量获取嵌入向量,减少API调用 - embeddings = {} - topics_batch = [topic for topic in topics if len(topic) > 0] - if message: # 确保消息非空 - topics_batch.append(message) - - # 批量获取嵌入向量 - embed_start_time = time.time() - for text in topics_batch: - if not text or len(text.strip()) == 0: - continue - - try: - embedding = await get_embedding(text, request_type="info_retrieval") - if embedding: - embeddings[text] = embedding - else: - logger.warning(f"获取'{text}'的嵌入向量失败") - except Exception as e: - logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}") - - logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}秒") - - if not embeddings: - logger.error("所有嵌入向量获取失败") - return "" - - # 3. 对每个主题进行知识库查询 - all_results = [] - query_start_time = time.time() - - # 首先添加原始消息的查询结果 - if message in embeddings: - original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True) - if original_results: - for result in original_results: - result["topic"] = "原始消息" - all_results.extend(original_results) - logger.info(f"原始消息查询到{len(original_results)}条结果") - - # 然后添加每个主题的查询结果 - for topic in topics: - if not topic or topic not in embeddings: - continue - - try: - topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True) - if topic_results: - # 添加主题标记 - for result in topic_results: - result["topic"] = topic - all_results.extend(topic_results) - logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果") - except Exception as e: - logger.error(f"查询主题'{topic}'时发生错误: {str(e)}") - - logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果") - - # 4. 去重和过滤 - process_start_time = time.time() - unique_contents = set() - filtered_results = [] - for result in all_results: - content = result["content"] - if content not in unique_contents: - unique_contents.add(content) - filtered_results.append(result) - - # 5. 按相似度排序 - filtered_results.sort(key=lambda x: x["similarity"], reverse=True) - - # 6. 限制总数量(最多10条) - filtered_results = filtered_results[:10] - logger.info( - f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果" - ) - - # 7. 格式化输出 - if filtered_results: - format_start_time = time.time() - grouped_results = {} - for result in filtered_results: - topic = result["topic"] - if topic not in grouped_results: - grouped_results[topic] = [] - grouped_results[topic].append(result) - - # 按主题组织输出 - for topic, results in grouped_results.items(): - related_info += f"【主题: {topic}】\n" - for _i, result in enumerate(results, 1): - _similarity = result["similarity"] - content = result["content"].strip() - # 调试:为内容添加序号和相似度信息 - # related_info += f"{i}. [{similarity:.2f}] {content}\n" - related_info += f"{content}\n" - related_info += "\n" - - logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}秒") - - logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}秒") - return related_info, grouped_results - - def get_info_from_db( - self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False - ) -> Union[str, list]: - if not query_embedding: - return "" if not return_raw else [] - # 使用余弦相似度计算 - pipeline = [ - { - "$addFields": { - "dotProduct": { - "$reduce": { - "input": {"$range": [0, {"$size": "$embedding"}]}, - "initialValue": 0, - "in": { - "$add": [ - "$$value", - { - "$multiply": [ - {"$arrayElemAt": ["$embedding", "$$this"]}, - {"$arrayElemAt": [query_embedding, "$$this"]}, - ] - }, - ] - }, - } - }, - "magnitude1": { - "$sqrt": { - "$reduce": { - "input": "$embedding", - "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, - } - } - }, - "magnitude2": { - "$sqrt": { - "$reduce": { - "input": query_embedding, - "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, - } - } - }, - } - }, - {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}}, - { - "$match": { - "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果 - } - }, - {"$sort": {"similarity": -1}}, - {"$limit": limit}, - {"$project": {"content": 1, "similarity": 1}}, - ] - - results = list(db.knowledges.aggregate(pipeline)) - logger.debug(f"知识库查询结果数量: {len(results)}") - - if not results: - return "" if not return_raw else [] - - if return_raw: - return results - else: - # 返回所有找到的内容,用换行分隔 - return "\n".join(str(result["content"]) for result in results) \ No newline at end of file