Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
This commit is contained in:
@@ -3,13 +3,13 @@ MaiBot模块系统
|
||||
包含聊天、情绪、记忆、日程等功能模块
|
||||
"""
|
||||
|
||||
from src.chat.message_receive.chat_stream import chat_manager
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.chat.normal_chat.willing.willing_manager import willing_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.normal_chat.willing.willing_manager import get_willing_manager
|
||||
|
||||
# 导出主要组件供外部使用
|
||||
__all__ = [
|
||||
"chat_manager",
|
||||
"emoji_manager",
|
||||
"willing_manager",
|
||||
"get_chat_manager",
|
||||
"get_emoji_manager",
|
||||
"get_willing_manager",
|
||||
]
|
||||
|
||||
@@ -15,9 +15,9 @@ import re
|
||||
from src.common.database.database_model import Emoji
|
||||
from src.common.database.database import db as peewee_db
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.utils_image import image_path_to_base64, image_manager
|
||||
from src.chat.utils.utils_image import image_path_to_base64, get_image_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -163,7 +163,7 @@ class MaiEmoji:
|
||||
last_used_time=self.last_used_time,
|
||||
)
|
||||
|
||||
logger.success(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")
|
||||
logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")
|
||||
|
||||
return True
|
||||
|
||||
@@ -317,7 +317,7 @@ async def clear_temp_emoji() -> None:
|
||||
os.remove(file_path)
|
||||
logger.debug(f"[清理] 删除: {filename}")
|
||||
|
||||
logger.success("[清理] 完成")
|
||||
logger.info("[清理] 完成")
|
||||
|
||||
|
||||
async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"]) -> None:
|
||||
@@ -349,7 +349,7 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"]) -
|
||||
logger.error(f"[错误] 删除文件时出错 ({file_full_path}): {str(e)}")
|
||||
|
||||
if cleaned_count > 0:
|
||||
logger.success(f"[清理] 在目录 {emoji_dir} 中清理了 {cleaned_count} 个破损表情包。")
|
||||
logger.info(f"[清理] 在目录 {emoji_dir} 中清理了 {cleaned_count} 个破损表情包。")
|
||||
else:
|
||||
logger.info(f"[清理] 目录 {emoji_dir} 中没有需要清理的。")
|
||||
|
||||
@@ -568,7 +568,7 @@ class EmojiManager:
|
||||
|
||||
# 输出清理结果
|
||||
if removed_count > 0:
|
||||
logger.success(f"[清理] 已清理 {removed_count} 个失效/文件丢失的表情包记录")
|
||||
logger.info(f"[清理] 已清理 {removed_count} 个失效/文件丢失的表情包记录")
|
||||
logger.info(f"[统计] 清理前记录数: {total_count} | 清理后有效记录数: {len(self.emoji_objects)}")
|
||||
else:
|
||||
logger.info(f"[检查] 已检查 {total_count} 个表情包记录,全部完好")
|
||||
@@ -645,7 +645,7 @@ class EmojiManager:
|
||||
self.emoji_objects = emoji_objects
|
||||
self.emoji_num = len(emoji_objects)
|
||||
|
||||
logger.success(f"[数据库] 加载完成: 共加载 {self.emoji_num} 个表情包记录。")
|
||||
logger.info(f"[数据库] 加载完成: 共加载 {self.emoji_num} 个表情包记录。")
|
||||
if load_errors > 0:
|
||||
logger.warning(f"[数据库] 加载过程中出现 {load_errors} 个错误。")
|
||||
|
||||
@@ -808,7 +808,7 @@ class EmojiManager:
|
||||
if register_success:
|
||||
self.emoji_objects.append(new_emoji)
|
||||
self.emoji_num += 1
|
||||
logger.success(f"[成功] 注册: {new_emoji.filename}")
|
||||
logger.info(f"[成功] 注册: {new_emoji.filename}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[错误] 注册表情包到数据库失败: {new_emoji.filename}")
|
||||
@@ -844,7 +844,7 @@ class EmojiManager:
|
||||
|
||||
# 调用AI获取描述
|
||||
if image_format == "gif" or image_format == "GIF":
|
||||
image_base64 = image_manager.transform_gif(image_base64)
|
||||
image_base64 = get_image_manager().transform_gif(image_base64)
|
||||
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, "jpg")
|
||||
else:
|
||||
@@ -973,7 +973,7 @@ class EmojiManager:
|
||||
# 注册成功后,添加到内存列表
|
||||
self.emoji_objects.append(new_emoji)
|
||||
self.emoji_num += 1
|
||||
logger.success(f"[成功] 注册新表情包: {filename} (当前: {self.emoji_num}/{self.emoji_num_max})")
|
||||
logger.info(f"[成功] 注册新表情包: {filename} (当前: {self.emoji_num}/{self.emoji_num_max})")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[注册失败] 保存表情包到数据库/移动文件失败: {filename}")
|
||||
@@ -1000,5 +1000,11 @@ class EmojiManager:
|
||||
return False
|
||||
|
||||
|
||||
# 创建全局单例
|
||||
emoji_manager = EmojiManager()
|
||||
emoji_manager = None
|
||||
|
||||
|
||||
def get_emoji_manager():
|
||||
global emoji_manager
|
||||
if emoji_manager is None:
|
||||
emoji_manager = EmojiManager()
|
||||
return emoji_manager
|
||||
|
||||
@@ -1,25 +1,25 @@
|
||||
import traceback
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
|
||||
from src.chat.focus_chat.expressors.exprssion_learner import get_expression_learner
|
||||
from src.chat.message_receive.message import MessageRecv, MessageThinking, MessageSending
|
||||
from src.chat.message_receive.message import Seg # Local import needed after move
|
||||
from src.chat.message_receive.message import UserInfo
|
||||
from src.chat.message_receive.chat_stream import chat_manager
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.utils_image import image_path_to_base64 # Local import needed after move
|
||||
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.focus_chat.heartFC_sender import HeartFCSender
|
||||
from src.chat.utils.utils import process_llm_response
|
||||
from src.chat.utils.info_catcher import info_catcher_manager
|
||||
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
import time
|
||||
from src.chat.focus_chat.expressors.exprssion_learner import expression_learner
|
||||
import random
|
||||
|
||||
logger = get_logger("expressor")
|
||||
@@ -110,6 +110,7 @@ class DefaultExpressor:
|
||||
# logger.debug(f"创建思考消息thinking_message:{thinking_message}")
|
||||
|
||||
await self.heart_fc_sender.register_thinking(thinking_message)
|
||||
return None
|
||||
|
||||
async def deal_reply(
|
||||
self,
|
||||
@@ -181,14 +182,6 @@ class DefaultExpressor:
|
||||
(已整合原 HeartFCGenerator 的功能)
|
||||
"""
|
||||
try:
|
||||
# 1. 获取情绪影响因子并调整模型温度
|
||||
# arousal_multiplier = mood_manager.get_arousal_multiplier()
|
||||
# current_temp = float(global_config.model.normal["temp"]) * arousal_multiplier
|
||||
# self.express_model.params["temperature"] = current_temp # 动态调整温度
|
||||
|
||||
# 2. 获取信息捕捉器
|
||||
info_catcher = info_catcher_manager.get_info_catcher(thinking_id)
|
||||
|
||||
# --- Determine sender_name for private chat ---
|
||||
sender_name_for_prompt = "某人" # Default for group or if info unavailable
|
||||
if not self.is_group_chat and self.chat_target_info:
|
||||
@@ -227,15 +220,9 @@ class DefaultExpressor:
|
||||
# logger.info(f"{self.log_prefix}[Replier-{thinking_id}]\nPrompt:\n{prompt}\n")
|
||||
content, (reasoning_content, model_name) = await self.express_model.generate_response_async(prompt)
|
||||
|
||||
# logger.info(f"{self.log_prefix}\nPrompt:\n{prompt}\n---------------------------\n")
|
||||
|
||||
logger.info(f"想要表达:{in_mind_reply}||理由:{reason}")
|
||||
logger.info(f"最终回复: {content}\n")
|
||||
|
||||
info_catcher.catch_after_llm_generated(
|
||||
prompt=prompt, response=content, reasoning_content=reasoning_content, model_name=model_name
|
||||
)
|
||||
|
||||
except Exception as llm_e:
|
||||
# 精简报错信息
|
||||
logger.error(f"{self.log_prefix}LLM 生成失败: {llm_e}")
|
||||
@@ -288,6 +275,7 @@ class DefaultExpressor:
|
||||
truncate=True,
|
||||
)
|
||||
|
||||
expression_learner = get_expression_learner()
|
||||
(
|
||||
learnt_style_expressions,
|
||||
learnt_grammar_expressions,
|
||||
@@ -379,7 +367,7 @@ class DefaultExpressor:
|
||||
logger.error(f"{self.log_prefix} 无法发送回复,anchor_message 为空。")
|
||||
return None
|
||||
|
||||
stream_name = chat_manager.get_stream_name(chat_id) or chat_id # 获取流名称用于日志
|
||||
stream_name = get_chat_manager().get_stream_name(chat_id) or chat_id # 获取流名称用于日志
|
||||
|
||||
# 检查思考过程是否仍在进行,并获取开始时间
|
||||
if thinking_id:
|
||||
@@ -468,7 +456,7 @@ class DefaultExpressor:
|
||||
选择表情,根据send_emoji文本选择表情,返回表情base64
|
||||
"""
|
||||
emoji_base64 = ""
|
||||
emoji_raw = await emoji_manager.get_emoji_for_text(send_emoji)
|
||||
emoji_raw = await get_emoji_manager().get_emoji_for_text(send_emoji)
|
||||
if emoji_raw:
|
||||
emoji_path, _description, _emotion = emoji_raw
|
||||
emoji_base64 = image_path_to_base64(emoji_path)
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import time
|
||||
import random
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_random, build_anonymous_messages
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
import os
|
||||
from src.chat.message_receive.chat_stream import chat_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
import json
|
||||
|
||||
|
||||
@@ -113,25 +113,25 @@ class ExpressionLearner:
|
||||
同时对所有已存储的表达方式进行全局衰减
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
# 全局衰减所有已存储的表达方式
|
||||
for type in ["style", "grammar"]:
|
||||
base_dir = os.path.join("data", "expression", f"learnt_{type}")
|
||||
if not os.path.exists(base_dir):
|
||||
continue
|
||||
|
||||
|
||||
for chat_id in os.listdir(base_dir):
|
||||
file_path = os.path.join(base_dir, chat_id, "expressions.json")
|
||||
if not os.path.exists(file_path):
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
expressions = json.load(f)
|
||||
|
||||
|
||||
# 应用全局衰减
|
||||
decayed_expressions = self.apply_decay_to_expressions(expressions, current_time)
|
||||
|
||||
|
||||
# 保存衰减后的结果
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(decayed_expressions, f, ensure_ascii=False, indent=2)
|
||||
@@ -140,12 +140,12 @@ class ExpressionLearner:
|
||||
continue
|
||||
|
||||
# 学习新的表达方式(这里会进行局部衰减)
|
||||
for i in range(3):
|
||||
for _ in range(3):
|
||||
learnt_style: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="style", num=25)
|
||||
if not learnt_style:
|
||||
return []
|
||||
|
||||
for j in range(1):
|
||||
for _ in range(1):
|
||||
learnt_grammar: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="grammar", num=10)
|
||||
if not learnt_grammar:
|
||||
return []
|
||||
@@ -162,23 +162,25 @@ class ExpressionLearner:
|
||||
"""
|
||||
if time_diff_days <= 0 or time_diff_days >= DECAY_DAYS:
|
||||
return 0.001
|
||||
|
||||
|
||||
# 使用二次函数进行插值
|
||||
# 将7天作为顶点,0天和30天作为两个端点
|
||||
# 使用顶点式:y = a(x-h)^2 + k,其中(h,k)为顶点
|
||||
h = 7.0 # 顶点x坐标
|
||||
k = 0.001 # 顶点y坐标
|
||||
|
||||
|
||||
# 计算a值,使得x=0和x=30时y=0.001
|
||||
# 0.001 = a(0-7)^2 + 0.001
|
||||
# 解得a = 0
|
||||
a = 0
|
||||
|
||||
|
||||
# 计算衰减值
|
||||
decay = a * (time_diff_days - h) ** 2 + k
|
||||
return min(0.001, decay)
|
||||
|
||||
def apply_decay_to_expressions(self, expressions: List[Dict[str, Any]], current_time: float) -> List[Dict[str, Any]]:
|
||||
def apply_decay_to_expressions(
|
||||
self, expressions: List[Dict[str, Any]], current_time: float
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
对表达式列表应用衰减
|
||||
返回衰减后的表达式列表,移除count小于0的项
|
||||
@@ -188,16 +190,16 @@ class ExpressionLearner:
|
||||
# 确保last_active_time存在,如果不存在则使用current_time
|
||||
if "last_active_time" not in expr:
|
||||
expr["last_active_time"] = current_time
|
||||
|
||||
|
||||
last_active = expr["last_active_time"]
|
||||
time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天
|
||||
|
||||
|
||||
decay_value = self.calculate_decay_factor(time_diff_days)
|
||||
expr["count"] = max(0.01, expr.get("count", 1) - decay_value)
|
||||
|
||||
|
||||
if expr["count"] > 0:
|
||||
result.append(expr)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
async def learn_and_store(self, type: str, num: int = 10) -> List[Tuple[str, str, str]]:
|
||||
@@ -211,14 +213,14 @@ class ExpressionLearner:
|
||||
type_str = "句法特点"
|
||||
else:
|
||||
raise ValueError(f"Invalid type: {type}")
|
||||
|
||||
|
||||
res = await self.learn_expression(type, num)
|
||||
|
||||
if res is None:
|
||||
return []
|
||||
learnt_expressions, chat_id = res
|
||||
|
||||
chat_stream = chat_manager.get_stream(chat_id)
|
||||
chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
if chat_stream.group_info:
|
||||
group_name = chat_stream.group_info.group_name
|
||||
else:
|
||||
@@ -238,15 +240,15 @@ class ExpressionLearner:
|
||||
if chat_id not in chat_dict:
|
||||
chat_dict[chat_id] = []
|
||||
chat_dict[chat_id].append({"situation": situation, "style": style})
|
||||
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
# 存储到/data/expression/对应chat_id/expressions.json
|
||||
for chat_id, expr_list in chat_dict.items():
|
||||
dir_path = os.path.join("data", "expression", f"learnt_{type}", str(chat_id))
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
file_path = os.path.join(dir_path, "expressions.json")
|
||||
|
||||
|
||||
# 若已存在,先读出合并
|
||||
old_data: List[Dict[str, Any]] = []
|
||||
if os.path.exists(file_path):
|
||||
@@ -255,10 +257,10 @@ class ExpressionLearner:
|
||||
old_data = json.load(f)
|
||||
except Exception:
|
||||
old_data = []
|
||||
|
||||
|
||||
# 应用衰减
|
||||
# old_data = self.apply_decay_to_expressions(old_data, current_time)
|
||||
|
||||
|
||||
# 合并逻辑
|
||||
for new_expr in expr_list:
|
||||
found = False
|
||||
@@ -278,43 +280,43 @@ class ExpressionLearner:
|
||||
new_expr["count"] = 1
|
||||
new_expr["last_active_time"] = current_time
|
||||
old_data.append(new_expr)
|
||||
|
||||
|
||||
# 处理超限问题
|
||||
if len(old_data) > MAX_EXPRESSION_COUNT:
|
||||
# 计算每个表达方式的权重(count的倒数,这样count越小的越容易被选中)
|
||||
weights = [1 / (expr.get("count", 1) + 0.1) for expr in old_data]
|
||||
|
||||
|
||||
# 随机选择要移除的表达方式,避免重复索引
|
||||
remove_count = len(old_data) - MAX_EXPRESSION_COUNT
|
||||
|
||||
|
||||
# 使用一种不会选到重复索引的方法
|
||||
indices = list(range(len(old_data)))
|
||||
|
||||
|
||||
# 方法1:使用numpy.random.choice
|
||||
# 把列表转成一个映射字典,保证不会有重复
|
||||
remove_set = set()
|
||||
total_attempts = 0
|
||||
|
||||
|
||||
# 尝试按权重随机选择,直到选够数量
|
||||
while len(remove_set) < remove_count and total_attempts < len(old_data) * 2:
|
||||
idx = random.choices(indices, weights=weights, k=1)[0]
|
||||
remove_set.add(idx)
|
||||
total_attempts += 1
|
||||
|
||||
|
||||
# 如果没选够,随机补充
|
||||
if len(remove_set) < remove_count:
|
||||
remaining = set(indices) - remove_set
|
||||
remove_set.update(random.sample(list(remaining), remove_count - len(remove_set)))
|
||||
|
||||
|
||||
remove_indices = list(remove_set)
|
||||
|
||||
|
||||
# 从后往前删除,避免索引变化
|
||||
for idx in sorted(remove_indices, reverse=True):
|
||||
old_data.pop(idx)
|
||||
|
||||
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(old_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
return learnt_expressions
|
||||
|
||||
async def learn_expression(self, type: str, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:
|
||||
@@ -397,4 +399,11 @@ class ExpressionLearner:
|
||||
|
||||
init_prompt()
|
||||
|
||||
expression_learner = ExpressionLearner()
|
||||
expression_learner = None
|
||||
|
||||
|
||||
def get_expression_learner():
|
||||
global expression_learner
|
||||
if expression_learner is None:
|
||||
expression_learner = ExpressionLearner()
|
||||
return expression_learner
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import time
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
import json
|
||||
|
||||
logger = get_logger("hfc") # Logger Name Changed
|
||||
@@ -97,7 +97,7 @@ class CycleDetail:
|
||||
)
|
||||
|
||||
# current_time_minute = time.strftime("%Y%m%d_%H%M", time.localtime())
|
||||
|
||||
|
||||
# try:
|
||||
# self.log_cycle_to_file(
|
||||
# log_dir + self.prefix + f"/{current_time_minute}_cycle_" + str(self.cycle_id) + ".json"
|
||||
@@ -117,7 +117,6 @@ class CycleDetail:
|
||||
if dir_name and not os.path.exists(dir_name):
|
||||
os.makedirs(dir_name, exist_ok=True)
|
||||
# 写入文件
|
||||
|
||||
|
||||
file_path = os.path.join(dir_name, os.path.basename(file_path))
|
||||
# print("file_path:", file_path)
|
||||
|
||||
@@ -4,20 +4,17 @@ import time
|
||||
import traceback
|
||||
from collections import deque
|
||||
from typing import List, Optional, Dict, Any, Deque, Callable, Awaitable
|
||||
from src.chat.message_receive.chat_stream import chat_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from rich.traceback import install
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.heart_flow.observation.observation import Observation
|
||||
from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail
|
||||
from src.chat.focus_chat.info.info_base import InfoBase
|
||||
from src.chat.focus_chat.info_processors.chattinginfo_processor import ChattingInfoProcessor
|
||||
from src.chat.focus_chat.info_processors.relationship_processor import RelationshipProcessor
|
||||
from src.chat.focus_chat.info_processors.mind_processor import MindProcessor
|
||||
from src.chat.focus_chat.info_processors.working_memory_processor import WorkingMemoryProcessor
|
||||
|
||||
# from src.chat.focus_chat.info_processors.action_processor import ActionProcessor
|
||||
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
|
||||
from src.chat.heart_flow.observation.working_observation import WorkingMemoryObservation
|
||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||
@@ -47,11 +44,10 @@ OBSERVATION_CLASSES = {
|
||||
# 定义处理器映射:键是处理器名称,值是 (处理器类, 可选的配置键名)
|
||||
PROCESSOR_CLASSES = {
|
||||
"ChattingInfoProcessor": (ChattingInfoProcessor, None),
|
||||
"MindProcessor": (MindProcessor, "mind_processor"),
|
||||
"ToolProcessor": (ToolProcessor, "tool_use_processor"),
|
||||
"WorkingMemoryProcessor": (WorkingMemoryProcessor, "working_memory_processor"),
|
||||
"SelfProcessor": (SelfProcessor, "self_identify_processor"),
|
||||
"RelationshipProcessor": (RelationshipProcessor, "relationship_processor"),
|
||||
"RelationshipProcessor": (RelationshipProcessor, "relation_processor"),
|
||||
}
|
||||
|
||||
logger = get_logger("hfc") # Logger Name Changed
|
||||
@@ -97,31 +93,39 @@ class HeartFChatting:
|
||||
"""
|
||||
# 基础属性
|
||||
self.stream_id: str = chat_id # 聊天流ID
|
||||
self.chat_stream = chat_manager.get_stream(self.stream_id)
|
||||
self.log_prefix = f"[{chat_manager.get_stream_name(self.stream_id) or self.stream_id}]"
|
||||
|
||||
self.chat_stream = get_chat_manager().get_stream(self.stream_id)
|
||||
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.stream_id) or self.stream_id}]"
|
||||
|
||||
self.memory_activator = MemoryActivator()
|
||||
|
||||
|
||||
# 初始化观察器
|
||||
self.observations: List[Observation] = []
|
||||
self._register_observations()
|
||||
|
||||
|
||||
# 根据配置文件和默认规则确定启用的处理器
|
||||
config_processor_settings = global_config.focus_chat_processor
|
||||
self.enabled_processor_names = [
|
||||
proc_name for proc_name, (_proc_class, config_key) in PROCESSOR_CLASSES.items()
|
||||
if not config_key or getattr(config_processor_settings, config_key, True)
|
||||
]
|
||||
self.enabled_processor_names = []
|
||||
|
||||
for proc_name, (_proc_class, config_key) in PROCESSOR_CLASSES.items():
|
||||
# 对于关系处理器,需要同时检查两个配置项
|
||||
if proc_name == "RelationshipProcessor":
|
||||
if global_config.relationship.enable_relationship and getattr(
|
||||
config_processor_settings, config_key, True
|
||||
):
|
||||
self.enabled_processor_names.append(proc_name)
|
||||
else:
|
||||
# 其他处理器的原有逻辑
|
||||
if not config_key or getattr(config_processor_settings, config_key, True):
|
||||
self.enabled_processor_names.append(proc_name)
|
||||
|
||||
# logger.info(f"{self.log_prefix} 将启用的处理器: {self.enabled_processor_names}")
|
||||
|
||||
|
||||
self.processors: List[BaseProcessor] = []
|
||||
self._register_default_processors()
|
||||
|
||||
self.expressor = DefaultExpressor(chat_stream=self.chat_stream)
|
||||
self.replyer = DefaultReplyer(chat_stream=self.chat_stream)
|
||||
|
||||
|
||||
|
||||
self.action_manager = ActionManager()
|
||||
self.action_planner = PlannerFactory.create_planner(
|
||||
log_prefix=self.log_prefix, action_manager=self.action_manager
|
||||
@@ -130,7 +134,6 @@ class HeartFChatting:
|
||||
self.action_observation = ActionObservation(observe_id=self.stream_id)
|
||||
self.action_observation.set_action_manager(self.action_manager)
|
||||
|
||||
|
||||
self._processing_lock = asyncio.Lock()
|
||||
|
||||
# 循环控制内部状态
|
||||
@@ -152,6 +155,13 @@ class HeartFChatting:
|
||||
|
||||
for name, (observation_class, param_name) in OBSERVATION_CLASSES.items():
|
||||
try:
|
||||
# 检查是否需要跳过WorkingMemoryObservation
|
||||
if name == "WorkingMemoryObservation":
|
||||
# 如果工作记忆处理器被禁用,则跳过WorkingMemoryObservation
|
||||
if not global_config.focus_chat_processor.working_memory_processor:
|
||||
logger.debug(f"{self.log_prefix} 工作记忆处理器已禁用,跳过注册观察器 {name}")
|
||||
continue
|
||||
|
||||
# 根据参数名使用正确的参数
|
||||
kwargs = {param_name: self.stream_id}
|
||||
observation = observation_class(**kwargs)
|
||||
@@ -174,7 +184,12 @@ class HeartFChatting:
|
||||
if processor_info:
|
||||
processor_actual_class = processor_info[0] # 获取实际的类定义
|
||||
# 根据处理器类名判断是否需要 subheartflow_id
|
||||
if name in ["MindProcessor", "ToolProcessor", "WorkingMemoryProcessor", "SelfProcessor", "RelationshipProcessor"]:
|
||||
if name in [
|
||||
"ToolProcessor",
|
||||
"WorkingMemoryProcessor",
|
||||
"SelfProcessor",
|
||||
"RelationshipProcessor",
|
||||
]:
|
||||
self.processors.append(processor_actual_class(subheartflow_id=self.stream_id))
|
||||
elif name == "ChattingInfoProcessor":
|
||||
self.processors.append(processor_actual_class())
|
||||
@@ -195,9 +210,7 @@ class HeartFChatting:
|
||||
)
|
||||
|
||||
if self.processors:
|
||||
logger.info(
|
||||
f"{self.log_prefix} 已注册处理器: {[p.__class__.__name__ for p in self.processors]}"
|
||||
)
|
||||
logger.info(f"{self.log_prefix} 已注册处理器: {[p.__class__.__name__ for p in self.processors]}")
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 没有注册任何处理器。这可能是由于配置错误或所有处理器都被禁用了。")
|
||||
|
||||
@@ -284,7 +297,9 @@ class HeartFChatting:
|
||||
self._current_cycle_detail.set_loop_info(loop_info)
|
||||
|
||||
# 从observations列表中获取HFCloopObservation
|
||||
hfcloop_observation = next((obs for obs in self.observations if isinstance(obs, HFCloopObservation)), None)
|
||||
hfcloop_observation = next(
|
||||
(obs for obs in self.observations if isinstance(obs, HFCloopObservation)), None
|
||||
)
|
||||
if hfcloop_observation:
|
||||
hfcloop_observation.add_loop_info(self._current_cycle_detail)
|
||||
else:
|
||||
@@ -356,9 +371,7 @@ class HeartFChatting:
|
||||
if acquired and self._processing_lock.locked():
|
||||
self._processing_lock.release()
|
||||
|
||||
async def _process_processors(
|
||||
self, observations: List[Observation], running_memorys: List[Dict[str, Any]]
|
||||
) -> tuple[List[InfoBase], Dict[str, float]]:
|
||||
async def _process_processors(self, observations: List[Observation]) -> tuple[List[InfoBase], Dict[str, float]]:
|
||||
# 记录并行任务开始时间
|
||||
parallel_start_time = time.time()
|
||||
logger.debug(f"{self.log_prefix} 开始信息处理器并行任务")
|
||||
@@ -372,7 +385,7 @@ class HeartFChatting:
|
||||
|
||||
async def run_with_timeout(proc=processor):
|
||||
return await asyncio.wait_for(
|
||||
proc.process_info(observations=observations, running_memorys=running_memorys),
|
||||
proc.process_info(observations=observations),
|
||||
timeout=global_config.focus_chat.processor_max_time,
|
||||
)
|
||||
|
||||
@@ -443,32 +456,29 @@ class HeartFChatting:
|
||||
|
||||
# 根据配置决定是否并行执行调整动作、回忆和处理器阶段
|
||||
|
||||
# 并行执行调整动作、回忆和处理器阶段
|
||||
# 并行执行调整动作、回忆和处理器阶段
|
||||
with Timer("并行调整动作、处理", cycle_timers):
|
||||
# 创建并行任务
|
||||
async def modify_actions_task():
|
||||
async def modify_actions_task():
|
||||
# 调用完整的动作修改流程
|
||||
await self.action_modifier.modify_actions(
|
||||
observations=self.observations,
|
||||
)
|
||||
|
||||
|
||||
await self.action_observation.observe()
|
||||
self.observations.append(self.action_observation)
|
||||
return True
|
||||
|
||||
|
||||
# 创建三个并行任务
|
||||
action_modify_task = asyncio.create_task(modify_actions_task())
|
||||
memory_task = asyncio.create_task(self.memory_activator.activate_memory(self.observations))
|
||||
processor_task = asyncio.create_task(self._process_processors(self.observations, []))
|
||||
processor_task = asyncio.create_task(self._process_processors(self.observations))
|
||||
|
||||
# 等待三个任务完成
|
||||
_, running_memorys, (all_plan_info, processor_time_costs) = await asyncio.gather(
|
||||
action_modify_task, memory_task, processor_task
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
loop_processor_info = {
|
||||
"all_plan_info": all_plan_info,
|
||||
"processor_time_costs": processor_time_costs,
|
||||
@@ -479,7 +489,6 @@ class HeartFChatting:
|
||||
|
||||
loop_plan_info = {
|
||||
"action_result": plan_result.get("action_result", {}),
|
||||
"current_mind": plan_result.get("current_mind", ""),
|
||||
"observed_messages": plan_result.get("observed_messages", ""),
|
||||
}
|
||||
|
||||
@@ -552,9 +561,6 @@ class HeartFChatting:
|
||||
tuple[bool, str, str]: (是否执行了动作, 思考消息ID, 命令)
|
||||
"""
|
||||
try:
|
||||
action_time = time.time()
|
||||
action_id = f"{action_time}_{thinking_id}"
|
||||
|
||||
# 使用工厂创建动作处理器实例
|
||||
try:
|
||||
action_handler = self.action_manager.create_action(
|
||||
@@ -586,9 +592,13 @@ class HeartFChatting:
|
||||
else:
|
||||
success, reply_text = result
|
||||
command = ""
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 麦麦执行了'{action}', 返回结果'{success}', '{reply_text}', '{command}'"
|
||||
)
|
||||
|
||||
# 检查action_data中是否有系统命令,优先使用系统命令
|
||||
if "_system_command" in action_data:
|
||||
command = action_data["_system_command"]
|
||||
logger.debug(f"{self.log_prefix} 从action_data中获取系统命令: {command}")
|
||||
|
||||
logger.debug(f"{self.log_prefix} 麦麦执行了'{action}', 返回结果'{success}', '{reply_text}', '{command}'")
|
||||
|
||||
return success, reply_text, command
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import asyncio
|
||||
from typing import Dict, Optional # 重新导入类型
|
||||
from src.chat.message_receive.message import MessageSending, MessageThinking
|
||||
from src.common.message.api import global_api
|
||||
from src.common.message.api import get_global_api
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.utils.utils import truncate_message
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.utils import calculate_typing_time
|
||||
from rich.traceback import install
|
||||
import traceback
|
||||
@@ -21,8 +21,8 @@ async def send_message(message: MessageSending) -> str:
|
||||
|
||||
try:
|
||||
# 直接调用API发送消息
|
||||
await global_api.send_message(message)
|
||||
logger.success(f"已将消息 '{message_preview}' 发往平台'{message.message_info.platform}'")
|
||||
await get_global_api().send_message(message)
|
||||
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.message_info.platform}'")
|
||||
return message.processed_plain_text
|
||||
|
||||
except Exception as e:
|
||||
@@ -88,10 +88,10 @@ class HeartFCSender:
|
||||
"""
|
||||
if not message.chat_stream:
|
||||
logger.error("消息缺少 chat_stream,无法发送")
|
||||
return
|
||||
raise Exception("消息缺少 chat_stream,无法发送")
|
||||
if not message.message_info or not message.message_info.message_id:
|
||||
logger.error("消息缺少 message_info 或 message_id,无法发送")
|
||||
return
|
||||
raise Exception("消息缺少 message_info 或 message_id,无法发送")
|
||||
|
||||
chat_id = message.chat_stream.stream_id
|
||||
message_id = message.message_info.message_id
|
||||
@@ -110,7 +110,9 @@ class HeartFCSender:
|
||||
message.set_reply()
|
||||
logger.debug(f"[{chat_id}] 应用 set_reply 逻辑: {message.processed_plain_text[:20]}...")
|
||||
|
||||
# print(f"message.display_message: {message.display_message}")
|
||||
await message.process()
|
||||
# print(f"message.display_message: {message.display_message}")
|
||||
|
||||
if typing:
|
||||
if has_thinking:
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
from src.chat.memory_system.Hippocampus import HippocampusManager
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.heart_flow.heartflow import heartflow
|
||||
from src.chat.message_receive.chat_stream import chat_manager, ChatStream
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.person_info.relationship_manager import relationship_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
import math
|
||||
import re
|
||||
@@ -15,6 +14,8 @@ import traceback
|
||||
from typing import Optional, Tuple, Dict, Any
|
||||
from maim_message import UserInfo
|
||||
|
||||
from src.person_info.relationship_manager import get_relationship_manager
|
||||
|
||||
# from ..message_receive.message_buffer import message_buffer
|
||||
|
||||
logger = get_logger("chat")
|
||||
@@ -45,14 +46,15 @@ async def _process_relationship(message: MessageRecv) -> None:
|
||||
nickname = message.message_info.user_info.user_nickname
|
||||
cardname = message.message_info.user_info.user_cardname or nickname
|
||||
|
||||
relationship_manager = get_relationship_manager()
|
||||
is_known = await relationship_manager.is_known_some_one(platform, user_id)
|
||||
|
||||
if not is_known:
|
||||
logger.info(f"首次认识用户: {nickname}")
|
||||
await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname)
|
||||
# elif not await relationship_manager.is_qved_name(platform, user_id):
|
||||
# logger.info(f"给用户({nickname},{cardname})取名: {nickname}")
|
||||
# await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname, "")
|
||||
# logger.info(f"给用户({nickname},{cardname})取名: {nickname}")
|
||||
# await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname, "")
|
||||
|
||||
|
||||
async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
|
||||
@@ -67,21 +69,22 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
|
||||
is_mentioned, _ = is_mentioned_bot_in_message(message)
|
||||
interested_rate = 0.0
|
||||
|
||||
with Timer("记忆激活"):
|
||||
interested_rate = await HippocampusManager.get_instance().get_activate_from_text(
|
||||
message.processed_plain_text,
|
||||
fast_retrieval=True,
|
||||
)
|
||||
text_len = len(message.processed_plain_text)
|
||||
# 根据文本长度调整兴趣度,长度越大兴趣度越高,但增长率递减,最低0.01,最高0.05
|
||||
# 采用对数函数实现递减增长
|
||||
if global_config.memory.enable_memory:
|
||||
with Timer("记忆激活"):
|
||||
interested_rate = await hippocampus_manager.get_activate_from_text(
|
||||
message.processed_plain_text,
|
||||
fast_retrieval=True,
|
||||
)
|
||||
logger.debug(f"记忆激活率: {interested_rate:.2f}")
|
||||
|
||||
base_interest = 0.01 + (0.05 - 0.01) * (math.log10(text_len + 1) / math.log10(1000 + 1))
|
||||
base_interest = min(max(base_interest, 0.01), 0.05)
|
||||
text_len = len(message.processed_plain_text)
|
||||
# 根据文本长度调整兴趣度,长度越大兴趣度越高,但增长率递减,最低0.01,最高0.05
|
||||
# 采用对数函数实现递减增长
|
||||
|
||||
interested_rate += base_interest
|
||||
base_interest = 0.01 + (0.05 - 0.01) * (math.log10(text_len + 1) / math.log10(1000 + 1))
|
||||
base_interest = min(max(base_interest, 0.01), 0.05)
|
||||
|
||||
logger.trace(f"记忆激活率: {interested_rate:.2f}")
|
||||
interested_rate += base_interest
|
||||
|
||||
if is_mentioned:
|
||||
interest_increase_on_mention = 1
|
||||
@@ -180,8 +183,7 @@ class HeartFCMessageReceiver:
|
||||
userinfo = message.message_info.user_info
|
||||
messageinfo = message.message_info
|
||||
|
||||
|
||||
chat = await chat_manager.get_or_create_stream(
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=messageinfo.platform,
|
||||
user_info=userinfo,
|
||||
group_info=groupinfo,
|
||||
@@ -210,7 +212,7 @@ class HeartFCMessageReceiver:
|
||||
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{message.processed_plain_text}")
|
||||
|
||||
# 8. 关系处理
|
||||
if global_config.relationship.give_name:
|
||||
if global_config.relationship.enable_relationship and global_config.relationship.give_name:
|
||||
await _process_relationship(message)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Optional
|
||||
from src.chat.message_receive.message import MessageRecv, BaseMessageInfo
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.message_receive.message import UserInfo
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
import json
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Any, Optional, Dict
|
||||
from typing import List, Any
|
||||
from src.chat.focus_chat.info.info_base import InfoBase
|
||||
from src.chat.heart_flow.observation.observation import Observation
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("base_processor")
|
||||
|
||||
@@ -23,8 +23,7 @@ class BaseProcessor(ABC):
|
||||
@abstractmethod
|
||||
async def process_info(
|
||||
self,
|
||||
observations: Optional[List[Observation]] = None,
|
||||
running_memorys: Optional[List[Dict]] = None,
|
||||
observations: List[Observation] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[InfoBase]:
|
||||
"""处理信息对象的抽象方法
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
from typing import List, Optional, Any
|
||||
from typing import List, Any
|
||||
from src.chat.focus_chat.info.obs_info import ObsInfo
|
||||
from src.chat.heart_flow.observation.observation import Observation
|
||||
from src.chat.focus_chat.info.info_base import InfoBase
|
||||
from .base_processor import BaseProcessor
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
|
||||
from src.chat.focus_chat.info.cycle_info import CycleInfo
|
||||
from datetime import datetime
|
||||
from typing import Dict
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
import asyncio
|
||||
|
||||
logger = get_logger("processor")
|
||||
|
||||
@@ -36,8 +34,7 @@ class ChattingInfoProcessor(BaseProcessor):
|
||||
|
||||
async def process_info(
|
||||
self,
|
||||
observations: Optional[List[Observation]] = None,
|
||||
running_memorys: Optional[List[Dict]] = None,
|
||||
observations: List[Observation] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[InfoBase]:
|
||||
"""处理Observation对象
|
||||
|
||||
@@ -4,18 +4,17 @@ from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
import time
|
||||
import traceback
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.individuality.individuality import individuality
|
||||
from src.common.logger import get_logger
|
||||
from src.individuality.individuality import get_individuality
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.json_utils import safe_json_dumps
|
||||
from src.chat.message_receive.chat_stream import chat_manager
|
||||
from src.person_info.relationship_manager import relationship_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.person_info.relationship_manager import get_relationship_manager
|
||||
from .base_processor import BaseProcessor
|
||||
from src.chat.focus_chat.info.mind_info import MindInfo
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
|
||||
from src.chat.heart_flow.observation.actions_observation import ActionObservation
|
||||
from typing import Dict
|
||||
from src.chat.focus_chat.info.info_base import InfoBase
|
||||
|
||||
logger = get_logger("processor")
|
||||
@@ -77,7 +76,7 @@ class MindProcessor(BaseProcessor):
|
||||
self.structured_info = []
|
||||
self.structured_info_str = ""
|
||||
|
||||
name = chat_manager.get_stream_name(self.subheartflow_id)
|
||||
name = get_chat_manager().get_stream_name(self.subheartflow_id)
|
||||
self.log_prefix = f"[{name}] "
|
||||
self._update_structured_info_str()
|
||||
|
||||
@@ -110,7 +109,8 @@ class MindProcessor(BaseProcessor):
|
||||
logger.debug(f"{self.log_prefix} 更新 structured_info_str: \n{self.structured_info_str}")
|
||||
|
||||
async def process_info(
|
||||
self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None, *infos
|
||||
self,
|
||||
observations: List[Observation] = None,
|
||||
) -> List[InfoBase]:
|
||||
"""处理信息对象
|
||||
|
||||
@@ -120,16 +120,14 @@ class MindProcessor(BaseProcessor):
|
||||
Returns:
|
||||
List[InfoBase]: 处理后的结构化信息列表
|
||||
"""
|
||||
current_mind = await self.do_thinking_before_reply(observations, running_memorys)
|
||||
current_mind = await self.do_thinking_before_reply(observations)
|
||||
|
||||
mind_info = MindInfo()
|
||||
mind_info.set_current_mind(current_mind)
|
||||
|
||||
return [mind_info]
|
||||
|
||||
async def do_thinking_before_reply(
|
||||
self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None
|
||||
):
|
||||
async def do_thinking_before_reply(self, observations: List[Observation] = None):
|
||||
"""
|
||||
在回复前进行思考,生成内心想法并收集工具调用结果
|
||||
|
||||
@@ -157,13 +155,6 @@ class MindProcessor(BaseProcessor):
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 当前完整的 structured_info: {safe_json_dumps(self.structured_info, ensure_ascii=False)}"
|
||||
)
|
||||
|
||||
memory_str = ""
|
||||
if running_memorys:
|
||||
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
|
||||
for running_memory in running_memorys:
|
||||
memory_str += f"{running_memory['topic']}: {running_memory['content']}\n"
|
||||
|
||||
# ---------- 1. 准备基础数据 ----------
|
||||
# 获取现有想法和情绪状态
|
||||
previous_mind = self.current_mind if self.current_mind else ""
|
||||
@@ -193,15 +184,16 @@ class MindProcessor(BaseProcessor):
|
||||
# 获取个性化信息
|
||||
|
||||
relation_prompt = ""
|
||||
for person in person_list:
|
||||
relation_prompt += await relationship_manager.build_relationship_info(person, is_id=True)
|
||||
if global_config.relationship.enable_relationship:
|
||||
for person in person_list:
|
||||
relationship_manager = get_relationship_manager()
|
||||
relation_prompt += await relationship_manager.build_relationship_info(person, is_id=True)
|
||||
|
||||
template_name = "sub_heartflow_prompt_before" if is_group_chat else "sub_heartflow_prompt_private_before"
|
||||
logger.debug(f"{self.log_prefix} 使用{'群聊' if is_group_chat else '私聊'}思考模板")
|
||||
|
||||
prompt = (await global_prompt_manager.get_prompt_async(template_name)).format(
|
||||
bot_name=individuality.name,
|
||||
memory_str=memory_str,
|
||||
bot_name=get_individuality().name,
|
||||
extra_info=self.structured_info_str,
|
||||
relation_prompt=relation_prompt,
|
||||
time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
|
||||
@@ -4,21 +4,27 @@ from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
import time
|
||||
import traceback
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import chat_manager
|
||||
from src.person_info.relationship_manager import relationship_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.person_info.relationship_manager import get_relationship_manager
|
||||
from .base_processor import BaseProcessor
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
from typing import Dict
|
||||
from src.chat.focus_chat.info.info_base import InfoBase
|
||||
from src.chat.focus_chat.info.relation_info import RelationInfo
|
||||
from json_repair import repair_json
|
||||
from src.person_info.person_info import person_info_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
import json
|
||||
import asyncio
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat
|
||||
|
||||
|
||||
# 配置常量:是否启用小模型即时信息提取
|
||||
# 开启时:使用小模型并行即时提取,速度更快,但精度可能略低
|
||||
# 关闭时:使用原来的异步模式,精度更高但速度较慢
|
||||
ENABLE_INSTANT_INFO_EXTRACTION = True
|
||||
|
||||
logger = get_logger("processor")
|
||||
|
||||
|
||||
@@ -58,7 +64,7 @@ def init_prompt():
|
||||
|
||||
"""
|
||||
Prompt(relationship_prompt, "relationship_prompt")
|
||||
|
||||
|
||||
fetch_info_prompt = """
|
||||
|
||||
{name_block}
|
||||
@@ -79,7 +85,6 @@ def init_prompt():
|
||||
Prompt(fetch_info_prompt, "fetch_info_prompt")
|
||||
|
||||
|
||||
|
||||
class RelationshipProcessor(BaseProcessor):
|
||||
log_prefix = "关系"
|
||||
|
||||
@@ -87,8 +92,10 @@ class RelationshipProcessor(BaseProcessor):
|
||||
super().__init__()
|
||||
|
||||
self.subheartflow_id = subheartflow_id
|
||||
self.info_fetching_cache: List[Dict[str, any]] = []
|
||||
self.info_fetched_cache: Dict[str, Dict[str, any]] = {} # {person_id: {"info": str, "ttl": int, "start_time": float}}
|
||||
self.info_fetching_cache: List[Dict[str, any]] = []
|
||||
self.info_fetched_cache: Dict[
|
||||
str, Dict[str, any]
|
||||
] = {} # {person_id: {"info": str, "ttl": int, "start_time": float}}
|
||||
self.person_engaged_cache: List[Dict[str, any]] = [] # [{person_id: str, start_time: float, rounds: int}]
|
||||
self.grace_period_rounds = 5
|
||||
|
||||
@@ -97,12 +104,17 @@ class RelationshipProcessor(BaseProcessor):
|
||||
request_type="focus.relationship",
|
||||
)
|
||||
|
||||
name = chat_manager.get_stream_name(self.subheartflow_id)
|
||||
# 小模型用于即时信息提取
|
||||
if ENABLE_INSTANT_INFO_EXTRACTION:
|
||||
self.instant_llm_model = LLMRequest(
|
||||
model=global_config.model.utils_small,
|
||||
request_type="focus.relationship.instant",
|
||||
)
|
||||
|
||||
name = get_chat_manager().get_stream_name(self.subheartflow_id)
|
||||
self.log_prefix = f"[{name}] "
|
||||
|
||||
async def process_info(
|
||||
self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None, *infos
|
||||
) -> List[InfoBase]:
|
||||
async def process_info(self, observations: List[Observation] = None, *infos) -> List[InfoBase]:
|
||||
"""处理信息对象
|
||||
|
||||
Args:
|
||||
@@ -124,7 +136,7 @@ class RelationshipProcessor(BaseProcessor):
|
||||
|
||||
async def relation_identify(
|
||||
self,
|
||||
observations: Optional[List[Observation]] = None,
|
||||
observations: List[Observation] = None,
|
||||
):
|
||||
"""
|
||||
在回复前进行思考,生成内心想法并收集工具调用结果
|
||||
@@ -144,18 +156,27 @@ class RelationshipProcessor(BaseProcessor):
|
||||
for record in list(self.person_engaged_cache):
|
||||
record["rounds"] += 1
|
||||
time_elapsed = current_time - record["start_time"]
|
||||
message_count = len(get_raw_msg_by_timestamp_with_chat(self.subheartflow_id, record["start_time"], current_time))
|
||||
|
||||
if (record["rounds"] > 50 or
|
||||
time_elapsed > 1800 or # 30分钟
|
||||
message_count > 75):
|
||||
logger.info(f"{self.log_prefix} 用户 {record['person_id']} 满足关系构建条件,开始构建关系。")
|
||||
message_count = len(
|
||||
get_raw_msg_by_timestamp_with_chat(self.subheartflow_id, record["start_time"], current_time)
|
||||
)
|
||||
|
||||
print(record)
|
||||
|
||||
# 根据消息数量和时间设置不同的触发条件
|
||||
should_trigger = (
|
||||
message_count >= 50 # 50条消息必定满足
|
||||
or (message_count >= 35 and time_elapsed >= 300) # 35条且10分钟
|
||||
or (message_count >= 25 and time_elapsed >= 900) # 25条且30分钟
|
||||
or (message_count >= 10 and time_elapsed >= 2000) # 10条且1小时
|
||||
)
|
||||
|
||||
if should_trigger:
|
||||
logger.info(
|
||||
f"{self.log_prefix} 用户 {record['person_id']} 满足关系构建条件,开始构建关系。消息数:{message_count},时长:{time_elapsed:.0f}秒"
|
||||
)
|
||||
asyncio.create_task(
|
||||
self.update_impression_on_cache_expiry(
|
||||
record["person_id"],
|
||||
self.subheartflow_id,
|
||||
record["start_time"],
|
||||
current_time
|
||||
record["person_id"], self.subheartflow_id, record["start_time"], current_time
|
||||
)
|
||||
)
|
||||
self.person_engaged_cache.remove(record)
|
||||
@@ -167,20 +188,24 @@ class RelationshipProcessor(BaseProcessor):
|
||||
if self.info_fetched_cache[person_id][info_type]["ttl"] <= 0:
|
||||
# 在删除前查找匹配的info_fetching_cache记录
|
||||
matched_record = None
|
||||
min_time_diff = float('inf')
|
||||
min_time_diff = float("inf")
|
||||
for record in self.info_fetching_cache:
|
||||
if (record["person_id"] == person_id and
|
||||
record["info_type"] == info_type and
|
||||
not record["forget"]):
|
||||
time_diff = abs(record["start_time"] - self.info_fetched_cache[person_id][info_type]["start_time"])
|
||||
if (
|
||||
record["person_id"] == person_id
|
||||
and record["info_type"] == info_type
|
||||
and not record["forget"]
|
||||
):
|
||||
time_diff = abs(
|
||||
record["start_time"] - self.info_fetched_cache[person_id][info_type]["start_time"]
|
||||
)
|
||||
if time_diff < min_time_diff:
|
||||
min_time_diff = time_diff
|
||||
matched_record = record
|
||||
|
||||
|
||||
if matched_record:
|
||||
matched_record["forget"] = True
|
||||
logger.info(f"{self.log_prefix} 用户 {person_id} 的 {info_type} 信息已过期,标记为遗忘。")
|
||||
|
||||
|
||||
del self.info_fetched_cache[person_id][info_type]
|
||||
if not self.info_fetched_cache[person_id]:
|
||||
del self.info_fetched_cache[person_id]
|
||||
@@ -188,7 +213,7 @@ class RelationshipProcessor(BaseProcessor):
|
||||
# 5. 为需要处理的人员准备LLM prompt
|
||||
nickname_str = ",".join(global_config.bot.alias_names)
|
||||
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
|
||||
|
||||
|
||||
info_cache_block = ""
|
||||
if self.info_fetching_cache:
|
||||
for info_fetching in self.info_fetching_cache:
|
||||
@@ -203,37 +228,63 @@ class RelationshipProcessor(BaseProcessor):
|
||||
chat_observe_info=chat_observe_info,
|
||||
info_cache_block=info_cache_block,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
logger.info(f"{self.log_prefix} 人物信息prompt: \n{prompt}\n")
|
||||
logger.debug(f"{self.log_prefix} 人物信息prompt: \n{prompt}\n")
|
||||
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
if content:
|
||||
print(f"content: {content}")
|
||||
content_json = json.loads(repair_json(content))
|
||||
|
||||
# 收集即时提取任务
|
||||
instant_tasks = []
|
||||
async_tasks = []
|
||||
|
||||
person_info_manager = get_person_info_manager()
|
||||
for person_name, info_type in content_json.items():
|
||||
person_id = person_info_manager.get_person_id_by_person_name(person_name)
|
||||
if person_id:
|
||||
self.info_fetching_cache.append({
|
||||
"person_id": person_id,
|
||||
"person_name": person_name,
|
||||
"info_type": info_type,
|
||||
"start_time": time.time(),
|
||||
"forget": False,
|
||||
})
|
||||
self.info_fetching_cache.append(
|
||||
{
|
||||
"person_id": person_id,
|
||||
"person_name": person_name,
|
||||
"info_type": info_type,
|
||||
"start_time": time.time(),
|
||||
"forget": False,
|
||||
}
|
||||
)
|
||||
if len(self.info_fetching_cache) > 20:
|
||||
self.info_fetching_cache.pop(0)
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 未找到用户 {person_name} 的ID,跳过调取信息。")
|
||||
|
||||
continue
|
||||
|
||||
logger.info(f"{self.log_prefix} 调取用户 {person_name} 的 {info_type} 信息。")
|
||||
|
||||
self.person_engaged_cache.append({
|
||||
"person_id": person_id,
|
||||
"start_time": time.time(),
|
||||
"rounds": 0
|
||||
})
|
||||
asyncio.create_task(self.fetch_person_info(person_id, [info_type], start_time=time.time()))
|
||||
|
||||
# 检查person_engaged_cache中是否已存在该person_id
|
||||
person_exists = any(record["person_id"] == person_id for record in self.person_engaged_cache)
|
||||
if not person_exists:
|
||||
self.person_engaged_cache.append(
|
||||
{"person_id": person_id, "start_time": time.time(), "rounds": 0}
|
||||
)
|
||||
|
||||
if ENABLE_INSTANT_INFO_EXTRACTION:
|
||||
# 收集即时提取任务
|
||||
instant_tasks.append((person_id, info_type, time.time()))
|
||||
else:
|
||||
# 使用原来的异步模式
|
||||
async_tasks.append(
|
||||
asyncio.create_task(self.fetch_person_info(person_id, [info_type], start_time=time.time()))
|
||||
)
|
||||
|
||||
# 执行即时提取任务
|
||||
if ENABLE_INSTANT_INFO_EXTRACTION and instant_tasks:
|
||||
await self._execute_instant_extraction_batch(instant_tasks)
|
||||
|
||||
# 启动异步任务(如果不是即时模式)
|
||||
if async_tasks:
|
||||
# 异步任务不需要等待完成
|
||||
pass
|
||||
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} LLM返回空结果,关系识别失败。")
|
||||
@@ -254,86 +305,179 @@ class RelationshipProcessor(BaseProcessor):
|
||||
info_content = self.info_fetched_cache[person_id][info_type]["info"]
|
||||
person_infos_str += f"[{info_type}]:{info_content};"
|
||||
else:
|
||||
person_infos_str += f"你不了解{person_name}有关[{info_type}]的信息,不要胡乱回答;"
|
||||
person_infos_str += f"你不了解{person_name}有关[{info_type}]的信息,不要胡乱回答,你可以直接说你不知道,或者你忘记了;"
|
||||
if person_infos_str:
|
||||
persons_infos_str += f"你对 {person_name} 的了解:{person_infos_str}\n"
|
||||
|
||||
# 处理正在调取但还没有结果的项目
|
||||
pending_info_dict = {}
|
||||
for record in self.info_fetching_cache:
|
||||
if not record["forget"]:
|
||||
current_time = time.time()
|
||||
# 只处理不超过2分钟的调取请求,避免过期请求一直显示
|
||||
if current_time - record["start_time"] <= 120: # 10分钟内的请求
|
||||
person_id = record["person_id"]
|
||||
person_name = record["person_name"]
|
||||
info_type = record["info_type"]
|
||||
|
||||
# 检查是否已经在info_fetched_cache中有结果
|
||||
if (person_id in self.info_fetched_cache and
|
||||
info_type in self.info_fetched_cache[person_id]):
|
||||
continue
|
||||
|
||||
# 按人物组织正在调取的信息
|
||||
if person_name not in pending_info_dict:
|
||||
pending_info_dict[person_name] = []
|
||||
pending_info_dict[person_name].append(info_type)
|
||||
|
||||
# 添加正在调取的信息到返回字符串
|
||||
for person_name, info_types in pending_info_dict.items():
|
||||
info_types_str = "、".join(info_types)
|
||||
persons_infos_str += f"你正在识图回忆有关 {person_name} 的 {info_types_str} 信息,稍等一下再回答...\n"
|
||||
|
||||
# 处理正在调取但还没有结果的项目(只在非即时提取模式下显示)
|
||||
if not ENABLE_INSTANT_INFO_EXTRACTION:
|
||||
pending_info_dict = {}
|
||||
for record in self.info_fetching_cache:
|
||||
if not record["forget"]:
|
||||
current_time = time.time()
|
||||
# 只处理不超过2分钟的调取请求,避免过期请求一直显示
|
||||
if current_time - record["start_time"] <= 120: # 10分钟内的请求
|
||||
person_id = record["person_id"]
|
||||
person_name = record["person_name"]
|
||||
info_type = record["info_type"]
|
||||
|
||||
# 检查是否已经在info_fetched_cache中有结果
|
||||
if person_id in self.info_fetched_cache and info_type in self.info_fetched_cache[person_id]:
|
||||
continue
|
||||
|
||||
# 按人物组织正在调取的信息
|
||||
if person_name not in pending_info_dict:
|
||||
pending_info_dict[person_name] = []
|
||||
pending_info_dict[person_name].append(info_type)
|
||||
|
||||
# 添加正在调取的信息到返回字符串
|
||||
for person_name, info_types in pending_info_dict.items():
|
||||
info_types_str = "、".join(info_types)
|
||||
persons_infos_str += f"你正在识图回忆有关 {person_name} 的 {info_types_str} 信息,稍等一下再回答...\n"
|
||||
|
||||
return persons_infos_str
|
||||
|
||||
|
||||
async def _execute_instant_extraction_batch(self, instant_tasks: list):
|
||||
"""
|
||||
批量执行即时提取任务
|
||||
"""
|
||||
if not instant_tasks:
|
||||
return
|
||||
|
||||
logger.info(f"{self.log_prefix} [即时提取] 开始批量提取 {len(instant_tasks)} 个信息")
|
||||
|
||||
# 创建所有提取任务
|
||||
extraction_tasks = []
|
||||
for person_id, info_type, start_time in instant_tasks:
|
||||
# 检查缓存中是否已存在且未过期的信息
|
||||
if person_id in self.info_fetched_cache and info_type in self.info_fetched_cache[person_id]:
|
||||
logger.info(f"{self.log_prefix} 用户 {person_id} 的 {info_type} 信息已存在且未过期,跳过调取。")
|
||||
continue
|
||||
|
||||
task = asyncio.create_task(self._fetch_single_info_instant(person_id, info_type, start_time))
|
||||
extraction_tasks.append(task)
|
||||
|
||||
# 并行执行所有提取任务并等待完成
|
||||
if extraction_tasks:
|
||||
await asyncio.gather(*extraction_tasks, return_exceptions=True)
|
||||
logger.info(f"{self.log_prefix} [即时提取] 批量提取完成")
|
||||
|
||||
async def _fetch_single_info_instant(self, person_id: str, info_type: str, start_time: float):
|
||||
"""
|
||||
使用小模型提取单个信息类型
|
||||
"""
|
||||
person_info_manager = get_person_info_manager()
|
||||
nickname_str = ",".join(global_config.bot.alias_names)
|
||||
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
|
||||
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
|
||||
person_impression = await person_info_manager.get_value(person_id, "impression")
|
||||
if not person_impression:
|
||||
impression_block = "你对ta没有什么深刻的印象"
|
||||
else:
|
||||
impression_block = f"{person_impression}"
|
||||
|
||||
points = await person_info_manager.get_value(person_id, "points")
|
||||
if points:
|
||||
points_text = "\n".join([f"{point[2]}:{point[0]}" for point in points])
|
||||
else:
|
||||
points_text = "你不记得ta最近发生了什么"
|
||||
|
||||
prompt = (await global_prompt_manager.get_prompt_async("fetch_info_prompt")).format(
|
||||
name_block=name_block,
|
||||
info_type=info_type,
|
||||
person_impression=impression_block,
|
||||
person_name=person_name,
|
||||
info_json_str=f'"{info_type}": "信息内容"',
|
||||
points_text=points_text,
|
||||
)
|
||||
|
||||
try:
|
||||
# 使用小模型进行即时提取
|
||||
content, _ = await self.instant_llm_model.generate_response_async(prompt=prompt)
|
||||
|
||||
logger.info(f"{self.log_prefix} [即时提取] {person_name} 的 {info_type} 结果: {content}")
|
||||
|
||||
if content:
|
||||
content_json = json.loads(repair_json(content))
|
||||
if info_type in content_json:
|
||||
info_content = content_json[info_type]
|
||||
if info_content != "none" and info_content:
|
||||
if person_id not in self.info_fetched_cache:
|
||||
self.info_fetched_cache[person_id] = {}
|
||||
self.info_fetched_cache[person_id][info_type] = {
|
||||
"info": info_content,
|
||||
"ttl": 8, # 小模型提取的信息TTL稍短
|
||||
"start_time": start_time,
|
||||
"person_name": person_name,
|
||||
"unknow": False,
|
||||
}
|
||||
logger.info(
|
||||
f"{self.log_prefix} [即时提取] 成功获取 {person_name} 的 {info_type}: {info_content}"
|
||||
)
|
||||
else:
|
||||
if person_id not in self.info_fetched_cache:
|
||||
self.info_fetched_cache[person_id] = {}
|
||||
self.info_fetched_cache[person_id][info_type] = {
|
||||
"info": "unknow",
|
||||
"ttl": 8,
|
||||
"start_time": start_time,
|
||||
"person_name": person_name,
|
||||
"unknow": True,
|
||||
}
|
||||
logger.info(f"{self.log_prefix} [即时提取] {person_name} 的 {info_type} 信息不明确")
|
||||
else:
|
||||
logger.warning(
|
||||
f"{self.log_prefix} [即时提取] 小模型返回空结果,获取 {person_name} 的 {info_type} 信息失败。"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} [即时提取] 执行小模型请求获取用户信息时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def fetch_person_info(self, person_id: str, info_types: list[str], start_time: float):
|
||||
"""
|
||||
获取某个人的信息
|
||||
"""
|
||||
# 检查缓存中是否已存在且未过期的信息
|
||||
info_types_to_fetch = []
|
||||
|
||||
|
||||
for info_type in info_types:
|
||||
if (person_id in self.info_fetched_cache and
|
||||
info_type in self.info_fetched_cache[person_id]):
|
||||
if person_id in self.info_fetched_cache and info_type in self.info_fetched_cache[person_id]:
|
||||
logger.info(f"{self.log_prefix} 用户 {person_id} 的 {info_type} 信息已存在且未过期,跳过调取。")
|
||||
continue
|
||||
info_types_to_fetch.append(info_type)
|
||||
|
||||
|
||||
if not info_types_to_fetch:
|
||||
return
|
||||
|
||||
|
||||
nickname_str = ",".join(global_config.bot.alias_names)
|
||||
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
|
||||
|
||||
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
|
||||
|
||||
info_type_str = ""
|
||||
info_json_str = ""
|
||||
for info_type in info_types_to_fetch:
|
||||
info_type_str += f"{info_type},"
|
||||
info_json_str += f"\"{info_type}\": \"信息内容\","
|
||||
info_json_str += f'"{info_type}": "信息内容",'
|
||||
info_type_str = info_type_str[:-1]
|
||||
info_json_str = info_json_str[:-1]
|
||||
|
||||
|
||||
person_impression = await person_info_manager.get_value(person_id, "impression")
|
||||
if not person_impression:
|
||||
impression_block = "你对ta没有什么深刻的印象"
|
||||
else:
|
||||
impression_block = f"{person_impression}"
|
||||
|
||||
|
||||
|
||||
points = await person_info_manager.get_value(person_id, "points")
|
||||
|
||||
if points:
|
||||
points_text = "\n".join([
|
||||
f"{point[2]}:{point[0]}"
|
||||
for point in points
|
||||
])
|
||||
points_text = "\n".join([f"{point[2]}:{point[0]}" for point in points])
|
||||
else:
|
||||
points_text = "你不记得ta最近发生了什么"
|
||||
|
||||
|
||||
|
||||
prompt = (await global_prompt_manager.get_prompt_async("fetch_info_prompt")).format(
|
||||
name_block=name_block,
|
||||
info_type=info_type_str,
|
||||
@@ -345,10 +489,10 @@ class RelationshipProcessor(BaseProcessor):
|
||||
|
||||
try:
|
||||
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
|
||||
|
||||
# logger.info(f"{self.log_prefix} fetch_person_info prompt: \n{prompt}\n")
|
||||
logger.info(f"{self.log_prefix} fetch_person_info 结果: {content}")
|
||||
|
||||
|
||||
if content:
|
||||
try:
|
||||
content_json = json.loads(repair_json(content))
|
||||
@@ -366,9 +510,9 @@ class RelationshipProcessor(BaseProcessor):
|
||||
else:
|
||||
if person_id not in self.info_fetched_cache:
|
||||
self.info_fetched_cache[person_id] = {}
|
||||
|
||||
|
||||
self.info_fetched_cache[person_id][info_type] = {
|
||||
"info":"unknow",
|
||||
"info": "unknow",
|
||||
"ttl": 10,
|
||||
"start_time": start_time,
|
||||
"person_name": person_name,
|
||||
@@ -383,19 +527,16 @@ class RelationshipProcessor(BaseProcessor):
|
||||
logger.error(f"{self.log_prefix} 执行LLM请求获取用户信息时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def update_impression_on_cache_expiry(
|
||||
self, person_id: str, chat_id: str, start_time: float, end_time: float
|
||||
):
|
||||
async def update_impression_on_cache_expiry(self, person_id: str, chat_id: str, start_time: float, end_time: float):
|
||||
"""
|
||||
在缓存过期时,获取聊天记录并更新用户印象
|
||||
"""
|
||||
logger.info(f"缓存过期,开始为 {person_id} 更新印象。时间范围:{start_time} -> {end_time}")
|
||||
try:
|
||||
|
||||
|
||||
impression_messages = get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time)
|
||||
if impression_messages:
|
||||
logger.info(f"为 {person_id} 获取到 {len(impression_messages)} 条消息用于印象更新。")
|
||||
relationship_manager = get_relationship_manager()
|
||||
await relationship_manager.update_person_impression(
|
||||
person_id=person_id, timestamp=end_time, bot_engaged_messages=impression_messages
|
||||
)
|
||||
|
||||
@@ -4,14 +4,13 @@ from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
import time
|
||||
import traceback
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.individuality.individuality import individuality
|
||||
from src.common.logger import get_logger
|
||||
from src.individuality.individuality import get_individuality
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import chat_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from .base_processor import BaseProcessor
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
|
||||
from typing import Dict
|
||||
from src.chat.focus_chat.info.info_base import InfoBase
|
||||
from src.chat.focus_chat.info.self_info import SelfInfo
|
||||
|
||||
@@ -59,12 +58,10 @@ class SelfProcessor(BaseProcessor):
|
||||
request_type="focus.processor.self_identify",
|
||||
)
|
||||
|
||||
name = chat_manager.get_stream_name(self.subheartflow_id)
|
||||
name = get_chat_manager().get_stream_name(self.subheartflow_id)
|
||||
self.log_prefix = f"[{name}] "
|
||||
|
||||
async def process_info(
|
||||
self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None, *infos
|
||||
) -> List[InfoBase]:
|
||||
async def process_info(self, observations: List[Observation] = None, *infos) -> List[InfoBase]:
|
||||
"""处理信息对象
|
||||
|
||||
Args:
|
||||
@@ -73,7 +70,7 @@ class SelfProcessor(BaseProcessor):
|
||||
Returns:
|
||||
List[InfoBase]: 处理后的结构化信息列表
|
||||
"""
|
||||
self_info_str = await self.self_indentify(observations, running_memorys)
|
||||
self_info_str = await self.self_indentify(observations)
|
||||
|
||||
if self_info_str:
|
||||
self_info = SelfInfo()
|
||||
@@ -85,7 +82,8 @@ class SelfProcessor(BaseProcessor):
|
||||
return [self_info]
|
||||
|
||||
async def self_indentify(
|
||||
self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None
|
||||
self,
|
||||
observations: List[Observation] = None,
|
||||
):
|
||||
"""
|
||||
在回复前进行思考,生成内心想法并收集工具调用结果
|
||||
@@ -100,13 +98,6 @@ class SelfProcessor(BaseProcessor):
|
||||
tuple: (current_mind, past_mind, prompt) 当前想法、过去的想法列表和使用的prompt
|
||||
"""
|
||||
|
||||
for observation in observations:
|
||||
if isinstance(observation, ChattingObservation):
|
||||
is_group_chat = observation.is_group_chat
|
||||
chat_target_info = observation.chat_target_info
|
||||
chat_target_name = "对方" # 私聊默认名称
|
||||
person_list = observation.person_list
|
||||
|
||||
if observations is None:
|
||||
observations = []
|
||||
for observation in observations:
|
||||
@@ -122,9 +113,7 @@ class SelfProcessor(BaseProcessor):
|
||||
)
|
||||
# 获取聊天内容
|
||||
chat_observe_info = observation.get_observe_info()
|
||||
person_list = observation.person_list
|
||||
if isinstance(observation, HFCloopObservation):
|
||||
# hfcloop_observe_info = observation.get_observe_info()
|
||||
pass
|
||||
|
||||
nickname_str = ""
|
||||
@@ -132,8 +121,9 @@ class SelfProcessor(BaseProcessor):
|
||||
nickname_str += f"{nicknames},"
|
||||
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
|
||||
|
||||
personality_block = individuality.get_personality_prompt(x_person=2, level=2)
|
||||
identity_block = individuality.get_identity_prompt(x_person=2, level=2)
|
||||
personality_block = get_individuality().get_personality_prompt(x_person=2, level=2)
|
||||
|
||||
identity_block = get_individuality().get_identity_prompt(x_person=2, level=2)
|
||||
|
||||
prompt = (await global_prompt_manager.get_prompt_async("indentify_prompt")).format(
|
||||
name_block=name_block,
|
||||
|
||||
@@ -2,13 +2,13 @@ from src.chat.heart_flow.observation.chatting_observation import ChattingObserva
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
import time
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.individuality.individuality import individuality
|
||||
from src.common.logger import get_logger
|
||||
from src.individuality.individuality import get_individuality
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.tools.tool_use import ToolUser
|
||||
from src.chat.utils.json_utils import process_llm_tool_calls
|
||||
from .base_processor import BaseProcessor
|
||||
from typing import List, Optional, Dict
|
||||
from typing import List, Optional
|
||||
from src.chat.heart_flow.observation.observation import Observation
|
||||
from src.chat.focus_chat.info.structured_info import StructuredInfo
|
||||
from src.chat.heart_flow.observation.structure_observation import StructureObservation
|
||||
@@ -47,12 +47,12 @@ class ToolProcessor(BaseProcessor):
|
||||
)
|
||||
self.structured_info = []
|
||||
|
||||
async def process_info(
|
||||
self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None, *infos
|
||||
) -> List[dict]:
|
||||
async def process_info(self, observations: Optional[List[Observation]] = None) -> List[StructuredInfo]:
|
||||
"""处理信息对象
|
||||
|
||||
Args:
|
||||
observations: 可选的观察列表,包含ChattingObservation和StructureObservation类型
|
||||
running_memories: 可选的运行时记忆列表,包含字典类型的记忆信息
|
||||
*infos: 可变数量的InfoBase类型的信息对象
|
||||
|
||||
Returns:
|
||||
@@ -60,15 +60,15 @@ class ToolProcessor(BaseProcessor):
|
||||
"""
|
||||
|
||||
working_infos = []
|
||||
result = []
|
||||
|
||||
if observations:
|
||||
for observation in observations:
|
||||
if isinstance(observation, ChattingObservation):
|
||||
result, used_tools, prompt = await self.execute_tools(observation, running_memorys)
|
||||
result, used_tools, prompt = await self.execute_tools(observation)
|
||||
|
||||
# 更新WorkingObservation中的结构化信息
|
||||
logger.debug(f"工具调用结果: {result}")
|
||||
|
||||
# 更新WorkingObservation中的结构化信息
|
||||
for observation in observations:
|
||||
if isinstance(observation, StructureObservation):
|
||||
for structured_info in result:
|
||||
@@ -81,16 +81,11 @@ class ToolProcessor(BaseProcessor):
|
||||
structured_info = StructuredInfo()
|
||||
if working_infos:
|
||||
for working_info in working_infos:
|
||||
# print(f"working_info: {working_info}")
|
||||
# print(f"working_info.get('type'): {working_info.get('type')}")
|
||||
# print(f"working_info.get('content'): {working_info.get('content')}")
|
||||
structured_info.set_info(key=working_info.get("type"), value=working_info.get("content"))
|
||||
# info = structured_info.get_processed_info()
|
||||
# print(f"info: {info}")
|
||||
|
||||
return [structured_info]
|
||||
|
||||
async def execute_tools(self, observation: ChattingObservation, running_memorys: Optional[List[Dict]] = None):
|
||||
async def execute_tools(self, observation: ChattingObservation):
|
||||
"""
|
||||
并行执行工具,返回结构化信息
|
||||
|
||||
@@ -118,13 +113,7 @@ class ToolProcessor(BaseProcessor):
|
||||
is_group_chat = observation.is_group_chat
|
||||
|
||||
chat_observe_info = observation.get_observe_info()
|
||||
person_list = observation.person_list
|
||||
|
||||
memory_str = ""
|
||||
if running_memorys:
|
||||
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
|
||||
for running_memory in running_memorys:
|
||||
memory_str += f"{running_memory['topic']}: {running_memory['content']}\n"
|
||||
# person_list = observation.person_list
|
||||
|
||||
# 获取时间信息
|
||||
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
@@ -132,18 +121,15 @@ class ToolProcessor(BaseProcessor):
|
||||
# 构建专用于工具调用的提示词
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"tool_executor_prompt",
|
||||
memory_str=memory_str,
|
||||
chat_observe_info=chat_observe_info,
|
||||
is_group_chat=is_group_chat,
|
||||
bot_name=individuality.name,
|
||||
bot_name=get_individuality().name,
|
||||
time_now=time_now,
|
||||
)
|
||||
|
||||
# 调用LLM,专注于工具使用
|
||||
# logger.info(f"开始执行工具调用{prompt}")
|
||||
response, other_info = await self.llm_model.generate_response_async(
|
||||
prompt=prompt, tools=tools
|
||||
)
|
||||
response, other_info = await self.llm_model.generate_response_async(prompt=prompt, tools=tools)
|
||||
|
||||
if len(other_info) == 3:
|
||||
reasoning_content, model_name, tool_calls = other_info
|
||||
|
||||
@@ -4,15 +4,14 @@ from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
import time
|
||||
import traceback
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import chat_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from .base_processor import BaseProcessor
|
||||
from src.chat.focus_chat.info.mind_info import MindInfo
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
from src.chat.heart_flow.observation.working_observation import WorkingMemoryObservation
|
||||
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
|
||||
from typing import Dict
|
||||
from src.chat.focus_chat.info.info_base import InfoBase
|
||||
from json_repair import repair_json
|
||||
from src.chat.focus_chat.info.workingmemory_info import WorkingMemoryInfo
|
||||
@@ -64,12 +63,10 @@ class WorkingMemoryProcessor(BaseProcessor):
|
||||
request_type="focus.processor.working_memory",
|
||||
)
|
||||
|
||||
name = chat_manager.get_stream_name(self.subheartflow_id)
|
||||
name = get_chat_manager().get_stream_name(self.subheartflow_id)
|
||||
self.log_prefix = f"[{name}] "
|
||||
|
||||
async def process_info(
|
||||
self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None, *infos
|
||||
) -> List[InfoBase]:
|
||||
async def process_info(self, observations: List[Observation] = None, *infos) -> List[InfoBase]:
|
||||
"""处理信息对象
|
||||
|
||||
Args:
|
||||
@@ -118,9 +115,7 @@ class WorkingMemoryProcessor(BaseProcessor):
|
||||
memory_str=memory_choose_str,
|
||||
)
|
||||
|
||||
|
||||
# print(f"prompt: {prompt}")
|
||||
|
||||
|
||||
# 调用LLM处理记忆
|
||||
content = ""
|
||||
|
||||
@@ -3,10 +3,10 @@ from src.chat.heart_flow.observation.structure_observation import StructureObser
|
||||
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from datetime import datetime
|
||||
from src.chat.memory_system.Hippocampus import HippocampusManager
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
from typing import List, Dict
|
||||
import difflib
|
||||
import json
|
||||
@@ -87,6 +87,10 @@ class MemoryActivator:
|
||||
Returns:
|
||||
List[Dict]: 激活的记忆列表
|
||||
"""
|
||||
# 如果记忆系统被禁用,直接返回空列表
|
||||
if not global_config.memory.enable_memory:
|
||||
return []
|
||||
|
||||
obs_info_text = ""
|
||||
for observation in observations:
|
||||
if isinstance(observation, ChattingObservation):
|
||||
@@ -128,10 +132,10 @@ class MemoryActivator:
|
||||
logger.debug(f"当前激活的记忆关键词: {self.cached_keywords}")
|
||||
|
||||
# 调用记忆系统获取相关记忆
|
||||
related_memory = await HippocampusManager.get_instance().get_memory_from_topic(
|
||||
related_memory = await hippocampus_manager.get_memory_from_topic(
|
||||
valid_keywords=keywords, max_memory_num=3, max_memory_length=2, max_depth=3
|
||||
)
|
||||
# related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
||||
# related_memory = await hippocampus_manager.get_memory_from_text(
|
||||
# text=obs_info_text, max_memory_num=5, max_memory_length=2, max_depth=3, fast_retrieval=False
|
||||
# )
|
||||
|
||||
|
||||
@@ -1,16 +1,13 @@
|
||||
from typing import Dict, List, Optional, Type, Any
|
||||
from src.chat.focus_chat.planners.actions.base_action import BaseAction, _ACTION_REGISTRY
|
||||
from src.plugin_system.base.base_action import BaseAction
|
||||
from src.chat.heart_flow.observation.observation import Observation
|
||||
from src.chat.focus_chat.replyer.default_replyer import DefaultReplyer
|
||||
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.common.logger_manager import get_logger
|
||||
import importlib
|
||||
import pkgutil
|
||||
import os
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# 导入动作类,确保装饰器被执行
|
||||
import src.chat.focus_chat.planners.actions # noqa
|
||||
# 不再需要导入动作类,因为已经在main.py中导入
|
||||
# import src.chat.actions.default_actions # noqa
|
||||
|
||||
logger = get_logger("action_manager")
|
||||
|
||||
@@ -18,6 +15,84 @@ logger = get_logger("action_manager")
|
||||
ActionInfo = Dict[str, Any]
|
||||
|
||||
|
||||
class PluginActionWrapper(BaseAction):
|
||||
"""
|
||||
新插件系统Action组件的兼容性包装器
|
||||
|
||||
将新插件系统的Action组件包装为旧系统兼容的BaseAction接口
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, plugin_action, action_name: str, action_data: dict, reasoning: str, cycle_timers: dict, thinking_id: str
|
||||
):
|
||||
"""初始化包装器"""
|
||||
# 调用旧系统BaseAction初始化,只传递它能接受的参数
|
||||
super().__init__(
|
||||
action_data=action_data, reasoning=reasoning, cycle_timers=cycle_timers, thinking_id=thinking_id
|
||||
)
|
||||
|
||||
# 存储插件Action实例(它已经包含了所有必要的服务对象)
|
||||
self.plugin_action = plugin_action
|
||||
self.action_name = action_name
|
||||
|
||||
# 从插件Action实例复制属性到包装器
|
||||
self._sync_attributes_from_plugin_action()
|
||||
|
||||
def _sync_attributes_from_plugin_action(self):
|
||||
"""从插件Action实例同步属性到包装器"""
|
||||
# 基本属性
|
||||
self.action_name = getattr(self.plugin_action, "action_name", self.action_name)
|
||||
|
||||
# 设置兼容的默认值
|
||||
self.action_description = f"插件Action: {self.action_name}"
|
||||
self.action_parameters = {}
|
||||
self.action_require = []
|
||||
|
||||
# 激活类型属性(从新插件系统转换)
|
||||
plugin_focus_type = getattr(self.plugin_action, "focus_activation_type", None)
|
||||
plugin_normal_type = getattr(self.plugin_action, "normal_activation_type", None)
|
||||
|
||||
if plugin_focus_type:
|
||||
self.focus_activation_type = (
|
||||
plugin_focus_type.value if hasattr(plugin_focus_type, "value") else str(plugin_focus_type)
|
||||
)
|
||||
if plugin_normal_type:
|
||||
self.normal_activation_type = (
|
||||
plugin_normal_type.value if hasattr(plugin_normal_type, "value") else str(plugin_normal_type)
|
||||
)
|
||||
|
||||
# 其他属性
|
||||
self.random_activation_probability = getattr(self.plugin_action, "random_activation_probability", 0.0)
|
||||
self.llm_judge_prompt = getattr(self.plugin_action, "llm_judge_prompt", "")
|
||||
self.activation_keywords = getattr(self.plugin_action, "activation_keywords", [])
|
||||
self.keyword_case_sensitive = getattr(self.plugin_action, "keyword_case_sensitive", False)
|
||||
|
||||
# 模式和并行设置
|
||||
plugin_mode = getattr(self.plugin_action, "mode_enable", None)
|
||||
if plugin_mode:
|
||||
self.mode_enable = plugin_mode.value if hasattr(plugin_mode, "value") else str(plugin_mode)
|
||||
|
||||
self.parallel_action = getattr(self.plugin_action, "parallel_action", True)
|
||||
self.enable_plugin = True
|
||||
|
||||
async def execute(self) -> tuple[bool, str]:
|
||||
"""实现抽象方法execute,委托给插件Action的execute方法"""
|
||||
try:
|
||||
# 调用插件Action的execute方法
|
||||
success, response = await self.plugin_action.execute()
|
||||
|
||||
logger.debug(f"插件Action {self.action_name} 执行{'成功' if success else '失败'}: {response}")
|
||||
return success, response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"插件Action {self.action_name} 执行异常: {e}")
|
||||
return False, f"插件Action执行失败: {str(e)}"
|
||||
|
||||
async def handle_action(self) -> tuple[bool, str]:
|
||||
"""兼容旧系统的动作处理接口,委托给execute方法"""
|
||||
return await self.execute()
|
||||
|
||||
|
||||
class ActionManager:
|
||||
"""
|
||||
动作管理器,用于管理各种类型的动作
|
||||
@@ -41,7 +116,7 @@ class ActionManager:
|
||||
|
||||
# 初始化时将默认动作加载到使用中的动作
|
||||
self._using_actions = self._default_actions.copy()
|
||||
|
||||
|
||||
# 添加系统核心动作
|
||||
self._add_system_core_actions()
|
||||
|
||||
@@ -50,8 +125,13 @@ class ActionManager:
|
||||
加载所有通过装饰器注册的动作
|
||||
"""
|
||||
try:
|
||||
# 从_ACTION_REGISTRY获取所有已注册动作
|
||||
for action_name, action_class in _ACTION_REGISTRY.items():
|
||||
# 从组件注册中心获取所有已注册的action
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
action_registry = component_registry.get_action_registry()
|
||||
|
||||
# 从action_registry获取所有已注册动作
|
||||
for action_name, action_class in action_registry.items():
|
||||
# 获取动作相关信息
|
||||
|
||||
# 不读取插件动作和基类
|
||||
@@ -63,19 +143,33 @@ class ActionManager:
|
||||
action_require: list[str] = getattr(action_class, "action_require", [])
|
||||
associated_types: list[str] = getattr(action_class, "associated_types", [])
|
||||
is_enabled: bool = getattr(action_class, "enable_plugin", True)
|
||||
|
||||
|
||||
# 获取激活类型相关属性
|
||||
focus_activation_type: str = getattr(action_class, "focus_activation_type", "always")
|
||||
normal_activation_type: str = getattr(action_class, "normal_activation_type", "always")
|
||||
|
||||
focus_activation_type_attr = getattr(action_class, "focus_activation_type", "always")
|
||||
normal_activation_type_attr = getattr(action_class, "normal_activation_type", "always")
|
||||
|
||||
# 处理枚举值,提取.value
|
||||
focus_activation_type = (
|
||||
focus_activation_type_attr.value
|
||||
if hasattr(focus_activation_type_attr, "value")
|
||||
else str(focus_activation_type_attr)
|
||||
)
|
||||
normal_activation_type = (
|
||||
normal_activation_type_attr.value
|
||||
if hasattr(normal_activation_type_attr, "value")
|
||||
else str(normal_activation_type_attr)
|
||||
)
|
||||
|
||||
# 其他属性
|
||||
random_probability: float = getattr(action_class, "random_activation_probability", 0.3)
|
||||
llm_judge_prompt: str = getattr(action_class, "llm_judge_prompt", "")
|
||||
activation_keywords: list[str] = getattr(action_class, "activation_keywords", [])
|
||||
keyword_case_sensitive: bool = getattr(action_class, "keyword_case_sensitive", False)
|
||||
|
||||
# 获取模式启用属性
|
||||
mode_enable: str = getattr(action_class, "mode_enable", "all")
|
||||
|
||||
|
||||
# 处理模式启用属性
|
||||
mode_enable_attr = getattr(action_class, "mode_enable", "all")
|
||||
mode_enable = mode_enable_attr.value if hasattr(mode_enable_attr, "value") else str(mode_enable_attr)
|
||||
|
||||
# 获取并行执行属性
|
||||
parallel_action: bool = getattr(action_class, "parallel_action", False)
|
||||
|
||||
@@ -114,45 +208,76 @@ class ActionManager:
|
||||
def _load_plugin_actions(self) -> None:
|
||||
"""
|
||||
加载所有插件目录中的动作
|
||||
|
||||
注意:插件动作的实际导入已经在main.py中完成,这里只需要从action_registry获取
|
||||
同时也从新插件系统的component_registry获取Action组件
|
||||
"""
|
||||
try:
|
||||
# 检查插件目录是否存在
|
||||
plugin_path = "src.plugins"
|
||||
plugin_dir = plugin_path.replace(".", os.path.sep)
|
||||
if not os.path.exists(plugin_dir):
|
||||
logger.info(f"插件目录 {plugin_dir} 不存在,跳过插件动作加载")
|
||||
return
|
||||
|
||||
# 导入插件包
|
||||
try:
|
||||
plugins_package = importlib.import_module(plugin_path)
|
||||
except ImportError as e:
|
||||
logger.error(f"导入插件包失败: {e}")
|
||||
return
|
||||
|
||||
# 遍历插件包中的所有子包
|
||||
for _, plugin_name, is_pkg in pkgutil.iter_modules(
|
||||
plugins_package.__path__, plugins_package.__name__ + "."
|
||||
):
|
||||
if not is_pkg:
|
||||
continue
|
||||
|
||||
# 检查插件是否有actions子包
|
||||
plugin_actions_path = f"{plugin_name}.actions"
|
||||
try:
|
||||
# 尝试导入插件的actions包
|
||||
importlib.import_module(plugin_actions_path)
|
||||
logger.info(f"成功加载插件动作模块: {plugin_actions_path}")
|
||||
except ImportError as e:
|
||||
logger.debug(f"插件 {plugin_name} 没有actions子包或导入失败: {e}")
|
||||
continue
|
||||
|
||||
# 再次从_ACTION_REGISTRY获取所有动作(包括刚刚从插件加载的)
|
||||
# 从旧的action_registry获取插件动作
|
||||
self._load_registered_actions()
|
||||
logger.debug("从旧注册表加载插件动作成功")
|
||||
|
||||
# 从新插件系统获取Action组件
|
||||
self._load_plugin_system_actions()
|
||||
logger.debug("从新插件系统加载Action组件成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载插件动作失败: {e}")
|
||||
|
||||
def _load_plugin_system_actions(self) -> None:
|
||||
"""从新插件系统的component_registry加载Action组件"""
|
||||
try:
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
|
||||
# 获取所有Action组件
|
||||
action_components = component_registry.get_components_by_type(ComponentType.ACTION)
|
||||
|
||||
for action_name, action_info in action_components.items():
|
||||
if action_name in self._registered_actions:
|
||||
logger.debug(f"Action组件 {action_name} 已存在,跳过")
|
||||
continue
|
||||
|
||||
# 将新插件系统的ActionInfo转换为旧系统格式
|
||||
converted_action_info = {
|
||||
"description": action_info.description,
|
||||
"parameters": getattr(action_info, "action_parameters", {}),
|
||||
"require": getattr(action_info, "action_require", []),
|
||||
"associated_types": getattr(action_info, "associated_types", []),
|
||||
"enable_plugin": action_info.enabled,
|
||||
# 激活类型相关
|
||||
"focus_activation_type": action_info.focus_activation_type.value,
|
||||
"normal_activation_type": action_info.normal_activation_type.value,
|
||||
"random_activation_probability": action_info.random_activation_probability,
|
||||
"llm_judge_prompt": action_info.llm_judge_prompt,
|
||||
"activation_keywords": action_info.activation_keywords,
|
||||
"keyword_case_sensitive": action_info.keyword_case_sensitive,
|
||||
# 模式和并行设置
|
||||
"mode_enable": action_info.mode_enable.value,
|
||||
"parallel_action": action_info.parallel_action,
|
||||
# 标记这是来自新插件系统的组件
|
||||
"_plugin_system_component": True,
|
||||
"_plugin_name": getattr(action_info, "plugin_name", ""),
|
||||
}
|
||||
|
||||
self._registered_actions[action_name] = converted_action_info
|
||||
|
||||
# 如果启用,也添加到默认动作集
|
||||
if action_info.enabled:
|
||||
self._default_actions[action_name] = converted_action_info
|
||||
|
||||
logger.debug(
|
||||
f"从插件系统加载Action组件: {action_name} (插件: {getattr(action_info, 'plugin_name', 'unknown')})"
|
||||
)
|
||||
|
||||
logger.info(f"从新插件系统加载了 {len(action_components)} 个Action组件")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从插件系统加载Action组件失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
def create_action(
|
||||
self,
|
||||
action_name: str,
|
||||
@@ -191,7 +316,28 @@ class ActionManager:
|
||||
# logger.warning(f"当前不可用的动作类型: {action_name}")
|
||||
# return None
|
||||
|
||||
handler_class = _ACTION_REGISTRY.get(action_name)
|
||||
# 检查是否是新插件系统的Action组件
|
||||
action_info = self._registered_actions.get(action_name)
|
||||
if action_info and action_info.get("_plugin_system_component", False):
|
||||
return self._create_plugin_system_action(
|
||||
action_name,
|
||||
action_data,
|
||||
reasoning,
|
||||
cycle_timers,
|
||||
thinking_id,
|
||||
observations,
|
||||
chat_stream,
|
||||
log_prefix,
|
||||
shutting_down,
|
||||
expressor,
|
||||
replyer,
|
||||
)
|
||||
|
||||
# 旧系统的动作创建逻辑
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
action_registry = component_registry.get_action_registry()
|
||||
handler_class = action_registry.get(action_name)
|
||||
if not handler_class:
|
||||
logger.warning(f"未注册的动作类型: {action_name}")
|
||||
return None
|
||||
@@ -217,6 +363,75 @@ class ActionManager:
|
||||
logger.error(f"创建动作处理器实例失败: {e}")
|
||||
return None
|
||||
|
||||
def _create_plugin_system_action(
|
||||
self,
|
||||
action_name: str,
|
||||
action_data: dict,
|
||||
reasoning: str,
|
||||
cycle_timers: dict,
|
||||
thinking_id: str,
|
||||
observations: List[Observation],
|
||||
chat_stream: ChatStream,
|
||||
log_prefix: str,
|
||||
shutting_down: bool = False,
|
||||
expressor: DefaultExpressor = None,
|
||||
replyer: DefaultReplyer = None,
|
||||
) -> Optional["PluginActionWrapper"]:
|
||||
"""
|
||||
创建新插件系统的Action组件实例,并包装为兼容旧系统的接口
|
||||
|
||||
Returns:
|
||||
Optional[PluginActionWrapper]: 包装后的Action实例
|
||||
"""
|
||||
try:
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
# 获取组件类
|
||||
component_class = component_registry.get_component_class(action_name)
|
||||
if not component_class:
|
||||
logger.error(f"未找到插件Action组件类: {action_name}")
|
||||
return None
|
||||
|
||||
# 获取插件配置
|
||||
component_info = component_registry.get_component_info(action_name)
|
||||
plugin_config = None
|
||||
if component_info and component_info.plugin_name:
|
||||
plugin_config = component_registry.get_plugin_config(component_info.plugin_name)
|
||||
|
||||
# 创建插件Action实例
|
||||
plugin_action_instance = component_class(
|
||||
action_data=action_data,
|
||||
reasoning=reasoning,
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
chat_stream=chat_stream,
|
||||
expressor=expressor,
|
||||
replyer=replyer,
|
||||
observations=observations,
|
||||
log_prefix=log_prefix,
|
||||
plugin_config=plugin_config,
|
||||
)
|
||||
|
||||
# 创建兼容性包装器
|
||||
wrapper = PluginActionWrapper(
|
||||
plugin_action=plugin_action_instance,
|
||||
action_name=action_name,
|
||||
action_data=action_data,
|
||||
reasoning=reasoning,
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
)
|
||||
|
||||
logger.debug(f"创建插件Action实例成功: {action_name}")
|
||||
return wrapper
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建插件Action实例失败 {action_name}: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
def get_registered_actions(self) -> Dict[str, ActionInfo]:
|
||||
"""获取所有已注册的动作集"""
|
||||
return self._registered_actions.copy()
|
||||
@@ -232,26 +447,30 @@ class ActionManager:
|
||||
def get_using_actions_for_mode(self, mode: str) -> Dict[str, ActionInfo]:
|
||||
"""
|
||||
根据聊天模式获取可用的动作集合
|
||||
|
||||
|
||||
Args:
|
||||
mode: 聊天模式 ("focus", "normal", "all")
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, ActionInfo]: 在指定模式下可用的动作集合
|
||||
"""
|
||||
filtered_actions = {}
|
||||
|
||||
|
||||
# print(self._using_actions)
|
||||
|
||||
for action_name, action_info in self._using_actions.items():
|
||||
# print(f"action_info: {action_info}")
|
||||
# print(f"action_name: {action_name}")
|
||||
action_mode = action_info.get("mode_enable", "all")
|
||||
|
||||
|
||||
# 检查动作是否在当前模式下启用
|
||||
if action_mode == "all" or action_mode == mode:
|
||||
filtered_actions[action_name] = action_info
|
||||
logger.debug(f"动作 {action_name} 在模式 {mode} 下可用 (mode_enable: {action_mode})")
|
||||
else:
|
||||
logger.debug(f"动作 {action_name} 在模式 {mode} 下不可用 (mode_enable: {action_mode})")
|
||||
|
||||
logger.info(f"模式 {mode} 下可用动作: {list(filtered_actions.keys())}")
|
||||
|
||||
logger.debug(f"模式 {mode} 下可用动作: {list(filtered_actions.keys())}")
|
||||
return filtered_actions
|
||||
|
||||
def add_action_to_using(self, action_name: str) -> bool:
|
||||
@@ -291,7 +510,7 @@ class ActionManager:
|
||||
return False
|
||||
|
||||
del self._using_actions[action_name]
|
||||
logger.info(f"已从使用集中移除动作 {action_name}")
|
||||
logger.debug(f"已从使用集中移除动作 {action_name}")
|
||||
return True
|
||||
|
||||
def add_action(self, action_name: str, description: str, parameters: Dict = None, require: List = None) -> bool:
|
||||
@@ -354,19 +573,19 @@ class ActionManager:
|
||||
系统核心动作是那些enable_plugin为False但是系统必需的动作
|
||||
"""
|
||||
system_core_actions = ["exit_focus_chat"] # 可以根据需要扩展
|
||||
|
||||
|
||||
for action_name in system_core_actions:
|
||||
if action_name in self._registered_actions and action_name not in self._using_actions:
|
||||
self._using_actions[action_name] = self._registered_actions[action_name]
|
||||
logger.info(f"添加系统核心动作到使用集: {action_name}")
|
||||
logger.debug(f"添加系统核心动作到使用集: {action_name}")
|
||||
|
||||
def add_system_action_if_needed(self, action_name: str) -> bool:
|
||||
"""
|
||||
根据需要添加系统动作到使用集
|
||||
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否成功添加
|
||||
"""
|
||||
@@ -386,4 +605,7 @@ class ActionManager:
|
||||
Returns:
|
||||
Optional[Type[BaseAction]]: 动作处理器类,如果不存在则返回None
|
||||
"""
|
||||
return _ACTION_REGISTRY.get(action_name)
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
action_registry = component_registry.get_action_registry()
|
||||
return action_registry.get(action_name)
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
# 导入所有动作模块以确保装饰器被执行
|
||||
from . import reply_action # noqa
|
||||
from . import no_reply_action # noqa
|
||||
from . import exit_focus_chat_action # noqa
|
||||
from . import emoji_action # noqa
|
||||
|
||||
# 在此处添加更多动作模块导入
|
||||
@@ -1,124 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple, Dict, Type
|
||||
from src.common.logger_manager import get_logger
|
||||
|
||||
logger = get_logger("base_action")
|
||||
|
||||
# 全局动作注册表
|
||||
_ACTION_REGISTRY: Dict[str, Type["BaseAction"]] = {}
|
||||
_DEFAULT_ACTIONS: Dict[str, str] = {}
|
||||
|
||||
# 动作激活类型枚举
|
||||
class ActionActivationType:
|
||||
ALWAYS = "always" # 默认参与到planner
|
||||
LLM_JUDGE = "llm_judge" # LLM判定是否启动该action到planner
|
||||
RANDOM = "random" # 随机启用action到planner
|
||||
KEYWORD = "keyword" # 关键词触发启用action到planner
|
||||
|
||||
# 聊天模式枚举
|
||||
class ChatMode:
|
||||
FOCUS = "focus" # Focus聊天模式
|
||||
NORMAL = "normal" # Normal聊天模式
|
||||
ALL = "all" # 所有聊天模式
|
||||
|
||||
def register_action(cls):
|
||||
"""
|
||||
动作注册装饰器
|
||||
|
||||
用法:
|
||||
@register_action
|
||||
class MyAction(BaseAction):
|
||||
action_name = "my_action"
|
||||
action_description = "我的动作"
|
||||
focus_activation_type = ActionActivationType.ALWAYS
|
||||
normal_activation_type = ActionActivationType.ALWAYS
|
||||
mode_enable = ChatMode.ALL
|
||||
parallel_action = False
|
||||
...
|
||||
"""
|
||||
# 检查类是否有必要的属性
|
||||
if not hasattr(cls, "action_name") or not hasattr(cls, "action_description"):
|
||||
logger.error(f"动作类 {cls.__name__} 缺少必要的属性: action_name 或 action_description")
|
||||
return cls
|
||||
|
||||
action_name = cls.action_name
|
||||
action_description = cls.action_description
|
||||
is_enabled = getattr(cls, "enable_plugin", True) # 默认启用插件
|
||||
|
||||
if not action_name or not action_description:
|
||||
logger.error(f"动作类 {cls.__name__} 的 action_name 或 action_description 为空")
|
||||
return cls
|
||||
|
||||
# 将动作类注册到全局注册表
|
||||
_ACTION_REGISTRY[action_name] = cls
|
||||
|
||||
# 如果启用插件,添加到默认动作集
|
||||
if is_enabled:
|
||||
_DEFAULT_ACTIONS[action_name] = action_description
|
||||
|
||||
logger.info(f"已注册动作: {action_name} -> {cls.__name__},插件启用: {is_enabled}")
|
||||
return cls
|
||||
|
||||
|
||||
class BaseAction(ABC):
|
||||
"""动作基类接口
|
||||
|
||||
所有具体的动作类都应该继承这个基类,并实现handle_action方法。
|
||||
"""
|
||||
|
||||
def __init__(self, action_data: dict, reasoning: str, cycle_timers: dict, thinking_id: str):
|
||||
"""初始化动作
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
action_data: 动作数据
|
||||
reasoning: 执行该动作的理由
|
||||
cycle_timers: 计时器字典
|
||||
thinking_id: 思考ID
|
||||
"""
|
||||
# 每个动作必须实现
|
||||
self.action_name: str = "base_action"
|
||||
self.action_description: str = "基础动作"
|
||||
self.action_parameters: dict = {}
|
||||
self.action_require: list[str] = []
|
||||
|
||||
# 动作激活类型设置
|
||||
# Focus模式下的激活类型,默认为always
|
||||
self.focus_activation_type: str = ActionActivationType.ALWAYS
|
||||
# Normal模式下的激活类型,默认为always
|
||||
self.normal_activation_type: str = ActionActivationType.ALWAYS
|
||||
|
||||
# 随机激活的概率(0.0-1.0),用于RANDOM激活类型
|
||||
self.random_activation_probability: float = 0.3
|
||||
# LLM判定的提示词,用于LLM_JUDGE激活类型
|
||||
self.llm_judge_prompt: str = ""
|
||||
# 关键词触发列表,用于KEYWORD激活类型
|
||||
self.activation_keywords: list[str] = []
|
||||
# 关键词匹配是否区分大小写
|
||||
self.keyword_case_sensitive: bool = False
|
||||
|
||||
# 模式启用设置:指定在哪些聊天模式下启用此动作
|
||||
# 可选值: "focus"(仅Focus模式), "normal"(仅Normal模式), "all"(所有模式)
|
||||
self.mode_enable: str = ChatMode.ALL
|
||||
|
||||
# 并行执行设置:仅在Normal模式下生效,设置为True的动作可以与回复动作并行执行
|
||||
# 而不是替代回复动作,适用于图片生成、TTS、禁言等不需要覆盖回复的动作
|
||||
self.parallel_action: bool = False
|
||||
|
||||
self.associated_types: list[str] = []
|
||||
|
||||
self.enable_plugin: bool = True # 是否启用插件,默认启用
|
||||
|
||||
self.action_data = action_data
|
||||
self.reasoning = reasoning
|
||||
self.cycle_timers = cycle_timers
|
||||
self.thinking_id = thinking_id
|
||||
|
||||
@abstractmethod
|
||||
async def handle_action(self) -> Tuple[bool, str]:
|
||||
"""处理动作的抽象方法,需要被子类实现
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否执行成功, 回复文本)
|
||||
"""
|
||||
pass
|
||||
@@ -1,150 +0,0 @@
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action, ActionActivationType, ChatMode
|
||||
from typing import Tuple, List
|
||||
from src.chat.heart_flow.observation.observation import Observation
|
||||
from src.chat.focus_chat.replyer.default_replyer import DefaultReplyer
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("action_taken")
|
||||
|
||||
|
||||
@register_action
|
||||
class EmojiAction(BaseAction):
|
||||
"""表情动作处理类
|
||||
|
||||
处理构建和发送消息表情的动作。
|
||||
"""
|
||||
|
||||
action_name: str = "emoji"
|
||||
action_description: str = "当你想单独发送一个表情包辅助你的回复表达"
|
||||
action_parameters: dict[str:str] = {
|
||||
"description": "文字描述你想要发送的表情包内容",
|
||||
}
|
||||
action_require: list[str] = [
|
||||
"表达情绪时可以选择使用",
|
||||
"重点:不要连续发,如果你已经发过[表情包],就不要选择此动作"]
|
||||
|
||||
associated_types: list[str] = ["emoji"]
|
||||
|
||||
enable_plugin = True
|
||||
|
||||
focus_activation_type = ActionActivationType.LLM_JUDGE
|
||||
normal_activation_type = ActionActivationType.RANDOM
|
||||
|
||||
random_activation_probability = global_config.normal_chat.emoji_chance
|
||||
|
||||
parallel_action = True
|
||||
|
||||
|
||||
llm_judge_prompt = """
|
||||
判定是否需要使用表情动作的条件:
|
||||
1. 用户明确要求使用表情包
|
||||
2. 这是一个适合表达强烈情绪的场合
|
||||
3. 不要发送太多表情包,如果你已经发送过多个表情包
|
||||
"""
|
||||
|
||||
# 模式启用设置 - 表情动作只在Focus模式下使用
|
||||
mode_enable = ChatMode.ALL
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
action_data: dict,
|
||||
reasoning: str,
|
||||
cycle_timers: dict,
|
||||
thinking_id: str,
|
||||
observations: List[Observation],
|
||||
chat_stream: ChatStream,
|
||||
log_prefix: str,
|
||||
replyer: DefaultReplyer,
|
||||
**kwargs,
|
||||
):
|
||||
"""初始化回复动作处理器
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
action_data: 动作数据,包含 message, emojis, target 等
|
||||
reasoning: 执行该动作的理由
|
||||
cycle_timers: 计时器字典
|
||||
thinking_id: 思考ID
|
||||
observations: 观察列表
|
||||
replyer: 回复器
|
||||
chat_stream: 聊天流
|
||||
log_prefix: 日志前缀
|
||||
"""
|
||||
super().__init__(action_data, reasoning, cycle_timers, thinking_id)
|
||||
self.observations = observations
|
||||
self.replyer = replyer
|
||||
self.chat_stream = chat_stream
|
||||
self.log_prefix = log_prefix
|
||||
|
||||
async def handle_action(self) -> Tuple[bool, str]:
|
||||
"""
|
||||
处理回复动作
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否执行成功, 回复文本)
|
||||
"""
|
||||
# 注意: 此处可能会使用不同的expressor实现根据任务类型切换不同的回复策略
|
||||
return await self._handle_reply(
|
||||
reasoning=self.reasoning,
|
||||
reply_data=self.action_data,
|
||||
cycle_timers=self.cycle_timers,
|
||||
thinking_id=self.thinking_id,
|
||||
)
|
||||
|
||||
async def _handle_reply(
|
||||
self, reasoning: str, reply_data: dict, cycle_timers: dict, thinking_id: str
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
处理统一的回复动作 - 可包含文本和表情,顺序任意
|
||||
|
||||
reply_data格式:
|
||||
{
|
||||
"description": "描述你想要发送的表情"
|
||||
}
|
||||
"""
|
||||
logger.info(f"{self.log_prefix} 决定发送表情")
|
||||
# 从聊天观察获取锚定消息
|
||||
# chatting_observation: ChattingObservation = next(
|
||||
# obs for obs in self.observations if isinstance(obs, ChattingObservation)
|
||||
# )
|
||||
# if reply_data.get("target"):
|
||||
# anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
|
||||
# else:
|
||||
# anchor_message = None
|
||||
|
||||
# 如果没有找到锚点消息,创建一个占位符
|
||||
# if not anchor_message:
|
||||
# logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符")
|
||||
# anchor_message = await create_empty_anchor_message(
|
||||
# self.chat_stream.platform, self.chat_stream.group_info, self.chat_stream
|
||||
# )
|
||||
# else:
|
||||
# anchor_message.update_chat_stream(self.chat_stream)
|
||||
|
||||
logger.info(f"{self.log_prefix} 为了表情包创建占位符")
|
||||
anchor_message = await create_empty_anchor_message(
|
||||
self.chat_stream.platform, self.chat_stream.group_info, self.chat_stream
|
||||
)
|
||||
|
||||
success, reply_set = await self.replyer.deal_emoji(
|
||||
cycle_timers=cycle_timers,
|
||||
action_data=reply_data,
|
||||
anchor_message=anchor_message,
|
||||
# reasoning=reasoning,
|
||||
thinking_id=thinking_id,
|
||||
)
|
||||
|
||||
reply_text = ""
|
||||
if reply_set:
|
||||
for reply in reply_set:
|
||||
type = reply[0]
|
||||
data = reply[1]
|
||||
if type == "text":
|
||||
reply_text += data
|
||||
elif type == "emoji":
|
||||
reply_text += data
|
||||
|
||||
return success, reply_text
|
||||
@@ -1,88 +0,0 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action, ChatMode
|
||||
from typing import Tuple, List
|
||||
from src.chat.heart_flow.observation.observation import Observation
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
|
||||
logger = get_logger("action_taken")
|
||||
|
||||
|
||||
@register_action
|
||||
class ExitFocusChatAction(BaseAction):
|
||||
"""退出专注聊天动作处理类
|
||||
|
||||
处理决定退出专注聊天的动作。
|
||||
执行后会将所属的sub heartflow转变为normal_chat状态。
|
||||
"""
|
||||
|
||||
action_name = "exit_focus_chat"
|
||||
action_description = "退出专注聊天,转为普通聊天模式"
|
||||
action_parameters = {}
|
||||
action_require = [
|
||||
"很长时间没有回复,你决定退出专注聊天",
|
||||
"当前内容不需要持续专注关注,你决定退出专注聊天",
|
||||
"聊天内容已经完成,你决定退出专注聊天",
|
||||
]
|
||||
# 退出专注聊天是系统核心功能,不是插件,但默认不启用(需要特定条件触发)
|
||||
enable_plugin = False
|
||||
|
||||
# 模式启用设置 - 退出专注聊天动作只在Focus模式下使用
|
||||
mode_enable = ChatMode.FOCUS
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
action_data: dict,
|
||||
reasoning: str,
|
||||
cycle_timers: dict,
|
||||
thinking_id: str,
|
||||
observations: List[Observation],
|
||||
log_prefix: str,
|
||||
chat_stream: ChatStream,
|
||||
shutting_down: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""初始化退出专注聊天动作处理器
|
||||
|
||||
Args:
|
||||
action_data: 动作数据
|
||||
reasoning: 执行该动作的理由
|
||||
cycle_timers: 计时器字典
|
||||
thinking_id: 思考ID
|
||||
observations: 观察列表
|
||||
log_prefix: 日志前缀
|
||||
shutting_down: 是否正在关闭
|
||||
"""
|
||||
super().__init__(action_data, reasoning, cycle_timers, thinking_id)
|
||||
self.observations = observations
|
||||
self.log_prefix = log_prefix
|
||||
self._shutting_down = shutting_down
|
||||
|
||||
async def handle_action(self) -> Tuple[bool, str]:
|
||||
"""
|
||||
处理退出专注聊天的情况
|
||||
|
||||
工作流程:
|
||||
1. 将sub heartflow转换为normal_chat状态
|
||||
2. 等待新消息、超时或关闭信号
|
||||
3. 根据等待结果更新连续不回复计数
|
||||
4. 如果达到阈值,触发回调
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否执行成功, 状态转换消息)
|
||||
"""
|
||||
try:
|
||||
# 转换状态
|
||||
status_message = ""
|
||||
command = "stop_focus_chat"
|
||||
return True, status_message, command
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.log_prefix} 处理 'exit_focus_chat' 时等待被中断 (CancelledError)")
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"处理 'exit_focus_chat' 时发生错误: {str(e)}"
|
||||
logger.error(f"{self.log_prefix} {error_msg}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False, "", ""
|
||||
@@ -1,139 +0,0 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action, ActionActivationType, ChatMode
|
||||
from typing import Tuple, List
|
||||
from src.chat.heart_flow.observation.observation import Observation
|
||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||
from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp
|
||||
|
||||
logger = get_logger("action_taken")
|
||||
|
||||
# 常量定义
|
||||
WAITING_TIME_THRESHOLD = 1200 # 等待新消息时间阈值,单位秒
|
||||
|
||||
|
||||
@register_action
|
||||
class NoReplyAction(BaseAction):
|
||||
"""不回复动作处理类
|
||||
|
||||
处理决定不回复的动作。
|
||||
"""
|
||||
|
||||
action_name = "no_reply"
|
||||
action_description = "暂时不回复消息"
|
||||
action_parameters = {}
|
||||
action_require = [
|
||||
"你连续发送了太多消息,且无人回复",
|
||||
"想要休息一下",
|
||||
]
|
||||
enable_plugin = True
|
||||
|
||||
# 激活类型设置
|
||||
focus_activation_type = ActionActivationType.ALWAYS
|
||||
|
||||
# 模式启用设置 - no_reply动作只在Focus模式下使用
|
||||
mode_enable = ChatMode.FOCUS
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
action_data: dict,
|
||||
reasoning: str,
|
||||
cycle_timers: dict,
|
||||
thinking_id: str,
|
||||
observations: List[Observation],
|
||||
log_prefix: str,
|
||||
shutting_down: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""初始化不回复动作处理器
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
action_data: 动作数据
|
||||
reasoning: 执行该动作的理由
|
||||
cycle_timers: 计时器字典
|
||||
thinking_id: 思考ID
|
||||
observations: 观察列表
|
||||
log_prefix: 日志前缀
|
||||
shutting_down: 是否正在关闭
|
||||
"""
|
||||
super().__init__(action_data, reasoning, cycle_timers, thinking_id)
|
||||
self.observations = observations
|
||||
self.log_prefix = log_prefix
|
||||
self._shutting_down = shutting_down
|
||||
|
||||
async def handle_action(self) -> Tuple[bool, str]:
|
||||
"""
|
||||
处理不回复的情况
|
||||
|
||||
工作流程:
|
||||
1. 等待新消息、超时或关闭信号
|
||||
2. 根据等待结果更新连续不回复计数
|
||||
3. 如果达到阈值,触发回调
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否执行成功, 空字符串)
|
||||
"""
|
||||
logger.info(f"{self.log_prefix} 决定不回复: {self.reasoning}")
|
||||
|
||||
observation = self.observations[0] if self.observations else None
|
||||
|
||||
try:
|
||||
with Timer("等待新消息", self.cycle_timers):
|
||||
# 等待新消息、超时或关闭信号,并获取结果
|
||||
await self._wait_for_new_message(observation, self.thinking_id, self.log_prefix)
|
||||
|
||||
return True, "" # 不回复动作没有回复文本
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.log_prefix} 处理 'no_reply' 时等待被中断 (CancelledError)")
|
||||
raise
|
||||
except Exception as e: # 捕获调用管理器或其他地方可能发生的错误
|
||||
logger.error(f"{self.log_prefix} 处理 'no_reply' 时发生错误: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False, ""
|
||||
|
||||
async def _wait_for_new_message(self, observation: ChattingObservation, thinking_id: str, log_prefix: str) -> bool:
|
||||
"""
|
||||
等待新消息 或 检测到关闭信号
|
||||
|
||||
参数:
|
||||
observation: 观察实例
|
||||
thinking_id: 思考ID
|
||||
log_prefix: 日志前缀
|
||||
|
||||
返回:
|
||||
bool: 是否检测到新消息 (如果因关闭信号退出则返回 False)
|
||||
"""
|
||||
wait_start_time = asyncio.get_event_loop().time()
|
||||
while True:
|
||||
# --- 在每次循环开始时检查关闭标志 ---
|
||||
if self._shutting_down:
|
||||
logger.info(f"{log_prefix} 等待新消息时检测到关闭信号,中断等待。")
|
||||
return False # 表示因为关闭而退出
|
||||
# -----------------------------------
|
||||
|
||||
thinking_id_timestamp = parse_thinking_id_to_timestamp(thinking_id)
|
||||
|
||||
# 检查新消息
|
||||
if await observation.has_new_messages_since(thinking_id_timestamp):
|
||||
logger.info(f"{log_prefix} 检测到新消息")
|
||||
return True
|
||||
|
||||
# 检查超时 (放在检查新消息和关闭之后)
|
||||
if asyncio.get_event_loop().time() - wait_start_time > WAITING_TIME_THRESHOLD:
|
||||
logger.warning(f"{log_prefix} 等待新消息超时({WAITING_TIME_THRESHOLD}秒)")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 短暂休眠,让其他任务有机会运行,并能更快响应取消或关闭
|
||||
await asyncio.sleep(0.5) # 缩短休眠时间
|
||||
except asyncio.CancelledError:
|
||||
# 如果在休眠时被取消,再次检查关闭标志
|
||||
# 如果是正常关闭,则不需要警告
|
||||
if not self._shutting_down:
|
||||
logger.warning(f"{log_prefix} _wait_for_new_message 的休眠被意外取消")
|
||||
# 无论如何,重新抛出异常,让上层处理
|
||||
raise
|
||||
@@ -1,779 +0,0 @@
|
||||
import traceback
|
||||
from typing import Tuple, Dict, List, Any, Optional, Union, Type
|
||||
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action, ActionActivationType, ChatMode # noqa F401
|
||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.person_info.person_info import person_info_manager
|
||||
from abc import abstractmethod
|
||||
from src.config.config import global_config
|
||||
import os
|
||||
import inspect
|
||||
import toml # 导入 toml 库
|
||||
from src.common.database.database_model import ActionRecords
|
||||
from src.common.database.database import db
|
||||
from peewee import Model, DoesNotExist
|
||||
import json
|
||||
import time
|
||||
|
||||
# 以下为类型注解需要
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor
|
||||
from src.chat.focus_chat.replyer.default_replyer import DefaultReplyer
|
||||
from src.chat.focus_chat.info.obs_info import ObsInfo
|
||||
|
||||
logger = get_logger("plugin_action")
|
||||
|
||||
|
||||
class PluginAction(BaseAction):
|
||||
"""插件动作基类
|
||||
|
||||
封装了主程序内部依赖,提供简化的API接口给插件开发者
|
||||
"""
|
||||
|
||||
action_config_file_name: Optional[str] = None # 插件可以覆盖此属性来指定配置文件名
|
||||
|
||||
# 默认激活类型设置,插件可以覆盖
|
||||
focus_activation_type = ActionActivationType.ALWAYS
|
||||
normal_activation_type = ActionActivationType.ALWAYS
|
||||
random_activation_probability: float = 0.3
|
||||
llm_judge_prompt: str = ""
|
||||
activation_keywords: list[str] = []
|
||||
keyword_case_sensitive: bool = False
|
||||
|
||||
# 默认模式启用设置 - 插件动作默认在所有模式下可用,插件可以覆盖
|
||||
mode_enable = ChatMode.ALL
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
action_data: dict,
|
||||
reasoning: str,
|
||||
cycle_timers: dict,
|
||||
thinking_id: str,
|
||||
global_config: Optional[dict] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""初始化插件动作基类"""
|
||||
super().__init__(action_data, reasoning, cycle_timers, thinking_id)
|
||||
|
||||
# 存储内部服务和对象引用
|
||||
self._services = {}
|
||||
self.config: Dict[str, Any] = {} # 用于存储插件自身的配置
|
||||
|
||||
# 从kwargs提取必要的内部服务
|
||||
if "observations" in kwargs:
|
||||
self._services["observations"] = kwargs["observations"]
|
||||
if "expressor" in kwargs:
|
||||
self._services["expressor"] = kwargs["expressor"]
|
||||
if "chat_stream" in kwargs:
|
||||
self._services["chat_stream"] = kwargs["chat_stream"]
|
||||
if "replyer" in kwargs:
|
||||
self._services["replyer"] = kwargs["replyer"]
|
||||
|
||||
self.log_prefix = kwargs.get("log_prefix", "")
|
||||
self._load_plugin_config() # 初始化时加载插件配置
|
||||
|
||||
def _load_plugin_config(self):
|
||||
"""
|
||||
加载插件自身的配置文件。
|
||||
配置文件应与插件模块在同一目录下。
|
||||
插件可以通过覆盖 `action_config_file_name` 类属性来指定文件名。
|
||||
如果 `action_config_file_name` 未指定,则不加载配置。
|
||||
仅支持 TOML (.toml) 格式。
|
||||
"""
|
||||
if not self.action_config_file_name:
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 插件 {self.__class__.__name__} 未指定 action_config_file_name,不加载插件配置。"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
plugin_module_path = inspect.getfile(self.__class__)
|
||||
plugin_dir = os.path.dirname(plugin_module_path)
|
||||
config_file_path = os.path.join(plugin_dir, self.action_config_file_name)
|
||||
|
||||
if not os.path.exists(config_file_path):
|
||||
logger.warning(
|
||||
f"{self.log_prefix} 插件 {self.__class__.__name__} 的配置文件 {config_file_path} 不存在。"
|
||||
)
|
||||
return
|
||||
|
||||
file_ext = os.path.splitext(self.action_config_file_name)[1].lower()
|
||||
|
||||
if file_ext == ".toml":
|
||||
with open(config_file_path, "r", encoding="utf-8") as f:
|
||||
self.config = toml.load(f) or {}
|
||||
logger.info(f"{self.log_prefix} 插件 {self.__class__.__name__} 的配置已从 {config_file_path} 加载。")
|
||||
else:
|
||||
logger.warning(
|
||||
f"{self.log_prefix} 不支持的插件配置文件格式: {file_ext}。仅支持 .toml。插件配置未加载。"
|
||||
)
|
||||
self.config = {} # 确保未加载时为空字典
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"{self.log_prefix} 加载插件 {self.__class__.__name__} 的配置文件 {self.action_config_file_name} 时出错: {e}"
|
||||
)
|
||||
self.config = {} # 出错时确保 config 是一个空字典
|
||||
|
||||
def get_global_config(self, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
安全地从全局配置中获取一个值。
|
||||
插件应使用此方法读取全局配置,以保证只读和隔离性。
|
||||
"""
|
||||
|
||||
return global_config.get(key, default)
|
||||
|
||||
async def get_user_id_by_person_name(self, person_name: str) -> Tuple[str, str]:
|
||||
"""根据用户名获取用户ID"""
|
||||
person_id = person_info_manager.get_person_id_by_person_name(person_name)
|
||||
user_id = await person_info_manager.get_value(person_id, "user_id")
|
||||
platform = await person_info_manager.get_value(person_id, "platform")
|
||||
return platform, user_id
|
||||
|
||||
# 提供简化的API方法
|
||||
async def send_message(self, type: str, data: str, target: Optional[str] = "", display_message: str = "") -> bool:
|
||||
"""发送消息的简化方法
|
||||
|
||||
Args:
|
||||
text: 要发送的消息文本
|
||||
target: 目标消息(可选)
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
try:
|
||||
expressor: DefaultExpressor = self._services.get("expressor")
|
||||
chat_stream: ChatStream = self._services.get("chat_stream")
|
||||
|
||||
if not expressor or not chat_stream:
|
||||
logger.error(f"{self.log_prefix} 无法发送消息:缺少必要的内部服务")
|
||||
return False
|
||||
|
||||
# 构造简化的动作数据
|
||||
# reply_data = {"text": text, "target": target or "", "emojis": []}
|
||||
|
||||
# 获取锚定消息(如果有)
|
||||
observations = self._services.get("observations", [])
|
||||
|
||||
if len(observations) > 0:
|
||||
chatting_observation: ChattingObservation = next(
|
||||
obs for obs in observations if isinstance(obs, ChattingObservation)
|
||||
)
|
||||
|
||||
anchor_message = chatting_observation.search_message_by_text(target)
|
||||
else:
|
||||
anchor_message = None
|
||||
|
||||
# 如果没有找到锚点消息,创建一个占位符
|
||||
if not anchor_message:
|
||||
logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符")
|
||||
anchor_message = await create_empty_anchor_message(
|
||||
chat_stream.platform, chat_stream.group_info, chat_stream
|
||||
)
|
||||
else:
|
||||
anchor_message.update_chat_stream(chat_stream)
|
||||
|
||||
response_set = [
|
||||
(type, data),
|
||||
]
|
||||
|
||||
# 调用内部方法发送消息
|
||||
success = await expressor.send_response_messages(
|
||||
anchor_message=anchor_message,
|
||||
response_set=response_set,
|
||||
display_message=display_message,
|
||||
)
|
||||
|
||||
return success
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 发送消息时出错: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
async def send_message_by_expressor(self, text: str, target: Optional[str] = None) -> bool:
|
||||
"""发送消息的简化方法
|
||||
|
||||
Args:
|
||||
text: 要发送的消息文本
|
||||
target: 目标消息(可选)
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
expressor: DefaultExpressor = self._services.get("expressor")
|
||||
chat_stream: ChatStream = self._services.get("chat_stream")
|
||||
|
||||
if not expressor or not chat_stream:
|
||||
logger.error(f"{self.log_prefix} 无法发送消息:缺少必要的内部服务")
|
||||
return False
|
||||
|
||||
# 构造简化的动作数据
|
||||
reply_data = {"text": text, "target": target or "", "emojis": []}
|
||||
|
||||
# 获取锚定消息(如果有)
|
||||
observations = self._services.get("observations", [])
|
||||
|
||||
# 查找 ChattingObservation 实例
|
||||
chatting_observation = None
|
||||
for obs in observations:
|
||||
if isinstance(obs, ChattingObservation):
|
||||
chatting_observation = obs
|
||||
break
|
||||
|
||||
if not chatting_observation:
|
||||
logger.warning(f"{self.log_prefix} 未找到 ChattingObservation 实例,创建占位符")
|
||||
anchor_message = await create_empty_anchor_message(
|
||||
chat_stream.platform, chat_stream.group_info, chat_stream
|
||||
)
|
||||
else:
|
||||
anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
|
||||
if not anchor_message:
|
||||
logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符")
|
||||
anchor_message = await create_empty_anchor_message(
|
||||
chat_stream.platform, chat_stream.group_info, chat_stream
|
||||
)
|
||||
else:
|
||||
anchor_message.update_chat_stream(chat_stream)
|
||||
|
||||
# 调用内部方法发送消息
|
||||
success, _ = await expressor.deal_reply(
|
||||
cycle_timers=self.cycle_timers,
|
||||
action_data=reply_data,
|
||||
anchor_message=anchor_message,
|
||||
reasoning=self.reasoning,
|
||||
thinking_id=self.thinking_id,
|
||||
)
|
||||
|
||||
return success
|
||||
|
||||
async def send_message_by_replyer(self, target: Optional[str] = None, extra_info_block: Optional[str] = None) -> bool:
|
||||
"""通过 replyer 发送消息的简化方法
|
||||
|
||||
Args:
|
||||
text: 要发送的消息文本
|
||||
target: 目标消息(可选)
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
replyer: DefaultReplyer = self._services.get("replyer")
|
||||
chat_stream: ChatStream = self._services.get("chat_stream")
|
||||
|
||||
if not replyer or not chat_stream:
|
||||
logger.error(f"{self.log_prefix} 无法发送消息:缺少必要的内部服务")
|
||||
return False
|
||||
|
||||
# 构造简化的动作数据
|
||||
reply_data = {"target": target or "", "extra_info_block": extra_info_block}
|
||||
|
||||
# 获取锚定消息(如果有)
|
||||
observations = self._services.get("observations", [])
|
||||
|
||||
# 查找 ChattingObservation 实例
|
||||
chatting_observation = None
|
||||
for obs in observations:
|
||||
if isinstance(obs, ChattingObservation):
|
||||
chatting_observation = obs
|
||||
break
|
||||
|
||||
if not chatting_observation:
|
||||
logger.warning(f"{self.log_prefix} 未找到 ChattingObservation 实例,创建占位符")
|
||||
anchor_message = await create_empty_anchor_message(
|
||||
chat_stream.platform, chat_stream.group_info, chat_stream
|
||||
)
|
||||
else:
|
||||
anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
|
||||
if not anchor_message:
|
||||
logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符")
|
||||
anchor_message = await create_empty_anchor_message(
|
||||
chat_stream.platform, chat_stream.group_info, chat_stream
|
||||
)
|
||||
else:
|
||||
anchor_message.update_chat_stream(chat_stream)
|
||||
|
||||
# 调用内部方法发送消息
|
||||
success, _ = await replyer.deal_reply(
|
||||
cycle_timers=self.cycle_timers,
|
||||
action_data=reply_data,
|
||||
anchor_message=anchor_message,
|
||||
reasoning=self.reasoning,
|
||||
thinking_id=self.thinking_id,
|
||||
)
|
||||
|
||||
return success
|
||||
|
||||
def get_chat_type(self) -> str:
|
||||
"""获取当前聊天类型
|
||||
|
||||
Returns:
|
||||
str: 聊天类型 ("group" 或 "private")
|
||||
"""
|
||||
chat_stream: ChatStream = self._services.get("chat_stream")
|
||||
if chat_stream and hasattr(chat_stream, "group_info"):
|
||||
return "group" if chat_stream.group_info else "private"
|
||||
return "unknown"
|
||||
|
||||
def get_recent_messages(self, count: int = 5) -> List[Dict[str, Any]]:
|
||||
"""获取最近的消息
|
||||
|
||||
Args:
|
||||
count: 要获取的消息数量
|
||||
|
||||
Returns:
|
||||
List[Dict]: 消息列表,每个消息包含发送者、内容等信息
|
||||
"""
|
||||
messages = []
|
||||
observations = self._services.get("observations", [])
|
||||
|
||||
if observations and len(observations) > 0:
|
||||
obs = observations[0]
|
||||
if hasattr(obs, "get_talking_message"):
|
||||
obs: ObsInfo
|
||||
raw_messages = obs.get_talking_message()
|
||||
# 转换为简化格式
|
||||
for msg in raw_messages[-count:]:
|
||||
simple_msg = {
|
||||
"sender": msg.get("sender", "未知"),
|
||||
"content": msg.get("content", ""),
|
||||
"timestamp": msg.get("timestamp", 0),
|
||||
}
|
||||
messages.append(simple_msg)
|
||||
|
||||
return messages
|
||||
|
||||
def get_available_models(self) -> Dict[str, Any]:
|
||||
"""获取所有可用的模型配置
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 模型配置字典,key为模型名称,value为模型配置
|
||||
"""
|
||||
if not hasattr(global_config, "model"):
|
||||
logger.error(f"{self.log_prefix} 无法获取模型列表:全局配置中未找到 model 配置")
|
||||
return {}
|
||||
|
||||
models = global_config.model
|
||||
|
||||
return models
|
||||
|
||||
async def generate_with_model(
|
||||
self,
|
||||
prompt: str,
|
||||
model_config: Dict[str, Any],
|
||||
request_type: str = "plugin.generate",
|
||||
**kwargs
|
||||
) -> Tuple[bool, str]:
|
||||
"""使用指定模型生成内容
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
model_config: 模型配置(从 get_available_models 获取的模型配置)
|
||||
temperature: 温度参数,控制随机性 (0-1)
|
||||
max_tokens: 最大生成token数
|
||||
request_type: 请求类型标识
|
||||
**kwargs: 其他模型特定参数
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否成功, 生成的内容或错误信息)
|
||||
"""
|
||||
try:
|
||||
|
||||
|
||||
logger.info(f"prompt: {prompt}")
|
||||
|
||||
llm_request = LLMRequest(
|
||||
model=model_config,
|
||||
request_type=request_type,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
response,(resoning , model_name) = await llm_request.generate_response_async(prompt)
|
||||
return True, response, resoning, model_name
|
||||
except Exception as e:
|
||||
error_msg = f"生成内容时出错: {str(e)}"
|
||||
logger.error(f"{self.log_prefix} {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
@abstractmethod
|
||||
async def process(self) -> Tuple[bool, str]:
|
||||
"""插件处理逻辑,子类必须实现此方法
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否执行成功, 回复文本)
|
||||
"""
|
||||
pass
|
||||
|
||||
async def handle_action(self) -> Tuple[bool, str]:
|
||||
"""实现BaseAction的抽象方法,调用子类的process方法
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否执行成功, 回复文本)
|
||||
"""
|
||||
return await self.process()
|
||||
|
||||
async def store_action_info(self, action_build_into_prompt: bool = False, action_prompt_display: str = "", action_done: bool = True) -> None:
|
||||
"""存储action执行信息到数据库
|
||||
|
||||
Args:
|
||||
action_build_into_prompt: 是否构建到提示中
|
||||
action_prompt_display: 动作显示内容
|
||||
"""
|
||||
try:
|
||||
chat_stream: ChatStream = self._services.get("chat_stream")
|
||||
if not chat_stream:
|
||||
logger.error(f"{self.log_prefix} 无法存储action信息:缺少chat_stream服务")
|
||||
return
|
||||
|
||||
action_time = time.time()
|
||||
action_id = f"{action_time}_{self.thinking_id}"
|
||||
|
||||
ActionRecords.create(
|
||||
action_id=action_id,
|
||||
time=action_time,
|
||||
action_name=self.__class__.__name__,
|
||||
action_data=str(self.action_data),
|
||||
action_done=action_done,
|
||||
action_build_into_prompt=action_build_into_prompt,
|
||||
action_prompt_display=action_prompt_display,
|
||||
chat_id=chat_stream.stream_id,
|
||||
chat_info_stream_id=chat_stream.stream_id,
|
||||
chat_info_platform=chat_stream.platform,
|
||||
user_id=chat_stream.user_info.user_id if chat_stream.user_info else "",
|
||||
user_nickname=chat_stream.user_info.user_nickname if chat_stream.user_info else "",
|
||||
user_cardname=chat_stream.user_info.user_cardname if chat_stream.user_info else ""
|
||||
)
|
||||
logger.debug(f"{self.log_prefix} 已存储action信息: {action_prompt_display}")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 存储action信息时出错: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
async def db_query(
|
||||
self,
|
||||
model_class: Type[Model],
|
||||
query_type: str = "get",
|
||||
filters: Dict[str, Any] = None,
|
||||
data: Dict[str, Any] = None,
|
||||
limit: int = None,
|
||||
order_by: List[str] = None,
|
||||
single_result: bool = False
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
"""执行数据库查询操作
|
||||
|
||||
这个方法提供了一个通用接口来执行数据库操作,包括查询、创建、更新和删除记录。
|
||||
|
||||
Args:
|
||||
model_class: Peewee 模型类,例如 ActionRecords, Messages 等
|
||||
query_type: 查询类型,可选值: "get", "create", "update", "delete", "count"
|
||||
filters: 过滤条件字典,键为字段名,值为要匹配的值
|
||||
data: 用于创建或更新的数据字典
|
||||
limit: 限制结果数量
|
||||
order_by: 排序字段列表,使用字段名,前缀'-'表示降序
|
||||
single_result: 是否只返回单个结果
|
||||
|
||||
Returns:
|
||||
根据查询类型返回不同的结果:
|
||||
- "get": 返回查询结果列表或单个结果(如果 single_result=True)
|
||||
- "create": 返回创建的记录
|
||||
- "update": 返回受影响的行数
|
||||
- "delete": 返回受影响的行数
|
||||
- "count": 返回记录数量
|
||||
|
||||
示例:
|
||||
# 查询最近10条消息
|
||||
messages = await self.db_query(
|
||||
Messages,
|
||||
query_type="get",
|
||||
filters={"chat_id": chat_stream.stream_id},
|
||||
limit=10,
|
||||
order_by=["-time"]
|
||||
)
|
||||
|
||||
# 创建一条记录
|
||||
new_record = await self.db_query(
|
||||
ActionRecords,
|
||||
query_type="create",
|
||||
data={"action_id": "123", "time": time.time(), "action_name": "TestAction"}
|
||||
)
|
||||
|
||||
# 更新记录
|
||||
updated_count = await self.db_query(
|
||||
ActionRecords,
|
||||
query_type="update",
|
||||
filters={"action_id": "123"},
|
||||
data={"action_done": True}
|
||||
)
|
||||
|
||||
# 删除记录
|
||||
deleted_count = await self.db_query(
|
||||
ActionRecords,
|
||||
query_type="delete",
|
||||
filters={"action_id": "123"}
|
||||
)
|
||||
|
||||
# 计数
|
||||
count = await self.db_query(
|
||||
Messages,
|
||||
query_type="count",
|
||||
filters={"chat_id": chat_stream.stream_id}
|
||||
)
|
||||
"""
|
||||
try:
|
||||
# 构建基本查询
|
||||
if query_type in ["get", "update", "delete", "count"]:
|
||||
query = model_class.select()
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
query = query.where(getattr(model_class, field) == value)
|
||||
|
||||
# 执行查询
|
||||
if query_type == "get":
|
||||
# 应用排序
|
||||
if order_by:
|
||||
for field in order_by:
|
||||
if field.startswith("-"):
|
||||
query = query.order_by(getattr(model_class, field[1:]).desc())
|
||||
else:
|
||||
query = query.order_by(getattr(model_class, field))
|
||||
|
||||
# 应用限制
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
# 执行查询
|
||||
results = list(query.dicts())
|
||||
|
||||
# 返回结果
|
||||
if single_result:
|
||||
return results[0] if results else None
|
||||
return results
|
||||
|
||||
elif query_type == "create":
|
||||
if not data:
|
||||
raise ValueError("创建记录需要提供data参数")
|
||||
|
||||
# 创建记录
|
||||
record = model_class.create(**data)
|
||||
# 返回创建的记录
|
||||
return model_class.select().where(model_class.id == record.id).dicts().get()
|
||||
|
||||
elif query_type == "update":
|
||||
if not data:
|
||||
raise ValueError("更新记录需要提供data参数")
|
||||
|
||||
# 更新记录
|
||||
return query.update(**data).execute()
|
||||
|
||||
elif query_type == "delete":
|
||||
# 删除记录
|
||||
return query.delete().execute()
|
||||
|
||||
elif query_type == "count":
|
||||
# 计数
|
||||
return query.count()
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的查询类型: {query_type}")
|
||||
|
||||
except DoesNotExist:
|
||||
# 记录不存在
|
||||
if query_type == "get" and single_result:
|
||||
return None
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 数据库操作出错: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# 根据查询类型返回合适的默认值
|
||||
if query_type == "get":
|
||||
return None if single_result else []
|
||||
elif query_type in ["create", "update", "delete", "count"]:
|
||||
return None
|
||||
|
||||
async def db_raw_query(
|
||||
self,
|
||||
sql: str,
|
||||
params: List[Any] = None,
|
||||
fetch_results: bool = True
|
||||
) -> Union[List[Dict[str, Any]], int, None]:
|
||||
"""执行原始SQL查询
|
||||
|
||||
警告: 使用此方法需要小心,确保SQL语句已正确构造以避免SQL注入风险。
|
||||
|
||||
Args:
|
||||
sql: 原始SQL查询字符串
|
||||
params: 查询参数列表,用于替换SQL中的占位符
|
||||
fetch_results: 是否获取查询结果,对于SELECT查询设为True,对于
|
||||
UPDATE/INSERT/DELETE等操作设为False
|
||||
|
||||
Returns:
|
||||
如果fetch_results为True,返回查询结果列表;
|
||||
如果fetch_results为False,返回受影响的行数;
|
||||
如果出错,返回None
|
||||
"""
|
||||
try:
|
||||
cursor = db.execute_sql(sql, params or [])
|
||||
|
||||
if fetch_results:
|
||||
# 获取列名
|
||||
columns = [col[0] for col in cursor.description]
|
||||
|
||||
# 构建结果字典列表
|
||||
results = []
|
||||
for row in cursor.fetchall():
|
||||
results.append(dict(zip(columns, row)))
|
||||
|
||||
return results
|
||||
else:
|
||||
# 返回受影响的行数
|
||||
return cursor.rowcount
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行原始SQL查询出错: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
async def db_save(
|
||||
self,
|
||||
model_class: Type[Model],
|
||||
data: Dict[str, Any],
|
||||
key_field: str = None,
|
||||
key_value: Any = None
|
||||
) -> Union[Dict[str, Any], None]:
|
||||
"""保存数据到数据库(创建或更新)
|
||||
|
||||
如果提供了key_field和key_value,会先尝试查找匹配的记录进行更新;
|
||||
如果没有找到匹配记录,或未提供key_field和key_value,则创建新记录。
|
||||
|
||||
Args:
|
||||
model_class: Peewee模型类,如ActionRecords, Messages等
|
||||
data: 要保存的数据字典
|
||||
key_field: 用于查找现有记录的字段名,例如"action_id"
|
||||
key_value: 用于查找现有记录的字段值
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 保存后的记录数据
|
||||
None: 如果操作失败
|
||||
|
||||
示例:
|
||||
# 创建或更新一条记录
|
||||
record = await self.db_save(
|
||||
ActionRecords,
|
||||
{
|
||||
"action_id": "123",
|
||||
"time": time.time(),
|
||||
"action_name": "TestAction",
|
||||
"action_done": True
|
||||
},
|
||||
key_field="action_id",
|
||||
key_value="123"
|
||||
)
|
||||
"""
|
||||
try:
|
||||
# 如果提供了key_field和key_value,尝试更新现有记录
|
||||
if key_field and key_value is not None:
|
||||
# 查找现有记录
|
||||
existing_records = list(model_class.select().where(
|
||||
getattr(model_class, key_field) == key_value
|
||||
).limit(1))
|
||||
|
||||
if existing_records:
|
||||
# 更新现有记录
|
||||
existing_record = existing_records[0]
|
||||
for field, value in data.items():
|
||||
setattr(existing_record, field, value)
|
||||
existing_record.save()
|
||||
|
||||
# 返回更新后的记录
|
||||
updated_record = model_class.select().where(
|
||||
model_class.id == existing_record.id
|
||||
).dicts().get()
|
||||
return updated_record
|
||||
|
||||
# 如果没有找到现有记录或未提供key_field和key_value,创建新记录
|
||||
new_record = model_class.create(**data)
|
||||
|
||||
# 返回创建的记录
|
||||
created_record = model_class.select().where(
|
||||
model_class.id == new_record.id
|
||||
).dicts().get()
|
||||
return created_record
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 保存数据库记录出错: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
async def db_get(
|
||||
self,
|
||||
model_class: Type[Model],
|
||||
filters: Dict[str, Any] = None,
|
||||
order_by: str = None,
|
||||
limit: int = None
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
"""从数据库获取记录
|
||||
|
||||
这是db_query方法的简化版本,专注于数据检索操作。
|
||||
|
||||
Args:
|
||||
model_class: Peewee模型类
|
||||
filters: 过滤条件,字段名和值的字典
|
||||
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间降序
|
||||
limit: 结果数量限制,如果为1则返回单个记录而不是列表
|
||||
|
||||
Returns:
|
||||
如果limit=1,返回单个记录字典或None;
|
||||
否则返回记录字典列表或空列表。
|
||||
|
||||
示例:
|
||||
# 获取单个记录
|
||||
record = await self.db_get(
|
||||
ActionRecords,
|
||||
filters={"action_id": "123"},
|
||||
limit=1
|
||||
)
|
||||
|
||||
# 获取最近10条记录
|
||||
records = await self.db_get(
|
||||
Messages,
|
||||
filters={"chat_id": chat_stream.stream_id},
|
||||
order_by="-time",
|
||||
limit=10
|
||||
)
|
||||
"""
|
||||
try:
|
||||
# 构建查询
|
||||
query = model_class.select()
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
query = query.where(getattr(model_class, field) == value)
|
||||
|
||||
# 应用排序
|
||||
if order_by:
|
||||
if order_by.startswith("-"):
|
||||
query = query.order_by(getattr(model_class, order_by[1:]).desc())
|
||||
else:
|
||||
query = query.order_by(getattr(model_class, order_by))
|
||||
|
||||
# 应用限制
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
# 执行查询
|
||||
results = list(query.dicts())
|
||||
|
||||
# 返回结果
|
||||
if limit == 1:
|
||||
return results[0] if results else None
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 获取数据库记录出错: {e}")
|
||||
traceback.print_exc()
|
||||
return None if limit == 1 else []
|
||||
@@ -1,196 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action, ActionActivationType, ChatMode
|
||||
from typing import Tuple, List
|
||||
from src.chat.heart_flow.observation.observation import Observation
|
||||
from src.chat.focus_chat.replyer.default_replyer import DefaultReplyer
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
|
||||
import time
|
||||
import traceback
|
||||
from src.common.database.database_model import ActionRecords
|
||||
import re
|
||||
|
||||
logger = get_logger("action_taken")
|
||||
|
||||
|
||||
@register_action
|
||||
class ReplyAction(BaseAction):
|
||||
"""回复动作处理类
|
||||
|
||||
处理构建和发送消息回复的动作。
|
||||
"""
|
||||
|
||||
action_name: str = "reply"
|
||||
action_description: str = "当你想要参与回复或者聊天"
|
||||
action_parameters: dict[str:str] = {
|
||||
"reply_to": "如果是明确回复某个人的发言,请在reply_to参数中指定,格式:(用户名:发言内容),如果不是,reply_to的值设为none"
|
||||
}
|
||||
action_require: list[str] = [
|
||||
"你想要闲聊或者随便附和",
|
||||
"有人提到你",
|
||||
"如果你刚刚进行了回复,不要对同一个话题重复回应"
|
||||
]
|
||||
|
||||
associated_types: list[str] = ["text"]
|
||||
|
||||
enable_plugin = True
|
||||
|
||||
# 激活类型设置
|
||||
focus_activation_type = ActionActivationType.ALWAYS
|
||||
|
||||
# 模式启用设置 - 回复动作只在Focus模式下使用
|
||||
mode_enable = ChatMode.FOCUS
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
action_data: dict,
|
||||
reasoning: str,
|
||||
cycle_timers: dict,
|
||||
thinking_id: str,
|
||||
observations: List[Observation],
|
||||
chat_stream: ChatStream,
|
||||
log_prefix: str,
|
||||
replyer: DefaultReplyer,
|
||||
**kwargs,
|
||||
):
|
||||
"""初始化回复动作处理器
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
action_data: 动作数据,包含 message, emojis, target 等
|
||||
reasoning: 执行该动作的理由
|
||||
cycle_timers: 计时器字典
|
||||
thinking_id: 思考ID
|
||||
observations: 观察列表
|
||||
replyer: 回复器
|
||||
chat_stream: 聊天流
|
||||
log_prefix: 日志前缀
|
||||
"""
|
||||
super().__init__(action_data, reasoning, cycle_timers, thinking_id)
|
||||
self.observations = observations
|
||||
self.replyer = replyer
|
||||
self.chat_stream = chat_stream
|
||||
self.log_prefix = log_prefix
|
||||
|
||||
async def handle_action(self) -> Tuple[bool, str]:
|
||||
"""
|
||||
处理回复动作
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否执行成功, 回复文本)
|
||||
"""
|
||||
# 注意: 此处可能会使用不同的expressor实现根据任务类型切换不同的回复策略
|
||||
success, reply_text = await self._handle_reply(
|
||||
reasoning=self.reasoning,
|
||||
reply_data=self.action_data,
|
||||
cycle_timers=self.cycle_timers,
|
||||
thinking_id=self.thinking_id,
|
||||
)
|
||||
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=f"{reply_text}",
|
||||
)
|
||||
|
||||
return success, reply_text
|
||||
|
||||
async def _handle_reply(
|
||||
self, reasoning: str, reply_data: dict, cycle_timers: dict, thinking_id: str
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
处理统一的回复动作 - 可包含文本和表情,顺序任意
|
||||
|
||||
reply_data格式:
|
||||
{
|
||||
"text": "你好啊" # 文本内容列表(可选)
|
||||
"target": "锚定消息", # 锚定消息的文本内容
|
||||
}
|
||||
"""
|
||||
logger.info(f"{self.log_prefix} 决定回复: {self.reasoning}")
|
||||
|
||||
# 从聊天观察获取锚定消息
|
||||
chatting_observation: ChattingObservation = next(
|
||||
obs for obs in self.observations if isinstance(obs, ChattingObservation)
|
||||
)
|
||||
|
||||
reply_to = reply_data.get("reply_to", "none")
|
||||
|
||||
# sender = ""
|
||||
target = ""
|
||||
if ":" in reply_to or ":" in reply_to:
|
||||
# 使用正则表达式匹配中文或英文冒号
|
||||
parts = re.split(pattern=r'[::]', string=reply_to, maxsplit=1)
|
||||
if len(parts) == 2:
|
||||
# sender = parts[0].strip()
|
||||
target = parts[1].strip()
|
||||
anchor_message = chatting_observation.search_message_by_text(target)
|
||||
else:
|
||||
anchor_message = None
|
||||
|
||||
if anchor_message:
|
||||
anchor_message.update_chat_stream(self.chat_stream)
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符")
|
||||
anchor_message = await create_empty_anchor_message(
|
||||
self.chat_stream.platform, self.chat_stream.group_info, self.chat_stream
|
||||
)
|
||||
|
||||
|
||||
success, reply_set = await self.replyer.deal_reply(
|
||||
cycle_timers=cycle_timers,
|
||||
action_data=reply_data,
|
||||
anchor_message=anchor_message,
|
||||
reasoning=reasoning,
|
||||
thinking_id=thinking_id,
|
||||
)
|
||||
|
||||
reply_text = ""
|
||||
for reply in reply_set:
|
||||
type = reply[0]
|
||||
data = reply[1]
|
||||
if type == "text":
|
||||
reply_text += data
|
||||
elif type == "emoji":
|
||||
reply_text += data
|
||||
|
||||
return success, reply_text
|
||||
|
||||
|
||||
async def store_action_info(self, action_build_into_prompt: bool = False, action_prompt_display: str = "", action_done: bool = True) -> None:
|
||||
"""存储action执行信息到数据库
|
||||
|
||||
Args:
|
||||
action_build_into_prompt: 是否构建到提示中
|
||||
action_prompt_display: 动作显示内容
|
||||
"""
|
||||
try:
|
||||
chat_stream = self.chat_stream
|
||||
if not chat_stream:
|
||||
logger.error(f"{self.log_prefix} 无法存储action信息:缺少chat_stream服务")
|
||||
return
|
||||
|
||||
action_time = time.time()
|
||||
action_id = f"{action_time}_{self.thinking_id}"
|
||||
|
||||
ActionRecords.create(
|
||||
action_id=action_id,
|
||||
time=action_time,
|
||||
action_name=self.__class__.__name__,
|
||||
action_data=str(self.action_data),
|
||||
action_done=action_done,
|
||||
action_build_into_prompt=action_build_into_prompt,
|
||||
action_prompt_display=action_prompt_display,
|
||||
chat_id=chat_stream.stream_id,
|
||||
chat_info_stream_id=chat_stream.stream_id,
|
||||
chat_info_platform=chat_stream.platform,
|
||||
user_id=chat_stream.user_info.user_id if chat_stream.user_info else "",
|
||||
user_nickname=chat_stream.user_info.user_nickname if chat_stream.user_info else "",
|
||||
user_cardname=chat_stream.user_info.user_cardname if chat_stream.user_info else ""
|
||||
)
|
||||
logger.debug(f"{self.log_prefix} 已存储action信息: {action_prompt_display}")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 存储action信息时出错: {e}")
|
||||
traceback.print_exc()
|
||||
@@ -1,12 +1,11 @@
|
||||
from typing import List, Optional, Any, Dict
|
||||
from src.chat.heart_flow.observation.observation import Observation
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
|
||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||
from src.chat.message_receive.chat_stream import chat_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.config.config import global_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.focus_chat.planners.actions.base_action import ActionActivationType, ChatMode
|
||||
import random
|
||||
import asyncio
|
||||
import hashlib
|
||||
@@ -29,14 +28,14 @@ class ActionModifier:
|
||||
def __init__(self, action_manager: ActionManager):
|
||||
"""初始化动作处理器"""
|
||||
self.action_manager = action_manager
|
||||
self.all_actions = self.action_manager.get_using_actions_for_mode(ChatMode.FOCUS)
|
||||
|
||||
self.all_actions = self.action_manager.get_using_actions_for_mode("focus")
|
||||
|
||||
# 用于LLM判定的小模型
|
||||
self.llm_judge = LLMRequest(
|
||||
model=global_config.model.utils_small,
|
||||
request_type="action.judge",
|
||||
)
|
||||
|
||||
|
||||
# 缓存相关属性
|
||||
self._llm_judge_cache = {} # 缓存LLM判定结果
|
||||
self._cache_expiry_time = 30 # 缓存过期时间(秒)
|
||||
@@ -49,15 +48,15 @@ class ActionModifier:
|
||||
):
|
||||
"""
|
||||
完整的动作修改流程,整合传统观察处理和新的激活类型判定
|
||||
|
||||
|
||||
这个方法处理完整的动作管理流程:
|
||||
1. 基于观察的传统动作修改(循环历史分析、类型匹配等)
|
||||
2. 基于激活类型的智能动作判定,最终确定可用动作集
|
||||
|
||||
|
||||
处理后,ActionManager 将包含最终的可用动作集,供规划器直接使用
|
||||
"""
|
||||
logger.debug(f"{self.log_prefix}开始完整动作修改流程")
|
||||
|
||||
|
||||
# === 第一阶段:传统观察处理 ===
|
||||
if observations:
|
||||
hfc_obs = None
|
||||
@@ -79,14 +78,17 @@ class ActionModifier:
|
||||
if hfc_obs:
|
||||
obs = hfc_obs
|
||||
# 获取适用于FOCUS模式的动作
|
||||
all_actions = self.action_manager.get_using_actions_for_mode(ChatMode.FOCUS)
|
||||
all_actions = self.action_manager.get_using_actions_for_mode("focus")
|
||||
# print("=======================")
|
||||
# print(all_actions)
|
||||
# print("=======================")
|
||||
action_changes = await self.analyze_loop_actions(obs)
|
||||
if action_changes["add"] or action_changes["remove"]:
|
||||
# 合并动作变更
|
||||
merged_action_changes["add"].extend(action_changes["add"])
|
||||
merged_action_changes["remove"].extend(action_changes["remove"])
|
||||
reasons.append("基于循环历史分析")
|
||||
|
||||
|
||||
# 详细记录循环历史分析的变更原因
|
||||
for action_name in action_changes["add"]:
|
||||
logger.info(f"{self.log_prefix}添加动作: {action_name},原因: 循环历史分析建议添加")
|
||||
@@ -97,7 +99,7 @@ class ActionModifier:
|
||||
if chat_obs:
|
||||
obs = chat_obs
|
||||
# 检查动作的关联类型
|
||||
chat_context = chat_manager.get_stream(obs.chat_id).context
|
||||
chat_context = get_chat_manager().get_stream(obs.chat_id).context
|
||||
type_mismatched_actions = []
|
||||
|
||||
for action_name in all_actions.keys():
|
||||
@@ -106,7 +108,9 @@ class ActionModifier:
|
||||
if not chat_context.check_types(data["associated_types"]):
|
||||
type_mismatched_actions.append(action_name)
|
||||
associated_types_str = ", ".join(data["associated_types"])
|
||||
logger.info(f"{self.log_prefix}移除动作: {action_name},原因: 关联类型不匹配(需要: {associated_types_str})")
|
||||
logger.info(
|
||||
f"{self.log_prefix}移除动作: {action_name},原因: 关联类型不匹配(需要: {associated_types_str})"
|
||||
)
|
||||
|
||||
if type_mismatched_actions:
|
||||
# 合并到移除列表中
|
||||
@@ -123,17 +127,19 @@ class ActionModifier:
|
||||
self.action_manager.remove_action_from_using(action_name)
|
||||
logger.debug(f"{self.log_prefix}应用移除动作: {action_name},原因集合: {reasons}")
|
||||
|
||||
logger.info(f"{self.log_prefix}传统动作修改完成,当前使用动作: {list(self.action_manager.get_using_actions().keys())}")
|
||||
logger.info(
|
||||
f"{self.log_prefix}传统动作修改完成,当前使用动作: {list(self.action_manager.get_using_actions().keys())}"
|
||||
)
|
||||
|
||||
# === 第二阶段:激活类型判定 ===
|
||||
# 如果提供了聊天上下文,则进行激活类型判定
|
||||
if chat_content is not None:
|
||||
logger.debug(f"{self.log_prefix}开始激活类型判定阶段")
|
||||
|
||||
|
||||
# 获取当前使用的动作集(经过第一阶段处理,且适用于FOCUS模式)
|
||||
current_using_actions = self.action_manager.get_using_actions()
|
||||
all_registered_actions = self.action_manager.get_using_actions_for_mode(ChatMode.FOCUS)
|
||||
|
||||
all_registered_actions = self.action_manager.get_registered_actions()
|
||||
|
||||
# 构建完整的动作信息
|
||||
current_actions_with_info = {}
|
||||
for action_name in current_using_actions.keys():
|
||||
@@ -141,46 +147,49 @@ class ActionModifier:
|
||||
current_actions_with_info[action_name] = all_registered_actions[action_name]
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix}使用中的动作 {action_name} 未在已注册动作中找到")
|
||||
|
||||
|
||||
# 应用激活类型判定
|
||||
final_activated_actions = await self._apply_activation_type_filtering(
|
||||
current_actions_with_info,
|
||||
chat_content,
|
||||
)
|
||||
|
||||
|
||||
# 更新ActionManager,移除未激活的动作
|
||||
actions_to_remove = []
|
||||
removal_reasons = {}
|
||||
|
||||
|
||||
for action_name in current_using_actions.keys():
|
||||
if action_name not in final_activated_actions:
|
||||
actions_to_remove.append(action_name)
|
||||
# 确定移除原因
|
||||
if action_name in all_registered_actions:
|
||||
action_info = all_registered_actions[action_name]
|
||||
activation_type = action_info.get("focus_activation_type", ActionActivationType.ALWAYS)
|
||||
|
||||
if activation_type == ActionActivationType.RANDOM:
|
||||
activation_type = action_info.get("focus_activation_type", "always")
|
||||
|
||||
# 处理字符串格式的激活类型值
|
||||
if activation_type == "random":
|
||||
probability = action_info.get("random_probability", 0.3)
|
||||
removal_reasons[action_name] = f"RANDOM类型未触发(概率{probability})"
|
||||
elif activation_type == ActionActivationType.LLM_JUDGE:
|
||||
elif activation_type == "llm_judge":
|
||||
removal_reasons[action_name] = "LLM判定未激活"
|
||||
elif activation_type == ActionActivationType.KEYWORD:
|
||||
elif activation_type == "keyword":
|
||||
keywords = action_info.get("activation_keywords", [])
|
||||
removal_reasons[action_name] = f"关键词未匹配(关键词: {keywords})"
|
||||
else:
|
||||
removal_reasons[action_name] = "激活判定未通过"
|
||||
else:
|
||||
removal_reasons[action_name] = "动作信息不完整"
|
||||
|
||||
|
||||
for action_name in actions_to_remove:
|
||||
self.action_manager.remove_action_from_using(action_name)
|
||||
reason = removal_reasons.get(action_name, "未知原因")
|
||||
logger.info(f"{self.log_prefix}移除动作: {action_name},原因: {reason}")
|
||||
|
||||
|
||||
logger.info(f"{self.log_prefix}激活类型判定完成,最终可用动作: {list(final_activated_actions.keys())}")
|
||||
|
||||
logger.info(f"{self.log_prefix}完整动作修改流程结束,最终动作集: {list(self.action_manager.get_using_actions().keys())}")
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix}完整动作修改流程结束,最终动作集: {list(self.action_manager.get_using_actions().keys())}"
|
||||
)
|
||||
|
||||
async def _apply_activation_type_filtering(
|
||||
self,
|
||||
@@ -189,43 +198,42 @@ class ActionModifier:
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
应用激活类型过滤逻辑,支持四种激活类型的并行处理
|
||||
|
||||
|
||||
Args:
|
||||
actions_with_info: 带完整信息的动作字典
|
||||
observed_messages_str: 观察到的聊天消息
|
||||
chat_context: 聊天上下文信息
|
||||
extra_context: 额外的上下文信息
|
||||
|
||||
chat_content: 聊天内容
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 过滤后激活的actions字典
|
||||
"""
|
||||
activated_actions = {}
|
||||
|
||||
|
||||
# 分类处理不同激活类型的actions
|
||||
always_actions = {}
|
||||
random_actions = {}
|
||||
llm_judge_actions = {}
|
||||
keyword_actions = {}
|
||||
|
||||
|
||||
for action_name, action_info in actions_with_info.items():
|
||||
activation_type = action_info.get("focus_activation_type", ActionActivationType.ALWAYS)
|
||||
|
||||
if activation_type == ActionActivationType.ALWAYS:
|
||||
activation_type = action_info.get("focus_activation_type", "always")
|
||||
|
||||
# 现在统一是字符串格式的激活类型值
|
||||
if activation_type == "always":
|
||||
always_actions[action_name] = action_info
|
||||
elif activation_type == ActionActivationType.RANDOM:
|
||||
elif activation_type == "random":
|
||||
random_actions[action_name] = action_info
|
||||
elif activation_type == ActionActivationType.LLM_JUDGE:
|
||||
elif activation_type == "llm_judge":
|
||||
llm_judge_actions[action_name] = action_info
|
||||
elif activation_type == ActionActivationType.KEYWORD:
|
||||
elif activation_type == "keyword":
|
||||
keyword_actions[action_name] = action_info
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix}未知的激活类型: {activation_type},跳过处理")
|
||||
|
||||
|
||||
# 1. 处理ALWAYS类型(直接激活)
|
||||
for action_name, action_info in always_actions.items():
|
||||
activated_actions[action_name] = action_info
|
||||
logger.debug(f"{self.log_prefix}激活动作: {action_name},原因: ALWAYS类型直接激活")
|
||||
|
||||
|
||||
# 2. 处理RANDOM类型
|
||||
for action_name, action_info in random_actions.items():
|
||||
probability = action_info.get("random_probability", 0.3)
|
||||
@@ -235,7 +243,7 @@ class ActionModifier:
|
||||
logger.debug(f"{self.log_prefix}激活动作: {action_name},原因: RANDOM类型触发(概率{probability})")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: RANDOM类型未触发(概率{probability})")
|
||||
|
||||
|
||||
# 3. 处理KEYWORD类型(快速判定)
|
||||
for action_name, action_info in keyword_actions.items():
|
||||
should_activate = self._check_keyword_activation(
|
||||
@@ -250,7 +258,7 @@ class ActionModifier:
|
||||
else:
|
||||
keywords = action_info.get("activation_keywords", [])
|
||||
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: KEYWORD类型未匹配关键词({keywords})")
|
||||
|
||||
|
||||
# 4. 处理LLM_JUDGE类型(并行判定)
|
||||
if llm_judge_actions:
|
||||
# 直接并行处理所有LLM判定actions
|
||||
@@ -258,7 +266,7 @@ class ActionModifier:
|
||||
llm_judge_actions,
|
||||
chat_content,
|
||||
)
|
||||
|
||||
|
||||
# 添加激活的LLM判定actions
|
||||
for action_name, should_activate in llm_results.items():
|
||||
if should_activate:
|
||||
@@ -266,46 +274,43 @@ class ActionModifier:
|
||||
logger.debug(f"{self.log_prefix}激活动作: {action_name},原因: LLM_JUDGE类型判定通过")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: LLM_JUDGE类型判定未通过")
|
||||
|
||||
|
||||
logger.debug(f"{self.log_prefix}激活类型过滤完成: {list(activated_actions.keys())}")
|
||||
return activated_actions
|
||||
|
||||
async def process_actions_for_planner(
|
||||
self,
|
||||
observed_messages_str: str = "",
|
||||
chat_context: Optional[str] = None,
|
||||
extra_context: Optional[str] = None
|
||||
self, observed_messages_str: str = "", chat_context: Optional[str] = None, extra_context: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
[已废弃] 此方法现在已被整合到 modify_actions() 中
|
||||
|
||||
|
||||
为了保持向后兼容性而保留,但建议直接使用 ActionManager.get_using_actions()
|
||||
规划器应该直接从 ActionManager 获取最终的可用动作集,而不是调用此方法
|
||||
|
||||
|
||||
新的架构:
|
||||
1. 主循环调用 modify_actions() 处理完整的动作管理流程
|
||||
2. 规划器直接使用 ActionManager.get_using_actions() 获取最终动作集
|
||||
"""
|
||||
logger.warning(f"{self.log_prefix}process_actions_for_planner() 已废弃,建议规划器直接使用 ActionManager.get_using_actions()")
|
||||
|
||||
logger.warning(
|
||||
f"{self.log_prefix}process_actions_for_planner() 已废弃,建议规划器直接使用 ActionManager.get_using_actions()"
|
||||
)
|
||||
|
||||
# 为了向后兼容,仍然返回当前使用的动作集
|
||||
current_using_actions = self.action_manager.get_using_actions()
|
||||
all_registered_actions = self.action_manager.get_registered_actions()
|
||||
|
||||
|
||||
# 构建完整的动作信息
|
||||
result = {}
|
||||
for action_name in current_using_actions.keys():
|
||||
if action_name in all_registered_actions:
|
||||
result[action_name] = all_registered_actions[action_name]
|
||||
|
||||
|
||||
return result
|
||||
|
||||
def _generate_context_hash(self, chat_content: str) -> str:
|
||||
"""生成上下文的哈希值用于缓存"""
|
||||
context_content = f"{chat_content}"
|
||||
return hashlib.md5(context_content.encode('utf-8')).hexdigest()
|
||||
|
||||
|
||||
return hashlib.md5(context_content.encode("utf-8")).hexdigest()
|
||||
|
||||
async def _process_llm_judge_actions_parallel(
|
||||
self,
|
||||
@@ -314,85 +319,83 @@ class ActionModifier:
|
||||
) -> Dict[str, bool]:
|
||||
"""
|
||||
并行处理LLM判定actions,支持智能缓存
|
||||
|
||||
|
||||
Args:
|
||||
llm_judge_actions: 需要LLM判定的actions
|
||||
observed_messages_str: 观察到的聊天消息
|
||||
chat_context: 聊天上下文
|
||||
extra_context: 额外上下文
|
||||
|
||||
chat_content: 聊天内容
|
||||
|
||||
Returns:
|
||||
Dict[str, bool]: action名称到激活结果的映射
|
||||
"""
|
||||
|
||||
|
||||
# 生成当前上下文的哈希值
|
||||
current_context_hash = self._generate_context_hash(chat_content)
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
results = {}
|
||||
tasks_to_run = {}
|
||||
|
||||
|
||||
# 检查缓存
|
||||
for action_name, action_info in llm_judge_actions.items():
|
||||
cache_key = f"{action_name}_{current_context_hash}"
|
||||
|
||||
|
||||
# 检查是否有有效的缓存
|
||||
if (cache_key in self._llm_judge_cache and
|
||||
current_time - self._llm_judge_cache[cache_key]["timestamp"] < self._cache_expiry_time):
|
||||
|
||||
if (
|
||||
cache_key in self._llm_judge_cache
|
||||
and current_time - self._llm_judge_cache[cache_key]["timestamp"] < self._cache_expiry_time
|
||||
):
|
||||
results[action_name] = self._llm_judge_cache[cache_key]["result"]
|
||||
logger.debug(f"{self.log_prefix}使用缓存结果 {action_name}: {'激活' if results[action_name] else '未激活'}")
|
||||
logger.debug(
|
||||
f"{self.log_prefix}使用缓存结果 {action_name}: {'激活' if results[action_name] else '未激活'}"
|
||||
)
|
||||
else:
|
||||
# 需要进行LLM判定
|
||||
tasks_to_run[action_name] = action_info
|
||||
|
||||
|
||||
# 如果有需要运行的任务,并行执行
|
||||
if tasks_to_run:
|
||||
logger.debug(f"{self.log_prefix}并行执行LLM判定,任务数: {len(tasks_to_run)}")
|
||||
|
||||
|
||||
# 创建并行任务
|
||||
tasks = []
|
||||
task_names = []
|
||||
|
||||
|
||||
for action_name, action_info in tasks_to_run.items():
|
||||
task = self._llm_judge_action(
|
||||
action_name,
|
||||
action_info,
|
||||
chat_content,
|
||||
action_name,
|
||||
action_info,
|
||||
chat_content,
|
||||
)
|
||||
tasks.append(task)
|
||||
task_names.append(action_name)
|
||||
|
||||
|
||||
# 并行执行所有任务
|
||||
try:
|
||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
|
||||
# 处理结果并更新缓存
|
||||
for i, (action_name, result) in enumerate(zip(task_names, task_results)):
|
||||
for _, (action_name, result) in enumerate(zip(task_names, task_results)):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"{self.log_prefix}LLM判定action {action_name} 时出错: {result}")
|
||||
results[action_name] = False
|
||||
else:
|
||||
results[action_name] = result
|
||||
|
||||
|
||||
# 更新缓存
|
||||
cache_key = f"{action_name}_{current_context_hash}"
|
||||
self._llm_judge_cache[cache_key] = {
|
||||
"result": result,
|
||||
"timestamp": current_time
|
||||
}
|
||||
|
||||
self._llm_judge_cache[cache_key] = {"result": result, "timestamp": current_time}
|
||||
|
||||
logger.debug(f"{self.log_prefix}并行LLM判定完成,耗时: {time.time() - current_time:.2f}s")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}并行LLM判定失败: {e}")
|
||||
# 如果并行执行失败,为所有任务返回False
|
||||
for action_name in tasks_to_run.keys():
|
||||
results[action_name] = False
|
||||
|
||||
|
||||
# 清理过期缓存
|
||||
self._cleanup_expired_cache(current_time)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
def _cleanup_expired_cache(self, current_time: float):
|
||||
@@ -401,40 +404,39 @@ class ActionModifier:
|
||||
for cache_key, cache_data in self._llm_judge_cache.items():
|
||||
if current_time - cache_data["timestamp"] > self._cache_expiry_time:
|
||||
expired_keys.append(cache_key)
|
||||
|
||||
|
||||
for key in expired_keys:
|
||||
del self._llm_judge_cache[key]
|
||||
|
||||
|
||||
if expired_keys:
|
||||
logger.debug(f"{self.log_prefix}清理了 {len(expired_keys)} 个过期缓存条目")
|
||||
|
||||
async def _llm_judge_action(
|
||||
self,
|
||||
action_name: str,
|
||||
self,
|
||||
action_name: str,
|
||||
action_info: Dict[str, Any],
|
||||
chat_content: str = "",
|
||||
) -> bool:
|
||||
"""
|
||||
使用LLM判定是否应该激活某个action
|
||||
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
action_info: 动作信息
|
||||
observed_messages_str: 观察到的聊天消息
|
||||
chat_context: 聊天上下文
|
||||
extra_context: 额外上下文
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否应该激活此action
|
||||
"""
|
||||
|
||||
|
||||
try:
|
||||
# 构建判定提示词
|
||||
action_description = action_info.get("description", "")
|
||||
action_require = action_info.get("require", [])
|
||||
custom_prompt = action_info.get("llm_judge_prompt", "")
|
||||
|
||||
|
||||
|
||||
# 构建基础判定提示词
|
||||
base_prompt = f"""
|
||||
你需要判断在当前聊天情况下,是否应该激活名为"{action_name}"的动作。
|
||||
@@ -445,34 +447,34 @@ class ActionModifier:
|
||||
"""
|
||||
for req in action_require:
|
||||
base_prompt += f"- {req}\n"
|
||||
|
||||
|
||||
if custom_prompt:
|
||||
base_prompt += f"\n额外判定条件:\n{custom_prompt}\n"
|
||||
|
||||
|
||||
if chat_content:
|
||||
base_prompt += f"\n当前聊天记录:\n{chat_content}\n"
|
||||
|
||||
|
||||
|
||||
base_prompt += """
|
||||
请根据以上信息判断是否应该激活这个动作。
|
||||
只需要回答"是"或"否",不要有其他内容。
|
||||
"""
|
||||
|
||||
|
||||
# 调用LLM进行判定
|
||||
response, _ = await self.llm_judge.generate_response_async(prompt=base_prompt)
|
||||
|
||||
|
||||
# 解析响应
|
||||
response = response.strip().lower()
|
||||
|
||||
|
||||
# print(base_prompt)
|
||||
print(f"LLM判定动作 {action_name}:响应='{response}'")
|
||||
|
||||
|
||||
|
||||
should_activate = "是" in response or "yes" in response or "true" in response
|
||||
|
||||
logger.debug(f"{self.log_prefix}LLM判定动作 {action_name}:响应='{response}',结果={'激活' if should_activate else '不激活'}")
|
||||
|
||||
logger.debug(
|
||||
f"{self.log_prefix}LLM判定动作 {action_name}:响应='{response}',结果={'激活' if should_activate else '不激活'}"
|
||||
)
|
||||
return should_activate
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}LLM判定动作 {action_name} 时出错: {e}")
|
||||
# 出错时默认不激活
|
||||
@@ -486,45 +488,45 @@ class ActionModifier:
|
||||
) -> bool:
|
||||
"""
|
||||
检查是否匹配关键词触发条件
|
||||
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
action_info: 动作信息
|
||||
observed_messages_str: 观察到的聊天消息
|
||||
chat_context: 聊天上下文
|
||||
extra_context: 额外上下文
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否应该激活此action
|
||||
"""
|
||||
|
||||
|
||||
activation_keywords = action_info.get("activation_keywords", [])
|
||||
case_sensitive = action_info.get("keyword_case_sensitive", False)
|
||||
|
||||
|
||||
if not activation_keywords:
|
||||
logger.warning(f"{self.log_prefix}动作 {action_name} 设置为关键词触发但未配置关键词")
|
||||
return False
|
||||
|
||||
|
||||
# 构建检索文本
|
||||
search_text = ""
|
||||
if chat_content:
|
||||
search_text += chat_content
|
||||
# if chat_context:
|
||||
# search_text += f" {chat_context}"
|
||||
# search_text += f" {chat_context}"
|
||||
# if extra_context:
|
||||
# search_text += f" {extra_context}"
|
||||
|
||||
# search_text += f" {extra_context}"
|
||||
|
||||
# 如果不区分大小写,转换为小写
|
||||
if not case_sensitive:
|
||||
search_text = search_text.lower()
|
||||
|
||||
|
||||
# 检查每个关键词
|
||||
matched_keywords = []
|
||||
for keyword in activation_keywords:
|
||||
check_keyword = keyword if case_sensitive else keyword.lower()
|
||||
if check_keyword in search_text:
|
||||
matched_keywords.append(keyword)
|
||||
|
||||
|
||||
if matched_keywords:
|
||||
logger.debug(f"{self.log_prefix}动作 {action_name} 匹配到关键词: {matched_keywords}")
|
||||
return True
|
||||
@@ -560,15 +562,17 @@ class ActionModifier:
|
||||
reply_sequence.append(action_type == "reply")
|
||||
|
||||
# 检查no_reply比例
|
||||
if len(recent_cycles) >= (5 * global_config.chat.exit_focus_threshold) and (
|
||||
if len(recent_cycles) >= (4 * global_config.chat.exit_focus_threshold) and (
|
||||
no_reply_count / len(recent_cycles)
|
||||
) >= (0.8 * global_config.chat.exit_focus_threshold):
|
||||
) >= (0.7 * global_config.chat.exit_focus_threshold):
|
||||
if global_config.chat.chat_mode == "auto":
|
||||
result["add"].append("exit_focus_chat")
|
||||
result["remove"].append("no_reply")
|
||||
result["remove"].append("reply")
|
||||
no_reply_ratio = no_reply_count / len(recent_cycles)
|
||||
logger.info(f"{self.log_prefix}检测到高no_reply比例: {no_reply_ratio:.2f},达到退出聊天阈值,将添加exit_focus_chat并移除no_reply/reply动作")
|
||||
logger.info(
|
||||
f"{self.log_prefix}检测到高no_reply比例: {no_reply_ratio:.2f},达到退出聊天阈值,将添加exit_focus_chat并移除no_reply/reply动作"
|
||||
)
|
||||
|
||||
# 计算连续回复的相关阈值
|
||||
|
||||
@@ -593,7 +597,7 @@ class ActionModifier:
|
||||
if len(last_max_reply_num) >= max_reply_num and all(last_max_reply_num):
|
||||
# 如果最近max_reply_num次都是reply,直接移除
|
||||
result["remove"].append("reply")
|
||||
reply_count = len(last_max_reply_num) - no_reply_count
|
||||
# reply_count = len(last_max_reply_num) - no_reply_count
|
||||
logger.info(
|
||||
f"{self.log_prefix}移除reply动作,原因: 连续回复过多(最近{len(last_max_reply_num)}次全是reply,超过阈值{max_reply_num})"
|
||||
)
|
||||
@@ -622,8 +626,6 @@ class ActionModifier:
|
||||
f"{self.log_prefix}连续回复检测:最近{one_thres_reply_num}次全是reply,{removal_probability:.2f}概率移除,未触发"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"{self.log_prefix}连续回复检测:无需移除reply动作,最近回复模式正常"
|
||||
)
|
||||
logger.debug(f"{self.log_prefix}连续回复检测:无需移除reply动作,最近回复模式正常")
|
||||
|
||||
return result
|
||||
|
||||
@@ -2,8 +2,7 @@ from typing import Dict, Type
|
||||
from src.chat.focus_chat.planners.base_planner import BasePlanner
|
||||
from src.chat.focus_chat.planners.planner_simple import ActionPlanner as SimpleActionPlanner
|
||||
from src.chat.focus_chat.planners.action_manager import ActionManager
|
||||
from src.config.config import global_config
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("planner_factory")
|
||||
|
||||
@@ -40,12 +39,7 @@ class PlannerFactory:
|
||||
Returns:
|
||||
BasePlanner: 规划器实例
|
||||
"""
|
||||
planner_type = global_config.focus_chat.planner_type
|
||||
|
||||
if planner_type not in cls._planner_types:
|
||||
logger.warning(f"{log_prefix} 未知的规划器类型: {planner_type},使用默认规划器")
|
||||
planner_type = "complex"
|
||||
|
||||
planner_class = cls._planner_types[planner_type]
|
||||
logger.info(f"{log_prefix} 使用{planner_type}规划器")
|
||||
planner_class = cls._planner_types["simple"]
|
||||
logger.info(f"{log_prefix} 使用simple规划器")
|
||||
return planner_class(log_prefix=log_prefix, action_manager=action_manager)
|
||||
|
||||
@@ -11,12 +11,10 @@ from src.chat.focus_chat.info.action_info import ActionInfo
|
||||
from src.chat.focus_chat.info.structured_info import StructuredInfo
|
||||
from src.chat.focus_chat.info.self_info import SelfInfo
|
||||
from src.chat.focus_chat.info.relation_info import RelationInfo
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.individuality.individuality import individuality
|
||||
from src.individuality.individuality import get_individuality
|
||||
from src.chat.focus_chat.planners.action_manager import ActionManager
|
||||
from src.chat.focus_chat.planners.modify_actions import ActionModifier
|
||||
from src.chat.focus_chat.planners.actions.base_action import ChatMode
|
||||
from json_repair import repair_json
|
||||
from src.chat.focus_chat.planners.base_planner import BasePlanner
|
||||
from datetime import datetime
|
||||
@@ -110,8 +108,8 @@ class ActionPlanner(BasePlanner):
|
||||
nickname_str += f"{nicknames},"
|
||||
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
|
||||
|
||||
personality_block = individuality.get_personality_prompt(x_person=2, level=2)
|
||||
identity_block = individuality.get_identity_prompt(x_person=2, level=2)
|
||||
personality_block = get_individuality().get_personality_prompt(x_person=2, level=2)
|
||||
identity_block = get_individuality().get_identity_prompt(x_person=2, level=2)
|
||||
|
||||
self_info = name_block + personality_block + identity_block
|
||||
current_mind = "你思考了很久,没有想清晰要做什么"
|
||||
@@ -146,8 +144,8 @@ class ActionPlanner(BasePlanner):
|
||||
# 获取经过modify_actions处理后的最终可用动作集
|
||||
# 注意:动作的激活判定现在在主循环的modify_actions中完成
|
||||
# 使用Focus模式过滤动作
|
||||
current_available_actions_dict = self.action_manager.get_using_actions_for_mode(ChatMode.FOCUS)
|
||||
|
||||
current_available_actions_dict = self.action_manager.get_using_actions_for_mode("focus")
|
||||
|
||||
# 获取完整的动作信息
|
||||
all_registered_actions = self.action_manager.get_registered_actions()
|
||||
current_available_actions = {}
|
||||
@@ -166,7 +164,7 @@ class ActionPlanner(BasePlanner):
|
||||
logger.info(f"{self.log_prefix}{reasoning}")
|
||||
self.action_manager.restore_actions()
|
||||
logger.debug(
|
||||
f"{self.log_prefix}沉默后恢复到默认动作集, 当前可用: {list(self.action_manager.get_using_actions().keys())}"
|
||||
f"{self.log_prefix}[focus]沉默后恢复到默认动作集, 当前可用: {list(self.action_manager.get_using_actions().keys())}"
|
||||
)
|
||||
return {
|
||||
"action_result": {"action_type": action, "action_data": action_data, "reasoning": reasoning},
|
||||
@@ -193,12 +191,11 @@ class ActionPlanner(BasePlanner):
|
||||
try:
|
||||
prompt = f"{prompt}"
|
||||
llm_content, (reasoning_content, _) = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||
|
||||
# logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
|
||||
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
logger.info(f"{self.log_prefix}规划器推理: {reasoning_content}")
|
||||
|
||||
|
||||
|
||||
except Exception as req_e:
|
||||
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
||||
reasoning = f"LLM 请求失败,你的模型出现问题: {req_e}"
|
||||
@@ -219,7 +216,6 @@ class ActionPlanner(BasePlanner):
|
||||
|
||||
# 提取决策,提供默认值
|
||||
extracted_action = parsed_json.get("action", "no_reply")
|
||||
# extracted_reasoning = parsed_json.get("reasoning", "LLM未提供理由")
|
||||
extracted_reasoning = ""
|
||||
|
||||
# 将所有其他属性添加到action_data
|
||||
@@ -238,10 +234,10 @@ class ActionPlanner(BasePlanner):
|
||||
extra_info_block = ""
|
||||
|
||||
action_data["extra_info_block"] = extra_info_block
|
||||
|
||||
|
||||
if relation_info:
|
||||
action_data["relation_info_block"] = relation_info
|
||||
|
||||
|
||||
# 对于reply动作不需要额外处理,因为相关字段已经在上面的循环中添加到action_data
|
||||
|
||||
if extracted_action not in current_available_actions:
|
||||
@@ -267,10 +263,6 @@ class ActionPlanner(BasePlanner):
|
||||
action = "no_reply"
|
||||
reasoning = f"Planner 内部处理错误: {outer_e}"
|
||||
|
||||
# logger.debug(
|
||||
# f"{self.log_prefix}规划器Prompt:\n{prompt}\n\n决策动作:{action},\n动作信息: '{action_data}'\n理由: {reasoning}"
|
||||
# )
|
||||
|
||||
# 恢复到默认动作集
|
||||
self.action_manager.restore_actions()
|
||||
logger.debug(
|
||||
@@ -304,12 +296,11 @@ class ActionPlanner(BasePlanner):
|
||||
) -> str:
|
||||
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
|
||||
try:
|
||||
|
||||
if relation_info_block:
|
||||
relation_info_block = f"以下是你和别人的关系描述:\n{relation_info_block}"
|
||||
else:
|
||||
relation_info_block = ""
|
||||
|
||||
|
||||
memory_str = ""
|
||||
if running_memorys:
|
||||
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
|
||||
@@ -332,11 +323,11 @@ class ActionPlanner(BasePlanner):
|
||||
|
||||
# mind_info_block = ""
|
||||
# if current_mind:
|
||||
# mind_info_block = f"对聊天的规划:{current_mind}"
|
||||
# mind_info_block = f"对聊天的规划:{current_mind}"
|
||||
# else:
|
||||
# mind_info_block = "你刚参与聊天"
|
||||
# mind_info_block = "你刚参与聊天"
|
||||
|
||||
personality_block = individuality.get_prompt(x_person=2, level=2)
|
||||
personality_block = get_individuality().get_prompt(x_person=2, level=2)
|
||||
|
||||
action_options_block = ""
|
||||
for using_actions_name, using_actions_info in current_available_actions.items():
|
||||
@@ -352,16 +343,14 @@ class ActionPlanner(BasePlanner):
|
||||
param_text = "\n"
|
||||
for param_name, param_description in using_actions_info["parameters"].items():
|
||||
param_text += f' "{param_name}":"{param_description}"\n'
|
||||
param_text = param_text.rstrip('\n')
|
||||
param_text = param_text.rstrip("\n")
|
||||
else:
|
||||
param_text = ""
|
||||
|
||||
|
||||
require_text = ""
|
||||
for require_item in using_actions_info["require"]:
|
||||
require_text += f"- {require_item}\n"
|
||||
require_text = require_text.rstrip('\n')
|
||||
|
||||
require_text = require_text.rstrip("\n")
|
||||
|
||||
using_action_prompt = using_action_prompt.format(
|
||||
action_name=using_actions_name,
|
||||
|
||||
@@ -1,25 +1,25 @@
|
||||
import traceback
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
|
||||
from src.chat.focus_chat.expressors.exprssion_learner import get_expression_learner
|
||||
from src.chat.message_receive.message import MessageRecv, MessageThinking, MessageSending
|
||||
from src.chat.message_receive.message import Seg # Local import needed after move
|
||||
from src.chat.message_receive.message import UserInfo
|
||||
from src.chat.message_receive.chat_stream import chat_manager
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.utils_image import image_path_to_base64 # Local import needed after move
|
||||
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.focus_chat.heartFC_sender import HeartFCSender
|
||||
from src.chat.utils.utils import process_llm_response
|
||||
from src.chat.utils.info_catcher import info_catcher_manager
|
||||
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
import time
|
||||
from src.chat.focus_chat.expressors.exprssion_learner import expression_learner
|
||||
import random
|
||||
from datetime import datetime
|
||||
import re
|
||||
@@ -94,7 +94,7 @@ class DefaultReplyer:
|
||||
|
||||
self.chat_id = chat_stream.stream_id
|
||||
self.chat_stream = chat_stream
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
||||
|
||||
async def _create_thinking_message(self, anchor_message: Optional[MessageRecv], thinking_id: str):
|
||||
"""创建思考消息 (尝试锚定到 anchor_message)"""
|
||||
@@ -121,6 +121,7 @@ class DefaultReplyer:
|
||||
# logger.debug(f"创建思考消息thinking_message:{thinking_message}")
|
||||
|
||||
await self.heart_fc_sender.register_thinking(thinking_message)
|
||||
return None
|
||||
|
||||
async def deal_reply(
|
||||
self,
|
||||
@@ -140,6 +141,8 @@ class DefaultReplyer:
|
||||
# 处理文本部分
|
||||
# text_part = action_data.get("text", [])
|
||||
# if text_part:
|
||||
sent_msg_list = []
|
||||
|
||||
with Timer("生成回复", cycle_timers):
|
||||
# 可以保留原有的文本处理逻辑或进行适当调整
|
||||
reply = await self.reply(
|
||||
@@ -238,24 +241,21 @@ class DefaultReplyer:
|
||||
# current_temp = float(global_config.model.normal["temp"]) * arousal_multiplier
|
||||
# self.express_model.params["temperature"] = current_temp # 动态调整温度
|
||||
|
||||
# 2. 获取信息捕捉器
|
||||
info_catcher = info_catcher_manager.get_info_catcher(thinking_id)
|
||||
|
||||
reply_to = action_data.get("reply_to", "none")
|
||||
|
||||
|
||||
sender = ""
|
||||
targer = ""
|
||||
if ":" in reply_to or ":" in reply_to:
|
||||
# 使用正则表达式匹配中文或英文冒号
|
||||
parts = re.split(pattern=r'[::]', string=reply_to, maxsplit=1)
|
||||
parts = re.split(pattern=r"[::]", string=reply_to, maxsplit=1)
|
||||
if len(parts) == 2:
|
||||
sender = parts[0].strip()
|
||||
targer = parts[1].strip()
|
||||
|
||||
|
||||
identity = action_data.get("identity", "")
|
||||
extra_info_block = action_data.get("extra_info_block", "")
|
||||
relation_info_block = action_data.get("relation_info_block", "")
|
||||
|
||||
|
||||
# 3. 构建 Prompt
|
||||
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
||||
prompt = await self.build_prompt_focus(
|
||||
@@ -286,10 +286,6 @@ class DefaultReplyer:
|
||||
# logger.info(f"prompt: {prompt}")
|
||||
logger.info(f"最终回复: {content}")
|
||||
|
||||
info_catcher.catch_after_llm_generated(
|
||||
prompt=prompt, response=content, reasoning_content=reasoning_content, model_name=model_name
|
||||
)
|
||||
|
||||
except Exception as llm_e:
|
||||
# 精简报错信息
|
||||
logger.error(f"{self.log_prefix}LLM 生成失败: {llm_e}")
|
||||
@@ -340,13 +336,14 @@ class DefaultReplyer:
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
merge_messages=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
expression_learner = get_expression_learner()
|
||||
(
|
||||
learnt_style_expressions,
|
||||
learnt_grammar_expressions,
|
||||
@@ -378,8 +375,6 @@ class DefaultReplyer:
|
||||
|
||||
style_habbits_str = "\n".join(style_habbits)
|
||||
grammar_habbits_str = "\n".join(grammar_habbits)
|
||||
|
||||
|
||||
|
||||
# 关键词检测与反应
|
||||
keywords_reaction_prompt = ""
|
||||
@@ -411,16 +406,15 @@ class DefaultReplyer:
|
||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
# logger.debug("开始构建 focus prompt")
|
||||
|
||||
|
||||
if sender_name:
|
||||
reply_target_block = f"现在{sender_name}说的:{target_message}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
reply_target_block = (
|
||||
f"现在{sender_name}说的:{target_message}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
elif target_message:
|
||||
reply_target_block = f"现在{target_message}引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
else:
|
||||
reply_target_block = "现在,你想要在群里发言或者回复消息。"
|
||||
|
||||
|
||||
|
||||
|
||||
# --- Choose template based on chat type ---
|
||||
if is_group_chat:
|
||||
@@ -494,7 +488,7 @@ class DefaultReplyer:
|
||||
logger.error(f"{self.log_prefix} 无法发送回复,anchor_message 为空。")
|
||||
return None
|
||||
|
||||
stream_name = chat_manager.get_stream_name(chat_id) or chat_id # 获取流名称用于日志
|
||||
stream_name = get_chat_manager().get_stream_name(chat_id) or chat_id # 获取流名称用于日志
|
||||
|
||||
# 检查思考过程是否仍在进行,并获取开始时间
|
||||
if thinking_id:
|
||||
@@ -586,7 +580,7 @@ class DefaultReplyer:
|
||||
"""
|
||||
emoji_base64 = ""
|
||||
description = ""
|
||||
emoji_raw = await emoji_manager.get_emoji_for_text(send_emoji)
|
||||
emoji_raw = await get_emoji_manager().get_emoji_for_text(send_emoji)
|
||||
if emoji_raw:
|
||||
emoji_path, description, _emotion = emoji_raw
|
||||
emoji_base64 = image_path_to_base64(emoji_path)
|
||||
@@ -669,30 +663,30 @@ def find_similar_expressions(input_text: str, expressions: List[Dict], top_k: in
|
||||
"""使用TF-IDF和余弦相似度找出与输入文本最相似的top_k个表达方式"""
|
||||
if not expressions:
|
||||
return []
|
||||
|
||||
|
||||
# 准备文本数据
|
||||
texts = [expr['situation'] for expr in expressions]
|
||||
texts = [expr["situation"] for expr in expressions]
|
||||
texts.append(input_text) # 添加输入文本
|
||||
|
||||
|
||||
# 使用TF-IDF向量化
|
||||
vectorizer = TfidfVectorizer()
|
||||
tfidf_matrix = vectorizer.fit_transform(texts)
|
||||
|
||||
|
||||
# 计算余弦相似度
|
||||
similarity_matrix = cosine_similarity(tfidf_matrix)
|
||||
|
||||
|
||||
# 获取输入文本的相似度分数(最后一行)
|
||||
scores = similarity_matrix[-1][:-1] # 排除与自身的相似度
|
||||
|
||||
|
||||
# 获取top_k的索引
|
||||
top_indices = np.argsort(scores)[::-1][:top_k]
|
||||
|
||||
|
||||
# 获取相似表达
|
||||
similar_exprs = []
|
||||
for idx in top_indices:
|
||||
if scores[idx] > 0: # 只保留有相似度的
|
||||
similar_exprs.append(expressions[idx])
|
||||
|
||||
|
||||
return similar_exprs
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Dict, Any, Type, TypeVar, List, Optional
|
||||
import traceback
|
||||
from json_repair import repair_json
|
||||
from rich.traceback import install
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.chat.focus_chat.working_memory.memory_item import MemoryItem
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from typing import List, Any, Optional
|
||||
import asyncio
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.focus_chat.working_memory.memory_manager import MemoryManager, MemoryItem
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -33,8 +34,11 @@ class WorkingMemory:
|
||||
# 衰减任务
|
||||
self.decay_task = None
|
||||
|
||||
# 启动自动衰减任务
|
||||
self._start_auto_decay()
|
||||
# 只有在工作记忆处理器启用时才启动自动衰减任务
|
||||
if global_config.focus_chat_processor.working_memory_processor:
|
||||
self._start_auto_decay()
|
||||
else:
|
||||
logger.debug(f"工作记忆处理器已禁用,跳过启动自动衰减任务 (chat_id: {chat_id})")
|
||||
|
||||
def _start_auto_decay(self):
|
||||
"""启动自动衰减任务"""
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
from typing import Optional, Coroutine, Callable, Any, List
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.heart_flow.subheartflow_manager import SubHeartflowManager
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
from typing import Any, Optional, List
|
||||
from src.chat.heart_flow.subheartflow_manager import SubHeartflowManager
|
||||
from src.chat.heart_flow.background_tasks import BackgroundTaskManager # Import BackgroundTaskManager
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# 定义了来自外部世界的信息
|
||||
# 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体
|
||||
from datetime import datetime
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.focus_chat.planners.action_manager import ActionManager
|
||||
|
||||
logger = get_logger("observation")
|
||||
|
||||
@@ -14,7 +14,7 @@ import difflib
|
||||
from src.chat.message_receive.message import MessageRecv # 添加 MessageRecv 导入
|
||||
from src.chat.heart_flow.observation.observation import Observation
|
||||
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
|
||||
from src.chat.utils.prompt_builder import Prompt
|
||||
|
||||
@@ -62,13 +62,12 @@ class ChattingObservation(Observation):
|
||||
self.oldest_messages = []
|
||||
self.oldest_messages_str = ""
|
||||
self.compressor_prompt = ""
|
||||
|
||||
|
||||
initial_messages = get_raw_msg_before_timestamp_with_chat(self.chat_id, self.last_observe_time, 10)
|
||||
self.last_observe_time = initial_messages[-1]["time"] if initial_messages else self.last_observe_time
|
||||
self.talking_message = initial_messages
|
||||
self.talking_message_str = build_readable_messages(self.talking_message, show_actions=True)
|
||||
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""将观察对象转换为可序列化的字典"""
|
||||
return {
|
||||
@@ -283,12 +282,12 @@ class ChattingObservation(Observation):
|
||||
show_actions=True,
|
||||
)
|
||||
# print(f"构建中:self.talking_message_str_truncate: {self.talking_message_str_truncate}")
|
||||
|
||||
|
||||
self.person_list = await get_person_id_list(self.talking_message)
|
||||
|
||||
# print(f"构建中:self.person_list: {self.person_list}")
|
||||
|
||||
logger.trace(
|
||||
logger.debug(
|
||||
f"Chat {self.chat_id} - 压缩早期记忆:{self.mid_memory_info}\n现在聊天内容:{self.talking_message_str}"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# 定义了来自外部世界的信息
|
||||
# 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体
|
||||
from datetime import datetime
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail
|
||||
from typing import List
|
||||
# Import the new utility function
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# 定义了来自外部世界的信息
|
||||
# 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体
|
||||
from datetime import datetime
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("observation")
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# Import the new utility function
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# 定义了来自外部世界的信息
|
||||
# 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体
|
||||
from datetime import datetime
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
|
||||
from src.chat.focus_chat.working_memory.memory_item import MemoryItem
|
||||
from typing import List
|
||||
|
||||
@@ -4,9 +4,9 @@ import asyncio
|
||||
import time
|
||||
from typing import Optional, List, Dict, Tuple
|
||||
import traceback
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.chat_stream import chat_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.focus_chat.heartFC_chat import HeartFChatting
|
||||
from src.chat.normal_chat.normal_chat import NormalChat
|
||||
from src.chat.heart_flow.chat_state_info import ChatState, ChatStateInfo
|
||||
@@ -42,9 +42,7 @@ class SubHeartflow:
|
||||
self.history_chat_state: List[Tuple[ChatState, float]] = []
|
||||
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
||||
self.log_prefix = (
|
||||
chat_manager.get_stream_name(self.subheartflow_id) or self.subheartflow_id
|
||||
)
|
||||
self.log_prefix = get_chat_manager().get_stream_name(self.subheartflow_id) or self.subheartflow_id
|
||||
# 兴趣消息集合
|
||||
self.interest_dict: Dict[str, tuple[MessageRecv, float, bool]] = {}
|
||||
|
||||
@@ -105,7 +103,7 @@ class SubHeartflow:
|
||||
log_prefix = self.log_prefix
|
||||
try:
|
||||
# 获取聊天流并创建 NormalChat 实例 (同步部分)
|
||||
chat_stream = chat_manager.get_stream(self.chat_id)
|
||||
chat_stream = get_chat_manager().get_stream(self.chat_id)
|
||||
if not chat_stream:
|
||||
logger.error(f"{log_prefix} 无法获取 chat_stream,无法启动 NormalChat。")
|
||||
return False
|
||||
@@ -199,7 +197,6 @@ class SubHeartflow:
|
||||
# 如果实例不存在,则创建并启动
|
||||
logger.info(f"{log_prefix} 麦麦准备开始专注聊天...")
|
||||
try:
|
||||
|
||||
self.heart_fc_instance = HeartFChatting(
|
||||
chat_id=self.subheartflow_id,
|
||||
# observations=self.observations,
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, Any, Optional, List
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.message_receive.chat_stream import chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ async def _try_set_subflow_absent_internal(subflow: "SubHeartflow", log_prefix:
|
||||
bool: 如果状态成功变为 ABSENT 或原本就是 ABSENT,返回 True;否则返回 False。
|
||||
"""
|
||||
flow_id = subflow.subheartflow_id
|
||||
stream_name = chat_manager.get_stream_name(flow_id) or flow_id
|
||||
stream_name = get_chat_manager().get_stream_name(flow_id) or flow_id
|
||||
|
||||
if subflow.chat_state.chat_status != ChatState.ABSENT:
|
||||
logger.debug(f"{log_prefix} 设置 {stream_name} 状态为 ABSENT")
|
||||
@@ -106,7 +106,7 @@ class SubHeartflowManager:
|
||||
|
||||
# 注册子心流
|
||||
self.subheartflows[subheartflow_id] = new_subflow
|
||||
heartflow_name = chat_manager.get_stream_name(subheartflow_id) or subheartflow_id
|
||||
heartflow_name = get_chat_manager().get_stream_name(subheartflow_id) or subheartflow_id
|
||||
logger.info(f"[{heartflow_name}] 开始接收消息")
|
||||
|
||||
return new_subflow
|
||||
@@ -120,7 +120,7 @@ class SubHeartflowManager:
|
||||
async with self._lock: # 加锁以安全访问字典
|
||||
subheartflow = self.subheartflows.get(subheartflow_id)
|
||||
|
||||
stream_name = chat_manager.get_stream_name(subheartflow_id) or subheartflow_id
|
||||
stream_name = get_chat_manager().get_stream_name(subheartflow_id) or subheartflow_id
|
||||
logger.info(f"{log_prefix} 正在停止 {stream_name}, 原因: {reason}")
|
||||
|
||||
# 调用内部方法处理状态变更
|
||||
@@ -170,7 +170,9 @@ class SubHeartflowManager:
|
||||
changed_count += 1
|
||||
else:
|
||||
# 这种情况理论上不应发生,如果内部方法返回 True 的话
|
||||
stream_name = chat_manager.get_stream_name(subflow.subheartflow_id) or subflow.subheartflow_id
|
||||
stream_name = (
|
||||
get_chat_manager().get_stream_name(subflow.subheartflow_id) or subflow.subheartflow_id
|
||||
)
|
||||
logger.warning(f"{log_prefix} 内部方法声称成功但 {stream_name} 状态未变为 ABSENT。")
|
||||
# 锁在此处自动释放
|
||||
|
||||
@@ -183,7 +185,7 @@ class SubHeartflowManager:
|
||||
# try:
|
||||
# for sub_hf in list(self.subheartflows.values()):
|
||||
# flow_id = sub_hf.subheartflow_id
|
||||
# stream_name = chat_manager.get_stream_name(flow_id) or flow_id
|
||||
# stream_name = get_chat_manager().get_stream_name(flow_id) or flow_id
|
||||
|
||||
# # 跳过已经是FOCUSED状态的子心流
|
||||
# if sub_hf.chat_state.chat_status == ChatState.FOCUSED:
|
||||
@@ -229,7 +231,7 @@ class SubHeartflowManager:
|
||||
logger.warning(f"[状态转换请求] 尝试转换不存在的子心流 {subflow_id} 到 NORMAL")
|
||||
return
|
||||
|
||||
stream_name = chat_manager.get_stream_name(subflow_id) or subflow_id
|
||||
stream_name = get_chat_manager().get_stream_name(subflow_id) or subflow_id
|
||||
current_state = subflow.chat_state.chat_status
|
||||
|
||||
if current_state == ChatState.FOCUSED:
|
||||
@@ -298,7 +300,7 @@ class SubHeartflowManager:
|
||||
# --- 遍历评估每个符合条件的私聊 --- #
|
||||
for sub_hf in eligible_subflows:
|
||||
flow_id = sub_hf.subheartflow_id
|
||||
stream_name = chat_manager.get_stream_name(flow_id) or flow_id
|
||||
stream_name = get_chat_manager().get_stream_name(flow_id) or flow_id
|
||||
log_prefix = f"[{stream_name}]({log_prefix_task})"
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Optional, Tuple, Dict
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.message_receive.chat_stream import chat_manager
|
||||
from src.person_info.person_info import person_info_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
|
||||
logger = get_logger("heartflow_utils")
|
||||
|
||||
@@ -23,7 +23,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
||||
chat_target_info = None
|
||||
|
||||
try:
|
||||
chat_stream = chat_manager.get_stream(chat_id)
|
||||
chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
|
||||
if chat_stream:
|
||||
if chat_stream.group_info:
|
||||
@@ -47,10 +47,11 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
||||
# Try to fetch person info
|
||||
try:
|
||||
# Assume get_person_id is sync (as per original code), keep using to_thread
|
||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
person_name = None
|
||||
if person_id:
|
||||
# get_value is async, so await it directly
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_name = person_info_manager.get_value_sync(person_id, "person_name")
|
||||
|
||||
target_info["person_id"] = person_id
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Configure logger
|
||||
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("lpmm")
|
||||
|
||||
@@ -132,9 +132,6 @@ global_config = dict(
|
||||
}
|
||||
)
|
||||
|
||||
# _load_config(global_config, parser.parse_args().config_path)
|
||||
# file_path = os.path.abspath(__file__)
|
||||
# dir_path = os.path.dirname(file_path)
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
config_path = os.path.join(ROOT_PATH, "config", "lpmm_config.toml")
|
||||
_load_config(global_config, config_path)
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
|
||||
from .global_logger import logger
|
||||
from .lpmmconfig import global_config
|
||||
from src.chat.knowledge.utils import get_sha256
|
||||
from src.chat.knowledge.utils.hash import get_sha256
|
||||
|
||||
|
||||
def load_raw_data(path: str = None) -> tuple[list[str], list[str]]:
|
||||
@@ -25,10 +25,10 @@ def load_raw_data(path: str = None) -> tuple[list[str], list[str]]:
|
||||
import_json = json.loads(f.read())
|
||||
else:
|
||||
raise Exception(f"原始数据文件读取失败: {json_path}")
|
||||
# import_json内容示例:
|
||||
# import_json = [
|
||||
# "The capital of China is Beijing. The capital of France is Paris.",
|
||||
# ]
|
||||
"""
|
||||
import_json 内容示例:
|
||||
import_json = ["The capital of China is Beijing. The capital of France is Paris.",]
|
||||
"""
|
||||
raw_data = []
|
||||
sha256_list = []
|
||||
sha256_set = set()
|
||||
|
||||
@@ -12,7 +12,7 @@ import networkx as nx
|
||||
import numpy as np
|
||||
from collections import Counter
|
||||
from ...llm_models.utils_model import LLMRequest
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
|
||||
from ..utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp,
|
||||
@@ -346,7 +346,9 @@ class Hippocampus:
|
||||
# 使用LLM提取关键词
|
||||
topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量
|
||||
# logger.info(f"提取关键词数量: {topic_num}")
|
||||
topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async(self.find_topic_llm(text, topic_num))
|
||||
topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async(
|
||||
self.find_topic_llm(text, topic_num)
|
||||
)
|
||||
|
||||
# 提取关键词
|
||||
keywords = re.findall(r"<([^>]+)>", topics_response)
|
||||
@@ -407,9 +409,9 @@ class Hippocampus:
|
||||
activation_values[neighbor] = new_activation
|
||||
visited_nodes.add(neighbor)
|
||||
nodes_to_process.append((neighbor, new_activation, current_depth + 1))
|
||||
logger.trace(
|
||||
f"节点 '{neighbor}' 被激活,激活值: {new_activation:.2f} (通过 '{current_node}' 连接,强度: {strength}, 深度: {current_depth + 1})"
|
||||
) # noqa: E501
|
||||
# logger.debug(
|
||||
# f"节点 '{neighbor}' 被激活,激活值: {new_activation:.2f} (通过 '{current_node}' 连接,强度: {strength}, 深度: {current_depth + 1})"
|
||||
# ) # noqa: E501
|
||||
|
||||
# 更新激活映射
|
||||
for node, activation_value in activation_values.items():
|
||||
@@ -578,9 +580,9 @@ class Hippocampus:
|
||||
activation_values[neighbor] = new_activation
|
||||
visited_nodes.add(neighbor)
|
||||
nodes_to_process.append((neighbor, new_activation, current_depth + 1))
|
||||
logger.trace(
|
||||
f"节点 '{neighbor}' 被激活,激活值: {new_activation:.2f} (通过 '{current_node}' 连接,强度: {strength}, 深度: {current_depth + 1})"
|
||||
) # noqa: E501
|
||||
# logger.debug(
|
||||
# f"节点 '{neighbor}' 被激活,激活值: {new_activation:.2f} (通过 '{current_node}' 连接,强度: {strength}, 深度: {current_depth + 1})"
|
||||
# ) # noqa: E501
|
||||
|
||||
# 更新激活映射
|
||||
for node, activation_value in activation_values.items():
|
||||
@@ -701,7 +703,9 @@ class Hippocampus:
|
||||
# 使用LLM提取关键词
|
||||
topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量
|
||||
# logger.info(f"提取关键词数量: {topic_num}")
|
||||
topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async(self.find_topic_llm(text, topic_num))
|
||||
topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async(
|
||||
self.find_topic_llm(text, topic_num)
|
||||
)
|
||||
|
||||
# 提取关键词
|
||||
keywords = re.findall(r"<([^>]+)>", topics_response)
|
||||
@@ -729,7 +733,7 @@ class Hippocampus:
|
||||
|
||||
# 对每个关键词进行扩散式检索
|
||||
for keyword in valid_keywords:
|
||||
logger.trace(f"开始以关键词 '{keyword}' 为中心进行扩散检索 (最大深度: {max_depth}):")
|
||||
logger.debug(f"开始以关键词 '{keyword}' 为中心进行扩散检索 (最大深度: {max_depth}):")
|
||||
# 初始化激活值
|
||||
activation_values = {keyword: 1.0}
|
||||
# 记录已访问的节点
|
||||
@@ -780,7 +784,7 @@ class Hippocampus:
|
||||
|
||||
# 计算激活节点数与总节点数的比值
|
||||
total_activation = sum(activate_map.values())
|
||||
logger.trace(f"总激活值: {total_activation:.2f}")
|
||||
logger.debug(f"总激活值: {total_activation:.2f}")
|
||||
total_nodes = len(self.memory_graph.G.nodes())
|
||||
# activated_nodes = len(activate_map)
|
||||
activation_ratio = total_activation / total_nodes if total_nodes > 0 else 0
|
||||
@@ -825,7 +829,7 @@ class EntorhinalCortex:
|
||||
)
|
||||
if messages:
|
||||
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
|
||||
logger.success(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}条")
|
||||
logger.info(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}条")
|
||||
chat_samples.append(messages)
|
||||
else:
|
||||
logger.debug(f"时间戳 {timestamp} 的消息无需记忆")
|
||||
@@ -893,7 +897,7 @@ class EntorhinalCortex:
|
||||
# 获取数据库中所有节点和内存中所有节点
|
||||
db_nodes = {node.concept: node for node in GraphNodes.select()}
|
||||
memory_nodes = list(self.memory_graph.G.nodes(data=True))
|
||||
|
||||
|
||||
# 批量准备节点数据
|
||||
nodes_to_create = []
|
||||
nodes_to_update = []
|
||||
@@ -929,22 +933,26 @@ class EntorhinalCortex:
|
||||
continue
|
||||
|
||||
if concept not in db_nodes:
|
||||
nodes_to_create.append({
|
||||
"concept": concept,
|
||||
"memory_items": memory_items_json,
|
||||
"hash": memory_hash,
|
||||
"created_time": created_time,
|
||||
"last_modified": last_modified,
|
||||
})
|
||||
else:
|
||||
db_node = db_nodes[concept]
|
||||
if db_node.hash != memory_hash:
|
||||
nodes_to_update.append({
|
||||
nodes_to_create.append(
|
||||
{
|
||||
"concept": concept,
|
||||
"memory_items": memory_items_json,
|
||||
"hash": memory_hash,
|
||||
"created_time": created_time,
|
||||
"last_modified": last_modified,
|
||||
})
|
||||
}
|
||||
)
|
||||
else:
|
||||
db_node = db_nodes[concept]
|
||||
if db_node.hash != memory_hash:
|
||||
nodes_to_update.append(
|
||||
{
|
||||
"concept": concept,
|
||||
"memory_items": memory_items_json,
|
||||
"hash": memory_hash,
|
||||
"last_modified": last_modified,
|
||||
}
|
||||
)
|
||||
|
||||
# 计算需要删除的节点
|
||||
memory_concepts = {concept for concept, _ in memory_nodes}
|
||||
@@ -954,13 +962,13 @@ class EntorhinalCortex:
|
||||
if nodes_to_create:
|
||||
batch_size = 100
|
||||
for i in range(0, len(nodes_to_create), batch_size):
|
||||
batch = nodes_to_create[i:i + batch_size]
|
||||
batch = nodes_to_create[i : i + batch_size]
|
||||
GraphNodes.insert_many(batch).execute()
|
||||
|
||||
if nodes_to_update:
|
||||
batch_size = 100
|
||||
for i in range(0, len(nodes_to_update), batch_size):
|
||||
batch = nodes_to_update[i:i + batch_size]
|
||||
batch = nodes_to_update[i : i + batch_size]
|
||||
for node_data in batch:
|
||||
GraphNodes.update(**{k: v for k, v in node_data.items() if k != "concept"}).where(
|
||||
GraphNodes.concept == node_data["concept"]
|
||||
@@ -992,22 +1000,26 @@ class EntorhinalCortex:
|
||||
last_modified = data.get("last_modified", current_time)
|
||||
|
||||
if edge_key not in db_edge_dict:
|
||||
edges_to_create.append({
|
||||
"source": source,
|
||||
"target": target,
|
||||
"strength": strength,
|
||||
"hash": edge_hash,
|
||||
"created_time": created_time,
|
||||
"last_modified": last_modified,
|
||||
})
|
||||
edges_to_create.append(
|
||||
{
|
||||
"source": source,
|
||||
"target": target,
|
||||
"strength": strength,
|
||||
"hash": edge_hash,
|
||||
"created_time": created_time,
|
||||
"last_modified": last_modified,
|
||||
}
|
||||
)
|
||||
elif db_edge_dict[edge_key]["hash"] != edge_hash:
|
||||
edges_to_update.append({
|
||||
"source": source,
|
||||
"target": target,
|
||||
"strength": strength,
|
||||
"hash": edge_hash,
|
||||
"last_modified": last_modified,
|
||||
})
|
||||
edges_to_update.append(
|
||||
{
|
||||
"source": source,
|
||||
"target": target,
|
||||
"strength": strength,
|
||||
"hash": edge_hash,
|
||||
"last_modified": last_modified,
|
||||
}
|
||||
)
|
||||
|
||||
# 计算需要删除的边
|
||||
memory_edge_keys = {(source, target) for source, target, _ in memory_edges}
|
||||
@@ -1017,13 +1029,13 @@ class EntorhinalCortex:
|
||||
if edges_to_create:
|
||||
batch_size = 100
|
||||
for i in range(0, len(edges_to_create), batch_size):
|
||||
batch = edges_to_create[i:i + batch_size]
|
||||
batch = edges_to_create[i : i + batch_size]
|
||||
GraphEdges.insert_many(batch).execute()
|
||||
|
||||
if edges_to_update:
|
||||
batch_size = 100
|
||||
for i in range(0, len(edges_to_update), batch_size):
|
||||
batch = edges_to_update[i:i + batch_size]
|
||||
batch = edges_to_update[i : i + batch_size]
|
||||
for edge_data in batch:
|
||||
GraphEdges.update(**{k: v for k, v in edge_data.items() if k not in ["source", "target"]}).where(
|
||||
(GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"])
|
||||
@@ -1031,13 +1043,11 @@ class EntorhinalCortex:
|
||||
|
||||
if edges_to_delete:
|
||||
for source, target in edges_to_delete:
|
||||
GraphEdges.delete().where(
|
||||
(GraphEdges.source == source) & (GraphEdges.target == target)
|
||||
).execute()
|
||||
GraphEdges.delete().where((GraphEdges.source == source) & (GraphEdges.target == target)).execute()
|
||||
|
||||
end_time = time.time()
|
||||
logger.success(f"[同步] 总耗时: {end_time - start_time:.2f}秒")
|
||||
logger.success(f"[同步] 同步了 {len(memory_nodes)} 个节点和 {len(memory_edges)} 条边")
|
||||
logger.info(f"[同步] 总耗时: {end_time - start_time:.2f}秒")
|
||||
logger.info(f"[同步] 同步了 {len(memory_nodes)} 个节点和 {len(memory_edges)} 条边")
|
||||
|
||||
async def resync_memory_to_db(self):
|
||||
"""清空数据库并重新同步所有记忆数据"""
|
||||
@@ -1069,13 +1079,15 @@ class EntorhinalCortex:
|
||||
if not memory_items_json:
|
||||
continue
|
||||
|
||||
nodes_data.append({
|
||||
"concept": concept,
|
||||
"memory_items": memory_items_json,
|
||||
"hash": self.hippocampus.calculate_node_hash(concept, memory_items),
|
||||
"created_time": data.get("created_time", current_time),
|
||||
"last_modified": data.get("last_modified", current_time),
|
||||
})
|
||||
nodes_data.append(
|
||||
{
|
||||
"concept": concept,
|
||||
"memory_items": memory_items_json,
|
||||
"hash": self.hippocampus.calculate_node_hash(concept, memory_items),
|
||||
"created_time": data.get("created_time", current_time),
|
||||
"last_modified": data.get("last_modified", current_time),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"准备节点 {concept} 数据时发生错误: {e}")
|
||||
continue
|
||||
@@ -1084,14 +1096,16 @@ class EntorhinalCortex:
|
||||
edges_data = []
|
||||
for source, target, data in memory_edges:
|
||||
try:
|
||||
edges_data.append({
|
||||
"source": source,
|
||||
"target": target,
|
||||
"strength": data.get("strength", 1),
|
||||
"hash": self.hippocampus.calculate_edge_hash(source, target),
|
||||
"created_time": data.get("created_time", current_time),
|
||||
"last_modified": data.get("last_modified", current_time),
|
||||
})
|
||||
edges_data.append(
|
||||
{
|
||||
"source": source,
|
||||
"target": target,
|
||||
"strength": data.get("strength", 1),
|
||||
"hash": self.hippocampus.calculate_edge_hash(source, target),
|
||||
"created_time": data.get("created_time", current_time),
|
||||
"last_modified": data.get("last_modified", current_time),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"准备边 {source}-{target} 数据时发生错误: {e}")
|
||||
continue
|
||||
@@ -1102,7 +1116,7 @@ class EntorhinalCortex:
|
||||
batch_size = 500 # 增加批量大小
|
||||
with GraphNodes._meta.database.atomic():
|
||||
for i in range(0, len(nodes_data), batch_size):
|
||||
batch = nodes_data[i:i + batch_size]
|
||||
batch = nodes_data[i : i + batch_size]
|
||||
GraphNodes.insert_many(batch).execute()
|
||||
node_end = time.time()
|
||||
logger.info(f"[数据库] 写入 {len(nodes_data)} 个节点耗时: {node_end - node_start:.2f}秒")
|
||||
@@ -1113,14 +1127,14 @@ class EntorhinalCortex:
|
||||
batch_size = 500 # 增加批量大小
|
||||
with GraphEdges._meta.database.atomic():
|
||||
for i in range(0, len(edges_data), batch_size):
|
||||
batch = edges_data[i:i + batch_size]
|
||||
batch = edges_data[i : i + batch_size]
|
||||
GraphEdges.insert_many(batch).execute()
|
||||
edge_end = time.time()
|
||||
logger.info(f"[数据库] 写入 {len(edges_data)} 条边耗时: {edge_end - edge_start:.2f}秒")
|
||||
|
||||
end_time = time.time()
|
||||
logger.success(f"[数据库] 重新同步完成,总耗时: {end_time - start_time:.2f}秒")
|
||||
logger.success(f"[数据库] 同步了 {len(nodes_data)} 个节点和 {len(edges_data)} 条边")
|
||||
logger.info(f"[数据库] 重新同步完成,总耗时: {end_time - start_time:.2f}秒")
|
||||
logger.info(f"[数据库] 同步了 {len(nodes_data)} 个节点和 {len(edges_data)} 条边")
|
||||
|
||||
def sync_memory_from_db(self):
|
||||
"""从数据库同步数据到内存中的图结构"""
|
||||
@@ -1195,7 +1209,7 @@ class EntorhinalCortex:
|
||||
)
|
||||
|
||||
if need_update:
|
||||
logger.success("[数据库] 已为缺失的时间字段进行补充")
|
||||
logger.info("[数据库] 已为缺失的时间字段进行补充")
|
||||
|
||||
|
||||
# 负责整合,遗忘,合并记忆
|
||||
@@ -1240,9 +1254,8 @@ class ParahippocampalGyrus:
|
||||
logger.warning("无法从提供的消息生成可读文本,跳过记忆压缩。")
|
||||
return set(), {}
|
||||
|
||||
current_YMD_time = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
current_YMD_time_str = f"当前日期: {current_YMD_time}"
|
||||
input_text = f"{current_YMD_time_str}\n{input_text}"
|
||||
current_date = f"当前日期: {datetime.datetime.now().isoformat()}"
|
||||
input_text = f"{current_date}\n{input_text}"
|
||||
|
||||
logger.debug(f"记忆来源:\n{input_text}")
|
||||
|
||||
@@ -1374,7 +1387,7 @@ class ParahippocampalGyrus:
|
||||
logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})")
|
||||
|
||||
if all_added_nodes:
|
||||
logger.success(f"更新记忆: {', '.join(all_added_nodes)}")
|
||||
logger.info(f"更新记忆: {', '.join(all_added_nodes)}")
|
||||
if all_added_edges:
|
||||
logger.debug(f"强化连接: {', '.join(all_added_edges)}")
|
||||
if all_connected_nodes:
|
||||
@@ -1383,7 +1396,7 @@ class ParahippocampalGyrus:
|
||||
await self.hippocampus.entorhinal_cortex.sync_memory_to_db()
|
||||
|
||||
end_time = time.time()
|
||||
logger.success(f"---------------------记忆构建耗时: {end_time - start_time:.2f} 秒---------------------")
|
||||
logger.info(f"---------------------记忆构建耗时: {end_time - start_time:.2f} 秒---------------------")
|
||||
|
||||
async def operation_forget_topic(self, percentage=0.005):
|
||||
start_time = time.time()
|
||||
@@ -1592,8 +1605,8 @@ class ParahippocampalGyrus:
|
||||
|
||||
if similarity >= similarity_threshold:
|
||||
logger.debug(f"[整合] 节点 '{node}' 中发现相似项 (相似度: {similarity:.2f}):")
|
||||
logger.trace(f" - '{item1}'")
|
||||
logger.trace(f" - '{item2}'")
|
||||
logger.debug(f" - '{item1}'")
|
||||
logger.debug(f" - '{item2}'")
|
||||
|
||||
# 比较信息量
|
||||
info1 = calculate_information_content(item1)
|
||||
@@ -1655,21 +1668,9 @@ class ParahippocampalGyrus:
|
||||
|
||||
|
||||
class HippocampusManager:
|
||||
_instance = None
|
||||
_hippocampus = None
|
||||
_initialized = False
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def get_hippocampus(cls):
|
||||
if not cls._initialized:
|
||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||
return cls._hippocampus
|
||||
def __init__(self):
|
||||
self._hippocampus = None
|
||||
self._initialized = False
|
||||
|
||||
def initialize(self):
|
||||
"""初始化海马体实例"""
|
||||
@@ -1685,7 +1686,7 @@ class HippocampusManager:
|
||||
node_count = len(memory_graph.nodes())
|
||||
edge_count = len(memory_graph.edges())
|
||||
|
||||
logger.success(f"""--------------------------------
|
||||
logger.info(f"""--------------------------------
|
||||
记忆系统参数配置:
|
||||
构建间隔: {global_config.memory.memory_build_interval}秒|样本数: {global_config.memory.memory_build_sample_num},长度: {global_config.memory.memory_build_sample_length}|压缩率: {global_config.memory.memory_compress_rate}
|
||||
记忆构建分布: {global_config.memory.memory_build_distribution}
|
||||
@@ -1695,6 +1696,11 @@ class HippocampusManager:
|
||||
|
||||
return self._hippocampus
|
||||
|
||||
def get_hippocampus(self):
|
||||
if not self._initialized:
|
||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||
return self._hippocampus
|
||||
|
||||
async def build_memory(self):
|
||||
"""构建记忆的公共接口"""
|
||||
if not self._initialized:
|
||||
@@ -1772,3 +1778,7 @@ class HippocampusManager:
|
||||
if not self._initialized:
|
||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||
return self._hippocampus.get_all_node_names()
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
hippocampus_manager = HippocampusManager()
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.chat.message_receive.chat_stream import chat_manager
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.message_sender import message_manager
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
|
||||
|
||||
__all__ = [
|
||||
"emoji_manager",
|
||||
"chat_manager",
|
||||
"get_emoji_manager",
|
||||
"get_chat_manager",
|
||||
"message_manager",
|
||||
"MessageStorage",
|
||||
]
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
import traceback
|
||||
from typing import Dict, Any
|
||||
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
from src.manager.mood_manager import mood_manager # 导入情绪管理器
|
||||
from src.chat.message_receive.chat_stream import chat_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.experimental.only_message_process import MessageProcessor
|
||||
from src.experimental.PFC.pfc_manager import PFCManager
|
||||
from src.chat.focus_chat.heartflow_message_processor import HeartFCMessageReceiver
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.core.component_registry import component_registry # 导入新插件系统
|
||||
|
||||
# 定义日志配置
|
||||
|
||||
@@ -32,7 +33,7 @@ class ChatBot:
|
||||
async def _ensure_started(self):
|
||||
"""确保所有任务已启动"""
|
||||
if not self._started:
|
||||
logger.trace("确保ChatBot所有任务已启动")
|
||||
logger.debug("确保ChatBot所有任务已启动")
|
||||
|
||||
self._started = True
|
||||
|
||||
@@ -47,6 +48,60 @@ class ChatBot:
|
||||
except Exception as e:
|
||||
logger.error(f"创建PFC聊天失败: {e}")
|
||||
|
||||
async def _process_commands_with_new_system(self, message: MessageRecv):
|
||||
"""使用新插件系统处理命令"""
|
||||
try:
|
||||
if not message.processed_plain_text:
|
||||
await message.process()
|
||||
|
||||
text = message.processed_plain_text
|
||||
|
||||
# 使用新的组件注册中心查找命令
|
||||
command_result = component_registry.find_command_by_text(text)
|
||||
if command_result:
|
||||
command_class, matched_groups, intercept_message, plugin_name = command_result
|
||||
|
||||
# 获取插件配置
|
||||
plugin_config = component_registry.get_plugin_config(plugin_name)
|
||||
|
||||
# 创建命令实例
|
||||
command_instance = command_class(message, plugin_config)
|
||||
command_instance.set_matched_groups(matched_groups)
|
||||
|
||||
try:
|
||||
# 执行命令
|
||||
success, response = await command_instance.execute()
|
||||
|
||||
# 记录命令执行结果
|
||||
if success:
|
||||
logger.info(f"命令执行成功: {command_class.__name__} (拦截: {intercept_message})")
|
||||
else:
|
||||
logger.warning(f"命令执行失败: {command_class.__name__} - {response}")
|
||||
|
||||
# 根据命令的拦截设置决定是否继续处理消息
|
||||
return True, response, not intercept_message # 找到命令,根据intercept_message决定是否继续
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行命令时出错: {command_class.__name__} - {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
try:
|
||||
await command_instance.send_reply(f"命令执行出错: {str(e)}")
|
||||
except Exception as send_error:
|
||||
logger.error(f"发送错误消息失败: {send_error}")
|
||||
|
||||
# 命令出错时,根据命令的拦截设置决定是否继续处理消息
|
||||
return True, str(e), not intercept_message
|
||||
|
||||
# 没有找到命令,继续处理消息
|
||||
return False, None, True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理命令时出错: {e}")
|
||||
return False, None, True # 出错时继续处理消息
|
||||
|
||||
async def message_process(self, message_data: Dict[str, Any]) -> None:
|
||||
"""处理转化后的统一格式消息
|
||||
这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中
|
||||
@@ -73,11 +128,30 @@ class ChatBot:
|
||||
message_data["message_info"]["user_info"]["user_id"]
|
||||
)
|
||||
# print(message_data)
|
||||
logger.trace(f"处理消息:{str(message_data)[:120]}...")
|
||||
# logger.debug(str(message_data))
|
||||
message = MessageRecv(message_data)
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
chat_manager.register_message(message)
|
||||
get_chat_manager().register_message(message)
|
||||
|
||||
# 创建聊天流
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=message.message_info.platform,
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
)
|
||||
message.update_chat_stream(chat)
|
||||
|
||||
# 处理消息内容,生成纯文本
|
||||
await message.process()
|
||||
|
||||
# 命令处理 - 使用新插件系统检查并处理命令
|
||||
is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message)
|
||||
|
||||
# 如果是命令且不需要继续处理,则直接返回
|
||||
if is_command and not continue_process:
|
||||
logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}")
|
||||
return
|
||||
|
||||
# 确认从接口发来的message是否有自定义的prompt模板信息
|
||||
if message.message_info.template_info and not message.message_info.template_info.template_default:
|
||||
@@ -92,29 +166,23 @@ class ChatBot:
|
||||
template_group_name = None
|
||||
|
||||
async def preprocess():
|
||||
logger.trace("开始预处理消息...")
|
||||
logger.debug("开始预处理消息...")
|
||||
# 如果在私聊中
|
||||
if group_info is None:
|
||||
logger.trace("检测到私聊消息")
|
||||
logger.debug("检测到私聊消息")
|
||||
if global_config.experimental.pfc_chatting:
|
||||
logger.trace("进入PFC私聊处理流程")
|
||||
logger.debug("进入PFC私聊处理流程")
|
||||
# 创建聊天流
|
||||
logger.trace(f"为{user_info.user_id}创建/获取聊天流")
|
||||
chat = await chat_manager.get_or_create_stream(
|
||||
platform=message.message_info.platform,
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
)
|
||||
message.update_chat_stream(chat)
|
||||
logger.debug(f"为{user_info.user_id}创建/获取聊天流")
|
||||
await self.only_process_chat.process_message(message)
|
||||
await self._create_pfc_chat(message)
|
||||
# 禁止PFC,进入普通的心流消息处理逻辑
|
||||
else:
|
||||
logger.trace("进入普通心流私聊处理")
|
||||
logger.debug("进入普通心流私聊处理")
|
||||
await self.heartflow_message_receiver.process_message(message_data)
|
||||
# 群聊默认进入心流消息处理逻辑
|
||||
else:
|
||||
logger.trace(f"检测到群聊消息,群ID: {group_info.group_id}")
|
||||
logger.debug(f"检测到群聊消息,群ID: {group_info.group_id}")
|
||||
await self.heartflow_message_receiver.process_message(message_data)
|
||||
|
||||
if template_group_name:
|
||||
|
||||
@@ -13,7 +13,7 @@ from maim_message import GroupInfo, UserInfo
|
||||
if TYPE_CHECKING:
|
||||
from .message import MessageRecv
|
||||
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -135,7 +135,7 @@ class ChatManager:
|
||||
"""异步初始化"""
|
||||
try:
|
||||
await self.load_all_streams()
|
||||
logger.success(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流")
|
||||
logger.info(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流")
|
||||
except Exception as e:
|
||||
logger.error(f"聊天管理器启动失败: {str(e)}")
|
||||
|
||||
@@ -377,5 +377,11 @@ class ChatManager:
|
||||
logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True)
|
||||
|
||||
|
||||
# 创建全局单例
|
||||
chat_manager = ChatManager()
|
||||
chat_manager = None
|
||||
|
||||
|
||||
def get_chat_manager():
|
||||
global chat_manager
|
||||
if chat_manager is None:
|
||||
chat_manager = ChatManager()
|
||||
return chat_manager
|
||||
|
||||
@@ -5,11 +5,11 @@ from typing import Optional, Any, TYPE_CHECKING
|
||||
|
||||
import urllib3
|
||||
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .chat_stream import ChatStream
|
||||
from ..utils.utils_image import image_manager
|
||||
from ..utils.utils_image import get_image_manager
|
||||
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
||||
from rich.traceback import install
|
||||
|
||||
@@ -138,12 +138,12 @@ class MessageRecv(Message):
|
||||
elif seg.type == "image":
|
||||
# 如果是base64图片数据
|
||||
if isinstance(seg.data, str):
|
||||
return await image_manager.get_image_description(seg.data)
|
||||
return await get_image_manager().get_image_description(seg.data)
|
||||
return "[发了一张图片,网卡了加载不出来]"
|
||||
elif seg.type == "emoji":
|
||||
self.is_emoji = True
|
||||
if isinstance(seg.data, str):
|
||||
return await image_manager.get_emoji_description(seg.data)
|
||||
return await get_image_manager().get_emoji_description(seg.data)
|
||||
return "[发了一个表情包,网卡了加载不出来]"
|
||||
else:
|
||||
return f"[{seg.type}:{str(seg.data)}]"
|
||||
@@ -207,11 +207,11 @@ class MessageProcessBase(Message):
|
||||
elif seg.type == "image":
|
||||
# 如果是base64图片数据
|
||||
if isinstance(seg.data, str):
|
||||
return await image_manager.get_image_description(seg.data)
|
||||
return await get_image_manager().get_image_description(seg.data)
|
||||
return "[图片,网卡了加载不出来]"
|
||||
elif seg.type == "emoji":
|
||||
if isinstance(seg.data, str):
|
||||
return await image_manager.get_emoji_description(seg.data)
|
||||
return await get_image_manager().get_emoji_description(seg.data)
|
||||
return "[表情,网卡了加载不出来]"
|
||||
elif seg.type == "at":
|
||||
return f"[@{seg.data}]"
|
||||
|
||||
@@ -3,7 +3,7 @@ import asyncio
|
||||
import time
|
||||
from asyncio import Task
|
||||
from typing import Union
|
||||
from src.common.message.api import global_api
|
||||
from src.common.message.api import get_global_api
|
||||
|
||||
# from ...common.database import db # 数据库依赖似乎不需要了,注释掉
|
||||
from .message import MessageSending, MessageThinking, MessageSet
|
||||
@@ -12,7 +12,7 @@ from .storage import MessageStorage
|
||||
from ...config.config import global_config
|
||||
from ..utils.utils import truncate_message, calculate_typing_time, count_messages_between
|
||||
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -24,7 +24,7 @@ logger = get_logger("sender")
|
||||
async def send_via_ws(message: MessageSending) -> None:
|
||||
"""通过 WebSocket 发送消息"""
|
||||
try:
|
||||
await global_api.send_message(message)
|
||||
await get_global_api().send_message(message)
|
||||
except Exception as e:
|
||||
logger.error(f"WS发送失败: {e}")
|
||||
raise ValueError(f"未找到平台:{message.message_info.platform} 的url配置,请检查配置文件") from e
|
||||
@@ -41,16 +41,16 @@ async def send_message(
|
||||
thinking_start_time=message.thinking_start_time,
|
||||
is_emoji=message.is_emoji,
|
||||
)
|
||||
# logger.trace(f"{message.processed_plain_text},{typing_time},计算输入时间结束") # 减少日志
|
||||
# logger.debug(f"{message.processed_plain_text},{typing_time},计算输入时间结束") # 减少日志
|
||||
await asyncio.sleep(typing_time)
|
||||
# logger.trace(f"{message.processed_plain_text},{typing_time},等待输入时间结束") # 减少日志
|
||||
# logger.debug(f"{message.processed_plain_text},{typing_time},等待输入时间结束") # 减少日志
|
||||
# --- 结束打字延迟 ---
|
||||
|
||||
message_preview = truncate_message(message.processed_plain_text)
|
||||
|
||||
try:
|
||||
await send_via_ws(message)
|
||||
logger.success(f"发送消息 '{message_preview}' 成功") # 调整日志格式
|
||||
logger.info(f"发送消息 '{message_preview}' 成功") # 调整日志格式
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息 '{message_preview}' 失败: {str(e)}")
|
||||
|
||||
|
||||
@@ -5,9 +5,9 @@ from typing import Union
|
||||
from .message import MessageSending, MessageRecv
|
||||
from .chat_stream import ChatStream
|
||||
from ...common.database.database_model import Messages, RecalledMessages # Import Peewee models
|
||||
from src.common.logger import get_module_logger
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_module_logger("message_storage")
|
||||
logger = get_logger("message_storage")
|
||||
|
||||
|
||||
class MessageStorage:
|
||||
|
||||
@@ -4,19 +4,16 @@ import traceback
|
||||
from random import random
|
||||
from typing import List, Optional # 导入 Optional
|
||||
from maim_message import UserInfo, Seg
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
|
||||
from src.manager.mood_manager import mood_manager
|
||||
from src.chat.message_receive.chat_stream import ChatStream, chat_manager
|
||||
from src.chat.utils.info_catcher import info_catcher_manager
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
from .normal_chat_generator import NormalChatGenerator
|
||||
from ..message_receive.message import MessageSending, MessageRecv, MessageThinking, MessageSet
|
||||
from src.chat.message_receive.message_sender import message_manager
|
||||
from src.chat.utils.utils_image import image_path_to_base64
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.chat.normal_chat.willing.willing_manager import willing_manager
|
||||
from src.chat.normal_chat.willing.willing_manager import get_willing_manager
|
||||
from src.chat.normal_chat.normal_chat_utils import get_recent_message_stats
|
||||
from src.config.config import global_config
|
||||
from src.chat.focus_chat.planners.action_manager import ActionManager
|
||||
@@ -25,6 +22,7 @@ from src.chat.normal_chat.normal_chat_action_modifier import NormalChatActionMod
|
||||
from src.chat.normal_chat.normal_chat_expressor import NormalChatExpressor
|
||||
from src.chat.focus_chat.replyer.default_replyer import DefaultReplyer
|
||||
|
||||
willing_manager = get_willing_manager()
|
||||
|
||||
logger = get_logger("normal_chat")
|
||||
|
||||
@@ -35,7 +33,7 @@ class NormalChat:
|
||||
|
||||
self.chat_stream = chat_stream
|
||||
self.stream_id = chat_stream.stream_id
|
||||
self.stream_name = chat_manager.get_stream_name(self.stream_id) or self.stream_id
|
||||
self.stream_name = get_chat_manager().get_stream_name(self.stream_id) or self.stream_id
|
||||
|
||||
# 初始化Normal Chat专用表达器
|
||||
self.expressor = NormalChatExpressor(self.chat_stream)
|
||||
@@ -150,50 +148,6 @@ class NormalChat:
|
||||
|
||||
return first_bot_msg
|
||||
|
||||
# 改为实例方法
|
||||
async def _handle_emoji(self, message: MessageRecv, response: str):
|
||||
"""处理表情包"""
|
||||
if random() < global_config.normal_chat.emoji_chance:
|
||||
emoji_raw = await emoji_manager.get_emoji_for_text(response)
|
||||
if emoji_raw:
|
||||
emoji_path, description, _emotion = emoji_raw
|
||||
emoji_cq = image_path_to_base64(emoji_path)
|
||||
|
||||
thinking_time_point = round(message.message_info.time, 2)
|
||||
|
||||
message_segment = Seg(type="emoji", data=emoji_cq)
|
||||
bot_message = MessageSending(
|
||||
message_id="mt" + str(thinking_time_point),
|
||||
chat_stream=self.chat_stream, # 使用 self.chat_stream
|
||||
bot_user_info=UserInfo(
|
||||
user_id=global_config.bot.qq_account,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=message.message_info.platform,
|
||||
),
|
||||
sender_info=message.message_info.user_info,
|
||||
message_segment=message_segment,
|
||||
reply=message,
|
||||
is_head=False,
|
||||
is_emoji=True,
|
||||
apply_set_reply_logic=True,
|
||||
)
|
||||
await message_manager.add_message(bot_message)
|
||||
|
||||
# 改为实例方法 (虽然它只用 message.chat_stream, 但逻辑上属于实例)
|
||||
# async def _update_relationship(self, message: MessageRecv, response_set):
|
||||
# """更新关系情绪"""
|
||||
# ori_response = ",".join(response_set)
|
||||
# stance, emotion = await self.gpt._get_emotion_tags(ori_response, message.processed_plain_text)
|
||||
# user_info = message.message_info.user_info
|
||||
# platform = user_info.platform
|
||||
# await relationship_manager.calculate_update_relationship_value(
|
||||
# user_info,
|
||||
# platform,
|
||||
# label=emotion,
|
||||
# stance=stance, # 使用 self.chat_stream
|
||||
# )
|
||||
# self.mood_manager.update_mood_from_emotion(emotion, global_config.mood.mood_intensity_factor)
|
||||
|
||||
async def _reply_interested_message(self) -> None:
|
||||
"""
|
||||
后台任务方法,轮询当前实例关联chat的兴趣消息
|
||||
@@ -298,9 +252,6 @@ class NormalChat:
|
||||
|
||||
logger.debug(f"[{self.stream_name}] 创建捕捉器,thinking_id:{thinking_id}")
|
||||
|
||||
info_catcher = info_catcher_manager.get_info_catcher(thinking_id)
|
||||
info_catcher.catch_decide_to_response(message)
|
||||
|
||||
# 如果启用planner,预先修改可用actions(避免在并行任务中重复调用)
|
||||
available_actions = None
|
||||
if self.enable_planner:
|
||||
@@ -336,13 +287,17 @@ class NormalChat:
|
||||
try:
|
||||
# 获取发送者名称(动作修改已在并行执行前完成)
|
||||
sender_name = self._get_sender_name(message)
|
||||
|
||||
|
||||
no_action = {
|
||||
"action_result": {"action_type": "no_action", "action_data": {}, "reasoning": "规划器初始化默认", "is_parallel": True},
|
||||
"action_result": {
|
||||
"action_type": "no_action",
|
||||
"action_data": {},
|
||||
"reasoning": "规划器初始化默认",
|
||||
"is_parallel": True,
|
||||
},
|
||||
"chat_context": "",
|
||||
"action_prompt": "",
|
||||
}
|
||||
|
||||
|
||||
# 检查是否应该跳过规划
|
||||
if self.action_modifier.should_skip_planning():
|
||||
@@ -357,7 +312,9 @@ class NormalChat:
|
||||
reasoning = plan_result["action_result"]["reasoning"]
|
||||
is_parallel = plan_result["action_result"].get("is_parallel", False)
|
||||
|
||||
logger.info(f"[{self.stream_name}] Planner决策: {action_type}, 理由: {reasoning}, 并行执行: {is_parallel}")
|
||||
logger.info(
|
||||
f"[{self.stream_name}] Planner决策: {action_type}, 理由: {reasoning}, 并行执行: {is_parallel}"
|
||||
)
|
||||
self.action_type = action_type # 更新实例属性
|
||||
self.is_parallel_action = is_parallel # 新增:保存并行执行标志
|
||||
|
||||
@@ -376,7 +333,12 @@ class NormalChat:
|
||||
else:
|
||||
logger.warning(f"[{self.stream_name}] 额外动作 {action_type} 执行失败")
|
||||
|
||||
return {"action_type": action_type, "action_data": action_data, "reasoning": reasoning, "is_parallel": is_parallel}
|
||||
return {
|
||||
"action_type": action_type,
|
||||
"action_data": action_data,
|
||||
"reasoning": reasoning,
|
||||
"is_parallel": is_parallel,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.stream_name}] Planner执行失败: {e}")
|
||||
@@ -394,21 +356,25 @@ class NormalChat:
|
||||
if isinstance(response_set, Exception):
|
||||
logger.error(f"[{self.stream_name}] 回复生成异常: {response_set}")
|
||||
response_set = None
|
||||
elif response_set:
|
||||
info_catcher.catch_after_generate_response(timing_results["并行生成回复和规划"])
|
||||
|
||||
# 处理规划结果(可选,不影响回复)
|
||||
if isinstance(plan_result, Exception):
|
||||
logger.error(f"[{self.stream_name}] 动作规划异常: {plan_result}")
|
||||
elif plan_result:
|
||||
logger.debug(f"[{self.stream_name}] 额外动作处理完成: {self.action_type}")
|
||||
|
||||
|
||||
if not response_set or (
|
||||
self.enable_planner and self.action_type not in ["no_action", "change_to_focus_chat"] and not self.is_parallel_action
|
||||
self.enable_planner
|
||||
and self.action_type not in ["no_action", "change_to_focus_chat"]
|
||||
and not self.is_parallel_action
|
||||
):
|
||||
if not response_set:
|
||||
logger.info(f"[{self.stream_name}] 模型未生成回复内容")
|
||||
elif self.enable_planner and self.action_type not in ["no_action", "change_to_focus_chat"] and not self.is_parallel_action:
|
||||
elif (
|
||||
self.enable_planner
|
||||
and self.action_type not in ["no_action", "change_to_focus_chat"]
|
||||
and not self.is_parallel_action
|
||||
):
|
||||
logger.info(f"[{self.stream_name}] 模型选择其他动作(非并行动作)")
|
||||
# 如果模型未生成回复,移除思考消息
|
||||
container = await message_manager.get_container(self.stream_id) # 使用 self.stream_id
|
||||
@@ -435,8 +401,6 @@ class NormalChat:
|
||||
|
||||
# 检查 first_bot_msg 是否为 None (例如思考消息已被移除的情况)
|
||||
if first_bot_msg:
|
||||
info_catcher.catch_after_response(timing_results["消息发送"], response_set, first_bot_msg)
|
||||
|
||||
# 记录回复信息到最近回复列表中
|
||||
reply_info = {
|
||||
"time": time.time(),
|
||||
@@ -465,14 +429,9 @@ class NormalChat:
|
||||
logger.warning(f"[{self.stream_name}] 没有设置切换到focus聊天模式的回调函数,无法执行切换")
|
||||
return
|
||||
else:
|
||||
# await self._check_switch_to_focus()
|
||||
await self._check_switch_to_focus()
|
||||
pass
|
||||
|
||||
info_catcher.done_catch()
|
||||
|
||||
with Timer("处理表情包", timing_results):
|
||||
await self._handle_emoji(message, response_set[0])
|
||||
|
||||
# with Timer("关系更新", timing_results):
|
||||
# await self._update_relationship(message, response_set)
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import List, Any, Dict
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.focus_chat.planners.action_manager import ActionManager
|
||||
from src.chat.focus_chat.planners.actions.base_action import ActionActivationType, ChatMode
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
from src.config.config import global_config
|
||||
import random
|
||||
@@ -35,7 +34,7 @@ class NormalChatActionModifier:
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""为Normal Chat修改可用动作集合
|
||||
|
||||
|
||||
实现动作激活策略:
|
||||
1. 基于关联类型的动态过滤
|
||||
2. 基于激活类型的智能判定(LLM_JUDGE转为概率激活)
|
||||
@@ -49,7 +48,7 @@ class NormalChatActionModifier:
|
||||
reasons = []
|
||||
merged_action_changes = {"add": [], "remove": []}
|
||||
type_mismatched_actions = [] # 在外层定义避免作用域问题
|
||||
|
||||
|
||||
self.action_manager.restore_default_actions()
|
||||
|
||||
# 第一阶段:基于关联类型的动态过滤
|
||||
@@ -57,7 +56,7 @@ class NormalChatActionModifier:
|
||||
chat_context = chat_stream.context if hasattr(chat_stream, "context") else None
|
||||
if chat_context:
|
||||
# 获取Normal模式下的可用动作(已经过滤了mode_enable)
|
||||
current_using_actions = self.action_manager.get_using_actions_for_mode(ChatMode.NORMAL)
|
||||
current_using_actions = self.action_manager.get_using_actions_for_mode("normal")
|
||||
# print(f"current_using_actions: {current_using_actions}")
|
||||
for action_name in current_using_actions.keys():
|
||||
if action_name in self.all_actions:
|
||||
@@ -74,7 +73,7 @@ class NormalChatActionModifier:
|
||||
# 第二阶段:应用激活类型判定
|
||||
# 构建聊天内容 - 使用与planner一致的方式
|
||||
chat_content = ""
|
||||
if chat_stream and hasattr(chat_stream, 'stream_id'):
|
||||
if chat_stream and hasattr(chat_stream, "stream_id"):
|
||||
try:
|
||||
# 获取消息历史,使用与normal_chat_planner相同的方法
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
@@ -82,7 +81,7 @@ class NormalChatActionModifier:
|
||||
timestamp=time.time(),
|
||||
limit=global_config.focus_chat.observation_context_size, # 使用相同的配置
|
||||
)
|
||||
|
||||
|
||||
# 构建可读的聊天上下文
|
||||
chat_content = build_readable_messages(
|
||||
message_list_before_now,
|
||||
@@ -92,39 +91,41 @@ class NormalChatActionModifier:
|
||||
read_mark=0.0,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
|
||||
logger.debug(f"{self.log_prefix} 成功构建聊天内容,长度: {len(chat_content)}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"{self.log_prefix} 构建聊天内容失败: {e}")
|
||||
chat_content = ""
|
||||
|
||||
|
||||
# 获取当前Normal模式下的动作集进行激活判定
|
||||
current_actions = self.action_manager.get_using_actions_for_mode(ChatMode.NORMAL)
|
||||
current_actions = self.action_manager.get_using_actions_for_mode("normal")
|
||||
|
||||
# print(f"current_actions: {current_actions}")
|
||||
# print(f"chat_content: {chat_content}")
|
||||
final_activated_actions = await self._apply_normal_activation_filtering(
|
||||
current_actions,
|
||||
chat_content,
|
||||
message_content
|
||||
current_actions, chat_content, message_content
|
||||
)
|
||||
# print(f"final_activated_actions: {final_activated_actions}")
|
||||
|
||||
|
||||
# 统一处理所有需要移除的动作,避免重复移除
|
||||
all_actions_to_remove = set() # 使用set避免重复
|
||||
|
||||
|
||||
# 添加关联类型不匹配的动作
|
||||
if type_mismatched_actions:
|
||||
all_actions_to_remove.update(type_mismatched_actions)
|
||||
|
||||
|
||||
# 添加激活类型判定未通过的动作
|
||||
for action_name in current_actions.keys():
|
||||
if action_name not in final_activated_actions:
|
||||
all_actions_to_remove.add(action_name)
|
||||
|
||||
|
||||
# 统计移除原因(避免重复)
|
||||
activation_failed_actions = [name for name in current_actions.keys() if name not in final_activated_actions and name not in type_mismatched_actions]
|
||||
activation_failed_actions = [
|
||||
name
|
||||
for name in current_actions.keys()
|
||||
if name not in final_activated_actions and name not in type_mismatched_actions
|
||||
]
|
||||
if activation_failed_actions:
|
||||
reasons.append(f"移除{activation_failed_actions}(激活类型判定未通过)")
|
||||
|
||||
@@ -146,9 +147,9 @@ class NormalChatActionModifier:
|
||||
# 记录变更原因
|
||||
if reasons:
|
||||
logger.info(f"{self.log_prefix} 动作调整完成: {' | '.join(reasons)}")
|
||||
|
||||
|
||||
# 获取最终的Normal模式可用动作并记录
|
||||
final_actions = self.action_manager.get_using_actions_for_mode(ChatMode.NORMAL)
|
||||
final_actions = self.action_manager.get_using_actions_for_mode("normal")
|
||||
logger.debug(f"{self.log_prefix} 当前Normal模式可用动作: {list(final_actions.keys())}")
|
||||
|
||||
async def _apply_normal_activation_filtering(
|
||||
@@ -159,73 +160,69 @@ class NormalChatActionModifier:
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
应用Normal模式的激活类型过滤逻辑
|
||||
|
||||
|
||||
与Focus模式的区别:
|
||||
1. LLM_JUDGE类型转换为概率激活(避免LLM调用)
|
||||
2. RANDOM类型保持概率激活
|
||||
3. KEYWORD类型保持关键词匹配
|
||||
4. ALWAYS类型直接激活
|
||||
|
||||
|
||||
Args:
|
||||
actions_with_info: 带完整信息的动作字典
|
||||
chat_content: 聊天内容
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 过滤后激活的actions字典
|
||||
"""
|
||||
activated_actions = {}
|
||||
|
||||
|
||||
# 分类处理不同激活类型的actions
|
||||
always_actions = {}
|
||||
random_actions = {}
|
||||
keyword_actions = {}
|
||||
|
||||
|
||||
for action_name, action_info in actions_with_info.items():
|
||||
# 使用normal_activation_type
|
||||
activation_type = action_info.get("normal_activation_type", ActionActivationType.ALWAYS)
|
||||
|
||||
if activation_type == ActionActivationType.ALWAYS:
|
||||
activation_type = action_info.get("normal_activation_type", "always")
|
||||
|
||||
# 现在统一是字符串格式的激活类型值
|
||||
if activation_type == "always":
|
||||
always_actions[action_name] = action_info
|
||||
elif activation_type == ActionActivationType.RANDOM or activation_type == ActionActivationType.LLM_JUDGE:
|
||||
elif activation_type == "random" or activation_type == "llm_judge":
|
||||
random_actions[action_name] = action_info
|
||||
elif activation_type == ActionActivationType.KEYWORD:
|
||||
elif activation_type == "keyword":
|
||||
keyword_actions[action_name] = action_info
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix}未知的激活类型: {activation_type},跳过处理")
|
||||
|
||||
|
||||
# 1. 处理ALWAYS类型(直接激活)
|
||||
for action_name, action_info in always_actions.items():
|
||||
activated_actions[action_name] = action_info
|
||||
logger.debug(f"{self.log_prefix}激活动作: {action_name},原因: ALWAYS类型直接激活")
|
||||
|
||||
|
||||
# 2. 处理RANDOM类型(概率激活)
|
||||
for action_name, action_info in random_actions.items():
|
||||
probability = action_info.get("random_probability", 0.3)
|
||||
should_activate = random.random() < probability
|
||||
if should_activate:
|
||||
activated_actions[action_name] = action_info
|
||||
logger.info(f"{self.log_prefix}激活动作: {action_name},原因: RANDOM类型触发(概率{probability})")
|
||||
logger.debug(f"{self.log_prefix}激活动作: {action_name},原因: RANDOM类型触发(概率{probability})")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: RANDOM类型未触发(概率{probability})")
|
||||
|
||||
|
||||
# 3. 处理KEYWORD类型(关键词匹配)
|
||||
for action_name, action_info in keyword_actions.items():
|
||||
should_activate = self._check_keyword_activation(
|
||||
action_name,
|
||||
action_info,
|
||||
chat_content,
|
||||
message_content
|
||||
)
|
||||
should_activate = self._check_keyword_activation(action_name, action_info, chat_content, message_content)
|
||||
if should_activate:
|
||||
activated_actions[action_name] = action_info
|
||||
keywords = action_info.get("activation_keywords", [])
|
||||
logger.info(f"{self.log_prefix}激活动作: {action_name},原因: KEYWORD类型匹配关键词({keywords})")
|
||||
logger.debug(f"{self.log_prefix}激活动作: {action_name},原因: KEYWORD类型匹配关键词({keywords})")
|
||||
else:
|
||||
keywords = action_info.get("activation_keywords", [])
|
||||
logger.info(f"{self.log_prefix}未激活动作: {action_name},原因: KEYWORD类型未匹配关键词({keywords})")
|
||||
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: KEYWORD类型未匹配关键词({keywords})")
|
||||
# print(f"keywords: {keywords}")
|
||||
# print(f"chat_content: {chat_content}")
|
||||
|
||||
|
||||
logger.debug(f"{self.log_prefix}Normal模式激活类型过滤完成: {list(activated_actions.keys())}")
|
||||
return activated_actions
|
||||
|
||||
@@ -238,51 +235,50 @@ class NormalChatActionModifier:
|
||||
) -> bool:
|
||||
"""
|
||||
检查是否匹配关键词触发条件
|
||||
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
action_info: 动作信息
|
||||
chat_content: 聊天内容(已经是格式化后的可读消息)
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否应该激活此action
|
||||
"""
|
||||
|
||||
|
||||
activation_keywords = action_info.get("activation_keywords", [])
|
||||
case_sensitive = action_info.get("keyword_case_sensitive", False)
|
||||
|
||||
|
||||
if not activation_keywords:
|
||||
logger.warning(f"{self.log_prefix}动作 {action_name} 设置为关键词触发但未配置关键词")
|
||||
return False
|
||||
|
||||
|
||||
# 使用构建好的聊天内容作为检索文本
|
||||
search_text = chat_content +message_content
|
||||
|
||||
search_text = chat_content + message_content
|
||||
|
||||
# 如果不区分大小写,转换为小写
|
||||
if not case_sensitive:
|
||||
search_text = search_text.lower()
|
||||
|
||||
|
||||
# 检查每个关键词
|
||||
matched_keywords = []
|
||||
for keyword in activation_keywords:
|
||||
check_keyword = keyword if case_sensitive else keyword.lower()
|
||||
if check_keyword in search_text:
|
||||
matched_keywords.append(keyword)
|
||||
|
||||
|
||||
|
||||
# print(f"search_text: {search_text}")
|
||||
# print(f"activation_keywords: {activation_keywords}")
|
||||
|
||||
|
||||
if matched_keywords:
|
||||
logger.info(f"{self.log_prefix}动作 {action_name} 匹配到关键词: {matched_keywords}")
|
||||
logger.debug(f"{self.log_prefix}动作 {action_name} 匹配到关键词: {matched_keywords}")
|
||||
return True
|
||||
else:
|
||||
logger.info(f"{self.log_prefix}动作 {action_name} 未匹配到任何关键词: {activation_keywords}")
|
||||
logger.debug(f"{self.log_prefix}动作 {action_name} 未匹配到任何关键词: {activation_keywords}")
|
||||
return False
|
||||
|
||||
def get_available_actions_count(self) -> int:
|
||||
"""获取当前可用动作数量(排除默认的no_action)"""
|
||||
current_actions = self.action_manager.get_using_actions_for_mode(ChatMode.NORMAL)
|
||||
current_actions = self.action_manager.get_using_actions_for_mode("normal")
|
||||
# 排除no_action(如果存在)
|
||||
filtered_actions = {k: v for k, v in current_actions.items() if k != "no_action"}
|
||||
return len(filtered_actions)
|
||||
|
||||
@@ -1,258 +1,262 @@
|
||||
"""
|
||||
Normal Chat Expressor
|
||||
|
||||
为Normal Chat专门设计的表达器,不需要经过LLM风格化处理,
|
||||
直接发送消息,主要用于插件动作中需要发送消息的场景。
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import List, Optional, Tuple, Dict, Any
|
||||
from src.chat.message_receive.message import MessageRecv, MessageSending, MessageThinking, Seg
|
||||
from src.chat.message_receive.message import UserInfo
|
||||
from src.chat.message_receive.chat_stream import ChatStream,chat_manager
|
||||
from src.chat.message_receive.message_sender import message_manager
|
||||
from src.config.config import global_config
|
||||
from src.common.logger_manager import get_logger
|
||||
|
||||
logger = get_logger("normal_chat_expressor")
|
||||
|
||||
|
||||
class NormalChatExpressor:
|
||||
"""Normal Chat专用表达器
|
||||
|
||||
特点:
|
||||
1. 不经过LLM风格化,直接发送消息
|
||||
2. 支持文本和表情包发送
|
||||
3. 为插件动作提供简化的消息发送接口
|
||||
4. 保持与focus_chat expressor相似的API,但去掉复杂的风格化流程
|
||||
"""
|
||||
|
||||
def __init__(self, chat_stream: ChatStream):
|
||||
"""初始化Normal Chat表达器
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象
|
||||
stream_name: 流名称
|
||||
"""
|
||||
self.chat_stream = chat_stream
|
||||
self.stream_name = chat_manager.get_stream_name(self.chat_stream.stream_id) or self.chat_stream.stream_id
|
||||
self.log_prefix = f"[{self.stream_name}]Normal表达器"
|
||||
|
||||
logger.debug(f"{self.log_prefix} 初始化完成")
|
||||
|
||||
async def create_thinking_message(
|
||||
self, anchor_message: Optional[MessageRecv], thinking_id: str
|
||||
) -> Optional[MessageThinking]:
|
||||
"""创建思考消息
|
||||
|
||||
Args:
|
||||
anchor_message: 锚点消息
|
||||
thinking_id: 思考ID
|
||||
|
||||
Returns:
|
||||
MessageThinking: 创建的思考消息,如果失败返回None
|
||||
"""
|
||||
if not anchor_message or not anchor_message.chat_stream:
|
||||
logger.error(f"{self.log_prefix} 无法创建思考消息,缺少有效的锚点消息或聊天流")
|
||||
return None
|
||||
|
||||
messageinfo = anchor_message.message_info
|
||||
thinking_time_point = time.time()
|
||||
|
||||
bot_user_info = UserInfo(
|
||||
user_id=global_config.bot.qq_account,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=messageinfo.platform,
|
||||
)
|
||||
|
||||
thinking_message = MessageThinking(
|
||||
message_id=thinking_id,
|
||||
chat_stream=self.chat_stream,
|
||||
bot_user_info=bot_user_info,
|
||||
reply=anchor_message,
|
||||
thinking_start_time=thinking_time_point,
|
||||
)
|
||||
|
||||
await message_manager.add_message(thinking_message)
|
||||
logger.debug(f"{self.log_prefix} 创建思考消息: {thinking_id}")
|
||||
return thinking_message
|
||||
|
||||
async def send_response_messages(
|
||||
self,
|
||||
anchor_message: Optional[MessageRecv],
|
||||
response_set: List[Tuple[str, str]],
|
||||
thinking_id: str = "",
|
||||
display_message: str = "",
|
||||
) -> Optional[MessageSending]:
|
||||
"""发送回复消息
|
||||
|
||||
Args:
|
||||
anchor_message: 锚点消息
|
||||
response_set: 回复内容集合,格式为 [(type, content), ...]
|
||||
thinking_id: 思考ID
|
||||
display_message: 显示消息
|
||||
|
||||
Returns:
|
||||
MessageSending: 发送的第一条消息,如果失败返回None
|
||||
"""
|
||||
try:
|
||||
if not response_set:
|
||||
logger.warning(f"{self.log_prefix} 回复内容为空")
|
||||
return None
|
||||
|
||||
# 如果没有thinking_id,生成一个
|
||||
if not thinking_id:
|
||||
thinking_time_point = round(time.time(), 2)
|
||||
thinking_id = "mt" + str(thinking_time_point)
|
||||
|
||||
# 创建思考消息
|
||||
if anchor_message:
|
||||
await self.create_thinking_message(anchor_message, thinking_id)
|
||||
|
||||
# 创建消息集
|
||||
|
||||
first_bot_msg = None
|
||||
mark_head = False
|
||||
is_emoji = False
|
||||
if len(response_set) == 0:
|
||||
return None
|
||||
message_id = f"{thinking_id}_{len(response_set)}"
|
||||
response_type, content = response_set[0]
|
||||
if len(response_set) > 1:
|
||||
message_segment = Seg(type="seglist", data=[Seg(type=t, data=c) for t, c in response_set])
|
||||
else:
|
||||
message_segment = Seg(type=response_type, data=content)
|
||||
if response_type == "emoji":
|
||||
is_emoji = True
|
||||
|
||||
bot_msg = await self._build_sending_message(
|
||||
message_id=message_id,
|
||||
message_segment=message_segment,
|
||||
thinking_id=thinking_id,
|
||||
anchor_message=anchor_message,
|
||||
thinking_start_time=time.time(),
|
||||
reply_to=mark_head,
|
||||
is_emoji=is_emoji,
|
||||
)
|
||||
logger.debug(f"{self.log_prefix} 添加{response_type}类型消息: {content}")
|
||||
|
||||
# 提交消息集
|
||||
if bot_msg:
|
||||
await message_manager.add_message(bot_msg)
|
||||
logger.info(f"{self.log_prefix} 成功发送 {response_type}类型消息: {content}")
|
||||
container = await message_manager.get_container(self.chat_stream.stream_id) # 使用 self.stream_id
|
||||
for msg in container.messages[:]:
|
||||
if isinstance(msg, MessageThinking) and msg.message_info.message_id == thinking_id:
|
||||
container.messages.remove(msg)
|
||||
logger.debug(f"[{self.stream_name}] 已移除未产生回复的思考消息 {thinking_id}")
|
||||
break
|
||||
return first_bot_msg
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 没有有效的消息被创建")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 发送消息失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
async def _build_sending_message(
|
||||
self,
|
||||
message_id: str,
|
||||
message_segment: Seg,
|
||||
thinking_id: str,
|
||||
anchor_message: Optional[MessageRecv],
|
||||
thinking_start_time: float,
|
||||
reply_to: bool = False,
|
||||
is_emoji: bool = False,
|
||||
) -> MessageSending:
|
||||
"""构建发送消息
|
||||
|
||||
Args:
|
||||
message_id: 消息ID
|
||||
message_segment: 消息段
|
||||
thinking_id: 思考ID
|
||||
anchor_message: 锚点消息
|
||||
thinking_start_time: 思考开始时间
|
||||
reply_to: 是否回复
|
||||
is_emoji: 是否为表情包
|
||||
|
||||
Returns:
|
||||
MessageSending: 构建的发送消息
|
||||
"""
|
||||
bot_user_info = UserInfo(
|
||||
user_id=global_config.bot.qq_account,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=anchor_message.message_info.platform if anchor_message else "unknown",
|
||||
)
|
||||
|
||||
message_sending = MessageSending(
|
||||
message_id=message_id,
|
||||
chat_stream=self.chat_stream,
|
||||
bot_user_info=bot_user_info,
|
||||
message_segment=message_segment,
|
||||
sender_info=self.chat_stream.user_info,
|
||||
reply=anchor_message if reply_to else None,
|
||||
thinking_start_time=thinking_start_time,
|
||||
is_emoji=is_emoji,
|
||||
)
|
||||
|
||||
return message_sending
|
||||
|
||||
async def deal_reply(
|
||||
self,
|
||||
cycle_timers: dict,
|
||||
action_data: Dict[str, Any],
|
||||
reasoning: str,
|
||||
anchor_message: MessageRecv,
|
||||
thinking_id: str,
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
"""处理回复动作 - 兼容focus_chat expressor API
|
||||
|
||||
Args:
|
||||
cycle_timers: 周期计时器(normal_chat中不使用)
|
||||
action_data: 动作数据,包含text、target、emojis等
|
||||
reasoning: 推理说明
|
||||
anchor_message: 锚点消息
|
||||
thinking_id: 思考ID
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str]]: (是否成功, 回复文本)
|
||||
"""
|
||||
try:
|
||||
response_set = []
|
||||
|
||||
# 处理文本内容
|
||||
text_content = action_data.get("text", "")
|
||||
if text_content:
|
||||
response_set.append(("text", text_content))
|
||||
|
||||
# 处理表情包
|
||||
emoji_content = action_data.get("emojis", "")
|
||||
if emoji_content:
|
||||
response_set.append(("emoji", emoji_content))
|
||||
|
||||
if not response_set:
|
||||
logger.warning(f"{self.log_prefix} deal_reply: 没有有效的回复内容")
|
||||
return False, None
|
||||
|
||||
# 发送消息
|
||||
result = await self.send_response_messages(
|
||||
anchor_message=anchor_message,
|
||||
response_set=response_set,
|
||||
thinking_id=thinking_id,
|
||||
)
|
||||
|
||||
if result:
|
||||
return True, text_content if text_content else "发送成功"
|
||||
else:
|
||||
return False, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} deal_reply执行失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False, None
|
||||
"""
|
||||
Normal Chat Expressor
|
||||
|
||||
为Normal Chat专门设计的表达器,不需要经过LLM风格化处理,
|
||||
直接发送消息,主要用于插件动作中需要发送消息的场景。
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import List, Optional, Tuple, Dict, Any
|
||||
from src.chat.message_receive.message import MessageRecv, MessageSending, MessageThinking, Seg
|
||||
from src.chat.message_receive.message import UserInfo
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.message_receive.message_sender import message_manager
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("normal_chat_expressor")
|
||||
|
||||
|
||||
class NormalChatExpressor:
|
||||
"""Normal Chat专用表达器
|
||||
|
||||
特点:
|
||||
1. 不经过LLM风格化,直接发送消息
|
||||
2. 支持文本和表情包发送
|
||||
3. 为插件动作提供简化的消息发送接口
|
||||
4. 保持与focus_chat expressor相似的API,但去掉复杂的风格化流程
|
||||
"""
|
||||
|
||||
def __init__(self, chat_stream: ChatStream):
|
||||
"""初始化Normal Chat表达器
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象
|
||||
stream_name: 流名称
|
||||
"""
|
||||
self.chat_stream = chat_stream
|
||||
self.stream_name = get_chat_manager().get_stream_name(self.chat_stream.stream_id) or self.chat_stream.stream_id
|
||||
self.log_prefix = f"[{self.stream_name}]Normal表达器"
|
||||
|
||||
logger.debug(f"{self.log_prefix} 初始化完成")
|
||||
|
||||
async def create_thinking_message(
|
||||
self, anchor_message: Optional[MessageRecv], thinking_id: str
|
||||
) -> Optional[MessageThinking]:
|
||||
"""创建思考消息
|
||||
|
||||
Args:
|
||||
anchor_message: 锚点消息
|
||||
thinking_id: 思考ID
|
||||
|
||||
Returns:
|
||||
MessageThinking: 创建的思考消息,如果失败返回None
|
||||
"""
|
||||
if not anchor_message or not anchor_message.chat_stream:
|
||||
logger.error(f"{self.log_prefix} 无法创建思考消息,缺少有效的锚点消息或聊天流")
|
||||
return None
|
||||
|
||||
messageinfo = anchor_message.message_info
|
||||
thinking_time_point = time.time()
|
||||
|
||||
bot_user_info = UserInfo(
|
||||
user_id=global_config.bot.qq_account,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=messageinfo.platform,
|
||||
)
|
||||
|
||||
thinking_message = MessageThinking(
|
||||
message_id=thinking_id,
|
||||
chat_stream=self.chat_stream,
|
||||
bot_user_info=bot_user_info,
|
||||
reply=anchor_message,
|
||||
thinking_start_time=thinking_time_point,
|
||||
)
|
||||
|
||||
await message_manager.add_message(thinking_message)
|
||||
logger.debug(f"{self.log_prefix} 创建思考消息: {thinking_id}")
|
||||
return thinking_message
|
||||
|
||||
async def send_response_messages(
|
||||
self,
|
||||
anchor_message: Optional[MessageRecv],
|
||||
response_set: List[Tuple[str, str]],
|
||||
thinking_id: str = "",
|
||||
display_message: str = "",
|
||||
) -> Optional[MessageSending]:
|
||||
"""发送回复消息
|
||||
|
||||
Args:
|
||||
anchor_message: 锚点消息
|
||||
response_set: 回复内容集合,格式为 [(type, content), ...]
|
||||
thinking_id: 思考ID
|
||||
display_message: 显示消息
|
||||
|
||||
Returns:
|
||||
MessageSending: 发送的第一条消息,如果失败返回None
|
||||
"""
|
||||
try:
|
||||
if not response_set:
|
||||
logger.warning(f"{self.log_prefix} 回复内容为空")
|
||||
return None
|
||||
|
||||
# 如果没有thinking_id,生成一个
|
||||
if not thinking_id:
|
||||
thinking_time_point = round(time.time(), 2)
|
||||
thinking_id = "mt" + str(thinking_time_point)
|
||||
|
||||
# 创建思考消息
|
||||
if anchor_message:
|
||||
await self.create_thinking_message(anchor_message, thinking_id)
|
||||
|
||||
# 创建消息集
|
||||
|
||||
mark_head = False
|
||||
is_emoji = False
|
||||
if len(response_set) == 0:
|
||||
return None
|
||||
message_id = f"{thinking_id}_{len(response_set)}"
|
||||
response_type, content = response_set[0]
|
||||
if len(response_set) > 1:
|
||||
message_segment = Seg(type="seglist", data=[Seg(type=t, data=c) for t, c in response_set])
|
||||
else:
|
||||
message_segment = Seg(type=response_type, data=content)
|
||||
if response_type == "emoji":
|
||||
is_emoji = True
|
||||
|
||||
bot_msg = await self._build_sending_message(
|
||||
message_id=message_id,
|
||||
message_segment=message_segment,
|
||||
thinking_id=thinking_id,
|
||||
anchor_message=anchor_message,
|
||||
thinking_start_time=time.time(),
|
||||
reply_to=mark_head,
|
||||
is_emoji=is_emoji,
|
||||
display_message=display_message,
|
||||
)
|
||||
logger.debug(f"{self.log_prefix} 添加{response_type}类型消息: {content}")
|
||||
|
||||
# 提交消息集
|
||||
if bot_msg:
|
||||
await message_manager.add_message(bot_msg)
|
||||
logger.info(
|
||||
f"{self.log_prefix} 成功发送 {response_type}类型消息: {str(content)[:200] + '...' if len(str(content)) > 200 else content}"
|
||||
)
|
||||
container = await message_manager.get_container(self.chat_stream.stream_id) # 使用 self.stream_id
|
||||
for msg in container.messages[:]:
|
||||
if isinstance(msg, MessageThinking) and msg.message_info.message_id == thinking_id:
|
||||
container.messages.remove(msg)
|
||||
logger.debug(f"[{self.stream_name}] 已移除未产生回复的思考消息 {thinking_id}")
|
||||
break
|
||||
return bot_msg
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 没有有效的消息被创建")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 发送消息失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
async def _build_sending_message(
|
||||
self,
|
||||
message_id: str,
|
||||
message_segment: Seg,
|
||||
thinking_id: str,
|
||||
anchor_message: Optional[MessageRecv],
|
||||
thinking_start_time: float,
|
||||
reply_to: bool = False,
|
||||
is_emoji: bool = False,
|
||||
display_message: str = "",
|
||||
) -> MessageSending:
|
||||
"""构建发送消息
|
||||
|
||||
Args:
|
||||
message_id: 消息ID
|
||||
message_segment: 消息段
|
||||
thinking_id: 思考ID
|
||||
anchor_message: 锚点消息
|
||||
thinking_start_time: 思考开始时间
|
||||
reply_to: 是否回复
|
||||
is_emoji: 是否为表情包
|
||||
|
||||
Returns:
|
||||
MessageSending: 构建的发送消息
|
||||
"""
|
||||
bot_user_info = UserInfo(
|
||||
user_id=global_config.bot.qq_account,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=anchor_message.message_info.platform if anchor_message else "unknown",
|
||||
)
|
||||
|
||||
message_sending = MessageSending(
|
||||
message_id=message_id,
|
||||
chat_stream=self.chat_stream,
|
||||
bot_user_info=bot_user_info,
|
||||
message_segment=message_segment,
|
||||
sender_info=self.chat_stream.user_info,
|
||||
reply=anchor_message if reply_to else None,
|
||||
thinking_start_time=thinking_start_time,
|
||||
is_emoji=is_emoji,
|
||||
display_message=display_message,
|
||||
)
|
||||
|
||||
return message_sending
|
||||
|
||||
async def deal_reply(
|
||||
self,
|
||||
cycle_timers: dict,
|
||||
action_data: Dict[str, Any],
|
||||
reasoning: str,
|
||||
anchor_message: MessageRecv,
|
||||
thinking_id: str,
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
"""处理回复动作 - 兼容focus_chat expressor API
|
||||
|
||||
Args:
|
||||
cycle_timers: 周期计时器(normal_chat中不使用)
|
||||
action_data: 动作数据,包含text、target、emojis等
|
||||
reasoning: 推理说明
|
||||
anchor_message: 锚点消息
|
||||
thinking_id: 思考ID
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str]]: (是否成功, 回复文本)
|
||||
"""
|
||||
try:
|
||||
response_set = []
|
||||
|
||||
# 处理文本内容
|
||||
text_content = action_data.get("text", "")
|
||||
if text_content:
|
||||
response_set.append(("text", text_content))
|
||||
|
||||
# 处理表情包
|
||||
emoji_content = action_data.get("emojis", "")
|
||||
if emoji_content:
|
||||
response_set.append(("emoji", emoji_content))
|
||||
|
||||
if not response_set:
|
||||
logger.warning(f"{self.log_prefix} deal_reply: 没有有效的回复内容")
|
||||
return False, None
|
||||
|
||||
# 发送消息
|
||||
result = await self.send_response_messages(
|
||||
anchor_message=anchor_message,
|
||||
response_set=response_set,
|
||||
thinking_id=thinking_id,
|
||||
)
|
||||
|
||||
if result:
|
||||
return True, text_content if text_content else "发送成功"
|
||||
else:
|
||||
return False, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} deal_reply执行失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False, None
|
||||
|
||||
@@ -5,9 +5,8 @@ from src.config.config import global_config
|
||||
from src.chat.message_receive.message import MessageThinking
|
||||
from src.chat.normal_chat.normal_prompt import prompt_builder
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.utils.info_catcher import info_catcher_manager
|
||||
from src.person_info.person_info import person_info_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
from src.chat.utils.utils import process_llm_response
|
||||
|
||||
|
||||
@@ -26,9 +25,7 @@ class NormalChatGenerator:
|
||||
request_type="normal.chat_2",
|
||||
)
|
||||
|
||||
self.model_sum = LLMRequest(
|
||||
model=global_config.model.memory_summary, temperature=0.7, request_type="relation"
|
||||
)
|
||||
self.model_sum = LLMRequest(model=global_config.model.memory_summary, temperature=0.7, request_type="relation")
|
||||
self.current_model_type = "r1" # 默认使用 R1
|
||||
self.current_model_name = "unknown model"
|
||||
|
||||
@@ -69,12 +66,10 @@ class NormalChatGenerator:
|
||||
enable_planner: bool = False,
|
||||
available_actions=None,
|
||||
):
|
||||
info_catcher = info_catcher_manager.get_info_catcher(thinking_id)
|
||||
|
||||
person_id = person_info_manager.get_person_id(
|
||||
person_id = PersonInfoManager.get_person_id(
|
||||
message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
|
||||
)
|
||||
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
|
||||
if message.chat_stream.user_info.user_cardname and message.chat_stream.user_info.user_nickname:
|
||||
@@ -105,10 +100,6 @@ class NormalChatGenerator:
|
||||
|
||||
logger.info(f"对 {message.processed_plain_text} 的回复:{content}")
|
||||
|
||||
info_catcher.catch_after_llm_generated(
|
||||
prompt=prompt, response=content, reasoning_content=reasoning_content, model_name=self.current_model_name
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("生成回复时出错")
|
||||
return None
|
||||
|
||||
@@ -3,11 +3,10 @@ from typing import Dict, Any
|
||||
from rich.traceback import install
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.individuality.individuality import individuality
|
||||
from src.individuality.individuality import get_individuality
|
||||
from src.chat.focus_chat.planners.action_manager import ActionManager
|
||||
from src.chat.focus_chat.planners.actions.base_action import ChatMode
|
||||
from src.chat.message_receive.message import MessageThinking
|
||||
from json_repair import repair_json
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
@@ -26,6 +25,11 @@ def init_prompt():
|
||||
{self_info_block}
|
||||
请记住你的性格,身份和特点。
|
||||
|
||||
你是群内的一员,你现在正在参与群内的闲聊,以下是群内的聊天内容:
|
||||
{chat_context}
|
||||
|
||||
基于以上聊天上下文和用户的最新消息,选择最合适的action。
|
||||
|
||||
注意,除了下面动作选项之外,你在聊天中不能做其他任何事情,这是你能力的边界,现在请你选择合适的action:
|
||||
|
||||
{action_options_text}
|
||||
@@ -38,11 +42,6 @@ def init_prompt():
|
||||
你必须从上面列出的可用action中选择一个,并说明原因。
|
||||
{moderation_prompt}
|
||||
|
||||
你是群内的一员,你现在正在参与群内的闲聊,以下是群内的聊天内容:
|
||||
{chat_context}
|
||||
|
||||
基于以上聊天上下文和用户的最新消息,选择最合适的action。
|
||||
|
||||
请以动作的输出要求,以严格的 JSON 格式输出,且仅包含 JSON 内容。不要有任何其他文字或解释:
|
||||
""",
|
||||
"normal_chat_planner_prompt",
|
||||
@@ -94,14 +93,14 @@ class NormalChatPlanner:
|
||||
nickname_str += f"{nicknames},"
|
||||
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
|
||||
|
||||
personality_block = individuality.get_personality_prompt(x_person=2, level=2)
|
||||
identity_block = individuality.get_identity_prompt(x_person=2, level=2)
|
||||
personality_block = get_individuality().get_personality_prompt(x_person=2, level=2)
|
||||
identity_block = get_individuality().get_identity_prompt(x_person=2, level=2)
|
||||
|
||||
self_info = name_block + personality_block + identity_block
|
||||
|
||||
# 获取当前可用的动作,使用Normal模式过滤
|
||||
current_available_actions = self.action_manager.get_using_actions_for_mode(ChatMode.NORMAL)
|
||||
|
||||
current_available_actions = self.action_manager.get_using_actions_for_mode("normal")
|
||||
|
||||
# 注意:动作的激活判定现在在 normal_chat_action_modifier 中完成
|
||||
# 这里直接使用经过 action_modifier 处理后的最终动作集
|
||||
# 符合职责分离原则:ActionModifier负责动作管理,Planner专注于决策
|
||||
@@ -110,7 +109,12 @@ class NormalChatPlanner:
|
||||
if not current_available_actions:
|
||||
logger.debug(f"{self.log_prefix}规划器: 没有可用动作,返回no_action")
|
||||
return {
|
||||
"action_result": {"action_type": action, "action_data": action_data, "reasoning": reasoning, "is_parallel": True},
|
||||
"action_result": {
|
||||
"action_type": action,
|
||||
"action_data": action_data,
|
||||
"reasoning": reasoning,
|
||||
"is_parallel": True,
|
||||
},
|
||||
"chat_context": "",
|
||||
"action_prompt": "",
|
||||
}
|
||||
@@ -121,7 +125,7 @@ class NormalChatPlanner:
|
||||
timestamp=time.time(),
|
||||
limit=global_config.focus_chat.observation_context_size,
|
||||
)
|
||||
|
||||
|
||||
chat_context = build_readable_messages(
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
@@ -130,7 +134,7 @@ class NormalChatPlanner:
|
||||
read_mark=0.0,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
|
||||
# 构建planner的prompt
|
||||
prompt = await self.build_planner_prompt(
|
||||
self_info_block=self_info,
|
||||
@@ -141,7 +145,12 @@ class NormalChatPlanner:
|
||||
if not prompt:
|
||||
logger.warning(f"{self.log_prefix}规划器: 构建提示词失败")
|
||||
return {
|
||||
"action_result": {"action_type": action, "action_data": action_data, "reasoning": reasoning, "is_parallel": False},
|
||||
"action_result": {
|
||||
"action_type": action,
|
||||
"action_data": action_data,
|
||||
"reasoning": reasoning,
|
||||
"is_parallel": False,
|
||||
},
|
||||
"chat_context": chat_context,
|
||||
"action_prompt": "",
|
||||
}
|
||||
@@ -149,8 +158,8 @@ class NormalChatPlanner:
|
||||
# 使用LLM生成动作决策
|
||||
try:
|
||||
content, (reasoning_content, model_name) = await self.planner_llm.generate_response_async(prompt)
|
||||
|
||||
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
|
||||
# logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
logger.info(f"{self.log_prefix}规划器原始响应: {content}")
|
||||
logger.info(f"{self.log_prefix}规划器推理: {reasoning_content}")
|
||||
logger.info(f"{self.log_prefix}规划器模型: {model_name}")
|
||||
@@ -201,8 +210,10 @@ class NormalChatPlanner:
|
||||
if action in current_available_actions:
|
||||
action_info = current_available_actions[action]
|
||||
is_parallel = action_info.get("parallel_action", False)
|
||||
|
||||
logger.debug(f"{self.log_prefix}规划器决策动作:{action}, 动作信息: '{action_data}', 理由: {reasoning}, 并行执行: {is_parallel}")
|
||||
|
||||
logger.debug(
|
||||
f"{self.log_prefix}规划器决策动作:{action}, 动作信息: '{action_data}', 理由: {reasoning}, 并行执行: {is_parallel}"
|
||||
)
|
||||
|
||||
# 恢复到默认动作集
|
||||
self.action_manager.restore_actions()
|
||||
@@ -216,15 +227,15 @@ class NormalChatPlanner:
|
||||
"action_data": action_data,
|
||||
"reasoning": reasoning,
|
||||
"timestamp": time.time(),
|
||||
"model_name": model_name if 'model_name' in locals() else None
|
||||
"model_name": model_name if "model_name" in locals() else None,
|
||||
}
|
||||
|
||||
action_result = {
|
||||
"action_type": action,
|
||||
"action_data": action_data,
|
||||
"action_type": action,
|
||||
"action_data": action_data,
|
||||
"reasoning": reasoning,
|
||||
"is_parallel": is_parallel,
|
||||
"action_record": json.dumps(action_record, ensure_ascii=False)
|
||||
"action_record": json.dumps(action_record, ensure_ascii=False),
|
||||
}
|
||||
|
||||
plan_result = {
|
||||
@@ -248,24 +259,19 @@ class NormalChatPlanner:
|
||||
|
||||
# 添加特殊的change_to_focus_chat动作
|
||||
action_options_text += "动作:change_to_focus_chat\n"
|
||||
action_options_text += (
|
||||
"该动作的描述:当聊天变得热烈、自己回复条数很多或需要深入交流时使用,正常回复消息并切换到focus_chat模式\n"
|
||||
)
|
||||
action_options_text += "该动作的描述:当聊天变得热烈、自己回复条数很多或需要深入交流时使用,正常回复消息并切换到focus_chat模式\n"
|
||||
|
||||
action_options_text += "使用该动作的场景:\n"
|
||||
action_options_text += "- 聊天上下文中自己的回复条数较多(超过3-4条)\n"
|
||||
action_options_text += "- 对话进行得非常热烈活跃\n"
|
||||
action_options_text += "- 用户表现出深入交流的意图\n"
|
||||
action_options_text += "- 话题需要更专注和深入的讨论\n\n"
|
||||
|
||||
|
||||
action_options_text += "输出要求:\n"
|
||||
action_options_text += "{{"
|
||||
action_options_text += " \"action\": \"change_to_focus_chat\""
|
||||
action_options_text += ' "action": "change_to_focus_chat"'
|
||||
action_options_text += "}}\n\n"
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
for action_name, action_info in current_available_actions.items():
|
||||
action_description = action_info.get("description", "")
|
||||
action_parameters = action_info.get("parameters", {})
|
||||
@@ -276,15 +282,14 @@ class NormalChatPlanner:
|
||||
print(action_parameters)
|
||||
for param_name, param_description in action_parameters.items():
|
||||
param_text += f' "{param_name}":"{param_description}"\n'
|
||||
param_text = param_text.rstrip('\n')
|
||||
param_text = param_text.rstrip("\n")
|
||||
else:
|
||||
param_text = ""
|
||||
|
||||
|
||||
require_text = ""
|
||||
for require_item in action_require:
|
||||
require_text += f"- {require_item}\n"
|
||||
require_text = require_text.rstrip('\n')
|
||||
require_text = require_text.rstrip("\n")
|
||||
|
||||
# 构建单个动作的提示
|
||||
action_prompt = await global_prompt_manager.format_prompt(
|
||||
@@ -316,6 +321,4 @@ class NormalChatPlanner:
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
from src.chat.focus_chat.expressors.exprssion_learner import get_expression_learner
|
||||
from src.config.config import global_config
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.individuality.individuality import individuality
|
||||
from src.common.logger import get_logger
|
||||
from src.individuality.individuality import get_individuality
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
from src.person_info.relationship_manager import relationship_manager
|
||||
import time
|
||||
from src.chat.utils.utils import get_recent_group_speaker
|
||||
from src.manager.mood_manager import mood_manager
|
||||
from src.chat.memory_system.Hippocampus import HippocampusManager
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
from src.chat.knowledge.knowledge_lib import qa_manager
|
||||
from src.chat.focus_chat.expressors.exprssion_learner import expression_learner
|
||||
import random
|
||||
import re
|
||||
|
||||
from src.person_info.relationship_manager import get_relationship_manager
|
||||
|
||||
logger = get_logger("prompt")
|
||||
|
||||
@@ -96,7 +96,7 @@ class PromptBuilder:
|
||||
enable_planner: bool = False,
|
||||
available_actions=None,
|
||||
) -> str:
|
||||
prompt_personality = individuality.get_prompt(x_person=2, level=2)
|
||||
prompt_personality = get_individuality().get_prompt(x_person=2, level=2)
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
|
||||
who_chat_in_group = []
|
||||
@@ -112,11 +112,13 @@ class PromptBuilder:
|
||||
)
|
||||
|
||||
relation_prompt = ""
|
||||
for person in who_chat_in_group:
|
||||
relation_prompt += await relationship_manager.build_relationship_info(person)
|
||||
if global_config.relationship.enable_relationship:
|
||||
for person in who_chat_in_group:
|
||||
relationship_manager = get_relationship_manager()
|
||||
relation_prompt += await relationship_manager.build_relationship_info(person)
|
||||
|
||||
mood_prompt = mood_manager.get_mood_prompt()
|
||||
|
||||
expression_learner = get_expression_learner()
|
||||
(
|
||||
learnt_style_expressions,
|
||||
learnt_grammar_expressions,
|
||||
@@ -159,18 +161,19 @@ class PromptBuilder:
|
||||
)[0]
|
||||
memory_prompt = ""
|
||||
|
||||
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
||||
text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
|
||||
)
|
||||
|
||||
related_memory_info = ""
|
||||
if related_memory:
|
||||
for memory in related_memory:
|
||||
related_memory_info += memory[1]
|
||||
memory_prompt = await global_prompt_manager.format_prompt(
|
||||
"memory_prompt", related_memory_info=related_memory_info
|
||||
if global_config.memory.enable_memory:
|
||||
related_memory = await hippocampus_manager.get_memory_from_text(
|
||||
text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
|
||||
)
|
||||
|
||||
related_memory_info = ""
|
||||
if related_memory:
|
||||
for memory in related_memory:
|
||||
related_memory_info += memory[1]
|
||||
memory_prompt = await global_prompt_manager.format_prompt(
|
||||
"memory_prompt", related_memory_info=related_memory_info
|
||||
)
|
||||
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_stream.stream_id,
|
||||
timestamp=time.time(),
|
||||
@@ -212,7 +215,6 @@ class PromptBuilder:
|
||||
except Exception as e:
|
||||
logger.error(f"关键词检测与反应时发生异常: {str(e)}", exc_info=True)
|
||||
|
||||
|
||||
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
|
||||
|
||||
# 构建action描述 (如果启用planner)
|
||||
|
||||
@@ -42,9 +42,7 @@ class ClassicalWillingManager(BaseWillingManager):
|
||||
|
||||
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
|
||||
|
||||
reply_probability = min(
|
||||
max((current_willing - 0.5), 0.01) * 2, 1
|
||||
)
|
||||
reply_probability = min(max((current_willing - 0.5), 0.01) * 2, 1)
|
||||
|
||||
# 检查群组权限(如果是群聊)
|
||||
if (
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from src.common.logger import LogConfig, WILLING_STYLE_CONFIG, LoguruLogger, get_module_logger
|
||||
from src.common.logger import get_logger
|
||||
from dataclasses import dataclass
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.person_info.person_info import person_info_manager, PersonInfoManager
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
from abc import ABC, abstractmethod
|
||||
import importlib
|
||||
from typing import Dict, Optional
|
||||
@@ -33,12 +33,8 @@ set_willing 设置某聊天流意愿
|
||||
示例: 在 `mode_aggressive.py` 中,类名应为 `AggressiveWillingManager`
|
||||
"""
|
||||
|
||||
willing_config = LogConfig(
|
||||
# 使用消息发送专用样式
|
||||
console_format=WILLING_STYLE_CONFIG["console_format"],
|
||||
file_format=WILLING_STYLE_CONFIG["file_format"],
|
||||
)
|
||||
logger = get_module_logger("willing", config=willing_config)
|
||||
|
||||
logger = get_logger("willing")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -93,14 +89,14 @@ class BaseWillingManager(ABC):
|
||||
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿(chat_id)
|
||||
self.ongoing_messages: Dict[str, WillingInfo] = {} # 当前正在进行的消息(message_id)
|
||||
self.lock = asyncio.Lock()
|
||||
self.logger: LoguruLogger = logger
|
||||
self.logger = logger
|
||||
|
||||
def setup(self, message: MessageRecv, chat: ChatStream, is_mentioned_bot: bool, interested_rate: float):
|
||||
person_id = person_info_manager.get_person_id(chat.platform, chat.user_info.user_id)
|
||||
person_id = PersonInfoManager.get_person_id(chat.platform, chat.user_info.user_id)
|
||||
self.ongoing_messages[message.message_info.message_id] = WillingInfo(
|
||||
message=message,
|
||||
chat=chat,
|
||||
person_info_manager=person_info_manager,
|
||||
person_info_manager=get_person_info_manager(),
|
||||
chat_id=chat.stream_id,
|
||||
person_id=person_id,
|
||||
group_info=chat.group_info,
|
||||
@@ -177,4 +173,11 @@ def init_willing_manager() -> BaseWillingManager:
|
||||
|
||||
|
||||
# 全局willing_manager对象
|
||||
willing_manager = init_willing_manager()
|
||||
willing_manager = None
|
||||
|
||||
|
||||
def get_willing_manager():
|
||||
global willing_manager
|
||||
if willing_manager is None:
|
||||
willing_manager = init_willing_manager()
|
||||
return willing_manager
|
||||
|
||||
@@ -4,7 +4,7 @@ import time # 导入 time 模块以获取当前时间
|
||||
import random
|
||||
import re
|
||||
from src.common.message_repository import find_messages, count_messages
|
||||
from src.person_info.person_info import person_info_manager
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
from src.chat.utils.utils import translate_timestamp_to_human_readable
|
||||
from rich.traceback import install
|
||||
from src.common.database.database_model import ActionRecords
|
||||
@@ -219,7 +219,8 @@ def _build_readable_messages_internal(
|
||||
if not all([platform, user_id, timestamp is not None]):
|
||||
continue
|
||||
|
||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
person_info_manager = get_person_info_manager()
|
||||
# 根据 replace_bot_name 参数决定是否替换机器人名称
|
||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||
person_name = f"{global_config.bot.nickname}(你)"
|
||||
@@ -241,7 +242,7 @@ def _build_readable_messages_internal(
|
||||
if match:
|
||||
aaa = match.group(1)
|
||||
bbb = match.group(2)
|
||||
reply_person_id = person_info_manager.get_person_id(platform, bbb)
|
||||
reply_person_id = PersonInfoManager.get_person_id(platform, bbb)
|
||||
reply_person_name = person_info_manager.get_value_sync(reply_person_id, "person_name")
|
||||
if not reply_person_name:
|
||||
reply_person_name = aaa
|
||||
@@ -258,7 +259,7 @@ def _build_readable_messages_internal(
|
||||
new_content += content[last_end : m.start()]
|
||||
aaa = m.group(1)
|
||||
bbb = m.group(2)
|
||||
at_person_id = person_info_manager.get_person_id(platform, bbb)
|
||||
at_person_id = PersonInfoManager.get_person_id(platform, bbb)
|
||||
at_person_name = person_info_manager.get_value_sync(at_person_id, "person_name")
|
||||
if not at_person_name:
|
||||
at_person_name = aaa
|
||||
@@ -286,7 +287,7 @@ def _build_readable_messages_internal(
|
||||
message_details_with_flags.append((timestamp, name, content, is_action))
|
||||
# print(f"content:{content}")
|
||||
# print(f"is_action:{is_action}")
|
||||
|
||||
|
||||
# print(f"message_details_with_flags:{message_details_with_flags}")
|
||||
|
||||
# 应用截断逻辑 (如果 truncate 为 True)
|
||||
@@ -324,7 +325,7 @@ def _build_readable_messages_internal(
|
||||
else:
|
||||
# 如果不截断,直接使用原始列表
|
||||
message_details = message_details_with_flags
|
||||
|
||||
|
||||
# print(f"message_details:{message_details}")
|
||||
|
||||
# 3: 合并连续消息 (如果 merge_messages 为 True)
|
||||
@@ -336,12 +337,12 @@ def _build_readable_messages_internal(
|
||||
"start_time": message_details[0][0],
|
||||
"end_time": message_details[0][0],
|
||||
"content": [message_details[0][2]],
|
||||
"is_action": message_details[0][3]
|
||||
"is_action": message_details[0][3],
|
||||
}
|
||||
|
||||
for i in range(1, len(message_details)):
|
||||
timestamp, name, content, is_action = message_details[i]
|
||||
|
||||
|
||||
# 对于动作记录,不进行合并
|
||||
if is_action or current_merge["is_action"]:
|
||||
# 保存当前的合并块
|
||||
@@ -352,7 +353,7 @@ def _build_readable_messages_internal(
|
||||
"start_time": timestamp,
|
||||
"end_time": timestamp,
|
||||
"content": [content],
|
||||
"is_action": is_action
|
||||
"is_action": is_action,
|
||||
}
|
||||
continue
|
||||
|
||||
@@ -365,11 +366,11 @@ def _build_readable_messages_internal(
|
||||
merged_messages.append(current_merge)
|
||||
# 开始新的合并块
|
||||
current_merge = {
|
||||
"name": name,
|
||||
"start_time": timestamp,
|
||||
"end_time": timestamp,
|
||||
"name": name,
|
||||
"start_time": timestamp,
|
||||
"end_time": timestamp,
|
||||
"content": [content],
|
||||
"is_action": is_action
|
||||
"is_action": is_action,
|
||||
}
|
||||
# 添加最后一个合并块
|
||||
merged_messages.append(current_merge)
|
||||
@@ -381,10 +382,9 @@ def _build_readable_messages_internal(
|
||||
"start_time": timestamp, # 起始和结束时间相同
|
||||
"end_time": timestamp,
|
||||
"content": [content], # 内容只有一个元素
|
||||
"is_action": is_action
|
||||
"is_action": is_action,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 4 & 5: 格式化为字符串
|
||||
output_lines = []
|
||||
@@ -451,7 +451,7 @@ def build_readable_messages(
|
||||
将消息列表转换为可读的文本格式。
|
||||
如果提供了 read_mark,则在相应位置插入已读标记。
|
||||
允许通过参数控制格式化行为。
|
||||
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
replace_bot_name: 是否替换机器人名称为"你"
|
||||
@@ -463,22 +463,24 @@ def build_readable_messages(
|
||||
"""
|
||||
# 创建messages的深拷贝,避免修改原始列表
|
||||
copy_messages = [msg.copy() for msg in messages]
|
||||
|
||||
|
||||
if show_actions and copy_messages:
|
||||
# 获取所有消息的时间范围
|
||||
min_time = min(msg.get("time", 0) for msg in copy_messages)
|
||||
max_time = max(msg.get("time", 0) for msg in copy_messages)
|
||||
|
||||
|
||||
# 从第一条消息中获取chat_id
|
||||
chat_id = copy_messages[0].get("chat_id") if copy_messages else None
|
||||
|
||||
|
||||
# 获取这个时间范围内的动作记录,并匹配chat_id
|
||||
actions = ActionRecords.select().where(
|
||||
(ActionRecords.time >= min_time) &
|
||||
(ActionRecords.time <= max_time) &
|
||||
(ActionRecords.chat_id == chat_id)
|
||||
).order_by(ActionRecords.time)
|
||||
|
||||
actions = (
|
||||
ActionRecords.select()
|
||||
.where(
|
||||
(ActionRecords.time >= min_time) & (ActionRecords.time <= max_time) & (ActionRecords.chat_id == chat_id)
|
||||
)
|
||||
.order_by(ActionRecords.time)
|
||||
)
|
||||
|
||||
# 将动作记录转换为消息格式
|
||||
for action in actions:
|
||||
# 只有当build_into_prompt为True时才添加动作记录
|
||||
@@ -495,25 +497,22 @@ def build_readable_messages(
|
||||
"action_name": action.action_name, # 保存动作名称
|
||||
}
|
||||
copy_messages.append(action_msg)
|
||||
|
||||
|
||||
# 重新按时间排序
|
||||
copy_messages.sort(key=lambda x: x.get("time", 0))
|
||||
|
||||
if read_mark <= 0:
|
||||
# 没有有效的 read_mark,直接格式化所有消息
|
||||
|
||||
|
||||
# for message in messages:
|
||||
# print(f"message:{message}")
|
||||
|
||||
|
||||
# print(f"message:{message}")
|
||||
|
||||
formatted_string, _ = _build_readable_messages_internal(
|
||||
copy_messages, replace_bot_name, merge_messages, timestamp_mode, truncate
|
||||
)
|
||||
|
||||
|
||||
# print(f"formatted_string:{formatted_string}")
|
||||
|
||||
|
||||
|
||||
|
||||
return formatted_string
|
||||
else:
|
||||
# 按 read_mark 分割消息
|
||||
@@ -521,10 +520,10 @@ def build_readable_messages(
|
||||
messages_after_mark = [msg for msg in copy_messages if msg.get("time", 0) > read_mark]
|
||||
|
||||
# for message in messages_before_mark:
|
||||
# print(f"message:{message}")
|
||||
|
||||
# print(f"message:{message}")
|
||||
|
||||
# for message in messages_after_mark:
|
||||
# print(f"message:{message}")
|
||||
# print(f"message:{message}")
|
||||
|
||||
# 分别格式化
|
||||
formatted_before, _ = _build_readable_messages_internal(
|
||||
@@ -536,7 +535,7 @@ def build_readable_messages(
|
||||
merge_messages,
|
||||
timestamp_mode,
|
||||
)
|
||||
|
||||
|
||||
# print(f"formatted_before:{formatted_before}")
|
||||
# print(f"formatted_after:{formatted_after}")
|
||||
|
||||
@@ -574,7 +573,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
||||
# print("SELF11111111111111")
|
||||
return "SELF"
|
||||
try:
|
||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
except Exception as _e:
|
||||
person_id = None
|
||||
if not person_id:
|
||||
@@ -587,14 +586,9 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
||||
|
||||
for msg in messages:
|
||||
try:
|
||||
# user_info = msg.get("user_info", {})
|
||||
platform = msg.get("chat_info_platform")
|
||||
user_id = msg.get("user_id")
|
||||
_timestamp = msg.get("time")
|
||||
# print(f"msg:{msg}")
|
||||
# print(f"platform:{platform}")
|
||||
# print(f"user_id:{user_id}")
|
||||
# print(f"timestamp:{timestamp}")
|
||||
if msg.get("display_message"):
|
||||
content = msg.get("display_message")
|
||||
else:
|
||||
@@ -680,7 +674,7 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
||||
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
|
||||
continue
|
||||
|
||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
|
||||
# 只有当获取到有效 person_id 时才添加
|
||||
if person_id:
|
||||
|
||||
@@ -1,223 +0,0 @@
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.message import MessageRecv, MessageSending, Message
|
||||
from src.common.database.database_model import Messages, ThinkingLog
|
||||
import time
|
||||
import traceback
|
||||
from typing import List
|
||||
import json
|
||||
|
||||
|
||||
class InfoCatcher:
|
||||
def __init__(self):
|
||||
self.chat_history = [] # 聊天历史,长度为三倍使用的上下文喵~
|
||||
self.chat_history_in_thinking = [] # 思考期间的聊天内容喵~
|
||||
self.chat_history_after_response = [] # 回复后的聊天内容,长度为一倍上下文喵~
|
||||
|
||||
self.chat_id = ""
|
||||
self.trigger_response_text = ""
|
||||
self.response_text = ""
|
||||
|
||||
self.trigger_response_time = 0
|
||||
self.trigger_response_message = None
|
||||
|
||||
self.response_time = 0
|
||||
self.response_messages = []
|
||||
|
||||
# 使用字典来存储 heartflow 模式的数据
|
||||
self.heartflow_data = {
|
||||
"heart_flow_prompt": "",
|
||||
"sub_heartflow_before": "",
|
||||
"sub_heartflow_now": "",
|
||||
"sub_heartflow_after": "",
|
||||
"sub_heartflow_model": "",
|
||||
"prompt": "",
|
||||
"response": "",
|
||||
"model": "",
|
||||
}
|
||||
|
||||
# 使用字典来存储 reasoning 模式的数据喵~
|
||||
self.reasoning_data = {"thinking_log": "", "prompt": "", "response": "", "model": ""}
|
||||
|
||||
# 耗时喵~
|
||||
self.timing_results = {
|
||||
"interested_rate_time": 0,
|
||||
"sub_heartflow_observe_time": 0,
|
||||
"sub_heartflow_step_time": 0,
|
||||
"make_response_time": 0,
|
||||
}
|
||||
|
||||
def catch_decide_to_response(self, message: MessageRecv):
|
||||
# 搜集决定回复时的信息
|
||||
self.trigger_response_message = message
|
||||
self.trigger_response_text = message.detailed_plain_text
|
||||
|
||||
self.trigger_response_time = time.time()
|
||||
|
||||
self.chat_id = message.chat_stream.stream_id
|
||||
|
||||
self.chat_history = self.get_message_from_db_before_msg(message)
|
||||
|
||||
def catch_after_observe(self, obs_duration: float): # 这里可以有更多信息
|
||||
self.timing_results["sub_heartflow_observe_time"] = obs_duration
|
||||
|
||||
def catch_afer_shf_step(self, step_duration: float, past_mind: str, current_mind: str):
|
||||
self.timing_results["sub_heartflow_step_time"] = step_duration
|
||||
if len(past_mind) > 1:
|
||||
self.heartflow_data["sub_heartflow_before"] = past_mind[-1]
|
||||
self.heartflow_data["sub_heartflow_now"] = current_mind
|
||||
else:
|
||||
self.heartflow_data["sub_heartflow_before"] = past_mind[-1]
|
||||
self.heartflow_data["sub_heartflow_now"] = current_mind
|
||||
|
||||
def catch_after_llm_generated(self, prompt: str, response: str, reasoning_content: str = "", model_name: str = ""):
|
||||
self.reasoning_data["thinking_log"] = reasoning_content
|
||||
self.reasoning_data["prompt"] = prompt
|
||||
self.reasoning_data["response"] = response
|
||||
self.reasoning_data["model"] = model_name
|
||||
|
||||
self.response_text = response
|
||||
|
||||
def catch_after_generate_response(self, response_duration: float):
|
||||
self.timing_results["make_response_time"] = response_duration
|
||||
|
||||
def catch_after_response(
|
||||
self, response_duration: float, response_message: List[str], first_bot_msg: MessageSending
|
||||
):
|
||||
self.timing_results["make_response_time"] = response_duration
|
||||
self.response_time = time.time()
|
||||
self.response_messages = []
|
||||
for msg in response_message:
|
||||
self.response_messages.append(msg)
|
||||
|
||||
self.chat_history_in_thinking = self.get_message_from_db_between_msgs(
|
||||
self.trigger_response_message, first_bot_msg
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_message_from_db_between_msgs(message_start: Message, message_end: Message):
|
||||
try:
|
||||
time_start = message_start.message_info.time
|
||||
time_end = message_end.message_info.time
|
||||
chat_id = message_start.chat_stream.stream_id
|
||||
|
||||
# print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}")
|
||||
|
||||
messages_between_query = (
|
||||
Messages.select()
|
||||
.where((Messages.chat_id == chat_id) & (Messages.time > time_start) & (Messages.time < time_end))
|
||||
.order_by(Messages.time.desc())
|
||||
)
|
||||
|
||||
result = list(messages_between_query)
|
||||
# print(f"查询结果数量: {len(result)}")
|
||||
# if result:
|
||||
# print(f"第一条消息时间: {result[0].time}")
|
||||
# print(f"最后一条消息时间: {result[-1].time}")
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"获取消息时出错: {str(e)}")
|
||||
print(traceback.format_exc())
|
||||
return []
|
||||
|
||||
def get_message_from_db_before_msg(self, message: MessageRecv):
|
||||
message_id_val = message.message_info.message_id
|
||||
chat_id_val = message.chat_stream.stream_id
|
||||
|
||||
messages_before_query = (
|
||||
Messages.select()
|
||||
.where((Messages.chat_id == chat_id_val) & (Messages.message_id < message_id_val))
|
||||
.order_by(Messages.time.desc())
|
||||
.limit(global_config.focus_chat.observation_context_size * 3)
|
||||
)
|
||||
|
||||
return list(messages_before_query)
|
||||
|
||||
def message_list_to_dict(self, message_list):
|
||||
result = []
|
||||
for msg_item in message_list:
|
||||
processed_msg_item = msg_item
|
||||
if not isinstance(msg_item, dict):
|
||||
processed_msg_item = self.message_to_dict(msg_item)
|
||||
|
||||
if not processed_msg_item:
|
||||
continue
|
||||
|
||||
lite_message = {
|
||||
"time": processed_msg_item.get("time"),
|
||||
"user_nickname": processed_msg_item.get("user_nickname"),
|
||||
"processed_plain_text": processed_msg_item.get("processed_plain_text"),
|
||||
}
|
||||
result.append(lite_message)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def message_to_dict(msg_obj):
|
||||
if not msg_obj:
|
||||
return None
|
||||
if isinstance(msg_obj, dict):
|
||||
return msg_obj
|
||||
|
||||
if isinstance(msg_obj, Messages):
|
||||
return {
|
||||
"time": msg_obj.time,
|
||||
"user_id": msg_obj.user_id,
|
||||
"user_nickname": msg_obj.user_nickname,
|
||||
"processed_plain_text": msg_obj.processed_plain_text,
|
||||
}
|
||||
|
||||
if hasattr(msg_obj, "message_info") and hasattr(msg_obj.message_info, "user_info"):
|
||||
return {
|
||||
"time": msg_obj.message_info.time,
|
||||
"user_id": msg_obj.message_info.user_info.user_id,
|
||||
"user_nickname": msg_obj.message_info.user_info.user_nickname,
|
||||
"processed_plain_text": msg_obj.processed_plain_text,
|
||||
}
|
||||
|
||||
print(f"Warning: message_to_dict received an unhandled type: {type(msg_obj)}")
|
||||
return {}
|
||||
|
||||
def done_catch(self):
|
||||
"""将收集到的信息存储到数据库的 thinking_log 表中喵~"""
|
||||
try:
|
||||
trigger_info_dict = self.message_to_dict(self.trigger_response_message)
|
||||
response_info_dict = {
|
||||
"time": self.response_time,
|
||||
"message": self.response_messages,
|
||||
}
|
||||
chat_history_list = self.message_list_to_dict(self.chat_history)
|
||||
chat_history_in_thinking_list = self.message_list_to_dict(self.chat_history_in_thinking)
|
||||
chat_history_after_response_list = self.message_list_to_dict(self.chat_history_after_response)
|
||||
|
||||
log_entry = ThinkingLog(
|
||||
chat_id=self.chat_id,
|
||||
trigger_text=self.trigger_response_text,
|
||||
response_text=self.response_text,
|
||||
trigger_info_json=json.dumps(trigger_info_dict) if trigger_info_dict else None,
|
||||
response_info_json=json.dumps(response_info_dict),
|
||||
timing_results_json=json.dumps(self.timing_results),
|
||||
chat_history_json=json.dumps(chat_history_list),
|
||||
chat_history_in_thinking_json=json.dumps(chat_history_in_thinking_list),
|
||||
chat_history_after_response_json=json.dumps(chat_history_after_response_list),
|
||||
heartflow_data_json=json.dumps(self.heartflow_data),
|
||||
reasoning_data_json=json.dumps(self.reasoning_data),
|
||||
)
|
||||
log_entry.save()
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"存储思考日志时出错: {str(e)} 喵~")
|
||||
print(traceback.format_exc())
|
||||
return False
|
||||
|
||||
|
||||
class InfoCatcherManager:
|
||||
def __init__(self):
|
||||
self.info_catchers = {}
|
||||
|
||||
def get_info_catcher(self, thinking_id: str) -> InfoCatcher:
|
||||
if thinking_id not in self.info_catchers:
|
||||
self.info_catchers[thinking_id] = InfoCatcher()
|
||||
return self.info_catchers[thinking_id]
|
||||
|
||||
|
||||
info_catcher_manager = InfoCatcherManager()
|
||||
@@ -1,88 +0,0 @@
|
||||
import sys
|
||||
import loguru
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class LogClassification(Enum):
|
||||
BASE = "base"
|
||||
MEMORY = "memory"
|
||||
EMOJI = "emoji"
|
||||
CHAT = "chat"
|
||||
PBUILDER = "promptbuilder"
|
||||
|
||||
|
||||
class LogModule:
|
||||
logger = loguru.logger.opt()
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def setup_logger(self, log_type: LogClassification):
|
||||
"""配置日志格式
|
||||
|
||||
Args:
|
||||
log_type: 日志类型,可选值:BASE(基础日志)、MEMORY(记忆系统日志)、EMOJI(表情包系统日志)
|
||||
"""
|
||||
# 移除默认日志处理器
|
||||
self.logger.remove()
|
||||
|
||||
# 基础日志格式
|
||||
base_format = (
|
||||
"<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | "
|
||||
" d<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
|
||||
)
|
||||
|
||||
chat_format = (
|
||||
"<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | "
|
||||
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
|
||||
)
|
||||
|
||||
# 记忆系统日志格式
|
||||
memory_format = (
|
||||
"<green>{time:HH:mm}</green> | <level>{level: <8}</level> | "
|
||||
"<light-magenta>海马体</light-magenta> | <level>{message}</level>"
|
||||
)
|
||||
|
||||
# 表情包系统日志格式
|
||||
emoji_format = (
|
||||
"<green>{time:HH:mm}</green> | <level>{level: <8}</level> | <yellow>表情包</yellow> | "
|
||||
"<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
|
||||
)
|
||||
|
||||
promptbuilder_format = (
|
||||
"<green>{time:HH:mm}</green> | <level>{level: <8}</level> | <yellow>Prompt</yellow> | "
|
||||
"<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
|
||||
)
|
||||
|
||||
# 根据日志类型选择日志格式和输出
|
||||
if log_type == LogClassification.CHAT:
|
||||
self.logger.add(
|
||||
sys.stderr,
|
||||
format=chat_format,
|
||||
# level="INFO"
|
||||
)
|
||||
elif log_type == LogClassification.PBUILDER:
|
||||
self.logger.add(
|
||||
sys.stderr,
|
||||
format=promptbuilder_format,
|
||||
# level="INFO"
|
||||
)
|
||||
elif log_type == LogClassification.MEMORY:
|
||||
# 同时输出到控制台和文件
|
||||
self.logger.add(
|
||||
sys.stderr,
|
||||
format=memory_format,
|
||||
# level="INFO"
|
||||
)
|
||||
self.logger.add("logs/memory.log", format=memory_format, level="INFO", rotation="1 day", retention="7 days")
|
||||
elif log_type == LogClassification.EMOJI:
|
||||
self.logger.add(
|
||||
sys.stderr,
|
||||
format=emoji_format,
|
||||
# level="INFO"
|
||||
)
|
||||
self.logger.add("logs/emoji.log", format=emoji_format, level="INFO", rotation="1 day", retention="7 days")
|
||||
else: # BASE
|
||||
self.logger.add(sys.stderr, format=base_format, level="INFO")
|
||||
|
||||
return self.logger
|
||||
@@ -3,14 +3,14 @@ import re
|
||||
from contextlib import asynccontextmanager
|
||||
import asyncio
|
||||
import contextvars
|
||||
from src.common.logger import get_module_logger
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# import traceback
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_module_logger("prompt_build")
|
||||
logger = get_logger("prompt_build")
|
||||
|
||||
|
||||
class PromptContext:
|
||||
|
||||
@@ -3,14 +3,14 @@ from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Tuple, List
|
||||
|
||||
|
||||
from src.common.logger import get_module_logger
|
||||
from src.common.logger import get_logger
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
|
||||
from ...common.database.database import db # This db is the Peewee database instance
|
||||
from ...common.database.database_model import OnlineTime, LLMUsage, Messages # Import the Peewee model
|
||||
from src.manager.local_store_manager import local_storage
|
||||
|
||||
logger = get_module_logger("maibot_statistic")
|
||||
logger = get_logger("maibot_statistic")
|
||||
|
||||
# 统计数据的键
|
||||
TOTAL_REQ_CNT = "total_requests"
|
||||
|
||||
@@ -111,11 +111,13 @@ class Timer:
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
with self:
|
||||
return await func(*args, **kwargs)
|
||||
return None
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
with self:
|
||||
return func(*args, **kwargs)
|
||||
return None
|
||||
|
||||
wrapper = async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
wrapper.__timer__ = self # 保留计时器引用
|
||||
|
||||
@@ -13,9 +13,9 @@ from pathlib import Path
|
||||
import jieba
|
||||
from pypinyin import Style, pinyin
|
||||
|
||||
from src.common.logger import get_module_logger
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_module_logger("typo_gen")
|
||||
logger = get_logger("typo_gen")
|
||||
|
||||
|
||||
class ChineseTypoGenerator:
|
||||
|
||||
@@ -7,7 +7,7 @@ import jieba
|
||||
import numpy as np
|
||||
from maim_message import UserInfo
|
||||
|
||||
from src.common.logger import get_module_logger
|
||||
from src.common.logger import get_logger
|
||||
from src.manager.mood_manager import mood_manager
|
||||
from ..message_receive.message import MessageRecv
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
@@ -15,7 +15,7 @@ from .typo_generator import ChineseTypoGenerator
|
||||
from ...config.config import global_config
|
||||
from ...common.message_repository import find_messages, count_messages
|
||||
|
||||
logger = get_module_logger("chat_utils")
|
||||
logger = get_logger("chat_utils")
|
||||
|
||||
|
||||
def is_english_letter(char: str) -> bool:
|
||||
@@ -247,8 +247,6 @@ def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
|
||||
|
||||
# 如果分割后为空(例如,输入全是分隔符且不满足保留条件),恢复颜文字并返回
|
||||
if not segments:
|
||||
# recovered_text = recover_kaomoji([text], mapping) # 恢复原文本中的颜文字 - 已移至上层处理
|
||||
# return [s for s in recovered_text if s] # 返回非空结果
|
||||
return [text] if text else [] # 如果原始文本非空,则返回原始文本(可能只包含未被分割的字符或颜文字占位符)
|
||||
|
||||
# 2. 概率合并
|
||||
@@ -324,16 +322,18 @@ def random_remove_punctuation(text: str) -> str:
|
||||
|
||||
|
||||
def process_llm_response(text: str) -> list[str]:
|
||||
if not global_config.response_post_process.enable_response_post_process:
|
||||
return [text]
|
||||
|
||||
# 先保护颜文字
|
||||
if global_config.response_splitter.enable_kaomoji_protection:
|
||||
protected_text, kaomoji_mapping = protect_kaomoji(text)
|
||||
logger.trace(f"保护颜文字后的文本: {protected_text}")
|
||||
logger.debug(f"保护颜文字后的文本: {protected_text}")
|
||||
else:
|
||||
protected_text = text
|
||||
kaomoji_mapping = {}
|
||||
# 提取被 () 或 [] 或 ()包裹且包含中文的内容
|
||||
pattern = re.compile(r"[(\[(](?=.*[一-鿿]).*?[)\])]")
|
||||
# _extracted_contents = pattern.findall(text)
|
||||
_extracted_contents = pattern.findall(protected_text) # 在保护后的文本上查找
|
||||
# 去除 () 和 [] 及其包裹的内容
|
||||
cleaned_text = pattern.sub("", protected_text)
|
||||
|
||||
@@ -13,7 +13,7 @@ from src.common.database.database_model import Images, ImageDescriptions
|
||||
from src.config.config import global_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.common.logger import get_logger
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -228,7 +228,7 @@ class ImageManager:
|
||||
description=description,
|
||||
timestamp=current_timestamp,
|
||||
)
|
||||
logger.trace(f"保存图片元数据: {file_path}")
|
||||
logger.debug(f"保存图片元数据: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存图片文件或元数据失败: {str(e)}")
|
||||
|
||||
@@ -288,7 +288,7 @@ class ImageManager:
|
||||
# 计算和上一张选中帧的差异(均方误差 MSE)
|
||||
if last_selected_frame_np is not None:
|
||||
mse = np.mean((current_frame_np - last_selected_frame_np) ** 2)
|
||||
# logger.trace(f"帧 {i} 与上一选中帧的 MSE: {mse}") # 可以取消注释来看差异值
|
||||
# logger.debug(f"帧 {i} 与上一选中帧的 MSE: {mse}") # 可以取消注释来看差异值
|
||||
|
||||
# 如果差异够大,就选它!
|
||||
if mse > similarity_threshold:
|
||||
@@ -362,7 +362,15 @@ class ImageManager:
|
||||
|
||||
|
||||
# 创建全局单例
|
||||
image_manager = ImageManager()
|
||||
image_manager = None
|
||||
|
||||
|
||||
def get_image_manager() -> ImageManager:
|
||||
"""获取全局图片管理器单例"""
|
||||
global image_manager
|
||||
if image_manager is None:
|
||||
image_manager = ImageManager()
|
||||
return image_manager
|
||||
|
||||
|
||||
def image_path_to_base64(image_path: str) -> str:
|
||||
|
||||
Reference in New Issue
Block a user