refactor: 移除对机器人自身用户ID的特殊处理,统一使用QQ号进行比较
This commit is contained in:
@@ -280,8 +280,6 @@ class MessageStorageBatcher:
|
||||
user_platform = user_info_dict.get("platform")
|
||||
user_id = user_info_dict.get("user_id")
|
||||
# 将机器人自己的user_id标记为"SELF",增强对自我身份的识别
|
||||
if user_id == global_config.bot.qq_account:
|
||||
user_id = "SELF"
|
||||
user_nickname = user_info_dict.get("user_nickname")
|
||||
user_cardname = user_info_dict.get("user_cardname")
|
||||
|
||||
@@ -630,9 +628,6 @@ class MessageStorage:
|
||||
|
||||
user_platform = user_info_dict.get("platform")
|
||||
user_id = user_info_dict.get("user_id")
|
||||
# 将机器人自己的user_id标记为"SELF",增强对自我身份的识别
|
||||
if user_id == global_config.bot.qq_account:
|
||||
user_id = "SELF"
|
||||
user_nickname = user_info_dict.get("user_nickname")
|
||||
user_cardname = user_info_dict.get("user_cardname")
|
||||
|
||||
|
||||
@@ -1872,6 +1872,8 @@ class DefaultReplyer:
|
||||
# 获取用户ID
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = await person_info_manager.get_person_id_by_person_name(sender)
|
||||
if person_id == "SELF":
|
||||
return f"你将要回复的是你自己发送的消息。"
|
||||
if not person_id:
|
||||
logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取")
|
||||
return f"你完全不认识{sender},不理解ta的相关信息。"
|
||||
|
||||
@@ -44,8 +44,8 @@ def replace_user_references_sync(
|
||||
|
||||
if name_resolver is None:
|
||||
def default_resolver(platform: str, user_id: str) -> str:
|
||||
# 检查是否是机器人自己(支持SELF标记或直接比对QQ号)
|
||||
if replace_bot_name and (user_id == "SELF" or user_id == global_config.bot.qq_account):
|
||||
# 检查是否是机器人自己
|
||||
if replace_bot_name and (user_id == str(global_config.bot.qq_account)):
|
||||
return f"{global_config.bot.nickname}(你)"
|
||||
# 同步函数中无法使用异步的 get_value,直接返回 user_id
|
||||
# 建议调用方使用 replace_user_references_async 以获取完整的用户名
|
||||
@@ -61,7 +61,7 @@ def replace_user_references_sync(
|
||||
bbb = match[2]
|
||||
try:
|
||||
# 检查是否是机器人自己(支持SELF标记或直接比对QQ号)
|
||||
if replace_bot_name and (bbb == "SELF" or bbb == global_config.bot.qq_account):
|
||||
if replace_bot_name and (bbb == str(global_config.bot.qq_account)):
|
||||
reply_person_name = f"{global_config.bot.nickname}(你)"
|
||||
else:
|
||||
reply_person_name = name_resolver(platform, bbb) or aaa
|
||||
@@ -81,8 +81,8 @@ def replace_user_references_sync(
|
||||
aaa = m.group(1)
|
||||
bbb = m.group(2)
|
||||
try:
|
||||
# 检查是否是机器人自己(支持SELF标记或直接比对QQ号)
|
||||
if replace_bot_name and (bbb == "SELF" or bbb == global_config.bot.qq_account):
|
||||
# 检查是否是机器人自己
|
||||
if replace_bot_name and (bbb == str(global_config.bot.qq_account)):
|
||||
at_person_name = f"{global_config.bot.nickname}(你)"
|
||||
else:
|
||||
at_person_name = name_resolver(platform, bbb) or aaa
|
||||
@@ -120,8 +120,8 @@ async def replace_user_references_async(
|
||||
person_info_manager = get_person_info_manager()
|
||||
|
||||
async def default_resolver(platform: str, user_id: str) -> str:
|
||||
# 检查是否是机器人自己(支持SELF标记或直接比对QQ号)
|
||||
if replace_bot_name and (user_id == "SELF" or user_id == global_config.bot.qq_account):
|
||||
# 检查是否是机器人自己
|
||||
if replace_bot_name and (user_id == str(global_config.bot.qq_account)):
|
||||
return f"{global_config.bot.nickname}(你)"
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
return await person_info_manager.get_value(person_id, "person_name") or user_id # type: ignore
|
||||
@@ -135,8 +135,8 @@ async def replace_user_references_async(
|
||||
aaa = match.group(1)
|
||||
bbb = match.group(2)
|
||||
try:
|
||||
# 检查是否是机器人自己(支持SELF标记或直接比对QQ号)
|
||||
if replace_bot_name and (bbb == "SELF" or bbb == global_config.bot.qq_account):
|
||||
# 检查是否是机器人自己
|
||||
if replace_bot_name and (bbb == str(global_config.bot.qq_account)):
|
||||
reply_person_name = f"{global_config.bot.nickname}(你)"
|
||||
else:
|
||||
reply_person_name = await name_resolver(platform, bbb) or aaa
|
||||
@@ -156,8 +156,8 @@ async def replace_user_references_async(
|
||||
aaa = m.group(1)
|
||||
bbb = m.group(2)
|
||||
try:
|
||||
# 检查是否是机器人自己(支持SELF标记或直接比对QQ号)
|
||||
if replace_bot_name and (bbb == "SELF" or bbb == global_config.bot.qq_account):
|
||||
# 检查是否是机器人自己
|
||||
if replace_bot_name and (bbb == str(global_config.bot.qq_account)):
|
||||
at_person_name = f"{global_config.bot.nickname}(你)"
|
||||
else:
|
||||
at_person_name = await name_resolver(platform, bbb) or aaa
|
||||
@@ -641,7 +641,7 @@ async def _build_readable_messages_internal(
|
||||
# 根据 replace_bot_name 参数决定是否替换机器人名称
|
||||
person_name: str
|
||||
# 检查是否是机器人自己(支持SELF标记或直接比对QQ号)
|
||||
if replace_bot_name and (user_id == "SELF" or user_id == global_config.bot.qq_account):
|
||||
if replace_bot_name and user_id == str(global_config.bot.qq_account):
|
||||
person_name = f"{global_config.bot.nickname}(你)"
|
||||
else:
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
@@ -657,8 +657,8 @@ async def _build_readable_messages_internal(
|
||||
else:
|
||||
person_name = "某人"
|
||||
|
||||
# 在用户名后面添加 QQ 号, 但机器人本体不用(包括SELF标记)
|
||||
if user_id != global_config.bot.qq_account and user_id != "SELF":
|
||||
# 在用户名后面添加 QQ 号, 但机器人本体不用
|
||||
if user_id != str(global_config.bot.qq_account):
|
||||
person_name = f"{person_name}({user_id})"
|
||||
|
||||
# 使用独立函数处理用户引用格式
|
||||
@@ -1022,7 +1022,7 @@ async def build_readable_messages(
|
||||
actions = [
|
||||
{
|
||||
"time": a.time,
|
||||
"user_id": global_config.bot.qq_account,
|
||||
"user_id": str(global_config.bot.qq_account),
|
||||
"user_nickname": global_config.bot.nickname,
|
||||
"user_cardname": "",
|
||||
"processed_plain_text": f"{a.action_prompt_display}",
|
||||
|
||||
@@ -166,7 +166,7 @@ async def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12)
|
||||
)
|
||||
if (
|
||||
(user_info.platform, user_info.user_id) != sender
|
||||
and user_info.user_id != global_config.bot.qq_account
|
||||
and user_info.user_id != str(global_config.bot.qq_account)
|
||||
and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group
|
||||
and len(who_chat_in_group) < 5
|
||||
): # 排除重复,排除消息发送者,排除bot,限制加载的关系数目
|
||||
|
||||
@@ -89,7 +89,7 @@ async def find_messages(
|
||||
query = query.where(*conditions)
|
||||
|
||||
if filter_bot:
|
||||
query = query.where(Messages.user_id != global_config.bot.qq_account)
|
||||
query = query.where(Messages.user_id != str(global_config.bot.qq_account))
|
||||
|
||||
if filter_command:
|
||||
query = query.where(not_(Messages.is_command))
|
||||
|
||||
@@ -132,6 +132,8 @@ class PersonInfoManager:
|
||||
若未命中则查询数据库并更新缓存。
|
||||
"""
|
||||
try:
|
||||
if person_name == f"{global_config.bot.nickname}(你)":
|
||||
return "SELF"
|
||||
# 优先使用内存缓存加速查找:self.person_name_list maps person_id -> person_name
|
||||
for pid, pname in self.person_name_list.items():
|
||||
if pname == person_name:
|
||||
|
||||
@@ -104,7 +104,7 @@ class RelationshipManager:
|
||||
replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name")
|
||||
|
||||
# 跳过机器人自己
|
||||
if replace_user_id == global_config.bot.qq_account:
|
||||
if replace_user_id == str(global_config.bot.qq_account):
|
||||
name_mapping[f"{global_config.bot.nickname}"] = f"{global_config.bot.nickname}"
|
||||
continue
|
||||
|
||||
|
||||
@@ -209,7 +209,7 @@ class StreamToolHistoryManager:
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.chat_id}] 缓存设置失败: {e}")
|
||||
|
||||
async def get_recent_history(self, count: int = 5, status_filter: str | None = None) -> list[ToolCallRecord]:
|
||||
def get_recent_history(self, count: int = 5, status_filter: str | None = None) -> list[ToolCallRecord]:
|
||||
"""获取最近的历史记录
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import time
|
||||
from dataclasses import asdict
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any
|
||||
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
@@ -16,6 +17,26 @@ from src.plugin_system.core.stream_tool_history import ToolCallRecord, get_strea
|
||||
logger = get_logger("tool_use")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolExecutionConfig:
|
||||
"""工具执行配置"""
|
||||
enable_parallel: bool = True # 是否启用并行执行
|
||||
max_concurrent_tools: int = 5 # 最大并发工具数量
|
||||
tool_timeout: float = 60.0 # 单个工具超时时间(秒)
|
||||
enable_dependency_check: bool = True # 是否启用依赖检查
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolExecutionResult:
|
||||
"""工具执行结果"""
|
||||
tool_call: ToolCall
|
||||
result: dict[str, Any] | None
|
||||
error: Exception | None = None
|
||||
execution_time: float = 0.0
|
||||
is_timeout: bool = False
|
||||
original_index: int = 0 # 原始索引,用于保持结果顺序
|
||||
|
||||
|
||||
def init_tool_executor_prompt():
|
||||
"""初始化工具执行器的提示词"""
|
||||
tool_executor_prompt = """
|
||||
@@ -75,16 +96,19 @@ class ToolExecutor:
|
||||
"""独立的工具执行器组件
|
||||
|
||||
可以直接输入聊天消息内容,自动判断并执行相应的工具,返回结构化的工具执行结果。
|
||||
支持并发执行多个工具,提升执行效率。
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str):
|
||||
def __init__(self, chat_id: str, execution_config: ToolExecutionConfig | None = None):
|
||||
"""初始化工具执行器
|
||||
|
||||
Args:
|
||||
executor_id: 执行器标识符,用于日志记录
|
||||
chat_id: 聊天标识符,用于日志记录
|
||||
execution_config: 工具执行配置,如果不提供则使用默认配置
|
||||
"""
|
||||
self.chat_id = chat_id
|
||||
self.execution_config = execution_config or ToolExecutionConfig()
|
||||
|
||||
# chat_stream 和 log_prefix 将在异步方法中初始化
|
||||
self.chat_stream = None # type: ignore
|
||||
self.log_prefix = f"[{chat_id}]"
|
||||
@@ -199,7 +223,7 @@ class ToolExecutor:
|
||||
|
||||
|
||||
async def execute_tool_calls(self, tool_calls: list[ToolCall] | None) -> tuple[list[dict[str, Any]], list[str]]:
|
||||
"""执行工具调用
|
||||
"""执行工具调用,支持并发执行
|
||||
|
||||
Args:
|
||||
tool_calls: LLM返回的工具调用列表
|
||||
@@ -216,70 +240,50 @@ class ToolExecutor:
|
||||
|
||||
# 提取tool_calls中的函数名称
|
||||
func_names = []
|
||||
for call in tool_calls:
|
||||
valid_tool_calls = []
|
||||
for i, call in enumerate(tool_calls):
|
||||
try:
|
||||
if hasattr(call, "func_name"):
|
||||
func_names.append(call.func_name)
|
||||
valid_tool_calls.append(call)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}获取工具名称失败: {e}")
|
||||
continue
|
||||
|
||||
if func_names:
|
||||
logger.info(f"{self.log_prefix}开始执行工具调用: {func_names}")
|
||||
else:
|
||||
if not valid_tool_calls:
|
||||
logger.warning(f"{self.log_prefix}未找到有效的工具调用")
|
||||
return [], []
|
||||
|
||||
# 执行每个工具调用
|
||||
for tool_call in tool_calls:
|
||||
tool_name = getattr(tool_call, "func_name", "unknown_tool")
|
||||
tool_args = getattr(tool_call, "args", {})
|
||||
try:
|
||||
logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
|
||||
if func_names:
|
||||
logger.info(f"{self.log_prefix}开始执行工具调用: {func_names} (模式: {'并发' if self.execution_config.enable_parallel else '串行'})")
|
||||
|
||||
# 执行工具
|
||||
result = await self.execute_tool_call(tool_call)
|
||||
# 选择执行模式
|
||||
if self.execution_config.enable_parallel and len(valid_tool_calls) > 1:
|
||||
# 并发执行模式
|
||||
execution_results = await self._execute_tools_concurrently(valid_tool_calls)
|
||||
else:
|
||||
# 串行执行模式(保持原有逻辑)
|
||||
execution_results = await self._execute_tools_sequentially(valid_tool_calls)
|
||||
|
||||
if result:
|
||||
tool_info = {
|
||||
"type": result.get("type", "unknown_type"),
|
||||
"id": result.get("id", f"tool_exec_{time.time()}"),
|
||||
"content": result.get("content", ""),
|
||||
"tool_name": tool_name,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
content = tool_info["content"]
|
||||
if not isinstance(content, str | list | tuple):
|
||||
tool_info["content"] = str(content)
|
||||
# 处理执行结果,保持原始顺序
|
||||
execution_results.sort(key=lambda x: x.original_index)
|
||||
|
||||
tool_results.append(tool_info)
|
||||
used_tools.append(tool_name)
|
||||
logger.info(f"{self.log_prefix}工具{tool_name}执行成功,类型: {tool_info['type']}")
|
||||
preview = content[:200] if isinstance(content, str) else str(content)[:200]
|
||||
logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...")
|
||||
for exec_result in execution_results:
|
||||
tool_name = getattr(exec_result.tool_call, "func_name", "unknown_tool")
|
||||
tool_args = getattr(exec_result.tool_call, "args", {})
|
||||
|
||||
# 记录到历史
|
||||
await self.history_manager.add_tool_call(ToolCallRecord(
|
||||
tool_name=tool_name,
|
||||
args=tool_args,
|
||||
result=result,
|
||||
status="success"
|
||||
))
|
||||
else:
|
||||
# 工具返回空结果也记录到历史
|
||||
await self.history_manager.add_tool_call(ToolCallRecord(
|
||||
tool_name=tool_name,
|
||||
args=tool_args,
|
||||
result=None,
|
||||
status="success"
|
||||
))
|
||||
if exec_result.error:
|
||||
# 处理错误结果
|
||||
error_msg = f"工具{tool_name}执行失败"
|
||||
if exec_result.is_timeout:
|
||||
error_msg += f" (超时: {self.execution_config.tool_timeout}s)"
|
||||
error_msg += f": {exec_result.error!s}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}")
|
||||
# 添加错误信息到结果中
|
||||
logger.error(f"{self.log_prefix}{error_msg}")
|
||||
error_info = {
|
||||
"type": "tool_error",
|
||||
"id": f"tool_error_{time.time()}",
|
||||
"content": f"工具{tool_name}执行失败: {e!s}",
|
||||
"content": error_msg,
|
||||
"tool_name": tool_name,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
@@ -290,12 +294,188 @@ class ToolExecutor:
|
||||
tool_name=tool_name,
|
||||
args=tool_args,
|
||||
result=None,
|
||||
status="error",
|
||||
error_message=str(e)
|
||||
status="error" if not exec_result.is_timeout else "timeout",
|
||||
error_message=str(exec_result.error),
|
||||
execution_time=exec_result.execution_time
|
||||
))
|
||||
elif exec_result.result:
|
||||
# 处理成功结果
|
||||
tool_info = {
|
||||
"type": exec_result.result.get("type", "unknown_type"),
|
||||
"id": exec_result.result.get("id", f"tool_exec_{time.time()}"),
|
||||
"content": exec_result.result.get("content", ""),
|
||||
"tool_name": tool_name,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
content = tool_info["content"]
|
||||
if not isinstance(content, str | list | tuple):
|
||||
tool_info["content"] = str(content)
|
||||
|
||||
tool_results.append(tool_info)
|
||||
used_tools.append(tool_name)
|
||||
logger.info(f"{self.log_prefix}工具{tool_name}执行成功,类型: {tool_info['type']}, 耗时: {exec_result.execution_time:.2f}s")
|
||||
preview = content[:200] if isinstance(content, str) else str(content)[:200]
|
||||
logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...")
|
||||
|
||||
# 记录到历史
|
||||
await self.history_manager.add_tool_call(ToolCallRecord(
|
||||
tool_name=tool_name,
|
||||
args=tool_args,
|
||||
result=exec_result.result,
|
||||
status="success",
|
||||
execution_time=exec_result.execution_time
|
||||
))
|
||||
else:
|
||||
# 工具返回空结果也记录到历史
|
||||
await self.history_manager.add_tool_call(ToolCallRecord(
|
||||
tool_name=tool_name,
|
||||
args=tool_args,
|
||||
result=None,
|
||||
status="success",
|
||||
execution_time=exec_result.execution_time
|
||||
))
|
||||
|
||||
return tool_results, used_tools
|
||||
|
||||
async def _execute_tools_concurrently(self, tool_calls: list[ToolCall]) -> list[ToolExecutionResult]:
|
||||
"""并发执行多个工具调用
|
||||
|
||||
Args:
|
||||
tool_calls: 工具调用列表
|
||||
|
||||
Returns:
|
||||
List[ToolExecutionResult]: 执行结果列表
|
||||
"""
|
||||
logger.info(f"{self.log_prefix}启动并发执行,工具数量: {len(tool_calls)}, 最大并发数: {self.execution_config.max_concurrent_tools}")
|
||||
|
||||
# 创建信号量控制并发数量
|
||||
semaphore = asyncio.Semaphore(self.execution_config.max_concurrent_tools)
|
||||
|
||||
async def execute_with_semaphore(tool_call: ToolCall, index: int) -> ToolExecutionResult:
|
||||
"""在信号量控制下执行单个工具"""
|
||||
async with semaphore:
|
||||
return await self._execute_single_tool_with_timeout(tool_call, index)
|
||||
|
||||
# 创建所有任务
|
||||
tasks = [
|
||||
execute_with_semaphore(tool_call, i)
|
||||
for i, tool_call in enumerate(tool_calls)
|
||||
]
|
||||
|
||||
# 并发执行所有任务
|
||||
try:
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 处理异常结果
|
||||
processed_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"{self.log_prefix}工具执行任务异常: {result}")
|
||||
processed_results.append(ToolExecutionResult(
|
||||
tool_call=tool_calls[i],
|
||||
result=None,
|
||||
error=result,
|
||||
original_index=i
|
||||
))
|
||||
else:
|
||||
processed_results.append(result)
|
||||
|
||||
return processed_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}并发执行过程中发生异常: {e}")
|
||||
# 返回所有工具的错误结果
|
||||
return [
|
||||
ToolExecutionResult(
|
||||
tool_call=tool_call,
|
||||
result=None,
|
||||
error=e,
|
||||
original_index=i
|
||||
)
|
||||
for i, tool_call in enumerate(tool_calls)
|
||||
]
|
||||
|
||||
async def _execute_tools_sequentially(self, tool_calls: list[ToolCall]) -> list[ToolExecutionResult]:
|
||||
"""串行执行多个工具调用(保持原有逻辑)
|
||||
|
||||
Args:
|
||||
tool_calls: 工具调用列表
|
||||
|
||||
Returns:
|
||||
List[ToolExecutionResult]: 执行结果列表
|
||||
"""
|
||||
logger.info(f"{self.log_prefix}启动串行执行,工具数量: {len(tool_calls)}")
|
||||
|
||||
results = []
|
||||
for i, tool_call in enumerate(tool_calls):
|
||||
result = await self._execute_single_tool_with_timeout(tool_call, i)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
async def _execute_single_tool_with_timeout(self, tool_call: ToolCall, index: int) -> ToolExecutionResult:
|
||||
"""执行单个工具调用,支持超时控制
|
||||
|
||||
Args:
|
||||
tool_call: 工具调用
|
||||
index: 原始索引
|
||||
|
||||
Returns:
|
||||
ToolExecutionResult: 执行结果
|
||||
"""
|
||||
tool_name = getattr(tool_call, "func_name", "unknown_tool")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
logger.debug(f"{self.log_prefix}开始执行工具: {tool_name}")
|
||||
|
||||
# 使用 asyncio.wait_for 实现超时控制
|
||||
if self.execution_config.tool_timeout > 0:
|
||||
result = await asyncio.wait_for(
|
||||
self.execute_tool_call(tool_call),
|
||||
timeout=self.execution_config.tool_timeout
|
||||
)
|
||||
else:
|
||||
result = await self.execute_tool_call(tool_call)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
logger.debug(f"{self.log_prefix}工具 {tool_name} 执行完成,耗时: {execution_time:.2f}s")
|
||||
|
||||
return ToolExecutionResult(
|
||||
tool_call=tool_call,
|
||||
result=result,
|
||||
error=None,
|
||||
execution_time=execution_time,
|
||||
is_timeout=False,
|
||||
original_index=index
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
execution_time = time.time() - start_time
|
||||
logger.warning(f"{self.log_prefix}工具 {tool_name} 执行超时 ({self.execution_config.tool_timeout}s)")
|
||||
|
||||
return ToolExecutionResult(
|
||||
tool_call=tool_call,
|
||||
result=None,
|
||||
error=asyncio.TimeoutError(f"工具执行超时 ({self.execution_config.tool_timeout}s)"),
|
||||
execution_time=execution_time,
|
||||
is_timeout=True,
|
||||
original_index=index
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
logger.error(f"{self.log_prefix}工具 {tool_name} 执行失败: {e}")
|
||||
|
||||
return ToolExecutionResult(
|
||||
tool_call=tool_call,
|
||||
result=None,
|
||||
error=e,
|
||||
execution_time=execution_time,
|
||||
is_timeout=False,
|
||||
original_index=index
|
||||
)
|
||||
|
||||
async def execute_tool_call(
|
||||
self, tool_call: ToolCall, tool_instance: BaseTool | None = None
|
||||
) -> dict[str, Any] | None:
|
||||
@@ -529,6 +709,59 @@ class ToolExecutor:
|
||||
"""
|
||||
return self.history_manager.get_stats()
|
||||
|
||||
def set_execution_config(self, config: ToolExecutionConfig) -> None:
|
||||
"""设置工具执行配置
|
||||
|
||||
Args:
|
||||
config: 新的执行配置
|
||||
"""
|
||||
self.execution_config = config
|
||||
logger.info(f"{self.log_prefix}工具执行配置已更新: 并发={config.enable_parallel}, 最大并发数={config.max_concurrent_tools}, 超时={config.tool_timeout}s")
|
||||
|
||||
def enable_parallel_execution(self, max_concurrent_tools: int = 5, timeout: float = 60.0) -> None:
|
||||
"""启用并发执行
|
||||
|
||||
Args:
|
||||
max_concurrent_tools: 最大并发工具数量
|
||||
timeout: 单个工具超时时间(秒)
|
||||
"""
|
||||
self.execution_config.enable_parallel = True
|
||||
self.execution_config.max_concurrent_tools = max_concurrent_tools
|
||||
self.execution_config.tool_timeout = timeout
|
||||
logger.info(f"{self.log_prefix}已启用并发执行: 最大并发数={max_concurrent_tools}, 超时={timeout}s")
|
||||
|
||||
def disable_parallel_execution(self) -> None:
|
||||
"""禁用并发执行,使用串行模式"""
|
||||
self.execution_config.enable_parallel = False
|
||||
logger.info(f"{self.log_prefix}已禁用并发执行,使用串行模式")
|
||||
|
||||
@classmethod
|
||||
def create_with_parallel_config(
|
||||
cls,
|
||||
chat_id: str,
|
||||
max_concurrent_tools: int = 5,
|
||||
tool_timeout: float = 60.0,
|
||||
enable_dependency_check: bool = True
|
||||
) -> "ToolExecutor":
|
||||
"""创建支持并发执行的工具执行器
|
||||
|
||||
Args:
|
||||
chat_id: 聊天标识符
|
||||
max_concurrent_tools: 最大并发工具数量
|
||||
tool_timeout: 单个工具超时时间(秒)
|
||||
enable_dependency_check: 是否启用依赖检查
|
||||
|
||||
Returns:
|
||||
配置好并发执行的ToolExecutor实例
|
||||
"""
|
||||
config = ToolExecutionConfig(
|
||||
enable_parallel=True,
|
||||
max_concurrent_tools=max_concurrent_tools,
|
||||
tool_timeout=tool_timeout,
|
||||
enable_dependency_check=enable_dependency_check
|
||||
)
|
||||
return cls(chat_id, config)
|
||||
|
||||
|
||||
"""
|
||||
ToolExecutor使用示例:
|
||||
@@ -541,7 +774,25 @@ results, _, _ = await executor.execute_from_chat_message(
|
||||
sender="用户"
|
||||
)
|
||||
|
||||
# 2. 获取详细信息
|
||||
# 2. 并发执行配置 - 创建支持并发的执行器
|
||||
parallel_executor = ToolExecutor.create_with_parallel_config(
|
||||
chat_id=my_chat_id,
|
||||
max_concurrent_tools=3, # 最大3个工具并发
|
||||
tool_timeout=30.0 # 单个工具30秒超时
|
||||
)
|
||||
|
||||
# 或者动态配置并发执行
|
||||
executor.enable_parallel_execution(max_concurrent_tools=5, timeout=60.0)
|
||||
|
||||
# 3. 并发执行多个工具 - 当LLM返回多个工具调用时自动并发执行
|
||||
results, used_tools, _ = await parallel_executor.execute_from_chat_message(
|
||||
target_message="帮我查询天气、新闻和股票价格",
|
||||
chat_history="",
|
||||
sender="用户"
|
||||
)
|
||||
# 多个工具将并发执行,显著提升性能
|
||||
|
||||
# 4. 获取详细信息
|
||||
results, used_tools, prompt = await executor.execute_from_chat_message(
|
||||
target_message="帮我查询Python相关知识",
|
||||
chat_history="",
|
||||
@@ -549,13 +800,13 @@ results, used_tools, prompt = await executor.execute_from_chat_message(
|
||||
return_details=True
|
||||
)
|
||||
|
||||
# 3. 直接执行特定工具
|
||||
# 5. 直接执行特定工具
|
||||
result = await executor.execute_specific_tool_simple(
|
||||
tool_name="get_knowledge",
|
||||
tool_args={"query": "机器学习"}
|
||||
)
|
||||
|
||||
# 4. 使用工具历史 - 连续对话中的工具调用
|
||||
# 6. 使用工具历史 - 连续对话中的工具调用
|
||||
# 第一次调用
|
||||
await executor.execute_from_chat_message(
|
||||
target_message="查询今天的天气",
|
||||
@@ -569,7 +820,27 @@ await executor.execute_from_chat_message(
|
||||
sender="用户"
|
||||
)
|
||||
|
||||
# 5. 获取和清除历史
|
||||
# 7. 配置管理
|
||||
config = ToolExecutionConfig(
|
||||
enable_parallel=True,
|
||||
max_concurrent_tools=10,
|
||||
tool_timeout=120.0,
|
||||
enable_dependency_check=True
|
||||
)
|
||||
executor.set_execution_config(config)
|
||||
|
||||
# 8. 获取和清除历史
|
||||
history = executor.get_tool_history() # 获取历史记录
|
||||
stats = executor.get_tool_stats() # 获取执行统计信息
|
||||
executor.clear_tool_history() # 清除历史记录
|
||||
|
||||
# 9. 禁用并发执行(如需要串行执行)
|
||||
executor.disable_parallel_execution()
|
||||
|
||||
并发执行优势:
|
||||
- 🚀 性能提升:多个工具同时执行,减少总体等待时间
|
||||
- 🛡️ 错误隔离:单个工具失败不影响其他工具执行
|
||||
- ⏱️ 超时控制:防止单个工具无限等待
|
||||
- 🔧 灵活配置:可根据需要调整并发数量和超时时间
|
||||
- 📊 统计信息:提供详细的执行时间和性能数据
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user