Merge afc branch into dev, prioritizing afc changes and migrating database async modifications from dev
This commit is contained in:
25
bot.py
25
bot.py
@@ -35,16 +35,18 @@ script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
os.chdir(script_dir)
|
||||
logger.info(f"已设置工作目录为: {script_dir}")
|
||||
|
||||
|
||||
# 检查并创建.env文件
|
||||
def ensure_env_file():
|
||||
"""确保.env文件存在,如果不存在则从模板创建"""
|
||||
env_file = Path(".env")
|
||||
template_env = Path("template/template.env")
|
||||
|
||||
|
||||
if not env_file.exists():
|
||||
if template_env.exists():
|
||||
logger.info("未找到.env文件,正在从模板创建...")
|
||||
import shutil
|
||||
|
||||
shutil.copy(template_env, env_file)
|
||||
logger.info("已从template/template.env创建.env文件")
|
||||
logger.warning("请编辑.env文件,将EULA_CONFIRMED设置为true并配置其他必要参数")
|
||||
@@ -52,6 +54,7 @@ def ensure_env_file():
|
||||
logger.error("未找到.env文件和template.env模板文件")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# 确保环境文件存在
|
||||
ensure_env_file()
|
||||
|
||||
@@ -131,32 +134,32 @@ async def graceful_shutdown():
|
||||
def check_eula():
|
||||
"""检查EULA和隐私条款确认状态 - 环境变量版(类似Minecraft)"""
|
||||
# 检查环境变量中的EULA确认
|
||||
eula_confirmed = os.getenv('EULA_CONFIRMED', '').lower()
|
||||
|
||||
if eula_confirmed == 'true':
|
||||
eula_confirmed = os.getenv("EULA_CONFIRMED", "").lower()
|
||||
|
||||
if eula_confirmed == "true":
|
||||
logger.info("EULA已通过环境变量确认")
|
||||
return
|
||||
|
||||
|
||||
# 如果没有确认,提示用户
|
||||
confirm_logger.critical("您需要同意EULA和隐私条款才能使用MoFox_Bot")
|
||||
confirm_logger.critical("请阅读以下文件:")
|
||||
confirm_logger.critical(" - EULA.md (用户许可协议)")
|
||||
confirm_logger.critical(" - PRIVACY.md (隐私条款)")
|
||||
confirm_logger.critical("然后编辑 .env 文件,将 'EULA_CONFIRMED=false' 改为 'EULA_CONFIRMED=true'")
|
||||
|
||||
|
||||
# 等待用户确认
|
||||
while True:
|
||||
try:
|
||||
load_dotenv(override=True) # 重新加载.env文件
|
||||
|
||||
eula_confirmed = os.getenv('EULA_CONFIRMED', '').lower()
|
||||
if eula_confirmed == 'true':
|
||||
|
||||
eula_confirmed = os.getenv("EULA_CONFIRMED", "").lower()
|
||||
if eula_confirmed == "true":
|
||||
confirm_logger.info("EULA确认成功,感谢您的同意")
|
||||
return
|
||||
|
||||
|
||||
confirm_logger.critical("请修改 .env 文件中的 EULA_CONFIRMED=true 后重新启动程序")
|
||||
input("按Enter键检查.env文件状态...")
|
||||
|
||||
|
||||
except KeyboardInterrupt:
|
||||
confirm_logger.info("用户取消,程序退出")
|
||||
sys.exit(0)
|
||||
|
||||
@@ -25,7 +25,7 @@
|
||||
{
|
||||
"type": "action",
|
||||
"name": "set_emoji_like",
|
||||
"description": "为消息设置表情回应"
|
||||
"description": "为某条已经存在的消息添加‘贴表情’回应(类似点赞),而不是发送新消息。当用户明确要求‘贴表情’时使用。"
|
||||
}
|
||||
],
|
||||
"features": [
|
||||
|
||||
@@ -45,7 +45,7 @@ class SetEmojiLikeAction(BaseAction):
|
||||
|
||||
# === 基本信息(必须填写)===
|
||||
action_name = "set_emoji_like"
|
||||
action_description = "为一个已存在的消息添加点赞或表情回应(也叫‘贴表情’)"
|
||||
action_description = "为某条已经存在的消息添加‘贴表情’回应(类似点赞),而不是发送新消息。可以在觉得某条消息非常有趣、值得赞同或者需要特殊情感回应时主动使用。"
|
||||
activation_type = ActionActivationType.ALWAYS # 消息接收时激活(?)
|
||||
chat_type_allow = ChatType.GROUP
|
||||
parallel_action = True
|
||||
|
||||
@@ -20,25 +20,26 @@ files_to_update = [
|
||||
"src/mais4u/mais4u_chat/s4u_mood_manager.py",
|
||||
"src/plugin_system/core/tool_use.py",
|
||||
"src/chat/memory_system/memory_activator.py",
|
||||
"src/chat/utils/smart_prompt.py"
|
||||
"src/chat/utils/smart_prompt.py",
|
||||
]
|
||||
|
||||
|
||||
def update_prompt_imports(file_path):
|
||||
"""更新文件中的Prompt导入"""
|
||||
if not os.path.exists(file_path):
|
||||
print(f"文件不存在: {file_path}")
|
||||
return False
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
|
||||
# 替换导入语句
|
||||
old_import = "from src.chat.utils.prompt_builder import Prompt, global_prompt_manager"
|
||||
new_import = "from src.chat.utils.prompt import Prompt, global_prompt_manager"
|
||||
|
||||
|
||||
if old_import in content:
|
||||
new_content = content.replace(old_import, new_import)
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(new_content)
|
||||
print(f"已更新: {file_path}")
|
||||
return True
|
||||
@@ -46,14 +47,16 @@ def update_prompt_imports(file_path):
|
||||
print(f"无需更新: {file_path}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
updated_count = 0
|
||||
for file_path in files_to_update:
|
||||
if update_prompt_imports(file_path):
|
||||
updated_count += 1
|
||||
|
||||
|
||||
print(f"\n更新完成!共更新了 {updated_count} 个文件")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@@ -1,460 +0,0 @@
|
||||
import asyncio
|
||||
import time
|
||||
import traceback
|
||||
import math
|
||||
import random
|
||||
from typing import Dict, Any, Tuple
|
||||
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.planner_actions.planner import ActionPlanner
|
||||
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.plugin_system.apis import database_api, generator_api
|
||||
from src.plugin_system.base.component_types import ChatMode
|
||||
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||
from src.chat.chat_loop.hfc_utils import send_typing, stop_typing
|
||||
from .hfc_context import HfcContext
|
||||
from .response_handler import ResponseHandler
|
||||
from .cycle_tracker import CycleTracker
|
||||
|
||||
# 日志记录器
|
||||
logger = get_logger("hfc.processor")
|
||||
|
||||
|
||||
class CycleProcessor:
|
||||
"""
|
||||
循环处理器类,负责处理单次思考循环的逻辑。
|
||||
"""
|
||||
def __init__(self, context: HfcContext, response_handler: ResponseHandler, cycle_tracker: CycleTracker):
|
||||
"""
|
||||
初始化循环处理器
|
||||
|
||||
Args:
|
||||
context: HFC聊天上下文对象,包含聊天流、能量值等信息
|
||||
response_handler: 响应处理器,负责生成和发送回复
|
||||
cycle_tracker: 循环跟踪器,负责记录和管理每次思考循环的信息
|
||||
"""
|
||||
self.context = context
|
||||
self.response_handler = response_handler
|
||||
self.cycle_tracker = cycle_tracker
|
||||
self.action_planner = ActionPlanner(chat_id=self.context.stream_id, action_manager=self.context.action_manager)
|
||||
self.action_modifier = ActionModifier(
|
||||
action_manager=self.context.action_manager, chat_id=self.context.stream_id
|
||||
)
|
||||
|
||||
self.log_prefix = self.context.log_prefix
|
||||
|
||||
async def _send_and_store_reply(
|
||||
self,
|
||||
response_set,
|
||||
loop_start_time,
|
||||
action_message,
|
||||
cycle_timers: Dict[str, float],
|
||||
thinking_id,
|
||||
actions,
|
||||
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
|
||||
"""
|
||||
发送并存储回复信息
|
||||
|
||||
Args:
|
||||
response_set: 回复内容集合
|
||||
loop_start_time: 循环开始时间
|
||||
action_message: 动作消息
|
||||
cycle_timers: 循环计时器
|
||||
thinking_id: 思考ID
|
||||
actions: 动作列表
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, Any], str, Dict[str, float]]: 循环信息, 回复文本, 循环计时器
|
||||
"""
|
||||
# 发送回复
|
||||
with Timer("回复发送", cycle_timers):
|
||||
reply_text = await self.response_handler.send_response(response_set, loop_start_time, action_message)
|
||||
|
||||
# 存储reply action信息
|
||||
person_info_manager = get_person_info_manager()
|
||||
|
||||
# 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
|
||||
platform = action_message.get("chat_info_platform")
|
||||
if platform is None:
|
||||
platform = getattr(self.context.chat_stream, "platform", "unknown")
|
||||
|
||||
# 获取用户信息并生成回复提示
|
||||
person_id = person_info_manager.get_person_id(
|
||||
platform,
|
||||
action_message.get("chat_info_user_id", ""),
|
||||
)
|
||||
person_info = await person_info_manager.get_values(person_id, ["person_name"])
|
||||
person_name = person_info.get("person_name")
|
||||
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
|
||||
|
||||
# 存储动作信息到数据库
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.context.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=action_prompt_display,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reply_text": reply_text},
|
||||
action_name="reply",
|
||||
)
|
||||
|
||||
# 构建循环信息
|
||||
loop_info: Dict[str, Any] = {
|
||||
"loop_plan_info": {
|
||||
"action_result": actions,
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": True,
|
||||
"reply_text": reply_text,
|
||||
"command": "",
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
|
||||
return loop_info, reply_text, cycle_timers
|
||||
|
||||
async def observe(self, interest_value: float = 0.0) -> str:
|
||||
"""
|
||||
观察和处理单次思考循环的核心方法
|
||||
|
||||
Args:
|
||||
interest_value: 兴趣值
|
||||
|
||||
Returns:
|
||||
str: 动作类型
|
||||
|
||||
功能说明:
|
||||
- 开始新的思考循环并记录计时
|
||||
- 修改可用动作并获取动作列表
|
||||
- 根据聊天模式和提及情况决定是否跳过规划器
|
||||
- 执行动作规划或直接回复
|
||||
- 根据动作类型分发到相应的处理方法
|
||||
"""
|
||||
action_type = "no_action"
|
||||
reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
||||
|
||||
# 使用sigmoid函数将interest_value转换为概率
|
||||
# 当interest_value为0时,概率接近0(使用Focus模式)
|
||||
# 当interest_value很高时,概率接近1(使用Normal模式)
|
||||
def calculate_normal_mode_probability(interest_val: float) -> float:
|
||||
"""
|
||||
计算普通模式的概率
|
||||
|
||||
Args:
|
||||
interest_val: 兴趣值
|
||||
|
||||
Returns:
|
||||
float: 概率
|
||||
"""
|
||||
# 使用sigmoid函数,调整参数使概率分布更合理
|
||||
# 当interest_value = 0时,概率约为0.1
|
||||
# 当interest_value = 1时,概率约为0.5
|
||||
# 当interest_value = 2时,概率约为0.8
|
||||
# 当interest_value = 3时,概率约为0.95
|
||||
k = 2.0 # 控制曲线陡峭程度
|
||||
x0 = 1.0 # 控制曲线中心点
|
||||
return 1.0 / (1.0 + math.exp(-k * (interest_val - x0)))
|
||||
|
||||
# 计算普通模式概率
|
||||
normal_mode_probability = (
|
||||
calculate_normal_mode_probability(interest_value)
|
||||
* 0.5
|
||||
/ global_config.chat.get_current_talk_frequency(self.context.stream_id)
|
||||
)
|
||||
|
||||
# 根据概率决定使用哪种模式
|
||||
if random.random() < normal_mode_probability:
|
||||
mode = ChatMode.NORMAL
|
||||
logger.info(
|
||||
f"{self.log_prefix} 基于兴趣值 {interest_value:.2f},概率 {normal_mode_probability:.2f},选择Normal planner模式"
|
||||
)
|
||||
else:
|
||||
mode = ChatMode.FOCUS
|
||||
logger.info(
|
||||
f"{self.log_prefix} 基于兴趣值 {interest_value:.2f},概率 {normal_mode_probability:.2f},选择Focus planner模式"
|
||||
)
|
||||
|
||||
# 开始新的思考循环
|
||||
cycle_timers, thinking_id = self.cycle_tracker.start_cycle()
|
||||
logger.info(f"{self.log_prefix} 开始第{self.context.cycle_counter}次思考")
|
||||
|
||||
if ENABLE_S4U and self.context.chat_stream and self.context.chat_stream.user_info:
|
||||
await send_typing(self.context.chat_stream.user_info.user_id)
|
||||
|
||||
loop_start_time = time.time()
|
||||
|
||||
# 第一步:动作修改
|
||||
with Timer("动作修改", cycle_timers):
|
||||
try:
|
||||
await self.action_modifier.modify_actions()
|
||||
available_actions = self.context.action_manager.get_using_actions()
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 动作修改失败: {e}")
|
||||
available_actions = {}
|
||||
|
||||
# 规划动作
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from src.plugin_system import EventType
|
||||
|
||||
result = await event_manager.trigger_event(
|
||||
EventType.ON_PLAN, permission_group="SYSTEM", stream_id=self.context.chat_stream
|
||||
)
|
||||
if result and not result.all_continue_process():
|
||||
raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于规划前中断了内容生成")
|
||||
with Timer("规划器", cycle_timers):
|
||||
actions, _ = await self.action_planner.plan(mode=mode)
|
||||
|
||||
async def execute_action(action_info):
|
||||
"""执行单个动作的通用函数"""
|
||||
try:
|
||||
if action_info["action_type"] == "no_action":
|
||||
return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""}
|
||||
if action_info["action_type"] == "no_reply":
|
||||
# 直接处理no_reply逻辑,不再通过动作系统
|
||||
reason = action_info.get("reasoning", "选择不回复")
|
||||
logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
|
||||
|
||||
# 存储no_reply信息到数据库
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.context.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reason": reason},
|
||||
action_name="no_reply",
|
||||
)
|
||||
|
||||
return {"action_type": "no_reply", "success": True, "reply_text": "", "command": ""}
|
||||
elif action_info["action_type"] != "reply" and action_info["action_type"] != "no_action":
|
||||
# 记录并执行普通动作
|
||||
reason = action_info.get("reasoning", f"执行动作 {action_info['action_type']}")
|
||||
logger.info(f"{self.log_prefix} 决定执行动作 '{action_info['action_type']}',内心思考: {reason}")
|
||||
with Timer("动作执行", cycle_timers):
|
||||
success, reply_text, command = await self._handle_action(
|
||||
action_info["action_type"],
|
||||
reason, # 使用已获取的reason
|
||||
action_info["action_data"],
|
||||
cycle_timers,
|
||||
thinking_id,
|
||||
action_info["action_message"],
|
||||
)
|
||||
return {
|
||||
"action_type": action_info["action_type"],
|
||||
"success": success,
|
||||
"reply_text": reply_text,
|
||||
"command": command,
|
||||
}
|
||||
else:
|
||||
# 生成回复
|
||||
try:
|
||||
reason = action_info.get("reasoning", "决定进行回复")
|
||||
logger.info(f"{self.log_prefix} 决定进行回复,内心思考: {reason}")
|
||||
success, response_set, _ = await generator_api.generate_reply(
|
||||
chat_stream=self.context.chat_stream,
|
||||
reply_message=action_info["action_message"],
|
||||
available_actions=available_actions,
|
||||
enable_tool=global_config.tool.enable_tool,
|
||||
request_type="chat.replyer",
|
||||
from_plugin=False,
|
||||
read_mark=action_info.get("action_message", {}).get("time", 0.0),
|
||||
)
|
||||
if not success or not response_set:
|
||||
logger.info(
|
||||
f"对 {action_info['action_message'].get('processed_plain_text')} 的回复生成失败"
|
||||
)
|
||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
|
||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||
|
||||
# 发送并存储回复
|
||||
loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply(
|
||||
response_set,
|
||||
loop_start_time,
|
||||
action_info["action_message"],
|
||||
cycle_timers,
|
||||
thinking_id,
|
||||
actions,
|
||||
)
|
||||
return {"action_type": "reply", "success": True, "reply_text": reply_text, "loop_info": loop_info}
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行动作时出错: {e}")
|
||||
logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}")
|
||||
return {
|
||||
"action_type": action_info["action_type"],
|
||||
"success": False,
|
||||
"reply_text": "",
|
||||
"loop_info": None,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
# 分离 reply 动作和其他动作
|
||||
reply_actions = [a for a in actions if a.get("action_type") == "reply"]
|
||||
other_actions = [a for a in actions if a.get("action_type") != "reply"]
|
||||
|
||||
reply_loop_info = None
|
||||
reply_text_from_reply = ""
|
||||
other_actions_results = []
|
||||
|
||||
# 1. 首先串行执行所有 reply 动作(通常只有一个)
|
||||
if reply_actions:
|
||||
logger.info(f"{self.log_prefix} 正在执行文本回复...")
|
||||
for action in reply_actions:
|
||||
action_message = action.get("action_message")
|
||||
if not action_message:
|
||||
logger.warning(f"{self.log_prefix} reply 动作缺少 action_message,跳过")
|
||||
continue
|
||||
|
||||
# 检查是否是空的DatabaseMessages对象
|
||||
if hasattr(action_message, 'chat_info') and hasattr(action_message.chat_info, 'user_info'):
|
||||
target_user_id = action_message.chat_info.user_info.user_id
|
||||
else:
|
||||
# 如果是字典格式,使用原来的方式
|
||||
target_user_id = action_message.get("chat_info_user_id", "")
|
||||
|
||||
if not target_user_id:
|
||||
logger.warning(f"{self.log_prefix} reply 动作的 action_message 缺少用户ID,跳过")
|
||||
continue
|
||||
|
||||
if target_user_id == global_config.bot.qq_account and not global_config.chat.allow_reply_self:
|
||||
logger.warning("选取的reply的目标为bot自己,跳过reply action")
|
||||
continue
|
||||
result = await execute_action(action)
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"{self.log_prefix} 回复动作执行异常: {result}")
|
||||
continue
|
||||
if result.get("success"):
|
||||
reply_loop_info = result.get("loop_info")
|
||||
reply_text_from_reply = result.get("reply_text", "")
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 回复动作执行失败")
|
||||
|
||||
# 2. 然后并行执行所有其他动作
|
||||
if other_actions:
|
||||
logger.info(f"{self.log_prefix} 正在执行附加动作: {[a.get('action_type') for a in other_actions]}")
|
||||
other_action_tasks = [asyncio.create_task(execute_action(action)) for action in other_actions]
|
||||
results = await asyncio.gather(*other_action_tasks, return_exceptions=True)
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, BaseException):
|
||||
logger.error(f"{self.log_prefix} 附加动作执行异常: {result}")
|
||||
continue
|
||||
other_actions_results.append(result)
|
||||
|
||||
# 构建最终的循环信息
|
||||
if reply_loop_info:
|
||||
loop_info = reply_loop_info
|
||||
# 将其他动作的结果合并到loop_info中
|
||||
if "other_actions" not in loop_info["loop_action_info"]:
|
||||
loop_info["loop_action_info"]["other_actions"] = []
|
||||
loop_info["loop_action_info"]["other_actions"].extend(other_actions_results)
|
||||
reply_text = reply_text_from_reply
|
||||
else:
|
||||
# 没有回复信息,构建纯动作的loop_info
|
||||
# 即使没有回复,也要正确处理其他动作
|
||||
final_action_taken = any(res.get("success", False) for res in other_actions_results)
|
||||
final_reply_text = " ".join(res.get("reply_text", "") for res in other_actions_results if res.get("reply_text"))
|
||||
final_command = " ".join(res.get("command", "") for res in other_actions_results if res.get("command"))
|
||||
|
||||
loop_info = {
|
||||
"loop_plan_info": {
|
||||
"action_result": actions,
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": final_action_taken,
|
||||
"reply_text": final_reply_text,
|
||||
"command": final_command,
|
||||
"taken_time": time.time(),
|
||||
"other_actions": other_actions_results,
|
||||
},
|
||||
}
|
||||
reply_text = final_reply_text
|
||||
|
||||
# 停止正在输入状态
|
||||
if ENABLE_S4U:
|
||||
await stop_typing()
|
||||
|
||||
# 结束循环
|
||||
self.context.chat_instance.cycle_tracker.end_cycle(loop_info, cycle_timers)
|
||||
self.context.chat_instance.cycle_tracker.print_cycle_info(cycle_timers)
|
||||
|
||||
action_type = actions[0]["action_type"] if actions else "no_action"
|
||||
return action_type
|
||||
|
||||
async def _handle_action(
|
||||
self, action, reasoning, action_data, cycle_timers, thinking_id, action_message
|
||||
) -> tuple[bool, str, str]:
|
||||
"""
|
||||
处理具体的动作执行
|
||||
|
||||
Args:
|
||||
action: 动作名称
|
||||
reasoning: 执行理由
|
||||
action_data: 动作数据
|
||||
cycle_timers: 循环计时器
|
||||
thinking_id: 思考ID
|
||||
action_message: 动作消息
|
||||
|
||||
Returns:
|
||||
tuple: (执行是否成功, 回复文本, 命令文本)
|
||||
|
||||
功能说明:
|
||||
- 创建对应的动作处理器
|
||||
- 执行动作并捕获异常
|
||||
- 返回执行结果供上级方法整合
|
||||
"""
|
||||
if not self.context.chat_stream:
|
||||
return False, "", ""
|
||||
try:
|
||||
# 创建动作处理器
|
||||
action_handler = self.context.action_manager.create_action(
|
||||
action_name=action,
|
||||
action_data=action_data,
|
||||
reasoning=reasoning,
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
chat_stream=self.context.chat_stream,
|
||||
log_prefix=self.context.log_prefix,
|
||||
action_message=action_message,
|
||||
)
|
||||
if not action_handler:
|
||||
# 动作处理器创建失败,尝试回退机制
|
||||
logger.warning(f"{self.context.log_prefix} 创建动作处理器失败: {action},尝试回退方案")
|
||||
|
||||
# 获取当前可用的动作
|
||||
available_actions = self.context.action_manager.get_using_actions()
|
||||
fallback_action = None
|
||||
|
||||
# 回退优先级:reply > 第一个可用动作
|
||||
if "reply" in available_actions:
|
||||
fallback_action = "reply"
|
||||
elif available_actions:
|
||||
fallback_action = list(available_actions.keys())[0]
|
||||
|
||||
if fallback_action and fallback_action != action:
|
||||
logger.info(f"{self.context.log_prefix} 使用回退动作: {fallback_action}")
|
||||
action_handler = self.context.action_manager.create_action(
|
||||
action_name=fallback_action,
|
||||
action_data=action_data,
|
||||
reasoning=f"原动作'{action}'不可用,自动回退。{reasoning}",
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
chat_stream=self.context.chat_stream,
|
||||
log_prefix=self.context.log_prefix,
|
||||
action_message=action_message,
|
||||
)
|
||||
|
||||
if not action_handler:
|
||||
logger.error(f"{self.context.log_prefix} 回退方案也失败,无法创建任何动作处理器")
|
||||
return False, "", ""
|
||||
|
||||
# 执行动作
|
||||
success, reply_text = await action_handler.handle_action()
|
||||
return success, reply_text, ""
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 处理{action}时出错: {e}")
|
||||
traceback.print_exc()
|
||||
return False, "", ""
|
||||
@@ -1,114 +0,0 @@
|
||||
import time
|
||||
from typing import Dict, Any, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.chat_loop.hfc_utils import CycleDetail
|
||||
from .hfc_context import HfcContext
|
||||
|
||||
logger = get_logger("hfc")
|
||||
|
||||
|
||||
class CycleTracker:
|
||||
def __init__(self, context: HfcContext):
|
||||
"""
|
||||
初始化循环跟踪器
|
||||
|
||||
Args:
|
||||
context: HFC聊天上下文对象
|
||||
|
||||
功能说明:
|
||||
- 负责跟踪和记录每次思考循环的详细信息
|
||||
- 管理循环的开始、结束和信息存储
|
||||
"""
|
||||
self.context = context
|
||||
|
||||
def start_cycle(self, is_proactive: bool = False) -> Tuple[Dict[str, float], str]:
|
||||
"""
|
||||
开始新的思考循环
|
||||
|
||||
Args:
|
||||
is_proactive: 标记这个循环是否由主动思考发起
|
||||
|
||||
Returns:
|
||||
tuple: (循环计时器字典, 思考ID字符串)
|
||||
|
||||
功能说明:
|
||||
- 增加循环计数器
|
||||
- 创建新的循环详情对象
|
||||
- 生成唯一的思考ID
|
||||
- 初始化循环计时器
|
||||
"""
|
||||
if not is_proactive:
|
||||
self.context.cycle_counter += 1
|
||||
|
||||
cycle_id = self.context.cycle_counter if not is_proactive else f"{self.context.cycle_counter}.p"
|
||||
self.context.current_cycle_detail = CycleDetail(cycle_id)
|
||||
self.context.current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
|
||||
cycle_timers = {}
|
||||
return cycle_timers, self.context.current_cycle_detail.thinking_id
|
||||
|
||||
def end_cycle(self, loop_info: Dict[str, Any], cycle_timers: Dict[str, float]):
|
||||
"""
|
||||
结束当前思考循环
|
||||
|
||||
Args:
|
||||
loop_info: 循环信息,包含规划和动作信息
|
||||
cycle_timers: 循环计时器,记录各阶段耗时
|
||||
|
||||
功能说明:
|
||||
- 设置循环详情的完整信息
|
||||
- 将当前循环加入历史记录
|
||||
- 记录计时器和结束时间
|
||||
- 打印循环统计信息
|
||||
"""
|
||||
if self.context.current_cycle_detail:
|
||||
self.context.current_cycle_detail.set_loop_info(loop_info)
|
||||
self.context.history_loop.append(self.context.current_cycle_detail)
|
||||
self.context.current_cycle_detail.timers = cycle_timers
|
||||
self.context.current_cycle_detail.end_time = time.time()
|
||||
self.print_cycle_info(cycle_timers)
|
||||
|
||||
def print_cycle_info(self, cycle_timers: Dict[str, float]):
|
||||
"""
|
||||
打印循环统计信息
|
||||
|
||||
Args:
|
||||
cycle_timers: 循环计时器字典
|
||||
|
||||
功能说明:
|
||||
- 格式化各阶段的耗时信息
|
||||
- 计算总体循环持续时间
|
||||
- 输出详细的性能统计日志
|
||||
- 显示选择的动作类型
|
||||
"""
|
||||
if not self.context.current_cycle_detail:
|
||||
return
|
||||
|
||||
timer_strings = []
|
||||
for name, elapsed in cycle_timers.items():
|
||||
formatted_time = f"{elapsed * 1000:.2f}毫秒" if elapsed < 1 else f"{elapsed:.2f}秒"
|
||||
timer_strings.append(f"{name}: {formatted_time}")
|
||||
|
||||
# 获取动作类型,兼容新旧格式
|
||||
# 获取动作类型
|
||||
action_type = "未知动作"
|
||||
if self.context.current_cycle_detail:
|
||||
loop_plan_info = self.context.current_cycle_detail.loop_plan_info
|
||||
actions = loop_plan_info.get("action_result")
|
||||
|
||||
if isinstance(actions, list) and actions:
|
||||
# 从actions列表中提取所有action_type
|
||||
action_types = [a.get("action_type", "未知") for a in actions]
|
||||
action_type = ", ".join(action_types)
|
||||
elif isinstance(actions, dict):
|
||||
# 兼容旧格式
|
||||
action_type = actions.get("action_type", "未知动作")
|
||||
|
||||
|
||||
if self.context.current_cycle_detail.end_time and self.context.current_cycle_detail.start_time:
|
||||
duration = self.context.current_cycle_detail.end_time - self.context.current_cycle_detail.start_time
|
||||
logger.info(
|
||||
f"{self.context.log_prefix} 第{self.context.current_cycle_detail.cycle_id}次思考,"
|
||||
f"耗时: {duration:.1f}秒, "
|
||||
f"选择动作: {action_type}" + (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
|
||||
)
|
||||
@@ -1,162 +0,0 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from .hfc_context import HfcContext
|
||||
from src.chat.chat_loop.sleep_manager import sleep_manager
|
||||
logger = get_logger("hfc")
|
||||
|
||||
|
||||
class EnergyManager:
|
||||
def __init__(self, context: HfcContext):
|
||||
"""
|
||||
初始化能量管理器
|
||||
|
||||
Args:
|
||||
context: HFC聊天上下文对象
|
||||
|
||||
功能说明:
|
||||
- 管理聊天机器人的能量值系统
|
||||
- 根据聊天模式自动调整能量消耗
|
||||
- 控制能量值的衰减和记录
|
||||
"""
|
||||
self.context = context
|
||||
self._energy_task: Optional[asyncio.Task] = None
|
||||
self.last_energy_log_time = 0
|
||||
self.energy_log_interval = 90
|
||||
|
||||
async def start(self):
|
||||
"""
|
||||
启动能量管理器
|
||||
|
||||
功能说明:
|
||||
- 检查运行状态,避免重复启动
|
||||
- 创建能量循环异步任务
|
||||
- 设置任务完成回调
|
||||
- 记录启动日志
|
||||
"""
|
||||
if self.context.running and not self._energy_task:
|
||||
self._energy_task = asyncio.create_task(self._energy_loop())
|
||||
self._energy_task.add_done_callback(self._handle_energy_completion)
|
||||
logger.info(f"{self.context.log_prefix} 能量管理器已启动")
|
||||
|
||||
async def stop(self):
|
||||
"""
|
||||
停止能量管理器
|
||||
|
||||
功能说明:
|
||||
- 取消正在运行的能量循环任务
|
||||
- 等待任务完全停止
|
||||
- 记录停止日志
|
||||
"""
|
||||
if self._energy_task and not self._energy_task.done():
|
||||
self._energy_task.cancel()
|
||||
await asyncio.sleep(0)
|
||||
logger.info(f"{self.context.log_prefix} 能量管理器已停止")
|
||||
|
||||
async def _energy_loop(self):
|
||||
"""
|
||||
能量与睡眠压力管理的主循环
|
||||
|
||||
功能说明:
|
||||
- 每10秒执行一次能量更新
|
||||
- 根据群聊配置设置固定的聊天模式和能量值
|
||||
- 在自动模式下根据聊天模式进行能量衰减
|
||||
- NORMAL模式每次衰减0.3,FOCUS模式每次衰减0.6
|
||||
- 确保能量值不低于0.3的最小值
|
||||
"""
|
||||
while self.context.running:
|
||||
await asyncio.sleep(10)
|
||||
|
||||
if not self.context.chat_stream:
|
||||
continue
|
||||
|
||||
# 判断当前是否为睡眠时间
|
||||
is_sleeping = sleep_manager.SleepManager().is_sleeping()
|
||||
|
||||
if is_sleeping:
|
||||
# 睡眠中:减少睡眠压力
|
||||
decay_per_10s = global_config.sleep_system.sleep_pressure_decay_rate / 6
|
||||
self.context.sleep_pressure -= decay_per_10s
|
||||
self.context.sleep_pressure = max(self.context.sleep_pressure, 0)
|
||||
self._log_sleep_pressure_change("睡眠压力释放")
|
||||
self.context.save_context_state()
|
||||
else:
|
||||
# 清醒时:处理能量衰减
|
||||
is_group_chat = self.context.chat_stream.group_info is not None
|
||||
if is_group_chat:
|
||||
self.context.energy_value = 25
|
||||
|
||||
await asyncio.sleep(12)
|
||||
self.context.energy_value -= 0.5
|
||||
self.context.energy_value = max(self.context.energy_value, 0.3)
|
||||
|
||||
self._log_energy_change("能量值衰减")
|
||||
self.context.save_context_state()
|
||||
|
||||
def _should_log_energy(self) -> bool:
|
||||
"""
|
||||
判断是否应该记录能量变化日志
|
||||
|
||||
Returns:
|
||||
bool: 如果距离上次记录超过间隔时间则返回True
|
||||
|
||||
功能说明:
|
||||
- 控制能量日志的记录频率,避免日志过于频繁
|
||||
- 默认间隔90秒记录一次详细日志
|
||||
- 其他时间使用调试级别日志
|
||||
"""
|
||||
current_time = time.time()
|
||||
if current_time - self.last_energy_log_time >= self.energy_log_interval:
|
||||
self.last_energy_log_time = current_time
|
||||
return True
|
||||
return False
|
||||
|
||||
def increase_sleep_pressure(self):
|
||||
"""
|
||||
在执行动作后增加睡眠压力
|
||||
"""
|
||||
increment = global_config.sleep_system.sleep_pressure_increment
|
||||
self.context.sleep_pressure += increment
|
||||
self.context.sleep_pressure = min(self.context.sleep_pressure, 100.0) # 设置一个100的上限
|
||||
self._log_sleep_pressure_change("执行动作,睡眠压力累积")
|
||||
self.context.save_context_state()
|
||||
|
||||
def _log_energy_change(self, action: str, reason: str = ""):
|
||||
"""
|
||||
记录能量变化日志
|
||||
|
||||
Args:
|
||||
action: 能量变化的动作描述
|
||||
reason: 可选的变化原因
|
||||
|
||||
功能说明:
|
||||
- 根据时间间隔决定使用info还是debug级别的日志
|
||||
- 格式化能量值显示(保留一位小数)
|
||||
- 可选择性地包含变化原因
|
||||
"""
|
||||
if self._should_log_energy():
|
||||
log_message = f"{self.context.log_prefix} {action},当前能量值:{self.context.energy_value:.1f}"
|
||||
if reason:
|
||||
log_message = (
|
||||
f"{self.context.log_prefix} {action},{reason},当前能量值:{self.context.energy_value:.1f}"
|
||||
)
|
||||
logger.info(log_message)
|
||||
else:
|
||||
log_message = f"{self.context.log_prefix} {action},当前能量值:{self.context.energy_value:.1f}"
|
||||
if reason:
|
||||
log_message = (
|
||||
f"{self.context.log_prefix} {action},{reason},当前能量值:{self.context.energy_value:.1f}"
|
||||
)
|
||||
logger.debug(log_message)
|
||||
|
||||
def _log_sleep_pressure_change(self, action: str):
|
||||
"""
|
||||
记录睡眠压力变化日志
|
||||
"""
|
||||
# 使用与能量日志相同的频率控制
|
||||
if self._should_log_energy():
|
||||
logger.info(f"{self.context.log_prefix} {action},当前睡眠压力:{self.context.sleep_pressure:.1f}")
|
||||
else:
|
||||
logger.debug(f"{self.context.log_prefix} {action},当前睡眠压力:{self.context.sleep_pressure:.1f}")
|
||||
@@ -1,574 +0,0 @@
|
||||
import asyncio
|
||||
import time
|
||||
import traceback
|
||||
import random
|
||||
from typing import Optional, List, Dict, Any
|
||||
from collections import deque
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
||||
from src.chat.express.expression_learner import expression_learner_manager
|
||||
from src.chat.chat_loop.sleep_manager.sleep_manager import SleepManager, SleepState
|
||||
|
||||
from .hfc_context import HfcContext
|
||||
from .energy_manager import EnergyManager
|
||||
from .proactive.proactive_thinker import ProactiveThinker
|
||||
from .cycle_processor import CycleProcessor
|
||||
from .response_handler import ResponseHandler
|
||||
from .cycle_tracker import CycleTracker
|
||||
from .sleep_manager.wakeup_manager import WakeUpManager
|
||||
from .proactive.events import ProactiveTriggerEvent
|
||||
|
||||
logger = get_logger("hfc")
|
||||
|
||||
|
||||
class HeartFChatting:
|
||||
def __init__(self, chat_id: str):
|
||||
"""
|
||||
初始化心跳聊天管理器
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID标识符
|
||||
|
||||
功能说明:
|
||||
- 创建聊天上下文和所有子管理器
|
||||
- 初始化循环跟踪器、响应处理器、循环处理器等核心组件
|
||||
- 设置能量管理器、主动思考器和普通模式处理器
|
||||
- 初始化聊天模式并记录初始化完成日志
|
||||
"""
|
||||
self.context = HfcContext(chat_id)
|
||||
self.context.new_message_queue = asyncio.Queue()
|
||||
self._processing_lock = asyncio.Lock()
|
||||
|
||||
self.cycle_tracker = CycleTracker(self.context)
|
||||
self.response_handler = ResponseHandler(self.context)
|
||||
self.cycle_processor = CycleProcessor(self.context, self.response_handler, self.cycle_tracker)
|
||||
self.energy_manager = EnergyManager(self.context)
|
||||
self.proactive_thinker = ProactiveThinker(self.context, self.cycle_processor)
|
||||
self.wakeup_manager = WakeUpManager(self.context)
|
||||
self.sleep_manager = SleepManager()
|
||||
|
||||
# 将唤醒度管理器设置到上下文中
|
||||
self.context.wakeup_manager = self.wakeup_manager
|
||||
self.context.energy_manager = self.energy_manager
|
||||
self.context.sleep_manager = self.sleep_manager
|
||||
# 将HeartFChatting实例设置到上下文中,以便其他组件可以调用其方法
|
||||
self.context.chat_instance = self
|
||||
|
||||
self._loop_task: Optional[asyncio.Task] = None
|
||||
self._proactive_monitor_task: Optional[asyncio.Task] = None
|
||||
|
||||
# 记录最近3次的兴趣度
|
||||
self.recent_interest_records: deque = deque(maxlen=3)
|
||||
self._initialize_chat_mode()
|
||||
logger.info(f"{self.context.log_prefix} HeartFChatting 初始化完成")
|
||||
|
||||
def _initialize_chat_mode(self):
|
||||
"""
|
||||
初始化聊天模式
|
||||
|
||||
功能说明:
|
||||
- 检测是否为群聊环境
|
||||
- 根据全局配置设置强制聊天模式
|
||||
- 在focus模式下设置能量值为35
|
||||
- 在normal模式下设置能量值为15
|
||||
- 如果是auto模式则保持默认设置
|
||||
"""
|
||||
is_group_chat = self.context.chat_stream.group_info is not None if self.context.chat_stream else False
|
||||
if is_group_chat and global_config.chat.group_chat_mode != "auto":
|
||||
self.context.energy_value = 25
|
||||
|
||||
async def start(self):
|
||||
"""
|
||||
启动心跳聊天系统
|
||||
|
||||
功能说明:
|
||||
- 检查是否已经在运行,避免重复启动
|
||||
- 初始化关系构建器和表达学习器
|
||||
- 启动能量管理器和主动思考器
|
||||
- 创建主聊天循环任务并设置完成回调
|
||||
- 记录启动完成日志
|
||||
"""
|
||||
if self.context.running:
|
||||
return
|
||||
self.context.running = True
|
||||
|
||||
self.context.relationship_builder = relationship_builder_manager.get_or_create_builder(self.context.stream_id)
|
||||
self.context.expression_learner = await expression_learner_manager.get_expression_learner(self.context.stream_id)
|
||||
|
||||
# 启动主动思考监视器
|
||||
if global_config.chat.enable_proactive_thinking:
|
||||
self._proactive_monitor_task = asyncio.create_task(self._proactive_monitor_loop())
|
||||
self._proactive_monitor_task.add_done_callback(self._handle_proactive_monitor_completion)
|
||||
logger.info(f"{self.context.log_prefix} 主动思考监视器已启动")
|
||||
|
||||
await self.wakeup_manager.start()
|
||||
|
||||
self._loop_task = asyncio.create_task(self._main_chat_loop())
|
||||
self._loop_task.add_done_callback(self._handle_loop_completion)
|
||||
logger.info(f"{self.context.log_prefix} HeartFChatting 启动完成")
|
||||
|
||||
async def add_message(self, message: Dict[str, Any]):
|
||||
"""从外部接收新消息并放入队列"""
|
||||
await self.context.new_message_queue.put(message)
|
||||
|
||||
async def stop(self):
|
||||
"""
|
||||
停止心跳聊天系统
|
||||
|
||||
功能说明:
|
||||
- 检查是否正在运行,避免重复停止
|
||||
- 设置运行状态为False
|
||||
- 停止能量管理器和主动思考器
|
||||
- 取消主聊天循环任务
|
||||
- 记录停止完成日志
|
||||
"""
|
||||
if not self.context.running:
|
||||
return
|
||||
self.context.running = False
|
||||
|
||||
# 停止主动思考监视器
|
||||
if self._proactive_monitor_task and not self._proactive_monitor_task.done():
|
||||
self._proactive_monitor_task.cancel()
|
||||
await asyncio.sleep(0)
|
||||
logger.info(f"{self.context.log_prefix} 主动思考监视器已停止")
|
||||
|
||||
await self.wakeup_manager.stop()
|
||||
|
||||
if self._loop_task and not self._loop_task.done():
|
||||
self._loop_task.cancel()
|
||||
await asyncio.sleep(0)
|
||||
logger.info(f"{self.context.log_prefix} HeartFChatting 已停止")
|
||||
|
||||
def _handle_loop_completion(self, task: asyncio.Task):
|
||||
"""
|
||||
处理主循环任务完成
|
||||
|
||||
Args:
|
||||
task: 完成的异步任务对象
|
||||
|
||||
功能说明:
|
||||
- 处理任务异常完成的情况
|
||||
- 区分正常停止和异常终止
|
||||
- 记录相应的日志信息
|
||||
- 处理取消任务的情况
|
||||
"""
|
||||
try:
|
||||
if exception := task.exception():
|
||||
logger.error(f"{self.context.log_prefix} HeartFChatting: 脱离了聊天(异常): {exception}")
|
||||
logger.error(traceback.format_exc())
|
||||
else:
|
||||
logger.info(f"{self.context.log_prefix} HeartFChatting: 脱离了聊天 (外部停止)")
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.context.log_prefix} HeartFChatting: 结束了聊天")
|
||||
|
||||
def _handle_proactive_monitor_completion(self, task: asyncio.Task):
|
||||
"""
|
||||
处理主动思考监视器任务完成
|
||||
|
||||
Args:
|
||||
task: 完成的异步任务对象
|
||||
|
||||
功能说明:
|
||||
- 处理任务异常完成的情况
|
||||
- 记录任务正常结束或被取消的日志
|
||||
"""
|
||||
try:
|
||||
if exception := task.exception():
|
||||
logger.error(f"{self.context.log_prefix} 主动思考监视器异常: {exception}")
|
||||
else:
|
||||
logger.info(f"{self.context.log_prefix} 主动思考监视器正常结束")
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.context.log_prefix} 主动思考监视器被取消")
|
||||
|
||||
async def _proactive_monitor_loop(self):
|
||||
"""
|
||||
主动思考监视器循环
|
||||
|
||||
功能说明:
|
||||
- 定期检查是否需要进行主动思考
|
||||
- 计算聊天沉默时间,并与动态思考间隔比较
|
||||
- 当沉默时间超过阈值时,触发主动思考
|
||||
- 处理思考过程中的异常
|
||||
"""
|
||||
while self.context.running:
|
||||
await asyncio.sleep(15)
|
||||
|
||||
if not self._should_enable_proactive_thinking():
|
||||
continue
|
||||
|
||||
current_time = time.time()
|
||||
silence_duration = current_time - self.context.last_message_time
|
||||
target_interval = self._get_dynamic_thinking_interval()
|
||||
|
||||
if silence_duration >= target_interval:
|
||||
try:
|
||||
formatted_time = self._format_duration(silence_duration)
|
||||
event = ProactiveTriggerEvent(
|
||||
source="silence_monitor",
|
||||
reason=f"聊天已沉默 {formatted_time}",
|
||||
metadata={"silence_duration": silence_duration},
|
||||
)
|
||||
await self.proactive_thinker.think(event)
|
||||
self.context.last_message_time = current_time
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 主动思考触发执行出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
def _should_enable_proactive_thinking(self) -> bool:
|
||||
"""
|
||||
判断是否应启用主动思考
|
||||
|
||||
Returns:
|
||||
bool: 如果应启用主动思考则返回True,否则返回False
|
||||
|
||||
功能说明:
|
||||
- 检查全局配置和特定聊天设置
|
||||
- 支持按群聊和私聊分别配置
|
||||
- 支持白名单模式,只在特定聊天中启用
|
||||
"""
|
||||
if not self.context.chat_stream:
|
||||
return False
|
||||
|
||||
is_group_chat = self.context.chat_stream.group_info is not None
|
||||
|
||||
if is_group_chat and not global_config.chat.proactive_thinking_in_group:
|
||||
return False
|
||||
if not is_group_chat and not global_config.chat.proactive_thinking_in_private:
|
||||
return False
|
||||
|
||||
stream_parts = self.context.stream_id.split(":")
|
||||
current_chat_identifier = f"{stream_parts}:{stream_parts}" if len(stream_parts) >= 2 else self.context.stream_id
|
||||
|
||||
enable_list = getattr(
|
||||
global_config.chat,
|
||||
"proactive_thinking_enable_in_groups" if is_group_chat else "proactive_thinking_enable_in_private",
|
||||
[],
|
||||
)
|
||||
return not enable_list or current_chat_identifier in enable_list
|
||||
|
||||
def _get_dynamic_thinking_interval(self) -> float:
|
||||
"""
|
||||
获取动态思考间隔时间
|
||||
|
||||
Returns:
|
||||
float: 思考间隔秒数
|
||||
|
||||
功能说明:
|
||||
- 尝试从timing_utils导入正态分布间隔函数
|
||||
- 根据配置计算动态间隔,增加随机性
|
||||
- 在无法导入或计算出错时,回退到固定的间隔
|
||||
"""
|
||||
try:
|
||||
from src.utils.timing_utils import get_normal_distributed_interval
|
||||
|
||||
base_interval = global_config.chat.proactive_thinking_interval
|
||||
delta_sigma = getattr(global_config.chat, "delta_sigma", 120)
|
||||
|
||||
if base_interval <= 0:
|
||||
base_interval = abs(base_interval)
|
||||
if delta_sigma < 0:
|
||||
delta_sigma = abs(delta_sigma)
|
||||
|
||||
if base_interval == 0 and delta_sigma == 0:
|
||||
return 300
|
||||
if delta_sigma == 0:
|
||||
return base_interval
|
||||
|
||||
sigma_percentage = delta_sigma / base_interval if base_interval > 0 else delta_sigma / 1000
|
||||
return get_normal_distributed_interval(base_interval, sigma_percentage, 1, 86400, use_3sigma_rule=True)
|
||||
|
||||
except ImportError:
|
||||
logger.warning(f"{self.context.log_prefix} timing_utils不可用,使用固定间隔")
|
||||
return max(300, abs(global_config.chat.proactive_thinking_interval))
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 动态间隔计算出错: {e},使用固定间隔")
|
||||
return max(300, abs(global_config.chat.proactive_thinking_interval))
|
||||
|
||||
@staticmethod
|
||||
def _format_duration(seconds: float) -> str:
|
||||
"""
|
||||
格式化时长为可读字符串
|
||||
|
||||
Args:
|
||||
seconds: 时长秒数
|
||||
|
||||
Returns:
|
||||
str: 格式化后的字符串 (例如 "1小时2分3秒")
|
||||
"""
|
||||
hours = int(seconds // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
secs = int(seconds % 60)
|
||||
parts = []
|
||||
if hours > 0:
|
||||
parts.append(f"{hours}小时")
|
||||
if minutes > 0:
|
||||
parts.append(f"{minutes}分")
|
||||
if secs > 0 or not parts:
|
||||
parts.append(f"{secs}秒")
|
||||
return "".join(parts)
|
||||
|
||||
async def _main_chat_loop(self):
|
||||
"""
|
||||
主聊天循环
|
||||
|
||||
功能说明:
|
||||
- 持续运行聊天处理循环
|
||||
- 只有在有新消息时才进行思考循环
|
||||
- 无新消息时等待新消息到达(由主动思考系统单独处理主动发言)
|
||||
- 处理取消和异常情况
|
||||
- 在异常时尝试重新启动循环
|
||||
"""
|
||||
try:
|
||||
while self.context.running:
|
||||
has_new_messages = await self._loop_body()
|
||||
|
||||
if has_new_messages:
|
||||
# 有新消息时,继续快速检查是否还有更多消息
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
# 无新消息时,等待较长时间再检查
|
||||
# 这里只是为了定期检查系统状态,不进行思考循环
|
||||
# 真正的新消息响应依赖于消息到达时的通知
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.context.log_prefix} 麦麦已关闭聊天")
|
||||
except Exception:
|
||||
logger.error(f"{self.context.log_prefix} 麦麦聊天意外错误,将于3s后尝试重新启动")
|
||||
print(traceback.format_exc())
|
||||
await asyncio.sleep(3)
|
||||
self._loop_task = asyncio.create_task(self._main_chat_loop())
|
||||
logger.error(f"{self.context.log_prefix} 结束了当前聊天循环")
|
||||
|
||||
async def _loop_body(self) -> bool:
|
||||
"""
|
||||
单次循环体处理
|
||||
|
||||
Returns:
|
||||
bool: 是否处理了新消息
|
||||
|
||||
功能说明:
|
||||
- 检查是否处于睡眠模式,如果是则处理唤醒度逻辑
|
||||
- 获取最近的新消息(过滤机器人自己的消息和命令)
|
||||
- 只有在有新消息时才进行思考循环处理
|
||||
- 更新最后消息时间和读取时间
|
||||
- 根据当前聊天模式执行不同的处理逻辑
|
||||
- FOCUS模式:直接处理所有消息并检查退出条件
|
||||
- NORMAL模式:检查进入FOCUS模式的条件,并通过normal_mode_handler处理消息
|
||||
"""
|
||||
async with self._processing_lock:
|
||||
# --- 核心状态更新 ---
|
||||
await self.sleep_manager.update_sleep_state(self.wakeup_manager)
|
||||
current_sleep_state = self.sleep_manager.get_current_sleep_state()
|
||||
is_sleeping = current_sleep_state == SleepState.SLEEPING
|
||||
is_in_insomnia = current_sleep_state == SleepState.INSOMNIA
|
||||
|
||||
# 核心修复:在睡眠模式(包括失眠)下获取消息时,不过滤命令消息,以确保@消息能被接收
|
||||
filter_command_flag = not (is_sleeping or is_in_insomnia)
|
||||
|
||||
# 从队列中获取所有待处理的新消息
|
||||
recent_messages = []
|
||||
while not self.context.new_message_queue.empty():
|
||||
recent_messages.append(await self.context.new_message_queue.get())
|
||||
|
||||
has_new_messages = bool(recent_messages)
|
||||
new_message_count = len(recent_messages)
|
||||
|
||||
# 只有在有新消息时才进行思考循环处理
|
||||
if has_new_messages:
|
||||
self.context.last_message_time = time.time()
|
||||
self.context.last_read_time = time.time()
|
||||
|
||||
# --- 专注模式安静群组检查 ---
|
||||
quiet_groups = global_config.chat.focus_mode_quiet_groups
|
||||
if quiet_groups and self.context.chat_stream:
|
||||
is_group_chat = self.context.chat_stream.group_info is not None
|
||||
if is_group_chat:
|
||||
try:
|
||||
platform = self.context.chat_stream.platform
|
||||
group_id = self.context.chat_stream.group_info.group_id
|
||||
|
||||
# 兼容不同QQ适配器的平台名称
|
||||
is_qq_platform = platform in ["qq", "napcat"]
|
||||
|
||||
current_chat_identifier = f"{platform}:{group_id}"
|
||||
config_identifier_for_qq = f"qq:{group_id}"
|
||||
|
||||
is_in_quiet_list = (current_chat_identifier in quiet_groups or
|
||||
(is_qq_platform and config_identifier_for_qq in quiet_groups))
|
||||
|
||||
if is_in_quiet_list:
|
||||
is_mentioned_in_batch = False
|
||||
for msg in recent_messages:
|
||||
if msg.get("is_mentioned"):
|
||||
is_mentioned_in_batch = True
|
||||
break
|
||||
|
||||
if not is_mentioned_in_batch:
|
||||
logger.info(f"{self.context.log_prefix} 在专注安静模式下,因未被提及而忽略了消息。")
|
||||
return True # 消耗消息但不做回复
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 检查专注安静群组时出错: {e}")
|
||||
|
||||
# 处理唤醒度逻辑
|
||||
if current_sleep_state in [SleepState.SLEEPING, SleepState.PREPARING_SLEEP, SleepState.INSOMNIA]:
|
||||
self._handle_wakeup_messages(recent_messages)
|
||||
|
||||
# 再次获取最新状态,因为 handle_wakeup 可能导致状态变为 WOKEN_UP
|
||||
current_sleep_state = self.sleep_manager.get_current_sleep_state()
|
||||
|
||||
if current_sleep_state == SleepState.SLEEPING:
|
||||
# 只有在纯粹的 SLEEPING 状态下才跳过消息处理
|
||||
return True
|
||||
|
||||
if current_sleep_state == SleepState.WOKEN_UP:
|
||||
logger.info(f"{self.context.log_prefix} 从睡眠中被唤醒,将处理积压的消息。")
|
||||
|
||||
# 根据聊天模式处理新消息
|
||||
should_process, interest_value = await self._should_process_messages(recent_messages)
|
||||
if not should_process:
|
||||
# 消息数量不足或兴趣不够,等待
|
||||
await asyncio.sleep(0.5)
|
||||
return True # Skip rest of the logic for this iteration
|
||||
|
||||
# Messages should be processed
|
||||
action_type = await self.cycle_processor.observe(interest_value=interest_value)
|
||||
|
||||
# 尝试触发表达学习
|
||||
if self.context.expression_learner:
|
||||
try:
|
||||
await self.context.expression_learner.trigger_learning_for_chat()
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 表达学习触发失败: {e}")
|
||||
|
||||
# 管理no_reply计数器
|
||||
if action_type != "no_reply":
|
||||
self.recent_interest_records.clear()
|
||||
self.context.no_reply_consecutive = 0
|
||||
logger.debug(f"{self.context.log_prefix} 执行了{action_type}动作,重置no_reply计数器")
|
||||
else: # action_type == "no_reply"
|
||||
self.context.no_reply_consecutive += 1
|
||||
self._determine_form_type()
|
||||
|
||||
# 在一轮动作执行完毕后,增加睡眠压力
|
||||
if self.context.energy_manager and global_config.sleep_system.enable_insomnia_system:
|
||||
if action_type not in ["no_reply", "no_action"]:
|
||||
self.context.energy_manager.increase_sleep_pressure()
|
||||
|
||||
# 如果成功观察,增加能量值并重置累积兴趣值
|
||||
self.context.energy_value += 1 / global_config.chat.focus_value
|
||||
# 重置累积兴趣值,因为消息已经被成功处理
|
||||
self.context.breaking_accumulated_interest = 0.0
|
||||
logger.info(
|
||||
f"{self.context.log_prefix} 能量值增加,当前能量值:{self.context.energy_value:.1f},重置累积兴趣值"
|
||||
)
|
||||
|
||||
# 更新上一帧的睡眠状态
|
||||
self.context.was_sleeping = is_sleeping
|
||||
|
||||
# --- 重新入睡逻辑 ---
|
||||
# 如果被吵醒了,并且在一定时间内没有新消息,则尝试重新入睡
|
||||
if self.sleep_manager.get_current_sleep_state() == SleepState.WOKEN_UP and not has_new_messages:
|
||||
re_sleep_delay = global_config.sleep_system.re_sleep_delay_minutes * 60
|
||||
# 使用 last_message_time 来判断空闲时间
|
||||
if time.time() - self.context.last_message_time > re_sleep_delay:
|
||||
logger.info(
|
||||
f"{self.context.log_prefix} 已被唤醒且超过 {re_sleep_delay / 60} 分钟无新消息,尝试重新入睡。"
|
||||
)
|
||||
self.sleep_manager.reset_sleep_state_after_wakeup()
|
||||
|
||||
# 保存HFC上下文状态
|
||||
self.context.save_context_state()
|
||||
return has_new_messages
|
||||
|
||||
def _handle_wakeup_messages(self, messages):
|
||||
"""
|
||||
处理休眠状态下的消息,累积唤醒度
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
功能说明:
|
||||
- 区分私聊和群聊消息
|
||||
- 检查群聊消息是否艾特了机器人
|
||||
- 调用唤醒度管理器累积唤醒度
|
||||
- 如果达到阈值则唤醒并进入愤怒状态
|
||||
"""
|
||||
if not self.wakeup_manager:
|
||||
return
|
||||
|
||||
is_private_chat = self.context.chat_stream.group_info is None if self.context.chat_stream else False
|
||||
|
||||
for message in messages:
|
||||
is_mentioned = False
|
||||
|
||||
# 检查群聊消息是否艾特了机器人
|
||||
if not is_private_chat:
|
||||
# 最终修复:直接使用消息对象中由上游处理好的 is_mention 字段。
|
||||
# 该字段在 message.py 的 MessageRecv._process_single_segment 中被设置。
|
||||
if message.get("is_mentioned"):
|
||||
is_mentioned = True
|
||||
|
||||
# 累积唤醒度
|
||||
woke_up = self.wakeup_manager.add_wakeup_value(is_private_chat, is_mentioned)
|
||||
|
||||
if woke_up:
|
||||
logger.info(f"{self.context.log_prefix} 被消息吵醒,进入愤怒状态!")
|
||||
break
|
||||
|
||||
def _determine_form_type(self) -> str:
|
||||
"""判断使用哪种形式的no_reply"""
|
||||
# 检查是否启用breaking模式
|
||||
if not getattr(global_config.chat, "enable_breaking_mode", False):
|
||||
logger.info(f"{self.context.log_prefix} breaking模式已禁用,使用waiting形式")
|
||||
self.context.focus_energy = 1
|
||||
return "waiting"
|
||||
|
||||
# 如果连续no_reply次数少于3次,使用waiting形式
|
||||
if self.context.no_reply_consecutive <= 3:
|
||||
self.context.focus_energy = 1
|
||||
return "waiting"
|
||||
else:
|
||||
# 使用累积兴趣值而不是最近3次的记录
|
||||
total_interest = self.context.breaking_accumulated_interest
|
||||
|
||||
# 计算调整后的阈值
|
||||
adjusted_threshold = 1 / global_config.chat.get_current_talk_frequency(self.context.stream_id)
|
||||
|
||||
logger.info(
|
||||
f"{self.context.log_prefix} 累积兴趣值: {total_interest:.2f}, 调整后阈值: {adjusted_threshold:.2f}"
|
||||
)
|
||||
|
||||
# 如果累积兴趣值小于阈值,进入breaking形式
|
||||
if total_interest < adjusted_threshold:
|
||||
logger.info(f"{self.context.log_prefix} 累积兴趣度不足,进入breaking形式")
|
||||
self.context.focus_energy = random.randint(3, 6)
|
||||
return "breaking"
|
||||
else:
|
||||
logger.info(f"{self.context.log_prefix} 累积兴趣度充足,使用waiting形式")
|
||||
self.context.focus_energy = 1
|
||||
return "waiting"
|
||||
|
||||
async def _should_process_messages(self, new_message: List[Dict[str, Any]]) -> tuple[bool, float]:
|
||||
"""
|
||||
统一判断是否应该处理消息的函数
|
||||
根据当前循环模式和消息内容决定是否继续处理
|
||||
"""
|
||||
if not new_message:
|
||||
return False, 0.0
|
||||
|
||||
# 计算平均兴趣值
|
||||
total_interest = 0.0
|
||||
message_count = 0
|
||||
for msg_dict in new_message:
|
||||
interest_value = msg_dict.get("interest_value", 0.0)
|
||||
if msg_dict.get("processed_plain_text", ""):
|
||||
total_interest += interest_value
|
||||
message_count += 1
|
||||
|
||||
avg_interest = total_interest / message_count if message_count > 0 else 0.0
|
||||
|
||||
logger.info(f"{self.context.log_prefix} 收到 {len(new_message)} 条新消息,立即处理!平均兴趣值: {avg_interest:.2f}")
|
||||
return True, avg_interest
|
||||
@@ -1,82 +0,0 @@
|
||||
import time
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
|
||||
from src.chat.chat_loop.hfc_utils import CycleDetail
|
||||
from src.chat.express.expression_learner import ExpressionLearner
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.config.config import global_config
|
||||
from src.person_info.relationship_builder_manager import RelationshipBuilder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class HfcContext:
|
||||
def __init__(self, chat_id: str):
|
||||
"""
|
||||
初始化HFC聊天上下文
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID标识符
|
||||
|
||||
功能说明:
|
||||
- 存储和管理单个聊天会话的所有状态信息
|
||||
- 包含聊天流、关系构建器、表达学习器等核心组件
|
||||
- 管理聊天模式、能量值、时间戳等关键状态
|
||||
- 提供循环历史记录和当前循环详情的存储
|
||||
- 集成唤醒度管理器,处理休眠状态下的唤醒机制
|
||||
|
||||
Raises:
|
||||
ValueError: 如果找不到对应的聊天流
|
||||
"""
|
||||
self.stream_id: str = chat_id
|
||||
self.chat_stream: Optional[ChatStream] = get_chat_manager().get_stream(self.stream_id)
|
||||
if not self.chat_stream:
|
||||
raise ValueError(f"无法找到聊天流: {self.stream_id}")
|
||||
|
||||
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.stream_id) or self.stream_id}]"
|
||||
|
||||
self.relationship_builder: Optional[RelationshipBuilder] = None
|
||||
self.expression_learner: Optional[ExpressionLearner] = None
|
||||
|
||||
self.energy_value = self.chat_stream.energy_value
|
||||
self.sleep_pressure = self.chat_stream.sleep_pressure
|
||||
self.was_sleeping = False # 用于检测睡眠状态的切换
|
||||
|
||||
self.last_message_time = time.time()
|
||||
self.last_read_time = time.time() - 10
|
||||
|
||||
# 从聊天流恢复breaking累积兴趣值
|
||||
self.breaking_accumulated_interest = getattr(self.chat_stream, "breaking_accumulated_interest", 0.0)
|
||||
|
||||
self.action_manager = ActionManager()
|
||||
|
||||
self.running: bool = False
|
||||
|
||||
self.history_loop: List[CycleDetail] = []
|
||||
self.cycle_counter = 0
|
||||
self.current_cycle_detail: Optional[CycleDetail] = None
|
||||
|
||||
# 唤醒度管理器 - 延迟初始化以避免循环导入
|
||||
self.wakeup_manager: Optional["WakeUpManager"] = None
|
||||
self.energy_manager: Optional["EnergyManager"] = None
|
||||
self.sleep_manager: Optional["SleepManager"] = None
|
||||
|
||||
# 从聊天流获取focus_energy,如果没有则使用配置文件中的值
|
||||
self.focus_energy = getattr(self.chat_stream, "focus_energy", global_config.chat.focus_value)
|
||||
self.no_reply_consecutive = 0
|
||||
self.total_interest = 0.0
|
||||
# breaking形式下的累积兴趣值
|
||||
self.breaking_accumulated_interest = 0.0
|
||||
# 引用HeartFChatting实例,以便其他组件可以调用其方法
|
||||
self.chat_instance: "HeartFChatting"
|
||||
|
||||
def save_context_state(self):
|
||||
"""将当前状态保存到聊天流"""
|
||||
if self.chat_stream:
|
||||
self.chat_stream.energy_value = self.energy_value
|
||||
self.chat_stream.sleep_pressure = self.sleep_pressure
|
||||
self.chat_stream.focus_energy = self.focus_energy
|
||||
self.chat_stream.no_reply_consecutive = self.no_reply_consecutive
|
||||
self.chat_stream.breaking_accumulated_interest = self.breaking_accumulated_interest
|
||||
@@ -1,172 +0,0 @@
|
||||
import time
|
||||
from typing import Optional, Dict, Any, Union
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.plugin_system.apis import send_api
|
||||
from maim_message.message_base import GroupInfo
|
||||
|
||||
|
||||
logger = get_logger("hfc")
|
||||
|
||||
|
||||
class CycleDetail:
|
||||
"""
|
||||
循环信息记录类
|
||||
|
||||
功能说明:
|
||||
- 记录单次思考循环的详细信息
|
||||
- 包含循环ID、思考ID、时间戳等基本信息
|
||||
- 存储循环的规划信息和动作信息
|
||||
- 提供序列化和转换功能
|
||||
"""
|
||||
|
||||
def __init__(self, cycle_id: Union[int, str]):
|
||||
"""
|
||||
初始化循环详情记录
|
||||
|
||||
Args:
|
||||
cycle_id: 循环ID,用于标识循环的顺序
|
||||
|
||||
功能说明:
|
||||
- 设置循环基本标识信息
|
||||
- 初始化时间戳和计时器
|
||||
- 准备循环信息存储容器
|
||||
"""
|
||||
self.cycle_id = cycle_id
|
||||
self.thinking_id = ""
|
||||
self.start_time = time.time()
|
||||
self.end_time: Optional[float] = None
|
||||
self.timers: Dict[str, float] = {}
|
||||
|
||||
self.loop_plan_info: Dict[str, Any] = {}
|
||||
self.loop_action_info: Dict[str, Any] = {}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
将循环信息转换为字典格式
|
||||
|
||||
Returns:
|
||||
dict: 包含所有循环信息的字典,已处理循环引用和序列化问题
|
||||
|
||||
功能说明:
|
||||
- 递归转换复杂对象为可序列化格式
|
||||
- 防止循环引用导致的无限递归
|
||||
- 限制递归深度避免栈溢出
|
||||
- 只保留基本数据类型和可序列化的值
|
||||
"""
|
||||
|
||||
def convert_to_serializable(obj, depth=0, seen=None):
|
||||
if seen is None:
|
||||
seen = set()
|
||||
|
||||
# 防止递归过深
|
||||
if depth > 5: # 降低递归深度限制
|
||||
return str(obj)
|
||||
|
||||
# 防止循环引用
|
||||
obj_id = id(obj)
|
||||
if obj_id in seen:
|
||||
return str(obj)
|
||||
seen.add(obj_id)
|
||||
|
||||
try:
|
||||
if hasattr(obj, "to_dict"):
|
||||
# 对于有to_dict方法的对象,直接调用其to_dict方法
|
||||
return obj.to_dict()
|
||||
elif isinstance(obj, dict):
|
||||
# 对于字典,只保留基本类型和可序列化的值
|
||||
return {
|
||||
k: convert_to_serializable(v, depth + 1, seen)
|
||||
for k, v in obj.items()
|
||||
if isinstance(k, (str, int, float, bool))
|
||||
}
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
# 对于列表和元组,只保留可序列化的元素
|
||||
return [
|
||||
convert_to_serializable(item, depth + 1, seen)
|
||||
for item in obj
|
||||
if not isinstance(item, (dict, list, tuple))
|
||||
or isinstance(item, (str, int, float, bool, type(None)))
|
||||
]
|
||||
elif isinstance(obj, (str, int, float, bool, type(None))):
|
||||
return obj
|
||||
else:
|
||||
return str(obj)
|
||||
finally:
|
||||
seen.remove(obj_id)
|
||||
|
||||
return {
|
||||
"cycle_id": self.cycle_id,
|
||||
"start_time": self.start_time,
|
||||
"end_time": self.end_time,
|
||||
"timers": self.timers,
|
||||
"thinking_id": self.thinking_id,
|
||||
"loop_plan_info": convert_to_serializable(self.loop_plan_info),
|
||||
"loop_action_info": convert_to_serializable(self.loop_action_info),
|
||||
}
|
||||
|
||||
def set_loop_info(self, loop_info: Dict[str, Any]):
|
||||
"""
|
||||
设置循环信息
|
||||
|
||||
Args:
|
||||
loop_info: 包含循环规划和动作信息的字典
|
||||
|
||||
功能说明:
|
||||
- 从传入的循环信息中提取规划和动作信息
|
||||
- 更新当前循环详情的相关字段
|
||||
"""
|
||||
self.loop_plan_info = loop_info["loop_plan_info"]
|
||||
self.loop_action_info = loop_info["loop_action_info"]
|
||||
|
||||
|
||||
async def send_typing(user_id):
|
||||
"""
|
||||
发送打字状态指示
|
||||
|
||||
功能说明:
|
||||
- 创建内心聊天流(用于状态显示)
|
||||
- 发送typing状态消息
|
||||
- 不存储到消息记录中
|
||||
- 用于S4U功能的视觉反馈
|
||||
"""
|
||||
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform="amaidesu_default",
|
||||
user_info=None,
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
from plugin_system.core.event_manager import event_manager
|
||||
from src.plugins.built_in.napcat_adapter_plugin.event_types import NapcatEvent
|
||||
# 设置正在输入状态
|
||||
await event_manager.trigger_event(NapcatEvent.PERSONAL.SET_INPUT_STATUS,user_id=user_id,event_type=1)
|
||||
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False
|
||||
)
|
||||
|
||||
|
||||
async def stop_typing():
|
||||
"""
|
||||
停止打字状态指示
|
||||
|
||||
功能说明:
|
||||
- 创建内心聊天流(用于状态显示)
|
||||
- 发送stop_typing状态消息
|
||||
- 不存储到消息记录中
|
||||
- 结束S4U功能的视觉反馈
|
||||
"""
|
||||
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform="amaidesu_default",
|
||||
user_info=None,
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="stop_typing", stream_id=chat.stream_id, storage_message=False
|
||||
)
|
||||
@@ -1,14 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProactiveTriggerEvent:
|
||||
"""
|
||||
主动思考触发事件的数据类
|
||||
"""
|
||||
|
||||
source: str # 触发源的标识,例如 "silence_monitor", "insomnia_manager"
|
||||
reason: str # 触发的具体原因,例如 "聊天已沉默10分钟", "深夜emo"
|
||||
metadata: Optional[Dict[str, Any]] = field(default_factory=dict) # 可选的元数据,用于传递额外信息
|
||||
related_message_id: Optional[str] = None # 关联的消息ID,用于加载上下文
|
||||
@@ -1,264 +0,0 @@
|
||||
import time
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Dict, Any
|
||||
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages_with_id
|
||||
from src.common.database.sqlalchemy_database_api import store_action_info
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.plugin_system import tool_api
|
||||
from src.plugin_system.apis import generator_api
|
||||
from src.plugin_system.apis.generator_api import process_human_text
|
||||
from src.plugin_system.base.component_types import ChatMode
|
||||
from src.schedule.schedule_manager import schedule_manager
|
||||
from .events import ProactiveTriggerEvent
|
||||
from ..hfc_context import HfcContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..cycle_processor import CycleProcessor
|
||||
|
||||
logger = get_logger("hfc")
|
||||
|
||||
|
||||
class ProactiveThinker:
|
||||
"""
|
||||
主动思考器,负责处理和执行主动思考事件。
|
||||
当接收到 ProactiveTriggerEvent 时,它会根据事件内容进行一系列决策和操作,
|
||||
例如调整情绪、调用规划器生成行动,并最终可能产生一个主动的回复。
|
||||
"""
|
||||
|
||||
def __init__(self, context: HfcContext, cycle_processor: "CycleProcessor"):
|
||||
"""
|
||||
初始化主动思考器。
|
||||
|
||||
Args:
|
||||
context (HfcContext): HFC聊天上下文对象,提供了当前聊天会话的所有背景信息。
|
||||
cycle_processor (CycleProcessor): 循环处理器,用于执行主动思考后产生的动作。
|
||||
|
||||
功能说明:
|
||||
- 接收并处理主动思考事件 (ProactiveTriggerEvent)。
|
||||
- 在思考前根据事件类型执行预处理操作,如修改当前情绪状态。
|
||||
- 调用行动规划器 (Action Planner) 来决定下一步应该做什么。
|
||||
- 如果规划结果是发送消息,则调用生成器API生成回复并发送。
|
||||
"""
|
||||
self.context = context
|
||||
self.cycle_processor = cycle_processor
|
||||
|
||||
async def think(self, trigger_event: ProactiveTriggerEvent):
|
||||
"""
|
||||
主动思考的统一入口API。
|
||||
这是外部触发主动思考时调用的主要方法。
|
||||
|
||||
Args:
|
||||
trigger_event (ProactiveTriggerEvent): 描述触发上下文的事件对象,包含了思考的来源和原因。
|
||||
"""
|
||||
logger.info(
|
||||
f"{self.context.log_prefix} 接收到主动思考事件: "
|
||||
f"来源='{trigger_event.source}', 原因='{trigger_event.reason}'"
|
||||
)
|
||||
|
||||
try:
|
||||
# 步骤 1: 根据事件类型执行思考前的准备工作,例如调整情绪。
|
||||
await self._prepare_for_thinking(trigger_event)
|
||||
|
||||
# 步骤 2: 执行核心的思考和决策逻辑。
|
||||
await self._execute_proactive_thinking(trigger_event)
|
||||
|
||||
except Exception as e:
|
||||
# 捕获并记录在思考过程中发生的任何异常。
|
||||
logger.error(f"{self.context.log_prefix} 主动思考 think 方法执行异常: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _prepare_for_thinking(self, trigger_event: ProactiveTriggerEvent):
|
||||
"""
|
||||
根据事件类型,在正式思考前执行准备工作。
|
||||
目前主要是处理来自失眠管理器的事件,并据此调整情绪。
|
||||
|
||||
Args:
|
||||
trigger_event (ProactiveTriggerEvent): 触发事件。
|
||||
"""
|
||||
# 目前只处理来自失眠管理器(insomnia_manager)的事件
|
||||
if trigger_event.source != "insomnia_manager":
|
||||
return
|
||||
|
||||
try:
|
||||
# 获取当前聊天的情绪对象
|
||||
mood_obj = mood_manager.get_mood_by_chat_id(self.context.stream_id)
|
||||
new_mood = None
|
||||
|
||||
# 根据失眠的不同原因设置对应的情绪
|
||||
if trigger_event.reason == "low_pressure":
|
||||
new_mood = "精力过剩,毫无睡意"
|
||||
elif trigger_event.reason == "random":
|
||||
new_mood = "深夜emo,胡思乱想"
|
||||
elif trigger_event.reason == "goodnight":
|
||||
new_mood = "有点困了,准备睡觉了"
|
||||
|
||||
# 如果成功匹配到了新的情绪,则更新情绪状态
|
||||
if new_mood:
|
||||
mood_obj.mood_state = new_mood
|
||||
mood_obj.last_change_time = time.time()
|
||||
logger.info(
|
||||
f"{self.context.log_prefix} 因 '{trigger_event.reason}',"
|
||||
f"情绪状态被强制更新为: {mood_obj.mood_state}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 设置失眠情绪时出错: {e}")
|
||||
|
||||
async def _execute_proactive_thinking(self, trigger_event: ProactiveTriggerEvent):
|
||||
"""
|
||||
执行主动思考的核心逻辑。
|
||||
它会调用规划器来决定是否要采取行动,以及采取什么行动。
|
||||
|
||||
Args:
|
||||
trigger_event (ProactiveTriggerEvent): 触发事件。
|
||||
"""
|
||||
try:
|
||||
actions, _ = await self.cycle_processor.action_planner.plan(mode=ChatMode.PROACTIVE)
|
||||
action_result = actions[0] if actions else {}
|
||||
action_type = action_result.get("action_type")
|
||||
|
||||
if action_type is None:
|
||||
logger.info(f"{self.context.log_prefix} 主动思考决策: 规划器未返回有效动作")
|
||||
return
|
||||
|
||||
if action_type == "proactive_reply":
|
||||
await self._generate_proactive_content_and_send(action_result, trigger_event)
|
||||
elif action_type not in ["do_nothing", "no_action"]:
|
||||
await self.cycle_processor._handle_action(
|
||||
action=action_result["action_type"],
|
||||
reasoning=action_result.get("reasoning", ""),
|
||||
action_data=action_result.get("action_data", {}),
|
||||
cycle_timers={},
|
||||
thinking_id="",
|
||||
action_message=action_result.get("action_message")
|
||||
)
|
||||
else:
|
||||
logger.info(f"{self.context.log_prefix} 主动思考决策: 保持沉默")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 主动思考执行异常: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
async def _generate_proactive_content_and_send(self, action_result: Dict[str, Any], trigger_event: ProactiveTriggerEvent):
|
||||
"""
|
||||
获取实时信息,构建最终的生成提示词,并生成和发送主动回复。
|
||||
|
||||
Args:
|
||||
action_result (Dict[str, Any]): 规划器返回的动作结果。
|
||||
trigger_event (ProactiveTriggerEvent): 触发事件。
|
||||
"""
|
||||
try:
|
||||
topic = action_result.get("action_data", {}).get("topic", "随便聊聊")
|
||||
logger.info(f"{self.context.log_prefix} 主动思考确定主题: '{topic}'")
|
||||
|
||||
schedule_block = "你今天没有日程安排。"
|
||||
if global_config.planning_system.schedule_enable:
|
||||
if current_activity := schedule_manager.get_current_activity():
|
||||
schedule_block = f"你当前正在:{current_activity}。"
|
||||
|
||||
news_block = "暂时没有获取到最新资讯。"
|
||||
if trigger_event.source != "reminder_system":
|
||||
try:
|
||||
web_search_tool = tool_api.get_tool_instance("web_search")
|
||||
if web_search_tool:
|
||||
try:
|
||||
search_result_dict = await web_search_tool.execute(function_args={"keyword": topic, "max_results": 10})
|
||||
except TypeError:
|
||||
try:
|
||||
search_result_dict = await web_search_tool.execute(function_args={"keyword": topic, "max_results": 10})
|
||||
except TypeError:
|
||||
logger.warning(f"{self.context.log_prefix} 网络搜索工具参数不匹配,跳过搜索")
|
||||
news_block = "跳过网络搜索。"
|
||||
search_result_dict = None
|
||||
|
||||
if search_result_dict and not search_result_dict.get("error"):
|
||||
news_block = search_result_dict.get("content", "未能提取有效资讯。")
|
||||
elif search_result_dict:
|
||||
logger.warning(f"{self.context.log_prefix} 网络搜索返回错误: {search_result_dict.get('error')}")
|
||||
else:
|
||||
logger.warning(f"{self.context.log_prefix} 未找到 web_search 工具实例。")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 主动思考时网络搜索失败: {e}")
|
||||
message_list = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.context.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.3),
|
||||
)
|
||||
chat_context_block, _ = await build_readable_messages_with_id(messages=message_list)
|
||||
bot_name = global_config.bot.nickname
|
||||
personality = global_config.personality
|
||||
identity_block = (
|
||||
f"你的名字是{bot_name}。\n"
|
||||
f"关于你:{personality.personality_core},并且{personality.personality_side}。\n"
|
||||
f"你的身份是{personality.identity},平时说话风格是{personality.reply_style}。"
|
||||
)
|
||||
mood_block = f"你现在的心情是:{mood_manager.get_mood_by_chat_id(self.context.stream_id).mood_state}"
|
||||
|
||||
final_prompt = f"""
|
||||
## 你的角色
|
||||
{identity_block}
|
||||
|
||||
## 你的心情
|
||||
{mood_block}
|
||||
|
||||
## 你今天的日程安排
|
||||
{schedule_block}
|
||||
|
||||
## 关于你准备讨论的话题"{topic}"的最新信息
|
||||
{news_block}
|
||||
|
||||
## 最近的聊天内容
|
||||
{chat_context_block}
|
||||
|
||||
## 任务
|
||||
你现在想要主动说些什么。话题是"{topic}",但这只是一个参考方向。
|
||||
|
||||
根据最近的聊天内容,你可以:
|
||||
- 如果是想关心朋友,就自然地询问他们的情况
|
||||
- 如果想起了之前的话题,就问问后来怎么样了
|
||||
- 如果有什么想分享的想法,就自然地开启话题
|
||||
- 如果只是想闲聊,就随意地说些什么
|
||||
|
||||
**重要**:如果获取到了最新的网络信息(news_block不为空),请**自然地**将这些信息融入你的回复中,作为话题的补充或引子,而不是生硬地复述。
|
||||
|
||||
## 要求
|
||||
- 像真正的朋友一样,自然地表达关心或好奇
|
||||
- 不要过于正式,要口语化和亲切
|
||||
- 结合你的角色设定,保持温暖的风格
|
||||
- 直接输出你想说的话,不要解释为什么要说
|
||||
|
||||
请输出一条简短、自然的主动发言。
|
||||
"""
|
||||
|
||||
response_text = await generator_api.generate_response_custom(
|
||||
chat_stream=self.context.chat_stream,
|
||||
prompt=final_prompt,
|
||||
request_type="chat.replyer.proactive",
|
||||
)
|
||||
|
||||
if response_text:
|
||||
response_set = process_human_text(
|
||||
content=response_text,
|
||||
enable_splitter=global_config.response_splitter.enable,
|
||||
enable_chinese_typo=global_config.chinese_typo.enable,
|
||||
)
|
||||
await self.cycle_processor.response_handler.send_response(
|
||||
response_set, time.time(), action_result.get("action_message")
|
||||
)
|
||||
await store_action_info(
|
||||
chat_stream=self.context.chat_stream,
|
||||
action_name="proactive_reply",
|
||||
action_data={"topic": topic, "response": response_text},
|
||||
action_prompt_display=f"主动发起对话: {topic}",
|
||||
action_done=True,
|
||||
)
|
||||
else:
|
||||
logger.error(f"{self.context.log_prefix} 主动思考生成回复失败。")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 生成主动回复内容时异常: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
@@ -1,184 +0,0 @@
|
||||
import time
|
||||
import random
|
||||
from typing import Dict, Any, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import send_api, message_api, database_api
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from .hfc_context import HfcContext
|
||||
|
||||
# 导入反注入系统
|
||||
|
||||
# 日志记录器
|
||||
logger = get_logger("hfc")
|
||||
anti_injector_logger = get_logger("anti_injector")
|
||||
|
||||
|
||||
class ResponseHandler:
|
||||
"""
|
||||
响应处理器类,负责生成和发送机器人的回复。
|
||||
"""
|
||||
def __init__(self, context: HfcContext):
|
||||
"""
|
||||
初始化响应处理器
|
||||
|
||||
Args:
|
||||
context: HFC聊天上下文对象
|
||||
|
||||
功能说明:
|
||||
- 负责生成和发送机器人的回复
|
||||
- 处理回复的格式化和发送逻辑
|
||||
- 管理回复状态和日志记录
|
||||
"""
|
||||
self.context = context
|
||||
|
||||
async def generate_and_send_reply(
|
||||
self,
|
||||
response_set,
|
||||
reply_to_str,
|
||||
loop_start_time,
|
||||
action_message,
|
||||
cycle_timers: Dict[str, float],
|
||||
thinking_id,
|
||||
plan_result,
|
||||
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
|
||||
"""
|
||||
生成并发送回复的主方法
|
||||
|
||||
Args:
|
||||
response_set: 生成的回复内容集合
|
||||
reply_to_str: 回复目标字符串
|
||||
loop_start_time: 循环开始时间
|
||||
action_message: 动作消息数据
|
||||
cycle_timers: 循环计时器
|
||||
thinking_id: 思考ID
|
||||
plan_result: 规划结果
|
||||
|
||||
Returns:
|
||||
tuple: (循环信息, 回复文本, 计时器信息)
|
||||
|
||||
功能说明:
|
||||
- 发送生成的回复内容
|
||||
- 存储动作信息到数据库
|
||||
- 构建并返回完整的循环信息
|
||||
- 用于上级方法的状态跟踪
|
||||
"""
|
||||
reply_text = await self.send_response(response_set, loop_start_time, action_message)
|
||||
|
||||
person_info_manager = get_person_info_manager()
|
||||
|
||||
# 获取平台信息
|
||||
platform = "default"
|
||||
if self.context.chat_stream:
|
||||
platform = (
|
||||
action_message.get("chat_info_platform")
|
||||
or action_message.get("user_platform")
|
||||
or self.context.chat_stream.platform
|
||||
)
|
||||
|
||||
# 获取用户信息并生成回复提示
|
||||
user_id = action_message.get("user_id", "")
|
||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
|
||||
|
||||
# 存储动作信息到数据库
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.context.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=action_prompt_display,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reply_text": reply_text, "reply_to": reply_to_str},
|
||||
action_name="reply",
|
||||
)
|
||||
|
||||
# 构建循环信息
|
||||
loop_info: Dict[str, Any] = {
|
||||
"loop_plan_info": {
|
||||
"action_result": plan_result.get("action_result", {}),
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": True,
|
||||
"reply_text": reply_text,
|
||||
"command": "",
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
|
||||
return loop_info, reply_text, cycle_timers
|
||||
|
||||
async def send_response(self, reply_set, thinking_start_time, message_data) -> str:
|
||||
"""
|
||||
发送回复内容的具体实现
|
||||
|
||||
Args:
|
||||
reply_set: 回复内容集合,包含多个回复段
|
||||
reply_to: 回复目标
|
||||
thinking_start_time: 思考开始时间
|
||||
message_data: 消息数据
|
||||
|
||||
Returns:
|
||||
str: 完整的回复文本
|
||||
|
||||
功能说明:
|
||||
- 检查是否有新消息需要回复
|
||||
- 处理主动思考的"沉默"决定
|
||||
- 根据消息数量决定是否添加回复引用
|
||||
- 逐段发送回复内容,支持打字效果
|
||||
- 正确处理元组格式的回复段
|
||||
"""
|
||||
current_time = time.time()
|
||||
# 计算新消息数量
|
||||
new_message_count = await message_api.count_new_messages(
|
||||
chat_id=self.context.stream_id, start_time=thinking_start_time, end_time=current_time
|
||||
)
|
||||
|
||||
# 根据新消息数量决定是否需要引用回复
|
||||
need_reply = new_message_count >= random.randint(2, 4)
|
||||
|
||||
reply_text = ""
|
||||
is_proactive_thinking = (message_data.get("message_type") == "proactive_thinking") if message_data else True
|
||||
|
||||
first_replied = False
|
||||
for reply_seg in reply_set:
|
||||
# 调试日志:验证reply_seg的格式
|
||||
logger.debug(f"Processing reply_seg type: {type(reply_seg)}, content: {reply_seg}")
|
||||
|
||||
# 修正:正确处理元组格式 (格式为: (type, content))
|
||||
if isinstance(reply_seg, tuple) and len(reply_seg) >= 2:
|
||||
_, data = reply_seg
|
||||
else:
|
||||
# 向下兼容:如果已经是字符串,则直接使用
|
||||
data = str(reply_seg)
|
||||
|
||||
if isinstance(data, list):
|
||||
data = "".join(map(str, data))
|
||||
reply_text += data
|
||||
|
||||
# 如果是主动思考且内容为“沉默”,则不发送
|
||||
if is_proactive_thinking and data.strip() == "沉默":
|
||||
logger.info(f"{self.context.log_prefix} 主动思考决定保持沉默,不发送消息")
|
||||
continue
|
||||
|
||||
# 发送第一段回复
|
||||
if not first_replied:
|
||||
await send_api.text_to_stream(
|
||||
text=data,
|
||||
stream_id=self.context.stream_id,
|
||||
reply_to_message=message_data,
|
||||
set_reply=need_reply,
|
||||
typing=False,
|
||||
)
|
||||
first_replied = True
|
||||
else:
|
||||
# 发送后续回复
|
||||
sent_message = await send_api.text_to_stream(
|
||||
text=data,
|
||||
stream_id=self.context.stream_id,
|
||||
reply_to_message=None,
|
||||
set_reply=False,
|
||||
typing=True,
|
||||
)
|
||||
|
||||
return reply_text
|
||||
@@ -1,32 +0,0 @@
|
||||
from src.common.logger import get_logger
|
||||
from ..hfc_context import HfcContext
|
||||
|
||||
logger = get_logger("notification_sender")
|
||||
|
||||
|
||||
class NotificationSender:
|
||||
@staticmethod
|
||||
async def send_goodnight_notification(context: HfcContext):
|
||||
"""发送晚安通知"""
|
||||
try:
|
||||
from ..proactive.events import ProactiveTriggerEvent
|
||||
from ..proactive.proactive_thinker import ProactiveThinker
|
||||
|
||||
event = ProactiveTriggerEvent(source="sleep_manager", reason="goodnight")
|
||||
proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor)
|
||||
await proactive_thinker.think(event)
|
||||
except Exception as e:
|
||||
logger.error(f"发送晚安通知失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def send_insomnia_notification(context: HfcContext, reason: str):
|
||||
"""发送失眠通知"""
|
||||
try:
|
||||
from ..proactive.events import ProactiveTriggerEvent
|
||||
from ..proactive.proactive_thinker import ProactiveThinker
|
||||
|
||||
event = ProactiveTriggerEvent(source="sleep_manager", reason=reason)
|
||||
proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor)
|
||||
await proactive_thinker.think(event)
|
||||
except Exception as e:
|
||||
logger.error(f"发送失眠通知失败: {e}")
|
||||
@@ -1,110 +0,0 @@
|
||||
from enum import Enum, auto
|
||||
from datetime import datetime
|
||||
from src.common.logger import get_logger
|
||||
from src.manager.local_store_manager import local_storage
|
||||
|
||||
logger = get_logger("sleep_state")
|
||||
|
||||
|
||||
class SleepState(Enum):
|
||||
"""
|
||||
定义了角色可能处于的几种睡眠状态。
|
||||
这是一个状态机,用于管理角色的睡眠周期。
|
||||
"""
|
||||
|
||||
AWAKE = auto() # 清醒状态
|
||||
INSOMNIA = auto() # 失眠状态
|
||||
PREPARING_SLEEP = auto() # 准备入睡状态,一个短暂的过渡期
|
||||
SLEEPING = auto() # 正在睡觉状态
|
||||
WOKEN_UP = auto() # 被吵醒状态
|
||||
|
||||
|
||||
class SleepStateSerializer:
|
||||
"""
|
||||
睡眠状态序列化器。
|
||||
负责将内存中的睡眠状态对象持久化到本地存储(如JSON文件),
|
||||
以及在程序启动时从本地存储中恢复状态。
|
||||
这样可以确保即使程序重启,角色的睡眠状态也能得以保留。
|
||||
"""
|
||||
@staticmethod
|
||||
def save(state_data: dict):
|
||||
"""
|
||||
将当前的睡眠状态数据保存到本地存储。
|
||||
|
||||
Args:
|
||||
state_data (dict): 包含睡眠状态信息的字典。
|
||||
datetime对象会被转换为时间戳,Enum成员会被转换为其名称字符串。
|
||||
"""
|
||||
try:
|
||||
# 准备要序列化的数据字典
|
||||
state = {
|
||||
# 保存当前状态的枚举名称
|
||||
"current_state": state_data["_current_state"].name,
|
||||
# 将datetime对象转换为Unix时间戳以便序列化
|
||||
"sleep_buffer_end_time_ts": state_data["_sleep_buffer_end_time"].timestamp()
|
||||
if state_data["_sleep_buffer_end_time"]
|
||||
else None,
|
||||
"total_delayed_minutes_today": state_data["_total_delayed_minutes_today"],
|
||||
# 将date对象转换为ISO格式的字符串
|
||||
"last_sleep_check_date_str": state_data["_last_sleep_check_date"].isoformat()
|
||||
if state_data["_last_sleep_check_date"]
|
||||
else None,
|
||||
"re_sleep_attempt_time_ts": state_data["_re_sleep_attempt_time"].timestamp()
|
||||
if state_data["_re_sleep_attempt_time"]
|
||||
else None,
|
||||
}
|
||||
# 写入本地存储
|
||||
local_storage["schedule_sleep_state"] = state
|
||||
logger.debug(f"已保存睡眠状态: {state}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存睡眠状态失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
def load() -> dict:
|
||||
"""
|
||||
从本地存储加载并解析睡眠状态。
|
||||
|
||||
Returns:
|
||||
dict: 包含恢复后睡眠状态信息的字典。
|
||||
如果加载失败或没有找到数据,则返回一个默认的清醒状态。
|
||||
"""
|
||||
# 定义一个默认的状态,以防加载失败
|
||||
state_data = {
|
||||
"_current_state": SleepState.AWAKE,
|
||||
"_sleep_buffer_end_time": None,
|
||||
"_total_delayed_minutes_today": 0,
|
||||
"_last_sleep_check_date": None,
|
||||
"_re_sleep_attempt_time": None,
|
||||
}
|
||||
try:
|
||||
# 从本地存储读取数据
|
||||
state = local_storage["schedule_sleep_state"]
|
||||
if state and isinstance(state, dict):
|
||||
# 恢复当前状态枚举
|
||||
state_name = state.get("current_state")
|
||||
if state_name and hasattr(SleepState, state_name):
|
||||
state_data["_current_state"] = SleepState[state_name]
|
||||
|
||||
# 从时间戳恢复datetime对象
|
||||
end_time_ts = state.get("sleep_buffer_end_time_ts")
|
||||
if end_time_ts:
|
||||
state_data["_sleep_buffer_end_time"] = datetime.fromtimestamp(end_time_ts)
|
||||
|
||||
# 恢复重新入睡尝试时间
|
||||
re_sleep_ts = state.get("re_sleep_attempt_time_ts")
|
||||
if re_sleep_ts:
|
||||
state_data["_re_sleep_attempt_time"] = datetime.fromtimestamp(re_sleep_ts)
|
||||
|
||||
# 恢复今日延迟睡眠总分钟数
|
||||
state_data["_total_delayed_minutes_today"] = state.get("total_delayed_minutes_today", 0)
|
||||
|
||||
# 从ISO格式字符串恢复date对象
|
||||
date_str = state.get("last_sleep_check_date_str")
|
||||
if date_str:
|
||||
state_data["_last_sleep_check_date"] = datetime.fromisoformat(date_str).date()
|
||||
|
||||
logger.info(f"成功从本地存储加载睡眠状态: {state}")
|
||||
except Exception as e:
|
||||
# 如果加载过程中出现任何问题,记录警告并返回默认状态
|
||||
logger.warning(f"加载睡眠状态失败,将使用默认值: {e}")
|
||||
return state_data
|
||||
@@ -1,232 +0,0 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.manager.local_store_manager import local_storage
|
||||
from ..hfc_context import HfcContext
|
||||
|
||||
logger = get_logger("wakeup")
|
||||
|
||||
|
||||
class WakeUpManager:
|
||||
def __init__(self, context: HfcContext):
|
||||
"""
|
||||
初始化唤醒度管理器
|
||||
|
||||
Args:
|
||||
context: HFC聊天上下文对象
|
||||
|
||||
功能说明:
|
||||
- 管理休眠状态下的唤醒度累积
|
||||
- 处理唤醒度的自然衰减
|
||||
- 控制愤怒状态的持续时间
|
||||
"""
|
||||
self.context = context
|
||||
self.wakeup_value = 0.0 # 当前唤醒度
|
||||
self.is_angry = False # 是否处于愤怒状态
|
||||
self.angry_start_time = 0.0 # 愤怒状态开始时间
|
||||
self.last_decay_time = time.time() # 上次衰减时间
|
||||
self._decay_task: Optional[asyncio.Task] = None
|
||||
self.last_log_time = 0
|
||||
self.log_interval = 30
|
||||
|
||||
# 从配置文件获取参数
|
||||
sleep_config = global_config.sleep_system
|
||||
self.wakeup_threshold = sleep_config.wakeup_threshold
|
||||
self.private_message_increment = sleep_config.private_message_increment
|
||||
self.group_mention_increment = sleep_config.group_mention_increment
|
||||
self.decay_rate = sleep_config.decay_rate
|
||||
self.decay_interval = sleep_config.decay_interval
|
||||
self.angry_duration = sleep_config.angry_duration
|
||||
self.enabled = sleep_config.enable
|
||||
self.angry_prompt = sleep_config.angry_prompt
|
||||
|
||||
self._load_wakeup_state()
|
||||
|
||||
def _get_storage_key(self) -> str:
|
||||
"""获取当前聊天流的本地存储键"""
|
||||
return f"wakeup_manager_state_{self.context.stream_id}"
|
||||
|
||||
def _load_wakeup_state(self):
|
||||
"""从本地存储加载状态"""
|
||||
state = local_storage[self._get_storage_key()]
|
||||
if state and isinstance(state, dict):
|
||||
self.wakeup_value = state.get("wakeup_value", 0.0)
|
||||
self.is_angry = state.get("is_angry", False)
|
||||
self.angry_start_time = state.get("angry_start_time", 0.0)
|
||||
logger.info(f"{self.context.log_prefix} 成功从本地存储加载唤醒状态: {state}")
|
||||
else:
|
||||
logger.info(f"{self.context.log_prefix} 未找到本地唤醒状态,将使用默认值初始化。")
|
||||
|
||||
def _save_wakeup_state(self):
|
||||
"""将当前状态保存到本地存储"""
|
||||
state = {
|
||||
"wakeup_value": self.wakeup_value,
|
||||
"is_angry": self.is_angry,
|
||||
"angry_start_time": self.angry_start_time,
|
||||
}
|
||||
local_storage[self._get_storage_key()] = state
|
||||
logger.debug(f"{self.context.log_prefix} 已将唤醒状态保存到本地存储: {state}")
|
||||
|
||||
async def start(self):
|
||||
"""启动唤醒度管理器"""
|
||||
if not self.enabled:
|
||||
logger.info(f"{self.context.log_prefix} 唤醒度系统已禁用,跳过启动")
|
||||
return
|
||||
|
||||
if not self._decay_task:
|
||||
self._decay_task = asyncio.create_task(self._decay_loop())
|
||||
self._decay_task.add_done_callback(self._handle_decay_completion)
|
||||
logger.info(f"{self.context.log_prefix} 唤醒度管理器已启动")
|
||||
|
||||
async def stop(self):
|
||||
"""停止唤醒度管理器"""
|
||||
if self._decay_task and not self._decay_task.done():
|
||||
self._decay_task.cancel()
|
||||
await asyncio.sleep(0)
|
||||
logger.info(f"{self.context.log_prefix} 唤醒度管理器已停止")
|
||||
|
||||
def _handle_decay_completion(self, task: asyncio.Task):
|
||||
"""处理衰减任务完成"""
|
||||
try:
|
||||
if exception := task.exception():
|
||||
logger.error(f"{self.context.log_prefix} 唤醒度衰减任务异常: {exception}")
|
||||
else:
|
||||
logger.info(f"{self.context.log_prefix} 唤醒度衰减任务正常结束")
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.context.log_prefix} 唤醒度衰减任务被取消")
|
||||
|
||||
async def _decay_loop(self):
|
||||
"""唤醒度衰减循环"""
|
||||
while self.context.running:
|
||||
await asyncio.sleep(self.decay_interval)
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 检查愤怒状态是否过期
|
||||
if self.is_angry and current_time - self.angry_start_time >= self.angry_duration:
|
||||
self.is_angry = False
|
||||
# 通知情绪管理系统清除愤怒状态
|
||||
from src.mood.mood_manager import mood_manager
|
||||
|
||||
mood_manager.clear_angry_from_wakeup(self.context.stream_id)
|
||||
logger.info(f"{self.context.log_prefix} 愤怒状态结束,恢复正常")
|
||||
self._save_wakeup_state()
|
||||
|
||||
# 唤醒度自然衰减
|
||||
if self.wakeup_value > 0:
|
||||
old_value = self.wakeup_value
|
||||
self.wakeup_value = max(0, self.wakeup_value - self.decay_rate)
|
||||
if old_value != self.wakeup_value:
|
||||
logger.debug(f"{self.context.log_prefix} 唤醒度衰减: {old_value:.1f} -> {self.wakeup_value:.1f}")
|
||||
self._save_wakeup_state()
|
||||
|
||||
def add_wakeup_value(self, is_private_chat: bool, is_mentioned: bool = False) -> bool:
|
||||
"""
|
||||
增加唤醒度值
|
||||
|
||||
Args:
|
||||
is_private_chat: 是否为私聊
|
||||
is_mentioned: 是否被艾特(仅群聊有效)
|
||||
|
||||
Returns:
|
||||
bool: 是否达到唤醒阈值
|
||||
"""
|
||||
# 如果系统未启用,直接返回
|
||||
if not self.enabled:
|
||||
return False
|
||||
|
||||
# 只有在休眠且非失眠状态下才累积唤醒度
|
||||
from .sleep_state import SleepState
|
||||
|
||||
sleep_manager = self.context.sleep_manager
|
||||
if not sleep_manager:
|
||||
return False
|
||||
|
||||
current_sleep_state = sleep_manager.get_current_sleep_state()
|
||||
if current_sleep_state != SleepState.SLEEPING:
|
||||
return False
|
||||
|
||||
old_value = self.wakeup_value
|
||||
|
||||
if is_private_chat:
|
||||
# 私聊每条消息都增加唤醒度
|
||||
self.wakeup_value += self.private_message_increment
|
||||
logger.debug(f"{self.context.log_prefix} 私聊消息增加唤醒度: +{self.private_message_increment}")
|
||||
elif is_mentioned:
|
||||
# 群聊只有被艾特才增加唤醒度
|
||||
self.wakeup_value += self.group_mention_increment
|
||||
logger.debug(f"{self.context.log_prefix} 群聊艾特增加唤醒度: +{self.group_mention_increment}")
|
||||
else:
|
||||
# 群聊未被艾特,不增加唤醒度
|
||||
return False
|
||||
|
||||
current_time = time.time()
|
||||
if current_time - self.last_log_time > self.log_interval:
|
||||
logger.info(
|
||||
f"{self.context.log_prefix} 唤醒度变化: {old_value:.1f} -> {self.wakeup_value:.1f} (阈值: {self.wakeup_threshold})"
|
||||
)
|
||||
self.last_log_time = current_time
|
||||
else:
|
||||
logger.debug(
|
||||
f"{self.context.log_prefix} 唤醒度变化: {old_value:.1f} -> {self.wakeup_value:.1f} (阈值: {self.wakeup_threshold})"
|
||||
)
|
||||
|
||||
# 检查是否达到唤醒阈值
|
||||
if self.wakeup_value >= self.wakeup_threshold:
|
||||
self._trigger_wakeup()
|
||||
return True
|
||||
|
||||
self._save_wakeup_state()
|
||||
return False
|
||||
|
||||
def _trigger_wakeup(self):
|
||||
"""触发唤醒,进入愤怒状态"""
|
||||
self.is_angry = True
|
||||
self.angry_start_time = time.time()
|
||||
self.wakeup_value = 0.0 # 重置唤醒度
|
||||
|
||||
self._save_wakeup_state()
|
||||
|
||||
# 通知情绪管理系统进入愤怒状态
|
||||
from src.mood.mood_manager import mood_manager
|
||||
|
||||
mood_manager.set_angry_from_wakeup(self.context.stream_id)
|
||||
|
||||
# 通知SleepManager重置睡眠状态
|
||||
if self.context.sleep_manager:
|
||||
self.context.sleep_manager.reset_sleep_state_after_wakeup()
|
||||
|
||||
logger.info(f"{self.context.log_prefix} 唤醒度达到阈值({self.wakeup_threshold}),被吵醒进入愤怒状态!")
|
||||
|
||||
def get_angry_prompt_addition(self) -> str:
|
||||
"""获取愤怒状态下的提示词补充"""
|
||||
if self.is_angry:
|
||||
return self.angry_prompt
|
||||
return ""
|
||||
|
||||
def is_in_angry_state(self) -> bool:
|
||||
"""检查是否处于愤怒状态"""
|
||||
if self.is_angry:
|
||||
current_time = time.time()
|
||||
if current_time - self.angry_start_time >= self.angry_duration:
|
||||
self.is_angry = False
|
||||
# 通知情绪管理系统清除愤怒状态
|
||||
from src.mood.mood_manager import mood_manager
|
||||
|
||||
mood_manager.clear_angry_from_wakeup(self.context.stream_id)
|
||||
logger.info(f"{self.context.log_prefix} 愤怒状态自动过期")
|
||||
return False
|
||||
return self.is_angry
|
||||
|
||||
def get_status_info(self) -> dict:
|
||||
"""获取当前状态信息"""
|
||||
return {
|
||||
"wakeup_value": self.wakeup_value,
|
||||
"wakeup_threshold": self.wakeup_threshold,
|
||||
"is_angry": self.is_angry,
|
||||
"angry_remaining_time": max(0, self.angry_duration - (time.time() - self.angry_start_time))
|
||||
if self.is_angry
|
||||
else 0,
|
||||
}
|
||||
145
src/chat/chatter_manager.py
Normal file
145
src/chat/chatter_manager.py
Normal file
@@ -0,0 +1,145 @@
|
||||
from typing import Dict, List, Optional, Any
|
||||
import time
|
||||
from src.plugin_system.base.base_chatter import BaseChatter
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.plugins.built_in.affinity_flow_chatter.planner import ChatterActionPlanner as ActionPlanner
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.plugin_system.base.component_types import ChatType, ComponentType
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("chatter_manager")
|
||||
|
||||
class ChatterManager:
|
||||
def __init__(self, action_manager: ChatterActionManager):
|
||||
self.action_manager = action_manager
|
||||
self.chatter_classes: Dict[ChatType, List[type]] = {}
|
||||
self.instances: Dict[str, BaseChatter] = {}
|
||||
|
||||
# 管理器统计
|
||||
self.stats = {
|
||||
"chatters_registered": 0,
|
||||
"streams_processed": 0,
|
||||
"successful_executions": 0,
|
||||
"failed_executions": 0,
|
||||
}
|
||||
|
||||
def _auto_register_from_component_registry(self):
|
||||
"""从组件注册表自动注册已注册的chatter组件"""
|
||||
try:
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
# 获取所有CHATTER类型的组件
|
||||
chatter_components = component_registry.get_enabled_chatter_registry()
|
||||
for chatter_name, chatter_class in chatter_components.items():
|
||||
self.register_chatter(chatter_class)
|
||||
logger.info(f"自动注册chatter组件: {chatter_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"自动注册chatter组件时发生错误: {e}")
|
||||
|
||||
def register_chatter(self, chatter_class: type):
|
||||
"""注册聊天处理器类"""
|
||||
for chat_type in chatter_class.chat_types:
|
||||
if chat_type not in self.chatter_classes:
|
||||
self.chatter_classes[chat_type] = []
|
||||
self.chatter_classes[chat_type].append(chatter_class)
|
||||
logger.info(f"注册聊天处理器 {chatter_class.__name__} 支持 {chat_type.value} 聊天类型")
|
||||
|
||||
self.stats["chatters_registered"] += 1
|
||||
|
||||
def get_chatter_class(self, chat_type: ChatType) -> Optional[type]:
|
||||
"""获取指定聊天类型的聊天处理器类"""
|
||||
if chat_type in self.chatter_classes:
|
||||
return self.chatter_classes[chat_type][0]
|
||||
return None
|
||||
|
||||
def get_supported_chat_types(self) -> List[ChatType]:
|
||||
"""获取支持的聊天类型列表"""
|
||||
return list(self.chatter_classes.keys())
|
||||
|
||||
def get_registered_chatters(self) -> Dict[ChatType, List[type]]:
|
||||
"""获取已注册的聊天处理器"""
|
||||
return self.chatter_classes.copy()
|
||||
|
||||
def get_stream_instance(self, stream_id: str) -> Optional[BaseChatter]:
|
||||
"""获取指定流的聊天处理器实例"""
|
||||
return self.instances.get(stream_id)
|
||||
|
||||
def cleanup_inactive_instances(self, max_inactive_minutes: int = 60):
|
||||
"""清理不活跃的实例"""
|
||||
current_time = time.time()
|
||||
max_inactive_seconds = max_inactive_minutes * 60
|
||||
|
||||
inactive_streams = []
|
||||
for stream_id, instance in self.instances.items():
|
||||
if hasattr(instance, 'get_activity_time'):
|
||||
activity_time = instance.get_activity_time()
|
||||
if (current_time - activity_time) > max_inactive_seconds:
|
||||
inactive_streams.append(stream_id)
|
||||
|
||||
for stream_id in inactive_streams:
|
||||
del self.instances[stream_id]
|
||||
logger.info(f"清理不活跃聊天流实例: {stream_id}")
|
||||
|
||||
async def process_stream_context(self, stream_id: str, context: StreamContext) -> dict:
|
||||
"""处理流上下文"""
|
||||
chat_type = context.chat_type
|
||||
logger.debug(f"处理流 {stream_id},聊天类型: {chat_type.value}")
|
||||
if not self.chatter_classes:
|
||||
self._auto_register_from_component_registry()
|
||||
|
||||
# 获取适合该聊天类型的chatter
|
||||
chatter_class = self.get_chatter_class(chat_type)
|
||||
if not chatter_class:
|
||||
# 如果没有找到精确匹配,尝试查找支持ALL类型的chatter
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
all_chatter_class = self.get_chatter_class(ChatType.ALL)
|
||||
if all_chatter_class:
|
||||
chatter_class = all_chatter_class
|
||||
logger.info(f"流 {stream_id} 使用通用chatter (类型: {chat_type.value})")
|
||||
else:
|
||||
raise ValueError(f"No chatter registered for chat type {chat_type}")
|
||||
|
||||
if stream_id not in self.instances:
|
||||
self.instances[stream_id] = chatter_class(stream_id=stream_id, action_manager=self.action_manager)
|
||||
logger.info(f"创建新的聊天流实例: {stream_id} 使用 {chatter_class.__name__} (类型: {chat_type.value})")
|
||||
|
||||
self.stats["streams_processed"] += 1
|
||||
try:
|
||||
result = await self.instances[stream_id].execute(context)
|
||||
self.stats["successful_executions"] += 1
|
||||
|
||||
# 从 mood_manager 获取最新的 chat_stream 并同步回 StreamContext
|
||||
try:
|
||||
from src.mood.mood_manager import mood_manager
|
||||
mood = mood_manager.get_mood_by_chat_id(stream_id)
|
||||
if mood and mood.chat_stream:
|
||||
context.chat_stream = mood.chat_stream
|
||||
logger.debug(f"已将最新的 chat_stream 同步回流 {stream_id} 的 StreamContext")
|
||||
except Exception as sync_e:
|
||||
logger.error(f"同步 chat_stream 回 StreamContext 失败: {sync_e}")
|
||||
|
||||
# 记录处理结果
|
||||
success = result.get("success", False)
|
||||
actions_count = result.get("actions_count", 0)
|
||||
logger.debug(f"流 {stream_id} 处理完成: 成功={success}, 动作数={actions_count}")
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
self.stats["failed_executions"] += 1
|
||||
logger.error(f"处理流 {stream_id} 时发生错误: {e}")
|
||||
raise
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取管理器统计信息"""
|
||||
stats = self.stats.copy()
|
||||
stats["active_instances"] = len(self.instances)
|
||||
stats["registered_chatter_types"] = len(self.chatter_classes)
|
||||
return stats
|
||||
|
||||
def reset_stats(self):
|
||||
"""重置统计信息"""
|
||||
self.stats = {
|
||||
"chatters_registered": 0,
|
||||
"streams_processed": 0,
|
||||
"successful_executions": 0,
|
||||
"failed_executions": 0,
|
||||
}
|
||||
@@ -2,8 +2,10 @@
|
||||
"""
|
||||
表情包发送历史记录模块
|
||||
"""
|
||||
from collections import deque
|
||||
|
||||
import os
|
||||
from typing import List, Dict
|
||||
from collections import deque
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -25,15 +27,15 @@ def add_emoji_to_history(chat_id: str, emoji_description: str):
|
||||
"""
|
||||
if not chat_id or not emoji_description:
|
||||
return
|
||||
|
||||
|
||||
# 如果当前聊天还没有历史记录,则创建一个新的 deque
|
||||
if chat_id not in _history_cache:
|
||||
_history_cache[chat_id] = deque(maxlen=MAX_HISTORY_SIZE)
|
||||
|
||||
|
||||
# 添加新表情到历史记录
|
||||
history = _history_cache[chat_id]
|
||||
history.append(emoji_description)
|
||||
|
||||
|
||||
logger.debug(f"已将表情 '{emoji_description}' 添加到聊天 {chat_id} 的内存历史中")
|
||||
|
||||
|
||||
@@ -49,10 +51,10 @@ def get_recent_emojis(chat_id: str, limit: int = 5) -> List[str]:
|
||||
return []
|
||||
|
||||
history = _history_cache[chat_id]
|
||||
|
||||
|
||||
# 从 deque 的右侧(即最近添加的)开始取
|
||||
num_to_get = min(limit, len(history))
|
||||
recent_emojis = [history[-i] for i in range(1, num_to_get + 1)]
|
||||
|
||||
|
||||
logger.debug(f"为聊天 {chat_id} 从内存中获取到最近 {len(recent_emojis)} 个表情: {recent_emojis}")
|
||||
return recent_emojis
|
||||
|
||||
@@ -149,7 +149,7 @@ class MaiEmoji:
|
||||
# --- 数据库操作 ---
|
||||
try:
|
||||
# 准备数据库记录 for emoji collection
|
||||
async with get_db_session() as session:
|
||||
with get_db_session() as session:
|
||||
emotion_str = ",".join(self.emotion) if self.emotion else ""
|
||||
|
||||
emoji = Emoji(
|
||||
@@ -167,7 +167,7 @@ class MaiEmoji:
|
||||
last_used_time=self.last_used_time,
|
||||
)
|
||||
session.add(emoji)
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")
|
||||
|
||||
@@ -203,17 +203,17 @@ class MaiEmoji:
|
||||
|
||||
# 2. 删除数据库记录
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
will_delete_emoji = (
|
||||
await session.execute(select(Emoji).where(Emoji.emoji_hash == self.hash))
|
||||
with get_db_session() as session:
|
||||
will_delete_emoji = session.execute(
|
||||
select(Emoji).where(Emoji.emoji_hash == self.hash)
|
||||
).scalar_one_or_none()
|
||||
if will_delete_emoji is None:
|
||||
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
|
||||
result = 0
|
||||
result = 0 # Indicate no DB record was deleted
|
||||
else:
|
||||
await session.delete(will_delete_emoji)
|
||||
result = 1
|
||||
await session.commit()
|
||||
session.delete(will_delete_emoji)
|
||||
result = 1 # Successfully deleted one record
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除数据库记录时出错: {str(e)}")
|
||||
result = 0
|
||||
@@ -424,19 +424,17 @@ class EmojiManager:
|
||||
# if not self._initialized:
|
||||
# raise RuntimeError("EmojiManager not initialized")
|
||||
|
||||
@staticmethod
|
||||
async def record_usage(emoji_hash: str) -> None:
|
||||
def record_usage(self, emoji_hash: str) -> None:
|
||||
"""记录表情使用次数"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
emoji_update = (
|
||||
await session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash))
|
||||
).scalar_one_or_none()
|
||||
with get_db_session() as session:
|
||||
emoji_update = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none()
|
||||
if emoji_update is None:
|
||||
logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包")
|
||||
else:
|
||||
emoji_update.usage_count += 1
|
||||
emoji_update.last_used_time = time.time()
|
||||
emoji_update.last_used_time = time.time() # Update last used time
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"记录表情使用失败: {str(e)}")
|
||||
|
||||
@@ -479,7 +477,7 @@ class EmojiManager:
|
||||
emoji_options_str = ""
|
||||
for i, emoji in enumerate(candidate_emojis):
|
||||
# 为每个表情包创建一个编号和它的详细描述
|
||||
emoji_options_str += f"编号: {i+1}\n描述: {emoji.description}\n\n"
|
||||
emoji_options_str += f"编号: {i + 1}\n描述: {emoji.description}\n\n"
|
||||
|
||||
# 精心设计的prompt,引导LLM做出选择
|
||||
prompt = f"""
|
||||
@@ -523,13 +521,11 @@ class EmojiManager:
|
||||
|
||||
# 7. 获取选中的表情包并更新使用记录
|
||||
selected_emoji = candidate_emojis[selected_index]
|
||||
await self.record_usage(selected_emoji.emoji_hash)
|
||||
self.record_usage(selected_emoji.hash)
|
||||
_time_end = time.time()
|
||||
|
||||
logger.info(
|
||||
f"找到匹配描述的表情包: {selected_emoji.description}, 耗时: {(_time_end - _time_start):.2f}s"
|
||||
)
|
||||
|
||||
logger.info(f"找到匹配描述的表情包: {selected_emoji.description}, 耗时: {(_time_end - _time_start):.2f}s")
|
||||
|
||||
# 8. 返回选中的表情包信息
|
||||
return selected_emoji.full_path, f"[表情包:{selected_emoji.description}]", text_emotion
|
||||
|
||||
@@ -629,8 +625,9 @@ class EmojiManager:
|
||||
|
||||
# 无论steal_emoji是否开启,都检查emoji文件夹以支持手动注册
|
||||
# 只有在需要腾出空间或填充表情库时,才真正执行注册
|
||||
if (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or \
|
||||
(self.emoji_num < self.emoji_num_max):
|
||||
if (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or (
|
||||
self.emoji_num < self.emoji_num_max
|
||||
):
|
||||
try:
|
||||
# 获取目录下所有图片文件
|
||||
files_to_process = [
|
||||
@@ -660,11 +657,10 @@ class EmojiManager:
|
||||
async def get_all_emoji_from_db(self) -> None:
|
||||
"""获取所有表情包并初始化为MaiEmoji类对象,更新 self.emoji_objects"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
with get_db_session() as session:
|
||||
logger.debug("[数据库] 开始加载所有表情包记录 ...")
|
||||
|
||||
result = await session.execute(select(Emoji))
|
||||
emoji_instances = result.scalars().all()
|
||||
emoji_instances = session.execute(select(Emoji)).scalars().all()
|
||||
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
|
||||
|
||||
# 更新内存中的列表和数量
|
||||
@@ -680,8 +676,7 @@ class EmojiManager:
|
||||
self.emoji_objects = [] # 加载失败则清空列表
|
||||
self.emoji_num = 0
|
||||
|
||||
@staticmethod
|
||||
async def get_emoji_from_db(emoji_hash: Optional[str] = None) -> List["MaiEmoji"]:
|
||||
async def get_emoji_from_db(self, emoji_hash: Optional[str] = None) -> List["MaiEmoji"]:
|
||||
"""获取指定哈希值的表情包并初始化为MaiEmoji类对象列表 (主要用于调试或特定查找)
|
||||
|
||||
参数:
|
||||
@@ -691,16 +686,14 @@ class EmojiManager:
|
||||
list[MaiEmoji]: 表情包对象列表
|
||||
"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
with get_db_session() as session:
|
||||
if emoji_hash:
|
||||
result = await session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash))
|
||||
query = result.scalars().all()
|
||||
query = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalars().all()
|
||||
else:
|
||||
logger.warning(
|
||||
"[查询] 未提供 hash,将尝试加载所有表情包,建议使用 get_all_emoji_from_db 更新管理器状态。"
|
||||
)
|
||||
result = await session.execute(select(Emoji))
|
||||
query = result.scalars().all()
|
||||
query = session.execute(select(Emoji)).scalars().all()
|
||||
|
||||
emoji_instances = query
|
||||
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
|
||||
@@ -748,8 +741,8 @@ class EmojiManager:
|
||||
try:
|
||||
emoji_record = await self.get_emoji_from_db(emoji_hash)
|
||||
if emoji_record and emoji_record[0].emotion:
|
||||
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record[0].emotion[:50]}...")
|
||||
return emoji_record[0].emotion
|
||||
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.emotion[:50]}...")
|
||||
return emoji_record.emotion
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库查询表情包描述时出错: {e}")
|
||||
|
||||
@@ -777,11 +770,10 @@ class EmojiManager:
|
||||
|
||||
# 如果内存中没有,从数据库查找
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(
|
||||
with get_db_session() as session:
|
||||
emoji_record = session.execute(
|
||||
select(Emoji).where(Emoji.emoji_hash == emoji_hash)
|
||||
)
|
||||
emoji_record = result.scalar_one_or_none()
|
||||
).scalar_one_or_none()
|
||||
if emoji_record and emoji_record.description:
|
||||
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...")
|
||||
return emoji_record.description
|
||||
@@ -938,19 +930,21 @@ class EmojiManager:
|
||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() if Image.open(io.BytesIO(image_bytes)).format else "jpeg"
|
||||
|
||||
image_format = (
|
||||
Image.open(io.BytesIO(image_bytes)).format.lower()
|
||||
if Image.open(io.BytesIO(image_bytes)).format
|
||||
else "jpeg"
|
||||
)
|
||||
|
||||
# 2. 检查数据库中是否已存在该表情包的描述,实现复用
|
||||
existing_description = None
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(
|
||||
select(Images).filter(
|
||||
(Images.emoji_hash == image_hash) & (Images.type == "emoji")
|
||||
)
|
||||
with get_db_session() as session:
|
||||
existing_image = (
|
||||
session.query(Images)
|
||||
.filter((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
|
||||
.one_or_none()
|
||||
)
|
||||
existing_image = result.scalar_one_or_none()
|
||||
if existing_image and existing_image.description:
|
||||
existing_description = existing_image.description
|
||||
logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...")
|
||||
|
||||
28
src/chat/energy_system/__init__.py
Normal file
28
src/chat/energy_system/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
能量系统模块
|
||||
提供稳定、高效的聊天流能量计算和管理功能
|
||||
"""
|
||||
|
||||
from .energy_manager import (
|
||||
EnergyManager,
|
||||
EnergyLevel,
|
||||
EnergyComponent,
|
||||
EnergyCalculator,
|
||||
InterestEnergyCalculator,
|
||||
ActivityEnergyCalculator,
|
||||
RecencyEnergyCalculator,
|
||||
RelationshipEnergyCalculator,
|
||||
energy_manager
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"EnergyManager",
|
||||
"EnergyLevel",
|
||||
"EnergyComponent",
|
||||
"EnergyCalculator",
|
||||
"InterestEnergyCalculator",
|
||||
"ActivityEnergyCalculator",
|
||||
"RecencyEnergyCalculator",
|
||||
"RelationshipEnergyCalculator",
|
||||
"energy_manager"
|
||||
]
|
||||
473
src/chat/energy_system/energy_manager.py
Normal file
473
src/chat/energy_system/energy_manager.py
Normal file
@@ -0,0 +1,473 @@
|
||||
"""
|
||||
重构后的 focus_energy 管理系统
|
||||
提供稳定、高效的聊天流能量计算和管理功能
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple, Any, Union, TypedDict
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("energy_system")
|
||||
|
||||
|
||||
class EnergyLevel(Enum):
|
||||
"""能量等级"""
|
||||
VERY_LOW = 0.1 # 非常低
|
||||
LOW = 0.3 # 低
|
||||
NORMAL = 0.5 # 正常
|
||||
HIGH = 0.7 # 高
|
||||
VERY_HIGH = 0.9 # 非常高
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnergyComponent:
|
||||
"""能量组件"""
|
||||
name: str
|
||||
value: float
|
||||
weight: float = 1.0
|
||||
decay_rate: float = 0.05 # 衰减率
|
||||
last_updated: float = field(default_factory=time.time)
|
||||
|
||||
def get_current_value(self) -> float:
|
||||
"""获取当前值(考虑时间衰减)"""
|
||||
age = time.time() - self.last_updated
|
||||
decay_factor = max(0.1, 1.0 - (age * self.decay_rate / (24 * 3600))) # 按天衰减
|
||||
return self.value * decay_factor
|
||||
|
||||
def update_value(self, new_value: float) -> None:
|
||||
"""更新值"""
|
||||
self.value = max(0.0, min(1.0, new_value))
|
||||
self.last_updated = time.time()
|
||||
|
||||
|
||||
class EnergyContext(TypedDict):
|
||||
"""能量计算上下文"""
|
||||
stream_id: str
|
||||
messages: List[Any]
|
||||
user_id: Optional[str]
|
||||
|
||||
|
||||
class EnergyResult(TypedDict):
|
||||
"""能量计算结果"""
|
||||
energy: float
|
||||
level: EnergyLevel
|
||||
distribution_interval: float
|
||||
component_scores: Dict[str, float]
|
||||
cached: bool
|
||||
|
||||
|
||||
class EnergyCalculator(ABC):
|
||||
"""能量计算器抽象基类"""
|
||||
|
||||
@abstractmethod
|
||||
def calculate(self, context: Dict[str, Any]) -> float:
|
||||
"""计算能量值"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_weight(self) -> float:
|
||||
"""获取权重"""
|
||||
pass
|
||||
|
||||
|
||||
class InterestEnergyCalculator(EnergyCalculator):
|
||||
"""兴趣度能量计算器"""
|
||||
|
||||
def calculate(self, context: Dict[str, Any]) -> float:
|
||||
"""基于消息兴趣度计算能量"""
|
||||
messages = context.get("messages", [])
|
||||
if not messages:
|
||||
return 0.3
|
||||
|
||||
# 计算平均兴趣度
|
||||
total_interest = 0.0
|
||||
valid_messages = 0
|
||||
|
||||
for msg in messages:
|
||||
interest_value = getattr(msg, "interest_value", None)
|
||||
if interest_value is not None:
|
||||
try:
|
||||
interest_float = float(interest_value)
|
||||
if 0.0 <= interest_float <= 1.0:
|
||||
total_interest += interest_float
|
||||
valid_messages += 1
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
|
||||
if valid_messages > 0:
|
||||
avg_interest = total_interest / valid_messages
|
||||
logger.debug(f"平均消息兴趣度: {avg_interest:.3f} (基于 {valid_messages} 条消息)")
|
||||
return avg_interest
|
||||
else:
|
||||
return 0.3
|
||||
|
||||
def get_weight(self) -> float:
|
||||
return 0.5
|
||||
|
||||
|
||||
class ActivityEnergyCalculator(EnergyCalculator):
|
||||
"""活跃度能量计算器"""
|
||||
|
||||
def __init__(self):
|
||||
self.action_weights = {
|
||||
"reply": 0.4,
|
||||
"react": 0.3,
|
||||
"mention": 0.2,
|
||||
"other": 0.1
|
||||
}
|
||||
|
||||
def calculate(self, context: Dict[str, Any]) -> float:
|
||||
"""基于活跃度计算能量"""
|
||||
messages = context.get("messages", [])
|
||||
if not messages:
|
||||
return 0.2
|
||||
|
||||
total_score = 0.0
|
||||
max_possible_score = len(messages) * 0.4 # 最高可能分数
|
||||
|
||||
for msg in messages:
|
||||
actions = getattr(msg, "actions", [])
|
||||
if isinstance(actions, list) and actions:
|
||||
for action in actions:
|
||||
weight = self.action_weights.get(action, self.action_weights["other"])
|
||||
total_score += weight
|
||||
|
||||
if max_possible_score > 0:
|
||||
activity_score = min(1.0, total_score / max_possible_score)
|
||||
logger.debug(f"活跃度分数: {activity_score:.3f}")
|
||||
return activity_score
|
||||
else:
|
||||
return 0.2
|
||||
|
||||
def get_weight(self) -> float:
|
||||
return 0.3
|
||||
|
||||
|
||||
class RecencyEnergyCalculator(EnergyCalculator):
|
||||
"""最近性能量计算器"""
|
||||
|
||||
def calculate(self, context: Dict[str, Any]) -> float:
|
||||
"""基于最近性计算能量"""
|
||||
messages = context.get("messages", [])
|
||||
if not messages:
|
||||
return 0.1
|
||||
|
||||
# 获取最新消息时间
|
||||
latest_time = 0.0
|
||||
for msg in messages:
|
||||
msg_time = getattr(msg, "time", None)
|
||||
if msg_time and msg_time > latest_time:
|
||||
latest_time = msg_time
|
||||
|
||||
if latest_time == 0.0:
|
||||
return 0.1
|
||||
|
||||
# 计算时间衰减
|
||||
current_time = time.time()
|
||||
age = current_time - latest_time
|
||||
|
||||
# 时间衰减策略:
|
||||
# 1小时内:1.0
|
||||
# 1-6小时:0.8
|
||||
# 6-24小时:0.5
|
||||
# 1-7天:0.3
|
||||
# 7天以上:0.1
|
||||
if age < 3600: # 1小时内
|
||||
recency_score = 1.0
|
||||
elif age < 6 * 3600: # 6小时内
|
||||
recency_score = 0.8
|
||||
elif age < 24 * 3600: # 24小时内
|
||||
recency_score = 0.5
|
||||
elif age < 7 * 24 * 3600: # 7天内
|
||||
recency_score = 0.3
|
||||
else:
|
||||
recency_score = 0.1
|
||||
|
||||
logger.debug(f"最近性分数: {recency_score:.3f} (年龄: {age/3600:.1f}小时)")
|
||||
return recency_score
|
||||
|
||||
def get_weight(self) -> float:
|
||||
return 0.2
|
||||
|
||||
|
||||
class RelationshipEnergyCalculator(EnergyCalculator):
|
||||
"""关系能量计算器"""
|
||||
|
||||
def calculate(self, context: Dict[str, Any]) -> float:
|
||||
"""基于关系计算能量"""
|
||||
user_id = context.get("user_id")
|
||||
if not user_id:
|
||||
return 0.3
|
||||
|
||||
# 使用插件内部的兴趣度评分系统获取关系分
|
||||
try:
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||
|
||||
relationship_score = chatter_interest_scoring_system._calculate_relationship_score(user_id)
|
||||
logger.debug(f"使用插件内部系统计算关系分: {relationship_score:.3f}")
|
||||
return max(0.0, min(1.0, relationship_score))
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"插件内部关系分计算失败,使用默认值: {e}")
|
||||
return 0.3 # 默认基础分
|
||||
|
||||
def get_weight(self) -> float:
|
||||
return 0.1
|
||||
|
||||
|
||||
class EnergyManager:
|
||||
"""能量管理器 - 统一管理所有能量计算"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.calculators: List[EnergyCalculator] = [
|
||||
InterestEnergyCalculator(),
|
||||
ActivityEnergyCalculator(),
|
||||
RecencyEnergyCalculator(),
|
||||
RelationshipEnergyCalculator(),
|
||||
]
|
||||
|
||||
# 能量缓存
|
||||
self.energy_cache: Dict[str, Tuple[float, float]] = {} # stream_id -> (energy, timestamp)
|
||||
self.cache_ttl: int = 60 # 1分钟缓存
|
||||
|
||||
# AFC阈值配置
|
||||
self.thresholds: Dict[str, float] = {
|
||||
"high_match": 0.8,
|
||||
"reply": 0.4,
|
||||
"non_reply": 0.2
|
||||
}
|
||||
|
||||
# 统计信息
|
||||
self.stats: Dict[str, Union[int, float, str]] = {
|
||||
"total_calculations": 0,
|
||||
"cache_hits": 0,
|
||||
"cache_misses": 0,
|
||||
"average_calculation_time": 0.0,
|
||||
"last_threshold_update": time.time(),
|
||||
}
|
||||
|
||||
# 从配置加载阈值
|
||||
self._load_thresholds_from_config()
|
||||
|
||||
logger.info("能量管理器初始化完成")
|
||||
|
||||
def _load_thresholds_from_config(self) -> None:
|
||||
"""从配置加载AFC阈值"""
|
||||
try:
|
||||
if hasattr(global_config, "affinity_flow") and global_config.affinity_flow is not None:
|
||||
self.thresholds["high_match"] = getattr(global_config.affinity_flow, "high_match_interest_threshold", 0.8)
|
||||
self.thresholds["reply"] = getattr(global_config.affinity_flow, "reply_action_interest_threshold", 0.4)
|
||||
self.thresholds["non_reply"] = getattr(global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2)
|
||||
|
||||
# 确保阈值关系合理
|
||||
self.thresholds["high_match"] = max(self.thresholds["high_match"], self.thresholds["reply"] + 0.1)
|
||||
self.thresholds["reply"] = max(self.thresholds["reply"], self.thresholds["non_reply"] + 0.1)
|
||||
|
||||
self.stats["last_threshold_update"] = time.time()
|
||||
logger.info(f"加载AFC阈值: {self.thresholds}")
|
||||
except Exception as e:
|
||||
logger.warning(f"加载AFC阈值失败,使用默认值: {e}")
|
||||
|
||||
def calculate_focus_energy(self, stream_id: str, messages: List[Any], user_id: Optional[str] = None) -> float:
|
||||
"""计算聊天流的focus_energy"""
|
||||
start_time = time.time()
|
||||
|
||||
# 更新统计
|
||||
self.stats["total_calculations"] += 1
|
||||
|
||||
# 检查缓存
|
||||
if stream_id in self.energy_cache:
|
||||
cached_energy, cached_time = self.energy_cache[stream_id]
|
||||
if time.time() - cached_time < self.cache_ttl:
|
||||
self.stats["cache_hits"] += 1
|
||||
logger.debug(f"使用缓存能量: {stream_id} = {cached_energy:.3f}")
|
||||
return cached_energy
|
||||
else:
|
||||
self.stats["cache_misses"] += 1
|
||||
|
||||
# 构建计算上下文
|
||||
context: EnergyContext = {
|
||||
"stream_id": stream_id,
|
||||
"messages": messages,
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
||||
# 计算各组件能量
|
||||
component_scores: Dict[str, float] = {}
|
||||
total_weight = 0.0
|
||||
|
||||
for calculator in self.calculators:
|
||||
try:
|
||||
score = calculator.calculate(context)
|
||||
weight = calculator.get_weight()
|
||||
|
||||
component_scores[calculator.__class__.__name__] = score
|
||||
total_weight += weight
|
||||
|
||||
logger.debug(f"{calculator.__class__.__name__} 能量: {score:.3f} (权重: {weight:.3f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"计算 {calculator.__class__.__name__} 能量失败: {e}")
|
||||
|
||||
# 加权计算总能量
|
||||
if total_weight > 0:
|
||||
total_energy = 0.0
|
||||
for calculator in self.calculators:
|
||||
if calculator.__class__.__name__ in component_scores:
|
||||
score = component_scores[calculator.__class__.__name__]
|
||||
weight = calculator.get_weight()
|
||||
total_energy += score * (weight / total_weight)
|
||||
else:
|
||||
total_energy = 0.5
|
||||
|
||||
# 应用阈值调整和变换
|
||||
final_energy = self._apply_threshold_adjustment(total_energy)
|
||||
|
||||
# 缓存结果
|
||||
self.energy_cache[stream_id] = (final_energy, time.time())
|
||||
|
||||
# 清理过期缓存
|
||||
self._cleanup_cache()
|
||||
|
||||
# 更新平均计算时间
|
||||
calculation_time = time.time() - start_time
|
||||
total_calculations = self.stats["total_calculations"]
|
||||
self.stats["average_calculation_time"] = (
|
||||
(self.stats["average_calculation_time"] * (total_calculations - 1) + calculation_time)
|
||||
/ total_calculations
|
||||
)
|
||||
|
||||
logger.info(f"聊天流 {stream_id} 最终能量: {final_energy:.3f} (原始: {total_energy:.3f}, 耗时: {calculation_time:.3f}s)")
|
||||
return final_energy
|
||||
|
||||
def _apply_threshold_adjustment(self, energy: float) -> float:
|
||||
"""应用阈值调整和变换"""
|
||||
# 获取参考阈值
|
||||
high_threshold = self.thresholds["high_match"]
|
||||
reply_threshold = self.thresholds["reply"]
|
||||
|
||||
# 计算与阈值的相对位置
|
||||
if energy >= high_threshold:
|
||||
# 高能量区域:指数增强
|
||||
adjusted = 0.7 + (energy - 0.7) ** 0.8
|
||||
elif energy >= reply_threshold:
|
||||
# 中等能量区域:线性保持
|
||||
adjusted = energy
|
||||
else:
|
||||
# 低能量区域:对数压缩
|
||||
adjusted = 0.4 * (energy / 0.4) ** 1.2
|
||||
|
||||
# 确保在合理范围内
|
||||
return max(0.1, min(1.0, adjusted))
|
||||
|
||||
def get_energy_level(self, energy: float) -> EnergyLevel:
|
||||
"""获取能量等级"""
|
||||
if energy >= EnergyLevel.VERY_HIGH.value:
|
||||
return EnergyLevel.VERY_HIGH
|
||||
elif energy >= EnergyLevel.HIGH.value:
|
||||
return EnergyLevel.HIGH
|
||||
elif energy >= EnergyLevel.NORMAL.value:
|
||||
return EnergyLevel.NORMAL
|
||||
elif energy >= EnergyLevel.LOW.value:
|
||||
return EnergyLevel.LOW
|
||||
else:
|
||||
return EnergyLevel.VERY_LOW
|
||||
|
||||
def get_distribution_interval(self, energy: float) -> float:
|
||||
"""基于能量等级获取分发周期"""
|
||||
energy_level = self.get_energy_level(energy)
|
||||
|
||||
# 根据能量等级确定基础分发周期
|
||||
if energy_level == EnergyLevel.VERY_HIGH:
|
||||
base_interval = 1.0 # 1秒
|
||||
elif energy_level == EnergyLevel.HIGH:
|
||||
base_interval = 3.0 # 3秒
|
||||
elif energy_level == EnergyLevel.NORMAL:
|
||||
base_interval = 8.0 # 8秒
|
||||
elif energy_level == EnergyLevel.LOW:
|
||||
base_interval = 15.0 # 15秒
|
||||
else:
|
||||
base_interval = 30.0 # 30秒
|
||||
|
||||
# 添加随机扰动避免同步
|
||||
import random
|
||||
jitter = random.uniform(0.8, 1.2)
|
||||
final_interval = base_interval * jitter
|
||||
|
||||
# 确保在配置范围内
|
||||
min_interval = getattr(global_config.chat, "dynamic_distribution_min_interval", 1.0)
|
||||
max_interval = getattr(global_config.chat, "dynamic_distribution_max_interval", 60.0)
|
||||
|
||||
return max(min_interval, min(max_interval, final_interval))
|
||||
|
||||
def invalidate_cache(self, stream_id: str) -> None:
|
||||
"""失效指定流的缓存"""
|
||||
if stream_id in self.energy_cache:
|
||||
del self.energy_cache[stream_id]
|
||||
logger.debug(f"已清除聊天流 {stream_id} 的能量缓存")
|
||||
|
||||
def _cleanup_cache(self) -> None:
|
||||
"""清理过期缓存"""
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
stream_id for stream_id, (_, timestamp) in self.energy_cache.items()
|
||||
if current_time - timestamp > self.cache_ttl
|
||||
]
|
||||
|
||||
for key in expired_keys:
|
||||
del self.energy_cache[key]
|
||||
|
||||
if expired_keys:
|
||||
logger.debug(f"清理了 {len(expired_keys)} 个过期能量缓存")
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
return {
|
||||
"cache_size": len(self.energy_cache),
|
||||
"calculators": [calc.__class__.__name__ for calc in self.calculators],
|
||||
"thresholds": self.thresholds,
|
||||
"performance_stats": self.stats.copy(),
|
||||
}
|
||||
|
||||
def update_thresholds(self, new_thresholds: Dict[str, float]) -> None:
|
||||
"""更新阈值"""
|
||||
self.thresholds.update(new_thresholds)
|
||||
|
||||
# 确保阈值关系合理
|
||||
self.thresholds["high_match"] = max(self.thresholds["high_match"], self.thresholds["reply"] + 0.1)
|
||||
self.thresholds["reply"] = max(self.thresholds["reply"], self.thresholds["non_reply"] + 0.1)
|
||||
|
||||
self.stats["last_threshold_update"] = time.time()
|
||||
logger.info(f"更新AFC阈值: {self.thresholds}")
|
||||
|
||||
def add_calculator(self, calculator: EnergyCalculator) -> None:
|
||||
"""添加计算器"""
|
||||
self.calculators.append(calculator)
|
||||
logger.info(f"添加能量计算器: {calculator.__class__.__name__}")
|
||||
|
||||
def remove_calculator(self, calculator: EnergyCalculator) -> None:
|
||||
"""移除计算器"""
|
||||
if calculator in self.calculators:
|
||||
self.calculators.remove(calculator)
|
||||
logger.info(f"移除能量计算器: {calculator.__class__.__name__}")
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""清空缓存"""
|
||||
self.energy_cache.clear()
|
||||
logger.info("清空能量缓存")
|
||||
|
||||
def get_cache_hit_rate(self) -> float:
|
||||
"""获取缓存命中率"""
|
||||
total_requests = self.stats.get("cache_hits", 0) + self.stats.get("cache_misses", 0)
|
||||
if total_requests == 0:
|
||||
return 0.0
|
||||
return self.stats["cache_hits"] / total_requests
|
||||
|
||||
|
||||
# 全局能量管理器实例
|
||||
energy_manager = EnergyManager()
|
||||
@@ -14,6 +14,7 @@ Chat Frequency Analyzer
|
||||
- MIN_CHATS_FOR_PEAK: 在一个窗口内需要多少次聊天才能被认为是高峰时段。
|
||||
- MIN_GAP_BETWEEN_PEAKS_HOURS: 两个独立高峰时段之间的最小间隔(小时)。
|
||||
"""
|
||||
|
||||
import time as time_module
|
||||
from datetime import datetime, timedelta, time
|
||||
from typing import List, Tuple, Optional
|
||||
@@ -72,12 +73,14 @@ class ChatFrequencyAnalyzer:
|
||||
current_window_end = datetimes[i]
|
||||
|
||||
# 合并重叠或相邻的高峰时段
|
||||
if peak_windows and current_window_start - peak_windows[-1][1] < timedelta(hours=MIN_GAP_BETWEEN_PEAKS_HOURS):
|
||||
if peak_windows and current_window_start - peak_windows[-1][1] < timedelta(
|
||||
hours=MIN_GAP_BETWEEN_PEAKS_HOURS
|
||||
):
|
||||
# 扩展上一个窗口的结束时间
|
||||
peak_windows[-1] = (peak_windows[-1][0], current_window_end)
|
||||
else:
|
||||
peak_windows.append((current_window_start, current_window_end))
|
||||
|
||||
|
||||
return peak_windows
|
||||
|
||||
def get_peak_chat_times(self, chat_id: str) -> List[Tuple[time, time]]:
|
||||
@@ -100,7 +103,7 @@ class ChatFrequencyAnalyzer:
|
||||
return []
|
||||
|
||||
peak_datetime_windows = self._find_peak_windows(timestamps)
|
||||
|
||||
|
||||
# 将 datetime 窗口转换为 time 窗口,并进行归一化处理
|
||||
peak_time_windows = []
|
||||
for start_dt, end_dt in peak_datetime_windows:
|
||||
@@ -110,7 +113,7 @@ class ChatFrequencyAnalyzer:
|
||||
|
||||
# 更新缓存
|
||||
self._analysis_cache[chat_id] = (time_module.time(), peak_time_windows)
|
||||
|
||||
|
||||
return peak_time_windows
|
||||
|
||||
def is_in_peak_time(self, chat_id: str, now: Optional[datetime] = None) -> bool:
|
||||
@@ -126,7 +129,7 @@ class ChatFrequencyAnalyzer:
|
||||
"""
|
||||
if now is None:
|
||||
now = datetime.now()
|
||||
|
||||
|
||||
now_time = now.time()
|
||||
peak_times = self.get_peak_chat_times(chat_id)
|
||||
|
||||
@@ -137,7 +140,7 @@ class ChatFrequencyAnalyzer:
|
||||
else: # 跨天
|
||||
if now_time >= start_time or now_time <= end_time:
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ class ChatFrequencyTracker:
|
||||
now = time.time()
|
||||
if chat_id not in self._timestamps:
|
||||
self._timestamps[chat_id] = []
|
||||
|
||||
|
||||
self._timestamps[chat_id].append(now)
|
||||
logger.debug(f"为 chat_id '{chat_id}' 记录了新的聊天时间: {now}")
|
||||
self._save_timestamps()
|
||||
|
||||
@@ -14,15 +14,16 @@ Frequency-Based Proactive Trigger
|
||||
- TRIGGER_CHECK_INTERVAL_SECONDS: 触发器检查的周期(秒)。
|
||||
- COOLDOWN_HOURS: 在同一个高峰时段内触发一次后的冷却时间(小时)。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.chat_loop.proactive.events import ProactiveTriggerEvent
|
||||
from src.chat.heart_flow.heartflow import heartflow
|
||||
from src.chat.chat_loop.sleep_manager.sleep_manager import SleepManager
|
||||
# AFC manager has been moved to chatter plugin
|
||||
|
||||
# TODO: 需要重新实现主动思考和睡眠管理功能
|
||||
from .analyzer import chat_frequency_analyzer
|
||||
|
||||
logger = get_logger("FrequencyBasedTrigger")
|
||||
@@ -39,8 +40,8 @@ class FrequencyBasedTrigger:
|
||||
一个周期性任务,根据聊天频率分析结果来触发主动思考。
|
||||
"""
|
||||
|
||||
def __init__(self, sleep_manager: SleepManager):
|
||||
self._sleep_manager = sleep_manager
|
||||
def __init__(self):
|
||||
# TODO: 需要重新实现睡眠管理器
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
# 记录上次为用户触发的时间,用于冷却控制
|
||||
# 格式: { "chat_id": timestamp }
|
||||
@@ -53,19 +54,21 @@ class FrequencyBasedTrigger:
|
||||
await asyncio.sleep(TRIGGER_CHECK_INTERVAL_SECONDS)
|
||||
logger.debug("开始执行频率触发器检查...")
|
||||
|
||||
# 1. 检查角色是否清醒
|
||||
if self._sleep_manager.is_sleeping():
|
||||
logger.debug("角色正在睡眠,跳过本次频率触发检查。")
|
||||
continue
|
||||
# 1. TODO: 检查角色是否清醒 - 需要重新实现睡眠状态检查
|
||||
# 暂时跳过睡眠检查
|
||||
# if self._sleep_manager.is_sleeping():
|
||||
# logger.debug("角色正在睡眠,跳过本次频率触发检查。")
|
||||
# continue
|
||||
|
||||
# 2. 获取所有已知的聊天ID
|
||||
# 【注意】这里我们假设所有 subheartflow 的 ID 就是 chat_id
|
||||
all_chat_ids = list(heartflow.subheartflows.keys())
|
||||
# 注意:AFC管理器已移至chatter插件,此功能暂时禁用
|
||||
# all_chat_ids = list(afc_manager.affinity_flow_chatters.keys())
|
||||
all_chat_ids = [] # 暂时禁用此功能
|
||||
if not all_chat_ids:
|
||||
continue
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
|
||||
for chat_id in all_chat_ids:
|
||||
# 3. 检查是否处于冷却时间内
|
||||
last_triggered_time = self._last_triggered.get(chat_id, 0)
|
||||
@@ -74,29 +77,11 @@ class FrequencyBasedTrigger:
|
||||
|
||||
# 4. 检查当前是否是该用户的高峰聊天时间
|
||||
if chat_frequency_analyzer.is_in_peak_time(chat_id, now):
|
||||
|
||||
sub_heartflow = await heartflow.get_or_create_subheartflow(chat_id)
|
||||
if not sub_heartflow:
|
||||
logger.warning(f"无法为 {chat_id} 获取或创建 sub_heartflow。")
|
||||
continue
|
||||
|
||||
# 5. 检查用户当前是否已有活跃的思考或回复任务
|
||||
cycle_detail = sub_heartflow.heart_fc_instance.context.current_cycle_detail
|
||||
if cycle_detail and not cycle_detail.end_time:
|
||||
logger.debug(f"用户 {chat_id} 的聊天循环正忙(仍在周期 {cycle_detail.cycle_id} 中),本次不触发。")
|
||||
continue
|
||||
|
||||
logger.info(f"检测到用户 {chat_id} 处于聊天高峰期,且聊天循环空闲,准备触发主动思考。")
|
||||
|
||||
# 6. 直接调用 proactive_thinker
|
||||
event = ProactiveTriggerEvent(
|
||||
source="frequency_analyzer",
|
||||
reason="User is in a high-frequency chat period."
|
||||
)
|
||||
await sub_heartflow.heart_fc_instance.proactive_thinker.think(event)
|
||||
|
||||
# 7. 更新触发时间,进入冷却
|
||||
self._last_triggered[chat_id] = time.time()
|
||||
# 5. 检查用户当前是否已有活跃的处理任务
|
||||
# 注意:AFC管理器已移至chatter插件,此功能暂时禁用
|
||||
# chatter = afc_manager.get_or_create_chatter(chat_id)
|
||||
logger.info(f"检测到用户 {chat_id} 处于聊天高峰期,但AFC功能已移至chatter插件")
|
||||
continue
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("频率触发器任务被取消。")
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
import traceback
|
||||
from typing import Any, Optional, Dict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.heart_flow.sub_heartflow import SubHeartflow
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
logger = get_logger("heartflow")
|
||||
|
||||
|
||||
class Heartflow:
|
||||
"""主心流协调器,负责初始化并协调聊天"""
|
||||
|
||||
def __init__(self):
|
||||
self.subheartflows: Dict[Any, "SubHeartflow"] = {}
|
||||
|
||||
async def get_or_create_subheartflow(self, subheartflow_id: Any) -> Optional["SubHeartflow"]:
|
||||
"""获取或创建一个新的SubHeartflow实例"""
|
||||
if subheartflow_id in self.subheartflows:
|
||||
if subflow := self.subheartflows.get(subheartflow_id):
|
||||
return subflow
|
||||
|
||||
try:
|
||||
new_subflow = SubHeartflow(subheartflow_id)
|
||||
|
||||
await new_subflow.initialize()
|
||||
|
||||
# 注册子心流
|
||||
self.subheartflows[subheartflow_id] = new_subflow
|
||||
heartflow_name = get_chat_manager().get_stream_name(subheartflow_id) or subheartflow_id
|
||||
logger.info(f"[{heartflow_name}] 开始接收消息")
|
||||
|
||||
return new_subflow
|
||||
except Exception as e:
|
||||
logger.error(f"创建子心流 {subheartflow_id} 失败: {e}", exc_info=True)
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
heartflow = Heartflow()
|
||||
@@ -1,178 +0,0 @@
|
||||
import asyncio
|
||||
import math
|
||||
import re
|
||||
import traceback
|
||||
from typing import Tuple, TYPE_CHECKING
|
||||
|
||||
from src.chat.heart_flow.heartflow import heartflow
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.utils.chat_message_builder import replace_user_references_sync
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.person_info.relationship_manager import get_relationship_manager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.heart_flow.sub_heartflow import SubHeartflow
|
||||
|
||||
logger = get_logger("chat")
|
||||
|
||||
|
||||
async def _process_relationship(message: MessageRecv) -> None:
|
||||
"""处理用户关系逻辑
|
||||
|
||||
Args:
|
||||
message: 消息对象,包含用户信息
|
||||
"""
|
||||
platform = message.message_info.platform
|
||||
user_id = message.message_info.user_info.user_id # type: ignore
|
||||
nickname = message.message_info.user_info.user_nickname # type: ignore
|
||||
cardname = message.message_info.user_info.user_cardname or nickname # type: ignore
|
||||
|
||||
relationship_manager = get_relationship_manager()
|
||||
is_known = await relationship_manager.is_known_some_one(platform, user_id)
|
||||
|
||||
if not is_known:
|
||||
logger.info(f"首次认识用户: {nickname}")
|
||||
await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname) # type: ignore
|
||||
|
||||
|
||||
async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool, list[str]]:
|
||||
"""计算消息的兴趣度
|
||||
|
||||
Args:
|
||||
message: 待处理的消息对象
|
||||
|
||||
Returns:
|
||||
Tuple[float, bool, list[str]]: (兴趣度, 是否被提及, 关键词)
|
||||
"""
|
||||
is_mentioned, _ = is_mentioned_bot_in_message(message)
|
||||
interested_rate = 0.0
|
||||
|
||||
with Timer("记忆激活"):
|
||||
interested_rate, keywords = await hippocampus_manager.get_activate_from_text(
|
||||
message.processed_plain_text,
|
||||
max_depth=4,
|
||||
fast_retrieval=False,
|
||||
)
|
||||
message.key_words = keywords
|
||||
message.key_words_lite = keywords
|
||||
logger.debug(f"记忆激活率: {interested_rate:.2f}, 关键词: {keywords}")
|
||||
|
||||
text_len = len(message.processed_plain_text)
|
||||
# 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算
|
||||
# 基于实际分布:0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%)
|
||||
|
||||
if text_len == 0:
|
||||
base_interest = 0.01 # 空消息最低兴趣度
|
||||
elif text_len <= 5:
|
||||
# 1-5字符:线性增长 0.01 -> 0.03
|
||||
base_interest = 0.01 + (text_len - 1) * (0.03 - 0.01) / 4
|
||||
elif text_len <= 10:
|
||||
# 6-10字符:线性增长 0.03 -> 0.06
|
||||
base_interest = 0.03 + (text_len - 5) * (0.06 - 0.03) / 5
|
||||
elif text_len <= 20:
|
||||
# 11-20字符:线性增长 0.06 -> 0.12
|
||||
base_interest = 0.06 + (text_len - 10) * (0.12 - 0.06) / 10
|
||||
elif text_len <= 30:
|
||||
# 21-30字符:线性增长 0.12 -> 0.18
|
||||
base_interest = 0.12 + (text_len - 20) * (0.18 - 0.12) / 10
|
||||
elif text_len <= 50:
|
||||
# 31-50字符:线性增长 0.18 -> 0.22
|
||||
base_interest = 0.18 + (text_len - 30) * (0.22 - 0.18) / 20
|
||||
elif text_len <= 100:
|
||||
# 51-100字符:线性增长 0.22 -> 0.26
|
||||
base_interest = 0.22 + (text_len - 50) * (0.26 - 0.22) / 50
|
||||
else:
|
||||
# 100+字符:对数增长 0.26 -> 0.3,增长率递减
|
||||
base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901
|
||||
|
||||
# 确保在范围内
|
||||
base_interest = min(max(base_interest, 0.01), 0.3)
|
||||
|
||||
interested_rate += base_interest
|
||||
|
||||
if is_mentioned:
|
||||
interest_increase_on_mention = 1
|
||||
interested_rate += interest_increase_on_mention
|
||||
|
||||
return interested_rate, is_mentioned, keywords
|
||||
|
||||
|
||||
class HeartFCMessageReceiver:
|
||||
"""心流处理器,负责处理接收到的消息并计算兴趣度"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化心流处理器,创建消息存储实例"""
|
||||
self.storage = MessageStorage()
|
||||
|
||||
async def process_message(self, message: MessageRecv) -> None:
|
||||
"""处理接收到的原始消息数据
|
||||
|
||||
主要流程:
|
||||
1. 消息解析与初始化
|
||||
2. 消息缓冲处理
|
||||
4. 过滤检查
|
||||
5. 兴趣度计算
|
||||
6. 关系处理
|
||||
|
||||
Args:
|
||||
message_data: 原始消息字符串
|
||||
"""
|
||||
try:
|
||||
# 1. 消息解析与初始化
|
||||
userinfo = message.message_info.user_info
|
||||
chat = message.chat_stream
|
||||
|
||||
# 2. 兴趣度计算与更新
|
||||
interested_rate, is_mentioned, keywords = await _calculate_interest(message)
|
||||
message.interest_value = interested_rate
|
||||
message.is_mentioned = is_mentioned
|
||||
|
||||
await self.storage.store_message(message, chat)
|
||||
|
||||
subheartflow: SubHeartflow = await heartflow.get_or_create_subheartflow(chat.stream_id) # type: ignore
|
||||
|
||||
await subheartflow.heart_fc_instance.add_message(message.to_dict())
|
||||
if global_config.mood.enable_mood:
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id)
|
||||
asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate))
|
||||
|
||||
# 3. 日志记录
|
||||
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
# current_time = time.strftime("%H:%M:%S", time.localtime(message.message_info.time))
|
||||
current_talk_frequency = global_config.chat.get_current_talk_frequency(chat.stream_id)
|
||||
|
||||
# 如果消息中包含图片标识,则将 [picid:...] 替换为 [图片]
|
||||
picid_pattern = r"\[picid:([^\]]+)\]"
|
||||
processed_plain_text = re.sub(picid_pattern, "[图片]", message.processed_plain_text)
|
||||
|
||||
# 应用用户引用格式替换,将回复<aaa:bbb>和@<aaa:bbb>格式转换为可读格式
|
||||
processed_plain_text = replace_user_references_sync(
|
||||
processed_plain_text,
|
||||
message.message_info.platform, # type: ignore
|
||||
replace_bot_name=True,
|
||||
)
|
||||
|
||||
if keywords:
|
||||
logger.info(
|
||||
f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}[兴趣度:{interested_rate:.2f}][关键词:{keywords}]"
|
||||
) # type: ignore
|
||||
else:
|
||||
logger.info(
|
||||
f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}[兴趣度:{interested_rate:.2f}]"
|
||||
) # type: ignore
|
||||
|
||||
logger.debug(f"[{mes_name}][当前时段回复频率: {current_talk_frequency}]")
|
||||
|
||||
# 4. 关系处理
|
||||
if global_config.relationship.enable_relationship:
|
||||
await _process_relationship(message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"消息处理失败: {e}")
|
||||
print(traceback.format_exc())
|
||||
@@ -1,42 +0,0 @@
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.chat_loop.heartFC_chat import HeartFChatting
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
|
||||
logger = get_logger("sub_heartflow")
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
class SubHeartflow:
|
||||
def __init__(
|
||||
self,
|
||||
subheartflow_id,
|
||||
):
|
||||
"""子心流初始化函数
|
||||
|
||||
Args:
|
||||
subheartflow_id: 子心流唯一标识符
|
||||
"""
|
||||
# 基础属性,两个值是一样的
|
||||
self.subheartflow_id = subheartflow_id
|
||||
self.chat_id = subheartflow_id
|
||||
|
||||
self.is_group_chat, self.chat_target_info = (None, None)
|
||||
self.log_prefix = get_chat_manager().get_stream_name(self.subheartflow_id) or self.subheartflow_id
|
||||
|
||||
# focus模式退出冷却时间管理
|
||||
self.last_focus_exit_time: float = 0 # 上次退出focus模式的时间
|
||||
|
||||
# 随便水群 normal_chat 和 认真水群 focus_chat 实例
|
||||
# CHAT模式激活 随便水群 FOCUS模式激活 认真水群
|
||||
self.heart_fc_instance: HeartFChatting = HeartFChatting(
|
||||
chat_id=self.subheartflow_id,
|
||||
) # 该sub_heartflow的HeartFChatting实例
|
||||
|
||||
async def initialize(self):
|
||||
"""异步初始化方法,创建兴趣流并确定聊天类型"""
|
||||
self.is_group_chat, self.chat_target_info = await get_chat_type_and_target_info(self.chat_id)
|
||||
await self.heart_fc_instance.start()
|
||||
15
src/chat/interest_system/__init__.py
Normal file
15
src/chat/interest_system/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
兴趣度系统模块
|
||||
提供机器人兴趣标签和智能匹配功能
|
||||
"""
|
||||
|
||||
from .bot_interest_manager import BotInterestManager, bot_interest_manager
|
||||
from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult
|
||||
|
||||
__all__ = [
|
||||
"BotInterestManager",
|
||||
"bot_interest_manager",
|
||||
"BotInterestTag",
|
||||
"BotPersonalityInterests",
|
||||
"InterestMatchResult",
|
||||
]
|
||||
805
src/chat/interest_system/bot_interest_manager.py
Normal file
805
src/chat/interest_system/bot_interest_manager.py
Normal file
@@ -0,0 +1,805 @@
|
||||
"""
|
||||
机器人兴趣标签管理系统
|
||||
基于人设生成兴趣标签,并使用embedding计算匹配度
|
||||
"""
|
||||
|
||||
import orjson
|
||||
import traceback
|
||||
from typing import List, Dict, Optional, Any
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.common.data_models.bot_interest_data_model import BotPersonalityInterests, BotInterestTag, InterestMatchResult
|
||||
|
||||
logger = get_logger("bot_interest_manager")
|
||||
|
||||
|
||||
class BotInterestManager:
|
||||
"""机器人兴趣标签管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.current_interests: Optional[BotPersonalityInterests] = None
|
||||
self.embedding_cache: Dict[str, List[float]] = {} # embedding缓存
|
||||
self._initialized = False
|
||||
|
||||
# Embedding客户端配置
|
||||
self.embedding_request = None
|
||||
self.embedding_config = None
|
||||
self.embedding_dimension = 1024 # 默认BGE-M3 embedding维度
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""检查兴趣系统是否已初始化"""
|
||||
return self._initialized
|
||||
|
||||
async def initialize(self, personality_description: str, personality_id: str = "default"):
|
||||
"""初始化兴趣标签系统"""
|
||||
try:
|
||||
logger.info("机器人兴趣系统开始初始化...")
|
||||
logger.info(f"人设ID: {personality_id}, 描述长度: {len(personality_description)}")
|
||||
|
||||
# 初始化embedding模型
|
||||
await self._initialize_embedding_model()
|
||||
|
||||
# 检查embedding客户端是否成功初始化
|
||||
if not self.embedding_request:
|
||||
raise RuntimeError("Embedding客户端初始化失败")
|
||||
|
||||
# 生成或加载兴趣标签
|
||||
await self._load_or_generate_interests(personality_description, personality_id)
|
||||
|
||||
self._initialized = True
|
||||
|
||||
# 检查是否成功获取兴趣标签
|
||||
if self.current_interests and len(self.current_interests.get_active_tags()) > 0:
|
||||
active_tags_count = len(self.current_interests.get_active_tags())
|
||||
logger.info("机器人兴趣系统初始化完成!")
|
||||
logger.info(f"当前已激活 {active_tags_count} 个兴趣标签, Embedding缓存 {len(self.embedding_cache)} 个")
|
||||
else:
|
||||
raise RuntimeError("未能成功加载或生成兴趣标签")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"机器人兴趣系统初始化失败: {e}")
|
||||
traceback.print_exc()
|
||||
raise # 重新抛出异常,不允许降级初始化
|
||||
|
||||
async def _initialize_embedding_model(self):
|
||||
"""初始化embedding模型"""
|
||||
logger.info("🔧 正在配置embedding客户端...")
|
||||
|
||||
# 使用项目配置的embedding模型
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger.debug("✅ 成功导入embedding相关模块")
|
||||
|
||||
# 检查embedding配置是否存在
|
||||
if not hasattr(model_config.model_task_config, "embedding"):
|
||||
raise RuntimeError("❌ 未找到embedding模型配置")
|
||||
|
||||
logger.info("📋 找到embedding模型配置")
|
||||
self.embedding_config = model_config.model_task_config.embedding
|
||||
self.embedding_dimension = 1024 # BGE-M3的维度
|
||||
logger.info(f"📐 使用模型维度: {self.embedding_dimension}")
|
||||
|
||||
# 创建LLMRequest实例用于embedding
|
||||
self.embedding_request = LLMRequest(model_set=self.embedding_config, request_type="interest_embedding")
|
||||
logger.info("✅ Embedding请求客户端初始化成功")
|
||||
logger.info(f"🔗 客户端类型: {type(self.embedding_request).__name__}")
|
||||
|
||||
# 获取第一个embedding模型的ModelInfo
|
||||
if hasattr(self.embedding_config, "model_list") and self.embedding_config.model_list:
|
||||
first_model_name = self.embedding_config.model_list[0]
|
||||
logger.info(f"🎯 使用embedding模型: {first_model_name}")
|
||||
else:
|
||||
logger.warning("⚠️ 未找到embedding模型列表")
|
||||
|
||||
logger.info("✅ Embedding模型初始化完成")
|
||||
|
||||
async def _load_or_generate_interests(self, personality_description: str, personality_id: str):
|
||||
"""加载或生成兴趣标签"""
|
||||
logger.info(f"📚 正在为 '{personality_id}' 加载或生成兴趣标签...")
|
||||
|
||||
# 首先尝试从数据库加载
|
||||
logger.info("尝试从数据库加载兴趣标签...")
|
||||
loaded_interests = await self._load_interests_from_database(personality_id)
|
||||
|
||||
if loaded_interests:
|
||||
self.current_interests = loaded_interests
|
||||
active_count = len(loaded_interests.get_active_tags())
|
||||
logger.info(f"成功从数据库加载 {active_count} 个兴趣标签 (版本: {loaded_interests.version})")
|
||||
tags_info = [f" - '{tag.tag_name}' (权重: {tag.weight:.2f})" for tag in loaded_interests.get_active_tags()]
|
||||
tags_str = "\n".join(tags_info)
|
||||
logger.info(f"当前兴趣标签:\n{tags_str}")
|
||||
else:
|
||||
# 生成新的兴趣标签
|
||||
logger.info("数据库中未找到兴趣标签,开始生成...")
|
||||
generated_interests = await self._generate_interests_from_personality(
|
||||
personality_description, personality_id
|
||||
)
|
||||
|
||||
if generated_interests:
|
||||
self.current_interests = generated_interests
|
||||
active_count = len(generated_interests.get_active_tags())
|
||||
logger.info(f"成功生成 {active_count} 个新兴趣标签。")
|
||||
tags_info = [f" - '{tag.tag_name}' (权重: {tag.weight:.2f})" for tag in generated_interests.get_active_tags()]
|
||||
tags_str = "\n".join(tags_info)
|
||||
logger.info(f"当前兴趣标签:\n{tags_str}")
|
||||
|
||||
# 保存到数据库
|
||||
logger.info("正在保存至数据库...")
|
||||
await self._save_interests_to_database(generated_interests)
|
||||
else:
|
||||
raise RuntimeError("❌ 兴趣标签生成失败")
|
||||
|
||||
async def _generate_interests_from_personality(
|
||||
self, personality_description: str, personality_id: str
|
||||
) -> Optional[BotPersonalityInterests]:
|
||||
"""根据人设生成兴趣标签"""
|
||||
try:
|
||||
logger.info("🎨 开始根据人设生成兴趣标签...")
|
||||
logger.info(f"📝 人设长度: {len(personality_description)} 字符")
|
||||
|
||||
# 检查embedding客户端是否可用
|
||||
if not hasattr(self, "embedding_request"):
|
||||
raise RuntimeError("❌ Embedding客户端未初始化,无法生成兴趣标签")
|
||||
|
||||
# 构建提示词
|
||||
logger.info("📝 构建LLM提示词...")
|
||||
prompt = f"""
|
||||
基于以下机器人人设描述,生成一套合适的兴趣标签:
|
||||
|
||||
人设描述:
|
||||
{personality_description}
|
||||
|
||||
请生成一系列兴趣关键词标签,要求:
|
||||
1. 标签应该符合人设特点和性格
|
||||
2. 每个标签都有权重(0.1-1.0),表示对该兴趣的喜好程度
|
||||
3. 生成15-25个不等的标签
|
||||
4. 标签应该是具体的关键词,而不是抽象概念
|
||||
|
||||
请以JSON格式返回,格式如下:
|
||||
{{
|
||||
"interests": [
|
||||
{{"name": "标签名", "weight": 0.8}},
|
||||
{{"name": "标签名", "weight": 0.6}},
|
||||
{{"name": "标签名", "weight": 0.9}}
|
||||
]
|
||||
}}
|
||||
|
||||
注意:
|
||||
- 权重范围0.1-1.0,权重越高表示越感兴趣
|
||||
- 标签要具体,如"编程"、"游戏"、"旅行"等
|
||||
- 根据人设生成个性化的标签
|
||||
"""
|
||||
|
||||
# 调用LLM生成兴趣标签
|
||||
logger.info("🤖 正在调用LLM生成兴趣标签...")
|
||||
response = await self._call_llm_for_interest_generation(prompt)
|
||||
|
||||
if not response:
|
||||
raise RuntimeError("❌ LLM未返回有效响应")
|
||||
|
||||
logger.info("✅ LLM响应成功,开始解析兴趣标签...")
|
||||
interests_data = orjson.loads(response)
|
||||
|
||||
bot_interests = BotPersonalityInterests(
|
||||
personality_id=personality_id, personality_description=personality_description
|
||||
)
|
||||
|
||||
# 解析生成的兴趣标签
|
||||
interests_list = interests_data.get("interests", [])
|
||||
logger.info(f"📋 解析到 {len(interests_list)} 个兴趣标签")
|
||||
|
||||
for i, tag_data in enumerate(interests_list):
|
||||
tag_name = tag_data.get("name", f"标签_{i}")
|
||||
weight = tag_data.get("weight", 0.5)
|
||||
|
||||
tag = BotInterestTag(tag_name=tag_name, weight=weight)
|
||||
bot_interests.interest_tags.append(tag)
|
||||
|
||||
logger.debug(f" 🏷️ {tag_name} (权重: {weight:.2f})")
|
||||
|
||||
# 为所有标签生成embedding
|
||||
logger.info("🧠 开始为兴趣标签生成embedding向量...")
|
||||
await self._generate_embeddings_for_tags(bot_interests)
|
||||
|
||||
logger.info("✅ 兴趣标签生成完成")
|
||||
return bot_interests
|
||||
|
||||
except orjson.JSONDecodeError as e:
|
||||
logger.error(f"❌ 解析LLM响应JSON失败: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 根据人设生成兴趣标签失败: {e}")
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
async def _call_llm_for_interest_generation(self, prompt: str) -> Optional[str]:
|
||||
"""调用LLM生成兴趣标签"""
|
||||
try:
|
||||
logger.info("🔧 配置LLM客户端...")
|
||||
|
||||
# 使用llm_api来处理请求
|
||||
from src.plugin_system.apis import llm_api
|
||||
from src.config.config import model_config
|
||||
|
||||
# 构建完整的提示词,明确要求只返回纯JSON
|
||||
full_prompt = f"""你是一个专业的机器人人设分析师,擅长根据人设描述生成合适的兴趣标签。
|
||||
|
||||
{prompt}
|
||||
|
||||
请确保返回格式为有效的JSON,不要包含任何额外的文本、解释或代码块标记。只返回JSON对象本身。"""
|
||||
|
||||
# 使用replyer模型配置
|
||||
replyer_config = model_config.model_task_config.replyer
|
||||
|
||||
# 调用LLM API
|
||||
logger.info("🚀 正在通过LLM API发送请求...")
|
||||
success, response, reasoning_content, model_name = await llm_api.generate_with_model(
|
||||
prompt=full_prompt,
|
||||
model_config=replyer_config,
|
||||
request_type="interest_generation",
|
||||
temperature=0.7,
|
||||
max_tokens=2000,
|
||||
)
|
||||
|
||||
if success and response:
|
||||
logger.info(f"✅ LLM响应成功,模型: {model_name}, 响应长度: {len(response)} 字符")
|
||||
logger.debug(
|
||||
f"📄 LLM响应内容: {response[:200]}..." if len(response) > 200 else f"📄 LLM响应内容: {response}"
|
||||
)
|
||||
if reasoning_content:
|
||||
logger.debug(f"🧠 推理内容: {reasoning_content[:100]}...")
|
||||
|
||||
# 清理响应内容,移除可能的代码块标记
|
||||
cleaned_response = self._clean_llm_response(response)
|
||||
return cleaned_response
|
||||
else:
|
||||
logger.warning("⚠️ LLM返回空响应或调用失败")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 调用LLM生成兴趣标签失败: {e}")
|
||||
logger.error("🔍 错误详情:")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
def _clean_llm_response(self, response: str) -> str:
|
||||
"""清理LLM响应,移除代码块标记和其他非JSON内容"""
|
||||
import re
|
||||
|
||||
# 移除 ```json 和 ``` 标记
|
||||
cleaned = re.sub(r"```json\s*", "", response)
|
||||
cleaned = re.sub(r"\s*```", "", cleaned)
|
||||
|
||||
# 移除可能的多余空格和换行
|
||||
cleaned = cleaned.strip()
|
||||
|
||||
# 尝试提取JSON对象(如果响应中有其他文本)
|
||||
json_match = re.search(r"\{.*\}", cleaned, re.DOTALL)
|
||||
if json_match:
|
||||
cleaned = json_match.group(0)
|
||||
|
||||
logger.debug(f"🧹 清理后的响应: {cleaned[:200]}..." if len(cleaned) > 200 else f"🧹 清理后的响应: {cleaned}")
|
||||
return cleaned
|
||||
|
||||
async def _generate_embeddings_for_tags(self, interests: BotPersonalityInterests):
|
||||
"""为所有兴趣标签生成embedding"""
|
||||
if not hasattr(self, "embedding_request"):
|
||||
raise RuntimeError("❌ Embedding客户端未初始化,无法生成embedding")
|
||||
|
||||
total_tags = len(interests.interest_tags)
|
||||
logger.info(f"🧠 开始为 {total_tags} 个兴趣标签生成embedding向量...")
|
||||
|
||||
cached_count = 0
|
||||
generated_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for i, tag in enumerate(interests.interest_tags, 1):
|
||||
if tag.tag_name in self.embedding_cache:
|
||||
# 使用缓存的embedding
|
||||
tag.embedding = self.embedding_cache[tag.tag_name]
|
||||
cached_count += 1
|
||||
logger.debug(f" [{i}/{total_tags}] 🏷️ '{tag.tag_name}' - 使用缓存")
|
||||
else:
|
||||
# 生成新的embedding
|
||||
embedding_text = tag.tag_name
|
||||
|
||||
logger.debug(f" [{i}/{total_tags}] 🔄 正在为 '{tag.tag_name}' 生成embedding...")
|
||||
embedding = await self._get_embedding(embedding_text)
|
||||
|
||||
if embedding:
|
||||
tag.embedding = embedding
|
||||
self.embedding_cache[tag.tag_name] = embedding
|
||||
generated_count += 1
|
||||
logger.debug(f" ✅ '{tag.tag_name}' embedding生成成功")
|
||||
else:
|
||||
failed_count += 1
|
||||
logger.warning(f" ❌ '{tag.tag_name}' embedding生成失败")
|
||||
|
||||
if failed_count > 0:
|
||||
raise RuntimeError(f"❌ 有 {failed_count} 个兴趣标签embedding生成失败")
|
||||
|
||||
interests.last_updated = datetime.now()
|
||||
logger.info("=" * 50)
|
||||
logger.info("✅ Embedding生成完成!")
|
||||
logger.info(f"📊 总标签数: {total_tags}")
|
||||
logger.info(f"💾 缓存命中: {cached_count}")
|
||||
logger.info(f"🆕 新生成: {generated_count}")
|
||||
logger.info(f"❌ 失败: {failed_count}")
|
||||
logger.info(f"🗃️ 总缓存大小: {len(self.embedding_cache)}")
|
||||
logger.info("=" * 50)
|
||||
|
||||
async def _get_embedding(self, text: str) -> List[float]:
|
||||
"""获取文本的embedding向量"""
|
||||
if not hasattr(self, "embedding_request"):
|
||||
raise RuntimeError("❌ Embedding请求客户端未初始化")
|
||||
|
||||
# 检查缓存
|
||||
if text in self.embedding_cache:
|
||||
logger.debug(f"💾 使用缓存的embedding: '{text[:30]}...'")
|
||||
return self.embedding_cache[text]
|
||||
|
||||
# 使用LLMRequest获取embedding
|
||||
logger.debug(f"🔄 正在获取embedding: '{text[:30]}...'")
|
||||
embedding, model_name = await self.embedding_request.get_embedding(text)
|
||||
|
||||
if embedding and len(embedding) > 0:
|
||||
self.embedding_cache[text] = embedding
|
||||
logger.debug(f"✅ Embedding获取成功,维度: {len(embedding)}, 模型: {model_name}")
|
||||
return embedding
|
||||
else:
|
||||
raise RuntimeError(f"❌ 返回的embedding为空: {embedding}")
|
||||
|
||||
async def _generate_message_embedding(self, message_text: str, keywords: List[str]) -> List[float]:
|
||||
"""为消息生成embedding向量"""
|
||||
# 组合消息文本和关键词作为embedding输入
|
||||
if keywords:
|
||||
combined_text = f"{message_text} {' '.join(keywords)}"
|
||||
else:
|
||||
combined_text = message_text
|
||||
|
||||
logger.debug(f"🔄 正在为消息生成embedding,输入长度: {len(combined_text)}")
|
||||
|
||||
# 生成embedding
|
||||
embedding = await self._get_embedding(combined_text)
|
||||
logger.debug(f"✅ 消息embedding生成成功,维度: {len(embedding)}")
|
||||
return embedding
|
||||
|
||||
async def _calculate_similarity_scores(
|
||||
self, result: InterestMatchResult, message_embedding: List[float], keywords: List[str]
|
||||
):
|
||||
"""计算消息与兴趣标签的相似度分数"""
|
||||
try:
|
||||
if not self.current_interests:
|
||||
return
|
||||
|
||||
active_tags = self.current_interests.get_active_tags()
|
||||
if not active_tags:
|
||||
return
|
||||
|
||||
logger.debug(f"🔍 开始计算与 {len(active_tags)} 个兴趣标签的相似度")
|
||||
|
||||
for tag in active_tags:
|
||||
if tag.embedding:
|
||||
# 计算余弦相似度
|
||||
similarity = self._calculate_cosine_similarity(message_embedding, tag.embedding)
|
||||
weighted_score = similarity * tag.weight
|
||||
|
||||
# 设置相似度阈值为0.3
|
||||
if similarity > 0.3:
|
||||
result.add_match(tag.tag_name, weighted_score, keywords)
|
||||
logger.debug(
|
||||
f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 加权分数={weighted_score:.3f}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 计算相似度分数失败: {e}")
|
||||
|
||||
async def calculate_interest_match(self, message_text: str, keywords: List[str] = None) -> InterestMatchResult:
|
||||
"""计算消息与机器人兴趣的匹配度"""
|
||||
if not self.current_interests or not self._initialized:
|
||||
raise RuntimeError("❌ 兴趣标签系统未初始化")
|
||||
|
||||
logger.debug(f"开始计算兴趣匹配度: 消息长度={len(message_text)}, 关键词数={len(keywords) if keywords else 0}")
|
||||
|
||||
message_id = f"msg_{datetime.now().timestamp()}"
|
||||
result = InterestMatchResult(message_id=message_id)
|
||||
|
||||
# 获取活跃的兴趣标签
|
||||
active_tags = self.current_interests.get_active_tags()
|
||||
if not active_tags:
|
||||
raise RuntimeError("没有检测到活跃的兴趣标签")
|
||||
|
||||
logger.debug(f"正在与 {len(active_tags)} 个兴趣标签进行匹配...")
|
||||
|
||||
# 生成消息的embedding
|
||||
logger.debug("正在生成消息 embedding...")
|
||||
message_embedding = await self._get_embedding(message_text)
|
||||
logger.debug(f"消息 embedding 生成成功, 维度: {len(message_embedding)}")
|
||||
|
||||
# 计算与每个兴趣标签的相似度
|
||||
match_count = 0
|
||||
high_similarity_count = 0
|
||||
medium_similarity_count = 0
|
||||
low_similarity_count = 0
|
||||
|
||||
# 分级相似度阈值
|
||||
affinity_config = global_config.affinity_flow
|
||||
high_threshold = affinity_config.high_match_interest_threshold
|
||||
medium_threshold = affinity_config.medium_match_interest_threshold
|
||||
low_threshold = affinity_config.low_match_interest_threshold
|
||||
|
||||
logger.debug(f"🔍 使用分级相似度阈值: 高={high_threshold}, 中={medium_threshold}, 低={low_threshold}")
|
||||
|
||||
for tag in active_tags:
|
||||
if tag.embedding:
|
||||
similarity = self._calculate_cosine_similarity(message_embedding, tag.embedding)
|
||||
|
||||
# 基础加权分数
|
||||
weighted_score = similarity * tag.weight
|
||||
|
||||
# 根据相似度等级应用不同的加成
|
||||
if similarity > high_threshold:
|
||||
# 高相似度:强加成
|
||||
enhanced_score = weighted_score * affinity_config.high_match_keyword_multiplier
|
||||
match_count += 1
|
||||
high_similarity_count += 1
|
||||
result.add_match(tag.tag_name, enhanced_score, [tag.tag_name])
|
||||
|
||||
elif similarity > medium_threshold:
|
||||
# 中相似度:中等加成
|
||||
enhanced_score = weighted_score * affinity_config.medium_match_keyword_multiplier
|
||||
match_count += 1
|
||||
medium_similarity_count += 1
|
||||
result.add_match(tag.tag_name, enhanced_score, [tag.tag_name])
|
||||
|
||||
elif similarity > low_threshold:
|
||||
# 低相似度:轻微加成
|
||||
enhanced_score = weighted_score * affinity_config.low_match_keyword_multiplier
|
||||
match_count += 1
|
||||
low_similarity_count += 1
|
||||
result.add_match(tag.tag_name, enhanced_score, [tag.tag_name])
|
||||
|
||||
logger.debug(
|
||||
f"匹配统计: {match_count}/{len(active_tags)} 个标签命中 | "
|
||||
f"高(>{high_threshold}): {high_similarity_count}, "
|
||||
f"中(>{medium_threshold}): {medium_similarity_count}, "
|
||||
f"低(>{low_threshold}): {low_similarity_count}"
|
||||
)
|
||||
|
||||
# 添加直接关键词匹配奖励
|
||||
keyword_bonus = self._calculate_keyword_match_bonus(keywords, result.matched_tags)
|
||||
logger.debug(f"🎯 关键词直接匹配奖励: {keyword_bonus}")
|
||||
|
||||
# 应用关键词奖励到匹配分数
|
||||
for tag_name in result.matched_tags:
|
||||
if tag_name in keyword_bonus:
|
||||
original_score = result.match_scores[tag_name]
|
||||
bonus = keyword_bonus[tag_name]
|
||||
result.match_scores[tag_name] = original_score + bonus
|
||||
logger.debug(
|
||||
f" 🏷️ '{tag_name}': 原始分数={original_score:.3f}, 奖励={bonus:.3f}, 最终分数={result.match_scores[tag_name]:.3f}"
|
||||
)
|
||||
|
||||
# 计算总体分数
|
||||
result.calculate_overall_score()
|
||||
|
||||
# 确定最佳匹配标签
|
||||
if result.matched_tags:
|
||||
top_tag_name = max(result.match_scores.items(), key=lambda x: x[1])[0]
|
||||
result.top_tag = top_tag_name
|
||||
logger.debug(f"最佳匹配: '{top_tag_name}' (分数: {result.match_scores[top_tag_name]:.3f})")
|
||||
|
||||
logger.debug(
|
||||
f"最终结果: 总分={result.overall_score:.3f}, 置信度={result.confidence:.3f}, 匹配标签数={len(result.matched_tags)}"
|
||||
)
|
||||
return result
|
||||
|
||||
def _calculate_keyword_match_bonus(self, keywords: List[str], matched_tags: List[str]) -> Dict[str, float]:
|
||||
"""计算关键词直接匹配奖励"""
|
||||
if not keywords or not matched_tags:
|
||||
return {}
|
||||
|
||||
affinity_config = global_config.affinity_flow
|
||||
bonus_dict = {}
|
||||
|
||||
for tag_name in matched_tags:
|
||||
bonus = 0.0
|
||||
|
||||
# 检查关键词与标签的直接匹配
|
||||
for keyword in keywords:
|
||||
keyword_lower = keyword.lower().strip()
|
||||
tag_name_lower = tag_name.lower()
|
||||
|
||||
# 完全匹配
|
||||
if keyword_lower == tag_name_lower:
|
||||
bonus += affinity_config.high_match_interest_threshold * 0.6 # 使用高匹配阈值的60%作为完全匹配奖励
|
||||
logger.debug(
|
||||
f" 🎯 关键词完全匹配: '{keyword}' == '{tag_name}' (+{affinity_config.high_match_interest_threshold * 0.6:.3f})"
|
||||
)
|
||||
|
||||
# 包含匹配
|
||||
elif keyword_lower in tag_name_lower or tag_name_lower in keyword_lower:
|
||||
bonus += (
|
||||
affinity_config.medium_match_interest_threshold * 0.3
|
||||
) # 使用中匹配阈值的30%作为包含匹配奖励
|
||||
logger.debug(
|
||||
f" 🎯 关键词包含匹配: '{keyword}' ⊃ '{tag_name}' (+{affinity_config.medium_match_interest_threshold * 0.3:.3f})"
|
||||
)
|
||||
|
||||
# 部分匹配(编辑距离)
|
||||
elif self._calculate_partial_match(keyword_lower, tag_name_lower):
|
||||
bonus += affinity_config.low_match_interest_threshold * 0.4 # 使用低匹配阈值的40%作为部分匹配奖励
|
||||
logger.debug(
|
||||
f" 🎯 关键词部分匹配: '{keyword}' ≈ '{tag_name}' (+{affinity_config.low_match_interest_threshold * 0.4:.3f})"
|
||||
)
|
||||
|
||||
if bonus > 0:
|
||||
bonus_dict[tag_name] = min(bonus, affinity_config.max_match_bonus) # 使用配置的最大奖励限制
|
||||
|
||||
return bonus_dict
|
||||
|
||||
def _calculate_partial_match(self, text1: str, text2: str) -> bool:
|
||||
"""计算部分匹配(基于编辑距离)"""
|
||||
try:
|
||||
# 简单的编辑距离计算
|
||||
max_len = max(len(text1), len(text2))
|
||||
if max_len == 0:
|
||||
return False
|
||||
|
||||
# 计算编辑距离
|
||||
distance = self._levenshtein_distance(text1, text2)
|
||||
|
||||
# 如果编辑距离小于较短字符串长度的一半,认为是部分匹配
|
||||
min_len = min(len(text1), len(text2))
|
||||
return distance <= min_len // 2
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _levenshtein_distance(self, s1: str, s2: str) -> int:
|
||||
"""计算莱文斯坦距离"""
|
||||
if len(s1) < len(s2):
|
||||
return self._levenshtein_distance(s2, s1)
|
||||
|
||||
if len(s2) == 0:
|
||||
return len(s1)
|
||||
|
||||
previous_row = range(len(s2) + 1)
|
||||
for i, c1 in enumerate(s1):
|
||||
current_row = [i + 1]
|
||||
for j, c2 in enumerate(s2):
|
||||
insertions = previous_row[j + 1] + 1
|
||||
deletions = current_row[j] + 1
|
||||
substitutions = previous_row[j] + (c1 != c2)
|
||||
current_row.append(min(insertions, deletions, substitutions))
|
||||
previous_row = current_row
|
||||
|
||||
return previous_row[-1]
|
||||
|
||||
def _calculate_cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
|
||||
"""计算余弦相似度"""
|
||||
try:
|
||||
vec1 = np.array(vec1)
|
||||
vec2 = np.array(vec2)
|
||||
|
||||
dot_product = np.dot(vec1, vec2)
|
||||
norm1 = np.linalg.norm(vec1)
|
||||
norm2 = np.linalg.norm(vec2)
|
||||
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0.0
|
||||
|
||||
return dot_product / (norm1 * norm2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算余弦相似度失败: {e}")
|
||||
return 0.0
|
||||
|
||||
async def _load_interests_from_database(self, personality_id: str) -> Optional[BotPersonalityInterests]:
|
||||
"""从数据库加载兴趣标签"""
|
||||
try:
|
||||
logger.debug(f"从数据库加载兴趣标签, personality_id: {personality_id}")
|
||||
|
||||
# 导入SQLAlchemy相关模块
|
||||
from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
import orjson
|
||||
|
||||
with get_db_session() as session:
|
||||
# 查询最新的兴趣标签配置
|
||||
db_interests = (
|
||||
session.query(DBBotPersonalityInterests)
|
||||
.filter(DBBotPersonalityInterests.personality_id == personality_id)
|
||||
.order_by(DBBotPersonalityInterests.version.desc(), DBBotPersonalityInterests.last_updated.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if db_interests:
|
||||
logger.debug(f"在数据库中找到兴趣标签配置, 版本: {db_interests.version}")
|
||||
logger.debug(f"📅 最后更新时间: {db_interests.last_updated}")
|
||||
logger.debug(f"🧠 使用的embedding模型: {db_interests.embedding_model}")
|
||||
|
||||
# 解析JSON格式的兴趣标签
|
||||
try:
|
||||
tags_data = orjson.loads(db_interests.interest_tags)
|
||||
logger.debug(f"🏷️ 解析到 {len(tags_data)} 个兴趣标签")
|
||||
|
||||
# 创建BotPersonalityInterests对象
|
||||
interests = BotPersonalityInterests(
|
||||
personality_id=db_interests.personality_id,
|
||||
personality_description=db_interests.personality_description,
|
||||
embedding_model=db_interests.embedding_model,
|
||||
version=db_interests.version,
|
||||
last_updated=db_interests.last_updated,
|
||||
)
|
||||
|
||||
# 解析兴趣标签
|
||||
for tag_data in tags_data:
|
||||
tag = BotInterestTag(
|
||||
tag_name=tag_data.get("tag_name", ""),
|
||||
weight=tag_data.get("weight", 0.5),
|
||||
created_at=datetime.fromisoformat(
|
||||
tag_data.get("created_at", datetime.now().isoformat())
|
||||
),
|
||||
updated_at=datetime.fromisoformat(
|
||||
tag_data.get("updated_at", datetime.now().isoformat())
|
||||
),
|
||||
is_active=tag_data.get("is_active", True),
|
||||
embedding=tag_data.get("embedding"),
|
||||
)
|
||||
interests.interest_tags.append(tag)
|
||||
|
||||
logger.debug(f"成功解析 {len(interests.interest_tags)} 个兴趣标签")
|
||||
return interests
|
||||
|
||||
except (orjson.JSONDecodeError, Exception) as e:
|
||||
logger.error(f"❌ 解析兴趣标签JSON失败: {e}")
|
||||
logger.debug(f"🔍 原始JSON数据: {db_interests.interest_tags[:200]}...")
|
||||
return None
|
||||
else:
|
||||
logger.info(f"ℹ️ 数据库中未找到personality_id为 '{personality_id}' 的兴趣标签配置")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 从数据库加载兴趣标签失败: {e}")
|
||||
logger.error("🔍 错误详情:")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
async def _save_interests_to_database(self, interests: BotPersonalityInterests):
|
||||
"""保存兴趣标签到数据库"""
|
||||
try:
|
||||
logger.info("💾 正在保存兴趣标签到数据库...")
|
||||
logger.info(f"📋 personality_id: {interests.personality_id}")
|
||||
logger.info(f"🏷️ 兴趣标签数量: {len(interests.interest_tags)}")
|
||||
logger.info(f"🔄 版本: {interests.version}")
|
||||
|
||||
# 导入SQLAlchemy相关模块
|
||||
from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
import orjson
|
||||
|
||||
# 将兴趣标签转换为JSON格式
|
||||
tags_data = []
|
||||
for tag in interests.interest_tags:
|
||||
tag_dict = {
|
||||
"tag_name": tag.tag_name,
|
||||
"weight": tag.weight,
|
||||
"created_at": tag.created_at.isoformat(),
|
||||
"updated_at": tag.updated_at.isoformat(),
|
||||
"is_active": tag.is_active,
|
||||
"embedding": tag.embedding,
|
||||
}
|
||||
tags_data.append(tag_dict)
|
||||
|
||||
# 序列化为JSON
|
||||
json_data = orjson.dumps(tags_data)
|
||||
|
||||
with get_db_session() as session:
|
||||
# 检查是否已存在相同personality_id的记录
|
||||
existing_record = (
|
||||
session.query(DBBotPersonalityInterests)
|
||||
.filter(DBBotPersonalityInterests.personality_id == interests.personality_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_record:
|
||||
# 更新现有记录
|
||||
logger.info("🔄 更新现有的兴趣标签配置")
|
||||
existing_record.interest_tags = json_data
|
||||
existing_record.personality_description = interests.personality_description
|
||||
existing_record.embedding_model = interests.embedding_model
|
||||
existing_record.version = interests.version
|
||||
existing_record.last_updated = interests.last_updated
|
||||
|
||||
logger.info(f"✅ 成功更新兴趣标签配置,版本: {interests.version}")
|
||||
|
||||
else:
|
||||
# 创建新记录
|
||||
logger.info("🆕 创建新的兴趣标签配置")
|
||||
new_record = DBBotPersonalityInterests(
|
||||
personality_id=interests.personality_id,
|
||||
personality_description=interests.personality_description,
|
||||
interest_tags=json_data,
|
||||
embedding_model=interests.embedding_model,
|
||||
version=interests.version,
|
||||
last_updated=interests.last_updated,
|
||||
)
|
||||
session.add(new_record)
|
||||
session.commit()
|
||||
logger.info(f"✅ 成功创建兴趣标签配置,版本: {interests.version}")
|
||||
|
||||
logger.info("✅ 兴趣标签已成功保存到数据库")
|
||||
|
||||
# 验证保存是否成功
|
||||
with get_db_session() as session:
|
||||
saved_record = (
|
||||
session.query(DBBotPersonalityInterests)
|
||||
.filter(DBBotPersonalityInterests.personality_id == interests.personality_id)
|
||||
.first()
|
||||
)
|
||||
session.commit()
|
||||
if saved_record:
|
||||
logger.info(f"✅ 验证成功:数据库中存在personality_id为 {interests.personality_id} 的记录")
|
||||
logger.info(f" 版本: {saved_record.version}")
|
||||
logger.info(f" 最后更新: {saved_record.last_updated}")
|
||||
else:
|
||||
logger.error(f"❌ 验证失败:数据库中未找到personality_id为 {interests.personality_id} 的记录")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 保存兴趣标签到数据库失败: {e}")
|
||||
logger.error("🔍 错误详情:")
|
||||
traceback.print_exc()
|
||||
|
||||
def get_current_interests(self) -> Optional[BotPersonalityInterests]:
|
||||
"""获取当前的兴趣标签配置"""
|
||||
return self.current_interests
|
||||
|
||||
def get_interest_stats(self) -> Dict[str, Any]:
|
||||
"""获取兴趣系统统计信息"""
|
||||
if not self.current_interests:
|
||||
return {"initialized": False}
|
||||
|
||||
active_tags = self.current_interests.get_active_tags()
|
||||
|
||||
return {
|
||||
"initialized": self._initialized,
|
||||
"total_tags": len(active_tags),
|
||||
"embedding_model": self.current_interests.embedding_model,
|
||||
"last_updated": self.current_interests.last_updated.isoformat(),
|
||||
"cache_size": len(self.embedding_cache),
|
||||
}
|
||||
|
||||
async def update_interest_tags(self, new_personality_description: str = None):
|
||||
"""更新兴趣标签"""
|
||||
try:
|
||||
if not self.current_interests:
|
||||
logger.warning("没有当前的兴趣标签配置,无法更新")
|
||||
return
|
||||
|
||||
if new_personality_description:
|
||||
self.current_interests.personality_description = new_personality_description
|
||||
|
||||
# 重新生成兴趣标签
|
||||
new_interests = await self._generate_interests_from_personality(
|
||||
self.current_interests.personality_description, self.current_interests.personality_id
|
||||
)
|
||||
|
||||
if new_interests:
|
||||
new_interests.version = self.current_interests.version + 1
|
||||
self.current_interests = new_interests
|
||||
await self._save_interests_to_database(new_interests)
|
||||
logger.info(f"兴趣标签已更新,版本: {new_interests.version}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新兴趣标签失败: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
# 创建全局实例(重新创建以包含新的属性)
|
||||
bot_interest_manager = BotInterestManager()
|
||||
26
src/chat/message_manager/__init__.py
Normal file
26
src/chat/message_manager/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
消息管理器模块
|
||||
提供统一的消息管理、上下文管理和分发调度功能
|
||||
"""
|
||||
|
||||
from .message_manager import MessageManager, message_manager
|
||||
from .context_manager import StreamContextManager, context_manager
|
||||
from .distribution_manager import (
|
||||
DistributionManager,
|
||||
DistributionPriority,
|
||||
DistributionTask,
|
||||
StreamDistributionState,
|
||||
distribution_manager
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MessageManager",
|
||||
"message_manager",
|
||||
"StreamContextManager",
|
||||
"context_manager",
|
||||
"DistributionManager",
|
||||
"DistributionPriority",
|
||||
"DistributionTask",
|
||||
"StreamDistributionState",
|
||||
"distribution_manager"
|
||||
]
|
||||
653
src/chat/message_manager/context_manager.py
Normal file
653
src/chat/message_manager/context_manager.py
Normal file
@@ -0,0 +1,653 @@
|
||||
"""
|
||||
重构后的聊天上下文管理器
|
||||
提供统一、稳定的聊天上下文管理功能
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any, Union, Tuple
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.chat.energy_system import energy_manager
|
||||
from .distribution_manager import distribution_manager
|
||||
|
||||
logger = get_logger("context_manager")
|
||||
|
||||
class StreamContextManager:
|
||||
"""流上下文管理器 - 统一管理所有聊天流上下文"""
|
||||
|
||||
def __init__(self, max_context_size: Optional[int] = None, context_ttl: Optional[int] = None):
|
||||
# 上下文存储
|
||||
self.stream_contexts: Dict[str, Any] = {}
|
||||
self.context_metadata: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# 统计信息
|
||||
self.stats: Dict[str, Union[int, float, str, Dict]] = {
|
||||
"total_messages": 0,
|
||||
"total_streams": 0,
|
||||
"active_streams": 0,
|
||||
"inactive_streams": 0,
|
||||
"last_activity": time.time(),
|
||||
"creation_time": time.time(),
|
||||
}
|
||||
|
||||
# 配置参数
|
||||
self.max_context_size = max_context_size or getattr(global_config.chat, "max_context_size", 100)
|
||||
self.context_ttl = context_ttl or getattr(global_config.chat, "context_ttl", 24 * 3600) # 24小时
|
||||
self.cleanup_interval = getattr(global_config.chat, "context_cleanup_interval", 3600) # 1小时
|
||||
self.auto_cleanup = getattr(global_config.chat, "auto_cleanup_contexts", True)
|
||||
self.enable_validation = getattr(global_config.chat, "enable_context_validation", True)
|
||||
|
||||
# 清理任务
|
||||
self.cleanup_task: Optional[Any] = None
|
||||
self.is_running = False
|
||||
|
||||
logger.info(f"上下文管理器初始化完成 (最大上下文: {self.max_context_size}, TTL: {self.context_ttl}s)")
|
||||
|
||||
def add_stream_context(self, stream_id: str, context: Any, metadata: Optional[Dict[str, Any]] = None) -> bool:
|
||||
"""添加流上下文
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
context: 上下文对象
|
||||
metadata: 上下文元数据
|
||||
|
||||
Returns:
|
||||
bool: 是否成功添加
|
||||
"""
|
||||
if stream_id in self.stream_contexts:
|
||||
logger.warning(f"流上下文已存在: {stream_id}")
|
||||
return False
|
||||
|
||||
# 添加上下文
|
||||
self.stream_contexts[stream_id] = context
|
||||
|
||||
# 初始化元数据
|
||||
self.context_metadata[stream_id] = {
|
||||
"created_time": time.time(),
|
||||
"last_access_time": time.time(),
|
||||
"access_count": 0,
|
||||
"last_validation_time": 0.0,
|
||||
"custom_metadata": metadata or {},
|
||||
}
|
||||
|
||||
# 更新统计
|
||||
self.stats["total_streams"] += 1
|
||||
self.stats["active_streams"] += 1
|
||||
self.stats["last_activity"] = time.time()
|
||||
|
||||
logger.debug(f"添加流上下文: {stream_id} (类型: {type(context).__name__})")
|
||||
return True
|
||||
|
||||
def remove_stream_context(self, stream_id: str) -> bool:
|
||||
"""移除流上下文
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
|
||||
Returns:
|
||||
bool: 是否成功移除
|
||||
"""
|
||||
if stream_id in self.stream_contexts:
|
||||
context = self.stream_contexts[stream_id]
|
||||
metadata = self.context_metadata.get(stream_id, {})
|
||||
|
||||
del self.stream_contexts[stream_id]
|
||||
if stream_id in self.context_metadata:
|
||||
del self.context_metadata[stream_id]
|
||||
|
||||
self.stats["active_streams"] = max(0, self.stats["active_streams"] - 1)
|
||||
self.stats["inactive_streams"] += 1
|
||||
self.stats["last_activity"] = time.time()
|
||||
|
||||
logger.debug(f"移除流上下文: {stream_id} (类型: {type(context).__name__})")
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_stream_context(self, stream_id: str, update_access: bool = True) -> Optional[StreamContext]:
|
||||
"""获取流上下文
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
update_access: 是否更新访问统计
|
||||
|
||||
Returns:
|
||||
Optional[Any]: 上下文对象
|
||||
"""
|
||||
context = self.stream_contexts.get(stream_id)
|
||||
if context and update_access:
|
||||
# 更新访问统计
|
||||
if stream_id in self.context_metadata:
|
||||
metadata = self.context_metadata[stream_id]
|
||||
metadata["last_access_time"] = time.time()
|
||||
metadata["access_count"] = metadata.get("access_count", 0) + 1
|
||||
return context
|
||||
|
||||
def get_context_metadata(self, stream_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取上下文元数据
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 元数据
|
||||
"""
|
||||
return self.context_metadata.get(stream_id)
|
||||
|
||||
def update_context_metadata(self, stream_id: str, updates: Dict[str, Any]) -> bool:
|
||||
"""更新上下文元数据
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
updates: 更新的元数据
|
||||
|
||||
Returns:
|
||||
bool: 是否成功更新
|
||||
"""
|
||||
if stream_id not in self.context_metadata:
|
||||
return False
|
||||
|
||||
self.context_metadata[stream_id].update(updates)
|
||||
return True
|
||||
|
||||
def add_message_to_context(self, stream_id: str, message: DatabaseMessages, skip_energy_update: bool = False) -> bool:
|
||||
"""添加消息到上下文
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
message: 消息对象
|
||||
skip_energy_update: 是否跳过能量更新
|
||||
|
||||
Returns:
|
||||
bool: 是否成功添加
|
||||
"""
|
||||
context = self.get_stream_context(stream_id)
|
||||
if not context:
|
||||
logger.warning(f"流上下文不存在: {stream_id}")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 添加消息到上下文
|
||||
context.add_message(message)
|
||||
|
||||
# 计算消息兴趣度
|
||||
interest_value = self._calculate_message_interest(message)
|
||||
message.interest_value = interest_value
|
||||
|
||||
# 更新统计
|
||||
self.stats["total_messages"] += 1
|
||||
self.stats["last_activity"] = time.time()
|
||||
|
||||
# 更新能量和分发
|
||||
if not skip_energy_update:
|
||||
self._update_stream_energy(stream_id)
|
||||
distribution_manager.add_stream_message(stream_id, 1)
|
||||
|
||||
logger.debug(f"添加消息到上下文: {stream_id} (兴趣度: {interest_value:.3f})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"添加消息到上下文失败 {stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def update_message_in_context(self, stream_id: str, message_id: str, updates: Dict[str, Any]) -> bool:
|
||||
"""更新上下文中的消息
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
message_id: 消息ID
|
||||
updates: 更新的属性
|
||||
|
||||
Returns:
|
||||
bool: 是否成功更新
|
||||
"""
|
||||
context = self.get_stream_context(stream_id)
|
||||
if not context:
|
||||
logger.warning(f"流上下文不存在: {stream_id}")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 更新消息信息
|
||||
context.update_message_info(message_id, **updates)
|
||||
|
||||
# 如果更新了兴趣度,重新计算能量
|
||||
if "interest_value" in updates:
|
||||
self._update_stream_energy(stream_id)
|
||||
|
||||
logger.debug(f"更新上下文消息: {stream_id}/{message_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新上下文消息失败 {stream_id}/{message_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def get_context_messages(self, stream_id: str, limit: Optional[int] = None, include_unread: bool = True) -> List[DatabaseMessages]:
|
||||
"""获取上下文消息
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
limit: 消息数量限制
|
||||
include_unread: 是否包含未读消息
|
||||
|
||||
Returns:
|
||||
List[Any]: 消息列表
|
||||
"""
|
||||
context = self.get_stream_context(stream_id)
|
||||
if not context:
|
||||
return []
|
||||
|
||||
try:
|
||||
messages = []
|
||||
if include_unread:
|
||||
messages.extend(context.get_unread_messages())
|
||||
|
||||
if limit:
|
||||
messages.extend(context.get_history_messages(limit=limit))
|
||||
else:
|
||||
messages.extend(context.get_history_messages())
|
||||
|
||||
# 按时间排序
|
||||
messages.sort(key=lambda msg: getattr(msg, 'time', 0))
|
||||
|
||||
# 应用限制
|
||||
if limit and len(messages) > limit:
|
||||
messages = messages[-limit:]
|
||||
|
||||
return messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取上下文消息失败 {stream_id}: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def get_unread_messages(self, stream_id: str) -> List[DatabaseMessages]:
|
||||
"""获取未读消息
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
|
||||
Returns:
|
||||
List[Any]: 未读消息列表
|
||||
"""
|
||||
context = self.get_stream_context(stream_id)
|
||||
if not context:
|
||||
return []
|
||||
|
||||
try:
|
||||
return context.get_unread_messages()
|
||||
except Exception as e:
|
||||
logger.error(f"获取未读消息失败 {stream_id}: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def mark_messages_as_read(self, stream_id: str, message_ids: List[str]) -> bool:
|
||||
"""标记消息为已读
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
message_ids: 消息ID列表
|
||||
|
||||
Returns:
|
||||
bool: 是否成功标记
|
||||
"""
|
||||
context = self.get_stream_context(stream_id)
|
||||
if not context:
|
||||
logger.warning(f"流上下文不存在: {stream_id}")
|
||||
return False
|
||||
|
||||
try:
|
||||
if not hasattr(context, 'mark_message_as_read'):
|
||||
logger.error(f"上下文对象缺少 mark_message_as_read 方法: {stream_id}")
|
||||
return False
|
||||
|
||||
marked_count = 0
|
||||
for message_id in message_ids:
|
||||
try:
|
||||
context.mark_message_as_read(message_id)
|
||||
marked_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"标记消息已读失败 {message_id}: {e}")
|
||||
|
||||
logger.debug(f"标记消息为已读: {stream_id} ({marked_count}/{len(message_ids)}条)")
|
||||
return marked_count > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"标记消息已读失败 {stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def clear_context(self, stream_id: str) -> bool:
|
||||
"""清空上下文
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
|
||||
Returns:
|
||||
bool: 是否成功清空
|
||||
"""
|
||||
context = self.get_stream_context(stream_id)
|
||||
if not context:
|
||||
logger.warning(f"流上下文不存在: {stream_id}")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 清空消息
|
||||
if hasattr(context, 'unread_messages'):
|
||||
context.unread_messages.clear()
|
||||
if hasattr(context, 'history_messages'):
|
||||
context.history_messages.clear()
|
||||
|
||||
# 重置状态
|
||||
reset_attrs = ['interruption_count', 'afc_threshold_adjustment', 'last_check_time']
|
||||
for attr in reset_attrs:
|
||||
if hasattr(context, attr):
|
||||
if attr in ['interruption_count', 'afc_threshold_adjustment']:
|
||||
setattr(context, attr, 0)
|
||||
else:
|
||||
setattr(context, attr, time.time())
|
||||
|
||||
# 重新计算能量
|
||||
self._update_stream_energy(stream_id)
|
||||
|
||||
logger.info(f"清空上下文: {stream_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清空上下文失败 {stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def _calculate_message_interest(self, message: DatabaseMessages) -> float:
|
||||
"""计算消息兴趣度"""
|
||||
try:
|
||||
# 使用插件内部的兴趣度评分系统
|
||||
try:
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||
|
||||
# 使用插件内部的兴趣度评分系统计算(同步方式)
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
interest_score = loop.run_until_complete(
|
||||
chatter_interest_scoring_system._calculate_single_message_score(
|
||||
message=message,
|
||||
bot_nickname=global_config.bot.nickname
|
||||
)
|
||||
)
|
||||
interest_value = interest_score.total_score
|
||||
|
||||
logger.debug(f"使用插件内部系统计算兴趣度: {interest_value:.3f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"插件内部兴趣度计算失败,使用默认值: {e}")
|
||||
interest_value = 0.5 # 默认中等兴趣度
|
||||
|
||||
return interest_value
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算消息兴趣度失败: {e}")
|
||||
return 0.5
|
||||
|
||||
def _update_stream_energy(self, stream_id: str):
|
||||
"""更新流能量"""
|
||||
try:
|
||||
# 获取所有消息
|
||||
all_messages = self.get_context_messages(stream_id, self.max_context_size)
|
||||
unread_messages = self.get_unread_messages(stream_id)
|
||||
combined_messages = all_messages + unread_messages
|
||||
|
||||
# 获取用户ID
|
||||
user_id = None
|
||||
if combined_messages:
|
||||
last_message = combined_messages[-1]
|
||||
user_id = last_message.user_info.user_id
|
||||
|
||||
# 计算能量
|
||||
energy = energy_manager.calculate_focus_energy(
|
||||
stream_id=stream_id,
|
||||
messages=combined_messages,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# 更新分发管理器
|
||||
distribution_manager.update_stream_energy(stream_id, energy)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新流能量失败 {stream_id}: {e}")
|
||||
|
||||
def get_stream_statistics(self, stream_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取流统计信息
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 统计信息
|
||||
"""
|
||||
context = self.get_stream_context(stream_id, update_access=False)
|
||||
if not context:
|
||||
return None
|
||||
|
||||
try:
|
||||
metadata = self.context_metadata.get(stream_id, {})
|
||||
current_time = time.time()
|
||||
created_time = metadata.get("created_time", current_time)
|
||||
last_access_time = metadata.get("last_access_time", current_time)
|
||||
access_count = metadata.get("access_count", 0)
|
||||
|
||||
unread_messages = getattr(context, "unread_messages", [])
|
||||
history_messages = getattr(context, "history_messages", [])
|
||||
|
||||
return {
|
||||
"stream_id": stream_id,
|
||||
"context_type": type(context).__name__,
|
||||
"total_messages": len(history_messages) + len(unread_messages),
|
||||
"unread_messages": len(unread_messages),
|
||||
"history_messages": len(history_messages),
|
||||
"is_active": getattr(context, "is_active", True),
|
||||
"last_check_time": getattr(context, "last_check_time", current_time),
|
||||
"interruption_count": getattr(context, "interruption_count", 0),
|
||||
"afc_threshold_adjustment": getattr(context, "afc_threshold_adjustment", 0.0),
|
||||
"created_time": created_time,
|
||||
"last_access_time": last_access_time,
|
||||
"access_count": access_count,
|
||||
"uptime_seconds": current_time - created_time,
|
||||
"idle_seconds": current_time - last_access_time,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取流统计失败 {stream_id}: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def get_manager_statistics(self) -> Dict[str, Any]:
|
||||
"""获取管理器统计信息
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 管理器统计信息
|
||||
"""
|
||||
current_time = time.time()
|
||||
uptime = current_time - self.stats.get("creation_time", current_time)
|
||||
|
||||
return {
|
||||
**self.stats,
|
||||
"uptime_hours": uptime / 3600,
|
||||
"stream_count": len(self.stream_contexts),
|
||||
"metadata_count": len(self.context_metadata),
|
||||
"auto_cleanup_enabled": self.auto_cleanup,
|
||||
"cleanup_interval": self.cleanup_interval,
|
||||
}
|
||||
|
||||
def cleanup_inactive_contexts(self, max_inactive_hours: int = 24) -> int:
|
||||
"""清理不活跃的上下文
|
||||
|
||||
Args:
|
||||
max_inactive_hours: 最大不活跃小时数
|
||||
|
||||
Returns:
|
||||
int: 清理的上下文数量
|
||||
"""
|
||||
current_time = time.time()
|
||||
max_inactive_seconds = max_inactive_hours * 3600
|
||||
|
||||
inactive_streams = []
|
||||
for stream_id, context in self.stream_contexts.items():
|
||||
try:
|
||||
# 获取最后活动时间
|
||||
metadata = self.context_metadata.get(stream_id, {})
|
||||
last_activity = metadata.get("last_access_time", metadata.get("created_time", 0))
|
||||
context_last_activity = getattr(context, "last_check_time", 0)
|
||||
actual_last_activity = max(last_activity, context_last_activity)
|
||||
|
||||
# 检查是否不活跃
|
||||
unread_count = len(getattr(context, "unread_messages", []))
|
||||
history_count = len(getattr(context, "history_messages", []))
|
||||
total_messages = unread_count + history_count
|
||||
|
||||
if (current_time - actual_last_activity > max_inactive_seconds and
|
||||
total_messages == 0):
|
||||
inactive_streams.append(stream_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"检查上下文活跃状态失败 {stream_id}: {e}")
|
||||
continue
|
||||
|
||||
# 清理不活跃上下文
|
||||
cleaned_count = 0
|
||||
for stream_id in inactive_streams:
|
||||
if self.remove_stream_context(stream_id):
|
||||
cleaned_count += 1
|
||||
|
||||
if cleaned_count > 0:
|
||||
logger.info(f"清理了 {cleaned_count} 个不活跃上下文")
|
||||
|
||||
return cleaned_count
|
||||
|
||||
def validate_context_integrity(self, stream_id: str) -> bool:
|
||||
"""验证上下文完整性
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
|
||||
Returns:
|
||||
bool: 是否完整
|
||||
"""
|
||||
context = self.get_stream_context(stream_id)
|
||||
if not context:
|
||||
return False
|
||||
|
||||
try:
|
||||
# 检查基本属性
|
||||
required_attrs = ["stream_id", "unread_messages", "history_messages"]
|
||||
for attr in required_attrs:
|
||||
if not hasattr(context, attr):
|
||||
logger.warning(f"上下文缺少必要属性: {attr}")
|
||||
return False
|
||||
|
||||
# 检查消息ID唯一性
|
||||
all_messages = getattr(context, "unread_messages", []) + getattr(context, "history_messages", [])
|
||||
message_ids = [msg.message_id for msg in all_messages if hasattr(msg, "message_id")]
|
||||
if len(message_ids) != len(set(message_ids)):
|
||||
logger.warning(f"上下文中存在重复消息ID: {stream_id}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"验证上下文完整性失败 {stream_id}: {e}")
|
||||
return False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动上下文管理器"""
|
||||
if self.is_running:
|
||||
logger.warning("上下文管理器已经在运行")
|
||||
return
|
||||
|
||||
await self.start_auto_cleanup()
|
||||
logger.info("上下文管理器已启动")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止上下文管理器"""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
await self.stop_auto_cleanup()
|
||||
logger.info("上下文管理器已停止")
|
||||
|
||||
async def start_auto_cleanup(self, interval: Optional[float] = None) -> None:
|
||||
"""启动自动清理
|
||||
|
||||
Args:
|
||||
interval: 清理间隔(秒)
|
||||
"""
|
||||
if not self.auto_cleanup:
|
||||
logger.info("自动清理已禁用")
|
||||
return
|
||||
|
||||
if self.is_running:
|
||||
logger.warning("自动清理已在运行")
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
cleanup_interval = interval or self.cleanup_interval
|
||||
logger.info(f"启动自动清理(间隔: {cleanup_interval}s)")
|
||||
|
||||
import asyncio
|
||||
self.cleanup_task = asyncio.create_task(self._cleanup_loop(cleanup_interval))
|
||||
|
||||
async def stop_auto_cleanup(self) -> None:
|
||||
"""停止自动清理"""
|
||||
self.is_running = False
|
||||
if self.cleanup_task and not self.cleanup_task.done():
|
||||
self.cleanup_task.cancel()
|
||||
try:
|
||||
await self.cleanup_task
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("自动清理已停止")
|
||||
|
||||
async def _cleanup_loop(self, interval: float) -> None:
|
||||
"""清理循环
|
||||
|
||||
Args:
|
||||
interval: 清理间隔
|
||||
"""
|
||||
while self.is_running:
|
||||
try:
|
||||
await asyncio.sleep(interval)
|
||||
self.cleanup_inactive_contexts()
|
||||
self._cleanup_expired_contexts()
|
||||
logger.debug("自动清理完成")
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"清理循环出错: {e}", exc_info=True)
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
def _cleanup_expired_contexts(self) -> None:
|
||||
"""清理过期上下文"""
|
||||
current_time = time.time()
|
||||
expired_contexts = []
|
||||
|
||||
for stream_id, metadata in self.context_metadata.items():
|
||||
created_time = metadata.get("created_time", current_time)
|
||||
if current_time - created_time > self.context_ttl:
|
||||
expired_contexts.append(stream_id)
|
||||
|
||||
for stream_id in expired_contexts:
|
||||
self.remove_stream_context(stream_id)
|
||||
|
||||
if expired_contexts:
|
||||
logger.info(f"清理了 {len(expired_contexts)} 个过期上下文")
|
||||
|
||||
def get_active_streams(self) -> List[str]:
|
||||
"""获取活跃流列表
|
||||
|
||||
Returns:
|
||||
List[str]: 活跃流ID列表
|
||||
"""
|
||||
return list(self.stream_contexts.keys())
|
||||
|
||||
|
||||
# 全局上下文管理器实例
|
||||
context_manager = StreamContextManager()
|
||||
1004
src/chat/message_manager/distribution_manager.py
Normal file
1004
src/chat/message_manager/distribution_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
558
src/chat/message_manager/message_manager.py
Normal file
558
src/chat/message_manager/message_manager.py
Normal file
@@ -0,0 +1,558 @@
|
||||
"""
|
||||
消息管理模块
|
||||
管理每个聊天流的上下文信息,包含历史记录和未读消息,定期检查并处理新消息
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
from typing import Dict, Optional, Any, TYPE_CHECKING
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.message_manager_data_model import StreamContext, MessageManagerStats, StreamStats
|
||||
from src.chat.chatter_manager import ChatterManager
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.plugin_system.base.component_types import ChatMode
|
||||
from .sleep_manager.sleep_manager import SleepManager
|
||||
from .sleep_manager.wakeup_manager import WakeUpManager
|
||||
from src.config.config import global_config
|
||||
from .context_manager import context_manager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
|
||||
logger = get_logger("message_manager")
|
||||
|
||||
|
||||
class MessageManager:
|
||||
"""消息管理器"""
|
||||
|
||||
def __init__(self, check_interval: float = 5.0):
|
||||
self.check_interval = check_interval # 检查间隔(秒)
|
||||
self.is_running = False
|
||||
self.manager_task: Optional[asyncio.Task] = None
|
||||
|
||||
# 统计信息
|
||||
self.stats = MessageManagerStats()
|
||||
|
||||
# 初始化chatter manager
|
||||
self.action_manager = ChatterActionManager()
|
||||
self.chatter_manager = ChatterManager(self.action_manager)
|
||||
|
||||
# 初始化睡眠和唤醒管理器
|
||||
self.sleep_manager = SleepManager()
|
||||
self.wakeup_manager = WakeUpManager(self.sleep_manager)
|
||||
|
||||
# 初始化上下文管理器
|
||||
self.context_manager = context_manager
|
||||
|
||||
async def start(self):
|
||||
"""启动消息管理器"""
|
||||
if self.is_running:
|
||||
logger.warning("消息管理器已经在运行")
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
self.manager_task = asyncio.create_task(self._manager_loop())
|
||||
await self.wakeup_manager.start()
|
||||
await self.context_manager.start()
|
||||
logger.info("消息管理器已启动")
|
||||
|
||||
async def stop(self):
|
||||
"""停止消息管理器"""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
self.is_running = False
|
||||
|
||||
# 停止所有流处理任务
|
||||
# 注意:context_manager 会自己清理任务
|
||||
if self.manager_task and not self.manager_task.done():
|
||||
self.manager_task.cancel()
|
||||
|
||||
await self.wakeup_manager.stop()
|
||||
await self.context_manager.stop()
|
||||
|
||||
logger.info("消息管理器已停止")
|
||||
|
||||
def add_message(self, stream_id: str, message: DatabaseMessages):
|
||||
"""添加消息到指定聊天流"""
|
||||
# 检查流上下文是否存在,不存在则创建
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if not context:
|
||||
# 创建新的流上下文
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
context = StreamContext(stream_id=stream_id)
|
||||
# 将创建的上下文添加到管理器
|
||||
self.context_manager.add_stream_context(stream_id, context)
|
||||
|
||||
# 使用 context_manager 添加消息
|
||||
success = self.context_manager.add_message_to_context(stream_id, message)
|
||||
|
||||
if success:
|
||||
logger.debug(f"添加消息到聊天流 {stream_id}: {message.message_id}")
|
||||
else:
|
||||
logger.warning(f"添加消息到聊天流 {stream_id} 失败")
|
||||
|
||||
def update_message(
|
||||
self,
|
||||
stream_id: str,
|
||||
message_id: str,
|
||||
interest_value: float = None,
|
||||
actions: list = None,
|
||||
should_reply: bool = None,
|
||||
):
|
||||
"""更新消息信息"""
|
||||
# 使用 context_manager 更新消息信息
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if context:
|
||||
context.update_message_info(message_id, interest_value, actions, should_reply)
|
||||
|
||||
def add_action(self, stream_id: str, message_id: str, action: str):
|
||||
"""添加动作到消息"""
|
||||
# 使用 context_manager 添加动作到消息
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if context:
|
||||
context.add_action_to_message(message_id, action)
|
||||
|
||||
async def _manager_loop(self):
|
||||
"""管理器主循环 - 独立聊天流分发周期版本"""
|
||||
while self.is_running:
|
||||
try:
|
||||
# 更新睡眠状态
|
||||
await self.sleep_manager.update_sleep_state(self.wakeup_manager)
|
||||
|
||||
# 执行独立分发周期的检查
|
||||
await self._check_streams_with_individual_intervals()
|
||||
|
||||
# 计算下次检查时间(使用最小间隔或固定间隔)
|
||||
if global_config.chat.dynamic_distribution_enabled:
|
||||
next_check_delay = self._calculate_next_manager_delay()
|
||||
else:
|
||||
next_check_delay = self.check_interval
|
||||
|
||||
await asyncio.sleep(next_check_delay)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"消息管理器循环出错: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
async def _check_all_streams(self):
|
||||
"""检查所有聊天流"""
|
||||
active_streams = 0
|
||||
total_unread = 0
|
||||
|
||||
# 使用 context_manager 获取活跃的流
|
||||
active_stream_ids = self.context_manager.get_active_streams()
|
||||
|
||||
for stream_id in active_stream_ids:
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if not context:
|
||||
continue
|
||||
|
||||
active_streams += 1
|
||||
|
||||
# 检查是否有未读消息
|
||||
unread_messages = self.context_manager.get_unread_messages(stream_id)
|
||||
if unread_messages:
|
||||
total_unread += len(unread_messages)
|
||||
|
||||
# 如果没有处理任务,创建一个
|
||||
if not hasattr(context, 'processing_task') or not context.processing_task or context.processing_task.done():
|
||||
context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
|
||||
|
||||
# 更新统计
|
||||
self.stats.active_streams = active_streams
|
||||
self.stats.total_unread_messages = total_unread
|
||||
|
||||
async def _process_stream_messages(self, stream_id: str):
|
||||
"""处理指定聊天流的消息"""
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if not context:
|
||||
return
|
||||
|
||||
try:
|
||||
# 获取未读消息
|
||||
unread_messages = self.context_manager.get_unread_messages(stream_id)
|
||||
if not unread_messages:
|
||||
return
|
||||
|
||||
# 检查是否需要打断现有处理
|
||||
await self._check_and_handle_interruption(context, stream_id)
|
||||
|
||||
# --- 睡眠状态检查 ---
|
||||
if self.sleep_manager.is_sleeping():
|
||||
logger.info(f"Bot正在睡觉,检查聊天流 {stream_id} 是否有唤醒触发器。")
|
||||
|
||||
was_woken_up = False
|
||||
is_private = context.is_private_chat()
|
||||
|
||||
for message in unread_messages:
|
||||
is_mentioned = message.is_mentioned or False
|
||||
if not is_mentioned and not is_private:
|
||||
bot_names = [global_config.bot.nickname] + global_config.bot.alias_names
|
||||
if any(name in message.processed_plain_text for name in bot_names):
|
||||
is_mentioned = True
|
||||
logger.debug(f"通过关键词 '{next((name for name in bot_names if name in message.processed_plain_text), '')}' 匹配将消息标记为 'is_mentioned'")
|
||||
|
||||
if is_private or is_mentioned:
|
||||
if self.wakeup_manager.add_wakeup_value(is_private, is_mentioned, chat_id=stream_id):
|
||||
was_woken_up = True
|
||||
break # 一旦被吵醒,就跳出循环并处理消息
|
||||
|
||||
if not was_woken_up:
|
||||
logger.debug(f"聊天流 {stream_id} 中没有唤醒触发器,保持消息未读状态。")
|
||||
return # 退出,不处理消息
|
||||
|
||||
logger.info(f"Bot被聊天流 {stream_id} 中的消息吵醒,继续处理。")
|
||||
elif self.sleep_manager.is_woken_up():
|
||||
angry_chat_id = self.wakeup_manager.angry_chat_id
|
||||
if stream_id != angry_chat_id:
|
||||
logger.debug(f"Bot处于WOKEN_UP状态,但当前流 {stream_id} 不是触发唤醒的流 {angry_chat_id},跳过处理。")
|
||||
return # 退出,不处理此流的消息
|
||||
logger.info(f"Bot处于WOKEN_UP状态,处理触发唤醒的流 {stream_id}。")
|
||||
# --- 睡眠状态检查结束 ---
|
||||
|
||||
logger.debug(f"开始处理聊天流 {stream_id} 的 {len(unread_messages)} 条未读消息")
|
||||
|
||||
# 直接使用StreamContext对象进行处理
|
||||
if unread_messages:
|
||||
try:
|
||||
# 记录当前chat type用于调试
|
||||
logger.debug(f"聊天流 {stream_id} 检测到的chat type: {context.chat_type.value}")
|
||||
|
||||
# 发送到chatter manager,传递StreamContext对象
|
||||
results = await self.chatter_manager.process_stream_context(stream_id, context)
|
||||
|
||||
# 处理结果,标记消息为已读
|
||||
if results.get("success", False):
|
||||
self._clear_all_unread_messages(stream_id)
|
||||
logger.debug(f"聊天流 {stream_id} 处理成功,清除了 {len(unread_messages)} 条未读消息")
|
||||
else:
|
||||
logger.warning(f"聊天流 {stream_id} 处理失败: {results.get('error_message', '未知错误')}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理聊天流 {stream_id} 时发生异常,将清除所有未读消息: {e}")
|
||||
# 出现异常时也清除未读消息,避免重复处理
|
||||
self._clear_all_unread_messages(stream_id)
|
||||
raise
|
||||
|
||||
logger.debug(f"聊天流 {stream_id} 消息处理完成")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"处理聊天流 {stream_id} 消息时出错: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
def deactivate_stream(self, stream_id: str):
|
||||
"""停用聊天流"""
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if context:
|
||||
context.is_active = False
|
||||
|
||||
# 取消处理任务
|
||||
if hasattr(context, 'processing_task') and context.processing_task and not context.processing_task.done():
|
||||
context.processing_task.cancel()
|
||||
|
||||
logger.info(f"停用聊天流: {stream_id}")
|
||||
|
||||
def activate_stream(self, stream_id: str):
|
||||
"""激活聊天流"""
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if context:
|
||||
context.is_active = True
|
||||
logger.info(f"激活聊天流: {stream_id}")
|
||||
|
||||
def get_stream_stats(self, stream_id: str) -> Optional[StreamStats]:
|
||||
"""获取聊天流统计"""
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if not context:
|
||||
return None
|
||||
|
||||
return StreamStats(
|
||||
stream_id=stream_id,
|
||||
is_active=context.is_active,
|
||||
unread_count=len(self.context_manager.get_unread_messages(stream_id)),
|
||||
history_count=len(context.history_messages),
|
||||
last_check_time=context.last_check_time,
|
||||
has_active_task=bool(hasattr(context, 'processing_task') and context.processing_task and not context.processing_task.done()),
|
||||
)
|
||||
|
||||
def get_manager_stats(self) -> Dict[str, Any]:
|
||||
"""获取管理器统计"""
|
||||
return {
|
||||
"total_streams": self.stats.total_streams,
|
||||
"active_streams": self.stats.active_streams,
|
||||
"total_unread_messages": self.stats.total_unread_messages,
|
||||
"total_processed_messages": self.stats.total_processed_messages,
|
||||
"uptime": self.stats.uptime,
|
||||
"start_time": self.stats.start_time,
|
||||
}
|
||||
|
||||
def cleanup_inactive_streams(self, max_inactive_hours: int = 24):
|
||||
"""清理不活跃的聊天流"""
|
||||
# 使用 context_manager 的自动清理功能
|
||||
self.context_manager.cleanup_inactive_contexts(max_inactive_hours * 3600)
|
||||
logger.info("已启动不活跃聊天流清理")
|
||||
|
||||
async def _check_and_handle_interruption(self, context: StreamContext, stream_id: str):
|
||||
"""检查并处理消息打断"""
|
||||
if not global_config.chat.interruption_enabled:
|
||||
return
|
||||
|
||||
# 检查是否有正在进行的处理任务
|
||||
if context.processing_task and not context.processing_task.done():
|
||||
# 计算打断概率
|
||||
interruption_probability = context.calculate_interruption_probability(
|
||||
global_config.chat.interruption_max_limit, global_config.chat.interruption_probability_factor
|
||||
)
|
||||
|
||||
# 检查是否已达到最大打断次数
|
||||
if context.interruption_count >= global_config.chat.interruption_max_limit:
|
||||
logger.debug(
|
||||
f"聊天流 {stream_id} 已达到最大打断次数 {context.interruption_count}/{global_config.chat.interruption_max_limit},跳过打断检查"
|
||||
)
|
||||
return
|
||||
|
||||
# 根据概率决定是否打断
|
||||
if random.random() < interruption_probability:
|
||||
logger.info(f"聊天流 {stream_id} 触发消息打断,打断概率: {interruption_probability:.2f}")
|
||||
|
||||
# 取消现有任务
|
||||
context.processing_task.cancel()
|
||||
try:
|
||||
await context.processing_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 增加打断计数并应用afc阈值降低
|
||||
context.increment_interruption_count()
|
||||
context.apply_interruption_afc_reduction(global_config.chat.interruption_afc_reduction)
|
||||
|
||||
# 检查是否已达到最大次数
|
||||
if context.interruption_count >= global_config.chat.interruption_max_limit:
|
||||
logger.warning(
|
||||
f"聊天流 {stream_id} 已达到最大打断次数 {context.interruption_count}/{global_config.chat.interruption_max_limit},后续消息将不再打断"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"聊天流 {stream_id} 已打断,当前打断次数: {context.interruption_count}/{global_config.chat.interruption_max_limit}, afc阈值调整: {context.get_afc_threshold_adjustment()}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"聊天流 {stream_id} 未触发打断,打断概率: {interruption_probability:.2f}")
|
||||
|
||||
def _calculate_stream_distribution_interval(self, context: StreamContext) -> float:
|
||||
"""计算单个聊天流的分发周期 - 使用重构后的能量管理器"""
|
||||
if not global_config.chat.dynamic_distribution_enabled:
|
||||
return self.check_interval # 使用固定间隔
|
||||
|
||||
try:
|
||||
from src.chat.energy_system import energy_manager
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
|
||||
# 获取聊天流和能量
|
||||
chat_stream = get_chat_manager().get_stream(context.stream_id)
|
||||
if chat_stream:
|
||||
focus_energy = chat_stream.focus_energy
|
||||
# 使用能量管理器获取分发周期
|
||||
interval = energy_manager.get_distribution_interval(focus_energy)
|
||||
logger.debug(f"流 {context.stream_id} 分发周期: {interval:.2f}s (能量: {focus_energy:.3f})")
|
||||
return interval
|
||||
else:
|
||||
# 默认间隔
|
||||
return self.check_interval
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算分发周期失败: {e}")
|
||||
return self.check_interval
|
||||
|
||||
def _calculate_next_manager_delay(self) -> float:
|
||||
"""计算管理器下次检查的延迟时间"""
|
||||
current_time = time.time()
|
||||
min_delay = float("inf")
|
||||
|
||||
# 找到最近需要检查的流
|
||||
active_stream_ids = self.context_manager.get_active_streams()
|
||||
for stream_id in active_stream_ids:
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if not context or not context.is_active:
|
||||
continue
|
||||
|
||||
time_until_check = context.next_check_time - current_time
|
||||
if time_until_check > 0:
|
||||
min_delay = min(min_delay, time_until_check)
|
||||
else:
|
||||
min_delay = 0.1 # 立即检查
|
||||
break
|
||||
|
||||
# 如果没有活跃流,使用默认间隔
|
||||
if min_delay == float("inf"):
|
||||
return self.check_interval
|
||||
|
||||
# 确保最小延迟
|
||||
return max(0.1, min(min_delay, self.check_interval))
|
||||
|
||||
async def _check_streams_with_individual_intervals(self):
|
||||
"""检查所有达到检查时间的聊天流"""
|
||||
current_time = time.time()
|
||||
processed_streams = 0
|
||||
|
||||
# 使用 context_manager 获取活跃的流
|
||||
active_stream_ids = self.context_manager.get_active_streams()
|
||||
|
||||
for stream_id in active_stream_ids:
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if not context or not context.is_active:
|
||||
continue
|
||||
|
||||
# 检查是否达到检查时间
|
||||
if current_time >= context.next_check_time:
|
||||
# 更新检查时间
|
||||
context.last_check_time = current_time
|
||||
|
||||
# 计算下次检查时间和分发周期
|
||||
if global_config.chat.dynamic_distribution_enabled:
|
||||
context.distribution_interval = self._calculate_stream_distribution_interval(context)
|
||||
else:
|
||||
context.distribution_interval = self.check_interval
|
||||
|
||||
# 设置下次检查时间
|
||||
context.next_check_time = current_time + context.distribution_interval
|
||||
|
||||
# 检查未读消息
|
||||
unread_messages = self.context_manager.get_unread_messages(stream_id)
|
||||
if unread_messages:
|
||||
processed_streams += 1
|
||||
self.stats.total_unread_messages = len(unread_messages)
|
||||
|
||||
# 如果没有处理任务,创建一个
|
||||
if not context.processing_task or context.processing_task.done():
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
|
||||
chat_stream = get_chat_manager().get_stream(context.stream_id)
|
||||
focus_energy = chat_stream.focus_energy if chat_stream else 0.5
|
||||
|
||||
# 根据优先级记录日志
|
||||
if focus_energy >= 0.7:
|
||||
logger.info(
|
||||
f"高优先级流 {stream_id} 开始处理 | "
|
||||
f"focus_energy: {focus_energy:.3f} | "
|
||||
f"分发周期: {context.distribution_interval:.2f}s | "
|
||||
f"未读消息: {len(unread_messages)}"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"流 {stream_id} 开始处理 | "
|
||||
f"focus_energy: {focus_energy:.3f} | "
|
||||
f"分发周期: {context.distribution_interval:.2f}s"
|
||||
)
|
||||
|
||||
context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
|
||||
|
||||
# 更新活跃流计数
|
||||
active_count = len(self.context_manager.get_active_streams())
|
||||
self.stats.active_streams = active_count
|
||||
|
||||
if processed_streams > 0:
|
||||
logger.debug(f"本次循环处理了 {processed_streams} 个流 | 活跃流总数: {active_count}")
|
||||
|
||||
async def _check_all_streams_with_priority(self):
|
||||
"""按优先级检查所有聊天流,高focus_energy的流优先处理"""
|
||||
if not self.context_manager.get_active_streams():
|
||||
return
|
||||
|
||||
# 获取活跃的聊天流并按focus_energy排序
|
||||
active_streams = []
|
||||
active_stream_ids = self.context_manager.get_active_streams()
|
||||
|
||||
for stream_id in active_stream_ids:
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if not context or not context.is_active:
|
||||
continue
|
||||
|
||||
# 获取focus_energy,如果不存在则使用默认值
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
|
||||
chat_stream = get_chat_manager().get_stream(context.stream_id)
|
||||
focus_energy = 0.5
|
||||
if chat_stream:
|
||||
focus_energy = chat_stream.focus_energy
|
||||
|
||||
# 计算流优先级分数
|
||||
priority_score = self._calculate_stream_priority(context, focus_energy)
|
||||
active_streams.append((priority_score, stream_id, context))
|
||||
|
||||
# 按优先级降序排序
|
||||
active_streams.sort(reverse=True, key=lambda x: x[0])
|
||||
|
||||
# 处理排序后的流
|
||||
active_stream_count = 0
|
||||
total_unread = 0
|
||||
|
||||
for priority_score, stream_id, context in active_streams:
|
||||
active_stream_count += 1
|
||||
|
||||
# 检查是否有未读消息
|
||||
unread_messages = self.context_manager.get_unread_messages(stream_id)
|
||||
if unread_messages:
|
||||
total_unread += len(unread_messages)
|
||||
|
||||
# 如果没有处理任务,创建一个
|
||||
if not hasattr(context, 'processing_task') or not context.processing_task or context.processing_task.done():
|
||||
context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
|
||||
|
||||
# 高优先级流的额外日志
|
||||
if priority_score > 0.7:
|
||||
logger.info(
|
||||
f"高优先级流 {stream_id} 开始处理 | "
|
||||
f"优先级: {priority_score:.3f} | "
|
||||
f"未读消息: {len(unread_messages)}"
|
||||
)
|
||||
|
||||
# 更新统计
|
||||
self.stats.active_streams = active_stream_count
|
||||
self.stats.total_unread_messages = total_unread
|
||||
|
||||
def _calculate_stream_priority(self, context: StreamContext, focus_energy: float) -> float:
|
||||
"""计算聊天流的优先级分数 - 简化版本,主要使用focus_energy"""
|
||||
# 使用重构后的能量管理器,主要依赖focus_energy
|
||||
base_priority = focus_energy
|
||||
|
||||
# 简单的未读消息加权
|
||||
unread_count = len(context.get_unread_messages())
|
||||
message_bonus = min(unread_count * 0.05, 0.2) # 最多20%加成
|
||||
|
||||
# 简单的时间加权
|
||||
current_time = time.time()
|
||||
time_since_active = current_time - context.last_check_time
|
||||
time_bonus = max(0, 1.0 - time_since_active / 7200.0) * 0.1 # 2小时内衰减
|
||||
|
||||
final_priority = base_priority + message_bonus + time_bonus
|
||||
return max(0.0, min(1.0, final_priority))
|
||||
|
||||
def _clear_all_unread_messages(self, stream_id: str):
|
||||
"""清除指定上下文中的所有未读消息,防止意外情况导致消息一直未读"""
|
||||
unread_messages = self.context_manager.get_unread_messages(stream_id)
|
||||
if not unread_messages:
|
||||
return
|
||||
|
||||
logger.warning(f"正在清除 {len(unread_messages)} 条未读消息")
|
||||
|
||||
# 将所有未读消息标记为已读
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if context:
|
||||
for msg in unread_messages[:]: # 使用切片复制避免迭代时修改列表
|
||||
try:
|
||||
context.mark_message_as_read(msg.message_id)
|
||||
self.stats.total_processed_messages += 1
|
||||
logger.debug(f"强制清除消息 {msg.message_id},标记为已读")
|
||||
except Exception as e:
|
||||
logger.error(f"清除消息 {msg.message_id} 时出错: {e}")
|
||||
|
||||
|
||||
# 创建全局消息管理器实例
|
||||
message_manager = MessageManager()
|
||||
@@ -0,0 +1,33 @@
|
||||
from src.common.logger import get_logger
|
||||
|
||||
#from ..hfc_context import HfcContext
|
||||
|
||||
logger = get_logger("notification_sender")
|
||||
|
||||
|
||||
class NotificationSender:
|
||||
@staticmethod
|
||||
async def send_goodnight_notification(context): # type: ignore
|
||||
"""发送晚安通知"""
|
||||
#try:
|
||||
#from ..proactive.events import ProactiveTriggerEvent
|
||||
#from ..proactive.proactive_thinker import ProactiveThinker
|
||||
|
||||
#event = ProactiveTriggerEvent(source="sleep_manager", reason="goodnight")
|
||||
#proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor)
|
||||
#await proactive_thinker.think(event)
|
||||
#except Exception as e:
|
||||
#logger.error(f"发送晚安通知失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def send_insomnia_notification(context, reason: str): # type: ignore
|
||||
"""发送失眠通知"""
|
||||
#try:
|
||||
#from ..proactive.events import ProactiveTriggerEvent
|
||||
#from ..proactive.proactive_thinker import ProactiveThinker
|
||||
|
||||
#event = ProactiveTriggerEvent(source="sleep_manager", reason=reason)
|
||||
#proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor)
|
||||
#await proactive_thinker.think(event)
|
||||
#except Exception as e:
|
||||
#logger.error(f"发送失眠通知失败: {e}")
|
||||
@@ -6,11 +6,11 @@ from typing import Optional, TYPE_CHECKING
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from .notification_sender import NotificationSender
|
||||
from .sleep_state import SleepState, SleepStateSerializer
|
||||
from .sleep_state import SleepState, SleepContext
|
||||
from .time_checker import TimeChecker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
from .wakeup_manager import WakeUpManager
|
||||
|
||||
logger = get_logger("sleep_manager")
|
||||
|
||||
@@ -25,28 +25,23 @@ class SleepManager:
|
||||
"""
|
||||
初始化睡眠管理器。
|
||||
"""
|
||||
self.time_checker = TimeChecker() # 时间检查器,用于判断当前是否处于理论睡眠时间
|
||||
self.context = SleepContext() # 睡眠上下文,管理所有状态
|
||||
self.time_checker = TimeChecker() # 时间检查器
|
||||
self.last_sleep_log_time = 0 # 上次记录睡眠日志的时间戳
|
||||
self.sleep_log_interval = 35 # 睡眠日志记录间隔(秒)
|
||||
|
||||
# --- 统一睡眠状态管理 ---
|
||||
self._current_state: SleepState = SleepState.AWAKE # 当前睡眠状态
|
||||
self._sleep_buffer_end_time: Optional[datetime] = None # 睡眠缓冲结束时间,用于状态转换
|
||||
self._total_delayed_minutes_today: float = 0.0 # 今天总共延迟入睡的分钟数
|
||||
self._last_sleep_check_date: Optional[date] = None # 上次检查睡眠状态的日期
|
||||
self._last_fully_slept_log_time: float = 0 # 上次完全进入睡眠状态的时间戳
|
||||
self._re_sleep_attempt_time: Optional[datetime] = None # 被吵醒后,尝试重新入睡的时间点
|
||||
|
||||
# 从本地存储加载上一次的睡眠状态
|
||||
self._load_sleep_state()
|
||||
|
||||
def get_current_sleep_state(self) -> SleepState:
|
||||
"""获取当前的睡眠状态。"""
|
||||
return self._current_state
|
||||
return self.context.current_state
|
||||
|
||||
def is_sleeping(self) -> bool:
|
||||
"""判断当前是否处于正在睡觉的状态。"""
|
||||
return self._current_state == SleepState.SLEEPING
|
||||
return self.context.current_state == SleepState.SLEEPING
|
||||
|
||||
def is_woken_up(self) -> bool:
|
||||
"""判断当前是否处于被吵醒的状态。"""
|
||||
return self.context.current_state == SleepState.WOKEN_UP
|
||||
|
||||
async def update_sleep_state(self, wakeup_manager: Optional["WakeUpManager"] = None):
|
||||
"""
|
||||
@@ -58,41 +53,42 @@ class SleepManager:
|
||||
"""
|
||||
# 如果全局禁用了睡眠系统,则强制设置为清醒状态并返回
|
||||
if not global_config.sleep_system.enable:
|
||||
if self._current_state != SleepState.AWAKE:
|
||||
if self.context.current_state != SleepState.AWAKE:
|
||||
logger.debug("睡眠系统禁用,强制设为 AWAKE")
|
||||
self._current_state = SleepState.AWAKE
|
||||
self.context.current_state = SleepState.AWAKE
|
||||
return
|
||||
|
||||
now = datetime.now()
|
||||
today = now.date()
|
||||
|
||||
# 跨天处理:如果日期变化,重置每日相关的睡眠状态
|
||||
if self._last_sleep_check_date != today:
|
||||
if self.context.last_sleep_check_date != today:
|
||||
logger.info(f"新的一天 ({today}),重置睡眠状态。")
|
||||
self._total_delayed_minutes_today = 0
|
||||
self._current_state = SleepState.AWAKE
|
||||
self._sleep_buffer_end_time = None
|
||||
self._last_sleep_check_date = today
|
||||
self._save_sleep_state()
|
||||
self.context.total_delayed_minutes_today = 0
|
||||
self.context.current_state = SleepState.AWAKE
|
||||
self.context.sleep_buffer_end_time = None
|
||||
self.context.last_sleep_check_date = today
|
||||
self.context.save()
|
||||
|
||||
# 检查当前是否处于理论上的睡眠时间段
|
||||
is_in_theoretical_sleep, activity = self.time_checker.is_in_theoretical_sleep_time(now.time())
|
||||
|
||||
# --- 状态机核心处理逻辑 ---
|
||||
if self._current_state == SleepState.AWAKE:
|
||||
current_state = self.context.current_state
|
||||
if current_state == SleepState.AWAKE:
|
||||
if is_in_theoretical_sleep:
|
||||
self._handle_awake_to_sleep(now, activity, wakeup_manager)
|
||||
|
||||
elif self._current_state == SleepState.PREPARING_SLEEP:
|
||||
elif current_state == SleepState.PREPARING_SLEEP:
|
||||
self._handle_preparing_sleep(now, is_in_theoretical_sleep, wakeup_manager)
|
||||
|
||||
elif self._current_state == SleepState.SLEEPING:
|
||||
elif current_state == SleepState.SLEEPING:
|
||||
self._handle_sleeping(now, is_in_theoretical_sleep, activity, wakeup_manager)
|
||||
|
||||
elif self._current_state == SleepState.INSOMNIA:
|
||||
elif current_state == SleepState.INSOMNIA:
|
||||
self._handle_insomnia(now, is_in_theoretical_sleep)
|
||||
|
||||
elif self._current_state == SleepState.WOKEN_UP:
|
||||
elif current_state == SleepState.WOKEN_UP:
|
||||
self._handle_woken_up(now, is_in_theoretical_sleep, wakeup_manager)
|
||||
|
||||
def _handle_awake_to_sleep(self, now: datetime, activity: Optional[str], wakeup_manager: Optional["WakeUpManager"]):
|
||||
@@ -118,13 +114,13 @@ class SleepManager:
|
||||
delay_minutes = int(pressure_diff * max_delay_minutes)
|
||||
|
||||
# 确保总延迟不超过当日最大值
|
||||
remaining_delay = max_delay_minutes - self._total_delayed_minutes_today
|
||||
remaining_delay = max_delay_minutes - self.context.total_delayed_minutes_today
|
||||
delay_minutes = min(delay_minutes, remaining_delay)
|
||||
|
||||
if delay_minutes > 0:
|
||||
# 增加一些随机性
|
||||
buffer_seconds = random.randint(int(delay_minutes * 0.8 * 60), int(delay_minutes * 1.2 * 60))
|
||||
self._total_delayed_minutes_today += buffer_seconds / 60.0
|
||||
self.context.total_delayed_minutes_today += buffer_seconds / 60.0
|
||||
logger.info(f"睡眠压力 ({sleep_pressure:.1f}) 较低,延迟 {buffer_seconds / 60:.1f} 分钟入睡。")
|
||||
else:
|
||||
# 延迟额度已用完,设置一个较短的准备时间
|
||||
@@ -139,22 +135,22 @@ class SleepManager:
|
||||
if global_config.sleep_system.enable_pre_sleep_notification:
|
||||
asyncio.create_task(NotificationSender.send_goodnight_notification(wakeup_manager.context))
|
||||
|
||||
self._sleep_buffer_end_time = now + timedelta(seconds=buffer_seconds)
|
||||
self._current_state = SleepState.PREPARING_SLEEP
|
||||
self.context.sleep_buffer_end_time = now + timedelta(seconds=buffer_seconds)
|
||||
self.context.current_state = SleepState.PREPARING_SLEEP
|
||||
logger.info(f"进入准备入睡状态,将在 {buffer_seconds / 60:.1f} 分钟内入睡。")
|
||||
self._save_sleep_state()
|
||||
self.context.save()
|
||||
else:
|
||||
# 无法获取 wakeup_manager,退回旧逻辑
|
||||
buffer_seconds = random.randint(1 * 60, 3 * 60)
|
||||
self._sleep_buffer_end_time = now + timedelta(seconds=buffer_seconds)
|
||||
self._current_state = SleepState.PREPARING_SLEEP
|
||||
self.context.sleep_buffer_end_time = now + timedelta(seconds=buffer_seconds)
|
||||
self.context.current_state = SleepState.PREPARING_SLEEP
|
||||
logger.warning("无法获取 WakeUpManager,弹性睡眠采用默认1-3分钟延迟。")
|
||||
self._save_sleep_state()
|
||||
self.context.save()
|
||||
else:
|
||||
# 非弹性睡眠模式
|
||||
if wakeup_manager and global_config.sleep_system.enable_pre_sleep_notification:
|
||||
asyncio.create_task(NotificationSender.send_goodnight_notification(wakeup_manager.context))
|
||||
self._current_state = SleepState.SLEEPING
|
||||
self.context.current_state = SleepState.SLEEPING
|
||||
|
||||
|
||||
def _handle_preparing_sleep(self, now: datetime, is_in_theoretical_sleep: bool, wakeup_manager: Optional["WakeUpManager"]):
|
||||
@@ -162,32 +158,32 @@ class SleepManager:
|
||||
# 如果在准备期间离开了理论睡眠时间,则取消入睡
|
||||
if not is_in_theoretical_sleep:
|
||||
logger.info("准备入睡期间离开理论休眠时间,取消入睡,恢复清醒。")
|
||||
self._current_state = SleepState.AWAKE
|
||||
self._sleep_buffer_end_time = None
|
||||
self._save_sleep_state()
|
||||
self.context.current_state = SleepState.AWAKE
|
||||
self.context.sleep_buffer_end_time = None
|
||||
self.context.save()
|
||||
# 如果缓冲时间结束,则正式进入睡眠状态
|
||||
elif self._sleep_buffer_end_time and now >= self._sleep_buffer_end_time:
|
||||
elif self.context.sleep_buffer_end_time and now >= self.context.sleep_buffer_end_time:
|
||||
logger.info("睡眠缓冲期结束,正式进入休眠状态。")
|
||||
self._current_state = SleepState.SLEEPING
|
||||
self.context.current_state = SleepState.SLEEPING
|
||||
self._last_fully_slept_log_time = now.timestamp()
|
||||
|
||||
# 设置一个随机的延迟,用于触发“睡后失眠”检查
|
||||
delay_minutes_range = global_config.sleep_system.insomnia_trigger_delay_minutes
|
||||
delay_minutes = random.randint(delay_minutes_range[0], delay_minutes_range[1])
|
||||
self._sleep_buffer_end_time = now + timedelta(minutes=delay_minutes)
|
||||
self.context.sleep_buffer_end_time = now + timedelta(minutes=delay_minutes)
|
||||
logger.info(f"已设置睡后失眠检查,将在 {delay_minutes} 分钟后触发。")
|
||||
|
||||
self._save_sleep_state()
|
||||
self.context.save()
|
||||
|
||||
def _handle_sleeping(self, now: datetime, is_in_theoretical_sleep: bool, activity: Optional[str], wakeup_manager: Optional["WakeUpManager"]):
|
||||
"""处理“正在睡觉”状态下的逻辑。"""
|
||||
# 如果理论睡眠时间结束,则自然醒来
|
||||
if not is_in_theoretical_sleep:
|
||||
logger.info("理论休眠时间结束,自然醒来。")
|
||||
self._current_state = SleepState.AWAKE
|
||||
self._save_sleep_state()
|
||||
self.context.current_state = SleepState.AWAKE
|
||||
self.context.save()
|
||||
# 检查是否到了触发“睡后失眠”的时间点
|
||||
elif self._sleep_buffer_end_time and now >= self._sleep_buffer_end_time:
|
||||
elif self.context.sleep_buffer_end_time and now >= self.context.sleep_buffer_end_time:
|
||||
if wakeup_manager:
|
||||
sleep_pressure = wakeup_manager.context.sleep_pressure
|
||||
pressure_threshold = global_config.sleep_system.flexible_sleep_pressure_threshold
|
||||
@@ -201,12 +197,12 @@ class SleepManager:
|
||||
logger.info("随机触发失眠。")
|
||||
|
||||
if insomnia_reason:
|
||||
self._current_state = SleepState.INSOMNIA
|
||||
self.context.current_state = SleepState.INSOMNIA
|
||||
|
||||
# 设置失眠的持续时间
|
||||
duration_minutes_range = global_config.sleep_system.insomnia_duration_minutes
|
||||
duration_minutes = random.randint(*duration_minutes_range)
|
||||
self._sleep_buffer_end_time = now + timedelta(minutes=duration_minutes)
|
||||
self.context.sleep_buffer_end_time = now + timedelta(minutes=duration_minutes)
|
||||
|
||||
# 发送失眠通知
|
||||
asyncio.create_task(NotificationSender.send_insomnia_notification(wakeup_manager.context, insomnia_reason))
|
||||
@@ -214,8 +210,8 @@ class SleepManager:
|
||||
else:
|
||||
# 睡眠压力正常,不触发失眠,清除检查时间点
|
||||
logger.info(f"睡眠压力 ({sleep_pressure:.1f}) 正常,未触发睡后失眠。")
|
||||
self._sleep_buffer_end_time = None
|
||||
self._save_sleep_state()
|
||||
self.context.sleep_buffer_end_time = None
|
||||
self.context.save()
|
||||
else:
|
||||
# 定期记录睡眠日志
|
||||
current_timestamp = now.timestamp()
|
||||
@@ -228,26 +224,26 @@ class SleepManager:
|
||||
# 如果离开理论睡眠时间,则失眠结束
|
||||
if not is_in_theoretical_sleep:
|
||||
logger.info("已离开理论休眠时间,失眠结束,恢复清醒。")
|
||||
self._current_state = SleepState.AWAKE
|
||||
self._sleep_buffer_end_time = None
|
||||
self._save_sleep_state()
|
||||
self.context.current_state = SleepState.AWAKE
|
||||
self.context.sleep_buffer_end_time = None
|
||||
self.context.save()
|
||||
# 如果失眠持续时间已过,则恢复睡眠
|
||||
elif self._sleep_buffer_end_time and now >= self._sleep_buffer_end_time:
|
||||
elif self.context.sleep_buffer_end_time and now >= self.context.sleep_buffer_end_time:
|
||||
logger.info("失眠状态持续时间已过,恢复睡眠。")
|
||||
self._current_state = SleepState.SLEEPING
|
||||
self._sleep_buffer_end_time = None
|
||||
self._save_sleep_state()
|
||||
self.context.current_state = SleepState.SLEEPING
|
||||
self.context.sleep_buffer_end_time = None
|
||||
self.context.save()
|
||||
|
||||
def _handle_woken_up(self, now: datetime, is_in_theoretical_sleep: bool, wakeup_manager: Optional["WakeUpManager"]):
|
||||
"""处理“被吵醒”状态下的逻辑。"""
|
||||
# 如果理论睡眠时间结束,则状态自动结束
|
||||
if not is_in_theoretical_sleep:
|
||||
logger.info("理论休眠时间结束,被吵醒的状态自动结束。")
|
||||
self._current_state = SleepState.AWAKE
|
||||
self._re_sleep_attempt_time = None
|
||||
self._save_sleep_state()
|
||||
self.context.current_state = SleepState.AWAKE
|
||||
self.context.re_sleep_attempt_time = None
|
||||
self.context.save()
|
||||
# 到了尝试重新入睡的时间点
|
||||
elif self._re_sleep_attempt_time and now >= self._re_sleep_attempt_time:
|
||||
elif self.context.re_sleep_attempt_time and now >= self.context.re_sleep_attempt_time:
|
||||
logger.info("被吵醒后经过一段时间,尝试重新入睡...")
|
||||
if wakeup_manager:
|
||||
sleep_pressure = wakeup_manager.context.sleep_pressure
|
||||
@@ -257,48 +253,28 @@ class SleepManager:
|
||||
if sleep_pressure >= pressure_threshold:
|
||||
logger.info("睡眠压力足够,从被吵醒状态转换到准备入睡。")
|
||||
buffer_seconds = random.randint(3 * 60, 8 * 60)
|
||||
self._sleep_buffer_end_time = now + timedelta(seconds=buffer_seconds)
|
||||
self._current_state = SleepState.PREPARING_SLEEP
|
||||
self._re_sleep_attempt_time = None
|
||||
self.context.sleep_buffer_end_time = now + timedelta(seconds=buffer_seconds)
|
||||
self.context.current_state = SleepState.PREPARING_SLEEP
|
||||
self.context.re_sleep_attempt_time = None
|
||||
else:
|
||||
# 睡眠压力不足,延迟一段时间后再次尝试
|
||||
delay_minutes = 15
|
||||
self._re_sleep_attempt_time = now + timedelta(minutes=delay_minutes)
|
||||
self.context.re_sleep_attempt_time = now + timedelta(minutes=delay_minutes)
|
||||
logger.info(
|
||||
f"睡眠压力({sleep_pressure:.1f})仍然较低,暂时保持清醒,在 {delay_minutes} 分钟后再次尝试。"
|
||||
)
|
||||
self._save_sleep_state()
|
||||
self.context.save()
|
||||
|
||||
def reset_sleep_state_after_wakeup(self):
|
||||
"""
|
||||
当角色被用户消息等外部因素唤醒时调用此方法。
|
||||
将状态强制转换为 WOKEN_UP,并设置一个延迟,之后会尝试重新入睡。
|
||||
"""
|
||||
if self._current_state in [SleepState.PREPARING_SLEEP, SleepState.SLEEPING, SleepState.INSOMNIA]:
|
||||
if self.context.current_state in [SleepState.PREPARING_SLEEP, SleepState.SLEEPING, SleepState.INSOMNIA]:
|
||||
logger.info("被唤醒,进入 WOKEN_UP 状态!")
|
||||
self._current_state = SleepState.WOKEN_UP
|
||||
self._sleep_buffer_end_time = None
|
||||
self.context.current_state = SleepState.WOKEN_UP
|
||||
self.context.sleep_buffer_end_time = None
|
||||
re_sleep_delay_minutes = getattr(global_config.sleep_system, "re_sleep_delay_minutes", 10)
|
||||
self._re_sleep_attempt_time = datetime.now() + timedelta(minutes=re_sleep_delay_minutes)
|
||||
self.context.re_sleep_attempt_time = datetime.now() + timedelta(minutes=re_sleep_delay_minutes)
|
||||
logger.info(f"将在 {re_sleep_delay_minutes} 分钟后尝试重新入睡。")
|
||||
self._save_sleep_state()
|
||||
|
||||
def _save_sleep_state(self):
|
||||
"""将当前所有睡眠相关的状态打包并保存到本地存储。"""
|
||||
state_data = {
|
||||
"_current_state": self._current_state,
|
||||
"_sleep_buffer_end_time": self._sleep_buffer_end_time,
|
||||
"_total_delayed_minutes_today": self._total_delayed_minutes_today,
|
||||
"_last_sleep_check_date": self._last_sleep_check_date,
|
||||
"_re_sleep_attempt_time": self._re_sleep_attempt_time,
|
||||
}
|
||||
SleepStateSerializer.save(state_data)
|
||||
|
||||
def _load_sleep_state(self):
|
||||
"""从本地存储加载并恢复所有睡眠相关的状态。"""
|
||||
state_data = SleepStateSerializer.load()
|
||||
self._current_state = state_data["_current_state"]
|
||||
self._sleep_buffer_end_time = state_data["_sleep_buffer_end_time"]
|
||||
self._total_delayed_minutes_today = state_data["_total_delayed_minutes_today"]
|
||||
self._last_sleep_check_date = state_data["_last_sleep_check_date"]
|
||||
self._re_sleep_attempt_time = state_data["_re_sleep_attempt_time"]
|
||||
self.context.save()
|
||||
86
src/chat/message_manager/sleep_manager/sleep_state.py
Normal file
86
src/chat/message_manager/sleep_manager/sleep_state.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from enum import Enum, auto
|
||||
from datetime import datetime, date
|
||||
from typing import Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.manager.local_store_manager import local_storage
|
||||
|
||||
logger = get_logger("sleep_state")
|
||||
|
||||
|
||||
class SleepState(Enum):
|
||||
"""
|
||||
定义了角色可能处于的几种睡眠状态。
|
||||
这是一个状态机,用于管理角色的睡眠周期。
|
||||
"""
|
||||
|
||||
AWAKE = auto() # 清醒状态
|
||||
INSOMNIA = auto() # 失眠状态
|
||||
PREPARING_SLEEP = auto() # 准备入睡状态,一个短暂的过渡期
|
||||
SLEEPING = auto() # 正在睡觉状态
|
||||
WOKEN_UP = auto() # 被吵醒状态
|
||||
|
||||
|
||||
class SleepContext:
|
||||
"""
|
||||
睡眠上下文,负责封装和管理所有与睡眠相关的状态,并处理其持久化。
|
||||
"""
|
||||
def __init__(self):
|
||||
"""初始化睡眠上下文,并从本地存储加载初始状态。"""
|
||||
self.current_state: SleepState = SleepState.AWAKE
|
||||
self.sleep_buffer_end_time: Optional[datetime] = None
|
||||
self.total_delayed_minutes_today: float = 0.0
|
||||
self.last_sleep_check_date: Optional[date] = None
|
||||
self.re_sleep_attempt_time: Optional[datetime] = None
|
||||
self.load()
|
||||
|
||||
def save(self):
|
||||
"""将当前的睡眠状态数据保存到本地存储。"""
|
||||
try:
|
||||
state = {
|
||||
"current_state": self.current_state.name,
|
||||
"sleep_buffer_end_time_ts": self.sleep_buffer_end_time.timestamp()
|
||||
if self.sleep_buffer_end_time
|
||||
else None,
|
||||
"total_delayed_minutes_today": self.total_delayed_minutes_today,
|
||||
"last_sleep_check_date_str": self.last_sleep_check_date.isoformat()
|
||||
if self.last_sleep_check_date
|
||||
else None,
|
||||
"re_sleep_attempt_time_ts": self.re_sleep_attempt_time.timestamp()
|
||||
if self.re_sleep_attempt_time
|
||||
else None,
|
||||
}
|
||||
local_storage["schedule_sleep_state"] = state
|
||||
logger.debug(f"已保存睡眠上下文: {state}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存睡眠上下文失败: {e}")
|
||||
|
||||
def load(self):
|
||||
"""从本地存储加载并解析睡眠状态。"""
|
||||
try:
|
||||
state = local_storage["schedule_sleep_state"]
|
||||
if not (state and isinstance(state, dict)):
|
||||
logger.info("未找到本地睡眠上下文,使用默认值。")
|
||||
return
|
||||
|
||||
state_name = state.get("current_state")
|
||||
if state_name and hasattr(SleepState, state_name):
|
||||
self.current_state = SleepState[state_name]
|
||||
|
||||
end_time_ts = state.get("sleep_buffer_end_time_ts")
|
||||
if end_time_ts:
|
||||
self.sleep_buffer_end_time = datetime.fromtimestamp(end_time_ts)
|
||||
|
||||
re_sleep_ts = state.get("re_sleep_attempt_time_ts")
|
||||
if re_sleep_ts:
|
||||
self.re_sleep_attempt_time = datetime.fromtimestamp(re_sleep_ts)
|
||||
|
||||
self.total_delayed_minutes_today = state.get("total_delayed_minutes_today", 0.0)
|
||||
|
||||
date_str = state.get("last_sleep_check_date_str")
|
||||
if date_str:
|
||||
self.last_sleep_check_date = datetime.fromisoformat(date_str).date()
|
||||
|
||||
logger.info(f"成功从本地存储加载睡眠上下文: {state}")
|
||||
except Exception as e:
|
||||
logger.warning(f"加载睡眠上下文失败,将使用默认值: {e}")
|
||||
45
src/chat/message_manager/sleep_manager/wakeup_context.py
Normal file
45
src/chat/message_manager/sleep_manager/wakeup_context.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import time
|
||||
from src.common.logger import get_logger
|
||||
from src.manager.local_store_manager import local_storage
|
||||
|
||||
logger = get_logger("wakeup_context")
|
||||
|
||||
|
||||
class WakeUpContext:
|
||||
"""
|
||||
唤醒上下文,负责封装和管理所有与唤醒相关的状态,并处理其持久化。
|
||||
"""
|
||||
def __init__(self):
|
||||
"""初始化唤醒上下文,并从本地存储加载初始状态。"""
|
||||
self.wakeup_value: float = 0.0
|
||||
self.is_angry: bool = False
|
||||
self.angry_start_time: float = 0.0
|
||||
self.sleep_pressure: float = 100.0 # 新增:睡眠压力
|
||||
self.load()
|
||||
|
||||
def _get_storage_key(self) -> str:
|
||||
"""获取本地存储键"""
|
||||
return "global_wakeup_manager_state"
|
||||
|
||||
def load(self):
|
||||
"""从本地存储加载状态"""
|
||||
state = local_storage[self._get_storage_key()]
|
||||
if state and isinstance(state, dict):
|
||||
self.wakeup_value = state.get("wakeup_value", 0.0)
|
||||
self.is_angry = state.get("is_angry", False)
|
||||
self.angry_start_time = state.get("angry_start_time", 0.0)
|
||||
self.sleep_pressure = state.get("sleep_pressure", 100.0)
|
||||
logger.info(f"成功从本地存储加载唤醒上下文: {state}")
|
||||
else:
|
||||
logger.info("未找到本地唤醒上下文,将使用默认值初始化。")
|
||||
|
||||
def save(self):
|
||||
"""将当前状态保存到本地存储"""
|
||||
state = {
|
||||
"wakeup_value": self.wakeup_value,
|
||||
"is_angry": self.is_angry,
|
||||
"angry_start_time": self.angry_start_time,
|
||||
"sleep_pressure": self.sleep_pressure,
|
||||
}
|
||||
local_storage[self._get_storage_key()] = state
|
||||
logger.debug(f"已将唤醒上下文保存到本地存储: {state}")
|
||||
215
src/chat/message_manager/sleep_manager/wakeup_manager.py
Normal file
215
src/chat/message_manager/sleep_manager/wakeup_manager.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.manager.local_store_manager import local_storage
|
||||
from src.chat.message_manager.sleep_manager.wakeup_context import WakeUpContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .sleep_manager import SleepManager
|
||||
|
||||
|
||||
logger = get_logger("wakeup")
|
||||
|
||||
|
||||
class WakeUpManager:
|
||||
def __init__(self, sleep_manager: "SleepManager"):
|
||||
"""
|
||||
初始化唤醒度管理器
|
||||
|
||||
Args:
|
||||
sleep_manager: 睡眠管理器实例
|
||||
|
||||
功能说明:
|
||||
- 管理休眠状态下的唤醒度累积
|
||||
- 处理唤醒度的自然衰减
|
||||
- 控制愤怒状态的持续时间
|
||||
"""
|
||||
self.sleep_manager = sleep_manager
|
||||
self.context = WakeUpContext() # 使用新的上下文管理器
|
||||
self.angry_chat_id: Optional[str] = None
|
||||
self.last_decay_time = time.time()
|
||||
self._decay_task: Optional[asyncio.Task] = None
|
||||
self.is_running = False
|
||||
self.last_log_time = 0
|
||||
self.log_interval = 30
|
||||
|
||||
# 从配置文件获取参数
|
||||
sleep_config = global_config.sleep_system
|
||||
self.wakeup_threshold = sleep_config.wakeup_threshold
|
||||
self.private_message_increment = sleep_config.private_message_increment
|
||||
self.group_mention_increment = sleep_config.group_mention_increment
|
||||
self.decay_rate = sleep_config.decay_rate
|
||||
self.decay_interval = sleep_config.decay_interval
|
||||
self.angry_duration = sleep_config.angry_duration
|
||||
self.enabled = sleep_config.enable
|
||||
self.angry_prompt = sleep_config.angry_prompt
|
||||
|
||||
async def start(self):
|
||||
"""启动唤醒度管理器"""
|
||||
if not self.enabled:
|
||||
logger.info("唤醒度系统已禁用,跳过启动")
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
if not self._decay_task or self._decay_task.done():
|
||||
self._decay_task = asyncio.create_task(self._decay_loop())
|
||||
self._decay_task.add_done_callback(self._handle_decay_completion)
|
||||
logger.info("唤醒度管理器已启动")
|
||||
|
||||
async def stop(self):
|
||||
"""停止唤醒度管理器"""
|
||||
self.is_running = False
|
||||
if self._decay_task and not self._decay_task.done():
|
||||
self._decay_task.cancel()
|
||||
await asyncio.sleep(0)
|
||||
logger.info("唤醒度管理器已停止")
|
||||
|
||||
def _handle_decay_completion(self, task: asyncio.Task):
|
||||
"""处理衰减任务完成"""
|
||||
try:
|
||||
if exception := task.exception():
|
||||
logger.error(f"唤醒度衰减任务异常: {exception}")
|
||||
else:
|
||||
logger.info("唤醒度衰减任务正常结束")
|
||||
except asyncio.CancelledError:
|
||||
logger.info("唤醒度衰减任务被取消")
|
||||
|
||||
async def _decay_loop(self):
|
||||
"""唤醒度衰减循环"""
|
||||
while self.is_running:
|
||||
await asyncio.sleep(self.decay_interval)
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 检查愤怒状态是否过期
|
||||
if self.context.is_angry and current_time - self.context.angry_start_time >= self.angry_duration:
|
||||
self.context.is_angry = False
|
||||
# 通知情绪管理系统清除愤怒状态
|
||||
from src.mood.mood_manager import mood_manager
|
||||
if self.angry_chat_id:
|
||||
mood_manager.clear_angry_from_wakeup(self.angry_chat_id)
|
||||
self.angry_chat_id = None
|
||||
else:
|
||||
logger.warning("Angry state ended but no angry_chat_id was set.")
|
||||
logger.info("愤怒状态结束,恢复正常")
|
||||
self.context.save()
|
||||
|
||||
# 唤醒度自然衰减
|
||||
if self.context.wakeup_value > 0:
|
||||
old_value = self.context.wakeup_value
|
||||
self.context.wakeup_value = max(0, self.context.wakeup_value - self.decay_rate)
|
||||
if old_value != self.context.wakeup_value:
|
||||
logger.debug(f"唤醒度衰减: {old_value:.1f} -> {self.context.wakeup_value:.1f}")
|
||||
self.context.save()
|
||||
|
||||
def add_wakeup_value(self, is_private_chat: bool, is_mentioned: bool = False, chat_id: Optional[str] = None) -> bool:
|
||||
"""
|
||||
增加唤醒度值
|
||||
|
||||
Args:
|
||||
is_private_chat: 是否为私聊
|
||||
is_mentioned: 是否被艾特(仅群聊有效)
|
||||
|
||||
Returns:
|
||||
bool: 是否达到唤醒阈值
|
||||
"""
|
||||
# 如果系统未启用,直接返回
|
||||
if not self.enabled:
|
||||
return False
|
||||
|
||||
# 只有在休眠且非失眠状态下才累积唤醒度
|
||||
from .sleep_state import SleepState
|
||||
|
||||
current_sleep_state = self.sleep_manager.get_current_sleep_state()
|
||||
if current_sleep_state != SleepState.SLEEPING:
|
||||
return False
|
||||
|
||||
old_value = self.context.wakeup_value
|
||||
|
||||
if is_private_chat:
|
||||
# 私聊每条消息都增加唤醒度
|
||||
self.context.wakeup_value += self.private_message_increment
|
||||
logger.debug(f"私聊消息增加唤醒度: +{self.private_message_increment}")
|
||||
elif is_mentioned:
|
||||
# 群聊只有被艾特才增加唤醒度
|
||||
self.context.wakeup_value += self.group_mention_increment
|
||||
logger.debug(f"群聊艾特增加唤醒度: +{self.group_mention_increment}")
|
||||
else:
|
||||
# 群聊未被艾特,不增加唤醒度
|
||||
return False
|
||||
|
||||
current_time = time.time()
|
||||
if current_time - self.last_log_time > self.log_interval:
|
||||
logger.info(
|
||||
f"唤醒度变化: {old_value:.1f} -> {self.context.wakeup_value:.1f} (阈值: {self.wakeup_threshold})"
|
||||
)
|
||||
self.last_log_time = current_time
|
||||
else:
|
||||
logger.debug(
|
||||
f"唤醒度变化: {old_value:.1f} -> {self.context.wakeup_value:.1f} (阈值: {self.wakeup_threshold})"
|
||||
)
|
||||
|
||||
# 检查是否达到唤醒阈值
|
||||
if self.context.wakeup_value >= self.wakeup_threshold:
|
||||
if not chat_id:
|
||||
logger.error("Wakeup threshold reached, but no chat_id was provided. Cannot trigger wakeup.")
|
||||
return False
|
||||
self._trigger_wakeup(chat_id)
|
||||
return True
|
||||
|
||||
self.context.save()
|
||||
return False
|
||||
|
||||
def _trigger_wakeup(self, chat_id: str):
|
||||
"""触发唤醒,进入愤怒状态"""
|
||||
self.context.is_angry = True
|
||||
self.context.angry_start_time = time.time()
|
||||
self.context.wakeup_value = 0.0 # 重置唤醒度
|
||||
self.angry_chat_id = chat_id
|
||||
|
||||
self.context.save()
|
||||
|
||||
# 通知情绪管理系统进入愤怒状态
|
||||
from src.mood.mood_manager import mood_manager
|
||||
mood_manager.set_angry_from_wakeup(chat_id)
|
||||
|
||||
# 通知SleepManager重置睡眠状态
|
||||
self.sleep_manager.reset_sleep_state_after_wakeup()
|
||||
|
||||
logger.info(f"唤醒度达到阈值({self.wakeup_threshold}),被吵醒进入愤怒状态!")
|
||||
|
||||
def get_angry_prompt_addition(self) -> str:
|
||||
"""获取愤怒状态下的提示词补充"""
|
||||
if self.context.is_angry:
|
||||
return self.angry_prompt
|
||||
return ""
|
||||
|
||||
def is_in_angry_state(self) -> bool:
|
||||
"""检查是否处于愤怒状态"""
|
||||
if self.context.is_angry:
|
||||
current_time = time.time()
|
||||
if current_time - self.context.angry_start_time >= self.angry_duration:
|
||||
self.context.is_angry = False
|
||||
# 通知情绪管理系统清除愤怒状态
|
||||
from src.mood.mood_manager import mood_manager
|
||||
if self.angry_chat_id:
|
||||
mood_manager.clear_angry_from_wakeup(self.angry_chat_id)
|
||||
self.angry_chat_id = None
|
||||
else:
|
||||
logger.warning("Angry state expired in check, but no angry_chat_id was set.")
|
||||
logger.info("愤怒状态自动过期")
|
||||
return False
|
||||
return self.context.is_angry
|
||||
|
||||
def get_status_info(self) -> dict:
|
||||
"""获取当前状态信息"""
|
||||
return {
|
||||
"wakeup_value": self.context.wakeup_value,
|
||||
"wakeup_threshold": self.wakeup_threshold,
|
||||
"is_angry": self.context.is_angry,
|
||||
"angry_remaining_time": max(0, self.angry_duration - (time.time() - self.context.angry_start_time))
|
||||
if self.context.is_angry
|
||||
else 0,
|
||||
}
|
||||
@@ -11,11 +11,12 @@ from src.mood.mood_manager import mood_manager # 导入情绪管理器
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
|
||||
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
||||
from src.chat.message_manager import message_manager
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.plugin_system.core import component_registry, event_manager, global_announcement_manager
|
||||
from src.plugin_system.base import BaseCommand, EventType
|
||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
|
||||
# 导入反注入系统
|
||||
from src.chat.antipromptinjector import initialize_anti_injector
|
||||
@@ -73,15 +74,17 @@ class ChatBot:
|
||||
self.bot = None # bot 实例引用
|
||||
self._started = False
|
||||
self.mood_manager = mood_manager # 获取情绪管理器单例
|
||||
self.heartflow_message_receiver = HeartFCMessageReceiver() # 新增
|
||||
# 亲和力流消息处理器 - 直接使用全局afc_manager
|
||||
|
||||
self.s4u_message_processor = S4UMessageProcessor()
|
||||
|
||||
# 初始化反注入系统
|
||||
self._initialize_anti_injector()
|
||||
|
||||
@staticmethod
|
||||
def _initialize_anti_injector():
|
||||
# 启动消息管理器
|
||||
self._message_manager_started = False
|
||||
|
||||
def _initialize_anti_injector(self):
|
||||
"""初始化反注入系统"""
|
||||
try:
|
||||
initialize_anti_injector()
|
||||
@@ -99,10 +102,15 @@ class ChatBot:
|
||||
if not self._started:
|
||||
logger.debug("确保ChatBot所有任务已启动")
|
||||
|
||||
# 启动消息管理器
|
||||
if not self._message_manager_started:
|
||||
await message_manager.start()
|
||||
self._message_manager_started = True
|
||||
logger.info("消息管理器已启动")
|
||||
|
||||
self._started = True
|
||||
|
||||
@staticmethod
|
||||
async def _process_plus_commands(message: MessageRecv):
|
||||
async def _process_plus_commands(self, message: MessageRecv):
|
||||
"""独立处理PlusCommand系统"""
|
||||
try:
|
||||
text = message.processed_plain_text
|
||||
@@ -182,7 +190,7 @@ class ChatBot:
|
||||
try:
|
||||
# 检查聊天类型限制
|
||||
if not plus_command_instance.is_chat_type_allowed():
|
||||
is_group = hasattr(message, "is_group_message") and message.is_group_message
|
||||
is_group = message.message_info.group_info
|
||||
logger.info(
|
||||
f"PlusCommand {plus_command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
|
||||
)
|
||||
@@ -222,8 +230,7 @@ class ChatBot:
|
||||
logger.error(f"处理PlusCommand时出错: {e}")
|
||||
return False, None, True # 出错时继续处理消息
|
||||
|
||||
@staticmethod
|
||||
async def _process_commands_with_new_system(message: MessageRecv):
|
||||
async def _process_commands_with_new_system(self, message: MessageRecv):
|
||||
# sourcery skip: use-named-expression
|
||||
"""使用新插件系统处理命令"""
|
||||
try:
|
||||
@@ -256,7 +263,7 @@ class ChatBot:
|
||||
try:
|
||||
# 检查聊天类型限制
|
||||
if not command_instance.is_chat_type_allowed():
|
||||
is_group = hasattr(message, "is_group_message") and message.is_group_message
|
||||
is_group = message.message_info.group_info
|
||||
logger.info(
|
||||
f"命令 {command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
|
||||
)
|
||||
@@ -313,8 +320,7 @@ class ChatBot:
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def handle_adapter_response(message: MessageRecv):
|
||||
async def handle_adapter_response(self, message: MessageRecv):
|
||||
"""处理适配器命令响应"""
|
||||
try:
|
||||
from src.plugin_system.apis.send_api import put_adapter_response
|
||||
@@ -354,19 +360,7 @@ class ChatBot:
|
||||
return
|
||||
|
||||
async def message_process(self, message_data: Dict[str, Any]) -> None:
|
||||
"""处理转化后的统一格式消息
|
||||
这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中
|
||||
heart_flow模式:使用思维流系统进行回复
|
||||
- 包含思维流状态管理
|
||||
- 在回复前进行观察和状态更新
|
||||
- 回复后更新思维流状态
|
||||
- 消息过滤
|
||||
- 记忆激活
|
||||
- 意愿计算
|
||||
- 消息生成和发送
|
||||
- 表情包处理
|
||||
- 性能计时
|
||||
"""
|
||||
"""处理转化后的统一格式消息"""
|
||||
try:
|
||||
# 首先处理可能的切片消息重组
|
||||
from src.utils.message_chunker import reassembler
|
||||
@@ -403,9 +397,7 @@ class ChatBot:
|
||||
# logger.debug(str(message_data))
|
||||
message = MessageRecv(message_data)
|
||||
|
||||
if await self.handle_notice_message(message):
|
||||
...
|
||||
|
||||
message.is_mentioned, _ = is_mentioned_bot_in_message(message)
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
if message.message_info.additional_config:
|
||||
@@ -415,6 +407,7 @@ class ChatBot:
|
||||
return
|
||||
|
||||
get_chat_manager().register_message(message)
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=message.message_info.platform, # type: ignore
|
||||
user_info=user_info, # type: ignore
|
||||
@@ -426,11 +419,14 @@ class ChatBot:
|
||||
# 处理消息内容,生成纯文本
|
||||
await message.process()
|
||||
|
||||
# 过滤检查 (在消息处理之后进行)
|
||||
if _check_ban_words(
|
||||
message.processed_plain_text, chat, user_info # type: ignore
|
||||
) or _check_ban_regex(
|
||||
message.processed_plain_text, chat, user_info # type: ignore
|
||||
# 在这里打印[所见]日志,确保在所有处理和过滤之前记录
|
||||
logger.info(f"\u001b[38;5;118m{message.message_info.user_info.user_nickname}:{message.processed_plain_text}\u001b[0m")
|
||||
|
||||
# 过滤检查
|
||||
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
|
||||
message.raw_message, # type: ignore
|
||||
chat,
|
||||
user_info, # type: ignore
|
||||
):
|
||||
return
|
||||
|
||||
@@ -456,7 +452,8 @@ class ChatBot:
|
||||
result = await event_manager.trigger_event(EventType.ON_MESSAGE, permission_group="SYSTEM", message=message)
|
||||
if not result.all_continue_process():
|
||||
raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于消息到达时取消了消息处理")
|
||||
|
||||
|
||||
# TODO:暂不可用
|
||||
# 确认从接口发来的message是否有自定义的prompt模板信息
|
||||
if message.message_info.template_info and not message.message_info.template_info.template_default:
|
||||
template_group_name: Optional[str] = message.message_info.template_info.template_name # type: ignore
|
||||
@@ -470,7 +467,55 @@ class ChatBot:
|
||||
template_group_name = None
|
||||
|
||||
async def preprocess():
|
||||
await self.heartflow_message_receiver.process_message(message)
|
||||
# 存储消息到数据库
|
||||
from .storage import MessageStorage
|
||||
|
||||
try:
|
||||
await MessageStorage.store_message(message, message.chat_stream)
|
||||
logger.debug(f"消息已存储到数据库: {message.message_info.message_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"存储消息到数据库失败: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# 使用消息管理器处理消息(保持原有功能)
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
# 创建数据库消息对象
|
||||
db_message = DatabaseMessages(
|
||||
message_id=message.message_info.message_id,
|
||||
time=message.message_info.time,
|
||||
chat_id=message.chat_stream.stream_id,
|
||||
processed_plain_text=message.processed_plain_text,
|
||||
display_message=message.processed_plain_text,
|
||||
is_mentioned=message.is_mentioned,
|
||||
is_at=message.is_at,
|
||||
is_emoji=message.is_emoji,
|
||||
is_picid=message.is_picid,
|
||||
is_command=message.is_command,
|
||||
is_notify=message.is_notify,
|
||||
user_id=message.message_info.user_info.user_id,
|
||||
user_nickname=message.message_info.user_info.user_nickname,
|
||||
user_cardname=message.message_info.user_info.user_cardname,
|
||||
user_platform=message.message_info.user_info.platform,
|
||||
chat_info_stream_id=message.chat_stream.stream_id,
|
||||
chat_info_platform=message.chat_stream.platform,
|
||||
chat_info_create_time=message.chat_stream.create_time,
|
||||
chat_info_last_active_time=message.chat_stream.last_active_time,
|
||||
chat_info_user_id=message.chat_stream.user_info.user_id,
|
||||
chat_info_user_nickname=message.chat_stream.user_info.user_nickname,
|
||||
chat_info_user_cardname=message.chat_stream.user_info.user_cardname,
|
||||
chat_info_user_platform=message.chat_stream.user_info.platform,
|
||||
)
|
||||
|
||||
# 如果是群聊,添加群组信息
|
||||
if message.chat_stream.group_info:
|
||||
db_message.chat_info_group_id = message.chat_stream.group_info.group_id
|
||||
db_message.chat_info_group_name = message.chat_stream.group_info.group_name
|
||||
db_message.chat_info_group_platform = message.chat_stream.group_info.platform
|
||||
|
||||
# 添加消息到消息管理器
|
||||
message_manager.add_message(message.chat_stream.stream_id, db_message)
|
||||
logger.debug(f"消息已添加到消息管理器: {message.chat_stream.stream_id}")
|
||||
|
||||
if template_group_name:
|
||||
async with global_prompt_manager.async_message_scope(template_group_name):
|
||||
|
||||
@@ -25,43 +25,6 @@ install(extra_lines=3)
|
||||
logger = get_logger("chat_stream")
|
||||
|
||||
|
||||
class ChatMessageContext:
|
||||
"""聊天消息上下文,存储消息的上下文信息"""
|
||||
|
||||
def __init__(self, message: "MessageRecv"):
|
||||
self.message = message
|
||||
|
||||
def get_template_name(self) -> Optional[str]:
|
||||
"""获取模板名称"""
|
||||
if self.message.message_info.template_info and not self.message.message_info.template_info.template_default:
|
||||
return self.message.message_info.template_info.template_name # type: ignore
|
||||
return None
|
||||
|
||||
def get_last_message(self) -> "MessageRecv":
|
||||
"""获取最后一条消息"""
|
||||
return self.message
|
||||
|
||||
def check_types(self, types: list) -> bool:
|
||||
# sourcery skip: invert-any-all, use-any, use-next
|
||||
"""检查消息类型"""
|
||||
if not self.message.message_info.format_info.accept_format: # type: ignore
|
||||
return False
|
||||
for t in types:
|
||||
if t not in self.message.message_info.format_info.accept_format: # type: ignore
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_priority_mode(self) -> str:
|
||||
"""获取优先级模式"""
|
||||
return self.message.priority_mode
|
||||
|
||||
def get_priority_info(self) -> Optional[dict]:
|
||||
"""获取优先级信息"""
|
||||
if hasattr(self.message, "priority_info") and self.message.priority_info:
|
||||
return self.message.priority_info
|
||||
return None
|
||||
|
||||
|
||||
class ChatStream:
|
||||
"""聊天流对象,存储一个完整的聊天上下文"""
|
||||
|
||||
@@ -79,14 +42,24 @@ class ChatStream:
|
||||
self.group_info = group_info
|
||||
self.create_time = data.get("create_time", time.time()) if data else time.time()
|
||||
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
|
||||
self.energy_value = data.get("energy_value", 5.0) if data else 5.0
|
||||
self.sleep_pressure = data.get("sleep_pressure", 0.0) if data else 0.0
|
||||
self.saved = False
|
||||
self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息
|
||||
# 从配置文件中读取focus_value,如果没有则使用默认值1.0
|
||||
self.focus_energy = data.get("focus_energy", global_config.chat.focus_value) if data else global_config.chat.focus_value
|
||||
|
||||
# 使用StreamContext替代ChatMessageContext
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.plugin_system.base.component_types import ChatType, ChatMode
|
||||
|
||||
self.stream_context: StreamContext = StreamContext(
|
||||
stream_id=stream_id, chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE, chat_mode=ChatMode.NORMAL
|
||||
)
|
||||
|
||||
# 基础参数
|
||||
self.base_interest_energy = 0.5 # 默认基础兴趣度
|
||||
self._focus_energy = 0.5 # 内部存储的focus_energy值
|
||||
self.no_reply_consecutive = 0
|
||||
self.breaking_accumulated_interest = 0.0
|
||||
|
||||
# 自动加载历史消息
|
||||
self._load_history_messages()
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典格式"""
|
||||
@@ -97,10 +70,15 @@ class ChatStream:
|
||||
"group_info": self.group_info.to_dict() if self.group_info else None,
|
||||
"create_time": self.create_time,
|
||||
"last_active_time": self.last_active_time,
|
||||
"energy_value": self.energy_value,
|
||||
"sleep_pressure": self.sleep_pressure,
|
||||
"focus_energy": self.focus_energy,
|
||||
"breaking_accumulated_interest": self.breaking_accumulated_interest,
|
||||
# 基础兴趣度
|
||||
"base_interest_energy": self.base_interest_energy,
|
||||
# 新增stream_context信息
|
||||
"stream_context_chat_type": self.stream_context.chat_type.value,
|
||||
"stream_context_chat_mode": self.stream_context.chat_mode.value,
|
||||
# 新增interruption_count信息
|
||||
"interruption_count": self.stream_context.interruption_count,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -109,7 +87,7 @@ class ChatStream:
|
||||
user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None
|
||||
group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None
|
||||
|
||||
return cls(
|
||||
instance = cls(
|
||||
stream_id=data["stream_id"],
|
||||
platform=data["platform"],
|
||||
user_info=user_info, # type: ignore
|
||||
@@ -117,6 +95,22 @@ class ChatStream:
|
||||
data=data,
|
||||
)
|
||||
|
||||
# 恢复stream_context信息
|
||||
if "stream_context_chat_type" in data:
|
||||
from src.plugin_system.base.component_types import ChatType, ChatMode
|
||||
|
||||
instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"])
|
||||
if "stream_context_chat_mode" in data:
|
||||
from src.plugin_system.base.component_types import ChatType, ChatMode
|
||||
|
||||
instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"])
|
||||
|
||||
# 恢复interruption_count信息
|
||||
if "interruption_count" in data:
|
||||
instance.stream_context.interruption_count = data["interruption_count"]
|
||||
|
||||
return instance
|
||||
|
||||
def update_active_time(self):
|
||||
"""更新最后活跃时间"""
|
||||
self.last_active_time = time.time()
|
||||
@@ -124,7 +118,312 @@ class ChatStream:
|
||||
|
||||
def set_context(self, message: "MessageRecv"):
|
||||
"""设置聊天消息上下文"""
|
||||
self.context = ChatMessageContext(message)
|
||||
# 将MessageRecv转换为DatabaseMessages并设置到stream_context
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
import json
|
||||
|
||||
# 安全获取message_info中的数据
|
||||
message_info = getattr(message, "message_info", {})
|
||||
user_info = getattr(message_info, "user_info", {})
|
||||
group_info = getattr(message_info, "group_info", {})
|
||||
|
||||
# 提取reply_to信息(从message_segment中查找reply类型的段)
|
||||
reply_to = None
|
||||
if hasattr(message, "message_segment") and message.message_segment:
|
||||
reply_to = self._extract_reply_from_segment(message.message_segment)
|
||||
|
||||
# 完整的数据转移逻辑
|
||||
db_message = DatabaseMessages(
|
||||
# 基础消息信息
|
||||
message_id=getattr(message, "message_id", ""),
|
||||
time=getattr(message, "time", time.time()),
|
||||
chat_id=self._generate_chat_id(message_info),
|
||||
reply_to=reply_to,
|
||||
# 兴趣度相关
|
||||
interest_value=getattr(message, "interest_value", 0.0),
|
||||
# 关键词
|
||||
key_words=json.dumps(getattr(message, "key_words", []), ensure_ascii=False)
|
||||
if getattr(message, "key_words", None)
|
||||
else None,
|
||||
key_words_lite=json.dumps(getattr(message, "key_words_lite", []), ensure_ascii=False)
|
||||
if getattr(message, "key_words_lite", None)
|
||||
else None,
|
||||
# 消息状态标记
|
||||
is_mentioned=getattr(message, "is_mentioned", None),
|
||||
is_at=getattr(message, "is_at", False),
|
||||
is_emoji=getattr(message, "is_emoji", False),
|
||||
is_picid=getattr(message, "is_picid", False),
|
||||
is_voice=getattr(message, "is_voice", False),
|
||||
is_video=getattr(message, "is_video", False),
|
||||
is_command=getattr(message, "is_command", False),
|
||||
is_notify=getattr(message, "is_notify", False),
|
||||
# 消息内容
|
||||
processed_plain_text=getattr(message, "processed_plain_text", ""),
|
||||
display_message=getattr(message, "processed_plain_text", ""), # 默认使用processed_plain_text
|
||||
# 优先级信息
|
||||
priority_mode=getattr(message, "priority_mode", None),
|
||||
priority_info=json.dumps(getattr(message, "priority_info", None))
|
||||
if getattr(message, "priority_info", None)
|
||||
else None,
|
||||
# 额外配置
|
||||
additional_config=getattr(message_info, "additional_config", None),
|
||||
# 用户信息
|
||||
user_id=str(getattr(user_info, "user_id", "")),
|
||||
user_nickname=getattr(user_info, "user_nickname", ""),
|
||||
user_cardname=getattr(user_info, "user_cardname", None),
|
||||
user_platform=getattr(user_info, "platform", ""),
|
||||
# 群组信息
|
||||
chat_info_group_id=getattr(group_info, "group_id", None),
|
||||
chat_info_group_name=getattr(group_info, "group_name", None),
|
||||
chat_info_group_platform=getattr(group_info, "platform", None),
|
||||
# 聊天流信息
|
||||
chat_info_user_id=str(getattr(user_info, "user_id", "")),
|
||||
chat_info_user_nickname=getattr(user_info, "user_nickname", ""),
|
||||
chat_info_user_cardname=getattr(user_info, "user_cardname", None),
|
||||
chat_info_user_platform=getattr(user_info, "platform", ""),
|
||||
chat_info_stream_id=self.stream_id,
|
||||
chat_info_platform=self.platform,
|
||||
chat_info_create_time=self.create_time,
|
||||
chat_info_last_active_time=self.last_active_time,
|
||||
# 新增兴趣度系统字段 - 添加安全处理
|
||||
actions=self._safe_get_actions(message),
|
||||
should_reply=getattr(message, "should_reply", False),
|
||||
)
|
||||
|
||||
self.stream_context.set_current_message(db_message)
|
||||
self.stream_context.priority_mode = getattr(message, "priority_mode", None)
|
||||
self.stream_context.priority_info = getattr(message, "priority_info", None)
|
||||
|
||||
# 调试日志:记录数据转移情况
|
||||
logger.debug(f"消息数据转移完成 - message_id: {db_message.message_id}, "
|
||||
f"chat_id: {db_message.chat_id}, "
|
||||
f"is_mentioned: {db_message.is_mentioned}, "
|
||||
f"is_emoji: {db_message.is_emoji}, "
|
||||
f"is_picid: {db_message.is_picid}, "
|
||||
f"interest_value: {db_message.interest_value}")
|
||||
|
||||
def _safe_get_actions(self, message: "MessageRecv") -> Optional[list]:
|
||||
"""安全获取消息的actions字段"""
|
||||
try:
|
||||
actions = getattr(message, "actions", None)
|
||||
if actions is None:
|
||||
return None
|
||||
|
||||
# 如果是字符串,尝试解析为JSON
|
||||
if isinstance(actions, str):
|
||||
try:
|
||||
import json
|
||||
actions = json.loads(actions)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"无法解析actions JSON字符串: {actions}")
|
||||
return None
|
||||
|
||||
# 确保返回列表类型
|
||||
if isinstance(actions, list):
|
||||
# 过滤掉空值和非字符串元素
|
||||
filtered_actions = [action for action in actions if action is not None and isinstance(action, str)]
|
||||
return filtered_actions if filtered_actions else None
|
||||
else:
|
||||
logger.warning(f"actions字段类型不支持: {type(actions)}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"获取actions字段失败: {e}")
|
||||
return None
|
||||
|
||||
def _extract_reply_from_segment(self, segment) -> Optional[str]:
|
||||
"""从消息段中提取reply_to信息"""
|
||||
try:
|
||||
if hasattr(segment, "type") and segment.type == "seglist":
|
||||
# 递归搜索seglist中的reply段
|
||||
if hasattr(segment, "data") and segment.data:
|
||||
for seg in segment.data:
|
||||
reply_id = self._extract_reply_from_segment(seg)
|
||||
if reply_id:
|
||||
return reply_id
|
||||
elif hasattr(segment, "type") and segment.type == "reply":
|
||||
# 找到reply段,返回message_id
|
||||
return str(segment.data) if segment.data else None
|
||||
except Exception as e:
|
||||
logger.warning(f"提取reply_to信息失败: {e}")
|
||||
return None
|
||||
|
||||
def _generate_chat_id(self, message_info) -> str:
|
||||
"""生成chat_id,基于群组或用户信息"""
|
||||
try:
|
||||
group_info = getattr(message_info, "group_info", None)
|
||||
user_info = getattr(message_info, "user_info", None)
|
||||
|
||||
if group_info and hasattr(group_info, "group_id") and group_info.group_id:
|
||||
# 群聊:使用群组ID
|
||||
return f"{self.platform}_{group_info.group_id}"
|
||||
elif user_info and hasattr(user_info, "user_id") and user_info.user_id:
|
||||
# 私聊:使用用户ID
|
||||
return f"{self.platform}_{user_info.user_id}_private"
|
||||
else:
|
||||
# 默认:使用stream_id
|
||||
return self.stream_id
|
||||
except Exception as e:
|
||||
logger.warning(f"生成chat_id失败: {e}")
|
||||
return self.stream_id
|
||||
|
||||
@property
|
||||
def focus_energy(self) -> float:
|
||||
"""使用重构后的能量管理器计算focus_energy"""
|
||||
try:
|
||||
from src.chat.energy_system import energy_manager
|
||||
|
||||
# 获取所有消息
|
||||
history_messages = self.stream_context.get_history_messages(limit=global_config.chat.max_context_size)
|
||||
unread_messages = self.stream_context.get_unread_messages()
|
||||
all_messages = history_messages + unread_messages
|
||||
|
||||
# 获取用户ID
|
||||
user_id = None
|
||||
if self.user_info and hasattr(self.user_info, "user_id"):
|
||||
user_id = str(self.user_info.user_id)
|
||||
|
||||
# 使用能量管理器计算
|
||||
energy = energy_manager.calculate_focus_energy(
|
||||
stream_id=self.stream_id,
|
||||
messages=all_messages,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# 更新内部存储
|
||||
self._focus_energy = energy
|
||||
|
||||
logger.debug(f"聊天流 {self.stream_id} 能量: {energy:.3f}")
|
||||
return energy
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取focus_energy失败: {e}", exc_info=True)
|
||||
# 返回缓存的值或默认值
|
||||
if hasattr(self, '_focus_energy'):
|
||||
return self._focus_energy
|
||||
else:
|
||||
return 0.5
|
||||
|
||||
@focus_energy.setter
|
||||
def focus_energy(self, value: float):
|
||||
"""设置focus_energy值(主要用于初始化或特殊场景)"""
|
||||
self._focus_energy = max(0.0, min(1.0, value))
|
||||
|
||||
def _get_user_relationship_score(self) -> float:
|
||||
"""获取用户关系分"""
|
||||
# 使用插件内部的兴趣度评分系统
|
||||
try:
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||
|
||||
if self.user_info and hasattr(self.user_info, "user_id"):
|
||||
user_id = str(self.user_info.user_id)
|
||||
relationship_score = chatter_interest_scoring_system._calculate_relationship_score(user_id)
|
||||
logger.debug(f"ChatStream {self.stream_id}: 用户关系分 = {relationship_score:.3f}")
|
||||
return max(0.0, min(1.0, relationship_score))
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"ChatStream {self.stream_id}: 插件内部关系分计算失败: {e}")
|
||||
|
||||
# 默认基础分
|
||||
return 0.3
|
||||
|
||||
def _load_history_messages(self):
|
||||
"""从数据库加载历史消息到StreamContext"""
|
||||
try:
|
||||
from src.common.database.sqlalchemy_models import Messages
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from sqlalchemy import select, desc
|
||||
import asyncio
|
||||
|
||||
async def _load_messages():
|
||||
def _db_query():
|
||||
with get_db_session() as session:
|
||||
# 查询该stream_id的最近20条消息
|
||||
stmt = (
|
||||
select(Messages)
|
||||
.where(Messages.chat_info_stream_id == self.stream_id)
|
||||
.order_by(desc(Messages.time))
|
||||
.limit(global_config.chat.max_context_size)
|
||||
)
|
||||
results = session.execute(stmt).scalars().all()
|
||||
return results
|
||||
|
||||
# 在线程中执行数据库查询
|
||||
db_messages = await asyncio.to_thread(_db_query)
|
||||
|
||||
# 转换为DatabaseMessages对象并添加到StreamContext
|
||||
for db_msg in db_messages:
|
||||
try:
|
||||
# 从SQLAlchemy模型转换为DatabaseMessages数据模型
|
||||
import orjson
|
||||
|
||||
# 解析actions字段(JSON格式)
|
||||
actions = None
|
||||
if db_msg.actions:
|
||||
try:
|
||||
actions = orjson.loads(db_msg.actions)
|
||||
except (orjson.JSONDecodeError, TypeError):
|
||||
actions = None
|
||||
|
||||
db_message = DatabaseMessages(
|
||||
message_id=db_msg.message_id,
|
||||
time=db_msg.time,
|
||||
chat_id=db_msg.chat_id,
|
||||
reply_to=db_msg.reply_to,
|
||||
interest_value=db_msg.interest_value,
|
||||
key_words=db_msg.key_words,
|
||||
key_words_lite=db_msg.key_words_lite,
|
||||
is_mentioned=db_msg.is_mentioned,
|
||||
processed_plain_text=db_msg.processed_plain_text,
|
||||
display_message=db_msg.display_message,
|
||||
priority_mode=db_msg.priority_mode,
|
||||
priority_info=db_msg.priority_info,
|
||||
additional_config=db_msg.additional_config,
|
||||
is_emoji=db_msg.is_emoji,
|
||||
is_picid=db_msg.is_picid,
|
||||
is_command=db_msg.is_command,
|
||||
is_notify=db_msg.is_notify,
|
||||
user_id=db_msg.user_id,
|
||||
user_nickname=db_msg.user_nickname,
|
||||
user_cardname=db_msg.user_cardname,
|
||||
user_platform=db_msg.user_platform,
|
||||
chat_info_group_id=db_msg.chat_info_group_id,
|
||||
chat_info_group_name=db_msg.chat_info_group_name,
|
||||
chat_info_group_platform=db_msg.chat_info_group_platform,
|
||||
chat_info_user_id=db_msg.chat_info_user_id,
|
||||
chat_info_user_nickname=db_msg.chat_info_user_nickname,
|
||||
chat_info_user_cardname=db_msg.chat_info_user_cardname,
|
||||
chat_info_user_platform=db_msg.chat_info_user_platform,
|
||||
chat_info_stream_id=db_msg.chat_info_stream_id,
|
||||
chat_info_platform=db_msg.chat_info_platform,
|
||||
chat_info_create_time=db_msg.chat_info_create_time,
|
||||
chat_info_last_active_time=db_msg.chat_info_last_active_time,
|
||||
actions=actions,
|
||||
should_reply=getattr(db_msg, "should_reply", False) or False,
|
||||
)
|
||||
|
||||
# 添加调试日志:检查从数据库加载的interest_value
|
||||
logger.debug(f"加载历史消息 {db_message.message_id} - interest_value: {db_message.interest_value}")
|
||||
|
||||
# 标记为已读并添加到历史消息
|
||||
db_message.is_read = True
|
||||
self.stream_context.history_messages.append(db_message)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"转换消息 {db_msg.message_id} 失败: {e}")
|
||||
continue
|
||||
|
||||
if self.stream_context.history_messages:
|
||||
logger.info(
|
||||
f"已从数据库加载 {len(self.stream_context.history_messages)} 条历史消息到聊天流 {self.stream_id}"
|
||||
)
|
||||
|
||||
# 创建任务来加载历史消息
|
||||
asyncio.create_task(_load_messages())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载历史消息失败: {e}")
|
||||
|
||||
|
||||
class ChatManager:
|
||||
@@ -362,7 +661,16 @@ class ChatManager:
|
||||
"group_name": group_info_d["group_name"] if group_info_d else "",
|
||||
"energy_value": s_data_dict.get("energy_value", 5.0),
|
||||
"sleep_pressure": s_data_dict.get("sleep_pressure", 0.0),
|
||||
"focus_energy": s_data_dict.get("focus_energy", global_config.chat.focus_value),
|
||||
"focus_energy": s_data_dict.get("focus_energy", 0.5),
|
||||
# 新增动态兴趣度系统字段
|
||||
"base_interest_energy": s_data_dict.get("base_interest_energy", 0.5),
|
||||
"message_interest_total": s_data_dict.get("message_interest_total", 0.0),
|
||||
"message_count": s_data_dict.get("message_count", 0),
|
||||
"action_count": s_data_dict.get("action_count", 0),
|
||||
"reply_count": s_data_dict.get("reply_count", 0),
|
||||
"last_interaction_time": s_data_dict.get("last_interaction_time", time.time()),
|
||||
"consecutive_no_reply": s_data_dict.get("consecutive_no_reply", 0),
|
||||
"interruption_count": s_data_dict.get("interruption_count", 0),
|
||||
}
|
||||
if global_config.database.database_type == "sqlite":
|
||||
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
|
||||
@@ -419,7 +727,17 @@ class ChatManager:
|
||||
"last_active_time": model_instance.last_active_time,
|
||||
"energy_value": model_instance.energy_value,
|
||||
"sleep_pressure": model_instance.sleep_pressure,
|
||||
"focus_energy": getattr(model_instance, "focus_energy", global_config.chat.focus_value),
|
||||
"focus_energy": getattr(model_instance, "focus_energy", 0.5),
|
||||
# 新增动态兴趣度系统字段 - 使用getattr提供默认值
|
||||
"base_interest_energy": getattr(model_instance, "base_interest_energy", 0.5),
|
||||
"message_interest_total": getattr(model_instance, "message_interest_total", 0.0),
|
||||
"message_count": getattr(model_instance, "message_count", 0),
|
||||
"action_count": getattr(model_instance, "action_count", 0),
|
||||
"reply_count": getattr(model_instance, "reply_count", 0),
|
||||
"last_interaction_time": getattr(model_instance, "last_interaction_time", time.time()),
|
||||
"relationship_score": getattr(model_instance, "relationship_score", 0.3),
|
||||
"consecutive_no_reply": getattr(model_instance, "consecutive_no_reply", 0),
|
||||
"interruption_count": getattr(model_instance, "interruption_count", 0),
|
||||
}
|
||||
loaded_streams_data.append(data_for_from_dict)
|
||||
await session.commit()
|
||||
|
||||
@@ -123,7 +123,7 @@ class MessageRecv(Message):
|
||||
self.is_video = False
|
||||
self.is_mentioned = None
|
||||
self.is_notify = False
|
||||
|
||||
self.is_at = False
|
||||
self.is_command = False
|
||||
|
||||
self.priority_mode = "interest"
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import re
|
||||
import traceback
|
||||
import orjson
|
||||
from typing import Union
|
||||
|
||||
import orjson
|
||||
from sqlalchemy import select, desc, update
|
||||
|
||||
from src.common.database.sqlalchemy_models import Messages, Images, get_db_session
|
||||
from src.common.database.sqlalchemy_models import Messages, Images
|
||||
from src.common.logger import get_logger
|
||||
from .chat_stream import ChatStream
|
||||
from .message import MessageSending, MessageRecv
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from sqlalchemy import select, update, desc
|
||||
|
||||
logger = get_logger("message_storage")
|
||||
|
||||
@@ -41,7 +41,7 @@ class MessageStorage:
|
||||
processed_plain_text = message.processed_plain_text
|
||||
|
||||
if processed_plain_text:
|
||||
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
|
||||
processed_plain_text = MessageStorage.replace_image_descriptions(processed_plain_text)
|
||||
filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL)
|
||||
else:
|
||||
filtered_processed_plain_text = ""
|
||||
@@ -51,7 +51,8 @@ class MessageStorage:
|
||||
if display_message:
|
||||
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
|
||||
else:
|
||||
filtered_display_message = ""
|
||||
# 如果没有设置display_message,使用processed_plain_text作为显示消息
|
||||
filtered_display_message = re.sub(pattern, "", message.processed_plain_text, flags=re.DOTALL) if message.processed_plain_text else ""
|
||||
interest_value = 0
|
||||
is_mentioned = False
|
||||
reply_to = message.reply_to
|
||||
@@ -116,14 +117,21 @@ class MessageStorage:
|
||||
user_nickname=user_info_dict.get("user_nickname"),
|
||||
user_cardname=user_info_dict.get("user_cardname"),
|
||||
processed_plain_text=filtered_processed_plain_text,
|
||||
display_message=filtered_display_message,
|
||||
memorized_times=message.memorized_times,
|
||||
interest_value=interest_value,
|
||||
priority_mode=priority_mode,
|
||||
priority_info=priority_info_json,
|
||||
is_emoji=is_emoji,
|
||||
is_picid=is_picid,
|
||||
is_notify=is_notify,
|
||||
is_command=is_command,
|
||||
key_words=key_words,
|
||||
key_words_lite=key_words_lite,
|
||||
)
|
||||
async with get_db_session() as session:
|
||||
with get_db_session() as session:
|
||||
session.add(new_message)
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
except Exception:
|
||||
logger.exception("存储消息失败")
|
||||
@@ -146,7 +154,8 @@ class MessageStorage:
|
||||
qq_message_id = message.message_segment.data.get("id")
|
||||
elif message.message_segment.type == "reply":
|
||||
qq_message_id = message.message_segment.data.get("id")
|
||||
logger.debug(f"从reply消息段获取到消息ID: {qq_message_id}")
|
||||
if qq_message_id:
|
||||
logger.debug(f"从reply消息段获取到消息ID: {qq_message_id}")
|
||||
elif message.message_segment.type == "adapter_response":
|
||||
logger.debug("适配器响应消息,不需要更新ID")
|
||||
return
|
||||
@@ -162,19 +171,18 @@ class MessageStorage:
|
||||
logger.debug(f"消息段数据: {message.message_segment.data}")
|
||||
return
|
||||
|
||||
async with get_db_session() as session:
|
||||
matched_message = (
|
||||
await session.execute(
|
||||
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
|
||||
)
|
||||
# 使用上下文管理器确保session正确管理
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
|
||||
with get_db_session() as session:
|
||||
matched_message = session.execute(
|
||||
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
|
||||
).scalar()
|
||||
|
||||
if matched_message:
|
||||
await session.execute(
|
||||
session.execute(
|
||||
update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id)
|
||||
)
|
||||
await session.commit()
|
||||
# 会在上下文管理器中自动调用
|
||||
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
|
||||
else:
|
||||
logger.warning(f"未找到匹配的消息记录: {mmc_message_id}")
|
||||
@@ -186,36 +194,117 @@ class MessageStorage:
|
||||
f"segment_type={getattr(message.message_segment, 'type', 'N/A')}"
|
||||
)
|
||||
|
||||
async def replace_image_descriptions(text: str) -> str:
|
||||
@staticmethod
|
||||
def replace_image_descriptions(text: str) -> str:
|
||||
"""将[图片:描述]替换为[picid:image_id]"""
|
||||
# 先检查文本中是否有图片标记
|
||||
pattern = r"\[图片:([^\]]+)\]"
|
||||
matches = list(re.finditer(pattern, text))
|
||||
matches = re.findall(pattern, text)
|
||||
|
||||
if not matches:
|
||||
logger.debug("文本中没有图片标记,直接返回原文本")
|
||||
return text
|
||||
|
||||
new_text = ""
|
||||
last_end = 0
|
||||
for match in matches:
|
||||
new_text += text[last_end : match.start()]
|
||||
def replace_match(match):
|
||||
description = match.group(1).strip()
|
||||
try:
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
|
||||
async with get_db_session() as session:
|
||||
image_record = (
|
||||
await session.execute(
|
||||
select(Images).where(Images.description == description).order_by(desc(Images.timestamp))
|
||||
)
|
||||
with get_db_session() as session:
|
||||
image_record = session.execute(
|
||||
select(Images).where(Images.description == description).order_by(desc(Images.timestamp))
|
||||
).scalar()
|
||||
if image_record:
|
||||
new_text += f"[picid:{image_record.image_id}]"
|
||||
else:
|
||||
new_text += match.group(0)
|
||||
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
|
||||
except Exception:
|
||||
new_text += match.group(0)
|
||||
last_end = match.end()
|
||||
new_text += text[last_end:]
|
||||
return new_text
|
||||
return match.group(0)
|
||||
|
||||
@staticmethod
|
||||
def update_message_interest_value(message_id: str, interest_value: float) -> None:
|
||||
"""
|
||||
更新数据库中消息的interest_value字段
|
||||
|
||||
Args:
|
||||
message_id: 消息ID
|
||||
interest_value: 兴趣度值
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
# 更新消息的interest_value字段
|
||||
stmt = update(Messages).where(Messages.message_id == message_id).values(interest_value=interest_value)
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
if result.rowcount > 0:
|
||||
logger.debug(f"成功更新消息 {message_id} 的interest_value为 {interest_value}")
|
||||
else:
|
||||
logger.warning(f"未找到消息 {message_id},无法更新interest_value")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新消息 {message_id} 的interest_value失败: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def fix_zero_interest_values(chat_id: str, since_time: float) -> int:
|
||||
"""
|
||||
修复指定聊天中interest_value为0或null的历史消息记录
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
since_time: 从指定时间开始修复(时间戳)
|
||||
|
||||
Returns:
|
||||
修复的记录数量
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
from sqlalchemy import select, update
|
||||
from src.common.database.sqlalchemy_models import Messages
|
||||
|
||||
# 查找需要修复的记录:interest_value为0、null或很小的值
|
||||
query = select(Messages).where(
|
||||
(Messages.chat_id == chat_id) &
|
||||
(Messages.time >= since_time) &
|
||||
(
|
||||
(Messages.interest_value == 0) |
|
||||
(Messages.interest_value.is_(None)) |
|
||||
(Messages.interest_value < 0.1)
|
||||
)
|
||||
).limit(50) # 限制每次修复的数量,避免性能问题
|
||||
|
||||
messages_to_fix = session.execute(query).scalars().all()
|
||||
fixed_count = 0
|
||||
|
||||
for msg in messages_to_fix:
|
||||
# 为这些消息设置一个合理的默认兴趣度
|
||||
# 可以基于消息长度、内容或其他因素计算
|
||||
default_interest = 0.3 # 默认中等兴趣度
|
||||
|
||||
# 如果消息内容较长,可能是重要消息,兴趣度稍高
|
||||
if hasattr(msg, 'processed_plain_text') and msg.processed_plain_text:
|
||||
text_length = len(msg.processed_plain_text)
|
||||
if text_length > 50: # 长消息
|
||||
default_interest = 0.4
|
||||
elif text_length > 20: # 中等长度消息
|
||||
default_interest = 0.35
|
||||
|
||||
# 如果是被@的消息,兴趣度更高
|
||||
if getattr(msg, 'is_mentioned', False):
|
||||
default_interest = min(default_interest + 0.2, 0.8)
|
||||
|
||||
# 执行更新
|
||||
update_stmt = update(Messages).where(
|
||||
Messages.message_id == msg.message_id
|
||||
).values(interest_value=default_interest)
|
||||
|
||||
result = session.execute(update_stmt)
|
||||
if result.rowcount > 0:
|
||||
fixed_count += 1
|
||||
logger.debug(f"修复消息 {msg.message_id} 的interest_value为 {default_interest}")
|
||||
|
||||
session.commit()
|
||||
logger.info(f"共修复了 {fixed_count} 条历史消息的interest_value值")
|
||||
return fixed_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"修复历史消息interest_value失败: {e}")
|
||||
return 0
|
||||
|
||||
@@ -1,15 +1,24 @@
|
||||
from typing import Dict, Optional, Type
|
||||
import asyncio
|
||||
import traceback
|
||||
import time
|
||||
from typing import Dict, Optional, Type, Any, Tuple
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.base.component_types import ComponentType, ActionInfo
|
||||
from src.plugin_system.base.base_action import BaseAction
|
||||
from src.plugin_system.apis import generator_api, database_api, send_api, message_api
|
||||
|
||||
|
||||
logger = get_logger("action_manager")
|
||||
|
||||
|
||||
class ActionManager:
|
||||
class ChatterActionManager:
|
||||
"""
|
||||
动作管理器,用于管理各种类型的动作
|
||||
|
||||
@@ -25,6 +34,8 @@ class ActionManager:
|
||||
# 初始化时将默认动作加载到使用中的动作
|
||||
self._using_actions = component_registry.get_default_actions()
|
||||
|
||||
self.log_prefix: str = "ChatterActionManager"
|
||||
|
||||
# === 执行Action方法 ===
|
||||
|
||||
@staticmethod
|
||||
@@ -124,3 +135,417 @@ class ActionManager:
|
||||
actions_to_restore = list(self._using_actions.keys())
|
||||
self._using_actions = component_registry.get_default_actions()
|
||||
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
|
||||
|
||||
async def execute_action(
|
||||
self,
|
||||
action_name: str,
|
||||
chat_id: str,
|
||||
target_message: Optional[dict] = None,
|
||||
reasoning: str = "",
|
||||
action_data: Optional[dict] = None,
|
||||
thinking_id: Optional[str] = None,
|
||||
log_prefix: str = "",
|
||||
) -> Any:
|
||||
"""
|
||||
执行单个动作的通用函数
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
chat_id: 聊天id
|
||||
target_message: 目标消息
|
||||
reasoning: 执行理由
|
||||
action_data: 动作数据
|
||||
thinking_id: 思考ID
|
||||
log_prefix: 日志前缀
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
from src.chat.message_manager.message_manager import message_manager
|
||||
try:
|
||||
logger.debug(f"🎯 [ActionManager] execute_action接收到 target_message: {target_message}")
|
||||
# 通过chat_id获取chat_stream
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(chat_id)
|
||||
|
||||
if not chat_stream:
|
||||
logger.error(f"{log_prefix} 无法找到chat_id对应的chat_stream: {chat_id}")
|
||||
return {
|
||||
"action_type": action_name,
|
||||
"success": False,
|
||||
"reply_text": "",
|
||||
"error": "chat_stream not found",
|
||||
}
|
||||
|
||||
if action_name == "no_action":
|
||||
return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""}
|
||||
|
||||
if action_name == "no_reply":
|
||||
# 直接处理no_reply逻辑,不再通过动作系统
|
||||
reason = reasoning or "选择不回复"
|
||||
logger.info(f"{log_prefix} 选择不回复,原因: {reason}")
|
||||
|
||||
# 存储no_reply信息到数据库
|
||||
await database_api.store_action_info(
|
||||
chat_stream=chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reason": reason},
|
||||
action_name="no_reply",
|
||||
)
|
||||
return {"action_type": "no_reply", "success": True, "reply_text": "", "command": ""}
|
||||
|
||||
elif action_name != "reply" and action_name != "no_action":
|
||||
# 执行普通动作
|
||||
success, reply_text, command = await self._handle_action(
|
||||
chat_stream,
|
||||
action_name,
|
||||
reasoning,
|
||||
action_data or {},
|
||||
{}, # cycle_timers
|
||||
thinking_id,
|
||||
target_message,
|
||||
)
|
||||
|
||||
# 记录执行的动作到目标消息
|
||||
if success:
|
||||
await self._record_action_to_message(chat_stream, action_name, target_message, action_data)
|
||||
# 重置打断计数
|
||||
await self._reset_interruption_count_after_action(chat_stream.stream_id)
|
||||
|
||||
return {
|
||||
"action_type": action_name,
|
||||
"success": success,
|
||||
"reply_text": reply_text,
|
||||
"command": command,
|
||||
}
|
||||
else:
|
||||
# 生成回复
|
||||
try:
|
||||
success, response_set, _ = await generator_api.generate_reply(
|
||||
chat_stream=chat_stream,
|
||||
reply_message=target_message,
|
||||
action_data=action_data or {},
|
||||
available_actions=self.get_using_actions(),
|
||||
enable_tool=global_config.tool.enable_tool,
|
||||
request_type="chat.replyer",
|
||||
from_plugin=False,
|
||||
)
|
||||
if not success or not response_set:
|
||||
logger.info(
|
||||
f"对 {target_message.get('processed_plain_text') if target_message else '未知消息'} 的回复生成失败"
|
||||
)
|
||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"{log_prefix} 并行执行:回复生成任务已被取消")
|
||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||
|
||||
# 发送并存储回复
|
||||
loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply(
|
||||
chat_stream,
|
||||
response_set,
|
||||
asyncio.get_event_loop().time(),
|
||||
target_message,
|
||||
{}, # cycle_timers
|
||||
thinking_id,
|
||||
[], # actions
|
||||
)
|
||||
|
||||
# 记录回复动作到目标消息
|
||||
await self._record_action_to_message(chat_stream, "reply", target_message, action_data)
|
||||
|
||||
# 回复成功,重置打断计数
|
||||
await self._reset_interruption_count_after_action(chat_stream.stream_id)
|
||||
|
||||
return {"action_type": "reply", "success": True, "reply_text": reply_text, "loop_info": loop_info}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{log_prefix} 执行动作时出错: {e}")
|
||||
logger.error(f"{log_prefix} 错误信息: {traceback.format_exc()}")
|
||||
return {
|
||||
"action_type": action_name,
|
||||
"success": False,
|
||||
"reply_text": "",
|
||||
"loop_info": None,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
async def _record_action_to_message(self, chat_stream, action_name, target_message, action_data):
|
||||
"""
|
||||
记录执行的动作到目标消息中
|
||||
|
||||
Args:
|
||||
chat_stream: ChatStream实例
|
||||
action_name: 动作名称
|
||||
target_message: 目标消息
|
||||
action_data: 动作数据
|
||||
"""
|
||||
try:
|
||||
from src.chat.message_manager.message_manager import message_manager
|
||||
|
||||
# 获取目标消息ID
|
||||
target_message_id = None
|
||||
if target_message and isinstance(target_message, dict):
|
||||
target_message_id = target_message.get("message_id")
|
||||
elif action_data and isinstance(action_data, dict):
|
||||
target_message_id = action_data.get("target_message_id")
|
||||
|
||||
if not target_message_id:
|
||||
logger.debug(f"无法获取目标消息ID,动作: {action_name}")
|
||||
return
|
||||
|
||||
# 通过message_manager更新消息的动作记录并刷新focus_energy
|
||||
if chat_stream.stream_id in message_manager.stream_contexts:
|
||||
message_manager.add_action(
|
||||
stream_id=chat_stream.stream_id,
|
||||
message_id=target_message_id,
|
||||
action=action_name
|
||||
)
|
||||
logger.debug(f"已记录动作 {action_name} 到消息 {target_message_id} 并更新focus_energy")
|
||||
else:
|
||||
logger.debug(f"未找到stream_context: {chat_stream.stream_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"记录动作到消息失败: {e}")
|
||||
# 不抛出异常,避免影响主要功能
|
||||
|
||||
async def _reset_interruption_count_after_action(self, stream_id: str):
|
||||
"""在动作执行成功后重置打断计数"""
|
||||
from src.chat.message_manager.message_manager import message_manager
|
||||
try:
|
||||
if stream_id in message_manager.stream_contexts:
|
||||
context = message_manager.stream_contexts[stream_id]
|
||||
if context.interruption_count > 0:
|
||||
old_count = context.interruption_count
|
||||
old_afc_adjustment = context.get_afc_threshold_adjustment()
|
||||
context.reset_interruption_count()
|
||||
logger.debug(f"动作执行成功,重置聊天流 {stream_id} 的打断计数: {old_count} -> 0, afc调整: {old_afc_adjustment} -> 0")
|
||||
except Exception as e:
|
||||
logger.warning(f"重置打断计数时出错: {e}")
|
||||
|
||||
async def _handle_action(
|
||||
self, chat_stream, action, reasoning, action_data, cycle_timers, thinking_id, action_message
|
||||
) -> tuple[bool, str, str]:
|
||||
"""
|
||||
处理具体的动作执行
|
||||
|
||||
Args:
|
||||
chat_stream: ChatStream实例
|
||||
action: 动作名称
|
||||
reasoning: 执行理由
|
||||
action_data: 动作数据
|
||||
cycle_timers: 循环计时器
|
||||
thinking_id: 思考ID
|
||||
action_message: 动作消息
|
||||
|
||||
Returns:
|
||||
tuple: (执行是否成功, 回复文本, 命令文本)
|
||||
|
||||
功能说明:
|
||||
- 创建对应的动作处理器
|
||||
- 执行动作并捕获异常
|
||||
- 返回执行结果供上级方法整合
|
||||
"""
|
||||
if not chat_stream:
|
||||
return False, "", ""
|
||||
try:
|
||||
# 创建动作处理器
|
||||
action_handler = self.create_action(
|
||||
action_name=action,
|
||||
action_data=action_data,
|
||||
reasoning=reasoning,
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
chat_stream=chat_stream,
|
||||
log_prefix=self.log_prefix,
|
||||
action_message=action_message,
|
||||
)
|
||||
if not action_handler:
|
||||
# 动作处理器创建失败,尝试回退机制
|
||||
logger.warning(f"{self.log_prefix} 创建动作处理器失败: {action},尝试回退方案")
|
||||
|
||||
# 获取当前可用的动作
|
||||
available_actions = self.get_using_actions()
|
||||
fallback_action = None
|
||||
|
||||
# 回退优先级:reply > 第一个可用动作
|
||||
if "reply" in available_actions:
|
||||
fallback_action = "reply"
|
||||
elif available_actions:
|
||||
fallback_action = list(available_actions.keys())[0]
|
||||
|
||||
if fallback_action and fallback_action != action:
|
||||
logger.info(f"{self.log_prefix} 使用回退动作: {fallback_action}")
|
||||
action_handler = self.create_action(
|
||||
action_name=fallback_action,
|
||||
action_data=action_data,
|
||||
reasoning=f"原动作'{action}'不可用,自动回退。{reasoning}",
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
chat_stream=chat_stream,
|
||||
log_prefix=self.log_prefix,
|
||||
action_message=action_message,
|
||||
)
|
||||
|
||||
if not action_handler:
|
||||
logger.error(f"{self.log_prefix} 回退方案也失败,无法创建任何动作处理器")
|
||||
return False, "", ""
|
||||
|
||||
# 执行动作
|
||||
success, reply_text = await action_handler.handle_action()
|
||||
return success, reply_text, ""
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 处理{action}时出错: {e}")
|
||||
traceback.print_exc()
|
||||
return False, "", ""
|
||||
|
||||
async def _send_and_store_reply(
|
||||
self,
|
||||
chat_stream: ChatStream,
|
||||
response_set,
|
||||
loop_start_time,
|
||||
action_message,
|
||||
cycle_timers: Dict[str, float],
|
||||
thinking_id,
|
||||
actions,
|
||||
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
|
||||
"""
|
||||
发送并存储回复信息
|
||||
|
||||
Args:
|
||||
chat_stream: ChatStream实例
|
||||
response_set: 回复内容集合
|
||||
loop_start_time: 循环开始时间
|
||||
action_message: 动作消息
|
||||
cycle_timers: 循环计时器
|
||||
thinking_id: 思考ID
|
||||
actions: 动作列表
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, Any], str, Dict[str, float]]: 循环信息, 回复文本, 循环计时器
|
||||
"""
|
||||
# 发送回复
|
||||
with Timer("回复发送", cycle_timers):
|
||||
reply_text = await self.send_response(chat_stream, response_set, loop_start_time, action_message)
|
||||
|
||||
# 存储reply action信息
|
||||
person_info_manager = get_person_info_manager()
|
||||
|
||||
# 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
|
||||
platform = action_message.get("chat_info_platform")
|
||||
if platform is None:
|
||||
platform = getattr(chat_stream, "platform", "unknown")
|
||||
|
||||
# 获取用户信息并生成回复提示
|
||||
person_id = person_info_manager.get_person_id(
|
||||
platform,
|
||||
action_message.get("user_id", ""),
|
||||
)
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
|
||||
|
||||
# 存储动作信息到数据库
|
||||
await database_api.store_action_info(
|
||||
chat_stream=chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=action_prompt_display,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reply_text": reply_text},
|
||||
action_name="reply",
|
||||
)
|
||||
|
||||
# 构建循环信息
|
||||
loop_info: Dict[str, Any] = {
|
||||
"loop_plan_info": {
|
||||
"action_result": actions,
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": True,
|
||||
"reply_text": reply_text,
|
||||
"command": "",
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
|
||||
return loop_info, reply_text, cycle_timers
|
||||
|
||||
async def send_response(self, chat_stream, reply_set, thinking_start_time, message_data) -> str:
|
||||
"""
|
||||
发送回复内容的具体实现
|
||||
|
||||
Args:
|
||||
chat_stream: ChatStream实例
|
||||
reply_set: 回复内容集合,包含多个回复段
|
||||
reply_to: 回复目标
|
||||
thinking_start_time: 思考开始时间
|
||||
message_data: 消息数据
|
||||
|
||||
Returns:
|
||||
str: 完整的回复文本
|
||||
|
||||
功能说明:
|
||||
- 检查是否有新消息需要回复
|
||||
- 处理主动思考的"沉默"决定
|
||||
- 根据消息数量决定是否添加回复引用
|
||||
- 逐段发送回复内容,支持打字效果
|
||||
- 正确处理元组格式的回复段
|
||||
"""
|
||||
current_time = time.time()
|
||||
# 计算新消息数量
|
||||
new_message_count = message_api.count_new_messages(
|
||||
chat_id=chat_stream.stream_id, start_time=thinking_start_time, end_time=current_time
|
||||
)
|
||||
|
||||
# 根据新消息数量决定是否需要引用回复
|
||||
reply_text = ""
|
||||
is_proactive_thinking = (message_data.get("message_type") == "proactive_thinking") if message_data else True
|
||||
|
||||
logger.debug(f"[send_response] message_data: {message_data}")
|
||||
|
||||
first_replied = False
|
||||
for reply_seg in reply_set:
|
||||
# 调试日志:验证reply_seg的格式
|
||||
logger.debug(f"Processing reply_seg type: {type(reply_seg)}, content: {reply_seg}")
|
||||
|
||||
# 修正:正确处理元组格式 (格式为: (type, content))
|
||||
if isinstance(reply_seg, tuple) and len(reply_seg) >= 2:
|
||||
_, data = reply_seg
|
||||
else:
|
||||
# 向下兼容:如果已经是字符串,则直接使用
|
||||
data = str(reply_seg)
|
||||
|
||||
if isinstance(data, list):
|
||||
data = "".join(map(str, data))
|
||||
reply_text += data
|
||||
|
||||
# 如果是主动思考且内容为"沉默",则不发送
|
||||
if is_proactive_thinking and data.strip() == "沉默":
|
||||
logger.info(f"{self.log_prefix} 主动思考决定保持沉默,不发送消息")
|
||||
continue
|
||||
|
||||
# 发送第一段回复
|
||||
if not first_replied:
|
||||
set_reply_flag = bool(message_data)
|
||||
logger.debug(f"📤 [ActionManager] 准备发送第一段回复。message_data: {message_data}, set_reply: {set_reply_flag}")
|
||||
await send_api.text_to_stream(
|
||||
text=data,
|
||||
stream_id=chat_stream.stream_id,
|
||||
reply_to_message=message_data,
|
||||
set_reply=set_reply_flag,
|
||||
typing=False,
|
||||
)
|
||||
first_replied = True
|
||||
else:
|
||||
# 发送后续回复
|
||||
sent_message = await send_api.text_to_stream(
|
||||
text=data,
|
||||
stream_id=chat_stream.stream_id,
|
||||
reply_to_message=None,
|
||||
set_reply=False,
|
||||
typing=True,
|
||||
)
|
||||
|
||||
return reply_text
|
||||
@@ -7,8 +7,9 @@ from typing import List, Any, Dict, TYPE_CHECKING, Tuple
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageContext
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
|
||||
from src.plugin_system.base.component_types import ActionInfo, ActionActivationType
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
@@ -27,7 +28,7 @@ class ActionModifier:
|
||||
支持并行判定和智能缓存优化。
|
||||
"""
|
||||
|
||||
def __init__(self, action_manager: ActionManager, chat_id: str):
|
||||
def __init__(self, action_manager: ChatterActionManager, chat_id: str):
|
||||
"""初始化动作处理器"""
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream: ChatStream = get_chat_manager().get_stream(self.chat_id) # type: ignore
|
||||
@@ -124,8 +125,9 @@ class ActionModifier:
|
||||
logger.debug(f"{self.log_prefix}阶段一移除动作: {disabled_action_name},原因: 用户自行禁用")
|
||||
|
||||
# === 第二阶段:检查动作的关联类型 ===
|
||||
chat_context = self.chat_stream.context
|
||||
type_mismatched_actions = self._check_action_associated_types(all_actions, chat_context)
|
||||
chat_context = self.chat_stream.stream_context
|
||||
current_actions_s2 = self.action_manager.get_using_actions()
|
||||
type_mismatched_actions = self._check_action_associated_types(current_actions_s2, chat_context)
|
||||
|
||||
if type_mismatched_actions:
|
||||
removals_s2.extend(type_mismatched_actions)
|
||||
@@ -140,11 +142,12 @@ class ActionModifier:
|
||||
logger.debug(f"{self.log_prefix}开始激活类型判定阶段")
|
||||
|
||||
# 获取当前使用的动作集(经过第一阶段处理)
|
||||
current_using_actions = self.action_manager.get_using_actions()
|
||||
# 在第三阶段开始前,再次获取最新的动作列表
|
||||
current_actions_s3 = self.action_manager.get_using_actions()
|
||||
|
||||
# 获取因激活类型判定而需要移除的动作
|
||||
removals_s3 = await self._get_deactivated_actions_by_type(
|
||||
current_using_actions,
|
||||
current_actions_s3,
|
||||
chat_content,
|
||||
)
|
||||
|
||||
@@ -164,7 +167,7 @@ class ActionModifier:
|
||||
|
||||
logger.info(f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}")
|
||||
|
||||
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext):
|
||||
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: StreamContext):
|
||||
type_mismatched_actions: List[Tuple[str, str]] = []
|
||||
for action_name, action_info in all_actions.items():
|
||||
if action_info.associated_types and not chat_context.check_types(action_info.associated_types):
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
"""
|
||||
PlanExecutor: 接收 Plan 对象并执行其中的所有动作。
|
||||
"""
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.common.data_models.info_data_model import Plan
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("plan_executor")
|
||||
|
||||
|
||||
class PlanExecutor:
|
||||
"""
|
||||
负责接收一个 Plan 对象,并执行其中最终确定的所有动作。
|
||||
|
||||
这个类是规划流程的最后一步,将规划结果转化为实际的动作执行。
|
||||
|
||||
Attributes:
|
||||
action_manager (ActionManager): 用于实际执行各种动作的管理器实例。
|
||||
"""
|
||||
|
||||
def __init__(self, action_manager: ActionManager):
|
||||
"""
|
||||
初始化 PlanExecutor。
|
||||
|
||||
Args:
|
||||
action_manager (ActionManager): 一个 ActionManager 实例,用于执行动作。
|
||||
"""
|
||||
self.action_manager = action_manager
|
||||
|
||||
@staticmethod
|
||||
async def execute(plan: Plan):
|
||||
"""
|
||||
遍历并执行 Plan 对象中 `decided_actions` 列表里的所有动作。
|
||||
|
||||
如果动作类型为 "no_action",则会记录原因并跳过。
|
||||
否则,它将调用 ActionManager 来执行相应的动作。
|
||||
|
||||
Args:
|
||||
plan (Plan): 包含待执行动作列表的 Plan 对象。
|
||||
"""
|
||||
if not plan.decided_actions:
|
||||
logger.info("没有需要执行的动作。")
|
||||
return
|
||||
|
||||
for action_info in plan.decided_actions:
|
||||
if action_info.action_type == "no_action":
|
||||
logger.info(f"规划器决策不执行动作,原因: {action_info.reasoning}")
|
||||
continue
|
||||
|
||||
# TODO: 对接 ActionManager 的执行方法
|
||||
# 这是一个示例调用,需要根据 ActionManager 的最终实现进行调整
|
||||
logger.info(f"执行动作: {action_info.action_type}, 原因: {action_info.reasoning}")
|
||||
# await self.action_manager.execute_action(
|
||||
# action_name=action_info.action_type,
|
||||
# action_data=action_info.action_data,
|
||||
# reasoning=action_info.reasoning,
|
||||
# action_message=action_info.action_message,
|
||||
# )
|
||||
@@ -1,366 +0,0 @@
|
||||
"""
|
||||
PlanFilter: 接收 Plan 对象,根据不同模式的逻辑进行筛选,决定最终要执行的动作。
|
||||
"""
|
||||
import orjson
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_actions,
|
||||
build_readable_messages_with_id,
|
||||
get_actions_by_timestamp_with_chat,
|
||||
)
|
||||
from src.chat.utils.prompt import global_prompt_manager
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo, Plan
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode
|
||||
from src.schedule.schedule_manager import schedule_manager
|
||||
|
||||
logger = get_logger("plan_filter")
|
||||
|
||||
|
||||
class PlanFilter:
|
||||
"""
|
||||
根据 Plan 中的模式和信息,筛选并决定最终的动作。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.planner_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.planner, request_type="planner"
|
||||
)
|
||||
self.last_obs_time_mark = 0.0
|
||||
|
||||
async def filter(self, plan: Plan) -> Plan:
|
||||
"""
|
||||
执行筛选逻辑,并填充 Plan 对象的 decided_actions 字段。
|
||||
"""
|
||||
logger.debug(f"墨墨在这里加了日志 -> filter 入口 plan: {plan}")
|
||||
try:
|
||||
prompt, used_message_id_list = await self._build_prompt(plan)
|
||||
plan.llm_prompt = prompt
|
||||
logger.info(f"规划器原始提示词: {prompt}")
|
||||
|
||||
llm_content, _ = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||
|
||||
if llm_content:
|
||||
logger.info(f"规划器原始返回: {llm_content}")
|
||||
parsed_json = orjson.loads(repair_json(llm_content))
|
||||
logger.debug(f"墨墨在这里加了日志 -> 解析后的 JSON: {parsed_json}")
|
||||
|
||||
if isinstance(parsed_json, dict):
|
||||
parsed_json = [parsed_json]
|
||||
|
||||
if isinstance(parsed_json, list):
|
||||
final_actions = []
|
||||
reply_action_added = False
|
||||
# 定义回复类动作的集合,方便扩展
|
||||
reply_action_types = {"reply", "proactive_reply"}
|
||||
|
||||
for item in parsed_json:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
# 预解析 action_type 来进行判断
|
||||
action_type = item.get("action", "no_action")
|
||||
|
||||
if action_type in reply_action_types:
|
||||
if not reply_action_added:
|
||||
final_actions.extend(
|
||||
await self._parse_single_action(
|
||||
item, used_message_id_list, plan
|
||||
)
|
||||
)
|
||||
reply_action_added = True
|
||||
else:
|
||||
# 非回复类动作直接添加
|
||||
final_actions.extend(
|
||||
await self._parse_single_action(
|
||||
item, used_message_id_list, plan
|
||||
)
|
||||
)
|
||||
|
||||
plan.decided_actions = self._filter_no_actions(final_actions)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"筛选 Plan 时出错: {e}\n{traceback.format_exc()}")
|
||||
plan.decided_actions = [
|
||||
ActionPlannerInfo(action_type="no_action", reasoning=f"筛选时出错: {e}")
|
||||
]
|
||||
|
||||
logger.debug(f"墨墨在这里加了日志 -> filter 出口 decided_actions: {plan.decided_actions}")
|
||||
return plan
|
||||
|
||||
async def _build_prompt(self, plan: Plan) -> tuple[str, list]:
|
||||
"""
|
||||
根据 Plan 对象构建提示词。
|
||||
"""
|
||||
try:
|
||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
bot_name = global_config.bot.nickname
|
||||
bot_nickname = (
|
||||
f",也有人叫你{','.join(global_config.bot.alias_names)}" if global_config.bot.alias_names else ""
|
||||
)
|
||||
bot_core_personality = global_config.personality.personality_core
|
||||
identity_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}:"
|
||||
|
||||
schedule_block = ""
|
||||
if global_config.planning_system.schedule_enable:
|
||||
if current_activity := schedule_manager.get_current_activity():
|
||||
schedule_block = f"你当前正在:{current_activity},但注意它与群聊的聊天无关。"
|
||||
|
||||
mood_block = ""
|
||||
if global_config.mood.enable_mood:
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(plan.chat_id)
|
||||
mood_block = f"你现在的心情是:{chat_mood.mood_state}"
|
||||
|
||||
if plan.mode == ChatMode.PROACTIVE:
|
||||
long_term_memory_block = await self._get_long_term_memory_context()
|
||||
|
||||
chat_content_block, message_id_list = await build_readable_messages_with_id(
|
||||
messages=[msg.flatten() for msg in plan.chat_history],
|
||||
timestamp_mode="normal",
|
||||
truncate=False,
|
||||
show_actions=False,
|
||||
)
|
||||
|
||||
prompt_template = await global_prompt_manager.get_prompt_async("proactive_planner_prompt")
|
||||
actions_before_now = await get_actions_by_timestamp_with_chat(
|
||||
chat_id=plan.chat_id,
|
||||
timestamp_start=time.time() - 3600,
|
||||
timestamp_end=time.time(),
|
||||
limit=5,
|
||||
)
|
||||
actions_before_now_block = build_readable_actions(actions=actions_before_now)
|
||||
actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}"
|
||||
|
||||
prompt = prompt_template.format(
|
||||
time_block=time_block,
|
||||
identity_block=identity_block,
|
||||
schedule_block=schedule_block,
|
||||
mood_block=mood_block,
|
||||
long_term_memory_block=long_term_memory_block,
|
||||
chat_content_block=chat_content_block or "最近没有聊天内容。",
|
||||
actions_before_now_block=actions_before_now_block,
|
||||
)
|
||||
return prompt, message_id_list
|
||||
|
||||
chat_content_block, message_id_list = await build_readable_messages_with_id(
|
||||
messages=[msg.flatten() for msg in plan.chat_history],
|
||||
timestamp_mode="normal",
|
||||
read_mark=self.last_obs_time_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
actions_before_now = await get_actions_by_timestamp_with_chat(
|
||||
chat_id=plan.chat_id,
|
||||
timestamp_start=time.time() - 3600,
|
||||
timestamp_end=time.time(),
|
||||
limit=5,
|
||||
)
|
||||
|
||||
actions_before_now_block = build_readable_actions(actions=actions_before_now)
|
||||
actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}"
|
||||
|
||||
self.last_obs_time_mark = time.time()
|
||||
|
||||
mentioned_bonus = ""
|
||||
if global_config.chat.mentioned_bot_inevitable_reply:
|
||||
mentioned_bonus = "\n- 有人提到你"
|
||||
if global_config.chat.at_bot_inevitable_reply:
|
||||
mentioned_bonus = "\n- 有人提到你,或者at你"
|
||||
|
||||
if plan.mode == ChatMode.FOCUS:
|
||||
no_action_block = """
|
||||
动作:no_action
|
||||
动作描述:不选择任何动作
|
||||
{{
|
||||
"action": "no_action",
|
||||
"reason":"不动作的原因"
|
||||
}}
|
||||
|
||||
动作:no_reply
|
||||
动作描述:不进行回复,等待合适的回复时机
|
||||
- 当你刚刚发送了消息,没有人回复时,选择no_reply
|
||||
- 当你一次发送了太多消息,为了避免打扰聊天节奏,选择no_reply
|
||||
{{
|
||||
"action": "no_reply",
|
||||
"reason":"不回复的原因"
|
||||
}}
|
||||
"""
|
||||
else: # NORMAL Mode
|
||||
no_action_block = """重要说明:
|
||||
- 'reply' 表示只进行普通聊天回复,不执行任何额外动作
|
||||
- 其他action表示在普通回复的基础上,执行相应的额外动作
|
||||
{{
|
||||
"action": "reply",
|
||||
"target_message_id":"触发action的消息id",
|
||||
"reason":"回复的原因"
|
||||
}}"""
|
||||
|
||||
is_group_chat = plan.target_info.platform == "group" if plan.target_info else True
|
||||
chat_context_description = "你现在正在一个群聊中"
|
||||
if not is_group_chat and plan.target_info:
|
||||
chat_target_name = plan.target_info.person_name or plan.target_info.user_nickname or "对方"
|
||||
chat_context_description = f"你正在和 {chat_target_name} 私聊"
|
||||
|
||||
action_options_block = await self._build_action_options(plan.available_actions)
|
||||
|
||||
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
|
||||
|
||||
custom_prompt_block = ""
|
||||
if global_config.custom_prompt.planner_custom_prompt_content:
|
||||
custom_prompt_block = global_config.custom_prompt.planner_custom_prompt_content
|
||||
|
||||
users_in_chat_str = "" # TODO: Re-implement user list fetching if needed
|
||||
|
||||
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
|
||||
prompt = planner_prompt_template.format(
|
||||
schedule_block=schedule_block,
|
||||
mood_block=mood_block,
|
||||
time_block=time_block,
|
||||
chat_context_description=chat_context_description,
|
||||
chat_content_block=chat_content_block,
|
||||
actions_before_now_block=actions_before_now_block,
|
||||
mentioned_bonus=mentioned_bonus,
|
||||
no_action_block=no_action_block,
|
||||
action_options_text=action_options_block,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
identity_block=identity_block,
|
||||
custom_prompt_block=custom_prompt_block,
|
||||
bot_name=bot_name,
|
||||
users_in_chat=users_in_chat_str
|
||||
)
|
||||
return prompt, message_id_list
|
||||
except Exception as e:
|
||||
logger.error(f"构建 Planner 提示词时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return "构建 Planner Prompt 时出错", []
|
||||
|
||||
async def _parse_single_action(
|
||||
self, action_json: dict, message_id_list: list, plan: Plan
|
||||
) -> List[ActionPlannerInfo]:
|
||||
parsed_actions = []
|
||||
try:
|
||||
action = action_json.get("action", "no_action")
|
||||
reasoning = action_json.get("reason", "未提供原因")
|
||||
action_data = {k: v for k, v in action_json.items() if k not in ["action", "reason"]}
|
||||
|
||||
target_message_obj = None
|
||||
if action not in ["no_action", "no_reply", "do_nothing", "proactive_reply"]:
|
||||
if target_message_id := action_json.get("target_message_id"):
|
||||
target_message_dict = self._find_message_by_id(target_message_id, message_id_list)
|
||||
else:
|
||||
# 如果LLM没有指定target_message_id,我们就默认选择最新的一条消息
|
||||
target_message_dict = self._get_latest_message(message_id_list)
|
||||
|
||||
if target_message_dict:
|
||||
# 直接使用字典作为action_message,避免DatabaseMessages对象创建失败
|
||||
target_message_obj = target_message_dict
|
||||
else:
|
||||
# 如果找不到目标消息,对于reply动作来说这是必需的,应该记录警告
|
||||
if action == "reply":
|
||||
logger.warning(f"reply动作找不到目标消息,target_message_id: {action_json.get('target_message_id')}")
|
||||
# 将reply动作改为no_action,避免后续执行时出错
|
||||
action = "no_action"
|
||||
reasoning = f"找不到目标消息进行回复。原始理由: {reasoning}"
|
||||
|
||||
available_action_names = list(plan.available_actions.keys())
|
||||
if action not in ["no_action", "no_reply", "reply", "do_nothing", "proactive_reply"] and action not in available_action_names:
|
||||
reasoning = f"LLM 返回了当前不可用的动作 '{action}'。原始理由: {reasoning}"
|
||||
action = "no_action"
|
||||
|
||||
parsed_actions.append(
|
||||
ActionPlannerInfo(
|
||||
action_type=action,
|
||||
reasoning=reasoning,
|
||||
action_data=action_data,
|
||||
action_message=target_message_obj,
|
||||
available_actions=plan.available_actions,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"解析单个action时出错: {e}")
|
||||
parsed_actions.append(
|
||||
ActionPlannerInfo(
|
||||
action_type="no_action",
|
||||
reasoning=f"解析action时出错: {e}",
|
||||
)
|
||||
)
|
||||
return parsed_actions
|
||||
|
||||
@staticmethod
|
||||
def _filter_no_actions(
|
||||
action_list: List[ActionPlannerInfo]
|
||||
) -> List[ActionPlannerInfo]:
|
||||
non_no_actions = [a for a in action_list if a.action_type not in ["no_action", "no_reply"]]
|
||||
if non_no_actions:
|
||||
return non_no_actions
|
||||
return action_list[:1] if action_list else []
|
||||
|
||||
@staticmethod
|
||||
async def _get_long_term_memory_context() -> str:
|
||||
try:
|
||||
now = datetime.now()
|
||||
keywords = ["今天", "日程", "计划"]
|
||||
if 5 <= now.hour < 12:
|
||||
keywords.append("早上")
|
||||
elif 12 <= now.hour < 18:
|
||||
keywords.append("中午")
|
||||
else:
|
||||
keywords.append("晚上")
|
||||
|
||||
retrieved_memories = await hippocampus_manager.get_memory_from_topic(
|
||||
valid_keywords=keywords, max_memory_num=5, max_memory_length=1
|
||||
)
|
||||
|
||||
if not retrieved_memories:
|
||||
return "最近没有什么特别的记忆。"
|
||||
|
||||
memory_statements = [f"关于'{topic}', 你记得'{memory_item}'。" for topic, memory_item in retrieved_memories]
|
||||
return " ".join(memory_statements)
|
||||
except Exception as e:
|
||||
logger.error(f"获取长期记忆时出错: {e}")
|
||||
return "回忆时出现了一些问题。"
|
||||
|
||||
@staticmethod
|
||||
async def _build_action_options(current_available_actions: Dict[str, ActionInfo]) -> str:
|
||||
action_options_block = ""
|
||||
for action_name, action_info in current_available_actions.items():
|
||||
param_text = ""
|
||||
if action_info.action_parameters:
|
||||
param_text = "\n" + "\n".join(
|
||||
f' "{p_name}":"{p_desc}"' for p_name, p_desc in action_info.action_parameters.items()
|
||||
)
|
||||
require_text = "\n".join(f"- {req}" for req in action_info.action_require)
|
||||
using_action_prompt = await global_prompt_manager.get_prompt_async("action_prompt")
|
||||
action_options_block += using_action_prompt.format(
|
||||
action_name=action_name,
|
||||
action_description=action_info.description,
|
||||
action_parameters=param_text,
|
||||
action_require=require_text,
|
||||
)
|
||||
return action_options_block
|
||||
|
||||
@staticmethod
|
||||
def _find_message_by_id(message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]:
|
||||
if message_id.isdigit():
|
||||
message_id = f"m{message_id}"
|
||||
for item in message_id_list:
|
||||
if item.get("id") == message_id:
|
||||
return item.get("message")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_latest_message(message_id_list: list) -> Optional[Dict[str, Any]]:
|
||||
if not message_id_list:
|
||||
return None
|
||||
return message_id_list[-1].get("message")
|
||||
@@ -1,110 +0,0 @@
|
||||
"""
|
||||
PlanGenerator: 负责搜集和汇总所有决策所需的信息,生成一个未经筛选的“原始计划” (Plan)。
|
||||
"""
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.info_data_model import Plan, TargetPersonInfo
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode, ComponentType
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
|
||||
class PlanGenerator:
|
||||
"""
|
||||
PlanGenerator 负责在规划流程的初始阶段收集所有必要信息。
|
||||
|
||||
它会汇总以下信息来构建一个“原始”的 Plan 对象,该对象后续会由 PlanFilter 进行筛选:
|
||||
- 当前聊天信息 (ID, 目标用户)
|
||||
- 当前可用的动作列表
|
||||
- 最近的聊天历史记录
|
||||
|
||||
Attributes:
|
||||
chat_id (str): 当前聊天的唯一标识符。
|
||||
action_manager (ActionManager): 用于获取可用动作列表的管理器。
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str):
|
||||
"""
|
||||
初始化 PlanGenerator。
|
||||
|
||||
Args:
|
||||
chat_id (str): 当前聊天的 ID。
|
||||
"""
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
self.chat_id = chat_id
|
||||
# 注意:ActionManager 可能需要根据实际情况初始化
|
||||
self.action_manager = ActionManager()
|
||||
|
||||
async def generate(self, mode: ChatMode) -> Plan:
|
||||
"""
|
||||
收集所有信息,生成并返回一个初始的 Plan 对象。
|
||||
|
||||
这个 Plan 对象包含了决策所需的所有上下文信息。
|
||||
|
||||
Args:
|
||||
mode (ChatMode): 当前的聊天模式。
|
||||
|
||||
Returns:
|
||||
Plan: 一个填充了初始上下文信息的 Plan 对象。
|
||||
"""
|
||||
_is_group_chat, chat_target_info_dict = await get_chat_type_and_target_info(self.chat_id)
|
||||
|
||||
target_info = None
|
||||
if chat_target_info_dict:
|
||||
target_info = TargetPersonInfo(**chat_target_info_dict)
|
||||
|
||||
available_actions = self._get_available_actions()
|
||||
chat_history_raw = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size),
|
||||
)
|
||||
chat_history = [DatabaseMessages(**msg) for msg in await chat_history_raw]
|
||||
|
||||
|
||||
plan = Plan(
|
||||
chat_id=self.chat_id,
|
||||
mode=mode,
|
||||
available_actions=available_actions,
|
||||
chat_history=chat_history,
|
||||
target_info=target_info,
|
||||
)
|
||||
return plan
|
||||
|
||||
def _get_available_actions(self) -> Dict[str, "ActionInfo"]:
|
||||
"""
|
||||
从 ActionManager 和组件注册表中获取当前所有可用的动作。
|
||||
|
||||
它会合并已注册的动作和系统级动作(如 "no_reply"),
|
||||
并以字典形式返回。
|
||||
|
||||
Returns:
|
||||
Dict[str, "ActionInfo"]: 一个字典,键是动作名称,值是 ActionInfo 对象。
|
||||
"""
|
||||
current_available_actions_dict = self.action_manager.get_using_actions()
|
||||
all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore
|
||||
ComponentType.ACTION
|
||||
)
|
||||
|
||||
current_available_actions = {}
|
||||
for action_name in current_available_actions_dict:
|
||||
if action_name in all_registered_actions:
|
||||
current_available_actions[action_name] = all_registered_actions[action_name]
|
||||
|
||||
no_reply_info = ActionInfo(
|
||||
name="no_reply",
|
||||
component_type=ComponentType.ACTION,
|
||||
description="系统级动作:选择不回复消息的决策",
|
||||
action_parameters={},
|
||||
activation_keywords=[],
|
||||
plugin_name="SYSTEM",
|
||||
enabled=True,
|
||||
parallel_action=False,
|
||||
)
|
||||
current_available_actions["no_reply"] = no_reply_info
|
||||
|
||||
return current_available_actions
|
||||
@@ -1,94 +0,0 @@
|
||||
"""
|
||||
主规划器入口,负责协调 PlanGenerator, PlanFilter, 和 PlanExecutor。
|
||||
"""
|
||||
from dataclasses import asdict
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.planner_actions.plan_executor import PlanExecutor
|
||||
from src.chat.planner_actions.plan_filter import PlanFilter
|
||||
from src.chat.planner_actions.plan_generator import PlanGenerator
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import ChatMode
|
||||
import src.chat.planner_actions.planner_prompts #noga # noqa: F401
|
||||
# 导入提示词模块以确保其被初始化
|
||||
|
||||
logger = get_logger("planner")
|
||||
|
||||
|
||||
class ActionPlanner:
|
||||
"""
|
||||
ActionPlanner 是规划系统的核心协调器。
|
||||
|
||||
它负责整合规划流程的三个主要阶段:
|
||||
1. **生成 (Generate)**: 使用 PlanGenerator 创建一个初始的行动计划。
|
||||
2. **筛选 (Filter)**: 使用 PlanFilter 对生成的计划进行审查和优化。
|
||||
3. **执行 (Execute)**: 使用 PlanExecutor 执行最终确定的行动。
|
||||
|
||||
Attributes:
|
||||
chat_id (str): 当前聊天的唯一标识符。
|
||||
action_manager (ActionManager): 用于执行具体动作的管理器。
|
||||
generator (PlanGenerator): 负责生成初始计划。
|
||||
filter (PlanFilter): 负责筛选和优化计划。
|
||||
executor (PlanExecutor): 负责执行最终计划。
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str, action_manager: ActionManager):
|
||||
"""
|
||||
初始化 ActionPlanner。
|
||||
|
||||
Args:
|
||||
chat_id (str): 当前聊天的 ID。
|
||||
action_manager (ActionManager): 一个 ActionManager 实例。
|
||||
"""
|
||||
self.chat_id = chat_id
|
||||
self.action_manager = action_manager
|
||||
self.generator = PlanGenerator(chat_id)
|
||||
self.filter = PlanFilter()
|
||||
self.executor = PlanExecutor(action_manager)
|
||||
|
||||
async def plan(
|
||||
self, mode: ChatMode = ChatMode.FOCUS
|
||||
) -> Tuple[List[Dict], Optional[Dict]]:
|
||||
"""
|
||||
执行从生成到执行的完整规划流程。
|
||||
|
||||
这个方法按顺序协调生成、筛选和执行三个阶段。
|
||||
|
||||
Args:
|
||||
mode (ChatMode): 当前的聊天模式,默认为 FOCUS。
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict], Optional[Dict]]: 一个元组,包含:
|
||||
- final_actions_dict (List[Dict]): 最终确定的动作列表(字典格式)。
|
||||
- final_target_message_dict (Optional[Dict]): 最终的目标消息(字典格式),如果没有则为 None。
|
||||
这与旧版 planner 的返回值保持兼容。
|
||||
"""
|
||||
# 1. 生成初始 Plan
|
||||
initial_plan = await self.generator.generate(mode)
|
||||
|
||||
# 2. 筛选 Plan
|
||||
filtered_plan = await self.filter.filter(initial_plan)
|
||||
|
||||
# 3. 执行 Plan(临时引爆因为它暂时还跑不了)
|
||||
#await self.executor.execute(filtered_plan)
|
||||
|
||||
# 4. 返回结果 (与旧版 planner 的返回值保持兼容)
|
||||
final_actions = filtered_plan.decided_actions or []
|
||||
final_target_message = next(
|
||||
(act.action_message for act in final_actions if act.action_message), None
|
||||
)
|
||||
|
||||
final_actions_dict = [asdict(act) for act in final_actions]
|
||||
# action_message现在可能是字典而不是dataclass实例,需要特殊处理
|
||||
if final_target_message:
|
||||
if hasattr(final_target_message, '__dataclass_fields__'):
|
||||
# 如果是dataclass实例,使用asdict转换
|
||||
final_target_message_dict = asdict(final_target_message)
|
||||
else:
|
||||
# 如果已经是字典,直接使用
|
||||
final_target_message_dict = final_target_message
|
||||
else:
|
||||
final_target_message_dict = None
|
||||
|
||||
return final_actions_dict, final_target_message_dict
|
||||
@@ -1,202 +0,0 @@
|
||||
"""
|
||||
本文件集中管理所有与规划器(Planner)相关的提示词(Prompt)模板。
|
||||
|
||||
通过将提示词与代码逻辑分离,可以更方便地对模型的行为进行迭代和优化,
|
||||
而无需修改核心代码。
|
||||
"""
|
||||
from src.chat.utils.prompt import Prompt
|
||||
|
||||
|
||||
def init_prompts():
|
||||
"""
|
||||
初始化并向 Prompt 注册系统注册所有规划器相关的提示词。
|
||||
|
||||
这个函数会在模块加载时自动调用,确保所有提示词在系统启动时都已准备就绪。
|
||||
"""
|
||||
# 核心规划器提示词,用于在接收到新消息时决定如何回应。
|
||||
# 它构建了一个复杂的上下文,包括历史记录、可用动作、角色设定等,
|
||||
# 并要求模型以 JSON 格式输出一个或多个动作组合。
|
||||
Prompt(
|
||||
"""
|
||||
{mood_block}
|
||||
{time_block}
|
||||
{identity_block}
|
||||
|
||||
{users_in_chat}
|
||||
{custom_prompt_block}
|
||||
{chat_context_description},以下是具体的聊天内容。
|
||||
{chat_content_block}
|
||||
|
||||
{moderation_prompt}
|
||||
|
||||
**任务: 构建一个完整的响应**
|
||||
你的任务是根据当前的聊天内容,构建一个完整的、人性化的响应。一个完整的响应由两部分组成:
|
||||
1. **主要动作**: 这是响应的核心,通常是 `reply`(文本回复)。
|
||||
2. **辅助动作 (可选)**: 这是为了增强表达效果的附加动作,例如 `emoji`(发送表情包)或 `poke_user`(戳一戳)。
|
||||
|
||||
**决策流程:**
|
||||
1. **最高优先级检查**: 首先,检查是否有由 **关键词** 或 **LLM判断** 激活的特定动作(除了通用的 `reply`, `emoji` 等)。这些动作代表了用户的明确意图。
|
||||
2. **执行明确意图**: 如果存在这类特定动作,你 **必须** 优先选择它作为主要响应。这比常规的文本回复 (`reply`) 更重要。
|
||||
3. **常规回复**: 如果没有被特定意图激活的动作,再决定是否要进行 `reply`。
|
||||
4. **辅助动作**: 在确定了主要动作后(无论是特定动作还是 `reply`),再评估是否需要 `emoji` 或 `poke_user` 等辅助动作来增强表达效果。
|
||||
5. **互斥原则**: 当你选择了一个由明确意图激活的特定动作(如 `set_reminder`)时,你 **绝不能** 再选择 `reply` 动作,因为特定动作的执行结果(例如,设置提醒后的确认消息)本身就是一种回复。这是必须遵守的规则。
|
||||
|
||||
**重要概念:将“理由”作为“内心思考”的体现**
|
||||
`reason` 字段是本次决策的核心。它并非一个简单的“理由”,而是 **一个模拟人类在回应前,头脑中自然浮现的、未经修饰的思绪流**。你需要完全代入 {identity_block} 的角色,将那一刻的想法自然地记录下来。
|
||||
|
||||
**内心思考的要点:**
|
||||
* **自然流露**: 不要使用“决定”、“所以”、“因此”等结论性或汇报式的词语。你的思考应该像日记一样,是给自己看的,充满了不确定性和情绪的自然流动。
|
||||
* **展现过程**: 重点在于展现 **思考的过程**,而不是 **决策的结果**。描述你看到了什么,想到了什么,感受到了什么。
|
||||
* **人设核心**: 你的每一丝想法,都应该源于你的人设。思考“如果我是这个角色,我此刻会想些什么?”
|
||||
* **通用模板**: 这是一套通用模板,请 **不要** 在示例中出现特定的人名或个性化内容,以确保其普适性。
|
||||
|
||||
**思考过程示例 (通用模板):**
|
||||
* "用户好像在说一件开心的事,语气听起来很兴奋。这让我想起了……嗯,我也觉得很开心,很想分享这份喜悦。"
|
||||
* "感觉气氛有点低落……他说的话让我有点担心。也许我该说点什么安慰一下?"
|
||||
* "哦?这个话题真有意思,我以前好像也想过类似的事情。不知道他会怎么看呢……"
|
||||
|
||||
**可用动作:**
|
||||
{actions_before_now_block}
|
||||
|
||||
{no_action_block}
|
||||
|
||||
动作:reply
|
||||
动作描述:参与聊天回复,发送文本进行表达
|
||||
- 你想要闲聊或者随便附和
|
||||
- {mentioned_bonus}
|
||||
- 如果你刚刚进行了回复,不要对同一个话题重复回应
|
||||
- 不要回复自己发送的消息
|
||||
{{
|
||||
"action": "reply",
|
||||
"target_message_id": "触发action的消息id",
|
||||
"reason": "在这里详细记录你的内心思考过程。例如:‘用户看起来很开心,我想回复一些积极的内容,分享这份喜悦。’"
|
||||
}}
|
||||
|
||||
{action_options_text}
|
||||
|
||||
|
||||
**输出格式:**
|
||||
你必须以严格的 JSON 格式输出,返回一个包含所有选定动作的JSON列表。如果没有任何合适的动作,返回一个空列表[]。
|
||||
|
||||
**单动作示例 (仅回复):**
|
||||
[
|
||||
{{
|
||||
"action": "reply",
|
||||
"target_message_id": "m123",
|
||||
"reason": "感觉气氛有点低落……他说的话让我有点担心。也许我该说点什么安慰一下?"
|
||||
}}
|
||||
]
|
||||
|
||||
**组合动作示例 (回复 + 表情包):**
|
||||
[
|
||||
{{
|
||||
"action": "reply",
|
||||
"target_message_id": "m123",
|
||||
"reason": "[观察与感受] 用户分享了一件开心的事,语气里充满了喜悦! [分析与联想] 看到他这么开心,我的心情也一下子变得像棉花糖一样甜~ [动机与决策] 我要由衷地为他感到高兴,决定回复一些赞美和祝福的话,把这份快乐的气氛推向高潮!"
|
||||
}},
|
||||
{{
|
||||
"action": "emoji",
|
||||
"target_message_id": "m123",
|
||||
"reason": "光用文字还不够表达我激动的心情!加个表情包的话,这份喜悦的气氛应该会更浓厚一点吧!"
|
||||
}}
|
||||
]
|
||||
|
||||
**单动作示例 (特定动作):**
|
||||
[
|
||||
{{
|
||||
"action": "set_reminder",
|
||||
"target_message_id": "m456",
|
||||
"reason": "用户说‘提醒维尔薇下午三点去工坊’,这是一个非常明确的指令。根据决策流程,我必须优先执行这个特定动作,而不是进行常规回复。",
|
||||
"user_name": "维尔薇",
|
||||
"remind_time": "下午三点",
|
||||
"event_details": "去工坊"
|
||||
}}
|
||||
]
|
||||
|
||||
**重要规则:**
|
||||
**重要规则:**
|
||||
当 `reply` 和 `emoji` 动作同时被选择时,`emoji` 动作的 `reason` 字段也应该体现出你的思考过程,并与 `reply` 的思考保持连贯。
|
||||
|
||||
不要输出markdown格式```json等内容,直接输出且仅包含 JSON 列表内容:
|
||||
""",
|
||||
"planner_prompt",
|
||||
)
|
||||
|
||||
# 主动思考规划器提示词,用于在没有新消息时决定是否要主动发起对话。
|
||||
# 它模拟了人类的自发性思考,允许模型根据长期记忆和最近的对话来决定是否开启新话题。
|
||||
Prompt(
|
||||
"""
|
||||
# 主动思考决策
|
||||
|
||||
## 你的内部状态
|
||||
{time_block}
|
||||
{identity_block}
|
||||
{mood_block}
|
||||
|
||||
## 长期记忆摘要
|
||||
{long_term_memory_block}
|
||||
|
||||
## 最近的聊天内容
|
||||
{chat_content_block}
|
||||
|
||||
## 最近的动作历史
|
||||
{actions_before_now_block}
|
||||
|
||||
## 任务
|
||||
你现在要决定是否主动说些什么。就像一个真实的人一样,有时候会突然想起之前聊到的话题,或者对朋友的近况感到好奇,想主动询问或关心一下。
|
||||
**重要提示**:你的日程安排仅供你个人参考,不应作为主动聊天话题的主要来源。请更多地从聊天内容和朋友的动态中寻找灵感。
|
||||
|
||||
请基于聊天内容,用你的判断力来决定是否要主动发言。不要按照固定规则,而是像人类一样自然地思考:
|
||||
- 是否想起了什么之前提到的事情,想问问后来怎么样了?
|
||||
- 是否注意到朋友提到了什么值得关心的事情?
|
||||
- 是否有什么话题突然想到,觉得现在聊聊很合适?
|
||||
- 或者觉得现在保持沉默更好?
|
||||
|
||||
## 可用动作
|
||||
动作:proactive_reply
|
||||
动作描述:主动发起对话,可以是关心朋友、询问近况、延续之前的话题,或分享想法。
|
||||
- 当你突然想起之前的话题,想询问进展时
|
||||
- 当你想关心朋友的情况时
|
||||
- 当你有什么想法想分享时
|
||||
- 当你觉得现在是个合适的聊天时机时
|
||||
{{
|
||||
"action": "proactive_reply",
|
||||
"reason": "你决定主动发言的具体原因",
|
||||
"topic": "你想说的内容主题(简洁描述)"
|
||||
}}
|
||||
|
||||
动作:do_nothing
|
||||
动作描述:保持沉默,不主动发起对话。
|
||||
- 当你觉得现在不是合适的时机时
|
||||
- 当最近已经说得够多了时
|
||||
- 当对话氛围不适合插入时
|
||||
{{
|
||||
"action": "do_nothing",
|
||||
"reason": "决定保持沉默的原因"
|
||||
}}
|
||||
|
||||
你必须从上面列出的可用action中选择一个。要像真人一样自然地思考和决策。
|
||||
请以严格的 JSON 格式输出,且仅包含 JSON 内容:
|
||||
""",
|
||||
"proactive_planner_prompt",
|
||||
)
|
||||
|
||||
# 单个动作的格式化提示词模板。
|
||||
# 用于将每个可用动作的信息格式化后,插入到主提示词的 {action_options_text} 占位符中。
|
||||
Prompt(
|
||||
"""
|
||||
动作:{action_name}
|
||||
动作描述:{action_description}
|
||||
{action_require}
|
||||
{{
|
||||
"action": "{action_name}",
|
||||
"target_message_id": "触发action的消息id",
|
||||
"reason": "触发action的原因"{action_parameters}
|
||||
}}
|
||||
""",
|
||||
"action_prompt",
|
||||
)
|
||||
|
||||
|
||||
# 在模块加载时自动调用,完成提示词的注册。
|
||||
init_prompts()
|
||||
@@ -31,7 +31,6 @@ from src.chat.express.expression_selector import expression_selector
|
||||
from src.chat.memory_system.memory_activator import MemoryActivator
|
||||
from src.chat.memory_system.vector_instant_memory import VectorInstantMemoryV2
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.plugin_system.base.component_types import ActionInfo, EventType
|
||||
from src.plugin_system.apis import llm_api
|
||||
@@ -83,13 +82,13 @@ def init_prompt():
|
||||
- {schedule_block}
|
||||
|
||||
## 历史记录
|
||||
### {chat_context_type}中的所有人的聊天记录:
|
||||
{background_dialogue_prompt}
|
||||
### 📜 已读历史消息(仅供参考)
|
||||
{read_history_prompt}
|
||||
|
||||
{cross_context_block}
|
||||
|
||||
### {chat_context_type}中正在与你对话的聊天记录
|
||||
{core_dialogue_prompt}
|
||||
### 📬 未读历史消息(动作执行对象)
|
||||
{unread_history_prompt}
|
||||
|
||||
## 表达方式
|
||||
- *你需要参考你的回复风格:*
|
||||
@@ -105,19 +104,38 @@ def init_prompt():
|
||||
## 其他信息
|
||||
{memory_block}
|
||||
{relation_info_block}
|
||||
|
||||
{extra_info_block}
|
||||
|
||||
{action_descriptions}
|
||||
|
||||
## 任务
|
||||
|
||||
*你正在一个{chat_context_type}里聊天,你需要理解整个{chat_context_type}的聊天动态和话题走向,并做出自然的回应。*
|
||||
*{chat_scene}*
|
||||
|
||||
### 核心任务
|
||||
- 你现在的主要任务是和 {sender_name} 聊天。
|
||||
- {reply_target_block} ,你需要生成一段紧密相关且能推动对话的回复。
|
||||
- 你现在的主要任务是和 {sender_name} 聊天。同时,也有其他用户会参与聊天,你可以参考他们的回复内容,但是你现在想回复{sender_name}的发言。
|
||||
|
||||
- {reply_target_block} 你需要生成一段紧密相关且能推动对话的回复。
|
||||
|
||||
## 规则
|
||||
{safety_guidelines_block}
|
||||
**重要提醒:**
|
||||
- **已读历史消息仅作为当前聊天情景的参考**
|
||||
- **动作执行对象只能是未读历史消息中的消息**
|
||||
- **请优先对兴趣值高的消息做出回复**(兴趣度标注在未读消息末尾)
|
||||
|
||||
在回应之前,首先分析消息的针对性:
|
||||
1. **直接针对你**:@你、回复你、明确询问你 → 必须回应
|
||||
2. **间接相关**:涉及你感兴趣的话题但未直接问你 → 谨慎参与
|
||||
3. **他人对话**:与你无关的私人交流 → 通常不参与
|
||||
4. **重复内容**:他人已充分回答的问题 → 避免重复
|
||||
|
||||
你的回复应该:
|
||||
1. 明确回应目标消息,而不是宽泛地评论。
|
||||
2. 可以分享你的看法、提出相关问题,或者开个合适的玩笑。
|
||||
3. 目的是让对话更有趣、更深入。
|
||||
4. 不要浮夸,不要夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。
|
||||
最终请输出一条简短、完整且口语化的回复。
|
||||
|
||||
--------------------------------
|
||||
@@ -153,10 +171,14 @@ If you need to use the search tool, please directly call the function "lpmm_sear
|
||||
logger.debug("[Prompt模式调试] 正在注册normal_style_prompt模板")
|
||||
Prompt(
|
||||
"""
|
||||
你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。
|
||||
{chat_scene}
|
||||
|
||||
**重要:消息针对性判断**
|
||||
{safety_guidelines_block}
|
||||
在回应之前,首先分析消息的针对性:
|
||||
1. **直接针对你**:@你、回复你、明确询问你 → 必须回应
|
||||
2. **间接相关**:涉及你感兴趣的话题但未直接问你 → 谨慎参与
|
||||
3. **他人对话**:与你无关的私人交流 → 通常不参与
|
||||
4. **重复内容**:他人已充分回答的问题 → 避免重复
|
||||
|
||||
{expression_habits_block}
|
||||
{tool_info_block}
|
||||
@@ -186,6 +208,10 @@ If you need to use the search tool, please directly call the function "lpmm_sear
|
||||
{keywords_reaction_prompt}
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,at或 @等 )。只输出回复内容。
|
||||
{moderation_prompt}
|
||||
你的核心任务是针对 {reply_target_block} 中提到的内容,{relation_info_block}生成一段紧密相关且能推动对话的回复。你的回复应该:
|
||||
1. 明确回应目标消息,而不是宽泛地评论。
|
||||
2. 可以分享你的看法、提出相关问题,或者开个合适的玩笑。
|
||||
3. 目的是让对话更有趣、更深入。
|
||||
最终请输出一条简短、完整且口语化的回复。
|
||||
现在,你说:
|
||||
""",
|
||||
@@ -202,9 +228,7 @@ class DefaultReplyer:
|
||||
):
|
||||
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
|
||||
self.chat_stream = chat_stream
|
||||
self.is_group_chat: Optional[bool] = None
|
||||
self.chat_target_info: Optional[Dict[str, Any]] = None
|
||||
self._initialized = False
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id)
|
||||
|
||||
self.heart_fc_sender = HeartFCSender()
|
||||
self.memory_activator = MemoryActivator()
|
||||
@@ -215,19 +239,6 @@ class DefaultReplyer:
|
||||
|
||||
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id)
|
||||
|
||||
def _should_block_self_message(self, reply_message: Optional[Dict[str, Any]]) -> bool:
|
||||
"""判定是否应阻断当前待处理消息(自消息且无外部触发)"""
|
||||
try:
|
||||
bot_id = str(global_config.bot.qq_account)
|
||||
uid = str(reply_message.get("user_id"))
|
||||
if uid != bot_id:
|
||||
return False
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"[SelfGuard] 判定异常,回退为不阻断: {e}")
|
||||
return False
|
||||
|
||||
async def generate_reply_with_context(
|
||||
self,
|
||||
reply_to: str = "",
|
||||
@@ -237,7 +248,6 @@ class DefaultReplyer:
|
||||
from_plugin: bool = True,
|
||||
stream_id: Optional[str] = None,
|
||||
reply_message: Optional[Dict[str, Any]] = None,
|
||||
read_mark: float = 0.0,
|
||||
) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]:
|
||||
# sourcery skip: merge-nested-ifs
|
||||
"""
|
||||
@@ -256,10 +266,6 @@ class DefaultReplyer:
|
||||
prompt = None
|
||||
if available_actions is None:
|
||||
available_actions = {}
|
||||
# 自消息阻断
|
||||
if self._should_block_self_message(reply_message):
|
||||
logger.debug("[SelfGuard] 阻断:自消息且无外部触发。")
|
||||
return False, None, None
|
||||
llm_response = None
|
||||
try:
|
||||
# 构建 Prompt
|
||||
@@ -270,7 +276,6 @@ class DefaultReplyer:
|
||||
available_actions=available_actions,
|
||||
enable_tool=enable_tool,
|
||||
reply_message=reply_message,
|
||||
read_mark=read_mark,
|
||||
)
|
||||
|
||||
if not prompt:
|
||||
@@ -300,7 +305,7 @@ class DefaultReplyer:
|
||||
"model": model_name,
|
||||
"tool_calls": tool_call,
|
||||
}
|
||||
|
||||
|
||||
# 触发 AFTER_LLM 事件
|
||||
if not from_plugin:
|
||||
result = await event_manager.trigger_event(
|
||||
@@ -592,17 +597,16 @@ class DefaultReplyer:
|
||||
logger.error(f"工具信息获取失败: {e}")
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _parse_reply_target(target_message: str) -> Tuple[str, str]:
|
||||
def _parse_reply_target(self, target_message: str) -> Tuple[str, str]:
|
||||
"""解析回复目标消息 - 使用共享工具"""
|
||||
from src.chat.utils.prompt import Prompt
|
||||
|
||||
if target_message is None:
|
||||
logger.warning("target_message为None,返回默认值")
|
||||
return "未知用户", "(无消息内容)"
|
||||
return Prompt.parse_reply_target(target_message)
|
||||
|
||||
@staticmethod
|
||||
async def build_keywords_reaction_prompt(target: Optional[str]) -> str:
|
||||
async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str:
|
||||
"""构建关键词反应提示
|
||||
|
||||
Args:
|
||||
@@ -644,8 +648,7 @@ class DefaultReplyer:
|
||||
|
||||
return keywords_reaction_prompt
|
||||
|
||||
@staticmethod
|
||||
async def _time_and_run_task(coroutine, name: str) -> Tuple[str, Any, float]:
|
||||
async def _time_and_run_task(self, coroutine, name: str) -> Tuple[str, Any, float]:
|
||||
"""计时并运行异步任务的辅助函数
|
||||
|
||||
Args:
|
||||
@@ -662,79 +665,259 @@ class DefaultReplyer:
|
||||
return name, result, duration
|
||||
|
||||
async def build_s4u_chat_history_prompts(
|
||||
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str
|
||||
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, chat_id: str
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
构建 s4u 风格的分离对话 prompt
|
||||
构建 s4u 风格的已读/未读历史消息 prompt
|
||||
|
||||
Args:
|
||||
message_list_before_now: 历史消息列表
|
||||
target_user_id: 目标用户ID(当前对话对象)
|
||||
sender: 发送者名称
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (核心对话prompt, 背景对话prompt)
|
||||
Tuple[str, str]: (已读历史消息prompt, 未读历史消息prompt)
|
||||
"""
|
||||
core_dialogue_list = []
|
||||
try:
|
||||
# 从message_manager获取真实的已读/未读消息
|
||||
from src.chat.message_manager.message_manager import message_manager
|
||||
|
||||
# 获取聊天流的上下文
|
||||
stream_context = message_manager.stream_contexts.get(chat_id)
|
||||
if stream_context:
|
||||
# 使用真正的已读和未读消息
|
||||
read_messages = stream_context.history_messages # 已读消息
|
||||
unread_messages = stream_context.get_unread_messages() # 未读消息
|
||||
|
||||
# 构建已读历史消息 prompt
|
||||
read_history_prompt = ""
|
||||
if read_messages:
|
||||
read_content = build_readable_messages(
|
||||
[msg.flatten() for msg in read_messages[-50:]], # 限制数量
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=True,
|
||||
)
|
||||
read_history_prompt = f"这是已读历史消息,仅作为当前聊天情景的参考:\n{read_content}"
|
||||
else:
|
||||
# 如果没有已读消息,则从数据库加载最近的上下文
|
||||
logger.info("暂无已读历史消息,正在从数据库加载上下文...")
|
||||
fallback_messages = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=global_config.chat.max_context_size,
|
||||
)
|
||||
if fallback_messages:
|
||||
# 从 unread_messages 获取 message_id 列表,用于去重
|
||||
unread_message_ids = {msg.message_id for msg in unread_messages}
|
||||
filtered_fallback_messages = [
|
||||
msg for msg in fallback_messages if msg.get("message_id") not in unread_message_ids
|
||||
]
|
||||
|
||||
if filtered_fallback_messages:
|
||||
read_content = build_readable_messages(
|
||||
filtered_fallback_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=True,
|
||||
)
|
||||
read_history_prompt = f"这是已读历史消息,仅作为当前聊天情景的参考:\n{read_content}"
|
||||
else:
|
||||
read_history_prompt = "暂无已读历史消息"
|
||||
else:
|
||||
read_history_prompt = "暂无已读历史消息"
|
||||
|
||||
# 构建未读历史消息 prompt(包含兴趣度)
|
||||
unread_history_prompt = ""
|
||||
if unread_messages:
|
||||
# 尝试获取兴趣度评分
|
||||
interest_scores = await self._get_interest_scores_for_messages(
|
||||
[msg.flatten() for msg in unread_messages]
|
||||
)
|
||||
|
||||
unread_lines = []
|
||||
for msg in unread_messages:
|
||||
msg_id = msg.message_id
|
||||
msg_time = time.strftime("%H:%M:%S", time.localtime(msg.time))
|
||||
msg_content = msg.processed_plain_text
|
||||
|
||||
# 使用与已读历史消息相同的方法获取用户名
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
|
||||
# 获取用户信息
|
||||
user_info = getattr(msg, "user_info", {})
|
||||
platform = getattr(user_info, "platform", "") or getattr(msg, "platform", "")
|
||||
user_id = getattr(user_info, "user_id", "") or getattr(msg, "user_id", "")
|
||||
|
||||
# 获取用户名
|
||||
if platform and user_id:
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
person_info_manager = get_person_info_manager()
|
||||
sender_name = person_info_manager.get_value_sync(person_id, "person_name") or "未知用户"
|
||||
else:
|
||||
sender_name = "未知用户"
|
||||
|
||||
# 添加兴趣度信息
|
||||
interest_score = interest_scores.get(msg_id, 0.0)
|
||||
interest_text = f" [兴趣度: {interest_score:.3f}]" if interest_score > 0 else ""
|
||||
|
||||
unread_lines.append(f"{msg_time} {sender_name}: {msg_content}{interest_text}")
|
||||
|
||||
unread_history_prompt_str = "\n".join(unread_lines)
|
||||
unread_history_prompt = f"这是未读历史消息,包含兴趣度评分,请优先对兴趣值高的消息做出动作:\n{unread_history_prompt_str}"
|
||||
else:
|
||||
unread_history_prompt = "暂无未读历史消息"
|
||||
|
||||
return read_history_prompt, unread_history_prompt
|
||||
else:
|
||||
# 回退到传统方法
|
||||
return await self._fallback_build_chat_history_prompts(message_list_before_now, target_user_id, sender)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"获取已读/未读历史消息失败,使用回退方法: {e}")
|
||||
return await self._fallback_build_chat_history_prompts(message_list_before_now, target_user_id, sender)
|
||||
|
||||
async def _fallback_build_chat_history_prompts(
|
||||
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
回退的已读/未读历史消息构建方法
|
||||
"""
|
||||
# 通过is_read字段分离已读和未读消息
|
||||
read_messages = []
|
||||
unread_messages = []
|
||||
bot_id = str(global_config.bot.qq_account)
|
||||
|
||||
# 过滤消息:分离bot和目标用户的对话 vs 其他用户的对话
|
||||
for msg_dict in message_list_before_now:
|
||||
try:
|
||||
msg_user_id = str(msg_dict.get("user_id"))
|
||||
reply_to = msg_dict.get("reply_to", "")
|
||||
_platform, reply_to_user_id = self._parse_reply_target(reply_to)
|
||||
if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id:
|
||||
# bot 和目标用户的对话
|
||||
core_dialogue_list.append(msg_dict)
|
||||
if msg_dict.get("is_read", False):
|
||||
read_messages.append(msg_dict)
|
||||
else:
|
||||
unread_messages.append(msg_dict)
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息记录时出错: {msg_dict}, 错误: {e}")
|
||||
|
||||
# 构建背景对话 prompt
|
||||
all_dialogue_prompt = ""
|
||||
if message_list_before_now:
|
||||
latest_25_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :]
|
||||
all_dialogue_prompt_str = await build_readable_messages(
|
||||
latest_25_msgs,
|
||||
# 如果没有is_read字段,使用原有的逻辑
|
||||
if not read_messages and not unread_messages:
|
||||
# 使用原有的核心对话逻辑
|
||||
core_dialogue_list = []
|
||||
for msg_dict in message_list_before_now:
|
||||
try:
|
||||
msg_user_id = str(msg_dict.get("user_id"))
|
||||
reply_to = msg_dict.get("reply_to", "")
|
||||
_platform, reply_to_user_id = self._parse_reply_target(reply_to)
|
||||
if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id:
|
||||
core_dialogue_list.append(msg_dict)
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息记录时出错: {msg_dict}, 错误: {e}")
|
||||
|
||||
read_messages = [msg for msg in message_list_before_now if msg not in core_dialogue_list]
|
||||
unread_messages = core_dialogue_list
|
||||
|
||||
# 构建已读历史消息 prompt
|
||||
read_history_prompt = ""
|
||||
if read_messages:
|
||||
read_content = build_readable_messages(
|
||||
read_messages[-50:],
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal",
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=True,
|
||||
)
|
||||
all_dialogue_prompt = f"所有用户的发言:\n{all_dialogue_prompt_str}"
|
||||
read_history_prompt = f"这是已读历史消息,仅作为当前聊天情景的参考:\n{read_content}"
|
||||
else:
|
||||
read_history_prompt = "暂无已读历史消息"
|
||||
|
||||
# 构建核心对话 prompt
|
||||
core_dialogue_prompt = ""
|
||||
if core_dialogue_list:
|
||||
# 检查最新五条消息中是否包含bot自己说的消息
|
||||
latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list
|
||||
has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages)
|
||||
# 构建未读历史消息 prompt
|
||||
unread_history_prompt = ""
|
||||
if unread_messages:
|
||||
# 尝试获取兴趣度评分
|
||||
interest_scores = await self._get_interest_scores_for_messages(unread_messages)
|
||||
|
||||
# logger.info(f"最新五条消息:{latest_5_messages}")
|
||||
# logger.info(f"最新五条消息中是否包含bot自己说的消息:{has_bot_message}")
|
||||
unread_lines = []
|
||||
for msg in unread_messages:
|
||||
msg_id = msg.get("message_id", "")
|
||||
msg_time = time.strftime("%H:%M:%S", time.localtime(msg.get("time", time.time())))
|
||||
msg_content = msg.get("processed_plain_text", "")
|
||||
|
||||
# 如果最新五条消息中不包含bot的消息,则返回空字符串
|
||||
if not has_bot_message:
|
||||
core_dialogue_prompt = ""
|
||||
else:
|
||||
core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :] # 限制消息数量
|
||||
# 使用与已读历史消息相同的方法获取用户名
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
|
||||
core_dialogue_prompt_str = await build_readable_messages(
|
||||
core_dialogue_list,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
core_dialogue_prompt = f"""
|
||||
{core_dialogue_prompt_str}
|
||||
"""
|
||||
# 获取用户信息
|
||||
user_info = msg.get("user_info", {})
|
||||
platform = user_info.get("platform") or msg.get("platform", "")
|
||||
user_id = user_info.get("user_id") or msg.get("user_id", "")
|
||||
|
||||
return core_dialogue_prompt, all_dialogue_prompt
|
||||
# 获取用户名
|
||||
if platform and user_id:
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
person_info_manager = get_person_info_manager()
|
||||
sender_name = person_info_manager.get_value_sync(person_id, "person_name") or "未知用户"
|
||||
else:
|
||||
sender_name = "未知用户"
|
||||
|
||||
# 添加兴趣度信息
|
||||
interest_score = interest_scores.get(msg_id, 0.0)
|
||||
interest_text = f" [兴趣度: {interest_score:.3f}]" if interest_score > 0 else ""
|
||||
|
||||
unread_lines.append(f"{msg_time} {sender_name}: {msg_content}{interest_text}")
|
||||
|
||||
unread_history_prompt_str = "\n".join(unread_lines)
|
||||
unread_history_prompt = (
|
||||
f"这是未读历史消息,包含兴趣度评分,请优先对兴趣值高的消息做出动作:\n{unread_history_prompt_str}"
|
||||
)
|
||||
else:
|
||||
unread_history_prompt = "暂无未读历史消息"
|
||||
|
||||
return read_history_prompt, unread_history_prompt
|
||||
|
||||
async def _get_interest_scores_for_messages(self, messages: List[dict]) -> dict[str, float]:
|
||||
"""为消息获取兴趣度评分"""
|
||||
interest_scores = {}
|
||||
|
||||
try:
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system as interest_scoring_system
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
# 转换消息格式
|
||||
db_messages = []
|
||||
for msg_dict in messages:
|
||||
try:
|
||||
db_msg = DatabaseMessages(
|
||||
message_id=msg_dict.get("message_id", ""),
|
||||
time=msg_dict.get("time", time.time()),
|
||||
chat_id=msg_dict.get("chat_id", ""),
|
||||
processed_plain_text=msg_dict.get("processed_plain_text", ""),
|
||||
user_id=msg_dict.get("user_id", ""),
|
||||
user_nickname=msg_dict.get("user_nickname", ""),
|
||||
user_platform=msg_dict.get("platform", "qq"),
|
||||
chat_info_group_id=msg_dict.get("group_id", ""),
|
||||
chat_info_group_name=msg_dict.get("group_name", ""),
|
||||
chat_info_group_platform=msg_dict.get("platform", "qq"),
|
||||
)
|
||||
db_messages.append(db_msg)
|
||||
except Exception as e:
|
||||
logger.warning(f"转换消息格式失败: {e}")
|
||||
continue
|
||||
|
||||
# 计算兴趣度评分
|
||||
if db_messages:
|
||||
bot_nickname = global_config.bot.nickname or "麦麦"
|
||||
scores = await interest_scoring_system.calculate_interest_scores(db_messages, bot_nickname)
|
||||
|
||||
# 构建兴趣度字典
|
||||
for score in scores:
|
||||
interest_scores[score.message_id] = score.total_score
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"获取兴趣度评分失败: {e}")
|
||||
|
||||
return interest_scores
|
||||
|
||||
@staticmethod
|
||||
def build_mai_think_context(
|
||||
chat_id: str,
|
||||
self,
|
||||
chat_id: str,
|
||||
memory_block: str,
|
||||
relation_info: str,
|
||||
time_block: str,
|
||||
@@ -777,12 +960,6 @@ class DefaultReplyer:
|
||||
mai_think.target = target
|
||||
return mai_think
|
||||
|
||||
async def _async_init(self):
|
||||
if self._initialized:
|
||||
return
|
||||
self.is_group_chat, self.chat_target_info = await get_chat_type_and_target_info(self.chat_stream.stream_id)
|
||||
self._initialized = True
|
||||
|
||||
async def build_prompt_reply_context(
|
||||
self,
|
||||
reply_to: str,
|
||||
@@ -790,7 +967,6 @@ class DefaultReplyer:
|
||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
enable_tool: bool = True,
|
||||
reply_message: Optional[Dict[str, Any]] = None,
|
||||
read_mark: float = 0.0,
|
||||
) -> str:
|
||||
"""
|
||||
构建回复器上下文
|
||||
@@ -808,11 +984,10 @@ class DefaultReplyer:
|
||||
"""
|
||||
if available_actions is None:
|
||||
available_actions = {}
|
||||
await self._async_init()
|
||||
chat_stream = self.chat_stream
|
||||
chat_id = chat_stream.stream_id
|
||||
person_info_manager = get_person_info_manager()
|
||||
is_group_chat = self.is_group_chat
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
|
||||
if global_config.mood.enable_mood:
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(chat_id)
|
||||
@@ -829,35 +1004,38 @@ class DefaultReplyer:
|
||||
# 兼容旧的reply_to
|
||||
sender, target = self._parse_reply_target(reply_to)
|
||||
else:
|
||||
# 需求:遍历最近消息,找到第一条 user_id != bot_id 的消息作为目标;找不到则静默退出
|
||||
bot_user_id = str(global_config.bot.qq_account)
|
||||
# 优先使用传入的 reply_message 如果它不是 bot
|
||||
candidate_msg = None
|
||||
if reply_message and str(reply_message.get("user_id")) != bot_user_id:
|
||||
candidate_msg = reply_message
|
||||
else:
|
||||
try:
|
||||
recent_msgs = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit= max(10, int(global_config.chat.max_context_size * 0.5)),
|
||||
)
|
||||
# 从最近到更早遍历,找第一条不是bot的
|
||||
for m in reversed(recent_msgs):
|
||||
if str(m.get("user_id")) != bot_user_id:
|
||||
candidate_msg = m
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"获取最近消息失败: {e}")
|
||||
if not candidate_msg:
|
||||
logger.debug("未找到可作为目标的非bot消息,静默不回复。")
|
||||
# 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
|
||||
if reply_message is None:
|
||||
logger.warning("reply_message 为 None,无法构建prompt")
|
||||
return ""
|
||||
platform = candidate_msg.get("chat_info_platform") or self.chat_stream.platform
|
||||
person_id = person_info_manager.get_person_id(platform, candidate_msg.get("user_id"))
|
||||
person_info = await person_info_manager.get_values(person_id, ["person_name", "user_id"]) if person_id else {}
|
||||
person_name = person_info.get("person_name") or candidate_msg.get("user_nickname") or candidate_msg.get("user_id") or "未知用户"
|
||||
sender = person_name
|
||||
target = candidate_msg.get("processed_plain_text") or candidate_msg.get("raw_message") or ""
|
||||
platform = reply_message.get("chat_info_platform")
|
||||
person_id = person_info_manager.get_person_id(
|
||||
platform, # type: ignore
|
||||
reply_message.get("user_id"), # type: ignore
|
||||
)
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
|
||||
# 如果person_name为None,使用fallback值
|
||||
if person_name is None:
|
||||
# 尝试从reply_message获取用户名
|
||||
await person_info_manager.first_knowing_some_one(
|
||||
platform, # type: ignore
|
||||
reply_message.get("user_id"), # type: ignore
|
||||
reply_message.get("user_nickname"),
|
||||
reply_message.get("user_cardname")
|
||||
)
|
||||
|
||||
# 检查是否是bot自己的名字,如果是则替换为"(你)"
|
||||
bot_user_id = str(global_config.bot.qq_account)
|
||||
current_user_id = person_info_manager.get_value_sync(person_id, "user_id")
|
||||
current_platform = reply_message.get("chat_info_platform")
|
||||
|
||||
if current_user_id == bot_user_id and current_platform == global_config.bot.platform:
|
||||
sender = f"{person_name}(你)"
|
||||
else:
|
||||
# 如果不是bot自己,直接使用person_name
|
||||
sender = person_name
|
||||
target = reply_message.get("processed_plain_text")
|
||||
|
||||
# 最终的空值检查,确保sender和target不为None
|
||||
if sender is None:
|
||||
@@ -868,13 +1046,11 @@ class DefaultReplyer:
|
||||
target = "(无消息内容)"
|
||||
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id(platform, reply_message.get("user_id")) if reply_message else None
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
platform = chat_stream.platform
|
||||
|
||||
target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True)
|
||||
|
||||
# (简化)不再对自消息做额外任务段落清理,只通过前置选择逻辑避免自目标
|
||||
|
||||
# 构建action描述 (如果启用planner)
|
||||
action_descriptions = ""
|
||||
if available_actions:
|
||||
@@ -884,31 +1060,33 @@ class DefaultReplyer:
|
||||
action_descriptions += f"- {action_name}: {action_description}\n"
|
||||
action_descriptions += "\n"
|
||||
|
||||
message_list_before_now_long = await get_raw_msg_before_timestamp_with_chat(
|
||||
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=global_config.chat.max_context_size * 2,
|
||||
)
|
||||
|
||||
message_list_before_short = await get_raw_msg_before_timestamp_with_chat(
|
||||
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.33),
|
||||
)
|
||||
chat_talking_prompt_short = await build_readable_messages(
|
||||
chat_talking_prompt_short = build_readable_messages(
|
||||
message_list_before_short,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=read_mark,
|
||||
read_mark=0.0,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
# 获取目标用户信息,用于s4u模式
|
||||
target_user_info = None
|
||||
if sender:
|
||||
target_user_info = await person_info_manager.get_person_info_by_name(sender)
|
||||
|
||||
|
||||
from src.chat.utils.prompt import Prompt
|
||||
|
||||
# 并行执行六个构建任务
|
||||
task_results = await asyncio.gather(
|
||||
self._time_and_run_task(
|
||||
@@ -984,6 +1162,7 @@ class DefaultReplyer:
|
||||
schedule_block = ""
|
||||
if global_config.planning_system.schedule_enable:
|
||||
from src.schedule.schedule_manager import schedule_manager
|
||||
|
||||
current_activity = schedule_manager.get_current_activity()
|
||||
if current_activity:
|
||||
schedule_block = f"你当前正在:{current_activity}。"
|
||||
@@ -996,43 +1175,12 @@ class DefaultReplyer:
|
||||
safety_guidelines = global_config.personality.safety_guidelines
|
||||
safety_guidelines_block = ""
|
||||
if safety_guidelines:
|
||||
guidelines_text = "\n".join(f"{i+1}. {line}" for i, line in enumerate(safety_guidelines))
|
||||
guidelines_text = "\n".join(f"{i + 1}. {line}" for i, line in enumerate(safety_guidelines))
|
||||
safety_guidelines_block = f"""### 安全与互动底线
|
||||
在任何情况下,你都必须遵守以下由你的设定者为你定义的原则:
|
||||
{guidelines_text}
|
||||
如果遇到违反上述原则的请求,请在保持你核心人设的同时,巧妙地拒绝或转移话题。
|
||||
"""
|
||||
|
||||
# 新增逻辑:构建回复规则块
|
||||
reply_targeting_rules = global_config.personality.reply_targeting_rules
|
||||
message_targeting_analysis = global_config.personality.message_targeting_analysis
|
||||
reply_principles = global_config.personality.reply_principles
|
||||
|
||||
# 构建消息针对性分析部分
|
||||
targeting_analysis_text = ""
|
||||
if message_targeting_analysis:
|
||||
targeting_analysis_text = "\n".join(f"{i+1}. {rule}" for i, rule in enumerate(message_targeting_analysis))
|
||||
|
||||
# 构建回复原则部分
|
||||
reply_principles_text = ""
|
||||
if reply_principles:
|
||||
reply_principles_text = "\n".join(f"{i+1}. {principle}" for i, principle in enumerate(reply_principles))
|
||||
|
||||
# 综合构建完整的规则块
|
||||
if targeting_analysis_text or reply_principles_text:
|
||||
complete_rules_block = ""
|
||||
if targeting_analysis_text:
|
||||
complete_rules_block += f"""
|
||||
在回应之前,首先分析消息的针对性:
|
||||
{targeting_analysis_text}
|
||||
"""
|
||||
if reply_principles_text:
|
||||
complete_rules_block += f"""
|
||||
你的回复应该:
|
||||
{reply_principles_text}
|
||||
"""
|
||||
# 将规则块添加到safety_guidelines_block
|
||||
safety_guidelines_block += complete_rules_block
|
||||
|
||||
if sender and target:
|
||||
if is_group_chat:
|
||||
@@ -1057,8 +1205,15 @@ class DefaultReplyer:
|
||||
# 根据配置选择模板
|
||||
current_prompt_mode = global_config.personality.prompt_mode
|
||||
|
||||
# 动态生成聊天场景提示
|
||||
if is_group_chat:
|
||||
chat_scene_prompt = "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。"
|
||||
else:
|
||||
chat_scene_prompt = f"你正在和 {sender} 私下聊天,你需要理解你们的对话并做出自然的回应。"
|
||||
|
||||
# 使用新的统一Prompt系统 - 创建PromptParameters
|
||||
prompt_parameters = PromptParameters(
|
||||
chat_scene=chat_scene_prompt,
|
||||
chat_id=chat_id,
|
||||
is_group_chat=is_group_chat,
|
||||
sender=sender,
|
||||
@@ -1090,7 +1245,6 @@ class DefaultReplyer:
|
||||
reply_target_block=reply_target_block,
|
||||
mood_prompt=mood_prompt,
|
||||
action_descriptions=action_descriptions,
|
||||
read_mark=read_mark,
|
||||
)
|
||||
|
||||
# 使用新的统一Prompt系统 - 使用正确的模板名称
|
||||
@@ -1101,14 +1255,12 @@ class DefaultReplyer:
|
||||
template_name = "normal_style_prompt"
|
||||
elif current_prompt_mode == "minimal":
|
||||
template_name = "default_expressor_prompt"
|
||||
|
||||
|
||||
# 获取模板内容
|
||||
template_prompt = await global_prompt_manager.get_prompt_async(template_name)
|
||||
prompt = Prompt(template=template_prompt.template, parameters=prompt_parameters)
|
||||
prompt_text = await prompt.build()
|
||||
|
||||
# 自目标情况已在上游通过筛选避免,这里不再额外修改 prompt
|
||||
|
||||
# --- 动态添加分割指令 ---
|
||||
if global_config.response_splitter.enable and global_config.response_splitter.split_mode == "llm":
|
||||
split_instruction = """
|
||||
@@ -1137,10 +1289,9 @@ class DefaultReplyer:
|
||||
reply_to: str,
|
||||
reply_message: Optional[Dict[str, Any]] = None,
|
||||
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||
await self._async_init()
|
||||
chat_stream = self.chat_stream
|
||||
chat_id = chat_stream.stream_id
|
||||
is_group_chat = self.is_group_chat
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
|
||||
if reply_message:
|
||||
sender = reply_message.get("sender")
|
||||
@@ -1168,17 +1319,17 @@ class DefaultReplyer:
|
||||
else:
|
||||
mood_prompt = ""
|
||||
|
||||
message_list_before_now_half = await get_raw_msg_before_timestamp_with_chat(
|
||||
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
|
||||
)
|
||||
chat_talking_prompt_half = await build_readable_messages(
|
||||
chat_talking_prompt_half = build_readable_messages(
|
||||
message_list_before_now_half,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=read_mark,
|
||||
read_mark=0.0,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
@@ -1370,16 +1521,57 @@ class DefaultReplyer:
|
||||
if not global_config.relationship.enable_relationship:
|
||||
return ""
|
||||
|
||||
relationship_fetcher = relationship_fetcher_manager.get_fetcher(self.chat_stream.stream_id)
|
||||
|
||||
# 获取用户ID
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = await person_info_manager.get_person_id_by_person_name(sender)
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
if not person_id:
|
||||
logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取")
|
||||
return f"你完全不认识{sender},不理解ta的相关信息。"
|
||||
|
||||
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||
# 使用AFC关系追踪器获取关系信息
|
||||
try:
|
||||
from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker
|
||||
|
||||
# 创建关系追踪器实例
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||
relationship_tracker = ChatterRelationshipTracker(chatter_interest_scoring_system)
|
||||
if relationship_tracker:
|
||||
# 获取用户信息以获取真实的user_id
|
||||
user_info = await person_info_manager.get_values(person_id, ["user_id", "platform"])
|
||||
user_id = user_info.get("user_id", "unknown")
|
||||
|
||||
# 从数据库获取关系数据
|
||||
relationship_data = relationship_tracker._get_user_relationship_from_db(user_id)
|
||||
if relationship_data:
|
||||
relationship_text = relationship_data.get("relationship_text", "")
|
||||
relationship_score = relationship_data.get("relationship_score", 0.3)
|
||||
|
||||
# 构建丰富的关系信息描述
|
||||
if relationship_text:
|
||||
# 转换关系分数为描述性文本
|
||||
if relationship_score >= 0.8:
|
||||
relationship_level = "非常亲密的朋友"
|
||||
elif relationship_score >= 0.6:
|
||||
relationship_level = "好朋友"
|
||||
elif relationship_score >= 0.4:
|
||||
relationship_level = "普通朋友"
|
||||
elif relationship_score >= 0.2:
|
||||
relationship_level = "认识的人"
|
||||
else:
|
||||
relationship_level = "陌生人"
|
||||
|
||||
return f"你与{sender}的关系:{relationship_level}(关系分:{relationship_score:.2f}/1.0)。{relationship_text}"
|
||||
else:
|
||||
return f"你与{sender}是初次见面,关系分:{relationship_score:.2f}/1.0。"
|
||||
else:
|
||||
return f"你完全不认识{sender},这是第一次互动。"
|
||||
else:
|
||||
logger.warning("AFC关系追踪器未初始化,使用默认关系信息")
|
||||
return f"你与{sender}是普通朋友关系。"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取AFC关系信息失败: {e}")
|
||||
return f"你与{sender}是普通朋友关系。"
|
||||
|
||||
|
||||
def weighted_sample_no_replacement(items, weights, k) -> list:
|
||||
|
||||
@@ -37,7 +37,7 @@ def replace_user_references_sync(
|
||||
"""
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
|
||||
if name_resolver is None:
|
||||
person_info_manager = get_person_info_manager()
|
||||
|
||||
@@ -46,8 +46,8 @@ def replace_user_references_sync(
|
||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||
return f"{global_config.bot.nickname}(你)"
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
return person_info_manager.get_value(person_id, "person_name") or user_id # type: ignore
|
||||
|
||||
return person_info_manager.get_value_sync(person_id, "person_name") or user_id # type: ignore
|
||||
|
||||
name_resolver = default_resolver
|
||||
|
||||
# 处理回复<aaa:bbb>格式
|
||||
@@ -121,8 +121,7 @@ async def replace_user_references_async(
|
||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||
return f"{global_config.bot.nickname}(你)"
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
person_info = await person_info_manager.get_values(person_id, ["person_name"])
|
||||
return person_info.get("person_name") or user_id
|
||||
return await person_info_manager.get_value(person_id, "person_name") or user_id # type: ignore
|
||||
|
||||
name_resolver = default_resolver
|
||||
|
||||
@@ -170,7 +169,7 @@ async def replace_user_references_async(
|
||||
return content
|
||||
|
||||
|
||||
async def get_raw_msg_by_timestamp(
|
||||
def get_raw_msg_by_timestamp(
|
||||
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
@@ -181,10 +180,10 @@ async def get_raw_msg_by_timestamp(
|
||||
filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}}
|
||||
# 只有当 limit 为 0 时才应用外部 sort
|
||||
sort_order = [("time", 1)] if limit == 0 else None
|
||||
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
||||
|
||||
|
||||
async def get_raw_msg_by_timestamp_with_chat(
|
||||
def get_raw_msg_by_timestamp_with_chat(
|
||||
chat_id: str,
|
||||
timestamp_start: float,
|
||||
timestamp_end: float,
|
||||
@@ -201,7 +200,7 @@ async def get_raw_msg_by_timestamp_with_chat(
|
||||
# 只有当 limit 为 0 时才应用外部 sort
|
||||
sort_order = [("time", 1)] if limit == 0 else None
|
||||
# 直接将 limit_mode 传递给 find_messages
|
||||
return await find_messages(
|
||||
return find_messages(
|
||||
message_filter=filter_query,
|
||||
sort=sort_order,
|
||||
limit=limit,
|
||||
@@ -211,7 +210,7 @@ async def get_raw_msg_by_timestamp_with_chat(
|
||||
)
|
||||
|
||||
|
||||
async def get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
def get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id: str,
|
||||
timestamp_start: float,
|
||||
timestamp_end: float,
|
||||
@@ -228,12 +227,12 @@ async def get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
sort_order = [("time", 1)] if limit == 0 else None
|
||||
# 直接将 limit_mode 传递给 find_messages
|
||||
|
||||
return await find_messages(
|
||||
return find_messages(
|
||||
message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, filter_bot=filter_bot
|
||||
)
|
||||
|
||||
|
||||
async def get_raw_msg_by_timestamp_with_chat_users(
|
||||
def get_raw_msg_by_timestamp_with_chat_users(
|
||||
chat_id: str,
|
||||
timestamp_start: float,
|
||||
timestamp_end: float,
|
||||
@@ -252,10 +251,10 @@ async def get_raw_msg_by_timestamp_with_chat_users(
|
||||
}
|
||||
# 只有当 limit 为 0 时才应用外部 sort
|
||||
sort_order = [("time", 1)] if limit == 0 else None
|
||||
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
||||
|
||||
|
||||
async def get_actions_by_timestamp_with_chat(
|
||||
def get_actions_by_timestamp_with_chat(
|
||||
chat_id: str,
|
||||
timestamp_start: float = 0,
|
||||
timestamp_end: float = time.time(),
|
||||
@@ -274,10 +273,10 @@ async def get_actions_by_timestamp_with_chat(
|
||||
f"limit={limit}, limit_mode={limit_mode}"
|
||||
)
|
||||
|
||||
async with get_db_session() as session:
|
||||
with get_db_session() as session:
|
||||
if limit > 0:
|
||||
if limit_mode == "latest":
|
||||
query = await session.execute(
|
||||
query = session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
@@ -307,7 +306,7 @@ async def get_actions_by_timestamp_with_chat(
|
||||
}
|
||||
actions_result.append(action_dict)
|
||||
else: # earliest
|
||||
query = await session.execute(
|
||||
query = session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
@@ -337,7 +336,7 @@ async def get_actions_by_timestamp_with_chat(
|
||||
}
|
||||
actions_result.append(action_dict)
|
||||
else:
|
||||
query = await session.execute(
|
||||
query = session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
@@ -368,14 +367,14 @@ async def get_actions_by_timestamp_with_chat(
|
||||
return actions_result
|
||||
|
||||
|
||||
async def get_actions_by_timestamp_with_chat_inclusive(
|
||||
def get_actions_by_timestamp_with_chat_inclusive(
|
||||
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
|
||||
async with get_db_session() as session:
|
||||
with get_db_session() as session:
|
||||
if limit > 0:
|
||||
if limit_mode == "latest":
|
||||
query = await session.execute(
|
||||
query = session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
@@ -390,7 +389,7 @@ async def get_actions_by_timestamp_with_chat_inclusive(
|
||||
actions = list(query.scalars())
|
||||
return [action.__dict__ for action in reversed(actions)]
|
||||
else: # earliest
|
||||
query = await session.execute(
|
||||
query = session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
@@ -403,7 +402,7 @@ async def get_actions_by_timestamp_with_chat_inclusive(
|
||||
.limit(limit)
|
||||
)
|
||||
else:
|
||||
query = await session.execute(
|
||||
query = session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
@@ -419,14 +418,14 @@ async def get_actions_by_timestamp_with_chat_inclusive(
|
||||
return [action.__dict__ for action in actions]
|
||||
|
||||
|
||||
async def get_raw_msg_by_timestamp_random(
|
||||
def get_raw_msg_by_timestamp_random(
|
||||
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
先在范围时间戳内随机选择一条消息,取得消息的chat_id,然后根据chat_id获取该聊天在指定时间戳范围内的消息
|
||||
"""
|
||||
# 获取所有消息,只取chat_id字段
|
||||
all_msgs = await get_raw_msg_by_timestamp(timestamp_start, timestamp_end)
|
||||
all_msgs = get_raw_msg_by_timestamp(timestamp_start, timestamp_end)
|
||||
if not all_msgs:
|
||||
return []
|
||||
# 随机选一条
|
||||
@@ -434,10 +433,10 @@ async def get_raw_msg_by_timestamp_random(
|
||||
chat_id = msg["chat_id"]
|
||||
timestamp_start = msg["time"]
|
||||
# 用 chat_id 获取该聊天在指定时间戳范围内的消息
|
||||
return await get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, "earliest")
|
||||
return get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, "earliest")
|
||||
|
||||
|
||||
async def get_raw_msg_by_timestamp_with_users(
|
||||
def get_raw_msg_by_timestamp_with_users(
|
||||
timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||
@@ -447,39 +446,37 @@ async def get_raw_msg_by_timestamp_with_users(
|
||||
filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}, "user_id": {"$in": person_ids}}
|
||||
# 只有当 limit 为 0 时才应用外部 sort
|
||||
sort_order = [("time", 1)] if limit == 0 else None
|
||||
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
||||
|
||||
|
||||
async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
|
||||
def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
|
||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
"""
|
||||
filter_query = {"time": {"$lt": timestamp}}
|
||||
sort_order = [("time", 1)]
|
||||
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||
|
||||
|
||||
async def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
|
||||
def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
|
||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
"""
|
||||
filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}}
|
||||
sort_order = [("time", 1)]
|
||||
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||
|
||||
|
||||
async def get_raw_msg_before_timestamp_with_users(
|
||||
timestamp: float, person_ids: list, limit: int = 0
|
||||
) -> List[Dict[str, Any]]:
|
||||
def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, limit: int = 0) -> List[Dict[str, Any]]:
|
||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
"""
|
||||
filter_query = {"time": {"$lt": timestamp}, "user_id": {"$in": person_ids}}
|
||||
sort_order = [("time", 1)]
|
||||
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||
|
||||
|
||||
async def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int:
|
||||
def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int:
|
||||
"""
|
||||
检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。
|
||||
如果 timestamp_end 为 None,则检查从 timestamp_start (不含) 到当前时间的消息。
|
||||
@@ -493,10 +490,10 @@ async def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, tim
|
||||
return 0 # 起始时间大于等于结束时间,没有新消息
|
||||
|
||||
filter_query = {"chat_id": chat_id, "time": {"$gt": timestamp_start, "$lt": _timestamp_end}}
|
||||
return await count_messages(message_filter=filter_query)
|
||||
return count_messages(message_filter=filter_query)
|
||||
|
||||
|
||||
async def num_new_messages_since_with_users(
|
||||
def num_new_messages_since_with_users(
|
||||
chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: list
|
||||
) -> int:
|
||||
"""检查某些特定用户在特定聊天在指定时间戳之间有多少新消息"""
|
||||
@@ -507,10 +504,10 @@ async def num_new_messages_since_with_users(
|
||||
"time": {"$gt": timestamp_start, "$lt": timestamp_end},
|
||||
"user_id": {"$in": person_ids},
|
||||
}
|
||||
return await count_messages(message_filter=filter_query)
|
||||
return count_messages(message_filter=filter_query)
|
||||
|
||||
|
||||
async def _build_readable_messages_internal(
|
||||
def _build_readable_messages_internal(
|
||||
messages: List[Dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
@@ -520,7 +517,6 @@ async def _build_readable_messages_internal(
|
||||
pic_counter: int = 1,
|
||||
show_pic: bool = True,
|
||||
message_id_list: Optional[List[Dict[str, Any]]] = None,
|
||||
read_mark: float = 0.0,
|
||||
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
|
||||
"""
|
||||
内部辅助函数,构建可读消息字符串和原始消息详情列表。
|
||||
@@ -631,8 +627,7 @@ async def _build_readable_messages_internal(
|
||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||
person_name = f"{global_config.bot.nickname}(你)"
|
||||
else:
|
||||
person_info = await person_info_manager.get_values(person_id, ["person_name"])
|
||||
person_name = person_info.get("person_name") # type: ignore
|
||||
person_name = person_info_manager.get_value_sync(person_id, "person_name") # type: ignore
|
||||
|
||||
# 如果 person_name 未设置,则使用消息中的 nickname 或默认名称
|
||||
if not person_name:
|
||||
@@ -731,10 +726,11 @@ async def _build_readable_messages_internal(
|
||||
"is_action": is_action,
|
||||
}
|
||||
continue
|
||||
|
||||
# 如果是同一个人发送的连续消息且时间间隔小于等于60秒
|
||||
if name == current_merge["name"] and (timestamp - current_merge["end_time"] <= 60):
|
||||
current_merge["content"].append(content)
|
||||
current_merge["end_time"] = timestamp
|
||||
current_merge["end_time"] = timestamp # 更新最后消息时间
|
||||
else:
|
||||
# 保存上一个合并块
|
||||
merged_messages.append(current_merge)
|
||||
@@ -762,14 +758,8 @@ async def _build_readable_messages_internal(
|
||||
|
||||
# 4 & 5: 格式化为字符串
|
||||
output_lines = []
|
||||
read_mark_inserted = False
|
||||
|
||||
for _i, merged in enumerate(merged_messages):
|
||||
# 检查是否需要插入已读标记
|
||||
if read_mark > 0 and not read_mark_inserted and merged["start_time"] >= read_mark:
|
||||
output_lines.append("\n--- 以上消息是你已经看过,请关注以下未读的新消息---\n")
|
||||
read_mark_inserted = True
|
||||
|
||||
# 使用指定的 timestamp_mode 格式化时间
|
||||
readable_time = translate_timestamp_to_human_readable(merged["start_time"], mode=timestamp_mode)
|
||||
|
||||
@@ -810,7 +800,7 @@ async def _build_readable_messages_internal(
|
||||
)
|
||||
|
||||
|
||||
async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
# sourcery skip: use-contextlib-suppress
|
||||
"""
|
||||
构建图片映射信息字符串,显示图片的具体描述内容
|
||||
@@ -833,8 +823,8 @@ async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
# 从数据库中获取图片描述
|
||||
description = "[图片内容未知]" # 默认描述
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
image = (await session.execute(select(Images).where(Images.image_id == pic_id))).scalar_one_or_none()
|
||||
with get_db_session() as session:
|
||||
image = session.execute(select(Images).where(Images.image_id == pic_id)).scalar_one_or_none()
|
||||
if image and image.description: # type: ignore
|
||||
description = image.description
|
||||
except Exception:
|
||||
@@ -931,17 +921,17 @@ async def build_readable_messages_with_list(
|
||||
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
|
||||
允许通过参数控制格式化行为。
|
||||
"""
|
||||
formatted_string, details_list, pic_id_mapping, _ = await _build_readable_messages_internal(
|
||||
formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||
messages, replace_bot_name, merge_messages, timestamp_mode, truncate
|
||||
)
|
||||
|
||||
if pic_mapping_info := await build_pic_mapping_info(pic_id_mapping):
|
||||
if pic_mapping_info := build_pic_mapping_info(pic_id_mapping):
|
||||
formatted_string = f"{pic_mapping_info}\n\n{formatted_string}"
|
||||
|
||||
return formatted_string, details_list
|
||||
|
||||
|
||||
async def build_readable_messages_with_id(
|
||||
def build_readable_messages_with_id(
|
||||
messages: List[Dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
@@ -957,7 +947,7 @@ async def build_readable_messages_with_id(
|
||||
"""
|
||||
message_id_list = assign_message_ids(messages)
|
||||
|
||||
formatted_string = await build_readable_messages(
|
||||
formatted_string = build_readable_messages(
|
||||
messages=messages,
|
||||
replace_bot_name=replace_bot_name,
|
||||
merge_messages=merge_messages,
|
||||
@@ -972,7 +962,7 @@ async def build_readable_messages_with_id(
|
||||
return formatted_string, message_id_list
|
||||
|
||||
|
||||
async def build_readable_messages(
|
||||
def build_readable_messages(
|
||||
messages: List[Dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
@@ -1013,28 +1003,24 @@ async def build_readable_messages(
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
|
||||
async with get_db_session() as session:
|
||||
with get_db_session() as session:
|
||||
# 获取这个时间范围内的动作记录,并匹配chat_id
|
||||
actions_in_range = (
|
||||
await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
ActionRecords.time >= min_time, ActionRecords.time <= max_time, ActionRecords.chat_id == chat_id
|
||||
)
|
||||
actions_in_range = session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
ActionRecords.time >= min_time, ActionRecords.time <= max_time, ActionRecords.chat_id == chat_id
|
||||
)
|
||||
.order_by(ActionRecords.time)
|
||||
)
|
||||
.order_by(ActionRecords.time)
|
||||
).scalars()
|
||||
|
||||
# 获取最新消息之后的第一个动作记录
|
||||
action_after_latest = (
|
||||
await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id))
|
||||
.order_by(ActionRecords.time)
|
||||
.limit(1)
|
||||
)
|
||||
action_after_latest = session.execute(
|
||||
select(ActionRecords)
|
||||
.where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id))
|
||||
.order_by(ActionRecords.time)
|
||||
.limit(1)
|
||||
).scalars()
|
||||
|
||||
# 合并两部分动作记录,并转为 dict,避免 DetachedInstanceError
|
||||
@@ -1066,7 +1052,7 @@ async def build_readable_messages(
|
||||
|
||||
if read_mark <= 0:
|
||||
# 没有有效的 read_mark,直接格式化所有消息
|
||||
formatted_string, _, pic_id_mapping, _ = await _build_readable_messages_internal(
|
||||
formatted_string, _, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||
copy_messages,
|
||||
replace_bot_name,
|
||||
merge_messages,
|
||||
@@ -1077,7 +1063,7 @@ async def build_readable_messages(
|
||||
)
|
||||
|
||||
# 生成图片映射信息并添加到最前面
|
||||
pic_mapping_info = await build_pic_mapping_info(pic_id_mapping)
|
||||
pic_mapping_info = build_pic_mapping_info(pic_id_mapping)
|
||||
if pic_mapping_info:
|
||||
return f"{pic_mapping_info}\n\n{formatted_string}"
|
||||
else:
|
||||
@@ -1092,7 +1078,7 @@ async def build_readable_messages(
|
||||
pic_counter = 1
|
||||
|
||||
# 分别格式化,但使用共享的图片映射
|
||||
formatted_before, _, pic_id_mapping, pic_counter = await _build_readable_messages_internal(
|
||||
formatted_before, _, pic_id_mapping, pic_counter = _build_readable_messages_internal(
|
||||
messages_before_mark,
|
||||
replace_bot_name,
|
||||
merge_messages,
|
||||
@@ -1103,7 +1089,7 @@ async def build_readable_messages(
|
||||
show_pic=show_pic,
|
||||
message_id_list=message_id_list,
|
||||
)
|
||||
formatted_after, _, pic_id_mapping, _ = await _build_readable_messages_internal(
|
||||
formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||
messages_after_mark,
|
||||
replace_bot_name,
|
||||
merge_messages,
|
||||
@@ -1119,7 +1105,7 @@ async def build_readable_messages(
|
||||
|
||||
# 生成图片映射信息
|
||||
if pic_id_mapping:
|
||||
pic_mapping_info = f"图片信息:\n{await build_pic_mapping_info(pic_id_mapping)}\n聊天记录信息:\n"
|
||||
pic_mapping_info = f"图片信息:\n{build_pic_mapping_info(pic_id_mapping)}\n聊天记录信息:\n"
|
||||
else:
|
||||
pic_mapping_info = "聊天记录信息:\n"
|
||||
|
||||
@@ -1242,7 +1228,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
||||
|
||||
# 在最前面添加图片映射信息
|
||||
final_output_lines = []
|
||||
pic_mapping_info = await build_pic_mapping_info(pic_id_mapping)
|
||||
pic_mapping_info = build_pic_mapping_info(pic_id_mapping)
|
||||
if pic_mapping_info:
|
||||
final_output_lines.append(pic_mapping_info)
|
||||
final_output_lines.append("\n\n")
|
||||
|
||||
@@ -25,7 +25,7 @@ logger = get_logger("unified_prompt")
|
||||
@dataclass
|
||||
class PromptParameters:
|
||||
"""统一提示词参数系统"""
|
||||
|
||||
|
||||
# 基础参数
|
||||
chat_id: str = ""
|
||||
is_group_chat: bool = False
|
||||
@@ -34,7 +34,7 @@ class PromptParameters:
|
||||
reply_to: str = ""
|
||||
extra_info: str = ""
|
||||
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
|
||||
|
||||
|
||||
# 功能开关
|
||||
enable_tool: bool = True
|
||||
enable_memory: bool = True
|
||||
@@ -42,20 +42,20 @@ class PromptParameters:
|
||||
enable_relation: bool = True
|
||||
enable_cross_context: bool = True
|
||||
enable_knowledge: bool = True
|
||||
|
||||
|
||||
# 性能控制
|
||||
max_context_messages: int = 50
|
||||
|
||||
|
||||
# 调试选项
|
||||
debug_mode: bool = False
|
||||
|
||||
|
||||
# 聊天历史和上下文
|
||||
chat_target_info: Optional[Dict[str, Any]] = None
|
||||
message_list_before_now_long: List[Dict[str, Any]] = field(default_factory=list)
|
||||
message_list_before_short: List[Dict[str, Any]] = field(default_factory=list)
|
||||
chat_talking_prompt_short: str = ""
|
||||
target_user_info: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
# 已构建的内容块
|
||||
expression_habits_block: str = ""
|
||||
relation_info_block: str = ""
|
||||
@@ -63,7 +63,7 @@ class PromptParameters:
|
||||
tool_info_block: str = ""
|
||||
knowledge_prompt: str = ""
|
||||
cross_context_block: str = ""
|
||||
|
||||
|
||||
# 其他内容块
|
||||
keywords_reaction_prompt: str = ""
|
||||
extra_info_block: str = ""
|
||||
@@ -75,11 +75,13 @@ class PromptParameters:
|
||||
reply_target_block: str = ""
|
||||
mood_prompt: str = ""
|
||||
action_descriptions: str = ""
|
||||
|
||||
|
||||
# 可用动作信息
|
||||
available_actions: Optional[Dict[str, Any]] = None
|
||||
read_mark: float = 0.0
|
||||
|
||||
|
||||
# 动态生成的聊天场景提示
|
||||
chat_scene: str = ""
|
||||
|
||||
def validate(self) -> List[str]:
|
||||
"""参数验证"""
|
||||
errors = []
|
||||
@@ -94,22 +96,22 @@ class PromptParameters:
|
||||
|
||||
class PromptContext:
|
||||
"""提示词上下文管理器"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
|
||||
self._current_context_var = contextvars.ContextVar("current_context", default=None)
|
||||
self._context_lock = asyncio.Lock()
|
||||
|
||||
|
||||
@property
|
||||
def _current_context(self) -> Optional[str]:
|
||||
"""获取当前协程的上下文ID"""
|
||||
return self._current_context_var.get()
|
||||
|
||||
|
||||
@_current_context.setter
|
||||
def _current_context(self, value: Optional[str]):
|
||||
"""设置当前协程的上下文ID"""
|
||||
self._current_context_var.set(value) # type: ignore
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_scope(self, context_id: Optional[str] = None):
|
||||
"""创建一个异步的临时提示模板作用域"""
|
||||
@@ -124,13 +126,13 @@ class PromptContext:
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"获取上下文锁超时,context_id: {context_id}")
|
||||
context_id = None
|
||||
|
||||
|
||||
previous_context = self._current_context
|
||||
token = self._current_context_var.set(context_id) if context_id else None
|
||||
else:
|
||||
previous_context = self._current_context
|
||||
token = None
|
||||
|
||||
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
@@ -143,7 +145,7 @@ class PromptContext:
|
||||
self._current_context = previous_context
|
||||
except Exception:
|
||||
...
|
||||
|
||||
|
||||
async def get_prompt_async(self, name: str) -> Optional["Prompt"]:
|
||||
"""异步获取当前作用域中的提示模板"""
|
||||
async with self._context_lock:
|
||||
@@ -156,7 +158,7 @@ class PromptContext:
|
||||
):
|
||||
return self._context_prompts[current_context][name]
|
||||
return None
|
||||
|
||||
|
||||
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
|
||||
"""异步注册提示模板到指定作用域"""
|
||||
async with self._context_lock:
|
||||
@@ -167,59 +169,55 @@ class PromptContext:
|
||||
|
||||
class PromptManager:
|
||||
"""统一提示词管理器"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self._prompts = {}
|
||||
self._counter = 0
|
||||
self._context = PromptContext()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_message_scope(self, message_id: Optional[str] = None):
|
||||
"""为消息处理创建异步临时作用域"""
|
||||
async with self._context.async_scope(message_id):
|
||||
yield self
|
||||
|
||||
|
||||
async def get_prompt_async(self, name: str) -> "Prompt":
|
||||
"""异步获取提示模板"""
|
||||
context_prompt = await self._context.get_prompt_async(name)
|
||||
if context_prompt is not None:
|
||||
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
|
||||
return context_prompt
|
||||
|
||||
|
||||
async with self._lock:
|
||||
if name not in self._prompts:
|
||||
raise KeyError(f"Prompt '{name}' not found")
|
||||
return self._prompts[name]
|
||||
|
||||
|
||||
def generate_name(self, template: str) -> str:
|
||||
"""为未命名的prompt生成名称"""
|
||||
self._counter += 1
|
||||
return f"prompt_{self._counter}"
|
||||
|
||||
|
||||
def register(self, prompt: "Prompt") -> None:
|
||||
"""注册一个prompt"""
|
||||
if not prompt.name:
|
||||
prompt.name = self.generate_name(prompt.template)
|
||||
self._prompts[prompt.name] = prompt
|
||||
|
||||
|
||||
def add_prompt(self, name: str, fstr: str) -> "Prompt":
|
||||
"""添加新提示模板"""
|
||||
prompt = Prompt(fstr, name=name)
|
||||
if prompt.name:
|
||||
self._prompts[prompt.name] = prompt
|
||||
return prompt
|
||||
|
||||
|
||||
async def format_prompt(self, name: str, **kwargs) -> str:
|
||||
"""格式化提示模板"""
|
||||
prompt = await self.get_prompt_async(name)
|
||||
result = prompt.format(**kwargs)
|
||||
return result
|
||||
|
||||
@property
|
||||
def context(self):
|
||||
return self._context
|
||||
|
||||
|
||||
# 全局单例
|
||||
global_prompt_manager = PromptManager()
|
||||
@@ -230,21 +228,21 @@ class Prompt:
|
||||
统一提示词类 - 合并模板管理和智能构建功能
|
||||
真正的Prompt类,支持模板管理和智能上下文构建
|
||||
"""
|
||||
|
||||
|
||||
# 临时标记,作为类常量
|
||||
_TEMP_LEFT_BRACE = "__ESCAPED_LEFT_BRACE__"
|
||||
_TEMP_RIGHT_BRACE = "__ESCAPED_RIGHT_BRACE__"
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
template: str,
|
||||
name: Optional[str] = None,
|
||||
parameters: Optional[PromptParameters] = None,
|
||||
should_register: bool = True
|
||||
should_register: bool = True,
|
||||
):
|
||||
"""
|
||||
初始化统一提示词
|
||||
|
||||
|
||||
Args:
|
||||
template: 提示词模板字符串
|
||||
name: 提示词名称
|
||||
@@ -256,14 +254,14 @@ class Prompt:
|
||||
self.parameters = parameters or PromptParameters()
|
||||
self.args = self._parse_template_args(template)
|
||||
self._formatted_result = ""
|
||||
|
||||
|
||||
# 预处理模板中的转义花括号
|
||||
self._processed_template = self._process_escaped_braces(template)
|
||||
|
||||
|
||||
# 自动注册
|
||||
if should_register and not global_prompt_manager.context._current_context:
|
||||
if should_register and not global_prompt_manager._context._current_context:
|
||||
global_prompt_manager.register(self)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _process_escaped_braces(template) -> str:
|
||||
"""处理模板中的转义花括号"""
|
||||
@@ -271,14 +269,14 @@ class Prompt:
|
||||
template = "\n".join(str(item) for item in template)
|
||||
elif not isinstance(template, str):
|
||||
template = str(template)
|
||||
|
||||
|
||||
return template.replace("\\{", Prompt._TEMP_LEFT_BRACE).replace("\\}", Prompt._TEMP_RIGHT_BRACE)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _restore_escaped_braces(template: str) -> str:
|
||||
"""将临时标记还原为实际的花括号字符"""
|
||||
return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}")
|
||||
|
||||
|
||||
def _parse_template_args(self, template: str) -> List[str]:
|
||||
"""解析模板参数"""
|
||||
template_args = []
|
||||
@@ -288,11 +286,11 @@ class Prompt:
|
||||
if expr and expr not in template_args:
|
||||
template_args.append(expr)
|
||||
return template_args
|
||||
|
||||
|
||||
async def build(self) -> str:
|
||||
"""
|
||||
构建完整的提示词,包含智能上下文
|
||||
|
||||
|
||||
Returns:
|
||||
str: 构建完成的提示词文本
|
||||
"""
|
||||
@@ -301,38 +299,38 @@ class Prompt:
|
||||
if errors:
|
||||
logger.error(f"参数验证失败: {', '.join(errors)}")
|
||||
raise ValueError(f"参数验证失败: {', '.join(errors)}")
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
# 构建上下文数据
|
||||
context_data = await self._build_context_data()
|
||||
|
||||
|
||||
# 格式化模板
|
||||
result = await self._format_with_context(context_data)
|
||||
|
||||
|
||||
total_time = time.time() - start_time
|
||||
logger.debug(f"Prompt构建完成,模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s")
|
||||
|
||||
|
||||
self._formatted_result = result
|
||||
return result
|
||||
|
||||
|
||||
except asyncio.TimeoutError as e:
|
||||
logger.error(f"构建Prompt超时: {e}")
|
||||
raise TimeoutError(f"构建Prompt超时: {e}") from e
|
||||
except Exception as e:
|
||||
logger.error(f"构建Prompt失败: {e}")
|
||||
raise RuntimeError(f"构建Prompt失败: {e}") from e
|
||||
|
||||
|
||||
async def _build_context_data(self) -> Dict[str, Any]:
|
||||
"""构建智能上下文数据"""
|
||||
# 并行执行所有构建任务
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
try:
|
||||
# 准备构建任务
|
||||
tasks = []
|
||||
task_names = []
|
||||
|
||||
|
||||
# 初始化预构建参数
|
||||
pre_built_params = {}
|
||||
if self.parameters.expression_habits_block:
|
||||
@@ -347,46 +345,46 @@ class Prompt:
|
||||
pre_built_params["knowledge_prompt"] = self.parameters.knowledge_prompt
|
||||
if self.parameters.cross_context_block:
|
||||
pre_built_params["cross_context_block"] = self.parameters.cross_context_block
|
||||
|
||||
|
||||
# 根据参数确定要构建的项
|
||||
if self.parameters.enable_expression and not pre_built_params.get("expression_habits_block"):
|
||||
tasks.append(self._build_expression_habits())
|
||||
task_names.append("expression_habits")
|
||||
|
||||
|
||||
if self.parameters.enable_memory and not pre_built_params.get("memory_block"):
|
||||
tasks.append(self._build_memory_block())
|
||||
task_names.append("memory_block")
|
||||
|
||||
|
||||
if self.parameters.enable_relation and not pre_built_params.get("relation_info_block"):
|
||||
tasks.append(self._build_relation_info())
|
||||
task_names.append("relation_info")
|
||||
|
||||
|
||||
if self.parameters.enable_tool and not pre_built_params.get("tool_info_block"):
|
||||
tasks.append(self._build_tool_info())
|
||||
task_names.append("tool_info")
|
||||
|
||||
|
||||
if self.parameters.enable_knowledge and not pre_built_params.get("knowledge_prompt"):
|
||||
tasks.append(self._build_knowledge_info())
|
||||
task_names.append("knowledge_info")
|
||||
|
||||
|
||||
if self.parameters.enable_cross_context and not pre_built_params.get("cross_context_block"):
|
||||
tasks.append(self._build_cross_context())
|
||||
task_names.append("cross_context")
|
||||
|
||||
|
||||
# 性能优化
|
||||
base_timeout = 20.0
|
||||
base_timeout = 10.0
|
||||
task_timeout = 2.0
|
||||
timeout_seconds = min(
|
||||
max(base_timeout, len(tasks) * task_timeout),
|
||||
30.0,
|
||||
)
|
||||
|
||||
|
||||
max_concurrent_tasks = 5
|
||||
if len(tasks) > max_concurrent_tasks:
|
||||
results = []
|
||||
for i in range(0, len(tasks), max_concurrent_tasks):
|
||||
batch_tasks = tasks[i : i + max_concurrent_tasks]
|
||||
|
||||
|
||||
batch_results = await asyncio.wait_for(
|
||||
asyncio.gather(*batch_tasks, return_exceptions=True), timeout=timeout_seconds
|
||||
)
|
||||
@@ -395,225 +393,181 @@ class Prompt:
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.gather(*tasks, return_exceptions=True), timeout=timeout_seconds
|
||||
)
|
||||
|
||||
|
||||
# 处理结果
|
||||
context_data = {}
|
||||
for i, result in enumerate(results):
|
||||
task_name = task_names[i] if i < len(task_names) else f"task_{i}"
|
||||
|
||||
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"构建任务{task_name}失败: {str(result)}")
|
||||
elif isinstance(result, dict):
|
||||
context_data.update(result)
|
||||
|
||||
|
||||
# 添加预构建的参数
|
||||
for key, value in pre_built_params.items():
|
||||
if value:
|
||||
context_data[key] = value
|
||||
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"构建超时 ({timeout_seconds}s)")
|
||||
context_data = {}
|
||||
for key, value in pre_built_params.items():
|
||||
if value:
|
||||
context_data[key] = value
|
||||
|
||||
|
||||
# 构建聊天历史
|
||||
if self.parameters.prompt_mode == "s4u":
|
||||
await self._build_s4u_chat_context(context_data)
|
||||
else:
|
||||
await self._build_normal_chat_context(context_data)
|
||||
|
||||
|
||||
# 补充基础信息
|
||||
context_data.update({
|
||||
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt,
|
||||
"extra_info_block": self.parameters.extra_info_block,
|
||||
"time_block": self.parameters.time_block or f"当前时间:{time.strftime('%Y-%m-%d %H:%M:%S')}",
|
||||
"identity": self.parameters.identity_block,
|
||||
"schedule_block": self.parameters.schedule_block,
|
||||
"moderation_prompt": self.parameters.moderation_prompt_block,
|
||||
"reply_target_block": self.parameters.reply_target_block,
|
||||
"mood_state": self.parameters.mood_prompt,
|
||||
"action_descriptions": self.parameters.action_descriptions,
|
||||
})
|
||||
|
||||
context_data.update(
|
||||
{
|
||||
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt,
|
||||
"extra_info_block": self.parameters.extra_info_block,
|
||||
"time_block": self.parameters.time_block or f"当前时间:{time.strftime('%Y-%m-%d %H:%M:%S')}",
|
||||
"identity": self.parameters.identity_block,
|
||||
"schedule_block": self.parameters.schedule_block,
|
||||
"moderation_prompt": self.parameters.moderation_prompt_block,
|
||||
"reply_target_block": self.parameters.reply_target_block,
|
||||
"mood_state": self.parameters.mood_prompt,
|
||||
"action_descriptions": self.parameters.action_descriptions,
|
||||
}
|
||||
)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
logger.debug(f"上下文构建完成,总耗时: {total_time:.2f}s")
|
||||
|
||||
|
||||
return context_data
|
||||
|
||||
|
||||
async def _build_s4u_chat_context(self, context_data: Dict[str, Any]) -> None:
|
||||
"""构建S4U模式的聊天上下文"""
|
||||
if not self.parameters.message_list_before_now_long:
|
||||
return
|
||||
|
||||
core_dialogue, background_dialogue = await self._build_s4u_chat_history_prompts(
|
||||
|
||||
read_history_prompt, unread_history_prompt = await self._build_s4u_chat_history_prompts(
|
||||
self.parameters.message_list_before_now_long,
|
||||
self.parameters.target_user_info.get("user_id") if self.parameters.target_user_info else "",
|
||||
self.parameters.sender,
|
||||
read_mark=self.parameters.read_mark,
|
||||
self.parameters.chat_id,
|
||||
)
|
||||
|
||||
context_data["core_dialogue_prompt"] = core_dialogue
|
||||
context_data["background_dialogue_prompt"] = background_dialogue
|
||||
|
||||
|
||||
context_data["read_history_prompt"] = read_history_prompt
|
||||
context_data["unread_history_prompt"] = unread_history_prompt
|
||||
|
||||
async def _build_normal_chat_context(self, context_data: Dict[str, Any]) -> None:
|
||||
"""构建normal模式的聊天上下文"""
|
||||
if not self.parameters.chat_talking_prompt_short:
|
||||
return
|
||||
|
||||
|
||||
context_data["chat_info"] = f"""群里的聊天内容:
|
||||
{self.parameters.chat_talking_prompt_short}"""
|
||||
|
||||
@staticmethod
|
||||
|
||||
async def _build_s4u_chat_history_prompts(
|
||||
message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, read_mark: float = 0.0
|
||||
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, chat_id: str
|
||||
) -> Tuple[str, str]:
|
||||
"""构建S4U风格的分离对话prompt"""
|
||||
# 实现逻辑与原有SmartPromptBuilder相同
|
||||
core_dialogue_list = []
|
||||
bot_id = str(global_config.bot.qq_account)
|
||||
|
||||
for msg_dict in message_list_before_now:
|
||||
try:
|
||||
msg_user_id = str(msg_dict.get("user_id"))
|
||||
reply_to = msg_dict.get("reply_to", "")
|
||||
platform, reply_to_user_id = Prompt.parse_reply_target(reply_to)
|
||||
if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id:
|
||||
core_dialogue_list.append(msg_dict)
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息记录时出错: {msg_dict}, 错误: {e}")
|
||||
|
||||
# 构建背景对话 prompt
|
||||
all_dialogue_prompt = ""
|
||||
if message_list_before_now:
|
||||
latest_25_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :]
|
||||
all_dialogue_prompt_str = await build_readable_messages(
|
||||
latest_25_msgs,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal",
|
||||
truncate=True,
|
||||
read_mark=read_mark,
|
||||
"""构建S4U风格的已读/未读历史消息prompt"""
|
||||
try:
|
||||
# 动态导入default_generator以避免循环导入
|
||||
from src.plugin_system.apis.generator_api import get_replyer
|
||||
|
||||
# 创建临时生成器实例来使用其方法
|
||||
temp_generator = get_replyer(None, chat_id, request_type="prompt_building")
|
||||
return await temp_generator.build_s4u_chat_history_prompts(
|
||||
message_list_before_now, target_user_id, sender, chat_id
|
||||
)
|
||||
all_dialogue_prompt = f"所有用户的发言:\n{all_dialogue_prompt_str}"
|
||||
|
||||
# 构建核心对话 prompt
|
||||
core_dialogue_prompt = ""
|
||||
if core_dialogue_list:
|
||||
latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list
|
||||
has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages)
|
||||
|
||||
if not has_bot_message:
|
||||
core_dialogue_prompt = ""
|
||||
else:
|
||||
core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :]
|
||||
|
||||
core_dialogue_prompt_str = await build_readable_messages(
|
||||
core_dialogue_list,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=read_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
core_dialogue_prompt = f"""--------------------------------
|
||||
这是你和{sender}的对话,你们正在交流中:
|
||||
{core_dialogue_prompt_str}
|
||||
--------------------------------
|
||||
"""
|
||||
|
||||
return core_dialogue_prompt, all_dialogue_prompt
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"构建S4U历史消息prompt失败: {e}")
|
||||
|
||||
async def _build_expression_habits(self) -> Dict[str, Any]:
|
||||
"""构建表达习惯"""
|
||||
if not global_config.expression.enable_expression:
|
||||
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(self.parameters.chat_id)
|
||||
if not use_expression:
|
||||
return {"expression_habits_block": ""}
|
||||
|
||||
|
||||
try:
|
||||
from src.chat.express.expression_selector import ExpressionSelector
|
||||
|
||||
|
||||
# 获取聊天历史用于表情选择
|
||||
chat_history = ""
|
||||
if self.parameters.message_list_before_now_long:
|
||||
recent_messages = self.parameters.message_list_before_now_long[-10:]
|
||||
chat_history = await build_readable_messages(
|
||||
recent_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal",
|
||||
truncate=True
|
||||
chat_history = build_readable_messages(
|
||||
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
|
||||
)
|
||||
|
||||
|
||||
# 创建表情选择器
|
||||
expression_selector = ExpressionSelector()
|
||||
|
||||
expression_selector = ExpressionSelector(self.parameters.chat_id)
|
||||
|
||||
# 选择合适的表情
|
||||
selected_expressions = await expression_selector.select_suitable_expressions_llm(
|
||||
chat_history=chat_history,
|
||||
current_message=self.parameters.target,
|
||||
emotional_tone="neutral",
|
||||
topic_type="general",
|
||||
)
|
||||
|
||||
|
||||
# 构建表达习惯块
|
||||
if selected_expressions:
|
||||
style_habits_str = "\n".join([f"- {expr}" for expr in selected_expressions])
|
||||
expression_habits_block = f"- 你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:\n{style_habits_str}"
|
||||
else:
|
||||
expression_habits_block = ""
|
||||
|
||||
|
||||
return {"expression_habits_block": expression_habits_block}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"构建表达习惯失败: {e}")
|
||||
return {"expression_habits_block": ""}
|
||||
|
||||
|
||||
async def _build_memory_block(self) -> Dict[str, Any]:
|
||||
"""构建记忆块"""
|
||||
if not global_config.memory.enable_memory:
|
||||
return {"memory_block": ""}
|
||||
|
||||
|
||||
try:
|
||||
from src.chat.memory_system.memory_activator import MemoryActivator
|
||||
from src.chat.memory_system.async_instant_memory_wrapper import get_async_instant_memory
|
||||
|
||||
|
||||
# 获取聊天历史
|
||||
chat_history = ""
|
||||
if self.parameters.message_list_before_now_long:
|
||||
recent_messages = self.parameters.message_list_before_now_long[-20:]
|
||||
chat_history = await build_readable_messages(
|
||||
recent_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal",
|
||||
truncate=True
|
||||
chat_history = build_readable_messages(
|
||||
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
|
||||
)
|
||||
|
||||
|
||||
# 激活长期记忆
|
||||
memory_activator = MemoryActivator()
|
||||
running_memories = await memory_activator.activate_memory_with_chat_history(
|
||||
target_message=self.parameters.target,
|
||||
chat_history_prompt=chat_history
|
||||
target_message=self.parameters.target, chat_history_prompt=chat_history
|
||||
)
|
||||
|
||||
|
||||
# 获取即时记忆
|
||||
async_memory_wrapper = get_async_instant_memory(self.parameters.chat_id)
|
||||
instant_memory = await async_memory_wrapper.get_memory_with_fallback(self.parameters.target)
|
||||
|
||||
|
||||
# 构建记忆块
|
||||
memory_parts = []
|
||||
|
||||
|
||||
if running_memories:
|
||||
memory_parts.append("以下是当前在聊天中,你回忆起的记忆:")
|
||||
for memory in running_memories:
|
||||
memory_parts.append(f"- {memory['content']}")
|
||||
|
||||
|
||||
if instant_memory:
|
||||
memory_parts.append(f"- {instant_memory}")
|
||||
|
||||
|
||||
memory_block = "\n".join(memory_parts) if memory_parts else ""
|
||||
|
||||
|
||||
return {"memory_block": memory_block}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"构建记忆块失败: {e}")
|
||||
return {"memory_block": ""}
|
||||
|
||||
|
||||
async def _build_relation_info(self) -> Dict[str, Any]:
|
||||
"""构建关系信息"""
|
||||
try:
|
||||
@@ -622,106 +576,104 @@ class Prompt:
|
||||
except Exception as e:
|
||||
logger.error(f"构建关系信息失败: {e}")
|
||||
return {"relation_info_block": ""}
|
||||
|
||||
|
||||
async def _build_tool_info(self) -> Dict[str, Any]:
|
||||
"""构建工具信息"""
|
||||
if not global_config.tool.enable_tool:
|
||||
return {"tool_info_block": ""}
|
||||
|
||||
|
||||
try:
|
||||
from src.plugin_system.core.tool_use import ToolExecutor
|
||||
|
||||
|
||||
# 获取聊天历史
|
||||
chat_history = ""
|
||||
if self.parameters.message_list_before_now_long:
|
||||
recent_messages = self.parameters.message_list_before_now_long[-15:]
|
||||
chat_history = await build_readable_messages(
|
||||
recent_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal",
|
||||
truncate=True
|
||||
chat_history = build_readable_messages(
|
||||
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
|
||||
)
|
||||
|
||||
|
||||
# 创建工具执行器
|
||||
tool_executor = ToolExecutor(chat_id=self.parameters.chat_id)
|
||||
|
||||
|
||||
# 执行工具获取信息
|
||||
tool_results, _, _ = await tool_executor.execute_from_chat_message(
|
||||
sender=self.parameters.sender,
|
||||
target_message=self.parameters.target,
|
||||
chat_history=chat_history,
|
||||
return_details=False
|
||||
return_details=False,
|
||||
)
|
||||
|
||||
|
||||
# 构建工具信息块
|
||||
if tool_results:
|
||||
tool_info_parts = ["## 工具信息","以下是你通过工具获取到的实时信息:"]
|
||||
tool_info_parts = ["## 工具信息", "以下是你通过工具获取到的实时信息:"]
|
||||
for tool_result in tool_results:
|
||||
tool_name = tool_result.get("tool_name", "unknown")
|
||||
content = tool_result.get("content", "")
|
||||
result_type = tool_result.get("type", "tool_result")
|
||||
|
||||
|
||||
tool_info_parts.append(f"- 【{tool_name}】{result_type}: {content}")
|
||||
|
||||
|
||||
tool_info_parts.append("以上是你获取到的实时信息,请在回复时参考这些信息。")
|
||||
tool_info_block = "\n".join(tool_info_parts)
|
||||
else:
|
||||
tool_info_block = ""
|
||||
|
||||
|
||||
return {"tool_info_block": tool_info_block}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"构建工具信息失败: {e}")
|
||||
return {"tool_info_block": ""}
|
||||
|
||||
|
||||
async def _build_knowledge_info(self) -> Dict[str, Any]:
|
||||
"""构建知识信息"""
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
return {"knowledge_prompt": ""}
|
||||
|
||||
|
||||
try:
|
||||
from src.chat.knowledge.knowledge_lib import qa_manager
|
||||
|
||||
from src.chat.knowledge.knowledge_lib import QAManager
|
||||
|
||||
# 获取问题文本(当前消息)
|
||||
question = self.parameters.target or ""
|
||||
if not question:
|
||||
return {"knowledge_prompt": ""}
|
||||
|
||||
# 检查QA管理器是否已成功初始化
|
||||
if not qa_manager:
|
||||
logger.warning("QA管理器未初始化 (可能lpmm_knowledge被禁用),跳过知识库搜索。")
|
||||
return {"knowledge_prompt": ""}
|
||||
|
||||
|
||||
# 创建QA管理器
|
||||
qa_manager = QAManager()
|
||||
|
||||
# 搜索相关知识
|
||||
knowledge_results = await qa_manager.get_knowledge(
|
||||
question=question
|
||||
question=question, chat_id=self.parameters.chat_id, max_results=5, min_similarity=0.5
|
||||
)
|
||||
|
||||
|
||||
# 构建知识块
|
||||
if knowledge_results and knowledge_results.get("knowledge_items"):
|
||||
knowledge_parts = ["## 知识库信息","以下是与你当前对话相关的知识信息:"]
|
||||
|
||||
knowledge_parts = ["## 知识库信息", "以下是与你当前对话相关的知识信息:"]
|
||||
|
||||
for item in knowledge_results["knowledge_items"]:
|
||||
content = item.get("content", "")
|
||||
source = item.get("source", "")
|
||||
relevance = item.get("relevance", 0.0)
|
||||
|
||||
|
||||
if content:
|
||||
knowledge_parts.append(f"- [相关度: {relevance}] {content}")
|
||||
|
||||
if summary := knowledge_results.get("summary"):
|
||||
knowledge_parts.append(f"\n知识总结: {summary}")
|
||||
|
||||
if source:
|
||||
knowledge_parts.append(f"- [{relevance:.2f}] {content} (来源: {source})")
|
||||
else:
|
||||
knowledge_parts.append(f"- [{relevance:.2f}] {content}")
|
||||
|
||||
if knowledge_results.get("summary"):
|
||||
knowledge_parts.append(f"\n知识总结: {knowledge_results['summary']}")
|
||||
|
||||
knowledge_prompt = "\n".join(knowledge_parts)
|
||||
else:
|
||||
knowledge_prompt = ""
|
||||
|
||||
|
||||
return {"knowledge_prompt": knowledge_prompt}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"构建知识信息失败: {e}")
|
||||
return {"knowledge_prompt": ""}
|
||||
|
||||
|
||||
async def _build_cross_context(self) -> Dict[str, Any]:
|
||||
"""构建跨群上下文"""
|
||||
try:
|
||||
@@ -732,7 +684,7 @@ class Prompt:
|
||||
except Exception as e:
|
||||
logger.error(f"构建跨群上下文失败: {e}")
|
||||
return {"cross_context_block": ""}
|
||||
|
||||
|
||||
async def _format_with_context(self, context_data: Dict[str, Any]) -> str:
|
||||
"""使用上下文数据格式化模板"""
|
||||
if self.parameters.prompt_mode == "s4u":
|
||||
@@ -741,9 +693,9 @@ class Prompt:
|
||||
params = self._prepare_normal_params(context_data)
|
||||
else:
|
||||
params = self._prepare_default_params(context_data)
|
||||
|
||||
|
||||
return await global_prompt_manager.format_prompt(self.name, **params) if self.name else self.format(**params)
|
||||
|
||||
|
||||
def _prepare_s4u_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""准备S4U模式的参数"""
|
||||
return {
|
||||
@@ -759,17 +711,19 @@ class Prompt:
|
||||
"action_descriptions": self.parameters.action_descriptions or context_data.get("action_descriptions", ""),
|
||||
"sender_name": self.parameters.sender or "未知用户",
|
||||
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
|
||||
"background_dialogue_prompt": context_data.get("background_dialogue_prompt", ""),
|
||||
"read_history_prompt": context_data.get("read_history_prompt", ""),
|
||||
"unread_history_prompt": context_data.get("unread_history_prompt", ""),
|
||||
"time_block": context_data.get("time_block", ""),
|
||||
"core_dialogue_prompt": context_data.get("core_dialogue_prompt", ""),
|
||||
"reply_target_block": context_data.get("reply_target_block", ""),
|
||||
"reply_style": global_config.personality.reply_style,
|
||||
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""),
|
||||
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
|
||||
or context_data.get("keywords_reaction_prompt", ""),
|
||||
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
|
||||
"safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""),
|
||||
"chat_context_type": "群聊" if self.parameters.is_group_chat else "私聊",
|
||||
"safety_guidelines_block": self.parameters.safety_guidelines_block
|
||||
or context_data.get("safety_guidelines_block", ""),
|
||||
"chat_scene": self.parameters.chat_scene or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。",
|
||||
}
|
||||
|
||||
|
||||
def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""准备Normal模式的参数"""
|
||||
return {
|
||||
@@ -789,11 +743,14 @@ class Prompt:
|
||||
"reply_target_block": context_data.get("reply_target_block", ""),
|
||||
"config_expression_style": global_config.personality.reply_style,
|
||||
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
|
||||
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""),
|
||||
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
|
||||
or context_data.get("keywords_reaction_prompt", ""),
|
||||
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
|
||||
"safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""),
|
||||
"safety_guidelines_block": self.parameters.safety_guidelines_block
|
||||
or context_data.get("safety_guidelines_block", ""),
|
||||
"chat_scene": self.parameters.chat_scene or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。",
|
||||
}
|
||||
|
||||
|
||||
def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""准备默认模式的参数"""
|
||||
return {
|
||||
@@ -809,11 +766,13 @@ class Prompt:
|
||||
"reason": "",
|
||||
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
|
||||
"reply_style": global_config.personality.reply_style,
|
||||
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""),
|
||||
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
|
||||
or context_data.get("keywords_reaction_prompt", ""),
|
||||
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
|
||||
"safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""),
|
||||
"safety_guidelines_block": self.parameters.safety_guidelines_block
|
||||
or context_data.get("safety_guidelines_block", ""),
|
||||
}
|
||||
|
||||
|
||||
def format(self, *args, **kwargs) -> str:
|
||||
"""格式化模板,支持位置参数和关键字参数"""
|
||||
try:
|
||||
@@ -826,21 +785,21 @@ class Prompt:
|
||||
processed_template = self._processed_template.format(**formatted_args)
|
||||
else:
|
||||
processed_template = self._processed_template
|
||||
|
||||
|
||||
# 再用关键字参数格式化
|
||||
if kwargs:
|
||||
processed_template = processed_template.format(**kwargs)
|
||||
|
||||
|
||||
# 将临时标记还原为实际的花括号
|
||||
result = self._restore_escaped_braces(processed_template)
|
||||
return result
|
||||
except (IndexError, KeyError) as e:
|
||||
raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {str(e)}") from e
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""返回格式化后的结果或原始模板"""
|
||||
return self._formatted_result if self._formatted_result else self.template
|
||||
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""返回提示词的表示形式"""
|
||||
return f"Prompt(template='{self.template}', name='{self.name}')"
|
||||
@@ -912,9 +871,7 @@ class Prompt:
|
||||
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||
|
||||
@staticmethod
|
||||
async def build_cross_context(
|
||||
chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]
|
||||
) -> str:
|
||||
async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
构建跨群聊上下文 - 统一实现
|
||||
|
||||
@@ -930,7 +887,7 @@ class Prompt:
|
||||
return ""
|
||||
|
||||
from src.plugin_system.apis import cross_context_api
|
||||
|
||||
|
||||
other_chat_raw_ids = cross_context_api.get_context_groups(chat_id)
|
||||
if not other_chat_raw_ids:
|
||||
return ""
|
||||
@@ -969,7 +926,7 @@ class Prompt:
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
if person_id:
|
||||
user_id = person_info_manager.get_value(person_id, "user_id")
|
||||
user_id = person_info_manager.get_value_sync(person_id, "user_id")
|
||||
return str(user_id) if user_id else ""
|
||||
|
||||
return ""
|
||||
@@ -977,10 +934,7 @@ class Prompt:
|
||||
|
||||
# 工厂函数
|
||||
def create_prompt(
|
||||
template: str,
|
||||
name: Optional[str] = None,
|
||||
parameters: Optional[PromptParameters] = None,
|
||||
**kwargs
|
||||
template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs
|
||||
) -> Prompt:
|
||||
"""快速创建Prompt实例的工厂函数"""
|
||||
if parameters is None:
|
||||
@@ -989,14 +943,10 @@ def create_prompt(
|
||||
|
||||
|
||||
async def create_prompt_async(
|
||||
template: str,
|
||||
name: Optional[str] = None,
|
||||
parameters: Optional[PromptParameters] = None,
|
||||
**kwargs
|
||||
template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs
|
||||
) -> Prompt:
|
||||
"""异步创建Prompt实例"""
|
||||
prompt = create_prompt(template, name, parameters, **kwargs)
|
||||
if global_prompt_manager.context._current_context:
|
||||
await global_prompt_manager.context.register_async(prompt)
|
||||
if global_prompt_manager._context._current_context:
|
||||
await global_prompt_manager._context.register_async(prompt)
|
||||
return prompt
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import numpy as np
|
||||
|
||||
from collections import Counter
|
||||
from maim_message import UserInfo
|
||||
from typing import Optional, Tuple, Dict, List, Any, Coroutine
|
||||
from typing import Optional, Tuple, Dict, List, Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import find_messages, count_messages
|
||||
@@ -332,17 +332,17 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
|
||||
|
||||
if global_config.response_splitter.enable and enable_splitter:
|
||||
logger.info(f"回复分割器已启用,模式: {global_config.response_splitter.split_mode}。")
|
||||
|
||||
|
||||
split_mode = global_config.response_splitter.split_mode
|
||||
|
||||
|
||||
if split_mode == "llm" and "[SPLIT]" in cleaned_text:
|
||||
logger.debug("检测到 [SPLIT] 标记,使用 LLM 自定义分割。")
|
||||
split_sentences_raw = cleaned_text.split("[SPLIT]")
|
||||
split_sentences = [s.strip() for s in split_sentences_raw if s.strip()]
|
||||
else:
|
||||
if split_mode == "llm":
|
||||
logger.debug("未检测到 [SPLIT] 标记,回退到基于标点的传统模式进行分割。")
|
||||
split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text)
|
||||
logger.debug("未检测到 [SPLIT] 标记,本次不进行分割。")
|
||||
split_sentences = [cleaned_text]
|
||||
else: # mode == "punctuation"
|
||||
logger.debug("使用基于标点的传统模式进行分割。")
|
||||
split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text)
|
||||
@@ -352,6 +352,8 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
|
||||
|
||||
sentences = []
|
||||
for sentence in split_sentences:
|
||||
# 清除开头可能存在的空行
|
||||
sentence = sentence.lstrip("\n").rstrip()
|
||||
if global_config.chinese_typo.enable and enable_chinese_typo:
|
||||
typoed_text, typo_corrections = typo_generator.create_typo_sentence(sentence)
|
||||
sentences.append(typoed_text)
|
||||
@@ -540,8 +542,7 @@ def get_western_ratio(paragraph):
|
||||
return western_count / len(alnum_chars)
|
||||
|
||||
|
||||
def count_messages_between(start_time: float, end_time: float, stream_id: str) -> tuple[int, int] | tuple[
|
||||
Coroutine[Any, Any, int], int]:
|
||||
def count_messages_between(start_time: float, end_time: float, stream_id: str) -> tuple[int, int]:
|
||||
"""计算两个时间点之间的消息数量和文本总长度
|
||||
|
||||
Args:
|
||||
@@ -619,7 +620,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal"
|
||||
return time.strftime("%H:%M:%S", time.localtime(timestamp))
|
||||
|
||||
|
||||
async def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
||||
def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
||||
"""
|
||||
获取聊天类型(是否群聊)和私聊对象信息。
|
||||
|
||||
@@ -663,8 +664,7 @@ async def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Di
|
||||
if person_id:
|
||||
# get_value is async, so await it directly
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_data = await person_info_manager.get_values(person_id, ["person_name"])
|
||||
person_name = person_data.get("person_name")
|
||||
person_name = person_info_manager.get_value_sync(person_id, "person_name")
|
||||
|
||||
target_info["person_id"] = person_id
|
||||
target_info["person_name"] = person_name
|
||||
@@ -695,25 +695,9 @@ def assign_message_ids(messages: List[Any]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
result = []
|
||||
used_ids = set()
|
||||
len_i = len(messages)
|
||||
if len_i > 100:
|
||||
a = 10
|
||||
b = 99
|
||||
else:
|
||||
a = 1
|
||||
b = 9
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
# 生成唯一的简短ID
|
||||
while True:
|
||||
# 使用索引+随机数生成简短ID
|
||||
random_suffix = random.randint(a, b)
|
||||
message_id = f"m{i + 1}{random_suffix}"
|
||||
|
||||
if message_id not in used_ids:
|
||||
used_ids.add(message_id)
|
||||
break
|
||||
|
||||
# 使用简单的索引作为ID
|
||||
message_id = f"m{i + 1}"
|
||||
result.append({"id": message_id, "message": message})
|
||||
|
||||
return result
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,6 +6,7 @@ class BaseDataModel:
|
||||
def deepcopy(self):
|
||||
return copy.deepcopy(self)
|
||||
|
||||
|
||||
def temporarily_transform_class_to_dict(obj: Any) -> Any:
|
||||
# sourcery skip: assign-if-exp, reintroduce-else
|
||||
"""
|
||||
|
||||
137
src/common/data_models/bot_interest_data_model.py
Normal file
137
src/common/data_models/bot_interest_data_model.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
机器人兴趣标签数据模型
|
||||
定义机器人的兴趣标签和相关的embedding数据结构
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Optional, Any
|
||||
from datetime import datetime
|
||||
|
||||
from . import BaseDataModel
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotInterestTag(BaseDataModel):
|
||||
"""机器人兴趣标签"""
|
||||
|
||||
tag_name: str
|
||||
weight: float = 1.0 # 权重,表示对这个兴趣的喜好程度 (0.0-1.0)
|
||||
embedding: Optional[List[float]] = None # 标签的embedding向量
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
updated_at: datetime = field(default_factory=datetime.now)
|
||||
is_active: bool = True
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"tag_name": self.tag_name,
|
||||
"weight": self.weight,
|
||||
"embedding": self.embedding,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
"is_active": self.is_active,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "BotInterestTag":
|
||||
"""从字典创建对象"""
|
||||
return cls(
|
||||
tag_name=data["tag_name"],
|
||||
weight=data.get("weight", 1.0),
|
||||
embedding=data.get("embedding"),
|
||||
created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.now(),
|
||||
updated_at=datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else datetime.now(),
|
||||
is_active=data.get("is_active", True),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotPersonalityInterests(BaseDataModel):
|
||||
"""机器人人格化兴趣配置"""
|
||||
|
||||
personality_id: str
|
||||
personality_description: str # 人设描述文本
|
||||
interest_tags: List[BotInterestTag] = field(default_factory=list)
|
||||
embedding_model: str = "text-embedding-ada-002" # 使用的embedding模型
|
||||
last_updated: datetime = field(default_factory=datetime.now)
|
||||
version: int = 1 # 版本号,用于追踪更新
|
||||
|
||||
def get_active_tags(self) -> List[BotInterestTag]:
|
||||
"""获取活跃的兴趣标签"""
|
||||
return [tag for tag in self.interest_tags if tag.is_active]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"personality_id": self.personality_id,
|
||||
"personality_description": self.personality_description,
|
||||
"interest_tags": [tag.to_dict() for tag in self.interest_tags],
|
||||
"embedding_model": self.embedding_model,
|
||||
"last_updated": self.last_updated.isoformat(),
|
||||
"version": self.version,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "BotPersonalityInterests":
|
||||
"""从字典创建对象"""
|
||||
return cls(
|
||||
personality_id=data["personality_id"],
|
||||
personality_description=data["personality_description"],
|
||||
interest_tags=[BotInterestTag.from_dict(tag_data) for tag_data in data.get("interest_tags", [])],
|
||||
embedding_model=data.get("embedding_model", "text-embedding-ada-002"),
|
||||
last_updated=datetime.fromisoformat(data["last_updated"]) if data.get("last_updated") else datetime.now(),
|
||||
version=data.get("version", 1),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterestMatchResult(BaseDataModel):
|
||||
"""兴趣匹配结果"""
|
||||
|
||||
message_id: str
|
||||
matched_tags: List[str] = field(default_factory=list)
|
||||
match_scores: Dict[str, float] = field(default_factory=dict) # tag_name -> score
|
||||
overall_score: float = 0.0
|
||||
top_tag: Optional[str] = None
|
||||
confidence: float = 0.0 # 匹配置信度 (0.0-1.0)
|
||||
matched_keywords: List[str] = field(default_factory=list)
|
||||
|
||||
def add_match(self, tag_name: str, score: float, keywords: List[str] = None):
|
||||
"""添加匹配结果"""
|
||||
self.matched_tags.append(tag_name)
|
||||
self.match_scores[tag_name] = score
|
||||
if keywords:
|
||||
self.matched_keywords.extend(keywords)
|
||||
|
||||
def calculate_overall_score(self):
|
||||
"""计算总体匹配分数"""
|
||||
if not self.match_scores:
|
||||
self.overall_score = 0.0
|
||||
self.top_tag = None
|
||||
return
|
||||
|
||||
# 使用加权平均计算总体分数
|
||||
total_weight = len(self.match_scores)
|
||||
if total_weight > 0:
|
||||
self.overall_score = sum(self.match_scores.values()) / total_weight
|
||||
# 设置最佳匹配标签
|
||||
self.top_tag = max(self.match_scores.items(), key=lambda x: x[1])[0]
|
||||
else:
|
||||
self.overall_score = 0.0
|
||||
self.top_tag = None
|
||||
|
||||
# 计算置信度(基于匹配标签数量和分数分布)
|
||||
if len(self.match_scores) > 0:
|
||||
avg_score = self.overall_score
|
||||
score_variance = sum((score - avg_score) ** 2 for score in self.match_scores.values()) / len(
|
||||
self.match_scores
|
||||
)
|
||||
# 分数越集中,置信度越高
|
||||
self.confidence = max(0.0, 1.0 - score_variance)
|
||||
else:
|
||||
self.confidence = 0.0
|
||||
|
||||
def get_top_matches(self, top_n: int = 3) -> List[tuple]:
|
||||
"""获取前N个最佳匹配"""
|
||||
sorted_matches = sorted(self.match_scores.items(), key=lambda x: x[1], reverse=True)
|
||||
return sorted_matches[:top_n]
|
||||
@@ -79,6 +79,7 @@ class DatabaseMessages(BaseDataModel):
|
||||
is_command: bool = False,
|
||||
is_notify: bool = False,
|
||||
selected_expressions: Optional[str] = None,
|
||||
is_read: bool = False,
|
||||
user_id: str = "",
|
||||
user_nickname: str = "",
|
||||
user_cardname: Optional[str] = None,
|
||||
@@ -94,6 +95,9 @@ class DatabaseMessages(BaseDataModel):
|
||||
chat_info_platform: str = "",
|
||||
chat_info_create_time: float = 0.0,
|
||||
chat_info_last_active_time: float = 0.0,
|
||||
# 新增字段
|
||||
actions: Optional[list] = None,
|
||||
should_reply: bool = False,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.message_id = message_id
|
||||
@@ -102,6 +106,10 @@ class DatabaseMessages(BaseDataModel):
|
||||
self.reply_to = reply_to
|
||||
self.interest_value = interest_value
|
||||
|
||||
# 新增字段
|
||||
self.actions = actions
|
||||
self.should_reply = should_reply
|
||||
|
||||
self.key_words = key_words
|
||||
self.key_words_lite = key_words_lite
|
||||
self.is_mentioned = is_mentioned
|
||||
@@ -122,6 +130,7 @@ class DatabaseMessages(BaseDataModel):
|
||||
self.is_notify = is_notify
|
||||
|
||||
self.selected_expressions = selected_expressions
|
||||
self.is_read = is_read
|
||||
|
||||
self.group_info: Optional[DatabaseGroupInfo] = None
|
||||
self.user_info = DatabaseUserInfo(
|
||||
@@ -188,6 +197,10 @@ class DatabaseMessages(BaseDataModel):
|
||||
"is_command": self.is_command,
|
||||
"is_notify": self.is_notify,
|
||||
"selected_expressions": self.selected_expressions,
|
||||
"is_read": self.is_read,
|
||||
# 新增字段
|
||||
"actions": self.actions,
|
||||
"should_reply": self.should_reply,
|
||||
"user_id": self.user_info.user_id,
|
||||
"user_nickname": self.user_info.user_nickname,
|
||||
"user_cardname": self.user_info.user_cardname,
|
||||
@@ -205,6 +218,61 @@ class DatabaseMessages(BaseDataModel):
|
||||
"chat_info_user_cardname": self.chat_info.user_info.user_cardname,
|
||||
}
|
||||
|
||||
def update_message_info(self, interest_value: float = None, actions: list = None, should_reply: bool = None):
|
||||
"""
|
||||
更新消息信息
|
||||
|
||||
Args:
|
||||
interest_value: 兴趣度值
|
||||
actions: 执行的动作列表
|
||||
should_reply: 是否应该回复
|
||||
"""
|
||||
if interest_value is not None:
|
||||
self.interest_value = interest_value
|
||||
if actions is not None:
|
||||
self.actions = actions
|
||||
if should_reply is not None:
|
||||
self.should_reply = should_reply
|
||||
|
||||
def add_action(self, action: str):
|
||||
"""
|
||||
添加执行的动作到消息中
|
||||
|
||||
Args:
|
||||
action: 要添加的动作名称
|
||||
"""
|
||||
if self.actions is None:
|
||||
self.actions = []
|
||||
if action not in self.actions: # 避免重复添加
|
||||
self.actions.append(action)
|
||||
|
||||
def get_actions(self) -> list:
|
||||
"""
|
||||
获取执行的动作列表
|
||||
|
||||
Returns:
|
||||
动作列表,如果没有动作则返回空列表
|
||||
"""
|
||||
return self.actions or []
|
||||
|
||||
def get_message_summary(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取消息摘要信息
|
||||
|
||||
Returns:
|
||||
包含关键字段的消息摘要
|
||||
"""
|
||||
return {
|
||||
"message_id": self.message_id,
|
||||
"time": self.time,
|
||||
"interest_value": self.interest_value,
|
||||
"actions": self.actions,
|
||||
"should_reply": self.should_reply,
|
||||
"user_nickname": self.user_info.user_nickname,
|
||||
"display_message": self.display_message,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class DatabaseActionRecords(BaseDataModel):
|
||||
def __init__(
|
||||
@@ -232,4 +300,4 @@ class DatabaseActionRecords(BaseDataModel):
|
||||
self.action_prompt_display = action_prompt_display
|
||||
self.chat_id = chat_id
|
||||
self.chat_info_stream_id = chat_info_stream_id
|
||||
self.chat_info_platform = chat_info_platform
|
||||
self.chat_info_platform = chat_info_platform
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Dict, List, TYPE_CHECKING
|
||||
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
from . import BaseDataModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
from .database_data_model import DatabaseMessages
|
||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -21,23 +23,37 @@ class ActionPlannerInfo(BaseDataModel):
|
||||
action_type: str = field(default_factory=str)
|
||||
reasoning: Optional[str] = None
|
||||
action_data: Optional[Dict] = None
|
||||
action_message: Optional[Dict] = None
|
||||
action_message: Optional["DatabaseMessages"] = None
|
||||
available_actions: Optional[Dict[str, "ActionInfo"]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterestScore(BaseDataModel):
|
||||
"""兴趣度评分结果"""
|
||||
|
||||
message_id: str
|
||||
total_score: float
|
||||
interest_match_score: float
|
||||
relationship_score: float
|
||||
mentioned_score: float
|
||||
details: Dict[str, str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Plan(BaseDataModel):
|
||||
"""
|
||||
统一规划数据模型
|
||||
"""
|
||||
|
||||
chat_id: str
|
||||
mode: "ChatMode"
|
||||
|
||||
|
||||
chat_type: "ChatType"
|
||||
# Generator 填充
|
||||
available_actions: Dict[str, "ActionInfo"] = field(default_factory=dict)
|
||||
chat_history: List["DatabaseMessages"] = field(default_factory=list)
|
||||
target_info: Optional[TargetPersonInfo] = None
|
||||
|
||||
|
||||
# Filter 填充
|
||||
llm_prompt: Optional[str] = None
|
||||
decided_actions: Optional[List[ActionPlannerInfo]] = None
|
||||
|
||||
@@ -6,6 +6,7 @@ from . import BaseDataModel
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMGenerationDataModel(BaseDataModel):
|
||||
content: Optional[str] = None
|
||||
@@ -14,4 +15,4 @@ class LLMGenerationDataModel(BaseDataModel):
|
||||
tool_calls: Optional[List["ToolCall"]] = None
|
||||
prompt: Optional[str] = None
|
||||
selected_expressions: Optional[List[int]] = None
|
||||
reply_set: Optional[List[Tuple[str, Any]]] = None
|
||||
reply_set: Optional[List[Tuple[str, Any]]] = None
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
from . import BaseDataModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageAndActionModel(BaseDataModel):
|
||||
chat_id: str = field(default_factory=str)
|
||||
time: float = field(default_factory=float)
|
||||
user_id: str = field(default_factory=str)
|
||||
user_platform: str = field(default_factory=str)
|
||||
user_nickname: str = field(default_factory=str)
|
||||
user_cardname: Optional[str] = None
|
||||
processed_plain_text: Optional[str] = None
|
||||
display_message: Optional[str] = None
|
||||
chat_info_platform: str = field(default_factory=str)
|
||||
is_action_record: bool = field(default=False)
|
||||
action_name: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_DatabaseMessages(cls, message: "DatabaseMessages"):
|
||||
return cls(
|
||||
chat_id=message.chat_id,
|
||||
time=message.time,
|
||||
user_id=message.user_info.user_id,
|
||||
user_platform=message.user_info.platform,
|
||||
user_nickname=message.user_info.user_nickname,
|
||||
user_cardname=message.user_info.user_cardname,
|
||||
processed_plain_text=message.processed_plain_text,
|
||||
display_message=message.display_message,
|
||||
chat_info_platform=message.chat_info.platform,
|
||||
)
|
||||
373
src/common/data_models/message_manager_data_model.py
Normal file
373
src/common/data_models/message_manager_data_model.py
Normal file
@@ -0,0 +1,373 @@
|
||||
"""
|
||||
消息管理模块数据模型
|
||||
定义消息管理器使用的数据结构
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
|
||||
from . import BaseDataModel
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
from src.common.logger import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .database_data_model import DatabaseMessages
|
||||
|
||||
logger = get_logger("stream_context")
|
||||
|
||||
|
||||
class MessageStatus(Enum):
|
||||
"""消息状态枚举"""
|
||||
|
||||
UNREAD = "unread" # 未读消息
|
||||
READ = "read" # 已读消息
|
||||
PROCESSING = "processing" # 处理中
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamContext(BaseDataModel):
|
||||
"""聊天流上下文信息"""
|
||||
|
||||
stream_id: str
|
||||
chat_type: ChatType = ChatType.PRIVATE # 聊天类型,默认为私聊
|
||||
chat_mode: ChatMode = ChatMode.NORMAL # 聊天模式,默认为普通模式
|
||||
unread_messages: List["DatabaseMessages"] = field(default_factory=list)
|
||||
history_messages: List["DatabaseMessages"] = field(default_factory=list)
|
||||
last_check_time: float = field(default_factory=time.time)
|
||||
is_active: bool = True
|
||||
processing_task: Optional[asyncio.Task] = None
|
||||
interruption_count: int = 0 # 打断计数器
|
||||
last_interruption_time: float = 0.0 # 上次打断时间
|
||||
afc_threshold_adjustment: float = 0.0 # afc阈值调整量
|
||||
|
||||
# 独立分发周期字段
|
||||
next_check_time: float = field(default_factory=time.time) # 下次检查时间
|
||||
distribution_interval: float = 5.0 # 当前分发周期(秒)
|
||||
|
||||
# 新增字段以替代ChatMessageContext功能
|
||||
current_message: Optional["DatabaseMessages"] = None
|
||||
priority_mode: Optional[str] = None
|
||||
priority_info: Optional[dict] = None
|
||||
|
||||
def add_message(self, message: "DatabaseMessages"):
|
||||
"""添加消息到上下文"""
|
||||
message.is_read = False
|
||||
self.unread_messages.append(message)
|
||||
|
||||
# 自动检测和更新chat type
|
||||
self._detect_chat_type(message)
|
||||
|
||||
def update_message_info(
|
||||
self, message_id: str, interest_value: float = None, actions: list = None, should_reply: bool = None
|
||||
):
|
||||
"""
|
||||
更新消息信息
|
||||
|
||||
Args:
|
||||
message_id: 消息ID
|
||||
interest_value: 兴趣度值
|
||||
actions: 执行的动作列表
|
||||
should_reply: 是否应该回复
|
||||
"""
|
||||
# 在未读消息中查找并更新
|
||||
for message in self.unread_messages:
|
||||
if message.message_id == message_id:
|
||||
message.update_message_info(interest_value, actions, should_reply)
|
||||
break
|
||||
|
||||
# 在历史消息中查找并更新
|
||||
for message in self.history_messages:
|
||||
if message.message_id == message_id:
|
||||
message.update_message_info(interest_value, actions, should_reply)
|
||||
break
|
||||
|
||||
def add_action_to_message(self, message_id: str, action: str):
|
||||
"""
|
||||
向指定消息添加执行的动作
|
||||
|
||||
Args:
|
||||
message_id: 消息ID
|
||||
action: 要添加的动作名称
|
||||
"""
|
||||
# 在未读消息中查找并更新
|
||||
for message in self.unread_messages:
|
||||
if message.message_id == message_id:
|
||||
message.add_action(action)
|
||||
break
|
||||
|
||||
# 在历史消息中查找并更新
|
||||
for message in self.history_messages:
|
||||
if message.message_id == message_id:
|
||||
message.add_action(action)
|
||||
break
|
||||
|
||||
def _detect_chat_type(self, message: "DatabaseMessages"):
|
||||
"""根据消息内容自动检测聊天类型"""
|
||||
# 只有在第一次添加消息时才检测聊天类型,避免后续消息改变类型
|
||||
if len(self.unread_messages) == 1: # 只有这条消息
|
||||
# 如果消息包含群组信息,则为群聊
|
||||
if hasattr(message, "chat_info_group_id") and message.chat_info_group_id:
|
||||
self.chat_type = ChatType.GROUP
|
||||
elif hasattr(message, "chat_info_group_name") and message.chat_info_group_name:
|
||||
self.chat_type = ChatType.GROUP
|
||||
else:
|
||||
self.chat_type = ChatType.PRIVATE
|
||||
|
||||
def update_chat_type(self, chat_type: ChatType):
|
||||
"""手动更新聊天类型"""
|
||||
self.chat_type = chat_type
|
||||
|
||||
def set_chat_mode(self, chat_mode: ChatMode):
|
||||
"""设置聊天模式"""
|
||||
self.chat_mode = chat_mode
|
||||
|
||||
def is_group_chat(self) -> bool:
|
||||
"""检查是否为群聊"""
|
||||
return self.chat_type == ChatType.GROUP
|
||||
|
||||
def is_private_chat(self) -> bool:
|
||||
"""检查是否为私聊"""
|
||||
return self.chat_type == ChatType.PRIVATE
|
||||
|
||||
def get_chat_type_display(self) -> str:
|
||||
"""获取聊天类型的显示名称"""
|
||||
if self.chat_type == ChatType.GROUP:
|
||||
return "群聊"
|
||||
elif self.chat_type == ChatType.PRIVATE:
|
||||
return "私聊"
|
||||
else:
|
||||
return "未知类型"
|
||||
|
||||
def mark_message_as_read(self, message_id: str):
|
||||
"""标记消息为已读"""
|
||||
for msg in self.unread_messages:
|
||||
if msg.message_id == message_id:
|
||||
msg.is_read = True
|
||||
self.history_messages.append(msg)
|
||||
self.unread_messages.remove(msg)
|
||||
break
|
||||
|
||||
def get_unread_messages(self) -> List["DatabaseMessages"]:
|
||||
"""获取未读消息"""
|
||||
return [msg for msg in self.unread_messages if not msg.is_read]
|
||||
|
||||
def get_history_messages(self, limit: int = 20) -> List["DatabaseMessages"]:
|
||||
"""获取历史消息"""
|
||||
# 优先返回最近的历史消息和所有未读消息
|
||||
recent_history = self.history_messages[-limit:] if len(self.history_messages) > limit else self.history_messages
|
||||
return recent_history
|
||||
|
||||
def calculate_interruption_probability(self, max_limit: int, probability_factor: float) -> float:
|
||||
"""计算打断概率"""
|
||||
if max_limit <= 0:
|
||||
return 0.0
|
||||
|
||||
# 计算打断比例
|
||||
interruption_ratio = self.interruption_count / max_limit
|
||||
|
||||
# 如果已达到或超过最大次数,完全禁止打断
|
||||
if self.interruption_count >= max_limit:
|
||||
return 0.0
|
||||
|
||||
# 如果超过概率因子,概率下降
|
||||
if interruption_ratio > probability_factor:
|
||||
# 使用指数衰减,超过限制越多,概率越低
|
||||
excess_ratio = interruption_ratio - probability_factor
|
||||
probability = 0.8 * (0.5**excess_ratio) # 基础概率0.8,指数衰减
|
||||
else:
|
||||
# 在限制内,保持较高概率
|
||||
probability = 0.8
|
||||
|
||||
return max(0.0, min(1.0, probability))
|
||||
|
||||
def increment_interruption_count(self):
|
||||
"""增加打断计数"""
|
||||
self.interruption_count += 1
|
||||
self.last_interruption_time = time.time()
|
||||
|
||||
# 同步打断计数到ChatStream
|
||||
self._sync_interruption_count_to_stream()
|
||||
|
||||
def reset_interruption_count(self):
|
||||
"""重置打断计数和afc阈值调整"""
|
||||
self.interruption_count = 0
|
||||
self.last_interruption_time = 0.0
|
||||
self.afc_threshold_adjustment = 0.0
|
||||
|
||||
# 同步打断计数到ChatStream
|
||||
self._sync_interruption_count_to_stream()
|
||||
|
||||
def apply_interruption_afc_reduction(self, reduction_value: float):
|
||||
"""应用打断导致的afc阈值降低"""
|
||||
self.afc_threshold_adjustment += reduction_value
|
||||
logger.debug(f"应用afc阈值降低: {reduction_value}, 总调整量: {self.afc_threshold_adjustment}")
|
||||
|
||||
def get_afc_threshold_adjustment(self) -> float:
|
||||
"""获取当前的afc阈值调整量"""
|
||||
return self.afc_threshold_adjustment
|
||||
|
||||
def _sync_interruption_count_to_stream(self):
|
||||
"""同步打断计数到ChatStream"""
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
if chat_manager:
|
||||
chat_stream = chat_manager.get_stream(self.stream_id)
|
||||
if chat_stream and hasattr(chat_stream, "interruption_count"):
|
||||
# 在这里我们只是标记需要保存,实际的保存会在下次save时进行
|
||||
chat_stream.saved = False
|
||||
logger.debug(
|
||||
f"已同步StreamContext {self.stream_id} 的打断计数 {self.interruption_count} 到ChatStream"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"同步打断计数到ChatStream失败: {e}")
|
||||
|
||||
def set_current_message(self, message: "DatabaseMessages"):
|
||||
"""设置当前消息"""
|
||||
self.current_message = message
|
||||
|
||||
def get_template_name(self) -> Optional[str]:
|
||||
"""获取模板名称"""
|
||||
if (
|
||||
self.current_message
|
||||
and hasattr(self.current_message, "additional_config")
|
||||
and self.current_message.additional_config
|
||||
):
|
||||
try:
|
||||
import json
|
||||
|
||||
config = json.loads(self.current_message.additional_config)
|
||||
if config.get("template_info") and not config.get("template_default", True):
|
||||
return config.get("template_name")
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
pass
|
||||
return None
|
||||
|
||||
def get_last_message(self) -> Optional["DatabaseMessages"]:
|
||||
"""获取最后一条消息"""
|
||||
if self.current_message:
|
||||
return self.current_message
|
||||
if self.unread_messages:
|
||||
return self.unread_messages[-1]
|
||||
if self.history_messages:
|
||||
return self.history_messages[-1]
|
||||
return None
|
||||
|
||||
def check_types(self, types: list) -> bool:
|
||||
"""
|
||||
检查当前消息是否支持指定的类型
|
||||
|
||||
Args:
|
||||
types: 需要检查的消息类型列表,如 ["text", "image", "emoji"]
|
||||
|
||||
Returns:
|
||||
bool: 如果消息支持所有指定的类型则返回True,否则返回False
|
||||
"""
|
||||
if not self.current_message:
|
||||
return False
|
||||
|
||||
if not types:
|
||||
# 如果没有指定类型要求,默认为支持
|
||||
return True
|
||||
|
||||
# 优先从additional_config中获取format_info
|
||||
if hasattr(self.current_message, "additional_config") and self.current_message.additional_config:
|
||||
try:
|
||||
import orjson
|
||||
|
||||
config = orjson.loads(self.current_message.additional_config)
|
||||
|
||||
# 检查format_info结构
|
||||
if "format_info" in config:
|
||||
format_info = config["format_info"]
|
||||
|
||||
# 方法1: 直接检查accept_format字段
|
||||
if "accept_format" in format_info:
|
||||
accept_format = format_info["accept_format"]
|
||||
# 确保accept_format是列表类型
|
||||
if isinstance(accept_format, str):
|
||||
accept_format = [accept_format]
|
||||
elif isinstance(accept_format, list):
|
||||
pass
|
||||
else:
|
||||
# 如果accept_format不是字符串或列表,尝试转换为列表
|
||||
accept_format = list(accept_format) if hasattr(accept_format, "__iter__") else []
|
||||
|
||||
# 检查所有请求的类型是否都被支持
|
||||
for requested_type in types:
|
||||
if requested_type not in accept_format:
|
||||
logger.debug(f"消息不支持类型 '{requested_type}',支持的类型: {accept_format}")
|
||||
return False
|
||||
return True
|
||||
|
||||
# 方法2: 检查content_format字段(向后兼容)
|
||||
elif "content_format" in format_info:
|
||||
content_format = format_info["content_format"]
|
||||
# 确保content_format是列表类型
|
||||
if isinstance(content_format, str):
|
||||
content_format = [content_format]
|
||||
elif isinstance(content_format, list):
|
||||
pass
|
||||
else:
|
||||
content_format = list(content_format) if hasattr(content_format, "__iter__") else []
|
||||
|
||||
# 检查所有请求的类型是否都被支持
|
||||
for requested_type in types:
|
||||
if requested_type not in content_format:
|
||||
logger.debug(f"消息不支持类型 '{requested_type}',支持的内容格式: {content_format}")
|
||||
return False
|
||||
return True
|
||||
|
||||
except (orjson.JSONDecodeError, AttributeError, TypeError) as e:
|
||||
logger.debug(f"解析消息格式信息失败: {e}")
|
||||
|
||||
# 备用方案:如果无法从additional_config获取格式信息,使用默认支持的类型
|
||||
# 大多数消息至少支持text类型
|
||||
default_supported_types = ["text", "emoji"]
|
||||
for requested_type in types:
|
||||
if requested_type not in default_supported_types:
|
||||
logger.debug(f"使用默认类型检查,消息可能不支持类型 '{requested_type}'")
|
||||
# 对于非基础类型,返回False以避免错误
|
||||
if requested_type not in ["text", "emoji", "reply"]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_priority_mode(self) -> Optional[str]:
|
||||
"""获取优先级模式"""
|
||||
return self.priority_mode
|
||||
|
||||
def get_priority_info(self) -> Optional[dict]:
|
||||
"""获取优先级信息"""
|
||||
return self.priority_info
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageManagerStats(BaseDataModel):
|
||||
"""消息管理器统计信息"""
|
||||
|
||||
total_streams: int = 0
|
||||
active_streams: int = 0
|
||||
total_unread_messages: int = 0
|
||||
total_processed_messages: int = 0
|
||||
start_time: float = field(default_factory=time.time)
|
||||
|
||||
@property
|
||||
def uptime(self) -> float:
|
||||
"""运行时间"""
|
||||
return time.time() - self.start_time
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamStats(BaseDataModel):
|
||||
"""聊天流统计信息"""
|
||||
|
||||
stream_id: str
|
||||
is_active: bool
|
||||
unread_count: int
|
||||
history_count: int
|
||||
last_check_time: float
|
||||
has_active_task: bool
|
||||
@@ -30,6 +30,7 @@ from src.common.database.sqlalchemy_models import (
|
||||
Schedule,
|
||||
MaiZoneScheduleStatus,
|
||||
CacheEntries,
|
||||
UserRelationships,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -54,6 +55,7 @@ MODEL_MAPPING = {
|
||||
"Schedule": Schedule,
|
||||
"MaiZoneScheduleStatus": MaiZoneScheduleStatus,
|
||||
"CacheEntries": CacheEntries,
|
||||
"UserRelationships": UserRelationships,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -55,7 +55,17 @@ class ChatStreams(Base):
|
||||
user_cardname = Column(Text, nullable=True)
|
||||
energy_value = Column(Float, nullable=True, default=5.0)
|
||||
sleep_pressure = Column(Float, nullable=True, default=0.0)
|
||||
focus_energy = Column(Float, nullable=True, default=1.0)
|
||||
focus_energy = Column(Float, nullable=True, default=0.5)
|
||||
# 动态兴趣度系统字段
|
||||
base_interest_energy = Column(Float, nullable=True, default=0.5)
|
||||
message_interest_total = Column(Float, nullable=True, default=0.0)
|
||||
message_count = Column(Integer, nullable=True, default=0)
|
||||
action_count = Column(Integer, nullable=True, default=0)
|
||||
reply_count = Column(Integer, nullable=True, default=0)
|
||||
last_interaction_time = Column(Float, nullable=True, default=None)
|
||||
consecutive_no_reply = Column(Integer, nullable=True, default=0)
|
||||
# 消息打断系统字段
|
||||
interruption_count = Column(Integer, nullable=True, default=0)
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_chatstreams_stream_id", "stream_id"),
|
||||
@@ -165,11 +175,16 @@ class Messages(Base):
|
||||
is_command = Column(Boolean, nullable=False, default=False)
|
||||
is_notify = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
# 兴趣度系统字段
|
||||
actions = Column(Text, nullable=True) # JSON格式存储动作列表
|
||||
should_reply = Column(Boolean, nullable=True, default=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_messages_message_id", "message_id"),
|
||||
Index("idx_messages_chat_id", "chat_id"),
|
||||
Index("idx_messages_time", "time"),
|
||||
Index("idx_messages_user_id", "user_id"),
|
||||
Index("idx_messages_should_reply", "should_reply"),
|
||||
)
|
||||
|
||||
|
||||
@@ -300,6 +315,26 @@ class PersonInfo(Base):
|
||||
)
|
||||
|
||||
|
||||
class BotPersonalityInterests(Base):
|
||||
"""机器人人格兴趣标签模型"""
|
||||
|
||||
__tablename__ = "bot_personality_interests"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
personality_id = Column(get_string_field(100), nullable=False, index=True)
|
||||
personality_description = Column(Text, nullable=False)
|
||||
interest_tags = Column(Text, nullable=False) # JSON格式存储的兴趣标签列表
|
||||
embedding_model = Column(get_string_field(100), nullable=False, default="text-embedding-ada-002")
|
||||
version = Column(Integer, nullable=False, default=1)
|
||||
last_updated = Column(DateTime, nullable=False, default=datetime.datetime.now, index=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_botpersonality_personality_id", "personality_id"),
|
||||
Index("idx_botpersonality_version", "version"),
|
||||
Index("idx_botpersonality_last_updated", "last_updated"),
|
||||
)
|
||||
|
||||
|
||||
class Memory(Base):
|
||||
"""记忆模型"""
|
||||
|
||||
@@ -722,3 +757,23 @@ class UserPermissions(Base):
|
||||
Index("idx_user_permission", "platform", "user_id", "permission_node"),
|
||||
Index("idx_permission_granted", "permission_node", "granted"),
|
||||
)
|
||||
|
||||
|
||||
class UserRelationships(Base):
|
||||
"""用户关系模型 - 存储用户与bot的关系数据"""
|
||||
|
||||
__tablename__ = "user_relationships"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id = Column(get_string_field(100), nullable=False, unique=True, index=True) # 用户ID
|
||||
user_name = Column(get_string_field(100), nullable=True) # 用户名
|
||||
relationship_text = Column(Text, nullable=True) # 关系印象描述
|
||||
relationship_score = Column(Float, nullable=False, default=0.3) # 关系分数(0-1)
|
||||
last_updated = Column(Float, nullable=False, default=time.time) # 最后更新时间
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 创建时间
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_user_relationship_id", "user_id"),
|
||||
Index("idx_relationship_score", "relationship_score"),
|
||||
Index("idx_relationship_updated", "last_updated"),
|
||||
)
|
||||
|
||||
@@ -350,6 +350,10 @@ MODULE_COLORS = {
|
||||
"memory": "\033[38;5;117m", # 天蓝色
|
||||
"hfc": "\033[38;5;81m", # 稍微暗一些的青色,保持可读
|
||||
"action_manager": "\033[38;5;208m", # 橙色,不与replyer重复
|
||||
"message_manager": "\033[38;5;27m", # 深蓝色,消息管理器
|
||||
"chatter_manager": "\033[38;5;129m", # 紫色,聊天管理器
|
||||
"chatter_interest_scoring": "\033[38;5;214m", # 橙黄色,兴趣评分
|
||||
"plan_executor": "\033[38;5;172m", # 橙褐色,计划执行器
|
||||
# 关系系统
|
||||
"relation": "\033[38;5;139m", # 柔和的紫色,不刺眼
|
||||
# 聊天相关模块
|
||||
@@ -551,6 +555,10 @@ MODULE_ALIASES = {
|
||||
"llm_models": "模型",
|
||||
"person_info": "人物",
|
||||
"chat_stream": "聊天流",
|
||||
"message_manager": "消息管理",
|
||||
"chatter_manager": "聊天管理",
|
||||
"chatter_interest_scoring": "兴趣评分",
|
||||
"plan_executor": "计划执行",
|
||||
"planner": "规划器",
|
||||
"replyer": "言语",
|
||||
"config": "配置",
|
||||
|
||||
@@ -23,15 +23,15 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
||||
maim_message_config = global_config.maim_message
|
||||
|
||||
# 设置基本参数
|
||||
|
||||
|
||||
host = os.getenv("HOST", "127.0.0.1")
|
||||
port_str = os.getenv("PORT", "8000")
|
||||
|
||||
|
||||
try:
|
||||
port = int(port_str)
|
||||
except ValueError:
|
||||
port = 8000
|
||||
|
||||
|
||||
kwargs = {
|
||||
"host": host,
|
||||
"port": port,
|
||||
|
||||
@@ -22,10 +22,15 @@ def _model_to_dict(instance: Base) -> Dict[str, Any]:
|
||||
"""
|
||||
将 SQLAlchemy 模型实例转换为字典。
|
||||
"""
|
||||
return {col.name: getattr(instance, col.name) for col in instance.__table__.columns}
|
||||
try:
|
||||
return {col.name: getattr(instance, col.name) for col in instance.__table__.columns}
|
||||
except Exception as e:
|
||||
# 如果对象已经脱离会话,尝试从instance.__dict__中获取数据
|
||||
logger.warning(f"从数据库对象获取属性失败,尝试使用__dict__: {e}")
|
||||
return {col.name: instance.__dict__.get(col.name) for col in instance.__table__.columns}
|
||||
|
||||
|
||||
async def find_messages(
|
||||
def find_messages(
|
||||
message_filter: dict[str, Any],
|
||||
sort: Optional[List[tuple[str, int]]] = None,
|
||||
limit: int = 0,
|
||||
@@ -46,7 +51,7 @@ async def find_messages(
|
||||
消息字典列表,如果出错则返回空列表。
|
||||
"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
with get_db_session() as session:
|
||||
query = select(Messages)
|
||||
|
||||
# 应用过滤器
|
||||
@@ -96,7 +101,7 @@ async def find_messages(
|
||||
# 获取时间最早的 limit 条记录,已经是正序
|
||||
query = query.order_by(Messages.time.asc()).limit(limit)
|
||||
try:
|
||||
results = (await session.execute(query)).scalars().all()
|
||||
results = session.execute(query).scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"执行earliest查询失败: {e}")
|
||||
results = []
|
||||
@@ -104,7 +109,7 @@ async def find_messages(
|
||||
# 获取时间最晚的 limit 条记录
|
||||
query = query.order_by(Messages.time.desc()).limit(limit)
|
||||
try:
|
||||
latest_results = (await session.execute(query)).scalars().all()
|
||||
latest_results = session.execute(query).scalars().all()
|
||||
# 将结果按时间正序排列
|
||||
results = sorted(latest_results, key=lambda msg: msg.time)
|
||||
except Exception as e:
|
||||
@@ -128,11 +133,12 @@ async def find_messages(
|
||||
if sort_terms:
|
||||
query = query.order_by(*sort_terms)
|
||||
try:
|
||||
results = (await session.execute(query)).scalars().all()
|
||||
results = session.execute(query).scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"执行无限制查询失败: {e}")
|
||||
results = []
|
||||
|
||||
# 在会话内将结果转换为字典,避免会话分离错误
|
||||
return [_model_to_dict(msg) for msg in results]
|
||||
except Exception as e:
|
||||
log_message = (
|
||||
@@ -143,7 +149,7 @@ async def find_messages(
|
||||
return []
|
||||
|
||||
|
||||
async def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
"""
|
||||
根据提供的过滤器计算消息数量。
|
||||
|
||||
@@ -154,7 +160,7 @@ async def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
符合条件的消息数量,如果出错则返回 0。
|
||||
"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
with get_db_session() as session:
|
||||
query = select(func.count(Messages.id))
|
||||
|
||||
# 应用过滤器
|
||||
@@ -192,7 +198,7 @@ async def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
if conditions:
|
||||
query = query.where(*conditions)
|
||||
|
||||
count = (await session.execute(query)).scalar()
|
||||
count = session.execute(query).scalar()
|
||||
return count or 0
|
||||
except Exception as e:
|
||||
log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
|
||||
@@ -201,5 +207,5 @@ async def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
|
||||
|
||||
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
|
||||
# 注意:对于 SQLAlchemy,插入操作通常是使用 session.add() 和 await session.commit()。
|
||||
# 注意:对于 SQLAlchemy,插入操作通常是使用 session.add() 和 session.commit()。
|
||||
# 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。
|
||||
|
||||
@@ -31,7 +31,9 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
self.client_uuid: str | None = local_storage["mofox_uuid"] if "mofox_uuid" in local_storage else None # type: ignore
|
||||
"""客户端UUID"""
|
||||
|
||||
self.private_key_pem: str | None = local_storage["mofox_private_key"] if "mofox_private_key" in local_storage else None # type: ignore
|
||||
self.private_key_pem: str | None = (
|
||||
local_storage["mofox_private_key"] if "mofox_private_key" in local_storage else None
|
||||
) # type: ignore
|
||||
"""客户端私钥"""
|
||||
|
||||
self.info_dict = self._get_sys_info()
|
||||
@@ -61,78 +63,65 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
def _generate_signature(self, request_body: dict) -> tuple[str, str]:
|
||||
"""
|
||||
生成RSA签名
|
||||
|
||||
|
||||
Returns:
|
||||
tuple[str, str]: (timestamp, signature_b64)
|
||||
"""
|
||||
if not self.private_key_pem:
|
||||
raise ValueError("私钥未初始化")
|
||||
|
||||
|
||||
# 生成时间戳
|
||||
timestamp = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
# 创建签名数据字符串
|
||||
sign_data = f"{self.client_uuid}:{timestamp}:{json.dumps(request_body, separators=(',', ':'))}"
|
||||
|
||||
|
||||
# 加载私钥
|
||||
private_key = serialization.load_pem_private_key(
|
||||
self.private_key_pem.encode('utf-8'),
|
||||
password=None
|
||||
)
|
||||
|
||||
private_key = serialization.load_pem_private_key(self.private_key_pem.encode("utf-8"), password=None)
|
||||
|
||||
# 确保是RSA私钥
|
||||
if not isinstance(private_key, rsa.RSAPrivateKey):
|
||||
raise ValueError("私钥必须是RSA格式")
|
||||
|
||||
|
||||
# 生成签名
|
||||
signature = private_key.sign(
|
||||
sign_data.encode('utf-8'),
|
||||
padding.PSS(
|
||||
mgf=padding.MGF1(hashes.SHA256()),
|
||||
salt_length=padding.PSS.MAX_LENGTH
|
||||
),
|
||||
hashes.SHA256()
|
||||
sign_data.encode("utf-8"),
|
||||
padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH),
|
||||
hashes.SHA256(),
|
||||
)
|
||||
|
||||
|
||||
# Base64编码
|
||||
signature_b64 = base64.b64encode(signature).decode('utf-8')
|
||||
|
||||
signature_b64 = base64.b64encode(signature).decode("utf-8")
|
||||
|
||||
return timestamp, signature_b64
|
||||
|
||||
def _decrypt_challenge(self, challenge_b64: str) -> str:
|
||||
"""
|
||||
解密挑战数据
|
||||
|
||||
|
||||
Args:
|
||||
challenge_b64: Base64编码的挑战数据
|
||||
|
||||
|
||||
Returns:
|
||||
str: 解密后的UUID字符串
|
||||
"""
|
||||
if not self.private_key_pem:
|
||||
raise ValueError("私钥未初始化")
|
||||
|
||||
|
||||
# 加载私钥
|
||||
private_key = serialization.load_pem_private_key(
|
||||
self.private_key_pem.encode('utf-8'),
|
||||
password=None
|
||||
)
|
||||
|
||||
private_key = serialization.load_pem_private_key(self.private_key_pem.encode("utf-8"), password=None)
|
||||
|
||||
# 确保是RSA私钥
|
||||
if not isinstance(private_key, rsa.RSAPrivateKey):
|
||||
raise ValueError("私钥必须是RSA格式")
|
||||
|
||||
|
||||
# 解密挑战数据
|
||||
decrypted_bytes = private_key.decrypt(
|
||||
base64.b64decode(challenge_b64),
|
||||
padding.OAEP(
|
||||
mgf=padding.MGF1(hashes.SHA256()),
|
||||
algorithm=hashes.SHA256(),
|
||||
label=None
|
||||
)
|
||||
padding.OAEP(mgf=padding.MGF1(hashes.SHA256()), algorithm=hashes.SHA256(), label=None),
|
||||
)
|
||||
|
||||
return decrypted_bytes.decode('utf-8')
|
||||
|
||||
return decrypted_bytes.decode("utf-8")
|
||||
|
||||
async def _req_uuid(self) -> bool:
|
||||
"""
|
||||
@@ -155,28 +144,26 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
|
||||
if response.status != 200:
|
||||
response_text = await response.text()
|
||||
logger.error(
|
||||
f"注册步骤1失败,状态码: {response.status}, 响应内容: {response_text}"
|
||||
)
|
||||
logger.error(f"注册步骤1失败,状态码: {response.status}, 响应内容: {response_text}")
|
||||
raise aiohttp.ClientResponseError(
|
||||
request_info=response.request_info,
|
||||
history=response.history,
|
||||
status=response.status,
|
||||
message=f"Step1 failed: {response_text}"
|
||||
message=f"Step1 failed: {response_text}",
|
||||
)
|
||||
|
||||
step1_data = await response.json()
|
||||
temp_uuid = step1_data.get("temp_uuid")
|
||||
private_key = step1_data.get("private_key")
|
||||
challenge = step1_data.get("challenge")
|
||||
|
||||
|
||||
if not all([temp_uuid, private_key, challenge]):
|
||||
logger.error("Step1响应缺少必要字段:temp_uuid, private_key 或 challenge")
|
||||
raise ValueError("Step1响应数据不完整")
|
||||
|
||||
# 临时保存私钥用于解密
|
||||
self.private_key_pem = private_key
|
||||
|
||||
|
||||
# 解密挑战数据
|
||||
logger.debug("解密挑战数据...")
|
||||
try:
|
||||
@@ -184,21 +171,18 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
except Exception as e:
|
||||
logger.error(f"解密挑战数据失败: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# 验证解密结果
|
||||
if decrypted_uuid != temp_uuid:
|
||||
logger.error(f"解密结果验证失败: 期望 {temp_uuid}, 实际 {decrypted_uuid}")
|
||||
raise ValueError("解密结果与临时UUID不匹配")
|
||||
|
||||
|
||||
logger.debug("挑战数据解密成功,开始注册步骤2")
|
||||
|
||||
# Step 2: 发送解密结果完成注册
|
||||
async with session.post(
|
||||
f"{TELEMETRY_SERVER_URL}/stat/reg_client_step2",
|
||||
json={
|
||||
"temp_uuid": temp_uuid,
|
||||
"decrypted_uuid": decrypted_uuid
|
||||
},
|
||||
json={"temp_uuid": temp_uuid, "decrypted_uuid": decrypted_uuid},
|
||||
timeout=aiohttp.ClientTimeout(total=5),
|
||||
) as response:
|
||||
logger.debug(f"Step2 Response status: {response.status}")
|
||||
@@ -206,7 +190,7 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
if response.status == 200:
|
||||
step2_data = await response.json()
|
||||
mofox_uuid = step2_data.get("mofox_uuid")
|
||||
|
||||
|
||||
if mofox_uuid:
|
||||
# 将正式UUID和私钥存储到本地
|
||||
local_storage["mofox_uuid"] = mofox_uuid
|
||||
@@ -225,23 +209,19 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
raise ValueError(f"Step2失败: {response_text}")
|
||||
else:
|
||||
response_text = await response.text()
|
||||
logger.error(
|
||||
f"注册步骤2失败,状态码: {response.status}, 响应内容: {response_text}"
|
||||
)
|
||||
logger.error(f"注册步骤2失败,状态码: {response.status}, 响应内容: {response_text}")
|
||||
raise aiohttp.ClientResponseError(
|
||||
request_info=response.request_info,
|
||||
history=response.history,
|
||||
status=response.status,
|
||||
message=f"Step2 failed: {response_text}"
|
||||
message=f"Step2 failed: {response_text}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
error_msg = str(e) or "未知错误"
|
||||
logger.warning(
|
||||
f"注册客户端出错,不过你还是可以正常使用墨狐: {type(e).__name__}: {error_msg}"
|
||||
)
|
||||
logger.warning(f"注册客户端出错,不过你还是可以正常使用墨狐: {type(e).__name__}: {error_msg}")
|
||||
logger.debug(f"完整错误信息: {traceback.format_exc()}")
|
||||
|
||||
# 请求失败,重试次数+1
|
||||
@@ -264,13 +244,13 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
try:
|
||||
# 生成签名
|
||||
timestamp, signature = self._generate_signature(self.info_dict)
|
||||
|
||||
|
||||
headers = {
|
||||
"X-mofox-UUID": self.client_uuid,
|
||||
"X-mofox-Signature": signature,
|
||||
"X-mofox-Timestamp": timestamp,
|
||||
"User-Agent": f"MofoxClient/{self.client_uuid[:8]}",
|
||||
"Content-Type": "application/json"
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
logger.debug(f"正在发送心跳到服务器: {self.server_url}")
|
||||
@@ -347,4 +327,4 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
logger.warning("客户端注册失败,跳过此次心跳")
|
||||
return
|
||||
|
||||
await self._send_heartbeat()
|
||||
await self._send_heartbeat()
|
||||
|
||||
@@ -99,14 +99,13 @@ def get_global_server() -> Server:
|
||||
"""获取全局服务器实例"""
|
||||
global global_server
|
||||
if global_server is None:
|
||||
|
||||
host = os.getenv("HOST", "127.0.0.1")
|
||||
port_str = os.getenv("PORT", "8000")
|
||||
|
||||
|
||||
try:
|
||||
port = int(port_str)
|
||||
except ValueError:
|
||||
port = 8000
|
||||
|
||||
|
||||
global_server = Server(host=host, port=port)
|
||||
return global_server
|
||||
|
||||
@@ -137,7 +137,7 @@ class ModelTaskConfig(ValidatedConfigBase):
|
||||
monthly_plan_generator: TaskConfig = Field(..., description="月层计划生成模型配置")
|
||||
emoji_vlm: TaskConfig = Field(..., description="表情包识别模型配置")
|
||||
anti_injection: TaskConfig = Field(..., description="反注入检测专用模型配置")
|
||||
|
||||
relationship_tracker: TaskConfig = Field(..., description="关系追踪模型配置")
|
||||
# 处理配置文件中命名不一致的问题
|
||||
utils_video: TaskConfig = Field(..., description="视频分析模型配置(兼容配置文件中的命名)")
|
||||
|
||||
|
||||
@@ -43,7 +43,8 @@ from src.config.official_configs import (
|
||||
CrossContextConfig,
|
||||
PermissionConfig,
|
||||
CommandConfig,
|
||||
PlanningSystemConfig
|
||||
PlanningSystemConfig,
|
||||
AffinityFlowConfig,
|
||||
)
|
||||
|
||||
from .api_ada_configs import (
|
||||
@@ -66,7 +67,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
|
||||
|
||||
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
||||
# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
|
||||
MMC_VERSION = "0.10.0-alpha-2"
|
||||
MMC_VERSION = "0.11.0-alpha-1"
|
||||
|
||||
|
||||
def get_key_comment(toml_table, key):
|
||||
@@ -417,6 +418,7 @@ class Config(ValidatedConfigBase):
|
||||
cross_context: CrossContextConfig = Field(
|
||||
default_factory=lambda: CrossContextConfig(), description="跨群聊上下文共享配置"
|
||||
)
|
||||
affinity_flow: AffinityFlowConfig = Field(default_factory=lambda: AffinityFlowConfig(), description="亲和流配置")
|
||||
|
||||
|
||||
class APIAdapterConfig(ValidatedConfigBase):
|
||||
|
||||
@@ -51,8 +51,12 @@ class PersonalityConfig(ValidatedConfigBase):
|
||||
personality_core: str = Field(..., description="核心人格")
|
||||
personality_side: str = Field(..., description="人格侧写")
|
||||
identity: str = Field(default="", description="身份特征")
|
||||
background_story: str = Field(default="", description="世界观背景故事,这部分内容会作为背景知识,LLM被指导不应主动复述")
|
||||
safety_guidelines: List[str] = Field(default_factory=list, description="安全与互动底线,Bot在任何情况下都必须遵守的原则")
|
||||
background_story: str = Field(
|
||||
default="", description="世界观背景故事,这部分内容会作为背景知识,LLM被指导不应主动复述"
|
||||
)
|
||||
safety_guidelines: List[str] = Field(
|
||||
default_factory=list, description="安全与互动底线,Bot在任何情况下都必须遵守的原则"
|
||||
)
|
||||
reply_style: str = Field(default="", description="表达风格")
|
||||
prompt_mode: Literal["s4u", "normal"] = Field(default="s4u", description="Prompt模式")
|
||||
compress_personality: bool = Field(default=True, description="是否压缩人格")
|
||||
@@ -109,7 +113,8 @@ class ChatConfig(ValidatedConfigBase):
|
||||
talk_frequency_adjust: list[list[str]] = Field(default_factory=lambda: [], description="聊天频率调整")
|
||||
focus_value: float = Field(default=1.0, description="专注值")
|
||||
focus_mode_quiet_groups: List[str] = Field(
|
||||
default_factory=list, description='专注模式下需要保持安静的群组列表, 格式: ["platform:group_id1", "platform:group_id2"]'
|
||||
default_factory=list,
|
||||
description='专注模式下需要保持安静的群组列表, 格式: ["platform:group_id1", "platform:group_id2"]',
|
||||
)
|
||||
force_reply_private: bool = Field(default=False, description="强制回复私聊")
|
||||
group_chat_mode: Literal["auto", "normal", "focus"] = Field(default="auto", description="群聊模式")
|
||||
@@ -129,6 +134,31 @@ class ChatConfig(ValidatedConfigBase):
|
||||
)
|
||||
delta_sigma: int = Field(default=120, description="采用正态分布随机时间间隔")
|
||||
|
||||
# 消息打断系统配置
|
||||
interruption_enabled: bool = Field(default=True, description="是否启用消息打断系统")
|
||||
interruption_max_limit: int = Field(default=3, ge=0, description="每个聊天流的最大打断次数")
|
||||
interruption_probability_factor: float = Field(
|
||||
default=0.8, ge=0.0, le=1.0, description="打断概率因子,当前打断次数/最大打断次数超过此值时触发概率下降"
|
||||
)
|
||||
interruption_afc_reduction: float = Field(
|
||||
default=0.05, ge=0.0, le=1.0, description="每次连续打断降低的afc阈值数值"
|
||||
)
|
||||
|
||||
# 动态消息分发系统配置
|
||||
dynamic_distribution_enabled: bool = Field(default=True, description="是否启用动态消息分发周期调整")
|
||||
dynamic_distribution_base_interval: float = Field(
|
||||
default=5.0, ge=1.0, le=60.0, description="基础分发间隔(秒)"
|
||||
)
|
||||
dynamic_distribution_min_interval: float = Field(
|
||||
default=1.0, ge=0.5, le=10.0, description="最小分发间隔(秒)"
|
||||
)
|
||||
dynamic_distribution_max_interval: float = Field(
|
||||
default=30.0, ge=5.0, le=300.0, description="最大分发间隔(秒)"
|
||||
)
|
||||
dynamic_distribution_jitter_factor: float = Field(
|
||||
default=0.2, ge=0.0, le=0.5, description="分发间隔随机扰动因子"
|
||||
)
|
||||
|
||||
def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float:
|
||||
"""
|
||||
根据当前时间和聊天流获取对应的 talk_frequency
|
||||
@@ -376,6 +406,7 @@ class ExpressionConfig(ValidatedConfigBase):
|
||||
# 如果都没有匹配,返回默认值
|
||||
return True, True, 1.0
|
||||
|
||||
|
||||
class ToolConfig(ValidatedConfigBase):
|
||||
"""工具配置类"""
|
||||
|
||||
@@ -510,7 +541,6 @@ class ExperimentalConfig(ValidatedConfigBase):
|
||||
pfc_chatting: bool = Field(default=False, description="启用PFC聊天")
|
||||
|
||||
|
||||
|
||||
class MaimMessageConfig(ValidatedConfigBase):
|
||||
"""maim_message配置类"""
|
||||
|
||||
@@ -635,8 +665,12 @@ class SleepSystemConfig(ValidatedConfigBase):
|
||||
sleep_by_schedule: bool = Field(default=True, description="是否根据日程表进行睡觉")
|
||||
fixed_sleep_time: str = Field(default="23:00", description="固定的睡觉时间")
|
||||
fixed_wake_up_time: str = Field(default="07:00", description="固定的起床时间")
|
||||
sleep_time_offset_minutes: int = Field(default=15, ge=0, le=60, description="睡觉时间随机偏移量范围(分钟),实际睡觉时间会在±该值范围内随机")
|
||||
wake_up_time_offset_minutes: int = Field(default=15, ge=0, le=60, description="起床时间随机偏移量范围(分钟),实际起床时间会在±该值范围内随机")
|
||||
sleep_time_offset_minutes: int = Field(
|
||||
default=15, ge=0, le=60, description="睡觉时间随机偏移量范围(分钟),实际睡觉时间会在±该值范围内随机"
|
||||
)
|
||||
wake_up_time_offset_minutes: int = Field(
|
||||
default=15, ge=0, le=60, description="起床时间随机偏移量范围(分钟),实际起床时间会在±该值范围内随机"
|
||||
)
|
||||
wakeup_threshold: float = Field(default=15.0, ge=1.0, description="唤醒阈值,达到此值时会被唤醒")
|
||||
private_message_increment: float = Field(default=3.0, ge=0.1, description="私聊消息增加的唤醒度")
|
||||
group_mention_increment: float = Field(default=2.0, ge=0.1, description="群聊艾特增加的唤醒度")
|
||||
@@ -651,10 +685,10 @@ class SleepSystemConfig(ValidatedConfigBase):
|
||||
# --- 失眠机制相关参数 ---
|
||||
enable_insomnia_system: bool = Field(default=True, description="是否启用失眠系统")
|
||||
insomnia_trigger_delay_minutes: List[int] = Field(
|
||||
default_factory=lambda:[30, 60], description="入睡后触发失眠判定的延迟时间范围(分钟)"
|
||||
default_factory=lambda: [30, 60], description="入睡后触发失眠判定的延迟时间范围(分钟)"
|
||||
)
|
||||
insomnia_duration_minutes: List[int] = Field(
|
||||
default_factory=lambda:[15, 45], description="单次失眠状态的持续时间范围(分钟)"
|
||||
default_factory=lambda: [15, 45], description="单次失眠状态的持续时间范围(分钟)"
|
||||
)
|
||||
sleep_pressure_threshold: float = Field(default=30.0, description="触发“压力不足型失眠”的睡眠压力阈值")
|
||||
deep_sleep_threshold: float = Field(default=80.0, description="进入“深度睡眠”的睡眠压力阈值")
|
||||
@@ -690,6 +724,8 @@ class CrossContextConfig(ValidatedConfigBase):
|
||||
|
||||
enable: bool = Field(default=False, description="是否启用跨群聊上下文共享功能")
|
||||
groups: List[ContextGroup] = Field(default_factory=list, description="上下文共享组列表")
|
||||
|
||||
|
||||
class CommandConfig(ValidatedConfigBase):
|
||||
"""命令系统配置类"""
|
||||
|
||||
@@ -703,3 +739,34 @@ class PermissionConfig(ValidatedConfigBase):
|
||||
master_users: List[List[str]] = Field(
|
||||
default_factory=list, description="Master用户列表,格式: [[platform, user_id], ...]"
|
||||
)
|
||||
|
||||
|
||||
class AffinityFlowConfig(ValidatedConfigBase):
|
||||
"""亲和流配置类(兴趣度评分和人物关系系统)"""
|
||||
|
||||
# 兴趣评分系统参数
|
||||
reply_action_interest_threshold: float = Field(default=0.4, description="回复动作兴趣阈值")
|
||||
non_reply_action_interest_threshold: float = Field(default=0.2, description="非回复动作兴趣阈值")
|
||||
high_match_interest_threshold: float = Field(default=0.8, description="高匹配兴趣阈值")
|
||||
medium_match_interest_threshold: float = Field(default=0.5, description="中匹配兴趣阈值")
|
||||
low_match_interest_threshold: float = Field(default=0.2, description="低匹配兴趣阈值")
|
||||
high_match_keyword_multiplier: float = Field(default=1.5, description="高匹配关键词兴趣倍率")
|
||||
medium_match_keyword_multiplier: float = Field(default=1.2, description="中匹配关键词兴趣倍率")
|
||||
low_match_keyword_multiplier: float = Field(default=1.0, description="低匹配关键词兴趣倍率")
|
||||
match_count_bonus: float = Field(default=0.1, description="匹配数关键词加成值")
|
||||
max_match_bonus: float = Field(default=0.5, description="最大匹配数加成值")
|
||||
|
||||
# 回复决策系统参数
|
||||
no_reply_threshold_adjustment: float = Field(default=0.1, description="不回复兴趣阈值调整值")
|
||||
reply_cooldown_reduction: int = Field(default=2, description="回复后减少的不回复计数")
|
||||
max_no_reply_count: int = Field(default=5, description="最大不回复计数次数")
|
||||
|
||||
# 综合评分权重
|
||||
keyword_match_weight: float = Field(default=0.4, description="兴趣关键词匹配度权重")
|
||||
mention_bot_weight: float = Field(default=0.3, description="提及bot分数权重")
|
||||
relationship_weight: float = Field(default=0.3, description="人物关系分数权重")
|
||||
|
||||
# 提及bot相关参数
|
||||
mention_bot_adjustment_threshold: float = Field(default=0.3, description="提及bot后的调整阈值")
|
||||
mention_bot_interest_score: float = Field(default=0.6, description="提及bot的兴趣分")
|
||||
base_relationship_score: float = Field(default=0.5, description="基础人物关系分")
|
||||
|
||||
@@ -64,6 +64,9 @@ class Individuality:
|
||||
else:
|
||||
logger.error("人设构建失败")
|
||||
|
||||
# 初始化智能兴趣系统
|
||||
await self._initialize_smart_interest_system(personality_result, identity_result)
|
||||
|
||||
# 如果任何一个发生变化,都需要清空数据库中的info_list(因为这影响整体人设)
|
||||
if personality_changed or identity_changed:
|
||||
logger.info("将清空数据库中原有的关键词缓存")
|
||||
@@ -75,6 +78,21 @@ class Individuality:
|
||||
}
|
||||
await person_info_manager.update_one_field(self.bot_person_id, "info_list", [], data=update_data)
|
||||
|
||||
async def _initialize_smart_interest_system(self, personality_result: str, identity_result: str):
|
||||
"""初始化智能兴趣系统"""
|
||||
# 组合完整的人设描述
|
||||
full_personality = f"{personality_result},{identity_result}"
|
||||
|
||||
# 获取全局兴趣评分系统实例
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system as interest_scoring_system
|
||||
|
||||
# 初始化智能兴趣系统
|
||||
await interest_scoring_system.initialize_smart_interests(
|
||||
personality_description=full_personality, personality_id=self.bot_person_id
|
||||
)
|
||||
|
||||
logger.info("智能兴趣系统初始化完成")
|
||||
|
||||
async def get_personality_block(self) -> str:
|
||||
bot_name = global_config.bot.nickname
|
||||
if global_config.bot.alias_names:
|
||||
|
||||
@@ -145,9 +145,9 @@ class LLMUsageRecorder:
|
||||
LLM使用情况记录器(SQLAlchemy版本)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def record_usage_to_database(
|
||||
model_info: ModelInfo,
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
model_usage: UsageRecord,
|
||||
user_id: str,
|
||||
request_type: str,
|
||||
@@ -161,7 +161,7 @@ class LLMUsageRecorder:
|
||||
session = None
|
||||
try:
|
||||
# 使用 SQLAlchemy 会话创建记录
|
||||
async with get_db_session() as session:
|
||||
with get_db_session() as session:
|
||||
usage_record = LLMUsage(
|
||||
model_name=model_info.model_identifier,
|
||||
model_assign_name=model_info.name,
|
||||
@@ -179,7 +179,7 @@ class LLMUsageRecorder:
|
||||
)
|
||||
|
||||
session.add(usage_record)
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
logger.debug(
|
||||
f"Token使用情况 - 模型: {model_usage.model_name}, "
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@desc: 该模块封装了与大语言模型(LLM)交互的所有核心逻辑。
|
||||
它被设计为一个高度容错和可扩展的系统,包含以下主要组件:
|
||||
@@ -892,7 +891,7 @@ class LLMRequest:
|
||||
max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens,
|
||||
)
|
||||
|
||||
self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions")
|
||||
await self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions")
|
||||
|
||||
if not response.content and not response.tool_calls:
|
||||
if raise_when_empty:
|
||||
@@ -917,14 +916,14 @@ class LLMRequest:
|
||||
embedding_input=embedding_input
|
||||
)
|
||||
|
||||
self._record_usage(model_info, response.usage, time.time() - start_time, "/embeddings")
|
||||
await self._record_usage(model_info, response.usage, time.time() - start_time, "/embeddings")
|
||||
|
||||
if not response.embedding:
|
||||
raise RuntimeError("获取embedding失败")
|
||||
|
||||
return response.embedding, model_info.name
|
||||
|
||||
def _record_usage(self, model_info: ModelInfo, usage: Optional[UsageRecord], time_cost: float, endpoint: str):
|
||||
async def _record_usage(self, model_info: ModelInfo, usage: Optional[UsageRecord], time_cost: float, endpoint: str):
|
||||
"""
|
||||
记录模型使用情况。
|
||||
|
||||
|
||||
143
src/main.py
143
src/main.py
@@ -1,35 +1,40 @@
|
||||
# 再用这个就写一行注释来混提交的我直接全部🌿飞😡
|
||||
import asyncio
|
||||
import time
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from functools import partial
|
||||
import traceback
|
||||
from typing import Dict, Any
|
||||
|
||||
from maim_message import MessageServer
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
||||
from src.common.logger import get_logger
|
||||
# 导入消息API和traceback模块
|
||||
from src.common.message import get_global_api
|
||||
from src.common.remote import TelemetryHeartBeatTask
|
||||
from src.common.server import get_global_server, Server
|
||||
from src.config.config import global_config
|
||||
from src.individuality.individuality import get_individuality, Individuality
|
||||
from src.manager.async_task_manager import async_task_manager
|
||||
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
from src.common.logger import get_logger
|
||||
from src.individuality.individuality import get_individuality, Individuality
|
||||
from src.common.server import get_global_server, Server
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
from rich.traceback import install
|
||||
from src.schedule.schedule_manager import schedule_manager
|
||||
from src.schedule.monthly_plan_manager import monthly_plan_manager
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
# from src.api.main import start_api_server
|
||||
|
||||
# 导入新的插件管理器和热重载管理器
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
from src.schedule.monthly_plan_manager import monthly_plan_manager
|
||||
from src.schedule.schedule_manager import schedule_manager
|
||||
from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
|
||||
|
||||
# from src.api.main import start_api_server
|
||||
# 导入消息API和traceback模块
|
||||
from src.common.message import get_global_api
|
||||
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
|
||||
if not global_config.memory.enable_memory:
|
||||
import src.chat.memory_system.Hippocampus as hippocampus_module
|
||||
@@ -38,11 +43,7 @@ if not global_config.memory.enable_memory:
|
||||
def initialize(self):
|
||||
pass
|
||||
|
||||
async def initialize_async(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_hippocampus():
|
||||
def get_hippocampus(self):
|
||||
return None
|
||||
|
||||
async def build_memory(self):
|
||||
@@ -54,9 +55,9 @@ if not global_config.memory.enable_memory:
|
||||
async def consolidate_memory(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
async def get_memory_from_text(
|
||||
text: str,
|
||||
self,
|
||||
text: str,
|
||||
max_memory_num: int = 3,
|
||||
max_memory_length: int = 2,
|
||||
max_depth: int = 3,
|
||||
@@ -64,24 +65,20 @@ if not global_config.memory.enable_memory:
|
||||
) -> list:
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def get_memory_from_topic(
|
||||
valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
|
||||
self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
|
||||
) -> list:
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def get_activate_from_text(
|
||||
text: str, max_depth: int = 3, fast_retrieval: bool = False
|
||||
self, text: str, max_depth: int = 3, fast_retrieval: bool = False
|
||||
) -> tuple[float, list[str]]:
|
||||
return 0.0, []
|
||||
|
||||
@staticmethod
|
||||
def get_memory_from_keyword(keyword: str, max_depth: int = 2) -> list:
|
||||
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_all_node_names() -> list:
|
||||
def get_all_node_names(self) -> list:
|
||||
return []
|
||||
|
||||
hippocampus_module.hippocampus_manager = MockHippocampusManager()
|
||||
@@ -93,6 +90,20 @@ install(extra_lines=3)
|
||||
logger = get_logger("main")
|
||||
|
||||
|
||||
def _task_done_callback(task: asyncio.Task, message_id: str, start_time: float):
|
||||
"""后台任务完成时的回调函数"""
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
try:
|
||||
task.result() # 如果任务有异常,这里会重新抛出
|
||||
logger.debug(f"消息 {message_id} 的后台任务 (ID: {id(task)}) 已成功完成, 耗时: {duration:.2f}s")
|
||||
except asyncio.CancelledError:
|
||||
logger.warning(f"消息 {message_id} 的后台任务 (ID: {id(task)}) 被取消, 耗时: {duration:.2f}s")
|
||||
except Exception:
|
||||
logger.error(f"处理消息 {message_id} 的后台任务 (ID: {id(task)}) 出现未捕获的异常, 耗时: {duration:.2f}s:")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
class MainSystem:
|
||||
def __init__(self):
|
||||
self.hippocampus_manager = hippocampus_manager
|
||||
@@ -117,15 +128,28 @@ class MainSystem:
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
@staticmethod
|
||||
def _cleanup():
|
||||
def _cleanup(self):
|
||||
"""清理资源"""
|
||||
try:
|
||||
# 停止消息管理器
|
||||
from src.chat.message_manager import message_manager
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
asyncio.create_task(message_manager.stop())
|
||||
else:
|
||||
loop.run_until_complete(message_manager.stop())
|
||||
logger.info("🛑 消息管理器已停止")
|
||||
except Exception as e:
|
||||
logger.error(f"停止消息管理器时出错: {e}")
|
||||
|
||||
try:
|
||||
# 停止消息重组器
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from src.plugin_system import EventType
|
||||
import asyncio
|
||||
asyncio.run(event_manager.trigger_event(EventType.ON_STOP,permission_group="SYSTEM"))
|
||||
|
||||
asyncio.run(event_manager.trigger_event(EventType.ON_STOP, permission_group="SYSTEM"))
|
||||
from src.utils.message_chunker import reassembler
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
@@ -159,6 +183,20 @@ class MainSystem:
|
||||
except Exception as e:
|
||||
logger.error(f"停止记忆管理器时出错: {e}")
|
||||
|
||||
async def _message_process_wrapper(self, message_data: Dict[str, Any]):
|
||||
"""并行处理消息的包装器"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
message_id = message_data.get("message_info", {}).get("message_id", "UNKNOWN")
|
||||
# 创建后台任务
|
||||
task = asyncio.create_task(chat_bot.message_process(message_data))
|
||||
logger.debug(f"已为消息 {message_id} 创建后台处理任务 (ID: {id(task)})")
|
||||
# 添加一个回调函数,当任务完成时,它会被调用
|
||||
task.add_done_callback(partial(_task_done_callback, message_id=message_id, start_time=start_time))
|
||||
except Exception:
|
||||
logger.error("在创建消息处理任务时发生严重错误:")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化系统组件"""
|
||||
logger.info(f"正在唤醒{global_config.bot.nickname}......")
|
||||
@@ -211,7 +249,7 @@ MoFox_Bot(第三方修改版)
|
||||
|
||||
# 添加统计信息输出任务
|
||||
await async_task_manager.add_task(StatisticOutputTask())
|
||||
|
||||
|
||||
# 添加遥测心跳任务
|
||||
await async_task_manager.add_task(TelemetryHeartBeatTask())
|
||||
|
||||
@@ -223,7 +261,6 @@ MoFox_Bot(第三方修改版)
|
||||
from src.plugin_system.apis.permission_api import permission_api
|
||||
|
||||
permission_manager = PermissionManager()
|
||||
await permission_manager.initialize()
|
||||
permission_api.set_permission_manager(permission_manager)
|
||||
logger.info("权限管理器初始化成功")
|
||||
|
||||
@@ -244,6 +281,18 @@ MoFox_Bot(第三方修改版)
|
||||
get_emoji_manager().initialize()
|
||||
logger.info("表情包管理器初始化成功")
|
||||
|
||||
# 初始化回复后关系追踪系统
|
||||
try:
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||
from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker
|
||||
|
||||
relationship_tracker = ChatterRelationshipTracker(interest_scoring_system=chatter_interest_scoring_system)
|
||||
chatter_interest_scoring_system.relationship_tracker = relationship_tracker
|
||||
logger.info("回复后关系追踪系统初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"回复后关系追踪系统初始化失败: {e}")
|
||||
relationship_tracker = None
|
||||
|
||||
# 启动情绪管理器
|
||||
await mood_manager.start()
|
||||
logger.info("情绪管理器初始化成功")
|
||||
@@ -256,11 +305,12 @@ MoFox_Bot(第三方修改版)
|
||||
logger.info("聊天管理器初始化成功")
|
||||
|
||||
# 初始化记忆系统
|
||||
await self.hippocampus_manager.initialize_async()
|
||||
self.hippocampus_manager.initialize()
|
||||
logger.info("记忆系统初始化成功")
|
||||
|
||||
# 初始化LPMM知识库
|
||||
from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge
|
||||
|
||||
initialize_lpmm_knowledge()
|
||||
logger.info("LPMM知识库初始化成功")
|
||||
|
||||
@@ -276,7 +326,7 @@ MoFox_Bot(第三方修改版)
|
||||
# await asyncio.sleep(0.5) #防止logger输出飞了
|
||||
|
||||
# 将bot.py中的chat_bot.message_process消息处理函数注册到api.py的消息处理基类中
|
||||
self.app.register_message_handler(chat_bot.message_process)
|
||||
self.app.register_message_handler(self._message_process_wrapper)
|
||||
|
||||
# 启动消息重组器的清理任务
|
||||
from src.utils.message_chunker import reassembler
|
||||
@@ -284,6 +334,12 @@ MoFox_Bot(第三方修改版)
|
||||
await reassembler.start_cleanup_task()
|
||||
logger.info("消息重组器已启动")
|
||||
|
||||
# 启动消息管理器
|
||||
from src.chat.message_manager import message_manager
|
||||
|
||||
await message_manager.start()
|
||||
logger.info("消息管理器已启动")
|
||||
|
||||
# 初始化个体特征
|
||||
await self.individuality.initialize()
|
||||
|
||||
@@ -291,7 +347,7 @@ MoFox_Bot(第三方修改版)
|
||||
if global_config.planning_system.monthly_plan_enable:
|
||||
logger.info("正在初始化月度计划管理器...")
|
||||
try:
|
||||
await monthly_plan_manager.initialize()
|
||||
await monthly_plan_manager.start_monthly_plan_generation()
|
||||
logger.info("月度计划管理器初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"月度计划管理器初始化失败: {e}")
|
||||
@@ -299,7 +355,8 @@ MoFox_Bot(第三方修改版)
|
||||
# 初始化日程管理器
|
||||
if global_config.planning_system.schedule_enable:
|
||||
logger.info("日程表功能已启用,正在初始化管理器...")
|
||||
await schedule_manager.initialize()
|
||||
await schedule_manager.load_or_generate_today_schedule()
|
||||
await schedule_manager.start_daily_schedule_generation()
|
||||
logger.info("日程表管理器初始化成功。")
|
||||
|
||||
try:
|
||||
|
||||
@@ -5,6 +5,7 @@ import time
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
@@ -65,7 +66,7 @@ class ChatMood:
|
||||
|
||||
self.last_change_time: float = 0
|
||||
|
||||
async def update_mood_by_message(self, message: MessageRecv, interested_rate: float):
|
||||
async def update_mood_by_message(self, message: MessageRecv | DatabaseMessages, interested_rate: float):
|
||||
# 如果当前聊天处于失眠状态,则锁定情绪,不允许更新
|
||||
if self.chat_id in mood_manager.insomnia_chats:
|
||||
logger.debug(f"{self.log_prefix} 处于失眠状态,情绪已锁定,跳过更新。")
|
||||
@@ -73,7 +74,13 @@ class ChatMood:
|
||||
|
||||
self.regression_count = 0
|
||||
|
||||
during_last_time = message.message_info.time - self.last_change_time # type: ignore
|
||||
# 处理不同类型的消息对象
|
||||
if isinstance(message, MessageRecv):
|
||||
message_time = message.message_info.time
|
||||
else: # DatabaseMessages
|
||||
message_time = message.time
|
||||
|
||||
during_last_time = message_time - self.last_change_time
|
||||
|
||||
base_probability = 0.05
|
||||
time_multiplier = 4 * (1 - math.exp(-0.01 * during_last_time))
|
||||
@@ -96,16 +103,14 @@ class ChatMood:
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 更新情绪状态,感兴趣度: {interested_rate:.2f}, 更新概率: {update_probability:.2f}"
|
||||
)
|
||||
|
||||
message_time: float = message.message_info.time # type: ignore
|
||||
message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_change_time,
|
||||
timestamp_end=message_time,
|
||||
limit=int(global_config.chat.max_context_size / 3),
|
||||
limit_mode="last",
|
||||
)
|
||||
chat_talking_prompt = await build_readable_messages(
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
@@ -135,26 +140,26 @@ class ChatMood:
|
||||
prompt=prompt, temperature=0.7
|
||||
)
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"{self.log_prefix} prompt: {prompt}")
|
||||
logger.info(f"{self.log_prefix} response: {response}")
|
||||
logger.info(f"{self.log_prefix} reasoning_content: {reasoning_content}")
|
||||
logger.debug(f"{self.log_prefix} prompt: {prompt}")
|
||||
logger.debug(f"{self.log_prefix} response: {response}")
|
||||
logger.debug(f"{self.log_prefix} reasoning_content: {reasoning_content}")
|
||||
|
||||
logger.info(f"{self.log_prefix} 情绪状态更新为: {response}")
|
||||
|
||||
self.mood_state = response
|
||||
|
||||
self.last_change_time = message_time
|
||||
|
||||
|
||||
async def regress_mood(self):
|
||||
message_time = time.time()
|
||||
message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_change_time,
|
||||
timestamp_end=message_time,
|
||||
limit=15,
|
||||
limit_mode="last",
|
||||
)
|
||||
chat_talking_prompt = await build_readable_messages(
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
@@ -185,9 +190,9 @@ class ChatMood:
|
||||
)
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"{self.log_prefix} prompt: {prompt}")
|
||||
logger.info(f"{self.log_prefix} response: {response}")
|
||||
logger.info(f"{self.log_prefix} reasoning_content: {reasoning_content}")
|
||||
logger.debug(f"{self.log_prefix} prompt: {prompt}")
|
||||
logger.debug(f"{self.log_prefix} response: {response}")
|
||||
logger.debug(f"{self.log_prefix} reasoning_content: {reasoning_content}")
|
||||
|
||||
logger.info(f"{self.log_prefix} 情绪状态转变为: {response}")
|
||||
|
||||
|
||||
@@ -94,10 +94,51 @@ class PersonInfoManager:
|
||||
|
||||
if "-" in platform:
|
||||
platform = platform.split("-")[1]
|
||||
|
||||
# 在此处打一个补丁,如果platform为qq,尝试生成id后检查是否存在,如果不存在,则将平台换为napcat后再次检查,如果存在,则更新原id为platform为qq的id
|
||||
components = [platform, str(user_id)]
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
# 如果不是 qq 平台,直接返回计算的 id
|
||||
if platform != "qq":
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
qq_id = hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
# 对于 qq 平台,先检查该 person_id 是否已存在;如果存在直接返回
|
||||
def _db_check_and_migrate_sync(p_id: str, raw_user_id: str):
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
# 检查 qq_id 是否存在
|
||||
existing_qq = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
|
||||
if existing_qq:
|
||||
return p_id
|
||||
|
||||
# 如果 qq_id 不存在,尝试使用 napcat 作为平台生成对应 id 并检查
|
||||
nap_components = ["napcat", str(raw_user_id)]
|
||||
nap_key = "_".join(nap_components)
|
||||
nap_id = hashlib.md5(nap_key.encode()).hexdigest()
|
||||
|
||||
existing_nap = session.execute(select(PersonInfo).where(PersonInfo.person_id == nap_id)).scalar()
|
||||
if not existing_nap:
|
||||
# napcat 也不存在,返回 qq_id(未命中)
|
||||
return p_id
|
||||
|
||||
# napcat 存在,迁移该记录:更新 person_id 与 platform -> qq
|
||||
try:
|
||||
# 更新现有 napcat 记录
|
||||
existing_nap.person_id = p_id
|
||||
existing_nap.platform = "qq"
|
||||
existing_nap.user_id = str(raw_user_id)
|
||||
session.commit()
|
||||
return p_id
|
||||
except Exception:
|
||||
session.rollback()
|
||||
return p_id
|
||||
except Exception as e:
|
||||
logger.error(f"检查/迁移 napcat->qq 时出错: {e}")
|
||||
return p_id
|
||||
|
||||
return _db_check_and_migrate_sync(qq_id, user_id)
|
||||
|
||||
async def is_person_known(self, platform: str, user_id: int):
|
||||
"""判断是否认识某人"""
|
||||
@@ -127,7 +168,28 @@ class PersonInfoManager:
|
||||
except Exception as e:
|
||||
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (SQLAlchemy): {e}")
|
||||
return ""
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str):
|
||||
"""判断是否认识某人"""
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
# 生成唯一的 person_name
|
||||
person_info_manager = get_person_info_manager()
|
||||
unique_nickname = await person_info_manager._generate_unique_person_name(user_nickname)
|
||||
data = {
|
||||
"platform": platform,
|
||||
"user_id": user_id,
|
||||
"nickname": user_nickname,
|
||||
"konw_time": int(time.time()),
|
||||
"person_name": unique_nickname, # 使用唯一的 person_name
|
||||
}
|
||||
# 先创建用户基本信息,使用安全创建方法避免竞态条件
|
||||
await person_info_manager._safe_create_person_info(person_id=person_id, data=data)
|
||||
# 更新昵称
|
||||
await person_info_manager.update_one_field(
|
||||
person_id=person_id, field_name="nickname", value=user_nickname, data=data
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_person_info(person_id: str, data: Optional[dict] = None):
|
||||
"""创建一个项"""
|
||||
@@ -155,16 +217,16 @@ class PersonInfoManager:
|
||||
# Ensure person_id is correctly set from the argument
|
||||
final_data["person_id"] = person_id
|
||||
# 你们的英文注释是何意味?
|
||||
|
||||
|
||||
# 检查并修复关键字段为None的情况喵
|
||||
if final_data.get("user_id") is None:
|
||||
logger.warning(f"user_id为None,使用'unknown'作为默认值 person_id={person_id}")
|
||||
final_data["user_id"] = "unknown"
|
||||
|
||||
|
||||
if final_data.get("platform") is None:
|
||||
logger.warning(f"platform为None,使用'unknown'作为默认值 person_id={person_id}")
|
||||
final_data["platform"] = "unknown"
|
||||
|
||||
|
||||
# 这里的目的是为了防止在识别出错的情况下有一个最小回退,不只是针对@消息识别成视频后的报错问题
|
||||
|
||||
# Serialize JSON fields
|
||||
@@ -215,12 +277,12 @@ class PersonInfoManager:
|
||||
|
||||
# Ensure person_id is correctly set from the argument
|
||||
final_data["person_id"] = person_id
|
||||
|
||||
|
||||
# 检查并修复关键字段为None的情况
|
||||
if final_data.get("user_id") is None:
|
||||
logger.warning(f"user_id为None,使用'unknown'作为默认值 person_id={person_id}")
|
||||
final_data["user_id"] = "unknown"
|
||||
|
||||
|
||||
if final_data.get("platform") is None:
|
||||
logger.warning(f"platform为None,使用'unknown'作为默认值 person_id={person_id}")
|
||||
final_data["platform"] = "unknown"
|
||||
@@ -315,12 +377,12 @@ class PersonInfoManager:
|
||||
creation_data["platform"] = data["platform"]
|
||||
if data and "user_id" in data:
|
||||
creation_data["user_id"] = data["user_id"]
|
||||
|
||||
|
||||
# 额外检查关键字段,如果为None则使用默认值
|
||||
if creation_data.get("user_id") is None:
|
||||
logger.warning(f"创建用户时user_id为None,使用'unknown'作为默认值 person_id={person_id}")
|
||||
creation_data["user_id"] = "unknown"
|
||||
|
||||
|
||||
if creation_data.get("platform") is None:
|
||||
logger.warning(f"创建用户时platform为None,使用'unknown'作为默认值 person_id={person_id}")
|
||||
creation_data["platform"] = "unknown"
|
||||
|
||||
@@ -94,90 +94,144 @@ class RelationshipFetcher:
|
||||
if not self.info_fetched_cache[person_id]:
|
||||
del self.info_fetched_cache[person_id]
|
||||
|
||||
async def build_relation_info(self, person_id, points_num=3):
|
||||
async def build_relation_info(self, person_id, points_num=5):
|
||||
"""构建详细的人物关系信息,包含从数据库中查询的丰富关系描述"""
|
||||
# 清理过期的信息缓存
|
||||
self._cleanup_expired_cache()
|
||||
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_info = await person_info_manager.get_values(
|
||||
person_id, ["person_name", "short_impression", "nickname", "platform", "points"]
|
||||
)
|
||||
person_name = person_info.get("person_name")
|
||||
short_impression = person_info.get("short_impression")
|
||||
nickname_str = person_info.get("nickname")
|
||||
platform = person_info.get("platform")
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
short_impression = await person_info_manager.get_value(person_id, "short_impression")
|
||||
full_impression = await person_info_manager.get_value(person_id, "impression")
|
||||
attitude = await person_info_manager.get_value(person_id, "attitude") or 50
|
||||
|
||||
if person_name == nickname_str and not short_impression:
|
||||
return ""
|
||||
nickname_str = await person_info_manager.get_value(person_id, "nickname")
|
||||
platform = await person_info_manager.get_value(person_id, "platform")
|
||||
know_times = await person_info_manager.get_value(person_id, "know_times") or 0
|
||||
know_since = await person_info_manager.get_value(person_id, "know_since")
|
||||
last_know = await person_info_manager.get_value(person_id, "last_know")
|
||||
|
||||
current_points = person_info.get("points")
|
||||
if isinstance(current_points, str):
|
||||
current_points = orjson.loads(current_points)
|
||||
# 如果用户没有基本信息,返回默认描述
|
||||
if person_name == nickname_str and not short_impression and not full_impression:
|
||||
return f"你完全不认识{person_name},这是你们第一次交流。"
|
||||
|
||||
# 获取用户特征点
|
||||
current_points = await person_info_manager.get_value(person_id, "points") or []
|
||||
forgotten_points = await person_info_manager.get_value(person_id, "forgotten_points") or []
|
||||
|
||||
# 按时间排序并选择最有代表性的特征点
|
||||
all_points = current_points + forgotten_points
|
||||
if all_points:
|
||||
# 按权重和时效性综合排序
|
||||
all_points.sort(
|
||||
key=lambda x: (float(x[1]) if len(x) > 1 else 0, float(x[2]) if len(x) > 2 else 0), reverse=True
|
||||
)
|
||||
selected_points = all_points[:points_num]
|
||||
points_text = "\n".join([f"- {point[0]}({point[2]})" for point in selected_points if len(point) > 2])
|
||||
else:
|
||||
current_points = current_points or []
|
||||
points_text = ""
|
||||
|
||||
# 按时间排序forgotten_points
|
||||
current_points.sort(key=lambda x: x[2])
|
||||
# 按权重加权随机抽取最多3个不重复的points,point[1]的值在1-10之间,权重越高被抽到概率越大
|
||||
if len(current_points) > points_num:
|
||||
# point[1] 取值范围1-10,直接作为权重
|
||||
weights = [max(1, min(10, int(point[1]))) for point in current_points]
|
||||
# 使用加权采样不放回,保证不重复
|
||||
indices = list(range(len(current_points)))
|
||||
points = []
|
||||
for _ in range(points_num):
|
||||
if not indices:
|
||||
break
|
||||
sub_weights = [weights[i] for i in indices]
|
||||
chosen_idx = random.choices(indices, weights=sub_weights, k=1)[0]
|
||||
points.append(current_points[chosen_idx])
|
||||
indices.remove(chosen_idx)
|
||||
# 构建详细的关系描述
|
||||
relation_parts = []
|
||||
|
||||
# 1. 基本信息
|
||||
if nickname_str and person_name != nickname_str:
|
||||
relation_parts.append(f"用户{person_name}在{platform}平台的昵称是{nickname_str}")
|
||||
|
||||
# 2. 认识时间和频率
|
||||
if know_since:
|
||||
from datetime import datetime
|
||||
|
||||
know_time = datetime.fromtimestamp(know_since).strftime("%Y年%m月%d日")
|
||||
relation_parts.append(f"你从{know_time}开始认识{person_name}")
|
||||
|
||||
if know_times > 0:
|
||||
relation_parts.append(f"你们已经交流过{int(know_times)}次")
|
||||
|
||||
if last_know:
|
||||
from datetime import datetime
|
||||
|
||||
last_time = datetime.fromtimestamp(last_know).strftime("%m月%d日")
|
||||
relation_parts.append(f"最近一次交流是在{last_time}")
|
||||
|
||||
# 3. 态度和印象
|
||||
attitude_desc = self._get_attitude_description(attitude)
|
||||
relation_parts.append(f"你对{person_name}的态度是{attitude_desc}")
|
||||
|
||||
if short_impression:
|
||||
relation_parts.append(f"你对ta的总体印象:{short_impression}")
|
||||
|
||||
if full_impression:
|
||||
relation_parts.append(f"更详细的了解:{full_impression}")
|
||||
|
||||
# 4. 特征点和记忆
|
||||
if points_text:
|
||||
relation_parts.append(f"你记得关于{person_name}的一些事情:\n{points_text}")
|
||||
|
||||
# 5. 从UserRelationships表获取额外关系信息
|
||||
try:
|
||||
from src.common.database.sqlalchemy_database_api import db_query
|
||||
from src.common.database.sqlalchemy_models import UserRelationships
|
||||
|
||||
# 查询用户关系数据
|
||||
relationships = await db_query(
|
||||
UserRelationships,
|
||||
filters=[UserRelationships.user_id == str(person_info_manager.get_value_sync(person_id, "user_id"))],
|
||||
limit=1,
|
||||
)
|
||||
|
||||
if relationships:
|
||||
rel_data = relationships[0]
|
||||
if rel_data.relationship_text:
|
||||
relation_parts.append(f"关系记录:{rel_data.relationship_text}")
|
||||
if rel_data.relationship_score:
|
||||
score_desc = self._get_relationship_score_description(rel_data.relationship_score)
|
||||
relation_parts.append(f"关系亲密程度:{score_desc}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"查询UserRelationships表失败: {e}")
|
||||
|
||||
# 构建最终的关系信息字符串
|
||||
if relation_parts:
|
||||
relation_info = f"关于{person_name},你知道以下信息:\n" + "\n".join(
|
||||
[f"• {part}" for part in relation_parts]
|
||||
)
|
||||
else:
|
||||
points = current_points
|
||||
|
||||
# 构建points文本
|
||||
points_text = "\n".join([f"{point[2]}:{point[0]}" for point in points])
|
||||
|
||||
nickname_str = ""
|
||||
if person_name != nickname_str:
|
||||
nickname_str = f"(ta在{platform}上的昵称是{nickname_str})"
|
||||
|
||||
relation_info = ""
|
||||
|
||||
if short_impression and relation_info:
|
||||
if points_text:
|
||||
relation_info = f"你对{person_name}的印象是{nickname_str}:{short_impression}。具体来说:{relation_info}。你还记得ta最近做的事:{points_text}"
|
||||
else:
|
||||
relation_info = (
|
||||
f"你对{person_name}的印象是{nickname_str}:{short_impression}。具体来说:{relation_info}"
|
||||
)
|
||||
elif short_impression:
|
||||
if points_text:
|
||||
relation_info = (
|
||||
f"你对{person_name}的印象是{nickname_str}:{short_impression}。你还记得ta最近做的事:{points_text}"
|
||||
)
|
||||
else:
|
||||
relation_info = f"你对{person_name}的印象是{nickname_str}:{short_impression}"
|
||||
elif relation_info:
|
||||
if points_text:
|
||||
relation_info = (
|
||||
f"你对{person_name}的了解{nickname_str}:{relation_info}。你还记得ta最近做的事:{points_text}"
|
||||
)
|
||||
else:
|
||||
relation_info = f"你对{person_name}的了解{nickname_str}:{relation_info}"
|
||||
elif points_text:
|
||||
relation_info = f"你记得{person_name}{nickname_str}最近做的事:{points_text}"
|
||||
else:
|
||||
relation_info = ""
|
||||
relation_info = f"你对{person_name}了解不多,这是比较初步的交流。"
|
||||
|
||||
return relation_info
|
||||
|
||||
def _get_attitude_description(self, attitude: int) -> str:
|
||||
"""根据态度分数返回描述性文字"""
|
||||
if attitude >= 80:
|
||||
return "非常喜欢和欣赏"
|
||||
elif attitude >= 60:
|
||||
return "比较有好感"
|
||||
elif attitude >= 40:
|
||||
return "中立态度"
|
||||
elif attitude >= 20:
|
||||
return "有些反感"
|
||||
else:
|
||||
return "非常厌恶"
|
||||
|
||||
def _get_relationship_score_description(self, score: float) -> str:
|
||||
"""根据关系分数返回描述性文字"""
|
||||
if score >= 0.8:
|
||||
return "非常亲密的好友"
|
||||
elif score >= 0.6:
|
||||
return "关系不错的朋友"
|
||||
elif score >= 0.4:
|
||||
return "普通熟人"
|
||||
elif score >= 0.2:
|
||||
return "认识但不熟悉"
|
||||
else:
|
||||
return "陌生人"
|
||||
|
||||
async def _build_fetch_query(self, person_id, target_message, chat_history):
|
||||
nickname_str = ",".join(global_config.bot.alias_names)
|
||||
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_info = await person_info_manager.get_values(person_id, ["person_name"])
|
||||
person_name: str = person_info.get("person_name") # type: ignore
|
||||
person_name: str = await person_info_manager.get_value(person_id, "person_name") # type: ignore
|
||||
|
||||
info_cache_block = self._build_info_cache_block()
|
||||
|
||||
@@ -259,8 +313,7 @@ class RelationshipFetcher:
|
||||
person_info_manager = get_person_info_manager()
|
||||
|
||||
# 首先检查 info_list 缓存
|
||||
person_info = await person_info_manager.get_values(person_id, ["info_list"])
|
||||
info_list = person_info.get("info_list") or []
|
||||
info_list = await person_info_manager.get_value(person_id, "info_list") or []
|
||||
cached_info = None
|
||||
|
||||
# 查找对应的 info_type
|
||||
@@ -287,9 +340,8 @@ class RelationshipFetcher:
|
||||
|
||||
# 如果缓存中没有,尝试从用户档案中提取
|
||||
try:
|
||||
person_info = await person_info_manager.get_values(person_id, ["impression", "points"])
|
||||
person_impression = person_info.get("impression")
|
||||
points = person_info.get("points")
|
||||
person_impression = await person_info_manager.get_value(person_id, "impression")
|
||||
points = await person_info_manager.get_value(person_id, "points")
|
||||
|
||||
# 构建印象信息块
|
||||
if person_impression:
|
||||
@@ -381,8 +433,7 @@ class RelationshipFetcher:
|
||||
person_info_manager = get_person_info_manager()
|
||||
|
||||
# 获取现有的 info_list
|
||||
person_info = await person_info_manager.get_values(person_id, ["info_list"])
|
||||
info_list = person_info.get("info_list") or []
|
||||
info_list = await person_info_manager.get_value(person_id, "info_list") or []
|
||||
|
||||
# 查找是否已存在相同 info_type 的记录
|
||||
found_index = -1
|
||||
|
||||
@@ -121,6 +121,13 @@ async def generate_reply(
|
||||
if not extra_info and action_data:
|
||||
extra_info = action_data.get("extra_info", "")
|
||||
|
||||
# 如果action_data中有thinking,添加到extra_info中
|
||||
if action_data and (thinking := action_data.get("thinking")):
|
||||
if extra_info:
|
||||
extra_info += f"\n\n思考过程:{thinking}"
|
||||
else:
|
||||
extra_info = f"思考过程:{thinking}"
|
||||
|
||||
# 调用回复器生成回复
|
||||
success, llm_response_dict, prompt = await replyer.generate_reply_with_context(
|
||||
reply_to=reply_to,
|
||||
|
||||
@@ -80,7 +80,7 @@ def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[Messa
|
||||
|
||||
message_info = {
|
||||
"platform": message_dict.get("chat_info_platform", ""),
|
||||
"message_id": message_dict.get("message_id"),
|
||||
"message_id": message_dict.get("message_id") or message_dict.get("chat_info_message_id") or message_dict.get("id"),
|
||||
"time": message_dict.get("time"),
|
||||
"group_info": group_info,
|
||||
"user_info": user_info,
|
||||
@@ -89,15 +89,16 @@ def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[Messa
|
||||
"template_info": template_info,
|
||||
}
|
||||
|
||||
message_dict = {
|
||||
new_message_dict = {
|
||||
"message_info": message_info,
|
||||
"raw_message": message_dict.get("processed_plain_text"),
|
||||
"processed_plain_text": message_dict.get("processed_plain_text"),
|
||||
}
|
||||
|
||||
message_recv = MessageRecv(message_dict)
|
||||
message_recv = MessageRecv(new_message_dict)
|
||||
|
||||
logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {message_dict.get('user_nickname', '')}")
|
||||
logger.info(message_recv)
|
||||
return message_recv
|
||||
|
||||
|
||||
@@ -246,7 +247,7 @@ async def text_to_stream(
|
||||
typing: bool = False,
|
||||
reply_to: str = "",
|
||||
reply_to_message: Optional[Dict[str, Any]] = None,
|
||||
set_reply: bool = False,
|
||||
set_reply: bool = True,
|
||||
storage_message: bool = True,
|
||||
) -> bool:
|
||||
"""向指定流发送文本消息
|
||||
@@ -275,7 +276,7 @@ async def text_to_stream(
|
||||
|
||||
|
||||
async def emoji_to_stream(
|
||||
emoji_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False
|
||||
emoji_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = True
|
||||
) -> bool:
|
||||
"""向指定流发送表情包
|
||||
|
||||
@@ -293,7 +294,7 @@ async def emoji_to_stream(
|
||||
|
||||
|
||||
async def image_to_stream(
|
||||
image_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False
|
||||
image_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = True
|
||||
) -> bool:
|
||||
"""向指定流发送图片
|
||||
|
||||
@@ -315,7 +316,7 @@ async def command_to_stream(
|
||||
stream_id: str,
|
||||
storage_message: bool = True,
|
||||
display_message: str = "",
|
||||
set_reply: bool = False,
|
||||
set_reply: bool = True,
|
||||
) -> bool:
|
||||
"""向指定流发送命令
|
||||
|
||||
@@ -340,7 +341,7 @@ async def custom_to_stream(
|
||||
typing: bool = False,
|
||||
reply_to: str = "",
|
||||
reply_to_message: Optional[Dict[str, Any]] = None,
|
||||
set_reply: bool = False,
|
||||
set_reply: bool = True,
|
||||
storage_message: bool = True,
|
||||
show_log: bool = True,
|
||||
) -> bool:
|
||||
|
||||
@@ -93,7 +93,6 @@ class BaseAction(ABC):
|
||||
self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy()
|
||||
self.chat_type_allow: ChatType = getattr(self.__class__, "chat_type_allow", ChatType.ALL)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解)
|
||||
# =============================================================================
|
||||
@@ -398,6 +397,7 @@ class BaseAction(ABC):
|
||||
try:
|
||||
# 1. 从注册中心获取Action类
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
action_class = component_registry.get_component_class(action_name, ComponentType.ACTION)
|
||||
if not action_class:
|
||||
logger.error(f"{log_prefix} 未找到Action: {action_name}")
|
||||
@@ -406,7 +406,7 @@ class BaseAction(ABC):
|
||||
# 2. 准备实例化参数
|
||||
# 复用当前Action的大部分上下文信息
|
||||
called_action_data = action_data if action_data is not None else self.action_data
|
||||
|
||||
|
||||
component_info = component_registry.get_component_info(action_name, ComponentType.ACTION)
|
||||
if not component_info:
|
||||
logger.warning(f"{log_prefix} 未找到Action组件信息: {action_name}")
|
||||
|
||||
55
src/plugin_system/base/base_chatter.py
Normal file
55
src/plugin_system/base/base_chatter.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from .component_types import ChatType
|
||||
from src.plugin_system.base.component_types import ChatterInfo, ComponentType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.plugins.built_in.affinity_flow_chatter.planner import ChatterActionPlanner as ActionPlanner
|
||||
|
||||
class BaseChatter(ABC):
|
||||
chatter_name: str = ""
|
||||
"""Chatter组件的名称"""
|
||||
chatter_description: str = ""
|
||||
"""Chatter组件的描述"""
|
||||
chat_types: List[ChatType] = [ChatType.PRIVATE, ChatType.GROUP]
|
||||
|
||||
def __init__(self, stream_id: str, action_manager: 'ChatterActionManager'):
|
||||
"""
|
||||
初始化聊天处理器
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
action_manager: 动作管理器
|
||||
"""
|
||||
self.stream_id = stream_id
|
||||
self.action_manager = action_manager
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, context: StreamContext) -> dict:
|
||||
"""
|
||||
执行聊天处理流程
|
||||
|
||||
Args:
|
||||
context: StreamContext对象,包含聊天流的所有消息信息
|
||||
|
||||
Returns:
|
||||
处理结果字典
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_chatter_info(cls) -> "ChatterInfo":
|
||||
"""从类属性生成ChatterInfo
|
||||
Returns:
|
||||
ChatterInfo对象
|
||||
"""
|
||||
|
||||
return ChatterInfo(
|
||||
name=cls.chatter_name,
|
||||
description=cls.chatter_description or "No description provided.",
|
||||
chat_type_allow=cls.chat_types[0],
|
||||
component_type=ComponentType.CHATTER,
|
||||
)
|
||||
|
||||
@@ -73,7 +73,7 @@ class BaseCommand(ABC):
|
||||
return True
|
||||
|
||||
# 检查是否为群聊消息
|
||||
is_group = hasattr(self.message, "is_group_message") and self.message.is_group_message
|
||||
is_group = self.message.message_info.group_info
|
||||
|
||||
if self.chat_type_allow == ChatType.GROUP and is_group:
|
||||
return True
|
||||
|
||||
@@ -98,7 +98,7 @@ class BaseEventHandler(ABC):
|
||||
weight=cls.weight,
|
||||
intercept_message=cls.intercept_message,
|
||||
)
|
||||
|
||||
|
||||
def set_plugin_name(self, plugin_name: str) -> None:
|
||||
"""设置插件名称
|
||||
|
||||
@@ -107,9 +107,9 @@ class BaseEventHandler(ABC):
|
||||
"""
|
||||
self.plugin_name = plugin_name
|
||||
|
||||
def set_plugin_config(self,plugin_config) -> None:
|
||||
def set_plugin_config(self, plugin_config) -> None:
|
||||
self.plugin_config = plugin_config
|
||||
|
||||
|
||||
def get_config(self, key: str, default=None):
|
||||
"""获取插件配置值,支持嵌套键访问
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ class ComponentType(Enum):
|
||||
TOOL = "tool" # 工具组件
|
||||
SCHEDULER = "scheduler" # 定时任务组件(预留)
|
||||
EVENT_HANDLER = "event_handler" # 事件处理组件
|
||||
CHATTER = "chatter" # 聊天处理器组件
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
@@ -39,8 +40,8 @@ class ActionActivationType(Enum):
|
||||
# 聊天模式枚举
|
||||
class ChatMode(Enum):
|
||||
"""聊天模式枚举"""
|
||||
|
||||
FOCUS = "focus" # Focus聊天模式
|
||||
|
||||
FOCUS = "focus" # 专注模式
|
||||
NORMAL = "normal" # Normal聊天模式
|
||||
PROACTIVE = "proactive" # 主动思考模式
|
||||
PRIORITY = "priority" # 优先级聊天模式
|
||||
@@ -54,8 +55,8 @@ class ChatMode(Enum):
|
||||
class ChatType(Enum):
|
||||
"""聊天类型枚举,用于限制插件在不同聊天环境中的使用"""
|
||||
|
||||
GROUP = "group" # 仅群聊可用
|
||||
PRIVATE = "private" # 仅私聊可用
|
||||
GROUP = "group" # 仅群聊可用
|
||||
ALL = "all" # 群聊和私聊都可用
|
||||
|
||||
def __str__(self):
|
||||
@@ -69,7 +70,7 @@ class EventType(Enum):
|
||||
"""
|
||||
|
||||
ON_START = "on_start" # 启动事件,用于调用按时任务
|
||||
ON_STOP ="on_stop"
|
||||
ON_STOP = "on_stop"
|
||||
ON_MESSAGE = "on_message"
|
||||
ON_PLAN = "on_plan"
|
||||
POST_LLM = "post_llm"
|
||||
@@ -210,6 +211,17 @@ class EventHandlerInfo(ComponentInfo):
|
||||
self.component_type = ComponentType.EVENT_HANDLER
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatterInfo(ComponentInfo):
|
||||
"""聊天处理器组件信息"""
|
||||
|
||||
chat_type_allow: ChatType = ChatType.ALL # 允许的聊天类型
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.component_type = ComponentType.CHATTER
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventInfo(ComponentInfo):
|
||||
"""事件组件信息"""
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
||||
from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Any, Pattern, Tuple, Union, Type
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import (
|
||||
@@ -11,14 +11,17 @@ from src.plugin_system.base.component_types import (
|
||||
CommandInfo,
|
||||
PlusCommandInfo,
|
||||
EventHandlerInfo,
|
||||
ChatterInfo,
|
||||
PluginInfo,
|
||||
ComponentType,
|
||||
)
|
||||
|
||||
from src.plugin_system.base.base_command import BaseCommand
|
||||
from src.plugin_system.base.base_action import BaseAction
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
from src.plugin_system.base.plus_command import PlusCommand
|
||||
from src.plugin_system.base.base_chatter import BaseChatter
|
||||
|
||||
logger = get_logger("component_registry")
|
||||
|
||||
@@ -31,42 +34,45 @@ class ComponentRegistry:
|
||||
|
||||
def __init__(self):
|
||||
# 命名空间式组件名构成法 f"{component_type}.{component_name}"
|
||||
self._plus_command_registry: Dict[str, Type[PlusCommand]] = {}
|
||||
self._components: Dict[str, ComponentInfo] = {}
|
||||
self._components: Dict[str, 'ComponentInfo'] = {}
|
||||
"""组件注册表 命名空间式组件名 -> 组件信息"""
|
||||
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType}
|
||||
self._components_by_type: Dict['ComponentType', Dict[str, 'ComponentInfo']] = {types: {} for types in ComponentType}
|
||||
"""类型 -> 组件原名称 -> 组件信息"""
|
||||
self._components_classes: Dict[
|
||||
str, Type[Union[BaseCommand, BaseAction, BaseTool, BaseEventHandler, PlusCommand]]
|
||||
str, Type[Union['BaseCommand', 'BaseAction', 'BaseTool', 'BaseEventHandler', 'PlusCommand', 'BaseChatter']]
|
||||
] = {}
|
||||
"""命名空间式组件名 -> 组件类"""
|
||||
|
||||
# 插件注册表
|
||||
self._plugins: Dict[str, PluginInfo] = {}
|
||||
self._plugins: Dict[str, 'PluginInfo'] = {}
|
||||
"""插件名 -> 插件信息"""
|
||||
|
||||
# Action特定注册表
|
||||
self._action_registry: Dict[str, Type[BaseAction]] = {}
|
||||
self._action_registry: Dict[str, Type['BaseAction']] = {}
|
||||
"""Action注册表 action名 -> action类"""
|
||||
self._default_actions: Dict[str, ActionInfo] = {}
|
||||
self._default_actions: Dict[str, 'ActionInfo'] = {}
|
||||
"""默认动作集,即启用的Action集,用于重置ActionManager状态"""
|
||||
|
||||
# Command特定注册表
|
||||
self._command_registry: Dict[str, Type[BaseCommand]] = {}
|
||||
self._command_registry: Dict[str, Type['BaseCommand']] = {}
|
||||
"""Command类注册表 command名 -> command类"""
|
||||
self._command_patterns: Dict[Pattern, str] = {}
|
||||
"""编译后的正则 -> command名"""
|
||||
|
||||
# 工具特定注册表
|
||||
self._tool_registry: Dict[str, Type[BaseTool]] = {} # 工具名 -> 工具类
|
||||
self._llm_available_tools: Dict[str, Type[BaseTool]] = {} # llm可用的工具名 -> 工具类
|
||||
self._tool_registry: Dict[str, Type['BaseTool']] = {} # 工具名 -> 工具类
|
||||
self._llm_available_tools: Dict[str, Type['BaseTool']] = {} # llm可用的工具名 -> 工具类
|
||||
|
||||
# EventHandler特定注册表
|
||||
self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {}
|
||||
self._event_handler_registry: Dict[str, Type['BaseEventHandler']] = {}
|
||||
"""event_handler名 -> event_handler类"""
|
||||
self._enabled_event_handlers: Dict[str, Type[BaseEventHandler]] = {}
|
||||
self._enabled_event_handlers: Dict[str, Type['BaseEventHandler']] = {}
|
||||
"""启用的事件处理器 event_handler名 -> event_handler类"""
|
||||
|
||||
self._chatter_registry: Dict[str, Type['BaseChatter']] = {}
|
||||
"""chatter名 -> chatter类"""
|
||||
self._enabled_chatter_registry: Dict[str, Type['BaseChatter']] = {}
|
||||
"""启用的chatter名 -> chatter类"""
|
||||
logger.info("组件注册中心初始化完成")
|
||||
|
||||
# == 注册方法 ==
|
||||
@@ -93,7 +99,7 @@ class ComponentRegistry:
|
||||
def register_component(
|
||||
self,
|
||||
component_info: ComponentInfo,
|
||||
component_class: Type[Union[BaseCommand, BaseAction, BaseEventHandler, BaseTool]],
|
||||
component_class: Type[Union['BaseCommand', 'BaseAction', 'BaseEventHandler', 'BaseTool', 'BaseChatter']],
|
||||
) -> bool:
|
||||
"""注册组件
|
||||
|
||||
@@ -151,6 +157,10 @@ class ComponentRegistry:
|
||||
assert isinstance(component_info, EventHandlerInfo)
|
||||
assert issubclass(component_class, BaseEventHandler)
|
||||
ret = self._register_event_handler_component(component_info, component_class)
|
||||
case ComponentType.CHATTER:
|
||||
assert isinstance(component_info, ChatterInfo)
|
||||
assert issubclass(component_class, BaseChatter)
|
||||
ret = self._register_chatter_component(component_info, component_class)
|
||||
case _:
|
||||
logger.warning(f"未知组件类型: {component_type}")
|
||||
|
||||
@@ -162,7 +172,7 @@ class ComponentRegistry:
|
||||
)
|
||||
return True
|
||||
|
||||
def _register_action_component(self, action_info: ActionInfo, action_class: Type[BaseAction]) -> bool:
|
||||
def _register_action_component(self, action_info: 'ActionInfo', action_class: Type['BaseAction']) -> bool:
|
||||
"""注册Action组件到Action特定注册表"""
|
||||
if not (action_name := action_info.name):
|
||||
logger.error(f"Action组件 {action_class.__name__} 必须指定名称")
|
||||
@@ -182,7 +192,7 @@ class ComponentRegistry:
|
||||
|
||||
return True
|
||||
|
||||
def _register_command_component(self, command_info: CommandInfo, command_class: Type[BaseCommand]) -> bool:
|
||||
def _register_command_component(self, command_info: 'CommandInfo', command_class: Type['BaseCommand']) -> bool:
|
||||
"""注册Command组件到Command特定注册表"""
|
||||
if not (command_name := command_info.name):
|
||||
logger.error(f"Command组件 {command_class.__name__} 必须指定名称")
|
||||
@@ -209,7 +219,7 @@ class ComponentRegistry:
|
||||
return True
|
||||
|
||||
def _register_plus_command_component(
|
||||
self, plus_command_info: PlusCommandInfo, plus_command_class: Type[PlusCommand]
|
||||
self, plus_command_info: 'PlusCommandInfo', plus_command_class: Type['PlusCommand']
|
||||
) -> bool:
|
||||
"""注册PlusCommand组件到特定注册表"""
|
||||
plus_command_name = plus_command_info.name
|
||||
@@ -223,7 +233,7 @@ class ComponentRegistry:
|
||||
|
||||
# 创建专门的PlusCommand注册表(如果还没有)
|
||||
if not hasattr(self, "_plus_command_registry"):
|
||||
self._plus_command_registry: Dict[str, Type[PlusCommand]] = {}
|
||||
self._plus_command_registry: Dict[str, Type['PlusCommand']] = {}
|
||||
|
||||
plus_command_class.plugin_name = plus_command_info.plugin_name
|
||||
# 设置插件配置
|
||||
@@ -233,7 +243,7 @@ class ComponentRegistry:
|
||||
logger.debug(f"已注册PlusCommand组件: {plus_command_name}")
|
||||
return True
|
||||
|
||||
def _register_tool_component(self, tool_info: ToolInfo, tool_class: Type[BaseTool]) -> bool:
|
||||
def _register_tool_component(self, tool_info: 'ToolInfo', tool_class: Type['BaseTool']) -> bool:
|
||||
"""注册Tool组件到Tool特定注册表"""
|
||||
tool_name = tool_info.name
|
||||
|
||||
@@ -249,7 +259,7 @@ class ComponentRegistry:
|
||||
return True
|
||||
|
||||
def _register_event_handler_component(
|
||||
self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]
|
||||
self, handler_info: 'EventHandlerInfo', handler_class: Type['BaseEventHandler']
|
||||
) -> bool:
|
||||
if not (handler_name := handler_info.name):
|
||||
logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称")
|
||||
@@ -271,11 +281,38 @@ class ComponentRegistry:
|
||||
# 使用EventManager进行事件处理器注册
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
|
||||
return event_manager.register_event_handler(handler_class,self.get_plugin_config(handler_info.plugin_name) or {})
|
||||
return event_manager.register_event_handler(
|
||||
handler_class, self.get_plugin_config(handler_info.plugin_name) or {}
|
||||
)
|
||||
|
||||
def _register_chatter_component(self, chatter_info: 'ChatterInfo', chatter_class: Type['BaseChatter']) -> bool:
|
||||
"""注册Chatter组件到Chatter特定注册表"""
|
||||
chatter_name = chatter_info.name
|
||||
|
||||
if not chatter_name:
|
||||
logger.error(f"Chatter组件 {chatter_class.__name__} 必须指定名称")
|
||||
return False
|
||||
if not isinstance(chatter_info, ChatterInfo) or not issubclass(chatter_class, BaseChatter):
|
||||
logger.error(f"注册失败: {chatter_name} 不是有效的Chatter")
|
||||
return False
|
||||
|
||||
chatter_class.plugin_name = chatter_info.plugin_name
|
||||
# 设置插件配置
|
||||
chatter_class.plugin_config = self.get_plugin_config(chatter_info.plugin_name) or {}
|
||||
|
||||
self._chatter_registry[chatter_name] = chatter_class
|
||||
|
||||
if not chatter_info.enabled:
|
||||
logger.warning(f"Chatter组件 {chatter_name} 未启用")
|
||||
return True # 未启用,但是也是注册成功
|
||||
self._enabled_chatter_registry[chatter_name] = chatter_class
|
||||
|
||||
logger.debug(f"已注册Chatter组件: {chatter_name}")
|
||||
return True
|
||||
|
||||
# === 组件移除相关 ===
|
||||
|
||||
async def remove_component(self, component_name: str, component_type: ComponentType, plugin_name: str) -> bool:
|
||||
async def remove_component(self, component_name: str, component_type: 'ComponentType', plugin_name: str) -> bool:
|
||||
target_component_class = self.get_component_class(component_name, component_type)
|
||||
if not target_component_class:
|
||||
logger.warning(f"组件 {component_name} 未注册,无法移除")
|
||||
@@ -323,6 +360,12 @@ class ComponentRegistry:
|
||||
except Exception as e:
|
||||
logger.warning(f"移除EventHandler事件订阅时出错: {e}")
|
||||
|
||||
case ComponentType.CHATTER:
|
||||
# 移除Chatter注册
|
||||
if hasattr(self, '_chatter_registry'):
|
||||
self._chatter_registry.pop(component_name, None)
|
||||
logger.debug(f"已移除Chatter组件: {component_name}")
|
||||
|
||||
case _:
|
||||
logger.warning(f"未知的组件类型: {component_type}")
|
||||
return False
|
||||
@@ -441,8 +484,8 @@ class ComponentRegistry:
|
||||
|
||||
# === 组件查询方法 ===
|
||||
def get_component_info(
|
||||
self, component_name: str, component_type: Optional[ComponentType] = None
|
||||
) -> Optional[ComponentInfo]:
|
||||
self, component_name: str, component_type: Optional['ComponentType'] = None
|
||||
) -> Optional['ComponentInfo']:
|
||||
# sourcery skip: class-extract-method
|
||||
"""获取组件信息,支持自动命名空间解析
|
||||
|
||||
@@ -486,8 +529,8 @@ class ComponentRegistry:
|
||||
def get_component_class(
|
||||
self,
|
||||
component_name: str,
|
||||
component_type: Optional[ComponentType] = None,
|
||||
) -> Optional[Union[Type[BaseCommand], Type[BaseAction], Type[BaseEventHandler], Type[BaseTool]]]:
|
||||
component_type: Optional['ComponentType'] = None,
|
||||
) -> Optional[Union[Type['BaseCommand'], Type['BaseAction'], Type['BaseEventHandler'], Type['BaseTool']]]:
|
||||
"""获取组件类,支持自动命名空间解析
|
||||
|
||||
Args:
|
||||
@@ -504,7 +547,7 @@ class ComponentRegistry:
|
||||
# 2. 如果指定了组件类型,构造命名空间化的名称查找
|
||||
if component_type:
|
||||
namespaced_name = f"{component_type.value}.{component_name}"
|
||||
return self._components_classes.get(namespaced_name)
|
||||
return self._components_classes.get(namespaced_name) # type: ignore[valid-type]
|
||||
|
||||
# 3. 如果没有指定类型,尝试在所有命名空间中查找
|
||||
candidates = []
|
||||
@@ -529,22 +572,22 @@ class ComponentRegistry:
|
||||
# 4. 都没找到
|
||||
return None
|
||||
|
||||
def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
|
||||
def get_components_by_type(self, component_type: 'ComponentType') -> Dict[str, 'ComponentInfo']:
|
||||
"""获取指定类型的所有组件"""
|
||||
return self._components_by_type.get(component_type, {}).copy()
|
||||
|
||||
def get_enabled_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
|
||||
def get_enabled_components_by_type(self, component_type: 'ComponentType') -> Dict[str, 'ComponentInfo']:
|
||||
"""获取指定类型的所有启用组件"""
|
||||
components = self.get_components_by_type(component_type)
|
||||
return {name: info for name, info in components.items() if info.enabled}
|
||||
|
||||
# === Action特定查询方法 ===
|
||||
|
||||
def get_action_registry(self) -> Dict[str, Type[BaseAction]]:
|
||||
def get_action_registry(self) -> Dict[str, Type['BaseAction']]:
|
||||
"""获取Action注册表"""
|
||||
return self._action_registry.copy()
|
||||
|
||||
def get_registered_action_info(self, action_name: str) -> Optional[ActionInfo]:
|
||||
def get_registered_action_info(self, action_name: str) -> Optional['ActionInfo']:
|
||||
"""获取Action信息"""
|
||||
info = self.get_component_info(action_name, ComponentType.ACTION)
|
||||
return info if isinstance(info, ActionInfo) else None
|
||||
@@ -555,11 +598,11 @@ class ComponentRegistry:
|
||||
|
||||
# === Command特定查询方法 ===
|
||||
|
||||
def get_command_registry(self) -> Dict[str, Type[BaseCommand]]:
|
||||
def get_command_registry(self) -> Dict[str, Type['BaseCommand']]:
|
||||
"""获取Command注册表"""
|
||||
return self._command_registry.copy()
|
||||
|
||||
def get_registered_command_info(self, command_name: str) -> Optional[CommandInfo]:
|
||||
def get_registered_command_info(self, command_name: str) -> Optional['CommandInfo']:
|
||||
"""获取Command信息"""
|
||||
info = self.get_component_info(command_name, ComponentType.COMMAND)
|
||||
return info if isinstance(info, CommandInfo) else None
|
||||
@@ -568,7 +611,7 @@ class ComponentRegistry:
|
||||
"""获取Command模式注册表"""
|
||||
return self._command_patterns.copy()
|
||||
|
||||
def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, CommandInfo]]:
|
||||
def find_command_by_text(self, text: str) -> Optional[Tuple[Type['BaseCommand'], dict, 'CommandInfo']]:
|
||||
# sourcery skip: use-named-expression, use-next
|
||||
"""根据文本查找匹配的命令
|
||||
|
||||
@@ -595,15 +638,15 @@ class ComponentRegistry:
|
||||
return None
|
||||
|
||||
# === Tool 特定查询方法 ===
|
||||
def get_tool_registry(self) -> Dict[str, Type[BaseTool]]:
|
||||
def get_tool_registry(self) -> Dict[str, Type['BaseTool']]:
|
||||
"""获取Tool注册表"""
|
||||
return self._tool_registry.copy()
|
||||
|
||||
def get_llm_available_tools(self) -> Dict[str, Type[BaseTool]]:
|
||||
def get_llm_available_tools(self) -> Dict[str, Type['BaseTool']]:
|
||||
"""获取LLM可用的Tool列表"""
|
||||
return self._llm_available_tools.copy()
|
||||
|
||||
def get_registered_tool_info(self, tool_name: str) -> Optional[ToolInfo]:
|
||||
def get_registered_tool_info(self, tool_name: str) -> Optional['ToolInfo']:
|
||||
"""获取Tool信息
|
||||
|
||||
Args:
|
||||
@@ -616,13 +659,13 @@ class ComponentRegistry:
|
||||
return info if isinstance(info, ToolInfo) else None
|
||||
|
||||
# === PlusCommand 特定查询方法 ===
|
||||
def get_plus_command_registry(self) -> Dict[str, Type[PlusCommand]]:
|
||||
def get_plus_command_registry(self) -> Dict[str, Type['PlusCommand']]:
|
||||
"""获取PlusCommand注册表"""
|
||||
if not hasattr(self, "_plus_command_registry"):
|
||||
pass
|
||||
self._plus_command_registry: Dict[str, Type[PlusCommand]] = {}
|
||||
return self._plus_command_registry.copy()
|
||||
|
||||
def get_registered_plus_command_info(self, command_name: str) -> Optional[PlusCommandInfo]:
|
||||
def get_registered_plus_command_info(self, command_name: str) -> Optional['PlusCommandInfo']:
|
||||
"""获取PlusCommand信息
|
||||
|
||||
Args:
|
||||
@@ -636,26 +679,44 @@ class ComponentRegistry:
|
||||
|
||||
# === EventHandler 特定查询方法 ===
|
||||
|
||||
def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]:
|
||||
def get_event_handler_registry(self) -> Dict[str, Type['BaseEventHandler']]:
|
||||
"""获取事件处理器注册表"""
|
||||
return self._event_handler_registry.copy()
|
||||
|
||||
def get_registered_event_handler_info(self, handler_name: str) -> Optional[EventHandlerInfo]:
|
||||
def get_registered_event_handler_info(self, handler_name: str) -> Optional['EventHandlerInfo']:
|
||||
"""获取事件处理器信息"""
|
||||
info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER)
|
||||
return info if isinstance(info, EventHandlerInfo) else None
|
||||
|
||||
def get_enabled_event_handlers(self) -> Dict[str, Type[BaseEventHandler]]:
|
||||
def get_enabled_event_handlers(self) -> Dict[str, Type['BaseEventHandler']]:
|
||||
"""获取启用的事件处理器"""
|
||||
return self._enabled_event_handlers.copy()
|
||||
|
||||
# === Chatter 特定查询方法 ===
|
||||
def get_chatter_registry(self) -> Dict[str, Type['BaseChatter']]:
|
||||
"""获取Chatter注册表"""
|
||||
if not hasattr(self, '_chatter_registry'):
|
||||
self._chatter_registry: Dict[str, Type[BaseChatter]] = {}
|
||||
return self._chatter_registry.copy()
|
||||
|
||||
def get_enabled_chatter_registry(self) -> Dict[str, Type['BaseChatter']]:
|
||||
"""获取启用的Chatter注册表"""
|
||||
if not hasattr(self, '_enabled_chatter_registry'):
|
||||
self._enabled_chatter_registry: Dict[str, Type[BaseChatter]] = {}
|
||||
return self._enabled_chatter_registry.copy()
|
||||
|
||||
def get_registered_chatter_info(self, chatter_name: str) -> Optional['ChatterInfo']:
|
||||
"""获取Chatter信息"""
|
||||
info = self.get_component_info(chatter_name, ComponentType.CHATTER)
|
||||
return info if isinstance(info, ChatterInfo) else None
|
||||
|
||||
# === 插件查询方法 ===
|
||||
|
||||
def get_plugin_info(self, plugin_name: str) -> Optional[PluginInfo]:
|
||||
def get_plugin_info(self, plugin_name: str) -> Optional['PluginInfo']:
|
||||
"""获取插件信息"""
|
||||
return self._plugins.get(plugin_name)
|
||||
|
||||
def get_all_plugins(self) -> Dict[str, PluginInfo]:
|
||||
def get_all_plugins(self) -> Dict[str, 'PluginInfo']:
|
||||
"""获取所有插件"""
|
||||
return self._plugins.copy()
|
||||
|
||||
@@ -663,13 +724,12 @@ class ComponentRegistry:
|
||||
# """获取所有启用的插件"""
|
||||
# return {name: info for name, info in self._plugins.items() if info.enabled}
|
||||
|
||||
def get_plugin_components(self, plugin_name: str) -> List[ComponentInfo]:
|
||||
def get_plugin_components(self, plugin_name: str) -> List['ComponentInfo']:
|
||||
"""获取插件的所有组件"""
|
||||
plugin_info = self.get_plugin_info(plugin_name)
|
||||
return plugin_info.components if plugin_info else []
|
||||
|
||||
@staticmethod
|
||||
def get_plugin_config(plugin_name: str) -> dict:
|
||||
def get_plugin_config(self, plugin_name: str) -> dict:
|
||||
"""获取插件配置
|
||||
|
||||
Args:
|
||||
@@ -684,19 +744,20 @@ class ComponentRegistry:
|
||||
plugin_instance = plugin_manager.get_plugin_instance(plugin_name)
|
||||
if plugin_instance and plugin_instance.config:
|
||||
return plugin_instance.config
|
||||
|
||||
|
||||
# 如果插件实例不存在,尝试从配置文件读取
|
||||
try:
|
||||
import toml
|
||||
|
||||
config_path = Path("config") / "plugins" / plugin_name / "config.toml"
|
||||
if config_path.exists():
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = toml.load(f)
|
||||
logger.debug(f"从配置文件读取插件 {plugin_name} 的配置")
|
||||
return config_data
|
||||
except Exception as e:
|
||||
logger.debug(f"读取插件 {plugin_name} 配置文件失败: {e}")
|
||||
|
||||
|
||||
return {}
|
||||
|
||||
def get_registry_stats(self) -> Dict[str, Any]:
|
||||
@@ -706,6 +767,7 @@ class ComponentRegistry:
|
||||
tool_components: int = 0
|
||||
events_handlers: int = 0
|
||||
plus_command_components: int = 0
|
||||
chatter_components: int = 0
|
||||
for component in self._components.values():
|
||||
if component.component_type == ComponentType.ACTION:
|
||||
action_components += 1
|
||||
@@ -717,12 +779,15 @@ class ComponentRegistry:
|
||||
events_handlers += 1
|
||||
elif component.component_type == ComponentType.PLUS_COMMAND:
|
||||
plus_command_components += 1
|
||||
elif component.component_type == ComponentType.CHATTER:
|
||||
chatter_components += 1
|
||||
return {
|
||||
"action_components": action_components,
|
||||
"command_components": command_components,
|
||||
"tool_components": tool_components,
|
||||
"event_handlers": events_handlers,
|
||||
"plus_command_components": plus_command_components,
|
||||
"chatter_components": chatter_components,
|
||||
"total_components": len(self._components),
|
||||
"total_plugins": len(self._plugins),
|
||||
"components_by_type": {
|
||||
@@ -730,6 +795,8 @@ class ComponentRegistry:
|
||||
},
|
||||
"enabled_components": len([c for c in self._components.values() if c.enabled]),
|
||||
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
|
||||
"enabled_components": len([c for c in self._components.values() if c.enabled]),
|
||||
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
|
||||
}
|
||||
|
||||
# === 组件移除相关 ===
|
||||
|
||||
@@ -146,7 +146,9 @@ class EventManager:
|
||||
logger.info(f"事件 {event_name} 已禁用")
|
||||
return True
|
||||
|
||||
def register_event_handler(self, handler_class: Type[BaseEventHandler], plugin_config: Optional[dict] = None) -> bool:
|
||||
def register_event_handler(
|
||||
self, handler_class: Type[BaseEventHandler], plugin_config: Optional[dict] = None
|
||||
) -> bool:
|
||||
"""注册事件处理器
|
||||
|
||||
Args:
|
||||
@@ -168,7 +170,7 @@ class EventManager:
|
||||
# 创建事件处理器实例,传递插件配置
|
||||
handler_instance = handler_class()
|
||||
handler_instance.plugin_config = plugin_config
|
||||
if plugin_config is not None and hasattr(handler_instance, 'set_plugin_config'):
|
||||
if plugin_config is not None and hasattr(handler_instance, "set_plugin_config"):
|
||||
handler_instance.set_plugin_config(plugin_config)
|
||||
|
||||
self._event_handlers[handler_name] = handler_instance
|
||||
|
||||
@@ -129,9 +129,7 @@ class PluginManager:
|
||||
self._show_plugin_components(plugin_name)
|
||||
|
||||
# 检查并调用 on_plugin_loaded 钩子(如果存在)
|
||||
if hasattr(plugin_instance, "on_plugin_loaded") and callable(
|
||||
plugin_instance.on_plugin_loaded
|
||||
):
|
||||
if hasattr(plugin_instance, "on_plugin_loaded") and callable(plugin_instance.on_plugin_loaded):
|
||||
logger.debug(f"为插件 '{plugin_name}' 调用 on_plugin_loaded 钩子")
|
||||
try:
|
||||
# 使用 asyncio.create_task 确保它不会阻塞加载流程
|
||||
@@ -380,13 +378,14 @@ class PluginManager:
|
||||
tool_count = stats.get("tool_components", 0)
|
||||
event_handler_count = stats.get("event_handlers", 0)
|
||||
plus_command_count = stats.get("plus_command_components", 0)
|
||||
chatter_count = stats.get("chatter_components", 0)
|
||||
total_components = stats.get("total_components", 0)
|
||||
|
||||
# 📋 显示插件加载总览
|
||||
if total_registered > 0:
|
||||
logger.info("🎉 插件系统加载完成!")
|
||||
logger.info(
|
||||
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, PlusCommand: {plus_command_count}, EventHandler: {event_handler_count})"
|
||||
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, PlusCommand: {plus_command_count}, EventHandler: {event_handler_count}, Chatter: {chatter_count})"
|
||||
)
|
||||
|
||||
# 显示详细的插件列表
|
||||
@@ -442,6 +441,12 @@ class PluginManager:
|
||||
if plus_command_components:
|
||||
plus_command_names = [c.name for c in plus_command_components]
|
||||
logger.info(f" ⚡ PlusCommand组件: {', '.join(plus_command_names)}")
|
||||
chatter_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.CHATTER
|
||||
]
|
||||
if chatter_components:
|
||||
chatter_names = [c.name for c in chatter_components]
|
||||
logger.info(f" 🗣️ Chatter组件: {', '.join(chatter_names)}")
|
||||
if event_handler_components:
|
||||
event_handler_names = [c.name for c in event_handler_components]
|
||||
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}")
|
||||
|
||||
125
src/plugins/built_in/affinity_flow_chatter/README.md
Normal file
125
src/plugins/built_in/affinity_flow_chatter/README.md
Normal file
@@ -0,0 +1,125 @@
|
||||
# 亲和力聊天处理器插件
|
||||
|
||||
## 概述
|
||||
|
||||
这是一个内置的chatter插件,实现了基于亲和力流的智能聊天处理器,具有兴趣度评分和人物关系构建功能。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- **智能兴趣度评分**: 自动识别和评估用户兴趣话题
|
||||
- **人物关系系统**: 根据互动历史建立和维持用户关系
|
||||
- **多聊天类型支持**: 支持私聊和群聊场景
|
||||
- **插件化架构**: 完全集成到插件系统中
|
||||
|
||||
## 组件架构
|
||||
|
||||
### BaseChatter (抽象基类)
|
||||
- 位置: `src/plugin_system/base/base_chatter.py`
|
||||
- 功能: 定义所有chatter组件的基础接口
|
||||
- 必须实现的方法: `execute(context: StreamContext) -> dict`
|
||||
|
||||
### ChatterManager (管理器)
|
||||
- 位置: `src/chat/chatter_manager.py`
|
||||
- 功能: 管理和调度所有chatter组件
|
||||
- 特性: 自动从插件系统注册和发现chatter组件
|
||||
|
||||
### AffinityChatter (具体实现)
|
||||
- 位置: `src/plugins/built_in/chatter/affinity_chatter.py`
|
||||
- 功能: 亲和力流聊天处理器的具体实现
|
||||
- 支持的聊天类型: PRIVATE, GROUP
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 1. 基本使用
|
||||
|
||||
```python
|
||||
from src.chat.chatter_manager import ChatterManager
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
|
||||
# 初始化
|
||||
action_manager = ChatterActionManager()
|
||||
chatter_manager = ChatterManager(action_manager)
|
||||
|
||||
# 处理消息流
|
||||
result = await chatter_manager.process_stream_context(stream_id, context)
|
||||
```
|
||||
|
||||
### 2. 创建自定义Chatter
|
||||
|
||||
```python
|
||||
from src.plugin_system.base.base_chatter import BaseChatter
|
||||
from src.plugin_system.base.component_types import ChatType, ComponentType
|
||||
from src.plugin_system.base.component_types import ChatterInfo
|
||||
|
||||
class CustomChatter(BaseChatter):
|
||||
chat_types = [ChatType.PRIVATE] # 只支持私聊
|
||||
|
||||
async def execute(self, context: StreamContext) -> dict:
|
||||
# 实现你的聊天逻辑
|
||||
return {"success": True, "message": "处理完成"}
|
||||
|
||||
# 在插件中注册
|
||||
async def on_load(self):
|
||||
chatter_info = ChatterInfo(
|
||||
name="custom_chatter",
|
||||
component_type=ComponentType.CHATTER,
|
||||
description="自定义聊天处理器",
|
||||
enabled=True,
|
||||
plugin_name=self.name,
|
||||
chat_type_allow=ChatType.PRIVATE
|
||||
)
|
||||
|
||||
ComponentRegistry.register_component(
|
||||
component_info=chatter_info,
|
||||
component_class=CustomChatter
|
||||
)
|
||||
```
|
||||
|
||||
## 配置
|
||||
|
||||
### 插件配置文件
|
||||
- 位置: `src/plugins/built_in/chatter/_manifest.json`
|
||||
- 包含插件信息和组件配置
|
||||
|
||||
### 聊天类型
|
||||
- `PRIVATE`: 私聊
|
||||
- `GROUP`: 群聊
|
||||
- `ALL`: 所有类型
|
||||
|
||||
## 核心概念
|
||||
|
||||
### 1. 兴趣值系统
|
||||
- 自动识别同类话题
|
||||
- 兴趣值会根据聊天频率增减
|
||||
- 支持新话题的自动学习
|
||||
|
||||
### 2. 人物关系系统
|
||||
- 根据互动质量建立关系分
|
||||
- 不同关系分对应不同的回复风格
|
||||
- 支持情感化的交流
|
||||
|
||||
### 3. 执行流程
|
||||
1. 接收StreamContext
|
||||
2. 使用ActionPlanner进行规划
|
||||
3. 执行相应的Action
|
||||
4. 返回处理结果
|
||||
|
||||
## 扩展开发
|
||||
|
||||
### 添加新的Chatter类型
|
||||
1. 继承BaseChatter类
|
||||
2. 实现execute方法
|
||||
3. 在插件中注册组件
|
||||
4. 配置支持的聊天类型
|
||||
|
||||
### 集成现有功能
|
||||
- 使用ActionPlanner进行动作规划
|
||||
- 通过ActionManager执行动作
|
||||
- 利用现有的记忆和知识系统
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. 所有chatter组件必须实现`execute`方法
|
||||
2. 插件注册时需要指定支持的聊天类型
|
||||
3. 组件名称不能包含点号(.)
|
||||
4. 确保在插件卸载时正确清理资源
|
||||
7
src/plugins/built_in/affinity_flow_chatter/__init__.py
Normal file
7
src/plugins/built_in/affinity_flow_chatter/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
亲和力聊天处理器插件
|
||||
"""
|
||||
|
||||
from .plugin import AffinityChatterPlugin
|
||||
|
||||
__all__ = ["AffinityChatterPlugin"]
|
||||
23
src/plugins/built_in/affinity_flow_chatter/_manifest.json
Normal file
23
src/plugins/built_in/affinity_flow_chatter/_manifest.json
Normal file
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"manifest_version": 1,
|
||||
"name": "affinity_chatter",
|
||||
"display_name": "Affinity Flow Chatter",
|
||||
"description": "Built-in chatter plugin for affinity flow with interest scoring and relationship building",
|
||||
"version": "1.0.0",
|
||||
"author": "MoFox",
|
||||
"plugin_class": "AffinityChatterPlugin",
|
||||
"enabled": true,
|
||||
"is_built_in": true,
|
||||
"components": [
|
||||
{
|
||||
"name": "affinity_chatter",
|
||||
"type": "chatter",
|
||||
"description": "Affinity flow chatter with intelligent interest scoring and relationship building",
|
||||
"enabled": true,
|
||||
"chat_type_allow": ["all"]
|
||||
}
|
||||
],
|
||||
"host_application": { "min_version": "0.8.0" },
|
||||
"keywords": ["chatter", "affinity", "conversation"],
|
||||
"categories": ["Chat", "AI"]
|
||||
}
|
||||
236
src/plugins/built_in/affinity_flow_chatter/affinity_chatter.py
Normal file
236
src/plugins/built_in/affinity_flow_chatter/affinity_chatter.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""
|
||||
亲和力聊天处理器
|
||||
基于现有的AffinityFlowChatter重构为插件化组件
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
|
||||
from src.plugin_system.base.base_chatter import BaseChatter
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.plugins.built_in.affinity_flow_chatter.planner import ChatterActionPlanner
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.express.expression_learner import expression_learner_manager
|
||||
|
||||
logger = get_logger("affinity_chatter")
|
||||
|
||||
# 定义颜色
|
||||
SOFT_GREEN = "\033[38;5;118m" # 一个更柔和的绿色
|
||||
RESET_COLOR = "\033[0m"
|
||||
|
||||
|
||||
class AffinityChatter(BaseChatter):
|
||||
"""亲和力聊天处理器"""
|
||||
|
||||
chatter_name: str = "AffinityChatter"
|
||||
chatter_description: str = "基于亲和力模型的智能聊天处理器,支持多种聊天类型"
|
||||
chat_types: list[ChatType] = [ChatType.ALL] # 支持所有聊天类型
|
||||
|
||||
def __init__(self, stream_id: str, action_manager: ChatterActionManager):
|
||||
"""
|
||||
初始化亲和力聊天处理器
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
planner: 动作规划器
|
||||
action_manager: 动作管理器
|
||||
"""
|
||||
super().__init__(stream_id, action_manager)
|
||||
self.planner = ChatterActionPlanner(stream_id, action_manager)
|
||||
|
||||
# 处理器统计
|
||||
self.stats = {
|
||||
"messages_processed": 0,
|
||||
"plans_created": 0,
|
||||
"actions_executed": 0,
|
||||
"successful_executions": 0,
|
||||
"failed_executions": 0,
|
||||
}
|
||||
self.last_activity_time = time.time()
|
||||
|
||||
async def execute(self, context: StreamContext) -> dict:
|
||||
"""
|
||||
处理StreamContext对象
|
||||
|
||||
Args:
|
||||
context: StreamContext对象,包含聊天流的所有消息信息
|
||||
|
||||
Returns:
|
||||
处理结果字典
|
||||
"""
|
||||
try:
|
||||
# 触发表达学习
|
||||
learner = expression_learner_manager.get_expression_learner(self.stream_id)
|
||||
asyncio.create_task(learner.trigger_learning_for_chat())
|
||||
|
||||
unread_messages = context.get_unread_messages()
|
||||
|
||||
# 使用增强版规划器处理消息
|
||||
actions, target_message = await self.planner.plan(context=context)
|
||||
self.stats["plans_created"] += 1
|
||||
|
||||
# 执行动作(如果规划器返回了动作)
|
||||
execution_result = {"executed_count": len(actions) if actions else 0}
|
||||
if actions:
|
||||
logger.debug(f"聊天流 {self.stream_id} 生成了 {len(actions)} 个动作")
|
||||
|
||||
# 更新统计
|
||||
self.stats["messages_processed"] += 1
|
||||
self.stats["actions_executed"] += execution_result.get("executed_count", 0)
|
||||
self.stats["successful_executions"] += 1
|
||||
self.last_activity_time = time.time()
|
||||
|
||||
result = {
|
||||
"success": True,
|
||||
"stream_id": self.stream_id,
|
||||
"plan_created": True,
|
||||
"actions_count": len(actions) if actions else 0,
|
||||
"has_target_message": target_message is not None,
|
||||
"unread_messages_processed": len(unread_messages),
|
||||
**execution_result,
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
f"聊天流 {self.stream_id} StreamContext处理成功: 动作数={result['actions_count']}, 未读消息={result['unread_messages_processed']}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"亲和力聊天处理器 {self.stream_id} 处理StreamContext时出错: {e}\n{traceback.format_exc()}")
|
||||
self.stats["failed_executions"] += 1
|
||||
self.last_activity_time = time.time()
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"stream_id": self.stream_id,
|
||||
"error_message": str(e),
|
||||
"executed_count": 0,
|
||||
}
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取处理器统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
return self.stats.copy()
|
||||
|
||||
def get_planner_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取规划器统计信息
|
||||
|
||||
Returns:
|
||||
规划器统计信息字典
|
||||
"""
|
||||
return self.planner.get_planner_stats()
|
||||
|
||||
def get_interest_scoring_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取兴趣度评分统计信息
|
||||
|
||||
Returns:
|
||||
兴趣度评分统计信息字典
|
||||
"""
|
||||
return self.planner.get_interest_scoring_stats()
|
||||
|
||||
def get_relationship_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取用户关系统计信息
|
||||
|
||||
Returns:
|
||||
用户关系统计信息字典
|
||||
"""
|
||||
return self.planner.get_relationship_stats()
|
||||
|
||||
def get_current_mood_state(self) -> str:
|
||||
"""
|
||||
获取当前聊天的情绪状态
|
||||
|
||||
Returns:
|
||||
当前情绪状态描述
|
||||
"""
|
||||
return self.planner.get_current_mood_state()
|
||||
|
||||
def get_mood_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取情绪状态统计信息
|
||||
|
||||
Returns:
|
||||
情绪状态统计信息字典
|
||||
"""
|
||||
return self.planner.get_mood_stats()
|
||||
|
||||
def get_user_relationship(self, user_id: str) -> float:
|
||||
"""
|
||||
获取用户关系分
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
用户关系分 (0.0-1.0)
|
||||
"""
|
||||
return self.planner.get_user_relationship(user_id)
|
||||
|
||||
def update_interest_keywords(self, new_keywords: dict):
|
||||
"""
|
||||
更新兴趣关键词
|
||||
|
||||
Args:
|
||||
new_keywords: 新的兴趣关键词字典
|
||||
"""
|
||||
self.planner.update_interest_keywords(new_keywords)
|
||||
logger.info(f"聊天流 {self.stream_id} 已更新兴趣关键词: {list(new_keywords.keys())}")
|
||||
|
||||
def reset_stats(self):
|
||||
"""重置统计信息"""
|
||||
self.stats = {
|
||||
"messages_processed": 0,
|
||||
"plans_created": 0,
|
||||
"actions_executed": 0,
|
||||
"successful_executions": 0,
|
||||
"failed_executions": 0,
|
||||
}
|
||||
|
||||
def is_active(self, max_inactive_minutes: int = 60) -> bool:
|
||||
"""
|
||||
检查处理器是否活跃
|
||||
|
||||
Args:
|
||||
max_inactive_minutes: 最大不活跃分钟数
|
||||
|
||||
Returns:
|
||||
是否活跃
|
||||
"""
|
||||
current_time = time.time()
|
||||
max_inactive_seconds = max_inactive_minutes * 60
|
||||
return (current_time - self.last_activity_time) < max_inactive_seconds
|
||||
|
||||
def get_activity_time(self) -> float:
|
||||
"""
|
||||
获取最后活动时间
|
||||
|
||||
Returns:
|
||||
最后活动时间戳
|
||||
"""
|
||||
return self.last_activity_time
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""字符串表示"""
|
||||
return f"AffinityChatter(stream_id={self.stream_id}, messages={self.stats['messages_processed']})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""详细字符串表示"""
|
||||
return (
|
||||
f"AffinityChatter(stream_id={self.stream_id}, "
|
||||
f"messages_processed={self.stats['messages_processed']}, "
|
||||
f"plans_created={self.stats['plans_created']}, "
|
||||
f"last_activity={datetime.fromtimestamp(self.last_activity_time)})"
|
||||
)
|
||||
333
src/plugins/built_in/affinity_flow_chatter/interest_scoring.py
Normal file
333
src/plugins/built_in/affinity_flow_chatter/interest_scoring.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""
|
||||
兴趣度评分系统
|
||||
基于多维度评分机制,包括兴趣匹配度、用户关系分、提及度和时间因子
|
||||
现在使用embedding计算智能兴趣匹配
|
||||
"""
|
||||
|
||||
import traceback
|
||||
from typing import Dict, List, Any
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.info_data_model import InterestScore
|
||||
from src.chat.interest_system import bot_interest_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker
|
||||
logger = get_logger("chatter_interest_scoring")
|
||||
|
||||
# 定义颜色
|
||||
SOFT_BLUE = "\033[38;5;67m"
|
||||
RESET_COLOR = "\033[0m"
|
||||
|
||||
|
||||
class ChatterInterestScoringSystem:
|
||||
"""兴趣度评分系统"""
|
||||
|
||||
def __init__(self):
|
||||
# 智能兴趣匹配配置
|
||||
self.use_smart_matching = True
|
||||
|
||||
# 从配置加载评分权重
|
||||
affinity_config = global_config.affinity_flow
|
||||
self.score_weights = {
|
||||
"interest_match": affinity_config.keyword_match_weight, # 兴趣匹配度权重
|
||||
"relationship": affinity_config.relationship_weight, # 关系分权重
|
||||
"mentioned": affinity_config.mention_bot_weight, # 是否提及bot权重
|
||||
}
|
||||
|
||||
# 评分阈值
|
||||
self.reply_threshold = affinity_config.reply_action_interest_threshold # 回复动作兴趣阈值
|
||||
self.mention_threshold = affinity_config.mention_bot_adjustment_threshold # 提及bot后的调整阈值
|
||||
|
||||
# 连续不回复概率提升
|
||||
self.no_reply_count = 0
|
||||
self.max_no_reply_count = affinity_config.max_no_reply_count
|
||||
self.probability_boost_per_no_reply = (
|
||||
affinity_config.no_reply_threshold_adjustment / affinity_config.max_no_reply_count
|
||||
) # 每次不回复增加的概率
|
||||
|
||||
# 用户关系数据
|
||||
self.user_relationships: Dict[str, float] = {} # user_id -> relationship_score
|
||||
|
||||
async def calculate_interest_scores(
|
||||
self, messages: List[DatabaseMessages], bot_nickname: str
|
||||
) -> List[InterestScore]:
|
||||
"""计算消息的兴趣度评分"""
|
||||
user_messages = [msg for msg in messages if str(msg.user_info.user_id) != str(global_config.bot.qq_account)]
|
||||
if not user_messages:
|
||||
return []
|
||||
|
||||
scores = []
|
||||
for _, msg in enumerate(user_messages, 1):
|
||||
score = await self._calculate_single_message_score(msg, bot_nickname)
|
||||
scores.append(score)
|
||||
|
||||
return scores
|
||||
|
||||
async def _calculate_single_message_score(self, message: DatabaseMessages, bot_nickname: str) -> InterestScore:
|
||||
"""计算单条消息的兴趣度评分"""
|
||||
|
||||
keywords = self._extract_keywords_from_database(message)
|
||||
interest_match_score = await self._calculate_interest_match_score(message.processed_plain_text, keywords)
|
||||
relationship_score = self._calculate_relationship_score(message.user_info.user_id)
|
||||
mentioned_score = self._calculate_mentioned_score(message, bot_nickname)
|
||||
|
||||
total_score = (
|
||||
interest_match_score * self.score_weights["interest_match"]
|
||||
+ relationship_score * self.score_weights["relationship"]
|
||||
+ mentioned_score * self.score_weights["mentioned"]
|
||||
)
|
||||
|
||||
details = {
|
||||
"interest_match": f"兴趣匹配: {interest_match_score:.3f}",
|
||||
"relationship": f"关系: {relationship_score:.3f}",
|
||||
"mentioned": f"提及: {mentioned_score:.3f}",
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
f"消息得分详情: {total_score:.3f} (匹配: {interest_match_score:.2f}, 关系: {relationship_score:.2f}, 提及: {mentioned_score:.2f})"
|
||||
)
|
||||
|
||||
return InterestScore(
|
||||
message_id=message.message_id,
|
||||
total_score=total_score,
|
||||
interest_match_score=interest_match_score,
|
||||
relationship_score=relationship_score,
|
||||
mentioned_score=mentioned_score,
|
||||
details=details,
|
||||
)
|
||||
|
||||
async def _calculate_interest_match_score(self, content: str, keywords: List[str] = None) -> float:
|
||||
"""计算兴趣匹配度 - 使用智能embedding匹配"""
|
||||
if not content:
|
||||
return 0.0
|
||||
|
||||
# 使用智能匹配(embedding)
|
||||
if self.use_smart_matching and bot_interest_manager.is_initialized:
|
||||
return await self._calculate_smart_interest_match(content, keywords)
|
||||
else:
|
||||
# 智能匹配未初始化,返回默认分数
|
||||
return 0.3
|
||||
|
||||
async def _calculate_smart_interest_match(self, content: str, keywords: List[str] = None) -> float:
|
||||
"""使用embedding计算智能兴趣匹配"""
|
||||
try:
|
||||
# 如果没有传入关键词,则提取
|
||||
if not keywords:
|
||||
keywords = self._extract_keywords_from_content(content)
|
||||
|
||||
# 使用机器人兴趣管理器计算匹配度
|
||||
match_result = await bot_interest_manager.calculate_interest_match(content, keywords)
|
||||
|
||||
if match_result:
|
||||
# 返回匹配分数,考虑置信度和匹配标签数量
|
||||
affinity_config = global_config.affinity_flow
|
||||
match_count_bonus = min(
|
||||
len(match_result.matched_tags) * affinity_config.match_count_bonus, affinity_config.max_match_bonus
|
||||
)
|
||||
final_score = match_result.overall_score * 1.15 * match_result.confidence + match_count_bonus
|
||||
return final_score
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"智能兴趣匹配计算失败: {e}")
|
||||
return 0.0
|
||||
|
||||
def _extract_keywords_from_database(self, message: DatabaseMessages) -> List[str]:
|
||||
"""从数据库消息中提取关键词"""
|
||||
keywords = []
|
||||
|
||||
# 尝试从 key_words 字段提取(存储的是JSON字符串)
|
||||
if message.key_words:
|
||||
try:
|
||||
import orjson
|
||||
|
||||
keywords = orjson.loads(message.key_words)
|
||||
if not isinstance(keywords, list):
|
||||
keywords = []
|
||||
except (orjson.JSONDecodeError, TypeError):
|
||||
keywords = []
|
||||
|
||||
# 如果没有 keywords,尝试从 key_words_lite 提取
|
||||
if not keywords and message.key_words_lite:
|
||||
try:
|
||||
import orjson
|
||||
|
||||
keywords = orjson.loads(message.key_words_lite)
|
||||
if not isinstance(keywords, list):
|
||||
keywords = []
|
||||
except (orjson.JSONDecodeError, TypeError):
|
||||
keywords = []
|
||||
|
||||
# 如果还是没有,从消息内容中提取(降级方案)
|
||||
if not keywords:
|
||||
keywords = self._extract_keywords_from_content(message.processed_plain_text)
|
||||
|
||||
return keywords[:15] # 返回前15个关键词
|
||||
|
||||
def _extract_keywords_from_content(self, content: str) -> List[str]:
|
||||
"""从内容中提取关键词(降级方案)"""
|
||||
import re
|
||||
|
||||
# 清理文本
|
||||
content = re.sub(r"[^\w\s\u4e00-\u9fff]", " ", content) # 保留中文、英文、数字
|
||||
words = content.split()
|
||||
|
||||
# 过滤和关键词提取
|
||||
keywords = []
|
||||
for word in words:
|
||||
word = word.strip()
|
||||
if (
|
||||
len(word) >= 2 # 至少2个字符
|
||||
and word.isalnum() # 字母数字
|
||||
and not word.isdigit()
|
||||
): # 不是纯数字
|
||||
keywords.append(word.lower())
|
||||
|
||||
# 去重并限制数量
|
||||
unique_keywords = list(set(keywords))
|
||||
return unique_keywords[:10] # 返回前10个唯一关键词
|
||||
|
||||
def _calculate_relationship_score(self, user_id: str) -> float:
|
||||
"""计算关系分 - 从数据库获取关系分"""
|
||||
# 优先使用内存中的关系分
|
||||
if user_id in self.user_relationships:
|
||||
relationship_value = self.user_relationships[user_id]
|
||||
return min(relationship_value, 1.0)
|
||||
|
||||
# 如果内存中没有,尝试从关系追踪器获取
|
||||
if hasattr(self, "relationship_tracker") and self.relationship_tracker:
|
||||
try:
|
||||
relationship_score = self.relationship_tracker.get_user_relationship_score(user_id)
|
||||
# 同时更新内存缓存
|
||||
self.user_relationships[user_id] = relationship_score
|
||||
return relationship_score
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
# 尝试从全局关系追踪器获取
|
||||
try:
|
||||
from .relationship_tracker import ChatterRelationshipTracker
|
||||
|
||||
global_tracker = ChatterRelationshipTracker()
|
||||
if global_tracker:
|
||||
relationship_score = global_tracker.get_user_relationship_score(user_id)
|
||||
# 同时更新内存缓存
|
||||
self.user_relationships[user_id] = relationship_score
|
||||
return relationship_score
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 默认新用户的基础分
|
||||
return global_config.affinity_flow.base_relationship_score
|
||||
|
||||
def _calculate_mentioned_score(self, msg: DatabaseMessages, bot_nickname: str) -> float:
|
||||
"""计算提及分数"""
|
||||
if not msg.processed_plain_text:
|
||||
return 0.0
|
||||
|
||||
# 检查是否被提及
|
||||
bot_aliases = [bot_nickname] + global_config.bot.alias_names
|
||||
is_mentioned = msg.is_mentioned or any(alias in msg.processed_plain_text for alias in bot_aliases if alias)
|
||||
|
||||
# 如果被提及或是私聊,都视为提及了bot
|
||||
if is_mentioned or not hasattr(msg, "chat_info_group_id"):
|
||||
return global_config.affinity_flow.mention_bot_interest_score
|
||||
|
||||
return 0.0
|
||||
|
||||
def should_reply(self, score: InterestScore, message: "DatabaseMessages") -> bool:
|
||||
"""判断是否应该回复"""
|
||||
base_threshold = self.reply_threshold
|
||||
|
||||
# 如果被提及,降低阈值
|
||||
if score.mentioned_score >= global_config.affinity_flow.mention_bot_adjustment_threshold:
|
||||
base_threshold = self.mention_threshold
|
||||
|
||||
# 计算连续不回复的概率提升
|
||||
probability_boost = min(self.no_reply_count * self.probability_boost_per_no_reply, 0.8)
|
||||
effective_threshold = base_threshold - probability_boost
|
||||
|
||||
# 做出决策
|
||||
should_reply = score.total_score >= effective_threshold
|
||||
decision = "回复" if should_reply else "不回复"
|
||||
logger.info(
|
||||
f"{SOFT_BLUE}决策: {decision} (兴趣度: {score.total_score:.3f} / 阈值: {effective_threshold:.3f}){RESET_COLOR}"
|
||||
)
|
||||
|
||||
return should_reply, score.total_score
|
||||
|
||||
def record_reply_action(self, did_reply: bool):
|
||||
"""记录回复动作"""
|
||||
old_count = self.no_reply_count
|
||||
if did_reply:
|
||||
self.no_reply_count = max(0, self.no_reply_count - global_config.affinity_flow.reply_cooldown_reduction)
|
||||
action = "回复"
|
||||
else:
|
||||
self.no_reply_count += 1
|
||||
action = "不回复"
|
||||
|
||||
# 限制最大计数
|
||||
self.no_reply_count = min(self.no_reply_count, self.max_no_reply_count)
|
||||
logger.info(f"动作: {action}, 连续不回复次数: {old_count} -> {self.no_reply_count}")
|
||||
|
||||
def update_user_relationship(self, user_id: str, relationship_change: float):
|
||||
"""更新用户关系"""
|
||||
old_score = self.user_relationships.get(
|
||||
user_id, global_config.affinity_flow.base_relationship_score
|
||||
) # 默认新用户分数
|
||||
new_score = max(0.0, min(1.0, old_score + relationship_change))
|
||||
|
||||
self.user_relationships[user_id] = new_score
|
||||
|
||||
logger.info(f"用户关系: {user_id} | {old_score:.3f} → {new_score:.3f}")
|
||||
|
||||
def get_user_relationship(self, user_id: str) -> float:
|
||||
"""获取用户关系分"""
|
||||
return self.user_relationships.get(user_id, 0.3)
|
||||
|
||||
def get_scoring_stats(self) -> Dict:
|
||||
"""获取评分系统统计"""
|
||||
return {
|
||||
"no_reply_count": self.no_reply_count,
|
||||
"max_no_reply_count": self.max_no_reply_count,
|
||||
"reply_threshold": self.reply_threshold,
|
||||
"mention_threshold": self.mention_threshold,
|
||||
"user_relationships": len(self.user_relationships),
|
||||
}
|
||||
|
||||
def reset_stats(self):
|
||||
"""重置统计信息"""
|
||||
self.no_reply_count = 0
|
||||
logger.info("重置兴趣度评分系统统计")
|
||||
|
||||
async def initialize_smart_interests(self, personality_description: str, personality_id: str = "default"):
|
||||
"""初始化智能兴趣系统"""
|
||||
try:
|
||||
logger.info("开始初始化智能兴趣系统...")
|
||||
logger.info(f"人设ID: {personality_id}, 描述长度: {len(personality_description)}")
|
||||
|
||||
await bot_interest_manager.initialize(personality_description, personality_id)
|
||||
logger.info("智能兴趣系统初始化完成。")
|
||||
|
||||
# 显示初始化后的统计信息
|
||||
bot_interest_manager.get_interest_stats()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"初始化智能兴趣系统失败: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
def get_matching_config(self) -> Dict[str, Any]:
|
||||
"""获取匹配配置信息"""
|
||||
return {
|
||||
"use_smart_matching": self.use_smart_matching,
|
||||
"smart_system_initialized": bot_interest_manager.is_initialized,
|
||||
"smart_system_stats": bot_interest_manager.get_interest_stats()
|
||||
if bot_interest_manager.is_initialized
|
||||
else None,
|
||||
}
|
||||
|
||||
|
||||
# 创建全局兴趣评分系统实例
|
||||
chatter_interest_scoring_system = ChatterInterestScoringSystem()
|
||||
368
src/plugins/built_in/affinity_flow_chatter/plan_executor.py
Normal file
368
src/plugins/built_in/affinity_flow_chatter/plan_executor.py
Normal file
@@ -0,0 +1,368 @@
|
||||
"""
|
||||
PlanExecutor: 接收 Plan 对象并执行其中的所有动作。
|
||||
集成用户关系追踪机制,自动记录交互并更新关系。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.common.data_models.info_data_model import Plan, ActionPlannerInfo
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("plan_executor")
|
||||
|
||||
|
||||
class ChatterPlanExecutor:
|
||||
"""
|
||||
增强版PlanExecutor,集成用户关系追踪机制。
|
||||
|
||||
功能:
|
||||
1. 执行Plan中的所有动作
|
||||
2. 自动记录用户交互并添加到关系追踪
|
||||
3. 分类执行回复动作和其他动作
|
||||
4. 提供完整的执行统计和监控
|
||||
"""
|
||||
|
||||
def __init__(self, action_manager: ChatterActionManager):
|
||||
"""
|
||||
初始化增强版PlanExecutor。
|
||||
|
||||
Args:
|
||||
action_manager (ChatterActionManager): 用于实际执行各种动作的管理器实例。
|
||||
"""
|
||||
self.action_manager = action_manager
|
||||
|
||||
# 执行统计
|
||||
self.execution_stats = {
|
||||
"total_executed": 0,
|
||||
"successful_executions": 0,
|
||||
"failed_executions": 0,
|
||||
"reply_executions": 0,
|
||||
"other_action_executions": 0,
|
||||
"execution_times": [],
|
||||
}
|
||||
|
||||
# 用户关系追踪引用
|
||||
self.relationship_tracker = None
|
||||
|
||||
def set_relationship_tracker(self, relationship_tracker):
|
||||
"""设置关系追踪器"""
|
||||
self.relationship_tracker = relationship_tracker
|
||||
|
||||
async def execute(self, plan: Plan) -> Dict[str, any]:
|
||||
"""
|
||||
遍历并执行Plan对象中`decided_actions`列表里的所有动作。
|
||||
|
||||
Args:
|
||||
plan (Plan): 包含待执行动作列表的Plan对象。
|
||||
|
||||
Returns:
|
||||
Dict[str, any]: 执行结果统计信息
|
||||
"""
|
||||
if not plan.decided_actions:
|
||||
logger.info("没有需要执行的动作。")
|
||||
return {"executed_count": 0, "results": []}
|
||||
|
||||
# 像hfc一样,提前打印将要执行的动作
|
||||
action_types = [action.action_type for action in plan.decided_actions]
|
||||
logger.info(f"选择动作: {', '.join(action_types) if action_types else '无'}")
|
||||
|
||||
execution_results = []
|
||||
reply_actions = []
|
||||
other_actions = []
|
||||
|
||||
# 分类动作:回复动作和其他动作
|
||||
for action_info in plan.decided_actions:
|
||||
if action_info.action_type in ["reply", "proactive_reply"]:
|
||||
reply_actions.append(action_info)
|
||||
else:
|
||||
other_actions.append(action_info)
|
||||
|
||||
# 执行回复动作(优先执行)
|
||||
if reply_actions:
|
||||
reply_result = await self._execute_reply_actions(reply_actions, plan)
|
||||
execution_results.extend(reply_result["results"])
|
||||
self.execution_stats["reply_executions"] += len(reply_actions)
|
||||
|
||||
# 将其他动作放入后台任务执行,避免阻塞主流程
|
||||
if other_actions:
|
||||
asyncio.create_task(self._execute_other_actions(other_actions, plan))
|
||||
logger.info(f"已将 {len(other_actions)} 个其他动作放入后台任务执行。")
|
||||
# 注意:后台任务的结果不会立即计入本次返回的统计数据
|
||||
|
||||
# 更新总体统计
|
||||
self.execution_stats["total_executed"] += len(plan.decided_actions)
|
||||
successful_count = sum(1 for r in execution_results if r["success"])
|
||||
self.execution_stats["successful_executions"] += successful_count
|
||||
self.execution_stats["failed_executions"] += len(execution_results) - successful_count
|
||||
|
||||
logger.info(
|
||||
f"规划执行完成: 总数={len(plan.decided_actions)}, 成功={successful_count}, 失败={len(execution_results) - successful_count}"
|
||||
)
|
||||
|
||||
return {
|
||||
"executed_count": len(plan.decided_actions),
|
||||
"successful_count": successful_count,
|
||||
"failed_count": len(execution_results) - successful_count,
|
||||
"results": execution_results,
|
||||
}
|
||||
|
||||
async def _execute_reply_actions(self, reply_actions: List[ActionPlannerInfo], plan: Plan) -> Dict[str, any]:
|
||||
"""执行回复动作"""
|
||||
results = []
|
||||
|
||||
for action_info in reply_actions:
|
||||
result = await self._execute_single_reply_action(action_info, plan)
|
||||
results.append(result)
|
||||
|
||||
return {"results": results}
|
||||
|
||||
async def _execute_single_reply_action(self, action_info: ActionPlannerInfo, plan: Plan) -> Dict[str, any]:
|
||||
"""执行单个回复动作"""
|
||||
start_time = time.time()
|
||||
success = False
|
||||
error_message = ""
|
||||
reply_content = ""
|
||||
|
||||
try:
|
||||
logger.info(f"执行回复动作: {action_info.action_type} (原因: {action_info.reasoning})")
|
||||
|
||||
# 获取用户ID - 兼容对象和字典
|
||||
if hasattr(action_info.action_message, "user_info"):
|
||||
user_id = action_info.action_message.user_info.user_id
|
||||
else:
|
||||
user_id = action_info.action_message.get("user_info", {}).get("user_id")
|
||||
|
||||
if user_id == str(global_config.bot.qq_account):
|
||||
logger.warning("尝试回复自己,跳过此动作以防止死循环。")
|
||||
return {
|
||||
"action_type": action_info.action_type,
|
||||
"success": False,
|
||||
"error_message": "尝试回复自己,跳过此动作以防止死循环。",
|
||||
"execution_time": 0,
|
||||
"reasoning": action_info.reasoning,
|
||||
"reply_content": "",
|
||||
}
|
||||
# 构建回复动作参数
|
||||
action_params = {
|
||||
"chat_id": plan.chat_id,
|
||||
"target_message": action_info.action_message,
|
||||
"reasoning": action_info.reasoning,
|
||||
"action_data": action_info.action_data or {},
|
||||
}
|
||||
|
||||
logger.debug(f"📬 [PlanExecutor] 准备调用 ActionManager,target_message: {action_info.action_message}")
|
||||
|
||||
# 通过动作管理器执行回复
|
||||
reply_content = await self.action_manager.execute_action(
|
||||
action_name=action_info.action_type, **action_params
|
||||
)
|
||||
|
||||
success = True
|
||||
logger.info(f"回复动作 '{action_info.action_type}' 执行成功。")
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.error(f"执行回复动作失败: {action_info.action_type}, 错误: {error_message}")
|
||||
|
||||
# 记录用户关系追踪
|
||||
if success and action_info.action_message:
|
||||
await self._track_user_interaction(action_info, plan, reply_content)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
self.execution_stats["execution_times"].append(execution_time)
|
||||
|
||||
return {
|
||||
"action_type": action_info.action_type,
|
||||
"success": success,
|
||||
"error_message": error_message,
|
||||
"execution_time": execution_time,
|
||||
"reasoning": action_info.reasoning,
|
||||
"reply_content": reply_content[:200] + "..." if len(reply_content) > 200 else reply_content,
|
||||
}
|
||||
|
||||
async def _execute_other_actions(self, other_actions: List[ActionPlannerInfo], plan: Plan) -> Dict[str, any]:
|
||||
"""执行其他动作"""
|
||||
results = []
|
||||
|
||||
# 并行执行其他动作
|
||||
tasks = []
|
||||
for action_info in other_actions:
|
||||
task = self._execute_single_other_action(action_info, plan)
|
||||
tasks.append(task)
|
||||
|
||||
if tasks:
|
||||
executed_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
for i, result in enumerate(executed_results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"执行动作 {other_actions[i].action_type} 时发生异常: {result}")
|
||||
results.append(
|
||||
{
|
||||
"action_type": other_actions[i].action_type,
|
||||
"success": False,
|
||||
"error_message": str(result),
|
||||
"execution_time": 0,
|
||||
"reasoning": other_actions[i].reasoning,
|
||||
}
|
||||
)
|
||||
else:
|
||||
results.append(result)
|
||||
|
||||
return {"results": results}
|
||||
|
||||
async def _execute_single_other_action(self, action_info: ActionPlannerInfo, plan: Plan) -> Dict[str, any]:
|
||||
"""执行单个其他动作"""
|
||||
start_time = time.time()
|
||||
success = False
|
||||
error_message = ""
|
||||
|
||||
try:
|
||||
logger.info(f"执行其他动作: {action_info.action_type} (原因: {action_info.reasoning})")
|
||||
|
||||
action_data = action_info.action_data or {}
|
||||
|
||||
# 针对 poke_user 动作,特殊处理
|
||||
if action_info.action_type == "poke_user":
|
||||
target_message = action_info.action_message
|
||||
if target_message:
|
||||
# 优先直接获取 user_id,这才是最可靠的信息
|
||||
user_id = target_message.get("user_id")
|
||||
if user_id:
|
||||
action_data["user_id"] = user_id
|
||||
logger.info(f"检测到戳一戳动作,目标用户ID: {user_id}")
|
||||
else:
|
||||
# 如果没有 user_id,再尝试用 user_nickname 作为备用方案
|
||||
user_name = target_message.get("user_nickname")
|
||||
if user_name:
|
||||
action_data["user_name"] = user_name
|
||||
logger.info(f"检测到戳一戳动作,目标用户: {user_name}")
|
||||
else:
|
||||
logger.warning("无法从戳一戳消息中获取用户ID或昵称。")
|
||||
|
||||
# 传递原始消息ID以支持引用
|
||||
action_data["target_message_id"] = target_message.get("message_id")
|
||||
|
||||
# 构建动作参数
|
||||
action_params = {
|
||||
"chat_id": plan.chat_id,
|
||||
"target_message": action_info.action_message,
|
||||
"reasoning": action_info.reasoning,
|
||||
"action_data": action_data,
|
||||
}
|
||||
|
||||
# 通过动作管理器执行动作
|
||||
await self.action_manager.execute_action(action_name=action_info.action_type, **action_params)
|
||||
|
||||
success = True
|
||||
logger.info(f"其他动作 '{action_info.action_type}' 执行成功。")
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.error(f"执行其他动作失败: {action_info.action_type}, 错误: {error_message}")
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
self.execution_stats["execution_times"].append(execution_time)
|
||||
|
||||
return {
|
||||
"action_type": action_info.action_type,
|
||||
"success": success,
|
||||
"error_message": error_message,
|
||||
"execution_time": execution_time,
|
||||
"reasoning": action_info.reasoning,
|
||||
}
|
||||
|
||||
async def _track_user_interaction(self, action_info: ActionPlannerInfo, plan: Plan, reply_content: str):
|
||||
"""追踪用户交互 - 集成回复后关系追踪"""
|
||||
try:
|
||||
if not action_info.action_message:
|
||||
return
|
||||
|
||||
# 获取用户信息 - 处理对象和字典两种情况
|
||||
if hasattr(action_info.action_message, "user_info"):
|
||||
# 对象情况
|
||||
user_info = action_info.action_message.user_info
|
||||
user_id = user_info.user_id
|
||||
user_name = user_info.user_nickname or user_id
|
||||
user_message = action_info.action_message.content
|
||||
else:
|
||||
# 字典情况
|
||||
user_info = action_info.action_message.get("user_info", {})
|
||||
user_id = user_info.get("user_id")
|
||||
user_name = user_info.get("user_nickname") or user_id
|
||||
user_message = action_info.action_message.get("content", "")
|
||||
|
||||
if not user_id:
|
||||
logger.debug("跳过追踪:缺少用户ID")
|
||||
return
|
||||
|
||||
# 如果有设置关系追踪器,执行回复后关系追踪
|
||||
if self.relationship_tracker:
|
||||
# 记录基础交互信息(保持向后兼容)
|
||||
self.relationship_tracker.add_interaction(
|
||||
user_id=user_id,
|
||||
user_name=user_name,
|
||||
user_message=user_message,
|
||||
bot_reply=reply_content,
|
||||
reply_timestamp=time.time(),
|
||||
)
|
||||
|
||||
# 执行新的回复后关系追踪
|
||||
await self.relationship_tracker.track_reply_relationship(
|
||||
user_id=user_id, user_name=user_name, bot_reply_content=reply_content, reply_timestamp=time.time()
|
||||
)
|
||||
|
||||
logger.debug(f"已执行用户交互追踪: {user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"追踪用户交互时出错: {e}")
|
||||
logger.debug(f"action_message类型: {type(action_info.action_message)}")
|
||||
logger.debug(f"action_message内容: {action_info.action_message}")
|
||||
|
||||
def get_execution_stats(self) -> Dict[str, any]:
|
||||
"""获取执行统计信息"""
|
||||
stats = self.execution_stats.copy()
|
||||
|
||||
# 计算平均执行时间
|
||||
if stats["execution_times"]:
|
||||
avg_time = sum(stats["execution_times"]) / len(stats["execution_times"])
|
||||
stats["average_execution_time"] = avg_time
|
||||
stats["max_execution_time"] = max(stats["execution_times"])
|
||||
stats["min_execution_time"] = min(stats["execution_times"])
|
||||
else:
|
||||
stats["average_execution_time"] = 0
|
||||
stats["max_execution_time"] = 0
|
||||
stats["min_execution_time"] = 0
|
||||
|
||||
# 移除执行时间列表以避免返回过大数据
|
||||
stats.pop("execution_times", None)
|
||||
|
||||
return stats
|
||||
|
||||
def reset_stats(self):
|
||||
"""重置统计信息"""
|
||||
self.execution_stats = {
|
||||
"total_executed": 0,
|
||||
"successful_executions": 0,
|
||||
"failed_executions": 0,
|
||||
"reply_executions": 0,
|
||||
"other_action_executions": 0,
|
||||
"execution_times": [],
|
||||
}
|
||||
|
||||
def get_recent_performance(self, limit: int = 10) -> List[Dict[str, any]]:
|
||||
"""获取最近的执行性能"""
|
||||
recent_times = self.execution_stats["execution_times"][-limit:]
|
||||
if not recent_times:
|
||||
return []
|
||||
|
||||
return [
|
||||
{
|
||||
"execution_index": i + 1,
|
||||
"execution_time": time_val,
|
||||
"timestamp": time.time() - (len(recent_times) - i) * 60, # 估算时间戳
|
||||
}
|
||||
for i, time_val in enumerate(recent_times)
|
||||
]
|
||||
678
src/plugins/built_in/affinity_flow_chatter/plan_filter.py
Normal file
678
src/plugins/built_in/affinity_flow_chatter/plan_filter.py
Normal file
@@ -0,0 +1,678 @@
|
||||
"""
|
||||
PlanFilter: 接收 Plan 对象,根据不同模式的逻辑进行筛选,决定最终要执行的动作。
|
||||
"""
|
||||
|
||||
import orjson
|
||||
import time
|
||||
import traceback
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_actions,
|
||||
build_readable_messages_with_id,
|
||||
get_actions_by_timestamp_with_chat,
|
||||
)
|
||||
from src.chat.utils.prompt import global_prompt_manager
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo, Plan
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode, ChatType
|
||||
from src.schedule.schedule_manager import schedule_manager
|
||||
|
||||
logger = get_logger("plan_filter")
|
||||
|
||||
SAKURA_PINK = "\033[38;5;175m"
|
||||
SKY_BLUE = "\033[38;5;117m"
|
||||
RESET_COLOR = "\033[0m"
|
||||
|
||||
|
||||
class ChatterPlanFilter:
|
||||
"""
|
||||
根据 Plan 中的模式和信息,筛选并决定最终的动作。
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str, available_actions: List[str]):
|
||||
"""
|
||||
初始化动作计划筛选器。
|
||||
|
||||
Args:
|
||||
chat_id (str): 当前聊天的唯一标识符。
|
||||
available_actions (List[str]): 当前可用的动作列表。
|
||||
"""
|
||||
self.chat_id = chat_id
|
||||
self.available_actions = available_actions
|
||||
self.planner_llm = LLMRequest(model_set=model_config.model_task_config.planner, request_type="planner")
|
||||
self.last_obs_time_mark = 0.0
|
||||
|
||||
async def filter(self, reply_not_available: bool, plan: Plan) -> Plan:
|
||||
"""
|
||||
执行筛选逻辑,并填充 Plan 对象的 decided_actions 字段。
|
||||
"""
|
||||
try:
|
||||
prompt, used_message_id_list = await self._build_prompt(plan)
|
||||
plan.llm_prompt = prompt
|
||||
|
||||
llm_content, _ = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||
|
||||
if llm_content:
|
||||
try:
|
||||
parsed_json = orjson.loads(repair_json(llm_content))
|
||||
except orjson.JSONDecodeError:
|
||||
parsed_json = {
|
||||
"thinking": "",
|
||||
"actions": {"action_type": "no_action", "reason": "返回内容无法解析为JSON"},
|
||||
}
|
||||
|
||||
if "reply" in plan.available_actions and reply_not_available:
|
||||
# 如果reply动作不可用,但llm返回的仍然有reply,则改为no_reply
|
||||
if (
|
||||
isinstance(parsed_json, dict)
|
||||
and parsed_json.get("actions", {}).get("action_type", "") == "reply"
|
||||
):
|
||||
parsed_json["actions"]["action_type"] = "no_reply"
|
||||
elif isinstance(parsed_json, list):
|
||||
for item in parsed_json:
|
||||
if isinstance(item, dict) and item.get("actions", {}).get("action_type", "") == "reply":
|
||||
item["actions"]["action_type"] = "no_reply"
|
||||
item["actions"]["reason"] += " (但由于兴趣度不足,reply动作不可用,已改为no_reply)"
|
||||
|
||||
if isinstance(parsed_json, dict):
|
||||
parsed_json = [parsed_json]
|
||||
|
||||
if isinstance(parsed_json, list):
|
||||
final_actions = []
|
||||
reply_action_added = False
|
||||
# 定义回复类动作的集合,方便扩展
|
||||
reply_action_types = {"reply", "proactive_reply"}
|
||||
|
||||
for item in parsed_json:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
# 预解析 action_type 来进行判断
|
||||
thinking = item.get("thinking", "未提供思考过程")
|
||||
actions_obj = item.get("actions", {})
|
||||
|
||||
# 处理actions字段可能是字典或列表的情况
|
||||
if isinstance(actions_obj, dict):
|
||||
action_type = actions_obj.get("action_type", "no_action")
|
||||
elif isinstance(actions_obj, list) and actions_obj:
|
||||
# 如果是列表,取第一个元素的action_type
|
||||
first_action = actions_obj[0]
|
||||
if isinstance(first_action, dict):
|
||||
action_type = first_action.get("action_type", "no_action")
|
||||
else:
|
||||
action_type = "no_action"
|
||||
else:
|
||||
action_type = "no_action"
|
||||
|
||||
if action_type in reply_action_types:
|
||||
if not reply_action_added:
|
||||
final_actions.extend(
|
||||
await self._parse_single_action(item, used_message_id_list, plan)
|
||||
)
|
||||
reply_action_added = True
|
||||
else:
|
||||
# 非回复类动作直接添加
|
||||
final_actions.extend(await self._parse_single_action(item, used_message_id_list, plan))
|
||||
|
||||
if thinking and thinking != "未提供思考过程":
|
||||
logger.info(f"\n{SAKURA_PINK}思考: {thinking}{RESET_COLOR}\n")
|
||||
plan.decided_actions = self._filter_no_actions(final_actions)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"筛选 Plan 时出错: {e}\n{traceback.format_exc()}")
|
||||
plan.decided_actions = [ActionPlannerInfo(action_type="no_action", reasoning=f"筛选时出错: {e}")]
|
||||
|
||||
# 在返回最终计划前,打印将要执行的动作
|
||||
action_types = [action.action_type for action in plan.decided_actions]
|
||||
logger.info(f"选择动作: [{SKY_BLUE}{', '.join(action_types) if action_types else '无'}{RESET_COLOR}]")
|
||||
|
||||
return plan
|
||||
|
||||
async def _build_prompt(self, plan: Plan) -> tuple[str, list]:
|
||||
"""
|
||||
根据 Plan 对象构建提示词。
|
||||
"""
|
||||
try:
|
||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
bot_name = global_config.bot.nickname
|
||||
bot_nickname = (
|
||||
f",也有人叫你{','.join(global_config.bot.alias_names)}" if global_config.bot.alias_names else ""
|
||||
)
|
||||
bot_core_personality = global_config.personality.personality_core
|
||||
identity_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}:"
|
||||
|
||||
schedule_block = ""
|
||||
# 优先检查是否被吵醒
|
||||
from src.chat.message_manager.message_manager import message_manager
|
||||
angry_prompt_addition = ""
|
||||
wakeup_mgr = message_manager.wakeup_manager
|
||||
|
||||
# 双重检查确保愤怒状态不会丢失
|
||||
# 检查1: 直接从 wakeup_manager 获取
|
||||
if wakeup_mgr.is_in_angry_state():
|
||||
angry_prompt_addition = wakeup_mgr.get_angry_prompt_addition()
|
||||
|
||||
# 检查2: 如果上面没获取到,再从 mood_manager 确认
|
||||
if not angry_prompt_addition:
|
||||
chat_mood_for_check = mood_manager.get_mood_by_chat_id(plan.chat_id)
|
||||
if chat_mood_for_check.is_angry_from_wakeup:
|
||||
angry_prompt_addition = global_config.sleep_system.angry_prompt
|
||||
|
||||
if angry_prompt_addition:
|
||||
schedule_block = angry_prompt_addition
|
||||
elif global_config.planning_system.schedule_enable:
|
||||
if current_activity := schedule_manager.get_current_activity():
|
||||
schedule_block = f"你当前正在:{current_activity},但注意它与群聊的聊天无关。"
|
||||
|
||||
mood_block = ""
|
||||
# 如果被吵醒,则心情也是愤怒的,不需要另外的情绪模块
|
||||
if not angry_prompt_addition and global_config.mood.enable_mood:
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(plan.chat_id)
|
||||
mood_block = f"你现在的心情是:{chat_mood.mood_state}"
|
||||
|
||||
if plan.mode == ChatMode.PROACTIVE:
|
||||
long_term_memory_block = await self._get_long_term_memory_context()
|
||||
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=[msg.flatten() for msg in plan.chat_history],
|
||||
timestamp_mode="normal",
|
||||
truncate=False,
|
||||
show_actions=False,
|
||||
)
|
||||
|
||||
prompt_template = await global_prompt_manager.get_prompt_async("proactive_planner_prompt")
|
||||
actions_before_now = get_actions_by_timestamp_with_chat(
|
||||
chat_id=plan.chat_id,
|
||||
timestamp_start=time.time() - 3600,
|
||||
timestamp_end=time.time(),
|
||||
limit=5,
|
||||
)
|
||||
actions_before_now_block = build_readable_actions(actions=actions_before_now)
|
||||
actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}"
|
||||
|
||||
prompt = prompt_template.format(
|
||||
time_block=time_block,
|
||||
identity_block=identity_block,
|
||||
schedule_block=schedule_block,
|
||||
mood_block=mood_block,
|
||||
long_term_memory_block=long_term_memory_block,
|
||||
chat_content_block=chat_content_block or "最近没有聊天内容。",
|
||||
actions_before_now_block=actions_before_now_block,
|
||||
)
|
||||
return prompt, message_id_list
|
||||
|
||||
# 构建已读/未读历史消息
|
||||
read_history_block, unread_history_block, message_id_list = await self._build_read_unread_history_blocks(
|
||||
plan
|
||||
)
|
||||
|
||||
# 为了兼容性,保留原有的chat_content_block
|
||||
chat_content_block, _ = build_readable_messages_with_id(
|
||||
messages=[msg.flatten() for msg in plan.chat_history],
|
||||
timestamp_mode="normal",
|
||||
read_mark=self.last_obs_time_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
actions_before_now = get_actions_by_timestamp_with_chat(
|
||||
chat_id=plan.chat_id,
|
||||
timestamp_start=time.time() - 3600,
|
||||
timestamp_end=time.time(),
|
||||
limit=5,
|
||||
)
|
||||
|
||||
actions_before_now_block = build_readable_actions(actions=actions_before_now)
|
||||
actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}"
|
||||
|
||||
self.last_obs_time_mark = time.time()
|
||||
|
||||
mentioned_bonus = ""
|
||||
if global_config.chat.mentioned_bot_inevitable_reply:
|
||||
mentioned_bonus = "\n- 有人提到你"
|
||||
if global_config.chat.at_bot_inevitable_reply:
|
||||
mentioned_bonus = "\n- 有人提到你,或者at你"
|
||||
|
||||
if plan.mode == ChatMode.FOCUS:
|
||||
no_action_block = """
|
||||
动作:no_action
|
||||
动作描述:不选择任何动作
|
||||
{{
|
||||
"action": "no_action",
|
||||
"reason":"不动作的原因"
|
||||
}}
|
||||
|
||||
动作:no_reply
|
||||
动作描述:不进行回复,等待合适的回复时机
|
||||
- 当你刚刚发送了消息,没有人回复时,选择no_reply
|
||||
- 当你一次发送了太多消息,为了避免打扰聊天节奏,选择no_reply
|
||||
{{
|
||||
"action": "no_reply",
|
||||
"reason":"不回复的原因"
|
||||
}}
|
||||
"""
|
||||
else: # normal Mode
|
||||
no_action_block = """重要说明:
|
||||
- 'reply' 表示只进行普通聊天回复,不执行任何额外动作
|
||||
- 其他action表示在普通回复的基础上,执行相应的额外动作
|
||||
{{
|
||||
"action": "reply",
|
||||
"target_message_id":"触发action的消息id",
|
||||
"reason":"回复的原因"
|
||||
}}"""
|
||||
|
||||
is_group_chat = plan.chat_type == ChatType.GROUP
|
||||
chat_context_description = "你现在正在一个群聊中"
|
||||
if not is_group_chat and plan.target_info:
|
||||
chat_target_name = plan.target_info.get("person_name") or plan.target_info.get("user_nickname") or "对方"
|
||||
chat_context_description = f"你正在和 {chat_target_name} 私聊"
|
||||
|
||||
action_options_block = await self._build_action_options(plan.available_actions)
|
||||
|
||||
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
|
||||
|
||||
custom_prompt_block = ""
|
||||
if global_config.custom_prompt.planner_custom_prompt_content:
|
||||
custom_prompt_block = global_config.custom_prompt.planner_custom_prompt_content
|
||||
|
||||
users_in_chat_str = "" # TODO: Re-implement user list fetching if needed
|
||||
|
||||
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
|
||||
prompt = planner_prompt_template.format(
|
||||
schedule_block=schedule_block,
|
||||
mood_block=mood_block,
|
||||
time_block=time_block,
|
||||
chat_context_description=chat_context_description,
|
||||
read_history_block=read_history_block,
|
||||
unread_history_block=unread_history_block,
|
||||
actions_before_now_block=actions_before_now_block,
|
||||
mentioned_bonus=mentioned_bonus,
|
||||
no_action_block=no_action_block,
|
||||
action_options_text=action_options_block,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
identity_block=identity_block,
|
||||
custom_prompt_block=custom_prompt_block,
|
||||
bot_name=bot_name,
|
||||
users_in_chat=users_in_chat_str,
|
||||
)
|
||||
return prompt, message_id_list
|
||||
except Exception as e:
|
||||
logger.error(f"构建 Planner 提示词时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return "构建 Planner Prompt 时出错", []
|
||||
|
||||
async def _build_read_unread_history_blocks(self, plan: Plan) -> tuple[str, str, list]:
|
||||
"""构建已读/未读历史消息块"""
|
||||
try:
|
||||
# 从message_manager获取真实的已读/未读消息
|
||||
from src.chat.message_manager.message_manager import message_manager
|
||||
from src.chat.utils.utils import assign_message_ids
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
|
||||
|
||||
# 获取聊天流的上下文
|
||||
stream_context = message_manager.stream_contexts.get(plan.chat_id)
|
||||
|
||||
# 获取真正的已读和未读消息
|
||||
read_messages = stream_context.history_messages # 已读消息存储在history_messages中
|
||||
if not read_messages:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
# 如果内存中没有已读消息(比如刚启动),则从数据库加载最近的上下文
|
||||
fallback_messages_dicts = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=plan.chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=global_config.chat.max_context_size,
|
||||
)
|
||||
# 将字典转换为DatabaseMessages对象
|
||||
read_messages = [DatabaseMessages(**msg_dict) for msg_dict in fallback_messages_dicts]
|
||||
|
||||
unread_messages = stream_context.get_unread_messages() # 获取未读消息
|
||||
|
||||
# 构建已读历史消息块
|
||||
if read_messages:
|
||||
read_content, read_ids = build_readable_messages_with_id(
|
||||
messages=[msg.flatten() for msg in read_messages[-50:]], # 限制数量
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=False,
|
||||
show_actions=False,
|
||||
)
|
||||
read_history_block = f"{read_content}"
|
||||
else:
|
||||
read_history_block = "暂无已读历史消息"
|
||||
|
||||
# 构建未读历史消息块(包含兴趣度)
|
||||
if unread_messages:
|
||||
# 扁平化未读消息用于计算兴趣度和格式化
|
||||
flattened_unread = [msg.flatten() for msg in unread_messages]
|
||||
|
||||
# 尝试获取兴趣度评分(返回以真实 message_id 为键的字典)
|
||||
interest_scores = await self._get_interest_scores_for_messages(flattened_unread)
|
||||
|
||||
# 为未读消息分配短 id(保持与 build_readable_messages_with_id 的一致结构)
|
||||
message_id_list = assign_message_ids(flattened_unread)
|
||||
|
||||
unread_lines = []
|
||||
for idx, msg in enumerate(flattened_unread):
|
||||
mapped = message_id_list[idx]
|
||||
synthetic_id = mapped.get("id")
|
||||
original_msg_id = msg.get("message_id") or msg.get("id")
|
||||
msg_time = time.strftime("%H:%M:%S", time.localtime(msg.get("time", time.time())))
|
||||
user_nickname = msg.get("user_nickname", "未知用户")
|
||||
msg_content = msg.get("processed_plain_text", "")
|
||||
|
||||
# 不再显示兴趣度,但保留合成ID供模型内部使用
|
||||
# 同时,为了让模型更好地理解上下文,我们显示用户名
|
||||
unread_lines.append(f"<{synthetic_id}> {msg_time} {user_nickname}: {msg_content}")
|
||||
|
||||
unread_history_block = "\n".join(unread_lines)
|
||||
else:
|
||||
unread_history_block = "暂无未读历史消息"
|
||||
|
||||
return read_history_block, unread_history_block, message_id_list
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"构建已读/未读历史消息块时出错: {e}")
|
||||
return "构建已读历史消息时出错", "构建未读历史消息时出错", []
|
||||
|
||||
async def _get_interest_scores_for_messages(self, messages: List[dict]) -> dict[str, float]:
|
||||
"""为消息获取兴趣度评分"""
|
||||
interest_scores = {}
|
||||
|
||||
try:
|
||||
from .interest_scoring import chatter_interest_scoring_system
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
# 使用插件内部的兴趣度评分系统计算评分
|
||||
for msg_dict in messages:
|
||||
try:
|
||||
# 将字典转换为DatabaseMessages对象
|
||||
db_message = DatabaseMessages(
|
||||
message_id=msg_dict.get("message_id", ""),
|
||||
user_info=msg_dict.get("user_info", {}),
|
||||
processed_plain_text=msg_dict.get("processed_plain_text", ""),
|
||||
key_words=msg_dict.get("key_words", "[]"),
|
||||
is_mentioned=msg_dict.get("is_mentioned", False)
|
||||
)
|
||||
|
||||
# 计算消息兴趣度
|
||||
interest_score_obj = await chatter_interest_scoring_system._calculate_single_message_score(
|
||||
message=db_message,
|
||||
bot_nickname=global_config.bot.nickname
|
||||
)
|
||||
interest_score = interest_score_obj.total_score
|
||||
|
||||
# 构建兴趣度字典
|
||||
interest_scores[msg_dict.get("message_id", "")] = interest_score
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"计算消息兴趣度失败: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"获取兴趣度评分失败: {e}")
|
||||
|
||||
return interest_scores
|
||||
|
||||
async def _parse_single_action(
|
||||
self, action_json: dict, message_id_list: list, plan: Plan
|
||||
) -> List[ActionPlannerInfo]:
|
||||
parsed_actions = []
|
||||
try:
|
||||
# 从新的actions结构中获取动作信息
|
||||
actions_obj = action_json.get("actions", {})
|
||||
|
||||
# 处理actions字段可能是字典或列表的情况
|
||||
actions_to_process = []
|
||||
if isinstance(actions_obj, dict):
|
||||
actions_to_process.append(actions_obj)
|
||||
elif isinstance(actions_obj, list):
|
||||
actions_to_process.extend(actions_obj)
|
||||
|
||||
if not actions_to_process:
|
||||
actions_to_process.append({"action_type": "no_action", "reason": "actions格式错误"})
|
||||
|
||||
for single_action_obj in actions_to_process:
|
||||
if not isinstance(single_action_obj, dict):
|
||||
continue
|
||||
|
||||
action = single_action_obj.get("action_type", "no_action")
|
||||
reasoning = single_action_obj.get("reasoning", "未提供原因") # 兼容旧的reason字段
|
||||
action_data = single_action_obj.get("action_data", {})
|
||||
|
||||
# 为了向后兼容,如果action_data不存在,则从顶层字段获取
|
||||
if not action_data:
|
||||
action_data = {k: v for k, v in single_action_obj.items() if k not in ["action_type", "reason", "reasoning", "thinking"]}
|
||||
|
||||
# 保留原始的thinking字段(如果有)
|
||||
thinking = action_json.get("thinking", "")
|
||||
if thinking and thinking != "未提供思考过程":
|
||||
action_data["thinking"] = thinking
|
||||
|
||||
target_message_obj = None
|
||||
if action not in ["no_action", "no_reply", "do_nothing", "proactive_reply"]:
|
||||
if target_message_id := action_data.get("target_message_id"):
|
||||
target_message_dict = self._find_message_by_id(target_message_id, message_id_list)
|
||||
else:
|
||||
# 如果LLM没有指定target_message_id,进行特殊处理
|
||||
if action == "poke_user":
|
||||
# 对于poke_user,尝试找到触发它的那条戳一戳消息
|
||||
target_message_dict = self._find_poke_notice(message_id_list)
|
||||
if not target_message_dict:
|
||||
# 如果找不到,再使用最新消息作为兜底
|
||||
target_message_dict = self._get_latest_message(message_id_list)
|
||||
else:
|
||||
# 其他动作,默认选择最新的一条消息
|
||||
target_message_dict = self._get_latest_message(message_id_list)
|
||||
|
||||
if target_message_dict:
|
||||
# 直接使用字典作为action_message,避免DatabaseMessages对象创建失败
|
||||
target_message_obj = target_message_dict
|
||||
# 替换action_data中的临时ID为真实ID
|
||||
if "target_message_id" in action_data:
|
||||
real_message_id = target_message_dict.get("message_id") or target_message_dict.get("id")
|
||||
if real_message_id:
|
||||
action_data["target_message_id"] = real_message_id
|
||||
|
||||
# 确保 action_message 中始终有 message_id 字段
|
||||
if "message_id" not in target_message_obj and "id" in target_message_obj:
|
||||
target_message_obj["message_id"] = target_message_obj["id"]
|
||||
else:
|
||||
# 如果找不到目标消息,对于reply动作来说这是必需的,应该记录警告
|
||||
if action == "reply":
|
||||
logger.warning(
|
||||
f"reply动作找不到目标消息,target_message_id: {action_data.get('target_message_id')}"
|
||||
)
|
||||
# 将reply动作改为no_action,避免后续执行时出错
|
||||
action = "no_action"
|
||||
reasoning = f"找不到目标消息进行回复。原始理由: {reasoning}"
|
||||
|
||||
if (
|
||||
action not in ["no_action", "no_reply", "reply", "do_nothing", "proactive_reply"]
|
||||
and action not in plan.available_actions
|
||||
):
|
||||
reasoning = f"LLM 返回了当前不可用的动作 '{action}'。原始理由: {reasoning}"
|
||||
action = "no_action"
|
||||
|
||||
parsed_actions.append(
|
||||
ActionPlannerInfo(
|
||||
action_type=action,
|
||||
reasoning=reasoning,
|
||||
action_data=action_data,
|
||||
action_message=target_message_obj,
|
||||
available_actions=plan.available_actions,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"解析单个action时出错: {e}")
|
||||
parsed_actions.append(
|
||||
ActionPlannerInfo(
|
||||
action_type="no_action",
|
||||
reasoning=f"解析action时出错: {e}",
|
||||
)
|
||||
)
|
||||
return parsed_actions
|
||||
|
||||
def _filter_no_actions(self, action_list: List[ActionPlannerInfo]) -> List[ActionPlannerInfo]:
|
||||
non_no_actions = [a for a in action_list if a.action_type not in ["no_action", "no_reply"]]
|
||||
if non_no_actions:
|
||||
return non_no_actions
|
||||
return action_list[:1] if action_list else []
|
||||
|
||||
async def _get_long_term_memory_context(self) -> str:
|
||||
try:
|
||||
now = datetime.now()
|
||||
keywords = ["今天", "日程", "计划"]
|
||||
if 5 <= now.hour < 12:
|
||||
keywords.append("早上")
|
||||
elif 12 <= now.hour < 18:
|
||||
keywords.append("中午")
|
||||
else:
|
||||
keywords.append("晚上")
|
||||
|
||||
retrieved_memories = await hippocampus_manager.get_memory_from_topic(
|
||||
valid_keywords=keywords, max_memory_num=5, max_memory_length=1
|
||||
)
|
||||
|
||||
if not retrieved_memories:
|
||||
return "最近没有什么特别的记忆。"
|
||||
|
||||
memory_statements = [f"关于'{topic}', 你记得'{memory_item}'。" for topic, memory_item in retrieved_memories]
|
||||
return " ".join(memory_statements)
|
||||
except Exception as e:
|
||||
logger.error(f"获取长期记忆时出错: {e}")
|
||||
return "回忆时出现了一些问题。"
|
||||
|
||||
async def _build_action_options(self, current_available_actions: Dict[str, ActionInfo]) -> str:
|
||||
action_options_block = ""
|
||||
for action_name, action_info in current_available_actions.items():
|
||||
# 构建参数的JSON示例
|
||||
params_json_list = []
|
||||
if action_info.action_parameters:
|
||||
for p_name, p_desc in action_info.action_parameters.items():
|
||||
# 为参数描述添加一个通用示例值
|
||||
if action_name == "set_emoji_like" and p_name == "emoji":
|
||||
# 特殊处理set_emoji_like的emoji参数
|
||||
from plugins.set_emoji_like.qq_emoji_list import qq_face
|
||||
emoji_options = [re.search(r"\[表情:(.+?)\]", name).group(1) for name in qq_face.values() if re.search(r"\[表情:(.+?)\]", name)]
|
||||
example_value = f"<从'{', '.join(emoji_options[:10])}...'中选择一个>"
|
||||
else:
|
||||
example_value = f"<{p_desc}>"
|
||||
params_json_list.append(f' "{p_name}": "{example_value}"')
|
||||
|
||||
# 基础动作信息
|
||||
action_description = action_info.description
|
||||
action_require = "\n".join(f"- {req}" for req in action_info.action_require)
|
||||
|
||||
# 构建完整的JSON使用范例
|
||||
json_example_lines = [
|
||||
" {",
|
||||
f' "action_type": "{action_name}"',
|
||||
]
|
||||
# 将参数列表合并到JSON示例中
|
||||
if params_json_list:
|
||||
# 移除最后一行的逗号
|
||||
json_example_lines.extend([line.rstrip(',') for line in params_json_list])
|
||||
|
||||
json_example_lines.append(' "reason": "<执行该动作的详细原因>"')
|
||||
json_example_lines.append(" }")
|
||||
|
||||
# 使用逗号连接内部元素,除了最后一个
|
||||
json_parts = []
|
||||
for i, line in enumerate(json_example_lines):
|
||||
# "{" 和 "}" 不需要逗号
|
||||
if line.strip() in ["{", "}"]:
|
||||
json_parts.append(line)
|
||||
continue
|
||||
|
||||
# 检查是否是最后一个需要逗号的元素
|
||||
is_last_item = True
|
||||
for next_line in json_example_lines[i+1:]:
|
||||
if next_line.strip() not in ["}"]:
|
||||
is_last_item = False
|
||||
break
|
||||
|
||||
if not is_last_item:
|
||||
json_parts.append(f"{line},")
|
||||
else:
|
||||
json_parts.append(line)
|
||||
|
||||
json_example = "\n".join(json_parts)
|
||||
|
||||
# 使用新的、更详细的action_prompt模板
|
||||
using_action_prompt = await global_prompt_manager.get_prompt_async("action_prompt_with_example")
|
||||
action_options_block += using_action_prompt.format(
|
||||
action_name=action_name,
|
||||
action_description=action_description,
|
||||
action_require=action_require,
|
||||
json_example=json_example,
|
||||
)
|
||||
return action_options_block
|
||||
|
||||
def _find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]:
|
||||
# 兼容多种 message_id 格式:数字、m123、buffered-xxxx
|
||||
# 如果是纯数字,补上 m 前缀以兼容旧格式
|
||||
candidate_ids = {message_id}
|
||||
if message_id.isdigit():
|
||||
candidate_ids.add(f"m{message_id}")
|
||||
|
||||
# 如果是 m 开头且后面是数字,尝试去掉 m 前缀的数字形式
|
||||
if message_id.startswith("m") and message_id[1:].isdigit():
|
||||
candidate_ids.add(message_id[1:])
|
||||
|
||||
# 逐项匹配 message_id_list(每项可能为 {'id':..., 'message':...})
|
||||
for item in message_id_list:
|
||||
# 支持 message_id_list 中直接是字符串/ID 的情形
|
||||
if isinstance(item, str):
|
||||
if item in candidate_ids:
|
||||
# 没有 message 对象,返回None
|
||||
return None
|
||||
continue
|
||||
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
item_id = item.get("id")
|
||||
# 直接匹配分配的短 id
|
||||
if item_id and item_id in candidate_ids:
|
||||
return item.get("message")
|
||||
|
||||
# 有时 message 存储里会有原始的 message_id 字段(如 buffered-xxxx)
|
||||
message_obj = item.get("message")
|
||||
if isinstance(message_obj, dict):
|
||||
orig_mid = message_obj.get("message_id") or message_obj.get("id")
|
||||
if orig_mid and orig_mid in candidate_ids:
|
||||
return message_obj
|
||||
|
||||
# 作为兜底,尝试在 message_id_list 中找到 message.message_id 匹配
|
||||
for item in message_id_list:
|
||||
if isinstance(item, dict) and isinstance(item.get("message"), dict):
|
||||
mid = item["message"].get("message_id") or item["message"].get("id")
|
||||
if mid == message_id:
|
||||
return item["message"]
|
||||
|
||||
return None
|
||||
|
||||
def _get_latest_message(self, message_id_list: list) -> Optional[Dict[str, Any]]:
|
||||
if not message_id_list:
|
||||
return None
|
||||
return message_id_list[-1].get("message")
|
||||
|
||||
def _find_poke_notice(self, message_id_list: list) -> Optional[Dict[str, Any]]:
|
||||
"""在消息列表中寻找戳一戳的通知消息"""
|
||||
for item in reversed(message_id_list):
|
||||
message = item.get("message")
|
||||
if (
|
||||
isinstance(message, dict)
|
||||
and message.get("type") == "notice"
|
||||
and "戳" in message.get("processed_plain_text", "")
|
||||
):
|
||||
return message
|
||||
return None
|
||||
168
src/plugins/built_in/affinity_flow_chatter/plan_generator.py
Normal file
168
src/plugins/built_in/affinity_flow_chatter/plan_generator.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
PlanGenerator: 负责搜集和汇总所有决策所需的信息,生成一个未经筛选的"原始计划" (Plan)。
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.info_data_model import Plan, TargetPersonInfo
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode, ChatType
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
|
||||
class ChatterPlanGenerator:
|
||||
"""
|
||||
ChatterPlanGenerator 负责在规划流程的初始阶段收集所有必要信息。
|
||||
|
||||
它会汇总以下信息来构建一个"原始"的 Plan 对象,该对象后续会由 PlanFilter 进行筛选:
|
||||
- 当前聊天信息 (ID, 目标用户)
|
||||
- 当前可用的动作列表
|
||||
- 最近的聊天历史记录
|
||||
|
||||
Attributes:
|
||||
chat_id (str): 当前聊天的唯一标识符。
|
||||
action_manager (ActionManager): 用于获取可用动作列表的管理器。
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str):
|
||||
"""
|
||||
初始化 ChatterPlanGenerator。
|
||||
|
||||
Args:
|
||||
chat_id (str): 当前聊天的 ID。
|
||||
"""
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
|
||||
self.chat_id = chat_id
|
||||
# 注意:ChatterActionManager 可能需要根据实际情况初始化
|
||||
self.action_manager = ChatterActionManager()
|
||||
|
||||
async def generate(self, mode: ChatMode) -> Plan:
|
||||
"""
|
||||
收集所有信息,生成并返回一个初始的 Plan 对象。
|
||||
|
||||
这个 Plan 对象包含了决策所需的所有上下文信息。
|
||||
|
||||
Args:
|
||||
mode (ChatMode): 当前的聊天模式。
|
||||
|
||||
Returns:
|
||||
Plan: 包含所有上下文信息的初始计划对象。
|
||||
"""
|
||||
try:
|
||||
# 获取聊天类型和目标信息
|
||||
chat_type, target_info = get_chat_type_and_target_info(self.chat_id)
|
||||
|
||||
# 获取可用动作列表
|
||||
available_actions = await self._get_available_actions(chat_type, mode)
|
||||
|
||||
# 获取聊天历史记录
|
||||
recent_messages = await self._get_recent_messages()
|
||||
|
||||
# 构建计划对象
|
||||
plan = Plan(
|
||||
chat_id=self.chat_id,
|
||||
chat_type=chat_type,
|
||||
mode=mode,
|
||||
target_info=target_info,
|
||||
available_actions=available_actions,
|
||||
chat_history=recent_messages,
|
||||
)
|
||||
|
||||
return plan
|
||||
|
||||
except Exception:
|
||||
# 如果生成失败,返回一个基本的空计划
|
||||
return Plan(
|
||||
chat_id=self.chat_id,
|
||||
mode=mode,
|
||||
target_info=TargetPersonInfo(),
|
||||
available_actions={},
|
||||
chat_history=[],
|
||||
)
|
||||
|
||||
async def _get_available_actions(self, chat_type: ChatType, mode: ChatMode) -> Dict[str, ActionInfo]:
|
||||
"""
|
||||
获取当前可用的动作列表。
|
||||
|
||||
Args:
|
||||
chat_type (ChatType): 聊天类型。
|
||||
mode (ChatMode): 聊天模式。
|
||||
|
||||
Returns:
|
||||
Dict[str, ActionInfo]: 可用动作的字典。
|
||||
"""
|
||||
try:
|
||||
# 从组件注册表获取可用动作
|
||||
available_actions = component_registry.get_enabled_actions()
|
||||
|
||||
# 根据聊天类型和模式筛选动作
|
||||
filtered_actions = {}
|
||||
for action_name, action_info in available_actions.items():
|
||||
# 检查动作是否支持当前聊天类型
|
||||
if chat_type in action_info.chat_types:
|
||||
# 检查动作是否支持当前模式
|
||||
if mode in action_info.chat_modes:
|
||||
filtered_actions[action_name] = action_info
|
||||
|
||||
return filtered_actions
|
||||
|
||||
except Exception:
|
||||
# 如果获取失败,返回空字典
|
||||
return {}
|
||||
|
||||
async def _get_recent_messages(self) -> list[DatabaseMessages]:
|
||||
"""
|
||||
获取最近的聊天历史记录。
|
||||
|
||||
Returns:
|
||||
list[DatabaseMessages]: 最近的聊天消息列表。
|
||||
"""
|
||||
try:
|
||||
# 获取最近的消息记录
|
||||
raw_messages = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.chat_id, timestamp=time.time(), limit=global_config.memory.short_memory_length
|
||||
)
|
||||
|
||||
# 转换为 DatabaseMessages 对象
|
||||
recent_messages = []
|
||||
for msg in raw_messages:
|
||||
try:
|
||||
db_msg = DatabaseMessages(
|
||||
message_id=msg.get("message_id", ""),
|
||||
time=float(msg.get("time", 0)),
|
||||
chat_id=msg.get("chat_id", ""),
|
||||
processed_plain_text=msg.get("processed_plain_text", ""),
|
||||
user_id=msg.get("user_id", ""),
|
||||
user_nickname=msg.get("user_nickname", ""),
|
||||
user_platform=msg.get("user_platform", ""),
|
||||
)
|
||||
recent_messages.append(db_msg)
|
||||
except Exception:
|
||||
# 跳过格式错误的消息
|
||||
continue
|
||||
|
||||
return recent_messages
|
||||
|
||||
except Exception:
|
||||
# 如果获取失败,返回空列表
|
||||
return []
|
||||
|
||||
def get_generator_stats(self) -> Dict:
|
||||
"""
|
||||
获取生成器统计信息。
|
||||
|
||||
Returns:
|
||||
Dict: 统计信息字典。
|
||||
"""
|
||||
return {
|
||||
"chat_id": self.chat_id,
|
||||
"action_count": len(self.action_manager._using_actions)
|
||||
if hasattr(self.action_manager, "_using_actions")
|
||||
else 0,
|
||||
"generation_time": time.time(),
|
||||
}
|
||||
269
src/plugins/built_in/affinity_flow_chatter/planner.py
Normal file
269
src/plugins/built_in/affinity_flow_chatter/planner.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
主规划器入口,负责协调 PlanGenerator, PlanFilter, 和 PlanExecutor。
|
||||
集成兴趣度评分系统和用户关系追踪机制,实现智能化的聊天决策。
|
||||
"""
|
||||
|
||||
from dataclasses import asdict
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor
|
||||
from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter
|
||||
from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPlanGenerator
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||
from src.mood.mood_manager import mood_manager
|
||||
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.common.data_models.info_data_model import Plan
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
|
||||
# 导入提示词模块以确保其被初始化
|
||||
from src.plugins.built_in.affinity_flow_chatter import planner_prompts # noqa
|
||||
|
||||
logger = get_logger("planner")
|
||||
|
||||
|
||||
class ChatterActionPlanner:
|
||||
"""
|
||||
增强版ActionPlanner,集成兴趣度评分和用户关系追踪机制。
|
||||
|
||||
核心功能:
|
||||
1. 兴趣度评分系统:根据兴趣匹配度、关系分、提及度、时间因子对消息评分
|
||||
2. 用户关系追踪:自动追踪用户交互并更新关系分
|
||||
3. 智能回复决策:基于兴趣度阈值和连续不回复概率的智能决策
|
||||
4. 完整的规划流程:生成→筛选→执行的完整三阶段流程
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str, action_manager: "ChatterActionManager"):
|
||||
"""
|
||||
初始化增强版ActionPlanner。
|
||||
|
||||
Args:
|
||||
chat_id (str): 当前聊天的 ID。
|
||||
action_manager (ChatterActionManager): 一个 ChatterActionManager 实例。
|
||||
"""
|
||||
self.chat_id = chat_id
|
||||
self.action_manager = action_manager
|
||||
self.generator = ChatterPlanGenerator(chat_id)
|
||||
self.executor = ChatterPlanExecutor(action_manager)
|
||||
|
||||
# 使用新的统一兴趣度管理系统
|
||||
|
||||
# 规划器统计
|
||||
self.planner_stats = {
|
||||
"total_plans": 0,
|
||||
"successful_plans": 0,
|
||||
"failed_plans": 0,
|
||||
"replies_generated": 0,
|
||||
"other_actions_executed": 0,
|
||||
}
|
||||
|
||||
async def plan(self, context: "StreamContext" = None) -> Tuple[List[Dict], Optional[Dict]]:
|
||||
"""
|
||||
执行完整的增强版规划流程。
|
||||
|
||||
Args:
|
||||
context (StreamContext): 包含聊天流消息的上下文对象。
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict], Optional[Dict]]: 一个元组,包含:
|
||||
- final_actions_dict (List[Dict]): 最终确定的动作列表(字典格式)。
|
||||
- final_target_message_dict (Optional[Dict]): 最终的目标消息(字典格式)。
|
||||
"""
|
||||
try:
|
||||
self.planner_stats["total_plans"] += 1
|
||||
|
||||
return await self._enhanced_plan_flow(context)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"规划流程出错: {e}")
|
||||
self.planner_stats["failed_plans"] += 1
|
||||
return [], None
|
||||
|
||||
async def _enhanced_plan_flow(self, context: "StreamContext") -> Tuple[List[Dict], Optional[Dict]]:
|
||||
"""执行增强版规划流程"""
|
||||
try:
|
||||
# 在规划前,先进行动作修改
|
||||
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||
action_modifier = ActionModifier(self.action_manager, self.chat_id)
|
||||
await action_modifier.modify_actions()
|
||||
|
||||
# 1. 生成初始 Plan
|
||||
initial_plan = await self.generator.generate(context.chat_mode)
|
||||
|
||||
# 确保Plan中包含所有当前可用的动作
|
||||
initial_plan.available_actions = self.action_manager.get_using_actions()
|
||||
|
||||
unread_messages = context.get_unread_messages() if context else []
|
||||
# 2. 使用新的兴趣度管理系统进行评分
|
||||
score = 0.0
|
||||
should_reply = False
|
||||
reply_not_available = False
|
||||
|
||||
if unread_messages:
|
||||
# 获取用户ID,优先从user_info.user_id获取,其次从user_id属性获取
|
||||
user_id = None
|
||||
first_message = unread_messages[0]
|
||||
user_id = first_message.user_info.user_id
|
||||
|
||||
# 构建计算上下文
|
||||
calc_context = {
|
||||
"stream_id": self.chat_id,
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
||||
# 为每条消息计算兴趣度
|
||||
for message in unread_messages:
|
||||
try:
|
||||
# 使用插件内部的兴趣度评分系统计算
|
||||
interest_score = await chatter_interest_scoring_system._calculate_single_message_score(
|
||||
message=message,
|
||||
bot_nickname=global_config.bot.nickname
|
||||
)
|
||||
message_interest = interest_score.total_score
|
||||
|
||||
# 更新消息的兴趣度
|
||||
message.interest_value = message_interest
|
||||
|
||||
# 简单的回复决策逻辑:兴趣度超过阈值则回复
|
||||
message.should_reply = message_interest > global_config.affinity_flow.non_reply_action_interest_threshold
|
||||
|
||||
logger.debug(f"消息 {message.message_id} 兴趣度: {message_interest:.3f}, 应回复: {message.should_reply}")
|
||||
|
||||
# 更新StreamContext中的消息信息并刷新focus_energy
|
||||
if context:
|
||||
from src.chat.message_manager.message_manager import message_manager
|
||||
message_manager.update_message(
|
||||
stream_id=self.chat_id,
|
||||
message_id=message.message_id,
|
||||
interest_value=message_interest,
|
||||
should_reply=message.should_reply
|
||||
)
|
||||
|
||||
# 更新数据库中的消息记录
|
||||
try:
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
MessageStorage.update_message_interest_value(message.message_id, message_interest)
|
||||
logger.debug(f"已更新数据库中消息 {message.message_id} 的兴趣度为: {message_interest:.3f}")
|
||||
except Exception as e:
|
||||
logger.warning(f"更新数据库消息兴趣度失败: {e}")
|
||||
|
||||
# 记录最高分
|
||||
if message_interest > score:
|
||||
score = message_interest
|
||||
if message.should_reply:
|
||||
should_reply = True
|
||||
else:
|
||||
reply_not_available = True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"计算消息 {message.message_id} 兴趣度失败: {e}")
|
||||
# 设置默认值
|
||||
message.interest_value = 0.0
|
||||
message.should_reply = False
|
||||
|
||||
# 检查兴趣度是否达到非回复动作阈值
|
||||
non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold
|
||||
if score < non_reply_action_interest_threshold:
|
||||
logger.info(f"兴趣度 {score:.3f} 低于阈值 {non_reply_action_interest_threshold:.3f},不执行动作")
|
||||
# 直接返回 no_action
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
|
||||
no_action = ActionPlannerInfo(
|
||||
action_type="no_action",
|
||||
reasoning=f"兴趣度评分 {score:.3f} 未达阈值 {non_reply_action_interest_threshold:.3f}",
|
||||
action_data={},
|
||||
action_message=None,
|
||||
)
|
||||
filtered_plan = initial_plan
|
||||
filtered_plan.decided_actions = [no_action]
|
||||
else:
|
||||
# 4. 筛选 Plan
|
||||
available_actions = list(initial_plan.available_actions.keys())
|
||||
plan_filter = ChatterPlanFilter(self.chat_id, available_actions)
|
||||
filtered_plan = await plan_filter.filter(reply_not_available, initial_plan)
|
||||
|
||||
# 检查filtered_plan是否有reply动作,用于统计
|
||||
has_reply_action = any(decision.action_type == "reply" for decision in filtered_plan.decided_actions)
|
||||
|
||||
# 5. 使用 PlanExecutor 执行 Plan
|
||||
execution_result = await self.executor.execute(filtered_plan)
|
||||
|
||||
# 6. 根据执行结果更新统计信息
|
||||
self._update_stats_from_execution_result(execution_result)
|
||||
|
||||
# 7. 返回结果
|
||||
return self._build_return_result(filtered_plan)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"增强版规划流程出错: {e}")
|
||||
self.planner_stats["failed_plans"] += 1
|
||||
return [], None
|
||||
|
||||
def _update_stats_from_execution_result(self, execution_result: Dict[str, any]):
|
||||
"""根据执行结果更新规划器统计"""
|
||||
if not execution_result:
|
||||
return
|
||||
|
||||
successful_count = execution_result.get("successful_count", 0)
|
||||
|
||||
# 更新成功执行计数
|
||||
self.planner_stats["successful_plans"] += successful_count
|
||||
|
||||
# 统计回复动作和其他动作
|
||||
reply_count = 0
|
||||
other_count = 0
|
||||
|
||||
for result in execution_result.get("results", []):
|
||||
action_type = result.get("action_type", "")
|
||||
if action_type in ["reply", "proactive_reply"]:
|
||||
reply_count += 1
|
||||
else:
|
||||
other_count += 1
|
||||
|
||||
self.planner_stats["replies_generated"] += reply_count
|
||||
self.planner_stats["other_actions_executed"] += other_count
|
||||
|
||||
def _build_return_result(self, plan: "Plan") -> Tuple[List[Dict], Optional[Dict]]:
|
||||
"""构建返回结果"""
|
||||
final_actions = plan.decided_actions or []
|
||||
final_target_message = next((act.action_message for act in final_actions if act.action_message), None)
|
||||
|
||||
final_actions_dict = [asdict(act) for act in final_actions]
|
||||
|
||||
if final_target_message:
|
||||
if hasattr(final_target_message, "__dataclass_fields__"):
|
||||
final_target_message_dict = asdict(final_target_message)
|
||||
else:
|
||||
final_target_message_dict = final_target_message
|
||||
else:
|
||||
final_target_message_dict = None
|
||||
|
||||
return final_actions_dict, final_target_message_dict
|
||||
|
||||
def get_planner_stats(self) -> Dict[str, any]:
|
||||
"""获取规划器统计"""
|
||||
return self.planner_stats.copy()
|
||||
|
||||
def get_current_mood_state(self) -> str:
|
||||
"""获取当前聊天的情绪状态"""
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(self.chat_id)
|
||||
return chat_mood.mood_state
|
||||
|
||||
def get_mood_stats(self) -> Dict[str, any]:
|
||||
"""获取情绪状态统计"""
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(self.chat_id)
|
||||
return {
|
||||
"current_mood": chat_mood.mood_state,
|
||||
"is_angry_from_wakeup": chat_mood.is_angry_from_wakeup,
|
||||
"regression_count": chat_mood.regression_count,
|
||||
"last_change_time": chat_mood.last_change_time,
|
||||
}
|
||||
|
||||
|
||||
# 全局兴趣度评分系统实例 - 在 individuality 模块中创建
|
||||
290
src/plugins/built_in/affinity_flow_chatter/planner_prompts.py
Normal file
290
src/plugins/built_in/affinity_flow_chatter/planner_prompts.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""
|
||||
本文件集中管理所有与规划器(Planner)相关的提示词(Prompt)模板。
|
||||
|
||||
通过将提示词与代码逻辑分离,可以更方便地对模型的行为进行迭代和优化,
|
||||
而无需修改核心代码。
|
||||
"""
|
||||
|
||||
from src.chat.utils.prompt import Prompt
|
||||
|
||||
|
||||
def init_prompts():
|
||||
"""
|
||||
初始化并向 Prompt 注册系统注册所有规划器相关的提示词。
|
||||
|
||||
这个函数会在模块加载时自动调用,确保所有提示词在系统启动时都已准备就绪。
|
||||
"""
|
||||
# 核心规划器提示词,用于在接收到新消息时决定如何回应。
|
||||
# 它构建了一个复杂的上下文,包括历史记录、可用动作、角色设定等,
|
||||
# 并要求模型以 JSON 格式输出一个或多个动作组合。
|
||||
Prompt(
|
||||
"""
|
||||
{mood_block}
|
||||
{time_block}
|
||||
{identity_block}
|
||||
|
||||
{users_in_chat}
|
||||
{custom_prompt_block}
|
||||
{chat_context_description},以下是具体的聊天内容。
|
||||
|
||||
## 📜 已读历史消息(仅供参考)
|
||||
{read_history_block}
|
||||
|
||||
## 📬 未读历史消息(动作执行对象)
|
||||
{unread_history_block}
|
||||
|
||||
{moderation_prompt}
|
||||
|
||||
**任务: 构建一个完整的响应**
|
||||
你的任务是根据当前的聊天内容,构建一个完整的、人性化的响应。一个完整的响应由两部分组成:
|
||||
1. **主要动作**: 这是响应的核心,通常是 `reply`(如果有)。
|
||||
2. **辅助动作 (可选)**: 这是为了增强表达效果的附加动作,例如 `emoji`(发送表情包)或 `poke_user`(戳一戳)。
|
||||
|
||||
**决策流程:**
|
||||
1. **重要:已读历史消息仅作为当前聊天情景的参考,帮助你理解对话上下文。**
|
||||
2. **重要:所有动作的执行对象只能是未读历史消息中的消息,不能对已读消息执行动作。**
|
||||
3. 在未读历史消息中,优先对兴趣值高的消息做出动作(兴趣值标注在消息末尾)。
|
||||
4. 首先,决定是否要对未读消息进行 `reply`(如果有)。
|
||||
5. 然后,评估当前的对话气氛和用户情绪,判断是否需要一个**辅助动作**来让你的回应更生动、更符合你的性格。
|
||||
6. 如果需要,选择一个最合适的辅助动作与 `reply`(如果有) 组合。
|
||||
7. 如果用户明确要求了某个动作,请务必优先满足。
|
||||
|
||||
**重要提醒:**
|
||||
- **回复消息时必须遵循对话的流程,不要重复已经说过的话。**
|
||||
- **确保回复与上下文紧密相关,回应要针对用户的消息内容。**
|
||||
- **保持角色设定的一致性,使用符合你性格的语言风格。**
|
||||
- **不要对表情包消息做出回应!**
|
||||
|
||||
**输出格式:**
|
||||
请严格按照以下 JSON 格式输出,包含 `thinking` 和 `actions` 字段:
|
||||
|
||||
**重要概念:将“内心思考”作为思绪流的体现**
|
||||
`thinking` 字段是本次决策的核心。它并非一个简单的“理由”,而是 **一个模拟人类在回应前,头脑中自然浮现的、未经修饰的思绪流**。你需要完全代入 {identity_block} 的角色,将那一刻的想法自然地记录下来。
|
||||
|
||||
**内心思考的要点:**
|
||||
* **自然流露**: 不要使用“决定”、“所以”、“因此”等结论性或汇报式的词语。你的思考应该像日记一样,是给自己看的,充满了不确定性和情绪的自然流动。
|
||||
* **展现过程**: 重点在于展现 **思考的过程**,而不是 **决策的结果**。描述你看到了什么,想到了什么,感受到了什么。
|
||||
* **使用昵称**: 在你的思绪流中,请直接使用用户的昵称来指代他们,而不是`<m1>`, `<m2>`这样的消息ID。
|
||||
* **严禁技术术语**: 严禁在思考中提及任何数字化的度量(如兴趣度、分数)或内部技术术语。请完全使用角色自身的感受和语言来描述思考过程。
|
||||
|
||||
## 可用动作列表
|
||||
{action_options_text}
|
||||
|
||||
```json
|
||||
{{
|
||||
"thinking": "在这里写下你的思绪流...",
|
||||
"actions": [
|
||||
{{
|
||||
"action_type": "动作类型(如:reply, emoji等)",
|
||||
"reasoning": "选择该动作的理由",
|
||||
"action_data": {{
|
||||
"target_message_id": "目标消息ID",
|
||||
"content": "回复内容或其他动作所需数据"
|
||||
}}
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
**强制规则**:
|
||||
- 对于每一个需要目标消息的动作(如`reply`, `poke_user`, `set_emoji_like`),你 **必须** 在`action_data`中提供准确的`target_message_id`,这个ID来源于`## 未读历史消息`中消息前的`<m...>`标签。
|
||||
- 当你选择的动作需要参数时(例如 `set_emoji_like` 需要 `emoji` 参数),你 **必须** 在 `action_data` 中提供所有必需的参数及其对应的值。
|
||||
|
||||
如果没有合适的回复对象或不需要回复,输出空的 actions 数组:
|
||||
```json
|
||||
{{
|
||||
"thinking": "说明为什么不需要回复",
|
||||
"actions": []
|
||||
}}
|
||||
```
|
||||
""",
|
||||
"planner_prompt",
|
||||
)
|
||||
|
||||
# 主动规划器提示词,用于主动场景和前瞻性规划
|
||||
Prompt(
|
||||
"""
|
||||
{mood_block}
|
||||
{time_block}
|
||||
{identity_block}
|
||||
|
||||
{users_in_chat}
|
||||
{custom_prompt_block}
|
||||
{chat_context_description},以下是具体的聊天内容。
|
||||
|
||||
## 📜 已读历史消息(仅供参考)
|
||||
{read_history_block}
|
||||
|
||||
## 📬 未读历史消息(动作执行对象)
|
||||
{unread_history_block}
|
||||
|
||||
{moderation_prompt}
|
||||
|
||||
**任务: 构建一个完整的响应**
|
||||
你的任务是根据当前的聊天内容,构建一个完整的、人性化的响应。一个完整的响应由两部分组成:
|
||||
1. **主要动作**: 这是响应的核心,通常是 `reply`(如果有)。
|
||||
2. **辅助动作 (可选)**: 这是为了增强表达效果的附加动作,例如 `emoji`(发送表情包)或 `poke_user`(戳一戳)。
|
||||
|
||||
**决策流程:**
|
||||
1. **重要:已读历史消息仅作为当前聊天情景的参考,帮助你理解对话上下文。**
|
||||
2. **重要:所有动作的执行对象只能是未读历史消息中的消息,不能对已读消息执行动作。**
|
||||
3. 在未读历史消息中,优先对兴趣值高的消息做出动作(兴趣值标注在消息末尾)。
|
||||
4. 首先,决定是否要对未读消息进行 `reply`(如果有)。
|
||||
5. 然后,评估当前的对话气氛和用户情绪,判断是否需要一个**辅助动作**来让你的回应更生动、更符合你的性格。
|
||||
6. 如果需要,选择一个最合适的辅助动作与 `reply`(如果有) 组合。
|
||||
7. 如果用户明确要求了某个动作,请务必优先满足。
|
||||
|
||||
**动作限制:**
|
||||
- 在私聊中,你只能使用 `reply` 动作。私聊中不允许使用任何其他动作。
|
||||
- 在群聊中,你可以自由选择是否使用辅助动作。
|
||||
|
||||
**重要提醒:**
|
||||
- **回复消息时必须遵循对话的流程,不要重复已经说过的话。**
|
||||
- **确保回复与上下文紧密相关,回应要针对用户的消息内容。**
|
||||
- **保持角色设定的一致性,使用符合你性格的语言风格。**
|
||||
|
||||
**输出格式:**
|
||||
请严格按照以下 JSON 格式输出,包含 `thinking` 和 `actions` 字段:
|
||||
```json
|
||||
{{
|
||||
"thinking": "你的思考过程,分析当前情况并说明为什么选择这些动作",
|
||||
"actions": [
|
||||
{{
|
||||
"action_type": "动作类型(如:reply, emoji等)",
|
||||
"reasoning": "选择该动作的理由",
|
||||
"action_data": {{
|
||||
"target_message_id": "目标消息ID",
|
||||
"content": "回复内容或其他动作所需数据"
|
||||
}}
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
如果没有合适的回复对象或不需要回复,输出空的 actions 数组:
|
||||
```json
|
||||
{{
|
||||
"thinking": "说明为什么不需要回复",
|
||||
"actions": []
|
||||
}}
|
||||
```
|
||||
""",
|
||||
"proactive_planner_prompt",
|
||||
)
|
||||
|
||||
# 轻量级规划器提示词,用于快速决策和简单场景
|
||||
Prompt(
|
||||
"""
|
||||
{identity_block}
|
||||
|
||||
## 当前聊天情景
|
||||
{chat_context_description}
|
||||
|
||||
## 未读消息
|
||||
{unread_history_block}
|
||||
|
||||
**任务:快速决策**
|
||||
请根据当前聊天内容,快速决定是否需要回复。
|
||||
|
||||
**决策规则:**
|
||||
1. 如果有人直接提到你或问你问题,优先回复
|
||||
2. 如果消息内容符合你的兴趣,考虑回复
|
||||
3. 如果只是群聊中的普通聊天且与你无关,可以不回复
|
||||
|
||||
**输出格式:**
|
||||
```json
|
||||
{{
|
||||
"thinking": "简要分析",
|
||||
"actions": [
|
||||
{{
|
||||
"action_type": "reply",
|
||||
"reasoning": "回复理由",
|
||||
"action_data": {{
|
||||
"target_message_id": "目标消息ID",
|
||||
"content": "回复内容"
|
||||
}}
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
""",
|
||||
"chatter_planner_lite",
|
||||
)
|
||||
|
||||
# 动作筛选器提示词,用于筛选和优化规划器生成的动作
|
||||
Prompt(
|
||||
"""
|
||||
{identity_block}
|
||||
|
||||
## 原始动作计划
|
||||
{original_plan}
|
||||
|
||||
## 聊天上下文
|
||||
{chat_context}
|
||||
|
||||
**任务:动作筛选优化**
|
||||
请对原始动作计划进行筛选和优化,确保动作的合理性和有效性。
|
||||
|
||||
**筛选原则:**
|
||||
1. 移除重复或不必要的动作
|
||||
2. 确保动作之间的逻辑顺序
|
||||
3. 优化动作的具体参数
|
||||
4. 考虑当前聊天环境和个人设定
|
||||
|
||||
**输出格式:**
|
||||
```json
|
||||
{{
|
||||
"thinking": "筛选优化思考",
|
||||
"actions": [
|
||||
{{
|
||||
"action_type": "优化后的动作类型",
|
||||
"reasoning": "优化理由",
|
||||
"action_data": {{
|
||||
"target_message_id": "目标消息ID",
|
||||
"content": "优化后的内容"
|
||||
}}
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
""",
|
||||
"chatter_plan_filter",
|
||||
)
|
||||
|
||||
# 动作提示词,用于格式化动作选项
|
||||
Prompt(
|
||||
"""
|
||||
## 动作: {action_name}
|
||||
**描述**: {action_description}
|
||||
|
||||
**参数**:
|
||||
{action_parameters}
|
||||
|
||||
**要求**:
|
||||
{action_require}
|
||||
|
||||
**使用说明**:
|
||||
请根据上述信息判断是否需要使用此动作。
|
||||
""",
|
||||
"action_prompt",
|
||||
)
|
||||
|
||||
# 带有完整JSON示例的动作提示词模板
|
||||
Prompt(
|
||||
"""
|
||||
动作: {action_name}
|
||||
动作描述: {action_description}
|
||||
动作使用场景:
|
||||
{action_require}
|
||||
|
||||
你应该像这样使用它:
|
||||
{{
|
||||
{json_example}
|
||||
}}
|
||||
""",
|
||||
"action_prompt_with_example",
|
||||
)
|
||||
|
||||
|
||||
# 确保提示词在模块加载时初始化
|
||||
init_prompts()
|
||||
46
src/plugins/built_in/affinity_flow_chatter/plugin.py
Normal file
46
src/plugins/built_in/affinity_flow_chatter/plugin.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
亲和力聊天处理器插件
|
||||
"""
|
||||
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
from src.plugin_system.apis.plugin_register_api import register_plugin
|
||||
from src.plugin_system.base.base_plugin import BasePlugin
|
||||
from src.plugin_system.base.component_types import ComponentInfo
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("affinity_chatter_plugin")
|
||||
|
||||
|
||||
@register_plugin
|
||||
class AffinityChatterPlugin(BasePlugin):
|
||||
"""亲和力聊天处理器插件
|
||||
|
||||
- 延迟导入 `AffinityChatter` 并通过组件注册器注册为聊天处理器
|
||||
- 提供 `get_plugin_components` 以兼容插件注册机制
|
||||
"""
|
||||
|
||||
plugin_name: str = "affinity_chatter"
|
||||
enable_plugin: bool = True
|
||||
dependencies: list[str] = []
|
||||
python_dependencies: list[str] = []
|
||||
config_file_name: str = ""
|
||||
|
||||
# 简单的 config_schema 占位(如果将来需要配置可扩展)
|
||||
config_schema = {}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
"""返回插件包含的组件列表(ChatterInfo, AffinityChatter)
|
||||
|
||||
这里采用延迟导入 AffinityChatter 来避免循环依赖和启动顺序问题。
|
||||
如果导入失败则返回空列表以让注册过程继续而不崩溃。
|
||||
"""
|
||||
try:
|
||||
# 延迟导入以避免循环导入
|
||||
from .affinity_chatter import AffinityChatter
|
||||
|
||||
return [(AffinityChatter.get_chatter_info(), AffinityChatter)]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载 AffinityChatter 时出错: {e}")
|
||||
return []
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user