This commit is contained in:
tt-P607
2025-08-28 21:04:05 +08:00
27 changed files with 436 additions and 609 deletions

View File

@@ -10,7 +10,6 @@ from src.chat.express.expression_learner import expression_learner_manager
from src.plugin_system.base.component_types import ChatMode
from src.schedule.schedule_manager import schedule_manager
from src.plugin_system.apis import message_api
from src.mood.mood_manager import mood_manager
from .hfc_context import HfcContext
from .energy_manager import EnergyManager

View File

@@ -4,7 +4,7 @@ import hashlib
from pathlib import Path
import numpy as np
import faiss
from typing import Any, Dict, Optional, Union, List
from typing import Any, Dict, Optional, Union
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
@@ -14,6 +14,7 @@ from src.common.vector_db import vector_db_service
logger = get_logger("cache_manager")
class CacheManager:
"""
一个支持分层和语义缓存的通用工具缓存管理器。
@@ -21,6 +22,7 @@ class CacheManager:
L1缓存: 内存字典 (KV) + FAISS (Vector)。
L2缓存: 数据库 (KV) + ChromaDB (Vector)。
"""
_instance = None
def __new__(cls, *args, **kwargs):
@@ -32,7 +34,7 @@ class CacheManager:
"""
初始化缓存管理器。
"""
if not hasattr(self, '_initialized'):
if not hasattr(self, "_initialized"):
self.default_ttl = default_ttl
self.semantic_cache_collection_name = "semantic_cache"
@@ -41,7 +43,7 @@ class CacheManager:
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
self.l1_vector_index = faiss.IndexFlatIP(embedding_dim)
self.l1_vector_id_to_key: Dict[int, str] = {}
# L2 向量缓存 (使用新的服务)
vector_db_service.get_or_create_collection(self.semantic_cache_collection_name)
@@ -58,32 +60,32 @@ class CacheManager:
try:
if embedding_result is None:
return None
# 确保embedding_result是一维数组或列表
if isinstance(embedding_result, (list, tuple, np.ndarray)):
# 转换为numpy数组进行处理
embedding_array = np.array(embedding_result)
# 如果是多维数组,展平它
if embedding_array.ndim > 1:
embedding_array = embedding_array.flatten()
# 检查维度是否符合预期
expected_dim = global_config.lpmm_knowledge.embedding_dimension
if embedding_array.shape[0] != expected_dim:
logger.warning(f"嵌入向量维度不匹配: 期望 {expected_dim}, 实际 {embedding_array.shape[0]}")
return None
# 检查是否包含有效的数值
if np.isnan(embedding_array).any() or np.isinf(embedding_array).any():
logger.warning("嵌入向量包含无效的数值 (NaN 或 Inf)")
return None
return embedding_array.astype('float32')
return embedding_array.astype("float32")
else:
logger.warning(f"嵌入结果格式不支持: {type(embedding_result)}")
return None
except Exception as e:
logger.error(f"验证嵌入向量时发生错误: {e}")
return None
@@ -102,14 +104,20 @@ class CacheManager:
except (OSError, TypeError) as e:
file_hash = "unknown"
logger.warning(f"无法获取文件信息: {tool_file_path},错误: {e}")
try:
sorted_args = orjson.dumps(function_args, option=orjson.OPT_SORT_KEYS).decode('utf-8')
sorted_args = orjson.dumps(function_args, option=orjson.OPT_SORT_KEYS).decode("utf-8")
except TypeError:
sorted_args = repr(sorted(function_args.items()))
return f"{tool_name}::{sorted_args}::{file_hash}"
async def get(self, tool_name: str, function_args: Dict[str, Any], tool_file_path: Union[str, Path], semantic_query: Optional[str] = None) -> Optional[Any]:
async def get(
self,
tool_name: str,
function_args: Dict[str, Any],
tool_file_path: Union[str, Path],
semantic_query: Optional[str] = None,
) -> Optional[Any]:
"""
从缓存获取结果,查询顺序: L1-KV -> L1-Vector -> L2-KV -> L2-Vector。
"""
@@ -136,13 +144,13 @@ class CacheManager:
embedding_vector = embedding_result[0] if isinstance(embedding_result, tuple) else embedding_result
validated_embedding = self._validate_embedding(embedding_vector)
if validated_embedding is not None:
query_embedding = np.array([validated_embedding], dtype='float32')
query_embedding = np.array([validated_embedding], dtype="float32")
# 步骤 2a: L1 语义缓存 (FAISS)
if query_embedding is not None and self.l1_vector_index.ntotal > 0:
faiss.normalize_L2(query_embedding)
distances, indices = self.l1_vector_index.search(query_embedding, 1) # type: ignore
if indices.size > 0 and distances[0][0] > 0.75: # IP 越大越相似
distances, indices = self.l1_vector_index.search(query_embedding, 1) # type: ignore
if indices.size > 0 and distances[0][0] > 0.75: # IP 越大越相似
hit_index = indices[0][0]
l1_hit_key = self.l1_vector_id_to_key.get(hit_index)
if l1_hit_key and l1_hit_key in self.l1_kv_cache:
@@ -151,12 +159,9 @@ class CacheManager:
# 步骤 2b: L2 精确缓存 (数据库)
cache_results_obj = await db_query(
model_class=CacheEntries,
query_type="get",
filters={"cache_key": key},
single_result=True
model_class=CacheEntries, query_type="get", filters={"cache_key": key}, single_result=True
)
if cache_results_obj:
# 使用 getattr 安全访问属性,避免 Pylance 类型检查错误
expires_at = getattr(cache_results_obj, "expires_at", 0)
@@ -164,7 +169,7 @@ class CacheManager:
logger.info(f"命中L2键值缓存: {key}")
cache_value = getattr(cache_results_obj, "cache_value", "{}")
data = orjson.loads(cache_value)
# 更新访问统计
await db_query(
model_class=CacheEntries,
@@ -172,20 +177,16 @@ class CacheManager:
filters={"cache_key": key},
data={
"last_accessed": time.time(),
"access_count": getattr(cache_results_obj, "access_count", 0) + 1
}
"access_count": getattr(cache_results_obj, "access_count", 0) + 1,
},
)
# 回填 L1
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
return data
else:
# 删除过期的缓存条目
await db_query(
model_class=CacheEntries,
query_type="delete",
filters={"cache_key": key}
)
await db_query(model_class=CacheEntries, query_type="delete", filters={"cache_key": key})
# 步骤 2c: L2 语义缓存 (VectorDB Service)
if query_embedding is not None:
@@ -193,31 +194,33 @@ class CacheManager:
results = vector_db_service.query(
collection_name=self.semantic_cache_collection_name,
query_embeddings=query_embedding.tolist(),
n_results=1
n_results=1,
)
if results and results.get('ids') and results['ids'][0]:
distance = results['distances'][0][0] if results.get('distances') and results['distances'][0] else 'N/A'
if results and results.get("ids") and results["ids"][0]:
distance = (
results["distances"][0][0] if results.get("distances") and results["distances"][0] else "N/A"
)
logger.debug(f"L2语义搜索找到最相似的结果: id={results['ids'][0]}, 距离={distance}")
if distance != 'N/A' and distance < 0.75:
l2_hit_key = results['ids'][0][0] if isinstance(results['ids'][0], list) else results['ids'][0]
if distance != "N/A" and distance < 0.75:
l2_hit_key = results["ids"][0][0] if isinstance(results["ids"][0], list) else results["ids"][0]
logger.info(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}")
# 从数据库获取缓存数据
semantic_cache_results_obj = await db_query(
model_class=CacheEntries,
query_type="get",
filters={"cache_key": l2_hit_key},
single_result=True
single_result=True,
)
if semantic_cache_results_obj:
expires_at = getattr(semantic_cache_results_obj, "expires_at", 0)
if time.time() < expires_at:
cache_value = getattr(semantic_cache_results_obj, "cache_value", "{}")
data = orjson.loads(cache_value)
logger.debug(f"L2语义缓存返回的数据: {data}")
# 回填 L1
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
if query_embedding is not None:
@@ -235,7 +238,15 @@ class CacheManager:
logger.debug(f"缓存未命中: {key}")
return None
async def set(self, tool_name: str, function_args: Dict[str, Any], tool_file_path: Union[str, Path], data: Any, ttl: Optional[int] = None, semantic_query: Optional[str] = None):
async def set(
self,
tool_name: str,
function_args: Dict[str, Any],
tool_file_path: Union[str, Path],
data: Any,
ttl: Optional[int] = None,
semantic_query: Optional[str] = None,
):
"""将结果存入所有缓存层。"""
if ttl is None:
ttl = self.default_ttl
@@ -244,27 +255,22 @@ class CacheManager:
key = self._generate_key(tool_name, function_args, tool_file_path)
expires_at = time.time() + ttl
# 写入 L1
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
# 写入 L2 (数据库)
cache_data = {
"cache_key": key,
"cache_value": orjson.dumps(data).decode('utf-8'),
"cache_value": orjson.dumps(data).decode("utf-8"),
"expires_at": expires_at,
"tool_name": tool_name,
"created_at": time.time(),
"last_accessed": time.time(),
"access_count": 1
"access_count": 1,
}
await db_save(
model_class=CacheEntries,
data=cache_data,
key_field="cache_key",
key_value=key
)
await db_save(model_class=CacheEntries, data=cache_data, key_field="cache_key", key_value=key)
# 写入语义缓存
if semantic_query and self.embedding_model:
@@ -274,19 +280,19 @@ class CacheManager:
embedding_vector = embedding_result[0] if isinstance(embedding_result, tuple) else embedding_result
validated_embedding = self._validate_embedding(embedding_vector)
if validated_embedding is not None:
embedding = np.array([validated_embedding], dtype='float32')
embedding = np.array([validated_embedding], dtype="float32")
# 写入 L1 Vector
new_id = self.l1_vector_index.ntotal
faiss.normalize_L2(embedding)
self.l1_vector_index.add(x=embedding) # type: ignore
self.l1_vector_id_to_key[new_id] = key
# 写入 L2 Vector (使用新的服务)
vector_db_service.add(
collection_name=self.semantic_cache_collection_name,
embeddings=embedding.tolist(),
ids=[key]
ids=[key],
)
except Exception as e:
logger.warning(f"语义缓存写入失败: {e}")
@@ -306,16 +312,16 @@ class CacheManager:
await db_query(
model_class=CacheEntries,
query_type="delete",
filters={} # 删除所有记录
filters={}, # 删除所有记录
)
# 清空 VectorDB
try:
vector_db_service.delete_collection(name=self.semantic_cache_collection_name)
vector_db_service.get_or_create_collection(name=self.semantic_cache_collection_name)
except Exception as e:
logger.warning(f"清空 VectorDB 集合失败: {e}")
logger.info("L2 (数据库 & VectorDB) 缓存已清空。")
async def clear_all(self):
@@ -327,85 +333,23 @@ class CacheManager:
async def clean_expired(self):
"""清理过期的缓存条目"""
current_time = time.time()
# 清理L1过期条目
expired_keys = []
for key, entry in self.l1_kv_cache.items():
if current_time >= entry["expires_at"]:
expired_keys.append(key)
for key in expired_keys:
del self.l1_kv_cache[key]
# 清理L2过期条目
await db_query(
model_class=CacheEntries,
query_type="delete",
filters={"expires_at": {"$lt": current_time}}
)
await db_query(model_class=CacheEntries, query_type="delete", filters={"expires_at": {"$lt": current_time}})
if expired_keys:
logger.info(f"清理了 {len(expired_keys)} 个过期的L1缓存条目")
# 全局实例
tool_cache = CacheManager()
import inspect
import time
def wrap_tool_executor():
"""
包装工具执行器以添加缓存功能
这个函数应该在系统启动时被调用一次
"""
from src.plugin_system.core.tool_use import ToolExecutor
from src.plugin_system.apis.tool_api import get_tool_instance
original_execute = ToolExecutor.execute_tool_call
async def wrapped_execute_tool_call(self, tool_call, tool_instance=None):
if not tool_instance:
tool_instance = get_tool_instance(tool_call.func_name)
if not tool_instance or not tool_instance.enable_cache:
return await original_execute(self, tool_call, tool_instance)
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key)
cached_result = await tool_cache.get(
tool_name=tool_call.func_name,
function_args=tool_call.args,
tool_file_path=tool_file_path,
semantic_query=semantic_query
)
if cached_result:
logger.info(f"{getattr(self, 'log_prefix', '')}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
return cached_result
except Exception as e:
logger.error(f"{getattr(self, 'log_prefix', '')}检查工具缓存时出错: {e}")
result = await original_execute(self, tool_call, tool_instance)
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key)
await tool_cache.set(
tool_name=tool_call.func_name,
function_args=tool_call.args,
tool_file_path=tool_file_path,
data=result,
ttl=tool_instance.cache_ttl,
semantic_query=semantic_query
)
except Exception as e:
logger.error(f"{getattr(self, 'log_prefix', '')}设置工具缓存时出错: {e}")
return result
ToolExecutor.execute_tool_call = wrapped_execute_tool_call

View File

@@ -5,7 +5,7 @@ import random
from enum import Enum
from rich.traceback import install
from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine
from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine, Generator
from src.common.logger import get_logger
from src.config.config import model_config
@@ -283,131 +283,130 @@ class LLMRequest:
tools: Optional[List[Dict[str, Any]]] = None,
raise_when_empty: bool = True,
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
"""执行单次请求"""
# 模型选择和请求准备
start_time = time.time()
model_info, api_provider, client = self._select_model()
model_name = model_info.name
# 检查是否启用反截断
use_anti_truncation = getattr(api_provider, "anti_truncation", False)
processed_prompt = prompt
if use_anti_truncation:
processed_prompt += self.anti_truncation_instruction
logger.info(f"{api_provider} '{self.task_name}' 已启用反截断功能")
processed_prompt = self._apply_content_obfuscation(processed_prompt, api_provider)
message_builder = MessageBuilder()
message_builder.add_text_content(processed_prompt)
messages = [message_builder.build()]
tool_built = self._build_tool_options(tools)
# 空回复重试逻辑
empty_retry_count = 0
max_empty_retry = api_provider.max_retry
empty_retry_interval = api_provider.retry_interval
while empty_retry_count <= max_empty_retry:
"""
执行单次请求,并在模型失败时按顺序切换到下一个可用模型。
"""
failed_models = set()
last_exception: Optional[Exception] = None
model_scheduler = self._model_scheduler(failed_models)
for model_info, api_provider, client in model_scheduler:
start_time = time.time()
model_name = model_info.name
logger.info(f"正在尝试使用模型: {model_name}")
try:
response = await self._execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.RESPONSE,
model_info=model_info,
message_list=messages,
tool_options=tool_built,
temperature=temperature,
max_tokens=max_tokens,
)
content = response.content or ""
reasoning_content = response.reasoning_content or ""
tool_calls = response.tool_calls
# 从内容中提取<think>标签的推理内容(向后兼容)
if not reasoning_content and content:
content, extracted_reasoning = self._extract_reasoning(content)
reasoning_content = extracted_reasoning
is_empty_reply = False
is_truncated = False
# 检测是否为空回复或截断
if not tool_calls:
is_empty_reply = not content or content.strip() == ""
is_truncated = False
# 检查是否启用反截断
use_anti_truncation = getattr(api_provider, "anti_truncation", False)
processed_prompt = prompt
if use_anti_truncation:
if content.endswith("[done]"):
content = content[:-6].strip()
logger.debug("检测到并已移除 [done] 标记")
else:
is_truncated = True
logger.warning("未检测到 [done] 标记,判定为截断")
processed_prompt += self.anti_truncation_instruction
logger.info(f"'{model_name}' for task '{self.task_name}' 已启用反截断功能")
if is_empty_reply or is_truncated:
if empty_retry_count < max_empty_retry:
empty_retry_count += 1
reason = "空回复" if is_empty_reply else "截断"
logger.warning(f"检测到{reason},正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成")
processed_prompt = self._apply_content_obfuscation(processed_prompt, api_provider)
if empty_retry_interval > 0:
await asyncio.sleep(empty_retry_interval)
message_builder = MessageBuilder()
message_builder.add_text_content(processed_prompt)
messages = [message_builder.build()]
tool_built = self._build_tool_options(tools)
model_info, api_provider, client = self._select_model()
continue
else:
# 已达到最大重试次数,但仍然是空回复或截断
reason = "空回复" if is_empty_reply else "截断"
# 抛出异常,由外层重试逻辑或最终的异常处理器捕获
raise RuntimeError(f"经过 {max_empty_retry + 1} 次尝试后仍然是{reason}的回复")
# 针对当前模型的空回复/截断重试逻辑
empty_retry_count = 0
max_empty_retry = api_provider.max_retry
empty_retry_interval = api_provider.retry_interval
# 记录使用情况
if usage := response.usage:
llm_usage_recorder.record_usage_to_database(
while empty_retry_count <= max_empty_retry:
response = await self._execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.RESPONSE,
model_info=model_info,
model_usage=usage,
time_cost=time.time() - start_time,
user_id="system",
request_type=self.request_type,
endpoint="/chat/completions",
message_list=messages,
tool_options=tool_built,
temperature=temperature,
max_tokens=max_tokens,
)
# 处理空回复
if not content and not tool_calls:
if raise_when_empty:
raise RuntimeError(f"经过 {empty_retry_count} 次重试后仍然生成空回复")
content = "生成的响应为空,请检查模型配置或输入内容是否正确"
elif empty_retry_count > 0:
logger.info(f"经过 {empty_retry_count} 次重试后成功生成回复")
content = response.content or ""
reasoning_content = response.reasoning_content or ""
tool_calls = response.tool_calls
return content, (reasoning_content, model_info.name, tool_calls)
if not reasoning_content and content:
content, extracted_reasoning = self._extract_reasoning(content)
reasoning_content = extracted_reasoning
is_empty_reply = not tool_calls and (not content or content.strip() == "")
is_truncated = False
if use_anti_truncation:
if content.endswith("[done]"):
content = content[:-6].strip()
else:
is_truncated = True
if is_empty_reply or is_truncated:
empty_retry_count += 1
if empty_retry_count <= max_empty_retry:
reason = "空回复" if is_empty_reply else "截断"
logger.warning(f"模型 '{model_name}' 检测到{reason},正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成...")
if empty_retry_interval > 0:
await asyncio.sleep(empty_retry_interval)
continue # 继续使用当前模型重试
else:
# 当前模型重试次数用尽,跳出内层循环,触发外层循环切换模型
reason = "空回复" if is_empty_reply else "截断"
logger.error(f"模型 '{model_name}' 经过 {max_empty_retry} 次重试后仍然是{reason}的回复。")
raise RuntimeError(f"模型 '{model_name}' 达到最大空回复/截断重试次数")
# 成功获取响应
if usage := response.usage:
llm_usage_recorder.record_usage_to_database(
model_info=model_info, model_usage=usage, time_cost=time.time() - start_time,
user_id="system", request_type=self.request_type, endpoint="/chat/completions",
)
if not content and not tool_calls:
if raise_when_empty:
raise RuntimeError("生成空回复")
content = "生成的响应为空"
logger.info(f"模型 '{model_name}' 成功生成回复。")
return content, (reasoning_content, model_name, tool_calls)
except RespNotOkException as e:
if e.status_code in [401, 403]:
logger.error(f"模型 '{model_name}' 遇到认证/权限错误 (Code: {e.status_code}),将尝试下一个模型。")
failed_models.add(model_name)
last_exception = e
continue # 切换到下一个模型
else:
logger.error(f"模型 '{model_name}' 请求失败HTTP状态码: {e.status_code}")
if raise_when_empty:
raise
# 对于其他HTTP错误直接抛出不再尝试其他模型
return f"请求失败: {e}", ("", model_name, None)
except RuntimeError as e:
# 捕获所有重试失败(包括空回复和网络问题)
logger.error(f"模型 '{model_name}' 在所有重试后仍然失败: {e},将尝试下一个模型。")
failed_models.add(model_name)
last_exception = e
continue # 切换到下一个模型
except Exception as e:
logger.error(f"请求执行失败: {e}")
if raise_when_empty:
# 在非并发模式下,如果第一次尝试就失败,则直接抛出异常
if empty_retry_count == 0:
raise
logger.error(f"使用模型 '{model_name}' 时发生未知异常: {e}")
failed_models.add(model_name)
last_exception = e
continue # 切换到下一个模型
# 如果在重试过程中失败,则继续重试
empty_retry_count += 1
if empty_retry_count <= max_empty_retry:
logger.warning(f"请求失败,将在 {empty_retry_interval} 秒后进行第 {empty_retry_count}/{max_empty_retry} 次重试...")
if empty_retry_interval > 0:
await asyncio.sleep(empty_retry_interval)
continue
else:
logger.error(f"经过 {max_empty_retry} 次重试后仍然失败")
raise RuntimeError(f"经过 {max_empty_retry} 次重试后仍然无法生成有效回复") from e
else:
# 在并发模式下,单个请求的失败不应中断整个并发流程,
# 而是将异常返回给调用者(即 execute_concurrently进行统一处理
raise # 重新抛出异常,由 execute_concurrently 中的 gather 捕获
# 重试失败
# 所有模型都尝试失败
logger.error("所有可用模型都已尝试失败。")
if raise_when_empty:
raise RuntimeError(f"经过 {max_empty_retry} 次重试后仍然无法生成有效回复")
return "生成的响应为空,请检查模型配置或输入内容是否正确", ("", model_name, None)
if last_exception:
raise RuntimeError("所有模型都请求失败") from last_exception
raise RuntimeError("所有模型都请求失败,且没有具体的异常信息")
return "所有模型都请求失败", ("", "unknown", None)
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
"""获取嵌入向量
@@ -446,9 +445,24 @@ class LLMRequest:
return embedding, model_info.name
def _model_scheduler(self, failed_models: set) -> Generator[Tuple[ModelInfo, APIProvider, BaseClient], None, None]:
"""
一个模型调度器,按顺序提供模型,并跳过已失败的模型。
"""
for model_name in self.model_for_task.model_list:
if model_name in failed_models:
continue
model_info = model_config.get_model_info(model_name)
api_provider = model_config.get_provider(model_info.api_provider)
force_new_client = (self.request_type == "embedding")
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
yield model_info, api_provider, client
def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]:
"""
根据总tokens和惩罚值选择的模型
根据总tokens和惩罚值选择的模型 (负载均衡)
"""
least_used_model_name = min(
self.model_usage,

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Type
from typing import Optional, Type
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.component_types import ComponentType

View File

@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import List, Type, Tuple, Union, TYPE_CHECKING
from typing import List, Type, Tuple, Union
from .plugin_base import PluginBase
from src.common.logger import get_logger

View File

@@ -4,7 +4,7 @@
"""
from abc import ABC, abstractmethod
from typing import Dict, Tuple, Optional, List
from typing import Tuple, Optional, List
import re
from src.common.logger import get_logger

View File

@@ -7,8 +7,10 @@ from src.llm_models.utils_model import LLMRequest
from src.llm_models.payload_content import ToolCall
from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
import inspect
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger
from src.common.cache_manager import tool_cache
logger = get_logger("tool_use")
@@ -184,21 +186,65 @@ class ToolExecutor:
return tool_results, used_tools
async def execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]:
# sourcery skip: use-assigned-variable
"""执行单个工具调用
"""执行单个工具调用,并处理缓存"""
function_args = tool_call.args or {}
tool_instance = tool_instance or get_tool_instance(tool_call.func_name)
Args:
tool_call: 工具调用对象
# 如果工具不存在或未启用缓存,则直接执行
if not tool_instance or not tool_instance.enable_cache:
return await self._original_execute_tool_call(tool_call, tool_instance)
Returns:
Optional[Dict]: 工具调用结果如果失败则返回None
"""
# --- 缓存逻辑开始 ---
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
cached_result = await tool_cache.get(
tool_name=tool_call.func_name,
function_args=function_args,
tool_file_path=tool_file_path,
semantic_query=semantic_query
)
if cached_result:
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
return cached_result
except Exception as e:
logger.error(f"{self.log_prefix}检查工具缓存时出错: {e}")
# 缓存未命中,执行原始工具调用
result = await self._original_execute_tool_call(tool_call, tool_instance)
# 将结果存入缓存
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
await tool_cache.set(
tool_name=tool_call.func_name,
function_args=function_args,
tool_file_path=tool_file_path,
data=result,
ttl=tool_instance.cache_ttl,
semantic_query=semantic_query
)
except Exception as e:
logger.error(f"{self.log_prefix}设置工具缓存时出错: {e}")
# --- 缓存逻辑结束 ---
return result
async def _original_execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]:
"""执行单个工具调用的原始逻辑"""
try:
function_name = tool_call.func_name
function_args = tool_call.args or {}
logger.info(f"🤖 {self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}")
logger.info(f"{self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}")
function_args["llm_called"] = True # 标记为LLM调用
# 获取对应工具实例
tool_instance = tool_instance or get_tool_instance(function_name)
if not tool_instance:

View File

@@ -9,12 +9,9 @@ import datetime
import base64
import aiohttp
from src.common.logger import get_logger
import base64
import aiohttp
import imghdr
import asyncio
from src.common.logger import get_logger
from src.plugin_system.apis import llm_api, config_api, generator_api, person_api
from src.plugin_system.apis import llm_api, config_api, generator_api
from src.chat.message_receive.chat_stream import get_chat_manager
from maim_message import UserInfo
from src.llm_models.utils_model import LLMRequest

View File

@@ -16,7 +16,7 @@ from src.plugin_system.apis.permission_api import permission_api
from src.plugin_system.apis.logging_api import get_logger
from src.plugin_system.base.component_types import PlusCommandInfo, ChatType
from src.plugin_system.base.config_types import ConfigField
from src.plugin_system.utils.permission_decorators import require_permission, require_master, PermissionChecker
from src.plugin_system.utils.permission_decorators import require_permission
logger = get_logger("Permission")

View File

@@ -411,7 +411,6 @@ class ScheduleManager:
通过关键词匹配、唤醒度、睡眠压力等综合判断是否处于休眠时间。
新增弹性睡眠机制,允许在压力低时延迟入睡,并在入睡前发送通知。
"""
from src.chat.chat_loop.wakeup_manager import WakeUpManager
# --- 基础检查 ---
if not global_config.schedule.enable_is_sleep:
return False