style: 统一代码风格并采用现代化类型注解
对整个代码库进行了一次全面的代码风格清理和现代化改造,主要包括: - 移除了所有文件中多余的行尾空格。 - 将类型提示更新为 PEP 585 和 PEP 604 引入的现代语法(例如,使用 `list` 代替 `List`,使用 `|` 代替 `Optional`)。 - 清理了多个模块中未被使用的导入语句。 - 移除了不含插值变量的冗余 f-string。 - 调整了部分 `__init__.py` 文件中的 `__all__` 导出顺序,以保持一致性。 这些改动旨在提升代码的可读性和可维护性,使其与现代 Python 最佳实践保持一致,但未修改任何核心逻辑。
This commit is contained in:
committed by
Windpicker-owo
parent
5fa004503c
commit
f44ece0b29
@@ -161,16 +161,16 @@ class EmbeddingStore:
|
||||
# 限制 chunk_size 和 max_workers 在合理范围内
|
||||
chunk_size = max(MIN_CHUNK_SIZE, min(chunk_size, MAX_CHUNK_SIZE))
|
||||
max_workers = max(MIN_WORKERS, min(max_workers, MAX_WORKERS))
|
||||
|
||||
|
||||
semaphore = asyncio.Semaphore(max_workers)
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
results = {}
|
||||
|
||||
|
||||
# 将字符串列表分成多个 chunk
|
||||
chunks = []
|
||||
for i in range(0, len(strs), chunk_size):
|
||||
chunks.append(strs[i : i + chunk_size])
|
||||
|
||||
|
||||
async def _process_chunk(chunk: list[str]):
|
||||
"""处理一个 chunk 的字符串(批量获取 embedding)"""
|
||||
async with semaphore:
|
||||
@@ -180,12 +180,12 @@ class EmbeddingStore:
|
||||
embedding = await EmbeddingStore._get_embedding_async(llm, s)
|
||||
embeddings.append(embedding)
|
||||
results[s] = embedding
|
||||
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(len(chunk))
|
||||
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
# 并发处理所有 chunks
|
||||
tasks = [_process_chunk(chunk) for chunk in chunks]
|
||||
await asyncio.gather(*tasks)
|
||||
@@ -418,22 +418,22 @@ class EmbeddingStore:
|
||||
# 🔧 修复:检查所有 embedding 的维度是否一致
|
||||
dimensions = [len(emb) for emb in array]
|
||||
unique_dims = set(dimensions)
|
||||
|
||||
|
||||
if len(unique_dims) > 1:
|
||||
logger.error(f"检测到不一致的 embedding 维度: {unique_dims}")
|
||||
logger.error(f"维度分布: {dict(zip(*np.unique(dimensions, return_counts=True)))}")
|
||||
|
||||
|
||||
# 获取期望的维度(使用最常见的维度)
|
||||
from collections import Counter
|
||||
dim_counter = Counter(dimensions)
|
||||
expected_dim = dim_counter.most_common(1)[0][0]
|
||||
logger.warning(f"将使用最常见的维度: {expected_dim}")
|
||||
|
||||
|
||||
# 过滤掉维度不匹配的 embedding
|
||||
filtered_array = []
|
||||
filtered_idx2hash = {}
|
||||
skipped_count = 0
|
||||
|
||||
|
||||
for i, emb in enumerate(array):
|
||||
if len(emb) == expected_dim:
|
||||
filtered_array.append(emb)
|
||||
@@ -442,11 +442,11 @@ class EmbeddingStore:
|
||||
skipped_count += 1
|
||||
hash_key = self.idx2hash[str(i)]
|
||||
logger.warning(f"跳过维度不匹配的 embedding: {hash_key}, 维度={len(emb)}, 期望={expected_dim}")
|
||||
|
||||
|
||||
logger.warning(f"已过滤 {skipped_count} 个维度不匹配的 embedding")
|
||||
array = filtered_array
|
||||
self.idx2hash = filtered_idx2hash
|
||||
|
||||
|
||||
if not array:
|
||||
logger.error("过滤后没有可用的 embedding,无法构建索引")
|
||||
embedding_dim = expected_dim
|
||||
|
||||
@@ -13,4 +13,4 @@ __all__ = [
|
||||
"StreamLoopManager",
|
||||
"message_manager",
|
||||
"stream_loop_manager",
|
||||
]
|
||||
]
|
||||
|
||||
@@ -82,7 +82,7 @@ class SingleStreamContextManager:
|
||||
|
||||
self.total_messages += 1
|
||||
self.last_access_time = time.time()
|
||||
|
||||
|
||||
# 如果使用了缓存系统,输出调试信息
|
||||
if cache_enabled and self.context.is_cache_enabled:
|
||||
if self.context.is_chatter_processing:
|
||||
|
||||
@@ -111,9 +111,9 @@ class StreamLoopManager:
|
||||
# 获取或创建该流的启动锁
|
||||
if stream_id not in self._stream_start_locks:
|
||||
self._stream_start_locks[stream_id] = asyncio.Lock()
|
||||
|
||||
|
||||
lock = self._stream_start_locks[stream_id]
|
||||
|
||||
|
||||
# 使用锁防止并发启动同一个流的多个循环任务
|
||||
async with lock:
|
||||
# 获取流上下文
|
||||
@@ -148,7 +148,7 @@ class StreamLoopManager:
|
||||
# 紧急取消
|
||||
context.stream_loop_task.cancel()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
loop_task = asyncio.create_task(self._stream_loop_worker(stream_id), name=f"stream_loop_{stream_id}")
|
||||
|
||||
# 将任务记录到 StreamContext 中
|
||||
@@ -252,7 +252,7 @@ class StreamLoopManager:
|
||||
self.stats["total_process_cycles"] += 1
|
||||
if success:
|
||||
logger.info(f"✅ [流工作器] stream={stream_id[:8]}, 任务ID={task_id}, 处理成功")
|
||||
|
||||
|
||||
# 🔒 处理成功后,等待一小段时间确保清理操作完成
|
||||
# 这样可以避免在 chatter_manager 清除未读消息之前就进入下一轮循环
|
||||
await asyncio.sleep(0.1)
|
||||
@@ -382,7 +382,7 @@ class StreamLoopManager:
|
||||
self.chatter_manager.process_stream_context(stream_id, context),
|
||||
name=f"chatter_process_{stream_id}"
|
||||
)
|
||||
|
||||
|
||||
# 等待 chatter 任务完成
|
||||
results = await chatter_task
|
||||
success = results.get("success", False)
|
||||
@@ -398,8 +398,8 @@ class StreamLoopManager:
|
||||
else:
|
||||
logger.warning(f"流处理失败: {stream_id} - {results.get('error_message', '未知错误')}")
|
||||
|
||||
return success
|
||||
except asyncio.CancelledError:
|
||||
return success
|
||||
except asyncio.CancelledError:
|
||||
if chatter_task and not chatter_task.done():
|
||||
chatter_task.cancel()
|
||||
raise
|
||||
@@ -709,4 +709,4 @@ class StreamLoopManager:
|
||||
|
||||
|
||||
# 全局流循环管理器实例
|
||||
stream_loop_manager = StreamLoopManager()
|
||||
stream_loop_manager = StreamLoopManager()
|
||||
|
||||
@@ -417,7 +417,7 @@ class MessageManager:
|
||||
return
|
||||
|
||||
# 记录详细信息
|
||||
msg_previews = [f"{str(msg.message_id)[:8] if msg.message_id else 'unknown'}:{msg.processed_plain_text[:20] if msg.processed_plain_text else '(空)'}"
|
||||
msg_previews = [f"{str(msg.message_id)[:8] if msg.message_id else 'unknown'}:{msg.processed_plain_text[:20] if msg.processed_plain_text else '(空)'}"
|
||||
for msg in unread_messages[:3]] # 只显示前3条
|
||||
logger.info(f"🧹 [清除未读] stream={stream_id[:8]}, 开始清除 {len(unread_messages)} 条未读消息, 示例: {msg_previews}")
|
||||
|
||||
@@ -446,15 +446,15 @@ class MessageManager:
|
||||
context = chat_stream.context_manager.context
|
||||
if hasattr(context, "unread_messages") and context.unread_messages:
|
||||
unread_count = len(context.unread_messages)
|
||||
|
||||
|
||||
# 如果还有未读消息,说明 action_manager 可能遗漏了,标记它们
|
||||
if unread_count > 0:
|
||||
if unread_count > 0:
|
||||
# 获取所有未读消息的 ID
|
||||
message_ids = [msg.message_id for msg in context.unread_messages]
|
||||
|
||||
|
||||
# 标记为已读(会移到历史消息)
|
||||
success = chat_stream.context_manager.mark_messages_as_read(message_ids)
|
||||
|
||||
|
||||
if success:
|
||||
logger.debug(f"✅ stream={stream_id[:8]}, 成功标记 {unread_count} 条消息为已读")
|
||||
else:
|
||||
@@ -481,7 +481,7 @@ class MessageManager:
|
||||
try:
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = await chat_manager.get_stream(stream_id)
|
||||
if chat_stream and hasattr(chat_stream.context_manager.context, 'is_chatter_processing'):
|
||||
if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"):
|
||||
chat_stream.context_manager.context.is_chatter_processing = is_processing
|
||||
logger.debug(f"设置StreamContext处理状态: stream={stream_id}, processing={is_processing}")
|
||||
except Exception as e:
|
||||
@@ -517,7 +517,7 @@ class MessageManager:
|
||||
try:
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = await chat_manager.get_stream(stream_id)
|
||||
if chat_stream and hasattr(chat_stream.context_manager.context, 'is_chatter_processing'):
|
||||
if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"):
|
||||
return chat_stream.context_manager.context.is_chatter_processing
|
||||
except Exception:
|
||||
pass
|
||||
@@ -677,4 +677,4 @@ class MessageManager:
|
||||
|
||||
|
||||
# 创建全局消息管理器实例
|
||||
message_manager = MessageManager()
|
||||
message_manager = MessageManager()
|
||||
|
||||
@@ -248,16 +248,16 @@ class ChatterActionManager:
|
||||
try:
|
||||
# 根据动作类型确定提示词模式
|
||||
prompt_mode = "s4u" if action_name == "reply" else "normal"
|
||||
|
||||
|
||||
# 将prompt_mode传递给generate_reply
|
||||
action_data_with_mode = (action_data or {}).copy()
|
||||
action_data_with_mode["prompt_mode"] = prompt_mode
|
||||
|
||||
|
||||
# 只传递当前正在执行的动作,而不是所有可用动作
|
||||
# 这样可以让LLM明确知道"已决定执行X动作",而不是"有这些动作可用"
|
||||
current_action_info = self._using_actions.get(action_name)
|
||||
current_actions: dict[str, Any] = {action_name: current_action_info} if current_action_info else {}
|
||||
|
||||
|
||||
# 附加目标消息信息(如果存在)
|
||||
if target_message:
|
||||
# 提取目标消息的关键信息
|
||||
@@ -268,7 +268,7 @@ class ChatterActionManager:
|
||||
"time": getattr(target_message, "time", 0),
|
||||
}
|
||||
current_actions["_target_message"] = target_msg_info
|
||||
|
||||
|
||||
success, response_set, _ = await generator_api.generate_reply(
|
||||
chat_stream=chat_stream,
|
||||
reply_message=target_message,
|
||||
@@ -295,12 +295,12 @@ class ChatterActionManager:
|
||||
should_quote_reply = None
|
||||
if action_data and isinstance(action_data, dict):
|
||||
should_quote_reply = action_data.get("should_quote_reply", None)
|
||||
|
||||
|
||||
# respond动作默认不引用回复,保持对话流畅
|
||||
if action_name == "respond" and should_quote_reply is None:
|
||||
should_quote_reply = False
|
||||
|
||||
async def _after_reply():
|
||||
async def _after_reply():
|
||||
# 发送并存储回复
|
||||
loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply(
|
||||
chat_stream,
|
||||
|
||||
@@ -372,7 +372,7 @@ class DefaultReplyer:
|
||||
# 确保类型安全
|
||||
if isinstance(mode, str):
|
||||
prompt_mode_value = mode
|
||||
|
||||
|
||||
# 构建 Prompt
|
||||
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
||||
prompt = await self.build_prompt_reply_context(
|
||||
@@ -1166,16 +1166,16 @@ class DefaultReplyer:
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream_obj = await chat_manager.get_stream(chat_id)
|
||||
|
||||
|
||||
if chat_stream_obj:
|
||||
unread_messages = chat_stream_obj.context_manager.get_unread_messages()
|
||||
if unread_messages:
|
||||
# 使用最后一条未读消息作为参考
|
||||
last_msg = unread_messages[-1]
|
||||
platform = last_msg.chat_info.platform if hasattr(last_msg, 'chat_info') else chat_stream.platform
|
||||
user_id = last_msg.user_info.user_id if hasattr(last_msg, 'user_info') else ""
|
||||
user_nickname = last_msg.user_info.user_nickname if hasattr(last_msg, 'user_info') else ""
|
||||
user_cardname = last_msg.user_info.user_cardname if hasattr(last_msg, 'user_info') else ""
|
||||
platform = last_msg.chat_info.platform if hasattr(last_msg, "chat_info") else chat_stream.platform
|
||||
user_id = last_msg.user_info.user_id if hasattr(last_msg, "user_info") else ""
|
||||
user_nickname = last_msg.user_info.user_nickname if hasattr(last_msg, "user_info") else ""
|
||||
user_cardname = last_msg.user_info.user_cardname if hasattr(last_msg, "user_info") else ""
|
||||
processed_plain_text = last_msg.processed_plain_text or ""
|
||||
else:
|
||||
# 没有未读消息,使用默认值
|
||||
@@ -1258,19 +1258,19 @@ class DefaultReplyer:
|
||||
if available_actions:
|
||||
# 过滤掉特殊键(以_开头)
|
||||
action_items = {k: v for k, v in available_actions.items() if not k.startswith("_")}
|
||||
|
||||
|
||||
# 提取目标消息信息(如果存在)
|
||||
target_msg_info = available_actions.get("_target_message") # type: ignore
|
||||
|
||||
|
||||
if action_items:
|
||||
if len(action_items) == 1:
|
||||
# 单个动作
|
||||
action_name, action_info = list(action_items.items())[0]
|
||||
action_desc = action_info.description
|
||||
|
||||
|
||||
# 构建基础决策信息
|
||||
action_descriptions = f"## 决策信息\n\n你已经决定要执行 **{action_name}** 动作({action_desc})。\n\n"
|
||||
|
||||
|
||||
# 只有需要目标消息的动作才显示目标消息详情
|
||||
# respond 动作是统一回应所有未读消息,不应该显示特定目标消息
|
||||
if action_name not in ["respond"] and target_msg_info and isinstance(target_msg_info, dict):
|
||||
@@ -1279,7 +1279,7 @@ class DefaultReplyer:
|
||||
content = target_msg_info.get("content", "")
|
||||
msg_time = target_msg_info.get("time", 0)
|
||||
time_str = time_module.strftime("%H:%M:%S", time_module.localtime(msg_time)) if msg_time else "未知时间"
|
||||
|
||||
|
||||
action_descriptions += f"**目标消息**: {time_str} {sender} 说: {content}\n\n"
|
||||
else:
|
||||
# 多个动作
|
||||
@@ -2166,7 +2166,7 @@ class DefaultReplyer:
|
||||
except Exception as e:
|
||||
logger.error(f"存储聊天记忆失败: {e}")
|
||||
|
||||
|
||||
|
||||
|
||||
def weighted_sample_no_replacement(items, weights, k) -> list:
|
||||
"""
|
||||
|
||||
@@ -5,12 +5,12 @@
|
||||
插件可以通过实现这些接口来扩展安全功能。
|
||||
"""
|
||||
|
||||
from .interfaces import SecurityCheckResult, SecurityChecker
|
||||
from .interfaces import SecurityChecker, SecurityCheckResult
|
||||
from .manager import SecurityManager, get_security_manager
|
||||
|
||||
__all__ = [
|
||||
"SecurityChecker",
|
||||
"SecurityCheckResult",
|
||||
"SecurityChecker",
|
||||
"SecurityManager",
|
||||
"get_security_manager",
|
||||
]
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .interfaces import SecurityAction, SecurityCheckResult, SecurityChecker, SecurityLevel
|
||||
from .interfaces import SecurityAction, SecurityChecker, SecurityCheckResult, SecurityLevel
|
||||
|
||||
logger = get_logger("security.manager")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user