依旧修pyright喵~

This commit is contained in:
ikun-11451
2025-11-29 21:26:42 +08:00
parent 28719c1c89
commit 72e7492953
25 changed files with 170 additions and 104 deletions

View File

@@ -83,7 +83,7 @@ class ChatterManager:
inactive_streams = [] inactive_streams = []
for stream_id, instance in self.instances.items(): for stream_id, instance in self.instances.items():
if hasattr(instance, "get_activity_time"): if hasattr(instance, "get_activity_time"):
activity_time = instance.get_activity_time() activity_time = getattr(instance, "get_activity_time")()
if (current_time - activity_time) > max_inactive_seconds: if (current_time - activity_time) > max_inactive_seconds:
inactive_streams.append(stream_id) inactive_streams.append(stream_id)

View File

@@ -104,7 +104,7 @@ class ChatterActionManager:
log_prefix=log_prefix, log_prefix=log_prefix,
shutting_down=shutting_down, shutting_down=shutting_down,
plugin_config=plugin_config, plugin_config=plugin_config,
action_message=action_message, action_message=action_message, # type: ignore
) )
logger.debug(f"创建Action实例成功: {action_name}") logger.debug(f"创建Action实例成功: {action_name}")
@@ -173,6 +173,7 @@ class ChatterActionManager:
Returns: Returns:
执行结果 执行结果
""" """
assert global_config is not None
chat_stream = None chat_stream = None
try: try:

View File

@@ -30,6 +30,7 @@ class ActionModifier:
def __init__(self, action_manager: ChatterActionManager, chat_id: str): def __init__(self, action_manager: ChatterActionManager, chat_id: str):
"""初始化动作处理器""" """初始化动作处理器"""
assert model_config is not None
self.chat_id = chat_id self.chat_id = chat_id
# chat_stream 和 log_prefix 将在异步方法中初始化 # chat_stream 和 log_prefix 将在异步方法中初始化
self.chat_stream: "ChatStream | None" = None self.chat_stream: "ChatStream | None" = None
@@ -67,6 +68,7 @@ class ActionModifier:
处理后ActionManager 将包含最终的可用动作集,供规划器直接使用 处理后ActionManager 将包含最终的可用动作集,供规划器直接使用
""" """
assert global_config is not None
# 初始化log_prefix # 初始化log_prefix
await self._initialize_log_prefix() await self._initialize_log_prefix()
# 根据 stream_id 加载当前可用的动作 # 根据 stream_id 加载当前可用的动作

View File

@@ -240,6 +240,8 @@ class DefaultReplyer:
chat_stream: "ChatStream", chat_stream: "ChatStream",
request_type: str = "replyer", request_type: str = "replyer",
): ):
assert global_config is not None
assert model_config is not None
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type) self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
self.chat_stream = chat_stream self.chat_stream = chat_stream
# 这些将在异步初始化中设置 # 这些将在异步初始化中设置
@@ -267,6 +269,7 @@ class DefaultReplyer:
async def _build_auth_role_prompt(self) -> str: async def _build_auth_role_prompt(self) -> str:
"""根据主人配置生成额外提示词""" """根据主人配置生成额外提示词"""
assert global_config is not None
master_config = global_config.permission.master_prompt master_config = global_config.permission.master_prompt
if not master_config or not master_config.enable: if not master_config or not master_config.enable:
return "" return ""
@@ -515,6 +518,7 @@ class DefaultReplyer:
Returns: Returns:
str: 表达习惯信息字符串 str: 表达习惯信息字符串
""" """
assert global_config is not None
# 检查是否允许在此聊天流中使用表达 # 检查是否允许在此聊天流中使用表达
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(self.chat_stream.stream_id) use_expression, _, _ = global_config.expression.get_expression_config_for_chat(self.chat_stream.stream_id)
if not use_expression: if not use_expression:
@@ -583,6 +587,7 @@ class DefaultReplyer:
Returns: Returns:
str: 记忆信息字符串 str: 记忆信息字符串
""" """
assert global_config is not None
# 检查是否启用三层记忆系统 # 检查是否启用三层记忆系统
if not (global_config.memory and global_config.memory.enable): if not (global_config.memory and global_config.memory.enable):
return "" return ""
@@ -776,6 +781,7 @@ class DefaultReplyer:
Returns: Returns:
str: 关键词反应提示字符串,如果没有触发任何反应则为空字符串 str: 关键词反应提示字符串,如果没有触发任何反应则为空字符串
""" """
assert global_config is not None
if target is None: if target is None:
return "" return ""
@@ -834,6 +840,7 @@ class DefaultReplyer:
Returns: Returns:
str: 格式化的notice信息文本如果没有notice或未启用则返回空字符串 str: 格式化的notice信息文本如果没有notice或未启用则返回空字符串
""" """
assert global_config is not None
try: try:
logger.debug(f"开始构建notice块chat_id={chat_id}") logger.debug(f"开始构建notice块chat_id={chat_id}")
@@ -902,6 +909,7 @@ class DefaultReplyer:
Returns: Returns:
Tuple[str, str]: (已读历史消息prompt, 未读历史消息prompt) Tuple[str, str]: (已读历史消息prompt, 未读历史消息prompt)
""" """
assert global_config is not None
try: try:
# 从message_manager获取真实的已读/未读消息 # 从message_manager获取真实的已读/未读消息
@@ -1002,6 +1010,7 @@ class DefaultReplyer:
""" """
回退的已读/未读历史消息构建方法 回退的已读/未读历史消息构建方法
""" """
assert global_config is not None
# 通过is_read字段分离已读和未读消息 # 通过is_read字段分离已读和未读消息
read_messages = [] read_messages = []
unread_messages = [] unread_messages = []
@@ -1115,6 +1124,7 @@ class DefaultReplyer:
Returns: Returns:
str: 构建好的上下文 str: 构建好的上下文
""" """
assert global_config is not None
if available_actions is None: if available_actions is None:
available_actions = {} available_actions = {}
chat_stream = self.chat_stream chat_stream = self.chat_stream
@@ -1607,6 +1617,7 @@ class DefaultReplyer:
reply_to: str, reply_to: str,
reply_message: dict[str, Any] | DatabaseMessages | None = None, reply_message: dict[str, Any] | DatabaseMessages | None = None,
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if ) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
assert global_config is not None
chat_stream = self.chat_stream chat_stream = self.chat_stream
chat_id = chat_stream.stream_id chat_id = chat_stream.stream_id
is_group_chat = bool(chat_stream.group_info) is_group_chat = bool(chat_stream.group_info)
@@ -1767,6 +1778,7 @@ class DefaultReplyer:
return prompt_text return prompt_text
async def llm_generate_content(self, prompt: str): async def llm_generate_content(self, prompt: str):
assert global_config is not None
with Timer("LLM生成", {}): # 内部计时器,可选保留 with Timer("LLM生成", {}): # 内部计时器,可选保留
# 直接使用已初始化的模型实例 # 直接使用已初始化的模型实例
logger.info(f"使用模型集生成回复: {self.express_model.model_for_task}") logger.info(f"使用模型集生成回复: {self.express_model.model_for_task}")
@@ -1792,6 +1804,8 @@ class DefaultReplyer:
return content, reasoning_content, model_name, tool_calls return content, reasoning_content, model_name, tool_calls
async def get_prompt_info(self, message: str, sender: str, target: str): async def get_prompt_info(self, message: str, sender: str, target: str):
assert global_config is not None
assert model_config is not None
related_info = "" related_info = ""
start_time = time.time() start_time = time.time()
from src.plugins.built_in.knowledge.lpmm_get_knowledge import SearchKnowledgeFromLPMMTool from src.plugins.built_in.knowledge.lpmm_get_knowledge import SearchKnowledgeFromLPMMTool
@@ -1843,6 +1857,7 @@ class DefaultReplyer:
return "" return ""
async def build_relation_info(self, sender: str, target: str): async def build_relation_info(self, sender: str, target: str):
assert global_config is not None
# 获取用户ID # 获取用户ID
if sender == f"{global_config.bot.nickname}(你)": if sender == f"{global_config.bot.nickname}(你)":
return "你将要回复的是你自己发送的消息。" return "你将要回复的是你自己发送的消息。"
@@ -1927,6 +1942,7 @@ class DefaultReplyer:
reply_to: 回复对象 reply_to: 回复对象
reply_message: 回复的原始消息 reply_message: 回复的原始消息
""" """
assert global_config is not None
return # 已禁用,保留函数签名以防其他地方有引用 return # 已禁用,保留函数签名以防其他地方有引用
# 以下代码已废弃,不再执行 # 以下代码已废弃,不再执行

View File

@@ -173,9 +173,10 @@ class SecurityManager:
pre_check_results = await asyncio.gather(*pre_check_tasks, return_exceptions=True) pre_check_results = await asyncio.gather(*pre_check_tasks, return_exceptions=True)
# 筛选需要完整检查的检测器 # 筛选需要完整检查的检测器
checkers_to_run = [ checkers_to_run = []
c for c, need_check in zip(enabled_checkers, pre_check_results) if need_check is True for c, need_check in zip(enabled_checkers, pre_check_results):
] if need_check is True:
checkers_to_run.append(c)
if not checkers_to_run: if not checkers_to_run:
return SecurityCheckResult( return SecurityCheckResult(
@@ -192,20 +193,22 @@ class SecurityManager:
results = await asyncio.gather(*check_tasks, return_exceptions=True) results = await asyncio.gather(*check_tasks, return_exceptions=True)
# 过滤异常结果 # 过滤异常结果
valid_results = [] valid_results: list[SecurityCheckResult] = []
for checker, result in zip(checkers_to_run, results): for checker, result in zip(checkers_to_run, results):
if isinstance(result, Exception): if isinstance(result, BaseException):
logger.error(f"检测器 '{checker.name}' 执行失败: {result}") logger.error(f"检测器 '{checker.name}' 执行失败: {result}")
continue continue
result.checker_name = checker.name
valid_results.append(result) if isinstance(result, SecurityCheckResult):
result.checker_name = checker.name
valid_results.append(result)
# 合并结果 # 合并结果
return self._merge_results(valid_results, time.time() - start_time) return self._merge_results(valid_results, time.time() - start_time)
async def _check_all(self, message: str, context: dict, start_time: float) -> SecurityCheckResult: async def _check_all(self, message: str, context: dict, start_time: float) -> SecurityCheckResult:
"""检测所有模式(顺序执行所有检测器)""" """检测所有模式(顺序执行所有检测器)"""
results = [] results: list[SecurityCheckResult] = []
for checker in self._checkers: for checker in self._checkers:
if not checker.enabled: if not checker.enabled:

View File

@@ -39,11 +39,13 @@ def replace_user_references_sync(
Returns: Returns:
str: 处理后的内容字符串 str: 处理后的内容字符串
""" """
assert global_config is not None
if not content: if not content:
return "" return ""
if name_resolver is None: if name_resolver is None:
def default_resolver(platform: str, user_id: str) -> str: def default_resolver(platform: str, user_id: str) -> str:
assert global_config is not None
# 检查是否是机器人自己 # 检查是否是机器人自己
if replace_bot_name and (user_id == str(global_config.bot.qq_account)): if replace_bot_name and (user_id == str(global_config.bot.qq_account)):
return f"{global_config.bot.nickname}(你)" return f"{global_config.bot.nickname}(你)"
@@ -116,10 +118,12 @@ async def replace_user_references_async(
Returns: Returns:
str: 处理后的内容字符串 str: 处理后的内容字符串
""" """
assert global_config is not None
if name_resolver is None: if name_resolver is None:
person_info_manager = get_person_info_manager() person_info_manager = get_person_info_manager()
async def default_resolver(platform: str, user_id: str) -> str: async def default_resolver(platform: str, user_id: str) -> str:
assert global_config is not None
# 检查是否是机器人自己 # 检查是否是机器人自己
if replace_bot_name and (user_id == str(global_config.bot.qq_account)): if replace_bot_name and (user_id == str(global_config.bot.qq_account)):
return f"{global_config.bot.nickname}(你)" return f"{global_config.bot.nickname}(你)"
@@ -392,7 +396,7 @@ async def get_actions_by_timestamp_with_chat_inclusive(
actions = list(result.scalars()) actions = list(result.scalars())
return [action.__dict__ for action in reversed(actions)] return [action.__dict__ for action in reversed(actions)]
else: # earliest else: # earliest
result = await session.execute( query = await session.execute(
select(ActionRecords) select(ActionRecords)
.where( .where(
and_( and_(
@@ -540,6 +544,7 @@ async def _build_readable_messages_internal(
Returns: Returns:
包含格式化消息的字符串、原始消息详情列表、图片映射字典和更新后的计数器的元组。 包含格式化消息的字符串、原始消息详情列表、图片映射字典和更新后的计数器的元组。
""" """
assert global_config is not None
if not messages: if not messages:
return "", [], pic_id_mapping or {}, pic_counter return "", [], pic_id_mapping or {}, pic_counter
@@ -694,6 +699,7 @@ async def _build_readable_messages_internal(
percentile = i / n_messages # 计算消息在列表中的位置百分比 (0 <= percentile < 1) percentile = i / n_messages # 计算消息在列表中的位置百分比 (0 <= percentile < 1)
original_len = len(content) original_len = len(content)
limit = -1 # 默认不截断 limit = -1 # 默认不截断
replace_content = ""
if percentile < 0.2: # 60% 之前的消息 (即最旧的 60%) if percentile < 0.2: # 60% 之前的消息 (即最旧的 60%)
limit = 50 limit = 50
@@ -973,6 +979,7 @@ async def build_readable_messages(
truncate: 是否截断长消息 truncate: 是否截断长消息
show_actions: 是否显示动作记录 show_actions: 是否显示动作记录
""" """
assert global_config is not None
# 创建messages的深拷贝避免修改原始列表 # 创建messages的深拷贝避免修改原始列表
if not messages: if not messages:
return "" return ""
@@ -1112,6 +1119,7 @@ async def build_anonymous_messages(messages: list[dict[str, Any]]) -> str:
构建匿名可读消息将不同人的名称转为唯一占位符A、B、C...bot自己用SELF。 构建匿名可读消息将不同人的名称转为唯一占位符A、B、C...bot自己用SELF。
处理 回复<aaa:bbb> 和 @<aaa:bbb> 字段将bbb映射为匿名占位符。 处理 回复<aaa:bbb> 和 @<aaa:bbb> 字段将bbb映射为匿名占位符。
""" """
assert global_config is not None
if not messages: if not messages:
print("111111111111没有消息无法构建匿名消息") print("111111111111没有消息无法构建匿名消息")
return "" return ""
@@ -1127,6 +1135,7 @@ async def build_anonymous_messages(messages: list[dict[str, Any]]) -> str:
def get_anon_name(platform, user_id): def get_anon_name(platform, user_id):
# print(f"get_anon_name: platform:{platform}, user_id:{user_id}") # print(f"get_anon_name: platform:{platform}, user_id:{user_id}")
# print(f"global_config.bot.qq_account:{global_config.bot.qq_account}") # print(f"global_config.bot.qq_account:{global_config.bot.qq_account}")
assert global_config is not None
if user_id == global_config.bot.qq_account: if user_id == global_config.bot.qq_account:
# print("SELF11111111111111") # print("SELF11111111111111")
@@ -1204,6 +1213,7 @@ async def get_person_id_list(messages: list[dict[str, Any]]) -> list[str]:
Returns: Returns:
一个包含唯一 person_id 的列表。 一个包含唯一 person_id 的列表。
""" """
assert global_config is not None
person_ids_set = set() # 使用集合来自动去重 person_ids_set = set() # 使用集合来自动去重
for msg in messages: for msg in messages:

View File

@@ -649,6 +649,7 @@ class Prompt:
async def _build_expression_habits(self) -> dict[str, Any]: async def _build_expression_habits(self) -> dict[str, Any]:
"""构建表达习惯(如表情、口癖)的上下文块.""" """构建表达习惯(如表情、口癖)的上下文块."""
assert global_config is not None
# 检查当前聊天是否启用了表达习惯功能 # 检查当前聊天是否启用了表达习惯功能
use_expression, _, _ = global_config.expression.get_expression_config_for_chat( use_expression, _, _ = global_config.expression.get_expression_config_for_chat(
self.parameters.chat_id self.parameters.chat_id
@@ -728,6 +729,7 @@ class Prompt:
async def _build_tool_info(self) -> dict[str, Any]: async def _build_tool_info(self) -> dict[str, Any]:
"""构建工具调用结果的上下文块.""" """构建工具调用结果的上下文块."""
assert global_config is not None
if not global_config.tool.enable_tool: if not global_config.tool.enable_tool:
return {"tool_info_block": ""} return {"tool_info_block": ""}
@@ -779,6 +781,7 @@ class Prompt:
async def _build_knowledge_info(self) -> dict[str, Any]: async def _build_knowledge_info(self) -> dict[str, Any]:
"""构建从知识库检索到的相关信息的上下文块.""" """构建从知识库检索到的相关信息的上下文块."""
assert global_config is not None
if not global_config.lpmm_knowledge.enable: if not global_config.lpmm_knowledge.enable:
return {"knowledge_prompt": ""} return {"knowledge_prompt": ""}
@@ -873,6 +876,7 @@ class Prompt:
def _prepare_s4u_params(self, context_data: dict[str, Any]) -> dict[str, Any]: def _prepare_s4u_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
"""为S4UScene for You模式准备最终用于格式化的参数字典.""" """为S4UScene for You模式准备最终用于格式化的参数字典."""
assert global_config is not None
return { return {
**context_data, **context_data,
"expression_habits_block": context_data.get("expression_habits_block", ""), "expression_habits_block": context_data.get("expression_habits_block", ""),
@@ -915,6 +919,7 @@ class Prompt:
def _prepare_normal_params(self, context_data: dict[str, Any]) -> dict[str, Any]: def _prepare_normal_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
"""为Normal模式准备最终用于格式化的参数字典.""" """为Normal模式准备最终用于格式化的参数字典."""
assert global_config is not None
return { return {
**context_data, **context_data,
"expression_habits_block": context_data.get("expression_habits_block", ""), "expression_habits_block": context_data.get("expression_habits_block", ""),
@@ -959,6 +964,7 @@ class Prompt:
def _prepare_default_params(self, context_data: dict[str, Any]) -> dict[str, Any]: def _prepare_default_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
"""为默认模式(或其他未指定模式)准备最终用于格式化的参数字典.""" """为默认模式(或其他未指定模式)准备最终用于格式化的参数字典."""
assert global_config is not None
return { return {
"expression_habits_block": context_data.get("expression_habits_block", ""), "expression_habits_block": context_data.get("expression_habits_block", ""),
"relation_info_block": context_data.get("relation_info_block", ""), "relation_info_block": context_data.get("relation_info_block", ""),
@@ -1143,6 +1149,7 @@ class Prompt:
Returns: Returns:
str: 构建好的跨群聊上下文字符串。 str: 构建好的跨群聊上下文字符串。
""" """
assert global_config is not None
if not global_config.cross_context.enable: if not global_config.cross_context.enable:
return "" return ""

View File

@@ -338,6 +338,7 @@ class HTMLReportGenerator:
# 渲染模板 # 渲染模板
# 读取CSS和JS文件内容 # 读取CSS和JS文件内容
assert isinstance(self.jinja_env.loader, FileSystemLoader)
async with aiofiles.open(os.path.join(self.jinja_env.loader.searchpath[0], "report.css"), encoding="utf-8") as f: async with aiofiles.open(os.path.join(self.jinja_env.loader.searchpath[0], "report.css"), encoding="utf-8") as f:
report_css = await f.read() report_css = await f.read()
async with aiofiles.open(os.path.join(self.jinja_env.loader.searchpath[0], "report.js"), encoding="utf-8") as f: async with aiofiles.open(os.path.join(self.jinja_env.loader.searchpath[0], "report.js"), encoding="utf-8") as f:

View File

@@ -192,7 +192,7 @@ class StatisticOutputTask(AsyncTask):
self._statistic_console_output(stats, now) self._statistic_console_output(stats, now)
# 使用新的 HTMLReportGenerator 生成报告 # 使用新的 HTMLReportGenerator 生成报告
chart_data = await self._collect_chart_data(stats) chart_data = await self._collect_chart_data(stats)
deploy_time = datetime.fromtimestamp(local_storage.get("deploy_time", now.timestamp())) deploy_time = datetime.fromtimestamp(float(local_storage.get("deploy_time", now.timestamp()))) # type: ignore
report_generator = HTMLReportGenerator( report_generator = HTMLReportGenerator(
name_mapping=self.name_mapping, name_mapping=self.name_mapping,
stat_period=self.stat_period, stat_period=self.stat_period,
@@ -219,7 +219,7 @@ class StatisticOutputTask(AsyncTask):
# 使用新的 HTMLReportGenerator 生成报告 # 使用新的 HTMLReportGenerator 生成报告
chart_data = await self._collect_chart_data(stats) chart_data = await self._collect_chart_data(stats)
deploy_time = datetime.fromtimestamp(local_storage.get("deploy_time", now.timestamp())) deploy_time = datetime.fromtimestamp(float(local_storage.get("deploy_time", now.timestamp()))) # type: ignore
report_generator = HTMLReportGenerator( report_generator = HTMLReportGenerator(
name_mapping=self.name_mapping, name_mapping=self.name_mapping,
stat_period=self.stat_period, stat_period=self.stat_period,

View File

@@ -49,6 +49,7 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
tuple[bool, float]: (是否提及, 提及类型) tuple[bool, float]: (是否提及, 提及类型)
提及类型: 0=未提及, 1=弱提及(文本匹配), 2=强提及(@/回复/私聊) 提及类型: 0=未提及, 1=弱提及(文本匹配), 2=强提及(@/回复/私聊)
""" """
assert global_config is not None
nicknames = global_config.bot.alias_names nicknames = global_config.bot.alias_names
mention_type = 0 # 0=未提及, 1=弱提及, 2=强提及 mention_type = 0 # 0=未提及, 1=弱提及, 2=强提及
@@ -132,6 +133,7 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
async def get_embedding(text, request_type="embedding") -> list[float] | None: async def get_embedding(text, request_type="embedding") -> list[float] | None:
"""获取文本的embedding向量""" """获取文本的embedding向量"""
assert model_config is not None
# 每次都创建新的LLMRequest实例以避免事件循环冲突 # 每次都创建新的LLMRequest实例以避免事件循环冲突
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type) llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type)
try: try:
@@ -139,11 +141,12 @@ async def get_embedding(text, request_type="embedding") -> list[float] | None:
except Exception as e: except Exception as e:
logger.error(f"获取embedding失败: {e!s}") logger.error(f"获取embedding失败: {e!s}")
embedding = None embedding = None
return embedding return embedding # type: ignore
async def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list: async def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list:
# 获取当前群聊记录内发言的人 # 获取当前群聊记录内发言的人
assert global_config is not None
filter_query = {"chat_id": chat_stream_id} filter_query = {"chat_id": chat_stream_id}
sort_order = [("time", -1)] sort_order = [("time", -1)]
recent_messages = await find_messages(message_filter=filter_query, sort=sort_order, limit=limit) recent_messages = await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
@@ -400,11 +403,12 @@ def recover_quoted_content(sentences: list[str], placeholder_map: dict[str, str]
recovered_sentences.append(sentence) recovered_sentences.append(sentence)
return recovered_sentences return recovered_sentences
def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese_typo: bool = True) -> list[str]: def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese_typo: bool = True) -> list[str]:
assert global_config is not None
if not global_config.response_post_process.enable_response_post_process: if not global_config.response_post_process.enable_response_post_process:
return [text] return [text]
# --- 三层防护系统 ---
# --- 三层防护系统 --- # --- 三层防护系统 ---
# 第一层:保护颜文字 # 第一层:保护颜文字
protected_text, kaomoji_mapping = protect_kaomoji(text) if global_config.response_splitter.enable_kaomoji_protection else (text, {}) protected_text, kaomoji_mapping = protect_kaomoji(text) if global_config.response_splitter.enable_kaomoji_protection else (text, {})

View File

@@ -64,8 +64,6 @@ class ImageManager:
# except Exception as e: # except Exception as e:
# logger.error(f"数据库连接失败: {e}") # logger.error(f"数据库连接失败: {e}")
self._initialized = True
def _ensure_image_dir(self): def _ensure_image_dir(self):
"""确保图像存储目录存在""" """确保图像存储目录存在"""
os.makedirs(self.IMAGE_DIR, exist_ok=True) os.makedirs(self.IMAGE_DIR, exist_ok=True)
@@ -159,6 +157,7 @@ class ImageManager:
async def get_emoji_description(self, image_base64: str) -> str: async def get_emoji_description(self, image_base64: str) -> str:
"""获取表情包描述统一使用EmojiManager中的逻辑进行处理和缓存""" """获取表情包描述统一使用EmojiManager中的逻辑进行处理和缓存"""
try: try:
assert global_config is not None
from src.chat.emoji_system.emoji_manager import get_emoji_manager from src.chat.emoji_system.emoji_manager import get_emoji_manager
emoji_manager = get_emoji_manager() emoji_manager = get_emoji_manager()
@@ -190,7 +189,7 @@ class ImageManager:
return "[表情包(描述生成失败)]" return "[表情包(描述生成失败)]"
# 4. (可选) 如果启用了“偷表情包”,则将图片和完整描述存入待注册区 # 4. (可选) 如果启用了“偷表情包”,则将图片和完整描述存入待注册区
if global_config and global_config.emoji and global_config.emoji.steal_emoji: if global_config.emoji and global_config.emoji.steal_emoji:
logger.debug(f"偷取表情包功能已开启,保存待注册表情包: {image_hash}") logger.debug(f"偷取表情包功能已开启,保存待注册表情包: {image_hash}")
try: try:
image_format = (Image.open(io.BytesIO(image_bytes)).format or "jpeg").lower() image_format = (Image.open(io.BytesIO(image_bytes)).format or "jpeg").lower()

View File

@@ -44,6 +44,8 @@ class VideoAnalyzer:
"""基于 inkfox 的视频关键帧 + LLM 描述分析器""" """基于 inkfox 的视频关键帧 + LLM 描述分析器"""
def __init__(self) -> None: def __init__(self) -> None:
assert global_config is not None
assert model_config is not None
cfg = getattr(global_config, "video_analysis", object()) cfg = getattr(global_config, "video_analysis", object())
self.max_frames: int = getattr(cfg, "max_frames", 20) self.max_frames: int = getattr(cfg, "max_frames", 20)
self.frame_quality: int = getattr(cfg, "frame_quality", 85) self.frame_quality: int = getattr(cfg, "frame_quality", 85)

View File

@@ -135,6 +135,8 @@ class LegacyVideoAnalyzer:
def __init__(self): def __init__(self):
"""初始化视频分析器""" """初始化视频分析器"""
assert global_config is not None
assert model_config is not None
# 使用专用的视频分析配置 # 使用专用的视频分析配置
try: try:
self.video_llm = LLMRequest( self.video_llm = LLMRequest(

View File

@@ -11,6 +11,8 @@ logger = get_logger("chat_voice")
async def get_voice_text(voice_base64: str) -> str: async def get_voice_text(voice_base64: str) -> str:
"""获取音频文件转录文本""" """获取音频文件转录文本"""
assert global_config is not None
assert model_config is not None
if not global_config.voice.enable_asr: if not global_config.voice.enable_asr:
logger.warning("语音识别未启用,无法处理语音消息") logger.warning("语音识别未启用,无法处理语音消息")
return "[语音]" return "[语音]"

View File

@@ -62,7 +62,9 @@ class StreamContext(BaseDataModel):
stream_id: str stream_id: str
chat_type: ChatType = ChatType.PRIVATE # 聊天类型,默认为私聊 chat_type: ChatType = ChatType.PRIVATE # 聊天类型,默认为私聊
chat_mode: ChatMode = ChatMode.FOCUS # 聊天模式,默认为专注模式 chat_mode: ChatMode = ChatMode.FOCUS # 聊天模式,默认为专注模式
max_context_size: int = field(default_factory=lambda: getattr(global_config.chat, "max_context_size", 100)) max_context_size: int = field(
default_factory=lambda: getattr(global_config.chat, "max_context_size", 100) if global_config else 100
)
unread_messages: list["DatabaseMessages"] = field(default_factory=list) unread_messages: list["DatabaseMessages"] = field(default_factory=list)
history_messages: list["DatabaseMessages"] = field(default_factory=list) history_messages: list["DatabaseMessages"] = field(default_factory=list)
last_check_time: float = field(default_factory=time.time) last_check_time: float = field(default_factory=time.time)
@@ -98,7 +100,9 @@ class StreamContext(BaseDataModel):
def __post_init__(self): def __post_init__(self):
"""初始化历史消息异步加载""" """初始化历史消息异步加载"""
if not self.max_context_size or self.max_context_size <= 0: if not self.max_context_size or self.max_context_size <= 0:
self.max_context_size = getattr(global_config.chat, "max_context_size", 100) self.max_context_size = (
getattr(global_config.chat, "max_context_size", 100) if global_config else 100
)
try: try:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@@ -118,6 +122,7 @@ class StreamContext(BaseDataModel):
async def add_message(self, message: "DatabaseMessages", skip_energy_update: bool = False) -> bool: async def add_message(self, message: "DatabaseMessages", skip_energy_update: bool = False) -> bool:
"""添加消息到上下文,支持跳过能量更新的选项""" """添加消息到上下文,支持跳过能量更新的选项"""
try: try:
assert global_config is not None
cache_enabled = global_config.chat.enable_message_cache cache_enabled = global_config.chat.enable_message_cache
if cache_enabled and not self.is_cache_enabled: if cache_enabled and not self.is_cache_enabled:
self.enable_cache(True) self.enable_cache(True)
@@ -150,7 +155,7 @@ class StreamContext(BaseDataModel):
# ͬ<><CDAC><EFBFBD><EFBFBD><EFBFBD>ݵ<EFBFBD>ͳһ<CDB3><D2BB><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD> # ͬ<><CDAC><EFBFBD><EFBFBD><EFBFBD>ݵ<EFBFBD>ͳһ<CDB3><D2BB><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
try: try:
if global_config.memory and global_config.memory.enable: if global_config.memory and global_config.memory.enable:
unified_manager = _get_unified_memory_manager() unified_manager: Any = _get_unified_memory_manager()
if unified_manager: if unified_manager:
message_dict = { message_dict = {
"message_id": str(message.message_id), "message_id": str(message.message_id),
@@ -161,7 +166,7 @@ class StreamContext(BaseDataModel):
"platform": message.chat_info.platform, "platform": message.chat_info.platform,
"stream_id": self.stream_id, "stream_id": self.stream_id,
} }
await unified_manager.add_message(message_dict) await unified_manager.add_message(message_dict) # type: ignore
logger.debug(f"<EFBFBD><EFBFBD>Ϣ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ӵ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ϵͳ: {message.message_id}") logger.debug(f"<EFBFBD><EFBFBD>Ϣ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ӵ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ϵͳ: {message.message_id}")
except Exception as e: except Exception as e:
logger.error(f"<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ϣ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ϵͳʧ<EFBFBD><EFBFBD>: {e}") logger.error(f"<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ϣ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ϵͳʧ<EFBFBD><EFBFBD>: {e}")

View File

@@ -9,9 +9,10 @@
import operator import operator
from collections.abc import Callable from collections.abc import Callable
from functools import lru_cache from functools import lru_cache
from typing import Any, TypeVar from typing import Any, Generic, TypeVar
from sqlalchemy import delete, func, select, update from sqlalchemy import delete, func, select, update
from sqlalchemy.engine import CursorResult, Result
from src.common.database.core.models import Base from src.common.database.core.models import Base
from src.common.database.core.session import get_db_session from src.common.database.core.session import get_db_session
@@ -25,23 +26,23 @@ from src.common.logger import get_logger
logger = get_logger("database.crud") logger = get_logger("database.crud")
T = TypeVar("T", bound=Base) T = TypeVar("T", bound=Any)
@lru_cache(maxsize=256) @lru_cache(maxsize=256)
def _get_model_column_names(model: type[Base]) -> tuple[str, ...]: def _get_model_column_names(model: type[Any]) -> tuple[str, ...]:
"""获取模型的列名称列表""" """获取模型的列名称列表"""
return tuple(column.name for column in model.__table__.columns) return tuple(column.name for column in model.__table__.columns)
@lru_cache(maxsize=256) @lru_cache(maxsize=256)
def _get_model_field_set(model: type[Base]) -> frozenset[str]: def _get_model_field_set(model: type[Any]) -> frozenset[str]:
"""获取模型的有效字段集合""" """获取模型的有效字段集合"""
return frozenset(_get_model_column_names(model)) return frozenset(_get_model_column_names(model))
@lru_cache(maxsize=256) @lru_cache(maxsize=256)
def _get_model_value_fetcher(model: type[Base]) -> Callable[[Base], tuple[Any, ...]]: def _get_model_value_fetcher(model: type[Any]) -> Callable[[Any], tuple[Any, ...]]:
"""为模型准备attrgetter用于批量获取属性值""" """为模型准备attrgetter用于批量获取属性值"""
column_names = _get_model_column_names(model) column_names = _get_model_column_names(model)
@@ -51,21 +52,21 @@ def _get_model_value_fetcher(model: type[Base]) -> Callable[[Base], tuple[Any, .
if len(column_names) == 1: if len(column_names) == 1:
attr_name = column_names[0] attr_name = column_names[0]
def _single(instance: Base) -> tuple[Any, ...]: def _single(instance: Any) -> tuple[Any, ...]:
return (getattr(instance, attr_name),) return (getattr(instance, attr_name),)
return _single return _single
getter = operator.attrgetter(*column_names) getter = operator.attrgetter(*column_names)
def _multi(instance: Base) -> tuple[Any, ...]: def _multi(instance: Any) -> tuple[Any, ...]:
values = getter(instance) values = getter(instance)
return values if isinstance(values, tuple) else (values,) return values if isinstance(values, tuple) else (values,)
return _multi return _multi
def _model_to_dict(instance: Base) -> dict[str, Any]: def _model_to_dict(instance: Any) -> dict[str, Any]:
"""将 SQLAlchemy 模型实例转换为字典 """将 SQLAlchemy 模型实例转换为字典
Args: Args:
@@ -113,7 +114,7 @@ def _dict_to_model(model_class: type[T], data: dict[str, Any]) -> T:
return instance return instance
class CRUDBase: class CRUDBase(Generic[T]):
"""基础CRUD操作类 """基础CRUD操作类
提供通用的增删改查操作,自动集成缓存和批处理 提供通用的增删改查操作,自动集成缓存和批处理
@@ -249,7 +250,7 @@ class CRUDBase:
if cached_dicts is not None: if cached_dicts is not None:
logger.debug(f"缓存命中: {cache_key}") logger.debug(f"缓存命中: {cache_key}")
# 从字典列表恢复对象列表 # 从字典列表恢复对象列表
return [_dict_to_model(self.model, d) for d in cached_dicts] return [_dict_to_model(self.model, d) for d in cached_dicts] # type: ignore
# 从数据库查询 # 从数据库查询
async with get_db_session() as session: async with get_db_session() as session:
@@ -278,7 +279,7 @@ class CRUDBase:
await cache.set(cache_key, instances_dicts) await cache.set(cache_key, instances_dicts)
# 从字典列表重建对象列表返回detached状态所有字段已加载 # 从字典列表重建对象列表返回detached状态所有字段已加载
return [_dict_to_model(self.model, d) for d in instances_dicts] return [_dict_to_model(self.model, d) for d in instances_dicts] # type: ignore
async def create( async def create(
self, self,
@@ -420,7 +421,7 @@ class CRUDBase:
async with get_db_session() as session: async with get_db_session() as session:
stmt = delete(self.model).where(self.model.id == id) stmt = delete(self.model).where(self.model.id == id)
result = await session.execute(stmt) result = await session.execute(stmt)
success = result.rowcount > 0 success = result.rowcount > 0 # type: ignore
# 注意commit在get_db_session的context manager退出时自动执行 # 注意commit在get_db_session的context manager退出时自动执行
# 清除缓存 # 清除缓存
@@ -455,7 +456,7 @@ class CRUDBase:
stmt = stmt.where(getattr(self.model, key) == value) stmt = stmt.where(getattr(self.model, key) == value)
result = await session.execute(stmt) result = await session.execute(stmt)
return result.scalar() return int(result.scalar() or 0)
async def exists( async def exists(
self, self,
@@ -549,7 +550,7 @@ class CRUDBase:
.values(**obj_in) .values(**obj_in)
) )
result = await session.execute(stmt) result = await session.execute(stmt)
count += result.rowcount count += result.rowcount # type: ignore
# 清除缓存 # 清除缓存
cache_key = f"{self.model_name}:id:{id}" cache_key = f"{self.model_name}:id:{id}"

View File

@@ -20,7 +20,7 @@ from src.common.logger import get_logger
logger = get_logger("database.query") logger = get_logger("database.query")
T = TypeVar("T", bound="Base") T = TypeVar("T", bound=Any)
class QueryBuilder(Generic[T]): class QueryBuilder(Generic[T]):
@@ -330,7 +330,7 @@ class QueryBuilder(Generic[T]):
items = await self.all() items = await self.all()
return items, total return items, total # type: ignore
class AggregateQuery: class AggregateQuery:

View File

@@ -122,7 +122,7 @@ async def get_recent_actions(
动作记录列表 动作记录列表
""" """
query = QueryBuilder(ActionRecords) query = QueryBuilder(ActionRecords)
return await query.filter(chat_id=chat_id).order_by("-time").limit(limit).all() return await query.filter(chat_id=chat_id).order_by("-time").limit(limit).all() # type: ignore
# ===== Messages 业务API ===== # ===== Messages 业务API =====
@@ -148,7 +148,7 @@ async def get_chat_history(
.limit(limit) .limit(limit)
.offset(offset) .offset(offset)
.all() .all()
) ) # type: ignore
async def get_message_count(stream_id: str) -> int: async def get_message_count(stream_id: str) -> int:
@@ -292,7 +292,7 @@ async def get_active_streams(
if platform: if platform:
query = query.filter(platform=platform) query = query.filter(platform=platform)
return await query.order_by("-last_message_time").limit(limit).all() return await query.order_by("-last_message_time").limit(limit).all() # type: ignore
# ===== LLMUsage 业务API ===== # ===== LLMUsage 业务API =====
@@ -390,7 +390,7 @@ async def get_usage_statistics(
# 聚合统计 # 聚合统计
total_input = await query.sum("input_tokens") total_input = await query.sum("input_tokens")
total_output = await query.sum("output_tokens") total_output = await query.sum("output_tokens")
total_count = await query.filter().count() if hasattr(query, "count") else 0 total_count = await getattr(query.filter(), "count")() if hasattr(query, "count") else 0
return { return {
"total_input_tokens": int(total_input), "total_input_tokens": int(total_input),

View File

@@ -123,7 +123,7 @@ async def build_filters(model_class, filters: dict[str, Any]):
return conditions return conditions
def _model_to_dict(instance) -> dict[str, Any]: def _model_to_dict(instance) -> dict[str, Any] | None:
"""将数据库模型实例转换为字典兼容旧API """将数据库模型实例转换为字典兼容旧API
Args: Args:
@@ -238,7 +238,7 @@ async def db_query(
return None return None
# 更新记录 # 更新记录
updated = await crud.update(instance.id, data) updated = await crud.update(instance.id, data) # type: ignore
return _model_to_dict(updated) return _model_to_dict(updated)
elif query_type == "delete": elif query_type == "delete":
@@ -257,7 +257,7 @@ async def db_query(
return None return None
# 删除记录 # 删除记录
success = await crud.delete(instance.id) success = await crud.delete(instance.id) # type: ignore
return {"deleted": success} return {"deleted": success}
elif query_type == "count": elif query_type == "count":

View File

@@ -46,6 +46,7 @@ async def get_engine() -> AsyncEngine:
if _engine_lock is None: if _engine_lock is None:
_engine_lock = asyncio.Lock() _engine_lock = asyncio.Lock()
assert _engine_lock is not None
# 使用锁保护初始化过程 # 使用锁保护初始化过程
async with _engine_lock: async with _engine_lock:
# 双重检查锁定模式 # 双重检查锁定模式
@@ -55,6 +56,7 @@ async def get_engine() -> AsyncEngine:
try: try:
from src.config.config import global_config from src.config.config import global_config
assert global_config is not None
config = global_config.database config = global_config.database
db_type = config.database_type db_type = config.database_type

View File

@@ -44,6 +44,7 @@ def get_string_field(max_length=255, **kwargs):
""" """
from src.config.config import global_config from src.config.config import global_config
assert global_config is not None
db_type = global_config.database.database_type db_type = global_config.database.database_type
# MySQL 索引需要指定长度的 VARCHAR # MySQL 索引需要指定长度的 VARCHAR

View File

@@ -75,6 +75,7 @@ async def _apply_session_settings(session: AsyncSession, db_type: str) -> None:
# 可以设置 schema 搜索路径等 # 可以设置 schema 搜索路径等
from src.config.config import global_config from src.config.config import global_config
assert global_config is not None
schema = global_config.database.postgresql_schema schema = global_config.database.postgresql_schema
if schema and schema != "public": if schema and schema != "public":
await session.execute(text(f"SET search_path TO {schema}")) await session.execute(text(f"SET search_path TO {schema}"))
@@ -114,6 +115,7 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
# 获取数据库类型并应用特定设置 # 获取数据库类型并应用特定设置
from src.config.config import global_config from src.config.config import global_config
assert global_config is not None
await _apply_session_settings(session, global_config.database.database_type) await _apply_session_settings(session, global_config.database.database_type)
yield session yield session
@@ -142,6 +144,7 @@ async def get_db_session_direct() -> AsyncGenerator[AsyncSession, None]:
# 应用数据库特定设置 # 应用数据库特定设置
from src.config.config import global_config from src.config.config import global_config
assert global_config is not None
await _apply_session_settings(session, global_config.database.database_type) await _apply_session_settings(session, global_config.database.database_type)
yield session yield session

View File

@@ -13,7 +13,7 @@ from collections import defaultdict, deque
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import IntEnum from enum import IntEnum
from typing import Any, TypeVar from typing import Any
from sqlalchemy import delete, insert, select, update from sqlalchemy import delete, insert, select, update
@@ -23,8 +23,6 @@ from src.common.memory_utils import estimate_size_smart
logger = get_logger("batch_scheduler") logger = get_logger("batch_scheduler")
T = TypeVar("T")
class Priority(IntEnum): class Priority(IntEnum):
"""操作优先级""" """操作优先级"""
@@ -429,7 +427,7 @@ class AdaptiveBatchScheduler:
# 执行更新但不commit # 执行更新但不commit
result = await session.execute(stmt) result = await session.execute(stmt)
results.append((op, result.rowcount)) results.append((op, result.rowcount)) # type: ignore
# 注意commit 由 get_db_session_direct 上下文管理器自动处理 # 注意commit 由 get_db_session_direct 上下文管理器自动处理
@@ -471,7 +469,7 @@ class AdaptiveBatchScheduler:
# 执行删除但不commit # 执行删除但不commit
result = await session.execute(stmt) result = await session.execute(stmt)
results.append((op, result.rowcount)) results.append((op, result.rowcount)) # type: ignore
# 注意commit 由 get_db_session_direct 上下文管理器自动处理 # 注意commit 由 get_db_session_direct 上下文管理器自动处理

View File

@@ -398,47 +398,48 @@ class MultiLevelCache:
l2_stats_task = asyncio.create_task(self._get_cache_stats_safe(self.l2_cache, "L2")) l2_stats_task = asyncio.create_task(self._get_cache_stats_safe(self.l2_cache, "L2"))
# 使用超时避免死锁 # 使用超时避免死锁
try: results = await asyncio.gather(
l1_stats, l2_stats = await asyncio.gather( asyncio.wait_for(l1_stats_task, timeout=1.0),
asyncio.wait_for(l1_stats_task, timeout=1.0), asyncio.wait_for(l2_stats_task, timeout=1.0),
asyncio.wait_for(l2_stats_task, timeout=1.0), return_exceptions=True
return_exceptions=True )
) l1_stats = results[0]
except asyncio.TimeoutError: l2_stats = results[1]
logger.warning("缓存统计获取超时,使用基本统计")
l1_stats = await self.l1_cache.get_stats()
l2_stats = await self.l2_cache.get_stats()
# 处理异常情况 # 处理异常情况
if isinstance(l1_stats, Exception): if isinstance(l1_stats, BaseException):
logger.error(f"L1统计获取失败: {l1_stats}") logger.error(f"L1统计获取失败: {l1_stats}")
l1_stats = CacheStats() l1_stats = CacheStats()
if isinstance(l2_stats, Exception): if isinstance(l2_stats, BaseException):
logger.error(f"L2统计获取失败: {l2_stats}") logger.error(f"L2统计获取失败: {l2_stats}")
l2_stats = CacheStats() l2_stats = CacheStats()
assert isinstance(l1_stats, CacheStats)
assert isinstance(l2_stats, CacheStats)
# 🔧 修复:并行获取键集合,避免锁嵌套 # 🔧 修复:并行获取键集合,避免锁嵌套
l1_keys_task = asyncio.create_task(self._get_cache_keys_safe(self.l1_cache)) l1_keys_task = asyncio.create_task(self._get_cache_keys_safe(self.l1_cache))
l2_keys_task = asyncio.create_task(self._get_cache_keys_safe(self.l2_cache)) l2_keys_task = asyncio.create_task(self._get_cache_keys_safe(self.l2_cache))
try: results = await asyncio.gather(
l1_keys, l2_keys = await asyncio.gather( asyncio.wait_for(l1_keys_task, timeout=1.0),
asyncio.wait_for(l1_keys_task, timeout=1.0), asyncio.wait_for(l2_keys_task, timeout=1.0),
asyncio.wait_for(l2_keys_task, timeout=1.0), return_exceptions=True
return_exceptions=True )
) l1_keys = results[0]
except asyncio.TimeoutError: l2_keys = results[1]
logger.warning("缓存键获取超时,使用默认值")
l1_keys, l2_keys = set(), set()
# 处理异常情况 # 处理异常情况
if isinstance(l1_keys, Exception): if isinstance(l1_keys, BaseException):
logger.warning(f"L1键获取失败: {l1_keys}") logger.warning(f"L1键获取失败: {l1_keys}")
l1_keys = set() l1_keys = set()
if isinstance(l2_keys, Exception): if isinstance(l2_keys, BaseException):
logger.warning(f"L2键获取失败: {l2_keys}") logger.warning(f"L2键获取失败: {l2_keys}")
l2_keys = set() l2_keys = set()
assert isinstance(l1_keys, set)
assert isinstance(l2_keys, set)
# 计算共享键和独占键 # 计算共享键和独占键
shared_keys = l1_keys & l2_keys shared_keys = l1_keys & l2_keys
l1_only_keys = l1_keys - l2_keys l1_only_keys = l1_keys - l2_keys
@@ -448,24 +449,25 @@ class MultiLevelCache:
l1_size_task = asyncio.create_task(self._calculate_memory_usage_safe(self.l1_cache, l1_keys)) l1_size_task = asyncio.create_task(self._calculate_memory_usage_safe(self.l1_cache, l1_keys))
l2_size_task = asyncio.create_task(self._calculate_memory_usage_safe(self.l2_cache, l2_keys)) l2_size_task = asyncio.create_task(self._calculate_memory_usage_safe(self.l2_cache, l2_keys))
try: results = await asyncio.gather(
l1_size, l2_size = await asyncio.gather( asyncio.wait_for(l1_size_task, timeout=1.0),
asyncio.wait_for(l1_size_task, timeout=1.0), asyncio.wait_for(l2_size_task, timeout=1.0),
asyncio.wait_for(l2_size_task, timeout=1.0), return_exceptions=True
return_exceptions=True )
) l1_size = results[0]
except asyncio.TimeoutError: l2_size = results[1]
logger.warning("内存计算超时,使用统计值")
l1_size, l2_size = l1_stats.total_size, l2_stats.total_size
# 处理异常情况 # 处理异常情况
if isinstance(l1_size, Exception): if isinstance(l1_size, BaseException):
logger.warning(f"L1内存计算失败: {l1_size}") logger.warning(f"L1内存计算失败: {l1_size}")
l1_size = l1_stats.total_size l1_size = l1_stats.total_size
if isinstance(l2_size, Exception): if isinstance(l2_size, BaseException):
logger.warning(f"L2内存计算失败: {l2_size}") logger.warning(f"L2内存计算失败: {l2_size}")
l2_size = l2_stats.total_size l2_size = l2_stats.total_size
assert isinstance(l1_size, int)
assert isinstance(l2_size, int)
# 计算实际总内存(避免重复计数) # 计算实际总内存(避免重复计数)
actual_total_size = l1_size + l2_size - min(l1_stats.total_size, l2_stats.total_size) actual_total_size = l1_size + l2_size - min(l1_stats.total_size, l2_stats.total_size)
@@ -769,6 +771,7 @@ async def get_cache() -> MultiLevelCache:
try: try:
from src.config.config import global_config from src.config.config import global_config
assert global_config is not None
db_config = global_config.database db_config = global_config.database
# 检查是否启用缓存 # 检查是否启用缓存

View File

@@ -11,7 +11,7 @@ import functools
import hashlib import hashlib
import time import time
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import Any, TypeVar from typing import Any, ParamSpec, TypeVar
from sqlalchemy.exc import DBAPIError, OperationalError from sqlalchemy.exc import DBAPIError, OperationalError
from sqlalchemy.exc import TimeoutError as SQLTimeoutError from sqlalchemy.exc import TimeoutError as SQLTimeoutError
@@ -56,8 +56,9 @@ def generate_cache_key(
return ":".join(cache_key_parts) return ":".join(cache_key_parts)
T = TypeVar("T")
F = TypeVar("F", bound=Callable[..., Awaitable[Any]]) P = ParamSpec("P")
R = TypeVar("R")
def retry( def retry(
@@ -77,14 +78,13 @@ def retry(
exceptions: 需要重试的异常类型 exceptions: 需要重试的异常类型
Example: Example:
@retry(max_attempts=3, delay=1.0)
async def query_data(): async def query_data():
return await session.execute(stmt) return await session.execute(stmt)
""" """
def decorator(func: Callable[..., T]) -> Callable[..., T]: def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
@functools.wraps(func) @functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> T: async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
last_exception = None last_exception = None
current_delay = delay current_delay = delay
@@ -107,7 +107,9 @@ def retry(
) )
# 所有尝试都失败 # 所有尝试都失败
raise last_exception if last_exception:
raise last_exception
raise RuntimeError(f"Retry failed after {max_attempts} attempts")
return wrapper return wrapper
@@ -128,9 +130,9 @@ def timeout(seconds: float):
return await session.execute(complex_stmt) return await session.execute(complex_stmt)
""" """
def decorator(func: Callable[..., T]) -> Callable[..., T]: def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
@functools.wraps(func) @functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> T: async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
try: try:
return await asyncio.wait_for(func(*args, **kwargs), timeout=seconds) return await asyncio.wait_for(func(*args, **kwargs), timeout=seconds)
except asyncio.TimeoutError: except asyncio.TimeoutError:
@@ -164,9 +166,9 @@ def cached(
return await query_user(user_id) return await query_user(user_id)
""" """
def decorator(func: Callable[..., T]) -> Callable[..., T]: def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
@functools.wraps(func) @functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> T: async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
# 延迟导入避免循环依赖 # 延迟导入避免循环依赖
from src.common.database.optimization import get_cache from src.common.database.optimization import get_cache
@@ -226,9 +228,9 @@ def measure_time(log_slow: float | None = None):
return await session.execute(stmt) return await session.execute(stmt)
""" """
def decorator(func: Callable[..., T]) -> Callable[..., T]: def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
@functools.wraps(func) @functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> T: async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
start_time = time.perf_counter() start_time = time.perf_counter()
try: try:
@@ -268,21 +270,23 @@ def transactional(auto_commit: bool = True, auto_rollback: bool = True):
函数需要接受session参数 函数需要接受session参数
""" """
def decorator(func: Callable[..., T]) -> Callable[..., T]: def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
@functools.wraps(func) @functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> T: async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
# 查找session参数 # 查找session参数
session = None from sqlalchemy.ext.asyncio import AsyncSession
if args:
from sqlalchemy.ext.asyncio import AsyncSession
session: AsyncSession | None = None
if args:
for arg in args: for arg in args:
if isinstance(arg, AsyncSession): if isinstance(arg, AsyncSession):
session = arg session = arg
break break
if not session and "session" in kwargs: if not session and "session" in kwargs:
session = kwargs["session"] possible_session = kwargs["session"]
if isinstance(possible_session, AsyncSession):
session = possible_session
if not session: if not session:
logger.warning(f"{func.__name__} 未找到session参数跳过事务管理") logger.warning(f"{func.__name__} 未找到session参数跳过事务管理")
@@ -331,7 +335,7 @@ def db_operation(
return await complex_operation() return await complex_operation()
""" """
def decorator(func: Callable[..., T]) -> Callable[..., T]: def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
# 从内到外应用装饰器 # 从内到外应用装饰器
wrapped = func wrapped = func