feat(person_info): 实施基于稳健 ID 的用户信息同步。本次提交重构了用户识别和信息检索系统,使其基于稳定的平台和用户 ID,不再依赖脆弱的姓名解析机制。同时引入了自动后台进程,以保持用户信息的实时更新。主要变更包括:
- 在 `PersonInfoManager` 中新增 `sync_user_info` 方法,根据 `platform` 和 `user_id` 来创建和更新用户记录。 - `ChatManager` 现在会在处理消息时触发该同步作为非阻塞后台任务,确保用户数据(如昵称)保持最新。 - 提示生成逻辑,特别是关系和上下文信息的生成,已重构为使用稳定的 `user_id`,而非从回复消息内容中解析姓名。 - `PromptParameters` 已被扩展,以在整个回复生成流程中传递 `platform` 和 `user_id`。 - 弃用依赖名称到 ID 查找的脆弱方法。
This commit is contained in:
@@ -20,6 +20,9 @@ install(extra_lines=3)
|
|||||||
|
|
||||||
logger = get_logger("chat_stream")
|
logger = get_logger("chat_stream")
|
||||||
|
|
||||||
|
# 用于存储后台任务的集合,防止被垃圾回收
|
||||||
|
_background_tasks: set[asyncio.Task] = set()
|
||||||
|
|
||||||
|
|
||||||
class ChatStream:
|
class ChatStream:
|
||||||
"""聊天流对象,存储一个完整的聊天上下文"""
|
"""聊天流对象,存储一个完整的聊天上下文"""
|
||||||
@@ -406,6 +409,40 @@ class ChatManager:
|
|||||||
key = "_".join(components)
|
key = "_".join(components)
|
||||||
return hashlib.sha256(key.encode()).hexdigest()
|
return hashlib.sha256(key.encode()).hexdigest()
|
||||||
|
|
||||||
|
async def _process_message(self, message: DatabaseMessages):
|
||||||
|
"""
|
||||||
|
[新] 在消息处理流程中加入用户信息同步。
|
||||||
|
"""
|
||||||
|
# 1. 从消息中提取用户信息
|
||||||
|
user_info = getattr(message, "user_info", None)
|
||||||
|
if not user_info:
|
||||||
|
return
|
||||||
|
|
||||||
|
platform = getattr(user_info, "platform", None)
|
||||||
|
user_id = getattr(user_info, "user_id", None)
|
||||||
|
nickname = getattr(user_info, "user_nickname", None)
|
||||||
|
cardname = getattr(user_info, "user_cardname", None)
|
||||||
|
|
||||||
|
if not platform or not user_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 2. 异步执行用户信息同步
|
||||||
|
try:
|
||||||
|
from src.person_info.person_info import get_person_info_manager
|
||||||
|
person_info_manager = get_person_info_manager()
|
||||||
|
|
||||||
|
# 创建一个后台任务来执行同步,不阻塞当前流程
|
||||||
|
sync_task = asyncio.create_task(
|
||||||
|
person_info_manager.sync_user_info(platform, user_id, nickname, cardname)
|
||||||
|
)
|
||||||
|
# 将任务添加到集合中以防止被垃圾回收
|
||||||
|
# 可以在适当的地方(如程序关闭时)清理这个集合
|
||||||
|
_background_tasks.add(sync_task)
|
||||||
|
sync_task.add_done_callback(_background_tasks.discard)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建用户信息同步任务失败: {e}")
|
||||||
|
|
||||||
async def get_or_create_stream(
|
async def get_or_create_stream(
|
||||||
self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None
|
self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None
|
||||||
) -> ChatStream:
|
) -> ChatStream:
|
||||||
@@ -437,7 +474,10 @@ class ChatManager:
|
|||||||
# 检查是否有最后一条消息(现在使用 DatabaseMessages)
|
# 检查是否有最后一条消息(现在使用 DatabaseMessages)
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
|
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
|
||||||
await stream.set_context(self.last_messages[stream_id])
|
last_message = self.last_messages[stream_id]
|
||||||
|
await stream.set_context(last_message)
|
||||||
|
# 在这里调用消息处理
|
||||||
|
await self._process_message(last_message)
|
||||||
else:
|
else:
|
||||||
logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的或还没有消息")
|
logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的或还没有消息")
|
||||||
return stream
|
return stream
|
||||||
@@ -634,23 +674,23 @@ class ChatManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"批量写入器保存聊天流失败,使用原始方法: {e}")
|
logger.debug(f"批量写入器保存聊天流失败,使用原始方法: {e}")
|
||||||
|
|
||||||
# 尝试使用数据库批量调度器(回退方案1)
|
# 尝试使用数据库批量调度器(回退方案1) - [已废弃]
|
||||||
try:
|
# try:
|
||||||
from src.common.database.db_batch_scheduler import batch_update, get_batch_session
|
# from src.common.database.optimization.batch_scheduler import batch_update, get_batch_session
|
||||||
|
#
|
||||||
async with get_batch_session():
|
# async with get_batch_session():
|
||||||
# 使用批量更新
|
# # 使用批量更新
|
||||||
result = await batch_update(
|
# result = await batch_update(
|
||||||
model_class=ChatStreams,
|
# model_class=ChatStreams,
|
||||||
conditions={"stream_id": stream_data_dict["stream_id"]},
|
# conditions={"stream_id": stream_data_dict["stream_id"]},
|
||||||
data=ChatManager._prepare_stream_data(stream_data_dict),
|
# data=ChatManager._prepare_stream_data(stream_data_dict),
|
||||||
)
|
# )
|
||||||
if result and result > 0:
|
# if result and result > 0:
|
||||||
stream.saved = True
|
# stream.saved = True
|
||||||
logger.debug(f"聊天流 {stream.stream_id} 通过批量调度器保存成功")
|
# logger.debug(f"聊天流 {stream.stream_id} 通过批量调度器保存成功")
|
||||||
return
|
# return
|
||||||
except (ImportError, Exception) as e:
|
# except (ImportError, Exception) as e:
|
||||||
logger.debug(f"批量调度器保存聊天流失败,使用原始方法: {e}")
|
# logger.debug(f"批量调度器保存聊天流失败,使用原始方法: {e}")
|
||||||
|
|
||||||
# 回退到原始方法(最终方案)
|
# 回退到原始方法(最终方案)
|
||||||
async def _db_save_stream_async(s_data_dict: dict):
|
async def _db_save_stream_async(s_data_dict: dict):
|
||||||
|
|||||||
@@ -1332,8 +1332,8 @@ class DefaultReplyer:
|
|||||||
),
|
),
|
||||||
"cross_context": asyncio.create_task(
|
"cross_context": asyncio.create_task(
|
||||||
self._time_and_run_task(
|
self._time_and_run_task(
|
||||||
Prompt.build_cross_context(chat_id, "s4u", target_user_info),
|
# cross_context 的构建已移至 prompt.py
|
||||||
"cross_context",
|
asyncio.sleep(0, result=""), "cross_context"
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
"notice_block": asyncio.create_task(
|
"notice_block": asyncio.create_task(
|
||||||
@@ -1521,6 +1521,8 @@ class DefaultReplyer:
|
|||||||
|
|
||||||
# 使用新的统一Prompt系统 - 创建PromptParameters
|
# 使用新的统一Prompt系统 - 创建PromptParameters
|
||||||
prompt_parameters = PromptParameters(
|
prompt_parameters = PromptParameters(
|
||||||
|
platform=platform,
|
||||||
|
user_id=user_id,
|
||||||
chat_scene=chat_scene_prompt,
|
chat_scene=chat_scene_prompt,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
is_group_chat=is_group_chat,
|
is_group_chat=is_group_chat,
|
||||||
|
|||||||
@@ -709,9 +709,17 @@ class Prompt:
|
|||||||
async def _build_relation_info(self) -> dict[str, Any]:
|
async def _build_relation_info(self) -> dict[str, Any]:
|
||||||
"""构建与对话目标相关的关系信息."""
|
"""构建与对话目标相关的关系信息."""
|
||||||
try:
|
try:
|
||||||
# 调用静态方法来执行实际的构建逻辑
|
# [重构] 直接从 PromptParameters 获取稳定的用户身份信息
|
||||||
relation_info = await Prompt.build_relation_info(
|
platform = self.parameters.platform
|
||||||
self.parameters.chat_id, self.parameters.reply_to
|
user_id = self.parameters.user_id
|
||||||
|
|
||||||
|
if not platform or not user_id:
|
||||||
|
logger.warning("无法从参数中获取platform或user_id,跳过关系信息构建")
|
||||||
|
return {"relation_info_block": ""}
|
||||||
|
|
||||||
|
# 调用新的、基于ID的静态方法
|
||||||
|
relation_info = await Prompt.build_relation_info_by_user_id(
|
||||||
|
self.parameters.chat_id, platform, user_id
|
||||||
)
|
)
|
||||||
return {"relation_info_block": relation_info}
|
return {"relation_info_block": relation_info}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1063,43 +1071,29 @@ class Prompt:
|
|||||||
return sender, target
|
return sender, target
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def build_relation_info(chat_id: str, reply_to: str) -> str:
|
async def build_relation_info_by_user_id(chat_id: str, platform: str, user_id: str) -> str:
|
||||||
"""构建关于回复目标用户的关系信息字符串.
|
"""
|
||||||
|
[新] 根据用户ID构建关系信息字符串。
|
||||||
Args:
|
|
||||||
chat_id: 当前聊天的ID。
|
|
||||||
reply_to: 被回复的原始消息字符串。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 格式化后的关系信息字符串,或在失败时返回空字符串。
|
|
||||||
"""
|
"""
|
||||||
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
||||||
|
|
||||||
|
person_info_manager = get_person_info_manager()
|
||||||
|
person_id = person_info_manager.get_person_id(platform, user_id)
|
||||||
|
|
||||||
|
if not person_id:
|
||||||
|
logger.warning(f"构建关系信息时未找到用户 platform={platform}, user_id={user_id}")
|
||||||
|
return f"你似乎还不认识这位用户(ID: {user_id}),这是你们的第一次互动。"
|
||||||
|
|
||||||
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id)
|
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id)
|
||||||
|
|
||||||
if not reply_to:
|
# 并行构建用户信息和聊天流印象
|
||||||
return ""
|
user_relation_info_task = relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||||
# 解析出回复目标的发送者
|
stream_impression_task = relationship_fetcher.build_chat_stream_impression(chat_id)
|
||||||
sender, text = Prompt.parse_reply_target(reply_to)
|
|
||||||
if not sender or not text:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
# 根据发送者名称查找其用户ID
|
user_relation_info, stream_impression = await asyncio.gather(
|
||||||
person_info_manager = get_person_info_manager()
|
user_relation_info_task, stream_impression_task
|
||||||
person_id = await person_info_manager.get_person_id_by_person_name(sender)
|
|
||||||
if not person_id:
|
|
||||||
logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取")
|
|
||||||
return f"你完全不认识{sender},不理解ta的相关信息。"
|
|
||||||
|
|
||||||
# 使用关系提取器构建用户关系信息和聊天流印象
|
|
||||||
user_relation_info = await relationship_fetcher.build_relation_info(
|
|
||||||
person_id, points_num=5
|
|
||||||
)
|
|
||||||
stream_impression = await relationship_fetcher.build_chat_stream_impression(
|
|
||||||
chat_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 组合两部分信息
|
|
||||||
info_parts = []
|
info_parts = []
|
||||||
if user_relation_info:
|
if user_relation_info:
|
||||||
info_parts.append(user_relation_info)
|
info_parts.append(user_relation_info)
|
||||||
@@ -1149,6 +1143,7 @@ class Prompt:
|
|||||||
Returns:
|
Returns:
|
||||||
str: 构建好的跨群聊上下文字符串。
|
str: 构建好的跨群聊上下文字符串。
|
||||||
"""
|
"""
|
||||||
|
logger.info(f"Building cross context with target_user_info: {target_user_info}")
|
||||||
if not global_config.cross_context.enable:
|
if not global_config.cross_context.enable:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
@@ -1167,32 +1162,22 @@ class Prompt:
|
|||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
@staticmethod
|
# [废弃] 该函数完全依赖于不稳定的名称解析,应被移除
|
||||||
async def parse_reply_target_id(reply_to: str) -> str:
|
# @staticmethod
|
||||||
"""从回复目标字符串中解析出原始发送者的用户ID.
|
# async def parse_reply_target_id(reply_to: str) -> str:
|
||||||
|
# """从回复目标字符串中解析出原始发送者的用户ID."""
|
||||||
Args:
|
# if not reply_to:
|
||||||
reply_to: 回复目标字符串。
|
# return ""
|
||||||
|
# sender, _ = Prompt.parse_reply_target(reply_to)
|
||||||
Returns:
|
# if not sender:
|
||||||
str: 找到的用户ID,如果找不到则返回空字符串。
|
# return ""
|
||||||
"""
|
# person_info_manager = get_person_info_manager()
|
||||||
if not reply_to:
|
# # [脆弱点] 使用了不稳健的按名称查询
|
||||||
return ""
|
# person_id = await person_info_manager.get_person_id_by_name_robust(sender)
|
||||||
|
# if person_id:
|
||||||
# 首先,解析出发送者的名称
|
# user_id = await person_info_manager.get_value(person_id, "user_id")
|
||||||
sender, _ = Prompt.parse_reply_target(reply_to)
|
# return str(user_id) if user_id else ""
|
||||||
if not sender:
|
# return ""
|
||||||
return ""
|
|
||||||
|
|
||||||
# 然后,通过名称查询用户ID
|
|
||||||
person_info_manager = get_person_info_manager()
|
|
||||||
person_id = await person_info_manager.get_person_id_by_person_name(sender)
|
|
||||||
if person_id:
|
|
||||||
user_id = await person_info_manager.get_value(person_id, "user_id")
|
|
||||||
return str(user_id) if user_id else ""
|
|
||||||
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
# 工厂函数
|
# 工厂函数
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ class PromptParameters:
|
|||||||
|
|
||||||
# 基础参数
|
# 基础参数
|
||||||
chat_id: str = ""
|
chat_id: str = ""
|
||||||
|
platform: str = ""
|
||||||
|
user_id: str = ""
|
||||||
is_group_chat: bool = False
|
is_group_chat: bool = False
|
||||||
sender: str = ""
|
sender: str = ""
|
||||||
target: str = ""
|
target: str = ""
|
||||||
|
|||||||
@@ -135,6 +135,113 @@ class PersonInfoManager:
|
|||||||
logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}")
|
logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@cached(ttl=600, key_prefix="person_info_by_user_id", use_kwargs=False)
|
||||||
|
async def get_person_info_by_user_id(platform: str, user_id: str) -> dict | None:
|
||||||
|
"""[新] 根据 platform 和 user_id 获取用户信息字典"""
|
||||||
|
if not platform or not user_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||||
|
crud = CRUDBase(PersonInfo)
|
||||||
|
record = await crud.get_by(person_id=person_id)
|
||||||
|
|
||||||
|
if not record:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 将 SQLAlchemy 模型对象转换为字典
|
||||||
|
return {c.name: getattr(record, c.name) for c in record.__table__.columns}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@cached(ttl=600, key_prefix="person_info_by_person_id", use_kwargs=False)
|
||||||
|
async def get_person_info_by_person_id(person_id: str) -> dict | None:
|
||||||
|
"""[新] 根据 person_id 获取用户信息字典"""
|
||||||
|
if not person_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
crud = CRUDBase(PersonInfo)
|
||||||
|
record = await crud.get_by(person_id=person_id)
|
||||||
|
|
||||||
|
if not record:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 将 SQLAlchemy 模型对象转换为字典
|
||||||
|
return {c.name: getattr(record, c.name) for c in record.__table__.columns}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_person_id_by_name_robust(name: str) -> str | None:
|
||||||
|
"""[新] 稳健地根据名称获取 person_id,按 person_name -> nickname 顺序回退"""
|
||||||
|
if not name:
|
||||||
|
return None
|
||||||
|
|
||||||
|
crud = CRUDBase(PersonInfo)
|
||||||
|
|
||||||
|
# 1. 按 person_name 查询
|
||||||
|
records = await crud.get_multi(person_name=name, limit=1)
|
||||||
|
if records:
|
||||||
|
return records[0].person_id
|
||||||
|
|
||||||
|
# 2. 按 nickname 查询
|
||||||
|
records = await crud.get_multi(nickname=name, limit=1)
|
||||||
|
if records:
|
||||||
|
return records[0].person_id
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@staticmethod
|
||||||
|
@cached(ttl=600, key_prefix="person_info_by_name_robust", use_kwargs=False)
|
||||||
|
async def get_person_info_by_name_robust(name: str) -> dict | None:
|
||||||
|
"""[新] 稳健地根据名称获取用户信息,按 person_name -> nickname 顺序回退"""
|
||||||
|
person_id = await PersonInfoManager.get_person_id_by_name_robust(name)
|
||||||
|
if person_id:
|
||||||
|
return await PersonInfoManager.get_person_info_by_person_id(person_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def sync_user_info(platform: str, user_id: str, nickname: str | None, cardname: str | None) -> str:
|
||||||
|
"""
|
||||||
|
[新] 同步用户信息。查询或创建用户,并更新易变信息(如昵称)。
|
||||||
|
返回 person_id。
|
||||||
|
"""
|
||||||
|
if not platform or not user_id:
|
||||||
|
raise ValueError("platform 和 user_id 不能为空")
|
||||||
|
|
||||||
|
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||||
|
crud = CRUDBase(PersonInfo)
|
||||||
|
record = await crud.get_by(person_id=person_id)
|
||||||
|
|
||||||
|
effective_name = cardname or nickname or "未知用户"
|
||||||
|
|
||||||
|
if record:
|
||||||
|
# 用户已存在,检查是否需要更新
|
||||||
|
updates = {}
|
||||||
|
if nickname and record.nickname != nickname:
|
||||||
|
updates["nickname"] = nickname
|
||||||
|
|
||||||
|
if updates:
|
||||||
|
await crud.update(record.id, updates)
|
||||||
|
logger.debug(f"用户 {person_id} 信息已更新: {updates}")
|
||||||
|
else:
|
||||||
|
# 用户不存在,创建新用户
|
||||||
|
logger.info(f"新用户 {platform}:{user_id},将创建记录。")
|
||||||
|
unique_person_name = await PersonInfoManager._generate_unique_person_name(effective_name)
|
||||||
|
|
||||||
|
new_person_data = {
|
||||||
|
"person_id": person_id,
|
||||||
|
"platform": platform,
|
||||||
|
"user_id": str(user_id),
|
||||||
|
"nickname": nickname,
|
||||||
|
"person_name": unique_person_name,
|
||||||
|
"name_reason": "首次遇见时自动设置",
|
||||||
|
"know_since": int(time.time()),
|
||||||
|
"last_know": int(time.time()),
|
||||||
|
}
|
||||||
|
await PersonInfoManager._safe_create_person_info(person_id, new_person_data)
|
||||||
|
|
||||||
|
return person_id
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str):
|
async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str):
|
||||||
"""判断是否认识某人"""
|
"""判断是否认识某人"""
|
||||||
|
|||||||
Reference in New Issue
Block a user