三次修改

This commit is contained in:
tt-P607
2025-09-20 02:21:53 +08:00
committed by Windpicker-owo
parent 635311bc80
commit aba4f1a947
20 changed files with 923 additions and 479 deletions

View File

@@ -401,14 +401,15 @@ class PersonInfoManager:
# # 初始化时读取所有person_name
try:
pass
# 在这里获取会话
with get_db_session() as session:
for record in session.execute(
select(PersonInfo.person_id, PersonInfo.person_name).where(PersonInfo.person_name.is_not(None))
).fetchall():
if record.person_name:
self.person_name_list[record.person_id] = record.person_name
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (SQLAlchemy)")
# with get_db_session() as session:
# for record in session.execute(
# select(PersonInfo.person_id, PersonInfo.person_name).where(PersonInfo.person_name.is_not(None))
# ).fetchall():
# if record.person_name:
# self.person_name_list[record.person_id] = record.person_name
# logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (SQLAlchemy)")
except Exception as e:
logger.error(f"从 SQLAlchemy 加载 person_name_list 失败: {e}")
@@ -430,23 +431,25 @@ class PersonInfoManager:
"""判断是否认识某人"""
person_id = self.get_person_id(platform, user_id)
def _db_check_known_sync(p_id: str):
async def _db_check_known_async(p_id: str):
# 在需要时获取会话
with get_db_session() as session:
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() is not None
async with get_db_session() as session:
return (
await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))
).scalar() is not None
try:
return await asyncio.to_thread(_db_check_known_sync, person_id)
return await _db_check_known_async(person_id)
except Exception as e:
logger.error(f"检查用户 {person_id} 是否已知时出错 (SQLAlchemy): {e}")
return False
def get_person_id_by_person_name(self, person_name: str) -> str:
async def get_person_id_by_person_name(self, person_name: str) -> str:
"""根据用户名获取用户ID"""
try:
# 在需要时获取会话
with get_db_session() as session:
record = session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name)).scalar()
async with get_db_session() as session:
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name))).scalar()
return record.person_id if record else ""
except Exception as e:
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (SQLAlchemy): {e}")
@@ -500,19 +503,18 @@ class PersonInfoManager:
final_data[key] = orjson.dumps([]).decode("utf-8")
# If it's already a string, assume it's valid JSON or a non-JSON string field
def _db_create_sync(p_data: dict):
with get_db_session() as session:
async def _db_create_async(p_data: dict):
async with get_db_session() as session:
try:
new_person = PersonInfo(**p_data)
session.add(new_person)
session.commit()
await session.commit()
return True
except Exception as e:
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}")
return False
await asyncio.to_thread(_db_create_sync, final_data)
await _db_create_async(final_data)
async def _safe_create_person_info(self, person_id: str, data: Optional[dict] = None):
"""安全地创建用户信息,处理竞态条件"""
@@ -557,11 +559,11 @@ class PersonInfoManager:
elif final_data[key] is None: # Default for lists is [], store as "[]"
final_data[key] = orjson.dumps([]).decode("utf-8")
def _db_safe_create_sync(p_data: dict):
with get_db_session() as session:
async def _db_safe_create_async(p_data: dict):
async with get_db_session() as session:
try:
existing = session.execute(
select(PersonInfo).where(PersonInfo.person_id == p_data["person_id"])
existing = (
await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_data["person_id"]))
).scalar()
if existing:
logger.debug(f"用户 {p_data['person_id']} 已存在,跳过创建")
@@ -570,18 +572,17 @@ class PersonInfoManager:
# 尝试创建
new_person = PersonInfo(**p_data)
session.add(new_person)
session.commit()
await session.commit()
return True
except Exception as e:
if "UNIQUE constraint failed" in str(e):
logger.debug(f"检测到并发创建用户 {p_data.get('person_id')},跳过错误")
return True # 其他协程已创建,视为成功
return True
else:
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}")
return False
await asyncio.to_thread(_db_safe_create_sync, final_data)
await _db_safe_create_async(final_data)
async def update_one_field(self, person_id: str, field_name: str, value, data: Optional[Dict] = None):
"""更新某一个字段,会补全"""
@@ -598,37 +599,33 @@ class PersonInfoManager:
elif value is None: # Store None as "[]" for JSON list fields
processed_value = orjson.dumps([]).decode("utf-8")
def _db_update_sync(p_id: str, f_name: str, val_to_set):
async def _db_update_async(p_id: str, f_name: str, val_to_set):
start_time = time.time()
with get_db_session() as session:
async with get_db_session() as session:
try:
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
query_time = time.time()
if record:
setattr(record, f_name, val_to_set)
save_time = time.time()
total_time = save_time - start_time
if total_time > 0.5: # 如果超过500ms就记录日志
if total_time > 0.5:
logger.warning(
f"数据库更新操作耗时 {total_time:.3f}秒 (查询: {query_time - start_time:.3f}s, 保存: {save_time - query_time:.3f}s) person_id={p_id}, field={f_name}"
)
session.commit()
return True, False # Found and updated, no creation needed
await session.commit()
return True, False
else:
total_time = time.time() - start_time
if total_time > 0.5:
logger.warning(f"数据库查询操作耗时 {total_time:.3f}秒 person_id={p_id}, field={f_name}")
return False, True # Not found, needs creation
return False, True
except Exception as e:
total_time = time.time() - start_time
logger.error(f"数据库操作异常,耗时 {total_time:.3f}秒: {e}")
raise
found, needs_creation = await asyncio.to_thread(_db_update_sync, person_id, field_name, processed_value)
found, needs_creation = await _db_update_async(person_id, field_name, processed_value)
if needs_creation:
logger.info(f"{person_id} 不存在,将新建。")
@@ -666,13 +663,13 @@ class PersonInfoManager:
logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo SQLAlchemy 模型中定义。")
return False
def _db_has_field_sync(p_id: str, f_name: str):
with get_db_session() as session:
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
async def _db_has_field_async(p_id: str, f_name: str):
async with get_db_session() as session:
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
return bool(record)
try:
return await asyncio.to_thread(_db_has_field_sync, person_id, field_name)
return await _db_has_field_async(person_id, field_name)
except Exception as e:
logger.error(f"检查字段 {field_name} for {person_id} 时出错 (SQLAlchemy): {e}")
return False
@@ -778,14 +775,14 @@ class PersonInfoManager:
logger.info(f"尝试给用户{user_nickname} {person_id} 取名,但是 {generated_nickname} 已存在,重试中...")
else:
def _db_check_name_exists_sync(name_to_check):
with get_db_session() as session:
async def _db_check_name_exists_async(name_to_check):
async with get_db_session() as session:
return (
session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check)).scalar()
(await session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check))).scalar()
is not None
)
if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname):
if await _db_check_name_exists_async(generated_nickname):
is_duplicate = True
current_name_set.add(generated_nickname)
@@ -824,91 +821,26 @@ class PersonInfoManager:
logger.debug("删除失败person_id 不能为空")
return
def _db_delete_sync(p_id: str):
async def _db_delete_async(p_id: str):
try:
with get_db_session() as session:
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
async with get_db_session() as session:
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
if record:
session.delete(record)
session.commit()
return 1
await session.delete(record)
await session.commit()
return 1
return 0
except Exception as e:
logger.error(f"删除 PersonInfo {p_id} 失败 (SQLAlchemy): {e}")
return 0
deleted_count = await asyncio.to_thread(_db_delete_sync, person_id)
deleted_count = await _db_delete_async(person_id)
if deleted_count > 0:
logger.debug(f"删除成功person_id={person_id} (Peewee)")
logger.debug(f"删除成功person_id={person_id}")
else:
logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行 (Peewee)")
logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行")
@staticmethod
async def get_value(person_id: str, field_name: str):
"""获取指定用户指定字段的值"""
default_value_for_field = person_info_default.get(field_name)
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
default_value_for_field = [] # Ensure JSON fields default to [] if not in DB
def _db_get_value_sync(p_id: str, f_name: str):
with get_db_session() as session:
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
if record:
val = getattr(record, f_name, None)
if f_name in JSON_SERIALIZED_FIELDS:
if isinstance(val, str):
try:
return orjson.loads(val)
except orjson.JSONDecodeError:
logger.warning(f"字段 {f_name} for {p_id} 包含无效JSON: {val}. 返回默认值.")
return [] # Default for JSON fields on error
elif val is None: # Field exists in DB but is None
return [] # Default for JSON fields
# If val is already a list/dict (e.g. if somehow set without serialization)
return val # Should ideally not happen if update_one_field is always used
return val
return None # Record not found
try:
value_from_db = await asyncio.to_thread(_db_get_value_sync, person_id, field_name)
if value_from_db is not None:
return value_from_db
if field_name in person_info_default:
return default_value_for_field
logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。")
return None # Ultimate fallback
except Exception as e:
logger.error(f"获取字段 {field_name} for {person_id} 时出错 (Peewee): {e}")
# Fallback to default in case of any error during DB access
return default_value_for_field if field_name in person_info_default else None
@staticmethod
def get_value_sync(person_id: str, field_name: str):
"""同步获取指定用户指定字段的值"""
default_value_for_field = person_info_default.get(field_name)
with get_db_session() as session:
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
default_value_for_field = []
if record := session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id)).scalar():
val = getattr(record, field_name, None)
if field_name in JSON_SERIALIZED_FIELDS:
if isinstance(val, str):
try:
return orjson.loads(val)
except orjson.JSONDecodeError:
logger.warning(f"字段 {field_name} for {person_id} 包含无效JSON: {val}. 返回默认值.")
return []
elif val is None:
return []
return val
return val
if field_name in person_info_default:
return default_value_for_field
logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。")
return None
@staticmethod
async def get_values(person_id: str, field_names: list) -> dict:
@@ -919,11 +851,11 @@ class PersonInfoManager:
result = {}
def _db_get_record_sync(p_id: str):
with get_db_session() as session:
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
async def _db_get_record_async(p_id: str):
async with get_db_session() as session:
return (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
record = await asyncio.to_thread(_db_get_record_sync, person_id)
record = await _db_get_record_async(person_id)
# 获取 SQLAlchemy 模型的所有字段名
model_fields = [column.name for column in PersonInfo.__table__.columns]
@@ -960,14 +892,15 @@ class PersonInfoManager:
# 获取 SQLAlchemy 模型的所有字段名
model_fields = [column.name for column in PersonInfo.__table__.columns]
if field_name not in model_fields:
logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo SQLAlchemy 模 modelo中定义")
logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo SQLAlchemy 模中定义")
return {}
def _db_get_specific_sync(f_name: str):
async def _db_get_specific_async(f_name: str):
found_results = {}
try:
with get_db_session() as session:
for record in session.execute(select(PersonInfo.person_id, getattr(PersonInfo, f_name))).fetchall():
async with get_db_session() as session:
result = await session.execute(select(PersonInfo.person_id, getattr(PersonInfo, f_name)))
for record in result.fetchall():
value = getattr(record, f_name)
if way(value):
found_results[record.person_id] = value
@@ -978,9 +911,9 @@ class PersonInfoManager:
return found_results
try:
return await asyncio.to_thread(_db_get_specific_sync, field_name)
return await _db_get_specific_async(field_name)
except Exception as e:
logger.error(f"执行 get_specific_value_list 线程时出错: {str(e)}", exc_info=True)
logger.error(f"执行 get_specific_value_list 时出错: {str(e)}", exc_info=True)
return {}
async def get_or_create_person(
@@ -993,40 +926,38 @@ class PersonInfoManager:
"""
person_id = self.get_person_id(platform, user_id)
def _db_get_or_create_sync(p_id: str, init_data: dict):
async def _db_get_or_create_async(p_id: str, init_data: dict):
"""原子性的获取或创建操作"""
with get_db_session() as session:
async with get_db_session() as session:
# 首先尝试获取现有记录
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
if record:
return record, False # 记录存在,未创建
# 记录不存在,尝试创建
try:
new_person = PersonInfo(**init_data)
session.add(new_person)
session.commit()
return session.execute(
select(PersonInfo).where(PersonInfo.person_id == p_id)
).scalar(), True # 创建成功
except Exception as e:
# 如果创建失败(可能是因为竞态条件),再次尝试获取
if "UNIQUE constraint failed" in str(e):
logger.debug(f"检测到并发创建用户 {p_id},获取现有记录")
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
if record:
return record, False # 其他协程已创建,返回现有记录
# 如果仍然失败,重新抛出异常
raise e
# 记录不存在,尝试创建
try:
new_person = PersonInfo(**init_data)
session.add(new_person)
await session.commit()
await session.refresh(new_person)
return new_person, True # 创建成功
except Exception as e:
# 如果创建失败(可能是因为竞态条件),再次尝试获取
if "UNIQUE constraint failed" in str(e):
logger.debug(f"检测到并发创建用户 {p_id},获取现有记录")
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
if record:
return record, False # 其他协程已创建,返回现有记录
# 如果仍然失败,重新抛出异常
raise e
unique_nickname = await self._generate_unique_person_name(nickname)
initial_data = {
"person_id": person_id,
"platform": platform,
"user_id": str(user_id),
"nickname": nickname,
"person_name": unique_nickname, # 使用群昵称作为person_name
"person_name": unique_nickname,
"name_reason": "从群昵称获取",
"know_times": 0,
"know_since": int(datetime.datetime.now().timestamp()),
@@ -1036,7 +967,6 @@ class PersonInfoManager:
"forgotten_points": [],
}
# 序列化JSON字段
for key in JSON_SERIALIZED_FIELDS:
if key in initial_data:
if isinstance(initial_data[key], (list, dict)):
@@ -1044,15 +974,14 @@ class PersonInfoManager:
elif initial_data[key] is None:
initial_data[key] = orjson.dumps([]).decode("utf-8")
# 获取 SQLAlchemy 模odel的所有字段名
model_fields = [column.name for column in PersonInfo.__table__.columns]
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
record, was_created = await asyncio.to_thread(_db_get_or_create_sync, person_id, filtered_initial_data)
record, was_created = await _db_get_or_create_async(person_id, filtered_initial_data)
if was_created:
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)")
logger.info(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}")
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录。")
logger.info(f"已为 {person_id} 创建新记录,初始数据: {filtered_initial_data}")
else:
logger.debug(f"用户 {platform}:{user_id} (person_id: {person_id}) 已存在,返回现有记录。")
@@ -1072,11 +1001,13 @@ class PersonInfoManager:
if not found_person_id:
def _db_find_by_name_sync(p_name_to_find: str):
with get_db_session() as session:
return session.execute(select(PersonInfo).where(PersonInfo.person_name == p_name_to_find)).scalar()
async def _db_find_by_name_async(p_name_to_find: str):
async with get_db_session() as session:
return (
await session.execute(select(PersonInfo).where(PersonInfo.person_name == p_name_to_find))
).scalar()
record = await asyncio.to_thread(_db_find_by_name_sync, person_name)
record = await _db_find_by_name_async(person_name)
if record:
found_person_id = record.person_id
if (

View File

@@ -0,0 +1,457 @@
import time
import traceback
import orjson
import random
from typing import List, Dict, Any
from json_repair import repair_json
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.person_info.person_info import get_person_info_manager
logger = get_logger("relationship_fetcher")
def init_real_time_info_prompts():
"""初始化实时信息提取相关的提示词"""
relationship_prompt = """
<聊天记录>
{chat_observe_info}
</聊天记录>
{name_block}
现在,你想要回复{person_name}的消息,消息内容是:{target_message}。请根据聊天记录和你要回复的消息,从你对{person_name}的了解中提取有关的信息:
1.你需要提供你想要提取的信息具体是哪方面的信息,例如:年龄,性别,你们之间的交流方式,最近发生的事等等。
2.请注意,请不要重复调取相同的信息,已经调取的信息如下:
{info_cache_block}
3.如果当前聊天记录中没有需要查询的信息,或者现有信息已经足够回复,请返回{{"none": "不需要查询"}}
请以json格式输出例如
{{
"info_type": "信息类型",
}}
请严格按照json输出格式不要输出多余内容
"""
Prompt(relationship_prompt, "real_time_info_identify_prompt")
fetch_info_prompt = """
{name_block}
以下是你在之前与{person_name}的交流中,产生的对{person_name}的了解:
{person_impression_block}
{points_text_block}
请从中提取用户"{person_name}"的有关"{info_type}"信息
请以json格式输出例如
{{
{info_json_str}
}}
请严格按照json输出格式不要输出多余内容
"""
Prompt(fetch_info_prompt, "real_time_fetch_person_info_prompt")
class RelationshipFetcher:
def __init__(self, chat_id):
self.chat_id = chat_id
# 信息获取缓存:记录正在获取的信息请求
self.info_fetching_cache: List[Dict[str, Any]] = []
# 信息结果缓存存储已获取的信息结果带TTL
self.info_fetched_cache: Dict[str, Dict[str, Any]] = {}
# 结构:{person_id: {info_type: {"info": str, "ttl": int, "start_time": float, "person_name": str, "unknown": bool}}}
# LLM模型配置
self.llm_model = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="relation.fetcher"
)
# 小模型用于即时信息提取
self.instant_llm_model = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="relation.fetch"
)
name = get_chat_manager().get_stream_name(self.chat_id)
self.log_prefix = f"[{name}] 实时信息"
def _cleanup_expired_cache(self):
"""清理过期的信息缓存"""
for person_id in list(self.info_fetched_cache.keys()):
for info_type in list(self.info_fetched_cache[person_id].keys()):
self.info_fetched_cache[person_id][info_type]["ttl"] -= 1
if self.info_fetched_cache[person_id][info_type]["ttl"] <= 0:
del self.info_fetched_cache[person_id][info_type]
if not self.info_fetched_cache[person_id]:
del self.info_fetched_cache[person_id]
async def build_relation_info(self, person_id, points_num=3):
# 清理过期的信息缓存
self._cleanup_expired_cache()
person_info_manager = get_person_info_manager()
person_info = await person_info_manager.get_values(
person_id, ["person_name", "short_impression", "nickname", "platform", "points"]
)
person_name = person_info.get("person_name")
short_impression = person_info.get("short_impression")
nickname_str = person_info.get("nickname")
platform = person_info.get("platform")
if person_name == nickname_str and not short_impression:
return ""
current_points = person_info.get("points") or []
# 按时间排序forgotten_points
current_points.sort(key=lambda x: x[2])
# 按权重加权随机抽取最多3个不重复的pointspoint[1]的值在1-10之间权重越高被抽到概率越大
if len(current_points) > points_num:
# point[1] 取值范围1-10直接作为权重
weights = [max(1, min(10, int(point[1]))) for point in current_points]
# 使用加权采样不放回,保证不重复
indices = list(range(len(current_points)))
points = []
for _ in range(points_num):
if not indices:
break
sub_weights = [weights[i] for i in indices]
chosen_idx = random.choices(indices, weights=sub_weights, k=1)[0]
points.append(current_points[chosen_idx])
indices.remove(chosen_idx)
else:
points = current_points
# 构建points文本
points_text = "\n".join([f"{point[2]}{point[0]}" for point in points])
nickname_str = ""
if person_name != nickname_str:
nickname_str = f"(ta在{platform}上的昵称是{nickname_str})"
relation_info = ""
if short_impression and relation_info:
if points_text:
relation_info = f"你对{person_name}的印象是{nickname_str}{short_impression}。具体来说:{relation_info}。你还记得ta最近做的事{points_text}"
else:
relation_info = (
f"你对{person_name}的印象是{nickname_str}{short_impression}。具体来说:{relation_info}"
)
elif short_impression:
if points_text:
relation_info = (
f"你对{person_name}的印象是{nickname_str}{short_impression}。你还记得ta最近做的事{points_text}"
)
else:
relation_info = f"你对{person_name}的印象是{nickname_str}{short_impression}"
elif relation_info:
if points_text:
relation_info = (
f"你对{person_name}的了解{nickname_str}{relation_info}。你还记得ta最近做的事{points_text}"
)
else:
relation_info = f"你对{person_name}的了解{nickname_str}{relation_info}"
elif points_text:
relation_info = f"你记得{person_name}{nickname_str}最近做的事:{points_text}"
else:
relation_info = ""
return relation_info
async def _build_fetch_query(self, person_id, target_message, chat_history):
nickname_str = ",".join(global_config.bot.alias_names)
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
person_info_manager = get_person_info_manager()
person_info = await person_info_manager.get_values(person_id, ["person_name"])
person_name: str = person_info.get("person_name") # type: ignore
info_cache_block = self._build_info_cache_block()
prompt = (await global_prompt_manager.get_prompt_async("real_time_info_identify_prompt")).format(
chat_observe_info=chat_history,
name_block=name_block,
info_cache_block=info_cache_block,
person_name=person_name,
target_message=target_message,
)
try:
logger.debug(f"{self.log_prefix} 信息识别prompt: \n{prompt}\n")
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
if content:
content_json = orjson.loads(repair_json(content))
# 检查是否返回了不需要查询的标志
if "none" in content_json:
logger.debug(f"{self.log_prefix} LLM判断当前不需要查询任何信息{content_json.get('none', '')}")
return None
if info_type := content_json.get("info_type"):
# 记录信息获取请求
self.info_fetching_cache.append(
{
"person_id": get_person_info_manager().get_person_id_by_person_name(person_name),
"person_name": person_name,
"info_type": info_type,
"start_time": time.time(),
"forget": False,
}
)
# 限制缓存大小
if len(self.info_fetching_cache) > 10:
self.info_fetching_cache.pop(0)
logger.info(f"{self.log_prefix} 识别到需要调取用户 {person_name} 的[{info_type}]信息")
return info_type
else:
logger.warning(f"{self.log_prefix} LLM未返回有效的info_type。响应: {content}")
except Exception as e:
logger.error(f"{self.log_prefix} 执行信息识别LLM请求时出错: {e}")
logger.error(traceback.format_exc())
return None
def _build_info_cache_block(self) -> str:
"""构建已获取信息的缓存块"""
info_cache_block = ""
if self.info_fetching_cache:
# 对于每个(person_id, info_type)组合,只保留最新的记录
latest_records = {}
for info_fetching in self.info_fetching_cache:
key = (info_fetching["person_id"], info_fetching["info_type"])
if key not in latest_records or info_fetching["start_time"] > latest_records[key]["start_time"]:
latest_records[key] = info_fetching
# 按时间排序并生成显示文本
sorted_records = sorted(latest_records.values(), key=lambda x: x["start_time"])
for info_fetching in sorted_records:
info_cache_block += (
f"你已经调取了[{info_fetching['person_name']}]的[{info_fetching['info_type']}]信息\n"
)
return info_cache_block
async def _extract_single_info(self, person_id: str, info_type: str, person_name: str):
"""提取单个信息类型
Args:
person_id: 用户ID
info_type: 信息类型
person_name: 用户名
"""
start_time = time.time()
person_info_manager = get_person_info_manager()
# 首先检查 info_list 缓存
person_info = await person_info_manager.get_values(person_id, ["info_list"])
info_list = person_info.get("info_list") or []
cached_info = None
# 查找对应的 info_type
for info_item in info_list:
if info_item.get("info_type") == info_type:
cached_info = info_item.get("info_content")
logger.debug(f"{self.log_prefix} 在info_list中找到 {person_name}{info_type} 信息: {cached_info}")
break
# 如果缓存中有信息,直接使用
if cached_info:
if person_id not in self.info_fetched_cache:
self.info_fetched_cache[person_id] = {}
self.info_fetched_cache[person_id][info_type] = {
"info": cached_info,
"ttl": 2,
"start_time": start_time,
"person_name": person_name,
"unknown": cached_info == "none",
}
logger.info(f"{self.log_prefix} 记得 {person_name}{info_type}: {cached_info}")
return
# 如果缓存中没有,尝试从用户档案中提取
try:
person_info = await person_info_manager.get_values(person_id, ["impression", "points"])
person_impression = person_info.get("impression")
points = person_info.get("points")
# 构建印象信息块
if person_impression:
person_impression_block = (
f"<对{person_name}的总体了解>\n{person_impression}\n</对{person_name}的总体了解>"
)
else:
person_impression_block = ""
# 构建要点信息块
if points:
points_text = "\n".join([f"{point[2]}:{point[0]}" for point in points])
points_text_block = f"<对{person_name}的近期了解>\n{points_text}\n</对{person_name}的近期了解>"
else:
points_text_block = ""
# 如果完全没有用户信息
if not points_text_block and not person_impression_block:
if person_id not in self.info_fetched_cache:
self.info_fetched_cache[person_id] = {}
self.info_fetched_cache[person_id][info_type] = {
"info": "none",
"ttl": 2,
"start_time": start_time,
"person_name": person_name,
"unknown": True,
}
logger.info(f"{self.log_prefix} 完全不认识 {person_name}")
await self._save_info_to_cache(person_id, info_type, "none")
return
# 使用LLM提取信息
nickname_str = ",".join(global_config.bot.alias_names)
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
prompt = (await global_prompt_manager.get_prompt_async("real_time_fetch_person_info_prompt")).format(
name_block=name_block,
info_type=info_type,
person_impression_block=person_impression_block,
person_name=person_name,
info_json_str=f'"{info_type}": "有关{info_type}的信息内容"',
points_text_block=points_text_block,
)
# 使用小模型进行即时提取
content, _ = await self.instant_llm_model.generate_response_async(prompt=prompt)
if content:
content_json = orjson.loads(repair_json(content))
if info_type in content_json:
info_content = content_json[info_type]
is_unknown = info_content == "none" or not info_content
# 保存到运行时缓存
if person_id not in self.info_fetched_cache:
self.info_fetched_cache[person_id] = {}
self.info_fetched_cache[person_id][info_type] = {
"info": "unknown" if is_unknown else info_content,
"ttl": 3,
"start_time": start_time,
"person_name": person_name,
"unknown": is_unknown,
}
# 保存到持久化缓存 (info_list)
await self._save_info_to_cache(person_id, info_type, "none" if is_unknown else info_content)
if not is_unknown:
logger.info(f"{self.log_prefix} 思考得到,{person_name}{info_type}: {info_content}")
else:
logger.info(f"{self.log_prefix} 思考了也不知道{person_name}{info_type} 信息")
else:
logger.warning(f"{self.log_prefix} 小模型返回空结果,获取 {person_name}{info_type} 信息失败。")
except Exception as e:
logger.error(f"{self.log_prefix} 执行信息提取时出错: {e}")
logger.error(traceback.format_exc())
async def _save_info_to_cache(self, person_id: str, info_type: str, info_content: str):
# sourcery skip: use-next
"""将提取到的信息保存到 person_info 的 info_list 字段中
Args:
person_id: 用户ID
info_type: 信息类型
info_content: 信息内容
"""
try:
person_info_manager = get_person_info_manager()
# 获取现有的 info_list
person_info = await person_info_manager.get_values(person_id, ["info_list"])
info_list = person_info.get("info_list") or []
# 查找是否已存在相同 info_type 的记录
found_index = -1
for i, info_item in enumerate(info_list):
if isinstance(info_item, dict) and info_item.get("info_type") == info_type:
found_index = i
break
# 创建新的信息记录
new_info_item = {
"info_type": info_type,
"info_content": info_content,
}
if found_index >= 0:
# 更新现有记录
info_list[found_index] = new_info_item
logger.info(f"{self.log_prefix} [缓存更新] 更新 {person_id}{info_type} 信息缓存")
else:
# 添加新记录
info_list.append(new_info_item)
logger.info(f"{self.log_prefix} [缓存保存] 新增 {person_id}{info_type} 信息缓存")
# 保存更新后的 info_list
await person_info_manager.update_one_field(person_id, "info_list", info_list)
except Exception as e:
logger.error(f"{self.log_prefix} [缓存保存] 保存信息到缓存失败: {e}")
logger.error(traceback.format_exc())
class RelationshipFetcherManager:
"""关系提取器管理器
管理不同 chat_id 的 RelationshipFetcher 实例
"""
def __init__(self):
self._fetchers: Dict[str, RelationshipFetcher] = {}
def get_fetcher(self, chat_id: str) -> RelationshipFetcher:
"""获取或创建指定 chat_id 的 RelationshipFetcher
Args:
chat_id: 聊天ID
Returns:
RelationshipFetcher: 关系提取器实例
"""
if chat_id not in self._fetchers:
self._fetchers[chat_id] = RelationshipFetcher(chat_id)
return self._fetchers[chat_id]
def remove_fetcher(self, chat_id: str):
"""移除指定 chat_id 的 RelationshipFetcher
Args:
chat_id: 聊天ID
"""
if chat_id in self._fetchers:
del self._fetchers[chat_id]
def clear_all(self):
"""清空所有 RelationshipFetcher"""
self._fetchers.clear()
def get_active_chat_ids(self) -> List[str]:
"""获取所有活跃的 chat_id 列表"""
return list(self._fetchers.keys())
# 全局管理器实例
relationship_fetcher_manager = RelationshipFetcherManager()
init_real_time_info_prompts()