refactor(chat): 迁移数据库操作为异步模式并修复相关调用

将同步数据库操作全面迁移为异步模式,主要涉及:
- 将 `with get_db_session()` 改为 `async with get_db_session()`
- 修复相关异步调用链,确保 await 正确传递
- 优化消息管理器、上下文管理器等核心组件的异步处理
- 移除同步的 person_id 获取方法,避免协程对象传递问题

修复 deepcopy 在 StreamContext 中的序列化问题,跳过不可序列化的 asyncio.Task 对象

删除无用的测试文件和废弃的插件清单文件
This commit is contained in:
Windpicker-owo
2025-09-28 20:40:46 +08:00
parent 08ef960947
commit fd76e36320
30 changed files with 481 additions and 625 deletions

View File

@@ -1,3 +1,4 @@
import asyncio
import copy
import datetime
import hashlib
@@ -57,7 +58,7 @@ class PersonInfoManager:
self.person_name_list = {}
self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name")
# try:
# with get_db_session() as session:
# async with get_db_session() as session:
# db.connect(reuse_if_open=True)
# # 设置连接池参数仅对SQLite有效
# if hasattr(db, "execute_sql"):
@@ -75,7 +76,7 @@ class PersonInfoManager:
try:
pass
# 在这里获取会话
# with get_db_session() as session:
# async with get_db_session() as session:
# for record in session.execute(
# select(PersonInfo.person_id, PersonInfo.person_name).where(PersonInfo.person_name.is_not(None))
# ).fetchall():
@@ -87,58 +88,25 @@ class PersonInfoManager:
@staticmethod
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
"""获取唯一id"""
"""获取唯一id(同步)
说明: 原来该方法为异步并在内部尝试执行数据库检查/迁移,导致在许多调用处未 await 时返回 coroutine 对象。
为了避免将 coroutine 传递到其它同步调用(例如数据库查询条件)中,这里将方法改为同步并仅返回基于 platform 和 user_id 的 MD5 哈希值。
注意: 这会跳过原有的 napcat->qq 迁移检查逻辑。如需保留迁移,请使用显式的、在合适时机执行的迁移任务。
"""
# 检查platform是否为None或空
if platform is None:
platform = "unknown"
if "-" in platform:
platform = platform.split("-")[1]
# 在此处打一个补丁如果platform为qq尝试生成id后检查是否存在如果不存在则将平台换为napcat后再次检查如果存在则更新原id为platform为qq的id
components = [platform, str(user_id)]
key = "_".join(components)
# 如果不是 qq 平台,直接返回计算的 id
if platform != "qq":
return hashlib.md5(key.encode()).hexdigest()
qq_id = hashlib.md5(key.encode()).hexdigest()
# 对于 qq 平台,先检查该 person_id 是否已存在;如果存在直接返回
def _db_check_and_migrate_sync(p_id: str, raw_user_id: str):
try:
with get_db_session() as session:
# 检查 qq_id 是否存在
existing_qq = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
if existing_qq:
return p_id
# 如果 qq_id 不存在,尝试使用 napcat 作为平台生成对应 id 并检查
nap_components = ["napcat", str(raw_user_id)]
nap_key = "_".join(nap_components)
nap_id = hashlib.md5(nap_key.encode()).hexdigest()
existing_nap = session.execute(select(PersonInfo).where(PersonInfo.person_id == nap_id)).scalar()
if not existing_nap:
# napcat 也不存在,返回 qq_id未命中
return p_id
# napcat 存在,迁移该记录:更新 person_id 与 platform -> qq
try:
# 更新现有 napcat 记录
existing_nap.person_id = p_id
existing_nap.platform = "qq"
existing_nap.user_id = str(raw_user_id)
session.commit()
return p_id
except Exception:
session.rollback()
return p_id
except Exception as e:
logger.error(f"检查/迁移 napcat->qq 时出错: {e}")
return p_id
return _db_check_and_migrate_sync(qq_id, user_id)
# 直接返回计算的 id(同步)
return hashlib.md5(key.encode()).hexdigest()
async def is_person_known(self, platform: str, user_id: int):
"""判断是否认识某人"""
@@ -157,17 +125,25 @@ class PersonInfoManager:
logger.error(f"检查用户 {person_id} 是否已知时出错 (SQLAlchemy): {e}")
return False
@staticmethod
async def get_person_id_by_person_name(person_name: str) -> str:
"""根据用户名获取用户ID"""
async def get_person_id_by_person_name(self, person_name: str) -> str:
"""
根据用户名获取用户ID(同步)
说明: 为了避免在多个调用点将 coroutine 误传递到数据库查询中,
此处提供一个同步实现。优先在内存缓存 `self.person_name_list` 中查找,
若未命中则返回空字符串。若后续需要更强的一致性,可在异步上下文
额外实现带 await 的查询方法。
"""
try:
# 在需要时获取会话
async with get_db_session() as session:
record = result = await session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name))
result.scalar()
return record.person_id if record else ""
# 优先使用内存缓存加速查找self.person_name_list maps person_id -> person_name
for pid, pname in self.person_name_list.items():
if pname == person_name:
return pid
# 未找到缓存命中,避免在同步路径中进行阻塞的数据库查询,直接返回空字符串
return ""
except Exception as e:
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (SQLAlchemy): {e}")
logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}")
return ""
@staticmethod
@@ -578,26 +554,15 @@ class PersonInfoManager:
@staticmethod
def get_value(person_id: str, field_name: str) -> Any:
async def get_value(person_id: str, field_name: str) -> Any:
"""获取单个字段值(同步版本)"""
if not person_id:
logger.debug("get_value获取失败person_id不能为空")
return None
import asyncio
async def _get_record_sync():
async with get_db_session() as session:
result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id))
record = result.scalar()
return record
try:
record = asyncio.run(_get_record_sync())
except RuntimeError:
# 如果当前线程已经有事件循环在运行,则使用现有的循环
loop = asyncio.get_running_loop()
record = loop.run_until_complete(_get_record_sync())
async with get_db_session() as session:
result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id))
record = result.scalar()
model_fields = [column.name for column in PersonInfo.__table__.columns]