From f2c901bc988de8d3b64b69112e378e112f20970e Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Sun, 20 Jul 2025 18:14:53 +0800 Subject: [PATCH] typing --- .../heart_flow/heartflow_message_processor.py | 2 +- src/chat/message_receive/message.py | 4 +- src/chat/utils/statistic.py | 2 +- src/chat/utils/utils.py | 4 +- src/chat/willing/willing_manager.py | 2 +- src/common/database/database_model.py | 30 ++++++------- src/llm_models/utils_model.py | 45 ++++++++++--------- src/mood/mood_manager.py | 2 +- src/tools/tool_can_use/rename_person_tool.py | 2 +- 9 files changed, 47 insertions(+), 46 deletions(-) diff --git a/src/chat/heart_flow/heartflow_message_processor.py b/src/chat/heart_flow/heartflow_message_processor.py index 076ef0c06..a9d118286 100644 --- a/src/chat/heart_flow/heartflow_message_processor.py +++ b/src/chat/heart_flow/heartflow_message_processor.py @@ -112,7 +112,7 @@ class HeartFCMessageReceiver: # subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned) - chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id) # type: ignore + chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id) asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate)) # 3. 日志记录 diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index 36737eb77..b35b233ea 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -494,10 +494,10 @@ class MessageSending(MessageProcessBase): # ) -> "MessageSending": # """从思考状态消息创建发送状态消息""" # return cls( - # message_id=thinking.message_info.message_id, # type: ignore + # message_id=thinking.message_info.message_id, # chat_stream=thinking.chat_stream, # message_segment=message_segment, - # bot_user_info=thinking.message_info.user_info, # type: ignore + # bot_user_info=thinking.message_info.user_info, # reply=thinking.reply, # is_head=is_head, # is_emoji=is_emoji, diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index 0aff5102e..bce8856e5 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -2348,7 +2348,7 @@ class AsyncStatisticOutputTask(AsyncTask): @staticmethod def _format_model_classified_stat(stats: Dict[str, Any]) -> str: - return StatisticOutputTask._format_model_classified_stat(stats) # type: ignore + return StatisticOutputTask._format_model_classified_stat(stats) def _format_chat_stat(self, stats: Dict[str, Any]) -> str: return StatisticOutputTask._format_chat_stat(self, stats) # type: ignore diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index a329b3548..071f1886c 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -285,7 +285,7 @@ def random_remove_punctuation(text: str) -> str: continue elif char == ",": rand = random.random() - if rand < 0.25: # 5%概率删除逗号 + if rand < 0.05: # 5%概率删除逗号 continue elif rand < 0.25: # 20%概率把逗号变成空格 result += " " @@ -628,7 +628,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]: elif chat_stream.user_info: # It's a private chat is_group_chat = False user_info = chat_stream.user_info - platform: str = chat_stream.platform # type: ignore + platform: str = chat_stream.platform user_id: str = user_info.user_id # type: ignore # Initialize target_info with basic info diff --git a/src/chat/willing/willing_manager.py b/src/chat/willing/willing_manager.py index 31ea49399..6b946f92c 100644 --- a/src/chat/willing/willing_manager.py +++ b/src/chat/willing/willing_manager.py @@ -94,7 +94,7 @@ class BaseWillingManager(ABC): def setup(self, message: dict, chat: ChatStream): person_id = PersonInfoManager.get_person_id(chat.platform, chat.user_info.user_id) # type: ignore - self.ongoing_messages[message.get("message_id", "")] = WillingInfo( # type: ignore + self.ongoing_messages[message.get("message_id", "")] = WillingInfo( message=message, chat=chat, person_info_manager=get_person_info_manager(), diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index 4b60dfa10..645b0a5d6 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -65,7 +65,7 @@ class ChatStreams(BaseModel): # user_cardname 可能为空字符串或不存在,设置 null=True 更具灵活性。 user_cardname = TextField(null=True) - class Meta: # type: ignore + class Meta: # 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。 # 如果不使用带有数据库实例的 BaseModel,或者想覆盖它, # 请取消注释并在下面设置数据库实例: @@ -89,7 +89,7 @@ class LLMUsage(BaseModel): status = TextField() timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引 - class Meta: # type: ignore + class Meta: # 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。 # database = db table_name = "llm_usage" @@ -112,7 +112,7 @@ class Emoji(BaseModel): usage_count = IntegerField(default=0) # 使用次数(被使用的次数) last_used_time = FloatField(null=True) # 上次使用时间 - class Meta: # type: ignore + class Meta: # database = db # 继承自 BaseModel table_name = "emoji" @@ -163,7 +163,7 @@ class Messages(BaseModel): is_picid = BooleanField(default=False) is_command = BooleanField(default=False) - class Meta: # type: ignore + class Meta: # database = db # 继承自 BaseModel table_name = "messages" @@ -187,7 +187,7 @@ class ActionRecords(BaseModel): chat_info_stream_id = TextField() chat_info_platform = TextField() - class Meta: # type: ignore + class Meta: # database = db # 继承自 BaseModel table_name = "action_records" @@ -207,7 +207,7 @@ class Images(BaseModel): type = TextField() # 图像类型,例如 "emoji" vlm_processed = BooleanField(default=False) # 是否已经过VLM处理 - class Meta: # type: ignore + class Meta: table_name = "images" @@ -221,7 +221,7 @@ class ImageDescriptions(BaseModel): description = TextField() # 图像的描述 timestamp = FloatField() # 时间戳 - class Meta: # type: ignore + class Meta: # database = db # 继承自 BaseModel table_name = "image_descriptions" @@ -237,7 +237,7 @@ class OnlineTime(BaseModel): start_timestamp = DateTimeField(default=datetime.datetime.now) end_timestamp = DateTimeField(index=True) - class Meta: # type: ignore + class Meta: # database = db # 继承自 BaseModel table_name = "online_time" @@ -264,7 +264,7 @@ class PersonInfo(BaseModel): last_know = FloatField(null=True) # 最后一次印象总结时间 attitude = IntegerField(null=True, default=50) # 态度,0-100,从非常厌恶到十分喜欢 - class Meta: # type: ignore + class Meta: # database = db # 继承自 BaseModel table_name = "person_info" @@ -277,7 +277,7 @@ class Memory(BaseModel): create_time = FloatField(null=True) last_view_time = FloatField(null=True) - class Meta: # type: ignore + class Meta: table_name = "memory" @@ -290,7 +290,7 @@ class Knowledges(BaseModel): embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表 # 可以添加其他元数据字段,如 source, create_time 等 - class Meta: # type: ignore + class Meta: # database = db # 继承自 BaseModel table_name = "knowledges" @@ -307,7 +307,7 @@ class Expression(BaseModel): chat_id = TextField(index=True) type = TextField() - class Meta: # type: ignore + class Meta: table_name = "expression" @@ -331,7 +331,7 @@ class ThinkingLog(BaseModel): # And: import datetime created_at = DateTimeField(default=datetime.datetime.now) - class Meta: # type: ignore + class Meta: table_name = "thinking_logs" @@ -346,7 +346,7 @@ class GraphNodes(BaseModel): created_time = FloatField() # 创建时间戳 last_modified = FloatField() # 最后修改时间戳 - class Meta: # type: ignore + class Meta: table_name = "graph_nodes" @@ -362,7 +362,7 @@ class GraphEdges(BaseModel): created_time = FloatField() # 创建时间戳 last_modified = FloatField() # 最后修改时间戳 - class Meta: # type: ignore + class Meta: table_name = "graph_edges" diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 2e1d426a6..3621b4502 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -2,7 +2,7 @@ import asyncio import json import re from datetime import datetime -from typing import Tuple, Union, Dict, Any +from typing import Tuple, Union, Dict, Any, Callable import aiohttp from aiohttp.client import ClientResponse from src.common.logger import get_logger @@ -300,7 +300,7 @@ class LLMRequest: file_format: str = None, payload: dict = None, retry_policy: dict = None, - response_handler: callable = None, + response_handler: Callable = None, user_id: str = "system", request_type: str = None, ): @@ -336,19 +336,17 @@ class LLMRequest: headers["Accept"] = "text/event-stream" async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session: post_kwargs = {"headers": headers} - #form-data数据上传方式不同 + # form-data数据上传方式不同 if file_bytes: post_kwargs["data"] = request_content["payload"] else: post_kwargs["json"] = request_content["payload"] - async with session.post( - request_content["api_url"], **post_kwargs - ) as response: + async with session.post(request_content["api_url"], **post_kwargs) as response: handled_result = await self._handle_response( response, request_content, retry, response_handler, user_id, request_type, endpoint ) - return handled_result + return handled_result except Exception as e: handled_payload, count_delta = await self._handle_exception(e, retry, request_content) @@ -366,11 +364,11 @@ class LLMRequest: response: ClientResponse, request_content: Dict[str, Any], retry_count: int, - response_handler: callable, + response_handler: Callable, user_id, request_type, endpoint, - ) -> Union[Dict[str, Any], None]: + ): policy = request_content["policy"] stream_mode = request_content["stream_mode"] if response.status in policy["retry_codes"] or response.status in policy["abort_codes"]: @@ -477,9 +475,7 @@ class LLMRequest: } return result - async def _handle_error_response( - self, response: ClientResponse, retry_count: int, policy: Dict[str, Any] - ) -> Union[Dict[str, any]]: + async def _handle_error_response(self, response: ClientResponse, retry_count: int, policy: Dict[str, Any]): if response.status in policy["retry_codes"]: wait_time = policy["base_wait"] * (2**retry_count) logger.warning(f"模型 {self.model_name} 错误码: {response.status}, 等待 {wait_time}秒后重试") @@ -629,7 +625,9 @@ class LLMRequest: ) # 安全地检查和记录请求详情 handled_payload = await _safely_record(request_content, payload) - logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}") + logger.critical( + f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}" + ) raise RuntimeError( f"模型 {self.model_name} API请求失败: 状态码 {exception.status}, {exception.message}" ) @@ -643,7 +641,9 @@ class LLMRequest: logger.critical(f"模型 {self.model_name} 请求失败: {str(exception)}") # 安全地检查和记录请求详情 handled_payload = await _safely_record(request_content, payload) - logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}") + logger.critical( + f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}" + ) raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(exception)}") async def _transform_parameters(self, params: dict) -> dict: @@ -682,15 +682,14 @@ class LLMRequest: logger.warning(f"暂不支持的文件类型: {file_format}") data.add_field( - "file",io.BytesIO(file_bytes), + "file", + io.BytesIO(file_bytes), filename=f"file.{file_format}", - content_type=f'{content_type}' # 根据实际文件类型设置 - ) - data.add_field( - "model", self.model_name + content_type=f"{content_type}", # 根据实际文件类型设置 ) + data.add_field("model", self.model_name) return data - + async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict: """构建请求体""" # 复制一份参数,避免直接修改 self.params @@ -819,9 +818,11 @@ class LLMRequest: async def generate_response_for_voice(self, voice_bytes: bytes) -> Tuple: """根据输入的语音文件生成模型的异步响应""" - response = await self._execute_request(endpoint="/audio/transcriptions",file_bytes=voice_bytes, file_format='wav') + response = await self._execute_request( + endpoint="/audio/transcriptions", file_bytes=voice_bytes, file_format="wav" + ) return response - + async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]: """异步方式根据输入的提示生成模型的响应""" # 构建请求体,不硬编码max_tokens diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py index 398b1f372..4134de9b9 100644 --- a/src/mood/mood_manager.py +++ b/src/mood/mood_manager.py @@ -134,7 +134,7 @@ class ChatMood: self.mood_state = response - self.last_change_time = message_time # type: ignore + self.last_change_time = message_time async def regress_mood(self): message_time = time.time() diff --git a/src/tools/tool_can_use/rename_person_tool.py b/src/tools/tool_can_use/rename_person_tool.py index cfc6ef4b0..2216b8245 100644 --- a/src/tools/tool_can_use/rename_person_tool.py +++ b/src/tools/tool_can_use/rename_person_tool.py @@ -71,7 +71,7 @@ class RenamePersonTool(BaseTool): user_nickname=user_nickname, # type: ignore user_cardname=user_cardname, # type: ignore user_avatar=user_avatar, # type: ignore - request=request_context, # type: ignore + request=request_context, ) # 3. 处理结果