refactor: 统一类型注解风格并优化代码结构

- 将裸 except 改为显式 Exception 捕获
- 用列表推导式替换冗余 for 循环
- 为类属性添加 ClassVar 注解
- 统一 Union/Optional 写法为 |
- 移除未使用的导入
- 修复 SQLAlchemy 空值比较语法
- 优化字符串拼接与字典更新逻辑
- 补充缺失的 noqa 注释与异常链

BREAKING CHANGE: 所有插件基类的类级字段现要求显式 ClassVar 注解,自定义插件需同步更新
This commit is contained in:
明天好像没什么
2025-10-31 22:42:39 +08:00
parent 5080cfccfc
commit 0e129d385e
105 changed files with 592 additions and 561 deletions

View File

@@ -251,14 +251,14 @@ class ExpressionSelector:
) -> list[dict[str, Any]]:
"""
统一的表达方式选择入口,根据配置自动选择模式
Args:
chat_id: 聊天ID
chat_history: 聊天历史(列表或字符串)
target_message: 目标消息
max_num: 最多返回数量
min_num: 最少返回数量
Returns:
选中的表达方式列表
"""
@@ -403,12 +403,12 @@ class ExpressionSelector:
) -> list[dict[str, Any]]:
"""
根据StyleLearner预测的风格获取表达方式
Args:
chat_id: 聊天ID
predicted_styles: 预测的风格列表,格式: [(style, score), ...]
max_num: 最多返回数量
Returns:
表达方式列表
"""
@@ -430,7 +430,7 @@ class ExpressionSelector:
.where(Expression.type == "style")
.distinct()
)
db_chat_ids = [cid for cid in db_chat_ids_result.scalars()]
db_chat_ids = list(db_chat_ids_result.scalars())
logger.info(f"数据库中有表达方式的chat_ids ({len(db_chat_ids)}个): {db_chat_ids}")
# 获取所有相关 chat_id 的表达方式(用于模糊匹配)
@@ -509,15 +509,16 @@ class ExpressionSelector:
)
# 转换为字典格式
expressions = []
for expr in expressions_objs:
expressions.append({
expressions = [
{
"situation": expr.situation or "",
"style": expr.style or "",
"type": expr.type or "style",
"count": float(expr.count) if expr.count else 0.0,
"last_active_time": expr.last_active_time or 0.0
})
}
for expr in expressions_objs
]
logger.debug(f"从数据库获取了 {len(expressions)} 个表达方式")
return expressions
@@ -617,7 +618,7 @@ class ExpressionSelector:
# 对选中的所有表达方式一次性更新count数
if valid_expressions:
asyncio.create_task(self.update_expressions_count_batch(valid_expressions, 0.006))
asyncio.create_task(self.update_expressions_count_batch(valid_expressions, 0.006)) # noqa: RUF006
# logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
return valid_expressions

View File

@@ -61,7 +61,7 @@ class ExpressorModel:
if cid not in self.nb.token_counts:
self.nb.token_counts[cid] = defaultdict(float)
def predict(self, text: str, k: int = None) -> tuple[str | None, dict[str, float]]:
def predict(self, text: str, k: int | None = None) -> tuple[str | None, dict[str, float]]:
"""
直接对所有候选进行朴素贝叶斯评分

View File

@@ -10,7 +10,7 @@ logger = get_logger("expressor.tokenizer")
class Tokenizer:
"""文本分词器支持中文Jieba分词"""
def __init__(self, stopwords: set = None, use_jieba: bool = True):
def __init__(self, stopwords: set | None = None, use_jieba: bool = True):
"""
Args:
stopwords: 停用词集合
@@ -21,7 +21,7 @@ class Tokenizer:
if use_jieba:
try:
import rjieba
import rjieba # noqa: F401
# rjieba 会自动初始化,无需手动调用
logger.info("RJieba分词器初始化成功")

View File

@@ -55,12 +55,12 @@ class SituationExtractor:
) -> list[str]:
"""
从聊天历史中提取情境
Args:
chat_history: 聊天历史(列表或字符串)
target_message: 目标消息(可选)
max_situations: 最多提取的情境数量
Returns:
情境描述列表
"""
@@ -115,11 +115,11 @@ class SituationExtractor:
def _parse_situations(response: str, max_situations: int) -> list[str]:
"""
解析 LLM 返回的情境描述
Args:
response: LLM 响应
max_situations: 最多返回的情境数量
Returns:
情境描述列表
"""

View File

@@ -391,7 +391,7 @@ class StyleLearnerManager:
是否全部保存成功
"""
success = True
for chat_id, learner in self.learners.items():
for learner in self.learners.values():
if not learner.save(self.model_save_path):
success = False

View File

@@ -306,10 +306,8 @@ class EmbeddingStore:
def save_to_file(self) -> None:
"""保存到文件"""
data = []
logger.info(f"正在保存{self.namespace}嵌入库到文件{self.embedding_file_path}")
for item in self.store.values():
data.append(item.to_dict())
data = [item.to_dict() for item in self.store.values()]
data_frame = pd.DataFrame(data)
if not os.path.exists(self.dir):

View File

@@ -15,15 +15,14 @@ def dyn_select_top_k(
# 归一化
max_score = sorted_score[0][1]
min_score = sorted_score[-1][1]
normalized_score = []
for score_item in sorted_score:
normalized_score.append(
(
score_item[0],
score_item[1],
(score_item[1] - min_score) / (max_score - min_score),
)
normalized_score = [
(
score_item[0],
score_item[1],
(score_item[1] - min_score) / (max_score - min_score),
)
for score_item in sorted_score
]
# 寻找跳变点score变化最大的位置
jump_idx = 0

View File

@@ -468,10 +468,10 @@ class HippocampusSampler:
merged_groups.append(current_group)
# 过滤掉只有一条消息的组(除非内容较长)
result_groups = []
for group in merged_groups:
if len(group) > 1 or any(len(msg.get("processed_plain_text", "")) > 100 for msg in group):
result_groups.append(group)
result_groups = [
group for group in merged_groups
if len(group) > 1 or any(len(msg.get("processed_plain_text", "")) > 100 for msg in group)
]
return result_groups

View File

@@ -634,9 +634,7 @@ class MemoryBuilder:
if cleaned:
participants.append(cleaned)
elif isinstance(value, str):
for part in self._split_subject_string(value):
if part:
participants.append(part)
participants.extend(part for part in self._split_subject_string(value) if part)
fallback = self._resolve_user_display(context, user_id)
if fallback:

View File

@@ -1265,9 +1265,7 @@ class MemorySystem:
)
if relevant_memories:
memory_contexts = []
for memory in relevant_memories:
memory_contexts.append(f"[历史记忆] {memory.text_content}")
memory_contexts = [f"[历史记忆] {memory.text_content}" for memory in relevant_memories]
memory_transcript = "\n".join(memory_contexts)
cleaned_fallback = (fallback_text or "").strip()

View File

@@ -122,8 +122,7 @@ class MessageCollectionStorage:
collections = []
if results and results.get("ids") and results["ids"][0]:
for metadata in results["metadatas"][0]:
collections.append(MessageCollection.from_dict(metadata))
collections.extend(MessageCollection.from_dict(metadata) for metadata in results["metadatas"][0])
return collections
except Exception as e:

View File

@@ -115,7 +115,7 @@ class StreamLoopManager:
if not force and context.stream_loop_task and not context.stream_loop_task.done():
logger.debug(f"{stream_id} 循环已在运行")
return True
# 如果是强制启动且任务仍在运行,先取消旧任务
if force and context.stream_loop_task and not context.stream_loop_task.done():
logger.info(f"强制启动模式:先取消现有流循环任务: {stream_id}")
@@ -438,7 +438,7 @@ class StreamLoopManager:
async def _update_stream_energy(self, stream_id: str, context: Any) -> None:
"""更新流的能量值
Args:
stream_id: 流ID
context: 流上下文 (StreamContext)

View File

@@ -161,7 +161,7 @@ class GlobalNoticeManager:
self._cleanup_expired_notices()
# 收集可访问的notice
for storage_key, notices in self._notices.items():
for notices in self._notices.values():
for notice in notices:
if notice.is_expired():
continue

View File

@@ -355,7 +355,7 @@ class MessageManager:
try:
stream_loop_task.cancel()
logger.info(f"已发送取消信号到流循环任务: {chat_stream.stream_id}")
# 等待任务真正结束(设置超时避免死锁)
try:
await asyncio.wait_for(stream_loop_task, timeout=2.0)
@@ -625,7 +625,7 @@ class MessageManager:
def _determine_notice_scope(self, message: DatabaseMessages, stream_id: str) -> NoticeScope:
"""确定notice的作用域
作用域完全由 additional_config 中的 is_public_notice 字段决定:
- is_public_notice=True: 公共notice所有聊天流可见
- is_public_notice=False 或未设置: 特定聊天流notice

View File

@@ -125,7 +125,7 @@ class ChatStream:
async def set_context(self, message: DatabaseMessages):
"""设置聊天消息上下文
Args:
message: DatabaseMessages 对象,直接使用不需要转换
"""

View File

@@ -22,17 +22,17 @@ logger = get_logger("message_processor")
async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str, platform: str) -> DatabaseMessages:
"""从适配器消息字典处理并生成 DatabaseMessages
这个函数整合了原 MessageRecv 的所有处理逻辑:
1. 解析 message_segment 并异步处理内容(图片、语音、视频等)
2. 提取所有消息元数据
3. 直接构造 DatabaseMessages 对象
Args:
message_dict: MessageCQ序列化后的字典
stream_id: 聊天流ID
platform: 平台标识
Returns:
DatabaseMessages: 处理完成的数据库消息对象
"""
@@ -98,7 +98,7 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str
mentioned_value = processing_state.get("is_mentioned")
if isinstance(mentioned_value, bool):
is_mentioned = mentioned_value
elif isinstance(mentioned_value, (int, float)):
elif isinstance(mentioned_value, int | float):
is_mentioned = mentioned_value != 0
db_message = DatabaseMessages(
@@ -151,12 +151,12 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str
async def _process_message_segments(segment: Seg, state: dict, message_info: BaseMessageInfo) -> str:
"""递归处理消息段,转换为文字描述
Args:
segment: 要处理的消息段
state: 处理状态字典(用于记录消息类型标记)
message_info: 消息基础信息(用于某些处理逻辑)
Returns:
str: 处理后的文本
"""
@@ -175,12 +175,12 @@ async def _process_message_segments(segment: Seg, state: dict, message_info: Bas
async def _process_single_segment(segment: Seg, state: dict, message_info: BaseMessageInfo) -> str:
"""处理单个消息段
Args:
segment: 消息段
state: 处理状态字典
message_info: 消息基础信息
Returns:
str: 处理后的文本
"""
@@ -337,13 +337,13 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, is_public_notice: bool, notice_type: str | None) -> str | None:
"""准备 additional_config包含 format_info 和 notice 信息
Args:
message_info: 消息基础信息
is_notify: 是否为notice消息
is_public_notice: 是否为公共notice
notice_type: notice类型
Returns:
str | None: JSON 字符串格式的 additional_config如果为空则返回 None
"""
@@ -387,10 +387,10 @@ def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, i
def _extract_reply_from_segment(segment: Seg) -> str | None:
"""从消息段中提取reply_to信息
Args:
segment: 消息段
Returns:
str | None: 回复的消息ID如果没有则返回None
"""
@@ -416,10 +416,10 @@ def _extract_reply_from_segment(segment: Seg) -> str | None:
def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessageInfo:
"""从 DatabaseMessages 重建 BaseMessageInfo用于需要 message_info 的遗留代码)
Args:
db_message: DatabaseMessages 对象
Returns:
BaseMessageInfo: 重建的消息信息对象
"""
@@ -466,7 +466,7 @@ def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessag
def set_db_message_runtime_attr(db_message: DatabaseMessages, attr_name: str, value: Any) -> None:
"""安全地为 DatabaseMessages 设置运行时属性
Args:
db_message: DatabaseMessages 对象
attr_name: 属性名
@@ -477,12 +477,12 @@ def set_db_message_runtime_attr(db_message: DatabaseMessages, attr_name: str, va
def get_db_message_runtime_attr(db_message: DatabaseMessages, attr_name: str, default: Any = None) -> Any:
"""安全地获取 DatabaseMessages 的运行时属性
Args:
db_message: DatabaseMessages 对象
attr_name: 属性名
default: 默认值
Returns:
属性值或默认值
"""

View File

@@ -275,8 +275,8 @@ class MessageStorage:
except Exception as e:
logger.error(f"更新消息ID失败: {e}")
logger.error(
f"消息信息: message_id={getattr(message.message_info, 'message_id', 'N/A')}, "
f"segment_type={getattr(message.message_segment, 'type', 'N/A')}"
f"消息信息: message_id={message_data.get('message_info', {}).get('message_id', 'N/A')}, "
f"segment_type={message_data.get('message_segment', {}).get('type', 'N/A')}"
)
@staticmethod

View File

@@ -47,7 +47,7 @@ async def send_message(message: MessageSending, show_log=True) -> bool:
logger.error(f"[事件触发] 异步触发事件失败: {e}", exc_info=True)
# 创建异步任务,不等待完成
asyncio.create_task(trigger_event_async())
asyncio.create_task(trigger_event_async()) # noqa: RUF006
logger.info("[发送完成] AFTER_SEND 事件已提交到异步任务")
except Exception as event_error:
logger.error(f"触发 AFTER_SEND 事件时出错: {event_error}", exc_info=True)

View File

@@ -204,7 +204,7 @@ class ChatterActionManager:
action_prompt_display=reason,
)
else:
asyncio.create_task(
asyncio.create_task( # noqa: RUF006
database_api.store_action_info(
chat_stream=chat_stream,
action_build_into_prompt=False,
@@ -217,7 +217,7 @@ class ChatterActionManager:
)
# 自动清空所有未读消息
asyncio.create_task(self._clear_all_unread_messages(chat_stream.stream_id, "no_reply"))
asyncio.create_task(self._clear_all_unread_messages(chat_stream.stream_id, "no_reply")) # noqa: RUF006
return {"action_type": "no_reply", "success": True, "reply_text": "", "command": ""}
@@ -235,14 +235,14 @@ class ChatterActionManager:
# 记录执行的动作到目标消息
if success:
asyncio.create_task(
asyncio.create_task( # noqa: RUF006
self._record_action_to_message(chat_stream, action_name, target_message, action_data)
)
# 自动清空所有未读消息
if clear_unread_messages:
asyncio.create_task(self._clear_all_unread_messages(chat_stream.stream_id, action_name))
asyncio.create_task(self._clear_all_unread_messages(chat_stream.stream_id, action_name)) # noqa: RUF006
# 重置打断计数
asyncio.create_task(self._reset_interruption_count_after_action(chat_stream.stream_id))
asyncio.create_task(self._reset_interruption_count_after_action(chat_stream.stream_id)) # noqa: RUF006
return {
"action_type": action_name,
@@ -289,13 +289,13 @@ class ChatterActionManager:
)
# 记录回复动作到目标消息
asyncio.create_task(self._record_action_to_message(chat_stream, "reply", target_message, action_data))
asyncio.create_task(self._record_action_to_message(chat_stream, "reply", target_message, action_data)) # noqa: RUF006
if clear_unread_messages:
asyncio.create_task(self._clear_all_unread_messages(chat_stream.stream_id, "reply"))
asyncio.create_task(self._clear_all_unread_messages(chat_stream.stream_id, "reply")) # noqa: RUF006
# 回复成功,重置打断计数
asyncio.create_task(self._reset_interruption_count_after_action(chat_stream.stream_id))
asyncio.create_task(self._reset_interruption_count_after_action(chat_stream.stream_id)) # noqa: RUF006
return {"action_type": "reply", "success": True, "reply_text": reply_text, "loop_info": loop_info}

View File

@@ -196,7 +196,7 @@ class ActionModifier:
) -> list[tuple[str, str]]:
"""
根据激活类型过滤,返回需要停用的动作列表及原因
新的实现:调用每个 Action 类的 go_activate 方法来判断是否激活
Args:
@@ -271,8 +271,7 @@ class ActionModifier:
except Exception as e:
logger.error(f"{self.log_prefix}并行激活判断失败: {e}")
# 如果并行执行失败,为所有任务默认不激活
for action_name in task_action_names:
deactivated_actions.append((action_name, f"并行判断失败: {e}"))
deactivated_actions.extend((action_name, f"并行判断失败: {e}") for action_name in task_action_names)
return deactivated_actions

View File

@@ -501,9 +501,7 @@ class Prompt:
context_data.update(result)
# 合并预构建的参数,这会覆盖任何同名的实时构建结果
for key, value in pre_built_params.items():
if value:
context_data[key] = value
context_data.update({key: value for key, value in pre_built_params.items() if value})
except asyncio.TimeoutError:
# 这是一个不太可能发生的、总体的构建超时,作为最后的保障

View File

@@ -18,7 +18,7 @@ def get_voice_key(base64_content: str) -> str:
def register_self_voice(base64_content: str, text: str):
"""
为机器人自己发送的语音消息注册其原始文本。
Args:
base64_content (str): 语音的base64编码内容。
text (str): 原始文本。
@@ -30,10 +30,10 @@ def consume_self_voice_text(base64_content: str) -> str | None:
"""
获取并移除机器人自己发送的语音消息的原始文本。
这是一个一次性操作,获取后即从缓存中删除。
Args:
base64_content (str): 语音的base64编码内容。
Returns:
str | None: 如果找到则返回原始文本否则返回None。
"""

View File

@@ -234,7 +234,7 @@ class StatisticOutputTask(AsyncTask):
logger.exception(f"后台统计数据输出过程中发生异常:{e}")
# 创建后台任务,立即返回
asyncio.create_task(_async_collect_and_output())
asyncio.create_task(_async_collect_and_output()) # noqa: RUF006
# -- 以下为统计数据收集方法 --

View File

@@ -44,10 +44,10 @@ def db_message_to_str(message_dict: dict) -> str:
def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
"""检查消息是否提到了机器人
Args:
message: DatabaseMessages 消息对象
Returns:
tuple[bool, float]: (是否提及, 提及概率)
"""