better:心流升级,大大减少了复读情况,并且灵活调用工具来实现知识和记忆检索

This commit is contained in:
SengokuCola
2025-04-10 23:32:28 +08:00
parent 110f94353f
commit 54f7b73ec4
10 changed files with 729 additions and 597 deletions

View File

@@ -0,0 +1,102 @@
# 工具系统使用指南
## 概述
`tool_can_use` 是一个插件式工具系统,允许轻松扩展和注册新工具。每个工具作为独立的文件存在于该目录下,系统会自动发现和注册这些工具。
## 工具结构
每个工具应该继承 `BaseTool` 基类并实现必要的属性和方法:
```python
from src.do_tool.tool_can_use.base_tool import BaseTool, register_tool
class MyNewTool(BaseTool):
# 工具名称,必须唯一
name = "my_new_tool"
# 工具描述告诉LLM这个工具的用途
description = "这是一个新工具,用于..."
# 工具参数定义遵循JSONSchema格式
parameters = {
"type": "object",
"properties": {
"param1": {
"type": "string",
"description": "参数1的描述"
},
"param2": {
"type": "integer",
"description": "参数2的描述"
}
},
"required": ["param1"] # 必需的参数列表
}
async def execute(self, function_args, message_txt=""):
"""执行工具逻辑
Args:
function_args: 工具调用参数
message_txt: 原始消息文本
Returns:
Dict: 包含执行结果的字典必须包含name和content字段
"""
# 实现工具逻辑
result = f"工具执行结果: {function_args.get('param1')}"
return {
"name": self.name,
"content": result
}
# 注册工具
register_tool(MyNewTool)
```
## 自动注册机制
工具系统通过以下步骤自动注册工具:
1.`__init__.py`中,`discover_tools()`函数会自动遍历当前目录中的所有Python文件
2. 对于每个文件,系统会寻找继承自`BaseTool`的类
3. 这些类会被自动注册到工具注册表中
只要确保在每个工具文件的末尾调用`register_tool(YourToolClass)`,工具就会被自动注册。
## 添加新工具步骤
1.`tool_can_use`目录下创建新的Python文件`my_new_tool.py`
2. 导入`BaseTool``register_tool`
3. 创建继承自`BaseTool`的工具类
4. 实现必要的属性(`name`, `description`, `parameters`
5. 实现`execute`方法
6. 使用`register_tool`注册工具
## 与ToolUser整合
`ToolUser`类已经更新为使用这个新的工具系统,它会:
1. 自动获取所有已注册工具的定义
2. 基于工具名称找到对应的工具实例
3. 调用工具的`execute`方法
## 使用示例
```python
from src.do_tool.tool_use import ToolUser
# 创建工具用户
tool_user = ToolUser()
# 使用工具
result = await tool_user.use_tool(message_txt="查询关于Python的知识", sender_name="用户", chat_stream=chat_stream)
# 处理结果
if result["used_tools"]:
print("工具使用结果:", result["collected_info"])
else:
print("未使用工具")
```

View File

@@ -0,0 +1,11 @@
from src.do_tool.tool_can_use.base_tool import (
BaseTool,
register_tool,
discover_tools,
get_all_tool_definitions,
get_tool_instance,
TOOL_REGISTRY
)
# 自动发现并注册工具
discover_tools()

View File

@@ -0,0 +1,119 @@
from typing import Dict, List, Any, Optional, Union, Type
import inspect
import importlib
import pkgutil
import os
from src.common.logger import get_module_logger
logger = get_module_logger("base_tool")
# 工具注册表
TOOL_REGISTRY = {}
class BaseTool:
"""所有工具的基类"""
# 工具名称,子类必须重写
name = None
# 工具描述,子类必须重写
description = None
# 工具参数定义,子类必须重写
parameters = None
@classmethod
def get_tool_definition(cls) -> Dict[str, Any]:
"""获取工具定义用于LLM工具调用
Returns:
Dict: 工具定义字典
"""
if not cls.name or not cls.description or not cls.parameters:
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
return {
"type": "function",
"function": {
"name": cls.name,
"description": cls.description,
"parameters": cls.parameters
}
}
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
"""执行工具函数
Args:
function_args: 工具调用参数
message_txt: 原始消息文本
Returns:
Dict: 工具执行结果
"""
raise NotImplementedError("子类必须实现execute方法")
def register_tool(tool_class: Type[BaseTool]):
"""注册工具到全局注册表
Args:
tool_class: 工具类
"""
if not issubclass(tool_class, BaseTool):
raise TypeError(f"{tool_class.__name__} 不是 BaseTool 的子类")
tool_name = tool_class.name
if not tool_name:
raise ValueError(f"工具类 {tool_class.__name__} 没有定义 name 属性")
TOOL_REGISTRY[tool_name] = tool_class
logger.info(f"已注册工具: {tool_name}")
def discover_tools():
"""自动发现并注册tool_can_use目录下的所有工具"""
# 获取当前目录路径
current_dir = os.path.dirname(os.path.abspath(__file__))
package_name = os.path.basename(current_dir)
parent_dir = os.path.dirname(current_dir)
# 导入当前包
package = importlib.import_module(f"src.do_tool.{package_name}")
# 遍历包中的所有模块
for _, module_name, is_pkg in pkgutil.iter_modules([current_dir]):
# 跳过当前模块和__pycache__
if module_name == "base_tool" or module_name.startswith("__"):
continue
# 导入模块
module = importlib.import_module(f"src.do_tool.{package_name}.{module_name}")
# 查找模块中的工具类
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
register_tool(obj)
logger.info(f"工具发现完成,共注册 {len(TOOL_REGISTRY)} 个工具")
def get_all_tool_definitions() -> List[Dict[str, Any]]:
"""获取所有已注册工具的定义
Returns:
List[Dict]: 工具定义列表
"""
return [tool_class().get_tool_definition() for tool_class in TOOL_REGISTRY.values()]
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
"""获取指定名称的工具实例
Args:
tool_name: 工具名称
Returns:
Optional[BaseTool]: 工具实例如果找不到则返回None
"""
tool_class = TOOL_REGISTRY.get(tool_name)
if not tool_class:
return None
return tool_class()

View File

@@ -0,0 +1,63 @@
from src.do_tool.tool_can_use.base_tool import BaseTool, register_tool
from src.plugins.schedule.schedule_generator import bot_schedule
from src.common.logger import get_module_logger
from typing import Dict, Any
logger = get_module_logger("get_current_task_tool")
class GetCurrentTaskTool(BaseTool):
"""获取当前正在做的事情/最近的任务工具"""
name = "get_current_task"
description = "获取当前正在做的事情/最近的任务"
parameters = {
"type": "object",
"properties": {
"num": {
"type": "integer",
"description": "要获取的任务数量"
},
"time_info": {
"type": "boolean",
"description": "是否包含时间信息"
}
},
"required": []
}
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
"""执行获取当前任务
Args:
function_args: 工具参数
message_txt: 原始消息文本,此工具不使用
Returns:
Dict: 工具执行结果
"""
try:
# 获取参数,如果没有提供则使用默认值
num = function_args.get("num", 1)
time_info = function_args.get("time_info", False)
# 调用日程系统获取当前任务
current_task = bot_schedule.get_current_num_task(num=num, time_info=time_info)
# 格式化返回结果
if current_task:
task_info = current_task
else:
task_info = "当前没有正在进行的任务"
return {
"name": "get_current_task",
"content": f"当前任务信息: {task_info}"
}
except Exception as e:
logger.error(f"获取当前任务工具执行失败: {str(e)}")
return {
"name": "get_current_task",
"content": f"获取当前任务失败: {str(e)}"
}
# 注册工具
register_tool(GetCurrentTaskTool)

View File

@@ -0,0 +1,147 @@
from src.do_tool.tool_can_use.base_tool import BaseTool, register_tool
from src.plugins.chat.utils import get_embedding
from src.common.database import db
from src.common.logger import get_module_logger
from typing import Dict, Any, Union, List
logger = get_module_logger("get_knowledge_tool")
class SearchKnowledgeTool(BaseTool):
"""从知识库中搜索相关信息的工具"""
name = "search_knowledge"
description = "从知识库中搜索相关信息"
parameters = {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "搜索查询关键词"
},
"threshold": {
"type": "number",
"description": "相似度阈值0.0到1.0之间"
}
},
"required": ["query"]
}
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
"""执行知识库搜索
Args:
function_args: 工具参数
message_txt: 原始消息文本
Returns:
Dict: 工具执行结果
"""
try:
query = function_args.get("query", message_txt)
threshold = function_args.get("threshold", 0.4)
# 调用知识库搜索
embedding = await get_embedding(query, request_type="info_retrieval")
if embedding:
knowledge_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
if knowledge_info:
content = f"你知道这些知识: {knowledge_info}"
else:
content = f"你不太了解有关{query}的知识"
return {
"name": "search_knowledge",
"content": content
}
return {
"name": "search_knowledge",
"content": f"无法获取关于'{query}'的嵌入向量"
}
except Exception as e:
logger.error(f"知识库搜索工具执行失败: {str(e)}")
return {
"name": "search_knowledge",
"content": f"知识库搜索失败: {str(e)}"
}
def get_info_from_db(
self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
) -> Union[str, list]:
"""从数据库中获取相关信息
Args:
query_embedding: 查询的嵌入向量
limit: 最大返回结果数
threshold: 相似度阈值
return_raw: 是否返回原始结果
Returns:
Union[str, list]: 格式化的信息字符串或原始结果列表
"""
if not query_embedding:
return "" if not return_raw else []
# 使用余弦相似度计算
pipeline = [
{
"$addFields": {
"dotProduct": {
"$reduce": {
"input": {"$range": [0, {"$size": "$embedding"}]},
"initialValue": 0,
"in": {
"$add": [
"$$value",
{
"$multiply": [
{"$arrayElemAt": ["$embedding", "$$this"]},
{"$arrayElemAt": [query_embedding, "$$this"]},
]
},
]
},
}
},
"magnitude1": {
"$sqrt": {
"$reduce": {
"input": "$embedding",
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
"magnitude2": {
"$sqrt": {
"$reduce": {
"input": query_embedding,
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
}
},
{"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
{
"$match": {
"similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
}
},
{"$sort": {"similarity": -1}},
{"$limit": limit},
{"$project": {"content": 1, "similarity": 1}},
]
results = list(db.knowledges.aggregate(pipeline))
logger.debug(f"知识库查询结果数量: {len(results)}")
if not results:
return "" if not return_raw else []
if return_raw:
return results
else:
# 返回所有找到的内容,用换行分隔
return "\n".join(str(result["content"]) for result in results)
# 注册工具
register_tool(SearchKnowledgeTool)

View File

@@ -0,0 +1,72 @@
from src.do_tool.tool_can_use.base_tool import BaseTool, register_tool
from src.plugins.memory_system.Hippocampus import HippocampusManager
from src.common.logger import get_module_logger
from typing import Dict, Any
logger = get_module_logger("get_memory_tool")
class GetMemoryTool(BaseTool):
"""从记忆系统中获取相关记忆的工具"""
name = "get_memory"
description = "从记忆系统中获取相关记忆"
parameters = {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "要查询的相关文本"
},
"max_memory_num": {
"type": "integer",
"description": "最大返回记忆数量"
}
},
"required": ["text"]
}
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
"""执行记忆获取
Args:
function_args: 工具参数
message_txt: 原始消息文本
Returns:
Dict: 工具执行结果
"""
try:
text = function_args.get("text", message_txt)
max_memory_num = function_args.get("max_memory_num", 2)
# 调用记忆系统
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
text=text,
max_memory_num=max_memory_num,
max_memory_length=2,
max_depth=3,
fast_retrieval=False
)
memory_info = ""
if related_memory:
for memory in related_memory:
memory_info += memory[1] + "\n"
if memory_info:
content = f"你记得这些事情: {memory_info}"
else:
content = f"你不太记得有关{text}的记忆,你对此不太了解"
return {
"name": "get_memory",
"content": content
}
except Exception as e:
logger.error(f"记忆获取工具执行失败: {str(e)}")
return {
"name": "get_memory",
"content": f"记忆获取失败: {str(e)}"
}
# 注册工具
register_tool(GetMemoryTool)

172
src/do_tool/tool_use.py Normal file
View File

@@ -0,0 +1,172 @@
from src.plugins.models.utils_model import LLM_request
from src.plugins.config.config import global_config
from src.plugins.chat.chat_stream import ChatStream
from src.common.database import db
import time
import json
from src.common.logger import get_module_logger
from typing import Union
from src.do_tool.tool_can_use import get_all_tool_definitions, get_tool_instance
logger = get_module_logger("tool_use")
class ToolUser:
def __init__(self):
self.llm_model_tool = LLM_request(
model=global_config.llm_heartflow, temperature=0.2, max_tokens=1000, request_type="tool_use"
)
async def _build_tool_prompt(self, message_txt:str, sender_name:str, chat_stream:ChatStream):
"""构建工具使用的提示词
Args:
message_txt: 用户消息文本
sender_name: 发送者名称
chat_stream: 聊天流对象
Returns:
str: 构建好的提示词
"""
new_messages = list(
db.messages.find({"chat_id": chat_stream.stream_id, "time": {"$gt": time.time()}})
.sort("time", 1)
.limit(15)
)
new_messages_str = ""
for msg in new_messages:
if "detailed_plain_text" in msg:
new_messages_str += f"{msg['detailed_plain_text']}"
# 这些信息应该从调用者传入而不是从self获取
bot_name = global_config.BOT_NICKNAME
prompt = ""
prompt += "你正在思考如何回复群里的消息。\n"
prompt += f"你注意到{sender_name}刚刚说:{message_txt}\n"
prompt += f"注意你就是{bot_name}{bot_name}指的就是你。"
prompt += "你现在需要对群里的聊天内容进行回复,现在请你思考,你是否需要额外的信息,或者一些工具来帮你回复,比如回忆或者搜寻已有的知识,或者了解你现在正在做什么,请输出你需要的工具,或者你需要的额外信息。"
return prompt
def _define_tools(self):
"""获取所有已注册工具的定义
Returns:
list: 工具定义列表
"""
return get_all_tool_definitions()
async def _execute_tool_call(self, tool_call, message_txt:str):
"""执行特定的工具调用
Args:
tool_call: 工具调用对象
message_txt: 原始消息文本
Returns:
dict: 工具调用结果
"""
try:
function_name = tool_call["function"]["name"]
function_args = json.loads(tool_call["function"]["arguments"])
# 获取对应工具实例
tool_instance = get_tool_instance(function_name)
if not tool_instance:
logger.warning(f"未知工具名称: {function_name}")
return None
# 执行工具
result = await tool_instance.execute(function_args, message_txt)
if result:
return {
"tool_call_id": tool_call["id"],
"role": "tool",
"name": function_name,
"content": result["content"]
}
return None
except Exception as e:
logger.error(f"执行工具调用时发生错误: {str(e)}")
return None
async def use_tool(self, message_txt:str, sender_name:str, chat_stream:ChatStream):
"""使用工具辅助思考,判断是否需要额外信息
Args:
message_txt: 用户消息文本
sender_name: 发送者名称
chat_stream: 聊天流对象
Returns:
dict: 工具使用结果
"""
try:
# 构建提示词
prompt = await self._build_tool_prompt(message_txt, sender_name, chat_stream)
# 定义可用工具
tools = self._define_tools()
# 使用llm_model_tool发送带工具定义的请求
payload = {
"model": self.llm_model_tool.model_name,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": global_config.max_response_length,
"tools": tools,
"temperature": 0.2
}
logger.debug(f"发送工具调用请求,模型: {self.llm_model_tool.model_name}")
# 发送请求获取模型是否需要调用工具
response = await self.llm_model_tool._execute_request(
endpoint="/chat/completions",
payload=payload,
prompt=prompt
)
# 根据返回值数量判断是否有工具调用
if len(response) == 3:
content, reasoning_content, tool_calls = response
logger.info(f"工具思考: {tool_calls}")
# 检查响应中工具调用是否有效
if not tool_calls:
logger.info("模型返回了空的tool_calls列表")
return {"used_tools": False}
logger.info(f"模型请求调用{len(tool_calls)}个工具")
tool_results = []
collected_info = ""
# 执行所有工具调用
for tool_call in tool_calls:
result = await self._execute_tool_call(tool_call, message_txt)
if result:
tool_results.append(result)
# 将工具结果添加到收集的信息中
collected_info += f"\n{result['name']}返回结果: {result['content']}\n"
# 如果有工具结果,直接返回收集的信息
if collected_info:
logger.info(f"工具调用收集到信息: {collected_info}")
return {
"used_tools": True,
"collected_info": collected_info,
}
else:
# 没有工具调用
content, reasoning_content = response
logger.info("模型没有请求调用任何工具")
# 如果没有工具调用或处理失败,直接返回原始思考
return {
"used_tools": False,
}
except Exception as e:
logger.error(f"工具调用过程中出错: {str(e)}")
return {
"used_tools": False,
"error": str(e),
}

View File

@@ -68,20 +68,24 @@ class ChattingObservation(Observation):
self.translate_message_list_to_str()
# 更新观察次数
self.observe_times += 1
# self.observe_times += 1
self.last_observe_time = new_messages[-1]["time"]
# 检查是否需要更新summary
current_time = int(datetime.now().timestamp())
if current_time - self.last_summary_time >= 30: # 如果超过30秒重置计数
self.summary_count = 0
self.last_summary_time = current_time
# current_time = int(datetime.now().timestamp())
# if current_time - self.last_summary_time >= 30: # 如果超过30秒重置计数
# self.summary_count = 0
# self.last_summary_time = current_time
if self.summary_count < self.max_update_in_30s: # 如果30秒内更新次数小于2次
await self.update_talking_summary(new_messages_str)
self.summary_count += 1
# if self.summary_count < self.max_update_in_30s: # 如果30秒内更新次数小于2次
# await self.update_talking_summary(new_messages_str)
# print(f"更新聊天总结:{self.observe_info}11111111111111")
# self.summary_count += 1
updated_observe_info = await self.update_talking_summary(new_messages_str)
print(f"更新聊天总结:{updated_observe_info}11111111111111")
self.observe_info = updated_observe_info
return self.observe_info
return updated_observe_info
async def carefully_observe(self):
# 查找新消息限制最多40条
@@ -110,43 +114,46 @@ class ChattingObservation(Observation):
self.observe_times += 1
self.last_observe_time = new_messages[-1]["time"]
await self.update_talking_summary(new_messages_str)
return self.observe_info
updated_observe_info = await self.update_talking_summary(new_messages_str)
self.observe_info = updated_observe_info
return updated_observe_info
async def update_talking_summary(self, new_messages_str):
# 基于已经有的talking_summary和新的talking_message生成一个summary
# print(f"更新聊天总结:{self.talking_summary}")
# 开始构建prompt
prompt_personality = ""
# person
individuality = Individuality.get_instance()
# prompt_personality = "你"
# # person
# individuality = Individuality.get_instance()
personality_core = individuality.personality.personality_core
prompt_personality += personality_core
# personality_core = individuality.personality.personality_core
# prompt_personality += personality_core
personality_sides = individuality.personality.personality_sides
random.shuffle(personality_sides)
prompt_personality += f",{personality_sides[0]}"
# personality_sides = individuality.personality.personality_sides
# random.shuffle(personality_sides)
# prompt_personality += f",{personality_sides[0]}"
identity_detail = individuality.identity.identity_detail
random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}"
# identity_detail = individuality.identity.identity_detail
# random.shuffle(identity_detail)
# prompt_personality += f",{identity_detail[0]}"
personality_info = prompt_personality
# personality_info = prompt_personality
prompt = ""
prompt += f"{personality_info},请注意识别你自己的聊天发言"
prompt += f"你的名字叫:{self.name},你的昵称是:{self.nick_name}\n"
# prompt += f"{personality_info}"
prompt += f"你的名字叫:{self.name}\n,标识'{self.name}'的都是你自己说的话"
prompt += f"你正在参与一个qq群聊的讨论你记得这个群之前在聊的内容是{self.observe_info}\n"
prompt += f"现在群里的群友们产生了新的讨论,有了新的发言,具体内容如下:{new_messages_str}\n"
prompt += """以上是群里在进行的聊天,请你对这个聊天内容进行总结,总结内容要包含聊天的大致内容,
以及聊天中的一些重要信息,注意识别你自己的发言,记得不要分点,不要太长,精简的概括成一段文本\n"""
prompt += """以上是群里在进行的聊天,请你对这个聊天内容进行总结,总结内容要包含聊天的大致内容,目前最新讨论的话题
以及聊天中的一些重要信息,记得不要分点,精简的概括成一段文本\n"""
prompt += "总结概括:"
try:
self.observe_info, reasoning_content = await self.llm_summary.generate_response_async(prompt)
updated_observe_info, reasoning_content = await self.llm_summary.generate_response_async(prompt)
except Exception as e:
print(f"获取总结失败: {e}")
self.observe_info = ""
updated_observe_info = ""
return updated_observe_info
# print(f"prompt{prompt}")
# print(f"self.observe_info{self.observe_info}")

View File

@@ -17,7 +17,7 @@ from src.plugins.chat.chat_stream import ChatStream
from src.plugins.person_info.relationship_manager import relationship_manager
from src.plugins.chat.utils import get_recent_group_speaker
import json
from src.heart_flow.tool_use import ToolUser
from src.do_tool.tool_use import ToolUser
subheartflow_config = LogConfig(
# 使用海马体专用样式
@@ -133,6 +133,7 @@ class SubHeartflow:
tool_result = await self.tool_user.use_tool(message_txt, sender_name, chat_stream)
# 如果工具被使用且获得了结果,将收集到的信息合并到思考中
collected_info = ""
if tool_result.get("used_tools", False):
logger.info("使用工具收集了信息")
@@ -140,8 +141,6 @@ class SubHeartflow:
if "collected_info" in tool_result:
collected_info = tool_result["collected_info"]
# 开始构建prompt
prompt_personality = f"你的名字是{self.bot_name},你"
# person
@@ -178,6 +177,7 @@ class SubHeartflow:
)
prompt = ""
# prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n"
if tool_result.get("used_tools", False):
prompt += f"{collected_info}\n"
prompt += f"{relation_prompt_all}\n"
@@ -187,10 +187,10 @@ class SubHeartflow:
prompt += f"现在你正在上网和qq群里的网友们聊天群里正在聊的话题是{chat_observe_info}\n"
prompt += f"你现在{mood_info}\n"
prompt += f"你注意到{sender_name}刚刚说:{message_txt}\n"
prompt += "现在你接下去继续思考,产生新的想法,不要分点输出,输出连贯的内心独白,不要太长,"
prompt += "思考时可以想想如何对群聊内容进行回复。回复的要求是:平淡一些,简短一些,说中文,不要刻意突出自身学科背景,尽量不要说你说过的话 ,注意只输出回复内容。"
prompt += "请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等)"
prompt += f"记得结合上述的消息,生成符合内心想法的内心独白,文字不要浮夸,注意你就是{self.bot_name}{self.bot_name}指的就是你。"
prompt += "现在你接下去继续思考,产生新的想法,不要分点输出,输出连贯的内心独白"
prompt += "思考时可以想想如何对群聊内容进行回复。回复的要求是:平淡一些,简短一些,说中文,尽量不要说你说过的话\n"
prompt += "请注意不要输出多余内容(包括前后缀,冒号和引号,括号, 表情,其他描述等)"
prompt += f"记得结合上述的消息,生成内心想法,文字不要浮夸,注意你就是{self.bot_name}{self.bot_name}指的就是你。"
try:
response, reasoning_content = await self.llm_model.generate_response_async(prompt)

View File

@@ -1,561 +0,0 @@
from src.plugins.models.utils_model import LLM_request
from src.plugins.config.config import global_config
from src.plugins.chat.chat_stream import ChatStream
from src.plugins.memory_system.Hippocampus import HippocampusManager
from src.common.database import db
import time
import json
from src.common.logger import get_module_logger
from src.plugins.chat.utils import get_embedding
from typing import Union
logger = get_module_logger("tool_use")
class ToolUser:
def __init__(self):
self.llm_model_tool = LLM_request(
model=global_config.llm_heartflow, temperature=0.2, max_tokens=1000, request_type="tool_use"
)
async def _build_tool_prompt(self, message_txt:str, sender_name:str, chat_stream:ChatStream):
"""构建工具使用的提示词
Args:
message_txt: 用户消息文本
sender_name: 发送者名称
chat_stream: 聊天流对象
Returns:
str: 构建好的提示词
"""
from src.plugins.config.config import global_config
new_messages = list(
db.messages.find({"chat_id": chat_stream.stream_id, "time": {"$gt": time.time()}})
.sort("time", 1)
.limit(15)
)
new_messages_str = ""
for msg in new_messages:
if "detailed_plain_text" in msg:
new_messages_str += f"{msg['detailed_plain_text']}"
# 这些信息应该从调用者传入而不是从self获取
bot_name = global_config.BOT_NICKNAME
prompt = ""
prompt += "你正在思考如何回复群里的消息。\n"
prompt += f"你注意到{sender_name}刚刚说:{message_txt}\n"
prompt += f"注意你就是{bot_name}{bot_name}指的就是你。"
prompt += "你现在需要对群里的聊天内容进行回复,现在请你思考,你是否需要额外的信息,或者一些工具来帮你回复,比如回忆或者搜寻已有的知识,或者了解你现在正在做什么,请输出你需要的工具,或者你需要的额外信息。"
return prompt
def _define_tools(self):
"""定义可用的工具列表
Returns:
list: 工具定义列表
"""
tools = [
{
"type": "function",
"function": {
"name": "search_knowledge",
"description": "从知识库中搜索相关信息",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "搜索查询关键词"
},
"threshold": {
"type": "number",
"description": "相似度阈值0.0到1.0之间"
}
},
"required": ["query"]
}
}
},
{
"type": "function",
"function": {
"name": "get_memory",
"description": "从记忆系统中获取相关记忆",
"parameters": {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "要查询的相关文本"
},
"max_memory_num": {
"type": "integer",
"description": "最大返回记忆数量"
}
},
"required": ["text"]
}
}
},
{
"type": "function",
"function": {
"name": "get_current_task",
"description": "获取当前正在做的事情/最近的任务",
"parameters": {
"type": "object",
"properties": {
"num": {
"type": "integer",
"description": "要获取的任务数量"
},
"time_info": {
"type": "boolean",
"description": "是否包含时间信息"
}
},
"required": []
}
}
}
]
return tools
async def _execute_tool_call(self, tool_call, message_txt:str):
"""执行特定的工具调用
Args:
tool_call: 工具调用对象
message_txt: 原始消息文本
Returns:
dict: 工具调用结果
"""
try:
function_name = tool_call["function"]["name"]
function_args = json.loads(tool_call["function"]["arguments"])
if function_name == "search_knowledge":
return await self._execute_search_knowledge(tool_call, function_args, message_txt)
elif function_name == "get_memory":
return await self._execute_get_memory(tool_call, function_args, message_txt)
elif function_name == "get_current_task":
return await self._execute_get_current_task(tool_call, function_args)
logger.warning(f"未知工具名称: {function_name}")
return None
except Exception as e:
logger.error(f"执行工具调用时发生错误: {str(e)}")
return None
async def _execute_search_knowledge(self, tool_call, function_args, message_txt:str):
"""执行知识库搜索工具
Args:
tool_call: 工具调用对象
function_args: 工具参数
message_txt: 原始消息文本
Returns:
dict: 工具调用结果
"""
try:
query = function_args.get("query", message_txt)
threshold = function_args.get("threshold", 0.4)
# 调用知识库搜索
embedding = await get_embedding(query, request_type="info_retrieval")
if embedding:
knowledge_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
return {
"tool_call_id": tool_call["id"],
"role": "tool",
"name": "search_knowledge",
"content": f"知识库搜索结果: {knowledge_info}"
}
return None
except Exception as e:
logger.error(f"知识库搜索工具执行失败: {str(e)}")
return None
async def _execute_get_memory(self, tool_call, function_args, message_txt:str):
"""执行记忆获取工具
Args:
tool_call: 工具调用对象
function_args: 工具参数
message_txt: 原始消息文本
Returns:
dict: 工具调用结果
"""
try:
text = function_args.get("text", message_txt)
max_memory_num = function_args.get("max_memory_num", 2)
# 调用记忆系统
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
text=text,
max_memory_num=max_memory_num,
max_memory_length=2,
max_depth=3,
fast_retrieval=False
)
memory_info = ""
if related_memory:
for memory in related_memory:
memory_info += memory[1] + "\n"
return {
"tool_call_id": tool_call["id"],
"role": "tool",
"name": "get_memory",
"content": f"记忆系统结果: {memory_info if memory_info else '没有找到相关记忆'}"
}
except Exception as e:
logger.error(f"记忆获取工具执行失败: {str(e)}")
return None
async def _execute_get_current_task(self, tool_call, function_args):
"""执行获取当前任务工具
Args:
tool_call: 工具调用对象
function_args: 工具参数
Returns:
dict: 工具调用结果
"""
try:
from src.plugins.schedule.schedule_generator import bot_schedule
# 获取参数,如果没有提供则使用默认值
num = function_args.get("num", 1)
time_info = function_args.get("time_info", False)
# 调用日程系统获取当前任务
current_task = bot_schedule.get_current_num_task(num=num, time_info=time_info)
# 格式化返回结果
if current_task:
task_info = current_task
else:
task_info = "当前没有正在进行的任务"
return {
"tool_call_id": tool_call["id"],
"role": "tool",
"name": "get_current_task",
"content": f"当前任务信息: {task_info}"
}
except Exception as e:
logger.error(f"获取当前任务工具执行失败: {str(e)}")
return None
async def use_tool(self, message_txt:str, sender_name:str, chat_stream:ChatStream):
"""使用工具辅助思考,判断是否需要额外信息
Args:
message_txt: 用户消息文本
sender_name: 发送者名称
chat_stream: 聊天流对象
Returns:
dict: 工具使用结果
"""
try:
# 构建提示词
prompt = await self._build_tool_prompt(message_txt, sender_name, chat_stream)
# 定义可用工具
tools = self._define_tools()
# 使用llm_model_tool发送带工具定义的请求
payload = {
"model": self.llm_model_tool.model_name,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": global_config.max_response_length,
"tools": tools,
"temperature": 0.2
}
logger.debug(f"发送工具调用请求,模型: {self.llm_model_tool.model_name}")
# 发送请求获取模型是否需要调用工具
response = await self.llm_model_tool._execute_request(
endpoint="/chat/completions",
payload=payload,
prompt=prompt
)
# 根据返回值数量判断是否有工具调用
if len(response) == 3:
content, reasoning_content, tool_calls = response
logger.info(f"工具思考: {tool_calls}")
# 检查响应中工具调用是否有效
if not tool_calls:
logger.info("模型返回了空的tool_calls列表")
return {"used_tools": False, "thinking": self.current_mind}
logger.info(f"模型请求调用{len(tool_calls)}个工具")
tool_results = []
collected_info = ""
# 执行所有工具调用
for tool_call in tool_calls:
result = await self._execute_tool_call(tool_call, message_txt)
if result:
tool_results.append(result)
# 将工具结果添加到收集的信息中
collected_info += f"\n{result['name']}返回结果: {result['content']}\n"
# 如果有工具结果,直接返回收集的信息
if collected_info:
logger.info(f"工具调用收集到信息: {collected_info}")
return {
"used_tools": True,
"collected_info": collected_info,
"thinking": self.current_mind # 保持原始思考不变
}
else:
# 没有工具调用
content, reasoning_content = response
logger.info("模型没有请求调用任何工具")
# 如果没有工具调用或处理失败,直接返回原始思考
return {
"used_tools": False,
"thinking": self.current_mind
}
except Exception as e:
logger.error(f"工具调用过程中出错: {str(e)}")
return {
"used_tools": False,
"error": str(e),
"thinking": self.current_mind
}
async def get_prompt_info(self, message: str, threshold: float):
start_time = time.time()
related_info = ""
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
# 1. 先从LLM获取主题类似于记忆系统的做法
topics = []
# try:
# # 先尝试使用记忆系统的方法获取主题
# hippocampus = HippocampusManager.get_instance()._hippocampus
# topic_num = min(5, max(1, int(len(message) * 0.1)))
# topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
# # 提取关键词
# topics = re.findall(r"<([^>]+)>", topics_response[0])
# if not topics:
# topics = []
# else:
# topics = [
# topic.strip()
# for topic in ",".join(topics).replace("", ",").replace("、", ",").replace(" ", ",").split(",")
# if topic.strip()
# ]
# logger.info(f"从LLM提取的主题: {', '.join(topics)}")
# except Exception as e:
# logger.error(f"从LLM提取主题失败: {str(e)}")
# # 如果LLM提取失败使用jieba分词提取关键词作为备选
# words = jieba.cut(message)
# topics = [word for word in words if len(word) > 1][:5]
# logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
# 如果无法提取到主题,直接使用整个消息
if not topics:
logger.debug("未能提取到任何主题,使用整个消息进行查询")
embedding = await get_embedding(message, request_type="info_retrieval")
if not embedding:
logger.error("获取消息嵌入向量失败")
return ""
related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}")
return related_info, {}
# 2. 对每个主题进行知识库查询
logger.info(f"开始处理{len(topics)}个主题的知识库查询")
# 优化批量获取嵌入向量减少API调用
embeddings = {}
topics_batch = [topic for topic in topics if len(topic) > 0]
if message: # 确保消息非空
topics_batch.append(message)
# 批量获取嵌入向量
embed_start_time = time.time()
for text in topics_batch:
if not text or len(text.strip()) == 0:
continue
try:
embedding = await get_embedding(text, request_type="info_retrieval")
if embedding:
embeddings[text] = embedding
else:
logger.warning(f"获取'{text}'的嵌入向量失败")
except Exception as e:
logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}")
logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}")
if not embeddings:
logger.error("所有嵌入向量获取失败")
return ""
# 3. 对每个主题进行知识库查询
all_results = []
query_start_time = time.time()
# 首先添加原始消息的查询结果
if message in embeddings:
original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True)
if original_results:
for result in original_results:
result["topic"] = "原始消息"
all_results.extend(original_results)
logger.info(f"原始消息查询到{len(original_results)}条结果")
# 然后添加每个主题的查询结果
for topic in topics:
if not topic or topic not in embeddings:
continue
try:
topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True)
if topic_results:
# 添加主题标记
for result in topic_results:
result["topic"] = topic
all_results.extend(topic_results)
logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果")
except Exception as e:
logger.error(f"查询主题'{topic}'时发生错误: {str(e)}")
logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果")
# 4. 去重和过滤
process_start_time = time.time()
unique_contents = set()
filtered_results = []
for result in all_results:
content = result["content"]
if content not in unique_contents:
unique_contents.add(content)
filtered_results.append(result)
# 5. 按相似度排序
filtered_results.sort(key=lambda x: x["similarity"], reverse=True)
# 6. 限制总数量最多10条
filtered_results = filtered_results[:10]
logger.info(
f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果"
)
# 7. 格式化输出
if filtered_results:
format_start_time = time.time()
grouped_results = {}
for result in filtered_results:
topic = result["topic"]
if topic not in grouped_results:
grouped_results[topic] = []
grouped_results[topic].append(result)
# 按主题组织输出
for topic, results in grouped_results.items():
related_info += f"【主题: {topic}\n"
for _i, result in enumerate(results, 1):
_similarity = result["similarity"]
content = result["content"].strip()
# 调试:为内容添加序号和相似度信息
# related_info += f"{i}. [{similarity:.2f}] {content}\n"
related_info += f"{content}\n"
related_info += "\n"
logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}")
logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}")
return related_info, grouped_results
def get_info_from_db(
self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
) -> Union[str, list]:
if not query_embedding:
return "" if not return_raw else []
# 使用余弦相似度计算
pipeline = [
{
"$addFields": {
"dotProduct": {
"$reduce": {
"input": {"$range": [0, {"$size": "$embedding"}]},
"initialValue": 0,
"in": {
"$add": [
"$$value",
{
"$multiply": [
{"$arrayElemAt": ["$embedding", "$$this"]},
{"$arrayElemAt": [query_embedding, "$$this"]},
]
},
]
},
}
},
"magnitude1": {
"$sqrt": {
"$reduce": {
"input": "$embedding",
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
"magnitude2": {
"$sqrt": {
"$reduce": {
"input": query_embedding,
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
}
},
{"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
{
"$match": {
"similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
}
},
{"$sort": {"similarity": -1}},
{"$limit": limit},
{"$project": {"content": 1, "similarity": 1}},
]
results = list(db.knowledges.aggregate(pipeline))
logger.debug(f"知识库查询结果数量: {len(results)}")
if not results:
return "" if not return_raw else []
if return_raw:
return results
else:
# 返回所有找到的内容,用换行分隔
return "\n".join(str(result["content"]) for result in results)