🤖 自动格式化代码 [skip ci]
This commit is contained in:
@@ -754,14 +754,11 @@ class HeartFChatting:
|
|||||||
if relation_info:
|
if relation_info:
|
||||||
updated_action_data["relation_info"] = relation_info
|
updated_action_data["relation_info"] = relation_info
|
||||||
|
|
||||||
|
|
||||||
if structured_info:
|
if structured_info:
|
||||||
updated_action_data["structured_info"] = structured_info
|
updated_action_data["structured_info"] = structured_info
|
||||||
|
|
||||||
if all_post_plan_info:
|
if all_post_plan_info:
|
||||||
logger.info(
|
logger.info(f"{self.log_prefix} 后期处理完成,产生了 {len(all_post_plan_info)} 个信息项")
|
||||||
f"{self.log_prefix} 后期处理完成,产生了 {len(all_post_plan_info)} 个信息项"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 输出详细统计信息
|
# 输出详细统计信息
|
||||||
if post_processor_time_costs:
|
if post_processor_time_costs:
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
|
||||||
from src.chat.heart_flow.observation.structure_observation import StructureObservation
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -10,7 +8,6 @@ from typing import List, Dict
|
|||||||
import difflib
|
import difflib
|
||||||
import json
|
import json
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
from src.person_info.person_info import get_person_info_manager
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("memory_activator")
|
logger = get_logger("memory_activator")
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from typing import List, Optional, Union
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.message_receive.message import MessageThinking
|
from src.chat.message_receive.message import MessageThinking
|
||||||
@@ -19,8 +18,8 @@ class NormalChatGenerator:
|
|||||||
|
|
||||||
prob_first = global_config.normal_chat.normal_chat_first_probability
|
prob_first = global_config.normal_chat.normal_chat_first_probability
|
||||||
|
|
||||||
model_config_1['weight'] = prob_first
|
model_config_1["weight"] = prob_first
|
||||||
model_config_2['weight'] = 1.0 - prob_first
|
model_config_2["weight"] = 1.0 - prob_first
|
||||||
|
|
||||||
self.model_configs = [model_config_1, model_config_2]
|
self.model_configs = [model_config_1, model_config_2]
|
||||||
|
|
||||||
@@ -54,7 +53,7 @@ class NormalChatGenerator:
|
|||||||
available_actions=available_actions,
|
available_actions=available_actions,
|
||||||
model_configs=self.model_configs,
|
model_configs=self.model_configs,
|
||||||
request_type="normal.replyer",
|
request_type="normal.replyer",
|
||||||
return_prompt=True
|
return_prompt=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not success or not reply_set:
|
if not success or not reply_set:
|
||||||
|
|||||||
@@ -31,15 +31,12 @@ logger = get_logger("replyer")
|
|||||||
|
|
||||||
|
|
||||||
def init_prompt():
|
def init_prompt():
|
||||||
|
|
||||||
Prompt("你正在qq群里聊天,下面是群里在聊的内容:", "chat_target_group1")
|
Prompt("你正在qq群里聊天,下面是群里在聊的内容:", "chat_target_group1")
|
||||||
Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1")
|
Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1")
|
||||||
Prompt("在群里聊天", "chat_target_group2")
|
Prompt("在群里聊天", "chat_target_group2")
|
||||||
Prompt("和{sender_name}私聊", "chat_target_private2")
|
Prompt("和{sender_name}私聊", "chat_target_private2")
|
||||||
Prompt("\n你有以下这些**知识**:\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt")
|
Prompt("\n你有以下这些**知识**:\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Prompt(
|
Prompt(
|
||||||
"""
|
"""
|
||||||
{expression_habits_block}
|
{expression_habits_block}
|
||||||
@@ -134,7 +131,12 @@ def init_prompt():
|
|||||||
|
|
||||||
|
|
||||||
class DefaultReplyer:
|
class DefaultReplyer:
|
||||||
def __init__(self, chat_stream: ChatStream, model_configs: Optional[List[Dict[str, Any]]] = None, request_type: str = "focus.replyer"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
chat_stream: ChatStream,
|
||||||
|
model_configs: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
request_type: str = "focus.replyer",
|
||||||
|
):
|
||||||
self.log_prefix = "replyer"
|
self.log_prefix = "replyer"
|
||||||
self.request_type = request_type
|
self.request_type = request_type
|
||||||
|
|
||||||
@@ -143,14 +145,14 @@ class DefaultReplyer:
|
|||||||
else:
|
else:
|
||||||
# 当未提供配置时,使用默认配置并赋予默认权重
|
# 当未提供配置时,使用默认配置并赋予默认权重
|
||||||
default_config = global_config.model.replyer_1.copy()
|
default_config = global_config.model.replyer_1.copy()
|
||||||
default_config.setdefault('weight', 1.0)
|
default_config.setdefault("weight", 1.0)
|
||||||
self.express_model_configs = [default_config]
|
self.express_model_configs = [default_config]
|
||||||
|
|
||||||
if not self.express_model_configs:
|
if not self.express_model_configs:
|
||||||
logger.warning("未找到有效的模型配置,回复生成可能会失败。")
|
logger.warning("未找到有效的模型配置,回复生成可能会失败。")
|
||||||
# 提供一个最终的回退,以防止在空列表上调用 random.choice
|
# 提供一个最终的回退,以防止在空列表上调用 random.choice
|
||||||
fallback_config = global_config.model.replyer_1.copy()
|
fallback_config = global_config.model.replyer_1.copy()
|
||||||
fallback_config.setdefault('weight', 1.0)
|
fallback_config.setdefault("weight", 1.0)
|
||||||
self.express_model_configs = [fallback_config]
|
self.express_model_configs = [fallback_config]
|
||||||
|
|
||||||
self.heart_fc_sender = HeartFCSender()
|
self.heart_fc_sender = HeartFCSender()
|
||||||
@@ -163,7 +165,7 @@ class DefaultReplyer:
|
|||||||
"""使用加权随机选择来挑选一个模型配置"""
|
"""使用加权随机选择来挑选一个模型配置"""
|
||||||
configs = self.express_model_configs
|
configs = self.express_model_configs
|
||||||
# 提取权重,如果模型配置中没有'weight'键,则默认为1.0
|
# 提取权重,如果模型配置中没有'weight'键,则默认为1.0
|
||||||
weights = [config.get('weight', 1.0) for config in configs]
|
weights = [config.get("weight", 1.0) for config in configs]
|
||||||
|
|
||||||
# random.choices 返回一个列表,我们取第一个元素
|
# random.choices 返回一个列表,我们取第一个元素
|
||||||
selected_config = random.choices(population=configs, weights=weights, k=1)[0]
|
selected_config = random.choices(population=configs, weights=weights, k=1)[0]
|
||||||
@@ -198,18 +200,21 @@ class DefaultReplyer:
|
|||||||
|
|
||||||
async def generate_reply_with_context(
|
async def generate_reply_with_context(
|
||||||
self,
|
self,
|
||||||
reply_data: Dict[str, Any] = {},
|
reply_data: Dict[str, Any] = None,
|
||||||
reply_to: str = "",
|
reply_to: str = "",
|
||||||
relation_info: str = "",
|
relation_info: str = "",
|
||||||
structured_info: str = "",
|
structured_info: str = "",
|
||||||
extra_info: str = "",
|
extra_info: str = "",
|
||||||
available_actions: List[str] = [],
|
available_actions: List[str] = None,
|
||||||
|
|
||||||
) -> Tuple[bool, Optional[str]]:
|
) -> Tuple[bool, Optional[str]]:
|
||||||
"""
|
"""
|
||||||
回复器 (Replier): 核心逻辑,负责生成回复文本。
|
回复器 (Replier): 核心逻辑,负责生成回复文本。
|
||||||
(已整合原 HeartFCGenerator 的功能)
|
(已整合原 HeartFCGenerator 的功能)
|
||||||
"""
|
"""
|
||||||
|
if available_actions is None:
|
||||||
|
available_actions = []
|
||||||
|
if reply_data is None:
|
||||||
|
reply_data = {}
|
||||||
try:
|
try:
|
||||||
if not reply_data:
|
if not reply_data:
|
||||||
reply_data = {
|
reply_data = {
|
||||||
@@ -226,7 +231,7 @@ class DefaultReplyer:
|
|||||||
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
||||||
prompt = await self.build_prompt_reply_context(
|
prompt = await self.build_prompt_reply_context(
|
||||||
reply_data=reply_data, # 传递action_data
|
reply_data=reply_data, # 传递action_data
|
||||||
available_actions=available_actions
|
available_actions=available_actions,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. 调用 LLM 生成回复
|
# 4. 调用 LLM 生成回复
|
||||||
@@ -238,7 +243,9 @@ class DefaultReplyer:
|
|||||||
with Timer("LLM生成", {}): # 内部计时器,可选保留
|
with Timer("LLM生成", {}): # 内部计时器,可选保留
|
||||||
# 加权随机选择一个模型配置
|
# 加权随机选择一个模型配置
|
||||||
selected_model_config = self._select_weighted_model_config()
|
selected_model_config = self._select_weighted_model_config()
|
||||||
logger.info(f"{self.log_prefix} 使用模型配置: {selected_model_config.get('model_name', 'N/A')} (权重: {selected_model_config.get('weight', 1.0)})")
|
logger.info(
|
||||||
|
f"{self.log_prefix} 使用模型配置: {selected_model_config.get('model_name', 'N/A')} (权重: {selected_model_config.get('weight', 1.0)})"
|
||||||
|
)
|
||||||
|
|
||||||
express_model = LLMRequest(
|
express_model = LLMRequest(
|
||||||
model=selected_model_config,
|
model=selected_model_config,
|
||||||
@@ -262,9 +269,7 @@ class DefaultReplyer:
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
async def rewrite_reply_with_context(
|
async def rewrite_reply_with_context(self, reply_data: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
|
||||||
self, reply_data: Dict[str, Any]
|
|
||||||
) -> Tuple[bool, Optional[str]]:
|
|
||||||
"""
|
"""
|
||||||
表达器 (Expressor): 核心逻辑,负责生成回复文本。
|
表达器 (Expressor): 核心逻辑,负责生成回复文本。
|
||||||
"""
|
"""
|
||||||
@@ -291,7 +296,9 @@ class DefaultReplyer:
|
|||||||
with Timer("LLM生成", {}): # 内部计时器,可选保留
|
with Timer("LLM生成", {}): # 内部计时器,可选保留
|
||||||
# 加权随机选择一个模型配置
|
# 加权随机选择一个模型配置
|
||||||
selected_model_config = self._select_weighted_model_config()
|
selected_model_config = self._select_weighted_model_config()
|
||||||
logger.info(f"{self.log_prefix} 使用模型配置进行重写: {selected_model_config.get('model_name', 'N/A')} (权重: {selected_model_config.get('weight', 1.0)})")
|
logger.info(
|
||||||
|
f"{self.log_prefix} 使用模型配置进行重写: {selected_model_config.get('model_name', 'N/A')} (权重: {selected_model_config.get('weight', 1.0)})"
|
||||||
|
)
|
||||||
|
|
||||||
express_model = LLMRequest(
|
express_model = LLMRequest(
|
||||||
model=selected_model_config,
|
model=selected_model_config,
|
||||||
@@ -315,11 +322,7 @@ class DefaultReplyer:
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
async def build_prompt_reply_context(
|
async def build_prompt_reply_context(self, reply_data=None, available_actions: List[str] = None) -> str:
|
||||||
self,
|
|
||||||
reply_data=None,
|
|
||||||
available_actions: List[str] = []
|
|
||||||
) -> str:
|
|
||||||
"""
|
"""
|
||||||
构建回复器上下文
|
构建回复器上下文
|
||||||
|
|
||||||
@@ -336,6 +339,8 @@ class DefaultReplyer:
|
|||||||
Returns:
|
Returns:
|
||||||
str: 构建好的上下文
|
str: 构建好的上下文
|
||||||
"""
|
"""
|
||||||
|
if available_actions is None:
|
||||||
|
available_actions = []
|
||||||
chat_stream = self.chat_stream
|
chat_stream = self.chat_stream
|
||||||
chat_id = chat_stream.stream_id
|
chat_id = chat_stream.stream_id
|
||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
@@ -403,7 +408,6 @@ class DefaultReplyer:
|
|||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
bot_person_id = person_info_manager.get_person_id("system", "bot_id")
|
bot_person_id = person_info_manager.get_person_id("system", "bot_id")
|
||||||
|
|
||||||
|
|
||||||
is_group_chat = bool(chat_stream.group_info)
|
is_group_chat = bool(chat_stream.group_info)
|
||||||
|
|
||||||
style_habbits = []
|
style_habbits = []
|
||||||
@@ -415,7 +419,6 @@ class DefaultReplyer:
|
|||||||
chat_id, chat_talking_prompt_half, max_num=12, min_num=2, target_message=target
|
chat_id, chat_talking_prompt_half, max_num=12, min_num=2, target_message=target
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if selected_expressions:
|
if selected_expressions:
|
||||||
logger.info(f"{self.log_prefix} 使用处理器选中的{len(selected_expressions)}个表达方式")
|
logger.info(f"{self.log_prefix} 使用处理器选中的{len(selected_expressions)}个表达方式")
|
||||||
for expr in selected_expressions:
|
for expr in selected_expressions:
|
||||||
@@ -450,9 +453,7 @@ class DefaultReplyer:
|
|||||||
# 由于无法直接访问 HeartFChatting 的 observations 列表,
|
# 由于无法直接访问 HeartFChatting 的 observations 列表,
|
||||||
# 我们直接使用聊天记录作为上下文来激活记忆
|
# 我们直接使用聊天记录作为上下文来激活记忆
|
||||||
running_memorys = await self.memory_activator.activate_memory_with_chat_history(
|
running_memorys = await self.memory_activator.activate_memory_with_chat_history(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id, target_message=target, chat_history_prompt=chat_talking_prompt_half
|
||||||
target_message=target,
|
|
||||||
chat_history_prompt=chat_talking_prompt_half
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if running_memorys:
|
if running_memorys:
|
||||||
@@ -468,7 +469,9 @@ class DefaultReplyer:
|
|||||||
memory_block = ""
|
memory_block = ""
|
||||||
|
|
||||||
if structured_info:
|
if structured_info:
|
||||||
structured_info_block = f"以下是你了解的额外信息信息,现在请你阅读以下内容,进行决策\n{structured_info}\n以上是一些额外的信息。"
|
structured_info_block = (
|
||||||
|
f"以下是你了解的额外信息信息,现在请你阅读以下内容,进行决策\n{structured_info}\n以上是一些额外的信息。"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
structured_info_block = ""
|
structured_info_block = ""
|
||||||
|
|
||||||
@@ -558,7 +561,6 @@ class DefaultReplyer:
|
|||||||
if prompt_info:
|
if prompt_info:
|
||||||
prompt_info = await global_prompt_manager.format_prompt("knowledge_prompt", prompt_info=prompt_info)
|
prompt_info = await global_prompt_manager.format_prompt("knowledge_prompt", prompt_info=prompt_info)
|
||||||
|
|
||||||
|
|
||||||
# --- Choose template based on chat type ---
|
# --- Choose template based on chat type ---
|
||||||
if is_group_chat:
|
if is_group_chat:
|
||||||
template_name = "default_generator_prompt"
|
template_name = "default_generator_prompt"
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from src.common.logger import get_logger
|
|||||||
|
|
||||||
logger = get_logger("ReplyerManager")
|
logger = get_logger("ReplyerManager")
|
||||||
|
|
||||||
|
|
||||||
class ReplyerManager:
|
class ReplyerManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._replyers: Dict[str, DefaultReplyer] = {}
|
self._replyers: Dict[str, DefaultReplyer] = {}
|
||||||
@@ -14,7 +15,7 @@ class ReplyerManager:
|
|||||||
chat_stream: Optional[ChatStream] = None,
|
chat_stream: Optional[ChatStream] = None,
|
||||||
chat_id: Optional[str] = None,
|
chat_id: Optional[str] = None,
|
||||||
model_configs: Optional[List[Dict[str, Any]]] = None,
|
model_configs: Optional[List[Dict[str, Any]]] = None,
|
||||||
request_type: str = "replyer"
|
request_type: str = "replyer",
|
||||||
) -> Optional[DefaultReplyer]:
|
) -> Optional[DefaultReplyer]:
|
||||||
"""
|
"""
|
||||||
获取或创建回复器实例。
|
获取或创建回复器实例。
|
||||||
@@ -49,10 +50,11 @@ class ReplyerManager:
|
|||||||
replyer = DefaultReplyer(
|
replyer = DefaultReplyer(
|
||||||
chat_stream=target_stream,
|
chat_stream=target_stream,
|
||||||
model_configs=model_configs, # 可以是None,此时使用默认模型
|
model_configs=model_configs, # 可以是None,此时使用默认模型
|
||||||
request_type=request_type
|
request_type=request_type,
|
||||||
)
|
)
|
||||||
self._replyers[stream_id] = replyer
|
self._replyers[stream_id] = replyer
|
||||||
return replyer
|
return replyer
|
||||||
|
|
||||||
|
|
||||||
# 创建一个全局实例
|
# 创建一个全局实例
|
||||||
replyer_manager = ReplyerManager()
|
replyer_manager = ReplyerManager()
|
||||||
@@ -27,7 +27,7 @@ def get_replyer(
|
|||||||
chat_stream: Optional[ChatStream] = None,
|
chat_stream: Optional[ChatStream] = None,
|
||||||
chat_id: Optional[str] = None,
|
chat_id: Optional[str] = None,
|
||||||
model_configs: Optional[List[Dict[str, Any]]] = None,
|
model_configs: Optional[List[Dict[str, Any]]] = None,
|
||||||
request_type: str = "replyer"
|
request_type: str = "replyer",
|
||||||
) -> Optional[DefaultReplyer]:
|
) -> Optional[DefaultReplyer]:
|
||||||
"""获取回复器对象
|
"""获取回复器对象
|
||||||
|
|
||||||
@@ -46,10 +46,7 @@ def get_replyer(
|
|||||||
try:
|
try:
|
||||||
logger.debug(f"[GeneratorAPI] 正在获取回复器,chat_id: {chat_id}, chat_stream: {'有' if chat_stream else '无'}")
|
logger.debug(f"[GeneratorAPI] 正在获取回复器,chat_id: {chat_id}, chat_stream: {'有' if chat_stream else '无'}")
|
||||||
return replyer_manager.get_replyer(
|
return replyer_manager.get_replyer(
|
||||||
chat_stream=chat_stream,
|
chat_stream=chat_stream, chat_id=chat_id, model_configs=model_configs, request_type=request_type
|
||||||
chat_id=chat_id,
|
|
||||||
model_configs=model_configs,
|
|
||||||
request_type=request_type
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[GeneratorAPI] 获取回复器时发生意外错误: {e}", exc_info=True)
|
logger.error(f"[GeneratorAPI] 获取回复器时发生意外错误: {e}", exc_info=True)
|
||||||
@@ -154,9 +151,7 @@ async def rewrite_reply(
|
|||||||
logger.info("[GeneratorAPI] 开始重写回复")
|
logger.info("[GeneratorAPI] 开始重写回复")
|
||||||
|
|
||||||
# 调用回复器重写回复
|
# 调用回复器重写回复
|
||||||
success, content = await replyer.rewrite_reply_with_context(
|
success, content = await replyer.rewrite_reply_with_context(reply_data=reply_data or {})
|
||||||
reply_data=reply_data or {}
|
|
||||||
)
|
|
||||||
|
|
||||||
reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo)
|
reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo)
|
||||||
|
|
||||||
@@ -172,11 +167,7 @@ async def rewrite_reply(
|
|||||||
return False, []
|
return False, []
|
||||||
|
|
||||||
|
|
||||||
async def process_human_text(
|
async def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> List[Tuple[str, Any]]:
|
||||||
content:str,
|
|
||||||
enable_splitter:bool,
|
|
||||||
enable_chinese_typo:bool
|
|
||||||
) -> List[Tuple[str, Any]]:
|
|
||||||
"""将文本处理为更拟人化的文本
|
"""将文本处理为更拟人化的文本
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
Reference in New Issue
Block a user