refactor(chat): 迁移数据库操作为异步模式并修复相关调用
将同步数据库操作全面迁移为异步模式,主要涉及: - 将 `with get_db_session()` 改为 `async with get_db_session()` - 修复相关异步调用链,确保 await 正确传递 - 优化消息管理器、上下文管理器等核心组件的异步处理 - 移除同步的 person_id 获取方法,避免协程对象传递问题 修复 deepcopy 在 StreamContext 中的序列化问题,跳过不可序列化的 asyncio.Task 对象 删除无用的测试文件和废弃的插件清单文件
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import datetime
|
||||
import hashlib
|
||||
@@ -57,7 +58,7 @@ class PersonInfoManager:
|
||||
self.person_name_list = {}
|
||||
self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name")
|
||||
# try:
|
||||
# with get_db_session() as session:
|
||||
# async with get_db_session() as session:
|
||||
# db.connect(reuse_if_open=True)
|
||||
# # 设置连接池参数(仅对SQLite有效)
|
||||
# if hasattr(db, "execute_sql"):
|
||||
@@ -75,7 +76,7 @@ class PersonInfoManager:
|
||||
try:
|
||||
pass
|
||||
# 在这里获取会话
|
||||
# with get_db_session() as session:
|
||||
# async with get_db_session() as session:
|
||||
# for record in session.execute(
|
||||
# select(PersonInfo.person_id, PersonInfo.person_name).where(PersonInfo.person_name.is_not(None))
|
||||
# ).fetchall():
|
||||
@@ -87,58 +88,25 @@ class PersonInfoManager:
|
||||
|
||||
@staticmethod
|
||||
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
|
||||
"""获取唯一id"""
|
||||
"""获取唯一id(同步)
|
||||
|
||||
说明: 原来该方法为异步并在内部尝试执行数据库检查/迁移,导致在许多调用处未 await 时返回 coroutine 对象。
|
||||
为了避免将 coroutine 传递到其它同步调用(例如数据库查询条件)中,这里将方法改为同步并仅返回基于 platform 和 user_id 的 MD5 哈希值。
|
||||
|
||||
注意: 这会跳过原有的 napcat->qq 迁移检查逻辑。如需保留迁移,请使用显式的、在合适时机执行的迁移任务。
|
||||
"""
|
||||
# 检查platform是否为None或空
|
||||
if platform is None:
|
||||
platform = "unknown"
|
||||
|
||||
if "-" in platform:
|
||||
platform = platform.split("-")[1]
|
||||
# 在此处打一个补丁,如果platform为qq,尝试生成id后检查是否存在,如果不存在,则将平台换为napcat后再次检查,如果存在,则更新原id为platform为qq的id
|
||||
|
||||
components = [platform, str(user_id)]
|
||||
key = "_".join(components)
|
||||
|
||||
# 如果不是 qq 平台,直接返回计算的 id
|
||||
if platform != "qq":
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
qq_id = hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
# 对于 qq 平台,先检查该 person_id 是否已存在;如果存在直接返回
|
||||
def _db_check_and_migrate_sync(p_id: str, raw_user_id: str):
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
# 检查 qq_id 是否存在
|
||||
existing_qq = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
|
||||
if existing_qq:
|
||||
return p_id
|
||||
|
||||
# 如果 qq_id 不存在,尝试使用 napcat 作为平台生成对应 id 并检查
|
||||
nap_components = ["napcat", str(raw_user_id)]
|
||||
nap_key = "_".join(nap_components)
|
||||
nap_id = hashlib.md5(nap_key.encode()).hexdigest()
|
||||
|
||||
existing_nap = session.execute(select(PersonInfo).where(PersonInfo.person_id == nap_id)).scalar()
|
||||
if not existing_nap:
|
||||
# napcat 也不存在,返回 qq_id(未命中)
|
||||
return p_id
|
||||
|
||||
# napcat 存在,迁移该记录:更新 person_id 与 platform -> qq
|
||||
try:
|
||||
# 更新现有 napcat 记录
|
||||
existing_nap.person_id = p_id
|
||||
existing_nap.platform = "qq"
|
||||
existing_nap.user_id = str(raw_user_id)
|
||||
session.commit()
|
||||
return p_id
|
||||
except Exception:
|
||||
session.rollback()
|
||||
return p_id
|
||||
except Exception as e:
|
||||
logger.error(f"检查/迁移 napcat->qq 时出错: {e}")
|
||||
return p_id
|
||||
|
||||
return _db_check_and_migrate_sync(qq_id, user_id)
|
||||
# 直接返回计算的 id(同步)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
async def is_person_known(self, platform: str, user_id: int):
|
||||
"""判断是否认识某人"""
|
||||
@@ -157,17 +125,25 @@ class PersonInfoManager:
|
||||
logger.error(f"检查用户 {person_id} 是否已知时出错 (SQLAlchemy): {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def get_person_id_by_person_name(person_name: str) -> str:
|
||||
"""根据用户名获取用户ID"""
|
||||
async def get_person_id_by_person_name(self, person_name: str) -> str:
|
||||
"""
|
||||
根据用户名获取用户ID(同步)
|
||||
|
||||
说明: 为了避免在多个调用点将 coroutine 误传递到数据库查询中,
|
||||
此处提供一个同步实现。优先在内存缓存 `self.person_name_list` 中查找,
|
||||
若未命中则返回空字符串。若后续需要更强的一致性,可在异步上下文
|
||||
额外实现带 await 的查询方法。
|
||||
"""
|
||||
try:
|
||||
# 在需要时获取会话
|
||||
async with get_db_session() as session:
|
||||
record = result = await session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name))
|
||||
result.scalar()
|
||||
return record.person_id if record else ""
|
||||
# 优先使用内存缓存加速查找:self.person_name_list maps person_id -> person_name
|
||||
for pid, pname in self.person_name_list.items():
|
||||
if pname == person_name:
|
||||
return pid
|
||||
|
||||
# 未找到缓存命中,避免在同步路径中进行阻塞的数据库查询,直接返回空字符串
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (SQLAlchemy): {e}")
|
||||
logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}")
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
@@ -578,26 +554,15 @@ class PersonInfoManager:
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_value(person_id: str, field_name: str) -> Any:
|
||||
async def get_value(person_id: str, field_name: str) -> Any:
|
||||
"""获取单个字段值(同步版本)"""
|
||||
if not person_id:
|
||||
logger.debug("get_value获取失败:person_id不能为空")
|
||||
return None
|
||||
|
||||
import asyncio
|
||||
|
||||
async def _get_record_sync():
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id))
|
||||
record = result.scalar()
|
||||
return record
|
||||
|
||||
try:
|
||||
record = asyncio.run(_get_record_sync())
|
||||
except RuntimeError:
|
||||
# 如果当前线程已经有事件循环在运行,则使用现有的循环
|
||||
loop = asyncio.get_running_loop()
|
||||
record = loop.run_until_complete(_get_record_sync())
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id))
|
||||
record = result.scalar()
|
||||
|
||||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user