refactor(person_info): 引入同步方法 get_value 并替换旧的 get_value_sync(因为根本就没有这个方法)
为了解决在不同异步上下文中同步调用数据库可能引发的运行时错误,实现了一个新的、更健壮的同步方法 `PersonInfoManager.get_value`。 - 新方法能够正确处理已在运行的 asyncio 事件循环,提高了在混合代码环境中调用的稳定性。 - 全面替换了原有的 `get_value_sync` 方法调用,统一了同步获取用户信息的接口。
This commit is contained in:
committed by
Windpicker-owo
parent
02067b6eeb
commit
79baac2797
@@ -43,9 +43,9 @@ def replace_user_references_sync(
|
||||
# 检查是否是机器人自己
|
||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||
return f"{global_config.bot.nickname}(你)"
|
||||
person = Person(platform=platform, user_id=user_id)
|
||||
return person.person_name or user_id # type: ignore
|
||||
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
return person_info_manager.get_value(person_id, "person_name") or user_id # type: ignore
|
||||
|
||||
name_resolver = default_resolver
|
||||
|
||||
# 处理回复<aaa:bbb>格式
|
||||
|
||||
@@ -965,7 +965,7 @@ class Prompt:
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
if person_id:
|
||||
user_id = person_info_manager.get_value_sync(person_id, "user_id")
|
||||
user_id = person_info_manager.get_value(person_id, "user_id")
|
||||
return str(user_id) if user_id else ""
|
||||
|
||||
return ""
|
||||
|
||||
@@ -666,7 +666,8 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
||||
person_name = None
|
||||
if person_id:
|
||||
# get_value is async, so await it directly
|
||||
person_name = person.person_name
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_name = person_info_manager.get_value(person_id, "person_name")
|
||||
|
||||
target_info["person_id"] = person_id
|
||||
target_info["person_name"] = person_name
|
||||
|
||||
@@ -846,6 +846,45 @@ class PersonInfoManager:
|
||||
logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行")
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_value(person_id: str, field_name: str) -> Any:
|
||||
"""获取单个字段值(同步版本)"""
|
||||
if not person_id:
|
||||
logger.debug("get_value获取失败:person_id不能为空")
|
||||
return None
|
||||
|
||||
import asyncio
|
||||
|
||||
async def _get_record_sync():
|
||||
async with get_db_session() as session:
|
||||
return (await session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id))).scalar()
|
||||
|
||||
try:
|
||||
record = asyncio.run(_get_record_sync())
|
||||
except RuntimeError:
|
||||
# 如果当前线程已经有事件循环在运行,则使用现有的循环
|
||||
loop = asyncio.get_running_loop()
|
||||
record = loop.run_until_complete(_get_record_sync())
|
||||
|
||||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||||
|
||||
if field_name not in model_fields:
|
||||
if field_name in person_info_default:
|
||||
logger.debug(f"字段'{field_name}'不在SQLAlchemy模型中,使用默认配置值。")
|
||||
return copy.deepcopy(person_info_default[field_name])
|
||||
else:
|
||||
logger.debug(f"get_value查询失败:字段'{field_name}'未在SQLAlchemy模型和默认配置中定义。")
|
||||
return None
|
||||
|
||||
if record:
|
||||
value = getattr(record, field_name)
|
||||
if value is not None:
|
||||
return value
|
||||
else:
|
||||
return copy.deepcopy(person_info_default.get(field_name))
|
||||
else:
|
||||
return copy.deepcopy(person_info_default.get(field_name))
|
||||
|
||||
@staticmethod
|
||||
async def get_values(person_id: str, field_names: list) -> dict:
|
||||
"""获取指定person_id文档的多个字段值,若不存在该字段,则返回该字段的全局默认值"""
|
||||
@@ -884,7 +923,6 @@ class PersonInfoManager:
|
||||
result[field_name] = copy.deepcopy(person_info_default.get(field_name))
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def get_specific_value_list(
|
||||
field_name: str,
|
||||
|
||||
@@ -139,7 +139,7 @@ class RelationshipBuilder:
|
||||
"message_count": await self._count_messages_in_timerange(potential_start_time, message_time),
|
||||
}
|
||||
segments.append(new_segment)
|
||||
person_name = get_person_info_manager().get_value_sync(person_id, "person_name") or person_id
|
||||
person_name = get_person_info_manager().get_value(person_id, "person_name") or person_id
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 眼熟用户 {person_name} 在 {time.strftime('%H:%M:%S', time.localtime(potential_start_time))} - {time.strftime('%H:%M:%S', time.localtime(message_time))} 之间有 {new_segment['message_count']} 条消息"
|
||||
)
|
||||
@@ -178,8 +178,8 @@ class RelationshipBuilder:
|
||||
"message_count": await self._count_messages_in_timerange(potential_start_time, message_time),
|
||||
}
|
||||
segments.append(new_segment)
|
||||
person = Person(person_id=person_id)
|
||||
person_name = person.person_name or person_id
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_name = person_info_manager.get_value(person_id, "person_name") or person_id
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 重新眼熟用户 {person_name} 创建新消息段(超过10条消息间隔): {new_segment}"
|
||||
)
|
||||
@@ -369,8 +369,8 @@ class RelationshipBuilder:
|
||||
users_to_build_relationship = []
|
||||
for person_id, segments in self.person_engaged_cache.items():
|
||||
total_message_count = self._get_total_message_count(person_id)
|
||||
person_name = get_person_info_manager().get_value_sync(person_id, "person_name") or person_id
|
||||
|
||||
person_name = get_person_info_manager().get_value(person_id, "person_name") or person_id
|
||||
|
||||
if total_message_count >= max_build_threshold or (
|
||||
total_message_count >= 5 and (immediate_build == person_id or immediate_build == "all")
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user