依旧修pyright喵~
This commit is contained in:
@@ -83,7 +83,7 @@ class ChatterManager:
|
|||||||
inactive_streams = []
|
inactive_streams = []
|
||||||
for stream_id, instance in self.instances.items():
|
for stream_id, instance in self.instances.items():
|
||||||
if hasattr(instance, "get_activity_time"):
|
if hasattr(instance, "get_activity_time"):
|
||||||
activity_time = instance.get_activity_time()
|
activity_time = getattr(instance, "get_activity_time")()
|
||||||
if (current_time - activity_time) > max_inactive_seconds:
|
if (current_time - activity_time) > max_inactive_seconds:
|
||||||
inactive_streams.append(stream_id)
|
inactive_streams.append(stream_id)
|
||||||
|
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ class ChatterActionManager:
|
|||||||
log_prefix=log_prefix,
|
log_prefix=log_prefix,
|
||||||
shutting_down=shutting_down,
|
shutting_down=shutting_down,
|
||||||
plugin_config=plugin_config,
|
plugin_config=plugin_config,
|
||||||
action_message=action_message,
|
action_message=action_message, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"创建Action实例成功: {action_name}")
|
logger.debug(f"创建Action实例成功: {action_name}")
|
||||||
@@ -173,6 +173,7 @@ class ChatterActionManager:
|
|||||||
Returns:
|
Returns:
|
||||||
执行结果
|
执行结果
|
||||||
"""
|
"""
|
||||||
|
assert global_config is not None
|
||||||
|
|
||||||
chat_stream = None
|
chat_stream = None
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ class ActionModifier:
|
|||||||
|
|
||||||
def __init__(self, action_manager: ChatterActionManager, chat_id: str):
|
def __init__(self, action_manager: ChatterActionManager, chat_id: str):
|
||||||
"""初始化动作处理器"""
|
"""初始化动作处理器"""
|
||||||
|
assert model_config is not None
|
||||||
self.chat_id = chat_id
|
self.chat_id = chat_id
|
||||||
# chat_stream 和 log_prefix 将在异步方法中初始化
|
# chat_stream 和 log_prefix 将在异步方法中初始化
|
||||||
self.chat_stream: "ChatStream | None" = None
|
self.chat_stream: "ChatStream | None" = None
|
||||||
@@ -67,6 +68,7 @@ class ActionModifier:
|
|||||||
|
|
||||||
处理后,ActionManager 将包含最终的可用动作集,供规划器直接使用
|
处理后,ActionManager 将包含最终的可用动作集,供规划器直接使用
|
||||||
"""
|
"""
|
||||||
|
assert global_config is not None
|
||||||
# 初始化log_prefix
|
# 初始化log_prefix
|
||||||
await self._initialize_log_prefix()
|
await self._initialize_log_prefix()
|
||||||
# 根据 stream_id 加载当前可用的动作
|
# 根据 stream_id 加载当前可用的动作
|
||||||
|
|||||||
@@ -240,6 +240,8 @@ class DefaultReplyer:
|
|||||||
chat_stream: "ChatStream",
|
chat_stream: "ChatStream",
|
||||||
request_type: str = "replyer",
|
request_type: str = "replyer",
|
||||||
):
|
):
|
||||||
|
assert global_config is not None
|
||||||
|
assert model_config is not None
|
||||||
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
|
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
|
||||||
self.chat_stream = chat_stream
|
self.chat_stream = chat_stream
|
||||||
# 这些将在异步初始化中设置
|
# 这些将在异步初始化中设置
|
||||||
@@ -267,6 +269,7 @@ class DefaultReplyer:
|
|||||||
|
|
||||||
async def _build_auth_role_prompt(self) -> str:
|
async def _build_auth_role_prompt(self) -> str:
|
||||||
"""根据主人配置生成额外提示词"""
|
"""根据主人配置生成额外提示词"""
|
||||||
|
assert global_config is not None
|
||||||
master_config = global_config.permission.master_prompt
|
master_config = global_config.permission.master_prompt
|
||||||
if not master_config or not master_config.enable:
|
if not master_config or not master_config.enable:
|
||||||
return ""
|
return ""
|
||||||
@@ -515,6 +518,7 @@ class DefaultReplyer:
|
|||||||
Returns:
|
Returns:
|
||||||
str: 表达习惯信息字符串
|
str: 表达习惯信息字符串
|
||||||
"""
|
"""
|
||||||
|
assert global_config is not None
|
||||||
# 检查是否允许在此聊天流中使用表达
|
# 检查是否允许在此聊天流中使用表达
|
||||||
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(self.chat_stream.stream_id)
|
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(self.chat_stream.stream_id)
|
||||||
if not use_expression:
|
if not use_expression:
|
||||||
@@ -583,6 +587,7 @@ class DefaultReplyer:
|
|||||||
Returns:
|
Returns:
|
||||||
str: 记忆信息字符串
|
str: 记忆信息字符串
|
||||||
"""
|
"""
|
||||||
|
assert global_config is not None
|
||||||
# 检查是否启用三层记忆系统
|
# 检查是否启用三层记忆系统
|
||||||
if not (global_config.memory and global_config.memory.enable):
|
if not (global_config.memory and global_config.memory.enable):
|
||||||
return ""
|
return ""
|
||||||
@@ -776,6 +781,7 @@ class DefaultReplyer:
|
|||||||
Returns:
|
Returns:
|
||||||
str: 关键词反应提示字符串,如果没有触发任何反应则为空字符串
|
str: 关键词反应提示字符串,如果没有触发任何反应则为空字符串
|
||||||
"""
|
"""
|
||||||
|
assert global_config is not None
|
||||||
if target is None:
|
if target is None:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
@@ -834,6 +840,7 @@ class DefaultReplyer:
|
|||||||
Returns:
|
Returns:
|
||||||
str: 格式化的notice信息文本,如果没有notice或未启用则返回空字符串
|
str: 格式化的notice信息文本,如果没有notice或未启用则返回空字符串
|
||||||
"""
|
"""
|
||||||
|
assert global_config is not None
|
||||||
try:
|
try:
|
||||||
logger.debug(f"开始构建notice块,chat_id={chat_id}")
|
logger.debug(f"开始构建notice块,chat_id={chat_id}")
|
||||||
|
|
||||||
@@ -902,6 +909,7 @@ class DefaultReplyer:
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple[str, str]: (已读历史消息prompt, 未读历史消息prompt)
|
Tuple[str, str]: (已读历史消息prompt, 未读历史消息prompt)
|
||||||
"""
|
"""
|
||||||
|
assert global_config is not None
|
||||||
try:
|
try:
|
||||||
# 从message_manager获取真实的已读/未读消息
|
# 从message_manager获取真实的已读/未读消息
|
||||||
|
|
||||||
@@ -1002,6 +1010,7 @@ class DefaultReplyer:
|
|||||||
"""
|
"""
|
||||||
回退的已读/未读历史消息构建方法
|
回退的已读/未读历史消息构建方法
|
||||||
"""
|
"""
|
||||||
|
assert global_config is not None
|
||||||
# 通过is_read字段分离已读和未读消息
|
# 通过is_read字段分离已读和未读消息
|
||||||
read_messages = []
|
read_messages = []
|
||||||
unread_messages = []
|
unread_messages = []
|
||||||
@@ -1115,6 +1124,7 @@ class DefaultReplyer:
|
|||||||
Returns:
|
Returns:
|
||||||
str: 构建好的上下文
|
str: 构建好的上下文
|
||||||
"""
|
"""
|
||||||
|
assert global_config is not None
|
||||||
if available_actions is None:
|
if available_actions is None:
|
||||||
available_actions = {}
|
available_actions = {}
|
||||||
chat_stream = self.chat_stream
|
chat_stream = self.chat_stream
|
||||||
@@ -1607,6 +1617,7 @@ class DefaultReplyer:
|
|||||||
reply_to: str,
|
reply_to: str,
|
||||||
reply_message: dict[str, Any] | DatabaseMessages | None = None,
|
reply_message: dict[str, Any] | DatabaseMessages | None = None,
|
||||||
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||||
|
assert global_config is not None
|
||||||
chat_stream = self.chat_stream
|
chat_stream = self.chat_stream
|
||||||
chat_id = chat_stream.stream_id
|
chat_id = chat_stream.stream_id
|
||||||
is_group_chat = bool(chat_stream.group_info)
|
is_group_chat = bool(chat_stream.group_info)
|
||||||
@@ -1767,6 +1778,7 @@ class DefaultReplyer:
|
|||||||
return prompt_text
|
return prompt_text
|
||||||
|
|
||||||
async def llm_generate_content(self, prompt: str):
|
async def llm_generate_content(self, prompt: str):
|
||||||
|
assert global_config is not None
|
||||||
with Timer("LLM生成", {}): # 内部计时器,可选保留
|
with Timer("LLM生成", {}): # 内部计时器,可选保留
|
||||||
# 直接使用已初始化的模型实例
|
# 直接使用已初始化的模型实例
|
||||||
logger.info(f"使用模型集生成回复: {self.express_model.model_for_task}")
|
logger.info(f"使用模型集生成回复: {self.express_model.model_for_task}")
|
||||||
@@ -1792,6 +1804,8 @@ class DefaultReplyer:
|
|||||||
return content, reasoning_content, model_name, tool_calls
|
return content, reasoning_content, model_name, tool_calls
|
||||||
|
|
||||||
async def get_prompt_info(self, message: str, sender: str, target: str):
|
async def get_prompt_info(self, message: str, sender: str, target: str):
|
||||||
|
assert global_config is not None
|
||||||
|
assert model_config is not None
|
||||||
related_info = ""
|
related_info = ""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
from src.plugins.built_in.knowledge.lpmm_get_knowledge import SearchKnowledgeFromLPMMTool
|
from src.plugins.built_in.knowledge.lpmm_get_knowledge import SearchKnowledgeFromLPMMTool
|
||||||
@@ -1843,6 +1857,7 @@ class DefaultReplyer:
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
async def build_relation_info(self, sender: str, target: str):
|
async def build_relation_info(self, sender: str, target: str):
|
||||||
|
assert global_config is not None
|
||||||
# 获取用户ID
|
# 获取用户ID
|
||||||
if sender == f"{global_config.bot.nickname}(你)":
|
if sender == f"{global_config.bot.nickname}(你)":
|
||||||
return "你将要回复的是你自己发送的消息。"
|
return "你将要回复的是你自己发送的消息。"
|
||||||
@@ -1927,6 +1942,7 @@ class DefaultReplyer:
|
|||||||
reply_to: 回复对象
|
reply_to: 回复对象
|
||||||
reply_message: 回复的原始消息
|
reply_message: 回复的原始消息
|
||||||
"""
|
"""
|
||||||
|
assert global_config is not None
|
||||||
return # 已禁用,保留函数签名以防其他地方有引用
|
return # 已禁用,保留函数签名以防其他地方有引用
|
||||||
|
|
||||||
# 以下代码已废弃,不再执行
|
# 以下代码已废弃,不再执行
|
||||||
|
|||||||
@@ -173,9 +173,10 @@ class SecurityManager:
|
|||||||
pre_check_results = await asyncio.gather(*pre_check_tasks, return_exceptions=True)
|
pre_check_results = await asyncio.gather(*pre_check_tasks, return_exceptions=True)
|
||||||
|
|
||||||
# 筛选需要完整检查的检测器
|
# 筛选需要完整检查的检测器
|
||||||
checkers_to_run = [
|
checkers_to_run = []
|
||||||
c for c, need_check in zip(enabled_checkers, pre_check_results) if need_check is True
|
for c, need_check in zip(enabled_checkers, pre_check_results):
|
||||||
]
|
if need_check is True:
|
||||||
|
checkers_to_run.append(c)
|
||||||
|
|
||||||
if not checkers_to_run:
|
if not checkers_to_run:
|
||||||
return SecurityCheckResult(
|
return SecurityCheckResult(
|
||||||
@@ -192,20 +193,22 @@ class SecurityManager:
|
|||||||
results = await asyncio.gather(*check_tasks, return_exceptions=True)
|
results = await asyncio.gather(*check_tasks, return_exceptions=True)
|
||||||
|
|
||||||
# 过滤异常结果
|
# 过滤异常结果
|
||||||
valid_results = []
|
valid_results: list[SecurityCheckResult] = []
|
||||||
for checker, result in zip(checkers_to_run, results):
|
for checker, result in zip(checkers_to_run, results):
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, BaseException):
|
||||||
logger.error(f"检测器 '{checker.name}' 执行失败: {result}")
|
logger.error(f"检测器 '{checker.name}' 执行失败: {result}")
|
||||||
continue
|
continue
|
||||||
result.checker_name = checker.name
|
|
||||||
valid_results.append(result)
|
if isinstance(result, SecurityCheckResult):
|
||||||
|
result.checker_name = checker.name
|
||||||
|
valid_results.append(result)
|
||||||
|
|
||||||
# 合并结果
|
# 合并结果
|
||||||
return self._merge_results(valid_results, time.time() - start_time)
|
return self._merge_results(valid_results, time.time() - start_time)
|
||||||
|
|
||||||
async def _check_all(self, message: str, context: dict, start_time: float) -> SecurityCheckResult:
|
async def _check_all(self, message: str, context: dict, start_time: float) -> SecurityCheckResult:
|
||||||
"""检测所有模式(顺序执行所有检测器)"""
|
"""检测所有模式(顺序执行所有检测器)"""
|
||||||
results = []
|
results: list[SecurityCheckResult] = []
|
||||||
|
|
||||||
for checker in self._checkers:
|
for checker in self._checkers:
|
||||||
if not checker.enabled:
|
if not checker.enabled:
|
||||||
|
|||||||
@@ -39,11 +39,13 @@ def replace_user_references_sync(
|
|||||||
Returns:
|
Returns:
|
||||||
str: 处理后的内容字符串
|
str: 处理后的内容字符串
|
||||||
"""
|
"""
|
||||||
|
assert global_config is not None
|
||||||
if not content:
|
if not content:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
if name_resolver is None:
|
if name_resolver is None:
|
||||||
def default_resolver(platform: str, user_id: str) -> str:
|
def default_resolver(platform: str, user_id: str) -> str:
|
||||||
|
assert global_config is not None
|
||||||
# 检查是否是机器人自己
|
# 检查是否是机器人自己
|
||||||
if replace_bot_name and (user_id == str(global_config.bot.qq_account)):
|
if replace_bot_name and (user_id == str(global_config.bot.qq_account)):
|
||||||
return f"{global_config.bot.nickname}(你)"
|
return f"{global_config.bot.nickname}(你)"
|
||||||
@@ -116,10 +118,12 @@ async def replace_user_references_async(
|
|||||||
Returns:
|
Returns:
|
||||||
str: 处理后的内容字符串
|
str: 处理后的内容字符串
|
||||||
"""
|
"""
|
||||||
|
assert global_config is not None
|
||||||
if name_resolver is None:
|
if name_resolver is None:
|
||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
|
|
||||||
async def default_resolver(platform: str, user_id: str) -> str:
|
async def default_resolver(platform: str, user_id: str) -> str:
|
||||||
|
assert global_config is not None
|
||||||
# 检查是否是机器人自己
|
# 检查是否是机器人自己
|
||||||
if replace_bot_name and (user_id == str(global_config.bot.qq_account)):
|
if replace_bot_name and (user_id == str(global_config.bot.qq_account)):
|
||||||
return f"{global_config.bot.nickname}(你)"
|
return f"{global_config.bot.nickname}(你)"
|
||||||
@@ -392,7 +396,7 @@ async def get_actions_by_timestamp_with_chat_inclusive(
|
|||||||
actions = list(result.scalars())
|
actions = list(result.scalars())
|
||||||
return [action.__dict__ for action in reversed(actions)]
|
return [action.__dict__ for action in reversed(actions)]
|
||||||
else: # earliest
|
else: # earliest
|
||||||
result = await session.execute(
|
query = await session.execute(
|
||||||
select(ActionRecords)
|
select(ActionRecords)
|
||||||
.where(
|
.where(
|
||||||
and_(
|
and_(
|
||||||
@@ -540,6 +544,7 @@ async def _build_readable_messages_internal(
|
|||||||
Returns:
|
Returns:
|
||||||
包含格式化消息的字符串、原始消息详情列表、图片映射字典和更新后的计数器的元组。
|
包含格式化消息的字符串、原始消息详情列表、图片映射字典和更新后的计数器的元组。
|
||||||
"""
|
"""
|
||||||
|
assert global_config is not None
|
||||||
if not messages:
|
if not messages:
|
||||||
return "", [], pic_id_mapping or {}, pic_counter
|
return "", [], pic_id_mapping or {}, pic_counter
|
||||||
|
|
||||||
@@ -694,6 +699,7 @@ async def _build_readable_messages_internal(
|
|||||||
percentile = i / n_messages # 计算消息在列表中的位置百分比 (0 <= percentile < 1)
|
percentile = i / n_messages # 计算消息在列表中的位置百分比 (0 <= percentile < 1)
|
||||||
original_len = len(content)
|
original_len = len(content)
|
||||||
limit = -1 # 默认不截断
|
limit = -1 # 默认不截断
|
||||||
|
replace_content = ""
|
||||||
|
|
||||||
if percentile < 0.2: # 60% 之前的消息 (即最旧的 60%)
|
if percentile < 0.2: # 60% 之前的消息 (即最旧的 60%)
|
||||||
limit = 50
|
limit = 50
|
||||||
@@ -973,6 +979,7 @@ async def build_readable_messages(
|
|||||||
truncate: 是否截断长消息
|
truncate: 是否截断长消息
|
||||||
show_actions: 是否显示动作记录
|
show_actions: 是否显示动作记录
|
||||||
"""
|
"""
|
||||||
|
assert global_config is not None
|
||||||
# 创建messages的深拷贝,避免修改原始列表
|
# 创建messages的深拷贝,避免修改原始列表
|
||||||
if not messages:
|
if not messages:
|
||||||
return ""
|
return ""
|
||||||
@@ -1112,6 +1119,7 @@ async def build_anonymous_messages(messages: list[dict[str, Any]]) -> str:
|
|||||||
构建匿名可读消息,将不同人的名称转为唯一占位符(A、B、C...),bot自己用SELF。
|
构建匿名可读消息,将不同人的名称转为唯一占位符(A、B、C...),bot自己用SELF。
|
||||||
处理 回复<aaa:bbb> 和 @<aaa:bbb> 字段,将bbb映射为匿名占位符。
|
处理 回复<aaa:bbb> 和 @<aaa:bbb> 字段,将bbb映射为匿名占位符。
|
||||||
"""
|
"""
|
||||||
|
assert global_config is not None
|
||||||
if not messages:
|
if not messages:
|
||||||
print("111111111111没有消息,无法构建匿名消息")
|
print("111111111111没有消息,无法构建匿名消息")
|
||||||
return ""
|
return ""
|
||||||
@@ -1127,6 +1135,7 @@ async def build_anonymous_messages(messages: list[dict[str, Any]]) -> str:
|
|||||||
def get_anon_name(platform, user_id):
|
def get_anon_name(platform, user_id):
|
||||||
# print(f"get_anon_name: platform:{platform}, user_id:{user_id}")
|
# print(f"get_anon_name: platform:{platform}, user_id:{user_id}")
|
||||||
# print(f"global_config.bot.qq_account:{global_config.bot.qq_account}")
|
# print(f"global_config.bot.qq_account:{global_config.bot.qq_account}")
|
||||||
|
assert global_config is not None
|
||||||
|
|
||||||
if user_id == global_config.bot.qq_account:
|
if user_id == global_config.bot.qq_account:
|
||||||
# print("SELF11111111111111")
|
# print("SELF11111111111111")
|
||||||
@@ -1204,6 +1213,7 @@ async def get_person_id_list(messages: list[dict[str, Any]]) -> list[str]:
|
|||||||
Returns:
|
Returns:
|
||||||
一个包含唯一 person_id 的列表。
|
一个包含唯一 person_id 的列表。
|
||||||
"""
|
"""
|
||||||
|
assert global_config is not None
|
||||||
person_ids_set = set() # 使用集合来自动去重
|
person_ids_set = set() # 使用集合来自动去重
|
||||||
|
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
|
|||||||
@@ -649,6 +649,7 @@ class Prompt:
|
|||||||
|
|
||||||
async def _build_expression_habits(self) -> dict[str, Any]:
|
async def _build_expression_habits(self) -> dict[str, Any]:
|
||||||
"""构建表达习惯(如表情、口癖)的上下文块."""
|
"""构建表达习惯(如表情、口癖)的上下文块."""
|
||||||
|
assert global_config is not None
|
||||||
# 检查当前聊天是否启用了表达习惯功能
|
# 检查当前聊天是否启用了表达习惯功能
|
||||||
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(
|
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(
|
||||||
self.parameters.chat_id
|
self.parameters.chat_id
|
||||||
@@ -728,6 +729,7 @@ class Prompt:
|
|||||||
|
|
||||||
async def _build_tool_info(self) -> dict[str, Any]:
|
async def _build_tool_info(self) -> dict[str, Any]:
|
||||||
"""构建工具调用结果的上下文块."""
|
"""构建工具调用结果的上下文块."""
|
||||||
|
assert global_config is not None
|
||||||
if not global_config.tool.enable_tool:
|
if not global_config.tool.enable_tool:
|
||||||
return {"tool_info_block": ""}
|
return {"tool_info_block": ""}
|
||||||
|
|
||||||
@@ -779,6 +781,7 @@ class Prompt:
|
|||||||
|
|
||||||
async def _build_knowledge_info(self) -> dict[str, Any]:
|
async def _build_knowledge_info(self) -> dict[str, Any]:
|
||||||
"""构建从知识库检索到的相关信息的上下文块."""
|
"""构建从知识库检索到的相关信息的上下文块."""
|
||||||
|
assert global_config is not None
|
||||||
if not global_config.lpmm_knowledge.enable:
|
if not global_config.lpmm_knowledge.enable:
|
||||||
return {"knowledge_prompt": ""}
|
return {"knowledge_prompt": ""}
|
||||||
|
|
||||||
@@ -873,6 +876,7 @@ class Prompt:
|
|||||||
|
|
||||||
def _prepare_s4u_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
|
def _prepare_s4u_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""为S4U(Scene for You)模式准备最终用于格式化的参数字典."""
|
"""为S4U(Scene for You)模式准备最终用于格式化的参数字典."""
|
||||||
|
assert global_config is not None
|
||||||
return {
|
return {
|
||||||
**context_data,
|
**context_data,
|
||||||
"expression_habits_block": context_data.get("expression_habits_block", ""),
|
"expression_habits_block": context_data.get("expression_habits_block", ""),
|
||||||
@@ -915,6 +919,7 @@ class Prompt:
|
|||||||
|
|
||||||
def _prepare_normal_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
|
def _prepare_normal_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""为Normal模式准备最终用于格式化的参数字典."""
|
"""为Normal模式准备最终用于格式化的参数字典."""
|
||||||
|
assert global_config is not None
|
||||||
return {
|
return {
|
||||||
**context_data,
|
**context_data,
|
||||||
"expression_habits_block": context_data.get("expression_habits_block", ""),
|
"expression_habits_block": context_data.get("expression_habits_block", ""),
|
||||||
@@ -959,6 +964,7 @@ class Prompt:
|
|||||||
|
|
||||||
def _prepare_default_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
|
def _prepare_default_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""为默认模式(或其他未指定模式)准备最终用于格式化的参数字典."""
|
"""为默认模式(或其他未指定模式)准备最终用于格式化的参数字典."""
|
||||||
|
assert global_config is not None
|
||||||
return {
|
return {
|
||||||
"expression_habits_block": context_data.get("expression_habits_block", ""),
|
"expression_habits_block": context_data.get("expression_habits_block", ""),
|
||||||
"relation_info_block": context_data.get("relation_info_block", ""),
|
"relation_info_block": context_data.get("relation_info_block", ""),
|
||||||
@@ -1143,6 +1149,7 @@ class Prompt:
|
|||||||
Returns:
|
Returns:
|
||||||
str: 构建好的跨群聊上下文字符串。
|
str: 构建好的跨群聊上下文字符串。
|
||||||
"""
|
"""
|
||||||
|
assert global_config is not None
|
||||||
if not global_config.cross_context.enable:
|
if not global_config.cross_context.enable:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|||||||
@@ -338,6 +338,7 @@ class HTMLReportGenerator:
|
|||||||
|
|
||||||
# 渲染模板
|
# 渲染模板
|
||||||
# 读取CSS和JS文件内容
|
# 读取CSS和JS文件内容
|
||||||
|
assert isinstance(self.jinja_env.loader, FileSystemLoader)
|
||||||
async with aiofiles.open(os.path.join(self.jinja_env.loader.searchpath[0], "report.css"), encoding="utf-8") as f:
|
async with aiofiles.open(os.path.join(self.jinja_env.loader.searchpath[0], "report.css"), encoding="utf-8") as f:
|
||||||
report_css = await f.read()
|
report_css = await f.read()
|
||||||
async with aiofiles.open(os.path.join(self.jinja_env.loader.searchpath[0], "report.js"), encoding="utf-8") as f:
|
async with aiofiles.open(os.path.join(self.jinja_env.loader.searchpath[0], "report.js"), encoding="utf-8") as f:
|
||||||
|
|||||||
@@ -192,7 +192,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
self._statistic_console_output(stats, now)
|
self._statistic_console_output(stats, now)
|
||||||
# 使用新的 HTMLReportGenerator 生成报告
|
# 使用新的 HTMLReportGenerator 生成报告
|
||||||
chart_data = await self._collect_chart_data(stats)
|
chart_data = await self._collect_chart_data(stats)
|
||||||
deploy_time = datetime.fromtimestamp(local_storage.get("deploy_time", now.timestamp()))
|
deploy_time = datetime.fromtimestamp(float(local_storage.get("deploy_time", now.timestamp()))) # type: ignore
|
||||||
report_generator = HTMLReportGenerator(
|
report_generator = HTMLReportGenerator(
|
||||||
name_mapping=self.name_mapping,
|
name_mapping=self.name_mapping,
|
||||||
stat_period=self.stat_period,
|
stat_period=self.stat_period,
|
||||||
@@ -219,7 +219,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
|
|
||||||
# 使用新的 HTMLReportGenerator 生成报告
|
# 使用新的 HTMLReportGenerator 生成报告
|
||||||
chart_data = await self._collect_chart_data(stats)
|
chart_data = await self._collect_chart_data(stats)
|
||||||
deploy_time = datetime.fromtimestamp(local_storage.get("deploy_time", now.timestamp()))
|
deploy_time = datetime.fromtimestamp(float(local_storage.get("deploy_time", now.timestamp()))) # type: ignore
|
||||||
report_generator = HTMLReportGenerator(
|
report_generator = HTMLReportGenerator(
|
||||||
name_mapping=self.name_mapping,
|
name_mapping=self.name_mapping,
|
||||||
stat_period=self.stat_period,
|
stat_period=self.stat_period,
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
|
|||||||
tuple[bool, float]: (是否提及, 提及类型)
|
tuple[bool, float]: (是否提及, 提及类型)
|
||||||
提及类型: 0=未提及, 1=弱提及(文本匹配), 2=强提及(@/回复/私聊)
|
提及类型: 0=未提及, 1=弱提及(文本匹配), 2=强提及(@/回复/私聊)
|
||||||
"""
|
"""
|
||||||
|
assert global_config is not None
|
||||||
nicknames = global_config.bot.alias_names
|
nicknames = global_config.bot.alias_names
|
||||||
mention_type = 0 # 0=未提及, 1=弱提及, 2=强提及
|
mention_type = 0 # 0=未提及, 1=弱提及, 2=强提及
|
||||||
|
|
||||||
@@ -132,6 +133,7 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
|
|||||||
|
|
||||||
async def get_embedding(text, request_type="embedding") -> list[float] | None:
|
async def get_embedding(text, request_type="embedding") -> list[float] | None:
|
||||||
"""获取文本的embedding向量"""
|
"""获取文本的embedding向量"""
|
||||||
|
assert model_config is not None
|
||||||
# 每次都创建新的LLMRequest实例以避免事件循环冲突
|
# 每次都创建新的LLMRequest实例以避免事件循环冲突
|
||||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type)
|
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type)
|
||||||
try:
|
try:
|
||||||
@@ -139,11 +141,12 @@ async def get_embedding(text, request_type="embedding") -> list[float] | None:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取embedding失败: {e!s}")
|
logger.error(f"获取embedding失败: {e!s}")
|
||||||
embedding = None
|
embedding = None
|
||||||
return embedding
|
return embedding # type: ignore
|
||||||
|
|
||||||
|
|
||||||
async def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list:
|
async def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list:
|
||||||
# 获取当前群聊记录内发言的人
|
# 获取当前群聊记录内发言的人
|
||||||
|
assert global_config is not None
|
||||||
filter_query = {"chat_id": chat_stream_id}
|
filter_query = {"chat_id": chat_stream_id}
|
||||||
sort_order = [("time", -1)]
|
sort_order = [("time", -1)]
|
||||||
recent_messages = await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
recent_messages = await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||||
@@ -400,11 +403,12 @@ def recover_quoted_content(sentences: list[str], placeholder_map: dict[str, str]
|
|||||||
recovered_sentences.append(sentence)
|
recovered_sentences.append(sentence)
|
||||||
return recovered_sentences
|
return recovered_sentences
|
||||||
|
|
||||||
|
|
||||||
def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese_typo: bool = True) -> list[str]:
|
def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese_typo: bool = True) -> list[str]:
|
||||||
|
assert global_config is not None
|
||||||
if not global_config.response_post_process.enable_response_post_process:
|
if not global_config.response_post_process.enable_response_post_process:
|
||||||
return [text]
|
return [text]
|
||||||
|
|
||||||
|
# --- 三层防护系统 ---
|
||||||
# --- 三层防护系统 ---
|
# --- 三层防护系统 ---
|
||||||
# 第一层:保护颜文字
|
# 第一层:保护颜文字
|
||||||
protected_text, kaomoji_mapping = protect_kaomoji(text) if global_config.response_splitter.enable_kaomoji_protection else (text, {})
|
protected_text, kaomoji_mapping = protect_kaomoji(text) if global_config.response_splitter.enable_kaomoji_protection else (text, {})
|
||||||
|
|||||||
@@ -64,8 +64,6 @@ class ImageManager:
|
|||||||
# except Exception as e:
|
# except Exception as e:
|
||||||
# logger.error(f"数据库连接失败: {e}")
|
# logger.error(f"数据库连接失败: {e}")
|
||||||
|
|
||||||
self._initialized = True
|
|
||||||
|
|
||||||
def _ensure_image_dir(self):
|
def _ensure_image_dir(self):
|
||||||
"""确保图像存储目录存在"""
|
"""确保图像存储目录存在"""
|
||||||
os.makedirs(self.IMAGE_DIR, exist_ok=True)
|
os.makedirs(self.IMAGE_DIR, exist_ok=True)
|
||||||
@@ -159,6 +157,7 @@ class ImageManager:
|
|||||||
async def get_emoji_description(self, image_base64: str) -> str:
|
async def get_emoji_description(self, image_base64: str) -> str:
|
||||||
"""获取表情包描述,统一使用EmojiManager中的逻辑进行处理和缓存"""
|
"""获取表情包描述,统一使用EmojiManager中的逻辑进行处理和缓存"""
|
||||||
try:
|
try:
|
||||||
|
assert global_config is not None
|
||||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||||
|
|
||||||
emoji_manager = get_emoji_manager()
|
emoji_manager = get_emoji_manager()
|
||||||
@@ -190,7 +189,7 @@ class ImageManager:
|
|||||||
return "[表情包(描述生成失败)]"
|
return "[表情包(描述生成失败)]"
|
||||||
|
|
||||||
# 4. (可选) 如果启用了“偷表情包”,则将图片和完整描述存入待注册区
|
# 4. (可选) 如果启用了“偷表情包”,则将图片和完整描述存入待注册区
|
||||||
if global_config and global_config.emoji and global_config.emoji.steal_emoji:
|
if global_config.emoji and global_config.emoji.steal_emoji:
|
||||||
logger.debug(f"偷取表情包功能已开启,保存待注册表情包: {image_hash}")
|
logger.debug(f"偷取表情包功能已开启,保存待注册表情包: {image_hash}")
|
||||||
try:
|
try:
|
||||||
image_format = (Image.open(io.BytesIO(image_bytes)).format or "jpeg").lower()
|
image_format = (Image.open(io.BytesIO(image_bytes)).format or "jpeg").lower()
|
||||||
|
|||||||
@@ -44,6 +44,8 @@ class VideoAnalyzer:
|
|||||||
"""基于 inkfox 的视频关键帧 + LLM 描述分析器"""
|
"""基于 inkfox 的视频关键帧 + LLM 描述分析器"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
assert global_config is not None
|
||||||
|
assert model_config is not None
|
||||||
cfg = getattr(global_config, "video_analysis", object())
|
cfg = getattr(global_config, "video_analysis", object())
|
||||||
self.max_frames: int = getattr(cfg, "max_frames", 20)
|
self.max_frames: int = getattr(cfg, "max_frames", 20)
|
||||||
self.frame_quality: int = getattr(cfg, "frame_quality", 85)
|
self.frame_quality: int = getattr(cfg, "frame_quality", 85)
|
||||||
|
|||||||
@@ -135,6 +135,8 @@ class LegacyVideoAnalyzer:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""初始化视频分析器"""
|
"""初始化视频分析器"""
|
||||||
|
assert global_config is not None
|
||||||
|
assert model_config is not None
|
||||||
# 使用专用的视频分析配置
|
# 使用专用的视频分析配置
|
||||||
try:
|
try:
|
||||||
self.video_llm = LLMRequest(
|
self.video_llm = LLMRequest(
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ logger = get_logger("chat_voice")
|
|||||||
|
|
||||||
async def get_voice_text(voice_base64: str) -> str:
|
async def get_voice_text(voice_base64: str) -> str:
|
||||||
"""获取音频文件转录文本"""
|
"""获取音频文件转录文本"""
|
||||||
|
assert global_config is not None
|
||||||
|
assert model_config is not None
|
||||||
if not global_config.voice.enable_asr:
|
if not global_config.voice.enable_asr:
|
||||||
logger.warning("语音识别未启用,无法处理语音消息")
|
logger.warning("语音识别未启用,无法处理语音消息")
|
||||||
return "[语音]"
|
return "[语音]"
|
||||||
|
|||||||
@@ -62,7 +62,9 @@ class StreamContext(BaseDataModel):
|
|||||||
stream_id: str
|
stream_id: str
|
||||||
chat_type: ChatType = ChatType.PRIVATE # 聊天类型,默认为私聊
|
chat_type: ChatType = ChatType.PRIVATE # 聊天类型,默认为私聊
|
||||||
chat_mode: ChatMode = ChatMode.FOCUS # 聊天模式,默认为专注模式
|
chat_mode: ChatMode = ChatMode.FOCUS # 聊天模式,默认为专注模式
|
||||||
max_context_size: int = field(default_factory=lambda: getattr(global_config.chat, "max_context_size", 100))
|
max_context_size: int = field(
|
||||||
|
default_factory=lambda: getattr(global_config.chat, "max_context_size", 100) if global_config else 100
|
||||||
|
)
|
||||||
unread_messages: list["DatabaseMessages"] = field(default_factory=list)
|
unread_messages: list["DatabaseMessages"] = field(default_factory=list)
|
||||||
history_messages: list["DatabaseMessages"] = field(default_factory=list)
|
history_messages: list["DatabaseMessages"] = field(default_factory=list)
|
||||||
last_check_time: float = field(default_factory=time.time)
|
last_check_time: float = field(default_factory=time.time)
|
||||||
@@ -98,7 +100,9 @@ class StreamContext(BaseDataModel):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""初始化历史消息异步加载"""
|
"""初始化历史消息异步加载"""
|
||||||
if not self.max_context_size or self.max_context_size <= 0:
|
if not self.max_context_size or self.max_context_size <= 0:
|
||||||
self.max_context_size = getattr(global_config.chat, "max_context_size", 100)
|
self.max_context_size = (
|
||||||
|
getattr(global_config.chat, "max_context_size", 100) if global_config else 100
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
@@ -118,6 +122,7 @@ class StreamContext(BaseDataModel):
|
|||||||
async def add_message(self, message: "DatabaseMessages", skip_energy_update: bool = False) -> bool:
|
async def add_message(self, message: "DatabaseMessages", skip_energy_update: bool = False) -> bool:
|
||||||
"""添加消息到上下文,支持跳过能量更新的选项"""
|
"""添加消息到上下文,支持跳过能量更新的选项"""
|
||||||
try:
|
try:
|
||||||
|
assert global_config is not None
|
||||||
cache_enabled = global_config.chat.enable_message_cache
|
cache_enabled = global_config.chat.enable_message_cache
|
||||||
if cache_enabled and not self.is_cache_enabled:
|
if cache_enabled and not self.is_cache_enabled:
|
||||||
self.enable_cache(True)
|
self.enable_cache(True)
|
||||||
@@ -150,7 +155,7 @@ class StreamContext(BaseDataModel):
|
|||||||
# ͬ<><CDAC><EFBFBD><EFBFBD><EFBFBD>ݵ<EFBFBD>ͳһ<CDB3><D2BB><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
|
# ͬ<><CDAC><EFBFBD><EFBFBD><EFBFBD>ݵ<EFBFBD>ͳһ<CDB3><D2BB><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
|
||||||
try:
|
try:
|
||||||
if global_config.memory and global_config.memory.enable:
|
if global_config.memory and global_config.memory.enable:
|
||||||
unified_manager = _get_unified_memory_manager()
|
unified_manager: Any = _get_unified_memory_manager()
|
||||||
if unified_manager:
|
if unified_manager:
|
||||||
message_dict = {
|
message_dict = {
|
||||||
"message_id": str(message.message_id),
|
"message_id": str(message.message_id),
|
||||||
@@ -161,7 +166,7 @@ class StreamContext(BaseDataModel):
|
|||||||
"platform": message.chat_info.platform,
|
"platform": message.chat_info.platform,
|
||||||
"stream_id": self.stream_id,
|
"stream_id": self.stream_id,
|
||||||
}
|
}
|
||||||
await unified_manager.add_message(message_dict)
|
await unified_manager.add_message(message_dict) # type: ignore
|
||||||
logger.debug(f"<EFBFBD><EFBFBD>Ϣ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ӵ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ϵͳ: {message.message_id}")
|
logger.debug(f"<EFBFBD><EFBFBD>Ϣ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ӵ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ϵͳ: {message.message_id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ϣ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ϵͳʧ<EFBFBD><EFBFBD>: {e}")
|
logger.error(f"<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ϣ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ϵͳʧ<EFBFBD><EFBFBD>: {e}")
|
||||||
|
|||||||
@@ -9,9 +9,10 @@
|
|||||||
import operator
|
import operator
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Any, TypeVar
|
from typing import Any, Generic, TypeVar
|
||||||
|
|
||||||
from sqlalchemy import delete, func, select, update
|
from sqlalchemy import delete, func, select, update
|
||||||
|
from sqlalchemy.engine import CursorResult, Result
|
||||||
|
|
||||||
from src.common.database.core.models import Base
|
from src.common.database.core.models import Base
|
||||||
from src.common.database.core.session import get_db_session
|
from src.common.database.core.session import get_db_session
|
||||||
@@ -25,23 +26,23 @@ from src.common.logger import get_logger
|
|||||||
|
|
||||||
logger = get_logger("database.crud")
|
logger = get_logger("database.crud")
|
||||||
|
|
||||||
T = TypeVar("T", bound=Base)
|
T = TypeVar("T", bound=Any)
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=256)
|
@lru_cache(maxsize=256)
|
||||||
def _get_model_column_names(model: type[Base]) -> tuple[str, ...]:
|
def _get_model_column_names(model: type[Any]) -> tuple[str, ...]:
|
||||||
"""获取模型的列名称列表"""
|
"""获取模型的列名称列表"""
|
||||||
return tuple(column.name for column in model.__table__.columns)
|
return tuple(column.name for column in model.__table__.columns)
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=256)
|
@lru_cache(maxsize=256)
|
||||||
def _get_model_field_set(model: type[Base]) -> frozenset[str]:
|
def _get_model_field_set(model: type[Any]) -> frozenset[str]:
|
||||||
"""获取模型的有效字段集合"""
|
"""获取模型的有效字段集合"""
|
||||||
return frozenset(_get_model_column_names(model))
|
return frozenset(_get_model_column_names(model))
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=256)
|
@lru_cache(maxsize=256)
|
||||||
def _get_model_value_fetcher(model: type[Base]) -> Callable[[Base], tuple[Any, ...]]:
|
def _get_model_value_fetcher(model: type[Any]) -> Callable[[Any], tuple[Any, ...]]:
|
||||||
"""为模型准备attrgetter,用于批量获取属性值"""
|
"""为模型准备attrgetter,用于批量获取属性值"""
|
||||||
column_names = _get_model_column_names(model)
|
column_names = _get_model_column_names(model)
|
||||||
|
|
||||||
@@ -51,21 +52,21 @@ def _get_model_value_fetcher(model: type[Base]) -> Callable[[Base], tuple[Any, .
|
|||||||
if len(column_names) == 1:
|
if len(column_names) == 1:
|
||||||
attr_name = column_names[0]
|
attr_name = column_names[0]
|
||||||
|
|
||||||
def _single(instance: Base) -> tuple[Any, ...]:
|
def _single(instance: Any) -> tuple[Any, ...]:
|
||||||
return (getattr(instance, attr_name),)
|
return (getattr(instance, attr_name),)
|
||||||
|
|
||||||
return _single
|
return _single
|
||||||
|
|
||||||
getter = operator.attrgetter(*column_names)
|
getter = operator.attrgetter(*column_names)
|
||||||
|
|
||||||
def _multi(instance: Base) -> tuple[Any, ...]:
|
def _multi(instance: Any) -> tuple[Any, ...]:
|
||||||
values = getter(instance)
|
values = getter(instance)
|
||||||
return values if isinstance(values, tuple) else (values,)
|
return values if isinstance(values, tuple) else (values,)
|
||||||
|
|
||||||
return _multi
|
return _multi
|
||||||
|
|
||||||
|
|
||||||
def _model_to_dict(instance: Base) -> dict[str, Any]:
|
def _model_to_dict(instance: Any) -> dict[str, Any]:
|
||||||
"""将 SQLAlchemy 模型实例转换为字典
|
"""将 SQLAlchemy 模型实例转换为字典
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -113,7 +114,7 @@ def _dict_to_model(model_class: type[T], data: dict[str, Any]) -> T:
|
|||||||
return instance
|
return instance
|
||||||
|
|
||||||
|
|
||||||
class CRUDBase:
|
class CRUDBase(Generic[T]):
|
||||||
"""基础CRUD操作类
|
"""基础CRUD操作类
|
||||||
|
|
||||||
提供通用的增删改查操作,自动集成缓存和批处理
|
提供通用的增删改查操作,自动集成缓存和批处理
|
||||||
@@ -249,7 +250,7 @@ class CRUDBase:
|
|||||||
if cached_dicts is not None:
|
if cached_dicts is not None:
|
||||||
logger.debug(f"缓存命中: {cache_key}")
|
logger.debug(f"缓存命中: {cache_key}")
|
||||||
# 从字典列表恢复对象列表
|
# 从字典列表恢复对象列表
|
||||||
return [_dict_to_model(self.model, d) for d in cached_dicts]
|
return [_dict_to_model(self.model, d) for d in cached_dicts] # type: ignore
|
||||||
|
|
||||||
# 从数据库查询
|
# 从数据库查询
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
@@ -278,7 +279,7 @@ class CRUDBase:
|
|||||||
await cache.set(cache_key, instances_dicts)
|
await cache.set(cache_key, instances_dicts)
|
||||||
|
|
||||||
# 从字典列表重建对象列表返回(detached状态,所有字段已加载)
|
# 从字典列表重建对象列表返回(detached状态,所有字段已加载)
|
||||||
return [_dict_to_model(self.model, d) for d in instances_dicts]
|
return [_dict_to_model(self.model, d) for d in instances_dicts] # type: ignore
|
||||||
|
|
||||||
async def create(
|
async def create(
|
||||||
self,
|
self,
|
||||||
@@ -420,7 +421,7 @@ class CRUDBase:
|
|||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
stmt = delete(self.model).where(self.model.id == id)
|
stmt = delete(self.model).where(self.model.id == id)
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
success = result.rowcount > 0
|
success = result.rowcount > 0 # type: ignore
|
||||||
# 注意:commit在get_db_session的context manager退出时自动执行
|
# 注意:commit在get_db_session的context manager退出时自动执行
|
||||||
|
|
||||||
# 清除缓存
|
# 清除缓存
|
||||||
@@ -455,7 +456,7 @@ class CRUDBase:
|
|||||||
stmt = stmt.where(getattr(self.model, key) == value)
|
stmt = stmt.where(getattr(self.model, key) == value)
|
||||||
|
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
return result.scalar()
|
return int(result.scalar() or 0)
|
||||||
|
|
||||||
async def exists(
|
async def exists(
|
||||||
self,
|
self,
|
||||||
@@ -549,7 +550,7 @@ class CRUDBase:
|
|||||||
.values(**obj_in)
|
.values(**obj_in)
|
||||||
)
|
)
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
count += result.rowcount
|
count += result.rowcount # type: ignore
|
||||||
|
|
||||||
# 清除缓存
|
# 清除缓存
|
||||||
cache_key = f"{self.model_name}:id:{id}"
|
cache_key = f"{self.model_name}:id:{id}"
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from src.common.logger import get_logger
|
|||||||
|
|
||||||
logger = get_logger("database.query")
|
logger = get_logger("database.query")
|
||||||
|
|
||||||
T = TypeVar("T", bound="Base")
|
T = TypeVar("T", bound=Any)
|
||||||
|
|
||||||
|
|
||||||
class QueryBuilder(Generic[T]):
|
class QueryBuilder(Generic[T]):
|
||||||
@@ -330,7 +330,7 @@ class QueryBuilder(Generic[T]):
|
|||||||
|
|
||||||
items = await self.all()
|
items = await self.all()
|
||||||
|
|
||||||
return items, total
|
return items, total # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class AggregateQuery:
|
class AggregateQuery:
|
||||||
|
|||||||
@@ -122,7 +122,7 @@ async def get_recent_actions(
|
|||||||
动作记录列表
|
动作记录列表
|
||||||
"""
|
"""
|
||||||
query = QueryBuilder(ActionRecords)
|
query = QueryBuilder(ActionRecords)
|
||||||
return await query.filter(chat_id=chat_id).order_by("-time").limit(limit).all()
|
return await query.filter(chat_id=chat_id).order_by("-time").limit(limit).all() # type: ignore
|
||||||
|
|
||||||
|
|
||||||
# ===== Messages 业务API =====
|
# ===== Messages 业务API =====
|
||||||
@@ -148,7 +148,7 @@ async def get_chat_history(
|
|||||||
.limit(limit)
|
.limit(limit)
|
||||||
.offset(offset)
|
.offset(offset)
|
||||||
.all()
|
.all()
|
||||||
)
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
async def get_message_count(stream_id: str) -> int:
|
async def get_message_count(stream_id: str) -> int:
|
||||||
@@ -292,7 +292,7 @@ async def get_active_streams(
|
|||||||
if platform:
|
if platform:
|
||||||
query = query.filter(platform=platform)
|
query = query.filter(platform=platform)
|
||||||
|
|
||||||
return await query.order_by("-last_message_time").limit(limit).all()
|
return await query.order_by("-last_message_time").limit(limit).all() # type: ignore
|
||||||
|
|
||||||
|
|
||||||
# ===== LLMUsage 业务API =====
|
# ===== LLMUsage 业务API =====
|
||||||
@@ -390,7 +390,7 @@ async def get_usage_statistics(
|
|||||||
# 聚合统计
|
# 聚合统计
|
||||||
total_input = await query.sum("input_tokens")
|
total_input = await query.sum("input_tokens")
|
||||||
total_output = await query.sum("output_tokens")
|
total_output = await query.sum("output_tokens")
|
||||||
total_count = await query.filter().count() if hasattr(query, "count") else 0
|
total_count = await getattr(query.filter(), "count")() if hasattr(query, "count") else 0
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"total_input_tokens": int(total_input),
|
"total_input_tokens": int(total_input),
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ async def build_filters(model_class, filters: dict[str, Any]):
|
|||||||
return conditions
|
return conditions
|
||||||
|
|
||||||
|
|
||||||
def _model_to_dict(instance) -> dict[str, Any]:
|
def _model_to_dict(instance) -> dict[str, Any] | None:
|
||||||
"""将数据库模型实例转换为字典(兼容旧API
|
"""将数据库模型实例转换为字典(兼容旧API
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -238,7 +238,7 @@ async def db_query(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# 更新记录
|
# 更新记录
|
||||||
updated = await crud.update(instance.id, data)
|
updated = await crud.update(instance.id, data) # type: ignore
|
||||||
return _model_to_dict(updated)
|
return _model_to_dict(updated)
|
||||||
|
|
||||||
elif query_type == "delete":
|
elif query_type == "delete":
|
||||||
@@ -257,7 +257,7 @@ async def db_query(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# 删除记录
|
# 删除记录
|
||||||
success = await crud.delete(instance.id)
|
success = await crud.delete(instance.id) # type: ignore
|
||||||
return {"deleted": success}
|
return {"deleted": success}
|
||||||
|
|
||||||
elif query_type == "count":
|
elif query_type == "count":
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ async def get_engine() -> AsyncEngine:
|
|||||||
if _engine_lock is None:
|
if _engine_lock is None:
|
||||||
_engine_lock = asyncio.Lock()
|
_engine_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
assert _engine_lock is not None
|
||||||
# 使用锁保护初始化过程
|
# 使用锁保护初始化过程
|
||||||
async with _engine_lock:
|
async with _engine_lock:
|
||||||
# 双重检查锁定模式
|
# 双重检查锁定模式
|
||||||
@@ -55,6 +56,7 @@ async def get_engine() -> AsyncEngine:
|
|||||||
try:
|
try:
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
assert global_config is not None
|
||||||
config = global_config.database
|
config = global_config.database
|
||||||
db_type = config.database_type
|
db_type = config.database_type
|
||||||
|
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ def get_string_field(max_length=255, **kwargs):
|
|||||||
"""
|
"""
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
assert global_config is not None
|
||||||
db_type = global_config.database.database_type
|
db_type = global_config.database.database_type
|
||||||
|
|
||||||
# MySQL 索引需要指定长度的 VARCHAR
|
# MySQL 索引需要指定长度的 VARCHAR
|
||||||
|
|||||||
@@ -75,6 +75,7 @@ async def _apply_session_settings(session: AsyncSession, db_type: str) -> None:
|
|||||||
# 可以设置 schema 搜索路径等
|
# 可以设置 schema 搜索路径等
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
assert global_config is not None
|
||||||
schema = global_config.database.postgresql_schema
|
schema = global_config.database.postgresql_schema
|
||||||
if schema and schema != "public":
|
if schema and schema != "public":
|
||||||
await session.execute(text(f"SET search_path TO {schema}"))
|
await session.execute(text(f"SET search_path TO {schema}"))
|
||||||
@@ -114,6 +115,7 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
|||||||
# 获取数据库类型并应用特定设置
|
# 获取数据库类型并应用特定设置
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
assert global_config is not None
|
||||||
await _apply_session_settings(session, global_config.database.database_type)
|
await _apply_session_settings(session, global_config.database.database_type)
|
||||||
|
|
||||||
yield session
|
yield session
|
||||||
@@ -142,6 +144,7 @@ async def get_db_session_direct() -> AsyncGenerator[AsyncSession, None]:
|
|||||||
# 应用数据库特定设置
|
# 应用数据库特定设置
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
assert global_config is not None
|
||||||
await _apply_session_settings(session, global_config.database.database_type)
|
await _apply_session_settings(session, global_config.database.database_type)
|
||||||
|
|
||||||
yield session
|
yield session
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from collections import defaultdict, deque
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from typing import Any, TypeVar
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import delete, insert, select, update
|
from sqlalchemy import delete, insert, select, update
|
||||||
|
|
||||||
@@ -23,8 +23,6 @@ from src.common.memory_utils import estimate_size_smart
|
|||||||
|
|
||||||
logger = get_logger("batch_scheduler")
|
logger = get_logger("batch_scheduler")
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
class Priority(IntEnum):
|
class Priority(IntEnum):
|
||||||
"""操作优先级"""
|
"""操作优先级"""
|
||||||
@@ -429,7 +427,7 @@ class AdaptiveBatchScheduler:
|
|||||||
|
|
||||||
# 执行更新(但不commit)
|
# 执行更新(但不commit)
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
results.append((op, result.rowcount))
|
results.append((op, result.rowcount)) # type: ignore
|
||||||
|
|
||||||
# 注意:commit 由 get_db_session_direct 上下文管理器自动处理
|
# 注意:commit 由 get_db_session_direct 上下文管理器自动处理
|
||||||
|
|
||||||
@@ -471,7 +469,7 @@ class AdaptiveBatchScheduler:
|
|||||||
|
|
||||||
# 执行删除(但不commit)
|
# 执行删除(但不commit)
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
results.append((op, result.rowcount))
|
results.append((op, result.rowcount)) # type: ignore
|
||||||
|
|
||||||
# 注意:commit 由 get_db_session_direct 上下文管理器自动处理
|
# 注意:commit 由 get_db_session_direct 上下文管理器自动处理
|
||||||
|
|
||||||
|
|||||||
@@ -398,47 +398,48 @@ class MultiLevelCache:
|
|||||||
l2_stats_task = asyncio.create_task(self._get_cache_stats_safe(self.l2_cache, "L2"))
|
l2_stats_task = asyncio.create_task(self._get_cache_stats_safe(self.l2_cache, "L2"))
|
||||||
|
|
||||||
# 使用超时避免死锁
|
# 使用超时避免死锁
|
||||||
try:
|
results = await asyncio.gather(
|
||||||
l1_stats, l2_stats = await asyncio.gather(
|
asyncio.wait_for(l1_stats_task, timeout=1.0),
|
||||||
asyncio.wait_for(l1_stats_task, timeout=1.0),
|
asyncio.wait_for(l2_stats_task, timeout=1.0),
|
||||||
asyncio.wait_for(l2_stats_task, timeout=1.0),
|
return_exceptions=True
|
||||||
return_exceptions=True
|
)
|
||||||
)
|
l1_stats = results[0]
|
||||||
except asyncio.TimeoutError:
|
l2_stats = results[1]
|
||||||
logger.warning("缓存统计获取超时,使用基本统计")
|
|
||||||
l1_stats = await self.l1_cache.get_stats()
|
|
||||||
l2_stats = await self.l2_cache.get_stats()
|
|
||||||
|
|
||||||
# 处理异常情况
|
# 处理异常情况
|
||||||
if isinstance(l1_stats, Exception):
|
if isinstance(l1_stats, BaseException):
|
||||||
logger.error(f"L1统计获取失败: {l1_stats}")
|
logger.error(f"L1统计获取失败: {l1_stats}")
|
||||||
l1_stats = CacheStats()
|
l1_stats = CacheStats()
|
||||||
if isinstance(l2_stats, Exception):
|
if isinstance(l2_stats, BaseException):
|
||||||
logger.error(f"L2统计获取失败: {l2_stats}")
|
logger.error(f"L2统计获取失败: {l2_stats}")
|
||||||
l2_stats = CacheStats()
|
l2_stats = CacheStats()
|
||||||
|
|
||||||
|
assert isinstance(l1_stats, CacheStats)
|
||||||
|
assert isinstance(l2_stats, CacheStats)
|
||||||
|
|
||||||
# 🔧 修复:并行获取键集合,避免锁嵌套
|
# 🔧 修复:并行获取键集合,避免锁嵌套
|
||||||
l1_keys_task = asyncio.create_task(self._get_cache_keys_safe(self.l1_cache))
|
l1_keys_task = asyncio.create_task(self._get_cache_keys_safe(self.l1_cache))
|
||||||
l2_keys_task = asyncio.create_task(self._get_cache_keys_safe(self.l2_cache))
|
l2_keys_task = asyncio.create_task(self._get_cache_keys_safe(self.l2_cache))
|
||||||
|
|
||||||
try:
|
results = await asyncio.gather(
|
||||||
l1_keys, l2_keys = await asyncio.gather(
|
asyncio.wait_for(l1_keys_task, timeout=1.0),
|
||||||
asyncio.wait_for(l1_keys_task, timeout=1.0),
|
asyncio.wait_for(l2_keys_task, timeout=1.0),
|
||||||
asyncio.wait_for(l2_keys_task, timeout=1.0),
|
return_exceptions=True
|
||||||
return_exceptions=True
|
)
|
||||||
)
|
l1_keys = results[0]
|
||||||
except asyncio.TimeoutError:
|
l2_keys = results[1]
|
||||||
logger.warning("缓存键获取超时,使用默认值")
|
|
||||||
l1_keys, l2_keys = set(), set()
|
|
||||||
|
|
||||||
# 处理异常情况
|
# 处理异常情况
|
||||||
if isinstance(l1_keys, Exception):
|
if isinstance(l1_keys, BaseException):
|
||||||
logger.warning(f"L1键获取失败: {l1_keys}")
|
logger.warning(f"L1键获取失败: {l1_keys}")
|
||||||
l1_keys = set()
|
l1_keys = set()
|
||||||
if isinstance(l2_keys, Exception):
|
if isinstance(l2_keys, BaseException):
|
||||||
logger.warning(f"L2键获取失败: {l2_keys}")
|
logger.warning(f"L2键获取失败: {l2_keys}")
|
||||||
l2_keys = set()
|
l2_keys = set()
|
||||||
|
|
||||||
|
assert isinstance(l1_keys, set)
|
||||||
|
assert isinstance(l2_keys, set)
|
||||||
|
|
||||||
# 计算共享键和独占键
|
# 计算共享键和独占键
|
||||||
shared_keys = l1_keys & l2_keys
|
shared_keys = l1_keys & l2_keys
|
||||||
l1_only_keys = l1_keys - l2_keys
|
l1_only_keys = l1_keys - l2_keys
|
||||||
@@ -448,24 +449,25 @@ class MultiLevelCache:
|
|||||||
l1_size_task = asyncio.create_task(self._calculate_memory_usage_safe(self.l1_cache, l1_keys))
|
l1_size_task = asyncio.create_task(self._calculate_memory_usage_safe(self.l1_cache, l1_keys))
|
||||||
l2_size_task = asyncio.create_task(self._calculate_memory_usage_safe(self.l2_cache, l2_keys))
|
l2_size_task = asyncio.create_task(self._calculate_memory_usage_safe(self.l2_cache, l2_keys))
|
||||||
|
|
||||||
try:
|
results = await asyncio.gather(
|
||||||
l1_size, l2_size = await asyncio.gather(
|
asyncio.wait_for(l1_size_task, timeout=1.0),
|
||||||
asyncio.wait_for(l1_size_task, timeout=1.0),
|
asyncio.wait_for(l2_size_task, timeout=1.0),
|
||||||
asyncio.wait_for(l2_size_task, timeout=1.0),
|
return_exceptions=True
|
||||||
return_exceptions=True
|
)
|
||||||
)
|
l1_size = results[0]
|
||||||
except asyncio.TimeoutError:
|
l2_size = results[1]
|
||||||
logger.warning("内存计算超时,使用统计值")
|
|
||||||
l1_size, l2_size = l1_stats.total_size, l2_stats.total_size
|
|
||||||
|
|
||||||
# 处理异常情况
|
# 处理异常情况
|
||||||
if isinstance(l1_size, Exception):
|
if isinstance(l1_size, BaseException):
|
||||||
logger.warning(f"L1内存计算失败: {l1_size}")
|
logger.warning(f"L1内存计算失败: {l1_size}")
|
||||||
l1_size = l1_stats.total_size
|
l1_size = l1_stats.total_size
|
||||||
if isinstance(l2_size, Exception):
|
if isinstance(l2_size, BaseException):
|
||||||
logger.warning(f"L2内存计算失败: {l2_size}")
|
logger.warning(f"L2内存计算失败: {l2_size}")
|
||||||
l2_size = l2_stats.total_size
|
l2_size = l2_stats.total_size
|
||||||
|
|
||||||
|
assert isinstance(l1_size, int)
|
||||||
|
assert isinstance(l2_size, int)
|
||||||
|
|
||||||
# 计算实际总内存(避免重复计数)
|
# 计算实际总内存(避免重复计数)
|
||||||
actual_total_size = l1_size + l2_size - min(l1_stats.total_size, l2_stats.total_size)
|
actual_total_size = l1_size + l2_size - min(l1_stats.total_size, l2_stats.total_size)
|
||||||
|
|
||||||
@@ -769,6 +771,7 @@ async def get_cache() -> MultiLevelCache:
|
|||||||
try:
|
try:
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
assert global_config is not None
|
||||||
db_config = global_config.database
|
db_config = global_config.database
|
||||||
|
|
||||||
# 检查是否启用缓存
|
# 检查是否启用缓存
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import functools
|
|||||||
import hashlib
|
import hashlib
|
||||||
import time
|
import time
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import Any, TypeVar
|
from typing import Any, ParamSpec, TypeVar
|
||||||
|
|
||||||
from sqlalchemy.exc import DBAPIError, OperationalError
|
from sqlalchemy.exc import DBAPIError, OperationalError
|
||||||
from sqlalchemy.exc import TimeoutError as SQLTimeoutError
|
from sqlalchemy.exc import TimeoutError as SQLTimeoutError
|
||||||
@@ -56,8 +56,9 @@ def generate_cache_key(
|
|||||||
|
|
||||||
return ":".join(cache_key_parts)
|
return ":".join(cache_key_parts)
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
F = TypeVar("F", bound=Callable[..., Awaitable[Any]])
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
def retry(
|
def retry(
|
||||||
@@ -77,14 +78,13 @@ def retry(
|
|||||||
exceptions: 需要重试的异常类型
|
exceptions: 需要重试的异常类型
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@retry(max_attempts=3, delay=1.0)
|
|
||||||
async def query_data():
|
async def query_data():
|
||||||
return await session.execute(stmt)
|
return await session.execute(stmt)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
async def wrapper(*args: Any, **kwargs: Any) -> T:
|
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
last_exception = None
|
last_exception = None
|
||||||
current_delay = delay
|
current_delay = delay
|
||||||
|
|
||||||
@@ -107,7 +107,9 @@ def retry(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 所有尝试都失败
|
# 所有尝试都失败
|
||||||
raise last_exception
|
if last_exception:
|
||||||
|
raise last_exception
|
||||||
|
raise RuntimeError(f"Retry failed after {max_attempts} attempts")
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
@@ -128,9 +130,9 @@ def timeout(seconds: float):
|
|||||||
return await session.execute(complex_stmt)
|
return await session.execute(complex_stmt)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
async def wrapper(*args: Any, **kwargs: Any) -> T:
|
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
try:
|
try:
|
||||||
return await asyncio.wait_for(func(*args, **kwargs), timeout=seconds)
|
return await asyncio.wait_for(func(*args, **kwargs), timeout=seconds)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
@@ -164,9 +166,9 @@ def cached(
|
|||||||
return await query_user(user_id)
|
return await query_user(user_id)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
async def wrapper(*args: Any, **kwargs: Any) -> T:
|
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
# 延迟导入避免循环依赖
|
# 延迟导入避免循环依赖
|
||||||
from src.common.database.optimization import get_cache
|
from src.common.database.optimization import get_cache
|
||||||
|
|
||||||
@@ -226,9 +228,9 @@ def measure_time(log_slow: float | None = None):
|
|||||||
return await session.execute(stmt)
|
return await session.execute(stmt)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
async def wrapper(*args: Any, **kwargs: Any) -> T:
|
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -268,21 +270,23 @@ def transactional(auto_commit: bool = True, auto_rollback: bool = True):
|
|||||||
函数需要接受session参数
|
函数需要接受session参数
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
async def wrapper(*args: Any, **kwargs: Any) -> T:
|
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
# 查找session参数
|
# 查找session参数
|
||||||
session = None
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
if args:
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
|
session: AsyncSession | None = None
|
||||||
|
if args:
|
||||||
for arg in args:
|
for arg in args:
|
||||||
if isinstance(arg, AsyncSession):
|
if isinstance(arg, AsyncSession):
|
||||||
session = arg
|
session = arg
|
||||||
break
|
break
|
||||||
|
|
||||||
if not session and "session" in kwargs:
|
if not session and "session" in kwargs:
|
||||||
session = kwargs["session"]
|
possible_session = kwargs["session"]
|
||||||
|
if isinstance(possible_session, AsyncSession):
|
||||||
|
session = possible_session
|
||||||
|
|
||||||
if not session:
|
if not session:
|
||||||
logger.warning(f"{func.__name__} 未找到session参数,跳过事务管理")
|
logger.warning(f"{func.__name__} 未找到session参数,跳过事务管理")
|
||||||
@@ -331,7 +335,7 @@ def db_operation(
|
|||||||
return await complex_operation()
|
return await complex_operation()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||||
# 从内到外应用装饰器
|
# 从内到外应用装饰器
|
||||||
wrapped = func
|
wrapped = func
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user