refactor(chat): 迁移数据库操作为异步模式并修复相关调用
将同步数据库操作全面迁移为异步模式,主要涉及: - 将 `with get_db_session()` 改为 `async with get_db_session()` - 修复相关异步调用链,确保 await 正确传递 - 优化消息管理器、上下文管理器等核心组件的异步处理 - 移除同步的 person_id 获取方法,避免协程对象传递问题 修复 deepcopy 在 StreamContext 中的序列化问题,跳过不可序列化的 asyncio.Task 对象 删除无用的测试文件和废弃的插件清单文件
This commit is contained in:
@@ -46,7 +46,7 @@ def replace_user_references_sync(
|
||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||
return f"{global_config.bot.nickname}(你)"
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
return person_info_manager.get_value_sync(person_id, "person_name") or user_id # type: ignore
|
||||
return person_info_manager.get_value(person_id, "person_name") or user_id # type: ignore
|
||||
|
||||
name_resolver = default_resolver
|
||||
|
||||
@@ -254,7 +254,7 @@ def get_raw_msg_by_timestamp_with_chat_users(
|
||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
||||
|
||||
|
||||
def get_actions_by_timestamp_with_chat(
|
||||
async def get_actions_by_timestamp_with_chat(
|
||||
chat_id: str,
|
||||
timestamp_start: float = 0,
|
||||
timestamp_end: float = time.time(),
|
||||
@@ -273,22 +273,21 @@ def get_actions_by_timestamp_with_chat(
|
||||
f"limit={limit}, limit_mode={limit_mode}"
|
||||
)
|
||||
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
if limit > 0:
|
||||
if limit_mode == "latest":
|
||||
query = session.execute(
|
||||
result = await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time > timestamp_start,
|
||||
ActionRecords.time < timestamp_end,
|
||||
ActionRecords.time >= timestamp_start,
|
||||
ActionRecords.time <= timestamp_end,
|
||||
)
|
||||
)
|
||||
.order_by(ActionRecords.time.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
actions = list(query.scalars())
|
||||
actions = list(result.scalars())
|
||||
actions_result = []
|
||||
for action in reversed(actions):
|
||||
action_dict = {
|
||||
@@ -305,38 +304,39 @@ def get_actions_by_timestamp_with_chat(
|
||||
"chat_info_platform": action.chat_info_platform,
|
||||
}
|
||||
actions_result.append(action_dict)
|
||||
else: # earliest
|
||||
query = session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time > timestamp_start,
|
||||
ActionRecords.time < timestamp_end,
|
||||
)
|
||||
)
|
||||
.order_by(ActionRecords.time.asc())
|
||||
.limit(limit)
|
||||
)
|
||||
actions = list(query.scalars())
|
||||
actions_result = []
|
||||
for action in actions:
|
||||
action_dict = {
|
||||
"id": action.id,
|
||||
"action_id": action.action_id,
|
||||
"time": action.time,
|
||||
"action_name": action.action_name,
|
||||
"action_data": action.action_data,
|
||||
"action_done": action.action_done,
|
||||
"action_build_into_prompt": action.action_build_into_prompt,
|
||||
"action_prompt_display": action.action_prompt_display,
|
||||
"chat_id": action.chat_id,
|
||||
"chat_info_stream_id": action.chat_info_stream_id,
|
||||
"chat_info_platform": action.chat_info_platform,
|
||||
}
|
||||
actions_result.append(action_dict)
|
||||
else: # earliest
|
||||
result = await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time > timestamp_start,
|
||||
ActionRecords.time < timestamp_end,
|
||||
)
|
||||
)
|
||||
.order_by(ActionRecords.time.asc())
|
||||
.limit(limit)
|
||||
)
|
||||
actions = list(result.scalars())
|
||||
actions_result = []
|
||||
for action in actions:
|
||||
action_dict = {
|
||||
"id": action.id,
|
||||
"action_id": action.action_id,
|
||||
"time": action.time,
|
||||
"action_name": action.action_name,
|
||||
"action_data": action.action_data,
|
||||
"action_done": action.action_done,
|
||||
"action_build_into_prompt": action.action_build_into_prompt,
|
||||
"action_prompt_display": action.action_prompt_display,
|
||||
"chat_id": action.chat_id,
|
||||
"chat_info_stream_id": action.chat_info_stream_id,
|
||||
"chat_info_platform": action.chat_info_platform,
|
||||
}
|
||||
actions_result.append(action_dict)
|
||||
else:
|
||||
query = session.execute(
|
||||
result = await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
@@ -347,7 +347,7 @@ def get_actions_by_timestamp_with_chat(
|
||||
)
|
||||
.order_by(ActionRecords.time.asc())
|
||||
)
|
||||
actions = list(query.scalars())
|
||||
actions = list(result.scalars())
|
||||
actions_result = []
|
||||
for action in actions:
|
||||
action_dict = {
|
||||
@@ -367,14 +367,14 @@ def get_actions_by_timestamp_with_chat(
|
||||
return actions_result
|
||||
|
||||
|
||||
def get_actions_by_timestamp_with_chat_inclusive(
|
||||
async def get_actions_by_timestamp_with_chat_inclusive(
|
||||
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
if limit > 0:
|
||||
if limit_mode == "latest":
|
||||
query = session.execute(
|
||||
result = await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
@@ -386,10 +386,10 @@ def get_actions_by_timestamp_with_chat_inclusive(
|
||||
.order_by(ActionRecords.time.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
actions = list(query.scalars())
|
||||
actions = list(result.scalars())
|
||||
return [action.__dict__ for action in reversed(actions)]
|
||||
else: # earliest
|
||||
query = session.execute(
|
||||
result = await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
@@ -402,7 +402,7 @@ def get_actions_by_timestamp_with_chat_inclusive(
|
||||
.limit(limit)
|
||||
)
|
||||
else:
|
||||
query = session.execute(
|
||||
query = await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
@@ -507,7 +507,7 @@ def num_new_messages_since_with_users(
|
||||
return count_messages(message_filter=filter_query)
|
||||
|
||||
|
||||
def _build_readable_messages_internal(
|
||||
async def _build_readable_messages_internal(
|
||||
messages: List[Dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
@@ -627,7 +627,7 @@ def _build_readable_messages_internal(
|
||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||
person_name = f"{global_config.bot.nickname}(你)"
|
||||
else:
|
||||
person_name = person_info_manager.get_value_sync(person_id, "person_name") # type: ignore
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name") # type: ignore
|
||||
|
||||
# 如果 person_name 未设置,则使用消息中的 nickname 或默认名称
|
||||
if not person_name:
|
||||
@@ -800,7 +800,7 @@ def _build_readable_messages_internal(
|
||||
)
|
||||
|
||||
|
||||
def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
# sourcery skip: use-contextlib-suppress
|
||||
"""
|
||||
构建图片映射信息字符串,显示图片的具体描述内容
|
||||
@@ -823,8 +823,8 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
# 从数据库中获取图片描述
|
||||
description = "[图片内容未知]" # 默认描述
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
result = session.execute(select(Images).where(Images.image_id == pic_id))
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(select(Images).where(Images.image_id == pic_id))
|
||||
image = result.scalar_one_or_none()
|
||||
if image and image.description: # type: ignore
|
||||
description = image.description
|
||||
@@ -922,17 +922,17 @@ async def build_readable_messages_with_list(
|
||||
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
|
||||
允许通过参数控制格式化行为。
|
||||
"""
|
||||
formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||
formatted_string, details_list, pic_id_mapping, _ = await _build_readable_messages_internal(
|
||||
messages, replace_bot_name, merge_messages, timestamp_mode, truncate
|
||||
)
|
||||
|
||||
if pic_mapping_info := build_pic_mapping_info(pic_id_mapping):
|
||||
if pic_mapping_info := await build_pic_mapping_info(pic_id_mapping):
|
||||
formatted_string = f"{pic_mapping_info}\n\n{formatted_string}"
|
||||
|
||||
return formatted_string, details_list
|
||||
|
||||
|
||||
def build_readable_messages_with_id(
|
||||
async def build_readable_messages_with_id(
|
||||
messages: List[Dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
@@ -948,7 +948,7 @@ def build_readable_messages_with_id(
|
||||
"""
|
||||
message_id_list = assign_message_ids(messages)
|
||||
|
||||
formatted_string = build_readable_messages(
|
||||
formatted_string = await build_readable_messages(
|
||||
messages=messages,
|
||||
replace_bot_name=replace_bot_name,
|
||||
merge_messages=merge_messages,
|
||||
@@ -960,10 +960,16 @@ def build_readable_messages_with_id(
|
||||
message_id_list=message_id_list,
|
||||
)
|
||||
|
||||
# 如果存在图片映射信息,附加之
|
||||
if pic_mapping_info := await build_pic_mapping_info({}):
|
||||
# 如果当前没有图片映射则不附加
|
||||
if pic_mapping_info:
|
||||
formatted_string = f"{pic_mapping_info}\n\n{formatted_string}"
|
||||
|
||||
return formatted_string, message_id_list
|
||||
|
||||
|
||||
def build_readable_messages(
|
||||
async def build_readable_messages(
|
||||
messages: List[Dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
@@ -1004,9 +1010,9 @@ def build_readable_messages(
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 获取这个时间范围内的动作记录,并匹配chat_id
|
||||
actions_in_range = session.execute(
|
||||
actions_in_range = (await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
@@ -1014,15 +1020,15 @@ def build_readable_messages(
|
||||
)
|
||||
)
|
||||
.order_by(ActionRecords.time)
|
||||
).scalars()
|
||||
)).scalars()
|
||||
|
||||
# 获取最新消息之后的第一个动作记录
|
||||
action_after_latest = session.execute(
|
||||
action_after_latest = (await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id))
|
||||
.order_by(ActionRecords.time)
|
||||
.limit(1)
|
||||
).scalars()
|
||||
)).scalars()
|
||||
|
||||
# 合并两部分动作记录,并转为 dict,避免 DetachedInstanceError
|
||||
actions = [
|
||||
@@ -1053,7 +1059,7 @@ def build_readable_messages(
|
||||
|
||||
if read_mark <= 0:
|
||||
# 没有有效的 read_mark,直接格式化所有消息
|
||||
formatted_string, _, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||
formatted_string, _, pic_id_mapping, _ = await _build_readable_messages_internal(
|
||||
copy_messages,
|
||||
replace_bot_name,
|
||||
merge_messages,
|
||||
@@ -1064,7 +1070,7 @@ def build_readable_messages(
|
||||
)
|
||||
|
||||
# 生成图片映射信息并添加到最前面
|
||||
pic_mapping_info = build_pic_mapping_info(pic_id_mapping)
|
||||
pic_mapping_info = await build_pic_mapping_info(pic_id_mapping)
|
||||
if pic_mapping_info:
|
||||
return f"{pic_mapping_info}\n\n{formatted_string}"
|
||||
else:
|
||||
@@ -1079,7 +1085,7 @@ def build_readable_messages(
|
||||
pic_counter = 1
|
||||
|
||||
# 分别格式化,但使用共享的图片映射
|
||||
formatted_before, _, pic_id_mapping, pic_counter = _build_readable_messages_internal(
|
||||
formatted_before, _, pic_id_mapping, pic_counter = await _build_readable_messages_internal(
|
||||
messages_before_mark,
|
||||
replace_bot_name,
|
||||
merge_messages,
|
||||
@@ -1090,7 +1096,7 @@ def build_readable_messages(
|
||||
show_pic=show_pic,
|
||||
message_id_list=message_id_list,
|
||||
)
|
||||
formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||
formatted_after, _, pic_id_mapping, _ = await _build_readable_messages_internal(
|
||||
messages_after_mark,
|
||||
replace_bot_name,
|
||||
merge_messages,
|
||||
@@ -1106,7 +1112,7 @@ def build_readable_messages(
|
||||
|
||||
# 生成图片映射信息
|
||||
if pic_id_mapping:
|
||||
pic_mapping_info = f"图片信息:\n{build_pic_mapping_info(pic_id_mapping)}\n聊天记录信息:\n"
|
||||
pic_mapping_info = f"图片信息:\n{await build_pic_mapping_info(pic_id_mapping)}\n聊天记录信息:\n"
|
||||
else:
|
||||
pic_mapping_info = "聊天记录信息:\n"
|
||||
|
||||
@@ -1229,7 +1235,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
||||
|
||||
# 在最前面添加图片映射信息
|
||||
final_output_lines = []
|
||||
pic_mapping_info = build_pic_mapping_info(pic_id_mapping)
|
||||
pic_mapping_info = await build_pic_mapping_info(pic_id_mapping)
|
||||
if pic_mapping_info:
|
||||
final_output_lines.append(pic_mapping_info)
|
||||
final_output_lines.append("\n\n")
|
||||
|
||||
Reference in New Issue
Block a user