re-style: 格式化代码

This commit is contained in:
John Richard
2025-10-02 20:26:01 +08:00
parent ecb02cae31
commit 7923eafef3
263 changed files with 3103 additions and 3123 deletions

View File

@@ -3,19 +3,20 @@
将原有的Prompt类和SmartPrompt功能整合为一个真正的Prompt类
"""
import re
import asyncio
import time
import contextvars
from dataclasses import dataclass, field
from typing import Dict, Any, Optional, List, Literal, Tuple
import re
import time
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import Any, Literal, Optional
from rich.traceback import install
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import build_readable_messages
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.utils.chat_message_builder import build_readable_messages
from src.chat.message_receive.chat_stream import get_chat_manager
from src.person_info.person_info import get_person_info_manager
install(extra_lines=3)
@@ -50,11 +51,11 @@ class PromptParameters:
debug_mode: bool = False
# 聊天历史和上下文
chat_target_info: Optional[Dict[str, Any]] = None
message_list_before_now_long: List[Dict[str, Any]] = field(default_factory=list)
message_list_before_short: List[Dict[str, Any]] = field(default_factory=list)
chat_target_info: dict[str, Any] | None = None
message_list_before_now_long: list[dict[str, Any]] = field(default_factory=list)
message_list_before_short: list[dict[str, Any]] = field(default_factory=list)
chat_talking_prompt_short: str = ""
target_user_info: Optional[Dict[str, Any]] = None
target_user_info: dict[str, Any] | None = None
# 已构建的内容块
expression_habits_block: str = ""
@@ -77,12 +78,12 @@ class PromptParameters:
action_descriptions: str = ""
# 可用动作信息
available_actions: Optional[Dict[str, Any]] = None
available_actions: dict[str, Any] | None = None
# 动态生成的聊天场景提示
chat_scene: str = ""
def validate(self) -> List[str]:
def validate(self) -> list[str]:
"""参数验证"""
errors = []
if not self.chat_id:
@@ -98,22 +99,22 @@ class PromptContext:
"""提示词上下文管理器"""
def __init__(self):
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
self._context_prompts: dict[str, dict[str, "Prompt"]] = {}
self._current_context_var = contextvars.ContextVar("current_context", default=None)
self._context_lock = asyncio.Lock()
@property
def _current_context(self) -> Optional[str]:
def _current_context(self) -> str | None:
"""获取当前协程的上下文ID"""
return self._current_context_var.get()
@_current_context.setter
def _current_context(self, value: Optional[str]):
def _current_context(self, value: str | None):
"""设置当前协程的上下文ID"""
self._current_context_var.set(value) # type: ignore
@asynccontextmanager
async def async_scope(self, context_id: Optional[str] = None):
async def async_scope(self, context_id: str | None = None):
"""创建一个异步的临时提示模板作用域"""
if context_id is not None:
try:
@@ -159,7 +160,7 @@ class PromptContext:
return self._context_prompts[current_context][name]
return None
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
async def register_async(self, prompt: "Prompt", context_id: str | None = None) -> None:
"""异步注册提示模板到指定作用域"""
async with self._context_lock:
if target_context := context_id or self._current_context:
@@ -177,7 +178,7 @@ class PromptManager:
self._lock = asyncio.Lock()
@asynccontextmanager
async def async_message_scope(self, message_id: Optional[str] = None):
async def async_message_scope(self, message_id: str | None = None):
"""为消息处理创建异步临时作用域"""
async with self._context.async_scope(message_id):
yield self
@@ -236,8 +237,8 @@ class Prompt:
def __init__(
self,
template: str,
name: Optional[str] = None,
parameters: Optional[PromptParameters] = None,
name: str | None = None,
parameters: PromptParameters | None = None,
should_register: bool = True,
):
"""
@@ -277,7 +278,7 @@ class Prompt:
"""将临时标记还原为实际的花括号字符"""
return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}")
def _parse_template_args(self, template: str) -> List[str]:
def _parse_template_args(self, template: str) -> list[str]:
"""解析模板参数"""
template_args = []
processed_template = self._process_escaped_braces(template)
@@ -321,7 +322,7 @@ class Prompt:
logger.error(f"构建Prompt失败: {e}")
raise RuntimeError(f"构建Prompt失败: {e}") from e
async def _build_context_data(self) -> Dict[str, Any]:
async def _build_context_data(self) -> dict[str, Any]:
"""构建智能上下文数据"""
# 并行执行所有构建任务
start_time = time.time()
@@ -401,7 +402,7 @@ class Prompt:
default_result = self._get_default_result_for_task(task_name)
results.append(default_result)
except Exception as e:
logger.error(f"构建任务{task_name}失败: {str(e)}")
logger.error(f"构建任务{task_name}失败: {e!s}")
default_result = self._get_default_result_for_task(task_name)
results.append(default_result)
@@ -411,7 +412,7 @@ class Prompt:
task_name = task_names[i] if i < len(task_names) else f"task_{i}"
if isinstance(result, Exception):
logger.error(f"构建任务{task_name}失败: {str(result)}")
logger.error(f"构建任务{task_name}失败: {result!s}")
elif isinstance(result, dict):
context_data.update(result)
@@ -453,7 +454,7 @@ class Prompt:
return context_data
async def _build_s4u_chat_context(self, context_data: Dict[str, Any]) -> None:
async def _build_s4u_chat_context(self, context_data: dict[str, Any]) -> None:
"""构建S4U模式的聊天上下文"""
if not self.parameters.message_list_before_now_long:
return
@@ -468,7 +469,7 @@ class Prompt:
context_data["read_history_prompt"] = read_history_prompt
context_data["unread_history_prompt"] = unread_history_prompt
async def _build_normal_chat_context(self, context_data: Dict[str, Any]) -> None:
async def _build_normal_chat_context(self, context_data: dict[str, Any]) -> None:
"""构建normal模式的聊天上下文"""
if not self.parameters.chat_talking_prompt_short:
return
@@ -477,8 +478,8 @@ class Prompt:
{self.parameters.chat_talking_prompt_short}"""
async def _build_s4u_chat_history_prompts(
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, chat_id: str
) -> Tuple[str, str]:
self, message_list_before_now: list[dict[str, Any]], target_user_id: str, sender: str, chat_id: str
) -> tuple[str, str]:
"""构建S4U风格的已读/未读历史消息prompt"""
try:
# 动态导入default_generator以避免循环导入
@@ -492,7 +493,7 @@ class Prompt:
except Exception as e:
logger.error(f"构建S4U历史消息prompt失败: {e}")
async def _build_expression_habits(self) -> Dict[str, Any]:
async def _build_expression_habits(self) -> dict[str, Any]:
"""构建表达习惯"""
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(self.parameters.chat_id)
if not use_expression:
@@ -533,7 +534,7 @@ class Prompt:
logger.error(f"构建表达习惯失败: {e}")
return {"expression_habits_block": ""}
async def _build_memory_block(self) -> Dict[str, Any]:
async def _build_memory_block(self) -> dict[str, Any]:
"""构建记忆块"""
if not global_config.memory.enable_memory:
return {"memory_block": ""}
@@ -653,7 +654,7 @@ class Prompt:
logger.error(f"构建记忆块失败: {e}")
return {"memory_block": ""}
async def _build_memory_block_fast(self) -> Dict[str, Any]:
async def _build_memory_block_fast(self) -> dict[str, Any]:
"""快速构建记忆块(简化版本,用于未预构建时的后备方案)"""
if not global_config.memory.enable_memory:
return {"memory_block": ""}
@@ -677,7 +678,7 @@ class Prompt:
logger.warning(f"快速构建记忆块失败: {e}")
return {"memory_block": ""}
async def _build_relation_info(self) -> Dict[str, Any]:
async def _build_relation_info(self) -> dict[str, Any]:
"""构建关系信息"""
try:
relation_info = await Prompt.build_relation_info(self.parameters.chat_id, self.parameters.reply_to)
@@ -686,7 +687,7 @@ class Prompt:
logger.error(f"构建关系信息失败: {e}")
return {"relation_info_block": ""}
async def _build_tool_info(self) -> Dict[str, Any]:
async def _build_tool_info(self) -> dict[str, Any]:
"""构建工具信息"""
if not global_config.tool.enable_tool:
return {"tool_info_block": ""}
@@ -734,7 +735,7 @@ class Prompt:
logger.error(f"构建工具信息失败: {e}")
return {"tool_info_block": ""}
async def _build_knowledge_info(self) -> Dict[str, Any]:
async def _build_knowledge_info(self) -> dict[str, Any]:
"""构建知识信息"""
if not global_config.lpmm_knowledge.enable:
return {"knowledge_prompt": ""}
@@ -783,7 +784,7 @@ class Prompt:
logger.error(f"构建知识信息失败: {e}")
return {"knowledge_prompt": ""}
async def _build_cross_context(self) -> Dict[str, Any]:
async def _build_cross_context(self) -> dict[str, Any]:
"""构建跨群上下文"""
try:
cross_context = await Prompt.build_cross_context(
@@ -794,7 +795,7 @@ class Prompt:
logger.error(f"构建跨群上下文失败: {e}")
return {"cross_context_block": ""}
async def _format_with_context(self, context_data: Dict[str, Any]) -> str:
async def _format_with_context(self, context_data: dict[str, Any]) -> str:
"""使用上下文数据格式化模板"""
if self.parameters.prompt_mode == "s4u":
params = self._prepare_s4u_params(context_data)
@@ -805,7 +806,7 @@ class Prompt:
return await global_prompt_manager.format_prompt(self.name, **params) if self.name else self.format(**params)
def _prepare_s4u_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
def _prepare_s4u_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
"""准备S4U模式的参数"""
return {
**context_data,
@@ -834,7 +835,7 @@ class Prompt:
or "你正在一个QQ群里聊天你需要理解整个群的聊天动态和话题走向并做出自然的回应。",
}
def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
def _prepare_normal_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
"""准备Normal模式的参数"""
return {
**context_data,
@@ -862,7 +863,7 @@ class Prompt:
or "你正在一个QQ群里聊天你需要理解整个群的聊天动态和话题走向并做出自然的回应。",
}
def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
def _prepare_default_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
"""准备默认模式的参数"""
return {
"expression_habits_block": context_data.get("expression_habits_block", ""),
@@ -905,7 +906,7 @@ class Prompt:
result = self._restore_escaped_braces(processed_template)
return result
except (IndexError, KeyError) as e:
raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {str(e)}") from e
raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {e!s}") from e
def __str__(self) -> str:
"""返回格式化后的结果或原始模板"""
@@ -922,7 +923,7 @@ class Prompt:
# =============================================================================
@staticmethod
def parse_reply_target(target_message: str) -> Tuple[str, str]:
def parse_reply_target(target_message: str) -> tuple[str, str]:
"""
解析回复目标消息 - 统一实现
@@ -981,7 +982,7 @@ class Prompt:
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
def _get_default_result_for_task(self, task_name: str) -> Dict[str, Any]:
def _get_default_result_for_task(self, task_name: str) -> dict[str, Any]:
"""
为超时的任务提供默认结果
@@ -1008,7 +1009,7 @@ class Prompt:
return {}
@staticmethod
async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]) -> str:
async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: dict[str, Any] | None) -> str:
"""
构建跨群聊上下文 - 统一实现
@@ -1071,7 +1072,7 @@ class Prompt:
# 工厂函数
def create_prompt(
template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs
template: str, name: str | None = None, parameters: PromptParameters | None = None, **kwargs
) -> Prompt:
"""快速创建Prompt实例的工厂函数"""
if parameters is None:
@@ -1080,7 +1081,7 @@ def create_prompt(
async def create_prompt_async(
template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs
template: str, name: str | None = None, parameters: PromptParameters | None = None, **kwargs
) -> Prompt:
"""异步创建Prompt实例"""
prompt = create_prompt(template, name, parameters, **kwargs)