完成所有类型注解的修复

This commit is contained in:
UnCLAS-Prommer
2025-07-13 00:19:54 +08:00
parent d2ad6ea1d8
commit 7ef0bfb7c8
32 changed files with 358 additions and 434 deletions

View File

@@ -1,17 +1,18 @@
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import PersonInfo # 新增导入
import copy
import hashlib
from typing import Any, Callable, Dict, Union
import datetime
import asyncio
import json
from json_repair import repair_json
from typing import Any, Callable, Dict, Union, Optional
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import PersonInfo
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
import json # 新增导入
from json_repair import repair_json
"""
PersonInfoManager 类方法功能摘要:
@@ -42,7 +43,7 @@ person_info_default = {
"last_know": None,
# "user_cardname": None, # This field is not in Peewee model PersonInfo
# "user_avatar": None, # This field is not in Peewee model PersonInfo
"impression": None, # Corrected from persion_impression
"impression": None, # Corrected from person_impression
"short_impression": None,
"info_list": None,
"points": None,
@@ -106,27 +107,24 @@ class PersonInfoManager:
logger.error(f"检查用户 {person_id} 是否已知时出错 (Peewee): {e}")
return False
def get_person_id_by_person_name(self, person_name: str):
def get_person_id_by_person_name(self, person_name: str) -> str:
"""根据用户名获取用户ID"""
try:
record = PersonInfo.get_or_none(PersonInfo.person_name == person_name)
if record:
return record.person_id
else:
return ""
return record.person_id if record else ""
except Exception as e:
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (Peewee): {e}")
return ""
@staticmethod
async def create_person_info(person_id: str, data: dict = None):
async def create_person_info(person_id: str, data: Optional[dict] = None):
"""创建一个项"""
if not person_id:
logger.debug("创建失败personid不存在")
logger.debug("创建失败person_id不存在")
return
_person_info_default = copy.deepcopy(person_info_default)
model_fields = PersonInfo._meta.fields.keys()
model_fields = PersonInfo._meta.fields.keys() # type: ignore
final_data = {"person_id": person_id}
@@ -163,9 +161,9 @@ class PersonInfoManager:
await asyncio.to_thread(_db_create_sync, final_data)
async def update_one_field(self, person_id: str, field_name: str, value, data: dict = None):
async def update_one_field(self, person_id: str, field_name: str, value, data: Optional[Dict] = None):
"""更新某一个字段,会补全"""
if field_name not in PersonInfo._meta.fields:
if field_name not in PersonInfo._meta.fields: # type: ignore
logger.debug(f"更新'{field_name}'失败,未在 PersonInfo Peewee 模型中定义的字段。")
return
@@ -228,15 +226,13 @@ class PersonInfoManager:
@staticmethod
async def has_one_field(person_id: str, field_name: str):
"""判断是否存在某一个字段"""
if field_name not in PersonInfo._meta.fields:
if field_name not in PersonInfo._meta.fields: # type: ignore
logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo Peewee 模型中定义。")
return False
def _db_has_field_sync(p_id: str, f_name: str):
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
if record:
return True
return False
return bool(record)
try:
return await asyncio.to_thread(_db_has_field_sync, person_id, field_name)
@@ -435,9 +431,7 @@ class PersonInfoManager:
except Exception as e:
logger.error(f"获取字段 {field_name} for {person_id} 时出错 (Peewee): {e}")
# Fallback to default in case of any error during DB access
if field_name in person_info_default:
return default_value_for_field
return None
return default_value_for_field if field_name in person_info_default else None
@staticmethod
def get_value_sync(person_id: str, field_name: str):
@@ -446,8 +440,7 @@ class PersonInfoManager:
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
default_value_for_field = []
record = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if record:
if record := PersonInfo.get_or_none(PersonInfo.person_id == person_id):
val = getattr(record, field_name, None)
if field_name in JSON_SERIALIZED_FIELDS:
if isinstance(val, str):
@@ -481,7 +474,7 @@ class PersonInfoManager:
record = await asyncio.to_thread(_db_get_record_sync, person_id)
for field_name in field_names:
if field_name not in PersonInfo._meta.fields:
if field_name not in PersonInfo._meta.fields: # type: ignore
if field_name in person_info_default:
result[field_name] = copy.deepcopy(person_info_default[field_name])
logger.debug(f"字段'{field_name}'不在Peewee模型中使用默认配置值。")
@@ -509,7 +502,7 @@ class PersonInfoManager:
"""
获取满足条件的字段值字典
"""
if field_name not in PersonInfo._meta.fields:
if field_name not in PersonInfo._meta.fields: # type: ignore
logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo Peewee 模型中定义")
return {}
@@ -531,7 +524,7 @@ class PersonInfoManager:
return {}
async def get_or_create_person(
self, platform: str, user_id: int, nickname: str = None, user_cardname: str = None, user_avatar: str = None
self, platform: str, user_id: int, nickname: str, user_cardname: str, user_avatar: Optional[str] = None
) -> str:
"""
根据 platform 和 user_id 获取 person_id。
@@ -561,7 +554,7 @@ class PersonInfoManager:
"points": [],
"forgotten_points": [],
}
model_fields = PersonInfo._meta.fields.keys()
model_fields = PersonInfo._meta.fields.keys() # type: ignore
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
await self.create_person_info(person_id, data=filtered_initial_data)
@@ -610,7 +603,9 @@ class PersonInfoManager:
"name_reason",
]
valid_fields_to_get = [
f for f in required_fields if f in PersonInfo._meta.fields or f in person_info_default
f
for f in required_fields
if f in PersonInfo._meta.fields or f in person_info_default # type: ignore
]
person_data = await self.get_values(found_person_id, valid_fields_to_get)

View File

@@ -3,12 +3,12 @@ import traceback
import os
import pickle
import random
from typing import List, Dict
from typing import List, Dict, Any
from src.config.config import global_config
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import get_chat_manager
from src.person_info.relationship_manager import get_relationship_manager
from src.person_info.person_info import get_person_info_manager, PersonInfoManager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp_with_chat,
get_raw_msg_by_timestamp_with_chat_inclusive,
@@ -45,7 +45,7 @@ class RelationshipBuilder:
self.chat_id = chat_id
# 新的消息段缓存结构:
# {person_id: [{"start_time": float, "end_time": float, "last_msg_time": float, "message_count": int}, ...]}
self.person_engaged_cache: Dict[str, List[Dict[str, any]]] = {}
self.person_engaged_cache: Dict[str, List[Dict[str, Any]]] = {}
# 持久化存储文件路径
self.cache_file_path = os.path.join("data", "relationship", f"relationship_cache_{self.chat_id}.pkl")
@@ -210,11 +210,7 @@ class RelationshipBuilder:
if person_id not in self.person_engaged_cache:
return 0
total_count = 0
for segment in self.person_engaged_cache[person_id]:
total_count += segment["message_count"]
return total_count
return sum(segment["message_count"] for segment in self.person_engaged_cache[person_id])
def _cleanup_old_segments(self) -> bool:
"""清理老旧的消息段"""
@@ -289,7 +285,7 @@ class RelationshipBuilder:
self.last_cleanup_time = current_time
# 保存缓存
if cleanup_stats["segments_removed"] > 0 or len(users_to_remove) > 0:
if cleanup_stats["segments_removed"] > 0 or users_to_remove:
self._save_cache()
logger.info(
f"{self.log_prefix} 清理完成 - 影响用户: {cleanup_stats['users_cleaned']}, 移除消息段: {cleanup_stats['segments_removed']}, 移除用户: {len(users_to_remove)}"
@@ -313,6 +309,7 @@ class RelationshipBuilder:
return False
def get_cache_status(self) -> str:
# sourcery skip: merge-list-append, merge-list-appends-into-extend
"""获取缓存状态信息,用于调试和监控"""
if not self.person_engaged_cache:
return f"{self.log_prefix} 关系缓存为空"
@@ -357,13 +354,12 @@ class RelationshipBuilder:
self._cleanup_old_segments()
current_time = time.time()
latest_messages = get_raw_msg_by_timestamp_with_chat(
if latest_messages := get_raw_msg_by_timestamp_with_chat(
self.chat_id,
self.last_processed_message_time,
current_time,
limit=50, # 获取自上次处理后的消息
)
if latest_messages:
):
# 处理所有新的非bot消息
for latest_msg in latest_messages:
user_id = latest_msg.get("user_id")
@@ -414,7 +410,7 @@ class RelationshipBuilder:
# 负责触发关系构建、整合消息段、更新用户印象
# ================================
async def update_impression_on_segments(self, person_id: str, chat_id: str, segments: List[Dict[str, any]]):
async def update_impression_on_segments(self, person_id: str, chat_id: str, segments: List[Dict[str, Any]]):
"""基于消息段更新用户印象"""
original_segment_count = len(segments)
logger.debug(f"开始为 {person_id} 基于 {original_segment_count} 个消息段更新印象")

View File

@@ -1,4 +1,5 @@
from typing import Dict, Optional, List
from typing import Dict, Optional, List, Any
from src.common.logger import get_logger
from .relationship_builder import RelationshipBuilder
@@ -63,7 +64,7 @@ class RelationshipBuilderManager:
"""
return list(self.builders.keys())
def get_status(self) -> Dict[str, any]:
def get_status(self) -> Dict[str, Any]:
"""获取管理器状态
Returns:
@@ -94,9 +95,7 @@ class RelationshipBuilderManager:
bool: 是否成功清理
"""
builder = self.get_builder(chat_id)
if builder:
return builder.force_cleanup_user_segments(person_id)
return False
return builder.force_cleanup_user_segments(person_id) if builder else False
# 全局管理器实例

View File

@@ -1,16 +1,19 @@
from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest
import time
import traceback
from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.person_info.person_info import get_person_info_manager
from typing import List, Dict
from json_repair import repair_json
from src.chat.message_receive.chat_stream import get_chat_manager
import json
import random
from typing import List, Dict, Any
from json_repair import repair_json
from src.common.logger import get_logger
from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.person_info.person_info import get_person_info_manager
logger = get_logger("relationship_fetcher")
@@ -62,11 +65,11 @@ class RelationshipFetcher:
self.chat_id = chat_id
# 信息获取缓存:记录正在获取的信息请求
self.info_fetching_cache: List[Dict[str, any]] = []
self.info_fetching_cache: List[Dict[str, Any]] = []
# 信息结果缓存存储已获取的信息结果带TTL
self.info_fetched_cache: Dict[str, Dict[str, any]] = {}
# 结构:{person_id: {info_type: {"info": str, "ttl": int, "start_time": float, "person_name": str, "unknow": bool}}}
self.info_fetched_cache: Dict[str, Dict[str, Any]] = {}
# 结构:{person_id: {info_type: {"info": str, "ttl": int, "start_time": float, "person_name": str, "unknown": bool}}}
# LLM模型配置
self.llm_model = LLMRequest(
@@ -184,7 +187,7 @@ class RelationshipFetcher:
nickname_str = ",".join(global_config.bot.alias_names)
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
person_info_manager = get_person_info_manager()
person_name = await person_info_manager.get_value(person_id, "person_name")
person_name: str = await person_info_manager.get_value(person_id, "person_name") # type: ignore
info_cache_block = self._build_info_cache_block()
@@ -208,8 +211,7 @@ class RelationshipFetcher:
logger.debug(f"{self.log_prefix} LLM判断当前不需要查询任何信息{content_json.get('none', '')}")
return None
info_type = content_json.get("info_type")
if info_type:
if info_type := content_json.get("info_type"):
# 记录信息获取请求
self.info_fetching_cache.append(
{
@@ -287,7 +289,7 @@ class RelationshipFetcher:
"ttl": 2,
"start_time": start_time,
"person_name": person_name,
"unknow": cached_info == "none",
"unknown": cached_info == "none",
}
logger.info(f"{self.log_prefix} 记得 {person_name}{info_type}: {cached_info}")
return
@@ -321,7 +323,7 @@ class RelationshipFetcher:
"ttl": 2,
"start_time": start_time,
"person_name": person_name,
"unknow": True,
"unknown": True,
}
logger.info(f"{self.log_prefix} 完全不认识 {person_name}")
await self._save_info_to_cache(person_id, info_type, "none")
@@ -353,15 +355,15 @@ class RelationshipFetcher:
if person_id not in self.info_fetched_cache:
self.info_fetched_cache[person_id] = {}
self.info_fetched_cache[person_id][info_type] = {
"info": "unknow" if is_unknown else info_content,
"info": "unknown" if is_unknown else info_content,
"ttl": 3,
"start_time": start_time,
"person_name": person_name,
"unknow": is_unknown,
"unknown": is_unknown,
}
# 保存到持久化缓存 (info_list)
await self._save_info_to_cache(person_id, info_type, info_content if not is_unknown else "none")
await self._save_info_to_cache(person_id, info_type, "none" if is_unknown else info_content)
if not is_unknown:
logger.info(f"{self.log_prefix} 思考得到,{person_name}{info_type}: {info_content}")
@@ -393,7 +395,7 @@ class RelationshipFetcher:
for info_type in self.info_fetched_cache[person_id]:
person_name = self.info_fetched_cache[person_id][info_type]["person_name"]
if not self.info_fetched_cache[person_id][info_type]["unknow"]:
if not self.info_fetched_cache[person_id][info_type]["unknown"]:
info_content = self.info_fetched_cache[person_id][info_type]["info"]
person_known_infos.append(f"[{info_type}]{info_content}")
else:
@@ -430,6 +432,7 @@ class RelationshipFetcher:
return persons_infos_str
async def _save_info_to_cache(self, person_id: str, info_type: str, info_content: str):
# sourcery skip: use-next
"""将提取到的信息保存到 person_info 的 info_list 字段中
Args:

View File

@@ -1,5 +1,5 @@
from src.common.logger import get_logger
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
from .person_info import PersonInfoManager, get_person_info_manager
import time
import random
from src.llm_models.utils_model import LLMRequest
@@ -12,7 +12,7 @@ from difflib import SequenceMatcher
import jieba
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from typing import List, Dict, Any
logger = get_logger("relation")
@@ -28,8 +28,7 @@ class RelationshipManager:
async def is_known_some_one(platform, user_id):
"""判断是否认识某人"""
person_info_manager = get_person_info_manager()
is_known = await person_info_manager.is_person_known(platform, user_id)
return is_known
return await person_info_manager.is_person_known(platform, user_id)
@staticmethod
async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str):
@@ -110,7 +109,7 @@ class RelationshipManager:
return relation_prompt
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages=None):
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[Dict[str, Any]]):
"""更新用户印象
Args:
@@ -123,7 +122,7 @@ class RelationshipManager:
person_info_manager = get_person_info_manager()
person_name = await person_info_manager.get_value(person_id, "person_name")
nickname = await person_info_manager.get_value(person_id, "nickname")
know_times = await person_info_manager.get_value(person_id, "know_times") or 0
know_times: float = await person_info_manager.get_value(person_id, "know_times") or 0 # type: ignore
alias_str = ", ".join(global_config.bot.alias_names)
# personality_block =get_individuality().get_personality_prompt(x_person=2, level=2)
@@ -142,13 +141,13 @@ class RelationshipManager:
# 遍历消息,构建映射
for msg in user_messages:
await person_info_manager.get_or_create_person(
platform=msg.get("chat_info_platform"),
user_id=msg.get("user_id"),
nickname=msg.get("user_nickname"),
user_cardname=msg.get("user_cardname"),
platform=msg.get("chat_info_platform"), # type: ignore
user_id=msg.get("user_id"), # type: ignore
nickname=msg.get("user_nickname"), # type: ignore
user_cardname=msg.get("user_cardname"), # type: ignore
)
replace_user_id = msg.get("user_id")
replace_platform = msg.get("chat_info_platform")
replace_user_id: str = msg.get("user_id") # type: ignore
replace_platform: str = msg.get("chat_info_platform") # type: ignore
replace_person_id = PersonInfoManager.get_person_id(replace_platform, replace_user_id)
replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name")
@@ -354,8 +353,8 @@ class RelationshipManager:
person_name = await person_info_manager.get_value(person_id, "person_name")
nickname = await person_info_manager.get_value(person_id, "nickname")
know_times = await person_info_manager.get_value(person_id, "know_times") or 0
attitude = await person_info_manager.get_value(person_id, "attitude") or 50
know_times: float = await person_info_manager.get_value(person_id, "know_times") or 0 # type: ignore
attitude: float = await person_info_manager.get_value(person_id, "attitude") or 50 # type: ignore
# 根据熟悉度,调整印象和简短印象的最大长度
if know_times > 300:
@@ -414,16 +413,14 @@ class RelationshipManager:
if len(remaining_points) < 10:
# 如果还没达到30条直接保留
remaining_points.append(point)
elif random.random() < keep_probability:
# 保留这个点,随机移除一个已保留的点
idx_to_remove = random.randrange(len(remaining_points))
points_to_move.append(remaining_points[idx_to_remove])
remaining_points[idx_to_remove] = point
else:
# 随机决定是否保留
if random.random() < keep_probability:
# 保留这个点,随机移除一个已保留的点
idx_to_remove = random.randrange(len(remaining_points))
points_to_move.append(remaining_points[idx_to_remove])
remaining_points[idx_to_remove] = point
else:
# 不保留这个点
points_to_move.append(point)
# 不保留这个点
points_to_move.append(point)
# 更新points和forgotten_points
current_points = remaining_points
@@ -520,7 +517,7 @@ class RelationshipManager:
new_attitude = int(relation_value_json.get("attitude", 50))
# 获取当前的关系值
old_attitude = await person_info_manager.get_value(person_id, "attitude") or 50
old_attitude: float = await person_info_manager.get_value(person_id, "attitude") or 50 # type: ignore
# 更新熟悉度
if new_attitude > 25: