refactor(chat): 迁移数据库操作为异步模式并修复相关调用

将同步数据库操作全面迁移为异步模式,主要涉及:
- 将 `with get_db_session()` 改为 `async with get_db_session()`
- 修复相关异步调用链,确保 await 正确传递
- 优化消息管理器、上下文管理器等核心组件的异步处理
- 移除同步的 person_id 获取方法,避免协程对象传递问题

修复 deepcopy 在 StreamContext 中的序列化问题,跳过不可序列化的 asyncio.Task 对象

删除无用的测试文件和废弃的插件清单文件
This commit is contained in:
Windpicker-owo
2025-09-28 20:40:46 +08:00
parent 08ef960947
commit fd76e36320
30 changed files with 481 additions and 625 deletions

View File

@@ -344,6 +344,39 @@ class StreamContext(BaseDataModel):
"""获取优先级信息"""
return self.priority_info
def __deepcopy__(self, memo):
"""自定义深拷贝,跳过不可序列化的 asyncio.Task (processing_task)。
deepcopy 在内部可能会尝试 pickle 某些对象(如 asyncio.Task
这会在多线程或运行时事件循环中导致 TypeError。这里我们手动复制
__dict__ 中的字段,确保 processing_task 被设置为 None其他字段使用
copy.deepcopy 递归复制。
"""
import copy
# 如果已经复制过,直接返回缓存结果
obj_id = id(self)
if obj_id in memo:
return memo[obj_id]
# 创建一个未初始化的新实例,然后逐个字段深拷贝
cls = self.__class__
new = cls.__new__(cls)
memo[obj_id] = new
for k, v in self.__dict__.items():
if k == "processing_task":
# 不复制 asyncio.Task避免无法 pickling
setattr(new, k, None)
else:
try:
setattr(new, k, copy.deepcopy(v, memo))
except Exception:
# 如果某个字段无法深拷贝,退回到原始引用(安全性谨慎)
setattr(new, k, v)
return new
@dataclass
class MessageManagerStats(BaseDataModel):

View File

@@ -30,7 +30,7 @@ def _model_to_dict(instance: Base) -> Dict[str, Any]:
return {col.name: instance.__dict__.get(col.name) for col in instance.__table__.columns}
def find_messages(
async def find_messages(
message_filter: dict[str, Any],
sort: Optional[List[tuple[str, int]]] = None,
limit: int = 0,
@@ -51,7 +51,7 @@ def find_messages(
消息字典列表,如果出错则返回空列表。
"""
try:
with get_db_session() as session:
async with get_db_session() as session:
query = select(Messages)
# 应用过滤器
@@ -101,8 +101,8 @@ def find_messages(
# 获取时间最早的 limit 条记录,已经是正序
query = query.order_by(Messages.time.asc()).limit(limit)
try:
results = result = session.execute(query)
result.scalars().all()
result = await session.execute(query)
results = result.scalars().all()
except Exception as e:
logger.error(f"执行earliest查询失败: {e}")
results = []
@@ -110,8 +110,8 @@ def find_messages(
# 获取时间最晚的 limit 条记录
query = query.order_by(Messages.time.desc()).limit(limit)
try:
latest_results = result = session.execute(query)
result.scalars().all()
result = await session.execute(query)
latest_results = result.scalars().all()
# 将结果按时间正序排列
results = sorted(latest_results, key=lambda msg: msg.time)
except Exception as e:
@@ -135,8 +135,8 @@ def find_messages(
if sort_terms:
query = query.order_by(*sort_terms)
try:
results = result = session.execute(query)
result.scalars().all()
result = await session.execute(query)
results = result.scalars().all()
except Exception as e:
logger.error(f"执行无限制查询失败: {e}")
results = []
@@ -152,7 +152,7 @@ def find_messages(
return []
def count_messages(message_filter: dict[str, Any]) -> int:
async def count_messages(message_filter: dict[str, Any]) -> int:
"""
根据提供的过滤器计算消息数量。
@@ -163,7 +163,7 @@ def count_messages(message_filter: dict[str, Any]) -> int:
符合条件的消息数量,如果出错则返回 0。
"""
try:
with get_db_session() as session:
async with get_db_session() as session:
query = select(func.count(Messages.id))
# 应用过滤器
@@ -201,7 +201,7 @@ def count_messages(message_filter: dict[str, Any]) -> int:
if conditions:
query = query.where(*conditions)
count = session.execute(query).scalar()
count = (await session.execute(query)).scalar()
return count or 0
except Exception as e:
log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"