refactor: 移除对机器人自身用户ID的特殊处理,统一使用QQ号进行比较
This commit is contained in:
@@ -280,8 +280,6 @@ class MessageStorageBatcher:
|
|||||||
user_platform = user_info_dict.get("platform")
|
user_platform = user_info_dict.get("platform")
|
||||||
user_id = user_info_dict.get("user_id")
|
user_id = user_info_dict.get("user_id")
|
||||||
# 将机器人自己的user_id标记为"SELF",增强对自我身份的识别
|
# 将机器人自己的user_id标记为"SELF",增强对自我身份的识别
|
||||||
if user_id == global_config.bot.qq_account:
|
|
||||||
user_id = "SELF"
|
|
||||||
user_nickname = user_info_dict.get("user_nickname")
|
user_nickname = user_info_dict.get("user_nickname")
|
||||||
user_cardname = user_info_dict.get("user_cardname")
|
user_cardname = user_info_dict.get("user_cardname")
|
||||||
|
|
||||||
@@ -630,9 +628,6 @@ class MessageStorage:
|
|||||||
|
|
||||||
user_platform = user_info_dict.get("platform")
|
user_platform = user_info_dict.get("platform")
|
||||||
user_id = user_info_dict.get("user_id")
|
user_id = user_info_dict.get("user_id")
|
||||||
# 将机器人自己的user_id标记为"SELF",增强对自我身份的识别
|
|
||||||
if user_id == global_config.bot.qq_account:
|
|
||||||
user_id = "SELF"
|
|
||||||
user_nickname = user_info_dict.get("user_nickname")
|
user_nickname = user_info_dict.get("user_nickname")
|
||||||
user_cardname = user_info_dict.get("user_cardname")
|
user_cardname = user_info_dict.get("user_cardname")
|
||||||
|
|
||||||
|
|||||||
@@ -1901,6 +1901,8 @@ class DefaultReplyer:
|
|||||||
# 获取用户ID
|
# 获取用户ID
|
||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
person_id = await person_info_manager.get_person_id_by_person_name(sender)
|
person_id = await person_info_manager.get_person_id_by_person_name(sender)
|
||||||
|
if person_id == "SELF":
|
||||||
|
return f"你将要回复的是你自己发送的消息。"
|
||||||
if not person_id:
|
if not person_id:
|
||||||
logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取")
|
logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取")
|
||||||
return f"你完全不认识{sender},不理解ta的相关信息。"
|
return f"你完全不认识{sender},不理解ta的相关信息。"
|
||||||
|
|||||||
@@ -44,8 +44,8 @@ def replace_user_references_sync(
|
|||||||
|
|
||||||
if name_resolver is None:
|
if name_resolver is None:
|
||||||
def default_resolver(platform: str, user_id: str) -> str:
|
def default_resolver(platform: str, user_id: str) -> str:
|
||||||
# 检查是否是机器人自己(支持SELF标记或直接比对QQ号)
|
# 检查是否是机器人自己
|
||||||
if replace_bot_name and (user_id == "SELF" or user_id == global_config.bot.qq_account):
|
if replace_bot_name and (user_id == str(global_config.bot.qq_account)):
|
||||||
return f"{global_config.bot.nickname}(你)"
|
return f"{global_config.bot.nickname}(你)"
|
||||||
# 同步函数中无法使用异步的 get_value,直接返回 user_id
|
# 同步函数中无法使用异步的 get_value,直接返回 user_id
|
||||||
# 建议调用方使用 replace_user_references_async 以获取完整的用户名
|
# 建议调用方使用 replace_user_references_async 以获取完整的用户名
|
||||||
@@ -61,7 +61,7 @@ def replace_user_references_sync(
|
|||||||
bbb = match[2]
|
bbb = match[2]
|
||||||
try:
|
try:
|
||||||
# 检查是否是机器人自己(支持SELF标记或直接比对QQ号)
|
# 检查是否是机器人自己(支持SELF标记或直接比对QQ号)
|
||||||
if replace_bot_name and (bbb == "SELF" or bbb == global_config.bot.qq_account):
|
if replace_bot_name and (bbb == str(global_config.bot.qq_account)):
|
||||||
reply_person_name = f"{global_config.bot.nickname}(你)"
|
reply_person_name = f"{global_config.bot.nickname}(你)"
|
||||||
else:
|
else:
|
||||||
reply_person_name = name_resolver(platform, bbb) or aaa
|
reply_person_name = name_resolver(platform, bbb) or aaa
|
||||||
@@ -81,8 +81,8 @@ def replace_user_references_sync(
|
|||||||
aaa = m.group(1)
|
aaa = m.group(1)
|
||||||
bbb = m.group(2)
|
bbb = m.group(2)
|
||||||
try:
|
try:
|
||||||
# 检查是否是机器人自己(支持SELF标记或直接比对QQ号)
|
# 检查是否是机器人自己
|
||||||
if replace_bot_name and (bbb == "SELF" or bbb == global_config.bot.qq_account):
|
if replace_bot_name and (bbb == str(global_config.bot.qq_account)):
|
||||||
at_person_name = f"{global_config.bot.nickname}(你)"
|
at_person_name = f"{global_config.bot.nickname}(你)"
|
||||||
else:
|
else:
|
||||||
at_person_name = name_resolver(platform, bbb) or aaa
|
at_person_name = name_resolver(platform, bbb) or aaa
|
||||||
@@ -118,8 +118,8 @@ async def replace_user_references_async(
|
|||||||
"""
|
"""
|
||||||
if name_resolver is None:
|
if name_resolver is None:
|
||||||
async def default_resolver(platform: str, user_id: str) -> str:
|
async def default_resolver(platform: str, user_id: str) -> str:
|
||||||
# 检查是否是机器人自己(支持SELF标记或直接比对QQ号)
|
# 检查是否是机器人自己
|
||||||
if replace_bot_name and (user_id == "SELF" or user_id == global_config.bot.qq_account):
|
if replace_bot_name and (user_id == str(global_config.bot.qq_account)):
|
||||||
return f"{global_config.bot.nickname}(你)"
|
return f"{global_config.bot.nickname}(你)"
|
||||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||||
person_info = await person_info_manager.get_values(person_id, ["person_name"])
|
person_info = await person_info_manager.get_values(person_id, ["person_name"])
|
||||||
@@ -134,8 +134,8 @@ async def replace_user_references_async(
|
|||||||
aaa = match.group(1)
|
aaa = match.group(1)
|
||||||
bbb = match.group(2)
|
bbb = match.group(2)
|
||||||
try:
|
try:
|
||||||
# 检查是否是机器人自己(支持SELF标记或直接比对QQ号)
|
# 检查是否是机器人自己
|
||||||
if replace_bot_name and (bbb == "SELF" or bbb == global_config.bot.qq_account):
|
if replace_bot_name and (bbb == str(global_config.bot.qq_account)):
|
||||||
reply_person_name = f"{global_config.bot.nickname}(你)"
|
reply_person_name = f"{global_config.bot.nickname}(你)"
|
||||||
else:
|
else:
|
||||||
reply_person_name = await name_resolver(platform, bbb) or aaa
|
reply_person_name = await name_resolver(platform, bbb) or aaa
|
||||||
@@ -155,8 +155,8 @@ async def replace_user_references_async(
|
|||||||
aaa = m.group(1)
|
aaa = m.group(1)
|
||||||
bbb = m.group(2)
|
bbb = m.group(2)
|
||||||
try:
|
try:
|
||||||
# 检查是否是机器人自己(支持SELF标记或直接比对QQ号)
|
# 检查是否是机器人自己
|
||||||
if replace_bot_name and (bbb == "SELF" or bbb == global_config.bot.qq_account):
|
if replace_bot_name and (bbb == str(global_config.bot.qq_account)):
|
||||||
at_person_name = f"{global_config.bot.nickname}(你)"
|
at_person_name = f"{global_config.bot.nickname}(你)"
|
||||||
else:
|
else:
|
||||||
at_person_name = await name_resolver(platform, bbb) or aaa
|
at_person_name = await name_resolver(platform, bbb) or aaa
|
||||||
@@ -640,7 +640,7 @@ async def _build_readable_messages_internal(
|
|||||||
# 根据 replace_bot_name 参数决定是否替换机器人名称
|
# 根据 replace_bot_name 参数决定是否替换机器人名称
|
||||||
person_name: str
|
person_name: str
|
||||||
# 检查是否是机器人自己(支持SELF标记或直接比对QQ号)
|
# 检查是否是机器人自己(支持SELF标记或直接比对QQ号)
|
||||||
if replace_bot_name and (user_id == "SELF" or user_id == global_config.bot.qq_account):
|
if replace_bot_name and user_id == str(global_config.bot.qq_account):
|
||||||
person_name = f"{global_config.bot.nickname}(你)"
|
person_name = f"{global_config.bot.nickname}(你)"
|
||||||
else:
|
else:
|
||||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||||
@@ -656,8 +656,8 @@ async def _build_readable_messages_internal(
|
|||||||
else:
|
else:
|
||||||
person_name = "某人"
|
person_name = "某人"
|
||||||
|
|
||||||
# 在用户名后面添加 QQ 号, 但机器人本体不用(包括SELF标记)
|
# 在用户名后面添加 QQ 号, 但机器人本体不用
|
||||||
if user_id != global_config.bot.qq_account and user_id != "SELF":
|
if user_id != str(global_config.bot.qq_account):
|
||||||
person_name = f"{person_name}({user_id})"
|
person_name = f"{person_name}({user_id})"
|
||||||
|
|
||||||
# 使用独立函数处理用户引用格式
|
# 使用独立函数处理用户引用格式
|
||||||
@@ -1026,7 +1026,7 @@ async def build_readable_messages(
|
|||||||
actions = [
|
actions = [
|
||||||
{
|
{
|
||||||
"time": a.time,
|
"time": a.time,
|
||||||
"user_id": global_config.bot.qq_account,
|
"user_id": str(global_config.bot.qq_account),
|
||||||
"user_nickname": global_config.bot.nickname,
|
"user_nickname": global_config.bot.nickname,
|
||||||
"user_cardname": "",
|
"user_cardname": "",
|
||||||
"processed_plain_text": f"{a.action_prompt_display}",
|
"processed_plain_text": f"{a.action_prompt_display}",
|
||||||
|
|||||||
@@ -166,7 +166,7 @@ async def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12)
|
|||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
(user_info.platform, user_info.user_id) != sender
|
(user_info.platform, user_info.user_id) != sender
|
||||||
and user_info.user_id != global_config.bot.qq_account
|
and user_info.user_id != str(global_config.bot.qq_account)
|
||||||
and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group
|
and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group
|
||||||
and len(who_chat_in_group) < 5
|
and len(who_chat_in_group) < 5
|
||||||
): # 排除重复,排除消息发送者,排除bot,限制加载的关系数目
|
): # 排除重复,排除消息发送者,排除bot,限制加载的关系数目
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ async def find_messages(
|
|||||||
query = query.where(*conditions)
|
query = query.where(*conditions)
|
||||||
|
|
||||||
if filter_bot:
|
if filter_bot:
|
||||||
query = query.where(Messages.user_id != global_config.bot.qq_account)
|
query = query.where(Messages.user_id != str(global_config.bot.qq_account))
|
||||||
|
|
||||||
if filter_command:
|
if filter_command:
|
||||||
query = query.where(not_(Messages.is_command))
|
query = query.where(not_(Messages.is_command))
|
||||||
|
|||||||
@@ -462,6 +462,8 @@ class PersonInfoManager:
|
|||||||
若未命中则查询数据库并更新缓存。
|
若未命中则查询数据库并更新缓存。
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
if person_name == f"{global_config.bot.nickname}(你)":
|
||||||
|
return "SELF"
|
||||||
# 优先使用内存缓存加速查找:self.person_name_list maps person_id -> person_name
|
# 优先使用内存缓存加速查找:self.person_name_list maps person_id -> person_name
|
||||||
for pid, pname in self.person_name_list.items():
|
for pid, pname in self.person_name_list.items():
|
||||||
if pname == person_name:
|
if pname == person_name:
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ class RelationshipManager:
|
|||||||
replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name")
|
replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name")
|
||||||
|
|
||||||
# 跳过机器人自己
|
# 跳过机器人自己
|
||||||
if replace_user_id == global_config.bot.qq_account:
|
if replace_user_id == str(global_config.bot.qq_account):
|
||||||
name_mapping[f"{global_config.bot.nickname}"] = f"{global_config.bot.nickname}"
|
name_mapping[f"{global_config.bot.nickname}"] = f"{global_config.bot.nickname}"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@@ -209,7 +209,7 @@ class StreamToolHistoryManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"[{self.chat_id}] 缓存设置失败: {e}")
|
logger.warning(f"[{self.chat_id}] 缓存设置失败: {e}")
|
||||||
|
|
||||||
async def get_recent_history(self, count: int = 5, status_filter: str | None = None) -> list[ToolCallRecord]:
|
def get_recent_history(self, count: int = 5, status_filter: str | None = None) -> list[ToolCallRecord]:
|
||||||
"""获取最近的历史记录
|
"""获取最近的历史记录
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import time
|
import time
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict, dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||||
@@ -16,6 +17,26 @@ from src.plugin_system.core.stream_tool_history import ToolCallRecord, get_strea
|
|||||||
logger = get_logger("tool_use")
|
logger = get_logger("tool_use")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolExecutionConfig:
|
||||||
|
"""工具执行配置"""
|
||||||
|
enable_parallel: bool = True # 是否启用并行执行
|
||||||
|
max_concurrent_tools: int = 5 # 最大并发工具数量
|
||||||
|
tool_timeout: float = 60.0 # 单个工具超时时间(秒)
|
||||||
|
enable_dependency_check: bool = True # 是否启用依赖检查
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolExecutionResult:
|
||||||
|
"""工具执行结果"""
|
||||||
|
tool_call: ToolCall
|
||||||
|
result: dict[str, Any] | None
|
||||||
|
error: Exception | None = None
|
||||||
|
execution_time: float = 0.0
|
||||||
|
is_timeout: bool = False
|
||||||
|
original_index: int = 0 # 原始索引,用于保持结果顺序
|
||||||
|
|
||||||
|
|
||||||
def init_tool_executor_prompt():
|
def init_tool_executor_prompt():
|
||||||
"""初始化工具执行器的提示词"""
|
"""初始化工具执行器的提示词"""
|
||||||
tool_executor_prompt = """
|
tool_executor_prompt = """
|
||||||
@@ -75,16 +96,19 @@ class ToolExecutor:
|
|||||||
"""独立的工具执行器组件
|
"""独立的工具执行器组件
|
||||||
|
|
||||||
可以直接输入聊天消息内容,自动判断并执行相应的工具,返回结构化的工具执行结果。
|
可以直接输入聊天消息内容,自动判断并执行相应的工具,返回结构化的工具执行结果。
|
||||||
|
支持并发执行多个工具,提升执行效率。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, chat_id: str):
|
def __init__(self, chat_id: str, execution_config: ToolExecutionConfig | None = None):
|
||||||
"""初始化工具执行器
|
"""初始化工具执行器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
executor_id: 执行器标识符,用于日志记录
|
|
||||||
chat_id: 聊天标识符,用于日志记录
|
chat_id: 聊天标识符,用于日志记录
|
||||||
|
execution_config: 工具执行配置,如果不提供则使用默认配置
|
||||||
"""
|
"""
|
||||||
self.chat_id = chat_id
|
self.chat_id = chat_id
|
||||||
|
self.execution_config = execution_config or ToolExecutionConfig()
|
||||||
|
|
||||||
# chat_stream 和 log_prefix 将在异步方法中初始化
|
# chat_stream 和 log_prefix 将在异步方法中初始化
|
||||||
self.chat_stream = None # type: ignore
|
self.chat_stream = None # type: ignore
|
||||||
self.log_prefix = f"[{chat_id}]"
|
self.log_prefix = f"[{chat_id}]"
|
||||||
@@ -199,7 +223,7 @@ class ToolExecutor:
|
|||||||
|
|
||||||
|
|
||||||
async def execute_tool_calls(self, tool_calls: list[ToolCall] | None) -> tuple[list[dict[str, Any]], list[str]]:
|
async def execute_tool_calls(self, tool_calls: list[ToolCall] | None) -> tuple[list[dict[str, Any]], list[str]]:
|
||||||
"""执行工具调用
|
"""执行工具调用,支持并发执行
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool_calls: LLM返回的工具调用列表
|
tool_calls: LLM返回的工具调用列表
|
||||||
@@ -216,70 +240,50 @@ class ToolExecutor:
|
|||||||
|
|
||||||
# 提取tool_calls中的函数名称
|
# 提取tool_calls中的函数名称
|
||||||
func_names = []
|
func_names = []
|
||||||
for call in tool_calls:
|
valid_tool_calls = []
|
||||||
|
for i, call in enumerate(tool_calls):
|
||||||
try:
|
try:
|
||||||
if hasattr(call, "func_name"):
|
if hasattr(call, "func_name"):
|
||||||
func_names.append(call.func_name)
|
func_names.append(call.func_name)
|
||||||
|
valid_tool_calls.append(call)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.log_prefix}获取工具名称失败: {e}")
|
logger.error(f"{self.log_prefix}获取工具名称失败: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if func_names:
|
if not valid_tool_calls:
|
||||||
logger.info(f"{self.log_prefix}开始执行工具调用: {func_names}")
|
|
||||||
else:
|
|
||||||
logger.warning(f"{self.log_prefix}未找到有效的工具调用")
|
logger.warning(f"{self.log_prefix}未找到有效的工具调用")
|
||||||
|
return [], []
|
||||||
|
|
||||||
# 执行每个工具调用
|
if func_names:
|
||||||
for tool_call in tool_calls:
|
logger.info(f"{self.log_prefix}开始执行工具调用: {func_names} (模式: {'并发' if self.execution_config.enable_parallel else '串行'})")
|
||||||
tool_name = getattr(tool_call, "func_name", "unknown_tool")
|
|
||||||
tool_args = getattr(tool_call, "args", {})
|
|
||||||
try:
|
|
||||||
logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
|
|
||||||
|
|
||||||
# 执行工具
|
# 选择执行模式
|
||||||
result = await self.execute_tool_call(tool_call)
|
if self.execution_config.enable_parallel and len(valid_tool_calls) > 1:
|
||||||
|
# 并发执行模式
|
||||||
|
execution_results = await self._execute_tools_concurrently(valid_tool_calls)
|
||||||
|
else:
|
||||||
|
# 串行执行模式(保持原有逻辑)
|
||||||
|
execution_results = await self._execute_tools_sequentially(valid_tool_calls)
|
||||||
|
|
||||||
if result:
|
# 处理执行结果,保持原始顺序
|
||||||
tool_info = {
|
execution_results.sort(key=lambda x: x.original_index)
|
||||||
"type": result.get("type", "unknown_type"),
|
|
||||||
"id": result.get("id", f"tool_exec_{time.time()}"),
|
|
||||||
"content": result.get("content", ""),
|
|
||||||
"tool_name": tool_name,
|
|
||||||
"timestamp": time.time(),
|
|
||||||
}
|
|
||||||
content = tool_info["content"]
|
|
||||||
if not isinstance(content, str | list | tuple):
|
|
||||||
tool_info["content"] = str(content)
|
|
||||||
|
|
||||||
tool_results.append(tool_info)
|
for exec_result in execution_results:
|
||||||
used_tools.append(tool_name)
|
tool_name = getattr(exec_result.tool_call, "func_name", "unknown_tool")
|
||||||
logger.info(f"{self.log_prefix}工具{tool_name}执行成功,类型: {tool_info['type']}")
|
tool_args = getattr(exec_result.tool_call, "args", {})
|
||||||
preview = content[:200] if isinstance(content, str) else str(content)[:200]
|
|
||||||
logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...")
|
|
||||||
|
|
||||||
# 记录到历史
|
if exec_result.error:
|
||||||
await self.history_manager.add_tool_call(ToolCallRecord(
|
# 处理错误结果
|
||||||
tool_name=tool_name,
|
error_msg = f"工具{tool_name}执行失败"
|
||||||
args=tool_args,
|
if exec_result.is_timeout:
|
||||||
result=result,
|
error_msg += f" (超时: {self.execution_config.tool_timeout}s)"
|
||||||
status="success"
|
error_msg += f": {exec_result.error!s}"
|
||||||
))
|
|
||||||
else:
|
|
||||||
# 工具返回空结果也记录到历史
|
|
||||||
await self.history_manager.add_tool_call(ToolCallRecord(
|
|
||||||
tool_name=tool_name,
|
|
||||||
args=tool_args,
|
|
||||||
result=None,
|
|
||||||
status="success"
|
|
||||||
))
|
|
||||||
|
|
||||||
except Exception as e:
|
logger.error(f"{self.log_prefix}{error_msg}")
|
||||||
logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}")
|
|
||||||
# 添加错误信息到结果中
|
|
||||||
error_info = {
|
error_info = {
|
||||||
"type": "tool_error",
|
"type": "tool_error",
|
||||||
"id": f"tool_error_{time.time()}",
|
"id": f"tool_error_{time.time()}",
|
||||||
"content": f"工具{tool_name}执行失败: {e!s}",
|
"content": error_msg,
|
||||||
"tool_name": tool_name,
|
"tool_name": tool_name,
|
||||||
"timestamp": time.time(),
|
"timestamp": time.time(),
|
||||||
}
|
}
|
||||||
@@ -290,12 +294,188 @@ class ToolExecutor:
|
|||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
args=tool_args,
|
args=tool_args,
|
||||||
result=None,
|
result=None,
|
||||||
status="error",
|
status="error" if not exec_result.is_timeout else "timeout",
|
||||||
error_message=str(e)
|
error_message=str(exec_result.error),
|
||||||
|
execution_time=exec_result.execution_time
|
||||||
|
))
|
||||||
|
elif exec_result.result:
|
||||||
|
# 处理成功结果
|
||||||
|
tool_info = {
|
||||||
|
"type": exec_result.result.get("type", "unknown_type"),
|
||||||
|
"id": exec_result.result.get("id", f"tool_exec_{time.time()}"),
|
||||||
|
"content": exec_result.result.get("content", ""),
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"timestamp": time.time(),
|
||||||
|
}
|
||||||
|
content = tool_info["content"]
|
||||||
|
if not isinstance(content, str | list | tuple):
|
||||||
|
tool_info["content"] = str(content)
|
||||||
|
|
||||||
|
tool_results.append(tool_info)
|
||||||
|
used_tools.append(tool_name)
|
||||||
|
logger.info(f"{self.log_prefix}工具{tool_name}执行成功,类型: {tool_info['type']}, 耗时: {exec_result.execution_time:.2f}s")
|
||||||
|
preview = content[:200] if isinstance(content, str) else str(content)[:200]
|
||||||
|
logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...")
|
||||||
|
|
||||||
|
# 记录到历史
|
||||||
|
await self.history_manager.add_tool_call(ToolCallRecord(
|
||||||
|
tool_name=tool_name,
|
||||||
|
args=tool_args,
|
||||||
|
result=exec_result.result,
|
||||||
|
status="success",
|
||||||
|
execution_time=exec_result.execution_time
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
# 工具返回空结果也记录到历史
|
||||||
|
await self.history_manager.add_tool_call(ToolCallRecord(
|
||||||
|
tool_name=tool_name,
|
||||||
|
args=tool_args,
|
||||||
|
result=None,
|
||||||
|
status="success",
|
||||||
|
execution_time=exec_result.execution_time
|
||||||
))
|
))
|
||||||
|
|
||||||
return tool_results, used_tools
|
return tool_results, used_tools
|
||||||
|
|
||||||
|
async def _execute_tools_concurrently(self, tool_calls: list[ToolCall]) -> list[ToolExecutionResult]:
|
||||||
|
"""并发执行多个工具调用
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_calls: 工具调用列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[ToolExecutionResult]: 执行结果列表
|
||||||
|
"""
|
||||||
|
logger.info(f"{self.log_prefix}启动并发执行,工具数量: {len(tool_calls)}, 最大并发数: {self.execution_config.max_concurrent_tools}")
|
||||||
|
|
||||||
|
# 创建信号量控制并发数量
|
||||||
|
semaphore = asyncio.Semaphore(self.execution_config.max_concurrent_tools)
|
||||||
|
|
||||||
|
async def execute_with_semaphore(tool_call: ToolCall, index: int) -> ToolExecutionResult:
|
||||||
|
"""在信号量控制下执行单个工具"""
|
||||||
|
async with semaphore:
|
||||||
|
return await self._execute_single_tool_with_timeout(tool_call, index)
|
||||||
|
|
||||||
|
# 创建所有任务
|
||||||
|
tasks = [
|
||||||
|
execute_with_semaphore(tool_call, i)
|
||||||
|
for i, tool_call in enumerate(tool_calls)
|
||||||
|
]
|
||||||
|
|
||||||
|
# 并发执行所有任务
|
||||||
|
try:
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# 处理异常结果
|
||||||
|
processed_results = []
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
logger.error(f"{self.log_prefix}工具执行任务异常: {result}")
|
||||||
|
processed_results.append(ToolExecutionResult(
|
||||||
|
tool_call=tool_calls[i],
|
||||||
|
result=None,
|
||||||
|
error=result,
|
||||||
|
original_index=i
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
processed_results.append(result)
|
||||||
|
|
||||||
|
return processed_results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{self.log_prefix}并发执行过程中发生异常: {e}")
|
||||||
|
# 返回所有工具的错误结果
|
||||||
|
return [
|
||||||
|
ToolExecutionResult(
|
||||||
|
tool_call=tool_call,
|
||||||
|
result=None,
|
||||||
|
error=e,
|
||||||
|
original_index=i
|
||||||
|
)
|
||||||
|
for i, tool_call in enumerate(tool_calls)
|
||||||
|
]
|
||||||
|
|
||||||
|
async def _execute_tools_sequentially(self, tool_calls: list[ToolCall]) -> list[ToolExecutionResult]:
|
||||||
|
"""串行执行多个工具调用(保持原有逻辑)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_calls: 工具调用列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[ToolExecutionResult]: 执行结果列表
|
||||||
|
"""
|
||||||
|
logger.info(f"{self.log_prefix}启动串行执行,工具数量: {len(tool_calls)}")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for i, tool_call in enumerate(tool_calls):
|
||||||
|
result = await self._execute_single_tool_with_timeout(tool_call, i)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def _execute_single_tool_with_timeout(self, tool_call: ToolCall, index: int) -> ToolExecutionResult:
|
||||||
|
"""执行单个工具调用,支持超时控制
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_call: 工具调用
|
||||||
|
index: 原始索引
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ToolExecutionResult: 执行结果
|
||||||
|
"""
|
||||||
|
tool_name = getattr(tool_call, "func_name", "unknown_tool")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.debug(f"{self.log_prefix}开始执行工具: {tool_name}")
|
||||||
|
|
||||||
|
# 使用 asyncio.wait_for 实现超时控制
|
||||||
|
if self.execution_config.tool_timeout > 0:
|
||||||
|
result = await asyncio.wait_for(
|
||||||
|
self.execute_tool_call(tool_call),
|
||||||
|
timeout=self.execution_config.tool_timeout
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = await self.execute_tool_call(tool_call)
|
||||||
|
|
||||||
|
execution_time = time.time() - start_time
|
||||||
|
logger.debug(f"{self.log_prefix}工具 {tool_name} 执行完成,耗时: {execution_time:.2f}s")
|
||||||
|
|
||||||
|
return ToolExecutionResult(
|
||||||
|
tool_call=tool_call,
|
||||||
|
result=result,
|
||||||
|
error=None,
|
||||||
|
execution_time=execution_time,
|
||||||
|
is_timeout=False,
|
||||||
|
original_index=index
|
||||||
|
)
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
execution_time = time.time() - start_time
|
||||||
|
logger.warning(f"{self.log_prefix}工具 {tool_name} 执行超时 ({self.execution_config.tool_timeout}s)")
|
||||||
|
|
||||||
|
return ToolExecutionResult(
|
||||||
|
tool_call=tool_call,
|
||||||
|
result=None,
|
||||||
|
error=asyncio.TimeoutError(f"工具执行超时 ({self.execution_config.tool_timeout}s)"),
|
||||||
|
execution_time=execution_time,
|
||||||
|
is_timeout=True,
|
||||||
|
original_index=index
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
execution_time = time.time() - start_time
|
||||||
|
logger.error(f"{self.log_prefix}工具 {tool_name} 执行失败: {e}")
|
||||||
|
|
||||||
|
return ToolExecutionResult(
|
||||||
|
tool_call=tool_call,
|
||||||
|
result=None,
|
||||||
|
error=e,
|
||||||
|
execution_time=execution_time,
|
||||||
|
is_timeout=False,
|
||||||
|
original_index=index
|
||||||
|
)
|
||||||
|
|
||||||
async def execute_tool_call(
|
async def execute_tool_call(
|
||||||
self, tool_call: ToolCall, tool_instance: BaseTool | None = None
|
self, tool_call: ToolCall, tool_instance: BaseTool | None = None
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
@@ -529,6 +709,59 @@ class ToolExecutor:
|
|||||||
"""
|
"""
|
||||||
return self.history_manager.get_stats()
|
return self.history_manager.get_stats()
|
||||||
|
|
||||||
|
def set_execution_config(self, config: ToolExecutionConfig) -> None:
|
||||||
|
"""设置工具执行配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: 新的执行配置
|
||||||
|
"""
|
||||||
|
self.execution_config = config
|
||||||
|
logger.info(f"{self.log_prefix}工具执行配置已更新: 并发={config.enable_parallel}, 最大并发数={config.max_concurrent_tools}, 超时={config.tool_timeout}s")
|
||||||
|
|
||||||
|
def enable_parallel_execution(self, max_concurrent_tools: int = 5, timeout: float = 60.0) -> None:
|
||||||
|
"""启用并发执行
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_concurrent_tools: 最大并发工具数量
|
||||||
|
timeout: 单个工具超时时间(秒)
|
||||||
|
"""
|
||||||
|
self.execution_config.enable_parallel = True
|
||||||
|
self.execution_config.max_concurrent_tools = max_concurrent_tools
|
||||||
|
self.execution_config.tool_timeout = timeout
|
||||||
|
logger.info(f"{self.log_prefix}已启用并发执行: 最大并发数={max_concurrent_tools}, 超时={timeout}s")
|
||||||
|
|
||||||
|
def disable_parallel_execution(self) -> None:
|
||||||
|
"""禁用并发执行,使用串行模式"""
|
||||||
|
self.execution_config.enable_parallel = False
|
||||||
|
logger.info(f"{self.log_prefix}已禁用并发执行,使用串行模式")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_with_parallel_config(
|
||||||
|
cls,
|
||||||
|
chat_id: str,
|
||||||
|
max_concurrent_tools: int = 5,
|
||||||
|
tool_timeout: float = 60.0,
|
||||||
|
enable_dependency_check: bool = True
|
||||||
|
) -> "ToolExecutor":
|
||||||
|
"""创建支持并发执行的工具执行器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_id: 聊天标识符
|
||||||
|
max_concurrent_tools: 最大并发工具数量
|
||||||
|
tool_timeout: 单个工具超时时间(秒)
|
||||||
|
enable_dependency_check: 是否启用依赖检查
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
配置好并发执行的ToolExecutor实例
|
||||||
|
"""
|
||||||
|
config = ToolExecutionConfig(
|
||||||
|
enable_parallel=True,
|
||||||
|
max_concurrent_tools=max_concurrent_tools,
|
||||||
|
tool_timeout=tool_timeout,
|
||||||
|
enable_dependency_check=enable_dependency_check
|
||||||
|
)
|
||||||
|
return cls(chat_id, config)
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
ToolExecutor使用示例:
|
ToolExecutor使用示例:
|
||||||
@@ -541,7 +774,25 @@ results, _, _ = await executor.execute_from_chat_message(
|
|||||||
sender="用户"
|
sender="用户"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. 获取详细信息
|
# 2. 并发执行配置 - 创建支持并发的执行器
|
||||||
|
parallel_executor = ToolExecutor.create_with_parallel_config(
|
||||||
|
chat_id=my_chat_id,
|
||||||
|
max_concurrent_tools=3, # 最大3个工具并发
|
||||||
|
tool_timeout=30.0 # 单个工具30秒超时
|
||||||
|
)
|
||||||
|
|
||||||
|
# 或者动态配置并发执行
|
||||||
|
executor.enable_parallel_execution(max_concurrent_tools=5, timeout=60.0)
|
||||||
|
|
||||||
|
# 3. 并发执行多个工具 - 当LLM返回多个工具调用时自动并发执行
|
||||||
|
results, used_tools, _ = await parallel_executor.execute_from_chat_message(
|
||||||
|
target_message="帮我查询天气、新闻和股票价格",
|
||||||
|
chat_history="",
|
||||||
|
sender="用户"
|
||||||
|
)
|
||||||
|
# 多个工具将并发执行,显著提升性能
|
||||||
|
|
||||||
|
# 4. 获取详细信息
|
||||||
results, used_tools, prompt = await executor.execute_from_chat_message(
|
results, used_tools, prompt = await executor.execute_from_chat_message(
|
||||||
target_message="帮我查询Python相关知识",
|
target_message="帮我查询Python相关知识",
|
||||||
chat_history="",
|
chat_history="",
|
||||||
@@ -549,13 +800,13 @@ results, used_tools, prompt = await executor.execute_from_chat_message(
|
|||||||
return_details=True
|
return_details=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. 直接执行特定工具
|
# 5. 直接执行特定工具
|
||||||
result = await executor.execute_specific_tool_simple(
|
result = await executor.execute_specific_tool_simple(
|
||||||
tool_name="get_knowledge",
|
tool_name="get_knowledge",
|
||||||
tool_args={"query": "机器学习"}
|
tool_args={"query": "机器学习"}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. 使用工具历史 - 连续对话中的工具调用
|
# 6. 使用工具历史 - 连续对话中的工具调用
|
||||||
# 第一次调用
|
# 第一次调用
|
||||||
await executor.execute_from_chat_message(
|
await executor.execute_from_chat_message(
|
||||||
target_message="查询今天的天气",
|
target_message="查询今天的天气",
|
||||||
@@ -569,7 +820,27 @@ await executor.execute_from_chat_message(
|
|||||||
sender="用户"
|
sender="用户"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5. 获取和清除历史
|
# 7. 配置管理
|
||||||
|
config = ToolExecutionConfig(
|
||||||
|
enable_parallel=True,
|
||||||
|
max_concurrent_tools=10,
|
||||||
|
tool_timeout=120.0,
|
||||||
|
enable_dependency_check=True
|
||||||
|
)
|
||||||
|
executor.set_execution_config(config)
|
||||||
|
|
||||||
|
# 8. 获取和清除历史
|
||||||
history = executor.get_tool_history() # 获取历史记录
|
history = executor.get_tool_history() # 获取历史记录
|
||||||
|
stats = executor.get_tool_stats() # 获取执行统计信息
|
||||||
executor.clear_tool_history() # 清除历史记录
|
executor.clear_tool_history() # 清除历史记录
|
||||||
|
|
||||||
|
# 9. 禁用并发执行(如需要串行执行)
|
||||||
|
executor.disable_parallel_execution()
|
||||||
|
|
||||||
|
并发执行优势:
|
||||||
|
- 🚀 性能提升:多个工具同时执行,减少总体等待时间
|
||||||
|
- 🛡️ 错误隔离:单个工具失败不影响其他工具执行
|
||||||
|
- ⏱️ 超时控制:防止单个工具无限等待
|
||||||
|
- 🔧 灵活配置:可根据需要调整并发数量和超时时间
|
||||||
|
- 📊 统计信息:提供详细的执行时间和性能数据
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user