style: 统一代码风格并采用现代化类型注解

对整个代码库进行了一次全面的代码风格清理和现代化改造,主要包括:

- 移除了所有文件中多余的行尾空格。
- 将类型提示更新为 PEP 585 和 PEP 604 引入的现代语法(例如,使用 `list` 代替 `List`,使用 `|` 代替 `Optional`)。
- 清理了多个模块中未被使用的导入语句。
- 移除了不含插值变量的冗余 f-string。
- 调整了部分 `__init__.py` 文件中的 `__all__` 导出顺序,以保持一致性。

这些改动旨在提升代码的可读性和可维护性,使其与现代 Python 最佳实践保持一致,但未修改任何核心逻辑。
This commit is contained in:
minecraft1024a
2025-11-12 12:49:40 +08:00
parent daf8ea7e6a
commit 0e1e9935b2
33 changed files with 227 additions and 229 deletions

View File

@@ -161,16 +161,16 @@ class EmbeddingStore:
# 限制 chunk_size 和 max_workers 在合理范围内
chunk_size = max(MIN_CHUNK_SIZE, min(chunk_size, MAX_CHUNK_SIZE))
max_workers = max(MIN_WORKERS, min(max_workers, MAX_WORKERS))
semaphore = asyncio.Semaphore(max_workers)
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
results = {}
# 将字符串列表分成多个 chunk
chunks = []
for i in range(0, len(strs), chunk_size):
chunks.append(strs[i : i + chunk_size])
async def _process_chunk(chunk: list[str]):
"""处理一个 chunk 的字符串(批量获取 embedding"""
async with semaphore:
@@ -180,12 +180,12 @@ class EmbeddingStore:
embedding = await EmbeddingStore._get_embedding_async(llm, s)
embeddings.append(embedding)
results[s] = embedding
if progress_callback:
progress_callback(len(chunk))
return embeddings
# 并发处理所有 chunks
tasks = [_process_chunk(chunk) for chunk in chunks]
await asyncio.gather(*tasks)
@@ -418,22 +418,22 @@ class EmbeddingStore:
# 🔧 修复:检查所有 embedding 的维度是否一致
dimensions = [len(emb) for emb in array]
unique_dims = set(dimensions)
if len(unique_dims) > 1:
logger.error(f"检测到不一致的 embedding 维度: {unique_dims}")
logger.error(f"维度分布: {dict(zip(*np.unique(dimensions, return_counts=True)))}")
# 获取期望的维度(使用最常见的维度)
from collections import Counter
dim_counter = Counter(dimensions)
expected_dim = dim_counter.most_common(1)[0][0]
logger.warning(f"将使用最常见的维度: {expected_dim}")
# 过滤掉维度不匹配的 embedding
filtered_array = []
filtered_idx2hash = {}
skipped_count = 0
for i, emb in enumerate(array):
if len(emb) == expected_dim:
filtered_array.append(emb)
@@ -442,11 +442,11 @@ class EmbeddingStore:
skipped_count += 1
hash_key = self.idx2hash[str(i)]
logger.warning(f"跳过维度不匹配的 embedding: {hash_key}, 维度={len(emb)}, 期望={expected_dim}")
logger.warning(f"已过滤 {skipped_count} 个维度不匹配的 embedding")
array = filtered_array
self.idx2hash = filtered_idx2hash
if not array:
logger.error("过滤后没有可用的 embedding无法构建索引")
embedding_dim = expected_dim

View File

@@ -13,4 +13,4 @@ __all__ = [
"StreamLoopManager",
"message_manager",
"stream_loop_manager",
]
]

View File

@@ -82,7 +82,7 @@ class SingleStreamContextManager:
self.total_messages += 1
self.last_access_time = time.time()
# 如果使用了缓存系统,输出调试信息
if cache_enabled and self.context.is_cache_enabled:
if self.context.is_chatter_processing:

View File

@@ -111,9 +111,9 @@ class StreamLoopManager:
# 获取或创建该流的启动锁
if stream_id not in self._stream_start_locks:
self._stream_start_locks[stream_id] = asyncio.Lock()
lock = self._stream_start_locks[stream_id]
# 使用锁防止并发启动同一个流的多个循环任务
async with lock:
# 获取流上下文
@@ -148,7 +148,7 @@ class StreamLoopManager:
# 紧急取消
context.stream_loop_task.cancel()
await asyncio.sleep(0.1)
loop_task = asyncio.create_task(self._stream_loop_worker(stream_id), name=f"stream_loop_{stream_id}")
# 将任务记录到 StreamContext 中
@@ -249,7 +249,7 @@ class StreamLoopManager:
self.stats["total_process_cycles"] += 1
if success:
logger.info(f"✅ [流工作器] stream={stream_id[:8]}, 任务ID={task_id}, 处理成功")
# 🔒 处理成功后,等待一小段时间确保清理操作完成
# 这样可以避免在 chatter_manager 清除未读消息之前就进入下一轮循环
await asyncio.sleep(0.1)
@@ -379,7 +379,7 @@ class StreamLoopManager:
self.chatter_manager.process_stream_context(stream_id, context),
name=f"chatter_process_{stream_id}"
)
# 等待 chatter 任务完成
results = await chatter_task
success = results.get("success", False)
@@ -395,8 +395,8 @@ class StreamLoopManager:
else:
logger.warning(f"流处理失败: {stream_id} - {results.get('error_message', '未知错误')}")
return success
except asyncio.CancelledError:
return success
except asyncio.CancelledError:
if chatter_task and not chatter_task.done():
chatter_task.cancel()
raise
@@ -706,4 +706,4 @@ class StreamLoopManager:
# 全局流循环管理器实例
stream_loop_manager = StreamLoopManager()
stream_loop_manager = StreamLoopManager()

View File

@@ -417,7 +417,7 @@ class MessageManager:
return
# 记录详细信息
msg_previews = [f"{str(msg.message_id)[:8] if msg.message_id else 'unknown'}:{msg.processed_plain_text[:20] if msg.processed_plain_text else '(空)'}"
msg_previews = [f"{str(msg.message_id)[:8] if msg.message_id else 'unknown'}:{msg.processed_plain_text[:20] if msg.processed_plain_text else '(空)'}"
for msg in unread_messages[:3]] # 只显示前3条
logger.info(f"🧹 [清除未读] stream={stream_id[:8]}, 开始清除 {len(unread_messages)} 条未读消息, 示例: {msg_previews}")
@@ -446,15 +446,15 @@ class MessageManager:
context = chat_stream.context_manager.context
if hasattr(context, "unread_messages") and context.unread_messages:
unread_count = len(context.unread_messages)
# 如果还有未读消息,说明 action_manager 可能遗漏了,标记它们
if unread_count > 0:
if unread_count > 0:
# 获取所有未读消息的 ID
message_ids = [msg.message_id for msg in context.unread_messages]
# 标记为已读(会移到历史消息)
success = chat_stream.context_manager.mark_messages_as_read(message_ids)
if success:
logger.debug(f"✅ stream={stream_id[:8]}, 成功标记 {unread_count} 条消息为已读")
else:
@@ -481,7 +481,7 @@ class MessageManager:
try:
chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream and hasattr(chat_stream.context_manager.context, 'is_chatter_processing'):
if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"):
chat_stream.context_manager.context.is_chatter_processing = is_processing
logger.debug(f"设置StreamContext处理状态: stream={stream_id}, processing={is_processing}")
except Exception as e:
@@ -517,7 +517,7 @@ class MessageManager:
try:
chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream and hasattr(chat_stream.context_manager.context, 'is_chatter_processing'):
if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"):
return chat_stream.context_manager.context.is_chatter_processing
except Exception:
pass
@@ -677,4 +677,4 @@ class MessageManager:
# 创建全局消息管理器实例
message_manager = MessageManager()
message_manager = MessageManager()

View File

@@ -248,16 +248,16 @@ class ChatterActionManager:
try:
# 根据动作类型确定提示词模式
prompt_mode = "s4u" if action_name == "reply" else "normal"
# 将prompt_mode传递给generate_reply
action_data_with_mode = (action_data or {}).copy()
action_data_with_mode["prompt_mode"] = prompt_mode
# 只传递当前正在执行的动作,而不是所有可用动作
# 这样可以让LLM明确知道"已决定执行X动作",而不是"有这些动作可用"
current_action_info = self._using_actions.get(action_name)
current_actions: dict[str, Any] = {action_name: current_action_info} if current_action_info else {}
# 附加目标消息信息(如果存在)
if target_message:
# 提取目标消息的关键信息
@@ -268,7 +268,7 @@ class ChatterActionManager:
"time": getattr(target_message, "time", 0),
}
current_actions["_target_message"] = target_msg_info
success, response_set, _ = await generator_api.generate_reply(
chat_stream=chat_stream,
reply_message=target_message,
@@ -295,12 +295,12 @@ class ChatterActionManager:
should_quote_reply = None
if action_data and isinstance(action_data, dict):
should_quote_reply = action_data.get("should_quote_reply", None)
# respond动作默认不引用回复保持对话流畅
if action_name == "respond" and should_quote_reply is None:
should_quote_reply = False
async def _after_reply():
async def _after_reply():
# 发送并存储回复
loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply(
chat_stream,

View File

@@ -365,7 +365,7 @@ class DefaultReplyer:
# 确保类型安全
if isinstance(mode, str):
prompt_mode_value = mode
# 构建 Prompt
with Timer("构建Prompt", {}): # 内部计时器,可选保留
prompt = await self.build_prompt_reply_context(
@@ -1171,16 +1171,16 @@ class DefaultReplyer:
from src.plugin_system.apis.chat_api import get_chat_manager
chat_manager = get_chat_manager()
chat_stream_obj = await chat_manager.get_stream(chat_id)
if chat_stream_obj:
unread_messages = chat_stream_obj.context_manager.get_unread_messages()
if unread_messages:
# 使用最后一条未读消息作为参考
last_msg = unread_messages[-1]
platform = last_msg.chat_info.platform if hasattr(last_msg, 'chat_info') else chat_stream.platform
user_id = last_msg.user_info.user_id if hasattr(last_msg, 'user_info') else ""
user_nickname = last_msg.user_info.user_nickname if hasattr(last_msg, 'user_info') else ""
user_cardname = last_msg.user_info.user_cardname if hasattr(last_msg, 'user_info') else ""
platform = last_msg.chat_info.platform if hasattr(last_msg, "chat_info") else chat_stream.platform
user_id = last_msg.user_info.user_id if hasattr(last_msg, "user_info") else ""
user_nickname = last_msg.user_info.user_nickname if hasattr(last_msg, "user_info") else ""
user_cardname = last_msg.user_info.user_cardname if hasattr(last_msg, "user_info") else ""
processed_plain_text = last_msg.processed_plain_text or ""
else:
# 没有未读消息,使用默认值
@@ -1263,19 +1263,19 @@ class DefaultReplyer:
if available_actions:
# 过滤掉特殊键以_开头
action_items = {k: v for k, v in available_actions.items() if not k.startswith("_")}
# 提取目标消息信息(如果存在)
target_msg_info = available_actions.get("_target_message") # type: ignore
if action_items:
if len(action_items) == 1:
# 单个动作
action_name, action_info = list(action_items.items())[0]
action_desc = action_info.description
# 构建基础决策信息
action_descriptions = f"## 决策信息\n\n你已经决定要执行 **{action_name}** 动作({action_desc})。\n\n"
# 只有需要目标消息的动作才显示目标消息详情
# respond 动作是统一回应所有未读消息,不应该显示特定目标消息
if action_name not in ["respond"] and target_msg_info and isinstance(target_msg_info, dict):
@@ -1284,7 +1284,7 @@ class DefaultReplyer:
content = target_msg_info.get("content", "")
msg_time = target_msg_info.get("time", 0)
time_str = time_module.strftime("%H:%M:%S", time_module.localtime(msg_time)) if msg_time else "未知时间"
action_descriptions += f"**目标消息**: {time_str} {sender} 说: {content}\n\n"
else:
# 多个动作
@@ -2137,7 +2137,7 @@ class DefaultReplyer:
except Exception as e:
logger.error(f"存储聊天记忆失败: {e}")
def weighted_sample_no_replacement(items, weights, k) -> list:
"""

View File

@@ -5,12 +5,12 @@
插件可以通过实现这些接口来扩展安全功能。
"""
from .interfaces import SecurityCheckResult, SecurityChecker
from .interfaces import SecurityChecker, SecurityCheckResult
from .manager import SecurityManager, get_security_manager
__all__ = [
"SecurityChecker",
"SecurityCheckResult",
"SecurityChecker",
"SecurityManager",
"get_security_manager",
]

View File

@@ -10,7 +10,7 @@ from typing import Any
from src.common.logger import get_logger
from .interfaces import SecurityAction, SecurityCheckResult, SecurityChecker, SecurityLevel
from .interfaces import SecurityAction, SecurityChecker, SecurityCheckResult, SecurityLevel
logger = get_logger("security.manager")