refactor(chat): 迁移数据库操作为异步模式并修复相关调用
将同步数据库操作全面迁移为异步模式,主要涉及: - 将 `with get_db_session()` 改为 `async with get_db_session()` - 修复相关异步调用链,确保 await 正确传递 - 优化消息管理器、上下文管理器等核心组件的异步处理 - 移除同步的 person_id 获取方法,避免协程对象传递问题 修复 deepcopy 在 StreamContext 中的序列化问题,跳过不可序列化的 asyncio.Task 对象 删除无用的测试文件和废弃的插件清单文件
This commit is contained in:
@@ -344,6 +344,39 @@ class StreamContext(BaseDataModel):
|
||||
"""获取优先级信息"""
|
||||
return self.priority_info
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
"""自定义深拷贝,跳过不可序列化的 asyncio.Task (processing_task)。
|
||||
|
||||
deepcopy 在内部可能会尝试 pickle 某些对象(如 asyncio.Task),
|
||||
这会在多线程或运行时事件循环中导致 TypeError。这里我们手动复制
|
||||
__dict__ 中的字段,确保 processing_task 被设置为 None,其他字段使用
|
||||
copy.deepcopy 递归复制。
|
||||
"""
|
||||
import copy
|
||||
|
||||
# 如果已经复制过,直接返回缓存结果
|
||||
obj_id = id(self)
|
||||
if obj_id in memo:
|
||||
return memo[obj_id]
|
||||
|
||||
# 创建一个未初始化的新实例,然后逐个字段深拷贝
|
||||
cls = self.__class__
|
||||
new = cls.__new__(cls)
|
||||
memo[obj_id] = new
|
||||
|
||||
for k, v in self.__dict__.items():
|
||||
if k == "processing_task":
|
||||
# 不复制 asyncio.Task,避免无法 pickling
|
||||
setattr(new, k, None)
|
||||
else:
|
||||
try:
|
||||
setattr(new, k, copy.deepcopy(v, memo))
|
||||
except Exception:
|
||||
# 如果某个字段无法深拷贝,退回到原始引用(安全性谨慎)
|
||||
setattr(new, k, v)
|
||||
|
||||
return new
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageManagerStats(BaseDataModel):
|
||||
|
||||
@@ -30,7 +30,7 @@ def _model_to_dict(instance: Base) -> Dict[str, Any]:
|
||||
return {col.name: instance.__dict__.get(col.name) for col in instance.__table__.columns}
|
||||
|
||||
|
||||
def find_messages(
|
||||
async def find_messages(
|
||||
message_filter: dict[str, Any],
|
||||
sort: Optional[List[tuple[str, int]]] = None,
|
||||
limit: int = 0,
|
||||
@@ -51,7 +51,7 @@ def find_messages(
|
||||
消息字典列表,如果出错则返回空列表。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
query = select(Messages)
|
||||
|
||||
# 应用过滤器
|
||||
@@ -101,8 +101,8 @@ def find_messages(
|
||||
# 获取时间最早的 limit 条记录,已经是正序
|
||||
query = query.order_by(Messages.time.asc()).limit(limit)
|
||||
try:
|
||||
results = result = session.execute(query)
|
||||
result.scalars().all()
|
||||
result = await session.execute(query)
|
||||
results = result.scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"执行earliest查询失败: {e}")
|
||||
results = []
|
||||
@@ -110,8 +110,8 @@ def find_messages(
|
||||
# 获取时间最晚的 limit 条记录
|
||||
query = query.order_by(Messages.time.desc()).limit(limit)
|
||||
try:
|
||||
latest_results = result = session.execute(query)
|
||||
result.scalars().all()
|
||||
result = await session.execute(query)
|
||||
latest_results = result.scalars().all()
|
||||
# 将结果按时间正序排列
|
||||
results = sorted(latest_results, key=lambda msg: msg.time)
|
||||
except Exception as e:
|
||||
@@ -135,8 +135,8 @@ def find_messages(
|
||||
if sort_terms:
|
||||
query = query.order_by(*sort_terms)
|
||||
try:
|
||||
results = result = session.execute(query)
|
||||
result.scalars().all()
|
||||
result = await session.execute(query)
|
||||
results = result.scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"执行无限制查询失败: {e}")
|
||||
results = []
|
||||
@@ -152,7 +152,7 @@ def find_messages(
|
||||
return []
|
||||
|
||||
|
||||
def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
async def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
"""
|
||||
根据提供的过滤器计算消息数量。
|
||||
|
||||
@@ -163,7 +163,7 @@ def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
符合条件的消息数量,如果出错则返回 0。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
query = select(func.count(Messages.id))
|
||||
|
||||
# 应用过滤器
|
||||
@@ -201,7 +201,7 @@ def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
if conditions:
|
||||
query = query.where(*conditions)
|
||||
|
||||
count = session.execute(query).scalar()
|
||||
count = (await session.execute(query)).scalar()
|
||||
return count or 0
|
||||
except Exception as e:
|
||||
log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
|
||||
|
||||
Reference in New Issue
Block a user