refactor(chat): 迁移数据库操作为异步模式并修复相关调用
将同步数据库操作全面迁移为异步模式,主要涉及: - 将 `with get_db_session()` 改为 `async with get_db_session()` - 修复相关异步调用链,确保 await 正确传递 - 优化消息管理器、上下文管理器等核心组件的异步处理 - 移除同步的 person_id 获取方法,避免协程对象传递问题 修复 deepcopy 在 StreamContext 中的序列化问题,跳过不可序列化的 asyncio.Task 对象 删除无用的测试文件和废弃的插件清单文件
This commit is contained in:
@@ -30,7 +30,7 @@ def _model_to_dict(instance: Base) -> Dict[str, Any]:
|
||||
return {col.name: instance.__dict__.get(col.name) for col in instance.__table__.columns}
|
||||
|
||||
|
||||
def find_messages(
|
||||
async def find_messages(
|
||||
message_filter: dict[str, Any],
|
||||
sort: Optional[List[tuple[str, int]]] = None,
|
||||
limit: int = 0,
|
||||
@@ -51,7 +51,7 @@ def find_messages(
|
||||
消息字典列表,如果出错则返回空列表。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
query = select(Messages)
|
||||
|
||||
# 应用过滤器
|
||||
@@ -101,8 +101,8 @@ def find_messages(
|
||||
# 获取时间最早的 limit 条记录,已经是正序
|
||||
query = query.order_by(Messages.time.asc()).limit(limit)
|
||||
try:
|
||||
results = result = session.execute(query)
|
||||
result.scalars().all()
|
||||
result = await session.execute(query)
|
||||
results = result.scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"执行earliest查询失败: {e}")
|
||||
results = []
|
||||
@@ -110,8 +110,8 @@ def find_messages(
|
||||
# 获取时间最晚的 limit 条记录
|
||||
query = query.order_by(Messages.time.desc()).limit(limit)
|
||||
try:
|
||||
latest_results = result = session.execute(query)
|
||||
result.scalars().all()
|
||||
result = await session.execute(query)
|
||||
latest_results = result.scalars().all()
|
||||
# 将结果按时间正序排列
|
||||
results = sorted(latest_results, key=lambda msg: msg.time)
|
||||
except Exception as e:
|
||||
@@ -135,8 +135,8 @@ def find_messages(
|
||||
if sort_terms:
|
||||
query = query.order_by(*sort_terms)
|
||||
try:
|
||||
results = result = session.execute(query)
|
||||
result.scalars().all()
|
||||
result = await session.execute(query)
|
||||
results = result.scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"执行无限制查询失败: {e}")
|
||||
results = []
|
||||
@@ -152,7 +152,7 @@ def find_messages(
|
||||
return []
|
||||
|
||||
|
||||
def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
async def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
"""
|
||||
根据提供的过滤器计算消息数量。
|
||||
|
||||
@@ -163,7 +163,7 @@ def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
符合条件的消息数量,如果出错则返回 0。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
query = select(func.count(Messages.id))
|
||||
|
||||
# 应用过滤器
|
||||
@@ -201,7 +201,7 @@ def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
if conditions:
|
||||
query = query.where(*conditions)
|
||||
|
||||
count = session.execute(query).scalar()
|
||||
count = (await session.execute(query)).scalar()
|
||||
return count or 0
|
||||
except Exception as e:
|
||||
log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
|
||||
|
||||
Reference in New Issue
Block a user