re-style: 格式化代码
This commit is contained in:
@@ -2,7 +2,8 @@ import copy
|
||||
import datetime
|
||||
import hashlib
|
||||
import time
|
||||
from typing import Any, Callable, Dict, Union, Optional
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from json_repair import repair_json
|
||||
@@ -86,7 +87,7 @@ class PersonInfoManager:
|
||||
logger.error(f"从 SQLAlchemy 加载 person_name_list 失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
|
||||
def get_person_id(platform: str, user_id: int | str) -> str:
|
||||
"""获取唯一id(同步)
|
||||
|
||||
说明: 原来该方法为异步并在内部尝试执行数据库检查/迁移,导致在许多调用处未 await 时返回 coroutine 对象。
|
||||
@@ -167,7 +168,7 @@ class PersonInfoManager:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_person_info(person_id: str, data: Optional[dict] = None):
|
||||
async def create_person_info(person_id: str, data: dict | None = None):
|
||||
"""创建一个项"""
|
||||
if not person_id:
|
||||
logger.debug("创建失败,person_id不存在")
|
||||
@@ -228,7 +229,7 @@ class PersonInfoManager:
|
||||
await _db_create_async(final_data)
|
||||
|
||||
@staticmethod
|
||||
async def _safe_create_person_info(person_id: str, data: Optional[dict] = None):
|
||||
async def _safe_create_person_info(person_id: str, data: dict | None = None):
|
||||
"""安全地创建用户信息,处理竞态条件"""
|
||||
if not person_id:
|
||||
logger.debug("创建失败,person_id不存在")
|
||||
@@ -296,7 +297,7 @@ class PersonInfoManager:
|
||||
|
||||
await _db_safe_create_async(final_data)
|
||||
|
||||
async def update_one_field(self, person_id: str, field_name: str, value, data: Optional[Dict] = None):
|
||||
async def update_one_field(self, person_id: str, field_name: str, value, data: dict | None = None):
|
||||
"""更新某一个字段,会补全"""
|
||||
# 获取 SQLAlchemy 模型的所有字段名
|
||||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||||
@@ -628,7 +629,7 @@ class PersonInfoManager:
|
||||
async def get_specific_value_list(
|
||||
field_name: str,
|
||||
way: Callable[[Any], bool],
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
获取满足条件的字段值字典
|
||||
"""
|
||||
@@ -649,18 +650,18 @@ class PersonInfoManager:
|
||||
found_results[record.person_id] = value
|
||||
except Exception as e_query:
|
||||
logger.error(
|
||||
f"数据库查询失败 (SQLAlchemy specific_value_list for {f_name}): {str(e_query)}", exc_info=True
|
||||
f"数据库查询失败 (SQLAlchemy specific_value_list for {f_name}): {e_query!s}", exc_info=True
|
||||
)
|
||||
return found_results
|
||||
|
||||
try:
|
||||
return await _db_get_specific_async(field_name)
|
||||
except Exception as e:
|
||||
logger.error(f"执行 get_specific_value_list 时出错: {str(e)}", exc_info=True)
|
||||
logger.error(f"执行 get_specific_value_list 时出错: {e!s}", exc_info=True)
|
||||
return {}
|
||||
|
||||
async def get_or_create_person(
|
||||
self, platform: str, user_id: int, nickname: str, user_cardname: str, user_avatar: Optional[str] = None
|
||||
self, platform: str, user_id: int, nickname: str, user_cardname: str, user_avatar: str | None = None
|
||||
) -> str:
|
||||
"""
|
||||
根据 platform 和 user_id 获取 person_id。
|
||||
|
||||
@@ -1,20 +1,21 @@
|
||||
import time
|
||||
import traceback
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
from typing import List, Dict, Any
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.person_info.relationship_manager import get_relationship_manager
|
||||
from src.person_info.person_info import get_person_info_manager, PersonInfoManager
|
||||
import time
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
num_new_messages_since,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
from src.person_info.relationship_manager import get_relationship_manager
|
||||
|
||||
logger = get_logger("relationship_builder")
|
||||
|
||||
@@ -45,7 +46,7 @@ class RelationshipBuilder:
|
||||
self.chat_id = chat_id
|
||||
# 新的消息段缓存结构:
|
||||
# {person_id: [{"start_time": float, "end_time": float, "last_msg_time": float, "message_count": int}, ...]}
|
||||
self.person_engaged_cache: Dict[str, List[Dict[str, Any]]] = {}
|
||||
self.person_engaged_cache: dict[str, list[dict[str, Any]]] = {}
|
||||
|
||||
# 持久化存储文件路径
|
||||
self.cache_file_path = os.path.join("data", "relationship", f"relationship_cache_{self.chat_id}.pkl")
|
||||
@@ -401,7 +402,7 @@ class RelationshipBuilder:
|
||||
# 负责触发关系构建、整合消息段、更新用户印象
|
||||
# ================================
|
||||
|
||||
async def update_impression_on_segments(self, person_id: str, chat_id: str, segments: List[Dict[str, Any]]):
|
||||
async def update_impression_on_segments(self, person_id: str, chat_id: str, segments: list[dict[str, Any]]):
|
||||
"""基于消息段更新用户印象"""
|
||||
original_segment_count = len(segments)
|
||||
logger.debug(f"开始为 {person_id} 基于 {original_segment_count} 个消息段更新印象")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Dict, Optional, List, Any
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .relationship_builder import RelationshipBuilder
|
||||
|
||||
logger = get_logger("relationship_builder_manager")
|
||||
@@ -13,7 +14,7 @@ class RelationshipBuilderManager:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.builders: Dict[str, RelationshipBuilder] = {}
|
||||
self.builders: dict[str, RelationshipBuilder] = {}
|
||||
|
||||
def get_or_create_builder(self, chat_id: str) -> RelationshipBuilder:
|
||||
"""获取或创建关系构建器
|
||||
@@ -30,7 +31,7 @@ class RelationshipBuilderManager:
|
||||
|
||||
return self.builders[chat_id]
|
||||
|
||||
def get_builder(self, chat_id: str) -> Optional[RelationshipBuilder]:
|
||||
def get_builder(self, chat_id: str) -> RelationshipBuilder | None:
|
||||
"""获取关系构建器
|
||||
|
||||
Args:
|
||||
@@ -56,7 +57,7 @@ class RelationshipBuilderManager:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_all_chat_ids(self) -> List[str]:
|
||||
def get_all_chat_ids(self) -> list[str]:
|
||||
"""获取所有管理的聊天ID列表
|
||||
|
||||
Returns:
|
||||
@@ -64,7 +65,7 @@ class RelationshipBuilderManager:
|
||||
"""
|
||||
return list(self.builders.keys())
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
def get_status(self) -> dict[str, Any]:
|
||||
"""获取管理器状态
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
import time
|
||||
import traceback
|
||||
import orjson
|
||||
from typing import Any
|
||||
|
||||
from typing import List, Dict, Any
|
||||
import orjson
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
|
||||
|
||||
logger = get_logger("relationship_fetcher")
|
||||
|
||||
|
||||
@@ -64,10 +63,10 @@ class RelationshipFetcher:
|
||||
self.chat_id = chat_id
|
||||
|
||||
# 信息获取缓存:记录正在获取的信息请求
|
||||
self.info_fetching_cache: List[Dict[str, Any]] = []
|
||||
self.info_fetching_cache: list[dict[str, Any]] = []
|
||||
|
||||
# 信息结果缓存:存储已获取的信息结果,带TTL
|
||||
self.info_fetched_cache: Dict[str, Dict[str, Any]] = {}
|
||||
self.info_fetched_cache: dict[str, dict[str, Any]] = {}
|
||||
# 结构:{person_id: {info_type: {"info": str, "ttl": int, "start_time": float, "person_name": str, "unknown": bool}}}
|
||||
|
||||
# LLM模型配置
|
||||
@@ -471,7 +470,7 @@ class RelationshipFetcherManager:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._fetchers: Dict[str, RelationshipFetcher] = {}
|
||||
self._fetchers: dict[str, RelationshipFetcher] = {}
|
||||
|
||||
def get_fetcher(self, chat_id: str) -> RelationshipFetcher:
|
||||
"""获取或创建指定 chat_id 的 RelationshipFetcher
|
||||
@@ -499,7 +498,7 @@ class RelationshipFetcherManager:
|
||||
"""清空所有 RelationshipFetcher"""
|
||||
self._fetchers.clear()
|
||||
|
||||
def get_active_chat_ids(self) -> List[str]:
|
||||
def get_active_chat_ids(self) -> list[str]:
|
||||
"""获取所有活跃的 chat_id 列表"""
|
||||
return list(self._fetchers.keys())
|
||||
|
||||
|
||||
@@ -1,18 +1,21 @@
|
||||
from src.common.logger import get_logger
|
||||
from .person_info import PersonInfoManager, get_person_info_manager
|
||||
import time
|
||||
import random
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
import orjson
|
||||
from json_repair import repair_json
|
||||
import time
|
||||
from datetime import datetime
|
||||
from difflib import SequenceMatcher
|
||||
from typing import Any
|
||||
|
||||
import jieba
|
||||
import orjson
|
||||
from json_repair import repair_json
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
from .person_info import PersonInfoManager, get_person_info_manager
|
||||
|
||||
logger = get_logger("relation")
|
||||
|
||||
@@ -54,7 +57,7 @@ class RelationshipManager:
|
||||
# person_id=person_id, user_nickname=user_nickname, user_cardname=user_cardname, user_avatar=user_avatar
|
||||
# )
|
||||
|
||||
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[Dict[str, Any]]):
|
||||
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: list[dict[str, Any]]):
|
||||
"""更新用户印象
|
||||
|
||||
Args:
|
||||
|
||||
Reference in New Issue
Block a user