🤖 自动格式化代码 [skip ci]

This commit is contained in:
github-actions[bot]
2025-07-01 04:27:28 +00:00
parent a1a81194f1
commit 6dee5a6333
6 changed files with 77 additions and 89 deletions

View File

@@ -754,14 +754,11 @@ class HeartFChatting:
if relation_info: if relation_info:
updated_action_data["relation_info"] = relation_info updated_action_data["relation_info"] = relation_info
if structured_info: if structured_info:
updated_action_data["structured_info"] = structured_info updated_action_data["structured_info"] = structured_info
if all_post_plan_info: if all_post_plan_info:
logger.info( logger.info(f"{self.log_prefix} 后期处理完成,产生了 {len(all_post_plan_info)} 个信息项")
f"{self.log_prefix} 后期处理完成,产生了 {len(all_post_plan_info)} 个信息项"
)
# 输出详细统计信息 # 输出详细统计信息
if post_processor_time_costs: if post_processor_time_costs:

View File

@@ -1,5 +1,3 @@
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.heart_flow.observation.structure_observation import StructureObservation
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -10,7 +8,6 @@ from typing import List, Dict
import difflib import difflib
import json import json
from json_repair import repair_json from json_repair import repair_json
from src.person_info.person_info import get_person_info_manager
logger = get_logger("memory_activator") logger = get_logger("memory_activator")

View File

@@ -1,4 +1,3 @@
from typing import List, Optional, Union
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config
from src.chat.message_receive.message import MessageThinking from src.chat.message_receive.message import MessageThinking
@@ -19,8 +18,8 @@ class NormalChatGenerator:
prob_first = global_config.normal_chat.normal_chat_first_probability prob_first = global_config.normal_chat.normal_chat_first_probability
model_config_1['weight'] = prob_first model_config_1["weight"] = prob_first
model_config_2['weight'] = 1.0 - prob_first model_config_2["weight"] = 1.0 - prob_first
self.model_configs = [model_config_1, model_config_2] self.model_configs = [model_config_1, model_config_2]
@@ -54,7 +53,7 @@ class NormalChatGenerator:
available_actions=available_actions, available_actions=available_actions,
model_configs=self.model_configs, model_configs=self.model_configs,
request_type="normal.replyer", request_type="normal.replyer",
return_prompt=True return_prompt=True,
) )
if not success or not reply_set: if not success or not reply_set:

View File

@@ -31,15 +31,12 @@ logger = get_logger("replyer")
def init_prompt(): def init_prompt():
Prompt("你正在qq群里聊天下面是群里在聊的内容", "chat_target_group1") Prompt("你正在qq群里聊天下面是群里在聊的内容", "chat_target_group1")
Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1") Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1")
Prompt("在群里聊天", "chat_target_group2") Prompt("在群里聊天", "chat_target_group2")
Prompt("{sender_name}私聊", "chat_target_private2") Prompt("{sender_name}私聊", "chat_target_private2")
Prompt("\n你有以下这些**知识**\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt") Prompt("\n你有以下这些**知识**\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt")
Prompt( Prompt(
""" """
{expression_habits_block} {expression_habits_block}
@@ -134,7 +131,12 @@ def init_prompt():
class DefaultReplyer: class DefaultReplyer:
def __init__(self, chat_stream: ChatStream, model_configs: Optional[List[Dict[str, Any]]] = None, request_type: str = "focus.replyer"): def __init__(
self,
chat_stream: ChatStream,
model_configs: Optional[List[Dict[str, Any]]] = None,
request_type: str = "focus.replyer",
):
self.log_prefix = "replyer" self.log_prefix = "replyer"
self.request_type = request_type self.request_type = request_type
@@ -143,14 +145,14 @@ class DefaultReplyer:
else: else:
# 当未提供配置时,使用默认配置并赋予默认权重 # 当未提供配置时,使用默认配置并赋予默认权重
default_config = global_config.model.replyer_1.copy() default_config = global_config.model.replyer_1.copy()
default_config.setdefault('weight', 1.0) default_config.setdefault("weight", 1.0)
self.express_model_configs = [default_config] self.express_model_configs = [default_config]
if not self.express_model_configs: if not self.express_model_configs:
logger.warning("未找到有效的模型配置,回复生成可能会失败。") logger.warning("未找到有效的模型配置,回复生成可能会失败。")
# 提供一个最终的回退,以防止在空列表上调用 random.choice # 提供一个最终的回退,以防止在空列表上调用 random.choice
fallback_config = global_config.model.replyer_1.copy() fallback_config = global_config.model.replyer_1.copy()
fallback_config.setdefault('weight', 1.0) fallback_config.setdefault("weight", 1.0)
self.express_model_configs = [fallback_config] self.express_model_configs = [fallback_config]
self.heart_fc_sender = HeartFCSender() self.heart_fc_sender = HeartFCSender()
@@ -163,7 +165,7 @@ class DefaultReplyer:
"""使用加权随机选择来挑选一个模型配置""" """使用加权随机选择来挑选一个模型配置"""
configs = self.express_model_configs configs = self.express_model_configs
# 提取权重,如果模型配置中没有'weight'键则默认为1.0 # 提取权重,如果模型配置中没有'weight'键则默认为1.0
weights = [config.get('weight', 1.0) for config in configs] weights = [config.get("weight", 1.0) for config in configs]
# random.choices 返回一个列表,我们取第一个元素 # random.choices 返回一个列表,我们取第一个元素
selected_config = random.choices(population=configs, weights=weights, k=1)[0] selected_config = random.choices(population=configs, weights=weights, k=1)[0]
@@ -198,18 +200,21 @@ class DefaultReplyer:
async def generate_reply_with_context( async def generate_reply_with_context(
self, self,
reply_data: Dict[str, Any] = {}, reply_data: Dict[str, Any] = None,
reply_to: str = "", reply_to: str = "",
relation_info: str = "", relation_info: str = "",
structured_info: str = "", structured_info: str = "",
extra_info: str = "", extra_info: str = "",
available_actions: List[str] = [], available_actions: List[str] = None,
) -> Tuple[bool, Optional[str]]: ) -> Tuple[bool, Optional[str]]:
""" """
回复器 (Replier): 核心逻辑,负责生成回复文本。 回复器 (Replier): 核心逻辑,负责生成回复文本。
(已整合原 HeartFCGenerator 的功能) (已整合原 HeartFCGenerator 的功能)
""" """
if available_actions is None:
available_actions = []
if reply_data is None:
reply_data = {}
try: try:
if not reply_data: if not reply_data:
reply_data = { reply_data = {
@@ -226,7 +231,7 @@ class DefaultReplyer:
with Timer("构建Prompt", {}): # 内部计时器,可选保留 with Timer("构建Prompt", {}): # 内部计时器,可选保留
prompt = await self.build_prompt_reply_context( prompt = await self.build_prompt_reply_context(
reply_data=reply_data, # 传递action_data reply_data=reply_data, # 传递action_data
available_actions=available_actions available_actions=available_actions,
) )
# 4. 调用 LLM 生成回复 # 4. 调用 LLM 生成回复
@@ -238,7 +243,9 @@ class DefaultReplyer:
with Timer("LLM生成", {}): # 内部计时器,可选保留 with Timer("LLM生成", {}): # 内部计时器,可选保留
# 加权随机选择一个模型配置 # 加权随机选择一个模型配置
selected_model_config = self._select_weighted_model_config() selected_model_config = self._select_weighted_model_config()
logger.info(f"{self.log_prefix} 使用模型配置: {selected_model_config.get('model_name', 'N/A')} (权重: {selected_model_config.get('weight', 1.0)})") logger.info(
f"{self.log_prefix} 使用模型配置: {selected_model_config.get('model_name', 'N/A')} (权重: {selected_model_config.get('weight', 1.0)})"
)
express_model = LLMRequest( express_model = LLMRequest(
model=selected_model_config, model=selected_model_config,
@@ -262,9 +269,7 @@ class DefaultReplyer:
traceback.print_exc() traceback.print_exc()
return False, None return False, None
async def rewrite_reply_with_context( async def rewrite_reply_with_context(self, reply_data: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
self, reply_data: Dict[str, Any]
) -> Tuple[bool, Optional[str]]:
""" """
表达器 (Expressor): 核心逻辑,负责生成回复文本。 表达器 (Expressor): 核心逻辑,负责生成回复文本。
""" """
@@ -291,7 +296,9 @@ class DefaultReplyer:
with Timer("LLM生成", {}): # 内部计时器,可选保留 with Timer("LLM生成", {}): # 内部计时器,可选保留
# 加权随机选择一个模型配置 # 加权随机选择一个模型配置
selected_model_config = self._select_weighted_model_config() selected_model_config = self._select_weighted_model_config()
logger.info(f"{self.log_prefix} 使用模型配置进行重写: {selected_model_config.get('model_name', 'N/A')} (权重: {selected_model_config.get('weight', 1.0)})") logger.info(
f"{self.log_prefix} 使用模型配置进行重写: {selected_model_config.get('model_name', 'N/A')} (权重: {selected_model_config.get('weight', 1.0)})"
)
express_model = LLMRequest( express_model = LLMRequest(
model=selected_model_config, model=selected_model_config,
@@ -315,11 +322,7 @@ class DefaultReplyer:
traceback.print_exc() traceback.print_exc()
return False, None return False, None
async def build_prompt_reply_context( async def build_prompt_reply_context(self, reply_data=None, available_actions: List[str] = None) -> str:
self,
reply_data=None,
available_actions: List[str] = []
) -> str:
""" """
构建回复器上下文 构建回复器上下文
@@ -336,6 +339,8 @@ class DefaultReplyer:
Returns: Returns:
str: 构建好的上下文 str: 构建好的上下文
""" """
if available_actions is None:
available_actions = []
chat_stream = self.chat_stream chat_stream = self.chat_stream
chat_id = chat_stream.stream_id chat_id = chat_stream.stream_id
person_info_manager = get_person_info_manager() person_info_manager = get_person_info_manager()
@@ -403,7 +408,6 @@ class DefaultReplyer:
person_info_manager = get_person_info_manager() person_info_manager = get_person_info_manager()
bot_person_id = person_info_manager.get_person_id("system", "bot_id") bot_person_id = person_info_manager.get_person_id("system", "bot_id")
is_group_chat = bool(chat_stream.group_info) is_group_chat = bool(chat_stream.group_info)
style_habbits = [] style_habbits = []
@@ -415,7 +419,6 @@ class DefaultReplyer:
chat_id, chat_talking_prompt_half, max_num=12, min_num=2, target_message=target chat_id, chat_talking_prompt_half, max_num=12, min_num=2, target_message=target
) )
if selected_expressions: if selected_expressions:
logger.info(f"{self.log_prefix} 使用处理器选中的{len(selected_expressions)}个表达方式") logger.info(f"{self.log_prefix} 使用处理器选中的{len(selected_expressions)}个表达方式")
for expr in selected_expressions: for expr in selected_expressions:
@@ -450,9 +453,7 @@ class DefaultReplyer:
# 由于无法直接访问 HeartFChatting 的 observations 列表, # 由于无法直接访问 HeartFChatting 的 observations 列表,
# 我们直接使用聊天记录作为上下文来激活记忆 # 我们直接使用聊天记录作为上下文来激活记忆
running_memorys = await self.memory_activator.activate_memory_with_chat_history( running_memorys = await self.memory_activator.activate_memory_with_chat_history(
chat_id=chat_id, chat_id=chat_id, target_message=target, chat_history_prompt=chat_talking_prompt_half
target_message=target,
chat_history_prompt=chat_talking_prompt_half
) )
if running_memorys: if running_memorys:
@@ -468,7 +469,9 @@ class DefaultReplyer:
memory_block = "" memory_block = ""
if structured_info: if structured_info:
structured_info_block = f"以下是你了解的额外信息信息,现在请你阅读以下内容,进行决策\n{structured_info}\n以上是一些额外的信息。" structured_info_block = (
f"以下是你了解的额外信息信息,现在请你阅读以下内容,进行决策\n{structured_info}\n以上是一些额外的信息。"
)
else: else:
structured_info_block = "" structured_info_block = ""
@@ -558,7 +561,6 @@ class DefaultReplyer:
if prompt_info: if prompt_info:
prompt_info = await global_prompt_manager.format_prompt("knowledge_prompt", prompt_info=prompt_info) prompt_info = await global_prompt_manager.format_prompt("knowledge_prompt", prompt_info=prompt_info)
# --- Choose template based on chat type --- # --- Choose template based on chat type ---
if is_group_chat: if is_group_chat:
template_name = "default_generator_prompt" template_name = "default_generator_prompt"

View File

@@ -5,6 +5,7 @@ from src.common.logger import get_logger
logger = get_logger("ReplyerManager") logger = get_logger("ReplyerManager")
class ReplyerManager: class ReplyerManager:
def __init__(self): def __init__(self):
self._replyers: Dict[str, DefaultReplyer] = {} self._replyers: Dict[str, DefaultReplyer] = {}
@@ -14,7 +15,7 @@ class ReplyerManager:
chat_stream: Optional[ChatStream] = None, chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None, chat_id: Optional[str] = None,
model_configs: Optional[List[Dict[str, Any]]] = None, model_configs: Optional[List[Dict[str, Any]]] = None,
request_type: str = "replyer" request_type: str = "replyer",
) -> Optional[DefaultReplyer]: ) -> Optional[DefaultReplyer]:
""" """
获取或创建回复器实例。 获取或创建回复器实例。
@@ -49,10 +50,11 @@ class ReplyerManager:
replyer = DefaultReplyer( replyer = DefaultReplyer(
chat_stream=target_stream, chat_stream=target_stream,
model_configs=model_configs, # 可以是None此时使用默认模型 model_configs=model_configs, # 可以是None此时使用默认模型
request_type=request_type request_type=request_type,
) )
self._replyers[stream_id] = replyer self._replyers[stream_id] = replyer
return replyer return replyer
# 创建一个全局实例 # 创建一个全局实例
replyer_manager = ReplyerManager() replyer_manager = ReplyerManager()

View File

@@ -27,7 +27,7 @@ def get_replyer(
chat_stream: Optional[ChatStream] = None, chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None, chat_id: Optional[str] = None,
model_configs: Optional[List[Dict[str, Any]]] = None, model_configs: Optional[List[Dict[str, Any]]] = None,
request_type: str = "replyer" request_type: str = "replyer",
) -> Optional[DefaultReplyer]: ) -> Optional[DefaultReplyer]:
"""获取回复器对象 """获取回复器对象
@@ -46,10 +46,7 @@ def get_replyer(
try: try:
logger.debug(f"[GeneratorAPI] 正在获取回复器chat_id: {chat_id}, chat_stream: {'' if chat_stream else ''}") logger.debug(f"[GeneratorAPI] 正在获取回复器chat_id: {chat_id}, chat_stream: {'' if chat_stream else ''}")
return replyer_manager.get_replyer( return replyer_manager.get_replyer(
chat_stream=chat_stream, chat_stream=chat_stream, chat_id=chat_id, model_configs=model_configs, request_type=request_type
chat_id=chat_id,
model_configs=model_configs,
request_type=request_type
) )
except Exception as e: except Exception as e:
logger.error(f"[GeneratorAPI] 获取回复器时发生意外错误: {e}", exc_info=True) logger.error(f"[GeneratorAPI] 获取回复器时发生意外错误: {e}", exc_info=True)
@@ -154,9 +151,7 @@ async def rewrite_reply(
logger.info("[GeneratorAPI] 开始重写回复") logger.info("[GeneratorAPI] 开始重写回复")
# 调用回复器重写回复 # 调用回复器重写回复
success, content = await replyer.rewrite_reply_with_context( success, content = await replyer.rewrite_reply_with_context(reply_data=reply_data or {})
reply_data=reply_data or {}
)
reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo) reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo)
@@ -172,11 +167,7 @@ async def rewrite_reply(
return False, [] return False, []
async def process_human_text( async def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> List[Tuple[str, Any]]:
content:str,
enable_splitter:bool,
enable_chinese_typo:bool
) -> List[Tuple[str, Any]]:
"""将文本处理为更拟人化的文本 """将文本处理为更拟人化的文本
Args: Args: