This commit is contained in:
Windpicker-owo
2025-08-28 21:05:17 +08:00
31 changed files with 861 additions and 664 deletions

View File

@@ -10,7 +10,6 @@ from src.chat.express.expression_learner import expression_learner_manager
from src.plugin_system.base.component_types import ChatMode
from src.schedule.schedule_manager import schedule_manager
from src.plugin_system.apis import message_api
from src.mood.mood_manager import mood_manager
from .hfc_context import HfcContext
from .energy_manager import EnergyManager

View File

@@ -4,7 +4,7 @@ import hashlib
from pathlib import Path
import numpy as np
import faiss
from typing import Any, Dict, Optional, Union, List
from typing import Any, Dict, Optional, Union
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
@@ -14,6 +14,7 @@ from src.common.vector_db import vector_db_service
logger = get_logger("cache_manager")
class CacheManager:
"""
一个支持分层和语义缓存的通用工具缓存管理器。
@@ -21,6 +22,7 @@ class CacheManager:
L1缓存: 内存字典 (KV) + FAISS (Vector)。
L2缓存: 数据库 (KV) + ChromaDB (Vector)。
"""
_instance = None
def __new__(cls, *args, **kwargs):
@@ -32,7 +34,7 @@ class CacheManager:
"""
初始化缓存管理器。
"""
if not hasattr(self, '_initialized'):
if not hasattr(self, "_initialized"):
self.default_ttl = default_ttl
self.semantic_cache_collection_name = "semantic_cache"
@@ -41,7 +43,7 @@ class CacheManager:
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
self.l1_vector_index = faiss.IndexFlatIP(embedding_dim)
self.l1_vector_id_to_key: Dict[int, str] = {}
# L2 向量缓存 (使用新的服务)
vector_db_service.get_or_create_collection(self.semantic_cache_collection_name)
@@ -58,32 +60,32 @@ class CacheManager:
try:
if embedding_result is None:
return None
# 确保embedding_result是一维数组或列表
if isinstance(embedding_result, (list, tuple, np.ndarray)):
# 转换为numpy数组进行处理
embedding_array = np.array(embedding_result)
# 如果是多维数组,展平它
if embedding_array.ndim > 1:
embedding_array = embedding_array.flatten()
# 检查维度是否符合预期
expected_dim = global_config.lpmm_knowledge.embedding_dimension
if embedding_array.shape[0] != expected_dim:
logger.warning(f"嵌入向量维度不匹配: 期望 {expected_dim}, 实际 {embedding_array.shape[0]}")
return None
# 检查是否包含有效的数值
if np.isnan(embedding_array).any() or np.isinf(embedding_array).any():
logger.warning("嵌入向量包含无效的数值 (NaN 或 Inf)")
return None
return embedding_array.astype('float32')
return embedding_array.astype("float32")
else:
logger.warning(f"嵌入结果格式不支持: {type(embedding_result)}")
return None
except Exception as e:
logger.error(f"验证嵌入向量时发生错误: {e}")
return None
@@ -102,14 +104,20 @@ class CacheManager:
except (OSError, TypeError) as e:
file_hash = "unknown"
logger.warning(f"无法获取文件信息: {tool_file_path},错误: {e}")
try:
sorted_args = orjson.dumps(function_args, option=orjson.OPT_SORT_KEYS).decode('utf-8')
sorted_args = orjson.dumps(function_args, option=orjson.OPT_SORT_KEYS).decode("utf-8")
except TypeError:
sorted_args = repr(sorted(function_args.items()))
return f"{tool_name}::{sorted_args}::{file_hash}"
async def get(self, tool_name: str, function_args: Dict[str, Any], tool_file_path: Union[str, Path], semantic_query: Optional[str] = None) -> Optional[Any]:
async def get(
self,
tool_name: str,
function_args: Dict[str, Any],
tool_file_path: Union[str, Path],
semantic_query: Optional[str] = None,
) -> Optional[Any]:
"""
从缓存获取结果,查询顺序: L1-KV -> L1-Vector -> L2-KV -> L2-Vector。
"""
@@ -136,13 +144,13 @@ class CacheManager:
embedding_vector = embedding_result[0] if isinstance(embedding_result, tuple) else embedding_result
validated_embedding = self._validate_embedding(embedding_vector)
if validated_embedding is not None:
query_embedding = np.array([validated_embedding], dtype='float32')
query_embedding = np.array([validated_embedding], dtype="float32")
# 步骤 2a: L1 语义缓存 (FAISS)
if query_embedding is not None and self.l1_vector_index.ntotal > 0:
faiss.normalize_L2(query_embedding)
distances, indices = self.l1_vector_index.search(query_embedding, 1) # type: ignore
if indices.size > 0 and distances[0][0] > 0.75: # IP 越大越相似
distances, indices = self.l1_vector_index.search(query_embedding, 1) # type: ignore
if indices.size > 0 and distances[0][0] > 0.75: # IP 越大越相似
hit_index = indices[0][0]
l1_hit_key = self.l1_vector_id_to_key.get(hit_index)
if l1_hit_key and l1_hit_key in self.l1_kv_cache:
@@ -151,12 +159,9 @@ class CacheManager:
# 步骤 2b: L2 精确缓存 (数据库)
cache_results_obj = await db_query(
model_class=CacheEntries,
query_type="get",
filters={"cache_key": key},
single_result=True
model_class=CacheEntries, query_type="get", filters={"cache_key": key}, single_result=True
)
if cache_results_obj:
# 使用 getattr 安全访问属性,避免 Pylance 类型检查错误
expires_at = getattr(cache_results_obj, "expires_at", 0)
@@ -164,7 +169,7 @@ class CacheManager:
logger.info(f"命中L2键值缓存: {key}")
cache_value = getattr(cache_results_obj, "cache_value", "{}")
data = orjson.loads(cache_value)
# 更新访问统计
await db_query(
model_class=CacheEntries,
@@ -172,20 +177,16 @@ class CacheManager:
filters={"cache_key": key},
data={
"last_accessed": time.time(),
"access_count": getattr(cache_results_obj, "access_count", 0) + 1
}
"access_count": getattr(cache_results_obj, "access_count", 0) + 1,
},
)
# 回填 L1
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
return data
else:
# 删除过期的缓存条目
await db_query(
model_class=CacheEntries,
query_type="delete",
filters={"cache_key": key}
)
await db_query(model_class=CacheEntries, query_type="delete", filters={"cache_key": key})
# 步骤 2c: L2 语义缓存 (VectorDB Service)
if query_embedding is not None:
@@ -193,31 +194,33 @@ class CacheManager:
results = vector_db_service.query(
collection_name=self.semantic_cache_collection_name,
query_embeddings=query_embedding.tolist(),
n_results=1
n_results=1,
)
if results and results.get('ids') and results['ids'][0]:
distance = results['distances'][0][0] if results.get('distances') and results['distances'][0] else 'N/A'
if results and results.get("ids") and results["ids"][0]:
distance = (
results["distances"][0][0] if results.get("distances") and results["distances"][0] else "N/A"
)
logger.debug(f"L2语义搜索找到最相似的结果: id={results['ids'][0]}, 距离={distance}")
if distance != 'N/A' and distance < 0.75:
l2_hit_key = results['ids'][0][0] if isinstance(results['ids'][0], list) else results['ids'][0]
if distance != "N/A" and distance < 0.75:
l2_hit_key = results["ids"][0][0] if isinstance(results["ids"][0], list) else results["ids"][0]
logger.info(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}")
# 从数据库获取缓存数据
semantic_cache_results_obj = await db_query(
model_class=CacheEntries,
query_type="get",
filters={"cache_key": l2_hit_key},
single_result=True
single_result=True,
)
if semantic_cache_results_obj:
expires_at = getattr(semantic_cache_results_obj, "expires_at", 0)
if time.time() < expires_at:
cache_value = getattr(semantic_cache_results_obj, "cache_value", "{}")
data = orjson.loads(cache_value)
logger.debug(f"L2语义缓存返回的数据: {data}")
# 回填 L1
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
if query_embedding is not None:
@@ -235,7 +238,15 @@ class CacheManager:
logger.debug(f"缓存未命中: {key}")
return None
async def set(self, tool_name: str, function_args: Dict[str, Any], tool_file_path: Union[str, Path], data: Any, ttl: Optional[int] = None, semantic_query: Optional[str] = None):
async def set(
self,
tool_name: str,
function_args: Dict[str, Any],
tool_file_path: Union[str, Path],
data: Any,
ttl: Optional[int] = None,
semantic_query: Optional[str] = None,
):
"""将结果存入所有缓存层。"""
if ttl is None:
ttl = self.default_ttl
@@ -244,27 +255,22 @@ class CacheManager:
key = self._generate_key(tool_name, function_args, tool_file_path)
expires_at = time.time() + ttl
# 写入 L1
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
# 写入 L2 (数据库)
cache_data = {
"cache_key": key,
"cache_value": orjson.dumps(data).decode('utf-8'),
"cache_value": orjson.dumps(data).decode("utf-8"),
"expires_at": expires_at,
"tool_name": tool_name,
"created_at": time.time(),
"last_accessed": time.time(),
"access_count": 1
"access_count": 1,
}
await db_save(
model_class=CacheEntries,
data=cache_data,
key_field="cache_key",
key_value=key
)
await db_save(model_class=CacheEntries, data=cache_data, key_field="cache_key", key_value=key)
# 写入语义缓存
if semantic_query and self.embedding_model:
@@ -274,19 +280,19 @@ class CacheManager:
embedding_vector = embedding_result[0] if isinstance(embedding_result, tuple) else embedding_result
validated_embedding = self._validate_embedding(embedding_vector)
if validated_embedding is not None:
embedding = np.array([validated_embedding], dtype='float32')
embedding = np.array([validated_embedding], dtype="float32")
# 写入 L1 Vector
new_id = self.l1_vector_index.ntotal
faiss.normalize_L2(embedding)
self.l1_vector_index.add(x=embedding) # type: ignore
self.l1_vector_id_to_key[new_id] = key
# 写入 L2 Vector (使用新的服务)
vector_db_service.add(
collection_name=self.semantic_cache_collection_name,
embeddings=embedding.tolist(),
ids=[key]
ids=[key],
)
except Exception as e:
logger.warning(f"语义缓存写入失败: {e}")
@@ -306,16 +312,16 @@ class CacheManager:
await db_query(
model_class=CacheEntries,
query_type="delete",
filters={} # 删除所有记录
filters={}, # 删除所有记录
)
# 清空 VectorDB
try:
vector_db_service.delete_collection(name=self.semantic_cache_collection_name)
vector_db_service.get_or_create_collection(name=self.semantic_cache_collection_name)
except Exception as e:
logger.warning(f"清空 VectorDB 集合失败: {e}")
logger.info("L2 (数据库 & VectorDB) 缓存已清空。")
async def clear_all(self):
@@ -327,85 +333,23 @@ class CacheManager:
async def clean_expired(self):
"""清理过期的缓存条目"""
current_time = time.time()
# 清理L1过期条目
expired_keys = []
for key, entry in self.l1_kv_cache.items():
if current_time >= entry["expires_at"]:
expired_keys.append(key)
for key in expired_keys:
del self.l1_kv_cache[key]
# 清理L2过期条目
await db_query(
model_class=CacheEntries,
query_type="delete",
filters={"expires_at": {"$lt": current_time}}
)
await db_query(model_class=CacheEntries, query_type="delete", filters={"expires_at": {"$lt": current_time}})
if expired_keys:
logger.info(f"清理了 {len(expired_keys)} 个过期的L1缓存条目")
# 全局实例
tool_cache = CacheManager()
import inspect
import time
def wrap_tool_executor():
"""
包装工具执行器以添加缓存功能
这个函数应该在系统启动时被调用一次
"""
from src.plugin_system.core.tool_use import ToolExecutor
from src.plugin_system.apis.tool_api import get_tool_instance
original_execute = ToolExecutor.execute_tool_call
async def wrapped_execute_tool_call(self, tool_call, tool_instance=None):
if not tool_instance:
tool_instance = get_tool_instance(tool_call.func_name)
if not tool_instance or not tool_instance.enable_cache:
return await original_execute(self, tool_call, tool_instance)
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key)
cached_result = await tool_cache.get(
tool_name=tool_call.func_name,
function_args=tool_call.args,
tool_file_path=tool_file_path,
semantic_query=semantic_query
)
if cached_result:
logger.info(f"{getattr(self, 'log_prefix', '')}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
return cached_result
except Exception as e:
logger.error(f"{getattr(self, 'log_prefix', '')}检查工具缓存时出错: {e}")
result = await original_execute(self, tool_call, tool_instance)
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key)
await tool_cache.set(
tool_name=tool_call.func_name,
function_args=tool_call.args,
tool_file_path=tool_file_path,
data=result,
ttl=tool_instance.cache_ttl,
semantic_query=semantic_query
)
except Exception as e:
logger.error(f"{getattr(self, 'log_prefix', '')}设置工具缓存时出错: {e}")
return result
ToolExecutor.execute_tool_call = wrapped_execute_tool_call

View File

@@ -5,7 +5,7 @@ import random
from enum import Enum
from rich.traceback import install
from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine
from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine, Generator
from src.common.logger import get_logger
from src.config.config import model_config
@@ -283,131 +283,130 @@ class LLMRequest:
tools: Optional[List[Dict[str, Any]]] = None,
raise_when_empty: bool = True,
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
"""执行单次请求"""
# 模型选择和请求准备
start_time = time.time()
model_info, api_provider, client = self._select_model()
model_name = model_info.name
# 检查是否启用反截断
use_anti_truncation = getattr(api_provider, "anti_truncation", False)
processed_prompt = prompt
if use_anti_truncation:
processed_prompt += self.anti_truncation_instruction
logger.info(f"{api_provider} '{self.task_name}' 已启用反截断功能")
processed_prompt = self._apply_content_obfuscation(processed_prompt, api_provider)
message_builder = MessageBuilder()
message_builder.add_text_content(processed_prompt)
messages = [message_builder.build()]
tool_built = self._build_tool_options(tools)
# 空回复重试逻辑
empty_retry_count = 0
max_empty_retry = api_provider.max_retry
empty_retry_interval = api_provider.retry_interval
while empty_retry_count <= max_empty_retry:
"""
执行单次请求,并在模型失败时按顺序切换到下一个可用模型。
"""
failed_models = set()
last_exception: Optional[Exception] = None
model_scheduler = self._model_scheduler(failed_models)
for model_info, api_provider, client in model_scheduler:
start_time = time.time()
model_name = model_info.name
logger.info(f"正在尝试使用模型: {model_name}")
try:
response = await self._execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.RESPONSE,
model_info=model_info,
message_list=messages,
tool_options=tool_built,
temperature=temperature,
max_tokens=max_tokens,
)
content = response.content or ""
reasoning_content = response.reasoning_content or ""
tool_calls = response.tool_calls
# 从内容中提取<think>标签的推理内容(向后兼容)
if not reasoning_content and content:
content, extracted_reasoning = self._extract_reasoning(content)
reasoning_content = extracted_reasoning
is_empty_reply = False
is_truncated = False
# 检测是否为空回复或截断
if not tool_calls:
is_empty_reply = not content or content.strip() == ""
is_truncated = False
# 检查是否启用反截断
use_anti_truncation = getattr(api_provider, "anti_truncation", False)
processed_prompt = prompt
if use_anti_truncation:
if content.endswith("[done]"):
content = content[:-6].strip()
logger.debug("检测到并已移除 [done] 标记")
else:
is_truncated = True
logger.warning("未检测到 [done] 标记,判定为截断")
processed_prompt += self.anti_truncation_instruction
logger.info(f"'{model_name}' for task '{self.task_name}' 已启用反截断功能")
if is_empty_reply or is_truncated:
if empty_retry_count < max_empty_retry:
empty_retry_count += 1
reason = "空回复" if is_empty_reply else "截断"
logger.warning(f"检测到{reason},正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成")
processed_prompt = self._apply_content_obfuscation(processed_prompt, api_provider)
if empty_retry_interval > 0:
await asyncio.sleep(empty_retry_interval)
message_builder = MessageBuilder()
message_builder.add_text_content(processed_prompt)
messages = [message_builder.build()]
tool_built = self._build_tool_options(tools)
model_info, api_provider, client = self._select_model()
continue
else:
# 已达到最大重试次数,但仍然是空回复或截断
reason = "空回复" if is_empty_reply else "截断"
# 抛出异常,由外层重试逻辑或最终的异常处理器捕获
raise RuntimeError(f"经过 {max_empty_retry + 1} 次尝试后仍然是{reason}的回复")
# 针对当前模型的空回复/截断重试逻辑
empty_retry_count = 0
max_empty_retry = api_provider.max_retry
empty_retry_interval = api_provider.retry_interval
# 记录使用情况
if usage := response.usage:
llm_usage_recorder.record_usage_to_database(
while empty_retry_count <= max_empty_retry:
response = await self._execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.RESPONSE,
model_info=model_info,
model_usage=usage,
time_cost=time.time() - start_time,
user_id="system",
request_type=self.request_type,
endpoint="/chat/completions",
message_list=messages,
tool_options=tool_built,
temperature=temperature,
max_tokens=max_tokens,
)
# 处理空回复
if not content and not tool_calls:
if raise_when_empty:
raise RuntimeError(f"经过 {empty_retry_count} 次重试后仍然生成空回复")
content = "生成的响应为空,请检查模型配置或输入内容是否正确"
elif empty_retry_count > 0:
logger.info(f"经过 {empty_retry_count} 次重试后成功生成回复")
content = response.content or ""
reasoning_content = response.reasoning_content or ""
tool_calls = response.tool_calls
return content, (reasoning_content, model_info.name, tool_calls)
if not reasoning_content and content:
content, extracted_reasoning = self._extract_reasoning(content)
reasoning_content = extracted_reasoning
is_empty_reply = not tool_calls and (not content or content.strip() == "")
is_truncated = False
if use_anti_truncation:
if content.endswith("[done]"):
content = content[:-6].strip()
else:
is_truncated = True
if is_empty_reply or is_truncated:
empty_retry_count += 1
if empty_retry_count <= max_empty_retry:
reason = "空回复" if is_empty_reply else "截断"
logger.warning(f"模型 '{model_name}' 检测到{reason},正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成...")
if empty_retry_interval > 0:
await asyncio.sleep(empty_retry_interval)
continue # 继续使用当前模型重试
else:
# 当前模型重试次数用尽,跳出内层循环,触发外层循环切换模型
reason = "空回复" if is_empty_reply else "截断"
logger.error(f"模型 '{model_name}' 经过 {max_empty_retry} 次重试后仍然是{reason}的回复。")
raise RuntimeError(f"模型 '{model_name}' 达到最大空回复/截断重试次数")
# 成功获取响应
if usage := response.usage:
llm_usage_recorder.record_usage_to_database(
model_info=model_info, model_usage=usage, time_cost=time.time() - start_time,
user_id="system", request_type=self.request_type, endpoint="/chat/completions",
)
if not content and not tool_calls:
if raise_when_empty:
raise RuntimeError("生成空回复")
content = "生成的响应为空"
logger.info(f"模型 '{model_name}' 成功生成回复。")
return content, (reasoning_content, model_name, tool_calls)
except RespNotOkException as e:
if e.status_code in [401, 403]:
logger.error(f"模型 '{model_name}' 遇到认证/权限错误 (Code: {e.status_code}),将尝试下一个模型。")
failed_models.add(model_name)
last_exception = e
continue # 切换到下一个模型
else:
logger.error(f"模型 '{model_name}' 请求失败HTTP状态码: {e.status_code}")
if raise_when_empty:
raise
# 对于其他HTTP错误直接抛出不再尝试其他模型
return f"请求失败: {e}", ("", model_name, None)
except RuntimeError as e:
# 捕获所有重试失败(包括空回复和网络问题)
logger.error(f"模型 '{model_name}' 在所有重试后仍然失败: {e},将尝试下一个模型。")
failed_models.add(model_name)
last_exception = e
continue # 切换到下一个模型
except Exception as e:
logger.error(f"请求执行失败: {e}")
if raise_when_empty:
# 在非并发模式下,如果第一次尝试就失败,则直接抛出异常
if empty_retry_count == 0:
raise
logger.error(f"使用模型 '{model_name}' 时发生未知异常: {e}")
failed_models.add(model_name)
last_exception = e
continue # 切换到下一个模型
# 如果在重试过程中失败,则继续重试
empty_retry_count += 1
if empty_retry_count <= max_empty_retry:
logger.warning(f"请求失败,将在 {empty_retry_interval} 秒后进行第 {empty_retry_count}/{max_empty_retry} 次重试...")
if empty_retry_interval > 0:
await asyncio.sleep(empty_retry_interval)
continue
else:
logger.error(f"经过 {max_empty_retry} 次重试后仍然失败")
raise RuntimeError(f"经过 {max_empty_retry} 次重试后仍然无法生成有效回复") from e
else:
# 在并发模式下,单个请求的失败不应中断整个并发流程,
# 而是将异常返回给调用者(即 execute_concurrently进行统一处理
raise # 重新抛出异常,由 execute_concurrently 中的 gather 捕获
# 重试失败
# 所有模型都尝试失败
logger.error("所有可用模型都已尝试失败。")
if raise_when_empty:
raise RuntimeError(f"经过 {max_empty_retry} 次重试后仍然无法生成有效回复")
return "生成的响应为空,请检查模型配置或输入内容是否正确", ("", model_name, None)
if last_exception:
raise RuntimeError("所有模型都请求失败") from last_exception
raise RuntimeError("所有模型都请求失败,且没有具体的异常信息")
return "所有模型都请求失败", ("", "unknown", None)
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
"""获取嵌入向量
@@ -446,9 +445,24 @@ class LLMRequest:
return embedding, model_info.name
def _model_scheduler(self, failed_models: set) -> Generator[Tuple[ModelInfo, APIProvider, BaseClient], None, None]:
"""
一个模型调度器,按顺序提供模型,并跳过已失败的模型。
"""
for model_name in self.model_for_task.model_list:
if model_name in failed_models:
continue
model_info = model_config.get_model_info(model_name)
api_provider = model_config.get_provider(model_info.api_provider)
force_new_client = (self.request_type == "embedding")
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
yield model_info, api_provider, client
def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]:
"""
根据总tokens和惩罚值选择的模型
根据总tokens和惩罚值选择的模型 (负载均衡)
"""
least_used_model_name = min(
self.model_usage,

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Type
from typing import Optional, Type
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.component_types import ComponentType

View File

@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import List, Type, Tuple, Union, TYPE_CHECKING
from typing import List, Type, Tuple, Union
from .plugin_base import PluginBase
from src.common.logger import get_logger

View File

@@ -4,7 +4,7 @@
"""
from abc import ABC, abstractmethod
from typing import Dict, Tuple, Optional, List
from typing import Tuple, Optional, List
import re
from src.common.logger import get_logger

View File

@@ -7,8 +7,10 @@ from src.llm_models.utils_model import LLMRequest
from src.llm_models.payload_content import ToolCall
from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
import inspect
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger
from src.common.cache_manager import tool_cache
logger = get_logger("tool_use")
@@ -184,21 +186,65 @@ class ToolExecutor:
return tool_results, used_tools
async def execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]:
# sourcery skip: use-assigned-variable
"""执行单个工具调用
"""执行单个工具调用,并处理缓存"""
function_args = tool_call.args or {}
tool_instance = tool_instance or get_tool_instance(tool_call.func_name)
Args:
tool_call: 工具调用对象
# 如果工具不存在或未启用缓存,则直接执行
if not tool_instance or not tool_instance.enable_cache:
return await self._original_execute_tool_call(tool_call, tool_instance)
Returns:
Optional[Dict]: 工具调用结果如果失败则返回None
"""
# --- 缓存逻辑开始 ---
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
cached_result = await tool_cache.get(
tool_name=tool_call.func_name,
function_args=function_args,
tool_file_path=tool_file_path,
semantic_query=semantic_query
)
if cached_result:
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
return cached_result
except Exception as e:
logger.error(f"{self.log_prefix}检查工具缓存时出错: {e}")
# 缓存未命中,执行原始工具调用
result = await self._original_execute_tool_call(tool_call, tool_instance)
# 将结果存入缓存
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
await tool_cache.set(
tool_name=tool_call.func_name,
function_args=function_args,
tool_file_path=tool_file_path,
data=result,
ttl=tool_instance.cache_ttl,
semantic_query=semantic_query
)
except Exception as e:
logger.error(f"{self.log_prefix}设置工具缓存时出错: {e}")
# --- 缓存逻辑结束 ---
return result
async def _original_execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]:
"""执行单个工具调用的原始逻辑"""
try:
function_name = tool_call.func_name
function_args = tool_call.args or {}
logger.info(f"🤖 {self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}")
logger.info(f"{self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}")
function_args["llm_called"] = True # 标记为LLM调用
# 获取对应工具实例
tool_instance = tool_instance or get_tool_instance(function_name)
if not tool_instance:

View File

@@ -24,6 +24,7 @@ from .services.qzone_service import QZoneService
from .services.scheduler_service import SchedulerService
from .services.monitor_service import MonitorService
from .services.cookie_service import CookieService
from .services.reply_tracker_service import ReplyTrackerService
from .services.manager import register_service
logger = get_logger("MaiZone.Plugin")
@@ -92,11 +93,13 @@ class MaiZoneRefactoredPlugin(BasePlugin):
content_service = ContentService(self.get_config)
image_service = ImageService(self.get_config)
cookie_service = CookieService(self.get_config)
reply_tracker_service = ReplyTrackerService()
qzone_service = QZoneService(self.get_config, content_service, image_service, cookie_service)
scheduler_service = SchedulerService(self.get_config, qzone_service)
monitor_service = MonitorService(self.get_config, qzone_service)
register_service("qzone", qzone_service)
register_service("reply_tracker", reply_tracker_service)
register_service("get_config", self.get_config)
asyncio.create_task(scheduler_service.start())

View File

@@ -9,12 +9,9 @@ import datetime
import base64
import aiohttp
from src.common.logger import get_logger
import base64
import aiohttp
import imghdr
import asyncio
from src.common.logger import get_logger
from src.plugin_system.apis import llm_api, config_api, generator_api, person_api
from src.plugin_system.apis import llm_api, config_api, generator_api
from src.chat.message_receive.chat_stream import get_chat_manager
from maim_message import UserInfo
from src.llm_models.utils_model import LLMRequest

View File

@@ -27,6 +27,7 @@ from src.chat.utils.chat_message_builder import (
from .content_service import ContentService
from .image_service import ImageService
from .cookie_service import CookieService
from .reply_tracker_service import ReplyTrackerService
logger = get_logger("MaiZone.QZoneService")
@@ -55,6 +56,7 @@ class QZoneService:
self.content_service = content_service
self.image_service = image_service
self.cookie_service = cookie_service
self.reply_tracker = ReplyTrackerService()
# --- Public Methods (High-Level Business Logic) ---
@@ -154,7 +156,8 @@ class QZoneService:
# --- 第一步: 单独处理自己说说的评论 ---
if self.get_config("monitor.enable_auto_reply", False):
try:
own_feeds = await api_client["list_feeds"](qq_account, 5) # 获取自己最近5条说说
# 传入新参数,表明正在检查自己的说说
own_feeds = await api_client["list_feeds"](qq_account, 5, is_monitoring_own_feeds=True)
if own_feeds:
logger.info(f"获取到自己 {len(own_feeds)} 条说说,检查评论...")
for feed in own_feeds:
@@ -248,42 +251,83 @@ class QZoneService:
content = feed.get("content", "")
fid = feed.get("tid", "")
if not comments:
if not comments or not fid:
return
# 筛选出未被自己回复过的评论
if not comments:
# 1. 将评论分为用户评论和自己回复
user_comments = [c for c in comments if str(c.get('qq_account')) != str(qq_account)]
my_replies = [c for c in comments if str(c.get('qq_account')) == str(qq_account)]
if not user_comments:
return
# 找到所有我已经回复过的评论的ID
replied_to_tids = {
c['parent_tid'] for c in comments
if c.get('parent_tid') and str(c.get('qq_account')) == str(qq_account)
}
# 2. 验证已记录的回复是否仍然存在,清理已删除的回复记录
await self._validate_and_cleanup_reply_records(fid, my_replies)
# 找出所有非我发出且我未回复的评论
comments_to_reply = [
c for c in comments
if str(c.get('qq_account')) != str(qq_account) and c.get('comment_tid') not in replied_to_tids
]
# 3. 使用验证后的持久化记录来筛选未回复的评论
comments_to_reply = []
for comment in user_comments:
comment_tid = comment.get('comment_tid')
if not comment_tid:
continue
# 检查是否已经在持久化记录中标记为已回复
if not self.reply_tracker.has_replied(fid, comment_tid):
comments_to_reply.append(comment)
if not comments_to_reply:
logger.debug(f"说说 {fid} 下的所有评论都已回复过")
return
logger.info(f"发现自己说说下的 {len(comments_to_reply)} 条新评论,准备回复...")
for comment in comments_to_reply:
reply_content = await self.content_service.generate_comment_reply(
content, comment.get("content", ""), comment.get("nickname", "")
)
if reply_content:
success = await api_client["reply"](
fid, qq_account, comment.get("nickname", ""), reply_content, comment.get("comment_tid")
comment_tid = comment.get("comment_tid")
nickname = comment.get("nickname", "")
comment_content = comment.get("content", "")
try:
reply_content = await self.content_service.generate_comment_reply(
content, comment_content, nickname
)
if success:
logger.info(f"成功回复'{comment.get('nickname', '')}'的评论: '{reply_content}'")
if reply_content:
success = await api_client["reply"](
fid, qq_account, nickname, reply_content, comment_tid
)
if success:
# 标记为已回复
self.reply_tracker.mark_as_replied(fid, comment_tid)
logger.info(f"成功回复'{nickname}'的评论: '{reply_content}'")
else:
logger.error(f"回复'{nickname}'的评论失败")
await asyncio.sleep(random.uniform(10, 20))
else:
logger.error(f"回复'{comment.get('nickname', '')}'的评论失败")
await asyncio.sleep(random.uniform(10, 20))
logger.warning(f"生成回复内容失败,跳过回复'{nickname}'的评论")
except Exception as e:
logger.error(f"回复'{nickname}'的评论时发生异常: {e}", exc_info=True)
async def _validate_and_cleanup_reply_records(self, fid: str, my_replies: List[Dict]):
"""验证并清理已删除的回复记录"""
# 获取当前记录中该说说的所有已回复评论ID
recorded_replied_comments = self.reply_tracker.get_replied_comments(fid)
if not recorded_replied_comments:
return
# 从API返回的我的回复中提取parent_tid即被回复的评论ID
current_replied_comments = set()
for reply in my_replies:
parent_tid = reply.get('parent_tid')
if parent_tid:
current_replied_comments.add(parent_tid)
# 找出记录中有但实际已不存在的回复
deleted_replies = recorded_replied_comments - current_replied_comments
if deleted_replies:
logger.info(f"检测到 {len(deleted_replies)} 个回复已被删除,清理记录...")
for comment_tid in deleted_replies:
self.reply_tracker.remove_reply_record(fid, comment_tid)
logger.debug(f"已清理删除的回复记录: feed_id={fid}, comment_id={comment_tid}")
async def _process_single_feed(self, feed: Dict, api_client: Dict, target_qq: str, target_name: str):
"""处理单条说说,决定是否评论和点赞"""
@@ -641,7 +685,7 @@ class QZoneService:
logger.error(f"上传图片 {index+1} 异常: {e}", exc_info=True)
return None
async def _list_feeds(t_qq: str, num: int) -> List[Dict]:
async def _list_feeds(t_qq: str, num: int, is_monitoring_own_feeds: bool = False) -> List[Dict]:
"""获取指定用户说说列表"""
try:
params = {
@@ -667,37 +711,41 @@ class QZoneService:
feeds_list = []
my_name = json_data.get("logininfo", {}).get("name", "")
for msg in json_data.get("msglist", []):
is_commented = any(
c.get("name") == my_name for c in msg.get("commentlist", []) if isinstance(c, dict)
)
if not is_commented:
images = [pic['url1'] for pic in msg.get('pictotal', []) if 'url1' in pic]
comments = []
if 'commentlist' in msg:
for c in msg['commentlist']:
comments.append({
'qq_account': c.get('uin'),
'nickname': c.get('name'),
'content': c.get('content'),
'comment_tid': c.get('tid'),
'parent_tid': c.get('parent_tid') # API直接返回了父ID
})
feeds_list.append(
{
"tid": msg.get("tid", ""),
"content": msg.get("content", ""),
"created_time": time.strftime(
"%Y-%m-%d %H:%M:%S", time.localtime(msg.get("created_time", 0))
),
"rt_con": msg.get("rt_con", {}).get("content", "")
if isinstance(msg.get("rt_con"), dict)
else "",
"images": images,
"comments": comments
}
# 只有在处理好友说说时,才检查是否已评论并跳过
if not is_monitoring_own_feeds:
is_commented = any(
c.get("name") == my_name for c in msg.get("commentlist", []) if isinstance(c, dict)
)
if is_commented:
continue
images = [pic['url1'] for pic in msg.get('pictotal', []) if 'url1' in pic]
comments = []
if 'commentlist' in msg:
for c in msg['commentlist']:
comments.append({
'qq_account': c.get('uin'),
'nickname': c.get('name'),
'content': c.get('content'),
'comment_tid': c.get('tid'),
'parent_tid': c.get('parent_tid') # API直接返回了父ID
})
feeds_list.append(
{
"tid": msg.get("tid", ""),
"content": msg.get("content", ""),
"created_time": time.strftime(
"%Y-%m-%d %H:%M:%S", time.localtime(msg.get("created_time", 0))
),
"rt_con": msg.get("rt_con", {}).get("content", "")
if isinstance(msg.get("rt_con"), dict)
else "",
"images": images,
"comments": comments
}
)
return feeds_list
except Exception as e:
logger.error(f"获取说说列表失败: {e}", exc_info=True)

View File

@@ -0,0 +1,195 @@
# -*- coding: utf-8 -*-
"""
评论回复跟踪服务
负责记录和管理已回复过的评论ID避免重复回复
"""
import json
import time
from pathlib import Path
from typing import Set, Dict, Any
from src.common.logger import get_logger
logger = get_logger("MaiZone.ReplyTrackerService")
class ReplyTrackerService:
"""
评论回复跟踪服务
使用本地JSON文件持久化存储已回复的评论ID
"""
def __init__(self):
# 数据存储路径
self.data_dir = Path(__file__).resolve().parent.parent / "data"
self.data_dir.mkdir(exist_ok=True)
self.reply_record_file = self.data_dir / "replied_comments.json"
# 内存中的已回复评论记录
# 格式: {feed_id: {comment_id: timestamp, ...}, ...}
self.replied_comments: Dict[str, Dict[str, float]] = {}
# 数据清理配置
self.max_record_days = 30 # 保留30天的记录
# 加载已有数据
self._load_data()
def _load_data(self):
"""从文件加载已回复评论数据"""
try:
if self.reply_record_file.exists():
with open(self.reply_record_file, 'r', encoding='utf-8') as f:
data = json.load(f)
self.replied_comments = data
logger.info(f"已加载 {len(self.replied_comments)} 条说说的回复记录")
else:
logger.info("未找到回复记录文件,将创建新的记录")
except Exception as e:
logger.error(f"加载回复记录失败: {e}")
self.replied_comments = {}
def _save_data(self):
"""保存已回复评论数据到文件"""
try:
# 清理过期数据
self._cleanup_old_records()
with open(self.reply_record_file, 'w', encoding='utf-8') as f:
json.dump(self.replied_comments, f, ensure_ascii=False, indent=2)
logger.debug("回复记录已保存")
except Exception as e:
logger.error(f"保存回复记录失败: {e}")
def _cleanup_old_records(self):
"""清理超过保留期限的记录"""
current_time = time.time()
cutoff_time = current_time - (self.max_record_days * 24 * 60 * 60)
feeds_to_remove = []
total_removed = 0
for feed_id, comments in self.replied_comments.items():
comments_to_remove = []
for comment_id, timestamp in comments.items():
if timestamp < cutoff_time:
comments_to_remove.append(comment_id)
# 移除过期的评论记录
for comment_id in comments_to_remove:
del comments[comment_id]
total_removed += 1
# 如果该说说下没有任何记录了,标记删除整个说说记录
if not comments:
feeds_to_remove.append(feed_id)
# 移除空的说说记录
for feed_id in feeds_to_remove:
del self.replied_comments[feed_id]
if total_removed > 0:
logger.info(f"清理了 {total_removed} 条过期的回复记录")
def has_replied(self, feed_id: str, comment_id: str) -> bool:
"""
检查是否已经回复过指定的评论
Args:
feed_id: 说说ID
comment_id: 评论ID
Returns:
bool: 如果已回复过返回True否则返回False
"""
if not feed_id or not comment_id:
return False
return (feed_id in self.replied_comments and
comment_id in self.replied_comments[feed_id])
def mark_as_replied(self, feed_id: str, comment_id: str):
"""
标记指定评论为已回复
Args:
feed_id: 说说ID
comment_id: 评论ID
"""
if not feed_id or not comment_id:
logger.warning("feed_id 或 comment_id 为空,无法标记为已回复")
return
current_time = time.time()
if feed_id not in self.replied_comments:
self.replied_comments[feed_id] = {}
self.replied_comments[feed_id][comment_id] = current_time
# 保存到文件
self._save_data()
logger.info(f"已标记评论为已回复: feed_id={feed_id}, comment_id={comment_id}")
def get_replied_comments(self, feed_id: str) -> Set[str]:
"""
获取指定说说下所有已回复的评论ID
Args:
feed_id: 说说ID
Returns:
Set[str]: 已回复的评论ID集合
"""
if feed_id in self.replied_comments:
return set(self.replied_comments[feed_id].keys())
return set()
def get_stats(self) -> Dict[str, Any]:
"""
获取回复记录统计信息
Returns:
Dict: 包含统计信息的字典
"""
total_feeds = len(self.replied_comments)
total_replies = sum(len(comments) for comments in self.replied_comments.values())
return {
"total_feeds_with_replies": total_feeds,
"total_replied_comments": total_replies,
"data_file": str(self.reply_record_file),
"max_record_days": self.max_record_days
}
def remove_reply_record(self, feed_id: str, comment_id: str):
"""
移除指定评论的回复记录
Args:
feed_id: 说说ID
comment_id: 评论ID
"""
if feed_id in self.replied_comments and comment_id in self.replied_comments[feed_id]:
del self.replied_comments[feed_id][comment_id]
# 如果该说说下没有任何回复记录了,删除整个说说记录
if not self.replied_comments[feed_id]:
del self.replied_comments[feed_id]
self._save_data()
logger.debug(f"已移除回复记录: feed_id={feed_id}, comment_id={comment_id}")
def remove_feed_records(self, feed_id: str):
"""
移除指定说说的所有回复记录
Args:
feed_id: 说说ID
"""
if feed_id in self.replied_comments:
del self.replied_comments[feed_id]
self._save_data()
logger.info(f"已移除说说 {feed_id} 的所有回复记录")

View File

@@ -16,7 +16,7 @@ from src.plugin_system.apis.permission_api import permission_api
from src.plugin_system.apis.logging_api import get_logger
from src.plugin_system.base.component_types import PlusCommandInfo, ChatType
from src.plugin_system.base.config_types import ConfigField
from src.plugin_system.utils.permission_decorators import require_permission, require_master, PermissionChecker
from src.plugin_system.utils.permission_decorators import require_permission
logger = get_logger("Permission")

View File

@@ -411,7 +411,6 @@ class ScheduleManager:
通过关键词匹配、唤醒度、睡眠压力等综合判断是否处于休眠时间。
新增弹性睡眠机制,允许在压力低时延迟入睡,并在入睡前发送通知。
"""
from src.chat.chat_loop.wakeup_manager import WakeUpManager
# --- 基础检查 ---
if not global_config.schedule.enable_is_sleep:
return False