1145 lines
49 KiB
Python
1145 lines
49 KiB
Python
import copy
|
||
import datetime
|
||
import hashlib
|
||
import time
|
||
from collections.abc import Callable
|
||
from typing import Any
|
||
|
||
import orjson
|
||
from json_repair import repair_json
|
||
|
||
from src.common.database.api.crud import CRUDBase
|
||
from src.common.database.core.models import PersonInfo
|
||
from src.common.database.utils.decorators import cached
|
||
from src.common.logger import get_logger
|
||
from src.config.config import global_config, model_config
|
||
from src.llm_models.utils_model import LLMRequest
|
||
|
||
logger = get_logger("person_info")
|
||
|
||
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
|
||
"""获取唯一id"""
|
||
if "-" in platform:
|
||
platform = platform.split("-")[1]
|
||
components = [platform, str(user_id)]
|
||
key = "_".join(components)
|
||
return hashlib.md5(key.encode()).hexdigest()
|
||
|
||
def get_person_id_by_person_name(person_name: str) -> str:
|
||
"""根据用户名获取用户ID"""
|
||
try:
|
||
record = PersonInfo.get_or_none(PersonInfo.person_name == person_name)
|
||
return record.person_id if record else ""
|
||
except Exception as e:
|
||
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (Peewee): {e}")
|
||
return ""
|
||
|
||
def is_person_known(person_id: str = None,user_id: str = None,platform: str = None,person_name: str = None) -> bool:
|
||
if person_id:
|
||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||
return person.is_known if person else False
|
||
elif user_id and platform:
|
||
person_id = get_person_id(platform, user_id)
|
||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||
return person.is_known if person else False
|
||
elif person_name:
|
||
person_id = get_person_id_by_person_name(person_name)
|
||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||
return person.is_known if person else False
|
||
else:
|
||
return False
|
||
|
||
class Person:
|
||
@classmethod
|
||
def register_person(cls, platform: str, user_id: str, nickname: str):
|
||
"""
|
||
注册新用户的类方法
|
||
必须输入 platform、user_id 和 nickname 参数
|
||
|
||
Args:
|
||
platform: 平台名称
|
||
user_id: 用户ID
|
||
nickname: 用户昵称
|
||
|
||
Returns:
|
||
Person: 新注册的Person实例
|
||
"""
|
||
if not platform or not user_id or not nickname:
|
||
logger.error("注册用户失败:platform、user_id 和 nickname 都是必需参数")
|
||
return None
|
||
|
||
# 生成唯一的person_id
|
||
person_id = get_person_id(platform, user_id)
|
||
|
||
if is_person_known(person_id=person_id):
|
||
logger.info(f"用户 {nickname} 已存在")
|
||
return Person(person_id=person_id)
|
||
|
||
# 创建Person实例
|
||
person = cls.__new__(cls)
|
||
|
||
# 设置基本属性
|
||
person.person_id = person_id
|
||
person.platform = platform
|
||
person.user_id = user_id
|
||
person.nickname = nickname
|
||
|
||
# 初始化默认值
|
||
person.is_known = True # 注册后立即标记为已认识
|
||
person.person_name = nickname # 使用nickname作为初始person_name
|
||
person.name_reason = "用户注册时设置的昵称"
|
||
person.know_times = 1
|
||
person.know_since = time.time()
|
||
person.last_know = time.time()
|
||
person.points = []
|
||
|
||
# 初始化性格特征相关字段
|
||
person.attitude_to_me = 0
|
||
person.attitude_to_me_confidence = 1
|
||
|
||
person.neuroticism = 5
|
||
person.neuroticism_confidence = 1
|
||
|
||
person.friendly_value = 50
|
||
person.friendly_value_confidence = 1
|
||
|
||
person.rudeness = 50
|
||
person.rudeness_confidence = 1
|
||
|
||
person.conscientiousness = 50
|
||
person.conscientiousness_confidence = 1
|
||
|
||
person.likeness = 50
|
||
person.likeness_confidence = 1
|
||
|
||
# 同步到数据库
|
||
person.sync_to_database()
|
||
|
||
logger.info(f"成功注册新用户:{person_id},平台:{platform},昵称:{nickname}")
|
||
|
||
return person
|
||
|
||
def __init__(self, platform: str = "", user_id: str = "",person_id: str = "",person_name: str = ""):
|
||
if platform == global_config.bot.platform and user_id == global_config.bot.qq_account:
|
||
self.is_known = True
|
||
self.person_id = get_person_id(platform, user_id)
|
||
self.user_id = user_id
|
||
self.platform = platform
|
||
self.nickname = global_config.bot.nickname
|
||
self.person_name = global_config.bot.nickname
|
||
return
|
||
|
||
self.user_id = ""
|
||
self.platform = ""
|
||
|
||
if person_id:
|
||
self.person_id = person_id
|
||
elif person_name:
|
||
self.person_id = get_person_id_by_person_name(person_name)
|
||
if not self.person_id:
|
||
logger.error(f"根据用户名 {person_name} 获取用户ID时出错,不存在用户{person_name}")
|
||
return
|
||
elif platform and user_id:
|
||
self.person_id = get_person_id(platform, user_id)
|
||
self.user_id = user_id
|
||
self.platform = platform
|
||
else:
|
||
logger.error("Person 初始化失败,缺少必要参数")
|
||
raise ValueError("Person 初始化失败,缺少必要参数")
|
||
|
||
if not is_person_known(person_id=self.person_id):
|
||
self.is_known = False
|
||
logger.warning(f"用户 {platform}:{user_id}:{person_name}:{person_id} 尚未认识")
|
||
self.person_name = f"未知用户{self.person_id[:4]}"
|
||
return
|
||
|
||
self.is_known = False
|
||
|
||
# 初始化默认值
|
||
self.nickname = ""
|
||
self.person_name = None
|
||
self.name_reason = None
|
||
self.know_times = 0
|
||
self.know_since = None
|
||
self.last_know = None
|
||
self.points = []
|
||
|
||
# 初始化性格特征相关字段
|
||
self.attitude_to_me:float = 0
|
||
self.attitude_to_me_confidence:float = 1
|
||
|
||
self.neuroticism:float = 5
|
||
self.neuroticism_confidence:float = 1
|
||
|
||
self.friendly_value:float = 50
|
||
self.friendly_value_confidence:float = 1
|
||
|
||
self.rudeness:float = 50
|
||
self.rudeness_confidence:float = 1
|
||
|
||
self.conscientiousness:float = 50
|
||
self.conscientiousness_confidence:float = 1
|
||
|
||
self.likeness:float = 50
|
||
self.likeness_confidence:float = 1
|
||
|
||
# 从数据库加载数据
|
||
self.load_from_database()
|
||
|
||
def load_from_database(self):
|
||
"""从数据库加载个人信息数据"""
|
||
try:
|
||
# 查询数据库中的记录
|
||
record = PersonInfo.get_or_none(PersonInfo.person_id == self.person_id)
|
||
|
||
if record:
|
||
self.user_id = record.user_id if record.user_id else ""
|
||
self.platform = record.platform if record.platform else ""
|
||
self.is_known = record.is_known if record.is_known else False
|
||
self.nickname = record.nickname if record.nickname else ""
|
||
self.person_name = record.person_name if record.person_name else self.nickname
|
||
self.name_reason = record.name_reason if record.name_reason else None
|
||
self.know_times = record.know_times if record.know_times else 0
|
||
|
||
# 处理points字段(JSON格式的列表)
|
||
if record.points:
|
||
try:
|
||
self.points = json.loads(record.points)
|
||
except (json.JSONDecodeError, TypeError):
|
||
logger.warning(f"解析用户 {self.person_id} 的points字段失败,使用默认值")
|
||
self.points = []
|
||
else:
|
||
self.points = []
|
||
|
||
# 加载性格特征相关字段
|
||
if record.attitude_to_me and not isinstance(record.attitude_to_me, str):
|
||
self.attitude_to_me = record.attitude_to_me
|
||
|
||
if record.attitude_to_me_confidence is not None:
|
||
self.attitude_to_me_confidence = float(record.attitude_to_me_confidence)
|
||
|
||
if record.friendly_value is not None:
|
||
self.friendly_value = float(record.friendly_value)
|
||
|
||
if record.friendly_value_confidence is not None:
|
||
self.friendly_value_confidence = float(record.friendly_value_confidence)
|
||
|
||
if record.rudeness is not None:
|
||
self.rudeness = float(record.rudeness)
|
||
|
||
if record.rudeness_confidence is not None:
|
||
self.rudeness_confidence = float(record.rudeness_confidence)
|
||
|
||
if record.neuroticism and not isinstance(record.neuroticism, str):
|
||
self.neuroticism = float(record.neuroticism)
|
||
|
||
if record.neuroticism_confidence is not None:
|
||
self.neuroticism_confidence = float(record.neuroticism_confidence)
|
||
|
||
if record.conscientiousness is not None:
|
||
self.conscientiousness = float(record.conscientiousness)
|
||
|
||
if record.conscientiousness_confidence is not None:
|
||
self.conscientiousness_confidence = float(record.conscientiousness_confidence)
|
||
|
||
if record.likeness is not None:
|
||
self.likeness = float(record.likeness)
|
||
|
||
if record.likeness_confidence is not None:
|
||
self.likeness_confidence = float(record.likeness_confidence)
|
||
|
||
logger.debug(f"已从数据库加载用户 {self.person_id} 的信息")
|
||
else:
|
||
self.sync_to_database()
|
||
logger.info(f"用户 {self.person_id} 在数据库中不存在,使用默认值并创建")
|
||
|
||
except Exception as e:
|
||
logger.error(f"从数据库加载用户 {self.person_id} 信息时出错: {e}")
|
||
# 出错时保持默认值
|
||
|
||
def sync_to_database(self):
|
||
"""将所有属性同步回数据库"""
|
||
if not self.is_known:
|
||
return
|
||
try:
|
||
# 准备数据
|
||
data = {
|
||
'person_id': self.person_id,
|
||
'is_known': self.is_known,
|
||
'platform': self.platform,
|
||
'user_id': self.user_id,
|
||
'nickname': self.nickname,
|
||
'person_name': self.person_name,
|
||
'name_reason': self.name_reason,
|
||
'know_times': self.know_times,
|
||
'know_since': self.know_since,
|
||
'last_know': self.last_know,
|
||
'points': json.dumps(self.points, ensure_ascii=False) if self.points else json.dumps([], ensure_ascii=False),
|
||
'attitude_to_me': self.attitude_to_me,
|
||
'attitude_to_me_confidence': self.attitude_to_me_confidence,
|
||
'friendly_value': self.friendly_value,
|
||
'friendly_value_confidence': self.friendly_value_confidence,
|
||
'rudeness': self.rudeness,
|
||
'rudeness_confidence': self.rudeness_confidence,
|
||
'neuroticism': self.neuroticism,
|
||
'neuroticism_confidence': self.neuroticism_confidence,
|
||
'conscientiousness': self.conscientiousness,
|
||
'conscientiousness_confidence': self.conscientiousness_confidence,
|
||
'likeness': self.likeness,
|
||
'likeness_confidence': self.likeness_confidence,
|
||
}
|
||
|
||
# 检查记录是否存在
|
||
record = PersonInfo.get_or_none(PersonInfo.person_id == self.person_id)
|
||
|
||
if record:
|
||
# 更新现有记录
|
||
for field, value in data.items():
|
||
if hasattr(record, field):
|
||
setattr(record, field, value)
|
||
record.save()
|
||
logger.debug(f"已同步用户 {self.person_id} 的信息到数据库")
|
||
else:
|
||
# 创建新记录
|
||
PersonInfo.create(**data)
|
||
logger.debug(f"已创建用户 {self.person_id} 的信息到数据库")
|
||
|
||
except Exception as e:
|
||
logger.error(f"同步用户 {self.person_id} 信息到数据库时出错: {e}")
|
||
|
||
def build_relationship(self,points_num=3):
|
||
# print(self.person_name,self.nickname,self.platform,self.is_known)
|
||
|
||
|
||
if not self.is_known:
|
||
return ""
|
||
|
||
# 按时间排序forgotten_points
|
||
current_points = self.points
|
||
current_points.sort(key=lambda x: x[2])
|
||
# 按权重加权随机抽取最多3个不重复的points,point[1]的值在1-10之间,权重越高被抽到概率越大
|
||
if len(current_points) > points_num:
|
||
# point[1] 取值范围1-10,直接作为权重
|
||
weights = [max(1, min(10, int(point[1]))) for point in current_points]
|
||
# 使用加权采样不放回,保证不重复
|
||
indices = list(range(len(current_points)))
|
||
points = []
|
||
for _ in range(points_num):
|
||
if not indices:
|
||
break
|
||
sub_weights = [weights[i] for i in indices]
|
||
chosen_idx = random.choices(indices, weights=sub_weights, k=1)[0]
|
||
points.append(current_points[chosen_idx])
|
||
indices.remove(chosen_idx)
|
||
else:
|
||
points = current_points
|
||
|
||
# 构建points文本
|
||
points_text = "\n".join([f"{point[2]}:{point[0]}" for point in points])
|
||
|
||
nickname_str = ""
|
||
if self.person_name != self.nickname:
|
||
nickname_str = f"(ta在{self.platform}上的昵称是{self.nickname})"
|
||
|
||
relation_info = ""
|
||
|
||
attitude_info = ""
|
||
if self.attitude_to_me:
|
||
if self.attitude_to_me > 8:
|
||
attitude_info = f"{self.person_name}对你的态度十分好,"
|
||
elif self.attitude_to_me > 5:
|
||
attitude_info = f"{self.person_name}对你的态度较好,"
|
||
|
||
|
||
if self.attitude_to_me < -8:
|
||
attitude_info = f"{self.person_name}对你的态度十分恶劣,"
|
||
elif self.attitude_to_me < -4:
|
||
attitude_info = f"{self.person_name}对你的态度不好,"
|
||
elif self.attitude_to_me < 0:
|
||
attitude_info = f"{self.person_name}对你的态度一般,"
|
||
|
||
neuroticism_info = ""
|
||
if self.neuroticism:
|
||
if self.neuroticism > 8:
|
||
neuroticism_info = f"{self.person_name}的情绪十分活跃,容易情绪化,"
|
||
elif self.neuroticism > 6:
|
||
neuroticism_info = f"{self.person_name}的情绪比较活跃,"
|
||
elif self.neuroticism > 4:
|
||
neuroticism_info = ""
|
||
elif self.neuroticism > 2:
|
||
neuroticism_info = f"{self.person_name}的情绪比较稳定,"
|
||
else:
|
||
neuroticism_info = f"{self.person_name}的情绪非常稳定,毫无波动"
|
||
|
||
points_info = ""
|
||
if points_text:
|
||
points_info = f"你还记得ta最近做的事:{points_text}"
|
||
|
||
if not (nickname_str or attitude_info or neuroticism_info or points_info):
|
||
return ""
|
||
relation_info = f"{self.person_name}:{nickname_str}{attitude_info}{neuroticism_info}{points_info}"
|
||
|
||
return relation_info
|
||
|
||
|
||
class PersonInfoManager:
|
||
def __init__(self):
|
||
"""初始化PersonInfoManager"""
|
||
# 移除self.person_name_list缓存,统一使用数据库缓存系统
|
||
self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name")
|
||
# try:
|
||
# async with get_db_session() as session:
|
||
# db.connect(reuse_if_open=True)
|
||
# # 设置连接池参数(仅对SQLite有效)
|
||
# if hasattr(db, "execute_sql"):
|
||
# # 检查数据库类型,只对SQLite执行PRAGMA语句
|
||
# if global_config.database.database_type == "sqlite":
|
||
# # 设置SQLite优化参数
|
||
# db.execute_sql("PRAGMA cache_size = -64000") # 64MB缓存
|
||
# db.execute_sql("PRAGMA temp_store = memory") # 临时存储在内存中
|
||
# db.execute_sql("PRAGMA mmap_size = 268435456") # 256MB内存映射
|
||
# db.create_tables([PersonInfo], safe=True)
|
||
# except Exception as e:
|
||
# logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}")
|
||
|
||
# 移除初始化时读取person_name_list的逻辑,统一使用数据库缓存
|
||
|
||
@staticmethod
|
||
def get_person_id(platform: str, user_id: int | str) -> str:
|
||
"""获取唯一id(同步)
|
||
|
||
说明: 原来该方法为异步并在内部尝试执行数据库检查/迁移,导致在许多调用处未 await 时返回 coroutine 对象。
|
||
为了避免将 coroutine 传递到其它同步调用(例如数据库查询条件)中,这里将方法改为同步并仅返回基于 platform 和 user_id 的 MD5 哈希值。
|
||
|
||
注意: 这会跳过原有的 napcat->qq 迁移检查逻辑。如需保留迁移,请使用显式的、在合适时机执行的迁移任务。
|
||
"""
|
||
# 检查platform是否为None或空
|
||
if platform is None:
|
||
platform = "unknown"
|
||
|
||
if "-" in platform:
|
||
platform = platform.split("-")[1]
|
||
|
||
components = [platform, str(user_id)]
|
||
key = "_".join(components)
|
||
|
||
# 直接返回计算的 id(同步)
|
||
return hashlib.md5(key.encode()).hexdigest()
|
||
|
||
@cached(ttl=300, key_prefix="person_known", use_kwargs=False)
|
||
async def is_person_known(self, platform: str, user_id: int):
|
||
"""判断是否认识某人(带5分钟缓存)"""
|
||
person_id = self.get_person_id(platform, user_id)
|
||
|
||
try:
|
||
# 使用CRUD进行查询
|
||
crud = CRUDBase(PersonInfo)
|
||
record = await crud.get_by(person_id=person_id)
|
||
return record is not None
|
||
except Exception as e:
|
||
logger.error(f"检查用户 {person_id} 是否已知时出错: {e}")
|
||
return False
|
||
|
||
@staticmethod
|
||
@cached(ttl=600, key_prefix="person_name_to_id", use_kwargs=False)
|
||
async def get_person_id_by_person_name(person_name: str) -> str:
|
||
"""
|
||
根据用户名获取用户ID(异步)
|
||
|
||
统一使用数据库缓存系统,移除内存缓存
|
||
"""
|
||
if not person_name:
|
||
return ""
|
||
|
||
try:
|
||
# 使用CRUD接口查询,使用装饰器缓存
|
||
crud = CRUDBase(PersonInfo)
|
||
records = await crud.get_multi(person_name=person_name, limit=1)
|
||
|
||
if records:
|
||
return records[0].person_id
|
||
|
||
# 数据库中没有找到
|
||
return ""
|
||
except Exception as e:
|
||
logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}")
|
||
return ""
|
||
|
||
@staticmethod
|
||
async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str):
|
||
"""判断是否认识某人"""
|
||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||
# 生成唯一的 person_name
|
||
unique_nickname = await PersonInfoManager._generate_unique_person_name(user_nickname)
|
||
data = {
|
||
"platform": platform,
|
||
"user_id": user_id,
|
||
"nickname": user_nickname,
|
||
"konw_time": int(time.time()),
|
||
"person_name": unique_nickname, # 使用唯一的 person_name
|
||
}
|
||
# 先创建用户基本信息,使用安全创建方法避免竞态条件
|
||
await PersonInfoManager._safe_create_person_info(person_id=person_id, data=data)
|
||
# 更新昵称
|
||
await get_person_info_manager().update_one_field(
|
||
person_id=person_id, field_name="nickname", value=user_nickname, data=data
|
||
)
|
||
|
||
@staticmethod
|
||
async def create_person_info(person_id: str, data: dict | None = None):
|
||
"""创建一个项"""
|
||
if not person_id:
|
||
logger.debug("创建失败,person_id不存在")
|
||
return
|
||
|
||
_person_info_default = copy.deepcopy(person_info_default)
|
||
# 获取 SQLAlchemy 模型的所有字段名
|
||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||
|
||
final_data = {"person_id": person_id}
|
||
|
||
# Start with defaults for all model fields
|
||
for key, default_value in _person_info_default.items():
|
||
if key in model_fields:
|
||
final_data[key] = default_value
|
||
|
||
# Override with provided data
|
||
if data:
|
||
for key, value in data.items():
|
||
if key in model_fields:
|
||
final_data[key] = value
|
||
|
||
# Ensure person_id is correctly set from the argument
|
||
final_data["person_id"] = person_id
|
||
# 你们的英文注释是何意味?
|
||
|
||
# 检查并修复关键字段为None的情况喵
|
||
if final_data.get("user_id") is None:
|
||
logger.warning(f"user_id为None,使用'unknown'作为默认值 person_id={person_id}")
|
||
final_data["user_id"] = "unknown"
|
||
|
||
if final_data.get("platform") is None:
|
||
logger.warning(f"platform为None,使用'unknown'作为默认值 person_id={person_id}")
|
||
final_data["platform"] = "unknown"
|
||
|
||
# 这里的目的是为了防止在识别出错的情况下有一个最小回退,不只是针对@消息识别成视频后的报错问题
|
||
|
||
# Serialize JSON fields
|
||
for key in JSON_SERIALIZED_FIELDS:
|
||
if key in final_data:
|
||
if isinstance(final_data[key], list | dict):
|
||
final_data[key] = orjson.dumps(final_data[key]).decode("utf-8")
|
||
elif final_data[key] is None: # Default for lists is [], store as "[]"
|
||
final_data[key] = orjson.dumps([]).decode("utf-8")
|
||
# If it's already a string, assume it's valid JSON or a non-JSON string field
|
||
|
||
# 使用CRUD接口创建记录
|
||
try:
|
||
crud = CRUDBase(PersonInfo)
|
||
await crud.create(final_data)
|
||
except Exception as e:
|
||
logger.error(f"创建 PersonInfo 记录 {final_data.get('person_id')} 失败 (SQLAlchemy): {e}")
|
||
|
||
@staticmethod
|
||
async def _safe_create_person_info(person_id: str, data: dict | None = None):
|
||
"""安全地创建用户信息,处理竞态条件"""
|
||
if not person_id:
|
||
logger.debug("创建失败,person_id不存在")
|
||
return
|
||
|
||
_person_info_default = copy.deepcopy(person_info_default)
|
||
# 获取 SQLAlchemy 模型的所有字段名
|
||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||
|
||
final_data = {"person_id": person_id}
|
||
|
||
# Start with defaults for all model fields
|
||
for key, default_value in _person_info_default.items():
|
||
if key in model_fields:
|
||
final_data[key] = default_value
|
||
|
||
# Override with provided data
|
||
if data:
|
||
for key, value in data.items():
|
||
if key in model_fields:
|
||
final_data[key] = value
|
||
|
||
# Ensure person_id is correctly set from the argument
|
||
final_data["person_id"] = person_id
|
||
|
||
# 检查并修复关键字段为None的情况
|
||
if final_data.get("user_id") is None:
|
||
logger.warning(f"user_id为None,使用'unknown'作为默认值 person_id={person_id}")
|
||
final_data["user_id"] = "unknown"
|
||
|
||
if final_data.get("platform") is None:
|
||
logger.warning(f"platform为None,使用'unknown'作为默认值 person_id={person_id}")
|
||
final_data["platform"] = "unknown"
|
||
|
||
# Serialize JSON fields
|
||
for key in JSON_SERIALIZED_FIELDS:
|
||
if key in final_data:
|
||
if isinstance(final_data[key], list | dict):
|
||
final_data[key] = orjson.dumps(final_data[key]).decode("utf-8")
|
||
elif final_data[key] is None: # Default for lists is [], store as "[]"
|
||
final_data[key] = orjson.dumps([]).decode("utf-8")
|
||
|
||
async def _db_safe_create_async(p_data: dict):
|
||
try:
|
||
# 使用CRUD进行检查和创建
|
||
crud = CRUDBase(PersonInfo)
|
||
existing = await crud.get_by(person_id=p_data["person_id"])
|
||
if existing:
|
||
logger.debug(f"用户 {p_data['person_id']} 已存在,跳过创建")
|
||
return True
|
||
|
||
# 创建新记录
|
||
await crud.create(p_data)
|
||
return True
|
||
except Exception as e:
|
||
if "UNIQUE constraint failed" in str(e):
|
||
logger.debug(f"检测到并发创建用户 {p_data.get('person_id')},跳过错误")
|
||
return True
|
||
else:
|
||
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败: {e}")
|
||
return False
|
||
|
||
await _db_safe_create_async(final_data)
|
||
|
||
async def update_one_field(self, person_id: str, field_name: str, value, data: dict | None = None):
|
||
"""更新某一个字段,会补全"""
|
||
# 获取 SQLAlchemy 模型的所有字段名
|
||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||
if field_name not in model_fields:
|
||
logger.debug(f"更新'{field_name}'失败,未在 PersonInfo SQLAlchemy 模型中定义的字段。")
|
||
return
|
||
|
||
processed_value = value
|
||
if field_name in JSON_SERIALIZED_FIELDS:
|
||
if isinstance(value, list | dict):
|
||
processed_value = orjson.dumps(value).decode("utf-8")
|
||
elif value is None: # Store None as "[]" for JSON list fields
|
||
processed_value = orjson.dumps([]).decode("utf-8")
|
||
|
||
async def _db_update_async(p_id: str, f_name: str, val_to_set):
|
||
start_time = time.time()
|
||
try:
|
||
# 使用CRUD进行更新
|
||
crud = CRUDBase(PersonInfo)
|
||
record = await crud.get_by(person_id=p_id)
|
||
query_time = time.time()
|
||
|
||
if record:
|
||
# 更新记录
|
||
await crud.update(record.id, {f_name: val_to_set})
|
||
save_time = time.time()
|
||
total_time = save_time - start_time
|
||
|
||
if total_time > 0.5:
|
||
logger.warning(
|
||
f"数据库更新操作耗时 {total_time:.3f}秒 (查询: {query_time - start_time:.3f}s, 保存: {save_time - query_time:.3f}s) person_id={p_id}, field={f_name}"
|
||
)
|
||
|
||
# 使缓存失效
|
||
from src.common.database.optimization.cache_manager import get_cache
|
||
from src.common.database.utils.decorators import generate_cache_key
|
||
cache = await get_cache()
|
||
# 使相关缓存失效
|
||
await cache.delete(generate_cache_key("person_value", p_id, f_name))
|
||
await cache.delete(generate_cache_key("person_values", p_id))
|
||
await cache.delete(generate_cache_key("person_has_field", p_id, f_name))
|
||
|
||
return True, False
|
||
else:
|
||
total_time = time.time() - start_time
|
||
if total_time > 0.5:
|
||
logger.warning(f"数据库查询操作耗时 {total_time:.3f}秒 person_id={p_id}, field={f_name}")
|
||
return False, True
|
||
except Exception as e:
|
||
total_time = time.time() - start_time
|
||
logger.error(f"数据库操作异常,耗时 {total_time:.3f}秒: {e}")
|
||
raise
|
||
|
||
_found, needs_creation = await _db_update_async(person_id, field_name, processed_value)
|
||
|
||
if needs_creation:
|
||
logger.info(f"{person_id} 不存在,将新建。")
|
||
creation_data = data if data is not None else {}
|
||
# Ensure platform and user_id are present for context if available from 'data'
|
||
# but primarily, set the field that triggered the update.
|
||
# The create_person_info will handle defaults and serialization.
|
||
creation_data[field_name] = value # Pass original value to create_person_info
|
||
|
||
# Ensure platform and user_id are in creation_data if available,
|
||
# otherwise create_person_info will use defaults.
|
||
if data and "platform" in data:
|
||
creation_data["platform"] = data["platform"]
|
||
if data and "user_id" in data:
|
||
creation_data["user_id"] = data["user_id"]
|
||
|
||
# 额外检查关键字段,如果为None则使用默认值
|
||
if creation_data.get("user_id") is None:
|
||
logger.warning(f"创建用户时user_id为None,使用'unknown'作为默认值 person_id={person_id}")
|
||
creation_data["user_id"] = "unknown"
|
||
|
||
if creation_data.get("platform") is None:
|
||
logger.warning(f"创建用户时platform为None,使用'unknown'作为默认值 person_id={person_id}")
|
||
creation_data["platform"] = "unknown"
|
||
|
||
# 使用安全的创建方法,处理竞态条件
|
||
await self._safe_create_person_info(person_id, creation_data)
|
||
|
||
@staticmethod
|
||
@cached(ttl=300, key_prefix="person_has_field")
|
||
async def has_one_field(person_id: str, field_name: str):
|
||
"""判断是否存在某一个字段(带5分钟缓存)"""
|
||
# 获取 SQLAlchemy 模型的所有字段名
|
||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||
if field_name not in model_fields:
|
||
logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo SQLAlchemy 模型中定义。")
|
||
return False
|
||
|
||
try:
|
||
# 使用CRUD进行查询
|
||
crud = CRUDBase(PersonInfo)
|
||
record = await crud.get_by(person_id=person_id)
|
||
return bool(record)
|
||
except Exception as e:
|
||
logger.error(f"检查字段 {field_name} for {person_id} 时出错: {e}")
|
||
return False
|
||
|
||
@staticmethod
|
||
def _extract_json_from_text(text: str) -> dict:
|
||
"""从文本中提取JSON数据的高容错方法"""
|
||
try:
|
||
fixed_json = repair_json(text)
|
||
if isinstance(fixed_json, str):
|
||
parsed_json = orjson.loads(fixed_json)
|
||
else:
|
||
parsed_json = fixed_json
|
||
|
||
if isinstance(parsed_json, list) and parsed_json:
|
||
parsed_json = parsed_json[0]
|
||
|
||
if isinstance(parsed_json, dict):
|
||
return parsed_json
|
||
|
||
except Exception as e:
|
||
logger.warning(f"JSON提取失败: {e}")
|
||
|
||
logger.warning(f"无法从文本中提取有效的JSON字典: {text}")
|
||
logger.info(f"文本: {text}")
|
||
return {"nickname": "", "reason": ""}
|
||
|
||
@staticmethod
|
||
async def _generate_unique_person_name(base_name: str) -> str:
|
||
"""生成唯一的 person_name,如果存在重复则添加数字后缀"""
|
||
# 处理空昵称的情况
|
||
if not base_name or base_name.isspace():
|
||
base_name = "空格"
|
||
|
||
try:
|
||
# 使用CRUD接口检查基础名称是否已存在于数据库中
|
||
crud = CRUDBase(PersonInfo)
|
||
existing_record = await crud.get_by(person_name=base_name)
|
||
if not existing_record:
|
||
return base_name
|
||
|
||
# 如果存在,添加数字后缀并检查
|
||
counter = 1
|
||
while True:
|
||
new_name = f"{base_name}[{counter}]"
|
||
existing_new_record = await crud.get_by(person_name=new_name)
|
||
if not existing_new_record:
|
||
return new_name
|
||
counter += 1
|
||
except Exception as e:
|
||
logger.error(f"生成唯一person_name时出错: {e}")
|
||
# 出错时返回带时间戳的唯一名称
|
||
import time
|
||
return f"{base_name}_{int(time.time())}"
|
||
|
||
async def qv_person_name(
|
||
self, person_id: str, user_nickname: str, user_cardname: str, user_avatar: str, request: str = ""
|
||
):
|
||
"""给某个用户取名"""
|
||
if not person_id:
|
||
logger.debug("取名失败:person_id不能为空")
|
||
return None
|
||
|
||
person = Person(person_id=person_id)
|
||
old_name = person.person_name
|
||
old_reason = person.name_reason
|
||
|
||
max_retries = 8
|
||
current_try = 0
|
||
existing_names_str = ""
|
||
# 获取数据库中已存在的名称用于重复检查
|
||
try:
|
||
# 使用CRUD接口获取所有已存在的名称
|
||
crud = CRUDBase(PersonInfo)
|
||
all_records = await crud.get_multi(limit=1000) # 限制数量避免性能问题
|
||
current_name_set = set(record.person_name for record in all_records if record.person_name)
|
||
except Exception as e:
|
||
logger.warning(f"获取现有名称列表失败: {e}")
|
||
current_name_set = set()
|
||
|
||
while current_try < max_retries:
|
||
# prompt_personality =get_individuality().get_prompt(x_person=2, level=1)
|
||
bot_name = global_config.bot.nickname
|
||
|
||
qv_name_prompt = f"你是{bot_name},一个伪装成人类的AI,你不能让别人发现这一点,"
|
||
qv_name_prompt += f"现在你想给一个用户取一个昵称,用户的qq昵称是{user_nickname},"
|
||
qv_name_prompt += f"用户的qq群昵称名是{user_cardname},"
|
||
if user_avatar:
|
||
qv_name_prompt += f"用户的qq头像是{user_avatar},"
|
||
if old_name:
|
||
qv_name_prompt += f"你之前叫他{old_name},是因为{old_reason},"
|
||
|
||
qv_name_prompt += f"\n其他取名的要求是:{request},不要太浮夸,简短,"
|
||
qv_name_prompt += "\n请根据以上用户信息,想想你叫他什么比较好,不要太浮夸,请最好使用用户的qq昵称或群昵称原文,可以稍作修改,优先使用原文。优先使用用户的qq昵称或者群昵称原文。"
|
||
|
||
if existing_names_str:
|
||
qv_name_prompt += f"\n请注意,以下名称已被你尝试过或已知存在,请避免:{existing_names_str}。\n"
|
||
|
||
if len(current_name_set) < 50 and current_name_set:
|
||
qv_name_prompt += f"已知的其他昵称有: {', '.join(list(current_name_set)[:10])}等。\n"
|
||
|
||
qv_name_prompt += "请用json给出你的想法,并给出理由,示例如下:"
|
||
qv_name_prompt += """{
|
||
"nickname": "昵称",
|
||
"reason": "理由"
|
||
}"""
|
||
response, _ = await self.qv_name_llm.generate_response_async(qv_name_prompt)
|
||
# logger.info(f"取名提示词:{qv_name_prompt}\n取名回复:{response}")
|
||
result = self._extract_json_from_text(response)
|
||
|
||
if not result or not result.get("nickname"):
|
||
logger.error("生成的昵称为空或结果格式不正确,重试中...")
|
||
current_try += 1
|
||
continue
|
||
|
||
generated_nickname = result["nickname"]
|
||
|
||
is_duplicate = False
|
||
if generated_nickname in current_name_set:
|
||
is_duplicate = True
|
||
logger.info(f"尝试给用户{user_nickname} {person_id} 取名,但是 {generated_nickname} 已存在,重试中...")
|
||
else:
|
||
|
||
async def _db_check_name_exists_async(name_to_check):
|
||
# 使用CRUD接口检查名称是否存在
|
||
crud = CRUDBase(PersonInfo)
|
||
existing_record = await crud.get_by(person_name=name_to_check)
|
||
return existing_record is not None
|
||
|
||
if await _db_check_name_exists_async(generated_nickname):
|
||
is_duplicate = True
|
||
current_name_set.add(generated_nickname)
|
||
|
||
if not is_duplicate:
|
||
person.person_name = generated_nickname
|
||
person.name_reason = result.get("reason", "未提供理由")
|
||
person.sync_to_database()
|
||
|
||
logger.info(
|
||
f"成功给用户{user_nickname} {person_id} 取名 {generated_nickname},理由:{result.get('reason', '未提供理由')}"
|
||
)
|
||
|
||
# 移除内存缓存更新,统一使用数据库缓存
|
||
return result
|
||
else:
|
||
if existing_names_str:
|
||
existing_names_str += "、"
|
||
existing_names_str += generated_nickname
|
||
logger.debug(f"生成的昵称 {generated_nickname} 已存在,重试中...")
|
||
current_try += 1
|
||
|
||
# 如果多次尝试后仍未成功,使用唯一的 user_nickname 作为默认值
|
||
unique_nickname = await PersonInfoManager._generate_unique_person_name(user_nickname)
|
||
logger.warning(f"在{max_retries}次尝试后未能生成唯一昵称,使用默认昵称 {unique_nickname}")
|
||
await self.update_one_field(person_id, "person_name", unique_nickname)
|
||
await self.update_one_field(person_id, "name_reason", "使用用户原始昵称作为默认值")
|
||
# 移除内存缓存更新,统一使用数据库缓存
|
||
return {"nickname": unique_nickname, "reason": "使用用户原始昵称作为默认值"}
|
||
|
||
|
||
@staticmethod
|
||
async def del_one_document(person_id: str):
|
||
"""删除指定 person_id 的文档"""
|
||
if not person_id:
|
||
logger.debug("删除失败:person_id 不能为空")
|
||
return
|
||
|
||
async def _db_delete_async(p_id: str):
|
||
try:
|
||
# 使用CRUD进行删除
|
||
crud = CRUDBase(PersonInfo)
|
||
record = await crud.get_by(person_id=p_id)
|
||
if record:
|
||
await crud.delete(record.id)
|
||
|
||
# 注意: 删除操作很少发生,缓存会在TTL过期后自动清除
|
||
# 无法从person_id反向得到platform和user_id,因此无法精确清除缓存
|
||
# 删除后的查询仍会返回正确结果(None/False)
|
||
return 1
|
||
return 0
|
||
except Exception as e:
|
||
logger.error(f"删除 PersonInfo {p_id} 失败: {e}")
|
||
return 0
|
||
|
||
deleted_count = await _db_delete_async(person_id)
|
||
|
||
if deleted_count > 0:
|
||
logger.debug(f"删除成功:person_id={person_id}")
|
||
else:
|
||
logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行")
|
||
|
||
@staticmethod
|
||
@cached(ttl=600, key_prefix="person_value")
|
||
async def get_value(person_id: str, field_name: str) -> Any:
|
||
"""获取单个字段值(带10分钟缓存)"""
|
||
if not person_id:
|
||
logger.debug("get_value获取失败:person_id不能为空")
|
||
return None
|
||
|
||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||
|
||
if field_name not in model_fields:
|
||
if field_name in person_info_default:
|
||
logger.debug(f"字段'{field_name}'不在SQLAlchemy模型中,使用默认配置值。")
|
||
return copy.deepcopy(person_info_default[field_name])
|
||
else:
|
||
logger.debug(f"get_value查询失败:字段'{field_name}'未在SQLAlchemy模型和默认配置中定义。")
|
||
return None
|
||
|
||
# 使用CRUD进行查询
|
||
crud = CRUDBase(PersonInfo)
|
||
record = await crud.get_by(person_id=person_id)
|
||
|
||
if record:
|
||
# 在访问属性前确保对象已加载所有数据
|
||
# 使用 try-except 捕获可能的延迟加载错误
|
||
try:
|
||
value = getattr(record, field_name)
|
||
if value is not None:
|
||
return value
|
||
else:
|
||
return copy.deepcopy(person_info_default.get(field_name))
|
||
except Exception as e:
|
||
logger.warning(f"访问字段 {field_name} 失败: {e}, 使用默认值")
|
||
return copy.deepcopy(person_info_default.get(field_name))
|
||
else:
|
||
return copy.deepcopy(person_info_default.get(field_name))
|
||
|
||
@staticmethod
|
||
@cached(ttl=600, key_prefix="person_values")
|
||
async def get_values(person_id: str, field_names: list) -> dict:
|
||
"""获取指定person_id文档的多个字段值(带10分钟缓存)"""
|
||
if not person_id:
|
||
logger.debug("get_values获取失败:person_id不能为空")
|
||
return {}
|
||
|
||
result = {}
|
||
|
||
# 使用CRUD进行查询
|
||
crud = CRUDBase(PersonInfo)
|
||
record = await crud.get_by(person_id=person_id)
|
||
|
||
# 获取 SQLAlchemy 模型的所有字段名
|
||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||
|
||
for field_name in field_names:
|
||
if field_name not in model_fields:
|
||
if field_name in person_info_default:
|
||
result[field_name] = copy.deepcopy(person_info_default[field_name])
|
||
logger.debug(f"字段'{field_name}'不在SQLAlchemy模型中,使用默认配置值。")
|
||
else:
|
||
logger.debug(f"get_values查询失败:字段'{field_name}'未在SQLAlchemy模型和默认配置中定义。")
|
||
result[field_name] = None
|
||
continue
|
||
|
||
if record:
|
||
try:
|
||
value = getattr(record, field_name)
|
||
if value is not None:
|
||
result[field_name] = value
|
||
else:
|
||
result[field_name] = copy.deepcopy(person_info_default.get(field_name))
|
||
except Exception as e:
|
||
logger.warning(f"访问字段 {field_name} 失败: {e}, 使用默认值")
|
||
result[field_name] = copy.deepcopy(person_info_default.get(field_name))
|
||
else:
|
||
result[field_name] = copy.deepcopy(person_info_default.get(field_name))
|
||
|
||
return result
|
||
|
||
@staticmethod
|
||
@cached(ttl=300, key_prefix="person_specific_list", use_kwargs=False)
|
||
async def get_specific_value_list(
|
||
field_name: str,
|
||
way: Callable[[Any], bool],
|
||
) -> dict[str, Any]:
|
||
"""
|
||
获取满足条件的字段值字典
|
||
"""
|
||
# 获取 SQLAlchemy 模型的所有字段名
|
||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||
if field_name not in model_fields:
|
||
logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo SQLAlchemy 模型中定义")
|
||
return {}
|
||
|
||
async def _db_get_specific_async(f_name: str):
|
||
found_results = {}
|
||
try:
|
||
# 使用CRUD获取所有记录
|
||
crud = CRUDBase(PersonInfo)
|
||
all_records = await crud.get_multi(limit=100000) # 获取所有记录
|
||
for record in all_records:
|
||
try:
|
||
value = getattr(record, f_name, None)
|
||
if value is not None and way(value):
|
||
person_id_value = getattr(record, "person_id", None)
|
||
if person_id_value:
|
||
found_results[person_id_value] = value
|
||
except Exception as e:
|
||
logger.warning(f"访问记录字段失败: {e}")
|
||
continue
|
||
except Exception as e_query:
|
||
logger.error(
|
||
f"数据库查询失败 (specific_value_list for {f_name}): {e_query!s}", exc_info=True
|
||
)
|
||
return found_results
|
||
|
||
try:
|
||
return await _db_get_specific_async(field_name)
|
||
except Exception as e:
|
||
logger.error(f"执行 get_specific_value_list 时出错: {e!s}", exc_info=True)
|
||
return {}
|
||
|
||
async def get_or_create_person(
|
||
self, platform: str, user_id: int, nickname: str, user_cardname: str, user_avatar: str | None = None
|
||
) -> str:
|
||
"""
|
||
根据 platform 和 user_id 获取 person_id。
|
||
如果对应的用户不存在,则使用提供的可选信息创建新用户。
|
||
使用try-except处理竞态条件,避免重复创建错误。
|
||
"""
|
||
person_id = self.get_person_id(platform, user_id)
|
||
|
||
async def _db_get_or_create_async(p_id: str, init_data: dict):
|
||
"""原子性的获取或创建操作"""
|
||
# 使用CRUD进行获取或创建
|
||
crud = CRUDBase(PersonInfo)
|
||
|
||
# 首先尝试获取现有记录
|
||
record = await crud.get_by(person_id=p_id)
|
||
if record:
|
||
return record, False # 记录存在,未创建
|
||
|
||
# 记录不存在,尝试创建
|
||
try:
|
||
new_person = await crud.create(init_data)
|
||
return new_person, True # 创建成功
|
||
except Exception as e:
|
||
# 如果创建失败(可能是因为竞态条件),再次尝试获取
|
||
if "UNIQUE constraint failed" in str(e):
|
||
logger.debug(f"检测到并发创建用户 {p_id},获取现有记录")
|
||
record = await crud.get_by(person_id=p_id)
|
||
if record:
|
||
return record, False # 其他协程已创建,返回现有记录
|
||
# 如果仍然失败,重新抛出异常
|
||
raise e
|
||
|
||
unique_nickname = await PersonInfoManager._generate_unique_person_name(nickname)
|
||
initial_data = {
|
||
"person_id": person_id,
|
||
"platform": platform,
|
||
"user_id": str(user_id),
|
||
"nickname": nickname,
|
||
"person_name": unique_nickname,
|
||
"name_reason": "从群昵称获取",
|
||
"know_times": 0,
|
||
"know_since": int(datetime.datetime.now().timestamp()),
|
||
"last_know": int(datetime.datetime.now().timestamp()),
|
||
"impression": None,
|
||
"points": [],
|
||
"forgotten_points": [],
|
||
}
|
||
|
||
for key in JSON_SERIALIZED_FIELDS:
|
||
if key in initial_data:
|
||
if isinstance(initial_data[key], list | dict):
|
||
initial_data[key] = orjson.dumps(initial_data[key]).decode("utf-8")
|
||
elif initial_data[key] is None:
|
||
initial_data[key] = orjson.dumps([]).decode("utf-8")
|
||
|
||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
|
||
|
||
_record, was_created = await _db_get_or_create_async(person_id, filtered_initial_data)
|
||
|
||
if was_created:
|
||
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录。")
|
||
logger.info(f"已为 {person_id} 创建新记录,初始数据: {filtered_initial_data}")
|
||
else:
|
||
logger.debug(f"用户 {platform}:{user_id} (person_id: {person_id}) 已存在,返回现有记录。")
|
||
|
||
return person_id
|
||
|
||
@staticmethod
|
||
@cached(ttl=600, key_prefix="person_info_by_name", use_kwargs=False)
|
||
async def get_person_info_by_name(person_name: str) -> dict | None:
|
||
"""根据 person_name 查找用户并返回基本信息 (如果找到)"""
|
||
if not person_name:
|
||
logger.debug("get_person_info_by_name 获取失败:person_name 不能为空")
|
||
return None
|
||
|
||
# 直接查询数据库,移除内存缓存逻辑
|
||
# 使用CRUD进行查询 (person_name不是唯一字段,可能返回多条)
|
||
crud = CRUDBase(PersonInfo)
|
||
records = await crud.get_multi(person_name=person_name, limit=1)
|
||
if records:
|
||
record = records[0]
|
||
found_person_id = record.person_id
|
||
else:
|
||
logger.debug(f"数据库中未找到名为 '{person_name}' 的用户")
|
||
return None
|
||
|
||
if found_person_id:
|
||
required_fields = [
|
||
"person_id",
|
||
"platform",
|
||
"user_id",
|
||
"nickname",
|
||
"user_cardname",
|
||
"user_avatar",
|
||
"person_name",
|
||
"name_reason",
|
||
]
|
||
# 获取 SQLAlchemy 模型的所有字段名
|
||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||
valid_fields_to_get = [f for f in required_fields if f in model_fields or f in person_info_default]
|
||
|
||
person_data = await PersonInfoManager.get_values(found_person_id, valid_fields_to_get)
|
||
|
||
if person_data:
|
||
final_result = {key: person_data.get(key) for key in required_fields}
|
||
return final_result
|
||
else:
|
||
logger.warning(f"找到了 person_id '{found_person_id}' 但 get_values 返回空 (Peewee)")
|
||
return None
|
||
|
||
logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id (Peewee)")
|
||
return None
|
||
|
||
|
||
person_info_manager = None
|
||
|
||
|
||
def get_person_info_manager():
|
||
global person_info_manager
|
||
if person_info_manager is None:
|
||
person_info_manager = PersonInfoManager()
|
||
return person_info_manager
|