diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index 15d583f1c..7dd7df940 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -104,7 +104,7 @@ class HeartFCSender: # 将MessageSending转换为DatabaseMessages db_message = await self._convert_to_database_message(message) if db_message and message.chat_stream.context_manager: - await message.chat_stream.context_manager.add_message(db_message) + message.chat_stream.context_manager.context.history_messages.append(db_message) logger.debug(f"[{chat_id}] Send API消息已添加到流上下文: {message_id}") except Exception as context_error: logger.warning(f"[{chat_id}] 将Send API消息添加到流上下文失败: {context_error}") diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index cdfc03360..6165d1a2a 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -7,10 +7,8 @@ from typing import Any import orjson from json_repair import repair_json -from sqlalchemy import select from src.common.database.api.crud import CRUDBase -from src.common.database.compatibility import get_db_session from src.common.database.core.models import PersonInfo from src.common.database.utils.decorators import cached from src.common.logger import get_logger @@ -57,7 +55,7 @@ person_info_default = { class PersonInfoManager: def __init__(self): """初始化PersonInfoManager""" - self.person_name_list = {} + # 移除self.person_name_list缓存,统一使用数据库缓存系统 self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name") # try: # async with get_db_session() as session: @@ -74,19 +72,7 @@ class PersonInfoManager: # except Exception as e: # logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}") - # # 初始化时读取所有person_name - try: - pass - # 在这里获取会话 - # async with get_db_session() as session: - # for record in session.execute( - # select(PersonInfo.person_id, PersonInfo.person_name).where(PersonInfo.person_name.is_not(None)) - # ).fetchall(): - # if record.person_name: - # self.person_name_list[record.person_id] = record.person_name - # logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (SQLAlchemy)") - except Exception as e: - logger.error(f"从 SQLAlchemy 加载 person_name_list 失败: {e}") + # 移除初始化时读取person_name_list的逻辑,统一使用数据库缓存 @staticmethod def get_person_id(platform: str, user_id: int | str) -> str: @@ -124,33 +110,26 @@ class PersonInfoManager: logger.error(f"检查用户 {person_id} 是否已知时出错: {e}") return False - async def get_person_id_by_person_name(self, person_name: str) -> str: + @staticmethod + @cached(ttl=600, key_prefix="person_name_to_id", use_kwargs=False) + async def get_person_id_by_person_name(person_name: str) -> str: """ 根据用户名获取用户ID(异步) - 说明: 优先在内存缓存 `self.person_name_list` 中查找, - 若未命中则查询数据库并更新缓存。 + 统一使用数据库缓存系统,移除内存缓存 """ + if not person_name: + return "" + try: - # 优先使用内存缓存加速查找:self.person_name_list maps person_id -> person_name - for pid, pname in self.person_name_list.items(): - if pname == person_name: - return pid + # 使用CRUD接口查询,使用装饰器缓存 + crud = CRUDBase(PersonInfo) + records = await crud.get_multi(person_name=person_name, limit=1) - # 缓存未命中,查询数据库 - async with get_db_session() as session: - result = await session.execute( - select(PersonInfo).where(PersonInfo.person_name == person_name) - ) - record = result.scalar() + if records: + return records[0].person_id - if record: - # 找到了,更新缓存 - self.person_name_list[record.person_id] = person_name - logger.debug(f"从数据库查到用户 '{person_name}',已更新缓存") - return record.person_id - - # 数据库也没有,返回空字符串 + # 数据库中没有找到 return "" except Exception as e: logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}") @@ -161,8 +140,7 @@ class PersonInfoManager: """判断是否认识某人""" person_id = PersonInfoManager.get_person_id(platform, user_id) # 生成唯一的 person_name - person_info_manager = get_person_info_manager() - unique_nickname = await person_info_manager._generate_unique_person_name(user_nickname) + unique_nickname = await PersonInfoManager._generate_unique_person_name(user_nickname) data = { "platform": platform, "user_id": user_id, @@ -171,9 +149,9 @@ class PersonInfoManager: "person_name": unique_nickname, # 使用唯一的 person_name } # 先创建用户基本信息,使用安全创建方法避免竞态条件 - await person_info_manager._safe_create_person_info(person_id=person_id, data=data) + await PersonInfoManager._safe_create_person_info(person_id=person_id, data=data) # 更新昵称 - await person_info_manager.update_one_field( + await get_person_info_manager().update_one_field( person_id=person_id, field_name="nickname", value=user_nickname, data=data ) @@ -225,18 +203,12 @@ class PersonInfoManager: final_data[key] = orjson.dumps([]).decode("utf-8") # If it's already a string, assume it's valid JSON or a non-JSON string field - async def _db_create_async(p_data: dict): - async with get_db_session() as session: - try: - new_person = PersonInfo(**p_data) - session.add(new_person) - await session.commit() - return True - except Exception as e: - logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}") - return False - - await _db_create_async(final_data) + # 使用CRUD接口创建记录 + try: + crud = CRUDBase(PersonInfo) + await crud.create(final_data) + except Exception as e: + logger.error(f"创建 PersonInfo 记录 {final_data.get('person_id')} 失败 (SQLAlchemy): {e}") @staticmethod async def _safe_create_person_info(person_id: str, data: dict | None = None): @@ -429,23 +401,33 @@ class PersonInfoManager: logger.info(f"文本: {text}") return {"nickname": "", "reason": ""} - async def _generate_unique_person_name(self, base_name: str) -> str: + @staticmethod + async def _generate_unique_person_name(base_name: str) -> str: """生成唯一的 person_name,如果存在重复则添加数字后缀""" # 处理空昵称的情况 if not base_name or base_name.isspace(): base_name = "空格" - # 检查基础名称是否已存在 - if base_name not in self.person_name_list.values(): - return base_name + try: + # 使用CRUD接口检查基础名称是否已存在于数据库中 + crud = CRUDBase(PersonInfo) + existing_record = await crud.get_by(person_name=base_name) + if not existing_record: + return base_name - # 如果存在,添加数字后缀 - counter = 1 - while True: - new_name = f"{base_name}[{counter}]" - if new_name not in self.person_name_list.values(): - return new_name - counter += 1 + # 如果存在,添加数字后缀并检查 + counter = 1 + while True: + new_name = f"{base_name}[{counter}]" + existing_new_record = await crud.get_by(person_name=new_name) + if not existing_new_record: + return new_name + counter += 1 + except Exception as e: + logger.error(f"生成唯一person_name时出错: {e}") + # 出错时返回带时间戳的唯一名称 + import time + return f"{base_name}_{int(time.time())}" async def qv_person_name( self, person_id: str, user_nickname: str, user_cardname: str, user_avatar: str, request: str = "" @@ -461,7 +443,15 @@ class PersonInfoManager: max_retries = 8 current_try = 0 existing_names_str = "" - current_name_set = set(self.person_name_list.values()) + # 获取数据库中已存在的名称用于重复检查 + try: + # 使用CRUD接口获取所有已存在的名称 + crud = CRUDBase(PersonInfo) + all_records = await crud.get_multi(limit=1000) # 限制数量避免性能问题 + current_name_set = set(record.person_name for record in all_records if record.person_name) + except Exception as e: + logger.warning(f"获取现有名称列表失败: {e}") + current_name_set = set() while current_try < max_retries: # prompt_personality =get_individuality().get_prompt(x_person=2, level=1) @@ -507,12 +497,10 @@ class PersonInfoManager: else: async def _db_check_name_exists_async(name_to_check): - async with get_db_session() as session: - result = await session.execute( - select(PersonInfo).where(PersonInfo.person_name == name_to_check) - ) - record = result.scalar() - return record is not None + # 使用CRUD接口检查名称是否存在 + crud = CRUDBase(PersonInfo) + existing_record = await crud.get_by(person_name=name_to_check) + return existing_record is not None if await _db_check_name_exists_async(generated_nickname): is_duplicate = True @@ -526,7 +514,7 @@ class PersonInfoManager: f"成功给用户{user_nickname} {person_id} 取名 {generated_nickname},理由:{result.get('reason', '未提供理由')}" ) - self.person_name_list[person_id] = generated_nickname + # 移除内存缓存更新,统一使用数据库缓存 return result else: if existing_names_str: @@ -536,11 +524,11 @@ class PersonInfoManager: current_try += 1 # 如果多次尝试后仍未成功,使用唯一的 user_nickname 作为默认值 - unique_nickname = await self._generate_unique_person_name(user_nickname) + unique_nickname = await PersonInfoManager._generate_unique_person_name(user_nickname) logger.warning(f"在{max_retries}次尝试后未能生成唯一昵称,使用默认昵称 {unique_nickname}") await self.update_one_field(person_id, "person_name", unique_nickname) await self.update_one_field(person_id, "name_reason", "使用用户原始昵称作为默认值") - self.person_name_list[person_id] = unique_nickname + # 移除内存缓存更新,统一使用数据库缓存 return {"nickname": unique_nickname, "reason": "使用用户原始昵称作为默认值"} @staticmethod @@ -654,6 +642,7 @@ class PersonInfoManager: return result @staticmethod + @cached(ttl=300, key_prefix="person_specific_list", use_kwargs=False) async def get_specific_value_list( field_name: str, way: Callable[[Any], bool], @@ -729,7 +718,7 @@ class PersonInfoManager: # 如果仍然失败,重新抛出异常 raise e - unique_nickname = await self._generate_unique_person_name(nickname) + unique_nickname = await PersonInfoManager._generate_unique_person_name(nickname) initial_data = { "person_id": person_id, "platform": platform, @@ -765,34 +754,24 @@ class PersonInfoManager: return person_id - async def get_person_info_by_name(self, person_name: str) -> dict | None: + @staticmethod + @cached(ttl=600, key_prefix="person_info_by_name", use_kwargs=False) + async def get_person_info_by_name(person_name: str) -> dict | None: """根据 person_name 查找用户并返回基本信息 (如果找到)""" if not person_name: logger.debug("get_person_info_by_name 获取失败:person_name 不能为空") return None - found_person_id = None - for pid, name_in_cache in self.person_name_list.items(): - if name_in_cache == person_name: - found_person_id = pid - break - - if not found_person_id: - - # 使用CRUD进行查询 (person_name不是唯一字段,可能返回多条) - crud = CRUDBase(PersonInfo) - records = await crud.get_multi(person_name=person_name, limit=1) - if records: - record = records[0] - found_person_id = record.person_id - if ( - found_person_id not in self.person_name_list - or self.person_name_list[found_person_id] != person_name - ): - self.person_name_list[found_person_id] = person_name - else: - logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户") - return None + # 直接查询数据库,移除内存缓存逻辑 + # 使用CRUD进行查询 (person_name不是唯一字段,可能返回多条) + crud = CRUDBase(PersonInfo) + records = await crud.get_multi(person_name=person_name, limit=1) + if records: + record = records[0] + found_person_id = record.person_id + else: + logger.debug(f"数据库中未找到名为 '{person_name}' 的用户") + return None if found_person_id: required_fields = [ @@ -809,7 +788,7 @@ class PersonInfoManager: model_fields = [column.name for column in PersonInfo.__table__.columns] valid_fields_to_get = [f for f in required_fields if f in model_fields or f in person_info_default] - person_data = await self.get_values(found_person_id, valid_fields_to_get) + person_data = await PersonInfoManager.get_values(found_person_id, valid_fields_to_get) if person_data: final_result = {key: person_data.get(key) for key in required_fields} diff --git a/src/plugins/built_in/social_toolkit_plugin/plugin.py b/src/plugins/built_in/social_toolkit_plugin/plugin.py index ceeffd2bc..a4cd165c4 100644 --- a/src/plugins/built_in/social_toolkit_plugin/plugin.py +++ b/src/plugins/built_in/social_toolkit_plugin/plugin.py @@ -465,22 +465,31 @@ class RemindAction(BaseAction): # 2. 包含匹配 if not user_info: - for person_id, name in person_manager.person_name_list.items(): - if user_name in name: - user_info = await person_manager.get_values(person_id, ["user_id", "user_nickname"]) + # 使用数据库查询获取所有用户进行包含匹配 + from src.common.database.api.crud import CRUDBase + from src.common.database.core.models import PersonInfo + crud = CRUDBase(PersonInfo) + all_records = await crud.get_multi(limit=1000) # 限制数量避免性能问题 + for record in all_records: + if record.person_name and user_name in record.person_name: + user_info = await person_manager.get_values(record.person_id, ["user_id", "user_nickname"]) break # 3. 模糊匹配 (此处简化为字符串相似度) if not user_info: best_match = None highest_similarity = 0 - for person_id, name in person_manager.person_name_list.items(): - import difflib + import difflib - similarity = difflib.SequenceMatcher(None, user_name, name).ratio() - if similarity > highest_similarity: - highest_similarity = similarity - best_match = person_id + # 使用数据库查询获取所有用户进行模糊匹配 + crud = CRUDBase(PersonInfo) + all_records = await crud.get_multi(limit=1000) # 限制数量避免性能问题 + for record in all_records: + if record.person_name: + similarity = difflib.SequenceMatcher(None, user_name, record.person_name).ratio() + if similarity > highest_similarity: + highest_similarity = similarity + best_match = record.person_id if best_match and highest_similarity > 0.6: # 相似度阈值 user_info = await person_manager.get_values(best_match, ["user_id", "user_nickname"])