From 1aa2734d62d2eebb25fefbe96e56d41a8dcfc216 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Thu, 17 Jul 2025 00:10:41 +0800 Subject: [PATCH] typing fix --- bot.py | 27 ++--- src/chat/express/expression_learner.py | 71 ++++++----- src/chat/memory_system/instant_memory.py | 92 +++++++------- src/chat/message_receive/chat_stream.py | 2 +- src/chat/message_receive/message.py | 21 ++-- src/chat/planner_actions/action_manager.py | 2 +- src/chat/utils/chat_message_builder.py | 6 +- src/chat/utils/statistic.py | 37 +++--- src/chat/willing/willing_manager.py | 5 +- src/common/database/database_model.py | 39 +++--- src/config/auto_update.py | 16 +-- src/config/config.py | 37 +++--- src/individuality/not_using/offline_llm.py | 4 +- src/individuality/not_using/per_bf_gen.py | 7 +- src/main.py | 6 +- src/mood/mood_manager.py | 3 + src/person_info/relationship_builder.py | 2 +- src/plugin_system/__init__.py | 6 +- src/plugin_system/apis/send_api.py | 39 +++--- src/plugin_system/base/base_action.py | 24 ++-- src/plugin_system/base/base_command.py | 2 +- src/plugin_system/core/component_registry.py | 112 +++++++++--------- src/plugin_system/core/dependency_manager.py | 4 +- src/plugin_system/core/plugin_manager.py | 44 +++---- .../tool_can_use/compare_numbers_tool.py | 6 +- src/tools/tool_can_use/rename_person_tool.py | 8 +- 26 files changed, 329 insertions(+), 293 deletions(-) diff --git a/bot.py b/bot.py index 5548c1725..72ea65d29 100644 --- a/bot.py +++ b/bot.py @@ -146,7 +146,7 @@ def _calculate_file_hash(file_path: Path, file_type: str) -> str: if not file_path.exists(): logger.error(f"{file_type} 文件不存在") raise FileNotFoundError(f"{file_type} 文件不存在") - + with open(file_path, "r", encoding="utf-8") as f: content = f.read() return hashlib.md5(content.encode("utf-8")).hexdigest() @@ -154,21 +154,21 @@ def _calculate_file_hash(file_path: Path, file_type: str) -> str: def _check_agreement_status(file_hash: str, confirm_file: Path, env_var: str) -> tuple[bool, bool]: """检查协议确认状态 - + Returns: tuple[bool, bool]: (已确认, 未更新) """ # 检查环境变量确认 if file_hash == os.getenv(env_var): return True, False - + # 检查确认文件 if confirm_file.exists(): with open(confirm_file, "r", encoding="utf-8") as f: confirmed_content = f.read() if file_hash == confirmed_content: return True, False - + return False, True @@ -178,7 +178,7 @@ def _prompt_user_confirmation(eula_hash: str, privacy_hash: str) -> None: confirm_logger.critical( f'输入"同意"或"confirmed"或设置环境变量"EULA_AGREE={eula_hash}"和"PRIVACY_AGREE={privacy_hash}"继续运行' ) - + while True: user_input = input().strip().lower() if user_input in ["同意", "confirmed"]: @@ -186,13 +186,12 @@ def _prompt_user_confirmation(eula_hash: str, privacy_hash: str) -> None: confirm_logger.critical('请输入"同意"或"confirmed"以继续运行') -def _save_confirmations(eula_updated: bool, privacy_updated: bool, - eula_hash: str, privacy_hash: str) -> None: +def _save_confirmations(eula_updated: bool, privacy_updated: bool, eula_hash: str, privacy_hash: str) -> None: """保存用户确认结果""" if eula_updated: logger.info(f"更新EULA确认文件{eula_hash}") Path("eula.confirmed").write_text(eula_hash, encoding="utf-8") - + if privacy_updated: logger.info(f"更新隐私条款确认文件{privacy_hash}") Path("privacy.confirmed").write_text(privacy_hash, encoding="utf-8") @@ -203,19 +202,17 @@ def check_eula(): # 计算文件哈希值 eula_hash = _calculate_file_hash(Path("EULA.md"), "EULA.md") privacy_hash = _calculate_file_hash(Path("PRIVACY.md"), "PRIVACY.md") - + # 检查确认状态 - eula_confirmed, eula_updated = _check_agreement_status( - eula_hash, Path("eula.confirmed"), "EULA_AGREE" - ) + eula_confirmed, eula_updated = _check_agreement_status(eula_hash, Path("eula.confirmed"), "EULA_AGREE") privacy_confirmed, privacy_updated = _check_agreement_status( privacy_hash, Path("privacy.confirmed"), "PRIVACY_AGREE" ) - + # 早期返回:如果都已确认且未更新 if eula_confirmed and privacy_confirmed: return - + # 如果有更新,需要重新确认 if eula_updated or privacy_updated: _prompt_user_confirmation(eula_hash, privacy_hash) @@ -225,7 +222,7 @@ def check_eula(): def raw_main(): # 利用 TZ 环境变量设定程序工作的时区 if platform.system().lower() != "windows": - time.tzset() + time.tzset() # type: ignore check_eula() logger.info("检查EULA和隐私条款完成") diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index 4139c65a5..e02ff7311 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -107,11 +107,12 @@ class ExpressionLearner: last_active_time = expr.get("last_active_time", time.time()) # 查重:同chat_id+type+situation+style from src.common.database.database_model import Expression + query = Expression.select().where( - (Expression.chat_id == chat_id) & - (Expression.type == type_str) & - (Expression.situation == situation) & - (Expression.style == style_val) + (Expression.chat_id == chat_id) + & (Expression.type == type_str) + & (Expression.situation == situation) + & (Expression.style == style_val) ) if query.exists(): expr_obj = query.get() @@ -125,7 +126,7 @@ class ExpressionLearner: count=count, last_active_time=last_active_time, chat_id=chat_id, - type=type_str + type=type_str, ) logger.info(f"已迁移 {expr_file} 到数据库") except Exception as e: @@ -149,24 +150,28 @@ class ExpressionLearner: # 直接从数据库查询 style_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "style")) for expr in style_query: - learnt_style_expressions.append({ - "situation": expr.situation, - "style": expr.style, - "count": expr.count, - "last_active_time": expr.last_active_time, - "source_id": chat_id, - "type": "style" - }) + learnt_style_expressions.append( + { + "situation": expr.situation, + "style": expr.style, + "count": expr.count, + "last_active_time": expr.last_active_time, + "source_id": chat_id, + "type": "style", + } + ) grammar_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "grammar")) for expr in grammar_query: - learnt_grammar_expressions.append({ - "situation": expr.situation, - "style": expr.style, - "count": expr.count, - "last_active_time": expr.last_active_time, - "source_id": chat_id, - "type": "grammar" - }) + learnt_grammar_expressions.append( + { + "situation": expr.situation, + "style": expr.style, + "count": expr.count, + "last_active_time": expr.last_active_time, + "source_id": chat_id, + "type": "grammar", + } + ) return learnt_style_expressions, learnt_grammar_expressions def is_similar(self, s1: str, s2: str) -> bool: @@ -213,14 +218,16 @@ class ExpressionLearner: logger.error(f"全局衰减{type}表达方式失败: {e}") continue + learnt_style: Optional[List[Tuple[str, str, str]]] = [] + learnt_grammar: Optional[List[Tuple[str, str, str]]] = [] # 学习新的表达方式(这里会进行局部衰减) for _ in range(3): - learnt_style: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="style", num=25) + learnt_style = await self.learn_and_store(type="style", num=25) if not learnt_style: return [], [] for _ in range(1): - learnt_grammar: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="grammar", num=10) + learnt_grammar = await self.learn_and_store(type="grammar", num=10) if not learnt_grammar: return [], [] @@ -321,10 +328,10 @@ class ExpressionLearner: for new_expr in expr_list: # 查找是否已存在相似表达方式 query = Expression.select().where( - (Expression.chat_id == chat_id) & - (Expression.type == type) & - (Expression.situation == new_expr["situation"]) & - (Expression.style == new_expr["style"]) + (Expression.chat_id == chat_id) + & (Expression.type == type) + & (Expression.situation == new_expr["situation"]) + & (Expression.style == new_expr["style"]) ) if query.exists(): expr_obj = query.get() @@ -342,13 +349,17 @@ class ExpressionLearner: count=1, last_active_time=current_time, chat_id=chat_id, - type=type + type=type, ) # 限制最大数量 - exprs = list(Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == type)).order_by(Expression.count.asc())) + exprs = list( + Expression.select() + .where((Expression.chat_id == chat_id) & (Expression.type == type)) + .order_by(Expression.count.asc()) + ) if len(exprs) > MAX_EXPRESSION_COUNT: # 删除count最小的多余表达方式 - for expr in exprs[:len(exprs) - MAX_EXPRESSION_COUNT]: + for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]: expr.delete_instance() return learnt_expressions diff --git a/src/chat/memory_system/instant_memory.py b/src/chat/memory_system/instant_memory.py index 5b38bbb0b..f7e54f8e9 100644 --- a/src/chat/memory_system/instant_memory.py +++ b/src/chat/memory_system/instant_memory.py @@ -9,51 +9,49 @@ from src.common.logger import get_logger import traceback from src.config.config import global_config -from src.common.database.database_model import Memory # Peewee Models导入 +from src.common.database.database_model import Memory # Peewee Models导入 logger = get_logger(__name__) + class MemoryItem: - def __init__(self,memory_id:str,chat_id:str,memory_text:str,keywords:list[str]): + def __init__(self, memory_id: str, chat_id: str, memory_text: str, keywords: list[str]): self.memory_id = memory_id self.chat_id = chat_id - self.memory_text:str = memory_text - self.keywords:list[str] = keywords - self.create_time:float = time.time() - self.last_view_time:float = time.time() - + self.memory_text: str = memory_text + self.keywords: list[str] = keywords + self.create_time: float = time.time() + self.last_view_time: float = time.time() + + class MemoryManager: def __init__(self): # self.memory_items:list[MemoryItem] = [] pass - - - class InstantMemory: - def __init__(self,chat_id): - self.chat_id = chat_id + def __init__(self, chat_id): + self.chat_id = chat_id self.last_view_time = time.time() self.summary_model = LLMRequest( model=global_config.model.memory, temperature=0.5, request_type="memory.summary", ) - - async def if_need_build(self,text): + + async def if_need_build(self, text): prompt = f""" 请判断以下内容中是否有值得记忆的信息,如果有,请输出1,否则输出0 {text} 请只输出1或0就好 """ - + try: - response,_ = await self.summary_model.generate_response_async(prompt) + response, _ = await self.summary_model.generate_response_async(prompt) print(prompt) print(response) - - + if "1" in response: return True else: @@ -61,8 +59,8 @@ class InstantMemory: except Exception as e: logger.error(f"判断是否需要记忆出现错误:{str(e)} {traceback.format_exc()}") return False - - async def build_memory(self,text): + + async def build_memory(self, text): prompt = f""" 以下内容中存在值得记忆的信息,请你从中总结出一段值得记忆的信息,并输出 {text} @@ -73,7 +71,7 @@ class InstantMemory: }} """ try: - response,_ = await self.summary_model.generate_response_async(prompt) + response, _ = await self.summary_model.generate_response_async(prompt) print(prompt) print(response) if not response: @@ -81,53 +79,53 @@ class InstantMemory: try: repaired = repair_json(response) result = json.loads(repaired) - memory_text = result.get('memory_text', '') - keywords = result.get('keywords', '') + memory_text = result.get("memory_text", "") + keywords = result.get("keywords", "") if isinstance(keywords, str): - keywords_list = [k.strip() for k in keywords.split('/') if k.strip()] + keywords_list = [k.strip() for k in keywords.split("/") if k.strip()] elif isinstance(keywords, list): keywords_list = keywords else: keywords_list = [] - return {'memory_text': memory_text, 'keywords': keywords_list} + return {"memory_text": memory_text, "keywords": keywords_list} except Exception as parse_e: logger.error(f"解析记忆json失败:{str(parse_e)} {traceback.format_exc()}") return None except Exception as e: logger.error(f"构建记忆出现错误:{str(e)} {traceback.format_exc()}") return None - - async def create_and_store_memory(self,text): + async def create_and_store_memory(self, text): if_need = await self.if_need_build(text) if if_need: logger.info(f"需要记忆:{text}") - memory = await self.build_memory(text) - if memory and memory.get('memory_text'): + memory = await self.build_memory(text) + if memory and memory.get("memory_text"): memory_id = f"{self.chat_id}_{time.time()}" memory_item = MemoryItem( memory_id=memory_id, chat_id=self.chat_id, - memory_text=memory['memory_text'], - keywords=memory.get('keywords', []) + memory_text=memory["memory_text"], + keywords=memory.get("keywords", []), ) await self.store_memory(memory_item) else: logger.info(f"不需要记忆:{text}") - - async def store_memory(self,memory_item:MemoryItem): + + async def store_memory(self, memory_item: MemoryItem): memory = Memory( memory_id=memory_item.memory_id, chat_id=memory_item.chat_id, memory_text=memory_item.memory_text, keywords=memory_item.keywords, create_time=memory_item.create_time, - last_view_time=memory_item.last_view_time + last_view_time=memory_item.last_view_time, ) memory.save() - - async def get_memory(self,target:str): + + async def get_memory(self, target: str): from json_repair import repair_json + prompt = f""" 请根据以下发言内容,判断是否需要提取记忆 {target} @@ -144,7 +142,7 @@ class InstantMemory: 请只输出json格式,不要输出其他多余内容 """ try: - response,_ = await self.summary_model.generate_response_async(prompt) + response, _ = await self.summary_model.generate_response_async(prompt) print(prompt) print(response) if not response: @@ -153,15 +151,15 @@ class InstantMemory: repaired = repair_json(response) result = json.loads(repaired) # 解析keywords - keywords = result.get('keywords', '') + keywords = result.get("keywords", "") if isinstance(keywords, str): - keywords_list = [k.strip() for k in keywords.split('/') if k.strip()] + keywords_list = [k.strip() for k in keywords.split("/") if k.strip()] elif isinstance(keywords, list): keywords_list = keywords else: keywords_list = [] # 解析time为时间段 - time_str = result.get('time', '').strip() + time_str = result.get("time", "").strip() start_time, end_time = self._parse_time_range(time_str) logger.info(f"start_time: {start_time}, end_time: {end_time}") # 检索包含关键词的记忆 @@ -170,16 +168,15 @@ class InstantMemory: start_ts = start_time.timestamp() end_ts = end_time.timestamp() query = Memory.select().where( - (Memory.chat_id == self.chat_id) & - (Memory.create_time >= start_ts) & - (Memory.create_time < end_ts) + (Memory.chat_id == self.chat_id) + & (Memory.create_time >= start_ts) # type: ignore + & (Memory.create_time < end_ts) # type: ignore ) else: query = Memory.select().where(Memory.chat_id == self.chat_id) - for mem in query: - #对每条记忆 + # 对每条记忆 mem_keywords = mem.keywords or [] parsed = ast.literal_eval(mem_keywords) if isinstance(parsed, list): @@ -212,6 +209,7 @@ class InstantMemory: - 空字符串:返回(None, None) """ from datetime import datetime, timedelta + now = datetime.now() if not time_str: return 0, now @@ -251,8 +249,8 @@ class InstantMemory: if m: months = int(m.group(1)) # 近似每月30天 - start = (now - timedelta(days=months*30)).replace(hour=0, minute=0, second=0, microsecond=0) + start = (now - timedelta(days=months * 30)).replace(hour=0, minute=0, second=0, microsecond=0) end = start + timedelta(days=1) return start, end # 其他无法解析 - return 0, now \ No newline at end of file + return 0, now diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 8b71314a6..e4a61900e 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -30,7 +30,7 @@ class ChatMessageContext: def get_template_name(self) -> Optional[str]: """获取模板名称""" if self.message.message_info.template_info and not self.message.message_info.template_info.template_default: - return self.message.message_info.template_info.template_name + return self.message.message_info.template_info.template_name # type: ignore return None def get_last_message(self) -> "MessageRecv": diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index e6b6741f0..487c7d036 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -107,9 +107,9 @@ class MessageRecv(Message): self.is_picid = False self.has_picid = False self.is_mentioned = None - + self.is_command = False - + self.priority_mode = "interest" self.priority_info = None self.interest_value: float = None # type: ignore @@ -181,6 +181,7 @@ class MessageRecv(Message): logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}") return f"[处理失败的{segment.type}消息]" + @dataclass class MessageRecvS4U(MessageRecv): def __init__(self, message_dict: dict[str, Any]): @@ -194,10 +195,10 @@ class MessageRecvS4U(MessageRecv): self.superchat_price = None self.superchat_message_text = None self.is_screen = False - + async def process(self) -> None: self.processed_plain_text = await self._process_message_segments(self.message_segment) - + async def _process_single_segment(self, segment: Seg) -> str: """处理单个消息段 @@ -252,7 +253,7 @@ class MessageRecvS4U(MessageRecv): elif segment.type == "gift": self.is_gift = True # 解析gift_info,格式为"名称:数量" - name, count = segment.data.split(":", 1) + name, count = segment.data.split(":", 1) # type: ignore self.gift_info = segment.data self.gift_name = name.strip() self.gift_count = int(count.strip()) @@ -260,13 +261,15 @@ class MessageRecvS4U(MessageRecv): elif segment.type == "superchat": self.is_superchat = True self.superchat_info = segment.data - price,message_text = segment.data.split(":", 1) + price, message_text = segment.data.split(":", 1) # type: ignore self.superchat_price = price.strip() self.superchat_message_text = message_text.strip() - + self.processed_plain_text = str(self.superchat_message_text) - self.processed_plain_text += f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)" - + self.processed_plain_text += ( + f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)" + ) + return self.processed_plain_text elif segment.type == "screen": self.is_screen = True diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index 6c82625b3..a4876a46d 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -80,7 +80,7 @@ class ActionManager: chat_stream: ChatStream, log_prefix: str, shutting_down: bool = False, - action_message: dict = None, + action_message: Optional[dict] = None, ) -> Optional[BaseAction]: """ 创建动作处理器实例 diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index bb32e63a2..3a08ca72b 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -252,7 +252,7 @@ def _build_readable_messages_internal( pic_id_mapping: Optional[Dict[str, str]] = None, pic_counter: int = 1, show_pic: bool = True, - message_id_list: List[Dict[str, Any]] = None, + message_id_list: Optional[List[Dict[str, Any]]] = None, ) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]: """ 内部辅助函数,构建可读消息字符串和原始消息详情列表。 @@ -615,7 +615,7 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str: for action in actions: action_time = action.get("time", current_time) action_name = action.get("action_name", "未知动作") - if action_name == "no_action" or action_name == "no_reply": + if action_name in ["no_action", "no_reply"]: continue action_prompt_display = action.get("action_prompt_display", "无具体内容") @@ -697,7 +697,7 @@ def build_readable_messages( truncate: bool = False, show_actions: bool = False, show_pic: bool = True, - message_id_list: List[Dict[str, Any]] = None, + message_id_list: Optional[List[Dict[str, Any]]] = None, ) -> str: # sourcery skip: extract-method """ 将消息列表转换为可读的文本格式。 diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index 4e0edd31f..0aff5102e 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -1211,7 +1211,7 @@ class StatisticOutputTask(AsyncTask): f.write(html_template) def _generate_focus_tab(self, stat: dict[str, Any]) -> str: - # sourcery skip: for-append-to-extend, list-comprehension, use-any + # sourcery skip: for-append-to-extend, list-comprehension, use-any, use-named-expression, use-next """生成Focus统计独立分页的HTML内容""" # 为每个时间段准备Focus数据 @@ -1559,6 +1559,7 @@ class StatisticOutputTask(AsyncTask): """ def _generate_versions_tab(self, stat: dict[str, Any]) -> str: + # sourcery skip: use-named-expression, use-next """生成版本对比独立分页的HTML内容""" # 为每个时间段准备版本对比数据 @@ -2306,13 +2307,13 @@ class AsyncStatisticOutputTask(AsyncTask): # 复用 StatisticOutputTask 的所有方法 def _collect_all_statistics(self, now: datetime): - return StatisticOutputTask._collect_all_statistics(self, now) + return StatisticOutputTask._collect_all_statistics(self, now) # type: ignore def _statistic_console_output(self, stats: Dict[str, Any], now: datetime): - return StatisticOutputTask._statistic_console_output(self, stats, now) + return StatisticOutputTask._statistic_console_output(self, stats, now) # type: ignore def _generate_html_report(self, stats: dict[str, Any], now: datetime): - return StatisticOutputTask._generate_html_report(self, stats, now) + return StatisticOutputTask._generate_html_report(self, stats, now) # type: ignore # 其他需要的方法也可以类似复用... @staticmethod @@ -2324,10 +2325,10 @@ class AsyncStatisticOutputTask(AsyncTask): return StatisticOutputTask._collect_online_time_for_period(collect_period, now) def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: - return StatisticOutputTask._collect_message_count_for_period(self, collect_period) + return StatisticOutputTask._collect_message_count_for_period(self, collect_period) # type: ignore def _collect_focus_statistics_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: - return StatisticOutputTask._collect_focus_statistics_for_period(self, collect_period) + return StatisticOutputTask._collect_focus_statistics_for_period(self, collect_period) # type: ignore def _process_focus_file_data( self, @@ -2336,10 +2337,10 @@ class AsyncStatisticOutputTask(AsyncTask): collect_period: List[Tuple[str, datetime]], file_time: datetime, ): - return StatisticOutputTask._process_focus_file_data(self, cycles_data, stats, collect_period, file_time) + return StatisticOutputTask._process_focus_file_data(self, cycles_data, stats, collect_period, file_time) # type: ignore def _calculate_focus_averages(self, stats: Dict[str, Any]): - return StatisticOutputTask._calculate_focus_averages(self, stats) + return StatisticOutputTask._calculate_focus_averages(self, stats) # type: ignore @staticmethod def _format_total_stat(stats: Dict[str, Any]) -> str: @@ -2347,31 +2348,31 @@ class AsyncStatisticOutputTask(AsyncTask): @staticmethod def _format_model_classified_stat(stats: Dict[str, Any]) -> str: - return StatisticOutputTask._format_model_classified_stat(stats) + return StatisticOutputTask._format_model_classified_stat(stats) # type: ignore def _format_chat_stat(self, stats: Dict[str, Any]) -> str: - return StatisticOutputTask._format_chat_stat(self, stats) + return StatisticOutputTask._format_chat_stat(self, stats) # type: ignore def _format_focus_stat(self, stats: Dict[str, Any]) -> str: - return StatisticOutputTask._format_focus_stat(self, stats) + return StatisticOutputTask._format_focus_stat(self, stats) # type: ignore def _generate_chart_data(self, stat: dict[str, Any]) -> dict: - return StatisticOutputTask._generate_chart_data(self, stat) + return StatisticOutputTask._generate_chart_data(self, stat) # type: ignore def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict: - return StatisticOutputTask._collect_interval_data(self, now, hours, interval_minutes) + return StatisticOutputTask._collect_interval_data(self, now, hours, interval_minutes) # type: ignore def _generate_chart_tab(self, chart_data: dict) -> str: - return StatisticOutputTask._generate_chart_tab(self, chart_data) + return StatisticOutputTask._generate_chart_tab(self, chart_data) # type: ignore def _get_chat_display_name_from_id(self, chat_id: str) -> str: - return StatisticOutputTask._get_chat_display_name_from_id(self, chat_id) + return StatisticOutputTask._get_chat_display_name_from_id(self, chat_id) # type: ignore def _generate_focus_tab(self, stat: dict[str, Any]) -> str: - return StatisticOutputTask._generate_focus_tab(self, stat) + return StatisticOutputTask._generate_focus_tab(self, stat) # type: ignore def _generate_versions_tab(self, stat: dict[str, Any]) -> str: - return StatisticOutputTask._generate_versions_tab(self, stat) + return StatisticOutputTask._generate_versions_tab(self, stat) # type: ignore def _convert_defaultdict_to_dict(self, data): - return StatisticOutputTask._convert_defaultdict_to_dict(self, data) + return StatisticOutputTask._convert_defaultdict_to_dict(self, data) # type: ignore diff --git a/src/chat/willing/willing_manager.py b/src/chat/willing/willing_manager.py index 29110ef94..6c53273f5 100644 --- a/src/chat/willing/willing_manager.py +++ b/src/chat/willing/willing_manager.py @@ -2,14 +2,13 @@ import importlib import asyncio from abc import ABC, abstractmethod -from typing import Dict, Optional +from typing import Dict, Optional, Any from rich.traceback import install from dataclasses import dataclass from src.common.logger import get_logger from src.config.config import global_config from src.chat.message_receive.chat_stream import ChatStream, GroupInfo -from src.chat.message_receive.message import MessageRecv from src.person_info.person_info import PersonInfoManager, get_person_info_manager install(extra_lines=3) @@ -54,7 +53,7 @@ class WillingInfo: interested_rate (float): 兴趣度 """ - message: MessageRecv + message: Dict[str, Any] # 原始消息数据 chat: ChatStream person_info_manager: PersonInfoManager chat_id: str diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index 8258ac9fb..4b60dfa10 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -65,7 +65,7 @@ class ChatStreams(BaseModel): # user_cardname 可能为空字符串或不存在,设置 null=True 更具灵活性。 user_cardname = TextField(null=True) - class Meta: + class Meta: # type: ignore # 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。 # 如果不使用带有数据库实例的 BaseModel,或者想覆盖它, # 请取消注释并在下面设置数据库实例: @@ -89,7 +89,7 @@ class LLMUsage(BaseModel): status = TextField() timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引 - class Meta: + class Meta: # type: ignore # 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。 # database = db table_name = "llm_usage" @@ -112,7 +112,7 @@ class Emoji(BaseModel): usage_count = IntegerField(default=0) # 使用次数(被使用的次数) last_used_time = FloatField(null=True) # 上次使用时间 - class Meta: + class Meta: # type: ignore # database = db # 继承自 BaseModel table_name = "emoji" @@ -162,7 +162,8 @@ class Messages(BaseModel): is_emoji = BooleanField(default=False) is_picid = BooleanField(default=False) is_command = BooleanField(default=False) - class Meta: + + class Meta: # type: ignore # database = db # 继承自 BaseModel table_name = "messages" @@ -186,7 +187,7 @@ class ActionRecords(BaseModel): chat_info_stream_id = TextField() chat_info_platform = TextField() - class Meta: + class Meta: # type: ignore # database = db # 继承自 BaseModel table_name = "action_records" @@ -206,7 +207,7 @@ class Images(BaseModel): type = TextField() # 图像类型,例如 "emoji" vlm_processed = BooleanField(default=False) # 是否已经过VLM处理 - class Meta: + class Meta: # type: ignore table_name = "images" @@ -220,7 +221,7 @@ class ImageDescriptions(BaseModel): description = TextField() # 图像的描述 timestamp = FloatField() # 时间戳 - class Meta: + class Meta: # type: ignore # database = db # 继承自 BaseModel table_name = "image_descriptions" @@ -236,7 +237,7 @@ class OnlineTime(BaseModel): start_timestamp = DateTimeField(default=datetime.datetime.now) end_timestamp = DateTimeField(index=True) - class Meta: + class Meta: # type: ignore # database = db # 继承自 BaseModel table_name = "online_time" @@ -263,10 +264,11 @@ class PersonInfo(BaseModel): last_know = FloatField(null=True) # 最后一次印象总结时间 attitude = IntegerField(null=True, default=50) # 态度,0-100,从非常厌恶到十分喜欢 - class Meta: + class Meta: # type: ignore # database = db # 继承自 BaseModel table_name = "person_info" + class Memory(BaseModel): memory_id = TextField(index=True) chat_id = TextField(null=True) @@ -274,10 +276,11 @@ class Memory(BaseModel): keywords = TextField(null=True) create_time = FloatField(null=True) last_view_time = FloatField(null=True) - - class Meta: + + class Meta: # type: ignore table_name = "memory" + class Knowledges(BaseModel): """ 用于存储知识库条目的模型。 @@ -287,10 +290,11 @@ class Knowledges(BaseModel): embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表 # 可以添加其他元数据字段,如 source, create_time 等 - class Meta: + class Meta: # type: ignore # database = db # 继承自 BaseModel table_name = "knowledges" + class Expression(BaseModel): """ 用于存储表达风格的模型。 @@ -302,10 +306,11 @@ class Expression(BaseModel): last_active_time = FloatField() chat_id = TextField(index=True) type = TextField() - - class Meta: + + class Meta: # type: ignore table_name = "expression" + class ThinkingLog(BaseModel): chat_id = TextField(index=True) trigger_text = TextField(null=True) @@ -326,7 +331,7 @@ class ThinkingLog(BaseModel): # And: import datetime created_at = DateTimeField(default=datetime.datetime.now) - class Meta: + class Meta: # type: ignore table_name = "thinking_logs" @@ -341,7 +346,7 @@ class GraphNodes(BaseModel): created_time = FloatField() # 创建时间戳 last_modified = FloatField() # 最后修改时间戳 - class Meta: + class Meta: # type: ignore table_name = "graph_nodes" @@ -357,7 +362,7 @@ class GraphEdges(BaseModel): created_time = FloatField() # 创建时间戳 last_modified = FloatField() # 最后修改时间戳 - class Meta: + class Meta: # type: ignore table_name = "graph_edges" diff --git a/src/config/auto_update.py b/src/config/auto_update.py index 355ebc55a..8d097ec49 100644 --- a/src/config/auto_update.py +++ b/src/config/auto_update.py @@ -7,13 +7,13 @@ from datetime import datetime def get_key_comment(toml_table, key): # 获取key的注释(如果有) - if hasattr(toml_table, 'trivia') and hasattr(toml_table.trivia, 'comment'): + if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"): return toml_table.trivia.comment - if hasattr(toml_table, 'value') and isinstance(toml_table.value, dict): + if hasattr(toml_table, "value") and isinstance(toml_table.value, dict): item = toml_table.value.get(key) - if item is not None and hasattr(item, 'trivia'): + if item is not None and hasattr(item, "trivia"): return item.trivia.comment - if hasattr(toml_table, 'keys'): + if hasattr(toml_table, "keys"): for k in toml_table.keys(): if isinstance(k, KeyType) and k.key == key: return k.trivia.comment @@ -36,16 +36,16 @@ def compare_dicts(new, old, path=None, new_comments=None, old_comments=None, log continue if key not in old: comment = get_key_comment(new, key) - logs.append(f"新增: {'.'.join(path+[str(key)])} 注释: {comment if comment else '无'}") + logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment if comment else '无'}") elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)): - compare_dicts(new[key], old[key], path+[str(key)], new_comments, old_comments, logs) + compare_dicts(new[key], old[key], path + [str(key)], new_comments, old_comments, logs) # 删减项 for key in old: if key == "version": continue if key not in new: comment = get_key_comment(old, key) - logs.append(f"删减: {'.'.join(path+[str(key)])} 注释: {comment if comment else '无'}") + logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment if comment else '无'}") return logs @@ -95,7 +95,7 @@ def update_config(): if old_version and new_version and old_version == new_version: print(f"检测到版本号相同 (v{old_version}),跳过更新") # 如果version相同,恢复旧配置文件并返回 - shutil.move(old_backup_path, old_config_path) + shutil.move(old_backup_path, old_config_path) # type: ignore return else: print(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}") diff --git a/src/config/config.py b/src/config/config.py index ed433dfd1..fcbde9871 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -53,13 +53,13 @@ MMC_VERSION = "0.9.0-snapshot.2" def get_key_comment(toml_table, key): # 获取key的注释(如果有) - if hasattr(toml_table, 'trivia') and hasattr(toml_table.trivia, 'comment'): + if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"): return toml_table.trivia.comment - if hasattr(toml_table, 'value') and isinstance(toml_table.value, dict): + if hasattr(toml_table, "value") and isinstance(toml_table.value, dict): item = toml_table.value.get(key) - if item is not None and hasattr(item, 'trivia'): + if item is not None and hasattr(item, "trivia"): return item.trivia.comment - if hasattr(toml_table, 'keys'): + if hasattr(toml_table, "keys"): for k in toml_table.keys(): if isinstance(k, KeyType) and k.key == key: return k.trivia.comment @@ -78,16 +78,16 @@ def compare_dicts(new, old, path=None, logs=None): continue if key not in old: comment = get_key_comment(new, key) - logs.append(f"新增: {'.'.join(path+[str(key)])} 注释: {comment if comment else '无'}") + logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment if comment else '无'}") elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)): - compare_dicts(new[key], old[key], path+[str(key)], logs) + compare_dicts(new[key], old[key], path + [str(key)], logs) # 删减项 for key in old: if key == "version": continue if key not in new: comment = get_key_comment(old, key) - logs.append(f"删减: {'.'.join(path+[str(key)])} 注释: {comment if comment else '无'}") + logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment if comment else '无'}") return logs @@ -99,6 +99,7 @@ def get_value_by_path(d, path): return None return d + def set_value_by_path(d, path, value): for k in path[:-1]: if k not in d or not isinstance(d[k], dict): @@ -106,6 +107,7 @@ def set_value_by_path(d, path, value): d = d[k] d[path[-1]] = value + def compare_default_values(new, old, path=None, logs=None, changes=None): # 递归比较两个dict,找出默认值变化项 if path is None: @@ -119,12 +121,14 @@ def compare_default_values(new, old, path=None, logs=None, changes=None): continue if key in old: if isinstance(new[key], (dict, Table)) and isinstance(old[key], (dict, Table)): - compare_default_values(new[key], old[key], path+[str(key)], logs, changes) + compare_default_values(new[key], old[key], path + [str(key)], logs, changes) else: # 只要值发生变化就记录 if new[key] != old[key]: - logs.append(f"默认值变化: {'.'.join(path+[str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}") - changes.append((path+[str(key)], old[key], new[key])) + logs.append( + f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}" + ) + changes.append((path + [str(key)], old[key], new[key])) return logs, changes @@ -148,8 +152,8 @@ def update_config(): return None with open(toml_path, "r", encoding="utf-8") as f: doc = tomlkit.load(f) - if "inner" in doc and "version" in doc["inner"]: - return doc["inner"]["version"] + if "inner" in doc and "version" in doc["inner"]: # type: ignore + return doc["inner"]["version"] # type: ignore return None template_version = get_version_from_toml(template_path) @@ -186,7 +190,9 @@ def update_config(): old_value = get_value_by_path(old_config, path) if old_value == old_default: set_value_by_path(old_config, path, new_default) - logger.info(f"已自动将配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}") + logger.info( + f"已自动将配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}" + ) else: logger.info("未检测到模板默认值变动") # 保存旧配置的变更(后续合并逻辑会用到 old_config) @@ -229,7 +235,9 @@ def update_config(): logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新") return else: - logger.info(f"\n----------------------------------------\n检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------") + logger.info( + f"\n----------------------------------------\n检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------" + ) else: logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新") @@ -321,6 +329,7 @@ class Config(ConfigBase): debug: DebugConfig custom_prompt: CustomPromptConfig + def load_config(config_path: str) -> Config: """ 加载配置文件 diff --git a/src/individuality/not_using/offline_llm.py b/src/individuality/not_using/offline_llm.py index 83cb263c7..2bafb69aa 100644 --- a/src/individuality/not_using/offline_llm.py +++ b/src/individuality/not_using/offline_llm.py @@ -39,7 +39,7 @@ class LLMRequestOff: } # 发送请求到完整的 chat/completions 端点 - api_url = f"{self.base_url.rstrip('/')}/chat/completions" + api_url = f"{self.base_url.rstrip('/')}/chat/completions" # type: ignore logger.info(f"Request URL: {api_url}") # 记录请求的 URL max_retries = 3 @@ -89,7 +89,7 @@ class LLMRequestOff: } # 发送请求到完整的 chat/completions 端点 - api_url = f"{self.base_url.rstrip('/')}/chat/completions" + api_url = f"{self.base_url.rstrip('/')}/chat/completions" # type: ignore logger.info(f"Request URL: {api_url}") # 记录请求的 URL max_retries = 3 diff --git a/src/individuality/not_using/per_bf_gen.py b/src/individuality/not_using/per_bf_gen.py index 3b66d0551..aedbe00ee 100644 --- a/src/individuality/not_using/per_bf_gen.py +++ b/src/individuality/not_using/per_bf_gen.py @@ -83,8 +83,8 @@ class PersonalityEvaluatorDirect: def __init__(self): self.personality_traits = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0} self.scenarios = [] - self.final_scores = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0} - self.dimension_counts = {trait: 0 for trait in self.final_scores.keys()} + self.final_scores: Dict[str, float] = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0} + self.dimension_counts = {trait: 0 for trait in self.final_scores} # 为每个人格特质获取对应的场景 for trait in PERSONALITY_SCENES: @@ -119,8 +119,7 @@ class PersonalityEvaluatorDirect: # 构建维度描述 dimension_descriptions = [] for dim in dimensions: - desc = FACTOR_DESCRIPTIONS.get(dim, "") - if desc: + if desc := FACTOR_DESCRIPTIONS.get(dim, ""): dimension_descriptions.append(f"- {dim}:{desc}") dimensions_text = "\n".join(dimension_descriptions) diff --git a/src/main.py b/src/main.py index 3dc8c4c9a..dbd12f1a4 100644 --- a/src/main.py +++ b/src/main.py @@ -153,14 +153,14 @@ class MainSystem: while True: await asyncio.sleep(global_config.memory.memory_build_interval) logger.info("正在进行记忆构建") - await self.hippocampus_manager.build_memory() + await self.hippocampus_manager.build_memory() # type: ignore async def forget_memory_task(self): """记忆遗忘任务""" while True: await asyncio.sleep(global_config.memory.forget_memory_interval) logger.info("[记忆遗忘] 开始遗忘记忆...") - await self.hippocampus_manager.forget_memory(percentage=global_config.memory.memory_forget_percentage) + await self.hippocampus_manager.forget_memory(percentage=global_config.memory.memory_forget_percentage) # type: ignore logger.info("[记忆遗忘] 记忆遗忘完成") async def consolidate_memory_task(self): @@ -168,7 +168,7 @@ class MainSystem: while True: await asyncio.sleep(global_config.memory.consolidate_memory_interval) logger.info("[记忆整合] 开始整合记忆...") - await self.hippocampus_manager.consolidate_memory() + await self.hippocampus_manager.consolidate_memory() # type: ignore logger.info("[记忆整合] 记忆整合完成") @staticmethod diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py index b47785401..398b1f372 100644 --- a/src/mood/mood_manager.py +++ b/src/mood/mood_manager.py @@ -49,6 +49,9 @@ class ChatMood: chat_manager = get_chat_manager() self.chat_stream = chat_manager.get_stream(self.chat_id) + + if not self.chat_stream: + raise ValueError(f"Chat stream for chat_id {chat_id} not found") self.log_prefix = f"[{self.chat_stream.group_info.group_name if self.chat_stream.group_info else self.chat_stream.user_info.user_nickname}]" diff --git a/src/person_info/relationship_builder.py b/src/person_info/relationship_builder.py index a489a34d5..69b9e84d2 100644 --- a/src/person_info/relationship_builder.py +++ b/src/person_info/relationship_builder.py @@ -26,7 +26,7 @@ SEGMENT_CLEANUP_CONFIG = { "cleanup_interval_hours": 0.5, # 清理间隔(小时) } -MAX_MESSAGE_COUNT = 80 / global_config.relationship.relation_frequency +MAX_MESSAGE_COUNT = int(80 / global_config.relationship.relation_frequency) class RelationshipBuilder: diff --git a/src/plugin_system/__init__.py b/src/plugin_system/__init__.py index b8701839d..59e240811 100644 --- a/src/plugin_system/__init__.py +++ b/src/plugin_system/__init__.py @@ -61,7 +61,7 @@ __all__ = [ "ConfigField", # 工具函数 "ManifestValidator", - "ManifestGenerator", - "validate_plugin_manifest", - "generate_plugin_manifest", + # "ManifestGenerator", + # "validate_plugin_manifest", + # "generate_plugin_manifest", ] diff --git a/src/plugin_system/apis/send_api.py b/src/plugin_system/apis/send_api.py index 97bee9908..c8b03a0a6 100644 --- a/src/plugin_system/apis/send_api.py +++ b/src/plugin_system/apis/send_api.py @@ -111,7 +111,7 @@ async def _send_to_target( is_head=True, is_emoji=(message_type == "emoji"), thinking_start_time=current_time, - reply_to = reply_to_platform_id + reply_to=reply_to_platform_id, ) # 发送消息 @@ -137,6 +137,7 @@ async def _send_to_target( async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageRecv]: + # sourcery skip: inline-variable, use-named-expression """查找要回复的消息 Args: @@ -184,14 +185,11 @@ async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageR # 检查是否有 回复 字段 reply_pattern = r"回复<([^:<>]+):([^:<>]+)>" - match = re.search(reply_pattern, translate_text) - if match: + if match := re.search(reply_pattern, translate_text): aaa = match.group(1) bbb = match.group(2) reply_person_id = get_person_info_manager().get_person_id(platform, bbb) - reply_person_name = await get_person_info_manager().get_value(reply_person_id, "person_name") - if not reply_person_name: - reply_person_name = aaa + reply_person_name = await get_person_info_manager().get_value(reply_person_id, "person_name") or aaa # 在内容前加上回复信息 translate_text = re.sub(reply_pattern, f"回复 {reply_person_name}", translate_text, count=1) @@ -206,9 +204,7 @@ async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageR aaa = m.group(1) bbb = m.group(2) at_person_id = get_person_info_manager().get_person_id(platform, bbb) - at_person_name = await get_person_info_manager().get_value(at_person_id, "person_name") - if not at_person_name: - at_person_name = aaa + at_person_name = await get_person_info_manager().get_value(at_person_id, "person_name") or aaa new_content += f"@{at_person_name}" last_end = m.end() new_content += translate_text[last_end:] @@ -370,7 +366,14 @@ async def custom_to_stream( bool: 是否发送成功 """ return await _send_to_target( - message_type, content, stream_id, display_message, typing, reply_to, storage_message, show_log + message_type, + content, + stream_id, + display_message, + typing, + reply_to, + storage_message=storage_message, + show_log=show_log, ) @@ -396,7 +399,7 @@ async def text_to_group( """ stream_id = get_chat_manager().get_stream_id(platform, group_id, True) - return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message) + return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message=storage_message) async def text_to_user( @@ -420,7 +423,7 @@ async def text_to_user( bool: 是否发送成功 """ stream_id = get_chat_manager().get_stream_id(platform, user_id, False) - return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message) + return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message=storage_message) async def emoji_to_group(emoji_base64: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool: @@ -543,7 +546,9 @@ async def custom_to_group( bool: 是否发送成功 """ stream_id = get_chat_manager().get_stream_id(platform, group_id, True) - return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message) + return await _send_to_target( + message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message + ) async def custom_to_user( @@ -571,7 +576,9 @@ async def custom_to_user( bool: 是否发送成功 """ stream_id = get_chat_manager().get_stream_id(platform, user_id, False) - return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message) + return await _send_to_target( + message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message + ) async def custom_message( @@ -611,4 +618,6 @@ async def custom_message( await send_api.custom_message("audio", audio_base64, "123456", True, reply_to="张三:你好") """ stream_id = get_chat_manager().get_stream_id(platform, target_id, is_group) - return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message) + return await _send_to_target( + message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message + ) diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py index 2c559a2c7..74ab22e67 100644 --- a/src/plugin_system/base/base_action.py +++ b/src/plugin_system/base/base_action.py @@ -38,7 +38,7 @@ class BaseAction(ABC): chat_stream: ChatStream, log_prefix: str = "", plugin_config: Optional[dict] = None, - action_message: dict = None, + action_message: Optional[dict] = None, **kwargs, ): """初始化Action组件 @@ -63,7 +63,7 @@ class BaseAction(ABC): self.cycle_timers = cycle_timers self.thinking_id = thinking_id self.log_prefix = log_prefix - + # 保存插件配置 self.plugin_config = plugin_config or {} @@ -92,10 +92,10 @@ class BaseAction(ABC): self.chat_stream = chat_stream or kwargs.get("chat_stream") self.chat_id = self.chat_stream.stream_id self.platform = getattr(self.chat_stream, "platform", None) - + # 初始化基础信息(带类型注解) self.action_message = action_message - + self.group_id = None self.group_name = None self.user_id = None @@ -103,15 +103,17 @@ class BaseAction(ABC): self.is_group = False self.target_id = None self.has_action_message = False - + if self.action_message: self.has_action_message = True - + else: + self.action_message = {} + if self.has_action_message: if self.action_name != "no_reply": self.group_id = str(self.action_message.get("chat_info_group_id", None)) self.group_name = self.action_message.get("chat_info_group_name", None) - + self.user_id = str(self.action_message.get("user_id", None)) self.user_nickname = self.action_message.get("user_nickname", None) if self.group_id: @@ -132,8 +134,6 @@ class BaseAction(ABC): self.is_group = False self.target_id = self.user_id - - logger.debug(f"{self.log_prefix} Action组件初始化完成") logger.info( f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}" @@ -199,7 +199,9 @@ class BaseAction(ABC): logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}") return False, f"等待新消息失败: {str(e)}" - async def send_text(self, content: str, reply_to: str = "", reply_to_platform_id: str = "", typing: bool = False) -> bool: + async def send_text( + self, content: str, reply_to: str = "", reply_to_platform_id: str = "", typing: bool = False + ) -> bool: """发送文本消息 Args: @@ -299,7 +301,7 @@ class BaseAction(ABC): ) async def send_command( - self, command_name: str, args: dict = None, display_message: str = None, storage_message: bool = True + self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True ) -> bool: """发送命令消息 diff --git a/src/plugin_system/base/base_command.py b/src/plugin_system/base/base_command.py index 2c2ddf81e..caf68567b 100644 --- a/src/plugin_system/base/base_command.py +++ b/src/plugin_system/base/base_command.py @@ -135,7 +135,7 @@ class BaseCommand(ABC): ) async def send_command( - self, command_name: str, args: dict = None, display_message: str = "", storage_message: bool = True + self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True ) -> bool: """发送命令消息 diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index b152a1abc..917069e11 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -346,67 +346,67 @@ class ComponentRegistry: # === 状态管理方法 === - def enable_component(self, component_name: str, component_type: ComponentType = None) -> bool: - # -------------------------------- NEED REFACTORING -------------------------------- - # -------------------------------- LOGIC ERROR ------------------------------------- - """启用组件,支持命名空间解析""" - # 首先尝试找到正确的命名空间化名称 - component_info = self.get_component_info(component_name, component_type) - if not component_info: - return False + # def enable_component(self, component_name: str, component_type: ComponentType = None) -> bool: + # # -------------------------------- NEED REFACTORING -------------------------------- + # # -------------------------------- LOGIC ERROR ------------------------------------- + # """启用组件,支持命名空间解析""" + # # 首先尝试找到正确的命名空间化名称 + # component_info = self.get_component_info(component_name, component_type) + # if not component_info: + # return False - # 根据组件类型构造正确的命名空间化名称 - if component_info.component_type == ComponentType.ACTION: - namespaced_name = f"action.{component_name}" if "." not in component_name else component_name - elif component_info.component_type == ComponentType.COMMAND: - namespaced_name = f"command.{component_name}" if "." not in component_name else component_name - else: - namespaced_name = ( - f"{component_info.component_type.value}.{component_name}" - if "." not in component_name - else component_name - ) + # # 根据组件类型构造正确的命名空间化名称 + # if component_info.component_type == ComponentType.ACTION: + # namespaced_name = f"action.{component_name}" if "." not in component_name else component_name + # elif component_info.component_type == ComponentType.COMMAND: + # namespaced_name = f"command.{component_name}" if "." not in component_name else component_name + # else: + # namespaced_name = ( + # f"{component_info.component_type.value}.{component_name}" + # if "." not in component_name + # else component_name + # ) - if namespaced_name in self._components: - self._components[namespaced_name].enabled = True - # 如果是Action,更新默认动作集 - # ---- HERE ---- - # if isinstance(component_info, ActionInfo): - # self._action_descriptions[component_name] = component_info.description - logger.debug(f"已启用组件: {component_name} -> {namespaced_name}") - return True - return False + # if namespaced_name in self._components: + # self._components[namespaced_name].enabled = True + # # 如果是Action,更新默认动作集 + # # ---- HERE ---- + # # if isinstance(component_info, ActionInfo): + # # self._action_descriptions[component_name] = component_info.description + # logger.debug(f"已启用组件: {component_name} -> {namespaced_name}") + # return True + # return False - def disable_component(self, component_name: str, component_type: ComponentType = None) -> bool: - # -------------------------------- NEED REFACTORING -------------------------------- - # -------------------------------- LOGIC ERROR ------------------------------------- - """禁用组件,支持命名空间解析""" - # 首先尝试找到正确的命名空间化名称 - component_info = self.get_component_info(component_name, component_type) - if not component_info: - return False + # def disable_component(self, component_name: str, component_type: ComponentType = None) -> bool: + # # -------------------------------- NEED REFACTORING -------------------------------- + # # -------------------------------- LOGIC ERROR ------------------------------------- + # """禁用组件,支持命名空间解析""" + # # 首先尝试找到正确的命名空间化名称 + # component_info = self.get_component_info(component_name, component_type) + # if not component_info: + # return False - # 根据组件类型构造正确的命名空间化名称 - if component_info.component_type == ComponentType.ACTION: - namespaced_name = f"action.{component_name}" if "." not in component_name else component_name - elif component_info.component_type == ComponentType.COMMAND: - namespaced_name = f"command.{component_name}" if "." not in component_name else component_name - else: - namespaced_name = ( - f"{component_info.component_type.value}.{component_name}" - if "." not in component_name - else component_name - ) + # # 根据组件类型构造正确的命名空间化名称 + # if component_info.component_type == ComponentType.ACTION: + # namespaced_name = f"action.{component_name}" if "." not in component_name else component_name + # elif component_info.component_type == ComponentType.COMMAND: + # namespaced_name = f"command.{component_name}" if "." not in component_name else component_name + # else: + # namespaced_name = ( + # f"{component_info.component_type.value}.{component_name}" + # if "." not in component_name + # else component_name + # ) - if namespaced_name in self._components: - self._components[namespaced_name].enabled = False - # 如果是Action,从默认动作集中移除 - # ---- HERE ---- - # if component_name in self._action_descriptions: - # del self._action_descriptions[component_name] - logger.debug(f"已禁用组件: {component_name} -> {namespaced_name}") - return True - return False + # if namespaced_name in self._components: + # self._components[namespaced_name].enabled = False + # # 如果是Action,从默认动作集中移除 + # # ---- HERE ---- + # # if component_name in self._action_descriptions: + # # del self._action_descriptions[component_name] + # logger.debug(f"已禁用组件: {component_name} -> {namespaced_name}") + # return True + # return False def get_registry_stats(self) -> Dict[str, Any]: """获取注册中心统计信息""" diff --git a/src/plugin_system/core/dependency_manager.py b/src/plugin_system/core/dependency_manager.py index 4a995e028..266254e72 100644 --- a/src/plugin_system/core/dependency_manager.py +++ b/src/plugin_system/core/dependency_manager.py @@ -7,7 +7,7 @@ import subprocess import sys import importlib -from typing import List, Dict, Tuple +from typing import List, Dict, Tuple, Any from src.common.logger import get_logger from src.plugin_system.base.component_types import PythonDependency @@ -176,7 +176,7 @@ class DependencyManager: logger.error(f"生成requirements文件失败: {str(e)}") return False - def get_install_summary(self) -> Dict[str, any]: + def get_install_summary(self) -> Dict[str, Any]: """获取安装摘要""" return { "install_log": self.install_log.copy(), diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index cff28cb99..b4050794f 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -197,29 +197,29 @@ class PluginManager: """获取所有启用的插件信息""" return list(component_registry.get_enabled_plugins().values()) - def enable_plugin(self, plugin_name: str) -> bool: - # -------------------------------- NEED REFACTORING -------------------------------- - """启用插件""" - if plugin_info := component_registry.get_plugin_info(plugin_name): - plugin_info.enabled = True - # 启用插件的所有组件 - for component in plugin_info.components: - component_registry.enable_component(component.name) - logger.debug(f"已启用插件: {plugin_name}") - return True - return False + # def enable_plugin(self, plugin_name: str) -> bool: + # # -------------------------------- NEED REFACTORING -------------------------------- + # """启用插件""" + # if plugin_info := component_registry.get_plugin_info(plugin_name): + # plugin_info.enabled = True + # # 启用插件的所有组件 + # for component in plugin_info.components: + # component_registry.enable_component(component.name) + # logger.debug(f"已启用插件: {plugin_name}") + # return True + # return False - def disable_plugin(self, plugin_name: str) -> bool: - # -------------------------------- NEED REFACTORING -------------------------------- - """禁用插件""" - if plugin_info := component_registry.get_plugin_info(plugin_name): - plugin_info.enabled = False - # 禁用插件的所有组件 - for component in plugin_info.components: - component_registry.disable_component(component.name) - logger.debug(f"已禁用插件: {plugin_name}") - return True - return False + # def disable_plugin(self, plugin_name: str) -> bool: + # # -------------------------------- NEED REFACTORING -------------------------------- + # """禁用插件""" + # if plugin_info := component_registry.get_plugin_info(plugin_name): + # plugin_info.enabled = False + # # 禁用插件的所有组件 + # for component in plugin_info.components: + # component_registry.disable_component(component.name) + # logger.debug(f"已禁用插件: {plugin_name}") + # return True + # return False def get_plugin_instance(self, plugin_name: str) -> Optional["PluginBase"]: """获取插件实例 diff --git a/src/tools/tool_can_use/compare_numbers_tool.py b/src/tools/tool_can_use/compare_numbers_tool.py index e73f6e79f..2930f8f4b 100644 --- a/src/tools/tool_can_use/compare_numbers_tool.py +++ b/src/tools/tool_can_use/compare_numbers_tool.py @@ -28,10 +28,10 @@ class CompareNumbersTool(BaseTool): Returns: dict: 工具执行结果 """ - try: - num1 = function_args.get("num1") - num2 = function_args.get("num2") + num1: int | float = function_args.get("num1") # type: ignore + num2: int | float = function_args.get("num2") # type: ignore + try: if num1 > num2: result = f"{num1} 大于 {num2}" elif num1 < num2: diff --git a/src/tools/tool_can_use/rename_person_tool.py b/src/tools/tool_can_use/rename_person_tool.py index 0651e0c2c..cfc6ef4b0 100644 --- a/src/tools/tool_can_use/rename_person_tool.py +++ b/src/tools/tool_can_use/rename_person_tool.py @@ -68,10 +68,10 @@ class RenamePersonTool(BaseTool): ) result = await person_info_manager.qv_person_name( person_id=person_id, - user_nickname=user_nickname, - user_cardname=user_cardname, - user_avatar=user_avatar, - request=request_context, + user_nickname=user_nickname, # type: ignore + user_cardname=user_cardname, # type: ignore + user_avatar=user_avatar, # type: ignore + request=request_context, # type: ignore ) # 3. 处理结果