Merge branch 'master' of https://github.com/MoFox-Studio/MoFox_Bot
This commit is contained in:
@@ -10,7 +10,6 @@ from src.chat.express.expression_learner import expression_learner_manager
|
||||
from src.plugin_system.base.component_types import ChatMode
|
||||
from src.schedule.schedule_manager import schedule_manager
|
||||
from src.plugin_system.apis import message_api
|
||||
from src.mood.mood_manager import mood_manager
|
||||
|
||||
from .hfc_context import HfcContext
|
||||
from .energy_manager import EnergyManager
|
||||
|
||||
@@ -4,7 +4,7 @@ import hashlib
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import faiss
|
||||
from typing import Any, Dict, Optional, Union, List
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
@@ -14,6 +14,7 @@ from src.common.vector_db import vector_db_service
|
||||
|
||||
logger = get_logger("cache_manager")
|
||||
|
||||
|
||||
class CacheManager:
|
||||
"""
|
||||
一个支持分层和语义缓存的通用工具缓存管理器。
|
||||
@@ -21,6 +22,7 @@ class CacheManager:
|
||||
L1缓存: 内存字典 (KV) + FAISS (Vector)。
|
||||
L2缓存: 数据库 (KV) + ChromaDB (Vector)。
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
@@ -32,7 +34,7 @@ class CacheManager:
|
||||
"""
|
||||
初始化缓存管理器。
|
||||
"""
|
||||
if not hasattr(self, '_initialized'):
|
||||
if not hasattr(self, "_initialized"):
|
||||
self.default_ttl = default_ttl
|
||||
self.semantic_cache_collection_name = "semantic_cache"
|
||||
|
||||
@@ -41,7 +43,7 @@ class CacheManager:
|
||||
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
|
||||
self.l1_vector_index = faiss.IndexFlatIP(embedding_dim)
|
||||
self.l1_vector_id_to_key: Dict[int, str] = {}
|
||||
|
||||
|
||||
# L2 向量缓存 (使用新的服务)
|
||||
vector_db_service.get_or_create_collection(self.semantic_cache_collection_name)
|
||||
|
||||
@@ -58,32 +60,32 @@ class CacheManager:
|
||||
try:
|
||||
if embedding_result is None:
|
||||
return None
|
||||
|
||||
|
||||
# 确保embedding_result是一维数组或列表
|
||||
if isinstance(embedding_result, (list, tuple, np.ndarray)):
|
||||
# 转换为numpy数组进行处理
|
||||
embedding_array = np.array(embedding_result)
|
||||
|
||||
|
||||
# 如果是多维数组,展平它
|
||||
if embedding_array.ndim > 1:
|
||||
embedding_array = embedding_array.flatten()
|
||||
|
||||
|
||||
# 检查维度是否符合预期
|
||||
expected_dim = global_config.lpmm_knowledge.embedding_dimension
|
||||
if embedding_array.shape[0] != expected_dim:
|
||||
logger.warning(f"嵌入向量维度不匹配: 期望 {expected_dim}, 实际 {embedding_array.shape[0]}")
|
||||
return None
|
||||
|
||||
|
||||
# 检查是否包含有效的数值
|
||||
if np.isnan(embedding_array).any() or np.isinf(embedding_array).any():
|
||||
logger.warning("嵌入向量包含无效的数值 (NaN 或 Inf)")
|
||||
return None
|
||||
|
||||
return embedding_array.astype('float32')
|
||||
|
||||
return embedding_array.astype("float32")
|
||||
else:
|
||||
logger.warning(f"嵌入结果格式不支持: {type(embedding_result)}")
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"验证嵌入向量时发生错误: {e}")
|
||||
return None
|
||||
@@ -102,14 +104,20 @@ class CacheManager:
|
||||
except (OSError, TypeError) as e:
|
||||
file_hash = "unknown"
|
||||
logger.warning(f"无法获取文件信息: {tool_file_path},错误: {e}")
|
||||
|
||||
|
||||
try:
|
||||
sorted_args = orjson.dumps(function_args, option=orjson.OPT_SORT_KEYS).decode('utf-8')
|
||||
sorted_args = orjson.dumps(function_args, option=orjson.OPT_SORT_KEYS).decode("utf-8")
|
||||
except TypeError:
|
||||
sorted_args = repr(sorted(function_args.items()))
|
||||
return f"{tool_name}::{sorted_args}::{file_hash}"
|
||||
|
||||
async def get(self, tool_name: str, function_args: Dict[str, Any], tool_file_path: Union[str, Path], semantic_query: Optional[str] = None) -> Optional[Any]:
|
||||
async def get(
|
||||
self,
|
||||
tool_name: str,
|
||||
function_args: Dict[str, Any],
|
||||
tool_file_path: Union[str, Path],
|
||||
semantic_query: Optional[str] = None,
|
||||
) -> Optional[Any]:
|
||||
"""
|
||||
从缓存获取结果,查询顺序: L1-KV -> L1-Vector -> L2-KV -> L2-Vector。
|
||||
"""
|
||||
@@ -136,13 +144,13 @@ class CacheManager:
|
||||
embedding_vector = embedding_result[0] if isinstance(embedding_result, tuple) else embedding_result
|
||||
validated_embedding = self._validate_embedding(embedding_vector)
|
||||
if validated_embedding is not None:
|
||||
query_embedding = np.array([validated_embedding], dtype='float32')
|
||||
query_embedding = np.array([validated_embedding], dtype="float32")
|
||||
|
||||
# 步骤 2a: L1 语义缓存 (FAISS)
|
||||
if query_embedding is not None and self.l1_vector_index.ntotal > 0:
|
||||
faiss.normalize_L2(query_embedding)
|
||||
distances, indices = self.l1_vector_index.search(query_embedding, 1) # type: ignore
|
||||
if indices.size > 0 and distances[0][0] > 0.75: # IP 越大越相似
|
||||
distances, indices = self.l1_vector_index.search(query_embedding, 1) # type: ignore
|
||||
if indices.size > 0 and distances[0][0] > 0.75: # IP 越大越相似
|
||||
hit_index = indices[0][0]
|
||||
l1_hit_key = self.l1_vector_id_to_key.get(hit_index)
|
||||
if l1_hit_key and l1_hit_key in self.l1_kv_cache:
|
||||
@@ -151,12 +159,9 @@ class CacheManager:
|
||||
|
||||
# 步骤 2b: L2 精确缓存 (数据库)
|
||||
cache_results_obj = await db_query(
|
||||
model_class=CacheEntries,
|
||||
query_type="get",
|
||||
filters={"cache_key": key},
|
||||
single_result=True
|
||||
model_class=CacheEntries, query_type="get", filters={"cache_key": key}, single_result=True
|
||||
)
|
||||
|
||||
|
||||
if cache_results_obj:
|
||||
# 使用 getattr 安全访问属性,避免 Pylance 类型检查错误
|
||||
expires_at = getattr(cache_results_obj, "expires_at", 0)
|
||||
@@ -164,7 +169,7 @@ class CacheManager:
|
||||
logger.info(f"命中L2键值缓存: {key}")
|
||||
cache_value = getattr(cache_results_obj, "cache_value", "{}")
|
||||
data = orjson.loads(cache_value)
|
||||
|
||||
|
||||
# 更新访问统计
|
||||
await db_query(
|
||||
model_class=CacheEntries,
|
||||
@@ -172,20 +177,16 @@ class CacheManager:
|
||||
filters={"cache_key": key},
|
||||
data={
|
||||
"last_accessed": time.time(),
|
||||
"access_count": getattr(cache_results_obj, "access_count", 0) + 1
|
||||
}
|
||||
"access_count": getattr(cache_results_obj, "access_count", 0) + 1,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# 回填 L1
|
||||
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
|
||||
return data
|
||||
else:
|
||||
# 删除过期的缓存条目
|
||||
await db_query(
|
||||
model_class=CacheEntries,
|
||||
query_type="delete",
|
||||
filters={"cache_key": key}
|
||||
)
|
||||
await db_query(model_class=CacheEntries, query_type="delete", filters={"cache_key": key})
|
||||
|
||||
# 步骤 2c: L2 语义缓存 (VectorDB Service)
|
||||
if query_embedding is not None:
|
||||
@@ -193,31 +194,33 @@ class CacheManager:
|
||||
results = vector_db_service.query(
|
||||
collection_name=self.semantic_cache_collection_name,
|
||||
query_embeddings=query_embedding.tolist(),
|
||||
n_results=1
|
||||
n_results=1,
|
||||
)
|
||||
if results and results.get('ids') and results['ids'][0]:
|
||||
distance = results['distances'][0][0] if results.get('distances') and results['distances'][0] else 'N/A'
|
||||
if results and results.get("ids") and results["ids"][0]:
|
||||
distance = (
|
||||
results["distances"][0][0] if results.get("distances") and results["distances"][0] else "N/A"
|
||||
)
|
||||
logger.debug(f"L2语义搜索找到最相似的结果: id={results['ids'][0]}, 距离={distance}")
|
||||
|
||||
if distance != 'N/A' and distance < 0.75:
|
||||
l2_hit_key = results['ids'][0][0] if isinstance(results['ids'][0], list) else results['ids'][0]
|
||||
|
||||
if distance != "N/A" and distance < 0.75:
|
||||
l2_hit_key = results["ids"][0][0] if isinstance(results["ids"][0], list) else results["ids"][0]
|
||||
logger.info(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}")
|
||||
|
||||
|
||||
# 从数据库获取缓存数据
|
||||
semantic_cache_results_obj = await db_query(
|
||||
model_class=CacheEntries,
|
||||
query_type="get",
|
||||
filters={"cache_key": l2_hit_key},
|
||||
single_result=True
|
||||
single_result=True,
|
||||
)
|
||||
|
||||
|
||||
if semantic_cache_results_obj:
|
||||
expires_at = getattr(semantic_cache_results_obj, "expires_at", 0)
|
||||
if time.time() < expires_at:
|
||||
cache_value = getattr(semantic_cache_results_obj, "cache_value", "{}")
|
||||
data = orjson.loads(cache_value)
|
||||
logger.debug(f"L2语义缓存返回的数据: {data}")
|
||||
|
||||
|
||||
# 回填 L1
|
||||
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
|
||||
if query_embedding is not None:
|
||||
@@ -235,7 +238,15 @@ class CacheManager:
|
||||
logger.debug(f"缓存未命中: {key}")
|
||||
return None
|
||||
|
||||
async def set(self, tool_name: str, function_args: Dict[str, Any], tool_file_path: Union[str, Path], data: Any, ttl: Optional[int] = None, semantic_query: Optional[str] = None):
|
||||
async def set(
|
||||
self,
|
||||
tool_name: str,
|
||||
function_args: Dict[str, Any],
|
||||
tool_file_path: Union[str, Path],
|
||||
data: Any,
|
||||
ttl: Optional[int] = None,
|
||||
semantic_query: Optional[str] = None,
|
||||
):
|
||||
"""将结果存入所有缓存层。"""
|
||||
if ttl is None:
|
||||
ttl = self.default_ttl
|
||||
@@ -244,27 +255,22 @@ class CacheManager:
|
||||
|
||||
key = self._generate_key(tool_name, function_args, tool_file_path)
|
||||
expires_at = time.time() + ttl
|
||||
|
||||
|
||||
# 写入 L1
|
||||
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
|
||||
|
||||
# 写入 L2 (数据库)
|
||||
cache_data = {
|
||||
"cache_key": key,
|
||||
"cache_value": orjson.dumps(data).decode('utf-8'),
|
||||
"cache_value": orjson.dumps(data).decode("utf-8"),
|
||||
"expires_at": expires_at,
|
||||
"tool_name": tool_name,
|
||||
"created_at": time.time(),
|
||||
"last_accessed": time.time(),
|
||||
"access_count": 1
|
||||
"access_count": 1,
|
||||
}
|
||||
|
||||
await db_save(
|
||||
model_class=CacheEntries,
|
||||
data=cache_data,
|
||||
key_field="cache_key",
|
||||
key_value=key
|
||||
)
|
||||
|
||||
await db_save(model_class=CacheEntries, data=cache_data, key_field="cache_key", key_value=key)
|
||||
|
||||
# 写入语义缓存
|
||||
if semantic_query and self.embedding_model:
|
||||
@@ -274,19 +280,19 @@ class CacheManager:
|
||||
embedding_vector = embedding_result[0] if isinstance(embedding_result, tuple) else embedding_result
|
||||
validated_embedding = self._validate_embedding(embedding_vector)
|
||||
if validated_embedding is not None:
|
||||
embedding = np.array([validated_embedding], dtype='float32')
|
||||
|
||||
embedding = np.array([validated_embedding], dtype="float32")
|
||||
|
||||
# 写入 L1 Vector
|
||||
new_id = self.l1_vector_index.ntotal
|
||||
faiss.normalize_L2(embedding)
|
||||
self.l1_vector_index.add(x=embedding) # type: ignore
|
||||
self.l1_vector_id_to_key[new_id] = key
|
||||
|
||||
|
||||
# 写入 L2 Vector (使用新的服务)
|
||||
vector_db_service.add(
|
||||
collection_name=self.semantic_cache_collection_name,
|
||||
embeddings=embedding.tolist(),
|
||||
ids=[key]
|
||||
ids=[key],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"语义缓存写入失败: {e}")
|
||||
@@ -306,16 +312,16 @@ class CacheManager:
|
||||
await db_query(
|
||||
model_class=CacheEntries,
|
||||
query_type="delete",
|
||||
filters={} # 删除所有记录
|
||||
filters={}, # 删除所有记录
|
||||
)
|
||||
|
||||
|
||||
# 清空 VectorDB
|
||||
try:
|
||||
vector_db_service.delete_collection(name=self.semantic_cache_collection_name)
|
||||
vector_db_service.get_or_create_collection(name=self.semantic_cache_collection_name)
|
||||
except Exception as e:
|
||||
logger.warning(f"清空 VectorDB 集合失败: {e}")
|
||||
|
||||
|
||||
logger.info("L2 (数据库 & VectorDB) 缓存已清空。")
|
||||
|
||||
async def clear_all(self):
|
||||
@@ -327,85 +333,23 @@ class CacheManager:
|
||||
async def clean_expired(self):
|
||||
"""清理过期的缓存条目"""
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
# 清理L1过期条目
|
||||
expired_keys = []
|
||||
for key, entry in self.l1_kv_cache.items():
|
||||
if current_time >= entry["expires_at"]:
|
||||
expired_keys.append(key)
|
||||
|
||||
|
||||
for key in expired_keys:
|
||||
del self.l1_kv_cache[key]
|
||||
|
||||
|
||||
# 清理L2过期条目
|
||||
await db_query(
|
||||
model_class=CacheEntries,
|
||||
query_type="delete",
|
||||
filters={"expires_at": {"$lt": current_time}}
|
||||
)
|
||||
|
||||
await db_query(model_class=CacheEntries, query_type="delete", filters={"expires_at": {"$lt": current_time}})
|
||||
|
||||
if expired_keys:
|
||||
logger.info(f"清理了 {len(expired_keys)} 个过期的L1缓存条目")
|
||||
|
||||
|
||||
# 全局实例
|
||||
tool_cache = CacheManager()
|
||||
|
||||
import inspect
|
||||
import time
|
||||
|
||||
def wrap_tool_executor():
|
||||
"""
|
||||
包装工具执行器以添加缓存功能
|
||||
这个函数应该在系统启动时被调用一次
|
||||
"""
|
||||
from src.plugin_system.core.tool_use import ToolExecutor
|
||||
from src.plugin_system.apis.tool_api import get_tool_instance
|
||||
original_execute = ToolExecutor.execute_tool_call
|
||||
|
||||
async def wrapped_execute_tool_call(self, tool_call, tool_instance=None):
|
||||
if not tool_instance:
|
||||
tool_instance = get_tool_instance(tool_call.func_name)
|
||||
|
||||
if not tool_instance or not tool_instance.enable_cache:
|
||||
return await original_execute(self, tool_call, tool_instance)
|
||||
|
||||
try:
|
||||
tool_file_path = inspect.getfile(tool_instance.__class__)
|
||||
semantic_query = None
|
||||
if tool_instance.semantic_cache_query_key:
|
||||
semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key)
|
||||
|
||||
cached_result = await tool_cache.get(
|
||||
tool_name=tool_call.func_name,
|
||||
function_args=tool_call.args,
|
||||
tool_file_path=tool_file_path,
|
||||
semantic_query=semantic_query
|
||||
)
|
||||
if cached_result:
|
||||
logger.info(f"{getattr(self, 'log_prefix', '')}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
|
||||
return cached_result
|
||||
except Exception as e:
|
||||
logger.error(f"{getattr(self, 'log_prefix', '')}检查工具缓存时出错: {e}")
|
||||
|
||||
result = await original_execute(self, tool_call, tool_instance)
|
||||
|
||||
try:
|
||||
tool_file_path = inspect.getfile(tool_instance.__class__)
|
||||
semantic_query = None
|
||||
if tool_instance.semantic_cache_query_key:
|
||||
semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key)
|
||||
|
||||
await tool_cache.set(
|
||||
tool_name=tool_call.func_name,
|
||||
function_args=tool_call.args,
|
||||
tool_file_path=tool_file_path,
|
||||
data=result,
|
||||
ttl=tool_instance.cache_ttl,
|
||||
semantic_query=semantic_query
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{getattr(self, 'log_prefix', '')}设置工具缓存时出错: {e}")
|
||||
|
||||
return result
|
||||
|
||||
ToolExecutor.execute_tool_call = wrapped_execute_tool_call
|
||||
@@ -5,7 +5,7 @@ import random
|
||||
|
||||
from enum import Enum
|
||||
from rich.traceback import install
|
||||
from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine
|
||||
from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine, Generator
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
@@ -283,131 +283,130 @@ class LLMRequest:
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
raise_when_empty: bool = True,
|
||||
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
|
||||
"""执行单次请求"""
|
||||
# 模型选择和请求准备
|
||||
start_time = time.time()
|
||||
model_info, api_provider, client = self._select_model()
|
||||
model_name = model_info.name
|
||||
|
||||
# 检查是否启用反截断
|
||||
use_anti_truncation = getattr(api_provider, "anti_truncation", False)
|
||||
|
||||
processed_prompt = prompt
|
||||
if use_anti_truncation:
|
||||
processed_prompt += self.anti_truncation_instruction
|
||||
logger.info(f"{api_provider} '{self.task_name}' 已启用反截断功能")
|
||||
|
||||
processed_prompt = self._apply_content_obfuscation(processed_prompt, api_provider)
|
||||
|
||||
message_builder = MessageBuilder()
|
||||
message_builder.add_text_content(processed_prompt)
|
||||
messages = [message_builder.build()]
|
||||
tool_built = self._build_tool_options(tools)
|
||||
|
||||
# 空回复重试逻辑
|
||||
empty_retry_count = 0
|
||||
max_empty_retry = api_provider.max_retry
|
||||
empty_retry_interval = api_provider.retry_interval
|
||||
|
||||
while empty_retry_count <= max_empty_retry:
|
||||
"""
|
||||
执行单次请求,并在模型失败时按顺序切换到下一个可用模型。
|
||||
"""
|
||||
failed_models = set()
|
||||
last_exception: Optional[Exception] = None
|
||||
|
||||
model_scheduler = self._model_scheduler(failed_models)
|
||||
|
||||
for model_info, api_provider, client in model_scheduler:
|
||||
start_time = time.time()
|
||||
model_name = model_info.name
|
||||
logger.info(f"正在尝试使用模型: {model_name}")
|
||||
|
||||
try:
|
||||
response = await self._execute_request(
|
||||
api_provider=api_provider,
|
||||
client=client,
|
||||
request_type=RequestType.RESPONSE,
|
||||
model_info=model_info,
|
||||
message_list=messages,
|
||||
tool_options=tool_built,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
content = response.content or ""
|
||||
reasoning_content = response.reasoning_content or ""
|
||||
tool_calls = response.tool_calls
|
||||
# 从内容中提取<think>标签的推理内容(向后兼容)
|
||||
if not reasoning_content and content:
|
||||
content, extracted_reasoning = self._extract_reasoning(content)
|
||||
reasoning_content = extracted_reasoning
|
||||
|
||||
is_empty_reply = False
|
||||
is_truncated = False
|
||||
# 检测是否为空回复或截断
|
||||
if not tool_calls:
|
||||
is_empty_reply = not content or content.strip() == ""
|
||||
is_truncated = False
|
||||
|
||||
# 检查是否启用反截断
|
||||
use_anti_truncation = getattr(api_provider, "anti_truncation", False)
|
||||
processed_prompt = prompt
|
||||
if use_anti_truncation:
|
||||
if content.endswith("[done]"):
|
||||
content = content[:-6].strip()
|
||||
logger.debug("检测到并已移除 [done] 标记")
|
||||
else:
|
||||
is_truncated = True
|
||||
logger.warning("未检测到 [done] 标记,判定为截断")
|
||||
processed_prompt += self.anti_truncation_instruction
|
||||
logger.info(f"'{model_name}' for task '{self.task_name}' 已启用反截断功能")
|
||||
|
||||
if is_empty_reply or is_truncated:
|
||||
if empty_retry_count < max_empty_retry:
|
||||
empty_retry_count += 1
|
||||
reason = "空回复" if is_empty_reply else "截断"
|
||||
logger.warning(f"检测到{reason},正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成")
|
||||
processed_prompt = self._apply_content_obfuscation(processed_prompt, api_provider)
|
||||
|
||||
if empty_retry_interval > 0:
|
||||
await asyncio.sleep(empty_retry_interval)
|
||||
message_builder = MessageBuilder()
|
||||
message_builder.add_text_content(processed_prompt)
|
||||
messages = [message_builder.build()]
|
||||
tool_built = self._build_tool_options(tools)
|
||||
|
||||
model_info, api_provider, client = self._select_model()
|
||||
continue
|
||||
else:
|
||||
# 已达到最大重试次数,但仍然是空回复或截断
|
||||
reason = "空回复" if is_empty_reply else "截断"
|
||||
# 抛出异常,由外层重试逻辑或最终的异常处理器捕获
|
||||
raise RuntimeError(f"经过 {max_empty_retry + 1} 次尝试后仍然是{reason}的回复")
|
||||
# 针对当前模型的空回复/截断重试逻辑
|
||||
empty_retry_count = 0
|
||||
max_empty_retry = api_provider.max_retry
|
||||
empty_retry_interval = api_provider.retry_interval
|
||||
|
||||
# 记录使用情况
|
||||
if usage := response.usage:
|
||||
llm_usage_recorder.record_usage_to_database(
|
||||
while empty_retry_count <= max_empty_retry:
|
||||
response = await self._execute_request(
|
||||
api_provider=api_provider,
|
||||
client=client,
|
||||
request_type=RequestType.RESPONSE,
|
||||
model_info=model_info,
|
||||
model_usage=usage,
|
||||
time_cost=time.time() - start_time,
|
||||
user_id="system",
|
||||
request_type=self.request_type,
|
||||
endpoint="/chat/completions",
|
||||
message_list=messages,
|
||||
tool_options=tool_built,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
# 处理空回复
|
||||
if not content and not tool_calls:
|
||||
if raise_when_empty:
|
||||
raise RuntimeError(f"经过 {empty_retry_count} 次重试后仍然生成空回复")
|
||||
content = "生成的响应为空,请检查模型配置或输入内容是否正确"
|
||||
elif empty_retry_count > 0:
|
||||
logger.info(f"经过 {empty_retry_count} 次重试后成功生成回复")
|
||||
content = response.content or ""
|
||||
reasoning_content = response.reasoning_content or ""
|
||||
tool_calls = response.tool_calls
|
||||
|
||||
return content, (reasoning_content, model_info.name, tool_calls)
|
||||
if not reasoning_content and content:
|
||||
content, extracted_reasoning = self._extract_reasoning(content)
|
||||
reasoning_content = extracted_reasoning
|
||||
|
||||
is_empty_reply = not tool_calls and (not content or content.strip() == "")
|
||||
is_truncated = False
|
||||
if use_anti_truncation:
|
||||
if content.endswith("[done]"):
|
||||
content = content[:-6].strip()
|
||||
else:
|
||||
is_truncated = True
|
||||
|
||||
if is_empty_reply or is_truncated:
|
||||
empty_retry_count += 1
|
||||
if empty_retry_count <= max_empty_retry:
|
||||
reason = "空回复" if is_empty_reply else "截断"
|
||||
logger.warning(f"模型 '{model_name}' 检测到{reason},正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成...")
|
||||
if empty_retry_interval > 0:
|
||||
await asyncio.sleep(empty_retry_interval)
|
||||
continue # 继续使用当前模型重试
|
||||
else:
|
||||
# 当前模型重试次数用尽,跳出内层循环,触发外层循环切换模型
|
||||
reason = "空回复" if is_empty_reply else "截断"
|
||||
logger.error(f"模型 '{model_name}' 经过 {max_empty_retry} 次重试后仍然是{reason}的回复。")
|
||||
raise RuntimeError(f"模型 '{model_name}' 达到最大空回复/截断重试次数")
|
||||
|
||||
# 成功获取响应
|
||||
if usage := response.usage:
|
||||
llm_usage_recorder.record_usage_to_database(
|
||||
model_info=model_info, model_usage=usage, time_cost=time.time() - start_time,
|
||||
user_id="system", request_type=self.request_type, endpoint="/chat/completions",
|
||||
)
|
||||
|
||||
if not content and not tool_calls:
|
||||
if raise_when_empty:
|
||||
raise RuntimeError("生成空回复")
|
||||
content = "生成的响应为空"
|
||||
|
||||
logger.info(f"模型 '{model_name}' 成功生成回复。")
|
||||
return content, (reasoning_content, model_name, tool_calls)
|
||||
|
||||
except RespNotOkException as e:
|
||||
if e.status_code in [401, 403]:
|
||||
logger.error(f"模型 '{model_name}' 遇到认证/权限错误 (Code: {e.status_code}),将尝试下一个模型。")
|
||||
failed_models.add(model_name)
|
||||
last_exception = e
|
||||
continue # 切换到下一个模型
|
||||
else:
|
||||
logger.error(f"模型 '{model_name}' 请求失败,HTTP状态码: {e.status_code}")
|
||||
if raise_when_empty:
|
||||
raise
|
||||
# 对于其他HTTP错误,直接抛出,不再尝试其他模型
|
||||
return f"请求失败: {e}", ("", model_name, None)
|
||||
|
||||
except RuntimeError as e:
|
||||
# 捕获所有重试失败(包括空回复和网络问题)
|
||||
logger.error(f"模型 '{model_name}' 在所有重试后仍然失败: {e},将尝试下一个模型。")
|
||||
failed_models.add(model_name)
|
||||
last_exception = e
|
||||
continue # 切换到下一个模型
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"请求执行失败: {e}")
|
||||
if raise_when_empty:
|
||||
# 在非并发模式下,如果第一次尝试就失败,则直接抛出异常
|
||||
if empty_retry_count == 0:
|
||||
raise
|
||||
logger.error(f"使用模型 '{model_name}' 时发生未知异常: {e}")
|
||||
failed_models.add(model_name)
|
||||
last_exception = e
|
||||
continue # 切换到下一个模型
|
||||
|
||||
# 如果在重试过程中失败,则继续重试
|
||||
empty_retry_count += 1
|
||||
if empty_retry_count <= max_empty_retry:
|
||||
logger.warning(f"请求失败,将在 {empty_retry_interval} 秒后进行第 {empty_retry_count}/{max_empty_retry} 次重试...")
|
||||
if empty_retry_interval > 0:
|
||||
await asyncio.sleep(empty_retry_interval)
|
||||
continue
|
||||
else:
|
||||
logger.error(f"经过 {max_empty_retry} 次重试后仍然失败")
|
||||
raise RuntimeError(f"经过 {max_empty_retry} 次重试后仍然无法生成有效回复") from e
|
||||
else:
|
||||
# 在并发模式下,单个请求的失败不应中断整个并发流程,
|
||||
# 而是将异常返回给调用者(即 execute_concurrently)进行统一处理
|
||||
raise # 重新抛出异常,由 execute_concurrently 中的 gather 捕获
|
||||
|
||||
# 重试失败
|
||||
# 所有模型都尝试失败
|
||||
logger.error("所有可用模型都已尝试失败。")
|
||||
if raise_when_empty:
|
||||
raise RuntimeError(f"经过 {max_empty_retry} 次重试后仍然无法生成有效回复")
|
||||
return "生成的响应为空,请检查模型配置或输入内容是否正确", ("", model_name, None)
|
||||
if last_exception:
|
||||
raise RuntimeError("所有模型都请求失败") from last_exception
|
||||
raise RuntimeError("所有模型都请求失败,且没有具体的异常信息")
|
||||
|
||||
return "所有模型都请求失败", ("", "unknown", None)
|
||||
|
||||
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
|
||||
"""获取嵌入向量
|
||||
@@ -446,9 +445,24 @@ class LLMRequest:
|
||||
|
||||
return embedding, model_info.name
|
||||
|
||||
def _model_scheduler(self, failed_models: set) -> Generator[Tuple[ModelInfo, APIProvider, BaseClient], None, None]:
|
||||
"""
|
||||
一个模型调度器,按顺序提供模型,并跳过已失败的模型。
|
||||
"""
|
||||
for model_name in self.model_for_task.model_list:
|
||||
if model_name in failed_models:
|
||||
continue
|
||||
|
||||
model_info = model_config.get_model_info(model_name)
|
||||
api_provider = model_config.get_provider(model_info.api_provider)
|
||||
force_new_client = (self.request_type == "embedding")
|
||||
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
|
||||
|
||||
yield model_info, api_provider, client
|
||||
|
||||
def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]:
|
||||
"""
|
||||
根据总tokens和惩罚值选择的模型
|
||||
根据总tokens和惩罚值选择的模型 (负载均衡)
|
||||
"""
|
||||
least_used_model_name = min(
|
||||
self.model_usage,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
from typing import Optional, Type
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List, Type, Tuple, Union, TYPE_CHECKING
|
||||
from typing import List, Type, Tuple, Union
|
||||
from .plugin_base import PluginBase
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Tuple, Optional, List
|
||||
from typing import Tuple, Optional, List
|
||||
import re
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -7,8 +7,10 @@ from src.llm_models.utils_model import LLMRequest
|
||||
from src.llm_models.payload_content import ToolCall
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
import inspect
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.common.cache_manager import tool_cache
|
||||
|
||||
logger = get_logger("tool_use")
|
||||
|
||||
@@ -184,21 +186,65 @@ class ToolExecutor:
|
||||
return tool_results, used_tools
|
||||
|
||||
async def execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]:
|
||||
# sourcery skip: use-assigned-variable
|
||||
"""执行单个工具调用
|
||||
"""执行单个工具调用,并处理缓存"""
|
||||
|
||||
function_args = tool_call.args or {}
|
||||
tool_instance = tool_instance or get_tool_instance(tool_call.func_name)
|
||||
|
||||
Args:
|
||||
tool_call: 工具调用对象
|
||||
# 如果工具不存在或未启用缓存,则直接执行
|
||||
if not tool_instance or not tool_instance.enable_cache:
|
||||
return await self._original_execute_tool_call(tool_call, tool_instance)
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: 工具调用结果,如果失败则返回None
|
||||
"""
|
||||
# --- 缓存逻辑开始 ---
|
||||
try:
|
||||
tool_file_path = inspect.getfile(tool_instance.__class__)
|
||||
semantic_query = None
|
||||
if tool_instance.semantic_cache_query_key:
|
||||
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
|
||||
|
||||
cached_result = await tool_cache.get(
|
||||
tool_name=tool_call.func_name,
|
||||
function_args=function_args,
|
||||
tool_file_path=tool_file_path,
|
||||
semantic_query=semantic_query
|
||||
)
|
||||
if cached_result:
|
||||
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
|
||||
return cached_result
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}检查工具缓存时出错: {e}")
|
||||
|
||||
# 缓存未命中,执行原始工具调用
|
||||
result = await self._original_execute_tool_call(tool_call, tool_instance)
|
||||
|
||||
# 将结果存入缓存
|
||||
try:
|
||||
tool_file_path = inspect.getfile(tool_instance.__class__)
|
||||
semantic_query = None
|
||||
if tool_instance.semantic_cache_query_key:
|
||||
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
|
||||
|
||||
await tool_cache.set(
|
||||
tool_name=tool_call.func_name,
|
||||
function_args=function_args,
|
||||
tool_file_path=tool_file_path,
|
||||
data=result,
|
||||
ttl=tool_instance.cache_ttl,
|
||||
semantic_query=semantic_query
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}设置工具缓存时出错: {e}")
|
||||
# --- 缓存逻辑结束 ---
|
||||
|
||||
return result
|
||||
|
||||
async def _original_execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]:
|
||||
"""执行单个工具调用的原始逻辑"""
|
||||
try:
|
||||
function_name = tool_call.func_name
|
||||
function_args = tool_call.args or {}
|
||||
logger.info(f"🤖 {self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}")
|
||||
logger.info(f"{self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}")
|
||||
function_args["llm_called"] = True # 标记为LLM调用
|
||||
|
||||
# 获取对应工具实例
|
||||
tool_instance = tool_instance or get_tool_instance(function_name)
|
||||
if not tool_instance:
|
||||
|
||||
@@ -24,6 +24,7 @@ from .services.qzone_service import QZoneService
|
||||
from .services.scheduler_service import SchedulerService
|
||||
from .services.monitor_service import MonitorService
|
||||
from .services.cookie_service import CookieService
|
||||
from .services.reply_tracker_service import ReplyTrackerService
|
||||
from .services.manager import register_service
|
||||
|
||||
logger = get_logger("MaiZone.Plugin")
|
||||
@@ -92,11 +93,13 @@ class MaiZoneRefactoredPlugin(BasePlugin):
|
||||
content_service = ContentService(self.get_config)
|
||||
image_service = ImageService(self.get_config)
|
||||
cookie_service = CookieService(self.get_config)
|
||||
reply_tracker_service = ReplyTrackerService()
|
||||
qzone_service = QZoneService(self.get_config, content_service, image_service, cookie_service)
|
||||
scheduler_service = SchedulerService(self.get_config, qzone_service)
|
||||
monitor_service = MonitorService(self.get_config, qzone_service)
|
||||
|
||||
register_service("qzone", qzone_service)
|
||||
register_service("reply_tracker", reply_tracker_service)
|
||||
register_service("get_config", self.get_config)
|
||||
|
||||
asyncio.create_task(scheduler_service.start())
|
||||
|
||||
@@ -9,12 +9,9 @@ import datetime
|
||||
import base64
|
||||
import aiohttp
|
||||
from src.common.logger import get_logger
|
||||
import base64
|
||||
import aiohttp
|
||||
import imghdr
|
||||
import asyncio
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import llm_api, config_api, generator_api, person_api
|
||||
from src.plugin_system.apis import llm_api, config_api, generator_api
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from maim_message import UserInfo
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
@@ -27,6 +27,7 @@ from src.chat.utils.chat_message_builder import (
|
||||
from .content_service import ContentService
|
||||
from .image_service import ImageService
|
||||
from .cookie_service import CookieService
|
||||
from .reply_tracker_service import ReplyTrackerService
|
||||
|
||||
logger = get_logger("MaiZone.QZoneService")
|
||||
|
||||
@@ -55,6 +56,7 @@ class QZoneService:
|
||||
self.content_service = content_service
|
||||
self.image_service = image_service
|
||||
self.cookie_service = cookie_service
|
||||
self.reply_tracker = ReplyTrackerService()
|
||||
|
||||
# --- Public Methods (High-Level Business Logic) ---
|
||||
|
||||
@@ -154,7 +156,8 @@ class QZoneService:
|
||||
# --- 第一步: 单独处理自己说说的评论 ---
|
||||
if self.get_config("monitor.enable_auto_reply", False):
|
||||
try:
|
||||
own_feeds = await api_client["list_feeds"](qq_account, 5) # 获取自己最近5条说说
|
||||
# 传入新参数,表明正在检查自己的说说
|
||||
own_feeds = await api_client["list_feeds"](qq_account, 5, is_monitoring_own_feeds=True)
|
||||
if own_feeds:
|
||||
logger.info(f"获取到自己 {len(own_feeds)} 条说说,检查评论...")
|
||||
for feed in own_feeds:
|
||||
@@ -248,42 +251,83 @@ class QZoneService:
|
||||
content = feed.get("content", "")
|
||||
fid = feed.get("tid", "")
|
||||
|
||||
if not comments:
|
||||
if not comments or not fid:
|
||||
return
|
||||
|
||||
# 筛选出未被自己回复过的评论
|
||||
if not comments:
|
||||
# 1. 将评论分为用户评论和自己的回复
|
||||
user_comments = [c for c in comments if str(c.get('qq_account')) != str(qq_account)]
|
||||
my_replies = [c for c in comments if str(c.get('qq_account')) == str(qq_account)]
|
||||
|
||||
if not user_comments:
|
||||
return
|
||||
|
||||
# 找到所有我已经回复过的评论的ID
|
||||
replied_to_tids = {
|
||||
c['parent_tid'] for c in comments
|
||||
if c.get('parent_tid') and str(c.get('qq_account')) == str(qq_account)
|
||||
}
|
||||
# 2. 验证已记录的回复是否仍然存在,清理已删除的回复记录
|
||||
await self._validate_and_cleanup_reply_records(fid, my_replies)
|
||||
|
||||
# 找出所有非我发出且我未回复过的评论
|
||||
comments_to_reply = [
|
||||
c for c in comments
|
||||
if str(c.get('qq_account')) != str(qq_account) and c.get('comment_tid') not in replied_to_tids
|
||||
]
|
||||
# 3. 使用验证后的持久化记录来筛选未回复的评论
|
||||
comments_to_reply = []
|
||||
for comment in user_comments:
|
||||
comment_tid = comment.get('comment_tid')
|
||||
if not comment_tid:
|
||||
continue
|
||||
|
||||
# 检查是否已经在持久化记录中标记为已回复
|
||||
if not self.reply_tracker.has_replied(fid, comment_tid):
|
||||
comments_to_reply.append(comment)
|
||||
|
||||
if not comments_to_reply:
|
||||
logger.debug(f"说说 {fid} 下的所有评论都已回复过")
|
||||
return
|
||||
|
||||
logger.info(f"发现自己说说下的 {len(comments_to_reply)} 条新评论,准备回复...")
|
||||
for comment in comments_to_reply:
|
||||
reply_content = await self.content_service.generate_comment_reply(
|
||||
content, comment.get("content", ""), comment.get("nickname", "")
|
||||
)
|
||||
if reply_content:
|
||||
success = await api_client["reply"](
|
||||
fid, qq_account, comment.get("nickname", ""), reply_content, comment.get("comment_tid")
|
||||
comment_tid = comment.get("comment_tid")
|
||||
nickname = comment.get("nickname", "")
|
||||
comment_content = comment.get("content", "")
|
||||
|
||||
try:
|
||||
reply_content = await self.content_service.generate_comment_reply(
|
||||
content, comment_content, nickname
|
||||
)
|
||||
if success:
|
||||
logger.info(f"成功回复'{comment.get('nickname', '')}'的评论: '{reply_content}'")
|
||||
if reply_content:
|
||||
success = await api_client["reply"](
|
||||
fid, qq_account, nickname, reply_content, comment_tid
|
||||
)
|
||||
if success:
|
||||
# 标记为已回复
|
||||
self.reply_tracker.mark_as_replied(fid, comment_tid)
|
||||
logger.info(f"成功回复'{nickname}'的评论: '{reply_content}'")
|
||||
else:
|
||||
logger.error(f"回复'{nickname}'的评论失败")
|
||||
await asyncio.sleep(random.uniform(10, 20))
|
||||
else:
|
||||
logger.error(f"回复'{comment.get('nickname', '')}'的评论失败")
|
||||
await asyncio.sleep(random.uniform(10, 20))
|
||||
logger.warning(f"生成回复内容失败,跳过回复'{nickname}'的评论")
|
||||
except Exception as e:
|
||||
logger.error(f"回复'{nickname}'的评论时发生异常: {e}", exc_info=True)
|
||||
|
||||
async def _validate_and_cleanup_reply_records(self, fid: str, my_replies: List[Dict]):
|
||||
"""验证并清理已删除的回复记录"""
|
||||
# 获取当前记录中该说说的所有已回复评论ID
|
||||
recorded_replied_comments = self.reply_tracker.get_replied_comments(fid)
|
||||
|
||||
if not recorded_replied_comments:
|
||||
return
|
||||
|
||||
# 从API返回的我的回复中提取parent_tid(即被回复的评论ID)
|
||||
current_replied_comments = set()
|
||||
for reply in my_replies:
|
||||
parent_tid = reply.get('parent_tid')
|
||||
if parent_tid:
|
||||
current_replied_comments.add(parent_tid)
|
||||
|
||||
# 找出记录中有但实际已不存在的回复
|
||||
deleted_replies = recorded_replied_comments - current_replied_comments
|
||||
|
||||
if deleted_replies:
|
||||
logger.info(f"检测到 {len(deleted_replies)} 个回复已被删除,清理记录...")
|
||||
for comment_tid in deleted_replies:
|
||||
self.reply_tracker.remove_reply_record(fid, comment_tid)
|
||||
logger.debug(f"已清理删除的回复记录: feed_id={fid}, comment_id={comment_tid}")
|
||||
|
||||
async def _process_single_feed(self, feed: Dict, api_client: Dict, target_qq: str, target_name: str):
|
||||
"""处理单条说说,决定是否评论和点赞"""
|
||||
@@ -641,7 +685,7 @@ class QZoneService:
|
||||
logger.error(f"上传图片 {index+1} 异常: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def _list_feeds(t_qq: str, num: int) -> List[Dict]:
|
||||
async def _list_feeds(t_qq: str, num: int, is_monitoring_own_feeds: bool = False) -> List[Dict]:
|
||||
"""获取指定用户说说列表"""
|
||||
try:
|
||||
params = {
|
||||
@@ -667,37 +711,41 @@ class QZoneService:
|
||||
feeds_list = []
|
||||
my_name = json_data.get("logininfo", {}).get("name", "")
|
||||
for msg in json_data.get("msglist", []):
|
||||
is_commented = any(
|
||||
c.get("name") == my_name for c in msg.get("commentlist", []) if isinstance(c, dict)
|
||||
)
|
||||
if not is_commented:
|
||||
images = [pic['url1'] for pic in msg.get('pictotal', []) if 'url1' in pic]
|
||||
|
||||
comments = []
|
||||
if 'commentlist' in msg:
|
||||
for c in msg['commentlist']:
|
||||
comments.append({
|
||||
'qq_account': c.get('uin'),
|
||||
'nickname': c.get('name'),
|
||||
'content': c.get('content'),
|
||||
'comment_tid': c.get('tid'),
|
||||
'parent_tid': c.get('parent_tid') # API直接返回了父ID
|
||||
})
|
||||
|
||||
feeds_list.append(
|
||||
{
|
||||
"tid": msg.get("tid", ""),
|
||||
"content": msg.get("content", ""),
|
||||
"created_time": time.strftime(
|
||||
"%Y-%m-%d %H:%M:%S", time.localtime(msg.get("created_time", 0))
|
||||
),
|
||||
"rt_con": msg.get("rt_con", {}).get("content", "")
|
||||
if isinstance(msg.get("rt_con"), dict)
|
||||
else "",
|
||||
"images": images,
|
||||
"comments": comments
|
||||
}
|
||||
# 只有在处理好友说说时,才检查是否已评论并跳过
|
||||
if not is_monitoring_own_feeds:
|
||||
is_commented = any(
|
||||
c.get("name") == my_name for c in msg.get("commentlist", []) if isinstance(c, dict)
|
||||
)
|
||||
if is_commented:
|
||||
continue
|
||||
|
||||
images = [pic['url1'] for pic in msg.get('pictotal', []) if 'url1' in pic]
|
||||
|
||||
comments = []
|
||||
if 'commentlist' in msg:
|
||||
for c in msg['commentlist']:
|
||||
comments.append({
|
||||
'qq_account': c.get('uin'),
|
||||
'nickname': c.get('name'),
|
||||
'content': c.get('content'),
|
||||
'comment_tid': c.get('tid'),
|
||||
'parent_tid': c.get('parent_tid') # API直接返回了父ID
|
||||
})
|
||||
|
||||
feeds_list.append(
|
||||
{
|
||||
"tid": msg.get("tid", ""),
|
||||
"content": msg.get("content", ""),
|
||||
"created_time": time.strftime(
|
||||
"%Y-%m-%d %H:%M:%S", time.localtime(msg.get("created_time", 0))
|
||||
),
|
||||
"rt_con": msg.get("rt_con", {}).get("content", "")
|
||||
if isinstance(msg.get("rt_con"), dict)
|
||||
else "",
|
||||
"images": images,
|
||||
"comments": comments
|
||||
}
|
||||
)
|
||||
return feeds_list
|
||||
except Exception as e:
|
||||
logger.error(f"获取说说列表失败: {e}", exc_info=True)
|
||||
|
||||
@@ -0,0 +1,195 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
评论回复跟踪服务
|
||||
负责记录和管理已回复过的评论ID,避免重复回复
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Set, Dict, Any
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("MaiZone.ReplyTrackerService")
|
||||
|
||||
|
||||
class ReplyTrackerService:
|
||||
"""
|
||||
评论回复跟踪服务
|
||||
使用本地JSON文件持久化存储已回复的评论ID
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 数据存储路径
|
||||
self.data_dir = Path(__file__).resolve().parent.parent / "data"
|
||||
self.data_dir.mkdir(exist_ok=True)
|
||||
self.reply_record_file = self.data_dir / "replied_comments.json"
|
||||
|
||||
# 内存中的已回复评论记录
|
||||
# 格式: {feed_id: {comment_id: timestamp, ...}, ...}
|
||||
self.replied_comments: Dict[str, Dict[str, float]] = {}
|
||||
|
||||
# 数据清理配置
|
||||
self.max_record_days = 30 # 保留30天的记录
|
||||
|
||||
# 加载已有数据
|
||||
self._load_data()
|
||||
|
||||
def _load_data(self):
|
||||
"""从文件加载已回复评论数据"""
|
||||
try:
|
||||
if self.reply_record_file.exists():
|
||||
with open(self.reply_record_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
self.replied_comments = data
|
||||
logger.info(f"已加载 {len(self.replied_comments)} 条说说的回复记录")
|
||||
else:
|
||||
logger.info("未找到回复记录文件,将创建新的记录")
|
||||
except Exception as e:
|
||||
logger.error(f"加载回复记录失败: {e}")
|
||||
self.replied_comments = {}
|
||||
|
||||
def _save_data(self):
|
||||
"""保存已回复评论数据到文件"""
|
||||
try:
|
||||
# 清理过期数据
|
||||
self._cleanup_old_records()
|
||||
|
||||
with open(self.reply_record_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(self.replied_comments, f, ensure_ascii=False, indent=2)
|
||||
logger.debug("回复记录已保存")
|
||||
except Exception as e:
|
||||
logger.error(f"保存回复记录失败: {e}")
|
||||
|
||||
def _cleanup_old_records(self):
|
||||
"""清理超过保留期限的记录"""
|
||||
current_time = time.time()
|
||||
cutoff_time = current_time - (self.max_record_days * 24 * 60 * 60)
|
||||
|
||||
feeds_to_remove = []
|
||||
total_removed = 0
|
||||
|
||||
for feed_id, comments in self.replied_comments.items():
|
||||
comments_to_remove = []
|
||||
|
||||
for comment_id, timestamp in comments.items():
|
||||
if timestamp < cutoff_time:
|
||||
comments_to_remove.append(comment_id)
|
||||
|
||||
# 移除过期的评论记录
|
||||
for comment_id in comments_to_remove:
|
||||
del comments[comment_id]
|
||||
total_removed += 1
|
||||
|
||||
# 如果该说说下没有任何记录了,标记删除整个说说记录
|
||||
if not comments:
|
||||
feeds_to_remove.append(feed_id)
|
||||
|
||||
# 移除空的说说记录
|
||||
for feed_id in feeds_to_remove:
|
||||
del self.replied_comments[feed_id]
|
||||
|
||||
if total_removed > 0:
|
||||
logger.info(f"清理了 {total_removed} 条过期的回复记录")
|
||||
|
||||
def has_replied(self, feed_id: str, comment_id: str) -> bool:
|
||||
"""
|
||||
检查是否已经回复过指定的评论
|
||||
|
||||
Args:
|
||||
feed_id: 说说ID
|
||||
comment_id: 评论ID
|
||||
|
||||
Returns:
|
||||
bool: 如果已回复过返回True,否则返回False
|
||||
"""
|
||||
if not feed_id or not comment_id:
|
||||
return False
|
||||
|
||||
return (feed_id in self.replied_comments and
|
||||
comment_id in self.replied_comments[feed_id])
|
||||
|
||||
def mark_as_replied(self, feed_id: str, comment_id: str):
|
||||
"""
|
||||
标记指定评论为已回复
|
||||
|
||||
Args:
|
||||
feed_id: 说说ID
|
||||
comment_id: 评论ID
|
||||
"""
|
||||
if not feed_id or not comment_id:
|
||||
logger.warning("feed_id 或 comment_id 为空,无法标记为已回复")
|
||||
return
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
if feed_id not in self.replied_comments:
|
||||
self.replied_comments[feed_id] = {}
|
||||
|
||||
self.replied_comments[feed_id][comment_id] = current_time
|
||||
|
||||
# 保存到文件
|
||||
self._save_data()
|
||||
|
||||
logger.info(f"已标记评论为已回复: feed_id={feed_id}, comment_id={comment_id}")
|
||||
|
||||
def get_replied_comments(self, feed_id: str) -> Set[str]:
|
||||
"""
|
||||
获取指定说说下所有已回复的评论ID
|
||||
|
||||
Args:
|
||||
feed_id: 说说ID
|
||||
|
||||
Returns:
|
||||
Set[str]: 已回复的评论ID集合
|
||||
"""
|
||||
if feed_id in self.replied_comments:
|
||||
return set(self.replied_comments[feed_id].keys())
|
||||
return set()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取回复记录统计信息
|
||||
|
||||
Returns:
|
||||
Dict: 包含统计信息的字典
|
||||
"""
|
||||
total_feeds = len(self.replied_comments)
|
||||
total_replies = sum(len(comments) for comments in self.replied_comments.values())
|
||||
|
||||
return {
|
||||
"total_feeds_with_replies": total_feeds,
|
||||
"total_replied_comments": total_replies,
|
||||
"data_file": str(self.reply_record_file),
|
||||
"max_record_days": self.max_record_days
|
||||
}
|
||||
|
||||
def remove_reply_record(self, feed_id: str, comment_id: str):
|
||||
"""
|
||||
移除指定评论的回复记录
|
||||
|
||||
Args:
|
||||
feed_id: 说说ID
|
||||
comment_id: 评论ID
|
||||
"""
|
||||
if feed_id in self.replied_comments and comment_id in self.replied_comments[feed_id]:
|
||||
del self.replied_comments[feed_id][comment_id]
|
||||
|
||||
# 如果该说说下没有任何回复记录了,删除整个说说记录
|
||||
if not self.replied_comments[feed_id]:
|
||||
del self.replied_comments[feed_id]
|
||||
|
||||
self._save_data()
|
||||
logger.debug(f"已移除回复记录: feed_id={feed_id}, comment_id={comment_id}")
|
||||
|
||||
def remove_feed_records(self, feed_id: str):
|
||||
"""
|
||||
移除指定说说的所有回复记录
|
||||
|
||||
Args:
|
||||
feed_id: 说说ID
|
||||
"""
|
||||
if feed_id in self.replied_comments:
|
||||
del self.replied_comments[feed_id]
|
||||
self._save_data()
|
||||
logger.info(f"已移除说说 {feed_id} 的所有回复记录")
|
||||
@@ -16,7 +16,7 @@ from src.plugin_system.apis.permission_api import permission_api
|
||||
from src.plugin_system.apis.logging_api import get_logger
|
||||
from src.plugin_system.base.component_types import PlusCommandInfo, ChatType
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from src.plugin_system.utils.permission_decorators import require_permission, require_master, PermissionChecker
|
||||
from src.plugin_system.utils.permission_decorators import require_permission
|
||||
|
||||
|
||||
logger = get_logger("Permission")
|
||||
|
||||
@@ -411,7 +411,6 @@ class ScheduleManager:
|
||||
通过关键词匹配、唤醒度、睡眠压力等综合判断是否处于休眠时间。
|
||||
新增弹性睡眠机制,允许在压力低时延迟入睡,并在入睡前发送通知。
|
||||
"""
|
||||
from src.chat.chat_loop.wakeup_manager import WakeUpManager
|
||||
# --- 基础检查 ---
|
||||
if not global_config.schedule.enable_is_sleep:
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user