refactor(db): 修正SQLAlchemy异步操作调用方式

移除session.add()方法的不必要await调用,修正异步数据库操作模式。主要变更包括:

- 将 `await session.add()` 统一改为 `session.add()`
- 修正部分函数调用为异步版本(如消息查询函数)
- 重构SQLAlchemyTransaction为完全异步实现
- 重写napcat_adapter_plugin数据库层以符合异步规范
- 添加aiomysql和aiosqlite依赖支持
This commit is contained in:
雅诺狐
2025-09-20 17:26:28 +08:00
committed by Windpicker-owo
parent 0cffc0aa95
commit 679195d792
23 changed files with 248 additions and 247 deletions

View File

@@ -3,7 +3,7 @@ import traceback
import os
import pickle
import random
from typing import List, Dict, Any, Coroutine
from typing import List, Dict, Any
from src.config.config import global_config
from src.common.logger import get_logger
from src.person_info.relationship_manager import get_relationship_manager
@@ -114,7 +114,7 @@ class RelationshipBuilder:
# 负责跟踪用户消息活动、管理消息段、清理过期数据
# ================================
def _update_message_segments(self, person_id: str, message_time: float):
async def _update_message_segments(self, person_id: str, message_time: float):
"""更新用户的消息段
Args:
@@ -127,11 +127,8 @@ class RelationshipBuilder:
segments = self.person_engaged_cache[person_id]
# 获取该消息前5条消息的时间作为潜在的开始时间
before_messages = get_raw_msg_before_timestamp_with_chat(self.chat_id, message_time, limit=5)
if before_messages:
potential_start_time = before_messages[0]["time"]
else:
potential_start_time = message_time
before_messages = await get_raw_msg_before_timestamp_with_chat(self.chat_id, message_time, limit=5)
potential_start_time = before_messages[0]["time"] if before_messages else message_time
# 如果没有现有消息段,创建新的
if not segments:
@@ -139,12 +136,10 @@ class RelationshipBuilder:
"start_time": potential_start_time,
"end_time": message_time,
"last_msg_time": message_time,
"message_count": self._count_messages_in_timerange(potential_start_time, message_time),
"message_count": await self._count_messages_in_timerange(potential_start_time, message_time),
}
segments.append(new_segment)
person = Person(person_id=person_id)
person_name = person.person_name or person_id
person_name = get_person_info_manager().get_value_sync(person_id, "person_name") or person_id
logger.debug(
f"{self.log_prefix} 眼熟用户 {person_name}{time.strftime('%H:%M:%S', time.localtime(potential_start_time))} - {time.strftime('%H:%M:%S', time.localtime(message_time))} 之间有 {new_segment['message_count']} 条消息"
)
@@ -155,39 +150,32 @@ class RelationshipBuilder:
last_segment = segments[-1]
# 计算从最后一条消息到当前消息之间的消息数量(不包含边界)
messages_between = self._count_messages_between(last_segment["last_msg_time"], message_time)
messages_between = await self._count_messages_between(last_segment["last_msg_time"], message_time)
if messages_between <= 10:
# 在10条消息内延伸当前消息段
last_segment["end_time"] = message_time
last_segment["last_msg_time"] = message_time
# 重新计算整个消息段的消息数量
last_segment["message_count"] = self._count_messages_in_timerange(
last_segment["message_count"] = await self._count_messages_in_timerange(
last_segment["start_time"], last_segment["end_time"]
)
logger.debug(f"{self.log_prefix} 延伸用户 {person_id} 的消息段: {last_segment}")
else:
# 超过10条消息结束当前消息段并创建新的
# 结束当前消息段延伸到原消息段最后一条消息后5条消息的时间
current_time = time.time()
after_messages = get_raw_msg_by_timestamp_with_chat(
after_messages = await get_raw_msg_by_timestamp_with_chat(
self.chat_id, last_segment["last_msg_time"], current_time, limit=5, limit_mode="earliest"
)
if after_messages and len(after_messages) >= 5:
# 如果有足够的后续消息使用第5条消息的时间作为结束时间
last_segment["end_time"] = after_messages[4]["time"]
# 重新计算当前消息段的消息数量
last_segment["message_count"] = self._count_messages_in_timerange(
last_segment["message_count"] = await self._count_messages_in_timerange(
last_segment["start_time"], last_segment["end_time"]
)
# 创建新的消息段
new_segment = {
"start_time": potential_start_time,
"end_time": message_time,
"last_msg_time": message_time,
"message_count": self._count_messages_in_timerange(potential_start_time, message_time),
"message_count": await self._count_messages_in_timerange(potential_start_time, message_time),
}
segments.append(new_segment)
person = Person(person_id=person_id)
@@ -198,14 +186,14 @@ class RelationshipBuilder:
self._save_cache()
def _count_messages_in_timerange(self, start_time: float, end_time: float) -> int:
async def _count_messages_in_timerange(self, start_time: float, end_time: float) -> int:
"""计算指定时间范围内的消息数量(包含边界)"""
messages = get_raw_msg_by_timestamp_with_chat_inclusive(self.chat_id, start_time, end_time)
messages = await get_raw_msg_by_timestamp_with_chat_inclusive(self.chat_id, start_time, end_time)
return len(messages)
def _count_messages_between(self, start_time: float, end_time: float) -> Coroutine[Any, Any, int]:
async def _count_messages_between(self, start_time: float, end_time: float) -> int:
"""计算两个时间点之间的消息数量(不包含边界),用于间隔检查"""
return num_new_messages_since(self.chat_id, start_time, end_time)
return await num_new_messages_since(self.chat_id, start_time, end_time)
def _get_total_message_count(self, person_id: str) -> int:
"""获取用户所有消息段的总消息数量"""
@@ -352,7 +340,7 @@ class RelationshipBuilder:
self._cleanup_old_segments()
current_time = time.time()
if latest_messages := get_raw_msg_by_timestamp_with_chat(
if latest_messages := await get_raw_msg_by_timestamp_with_chat(
self.chat_id,
self.last_processed_message_time,
current_time,
@@ -370,8 +358,8 @@ class RelationshipBuilder:
and user_id != global_config.bot.qq_account
and msg_time > self.last_processed_message_time
):
person_id = get_person_id(platform, user_id)
self._update_message_segments(person_id, msg_time)
person_id = PersonInfoManager.get_person_id(platform, user_id)
await self._update_message_segments(person_id, msg_time)
logger.debug(
f"{self.log_prefix} 更新用户 {person_id} 的消息段,消息时间:{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(msg_time))}"
)
@@ -441,7 +429,7 @@ class RelationshipBuilder:
start_date = time.strftime("%Y-%m-%d %H:%M", time.localtime(start_time))
# 获取该段的消息(包含边界)
segment_messages = get_raw_msg_by_timestamp_with_chat_inclusive(self.chat_id, start_time, end_time)
segment_messages = await get_raw_msg_by_timestamp_with_chat_inclusive(self.chat_id, start_time, end_time)
logger.debug(
f"消息段: {start_date} - {time.strftime('%Y-%m-%d %H:%M', time.localtime(end_time))}, 消息数: {len(segment_messages)}"
)