re-style: 格式化代码

This commit is contained in:
John Richard
2025-10-02 20:26:01 +08:00
parent ecb02cae31
commit 7923eafef3
263 changed files with 3103 additions and 3123 deletions

View File

@@ -2,7 +2,8 @@ import copy
import datetime
import hashlib
import time
from typing import Any, Callable, Dict, Union, Optional
from collections.abc import Callable
from typing import Any
import orjson
from json_repair import repair_json
@@ -86,7 +87,7 @@ class PersonInfoManager:
logger.error(f"从 SQLAlchemy 加载 person_name_list 失败: {e}")
@staticmethod
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
def get_person_id(platform: str, user_id: int | str) -> str:
"""获取唯一id同步
说明: 原来该方法为异步并在内部尝试执行数据库检查/迁移,导致在许多调用处未 await 时返回 coroutine 对象。
@@ -167,7 +168,7 @@ class PersonInfoManager:
)
@staticmethod
async def create_person_info(person_id: str, data: Optional[dict] = None):
async def create_person_info(person_id: str, data: dict | None = None):
"""创建一个项"""
if not person_id:
logger.debug("创建失败person_id不存在")
@@ -228,7 +229,7 @@ class PersonInfoManager:
await _db_create_async(final_data)
@staticmethod
async def _safe_create_person_info(person_id: str, data: Optional[dict] = None):
async def _safe_create_person_info(person_id: str, data: dict | None = None):
"""安全地创建用户信息,处理竞态条件"""
if not person_id:
logger.debug("创建失败person_id不存在")
@@ -296,7 +297,7 @@ class PersonInfoManager:
await _db_safe_create_async(final_data)
async def update_one_field(self, person_id: str, field_name: str, value, data: Optional[Dict] = None):
async def update_one_field(self, person_id: str, field_name: str, value, data: dict | None = None):
"""更新某一个字段,会补全"""
# 获取 SQLAlchemy 模型的所有字段名
model_fields = [column.name for column in PersonInfo.__table__.columns]
@@ -628,7 +629,7 @@ class PersonInfoManager:
async def get_specific_value_list(
field_name: str,
way: Callable[[Any], bool],
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
获取满足条件的字段值字典
"""
@@ -649,18 +650,18 @@ class PersonInfoManager:
found_results[record.person_id] = value
except Exception as e_query:
logger.error(
f"数据库查询失败 (SQLAlchemy specific_value_list for {f_name}): {str(e_query)}", exc_info=True
f"数据库查询失败 (SQLAlchemy specific_value_list for {f_name}): {e_query!s}", exc_info=True
)
return found_results
try:
return await _db_get_specific_async(field_name)
except Exception as e:
logger.error(f"执行 get_specific_value_list 时出错: {str(e)}", exc_info=True)
logger.error(f"执行 get_specific_value_list 时出错: {e!s}", exc_info=True)
return {}
async def get_or_create_person(
self, platform: str, user_id: int, nickname: str, user_cardname: str, user_avatar: Optional[str] = None
self, platform: str, user_id: int, nickname: str, user_cardname: str, user_avatar: str | None = None
) -> str:
"""
根据 platform 和 user_id 获取 person_id。

View File

@@ -1,20 +1,21 @@
import time
import traceback
import os
import pickle
import random
from typing import List, Dict, Any
from src.config.config import global_config
from src.common.logger import get_logger
from src.person_info.relationship_manager import get_relationship_manager
from src.person_info.person_info import get_person_info_manager, PersonInfoManager
import time
import traceback
from typing import Any
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import (
get_raw_msg_before_timestamp_with_chat,
get_raw_msg_by_timestamp_with_chat,
get_raw_msg_by_timestamp_with_chat_inclusive,
get_raw_msg_before_timestamp_with_chat,
num_new_messages_since,
)
from src.common.logger import get_logger
from src.config.config import global_config
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
from src.person_info.relationship_manager import get_relationship_manager
logger = get_logger("relationship_builder")
@@ -45,7 +46,7 @@ class RelationshipBuilder:
self.chat_id = chat_id
# 新的消息段缓存结构:
# {person_id: [{"start_time": float, "end_time": float, "last_msg_time": float, "message_count": int}, ...]}
self.person_engaged_cache: Dict[str, List[Dict[str, Any]]] = {}
self.person_engaged_cache: dict[str, list[dict[str, Any]]] = {}
# 持久化存储文件路径
self.cache_file_path = os.path.join("data", "relationship", f"relationship_cache_{self.chat_id}.pkl")
@@ -401,7 +402,7 @@ class RelationshipBuilder:
# 负责触发关系构建、整合消息段、更新用户印象
# ================================
async def update_impression_on_segments(self, person_id: str, chat_id: str, segments: List[Dict[str, Any]]):
async def update_impression_on_segments(self, person_id: str, chat_id: str, segments: list[dict[str, Any]]):
"""基于消息段更新用户印象"""
original_segment_count = len(segments)
logger.debug(f"开始为 {person_id} 基于 {original_segment_count} 个消息段更新印象")

View File

@@ -1,6 +1,7 @@
from typing import Dict, Optional, List, Any
from typing import Any
from src.common.logger import get_logger
from .relationship_builder import RelationshipBuilder
logger = get_logger("relationship_builder_manager")
@@ -13,7 +14,7 @@ class RelationshipBuilderManager:
"""
def __init__(self):
self.builders: Dict[str, RelationshipBuilder] = {}
self.builders: dict[str, RelationshipBuilder] = {}
def get_or_create_builder(self, chat_id: str) -> RelationshipBuilder:
"""获取或创建关系构建器
@@ -30,7 +31,7 @@ class RelationshipBuilderManager:
return self.builders[chat_id]
def get_builder(self, chat_id: str) -> Optional[RelationshipBuilder]:
def get_builder(self, chat_id: str) -> RelationshipBuilder | None:
"""获取关系构建器
Args:
@@ -56,7 +57,7 @@ class RelationshipBuilderManager:
return True
return False
def get_all_chat_ids(self) -> List[str]:
def get_all_chat_ids(self) -> list[str]:
"""获取所有管理的聊天ID列表
Returns:
@@ -64,7 +65,7 @@ class RelationshipBuilderManager:
"""
return list(self.builders.keys())
def get_status(self) -> Dict[str, Any]:
def get_status(self) -> dict[str, Any]:
"""获取管理器状态
Returns:

View File

@@ -1,18 +1,17 @@
import time
import traceback
import orjson
from typing import Any
from typing import List, Dict, Any
import orjson
from json_repair import repair_json
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.person_info.person_info import get_person_info_manager
logger = get_logger("relationship_fetcher")
@@ -64,10 +63,10 @@ class RelationshipFetcher:
self.chat_id = chat_id
# 信息获取缓存:记录正在获取的信息请求
self.info_fetching_cache: List[Dict[str, Any]] = []
self.info_fetching_cache: list[dict[str, Any]] = []
# 信息结果缓存存储已获取的信息结果带TTL
self.info_fetched_cache: Dict[str, Dict[str, Any]] = {}
self.info_fetched_cache: dict[str, dict[str, Any]] = {}
# 结构:{person_id: {info_type: {"info": str, "ttl": int, "start_time": float, "person_name": str, "unknown": bool}}}
# LLM模型配置
@@ -471,7 +470,7 @@ class RelationshipFetcherManager:
"""
def __init__(self):
self._fetchers: Dict[str, RelationshipFetcher] = {}
self._fetchers: dict[str, RelationshipFetcher] = {}
def get_fetcher(self, chat_id: str) -> RelationshipFetcher:
"""获取或创建指定 chat_id 的 RelationshipFetcher
@@ -499,7 +498,7 @@ class RelationshipFetcherManager:
"""清空所有 RelationshipFetcher"""
self._fetchers.clear()
def get_active_chat_ids(self) -> List[str]:
def get_active_chat_ids(self) -> list[str]:
"""获取所有活跃的 chat_id 列表"""
return list(self._fetchers.keys())

View File

@@ -1,18 +1,21 @@
from src.common.logger import get_logger
from .person_info import PersonInfoManager, get_person_info_manager
import time
import random
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.chat.utils.chat_message_builder import build_readable_messages
import orjson
from json_repair import repair_json
import time
from datetime import datetime
from difflib import SequenceMatcher
from typing import Any
import jieba
import orjson
from json_repair import repair_json
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from typing import List, Dict, Any
from src.chat.utils.chat_message_builder import build_readable_messages
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from .person_info import PersonInfoManager, get_person_info_manager
logger = get_logger("relation")
@@ -54,7 +57,7 @@ class RelationshipManager:
# person_id=person_id, user_nickname=user_nickname, user_cardname=user_cardname, user_avatar=user_avatar
# )
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[Dict[str, Any]]):
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: list[dict[str, Any]]):
"""更新用户印象
Args: