re-style: 格式化代码
This commit is contained in:
@@ -3,19 +3,20 @@
|
||||
将原有的Prompt类和SmartPrompt功能整合为一个真正的Prompt类
|
||||
"""
|
||||
|
||||
import re
|
||||
import asyncio
|
||||
import time
|
||||
import contextvars
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Any, Optional, List, Literal, Tuple
|
||||
import re
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -50,11 +51,11 @@ class PromptParameters:
|
||||
debug_mode: bool = False
|
||||
|
||||
# 聊天历史和上下文
|
||||
chat_target_info: Optional[Dict[str, Any]] = None
|
||||
message_list_before_now_long: List[Dict[str, Any]] = field(default_factory=list)
|
||||
message_list_before_short: List[Dict[str, Any]] = field(default_factory=list)
|
||||
chat_target_info: dict[str, Any] | None = None
|
||||
message_list_before_now_long: list[dict[str, Any]] = field(default_factory=list)
|
||||
message_list_before_short: list[dict[str, Any]] = field(default_factory=list)
|
||||
chat_talking_prompt_short: str = ""
|
||||
target_user_info: Optional[Dict[str, Any]] = None
|
||||
target_user_info: dict[str, Any] | None = None
|
||||
|
||||
# 已构建的内容块
|
||||
expression_habits_block: str = ""
|
||||
@@ -77,12 +78,12 @@ class PromptParameters:
|
||||
action_descriptions: str = ""
|
||||
|
||||
# 可用动作信息
|
||||
available_actions: Optional[Dict[str, Any]] = None
|
||||
available_actions: dict[str, Any] | None = None
|
||||
|
||||
# 动态生成的聊天场景提示
|
||||
chat_scene: str = ""
|
||||
|
||||
def validate(self) -> List[str]:
|
||||
def validate(self) -> list[str]:
|
||||
"""参数验证"""
|
||||
errors = []
|
||||
if not self.chat_id:
|
||||
@@ -98,22 +99,22 @@ class PromptContext:
|
||||
"""提示词上下文管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
|
||||
self._context_prompts: dict[str, dict[str, "Prompt"]] = {}
|
||||
self._current_context_var = contextvars.ContextVar("current_context", default=None)
|
||||
self._context_lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def _current_context(self) -> Optional[str]:
|
||||
def _current_context(self) -> str | None:
|
||||
"""获取当前协程的上下文ID"""
|
||||
return self._current_context_var.get()
|
||||
|
||||
@_current_context.setter
|
||||
def _current_context(self, value: Optional[str]):
|
||||
def _current_context(self, value: str | None):
|
||||
"""设置当前协程的上下文ID"""
|
||||
self._current_context_var.set(value) # type: ignore
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_scope(self, context_id: Optional[str] = None):
|
||||
async def async_scope(self, context_id: str | None = None):
|
||||
"""创建一个异步的临时提示模板作用域"""
|
||||
if context_id is not None:
|
||||
try:
|
||||
@@ -159,7 +160,7 @@ class PromptContext:
|
||||
return self._context_prompts[current_context][name]
|
||||
return None
|
||||
|
||||
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
|
||||
async def register_async(self, prompt: "Prompt", context_id: str | None = None) -> None:
|
||||
"""异步注册提示模板到指定作用域"""
|
||||
async with self._context_lock:
|
||||
if target_context := context_id or self._current_context:
|
||||
@@ -177,7 +178,7 @@ class PromptManager:
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_message_scope(self, message_id: Optional[str] = None):
|
||||
async def async_message_scope(self, message_id: str | None = None):
|
||||
"""为消息处理创建异步临时作用域"""
|
||||
async with self._context.async_scope(message_id):
|
||||
yield self
|
||||
@@ -236,8 +237,8 @@ class Prompt:
|
||||
def __init__(
|
||||
self,
|
||||
template: str,
|
||||
name: Optional[str] = None,
|
||||
parameters: Optional[PromptParameters] = None,
|
||||
name: str | None = None,
|
||||
parameters: PromptParameters | None = None,
|
||||
should_register: bool = True,
|
||||
):
|
||||
"""
|
||||
@@ -277,7 +278,7 @@ class Prompt:
|
||||
"""将临时标记还原为实际的花括号字符"""
|
||||
return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}")
|
||||
|
||||
def _parse_template_args(self, template: str) -> List[str]:
|
||||
def _parse_template_args(self, template: str) -> list[str]:
|
||||
"""解析模板参数"""
|
||||
template_args = []
|
||||
processed_template = self._process_escaped_braces(template)
|
||||
@@ -321,7 +322,7 @@ class Prompt:
|
||||
logger.error(f"构建Prompt失败: {e}")
|
||||
raise RuntimeError(f"构建Prompt失败: {e}") from e
|
||||
|
||||
async def _build_context_data(self) -> Dict[str, Any]:
|
||||
async def _build_context_data(self) -> dict[str, Any]:
|
||||
"""构建智能上下文数据"""
|
||||
# 并行执行所有构建任务
|
||||
start_time = time.time()
|
||||
@@ -401,7 +402,7 @@ class Prompt:
|
||||
default_result = self._get_default_result_for_task(task_name)
|
||||
results.append(default_result)
|
||||
except Exception as e:
|
||||
logger.error(f"构建任务{task_name}失败: {str(e)}")
|
||||
logger.error(f"构建任务{task_name}失败: {e!s}")
|
||||
default_result = self._get_default_result_for_task(task_name)
|
||||
results.append(default_result)
|
||||
|
||||
@@ -411,7 +412,7 @@ class Prompt:
|
||||
task_name = task_names[i] if i < len(task_names) else f"task_{i}"
|
||||
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"构建任务{task_name}失败: {str(result)}")
|
||||
logger.error(f"构建任务{task_name}失败: {result!s}")
|
||||
elif isinstance(result, dict):
|
||||
context_data.update(result)
|
||||
|
||||
@@ -453,7 +454,7 @@ class Prompt:
|
||||
|
||||
return context_data
|
||||
|
||||
async def _build_s4u_chat_context(self, context_data: Dict[str, Any]) -> None:
|
||||
async def _build_s4u_chat_context(self, context_data: dict[str, Any]) -> None:
|
||||
"""构建S4U模式的聊天上下文"""
|
||||
if not self.parameters.message_list_before_now_long:
|
||||
return
|
||||
@@ -468,7 +469,7 @@ class Prompt:
|
||||
context_data["read_history_prompt"] = read_history_prompt
|
||||
context_data["unread_history_prompt"] = unread_history_prompt
|
||||
|
||||
async def _build_normal_chat_context(self, context_data: Dict[str, Any]) -> None:
|
||||
async def _build_normal_chat_context(self, context_data: dict[str, Any]) -> None:
|
||||
"""构建normal模式的聊天上下文"""
|
||||
if not self.parameters.chat_talking_prompt_short:
|
||||
return
|
||||
@@ -477,8 +478,8 @@ class Prompt:
|
||||
{self.parameters.chat_talking_prompt_short}"""
|
||||
|
||||
async def _build_s4u_chat_history_prompts(
|
||||
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, chat_id: str
|
||||
) -> Tuple[str, str]:
|
||||
self, message_list_before_now: list[dict[str, Any]], target_user_id: str, sender: str, chat_id: str
|
||||
) -> tuple[str, str]:
|
||||
"""构建S4U风格的已读/未读历史消息prompt"""
|
||||
try:
|
||||
# 动态导入default_generator以避免循环导入
|
||||
@@ -492,7 +493,7 @@ class Prompt:
|
||||
except Exception as e:
|
||||
logger.error(f"构建S4U历史消息prompt失败: {e}")
|
||||
|
||||
async def _build_expression_habits(self) -> Dict[str, Any]:
|
||||
async def _build_expression_habits(self) -> dict[str, Any]:
|
||||
"""构建表达习惯"""
|
||||
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(self.parameters.chat_id)
|
||||
if not use_expression:
|
||||
@@ -533,7 +534,7 @@ class Prompt:
|
||||
logger.error(f"构建表达习惯失败: {e}")
|
||||
return {"expression_habits_block": ""}
|
||||
|
||||
async def _build_memory_block(self) -> Dict[str, Any]:
|
||||
async def _build_memory_block(self) -> dict[str, Any]:
|
||||
"""构建记忆块"""
|
||||
if not global_config.memory.enable_memory:
|
||||
return {"memory_block": ""}
|
||||
@@ -653,7 +654,7 @@ class Prompt:
|
||||
logger.error(f"构建记忆块失败: {e}")
|
||||
return {"memory_block": ""}
|
||||
|
||||
async def _build_memory_block_fast(self) -> Dict[str, Any]:
|
||||
async def _build_memory_block_fast(self) -> dict[str, Any]:
|
||||
"""快速构建记忆块(简化版本,用于未预构建时的后备方案)"""
|
||||
if not global_config.memory.enable_memory:
|
||||
return {"memory_block": ""}
|
||||
@@ -677,7 +678,7 @@ class Prompt:
|
||||
logger.warning(f"快速构建记忆块失败: {e}")
|
||||
return {"memory_block": ""}
|
||||
|
||||
async def _build_relation_info(self) -> Dict[str, Any]:
|
||||
async def _build_relation_info(self) -> dict[str, Any]:
|
||||
"""构建关系信息"""
|
||||
try:
|
||||
relation_info = await Prompt.build_relation_info(self.parameters.chat_id, self.parameters.reply_to)
|
||||
@@ -686,7 +687,7 @@ class Prompt:
|
||||
logger.error(f"构建关系信息失败: {e}")
|
||||
return {"relation_info_block": ""}
|
||||
|
||||
async def _build_tool_info(self) -> Dict[str, Any]:
|
||||
async def _build_tool_info(self) -> dict[str, Any]:
|
||||
"""构建工具信息"""
|
||||
if not global_config.tool.enable_tool:
|
||||
return {"tool_info_block": ""}
|
||||
@@ -734,7 +735,7 @@ class Prompt:
|
||||
logger.error(f"构建工具信息失败: {e}")
|
||||
return {"tool_info_block": ""}
|
||||
|
||||
async def _build_knowledge_info(self) -> Dict[str, Any]:
|
||||
async def _build_knowledge_info(self) -> dict[str, Any]:
|
||||
"""构建知识信息"""
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
return {"knowledge_prompt": ""}
|
||||
@@ -783,7 +784,7 @@ class Prompt:
|
||||
logger.error(f"构建知识信息失败: {e}")
|
||||
return {"knowledge_prompt": ""}
|
||||
|
||||
async def _build_cross_context(self) -> Dict[str, Any]:
|
||||
async def _build_cross_context(self) -> dict[str, Any]:
|
||||
"""构建跨群上下文"""
|
||||
try:
|
||||
cross_context = await Prompt.build_cross_context(
|
||||
@@ -794,7 +795,7 @@ class Prompt:
|
||||
logger.error(f"构建跨群上下文失败: {e}")
|
||||
return {"cross_context_block": ""}
|
||||
|
||||
async def _format_with_context(self, context_data: Dict[str, Any]) -> str:
|
||||
async def _format_with_context(self, context_data: dict[str, Any]) -> str:
|
||||
"""使用上下文数据格式化模板"""
|
||||
if self.parameters.prompt_mode == "s4u":
|
||||
params = self._prepare_s4u_params(context_data)
|
||||
@@ -805,7 +806,7 @@ class Prompt:
|
||||
|
||||
return await global_prompt_manager.format_prompt(self.name, **params) if self.name else self.format(**params)
|
||||
|
||||
def _prepare_s4u_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _prepare_s4u_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""准备S4U模式的参数"""
|
||||
return {
|
||||
**context_data,
|
||||
@@ -834,7 +835,7 @@ class Prompt:
|
||||
or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。",
|
||||
}
|
||||
|
||||
def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _prepare_normal_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""准备Normal模式的参数"""
|
||||
return {
|
||||
**context_data,
|
||||
@@ -862,7 +863,7 @@ class Prompt:
|
||||
or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。",
|
||||
}
|
||||
|
||||
def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _prepare_default_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""准备默认模式的参数"""
|
||||
return {
|
||||
"expression_habits_block": context_data.get("expression_habits_block", ""),
|
||||
@@ -905,7 +906,7 @@ class Prompt:
|
||||
result = self._restore_escaped_braces(processed_template)
|
||||
return result
|
||||
except (IndexError, KeyError) as e:
|
||||
raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {str(e)}") from e
|
||||
raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {e!s}") from e
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""返回格式化后的结果或原始模板"""
|
||||
@@ -922,7 +923,7 @@ class Prompt:
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def parse_reply_target(target_message: str) -> Tuple[str, str]:
|
||||
def parse_reply_target(target_message: str) -> tuple[str, str]:
|
||||
"""
|
||||
解析回复目标消息 - 统一实现
|
||||
|
||||
@@ -981,7 +982,7 @@ class Prompt:
|
||||
|
||||
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||
|
||||
def _get_default_result_for_task(self, task_name: str) -> Dict[str, Any]:
|
||||
def _get_default_result_for_task(self, task_name: str) -> dict[str, Any]:
|
||||
"""
|
||||
为超时的任务提供默认结果
|
||||
|
||||
@@ -1008,7 +1009,7 @@ class Prompt:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]) -> str:
|
||||
async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: dict[str, Any] | None) -> str:
|
||||
"""
|
||||
构建跨群聊上下文 - 统一实现
|
||||
|
||||
@@ -1071,7 +1072,7 @@ class Prompt:
|
||||
|
||||
# 工厂函数
|
||||
def create_prompt(
|
||||
template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs
|
||||
template: str, name: str | None = None, parameters: PromptParameters | None = None, **kwargs
|
||||
) -> Prompt:
|
||||
"""快速创建Prompt实例的工厂函数"""
|
||||
if parameters is None:
|
||||
@@ -1080,7 +1081,7 @@ def create_prompt(
|
||||
|
||||
|
||||
async def create_prompt_async(
|
||||
template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs
|
||||
template: str, name: str | None = None, parameters: PromptParameters | None = None, **kwargs
|
||||
) -> Prompt:
|
||||
"""异步创建Prompt实例"""
|
||||
prompt = create_prompt(template, name, parameters, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user